RegressorTrainer¶
- class torcheeg.trainers.RegressorTrainer(model: Module, lr: float = 0.001, weight_decay: float = 0.0, devices: int = 1, accelerator: str = 'cpu', metrics: List[str] = ['mse'])[source][source]¶
A generic trainer class for EEG regression.
trainer = RegressorTrainer(model) trainer.fit(train_loader, val_loader) trainer.test(test_loader)
- Parameters:
model (nn.Module) – The regression model that outputs continuous values. The dimension of its output should match the number of target variables to predict.
lr (float) – The learning rate. (default:
0.001)weight_decay (float) – The weight decay. (default:
0.0)devices (int) – The number of devices to use. (default:
1)accelerator (str) – The accelerator to use. Available options are: ‘cpu’, ‘gpu’. (default:
"cpu")metrics (list of str) – The metrics to use. Available options are: ‘mse’ (Mean Squared Error), ‘mae’ (Mean Absolute Error), ‘rmse’ (Root Mean Squared Error), ‘r2’ (R-squared score). (default:
["mse"])