Shortcuts

torcheeg.trainers

trainers.BasicTrainer

class torcheeg.trainers.BasicTrainer(modules: Dict, device_ids: List[int] = [], ddp_sync_bn: bool = True, ddp_replace_sampler: bool = True, ddp_val: bool = True, ddp_test: bool = True)[source][source]

A generic trainer class for EEG analysis providing interfaces for all trainers to implement contexts common in training deep learning models. After inheriting this class, on_training_step, on_validation_step, and on_test_step must be implemented.

The class provides the following hook functions for inserting additional implementations in the training, validation and testing lifecycle:

  • before_training_epoch: executed before each epoch of training starts

  • before_training_step: executed before each batch of training starts

  • on_training_step: the training process for each batch

  • after_training_step: execute after the training of each batch

  • after_training_epoch: executed after each epoch of training

  • before_validation_epoch: executed before each round of validation starts

  • before_validation_step: executed before the validation of each batch

  • on_validation_step: validation process for each batch

  • after_validation_step: executed after the validation of each batch

  • after_validation_epoch: executed after each round of validation

  • before_test_epoch: executed before each round of test starts

  • before_test_step: executed before the test of each batch

  • on_test_step: test process for each batch

  • after_test_step: executed after the test of each batch

  • after_test_epoch: executed after each round of test

If you want to use multiple GPUs for parallel computing, you need to specify the GPU indices you want to use in the python file:

trainer = BasicTrainer(model, device_ids=[1, 2, 7])
trainer.fit(train_loader, val_loader)
trainer.test(test_loader)

Then, you can use the torch.distributed.launch or torchrun to run your python file.

python -m torch.distributed.launch \
    --nproc_per_node=3 \
    --nnodes=1 \
    --node_rank=0 \
    --master_addr="localhost" \
    --master_port=2345 \
    your_python_file.py

Here, nproc_per_node is the number of GPUs you specify.

Parameters
  • model (Dict) – A dictionary that stores neural networks for import, export and device conversion of neural networks.

  • device_ids (list) – Use cpu if the list is empty. If the list contains indices of multiple GPUs, it needs to be launched with torch.distributed.launch or torchrun. (defualt: [])

  • ddp_sync_bn (bool) – Whether to replace batch normalization in network structure with cross-GPU synchronized batch normalization. Only valid when the length of device_ids is greater than one. (defualt: True)

  • ddp_replace_sampler (bool) – Whether to replace sampler in dataloader with DistributedSampler. Only valid when the length of device_ids is greater than one. (defualt: True)

  • ddp_val (bool) – Whether to use multi-GPU acceleration for the validation set. For experiments where data input order is sensitive, ddp_val should be set to False. Only valid when the length of device_ids is greater than one. (defualt: True)

  • ddp_test (bool) – Whether to use multi-GPU acceleration for the test set. For experiments where data input order is sensitive, ddp_test should be set to False. Only valid when the length of device_ids is greater than one. (defualt: True)

fit(train_loader: DataLoader, val_loader: DataLoader, num_epochs: int = 1, **kwargs)[source][source]
Parameters
  • train_loader (DataLoader) – Iterable DataLoader for traversing the training data batch (torch.utils.data.dataloader.DataLoader, torch_geometric.loader.DataLoader, etc).

  • val_loader (DataLoader) – Iterable DataLoader for traversing the validation data batch (torch.utils.data.dataloader.DataLoader, torch_geometric.loader.DataLoader, etc).

  • num_epochs (int) – training epochs. (defualt: 1)

test(test_loader: DataLoader, **kwargs)[source][source]
Parameters

test_loader (DataLoader) – Iterable DataLoader for traversing the test data batch (torch.utils.data.dataloader.DataLoader, torch_geometric.loader.DataLoader, etc).

trainers.ClassificationTrainer

class torcheeg.trainers.ClassificationTrainer(model: Module, num_classes: Optional[int] = None, lr: float = 0.0001, weight_decay: float = 0.0, device_ids: List[int] = [], ddp_sync_bn: bool = True, ddp_replace_sampler: bool = True, ddp_val: bool = True, ddp_test: bool = True)[source][source]

A generic trainer class for EEG classification.

trainer = ClassificationTrainer(model)
trainer.fit(train_loader, val_loader)
trainer.test(test_loader)

The class provides the following hook functions for inserting additional implementations in the training, validation and testing lifecycle:

  • before_training_epoch: executed before each epoch of training starts

  • before_training_step: executed before each batch of training starts

  • on_training_step: the training process for each batch

  • after_training_step: execute after the training of each batch

  • after_training_epoch: executed after each epoch of training

  • before_validation_epoch: executed before each round of validation starts

  • before_validation_step: executed before the validation of each batch

  • on_validation_step: validation process for each batch

  • after_validation_step: executed after the validation of each batch

  • after_validation_epoch: executed after each round of validation

  • before_test_epoch: executed before each round of test starts

  • before_test_step: executed before the test of each batch

  • on_test_step: test process for each batch

  • after_test_step: executed after the test of each batch

  • after_test_epoch: executed after each round of test

If you want to customize some operations, you just need to inherit the class and override the hook function:

class MyClassificationTrainer(ClassificationTrainer):
    def before_training_epoch(self, epoch_id: int, num_epochs: int):
        # Do something here.
        super().before_training_epoch(epoch_id, num_epochs)

If you want to use multiple GPUs for parallel computing, you need to specify the GPU indices you want to use in the python file:

trainer = ClassificationTrainer(model, device_ids=[1, 2, 7])
trainer.fit(train_loader, val_loader)
trainer.test(test_loader)

Then, you can use the torch.distributed.launch or torchrun to run your python file.

python -m torch.distributed.launch \
    --nproc_per_node=3 \
    --nnodes=1 \
    --node_rank=0 \
    --master_addr="localhost" \
    --master_port=2345 \
    your_python_file.py

Here, nproc_per_node is the number of GPUs you specify.

Parameters
  • model (nn.Module) – The classification model, and the dimension of its output should be equal to the number of categories in the dataset. The output layer does not need to have a softmax activation function.

  • num_classes (int, optional) – The number of categories in the dataset. If None, the number of categories will be inferred from the attribute num_classes of the model. (defualt: None)

  • lr (float) – The learning rate. (defualt: 0.0001)

  • weight_decay – (float): The weight decay (L2 penalty). (defualt: 0.0)

  • device_ids (list) – Use cpu if the list is empty. If the list contains indices of multiple GPUs, it needs to be launched with torch.distributed.launch or torchrun. (defualt: [])

  • ddp_sync_bn (bool) – Whether to replace batch normalization in network structure with cross-GPU synchronized batch normalization. Only valid when the length of device_ids is greater than one. (defualt: True)

  • ddp_replace_sampler (bool) – Whether to replace sampler in dataloader with DistributedSampler. Only valid when the length of device_ids is greater than one. (defualt: True)

  • ddp_val (bool) – Whether to use multi-GPU acceleration for the validation set. For experiments where data input order is sensitive, ddp_val should be set to False. Only valid when the length of device_ids is greater than one. (defualt: True)

  • ddp_test (bool) – Whether to use multi-GPU acceleration for the test set. For experiments where data input order is sensitive, ddp_test should be set to False. Only valid when the length of device_ids is greater than one. (defualt: True)

fit(train_loader: DataLoader, val_loader: DataLoader, num_epochs: int = 1, **kwargs)[source][source]
Parameters
  • train_loader (DataLoader) – Iterable DataLoader for traversing the training data batch (torch.utils.data.dataloader.DataLoader, torch_geometric.loader.DataLoader, etc).

  • val_loader (DataLoader) – Iterable DataLoader for traversing the validation data batch (torch.utils.data.dataloader.DataLoader, torch_geometric.loader.DataLoader, etc).

  • num_epochs (int) – training epochs. (defualt: 1)

test(test_loader: DataLoader, **kwargs)[source][source]
Parameters

test_loader (DataLoader) – Iterable DataLoader for traversing the test data batch (torch.utils.data.dataloader.DataLoader, torch_geometric.loader.DataLoader, etc).

trainers.CORALTrainer

class torcheeg.trainers.CORALTrainer(extractor: Module, classifier: Module, match_mean: bool = True, lambd: float = 1.0, num_classes: Optional[int] = None, lr: float = 0.0001, weight_decay: float = 0.0, device_ids: List[int] = [], ddp_sync_bn: bool = True, ddp_replace_sampler: bool = True, ddp_val: bool = True, ddp_test: bool = True)[source][source]

The individual differences and nonstationary of EEG signals make it difficult for deep learning models trained on the training set of subjects to correctly classify test samples from unseen subjects, since the training set and test set come from different data distributions. Domain adaptation is used to address the problem of distribution drift between training and test sets and thus achieves good performance in subject-independent (cross-subject) scenarios. This class supports the implementation of CORrelation ALignment (CORAL) for deep domain adaptation.

NOTE: CORAL belongs to unsupervised domain adaptation methods, which only use labeled source and unlabeled target data. This means that the target dataset does not have to return labels.

trainer = CORALTrainer(extractor, classifier)
trainer.fit(source_loader, target_loader, val_loader)
trainer.test(test_loader)

The class provides the following hook functions for inserting additional implementations in the training, validation and testing lifecycle:

  • before_training_epoch: executed before each epoch of training starts

  • before_training_step: executed before each batch of training starts

  • on_training_step: the training process for each batch

  • after_training_step: execute after the training of each batch

  • after_training_epoch: executed after each epoch of training

  • before_validation_epoch: executed before each round of validation starts

  • before_validation_step: executed before the validation of each batch

  • on_validation_step: validation process for each batch

  • after_validation_step: executed after the validation of each batch

  • after_validation_epoch: executed after each round of validation

  • before_test_epoch: executed before each round of test starts

  • before_test_step: executed before the test of each batch

  • on_test_step: test process for each batch

  • after_test_step: executed after the test of each batch

  • after_test_epoch: executed after each round of test

If you want to customize some operations, you just need to inherit the class and override the hook function:

class MyCORALTrainer(CORALTrainer):
    def before_training_epoch(self, epoch_id: int, num_epochs: int):
        # Do something here.
        super().before_training_epoch(epoch_id, num_epochs)

If you want to use multiple GPUs for parallel computing, you need to specify the GPU indices you want to use in the python file:

trainer = CORALTrainer(model, device_ids=[1, 2, 7])
trainer.fit(train_loader, val_loader)
trainer.test(test_loader)

Then, you can use the torch.distributed.launch or torchrun to run your python file.

python -m torch.distributed.launch \
    --nproc_per_node=3 \
    --nnodes=1 \
    --node_rank=0 \
    --master_addr="localhost" \
    --master_port=2345 \
    your_python_file.py

Here, nproc_per_node is the number of GPUs you specify.

Parameters
  • extractor (nn.Module) – The feature extraction model, learning the feature representation of EEG signal by forcing the correlation matrixes of source and target data close.

  • classifier (nn.Module) – The classification model, learning the classification task with source labeled data based on the feature of the feature extraction model. The dimension of its output should be equal to the number of categories in the dataset. The output layer does not need to have a softmax activation function.

  • lambd (float) – The weight of CORAL loss to trade-off between the classification loss and CORAL loss. (defualt: 1.0)

  • match_mean (bool) – Weither to match the means of the source domain and target domain samples. If False, only the second moment is matched. (defualt: False)

  • num_classes (int, optional) – The number of categories in the dataset. If None, the number of categories will be inferred from the attribute num_classes of the model. (defualt: None)

  • lr (float) – The learning rate. (defualt: 0.0001)

  • weight_decay (float) – The weight decay (L2 penalty). (defualt: 0.0)

  • device_ids (list) – Use cpu if the list is empty. If the list contains indices of multiple GPUs, it needs to be launched with torch.distributed.launch or torchrun. (defualt: [])

  • ddp_sync_bn (bool) – Whether to replace batch normalization in network structure with cross-GPU synchronized batch normalization. Only valid when the length of device_ids is greater than one. (defualt: True)

  • ddp_replace_sampler (bool) – Whether to replace sampler in dataloader with DistributedSampler. Only valid when the length of device_ids is greater than one. (defualt: True)

  • ddp_val (bool) – Whether to use multi-GPU acceleration for the validation set. For experiments where data input order is sensitive, ddp_val should be set to False. Only valid when the length of device_ids is greater than one. (defualt: True)

  • ddp_test (bool) – Whether to use multi-GPU acceleration for the test set. For experiments where data input order is sensitive, ddp_test should be set to False. Only valid when the length of device_ids is greater than one. (defualt: True)

fit(source_loader: DataLoader, target_loader: DataLoader, val_loader: DataLoader, num_epochs: int = 1, **kwargs)[source][source]
Parameters
  • source_loader (DataLoader) – Iterable DataLoader for traversing the data batch from the source domain (torch.utils.data.dataloader.DataLoader, torch_geometric.loader.DataLoader, etc).

  • target_loader (DataLoader) – Iterable DataLoader for traversing the training data batch from the target domain (torch.utils.data.dataloader.DataLoader, torch_geometric.loader.DataLoader, etc). The target dataset does not have to return labels.

  • val_loader (DataLoader) – Iterable DataLoader for traversing the validation data batch (torch.utils.data.dataloader.DataLoader, torch_geometric.loader.DataLoader, etc).

  • num_epochs (int) – training epochs. (defualt: 1)

test(test_loader: DataLoader, **kwargs)[source][source]
Parameters

test_loader (DataLoader) – Iterable DataLoader for traversing the test data batch (torch.utils.data.dataloader.DataLoader, torch_geometric.loader.DataLoader, etc).

trainers.DDCTrainer

class torcheeg.trainers.DDCTrainer(extractor: Module, classifier: Module, lambd: float = 1.0, adaption_factor: bool = False, num_classes: Optional[int] = None, lr: float = 0.0001, weight_decay: float = 0.0, device_ids: List[int] = [], ddp_sync_bn: bool = True, ddp_replace_sampler: bool = True, ddp_val: bool = True, ddp_test: bool = True)[source][source]

The individual differences and nonstationary of EEG signals make it difficult for deep learning models trained on the training set of subjects to correctly classify test samples from unseen subjects, since the training set and test set come from different data distributions. Domain adaptation is used to address the problem of distribution drift between training and test sets and thus achieves good performance in subject-independent (cross-subject) scenarios. This class supports the implementation of Deep Domain Confusion (DDC) for deep domain adaptation.

NOTE: DDC belongs to unsupervised domain adaptation methods, which only use labeled source and unlabeled target data. This means that the target dataset does not have to return labels.

trainer = DDCTrainer(extractor, classifier)
trainer.fit(source_loader, target_loader, val_loader)
trainer.test(test_loader)

The class provides the following hook functions for inserting additional implementations in the training, validation and testing lifecycle:

  • before_training_epoch: executed before each epoch of training starts

  • before_training_step: executed before each batch of training starts

  • on_training_step: the training process for each batch

  • after_training_step: execute after the training of each batch

  • after_training_epoch: executed after each epoch of training

  • before_validation_epoch: executed before each round of validation starts

  • before_validation_step: executed before the validation of each batch

  • on_validation_step: validation process for each batch

  • after_validation_step: executed after the validation of each batch

  • after_validation_epoch: executed after each round of validation

  • before_test_epoch: executed before each round of test starts

  • before_test_step: executed before the test of each batch

  • on_test_step: test process for each batch

  • after_test_step: executed after the test of each batch

  • after_test_epoch: executed after each round of test

You can override the methods of this interface to implement your own trainer:

If you want to customize some operations, you just need to inherit the class and override the hook function:

class MyDDCTrainer(DDCTrainer):
    def before_training_epoch(self, epoch_id: int, num_epochs: int):
        # Do something here.
        super().before_training_epoch(epoch_id, num_epochs)

If you want to use multiple GPUs for parallel computing, you need to specify the GPU indices you want to use in the python file:

trainer = DDCTrainer(model, device_ids=[1, 2, 7])
trainer.fit(train_loader, val_loader)
trainer.test(test_loader)

Then, you can use the torch.distributed.launch or torchrun to run your python file.

python -m torch.distributed.launch \
    --nproc_per_node=3 \
    --nnodes=1 \
    --node_rank=0 \
    --master_addr="localhost" \
    --master_port=2345 \
    your_python_file.py

Here, nproc_per_node is the number of GPUs you specify.

Parameters
  • extractor (nn.Module) – The feature extraction model, learning the feature representation of EEG signal by forcing the correlation matrixes of source and target data close.

  • classifier (nn.Module) – The classification model, learning the classification task with source labeled data based on the feature of the feature extraction model. The dimension of its output should be equal to the number of categories in the dataset. The output layer does not need to have a softmax activation function.

  • lambd (float) – The weight of DDC loss to trade-off between the classification loss and DDC loss. (defualt: 1.0)

  • adaption_factor (bool) – Whether to adjust the cross-domain-related loss term using the fitness factor, which was first proposed in DANN but works in many cases. (defualt: False)

  • num_classes (int, optional) – The number of categories in the dataset. If None, the number of categories will be inferred from the attribute num_classes of the model. (defualt: None)

  • lr (float) – The learning rate. (defualt: 0.0001)

  • weight_decay – (float): The weight decay (L2 penalty). (defualt: 0.0)

  • device_ids (list) – Use cpu if the list is empty. If the list contains indices of multiple GPUs, it needs to be launched with torch.distributed.launch or torchrun. (defualt: [])

  • ddp_sync_bn (bool) – Whether to replace batch normalization in network structure with cross-GPU synchronized batch normalization. Only valid when the length of device_ids is greater than one. (defualt: True)

  • ddp_replace_sampler (bool) – Whether to replace sampler in dataloader with DistributedSampler. Only valid when the length of device_ids is greater than one. (defualt: True)

  • ddp_val (bool) – Whether to use multi-GPU acceleration for the validation set. For experiments where data input order is sensitive, ddp_val should be set to False. Only valid when the length of device_ids is greater than one. (defualt: True)

  • ddp_test (bool) – Whether to use multi-GPU acceleration for the test set. For experiments where data input order is sensitive, ddp_test should be set to False. Only valid when the length of device_ids is greater than one. (defualt: True)

fit(source_loader: DataLoader, target_loader: DataLoader, val_loader: DataLoader, num_epochs: int = 1, **kwargs)[source][source]
Parameters
  • source_loader (DataLoader) – Iterable DataLoader for traversing the data batch from the source domain (torch.utils.data.dataloader.DataLoader, torch_geometric.loader.DataLoader, etc).

  • target_loader (DataLoader) – Iterable DataLoader for traversing the training data batch from the target domain (torch.utils.data.dataloader.DataLoader, torch_geometric.loader.DataLoader, etc). The target dataset does not have to return labels.

  • val_loader (DataLoader) – Iterable DataLoader for traversing the validation data batch (torch.utils.data.dataloader.DataLoader, torch_geometric.loader.DataLoader, etc).

  • num_epochs (int) – training epochs. (defualt: 1)

test(test_loader: DataLoader, **kwargs)[source][source]
Parameters

test_loader (DataLoader) – Iterable DataLoader for traversing the test data batch (torch.utils.data.dataloader.DataLoader, torch_geometric.loader.DataLoader, etc).

trainers.DANNTrainer

class torcheeg.trainers.DANNTrainer(extractor: Module, classifier: Module, domain_classifier: Module, lambd: float = 1.0, adaption_factor: bool = False, num_classes: Optional[int] = None, lr: float = 0.0001, weight_decay: float = 0.0, device_ids: List[int] = [], ddp_sync_bn: bool = True, ddp_replace_sampler: bool = True, ddp_val: bool = True, ddp_test: bool = True)[source][source]

The individual differences and nonstationary of EEG signals make it difficult for deep learning models trained on the training set of subjects to correctly classify test samples from unseen subjects, since the training set and test set come from different data distributions. Domain adaptation is used to address the problem of distribution drift between training and test sets and thus achieves good performance in subject-independent (cross-subject) scenarios. This class supports the implementation of Domain Adversarial Neural Networks (DANN) for deep domain adaptation.

NOTE: DANN belongs to unsupervised domain adaptation methods, which only use labeled source and unlabeled target data. This means that the target dataset does not have to return labels.

trainer = DANNTrainer(extractor, classifier, domain_classifier)
trainer.fit(source_loader, target_loader, val_loader)
trainer.test(test_loader)

The class provides the following hook functions for inserting additional implementations in the training, validation and testing lifecycle:

  • before_training_epoch: executed before each epoch of training starts

  • before_training_step: executed before each batch of training starts

  • on_training_step: the training process for each batch

  • after_training_step: execute after the training of each batch

  • after_training_epoch: executed after each epoch of training

  • before_validation_epoch: executed before each round of validation starts

  • before_validation_step: executed before the validation of each batch

  • on_validation_step: validation process for each batch

  • after_validation_step: executed after the validation of each batch

  • after_validation_epoch: executed after each round of validation

  • before_test_epoch: executed before each round of test starts

  • before_test_step: executed before the test of each batch

  • on_test_step: test process for each batch

  • after_test_step: executed after the test of each batch

  • after_test_epoch: executed after each round of test

If you want to customize some operations, you just need to inherit the class and override the hook function:

class MyDANNTrainer(DANNTrainer):
    def before_training_epoch(self, epoch_id: int, num_epochs: int):
        # Do something here.
        super().before_training_epoch(epoch_id, num_epochs)

If you want to use multiple GPUs for parallel computing, you need to specify the GPU indices you want to use in the python file:

trainer = DANNTrainer(model, device_ids=[1, 2, 7])
trainer.fit(train_loader, val_loader)
trainer.test(test_loader)

Then, you can use the torch.distributed.launch or torchrun to run your python file.

python -m torch.distributed.launch \
    --nproc_per_node=3 \
    --nnodes=1 \
    --node_rank=0 \
    --master_addr="localhost" \
    --master_port=2345 \
    your_python_file.py

Here, nproc_per_node is the number of GPUs you specify.

Parameters
  • extractor (nn.Module) – The feature extraction model, learning the feature representation of EEG signal by forcing the correlation matrixes of source and target data close.

  • classifier (nn.Module) – The classification model, learning the classification task with source labeled data based on the feature of the feature extraction model. The dimension of its output should be equal to the number of categories in the dataset. The output layer does not need to have a softmax activation function.

  • domain_classifier (nn.Module) – The classification model, learning to discriminate between the source and target domains based on the feature of the feature extraction model. The dimension of its output should be equal to the number of categories in the dataset. The output layer does not need to have a softmax activation function or a gradient reverse layer (which is already included in the implementation).

  • lambd (float) – The weight of DANN loss to trade-off between the classification loss and DANN loss. (defualt: 1.0)

  • adaption_factor (bool) – Whether to adjust the cross-domain-related loss term using the fitness factor, which was first proposed in DANN but works in many cases. (defualt: True)

  • num_classes (int, optional) – The number of categories in the dataset. If None, the number of categories will be inferred from the attribute num_classes of the model. (defualt: None)

  • lr (float) – The learning rate. (defualt: 0.0001)

  • weight_decay – (float): The weight decay (L2 penalty). (defualt: 0.0)

  • device – (torch.device or str): The device on which the model and data is or will be allocated. (defualt: False)

  • device_ids (list) – Use cpu if the list is empty. If the list contains indices of multiple GPUs, it needs to be launched with torch.distributed.launch or torchrun. (defualt: [])

  • ddp_sync_bn (bool) – Whether to replace batch normalization in network structure with cross-GPU synchronized batch normalization. Only valid when the length of device_ids is greater than one. (defualt: True)

  • ddp_replace_sampler (bool) – Whether to replace sampler in dataloader with DistributedSampler. Only valid when the length of device_ids is greater than one. (defualt: True)

  • ddp_val (bool) – Whether to use multi-GPU acceleration for the validation set. For experiments where data input order is sensitive, ddp_val should be set to False. Only valid when the length of device_ids is greater than one. (defualt: True)

  • ddp_test (bool) – Whether to use multi-GPU acceleration for the test set. For experiments where data input order is sensitive, ddp_test should be set to False. Only valid when the length of device_ids is greater than one. (defualt: True)

fit(source_loader: DataLoader, target_loader: DataLoader, val_loader: DataLoader, num_epochs: int = 1, **kwargs)[source][source]
Parameters
  • source_loader (DataLoader) – Iterable DataLoader for traversing the data batch from the source domain (torch.utils.data.dataloader.DataLoader, torch_geometric.loader.DataLoader, etc).

  • target_loader (DataLoader) – Iterable DataLoader for traversing the training data batch from the target domain (torch.utils.data.dataloader.DataLoader, torch_geometric.loader.DataLoader, etc). The target dataset does not have to return labels.

  • val_loader (DataLoader) – Iterable DataLoader for traversing the validation data batch (torch.utils.data.dataloader.DataLoader, torch_geometric.loader.DataLoader, etc).

  • num_epochs (int) – training epochs. (defualt: 1)

test(test_loader: DataLoader, **kwargs)[source][source]
Parameters

test_loader (DataLoader) – Iterable DataLoader for traversing the test data batch (torch.utils.data.dataloader.DataLoader, torch_geometric.loader.DataLoader, etc).

trainers.DANTrainer

class torcheeg.trainers.DANTrainer(extractor: Module, classifier: Module, lambd: float = 1.0, adaption_factor: bool = False, num_classes: Optional[int] = None, lr: float = 0.0001, weight_decay: float = 0.0, device_ids: List[int] = [], ddp_sync_bn: bool = True, ddp_replace_sampler: bool = True, ddp_val: bool = True, ddp_test: bool = True)[source][source]

The individual differences and nonstationary of EEG signals make it difficult for deep learning models trained on the training set of subjects to correctly classify test samples from unseen subjects, since the training set and test set come from different data distributions. Domain adaptation is used to address the problem of distribution drift between training and test sets and thus achieves good performance in subject-independent (cross-subject) scenarios. This class supports the implementation of Deep Adaptation Network (DAN) for deep domain adaptation.

NOTE: DAN belongs to unsupervised domain adaptation methods, which only use labeled source and unlabeled target data. This means that the target dataset does not have to return labels.

trainer = DANTrainer(extractor, classifier)
trainer.fit(source_loader, target_loader, val_loader)
trainer.test(test_loader)

The class provides the following hook functions for inserting additional implementations in the training, validation and testing lifecycle:

  • before_training_epoch: executed before each epoch of training starts

  • before_training_step: executed before each batch of training starts

  • on_training_step: the training process for each batch

  • after_training_step: execute after the training of each batch

  • after_training_epoch: executed after each epoch of training

  • before_validation_epoch: executed before each round of validation starts

  • before_validation_step: executed before the validation of each batch

  • on_validation_step: validation process for each batch

  • after_validation_step: executed after the validation of each batch

  • after_validation_epoch: executed after each round of validation

  • before_test_epoch: executed before each round of test starts

  • before_test_step: executed before the test of each batch

  • on_test_step: test process for each batch

  • after_test_step: executed after the test of each batch

  • after_test_epoch: executed after each round of test

If you want to customize some operations, you just need to inherit the class and override the hook function:

class MyDANTrainer(DANTrainer):
    def before_training_epoch(self, epoch_id: int, num_epochs: int):
        # Do something here.
        super().before_training_epoch(epoch_id, num_epochs)

If you want to use multiple GPUs for parallel computing, you need to specify the GPU indices you want to use in the python file:

trainer = DANTrainer(model, device_ids=[1, 2, 7])
trainer.fit(train_loader, val_loader)
trainer.test(test_loader)

Then, you can use the torch.distributed.launch or torchrun to run your python file.

python -m torch.distributed.launch \
    --nproc_per_node=3 \
    --nnodes=1 \
    --node_rank=0 \
    --master_addr="localhost" \
    --master_port=2345 \
    your_python_file.py

Here, nproc_per_node is the number of GPUs you specify.

Parameters
  • extractor (nn.Module) – The feature extraction model, learning the feature representation of EEG signal by forcing the correlation matrixes of source and target data close.

  • classifier (nn.Module) – The classification model, learning the classification task with source labeled data based on the feature of the feature extraction model. The dimension of its output should be equal to the number of categories in the dataset. The output layer does not need to have a softmax activation function.

  • lambd (float) – The weight of DAN loss to trade-off between the classification loss and DAN loss. (defualt: 1.0)

  • adaption_factor (bool) – Whether to adjust the cross-domain-related loss term using the fitness factor, which was first proposed in DANN but works in many cases. (defualt: True)

  • num_classes (int, optional) – The number of categories in the dataset. If None, the number of categories will be inferred from the attribute num_classes of the model. (defualt: None)

  • lr (float) – The learning rate. (defualt: 0.0001)

  • weight_decay (float) – The weight decay (L2 penalty). (defualt: 0.0)

  • device_ids (list) – Use cpu if the list is empty. If the list contains indices of multiple GPUs, it needs to be launched with torch.distributed.launch or torchrun. (defualt: [])

  • ddp_sync_bn (bool) – Whether to replace batch normalization in network structure with cross-GPU synchronized batch normalization. Only valid when the length of device_ids is greater than one. (defualt: True)

  • ddp_replace_sampler (bool) – Whether to replace sampler in dataloader with DistributedSampler. Only valid when the length of device_ids is greater than one. (defualt: True)

  • ddp_val (bool) – Whether to use multi-GPU acceleration for the validation set. For experiments where data input order is sensitive, ddp_val should be set to False. Only valid when the length of device_ids is greater than one. (defualt: True)

  • ddp_test (bool) – Whether to use multi-GPU acceleration for the test set. For experiments where data input order is sensitive, ddp_test should be set to False. Only valid when the length of device_ids is greater than one. (defualt: True)

fit(source_loader: DataLoader, target_loader: DataLoader, val_loader: DataLoader, num_epochs: int = 1, **kwargs)[source][source]
Parameters
  • source_loader (DataLoader) – Iterable DataLoader for traversing the data batch from the source domain (torch.utils.data.dataloader.DataLoader, torch_geometric.loader.DataLoader, etc).

  • target_loader (DataLoader) – Iterable DataLoader for traversing the training data batch from the target domain (torch.utils.data.dataloader.DataLoader, torch_geometric.loader.DataLoader, etc). The target dataset does not have to return labels.

  • val_loader (DataLoader) – Iterable DataLoader for traversing the validation data batch (torch.utils.data.dataloader.DataLoader, torch_geometric.loader.DataLoader, etc).

  • num_epochs (int) – training epochs. (defualt: 1)

test(test_loader: DataLoader, **kwargs)[source][source]
Parameters

test_loader (DataLoader) – Iterable DataLoader for traversing the test data batch (torch.utils.data.dataloader.DataLoader, torch_geometric.loader.DataLoader, etc).

trainers.ADATrainer

class torcheeg.trainers.ADATrainer(extractor: Module, classifier: Module, lambd: float = 1.0, adaption_factor: bool = False, num_classes: Optional[int] = None, lr: float = 0.0001, walker_weight: float = 1.0, visit_weight: float = 1.0, weight_decay: float = 0.0, device_ids: List[int] = [], ddp_sync_bn: bool = True, ddp_replace_sampler: bool = True, ddp_val: bool = True, ddp_test: bool = True)[source][source]

The individual differences and nonstationary of EEG signals make it difficult for deep learning models trained on the training set of subjects to correctly classify test samples from unseen subjects, since the training set and test set come from different data distributions. Domain adaptation is used to address the problem of distribution drift between training and test sets and thus achieves good performance in subject-independent (cross-subject) scenarios. This class supports the implementation of Associative Domain Adaptation (ADA) for deep domain adaptation.

NOTE: ADA belongs to unsupervised domain adaptation methods, which only use labeled source and unlabeled target data. This means that the target dataset does not have to return labels.

trainer = ADATrainer(extractor, classifier)
trainer.fit(source_loader, target_loader, val_loader)
trainer.test(test_loader)

The class provides the following hook functions for inserting additional implementations in the training, validation and testing lifecycle:

  • before_training_epoch: executed before each epoch of training starts

  • before_training_step: executed before each batch of training starts

  • on_training_step: the training process for each batch

  • after_training_step: execute after the training of each batch

  • after_training_epoch: executed after each epoch of training

  • before_validation_epoch: executed before each round of validation starts

  • before_validation_step: executed before the validation of each batch

  • on_validation_step: validation process for each batch

  • after_validation_step: executed after the validation of each batch

  • after_validation_epoch: executed after each round of validation

  • before_test_epoch: executed before each round of test starts

  • before_test_step: executed before the test of each batch

  • on_test_step: test process for each batch

  • after_test_step: executed after the test of each batch

  • after_test_epoch: executed after each round of test

If you want to customize some operations, you just need to inherit the class and override the hook function:

class MyADATrainer(ADATrainer):
    def before_training_epoch(self, epoch_id: int, num_epochs: int):
        # Do something here.
        super().before_training_epoch(epoch_id, num_epochs)

If you want to use multiple GPUs for parallel computing, you need to specify the GPU indices you want to use in the python file:

trainer = ADATrainer(model, device_ids=[1, 2, 7])
trainer.fit(train_loader, val_loader)
trainer.test(test_loader)

Then, you can use the torch.distributed.launch or torchrun to run your python file.

python -m torch.distributed.launch \
    --nproc_per_node=3 \
    --nnodes=1 \
    --node_rank=0 \
    --master_addr="localhost" \
    --master_port=2345 \
    your_python_file.py

Here, nproc_per_node is the number of GPUs you specify.

Parameters
  • extractor (nn.Module) – The feature extraction model, learning the feature representation of EEG signal by forcing the correlation matrixes of source and target data close.

  • classifier (nn.Module) – The classification model, learning the classification task with source labeled data based on the feature of the feature extraction model. The dimension of its output should be equal to the number of categories in the dataset. The output layer does not need to have a softmax activation function.

  • lambd (float) – The weight of ADA loss to trade-off between the classification loss and ADA loss. (defualt: 1.0)

  • adaption_factor (bool) – Whether to adjust the cross-domain-related loss term using the fitness factor, which was first proposed in DANN but works in many cases. (defualt: False)

  • num_classes (int, optional) – The number of categories in the dataset. If None, the number of categories will be inferred from the attribute num_classes of the model. (defualt: None)

  • lr (float) – The learning rate. (defualt: 0.0001)

  • walker_weight (float) – The weight of walker loss. (defualt: 1.0)

  • visit_weight (float) – The weight of visit loss. (defualt: 1.0)

  • weight_decay (float) – The weight decay (L2 penalty). (defualt: 0.0)

  • device_ids (list) – Use cpu if the list is empty. If the list contains indices of multiple GPUs, it needs to be launched with torch.distributed.launch or torchrun. (defualt: [])

  • ddp_sync_bn (bool) – Whether to replace batch normalization in network structure with cross-GPU synchronized batch normalization. Only valid when the length of device_ids is greater than one. (defualt: True)

  • ddp_replace_sampler (bool) – Whether to replace sampler in dataloader with DistributedSampler. Only valid when the length of device_ids is greater than one. (defualt: True)

  • ddp_val (bool) – Whether to use multi-GPU acceleration for the validation set. For experiments where data input order is sensitive, ddp_val should be set to False. Only valid when the length of device_ids is greater than one. (defualt: True)

  • ddp_test (bool) – Whether to use multi-GPU acceleration for the test set. For experiments where data input order is sensitive, ddp_test should be set to False. Only valid when the length of device_ids is greater than one. (defualt: True)

fit(source_loader: DataLoader, target_loader: DataLoader, val_loader: DataLoader, num_epochs: int = 1, **kwargs)[source][source]
Parameters
  • source_loader (DataLoader) – Iterable DataLoader for traversing the data batch from the source domain (torch.utils.data.dataloader.DataLoader, torch_geometric.loader.DataLoader, etc).

  • target_loader (DataLoader) – Iterable DataLoader for traversing the training data batch from the target domain (torch.utils.data.dataloader.DataLoader, torch_geometric.loader.DataLoader, etc). The target dataset does not have to return labels.

  • val_loader (DataLoader) – Iterable DataLoader for traversing the validation data batch (torch.utils.data.dataloader.DataLoader, torch_geometric.loader.DataLoader, etc).

  • num_epochs (int) – training epochs. (defualt: 1)

test(test_loader: DataLoader, **kwargs)[source][source]
Parameters

test_loader (DataLoader) – Iterable DataLoader for traversing the test data batch (torch.utils.data.dataloader.DataLoader, torch_geometric.loader.DataLoader, etc).

trainers.GANTrainer

class torcheeg.trainers.GANTrainer(generator: Module, discriminator: Module, generator_lr: float = 0.0001, discriminator_lr: float = 0.0001, lambd: float = 1.0, weight_decay: float = 0.0, device_ids: List[int] = [], ddp_sync_bn: bool = True, ddp_replace_sampler: bool = True, ddp_val: bool = True, ddp_test: bool = True)[source][source]

This class provide the implementation for WGAN-GP. It trains a zero-sum game between the generator and the discriminator, just like the traditional generative networks. The generator is optimized to generate simulation samples that are indistinguishable by the discriminator, and the discriminator is optimized to discriminate false samples generated by the generator. Compared with vanilla GAN, with WGAN-GP we can improve the stability of learning, get rid of problems like mode collapse, and provide meaningful learning curves useful for debugging and hyperparameter searches. Thus, existing work typically uses WGAN-GP to generate simulated EEG signals. For more details, please refer to the following information.

g_model = BGenerator(in_channels=128)
d_model = BDiscriminator(in_channels=4)
trainer = GANTrainer(generator, discriminator)
trainer.fit(train_loader, val_loader)
trainer.test(test_loader)

The class provides the following hook functions for inserting additional implementations in the training, validation and testing lifecycle:

  • before_training_epoch: executed before each epoch of training starts

  • before_training_step: executed before each batch of training starts

  • on_training_step: the training process for each batch

  • after_training_step: execute after the training of each batch

  • after_training_epoch: executed after each epoch of training

  • before_validation_epoch: executed before each round of validation starts

  • before_validation_step: executed before the validation of each batch

  • on_validation_step: validation process for each batch

  • after_validation_step: executed after the validation of each batch

  • after_validation_epoch: executed after each round of validation

  • before_test_epoch: executed before each round of test starts

  • before_test_step: executed before the test of each batch

  • on_test_step: test process for each batch

  • after_test_step: executed after the test of each batch

  • after_test_epoch: executed after each round of test

If you want to customize some operations, you just need to inherit the class and override the hook function:

class MyGANTrainer(GANTrainer):
    def before_training_epoch(self, epoch_id: int, num_epochs: int):
        # Do something here.
        super().before_training_epoch(epoch_id, num_epochs)

If you want to use multiple GPUs for parallel computing, you need to specify the GPU indices you want to use in the python file:

trainer = GANTrainer(generator, discriminator, device_ids=[1, 2, 7])
trainer.fit(train_loader, val_loader)
trainer.test(test_loader)

Then, you can use the torch.distributed.launch or torchrun to run your python file.

python -m torch.distributed.launch \
    --nproc_per_node=3 \
    --nnodes=1 \
    --node_rank=0 \
    --master_addr="localhost" \
    --master_port=2345 \
    your_python_file.py

Here, nproc_per_node is the number of GPUs you specify.

Parameters
  • generator (nn.Module) – The generator model for EEG signal generation, whose inputs are Gaussian distributed random vectors, outputs are generated EEG signals. The dimensions of the input vector should be defined on the in_channel attribute. The output layer does not need to have a softmax activation function.

  • discriminator (nn.Module) – The discriminator model to determine whether the EEG signal is real or generated, and the dimension of its output should be equal to the one (i.e., the score to distinguish the real and the fake). The output layer does not need to have a sigmoid activation function.

  • generator_lr (float) – The learning rate of the generator. (defualt: 0.0001)

  • discriminator_lr (float) – The learning rate of the discriminator. (defualt: 0.0001)

  • lambd (float) – The weight of gradient penalty loss to trade-off between the adversarial training loss and gradient penalty loss. (defualt: 1.0)

  • weight_decay – (float): The weight decay (L2 penalty). (defualt: 0.0)

  • device_ids (list) – Use cpu if the list is empty. If the list contains indices of multiple GPUs, it needs to be launched with torch.distributed.launch or torchrun. (defualt: [])

  • ddp_sync_bn (bool) – Whether to replace batch normalization in network structure with cross-GPU synchronized batch normalization. Only valid when the length of device_ids is greater than one. (defualt: True)

  • ddp_replace_sampler (bool) – Whether to replace sampler in dataloader with DistributedSampler. Only valid when the length of device_ids is greater than one. (defualt: True)

  • ddp_val (bool) – Whether to use multi-GPU acceleration for the validation set. For experiments where data input order is sensitive, ddp_val should be set to False. Only valid when the length of device_ids is greater than one. (defualt: True)

  • ddp_test (bool) – Whether to use multi-GPU acceleration for the test set. For experiments where data input order is sensitive, ddp_test should be set to False. Only valid when the length of device_ids is greater than one. (defualt: True)

fit(train_loader: DataLoader, val_loader: DataLoader, num_epochs: int = 1, **kwargs)[source][source]
Parameters
  • train_loader (DataLoader) – Iterable DataLoader for traversing the training data batch (torch.utils.data.dataloader.DataLoader, torch_geometric.loader.DataLoader, etc).

  • val_loader (DataLoader) – Iterable DataLoader for traversing the validation data batch (torch.utils.data.dataloader.DataLoader, torch_geometric.loader.DataLoader, etc).

  • num_epochs (int) – training epochs. (defualt: 1)

test(test_loader: DataLoader, **kwargs)[source][source]
Parameters

test_loader (DataLoader) – Iterable DataLoader for traversing the test data batch (torch.utils.data.dataloader.DataLoader, torch_geometric.loader.DataLoader, etc).

sample(num_samples: int) Tensor[source][source]

Samples from the latent space and return generated results.

Parameters

num_samples (int) – Number of samples.

Returns

the generated samples.

Return type

torch.Tensor

trainers.CGANTrainer

class torcheeg.trainers.CGANTrainer(generator: Module, discriminator: Module, generator_lr: float = 0.0001, discriminator_lr: float = 0.0001, lambd: float = 1.0, weight_decay: float = 0.0, device_ids: List[int] = [], ddp_sync_bn: bool = True, ddp_replace_sampler: bool = True, ddp_val: bool = True, ddp_test: bool = True)[source][source]

This class provide the implementation for WGAN-GP. It trains a zero-sum game between the generator and the discriminator, just like the traditional generative networks. The generator is optimized to generate simulation samples that are indistinguishable by the discriminator, and the discriminator is optimized to discriminate false samples generated by the generator. Compared with vanilla GAN, with WGAN-GP we can improve the stability of learning, get rid of problems like mode collapse, and provide meaningful learning curves useful for debugging and hyperparameter searches. Thus, existing work typically uses WGAN-GP to generate simulated EEG signals. In particular, the expected labels are additionally provided to guide the discriminator to distinguish whether the sample fits the data distribution of the class. For more details, please refer to the following information.

g_model = BCGenerator(in_channels=128, num_classes=2)
d_model = BCDiscriminator(in_channels=4, num_classes=2)
trainer = CGANTrainer(generator, discriminator)
trainer.fit(train_loader, val_loader)
trainer.test(test_loader)

The class provides the following hook functions for inserting additional implementations in the training, validation and testing lifecycle:

  • before_training_epoch: executed before each epoch of training starts

  • before_training_step: executed before each batch of training starts

  • on_training_step: the training process for each batch

  • after_training_step: execute after the training of each batch

  • after_training_epoch: executed after each epoch of training

  • before_validation_epoch: executed before each round of validation starts

  • before_validation_step: executed before the validation of each batch

  • on_validation_step: validation process for each batch

  • after_validation_step: executed after the validation of each batch

  • after_validation_epoch: executed after each round of validation

  • before_test_epoch: executed before each round of test starts

  • before_test_step: executed before the test of each batch

  • on_test_step: test process for each batch

  • after_test_step: executed after the test of each batch

  • after_test_epoch: executed after each round of test

If you want to customize some operations, you just need to inherit the class and override the hook function:

class MyGANTrainer(GANTrainer):
    def before_training_epoch(self, epoch_id: int, num_epochs: int):
        # Do something here.
        super().before_training_epoch(epoch_id, num_epochs)

If you want to use multiple GPUs for parallel computing, you need to specify the GPU indices you want to use in the python file:

trainer = GANTrainer(generator, discriminator, device_ids=[1, 2, 7])
trainer.fit(train_loader, val_loader)
trainer.test(test_loader)

Then, you can use the torch.distributed.launch or torchrun to run your python file.

python -m torch.distributed.launch \
    --nproc_per_node=3 \
    --nnodes=1 \
    --node_rank=0 \
    --master_addr="localhost" \
    --master_port=2345 \
    your_python_file.py

Here, nproc_per_node is the number of GPUs you specify.

Parameters
  • generator (nn.Module) – The generator model for EEG signal generation, whose inputs are Gaussian distributed random vectors, outputs are generated EEG signals. The dimensions of the input vector should be defined on the in_channel attribute. The output layer does not need to have a softmax activation function.

  • discriminator (nn.Module) – The discriminator model to determine whether the EEG signal is real or generated, and the dimension of its output should be equal to the one (i.e., the score to distinguish the real and the fake). The output layer does not need to have a sigmoid activation function.

  • generator_lr (float) – The learning rate of the generator. (defualt: 0.0001)

  • discriminator_lr (float) – The learning rate of the discriminator. (defualt: 0.0001)

  • lambd (float) – The weight of gradient penalty loss to trade-off between the adversarial training loss and gradient penalty loss. (defualt: 1.0)

  • weight_decay – (float): The weight decay (L2 penalty). (defualt: 0.0)

  • device_ids (list) – Use cpu if the list is empty. If the list contains indices of multiple GPUs, it needs to be launched with torch.distributed.launch or torchrun. (defualt: [])

  • ddp_sync_bn (bool) – Whether to replace batch normalization in network structure with cross-GPU synchronized batch normalization. Only valid when the length of device_ids is greater than one. (defualt: True)

  • ddp_replace_sampler (bool) – Whether to replace sampler in dataloader with DistributedSampler. Only valid when the length of device_ids is greater than one. (defualt: True)

  • ddp_val (bool) – Whether to use multi-GPU acceleration for the validation set. For experiments where data input order is sensitive, ddp_val should be set to False. Only valid when the length of device_ids is greater than one. (defualt: True)

  • ddp_test (bool) – Whether to use multi-GPU acceleration for the test set. For experiments where data input order is sensitive, ddp_test should be set to False. Only valid when the length of device_ids is greater than one. (defualt: True)

fit(train_loader: DataLoader, val_loader: DataLoader, num_epochs: int = 1, **kwargs)[source]
Parameters
  • train_loader (DataLoader) – Iterable DataLoader for traversing the training data batch (torch.utils.data.dataloader.DataLoader, torch_geometric.loader.DataLoader, etc).

  • val_loader (DataLoader) – Iterable DataLoader for traversing the validation data batch (torch.utils.data.dataloader.DataLoader, torch_geometric.loader.DataLoader, etc).

  • num_epochs (int) – training epochs. (defualt: 1)

test(test_loader: DataLoader, **kwargs)[source]
Parameters

test_loader (DataLoader) – Iterable DataLoader for traversing the test data batch (torch.utils.data.dataloader.DataLoader, torch_geometric.loader.DataLoader, etc).

sample(num_samples: int, labels: Optional[Tensor] = None) Tensor[source][source]

Samples from the latent space and return generated results.

Parameters
  • num_samples (int) – Number of samples.

  • labels (torch.Tensor) – Category labels (int) for a batch of samples The shape should be [n,]. Here, n corresponds to the batch size. If not provided, a batch of randomly generated categories will be used.

Returns

the generated samples.

Return type

torch.Tensor

trainers.VAETrainer

class torcheeg.trainers.VAETrainer(encoder: Module, decoder: Module, lr: float = 0.0001, weight_decay: float = 0.0, beta: float = 1.0, device_ids: List[int] = [], ddp_sync_bn: bool = True, ddp_replace_sampler: bool = True, ddp_val: bool = True, ddp_test: bool = True)[source][source]

The variational autoencoder consists of two parts, an encoder, and a decoder. The encoder compresses the input into the latent space. The decoder receives as input the information sampled from the latent space and produces it as similar as possible to ground truth. The latent vector should approach the gaussian distribution supervised by KL divergence based on the variation trick. This class implement the training, test, and new EEG inference of variational autoencoders.

encoder = BEncoder(in_channels=4)
decoder = BDecoder(in_channels=64, out_channels=4)
trainer = VAETrainer(encoder, decoder)
trainer.fit(train_loader, val_loader)
trainer.test(test_loader)

The class provides the following hook functions for inserting additional implementations in the training, validation and testing lifecycle:

  • before_training_epoch: executed before each epoch of training starts

  • before_training_step: executed before each batch of training starts

  • on_training_step: the training process for each batch

  • after_training_step: execute after the training of each batch

  • after_training_epoch: executed after each epoch of training

  • before_validation_epoch: executed before each round of validation starts

  • before_validation_step: executed before the validation of each batch

  • on_validation_step: validation process for each batch

  • after_validation_step: executed after the validation of each batch

  • after_validation_epoch: executed after each round of validation

  • before_test_epoch: executed before each round of test starts

  • before_test_step: executed before the test of each batch

  • on_test_step: test process for each batch

  • after_test_step: executed after the test of each batch

  • after_test_epoch: executed after each round of test

If you want to customize some operations, you just need to inherit the class and override the hook function:

class MyVAETrainer(VAETrainer):
    def before_training_epoch(self, epoch_id: int, num_epochs: int):
        # Do something here.
        super().before_training_epoch(epoch_id, num_epochs)

If you want to use multiple GPUs for parallel computing, you need to specify the GPU indices you want to use in the python file:

trainer = VAETrainer(generator, discriminator, device_ids=[1, 2, 7])
trainer.fit(train_loader, val_loader)
trainer.test(test_loader)

Then, you can use the torch.distributed.launch or torchrun to run your python file.

python -m torch.distributed.launch \
    --nproc_per_node=3 \
    --nnodes=1 \
    --node_rank=0 \
    --master_addr="localhost" \
    --master_port=2345 \
    your_python_file.py

Here, nproc_per_node is the number of GPUs you specify.

Parameters
  • encoder (nn.Module) – The encoder, whose inputs are EEG signals, outputs are two batches of vectors of the same dimension, representing the mean and variance estimated in the reparameterization trick.

  • decoder (nn.Module) – The decoder generating EEG signals from hidden variables encoded by the encoder. The dimensions of the input vector should be defined on the in_channel attribute.

  • lr (float) – The learning rate. (defualt: 0.0001)

  • weight_decay – (float): The weight decay (L2 penalty). (defualt: 0.0)

  • beta – (float): The weight of the KL divergence in the loss function. Please refer to betaVAE. (defualt: 1.0)

  • device_ids (list) – Use cpu if the list is empty. If the list contains indices of multiple GPUs, it needs to be launched with torch.distributed.launch or torchrun. (defualt: [])

  • ddp_sync_bn (bool) – Whether to replace batch normalization in network structure with cross-GPU synchronized batch normalization. Only valid when the length of device_ids is greater than one. (defualt: True)

  • ddp_replace_sampler (bool) – Whether to replace sampler in dataloader with DistributedSampler. Only valid when the length of device_ids is greater than one. (defualt: True)

  • ddp_val (bool) – Whether to use multi-GPU acceleration for the validation set. For experiments where data input order is sensitive, ddp_val should be set to False. Only valid when the length of device_ids is greater than one. (defualt: True)

  • ddp_test (bool) – Whether to use multi-GPU acceleration for the test set. For experiments where data input order is sensitive, ddp_test should be set to False. Only valid when the length of device_ids is greater than one. (defualt: True)

fit(train_loader: DataLoader, val_loader: DataLoader, num_epochs: int = 1, **kwargs)[source][source]

Train the model on the training set and use the validation set to validate the results of each round of training.

Parameters
  • train_loader (DataLoader) – Iterable DataLoader for traversing the training data batch (torch.utils.data.dataloader.DataLoader, torch_geometric.loader.DataLoader, etc).

  • val_loader (DataLoader) – Iterable DataLoader for traversing the validation data batch (torch.utils.data.dataloader.DataLoader, torch_geometric.loader.DataLoader, etc).

  • num_epochs (int) – training epochs. (defualt: 1)

test(test_loader: DataLoader, **kwargs)[source][source]

Validate the performance of the model on the test set.

Parameters

test_loader (DataLoader) – Iterable DataLoader for traversing the test data batch (torch.utils.data.dataloader.DataLoader, torch_geometric.loader.DataLoader, etc).

sample(num_samples: int) Tensor[source][source]

Samples from the latent space and return generated results.

Parameters

num_samples (int) – Number of samples.

Returns

the generated samples.

Return type

torch.Tensor

trainers.CVAETrainer

class torcheeg.trainers.CVAETrainer(encoder: Module, decoder: Module, lr: float = 0.0001, weight_decay: float = 0.0, beta: float = 1.0, device_ids: List[int] = [], ddp_sync_bn: bool = True, ddp_replace_sampler: bool = True, ddp_val: bool = True, ddp_test: bool = True)[source][source]

The variational autoencoder consists of two parts, an encoder, and a decoder. The encoder compresses the input into the latent space. The decoder receives as input the information sampled from the latent space and produces it as similar as possible to ground truth. The latent vector should approach the gaussian distribution supervised by KL divergence based on the variation trick. This class implement the training, test, and new EEG inference of variational autoencoders.

encoder = BCEncoder(in_channels=4, num_classes=2)
decoder = BCDecoder(in_channels=64, out_channels=4, num_classes=2)
trainer = CVAETrainer(encoder, decoder)
trainer.fit(train_loader, val_loader)
trainer.test(test_loader)

The class provides the following hook functions for inserting additional implementations in the training, validation and testing lifecycle:

  • before_training_epoch: executed before each epoch of training starts

  • before_training_step: executed before each batch of training starts

  • on_training_step: the training process for each batch

  • after_training_step: execute after the training of each batch

  • after_training_epoch: executed after each epoch of training

  • before_validation_epoch: executed before each round of validation starts

  • before_validation_step: executed before the validation of each batch

  • on_validation_step: validation process for each batch

  • after_validation_step: executed after the validation of each batch

  • after_validation_epoch: executed after each round of validation

  • before_test_epoch: executed before each round of test starts

  • before_test_step: executed before the test of each batch

  • on_test_step: test process for each batch

  • after_test_step: executed after the test of each batch

  • after_test_epoch: executed after each round of test

If you want to customize some operations, you just need to inherit the class and override the hook function:

class MyCVAETrainer(CVAETrainer):
    def before_training_epoch(self, epoch_id: int, num_epochs: int):
        # Do something here.
        super().before_training_epoch(epoch_id, num_epochs)

If you want to use multiple GPUs for parallel computing, you need to specify the GPU indices you want to use in the python file:

trainer = CVAETrainer(generator, discriminator, device_ids=[1, 2, 7])
trainer.fit(train_loader, val_loader)
trainer.test(test_loader)

Then, you can use the torch.distributed.launch or torchrun to run your python file.

python -m torch.distributed.launch \
    --nproc_per_node=3 \
    --nnodes=1 \
    --node_rank=0 \
    --master_addr="localhost" \
    --master_port=2345 \
    your_python_file.py

Here, nproc_per_node is the number of GPUs you specify.

Parameters
  • encoder (nn.Module) – The encoder, whose inputs are EEG signals, outputs are two batches of vectors of the same dimension, representing the mean and variance estimated in the reparameterization trick.

  • decoder (nn.Module) – The decoder generating EEG signals from hidden variables encoded by the encoder. The dimensions of the input vector should be defined on the in_channel attribute.

  • lr (float) – The learning rate. (defualt: 0.0001)

  • weight_decay – (float): The weight decay (L2 penalty). (defualt: 0.0)

  • beta – (float): The weight of the KL divergence in the loss function. Please refer to betaVAE. (defualt: 1.0)

  • device_ids (list) – Use cpu if the list is empty. If the list contains indices of multiple GPUs, it needs to be launched with torch.distributed.launch or torchrun. (defualt: [])

  • ddp_sync_bn (bool) – Whether to replace batch normalization in network structure with cross-GPU synchronized batch normalization. Only valid when the length of device_ids is greater than one. (defualt: True)

  • ddp_replace_sampler (bool) – Whether to replace sampler in dataloader with DistributedSampler. Only valid when the length of device_ids is greater than one. (defualt: True)

  • ddp_val (bool) – Whether to use multi-GPU acceleration for the validation set. For experiments where data input order is sensitive, ddp_val should be set to False. Only valid when the length of device_ids is greater than one. (defualt: True)

  • ddp_test (bool) – Whether to use multi-GPU acceleration for the test set. For experiments where data input order is sensitive, ddp_test should be set to False. Only valid when the length of device_ids is greater than one. (defualt: True)

fit(train_loader: DataLoader, val_loader: DataLoader, num_epochs: int = 1, **kwargs)[source]

Train the model on the training set and use the validation set to validate the results of each round of training.

Parameters
  • train_loader (DataLoader) – Iterable DataLoader for traversing the training data batch (torch.utils.data.dataloader.DataLoader, torch_geometric.loader.DataLoader, etc).

  • val_loader (DataLoader) – Iterable DataLoader for traversing the validation data batch (torch.utils.data.dataloader.DataLoader, torch_geometric.loader.DataLoader, etc).

  • num_epochs (int) – training epochs. (defualt: 1)

test(test_loader: DataLoader, **kwargs)[source]

Validate the performance of the model on the test set.

Parameters

test_loader (DataLoader) – Iterable DataLoader for traversing the test data batch (torch.utils.data.dataloader.DataLoader, torch_geometric.loader.DataLoader, etc).

sample(num_samples: int, labels: Optional[Tensor] = None) Tensor[source][source]

Samples from the latent space and return generated results.

Parameters
  • num_samples (int) – Number of samples.

  • labels (torch.Tensor) – Category labels (int) for a batch of samples The shape should be [n,]. Here, n corresponds to the batch size. If not provided, a batch of randomly generated categories will be used.

Returns

the generated samples.

Return type

torch.Tensor

trainers.DDPMTrainer

class torcheeg.trainers.DDPMTrainer(unet: Module, lr: float = 0.0003, beta_timesteps: int = 1000, beta_start: float = 0.0001, beta_end: float = 0.02, device_ids: List[int] = [], ddp_sync_bn: bool = True, ddp_replace_sampler: bool = True, ddp_val: bool = True, ddp_test: bool = True)[source][source]

The diffusion model consists of two processes, the forward process, and the backward process. The forward process is to gradually add Gaussian noise to an image until it becomes random noise, while the backward process is the de-noising process. We train an attention-based UNet network at the backward process to start with random noise and gradually de-noise it until an image is generated and use the UNet to generate a simulated image from random noises. This class implements the training, test, and new sample inference of DDPM.

unet = BUNet(in_channels=4)
trainer = DDPMTrainer(unet)
trainer.fit(train_loader, val_loader)
trainer.test(test_loader)

The class provides the following hook functions for inserting additional implementations in the training, validation and testing lifecycle:

  • before_training_epoch: executed before each epoch of training starts

  • before_training_step: executed before each batch of training starts

  • on_training_step: the training process for each batch

  • after_training_step: execute after the training of each batch

  • after_training_epoch: executed after each epoch of training

  • before_validation_epoch: executed before each round of validation starts

  • before_validation_step: executed before the validation of each batch

  • on_validation_step: validation process for each batch

  • after_validation_step: executed after the validation of each batch

  • after_validation_epoch: executed after each round of validation

  • before_test_epoch: executed before each round of test starts

  • before_test_step: executed before the test of each batch

  • on_test_step: test process for each batch

  • after_test_step: executed after the test of each batch

  • after_test_epoch: executed after each round of test

If you want to customize some operations, you just need to inherit the class and override the hook function:

class MyDDPMTrainer(DDPMTrainer):
    def before_training_epoch(self, epoch_id: int, num_epochs: int):
        # Do something here.
        super().before_training_epoch(epoch_id, num_epochs)

If you want to use multiple GPUs for parallel computing, you need to specify the GPU indices you want to use in the python file:

trainer = DDPMTrainer(generator, discriminator, device_ids=[1, 2, 7])
trainer.fit(train_loader, val_loader)
trainer.test(test_loader)

Then, you can use the torch.distributed.launch or torchrun to run your python file.

python -m torch.distributed.launch \
    --nproc_per_node=3 \
    --nnodes=1 \
    --node_rank=0 \
    --master_addr="localhost" \
    --master_port=2345 \
    your_python_file.py

Here, nproc_per_node is the number of GPUs you specify.

Parameters
  • encoder (nn.Module) – The encoder, whose inputs are EEG signals, outputs are two batches of vectors of the same dimension, representing the mean and variance estimated in the reparameterization trick.

  • decoder (nn.Module) – The decoder generating EEG signals from hidden variables encoded by the encoder. The dimensions of the input vector should be defined on the in_channel attribute.

  • lr (float) – The learning rate. (defualt: 0.0001)

  • weight_decay – (float): The weight decay (L2 penalty). (defualt: 0.0)

  • beta_start – (float): The start point of the linear beta scheduler to sample betas. (defualt: 1e-4)

  • beta_end – (float): The end point of the linear beta scheduler to sample betas. (defualt: 2e-2)

  • device_ids (list) – Use cpu if the list is empty. If the list contains indices of multiple GPUs, it needs to be launched with torch.distributed.launch or torchrun. (defualt: [])

  • ddp_sync_bn (bool) – Whether to replace batch normalization in network structure with cross-GPU synchronized batch normalization. Only valid when the length of device_ids is greater than one. (defualt: True)

  • ddp_replace_sampler (bool) – Whether to replace sampler in dataloader with DistributedSampler. Only valid when the length of device_ids is greater than one. (defualt: True)

  • ddp_val (bool) – Whether to use multi-GPU acceleration for the validation set. For experiments where data input order is sensitive, ddp_val should be set to False. Only valid when the length of device_ids is greater than one. (defualt: True)

  • ddp_test (bool) – Whether to use multi-GPU acceleration for the test set. For experiments where data input order is sensitive, ddp_test should be set to False. Only valid when the length of device_ids is greater than one. (defualt: True)

fit(train_loader: DataLoader, val_loader: DataLoader, num_epochs: int = 1, **kwargs)[source][source]

Train the model on the training set and use the validation set to validate the results of each round of training.

Parameters
  • train_loader (DataLoader) – Iterable DataLoader for traversing the training data batch (torch.utils.data.dataloader.DataLoader, torch_geometric.loader.DataLoader, etc).

  • val_loader (DataLoader) – Iterable DataLoader for traversing the validation data batch (torch.utils.data.dataloader.DataLoader, torch_geometric.loader.DataLoader, etc).

  • num_epochs (int) – training epochs. (defualt: 1)

test(test_loader: DataLoader, **kwargs)[source][source]

Validate the performance of the model on the test set.

Parameters

test_loader (DataLoader) – Iterable DataLoader for traversing the test data batch (torch.utils.data.dataloader.DataLoader, torch_geometric.loader.DataLoader, etc).

sample(num_samples: int, sample_size: Tuple[int]) Tensor[source][source]

Samples from the latent space and return generated results.

Parameters
  • num_samples (int) – Number of samples.

  • sample_size (tuple) – Shape of a sample.

Returns

the generated samples.

Return type

torch.Tensor

trainers.CDDPMTrainer

class torcheeg.trainers.CDDPMTrainer(unet: Module, lr: float = 0.0003, beta_timesteps: int = 1000, beta_start: float = 0.0001, beta_end: float = 0.02, device_ids: List[int] = [], ddp_sync_bn: bool = True, ddp_replace_sampler: bool = True, ddp_val: bool = True, ddp_test: bool = True)[source][source]

The diffusion model consists of two processes, the forward process, and the backward process. The forward process is to gradually add Gaussian noise to an image until it becomes random noise, while the backward process is the de-noising process. We train an attention-based UNet network at the backward process to start with random noise and gradually de-noise it until an image is generated and use the UNet to generate a simulated image from random noises. In particular, in conditional UNet, additional label information is provided to guide the noise reduction results during the noise reduction process. This class implements the training, test, and new sample inference of the conditional DDPM.

unet = BCUNet(in_channels=4, num_classes=2)
trainer = CDDPMTrainer(unet)
trainer.fit(train_loader, val_loader)
trainer.test(test_loader)

The class provides the following hook functions for inserting additional implementations in the training, validation and testing lifecycle:

  • before_training_epoch: executed before each epoch of training starts

  • before_training_step: executed before each batch of training starts

  • on_training_step: the training process for each batch

  • after_training_step: execute after the training of each batch

  • after_training_epoch: executed after each epoch of training

  • before_validation_epoch: executed before each round of validation starts

  • before_validation_step: executed before the validation of each batch

  • on_validation_step: validation process for each batch

  • after_validation_step: executed after the validation of each batch

  • after_validation_epoch: executed after each round of validation

  • before_test_epoch: executed before each round of test starts

  • before_test_step: executed before the test of each batch

  • on_test_step: test process for each batch

  • after_test_step: executed after the test of each batch

  • after_test_epoch: executed after each round of test

If you want to customize some operations, you just need to inherit the class and override the hook function:

class MyCDDPMTrainer(CDDPMTrainer):
    def before_training_epoch(self, epoch_id: int, num_epochs: int):
        # Do something here.
        super().before_training_epoch(epoch_id, num_epochs)

If you want to use multiple GPUs for parallel computing, you need to specify the GPU indices you want to use in the python file:

trainer = CDDPMTrainer(generator, discriminator, device_ids=[1, 2, 7])
trainer.fit(train_loader, val_loader)
trainer.test(test_loader)

Then, you can use the torch.distributed.launch or torchrun to run your python file.

python -m torch.distributed.launch \
    --nproc_per_node=3 \
    --nnodes=1 \
    --node_rank=0 \
    --master_addr="localhost" \
    --master_port=2345 \
    your_python_file.py

Here, nproc_per_node is the number of GPUs you specify.

Parameters
  • encoder (nn.Module) – The encoder, whose inputs are EEG signals, outputs are two batches of vectors of the same dimension, representing the mean and variance estimated in the reparameterization trick.

  • decoder (nn.Module) – The decoder generating EEG signals from hidden variables encoded by the encoder. The dimensions of the input vector should be defined on the in_channel attribute.

  • lr (float) – The learning rate. (defualt: 0.0001)

  • weight_decay – (float): The weight decay (L2 penalty). (defualt: 0.0)

  • beta_start – (float): The start point of the linear beta scheduler to sample betas. (defualt: 1e-4)

  • beta_end – (float): The end point of the linear beta scheduler to sample betas. (defualt: 2e-2)

  • device_ids (list) – Use cpu if the list is empty. If the list contains indices of multiple GPUs, it needs to be launched with torch.distributed.launch or torchrun. (defualt: [])

  • ddp_sync_bn (bool) – Whether to replace batch normalization in network structure with cross-GPU synchronized batch normalization. Only valid when the length of device_ids is greater than one. (defualt: True)

  • ddp_replace_sampler (bool) – Whether to replace sampler in dataloader with DistributedSampler. Only valid when the length of device_ids is greater than one. (defualt: True)

  • ddp_val (bool) – Whether to use multi-GPU acceleration for the validation set. For experiments where data input order is sensitive, ddp_val should be set to False. Only valid when the length of device_ids is greater than one. (defualt: True)

  • ddp_test (bool) – Whether to use multi-GPU acceleration for the test set. For experiments where data input order is sensitive, ddp_test should be set to False. Only valid when the length of device_ids is greater than one. (defualt: True)

fit(train_loader: DataLoader, val_loader: DataLoader, num_epochs: int = 1, **kwargs)[source]

Train the model on the training set and use the validation set to validate the results of each round of training.

Parameters
  • train_loader (DataLoader) – Iterable DataLoader for traversing the training data batch (torch.utils.data.dataloader.DataLoader, torch_geometric.loader.DataLoader, etc).

  • val_loader (DataLoader) – Iterable DataLoader for traversing the validation data batch (torch.utils.data.dataloader.DataLoader, torch_geometric.loader.DataLoader, etc).

  • num_epochs (int) – training epochs. (defualt: 1)

test(test_loader: DataLoader, **kwargs)[source]

Validate the performance of the model on the test set.

Parameters

test_loader (DataLoader) – Iterable DataLoader for traversing the test data batch (torch.utils.data.dataloader.DataLoader, torch_geometric.loader.DataLoader, etc).

sample(num_samples: int, sample_size: Tuple[int], labels: Optional[Tensor] = None) Tensor[source][source]

Samples from the latent space and return generated results.

Parameters
  • num_samples (int) – Number of samples.

  • sample_size (tuple) – Shape of a sample.

  • labels (torch.Tensor) – Category labels (int) for a batch of samples The shape should be [n,]. Here, n corresponds to the batch size. If not provided, a batch of randomly generated categories will be used.

Returns

the generated samples.

Return type

torch.Tensor

trainers.GlowTrainer

class torcheeg.trainers.GlowTrainer(glow: Module, lr: float = 0.0001, grad_norm_clip: float = 50.0, loss_scale: float = 0.001, device_ids: List[int] = [], ddp_sync_bn: bool = True, ddp_replace_sampler: bool = True, ddp_val: bool = True, ddp_test: bool = True)[source][source]

This class implement the training, test, and new EEG inference of normalizing flow-based models. Glow is dedicated to train an encoder that encodes the input as a hidden variable and makes the hidden variable obey the standard normal distribution. By good design, the encoder should be reversible. On this basis, as soon as the encoder is trained, the corresponding decoder can be used to generate samples from a Gaussian distribution according to the inverse operation. In particular, compared with vanilla normalizing flow-based models, the Glow model is a easy-to-use flow-based model that replaces the operation of permutating the channel axes by introducing a 1x1 reversible convolution.

Below is a recommended suite for use in EEG generation:

model = BGlow(in_channels=4)
trainer = GlowTrainer(generator, discriminator)
trainer.fit(train_loader, val_loader)
trainer.test(test_loader)

Below is a recommended suite for use in conditional EEG generation:

model = BGlow(in_channels=4, num_classes=2)
trainer = GlowTrainer(generator, discriminator)
trainer.fit(train_loader, val_loader)
trainer.test(test_loader)

The class provides the following hook functions for inserting additional implementations in the training, validation and testing lifecycle:

  • before_training_epoch: executed before each epoch of training starts

  • before_training_step: executed before each batch of training starts

  • on_training_step: the training process for each batch

  • after_training_step: execute after the training of each batch

  • after_training_epoch: executed after each epoch of training

  • before_validation_epoch: executed before each round of validation starts

  • before_validation_step: executed before the validation of each batch

  • on_validation_step: validation process for each batch

  • after_validation_step: executed after the validation of each batch

  • after_validation_epoch: executed after each round of validation

  • before_test_epoch: executed before each round of test starts

  • before_test_step: executed before the test of each batch

  • on_test_step: test process for each batch

  • after_test_step: executed after the test of each batch

  • after_test_epoch: executed after each round of test

If you want to customize some operations, you just need to inherit the class and override the hook function:

class MyGlowTrainer(GlowTrainer):
    def before_training_epoch(self, epoch_id: int, num_epochs: int):
        # Do something here.
        super().before_training_epoch(epoch_id, num_epochs)

If you want to use multiple GPUs for parallel computing, you need to specify the GPU indices you want to use in the python file:

trainer = GlowTrainer(generator, discriminator, device_ids=[1, 2, 7])
trainer.fit(train_loader, val_loader)
trainer.test(test_loader)

Then, you can use the torch.distributed.launch or torchrun to run your python file.

python -m torch.distributed.launch \
    --nproc_per_node=3 \
    --nnodes=1 \
    --node_rank=0 \
    --master_addr="localhost" \
    --master_port=2345 \
    your_python_file.py

Here, nproc_per_node is the number of GPUs you specify.

Parameters
  • encoder (nn.Module) – The encoder, whose inputs are EEG signals, outputs are two batches of vectors of the same dimension, representing the mean and variance estimated in the reparameterization trick.

  • decoder (nn.Module) – The decoder generating EEG signals from hidden variables encoded by the encoder. The dimensions of the input vector should be defined on the in_channel attribute.

  • lr (float) – The learning rate. (defualt: 0.0001)

  • weight_decay – (float): The weight decay (L2 penalty). (defualt: 0.0)

  • beta – (float): The weight of the KL divergence in the loss function. Please refer to betaGlow. (defualt: 1.0)

  • device_ids (list) – Use cpu if the list is empty. If the list contains indices of multiple GPUs, it needs to be launched with torch.distributed.launch or torchrun. (defualt: [])

  • ddp_sync_bn (bool) – Whether to replace batch normalization in network structure with cross-GPU synchronized batch normalization. Only valid when the length of device_ids is greater than one. (defualt: True)

  • ddp_replace_sampler (bool) – Whether to replace sampler in dataloader with DistributedSampler. Only valid when the length of device_ids is greater than one. (defualt: True)

  • ddp_val (bool) – Whether to use multi-GPU acceleration for the validation set. For experiments where data input order is sensitive, ddp_val should be set to False. Only valid when the length of device_ids is greater than one. (defualt: True)

  • ddp_test (bool) – Whether to use multi-GPU acceleration for the test set. For experiments where data input order is sensitive, ddp_test should be set to False. Only valid when the length of device_ids is greater than one. (defualt: True)

fit(train_loader: DataLoader, val_loader: DataLoader, num_epochs: int = 1, **kwargs)[source][source]

Train the model on the training set and use the validation set to validate the results of each round of training.

Parameters
  • train_loader (DataLoader) – Iterable DataLoader for traversing the training data batch (torch.utils.data.dataloader.DataLoader, torch_geometric.loader.DataLoader, etc).

  • val_loader (DataLoader) – Iterable DataLoader for traversing the validation data batch (torch.utils.data.dataloader.DataLoader, torch_geometric.loader.DataLoader, etc).

  • num_epochs (int) – training epochs. (defualt: 1)

test(test_loader: DataLoader, **kwargs)[source][source]

Validate the performance of the model on the test set.

Parameters

test_loader (DataLoader) – Iterable DataLoader for traversing the test data batch (torch.utils.data.dataloader.DataLoader, torch_geometric.loader.DataLoader, etc).

sample(num_samples: int, temperature: float = 1.0, labels: Optional[Tensor] = None) Tensor[source][source]

Samples from the latent space and return generated results.

Parameters
  • num_samples (int) – Number of samples.

  • temperature (float) – The hyper-parameter, temperature, to sample from gaussian distributions. (defualt: 1.0)

  • labels (torch.Tensor) – Category labels (int) for a batch of samples The shape should be [n,]. Here, n corresponds to the batch size. If not provided, a batch of randomly generated categories will be used.

Returns

the generated samples.

Return type

torch.Tensor

Read the Docs v: v1.0.11
Versions
latest
stable
v1.0.11
v1.0.10
v1.0.9
v1.0.8.post1
v1.0.8
v1.0.7
v1.0.6
v1.0.4
v1.0.3
v1.0.2
v1.0.1
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources