Source code for torcheeg.models.cnn.msrn

from typing import Tuple

import torch
import torch.nn as nn


[docs]class MSRN(torch.nn.Module): r''' Multi-scale Residual Network (MSRN). For more details, please refer to the following information. - Paper: Li J, Hua H, Xu Z, et al. Cross-subject EEG emotion recognition combined with connectivity features and meta-transfer learning[J]. Computers in Biology and Medicine, 2022, 145: 105519. - URL: https://www.sciencedirect.com/science/article/abs/pii/S0010482522003110 - Related Project: https://github.com/ljy-scut/MTL-MSRN Below is a recommended suite for use in emotion recognition tasks: .. code-block:: python dataset = DEAPDataset(io_path=f'./deap', root_path='./data_preprocessed_python', offline_transform=transforms.Compose([ transforms.BandSignal(), transforms.PearsonCorrelation() ]), online_transform=transforms.ToTensor(), label_transform=transforms.Compose([ transforms.Select('valence'), transforms.Binary(5.0), ])) model = MSRN(num_classes=2, in_channels=4) Args: in_channels (int): The feature dimension of each electrode. (defualt: :obj:`4`) hid_channels (int): The number of kernels in the first convolutional layers. (defualt: :obj:`64`) num_classes (int): The number of classes to predict. (defualt: :obj:`2`) dropout (float): Probability of an element to be zeroed in the dropout layers. (defualt: :obj:`0.25`) ''' def __init__(self, in_channels: int = 4, hid_channels: int = 64, num_classes: int = 2, dropout: float = 0.5): super().__init__() self.in_channels = in_channels self.hid_channels = hid_channels self.num_classes = num_classes self.dropout = dropout self.conv1 = nn.Sequential( nn.Conv2d(self.in_channels, self.hid_channels, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(self.hid_channels), nn.ReLU()) self.pool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.ms_module_conv3x3 = nn.Sequential( nn.Conv2d(self.hid_channels, self.in_channels, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(self.in_channels), nn.ReLU(), nn.Conv2d(self.in_channels, self.hid_channels, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(self.hid_channels), nn.ReLU(), nn.Conv2d(self.hid_channels, self.hid_channels, kernel_size=3, stride=2, padding=1), nn.BatchNorm2d(self.hid_channels), nn.ReLU(), nn.Conv2d(self.hid_channels, self.hid_channels * 2, kernel_size=3, stride=2, padding=1), nn.BatchNorm2d(self.hid_channels * 2), nn.ReLU(), nn.Conv2d(self.hid_channels * 2, self.hid_channels * 4, kernel_size=3, stride=2, padding=1), nn.BatchNorm2d(self.hid_channels * 4), nn.ReLU()) self.ms_module_conv5x5 = nn.Sequential( nn.Conv2d(self.hid_channels, self.in_channels, kernel_size=5, stride=1, padding=2), nn.BatchNorm2d(self.in_channels), nn.ReLU(), nn.Conv2d(self.in_channels, self.hid_channels, kernel_size=5, stride=1, padding=2), nn.BatchNorm2d(self.hid_channels), nn.ReLU(), nn.Conv2d(self.hid_channels, self.hid_channels, kernel_size=5, stride=2, padding=2), nn.BatchNorm2d(self.hid_channels), nn.ReLU(), nn.Conv2d(self.hid_channels, self.hid_channels * 2, kernel_size=5, stride=2, padding=2), nn.BatchNorm2d(self.hid_channels * 2), nn.ReLU(), nn.Conv2d(self.hid_channels * 2, self.hid_channels * 4, kernel_size=5, stride=2, padding=2), nn.BatchNorm2d(self.hid_channels * 4), nn.ReLU()) self.ms_module_conv7x7 = nn.Sequential( nn.Conv2d(self.hid_channels, self.in_channels, kernel_size=7, stride=1, padding=3), nn.BatchNorm2d(self.in_channels), nn.ReLU(), nn.Conv2d(self.in_channels, self.hid_channels, kernel_size=7, stride=1, padding=3), nn.BatchNorm2d(self.hid_channels), nn.ReLU(), nn.Conv2d(self.hid_channels, self.hid_channels, kernel_size=7, stride=2, padding=3), nn.BatchNorm2d(self.hid_channels), nn.ReLU(), nn.Conv2d(self.hid_channels, self.hid_channels * 2, kernel_size=7, stride=2, padding=3), nn.BatchNorm2d(self.hid_channels * 2), nn.ReLU(), nn.Conv2d(self.hid_channels * 2, self.hid_channels * 4, kernel_size=7, stride=2, padding=3), nn.BatchNorm2d(self.hid_channels * 4), nn.ReLU()) self.pool2 = nn.AdaptiveAvgPool2d(output_size=1) self.cls = nn.Sequential( nn.Linear(self.hid_channels * 12, self.hid_channels * 12), nn.ReLU(), nn.Dropout(dropout), nn.Linear(self.hid_channels * 12, self.num_classes))
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: r''' Args: x (torch.Tensor): EEG signal representation, the ideal input shape is :obj:`[n, 4, 32, 32]`. Here, :obj:`n` corresponds to the batch size, :obj:`4` corresponds to :obj:`in_channels`, and :obj:`(32, 32)` corresponds to :obj:`grid_size`. Returns: torch.Tensor[number of sample, number of classes]: the predicted probability that the samples belong to the classes. ''' x = self.conv1(x) x = self.pool1(x) x1 = self.ms_module_conv3x3(x) x2 = self.ms_module_conv5x5(x) x3 = self.ms_module_conv7x7(x) x1 = self.pool2(x1) x2 = self.pool2(x2) x3 = self.pool2(x3) x = torch.cat([x1, x2, x3], dim=1) x = x.squeeze(dim=-1).squeeze(dim=-1) x = self.cls(x) return x