Source code for torcheeg.datasets.module.emotion_recognition.dreamer
import os
from typing import Callable, Dict, Tuple, Union, Any
import scipy.io as scio
from ..base_dataset import BaseDataset
from ....utils import get_random_dir_path
[docs]class DREAMERDataset(BaseDataset):
r'''
A multi-modal database consisting of electroencephalogram and electrocardiogram signals recorded during affect elicitation by means of audio-visual stimuli. This class generates training samples and test samples according to the given parameters, and caches the generated results in a unified input and output format (IO). The relevant information of the dataset is as follows:
- Author: Katsigiannis et al.
- Year: 2017
- Download URL: https://zenodo.org/record/546113
- Reference: Katsigiannis S, Ramzan N. DREAMER: A database for emotion recognition through EEG and ECG signals from wireless low-cost off-the-shelf devices[J]. IEEE journal of biomedical and health informatics, 2017, 22(1): 98-107.
- Stimulus: 18 movie clips.
- Signals: Electroencephalogram (14 channels at 128Hz), and electrocardiogram (2 channels at 256Hz) of 23 subjects.
- Rating: Arousal, valence, like/dislike, dominance, familiarity (all ona scale from 1 to 5).
In order to use this dataset, the download file :obj:`DREAMER.mat` is required.
An example dataset for CNN-based methods:
.. code-block:: python
from torcheeg.datasets import DREAMERDataset
from torcheeg import transforms
from torcheeg.datasets.constants import DREAMER_CHANNEL_LOCATION_DICT
dataset = DREAMERDataset(mat_path='./DREAMER.mat',
offline_transform=transforms.Compose([
transforms.BandDifferentialEntropy(),
transforms.ToGrid(DREAMER_CHANNEL_LOCATION_DICT)
]),
online_transform=transforms.ToTensor(),
label_transform=transforms.Compose([
transforms.Select('valence'),
transforms.Binary(3.0),
]))
print(dataset[0])
# EEG signal (torch.Tensor[4, 9, 9]),
# coresponding baseline signal (torch.Tensor[4, 9, 9]),
# label (int)
Another example dataset for CNN-based methods:
.. code-block:: python
from torcheeg.datasets import DREAMERDataset
from torcheeg import transforms
dataset = DREAMERDataset(mat_path='./DREAMER.mat',
online_transform=transforms.Compose([
transforms.To2d(),
transforms.ToTensor()
]),
label_transform=transforms.Compose([
transforms.Select(['valence', 'arousal']),
transforms.Binary(3.0),
transforms.BinariesToCategory()
]))
print(dataset[0])
# EEG signal (torch.Tensor[1, 14, 128]),
# coresponding baseline signal (torch.Tensor[1, 14, 128]),
# label (int)
An example dataset for GNN-based methods:
.. code-block:: python
from torcheeg.datasets import DREAMERDataset
from torcheeg import transforms
from torcheeg.datasets.constants import DREAMER_ADJACENCY_MATRIX
from torcheeg.transforms.pyg import ToG
dataset = DREAMERDataset(mat_path='./DREAMER.mat',
online_transform=transforms.Compose([
ToG(DREAMER_ADJACENCY_MATRIX)
]),
label_transform=transforms.Compose([
transforms.Select('arousal'),
transforms.Binary(3.0)
]))
print(dataset[0])
# EEG signal (torch_geometric.data.Data),
# coresponding baseline signal (torch_geometric.data.Data),
# label (int)
Args:
mat_path (str): Downloaded data files in pickled matlab formats (default: :obj:`'./DREAMER.mat'`)
chunk_size (int): Number of data points included in each EEG chunk as training or test samples. If set to -1, the EEG signal of a trial is used as a sample of a chunk. (default: :obj:`128`)
overlap (int): The number of overlapping data points between different chunks when dividing EEG chunks. (default: :obj:`0`)
num_channel (int): Number of channels used, of which the first 14 channels are EEG signals. (default: :obj:`14`)
num_baseline (int): Number of baseline signal chunks used. (default: :obj:`61`)
baseline_chunk_size (int): Number of data points included in each baseline signal chunk. The baseline signal in the DREAMER dataset has a total of 7808 data points. (default: :obj:`128`)
online_transform (Callable, optional): The transformation of the EEG signals and baseline EEG signals. The input is a :obj:`np.ndarray`, and the ouput is used as the first and second value of each element in the dataset. (default: :obj:`None`)
offline_transform (Callable, optional): The usage is the same as :obj:`online_transform`, but executed before generating IO intermediate results. (default: :obj:`None`)
label_transform (Callable, optional): The transformation of the label. The input is an information dictionary, and the ouput is used as the third value of each element in the dataset. (default: :obj:`None`)
before_trial (Callable, optional): The hook performed on the trial to which the sample belongs. It is performed before the offline transformation and thus typically used to implement context-dependent sample transformations, such as moving averages, etc. The input of this hook function is a 2D EEG signal with shape (number of electrodes, number of data points), whose ideal output shape is also (number of electrodes, number of data points).
after_trial (Callable, optional): The hook performed on the trial to which the sample belongs. It is performed after the offline transformation and thus typically used to implement context-dependent sample transformations, such as moving averages, etc. The input and output of this hook function should be a sequence of dictionaries representing a sequence of EEG samples. Each dictionary contains two key-value pairs, indexed by :obj:`eeg` (the EEG signal matrix) and :obj:`key` (the index in the database) respectively.
io_path (str): The path to generated unified data IO, cached as an intermediate result. If set to None, a random path will be generated. (default: :obj:`None`)
io_size (int): Maximum size database may grow to; used to size the memory mapping. If database grows larger than ``map_size``, an exception will be raised and the user must close and reopen. (default: :obj:`1048576`)
io_mode (str): Storage mode of EEG signal. When io_mode is set to :obj:`lmdb`, TorchEEG provides an efficient database (LMDB) for storing EEG signals. LMDB may not perform well on limited operating systems, where a file system based EEG signal storage is also provided. When io_mode is set to :obj:`pickle`, pickle-based persistence files are used. When io_mode is set to :obj:`memory`, memory are used. (default: :obj:`lmdb`)
num_worker (int): Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process. (default: :obj:`0`)
verbose (bool): Whether to display logs during processing, such as progress bars, etc. (default: :obj:`True`)
'''
def __init__(self,
mat_path: str = './DREAMER.mat',
chunk_size: int = 128,
overlap: int = 0,
num_channel: int = 14,
num_baseline: int = 61,
baseline_chunk_size: int = 128,
online_transform: Union[None, Callable] = None,
offline_transform: Union[None, Callable] = None,
label_transform: Union[None, Callable] = None,
before_trial: Union[None, Callable] = None,
after_trial: Union[Callable, None] = None,
after_session: Union[Callable, None] = None,
after_subject: Union[Callable, None] = None,
io_path: Union[None, str] = None,
io_size: int = 1048576,
io_mode: str = 'lmdb',
num_worker: int = 0,
verbose: bool = True):
if io_path is None:
io_path = get_random_dir_path(dir_prefix='datasets')
# pass all arguments to super class
params = {
'mat_path': mat_path,
'chunk_size': chunk_size,
'overlap': overlap,
'num_channel': num_channel,
'num_baseline': num_baseline,
'baseline_chunk_size': baseline_chunk_size,
'online_transform': online_transform,
'offline_transform': offline_transform,
'label_transform': label_transform,
'before_trial': before_trial,
'after_trial': after_trial,
'after_session': after_session,
'after_subject': after_subject,
'io_path': io_path,
'io_size': io_size,
'io_mode': io_mode,
'num_worker': num_worker,
'verbose': verbose
}
super().__init__(**params)
# save all arguments to __dict__
self.__dict__.update(params)
@staticmethod
def process_record(file: Any = None,
mat_path: str = './DREAMER.mat',
chunk_size: int = 128,
overlap: int = 0,
num_channel: int = 14,
num_baseline: int = 61,
baseline_chunk_size: int = 128,
before_trial: Union[None, Callable] = None,
offline_transform: Union[None, Callable] = None,
**kwargs):
subject = file
mat_data = scio.loadmat(mat_path,
verify_compressed_data_integrity=False)
trial_len = len(
mat_data['DREAMER'][0, 0]['Data'][0,
0]['EEG'][0,
0]['stimuli'][0,
0]) # 18
write_pointer = 0
# loop for each trial
for trial_id in range(trial_len):
# extract baseline signals
trial_baseline_sample = mat_data['DREAMER'][0, 0]['Data'][
0, subject]['EEG'][0, 0]['baseline'][0, 0][trial_id, 0]
trial_baseline_sample = trial_baseline_sample[:, :num_channel].swapaxes(
1, 0) # channel(14), timestep(61*128)
trial_baseline_sample = trial_baseline_sample[:, :num_baseline *
baseline_chunk_size].reshape(
num_channel,
num_baseline,
baseline_chunk_size
).mean(
axis=1
) # channel(14), timestep(128)
# record the common meta info
trial_meta_info = {'subject_id': subject, 'trial_id': trial_id}
trial_meta_info['valence'] = mat_data['DREAMER'][0, 0]['Data'][
0, subject]['ScoreValence'][0, 0][trial_id, 0]
trial_meta_info['arousal'] = mat_data['DREAMER'][0, 0]['Data'][
0, subject]['ScoreArousal'][0, 0][trial_id, 0]
trial_meta_info['dominance'] = mat_data['DREAMER'][0, 0]['Data'][
0, subject]['ScoreDominance'][0, 0][trial_id, 0]
trial_samples = mat_data['DREAMER'][0, 0]['Data'][
0, subject]['EEG'][0, 0]['stimuli'][0, 0][trial_id, 0]
trial_samples = trial_samples[:, :num_channel].swapaxes(
1, 0) # channel(14), timestep(n*128)
if before_trial:
trial_samples = before_trial(trial_samples)
start_at = 0
if chunk_size <= 0:
dynamic_chunk_size = trial_samples.shape[1] - start_at
else:
dynamic_chunk_size = chunk_size
# chunk with chunk size
end_at = dynamic_chunk_size
# calculate moving step
step = dynamic_chunk_size - overlap
while end_at <= trial_samples.shape[1]:
clip_sample = trial_samples[:, start_at:end_at]
t_eeg = clip_sample
t_baseline = trial_baseline_sample
if not offline_transform is None:
t = offline_transform(eeg=clip_sample,
baseline=trial_baseline_sample)
t_eeg = t['eeg']
t_baseline = t['baseline']
# put baseline signal into IO
if not 'baseline_id' in trial_meta_info:
trial_base_id = f'{subject}_{write_pointer}'
yield {'eeg': t_baseline, 'key': trial_base_id}
write_pointer += 1
trial_meta_info['baseline_id'] = trial_base_id
clip_id = f'{subject}_{write_pointer}'
write_pointer += 1
# record meta info for each signal
record_info = {
'start_at': start_at,
'end_at': end_at,
'clip_id': clip_id
}
record_info.update(trial_meta_info)
yield {'eeg': t_eeg, 'key': clip_id, 'info': record_info}
start_at = start_at + step
end_at = start_at + dynamic_chunk_size
def set_records(self, mat_path: str = './DREAMER.mat', **kwargs):
assert os.path.exists(
mat_path
), f'mat_path ({mat_path}) does not exist. Please download the dataset and set the mat_path to the downloaded path.'
mat_data = scio.loadmat(mat_path,
verify_compressed_data_integrity=False)
subject_len = len(mat_data['DREAMER'][0, 0]['Data'][0]) # 23
return list(range(subject_len))
def __getitem__(self, index: int) -> Tuple:
info = self.read_info(index)
eeg_index = str(info['clip_id'])
eeg_record = str(info['_record_id'])
eeg = self.read_eeg(eeg_record, eeg_index)
baseline_index = str(info['baseline_id'])
baseline = self.read_eeg(eeg_record, baseline_index)
signal = eeg
label = info
if self.online_transform:
signal = self.online_transform(eeg=eeg, baseline=baseline)['eeg']
if self.label_transform:
label = self.label_transform(y=info)['y']
return signal, label
@property
def repr_body(self) -> Dict:
return dict(
super().repr_body, **{
'mat_path': self.mat_path,
'chunk_size': self.chunk_size,
'overlap': self.overlap,
'num_channel': self.num_channel,
'num_baseline': self.num_baseline,
'baseline_chunk_size': self.baseline_chunk_size,
'online_transform': self.online_transform,
'offline_transform': self.offline_transform,
'label_transform': self.label_transform,
'before_trial': self.before_trial,
'after_trial': self.after_trial,
'num_worker': self.num_worker,
'verbose': self.verbose,
'io_size': self.io_size
})