Shortcuts

CenterLossTrainer

class torcheeg.trainers.CenterLossTrainer(extractor, classifier, feature_dim: int, num_classes: int, lammda: float = 0.0005, lr: float = 0.001, weight_decay: float = 0.0, devices: int = 1, accelerator: str = 'cpu', metrics: List[str] = ['accuracy', 'precision', 'recall', 'f1score'])[source][source]

A trainer trains classification model contains a extractor and a classifier. As for Center loss, it can make the output of the extractor close to the mean of decoded features within the same class. PLease refer to the following infomation to comprehend how the center loss works.

The model structure is required to contains a extractor block which generates the deep feature code and a classifier connected to the extractor to judge which class the feature code belong to. Firstly, we should prepare a extractor model and a classifier model for decoding and classifying from decoding ouput respectly. Here we take FBMSNet as example. torcheeg.models.FBMSNet contains extractor and classifer method already and what We need to do is just to inherit the model to define a extractor and a classifier,and then override the forward method .

from torcheeg.models import FBMSNet
from torcheeg.trainers import CenterLossTrainer

class Extractor(FBMSNet):
    def forward(self, x):
        x = self.mixConv2d(x)
        x = self.scb(x)
        x = x.reshape([
            *x.shape[0:2], self.stride_factor,
            int(x.shape[3] / self.stride_factor)
        ])
        x = self.temporal_layer(x)
        return torch.flatten(x, start_dim=1)

class Classifier(FBMSNet):
    def forward(self, x):
        return self.fc(x)

extractor  = Extractor(num_classes=4,
                       num_electrodes=22,
                       chunk_size=512,
                       in_channels=9)

classifier = Classifier(num_classes=4,
                        num_electrodes=22,
                        chunk_size=512,
                        in_channels=9)

trainer = CenterLossTrainer(extractor=extractor,
                            classifier=classifier,
                            num_classes=4,
                            feature_dim=1152)
Parameters:
  • extractor (nn.Module) – The extractor which transforms eegsignal into 1D feature code.

  • classifier (nn.Module) – The classifier that predict from the extractor output which class the siginals belong to.

  • feature_dim (int) – The dimemsion of extractor output code whose mean values we can loosely regard as the “center”.

  • num_classes (int, optional) – The number of categories in the dataset.

  • lammda (float) – The weight of the center loss in total loss. (default: 5e-4)

  • lr (float) – The learning rate. (default: 0.001)

  • weight_decay (float) – The weight decay. (default: 0.0)

  • devices (int) – The number of devices to use. (default: 1)

  • accelerator (str) – The accelerator to use. Available options are: ‘cpu’, ‘gpu’. (default: "cpu")

  • metrics (list of str) – The metrics to use. Available options are: ‘precision’, ‘recall’, ‘f1score’, ‘accuracy’, ‘matthews’, ‘auroc’, and ‘kappa’. (default: ['accuracy', 'precision', 'recall', 'f1score'])

fit(train_loader: DataLoader, val_loader: DataLoader, max_epochs: int = 300, *args, **kwargs) Any[source]
Parameters:
  • train_loader (DataLoader) – Iterable DataLoader for traversing the training data batch (torch.utils.data.dataloader.DataLoader, torch_geometric.loader.DataLoader, etc).

  • val_loader (DataLoader) – Iterable DataLoader for traversing the validation data batch (torch.utils.data.dataloader.DataLoader, torch_geometric.loader.DataLoader, etc).

  • max_epochs (int) – Maximum number of epochs to train the model. (default: 300)

test(test_loader: DataLoader, *args, **kwargs) List[Dict[str, float]][source]
Parameters:

test_loader (DataLoader) – Iterable DataLoader for traversing the test data batch (torch.utils.data.dataloader.DataLoader, torch_geometric.loader.DataLoader, etc).

Read the Docs v: latest
Versions
latest
stable
v1.1.2
v1.1.1
v1.1.0
v1.0.11
v1.0.10
v1.0.9
v1.0.8.post1
v1.0.8
v1.0.7
v1.0.6
v1.0.4
v1.0.3
v1.0.2
v1.0.1
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources