Source code for torcheeg.models.transformer.atcnet
import torch.nn as nn
import torch.nn.functional as F
import torch
[docs]class ATCNet(nn.Module):
r'''
ATCNet: An attention-based temporal convolutional network for EEG-based motor imagery classification. For more details, please refer to the following information:
- Paper: H. Altaheri, G. Muhammad and M. Alsulaiman, "Physics-Informed Attention Temporal Convolutional Network for EEG-Based Motor Imagery Classification," in IEEE Transactions on Industrial Informatics, vol. 19, no. 2, pp. 2249-2258, Feb. 2023, doi: 10.1109/TII.2022.3197419.
- URL: https://github.com/Altaheri/EEG-ATCNet
Below is a quick start example:
.. code-block:: python
import torch
from torcheeg.models import ATCNet
model = ATCNet(in_channels=1,
num_classes=4,
num_windows=3,
num_electrodes=22,
chunk_size=1125)
# shape: (batch_size, in_channels, num_electrodes, chunk_size)
input = torch.rand(2, 1, 22, 1125)
output = model(input)
Args:
in_channels (int): The number of input channels per electrode. Use 1 for raw EEG signals, or set to the number of
frequency bands if the signal is decomposed into multiple sub-bands. (default: :obj:`1`)
num_classes (int): The number of classes to predict. (default: :obj:`4`)
num_windows (int): The number of temporal sliding windows after the convolutional block. Controls the temporal
resolution of the attention mechanism. (default: :obj:`3`)
num_electrodes (int): The number of EEG electrodes/channels in the input signal. (default: :obj:`22`)
conv_pool_size (int): The kernel size of the second average pooling layer in the convolutional block. Affects the
temporal downsampling rate. (default: :obj:`7`)
F1 (int): The number of temporal feature maps (filters) in the first convolutional layer. (default: :obj:`16`)
D (int): The spatial filter multiplier - number of filters per temporal feature map in the second convolutional
layer. (default: :obj:`2`)
tcn_kernel_size (int): The kernel size of the convolutional layers in the Temporal Convolutional Network (TCN)
block. (default: :obj:`4`)
tcn_depth (int): The number of TCN layers/iterations in the model. Each layer increases the receptive field.
(default: :obj:`2`)
chunk_size (int): The number of time points in each EEG segment/chunk. (default: :obj:`1125`)
'''
def __init__(self,
in_channels: int = 1,
num_classes: int = 4,
num_windows: int = 3,
num_electrodes: int = 22,
conv_pool_size: int = 7,
F1: int = 16,
D: int = 2,
tcn_kernel_size: int = 4,
tcn_depth: int = 2,
chunk_size: int = 1125,
):
super(ATCNet, self).__init__()
self.in_channels = in_channels
self.num_classes = num_classes
self.num_windows = num_windows
self.num_electrodes = num_electrodes
self.pool_size = conv_pool_size
self.F1 = F1
self.D = D
self.tcn_kernel_size = tcn_kernel_size
self.tcn_depth = tcn_depth
self.chunk_size = chunk_size
F2 = F1*D
self.conv_block = nn.Sequential(
nn.Conv2d(in_channels, F1, (1, int(chunk_size/2+1)),
stride=1, padding='same', bias=False),
nn.BatchNorm2d(F1, False),
nn.Conv2d(F1, F2, (num_electrodes, 1), padding=0, groups=F1),
nn.BatchNorm2d(F2, False),
nn.ELU(),
nn.AvgPool2d((1, 8)),
nn.Dropout2d(0.1),
nn.Conv2d(F2, F2, (1, 16), bias=False, padding='same'),
nn.BatchNorm2d(F2, False),
nn.ELU(),
nn.AvgPool2d((1, self.pool_size)),
nn.Dropout2d(0.1)
)
self.__build_model()
def __build_model(self):
with torch.no_grad():
x = torch.zeros(2, self.in_channels,
self.num_electrodes, self.chunk_size)
x = self.conv_block(x)
x = x[:, :, -1, :]
x = x.permute(0, 2, 1)
self.__chan_dim, self.__embed_dim = x.shape[1:]
self.win_len = self.__chan_dim - self.num_windows + 1
for i in range(self.num_windows):
st = i
end = x.shape[1] - self.num_windows+i+1
x2 = x[:, st:end, :]
self.__add_msa(i)
x2_ = self.get_submodule("msa"+str(i))(x2, x2, x2)[0]
self.__add_msa_drop(i)
x2_ = self.get_submodule("msa_drop"+str(i))(x2)
x2 = torch.add(x2, x2_)
for j in range(self.tcn_depth):
self.__add_tcn((i+1)*j, x2.shape[1])
out = self.get_submodule("tcn"+str((i+1)*j))(x2)
if x2.shape[1] != out.shape[1]:
self.__add_recov(i)
x2 = self.get_submodule("re"+str(i))(x2)
x2 = torch.add(x2, out)
x2 = nn.ELU()(x2)
x2 = x2[:, -1, :]
self.__dense_dim = x2.shape[-1]
self.__add_dense(i)
x2 = self.get_submodule("dense"+str(i))(x2)
def __add_msa(self, index: int):
self.add_module('msa'+str(index), nn.MultiheadAttention(
embed_dim=self.__embed_dim,
num_heads=2,
batch_first=True))
def __add_msa_drop(self, index):
self.add_module('msa_drop'+str(index), nn.Dropout(0.3))
def __add_tcn(self, index: int, num_electrodes: int):
self.add_module('tcn'+str(index),
nn.Sequential(
nn.Conv1d(num_electrodes, 32,
self.tcn_kernel_size, padding='same'),
nn.BatchNorm1d(32),
nn.ELU(),
nn.Dropout(0.3),
nn.Conv1d(32, 32, self.tcn_kernel_size, padding='same'),
nn.BatchNorm1d(32),
nn.ELU(),
nn.Dropout(0.3))
)
def __add_recov(self, index: int):
self.add_module('re'+str(index),
nn.Conv1d(self.win_len, 32, 4, padding='same'))
def __add_dense(self, index: int):
self.add_module('dense'+str(index),
nn.Linear(self.__dense_dim, self.num_classes))
[docs] def forward(self, x):
r'''
Args:
x (torch.Tensor): EEG signal representation, the ideal input shape is :obj:`[n, 22, 1125]`. Here, :obj:`n` corresponds to the batch size, :obj:`22` corresponds to :obj:`num_electrodes`, and :obj:`1125` corresponds to :obj:`chunk_size`.
Returns:
torch.Tensor[size of batch, number of classes]: The predicted probability that the samples belong to the classes.
'''
x = self.conv_block(x)
x = x[:, :, -1, :]
x = x.permute(0, 2, 1)
for i in range(self.num_windows):
st = i
end = x.shape[1] - self.num_windows+i+1
x2 = x[:, st:end, :]
x2_ = self.get_submodule("msa"+str(i))(x2, x2, x2)[0]
x2_ = self.get_submodule("msa_drop"+str(i))(x2)
x2 = torch.add(x2, x2_)
for j in range(self.tcn_depth):
out = self.get_submodule("tcn"+str((i+1)*j))(x2)
if x2.shape[1] != out.shape[1]:
x2 = self.get_submodule("re"+str(i))(x2)
x2 = torch.add(x2, out)
x2 = nn.ELU()(x2)
x2 = x2[:, -1, :]
x2 = self.get_submodule("dense"+str(i))(x2)
if i == 0:
sw_concat = x2
else:
sw_concat = sw_concat.add(x2)
x = sw_concat/self.num_windows
return x