from pathlib import Path
import pooch
from huggingface_hub import snapshot_download
from lisbet.io import dump_records
from lisbet.io.ext_sources import calms21, mabe22
[docs]
def fetch_dataset(dataset_id, download_path):
"""
Download and preprocess keypoints datasets from remote repositories.
Downloads the specified dataset, processes raw data (e.g., keypoints, annotations),
and stores them in a standardized format for analysis.
Parameters
----------
dataset_id : str
Identifier for the dataset to fetch. Currently supported datasets:
- "CalMS21_Task1": Mouse behavior classification dataset
- "CalMS21_Unlabeled": Unlabeled mouse behavior videos
- "SampleData": Sample dataset for testing
Additional datasets may be supported in future versions.
download_path : str
Base directory path where the dataset will be stored. The function creates
subdirectories for cache and processed data.
Returns
-------
None
Data is saved to disk in standardized format.
Raises
------
ValueError
If dataset_id is not one of the supported options.
Notes
-----
The function handles downloads with checksums and caching using pooch.
Downloaded data is temporarily stored in a cache directory before being
processed into the final standardized format.
"""
if dataset_id == "CalMS21_Task1":
# Get data from Caltech repo
fnames = pooch.retrieve(
url=(
"https://data.caltech.edu/records/s0vdx-0k302/files/"
"task1_classic_classification.zip?download=1"
),
known_hash="md5:8a02654fddae28614ee24a6a082261b8",
path=Path(download_path) / "datasets" / ".cache" / "lisbet",
processor=pooch.Unzip(
members=[
"task1_classic_classification/calms21_task1_train.json",
"task1_classic_classification/calms21_task1_test.json",
],
),
progressbar=True,
)
# Preprocess keypoints
rawdata_path = Path(fnames[0]).parents[1]
train_records, test_records = calms21.load_taskx(rawdata_path, taskid=1)
# Store data in LISBET-compatible format
data_path = (
Path(download_path)
/ "datasets"
/ "CalMS21"
/ "task1_classic_classification"
)
dump_records(data_path, train_records)
dump_records(data_path, test_records)
elif dataset_id == "CalMS21_Unlabeled":
# Get data from Caltech repo
fnames = pooch.retrieve(
url=(
"https://data.caltech.edu/records/s0vdx-0k302/files/"
"unlabeled_videos.zip?download=1"
),
known_hash="md5:35ab3acdeb231a3fe1536e38ad223b2e",
path=Path(download_path) / "datasets" / ".cache" / "lisbet",
processor=pooch.Unzip(
members=[
"unlabeled_videos/calms21_unlabeled_videos_part1.json",
"unlabeled_videos/calms21_unlabeled_videos_part2.json",
"unlabeled_videos/calms21_unlabeled_videos_part3.json",
"unlabeled_videos/calms21_unlabeled_videos_part4.json",
],
),
progressbar=True,
)
# Preprocess keypoints
rawdata_path = Path(fnames[0]).parents[1]
records = calms21.load_unlabeled(rawdata_path)
# Store data in LISBET-compatible format
data_path = Path(download_path) / "datasets" / "CalMS21" / "unlabeled_videos"
dump_records(data_path, records)
elif dataset_id == "MABe22_MouseTriplets":
# Get data from Caltech repo
train_path = pooch.retrieve(
url=(
"https://data.caltech.edu/records/rdsa8-rde65/files/"
"mouse_triplet_train.npy?download=1"
),
known_hash="md5:76a48f3a1679a219a0e7e8a87871cc74",
path=Path(download_path) / "datasets" / ".cache" / "lisbet",
progressbar=True,
)
test_seq_path = pooch.retrieve(
url=(
# TMP, bug in default Caltech repo
"https://data.caltech.edu/records/8kdn3-95j37/files/"
"mouse_triplet_test.npy?download=1"
),
known_hash="md5:20dc132300118a64aac665dd68153b20",
path=Path(download_path) / "datasets" / ".cache" / "lisbet",
progressbar=True,
)
test_labels_path = pooch.retrieve(
url=(
"https://data.caltech.edu/records/rdsa8-rde65/files/"
"mouse_triplets_test_labels.npy?download=1"
),
known_hash="md5:5a54f2d29a13a256aabbefc61a633176",
path=Path(download_path) / "datasets" / ".cache" / "lisbet",
progressbar=True,
)
# Preprocess keypoints
train_records, test_records = mabe22.load_mouse_triplets(
train_path, test_seq_path, test_labels_path
)
# Store records in LISBET-compatible format
data_path = Path(download_path) / "datasets" / "MABe22" / "mouse_triplets"
dump_records(data_path / "train", train_records)
dump_records(data_path / "test", test_records)
elif dataset_id == "SampleData":
# Fetch data from HuggingFace repo
# NOTE: This is a small sample dataset for testing purposes
data_path = snapshot_download(
repo_id="gchindemi/lisbet-examples",
allow_patterns="sample_keypoints/",
local_dir=Path(download_path) / "datasets",
repo_type="dataset",
)
else:
raise ValueError(f"Unknown dataset {dataset_id}")
[docs]
def fetch_model(model_id, download_path=Path(".")):
"""Fetch a model from the HF Hub."""
valid_model_ids = [
"lisbet32x4-calms21UftT1-classifier",
"lisbet32x4-calms21U-embedder",
]
assert model_id in valid_model_ids, (
f"Model ID '{model_id}' not found in the list of available models."
)
model_path = download_path / model_id
snapshot_download(
repo_id=f"gchindemi/{model_id}", repo_type="model", local_dir=model_path
)