Shortcuts

Apply the Domain Adaption Algorithm on the SEED Dataset

In this tutorial, you’ll learn how to utilize TorchEEG’s Associative Domain Adaptation (ADA) algorithm to address the cross-subject emotion recognition challenge using the SEED dataset.

Step 1: Initialize the Dataset

We leverage the SEED dataset, supported by TorchEEG, where each EEG sample lasts 1 second and consists of 200 data points.

In offline preprocessing, each electrode’s EEG signal is segmented into 4 sub-bands. We then calculate the differential entropy for each sub-band as a feature. The signals are debaselined, mapped onto a grid, and finally saved locally.

For online processing, we convert all EEG signals into Tensors for neural network input. Additionally, we apply after_hook_normalize to reduce trial variance effects on cross-subject emotion recognition.

from torcheeg.datasets import SEEDFeatureDataset
from torcheeg import transforms
from torcheeg.datasets.constants import SEED_CHANNEL_LOCATION_DICT
from torcheeg.transforms import after_hook_normalize
from pytorch_lightning.callbacks import ModelCheckpoint
import numpy as np

dataset = SEEDFeatureDataset(
    io_path=f'./examples_seed_domain_adaption/seed',
    root_path='./ExtractedFeatures',
    after_session=after_hook_normalize,
    offline_transform=transforms.Compose(
        [transforms.ToGrid(SEED_CHANNEL_LOCATION_DICT)]),
    online_transform=transforms.ToTensor(),  # seed do not have baseline signals
    label_transform=transforms.Compose(
        [transforms.Select('emotion'),
         transforms.Lambda(lambda x: x + 1)]),
    feature=['de_LDS'],
    num_worker=4)

Step 2: Define the Model

We need to split the CCNN model into two parts: the feature extractor, and the classifier. The feature extractor is trained to extract domain-invariant features, while the classifier is trained to classify the emotion.

from torcheeg.models import CCNN


class Extractor(CCNN):
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = x.flatten(start_dim=1)
        return x


class Classifier(CCNN):
    def forward(self, x):
        x = self.lin1(x)
        x = self.lin2(x)
        return x

Step 3: Split the Dataset and Train the Model

In this case, we do not consider the cross-session problem. We use the subset of each session to train the model, and use the subset of the same session to test the model. We use the leave-one-subject-out method for evaluation. Each subject is iteratively left out from the training set and used as the test set.

from torcheeg.model_selection import Subcategory

subset = Subcategory(
    criteria='session_id',
    split_path=f'./examples_seed_domain_adaption/split/session')

Step 4: Train the Model

For training, we make use of the ADA algorithm to adapt domain-invariant features across subjects.

from torch.utils.data import DataLoader
from torcheeg.model_selection import LeaveOneSubjectOut
from torcheeg.trainers import ADATrainer

scores = []
for j, sub_dataset in enumerate(subset.split(dataset)):
    loo = LeaveOneSubjectOut(
        split_path=f'./examples_seed_domain_adaption/split/loo_{j}')
    for i, (train_dataset, test_dataset) in enumerate(loo.split(sub_dataset)):

        extractor = Extractor(in_channels=5, num_classes=3)
        classifier = Classifier(in_channels=5, num_classes=3)

        source_loader = DataLoader(train_dataset,
                                   batch_size=128,
                                   shuffle=True,
                                   num_workers=4,
                                   drop_last=True)

        target_loader = DataLoader(test_dataset,
                                   batch_size=128,
                                   shuffle=True,
                                   num_workers=4,
                                   drop_last=True)

        test_loader = DataLoader(test_dataset,
                                 batch_size=128,
                                 shuffle=True,
                                 num_workers=4)

        trainer = ADATrainer(extractor=extractor,
                             classifier=classifier,
                             num_classes=2,
                             lr=1e-4,
                             weight_decay=0.0,
                             accelerator='gpu')

        trainer.fit(
            source_loader,
            target_loader,
            test_loader,
            max_epochs=50,
            default_root_dir=
            f'./examples_seed_domain_adaption/model/ses_{i}_sub_{j}.pth',
            callbacks=[ModelCheckpoint(save_last=True)],
            enable_progress_bar=True,
            enable_model_summary=True,
            limit_val_batches=0.0)
        score = trainer.test(test_loader,
                             enable_progress_bar=True,
                             enable_model_summary=True)[0]
        scores.append(score['test_accuracy'])

    print(f'final accuracy: {np.array(scores).mean():>0.1f}%')

Gallery generated by Sphinx-Gallery

Read the Docs v: latest
Versions
latest
stable
v1.1.2
v1.1.1
v1.1.0
v1.0.11
v1.0.10
v1.0.9
v1.0.8.post1
v1.0.8
v1.0.7
v1.0.6
v1.0.4
v1.0.3
v1.0.2
v1.0.1
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources