Shortcuts

FBMSNet

class torcheeg.models.FBMSNet(in_channels: int, num_electrodes: int, chunk_size: int, num_classes: int = 4, stride_factor: int = 4, temporal: str = 'LogVarLayer', num_feature: int = 36, dilatability: int = 8)[source][source]

FBMSNet, a novel multiscale temporal convolutional neural network for MI decoding tasks, employs Mixed Conv to extract multiscale temporal features which enhance the intra-class compactness and improve the inter-class separability with the joint supervision of the center loss andcenter loss.

Below is a example to explain how to use this model. Firstly we should transform eeg signal to several nonoverlapping frequency bands by torcheeg.transforms.BandSignal

from torcheeg.datasets import BCICIV2aDataset
from torcheeg import transforms
from torcheeg.models import FBMSNet
from torch.utils.data import DataLoader

freq_range_per_band = {
    'sub band1': [4, 8],
    'sub band2': [8, 12],
    'sub band3': [12, 16],
    'sub band4': [16, 20],
    'sub band5': [20, 24],
    'sub band6': [24, 28],
    'sub band7': [28, 32],
    'sub band8': [32, 36],
    'sub band9': [36, 40]
}
dataset = BCICIV2aDataset(root_path='./BCICIV_2a_mat',
                          chunk_size=512,
                          offline_transform=transforms.BandSignal(band_dict=freq_range_per_band,
                                                                  sampling_rate=250),
                          online_transform=transforms.ToTensor(),
                          label_transform=transforms.Compose(
                              [transforms.Select('label'),
                              transforms.Lambda(lambda x: x - 1)]))

model = FBMSNet(num_classes=4, num_electrodes=22, chunk_size=512, in_channels=9)

x, y = next(iter(DataLoader(dataset, batch_size=64)))
model(x)
Parameters:
  • num_electrodes (int) – The number of electrodes.

  • chunk_size (int) – Number of data points included in each EEG chunk.

  • in_channels (int) – The number of channels of the signal corresponding to each electrode. If the original signal is used as input, in_channels is set to 1; if the original signal is split into multiple sub-bands, in_channels is set to the number of bands. (default: 9)

  • num_classes (int) – The number of classes to predict. (default: 4)

  • stride_factor (int) – The stride factor. Please make sure the chunk_size parameter is a multiple of stride_factor parameter in order to init model successfully. (default: 4)

  • temporal (str) – The temporal layer used, with options including VarLayer, StdLayer, LogVarLayer, MeanLayer, and MaxLayer, used to compute statistics using different techniques in the temporal dimension. (default: LogVarLayer)

  • num_feature (int) – The number of Mixed Conv output channels which can stand for various kinds of feature. (default: 36)

  • dilatability (int) – The expansion multiple of the channels after the input bands pass through spatial convolutional blocks. (default: 8

forward(x)[source][source]
Parameters:

x (torch.Tensor) – EEG signal representation, the ideal input shape is [n, in_channel, num_electrodes, chunk_size ]. Here, n corresponds to the batch size

Returns:

The predicted probability that the samples belong to the classes.

Return type:

torch.Tensor[size of batch,number of classes]

Read the Docs v: latest
Versions
latest
stable
v1.1.1
v1.1.0
v1.0.11
v1.0.10
v1.0.9
v1.0.8.post1
v1.0.8
v1.0.7
v1.0.6
v1.0.4
v1.0.3
v1.0.2
v1.0.1
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources