SSTEmotionNet¶
- class torcheeg.models.SSTEmotionNet(grid_size: Tuple[int, int] = (32, 32), spectral_in_channels: int = 5, temporal_in_channels: int = 25, spectral_depth: int = 16, temporal_depth: int = 22, spectral_growth_rate: int = 12, temporal_growth_rate: int = 24, num_dense_block: int = 3, hid_channels: int = 50, densenet_dropout: float = 0.0, task_dropout: float = 0.0, num_classes: int = 3)[source][source]¶
Spatial-Spectral-Temporal based Attention 3D Dense Network (SST-EmotionNet) for EEG emotion recognition. For more details, please refer to the following information.
Paper: Jia Z, Lin Y, Cai X, et al. Sst-emotionnet: Spatial-spectral-temporal based attention 3d dense network for eeg emotion recognition[C]//Proceedings of the 28th ACM International Conference on Multimedia. 2020: 2909-2917.
Related Project: https://github.com/ziyujia/SST-EmotionNet
Related Project: https://github.com/LexieLiu01/SST-Emotion-Net-Pytorch-Version-
Below is a recommended suite for use in emotion recognition tasks:
dataset = DEAPDataset(io_path=f'./deap', root_path='./data_preprocessed_python', offline_transform=transforms.Compose([ transforms.BaselineRemoval(), transforms.Concatenate([ transforms.Compose([ transforms.BandDifferentialEntropy(sampling_rate=128), transforms.MeanStdNormalize() ]), transforms.Compose([ transforms.Downsample(num_points=32), transforms.MinMaxNormalize() ]) ]), transforms.ToInterpolatedGrid(DEAP_CHANNEL_LOCATION_DICT) ]), online_transform=transforms.Compose([ transforms.ToTensor(), transforms.Resize((16, 16)) ]), label_transform=transforms.Compose([ transforms.Select('valence'), transforms.Binary(5.0), ])) model = SSTEmotionNet(temporal_in_channels=32, spectral_in_channels=4, grid_size=(16, 16), num_classes=2)
- Parameters:
grid_size (tuple) – Spatial dimensions of grid-like EEG representation. (default:
(16, 16))spectral_in_channels (int) – How many 2D maps are stacked in the 3D spatial-spectral representation. (default:
5)temporal_in_channels (int) – How many 2D maps are stacked in the 3D spatial-temporal representation. (default:
25)spectral_depth (int) – The number of layers in spatial-spectral stream. (default:
16)temporal_depth (int) – The number of layers in spatial-temporal stream. (default:
22)spectral_growth_rate (int) – The growth rate of spatial-spectral stream. (default:
12)temporal_growth_rate (int) – The growth rate of spatial-temporal stream. (default:
24)num_dense_block (int) – The number of A3DBs to add to end (default:
3)hid_channels (int) – The basic hidden channels in the network blocks. (default:
50)densenet_dropout (int) – Probability of an element to be zeroed in the dropout layers from densenet blocks. (default:
0.0)task_dropout (int) – Probability of an element to be zeroed in the dropout layers from task-specific classification blocks. (default:
0.0)num_classes (int) – The number of classes to predict. (default:
2)
- forward(x: Tensor)[source][source]¶
- Parameters:
x (torch.Tensor) – EEG signal representation, the ideal input shape is
[n, 30, 16, 16]. Here,ncorresponds to the batch size,36corresponds to the sum ofspectral_in_channels(e.g., 5) andtemporal_in_channels(e.g., 25), and(16, 16)corresponds togrid_size. It is worth noting that the firstspectral_in_channelschannels should represent spectral information.- Returns:
the predicted probability that the samples belong to the classes.
- Return type:
torch.Tensor[number of sample, number of classes]