SSTEmotionNet¶
- class torcheeg.models.SSTEmotionNet(grid_size: Tuple[int, int] = (32, 32), spectral_in_channels: int = 5, temporal_in_channels: int = 25, spectral_depth: int = 16, temporal_depth: int = 22, spectral_growth_rate: int = 12, temporal_growth_rate: int = 24, num_dense_block: int = 3, hid_channels: int = 50, densenet_dropout: float = 0.0, task_dropout: float = 0.0, num_classes: int = 3)[source][source]¶
Spatial-Spectral-Temporal based Attention 3D Dense Network (SST-EmotionNet) for EEG emotion recognition. For more details, please refer to the following information.
Paper: Jia Z, Lin Y, Cai X, et al. Sst-emotionnet: Spatial-spectral-temporal based attention 3d dense network for eeg emotion recognition[C]//Proceedings of the 28th ACM International Conference on Multimedia. 2020: 2909-2917.
Related Project: https://github.com/ziyujia/SST-EmotionNet
Related Project: https://github.com/LexieLiu01/SST-Emotion-Net-Pytorch-Version-
Below is a recommended suite for use in emotion recognition tasks:
from torcheeg.datasets import DEAPDataset from torcheeg import transforms from torcheeg.datasets.constants import DEAP_CHANNEL_LOCATION_DICT from torcheeg.models import SSTEmotionNet from torch.utils.data import DataLoader dataset = DEAPDataset(root_path='./data_preprocessed_python', offline_transform=transforms.Compose([ transforms.BaselineRemoval(), transforms.Concatenate([ transforms.Compose([ transforms.BandDifferentialEntropy(sampling_rate=128), transforms.MeanStdNormalize() ]), transforms.Compose([ transforms.Downsample(num_points=32), transforms.MinMaxNormalize() ]) ]), transforms.ToInterpolatedGrid(DEAP_CHANNEL_LOCATION_DICT) ]), online_transform=transforms.Compose([ transforms.ToTensor(), transforms.Resize((16, 16)) ]), label_transform=transforms.Compose([ transforms.Select('valence'), transforms.Binary(5.0), ])) model = SSTEmotionNet(temporal_in_channels=32, spectral_in_channels=4, grid_size=(16, 16), num_classes=2) x, y = next(iter(DataLoader(dataset, batch_size=64))) model(x)
- Parameters:
grid_size (tuple) – Spatial dimensions of grid-like EEG representation. (default:
(16, 16)
)spectral_in_channels (int) – How many 2D maps are stacked in the 3D spatial-spectral representation. (default:
5
)temporal_in_channels (int) – How many 2D maps are stacked in the 3D spatial-temporal representation. (default:
25
)spectral_depth (int) – The number of layers in spatial-spectral stream. (default:
16
)temporal_depth (int) – The number of layers in spatial-temporal stream. (default:
22
)spectral_growth_rate (int) – The growth rate of spatial-spectral stream. (default:
12
)temporal_growth_rate (int) – The growth rate of spatial-temporal stream. (default:
24
)num_dense_block (int) – The number of A3DBs to add to end (default:
3
)hid_channels (int) – The basic hidden channels in the network blocks. (default:
50
)densenet_dropout (int) – Probability of an element to be zeroed in the dropout layers from densenet blocks. (default:
0.0
)task_dropout (int) – Probability of an element to be zeroed in the dropout layers from task-specific classification blocks. (default:
0.0
)num_classes (int) – The number of classes to predict. (default:
2
)
- forward(x: Tensor)[source][source]¶
- Parameters:
x (torch.Tensor) – EEG signal representation, the ideal input shape is
[n, 30, 16, 16]
. Here,n
corresponds to the batch size,36
corresponds to the sum ofspectral_in_channels
(e.g., 5) andtemporal_in_channels
(e.g., 25), and(16, 16)
corresponds togrid_size
. It is worth noting that the firstspectral_in_channels
channels should represent spectral information.- Returns:
the predicted probability that the samples belong to the classes.
- Return type:
torch.Tensor[number of sample, number of classes]