Shortcuts

Source code for torcheeg.trainers.domain_adaption.ada

from typing import List, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F

from .mmd_like import _MMDLikeTrainer


class WalkerLoss(nn.Module):
    def forward(self, P_aba, y):
        equality_matrix = torch.eq(y.reshape(-1, 1), y).float()
        p_target = equality_matrix / equality_matrix.sum(dim=1, keepdim=True)
        p_target.requires_grad = False

        L_walker = F.kl_div(torch.log(1e-8 + P_aba),
                            p_target,
                            size_average=False)
        L_walker /= p_target.shape[0]

        return L_walker


class VisitLoss(nn.Module):
    def forward(self, P_b):
        p_visit = torch.ones([1, P_b.shape[1]]) / float(P_b.shape[1])
        p_visit.requires_grad = False
        p_visit = p_visit.to(P_b.device)
        L_visit = F.kl_div(torch.log(1e-8 + P_b), p_visit, size_average=False)
        L_visit /= p_visit.shape[0]

        return L_visit


class AssociationMatrix(nn.Module):
    def __init__(self):
        super(AssociationMatrix, self).__init__()

    def forward(self, X_source, X_target):
        X_source = X_source.reshape(X_source.shape[0], -1)
        X_target = X_target.reshape(X_target.shape[0], -1)

        W = torch.mm(X_source, X_target.transpose(1, 0))

        P_ab = F.softmax(W, dim=1)
        P_ba = F.softmax(W.transpose(1, 0), dim=1)

        P_aba = P_ab.mm(P_ba)
        P_b = torch.mean(P_ab, dim=0, keepdim=True)

        return P_aba, P_b


class AssociativeLoss(nn.Module):
    def __init__(self, walker_weight=1., visit_weight=1.):
        super(AssociativeLoss, self).__init__()

        self.matrix = AssociationMatrix()
        self.walker = WalkerLoss()
        self.visit = VisitLoss()

        self.walker_weight = walker_weight
        self.visit_weight = visit_weight

    def forward(self, X_source, X_target, y):

        P_aba, P_b = self.matrix(X_source, X_target)
        L_walker = self.walker(P_aba, y)
        L_visit = self.visit(P_b)

        return self.visit_weight * L_visit + self.walker_weight * L_walker


[docs]class ADATrainer(_MMDLikeTrainer): r''' This class supports the implementation of Associative Domain Adaptation (ADA) for deep domain adaptation. NOTE: ADA belongs to unsupervised domain adaptation methods, which only use labeled source data and unlabeled target data. This means that the target dataset does not have to contain labels. - Paper: Haeusser P, Frerix T, Mordvintsev A, et al. Associative domain adaptation[C]//Proceedings of the IEEE international conference on computer vision. 2017: 2765-2773. - URL: https://arxiv.org/abs/1708.00938 - Related Project: https://github.com/stes/torch-assoc .. code-block:: python trainer = ADATrainer(extractor, classifier, num_classes=10, devices=1, weight_visit=0.6, accelerator='gpu') trainer.fit(source_loader, target_loader, val_loader) trainer.test(test_loader) Args: extractor (nn.Module): The feature extraction model learns the feature representation of the EEG signal by forcing the correlation matrixes of source and target data to be close. classifier (nn.Module): The classification model learns the classification task with the source labeled data based on the feature of the feature extraction model. The dimension of its output should be equal to the number of categories in the dataset. The output layer does not need to have a softmax activation function. num_classes (int, optional): The number of categories in the dataset. (default: :obj:`None`) lr (float): The learning rate. (default: :obj:`0.0001`) weight_decay (float): The weight decay. (default: :obj:`0.0`) weight_walker (float): The weight of the walker loss. (default: :obj:`1.0`) weight_visit (float): The weight of the visit loss. (default: :obj:`1.0`) weight_domain (float): The weight of the associative loss (default: :obj:`1.0`) weight_scheduler (bool): Whether to use a scheduler for the weight of the associative loss, growing from 0 to 1 following the schedule from the DANN paper. (default: :obj:`False`) lr_scheduler (bool): Whether to use a scheduler for the learning rate, as defined in the DANN paper. (default: :obj:`False`) warmup_epochs (int): The number of epochs for the warmup phase, during which the weight of the associative loss is 0. (default: :obj:`0`) devices (int): The number of devices to use. (default: :obj:`1`) accelerator (str): The accelerator to use. Available options are: 'cpu', 'gpu'. (default: :obj:`"cpu"`) metrics (list of str): The metrics to use. Available options are: 'precision', 'recall', 'f1score', 'accuracy'. (default: :obj:`["accuracy"]`) .. automethod:: fit .. automethod:: test ''' def __init__(self, extractor: nn.Module, classifier: nn.Module, num_classes: int, lr: float = 1e-4, weight_walker: float = 1.0, weight_visit: float = 1.0, weight_domain: float = 1.0, weight_decay: float = 0.0, weight_scheduler: bool = True, lr_scheduler_gamma: float = 0.0, lr_scheduler_decay: float = 0.75, warmup_epochs: int = 0, devices: int = 1, accelerator: str = "cpu", metrics: List[str] = ["accuracy"]): super(ADATrainer, self).__init__(extractor=extractor, classifier=classifier, num_classes=num_classes, lr=lr, weight_decay=weight_decay, weight_domain=weight_domain, weight_scheduler=weight_scheduler, lr_scheduler_gamma=lr_scheduler_gamma, lr_scheduler_decay=lr_scheduler_decay, warmup_epochs=warmup_epochs, devices=devices, accelerator=accelerator, metrics=metrics) self.weight_walker = weight_walker self.weight_visit = weight_visit self._assoc_fn = AssociativeLoss(weight_walker, weight_visit) def training_step(self, batch: Tuple[torch.Tensor], batch_idx: int) -> torch.Tensor: (x_source, y_source), (x_target, _) = batch x_source_feat = self.extractor(x_source) y_source_pred = self.classifier(x_source_feat) x_target_feat = self.extractor(x_target) domain_loss = self._domain_loss_fn(x_source_feat, x_target_feat, y_source) task_loss = self.ce_fn(y_source_pred, y_source) if self.current_epoch >= self.warmup_epochs: loss = task_loss + self.scheduled_weight_domain * domain_loss else: loss = task_loss self.log("train_domain_loss", self.train_domain_loss(loss), prog_bar=True, on_epoch=False, logger=False, on_step=True) self.log("train_task_loss", self.train_task_loss(loss), prog_bar=True, on_epoch=False, logger=False, on_step=True) for i, metric_value in enumerate(self.train_metrics.values()): self.log(f"train_{self.metrics[i]}", metric_value(y_source_pred, y_source), prog_bar=True, on_epoch=False, logger=False, on_step=True) return loss def _domain_loss_fn(self, x_source_feat: torch.Tensor, x_target_feat: torch.Tensor, y_source: torch.Tensor) -> torch.Tensor: return self._assoc_fn(x_source_feat, x_target_feat, y_source)

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources