"""IO utilities for LISBET."""
import inspect
import logging
import re
from dataclasses import dataclass
from functools import partial
from itertools import repeat
from pathlib import Path
import pandas as pd
import torch
import xarray as xr
import yaml
from movement.io import load_poses
from movement.transforms import scale
from torchinfo import summary
from tqdm.auto import tqdm
from lisbet.config.schemas import ModelConfig
from lisbet.modeling.factory import create_model_from_config
[docs]
@dataclass
class Record:
"""
Data structure representing a single pose-tracking record.
Parameters
----------
id : str
Unique identifier for the record, typically derived from the relative path.
posetracks : xarray.Dataset
Pose-tracking data for the record.
annotations : xarray.Dataset or None, optional
Annotations associated with the record, if available.
"""
id: str
posetracks: xr.Dataset
annotations: xr.Dataset | None = None
def _filter_kwargs(kwargs, handler):
"""Filter kwargs to match handler's signature."""
valid_args = [p.name for p in inspect.signature(handler).parameters.values()]
return {k: v for k, v in kwargs.items() if k in valid_args}
def _load_posetracks(seq_path, data_format, data_scale, select_coords, rename_coords):
"""
Load and preprocess a pose-tracking dataset from a sequence directory.
Applies optional coordinate selection and renaming, and rescales coordinates
to [0, 1] if requested.
Parameters
----------
seq_path : Path
Path to the sequence directory.
data_format : str
Format of the dataset ('DLC', 'SLEAP', 'movement').
data_scale : str or None
Scaling string or None for auto-scaling.
select_coords : str or None
Optional subset string in the format 'INDIVIDUALS;AXES;KEYPOINTS', where each
field is a comma-separated list or '*' for all.
rename_coords : str or None
Optional coordinate names remapping in the format 'INDIVIDUALS;AXES;KEYPOINTS',
where each field is a comma-separated list of maps 'old_id:new_id' or '*' for
no remapping at that level.
Returns
-------
xarray.Dataset
The loaded and preprocessed dataset.
"""
# Valid filenames and their corresponding loading functions
# TODO: Test re matching for all supported formats.
data_readers = {
"DLC": (r"(?i)(DLC.*?shuffle\d+|tracking).*\.csv$", load_poses.from_dlc_file),
"SLEAP": (r"(?i)SLEAP.*\.h5$", load_poses.from_sleap_file),
"movement": (r"(?i)tracking.*\.nc$", partial(xr.open_dataset, engine="scipy")),
}
if data_format in data_readers:
# Find all files matching the regex and load them
pattern, loader = data_readers[data_format]
rx = re.compile(pattern)
dss = [loader(pth) for pth in seq_path.iterdir() if rx.search(pth.name)]
else:
raise ValueError(f"Unknown data format {data_format}")
# Check if any datasets were found
if len(dss) == 0:
return None
# Merge all datasets into a single one
# NOTE: There should be only one dataset per sequence, but we keep this for
# compatibility with multiple single-individual datasets
ds = xr.concat(dss, dim="individuals")
logging.debug("Individuals: %s", ds["individuals"].values)
# Replace nan values with 0.0 in the 'position' variable
# NOTE: This is a workaround for the issue with NaN values in the 'position' during
# training, which can cause issues with the model. In the future, we could try
# to handle NaN values more gracefully, e.g., by interpolating them or using a
# more sophisticated imputation method in movement.
ds["position"] = ds["position"].fillna(0.0)
logging.debug("Replaced NaN values in 'position' with 0.0")
# Drop confidence variable, if present
# NOTE: This variable is currently not needed in LISBET, but it may become useful in
# the future, especially if we decide to provide a measure of tracking
# quality to the model.
if "confidence" in ds:
ds = ds.drop_vars("confidence")
# Apply coordinates selection if requested
# TODO: Move string parsing to CLI during refactoring and make 'sel_dict' an
# argument.
if select_coords is not None:
# Parse the subset string: 'INDIVIDUALS;AXES;KEYPOINTS'
fields = select_coords.split(";")
if len(fields) != 3:
raise ValueError(
"select_coords must have format 'INDIVIDUALS;AXES;KEYPOINTS', "
"e.g. 'ind1,ind2;x,y;nose,neck,tail'"
)
# Use a compact dict comprehension for selection
sel_keys = ["individuals", "space", "keypoints"]
sel_dict = {
key: [item.strip() for item in field.split(",") if item.strip()]
for key, field in zip(sel_keys, fields, strict=True)
if field.strip() and field.strip() != "*"
}
if sel_dict:
ds = ds.sel(**sel_dict)
logging.debug("Subset selection: %s", sel_dict)
# Apply coordinates renaming if requested
# TODO: Move string parsing to CLI during refactoring and make 'remap_dict' an
# argument.
if rename_coords is not None:
# Parse the remapping string: 'INDIVIDUALS;AXES;KEYPOINTS'
fields = rename_coords.split(";")
if len(fields) != 3 or any(f.strip() == "" for f in fields):
raise ValueError(
"rename_coords must have format 'INDIVIDUALS;AXES;KEYPOINTS', "
"using '*' for no remapping at a level, e.g. "
"'mouse1:resident,mouse2:intruder;*;nose:snout,tail:tailbase'"
)
rename_keys = ["individuals", "space", "keypoints"]
remap_dict = {}
for key, field in zip(rename_keys, fields, strict=True):
if field.strip() != "*":
mapping = {}
for pair in field.split(","):
old, new = pair.split(":")
mapping[old.strip()] = new.strip()
remap_dict[key] = (
key,
[mapping.get(val, val) for val in ds.coords[key].values],
)
if remap_dict:
ds = ds.assign_coords(**remap_dict)
# Rescale coordinates in the (0, 1) range
if data_scale is not None:
# Explicit scaling
factor = [1 / float(val) for val in data_scale.split("x")]
ds = ds.assign(position=scale(ds["position"], factor=factor))
logging.debug("Rescaled coordinates by factor %s", factor)
elif "image_size_px" in ds.attrs:
# Rescale using image size
factor = [1 / float(val) for val in ds.attrs["image_size_px"]]
ds = ds.assign(position=scale(ds["position"], factor=factor))
logging.debug(
"Rescaled coordinates by image size %s", ds.attrs["image_size_px"]
)
else:
# Auto-scaling
reduce_dims = ("time", "keypoints", "individuals")
pos = ds["position"]
min_val = pos.min(dim=reduce_dims, skipna=True)
max_val = pos.max(dim=reduce_dims, skipna=True)
ds = ds.assign(position=(pos - min_val) / (max_val - min_val))
# Validate scaling
assert ds["position"].min() >= 0.0, "Coordinates should be in the [0, 1] range"
assert ds["position"].max() <= 1.0, "Coordinates should be in the [0, 1] range"
logging.debug(
"Rescaled coordinates between min values %s and max values %s",
min_val.values,
max_val.values,
)
# After scaling, enforce [0, 1] range and raise if not satisfied
min_val = ds["position"].min()
max_val = ds["position"].max()
if min_val < 0.0 or max_val > 1.0:
raise ValueError(
f"After applying data_scale={data_scale}, coordinates are not in [0, 1] "
f"(min={min_val.values}, max={max_val.values}). Explicit scaling assumes "
"that the video has already been cropped to the region of interest during "
"pose estimation, its origin is at (0,0), and the maximum dimensions match "
"the scale provided. If this is not the case, use auto mode "
"(data_scale=None) for normalization."
)
# NOTE: We keep the whole Dataset object, rather than selecting the "position"
# variable, to allow for future extensions (e.g., adding more variables) and
# to keep the FPS information.
return ds
def _load_annotations(seq_path):
"""
Load annotations from a sequence directory, if present.
Returns
-------
xarray.Dataset or None
The loaded annotations, or None if not found.
"""
# Find all files matching the annotations regex and load them
rx = re.compile(r"(?i)(manual_scoring|annotations).*\.nc$")
annotations = [
xr.open_dataset(pth, engine="scipy")
for pth in seq_path.iterdir()
if rx.search(pth.name)
]
# Check if any annotations were found
if len(annotations) == 0:
return None
# Merge all annotations into a single one
annotations = xr.concat(annotations, dim="annotators")
logging.debug("Annotations: %s", annotations.coords["behaviors"].values)
return annotations
[docs]
def load_records(
data_format,
data_path,
data_scale=None,
data_filter=None,
select_coords=None,
rename_coords=None,
):
"""
Load pose-tracking records from a directory, with optional filtering, coordinate
selection and renaming.
Parameters
----------
data_format : {'movement', 'DLC', 'SLEAP'}
Dataset format to load.
data_path : str or Path
Root directory containing the sequence sub-directories.
data_scale : str, optional
If supplied as WIDTHxHEIGHT or WIDTHxHEIGHTxDEPTH, every input coordinate is
assumed to be in data units and is divided by the given scale to obtain
normalized coordinates in the range [0, 1]. Otherwise, the algorithm infers the
active extent directly from the data.
data_filter : str, optional
Comma-separated substrings; a record is kept if any substring occurs in its
relative path. By default, all records are kept.
select_coords : str or None
Optional subset string in the format 'INDIVIDUALS;AXES;KEYPOINTS', where each
field is a comma-separated list or '*' for all. If None, all data is loaded.
Example: 'mouse1,mouse2;x,y;nose,tail'.
rename_coords : str or None
Optional coordinate names remapping in the format 'INDIVIDUALS;AXES;KEYPOINTS',
where each field is a comma-separated list of maps 'old_id:new_id' or '*' for
no remapping at that level. If None, original dataset names are used.
Example: 'mouse1:resident,mouse2:intruder;*;nose:snout,tail:tailbase'.
Returns
-------
list[Record]
A list of Record objects, each containing id, posetracks, and optionally
annotations.
Raises
------
ValueError
If data_format is unsupported, or if select_coords/rename_coords are invalid.
NotImplementedError
For recognized but unimplemented formats.
Examples
--------
>>> records = load_records(
... data_format="movement",
... data_path="~/datasets/mice",
... select_coords="mouse1,mouse2;x,y;nose,tail",
... rename_coords="mouse1:resident,mouse2:intruder;*;nose:snout,tail:tailbase",
... )
>>> print(len(records))
42
>>> print(records[0].id)
'session1/seq001'
>>> print(records[0].posetracks)
<xarray.Dataset ...>
>>> print(records[0].annotations)
<xarray.Dataset ...> or None
"""
# Find all potential record paths
seq_paths = [f for f in Path(data_path).rglob("*") if f.is_dir()]
# Filter data, if requested
if data_filter is not None:
filters = data_filter.split(",")
seq_paths = [
seq_path
for seq_path in seq_paths
if any(flt in str(seq_path.relative_to(data_path)) for flt in filters)
]
logging.info("%d potential paths after filtering", len(seq_paths))
logging.debug(seq_paths)
# Load and preprocess raw data
records = []
for seq_path in tqdm(seq_paths, desc="Loading dataset"):
# Load pose-tracking data
posetracks = _load_posetracks(
seq_path, data_format, data_scale, select_coords, rename_coords
)
if posetracks is None:
logging.debug("Skipping %s, no tracking data found", str(seq_path))
continue
# Load annotations
annotations = _load_annotations(seq_path)
# Create record id
rec_id = str(seq_path.relative_to(data_path))
# Add Record object to the list
records.append(
Record(id=rec_id, posetracks=posetracks, annotations=annotations)
)
# Sanity check: All posetracks must have the same 'features' coordinate (summary of
# individuals/keypoints/space)
if records:
ref_coords = {
dim: records[0].posetracks.coords[dim].values.tolist()
for dim in ("individuals", "keypoints", "space")
}
for rec in records:
for dim in ("individuals", "keypoints", "space"):
ds_coords = rec.posetracks.coords[dim].values.tolist()
if ds_coords != ref_coords[dim]:
raise ValueError(
f"Inconsistent posetracks coordinates in record '{rec.id}':\n"
f"Reference {dim}:\n{ref_coords[dim]}\n"
f"Record {dim}:\n{ds_coords}"
)
else:
raise ValueError(
"No valid records found in the specified directory. Please check the data "
"path, format and filters to ensure they match the dataset structure.\n"
f"Current values are: \n data_path = {data_path}\n "
f"data_format = {data_format}\n data_filter = {data_filter}\n"
)
return records
[docs]
def load_multi_records(data_config):
"""Internal helper. Loads and splits records for all tasks."""
datasets = data_config.data_format.split(",")
datapaths = data_config.data_path.split(",")
if len(datasets) == len(datapaths):
datasources = list(zip(datasets, datapaths, strict=True))
elif len(datapaths) == 1:
datasources = list(zip(datasets, repeat(datapaths[0])))
else:
raise ValueError(
"Input arguments datasets and datapaths must have the same length, or"
" datapath must be a single element."
)
logging.debug(datasources)
# Load records
multi_records = [
load_records(
dataset,
datapath,
data_scale=data_config.data_scale,
data_filter=data_config.data_filter,
select_coords=data_config.select_coords,
rename_coords=data_config.rename_coords,
)
for dataset, datapath in datasources
]
# Sanity check: All posetracks must have the same 'individuals', 'keypoints', and
# 'space' coordinates across datasets. As consistency within a dataset
# is already checked, we only need to check the first record of each
# dataset against the others.
main_coords = [
{
dim: recs[0].posetracks.coords[dim].values.tolist()
for dim in ("individuals", "keypoints", "space")
}
for recs in multi_records
]
ref_coords = main_coords[0]
for i, coords in enumerate(main_coords):
for dim in ("individuals", "keypoints", "space"):
if coords[dim] != ref_coords[dim]:
raise ValueError(
"Inconsistent posetracks coordinates in loaded records, dataset "
f"{i}:\n"
f"Reference {dim}:\n{ref_coords[dim]}\n"
f"Record {dim}:\n{coords[dim]}"
)
return multi_records
[docs]
def load_model(config_path, weights_path):
"""
Load a pretrained LISBET model from a configuration file.
This function supports loading models from YAML configuration files (as used in
LISBET). It uses the model factory to instantiate the model and loads weights from
the specified file.
Parameters
----------
config_path : str or Path or dataclass
Path to the model configuration YAML file.
weights_path : str or Path
Path to the model weights file.
Returns
-------
torch.nn.Module
The loaded LISBET model.
"""
with open(config_path, encoding="utf-8") as f_yaml:
model_config_dict = yaml.safe_load(f_yaml)
# Create model configuration
model_config = ModelConfig.model_validate(model_config_dict)
# Load model from configuration
model = create_model_from_config(model_config)
# Load weights (strict=False allows for partial loading)
incompatible_layers = model.load_state_dict(
torch.load(weights_path, weights_only=True, map_location=torch.device("cpu")),
strict=False,
)
logging.info(
"Loaded weights from file.\nMissing keys: %s\nUnexpected keys: %s",
incompatible_layers.missing_keys,
incompatible_layers.unexpected_keys,
)
return model
[docs]
def export_embedder(model_path, weights_path, output_path=Path(".")):
# Get config dictionary
with open(model_path, encoding="utf-8") as f_yaml:
model_config_dict = yaml.safe_load(f_yaml)
model_id = model_config_dict["model_id"] + "-embedder"
# Update config
model_config_dict["model_id"] = model_id
model_config_dict["out_heads"] = {
# TODO: Remove this hack when we have a better solution
"embedding": {"output_token_idx": -(model_config_dict["window_offset"] + 1)}
}
# Create model configuration
model_config = ModelConfig.model_validate(model_config_dict)
# Create behavior embedding model
embedding_model = create_model_from_config(model_config)
summary(embedding_model)
# Load weights from pretrained model
embedding_model.load_state_dict(
torch.load(weights_path, weights_only=True, map_location=torch.device("cpu")),
strict=False,
)
# Store configuration
dump_model_config(output_path, model_id, model_config)
# Store weights
dump_weights(embedding_model, output_path, model_id, weights_path.name)
[docs]
def dump_records(data_path, records):
"""
Dump a list of records to a file.
Pose tracks and annotations are saved in a NetCDF format.
Parameters
----------
data_path : str or Path
Directory where the records will be saved.
records : list of Record
List of Record objects to be saved.
"""
for rec in tqdm(records, desc="Dumping records to disk"):
rec_path = Path(data_path) / rec.id
rec_path.mkdir(parents=True, exist_ok=True)
# Save posetracks
rec.posetracks.to_netcdf(rec_path / "tracking.nc", engine="scipy")
# Save annotations
if rec.annotations is not None:
rec.annotations.to_netcdf(rec_path / "annotations.nc", engine="scipy")
[docs]
def dump_annotations(results, output_path):
"""
Save LISBET behavior predictions to CSV files.
Parameters
----------
results : list of (record_id, np.ndarray)
Output from annotate_behavior.
output_path : str or Path
Root directory to save CSVs. Each record will be saved under
output_path/annotations/<record_id>/machineAnnotation_lisbet.csv
"""
for key, model_output in tqdm(results, desc="Saving LISBET annotations"):
dst_path = (
Path(output_path) / "annotations" / key / "machineAnnotation_lisbet.csv"
)
dst_path.parent.mkdir(parents=True, exist_ok=True)
pd.DataFrame(model_output).to_csv(dst_path, index=True)
[docs]
def dump_embeddings(results, output_path):
"""
Save LISBET embeddings to CSV files.
Parameters
----------
results : list of (record_id, np.ndarray)
Output from compute_embeddings.
output_path : str or Path
Root directory to save CSVs. Each record will be saved under
output_path/embeddings/<record_id>/features_lisbet_embedding.csv
"""
for key, model_output in tqdm(results, desc="Saving LISBET embeddings"):
dst_path = (
Path(output_path) / "embeddings" / key / "features_lisbet_embedding.csv"
)
dst_path.parent.mkdir(parents=True, exist_ok=True)
pd.DataFrame(model_output).to_csv(dst_path, index=True)
[docs]
def dump_evaluation_results(report: dict, output_path: str, model_path: str):
"""
Save evaluation report to a YAML file in a standardized location.
Parameters
----------
report : dict
The evaluation report.
output_path : str or Path
Directory to save the report.
model_path : str
Path to the model config (used to extract model_id).
"""
with open(model_path, encoding="utf-8") as f_yaml:
model_config = yaml.safe_load(f_yaml)
model_id = model_config.get("model_id", "unknown_model")
report_path = Path(output_path) / "evaluations" / model_id / "evaluation_report.yml"
report_path.parent.mkdir(parents=True, exist_ok=True)
with open(report_path, "w", encoding="utf-8") as f_yaml:
yaml.safe_dump(report, f_yaml)
[docs]
def dump_weights(model, output_path, run_id, filename):
"""Internal helper. Saves model weights."""
weights_path = Path(output_path) / "models" / run_id / "weights" / filename
weights_path.parent.mkdir(parents=True, exist_ok=True)
torch.save(model.state_dict(), weights_path)
[docs]
def dump_model_config(output_path, run_id, model_config):
"""Save model configuration to YAML file."""
model_path = Path(output_path) / "models" / run_id / "model_config.yml"
model_path.parent.mkdir(parents=True, exist_ok=True)
with open(model_path, "w", encoding="utf-8") as f_yaml:
yaml.safe_dump(model_config.model_dump(), f_yaml)
[docs]
def dump_profiling_results(output_path, run_id, prof):
"""Internal helper. Saves profiling results."""
# Create profiling directory
profiling_path = Path(output_path) / "models" / run_id / "profiler"
profiling_path.mkdir(parents=True, exist_ok=True)
# Save profiling results
prof.export_chrome_trace(str(profiling_path / "chrome_trace.json.gz"))
prof.export_memory_timeline(str(profiling_path / "memory_trace.html"))
prof.export_stacks(str(profiling_path / "cpu_stacks.txt"), "self_cpu_time_total")
prof.export_stacks(str(profiling_path / "cuda_stacks.txt"), "self_cuda_time_total")
with open(profiling_path / "profiling_summary.txt", "w", encoding="utf-8") as f:
f.write("CPU Profiling Summary:\n")
f.write(prof.key_averages().table(sort_by="self_cpu_time_total", row_limit=10))
f.write("\n\nCUDA Profiling Summary:\n")
f.write(prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=10))