BYOLTrainer¶
- class torcheeg.trainers.BYOLTrainer(extractor: Module, extract_channels: int, proj_channels: int = 256, proj_hid_channels: int = 512, lr: float = 0.0001, weight_decay: float = 0.0, moving_average_decay=0.99, devices: int = 1, accelerator: str = 'cpu', metrics: List[str] = ['acc_top1'])[source][source]¶
This class supports the implementation of Bootstrap Your Own Latent (BYOL) for self-supervised pre-training.
Paper: Grill J B, Strub F, Altché F, et al. Bootstrap your own latent-a new approach to self-supervised learning[J]. Advances in neural information processing systems, 2020, 33: 21271-21284.
URL: https://proceedings.neurips.cc/paper/2020/hash/f3ada80d5c4ee70142b17b1048576b2958e-Abstract.html
Related Project: https://github.com/lucidrains/byol-pytorch
from torcheeg.models import CCNN from torcheeg.trainers import BYOLTrainer class Extractor(CCNN): def forward(self, x): x = self.conv1(x) x = self.conv2(x) x = self.conv3(x) x = self.conv4(x) x = x.flatten(start_dim=1) return x extractor = Extractor(in_channels=5, num_classes=3) trainer = BYOLTrainer(extractor, extract_channels=256, devices=1, accelerator='gpu')
NOTE: The first element of each batch in
train_loader
andval_loader
should be a two-tuple, representing two random transformations (views) of data. You can useContrastive
to achieve this functionality.from torcheeg.datasets import DEAPDataset from torcheeg import transforms from torcheeg.datasets.constants import DEAP_CHANNEL_LOCATION_DICT contras_dataset = DEAPDataset( io_path=f'./io/deap', root_path='./data_preprocessed_python', offline_transform=transforms.Compose([ transforms.BandDifferentialEntropy(sampling_rate=128, apply_to_baseline=True), transforms.BaselineRemoval(), transforms.ToGrid(DEAP_CHANNEL_LOCATION_DICT) ]), online_transform=transforms.Compose([ transforms.ToTensor(), transforms.Contrastive(transforms.Compose( # see here [transforms.RandomMask(p=0.5), transforms.RandomNoise(p=0.5)]), num_views=2) ]), chunk_size=128, baseline_chunk_size=128, num_baseline=3) trainer.fit(train_loader, val_loader)
- Parameters:
extractor (nn.Module) – The feature extraction model learns the feature representation of the EEG signal by forcing the correlation matrixes of source and target data to be close.
extract_channels (int) – The feature dimensions of the output of the feature extraction model.
proj_channels (int) – The feature dimensions of the output of the projection head. (default:
256
)proj_hid_channels (int) – The feature dimensions of the hidden layer of the projection head. (default:
512
)lr (float) – The learning rate. (default:
0.0001
)weight_decay (float) – The weight decay. (default:
0.0
)devices (int) – The number of GPUs to use. (default:
1
)accelerator (str) – The accelerator to use. Available options are: ‘cpu’, ‘gpu’. (default:
"cpu"
)metrics (List[str]) – The metrics to use. Available options are: ‘acc_top1’, ‘acc_top5’, ‘acc_mean_pos’. (default:
["acc_top1"]
)
- fit(train_loader: DataLoader, val_loader: DataLoader, max_epochs: int = 300, *args, **kwargs) Any [source][source]¶
NOTE: The first element of each batch in
train_loader
andval_loader
should be a two-tuple, representing two random transformations (views) of data. You can useContrastive
to achieve this functionality.- Parameters:
train_loader (DataLoader) – Iterable DataLoader for traversing the training data batch (
torch.utils.data.dataloader.DataLoader
,torch_geometric.loader.DataLoader
, etc).val_loader (DataLoader) – Iterable DataLoader for traversing the validation data batch (
torch.utils.data.dataloader.DataLoader
,torch_geometric.loader.DataLoader
, etc).max_epochs (int) – Maximum number of epochs to train the model. (default:
300
)