Shortcuts

Source code for torcheeg.datasets.module.mne_raw_dataset

from typing import Any, Callable, Dict, Tuple, Union, List
import mne
import numpy as np

from .base_dataset import BaseDataset
from ...utils import get_random_dir_path

[docs]class MNERawDataset(BaseDataset): ''' Process a list of MNE Raw objects and corresponding information dictionaries. This dataset is particularly useful for working with pre-loaded MNE Raw objects, such as those obtained from various EEG datasets like Physionet EEG Motor Movement/Imagery Dataset. The dataset splits the continuous EEG data into epochs based on the specified chunk size and overlap. Each epoch is associated with the corresponding information from the info_list. .. code-block:: python import mne from torcheeg.datasets import MNERawDataset from torcheeg import transforms subject_id = 22 event_codes = [5, 6, 9, 10, 13, 14] physionet_paths = mne.datasets.eegbci.load_data( subject_id, event_codes, update_path=False) # Load each of the files raw_list = [mne.io.read_raw_edf(path, preload=True, stim_channel='auto') for path in physionet_paths] info_list = [{"trial_id": event_code, "subject_id": subject_id} for event_code in event_codes] dataset = MNERawDataset(raw_list=raw_list, info_list=info_list, chunk_size=500, overlap=0, online_transform=transforms.ToTensor(), label_transform=transforms.Select('trial_id')) Args: raw_list (List): A list of MNE Raw objects containing the EEG data. info_list (List): A list of dictionaries containing metadata for each Raw object. Each dictionary should correspond to the Raw object at the same index in raw_list. chunk_size (int): The size of each epoch in samples. (default: :obj:`3000`) overlap (int): The number of overlapping samples between consecutive epochs. (default: :obj:`0`) online_transform (Callable, optional): The transformation of the EEG signals and baseline EEG signals. The input is a :obj:`np.ndarray`, and the ouput is used as the first and second value of each element in the dataset. (default: :obj:`None`) offline_transform (Callable, optional): The usage is the same as :obj:`online_transform`, but executed before generating IO intermediate results. (default: :obj:`None`) label_transform (Callable, optional): The transformation of the label. The input is an information dictionary, and the ouput is used as the third value of each element in the dataset. (default: :obj:`None`) io_path (str): The path to generated unified data IO, cached as an intermediate result. If set to None, a random path will be generated. (default: :obj:`None`) io_size (int): Maximum size database may grow to; used to size the memory mapping. If database grows larger than ``map_size``, an exception will be raised and the user must close and reopen. (default: :obj:`1048576`) io_mode (str): Storage mode of EEG signal. When io_mode is set to :obj:`lmdb`, TorchEEG provides an efficient database (LMDB) for storing EEG signals. LMDB may not perform well on limited operating systems, where a file system based EEG signal storage is also provided. When io_mode is set to :obj:`pickle`, pickle-based persistence files are used. When io_mode is set to :obj:`memory`, memory are used. (default: :obj:`lmdb`) num_worker (int): Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process. (default: :obj:`0`) verbose (bool): Whether to display logs during processing, such as progress bars, etc. (default: :obj:`True`) ''' def __init__(self, raw_list: List, info_list: List, chunk_size: int = 3000, overlap: int = 0, online_transform: Union[None, Callable] = None, offline_transform: Union[None, Callable] = None, label_transform: Union[None, Callable] = None, io_path: Union[None, str] = None, io_size: int = 1048576, io_mode: str = 'lmdb', num_worker: int = 0, verbose: bool = True, **kwargs): if io_path is None: io_path = get_random_dir_path(dir_prefix='datasets') params = { 'raw_list': raw_list, 'info_list': info_list, 'chunk_size': chunk_size, 'overlap': overlap, 'online_transform': online_transform, 'offline_transform': offline_transform, 'label_transform': label_transform, 'io_path': io_path, 'io_size': io_size, 'io_mode': io_mode, 'num_worker': num_worker, 'verbose': verbose } params.update(kwargs) super().__init__(**params) self.__dict__.update(params) @staticmethod def read_record(record: str, **kwargs) -> Dict: return {} @staticmethod def process_record(record: Tuple, offline_transform: Union[None, Callable] = None, chunk_size: int = 500, overlap: int = 0, **kwargs): raw, info = record data, times = raw[:, :] # Calculate the step size step = chunk_size - overlap # Generate chunks for i in range(0, data.shape[1] - chunk_size + 1, step): chunk = data[:, i:i+chunk_size] if offline_transform is not None: chunk = offline_transform(eeg=chunk)['eeg'] clip_id = f"{info['subject_id']}_{info['trial_id']}_{i}" record_info = { **info, 'start_at': times[i], 'end_at': times[i+chunk_size-1], 'clip_id': clip_id } yield {'eeg': chunk, 'key': clip_id, 'info': record_info} def set_records(self, raw_list: List[mne.io.BaseRaw], info_list: List[Dict[str, Any]], **kwargs): return list(zip(raw_list, info_list)) def __getitem__(self, index: int) -> Tuple[Any, Any]: info = self.read_info(index) eeg_index = str(info['clip_id']) eeg_record = str(info['_record_id']) eeg = self.read_eeg(eeg_record, eeg_index) signal = eeg label = info if self.online_transform: signal = self.online_transform(eeg=eeg)['eeg'] if self.label_transform: label = self.label_transform(y=info)['y'] return signal, label @property def repr_body(self) -> Dict: return dict( super().repr_body, **{ 'chunk_size': self.chunk_size, 'overlap': self.overlap, 'online_transform': self.online_transform, 'offline_transform': self.offline_transform, 'label_transform': self.label_transform, 'num_worker': self.num_worker, 'verbose': self.verbose, 'io_size': self.io_size })

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources