Shortcuts

Source code for torcheeg.datasets.module.pair_dataset

from typing import Any, Callable, List, Optional, Tuple

import numpy as np
import pandas as pd
from torch.utils.data import Dataset

from .base_dataset import BaseDataset


[docs]class PairDataset(Dataset): r''' A dataset class for pairing multiple datasets. This class combines multiple datasets based on a specified join key and join type. It is particularly useful for constructing multimodal datasets, such as merging dataset A and dataset B to simultaneously access both modalities for the same subject during training. Below is a quick start example: .. code-block:: python from torcheeg.datasets import PairDataset, HMCDataset dataset_eeg = HMCDataset(root_path='./HMC/recordings', channels=['EEG F4-M1', 'EEG C4-M1', 'EEG O2-M1', 'EEG C3-M2']) dataset_ecg = HMCDataset(root_path='./HMC/recordings', channels=['ECG']) dataset = PairDataset(datasets=[dataset_eeg, dataset_ecg], join_key='clip_id') # Returns a tuple containing both EEG and ECG data: # (dataset_eeg[0][0], dataset_eeg[0][1], dataset_ecg[0][0], dataset_ecg[0][1]) dataset[0] Args: datasets (List[BaseDataset]): A list of datasets to be paired. Each dataset should inherit from BaseDataset. join_key (str): The key used to join the datasets. This should be a column name present in all datasets' info DataFrames. Common join keys could be 'subject_id', 'trial_id', 'clip_id', etc. (default: :obj:`'subject_id'`) join_type (str): The type of join to perform. Valid options are: - 'inner': Only keeps matching records from all datasets - 'outer': Keeps all records, filling missing matches with None - 'left': Keeps all records from the first dataset - 'right': Keeps all records from the last dataset (default: :obj:`'inner'`) pair_info_fn (Callable, optional): A custom function to pair the datasets. If provided, this function will be used instead of the default pairing logic. The function should take a list of info DataFrames as input and return a DataFrame with appropriate index columns. This is useful when you need custom pairing logic beyond simple joins. (default: :obj:`None`) distinct_key (Optional[str]): A key to ensure distinct pairs based on a specific column. This is useful when you want to avoid duplicate pairs based on certain criteria. For example, using 'trial_id' would ensure no duplicate trials are paired. (default: :obj:`None`) ''' def __init__(self, datasets: List[BaseDataset], join_key: str = 'subject_id', join_type: str = 'inner', pair_info_fn: Optional[Callable] = None, distinct_key: Optional[str] = None): self.datasets = datasets if pair_info_fn is None: info_list = [dataset.info.copy() for dataset in self.datasets] for i, info in enumerate(info_list): # print(len(info)) info[f'index_{i}'] = np.arange(len(info)) # Merge all datasets based on the join key self.info = info_list[0] for i, info in enumerate(info_list[1:], start=1): columns_to_drop = set(info.columns).intersection( set(self.info.columns)) - {join_key} info = info.drop(columns=columns_to_drop) self.info = pd.merge( self.info, info, on=join_key, how=join_type) else: self.info = pair_info_fn([dataset.info.copy() for dataset in self.datasets]) if not all(f'index_{i}' in self.info.columns for i in range(len(datasets))): raise ValueError( "The custom pair_info_fn must return a DataFrame with appropriate index columns.") if distinct_key is not None: if distinct_key not in self.info.columns: raise ValueError( f"distinct_key '{distinct_key}' not found in merged DataFrame columns") new_info = self.info.sample(frac=1, random_state=42) new_info = new_info.drop_duplicates( subset=distinct_key, keep='first') new_info = new_info.reset_index(drop=True) self.info = new_info def __getitem__(self, index: int) -> Tuple[Any, ...]: info = self.info.iloc[index].to_dict() # Retrieve the signals and labels from all datasets signals = [] labels = [] for i, dataset in enumerate(self.datasets): index_i = info[f'index_{i}'] signal, label = dataset[index_i] signals.append(signal) labels.append(label) return (*signals, *labels) def get_labels(self) -> List: labels = [] for i in range(len(self)): *_, label = self.__getitem__(i) labels.append(label) return labels def __len__(self): return len(self.info)

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources