Shortcuts

Source code for torcheeg.models.gnn.dgcnn

import torch
import torch.nn as nn
import torch.nn.functional as F


class GraphConvolution(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, bias: bool=False):

        super(GraphConvolution, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.weight = nn.Parameter(torch.FloatTensor(in_channels, out_channels))
        nn.init.xavier_normal_(self.weight)
        self.bias = None
        if bias:
            self.bias = nn.Parameter(torch.FloatTensor(out_channels))
            nn.init.zeros_(self.bias)

    def forward(self, x: torch.Tensor, adj: torch.Tensor) -> torch.Tensor:
        out = torch.matmul(adj, x)
        out = torch.matmul(out, self.weight)
        if self.bias is not None:
            return out + self.bias
        else:
            return out


class Linear(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, bias: bool=True):
        super(Linear, self).__init__()
        self.linear = nn.Linear(in_channels, out_channels, bias=bias)
        nn.init.xavier_normal_(self.linear.weight)
        if bias:
            nn.init.zeros_(self.linear.bias)

    def forward(self, inputs: torch.Tensor) -> torch.Tensor:
        return self.linear(inputs)


def normalize_A(A: torch.Tensor, symmetry: bool=False) -> torch.Tensor:
    A = F.relu(A)
    if symmetry:
        A = A + torch.transpose(A, 0, 1)
        d = torch.sum(A, 1)
        d = 1 / torch.sqrt(d + 1e-10)
        D = torch.diag_embed(d)
        L = torch.matmul(torch.matmul(D, A), D)
    else:
        d = torch.sum(A, 1)
        d = 1 / torch.sqrt(d + 1e-10)
        D = torch.diag_embed(d)
        L = torch.matmul(torch.matmul(D, A), D)
    return L


def generate_cheby_adj(A: torch.Tensor, num_layers: int) -> torch.Tensor:
    support = []
    for i in range(num_layers):
        if i == 0:
            support.append(torch.eye(A.shape[1]).to(A.device))
        elif i == 1:
            support.append(A)
        else:
            temp = torch.matmul(support[-1], A)
            support.append(temp)
    return support


class Chebynet(nn.Module):
    def __init__(self, in_channels: int, num_layers: int, out_channels: int):
        super(Chebynet, self).__init__()
        self.num_layers = num_layers
        self.gc1 = nn.ModuleList()
        for i in range(num_layers):
            self.gc1.append(GraphConvolution(in_channels, out_channels))

    def forward(self, x: torch.Tensor, L: torch.Tensor) -> torch.Tensor:
        adj = generate_cheby_adj(L, self.num_layers)
        for i in range(len(self.gc1)):
            if i == 0:
                result = self.gc1[i](x, adj[i])
            else:
                result += self.gc1[i](x, adj[i])
        result = F.relu(result)
        return result


[docs]class DGCNN(nn.Module): r''' Dynamical Graph Convolutional Neural Networks (DGCNN). For more details, please refer to the following information. - Paper: Song T, Zheng W, Song P, et al. EEG emotion recognition using dynamical graph convolutional neural networks[J]. IEEE Transactions on Affective Computing, 2018, 11(3): 532-541. - URL: https://ieeexplore.ieee.org/abstract/document/8320798 - Related Project: https://github.com/xueyunlong12589/DGCNN Below is a recommended suite for use in emotion recognition tasks: .. code-block:: python dataset = SEEDDataset(io_path=f'./seed', root_path='./Preprocessed_EEG', offline_transform=transforms.BandDifferentialEntropy(band_dict={ "delta": [1, 4], "theta": [4, 8], "alpha": [8, 14], "beta": [14, 31], "gamma": [31, 49] }), online_transform=transforms.Compose([ transforms.ToTensor() ]), label_transform=transforms.Compose([ transforms.Select('emotion'), transforms.Lambda(lambda x: x + 1) ])) model = DGCNN(in_channels=5, num_electrodes=62, hid_channels=32, num_layers=2, num_classes=2) Args: in_channels (int): The feature dimension of each electrode. (default: :obj:`5`) num_electrodes (int): The number of electrodes. (default: :obj:`62`) num_layers (int): The number of graph convolutional layers. (default: :obj:`2`) hid_channels (int): The number of hidden nodes in the first fully connected layer. (default: :obj:`32`) num_classes (int): The number of classes to predict. (default: :obj:`2`) ''' def __init__(self, in_channels: int = 5, num_electrodes: int = 62, num_layers: int = 2, hid_channels: int = 32, num_classes: int = 2): super(DGCNN, self).__init__() self.in_channels = in_channels self.num_electrodes = num_electrodes self.hid_channels = hid_channels self.num_layers = num_layers self.num_classes = num_classes self.layer1 = Chebynet(in_channels, num_layers, hid_channels) self.BN1 = nn.BatchNorm1d(in_channels) self.fc1 = Linear(num_electrodes * hid_channels, 64) self.fc2 = Linear(64, num_classes) self.A = nn.Parameter(torch.FloatTensor(num_electrodes, num_electrodes)) nn.init.xavier_normal_(self.A)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: r''' Args: x (torch.Tensor): EEG signal representation, the ideal input shape is :obj:`[n, 62, 5]`. Here, :obj:`n` corresponds to the batch size, :obj:`62` corresponds to :obj:`num_electrodes`, and :obj:`5` corresponds to :obj:`in_channels`. Returns: torch.Tensor[number of sample, number of classes]: the predicted probability that the samples belong to the classes. ''' x = self.BN1(x.transpose(1, 2)).transpose(1, 2) L = normalize_A(self.A) result = self.layer1(x, L) result = result.reshape(x.shape[0], -1) result = F.relu(self.fc1(result)) result = self.fc2(result) return result

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources