Source code for lisbet.inference.embedding

"""
Embedding extraction for LISBET.
"""


import numpy as np
import torch

from lisbet.inference.common import predict
from lisbet.io import dump_embeddings


def _embedding_forward(model: torch.nn.Module, data: torch.Tensor) -> torch.Tensor:
    """Forward function for extracting embeddings from the model."""
    return model(data, "embedding").squeeze(dim=1)


[docs] def compute_embeddings( model_path: str, weights_path: str, data_format: str, data_path: str, data_scale: str | None = None, data_filter: str | None = None, window_size: int = 200, window_offset: int = 0, fps_scaling: float = 1.0, batch_size: int = 128, output_path: str | None = None, select_coords: str | None = None, rename_coords: str | None = None, ) -> list[tuple[str, np.ndarray]]: """ Compute LISBET embeddings for every record in a dataset. This function loads an embedding model and processes an entire dataset, computing embeddings for each sequence. Parameters ---------- model_path : str Path to the model config (JSON format). weights_path : str Path to the HDF5 file containing the model weights. data_format : str Format of the dataset to analyze. data_path : str Path to the directory containing the dataset files. data_scale : str or None Scaling string or None for auto-scaling. data_filter : str, optional Filter to apply when loading records. window_size : int, default=200 Size of the sliding window to apply on the input sequences. window_offset : int, default=0 Sliding window offset. fps_scaling : float, default=1.0 FPS scaling factor. batch_size : int, default=128 Batch size for inference. output_path : str or None, optional If given, embeddings will be saved as CSV files in this directory. select_coords : str, optional Optional subset string in the format 'INDIVIDUALS;AXES;KEYPOINTS', where each field is a comma-separated list or '*' for all. If None, all data is loaded. rename_coords : str, optional Optional coordinate names remapping in the format 'INDIVIDUALS;AXES;KEYPOINTS', where each field is a comma-separated list of maps 'old_id:new_id' or '*' for no remapping at that level. If None, original dataset names are used. Returns ------- list of tuple of (str, ndarray) A list of (sequence ID, embedding) tuples for each sequence. Raises ------ ValueError If the loaded model is not an embedding model. """ results = predict( model_path=model_path, weights_path=weights_path, forward_fn=_embedding_forward, data_format=data_format, data_path=data_path, data_scale=data_scale, window_size=window_size, window_offset=window_offset, fps_scaling=fps_scaling, batch_size=batch_size, data_filter=data_filter, select_coords=select_coords, rename_coords=rename_coords, ) # Store results on file, if requested if output_path is not None: dump_embeddings(results, output_path) return results