Shortcuts

Source code for torcheeg.models.cnn.fbcnet

import torch
import torch.nn as nn


class Conv2dWithConstraint(nn.Conv2d):
    def __init__(self, *args, weight_norm=True, max_norm=1, **kwargs):
        self.max_norm = max_norm
        self.weight_norm = weight_norm
        super(Conv2dWithConstraint, self).__init__(*args, **kwargs)

    def forward(self, x):
        if self.weight_norm:
            self.weight.data = torch.renorm(self.weight.data,
                                            p=2,
                                            dim=0,
                                            maxnorm=self.max_norm)
        return super(Conv2dWithConstraint, self).forward(x)


class LinearWithConstraint(nn.Linear):
    def __init__(self, *args, weight_norm=True, max_norm=1, **kwargs):
        self.max_norm = max_norm
        self.weight_norm = weight_norm
        super(LinearWithConstraint, self).__init__(*args, **kwargs)

    def forward(self, x):
        if self.weight_norm:
            self.weight.data = torch.renorm(self.weight.data,
                                            p=2,
                                            dim=0,
                                            maxnorm=self.max_norm)
        return super(LinearWithConstraint, self).forward(x)


class VarLayer(nn.Module):
    def __init__(self, dim):
        super(VarLayer, self).__init__()
        self.dim = dim

    def forward(self, x):
        return x.var(dim=self.dim, keepdim=True)


class StdLayer(nn.Module):
    def __init__(self, dim):
        super(StdLayer, self).__init__()
        self.dim = dim

    def forward(self, x):
        return x.std(dim=self.dim, keepdim=True)


class LogVarLayer(nn.Module):
    def __init__(self, dim):
        super(LogVarLayer, self).__init__()
        self.dim = dim

    def forward(self, x):
        return torch.log(
            torch.clamp(x.var(dim=self.dim, keepdim=True), 1e-6, 1e6))


class MeanLayer(nn.Module):
    def __init__(self, dim):
        super(MeanLayer, self).__init__()
        self.dim = dim

    def forward(self, x):
        return x.mean(dim=self.dim, keepdim=True)


class MaxLayer(nn.Module):
    def __init__(self, dim):
        super(MaxLayer, self).__init__()
        self.dim = dim

    def forward(self, x):
        ma, ima = x.max(dim=self.dim, keepdim=True)
        return ma


class swish(nn.Module):
    def __init__(self):
        super(swish, self).__init__()

    def forward(self, x):
        return x * torch.sigmoid(x)


[docs]class FBCNet(nn.Module): r''' An Efficient Multi-view Convolutional Neural Network for Brain-Computer Interface. For more details, please refer to the following information. - Paper: Mane R, Chew E, Chua K, et al. FBCNet: A multi-view convolutional neural network for brain-computer interface[J]. arXiv preprint arXiv:2104.01233, 2021. - URL: https://arxiv.org/abs/2104.01233 - Related Project: https://github.com/ravikiran-mane/FBCNet Below is a recommended suite for use in emotion recognition tasks: .. code-block:: python from torcheeg.datasets import DEAPDataset from torcheeg import transforms from torcheeg.models import FBCNet from torch.utils.data import DataLoader dataset = DEAPDataset(root_path='./data_preprocessed_python', chunk_size=512, num_baseline=1, baseline_chunk_size=512, offline_transform=transforms.BandSignal(), online_transform=transforms.ToTensor(), label_transform=transforms.Compose([ transforms.Select('valence'), transforms.Binary(5.0), ])) model = FBCNet(num_classes=2, num_electrodes=32, chunk_size=512, in_channels=4, num_S=32) x, y = next(iter(DataLoader(dataset, batch_size=64))) model(x) Args: num_electrodes (int): The number of electrodes. (default: :obj:`28`) chunk_size (int): Number of data points included in each EEG chunk. (default: :obj:`1000`) in_channels (int): The number of channels of the signal corresponding to each electrode. If the original signal is used as input, in_channels is set to 1; if the original signal is split into multiple sub-bands, in_channels is set to the number of bands. (default: :obj:`9`) num_S (int): The number of spatial convolution block. (default: :obj:`32`) num_classes (int): The number of classes to predict. (default: :obj:`2`) temporal (str): The temporal layer used, with options including VarLayer, StdLayer, LogVarLayer, MeanLayer, and MaxLayer, used to compute statistics using different techniques in the temporal dimension. (default: :obj:`LogVarLayer`) stride_factor (int): The stride factor. (default: :obj:`4`) weight_norm (bool): Whether to use weight renormalization technique in Conv2dWithConstraint. (default: :obj:`True`) ''' def __init__(self, num_electrodes: int = 20, chunk_size: int = 1000, in_channels: int = 9, num_S: int = 32, num_classes: int = 2, temporal: str = 'LogVarLayer', stride_factor: int = 4, weight_norm: bool = True): super(FBCNet, self).__init__() self.num_electrodes = num_electrodes self.chunk_size = chunk_size self.num_classes = num_classes self.in_channels = in_channels self.num_S = num_S self.temporal = temporal self.stride_factor = stride_factor self.weight_norm = weight_norm assert chunk_size % stride_factor == 0, f'chunk_size should be divisible by stride_factor, chunk_size={chunk_size},stride_factor={stride_factor} does not meet the requirements.' self.scb = self.SCB(num_S, num_electrodes, self.in_channels, weight_norm=weight_norm) if temporal == 'VarLayer': self.temporal_layer = VarLayer(dim=3) elif temporal == 'StdLayer': self.temporal_layer = StdLayer(dim=3) elif temporal == 'LogVarLayer': self.temporal_layer = LogVarLayer(dim=3) elif temporal == 'MeanLayer': self.temporal_layer = MeanLayer(dim=3) elif temporal == 'MaxLayer': self.temporal_layer = MaxLayer(dim=3) else: raise NotImplementedError self.last_layer = self.last_block(self.num_S * self.in_channels * self.stride_factor, num_classes, weight_norm=weight_norm) def SCB(self, num_S, num_electrodes, in_channels, weight_norm=True): return nn.Sequential( Conv2dWithConstraint(in_channels, num_S * in_channels, (num_electrodes, 1), groups=in_channels, max_norm=2, weight_norm=weight_norm, padding=0), nn.BatchNorm2d(num_S * in_channels), swish()) def last_block(self, in_channels, out_channels, weight_norm=True): return nn.Sequential( LinearWithConstraint(in_channels, out_channels, max_norm=0.5, weight_norm=weight_norm), nn.LogSoftmax(dim=1))
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: r''' Args: x (torch.Tensor): EEG signal representation, the ideal input shape is :obj:`[n, in_channel, num_electrodes, chunk_size]`. Here, :obj:`n` corresponds to the batch size Returns: torch.Tensor[number of sample, number of classes]: the predicted probability that the samples belong to the classes. ''' x = self.scb(x) x = x.reshape([ *x.shape[0:2], self.stride_factor, int(x.shape[3] / self.stride_factor) ]) x = self.temporal_layer(x) x = torch.flatten(x, start_dim=1) x = self.last_layer(x) return x
def feature_dim(self): return self.num_S * self.in_channels * self.stride_factor
Read the Docs v: latest
Versions
latest
stable
v1.1.2
v1.1.1
v1.1.0
v1.0.11
v1.0.10
v1.0.9
v1.0.8.post1
v1.0.8
v1.0.7
v1.0.6
v1.0.4
v1.0.3
v1.0.2
v1.0.1
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources