Shortcuts

STNet

class torcheeg.models.STNet(chunk_size: int = 128, grid_size: Tuple[int, int] = (9, 9), num_classes: int = 2, dropout: float = 0.2)[source][source]

Spatio-temporal Network (STNet). For more details, please refer to the following information.

Below is a recommended suite for use in emotion recognition tasks:

dataset = DEAPDataset(io_path=f'./deap',
            root_path='./data_preprocessed_python',
            offline_transform=transforms.Compose([
                transforms.ToGrid(DEAP_CHANNEL_LOCATION_DICT)
            ]),
            online_transform=transforms.ToTensor(),
            label_transform=transforms.Compose([
                transforms.Select('valence'),
                transforms.Binary(5.0),
            ]))
model = STNet(num_classes=2, chunk_size=128, grid_size=(9, 9), dropout=0.2)
Parameters
  • chunk_size (int) – Number of data points included in each EEG chunk, i.e., \(T\) in the paper. (defualt: 128)

  • grid_size (tuple) – Spatial dimensions of grid-like EEG representation. (defualt: (9, 9))

  • num_classes (int) – The number of classes to predict. (defualt: 2)

  • dropout (float) – Probability of an element to be zeroed in the dropout layers. (defualt: 0.2)

forward(x: Tensor) Tensor[source][source]
Parameters

x (torch.Tensor) – EEG signal representation, the ideal input shape is [n, 128, 9, 9]. Here, n corresponds to the batch size, 128 corresponds to chunk_size, and (9, 9) corresponds to grid_size.

Returns

the predicted probability that the samples belong to the classes.

Return type

torch.Tensor[number of sample, number of classes]

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources