Transformer

torcheeg.models.SimpleViT

class torcheeg.models.SimpleViT(in_channels: int = 5, grid_size: Tuple[int, int] = (9, 9), patch_size: int = 3, hid_channels: int = 32, depth: int = 3, heads: int = 4, head_channels: int = 8, mlp_channels: int = 64, num_classes: int = 2)[source]

Bases: Module

A Simple and Effective Vision Transformer (SimpleViT). The authors of Vision Transformer (ViT) present a few minor modifications and dramatically improve the performance of plain ViT models. For more details, please refer to the following information.

It is worth noting that this model is not designed for EEG analysis, but shows good performance and can serve as a good research start.

Below is a recommended suite for use in emotion recognition tasks:

dataset = DEAPDataset(io_path=f'./deap',
            root_path='./data_preprocessed_python',
            offline_transform=transforms.Compose([
                transforms.BandDifferentialEntropy({
                    "delta": [1, 4],
                    "theta": [4, 8],
                    "alpha": [8, 14],
                    "beta": [14, 31],
                    "gamma": [31, 49]
                }),
                transforms.ToGrid(DEAP_CHANNEL_LOCATION_DICT)
            ]),
            online_transform=transforms.Compose([
                transforms.ToTensor(),
            ]),
            label_transform=transforms.Compose([
                transforms.Select('valence'),
                transforms.Binary(5.0),
            ]))
model = SimpleViT(hid_channels=32,
                  depth=3,
                  heads=4,
                  mlp_channels=64,
                  grid_size=(9, 9),
                  patch_size=3,
                  num_classes=2,
                  in_channels=5,
                  head_channels=64)
Parameters
  • in_channels (int) – The feature dimension of each electrode. (defualt: 5)

  • grid_size (tuple) – Spatial dimensions of grid-like EEG representation. (defualt: (9, 9))

  • patch_size (int) – The size (resolution) of each input patch. (defualt: 3)

  • hid_channels (int) – The feature dimension of embeded patch. (defualt: 32)

  • depth (int) – The number of attention layers for each transformer block. (defualt: 3)

  • heads (int) – The number of attention heads for each attention layer. (defualt: 4)

  • head_channels (int) – The dimension of each attention head for each attention layer. (defualt: 8)

  • mlp_channels (int) – The number of hidden nodes in the fully connected layer of each transformer block. (defualt: 64)

  • num_classes (int) – The number of classes to predict. (defualt: 2)

forward(x: Tensor) Tensor[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool