Shortcuts

CWGANGPTrainer

class torcheeg.trainers.CWGANGPTrainer(generator: Module, discriminator: Module, generator_lr: float = 0.0001, discriminator_lr: float = 0.0001, weight_decay: float = 0.0, weight_gradient_penalty: float = 1.0, latent_channels: int | None = None, devices: int = 1, accelerator: str = 'cpu', metrics: List[str] = [], metric_extractor: Module | None = None, metric_classifier: Module | None = None, metric_num_features: int | None = None)[source][source]

This class provide the implementation for WGAN-GP. It trains a zero-sum game between the generator and the discriminator, just like the traditional generative networks. The generator is optimized to generate simulation samples that are indistinguishable by the discriminator, and the discriminator is optimized to discriminate false samples generated by the generator. Compared with vanilla GAN, with WGAN-GP we can improve the stability of learning, get rid of problems like mode collapse, and provide meaningful learning curves useful for debugging and hyperparameter searches. Thus, existing work typically uses WGAN-GP to generate simulated EEG signals. In particular, the expected labels are additionally provided to guide the discriminator to distinguish whether the sample fits the data distribution of the class. For more details, please refer to the following information.

g_model = BGenerator(in_channels=128)
d_model = BDiscriminator(in_channels=4)
trainer = WGANGPTrainer(generator, discriminator)
trainer.fit(train_loader, val_loader)
trainer.test(test_loader)
Parameters:
  • generator (nn.Module) – The generator model for EEG signal generation, whose inputs are Gaussian distributed random vectors, outputs are generated EEG signals. The dimensions of the input vector should be defined on the in_channel attribute. The output layer does not need to have a softmax activation function.

  • discriminator (nn.Module) – The discriminator model to determine whether the EEG signal is real or generated, and the dimension of its output should be equal to the one (i.e., the score to distinguish the real and the fake). The output layer does not need to have a sigmoid activation function.

  • generator_lr (float) – The learning rate of the generator. (default: 0.0001)

  • discriminator_lr (float) – The learning rate of the discriminator. (default: 0.0001)

  • weight_gradient_penalty (float) – The weight of gradient penalty loss to trade-off between the adversarial training loss and gradient penalty loss. (default: 1.0)

  • weight_decay – (float): The weight decay (L2 penalty). (default: 0.0)

  • latent_channels (int) – The dimension of the latent vector. If not specified, it will be inferred from the in_channels attribute of the generator. (default: None)

  • devices (int) – The number of GPUs to use. (default: 1)

  • accelerator (str) – The accelerator to use. Available options are: ‘cpu’, ‘gpu’. (default: "cpu")

  • metrics (List[str]) – The metrics to use. The metrics to use. Available options are: ‘fid’, ‘is’. (default: [])

fit(train_loader: DataLoader, val_loader: DataLoader, max_epochs: int = 300, *args, **kwargs) Any[source]
Parameters:
  • train_loader (DataLoader) – Iterable DataLoader for traversing the training data batch (torch.utils.data.dataloader.DataLoader, torch_geometric.loader.DataLoader, etc).

  • val_loader (DataLoader) – Iterable DataLoader for traversing the validation data batch (torch.utils.data.dataloader.DataLoader, torch_geometric.loader.DataLoader, etc).

  • max_epochs (int) – Maximum number of epochs to train the model. (default: 300)

test(test_loader: DataLoader, *args, **kwargs) List[Dict[str, float]][source]
Parameters:

test_loader (DataLoader) – Iterable DataLoader for traversing the test data batch (torch.utils.data.dataloader.DataLoader, torch_geometric.loader.DataLoader, etc).

Read the Docs v: latest
Versions
latest
stable
v1.1.2
v1.1.1
v1.1.0
v1.0.11
v1.0.10
v1.0.9
v1.0.8.post1
v1.0.8
v1.0.7
v1.0.6
v1.0.4
v1.0.3
v1.0.2
v1.0.1
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources