BYOLTrainer¶
- class torcheeg.trainers.BYOLTrainer(extractor: Module, extract_channels: int, proj_channels: int = 256, proj_hid_channels: int = 512, lr: float = 0.0001, weight_decay: float = 0.0, moving_average_decay=0.99, devices: int = 1, accelerator: str = 'cpu', metrics: List[str] = ['acc_top1'])[source][source]¶
This class supports the implementation of Bootstrap Your Own Latent (BYOL) for self-supervised pre-training.
Paper: Grill J B, Strub F, Altché F, et al. Bootstrap your own latent-a new approach to self-supervised learning[J]. Advances in neural information processing systems, 2020, 33: 21271-21284.
URL: https://proceedings.neurips.cc/paper/2020/hash/f3ada80d5c4ee70142b17b8192b2958e-Abstract.html
Related Project: https://github.com/lucidrains/byol-pytorch
trainer = BYOLTrainer(extractor, extract_channels=256, devices=1, accelerator='gpu') trainer.fit(train_loader, val_loader)
NOTE: The first element of each batch in
train_loaderandval_loadershould be a two-tuple, representing two random transformations (views) of data. You can useContrastiveto achieve this functionality.contras_dataset = DEAPDataset( io_path=f'./io/deap', root_path='./data_preprocessed_python', offline_transform=transforms.Compose([ transforms.BandDifferentialEntropy(sampling_rate=128, apply_to_baseline=True), transforms.BaselineRemoval(), transforms.ToGrid(DEAP_CHANNEL_LOCATION_DICT) ]), online_transform=transforms.Compose([ transforms.ToTensor(), transforms.Contrastive(transforms.Compose( # see here [transforms.RandomMask(p=0.5), transforms.RandomNoise(p=0.5)]), num_views=2) ]), chunk_size=128, baseline_chunk_size=128, num_baseline=3) trainer = BYOLTrainer(extractor, extract_channels=256, devices=1, accelerator='gpu') trainer.fit(train_loader, val_loader)
- Parameters:
extractor (nn.Module) – The feature extraction model learns the feature representation of the EEG signal by forcing the correlation matrixes of source and target data to be close.
extract_channels (int) – The feature dimensions of the output of the feature extraction model.
proj_channels (int) – The feature dimensions of the output of the projection head. (default:
256)proj_hid_channels (int) – The feature dimensions of the hidden layer of the projection head. (default:
512)lr (float) – The learning rate. (default:
0.0001)weight_decay (float) – The weight decay. (default:
0.0)devices (int) – The number of GPUs to use. (default:
1)accelerator (str) – The accelerator to use. Available options are: ‘cpu’, ‘gpu’. (default:
"cpu")metrics (List[str]) – The metrics to use. Available options are: ‘acc_top1’, ‘acc_top5’, ‘acc_mean_pos’. (default:
["acc_top1"])
- fit(train_loader: DataLoader, val_loader: DataLoader, max_epochs: int = 300, *args, **kwargs) Any[source][source]¶
NOTE: The first element of each batch in
train_loaderandval_loadershould be a two-tuple, representing two random transformations (views) of data. You can useContrastiveto achieve this functionality.- Parameters:
train_loader (DataLoader) – Iterable DataLoader for traversing the training data batch (
torch.utils.data.dataloader.DataLoader,torch_geometric.loader.DataLoader, etc).val_loader (DataLoader) – Iterable DataLoader for traversing the validation data batch (
torch.utils.data.dataloader.DataLoader,torch_geometric.loader.DataLoader, etc).max_epochs (int) – Maximum number of epochs to train the model. (default:
300)