LaBraM¶
- class torcheeg.models.LaBraM(chunk_size=1600, patch_size=200, out_chans=8, num_classes=1000, embed_dim=200, depth=12, num_heads=10, mlp_ratio=4.0, qkv_bias=False, qk_norm=functools.partial(<class 'torch.nn.modules.normalization.LayerNorm'>, eps=1e-06), qk_scale=None, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.0, norm_layer=functools.partial(<class 'torch.nn.modules.normalization.LayerNorm'>, eps=1e-06), init_values=0.0, use_mean_pooling=True, init_scale=0.001, use_abs_pos_emb=True, **kwargs)[source][source]¶
Implementation of Large Brain Model (LaBraM) for EEG signal processing.
Paper: Jiang W, Zhao L, Lu B. Large Brain Model for Learning Generic Representations with Tremendous EEG Data in BCI[C]//The Twelfth International Conference on Learning Representations.
Related Project: https://github.com/935963004/LaBraM/
Below is a quick start example:
model = LaBraM.base_patch200_200(num_classes=4) # batch_size, num_electrodes, chunk_size // patch_size, patch_size x = torch.randn(2, 6, 8, 200) model(x, electrodes=['FP1', 'FPZ', 'FP2', 'AF9', 'AF7', 'AF5'])
- Parameters:
chunk_size (int) – The total length of the EEG signal segment to process. (default:
1600)patch_size (int) – The size of each temporal patch. (default:
200)out_chans (int) – Number of output channels from the temporal convolution. (default:
8)num_classes (int) – Number of classes for classification. (default:
1000)embed_dim (int) – Dimension of the embedding space. (default:
200)depth (int) – Number of transformer layers. (default:
12)num_heads (int) – Number of attention heads in each transformer layer. (default:
10)mlp_ratio (float) – Ratio of mlp hidden dim to embedding dim. (default:
4.0)qkv_bias (bool) – If True, add a learnable bias to query, key, value. (default:
False)qk_norm (callable) – Normalization layer for query and key. (default:
nn.LayerNorm)qk_scale (float) – Override default qk scale of head_dim ** -0.5 if set. (default:
None)drop_rate (float) – Dropout rate. (default:
0.0)attn_drop_rate (float) – Attention dropout rate. (default:
0.0)drop_path_rate (float) – Stochastic depth rate. (default:
0.0)norm_layer (callable) – Normalization layer. (default:
nn.LayerNorm)init_values (float) – Initial values for layer scale. (default:
0.0)use_mean_pooling (bool) – If True, use mean pooling for final feature vector. (default:
True)init_scale (float) – Initial scale for the head layer. (default:
0.001)use_abs_pos_emb (bool) – If True, use absolute positional embeddings. (default:
True)**kwargs – Additional keyword arguments
- forward(x, electrodes=[], return_patch_tokens=False, return_all_tokens=False, **kwargs)[source][source]¶
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.