Source code for lisbet.training.core

"""Training and fitting functions for LISBET.

Notes
-----
[a] The dictionary of RNG seed could be refactored as a Pydantic model in the future.

[b] The train/dev split is performed here and not in the input_pipeline module to
    emphasize that the test set is frozen and won't be used for hyper-parameters tuning.

[c] When mixing datasets of different lengths, the training and evaluation loops will
    stop after exhausting the shortest one. Please consider using random sampling.

"""

import logging
import os
from contextlib import nullcontext
from datetime import datetime

import numpy as np
import torch
from lightning.fabric import Fabric
from lightning.fabric.loggers import CSVLogger
from torch.profiler import ProfilerActivity, profile, schedule
from torch.utils.data import DataLoader
from torchinfo import summary
from tqdm.auto import trange

from lisbet.config.schemas import ExperimentConfig
from lisbet.io import (
    dump_model_config,
    dump_profiling_results,
    dump_weights,
    load_multi_records,
)
from lisbet.modeling.factory import create_model_from_config
from lisbet.training.preprocessing import split_multi_records
from lisbet.training.tasks import configure_tasks
from lisbet.training.utils import estimate_num_workers, generate_seeds, worker_init_fn


def _configure_profiler(steps_multiplier):
    """Internal helper. Configures the profiler."""
    if os.environ.get("TORCH_PROFILER", "0") == "1":
        logging.info("Profiler is enabled.")

        profiler = profile(
            activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
            schedule=schedule(
                skip_first=4 * steps_multiplier,
                wait=steps_multiplier,
                warmup=steps_multiplier,
                active=8 * steps_multiplier,
                repeat=1,
            ),
            record_shapes=True,
            profile_memory=True,
            with_stack=True,
            # NOTE: ExperimentalConfig needed until bug in torch.profiler is fixed, see
            #       https://github.com/pytorch/pytorch/issues/100253
            experimental_config=torch._C._profiler._ExperimentalConfig(verbose=True),
        )

    else:
        logging.debug("Profiler is disabled.")
        profiler = nullcontext()

    return profiler


def _build_model(training_config, model_config):
    """Internal helper. Builds the LISBET model using the config factory."""
    model = create_model_from_config(model_config)

    if training_config.load_backbone_weights:
        incompatible_layers = model.load_state_dict(
            torch.load(training_config.load_backbone_weights, weights_only=True),
            strict=False,
        )
        logging.info(
            "Loaded weights from file.\nMissing keys: %s\nUnexpected keys: %s",
            incompatible_layers.missing_keys,
            incompatible_layers.unexpected_keys,
        )

    if training_config.freeze_backbone_weights:
        for param in model.backbone.parameters():
            param.requires_grad = False
    return model


def _configure_optimizer_and_scheduler(model, learning_rate):
    """Internal helper. Configures optimizer, scheduler, and scaler."""
    # Configure optimizer
    optimizer = torch.optim.AdamW(
        filter(lambda p: p.requires_grad, model.parameters()),
        lr=learning_rate,
    )

    # Configure LR (warmup scheduler)
    warmup_epochs = 5
    warmup_scheduler = torch.optim.lr_scheduler.LinearLR(
        optimizer,
        start_factor=1e-2,
        end_factor=1.0,
        total_iters=warmup_epochs,
    )

    # Configure LR (main scheduler)
    T_0 = 10
    T_mult = 2
    main_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
        optimizer, T_0=T_0, T_mult=T_mult
    )

    # Configure final LR scheduler
    scheduler = torch.optim.lr_scheduler.SequentialLR(
        optimizer,
        schedulers=[warmup_scheduler, main_scheduler],
        milestones=[warmup_epochs],
    )

    return optimizer, scheduler


def _configure_dataloaders(tasks, group, batch_size, sample_ratio, pin_memory):
    """Internal helper. Configures dataloaders for a group."""
    # Estimate number of samples
    n_batches = np.ceil(
        min(getattr(task, f"{group}_dataset").n_frames for task in tasks) / batch_size
    ).astype(int)
    if sample_ratio is not None:
        n_batches = int(n_batches * sample_ratio)
    logging.info("Using %d samples from the %s group", n_batches * batch_size, group)

    # Estimate number of workers
    num_workers = estimate_num_workers(len(tasks), batch_size, batch_size_per_worker=4)

    # Create a dataloader for each task
    dataloaders = []
    for task in tasks:
        dataset = getattr(task, f"{group}_dataset")

        dataloader = DataLoader(
            dataset,
            batch_size=batch_size,
            num_workers=num_workers,
            prefetch_factor=4,
            persistent_workers=True,
            pin_memory=pin_memory,
            worker_init_fn=worker_init_fn,
        )

        dataloaders.append(dataloader)

    return dataloaders, n_batches


def _train_one_epoch(
    model,
    dataloaders,
    n_batches,
    optimizer,
    scheduler,
    tasks,
    prof,
    fabric,
):
    """Internal helper. Runs one training epoch."""
    model.train()

    dl_iter = [iter(dl) for dl in dataloaders]

    # Iterate over all batches
    for batch_idx in trange(n_batches, desc="Training batches", leave=False):
        optimizer.zero_grad(set_to_none=True)

        # Iterate over all tasks
        # NOTE: strict=False to allow for different iterable lengths
        for task, dataloader in zip(tasks, dl_iter, strict=False):
            data, target = next(dataloader)

            # Forward pass
            output = model(data, task.task_id)
            loss = task.loss_function(output, target)

            # Backward pass
            fabric.backward(loss)

            # Store loss value and metrics for stats
            if batch_idx % 10 == 0:
                task.train_loss.update(loss)
                task.train_score.update(output, target)

            # Step profiler
            if prof is not None:
                prof.step()

        # Step optimizer
        optimizer.step()

    # Step scheduler
    scheduler.step()


def _evaluate(model, dataloaders, n_batches, tasks):
    """Internal helper. Evaluates model on a group."""
    model.eval()

    dl_iter = [iter(dl) for dl in dataloaders]

    with torch.no_grad():
        # Iterate over all batches
        for batch_idx in trange(n_batches, desc="Evaluation batches", leave=False):
            # Iterate over all tasks
            # NOTE: strict=False to allow for different iterable lengths
            for task, dataloader in zip(tasks, dl_iter, strict=False):
                data, target = next(dataloader)

                # Forward pass
                output = model(data, task.task_id)
                loss = task.loss_function(output, target)

                # Store loss value and metrics for stats
                if batch_idx % 10 == 0:
                    task.dev_loss.update(loss)
                    task.dev_score.update(output, target)


def _compute_epoch_logs(group_id, tasks):
    """Internal helper. Computes metrics and mean losses for an epoch."""
    epoch_log = {}
    for task in tasks:
        # Compute metrics
        metric_name = f"{task.task_id}_{group_id}_score"
        epoch_log[metric_name] = getattr(task, f"{group_id}_score").compute()
        getattr(task, f"{group_id}_score").reset()

        # Compute mean losses
        loss_name = f"{task.task_id}_{group_id}_loss"
        epoch_log[loss_name] = getattr(task, f"{group_id}_loss").compute()
        getattr(task, f"{group_id}_loss").reset()

    return epoch_log


[docs] def train(experiment_config: ExperimentConfig) -> torch.nn.Module: """ Train a LISBET model. This function orchestrates the full training pipeline for LISBET, including data loading, model construction, training, evaluation, and saving artifacts. All parameters match the CLI arguments exactly. Parameters ---------- experiment_config : ExperimentConfig Configuration object containing all parameters for the training run. It includes data paths, model architecture, training hyperparameters, and task definitions. Must be a Pydantic model. Returns ------- model : torch.nn.Module The trained LISBET model instance. Notes ----- All arguments are exposed for CLI and documentation. For advanced usage, see the LISBET documentation. """ # Create aliases for configuration parameters model_config = experiment_config.model backbone_config = model_config.backbone data_config = experiment_config.data training_config = experiment_config.training # Configure base runtime arguments run_id = ( datetime.now().strftime("%Y%m%d%H%M%S") if experiment_config.run_id is None else experiment_config.run_id ) # Create Fabric instance precision = "16-mixed" if experiment_config.training.mixed_precision else "32-true" history_logger = CSVLogger( experiment_config.output_path / "models" / run_id, name="training_history", flush_logs_every_n_steps=1, ) fabric = Fabric(precision=precision, loggers=history_logger) logging.info("Using %s for training model %s.", fabric.device.type, run_id) # Configure RNGs run_seeds = generate_seeds(experiment_config.seed, experiment_config.task_ids_list) torch.manual_seed(run_seeds["torch"]) # Load records # TODO: Switch to the DataConfig object multi_records = load_multi_records(data_config) # Split records train_rec, dev_rec = split_multi_records( multi_records=multi_records, dev_ratio=data_config.dev_ratio, dev_seed=run_seeds.get("dev_split"), task_ids=experiment_config.task_ids_list, task_data=experiment_config.task_data, ) # Determine data shape from first record cdim = train_rec[experiment_config.task_ids_list[0]][0].posetracks.coords.sizes feature_dim = cdim["individuals"] * cdim["keypoints"] * cdim["space"] # Determine input_features list for config consistency first_record = train_rec[experiment_config.task_ids_list[0]][0] input_features = { dim: first_record.posetracks.coords[dim].values.tolist() for dim in ("individuals", "keypoints", "space") } if training_config.load_backbone_weights is not None: logging.warning( "Loading backbone weights from %s. If you are not experimenting with " "transfer learning, please verify that the input features of the " "pre-trained model match those of your data. In the future, this warning " "could become a requirement to load the backbone weights.", training_config.load_backbone_weights, ) # Compute backbone output token idx output_token_idx = -(data_config.window_offset + 1) if not (data_config.window_size > data_config.window_offset >= 0): raise RuntimeError( "Window offset must be a positive integer smaller than the window size" f" or zero, got {data_config.window_offset}." ) logging.debug("Output token IDX = %d", output_token_idx) # Select head hidden dimension based on head type head_hidden_dim = ( None if training_config.head_type == "linear" else backbone_config.hidden_dim ) logging.debug("Head(s) hidden dimension = %s", head_hidden_dim) # Configure tasks tasks = configure_tasks( train_rec, dev_rec, experiment_config.task_ids_list, data_config.window_size, data_config.window_offset, backbone_config.embedding_dim, head_hidden_dim, training_config.data_augmentation, run_seeds, fabric.device, ) n_tasks = len(tasks) # Set dynamic attributes for backbone backbone_config.feature_dim = feature_dim # Set dynamic attributes for model config model_config.input_features = input_features model_config.out_heads = {task.task_id: task.head.get_config() for task in tasks} # Build model model = _build_model(training_config, model_config) model_stats = summary(model, verbose=0) logging.info("Model summary\n" + str(model_stats)) # Optimizer and scheduler optimizer, scheduler = _configure_optimizer_and_scheduler( model, training_config.learning_rate ) # Save model config dump_model_config(experiment_config.output_path, run_id, model_config) # Configure dataloaders train_dataloaders, train_n_batches = _configure_dataloaders( tasks, "train", training_config.batch_size, data_config.train_sample, fabric.device.type == "cuda", ) if data_config.dev_ratio is not None: dev_dataloaders, dev_n_batches = _configure_dataloaders( tasks, "dev", training_config.batch_size, data_config.dev_sample, fabric.device.type == "cuda", ) # Configure Fabric model, optimizer = fabric.setup(model, optimizer) train_dataloaders = [fabric.setup_dataloaders(dl) for dl in train_dataloaders] if data_config.dev_ratio is not None: dev_dataloaders = [fabric.setup_dataloaders(dl) for dl in dev_dataloaders] # Training loop with _configure_profiler(steps_multiplier=n_tasks) as prof: for epoch in range(training_config.epochs): history_entry = {"epoch": epoch} print(f"Epoch {epoch}") logging.info("Current LR = %f", scheduler.get_last_lr()[0]) # Run training epoch _train_one_epoch( model, train_dataloaders, train_n_batches, optimizer, scheduler, tasks, prof, fabric, ) # Update history entry for current epoch history_entry.update(_compute_epoch_logs("train", tasks)) # Save weights, if requested if training_config.save_weights == "all": dump_weights( model, experiment_config.output_path, run_id, f"weights_epoch{epoch}.pt", ) if data_config.dev_ratio is not None: # Run dev epoch _evaluate(model, dev_dataloaders, dev_n_batches, tasks) # Update history entry for current epoch history_entry.update(_compute_epoch_logs("dev", tasks)) # Update history fabric.log_dict(history_entry, step=epoch) logging.info(", ".join(f"{k}: {v:.3f}" for k, v in history_entry.items())) # Save profiling results, if requested if prof is not None: dump_profiling_results(experiment_config.output_path, run_id, prof) # Save final weights, if requested if training_config.save_weights == "last": dump_weights(model, experiment_config.output_path, run_id, "weights_last.pt") return model