"""LISBET"""
import logging
from itertools import groupby
from pathlib import Path
import numpy as np
import pandas as pd
from scipy.cluster import hierarchy
from scipy.signal import savgol_filter
from scipy.spatial.distance import pdist, squareform
from sklearn.metrics import silhouette_samples, silhouette_score
from tqdm.auto import tqdm
[docs]
def load_annotations(annot_root, hmm_list):
"""
Load machine annotations from the given root directory.
Parameters
----------
annot_root : str
The root directory containing the annotation files.
hmm_list : list of int
List of numbers of states for Hidden Markov Models (HMMs).
Returns
-------
session_data : dict
Dictionary where keys are session paths and values are DataFrames with
concatenated HMM annotations.
"""
# Find all sessions
session_paths = set(
fname.parent for fname in Path(annot_root).glob("**/machineAnnotation*.csv")
)
logging.debug("Available sessions: %d", len(session_paths))
# Load data
session_data = {}
for session_path in tqdm(session_paths, desc="Loading data"):
key = str(session_path.relative_to(annot_root))
# Load annotation files
annot_data = [
pd.read_csv(
session_path / f"machineAnnotation_hmm{hmm_id}.csv", index_col=0
)
for hmm_id in hmm_list
]
annot_data = pd.concat(annot_data, axis=1, keys=hmm_list)
annot_data.columns = [
f"HMM_{x}_{y}" for x, y in annot_data.columns.to_flat_index()
]
session_data[key] = annot_data
return session_data
def _filter_by_frame(concat_data, frame_threshold):
"""
Filter motifs based on frame threshold.
Parameters
----------
concat_data : DataFrame
Concatenated DataFrame containing motif data.
frame_threshold : float
Minimum mean value for frames to be kept.
Returns
-------
concat_data : DataFrame
Filtered DataFrame with columns having mean values above the threshold.
"""
# Calculate the mean of each column
mean_values = concat_data.mean(axis=0)
# Identify columns where the mean is greater than or equal to the threshold
valid_columns = mean_values[mean_values >= frame_threshold].index
logging.debug("Filter by frame valid columns: %s", valid_columns.values)
# Filter the DataFrame to keep only the desired columns
concat_data = concat_data[valid_columns]
return concat_data
def _filter_by_bout(concat_data, bout_threshold, fps):
"""
Filter motifs based on bout threshold.
Parameters
----------
concat_data : DataFrame
Concatenated DataFrame containing motif data.
bout_threshold : float
Minimum mean bout duration for motifs to be kept.
fps : int
Frames per second, used to compute bout duration.
Returns
-------
concat_data : DataFrame
Filtered DataFrame with columns having mean bout durations above the threshold.
"""
events = []
for column_name in tqdm(concat_data.columns, desc="Computing bouts duration"):
column_data = concat_data[column_name]
# Compute bout duration
events.extend(
[
(column_name, sum(1 for i in g) / fps)
for k, g in groupby(column_data)
if k == 1
]
)
events = pd.DataFrame(
events,
columns=["motif_id", "bout_duration (s)"],
)
events_stats = (
events.groupby("motif_id")["bout_duration (s)"]
.agg(["mean", "std", "count", "sum"])
.reset_index()
)
valid_columns = events_stats[events_stats["mean"] >= bout_threshold]["motif_id"]
logging.debug("Filter by bout valid columns: %s", valid_columns.values)
concat_data = concat_data[valid_columns]
return concat_data
def _filter_by_distance(concat_data, distance_threshold):
cond_dist_matrix = squareform(pdist(concat_data.T, metric="jaccard"))
valid_columns = np.where(np.sum(cond_dist_matrix < distance_threshold, axis=0) - 1)
logging.debug(
"Filter by distance valid columns: %s",
concat_data.columns[valid_columns].values,
)
concat_data = concat_data.iloc[:, valid_columns[0]]
return concat_data
[docs]
def select_prototypes(
data_path: str,
min_n_components: int,
max_n_components: int,
method: str = "best",
frame_threshold: float = 0.05,
bout_threshold: float = 0.5,
distance_threshold: float = 0.6,
fps: int = 30,
output_path: str | None = None,
) -> tuple[dict, list[tuple[str, pd.DataFrame]]]:
"""
Select motifs from a set of Hidden Markov Models using a posteriori linkage.
Parameters
----------
data_path : str
The root directory containing the annotation files.
min_n_components : int, default=2
Minimum number of states for the HMMs.
max_n_components : int, default=32
Maximum number of states for the HMMs.
method : str, default='best'
Method for selecting prototypes. Valid options are 'min' and 'best'.
frame_threshold : float, default=0.05
Minimum fraction of allocated frames for motifs to be kept.
bout_threshold : float, default=0.5
Minimum mean bout duration in seconds for motifs to be kept.
distance_threshold : float, default=0.6
Maximum Jaccard distance from the closest motif (pairs only).
fps : int, default=30
Frames per second, used to compute bout duration.
output_path : str, optional
Path to store the output predictions. If `None`, results are not saved.
Returns
-------
hmm_info : dict
Dictionary containing supporting information useful for plotting the results.
predictions : list of tuples
List of tuples, where each tuple contains a session key and the corresponding
motifs DataFrame.
Notes
-----
[a] This method could be easily generalized to other clustering algorithms.
"""
# Load session data
hmm_list = list(range(min_n_components, max_n_components + 1))
session_data = load_annotations(data_path, hmm_list)
# Concatenate all sessions in a single dataset
concat_data = pd.concat(session_data.values(), ignore_index=True)
logging.debug("Annotation size: %s", concat_data.shape)
# Filter motifs, if requested
if frame_threshold is not None:
concat_data = _filter_by_frame(concat_data, frame_threshold)
logging.debug("Annotation size after frame threshold: %s", concat_data.shape)
if bout_threshold is not None:
concat_data = _filter_by_bout(concat_data, bout_threshold, fps)
logging.debug("Annotation size after bout threshold: %s", concat_data.shape)
if distance_threshold is not None:
concat_data = _filter_by_distance(concat_data, distance_threshold)
logging.debug("Annotation size after distance threshold: %s", concat_data.shape)
# Compute distance between motifs
cond_dist_matrix = pdist(concat_data.T, metric="jaccard")
# cond_dist_matrix = pdist(
# concat_data.T,
# metric=lambda u, v: 1 - f1_score(u, v, average="binary"),
# )
# n, k = concat_data.shape[1], 2
# cond_dist_matrix = [
# 1 - f1_score(concat_data.iloc[:, u], concat_data.iloc[:, v], average="binary")
# for (u, v)
# in tqdm(
# combinations(range(n), k),
# desc="Computing motifs similarity",
# total=comb(n, k),
# )
# ]
# Compute linkage
link_matrix = hierarchy.linkage(
cond_dist_matrix, method="average", metric=None, optimal_ordering=True
)
# Scan candidate clusters
min_clusters = 2
max_clusters = concat_data.shape[1]
candidates = []
for n_clusters in range(min_clusters, max_clusters):
# NOTE: We convert cluster labels to zero-based indexing
labels = hierarchy.fcluster(link_matrix, n_clusters, criterion="maxclust") - 1
score = silhouette_score(
squareform(cond_dist_matrix), labels, metric="precomputed"
)
candidates.append((score, n_clusters, labels))
# Find peak of smoothed silhouette score vector
y_data = np.array([c[0] for c in candidates])
y_sg = savgol_filter(y_data, window_length=max(3, max_clusters // 10), polyorder=2)
# Get best candidate
best_score, best_n_clusters, best_labels = candidates[np.argmax(y_sg)]
logging.info("Best number of clusters: %d", best_n_clusters)
# Identify prototypes
if method == "min":
prototypes = [
np.where(best_labels == cid)[0][0] for cid in range(best_n_clusters)
]
elif method == "best":
samples = silhouette_samples(
squareform(cond_dist_matrix), best_labels, metric="precomputed"
)
prototypes = [
np.argmax(np.where(best_labels == cid, samples, -1))
for cid in range(best_n_clusters)
]
else:
raise NotImplementedError(
f"Unknown method {method}. Valid options are 'min' and 'best'"
)
logging.debug("Prototypes %s: %s", method, concat_data.columns[prototypes].values)
# Assign predictions to match the corresponding sequences
predictions = []
for key, data in session_data.items():
filtered_data = data[concat_data.columns]
motifs = pd.DataFrame(filtered_data.iloc[:, prototypes])
logging.debug(
"Session %s prototypes %s: %s", key, method, motifs.columns.values
)
predictions.append((key, motifs))
# Store predictions on file, if requested
if output_path is not None:
dst_path = Path(output_path) / "prototypes" / key
dst_path.mkdir(parents=True, exist_ok=True)
motifs.to_csv(
dst_path
/ f"machineAnnotation_hmm{method}_{min(hmm_list)}_{max(hmm_list)}.csv"
)
# Collect supporting information, useful for plotting the results
hmm_info = {
"cond_dist_matrix": cond_dist_matrix,
"link_matrix": link_matrix,
"all_score": np.array([c[0] for c in candidates]),
"all_n_clusters": np.array([c[1] for c in candidates]),
"all_labels": np.array([c[2] for c in candidates]),
"best_n_clusters": best_n_clusters,
"best_score": best_score,
"best_labels": best_labels,
"prototypes": prototypes,
}
# Store supporting information on file, if requested
if output_path is not None:
dst_path = Path(output_path)
dst_path.mkdir(parents=True, exist_ok=True)
np.savez(
dst_path / f"info_hmm{method}_{min(hmm_list)}_{max(hmm_list)}.npz",
**hmm_info,
)
return hmm_info, predictions