STNet¶
- class torcheeg.models.STNet(chunk_size: int = 128, grid_size: Tuple[int, int] = (9, 9), num_classes: int = 2, dropout: float = 0.2)[source][source]¶
Spatio-temporal Network (STNet). For more details, please refer to the following information.
Paper: Zhang Z, Zhong S, Liu Y. GANSER: A Self-supervised Data Augmentation Framework for EEG-based Emotion Recognition[J]. IEEE Transactions on Affective Computing, 2022.
Related Project: https://github.com/tczhangzhi/GANSER-PyTorch
Below is a recommended suite for use in emotion recognition tasks:
dataset = DEAPDataset(io_path=f'./deap', root_path='./data_preprocessed_python', offline_transform=transforms.Compose([ transforms.ToGrid(DEAP_CHANNEL_LOCATION_DICT) ]), online_transform=transforms.ToTensor(), label_transform=transforms.Compose([ transforms.Select('valence'), transforms.Binary(5.0), ])) model = STNet(num_classes=2, chunk_size=128, grid_size=(9, 9), dropout=0.2)
- Parameters
chunk_size (int) – Number of data points included in each EEG chunk, i.e., \(T\) in the paper. (defualt:
128)grid_size (tuple) – Spatial dimensions of grid-like EEG representation. (defualt:
(9, 9))num_classes (int) – The number of classes to predict. (defualt:
2)dropout (float) – Probability of an element to be zeroed in the dropout layers. (defualt:
0.2)
- forward(x: Tensor) Tensor[source][source]¶
- Parameters
x (torch.Tensor) – EEG signal representation, the ideal input shape is
[n, 128, 9, 9]. Here,ncorresponds to the batch size,128corresponds tochunk_size, and(9, 9)corresponds togrid_size.- Returns
the predicted probability that the samples belong to the classes.
- Return type
torch.Tensor[number of sample, number of classes]