{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "# Training models with vanilla PyTorch\nIn this case, we introduce how to use TorchEEG and a customized training process based on vanilla PyTorch to train a Continuous Convolutional Neural Network (CCNN) on the DEAP dataset for emotion classification.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import logging\nimport os\nimport random\nimport time\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nfrom torch.utils.data.dataloader import DataLoader\nfrom torcheeg import transforms\nfrom torcheeg.datasets import DEAPDataset\nfrom torcheeg.datasets.constants.emotion_recognition.deap import \\\n    DEAP_CHANNEL_LOCATION_DICT\nfrom torcheeg.model_selection import KFoldPerSubject, train_test_split\nfrom torcheeg.models import CCNN"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Pre-experiment Preparation to Ensure Reproducibility\nUse the logging module to store output in a log file for easy reference while printing it to the screen.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "os.makedirs('./tmp_out/examples_torch/log', exist_ok=True)\nlogger = logging.getLogger('Training models with vanilla PyTorch')\nlogger.setLevel(logging.DEBUG)\nconsole_handler = logging.StreamHandler()\ntimeticks = time.strftime(\"%Y-%m-%d-%H-%M-%S\", time.localtime())\nfile_handler = logging.FileHandler(\n    os.path.join('./tmp_out/examples_torch/log', f'{timeticks}.log'))\nlogger.addHandler(console_handler)\nlogger.addHandler(file_handler)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Set the random number seed in all modules to guarantee the same result when running again.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "def seed_everything(seed):\n    random.seed(seed)\n    np.random.seed(seed)\n    os.environ[\"PYTHONHASHSEED\"] = str(seed)\n    torch.manual_seed(seed)\n    torch.cuda.manual_seed(seed)\n    torch.backends.cudnn.deterministic = True\n    torch.backends.cudnn.benchmark = False\n\n\nseed_everything(42)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Customize the training process\nTorchEEG provides a large number of trainers to help complete the training of classification models, however, you can also define the training functions to complete the training and testing of the model. Here is a simple example:\n\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "# training process\ndef train(dataloader, model, loss_fn, optimizer):\n    size = len(dataloader.dataset)\n    model.train()\n    for batch_idx, batch in enumerate(dataloader):\n        X = batch[0].to(device)\n        y = batch[1].to(device)\n\n        # Compute prediction error\n        pred = model(X)\n        loss = loss_fn(pred, y)\n\n        # Backpropagation\n        optimizer.zero_grad()\n        loss.backward()\n        optimizer.step()\n\n        if batch_idx % 100 == 0:\n            loss, current = loss.item(), batch_idx * len(X)\n            logger.info(f\"Loss: {loss:>7f}  [{current:>5d}/{size:>5d}]\")\n\n    return loss\n\n\n# validation process\ndef valid(dataloader, model, loss_fn):\n    size = len(dataloader.dataset)\n    num_batches = len(dataloader)\n    model.eval()\n    loss, correct = 0, 0\n    with torch.no_grad():\n        for batch in dataloader:\n            X = batch[0].to(device)\n            y = batch[1].to(device)\n\n            pred = model(X)\n            loss += loss_fn(pred, y).item()\n            correct += (pred.argmax(1) == y).type(torch.float).sum().item()\n    loss /= num_batches\n    correct /= size\n    logger.info(\n        f\"Valid Error: \\n Accuracy: {(100*correct):>0.1f}%, Avg loss: {loss:>8f} \\n\"\n    )\n\n    return correct, loss"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Building Deep Learning Pipelines Using TorchEEG\nStep 1: Initialize the Dataset\n\nWe use the DEAP dataset supported by TorchEEG. Here, we set an EEG sample to 1 second long and include 128 data points. The baseline signal is 3 seconds long, cut into three, and averaged as the baseline signal for the trial. In offline preprocessing, we divide the EEG signal of every electrode into 4 sub-bands, and calculate the differential entropy on each sub-band as a feature, followed by debaselining and mapping on the grid. Finally, the preprocessed EEG signals are stored in the local IO. In online processing, all EEG signals are converted into Tensors for input into neural networks.\n\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "dataset = DEAPDataset(\n    io_path=f'./tmp_out/examples_torch/deap',\n    root_path='./tmp_in/data_preprocessed_python',\n    offline_transform=transforms.Compose([\n        transforms.BandDifferentialEntropy(apply_to_baseline=True),\n        transforms.BaselineRemoval(),\n        transforms.ToGrid(DEAP_CHANNEL_LOCATION_DICT)\n    ]),\n    online_transform=transforms.ToTensor(),\n    label_transform=transforms.Compose([\n        transforms.Select('valence'),\n        transforms.Binary(5.0),\n    ]),\n    num_worker=8)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "<div class=\"alert alert-danger\"><h4>Warning</h4><p>If you use TorchEEG under the `Windows` system and want to use multiple processes (such as in dataset or dataloader), you should check whether :obj:`__name__` is :obj:`__main__` to avoid errors caused by multiple :obj:`import`.</p></div>\n\nThat is, under the :obj:`Windows` system, you need to:\n```\nif __name__ == \"__main__\":\n    dataset = DEAPDataset(io_path=f'./tmp_out/examples_ccnn/deap',\n                          root_path='./tmp_in/data_preprocessed_python',\n                          offline_transform=transforms.Compose([\n                              transforms.BandDifferentialEntropy(apply_to_baseline=True),\n                              transforms.BaselineRemoval(),\n                              transforms.ToGrid(DEAP_CHANNEL_LOCATION_DICT)\n                          ]),\n                          online_transform=transforms.ToTensor(),\n                          label_transform=transforms.Compose([\n                              transforms.Select('valence'),\n                              transforms.Binary(5.0),\n                          ]),\n                          io_mode='pickle',\n                          num_worker=8)\n    # the following codes\n```\n<div class=\"alert alert-info\"><h4>Note</h4><p>LMDB may not be optimized for parts of Windows systems or storage devices. If you find that the data preprocessing speed is slow, you can consider setting :obj:`io_mode` to :obj:`pickle`, which is an alternative implemented by TorchEEG based on pickle.</p></div>\n\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Step 2: Divide the Training and Test samples in the Dataset\n\nHere, the dataset is divided using per-subject 5-fold cross-validation. In the process of division, we split the training and test sets separately on each subject's EEG samples. Here, we take 4 folds as training samples and 1 fold as test samples.\n\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "k_fold = KFoldPerSubject(n_splits=10,\n                         split_path='./tmp_out/examples_ccnn/split',\n                         shuffle=True)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Step 3: Define the Model and Start Training\n\nWe first use a loop to get the dataset in each cross-validation. In each cross-validation, we initialize the CCNN model and define the hyperparameters. For example, each EEG sample contains 4-channel features from 4 sub-bands, the grid size is 9 times 9, etc.\n\nNext, we train the model for 50 epochs using the training function defined above and report the model performance on the validation set at each epoch with the validation function defined above.\n\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\nloss_fn = nn.CrossEntropyLoss()\nbatch_size = 64\n\ntest_accs = []\ntest_losses = []\n\nfor i, (train_dataset, test_dataset) in enumerate(k_fold.split(dataset)):\n    # initialize model\n    model = CCNN(num_classes=2, in_channels=4, grid_size=(9, 9)).to(device)\n    # initialize optimizer\n    optimizer = torch.optim.Adam(model.parameters(),\n                                 lr=1e-4)  # official: weight_decay=5e-1\n    # split train and val\n    train_dataset, val_dataset = train_test_split(\n        train_dataset,\n        test_size=0.2,\n        split_path=f'./tmp_out/examples_ccnn/split{i}',\n        shuffle=True)\n    train_loader = DataLoader(train_dataset,\n                              batch_size=batch_size,\n                              shuffle=True)\n    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)\n\n    epochs = 50\n    best_val_acc = 0.0\n    for t in range(epochs):\n        train_loss = train(train_loader, model, loss_fn, optimizer)\n        val_acc, val_loss = valid(val_loader, model, loss_fn)\n        # save the best model based on val_acc\n        if val_acc > best_val_acc:\n            best_val_acc = val_acc\n            torch.save(model.state_dict(),\n                       f'./tmp_out/examples_ccnn/model{i}.pt')\n\n    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)\n\n    # load the best model to test on test set\n    model.load_state_dict(torch.load(f'./tmp_out/examples_ccnn/model{i}.pt'))\n    test_acc, test_loss = valid(test_loader, model, loss_fn)\n\n    # log the test result\n    logger.info(\n        f\"Test Error {i}: \\n Accuracy: {(100*test_acc):>0.1f}%, Avg loss: {test_loss:>8f}\"\n    )\n\n    test_accs.append(test_acc)\n    test_losses.append(test_loss)\n\n# log the average test result on cross-validation datasets\nlogger.info(\n    f\"Test Error: \\n Accuracy: {100*np.mean(test_accs):>0.1f}%, Avg loss: {np.mean(test_losses):>8f}\"\n)"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.8.6"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}