Source code for lisbet.modeling.backbones.transformer

"""Transformer Backbone for Lisbet."""

from typing import Any

import torch
from torch import nn

from lisbet.modeling.backbones.base import BackboneInterface
from lisbet.modeling.modules_extra import PosEmbedding


[docs] class TransformerBackbone(BackboneInterface): """Transformer backbone for sequence modeling. A transformer-based backbone that processes input sequences using self-attention mechanisms. The backbone includes frame embedding, positional embedding, transformer encoder layers, and layer normalization. Parameters ---------- feature_dim : int Dimension of the input features. embedding_dim : int Dimension of the output embeddings. hidden_dim : int Dimension of the feedforward network inside transformer layers. num_heads : int Number of attention heads in the multi-head attention mechanism. num_layers : int Number of transformer encoder layers. max_length : int Maximum sequence length for positional embeddings. Attributes ---------- frame_embedder : nn.Linear Linear layer for embedding input frames. pos_embedder : PosEmbedding Positional embedding module. transformer_encoder : nn.TransformerEncoder Stack of transformer encoder layers. layer_norm : nn.LayerNorm Layer normalization applied to the output. """
[docs] def __init__( self, feature_dim: int, embedding_dim: int, hidden_dim: int, num_heads: int, num_layers: int, max_length: int, ) -> None: super().__init__() self.feature_dim = feature_dim self.embedding_dim = embedding_dim self.hidden_dim = hidden_dim self.num_heads = num_heads self.num_layers = num_layers self.max_length = max_length self.frame_embedder = nn.Linear(feature_dim, embedding_dim) self.pos_embedder = PosEmbedding(max_length, embedding_dim) encoder_layers = nn.TransformerEncoderLayer( embedding_dim, num_heads, hidden_dim, dropout=0.0, activation="gelu", batch_first=True, norm_first=True, ) self.transformer_encoder = nn.TransformerEncoder( encoder_layers, num_layers, enable_nested_tensor=False ) self.layer_norm = nn.LayerNorm(embedding_dim)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward pass through the transformer backbone. Parameters ---------- x : torch.Tensor Input tensor of shape (batch_size, sequence_length, feature_dim). Returns ------- torch.Tensor Output tensor of shape (batch_size, sequence_length, embedding_dim). """ x = self.frame_embedder(x) x = self.pos_embedder(x) x = self.transformer_encoder(x) x = self.layer_norm(x) return x
[docs] def get_config(self) -> dict[str, Any]: """Get the configuration dictionary for this backbone. Returns ------- dict[str, Any] Configuration dictionary containing all parameters needed to recreate this backbone instance. """ return { "feature_dim": self.feature_dim, "embedding_dim": self.embedding_dim, "hidden_dim": self.hidden_dim, "num_heads": self.num_heads, "num_layers": self.num_layers, "max_length": self.max_length, }
[docs] @classmethod def from_config(cls, config: dict[str, Any]) -> "TransformerBackbone": """Create a TransformerBackbone instance from a configuration dictionary. Parameters ---------- config : dict[str, Any] Configuration dictionary containing all parameters needed to create the backbone instance. Returns ------- TransformerBackbone New TransformerBackbone instance created from the configuration. """ return cls( feature_dim=config["feature_dim"], embedding_dim=config["embedding_dim"], hidden_dim=config["hidden_dim"], num_heads=config["num_heads"], num_layers=config["num_layers"], max_length=config["max_length"], )