Shortcuts

Source code for torcheeg.models.cnn.tslanet

import math
import warnings

import torch
import torch.nn as nn


def _trunc_normal_(tensor, mean, std, a, b):
    def norm_cdf(x):
        return (1. + math.erf(x / math.sqrt(2.))) / 2.

    if (mean < a - 2 * std) or (mean > b + 2 * std):
        warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
                      "The distribution of values may be incorrect.",
                      stacklevel=2)

    l = norm_cdf((a - mean) / std)
    u = norm_cdf((b - mean) / std)

    tensor.uniform_(2 * l - 1, 2 * u - 1)

    tensor.erfinv_()

    tensor.mul_(std * math.sqrt(2.))
    tensor.add_(mean)

    tensor.clamp_(min=a, max=b)
    return tensor


def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
    with torch.no_grad():
        return _trunc_normal_(tensor, mean, std, a, b)


def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True):
    if drop_prob == 0. or not training:
        return x
    keep_prob = 1 - drop_prob
    shape = (x.shape[0],) + (1,) * (x.ndim - 1)
    random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
    if keep_prob > 0.0 and scale_by_keep:
        random_tensor.div_(keep_prob)
    return x * random_tensor


class DropPath(nn.Module):
    def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob
        self.scale_by_keep = scale_by_keep

    def forward(self, x):
        return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)


class ICB(nn.Module):
    def __init__(self, in_features, hidden_features, drop=0.):
        super().__init__()
        self.conv1 = nn.Conv1d(in_features, hidden_features, 1)
        self.conv2 = nn.Conv1d(in_features, hidden_features, 3, 1, 1)
        self.conv3 = nn.Conv1d(hidden_features, in_features, 1)
        self.drop = nn.Dropout(drop)
        self.act = nn.GELU()

    def forward(self, x):
        x = x.transpose(1, 2)
        x1 = self.conv1(x)
        x1_1 = self.act(x1)
        x1_2 = self.drop(x1_1)

        x2 = self.conv2(x)
        x2_1 = self.act(x2)
        x2_2 = self.drop(x2_1)

        out1 = x1 * x2_2
        out2 = x2 * x1_2

        x = self.conv3(out1 + out2)
        x = x.transpose(1, 2)
        return x


class PatchEmbed(nn.Module):
    def __init__(self, seq_len, patch_size=8, in_chans=3, embed_dim=384):
        super().__init__()
        stride = patch_size // 2
        num_patches = int((seq_len - patch_size) / stride + 1)
        self.num_patches = num_patches
        self.proj = nn.Conv1d(in_chans, embed_dim,
                              kernel_size=patch_size, stride=stride)

    def forward(self, x):
        x_out = self.proj(x).flatten(2).transpose(1, 2)
        return x_out


class AdaptiveSpectralBlock(nn.Module):
    def __init__(self, dim, adaptive_filter=True):
        super().__init__()
        self.adaptive_filter = adaptive_filter

        self.complex_weight_high = nn.Parameter(
            torch.randn(dim, 2, dtype=torch.float32) * 0.02)
        self.complex_weight = nn.Parameter(
            torch.randn(dim, 2, dtype=torch.float32) * 0.02)

        trunc_normal_(self.complex_weight_high, std=.02)
        trunc_normal_(self.complex_weight, std=.02)
        self.threshold_param = nn.Parameter(torch.rand(1))  # * 0.5)

    def create_adaptive_high_freq_mask(self, x_fft):
        B, _, _ = x_fft.shape

        energy = torch.abs(x_fft).pow(2).sum(dim=-1)

        flat_energy = energy.view(B, -1)
        median_energy = flat_energy.median(dim=1, keepdim=True)[
            0]
        median_energy = median_energy.view(B, 1)

        epsilon = 1e-6
        normalized_energy = energy / (median_energy + epsilon)

        adaptive_mask = ((normalized_energy > self.threshold_param).float(
        ) - self.threshold_param).detach() + self.threshold_param
        adaptive_mask = adaptive_mask.unsqueeze(-1)

        return adaptive_mask

    def forward(self, x_in):
        B, N, C = x_in.shape

        dtype = x_in.dtype
        x = x_in.to(torch.float32)

        x_fft = torch.fft.rfft(x, dim=1, norm='ortho')
        weight = torch.view_as_complex(self.complex_weight)
        x_weighted = x_fft * weight

        if self.adaptive_filter:
            freq_mask = self.create_adaptive_high_freq_mask(x_fft)
            x_masked = x_fft * freq_mask.to(x.device)

            weight_high = torch.view_as_complex(self.complex_weight_high)
            x_weighted2 = x_masked * weight_high

            x_weighted += x_weighted2

        x = torch.fft.irfft(x_weighted, n=N, dim=1, norm='ortho')

        x = x.to(dtype)
        x = x.view(B, N, C)

        return x


class TSLANetLayer(nn.Module):
    def __init__(self, dim, mlp_ratio=3., drop=0., drop_path=0., norm_layer=nn.LayerNorm):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.ASB = AdaptiveSpectralBlock(dim)
        self.drop_path = DropPath(
            drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.ICB = ICB(in_features=dim,
                       hidden_features=mlp_hidden_dim, drop=drop)

    def forward(self, x):
        x = x + \
            self.drop_path(self.ICB(self.norm2(self.ASB(self.norm1(x)))))
        return x


[docs]class TSLANet(nn.Module): r''' A time series lightweight adaptive network for EEG classification. For more details, please refer to the following information. - Paper: Eldele E, Ragab M, Chen Z, et al. TSLANet: Rethinking Transformers for Time Series Representation Learning[C]//Forty-first International Conference on Machine Learning. - URL: https://openreview.net/pdf?id=CGR3vpX63X - Related Project: https://github.com/emadeldeen24/TSLANet Below is a quick start example: .. code-block:: python from torcheeg.models import TSLANet model = TSLANet(num_classes=5, chunk_size=3000, patch_size=200, num_electrodes=1) # batch_size, num_electrodes, chunk_size x = torch.randn(32, 1, 3000) model(x) Args: chunk_size (int): Number of data points in each EEG segment. (default: :obj:`3000`) patch_size (int): Size of each patch the input sequence is divided into. (default: :obj:`200`) num_electrodes (int): The number of EEG channels. (default: :obj:`6`) emb_dim (int): Dimension of the embedding space. (default: :obj:`128`) dropout_rate (float): Dropout rate for regularization. (default: :obj:`0.15`) depth (int): Number of TSLANet layers in the network. (default: :obj:`2`) num_classes (int): The number of classes to classify. (default: :obj:`2`) ''' def __init__(self, chunk_size: int = 3000, patch_size: int = 200, num_electrodes: int = 1, emb_dim: int = 128, dropout_rate: float = 0.15, depth: int = 2, num_classes: int = 2): super().__init__() self.emb_dim = emb_dim self.patch_embed = PatchEmbed( seq_len=chunk_size, patch_size=patch_size, in_chans=num_electrodes, embed_dim=emb_dim ) num_patches = self.patch_embed.num_patches self.pos_embed = nn.Parameter(torch.zeros( 1, num_patches, emb_dim), requires_grad=True) self.pos_drop = nn.Dropout(p=dropout_rate) self.input_layer = nn.Linear(patch_size, emb_dim) dpr = [x.item() for x in torch.linspace(0, dropout_rate, depth)] self.tsla_blocks = nn.ModuleList([ TSLANetLayer(dim=emb_dim, drop=dropout_rate, drop_path=dpr[i]) for i in range(depth)] ) self.head = nn.Linear(emb_dim, num_classes) trunc_normal_(self.pos_embed, std=.02) self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0)
[docs] def forward(self, x): x = self.patch_embed(x) x = x + self.pos_embed x = self.pos_drop(x) for tsla_blk in self.tsla_blocks: x = tsla_blk(x) x = x.mean(1) x = self.head(x) return x

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources