Shortcuts

Source code for torcheeg.models.gan.bcgan

from typing import Tuple

import torch
import torch.nn as nn


[docs]class BCGenerator(nn.Module): r''' GAN-based methods formulate a zero-sum game between the generator and the discriminator. The generator is optimized to generate simulation samples that are indistinguishable by the discriminator, and the discriminator is optimized to discriminate false samples generated by the generator. This class provide a baseline implementation for the generator. In particular, the expected labels are additionally provided to guide the generator to generate samples of the specified class from the random noise. - Related Project: https://github.com/eriklindernoren/PyTorch-GAN/ .. code-block:: python import torch from torcheeg.models.gan.bcgan import BCGenerator g_model = BCGenerator(in_channels=128, num_classes=3) z = torch.normal(mean=0, std=1, size=(1, 128)) y = torch.randint(low=0, high=3, size=(1,)) fake_X = g_model(z, y) Args: in_channels (int): The input feature dimension (of noise vectors). (default: :obj:`128`) out_channels (int): The generated feature dimension of each electrode. (default: :obj:`4`) grid_size (tuple): Spatial dimensions of grid-like EEG representation. (default: :obj:`(9, 9)`) num_classes (int): The number of classes. (default: :obj:`2`) ''' def __init__(self, in_channels: int = 128, out_channels: int = 4, grid_size: Tuple[int, int] = (9, 9), num_classes: int = 3): super(BCGenerator, self).__init__() self.in_channels = in_channels self.out_channels = out_channels self.grid_size = grid_size self.num_classes = num_classes self.label_embeding = nn.Embedding(num_classes, in_channels) self.deproj = nn.Sequential( nn.Linear(in_channels * 2, in_channels * 4 * 3 * 3), nn.LeakyReLU()) self.deconv1 = nn.Sequential( nn.ConvTranspose2d(in_channels * 4, in_channels * 2, kernel_size=3, stride=2, padding=1, bias=True), nn.BatchNorm2d(in_channels * 2), nn.LeakyReLU()) self.deconv2 = nn.Sequential( nn.ConvTranspose2d(in_channels * 2, in_channels * 2, kernel_size=3, stride=1, padding=1, bias=True), nn.BatchNorm2d(in_channels * 2), nn.LeakyReLU()) self.deconv3 = nn.Sequential( nn.ConvTranspose2d(in_channels * 2, in_channels, kernel_size=3, stride=2, padding=1, bias=True), nn.BatchNorm2d(in_channels), nn.LeakyReLU()) self.deconv4 = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=True)
[docs] def forward(self, x: torch.Tensor, y: torch.Tensor): r''' Args: x (torch.Tensor): a random vector, the ideal input shape is :obj:`[n, 128]`. Here, :obj:`n` corresponds to the batch size, and :obj:`128` corresponds to :obj:`in_channels`. y (torch.Tensor): Category labels (int) for a batch of samples The shape should be :obj:`[n,]`. Here, :obj:`n` corresponds to the batch size. Returns: torch.Tensor[n, 4, 9, 9]: the generated fake EEG signals. Here, :obj:`4` corresponds to the :obj:`out_channels`, and :obj:`(9, 9)` corresponds to the :obj:`grid_size`. ''' label_emb = self.label_embeding(y) x = torch.cat([x, label_emb], dim=-1) x = self.deproj(x) x = x.view(-1, self.in_channels * 4, 3, 3) x = self.deconv1(x) x = self.deconv2(x) x = self.deconv3(x) x = self.deconv4(x) return x
[docs]class BCDiscriminator(nn.Module): r''' GAN-based methods formulate a zero-sum game between the generator and the discriminator. The generator is optimized to generate simulation samples that are indistinguishable by the discriminator, and the discriminator is optimized to discriminate false samples generated by the generator. This class provide a baseline implementation for the discriminator. In particular, the expected labels are additionally provided to guide the discriminator to distinguish whether the sample fits the data distribution of the class. .. code-block:: python g_model = BCGenerator(in_channels=128, num_classes=3) d_model = BCDiscriminator(in_channels=4, num_classes=3) z = torch.normal(mean=0, std=1, size=(1, 128)) y = torch.randint(low=0, high=3, size=(1,)) fake_X = g_model(z, y) disc_X = d_model(fake_X, y) Args: in_channels (int): The feature dimension of each electrode. (default: :obj:`4`) grid_size (tuple): Spatial dimensions of grid-like EEG representation. (default: :obj:`(9, 9)`) hid_channels (int): The number of hidden nodes in the first fully connected layer. (default: :obj:`32`) num_classes (int): The number of classes. (default: :obj:`2`) ''' def __init__(self, in_channels: int = 4, grid_size: Tuple[int, int] = (9, 9), hid_channels: int = 64, num_classes: int = 3): super(BCDiscriminator, self).__init__() self.in_channels = in_channels self.grid_size = grid_size self.hid_channels = hid_channels self.num_classes = num_classes self.label_embeding = nn.Embedding( num_classes, in_channels * grid_size[0] * grid_size[1]) self.conv1 = nn.Sequential( nn.Conv2d(in_channels * 2, hid_channels, kernel_size=3, stride=1, padding=1, bias=True), nn.BatchNorm2d(hid_channels), nn.LeakyReLU()) self.conv2 = nn.Sequential( nn.Conv2d(hid_channels, hid_channels * 2, kernel_size=3, stride=2, padding=1, bias=True), nn.BatchNorm2d(hid_channels * 2), nn.LeakyReLU()) self.conv3 = nn.Sequential( nn.Conv2d(hid_channels * 2, hid_channels * 2, kernel_size=3, stride=1, padding=1, bias=True), nn.BatchNorm2d(hid_channels * 2), nn.LeakyReLU()) self.conv4 = nn.Sequential( nn.Conv2d(hid_channels * 2, hid_channels * 4, kernel_size=3, stride=2, padding=1, bias=True), nn.BatchNorm2d(hid_channels * 4), nn.LeakyReLU()) self.proj = nn.Linear(self.feature_dim, 1) @property def feature_dim(self): with torch.no_grad(): mock_y = torch.randint(low=0, high=self.num_classes, size=(1, )) mock_eeg = torch.zeros(1, self.in_channels, *self.grid_size) label_emb = self.label_embeding(mock_y) label_emb = label_emb.reshape(mock_eeg.shape) mock_eeg = torch.cat([mock_eeg, label_emb], dim=1) mock_eeg = self.conv1(mock_eeg) mock_eeg = self.conv2(mock_eeg) mock_eeg = self.conv3(mock_eeg) mock_eeg = self.conv4(mock_eeg) return mock_eeg.flatten(start_dim=1).shape[-1]
[docs] def forward(self, x: torch.Tensor, y: torch.Tensor): r''' Args: x (torch.Tensor): EEG signal representation, the ideal input shape is :obj:`[n, 4, 9, 9]`. Here, :obj:`n` corresponds to the batch size, :obj:`4` corresponds to the :obj:`in_channels`, and :obj:`(9, 9)` corresponds to the :obj:`grid_size`. y (torch.Tensor): Category labels (int) for a batch of samples The shape should be :obj:`[n,]`. Here, :obj:`n` corresponds to the batch size. Returns: torch.Tensor[n, 1]: Predicts the result of whether a given sample is a fake sample or not. Here, :obj:`n` corresponds to the batch size. ''' label_emb = self.label_embeding(y) label_emb = label_emb.reshape(x.shape) x = torch.cat([x, label_emb], dim=1) x = self.conv1(x) x = self.conv2(x) x = self.conv3(x) x = self.conv4(x) x = x.flatten(start_dim=1) x = self.proj(x) return x
Read the Docs v: latest
Versions
latest
stable
v1.1.2
v1.1.1
v1.1.0
v1.0.11
v1.0.10
v1.0.9
v1.0.8.post1
v1.0.8
v1.0.7
v1.0.6
v1.0.4
v1.0.3
v1.0.2
v1.0.1
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources