Source code for torcheeg.datasets.module.emotion_recognition.dreamer

from typing import Callable, Union, Tuple

from ..base_dataset import BaseDataset
from ...functional.emotion_recognition.dreamer import dreamer_constructor


[docs]class DREAMERDataset(BaseDataset): r''' A multi-modal database consisting of electroencephalogram and electrocardiogram signals recorded during affect elicitation by means of audio-visual stimuli. This class generates training samples and test samples according to the given parameters, and caches the generated results in a unified input and output format (IO). The relevant information of the dataset is as follows: - Author: Katsigiannis et al. - Year: 2017 - Download URL: https://zenodo.org/record/546113 - Reference: Katsigiannis S, Ramzan N. DREAMER: A database for emotion recognition through EEG and ECG signals from wireless low-cost off-the-shelf devices[J]. IEEE journal of biomedical and health informatics, 2017, 22(1): 98-107. - Stimulus: 18 movie clips. - Signals: Electroencephalogram (14 channels at 128Hz), and electrocardiogram (2 channels at 256Hz) of 23 subjects. - Rating: Arousal, valence, like/dislike, dominance, familiarity (all ona scale from 1 to 5). An example dataset for CNN-based methods: .. code-block:: python dataset = DREAMERDataset(io_path=f'./dreamer', mat_path='./DREAMER.mat', offline_transform=.Compose([ transforms.BandDifferentialEntropy(), transforms.ToGrid(DREAMER_CHANNEL_LOCATION_DICT) ]), online_transform=transforms.ToTensor(), label_transform=transforms.Compose([ transforms.Select('valence'), transforms.Binary(3.0), ])) print(dataset[0]) # EEG signal (torch.Tensor[128, 9, 9]), # coresponding baseline signal (torch.Tensor[128, 9, 9]), # label (int) Another example dataset for CNN-based methods: .. code-block:: python dataset = DREAMERDataset(io_path=f'./dreamer', mat_path='./DREAMER.mat', online_transform=transforms.Compose([ transforms.ToTensor(), transforms.Lambda(lambda x: x.unsqueeze(0)) ]), label_transform=transforms.Compose([ transforms.Select(['valence', 'arousal']), transforms.Binary(3.0), transforms.BinariesToCategory() ])) print(dataset[0]) # EEG signal (torch.Tensor[1, 14, 128]), # coresponding baseline signal (torch.Tensor[1, 14, 128]), # label (int) An example dataset for GNN-based methods: .. code-block:: python dataset = DREAMERDataset(io_path=f'./dreamer', mat_path='./DREAMER.mat', online_transform=transforms.Compose([ transforms.ToG(DREAMER_ADJACENCY_MATRIX) ]), label_transform=transforms.Compose([ transforms.Select('arousal'), transforms.Binary(3.0) ])) print(dataset[0]) # EEG signal (torch_geometric.data.Data), # coresponding baseline signal (torch_geometric.data.Data), # label (int) In particular, TorchEEG utilizes the producer-consumer model to allow multi-process data preprocessing. If your data preprocessing is time consuming, consider increasing :obj:`num_worker` for higher speedup. Args: mat_path (str): Downloaded data files in pickled matlab formats (default: :obj:`'./DREAMER.mat'`) chunk_size (int): Number of data points included in each EEG chunk as training or test samples. (default: :obj:`128`) overlap (int): The number of overlapping data points between different chunks when dividing EEG chunks. (default: :obj:`0`) channel_num (int): Number of channels used, of which the first 14 channels are EEG signals. (default: :obj:`14`) baseline_num (int): Number of baseline signal chunks used. (default: :obj:`61`) baseline_chunk_size (int): Number of data points included in each baseline signal chunk. The baseline signal in the DREAMER dataset has a total of 7808 data points. (default: :obj:`128`) online_transform (Callable, optional): The transformation of the EEG signals and baseline EEG signals. The input is a :obj:`np.ndarray`, and the ouput is used as the first and second value of each element in the dataset. (default: :obj:`None`) offline_transform (Callable, optional): The usage is the same as :obj:`online_transform`, but executed before generating IO intermediate results. (default: :obj:`None`) label_transform (Callable, optional): The transformation of the label. The input is an information dictionary, and the ouput is used as the third value of each element in the dataset. (default: :obj:`None`) io_path (str): The path to generated unified data IO, cached as an intermediate result. (default: :obj:`./io/dreamer`) num_worker (str): How many subprocesses to use for data processing. (default: :obj:`1`) verbose (bool): Whether to display logs during processing, such as progress bars, etc. (default: :obj:`True`) ''' def __init__(self, mat_path: str = './DREAMER.mat', chunk_size: int = 128, overlap: int = 0, channel_num: int = 14, baseline_num: int = 61, baseline_chunk_size: int = 128, online_transform: Union[None, Callable] = None, offline_transform: Union[None, Callable] = None, label_transform: Union[None, Callable] = None, io_path: str = './io/dreamer', num_worker: int = 1, verbose: bool = True): dreamer_constructor(mat_path=mat_path, chunk_size=chunk_size, overlap=overlap, channel_num=channel_num, baseline_num=baseline_num, baseline_chunk_size=baseline_chunk_size, transform=offline_transform, io_path=io_path, num_worker=num_worker, verbose=verbose) super().__init__(io_path) self.mat_path = mat_path self.chunk_size = chunk_size self.overlap = overlap self.channel_num = channel_num self.baseline_num = baseline_num self.baseline_chunk_size = baseline_chunk_size self.online_transform = online_transform self.offline_transform = offline_transform self.label_transform = label_transform self.io_path = io_path self.num_worker = num_worker self.verbose = verbose def __getitem__(self, index: int) -> Tuple: info = self.info.iloc[index].to_dict() eeg_index = str(info['clip_id']) eeg = self.eeg_io.read_eeg(eeg_index) if self.online_transform: eeg = self.online_transform(eeg) baseline_index = str(info['baseline_id']) baseline = self.eeg_io.read_eeg(baseline_index) if self.online_transform: baseline = self.online_transform(baseline) if self.label_transform: info = self.label_transform(info) if isinstance(info, list): return (eeg, baseline, *info) if isinstance(info, dict): return (eeg, baseline, *info.values()) return eeg, baseline, info