Shortcuts

Source code for torcheeg.trainers.generative.beta_vae

import logging
import warnings
from itertools import chain
from typing import Any, Dict, List, Tuple

import pytorch_lightning as pl
import torch
import torch.nn as nn
import torchmetrics
from torch.utils.data import DataLoader
from torchmetrics.image.inception import InceptionScore

from .utils import FrechetInceptionDistance

_EVALUATE_OUTPUT = List[Dict[str, float]]  # 1 dict per DataLoader

log = logging.getLogger(__name__)

[docs]class BetaVAETrainer(pl.LightningModule): r''' This class provide the implementation for BetaVAE training. The variational autoencoder consists of two parts, an encoder, and a decoder. The encoder compresses the input into the latent space. The decoder receives as input the information sampled from the latent space and produces it as similar as possible to ground truth. The latent vector should approach the gaussian distribution supervised by KL divergence based on the variation trick. This class implement the training, test, and new EEG inference of variational autoencoders. - Paper: Gulrajani I, Ahmed F, Arjovsky M, et al. Improved training of wasserstein gans[J]. Advances in neural information processing systems, 2017, 30. - URL: https://arxiv.org/abs/1704.00028 - Related Project: https://github.com/eriklindernoren/PyTorch-GAN .. code-block:: python encoder = BEncoder(in_channels=4) decoder = BDecoder(in_channels=64, out_channels=4) trainer = BetaVAETrainer(encoder, decoder) trainer.fit(train_loader, val_loader) trainer.test(test_loader) Args: encoder (nn.Module): The encoder, whose inputs are EEG signals, outputs are two batches of vectors of the same dimension, representing the mean and variance estimated in the reparameterization trick. decoder (nn.Module): The decoder generating EEG signals from hidden variables encoded by the encoder. lr (float): The learning rate. (default: :obj:`0.0001`) weight_decay: (float): The weight decay (L2 penalty). (default: :obj:`0.0`) beta: (float): The weight of the KL divergence in the loss function. When beta is 1, the model is a standard VAE. (default: :obj:`1.0`) devices (int): The number of devices to use. (default: :obj:`1`) accelerator (str): The accelerator to use. Available options are: 'cpu', 'gpu'. (default: :obj:`"cpu"`) metrics (list of str): The metrics to use. Available options are: 'fid', 'is'. Due to the time-consuming generation process, these indicators will only be calculated and printed during test. (default: :obj:`[]`) metric_extractor (nn.Module): The feature extraction model used to calculate the FID and IS metrics. (default: :obj:`None`) metric_classifier (nn.Module): The classification model used to calculate the IS metric. (default: :obj:`None`) metric_num_features (int): The number of features extracted by the feature extraction model. (default: :obj:`None`) .. automethod:: fit .. automethod:: test ''' def __init__(self, encoder: nn.Module, decoder: nn.Module, lr: float = 1e-4, weight_decay: float = 0.0, beta: float = 1.0, devices: int = 1, accelerator: str = "cpu", metrics: List[str] = [], metric_extractor: nn.Module = None, metric_classifier: nn.Module = None, metric_num_features: int = None): super().__init__() self.encoder = encoder self.decoder = decoder self.lr = lr self.weight_decay = weight_decay self.beta = beta self.devices = devices self.accelerator = accelerator self.metrics = metrics self.mse_fn = nn.MSELoss() self.metric_extractor = metric_extractor self.metric_classifier = metric_classifier self.metric_num_features = metric_num_features self.init_metrics(metrics) def init_metrics(self, metrics) -> None: self.train_rec_loss = torchmetrics.MeanMetric() self.val_rec_loss = torchmetrics.MeanMetric() self.test_rec_loss = torchmetrics.MeanMetric() self.train_kld_loss = torchmetrics.MeanMetric() self.val_kld_loss = torchmetrics.MeanMetric() self.test_kld_loss = torchmetrics.MeanMetric() if 'fid' in metrics: assert not self.metric_extractor is None, 'The metric_extractor should be specified.' if hasattr(self.metric_extractor, 'in_channels') and self.metric_num_features is None: warnings.warn( f'No metric_num_features specified, use metric_extractor.in_channels ({self.metric_extractor.in_channels}) as metric_num_features.' ) self.metric_num_features = self.metric_extractor.in_channels assert not self.metric_num_features is None, 'The metric_num_features should be specified.' self.test_fid = FrechetInceptionDistance(self.metric_extractor, self.metric_num_features) if 'is' in metrics: assert not self.metric_extractor is None, 'The metric_classifier should be specified.' self.test_is = InceptionScore(self.metric_classifier)
[docs] def fit(self, train_loader: DataLoader, val_loader: DataLoader, max_epochs: int = 300, *args, **kwargs) -> Any: r''' Args: train_loader (DataLoader): Iterable DataLoader for traversing the training data batch (:obj:`torch.utils.data.dataloader.DataLoader`, :obj:`torch_geometric.loader.DataLoader`, etc). val_loader (DataLoader): Iterable DataLoader for traversing the validation data batch (:obj:`torch.utils.data.dataloader.DataLoader`, :obj:`torch_geometric.loader.DataLoader`, etc). max_epochs (int): Maximum number of epochs to train the model. (default: :obj:`300`) ''' trainer = pl.Trainer(devices=self.devices, accelerator=self.accelerator, max_epochs=max_epochs, inference_mode=False, *args, **kwargs) return trainer.fit(self, train_loader, val_loader)
[docs] def test(self, test_loader: DataLoader, *args, **kwargs) -> _EVALUATE_OUTPUT: r''' Args: test_loader (DataLoader): Iterable DataLoader for traversing the test data batch (torch.utils.data.dataloader.DataLoader, torch_geometric.loader.DataLoader, etc). ''' trainer = pl.Trainer(devices=self.devices, accelerator=self.accelerator, inference_mode=False, *args, **kwargs) return trainer.test(self, test_loader)
def reparameterize(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor: std = torch.exp(0.5 * logvar) eps = torch.randn_like(std) return eps * std + mu def forward(self, latent: torch.Tensor) -> torch.Tensor: return self.decoder(latent) def predict_step(self, batch: Tuple[torch.Tensor], batch_idx: int, dataloader_idx: int = 0, random: bool = True) -> torch.Tensor: x, _ = batch try: mu, log_var = self.encoder(x) # should return two vectors, one for mu, one for log_var # otherwise, re-raising the error except Exception as exc: raise ValueError( 'The encoder should return two vectors, one for mu, one for log_var.' ) from exc latent = self.reparameterize(mu, log_var) # sample from random instead of encoded latent if random: latent = torch.normal(mean=0, std=1, size=latent.shape).type_as(x) return self(latent) def training_step(self, batch: Tuple[torch.Tensor], batch_idx: int) -> torch.Tensor: x, _ = batch try: mu, log_var = self.encoder(x) # should return two vectors, one for mu, one for log_var # otherwise, re-raising the error except Exception as exc: raise ValueError( 'The encoder should return two vectors, one for mu, one for log_var.' ) from exc latent = self.reparameterize(mu, log_var) rec_x = self.decoder(latent) rec_loss = self.mse_fn(rec_x, x) kld_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp()) loss = rec_loss + self.beta * kld_loss self.log("train_rec_loss", self.train_rec_loss(rec_loss), prog_bar=True, on_epoch=False, logger=False, on_step=True) self.log("train_kld_loss", self.train_kld_loss(kld_loss), prog_bar=True, on_epoch=False, logger=False, on_step=True) return loss def on_train_epoch_end(self) -> None: self.log("train_rec_loss", self.train_rec_loss.compute(), prog_bar=False, on_epoch=True, on_step=False, logger=True) self.log("train_kld_loss", self.train_kld_loss.compute(), prog_bar=False, on_epoch=True, on_step=False, logger=True) # print the metrics str = "\n[Train] " for key, value in self.trainer.logged_metrics.items(): if key.startswith("train_"): str += f"{key}: {value:.3f} " print(str + '\n') # reset the metrics self.train_rec_loss.reset() self.train_kld_loss.reset() @torch.enable_grad() def validation_step(self, batch: Tuple[torch.Tensor], batch_idx: int) -> torch.Tensor: x, _ = batch try: mu, log_var = self.encoder(x) # should return two vectors, one for mu, one for log_var # otherwise, re-raising the error except Exception as exc: raise ValueError( 'The encoder should return two vectors, one for mu, one for log_var.' ) from exc latent = self.reparameterize(mu, log_var) rec_x = self.decoder(latent) rec_loss = self.mse_fn(rec_x, x) kld_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp()) loss = rec_loss + self.beta * kld_loss self.val_rec_loss.update(rec_loss) self.val_kld_loss.update(kld_loss) return loss def on_validation_epoch_end(self) -> None: self.log("val_rec_loss", self.val_rec_loss.compute(), prog_bar=False, on_epoch=True, on_step=False, logger=True) self.log("val_kld_loss", self.val_kld_loss.compute(), prog_bar=False, on_epoch=True, on_step=False, logger=True) # print the metrics str = "\n[VAL] " for key, value in self.trainer.logged_metrics.items(): if key.startswith("val_"): str += f"{key}: {value:.3f} " print(str + '\n') # reset the metrics self.val_rec_loss.reset() self.val_kld_loss.reset() @torch.enable_grad() def test_step(self, batch: Tuple[torch.Tensor], batch_idx: int) -> torch.Tensor: x, _ = batch try: mu, log_var = self.encoder(x) # should return two vectors, one for mu, one for log_var # otherwise, re-raising the error except Exception as exc: raise ValueError( 'The encoder should return two vectors, one for mu, one for log_var.' ) from exc latent = self.reparameterize(mu, log_var) rec_x = self.decoder(latent) rec_loss = self.mse_fn(rec_x, x) kld_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp()) loss = rec_loss + self.beta * kld_loss self.test_rec_loss.update(rec_loss) self.test_kld_loss.update(kld_loss) # sample from random instead of encoded latent latent = torch.normal(mean=0, std=1, size=latent.shape).type_as(x) gen_x = self.decoder(latent) if 'fid' in self.metrics: self.test_fid.update(x, real=True) self.test_fid.update(gen_x, real=False) if 'is' in self.metrics: self.test_is.update(gen_x) return loss def on_test_epoch_end(self) -> None: self.log("test_rec_loss", self.test_rec_loss.compute(), prog_bar=False, on_epoch=True, on_step=False, logger=True) self.log("test_kld_loss", self.test_kld_loss.compute(), prog_bar=False, on_epoch=True, on_step=False, logger=True) if 'fid' in self.metrics: self.log("test_fid", self.test_fid.compute(), prog_bar=False, on_epoch=True, on_step=False, logger=True) if 'is' in self.metrics: self.log("test_is", self.test_is.compute()[0], prog_bar=False, on_epoch=True, on_step=False, logger=True) # print the metrics str = "\n[TEST] " for key, value in self.trainer.logged_metrics.items(): if key.startswith("test_"): str += f"{key}: {value:.3f} " print(str + '\n') # reset the metrics self.test_rec_loss.reset() self.test_kld_loss.reset() if 'fid' in self.metrics: self.test_fid.reset() if 'is' in self.metrics: self.test_is.reset() def configure_optimizers(self): return torch.optim.Adam(chain(self.encoder.parameters(), self.decoder.parameters()), lr=self.lr, weight_decay=self.weight_decay)
[docs]class CBetaVAETrainer(BetaVAETrainer): r''' This class provide the implementation for BetaVAE training. The variational autoencoder consists of two parts, an encoder, and a decoder. The encoder compresses the input into the latent space. The decoder receives as input the information sampled from the latent space and produces it as similar as possible to ground truth. The latent vector should approach the gaussian distribution supervised by KL divergence based on the variation trick. This class implement the training, test, and new EEG inference of variational autoencoders. - Paper: Gulrajani I, Ahmed F, Arjovsky M, et al. Improved training of wasserstein gans[J]. Advances in neural information processing systems, 2017, 30. - Paper: Higgins I, Matthey L, Pal A, et al. beta-vae: Learning basic visual concepts with a constrained variational framework[C]//International conference on learning representations. 2017. - URL: https://arxiv.org/abs/1704.00028 - Related Project: https://github.com/eriklindernoren/PyTorch-GAN .. code-block:: python encoder = BCEncoder(in_channels=4, num_classes=2) decoder = BCDecoder(in_channels=64, out_channels=4, num_classes=2) trainer = CVAETrainer(encoder, decoder) trainer.fit(train_loader, val_loader) trainer.test(test_loader) Args: encoder (nn.Module): The encoder, whose inputs are EEG signals, outputs are two batches of vectors of the same dimension, representing the mean and variance estimated in the reparameterization trick. decoder (nn.Module): The decoder generating EEG signals from hidden variables encoded by the encoder. The decoder of CVAE should have an additional input, which is the label of the EEG signal to be generated. lr (float): The learning rate. (default: :obj:`0.0001`) weight_decay: (float): The weight decay (L2 penalty). (default: :obj:`0.0`) beta: (float): The weight of the KL divergence in the loss function. When beta is 1, the model is a standard VAE. (default: :obj:`1.0`) devices (int): The number of devices to use. (default: :obj:`1`) accelerator (str): The accelerator to use. Available options are: 'cpu', 'gpu'. (default: :obj:`"cpu"`) metrics (list of str): The metrics to use. Available options are: 'fid', 'is'. (default: :obj:`[]`) metric_extractor (nn.Module): The feature extraction model used to calculate the FID and IS metrics. (default: :obj:`None`) metric_classifier (nn.Module): The classification model used to calculate the IS metric. (default: :obj:`None`) metric_num_features (int): The number of features extracted by the feature extraction model. (default: :obj:`None`) .. automethod:: fit .. automethod:: test ''' def __init__(self, encoder: nn.Module, decoder: nn.Module, lr: float = 1e-4, weight_decay: float = 0.0, beta: float = 1.0, devices: int = 1, accelerator: str = "cpu", metrics: List[str] = [], metric_extractor: nn.Module = None, metric_classifier: nn.Module = None, metric_num_features: int = None): super(CBetaVAETrainer, self).__init__(encoder, decoder, lr, weight_decay, beta, devices, accelerator, metrics, metric_extractor, metric_classifier, metric_num_features) def training_step(self, batch: Tuple[torch.Tensor], batch_idx: int) -> torch.Tensor: x, y = batch try: mu, log_var = self.encoder(x, y) # should return two vectors, one for mu, one for log_var # otherwise, re-raising the error except Exception as exc: raise ValueError( 'The encoder should return two vectors, one for mu, one for log_var.' ) from exc latent = self.reparameterize(mu, log_var) rec_x = self.decoder(latent, y) rec_loss = self.mse_fn(rec_x, x) kld_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp()) loss = rec_loss + self.beta * kld_loss self.log("train_rec_loss", self.train_rec_loss(rec_loss), prog_bar=True, on_epoch=False, logger=False, on_step=True) self.log("train_kld_loss", self.train_kld_loss(kld_loss), prog_bar=True, on_epoch=False, logger=False, on_step=True) return loss @torch.enable_grad() def validation_step(self, batch: Tuple[torch.Tensor], batch_idx: int) -> torch.Tensor: x, y = batch try: mu, log_var = self.encoder(x, y) # should return two vectors, one for mu, one for log_var # otherwise, re-raising the error except Exception as exc: raise ValueError( 'The encoder should return two vectors, one for mu, one for log_var.' ) from exc latent = self.reparameterize(mu, log_var) rec_x = self.decoder(latent, y) rec_loss = self.mse_fn(rec_x, x) kld_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp()) loss = rec_loss + self.beta * kld_loss self.val_rec_loss.update(rec_loss) self.val_kld_loss.update(kld_loss) return loss @torch.enable_grad() def test_step(self, batch: Tuple[torch.Tensor], batch_idx: int) -> torch.Tensor: x, y = batch try: mu, log_var = self.encoder(x, y) # should return two vectors, one for mu, one for log_var # otherwise, re-raising the error except Exception as exc: raise ValueError( 'The encoder should return two vectors, one for mu, one for log_var.' ) from exc latent = self.reparameterize(mu, log_var) rec_x = self.decoder(latent, y) rec_loss = self.mse_fn(rec_x, x) kld_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp()) self.test_rec_loss.update(rec_loss) self.test_kld_loss.update(kld_loss) # sample from random instead of encoded latent latent = torch.normal(mean=0, std=1, size=latent.shape).type_as(x) gen_x = self.decoder(latent, y) if 'fid' in self.metrics: self.test_fid.update(x, real=True) self.test_fid.update(gen_x, real=False) if 'is' in self.metrics: self.test_is.update(gen_x) return rec_loss, kld_loss def forward(self, latent: torch.Tensor, y: torch.Tensor) -> torch.Tensor: return self.decoder(latent, y) def predict_step(self, batch: Tuple[torch.Tensor], batch_idx: int, dataloader_idx: int = 0, random: bool = True) -> torch.Tensor: x, y = batch try: mu, log_var = self.encoder(x, y) # should return two vectors, one for mu, one for log_var # otherwise, re-raising the error except Exception as exc: raise ValueError( 'The encoder should return two vectors, one for mu, one for log_var.' ) from exc latent = self.reparameterize(mu, log_var) # sample from random instead of encoded latent if random: latent = torch.normal(mean=0, std=1, size=latent.shape).type_as(x) return self(latent, y)

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources