SimCLRTrainer¶
- class torcheeg.trainers.SimCLRTrainer(extractor: Module, extract_channels: int, proj_channels: int = 256, proj_hid_channels: int = 512, lr: float = 0.0001, weight_decay: float = 0.0, temperature: float = 0.1, devices: int = 1, accelerator: str = 'cpu', metrics: List[str] = ['acc_top1'])[source][source]¶
This class supports the implementation of A Simple Framework for Contrastive Learning of Visual Representations (SimCLR) for self-supervised pre-training.
Paper: Chen T, Kornblith S, Norouzi M, et al. A simple framework for contrastive learning of visual representations[C]//International conference on machine learning. PMLR, 2020: 1597-1607.
Related Project: https://github.com/sthalles/SimCLR
trainer = SimCLRTrainer(extractor, devices=1, accelerator='gpu') trainer.fit(train_loader, val_loader)
NOTE: The first element of each batch in
train_loaderandval_loadershould be a two-tuple, representing two random transformations (views) of data. You can useContrastiveto achieve this functionality.contras_dataset = DEAPDataset( io_path=f'./io/deap', root_path='./data_preprocessed_python', offline_transform=transforms.Compose([ transforms.BandDifferentialEntropy(sampling_rate=128, apply_to_baseline=True), transforms.BaselineRemoval(), transforms.ToGrid(DEAP_CHANNEL_LOCATION_DICT) ]), online_transform=transforms.Compose([ transforms.ToTensor(), transforms.Contrastive(transforms.Compose( # see here [transforms.RandomMask(p=0.5), transforms.RandomNoise(p=0.5)]), num_views=2) ]), chunk_size=128, baseline_chunk_size=128, num_baseline=3) trainer = SimCLRTrainer(extractor, devices=1, accelerator='gpu') trainer.fit(train_loader, val_loader)
- Parameters:
extractor (nn.Module) – The feature extraction model learns the feature representation of the EEG signal by forcing the correlation matrixes of source and target data to be close.
extract_channels (int) – The feature dimensions of the output of the feature extraction model.
proj_channels (int) – The feature dimensions of the output of the projection head. (default:
256)proj_hid_channels (int) – The feature dimensions of the hidden layer of the projection head. (default:
512)lr (float) – The learning rate. (default:
0.0001)weight_decay (float) – The weight decay. (default:
0.0)temperature (float) – The temperature. (default:
0.1)devices (int) – The number of GPUs to use. (default:
1)accelerator (str) – The accelerator to use. Available options are: ‘cpu’, ‘gpu’. (default:
"cpu")metrics (List[str]) – The metrics to use. Available options are: ‘acc_top1’, ‘acc_top5’, ‘acc_mean_pos’. (default:
["acc_top1"])
- fit(train_loader: DataLoader, val_loader: DataLoader, max_epochs: int = 300, *args, **kwargs) Any[source][source]¶
NOTE: The first element of each batch in
train_loaderandval_loadershould be a two-tuple, representing two random transformations (views) of data. You can useContrastiveto achieve this functionality.- Parameters:
train_loader (DataLoader) – Iterable DataLoader for traversing the training data batch (
torch.utils.data.dataloader.DataLoader,torch_geometric.loader.DataLoader, etc).val_loader (DataLoader) – Iterable DataLoader for traversing the validation data batch (
torch.utils.data.dataloader.DataLoader,torch_geometric.loader.DataLoader, etc).max_epochs (int) – Maximum number of epochs to train the model. (default:
300)