Shortcuts

FocalLossTrainer

class torcheeg.trainers.FocalLossTrainer(model: Module, num_classes: int, class_frequency: List[int] | DataLoader, gamma: float = 0.5, rule: str = 'reweight', beta_reweight: float = 0.9999, drw_epochs: int = 160, lr: float = 0.001, weight_decay: float = 0.0, devices: int = 1, accelerator: str = 'cpu', metrics: List[str] = ['accuracy'])[source][source]

A trainer class for EEG classification with Focal loss for imbalanced datasets.

from torcheeg.models import CCNN
from torcheeg.trainers import FocalLossTrainer

model = CCNN(in_channels=5, num_classes=2)
trainer = FocalLossTrainer(model, num_classes=2, class_frequency=[10, 20], gamma=1.0)
Parameters:
  • model (nn.Module) – The classification model, and the dimension of its output should be equal to the number of categories in the dataset. The output layer does not need to have a softmax activation function.

  • num_classes (int) – The number of classes in the dataset.

  • class_frequency (List[int] or Dataloader) – The frequency of each class in the dataset. It can be a list of integers or a dataloader to calculate the frequency of each class in the dataset, traversing the data batch (torch.utils.data.dataloader.DataLoader, torch_geometric.loader.DataLoader, etc). (default: None)

  • gamma (float) – The gamma parameter. (default: 1.0)

  • rule (str) – The rule to adjust the weight of each class. Availabel options are: ‘none’, ‘reweight’, ‘drw’ (deferred re-balancing optimization schedule). (default: none)

  • beta_reweight (float) – The beta parameter for reweighting. It is only used when rule is ‘reweight’ or ‘drw’. (default: 0.9999)

  • drw_epochs (int) – The number of epochs to use DRW. It is only used when rule is ‘drw’. (default: 160)

  • lr (float) – The learning rate. (default: 0.001)

  • weight_decay (float) – The weight decay. (default: 0.0)

  • devices (int) – The number of devices to use. (default: 1)

  • accelerator (str) – The accelerator to use. Availabel options are: ‘cpu’, ‘gpu’. (default: "cpu")

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources