Shortcuts

Source code for torcheeg.trainers.classification_trainer

import math
from typing import List, Tuple, Optional

import torch
import numpy as np
import torch.nn as nn
import torchmetrics
from torch.utils.data import DataLoader

from .basic_trainer import BasicTrainer


[docs]class ClassificationTrainer(BasicTrainer): r''' A generic trainer class for EEG classification. .. code-block:: python trainer = ClassificationTrainer(model) trainer.fit(train_loader, val_loader) trainer.test(test_loader) The class provides the following hook functions for inserting additional implementations in the training, validation and testing lifecycle: - :obj:`before_training_epoch`: executed before each epoch of training starts - :obj:`before_training_step`: executed before each batch of training starts - :obj:`on_training_step`: the training process for each batch - :obj:`after_training_step`: execute after the training of each batch - :obj:`after_training_epoch`: executed after each epoch of training - :obj:`before_validation_epoch`: executed before each round of validation starts - :obj:`before_validation_step`: executed before the validation of each batch - :obj:`on_validation_step`: validation process for each batch - :obj:`after_validation_step`: executed after the validation of each batch - :obj:`after_validation_epoch`: executed after each round of validation - :obj:`before_test_epoch`: executed before each round of test starts - :obj:`before_test_step`: executed before the test of each batch - :obj:`on_test_step`: test process for each batch - :obj:`after_test_step`: executed after the test of each batch - :obj:`after_test_epoch`: executed after each round of test If you want to customize some operations, you just need to inherit the class and override the hook function: .. code-block:: python class MyClassificationTrainer(ClassificationTrainer): def before_training_epoch(self, epoch_id: int, num_epochs: int): # Do something here. super().before_training_epoch(epoch_id, num_epochs) If you want to use multiple GPUs for parallel computing, you need to specify the GPU indices you want to use in the python file: .. code-block:: python trainer = ClassificationTrainer(model, device_ids=[1, 2, 7]) trainer.fit(train_loader, val_loader) trainer.test(test_loader) Then, you can use the :obj:`torch.distributed.launch` or :obj:`torchrun` to run your python file. .. code-block:: shell python -m torch.distributed.launch \ --nproc_per_node=3 \ --nnodes=1 \ --node_rank=0 \ --master_addr="localhost" \ --master_port=2345 \ your_python_file.py Here, :obj:`nproc_per_node` is the number of GPUs you specify. Args: model (nn.Module): The classification model, and the dimension of its output should be equal to the number of categories in the dataset. The output layer does not need to have a softmax activation function. num_classes (int, optional): The number of categories in the dataset. If :obj:`None`, the number of categories will be inferred from the attribute :obj:`num_classes` of the model. (defualt: :obj:`None`) lr (float): The learning rate. (defualt: :obj:`0.0001`) weight_decay: (float): The weight decay (L2 penalty). (defualt: :obj:`0.0`) device_ids (list): Use cpu if the list is empty. If the list contains indices of multiple GPUs, it needs to be launched with :obj:`torch.distributed.launch` or :obj:`torchrun`. (defualt: :obj:`[]`) ddp_sync_bn (bool): Whether to replace batch normalization in network structure with cross-GPU synchronized batch normalization. Only valid when the length of :obj:`device_ids` is greater than one. (defualt: :obj:`True`) ddp_replace_sampler (bool): Whether to replace sampler in dataloader with :obj:`DistributedSampler`. Only valid when the length of :obj:`device_ids` is greater than one. (defualt: :obj:`True`) ddp_val (bool): Whether to use multi-GPU acceleration for the validation set. For experiments where data input order is sensitive, :obj:`ddp_val` should be set to :obj:`False`. Only valid when the length of :obj:`device_ids` is greater than one. (defualt: :obj:`True`) ddp_test (bool): Whether to use multi-GPU acceleration for the test set. For experiments where data input order is sensitive, :obj:`ddp_test` should be set to :obj:`False`. Only valid when the length of :obj:`device_ids` is greater than one. (defualt: :obj:`True`) .. automethod:: fit .. automethod:: test ''' def __init__(self, model: nn.Module, num_classes: Optional[int] = None, lr: float = 1e-4, weight_decay: float = 0.0, device_ids: List[int] = [], ddp_sync_bn: bool = True, ddp_replace_sampler: bool = True, ddp_val: bool = True, ddp_test: bool = True): super(ClassificationTrainer, self).__init__(modules={'model': model}, device_ids=device_ids, ddp_sync_bn=ddp_sync_bn, ddp_replace_sampler=ddp_replace_sampler, ddp_val=ddp_val, ddp_test=ddp_test) self.lr = lr self.weight_decay = weight_decay if not num_classes is None: self.num_classes = num_classes elif hasattr(model, 'num_classes'): self.num_classes = model.num_classes else: raise ValueError('The number of classes is not specified.') self.optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay) self.loss_fn = nn.CrossEntropyLoss() # init metric self.train_loss = torchmetrics.MeanMetric().to(self.device) self.train_accuracy = torchmetrics.Accuracy( task='multiclass', num_classes=self.num_classes, top_k=1).to(self.device) self.val_loss = torchmetrics.MeanMetric().to(self.device) self.val_accuracy = torchmetrics.Accuracy( task='multiclass', num_classes=self.num_classes, top_k=1).to(self.device) self.test_loss = torchmetrics.MeanMetric().to(self.device) self.test_accuracy = torchmetrics.Accuracy( task='multiclass', num_classes=self.num_classes, top_k=1).to(self.device) def before_training_epoch(self, epoch_id: int, num_epochs: int, **kwargs): self.log(f"Epoch {epoch_id}\n-------------------------------") def on_training_step(self, train_batch: Tuple, batch_id: int, num_batches: int, **kwargs): self.train_accuracy.reset() self.train_loss.reset() X = train_batch[0].to(self.device) y = train_batch[1].to(self.device) # compute prediction error pred = self.modules['model'](X) loss = self.loss_fn(pred, y) # backpropagation self.optimizer.zero_grad() loss.backward() self.optimizer.step() # log five times log_step = math.ceil(num_batches / 5) if batch_id % log_step == 0: self.train_loss.update(loss) self.train_accuracy.update(pred.argmax(1), y) train_loss = self.train_loss.compute() train_accuracy = 100 * self.train_accuracy.compute() # if not distributed, world_size is 1 batch_id = batch_id * self.world_size num_batches = num_batches * self.world_size if self.is_main: self.log( f"loss: {train_loss:>8f}, accuracy: {train_accuracy:>0.1f}% [{batch_id:>5d}/{num_batches:>5d}]" ) def before_validation_epoch(self, epoch_id: int, num_epochs: int, **kwargs): self.val_accuracy.reset() self.val_loss.reset() def on_validation_step(self, val_batch: Tuple, batch_id: int, num_batches: int, **kwargs): X = val_batch[0].to(self.device) y = val_batch[1].to(self.device) pred = self.modules['model'](X) self.val_loss.update(self.loss_fn(pred, y)) self.val_accuracy.update(pred.argmax(1), y) def after_validation_epoch(self, epoch_id: int, num_epochs: int, **kwargs): val_accuracy = 100 * self.val_accuracy.compute() val_loss = self.val_loss.compute() self.log(f"\nloss: {val_loss:>8f}, accuracy: {val_accuracy:>0.1f}%") def before_test_epoch(self, **kwargs): self.test_loss.reset() self.test_accuracy.reset() def on_test_step(self, test_batch: Tuple, batch_id: int, num_batches: int, **kwargs): X = test_batch[0].to(self.device) y = test_batch[1].to(self.device) pred = self.modules['model'](X) self.test_loss.update(self.loss_fn(pred, y)) self.test_accuracy.update(pred.argmax(1), y) def after_test_epoch(self, **kwargs): test_accuracy = 100 * self.test_accuracy.compute() test_loss = self.test_loss.compute() self.log(f"\nloss: {test_loss:>8f}, accuracy: {test_accuracy:>0.1f}%")
[docs] def test(self, test_loader: DataLoader, **kwargs): r''' Args: test_loader (DataLoader): Iterable DataLoader for traversing the test data batch (torch.utils.data.dataloader.DataLoader, torch_geometric.loader.DataLoader, etc). ''' super().test(test_loader=test_loader, **kwargs)
[docs] def fit(self, train_loader: DataLoader, val_loader: DataLoader, num_epochs: int = 1, **kwargs): r''' Args: train_loader (DataLoader): Iterable DataLoader for traversing the training data batch (torch.utils.data.dataloader.DataLoader, torch_geometric.loader.DataLoader, etc). val_loader (DataLoader): Iterable DataLoader for traversing the validation data batch (torch.utils.data.dataloader.DataLoader, torch_geometric.loader.DataLoader, etc). num_epochs (int): training epochs. (defualt: :obj:`1`) ''' super().fit(train_loader=train_loader, val_loader=val_loader, num_epochs=num_epochs, **kwargs)

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources