Shortcuts

Source code for torcheeg.models.cnn.fbmsnet

import torch.nn.functional as F
from typing import Tuple, Optional
import torch
import torch.nn as nn
from .fbcnet import VarLayer, MaxLayer, StdLayer, LogVarLayer, LinearWithConstraint, MeanLayer, swish, Conv2dWithConstraint


## CONV_SAME_PADDING
def _calc_same_pad(i: int, k: int, s: int, d: int):
    return max((-(i // -s) - 1) * s + (k - 1) * d + 1 - i, 0)


def _same_pad_arg(input_size, kernel_size, stride, dilation, **_):
    ih, iw = input_size
    kh, kw = kernel_size
    pad_h = _calc_same_pad(ih, kh, stride[0], dilation[0])
    pad_w = _calc_same_pad(iw, kw, stride[1], dilation[1])
    return [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]


def conv2d_same(x,
                weight: torch.Tensor,
                bias: Optional[torch.Tensor] = None,
                stride: Tuple[int, int] = (1, 1),
                padding: Tuple[int, int] = (0, 0),
                dilation: Tuple[int, int] = (1, 1),
                groups: int = 1):
    ih, iw = x.size()[-2:]
    kh, kw = weight.size()[-2:]
    pad_h = _calc_same_pad(ih, kh, stride[0], dilation[0])
    pad_w = _calc_same_pad(iw, kw, stride[1], dilation[1])
    x = F.pad(x,
              [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2])
    return F.conv2d(x, weight, bias, stride, (0, 0), dilation, groups)


class SamePadConv2d(nn.Conv2d):
    """ Tensorflow like 'SAME' convolution wrapper for 2D convolutions
    """

    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 stride=1,
                 dilation=1,
                 groups=1,
                 bias=True):
        super(SamePadConv2d,
              self).__init__(in_channels, out_channels, kernel_size, stride, 0,
                             dilation, groups, bias)

    def forward(self, x):
        return conv2d_same(x, self.weight, self.bias, self.stride, self.padding,
                           self.dilation, self.groups)


## MIX_CONV
def create_conv2d_pad(in_chs, out_chs, kernel_size, **kwargs):
    padding = kwargs.pop('padding', '')
    kwargs.setdefault('bias', False)
    padding, is_dynamic = get_padding_value(padding, kernel_size, **kwargs)
    if is_dynamic:
        return SamePadConv2d(in_chs, out_chs, kernel_size, **kwargs)
    else:
        if isinstance(kernel_size, tuple):
            padding = (0, padding)
        return nn.Conv2d(in_chs,
                         out_chs,
                         kernel_size,
                         padding=padding,
                         **kwargs)


def _is_static_pad(kernel_size, stride=1, dilation=1, **_):
    return stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0


def _get_padding(kernel_size, stride=1, dilation=1, **_):
    if isinstance(kernel_size, tuple):
        kernel_size = max(kernel_size)
    padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2
    return padding


def _split_channels(num_chan, num_groups):
    split = [num_chan // num_groups for _ in range(num_groups)]
    split[0] += num_chan - sum(split)
    return split


def get_padding_value(padding, kernel_size, **kwargs):
    dynamic = False
    if isinstance(padding, str):
        # for any string padding, the padding will be calculated for you, one of three ways
        padding = padding.lower()
        if padding == 'same':
            dynamic = True
            padding = 0

        elif padding == 'valid':
            # 'VALID' padding, same as padding=0
            padding = 0
        else:
            # Default to PyTorch style 'same'-ish symmetric padding
            dynamic = True
            padding = 0
    else:
        dynamic = True
        padding = 0
    return padding, dynamic


class MixedConv2d(nn.ModuleDict):
    """ Mixed Grouped Convolution
    Based on MDConv and GroupedConv in MixNet impl:
      https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mixnet/custom_layers.py
    """

    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size=3,
                 stride=1,
                 padding='same',
                 dilation=1,
                 depthwise=False,
                 **kwargs):
        super(MixedConv2d, self).__init__()

        kernel_size = kernel_size if isinstance(kernel_size,
                                                list) else [kernel_size]
        num_groups = len(kernel_size)
        in_splits = _split_channels(in_channels, num_groups)
        out_splits = _split_channels(out_channels, num_groups)
        self.in_channels = sum(in_splits)
        self.out_channels = sum(out_splits)

        for idx, (k, in_ch,
                  out_ch) in enumerate(zip(kernel_size, in_splits, out_splits)):
            conv_groups = out_ch if depthwise else 1
            self.add_module(
                str(idx),
                create_conv2d_pad(in_ch,
                                  out_ch,
                                  k,
                                  stride=stride,
                                  padding=padding,
                                  dilation=dilation,
                                  groups=conv_groups,
                                  **kwargs))
        self.splits = in_splits

    def forward(self, x):
        x_split = torch.split(x, self.splits, 1)
        x_out = [conv(x_split[i]) for i, conv in enumerate(self.values())]
        x = torch.cat(x_out, 1)
        return x


[docs]class FBMSNet(nn.Module): r''' FBMSNet, a novel multiscale temporal convolutional neural network for MI decoding tasks, employs Mixed Conv to extract multiscale temporal features which enhance the intra-class compactness and improve the inter-class separability with the joint supervision of the center loss andcenter loss. - Paper: FBMSNet: A Filter-Bank Multi-Scale Convolutional Neural Network for EEG-Based Motor Imagery Decoding - URL: https://ieeexplore.ieee.org/document/9837422 - Related Project: https://github.com/Want2Vanish/FBMSNet Below is a example to explain how to use this model. Firstly we should transform eeg signal to several nonoverlapping frequency bands by :obj:`torcheeg.transforms.BandSignal` .. code-block:: python from torcheeg.datasets import BCICIV2aDataset from torcheeg import transforms from torcheeg.models import FBMSNet from torch.utils.data import DataLoader freq_range_per_band = { 'sub band1': [4, 8], 'sub band2': [8, 12], 'sub band3': [12, 16], 'sub band4': [16, 20], 'sub band5': [20, 24], 'sub band6': [24, 28], 'sub band7': [28, 32], 'sub band8': [32, 36], 'sub band9': [36, 40] } dataset = BCICIV2aDataset(root_path='./BCICIV_2a_mat', chunk_size=512, offline_transform=transforms.BandSignal(band_dict=freq_range_per_band, sampling_rate=250), online_transform=transforms.ToTensor(), label_transform=transforms.Compose( [transforms.Select('label'), transforms.Lambda(lambda x: x - 1)])) model = FBMSNet(num_classes=4, num_electrodes=22, chunk_size=512, in_channels=9) x, y = next(iter(DataLoader(dataset, batch_size=64))) model(x) Args: num_electrodes (int): The number of electrodes. chunk_size (int): Number of data points included in each EEG chunk. in_channels (int): The number of channels of the signal corresponding to each electrode. If the original signal is used as input, in_channels is set to 1; if the original signal is split into multiple sub-bands, in_channels is set to the number of bands. (default: :obj:`9`) num_classes (int): The number of classes to predict. (default: :obj:`4`) stride_factor (int): The stride factor. Please make sure the chunk_size parameter is a multiple of stride_factor parameter in order to init model successfully. (default: :obj:`4`) temporal (str): The temporal layer used, with options including VarLayer, StdLayer, LogVarLayer, MeanLayer, and MaxLayer, used to compute statistics using different techniques in the temporal dimension. (default: :obj:`LogVarLayer`) num_feature (int): The number of Mixed Conv output channels which can stand for various kinds of feature. (default: :obj:`36`) dilatability (int): The expansion multiple of the channels after the input bands pass through spatial convolutional blocks. (default: :obj:`8` ''' def __init__(self, in_channels: int, num_electrodes: int, chunk_size: int, num_classes: int = 4, stride_factor: int = 4, temporal: str = 'LogVarLayer', num_feature: int = 36, dilatability: int = 8): super(FBMSNet, self).__init__() self.num_classes = num_classes self.in_channels = in_channels self.num_electrodes = num_electrodes self.chunk_size = chunk_size self.stride_factor = stride_factor try: self.mixConv2d = nn.Sequential( MixedConv2d(in_channels=in_channels, out_channels=num_feature, kernel_size=[(1, 15), (1, 31), (1, 63), (1, 125)], stride=1, padding='', dilation=1, depthwise=False), nn.BatchNorm2d(num_feature), ) self.scb = self.SCB(in_chan=num_feature, out_chan=num_feature * dilatability, num_electrodes=int(num_electrodes)) # Formulate the temporal agreegator if temporal == 'VarLayer': self.temporal_layer = VarLayer(dim=3) elif temporal == 'StdLayer': self.temporal_layer = StdLayer(dim=3) elif temporal == 'LogVarLayer': self.temporal_layer = LogVarLayer(dim=3) elif temporal == 'MeanLayer': self.temporal_layer = MeanLayer(dim=3) elif temporal == 'MaxLayer': self.temporal_layer = MaxLayer(dim=3) else: raise NotImplementedError self.center_dim = self.feature_dim(in_channels, num_electrodes, chunk_size)[-1] self.fc = self.LastBlock(self.center_dim, num_classes) except: raise Exception( "Model init failed: The Chunksize must be a multiple of stride_factor.Please modify values of stride_factor or chunk_size." ) def SCB(self, in_chan, out_chan, num_electrodes, weight_norm=True, *args, **kwargs): return nn.Sequential( Conv2dWithConstraint(in_chan, out_chan, (num_electrodes, 1), groups=in_chan, max_norm=2, weight_norm=weight_norm, padding=0), nn.BatchNorm2d(out_chan), swish()) def LastBlock(self, inF, outF, weight_norm=True, *args, **kwargs): return nn.Sequential( LinearWithConstraint(inF, outF, max_norm=0.5, weight_norm=weight_norm, *args, **kwargs), nn.LogSoftmax(dim=1))
[docs] def forward(self, x): r''' Args: x (torch.Tensor): EEG signal representation, the ideal input shape is :obj:`[n, in_channel, num_electrodes, chunk_size ]`. Here, :obj:`n` corresponds to the batch size Returns: torch.Tensor[size of batch,number of classes]: The predicted probability that the samples belong to the classes. ''' x = self.mixConv2d(x) x = self.scb(x) x = x.reshape([ *x.shape[0:2], self.stride_factor, int(x.shape[3] / self.stride_factor) ]) x = self.temporal_layer(x) x = torch.flatten(x, start_dim=1) return self.fc(x)
def feature_dim(self, in_channels, num_electrodes, chunk_size): data = torch.ones((1, in_channels, num_electrodes, chunk_size)) x = self.mixConv2d(data) x = self.scb(x) x = x.reshape([*x.shape[0:2], self.stride_factor, -1]) x = self.temporal_layer(x) x = torch.flatten(x, start_dim=1) return x.size()
Read the Docs v: latest
Versions
latest
stable
v1.1.2
v1.1.1
v1.1.0
v1.0.11
v1.0.10
v1.0.9
v1.0.8.post1
v1.0.8
v1.0.7
v1.0.6
v1.0.4
v1.0.3
v1.0.2
v1.0.1
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources