Shortcuts

CenterLossTrainer

class torcheeg.trainers.CenterLossTrainer(extractor, classifier, feature_dim: int, num_classes: int, lammda: float = 0.0005, lr: float = 0.001, weight_decay: float = 0.0, devices: int = 1, accelerator: str = 'cpu', metrics: List[str] = ['accuracy'])[source][source]

A trainer trains classification model contains a extractor and a classifier. As for Center loss, it can make the output of the extractor close to the mean of decoded features within the same class. PLease refer to the following infomation to comprehend how the center loss works.

The model structure is required to contains a extractor block which generates the deep feature code and a classifier connected to the extractor to judge which class the feature code belong to. Firstly, we should prepare a extractor model and a classifier model for decoding and classifying from decoding ouput respectly. Here we take FBMSNet as example. torcheeg.models.FBMSNet contains extractor and classifer method already and what We need to do is just to inherit the model to define a extractor and a classifier,and then override the forward method .

from torcheeg.models import FBMSNet
from torcheeg.trainers import CenterLossTrainer

class Extractor(FBMSNet):
    def forward(self, x):
        x = self.mixConv2d(x)
        x = self.scb(x)
        x = x.reshape([
            *x.shape[0:2], self.stride_factor,
            int(x.shape[3] / self.stride_factor)
        ])
        x = self.temporal_layer(x)
        return torch.flatten(x, start_dim=1)

class Classifier(FBMSNet):
    def forward(self, x):
        return self.fc(x)

extractor  = Extractor(num_classes=4,
                       num_electrodes=22,
                       chunk_size=512,
                       in_channels=9)

classifier = Classifier(num_classes=4,
                        num_electrodes=22,
                        chunk_size=512,
                        in_channels=9)

trainer = CenterLossTrainer(extractor=extractor,
                            classifier=classifier,
                            num_classes=4,
                            feature_dim=1152)
Parameters:
  • extractor (nn.Module) – The extractor which transforms eegsignal into 1D feature code.

  • classifier (nn.Module) – The classifier that predict from the extractor output which class the siginals belong to.

  • feature_dim (int) – The dimemsion of extractor output code whose mean values we can loosely regard as the “center”.

  • num_classes (int, optional) – The number of categories in the dataset.

  • lammda (float) – The weight of the center loss in total loss. (default: 5e-4)

  • lr (float) – The learning rate. (default: 0.001)

  • weight_decay (float) – The weight decay. (default: 0.0)

  • devices (int) – The number of devices to use. (default: 1)

  • accelerator (str) – The accelerator to use. Availabel options are: ‘cpu’, ‘gpu’. (default: "cpu")

  • metrics (list of str) – The metrics to use. Availabel options are: ‘precision’, ‘recall’, ‘f1score’, ‘accuracy’, ‘matthews’, ‘auroc’, and ‘kappa’. (default: ['accuracy', 'precision', 'recall', 'f1score'])

fit(train_loader: DataLoader, val_loader: DataLoader, max_epochs: int = 300, *args, **kwargs) Any[source]
Parameters:
  • train_loader (DataLoader) – Iterable DataLoader for traversing the training data batch (torch.utils.data.dataloader.DataLoader, torch_geometric.loader.DataLoader, etc).

  • val_loader (DataLoader) – Iterable DataLoader for traversing the validation data batch (torch.utils.data.dataloader.DataLoader, torch_geometric.loader.DataLoader, etc).

  • max_epochs (int) – Maximum number of epochs to train the model. (default: 300)

test(test_loader: DataLoader, *args, **kwargs) List[Dict[str, float]][source]
Parameters:

test_loader (DataLoader) – Iterable DataLoader for traversing the test data batch (torch.utils.data.dataloader.DataLoader, torch_geometric.loader.DataLoader, etc).

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources