Source code for lisbet.modeling.backbones.lstm

"""LSTM Backbone for LISBET."""

from typing import Any

import torch
from torch import nn

from lisbet.modeling.backbones.base import BackboneInterface


[docs] class LSTMBackbone(BackboneInterface): """ LSTM backbone for sequence modeling. Parameters ---------- feature_dim : int Dimension of the input features. embedding_dim : int Dimension of the output embeddings. hidden_dim : int Dimension of the LSTM hidden state. num_layers : int Number of LSTM layers. """
[docs] def __init__( self, feature_dim: int, embedding_dim: int, hidden_dim: int, num_layers: int, ) -> None: super().__init__() self.feature_dim = feature_dim self.embedding_dim = embedding_dim self.hidden_dim = hidden_dim self.num_layers = num_layers self.input_proj = nn.Linear(feature_dim, embedding_dim) self.lstm = nn.LSTM( input_size=embedding_dim, hidden_size=hidden_dim, num_layers=num_layers, batch_first=True, bidirectional=False, ) self.output_proj = nn.Linear(hidden_dim, embedding_dim)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """ Forward pass through the LSTM backbone. Parameters ---------- x : torch.Tensor Input tensor of shape (batch_size, sequence_length, feature_dim). Returns ------- torch.Tensor Output tensor of shape (batch_size, sequence_length, embedding_dim). """ x = self.input_proj(x) out, _ = self.lstm(x) out = self.output_proj(out) return out
[docs] def get_config(self) -> dict[str, Any]: """Get the configuration dictionary for this backbone.""" return { "feature_dim": self.feature_dim, "embedding_dim": self.embedding_dim, "hidden_dim": self.hidden_dim, "num_layers": self.num_layers, }
[docs] @classmethod def from_config(cls, config: dict[str, Any]) -> "LSTMBackbone": """Create an LSTMBackbone instance from a configuration dictionary.""" return cls( feature_dim=config["feature_dim"], embedding_dim=config["embedding_dim"], hidden_dim=config["hidden_dim"], num_layers=config["num_layers"], )