BCGenerator¶
- class torcheeg.models.BCGenerator(in_channels: int = 128, out_channels: int = 4, grid_size: Tuple[int, int] = (9, 9), num_classes: int = 3)[source][source]¶
GAN-based methods formulate a zero-sum game between the generator and the discriminator. The generator is optimized to generate simulation samples that are indistinguishable by the discriminator, and the discriminator is optimized to discriminate false samples generated by the generator. This class provide a baseline implementation for the generator. In particular, the expected labels are additionally provided to guide the generator to generate samples of the specified class from the random noise.
Related Project: https://github.com/eriklindernoren/PyTorch-GAN/
import torch from torcheeg.models.gan.bcgan import BCGenerator g_model = BCGenerator(in_channels=128, num_classes=3) z = torch.normal(mean=0, std=1, size=(1, 128)) y = torch.randint(low=0, high=3, size=(1,)) fake_X = g_model(z, y)
- Parameters:
in_channels (int) – The input feature dimension (of noise vectors). (default:
128
)out_channels (int) – The generated feature dimension of each electrode. (default:
4
)grid_size (tuple) – Spatial dimensions of grid-like EEG representation. (default:
(9, 9)
)num_classes (int) – The number of classes. (default:
2
)
- forward(x: Tensor, y: Tensor)[source][source]¶
- Parameters:
x (torch.Tensor) – a random vector, the ideal input shape is
[n, 128]
. Here,n
corresponds to the batch size, and128
corresponds toin_channels
.y (torch.Tensor) – Category labels (int) for a batch of samples The shape should be
[n,]
. Here,n
corresponds to the batch size.
- Returns:
the generated fake EEG signals. Here,
4
corresponds to theout_channels
, and(9, 9)
corresponds to thegrid_size
.- Return type:
torch.Tensor[n, 4, 9, 9]