Label Transforms

transforms.Select

class torcheeg.transforms.Select(key: Union[str, List])[source]

Select part of the value from the information dictionary.

transform = Select(key='valence')
transform(y={'valence': 4.5, 'arousal': 5.5, 'subject_id': 7})['y']
>>> 4.5

Select allows multiple values to be selected and returned as a list. Suitable for multi-classification tasks or multi-task learning.

transform = Select(key=['valence', 'arousal'])
transform(y={'valence': 4.5, 'arousal': 5.5, 'subject_id': 7})['y']
>>> [4.5, 5.5]
Parameters

key (str or list) – The selected key can be a key string or a list of keys.

__call__(*args, y: Dict, **kwargs) Union[int, float, List][source]
Parameters

y (dict) – A dictionary describing the EEG signal samples, usually as the last return value for each sample in Dataset.

Returns

Selected value or selected value list.

Return type

str or list

transforms.Binary

class torcheeg.transforms.Binary(threshold: float)[source]

Binarize the label according to a certain threshold. Labels larger than the threshold are set to 1, and labels smaller than the threshold are set to 0.

transform = Binary(threshold=5.0)
transform(y=4.5)['y']
>>> 0

Binary allows simultaneous binarization using the same threshold for multiple labels.

transform = Binary(threshold=5.0)
transform(y=[4.5, 5.5])['y']
>>> [0, 1]
Parameters

threshold (float) – Threshold used during binarization.

__call__(*args, y: Union[int, float, List], **kwargs) Union[int, List][source]
Parameters

label (int, float, or list) – The input label or list of labels.

Returns

The output label or list of labels after binarization.

Return type

int, float, or list

transforms.BinariesToCategory

class torcheeg.transforms.BinariesToCategory[source]

Convert multiple binary labels into one multiclass label. Multiclass labels represent permutations of binary labels. Commonly used to combine two binary classification tasks into a single quad classification task.

transform = BinariesToCategory()
transform(y=[0, 0])['y']
>>> 0
transform(y=[0, 1])['y']
>>> 1
transform(y=[1, 0])['y']
>>> 2
transform(y=[1, 1])['y']
>>> 3
__call__(*args, y: List, **kwargs) int[source]
Parameters

y (list) – list of binary labels.

Returns

The converted multiclass label.

Return type

int