Source code for torcheeg.models.gan.bcgan
from typing import Tuple
import torch
import torch.nn as nn
[docs]class BCGenerator(nn.Module):
r'''
GAN-based methods formulate a zero-sum game between the generator and the discriminator. The generator is optimized to generate simulation samples that are indistinguishable by the discriminator, and the discriminator is optimized to discriminate false samples generated by the generator. This class provide a baseline implementation for the generator. In particular, the expected labels are additionally provided to guide the generator to generate samples of the specified class from the random noise.
- Related Project: https://github.com/eriklindernoren/PyTorch-GAN/
.. code-block:: python
import torch
from torcheeg.models.gan.bcgan import BCGenerator
g_model = BCGenerator(in_channels=128, num_classes=3)
z = torch.normal(mean=0, std=1, size=(1, 128))
y = torch.randint(low=0, high=3, size=(1,))
fake_X = g_model(z, y)
Args:
in_channels (int): The input feature dimension (of noise vectors). (default: :obj:`128`)
out_channels (int): The generated feature dimension of each electrode. (default: :obj:`4`)
grid_size (tuple): Spatial dimensions of grid-like EEG representation. (default: :obj:`(9, 9)`)
num_classes (int): The number of classes. (default: :obj:`2`)
'''
def __init__(self,
in_channels: int = 128,
out_channels: int = 4,
grid_size: Tuple[int, int] = (9, 9),
num_classes: int = 3):
super(BCGenerator, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.grid_size = grid_size
self.num_classes = num_classes
self.label_embeding = nn.Embedding(num_classes, in_channels)
self.deproj = nn.Sequential(
nn.Linear(in_channels * 2, in_channels * 4 * 3 * 3), nn.LeakyReLU())
self.deconv1 = nn.Sequential(
nn.ConvTranspose2d(in_channels * 4,
in_channels * 2,
kernel_size=3,
stride=2,
padding=1,
bias=True), nn.BatchNorm2d(in_channels * 2),
nn.LeakyReLU())
self.deconv2 = nn.Sequential(
nn.ConvTranspose2d(in_channels * 2,
in_channels * 2,
kernel_size=3,
stride=1,
padding=1,
bias=True), nn.BatchNorm2d(in_channels * 2),
nn.LeakyReLU())
self.deconv3 = nn.Sequential(
nn.ConvTranspose2d(in_channels * 2,
in_channels,
kernel_size=3,
stride=2,
padding=1,
bias=True), nn.BatchNorm2d(in_channels),
nn.LeakyReLU())
self.deconv4 = nn.ConvTranspose2d(in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1,
bias=True)
[docs] def forward(self, x: torch.Tensor, y: torch.Tensor):
r'''
Args:
x (torch.Tensor): a random vector, the ideal input shape is :obj:`[n, 128]`. Here, :obj:`n` corresponds to the batch size, and :obj:`128` corresponds to :obj:`in_channels`.
y (torch.Tensor): Category labels (int) for a batch of samples The shape should be :obj:`[n,]`. Here, :obj:`n` corresponds to the batch size.
Returns:
torch.Tensor[n, 4, 9, 9]: the generated fake EEG signals. Here, :obj:`4` corresponds to the :obj:`out_channels`, and :obj:`(9, 9)` corresponds to the :obj:`grid_size`.
'''
label_emb = self.label_embeding(y)
x = torch.cat([x, label_emb], dim=-1)
x = self.deproj(x)
x = x.view(-1, self.in_channels * 4, 3, 3)
x = self.deconv1(x)
x = self.deconv2(x)
x = self.deconv3(x)
x = self.deconv4(x)
return x
[docs]class BCDiscriminator(nn.Module):
r'''
GAN-based methods formulate a zero-sum game between the generator and the discriminator. The generator is optimized to generate simulation samples that are indistinguishable by the discriminator, and the discriminator is optimized to discriminate false samples generated by the generator. This class provide a baseline implementation for the discriminator. In particular, the expected labels are additionally provided to guide the discriminator to distinguish whether the sample fits the data distribution of the class.
.. code-block:: python
g_model = BCGenerator(in_channels=128, num_classes=3)
d_model = BCDiscriminator(in_channels=4, num_classes=3)
z = torch.normal(mean=0, std=1, size=(1, 128))
y = torch.randint(low=0, high=3, size=(1,))
fake_X = g_model(z, y)
disc_X = d_model(fake_X, y)
Args:
in_channels (int): The feature dimension of each electrode. (default: :obj:`4`)
grid_size (tuple): Spatial dimensions of grid-like EEG representation. (default: :obj:`(9, 9)`)
hid_channels (int): The number of hidden nodes in the first fully connected layer. (default: :obj:`32`)
num_classes (int): The number of classes. (default: :obj:`2`)
'''
def __init__(self,
in_channels: int = 4,
grid_size: Tuple[int, int] = (9, 9),
hid_channels: int = 64,
num_classes: int = 3):
super(BCDiscriminator, self).__init__()
self.in_channels = in_channels
self.grid_size = grid_size
self.hid_channels = hid_channels
self.num_classes = num_classes
self.label_embeding = nn.Embedding(
num_classes, in_channels * grid_size[0] * grid_size[1])
self.conv1 = nn.Sequential(
nn.Conv2d(in_channels * 2,
hid_channels,
kernel_size=3,
stride=1,
padding=1,
bias=True), nn.BatchNorm2d(hid_channels), nn.LeakyReLU())
self.conv2 = nn.Sequential(
nn.Conv2d(hid_channels,
hid_channels * 2,
kernel_size=3,
stride=2,
padding=1,
bias=True), nn.BatchNorm2d(hid_channels * 2),
nn.LeakyReLU())
self.conv3 = nn.Sequential(
nn.Conv2d(hid_channels * 2,
hid_channels * 2,
kernel_size=3,
stride=1,
padding=1,
bias=True), nn.BatchNorm2d(hid_channels * 2),
nn.LeakyReLU())
self.conv4 = nn.Sequential(
nn.Conv2d(hid_channels * 2,
hid_channels * 4,
kernel_size=3,
stride=2,
padding=1,
bias=True), nn.BatchNorm2d(hid_channels * 4),
nn.LeakyReLU())
self.proj = nn.Linear(self.feature_dim, 1)
@property
def feature_dim(self):
with torch.no_grad():
mock_y = torch.randint(low=0, high=self.num_classes, size=(1, ))
mock_eeg = torch.zeros(1, self.in_channels, *self.grid_size)
label_emb = self.label_embeding(mock_y)
label_emb = label_emb.reshape(mock_eeg.shape)
mock_eeg = torch.cat([mock_eeg, label_emb], dim=1)
mock_eeg = self.conv1(mock_eeg)
mock_eeg = self.conv2(mock_eeg)
mock_eeg = self.conv3(mock_eeg)
mock_eeg = self.conv4(mock_eeg)
return mock_eeg.flatten(start_dim=1).shape[-1]
[docs] def forward(self, x: torch.Tensor, y: torch.Tensor):
r'''
Args:
x (torch.Tensor): EEG signal representation, the ideal input shape is :obj:`[n, 4, 9, 9]`. Here, :obj:`n` corresponds to the batch size, :obj:`4` corresponds to the :obj:`in_channels`, and :obj:`(9, 9)` corresponds to the :obj:`grid_size`.
y (torch.Tensor): Category labels (int) for a batch of samples The shape should be :obj:`[n,]`. Here, :obj:`n` corresponds to the batch size.
Returns:
torch.Tensor[n, 1]: Predicts the result of whether a given sample is a fake sample or not. Here, :obj:`n` corresponds to the batch size.
'''
label_emb = self.label_embeding(y)
label_emb = label_emb.reshape(x.shape)
x = torch.cat([x, label_emb], dim=1)
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
x = self.conv4(x)
x = x.flatten(start_dim=1)
x = self.proj(x)
return x