Shortcuts

Source code for torcheeg.models.gan.bgan

from typing import Tuple

import torch
import torch.nn as nn


[docs]class BGenerator(nn.Module): r''' TorchEEG provides an EEG feature generator based on CNN architecture and GAN for generating EEG grid representations of different frequency bands based on a given class label. .. code-block:: python g_model = BGenerator(in_channels=128) z = torch.normal(mean=0, std=1, size=(1, 128)) fake_X = g_model(z) Args: in_channels (int): The input feature dimension (of noise vectors). (default: :obj:`128`) out_channels (int): The generated feature dimension of each electrode. (default: :obj:`4`) grid_size (tuple): Spatial dimensions of grid-like EEG representation. (default: :obj:`(9, 9)`) ''' def __init__(self, in_channels: int = 128, out_channels: int = 4, grid_size: Tuple[int, int] = (9, 9)): super(BGenerator, self).__init__() self.in_channels = in_channels self.out_channels = out_channels self.grid_size = grid_size self.deproj = nn.Sequential( nn.Linear(in_channels, in_channels * 4 * 3 * 3), nn.LeakyReLU()) self.deconv1 = nn.Sequential( nn.ConvTranspose2d(in_channels * 4, in_channels * 2, kernel_size=3, stride=2, padding=1, bias=True), nn.BatchNorm2d(in_channels * 2), nn.LeakyReLU()) self.deconv2 = nn.Sequential( nn.ConvTranspose2d(in_channels * 2, in_channels * 2, kernel_size=3, stride=1, padding=1, bias=True), nn.BatchNorm2d(in_channels * 2), nn.LeakyReLU()) self.deconv3 = nn.Sequential( nn.ConvTranspose2d(in_channels * 2, in_channels, kernel_size=3, stride=2, padding=1, bias=True), nn.BatchNorm2d(in_channels), nn.LeakyReLU()) self.deconv4 = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=True)
[docs] def forward(self, x: torch.Tensor): r''' Args: x (torch.Tensor): a random vector, the ideal input shape is :obj:`[n, 128]`. Here, :obj:`n` corresponds to the batch size, and :obj:`128` corresponds to :obj:`in_channels`. Returns: torch.Tensor[n, 4, 9, 9]: the generated fake EEG signals. Here, :obj:`4` corresponds to the :obj:`out_channels`, and :obj:`(9, 9)` corresponds to the :obj:`grid_size`. ''' x = self.deproj(x) x = x.view(-1, self.in_channels * 4, 3, 3) x = self.deconv1(x) x = self.deconv2(x) x = self.deconv3(x) x = self.deconv4(x) return x
[docs]class BDiscriminator(nn.Module): r''' TorchEEG provides an EEG feature generator based on CNN architecture and GAN for generating EEG grid representations of different frequency bands based on a given class label. .. code-block:: python g_model = BGenerator(in_channels=128) d_model = BDiscriminator(in_channels=4) z = torch.normal(mean=0, std=1, size=(1, 128)) fake_X = g_model(z) disc_X = d_model(fake_X) Args: in_channels (int): The feature dimension of each electrode. (default: :obj:`4`) grid_size (tuple): Spatial dimensions of grid-like EEG representation. (default: :obj:`(9, 9)`) hid_channels (int): The number of hidden nodes in the first fully connected layer. (default: :obj:`32`) ''' def __init__(self, in_channels: int = 4, grid_size: Tuple[int, int] = (9, 9), hid_channels: int = 64): super(BDiscriminator, self).__init__() self.in_channels = in_channels self.grid_size = grid_size self.hid_channels = hid_channels self.conv1 = nn.Sequential( nn.Conv2d(in_channels, hid_channels, kernel_size=3, stride=1, padding=1, bias=True), nn.BatchNorm2d(hid_channels), nn.LeakyReLU()) self.conv2 = nn.Sequential( nn.Conv2d(hid_channels, hid_channels * 2, kernel_size=3, stride=2, padding=1, bias=True), nn.BatchNorm2d(hid_channels * 2), nn.LeakyReLU()) self.conv3 = nn.Sequential( nn.Conv2d(hid_channels * 2, hid_channels * 2, kernel_size=3, stride=1, padding=1, bias=True), nn.BatchNorm2d(hid_channels * 2), nn.LeakyReLU()) self.conv4 = nn.Sequential( nn.Conv2d(hid_channels * 2, hid_channels * 4, kernel_size=3, stride=2, padding=1, bias=True), nn.BatchNorm2d(hid_channels * 4), nn.LeakyReLU()) self.proj = nn.Linear(self.feature_dim, 1) @property def feature_dim(self): with torch.no_grad(): mock_eeg = torch.zeros(1, self.in_channels, *self.grid_size) mock_eeg = self.conv1(mock_eeg) mock_eeg = self.conv2(mock_eeg) mock_eeg = self.conv3(mock_eeg) mock_eeg = self.conv4(mock_eeg) return mock_eeg.flatten(start_dim=1).shape[-1]
[docs] def forward(self, x: torch.Tensor): r''' Args: x (torch.Tensor): EEG signal representation, the ideal input shape is :obj:`[n, 4, 9, 9]`. Here, :obj:`n` corresponds to the batch size, :obj:`4` corresponds to the :obj:`in_channels`, and :obj:`(9, 9)` corresponds to the :obj:`grid_size`. Returns: torch.Tensor[n, 1]: Predicts the result of whether a given sample is a fake sample or not. Here, :obj:`n` corresponds to the batch size. ''' x = self.conv1(x) x = self.conv2(x) x = self.conv3(x) x = self.conv4(x) x = x.flatten(start_dim=1) x = self.proj(x) return x

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources