VanillaTransformer¶
- class torcheeg.models.VanillaTransformer(num_electrodes: int = 32, chunk_size: int = 128, t_patch_size: int = 32, hid_channels: int = 32, depth: int = 3, heads: int = 4, head_channels: int = 8, mlp_channels: int = 64, num_classes: int = 2)[source][source]¶
A vanilla version of the transformer adapted on EEG analysis. For more details, please refer to the following information.
It is worth noting that this model is not designed for EEG analysis, but shows good performance and can serve as a good research start.
Paper: Vaswani A, Shazeer N, Parmar N, et al. Attention is all you need[J]. Advances in neural information processing systems, 2017, 30.
Related Project: https://github.com/huggingface/transformers
Below is a recommended suite for use in emotion recognition tasks:
from torcheeg.datasets import DEAPDataset from torcheeg.models import VanillaTransformer from torcheeg import transforms from torch.utils.data import DataLoader dataset = DEAPDataset(root_path='./data_preprocessed_python', offline_transform=transforms.To2d(), online_transform=transforms.Compose([ transforms.ToTensor(), ]), label_transform=transforms.Compose([ transforms.Select('valence'), transforms.Binary(5.0), ])) model = VanillaTransformer(chunk_size=128, num_electrodes=32, patch_size=32, hid_channels=32, depth=3, heads=4, head_channels=64, mlp_channels=64, num_classes=2) x, y = next(iter(DataLoader(dataset, batch_size=64))) model(x)
- Parameters:
chunk_size (int) – Number of data points included in each EEG chunk. (default:
128
)grid_size (tuple) – Spatial dimensions of grid-like EEG representation. (default:
(9, 9)
)patch_size (tuple) – The size (resolution) of each input patch. (default:
(3, 3)
)hid_channels (int) – The feature dimension of embeded patch. (default:
32
)depth (int) – The number of attention layers for each transformer block. (default:
3
)heads (int) – The number of attention heads for each attention layer. (default:
4
)head_channels (int) – The dimension of each attention head for each attention layer. (default:
8
)mlp_channels (int) – The number of hidden nodes in the fully connected layer of each transformer block. (default:
64
)num_classes (int) – The number of classes to predict. (default:
2
)
- forward(x: Tensor) Tensor [source][source]¶
- Parameters:
x (torch.Tensor) – EEG signal representation, the ideal input shape is
[n, 1, 32, 128]
. Here,n
corresponds to the batch size,32
corresponds tonum_electrodes
, andchunk_size
corresponds tochunk_size
.- Returns:
the predicted probability that the samples belong to the classes.
- Return type:
torch.Tensor[number of sample, number of classes]