from abc import ABCMeta, abstractmethod
from typing import Dict, Tuple
from scipy.signal import butter, lfilter, welch
import numpy as np
class BandTransform(metaclass=ABCMeta):
def __init__(self,
frequency: int = 128,
order: int = 5,
band_dict: Dict[str, Tuple[int, int]] = {
"theta": [4, 8],
"alpha": [8, 14],
"beta": [14, 31],
"gamma": [31, 49]
}):
self.frequency = frequency
self.order = order
self.band_dict = band_dict
def __call__(self, x: np.ndarray) -> np.ndarray:
band_list = []
for low, high in self.band_dict.values():
c_list = []
for c in x:
b, a = butter(self.order, [low, high], fs=self.frequency, btype="band")
c_list.append(self.opt(lfilter(b, a, c)))
c_list = np.array(c_list)
band_list.append(c_list)
return np.stack(band_list, axis=-1)
@abstractmethod
def opt(self, x: np.ndarray) -> np.ndarray:
...
[docs]class BandDifferentialEntropy(BandTransform):
r'''
A transform method for calculating the differential entropy of EEG signals in several sub-bands with EEG signals as input.
.. code-block:: python
transform = BandDifferentialEntropy()
transform(torch.randn(32, 128)).shape
>>> (32, 4)
Args:
frequency (int): The sample frequency in Hz. (defualt: :obj:`128`)
order (int): The order of the filter. (defualt: :obj:`5`)
band_dict: (dict): Band name and the critical frequency or frequencies. By default, the differential entropy of the four subbands, theta, alpha, beta and gamma, is calculated. (defualt: :obj:`{...}`)
.. automethod:: __call__
'''
[docs] def __call__(self, x: np.ndarray) -> np.ndarray:
r'''
Args:
x (np.ndarray): The input EEG signals in shape of [number of electrodes, number of data points].
Returns:
np.ndarray[number of electrodes, number of subbands]: The differential entropy of several subbands for all electrodes.
'''
return super().__call__(x)
[docs] def opt(self, x: np.ndarray) -> np.ndarray:
return 1 / 2 * np.log2(2 * np.pi * np.e * np.std(x))
def __repr__(self) -> str:
return f"{self.__class__.__name__}()"
[docs]class BandPowerSpectralDensity:
r'''
A transform method for calculating the power spectral density of EEG signals in several sub-bands with EEG signals as input.
.. code-block:: python
transform = BandPowerSpectralDensity()
transform(torch.randn(32, 128)).shape
>>> (32, 4)
Args:
frequency (int): The sample frequency in Hz. (defualt: :obj:`128`)
window (int): Welch's method computes an estimate of the power spectral density by dividing the data into overlapping segments, where the window denotes length of each segment. (defualt: :obj:`128`)
order (int): The order of the filter. (defualt: :obj:`5`)
band_dict: (dict): Band name and the critical frequency or frequencies. By default, the power spectral density of the four subbands, theta, alpha, beta and gamma, is calculated. (defualt: :obj:`{...}`)
.. automethod:: __call__
'''
def __init__(self,
frequency: int = 128,
window_size: int = 128,
order: int = 5,
band_dict: Dict[str, Tuple[int, int]] = {
"theta": [4, 8],
"alpha": [8, 14],
"beta": [14, 31],
"gamma": [31, 49]
}):
self.frequency = frequency
self.window_size = window_size
self.order = order
self.band_dict = band_dict
[docs] def __call__(self, x: np.ndarray) -> np.ndarray:
r'''
Args:
x (np.ndarray): The input EEG signals in shape of [number of electrodes, number of data points].
Returns:
np.ndarray[number of electrodes, number of subbands]: The power spectral density of several subbands for all electrodes.
'''
band_list = []
for low, high in self.band_dict.values():
c_list = []
for c in x:
freqs, psd = welch(c, self.frequency, nperseg=self.window_size, scaling='density')
index_min = np.argmax(np.round(freqs) > low) - 1
index_max = np.argmax(np.round(freqs) > high)
c_list.append(psd[index_min:index_max].mean())
band_list.append(np.array(c_list))
return np.stack(band_list, axis=-1)
[docs]class BandMeanAbsoluteDeviation(BandTransform):
r'''
A transform method for calculating the mean absolute deviation of EEG signals in several sub-bands with EEG signals as input.
.. code-block:: python
transform = BandMeanAbsoluteDeviation()
transform(torch.randn(32, 128)).shape
>>> (32, 4)
Args:
frequency (int): The sample frequency in Hz. (defualt: :obj:`128`)
order (int): The order of the filter. (defualt: :obj:`5`)
band_dict: (dict): Band name and the critical frequency or frequencies. By default, the mean absolute deviation of the four subbands, theta, alpha, beta and gamma, is calculated. (defualt: :obj:`{...}`)
.. automethod:: __call__
'''
[docs] def __call__(self, x: np.ndarray) -> np.ndarray:
r'''
Args:
x (np.ndarray): The input EEG signals in shape of [number of electrodes, number of data points].
Returns:
np.ndarray[number of electrodes, number of subbands]: The mean absolute deviation of several subbands for all electrodes.
'''
return super().__call__(x)
[docs] def opt(self, x: np.ndarray) -> np.ndarray:
return np.mean(np.abs(x - np.mean(x)))
def __repr__(self) -> str:
return f"{self.__class__.__name__}()"
[docs]class BandKurtosis(BandTransform):
r'''
A transform method for calculating the kurtosis of EEG signals in several sub-bands with EEG signals as input.
.. code-block:: python
transform = BandKurtosis()
transform(torch.randn(32, 128)).shape
>>> (32, 4)
Args:
frequency (int): The sample frequency in Hz. (defualt: :obj:`128`)
order (int): The order of the filter. (defualt: :obj:`5`)
band_dict: (dict): Band name and the critical frequency or frequencies. By default, the kurtosis of the four subbands, theta, alpha, beta and gamma, is calculated. (defualt: :obj:`{...}`)
.. automethod:: __call__
'''
[docs] def __call__(self, x: np.ndarray) -> np.ndarray:
r'''
Args:
x (np.ndarray): The input EEG signals in shape of [number of electrodes, number of data points].
Returns:
np.ndarray[number of electrodes, number of subbands]: The kurtosis of several subbands for all electrodes.
'''
return super().__call__(x)
[docs] def opt(self, x: np.ndarray) -> np.ndarray:
n = len(x)
ave1 = 0.0
ave2 = 0.0
ave4 = 0.0
for x in x:
ave1 += x
ave2 += x**2
ave4 += x**4
ave1 /= n
ave2 /= n
ave4 /= n
sigma = np.sqrt(ave2 - ave1**2)
return ave4 / (sigma**4)
def __repr__(self) -> str:
return f"{self.__class__.__name__}()"
[docs]class BandSkewness(BandTransform):
r'''
A transform method for calculating the skewness of EEG signals in several sub-bands with EEG signals as input.
.. code-block:: python
transform = BandSkewness()
transform(torch.randn(32, 128)).shape
>>> (32, 4)
Args:
frequency (int): The sample frequency in Hz. (defualt: :obj:`128`)
order (int): The order of the filter. (defualt: :obj:`5`)
band_dict: (dict): Band name and the critical frequency or frequencies. By default, the skewness of the four subbands, theta, alpha, beta and gamma, is calculated. (defualt: :obj:`{...}`)
.. automethod:: __call__
'''
[docs] def __call__(self, x: np.ndarray) -> np.ndarray:
r'''
Args:
x (np.ndarray): The input EEG signals in shape of [number of electrodes, number of data points].
Returns:
np.ndarray[number of electrodes, number of subbands]: The skewness of several subbands for all electrodes.
'''
return super().__call__(x)
[docs] def opt(self, x: np.ndarray) -> np.ndarray:
n = len(x)
ave1 = 0.0
ave2 = 0.0
ave3 = 0.0
for x in x:
ave1 += x
ave2 += x**2
ave3 += x**3
ave1 /= n
ave2 /= n
ave3 /= n
sigma = np.sqrt(ave2 - ave1**2)
return (ave3 - 3 * ave1 * sigma**2 - ave1**3) / (sigma**3)
def __repr__(self) -> str:
return f"{self.__class__.__name__}()"