Source code for lisbet.modeling.backbones.base
"""Abstract base class for backbone models."""
from abc import ABC, abstractmethod
from typing import Any
import torch
from torch import nn
[docs]
class BackboneInterface(nn.Module, ABC):
"""Abstract base class for all backbone models.
This interface defines the required methods that all backbone
implementations must provide.
"""
[docs]
@abstractmethod
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass through the backbone.
Parameters
----------
x : torch.Tensor
Input tensor of shape (batch_size, sequence_length, input_dim).
Returns
-------
torch.Tensor
Output tensor of shape (batch_size, sequence_length, output_dim).
"""
pass
[docs]
@abstractmethod
def get_config(self) -> dict[str, Any]:
"""Get the configuration dictionary for this backbone.
Returns
-------
dict[str, Any]
Configuration dictionary containing all parameters needed
to recreate this backbone instance.
"""
pass
[docs]
@classmethod
@abstractmethod
def from_config(cls, config: dict[str, Any]) -> "BackboneInterface":
"""Create a backbone instance from a configuration dictionary.
Parameters
----------
config : dict[str, Any]
Configuration dictionary containing all parameters needed
to create the backbone instance.
Returns
-------
BackboneInterface
New backbone instance created from the configuration.
"""
pass