torcheeg.losses

torcheeg.losses.FocalLoss

class torcheeg.losses.FocalLoss(alpha: int = 1.0, gamma: int = 2.0, reduction: str = 'mean', label_smooth: float = 0.05, num_classes: int = 2)[source]

Focal loss is a loss that adds a factor \((1 - p_t)^{\gamma}\) to the standard cross entropy criterion. Setting \(\gamma>0\) reduces the relative loss for well-classified examples (\(p_t > 0.5\)), putting more focus on hard, misclassified examples. As experiments demonstrate, the focal loss enables training highly accurate models in the presence of vast numbers of easy negative examples.

inputs = torch.randn(3, 5, requires_grad=True)
targets = torch.empty(3, dtype=torch.long).random_(5)
loss = FocalLoss(num_classes=5)
output = loss(inputs, targets)

This version further adds support for multi-classification and label smoothing to the original implementation.

Parameters
  • alpha (float) – The hyperparameter alpha in the focal loss. (defualt: 1.0)

  • gamma (float) – The hyperparameter gamma in the focal loss. (defualt: 2.0)

  • reduction (str) – Specifies the reduction to apply to the output. Options include none, mean and sum. ‘none’: no reduction will be applied, ‘mean’: the sum of the output will be divided by the number of elements in the output, ‘sum’: the output will be summed.

  • num_classes (int) – The number of classes to predict. (defualt: 2)

forward(inputs: Tensor, targets: Tensor) Tensor[source]
Parameters
  • inputs (torch.Tensor) – The predictions from the model.

  • targets (torch.Tensor) – The ground-truth labels.

Returns

the loss value.

Return type

torch.Tensor

torcheeg.losses.VirtualAdversarialTrainingLoss

class torcheeg.losses.VirtualAdversarialTrainingLoss(xi: float = 10.0, eps: float = 1.0, iterations: int = 1)[source]

Virtual adversarial training loss is a regularization method based on virtual adversarial loss defined as the robustness of the conditional label distribution around each input data point against local perturbation. The virtual adversarial training loss smooth the model are only virtually adversarial and does not require label information and is hence applicable to semi-supervised learning.

inputs = torch.randn(3, 5, requires_grad=True)
model = nn.Linear(5, 5)
loss = VirtualAdversarialTrainingLoss()
output = loss(model, inputs)
Parameters
  • xi (float) – The hyperparameter xi in the focal loss. (defualt: 10.0)

  • eps (float) – The hyperparameter eps in the focal loss. (defualt: 1.0)

  • iterations (int) – iteration times of computing adversarial noise. (defualt: 1)

forward(model: Module, x: Tensor) Tensor[source]
Parameters
  • model (nn.Module) – For the model used for classification, the input of the forward function of the model should be a torch.Tensor, and the output should be a torch.Tensor corresponding to the predictions.

  • x (torch.Tensor) – The input data for the model.

Returns

The loss value.

Return type

torch.Tensor