CWGANGPTrainer¶
- class torcheeg.trainers.CWGANGPTrainer(generator: Module, discriminator: Module, generator_lr: float = 0.0001, discriminator_lr: float = 0.0001, weight_decay: float = 0.0, weight_gradient_penalty: float = 1.0, latent_channels: int | None = None, devices: int = 1, accelerator: str = 'cpu', metrics: List[str] = [], metric_extractor: Module | None = None, metric_classifier: Module | None = None, metric_num_features: int | None = None)[source][source]¶
This class provide the implementation for WGAN-GP. It trains a zero-sum game between the generator and the discriminator, just like the traditional generative networks. The generator is optimized to generate simulation samples that are indistinguishable by the discriminator, and the discriminator is optimized to discriminate false samples generated by the generator. Compared with vanilla GAN, with WGAN-GP we can improve the stability of learning, get rid of problems like mode collapse, and provide meaningful learning curves useful for debugging and hyperparameter searches. Thus, existing work typically uses WGAN-GP to generate simulated EEG signals. In particular, the expected labels are additionally provided to guide the discriminator to distinguish whether the sample fits the data distribution of the class. For more details, please refer to the following information.
Paper: Gulrajani I, Ahmed F, Arjovsky M, et al. Improved training of wasserstein gans[J]. Advances in neural information processing systems, 2017, 30.
Related Project: https://github.com/eriklindernoren/PyTorch-GAN
g_model = BGenerator(in_channels=128) d_model = BDiscriminator(in_channels=4) trainer = WGANGPTrainer(generator, discriminator) trainer.fit(train_loader, val_loader) trainer.test(test_loader)
- Parameters:
generator (nn.Module) – The generator model for EEG signal generation, whose inputs are Gaussian distributed random vectors, outputs are generated EEG signals. The dimensions of the input vector should be defined on the
in_channel
attribute. The output layer does not need to have a softmax activation function.discriminator (nn.Module) – The discriminator model to determine whether the EEG signal is real or generated, and the dimension of its output should be equal to the one (i.e., the score to distinguish the real and the fake). The output layer does not need to have a sigmoid activation function.
generator_lr (float) – The learning rate of the generator. (default:
0.0001
)discriminator_lr (float) – The learning rate of the discriminator. (default:
0.0001
)weight_gradient_penalty (float) – The weight of gradient penalty loss to trade-off between the adversarial training loss and gradient penalty loss. (default:
1.0
)weight_decay – (float): The weight decay (L2 penalty). (default:
0.0
)latent_channels (int) – The dimension of the latent vector. If not specified, it will be inferred from the
in_channels
attribute of the generator. (default:None
)devices (int) – The number of GPUs to use. (default:
1
)accelerator (str) – The accelerator to use. Available options are: ‘cpu’, ‘gpu’. (default:
"cpu"
)metrics (List[str]) – The metrics to use. The metrics to use. Available options are: ‘fid’, ‘is’. (default:
[]
)
- fit(train_loader: DataLoader, val_loader: DataLoader, max_epochs: int = 300, *args, **kwargs) Any [source]¶
- Parameters:
train_loader (DataLoader) – Iterable DataLoader for traversing the training data batch (
torch.utils.data.dataloader.DataLoader
,torch_geometric.loader.DataLoader
, etc).val_loader (DataLoader) – Iterable DataLoader for traversing the validation data batch (
torch.utils.data.dataloader.DataLoader
,torch_geometric.loader.DataLoader
, etc).max_epochs (int) – Maximum number of epochs to train the model. (default:
300
)