MTCNN¶
- class torcheeg.models.MTCNN(in_channels: int = 8, grid_size: Tuple[int, int] = (8, 9), num_classes: int = 2, dropout: float = 0.2)[source][source]¶
Multi-Task Convolutional Neural Network (MT-CNN). For more details, please refer to the following information.
Paper: Rudakov E, Laurent L, Cousin V, et al. Multi-Task CNN model for emotion recognition from EEG Brain maps[C]//2021 4th International Conference on Bio-Engineering for Smart Technologies (BioSMART). IEEE, 2021: 1-4.
Related Project: https://github.com/dolphin-in-a-coma/multi-task-cnn-eeg-emotion/
Below is a recommended suite for use in emotion recognition tasks:
from torcheeg.datasets import DEAPDataset from torcheeg import transforms from torcheeg.datasets.constants import DEAP_CHANNEL_LOCATION_DICT, DEAP_CHANNEL_LIST from torcheeg.models import MTCNN from torcheeg.datasets.constants.emotion_recognition.utils import format_channel_location_dict from torch.utils.data import DataLoader DEAP_LOCATION_LIST = [['-', '-', 'AF3', 'FP1', '-', 'FP2', 'AF4', '-', '-'], ['F7', '-', 'F3', '-', 'FZ', '-', 'F4', '-', 'F8'], ['-', 'FC5', '-', 'FC1', '-', 'FC2', '-', 'FC6', '-'], ['T7', '-', 'C3', '-', 'CZ', '-', 'C4', '-', 'T8'], ['-', 'CP5', '-', 'CP1', '-', 'CP2', '-', 'CP6', '-'], ['P7', '-', 'P3', '-', 'PZ', '-', 'P4', '-', 'P8'], ['-', '-', '-', 'PO3', '-', 'PO4', '-', '-', '-'], ['-', '-', '-', 'O1', 'OZ', 'O2', '-', '-', '-']] DEAP_CHANNEL_LOCATION_DICT = format_channel_location_dict(DEAP_CHANNEL_LIST, DEAP_LOCATION_LIST) dataset = DEAPDataset(root_path='./data_preprocessed_python', online_transform=transforms.Compose([ transforms.Concatenate([ transforms.BandDifferentialEntropy(), transforms.BandPowerSpectralDensity() ]), transforms.ToGrid(DEAP_CHANNEL_LOCATION_DICT) ]), label_transform=transforms.Compose([ transforms.Select('valence'), transforms.Binary(5.0), ])) model = MTCNN(num_classes=2, in_channels=8, grid_size=(8, 9), dropout=0.2) x, y = next(iter(DataLoader(dataset, batch_size=64))) model(x)
- Parameters:
in_channels (int) – The feature dimension of each electrode, i.e., \(N\) in the paper. (default:
4
)grid_size (tuple) – Spatial dimensions of grid-like EEG representation. (default:
(8, 9)
)num_classes (int) – The number of classes to predict. (default:
2
)dropout (float) – Probability of an element to be zeroed in the dropout layers. (default:
0.2
)
- forward(x: Tensor) Tensor [source][source]¶
- Parameters:
x (torch.Tensor) – EEG signal representation, the ideal input shape is
[n, 8, 8, 9]
. Here,n
corresponds to the batch size,8
corresponds toin_channels
, and(8, 9)
corresponds togrid_size
.- Returns:
the predicted probability that the samples belong to the classes.
- Return type:
torch.Tensor[number of sample, number of classes]