Shortcuts

RegressorTrainer

class torcheeg.trainers.RegressorTrainer(model: Module, lr: float = 0.001, weight_decay: float = 0.0, devices: int = 1, accelerator: str = 'cpu', metrics: List[str] = ['mse'])[source][source]

A generic trainer class for EEG regression.

trainer = RegressorTrainer(model)
trainer.fit(train_loader, val_loader)
trainer.test(test_loader)
Parameters:
  • model (nn.Module) – The regression model that outputs continuous values. The dimension of its output should match the number of target variables to predict.

  • lr (float) – The learning rate. (default: 0.001)

  • weight_decay (float) – The weight decay. (default: 0.0)

  • devices (int) – The number of devices to use. (default: 1)

  • accelerator (str) – The accelerator to use. Available options are: ‘cpu’, ‘gpu’. (default: "cpu")

  • metrics (list of str) – The metrics to use. Available options are: ‘mse’ (Mean Squared Error), ‘mae’ (Mean Absolute Error), ‘rmse’ (Root Mean Squared Error), ‘r2’ (R-squared score). (default: ["mse"])

fit(train_loader: DataLoader, val_loader: DataLoader, max_epochs: int = 300, *args, **kwargs) Any[source][source]
test(test_loader: DataLoader, *args, **kwargs) List[Dict[str, float]][source][source]

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources