Shortcuts

Source code for torcheeg.datasets.module.folder_dataset

from pathlib import Path
from typing import Any, Callable, Dict, Tuple, Union

import mne
import numpy as np

from .base_dataset import BaseDataset
from ...utils import get_random_dir_path


def default_read_fn(file_path, **kwargs):
    # Load EEG file
    raw = mne.io.read_raw(file_path)
    # Convert raw to epochs
    epochs = mne.make_fixed_length_epochs(raw, duration=1)
    # Return EEG data
    return epochs


[docs]class FolderDataset(BaseDataset): ''' Read EEG samples and their corresponding labels from a fixed folder structure. This class allows two kinds of common file structures, :obj:`subject_in_label` and :obj:`label_in_subject`. Here, :obj:`subject_in_label` corresponds to the following file structure: .. code-block:: python tree # outputs label01 |- sub01.edf |- sub02.edf label02 |- sub01.edf |- sub02.edf And :obj:`label_in_subject` corresponds to the following file structure: .. code-block:: python tree # outputs sub01 |- label01.edf |- label02.edf sub02 |- label01.edf |- label02.edf An example dataset for GNN-based methods: .. code-block:: python from torcheeg.datasets import FolderDataset from torcheeg import transforms sfreq = 128 # Sampling rate n_channels = 14 # Number of channels duration = 5 # Data collected for 5 seconds for i in range(num_files): n_samples = sfreq * duration data = np.random.randn(n_channels, n_samples) ch_names = [f'ch_{i+1:03}' for i in range(n_channels)] ch_types = ['eeg'] * n_channels info = mne.create_info(ch_names, sfreq, ch_types) raw = mne.io.RawArray(data, info) file_name = f'sub{i+1}.fif' file_path = os.path.join('./root_folder/', file_name) raw.save(file_path) label_map = {'folder1': 0, 'folder2': 1} dataset = FolderDataset(root_path='./root_folder', structure='subject_in_label', num_channel=14, online_transform=transforms.ToTensor(), label_transform=transforms.Compose([ transforms.Select('label'), transforms.Lambda(lambda x: label_map[x]) ]), num_worker=4) Args: root_path (str): The path to the root folder. (default: :obj:`'./folder'`) structure (str): Folder structure, which affects how labels and subjects are mapped to EEG signal samples. Please refer to the above description of the structure of the two folders to select the correct parameters. (default: :obj:`'subject_in_label'`) read_fn (Callable): Method for reading files in a folder. By default, this class provides methods for reading files using :obj:`mne.io.read_raw`. At the same time, we allow users to pass in custom file reading methods. The first input parameter of whose is file_path, and other parameters are additional parameters passed in when the class is initialized. For example, you can pass :obj:`chunk_size=32` to :obj:`FolderDataset`, then :obj:`chunk_size` will be received here. online_transform (Callable, optional): The transformation of the EEG signals and baseline EEG signals. The input is a :obj:`np.ndarray`, and the ouput is used as the first and second value of each element in the dataset. (default: :obj:`None`) offline_transform (Callable, optional): The usage is the same as :obj:`online_transform`, but executed before generating IO intermediate results. (default: :obj:`None`) label_transform (Callable, optional): The transformation of the label. The input is an information dictionary, and the ouput is used as the third value of each element in the dataset. (default: :obj:`None`) io_path (str): The path to generated unified data IO, cached as an intermediate result. If set to None, a random path will be generated. (default: :obj:`None`) io_size (int): Maximum size database may grow to; used to size the memory mapping. If database grows larger than ``map_size``, an exception will be raised and the user must close and reopen. (default: :obj:`1048576`) io_mode (str): Storage mode of EEG signal. When io_mode is set to :obj:`lmdb`, TorchEEG provides an efficient database (LMDB) for storing EEG signals. LMDB may not perform well on limited operating systems, where a file system based EEG signal storage is also provided. When io_mode is set to :obj:`pickle`, pickle-based persistence files are used. When io_mode is set to :obj:`memory`, memory are used. (default: :obj:`lmdb`) num_worker (int): Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process. (default: :obj:`0`) verbose (bool): Whether to display logs during processing, such as progress bars, etc. (default: :obj:`True`) ''' def __init__(self, root_path: str = './folder', structure: str = 'subject_in_label', read_fn: Union[None, Callable] = default_read_fn, online_transform: Union[None, Callable] = None, offline_transform: Union[None, Callable] = None, label_transform: Union[None, Callable] = None, io_path: Union[None, str] = None, io_size: int = 1048576, io_mode: str = 'lmdb', num_worker: int = 0, verbose: bool = True, **kwargs): if io_path is None: io_path = get_random_dir_path(dir_prefix='datasets') # pass all arguments to super class params = { 'root_path': root_path, 'structure': structure, 'read_fn': read_fn, 'online_transform': online_transform, 'offline_transform': offline_transform, 'label_transform': label_transform, 'io_path': io_path, 'io_size': io_size, 'io_mode': io_mode, 'num_worker': num_worker, 'verbose': verbose } params.update(kwargs) super().__init__(**params) # save all arguments to __dict__ self.__dict__.update(params) @staticmethod def process_record(file: Any = None, offline_transform: Union[None, Callable] = None, read_fn: Union[None, Callable] = None, **kwargs): file_path, subject_id, label = file trial_samples = read_fn(file_path, **kwargs) events = [i[0] for i in trial_samples.events] events.append( events[-1] + np.diff(events)[0]) # time interval between all events are same write_pointer = 0 for i, trial_signal in enumerate(trial_samples.get_data()): t_eeg = trial_signal if not offline_transform is None: t = offline_transform(eeg=trial_signal) t_eeg = t['eeg'] clip_id = f'{subject_id}_{label}_{write_pointer}' write_pointer += 1 record_info = { 'subject_id': subject_id, 'start_at': events[i], 'end_at': events[i + 1], 'clip_id': clip_id, 'label': label } yield {'eeg': t_eeg, 'key': clip_id, 'info': record_info} def set_records(self, root_path: str = './folder', structure: str = 'subject_in_label', **kwargs): # get all the subfolders subfolders = [str(i) for i in Path(root_path).iterdir() if i.is_dir()] # get all the files in the subfolders file_path_list = [] for subfolder in subfolders: file_path_list += [ str(i) for i in Path(subfolder).iterdir() if i.is_file() ] # get the subject id if structure == 'subject_in_label': # get the file name without the extension subjects = [i.split('/')[-1].split('.')[0] for i in file_path_list] labels = [i.split('/')[-2] for i in file_path_list] elif structure == 'label_in_subject': subjects = [i.split('/')[-2] for i in file_path_list] labels = [i.split('/')[-1].split('.')[0] for i in file_path_list] else: raise ValueError('Unknown folder mode: {}'.format(structure)) file_path_subject_label = list(zip(file_path_list, subjects, labels)) return file_path_subject_label def __getitem__(self, index: int) -> Tuple[any, any, int, int, int]: info = self.read_info(index) eeg_index = str(info['clip_id']) eeg_record = str(info['_record_id']) eeg = self.read_eeg(eeg_record, eeg_index) signal = eeg label = info if self.online_transform: signal = self.online_transform(eeg=eeg)['eeg'] if self.label_transform: label = self.label_transform(y=info)['y'] return signal, label @property def repr_body(self) -> Dict: return dict( super().repr_body, **{ 'root_path': self.root_path, 'chunk_size': self.chunk_size, 'overlap': self.overlap, 'num_channel': self.num_channel, 'online_transform': self.online_transform, 'offline_transform': self.offline_transform, 'label_transform': self.label_transform, 'before_trial': self.before_trial, 'after_trial': self.after_trial, 'num_worker': self.num_worker, 'verbose': self.verbose, 'io_size': self.io_size })
Read the Docs v: latest
Versions
latest
stable
v1.1.2
v1.1.1
v1.1.0
v1.0.11
v1.0.10
v1.0.9
v1.0.8.post1
v1.0.8
v1.0.7
v1.0.6
v1.0.4
v1.0.3
v1.0.2
v1.0.1
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources