Shortcuts

Source code for torcheeg.io.eeg_signal

import pickle
import os
from typing import Union

import torch
import lmdb


class _EEGSignalIO:

    @property
    def write_pointer(self):
        return len(self)

    def keys(self):
        raise NotImplementedError

    def eegs(self):
        raise NotImplementedError

    def read_eeg(self, key: str) -> any:
        raise NotImplementedError

    def write_eeg(self,
                  eeg: Union[any, torch.Tensor],
                  key: Union[str, None] = None) -> str:
        raise NotImplementedError


class MemoryEEGSignalIO(_EEGSignalIO):

    def __init__(self):
        self._memory = {}

    def __len__(self):
        return len(self._memory)

    def keys(self):
        r'''
        Get all keys in the MemoryEEGSignalIO.

        Returns:
            list: The list of keys in the MemoryEEGSignalIO.
        '''
        return list(self._memory.keys())

    def eegs(self):
        return list(self._memory.values())

    def read_eeg(self, key: str) -> any:
        r'''
        Read all the MemoryEEGSignalIO into memory, and index the specified EEG signal in memory with the given :obj:`key`.

        Args:
            key (str): The index of the EEG signal to be queried.
            
        Returns:
            any: The EEG signal sample.
        '''
        if key not in self._memory:
            raise RuntimeError(
                f'Unable to index the EEG signal sample with key {key}!')

        return self._memory[key]

    def write_eeg(self,
                  eeg: Union[any, torch.Tensor],
                  key: Union[str, None] = None) -> str:
        r'''
        Write EEG signal to memory.

        Args:
            eeg (any): EEG signal samples to be written into the database.
            key (str): The key of the EEG signal to be inserted, if not specified, it will be an auto-incrementing
        '''
        if key is None:
            key = str(self.write_pointer)

        if eeg is None:
            raise RuntimeError(f'Save None to the memory with the key {key}!')

        self._memory[key] = eeg

        return key
    
    def __copy__(self):
        cls = self.__class__
        result = cls.__new__(cls)
        result.__dict__.update(self.__dict__)
        return result


class LMDBEEGSignalIO(_EEGSignalIO):

    def __init__(self, io_path: str, io_size: int = 1048576) -> None:
        self.io_path = io_path
        self.io_size = io_size

        os.makedirs(self.io_path, exist_ok=True)
        self._env = lmdb.open(path=self.io_path,
                              map_size=self.io_size,
                              lock=False)

    def __del__(self):
        self._env.close()

    def __len__(self):
        with self._env.begin(write=False) as transaction:
            return transaction.stat()['entries']

    def write_eeg(self,
                  eeg: Union[any, torch.Tensor],
                  key: Union[str, None] = None) -> str:
        r'''
        Write EEG signal to database.

        Args:
            eeg (any): EEG signal samples to be written into the database.
            key (str, optional): The key of the EEG signal to be inserted, if not specified, it will be an auto-incrementing integer.

        Returns:
            int: The index of written EEG signals in the database.
        '''

        if key is None:
            key = str(self.write_pointer)

        if eeg is None:
            raise RuntimeError(f'Save None to the LMDB with the key {key}!')

        try_again = False
        try:
            with self._env.begin(write=True) as transaction:
                transaction.put(key.encode(), pickle.dumps(eeg))
        except lmdb.MapFullError:
            self.io_size = self.io_size * 2
            self._env.set_mapsize(self.io_size)
            try_again = True
        if try_again:
            return self.write_eeg(key=key, eeg=eeg)
        return key

    def read_eeg(self, key: str) -> any:
        r'''
        Query the corresponding EEG signal in the database according to the index.

        Args:
            key (str): The index of the EEG signal to be queried.
            
        Returns:
            any: The EEG signal sample.
        '''
        with self._env.begin(write=False) as transaction:
            eeg = transaction.get(key.encode())

        if eeg is None:
            raise RuntimeError(
                f'Unable to index the EEG signal sample with key {key}!')

        return pickle.loads(eeg)

    def keys(self):
        r'''
        Get all keys in the LMDBEEGSignalIO.

        Returns:
            list: The list of keys in the LMDBEEGSignalIO.
        '''
        with self._env.begin(write=False) as transaction:
            return [
                key.decode()
                for key in transaction.cursor().iternext(keys=True,
                                                         values=False)
            ]

    def eegs(self):
        r'''
        Get all EEG signals in the LMDBEEGSignalIO.

        Returns:
            list: The list of EEG signals in the LMDBEEGSignalIO.
        '''
        return [self.read_eeg(key) for key in self.keys()]

    def __getstate__(self):
        # pickle for Pallarel
        state = self.__dict__.copy()
        del state['_env']
        return state

    def __setstate__(self, state):
        # pickle for Pallarel
        self.__dict__.update(state)
        self._env = lmdb.open(path=self.io_path,
                              map_size=self.io_size,
                              lock=False)

    def __copy__(self):
        cls = self.__class__
        result = cls.__new__(cls)
        result.__dict__.update({
            k: v
            for k, v in self.__dict__.items() if k != '_env'
        })
        result._env = lmdb.open(path=self.io_path,
                                map_size=self.io_size,
                                lock=False)
        return result


class PickleEEGSignalIO(_EEGSignalIO):

    def __init__(self, io_path: str) -> None:
        self.io_path = io_path

        os.makedirs(self.io_path, exist_ok=True)

    def __len__(self):
        return len(os.listdir(self.io_path))

    def write_eeg(self,
                  eeg: Union[any, torch.Tensor],
                  key: Union[str, None] = None) -> str:
        r'''
            Write EEG signal to folder.
    
            Args:
                eeg (any): EEG signal samples to be written into the folder.
                key (str, optional): The key of the EEG signal to be inserted, if not specified, it will be an auto-incrementing integer.
    
            Returns:
                int: The index of written EEG signals in the folder.
            '''

        if key is None:
            key = str(self.write_pointer)

        if eeg is None:
            raise RuntimeError(f'Save None to the LMDB with the key {key}!')

        with open(os.path.join(self.io_path, key), 'wb') as f:
            pickle.dump(eeg, f)

        return key

    def read_eeg(self, key: str) -> any:
        r'''
            Query the corresponding EEG signal in the folder according to the index.
    
            Args:
                key (str): The index of the EEG signal to be queried.
                
            Returns:
                any: The EEG signal sample.
            '''
        with open(os.path.join(self.io_path, key), 'rb') as f:
            eeg = pickle.load(f)

        return eeg

    def keys(self):
        r'''
            Get all keys in the PickleEEGSignalIO.
    
            Returns:
                list: The list of keys in the PickleEEGSignalIO.
            '''
        return os.listdir(self.io_path)

    def eegs(self):
        r'''
            Get all EEG signals in the PickleEEGSignalIO.
    
            Returns:
                list: The list of EEG signals in the PickleEEGSignalIO.
            '''
        return [self.read_eeg(key) for key in self.keys()]
    
    def __copy__(self):
        cls = self.__class__
        result = cls.__new__(cls)
        result.__dict__.update(self.__dict__)
        return result


[docs]class EEGSignalIO: r''' A general-purpose, lightweight and efficient EEG signal IO APIs for converting various real-world EEG signal datasets into samples and storing them in the database. Here, we draw on the implementation ideas of industrial-grade application Caffe, and encapsulate a set of EEG signal reading and writing methods based on Lightning Memory-Mapped Database (LMDB), which not only unifies the differences of data types in different databases, but also accelerates the reading of data during training and testing. .. code-block:: python eeg_io = EEGSignalIO('YOUR_PATH') key = eeg_io.write_eeg(np.random.randn(32, 128)) eeg = eeg_io.read_eeg(key) eeg.shape >>> (32, 128) Args: io_path (str): Where the database is stored. io_size (int, optional): The maximum capacity of the database. It will increase according to the size of the dataset. (default: :obj:`1024`) io_mode (str): Storage mode of EEG signal. When io_mode is set to :obj:`lmdb`, TorchEEG provides an efficient database (LMDB) for storing EEG signals. LMDB may not perform well on limited operating systems. Here, a file system based and a memory based EEG signal storages are also provided. When io_mode is set to :obj:`pickle`, pickle-based persistence files are used. When io_mode is set to :obj:`memory`, memory are used. (default: :obj:`lmdb`) ''' def __init__(self, io_path: str, io_size: int = 1048576, io_mode: str = 'lmdb') -> None: self.io_path = io_path self.io_size = io_size self.io_mode = io_mode if self.io_mode == 'lmdb': self._io = LMDBEEGSignalIO(io_path=self.io_path, io_size=self.io_size) elif self.io_mode == 'pickle': self._io = PickleEEGSignalIO(io_path=self.io_path) elif self.io_mode == 'memory': self._io = MemoryEEGSignalIO() else: raise RuntimeError( f'Unsupported io_mode {self.io_mode}, please choose from lmdb, pickle and memory.' ) def __del__(self): del self._io def __copy__(self): cls = self.__class__ result = cls.__new__(cls) result.__dict__.update({ k: v for k, v in self.__dict__.items() if k != '_io' }) result._io = self._io.__copy__() return result def __len__(self): return len(self._io)
[docs] def write_eeg(self, eeg: Union[any, torch.Tensor], key: Union[str, None] = None) -> str: r''' Write EEG signal to database. Args: eeg (any): EEG signal samples to be written into the database. key (str, optional): The key of the EEG signal to be inserted, if not specified, it will be an auto-incrementing integer. Returns: int: The index of written EEG signals in the database. ''' return self._io.write_eeg(eeg=eeg, key=key)
[docs] def read_eeg(self, key: str) -> any: r''' Query the corresponding EEG signal in the database according to the index. Args: key (str): The index of the EEG signal to be queried. Returns: any: The EEG signal sample. ''' return self._io.read_eeg(key)
[docs] def keys(self): r''' Get all keys in the EEGSignalIO. Returns: list: The list of keys in the EEGSignalIO. ''' return self._io.keys()
[docs] def eegs(self): r''' Get all EEG signals in the EEGSignalIO. Returns: list: The list of EEG signals in the EEGSignalIO. ''' return self._io.eegs()
[docs] def to_lmdb(self, io_path: str, io_size: int = 1048576): r''' Convert to the LMDBEEGSignalIO, where the index of each sample in the database corresponds to the key, and the EEG signal stored in the database corresponds to the value. ''' _io = LMDBEEGSignalIO(io_path=io_path, io_size=io_size) self.io_path = io_path self.io_size = io_size self.io_mode = 'lmdb' for key in self.keys(): _io.write_eeg(self.read_eeg(key=key), key=key) self._io = _io
[docs] def to_pickle(self, io_path: str): r''' Convert to the PickleEEGSignalIO, where the index of each sample in the database corresponds to the key, and the EEG signal stored in the database corresponds to the value. ''' _io = PickleEEGSignalIO(io_path=io_path) self.io_path = io_path self.io_mode = 'pickle' for key in self.keys(): _io.write_eeg(self.read_eeg(key=key), key=key) self._io = _io
[docs] def to_memory(self): r''' Convert to the MemoryEEGSignalIO, where the index of each sample in the database corresponds to the key, and the EEG signal stored in the database corresponds to the value. ''' _io = MemoryEEGSignalIO() self.io_mode = 'memory' for key in self.keys(): _io.write_eeg(self.read_eeg(key=key), key=key) self._io = _io

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources