Shortcuts

Source code for torcheeg.trainers.domain_adaption.dan

from typing import List

import torch
import torch.nn as nn

from .mmd_like import _MMDLikeTrainer


def guassian_kernel(x_source,
                    x_target,
                    mul_kernel=2.0,
                    num_kernels=5,
                    sigma=None):
    n_samples = int(x_source.shape[0]) + int(x_target.shape[0])
    total = torch.cat([x_source, x_target], dim=0)
    total0 = total.unsqueeze(0).expand(int(total.size(0)), int(total.size(0)),
                                       int(total.size(1)))
    total1 = total.unsqueeze(1).expand(int(total.size(0)), int(total.size(0)),
                                       int(total.size(1)))
    L2_distance = ((total0 - total1)**2).sum(2)
    if sigma:
        bandwidth = sigma
    else:
        bandwidth = torch.sum(L2_distance.data) / (n_samples**2 - n_samples)
    bandwidth /= mul_kernel**(num_kernels // 2)
    bandwidth_list = [bandwidth * (mul_kernel**i) for i in range(num_kernels)]
    kernel_val = [
        torch.exp(-L2_distance / bandwidth_temp)
        for bandwidth_temp in bandwidth_list
    ]
    return sum(kernel_val)


def maximum_mean_discrepancy(x_source,
                             x_target,
                             mul_kernel=2.0,
                             num_kernels=5,
                             sigma=None):
    batch_size = int(x_source.shape[0])
    kernels = guassian_kernel(x_source,
                              x_target,
                              mul_kernel=mul_kernel,
                              num_kernels=num_kernels,
                              sigma=sigma)
    XX = kernels[:batch_size, :batch_size]
    YY = kernels[batch_size:, batch_size:]
    XY = kernels[:batch_size, batch_size:]
    YX = kernels[batch_size:, :batch_size]
    loss = torch.mean(XX + YY - XY - YX)
    return loss


[docs]class DANTrainer(_MMDLikeTrainer): r''' This class supports the implementation of Deep Adaptation Network (DAN) for deep domain adaptation. NOTE: DAN belongs to unsupervised domain adaptation methods, which only use labeled source and unlabeled target data. This means that the target dataset does not have to return labels. - Paper: Long M, Cao Y, Wang J, et al. Learning transferable features with deep adaptation networks[C]//International conference on machine learning. PMLR, 2015: 97-105. - URL: https://proceedings.mlr.press/v37/long15 - Related Project: https://github.com/jindongwang/transferlearning/blob/cfaf1174dff7390a861cc4abd5ede37dfa1063f5/code/deep/DAN/DAN.py .. code-block:: python trainer = DANTrainer(extractor, classifier, num_classes=10, devices=1, weight_domain=1.0, accelerator='gpu') trainer.fit(source_loader, target_loader, val_loader) trainer.test(test_loader) Args: extractor (nn.Module): The feature extraction model learns the feature representation of the EEG signal by forcing the correlation matrixes of source and target data to be close. classifier (nn.Module): The classification model learns the classification task with the source labeled data based on the feature of the feature extraction model. The dimension of its output should be equal to the number of categories in the dataset. The output layer does not need to have a softmax activation function. num_classes (int, optional): The number of categories in the dataset. (default: :obj:`None`) lr (float): The learning rate. (default: :obj:`0.0001`) weight_decay (float): The weight decay. (default: :obj:`0.0`) weight_domain (float): The weight of the DAN loss. (default: :obj:`1.0`) weight_scheduler (bool): Whether to use a scheduler for the weight of the DAN loss, growing from 0 to 1 following the schedule from the DANN paper. (default: :obj:`False`) lr_scheduler (bool): Whether to use a scheduler for the learning rate, as defined in the DANN paper. (default: :obj:`False`) warmup_epochs (int): The number of epochs for the warmup phase, during which the weight of the DAN loss is 0. (default: :obj:`0`) devices (int): The number of devices to use. (default: :obj:`1`) accelerator (str): The accelerator to use. Available options are: 'cpu', 'gpu'. (default: :obj:`"cpu"`) metrics (list of str): The metrics to use. Available options are: 'precision', 'recall', 'f1score', 'accuracy'. (default: :obj:`["accuracy"]`) .. automethod:: fit .. automethod:: test ''' def __init__(self, extractor: nn.Module, classifier: nn.Module, num_classes: int, lr: float = 1e-4, weight_decay: float = 0.0, weight_domain: float = 1.0, weight_scheduler: bool = True, lr_scheduler_gamma: float = 0.0, lr_scheduler_decay: float = 0.75, warmup_epochs: int = 0, devices: int = 1, accelerator: str = "cpu", metrics: List[str] = ["accuracy"]): super(DANTrainer, self).__init__(extractor=extractor, classifier=classifier, num_classes=num_classes, lr=lr, weight_decay=weight_decay, weight_domain=weight_domain, weight_scheduler=weight_scheduler, lr_scheduler_gamma=lr_scheduler_gamma, lr_scheduler_decay=lr_scheduler_decay, warmup_epochs=warmup_epochs, devices=devices, accelerator=accelerator, metrics=metrics) def _domain_loss_fn(self, x_source_feat: torch.Tensor, x_target_feat: torch.Tensor) -> torch.Tensor: batch_size = min(len(x_source_feat), len(x_target_feat)) if len(x_source_feat) != len(x_target_feat): # use the smaller batch size x_source_feat = x_source_feat[:batch_size] x_target_feat = x_target_feat[:batch_size] return maximum_mean_discrepancy(x_source_feat, x_target_feat)

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources