Shortcuts

Source code for torcheeg.models.cnn.sst_emotion_net

from collections import OrderedDict
from typing import Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F


[docs]class SSTEmotionNet(nn.Module): r''' Spatial-Spectral-Temporal based Attention 3D Dense Network (SST-EmotionNet) for EEG emotion recognition. For more details, please refer to the following information. - Paper: Jia Z, Lin Y, Cai X, et al. Sst-emotionnet: Spatial-spectral-temporal based attention 3d dense network for eeg emotion recognition[C]//Proceedings of the 28th ACM International Conference on Multimedia. 2020: 2909-2917. - URL: https://dl.acm.org/doi/abs/10.1145/3394171.3413724 - Related Project: https://github.com/ziyujia/SST-EmotionNet - Related Project: https://github.com/LexieLiu01/SST-Emotion-Net-Pytorch-Version- Below is a recommended suite for use in emotion recognition tasks: .. code-block:: python from torcheeg.datasets import DEAPDataset from torcheeg import transforms from torcheeg.datasets.constants import DEAP_CHANNEL_LOCATION_DICT from torcheeg.models import SSTEmotionNet from torch.utils.data import DataLoader dataset = DEAPDataset(root_path='./data_preprocessed_python', offline_transform=transforms.Compose([ transforms.BaselineRemoval(), transforms.Concatenate([ transforms.Compose([ transforms.BandDifferentialEntropy(sampling_rate=128), transforms.MeanStdNormalize() ]), transforms.Compose([ transforms.Downsample(num_points=32), transforms.MinMaxNormalize() ]) ]), transforms.ToInterpolatedGrid(DEAP_CHANNEL_LOCATION_DICT) ]), online_transform=transforms.Compose([ transforms.ToTensor(), transforms.Resize((16, 16)) ]), label_transform=transforms.Compose([ transforms.Select('valence'), transforms.Binary(5.0), ])) model = SSTEmotionNet(temporal_in_channels=32, spectral_in_channels=4, grid_size=(16, 16), num_classes=2) x, y = next(iter(DataLoader(dataset, batch_size=64))) model(x) Args: grid_size (tuple): Spatial dimensions of grid-like EEG representation. (default: :obj:`(16, 16)`) spectral_in_channels (int): How many 2D maps are stacked in the 3D spatial-spectral representation. (default: :obj:`5`) temporal_in_channels (int): How many 2D maps are stacked in the 3D spatial-temporal representation. (default: :obj:`25`) spectral_depth (int): The number of layers in spatial-spectral stream. (default: :obj:`16`) temporal_depth (int): The number of layers in spatial-temporal stream. (default: :obj:`22`) spectral_growth_rate (int): The growth rate of spatial-spectral stream. (default: :obj:`12`) temporal_growth_rate (int): The growth rate of spatial-temporal stream. (default: :obj:`24`) num_dense_block (int): The number of A3DBs to add to end (default: :obj:`3`) hid_channels (int): The basic hidden channels in the network blocks. (default: :obj:`50`) densenet_dropout (int): Probability of an element to be zeroed in the dropout layers from densenet blocks. (default: :obj:`0.0`) task_dropout (int): Probability of an element to be zeroed in the dropout layers from task-specific classification blocks. (default: :obj:`0.0`) num_classes (int): The number of classes to predict. (default: :obj:`2`) ''' def __init__(self, grid_size: Tuple[int, int] = (32, 32), spectral_in_channels: int = 5, temporal_in_channels: int = 25, spectral_depth: int = 16, temporal_depth: int = 22, spectral_growth_rate: int = 12, temporal_growth_rate: int = 24, num_dense_block: int = 3, hid_channels: int = 50, densenet_dropout: float = 0.0, task_dropout: float = 0.0, num_classes: int = 3): super(SSTEmotionNet, self).__init__() self.grid_size = grid_size self.spectral_in_channels = spectral_in_channels self.temporal_in_channels = temporal_in_channels self.spectral_depth = spectral_depth self.spectral_growth_rate = spectral_growth_rate self.temporal_growth_rate = temporal_growth_rate self.num_dense_block = num_dense_block self.hid_channels = hid_channels self.densenet_dropout = densenet_dropout self.task_dropout = task_dropout self.num_classes = num_classes assert grid_size[0] >= 16 and grid_size[ 1] >= 16, 'The height and width of the grid must be greater than or equal to 16. Please upsample the EEG grid.' self.spatial_spectral = DenseNet3D(grid_size=grid_size, in_channels=spectral_in_channels, depth=spectral_depth, num_dense_block=num_dense_block, growth_rate=spectral_growth_rate, reduction=0.5, bottleneck=True, dropout=densenet_dropout) self.spatial_temporal = DenseNet3D(grid_size=grid_size, in_channels=temporal_in_channels, depth=temporal_depth, num_dense_block=num_dense_block, growth_rate=temporal_growth_rate, bottleneck=True, subsample_initial_block=True, dropout=densenet_dropout) layers = [] spectral_out, temporal_out = self.get_feature_dims() layers.append(nn.Linear(spectral_out + temporal_out, hid_channels)) layers.append(nn.Dropout(p=task_dropout)) layers.append(nn.Linear(hid_channels, num_classes)) self.layers = nn.ModuleList(layers) def get_feature_dims(self): mock_eeg_s = torch.randn(2, self.grid_size[0], self.grid_size[1], self.spectral_in_channels) mock_eeg_t = torch.randn(2, self.grid_size[0], self.grid_size[1], self.temporal_in_channels) spectral_output = self.spatial_spectral(mock_eeg_s) temporal_output = self.spatial_temporal(mock_eeg_t) return spectral_output.shape[1], temporal_output.shape[1]
[docs] def forward(self, x: torch.Tensor): r''' Args: x (torch.Tensor): EEG signal representation, the ideal input shape is :obj:`[n, 30, 16, 16]`. Here, :obj:`n` corresponds to the batch size, :obj:`36` corresponds to the sum of :obj:`spectral_in_channels` (e.g., 5) and :obj:`temporal_in_channels` (e.g., 25), and :obj:`(16, 16)` corresponds to :obj:`grid_size`. It is worth noting that the first :obj:`spectral_in_channels` channels should represent spectral information. Returns: torch.Tensor[number of sample, number of classes]: the predicted probability that the samples belong to the classes. ''' assert x.shape[1] == ( self.spectral_in_channels + self.temporal_in_channels ), f'The input number of channels is {x.shape[1]}, but the expected number of channels is the number of spectral channels {self.spectral_in_channels} plus the number of temporal channels {self.temporal_in_channels}.' spectral_input = x[:, :self.spectral_in_channels] temporal_input = x[:, self.spectral_in_channels:] spectral_input = spectral_input.permute(0, 2, 3, 1) temporal_input = temporal_input.permute(0, 2, 3, 1) spectral_output = self.spatial_spectral(spectral_input) temporal_output = self.spatial_temporal(temporal_input) output = torch.cat([spectral_output, temporal_output], dim=1) for layer in self.layers: output = layer(output) return output
class DenseNet3D(nn.Module): def __init__( self, grid_size, in_channels, depth=40, num_dense_block=3, growth_rate=12, bottleneck=False, reduction=0.0, dropout=None, subsample_initial_block=False, ): super(DenseNet3D, self).__init__() self.grid_size = grid_size self.in_channels = in_channels if reduction != 0.0: assert reduction <= 1.0 and reduction > 0.0, 'reduction value must lie between 0.0 and 1.0.' assert (depth - 4) % 3 == 0, 'Depth must be 3 N + 4.' count = int((depth - 4) / 3) if bottleneck: count = count // 2 num_layers = [count for _ in range(num_dense_block)] num_filters = 2 * growth_rate compression = 1.0 - reduction if subsample_initial_block: initial_kernel = (5, 5, 3) initial_strides = (2, 2, 1) else: initial_kernel = (3, 3, 1) initial_strides = (1, 1, 1) layers = [] if subsample_initial_block: conv_layer = nn.Conv3d(1, num_filters, initial_kernel, stride=initial_strides, padding=(2, 2, 1), bias=False) else: conv_layer = nn.Conv3d(1, num_filters, initial_kernel, stride=initial_strides, padding=(1, 1, 0), bias=False) layers.append(("conv1", conv_layer)) if subsample_initial_block: layers.append(("batch1", nn.BatchNorm3d(num_filters, eps=1.1e-5))) layers.append(("active1", nn.ReLU())) layers.append(("maxpool", nn.MaxPool3d((2, 2, 2), stride=(2, 2, 2), padding=(0, 0, 1)))) self.conv_layer = nn.Sequential(OrderedDict(layers)) grid_height, grid_width, grid_channels = self.get_feature_dims() layers = [] for block_idx in range(num_dense_block - 1): layers.append( Attention(grid_size=(grid_height, grid_width), in_channels=grid_channels)) layers.append( DenseBlock(num_layers[block_idx], num_filters, growth_rate, bottleneck=bottleneck, dropout=dropout)) num_filters = num_filters + growth_rate * num_layers[block_idx] layers.append( Transition(num_filters, num_filters, compression=compression)) num_filters = int(num_filters * compression) grid_height = int(grid_height / 2) grid_width = int(grid_width / 2) grid_channels = int(grid_channels / 2) layers.append( Attention(grid_size=(grid_height, grid_width), in_channels=grid_channels)) layers.append( DenseBlock(num_layers[block_idx], num_filters, growth_rate, bottleneck=bottleneck, dropout=dropout)) num_filters = num_filters + growth_rate * num_layers[block_idx] self.layers = nn.ModuleList(layers) final_layers = [] final_layers.append(nn.BatchNorm3d(num_filters, eps=1.1e-5)) final_layers.append(nn.ReLU()) final_layers.append( nn.AvgPool3d((grid_height, grid_width, grid_channels))) self.final_layers = nn.ModuleList(final_layers) def get_feature_dims(self): mock_eeg = torch.randn(2, self.grid_size[0], self.grid_size[1], self.in_channels) mock_eeg = mock_eeg.unsqueeze(1) mock_eeg = self.conv_layer(mock_eeg) return mock_eeg.shape[2], mock_eeg.shape[3], mock_eeg.shape[4] def forward(self, x): x = x.unsqueeze(1) x = self.conv_layer(x) for layer in self.layers: x = layer(x) for layer in self.final_layers: x = layer(x) x = x.view(x.shape[0], -1) return x class DenseBlock(nn.Module): def __init__(self, num_layers, num_filters, growth_rate, bottleneck=False, dropout=None): super(DenseBlock, self).__init__() layers = [] for i in range(num_layers): convLayer = ConvBlock(num_filters, growth_rate, bottleneck, dropout) num_filters = num_filters + growth_rate layers.append(convLayer) self.layers = nn.ModuleList(layers) def forward(self, x): for layer in self.layers: cb = layer(x) x = torch.cat([x, cb], dim=1) return x class ConvBlock(nn.Module): def __init__(self, input_channel, num_filters, bottleneck=False, dropout=None, conv1x1=True): super(ConvBlock, self).__init__() layers = [] layers.append(nn.BatchNorm3d(input_channel, eps=1.1e-5)) layers.append(nn.ReLU()) if bottleneck: inter_channel = num_filters * 4 layers.append( nn.Conv3d(input_channel, inter_channel, (1, 1, 1), padding=0, bias=False)) layers.append(nn.BatchNorm3d(inter_channel, eps=1.1e-5)) layers.append(nn.ReLU()) layers.append( nn.Conv3d(inter_channel, num_filters, (3, 3, 1), padding=(1, 1, 0), bias=False)) if conv1x1: layers.append( nn.Conv3d(num_filters, num_filters, (1, 1, 1), padding=(0, 0, 0), bias=False)) layers.append( nn.Conv3d(num_filters, num_filters, (1, 1, 3), padding=(0, 0, 1), bias=False)) if dropout: layers.append(nn.Dropout(dropout)) self.layers = nn.ModuleList(layers) def forward(self, x): for layer in self.layers: x = layer(x) return x class Transition(nn.Module): def __init__(self, input_channel, num_filters, compression=1.0): super(Transition, self).__init__() layers = [] layers.append(nn.BatchNorm3d(input_channel, eps=1.1e-5)) layers.append(nn.ReLU()) layers.append( nn.Conv3d(input_channel, int(num_filters * compression), (1, 1, 1), padding=0, bias=False)) layers.append(nn.AvgPool3d((2, 2, 2), stride=(2, 2, 2))) self.layers = nn.ModuleList(layers) def forward(self, x): for layer in self.layers: x = layer(x) return x class Attention(nn.Module): def __init__(self, grid_size, in_channels): super(Attention, self).__init__() num_spatial = int(grid_size[0]) * int(grid_size[1]) self.spatial_pool = nn.AvgPool3d(kernel_size=[1, 1, in_channels]) self.spatail_dense = nn.Linear(num_spatial, num_spatial) self.temporal_pool = nn.AvgPool3d( kernel_size=[grid_size[0], grid_size[1], 1]) self.temporal_dense = nn.Linear(in_channels, in_channels) def forward(self, x): out = x x = torch.mean(x, dim=1) x = x.unsqueeze(1) num_spatial = x.shape[2] * x.shape[3] num_temporal = x.shape[-1] spatial = self.spatial_pool(x) spatial = spatial.view(-1, num_spatial) spatial = self.spatail_dense(spatial) spatial = F.sigmoid(spatial) spatial = spatial.view(x.shape[0], 1, x.shape[2], x.shape[3], 1) out = out * spatial temporal = self.temporal_pool(x) temporal = temporal.view(-1, num_temporal) temporal = self.temporal_dense(temporal) temporal = F.sigmoid(temporal) temporal = temporal.view(x.shape[0], 1, 1, 1, x.shape[-1]) out = out * temporal return out
Read the Docs v: latest
Versions
latest
stable
v1.1.2
v1.1.1
v1.1.0
v1.0.11
v1.0.10
v1.0.9
v1.0.8.post1
v1.0.8
v1.0.7
v1.0.6
v1.0.4
v1.0.3
v1.0.2
v1.0.1
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources