LDAMLossTrainer¶
- class torcheeg.trainers.LDAMLossTrainer(model: Module, num_classes: int, class_frequency: List[int] | DataLoader, max_margin: float = 0.5, scaling: float = 30, rule: str = 'none', beta_reweight: float = 0.9999, drw_epochs: int = 160, lr: float = 0.001, weight_decay: float = 0.0, devices: int = 1, accelerator: str = 'cpu', metrics: List[str] = ['accuracy'])[source][source]¶
A trainer class for EEG classification with Label-distribution-aware margin (LDAM) loss for imbalanced datasets.
Paper: Cao K, Wei C, Gaidon A, et al. Learning imbalanced datasets with label-distribution-aware margin loss[J]. Advances in neural information processing systems, 2019, 32.
Related Project: https://github.com/kaidic/LDAM-DRW
from torcheeg.models import CCNN from torcheeg.trainers import LDAMLossTrainer model = CCNN(in_channels=5, num_classes=2) trainer = LDAMLossTrainer(model, num_classes=2, class_frequency=[10, 20], max_margin=0.5, scaling=30)
- Parameters:
model (nn.Module) – The classification model, and the dimension of its output should be equal to the number of categories in the dataset. The output layer does not need to have a softmax activation function.
num_classes (int) – The number of classes in the dataset.
class_frequency (List[int] or Dataloader) – The frequency of each class in the dataset. It can be a list of integers or a dataloader to calculate the frequency of each class in the dataset, traversing the data batch (
torch.utils.data.dataloader.DataLoader
,torch_geometric.loader.DataLoader
, etc). (default:None
)max_margin (float) – The maximum margin. (default:
0.5
)rule (str) – The rule to adjust the weight of each class. Available options are: ‘none’, ‘reweight’, ‘drw’ (deferred re-balancing optimization schedule). (default:
'none'
)beta_reweight (float) – The beta parameter for reweighting. It is only used when
rule
is ‘reweight’ or ‘drw’. (default:0.9999
)drw_epochs (int) – The number of epochs to use DRW. It is only used when
rule
is ‘drw’. (default:160
)scaling (float) – The scaling factor. (default:
30
)lr (float) – The learning rate. (default:
0.001
)weight_decay (float) – The weight decay. (default:
0.0
)devices (int) – The number of devices to use. (default:
1
)accelerator (str) – The accelerator to use. Available options are: ‘cpu’, ‘gpu’. (default:
"cpu"
)