Shortcuts

CBetaVAETrainer

class torcheeg.trainers.CBetaVAETrainer(encoder: Module, decoder: Module, lr: float = 0.0001, weight_decay: float = 0.0, beta: float = 1.0, devices: int = 1, accelerator: str = 'cpu', metrics: List[str] = [], metric_extractor: Module | None = None, metric_classifier: Module | None = None, metric_num_features: int | None = None)[source][source]

This class provide the implementation for BetaVAE training. The variational autoencoder consists of two parts, an encoder, and a decoder. The encoder compresses the input into the latent space. The decoder receives as input the information sampled from the latent space and produces it as similar as possible to ground truth. The latent vector should approach the gaussian distribution supervised by KL divergence based on the variation trick. This class implement the training, test, and new EEG inference of variational autoencoders.

  • Paper: Gulrajani I, Ahmed F, Arjovsky M, et al. Improved training of wasserstein gans[J]. Advances in neural information processing systems, 2017, 30.

  • Paper: Higgins I, Matthey L, Pal A, et al. beta-vae: Learning basic visual concepts with a constrained variational framework[C]//International conference on learning representations. 2017.

  • URL: https://arxiv.org/abs/1704.00028

  • Related Project: https://github.com/eriklindernoren/PyTorch-GAN

encoder = BCEncoder(in_channels=4, num_classes=2)
decoder = BCDecoder(in_channels=64, out_channels=4, num_classes=2)
trainer = CVAETrainer(encoder, decoder)
trainer.fit(train_loader, val_loader)
trainer.test(test_loader)
Parameters:
  • encoder (nn.Module) – The encoder, whose inputs are EEG signals, outputs are two batches of vectors of the same dimension, representing the mean and variance estimated in the reparameterization trick.

  • decoder (nn.Module) – The decoder generating EEG signals from hidden variables encoded by the encoder. The decoder of CVAE should have an additional input, which is the label of the EEG signal to be generated.

  • lr (float) – The learning rate. (default: 0.0001)

  • weight_decay – (float): The weight decay (L2 penalty). (default: 0.0)

  • beta – (float): The weight of the KL divergence in the loss function. When beta is 1, the model is a standard VAE. (default: 1.0)

  • devices (int) – The number of devices to use. (default: 1)

  • accelerator (str) – The accelerator to use. Available options are: ‘cpu’, ‘gpu’. (default: "cpu")

  • metrics (list of str) – The metrics to use. Available options are: ‘fid’, ‘is’. (default: [])

  • metric_extractor (nn.Module) – The feature extraction model used to calculate the FID and IS metrics. (default: None)

  • metric_classifier (nn.Module) – The classification model used to calculate the IS metric. (default: None)

  • metric_num_features (int) – The number of features extracted by the feature extraction model. (default: None)

fit(train_loader: DataLoader, val_loader: DataLoader, max_epochs: int = 300, *args, **kwargs) Any[source]
Parameters:
  • train_loader (DataLoader) – Iterable DataLoader for traversing the training data batch (torch.utils.data.dataloader.DataLoader, torch_geometric.loader.DataLoader, etc).

  • val_loader (DataLoader) – Iterable DataLoader for traversing the validation data batch (torch.utils.data.dataloader.DataLoader, torch_geometric.loader.DataLoader, etc).

  • max_epochs (int) – Maximum number of epochs to train the model. (default: 300)

test(test_loader: DataLoader, *args, **kwargs) List[Dict[str, float]][source]
Parameters:

test_loader (DataLoader) – Iterable DataLoader for traversing the test data batch (torch.utils.data.dataloader.DataLoader, torch_geometric.loader.DataLoader, etc).

Read the Docs v: latest
Versions
latest
stable
v1.1.2
v1.1.1
v1.1.0
v1.0.11
v1.0.10
v1.0.9
v1.0.8.post1
v1.0.8
v1.0.7
v1.0.6
v1.0.4
v1.0.3
v1.0.2
v1.0.1
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources