Shortcuts

LaBraM

class torcheeg.models.LaBraM(chunk_size=1600, patch_size=200, out_chans=8, num_classes=1000, embed_dim=200, depth=12, num_heads=10, mlp_ratio=4.0, qkv_bias=False, qk_norm=functools.partial(<class 'torch.nn.modules.normalization.LayerNorm'>, eps=1e-06), qk_scale=None, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.0, norm_layer=functools.partial(<class 'torch.nn.modules.normalization.LayerNorm'>, eps=1e-06), init_values=0.0, use_mean_pooling=True, init_scale=0.001, use_abs_pos_emb=True, **kwargs)[source][source]

Implementation of Large Brain Model (LaBraM) for EEG signal processing.

Below is a quick start example:

model = LaBraM.base_patch200_200(num_classes=4)
# batch_size, num_electrodes, chunk_size // patch_size, patch_size
x = torch.randn(2, 6, 8, 200)
model(x, electrodes=['FP1', 'FPZ', 'FP2', 'AF9', 'AF7', 'AF5'])
Parameters:
  • chunk_size (int) – The total length of the EEG signal segment to process. (default: 1600)

  • patch_size (int) – The size of each temporal patch. (default: 200)

  • out_chans (int) – Number of output channels from the temporal convolution. (default: 8)

  • num_classes (int) – Number of classes for classification. (default: 1000)

  • embed_dim (int) – Dimension of the embedding space. (default: 200)

  • depth (int) – Number of transformer layers. (default: 12)

  • num_heads (int) – Number of attention heads in each transformer layer. (default: 10)

  • mlp_ratio (float) – Ratio of mlp hidden dim to embedding dim. (default: 4.0)

  • qkv_bias (bool) – If True, add a learnable bias to query, key, value. (default: False)

  • qk_norm (callable) – Normalization layer for query and key. (default: nn.LayerNorm)

  • qk_scale (float) – Override default qk scale of head_dim ** -0.5 if set. (default: None)

  • drop_rate (float) – Dropout rate. (default: 0.0)

  • attn_drop_rate (float) – Attention dropout rate. (default: 0.0)

  • drop_path_rate (float) – Stochastic depth rate. (default: 0.0)

  • norm_layer (callable) – Normalization layer. (default: nn.LayerNorm)

  • init_values (float) – Initial values for layer scale. (default: 0.0)

  • use_mean_pooling (bool) – If True, use mean pooling for final feature vector. (default: True)

  • init_scale (float) – Initial scale for the head layer. (default: 0.001)

  • use_abs_pos_emb (bool) – If True, use absolute positional embeddings. (default: True)

  • **kwargs – Additional keyword arguments

forward(x, electrodes=[], return_patch_tokens=False, return_all_tokens=False, **kwargs)[source][source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources