Shortcuts

Source code for torcheeg.models.gnn.lggnet

import math

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules.module import Module
from torch.nn.parameter import Parameter


class ChannelAttention(nn.Module):
    def __init__(self, in_channelsnel, ratio=2):
        super(ChannelAttention, self).__init__()
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc1 = nn.Conv2d(in_channelsnel,
                             in_channelsnel // ratio,
                             1,
                             bias=False)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Conv2d(in_channelsnel // ratio,
                             in_channelsnel,
                             1,
                             bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_pool_out = self.avg_pool(x)
        max_pool_out = self.max_pool(x)
        avg_pool_out = self.fc2(self.relu1(self.fc1(avg_pool_out)))
        max_pool_out = self.fc2(self.relu1(self.fc1(max_pool_out)))
        out = max_pool_out + avg_pool_out
        return self.sigmoid(out)


class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()
        assert kernel_size in (3, 7), 'Kernel size must be 3 or 7.'
        padding = 3 if kernel_size == 7 else 1
        self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        max_pool_out, _ = torch.max(x, dim=1, keepdim=True)
        avg_pool_out = torch.mean(x, dim=1, keepdim=True)
        out = torch.cat([avg_pool_out, max_pool_out], dim=1)
        out = self.conv1(out)
        return self.sigmoid(out)


class CBAMBlock(nn.Module):
    def __init__(self, in_channelsnel, ratio=2, kernel_size=7):
        super(CBAMBlock, self).__init__()
        self.cha_att = ChannelAttention(in_channelsnel, ratio=ratio)
        self.spa_att = SpatialAttention(kernel_size=kernel_size)

    def forward(self, x):
        out = x * self.cha_att(x)
        out = out * self.spa_att(out)
        return out


class PowerLayer(nn.Module):
    def __init__(self, kernel_size, stride):
        super(PowerLayer, self).__init__()
        self.pooling = nn.AvgPool2d(kernel_size=(1, kernel_size),
                                    stride=(1, stride))

    def forward(self, x):
        return torch.log(self.pooling(x.pow(2)))


class Aggregator():
    def __init__(self, region_list):
        self.region_list = region_list

    def forward(self, x):
        output = []
        for region_index in range(len(self.region_list)):
            region_x = x[:, self.region_list[region_index], :]
            aggr_region_x = torch.mean(region_x, dim=1)
            output.append(aggr_region_x)
        return torch.stack(output, dim=1)


class GraphConvolution(Module):
    def __init__(self, in_channels, out_channels, bias=True):
        super(GraphConvolution, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.weight = Parameter(torch.FloatTensor(in_channels, out_channels))

        if bias:
            self.bias = Parameter(
                torch.zeros((1, 1, out_channels), dtype=torch.float32))
        else:
            self.register_parameter('bias', None)
        nn.init.xavier_uniform_(self.weight, gain=1.414)

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.uniform_(-stdv, stdv)

    def forward(self, x, adj):
        output = torch.matmul(x, self.weight) - self.bias
        output = F.relu(torch.matmul(adj, output))
        return output


[docs]class LGGNet(nn.Module): r''' DLocal-Global-Graph Networks (LGGNet). For more details, please refer to the following information. - Paper: Ding Y, Robinson N, Zeng Q, et al. LGGNet: learning from Local-global-graph representations for brain-computer interface[J]. arXiv preprint arXiv:2105.02786, 2021. - URL: https://arxiv.org/abs/2105.02786 - Related Project: https://github.com/yi-ding-cs/LGG Below is a recommended suite for use in emotion recognition tasks: .. code-block:: python from torcheeg.datasets import SEEDDataset from torcheeg.models import LGGNet from torcheeg import transforms from torcheeg.datasets.constants import SEED_GENERAL_REGION_LIST dataset = SEEDDataset(root_path='./Preprocessed_EEG', offline_transform=transforms.Compose([ transforms.MeanStdNormalize(), transforms.To2d() ]), online_transform=transforms.Compose([ transforms.ToTensor() ]), label_transform=transforms.Compose([ transforms.Select('emotion'), transforms.Lambda(lambda x: x + 1) ])) model = LGGNet(region_list=SEED_GENERAL_REGION_LIST, chunk_size=128, num_electrodes=32, hid_channels=32, num_classes=2) x, y = next(iter(DataLoader(dataset, batch_size=64))) model(x) The current built-in :obj:`region_list` includs: - torcheeg.datasets.constants.emotion_recognition.amigos.AMIGOS_GENERAL_REGION_LIST - torcheeg.datasets.constants.emotion_recognition.amigos.AMIGOS_FRONTAL_REGION_LIST - torcheeg.datasets.constants.emotion_recognition.amigos.AMIGOS_HEMISPHERE_REGION_LIST - torcheeg.datasets.constants.emotion_recognition.deap.DEAP_GENERAL_REGION_LIST - ... - torcheeg.datasets.constants.emotion_recognition.dreamer.DREAMER_GENERAL_REGION_LIST - ... - torcheeg.datasets.constants.emotion_recognition.mahnob.MAHNOB_GENERAL_REGION_LIST - ... - torcheeg.datasets.constants.emotion_recognition.seed.SEED_GENERAL_REGION_LIST - ... Args: region_list (list): The local graph structure defined according to the 10-20 system, where the electrodes are divided into different brain regions. in_channels (int): The feature dimension of each electrode. (default: :obj:`1`) num_electrodes (int): The number of electrodes. (default: :obj:`32`) chunk_size (int): Number of data points included in each EEG chunk. (default: :obj:`128`) sampling_rate (int): The sampling rate of the EEG signals, i.e., :math:`f_s` in the paper. (default: :obj:`128`) num_T (int): The number of multi-scale 1D temporal kernels in the dynamic temporal layer, i.e., :math:`T` kernels in the paper. (default: :obj:`64`) hid_channels (int): The number of hidden nodes in the first fully connected layer. (default: :obj:`32`) dropout (float): Probability of an element to be zeroed in the dropout layers. (default: :obj:`0.5`) pool_kernel_size (int): The kernel size of pooling layers in the temporal blocks (default: :obj:`16`) pool_stride (int): The stride of pooling layers in the temporal blocks (default: :obj:`4`) num_classes (int): The number of classes to predict. (default: :obj:`2`) ''' def __init__(self, region_list, in_channels: int = 1, num_electrodes: int = 32, chunk_size: int = 128, sampling_rate: int = 128, num_T: int = 64, hid_channels: int = 32, dropout: float = 0.5, pool_kernel_size: int = 16, pool_stride: int = 4, num_classes: int = 2): super(LGGNet, self).__init__() self.region_list = region_list self.inception_window = [0.5, 0.25, 0.125] self.num_classes = num_classes self.in_channels = in_channels self.num_electrodes = num_electrodes self.chunk_size = chunk_size self.sampling_rate = sampling_rate self.num_T = num_T self.hid_channels = hid_channels self.dropout = dropout self.pool_kernel_size = pool_kernel_size self.pool_stride = pool_stride self.in_channels = in_channels self.num_electrodes = num_electrodes self.t_block1 = self.temporal_block( self.in_channels, self.num_T, (1, int(self.inception_window[0] * self.sampling_rate)), self.pool_kernel_size, self.pool_stride) self.t_block2 = self.temporal_block( self.in_channels, self.num_T, (1, int(self.inception_window[1] * self.sampling_rate)), self.pool_kernel_size, self.pool_stride) self.t_block3 = self.temporal_block( self.in_channels, self.num_T, (1, int(self.inception_window[2] * self.sampling_rate)), self.pool_kernel_size, self.pool_stride) self.bn_t1 = nn.BatchNorm2d(self.num_T) self.bn_t2 = nn.BatchNorm2d(self.num_T) self.cbam = CBAMBlock(num_electrodes) self.conv1x1 = nn.Sequential( nn.Conv2d(num_T, num_T, kernel_size=(1, 1), stride=(1, 1)), nn.LeakyReLU(), nn.AvgPool2d((1, 2))) self.avg_pool = nn.AvgPool2d((1, 2)) feature_dim = self.feature_dim self.local_filter_weight = nn.Parameter(torch.FloatTensor( self.num_electrodes, feature_dim), requires_grad=True) self.local_filter_bias = nn.Parameter(torch.zeros( (1, self.num_electrodes, 1), dtype=torch.float32), requires_grad=True) self.aggregate = Aggregator(self.region_list) num_region = len(self.region_list) self.global_adj = nn.Parameter(torch.FloatTensor( num_region, num_region), requires_grad=True) self.bn_g1 = nn.BatchNorm1d(num_region) self.bn_g2 = nn.BatchNorm1d(num_region) self.gcn = GraphConvolution(feature_dim, hid_channels) self.fc = nn.Sequential( nn.Dropout(p=dropout), nn.Linear(int(num_region * hid_channels), num_classes)) nn.init.xavier_uniform_(self.local_filter_weight) nn.init.xavier_uniform_(self.global_adj) def temporal_block(self, in_channels, out_channels, kernel_size, pool_kernel_size, pool_stride): return nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=(1, 1)), PowerLayer(kernel_size=pool_kernel_size, stride=pool_stride))
[docs] def forward(self, x): r''' Args: x (torch.Tensor): EEG signal representation, the ideal input shape is :obj:`[n, 1, 32, 128]`. Here, :obj:`n` corresponds to the batch size, :obj:`32` corresponds to :obj:`num_electrodes`, and :obj:`chunk_size` corresponds to :obj:`chunk_size`. Returns: torch.Tensor[number of sample, number of classes]: the predicted probability that the samples belong to the classes. ''' t1 = self.t_block1(x) t2 = self.t_block2(x) t3 = self.t_block3(x) x = torch.cat((t1, t2, t3), dim=-1) x = self.bn_t1(x) x = x.permute(0, 2, 1, 3) x = self.cbam(x) x = self.avg_pool(x) x = x.flatten(start_dim=2) x = self.local_filter(x) x = self.aggregate.forward(x) adj = self.get_adj(x) x = self.bn_g1(x) x = self.gcn(x, adj) x = self.bn_g2(x) x = x.view(x.shape[0], -1) x = self.fc(x) return x
@property def feature_dim(self): mock_eeg = torch.randn( (1, self.in_channels, self.num_electrodes, self.chunk_size)) t1 = self.t_block1(mock_eeg) t2 = self.t_block2(mock_eeg) t3 = self.t_block3(mock_eeg) mock_eeg = torch.cat((t1, t2, t3), dim=-1) mock_eeg = self.bn_t1(mock_eeg) mock_eeg = self.conv1x1(mock_eeg) mock_eeg = self.bn_t2(mock_eeg) mock_eeg = mock_eeg.permute(0, 2, 1, 3) mock_eeg = mock_eeg.flatten(start_dim=2) return mock_eeg.shape[-1] def local_filter(self, x): w = self.local_filter_weight.unsqueeze(0).repeat(x.shape[0], 1, 1) x = F.relu(torch.mul(x, w) - self.local_filter_bias) return x def get_adj(self, x, self_loop=True): adj = torch.bmm(x, x.permute(0, 2, 1)) num_nodes = adj.shape[-1] adj = F.relu(adj * (self.global_adj + self.global_adj.transpose(1, 0))) if self_loop: adj = adj + torch.eye(num_nodes).to(x.device) rowsum = torch.sum(adj, dim=-1) mask = torch.zeros_like(rowsum) mask[rowsum == 0] = 1 rowsum += mask d_inv_sqrt = torch.pow(rowsum, -0.5) d_mat_inv_sqrt = torch.diag_embed(d_inv_sqrt) adj = torch.bmm(torch.bmm(d_mat_inv_sqrt, adj), d_mat_inv_sqrt) return adj
Read the Docs v: latest
Versions
latest
stable
v1.1.2
v1.1.1
v1.1.0
v1.0.11
v1.0.10
v1.0.9
v1.0.8.post1
v1.0.8
v1.0.7
v1.0.6
v1.0.4
v1.0.3
v1.0.2
v1.0.1
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources