Shortcuts

MTCNN

class torcheeg.models.MTCNN(in_channels: int = 8, grid_size: Tuple[int, int] = (8, 9), num_classes: int = 2, dropout: float = 0.2)[source][source]

Multi-Task Convolutional Neural Network (MT-CNN). For more details, please refer to the following information.

Below is a recommended suite for use in emotion recognition tasks:

from torcheeg.datasets import DEAPDataset
from torcheeg import transforms
from torcheeg.datasets.constants import DEAP_CHANNEL_LOCATION_DICT, DEAP_CHANNEL_LIST
from torcheeg.models import MTCNN
from torcheeg.datasets.constants.emotion_recognition.utils import format_channel_location_dict
from torch.utils.data import DataLoader

DEAP_LOCATION_LIST = [['-', '-', 'AF3', 'FP1', '-', 'FP2', 'AF4', '-', '-'],
                      ['F7', '-', 'F3', '-', 'FZ', '-', 'F4', '-', 'F8'],
                      ['-', 'FC5', '-', 'FC1', '-', 'FC2', '-', 'FC6', '-'],
                      ['T7', '-', 'C3', '-', 'CZ', '-', 'C4', '-', 'T8'],
                      ['-', 'CP5', '-', 'CP1', '-', 'CP2', '-', 'CP6', '-'],
                      ['P7', '-', 'P3', '-', 'PZ', '-', 'P4', '-', 'P8'],
                      ['-', '-', '-', 'PO3', '-', 'PO4', '-', '-', '-'],
                      ['-', '-', '-', 'O1', 'OZ', 'O2', '-', '-', '-']]
DEAP_CHANNEL_LOCATION_DICT = format_channel_location_dict(DEAP_CHANNEL_LIST, DEAP_LOCATION_LIST)

dataset = DEAPDataset(root_path='./data_preprocessed_python',
                      online_transform=transforms.Compose([
                          transforms.Concatenate([
                              transforms.BandDifferentialEntropy(),
                              transforms.BandPowerSpectralDensity()
                          ]),
                          transforms.ToGrid(DEAP_CHANNEL_LOCATION_DICT)
                      ]),
                      label_transform=transforms.Compose([
                          transforms.Select('valence'),
                          transforms.Binary(5.0),
                      ]))

model = MTCNN(num_classes=2, in_channels=8, grid_size=(8, 9), dropout=0.2)

x, y = next(iter(DataLoader(dataset, batch_size=64)))
model(x)
Parameters:
  • in_channels (int) – The feature dimension of each electrode, i.e., \(N\) in the paper. (default: 4)

  • grid_size (tuple) – Spatial dimensions of grid-like EEG representation. (default: (8, 9))

  • num_classes (int) – The number of classes to predict. (default: 2)

  • dropout (float) – Probability of an element to be zeroed in the dropout layers. (default: 0.2)

forward(x: Tensor) Tensor[source][source]
Parameters:

x (torch.Tensor) – EEG signal representation, the ideal input shape is [n, 8, 8, 9]. Here, n corresponds to the batch size, 8 corresponds to in_channels, and (8, 9) corresponds to grid_size.

Returns:

the predicted probability that the samples belong to the classes.

Return type:

torch.Tensor[number of sample, number of classes]

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources