{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "# Cross-domain Emotion Recognition with ADA\nIn this case, we introduce how to use TorchEEG and a simple CCNN model to complete emotion recognition across subjects. Here, the EEG signals of different subjects have a distribution gap, so the model trained on some subjects has a performance drop on unknown subjects. We use a cross-domain algorithm, Associative Domain Adaptation (ADA), to solve this problem.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import logging\nimport os\nimport random\nimport time\n\nimport numpy as np\nimport pandas as pd\nimport torch\nfrom torch.utils.data.dataloader import DataLoader\n\nfrom torcheeg import transforms\nfrom torcheeg.datasets import SEEDFeatureDataset\nfrom torcheeg.datasets.constants.emotion_recognition.seed import \\\n    SEED_CHANNEL_LOCATION_DICT\nfrom torcheeg.model_selection import LeaveOneSubjectOut, Subcategory\nfrom torcheeg.models import CCNN\nfrom torcheeg.trainers import ADATrainer"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Pre-experiment Preparation to Ensure Reproducibility\nUse the logging module to store output in a log file for easy reference while printing it to the screen.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "os.makedirs('./tmp_out/examples_seed_feature_ada/log', exist_ok=True)\nlogger = logging.getLogger('Cross-domain Emotion Recognition with DAD')\nlogger.setLevel(logging.DEBUG)\nconsole_handler = logging.StreamHandler()\ntimeticks = time.strftime(\"%Y-%m-%d-%H-%M-%S\", time.localtime())\nfile_handler = logging.FileHandler(\n    os.path.join('./tmp_out/examples_seed_feature_ada/log', f'{timeticks}.log'))\nlogger.addHandler(console_handler)\nlogger.addHandler(file_handler)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Set the random number seed in all modules to guarantee the same result when running again.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "def seed_everything(seed):\n    random.seed(seed)\n    np.random.seed(seed)\n    os.environ[\"PYTHONHASHSEED\"] = str(seed)\n    torch.manual_seed(seed)\n    torch.cuda.manual_seed(seed)\n    torch.backends.cudnn.deterministic = True\n    torch.backends.cudnn.benchmark = False\n\n\nseed_everything(42)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Define feature extractor and classifier\nDifferent from the vanilla classification model, cross-domain algorithms usually need to constrain the features extracted by the model (such as domain invariant characteristics, etc.), so we need to split the classification model into feature extraction part and classification part.\n\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "class Extractor(CCNN):\n    def forward(self, x):\n        x = self.conv1(x)\n        x = self.conv2(x)\n        x = self.conv3(x)\n        x = self.conv4(x)\n        x = x.flatten(start_dim=1)\n        return x\n\n\nclass Classifier(CCNN):\n    def forward(self, x):\n        x = self.lin1(x)\n        x = self.lin2(x)\n        return x"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Customize Trainer\nTorchEEG provides a large number of domain adaption trainers to help complete the training of cross-subject/session emotion classification models. Here we choose the associative domain adaptation trainer, inherit the trainer and overload the log function to save the log using our own defined method; other hook functions can also be overloaded to meet special needs.\n\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "class MyClassificationTrainer(ADATrainer):\n    def log(self, *args, **kwargs):\n        if self.is_main:\n            logger.info(*args, **kwargs)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Building Deep Learning Pipelines Using TorchEEG\nStep 1: Initialize the Dataset\n\nWe use the SEED dataset supported by TorchEEG. Here we are using extracted EEG features. In the feature dataset, EEG signals (200 data points) per second are pre-computed with differential entropy in five sub-bands and smoothed using a linear dynamical system. We then normalize it and map it onto the grid. Finally, EEG signals are converted into Tensors for input into neural networks.\n\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "dataset = SEEDFeatureDataset(io_path='./tmp_out/examples_seed_feature_ada/seed',\n                             root_path='./tmp_in/ExtractedFeatures',\n                             feature=['de_movingAve'],\n                             online_transform=transforms.Compose([\n                                 transforms.MinMaxNormalize(axis=-1),\n                                 transforms.ToGrid(SEED_CHANNEL_LOCATION_DICT),\n                                 transforms.ToTensor()\n                             ]),\n                             label_transform=transforms.Compose([\n                                 transforms.Select('emotion'),\n                                 transforms.Lambda(lambda x: int(x) + 1),\n                             ]),\n                             num_worker=8)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "<div class=\"alert alert-danger\"><h4>Warning</h4><p>If you use TorchEEG under the `Windows` system and want to use multiple processes (such as in dataset or dataloader), you should check whether :obj:`__name__` is :obj:`__main__` to avoid errors caused by multiple :obj:`import`.</p></div>\n\nThat is, under the :obj:`Windows` system, you need to:\n```\nif __name__ == \"__main__\":\n    dataset = SEEDFeatureDataset(io_path='./tmp_out/examples_seed_feature_ada/seed',\n                          root_path='./tmp_in/ExtractedFeatures',\n                          feature=['de_movingAve'],\n                          online_transform=transforms.Compose([\n                              transforms.MinMaxNormalize(axis=-1),\n                              transforms.ToGrid(SEED_CHANNEL_LOCATION_DICT),\n                              transforms.ToTensor()\n                          ]),\n                          label_transform=transforms.Compose([\n                              transforms.Select('emotion'),\n                              transforms.Lambda(lambda x: int(x) + 1),\n                          ]),\n                          io_mode='pickle',\n                          num_worker=8)\n    # the following codes\n```\n<div class=\"alert alert-info\"><h4>Note</h4><p>LMDB may not be optimized for parts of Windows systems or storage devices. If you find that the data preprocessing speed is slow, you can consider setting :obj:`io_mode` to :obj:`pickle`, which is an alternative implemented by TorchEEG based on pickle.</p></div>\n\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Step 2: Divide the Training and Test samples in the Dataset\n\nHere we do not consider the impact of cross-session on the test results. Therefore, we first mark the session index on the sample according to the collection date. Next, we use :obj:`Subcategory` to divide the data set to obtain the sub-data set of the first session, the second session and the third session.\n\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "subject_info_list = []\nfor subject_id in dataset.info['subject_id'].unique().tolist():\n    subject_info = dataset.info[dataset.info['subject_id'] == subject_id]\n    session_id_set = subject_info['date'].unique().tolist()\n    session_id_map = {\n        session_id: i\n        for i, session_id in enumerate(session_id_set)\n    }\n    subject_info['session_id'] = subject_info['date'].apply(\n        lambda x: session_id_map[x])\n    subject_info_list.append(subject_info)\ndataset.info = pd.concat(subject_info_list)\n\nsubset = Subcategory(criteria='session_id',\n                     split_path='./tmp_out/examples_seed_feature_ada/split')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Step 3: Define the Model and Start Training\n\nFor the dataset of each session, we use the leave-one-out method to conduct experiments, and use one subject as the test dataset and the other subjects as the training dataset to train the model.\n\nWe define the training set as the source domain and the test set as the target domain, hoping that the model trained on the source domain can be generalized to the target domain. We then initialize the trainer and set the hyperparameters in the trained model, such as the learning rate, the equipment used, etc. The :obj:`fit` method receives the source domain data samples and the target domain samples (without labels, as they are from test dataset) and starts training the model. The :obj:`test` method receives a test dataset and reports the test results. The :obj:`save_state_dict` method can save the trained model.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "for j, sub_dataset in enumerate(subset.split(dataset)):\n    k_fold = LeaveOneSubjectOut(\n        split_path=f'./tmp_out/examples_seed_feature_ada/split_{j}')\n    for i, (train_dataset,\n            test_dataset) in enumerate(k_fold.split(sub_dataset)):\n        extractor = Extractor(in_channels=5, num_classes=3)\n        classifier = Classifier(in_channels=5, num_classes=3)\n        trainer = MyClassificationTrainer(extractor=extractor,\n                                          classifier=classifier,\n                                          lr=1e-4,\n                                          lambd=1.0,\n                                          weight_decay=0.0,\n                                          device_ids=[0])\n\n        source_loader = DataLoader(train_dataset,\n                                   batch_size=128,\n                                   shuffle=True,\n                                   num_workers=4,\n                                   drop_last=True)\n\n        target_loader = DataLoader(test_dataset,\n                                   batch_size=128,\n                                   shuffle=True,\n                                   num_workers=4,\n                                   drop_last=True)\n\n        test_loader = DataLoader(test_dataset,\n                                 batch_size=128,\n                                 shuffle=True,\n                                 num_workers=4)\n        trainer.fit(source_loader, target_loader, test_loader, num_epochs=50)\n        trainer.test(test_loader)\n        trainer.save_state_dict(\n            f'./tmp_out/examples_seed_feature_ada/weight/{j}-{i}.pth')"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.8.6"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}