Shortcuts

ClassifierTrainer

class torcheeg.trainers.ClassifierTrainer(model: Module, num_classes: int, lr: float = 0.001, weight_decay: float = 0.0, devices: int = 1, accelerator: str = 'cpu', verbose: bool = True, metrics: List[str] = ['accuracy'])[source][source]

A generic trainer class for EEG classification.

trainer = ClassifierTrainer(model)
trainer.fit(train_loader, val_loader)
trainer.test(test_loader)
Parameters:
  • model (nn.Module) – The classification model, and the dimension of its output should be equal to the number of categories in the dataset. The output layer does not need to have a softmax activation function.

  • num_classes (int, optional) – The number of categories in the dataset. If None, the number of categories will be inferred from the attribute num_classes of the model. (defualt: None)

  • lr (float) – The learning rate. (default: 0.001)

  • weight_decay (float) – The weight decay. (default: 0.0)

  • devices (int) – The number of devices to use. (default: 1)

  • accelerator (str) – The accelerator to use. Availabel options are: ‘cpu’, ‘gpu’. (default: "cpu")

  • metrics (list of str) – The metrics to use. Availabel options are: ‘precision’, ‘recall’, ‘f1score’, ‘accuracy’, ‘matthews’, ‘auroc’, and ‘kappa’. (default: ["accuracy"])

fit(train_loader: DataLoader, val_loader: DataLoader, max_epochs: int = 300, *args, **kwargs) Any[source][source]
Parameters:
  • train_loader (DataLoader) – Iterable DataLoader for traversing the training data batch (torch.utils.data.dataloader.DataLoader, torch_geometric.loader.DataLoader, etc).

  • val_loader (DataLoader) – Iterable DataLoader for traversing the validation data batch (torch.utils.data.dataloader.DataLoader, torch_geometric.loader.DataLoader, etc).

  • max_epochs (int) – Maximum number of epochs to train the model. (default: 300)

test(test_loader: DataLoader, *args, **kwargs) List[Dict[str, float]][source][source]
Parameters:

test_loader (DataLoader) – Iterable DataLoader for traversing the test data batch (torch.utils.data.dataloader.DataLoader, torch_geometric.loader.DataLoader, etc).

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources