Shortcuts

Source code for torcheeg.models.cnn.fbmsnet

import torch.nn.functional as F
from typing import Tuple, Optional
import torch
import torch.nn as nn
from .fbcnet import VarLayer, MaxLayer, StdLayer, LogVarLayer, LinearWithConstraint, MeanLayer, swish, Conv2dWithConstraint


## CONV_SAME_PADDING
def _calc_same_pad(i: int, k: int, s: int, d: int):
    return max((-(i // -s) - 1) * s + (k - 1) * d + 1 - i, 0)


def _same_pad_arg(input_size, kernel_size, stride, dilation, **_):
    ih, iw = input_size
    kh, kw = kernel_size
    pad_h = _calc_same_pad(ih, kh, stride[0], dilation[0])
    pad_w = _calc_same_pad(iw, kw, stride[1], dilation[1])
    return [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]


def conv2d_same(x,
                weight: torch.Tensor,
                bias: Optional[torch.Tensor] = None,
                stride: Tuple[int, int] = (1, 1),
                padding: Tuple[int, int] = (0, 0),
                dilation: Tuple[int, int] = (1, 1),
                groups: int = 1):
    ih, iw = x.size()[-2:]
    kh, kw = weight.size()[-2:]
    pad_h = _calc_same_pad(ih, kh, stride[0], dilation[0])
    pad_w = _calc_same_pad(iw, kw, stride[1], dilation[1])
    x = F.pad(x,
              [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2])
    return F.conv2d(x, weight, bias, stride, (0, 0), dilation, groups)


class SamePadConv2d(nn.Conv2d):
    """ Tensorflow like 'SAME' convolution wrapper for 2D convolutions
    """

    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 stride=1,
                 dilation=1,
                 groups=1,
                 bias=True):
        super(SamePadConv2d,
              self).__init__(in_channels, out_channels, kernel_size, stride, 0,
                             dilation, groups, bias)

    def forward(self, x):
        return conv2d_same(x, self.weight, self.bias, self.stride, self.padding,
                           self.dilation, self.groups)


## MIX_CONV
def create_conv2d_pad(in_chs, out_chs, kernel_size, **kwargs):
    padding = kwargs.pop('padding', '')
    kwargs.setdefault('bias', False)
    padding, is_dynamic = get_padding_value(padding, kernel_size, **kwargs)
    if is_dynamic:
        return SamePadConv2d(in_chs, out_chs, kernel_size, **kwargs)
    else:
        if isinstance(kernel_size, tuple):
            padding = (0, padding)
        return nn.Conv2d(in_chs,
                         out_chs,
                         kernel_size,
                         padding=padding,
                         **kwargs)


def _is_static_pad(kernel_size, stride=1, dilation=1, **_):
    return stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0


def _get_padding(kernel_size, stride=1, dilation=1, **_):
    if isinstance(kernel_size, tuple):
        kernel_size = max(kernel_size)
    padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2
    return padding


def _split_channels(num_chan, num_groups):
    split = [num_chan // num_groups for _ in range(num_groups)]
    split[0] += num_chan - sum(split)
    return split


def get_padding_value(padding, kernel_size, **kwargs):
    dynamic = False
    if isinstance(padding, str):
        # for any string padding, the padding will be calculated for you, one of three ways
        padding = padding.lower()
        if padding == 'same':
            dynamic = True
            padding = 0

        elif padding == 'valid':
            # 'VALID' padding, same as padding=0
            padding = 0
        else:
            # Default to PyTorch style 'same'-ish symmetric padding
            dynamic = True
            padding = 0
    else:
        dynamic = True
        padding = 0
    return padding, dynamic


class MixedConv2d(nn.ModuleDict):
    """ Mixed Grouped Convolution
    Based on MDConv and GroupedConv in MixNet impl:
      https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mixnet/custom_layers.py
    """

    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size=3,
                 stride=1,
                 padding='same',
                 dilation=1,
                 depthwise=False,
                 **kwargs):
        super(MixedConv2d, self).__init__()

        kernel_size = kernel_size if isinstance(kernel_size,
                                                list) else [kernel_size]
        num_groups = len(kernel_size)
        in_splits = _split_channels(in_channels, num_groups)
        out_splits = _split_channels(out_channels, num_groups)
        self.in_channels = sum(in_splits)
        self.out_channels = sum(out_splits)

        for idx, (k, in_ch,
                  out_ch) in enumerate(zip(kernel_size, in_splits, out_splits)):
            conv_groups = out_ch if depthwise else 1
            self.add_module(
                str(idx),
                create_conv2d_pad(in_ch,
                                  out_ch,
                                  k,
                                  stride=stride,
                                  padding=padding,
                                  dilation=dilation,
                                  groups=conv_groups,
                                  **kwargs))
        self.splits = in_splits

    def forward(self, x):
        x_split = torch.split(x, self.splits, 1)
        x_out = [conv(x_split[i]) for i, conv in enumerate(self.values())]
        x = torch.cat(x_out, 1)
        return x


[docs]class FBMSNet(nn.Module): r''' FBMSNet, a novel multiscale temporal convolutional neural network for MI decoding tasks, employs Mixed Conv to extract multiscale temporal features which enhance the intra-class compactness and improve the inter-class separability with the joint supervision of the center loss andcenter loss. - Paper: FBMSNet: A Filter-Bank Multi-Scale Convolutional Neural Network for EEG-Based Motor Imagery Decoding - URL: https://ieeexplore.ieee.org/document/9837422 - Related Project: https://github.com/Want2Vanish/FBMSNet Below is a example to explain how to use this model. Firstly we should transform eeg signal to several nonoverlapping frequency bands by :obj:`torcheeg.transforms.BandSignal` .. code-block:: python from torcheeg.datasets import BCICIV2aDataset from torcheeg import transforms from torcheeg.models import FBMSNet from torch.utils.data import DataLoader freq_range_per_band = { 'sub band1': [4, 8], 'sub band2': [8, 12], 'sub band3': [12, 16], 'sub band4': [16, 20], 'sub band5': [20, 24], 'sub band6': [24, 28], 'sub band7': [28, 32], 'sub band8': [32, 36], 'sub band9': [36, 40] } dataset = BCICIV2aDataset(root_path='./BCICIV_2a_mat', chunk_size=512, offline_transform=transforms.BandSignal(band_dict=freq_range_per_band, sampling_rate=250), online_transform=transforms.ToTensor(), label_transform=transforms.Compose( [transforms.Select('label'), transforms.Lambda(lambda x: x - 1)])) model = FBMSNet(num_classes=4, num_electrodes=22, chunk_size=512, in_channels=9) x, y = next(iter(DataLoader(dataset, batch_size=64))) model(x) Args: num_electrodes (int): The number of electrodes. chunk_size (int): Number of data points included in each EEG chunk. in_channels (int): The number of channels of the signal corresponding to each electrode. If the original signal is used as input, in_channels is set to 1; if the original signal is split into multiple sub-bands, in_channels is set to the number of bands. (default: :obj:`9`) num_classes (int): The number of classes to predict. (default: :obj:`4`) stride_factor (int): The stride factor. Please make sure the chunk_size parameter is a multiple of stride_factor parameter in order to init model successfully. (default: :obj:`4`) temporal (str): The temporal layer used, with options including VarLayer, StdLayer, LogVarLayer, MeanLayer, and MaxLayer, used to compute statistics using different techniques in the temporal dimension. (default: :obj:`LogVarLayer`) num_feature (int): The number of Mixed Conv output channels which can stand for various kinds of feature. (default: :obj:`36`) dilatability (int): The expansion multiple of the channels after the input bands pass through spatial convolutional blocks. (default: :obj:`8` ''' def __init__(self, in_channels: int, num_electrodes: int, chunk_size: int, num_classes: int = 4, stride_factor: int = 4, temporal: str = 'LogVarLayer', num_feature: int = 36, dilatability: int = 8): super(FBMSNet, self).__init__() self.num_classes = num_classes self.in_channels = in_channels self.num_electrodes = num_electrodes self.chunk_size = chunk_size self.stride_factor = stride_factor try: self.mixConv2d = nn.Sequential( MixedConv2d(in_channels=in_channels, out_channels=num_feature, kernel_size=[(1, 15), (1, 31), (1, 63), (1, 125)], stride=1, padding='', dilation=1, depthwise=False), nn.BatchNorm2d(num_feature), ) self.scb = self.SCB(in_chan=num_feature, out_chan=num_feature * dilatability, num_electrodes=int(num_electrodes)) # Formulate the temporal agreegator if temporal == 'VarLayer': self.temporal_layer = VarLayer(dim=3) elif temporal == 'StdLayer': self.temporal_layer = StdLayer(dim=3) elif temporal == 'LogVarLayer': self.temporal_layer = LogVarLayer(dim=3) elif temporal == 'MeanLayer': self.temporal_layer = MeanLayer(dim=3) elif temporal == 'MaxLayer': self.temporal_layer = MaxLayer(dim=3) else: raise NotImplementedError self.center_dim = self.feature_dim(in_channels, num_electrodes, chunk_size)[-1] self.fc = self.LastBlock(self.center_dim, num_classes) except: raise Exception( "Model init failed: The Chunksize must be a multiple of stride_factor.Please modify values of stride_factor or chunk_size." ) def SCB(self, in_chan, out_chan, num_electrodes, weight_norm=True, *args, **kwargs): return nn.Sequential( Conv2dWithConstraint(in_chan, out_chan, (num_electrodes, 1), groups=in_chan, max_norm=2, weight_norm=weight_norm, padding=0), nn.BatchNorm2d(out_chan), swish()) def LastBlock(self, inF, outF, weight_norm=True, *args, **kwargs): return nn.Sequential( LinearWithConstraint(inF, outF, max_norm=0.5, weight_norm=weight_norm, *args, **kwargs), nn.LogSoftmax(dim=1))
[docs] def forward(self, x): r''' Args: x (torch.Tensor): EEG signal representation, the ideal input shape is :obj:`[n, in_channel, num_electrodes, chunk_size ]`. Here, :obj:`n` corresponds to the batch size Returns: torch.Tensor[size of batch,number of classes]: The predicted probability that the samples belong to the classes. ''' x = self.mixConv2d(x) x = self.scb(x) x = x.reshape([ *x.shape[0:2], self.stride_factor, int(x.shape[3] / self.stride_factor) ]) x = self.temporal_layer(x) x = torch.flatten(x, start_dim=1) return self.fc(x)
def feature_dim(self, in_channels, num_electrodes, chunk_size): data = torch.ones((1, in_channels, num_electrodes, chunk_size)) x = self.mixConv2d(data) x = self.scb(x) x = x.reshape([*x.shape[0:2], self.stride_factor, -1]) x = self.temporal_layer(x) x = torch.flatten(x, start_dim=1) return x.size()

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources