{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "# GAN with the DEAP Dataset\nIn this case, we introduce how to use TorchEEG to train Generative Adversarial Networks (GAN) on the DEAP dataset for controllable EEG augmentation (which generates EEG signals confirming the given labels).\n\nHere, a zero game is played to optimize a generator and a discriminator. The generator is used to generate EEG signal samples according to the given labels, and the discriminator is used to distinguish generated samples from real samples. By confusing the discriminator, the generator will be able to produce samples that approximate the real distribution.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import logging\nimport os\nimport random\nimport time\n\nimport numpy as np\nimport torch\nimport torch.autograd as autograd\nfrom tensorboardX import SummaryWriter\nfrom torch.utils.data.dataloader import DataLoader\nfrom torcheeg import transforms\nfrom torcheeg.datasets import DEAPDataset\nfrom torcheeg.datasets.constants.emotion_recognition.deap import (\n    DEAP_CHANNEL_LIST, DEAP_CHANNEL_LOCATION_DICT)\nfrom torcheeg.model_selection import KFoldGroupbyTrial\nfrom torcheeg.models import BCDiscriminator, BCGenerator\nfrom torcheeg.trainers import CGANTrainer\nfrom torcheeg.utils import plot_feature_topomap"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Pre-experiment Preparation to Ensure Reproducibility\nUse the logging module to store output in a log file for easy reference while printing it to the screen.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "os.makedirs('./tmp_out/examples_deap_gan/log', exist_ok=True)\nlogger = logging.getLogger('GAN with the DEAP Dataset')\nlogger.setLevel(logging.DEBUG)\nconsole_handler = logging.StreamHandler()\ntimeticks = time.strftime(\"%Y-%m-%d-%H-%M-%S\", time.localtime())\nfile_handler = logging.FileHandler(\n    os.path.join('./tmp_out/examples_deap_gan/log', f'{timeticks}.log'))\nlogger.addHandler(console_handler)\nlogger.addHandler(file_handler)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Set the random number seed in all modules to guarantee the same result when running again.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "def seed_everything(seed):\n    random.seed(seed)\n    np.random.seed(seed)\n    os.environ[\"PYTHONHASHSEED\"] = str(seed)\n    torch.manual_seed(seed)\n    torch.cuda.manual_seed(seed)\n    torch.backends.cudnn.deterministic = True\n    torch.backends.cudnn.benchmark = False\n\n\nseed_everything(42)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Customize Trainer\nTorchEEG provides many trainers to help complete the training of classification models, generative models and cross-domain methods. Here we choose the conditional generative adversarial network trainer (CGANTrainer), inherit the trainer and overload the log function to save the log using our own defined method; the :obj:`on_validation_step` hook function is overloaded to visualize EEG signals generated during validation using tensorboard and the tool function :obj:`plot_feature_topomap` in TorchEEG; other hook functions can also be overloaded to meet special needs.\n\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "writer = SummaryWriter(log_dir='./tmp_out/examples_deap_gan/vis',\n                       comment='examples_deap_gan')\n\n\ndef gradient_penalty(model, real, fake, label=None):\n    device = real.device\n    real = real.data\n    fake = fake.data\n    alpha = torch.rand(real.size(0), *([1] * (len(real.shape) - 1))).to(device)\n    inputs = alpha * real + ((1 - alpha) * fake)\n    inputs.requires_grad_()\n\n    if label is None:\n        outputs = model(inputs)\n    else:\n        outputs = model(inputs, label)\n\n    gradient = autograd.grad(outputs=outputs,\n                             inputs=inputs,\n                             grad_outputs=torch.ones_like(outputs).to(device),\n                             create_graph=True,\n                             retain_graph=True,\n                             only_inputs=True)[0]\n\n    gradient = gradient.flatten(1)\n    return ((gradient.norm(2, dim=1) - 1)**2).mean()\n\n\nclass MyCGANTrainer(CGANTrainer):\n    def log(self, *args, **kwargs):\n        if self.is_main:\n            logger.info(*args, **kwargs)\n\n    def before_validation_epoch(self, epoch_id, num_epochs, **kwargs):\n        # record the current epoch\n        self.cur_epoch = epoch_id\n        self.val_g_loss.reset()\n        self.val_d_loss.reset()\n\n    def on_validation_step(self, val_batch, batch_id, num_batches, **kwargs):\n        X = val_batch[0].to(self.device)\n        y = val_batch[1].to(self.device)\n\n        # for g_loss\n        z = torch.normal(mean=0,\n                         std=1,\n                         size=(X.shape[0],\n                               self.modules['generator'].in_channels)).to(\n                                   self.device)\n        gen_X = self.modules['generator'](z, y)\n        g_loss = -torch.mean(self.modules['discriminator'](gen_X, y))\n\n        # for d_loss\n        real_loss = self.modules['discriminator'](X, y)\n        fake_loss = self.modules['discriminator'](gen_X.detach(), y)\n        gp_term = gradient_penalty(self.modules['discriminator'], X, gen_X, y)\n        d_loss = -torch.mean(real_loss) + torch.mean(\n            fake_loss) + self.lambd * gp_term\n\n        self.val_g_loss.update(g_loss)\n        self.val_d_loss.update(d_loss)\n\n        vis_batch = num_batches // 20\n        if batch_id % vis_batch == 0:\n            t = transforms.ToInterpolatedGrid(DEAP_CHANNEL_LOCATION_DICT)\n            # center should be 0.0\n            signal = t.reverse(eeg=gen_X[y == 0][0].detach().cpu().numpy() -\n                               0.5)['eeg']\n            top_img = plot_feature_topomap(\n                torch.tensor(signal),\n                channel_list=DEAP_CHANNEL_LIST,\n                feature_list=[\"theta\", \"alpha\", \"beta\", \"gamma\"])\n            # generate the visualization results and record them for the current epoch\n            writer.add_image(f'top{batch_id}/eeg-0',\n                             top_img,\n                             self.cur_epoch,\n                             dataformats='HWC')\n\n            signal = t.reverse(eeg=gen_X[y == 1][0].detach().cpu().numpy() -\n                               0.5)['eeg']\n            top_img = plot_feature_topomap(\n                torch.tensor(signal),\n                channel_list=DEAP_CHANNEL_LIST,\n                feature_list=[\"theta\", \"alpha\", \"beta\", \"gamma\"])\n            writer.add_image(f'top{batch_id}/eeg-1',\n                             top_img,\n                             self.cur_epoch,\n                             dataformats='HWC')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Building Deep Learning Pipelines Using TorchEEG\nStep 1: Initialize the Dataset\n\nWe use the DEAP dataset supported by TorchEEG. Here, we set an EEG sample to 1 second long and include 128 data points. The baseline signal is 3 seconds long, cut into three, and averaged as the baseline signal for the trial. In offline preprocessing, we divide the EEG signal of every electrode into 4 sub-bands, and calculate the differential entropy on each sub-band as a feature, followed by debaselining and mapping on the grid. Finally, the preprocessed EEG signals are stored in the local IO. In online processing, all EEG signals are converted into Tensors and normalized (in GANs normalization helps with convergence) for input into neural networks.\n\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "dataset = DEAPDataset(io_path=f'./tmp_out/examples_deap_gan/deap',\n                      root_path='./tmp_in/data_preprocessed_python',\n                      offline_transform=transforms.Compose([\n                          transforms.BandDifferentialEntropy(\n                              sampling_rate=128, apply_to_baseline=True),\n                          transforms.BaselineRemoval(),\n                          transforms.ToGrid(DEAP_CHANNEL_LOCATION_DICT)\n                      ]),\n                      online_transform=transforms.Compose([\n                          transforms.MinMaxNormalize(),\n                          transforms.ToTensor(),\n                      ]),\n                      label_transform=transforms.Compose([\n                          transforms.Select('valence'),\n                          transforms.Binary(5.0),\n                      ]),\n                      chunk_size=128,\n                      baseline_chunk_size=128,\n                      num_baseline=3,\n                      num_worker=4)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "<div class=\"alert alert-danger\"><h4>Warning</h4><p>If you use TorchEEG under the `Windows` system and want to use multiple processes (such as in dataset or dataloader), you should check whether :obj:`__name__` is :obj:`__main__` to avoid errors caused by multiple :obj:`import`.</p></div>\n\nThat is, under the :obj:`Windows` system, you need to:\n```\nif __name__ == \"__main__\":\n    dataset = DEAPDataset(\n         io_path=\n         f'./tmp_out/examples_deap_gan/deap',\n         root_path='./tmp_in/data_preprocessed_python',\n         offline_transform=transforms.Compose([\n             transforms.BandDifferentialEntropy(sampling_rate=128,\n                                                apply_to_baseline=True),\n             transforms.BaselineRemoval(),\n             transforms.ToGrid(DEAP_CHANNEL_LOCATION_DICT)\n         ]),\n         online_transform=transforms.Compose([\n             transforms.MinMaxNormalize(),\n             transforms.ToTensor(),\n         ]),\n         label_transform=transforms.Compose([\n             transforms.Select('valence'),\n             transforms.Binary(5.0),\n         ]),\n         io_mode='pickle',\n         chunk_size=128,\n         baseline_chunk_size=128,\n         num_baseline=3,\n         num_worker=4)\n    # the following codes\n```\n<div class=\"alert alert-info\"><h4>Note</h4><p>LMDB may not be optimized for parts of Windows systems or storage devices. If you find that the data preprocessing speed is slow, you can consider setting :obj:`io_mode` to :obj:`pickle`, which is an alternative implemented by TorchEEG based on pickle.</p></div>\n\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Step 2: Divide the Training and Test samples in the Dataset\n\nHere, the dataset is divided using 5-fold cross-validation. In the process of division, we group according to the trial index, and every trial takes 4 folds as training samples and 1 fold as test samples. Samples across trials are aggregated to obtain training set and test set.\n\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "k_fold = KFoldGroupbyTrial(n_splits=5,\n                           split_path='./tmp_out/examples_deap_gan/split')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Step 3: Define the Model and Start Training\n\nWe first use a loop to get the dataset in each cross-validation. In each cross-validation, we initialize the generator and discriminator models and define the hyperparameters. For example, For example, we want to generate the differential entropy features of the 4 sub-bands of the simulated EEG signal.  The generated samples are sampled and transformed from a feature space of 128.\n\nWe then initialize the trainer and set the hyperparameters in the trained model, such as the learning rate, the equipment used, etc. The :obj:`fit` method receives the training dataset and starts training the model. The :obj:`test` method receives a test dataset and reports the test results. The :obj:`save_state_dict` method can save the trained model.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "for i, (train_dataset, val_dataset) in enumerate(k_fold.split(dataset)):\n    # Initialize the model\n    generator = BCGenerator(in_channels=128, out_channels=4, num_classes=2)\n    discriminator = BCDiscriminator(hid_channels=128,\n                                    in_channels=4,\n                                    num_classes=2)\n\n    # Initialize the trainer and use the 0-th GPU for training, or set device_ids=[] to use CPU\n    trainer = MyCGANTrainer(generator=generator,\n                            discriminator=discriminator,\n                            generator_lr=0.0001,\n                            discriminator_lr=0.00001,\n                            weight_decay=0,\n                            device_ids=[0])\n\n    # Initialize several batches of training samples and test samples\n    train_loader = DataLoader(train_dataset,\n                              batch_size=256,\n                              shuffle=True,\n                              num_workers=4)\n    val_loader = DataLoader(val_dataset,\n                            batch_size=256,\n                            shuffle=False,\n                            num_workers=4)\n\n    # Do 50 rounds of training\n    trainer.fit(train_loader, val_loader, num_epochs=50)\n    trainer.test(val_loader)\n    trainer.save_state_dict(f'./tmp_out/examples_deap_gan/weight/{i}.pth')"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.8.6"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}