"""Augmentation module for transforming samples in a dataset.
This module provides data augmentation and preprocessing transforms for pose tracking
datasets stored as xarray.Dataset objects. The transforms can be used in training
pipelines to improve model robustness and generalization.
Available Transforms
--------------------
RandomPermutation
Randomly permutes both coordinate labels and their associated data together across
the entire time window. Useful for making models invariant to coordinate ordering
(e.g., individual identities, spatial axes).
RandomBlockPermutation
Randomly permutes data within a contiguous block of frames while keeping coordinate
labels unchanged. Creates temporal identity confusion within part of the window.
Useful for more challenging augmentation scenarios.
RandomRotation
Applies a random rotation to keypoint coordinates in normalized [0, 1] space.
Supports 2D and 3D keypoints with configurable maximum angle and post-rotation
normalization modes (truncate, rescale, or none).
KeypointAblation
Randomly sets keypoint coordinates to 0.0 with independent Bernoulli sampling
across (time, keypoints, individuals). Simulates missing or occluded keypoints
for robustness testing.
PoseToTensor
Converts pose tracking data from xarray.Dataset format to PyTorch tensors by
stacking spatial dimensions into a single feature dimension.
PoseToVideo
Renders pose tracking data as video frames (RGB images) using OpenCV, with
customizable body specifications for visualization.
VideoToTensor
Converts video frames from NumPy arrays to PyTorch tensors with optional
normalization for video model inputs.
Usage Examples
--------------
>>> from lisbet.transforms_extra import RandomPermutation, PoseToTensor
>>> from torchvision import transforms
>>>
>>> # Simple augmentation pipeline
>>> transform = transforms.Compose([
... RandomPermutation(seed=42, coordinate='individuals'),
... PoseToTensor(),
... ])
>>>
>>> # Apply with probability using torchvision.transforms.RandomApply
>>> transform = transforms.Compose([
... transforms.RandomApply([
... RandomPermutation(seed=42, coordinate='individuals')
... ], p=0.5),
... PoseToTensor(),
... ])
>>>
>>> # Block permutation for temporal identity confusion
>>> from lisbet.transforms_extra import RandomBlockPermutation
>>> transform = transforms.Compose([
... RandomBlockPermutation(
... seed=42, coordinate='individuals', permute_fraction=0.3
... ),
... PoseToTensor(),
... ])
>>>
>>> # Keypoint ablation for robustness to missing data
>>> from lisbet.transforms_extra import KeypointAblation
>>> transform = transforms.Compose([
... transforms.RandomApply([
... KeypointAblation(seed=42, p=0.05)
... ], p=1.0),
... PoseToTensor(),
... ])
>>>
>>> # Random rotation augmentation for spatial invariance
>>> from lisbet.transforms_extra import RandomRotation
>>> transform = transforms.Compose([
... RandomRotation(seed=42, max_angle=30.0, mode='truncate'),
... PoseToTensor(),
... ])
Notes
-----
- Augmentations should be applied thoughtfully based on dataset characteristics
- Spatial axis permutation (coordinate='space') is only suitable for top-down view
datasets where axes are symmetric
- Identity permutations work best for datasets where individual labels are
interchangeable
"""
import cv2
import numpy as np
import torch
import xarray as xr
from lisbet.drawing import BodySpecs, body_specs_registry, color_to_bgr
def _random_permutation(n, generator, exclude_identity=False):
"""Generate a random permutation of n elements.
Parameters
----------
n : int
Number of elements to permute.
generator : torch.Generator
Random number generator.
exclude_identity : bool
If True, the identity permutation is excluded.
Returns
-------
list
A random permutation as a list of indices.
Raises
------
ValueError
If exclude_identity=True and n < 2.
"""
if exclude_identity and n < 2:
raise ValueError("Cannot exclude identity permutation for n < 2")
perm = torch.randperm(n, generator=generator).tolist()
if exclude_identity:
identity = list(range(n))
while perm == identity:
perm = torch.randperm(n, generator=generator).tolist()
return perm
[docs]
class GaussianJitter:
"""Apply Gaussian jitter with across the full window.
Apply a Gaussian noise N(0, sigma^2) is added across all dimension.
Coordinates are assumed normalized in [0, 1] and
are clamped to that range post-perturbation.
Parameters
----------
seed : int
RNG seed for reproducibility.
sigma : float
Standard deviation of the Gaussian noise.
"""
[docs]
def __init__(self, seed: int, sigma: float):
self.seed = seed
self.sigma = float(sigma)
self.g = torch.Generator().manual_seed(seed)
def __call__(self, posetracks: xr.Dataset) -> xr.Dataset:
pos_var = posetracks["position"]
dims = list(pos_var.dims)
# Validate dataset dimensions
required_dims = {"time", "keypoints", "individuals"}
missing_dims = required_dims - set(dims)
if missing_dims:
raise ValueError(
f"Position variable must contain {required_dims} dimensions. "
f"Missing: {missing_dims}"
)
shape = pos_var.shape
# Mask shape excludes space dimension(s) for independence semantics.
mask_shape = [shape[d] for d in range(len(shape))]
# Replace space dimension size(s) by 1 for broadcasting (space may be before
# keypoints as per dataset examples)
for s_name in ["space"]:
if s_name in dims:
s_idx = dims.index(s_name)
mask_shape[s_idx] = 1
# Ensure independence only over time,keypoints,individuals by collapsing non
# listed dims to 1
for d_name in dims:
if d_name not in ("time", "keypoints", "individuals", "space"):
mask_shape[dims.index(d_name)] = 1
# Create noise tensor same full shape
noise = torch.randn(shape, generator=self.g) * self.sigma
# Apply
pos = torch.from_numpy(pos_var.values)
pos = pos + noise
# Clamp to [0,1]
pos.clamp_(0.0, 1.0)
# print('clamped pos:', pos)
pos_var = pos.numpy()
posetracks["position"].values[:] = pos_var
return posetracks
[docs]
class KeypointAblation:
"""Apply keypoint ablation with per-(keypoint, individual) Bernoulli sampling.
Probability ``pB`` is applied independently to each (keypoint, individual) pair.
For every selected pair, all spatial coordinates (x, y, z, etc.) are set to NaN
across the entire time window, simulating sustained missing or occluded keypoints.
This augmentation helps models become robust to missing data, which commonly occurs
due to occlusions, tracking failures, or low-confidence detections.
Parameters
----------
seed : int
RNG seed for reproducibility.
pB : float
Bernoulli probability for each (keypoint, individual) pair across the full
window.
Examples
--------
>>> from lisbet.transforms_extra import KeypointAblation
>>> ablation = KeypointAblation(seed=42, pB=0.05)
>>> ablated_ds = ablation(posetracks)
"""
[docs]
def __init__(self, seed: int, pB: float):
self.seed = seed
self.pB = float(pB)
self.g = torch.Generator().manual_seed(seed)
def __call__(self, posetracks: xr.Dataset) -> xr.Dataset:
pos_var = posetracks["position"]
dims = list(pos_var.dims)
# Validate dataset dimensions
required_dims = {"time", "keypoints", "individuals"}
missing_dims = required_dims - set(dims)
if missing_dims:
raise ValueError(
f"Position variable must contain {required_dims} dimensions. "
f"Missing: {missing_dims}"
)
shape = pos_var.shape
# Create mask shape for (keypoints, individuals) only, broadcast over time and
# space
mask_shape = []
for d_name in dims:
if d_name in ("keypoints", "individuals"):
mask_shape.append(shape[dims.index(d_name)])
else:
# Set to 1 for broadcasting (time, space, etc.)
mask_shape.append(1)
# Generate Bernoulli mask for (keypoint, individual) pairs
bern = torch.rand(mask_shape, generator=self.g) < self.pB
# Apply ablation by setting selected (keypoint, individual) pairs to NaN
# across all time
pos = torch.from_numpy(pos_var.values)
pos = torch.where(bern, torch.tensor(0.0), pos)
pos_var.values[:] = pos.numpy()
return posetracks
[docs]
class RandomPermutation:
"""
Randomly permutes the order of a specified coordinate (e.g., 'individuals') in an
xarray.Dataset, reordering both the coordinate labels and their associated data
together.
This augmentation can be used to increase invariance to coordinate order (e.g.,
fixed identity, axis orientation). The permutation is applied to the entire dataset.
Parameters
----------
seed : int
Random seed for reproducibility.
coordinate : str
Name of the coordinate to permute (e.g., 'individuals', 'keypoints', 'space').
exclude_identity : bool
If True, the identity permutation (no change) is excluded. This guarantees
that at least one element will be moved. Default is False.
Methods
-------
__call__(posetracks)
Applies the random permutation to the specified coordinate of the input
xarray.Dataset.
Examples
--------
>>> permute = RandomPermutation(seed=42, coordinate='individuals')
>>> permuted_ds = permute(posetracks)
>>> # Guarantee a permutation occurs
>>> permute = RandomPermutation(seed=42, coordinate='space', exclude_identity=True)
>>> permuted_ds = permute(posetracks)
"""
[docs]
def __init__(self, seed, coordinate="individuals", exclude_identity=False):
self.seed = seed
self.coordinate = coordinate
self.exclude_identity = exclude_identity
self.g = torch.Generator().manual_seed(seed)
[docs]
def __call__(self, posetracks):
"""
Apply random permutation to the specified coordinate.
Parameters
----------
posetracks : xarray.Dataset
Pose tracks dataset with a 'position' variable.
Returns
-------
xarray.Dataset
Dataset with permuted coordinate and data.
"""
# Get current coordinate values
coord_vals = list(posetracks.coords[self.coordinate].values)
# Generate a random permutation
perm = _random_permutation(len(coord_vals), self.g, self.exclude_identity)
# Apply permutation to the entire dataset
# NOTE: This reorders both coordinates and data together
posetracks = posetracks.isel({self.coordinate: perm})
return posetracks
[docs]
class RandomBlockPermutation:
"""
Randomly permutes the data (but not coordinate labels) of a specified coordinate
within a random contiguous block of frames in an xarray.Dataset.
This augmentation is useful to create identity swaps within a portion of the time
series, mimicking the effects of a tracking error, while maintaining consistent
coordinate labels throughout.
Parameters
----------
seed : int
Random seed for reproducibility.
coordinate : str
Name of the coordinate to permute (e.g., 'individuals', 'keypoints').
permute_fraction : float
Fraction of the time window to which the permutation is applied.
Must be in (0, 1]. A continuous block of frames of this relative size will be
selected at random, and the permutation will be applied only to the data
within this block, keeping coordinate labels unchanged.
exclude_identity : bool
If True, the identity permutation (no change) is excluded. This guarantees
that at least one element will be moved. Default is False.
Methods
-------
__call__(posetracks)
Applies the random block permutation to the specified coordinate of the input
xarray.Dataset.
Notes
-----
This implementation uses uniform frame probability sampling to ensure that every
frame in the window has an equal probability of being affected by the permutation,
regardless of its position. This is achieved by allowing the block's starting
position to extend beyond window boundaries, then clipping to the valid range.
As a consequence, the actual number of affected frames may be smaller than
``permute_fraction * window_size`` when the block overlaps with window boundaries.
On average, the expected probability for any given frame to be affected is::
block_size / (window_size + block_size - 1)
which simplifies to approximately ``permute_fraction / (1 + permute_fraction)``
for large windows. For example, with ``permute_fraction=0.3``, the expected
probability per frame is approximately 0.23 (about 77% of the nominal fraction).
Note that ``permute_fraction`` specifies the *nominal* block size, not the
expected fraction of affected frames. Even with ``permute_fraction=1.0``, the
expected probability per frame would be ~0.5, not 1.0, because the block can
"hang off" either edge of the window. This is the expected tradeoff for
achieving uniform frame probability.
Examples
--------
>>> permute = RandomBlockPermutation(seed=42, coordinate='individuals',
... permute_fraction=0.3)
>>> permuted_ds = permute(posetracks)
>>> # Guarantee a permutation occurs within the block
>>> permute = RandomBlockPermutation(seed=42, coordinate='individuals',
... permute_fraction=0.3, exclude_identity=True)
>>> permuted_ds = permute(posetracks)
"""
[docs]
def __init__(
self,
seed,
coordinate="individuals",
permute_fraction=0.5,
exclude_identity=False,
):
self.seed = seed
self.coordinate = coordinate
if not 0 < permute_fraction <= 1:
raise ValueError("permute_fraction must be a float in (0, 1].")
self.permute_fraction = permute_fraction
self.exclude_identity = exclude_identity
self.g = torch.Generator().manual_seed(seed)
[docs]
def __call__(self, posetracks):
"""
Apply random block permutation to the specified coordinate.
Parameters
----------
posetracks : xarray.Dataset
Pose tracks dataset with a 'position' variable.
Returns
-------
xarray.Dataset
Dataset with permuted data in a random block, coordinates unchanged.
"""
# Get current coordinate values
coord_vals = list(posetracks.coords[self.coordinate].values)
# Generate a random permutation
perm = _random_permutation(len(coord_vals), self.g, self.exclude_identity)
window_size = posetracks.sizes["time"]
block_size = int(self.permute_fraction * window_size)
if block_size == 0:
# No permutation needed
return posetracks
# Sample start_idx from extended range to ensure uniform frame probability.
# Range: [1 - block_size, window_size - 1] gives each frame exactly
# block_size chances to be included in the block.
start_idx = torch.randint(
1 - block_size, window_size, (1,), generator=self.g
).item()
# Clip to valid range
actual_start = max(0, start_idx)
actual_end = min(window_size, start_idx + block_size)
# For block permutation, we permute only the data
# while keeping coordinates unchanged across the full time series
block_to_permute = posetracks.isel(time=slice(actual_start, actual_end))
# Get the dimension index for the coordinate
coord_dim = list(posetracks["position"].dims).index(self.coordinate)
# Permute the data along the coordinate dimension
permuted_data = np.take(
block_to_permute["position"].values, perm, axis=coord_dim
)
# Create a new block with permuted data but original coordinates
permuted_block = block_to_permute.copy(deep=True)
permuted_block["position"].values[:] = permuted_data
# Split and concatenate
before_block = posetracks.isel(time=slice(None, actual_start))
after_block = posetracks.isel(time=slice(actual_end, None))
posetracks = xr.concat(
[before_block, permuted_block, after_block], dim="time", join="outer"
)
return posetracks
[docs]
class RandomRotation:
"""Apply a random rotation to keypoint coordinates in normalized [0, 1] space.
Samples a rotation angle uniformly from [-max_angle, +max_angle] and applies it
consistently across all frames in the window. For 2D data, rotates around the
center (0.5, 0.5). For 3D data, rotates around (0.5, 0.5, 0.5) about a randomly
sampled unit axis using Rodrigues' formula.
After rotation, coordinates can be normalized back to [0, 1] using one of three
modes: ``"truncate"`` (clamp), ``"rescale"`` (min-max rescaling per spatial
dimension), or ``"none"`` (no normalization).
Note: input data is assumed to be free of NaN values. NaN values are replaced
with 0.0 at load time (see ``lisbet.io.core._load_posetracks``).
Parameters
----------
seed : int
RNG seed for reproducibility.
max_angle : float
Maximum rotation angle in degrees. The angle is sampled uniformly from
[-max_angle, +max_angle]. Default is 180.0.
mode : str
Normalization mode after rotation. One of:
- ``"truncate"``: Clamp coordinates to [0, 1].
- ``"rescale"``: If any coordinate falls outside [0, 1] after rotation,
rescale each spatial dimension independently so that the min maps to 0
and the max maps to 1 (across all keypoints, individuals, and time).
If all coordinates are already within [0, 1], no rescaling is applied.
- ``"none"``: No normalization is applied.
Default is ``"truncate"``.
Examples
--------
>>> from lisbet.transforms_extra import RandomRotation
>>> rotation = RandomRotation(seed=42, max_angle=30.0)
>>> rotated_ds = rotation(posetracks)
>>> # Rescale mode for 3D data
>>> rotation = RandomRotation(seed=42, max_angle=45.0, mode='rescale')
>>> rotated_ds = rotation(posetracks)
"""
[docs]
def __init__(self, seed: int, max_angle: float = 180.0, mode: str = "truncate"):
valid_modes = ("truncate", "rescale", "none")
if mode not in valid_modes:
raise ValueError(f"mode must be one of {valid_modes}, got '{mode}'")
self.seed = seed
self.max_angle = float(max_angle)
self.mode = mode
self.g = torch.Generator().manual_seed(seed)
def __call__(self, posetracks: xr.Dataset) -> xr.Dataset:
"""
Apply random rotation to keypoint coordinates.
Parameters
----------
posetracks : xarray.Dataset
Pose tracks dataset with a 'position' variable containing dimensions
(time, keypoints, individuals, space).
Returns
-------
xarray.Dataset
Dataset with rotated position coordinates.
Raises
------
ValueError
If the 'space' dimension has a size other than 2 or 3.
"""
dims = list(posetracks["position"].dims)
space_idx = dims.index("space")
n_space = posetracks["position"].shape[space_idx]
if n_space not in (2, 3):
raise ValueError(f"'space' dimension must have size 2 or 3, got {n_space}")
# Sample rotation angle uniformly from [-max_angle, +max_angle]
angle_deg = (
torch.rand(1, generator=self.g).item() * 2.0 - 1.0
) * self.max_angle
angle_rad = angle_deg * (np.pi / 180.0)
# Build rotation matrix
c, s = np.cos(angle_rad), np.sin(angle_rad)
if n_space == 2:
R = np.array([[c, -s], [s, c]])
else:
# 3D: sample a random unit axis uniformly on the unit sphere
axis = torch.randn(3, generator=self.g).numpy()
axis = axis / np.linalg.norm(axis)
# Rodrigues' rotation formula: R = I + sin(θ)K + (1 - cos(θ))K²
kx, ky, kz = axis
K = np.array([[0.0, -kz, ky], [kz, 0.0, -kx], [-ky, kx, 0.0]])
R = np.eye(3) + s * K + (1.0 - c) * (K @ K)
# Rotate around center of the [0, 1] space
pos = np.moveaxis(posetracks["position"].values - 0.5, space_idx, -1) @ R.T
pos = np.moveaxis(pos, -1, space_idx) + 0.5
# Apply normalization mode
if self.mode == "truncate":
np.clip(pos, 0.0, 1.0, out=pos)
elif self.mode == "rescale" and (np.any(pos < 0.0) or np.any(pos > 1.0)):
for s_i in range(n_space):
slices = [slice(None)] * pos.ndim
slices[space_idx] = s_i
spatial_slice = pos[tuple(slices)]
vmin, vmax = spatial_slice.min(), spatial_slice.max()
if vmin != vmax:
pos[tuple(slices)] = (spatial_slice - vmin) / (vmax - vmin)
posetracks["position"].values[:] = pos
return posetracks
[docs]
class PoseToTensor:
"""
Convert the 'position' variable from a posetracks xarray.Dataset into a PyTorch
tensor.
This transformation stacks the 'individuals', 'keypoints', and 'space' dimensions
into a single 'features' dimension, resulting in a tensor of shape
(time, features), where features = individuals * keypoints * space.
Parameters
----------
None
Methods
-------
__call__(posetracks)
Stack the 'individuals', 'keypoints', and 'space' dimensions of the 'position'
variable and return as a PyTorch tensor.
Examples
--------
>>> tensor = PoseToTensor()(posetracks)
>>> tensor.shape
torch.Size([time, features])
"""
[docs]
def __call__(self, posetracks):
"""
Stack the 'individuals', 'keypoints', and 'space' dimensions of the 'position'
variable in the input xarray.Dataset and return as a PyTorch tensor.
Parameters
----------
posetracks : xarray.Dataset
Pose tracks dataset with a 'position' variable of shape
(time, individuals, keypoints, space).
Returns
-------
torch.Tensor
Tensor of shape (time, features), where features =
individuals * keypoints * space, containing the stacked position data.
"""
return torch.from_numpy(
posetracks.stack(
features=("individuals", "keypoints", "space")
).position.values.astype("float32")
)
[docs]
class PoseToVideo:
"""
Fast OpenCV-based transformation: posetracks (xarray.Dataset) to a sequence of BGR
images.
"""
[docs]
def __init__(
self,
body_specs: dict[str, BodySpecs],
image_size=(256, 256),
bg_color="black",
):
"""
Fast OpenCV-based transformation using BodySpecs for each individual.
Parameters
----------
body_specs : dict of str to BodySpecs
Dictionary mapping individual_name (or species) to BodySpecs.
image_size : tuple of int, optional
(width, height) of output frames. Default is (256, 256).
bg_color : tuple or str, optional
BGR tuple or color name/hex for background color (default is black).
"""
self.body_specs = body_specs
self.width, self.height = image_size
self.bg_color = color_to_bgr(bg_color)
def __call__(self, posetracks, show_progress=False):
frames = [
self.render_frame(posetracks, t) for t in range(posetracks.sizes["time"])
]
frames = np.stack(frames, axis=0)
# # Convert to PyTorch tensor
# frames = torch.Tensor(frames)
return frames
[docs]
def render_frame(self, posetracks, t_idx):
"""
Render a single frame of pose tracks as a BGR image.
Parameters
----------
posetracks : xarray.Dataset
The pose tracks dataset containing keypoints and individuals.
Must have a "position" variable with dimensions ("time", "individuals",
"keypoints", "space").
t_idx : int
The time index of the frame to render.
Returns
-------
frame : numpy.ndarray
The rendered frame as a (height, width, 3) uint8 RGB image.
"""
frame = np.full((self.height, self.width, 3), self.bg_color, dtype=np.uint8)
pos = (
posetracks["position"]
.isel(time=t_idx)
.transpose("individuals", "keypoints", "space")
.values
)
keypoints = list(posetracks.keypoints.values)
individuals = list(posetracks.individuals.values)
for ind_idx, ind_name in enumerate(individuals):
spec = self.body_specs.get(ind_name, body_specs_registry.get(ind_name))
if spec is None:
continue
# Draw polygons (with alpha blending)
for poly in spec.polygons:
pts = []
for kp in poly:
if kp in keypoints:
idx = keypoints.index(kp)
x, y = pos[ind_idx, idx, :]
# Skip if coordinates are NaN (ablated keypoints)
if not (np.isnan(x) or np.isnan(y)):
pts.append([int(x * self.width), int(y * self.height)])
if len(pts) >= 3:
pts_np = np.array([pts], dtype=np.int32)
overlay = frame.copy()
color = color_to_bgr(spec.polygon_color)
cv2.fillPoly(overlay, pts_np, color)
frame = cv2.addWeighted(
overlay, spec.polygon_alpha, frame, 1 - spec.polygon_alpha, 0
)
# Draw skeleton
for edge in spec.skeleton_edges:
if edge[0] in keypoints and edge[1] in keypoints:
idx1 = keypoints.index(edge[0])
idx2 = keypoints.index(edge[1])
x1, y1 = pos[ind_idx, idx1, :]
x2, y2 = pos[ind_idx, idx2, :]
# Skip if any coordinates are NaN (ablated keypoints)
if not (
np.isnan(x1) or np.isnan(y1) or np.isnan(x2) or np.isnan(y2)
):
color = color_to_bgr(spec.skeleton_color)
cv2.line(
frame,
(int(x1 * self.width), int(y1 * self.height)),
(int(x2 * self.width), int(y2 * self.height)),
color=color,
thickness=spec.skeleton_thickness,
lineType=cv2.LINE_AA,
)
# Draw keypoints
for k, kp in enumerate(keypoints):
x, y = pos[ind_idx, k, :]
# Skip if coordinates are NaN (ablated keypoints)
if not (np.isnan(x) or np.isnan(y)):
color = color_to_bgr(spec.get_keypoint_color(kp))
cv2.circle(
frame,
(int(x * self.width), int(y * self.height)),
spec.keypoint_size,
color=color,
thickness=-1,
lineType=cv2.LINE_AA,
)
# Convert a BGR frame (OpenCV) to RGB
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
return frame
[docs]
class VideoToTensor:
"""
Transform a video (NumPy RGB array) into a PyTorch tensor suitable for video models.
Converts (frames, H, W, 3) RGB uint8/float arrays to (frames, 3, H, W) float
tensors, with optional normalization and mean/std normalization.
Parameters
----------
normalize : bool, optional
If True, scale pixel values to [0, 1] (default: True).
mean : tuple or list or np.ndarray or torch.Tensor, optional
Per-channel mean for normalization (applied after scaling to [0, 1]).
If None, no mean subtraction is performed.
std : tuple or list or np.ndarray or torch.Tensor, optional
Per-channel std for normalization (applied after mean subtraction).
If None, no std division is performed.
dtype : torch.dtype, optional
Output tensor dtype (default: torch.float32).
"""
[docs]
def __init__(self, normalize=True, mean=None, std=None, dtype=torch.float32):
self.normalize = normalize
self.mean = mean
self.std = std
self.dtype = dtype
def __call__(self, video):
"""
Parameters
----------
video : np.ndarray
Video as (frames, H, W, 3) RGB, dtype uint8 or float.
Returns
-------
torch.Tensor
Video as (frames, 3, H, W), dtype as specified.
"""
if not isinstance(video, np.ndarray):
raise TypeError("Input video must be a numpy ndarray.")
if video.ndim != 4 or video.shape[-1] != 3:
raise ValueError("Input video must have shape (frames, H, W, 3) [RGB].")
# If uint8, convert to float32 for normalization
if video.dtype == np.uint8:
video = video.astype(np.float32)
if self.normalize:
video = video / 255.0
elif self.normalize:
# Assume already float, but ensure in [0, 1]
video = np.clip(video, 0.0, 1.0)
# Rearrange to (frames, 3, H, W)
video = np.transpose(video, (0, 3, 1, 2))
tensor = torch.from_numpy(video).type(self.dtype)
# Optional mean/std normalization (per channel)
if self.mean is not None:
mean = torch.as_tensor(self.mean, dtype=self.dtype, device=tensor.device)
if mean.ndim == 1:
mean = mean.view(1, 3, 1, 1)
tensor = tensor - mean
if self.std is not None:
std = torch.as_tensor(self.std, dtype=self.dtype, device=tensor.device)
if std.ndim == 1:
std = std.view(1, 3, 1, 1)
tensor = tensor / std
return tensor