Source code for lisbet.cli.commands.train

"""Train a model for keypoint classification and export the embedder."""

import argparse
import textwrap
from pathlib import Path

from lisbet.cli.common import add_data_io_args, add_keypoints_args, add_verbosity_args


[docs] def parse_data_augmentation(aug_string): """Parse data augmentation string into list of DataAugmentationConfig objects. Parameters ---------- aug_string : str or None Comma-separated augmentation specifications, each with optional parameters. Format: name:p=value:frac=value Example: "all_perm_id:p=0.5,blk_perm_id:p=0.3:frac=0.2" Returns ------- list[dict] or None List of dictionaries with augmentation configs, or None if None/empty. Examples -------- >>> parse_data_augmentation("all_perm_id") [{'name': 'all_perm_id', 'p': 1.0}] >>> parse_data_augmentation("all_perm_id:p=0.5,blk_perm_id:frac=0.3") [{'name': 'all_perm_id', 'p': 0.5}, {'name': 'blk_perm_id', 'p': 1.0, 'frac': 0.3}] >>> parse_data_augmentation("rotation:p=0.5:max_angle=30") [{'name': 'rotation', 'p': 0.5, 'max_angle': 30.0}] """ if not aug_string: return None augmentations = [] for aug_spec in aug_string.split(","): aug_spec = aug_spec.strip() if not aug_spec: continue parts = aug_spec.split(":") aug_config = {"name": parts[0].strip(), "p": 1.0} # Parse parameters for param in parts[1:]: key, _, value = param.partition("=") key = key.strip() value = value.strip() if not key or not value: raise ValueError(f"Invalid parameter format in '{aug_spec}': {param}") try: aug_config[key] = float(value) except ValueError as e: raise ValueError( f"Invalid parameter value in '{aug_spec}': {key}={value}" ) from e augmentations.append(aug_config) return augmentations if augmentations else None
[docs] def configure_train_model_parser(parser: argparse.ArgumentParser) -> None: """Configure train_model command parser.""" add_verbosity_args(parser) add_keypoints_args(parser) add_data_io_args(parser, "Keypoint data location") parser.add_argument("--epochs", default=10, type=int, help="Number of epochs") parser.add_argument("--batch_size", default=32, type=int, help="Batch size") parser.add_argument( "--learning_rate", default=1e-4, type=float, help="Learning rate" ) parser.add_argument( "--task_ids", type=str, default="multiclass", help=textwrap.dedent( """\ Task ID or comma-separated list of task IDS. Valid (supervised) tasks are: - multiclass: Multi-Class Frame Classification - multilabel: Multi-Label Frame Classification Valid (self-supervised) tasks are: - cons: Group Consistency Classification - order: Temporal Order Classification - shift: Temporal Shift Classification - warp: Temporal Warp Classification Example: order,cons """ ), ) parser.add_argument( "--task_data", type=str, help="i.e., multiclass:[0],order:[0,1]", ) parser.add_argument("--seed", default=1991, type=int, help="Base RNG seed") parser.add_argument("--run_id", type=str, help="ID of the run") parser.add_argument( "--data_augmentation", type=str, help=textwrap.dedent( """\ Data augmentation techniques to apply, comma-separated. Each augmentation can have optional parameters specified with colons. NOTE: Use the --data_augmentation= format (with equals sign) to clearly separate the argument from its value. This makes the syntax unambiguous and prevents shell interpretation of special characters. Valid options are: - all_perm_id: Randomly permute identities of individuals, applied consistently across all frames in a window. - all_perm_ax: Randomly permute x, y (and z) axes, applied consistently across all frames in a window. - blk_perm_id: Randomly permute identities of individuals, applied to a contiguous block of frames within a window. - gauss_jitter: Randomly add N(0,sigma) noise applied consistently in a window. - kp_ablation: Randomly set keypoint coordinates to NaN (missing data) applied consistently across all frames in a window. Use Bernoulli(pB) to select which keypoints to ablate. Simulates sporadic occlusions or tracking failures. - rotation: Randomly rotate keypoint coordinates around the center of the normalized [0,1] space. Supports 2D and 3D (auto-detected from data). Parameters (optional): - p=<float>: Probability of applying the transformation (default: 1.0) - frac=<float>: For block-based augmentations (blk_perm_id, blk_gauss_jitter). Block size fraction (defaults: 0.5 / 0.05 / 0.1) - sigma=<float>: Jitter noise std for gauss_jitter and blk_gauss_jitter (default 0.01). - max_angle=<float>: Maximum rotation angle in degrees for rotation (default 180.0). Angle sampled from [-max_angle, +max_angle]. Examples: --data_augmentation="all_perm_id" --data_augmentation="all_perm_id:p=0.5" --data_augmentation="all_perm_id:p=0.5,blk_perm_id:p=0.3:frac=0.2" --data_augmentation="all_perm_ax:p=0.7,blk_perm_id:frac=0.3" --data_augmentation="gauss_jitter:p=0.02:sigma=0.01" --data_augmentation="kp_ablation:p=0.05:pB=0.01" --data_augmentation="all_perm_id:p=0.5,kp_ablation:p=0.03:pB=0.01" --data_augmentation="rotation:p=0.5:max_angle=30" """ ), ) parser.add_argument( "--train_sample", type=float, help="Fraction of samples from the train set" ) parser.add_argument( "--dev_sample", type=float, help="Fraction of samples from the dev set" ) parser.add_argument( "--dev_ratio", type=float, help="Fraction of the train set to be held out as dev set", ) # Model architecture parser.add_argument( "--backbone_preset", default="transformer-base", type=str, help="Backbone type" ) parser.add_argument( "--head_type", type=str, choices=["mlp", "linear"], default="mlp", help="Classification head type", ) parser.add_argument( "--set", metavar="KEY=VALUE", action="append", help="Override config values, e.g. --set backbone.num_layers=4", ) # Model weights and saving options parser.add_argument( "--load_backbone_weights", type=Path, help="Path to backbone weights from pretrained model", ) parser.add_argument( "--freeze_backbone_weights", action="store_true", help="Freeze the backbone weights", ) parser.add_argument( "--save_weights", default="last", choices=["all", "last"], help="Save 'best', 'all' or 'last' model weights", ) parser.add_argument( "--save_history", action="store_true", help="Save model's training history" ) # Performance options parser.add_argument( "--mixed_precision", action="store_true", help="Run training in mixed precision mode", ) # Miscellaneous parser.add_argument("--dry_run", action="store_true", help="Print config and exit")
[docs] def configure_export_embedder_parser(parser: argparse.ArgumentParser) -> None: """Configure export_embedder command parser.""" add_verbosity_args(parser) parser.add_argument("model_path", type=Path, help="Path to model config") parser.add_argument("weights_path", type=Path, help="Path to model weights") parser.add_argument( "--output_path", type=Path, default=Path("."), help="Output path" )
[docs] def train_model(kwargs): """Train a model for keypoint classification.""" # Lazy imports to avoid unnecessary dependencies when not training from pydantic import TypeAdapter from lisbet.config.presets import BACKBONE_PRESETS from lisbet.config.schemas import ( BackboneConfig, DataAugmentationPipeline, DataConfig, ExperimentConfig, ModelConfig, TrainingConfig, ) from lisbet.training import train # Configure backbone preset_name = kwargs.get("backbone_preset", "transformer-base") if preset_name not in BACKBONE_PRESETS: raise ValueError(f"Unknown backbone preset: {preset_name}") backbone_config_dict = BACKBONE_PRESETS[preset_name] # Set max_length for transformer backbones to window_size if backbone_config_dict.get("type") == "transformer": backbone_config_dict["max_length"] = kwargs.get("window_size") # Parse overrides from --set backbone.*=... overrides = {} for override in kwargs.get("set", []) or []: if override.startswith("backbone."): keyval = override[len("backbone.") :].split("=", 1) if len(keyval) == 2: key, val = keyval overrides[key] = val backbone_config_dict.update(overrides) # Create backbone config adapter = TypeAdapter(BackboneConfig) backbone_config = adapter.validate_python(backbone_config_dict) # Configure data data_config = DataConfig.model_validate(kwargs) # Configure tasks task_ids_list = kwargs["task_ids"].split(",") # NOTE: For now we keep the task_data as a string, but it could be parsed into a # dict to simplify `split_multi_records`. Or even better, use a TaskConfig # class to handle task-specific configurations. task_data = kwargs["task_data"] # Configure model model_config = ModelConfig( model_id=kwargs["run_id"], backbone=backbone_config, out_heads={task_id: {} for task_id in task_ids_list}, input_features={}, window_size=kwargs["window_size"], window_offset=kwargs["window_offset"], ) # Parse and configure data augmentation aug_string = kwargs.get("data_augmentation") parsed_augmentation = parse_data_augmentation(aug_string) if parsed_augmentation is not None: pipeline = DataAugmentationPipeline(augmentations=parsed_augmentation) validated_augmentations = pipeline.augmentations else: validated_augmentations = parsed_augmentation # Update kwargs with parsed augmentation kwargs_for_training = {**kwargs, "data_augmentation": validated_augmentations} # Configure training training_config = TrainingConfig.model_validate(kwargs_for_training) # Create experiment configuration experiment_config = ExperimentConfig( run_id=kwargs["run_id"], seed=kwargs["seed"], model=model_config, training=training_config, data=data_config, task_ids_list=task_ids_list, task_data=task_data, output_path=kwargs["output_path"], ) if kwargs.get("dry_run"): # If dry run, just print the configuration print(experiment_config) else: # Train the model train(experiment_config)