Source code for torcheeg.models.transformer.labram
import math
import warnings
from collections import OrderedDict
from functools import partial
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True):
if drop_prob == 0. or not training:
return x
keep_prob = 1 - drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
if keep_prob > 0.0 and scale_by_keep:
random_tensor.div_(keep_prob)
return x * random_tensor
def _trunc_normal_(tensor, mean, std, a, b):
def norm_cdf(x):
return (1. + math.erf(x / math.sqrt(2.))) / 2.
if (mean < a - 2 * std) or (mean > b + 2 * std):
warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
"The distribution of values may be incorrect.",
stacklevel=2)
l = norm_cdf((a - mean) / std)
u = norm_cdf((b - mean) / std)
tensor.uniform_(2 * l - 1, 2 * u - 1)
tensor.erfinv_()
tensor.mul_(std * math.sqrt(2.))
tensor.add_(mean)
tensor.clamp_(min=a, max=b)
return tensor
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
with torch.no_grad():
return _trunc_normal_(tensor, mean, std, a, b)
standard_1020 = [
'FP1', 'FPZ', 'FP2',
'AF9', 'AF7', 'AF5', 'AF3', 'AF1', 'AFZ', 'AF2', 'AF4', 'AF6', 'AF8', 'AF10',
'F9', 'F7', 'F5', 'F3', 'F1', 'FZ', 'F2', 'F4', 'F6', 'F8', 'F10',
'FT9', 'FT7', 'FC5', 'FC3', 'FC1', 'FCZ', 'FC2', 'FC4', 'FC6', 'FT8', 'FT10',
'T9', 'T7', 'C5', 'C3', 'C1', 'CZ', 'C2', 'C4', 'C6', 'T8', 'T10',
'TP9', 'TP7', 'CP5', 'CP3', 'CP1', 'CPZ', 'CP2', 'CP4', 'CP6', 'TP8', 'TP10',
'P9', 'P7', 'P5', 'P3', 'P1', 'PZ', 'P2', 'P4', 'P6', 'P8', 'P10',
'PO9', 'PO7', 'PO5', 'PO3', 'PO1', 'POZ', 'PO2', 'PO4', 'PO6', 'PO8', 'PO10',
'O1', 'OZ', 'O2', 'O9', 'CB1', 'CB2',
'IZ', 'O10', 'T3', 'T5', 'T4', 'T6', 'M1', 'M2', 'A1', 'A2',
'CFC1', 'CFC2', 'CFC3', 'CFC4', 'CFC5', 'CFC6', 'CFC7', 'CFC8',
'CCP1', 'CCP2', 'CCP3', 'CCP4', 'CCP5', 'CCP6', 'CCP7', 'CCP8',
'T1', 'T2', 'FTT9h', 'TTP7h', 'TPP9h', 'FTT10h', 'TPP8h', 'TPP10h',
"FP1-F7", "F7-T7", "T7-P7", "P7-O1", "FP2-F8", "F8-T8", "T8-P8", "P8-O2", "FP1-F3", "F3-C3", "C3-P3", "P3-O1", "FP2-F4", "F4-C4", "C4-P4", "P4-O2"
]
def get_input_chans(ch_names):
input_chans = [0]
for ch_name in ch_names:
input_chans.append(standard_1020.index(ch_name) + 1)
return input_chans
class DropPath(nn.Module):
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
return drop_path(x, self.drop_prob, self.training)
def extra_repr(self) -> str:
return 'p={}'.format(self.drop_prob)
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.fc2(x)
x = self.drop(x)
return x
class Attention(nn.Module):
def __init__(
self, dim, num_heads=8, qkv_bias=False, qk_norm=None, qk_scale=None, attn_drop=0.,
proj_drop=0., window_size=None, attn_head_dim=None):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
if attn_head_dim is not None:
head_dim = attn_head_dim
all_head_dim = head_dim * self.num_heads
self.scale = qk_scale or head_dim ** -0.5
self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
if qkv_bias:
self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
else:
self.q_bias = None
self.v_bias = None
if qk_norm is not None:
self.q_norm = qk_norm(head_dim)
self.k_norm = qk_norm(head_dim)
else:
self.q_norm = None
self.k_norm = None
if window_size:
self.window_size = window_size
self.num_relative_distance = (
2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
self.relative_position_bias_table = nn.Parameter(
torch.zeros(self.num_relative_distance, num_heads))
coords_h = torch.arange(window_size[0])
coords_w = torch.arange(window_size[1])
coords = torch.stack(torch.meshgrid(
[coords_h, coords_w]))
coords_flatten = torch.flatten(coords, 1)
relative_coords = coords_flatten[:, :,
None] - coords_flatten[:, None, :]
relative_coords = relative_coords.permute(
1, 2, 0).contiguous()
relative_coords[:, :, 0] += window_size[0] - \
1
relative_coords[:, :, 1] += window_size[1] - 1
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
relative_position_index = \
torch.zeros(
size=(window_size[0] * window_size[1] + 1, ) * 2, dtype=relative_coords.dtype)
relative_position_index[1:, 1:] = relative_coords.sum(-1)
relative_position_index[0, 0:] = self.num_relative_distance - 3
relative_position_index[0:, 0] = self.num_relative_distance - 2
relative_position_index[0, 0] = self.num_relative_distance - 1
self.register_buffer("relative_position_index",
relative_position_index)
else:
self.window_size = None
self.relative_position_bias_table = None
self.relative_position_index = None
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(all_head_dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x, rel_pos_bias=None, return_attention=False, return_qkv=False):
B, N, C = x.shape
qkv_bias = None
if self.q_bias is not None:
qkv_bias = torch.cat((self.q_bias, torch.zeros_like(
self.v_bias, requires_grad=False), self.v_bias))
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
if self.q_norm is not None:
q = self.q_norm(q).type_as(v)
if self.k_norm is not None:
k = self.k_norm(k).type_as(v)
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
if self.relative_position_bias_table is not None:
relative_position_bias = \
self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1] + 1,
self.window_size[0] * self.window_size[1] + 1, -1)
relative_position_bias = relative_position_bias.permute(
2, 0, 1).contiguous()
attn = attn + relative_position_bias.unsqueeze(0)
if rel_pos_bias is not None:
attn = attn + rel_pos_bias
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
if return_attention:
return attn
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
x = self.proj(x)
x = self.proj_drop(x)
if return_qkv:
return x, qkv
return x
class Block(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_norm=None, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm,
window_size=None, attn_head_dim=None):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_norm=qk_norm, qk_scale=qk_scale,
attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim)
self.drop_path = DropPath(
drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim,
act_layer=act_layer, drop=drop)
if init_values > 0:
self.gamma_1 = nn.Parameter(
init_values * torch.ones((dim)), requires_grad=True)
self.gamma_2 = nn.Parameter(
init_values * torch.ones((dim)), requires_grad=True)
else:
self.gamma_1, self.gamma_2 = None, None
def forward(self, x, rel_pos_bias=None, return_attention=False, return_qkv=False):
if return_attention:
return self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias, return_attention=True)
if return_qkv:
y, qkv = self.attn(self.norm1(
x), rel_pos_bias=rel_pos_bias, return_qkv=return_qkv)
x = x + self.drop_path(self.gamma_1 * y)
x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
return x, qkv
if self.gamma_1 is None:
x = x + self.drop_path(self.attn(self.norm1(x),
rel_pos_bias=rel_pos_bias))
x = x + self.drop_path(self.mlp(self.norm2(x)))
else:
x = x + self.drop_path(self.gamma_1 *
self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias))
x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
return x
class TemporalConv(nn.Module):
def __init__(self, in_chans=1, out_chans=8):
super().__init__()
self.conv1 = nn.Conv2d(in_chans, out_chans, kernel_size=(
1, 15), stride=(1, 8), padding=(0, 7))
self.gelu1 = nn.GELU()
self.norm1 = nn.GroupNorm(4, out_chans)
self.conv2 = nn.Conv2d(out_chans, out_chans,
kernel_size=(1, 3), padding=(0, 1))
self.gelu2 = nn.GELU()
self.norm2 = nn.GroupNorm(4, out_chans)
self.conv3 = nn.Conv2d(out_chans, out_chans,
kernel_size=(1, 3), padding=(0, 1))
self.norm3 = nn.GroupNorm(4, out_chans)
self.gelu3 = nn.GELU()
def forward(self, x, **kwargs):
x = rearrange(x, 'B N A T -> B (N A) T')
B, NA, T = x.shape
x = x.unsqueeze(1)
x = self.gelu1(self.norm1(self.conv1(x)))
x = self.gelu2(self.norm2(self.conv2(x)))
x = self.gelu3(self.norm3(self.conv3(x)))
x = rearrange(x, 'B C NA T -> B NA (T C)')
return x
[docs]class LaBraM(nn.Module):
'''
Implementation of Large Brain Model (LaBraM) for EEG signal processing.
- Paper: Jiang W, Zhao L, Lu B. Large Brain Model for Learning Generic Representations with Tremendous EEG Data in BCI[C]//The Twelfth International Conference on Learning Representations.
- URL: https://openreview.net/forum?id=QzTpTRVtrP
- Related Project: https://github.com/935963004/LaBraM/
Below is a quick start example:
.. code-block:: python
model = LaBraM.base_patch200_200(num_classes=4)
# batch_size, num_electrodes, chunk_size // patch_size, patch_size
x = torch.randn(2, 6, 8, 200)
model(x, electrodes=['FP1', 'FPZ', 'FP2', 'AF9', 'AF7', 'AF5'])
Args:
chunk_size (int): The total length of the EEG signal segment to process. (default: :obj:`1600`)
patch_size (int): The size of each temporal patch. (default: :obj:`200`)
out_chans (int): Number of output channels from the temporal convolution. (default: :obj:`8`)
num_classes (int): Number of classes for classification. (default: :obj:`1000`)
embed_dim (int): Dimension of the embedding space. (default: :obj:`200`)
depth (int): Number of transformer layers. (default: :obj:`12`)
num_heads (int): Number of attention heads in each transformer layer. (default: :obj:`10`)
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. (default: :obj:`4.0`)
qkv_bias (bool): If True, add a learnable bias to query, key, value. (default: :obj:`False`)
qk_norm (callable): Normalization layer for query and key. (default: :obj:`nn.LayerNorm`)
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. (default: :obj:`None`)
drop_rate (float): Dropout rate. (default: :obj:`0.0`)
attn_drop_rate (float): Attention dropout rate. (default: :obj:`0.0`)
drop_path_rate (float): Stochastic depth rate. (default: :obj:`0.0`)
norm_layer (callable): Normalization layer. (default: :obj:`nn.LayerNorm`)
init_values (float): Initial values for layer scale. (default: :obj:`0.0`)
use_mean_pooling (bool): If True, use mean pooling for final feature vector. (default: :obj:`True`)
init_scale (float): Initial scale for the head layer. (default: :obj:`0.001`)
use_abs_pos_emb (bool): If True, use absolute positional embeddings. (default: :obj:`True`)
**kwargs: Additional keyword arguments
'''
def __init__(self, chunk_size=1600, patch_size=200, out_chans=8, num_classes=1000, embed_dim=200, depth=12,
num_heads=10, mlp_ratio=4., qkv_bias=False, qk_norm=partial(nn.LayerNorm, eps=1e-6), qk_scale=None, drop_rate=0., attn_drop_rate=0.,
drop_path_rate=0., norm_layer=partial(nn.LayerNorm, eps=1e-6), init_values=0.0,
use_mean_pooling=True, init_scale=0.001, use_abs_pos_emb=True, **kwargs):
super().__init__()
self.num_classes = num_classes
self.embed_dim = embed_dim
self.patch_embed = TemporalConv(out_chans=out_chans)
self.time_window = chunk_size // patch_size
self.patch_size = patch_size
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
if use_abs_pos_emb:
self.pos_embed = nn.Parameter(torch.zeros(
1, 128 + 1, embed_dim), requires_grad=True)
else:
self.pos_embed = None
self.time_embed = nn.Parameter(torch.zeros(
1, 16, embed_dim), requires_grad=True)
self.pos_drop = nn.Dropout(p=drop_rate)
self.rel_pos_bias = None
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
self.blocks = nn.ModuleList([
Block(
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_norm=qk_norm, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
init_values=init_values, window_size=None)
for i in range(depth)])
self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim)
self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None
self.head = nn.Linear(
embed_dim, num_classes) if num_classes > 0 else nn.Identity()
if self.pos_embed is not None:
trunc_normal_(self.pos_embed, std=.02)
if self.time_embed is not None:
trunc_normal_(self.time_embed, std=.02)
trunc_normal_(self.cls_token, std=.02)
if isinstance(self.head, nn.Linear):
trunc_normal_(self.head.weight, std=.02)
self.apply(self._init_weights)
self.fix_init_weight()
if isinstance(self.head, nn.Linear):
self.head.weight.data.mul_(init_scale)
self.head.bias.data.mul_(init_scale)
def fix_init_weight(self):
def rescale(param, layer_id):
param.div_(math.sqrt(2.0 * layer_id))
for layer_id, layer in enumerate(self.blocks):
rescale(layer.attn.proj.weight.data, layer_id + 1)
rescale(layer.mlp.fc2.weight.data, layer_id + 1)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward_features(self, x, electrodes=[], return_patch_tokens=False, return_all_tokens=False, **kwargs):
input_chans = None
if len(electrodes):
for electrode in electrodes:
assert electrode in standard_1020, f"{electrode} not in standard_1020"
assert len(electrodes) == x.shape[1], f"Number of electrodes {len(electrodes)} should match input ({x.shape[1]})."
input_chans = get_input_chans(electrodes)
else:
assert len(standard_1020) == x.shape[1], f"You must provide electrodes for the input. Expected default channels {standard_1020} are used."
batch_size, n, a, t = x.shape
input_time_window = a if t == self.patch_size else t
x = self.patch_embed(x)
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
pos_embed_used = self.pos_embed[:,
input_chans] if input_chans is not None else self.pos_embed
if self.pos_embed is not None:
pos_embed = pos_embed_used[:, 1:, :].unsqueeze(2).expand(
batch_size, -1, input_time_window, -1).flatten(1, 2)
pos_embed = torch.cat((pos_embed_used[:, 0:1, :].expand(
batch_size, -1, -1), pos_embed), dim=1)
x = x + pos_embed
if self.time_embed is not None:
nc = n if t == self.patch_size else a
time_embed = self.time_embed[:, 0:input_time_window, :].unsqueeze(
1).expand(batch_size, nc, -1, -1).flatten(1, 2)
x[:, 1:, :] += time_embed
x = self.pos_drop(x)
for blk in self.blocks:
x = blk(x, rel_pos_bias=None)
x = self.norm(x)
if self.fc_norm is not None:
if return_all_tokens:
return self.fc_norm(x)
t = x[:, 1:, :]
if return_patch_tokens:
return self.fc_norm(t)
else:
return self.fc_norm(t.mean(1))
else:
if return_all_tokens:
return x
elif return_patch_tokens:
return x[:, 1:]
else:
return x[:, 0]
[docs] def forward(self, x, electrodes=[], return_patch_tokens=False, return_all_tokens=False, **kwargs):
x = self.forward_features(
x, electrodes=electrodes, return_patch_tokens=return_patch_tokens, return_all_tokens=return_all_tokens, **kwargs)
x = self.head(x)
return x
def load_pretrained(self, checkpoint_path):
model_key = 'model|module'
model_filter_name = 'gzp'
checkpoint = torch.load(checkpoint_path, map_location='cpu')
checkpoint_model = None
for model_key in model_key.split('|'):
if model_key in checkpoint:
checkpoint_model = checkpoint[model_key]
print("Load state_dict by model_key = %s" % model_key)
break
if checkpoint_model is None:
checkpoint_model = checkpoint
if (checkpoint_model is not None) and (model_filter_name != ''):
all_keys = list(checkpoint_model.keys())
new_dict = OrderedDict()
for key in all_keys:
if key.startswith('student.'):
new_dict[key[8:]] = checkpoint_model[key]
else:
pass
checkpoint_model = new_dict
self.load_state_dict(checkpoint_model, strict=False)
@staticmethod
def base_patch200_200(**kwargs):
return LaBraM(
patch_size=200, embed_dim=200, depth=12, num_heads=10, mlp_ratio=4, qk_norm=partial(nn.LayerNorm, eps=1e-6),
norm_layer=partial(nn.LayerNorm, eps=1e-6), init_values=0.1, **kwargs)
@staticmethod
def large_patch200_200(**kwargs):
return LaBraM(
patch_size=200, embed_dim=400, depth=24, num_heads=16, mlp_ratio=4, out_chans=16, qk_norm=partial(nn.LayerNorm, eps=1e-6),
norm_layer=partial(nn.LayerNorm, eps=1e-6), init_values=1e-5, **kwargs)
@staticmethod
def huge_patch200_200(**kwargs):
return LaBraM(
patch_size=200, embed_dim=800, depth=48, num_heads=16, mlp_ratio=4, out_chans=32, qk_norm=partial(nn.LayerNorm, eps=1e-6),
norm_layer=partial(nn.LayerNorm, eps=1e-6), init_values=1e-5, **kwargs)