BDiscriminator¶
- class torcheeg.models.BDiscriminator(in_channels: int = 4, grid_size: Tuple[int, int] = (9, 9), hid_channels: int = 64)[source][source]¶
TorchEEG provides an EEG feature generator based on CNN architecture and GAN for generating EEG grid representations of different frequency bands based on a given class label.
g_model = BGenerator(in_channels=128) d_model = BDiscriminator(in_channels=4) z = torch.normal(mean=0, std=1, size=(1, 128)) fake_X = g_model(z) disc_X = d_model(fake_X)
- Parameters:
in_channels (int) – The feature dimension of each electrode. (default:
4)grid_size (tuple) – Spatial dimensions of grid-like EEG representation. (default:
(9, 9))hid_channels (int) – The number of hidden nodes in the first fully connected layer. (default:
32)
- forward(x: Tensor)[source][source]¶
- Parameters:
x (torch.Tensor) – EEG signal representation, the ideal input shape is
[n, 4, 9, 9]. Here,ncorresponds to the batch size,4corresponds to thein_channels, and(9, 9)corresponds to thegrid_size.- Returns:
Predicts the result of whether a given sample is a fake sample or not. Here,
ncorresponds to the batch size.- Return type:
torch.Tensor[n, 1]