Source code for torcheeg.datasets.module.emotion_recognition.dreamer

from typing import Callable, Dict, Tuple, Union

from ...constants.emotion_recognition.dreamer import (
    DREAMER_ADJACENCY_MATRIX, DREAMER_CHANNEL_LOCATION_DICT)
from ...functional.emotion_recognition.dreamer import dreamer_constructor
from ..base_dataset import BaseDataset


[docs]class DREAMERDataset(BaseDataset): r''' A multi-modal database consisting of electroencephalogram and electrocardiogram signals recorded during affect elicitation by means of audio-visual stimuli. This class generates training samples and test samples according to the given parameters, and caches the generated results in a unified input and output format (IO). The relevant information of the dataset is as follows: - Author: Katsigiannis et al. - Year: 2017 - Download URL: https://zenodo.org/record/546113 - Reference: Katsigiannis S, Ramzan N. DREAMER: A database for emotion recognition through EEG and ECG signals from wireless low-cost off-the-shelf devices[J]. IEEE journal of biomedical and health informatics, 2017, 22(1): 98-107. - Stimulus: 18 movie clips. - Signals: Electroencephalogram (14 channels at 128Hz), and electrocardiogram (2 channels at 256Hz) of 23 subjects. - Rating: Arousal, valence, like/dislike, dominance, familiarity (all ona scale from 1 to 5). In order to use this dataset, the download file :obj:`DREAMER.mat` is required. An example dataset for CNN-based methods: .. code-block:: python dataset = DREAMERDataset(io_path=f'./dreamer', mat_path='./DREAMER.mat', offline_transform=transforms.Compose([ transforms.BandDifferentialEntropy(), transforms.ToGrid(DREAMER_CHANNEL_LOCATION_DICT) ]), online_transform=transforms.ToTensor(), label_transform=transforms.Compose([ transforms.Select('valence'), transforms.Binary(3.0), ])) print(dataset[0]) # EEG signal (torch.Tensor[128, 9, 9]), # coresponding baseline signal (torch.Tensor[128, 9, 9]), # label (int) Another example dataset for CNN-based methods: .. code-block:: python dataset = DREAMERDataset(io_path=f'./dreamer', mat_path='./DREAMER.mat', online_transform=transforms.Compose([ transforms.To2d(), transforms.ToTensor() ]), label_transform=transforms.Compose([ transforms.Select(['valence', 'arousal']), transforms.Binary(3.0), transforms.BinariesToCategory() ])) print(dataset[0]) # EEG signal (torch.Tensor[1, 14, 128]), # coresponding baseline signal (torch.Tensor[1, 14, 128]), # label (int) An example dataset for GNN-based methods: .. code-block:: python dataset = DREAMERDataset(io_path=f'./dreamer', mat_path='./DREAMER.mat', online_transform=transforms.Compose([ transforms.pyg.ToG(DREAMER_ADJACENCY_MATRIX) ]), label_transform=transforms.Compose([ transforms.Select('arousal'), transforms.Binary(3.0) ])) print(dataset[0]) # EEG signal (torch_geometric.data.Data), # coresponding baseline signal (torch_geometric.data.Data), # label (int) In particular, TorchEEG utilizes the producer-consumer model to allow multi-process data preprocessing. If your data preprocessing is time consuming, consider increasing :obj:`num_worker` for higher speedup. If running under Windows, please use the proper idiom in the main module: .. code-block:: python if __name__ == '__main__': dataset = DREAMERDataset(io_path=f'./dreamer', mat_path='./DREAMER.mat', online_transform=transforms.Compose([ transforms.pyg.ToG(DREAMER_ADJACENCY_MATRIX) ]), label_transform=transforms.Compose([ transforms.Select('arousal'), transforms.Binary(3.0) ]), num_worker=4) print(dataset[0]) # EEG signal (torch_geometric.data.Data), # coresponding baseline signal (torch_geometric.data.Data), # label (int) Args: mat_path (str): Downloaded data files in pickled matlab formats (default: :obj:`'./DREAMER.mat'`) chunk_size (int): Number of data points included in each EEG chunk as training or test samples. (default: :obj:`128`) overlap (int): The number of overlapping data points between different chunks when dividing EEG chunks. (default: :obj:`0`) num_channel (int): Number of channels used, of which the first 14 channels are EEG signals. (default: :obj:`14`) num_baseline (int): Number of baseline signal chunks used. (default: :obj:`61`) baseline_chunk_size (int): Number of data points included in each baseline signal chunk. The baseline signal in the DREAMER dataset has a total of 7808 data points. (default: :obj:`128`) online_transform (Callable, optional): The transformation of the EEG signals and baseline EEG signals. The input is a :obj:`np.ndarray`, and the ouput is used as the first and second value of each element in the dataset. (default: :obj:`None`) offline_transform (Callable, optional): The usage is the same as :obj:`online_transform`, but executed before generating IO intermediate results. (default: :obj:`None`) label_transform (Callable, optional): The transformation of the label. The input is an information dictionary, and the ouput is used as the third value of each element in the dataset. (default: :obj:`None`) io_path (str): The path to generated unified data IO, cached as an intermediate result. (default: :obj:`./io/dreamer`) num_worker (str): How many subprocesses to use for data processing. (default: :obj:`0`) verbose (bool): Whether to display logs during processing, such as progress bars, etc. (default: :obj:`True`) cache_size (int): Maximum size database may grow to; used to size the memory mapping. If database grows larger than ``map_size``, an exception will be raised and the user must close and reopen. (default: :obj:`64 * 1024 * 1024 * 1024`) ''' channel_location_dict = DREAMER_CHANNEL_LOCATION_DICT adjacency_matrix = DREAMER_ADJACENCY_MATRIX def __init__(self, mat_path: str = './DREAMER.mat', chunk_size: int = 128, overlap: int = 0, num_channel: int = 14, num_baseline: int = 61, baseline_chunk_size: int = 128, online_transform: Union[None, Callable] = None, offline_transform: Union[None, Callable] = None, label_transform: Union[None, Callable] = None, io_path: str = './io/dreamer', num_worker: int = 0, verbose: bool = True, cache_size: int = 64 * 1024 * 1024 * 1024): dreamer_constructor(mat_path=mat_path, chunk_size=chunk_size, overlap=overlap, num_channel=num_channel, num_baseline=num_baseline, baseline_chunk_size=baseline_chunk_size, transform=offline_transform, io_path=io_path, num_worker=num_worker, verbose=verbose, cache_size=cache_size) super().__init__(io_path) self.mat_path = mat_path self.chunk_size = chunk_size self.overlap = overlap self.num_channel = num_channel self.num_baseline = num_baseline self.baseline_chunk_size = baseline_chunk_size self.online_transform = online_transform self.offline_transform = offline_transform self.label_transform = label_transform self.num_worker = num_worker self.verbose = verbose self.cache_size = cache_size def __getitem__(self, index: int) -> Tuple: info = self.info.iloc[index].to_dict() eeg_index = str(info['clip_id']) eeg = self.eeg_io.read_eeg(eeg_index) baseline_index = str(info['baseline_id']) baseline = self.eeg_io.read_eeg(baseline_index) signal = eeg label = info if self.online_transform: signal = self.online_transform(eeg=eeg, baseline=baseline)['eeg'] if self.label_transform: label = self.label_transform(y=info)['y'] return signal, label @property def repr_body(self) -> Dict: return dict( super().repr_body, **{ 'mat_path': self.mat_path, 'chunk_size': self.chunk_size, 'overlap': self.overlap, 'num_channel': self.num_channel, 'num_baseline': self.num_baseline, 'baseline_chunk_size': self.baseline_chunk_size, 'online_transform': self.online_transform, 'offline_transform': self.offline_transform, 'label_transform': self.label_transform, 'num_worker': self.num_worker, 'verbose': self.verbose, 'cache_size': self.cache_size })