Shortcuts

DARNet

class torcheeg.models.DARNet(num_electrodes: int = 62, chunk_size: int = 64, d_model: int = 16, num_heads: int = 8, attn_dropout: float = 0.1, num_classes: int = 2)[source][source]

The DARNet model is based on the paper “DARNet: Dual Attention Refinement Network with Spatiotemporal Construction for Auditory Attention Detection”. For more details, please refer to the following information.

Below is a recommended suite for use in auditory attention detection tasks:

from torcheeg.models import DARNet
from torcheeg.datasets import DTUDataset
from torcheeg import transforms
from torch.utils.data import DataLoader

dataset = DTUDataset(root_path='./DATA_preproc',
                      offline_transform=transforms.Compose([
                          transforms.MinMaxNormalize(axis=-1),
                          transforms.To2d()
                      ]),
                      online_transform=transforms.ToTensor(),
                      label_transform=transforms.Compose([
                          transforms.Select('attended_speaker'),
                          transforms.Lambda(lambda x: x - 1)
                      ]))

model = DARNet(num_electrodes=64,
               chunk_size=64,
               d_model=16,
               num_heads=8,
               attn_dropout=0.1,
               num_classes=2)

x, y = next(iter(DataLoader(dataset, batch_size=64)))
model(x)
Parameters:
  • num_electrodes (int) – The number of electrodes. (default: 62)

  • chunk_size (int) – The sampling rate of EEG signals. (default: 64)

  • d_model (int) – The dimension of the embedding model. (default: 16)

  • num_heads (int) – The number of attention heads. (default: 8)

  • attn_dropout (float) – The dropout rate for attention layers. (default: 0.1)

  • num_classes (int) – The number of classes. (default: 2)

forward(x: Tensor) Tensor[source][source]
Parameters:

x (torch.Tensor) – EEG signal representation, the ideal input shape is [n, 64, 64]. Here, n corresponds to the batch size, the first 64 corresponds to num_electrodes, and the second 64 corresponds to chunk_size.

Returns:

The predicted probability that the samples belong to the classes.

Return type:

torch.Tensor[size of batch, number of classes]

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources