BCDiscriminator¶
- class torcheeg.models.BCDiscriminator(in_channels: int = 4, grid_size: Tuple[int, int] = (9, 9), hid_channels: int = 64, num_classes: int = 3)[source][source]¶
GAN-based methods formulate a zero-sum game between the generator and the discriminator. The generator is optimized to generate simulation samples that are indistinguishable by the discriminator, and the discriminator is optimized to discriminate false samples generated by the generator. This class provide a baseline implementation for the discriminator. In particular, the expected labels are additionally provided to guide the discriminator to distinguish whether the sample fits the data distribution of the class.
g_model = BCGenerator(in_channels=128, num_classes=3) d_model = BCDiscriminator(in_channels=4, num_classes=3) z = torch.normal(mean=0, std=1, size=(1, 128)) y = torch.randint(low=0, high=3, size=(1,)) fake_X = g_model(z, y) disc_X = d_model(fake_X, y)
- Parameters:
in_channels (int) – The feature dimension of each electrode. (default:
4
)grid_size (tuple) – Spatial dimensions of grid-like EEG representation. (default:
(9, 9)
)hid_channels (int) – The number of hidden nodes in the first fully connected layer. (default:
32
)num_classes (int) – The number of classes. (default:
2
)
- forward(x: Tensor, y: Tensor)[source][source]¶
- Parameters:
x (torch.Tensor) – EEG signal representation, the ideal input shape is
[n, 4, 9, 9]
. Here,n
corresponds to the batch size,4
corresponds to thein_channels
, and(9, 9)
corresponds to thegrid_size
.y (torch.Tensor) – Category labels (int) for a batch of samples The shape should be
[n,]
. Here,n
corresponds to the batch size.
- Returns:
Predicts the result of whether a given sample is a fake sample or not. Here,
n
corresponds to the batch size.- Return type:
torch.Tensor[n, 1]