Shortcuts

Source code for torcheeg.datasets.module.mne_dataset

from typing import Callable, Dict, List, Tuple, Union

import mne

from ..functional.mne import mne_constructor
from .base_dataset import BaseDataset


[docs]class MNEDataset(BaseDataset): r''' A generic EEG analysis dataset that allows creating datasets from :obj:`mne.Epochs`, and caches the generated results in a unified input and output format (IO). It is generally used to support custom datasets or datasets not yet provided by TorchEEG. :obj:`MNEDataset` allows a list of :obj:`mne.Epochs` and the corresponding description dictionary as input, and divides the signals in :obj:`mne.Epochs` into several segments according to the configuration information provided by the user. These segments will be annotated by the description dictionary elements and the event type annotated in :obj:`mne.Epochs`. Here is an example case shows the use of :obj:`MNEDataset`: .. code-block:: python # subject index and run index of mne.Epochs metadata_list = [{ 'subject': 1, 'run': 3 }, { 'subject': 1, 'run': 7 }, { 'subject': 1, 'run': 11 }] epochs_list = [] for metadata in metadata_list: physionet_path = mne.datasets.eegbci.load_data(metadata['subject'], metadata['run'], update_path=False)[0] raw = mne.io.read_raw_edf(physionet_path, preload=True, stim_channel='auto') events, _ = mne.events_from_annotations(raw) picks = mne.pick_types(raw.info, meg=False, eeg=True, stim=False, eog=False, exclude='bads') # init Epochs with raw EEG signals and corresponding event annotations epochs_list.append(mne.Epochs(raw, events, picks=picks)) # split into chunks of 160 data points (1s) dataset = MNEDataset(epochs_list=epochs_list, metadata_list=metadata_list, chunk_size=160, overlap=0, num_channel=60, io_path=io_path, offline_transform=transforms.Compose( [transforms.BandDifferentialEntropy()]), online_transform=transforms.ToTensor(), label_transform=transforms.Compose([ transforms.Select('event') ]), num_worker=2) print(dataset[0]) # EEG signal (torch.Tensor[60, 4]), # coresponding baseline signal (torch.Tensor[60, 4]), # label (int) In particular, TorchEEG utilizes the producer-consumer model to allow multi-process data preprocessing. If your data preprocessing is time consuming, consider increasing :obj:`num_worker` for higher speedup. If running under Windows, please use the proper idiom in the main module: .. code-block:: python if __name__ == '__main__': # subject index and run index of mne.Epochs metadata_list = [{ 'subject': 1, 'run': 3 }, { 'subject': 1, 'run': 7 }, { 'subject': 1, 'run': 11 }] epochs_list = [] for metadata in metadata_list: physionet_path = mne.datasets.eegbci.load_data(metadata['subject'], metadata['run'], update_path=False)[0] raw = mne.io.read_raw_edf(physionet_path, preload=True, stim_channel='auto') events, _ = mne.events_from_annotations(raw) picks = mne.pick_types(raw.info, meg=False, eeg=True, stim=False, eog=False, exclude='bads') # init Epochs with raw EEG signals and corresponding event annotations epochs_list.append(mne.Epochs(raw, events, picks=picks)) # split into chunks of 160 data points (1s) dataset = MNEDataset(epochs_list=epochs_list, metadata_list=metadata_list, chunk_size=160, overlap=0, num_channel=60, io_path=io_path, offline_transform=transforms.Compose( [transforms.BandDifferentialEntropy()]), online_transform=transforms.ToTensor(), label_transform=transforms.Compose([ transforms.Select('event') ]), num_worker=2) print(dataset[0]) # EEG signal (torch.Tensor[60, 4]), # coresponding baseline signal (torch.Tensor[60, 4]), # label (int) Args: epochs_list (list): A list of :obj:`mne.Epochs`. :obj:`MNEDataset` will divide the signals in :obj:`mne.Epochs` into several segments according to the :obj:`chunk_size` and :obj:`overlap` information provided by the user. The divided segments will be transformed and cached in a unified input and output format (IO) for accessing. metadata_list (list): A list of dictionaries of the same length as :obj:`epochs_list`. Each of these dictionaries is annotated with meta-information about :obj:`mne.Epochs`, such as subject index, experimental dates, etc. These annotated meta-information will be added to the element corresponding to :obj:`mne.Epochs` for use as labels for the sample. chunk_size (int): Number of data points included in each EEG chunk as training or test samples. If set to -1, the EEG signal of a trial is used as a sample of a chunk. If set to -1, the EEG signal is not segmented, and the length of the chunk is the length of the event. (default: :obj:`-1`) overlap (int): The number of overlapping data points between different chunks when dividing EEG chunks. (default: :obj:`0`) num_channel (int): Number of channels used. If set to -1, all electrodes are used (default: :obj:`-1`) online_transform (Callable, optional): The transformation of the EEG signals and baseline EEG signals. The input is a :obj:`np.ndarray`, and the ouput is used as the first and second value of each element in the dataset. (default: :obj:`None`) offline_transform (Callable, optional): The usage is the same as :obj:`online_transform`, but executed before generating IO intermediate results. (default: :obj:`None`) label_transform (Callable, optional): The transformation of the label. The input is an information dictionary, and the ouput is used as the third value of each element in the dataset. (default: :obj:`None`) before_trial (Callable, optional): The hook performed on the trial to which the sample belongs. It is performed before the offline transformation and thus typically used to implement context-dependent sample transformations, such as moving averages, etc. The input and output of this hook function should be a :obj:`mne.Epoch`. after_trial (Callable, optional): The hook performed on the trial to which the sample belongs. It is performed after the offline transformation and thus typically used to implement context-dependent sample transformations, such as moving averages, etc. The input and output of this hook function should be a sequence of dictionaries representing a sequence of EEG samples. Each dictionary contains two key-value pairs, indexed by :obj:`eeg` (the EEG signal matrix) and :obj:`key` (the index in the database) respectively io_path (str): The path to generated unified data IO, cached as an intermediate result. (default: :obj:`./io/deap`) num_worker (str): How many subprocesses to use for data processing. (default: :obj:`0`) verbose (bool): Whether to display logs during processing, such as progress bars, etc. (default: :obj:`True`) verbose (bool): Whether to display logs during processing, such as progress bars, etc. (default: :obj:`True`) cache_size (int): Maximum size database may grow to; used to size the memory mapping. If database grows larger than ``map_size``, an exception will be raised and the user must close and reopen. (default: :obj:`64 * 1024 * 1024 * 1024`) ''' def __init__(self, epochs_list: List[mne.Epochs], metadata_list: List[Dict], chunk_size: int = -1, overlap: int = 0, num_channel: int = -1, online_transform: Union[None, Callable] = None, offline_transform: Union[None, Callable] = None, label_transform: Union[None, Callable] = None, before_trial: Union[None, Callable] = None, after_trial: Union[Callable, None] = None, io_path: str = './io/mne', num_worker: int = 0, verbose: bool = True, cache_size: int = 64 * 1024 * 1024 * 1024): mne_constructor(epochs_list=epochs_list, metadata_list=metadata_list, chunk_size=chunk_size, overlap=overlap, num_channel=num_channel, before_trial=before_trial, transform=offline_transform, after_trial=after_trial, io_path=io_path, num_worker=num_worker, verbose=verbose, cache_size=cache_size) super().__init__(io_path) self.chunk_size = chunk_size self.overlap = overlap self.num_channel = num_channel self.online_transform = online_transform self.offline_transform = offline_transform self.label_transform = label_transform self.before_trial = before_trial self.after_trial = after_trial self.num_worker = num_worker self.verbose = verbose self.cache_size = cache_size def __getitem__(self, index: int) -> Tuple: info = self.info.iloc[index].to_dict() eeg_index = str(info['clip_id']) eeg = self.eeg_io.read_eeg(eeg_index) signal = eeg label = info if self.online_transform: signal = self.online_transform(eeg=eeg)['eeg'] if self.label_transform: label = self.label_transform(y=info)['y'] return signal, label @property def repr_body(self) -> Dict: return dict( super().repr_body, **{ 'chunk_size': self.chunk_size, 'overlap': self.overlap, 'num_channel': self.num_channel, 'online_transform': self.online_transform, 'offline_transform': self.offline_transform, 'label_transform': self.label_transform, 'before_trial': self.before_trial, 'after_trial': self.after_trial, 'num_worker': self.num_worker, 'verbose': self.verbose, 'cache_size': self.cache_size })

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources