Source code for lisbet.inference.common

"""
Core prediction logic for LISBET inference.
"""

from itertools import zip_longest

import numpy as np
import torch
import yaml
from lightning.fabric.utilities.data import suggested_max_num_workers
from rich.console import Console
from rich.table import Table
from torchvision import transforms
from tqdm.auto import tqdm

from lisbet.datasets import WindowDataset
from lisbet.io import load_model, load_records
from lisbet.transforms_extra import PoseToTensor


[docs] def select_device(device: str | None = None) -> torch.device: """ Select the appropriate torch device. Parameters ---------- device : str or None, optional Device string (e.g., 'cuda', 'cpu', 'mps'). If None, automatically selects the best available device. Returns ------- torch.device The selected torch device. """ if device is None: if torch.cuda.is_available(): return torch.device("cuda") elif torch.mps.is_available(): return torch.device("mps") else: return torch.device("cpu") else: return torch.device(device)
[docs] def load_model_and_config(model_path: str, weights_path: str, device: torch.device): """ Load model configuration and model, move model to device and set to eval mode. Parameters ---------- model_path : str Path to the model configuration YAML file. weights_path : str Path to the model weights file. device : torch.device Device to move the model to. Returns ------- model : torch.nn.Module The loaded model. config : dict The loaded model configuration. """ with open(model_path, encoding="utf-8") as f: config = yaml.safe_load(f) model = load_model(model_path, weights_path) model.to(device) model.eval() return model, config
[docs] def check_feature_compatibility(config, records): """ Check if the input features of the model and dataset match. Parameters ---------- config : dict Model configuration dictionary containing input features. records : list List of dataset records to check for feature compatibility. Raises ------ ValueError If the input features of the model and dataset do not match. """ model_features = config.get("input_features", {}) dataset_coords = records[0].posetracks.coords # Extract dataset features dataset_individuals = list(dataset_coords["individuals"].values) dataset_keypoints = list(dataset_coords["keypoints"].values) dataset_space = list(dataset_coords["space"].values) # Extract model features model_individuals = list(model_features.get("individuals", [])) model_keypoints = list(model_features.get("keypoints", [])) model_space = list(model_features.get("space", [])) # Compare features features_match = ( dataset_individuals == model_individuals and dataset_keypoints == model_keypoints and dataset_space == model_space ) if not features_match: console = Console() table = Table(title="Input Features Compatibility Check") columns = [ ("Model Individuals", model_individuals, "cyan"), ("Dataset Individuals", dataset_individuals, "magenta"), ("Model Keypoints", model_keypoints, "cyan"), ("Dataset Keypoints", dataset_keypoints, "magenta"), ("Model Space", model_space, "cyan"), ("Dataset Space", dataset_space, "magenta"), ] for name, _, style in columns: table.add_column(name, style=style) for row in zip_longest( model_individuals, dataset_individuals, model_keypoints, dataset_keypoints, model_space, dataset_space, fillvalue="", ): table.add_row(*(str(item) for item in row)) # Print the table to string with console.capture() as capture: console.print( "[bold red]ERROR: Incompatible input features between model and " "dataset!\nPlease use 'select_coords' and 'rename_coords' to " "align model and dataset input features.[/bold red]" ) console.print(table) table_str = capture.get() raise ValueError(f"Incompatible input features.\n{table_str}")
[docs] def predict_record( record, model, device, window_size, window_offset, fps_scaling, batch_size, forward_fn, ): """ Run prediction on a single record and return the output. Parameters ---------- record : object The dataset record to predict on. model : torch.nn.Module The trained model for inference. device : torch.device Device to run inference on. window_size : int Size of the sliding window for the dataset. window_offset : int Offset for the sliding window. fps_scaling : float Scaling factor for frames per second. batch_size : int Batch size for inference. forward_fn : callable Function to perform the forward pass (model, data) -> prediction. Returns ------- output : np.ndarray The concatenated prediction output for the record. """ dataset = WindowDataset( records=[record], window_size=window_size, window_offset=window_offset, fps_scaling=fps_scaling, transform=transforms.Compose([PoseToTensor()]), ) num_workers = min(suggested_max_num_workers(1), batch_size // 8) prefetch_factor = 4 if num_workers > 0 else None pin_memory = device.type == "cuda" dataloader = torch.utils.data.DataLoader( dataset, batch_size=batch_size, num_workers=num_workers, prefetch_factor=prefetch_factor, pin_memory=pin_memory, ) predictions = [] with torch.no_grad(): for data in dataloader: data = data.to(device) pred = forward_fn(model, data) predictions.append(pred.cpu().numpy()) output = np.concatenate(predictions) return output
[docs] def predict( model_path: str, weights_path: str, data_format: str, data_path: str, *, data_scale: str | None = None, data_filter: str | None = None, window_size: int = 200, window_offset: int = 0, fps_scaling: float = 1.0, batch_size: int = 128, select_coords: str | None = None, rename_coords: str | None = None, device: str | None = None, forward_fn: callable = None, ) -> list[tuple[str, np.ndarray]]: """ Run model prediction on all records in a dataset. Handles model loading, feature compatibility, batching, and device management. Parameters ---------- model_path : str Path to the model configuration YAML file. weights_path : str Path to the model weights file. data_format : str Format of the input data (e.g., 'h5', 'csv'). data_path : str Path to the input data. data_scale : str or None, optional Scaling method for the data. data_filter : str or None, optional Filter to apply to the data. window_size : int, optional Size of the sliding window for the dataset. Default is 200. window_offset : int, optional Offset for the sliding window. Default is 0. fps_scaling : float, optional Scaling factor for frames per second. Default is 1.0. batch_size : int, optional Batch size for inference. Default is 128. select_coords : str or None, optional Coordinate selection string for filtering input features. rename_coords : str or None, optional Coordinate renaming string for aligning input features. device : str or None, optional Device string (e.g., 'cuda', 'cpu', 'mps'). If None, automatically selects device. forward_fn : callable, optional Function to perform the forward pass (model, data) -> prediction. Returns ------- results : list of tuple List of (record_id, prediction) tuples, where prediction is a numpy array. """ # Device selection device = select_device(device) # Load model config and model model, config = load_model_and_config(model_path, weights_path, device) # Load records records = load_records( data_format=data_format, data_path=data_path, data_filter=data_filter, data_scale=data_scale, select_coords=select_coords, rename_coords=rename_coords, ) # Input features compatibility check check_feature_compatibility(config, records) results = [] for record in tqdm(records, desc=f"Predicting {data_format} dataset"): output = predict_record( record, model, device, window_size, window_offset, fps_scaling, batch_size, forward_fn, ) results.append((record.id, output)) return results