"""Plotting utilities for LISBET.
This module provides a variety of functions for visualizing data and analysis results
related to the LISBET project. The functions cover a range of tasks, including
plotting UMAP embeddings, F1 score matrices, transition graphs, silhouette profiles,
and dendrograms.
"""
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
import networkx as nx
import numpy as np
import pandas as pd
import seaborn as sns
from matplotlib.patches import Rectangle
from matplotlib.transforms import blended_transform_factory
from scipy.cluster import hierarchy
from scipy.optimize import direct
from sklearn.metrics import silhouette_samples, silhouette_score
from sklearn.utils.class_weight import compute_class_weight
[docs]
def mm2inch(x):
"""
Convert millimeters to inches.
Parameters
----------
x : float or array-like
The value(s) in millimeters to convert.
Returns
-------
float or array-like
The converted value(s) in inches.
"""
return x / 25.4
[docs]
def pval2star(pval):
"""
Convert a p-value to the corresponding star notation (APA style).
Parameters
----------
pval : float
The p-value to convert.
Returns
-------
str
The corresponding star notation, where "***" indicates p less than or equal to
0.001, "**" indicates p less than or equal to 0.01, "*" indicates p less than or
equal to 0.05, and "ns" indicates p greater than 0.05.
"""
star_thresholds = [(0.001, "***"), (0.01, "**"), (0.05, "*")]
for threshold, star_notation in star_thresholds:
if pval <= threshold:
return star_notation
return "ns"
[docs]
def get_custom_cmap(n, palette="Set2", alpha=None, desat=None):
"""
Generate a custom colormap with n colors.
Parameters
----------
n : int
Number of colors in the colormap.
palette : str
A valid seaborn/matplotlib palette. Default is "Set2".
alpha : list, optional
Alpha values for each color in the colormap. If None, no alpha values are
applied. Default is None.
Returns
-------
cmap : matplotlib.colors.ListedColormap
A listed colormap with n colors.
"""
colors = sns.color_palette(palette, desat=desat)
colors = colors[:n] if n <= len(colors) else sns.blend_palette(colors, n_colors=n)
if alpha:
colors = [c + (a,) for c, a in zip(colors, alpha, strict=True)]
# Convert to a *discrete* colormap
cmap = mpl.colors.ListedColormap(colors)
return cmap
[docs]
def plot_umap2d(
data,
labels,
targets=None,
sample_size=None,
seed=None,
marker_size="auto",
cmap=None,
cbar_loc="top",
cbar_label="Motif ID",
cbar_ticklabels=None,
ax=None,
):
"""
Plot a 2D UMAP embedding with class labels.
Parameters
----------
data : array-like, shape (n_samples, 2)
The 2D coordinates of the points to plot.
labels : array-like, shape (n_samples,)
The class labels for each point.
sample_size : int, optional
The number of points to randomly sample for plotting. If None, all points are
plotted. Default is None.
seed : int, optional
Random seed for reproducibility when sampling. Default is None.
cmap : matplotlib.colors.Colormap, optional
Colormap for coloring the points. If None, a custom colormap is generated.
Default is None.
cbar_loc : str, optional
Location of the colorbar. Can be "top" or "right". Default is "top".
cbar_label : str, optional
Label for the colorbar. Default is "Motif ID".
cbar_ticklabels : list, optional
Custom tick labels for the colorbar. If None, labels are based on the unique
class labels. Default is None.
ax : matplotlib.axes.Axes, optional
Axes object to draw the plot onto. If None, a new figure and axes are created.
Default is None.
Returns
-------
None
"""
if ax is None:
_, ax = plt.subplots()
# Find number of states
unique_labels = np.unique(labels)
num_states = len(unique_labels)
# Sample points, if requested
if sample_size is not None:
rng = np.random.default_rng(seed)
sample_idx = rng.choice(data.shape[0], replace=False, size=sample_size)
else:
sample_idx = ...
# Compute silhouette score
sample_silhouette = silhouette_score(data[sample_idx], labels[sample_idx])
print(f"silhouette_score = {sample_silhouette}")
# Compute class weights, used to scale marker size in the next scatter plot
if marker_size == "auto":
class_weight = dict(
zip(
unique_labels,
compute_class_weight("balanced", classes=unique_labels, y=labels),
strict=True,
)
)
# Scale marker size to compensate for unbalanced classes
# NOTE: We could use the test labels instead of the predictions to facilitiate
# the visual comparison with Figure 1 in the manuscript.
marker_size = 0.5 * np.array([class_weight[label] for label in labels])
else:
marker_size = np.ones_like(labels) * marker_size
# Assign colors to motifs
colors = np.array([np.where(unique_labels == label)[0][0] for label in labels])
if cmap is None:
cmap = get_custom_cmap(num_states)
# Plot data
sc = ax.scatter(
data[sample_idx, 0],
data[sample_idx, 1],
s=marker_size[sample_idx],
c=colors[sample_idx],
marker=".",
vmin=-0.5,
vmax=num_states - 0.5,
cmap=cmap,
)
if targets is not None:
mis_idx = np.where(targets != labels)[0]
mis_sample_idx = np.intersect1d(mis_idx, sample_idx)
ax.scatter(
data[mis_sample_idx, 0],
data[mis_sample_idx, 1],
s=5 * marker_size[mis_sample_idx],
c=colors[mis_sample_idx],
marker="x",
vmin=-0.5,
vmax=num_states - 0.5,
cmap=cmap,
)
# Create colorbar
cbar = plt.colorbar(
sc,
ax=ax,
location=cbar_loc,
ticks=mticker.FixedLocator(range(num_states)),
label=cbar_label,
shrink=0.95,
)
# Customize ticklabels
if cbar_ticklabels is None:
cbar.ax.set_xticklabels(unique_labels)
else:
if cbar_loc == "right":
cbar.ax.set_yticklabels(
cbar_ticklabels, rotation=90, verticalalignment="center"
)
elif cbar_loc == "top":
cbar.ax.set_xticklabels(
cbar_ticklabels, rotation=30, horizontalalignment="left"
)
else:
raise NotImplementedError(
"Only 'top' and 'right' are valid cbar locations."
)
# Finalize plot
ax.set_xlabel("UMAP$_1$")
ax.set_ylabel("UMAP$_2$")
[docs]
def plot_f1_matrix(f1_matrix, xlabels=None, ylabels=None, ax=None):
"""
Plot an F1 score matrix as a heatmap.
Parameters
----------
f1_matrix : array-like, shape (n_classes, n_classes)
The F1 score matrix.
xlabels : list, optional
Labels for the x-axis. If None, indices are used. Default is None.
ylabels : list, optional
Labels for the y-axis. If None, indices are used. Default is None.
ax : matplotlib.axes.Axes, optional
Axes object to draw the plot onto. If None, a new figure and axes are created.
Default is None.
Returns
-------
None
The plot is drawn on the provided or created axes.
"""
if ax is None:
_, ax = plt.subplots()
if xlabels is None:
xlabels = range(f1_matrix.shape[1])
if ylabels is None:
ylabels = range(f1_matrix.shape[0])
data = pd.DataFrame(f1_matrix, columns=xlabels, index=ylabels)
sns.heatmap(
data,
annot=True,
fmt=".1f",
linewidth=0.5,
cmap="Blues",
cbar_kws={"location": "top", "label": "F1 score"},
ax=ax,
)
# Finalize plot
ax.set_ylabel("Motif ID")
ax.tick_params(left=False, bottom=False)
[docs]
def plot_transition_graph(
trans_prob, node_sizes=150, edge_vmin=None, edge_vmax=None, cmap=None, ax=None
):
"""
Plots a transition graph based on the provided transition probability matrix.
Parameters
----------
trans_prob : numpy.ndarray
A square matrix where element [i, j] represents the probability of transitioning
from state i to state j.
node_sizes : int or list, optional
The size of the nodes in the graph. If an int, all nodes will have the same
size.
edge_vmin : float, optional
Minimum value for the edge colormap normalization.
edge_vmax : float, optional
Maximum value for the edge colormap normalization.
cmap : matplotlib.colors.Colormap, optional
Colormap for the edges of the graph.
ax : matplotlib.axes.Axes, optional
Axes object to draw the plot onto, otherwise uses the current axes.
Returns
-------
nodes : matplotlib.collections.PathCollection
The drawn nodes of the graph.
labels : dict
The labels of the nodes in the graph.
cbar : matplotlib.colorbar.Colorbar
Colorbar corresponding to the edge weights.
"""
if ax is None:
_, ax = plt.subplots()
# Find number of states
num_states = trans_prob.shape[0]
# Make graph
G = nx.from_numpy_array(trans_prob, create_using=nx.DiGraph)
pos = nx.circular_layout(G)
edgelist, weights = zip(*nx.get_edge_attributes(G, "weight").items(), strict=True)
nodes = nx.draw_networkx_nodes(
G,
pos,
ax=ax,
node_size=node_sizes,
node_color=range(num_states),
cmap=get_custom_cmap(num_states),
)
labels = nx.draw_networkx_labels(
G,
pos=pos,
font_color="w",
font_size=mpl.rcParams["font.size"],
font_weight="bold",
ax=ax,
)
edges = nx.draw_networkx_edges(
G,
pos,
ax=ax,
node_size=node_sizes,
arrowstyle="->",
arrowsize=10,
connectionstyle="arc3,rad=0.3",
edge_cmap=cmap,
width=2,
edgelist=edgelist,
edge_color=weights,
edge_vmin=edge_vmin,
edge_vmax=edge_vmax,
)
# Make colorbar
pc = mpl.collections.PatchCollection(edges, cmap=cmap)
pc.set_array(weights)
cbar = plt.colorbar(
pc,
ax=ax,
location="top",
label="Transition probability",
shrink=0.95,
)
# Remove axis
ax.set_axis_off()
return nodes, labels, cbar
[docs]
def plot_points(data, y, ax):
"""
Plots boxplots and strip plots for the given data based on Motif ID.
Parameters
----------
data : pandas.DataFrame
The data containing 'Motif ID' and the variable to plot on the y-axis.
y : str
The name of the column in `data` to be plotted on the y-axis.
ax : matplotlib.axes.Axes
Axes object to draw the plot onto.
Returns
-------
None
"""
# Select colors
num_states = data["Motif ID"].max() + 1
cmap = get_custom_cmap(num_states)
palette = [cmap(c) for c in range(num_states)]
sns.boxplot(
data=data,
x="Motif ID",
y=y,
hue="Motif ID",
fill=False,
showfliers=False,
legend=False,
palette=palette,
ax=ax,
)
sns.stripplot(
data=data,
x="Motif ID",
y=y,
hue="Motif ID",
size=1.5,
legend=False,
palette=palette,
ax=ax,
)
[docs]
def plot_group_points(data, y, test_dict, ax):
"""
Plots group-wise boxplots and strip plots with statistical test results.
Parameters
----------
data : pandas.DataFrame
The data containing 'Motif ID', 'Group label', and the variable to plot on the
y-axis.
y : str
The name of the column in `data` to be plotted on the y-axis.
test_dict : dict
A dictionary where keys are Motif IDs and values are statistical test results
(e.g., t-test results).
ax : matplotlib.axes.Axes
Axes object to draw the plot onto.
Returns
-------
None
"""
sns.boxplot(
data=data,
x="Motif ID",
y=y,
hue="Group label",
fill=False,
showfliers=False,
ax=ax,
)
sns.stripplot(
data=data,
x="Motif ID",
y=y,
hue="Group label",
size=1.5,
dodge=True,
legend=False,
ax=ax,
)
# Plot pval
for pos, (motif_id, ttest_result) in enumerate(test_dict.items()):
print(f"Motif {motif_id} pval = {ttest_result.pvalue:.5f}")
ax.text(
pos,
1,
pval2star(ttest_result.pvalue),
horizontalalignment="center",
verticalalignment="baseline",
transform=blended_transform_factory(ax.transData, ax.transAxes),
color="k",
)
[docs]
def plot_embedding_summary(
embeddings,
labels,
predictions,
fps=None,
cmap_human=None,
cmap_machine=None,
axarr=None,
):
"""
Plots a summary of embeddings over time, including expert and predicted labels.
Parameters
----------
embeddings : numpy.ndarray
The embedding matrix where each row corresponds to a time point.
labels : numpy.ndarray
Array of expert-provided labels corresponding to the embeddings.
predictions : numpy.ndarray
Array of model-predicted labels corresponding to the embeddings.
axarr : numpy.ndarray of matplotlib.axes.Axes, optional
Array of Axes objects to draw the plots onto. If not provided, new subplots will
be created.
Returns
-------
None
"""
if axarr is None:
fig, axarr = plt.subplots(
3,
2,
sharex="col",
height_ratios=(1, 0.2, 0.2),
width_ratios=(1, 0.02),
gridspec_kw={"hspace": 0.05},
)
else:
fig = axarr[0][0].get_figure()
n_classes = int(np.max(labels)) + 1
n_states = int(np.max(predictions)) + 1
if fps is not None:
extent = [0, embeddings.shape[0] / fps, 0, embeddings.shape[1]]
else:
extent = None
# Plot embedding over time
ax, cax = axarr[0]
im = ax.imshow(
embeddings.T,
aspect="auto",
interpolation="none",
cmap="Spectral_r",
extent=extent,
)
fig.colorbar(im, cax=cax, label="Activation")
ax.set_ylabel("Embedding")
ax.set_yticks([])
# Manual annotations
ax, cax = axarr[1]
im = ax.imshow(
labels[np.newaxis],
aspect="auto",
interpolation="none",
cmap=get_custom_cmap(n_classes) if cmap_human is None else cmap_human,
vmin=-0.5,
vmax=n_classes - 0.5,
extent=extent,
)
fig.colorbar(im, cax=cax, ticks=mticker.FixedLocator(range(n_classes)))
ax.set_ylabel("Human")
ax.set_yticks([])
# HMM annotations
ax, cax = axarr[2]
im = ax.imshow(
predictions[np.newaxis],
aspect="auto",
interpolation="none",
cmap=get_custom_cmap(n_states) if cmap_machine is None else cmap_machine,
vmin=-0.5,
vmax=n_states - 0.5,
extent=extent,
)
fig.colorbar(im, cax=cax, label="Motif ID")
ax.set_ylabel("LISBET")
ax.set_yticks([])
# Finalize plot
fig.align_ylabels(axarr[[0, 2], 1])
ax.set_xlabel("Frame" if fps is None else "Time (s)")
[docs]
def plot_slh_score(all_n_clusters, all_score, best_n_clusters, best_score, ax):
"""
Plots the silhouette score as a function of the number of clusters.
Parameters
----------
all_n_clusters : list or numpy.ndarray
A list of different numbers of clusters.
all_score : list or numpy.ndarray
Corresponding silhouette scores for each number of clusters.
best_n_clusters : int
The number of clusters with the best silhouette score.
best_score : float
The best silhouette score obtained.
ax : matplotlib.axes.Axes
Axes object to draw the plot onto.
Returns
-------
None
"""
ax.plot(all_n_clusters, all_score)
ax.scatter(best_n_clusters, best_score, c="red", label="best")
ax.axhline(best_score, ls="dashed", color="k")
# ax.axvline(best_n_clusters, ls="dashed", color="k")
# Finalize plot
ax.set_xlabel("No. clusters")
ax.set_ylabel("Silhouette score")
ax.set_title("Average")
ax.legend()
[docs]
def plot_slh_profile(distance, link_matrix, cluster_labels, ax):
"""
Plots the silhouette score profile for hierarchical clustering.
Parameters
----------
distance : numpy.ndarray
Precomputed distance matrix.
link_matrix : numpy.ndarray
Linkage matrix obtained from hierarchical clustering.
cluster_labels : numpy.ndarray
Cluster labels for each sample.
ax : matplotlib.axes.Axes
Axes object to draw the plot onto.
Returns
-------
None
"""
n_clusters = len(np.unique(cluster_labels))
indices = hierarchy.leaves_list(link_matrix)
# Plot average silhouette score
slh_avg = silhouette_score(distance, cluster_labels, metric="precomputed")
ax.axhline(slh_avg, ls="dashed", color="k", label="Average silhouette score")
# Compute silhouette score profile
sample_slh_values = silhouette_samples(
distance, cluster_labels, metric="precomputed"
)
# Compute colors
cmap = get_custom_cmap(n_clusters)
# Plot silhouette score profile
x = np.arange(len(sample_slh_values))
colors = [cmap.colors[cluster_labels[i]] for i in indices]
ax.bar(x, sample_slh_values[indices], color=colors)
# Create axes for colobar
ax_cbar = ax.inset_axes([0, -0.2, 1, 0.15])
ax_cbar.sharex(ax)
# Plot colorbar
cbar_values = np.array(cluster_labels, ndmin=2)
ax_cbar.matshow(cbar_values[:, indices], aspect="auto", cmap=cmap)
# Plot cluster ids
res, ind = np.unique(cluster_labels[indices], return_index=True)
loc = 0
for cluster_id in res[np.argsort(ind)]:
cluster_size = np.sum(cluster_labels == cluster_id)
if cluster_id % 1 == 0:
ax_cbar.text(
loc + cluster_size / 2 - 0.5,
0,
f"{cluster_id}",
horizontalalignment="center",
verticalalignment="center",
color="w",
weight="bold",
)
loc += cluster_size
# Finalize plot
ax.set_title("Profile")
ax.set_xticks(range(len(indices)), labels=indices)
plt.setp(ax.get_yticklabels(), visible=False)
# ax.legend()
ax_cbar.set_yticks([])
ax_cbar.set_ylabel("Prototype ID")
ax_cbar.yaxis.set_label_position("right")
sns.despine(ax=ax_cbar, bottom=True, left=True)
[docs]
def plot_heatmap(distance, link_matrix, cluster_labels, prototypes, ax):
"""
Plots a heatmap of the sorted distance matrix along with cluster and prototype
information.
Parameters
----------
distance : numpy.ndarray
Precomputed distance matrix.
link_matrix : numpy.ndarray
Linkage matrix obtained from hierarchical clustering.
cluster_labels : numpy.ndarray
Cluster labels for each sample.
prototypes : numpy.ndarray
Array of prototype indices corresponding to each cluster.
ax : matplotlib.axes.Axes
Axes object to draw the plot onto.
Returns
-------
None
"""
n_clusters = len(np.unique(cluster_labels))
ax_cbar = ax.inset_axes([0.03, 0.13, 0.1, 0.02])
# Sort distance matrix
indices = hierarchy.leaves_list(link_matrix)
sorted_distance = distance[indices, :][:, indices]
# Plot heatmap
im = ax.imshow(
sorted_distance,
cmap="Reds_r",
aspect="auto",
)
# Add colorbar
plt.colorbar(
mappable=im,
cax=ax_cbar,
orientation="horizontal",
label="distance",
)
# Highlight clusters and prototypes
res, ind = np.unique(cluster_labels[indices], return_index=True)
loc = 0
for cid in res[np.argsort(ind)]:
width = np.sum(cluster_labels == cid)
ax.add_patch(
Rectangle(
(loc - 0.5, loc - 0.5),
width,
width,
fill=False,
edgecolor="k",
lw=1,
)
)
loc += width
for cid in range(n_clusters):
proto_loc = np.where(indices == prototypes[cid])[0]
ax.scatter(
proto_loc,
proto_loc,
marker="o",
s=15,
fc="white",
ec="blue",
label="HMM prototype" if cid == 1 else None,
)
# Finalize plot
ax.set_xticks([])
ax.set_yticks([])
# ax.set_yticks(range(len(indices)), labels=indices)
ax.set_xlabel("Motif ID")
ax.set_ylabel("Motif ID")
ax.yaxis.set_label_position("right")
ax.legend()
sns.despine(ax=ax, bottom=True, left=True)
[docs]
def plot_dendrogram(linkage_matrix, cluster_labels, ax):
"""
Plots a dendrogram for hierarchical clustering with clusters colored according to
labels.
Parameters
----------
linkage_matrix : numpy.ndarray
Linkage matrix obtained from hierarchical clustering.
cluster_labels : numpy.ndarray
Cluster labels for each sample.
ax : matplotlib.axes.Axes
Axes object to draw the plot onto.
Returns
-------
None
"""
# NOTE: Coloring a dendrogram is a nightmare! For this reason we use a workaround.
# In brief, we estimate the threshold to get the right number of clusters and
# let the dendrogram function cycle over the colors as usual.
n_clusters = len(np.unique(cluster_labels))
indices = hierarchy.leaves_list(linkage_matrix)
# Configure colormap
cmap = get_custom_cmap(n_clusters)
res, ind = np.unique(cluster_labels[indices], return_index=True)
hierarchy.set_link_color_palette(
[mpl.colors.rgb2hex(cmap.colors[i][:3]) for i in res[np.argsort(ind)]]
)
# Find color threshold
def thr_minf(thr):
# Count the number of "above" U links (easier), must be exactly n_clusters - 1
return np.abs(np.sum(linkage_matrix[:, 2] > thr) - (n_clusters - 1))
res = direct(
thr_minf,
bounds=[(min(linkage_matrix[:, 2]), max(linkage_matrix[:, 2]))],
f_min=0,
f_min_rtol=0,
)
color_thr = res.x[0]
hierarchy.dendrogram(
linkage_matrix,
orientation="left",
no_labels=True,
color_threshold=color_thr,
above_threshold_color="k",
ax=ax,
)
# Finalize plot
ax.set_ylabel("Motif hierarchy")
ax.invert_yaxis()
ax.set_xticks([])
sns.despine(ax=ax, bottom=True, left=True)