lisbet.inference.common#

Core prediction logic for LISBET inference.

Functions

check_feature_compatibility(config, records)

Check if the input features of the model and dataset match.

load_model_and_config(model_path, ...)

Load model configuration and model, move model to device and set to eval mode.

predict(model_path, weights_path, ...[, ...])

Run model prediction on all records in a dataset.

predict_record(record, model, device, ...)

Run prediction on a single record and return the output.

select_device([device])

Select the appropriate torch device.

lisbet.inference.common.select_device(device=None)[source]#

Select the appropriate torch device.

Parameters:

device (str | None) – Device string (e.g., ‘cuda’, ‘cpu’, ‘mps’). If None, automatically selects the best available device.

Returns:

The selected torch device.

Return type:

device

lisbet.inference.common.load_model_and_config(model_path, weights_path, device)[source]#

Load model configuration and model, move model to device and set to eval mode.

Parameters:
  • model_path (str) – Path to the model configuration YAML file.

  • weights_path (str) – Path to the model weights file.

  • device (device) – Device to move the model to.

Returns:

  • model (torch.nn.Module) – The loaded model.

  • config (dict) – The loaded model configuration.

lisbet.inference.common.check_feature_compatibility(config, records)[source]#

Check if the input features of the model and dataset match.

Parameters:
  • config (dict) – Model configuration dictionary containing input features.

  • records (list) – List of dataset records to check for feature compatibility.

Raises:

ValueError – If the input features of the model and dataset do not match.

lisbet.inference.common.predict_record(record, model, device, window_size, window_offset, fps_scaling, batch_size, forward_fn)[source]#

Run prediction on a single record and return the output.

Parameters:
  • record (object) – The dataset record to predict on.

  • model (torch.nn.Module) – The trained model for inference.

  • device (torch.device) – Device to run inference on.

  • window_size (int) – Size of the sliding window for the dataset.

  • window_offset (int) – Offset for the sliding window.

  • fps_scaling (float) – Scaling factor for frames per second.

  • batch_size (int) – Batch size for inference.

  • forward_fn (callable) – Function to perform the forward pass (model, data) -> prediction.

Returns:

output – The concatenated prediction output for the record.

Return type:

np.ndarray

lisbet.inference.common.predict(model_path, weights_path, data_format, data_path, *, data_scale=None, data_filter=None, window_size=200, window_offset=0, fps_scaling=1.0, batch_size=128, select_coords=None, rename_coords=None, device=None, forward_fn=None)[source]#

Run model prediction on all records in a dataset.

Handles model loading, feature compatibility, batching, and device management.

Parameters:
  • model_path (str) – Path to the model configuration YAML file.

  • weights_path (str) – Path to the model weights file.

  • data_format (str) – Format of the input data (e.g., ‘h5’, ‘csv’).

  • data_path (str) – Path to the input data.

  • data_scale (str | None) – Scaling method for the data.

  • data_filter (str | None) – Filter to apply to the data.

  • window_size (int) – Size of the sliding window for the dataset. Default is 200.

  • window_offset (int) – Offset for the sliding window. Default is 0.

  • fps_scaling (float) – Scaling factor for frames per second. Default is 1.0.

  • batch_size (int) – Batch size for inference. Default is 128.

  • select_coords (str | None) – Coordinate selection string for filtering input features.

  • rename_coords (str | None) – Coordinate renaming string for aligning input features.

  • device (str | None) – Device string (e.g., ‘cuda’, ‘cpu’, ‘mps’). If None, automatically selects device.

  • forward_fn (callable) – Function to perform the forward pass (model, data) -> prediction.

Returns:

results – List of (record_id, prediction) tuples, where prediction is a numpy array.

Return type:

list[tuple[str, ndarray]]