Shortcuts

Source code for torcheeg.io.eeg_signal

import pickle
import os
from typing import Union

import torch
import lmdb


class _EEGSignalIO:

    @property
    def write_pointer(self):
        return len(self)

    def keys(self):
        raise NotImplementedError

    def eegs(self):
        raise NotImplementedError

    def read_eeg(self, key: str) -> any:
        raise NotImplementedError

    def write_eeg(self,
                  eeg: Union[any, torch.Tensor],
                  key: Union[str, None] = None) -> str:
        raise NotImplementedError


class MemoryEEGSignalIO(_EEGSignalIO):

    def __init__(self):
        self._memory = {}

    def __len__(self):
        return len(self._memory)

    def keys(self):
        r'''
        Get all keys in the MemoryEEGSignalIO.

        Returns:
            list: The list of keys in the MemoryEEGSignalIO.
        '''
        return list(self._memory.keys())

    def eegs(self):
        return list(self._memory.values())

    def read_eeg(self, key: str) -> any:
        r'''
        Read all the MemoryEEGSignalIO into memory, and index the specified EEG signal in memory with the given :obj:`key`.

        Args:
            key (str): The index of the EEG signal to be queried.
            
        Returns:
            any: The EEG signal sample.
        '''
        if key not in self._memory:
            raise RuntimeError(
                f'Unable to index the EEG signal sample with key {key}!')

        return self._memory[key]

    def write_eeg(self,
                  eeg: Union[any, torch.Tensor],
                  key: Union[str, None] = None) -> str:
        r'''
        Write EEG signal to memory.

        Args:
            eeg (any): EEG signal samples to be written into the database.
            key (str): The key of the EEG signal to be inserted, if not specified, it will be an auto-incrementing
        '''
        if key is None:
            key = str(self.write_pointer)

        if eeg is None:
            raise RuntimeError(f'Save None to the memory with the key {key}!')

        self._memory[key] = eeg

        return key
    
    def __copy__(self):
        cls = self.__class__
        result = cls.__new__(cls)
        result.__dict__.update(self.__dict__)
        return result


class LMDBEEGSignalIO(_EEGSignalIO):

    def __init__(self, io_path: str, io_size: int = 1048576) -> None:
        self.io_path = io_path
        self.io_size = io_size

        os.makedirs(self.io_path, exist_ok=True)
        self._env = lmdb.open(path=self.io_path,
                              map_size=self.io_size,
                              lock=False)

    def __del__(self):
        self._env.close()

    def __len__(self):
        with self._env.begin(write=False) as transaction:
            return transaction.stat()['entries']

    def write_eeg(self,
                  eeg: Union[any, torch.Tensor],
                  key: Union[str, None] = None) -> str:
        r'''
        Write EEG signal to database.

        Args:
            eeg (any): EEG signal samples to be written into the database.
            key (str, optional): The key of the EEG signal to be inserted, if not specified, it will be an auto-incrementing integer.

        Returns:
            int: The index of written EEG signals in the database.
        '''

        if key is None:
            key = str(self.write_pointer)

        if eeg is None:
            raise RuntimeError(f'Save None to the LMDB with the key {key}!')

        try_again = False
        try:
            with self._env.begin(write=True) as transaction:
                transaction.put(key.encode(), pickle.dumps(eeg))
        except lmdb.MapFullError:
            self.io_size = self.io_size * 2
            self._env.set_mapsize(self.io_size)
            try_again = True
        if try_again:
            return self.write_eeg(key=key, eeg=eeg)
        return key

    def read_eeg(self, key: str) -> any:
        r'''
        Query the corresponding EEG signal in the database according to the index.

        Args:
            key (str): The index of the EEG signal to be queried.
            
        Returns:
            any: The EEG signal sample.
        '''
        with self._env.begin(write=False) as transaction:
            eeg = transaction.get(key.encode())

        if eeg is None:
            raise RuntimeError(
                f'Unable to index the EEG signal sample with key {key}!')

        return pickle.loads(eeg)

    def keys(self):
        r'''
        Get all keys in the LMDBEEGSignalIO.

        Returns:
            list: The list of keys in the LMDBEEGSignalIO.
        '''
        with self._env.begin(write=False) as transaction:
            return [
                key.decode()
                for key in transaction.cursor().iternext(keys=True,
                                                         values=False)
            ]

    def eegs(self):
        r'''
        Get all EEG signals in the LMDBEEGSignalIO.

        Returns:
            list: The list of EEG signals in the LMDBEEGSignalIO.
        '''
        return [self.read_eeg(key) for key in self.keys()]

    def __getstate__(self):
        # pickle for Pallarel
        state = self.__dict__.copy()
        del state['_env']
        return state

    def __setstate__(self, state):
        # pickle for Pallarel
        self.__dict__.update(state)
        self._env = lmdb.open(path=self.io_path,
                              map_size=self.io_size,
                              lock=False)

    def __copy__(self):
        cls = self.__class__
        result = cls.__new__(cls)
        result.__dict__.update({
            k: v
            for k, v in self.__dict__.items() if k != '_env'
        })
        result._env = lmdb.open(path=self.io_path,
                                map_size=self.io_size,
                                lock=False)
        return result


class PickleEEGSignalIO(_EEGSignalIO):

    def __init__(self, io_path: str) -> None:
        self.io_path = io_path

        os.makedirs(self.io_path, exist_ok=True)

    def __len__(self):
        return len(os.listdir(self.io_path))

    def write_eeg(self,
                  eeg: Union[any, torch.Tensor],
                  key: Union[str, None] = None) -> str:
        r'''
            Write EEG signal to folder.
    
            Args:
                eeg (any): EEG signal samples to be written into the folder.
                key (str, optional): The key of the EEG signal to be inserted, if not specified, it will be an auto-incrementing integer.
    
            Returns:
                int: The index of written EEG signals in the folder.
            '''

        if key is None:
            key = str(self.write_pointer)

        if eeg is None:
            raise RuntimeError(f'Save None to the LMDB with the key {key}!')

        with open(os.path.join(self.io_path, key), 'wb') as f:
            pickle.dump(eeg, f)

        return key

    def read_eeg(self, key: str) -> any:
        r'''
            Query the corresponding EEG signal in the folder according to the index.
    
            Args:
                key (str): The index of the EEG signal to be queried.
                
            Returns:
                any: The EEG signal sample.
            '''
        with open(os.path.join(self.io_path, key), 'rb') as f:
            eeg = pickle.load(f)

        return eeg

    def keys(self):
        r'''
            Get all keys in the PickleEEGSignalIO.
    
            Returns:
                list: The list of keys in the PickleEEGSignalIO.
            '''
        return os.listdir(self.io_path)

    def eegs(self):
        r'''
            Get all EEG signals in the PickleEEGSignalIO.
    
            Returns:
                list: The list of EEG signals in the PickleEEGSignalIO.
            '''
        return [self.read_eeg(key) for key in self.keys()]
    
    def __copy__(self):
        cls = self.__class__
        result = cls.__new__(cls)
        result.__dict__.update(self.__dict__)
        return result


[docs]class EEGSignalIO: r''' A general-purpose, lightweight and efficient EEG signal IO APIs for converting various real-world EEG signal datasets into samples and storing them in the database. Here, we draw on the implementation ideas of industrial-grade application Caffe, and encapsulate a set of EEG signal reading and writing methods based on Lightning Memory-Mapped Database (LMDB), which not only unifies the differences of data types in different databases, but also accelerates the reading of data during training and testing. .. code-block:: python eeg_io = EEGSignalIO('YOUR_PATH') key = eeg_io.write_eeg(np.random.randn(32, 128)) eeg = eeg_io.read_eeg(key) eeg.shape >>> (32, 128) Args: io_path (str): Where the database is stored. io_size (int, optional): The maximum capacity of the database. It will increase according to the size of the dataset. (default: :obj:`1024`) io_mode (str): Storage mode of EEG signal. When io_mode is set to :obj:`lmdb`, TorchEEG provides an efficient database (LMDB) for storing EEG signals. LMDB may not perform well on limited operating systems. Here, a file system based and a memory based EEG signal storages are also provided. When io_mode is set to :obj:`pickle`, pickle-based persistence files are used. When io_mode is set to :obj:`memory`, memory are used. (default: :obj:`lmdb`) ''' def __init__(self, io_path: str, io_size: int = 1048576, io_mode: str = 'lmdb') -> None: self.io_path = io_path self.io_size = io_size self.io_mode = io_mode if self.io_mode == 'lmdb': self._io = LMDBEEGSignalIO(io_path=self.io_path, io_size=self.io_size) elif self.io_mode == 'pickle': self._io = PickleEEGSignalIO(io_path=self.io_path) elif self.io_mode == 'memory': self._io = MemoryEEGSignalIO() else: raise RuntimeError( f'Unsupported io_mode {self.io_mode}, please choose from lmdb, pickle and memory.' ) def __del__(self): del self._io def __copy__(self): cls = self.__class__ result = cls.__new__(cls) result.__dict__.update({ k: v for k, v in self.__dict__.items() if k != '_io' }) result._io = self._io.__copy__() return result def __len__(self): return len(self._io)
[docs] def write_eeg(self, eeg: Union[any, torch.Tensor], key: Union[str, None] = None) -> str: r''' Write EEG signal to database. Args: eeg (any): EEG signal samples to be written into the database. key (str, optional): The key of the EEG signal to be inserted, if not specified, it will be an auto-incrementing integer. Returns: int: The index of written EEG signals in the database. ''' return self._io.write_eeg(eeg=eeg, key=key)
[docs] def read_eeg(self, key: str) -> any: r''' Query the corresponding EEG signal in the database according to the index. Args: key (str): The index of the EEG signal to be queried. Returns: any: The EEG signal sample. ''' return self._io.read_eeg(key)
[docs] def keys(self): r''' Get all keys in the EEGSignalIO. Returns: list: The list of keys in the EEGSignalIO. ''' return self._io.keys()
[docs] def eegs(self): r''' Get all EEG signals in the EEGSignalIO. Returns: list: The list of EEG signals in the EEGSignalIO. ''' return self._io.eegs()
[docs] def to_lmdb(self, io_path: str, io_size: int = 1048576): r''' Convert to the LMDBEEGSignalIO, where the index of each sample in the database corresponds to the key, and the EEG signal stored in the database corresponds to the value. ''' _io = LMDBEEGSignalIO(io_path=io_path, io_size=io_size) self.io_path = io_path self.io_size = io_size self.io_mode = 'lmdb' for key in self.keys(): _io.write_eeg(self.read_eeg(key=key), key=key) self._io = _io
[docs] def to_pickle(self, io_path: str): r''' Convert to the PickleEEGSignalIO, where the index of each sample in the database corresponds to the key, and the EEG signal stored in the database corresponds to the value. ''' _io = PickleEEGSignalIO(io_path=io_path) self.io_path = io_path self.io_mode = 'pickle' for key in self.keys(): _io.write_eeg(self.read_eeg(key=key), key=key) self._io = _io
[docs] def to_memory(self): r''' Convert to the MemoryEEGSignalIO, where the index of each sample in the database corresponds to the key, and the EEG signal stored in the database corresponds to the value. ''' _io = MemoryEEGSignalIO() self.io_mode = 'memory' for key in self.keys(): _io.write_eeg(self.read_eeg(key=key), key=key) self._io = _io
Read the Docs v: latest
Versions
latest
stable
v1.1.1
v1.1.0
v1.0.11
v1.0.10
v1.0.9
v1.0.8.post1
v1.0.8
v1.0.7
v1.0.6
v1.0.4
v1.0.3
v1.0.2
v1.0.1
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources