Source code for torcheeg.transforms.numpy.shape

import numpy as np

from typing import List, Dict, Tuple
from scipy.interpolate import griddata


[docs]class PickElectrode: r''' Select parts of electrode signals based on a given electrode index list. .. code-block:: python transform = PickElectrode(PickElectrode.to_index_list( ['FP1', 'AF3', 'F3', 'F7', 'FC5', 'FC1', 'C3', 'T7', 'CP5', 'CP1', 'P3', 'P7', 'PO3','O1', 'FP2', 'AF4', 'F4', 'F8', 'FC6', 'FC2', 'C4', 'T8', 'CP6', 'CP2', 'P4', 'P8', 'PO4', 'O2'], DEAP_CHANNEL_LIST)) transform(torch.randn(32, 128)).shape >>> (28, 128) Args: pick_list (np.ndarray): Selected electrode list. Should consist of integers representing the corresponding electrode indices. :obj:`to_index_list` can be used to obtain an index list when we only know the names of the electrode and not their indices. .. automethod:: __call__ ''' def __init__(self, pick_list: List[int]): self.pick_list = pick_list
[docs] def __call__(self, x: np.ndarray) -> np.ndarray: r''' Args: x (np.ndarray): The input EEG signals in shape of [number of electrodes, number of data points]. Returns: np.ndarray: The output signals with the shape of [number of picked electrodes, number of data points]. ''' assert max( self.pick_list ) < x.shape[0], f'The index {max(self.pick_list)} of the specified electrode is out of bounds {x.shape[0]}.' return x[self.pick_list]
[docs] @staticmethod def to_index_list(electrode_list: List[str], dataset_electrode_list: List[str], strict_mode=False) -> List[int]: r''' Args: electrode_list (list): picked electrode name, consisting of strings. dataset_electrode_list (list): The description of the electrode information contained in the EEG signal in the dataset, consisting of strings. For the electrode position information, please refer to constants grouped by dataset :obj:`datasets.constants`. strict_mode: (bool): Whether to use strict mode. In strict mode, unmatched picked electrode names are thrown as errors. Otherwise, unmatched picked electrode names are automatically ignored. (defualt: :obj:`False`) Returns: list: Selected electrode list, consisting of integers representing the corresponding electrode indices. ''' dataset_electrode_dict = dict(zip(dataset_electrode_list, list(range(len(dataset_electrode_list))))) if strict_mode: return [ dataset_electrode_dict[electrode] for electrode in electrode_list ] return [ dataset_electrode_dict[electrode] for electrode in electrode_list if electrode in dataset_electrode_dict ]
def __repr__(self): return f"{self.__class__.__name__}()"
[docs]class To2d: r''' Taking the electrode index as the row index and the temporal index as the column index, a two-dimensional EEG signal representation with the size of [number of electrodes, number of data points] is formed. While PyTorch performs convolution on the 2d tensor, an additional channel dimension is required, thus we append an additional dimension. .. code-block:: python transform = To2d() transform(torch.randn(32, 128)).shape >>> (1, 32, 128) .. automethod:: __call__ '''
[docs] def __call__(self, x: np.ndarray) -> np.ndarray: r''' Args: x (np.ndarray): The input EEG signals in shape of [number of electrodes, number of data points]. Returns: np.ndarray: The transformed results with the shape of [1, number of electrodes, number of data points]. ''' return x[np.newaxis, ...]
[docs]class ToGrid: r''' A transform method to project the EEG signals of different channels onto the grid according to the electrode positions to form a 3D EEG signal representation with the size of [number of electrodes, width of grid, height of grid]. For the electrode position information, please refer to constants grouped by dataset: - datasets.constants.emotion_recognition.deap.DEAP_CHANNEL_LOCATION_DICT - datasets.constants.emotion_recognition.dreamer.DREAMER_CHANNEL_LOCATION_DICT - datasets.constants.emotion_recognition.seed.SEED_CHANNEL_LOCATION_DICT - ... .. code-block:: python transform = ToGrid(DEAP_CHANNEL_LOCATION_DICT) transform(torch.randn(32, 128)).shape >>> (128, 9, 9) Args: channel_location (dict): Electrode location information. Represented in dictionary form, where :obj:`key` corresponds to the electrode name and :obj:`value` corresponds to the row index and column index of the electrode on the grid. .. automethod:: __call__ ''' def __init__(self, channel_location: Dict[str, Tuple[int, int]]): self.channel_location = channel_location
[docs] def __call__(self, x: np.ndarray) -> np.ndarray: r''' Args: x (np.ndarray): The input EEG signals in shape of [number of electrodes, number of data points]. Returns: np.ndarray: The projected results with the shape of [number of electrodes, width of grid, height of grid]. ''' # electronode x timestep outputs = np.zeros([9, 9, x.shape[-1]]) # 9 x 9 x timestep for i, (loc_x, loc_y) in enumerate(self.channel_location.values()): outputs[loc_x][loc_y] = x[i] outputs = outputs.transpose(2, 0, 1) # timestep x 9 x 9 return outputs
def __repr__(self): return f"{self.__class__.__name__}()"
[docs]class ToInterpolatedGrid: r''' A transform method to project the EEG signals of different channels onto the grid according to the electrode positions to form a 3D EEG signal representation with the size of [number of electrodes, width of grid, height of grid]. For the electrode position information, please refer to constants grouped by dataset: - datasets.constants.emotion_recognition.deap.DEAP_CHANNEL_LOCATION_DICT - datasets.constants.emotion_recognition.dreamer.DREAMER_CHANNEL_LOCATION_DICT - datasets.constants.emotion_recognition.seed.SEED_CHANNEL_LOCATION_DICT - ... .. code-block:: python transform = ToInterpolatedGrid(DEAP_CHANNEL_LOCATION_DICT) transform(torch.randn(32, 128)).shape >>> (128, 9, 9) Especially, missing values on the grid are supplemented using cubic interpolation Args: channel_location (dict): Electrode location information. Represented in dictionary form, where :obj:`key` corresponds to the electrode name and :obj:`value` corresponds to the row index and column index of the electrode on the grid. .. automethod:: __call__ ''' def __init__(self, channel_location: Dict[str, Tuple[int, int]]): self.channel_location = channel_location self.location_array = np.array(list(channel_location.values())) grid_x, grid_y = np.mgrid[min(self.location_array[:, 0]):max(self.location_array[:, 0]):9 * 1j, min(self.location_array[:, 1]):max(self.location_array[:, 1]):9 * 1j, ] self.grid_x = grid_x self.grid_y = grid_y
[docs] def __call__(self, x: np.ndarray) -> np.ndarray: r''' Args: x (np.ndarray): The input EEG signals in shape of [number of electrodes, number of data points]. Returns: np.ndarray: The projected results with the shape of [number of electrodes, width of grid, height of grid]. ''' # channel x timestep x = x.transpose(1, 0) # timestep x channel outputs = [] for timestep_split_x in x: outputs.append( griddata(self.location_array, timestep_split_x, (self.grid_x, self.grid_y), method='cubic', fill_value=0)) outputs = np.array(outputs) return outputs
def __repr__(self): return f"{self.__class__.__name__}()"