Source code for torcheeg.transforms.any.compose
from typing import Callable, List
from ..base_transform import BaseTransform
[docs]class Compose(BaseTransform):
r'''
Compose several transforms together. Consistent with :obj:`torchvision.transforms.Compose`'s behavior.
.. code-block:: python
from torcheeg import transforms
t = transforms.Compose([
transforms.ToTensor(),
transforms.Resize(size=(64, 64)),
transforms.RandomNoise(p=0.1),
transforms.RandomMask(p=0.1)
])
t(eeg=torch.randn(128, 9, 9))['eeg'].shape
>>> (128, 64, 64)
:obj:`Compose` supports transformers with different data dependencies. The above example combines multiple torch-based transformers, the following example shows a sequence of numpy-based transformer.
.. code-block:: python
from torcheeg import transforms
t = transforms.Compose([
transforms.BandDifferentialEntropy(),
transforms.MeanStdNormalize(),
transforms.ToGrid(DEAP_CHANNEL_LOCATION_DICT)
])
t(eeg=np.random.randn(32, 128))['eeg'].shape
>>> (128, 9, 9)
Args:
transforms (list): The list of transforms to compose.
.. automethod:: __call__
'''
def __init__(self, transforms: List[Callable]):
super(Compose, self).__init__()
self.transforms = transforms
[docs] def __call__(self, *args, **kwargs) -> any:
r'''
Args:
x (any): The input.
Returns:
any: The transformed output.
'''
if args:
raise KeyError("Please pass data as named parameters.")
for t in self.transforms:
kwargs = t(**kwargs)
return kwargs
def __repr__(self) -> str:
format_string = self.__class__.__name__ + '('
for i, t in enumerate(self.transforms):
if i:
format_string += ','
format_string += '\n'
format_string += f' {t}'
format_string += '\n)'
return format_string