{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "# Training models with PyTorch Geometric\nIn this case, we introduce how to use TorchEEG and a customized training process based on vanilla PyTorch to train a PyTorch Geometric-based graph convolutional network on the SEED dataset for emotion classification.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import logging\nimport os\nimport random\nimport time\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torch.nn import Linear\nfrom torch_geometric.loader import DataLoader\nfrom torch_geometric.nn import GATConv, global_mean_pool\n\nfrom torcheeg import transforms\nfrom torcheeg.datasets import SEEDFeatureDataset\nfrom torcheeg.datasets.constants.emotion_recognition.seed import \\\n    SEED_ADJACENCY_MATRIX\nfrom torcheeg.model_selection import KFoldPerSubjectGroupbyTrial\nfrom torcheeg.transforms.pyg import ToG"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Pre-experiment Preparation to Ensure Reproducibility\nUse the logging module to store output in a log file for easy reference while printing it to the screen.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "os.makedirs('./tmp_out/examples_torch/log', exist_ok=True)\nlogger = logging.getLogger('Training models with vanilla PyTorch')\nlogger.setLevel(logging.DEBUG)\nconsole_handler = logging.StreamHandler()\ntimeticks = time.strftime(\"%Y-%m-%d-%H-%M-%S\", time.localtime())\nfile_handler = logging.FileHandler(\n    os.path.join('./tmp_out/examples_torch/log', f'{timeticks}.log'))\nlogger.addHandler(console_handler)\nlogger.addHandler(file_handler)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Set the random number seed in all modules to guarantee the same result when running again.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "def seed_everything(seed):\n    random.seed(seed)\n    np.random.seed(seed)\n    os.environ[\"PYTHONHASHSEED\"] = str(seed)\n    torch.manual_seed(seed)\n    torch.cuda.manual_seed(seed)\n    torch.backends.cudnn.deterministic = True\n    torch.backends.cudnn.benchmark = False\n\n\nseed_everything(42)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Defining a graph convolutional network\nUse the API provided by PyTorch Geometric to define graph convolutional networks. Here, the EEG signal or feature of the electrode corresponds to the input :obj`data.x`, and the relationship between electrodes corresponds to :obj`data.edge_index`. Depending on the definition of the adjacency matrix, the relationship may represent spatial adjacency, etc.\n\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "class GNN(torch.nn.Module):\n    def __init__(self,\n                 in_channels=5,\n                 num_layers=3,\n                 hid_channels=64,\n                 num_classes=3):\n        super().__init__()\n        self.conv1 = GATConv(in_channels, hid_channels)\n        self.convs = torch.nn.ModuleList()\n        for _ in range(num_layers - 1):\n            self.convs.append(GATConv(hid_channels, hid_channels))\n        self.lin1 = Linear(hid_channels, hid_channels)\n        self.lin2 = Linear(hid_channels, num_classes)\n\n    def reset_parameters(self):\n        self.conv1.reset_parameters()\n        for conv in self.convs:\n            conv.reset_parameters()\n        self.lin1.reset_parameters()\n        self.lin2.reset_parameters()\n\n    def forward(self, data):\n        x, edge_index, batch = data.x, data.edge_index, data.batch\n        x = F.relu(self.conv1(x, edge_index))\n        for conv in self.convs:\n            x = F.relu(conv(x, edge_index))\n        x = global_mean_pool(x, batch)\n        x = F.relu(self.lin1(x))\n        x = F.dropout(x, p=0.5, training=self.training)\n        x = self.lin2(x)\n        return x"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Customize the training process\nTorchEEG provides a large number of trainers to help complete the training of classification models, however, you can also define the training functions to complete the training and testing of the model. Here is a simple example:\n\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "# training process\ndef train(dataloader, model, loss_fn, optimizer):\n    size = len(dataloader.dataset)\n    model.train()\n    for batch_idx, batch in enumerate(dataloader):\n        X = batch[0].to(device)\n        y = batch[1].to(device)\n\n        # Compute prediction error\n        pred = model(X)\n        loss = loss_fn(pred, y)\n\n        # Backpropagation\n        optimizer.zero_grad()\n        loss.backward()\n        optimizer.step()\n\n        if batch_idx % 20 == 0:\n            loss, current = loss.item(), batch_idx * batch_size\n            logger.info(f\"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]\")\n\n\n# validation process\ndef valid(dataloader, model, loss_fn):\n    size = len(dataloader.dataset)\n    num_batches = len(dataloader)\n    model.eval()\n    val_loss, correct = 0, 0\n    with torch.no_grad():\n        for batch in dataloader:\n            X = batch[0].to(device)\n            y = batch[1].to(device)\n\n            pred = model(X)\n            val_loss += loss_fn(pred, y).item()\n            correct += (pred.argmax(1) == y).type(torch.float).sum().item()\n    val_loss /= num_batches\n    correct /= size\n    logger.info(\n        f\"Test Error: \\n Accuracy: {(100*correct):>0.1f}%, Avg loss: {val_loss:>8f} \\n\"\n    )"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Building Deep Learning Pipelines Using TorchEEG\nStep 1: Initialize the Dataset\n\nWe use the SEED dataset supported by TorchEEG. Here we are using extracted EEG features. In the feature dataset, EEG signals (200 data points) per second are pre-computed with differential entropy in five sub-bands and smoothed using a linear dynamical system.\n<div class=\"alert alert-info\"><h4>Note</h4><p>In online processing, All EEG signals are converted into a graph structure that can be processed by the PyTorch Geometric model, i.e., :obj:`torch_geometric.data.Data`, according to the adjacency matrix and the signals or features from electrodes. Here, electrodes represent nodes, and the adjacency matrix defines the association between electrodes.</p></div>\n\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "dataset = SEEDFeatureDataset(io_path='./tmp_out/examples_torch_geometric/seed',\n                             root_path='./tmp_in/ExtractedFeatures',\n                             feature=['de_movingAve'],\n                             online_transform=transforms.Compose([\n                                 transforms.MinMaxNormalize(axis=-1),\n                                 ToG(SEED_ADJACENCY_MATRIX)\n                             ]),\n                             label_transform=transforms.Compose([\n                                 transforms.Select('emotion'),\n                                 transforms.Lambda(lambda x: int(x) + 1),\n                             ]),\n                             num_worker=8)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "<div class=\"alert alert-danger\"><h4>Warning</h4><p>If you use TorchEEG under the `Windows` system and want to use multiple processes (such as in dataset or dataloader), you should check whether :obj:`__name__` is :obj:`__main__` to avoid errors caused by multiple :obj:`import`.</p></div>\n\nThat is, under the :obj:`Windows` system, you need to:\n```\nif __name__ == \"__main__\":\n    dataset = SEEDFeatureDataset(io_path='./tmp_out/examples_torch_geometric/seed',\n                          root_path='./tmp_in/ExtractedFeatures',\n                          feature=['de_movingAve'],\n                          online_transform=transforms.Compose([\n                              transforms.MinMaxNormalize(axis=-1),\n                              ToG(SEED_ADJACENCY_MATRIX)\n                          ]),\n                          label_transform=transforms.Compose([\n                              transforms.Select('emotion'),\n                              transforms.Lambda(lambda x: int(x) + 1),\n                          ]),\n                          io_mode='pickle',\n                          num_worker=8)\n    # the following codes\n```\n<div class=\"alert alert-info\"><h4>Note</h4><p>LMDB may not be optimized for parts of Windows systems or storage devices. If you find that the data preprocessing speed is slow, you can consider setting :obj:`io_mode` to :obj:`pickle`, which is an alternative implemented by TorchEEG based on pickle.</p></div>\n\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Step 2: Divide the Training and Test samples in the Dataset\n\nHere, the dataset is divided using per-subject 5-fold cross-validation. In the process of division, we split the training and test sets separately on each subject's EEG samples. Here, we take 4 folds as training samples and 1 fold as test samples.\n\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "k_fold = KFoldPerSubjectGroupbyTrial(\n    n_splits=10,\n    split_path='./tmp_out/examples_torch_geometric/split',\n    shuffle=False)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Step 3: Define the Model and Start Training\n\nWe first use a loop to get the dataset in each cross-validation. In each cross-validation, we initialize the above-mentioned GNN model.\n\nNext, we train the model for 50 epochs using the training function defined above and report the model performance on the validation set at each epoch with the validation function defined above.\n\n\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "<div class=\"alert alert-info\"><h4>Note</h4><p>Please note that since the EEG signal sample returned by ToG is of type :obj:`torch_geometric.data.Data`, which represents a graph structure, the DataLoader provided by PyTorch cannot be used here to form a batch. Instead, :obj:`torch_geometric.loader.DataLoader` should be used to batch the adjacency matrix of the graph structure and the node features.</p></div>\n\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\nloss_fn = nn.CrossEntropyLoss()\nbatch_size = 64\n\nfor i, (train_dataset, val_dataset) in enumerate(k_fold.split(dataset)):\n    # initialize model\n    model = GNN().to(device)\n    # initialize optimizer\n    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)\n\n    train_loader = DataLoader(train_dataset,\n                              batch_size=batch_size,\n                              shuffle=True)\n    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)\n\n    epochs = 50\n    for t in range(epochs):\n        logger.info(f\"Epoch {t+1}\\n-------------------------------\")\n        train(train_loader, model, loss_fn, optimizer)\n        valid(val_loader, model, loss_fn)"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.8.6"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}