LGGNet¶
- class torcheeg.models.LGGNet(region_list, in_channels: int = 1, num_electrodes: int = 32, chunk_size: int = 128, sampling_rate: int = 128, num_T: int = 64, hid_channels: int = 32, dropout: float = 0.5, pool_kernel_size: int = 16, pool_stride: int = 4, num_classes: int = 2)[source][source]¶
DLocal-Global-Graph Networks (LGGNet). For more details, please refer to the following information.
Paper: Ding Y, Robinson N, Zeng Q, et al. LGGNet: learning from Local-global-graph representations for brain-computer interface[J]. arXiv preprint arXiv:2105.02786, 2021.
Related Project: https://github.com/yi-ding-cs/LGG
Below is a recommended suite for use in emotion recognition tasks:
from torcheeg.datasets import SEEDDataset from torcheeg.models import LGGNet from torcheeg import transforms from torcheeg.datasets.constants.emotion_recognition.seed import SEED_GENERAL_REGION_LIST dataset = SEEDDataset(root_path='./Preprocessed_EEG', offline_transform=transforms.Compose([ transforms.MeanStdNormalize(), transforms.To2d() ]), online_transform=transforms.Compose([ transforms.ToTensor() ]), label_transform=transforms.Compose([ transforms.Select('emotion'), transforms.Lambda(lambda x: x + 1) ])) model = LGGNet(region_list=SEED_GENERAL_REGION_LIST, chunk_size=128, num_electrodes=32, hid_channels=32, num_classes=2) x, y = next(iter(DataLoader(dataset, batch_size=64))) model(x)
The current built-in
region_list
includs:torcheeg.datasets.constants.emotion_recognition.amigos.AMIGOS_GENERAL_REGION_LIST
torcheeg.datasets.constants.emotion_recognition.amigos.AMIGOS_FRONTAL_REGION_LIST
torcheeg.datasets.constants.emotion_recognition.amigos.AMIGOS_HEMISPHERE_REGION_LIST
torcheeg.datasets.constants.emotion_recognition.deap.DEAP_GENERAL_REGION_LIST
…
torcheeg.datasets.constants.emotion_recognition.dreamer.DREAMER_GENERAL_REGION_LIST
…
torcheeg.datasets.constants.emotion_recognition.mahnob.MAHNOB_GENERAL_REGION_LIST
…
torcheeg.datasets.constants.emotion_recognition.seed.SEED_GENERAL_REGION_LIST
…
- Parameters:
region_list (list) – The local graph structure defined according to the 10-20 system, where the electrodes are divided into different brain regions.
in_channels (int) – The feature dimension of each electrode. (default:
1
)num_electrodes (int) – The number of electrodes. (default:
32
)chunk_size (int) – Number of data points included in each EEG chunk. (default:
128
)sampling_rate (int) – The sampling rate of the EEG signals, i.e., \(f_s\) in the paper. (default:
128
)num_T (int) – The number of multi-scale 1D temporal kernels in the dynamic temporal layer, i.e., \(T\) kernels in the paper. (default:
64
)hid_channels (int) – The number of hidden nodes in the first fully connected layer. (default:
32
)dropout (float) – Probability of an element to be zeroed in the dropout layers. (default:
0.5
)pool_kernel_size (int) – The kernel size of pooling layers in the temporal blocks (default:
16
)pool_stride (int) – The stride of pooling layers in the temporal blocks (default:
4
)num_classes (int) – The number of classes to predict. (default:
2
)
- forward(x)[source][source]¶
- Parameters:
x (torch.Tensor) – EEG signal representation, the ideal input shape is
[n, 1, 32, 128]
. Here,n
corresponds to the batch size,32
corresponds tonum_electrodes
, andchunk_size
corresponds tochunk_size
.- Returns:
the predicted probability that the samples belong to the classes.
- Return type:
torch.Tensor[number of sample, number of classes]