Source code for lisbet.modeling.factory
"""Model factory utilities for LISBET.
This module provides functions to create models from configuration Pydantic models,
including support for built-in transformer presets and parameter overrides.
All configuration objects must be Pydantic models with fields matching the expected
model constructor arguments.
Example
-------
from lisbet.config.presets import instantiate_backbone_preset
from lisbet.modeling.factory import create_model_from_config
cfg = instantiate_backbone_preset("transformer-base", feature_dim=32, embedding_dim=32)
model = create_model_from_config(cfg)
"""
from lisbet.config.schemas import ModelConfig
from lisbet.modeling import (
EmbeddingHead,
FrameClassificationHead,
MultiTaskModel,
TransformerBackbone,
WindowClassificationHead,
)
from lisbet.modeling.backbones.lstm import LSTMBackbone
from lisbet.modeling.backbones.tcn import TCNBackbone
# Registry for backbone types (future extensibility)
BACKBONE_REGISTRY = {
"transformer": TransformerBackbone,
"lstm": LSTMBackbone,
"tcn": TCNBackbone,
}
# Registry for head types (future extensibility)
HEAD_REGISTRY = {
"frame_classification": FrameClassificationHead,
"window_classification": WindowClassificationHead,
"embedding": EmbeddingHead,
}
[docs]
def create_model_from_config(model_config: ModelConfig) -> MultiTaskModel:
"""
Create a LISBET model from a configuration Pydantic model and head definitions.
Parameters
----------
model_config : ModelConfig
Configuration Pydantic model instance
Returns
-------
MultiTaskModel
Instantiated LISBET model.
Raises
------
ValueError
If the backbone type is unknown or a task_id is unrecognized.
Notes
-----
- The config must be a Pydantic model with fields matching the backbone constructor.
- The backbone config must have a 'type' field for discrimination.
"""
backbone_type = model_config.backbone.type
backbone_cls = BACKBONE_REGISTRY.get(backbone_type)
if backbone_cls is None:
raise ValueError(f"Unknown backbone type: {backbone_type}")
backbone_kwargs = model_config.backbone.model_dump(exclude={"type"})
backbone = backbone_cls(**backbone_kwargs)
# Build heads for each task
heads = {}
for task_id, head_cfg in model_config.out_heads.items():
if task_id == "embedding":
heads[task_id] = EmbeddingHead(
output_token_idx=head_cfg.get("output_token_idx", -1)
)
elif task_id in ("multiclass", "multilabel"):
heads[task_id] = FrameClassificationHead(
output_token_idx=head_cfg.get("output_token_idx", -1),
input_dim=backbone_kwargs["embedding_dim"],
num_classes=head_cfg["num_classes"],
hidden_dim=head_cfg.get("hidden_dim"),
)
elif task_id in ("cons", "order", "shift", "warp"):
heads[task_id] = WindowClassificationHead(
input_dim=backbone_kwargs["embedding_dim"],
num_classes=head_cfg.get("num_classes", 1),
hidden_dim=head_cfg.get("hidden_dim"),
)
else:
raise ValueError(f"Unknown task_id: {task_id}")
return MultiTaskModel(backbone, heads)