lisbet.modeling.models#

Multi-task model for different tasks.

Classes

MultiTaskModel(backbone, task_heads[, model_id])

Multi-task model that combines a backbone with multiple task-specific heads.

class lisbet.modeling.models.MultiTaskModel(backbone, task_heads, model_id='lisbet_model')[source]#

Multi-task model that combines a backbone with multiple task-specific heads.

This model enables training and inference across multiple tasks using a shared backbone representation. Each task has its own dedicated head that processes the backbone output.

Parameters:
  • backbone (BackboneInterface) – The backbone model that processes input sequences and produces shared representations.

  • task_heads (dict[str, Module]) – Dictionary mapping task IDs to their corresponding task-specific heads.

backbone#

The shared backbone model.

Type:

BackboneInterface

task_heads#

Dictionary of task-specific heads.

Type:

nn.ModuleDict

model_id#

Identifier for the model instance, useful for logging or saving. Defaults to “lisbet_model”.

Type:

str

__init__(backbone, task_heads, model_id='lisbet_model')[source]#

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(x, task_id)[source]#

Forward pass through the model for a specific task.

Parameters:
  • x (Tensor) – Input tensor of shape (batch_size, sequence_length, input_dim).

  • task_id (str) – Identifier for the task to use. Must be a key in task_heads.

Returns:

Task-specific output tensor. Shape depends on the specific task head.

Return type:

Tensor

Raises:

KeyError – If task_id is not found in the available task heads.

get_task_ids()[source]#

Get the list of available task IDs.

Returns:

List of task IDs that can be used with this model.

Return type:

list[str]

get_config()[source]#

Get the configuration dictionary for this model.

Returns:

Configuration dictionary containing backbone config and task head configs.

Return type:

dict[str, Any]

classmethod from_config(config, backbone_registry=None, head_registry=None)[source]#

Create a MultiTaskModel instance from a configuration dictionary.

Parameters:
  • config (dict[str, Any]) – Configuration dictionary containing backbone and task head configs.

  • backbone_registry (dict[str, type] | None) – Registry mapping backbone type names to their classes. If None, uses a default registry.

  • head_registry (dict[str, type] | None) – Registry mapping head type names to their classes. If None, uses a default registry.

Returns:

New MultiTaskModel instance created from the configuration.

Return type:

MultiTaskModel

Raises:

ValueError – If backbone or head types are not found in the registries.