lisbet.modeling.models#
Multi-task model for different tasks.
Classes
|
Multi-task model that combines a backbone with multiple task-specific heads. |
- class lisbet.modeling.models.MultiTaskModel(backbone, task_heads, model_id='lisbet_model')[source]#
Multi-task model that combines a backbone with multiple task-specific heads.
This model enables training and inference across multiple tasks using a shared backbone representation. Each task has its own dedicated head that processes the backbone output.
- Parameters:
backbone (
BackboneInterface) – The backbone model that processes input sequences and produces shared representations.task_heads (
dict[str,Module]) – Dictionary mapping task IDs to their corresponding task-specific heads.
- backbone#
The shared backbone model.
- Type:
- task_heads#
Dictionary of task-specific heads.
- Type:
nn.ModuleDict
- model_id#
Identifier for the model instance, useful for logging or saving. Defaults to “lisbet_model”.
- Type:
str
- __init__(backbone, task_heads, model_id='lisbet_model')[source]#
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(x, task_id)[source]#
Forward pass through the model for a specific task.
- Parameters:
x (
Tensor) – Input tensor of shape (batch_size, sequence_length, input_dim).task_id (
str) – Identifier for the task to use. Must be a key in task_heads.
- Returns:
Task-specific output tensor. Shape depends on the specific task head.
- Return type:
Tensor- Raises:
KeyError – If task_id is not found in the available task heads.
- get_task_ids()[source]#
Get the list of available task IDs.
- Returns:
List of task IDs that can be used with this model.
- Return type:
list[str]
- get_config()[source]#
Get the configuration dictionary for this model.
- Returns:
Configuration dictionary containing backbone config and task head configs.
- Return type:
dict[str,Any]
- classmethod from_config(config, backbone_registry=None, head_registry=None)[source]#
Create a MultiTaskModel instance from a configuration dictionary.
- Parameters:
config (
dict[str,Any]) – Configuration dictionary containing backbone and task head configs.backbone_registry (
dict[str,type] |None) – Registry mapping backbone type names to their classes. If None, uses a default registry.head_registry (
dict[str,type] |None) – Registry mapping head type names to their classes. If None, uses a default registry.
- Returns:
New MultiTaskModel instance created from the configuration.
- Return type:
- Raises:
ValueError – If backbone or head types are not found in the registries.