"""Multi-task model for different tasks."""
from typing import Any
import torch
from torch import nn
from lisbet.modeling.backbones.base import BackboneInterface
from lisbet.modeling.backbones.lstm import LSTMBackbone
from lisbet.modeling.backbones.transformer import TransformerBackbone
from lisbet.modeling.heads.classification import (
FrameClassificationHead,
WindowClassificationHead,
)
from lisbet.modeling.heads.embedding import EmbeddingHead
[docs]
class MultiTaskModel(nn.Module):
"""Multi-task model that combines a backbone with multiple task-specific heads.
This model enables training and inference across multiple tasks using a shared
backbone representation. Each task has its own dedicated head that processes
the backbone output.
Parameters
----------
backbone : BackboneInterface
The backbone model that processes input sequences and produces
shared representations.
task_heads : dict[str, nn.Module]
Dictionary mapping task IDs to their corresponding task-specific heads.
Attributes
----------
backbone : BackboneInterface
The shared backbone model.
task_heads : nn.ModuleDict
Dictionary of task-specific heads.
model_id : str
Identifier for the model instance, useful for logging or saving. Defaults to
"lisbet_model".
"""
[docs]
def __init__(
self,
backbone: BackboneInterface,
task_heads: dict[str, nn.Module],
model_id: str = "lisbet_model",
) -> None:
super().__init__()
self.backbone = backbone
self.task_heads = nn.ModuleDict(task_heads)
self.model_id = model_id
[docs]
def forward(self, x: torch.Tensor, task_id: str) -> torch.Tensor:
"""Forward pass through the model for a specific task.
Parameters
----------
x : torch.Tensor
Input tensor of shape (batch_size, sequence_length, input_dim).
task_id : str
Identifier for the task to use. Must be a key in task_heads.
Returns
-------
torch.Tensor
Task-specific output tensor. Shape depends on the specific task head.
Raises
------
KeyError
If task_id is not found in the available task heads.
"""
x = self.backbone(x)
x = self.task_heads[task_id](x)
return x
[docs]
def get_task_ids(self) -> list[str]:
"""Get the list of available task IDs.
Returns
-------
list[str]
List of task IDs that can be used with this model.
"""
return list(self.task_heads.keys())
[docs]
def get_config(self) -> dict[str, Any]:
"""Get the configuration dictionary for this model.
Returns
-------
dict[str, Any]
Configuration dictionary containing backbone config and task head configs.
"""
task_head_configs = {}
for task_id, head in self.task_heads.items():
if hasattr(head, "get_config"):
task_head_configs[task_id] = {
"type": head.__class__.__name__,
"config": head.get_config(),
}
else:
# Fallback for heads without get_config method
task_head_configs[task_id] = {
"type": head.__class__.__name__,
"config": {},
}
return {
"backbone": {
"type": self.backbone.__class__.__name__,
"config": self.backbone.get_config(),
},
"task_heads": task_head_configs,
"model_id": self.model_id,
}
[docs]
@classmethod
def from_config(
cls,
config: dict[str, Any],
backbone_registry: dict[str, type] | None = None,
head_registry: dict[str, type] | None = None,
) -> "MultiTaskModel":
"""Create a MultiTaskModel instance from a configuration dictionary.
Parameters
----------
config : dict[str, Any]
Configuration dictionary containing backbone and task head configs.
backbone_registry : dict[str, type] or None, optional
Registry mapping backbone type names to their classes.
If None, uses a default registry.
head_registry : dict[str, type] or None, optional
Registry mapping head type names to their classes.
If None, uses a default registry.
Returns
-------
MultiTaskModel
New MultiTaskModel instance created from the configuration.
Raises
------
ValueError
If backbone or head types are not found in the registries.
"""
# Default registries
if backbone_registry is None:
backbone_registry = {
"TransformerBackbone": TransformerBackbone,
"LSTMBackbone": LSTMBackbone,
}
if head_registry is None:
head_registry = {
"FrameClassificationHead": FrameClassificationHead,
"WindowClassificationHead": WindowClassificationHead,
"EmbeddingHead": EmbeddingHead,
}
# Create backbone
backbone_config = config["backbone"]
backbone_type = backbone_config["type"]
if backbone_type not in backbone_registry:
raise ValueError(f"Unknown backbone type: {backbone_type}")
backbone_cls = backbone_registry[backbone_type]
backbone = backbone_cls.from_config(backbone_config["config"])
# Create task heads
task_heads = {}
for task_id, head_config in config["task_heads"].items():
head_type = head_config["type"]
if head_type not in head_registry:
raise ValueError(f"Unknown head type: {head_type}")
head_cls = head_registry[head_type]
task_heads[task_id] = head_cls.from_config(head_config["config"])
return cls(
backbone=backbone,
task_heads=task_heads,
model_id=config.get("model_id", "lisbet_model"),
)