{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "# Examples of MNEDataset\nIn this case, we will introduce how to combine TorchEEG with MNE, and use deep learning algorithms to analyze the existing :obj:`mne.Epochs` format data.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import logging\nimport os\nimport random\nimport time\n\nimport mne\nimport numpy as np\nimport torch\nfrom torch.utils.data.dataloader import DataLoader\n\nfrom torcheeg import transforms\nfrom torcheeg.datasets import MNEDataset\nfrom torcheeg.model_selection import KFold\nfrom torcheeg.models import TSCeption\nfrom torcheeg.trainers import ClassificationTrainer"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Pre-experiment Preparation to Ensure Reproducibility\nUse the logging module to store output in a log file for easy reference while printing it to the screen.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "os.makedirs('./tmp_out/examples_mne_dataset/log', exist_ok=True)\nlogger = logging.getLogger('Examples of MNEDataset')\nlogger.setLevel(logging.DEBUG)\nconsole_handler = logging.StreamHandler()\ntimeticks = time.strftime(\"%Y-%m-%d-%H-%M-%S\", time.localtime())\nfile_handler = logging.FileHandler(\n    os.path.join('./tmp_out/examples_mne_dataset/log', f'{timeticks}.log'))\nlogger.addHandler(console_handler)\nlogger.addHandler(file_handler)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Set the random number seed in all modules to guarantee the same result when running again.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "def seed_everything(seed):\n    random.seed(seed)\n    np.random.seed(seed)\n    os.environ[\"PYTHONHASHSEED\"] = str(seed)\n    torch.manual_seed(seed)\n    torch.cuda.manual_seed(seed)\n    torch.backends.cudnn.deterministic = True\n    torch.backends.cudnn.benchmark = False\n\n\nseed_everything(42)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Customize Trainer\nTorchEEG provides a large number of trainers to help complete the training of classification models, generative models and cross-domain methods. Here we choose the simplest classification trainer, inherit the trainer and overload the log function to save the log using our own defined method; other hook functions can also be overloaded to meet special needs.\n\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "class MyClassificationTrainer(ClassificationTrainer):\n    def log(self, *args, **kwargs):\n        if self.is_main:\n            logger.info(*args, **kwargs)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Read data using MNE and formalize as :obj:`mne.Epochs`\nWe use mne's API to automatically download the motor imagery dataset in Physionet. The EEG signals of subjects 1-21 in runs 6, 10, and 14 were downloaded and filtered. We store multiple :obj:`mne.Epochs` into an array, and use a counterpart array :obj:`metadata_list` to describe the metadata corresponding to the corresponding Epochs.\n\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "metadata_list = [{\n    'subject': subject_id,\n    'run': run_id\n} for subject_id in range(1, 22)\n                 for run_id in [6, 10, 14]]  # motor imagery: hands vs feet\n\nepochs_list = []\nfor metadata in metadata_list:\n    physionet_path = mne.datasets.eegbci.load_data(metadata['subject'],\n                                                   metadata['run'],\n                                                   update_path=False)[0]\n    raw = mne.io.read_raw_edf(physionet_path, preload=True, stim_channel='auto')\n    mne.datasets.eegbci.standardize(raw)\n\n    montage = mne.channels.make_standard_montage('standard_1005')\n    raw.set_montage(montage)\n\n    raw.filter(7., 30., fir_design='firwin', skip_by_annotation='edge')\n    events, _ = mne.events_from_annotations(raw, event_id=dict(T1=2, T2=3))\n    picks = mne.pick_types(raw.info,\n                           meg=False,\n                           eeg=True,\n                           stim=False,\n                           eog=False,\n                           exclude='bads')\n    # init Epochs with raw EEG signals and corresponding event annotations. Here, tmin is set to -1., and tmax is set to 4.0, to avoid classification of evoked responses by using epochs that start 1s after cue onset.\n    epochs_list.append(\n        mne.Epochs(raw,\n                   events,\n                   dict(hands=2, feet=3),\n                   tmin=-1.,\n                   tmax=4.0,\n                   proj=True,\n                   picks=picks))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Convert :obj:`mne.Epochs` into MNEDataset\nWe use MNEDataset to window the Epochs. Here, we set the window size to 160 (1-second long) and the overlap to 80 to segment the EEG signal corresponding to each event. The corresponding information in :obj:`metadata_list` will be assigned to the corresponding window. At the same time, the window also includes the start position of the window :obj:`start_at`, the end position :obj:`start_at`, the epoch index :obj:`trial_id` and the corresponding event type :obj:`event`, which can be used and transformed as label.\n\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "dataset = MNEDataset(epochs_list=epochs_list,\n                     metadata_list=metadata_list,\n                     chunk_size=160,\n                     overlap=80,\n                     io_path='./tmp_out/examples_mne_dataset/physionet',\n                     offline_transform=transforms.Compose(\n                         [transforms.MeanStdNormalize(),\n                          transforms.To2d()]),\n                     online_transform=transforms.ToTensor(),\n                     label_transform=transforms.Compose([\n                         transforms.Select('event'),\n                         transforms.Lambda(lambda x: x - 2)\n                     ]),\n                     num_worker=2)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "<div class=\"alert alert-danger\"><h4>Warning</h4><p>If you use TorchEEG under the `Windows` system and want to use multiple processes (such as in dataset or dataloader), you should check whether :obj:`__name__` is :obj:`__main__` to avoid errors caused by multiple :obj:`import`.</p></div>\n\nThat is, under the :obj:`Windows` system, you need to:\n```\nif __name__ == \"__main__\":\n    dataset = MNEDataset(epochs_list=epochs_list,\n                  metadata_list=metadata_list,\n                  chunk_size=160,\n                  overlap=80,\n                  io_path='./tmp_out/examples_mne_dataset/physionet',\n                  offline_transform=transforms.Compose(\n                      [transforms.MeanStdNormalize(),\n                       transforms.To2d()]),\n                  online_transform=transforms.ToTensor(),\n                  label_transform=transforms.Compose([\n                      transforms.Select('event'),\n                      transforms.Lambda(lambda x: x - 2)\n                  ]),\n                  io_mode='pickle',\n                  num_worker=2)\n    # the following codes\n```\n<div class=\"alert alert-info\"><h4>Note</h4><p>LMDB may not be optimized for parts of Windows systems or storage devices. If you find that the data preprocessing speed is slow, you can consider setting :obj:`io_mode` to :obj:`pickle`, which is an alternative implemented by TorchEEG based on pickle.</p></div>\n\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Step 2: Divide the Training and Test samples in the Dataset\n\nHere, the dataset is divided using 5-fold cross-validation. In the process of division, the total dataset takes 4 folds as training samples and 1 fold as test samples.\n\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "k_fold = KFold(n_splits=5, split_path='./tmp_out/examples_mne_dataset/split')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Step 3: Define the Model and Start Training\n\nWe first use a loop to get the dataset in each cross-validation.\nIn each cross-validation, we initialize the TSCeption model, and define the hyperparameters. For example, the sampling rate of EEG sample is 160, there are 15 temporal modules and 15 spatial modules, etc. In this example, the shape of EEG signals is :obj:`[60, 160]` and the number of classes is 2.\n\nWe then initialize the trainer and set the hyperparameters in the trained model, such as the learning rate, the equipment used, etc. The :obj:`fit` method receives the training dataset and starts training the model. The :obj:`test` method receives a test dataset and reports the test results. The :obj:`save_state_dict` method can save the trained model.\n\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "for i, (train_dataset, val_dataset) in enumerate(k_fold.split(dataset)):\n    # Initialize the model\n    model = TSCeption(num_electrodes=64,\n                      num_classes=2,\n                      num_T=15,\n                      num_S=15,\n                      in_channels=1,\n                      hid_channels=32,\n                      sampling_rate=160,\n                      dropout=0.5)\n\n    # Initialize the trainer and use the 0-th GPU for training, or set device_ids=[] to use CPU\n    trainer = MyClassificationTrainer(model=model,\n                                      lr=1e-4,\n                                      weight_decay=1e-4,\n                                      device_ids=[0])\n\n    # Initialize several batches of training samples and test samples\n    train_loader = DataLoader(train_dataset,\n                              batch_size=256,\n                              shuffle=True,\n                              num_workers=4)\n    val_loader = DataLoader(val_dataset,\n                            batch_size=256,\n                            shuffle=False,\n                            num_workers=4)\n\n    # Do 50 rounds of training\n    trainer.fit(train_loader, val_loader, num_epochs=50)\n    trainer.test(val_loader)\n    trainer.save_state_dict(f'./tmp_out/examples_mne_dataset/weight/{i}.pth')"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.8.6"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}