Shortcuts

SimCLRTrainer

class torcheeg.trainers.SimCLRTrainer(extractor: Module, extract_channels: int, proj_channels: int = 256, proj_hid_channels: int = 512, lr: float = 0.0001, weight_decay: float = 0.0, temperature: float = 0.1, devices: int = 1, accelerator: str = 'cpu', metrics: List[str] = ['acc_top1'])[source][source]

This class supports the implementation of A Simple Framework for Contrastive Learning of Visual Representations (SimCLR) for self-supervised pre-training.

from torcheeg.models import CCNN
from torcheeg.trainers import BYOLTrainer

class Extractor(CCNN):
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = x.flatten(start_dim=1)
        return x

extractor = Extractor(in_channels=5, num_classes=3)
trainer = SimCLRTrainer(extractor,
                        devices=1,
                        accelerator='gpu')

NOTE: The first element of each batch in train_loader and val_loader should be a two-tuple, representing two random transformations (views) of data. You can use Contrastive to achieve this functionality.

from torcheeg.datasets import DEAPDataset
from torcheeg import transforms
from torcheeg.datasets.constants import DEAP_CHANNEL_LOCATION_DICT

contras_dataset = DEAPDataset(
    io_path=f'./io/deap',
    root_path='./data_preprocessed_python',
    offline_transform=transforms.Compose([
        transforms.BandDifferentialEntropy(sampling_rate=128,
                                        apply_to_baseline=True),
        transforms.BaselineRemoval(),
        transforms.ToGrid(DEAP_CHANNEL_LOCATION_DICT)
    ]),
    online_transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Contrastive(transforms.Compose( # see here
            [transforms.RandomMask(p=0.5),
            transforms.RandomNoise(p=0.5)]),
                            num_views=2)
    ]),
    chunk_size=128,
    baseline_chunk_size=128,
    num_baseline=3)

trainer.fit(train_loader, val_loader)
Parameters:
  • extractor (nn.Module) – The feature extraction model learns the feature representation of the EEG signal by forcing the correlation matrixes of source and target data to be close.

  • extract_channels (int) – The feature dimensions of the output of the feature extraction model.

  • proj_channels (int) – The feature dimensions of the output of the projection head. (default: 256)

  • proj_hid_channels (int) – The feature dimensions of the hidden layer of the projection head. (default: 512)

  • lr (float) – The learning rate. (default: 0.0001)

  • weight_decay (float) – The weight decay. (default: 0.0)

  • temperature (float) – The temperature. (default: 0.1)

  • devices (int) – The number of GPUs to use. (default: 1)

  • accelerator (str) – The accelerator to use. Available options are: ‘cpu’, ‘gpu’. (default: "cpu")

  • metrics (List[str]) – The metrics to use. Available options are: ‘acc_top1’, ‘acc_top5’, ‘acc_mean_pos’. (default: ["acc_top1"])

fit(train_loader: DataLoader, val_loader: DataLoader, max_epochs: int = 300, *args, **kwargs) Any[source][source]

NOTE: The first element of each batch in train_loader and val_loader should be a two-tuple, representing two random transformations (views) of data. You can use Contrastive to achieve this functionality.

Parameters:
  • train_loader (DataLoader) – Iterable DataLoader for traversing the training data batch (torch.utils.data.dataloader.DataLoader, torch_geometric.loader.DataLoader, etc).

  • val_loader (DataLoader) – Iterable DataLoader for traversing the validation data batch (torch.utils.data.dataloader.DataLoader, torch_geometric.loader.DataLoader, etc).

  • max_epochs (int) – Maximum number of epochs to train the model. (default: 300)

Read the Docs v: latest
Versions
latest
stable
v1.1.2
v1.1.1
v1.1.0
v1.0.11
v1.0.10
v1.0.9
v1.0.8.post1
v1.0.8
v1.0.7
v1.0.6
v1.0.4
v1.0.3
v1.0.2
v1.0.1
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources