Shortcuts

Source code for torcheeg.transforms.hooks.after_hook

import numpy as np
import torch

from typing import List, Union


[docs]def after_hook_normalize( data: List[Union[np.ndarray, torch.Tensor]], eps: float = 1e-6) -> List[Union[np.ndarray, torch.Tensor]]: r''' A common hook function used to normalize the signal of the whole trial/session/subject after dividing it into chunks and transforming the divided chunks. It is used as follows: .. code-block:: python from torcheeg.datasets import DEAPDataset from torcheeg.transforms import after_hook_normalize dataset = DEAPDataset(root_path='./data_preprocessed_python', offline_transform=transforms.Compose([ transforms.BandDifferentialEntropy(), transforms.ToGrid(DEAP_CHANNEL_LOCATION_DICT) ]), online_transform=transforms.ToTensor(), after_trial=after_hook_normalize, num_worker=4, label_transform=transforms.Compose([ transforms.Select('valence'), transforms.Binary(5.0), ])) If you want to pass in parameters, use partial to generate a new function: .. code-block:: python from functools import partial from torcheeg.datasets import DEAPDataset from torcheeg.transforms import after_hook_normalize DEAPDataset(root_path='./data_preprocessed_python', offline_transform=transforms.Compose([ transforms.BandDifferentialEntropy(), transforms.ToGrid(DEAP_CHANNEL_LOCATION_DICT) ]), online_transform=transforms.ToTensor(), after_trial=partial(after_hook_normalize, eps=1e-5), num_worker=4, label_transform=transforms.Compose([ transforms.Select('valence'), transforms.Binary(5.0), ])) Args: data (list): A list of :obj:`np.ndarray` or :obj:`torch.Tensor`, one of which corresponds to an EEG signal in trial. eps (float): The term added to the denominator to improve numerical stability (default: :obj:`1e-6`) Returns: list: The normalized results of a trial. It is a list of :obj:`np.ndarray` or :obj:`torch.Tensor`, one of which corresponds to an EEG signal in trial. ''' if isinstance(data[0], np.ndarray): data = np.stack(data, axis=0) min_v = data.min(axis=0, keepdims=True) max_v = data.max(axis=0, keepdims=True) data = (data - min_v) / (max_v - min_v + eps) return [sample for sample in data] elif isinstance(data[0], torch.Tensor): data = torch.stack(data, dim=0) min_v, _ = data.min(axis=0, keepdims=True) max_v, _ = data.max(axis=0, keepdims=True) data = (data - min_v) / (max_v - min_v + eps) return [sample for sample in data] else: raise ValueError( 'The after_hook_normalize only supports np.ndarray and torch.Tensor. Please make sure the outputs of offline_transform ({}) are np.ndarray or torch.Tensor.' .format(type(data[0])))
[docs]def after_hook_running_norm( data: List[Union[np.ndarray, torch.Tensor]], decay_rate: float = 0.9, eps: float = 1e-6) -> List[Union[np.ndarray, torch.Tensor]]: r''' A common hook function used to normalize the signal of the whole trial/session/subject after dividing it into chunks and transforming the divided chunks. It is used as follows: .. code-block:: python from torcheeg.datasets import DEAPDataset from torcheeg.transforms import after_hook_running_norm from torcheeg.datasets.constants import DEAP_CHANNEL_LOCATION_DICT dataset = DEAPDataset(root_path='./data_preprocessed_python', offline_transform=transforms.Compose([ transforms.BandDifferentialEntropy(), transforms.ToGrid(DEAP_CHANNEL_LOCATION_DICT) ]), online_transform=transforms.ToTensor(), after_trial=after_hook_running_norm, label_transform=transforms.Compose([ transforms.Select('valence'), transforms.Binary(5.0), ])) If you want to pass in parameters, use partial to generate a new function: .. code-block:: python from functools import partial from torcheeg.datasets import DEAPDataset from torcheeg.transforms import after_hook_running_norm dataset = DEAPDataset(root_path='./data_preprocessed_python', offline_transform=transforms.Compose([ transforms.BandDifferentialEntropy(), transforms.ToGrid(DEAP_CHANNEL_LOCATION_DICT) ]), online_transform=transforms.ToTensor(), after_trial=partial(after_hook_running_norm, decay_rate=0.9, eps=1e-6), label_transform=transforms.Compose([ transforms.Select('valence'), transforms.Binary(5.0), ])) Args: data (list): A list of :obj:`np.ndarray` or :obj:`torch.Tensor`, one of which corresponds to an EEG signal in trial. decay_rate (float): The decay rate used in the running normalization (default: :obj:`0.9`) eps (float): The term added to the denominator to improve numerical stability (default: :obj:`1e-6`) Returns: list: The normalized results of a trial. It is a list of :obj:`np.ndarray` or :obj:`torch.Tensor`, one of which corresponds to an EEG signal in trial. ''' if isinstance(data[0], np.ndarray): data = np.stack(data, axis=0) running_mean = np.zeros_like(data[0]) running_var = np.zeros_like(data[0]) for i, current_sample in enumerate(data): running_mean = decay_rate * running_mean + ( 1 - decay_rate) * current_sample running_var = decay_rate * running_var + ( 1 - decay_rate) * np.square(current_sample - running_mean) data[i] = (data[i] - running_mean) / np.sqrt(running_var + eps) return [sample for sample in data] elif isinstance(data[0], torch.Tensor): data = torch.stack(data, dim=0) running_mean = torch.zeros_like(data[0]) running_var = torch.zeros_like(data[0]) for i, current_sample in enumerate(data): running_mean = decay_rate * running_mean + ( 1 - decay_rate) * current_sample running_var = decay_rate * running_var + ( 1 - decay_rate) * torch.square(current_sample - running_mean) data[i] = (data[i] - running_mean) / torch.sqrt(running_var + eps) return [sample for sample in data] else: raise ValueError( 'The after_hook_running_norm only supports np.ndarray and torch.Tensor. Please make sure the outputs of offline_transform ({}) are np.ndarray or torch.Tensor.' .format(type(data[0])))
[docs]def after_hook_linear_dynamical_system( data: List[Union[np.ndarray, torch.Tensor]], V0: float = 0.01, A: float = 1, T: float = 0.0001, C: float = 1, sigma: float = 1) -> List[Union[np.ndarray, torch.Tensor]]: r''' A common hook function used to normalize the signal of the whole trial/session/subject after dividing it into chunks and transforming the divided chunks. It is used as follows: .. code-block:: python from torcheeg.datasets import DEAPDataset from torcheeg.transforms import after_hook_linear_dynamical_system dataset = DEAPDataset(root_path='./data_preprocessed_python', offline_transform=transforms.Compose([ transforms.BandDifferentialEntropy(), transforms.ToGrid(DEAP_CHANNEL_LOCATION_DICT) ]), online_transform=transforms.ToTensor(), after_trial=after_hook_linear_dynamical_system, label_transform=transforms.Compose([ transforms.Select('valence'), transforms.Binary(5.0), ])) If you want to pass in parameters, use partial to generate a new function: .. code-block:: python from functools import partial from torcheeg.datasets import DEAPDataset from torcheeg.transforms import after_hook_linear_dynamical_system dataset = DEAPDataset(root_path='./data_preprocessed_python', offline_transform=transforms.Compose([ transforms.BandDifferentialEntropy(), transforms.ToGrid(DEAP_CHANNEL_LOCATION_DICT) ]), online_transform=transforms.ToTensor(), after_trial=partial(after_hook_linear_dynamical_system, V0=0.01, A=1, T=0.0001, C=1, sigma=1), label_transform=transforms.Compose([ transforms.Select('valence'), transforms.Binary(5.0), ])) Args: data (list): A list of :obj:`np.ndarray` or :obj:`torch.Tensor`, one of which corresponds to an EEG signal in trial. V0 (float): The initial variance of the linear dynamical system (default: :obj:`0.01`) A (float): The coefficient of the linear dynamical system (default: :obj:`1`) T (float): The term added to the diagonal of the covariance matrix (default: :obj:`0.0001`) C (float): The coefficient of the linear dynamical system (default: :obj:`1`) sigma (float): The variance of the linear dynamical system (default: :obj:`1`) Returns: list: The normalized results of a trial. It is a list of :obj:`np.ndarray` or :obj:`torch.Tensor`, one of which corresponds to an EEG signal in trial. ''' if isinstance(data[0], np.ndarray): # save the data[0].shape and flatten them shape = data[0].shape data = np.stack([sample.flatten() for sample in data], axis=0) ave = np.mean(data, axis=0) u0 = ave X = data.transpose((1, 0)) [m, n] = X.shape P = np.zeros((m, n)) u = np.zeros((m, n)) V = np.zeros((m, n)) K = np.zeros((m, n)) K[:, 0] = (V0 * C / (C * V0 * C + sigma)) * np.ones((m, )) u[:, 0] = u0 + K[:, 0] * (X[:, 0] - C * u0) V[:, 0] = (np.ones((m, )) - K[:, 0] * C) * V0 for i in range(1, n): P[:, i - 1] = A * V[:, i - 1] * A + T K[:, i] = P[:, i - 1] * C / (C * P[:, i - 1] * C + sigma) u[:, i] = A * u[:, i - 1] + K[:, i] * (X[:, i] - C * A * u[:, i - 1]) V[:, i] = (np.ones((m, )) - K[:, i] * C) * P[:, i - 1] X = u return [sample.reshape(shape) for sample in X.transpose((1, 0))] elif isinstance(data[0], torch.Tensor): shape = data[0].shape data = torch.stack([sample.flatten() for sample in data], dim=0) ave = torch.mean(data, dim=0) u0 = ave X = data.transpose(1, 0) [m, n] = X.shape P = torch.zeros((m, n)) u = torch.zeros((m, n)) V = torch.zeros((m, n)) K = torch.zeros((m, n)) K[:, 0] = (V0 * C / (C * V0 * C + sigma)) * torch.ones((m, )) u[:, 0] = u0 + K[:, 0] * (X[:, 0] - C * u0) V[:, 0] = (torch.ones((m, )) - K[:, 0] * C) * V0 for i in range(1, n): P[:, i - 1] = A * V[:, i - 1] * A + T K[:, i] = P[:, i - 1] * C / (C * P[:, i - 1] * C + sigma) u[:, i] = A * u[:, i - 1] + K[:, i] * (X[:, i] - C * A * u[:, i - 1]) V[:, i] = (torch.ones((m, )) - K[:, i] * C) * P[:, i - 1] X = u return [sample.reshape(shape) for sample in X.transpose(1, 0)] else: raise ValueError( 'The after_hook_linear_dynamical_system only supports np.ndarray and torch.Tensor. Please make sure the outputs of offline_transform ({}) are np.ndarray or torch.Tensor.' .format(type(data[0])))
Read the Docs v: latest
Versions
latest
stable
v1.1.2
v1.1.1
v1.1.0
v1.0.11
v1.0.10
v1.0.9
v1.0.8.post1
v1.0.8
v1.0.7
v1.0.6
v1.0.4
v1.0.3
v1.0.2
v1.0.1
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources