Shortcuts

Introduction to the trainers Module

Welcome to the guide on TorchEEG’s trainers Module! This module provides you with a suite of trainers, built on Pytorch-lightning, for model training. These trainers are designed to handle a wide array of models, ranging from discriminative to generative ones, and are equipped with capabilities for contrast learning and fine-tuning. They are also capable of running on various hardware configurations, from single CPUs, single GPUs, to multiple GPUs. You also have the flexibility to tailor these trainers according to your specific needs by extending the existing ones.

Discriminative Models with ClassifierTrainer

The simplest yet highly effective method for discriminative models involves utilizing a classification loss function, such as cross-entropy for EEG signal recognition training, encapsulated in the ClassifierTrainer algorithm.

from torcheeg.datasets import DEAPDataset
from torcheeg import transforms

from torcheeg.model_selection import KFoldGroupbyTrial
from torcheeg.datasets.constants import \
    DEAP_CHANNEL_LOCATION_DICT
from torch.utils.data import DataLoader
from torcheeg.models import CCNN

from torcheeg.trainers import ClassifierTrainer

import pytorch_lightning as pl

dataset = DEAPDataset(
    io_path=f'./examples_trainers_1/deap',
    root_path='./data_preprocessed_python',
    offline_transform=transforms.Compose([
        transforms.BandDifferentialEntropy(apply_to_baseline=True),
        transforms.ToGrid(DEAP_CHANNEL_LOCATION_DICT, apply_to_baseline=True)
    ]),
    online_transform=transforms.Compose(
        [transforms.BaselineRemoval(),
         transforms.ToTensor()]),
    label_transform=transforms.Compose([
        transforms.Select('valence'),
        transforms.Binary(5.0),
    ]),
    num_worker=8)

k_fold = KFoldGroupbyTrial(n_splits=10,
                           split_path='./examples_trainers_1/split',
                           shuffle=True,
                           random_state=42)

for i, (train_dataset, val_dataset) in enumerate(k_fold.split(dataset)):
    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)

    model = CCNN(num_classes=2, in_channels=4, grid_size=(9, 9))

    trainer = ClassifierTrainer(model=model,
                                num_classes=2,
                                lr=1e-4,
                                weight_decay=1e-4,
                                devices=1,
                                accelerator="gpu")
    trainer.fit(train_loader,
                val_loader,
                max_epochs=50,
                default_root_dir=f'./examples_trainers_1/model/{i}',
                callbacks=[pl.callbacks.ModelCheckpoint(save_last=True)],
                enable_progress_bar=True,
                enable_model_summary=True,
                limit_val_batches=0.0)
    score = trainer.test(val_loader,
                         enable_progress_bar=True,
                         enable_model_summary=True)[0]
    print(f'Fold {i} test accuracy: {score["test_accuracy"]:.4f}')

Within ClassifierTrainer, you can specify the device and the number of devices to use. In the trainer.fit function, you can pass any parameters supported by the Trainer class in pytorch_lightning. If you wish to modify the training procedure, you can do so by extending the trainers:

class MyClassifierTrainer(ClassifierTrainer):
    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        # your loss function
        loss = self.ce_fn(y_hat, y)

        # log to prog_bar
        self.log("train_loss",
                 self.train_loss(loss),
                 prog_bar=True,
                 on_epoch=False,
                 logger=False,
                 on_step=True)

        for i, metric_value in enumerate(self.train_metrics.values()):
            self.log(f"train_{self.metrics[i]}",
                     metric_value(y_hat, y),
                     prog_bar=True,
                     on_epoch=False,
                     logger=False,
                     on_step=True)

        return loss

By referring to the related documents, you can also learn about the evaluation metrics supported in the trainers module. By making a few adjustments, you can get the trainer to report metrics like accuracy and f1 score.

trainer = ClassifierTrainer(model=model,
                                num_classes=2,
                                lr=1e-4,
                                weight_decay=1e-4,
                                metrics=['accuracy', 'recall', 'precision', 'f1score'],
                                accelerator="gpu")

Domain Adaptation Methods

A challenging aspect of EEG-based emotion recognition is the cross-subject problem. Even when evoked by the same stimulus, the distribution of EEG signal patterns among different individuals may undergo distributional shifts. This phenomenon negatively impacts the performance of the trained models when applied to unknown subjects. The trainers module addresses this issue by providing a host of domain adaptation algorithms. These algorithms use EEG samples from known subjects (source domain) and unknown subjects for testing (target domain). They optimize models with specific loss functions and training strategies to extract domain-invariant features or transfer knowledge from the source to the target domain. A range of cross-domain trainers like CORALTrainer, ADATrainer, DANNTrainer, DDCTrainer, and DANTrainer are available to handle different application scenarios and assist users in dealing with cross-domain problems across diverse models and datasets.

from torcheeg.datasets import DEAPDataset
from torcheeg import transforms

from torcheeg.model_selection import LeaveOneSubjectOut
from torcheeg.datasets.constants import \
    DEAP_CHANNEL_LOCATION_DICT
from torch.utils.data import DataLoader
from torcheeg.models import CCNN

from torcheeg.trainers import CORALTrainer

import pytorch_lightning as pl

dataset = DEAPDataset(
    io_path=f'./examples_trainers_2/deap',
    root_path='./data_preprocessed_python',
    offline_transform=transforms.Compose([
        transforms.BandDifferentialEntropy(apply_to_baseline=True),
        transforms.ToGrid(DEAP_CHANNEL_LOCATION_DICT, apply_to_baseline=True)
    ]),
    online_transform=transforms.Compose(
        [transforms.BaselineRemoval(),
         transforms.ToTensor()]),
    label_transform=transforms.Compose([
        transforms.Select('valence'),
        transforms.Binary(5.0),
    ]),
    num_worker=8)

k_fold = LeaveOneSubjectOut(split_path='./examples_trainers_2/split')


class Extractor(CCNN):
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = x.flatten(start_dim=1)
        return x


class Classifier(CCNN):
    def forward(self, x):
        x = self.lin1(x)
        x = self.lin2(x)
        return x


for i, (train_dataset, val_dataset) in enumerate(k_fold.split(dataset)):
    source_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
    target_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)

    extractor = Extractor(in_channels=5, num_classes=2)
    classifier = Classifier(in_channels=5, num_classes=2)

    trainer = CORALTrainer(extractor=extractor,
                                classifier=classifier,
                                num_classes=2,
                                lr=1e-4,
                                weight_decay=0.0,
                                accelerator='gpu')
    trainer.fit(source_loader,
                target_loader,
                target_loader,
                max_epochs=50,
                default_root_dir=f'./examples_trainers_2/model/{i}',
                callbacks=[pl.callbacks.ModelCheckpoint(save_last=True)],
                enable_progress_bar=True,
                enable_model_summary=True,
                limit_val_batches=0.0)
    score = trainer.test(target_loader,
                         enable_progress_bar=True,
                         enable_model_summary=True)[0]
    print(f'Fold {i} test accuracy: {score["test_accuracy"]:.4f}')

Generative Models

For generative models, we provide trainers for models including VAE, GAN, Normalizing Flow, and Diffusion Model. These trainers aim to train models to generate EEG signals that closely mimic the real distribution. We also provide conditional versions of these trainers, enabling the use of categories as conditions for category-driven EEG sample generation.

from torcheeg.trainers import CDDPMTrainer
from torcheeg.models import BCUNet

model = BCUNet(in_channels=4)
trainer = CDDPMTrainer(model, accelerator='gpu')

from torcheeg.trainers import CWGANGPTrainer
from torcheeg.models import BCGenerator, BCDiscriminator

g_model = BCGenerator(in_channels=128)
d_model = BCDiscriminator(in_channels=4)

trainer = CWGANGPTrainer(g_model,
                              d_model,
                             accelerator='gpu')

from torcheeg.trainers import CGlowTrainer
from torcheeg.models import BCGlow

model = BCGlow(in_channels=4)
trainer = CGlowTrainer(model, accelerator='gpu')

TorchEEG supports common evaluation metrics for generative models, such as FID. To use these metrics, you need to provide additional parameters to the Trainer like metric_extractor, metric_classifier, and metric_num_features. For details, please refer to the related documents. Here is an example:

import torch.nn as nn

class Extractor(nn.Module):

    def __init__(self, in_channels=4):
        super().__init__()
        self.conv1 = nn.Sequential(
            nn.ZeroPad2d((1, 2, 1, 2)),
            nn.Conv2d(in_channels, 64, kernel_size=4, stride=1), nn.ReLU())
        self.conv2 = nn.Sequential(nn.ZeroPad2d((1, 2, 1, 2)),
                                   nn.Conv2d(64, 128, kernel_size=4, stride=1),
                                   nn.ReLU())
        self.conv3 = nn.Sequential(nn.ZeroPad2d((1, 2, 1, 2)),
                                   nn.Conv2d(128, 256, kernel_size=4, stride=1),
                                   nn.ReLU())
        self.conv4 = nn.Sequential(nn.ZeroPad2d((1, 2, 1, 2)),
                                   nn.Conv2d(256, 64, kernel_size=4, stride=1),
                                   nn.ReLU())

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)

        x = x.flatten(start_dim=1)
        return x


class Classifier(nn.Module):

    def __init__(self, in_channels=4, num_classes=2):
        super().__init__()
        self.conv1 = nn.Sequential(
            nn.ZeroPad2d((1, 2, 1, 2)),
            nn.Conv2d(in_channels, 64, kernel_size=4, stride=1), nn.ReLU())
        self.conv2 = nn.Sequential(nn.ZeroPad2d((1, 2, 1, 2)),
                                   nn.Conv2d(64, 128, kernel_size=4, stride=1),
                                   nn.ReLU())
        self.conv3 = nn.Sequential(nn.ZeroPad2d((1, 2, 1, 2)),
                                   nn.Conv2d(128, 256, kernel_size=4, stride=1),
                                   nn.ReLU())
        self.conv4 = nn.Sequential(nn.ZeroPad2d((1, 2, 1, 2)),
                                   nn.Conv2d(256, 64, kernel_size=4, stride=1),
                                   nn.ReLU())

        self.lin1 = nn.Linear(9 * 9 * 64, 1024)
        self.lin2 = nn.Linear(1024, num_classes)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)

        x = x.flatten(start_dim=1)
        x = self.lin1(x)
        x = self.lin2(x)
        return x

g_model = BCGenerator(in_channels=128)
d_model = BCDiscriminator(in_channels=4)

extractor = Extractor()
classifier = Classifier()
# you may need to load state dict from your trained extractor, classifier

trainer = CWGANGPTrainer(g_model,
                              d_model,
                              metric_extractor=extractor,
                              metric_classifier=classifier,
                              metric_num_features=9 * 9 * 64,
                              metrics=['fid'],
                             accelerator='gpu')

Gallery generated by Sphinx-Gallery

Read the Docs v: latest
Versions
latest
stable
v1.1.1
v1.1.0
v1.0.11
v1.0.10
v1.0.9
v1.0.8.post1
v1.0.8
v1.0.7
v1.0.6
v1.0.4
v1.0.3
v1.0.2
v1.0.1
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources