MHANet¶
- class torcheeg.models.MHANet(num_electrodes: int = 64, chunk_size: int = 64, num_heads: int = 16, bias: bool = False, num_classes: int = 2)[source][source]¶
The MHANet model is based on the paper “MHANet: Multi-scale Hybrid Attention Network for Auditory Attention Detection”. For more details, please refer to the following information.
Paper: Li L, Fan C, Zhang H, et al. MHANet: Multi-scale Hybrid Attention Network for Auditory Attention Detection[J]. International Joint Conference on Artificial Intelligence, 2025.
Related Project: https://github.com/fchest/MHANet
Below is a recommended suite for use in auditory attention detection tasks:
from torcheeg.models import MHANet from torcheeg.datasets import DTUDataset from torcheeg import transforms from torch.utils.data import DataLoader dataset = DTUDataset(root_path='./DATA_preproc', offline_transform=transforms.Compose([ transforms.MinMaxNormalize(axis=-1), transforms.To2d() ]), online_transform=transforms.ToTensor(), label_transform=transforms.Compose([ transforms.Select('attended_speaker'), transforms.Lambda(lambda x: x - 1) ])) model = MHANet(num_electrodes=64, chunk_size=64, num_heads=16, bias=False, num_classes=2) x, y = next(iter(DataLoader(dataset, batch_size=64))) model(x)
- Parameters:
num_electrodes (int) – The number of electrodes. (default:
64)chunk_size (int) – The sampling rate of EEG signals. (default:
64)num_heads (int) – The number of attention heads. (default:
16)bias (bool) – Whether to use bias in convolution layers. (default:
False)num_classes (int) – The number of classes. (default:
2)
- forward(x: Tensor) Tensor[source][source]¶
- Parameters:
x (torch.Tensor) – EEG signal representation, the ideal input shape is
[n, 64, 64]. Here,ncorresponds to the batch size, the first64corresponds tonum_electrodes, and the second64corresponds tochunk_size.- Returns:
The predicted probability that the samples belong to the classes.
- Return type:
torch.Tensor[size of batch, number of classes]