FBMSNet¶
- class torcheeg.models.FBMSNet(in_channels: int, num_electrodes: int, chunk_size: int, num_classes: int = 4, stride_factor: int = 4, temporal: str = 'LogVarLayer', num_feature: int = 36, dilatability: int = 8)[source][source]¶
FBMSNet, a novel multiscale temporal convolutional neural network for MI decoding tasks, employs Mixed Conv to extract multiscale temporal features which enhance the intra-class compactness and improve the inter-class separability with the joint supervision of the center loss andcenter loss.
Paper: FBMSNet: A Filter-Bank Multi-Scale Convolutional Neural Network for EEG-Based Motor Imagery Decoding
Related Project: https://github.com/Want2Vanish/FBMSNet
Below is a example to explain how to use this model. Firstly we should transform eeg signal to several nonoverlapping frequency bands by
torcheeg.transforms.BandSignal# Define 9 nonoverlapping frequency bands, each with a 4 Hz bandwidth and spanning from 4 to 40 Hz. Freq_range_per_band = {'sub band1': [4, 8], 'sub band2': [8, 12], 'sub band3': [12, 16], 'sub band4': [16, 20], 'sub band5': [20, 24], 'sub band6': [24, 28], 'sub band7': [28, 32], 'sub band8': [32, 36], 'sub band9': [36, 40]} dataset =BCICIV2aDataset(io_path=f'./tmp_out/bciciv2a/band_9_filters', root_path='./BCICIV_2a_mat', chunk_size=512, offline_transform=transforms.BandSignal(band_dict=Freq_range_per_band, sampling_rate=250), online_transform=transforms.ToTensor(), label_transform=transforms.Compose([ transforms.Select('label'), transforms.Lambda(lambda x:x-1), ])) data = Dataloader(dataset) model = FBMSNet(num_classes=4, num_electrodes=22, chunk_size=512, in_channels=9 )
There are two ways to use the model. The first one, the effective way to get the prediction result but it don’t output the decoded feature.
x,y = next(iter(data)) pred = model(x)
To obtain the decoded feature, use
decodermethod. If we want to obtain prediction results based on the encoded features, useclassifiermethod.x,y = next(iter(data)) feature = model.decoder(x) pred = model.classifier(feature)
- Parameters:
num_electrodes (int) – The number of electrodes.
chunk_size (int) – Number of data points included in each EEG chunk.
in_channels (int) – The number of channels of the signal corresponding to each electrode. If the original signal is used as input, in_channels is set to 1; if the original signal is split into multiple sub-bands, in_channels is set to the number of bands. (default:
9)num_classes (int) – The number of classes to predict. (default:
4)stride_factor (int) – The stride factor. Please make sure the chunk_size parameter is a multiple of stride_factor parameter in order to init model successfully. (default:
4)temporal (str) – The temporal layer used, with options including VarLayer, StdLayer, LogVarLayer, MeanLayer, and MaxLayer, used to compute statistics using different techniques in the temporal dimension. (default:
LogVarLayer)num_feature (int) – The number of Mixed Conv output channels which can stand for various kinds of feature. (default:
36)dilatability (int) – The expansion multiple of the channels after the input bands pass through spatial convolutional blocks. (default:
8
- decoder(x)[source][source]¶
- Parameters:
x (torch.Tensor) – EEG signal representation, the ideal input shape is
[n, in_channel, num_electrodes, chunk_size ]. Here,ncorresponds to the batch size,in_channelscorresponds to the number of sub bands.- Returns:
The extracted deep features.
- Return type:
torch.Tensor[size of batch, length of deep feature code]
- classifier(feature)[source][source]¶
With feature which is ouput by decoder inputed,the classifier ouput the predicted probability that the samples belong to the classes.
- Parameters:
feature (torch.Tensor) – The extracted deep features. The ideal input shape is
[batch size,1152]`where feature dim is fixed as :obj:`1152.- Returns:
The predicted probability that the samples belong to the classes.
- Return type:
torch.Tensor[size of batch, num_classes]
- forward(x)[source][source]¶
- Parameters:
x (torch.Tensor) – EEG signal representation, the ideal input shape is
[n, in_channel, num_electrodes, chunk_size ]. Here,ncorresponds to the batch size- Returns:
The predicted probability that the samples belong to the classes.
- Return type:
torch.Tensor[size of batch,number of classes]