Shortcuts

BYOLTrainer

class torcheeg.trainers.BYOLTrainer(extractor: Module, extract_channels: int, proj_channels: int = 256, proj_hid_channels: int = 512, lr: float = 0.0001, weight_decay: float = 0.0, moving_average_decay=0.99, devices: int = 1, accelerator: str = 'cpu', metrics: List[str] = ['acc_top1'])[source][source]

This class supports the implementation of Bootstrap Your Own Latent (BYOL) for self-supervised pre-training.

trainer = BYOLTrainer(extractor,
                      extract_channels=256,
                      devices=1,
                      accelerator='gpu')
trainer.fit(train_loader, val_loader)

NOTE: The first element of each batch in train_loader and val_loader should be a two-tuple, representing two random transformations (views) of data. You can use Contrastive to achieve this functionality.

contras_dataset = DEAPDataset(
    io_path=f'./io/deap',
    root_path='./data_preprocessed_python',
    offline_transform=transforms.Compose([
        transforms.BandDifferentialEntropy(sampling_rate=128,
                                        apply_to_baseline=True),
        transforms.BaselineRemoval(),
        transforms.ToGrid(DEAP_CHANNEL_LOCATION_DICT)
    ]),
    online_transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Contrastive(transforms.Compose( # see here
            [transforms.RandomMask(p=0.5),
            transforms.RandomNoise(p=0.5)]),
                            num_views=2)
    ]),
    chunk_size=128,
    baseline_chunk_size=128,
    num_baseline=3)

trainer = BYOLTrainer(extractor,
                      extract_channels=256,
                      devices=1,
                      accelerator='gpu')
trainer.fit(train_loader, val_loader)
Parameters:
  • extractor (nn.Module) – The feature extraction model learns the feature representation of the EEG signal by forcing the correlation matrixes of source and target data to be close.

  • extract_channels (int) – The feature dimensions of the output of the feature extraction model.

  • proj_channels (int) – The feature dimensions of the output of the projection head. (default: 256)

  • proj_hid_channels (int) – The feature dimensions of the hidden layer of the projection head. (default: 512)

  • lr (float) – The learning rate. (default: 0.0001)

  • weight_decay (float) – The weight decay. (default: 0.0)

  • devices (int) – The number of GPUs to use. (default: 1)

  • accelerator (str) – The accelerator to use. Available options are: ‘cpu’, ‘gpu’. (default: "cpu")

  • metrics (List[str]) – The metrics to use. Available options are: ‘acc_top1’, ‘acc_top5’, ‘acc_mean_pos’. (default: ["acc_top1"])

fit(train_loader: DataLoader, val_loader: DataLoader, max_epochs: int = 300, *args, **kwargs) Any[source][source]

NOTE: The first element of each batch in train_loader and val_loader should be a two-tuple, representing two random transformations (views) of data. You can use Contrastive to achieve this functionality.

Parameters:
  • train_loader (DataLoader) – Iterable DataLoader for traversing the training data batch (torch.utils.data.dataloader.DataLoader, torch_geometric.loader.DataLoader, etc).

  • val_loader (DataLoader) – Iterable DataLoader for traversing the validation data batch (torch.utils.data.dataloader.DataLoader, torch_geometric.loader.DataLoader, etc).

  • max_epochs (int) – Maximum number of epochs to train the model. (default: 300)

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources