DARNet¶
- class torcheeg.models.DARNet(num_electrodes: int = 62, chunk_size: int = 64, d_model: int = 16, num_heads: int = 8, attn_dropout: float = 0.1, num_classes: int = 2)[source][source]¶
The DARNet model is based on the paper “DARNet: Dual Attention Refinement Network with Spatiotemporal Construction for Auditory Attention Detection”. For more details, please refer to the following information.
Paper: Yan S, Fan C, Zhang H, et al. Darnet: Dual attention refinement network with spatiotemporal construction for auditory attention detection[J]. Advances in Neural Information Processing Systems, 2024, 37: 31688-31707.
URL: https://openreview.net/forum?id=jWGGEDYORs¬eId=0A27gTqMH0
Related Project: https://github.com/fchest/DARNet.git
Below is a recommended suite for use in auditory attention detection tasks:
from torcheeg.models import DARNet from torcheeg.datasets import DTUDataset from torcheeg import transforms from torch.utils.data import DataLoader dataset = DTUDataset(root_path='./DATA_preproc', offline_transform=transforms.Compose([ transforms.MinMaxNormalize(axis=-1), transforms.To2d() ]), online_transform=transforms.ToTensor(), label_transform=transforms.Compose([ transforms.Select('attended_speaker'), transforms.Lambda(lambda x: x - 1) ])) model = DARNet(num_electrodes=64, chunk_size=64, d_model=16, num_heads=8, attn_dropout=0.1, num_classes=2) x, y = next(iter(DataLoader(dataset, batch_size=64))) model(x)
- Parameters:
num_electrodes (int) – The number of electrodes. (default:
62)chunk_size (int) – The sampling rate of EEG signals. (default:
64)d_model (int) – The dimension of the embedding model. (default:
16)num_heads (int) – The number of attention heads. (default:
8)attn_dropout (float) – The dropout rate for attention layers. (default:
0.1)num_classes (int) – The number of classes. (default:
2)
- forward(x: Tensor) Tensor[source][source]¶
- Parameters:
x (torch.Tensor) – EEG signal representation, the ideal input shape is
[n, 64, 64]. Here,ncorresponds to the batch size, the first64corresponds tonum_electrodes, and the second64corresponds tochunk_size.- Returns:
The predicted probability that the samples belong to the classes.
- Return type:
torch.Tensor[size of batch, number of classes]