BDiscriminator¶
- class torcheeg.models.BDiscriminator(in_channels: int = 4, grid_size: Tuple[int, int] = (9, 9), hid_channels: int = 64)[source][source]¶
TorchEEG provides an EEG feature generator based on CNN architecture and GAN for generating EEG grid representations of different frequency bands based on a given class label.
g_model = BGenerator(in_channels=128) d_model = BDiscriminator(in_channels=4) z = torch.normal(mean=0, std=1, size=(1, 128)) fake_X = g_model(z) disc_X = d_model(fake_X)
- Parameters:
in_channels (int) – The feature dimension of each electrode. (default:
4
)grid_size (tuple) – Spatial dimensions of grid-like EEG representation. (default:
(9, 9)
)hid_channels (int) – The number of hidden nodes in the first fully connected layer. (default:
32
)
- forward(x: Tensor)[source][source]¶
- Parameters:
x (torch.Tensor) – EEG signal representation, the ideal input shape is
[n, 4, 9, 9]
. Here,n
corresponds to the batch size,4
corresponds to thein_channels
, and(9, 9)
corresponds to thegrid_size
.- Returns:
Predicts the result of whether a given sample is a fake sample or not. Here,
n
corresponds to the batch size.- Return type:
torch.Tensor[n, 1]