
Source code for torcheeg.models.transformer.atcnet

import torch.nn as nn
import torch.nn.functional as F
import torch

[docs]class ATCNet(nn.Module): r''' ATCNet: An attention-based temporal convolutional network forEEG-based motor imagery classification.For more details ,please refer to the following information: - Paper: H. Altaheri, G. Muhammad and M. Alsulaiman, "Physics-Informed Attention Temporal Convolutional Network for EEG-Based Motor Imagery Classification," in IEEE Transactions on Industrial Informatics, vol. 19, no. 2, pp. 2249-2258, Feb. 2023, doi: 10.1109/TII.2022.3197419. - URL: .. code-block:: python import torch from torcheeg.models import ATCNet model = ATCNet(in_channels=1, num_classes=4, num_windows=3, num_electrodes=22, chunk_size=128) input = torch.rand(2, 1, 22, 128) # (batch_size, in_channels, num_electrodes,chunk_size) output = model(input) Args: in_channels (int): The number of channels of the signal corresponding to each electrode. If the original signal is used as input, in_channels is set to 1; if the original signal is split into multiple sub-bands, in_channels is set to the number of bands. (default: :obj:`1`) num_electrodes (int): The number of electrodes. (default: :obj:`32`) num_classes (int): The number of classes to predict. (default: :obj:`4`) num_windows (int): The number of sliding windows after conv block. (default: :obj:`3`) num_electrodes (int): The number of electrodes if the input is EEG signal. (default: :obj:`22`) conv_pool_size (int): The size of the second average pooling layer kernel in the conv block. (default: :obj:`7`) F1 (int): The channel size of the temporal feature maps in conv block. (default: :obj:`16`) D (int): The number of second conv layer's filters linked to each temporal feature map in the previous layer in conv block. (default: :obj:`2`) tcn_kernel_size (int): The size of conv layers kernel in the TCN block. (default: :obj:`4`) tcn_depth (int): The times of TCN loop. (default: :obj:`2`) chunk_size (int): The Number of data points included in each EEG chunk. (default: :obj:`1125`) ''' def __init__(self, in_channels: int = 1, num_classes: int = 4, num_windows: int = 3, num_electrodes: int = 22, conv_pool_size: int = 7, F1: int = 16, D: int =2, tcn_kernel_size: int = 4, tcn_depth: int = 2, chunk_size: int = 1125, ): super(ATCNet, self).__init__() self.in_channels = in_channels self.num_classes = num_classes self.num_windows = num_windows self.num_electrodes = num_electrodes self.pool_size = conv_pool_size self.F1 = F1 self.D = D self.tcn_kernel_size = tcn_kernel_size self.tcn_depth = tcn_depth self.chunk_size = chunk_size F2 = F1*D self.conv_block = nn.Sequential( nn.Conv2d(in_channels,F1,(1,int(chunk_size/2+1)),stride=1,padding = 'same',bias=False), nn.BatchNorm2d(F1, False), nn.Conv2d(F1,F2,(num_electrodes,1),padding = 0,groups=F1), nn.BatchNorm2d(F2,False), nn.ELU(), nn.AvgPool2d((1,8)), nn.Dropout2d(0.1), nn.Conv2d(F2,F2,(1,16),bias=False,padding='same'), nn.BatchNorm2d(F2,False), nn.ELU(), nn.AvgPool2d((1,self.pool_size)), nn.Dropout2d(0.1) ) self.__build_model() def __build_model(self): with torch.no_grad(): x = torch.zeros(2,self.in_channels,self.num_electrodes,self.chunk_size) x = self.conv_block(x) x = x[:,:,-1,:] x = x.permute(0,2,1) self.__chan_dim,self.__embed_dim = x.shape[1:] self.win_len = self.__chan_dim - self.num_windows +1 for i in range(self.num_windows): st = i end = x.shape[1] -self.num_windows+i+1 x2 = x[:,st:end,:] self.__add_msa(i) x2_= self.get_submodule("msa"+str(i))(x2,x2,x2)[0] self.__add_msa_drop(i) x2_ = self.get_submodule("msa_drop"+str(i))(x2) x2 = torch.add(x2,x2_) for j in range(self.tcn_depth): self.__add_tcn((i+1)*j,x2.shape[1]) out = self.get_submodule("tcn"+str( (i+1)*j ))(x2) if x2.shape[1] != out.shape[1]: self.__add_recov(i) x2 = self.get_submodule("re"+str(i))(x2) x2 = torch.add(x2,out) x2 = nn.ELU()(x2) x2 = x2[:,-1,:] self.__dense_dim = x2.shape[-1] self.__add_dense(i) x2 = self.get_submodule("dense"+str(i))(x2) def __add_msa(self,index:int): self.add_module('msa'+str(index),nn.MultiheadAttention( embed_dim=self.__embed_dim, num_heads=2, batch_first=True)) def __add_msa_drop(self,index): self.add_module('msa_drop'+str(index),nn.Dropout(0.3)) def __add_tcn(self,index:int,num_electrodes:int): self.add_module('tcn'+str(index), nn.Sequential( nn.Conv1d(num_electrodes,32,self.tcn_kernel_size,padding='same'), nn.BatchNorm1d(32), nn.ELU(), nn.Dropout(0.3), nn.Conv1d(32,32,self.tcn_kernel_size,padding = 'same'), nn.BatchNorm1d(32), nn.ELU(), nn.Dropout(0.3) ) ) def __add_recov(self,index:int): self.add_module('re'+str(index),nn.Conv1d(self.win_len,32,4,padding='same')) def __add_dense(self, index:int): self.add_module('dense'+str(index),nn.Linear(self.__dense_dim,self.num_classes))
[docs] def forward(self,x): r''' Args: x (torch.Tensor): EEG signal representation, the ideal input shape is :obj:`[n, 22, 1125]`. Here, :obj:`n` corresponds to the batch size, :obj:`22` corresponds to :obj:`num_electrodes`, and :obj:`1125` corresponds to :obj:`chunk_size`. Returns: torch.Tensor[size of batch, number of classes]: The predicted probability that the samples belong to the classes. ''' x = self.conv_block(x) x = x[:,:,-1,:] x = x.permute(0,2,1) for i in range(self.num_windows): st = i end = x.shape[1] -self.num_windows+i+1 x2 = x[:,st:end,:] x2_= self.get_submodule("msa"+str(i))(x2,x2,x2)[0] x2_ = self.get_submodule("msa_drop"+str(i))(x2) x2 = torch.add(x2,x2_) for j in range(self.tcn_depth): out = self.get_submodule("tcn"+str( (i+1)*j ))(x2) if x2.shape[1] != out.shape[1]: x2 = self.get_submodule("re"+str(i))(x2) x2 = torch.add(x2,out) x2 = nn.ELU()(x2) x2 = x2[:,-1,:] x2 = self.get_submodule("dense"+str(i))(x2) if i == 0: sw_concat = x2 else: sw_concat =sw_concat.add(x2) x = sw_concat/self.num_windows x = nn.Softmax(dim=1)(x) return x
Read the Docs v: latest
On Read the Docs
Project Home

Free document hosting provided by Read the Docs.


Access comprehensive developer documentation for PyTorch

View Docs


Get in-depth tutorials for beginners and advanced developers

View Tutorials


Find development resources and get your questions answered

View Resources