"""
Behavior annotation (classification and multi-label) for LISBET.
"""
from functools import partial
import numpy as np
import torch
from torch.nn.functional import one_hot
from lisbet.inference.common import predict
from lisbet.io import dump_annotations
def _multiclass_forward(model: torch.nn.Module, data: torch.Tensor) -> torch.Tensor:
"""Forward function for multiclass classification."""
output = model(data, "multiclass")
labels = one_hot(torch.argmax(output, dim=1), num_classes=output.shape[1])
return labels
def _multilabel_forward(
model: torch.nn.Module, data: torch.Tensor, threshold: float = 0.5
) -> torch.Tensor:
"""Forward function for multilabel classification."""
output = model(data, "multilabel")
labels = (torch.sigmoid(output) > threshold).int()
return labels
[docs]
def annotate_behavior(
model_path: str,
weights_path: str,
data_format: str,
data_path: str,
data_scale: str | None = None,
data_filter: str | None = None,
mode: str = "multiclass",
threshold: float = 0.5,
window_size: int = 200,
window_offset: int = 0,
fps_scaling: float = 1.0,
batch_size: int = 128,
output_path: str | None = None,
select_coords: str | None = None,
rename_coords: str | None = None,
) -> list[tuple[str, np.ndarray]]:
"""
Run LISBET behavior classification for every record in a dataset.
This function loads a classification model and processes an entire dataset,
producing behavior annotations for each sequence.
Parameters
----------
model_path : str
Path to the model config (JSON format).
weights_path : str
Path to the HDF5 file containing the model weights.
data_format : str
Format of the dataset to analyze.
data_path : str
Path to the directory containing the dataset files.
data_scale : str or None
Scaling string or None for auto-scaling.
data_filter : str, optional
Filter to apply when loading records.
mode : str, default='multiclass'
Classification mode, either 'multiclass' or 'multilabel'.
threshold : float, default=0.5
Probability threshold for multilabel classification.
window_size : int, default=200
Size of the sliding window to apply on the input sequences.
window_offset : int, default=0
Sliding window offset.
fps_scaling : float, default=1.0
FPS scaling factor.
batch_size : int, default=128
Batch size for inference.
output_path : str or None, optional
If given, predictions will be saved as CSV files in this directory.
select_coords : str, optional
Optional subset string in the format 'INDIVIDUALS;AXES;KEYPOINTS', where each
field is a comma-separated list or '*' for all. If None, all data is loaded.
rename_coords : str, optional
Optional coordinate names remapping in the format 'INDIVIDUALS;AXES;KEYPOINTS',
where each field is a comma-separated list of maps 'old_id:new_id' or '*' for
no remapping at that level. If None, original dataset names are used.
Returns
-------
list of tuple of (str, ndarray)
A list of (sequence ID, predicted behavior) tuples for each sequence.
Raises
------
ValueError
If the loaded model is not a classification model.
"""
if mode == "multiclass":
forward_fn = _multiclass_forward
elif mode == "multilabel":
forward_fn = partial(_multilabel_forward, threshold=threshold)
else:
raise ValueError(f"Unknown mode: {mode}")
results = predict(
model_path=model_path,
weights_path=weights_path,
forward_fn=forward_fn,
data_format=data_format,
data_path=data_path,
data_scale=data_scale,
window_size=window_size,
window_offset=window_offset,
fps_scaling=fps_scaling,
batch_size=batch_size,
data_filter=data_filter,
select_coords=select_coords,
rename_coords=rename_coords,
)
if output_path is not None:
# Save predictions to output path
dump_annotations(results, output_path)
return results