CenterLossTrainer¶
- class torcheeg.trainers.CenterLossTrainer(decoder, classifier, feature_dim: int, num_classes: int, lammda: float = 0.0005, lr: float = 0.001, weight_decay: float = 0.0, devices: int = 1, accelerator: str = 'cpu', metrics: List[str] = ['accuracy', 'precision', 'recall', 'f1score'])[source][source]¶
A trainer trains classification model contains a decoder and a classifier. As for Center loss, it can make the output of the decoder close to the mean of decoded features within the same class. PLease refer to the following infomation to comprehend how the center loss works.
Paper: FBMSNet: A Filter-Bank Multi-Scale Convolutional Neural Network for EEG-Based Motor Imagery Decoding
Related Project: https://github.com/Want2Vanish/FBMSNet
trainer = CenterLossTrainer(decoder = decoder, classifier = classifier, num_classes = your_classes, feature_dim = your_decoded_dim) trainer.fit(train_loader, val_loader) trainer.test(test_loader)
The model structure is required to contains a decoder block which generates the deep feature code and a classifier connected to the decoder to judge which class the feature code belong to. Firstly, we should prepare a
decodermodel and aclassifiermodel for decoding and classifying from decoding ouput respectly. Here we take FBMSNet as example.torcheeg.models.FBMSNetcontains decoder and classifer method already and what We need to do is just to inherit the model to define a decoder and a classifier,and then override theforwardmethod .from torcheeg.models import FBMSNet class FBMSDecoder(FBMSNet): def forward(self,x): return self.decoder(x) class FBMSClassifier(FBMSNet): def forward(self,x): return decoder = FBMSDecoder(num_classes=4, num_electrodes=22, chunk_size=512, in_channels=9) classifier = FBMSClassifier(num_classes=4, num_electrodes=22, chunk_size=512, in_channels=9) trainer = CenterLossTrainer(decoder=decoder, classifier=classifier, num_classes=4, feature_dim=1152)
Custom model is OK. Feel free to refer to this example:
class MyDecoder(nn.Module): def __init__(self): self.layer = nn.Linear(128,64) #(input dim, decoded dim) def forward(self,x): return self.layer(x) class MyClassifier(nn.Module): def __init__(self): self.layer = nn.Linear(64,2) #(decoded dim, num_classes) def forward(self,x):classifier return self.layer(x) decoder = MyDecoder() classifier = MyClassifier() trainer = CenterLossTrainer(decoder = decoder, classifier = classifier, num_classes = 2, feature_dim = 64)
- Parameters:
decoder (nn.Module) – The decoder which transforms eegsignal into 1D feature code.
classifier (nn.Module) – The classifier that predict from the decoder output which class the siginals belong to.
feature_dim (int) – The dimemsion of decoder output code whose mean values we can loosely regard as the “center”.
num_classes (int, optional) – The number of categories in the dataset.
lammda (float) – The weight of the center loss in total loss. (default:
5e-4)lr (float) – The learning rate. (default:
0.001)weight_decay (float) – The weight decay. (default:
0.0)devices (int) – The number of devices to use. (default:
1)accelerator (str) – The accelerator to use. Available options are: ‘cpu’, ‘gpu’. (default:
"cpu")metrics (list of str) – The metrics to use. Available options are: ‘precision’, ‘recall’, ‘f1_score’, ‘accuracy’, ‘matthews’, ‘auroc’, and ‘kappa’. (default:
['accuracy', 'precision', 'recall', 'f1score'])
- fit(train_loader: DataLoader, val_loader: DataLoader, max_epochs: int = 300, *args, **kwargs) Any[source]¶
- Parameters:
train_loader (DataLoader) – Iterable DataLoader for traversing the training data batch (
torch.utils.data.dataloader.DataLoader,torch_geometric.loader.DataLoader, etc).val_loader (DataLoader) – Iterable DataLoader for traversing the validation data batch (
torch.utils.data.dataloader.DataLoader,torch_geometric.loader.DataLoader, etc).max_epochs (int) – Maximum number of epochs to train the model. (default:
300)