lisbet.transforms_extra#

Augmentation module for transforming samples in a dataset.

This module provides data augmentation and preprocessing transforms for pose tracking datasets stored as xarray.Dataset objects. The transforms can be used in training pipelines to improve model robustness and generalization.

Available Transforms#

RandomPermutation

Randomly permutes both coordinate labels and their associated data together across the entire time window. Useful for making models invariant to coordinate ordering (e.g., individual identities, spatial axes).

RandomBlockPermutation

Randomly permutes data within a contiguous block of frames while keeping coordinate labels unchanged. Creates temporal identity confusion within part of the window. Useful for more challenging augmentation scenarios.

RandomRotation

Applies a random rotation to keypoint coordinates in normalized [0, 1] space. Supports 2D and 3D keypoints with configurable maximum angle and post-rotation normalization modes (truncate, rescale, or none).

KeypointAblation

Randomly sets keypoint coordinates to 0.0 with independent Bernoulli sampling across (time, keypoints, individuals). Simulates missing or occluded keypoints for robustness testing.

PoseToTensor

Converts pose tracking data from xarray.Dataset format to PyTorch tensors by stacking spatial dimensions into a single feature dimension.

PoseToVideo

Renders pose tracking data as video frames (RGB images) using OpenCV, with customizable body specifications for visualization.

VideoToTensor

Converts video frames from NumPy arrays to PyTorch tensors with optional normalization for video model inputs.

Usage Examples#

>>> from lisbet.transforms_extra import RandomPermutation, PoseToTensor
>>> from torchvision import transforms
>>>
>>> # Simple augmentation pipeline
>>> transform = transforms.Compose([
...     RandomPermutation(seed=42, coordinate='individuals'),
...     PoseToTensor(),
... ])
>>>
>>> # Apply with probability using torchvision.transforms.RandomApply
>>> transform = transforms.Compose([
...     transforms.RandomApply([
...         RandomPermutation(seed=42, coordinate='individuals')
...     ], p=0.5),
...     PoseToTensor(),
... ])
>>>
>>> # Block permutation for temporal identity confusion
>>> from lisbet.transforms_extra import RandomBlockPermutation
>>> transform = transforms.Compose([
...     RandomBlockPermutation(
...         seed=42, coordinate='individuals', permute_fraction=0.3
...     ),
...     PoseToTensor(),
... ])
>>>
>>> # Keypoint ablation for robustness to missing data
>>> from lisbet.transforms_extra import KeypointAblation
>>> transform = transforms.Compose([
...     transforms.RandomApply([
...         KeypointAblation(seed=42, p=0.05)
...     ], p=1.0),
...     PoseToTensor(),
... ])
>>>
>>> # Random rotation augmentation for spatial invariance
>>> from lisbet.transforms_extra import RandomRotation
>>> transform = transforms.Compose([
...     RandomRotation(seed=42, max_angle=30.0, mode='truncate'),
...     PoseToTensor(),
... ])

Notes

  • Augmentations should be applied thoughtfully based on dataset characteristics

  • Spatial axis permutation (coordinate=’space’) is only suitable for top-down view datasets where axes are symmetric

  • Identity permutations work best for datasets where individual labels are interchangeable

Classes

GaussianJitter(seed, sigma)

Apply Gaussian jitter with across the full window.

KeypointAblation(seed, pB)

Apply keypoint ablation with per-(keypoint, individual) Bernoulli sampling.

PoseToTensor()

Convert the 'position' variable from a posetracks xarray.Dataset into a PyTorch tensor.

PoseToVideo(body_specs[, image_size, bg_color])

Fast OpenCV-based transformation: posetracks (xarray.Dataset) to a sequence of BGR images.

RandomBlockPermutation(seed[, coordinate, ...])

Randomly permutes the data (but not coordinate labels) of a specified coordinate within a random contiguous block of frames in an xarray.Dataset.

RandomPermutation(seed[, coordinate, ...])

Randomly permutes the order of a specified coordinate (e.g., 'individuals') in an xarray.Dataset, reordering both the coordinate labels and their associated data together.

RandomRotation(seed[, max_angle, mode])

Apply a random rotation to keypoint coordinates in normalized [0, 1] space.

VideoToTensor([normalize, mean, std, dtype])

Transform a video (NumPy RGB array) into a PyTorch tensor suitable for video models.

class lisbet.transforms_extra.GaussianJitter(seed, sigma)[source]#

Apply Gaussian jitter with across the full window.

Apply a Gaussian noise N(0, sigma^2) is added across all dimension. Coordinates are assumed normalized in [0, 1] and are clamped to that range post-perturbation.

Parameters:
  • seed (int) – RNG seed for reproducibility.

  • sigma (float) – Standard deviation of the Gaussian noise.

__init__(seed, sigma)[source]#
class lisbet.transforms_extra.KeypointAblation(seed, pB)[source]#

Apply keypoint ablation with per-(keypoint, individual) Bernoulli sampling.

Probability pB is applied independently to each (keypoint, individual) pair. For every selected pair, all spatial coordinates (x, y, z, etc.) are set to NaN across the entire time window, simulating sustained missing or occluded keypoints.

This augmentation helps models become robust to missing data, which commonly occurs due to occlusions, tracking failures, or low-confidence detections.

Parameters:
  • seed (int) – RNG seed for reproducibility.

  • pB (float) – Bernoulli probability for each (keypoint, individual) pair across the full window.

Examples

>>> from lisbet.transforms_extra import KeypointAblation
>>> ablation = KeypointAblation(seed=42, pB=0.05)
>>> ablated_ds = ablation(posetracks)
__init__(seed, pB)[source]#
class lisbet.transforms_extra.RandomPermutation(seed, coordinate='individuals', exclude_identity=False)[source]#

Randomly permutes the order of a specified coordinate (e.g., ‘individuals’) in an xarray.Dataset, reordering both the coordinate labels and their associated data together.

This augmentation can be used to increase invariance to coordinate order (e.g., fixed identity, axis orientation). The permutation is applied to the entire dataset.

Parameters:
  • seed (int) – Random seed for reproducibility.

  • coordinate (str) – Name of the coordinate to permute (e.g., ‘individuals’, ‘keypoints’, ‘space’).

  • exclude_identity (bool) – If True, the identity permutation (no change) is excluded. This guarantees that at least one element will be moved. Default is False.

__call__(posetracks)[source]#

Applies the random permutation to the specified coordinate of the input xarray.Dataset.

Examples

>>> permute = RandomPermutation(seed=42, coordinate='individuals')
>>> permuted_ds = permute(posetracks)
>>> # Guarantee a permutation occurs
>>> permute = RandomPermutation(seed=42, coordinate='space', exclude_identity=True)
>>> permuted_ds = permute(posetracks)
__init__(seed, coordinate='individuals', exclude_identity=False)[source]#
class lisbet.transforms_extra.RandomBlockPermutation(seed, coordinate='individuals', permute_fraction=0.5, exclude_identity=False)[source]#

Randomly permutes the data (but not coordinate labels) of a specified coordinate within a random contiguous block of frames in an xarray.Dataset.

This augmentation is useful to create identity swaps within a portion of the time series, mimicking the effects of a tracking error, while maintaining consistent coordinate labels throughout.

Parameters:
  • seed (int) – Random seed for reproducibility.

  • coordinate (str) – Name of the coordinate to permute (e.g., ‘individuals’, ‘keypoints’).

  • permute_fraction (float) – Fraction of the time window to which the permutation is applied. Must be in (0, 1]. A continuous block of frames of this relative size will be selected at random, and the permutation will be applied only to the data within this block, keeping coordinate labels unchanged.

  • exclude_identity (bool) – If True, the identity permutation (no change) is excluded. This guarantees that at least one element will be moved. Default is False.

__call__(posetracks)[source]#

Applies the random block permutation to the specified coordinate of the input xarray.Dataset.

Notes

This implementation uses uniform frame probability sampling to ensure that every frame in the window has an equal probability of being affected by the permutation, regardless of its position. This is achieved by allowing the block’s starting position to extend beyond window boundaries, then clipping to the valid range.

As a consequence, the actual number of affected frames may be smaller than permute_fraction * window_size when the block overlaps with window boundaries. On average, the expected probability for any given frame to be affected is:

block_size / (window_size + block_size - 1)

which simplifies to approximately permute_fraction / (1 + permute_fraction) for large windows. For example, with permute_fraction=0.3, the expected probability per frame is approximately 0.23 (about 77% of the nominal fraction).

Note that permute_fraction specifies the nominal block size, not the expected fraction of affected frames. Even with permute_fraction=1.0, the expected probability per frame would be ~0.5, not 1.0, because the block can “hang off” either edge of the window. This is the expected tradeoff for achieving uniform frame probability.

Examples

>>> permute = RandomBlockPermutation(seed=42, coordinate='individuals',
...                                   permute_fraction=0.3)
>>> permuted_ds = permute(posetracks)
>>> # Guarantee a permutation occurs within the block
>>> permute = RandomBlockPermutation(seed=42, coordinate='individuals',
...                                   permute_fraction=0.3, exclude_identity=True)
>>> permuted_ds = permute(posetracks)
__init__(seed, coordinate='individuals', permute_fraction=0.5, exclude_identity=False)[source]#
class lisbet.transforms_extra.RandomRotation(seed, max_angle=180.0, mode='truncate')[source]#

Apply a random rotation to keypoint coordinates in normalized [0, 1] space.

Samples a rotation angle uniformly from [-max_angle, +max_angle] and applies it consistently across all frames in the window. For 2D data, rotates around the center (0.5, 0.5). For 3D data, rotates around (0.5, 0.5, 0.5) about a randomly sampled unit axis using Rodrigues’ formula.

After rotation, coordinates can be normalized back to [0, 1] using one of three modes: "truncate" (clamp), "rescale" (min-max rescaling per spatial dimension), or "none" (no normalization).

Note: input data is assumed to be free of NaN values. NaN values are replaced with 0.0 at load time (see lisbet.io.core._load_posetracks).

Parameters:
  • seed (int) – RNG seed for reproducibility.

  • max_angle (float) – Maximum rotation angle in degrees. The angle is sampled uniformly from [-max_angle, +max_angle]. Default is 180.0.

  • mode (str) –

    Normalization mode after rotation. One of:

    • "truncate": Clamp coordinates to [0, 1].

    • "rescale": If any coordinate falls outside [0, 1] after rotation, rescale each spatial dimension independently so that the min maps to 0 and the max maps to 1 (across all keypoints, individuals, and time). If all coordinates are already within [0, 1], no rescaling is applied.

    • "none": No normalization is applied.

    Default is "truncate".

Examples

>>> from lisbet.transforms_extra import RandomRotation
>>> rotation = RandomRotation(seed=42, max_angle=30.0)
>>> rotated_ds = rotation(posetracks)
>>> # Rescale mode for 3D data
>>> rotation = RandomRotation(seed=42, max_angle=45.0, mode='rescale')
>>> rotated_ds = rotation(posetracks)
__init__(seed, max_angle=180.0, mode='truncate')[source]#
class lisbet.transforms_extra.PoseToTensor[source]#

Convert the ‘position’ variable from a posetracks xarray.Dataset into a PyTorch tensor.

This transformation stacks the ‘individuals’, ‘keypoints’, and ‘space’ dimensions into a single ‘features’ dimension, resulting in a tensor of shape (time, features), where features = individuals * keypoints * space.

Parameters:

None

__call__(posetracks)[source]#

Stack the ‘individuals’, ‘keypoints’, and ‘space’ dimensions of the ‘position’ variable and return as a PyTorch tensor.

Examples

>>> tensor = PoseToTensor()(posetracks)
>>> tensor.shape
torch.Size([time, features])
class lisbet.transforms_extra.PoseToVideo(body_specs, image_size=(256, 256), bg_color='black')[source]#

Fast OpenCV-based transformation: posetracks (xarray.Dataset) to a sequence of BGR images.

__init__(body_specs, image_size=(256, 256), bg_color='black')[source]#

Fast OpenCV-based transformation using BodySpecs for each individual.

Parameters:
  • body_specs (dict[str, BodySpecs]) – Dictionary mapping individual_name (or species) to BodySpecs.

  • image_size (tuple of int, optional) – (width, height) of output frames. Default is (256, 256).

  • bg_color (tuple or str, optional) – BGR tuple or color name/hex for background color (default is black).

render_frame(posetracks, t_idx)[source]#

Render a single frame of pose tracks as a BGR image.

Parameters:
  • posetracks (xarray.Dataset) – The pose tracks dataset containing keypoints and individuals. Must have a “position” variable with dimensions (“time”, “individuals”, “keypoints”, “space”).

  • t_idx (int) – The time index of the frame to render.

Returns:

frame – The rendered frame as a (height, width, 3) uint8 RGB image.

Return type:

numpy.ndarray

class lisbet.transforms_extra.VideoToTensor(normalize=True, mean=None, std=None, dtype=torch.float32)[source]#

Transform a video (NumPy RGB array) into a PyTorch tensor suitable for video models.

Converts (frames, H, W, 3) RGB uint8/float arrays to (frames, 3, H, W) float tensors, with optional normalization and mean/std normalization.

Parameters:
  • normalize (bool, optional) – If True, scale pixel values to [0, 1] (default: True).

  • mean (tuple or list or np.ndarray or torch.Tensor, optional) – Per-channel mean for normalization (applied after scaling to [0, 1]). If None, no mean subtraction is performed.

  • std (tuple or list or np.ndarray or torch.Tensor, optional) – Per-channel std for normalization (applied after mean subtraction). If None, no std division is performed.

  • dtype (torch.dtype, optional) – Output tensor dtype (default: torch.float32).

__init__(normalize=True, mean=None, std=None, dtype=torch.float32)[source]#