Source code for torcheeg.transforms.pyg

import torch
import numpy as np

from typing import List
from torch_geometric.data import Data


[docs]class ToG: r''' A transformation method for constructing a graph representation of EEG signals, the results of which are applied to the input of the :obj:`torch_geometric` model. In the graph, nodes correspond to electrodes, and edges correspond to associations between electrodes (eg, spatially adjacent or functionally connected) :obj:`TorchEEG` provides some common graph structures. Consider using the following adjacency matrices depending on the dataset (with different EEG acquisition systems): - datasets.constants.emotion_recognition.deap.DEAP_ADJACENCY_MATRIX - datasets.constants.emotion_recognition.dreamer.DREAMER_ADJACENCY_MATRIX - datasets.constants.emotion_recognition.seed.SEED_ADJACENCY_MATRIX - ... .. code-block:: python transform = ToG(adj=DEAP_ADJACENCY_MATRIX) transform(np.random.randn(32, 128)).shape >>> (32, 4) Args: adj (list): An adjacency matrix represented by a 2D array, each element in the adjacency matrix represents the electrode-to-electrode edge weight. Please keep the order of electrodes in the rows and columns of the adjacency matrix consistent with the EEG signal to be transformed. complete_graph (bool): Whether to build as a complete graph. If False, only construct edges between electrodes based on non-zero elements; if True, construct variables between all electrodes and set the weight of non-existing edges to 0. (defualt: :obj:`False`) .. automethod:: __call__ ''' def __init__(self, adj: List[List], complete_graph: bool = False): adj = torch.tensor(adj) if complete_graph: adj[adj == 0] = 1e-6 self.adj = adj.to_sparse()
[docs] def __call__(self, x: np.ndarray) -> Data: r''' Args: x (np.ndarray): The input EEG signals in shape of [number of electrodes, number of data points]. Returns: torch_geometric.data.Data: The graph representation data types that torch_geometric can accept. Nodes correspond to electrodes, and edges are determined via the given adjacency matrix. ''' data = Data(edge_index=self.adj._indices()) data.x = torch.from_numpy(x).float() data.edge_weight = self.adj._values() return data