Shortcuts

Source code for torcheeg.model_selection.subcategory

import logging
import os
import re
from copy import copy
from typing import Dict, Tuple, Union

import pandas as pd

from torcheeg.datasets.module.base_dataset import BaseDataset

from ..utils import get_random_dir_path

log = logging.getLogger('torcheeg')


[docs]class Subcategory: r''' A tool class for separating out subsets of specified categories, often used to extract data for a certain type of paradigm, or for a certain type of task. Each subset in the formed subset list contains only one type of data. Common usage: .. code-block:: python from torcheeg.datasets import M3CVDataset from torcheeg.model_selection import Subcategory from torcheeg import transforms from torcheeg.utils import DataLoader cv = Subcategory() dataset = M3CVDataset(root_path='./aistudio', online_transform=transforms.Compose( [transforms.To2d(), transforms.ToTensor()]), label_transform=transforms.Compose([ transforms.Select('subject_id'), transforms.StringToInt() ])) for subdataset in cv.split(dataset): loader = DataLoader(subdataset) ... TorchEEG supports the division of training and test sets within each subset after dividing the data into subsets. The sample code is as follows: .. code-block:: python cv = Subcategory() dataset = M3CVDataset(root_path='./aistudio', online_transform=transforms.Compose( [transforms.To2d(), transforms.ToTensor()]), label_transform=transforms.Compose([ transforms.Select('subject_id'), transforms.StringToInt() ])) for i, subdataset in enumerate(cv.split(dataset)): train_dataset, test_dataset = train_test_split(dataset=subdataset, split_path=f'./split{i}') train_loader = DataLoader(train_dataset) test_loader = DataLoader(test_dataset) ... For the already divided training and testing sets, TorchEEG recommends using two :obj:`Subcategory` to extract their subcategories respectively. On this basis, the :obj:`zip` function can be used to combine the subsets. It is worth noting that it is necessary to ensure that the training and test sets have the same number and variety of classes. .. code-block:: python train_cv = Subcategory() train_dataset = M3CVDataset(root_path='./aistudio', online_transform=transforms.Compose( [transforms.To2d(), transforms.ToTensor()]), label_transform=transforms.Compose([ transforms.Select('subject_id'), transforms.StringToInt() ])) val_cv = Subcategory() val_dataset = M3CVDataset(root_path='./aistudio', subset='Calibration', num_channel=65, online_transform=transforms.Compose( [transforms.To2d(), transforms.ToTensor()]), label_transform=transforms.Compose([ transforms.Select('subject_id'), transforms.StringToInt() ])) for train_dataset, val_dataset in zip(train_cv.split(train_dataset), val_cv.split(val_dataset)): train_loader = DataLoader(train_dataset) val_loader = DataLoader(val_dataset) ... Args: criteria (str): The classification criteria according to which we extract subsets of data for the including categories. (default: :obj:`'task'`) split_path (str): The path to data partition information. If the path exists, read the existing partition from the path. If the path does not exist, the current division method will be saved for next use. If set to None, a random path will be generated. (default: :obj:`None`) ''' def __init__(self, criteria: str = 'task', split_path: Union[None, str] = None): if split_path is None: split_path = get_random_dir_path(dir_prefix='model_selection') self.criteria = criteria self.split_path = split_path def split_info_constructor(self, info: pd.DataFrame) -> None: assert self.criteria in list( info.columns ), f'Unsupported criteria {self.criteria}, please select one of the following options {list(info.columns)}.' category_list = list(set(info[self.criteria])) for category in category_list: subset_info = info[info[self.criteria] == category] subset_info.to_csv(os.path.join(self.split_path, f'{category}.csv'), index=False) @property def category_list(self): indice_files = list(os.listdir(self.split_path)) def indice_file_to_fold_id(indice_file): return int(re.findall(r'(.*).csv', indice_file)[0]) category_list = list(set(map(indice_file_to_fold_id, indice_files))) category_list.sort() return category_list def split(self, dataset: BaseDataset) -> Tuple[BaseDataset, BaseDataset]: if not os.path.exists(self.split_path): log.info( f'📊 | Create the split of train and test set.' ) log.info( f'😊 | Please set \033[92msplit_path\033[0m to \033[92m{self.split_path}\033[0m for the next run, if you want to use the same setting for the experiment.' ) os.makedirs(self.split_path) self.split_info_constructor(dataset.info) else: log.info( f'📊 | Detected existing split of train and test set, use existing split from {self.split_path}.' ) log.info( f'💡 | If the dataset is re-generated, you need to re-generate the split of the dataset instead of using the previous split.' ) category_list = self.category_list for category in category_list: subset_info = pd.read_csv( os.path.join(self.split_path, f'{category}.csv')) subset_dataset = copy(dataset) subset_dataset.info = subset_info yield subset_dataset @property def repr_body(self) -> Dict: return {'criteria': self.criteria, 'split_path': self.split_path} def __repr__(self) -> str: # init info format_string = self.__class__.__name__ + '(' for i, (k, v) in enumerate(self.repr_body.items()): # line end if i: format_string += ', ' # str param if isinstance(v, str): format_string += f"{k}='{v}'" else: format_string += f"{k}={v}" format_string += ')' return format_string
Read the Docs v: latest
Versions
latest
stable
v1.1.2
v1.1.1
v1.1.0
v1.0.11
v1.0.10
v1.0.9
v1.0.8.post1
v1.0.8
v1.0.7
v1.0.6
v1.0.4
v1.0.3
v1.0.2
v1.0.1
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources