ATCNet¶
- class torcheeg.models.ATCNet(in_channels: int = 1, num_classes: int = 4, num_windows: int = 3, num_electrodes: int = 22, conv_pool_size: int = 7, F1: int = 16, D: int = 2, tcn_kernel_size: int = 4, tcn_depth: int = 2, chunk_size: int = 1125)[source][source]¶
ATCNet: An attention-based temporal convolutional network for EEG-based motor imagery classification. For more details, please refer to the following information:
Paper: H. Altaheri, G. Muhammad and M. Alsulaiman, “Physics-Informed Attention Temporal Convolutional Network for EEG-Based Motor Imagery Classification,” in IEEE Transactions on Industrial Informatics, vol. 19, no. 2, pp. 2249-2258, Feb. 2023, doi: 10.1109/TII.2022.3197419.
Below is a quick start example:
import torch from torcheeg.models import ATCNet model = ATCNet(in_channels=1, num_classes=4, num_windows=3, num_electrodes=22, chunk_size=1125) # shape: (batch_size, in_channels, num_electrodes, chunk_size) input = torch.rand(2, 1, 22, 1125) output = model(input)
- Parameters:
in_channels (int) – The number of input channels per electrode. Use 1 for raw EEG signals, or set to the number of frequency bands if the signal is decomposed into multiple sub-bands. (default:
1)num_classes (int) – The number of classes to predict. (default:
4)num_windows (int) – The number of temporal sliding windows after the convolutional block. Controls the temporal resolution of the attention mechanism. (default:
3)num_electrodes (int) – The number of EEG electrodes/channels in the input signal. (default:
22)conv_pool_size (int) – The kernel size of the second average pooling layer in the convolutional block. Affects the temporal downsampling rate. (default:
7)F1 (int) – The number of temporal feature maps (filters) in the first convolutional layer. (default:
16)D (int) – The spatial filter multiplier - number of filters per temporal feature map in the second convolutional layer. (default:
2)tcn_kernel_size (int) – The kernel size of the convolutional layers in the Temporal Convolutional Network (TCN) block. (default:
4)tcn_depth (int) – The number of TCN layers/iterations in the model. Each layer increases the receptive field. (default:
2)chunk_size (int) – The number of time points in each EEG segment/chunk. (default:
1125)
- forward(x)[source][source]¶
- Parameters:
x (torch.Tensor) – EEG signal representation, the ideal input shape is
[n, 22, 1125]. Here,ncorresponds to the batch size,22corresponds tonum_electrodes, and1125corresponds tochunk_size.- Returns:
The predicted probability that the samples belong to the classes.
- Return type:
torch.Tensor[size of batch, number of classes]