torcheeg.trainers¶
Extensive trainers used to implement different training strategies, such as vanilla classification, domain adaption, etc.
trainers.BasicTrainer¶
- class torcheeg.trainers.BasicTrainer(modules: Dict, device_ids: List[int] = [], ddp_sync_bn: bool = True, ddp_replace_sampler: bool = True, ddp_val: bool = True, ddp_test: bool = True)[source][source]¶
A generic trainer class for EEG analysis providing interfaces for all trainers to implement contexts common in training deep learning models. After inheriting this class,
on_training_step
,on_validation_step
, andon_test_step
must be implemented.The class provides the following hook functions for inserting additional implementations in the training, validation and testing lifecycle:
before_training_epoch
: executed before each epoch of training startsbefore_training_step
: executed before each batch of training startson_training_step
: the training process for each batchafter_training_step
: execute after the training of each batchafter_training_epoch
: executed after each epoch of trainingbefore_validation_epoch
: executed before each round of validation startsbefore_validation_step
: executed before the validation of each batchon_validation_step
: validation process for each batchafter_validation_step
: executed after the validation of each batchafter_validation_epoch
: executed after each round of validationbefore_test_epoch
: executed before each round of test startsbefore_test_step
: executed before the test of each batchon_test_step
: test process for each batchafter_test_step
: executed after the test of each batchafter_test_epoch
: executed after each round of test
If you want to use multiple GPUs for parallel computing, you need to specify the GPU indices you want to use in the python file:
trainer = BasicTrainer(model, device_ids=[1, 2, 7]) trainer.fit(train_loader, val_loader) trainer.test(test_loader)
Then, you can use the
torch.distributed.launch
ortorchrun
to run your python file.python -m torch.distributed.launch \ --nproc_per_node=3 \ --nnodes=1 \ --node_rank=0 \ --master_addr="localhost" \ --master_port=2345 \ your_python_file.py
Here,
nproc_per_node
is the number of GPUs you specify.- Parameters
model (Dict) – A dictionary that stores neural networks for import, export and device conversion of neural networks.
device_ids (list) – Use cpu if the list is empty. If the list contains indices of multiple GPUs, it needs to be launched with
torch.distributed.launch
ortorchrun
. (defualt:[]
)ddp_sync_bn (bool) – Whether to replace batch normalization in network structure with cross-GPU synchronized batch normalization. Only valid when the length of
device_ids
is greater than one. (defualt:True
)ddp_replace_sampler (bool) – Whether to replace sampler in dataloader with
DistributedSampler
. Only valid when the length ofdevice_ids
is greater than one. (defualt:True
)ddp_val (bool) – Whether to use multi-GPU acceleration for the validation set. For experiments where data input order is sensitive,
ddp_val
should be set toFalse
. Only valid when the length ofdevice_ids
is greater than one. (defualt:True
)ddp_test (bool) – Whether to use multi-GPU acceleration for the test set. For experiments where data input order is sensitive,
ddp_test
should be set toFalse
. Only valid when the length ofdevice_ids
is greater than one. (defualt:True
)
- fit(train_loader: DataLoader, val_loader: DataLoader, num_epochs: int = 1, **kwargs)[source][source]¶
- Parameters
train_loader (DataLoader) – Iterable DataLoader for traversing the training data batch (torch.utils.data.dataloader.DataLoader, torch_geometric.loader.DataLoader, etc).
val_loader (DataLoader) – Iterable DataLoader for traversing the validation data batch (torch.utils.data.dataloader.DataLoader, torch_geometric.loader.DataLoader, etc).
num_epochs (int) – training epochs. (defualt:
1
)
trainers.ClassificationTrainer¶
- class torcheeg.trainers.ClassificationTrainer(model: Module, num_classes: Optional[int] = None, lr: float = 0.0001, weight_decay: float = 0.0, device_ids: List[int] = [], ddp_sync_bn: bool = True, ddp_replace_sampler: bool = True, ddp_val: bool = True, ddp_test: bool = True)[source][source]¶
A generic trainer class for EEG classification.
trainer = ClassificationTrainer(model) trainer.fit(train_loader, val_loader) trainer.test(test_loader)
The class provides the following hook functions for inserting additional implementations in the training, validation and testing lifecycle:
before_training_epoch
: executed before each epoch of training startsbefore_training_step
: executed before each batch of training startson_training_step
: the training process for each batchafter_training_step
: execute after the training of each batchafter_training_epoch
: executed after each epoch of trainingbefore_validation_epoch
: executed before each round of validation startsbefore_validation_step
: executed before the validation of each batchon_validation_step
: validation process for each batchafter_validation_step
: executed after the validation of each batchafter_validation_epoch
: executed after each round of validationbefore_test_epoch
: executed before each round of test startsbefore_test_step
: executed before the test of each batchon_test_step
: test process for each batchafter_test_step
: executed after the test of each batchafter_test_epoch
: executed after each round of test
If you want to customize some operations, you just need to inherit the class and override the hook function:
class MyClassificationTrainer(ClassificationTrainer): def before_training_epoch(self, epoch_id: int, num_epochs: int): # Do something here. super().before_training_epoch(epoch_id, num_epochs)
If you want to use multiple GPUs for parallel computing, you need to specify the GPU indices you want to use in the python file:
trainer = ClassificationTrainer(model, device_ids=[1, 2, 7]) trainer.fit(train_loader, val_loader) trainer.test(test_loader)
Then, you can use the
torch.distributed.launch
ortorchrun
to run your python file.python -m torch.distributed.launch \ --nproc_per_node=3 \ --nnodes=1 \ --node_rank=0 \ --master_addr="localhost" \ --master_port=2345 \ your_python_file.py
Here,
nproc_per_node
is the number of GPUs you specify.- Parameters
model (nn.Module) – The classification model, and the dimension of its output should be equal to the number of categories in the dataset. The output layer does not need to have a softmax activation function.
num_classes (int, optional) – The number of categories in the dataset. If
None
, the number of categories will be inferred from the attributenum_classes
of the model. (defualt:None
)lr (float) – The learning rate. (defualt:
0.0001
)weight_decay – (float): The weight decay (L2 penalty). (defualt:
0.0
)device_ids (list) – Use cpu if the list is empty. If the list contains indices of multiple GPUs, it needs to be launched with
torch.distributed.launch
ortorchrun
. (defualt:[]
)ddp_sync_bn (bool) – Whether to replace batch normalization in network structure with cross-GPU synchronized batch normalization. Only valid when the length of
device_ids
is greater than one. (defualt:True
)ddp_replace_sampler (bool) – Whether to replace sampler in dataloader with
DistributedSampler
. Only valid when the length ofdevice_ids
is greater than one. (defualt:True
)ddp_val (bool) – Whether to use multi-GPU acceleration for the validation set. For experiments where data input order is sensitive,
ddp_val
should be set toFalse
. Only valid when the length ofdevice_ids
is greater than one. (defualt:True
)ddp_test (bool) – Whether to use multi-GPU acceleration for the test set. For experiments where data input order is sensitive,
ddp_test
should be set toFalse
. Only valid when the length ofdevice_ids
is greater than one. (defualt:True
)
- fit(train_loader: DataLoader, val_loader: DataLoader, num_epochs: int = 1, **kwargs)[source][source]¶
- Parameters
train_loader (DataLoader) – Iterable DataLoader for traversing the training data batch (torch.utils.data.dataloader.DataLoader, torch_geometric.loader.DataLoader, etc).
val_loader (DataLoader) – Iterable DataLoader for traversing the validation data batch (torch.utils.data.dataloader.DataLoader, torch_geometric.loader.DataLoader, etc).
num_epochs (int) – training epochs. (defualt:
1
)
trainers.CORALTrainer¶
- class torcheeg.trainers.CORALTrainer(extractor: Module, classifier: Module, match_mean: bool = True, lambd: float = 1.0, num_classes: Optional[int] = None, lr: float = 0.0001, weight_decay: float = 0.0, device_ids: List[int] = [], ddp_sync_bn: bool = True, ddp_replace_sampler: bool = True, ddp_val: bool = True, ddp_test: bool = True)[source][source]¶
The individual differences and nonstationary of EEG signals make it difficult for deep learning models trained on the training set of subjects to correctly classify test samples from unseen subjects, since the training set and test set come from different data distributions. Domain adaptation is used to address the problem of distribution drift between training and test sets and thus achieves good performance in subject-independent (cross-subject) scenarios. This class supports the implementation of CORrelation ALignment (CORAL) for deep domain adaptation.
NOTE: CORAL belongs to unsupervised domain adaptation methods, which only use labeled source and unlabeled target data. This means that the target dataset does not have to return labels.
Paper: Sun B, Saenko K. Deep CORAL: Correlation alignment for deep domain adaptation[C]//European conference on computer vision. Springer, Cham, 2016: 443-450.
URL: https://link.springer.com/chapter/10.1007/978-3-030-04239-4_39
Related Project: https://github.com/adapt-python/adapt/blob/master/adapt/feature_based/_deepCORAL.py
trainer = CORALTrainer(extractor, classifier) trainer.fit(source_loader, target_loader, val_loader) trainer.test(test_loader)
The class provides the following hook functions for inserting additional implementations in the training, validation and testing lifecycle:
before_training_epoch
: executed before each epoch of training startsbefore_training_step
: executed before each batch of training startson_training_step
: the training process for each batchafter_training_step
: execute after the training of each batchafter_training_epoch
: executed after each epoch of trainingbefore_validation_epoch
: executed before each round of validation startsbefore_validation_step
: executed before the validation of each batchon_validation_step
: validation process for each batchafter_validation_step
: executed after the validation of each batchafter_validation_epoch
: executed after each round of validationbefore_test_epoch
: executed before each round of test startsbefore_test_step
: executed before the test of each batchon_test_step
: test process for each batchafter_test_step
: executed after the test of each batchafter_test_epoch
: executed after each round of test
If you want to customize some operations, you just need to inherit the class and override the hook function:
class MyCORALTrainer(CORALTrainer): def before_training_epoch(self, epoch_id: int, num_epochs: int): # Do something here. super().before_training_epoch(epoch_id, num_epochs)
If you want to use multiple GPUs for parallel computing, you need to specify the GPU indices you want to use in the python file:
trainer = CORALTrainer(model, device_ids=[1, 2, 7]) trainer.fit(train_loader, val_loader) trainer.test(test_loader)
Then, you can use the
torch.distributed.launch
ortorchrun
to run your python file.python -m torch.distributed.launch \ --nproc_per_node=3 \ --nnodes=1 \ --node_rank=0 \ --master_addr="localhost" \ --master_port=2345 \ your_python_file.py
Here,
nproc_per_node
is the number of GPUs you specify.- Parameters
extractor (nn.Module) – The feature extraction model, learning the feature representation of EEG signal by forcing the correlation matrixes of source and target data close.
classifier (nn.Module) – The classification model, learning the classification task with source labeled data based on the feature of the feature extraction model. The dimension of its output should be equal to the number of categories in the dataset. The output layer does not need to have a softmax activation function.
lambd (float) – The weight of CORAL loss to trade-off between the classification loss and CORAL loss. (defualt:
1.0
)match_mean (bool) – Weither to match the means of the source domain and target domain samples. If
False
, only the second moment is matched. (defualt:False
)num_classes (int, optional) – The number of categories in the dataset. If
None
, the number of categories will be inferred from the attributenum_classes
of the model. (defualt:None
)lr (float) – The learning rate. (defualt:
0.0001
)weight_decay (float) – The weight decay (L2 penalty). (defualt:
0.0
)device_ids (list) – Use cpu if the list is empty. If the list contains indices of multiple GPUs, it needs to be launched with
torch.distributed.launch
ortorchrun
. (defualt:[]
)ddp_sync_bn (bool) – Whether to replace batch normalization in network structure with cross-GPU synchronized batch normalization. Only valid when the length of
device_ids
is greater than one. (defualt:True
)ddp_replace_sampler (bool) – Whether to replace sampler in dataloader with
DistributedSampler
. Only valid when the length ofdevice_ids
is greater than one. (defualt:True
)ddp_val (bool) – Whether to use multi-GPU acceleration for the validation set. For experiments where data input order is sensitive,
ddp_val
should be set toFalse
. Only valid when the length ofdevice_ids
is greater than one. (defualt:True
)ddp_test (bool) – Whether to use multi-GPU acceleration for the test set. For experiments where data input order is sensitive,
ddp_test
should be set toFalse
. Only valid when the length ofdevice_ids
is greater than one. (defualt:True
)
- fit(source_loader: DataLoader, target_loader: DataLoader, val_loader: DataLoader, num_epochs: int = 1, **kwargs)[source][source]¶
- Parameters
source_loader (DataLoader) – Iterable DataLoader for traversing the data batch from the source domain (torch.utils.data.dataloader.DataLoader, torch_geometric.loader.DataLoader, etc).
target_loader (DataLoader) – Iterable DataLoader for traversing the training data batch from the target domain (torch.utils.data.dataloader.DataLoader, torch_geometric.loader.DataLoader, etc). The target dataset does not have to return labels.
val_loader (DataLoader) – Iterable DataLoader for traversing the validation data batch (torch.utils.data.dataloader.DataLoader, torch_geometric.loader.DataLoader, etc).
num_epochs (int) – training epochs. (defualt:
1
)
trainers.DDCTrainer¶
- class torcheeg.trainers.DDCTrainer(extractor: Module, classifier: Module, lambd: float = 1.0, adaption_factor: bool = False, num_classes: Optional[int] = None, lr: float = 0.0001, weight_decay: float = 0.0, device_ids: List[int] = [], ddp_sync_bn: bool = True, ddp_replace_sampler: bool = True, ddp_val: bool = True, ddp_test: bool = True)[source][source]¶
The individual differences and nonstationary of EEG signals make it difficult for deep learning models trained on the training set of subjects to correctly classify test samples from unseen subjects, since the training set and test set come from different data distributions. Domain adaptation is used to address the problem of distribution drift between training and test sets and thus achieves good performance in subject-independent (cross-subject) scenarios. This class supports the implementation of Deep Domain Confusion (DDC) for deep domain adaptation.
NOTE: DDC belongs to unsupervised domain adaptation methods, which only use labeled source and unlabeled target data. This means that the target dataset does not have to return labels.
Paper: Tzeng E, Hoffman J, Zhang N, et al. Deep domain confusion: Maximizing for domain invariance[J]. arXiv preprint arXiv:1412.3474, 2014.
Related Project: https://github.com/syorami/DDC-transfer-learning/blob/master/DDC.py
trainer = DDCTrainer(extractor, classifier) trainer.fit(source_loader, target_loader, val_loader) trainer.test(test_loader)
The class provides the following hook functions for inserting additional implementations in the training, validation and testing lifecycle:
before_training_epoch
: executed before each epoch of training startsbefore_training_step
: executed before each batch of training startson_training_step
: the training process for each batchafter_training_step
: execute after the training of each batchafter_training_epoch
: executed after each epoch of trainingbefore_validation_epoch
: executed before each round of validation startsbefore_validation_step
: executed before the validation of each batchon_validation_step
: validation process for each batchafter_validation_step
: executed after the validation of each batchafter_validation_epoch
: executed after each round of validationbefore_test_epoch
: executed before each round of test startsbefore_test_step
: executed before the test of each batchon_test_step
: test process for each batchafter_test_step
: executed after the test of each batchafter_test_epoch
: executed after each round of test
You can override the methods of this interface to implement your own trainer:
If you want to customize some operations, you just need to inherit the class and override the hook function:
class MyDDCTrainer(DDCTrainer): def before_training_epoch(self, epoch_id: int, num_epochs: int): # Do something here. super().before_training_epoch(epoch_id, num_epochs)
If you want to use multiple GPUs for parallel computing, you need to specify the GPU indices you want to use in the python file:
trainer = DDCTrainer(model, device_ids=[1, 2, 7]) trainer.fit(train_loader, val_loader) trainer.test(test_loader)
Then, you can use the
torch.distributed.launch
ortorchrun
to run your python file.python -m torch.distributed.launch \ --nproc_per_node=3 \ --nnodes=1 \ --node_rank=0 \ --master_addr="localhost" \ --master_port=2345 \ your_python_file.py
Here,
nproc_per_node
is the number of GPUs you specify.- Parameters
extractor (nn.Module) – The feature extraction model, learning the feature representation of EEG signal by forcing the correlation matrixes of source and target data close.
classifier (nn.Module) – The classification model, learning the classification task with source labeled data based on the feature of the feature extraction model. The dimension of its output should be equal to the number of categories in the dataset. The output layer does not need to have a softmax activation function.
lambd (float) – The weight of DDC loss to trade-off between the classification loss and DDC loss. (defualt:
1.0
)adaption_factor (bool) – Whether to adjust the cross-domain-related loss term using the fitness factor, which was first proposed in DANN but works in many cases. (defualt:
False
)num_classes (int, optional) – The number of categories in the dataset. If
None
, the number of categories will be inferred from the attributenum_classes
of the model. (defualt:None
)lr (float) – The learning rate. (defualt:
0.0001
)weight_decay – (float): The weight decay (L2 penalty). (defualt:
0.0
)device_ids (list) – Use cpu if the list is empty. If the list contains indices of multiple GPUs, it needs to be launched with
torch.distributed.launch
ortorchrun
. (defualt:[]
)ddp_sync_bn (bool) – Whether to replace batch normalization in network structure with cross-GPU synchronized batch normalization. Only valid when the length of
device_ids
is greater than one. (defualt:True
)ddp_replace_sampler (bool) – Whether to replace sampler in dataloader with
DistributedSampler
. Only valid when the length ofdevice_ids
is greater than one. (defualt:True
)ddp_val (bool) – Whether to use multi-GPU acceleration for the validation set. For experiments where data input order is sensitive,
ddp_val
should be set toFalse
. Only valid when the length ofdevice_ids
is greater than one. (defualt:True
)ddp_test (bool) – Whether to use multi-GPU acceleration for the test set. For experiments where data input order is sensitive,
ddp_test
should be set toFalse
. Only valid when the length ofdevice_ids
is greater than one. (defualt:True
)
- fit(source_loader: DataLoader, target_loader: DataLoader, val_loader: DataLoader, num_epochs: int = 1, **kwargs)[source][source]¶
- Parameters
source_loader (DataLoader) – Iterable DataLoader for traversing the data batch from the source domain (torch.utils.data.dataloader.DataLoader, torch_geometric.loader.DataLoader, etc).
target_loader (DataLoader) – Iterable DataLoader for traversing the training data batch from the target domain (torch.utils.data.dataloader.DataLoader, torch_geometric.loader.DataLoader, etc). The target dataset does not have to return labels.
val_loader (DataLoader) – Iterable DataLoader for traversing the validation data batch (torch.utils.data.dataloader.DataLoader, torch_geometric.loader.DataLoader, etc).
num_epochs (int) – training epochs. (defualt:
1
)
trainers.DANNTrainer¶
- class torcheeg.trainers.DANNTrainer(extractor: Module, classifier: Module, domain_classifier: Module, lambd: float = 1.0, adaption_factor: bool = False, num_classes: Optional[int] = None, lr: float = 0.0001, weight_decay: float = 0.0, device_ids: List[int] = [], ddp_sync_bn: bool = True, ddp_replace_sampler: bool = True, ddp_val: bool = True, ddp_test: bool = True)[source][source]¶
The individual differences and nonstationary of EEG signals make it difficult for deep learning models trained on the training set of subjects to correctly classify test samples from unseen subjects, since the training set and test set come from different data distributions. Domain adaptation is used to address the problem of distribution drift between training and test sets and thus achieves good performance in subject-independent (cross-subject) scenarios. This class supports the implementation of Domain Adversarial Neural Networks (DANN) for deep domain adaptation.
NOTE: DANN belongs to unsupervised domain adaptation methods, which only use labeled source and unlabeled target data. This means that the target dataset does not have to return labels.
Paper: Ganin Y, Ustinova E, Ajakan H, et al. Domain-adversarial training of neural networks[J]. The journal of machine learning research, 2016, 17(1): 2096-2030.
Related Project: https://github.com/fungtion/DANN/blob/master/train/main.py
trainer = DANNTrainer(extractor, classifier, domain_classifier) trainer.fit(source_loader, target_loader, val_loader) trainer.test(test_loader)
The class provides the following hook functions for inserting additional implementations in the training, validation and testing lifecycle:
before_training_epoch
: executed before each epoch of training startsbefore_training_step
: executed before each batch of training startson_training_step
: the training process for each batchafter_training_step
: execute after the training of each batchafter_training_epoch
: executed after each epoch of trainingbefore_validation_epoch
: executed before each round of validation startsbefore_validation_step
: executed before the validation of each batchon_validation_step
: validation process for each batchafter_validation_step
: executed after the validation of each batchafter_validation_epoch
: executed after each round of validationbefore_test_epoch
: executed before each round of test startsbefore_test_step
: executed before the test of each batchon_test_step
: test process for each batchafter_test_step
: executed after the test of each batchafter_test_epoch
: executed after each round of test
If you want to customize some operations, you just need to inherit the class and override the hook function:
class MyDANNTrainer(DANNTrainer): def before_training_epoch(self, epoch_id: int, num_epochs: int): # Do something here. super().before_training_epoch(epoch_id, num_epochs)
If you want to use multiple GPUs for parallel computing, you need to specify the GPU indices you want to use in the python file:
trainer = DANNTrainer(model, device_ids=[1, 2, 7]) trainer.fit(train_loader, val_loader) trainer.test(test_loader)
Then, you can use the
torch.distributed.launch
ortorchrun
to run your python file.python -m torch.distributed.launch \ --nproc_per_node=3 \ --nnodes=1 \ --node_rank=0 \ --master_addr="localhost" \ --master_port=2345 \ your_python_file.py
Here,
nproc_per_node
is the number of GPUs you specify.- Parameters
extractor (nn.Module) – The feature extraction model, learning the feature representation of EEG signal by forcing the correlation matrixes of source and target data close.
classifier (nn.Module) – The classification model, learning the classification task with source labeled data based on the feature of the feature extraction model. The dimension of its output should be equal to the number of categories in the dataset. The output layer does not need to have a softmax activation function.
domain_classifier (nn.Module) – The classification model, learning to discriminate between the source and target domains based on the feature of the feature extraction model. The dimension of its output should be equal to the number of categories in the dataset. The output layer does not need to have a softmax activation function or a gradient reverse layer (which is already included in the implementation).
lambd (float) – The weight of DANN loss to trade-off between the classification loss and DANN loss. (defualt:
1.0
)adaption_factor (bool) – Whether to adjust the cross-domain-related loss term using the fitness factor, which was first proposed in DANN but works in many cases. (defualt:
True
)num_classes (int, optional) – The number of categories in the dataset. If
None
, the number of categories will be inferred from the attributenum_classes
of the model. (defualt:None
)lr (float) – The learning rate. (defualt:
0.0001
)weight_decay – (float): The weight decay (L2 penalty). (defualt:
0.0
)device – (torch.device or str): The device on which the model and data is or will be allocated. (defualt:
False
)device_ids (list) – Use cpu if the list is empty. If the list contains indices of multiple GPUs, it needs to be launched with
torch.distributed.launch
ortorchrun
. (defualt:[]
)ddp_sync_bn (bool) – Whether to replace batch normalization in network structure with cross-GPU synchronized batch normalization. Only valid when the length of
device_ids
is greater than one. (defualt:True
)ddp_replace_sampler (bool) – Whether to replace sampler in dataloader with
DistributedSampler
. Only valid when the length ofdevice_ids
is greater than one. (defualt:True
)ddp_val (bool) – Whether to use multi-GPU acceleration for the validation set. For experiments where data input order is sensitive,
ddp_val
should be set toFalse
. Only valid when the length ofdevice_ids
is greater than one. (defualt:True
)ddp_test (bool) – Whether to use multi-GPU acceleration for the test set. For experiments where data input order is sensitive,
ddp_test
should be set toFalse
. Only valid when the length ofdevice_ids
is greater than one. (defualt:True
)
- fit(source_loader: DataLoader, target_loader: DataLoader, val_loader: DataLoader, num_epochs: int = 1, **kwargs)[source][source]¶
- Parameters
source_loader (DataLoader) – Iterable DataLoader for traversing the data batch from the source domain (torch.utils.data.dataloader.DataLoader, torch_geometric.loader.DataLoader, etc).
target_loader (DataLoader) – Iterable DataLoader for traversing the training data batch from the target domain (torch.utils.data.dataloader.DataLoader, torch_geometric.loader.DataLoader, etc). The target dataset does not have to return labels.
val_loader (DataLoader) – Iterable DataLoader for traversing the validation data batch (torch.utils.data.dataloader.DataLoader, torch_geometric.loader.DataLoader, etc).
num_epochs (int) – training epochs. (defualt:
1
)
trainers.DANTrainer¶
- class torcheeg.trainers.DANTrainer(extractor: Module, classifier: Module, lambd: float = 1.0, adaption_factor: bool = False, num_classes: Optional[int] = None, lr: float = 0.0001, weight_decay: float = 0.0, device_ids: List[int] = [], ddp_sync_bn: bool = True, ddp_replace_sampler: bool = True, ddp_val: bool = True, ddp_test: bool = True)[source][source]¶
The individual differences and nonstationary of EEG signals make it difficult for deep learning models trained on the training set of subjects to correctly classify test samples from unseen subjects, since the training set and test set come from different data distributions. Domain adaptation is used to address the problem of distribution drift between training and test sets and thus achieves good performance in subject-independent (cross-subject) scenarios. This class supports the implementation of Deep Adaptation Network (DAN) for deep domain adaptation.
NOTE: DAN belongs to unsupervised domain adaptation methods, which only use labeled source and unlabeled target data. This means that the target dataset does not have to return labels.
Paper: Long M, Cao Y, Wang J, et al. Learning transferable features with deep adaptation networks[C]//International conference on machine learning. PMLR, 2015: 97-105.
Related Project: https://github.com/jindongwang/transferlearning/blob/cfaf1174dff7390a861cc4abd5ede37dfa1063f5/code/deep/DAN/DAN.py
trainer = DANTrainer(extractor, classifier) trainer.fit(source_loader, target_loader, val_loader) trainer.test(test_loader)
The class provides the following hook functions for inserting additional implementations in the training, validation and testing lifecycle:
before_training_epoch
: executed before each epoch of training startsbefore_training_step
: executed before each batch of training startson_training_step
: the training process for each batchafter_training_step
: execute after the training of each batchafter_training_epoch
: executed after each epoch of trainingbefore_validation_epoch
: executed before each round of validation startsbefore_validation_step
: executed before the validation of each batchon_validation_step
: validation process for each batchafter_validation_step
: executed after the validation of each batchafter_validation_epoch
: executed after each round of validationbefore_test_epoch
: executed before each round of test startsbefore_test_step
: executed before the test of each batchon_test_step
: test process for each batchafter_test_step
: executed after the test of each batchafter_test_epoch
: executed after each round of test
If you want to customize some operations, you just need to inherit the class and override the hook function:
class MyDANTrainer(DANTrainer): def before_training_epoch(self, epoch_id: int, num_epochs: int): # Do something here. super().before_training_epoch(epoch_id, num_epochs)
If you want to use multiple GPUs for parallel computing, you need to specify the GPU indices you want to use in the python file:
trainer = DANTrainer(model, device_ids=[1, 2, 7]) trainer.fit(train_loader, val_loader) trainer.test(test_loader)
Then, you can use the
torch.distributed.launch
ortorchrun
to run your python file.python -m torch.distributed.launch \ --nproc_per_node=3 \ --nnodes=1 \ --node_rank=0 \ --master_addr="localhost" \ --master_port=2345 \ your_python_file.py
Here,
nproc_per_node
is the number of GPUs you specify.- Parameters
extractor (nn.Module) – The feature extraction model, learning the feature representation of EEG signal by forcing the correlation matrixes of source and target data close.
classifier (nn.Module) – The classification model, learning the classification task with source labeled data based on the feature of the feature extraction model. The dimension of its output should be equal to the number of categories in the dataset. The output layer does not need to have a softmax activation function.
lambd (float) – The weight of DAN loss to trade-off between the classification loss and DAN loss. (defualt:
1.0
)adaption_factor (bool) – Whether to adjust the cross-domain-related loss term using the fitness factor, which was first proposed in DANN but works in many cases. (defualt:
True
)num_classes (int, optional) – The number of categories in the dataset. If
None
, the number of categories will be inferred from the attributenum_classes
of the model. (defualt:None
)lr (float) – The learning rate. (defualt:
0.0001
)weight_decay (float) – The weight decay (L2 penalty). (defualt:
0.0
)device_ids (list) – Use cpu if the list is empty. If the list contains indices of multiple GPUs, it needs to be launched with
torch.distributed.launch
ortorchrun
. (defualt:[]
)ddp_sync_bn (bool) – Whether to replace batch normalization in network structure with cross-GPU synchronized batch normalization. Only valid when the length of
device_ids
is greater than one. (defualt:True
)ddp_replace_sampler (bool) – Whether to replace sampler in dataloader with
DistributedSampler
. Only valid when the length ofdevice_ids
is greater than one. (defualt:True
)ddp_val (bool) – Whether to use multi-GPU acceleration for the validation set. For experiments where data input order is sensitive,
ddp_val
should be set toFalse
. Only valid when the length ofdevice_ids
is greater than one. (defualt:True
)ddp_test (bool) – Whether to use multi-GPU acceleration for the test set. For experiments where data input order is sensitive,
ddp_test
should be set toFalse
. Only valid when the length ofdevice_ids
is greater than one. (defualt:True
)
- fit(source_loader: DataLoader, target_loader: DataLoader, val_loader: DataLoader, num_epochs: int = 1, **kwargs)[source][source]¶
- Parameters
source_loader (DataLoader) – Iterable DataLoader for traversing the data batch from the source domain (torch.utils.data.dataloader.DataLoader, torch_geometric.loader.DataLoader, etc).
target_loader (DataLoader) – Iterable DataLoader for traversing the training data batch from the target domain (torch.utils.data.dataloader.DataLoader, torch_geometric.loader.DataLoader, etc). The target dataset does not have to return labels.
val_loader (DataLoader) – Iterable DataLoader for traversing the validation data batch (torch.utils.data.dataloader.DataLoader, torch_geometric.loader.DataLoader, etc).
num_epochs (int) – training epochs. (defualt:
1
)
trainers.ADATrainer¶
- class torcheeg.trainers.ADATrainer(extractor: Module, classifier: Module, lambd: float = 1.0, adaption_factor: bool = False, num_classes: Optional[int] = None, lr: float = 0.0001, walker_weight: float = 1.0, visit_weight: float = 1.0, weight_decay: float = 0.0, device_ids: List[int] = [], ddp_sync_bn: bool = True, ddp_replace_sampler: bool = True, ddp_val: bool = True, ddp_test: bool = True)[source][source]¶
The individual differences and nonstationary of EEG signals make it difficult for deep learning models trained on the training set of subjects to correctly classify test samples from unseen subjects, since the training set and test set come from different data distributions. Domain adaptation is used to address the problem of distribution drift between training and test sets and thus achieves good performance in subject-independent (cross-subject) scenarios. This class supports the implementation of Associative Domain Adaptation (ADA) for deep domain adaptation.
NOTE: ADA belongs to unsupervised domain adaptation methods, which only use labeled source and unlabeled target data. This means that the target dataset does not have to return labels.
Paper: Haeusser P, Frerix T, Mordvintsev A, et al. Associative domain adaptation[C]//Proceedings of the IEEE international conference on computer vision. 2017: 2765-2773.
Related Project: https://github.com/stes/torch-associative
trainer = ADATrainer(extractor, classifier) trainer.fit(source_loader, target_loader, val_loader) trainer.test(test_loader)
The class provides the following hook functions for inserting additional implementations in the training, validation and testing lifecycle:
before_training_epoch
: executed before each epoch of training startsbefore_training_step
: executed before each batch of training startson_training_step
: the training process for each batchafter_training_step
: execute after the training of each batchafter_training_epoch
: executed after each epoch of trainingbefore_validation_epoch
: executed before each round of validation startsbefore_validation_step
: executed before the validation of each batchon_validation_step
: validation process for each batchafter_validation_step
: executed after the validation of each batchafter_validation_epoch
: executed after each round of validationbefore_test_epoch
: executed before each round of test startsbefore_test_step
: executed before the test of each batchon_test_step
: test process for each batchafter_test_step
: executed after the test of each batchafter_test_epoch
: executed after each round of test
If you want to customize some operations, you just need to inherit the class and override the hook function:
class MyADATrainer(ADATrainer): def before_training_epoch(self, epoch_id: int, num_epochs: int): # Do something here. super().before_training_epoch(epoch_id, num_epochs)
If you want to use multiple GPUs for parallel computing, you need to specify the GPU indices you want to use in the python file:
trainer = ADATrainer(model, device_ids=[1, 2, 7]) trainer.fit(train_loader, val_loader) trainer.test(test_loader)
Then, you can use the
torch.distributed.launch
ortorchrun
to run your python file.python -m torch.distributed.launch \ --nproc_per_node=3 \ --nnodes=1 \ --node_rank=0 \ --master_addr="localhost" \ --master_port=2345 \ your_python_file.py
Here,
nproc_per_node
is the number of GPUs you specify.- Parameters
extractor (nn.Module) – The feature extraction model, learning the feature representation of EEG signal by forcing the correlation matrixes of source and target data close.
classifier (nn.Module) – The classification model, learning the classification task with source labeled data based on the feature of the feature extraction model. The dimension of its output should be equal to the number of categories in the dataset. The output layer does not need to have a softmax activation function.
lambd (float) – The weight of ADA loss to trade-off between the classification loss and ADA loss. (defualt:
1.0
)adaption_factor (bool) – Whether to adjust the cross-domain-related loss term using the fitness factor, which was first proposed in DANN but works in many cases. (defualt:
False
)num_classes (int, optional) – The number of categories in the dataset. If
None
, the number of categories will be inferred from the attributenum_classes
of the model. (defualt:None
)lr (float) – The learning rate. (defualt:
0.0001
)walker_weight (float) – The weight of walker loss. (defualt:
1.0
)visit_weight (float) – The weight of visit loss. (defualt:
1.0
)weight_decay (float) – The weight decay (L2 penalty). (defualt:
0.0
)device_ids (list) – Use cpu if the list is empty. If the list contains indices of multiple GPUs, it needs to be launched with
torch.distributed.launch
ortorchrun
. (defualt:[]
)ddp_sync_bn (bool) – Whether to replace batch normalization in network structure with cross-GPU synchronized batch normalization. Only valid when the length of
device_ids
is greater than one. (defualt:True
)ddp_replace_sampler (bool) – Whether to replace sampler in dataloader with
DistributedSampler
. Only valid when the length ofdevice_ids
is greater than one. (defualt:True
)ddp_val (bool) – Whether to use multi-GPU acceleration for the validation set. For experiments where data input order is sensitive,
ddp_val
should be set toFalse
. Only valid when the length ofdevice_ids
is greater than one. (defualt:True
)ddp_test (bool) – Whether to use multi-GPU acceleration for the test set. For experiments where data input order is sensitive,
ddp_test
should be set toFalse
. Only valid when the length ofdevice_ids
is greater than one. (defualt:True
)
- fit(source_loader: DataLoader, target_loader: DataLoader, val_loader: DataLoader, num_epochs: int = 1, **kwargs)[source][source]¶
- Parameters
source_loader (DataLoader) – Iterable DataLoader for traversing the data batch from the source domain (torch.utils.data.dataloader.DataLoader, torch_geometric.loader.DataLoader, etc).
target_loader (DataLoader) – Iterable DataLoader for traversing the training data batch from the target domain (torch.utils.data.dataloader.DataLoader, torch_geometric.loader.DataLoader, etc). The target dataset does not have to return labels.
val_loader (DataLoader) – Iterable DataLoader for traversing the validation data batch (torch.utils.data.dataloader.DataLoader, torch_geometric.loader.DataLoader, etc).
num_epochs (int) – training epochs. (defualt:
1
)
trainers.GANTrainer¶
- class torcheeg.trainers.GANTrainer(generator: Module, discriminator: Module, generator_lr: float = 0.0001, discriminator_lr: float = 0.0001, lambd: float = 1.0, weight_decay: float = 0.0, device_ids: List[int] = [], ddp_sync_bn: bool = True, ddp_replace_sampler: bool = True, ddp_val: bool = True, ddp_test: bool = True)[source][source]¶
This class provide the implementation for WGAN-GP. It trains a zero-sum game between the generator and the discriminator, just like the traditional generative networks. The generator is optimized to generate simulation samples that are indistinguishable by the discriminator, and the discriminator is optimized to discriminate false samples generated by the generator. Compared with vanilla GAN, with WGAN-GP we can improve the stability of learning, get rid of problems like mode collapse, and provide meaningful learning curves useful for debugging and hyperparameter searches. Thus, existing work typically uses WGAN-GP to generate simulated EEG signals. For more details, please refer to the following information.
Paper: Gulrajani I, Ahmed F, Arjovsky M, et al. Improved training of wasserstein gans[J]. Advances in neural information processing systems, 2017, 30.
Related Project: https://github.com/eriklindernoren/PyTorch-GAN
g_model = BGenerator(in_channels=128) d_model = BDiscriminator(in_channels=4) trainer = GANTrainer(generator, discriminator) trainer.fit(train_loader, val_loader) trainer.test(test_loader)
The class provides the following hook functions for inserting additional implementations in the training, validation and testing lifecycle:
before_training_epoch
: executed before each epoch of training startsbefore_training_step
: executed before each batch of training startson_training_step
: the training process for each batchafter_training_step
: execute after the training of each batchafter_training_epoch
: executed after each epoch of trainingbefore_validation_epoch
: executed before each round of validation startsbefore_validation_step
: executed before the validation of each batchon_validation_step
: validation process for each batchafter_validation_step
: executed after the validation of each batchafter_validation_epoch
: executed after each round of validationbefore_test_epoch
: executed before each round of test startsbefore_test_step
: executed before the test of each batchon_test_step
: test process for each batchafter_test_step
: executed after the test of each batchafter_test_epoch
: executed after each round of test
If you want to customize some operations, you just need to inherit the class and override the hook function:
class MyGANTrainer(GANTrainer): def before_training_epoch(self, epoch_id: int, num_epochs: int): # Do something here. super().before_training_epoch(epoch_id, num_epochs)
If you want to use multiple GPUs for parallel computing, you need to specify the GPU indices you want to use in the python file:
trainer = GANTrainer(generator, discriminator, device_ids=[1, 2, 7]) trainer.fit(train_loader, val_loader) trainer.test(test_loader)
Then, you can use the
torch.distributed.launch
ortorchrun
to run your python file.python -m torch.distributed.launch \ --nproc_per_node=3 \ --nnodes=1 \ --node_rank=0 \ --master_addr="localhost" \ --master_port=2345 \ your_python_file.py
Here,
nproc_per_node
is the number of GPUs you specify.- Parameters
generator (nn.Module) – The generator model for EEG signal generation, whose inputs are Gaussian distributed random vectors, outputs are generated EEG signals. The dimensions of the input vector should be defined on the
in_channel
attribute. The output layer does not need to have a softmax activation function.discriminator (nn.Module) – The discriminator model to determine whether the EEG signal is real or generated, and the dimension of its output should be equal to the one (i.e., the score to distinguish the real and the fake). The output layer does not need to have a sigmoid activation function.
generator_lr (float) – The learning rate of the generator. (defualt:
0.0001
)discriminator_lr (float) – The learning rate of the discriminator. (defualt:
0.0001
)lambd (float) – The weight of gradient penalty loss to trade-off between the adversarial training loss and gradient penalty loss. (defualt:
1.0
)weight_decay – (float): The weight decay (L2 penalty). (defualt:
0.0
)device_ids (list) – Use cpu if the list is empty. If the list contains indices of multiple GPUs, it needs to be launched with
torch.distributed.launch
ortorchrun
. (defualt:[]
)ddp_sync_bn (bool) – Whether to replace batch normalization in network structure with cross-GPU synchronized batch normalization. Only valid when the length of
device_ids
is greater than one. (defualt:True
)ddp_replace_sampler (bool) – Whether to replace sampler in dataloader with
DistributedSampler
. Only valid when the length ofdevice_ids
is greater than one. (defualt:True
)ddp_val (bool) – Whether to use multi-GPU acceleration for the validation set. For experiments where data input order is sensitive,
ddp_val
should be set toFalse
. Only valid when the length ofdevice_ids
is greater than one. (defualt:True
)ddp_test (bool) – Whether to use multi-GPU acceleration for the test set. For experiments where data input order is sensitive,
ddp_test
should be set toFalse
. Only valid when the length ofdevice_ids
is greater than one. (defualt:True
)
- fit(train_loader: DataLoader, val_loader: DataLoader, num_epochs: int = 1, **kwargs)[source][source]¶
- Parameters
train_loader (DataLoader) – Iterable DataLoader for traversing the training data batch (torch.utils.data.dataloader.DataLoader, torch_geometric.loader.DataLoader, etc).
val_loader (DataLoader) – Iterable DataLoader for traversing the validation data batch (torch.utils.data.dataloader.DataLoader, torch_geometric.loader.DataLoader, etc).
num_epochs (int) – training epochs. (defualt:
1
)
trainers.CGANTrainer¶
- class torcheeg.trainers.CGANTrainer(generator: Module, discriminator: Module, generator_lr: float = 0.0001, discriminator_lr: float = 0.0001, lambd: float = 1.0, weight_decay: float = 0.0, device_ids: List[int] = [], ddp_sync_bn: bool = True, ddp_replace_sampler: bool = True, ddp_val: bool = True, ddp_test: bool = True)[source][source]¶
This class provide the implementation for WGAN-GP. It trains a zero-sum game between the generator and the discriminator, just like the traditional generative networks. The generator is optimized to generate simulation samples that are indistinguishable by the discriminator, and the discriminator is optimized to discriminate false samples generated by the generator. Compared with vanilla GAN, with WGAN-GP we can improve the stability of learning, get rid of problems like mode collapse, and provide meaningful learning curves useful for debugging and hyperparameter searches. Thus, existing work typically uses WGAN-GP to generate simulated EEG signals. In particular, the expected labels are additionally provided to guide the discriminator to distinguish whether the sample fits the data distribution of the class. For more details, please refer to the following information.
Paper: Gulrajani I, Ahmed F, Arjovsky M, et al. Improved training of wasserstein gans[J]. Advances in neural information processing systems, 2017, 30.
Related Project: https://github.com/eriklindernoren/PyTorch-GAN
g_model = BCGenerator(in_channels=128, num_classes=2) d_model = BCDiscriminator(in_channels=4, num_classes=2) trainer = CGANTrainer(generator, discriminator) trainer.fit(train_loader, val_loader) trainer.test(test_loader)
The class provides the following hook functions for inserting additional implementations in the training, validation and testing lifecycle:
before_training_epoch
: executed before each epoch of training startsbefore_training_step
: executed before each batch of training startson_training_step
: the training process for each batchafter_training_step
: execute after the training of each batchafter_training_epoch
: executed after each epoch of trainingbefore_validation_epoch
: executed before each round of validation startsbefore_validation_step
: executed before the validation of each batchon_validation_step
: validation process for each batchafter_validation_step
: executed after the validation of each batchafter_validation_epoch
: executed after each round of validationbefore_test_epoch
: executed before each round of test startsbefore_test_step
: executed before the test of each batchon_test_step
: test process for each batchafter_test_step
: executed after the test of each batchafter_test_epoch
: executed after each round of test
If you want to customize some operations, you just need to inherit the class and override the hook function:
class MyGANTrainer(GANTrainer): def before_training_epoch(self, epoch_id: int, num_epochs: int): # Do something here. super().before_training_epoch(epoch_id, num_epochs)
If you want to use multiple GPUs for parallel computing, you need to specify the GPU indices you want to use in the python file:
trainer = GANTrainer(generator, discriminator, device_ids=[1, 2, 7]) trainer.fit(train_loader, val_loader) trainer.test(test_loader)
Then, you can use the
torch.distributed.launch
ortorchrun
to run your python file.python -m torch.distributed.launch \ --nproc_per_node=3 \ --nnodes=1 \ --node_rank=0 \ --master_addr="localhost" \ --master_port=2345 \ your_python_file.py
Here,
nproc_per_node
is the number of GPUs you specify.- Parameters
generator (nn.Module) – The generator model for EEG signal generation, whose inputs are Gaussian distributed random vectors, outputs are generated EEG signals. The dimensions of the input vector should be defined on the
in_channel
attribute. The output layer does not need to have a softmax activation function.discriminator (nn.Module) – The discriminator model to determine whether the EEG signal is real or generated, and the dimension of its output should be equal to the one (i.e., the score to distinguish the real and the fake). The output layer does not need to have a sigmoid activation function.
generator_lr (float) – The learning rate of the generator. (defualt:
0.0001
)discriminator_lr (float) – The learning rate of the discriminator. (defualt:
0.0001
)lambd (float) – The weight of gradient penalty loss to trade-off between the adversarial training loss and gradient penalty loss. (defualt:
1.0
)weight_decay – (float): The weight decay (L2 penalty). (defualt:
0.0
)device_ids (list) – Use cpu if the list is empty. If the list contains indices of multiple GPUs, it needs to be launched with
torch.distributed.launch
ortorchrun
. (defualt:[]
)ddp_sync_bn (bool) – Whether to replace batch normalization in network structure with cross-GPU synchronized batch normalization. Only valid when the length of
device_ids
is greater than one. (defualt:True
)ddp_replace_sampler (bool) – Whether to replace sampler in dataloader with
DistributedSampler
. Only valid when the length ofdevice_ids
is greater than one. (defualt:True
)ddp_val (bool) – Whether to use multi-GPU acceleration for the validation set. For experiments where data input order is sensitive,
ddp_val
should be set toFalse
. Only valid when the length ofdevice_ids
is greater than one. (defualt:True
)ddp_test (bool) – Whether to use multi-GPU acceleration for the test set. For experiments where data input order is sensitive,
ddp_test
should be set toFalse
. Only valid when the length ofdevice_ids
is greater than one. (defualt:True
)
- fit(train_loader: DataLoader, val_loader: DataLoader, num_epochs: int = 1, **kwargs)[source]¶
- Parameters
train_loader (DataLoader) – Iterable DataLoader for traversing the training data batch (torch.utils.data.dataloader.DataLoader, torch_geometric.loader.DataLoader, etc).
val_loader (DataLoader) – Iterable DataLoader for traversing the validation data batch (torch.utils.data.dataloader.DataLoader, torch_geometric.loader.DataLoader, etc).
num_epochs (int) – training epochs. (defualt:
1
)
- test(test_loader: DataLoader, **kwargs)[source]¶
- Parameters
test_loader (DataLoader) – Iterable DataLoader for traversing the test data batch (torch.utils.data.dataloader.DataLoader, torch_geometric.loader.DataLoader, etc).
- sample(num_samples: int, labels: Optional[Tensor] = None) Tensor [source][source]¶
Samples from the latent space and return generated results.
- Parameters
num_samples (int) – Number of samples.
labels (torch.Tensor) – Category labels (int) for a batch of samples The shape should be
[n,]
. Here,n
corresponds to the batch size. If not provided, a batch of randomly generated categories will be used.
- Returns
the generated samples.
- Return type
torch.Tensor
trainers.VAETrainer¶
- class torcheeg.trainers.VAETrainer(encoder: Module, decoder: Module, lr: float = 0.0001, weight_decay: float = 0.0, beta: float = 1.0, device_ids: List[int] = [], ddp_sync_bn: bool = True, ddp_replace_sampler: bool = True, ddp_val: bool = True, ddp_test: bool = True)[source][source]¶
The variational autoencoder consists of two parts, an encoder, and a decoder. The encoder compresses the input into the latent space. The decoder receives as input the information sampled from the latent space and produces it as similar as possible to ground truth. The latent vector should approach the gaussian distribution supervised by KL divergence based on the variation trick. This class implement the training, test, and new EEG inference of variational autoencoders.
encoder = BEncoder(in_channels=4) decoder = BDecoder(in_channels=64, out_channels=4) trainer = VAETrainer(encoder, decoder) trainer.fit(train_loader, val_loader) trainer.test(test_loader)
The class provides the following hook functions for inserting additional implementations in the training, validation and testing lifecycle:
before_training_epoch
: executed before each epoch of training startsbefore_training_step
: executed before each batch of training startson_training_step
: the training process for each batchafter_training_step
: execute after the training of each batchafter_training_epoch
: executed after each epoch of trainingbefore_validation_epoch
: executed before each round of validation startsbefore_validation_step
: executed before the validation of each batchon_validation_step
: validation process for each batchafter_validation_step
: executed after the validation of each batchafter_validation_epoch
: executed after each round of validationbefore_test_epoch
: executed before each round of test startsbefore_test_step
: executed before the test of each batchon_test_step
: test process for each batchafter_test_step
: executed after the test of each batchafter_test_epoch
: executed after each round of test
If you want to customize some operations, you just need to inherit the class and override the hook function:
class MyVAETrainer(VAETrainer): def before_training_epoch(self, epoch_id: int, num_epochs: int): # Do something here. super().before_training_epoch(epoch_id, num_epochs)
If you want to use multiple GPUs for parallel computing, you need to specify the GPU indices you want to use in the python file:
trainer = VAETrainer(generator, discriminator, device_ids=[1, 2, 7]) trainer.fit(train_loader, val_loader) trainer.test(test_loader)
Then, you can use the
torch.distributed.launch
ortorchrun
to run your python file.python -m torch.distributed.launch \ --nproc_per_node=3 \ --nnodes=1 \ --node_rank=0 \ --master_addr="localhost" \ --master_port=2345 \ your_python_file.py
Here,
nproc_per_node
is the number of GPUs you specify.- Parameters
encoder (nn.Module) – The encoder, whose inputs are EEG signals, outputs are two batches of vectors of the same dimension, representing the mean and variance estimated in the reparameterization trick.
decoder (nn.Module) – The decoder generating EEG signals from hidden variables encoded by the encoder. The dimensions of the input vector should be defined on the
in_channel
attribute.lr (float) – The learning rate. (defualt:
0.0001
)weight_decay – (float): The weight decay (L2 penalty). (defualt:
0.0
)beta – (float): The weight of the KL divergence in the loss function. Please refer to betaVAE. (defualt:
1.0
)device_ids (list) – Use cpu if the list is empty. If the list contains indices of multiple GPUs, it needs to be launched with
torch.distributed.launch
ortorchrun
. (defualt:[]
)ddp_sync_bn (bool) – Whether to replace batch normalization in network structure with cross-GPU synchronized batch normalization. Only valid when the length of
device_ids
is greater than one. (defualt:True
)ddp_replace_sampler (bool) – Whether to replace sampler in dataloader with
DistributedSampler
. Only valid when the length ofdevice_ids
is greater than one. (defualt:True
)ddp_val (bool) – Whether to use multi-GPU acceleration for the validation set. For experiments where data input order is sensitive,
ddp_val
should be set toFalse
. Only valid when the length ofdevice_ids
is greater than one. (defualt:True
)ddp_test (bool) – Whether to use multi-GPU acceleration for the test set. For experiments where data input order is sensitive,
ddp_test
should be set toFalse
. Only valid when the length ofdevice_ids
is greater than one. (defualt:True
)
- fit(train_loader: DataLoader, val_loader: DataLoader, num_epochs: int = 1, **kwargs)[source][source]¶
Train the model on the training set and use the validation set to validate the results of each round of training.
- Parameters
train_loader (DataLoader) – Iterable DataLoader for traversing the training data batch (torch.utils.data.dataloader.DataLoader, torch_geometric.loader.DataLoader, etc).
val_loader (DataLoader) – Iterable DataLoader for traversing the validation data batch (torch.utils.data.dataloader.DataLoader, torch_geometric.loader.DataLoader, etc).
num_epochs (int) – training epochs. (defualt:
1
)
trainers.CVAETrainer¶
- class torcheeg.trainers.CVAETrainer(encoder: Module, decoder: Module, lr: float = 0.0001, weight_decay: float = 0.0, beta: float = 1.0, device_ids: List[int] = [], ddp_sync_bn: bool = True, ddp_replace_sampler: bool = True, ddp_val: bool = True, ddp_test: bool = True)[source][source]¶
The variational autoencoder consists of two parts, an encoder, and a decoder. The encoder compresses the input into the latent space. The decoder receives as input the information sampled from the latent space and produces it as similar as possible to ground truth. The latent vector should approach the gaussian distribution supervised by KL divergence based on the variation trick. This class implement the training, test, and new EEG inference of variational autoencoders.
encoder = BCEncoder(in_channels=4, num_classes=2) decoder = BCDecoder(in_channels=64, out_channels=4, num_classes=2) trainer = CVAETrainer(encoder, decoder) trainer.fit(train_loader, val_loader) trainer.test(test_loader)
The class provides the following hook functions for inserting additional implementations in the training, validation and testing lifecycle:
before_training_epoch
: executed before each epoch of training startsbefore_training_step
: executed before each batch of training startson_training_step
: the training process for each batchafter_training_step
: execute after the training of each batchafter_training_epoch
: executed after each epoch of trainingbefore_validation_epoch
: executed before each round of validation startsbefore_validation_step
: executed before the validation of each batchon_validation_step
: validation process for each batchafter_validation_step
: executed after the validation of each batchafter_validation_epoch
: executed after each round of validationbefore_test_epoch
: executed before each round of test startsbefore_test_step
: executed before the test of each batchon_test_step
: test process for each batchafter_test_step
: executed after the test of each batchafter_test_epoch
: executed after each round of test
If you want to customize some operations, you just need to inherit the class and override the hook function:
class MyCVAETrainer(CVAETrainer): def before_training_epoch(self, epoch_id: int, num_epochs: int): # Do something here. super().before_training_epoch(epoch_id, num_epochs)
If you want to use multiple GPUs for parallel computing, you need to specify the GPU indices you want to use in the python file:
trainer = CVAETrainer(generator, discriminator, device_ids=[1, 2, 7]) trainer.fit(train_loader, val_loader) trainer.test(test_loader)
Then, you can use the
torch.distributed.launch
ortorchrun
to run your python file.python -m torch.distributed.launch \ --nproc_per_node=3 \ --nnodes=1 \ --node_rank=0 \ --master_addr="localhost" \ --master_port=2345 \ your_python_file.py
Here,
nproc_per_node
is the number of GPUs you specify.- Parameters
encoder (nn.Module) – The encoder, whose inputs are EEG signals, outputs are two batches of vectors of the same dimension, representing the mean and variance estimated in the reparameterization trick.
decoder (nn.Module) – The decoder generating EEG signals from hidden variables encoded by the encoder. The dimensions of the input vector should be defined on the
in_channel
attribute.lr (float) – The learning rate. (defualt:
0.0001
)weight_decay – (float): The weight decay (L2 penalty). (defualt:
0.0
)beta – (float): The weight of the KL divergence in the loss function. Please refer to betaVAE. (defualt:
1.0
)device_ids (list) – Use cpu if the list is empty. If the list contains indices of multiple GPUs, it needs to be launched with
torch.distributed.launch
ortorchrun
. (defualt:[]
)ddp_sync_bn (bool) – Whether to replace batch normalization in network structure with cross-GPU synchronized batch normalization. Only valid when the length of
device_ids
is greater than one. (defualt:True
)ddp_replace_sampler (bool) – Whether to replace sampler in dataloader with
DistributedSampler
. Only valid when the length ofdevice_ids
is greater than one. (defualt:True
)ddp_val (bool) – Whether to use multi-GPU acceleration for the validation set. For experiments where data input order is sensitive,
ddp_val
should be set toFalse
. Only valid when the length ofdevice_ids
is greater than one. (defualt:True
)ddp_test (bool) – Whether to use multi-GPU acceleration for the test set. For experiments where data input order is sensitive,
ddp_test
should be set toFalse
. Only valid when the length ofdevice_ids
is greater than one. (defualt:True
)
- fit(train_loader: DataLoader, val_loader: DataLoader, num_epochs: int = 1, **kwargs)[source]¶
Train the model on the training set and use the validation set to validate the results of each round of training.
- Parameters
train_loader (DataLoader) – Iterable DataLoader for traversing the training data batch (torch.utils.data.dataloader.DataLoader, torch_geometric.loader.DataLoader, etc).
val_loader (DataLoader) – Iterable DataLoader for traversing the validation data batch (torch.utils.data.dataloader.DataLoader, torch_geometric.loader.DataLoader, etc).
num_epochs (int) – training epochs. (defualt:
1
)
- test(test_loader: DataLoader, **kwargs)[source]¶
Validate the performance of the model on the test set.
- Parameters
test_loader (DataLoader) – Iterable DataLoader for traversing the test data batch (torch.utils.data.dataloader.DataLoader, torch_geometric.loader.DataLoader, etc).
- sample(num_samples: int, labels: Optional[Tensor] = None) Tensor [source][source]¶
Samples from the latent space and return generated results.
- Parameters
num_samples (int) – Number of samples.
labels (torch.Tensor) – Category labels (int) for a batch of samples The shape should be
[n,]
. Here,n
corresponds to the batch size. If not provided, a batch of randomly generated categories will be used.
- Returns
the generated samples.
- Return type
torch.Tensor
trainers.DDPMTrainer¶
- class torcheeg.trainers.DDPMTrainer(unet: Module, lr: float = 0.0003, beta_timesteps: int = 1000, beta_start: float = 0.0001, beta_end: float = 0.02, device_ids: List[int] = [], ddp_sync_bn: bool = True, ddp_replace_sampler: bool = True, ddp_val: bool = True, ddp_test: bool = True)[source][source]¶
The diffusion model consists of two processes, the forward process, and the backward process. The forward process is to gradually add Gaussian noise to an image until it becomes random noise, while the backward process is the de-noising process. We train an attention-based UNet network at the backward process to start with random noise and gradually de-noise it until an image is generated and use the UNet to generate a simulated image from random noises. This class implements the training, test, and new sample inference of DDPM.
unet = BUNet(in_channels=4) trainer = DDPMTrainer(unet) trainer.fit(train_loader, val_loader) trainer.test(test_loader)
The class provides the following hook functions for inserting additional implementations in the training, validation and testing lifecycle:
before_training_epoch
: executed before each epoch of training startsbefore_training_step
: executed before each batch of training startson_training_step
: the training process for each batchafter_training_step
: execute after the training of each batchafter_training_epoch
: executed after each epoch of trainingbefore_validation_epoch
: executed before each round of validation startsbefore_validation_step
: executed before the validation of each batchon_validation_step
: validation process for each batchafter_validation_step
: executed after the validation of each batchafter_validation_epoch
: executed after each round of validationbefore_test_epoch
: executed before each round of test startsbefore_test_step
: executed before the test of each batchon_test_step
: test process for each batchafter_test_step
: executed after the test of each batchafter_test_epoch
: executed after each round of test
If you want to customize some operations, you just need to inherit the class and override the hook function:
class MyDDPMTrainer(DDPMTrainer): def before_training_epoch(self, epoch_id: int, num_epochs: int): # Do something here. super().before_training_epoch(epoch_id, num_epochs)
If you want to use multiple GPUs for parallel computing, you need to specify the GPU indices you want to use in the python file:
trainer = DDPMTrainer(generator, discriminator, device_ids=[1, 2, 7]) trainer.fit(train_loader, val_loader) trainer.test(test_loader)
Then, you can use the
torch.distributed.launch
ortorchrun
to run your python file.python -m torch.distributed.launch \ --nproc_per_node=3 \ --nnodes=1 \ --node_rank=0 \ --master_addr="localhost" \ --master_port=2345 \ your_python_file.py
Here,
nproc_per_node
is the number of GPUs you specify.- Parameters
encoder (nn.Module) – The encoder, whose inputs are EEG signals, outputs are two batches of vectors of the same dimension, representing the mean and variance estimated in the reparameterization trick.
decoder (nn.Module) – The decoder generating EEG signals from hidden variables encoded by the encoder. The dimensions of the input vector should be defined on the
in_channel
attribute.lr (float) – The learning rate. (defualt:
0.0001
)weight_decay – (float): The weight decay (L2 penalty). (defualt:
0.0
)beta_start – (float): The start point of the linear beta scheduler to sample betas. (defualt:
1e-4
)beta_end – (float): The end point of the linear beta scheduler to sample betas. (defualt:
2e-2
)device_ids (list) – Use cpu if the list is empty. If the list contains indices of multiple GPUs, it needs to be launched with
torch.distributed.launch
ortorchrun
. (defualt:[]
)ddp_sync_bn (bool) – Whether to replace batch normalization in network structure with cross-GPU synchronized batch normalization. Only valid when the length of
device_ids
is greater than one. (defualt:True
)ddp_replace_sampler (bool) – Whether to replace sampler in dataloader with
DistributedSampler
. Only valid when the length ofdevice_ids
is greater than one. (defualt:True
)ddp_val (bool) – Whether to use multi-GPU acceleration for the validation set. For experiments where data input order is sensitive,
ddp_val
should be set toFalse
. Only valid when the length ofdevice_ids
is greater than one. (defualt:True
)ddp_test (bool) – Whether to use multi-GPU acceleration for the test set. For experiments where data input order is sensitive,
ddp_test
should be set toFalse
. Only valid when the length ofdevice_ids
is greater than one. (defualt:True
)
- fit(train_loader: DataLoader, val_loader: DataLoader, num_epochs: int = 1, **kwargs)[source][source]¶
Train the model on the training set and use the validation set to validate the results of each round of training.
- Parameters
train_loader (DataLoader) – Iterable DataLoader for traversing the training data batch (torch.utils.data.dataloader.DataLoader, torch_geometric.loader.DataLoader, etc).
val_loader (DataLoader) – Iterable DataLoader for traversing the validation data batch (torch.utils.data.dataloader.DataLoader, torch_geometric.loader.DataLoader, etc).
num_epochs (int) – training epochs. (defualt:
1
)
trainers.CDDPMTrainer¶
- class torcheeg.trainers.CDDPMTrainer(unet: Module, lr: float = 0.0003, beta_timesteps: int = 1000, beta_start: float = 0.0001, beta_end: float = 0.02, device_ids: List[int] = [], ddp_sync_bn: bool = True, ddp_replace_sampler: bool = True, ddp_val: bool = True, ddp_test: bool = True)[source][source]¶
The diffusion model consists of two processes, the forward process, and the backward process. The forward process is to gradually add Gaussian noise to an image until it becomes random noise, while the backward process is the de-noising process. We train an attention-based UNet network at the backward process to start with random noise and gradually de-noise it until an image is generated and use the UNet to generate a simulated image from random noises. In particular, in conditional UNet, additional label information is provided to guide the noise reduction results during the noise reduction process. This class implements the training, test, and new sample inference of the conditional DDPM.
unet = BCUNet(in_channels=4, num_classes=2) trainer = CDDPMTrainer(unet) trainer.fit(train_loader, val_loader) trainer.test(test_loader)
The class provides the following hook functions for inserting additional implementations in the training, validation and testing lifecycle:
before_training_epoch
: executed before each epoch of training startsbefore_training_step
: executed before each batch of training startson_training_step
: the training process for each batchafter_training_step
: execute after the training of each batchafter_training_epoch
: executed after each epoch of trainingbefore_validation_epoch
: executed before each round of validation startsbefore_validation_step
: executed before the validation of each batchon_validation_step
: validation process for each batchafter_validation_step
: executed after the validation of each batchafter_validation_epoch
: executed after each round of validationbefore_test_epoch
: executed before each round of test startsbefore_test_step
: executed before the test of each batchon_test_step
: test process for each batchafter_test_step
: executed after the test of each batchafter_test_epoch
: executed after each round of test
If you want to customize some operations, you just need to inherit the class and override the hook function:
class MyCDDPMTrainer(CDDPMTrainer): def before_training_epoch(self, epoch_id: int, num_epochs: int): # Do something here. super().before_training_epoch(epoch_id, num_epochs)
If you want to use multiple GPUs for parallel computing, you need to specify the GPU indices you want to use in the python file:
trainer = CDDPMTrainer(generator, discriminator, device_ids=[1, 2, 7]) trainer.fit(train_loader, val_loader) trainer.test(test_loader)
Then, you can use the
torch.distributed.launch
ortorchrun
to run your python file.python -m torch.distributed.launch \ --nproc_per_node=3 \ --nnodes=1 \ --node_rank=0 \ --master_addr="localhost" \ --master_port=2345 \ your_python_file.py
Here,
nproc_per_node
is the number of GPUs you specify.- Parameters
encoder (nn.Module) – The encoder, whose inputs are EEG signals, outputs are two batches of vectors of the same dimension, representing the mean and variance estimated in the reparameterization trick.
decoder (nn.Module) – The decoder generating EEG signals from hidden variables encoded by the encoder. The dimensions of the input vector should be defined on the
in_channel
attribute.lr (float) – The learning rate. (defualt:
0.0001
)weight_decay – (float): The weight decay (L2 penalty). (defualt:
0.0
)beta_start – (float): The start point of the linear beta scheduler to sample betas. (defualt:
1e-4
)beta_end – (float): The end point of the linear beta scheduler to sample betas. (defualt:
2e-2
)device_ids (list) – Use cpu if the list is empty. If the list contains indices of multiple GPUs, it needs to be launched with
torch.distributed.launch
ortorchrun
. (defualt:[]
)ddp_sync_bn (bool) – Whether to replace batch normalization in network structure with cross-GPU synchronized batch normalization. Only valid when the length of
device_ids
is greater than one. (defualt:True
)ddp_replace_sampler (bool) – Whether to replace sampler in dataloader with
DistributedSampler
. Only valid when the length ofdevice_ids
is greater than one. (defualt:True
)ddp_val (bool) – Whether to use multi-GPU acceleration for the validation set. For experiments where data input order is sensitive,
ddp_val
should be set toFalse
. Only valid when the length ofdevice_ids
is greater than one. (defualt:True
)ddp_test (bool) – Whether to use multi-GPU acceleration for the test set. For experiments where data input order is sensitive,
ddp_test
should be set toFalse
. Only valid when the length ofdevice_ids
is greater than one. (defualt:True
)
- fit(train_loader: DataLoader, val_loader: DataLoader, num_epochs: int = 1, **kwargs)[source]¶
Train the model on the training set and use the validation set to validate the results of each round of training.
- Parameters
train_loader (DataLoader) – Iterable DataLoader for traversing the training data batch (torch.utils.data.dataloader.DataLoader, torch_geometric.loader.DataLoader, etc).
val_loader (DataLoader) – Iterable DataLoader for traversing the validation data batch (torch.utils.data.dataloader.DataLoader, torch_geometric.loader.DataLoader, etc).
num_epochs (int) – training epochs. (defualt:
1
)
- test(test_loader: DataLoader, **kwargs)[source]¶
Validate the performance of the model on the test set.
- Parameters
test_loader (DataLoader) – Iterable DataLoader for traversing the test data batch (torch.utils.data.dataloader.DataLoader, torch_geometric.loader.DataLoader, etc).
- sample(num_samples: int, sample_size: Tuple[int], labels: Optional[Tensor] = None) Tensor [source][source]¶
Samples from the latent space and return generated results.
- Parameters
num_samples (int) – Number of samples.
sample_size (tuple) – Shape of a sample.
labels (torch.Tensor) – Category labels (int) for a batch of samples The shape should be
[n,]
. Here,n
corresponds to the batch size. If not provided, a batch of randomly generated categories will be used.
- Returns
the generated samples.
- Return type
torch.Tensor
trainers.GlowTrainer¶
- class torcheeg.trainers.GlowTrainer(glow: Module, lr: float = 0.0001, grad_norm_clip: float = 50.0, loss_scale: float = 0.001, device_ids: List[int] = [], ddp_sync_bn: bool = True, ddp_replace_sampler: bool = True, ddp_val: bool = True, ddp_test: bool = True)[source][source]¶
This class implement the training, test, and new EEG inference of normalizing flow-based models. Glow is dedicated to train an encoder that encodes the input as a hidden variable and makes the hidden variable obey the standard normal distribution. By good design, the encoder should be reversible. On this basis, as soon as the encoder is trained, the corresponding decoder can be used to generate samples from a Gaussian distribution according to the inverse operation. In particular, compared with vanilla normalizing flow-based models, the Glow model is a easy-to-use flow-based model that replaces the operation of permutating the channel axes by introducing a 1x1 reversible convolution.
Below is a recommended suite for use in EEG generation:
model = BGlow(in_channels=4) trainer = GlowTrainer(generator, discriminator) trainer.fit(train_loader, val_loader) trainer.test(test_loader)
Below is a recommended suite for use in conditional EEG generation:
model = BGlow(in_channels=4, num_classes=2) trainer = GlowTrainer(generator, discriminator) trainer.fit(train_loader, val_loader) trainer.test(test_loader)
The class provides the following hook functions for inserting additional implementations in the training, validation and testing lifecycle:
before_training_epoch
: executed before each epoch of training startsbefore_training_step
: executed before each batch of training startson_training_step
: the training process for each batchafter_training_step
: execute after the training of each batchafter_training_epoch
: executed after each epoch of trainingbefore_validation_epoch
: executed before each round of validation startsbefore_validation_step
: executed before the validation of each batchon_validation_step
: validation process for each batchafter_validation_step
: executed after the validation of each batchafter_validation_epoch
: executed after each round of validationbefore_test_epoch
: executed before each round of test startsbefore_test_step
: executed before the test of each batchon_test_step
: test process for each batchafter_test_step
: executed after the test of each batchafter_test_epoch
: executed after each round of test
If you want to customize some operations, you just need to inherit the class and override the hook function:
class MyGlowTrainer(GlowTrainer): def before_training_epoch(self, epoch_id: int, num_epochs: int): # Do something here. super().before_training_epoch(epoch_id, num_epochs)
If you want to use multiple GPUs for parallel computing, you need to specify the GPU indices you want to use in the python file:
trainer = GlowTrainer(generator, discriminator, device_ids=[1, 2, 7]) trainer.fit(train_loader, val_loader) trainer.test(test_loader)
Then, you can use the
torch.distributed.launch
ortorchrun
to run your python file.python -m torch.distributed.launch \ --nproc_per_node=3 \ --nnodes=1 \ --node_rank=0 \ --master_addr="localhost" \ --master_port=2345 \ your_python_file.py
Here,
nproc_per_node
is the number of GPUs you specify.- Parameters
encoder (nn.Module) – The encoder, whose inputs are EEG signals, outputs are two batches of vectors of the same dimension, representing the mean and variance estimated in the reparameterization trick.
decoder (nn.Module) – The decoder generating EEG signals from hidden variables encoded by the encoder. The dimensions of the input vector should be defined on the
in_channel
attribute.lr (float) – The learning rate. (defualt:
0.0001
)weight_decay – (float): The weight decay (L2 penalty). (defualt:
0.0
)beta – (float): The weight of the KL divergence in the loss function. Please refer to betaGlow. (defualt:
1.0
)device_ids (list) – Use cpu if the list is empty. If the list contains indices of multiple GPUs, it needs to be launched with
torch.distributed.launch
ortorchrun
. (defualt:[]
)ddp_sync_bn (bool) – Whether to replace batch normalization in network structure with cross-GPU synchronized batch normalization. Only valid when the length of
device_ids
is greater than one. (defualt:True
)ddp_replace_sampler (bool) – Whether to replace sampler in dataloader with
DistributedSampler
. Only valid when the length ofdevice_ids
is greater than one. (defualt:True
)ddp_val (bool) – Whether to use multi-GPU acceleration for the validation set. For experiments where data input order is sensitive,
ddp_val
should be set toFalse
. Only valid when the length ofdevice_ids
is greater than one. (defualt:True
)ddp_test (bool) – Whether to use multi-GPU acceleration for the test set. For experiments where data input order is sensitive,
ddp_test
should be set toFalse
. Only valid when the length ofdevice_ids
is greater than one. (defualt:True
)
- fit(train_loader: DataLoader, val_loader: DataLoader, num_epochs: int = 1, **kwargs)[source][source]¶
Train the model on the training set and use the validation set to validate the results of each round of training.
- Parameters
train_loader (DataLoader) – Iterable DataLoader for traversing the training data batch (torch.utils.data.dataloader.DataLoader, torch_geometric.loader.DataLoader, etc).
val_loader (DataLoader) – Iterable DataLoader for traversing the validation data batch (torch.utils.data.dataloader.DataLoader, torch_geometric.loader.DataLoader, etc).
num_epochs (int) – training epochs. (defualt:
1
)
- test(test_loader: DataLoader, **kwargs)[source][source]¶
Validate the performance of the model on the test set.
- Parameters
test_loader (DataLoader) – Iterable DataLoader for traversing the test data batch (torch.utils.data.dataloader.DataLoader, torch_geometric.loader.DataLoader, etc).
- sample(num_samples: int, temperature: float = 1.0, labels: Optional[Tensor] = None) Tensor [source][source]¶
Samples from the latent space and return generated results.
- Parameters
num_samples (int) – Number of samples.
temperature (float) – The hyper-parameter, temperature, to sample from gaussian distributions. (defualt:
1.0
)labels (torch.Tensor) – Category labels (int) for a batch of samples The shape should be
[n,]
. Here,n
corresponds to the batch size. If not provided, a batch of randomly generated categories will be used.
- Returns
the generated samples.
- Return type
torch.Tensor