FBMSNet¶
- class torcheeg.models.FBMSNet(in_channels: int, num_electrodes: int, chunk_size: int, num_classes: int = 4, stride_factor: int = 4, temporal: str = 'LogVarLayer', num_feature: int = 36, dilatability: int = 8)[source][source]¶
FBMSNet, a novel multiscale temporal convolutional neural network for MI decoding tasks, employs Mixed Conv to extract multiscale temporal features which enhance the intra-class compactness and improve the inter-class separability with the joint supervision of the center loss andcenter loss.
Paper: FBMSNet: A Filter-Bank Multi-Scale Convolutional Neural Network for EEG-Based Motor Imagery Decoding
Related Project: https://github.com/Want2Vanish/FBMSNet
Below is a example to explain how to use this model. Firstly we should transform eeg signal to several nonoverlapping frequency bands by
torcheeg.transforms.BandSignal
from torcheeg.datasets import BCICIV2aDataset from torcheeg import transforms from torcheeg.models import FBMSNet from torch.utils.data import DataLoader freq_range_per_band = { 'sub band1': [4, 8], 'sub band2': [8, 12], 'sub band3': [12, 16], 'sub band4': [16, 20], 'sub band5': [20, 24], 'sub band6': [24, 28], 'sub band7': [28, 32], 'sub band8': [32, 36], 'sub band9': [36, 40] } dataset = BCICIV2aDataset(root_path='./BCICIV_2a_mat', chunk_size=512, offline_transform=transforms.BandSignal(band_dict=freq_range_per_band, sampling_rate=250), online_transform=transforms.ToTensor(), label_transform=transforms.Compose( [transforms.Select('label'), transforms.Lambda(lambda x: x - 1)])) model = FBMSNet(num_classes=4, num_electrodes=22, chunk_size=512, in_channels=9) x, y = next(iter(DataLoader(dataset, batch_size=64))) model(x)
- Parameters:
num_electrodes (int) – The number of electrodes.
chunk_size (int) – Number of data points included in each EEG chunk.
in_channels (int) – The number of channels of the signal corresponding to each electrode. If the original signal is used as input, in_channels is set to 1; if the original signal is split into multiple sub-bands, in_channels is set to the number of bands. (default:
9
)num_classes (int) – The number of classes to predict. (default:
4
)stride_factor (int) – The stride factor. Please make sure the chunk_size parameter is a multiple of stride_factor parameter in order to init model successfully. (default:
4
)temporal (str) – The temporal layer used, with options including VarLayer, StdLayer, LogVarLayer, MeanLayer, and MaxLayer, used to compute statistics using different techniques in the temporal dimension. (default:
LogVarLayer
)num_feature (int) – The number of Mixed Conv output channels which can stand for various kinds of feature. (default:
36
)dilatability (int) – The expansion multiple of the channels after the input bands pass through spatial convolutional blocks. (default:
8
- forward(x)[source][source]¶
- Parameters:
x (torch.Tensor) – EEG signal representation, the ideal input shape is
[n, in_channel, num_electrodes, chunk_size ]
. Here,n
corresponds to the batch size- Returns:
The predicted probability that the samples belong to the classes.
- Return type:
torch.Tensor[size of batch,number of classes]