Shortcuts

Source code for torcheeg.models.pyg.rgnn

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Batch
from torch_geometric.nn import SGConv, global_add_pool
from torch_scatter import scatter_add

from typing import Union, Tuple


def maybe_num_electrodes(index: torch.Tensor, num_electrodes: Union[int, None] = None) -> int:
    return index.max().item() + 1 if num_electrodes is None else num_electrodes


def add_remaining_self_loops(edge_index: torch.Tensor,
                             edge_weight: Union[torch.Tensor, None] = None,
                             fill_value: float = 1.0,
                             num_electrodes: Union[int, None] = None) -> Tuple[torch.Tensor, torch.Tensor]:
    num_electrodes = maybe_num_electrodes(edge_index, num_electrodes)
    row, col = edge_index

    mask = row != col

    inv_mask = ~mask
    # print("inv_mask", inv_mask)

    loop_weight = torch.full((num_electrodes, ),
                             fill_value,
                             dtype=None if edge_weight is None else edge_weight.dtype,
                             device=edge_index.device)

    if edge_weight is not None:
        assert edge_weight.numel() == edge_index.size(1)
        remaining_edge_weight = edge_weight[inv_mask]

        if remaining_edge_weight.numel() > 0:
            loop_weight[row[inv_mask]] = remaining_edge_weight

        edge_weight = torch.cat([edge_weight[mask], loop_weight], dim=0)
    loop_index = torch.arange(0, num_electrodes, dtype=row.dtype, device=row.device)
    loop_index = loop_index.unsqueeze(0).repeat(2, 1)
    edge_index = torch.cat([edge_index[:, mask], loop_index], dim=1)

    return edge_index, edge_weight


class NewSGConv(SGConv):
    def __init__(self,
                 in_channels: int,
                 out_channels: int,
                 num_layers: int = 1,
                 cached: bool = False,
                 bias: bool = True):
        super(NewSGConv, self).__init__(in_channels, out_channels, K=num_layers, cached=cached, bias=bias)

    # allow negative edge weights
    @staticmethod
    def norm(edge_index: torch.Tensor,
             num_electrodes: int,
             adj: torch.Tensor,
             improved: bool = False,
             dtype: Union[torch.dtype, None] = None):
        if adj is None:
            adj = torch.ones((edge_index.size(1), ), dtype=dtype, device=edge_index.device)

        fill_value = 1 if not improved else 2
        edge_index, adj = add_remaining_self_loops(edge_index, adj, fill_value, num_electrodes)
        row, col = edge_index
        deg = scatter_add(torch.abs(adj), row, dim=0, dim_size=num_electrodes)
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0

        return edge_index, deg_inv_sqrt[row] * adj * deg_inv_sqrt[col]

    def forward(self, x: torch.Tensor, edge_index: torch.Tensor, adj: Union[None, torch.Tensor] = None) -> torch.Tensor:
        """"""
        if not self.cached or self.cached_result is None:
            edge_index, norm = NewSGConv.norm(edge_index, x.size(0), adj, dtype=x.dtype)

            for k in range(self.K):
                x = self.propagate(edge_index, x=x, norm=norm)
            self.cached_result = x

        return self.lin(self.cached_result)

    def message(self, x_j: torch.Tensor, norm: torch.Tensor) -> torch.Tensor:
        # x_j: (batch_size*num_electrodes*num_electrodes, in_channels)
        # norm: (batch_size*num_electrodes*num_electrodes, )
        return norm.view(-1, 1) * x_j


[docs]class RGNN(torch.nn.Module): r''' Regularized Graph Neural Networks (RGNN). For more details, please refer to the following information. - Paper: Zhong P, Wang D, Miao C. EEG-based emotion recognition using regularized graph neural networks[J]. IEEE Transactions on Affective Computing, 2020. - URL: https://ieeexplore.ieee.org/abstract/document/9091308 - Related Project: https://github.com/zhongpeixiang/RGNN Below is a recommended suite for use in emotion recognition tasks: .. code-block:: python from torcheeg.datasets import SEEDDataset from torcheeg.models import RGNN from torcheeg.transforms.pyg import ToG from torcheeg.datasets.constants import SEED_STANDARD_ADJACENCY_MATRIX from torch_geometric.data import DataLoader dataset = SEEDDataset(root_path='./Preprocessed_EEG', offline_transform=transforms.BandDifferentialEntropy(), online_transform=ToG(SEED_STANDARD_ADJACENCY_MATRIX), label_transform=transforms.Compose([ transforms.Select('emotion'), transforms.Lambda(lambda x: int(x) + 1), ]), num_worker=8) model = RGNN(adj=torch.Tensor(SEED_STANDARD_ADJACENCY_MATRIX), in_channels=5, num_electrodes=62, hid_channels=32, num_layers=2, num_classes=3, dropout=0.7, learn_edge_weights=True) x, y = next(iter(DataLoader(dataset, batch_size=64))) model(x) Args: adj (torch.Tensor): The adjacency matrix corresponding to the EEG representation, where 1.0 means the node is adjacent and 0.0 means the node is not adjacent. The matrix shape should be [num_electrodes, num_electrodes]. num_electrodes (int): The number of electrodes. (default: :obj:`62`) in_channels (int): The feature dimension of each electrode. (default: :obj:`5`) num_layers (int): The number of graph convolutional layers. (default: :obj:`2`) hid_channels (int): The number of hidden nodes in the first fully connected layer. (default: :obj:`32`) num_classes (int): The number of classes to predict. (default: :obj:`3`) dropout (float): Probability of an element to be zeroed in the dropout layers at the output fully-connected layer. (default: :obj:`0.7`) learn_edge_weights (bool): Whether to learn a set of parameters to adjust the adjacency matrix. (default: :obj:`True`) ''' def __init__(self, adj: Union[torch.Tensor, list], num_electrodes: int = 62, in_channels: int = 5, num_layers: int = 2, hid_channels: int = 32, num_classes: int = 3, dropout: float = 0.7, learn_edge_weights: bool = True): super(RGNN, self).__init__() self.num_electrodes = num_electrodes self.in_channels = in_channels self.num_electrodes = num_electrodes self.hid_channels = hid_channels self.num_layers = num_layers self.num_classes = num_classes self.dropout = dropout self.learn_edge_weights = learn_edge_weights self.xs, self.ys = torch.tril_indices(self.num_electrodes, self.num_electrodes, offset=0) if isinstance(adj, list): adj = torch.tensor(adj) adj = adj.reshape(self.num_electrodes, self.num_electrodes)[self.xs, self.ys] # strict lower triangular values self.adj = nn.Parameter(adj, requires_grad=learn_edge_weights) self.conv1 = NewSGConv(in_channels=in_channels, out_channels=hid_channels, num_layers=num_layers) self.fc = nn.Linear(hid_channels, num_classes)
[docs] def forward(self, data: Batch) -> torch.Tensor: r''' Args: data (torch_geometric.data.Batch): EEG signal representation, the ideal input shape of data.x is :obj:`[n, 62, 4]`. Here, :obj:`n` corresponds to the batch size, :obj:`62` corresponds to the number of electrodes, and :obj:`4` corresponds to :obj:`in_channels`. Returns: torch.Tensor[number of sample, number of classes]: the predicted probability that the samples belong to the classes. ''' batch_size = data.num_graphs x, edge_index = data.x, data.edge_index adj = torch.zeros((self.num_electrodes, self.num_electrodes), device=edge_index.device) adj[self.xs.to(adj.device), self.ys.to(adj.device)] = self.adj adj = adj + adj.transpose(1, 0) - torch.diag(adj.diagonal()) # copy values from lower tri to upper tri adj = adj.reshape(-1).repeat(batch_size) x = F.relu(self.conv1(x, edge_index, adj)) x = global_add_pool(x, data.batch, size=batch_size) x = F.dropout(x, p=self.dropout, training=self.training) x = self.fc(x) return x
Read the Docs v: latest
Versions
latest
stable
v1.1.2
v1.1.1
v1.1.0
v1.0.11
v1.0.10
v1.0.9
v1.0.8.post1
v1.0.8
v1.0.7
v1.0.6
v1.0.4
v1.0.3
v1.0.2
v1.0.1
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources