Shortcuts

Source code for torcheeg.trainers.gan_trainer

import math
from typing import List, Tuple

import torch
import torch.autograd as autograd
import torch.nn as nn
import torchmetrics
from torch.utils.data import DataLoader

from .basic_trainer import BasicTrainer


def gradient_penalty(model, real, fake, label=None):
    device = real.device
    real = real.data
    fake = fake.data
    alpha = torch.rand(real.size(0), *([1] * (len(real.shape) - 1))).to(device)
    inputs = alpha * real + ((1 - alpha) * fake)
    inputs.requires_grad_()

    if label is None:
        outputs = model(inputs)
    else:
        outputs = model(inputs, label)

    gradient = autograd.grad(outputs=outputs,
                             inputs=inputs,
                             grad_outputs=torch.ones_like(outputs).to(device),
                             create_graph=True,
                             retain_graph=True,
                             only_inputs=True)[0]

    gradient = gradient.flatten(1)
    return ((gradient.norm(2, dim=1) - 1)**2).mean()


[docs]class GANTrainer(BasicTrainer): r''' This class provide the implementation for WGAN-GP. It trains a zero-sum game between the generator and the discriminator, just like the traditional generative networks. The generator is optimized to generate simulation samples that are indistinguishable by the discriminator, and the discriminator is optimized to discriminate false samples generated by the generator. Compared with vanilla GAN, with WGAN-GP we can improve the stability of learning, get rid of problems like mode collapse, and provide meaningful learning curves useful for debugging and hyperparameter searches. Thus, existing work typically uses WGAN-GP to generate simulated EEG signals. For more details, please refer to the following information. - Paper: Gulrajani I, Ahmed F, Arjovsky M, et al. Improved training of wasserstein gans[J]. Advances in neural information processing systems, 2017, 30. - URL: https://arxiv.org/abs/1704.00028 - Related Project: https://github.com/eriklindernoren/PyTorch-GAN .. code-block:: python g_model = BGenerator(in_channels=128) d_model = BDiscriminator(in_channels=4) trainer = GANTrainer(generator, discriminator) trainer.fit(train_loader, val_loader) trainer.test(test_loader) The class provides the following hook functions for inserting additional implementations in the training, validation and testing lifecycle: - :obj:`before_training_epoch`: executed before each epoch of training starts - :obj:`before_training_step`: executed before each batch of training starts - :obj:`on_training_step`: the training process for each batch - :obj:`after_training_step`: execute after the training of each batch - :obj:`after_training_epoch`: executed after each epoch of training - :obj:`before_validation_epoch`: executed before each round of validation starts - :obj:`before_validation_step`: executed before the validation of each batch - :obj:`on_validation_step`: validation process for each batch - :obj:`after_validation_step`: executed after the validation of each batch - :obj:`after_validation_epoch`: executed after each round of validation - :obj:`before_test_epoch`: executed before each round of test starts - :obj:`before_test_step`: executed before the test of each batch - :obj:`on_test_step`: test process for each batch - :obj:`after_test_step`: executed after the test of each batch - :obj:`after_test_epoch`: executed after each round of test If you want to customize some operations, you just need to inherit the class and override the hook function: .. code-block:: python class MyGANTrainer(GANTrainer): def before_training_epoch(self, epoch_id: int, num_epochs: int): # Do something here. super().before_training_epoch(epoch_id, num_epochs) If you want to use multiple GPUs for parallel computing, you need to specify the GPU indices you want to use in the python file: .. code-block:: python trainer = GANTrainer(generator, discriminator, device_ids=[1, 2, 7]) trainer.fit(train_loader, val_loader) trainer.test(test_loader) Then, you can use the :obj:`torch.distributed.launch` or :obj:`torchrun` to run your python file. .. code-block:: shell python -m torch.distributed.launch \ --nproc_per_node=3 \ --nnodes=1 \ --node_rank=0 \ --master_addr="localhost" \ --master_port=2345 \ your_python_file.py Here, :obj:`nproc_per_node` is the number of GPUs you specify. Args: generator (nn.Module): The generator model for EEG signal generation, whose inputs are Gaussian distributed random vectors, outputs are generated EEG signals. The dimensions of the input vector should be defined on the :obj:`in_channel` attribute. The output layer does not need to have a softmax activation function. discriminator (nn.Module): The discriminator model to determine whether the EEG signal is real or generated, and the dimension of its output should be equal to the one (i.e., the score to distinguish the real and the fake). The output layer does not need to have a sigmoid activation function. generator_lr (float): The learning rate of the generator. (defualt: :obj:`0.0001`) discriminator_lr (float): The learning rate of the discriminator. (defualt: :obj:`0.0001`) lambd (float): The weight of gradient penalty loss to trade-off between the adversarial training loss and gradient penalty loss. (defualt: :obj:`1.0`) weight_decay: (float): The weight decay (L2 penalty). (defualt: :obj:`0.0`) device_ids (list): Use cpu if the list is empty. If the list contains indices of multiple GPUs, it needs to be launched with :obj:`torch.distributed.launch` or :obj:`torchrun`. (defualt: :obj:`[]`) ddp_sync_bn (bool): Whether to replace batch normalization in network structure with cross-GPU synchronized batch normalization. Only valid when the length of :obj:`device_ids` is greater than one. (defualt: :obj:`True`) ddp_replace_sampler (bool): Whether to replace sampler in dataloader with :obj:`DistributedSampler`. Only valid when the length of :obj:`device_ids` is greater than one. (defualt: :obj:`True`) ddp_val (bool): Whether to use multi-GPU acceleration for the validation set. For experiments where data input order is sensitive, :obj:`ddp_val` should be set to :obj:`False`. Only valid when the length of :obj:`device_ids` is greater than one. (defualt: :obj:`True`) ddp_test (bool): Whether to use multi-GPU acceleration for the test set. For experiments where data input order is sensitive, :obj:`ddp_test` should be set to :obj:`False`. Only valid when the length of :obj:`device_ids` is greater than one. (defualt: :obj:`True`) .. automethod:: fit .. automethod:: test .. automethod:: sample ''' def __init__(self, generator: nn.Module, discriminator: nn.Module, generator_lr: float = 1e-4, discriminator_lr: float = 1e-4, lambd: float = 1.0, weight_decay: float = 0.0, device_ids: List[int] = [], ddp_sync_bn: bool = True, ddp_replace_sampler: bool = True, ddp_val: bool = True, ddp_test: bool = True): super(GANTrainer, self).__init__(modules={ 'generator': generator, 'discriminator': discriminator, }, device_ids=device_ids, ddp_sync_bn=ddp_sync_bn, ddp_replace_sampler=ddp_replace_sampler, ddp_val=ddp_val, ddp_test=ddp_test) self.generator_lr = generator_lr self.discriminator_lr = discriminator_lr self.weight_decay = weight_decay self.lambd = lambd self.generator_optimizer = torch.optim.Adam(generator.parameters(), lr=generator_lr, weight_decay=weight_decay) self.discriminator_optimizer = torch.optim.Adam( discriminator.parameters(), lr=discriminator_lr, weight_decay=weight_decay) self.loss_fn = nn.BCEWithLogitsLoss() # init metric self.train_g_loss = torchmetrics.MeanMetric().to(self.device) self.train_d_loss = torchmetrics.MeanMetric().to(self.device) self.val_g_loss = torchmetrics.MeanMetric().to(self.device) self.val_d_loss = torchmetrics.MeanMetric().to(self.device) self.test_g_loss = torchmetrics.MeanMetric().to(self.device) self.test_d_loss = torchmetrics.MeanMetric().to(self.device) def before_training_epoch(self, epoch_id: int, num_epochs: int): self.log(f"Epoch {epoch_id}\n-------------------------------") def on_training_step(self, train_batch: Tuple, batch_id: int, num_batches: int): self.train_g_loss.reset() self.train_d_loss.reset() X = train_batch[0].to(self.device) y = train_batch[1].to(self.device) self.generator_optimizer.zero_grad() assert hasattr( self.modules['generator'], 'in_channels' ), 'The generator must have the property in_channels to generate a batch of latent codes for the corresponding dimension.' z = torch.normal(mean=0, std=1, size=(X.shape[0], self.modules['generator'].in_channels)).to( self.device) gen_X = self.modules['generator'](z) g_loss = -torch.mean(self.modules['discriminator'](gen_X)) g_loss.backward() self.generator_optimizer.step() # backpropagation for discriminator self.discriminator_optimizer.zero_grad() real_loss = self.modules['discriminator'](X) fake_loss = self.modules['discriminator'](gen_X.detach()) gp_term = gradient_penalty(self.modules['discriminator'], X, gen_X) d_loss = -torch.mean(real_loss) + torch.mean( fake_loss) + self.lambd * gp_term d_loss.backward() self.discriminator_optimizer.step() # log five times log_step = math.ceil(num_batches / 5) if batch_id % log_step == 0: self.train_g_loss.update(g_loss) self.train_d_loss.update(d_loss) train_g_loss = self.train_g_loss.compute() train_d_loss = self.train_d_loss.compute() # if not distributed, world_size is 1 batch_id = batch_id * self.world_size num_batches = num_batches * self.world_size if self.is_main: self.log( f"g_loss: {train_g_loss:>8f}, d_loss: {train_d_loss:>8f} [{batch_id:>5d}/{num_batches:>5d}]" ) def before_validation_epoch(self, epoch_id: int, num_epochs: int): self.val_g_loss.reset() self.val_d_loss.reset() def on_validation_step(self, val_batch: Tuple, batch_id: int, num_batches: int): X = val_batch[0].to(self.device) y = val_batch[1].to(self.device) # for g_loss z = torch.normal(mean=0, std=1, size=(X.shape[0], self.modules['generator'].in_channels)).to( self.device) gen_X = self.modules['generator'](z) # g_loss = self.loss_fn(self.modules['discriminator'](gen_X), valid) g_loss = -torch.mean(self.modules['discriminator'](gen_X)) # for d_loss real_loss = self.modules['discriminator'](X) fake_loss = self.modules['discriminator'](gen_X.detach()) gp_term = gradient_penalty(self.modules['discriminator'], X, gen_X) d_loss = -torch.mean(real_loss) + torch.mean( fake_loss) + self.lambd * gp_term self.val_g_loss.update(g_loss) self.val_d_loss.update(d_loss) def after_validation_epoch(self, epoch_id: int, num_epochs: int): val_g_loss = self.val_g_loss.compute() val_d_loss = self.val_d_loss.compute() self.log(f"\ng_loss: {val_g_loss:>8f}, d_loss: {val_d_loss:>8f}") def before_test_epoch(self): self.test_g_loss.reset() self.test_d_loss.reset() def on_test_step(self, test_batch: Tuple, batch_id: int, num_batches: int): X = test_batch[0].to(self.device) y = test_batch[1].to(self.device) # for g_loss z = torch.normal(mean=0, std=1, size=(X.shape[0], self.modules['generator'].in_channels)).to( self.device) gen_X = self.modules['generator'](z) g_loss = -torch.mean(self.modules['discriminator'](gen_X)) # for d_loss real_loss = self.modules['discriminator'](X) fake_loss = self.modules['discriminator'](gen_X.detach()) gp_term = gradient_penalty(self.modules['discriminator'], X, gen_X) d_loss = -torch.mean(real_loss) + torch.mean( fake_loss) + self.lambd * gp_term self.test_g_loss.update(g_loss) self.test_d_loss.update(d_loss) def after_test_epoch(self): test_g_loss = self.test_g_loss.compute() test_d_loss = self.test_d_loss.compute() self.log(f"\ng_loss: {test_g_loss:>8f}, d_loss: {test_d_loss:>8f}")
[docs] def fit(self, train_loader: DataLoader, val_loader: DataLoader, num_epochs: int = 1, **kwargs): r''' Args: train_loader (DataLoader): Iterable DataLoader for traversing the training data batch (torch.utils.data.dataloader.DataLoader, torch_geometric.loader.DataLoader, etc). val_loader (DataLoader): Iterable DataLoader for traversing the validation data batch (torch.utils.data.dataloader.DataLoader, torch_geometric.loader.DataLoader, etc). num_epochs (int): training epochs. (defualt: :obj:`1`) ''' train_loader = self.on_reveive_dataloader(train_loader, mode='train') val_loader = self.on_reveive_dataloader(val_loader, mode='val') for t in range(num_epochs): if hasattr(train_loader, 'need_to_set_epoch'): train_loader.sampler.set_epoch(t) if hasattr(val_loader, 'need_to_set_epoch'): val_loader.sampler.set_epoch(t) num_batches = len(train_loader) # set model to train mode for k, m in self.modules.items(): self.modules[k].train() # hook self.before_training_epoch(t + 1, num_epochs, **kwargs) for batch_id, train_batch in enumerate(train_loader): # hook self.before_training_step(batch_id, num_batches, **kwargs) # hook self.on_training_step(train_batch, batch_id, num_batches, **kwargs) # hook self.after_training_step(batch_id, num_batches, **kwargs) # hook self.after_training_epoch(t + 1, num_epochs, **kwargs) # set model to val mode for k, m in self.modules.items(): self.modules[k].eval() num_batches = len(val_loader) # hook self.before_validation_epoch(t + 1, num_epochs, **kwargs) for batch_id, val_batch in enumerate(val_loader): # hook self.before_validation_step(batch_id, num_batches, **kwargs) # hook self.on_validation_step(val_batch, batch_id, num_batches, **kwargs) # hook self.after_validation_step(batch_id, num_batches, **kwargs) # hook self.after_validation_epoch(t + 1, num_epochs, **kwargs) return self
[docs] def test(self, test_loader: DataLoader, **kwargs): r''' Args: test_loader (DataLoader): Iterable DataLoader for traversing the test data batch (torch.utils.data.dataloader.DataLoader, torch_geometric.loader.DataLoader, etc). ''' test_loader = self.on_reveive_dataloader(test_loader, mode='test') for k, m in self.modules.items(): self.modules[k].eval() num_batches = len(test_loader) self.before_test_epoch(**kwargs) for batch_id, test_batch in enumerate(test_loader): # hook self.before_test_step(batch_id, num_batches, **kwargs) # hook self.on_test_step(test_batch, batch_id, num_batches, **kwargs) # hook self.after_test_step(batch_id, num_batches, **kwargs) self.after_test_epoch(**kwargs)
[docs] def sample(self, num_samples: int) -> torch.Tensor: """ Samples from the latent space and return generated results. Args: num_samples (int): Number of samples. Returns: torch.Tensor: the generated samples. """ self.modules['generator'].eval() with torch.no_grad(): z = torch.randn(num_samples, self.modules['generator'].in_channels).to( self.device) return self.modules['generator'](z)
[docs]class CGANTrainer(GANTrainer): r''' This class provide the implementation for WGAN-GP. It trains a zero-sum game between the generator and the discriminator, just like the traditional generative networks. The generator is optimized to generate simulation samples that are indistinguishable by the discriminator, and the discriminator is optimized to discriminate false samples generated by the generator. Compared with vanilla GAN, with WGAN-GP we can improve the stability of learning, get rid of problems like mode collapse, and provide meaningful learning curves useful for debugging and hyperparameter searches. Thus, existing work typically uses WGAN-GP to generate simulated EEG signals. In particular, the expected labels are additionally provided to guide the discriminator to distinguish whether the sample fits the data distribution of the class. For more details, please refer to the following information. - Paper: Gulrajani I, Ahmed F, Arjovsky M, et al. Improved training of wasserstein gans[J]. Advances in neural information processing systems, 2017, 30. - URL: https://arxiv.org/abs/1704.00028 - Related Project: https://github.com/eriklindernoren/PyTorch-GAN .. code-block:: python g_model = BCGenerator(in_channels=128, num_classes=2) d_model = BCDiscriminator(in_channels=4, num_classes=2) trainer = CGANTrainer(generator, discriminator) trainer.fit(train_loader, val_loader) trainer.test(test_loader) The class provides the following hook functions for inserting additional implementations in the training, validation and testing lifecycle: - :obj:`before_training_epoch`: executed before each epoch of training starts - :obj:`before_training_step`: executed before each batch of training starts - :obj:`on_training_step`: the training process for each batch - :obj:`after_training_step`: execute after the training of each batch - :obj:`after_training_epoch`: executed after each epoch of training - :obj:`before_validation_epoch`: executed before each round of validation starts - :obj:`before_validation_step`: executed before the validation of each batch - :obj:`on_validation_step`: validation process for each batch - :obj:`after_validation_step`: executed after the validation of each batch - :obj:`after_validation_epoch`: executed after each round of validation - :obj:`before_test_epoch`: executed before each round of test starts - :obj:`before_test_step`: executed before the test of each batch - :obj:`on_test_step`: test process for each batch - :obj:`after_test_step`: executed after the test of each batch - :obj:`after_test_epoch`: executed after each round of test If you want to customize some operations, you just need to inherit the class and override the hook function: .. code-block:: python class MyGANTrainer(GANTrainer): def before_training_epoch(self, epoch_id: int, num_epochs: int): # Do something here. super().before_training_epoch(epoch_id, num_epochs) If you want to use multiple GPUs for parallel computing, you need to specify the GPU indices you want to use in the python file: .. code-block:: python trainer = GANTrainer(generator, discriminator, device_ids=[1, 2, 7]) trainer.fit(train_loader, val_loader) trainer.test(test_loader) Then, you can use the :obj:`torch.distributed.launch` or :obj:`torchrun` to run your python file. .. code-block:: shell python -m torch.distributed.launch \ --nproc_per_node=3 \ --nnodes=1 \ --node_rank=0 \ --master_addr="localhost" \ --master_port=2345 \ your_python_file.py Here, :obj:`nproc_per_node` is the number of GPUs you specify. Args: generator (nn.Module): The generator model for EEG signal generation, whose inputs are Gaussian distributed random vectors, outputs are generated EEG signals. The dimensions of the input vector should be defined on the :obj:`in_channel` attribute. The output layer does not need to have a softmax activation function. discriminator (nn.Module): The discriminator model to determine whether the EEG signal is real or generated, and the dimension of its output should be equal to the one (i.e., the score to distinguish the real and the fake). The output layer does not need to have a sigmoid activation function. generator_lr (float): The learning rate of the generator. (defualt: :obj:`0.0001`) discriminator_lr (float): The learning rate of the discriminator. (defualt: :obj:`0.0001`) lambd (float): The weight of gradient penalty loss to trade-off between the adversarial training loss and gradient penalty loss. (defualt: :obj:`1.0`) weight_decay: (float): The weight decay (L2 penalty). (defualt: :obj:`0.0`) device_ids (list): Use cpu if the list is empty. If the list contains indices of multiple GPUs, it needs to be launched with :obj:`torch.distributed.launch` or :obj:`torchrun`. (defualt: :obj:`[]`) ddp_sync_bn (bool): Whether to replace batch normalization in network structure with cross-GPU synchronized batch normalization. Only valid when the length of :obj:`device_ids` is greater than one. (defualt: :obj:`True`) ddp_replace_sampler (bool): Whether to replace sampler in dataloader with :obj:`DistributedSampler`. Only valid when the length of :obj:`device_ids` is greater than one. (defualt: :obj:`True`) ddp_val (bool): Whether to use multi-GPU acceleration for the validation set. For experiments where data input order is sensitive, :obj:`ddp_val` should be set to :obj:`False`. Only valid when the length of :obj:`device_ids` is greater than one. (defualt: :obj:`True`) ddp_test (bool): Whether to use multi-GPU acceleration for the test set. For experiments where data input order is sensitive, :obj:`ddp_test` should be set to :obj:`False`. Only valid when the length of :obj:`device_ids` is greater than one. (defualt: :obj:`True`) .. automethod:: fit .. automethod:: test .. automethod:: sample ''' def __init__(self, generator: nn.Module, discriminator: nn.Module, generator_lr: float = 1e-4, discriminator_lr: float = 1e-4, lambd: float = 1.0, weight_decay: float = 0.0, device_ids: List[int] = [], ddp_sync_bn: bool = True, ddp_replace_sampler: bool = True, ddp_val: bool = True, ddp_test: bool = True): super(CGANTrainer, self).__init__( generator=generator, discriminator=discriminator, generator_lr=generator_lr, discriminator_lr=discriminator_lr, lambd=lambd, weight_decay=weight_decay, device_ids=device_ids, ddp_sync_bn=ddp_sync_bn, ddp_replace_sampler=ddp_replace_sampler, ddp_val=ddp_val, ddp_test=ddp_test, ) def on_training_step(self, train_batch: Tuple, batch_id: int, num_batches: int): self.train_g_loss.reset() self.train_d_loss.reset() X = train_batch[0].to(self.device) y = train_batch[1].to(self.device) self.generator_optimizer.zero_grad() assert hasattr( self.modules['generator'], 'in_channels' ), 'The generator must have the property in_channels to generate a batch of latent codes for the corresponding dimension.' z = torch.normal(mean=0, std=1, size=(X.shape[0], self.modules['generator'].in_channels)).to( self.device) gen_X = self.modules['generator'](z, y) g_loss = -torch.mean(self.modules['discriminator'](gen_X, y)) g_loss.backward() self.generator_optimizer.step() # backpropagation for discriminator self.discriminator_optimizer.zero_grad() real_loss = self.modules['discriminator'](X, y) fake_loss = self.modules['discriminator'](gen_X.detach(), y) gp_term = gradient_penalty(self.modules['discriminator'], X, gen_X, y) d_loss = -torch.mean(real_loss) + torch.mean( fake_loss) + self.lambd * gp_term d_loss.backward() self.discriminator_optimizer.step() # log five times log_step = math.ceil(num_batches / 5) if batch_id % log_step == 0: self.train_g_loss.update(g_loss) self.train_d_loss.update(d_loss) train_g_loss = self.train_g_loss.compute() train_d_loss = self.train_d_loss.compute() # if not distributed, world_size is 1 batch_id = batch_id * self.world_size num_batches = num_batches * self.world_size if self.is_main: self.log( f"g_loss: {train_g_loss:>8f}, d_loss: {train_d_loss:>8f} [{batch_id:>5d}/{num_batches:>5d}]" ) def on_validation_step(self, val_batch: Tuple, batch_id: int, num_batches: int): X = val_batch[0].to(self.device) y = val_batch[1].to(self.device) # for g_loss z = torch.normal(mean=0, std=1, size=(X.shape[0], self.modules['generator'].in_channels)).to( self.device) gen_X = self.modules['generator'](z, y) g_loss = -torch.mean(self.modules['discriminator'](gen_X, y)) # for d_loss real_loss = self.modules['discriminator'](X, y) fake_loss = self.modules['discriminator'](gen_X.detach(), y) gp_term = gradient_penalty(self.modules['discriminator'], X, gen_X, y) d_loss = -torch.mean(real_loss) + torch.mean( fake_loss) + self.lambd * gp_term self.val_g_loss.update(g_loss) self.val_d_loss.update(d_loss) def on_test_step(self, test_batch: Tuple, batch_id: int, num_batches: int): X = test_batch[0].to(self.device) y = test_batch[1].to(self.device) # for g_loss z = torch.normal(mean=0, std=1, size=(X.shape[0], self.modules['generator'].in_channels)).to( self.device) gen_X = self.modules['generator'](z, y) g_loss = -torch.mean(self.modules['discriminator'](gen_X, y)) # for d_loss real_loss = self.modules['discriminator'](X, y) fake_loss = self.modules['discriminator'](gen_X.detach(), y) gp_term = gradient_penalty(self.modules['discriminator'], X, gen_X, y) d_loss = -torch.mean(real_loss) + torch.mean( fake_loss) + self.lambd * gp_term self.test_g_loss.update(g_loss) self.test_d_loss.update(d_loss)
[docs] def sample(self, num_samples: int, labels: torch.Tensor = None) -> torch.Tensor: """ Samples from the latent space and return generated results. Args: num_samples (int): Number of samples. labels (torch.Tensor): Category labels (int) for a batch of samples The shape should be :obj:`[n,]`. Here, :obj:`n` corresponds to the batch size. If not provided, a batch of randomly generated categories will be used. Returns: torch.Tensor: the generated samples. """ if labels: assert len( labels ) == num_samples, f'labels ({len(labels)}) should be the same length as num_samples ({num_samples}).' assert isinstance( labels, torch.Tensor ), f'labels should be torch.Tensor instances, the current input is {type(labels)}' else: labels = torch.randint(low=0, high=self.modules['generator'].num_classes, size=(num_samples, )) labels = labels.long().to(self.device) self.modules['generator'].eval() with torch.no_grad(): z = torch.randn(num_samples, self.modules['generator'].in_channels).to( self.device) return self.modules['generator'](z, labels)

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources