ATCNet¶
- class torcheeg.models.ATCNet(in_channels: int = 1, num_classes: int = 4, num_windows: int = 3, num_electrodes: int = 22, conv_pool_size: int = 7, F1: int = 16, D: int = 2, tcn_kernel_size: int = 4, tcn_depth: int = 2, chunk_size: int = 1125)[source][source]¶
ATCNet: An attention-based temporal convolutional network forEEG-based motor imagery classification.For more details ,please refer to the following information:
Paper: H. Altaheri, G. Muhammad and M. Alsulaiman, “Physics-Informed Attention Temporal Convolutional Network for EEG-Based Motor Imagery Classification,” in IEEE Transactions on Industrial Informatics, vol. 19, no. 2, pp. 2249-2258, Feb. 2023, doi: 10.1109/TII.2022.3197419.
import torch from torcheeg.models import ATCNet model = ATCNet(in_channels=1, num_classes=4, num_windows=3, num_electrodes=22, chunk_size=128) input = torch.rand(2, 1, 22, 128) # (batch_size, in_channels, num_electrodes,chunk_size) output = model(input)
- Parameters:
in_channels (int) – The number of channels of the signal corresponding to each electrode. If the original signal is used as input, in_channels is set to 1; if the original signal is split into multiple sub-bands, in_channels is set to the number of bands. (default:
1
)num_electrodes (int) – The number of electrodes. (default:
32
)num_classes (int) – The number of classes to predict. (default:
4
)num_windows (int) – The number of sliding windows after conv block. (default:
3
)num_electrodes – The number of electrodes if the input is EEG signal. (default:
22
)conv_pool_size (int) – The size of the second average pooling layer kernel in the conv block. (default:
7
)F1 (int) – The channel size of the temporal feature maps in conv block. (default:
16
)D (int) – The number of second conv layer’s filters linked to each temporal feature map in the previous layer in conv block. (default:
2
)tcn_kernel_size (int) – The size of conv layers kernel in the TCN block. (default:
4
)tcn_depth (int) – The times of TCN loop. (default:
2
)chunk_size (int) – The Number of data points included in each EEG chunk. (default:
1125
)
- forward(x)[source][source]¶
- Parameters:
x (torch.Tensor) – EEG signal representation, the ideal input shape is
[n, 22, 1125]
. Here,n
corresponds to the batch size,22
corresponds tonum_electrodes
, and1125
corresponds tochunk_size
.- Returns:
The predicted probability that the samples belong to the classes.
- Return type:
torch.Tensor[size of batch, number of classes]