ClassifierTrainer¶
- class torcheeg.trainers.ClassifierTrainer(model: Module, num_classes: int, lr: float = 0.001, weight_decay: float = 0.0, devices: int = 1, accelerator: str = 'cpu', metrics: List[str] = ['accuracy'])[source][source]¶
A generic trainer class for EEG classification.
trainer = ClassifierTrainer(model) trainer.fit(train_loader, val_loader) trainer.test(test_loader)
- Parameters:
model (nn.Module) – The classification model, and the dimension of its output should be equal to the number of categories in the dataset. The output layer does not need to have a softmax activation function.
num_classes (int, optional) – The number of categories in the dataset. If
None
, the number of categories will be inferred from the attributenum_classes
of the model. (defualt:None
)lr (float) – The learning rate. (default:
0.001
)weight_decay (float) – The weight decay. (default:
0.0
)devices (int) – The number of devices to use. (default:
1
)accelerator (str) – The accelerator to use. Available options are: ‘cpu’, ‘gpu’. (default:
"cpu"
)metrics (list of str) – The metrics to use. Available options are: ‘precision’, ‘recall’, ‘f1score’, ‘accuracy’, ‘matthews’, ‘auroc’, and ‘kappa’. (default:
["accuracy"]
)
- fit(train_loader: DataLoader, val_loader: DataLoader, max_epochs: int = 300, *args, **kwargs) Any [source][source]¶
- Parameters:
train_loader (DataLoader) – Iterable DataLoader for traversing the training data batch (
torch.utils.data.dataloader.DataLoader
,torch_geometric.loader.DataLoader
, etc).val_loader (DataLoader) – Iterable DataLoader for traversing the validation data batch (
torch.utils.data.dataloader.DataLoader
,torch_geometric.loader.DataLoader
, etc).max_epochs (int) – Maximum number of epochs to train the model. (default:
300
)