Source code for lisbet.training.utils

"""Utility functions for model training."""

import hashlib
import logging
import struct

import numpy as np
import torch
import torch.distributed as dist
from lightning.fabric.utilities.data import suggested_max_num_workers
from torch.utils.data import get_worker_info


[docs] def generate_seeds(seed, task_ids): """Internal helper. Generates multiple seeds from the base one.""" rng = np.random.default_rng(seed) seed_keys = ( ["torch", "dev_split", "test_split"] + [ f"{group}_shuffle_{task_id}" for task_id in task_ids for group in ("train", "dev", "test") ] + [f"transform_{task_id}" for task_id in task_ids] + [f"dataset_{task_id}" for task_id in task_ids] ) run_seeds = {sk: rng.integers(low=0, high=2**31 - 1, dtype=int) for sk in seed_keys} logging.debug("Generated seeds: %s", run_seeds) return run_seeds
[docs] def estimate_num_workers(n_tasks, batch_size, batch_size_per_worker=8): """ Estimate the optimal number of DataLoader worker processes to use, based on the number of training tasks, the batch size, and the desired batch size per worker. Parameters ---------- n_tasks : int The number of training tasks (e.g., datasets or splits) being processed. batch_size : int The total batch size used for loading data. batch_size_per_worker : int, optional The target batch size to be handled by each worker process (default: 8). Returns ------- num_workers : int The estimated number of DataLoader worker processes to use. """ # Estimate number of workers local_world_size = 1 # torch.distributed.get_world_size() in distributed training max_workers = suggested_max_num_workers(local_world_size) + 1 fair_share = max_workers // max(1, n_tasks) batch_cap = max(1, batch_size // batch_size_per_worker) num_workers = max(1, min(max_workers, fair_share, batch_cap)) logging.debug( "Estimated num_workers: %d (max_workers: %d, fair_share: %d, batch_cap: %d)", num_workers, max_workers, fair_share, batch_cap, ) return num_workers
[docs] def worker_init_fn(worker_id: int): """ Worker initialization function for DataLoader. This function sets a unique random seed for each DataLoader worker based on the worker ID, global rank, and task seed. This ensures that each worker operates with a different random seed, which is crucial for data shuffling and augmentation in distributed training scenarios. Parameters ---------- worker_id : int The ID of the worker being initialized. This is typically an integer in the range [0, num_workers - 1]. """ info = get_worker_info() ds = info.dataset rank = dist.get_rank() if dist.is_initialized() else 0 # Generate a unique seed for the DataLoader worker payload = struct.pack(">IHH", ds.base_seed, rank, worker_id) seed = int.from_bytes(hashlib.blake2b(payload, digest_size=8).digest(), "big") ds.g = torch.Generator().manual_seed(seed) if rank == 0: logging.debug( "Worker %d initialized with seed %d (base seed: %d, rank: %d)", worker_id, seed, ds.base_seed, rank, )