import os
import cv2
import tqdm
import imageio
import warnings
import logging
import h5py
import numpy as np
import plotly
import matplotlib.pyplot as plt
from scipy.ndimage import gaussian_filter1d
from vidio.read import OpenCVReader
from scipy.spatial.distance import squareform, pdist
from scipy.cluster.hierarchy import linkage, dendrogram, leaves_list
from textwrap import fill
from PIL import Image
from keypoint_moseq.util import *
from keypoint_moseq.io import load_results, _get_path
from jax_moseq.models.keypoint_slds import center_embedding
from jax_moseq.utils import get_durations, get_frequencies
from plotly.subplots import make_subplots
import plotly.io as pio
pio.renderers.default = "iframe"
# set matplotlib defaults
plt.rcParams["figure.dpi"] = 100
# suppress warnings from imageio
logging.getLogger().setLevel(logging.ERROR)
[docs]
def crop_image(image, centroid, crop_size):
"""Crop an image around a centroid.
Parameters
----------
image: ndarray of shape (height, width, 3)
Image to crop.
centroid: tuple of int
(x,y) coordinates of the centroid.
crop_size: int or tuple(int,int)
Size of the crop around the centroid. Either a single int for a square
crop, or a tuple of ints (w,h) for a rectangular crop.
Returns
-------
image: ndarray of shape (crop_size, crop_size, 3)
Cropped image.
"""
if isinstance(crop_size, tuple):
w, h = crop_size
else:
w, h = crop_size, crop_size
x, y = int(centroid[0]), int(centroid[1])
x_min = max(0, x - w // 2)
y_min = max(0, y - h // 2)
x_max = min(image.shape[1], x + w // 2)
y_max = min(image.shape[0], y + h // 2)
cropped = image[y_min:y_max, x_min:x_max]
padded = np.zeros((h, w, *image.shape[2:]), dtype=image.dtype)
pad_x = max(w // 2 - x, 0)
pad_y = max(h // 2 - y, 0)
padded[pad_y : pad_y + cropped.shape[0], pad_x : pad_x + cropped.shape[1]] = cropped
return padded
[docs]
def plot_scree(pca, savefig=True, project_dir=None, fig_size=(3, 2)):
"""Plot explained variance as a function of the number of PCs.
Parameters
----------
pca : :py:func:`sklearn.decomposition.PCA`
Fitted PCA model
savefig : bool, True
Whether to save the figure to a file. If true, the figure is saved to
`{project_dir}/pca_scree.pdf`.
project_dir : str, default=None
Path to the project directory. Required if `savefig` is True.
fig_size : tuple, (2.5,2)
Size of the figure in inches.
Returns
-------
fig : :py:class:`matplotlib.figure.Figure`
Figure handle
"""
fig = plt.figure()
num_pcs = len(pca.components_)
plt.plot(np.arange(num_pcs) + 1, np.cumsum(pca.explained_variance_ratio_))
plt.xlabel("PCs")
plt.ylabel("Explained variance")
plt.gcf().set_size_inches(fig_size)
plt.grid()
plt.tight_layout()
if savefig:
assert project_dir is not None, fill(
"The `savefig` option requires a `project_dir`"
)
plt.savefig(os.path.join(project_dir, "pca_scree.pdf"))
plt.show()
return fig
[docs]
def plot_pcs(
pca,
*,
use_bodyparts,
skeleton,
keypoint_colormap="autumn",
keypoint_colors=None,
savefig=True,
project_dir=None,
scale=1,
plot_n_pcs=10,
axis_size=(2, 1.5),
ncols=5,
node_size=30.0,
line_width=2.0,
interactive=True,
**kwargs,
):
"""
Visualize the components of a fitted PCA model.
For each PC, a subplot shows the mean pose (semi-transparent) along with a
perturbation of the mean pose in the direction of the PC.
Parameters
----------
pca : :py:func:`sklearn.decomposition.PCA`
Fitted PCA model
use_bodyparts : list of str
List of bodyparts to that are used in the model; used to index bodypart
names in the skeleton.
skeleton : list
List of edges that define the skeleton, where each edge is a pair of
bodypart names.
keypoint_colormap : str
Name of a matplotlib colormap to use for coloring the keypoints.
keypoint_colors : array-like, shape=(num_keypoints,3), default=None
Color for each keypoint. If None, `keypoint_colormap` is used. If the
dtype is int, the values are assumed to be in the range 0-255,
otherwise they are assumed to be in the range 0-1.
savefig : bool, True
Whether to save the figure to a file. If true, the figure is saved to
`{project_dir}/pcs-{xy/xz/yz}.pdf` (`xz` and `yz` are only included
for 3D data).
project_dir : str, default=None
Path to the project directory. Required if `savefig` is True.
scale : float, default=0.5
Scale factor for the perturbation of the mean pose.
plot_n_pcs : int, default=10
Number of PCs to plot.
axis_size : tuple of float, default=(2,1.5)
Size of each subplot in inches.
ncols : int, default=5
Number of columns in the figure.
node_size : float, default=30.0
Size of the keypoints in the figure.
line_width: float, default=2.0
Width of edges in skeleton
interactive : bool, default=True
For 3D data, whether to generate an interactive 3D plot.
"""
k = len(use_bodyparts)
d = len(pca.mean_) // (k - 1)
if keypoint_colors is None:
cmap = plt.cm.get_cmap(keypoint_colormap)
keypoint_colors = cmap(np.linspace(0, 1, k))
Gamma = np.array(center_embedding(k))
edges = get_edges(use_bodyparts, skeleton)
plot_n_pcs = min(plot_n_pcs, pca.components_.shape[0])
magnitude = np.sqrt((pca.mean_**2).mean()) * scale
ymean = Gamma @ pca.mean_.reshape(k - 1, d)
ypcs = (pca.mean_ + magnitude * pca.components_).reshape(-1, k - 1, d)
ypcs = Gamma[np.newaxis] @ ypcs[:plot_n_pcs]
if d == 2:
dims_list, names = [[0, 1]], ["xy"]
if d == 3:
dims_list, names = [[0, 1], [0, 2]], ["xy", "xz"]
for dims, name in zip(dims_list, names):
nrows = int(np.ceil(plot_n_pcs / ncols))
fig, axs = plt.subplots(nrows, ncols, sharex=True, sharey=True)
for i, ax in enumerate(axs.flat):
if i >= plot_n_pcs:
ax.axis("off")
continue
for e in edges:
ax.plot(
*ymean[:, dims][e].T,
color=keypoint_colors[e[0]],
zorder=0,
alpha=0.25,
linewidth=line_width,
)
ax.plot(
*ypcs[i][:, dims][e].T,
color="k",
zorder=2,
linewidth=line_width + 0.2,
)
ax.plot(
*ypcs[i][:, dims][e].T,
color=keypoint_colors[e[0]],
zorder=3,
linewidth=line_width,
)
ax.scatter(
*ymean[:, dims].T,
c=keypoint_colors,
s=node_size,
zorder=1,
alpha=0.25,
linewidth=0,
)
ax.scatter(
*ypcs[i][:, dims].T,
c=keypoint_colors,
s=node_size,
zorder=4,
edgecolor="k",
linewidth=0.2,
)
ax.set_title(f"PC {i+1}", fontsize=10)
ax.set_aspect("equal")
ax.axis("off")
fig.set_size_inches((axis_size[0] * ncols, axis_size[1] * nrows))
plt.tight_layout()
if savefig:
assert project_dir is not None, fill(
"The `savefig` option requires a `project_dir`"
)
plt.savefig(os.path.join(project_dir, f"pcs-{name}.pdf"))
plt.show()
if interactive and d == 3:
plot_pcs_3D(
ymean,
ypcs,
edges,
keypoint_colormap,
project_dir if savefig else None,
node_size / 3,
line_width * 2,
)
[docs]
def plot_syllable_frequencies(
project_dir=None,
model_name=None,
results=None,
path=None,
minlength=10,
min_frequency=0.005,
):
"""Plot a histogram showing the frequency of each syllable.
Caller must provide a results dictionary, a path to a results .h5, or a
project directory and model name, in which case the results are loaded
from `{project_dir}/{model_name}/results.h5`.
Parameters
----------
results : dict, default=None
Dictionary containing modeling results for a dataset (see
:py:func:`keypoint_moseq.fitting.extract_results`)
model_name: str, default=None
Name of the model. Required to load results if `results` is None and
`path` is None.
project_dir: str, default=None
Project directory. Required to load results if `results` is None and
`path` is None.
path: str, default=None
Path to a results file. If None, results will be loaded from
`{project_dir}/{model_name}/results.h5`.
minlength: int, default=10
Minimum x-axis length of the histogram.
min_frequency: float, default=0.005
Minimum frequency of syllables to include in the histogram.
Returns
-------
fig : matplotlib.figure.Figure
Figure containing the histogram.
ax : matplotlib.axes.Axes
Axes containing the histogram.
"""
if results is None:
results = load_results(project_dir, model_name, path)
syllables = {k: res["syllable"] for k, res in results.items()}
frequencies = get_frequencies(syllables)
frequencies = frequencies[frequencies > min_frequency]
xmax = max(minlength, np.max(np.nonzero(frequencies > min_frequency)[0]) + 1)
fig, ax = plt.subplots()
ax.bar(range(len(frequencies)), frequencies, width=1)
ax.set_ylabel("probability")
ax.set_xlabel("syllable rank")
ax.set_xlim(-1, xmax + 1)
ax.set_title("Frequency distribution")
ax.set_yticks([])
return fig, ax
[docs]
def plot_duration_distribution(
project_dir=None,
model_name=None,
results=None,
path=None,
lim=None,
num_bins=30,
fps=None,
show_median=True,
):
"""Plot a histogram showing the frequency of each syllable.
Caller must provide a results dictionary, a path to a results .h5, or a
project directory and model name, in which case the results are loaded from
`{project_dir}/{model_name}/results.h5`.
Parameters
----------
results : dict, default=None
Dictionary containing modeling results for a dataset (see
:py:func:`keypoint_moseq.fitting.extract_results`)
model_name: str, default=None
Name of the model. Required to load results if `results` is None and
`path` is None.
project_dir: str, default=None
Project directory. Required to load results if `results` is None and
`path` is None.
path: str, default=None
Path to a results file. If None, results will be loaded from
`{project_dir}/{model_name}/results.h5`.
lim: tuple, default=None
x-axis limits as a pair of ints (in units of frames). If None, the
limits are set to (0, 95th-percentile).
num_bins: int, default=30
Number of bins in the histogram.
fps: int, default=None
Frames per second. Used to convert x-axis from frames to seconds.
show_median: bool, default=True
Whether to show the median duration as a vertical line.
Returns
-------
fig : matplotlib.figure.Figure
Figure containing the histogram.
ax : matplotlib.axes.Axes
Axes containing the histogram.
"""
if results is None:
results = load_results(project_dir, model_name, path)
syllables = {k: res["syllable"] for k, res in results.items()}
durations = get_durations(syllables)
if lim is None:
lim = int(np.percentile(durations, 95))
binsize = max(int(np.floor(lim / num_bins)), 1)
if fps is not None:
durations = durations / fps
binsize = binsize / fps
lim = lim / fps
xlabel = "syllable duration (s)"
else:
xlabel = "syllable duration (frames)"
fig, ax = plt.subplots()
ax.hist(durations, range=(0, lim), bins=(int(lim / binsize)), density=True)
ax.set_xlim([0, lim])
ax.set_xlabel(xlabel)
ax.set_ylabel("probability")
ax.set_title("Duration distribution")
ax.set_yticks([])
if show_median:
ax.axvline(np.median(durations), color="k", linestyle="--")
return fig, ax
[docs]
def plot_kappa_scan(kappas, project_dir, prefix, figsize=(8, 2.5)):
"""Plot the results of a kappa scan.
This function assumes that model results for each kappa value are stored
in `{project_dir}/{prefix}-{kappa}/checkpoint.h5`. Two plots are generated:
(1) a line plot showing the median syllable duration over the course of
fitting for each kappa value; (2) and a plot showing the final median
syllable duration as a function of kappa.
Parameters
----------
kappas : array-like of float
Kapppa values used in the scan.
project_dir : str
Path to the project directory.
prefix : str
Prefix for the kappa scan model names.
Returns
-------
fig : matplotlib.figure.Figure
Figure containing the plot.
final_median_durations : array of float
Median syllable durations for each kappa value, derived using the
final iteration of each model.
"""
median_dur_histories = []
final_median_durs = []
for kappa in tqdm.tqdm(kappas, desc="Loading checkpoints"):
model_dir = f"{project_dir}/{prefix}-{kappa}"
with h5py.File(f"{model_dir}/checkpoint.h5", "r") as h5:
mask = h5["data/mask"][()]
iterations = np.sort([int(i) for i in h5["model_snapshots"]])
history = {}
for itr in iterations:
z = h5[f"model_snapshots/{itr}/states/z"][()]
durs = get_durations(z, mask)
history[itr] = np.median(durs)
final_median_durs.append(history[iterations[-1]])
median_dur_histories.append(history)
fig, axs = plt.subplots(1, 2)
for i, (kappa, history) in enumerate(zip(kappas, median_dur_histories)):
color = plt.cm.viridis(i / (len(kappas) - 1))
label = "{:.1e}".format(kappa)
axs[0].plot(*zip(*history.items()), color=color, label=label)
axs[0].legend(loc="center left", bbox_to_anchor=(1, 0.5))
axs[0].set_xlabel("iteration")
axs[0].set_ylabel("median duration")
axs[1].scatter(kappas, final_median_durs)
axs[1].set_xlabel("kappa")
axs[1].set_ylabel("final median duration")
axs[1].set_xscale("log")
fig.set_size_inches(figsize)
plt.tight_layout()
return fig, np.array(final_median_durs)
[docs]
def plot_progress(
model,
data,
checkpoint_path,
iteration,
project_dir=None,
model_name=None,
path=None,
savefig=True,
fig_size=None,
window_size=600,
min_frequency=0.001,
min_histogram_length=10,
):
"""Plot the progress of the model during fitting.
The figure shows the following plots:
- Duration distribution:
The distribution of state durations for the most recent iteration
of the model.
- Frequency distribution:
The distribution of state frequencies for the most recent iteration
of the model.
- Median duration:
The median state duration across iterations.
- State sequence history
The state sequence across iterations in a random window (a new
window is selected each time the progress is plotted).
Parameters
----------
model : dict
Model dictionary containing `states`
data : dict
Data dictionary containing `mask`
checkpoint_path : str
Path to an HDF5 file containing model checkpoints.
iteration : int
Current iteration of model fitting
project_dir : str, default=None
Path to the project directory. Required if `savefig` is True.
model_name : str, default=None
Name of the model. Required if `savefig` is True.
savefig : bool, default=True
Whether to save the figure to a file. If true, the figure is either
saved to `path` or, to `{project_dir}/{model_name}-progress.pdf` if
`path` is None.
fig_size : tuple of float, default=None
Size of the figure in inches.
window_size : int, default=600
Window size for state sequence history plot.
min_frequency : float, default=.001
Minimum frequency for including a state in the frequency distribution
plot.
min_histogram_length : int, default=10
Minimum x-axis length of the frequency distribution plot.
Returns
-------
fig : matplotlib.figure.Figure
Figure containing the plots.
axs : list of matplotlib.axes.Axes
Axes containing the plots.
"""
z = np.array(model["states"]["z"])
mask = np.array(data["mask"])
durations = get_durations(z, mask)
frequencies = get_frequencies(z, mask)
with h5py.File(checkpoint_path, "r") as f:
saved_iterations = np.sort([int(i) for i in f["model_snapshots"]])
if len(saved_iterations) > 1:
fig, axs = plt.subplots(1, 4, gridspec_kw={"width_ratios": [1, 1, 1, 3]})
if fig_size is None:
fig_size = (12, 2.5)
else:
fig, axs = plt.subplots(1, 2)
if fig_size is None:
fig_size = (4, 2.5)
frequencies = np.sort(frequencies[frequencies > min_frequency])[::-1]
xmax = max(len(frequencies), min_histogram_length)
axs[0].bar(range(len(frequencies)), frequencies, width=1)
axs[0].set_ylabel("probability")
axs[0].set_xlabel("syllable rank")
axs[0].set_xlim([-1, xmax + 1])
axs[0].set_title("Frequency distribution")
axs[0].set_yticks([])
lim = int(np.percentile(durations, 95))
binsize = max(int(np.floor(lim / 30)), 1)
axs[1].hist(durations, range=(1, lim), bins=(int(lim / binsize)), density=True)
axs[1].set_xlim([1, lim])
axs[1].set_xlabel("syllable duration (frames)")
axs[1].set_ylabel("probability")
axs[1].set_title("Duration distribution")
axs[1].set_yticks([])
if len(saved_iterations) > 1:
window_size = int(min(window_size, mask.max(0).sum() - 1))
nz = np.stack(np.array(mask[:, window_size:]).nonzero(), axis=1)
batch_ix, start = nz[np.random.randint(nz.shape[0])]
sample_state_history = []
median_durations = []
for i in saved_iterations:
with h5py.File(checkpoint_path, "r") as f:
z = np.array(f[f"model_snapshots/{i}/states/z"])
sample_state_history.append(z[batch_ix, start : start + window_size])
median_durations.append(np.median(get_durations(z, mask)))
axs[2].scatter(saved_iterations, median_durations)
axs[2].set_ylim([-1, np.max(median_durations) * 1.1])
axs[2].set_xlabel("iteration")
axs[2].set_ylabel("duration")
axs[2].set_title("Median duration")
axs[3].imshow(
sample_state_history,
cmap=plt.cm.jet,
aspect="auto",
interpolation="nearest",
)
axs[3].set_xlabel("Time (frames)")
axs[3].set_ylabel("Iterations")
axs[3].set_title("State sequence history")
yticks = [
int(y) for y in axs[3].get_yticks() if y < len(saved_iterations) and y > 0
]
yticklabels = saved_iterations[yticks]
axs[3].set_yticks(yticks)
axs[3].set_yticklabels(yticklabels)
title = f"Iteration {iteration}"
if model_name is not None:
title = f"{model_name}: {title}"
fig.suptitle(title)
fig.set_size_inches(fig_size)
plt.tight_layout()
if savefig:
path = _get_path(project_dir, model_name, path, "fitting_progress.pdf")
plt.savefig(path)
plt.show()
return fig, axs
[docs]
def write_video_clip(frames, path, fps=30, quality=7):
"""Write a video clip to a file.
Parameters
----------
frames : np.ndarray
Video frames as a 4D array of shape `(num_frames, height, width, 3)`
or a 3D array of shape `(num_frames, height, width)`.
path : str
Path to save the video clip.
fps : int, default=30
Framerate of video encoding.
quality : int, default=7
Quality of video encoding.
"""
with imageio.get_writer(
path, pixelformat="yuv420p", fps=fps, quality=quality
) as writer:
for frame in frames:
writer.append_data(frame)
def _grid_movie_tile(
key,
start,
end,
videos,
centroids,
headings,
dot_color,
window_size,
scaled_window_size,
pre,
post,
dot_radius,
overlay_keypoints,
edges,
coordinates,
plot_options,
downsample_rate,
):
scale_factor = scaled_window_size / window_size
cs = centroids[key][start - pre : start + post]
h, c = headings[key][start], cs[pre]
r = np.float32([[np.cos(h), np.sin(h)], [-np.sin(h), np.cos(h)]])
tile = []
if videos is not None:
frames = videos[key][
(start - pre) * downsample_rate : (start + post) * downsample_rate
][::downsample_rate]
c = r @ c - window_size // 2
M = [[np.cos(h), np.sin(h), -c[0]], [-np.sin(h), np.cos(h), -c[1]]]
for ii, (frame, c) in enumerate(zip(frames, cs)):
if overlay_keypoints:
coords = coordinates[key][start - pre + ii]
frame = overlay_keypoints_on_image(
frame, coords, edges=edges, **plot_options
)
frame = cv2.warpAffine(frame, np.float32(M), (window_size, window_size))
frame = cv2.resize(frame, (scaled_window_size, scaled_window_size))
if 0 <= ii - pre <= end - start and dot_radius > 0:
pos = tuple([int(x) for x in M @ np.append(c, 1) * scale_factor])
cv2.circle(frame, pos, dot_radius, dot_color, -1, cv2.LINE_AA)
tile.append(frame)
else: # first transform keypoints, then overlay on black background
assert overlay_keypoints, fill(
"If no videos are provided, then `overlay_keypoints` must "
"be True. Otherwise there is nothing to show"
)
scale_factor = scaled_window_size / window_size
coords = coordinates[key][start - pre : start + post]
coords = (coords - c) @ r.T * scale_factor + scaled_window_size // 2
cs = (cs - c) @ r.T * scale_factor + scaled_window_size // 2
background = np.zeros((scaled_window_size, scaled_window_size, 3))
for ii, (uvs, c) in enumerate(zip(coords, cs)):
frame = overlay_keypoints_on_image(
background.copy(), uvs, edges=edges, **plot_options
)
if 0 <= ii - pre <= end - start and dot_radius > 0:
pos = (int(c[0]), int(c[1]))
cv2.circle(frame, pos, dot_radius, dot_color, -1, cv2.LINE_AA)
tile.append(frame)
return np.stack(tile)
[docs]
def grid_movie(
instances,
rows,
cols,
videos,
centroids,
headings,
window_size,
dot_color=(255, 255, 255),
dot_radius=4,
pre=30,
post=60,
scaled_window_size=None,
edges=[],
overlay_keypoints=False,
coordinates=None,
plot_options={},
downsample_rate=1,
):
"""Generate a grid movie and return it as an array of frames.
Grid movies show many instances of a syllable. Each instance contains a snippet of
video (and/or keypoint-overlay) centered on the animal and synchronized to the onset
of the syllable. A dot appears at syllable onset and disappears at syllable offset.
Parameters
----------
instances: list of tuples `(key, start, end)`
List of syllable instances to include in the grid movie, where each instance is
specified as a tuple with the video name, start frame and end frame. The list
must have length `rows*cols`. The video names must also be keys in `videos`.
rows: int, cols : int
Number of rows and columns in the grid movie grid
videos: dict or None
Dictionary mapping video names to video readers. Frames from each reader should
be accessible via `__getitem__(int or slice)`. If None, the the grid movie will
not include video frames.
centroids: dict
Dictionary mapping video names to arrays of shape `(n_frames, 2)` with the x,y
coordinates of animal centroid on each frame
headings: dict
Dictionary mapping video names to arrays of shape `(n_frames,)` with the heading
of the animal on each frame (in radians)
window_size: int
Size of the window around the animal. This should be a multiple of 16 or imageio
will complain.
dot_color: tuple of ints, default=(255,255,255)
RGB color of the dot indicating syllable onset and offset
dot_radius: int, default=4
Radius of the dot indicating syllable onset and offset
pre: int, default=30
Number of frames before syllable onset to include in the movie
post: int, default=60
Number of frames after syllable onset to include in the movie
scaled_window_size: int, default=None
Window size after scaling the video. If None, the no scaling is performed (i.e.
`scaled_window_size = window_size`)
overlay_keypoints: bool, default=False
If True, overlay the pose skeleton on the video frames.
edges: list of tuples, default=[]
List of edges defining pose skeleton. Used when `overlay_keypoints=True`.
coordinates: dict, default=None
Dictionary mapping video names to arrays of shape `(n_frames, 2)`. Used when
`overlay_keypoints=True`.
plot_options: dict, default={}
Dictionary of options to pass to `overlay_keypoints_on_image`. Used when
`overlay_keypoints=True`.
downsample_rate: int, default=1
Downsampling rate for the video frames. Coordinates at index `i` will be
plotted on the video frame at index `i*downsample_rate`.
Returns
-------
frames: array of shape `(post+pre, width, height, 3)`
Array of frames in the grid movie where::
width = rows * scaled_window_size
height = cols * scaled_window_size
"""
if videos is None:
assert overlay_keypoints, fill(
"If no videos are provided, then `overlay_keypoints` must "
"be True. Otherwise there is nothing to show"
)
if scaled_window_size is None:
scaled_window_size = window_size
tiles = []
for key, start, end in instances:
tiles.append(
_grid_movie_tile(
key,
start,
end,
videos,
centroids,
headings,
dot_color,
window_size,
scaled_window_size,
pre,
post,
dot_radius,
overlay_keypoints,
edges,
coordinates,
plot_options,
downsample_rate,
)
)
tiles = np.stack(tiles).reshape(
rows, cols, post + pre, scaled_window_size, scaled_window_size, 3
)
frames = np.concatenate(np.concatenate(tiles, axis=2), axis=2)
return frames
[docs]
def get_grid_movie_window_size(
sampled_instances,
centroids,
headings,
coordinates,
pre,
post,
pctl=90,
fudge_factor=1.1,
blocksize=16,
):
"""Automatically determine the window size for a grid movie.
The window size is set such that across all sampled instances,
the animal is fully visible in at least `pctl` percent of frames.
Parameters
----------
sampled_instances: dict
Dictionary mapping syllables to lists of instances, where each
instance is specified as a tuple with the video name, start frame
and end frame.
centroids: dict
Dictionary mapping video names to arrays of shape `(n_frames, 2)`
with the x,y coordinates of animal centroid on each frame
headings: dict
Dictionary mapping video names to arrays of shape `(n_frames,)`
with the heading of the animal on each frame (in radians)
coordinates: dict
Dictionary mapping recording names to keypoint coordinates as
ndarrays of shape (n_frames, n_bodyparts, 2).
pre, post: int
Number of frames before/after syllable onset that are included
in the grid movies.
pctl: int, default=95
Percentile of frames in which the animal should be fully visible.
fudge_factor: float, default=1.1
Factor by which to multiply the window size.
blocksize: int, default=16
Window size is rounded up to the nearest multiple of `blocksize`.
"""
all_trajectories = get_instance_trajectories(
sum(sampled_instances.values(), []),
coordinates,
pre=pre,
post=post,
centroids=centroids,
headings=headings,
)
all_trajectories = np.concatenate(all_trajectories, axis=0)
all_trajectories = all_trajectories[~np.isnan(all_trajectories).all((1, 2))]
max_distances = np.nanmax(np.abs(all_trajectories), axis=1)
window_size = np.percentile(max_distances, pctl) * fudge_factor * 2
window_size = int(np.ceil(window_size / blocksize) * blocksize)
return window_size
[docs]
def generate_grid_movies(
results,
project_dir=None,
model_name=None,
output_dir=None,
video_dir=None,
video_paths=None,
rows=4,
cols=6,
filter_size=9,
pre=30,
post=60,
min_frequency=0.005,
min_duration=3,
dot_radius=4,
dot_color=(255, 255, 255),
quality=7,
window_size=None,
coordinates=None,
centroids=None,
headings=None,
bodyparts=None,
use_bodyparts=None,
sampling_options={},
video_extension=None,
max_video_size=1920,
skeleton=[],
overlay_keypoints=False,
keypoints_only=False,
keypoints_scale=1,
fps=None,
plot_options={},
use_dims=[0, 1],
keypoint_colormap="autumn",
downsample_rate=1,
**kwargs,
):
"""Generate grid movies for a modeled dataset.
Grid movies show many instances of a syllable and are useful in
figuring out what behavior the syllable captures
(see :py:func:`keypoint_moseq.viz.grid_movie`). This method
generates a grid movie for each syllable that is used sufficiently
often (i.e. has at least `rows*cols` instances with duration
of at least `min_duration` and an overall frequency of at least
`min_frequency`). The grid movies are saved to `output_dir` if
specified, or else to `{project_dir}/{model_name}/grid_movies`.
Parameters
----------
results: dict
Dictionary containing modeling results for a dataset (see
:py:func:`keypoint_moseq.fitting.extract_results`)
project_dir: str, default=None
Project directory. Required to save grid movies if `output_dir`
is None.
model_name: str, default=None
Name of the model. Required to save grid movies if
`output_dir` is None.
output_dir: str, default=None
Directory where grid movies should be saved. If None, grid
movies will be saved to `{project_dir}/{model_name}/grid_movies`.
video_dir: str, default=None
Directory containing videos of the modeled data (see
:py:func:`keypoint_moseq.io.find_matching_videos`). Either `video_dir`
or `video_paths` must be provided unless `keypoints_only=True`.
video_paths: dict, default=None
Dictionary mapping recording names to video paths. The recording
names must correspond to keys in the results dictionary. Either
`video_dir` or `video_paths` must be provided unless
`keypoints_only=True`.
filter_size: int, default=9
Size of the median filter applied to centroids and headings
min_frequency: float, default=0.005
Minimum frequency of a syllable to be included in the grid movies.
min_duration: int, default=3
Minimum duration of a syllable instance to be included in the
grid movie for that syllable.
sampling_options: dict, default={}
Dictionary of options for sampling syllable instances (see
:py:func:`keypoint_moseq.util.sample_instances`).
coordinates: dict, default=None
Dictionary mapping recording names to keypoint coordinates as
ndarrays of shape (n_frames, n_bodyparts, [2 or 3]). Required when
`window_size=None`, or `overlay_keypoints=True`, or if using
density-based sampling (i.e. when `sampling_options['mode']=='density'`;
see :py:func:`keypoint_moseq.util.sample_instances`).
centroids: dict, default=None
Dictionary mapping recording names to arrays of shape `(n_frames, 2)`.
Overrides the centroid information in `results`.
headings: dict, default=None
Dictionary mapping recording names to arrays of shape `(n_frames,)`.
Overrides the heading information in `results`.
bodyparts: list of str, default=None
List of bodypart names in `coordinates`. Required when `coordinates` is
provided and bodyparts were reindexed for modeling.
use_bodyparts: list of str, default=None
Ordered list of bodyparts used for modeling. Required when
`coordinates` is provided and bodyparts were reindexed
for modeling.
quality: int, default=7
Quality of the grid movies. Higher values result in higher
quality movies but larger file sizes.
rows, cols, pre, post, dot_radius, dot_color, window_size
See :py:func:`keypoint_moseq.viz.grid_movie`
video_extension: str, default=None
Preferred video extension (passed to
:py:func:`keypoint_moseq.util.find_matching_videos`)
window_size: int, default=None
Size of the window around the animal. If None, the window
size is determined automatically based on the size of the
animal. If provided explicitly, `window_size` should be a
multiple of 16 or imageio will complain.
max_video_size: int, default=4000
Maximum size of the grid movie in pixels. If the grid movie
is larger than this, it will be downsampled.
skeleton: list of tuples, default=[]
List of tuples specifying the skeleton. Used when
`overlay_keypoints=True`.
overlay_keypoints: bool, default=False
Whether to overlay the keypoints on the grid movie.
keypoints_only: bool, default=False
Whether to only show the keypoints (i.e. no video frames).
Overrides `overlay_keypoints`. When this option is used,
the framerate should be explicitly specified using `fps`.
keypoints_scale: float, default=1
Factor to scale keypoint coordinates before plotting. Only used when
`keypoints_only=True`. This is useful when the keypoints are 3D and
encoded in units that are larger than a pixel.
fps: int, default=30
Framerate of the grid movie. If None, the framerate is determined
from the videos.
plot_options: dict, default={}
Dictionary of options to pass to
:py:func:`keypoint_moseq.viz.overlay_keypoints_on_image`.
use_dims: pair of ints, default=[0,1]
Dimensions to use for plotting keypoints. Only used when
`overlay_keypoints=True` and the keypoints are 3D.
keypoint_colormap: str, default='autumn'
Colormap used to color keypoints. Used when
`overlay_keypoints=True`.
downsample_rate: int, default=1
Downsampling rate for the video frames. Coordinates at index `i` will be
plotted on the video frame at index `i*downsample_rate`.
See :py:func:`keypoint_moseq.viz.grid_movie` for the remaining parameters.
Returns
-------
sampled_instances: dict
Dictionary mapping syllables to lists of instances shown in each in
grid movie (in row-major order), where each instance is specified as a
tuple with the video name, start frame and end frame.
"""
# check inputs
if keypoints_only:
overlay_keypoints = True
else:
assert (video_dir is not None) or (video_paths is not None), fill(
"Either `video_dir` or `video_paths` is required unless `keypoints_only=True`"
)
if window_size is None or overlay_keypoints:
assert coordinates is not None, fill(
"`coordinates` must be provided if `window_size` is None "
"or `overlay_keypoints` is True"
)
# prepare output directory
output_dir = _get_path(
project_dir, model_name, output_dir, "grid_movies", "output_dir"
)
if not os.path.exists(output_dir):
os.makedirs(output_dir)
print(f"Writing grid movies to {output_dir}")
# reindex coordinates if necessary
if not (bodyparts is None or use_bodyparts is None or coordinates is None):
coordinates = reindex_by_bodyparts(coordinates, bodyparts, use_bodyparts)
# get edges for plotting skeleton
edges = []
if len(skeleton) > 0 and overlay_keypoints:
edges = get_edges(use_bodyparts, skeleton)
# load results
if results is None:
results = load_results(project_dir, model_name)
# extract syllables from results
syllables = {k: v["syllable"] for k, v in results.items()}
# extract and smooth centroids and headings
if centroids is None:
centroids = {k: v["centroid"] for k, v in results.items()}
if headings is None:
headings = {k: v["heading"] for k, v in results.items()}
centroids, headings = filter_centroids_headings(
centroids, headings, filter_size=filter_size
)
# scale keypoints if necessary
if keypoints_only:
for k, v in coordinates.items():
coordinates[k] = v * keypoints_scale
for k, v in centroids.items():
centroids[k] = v * keypoints_scale
# load video readers if necessary
if not keypoints_only:
if video_paths is None:
video_paths = find_matching_videos(
results.keys(),
video_dir,
as_dict=True,
video_extension=video_extension,
)
check_video_paths(video_paths, results.keys())
videos = {k: OpenCVReader(path) for k, path in video_paths.items()}
if fps is None:
fps = list(videos.values())[0].fps
else:
videos = None
# sample instances for each syllable
syllable_instances = get_syllable_instances(
syllables,
pre=pre,
post=post,
min_duration=min_duration,
min_frequency=min_frequency,
min_instances=rows * cols,
)
if len(syllable_instances) == 0:
warnings.warn(
fill(
"No syllables with sufficient instances to make a grid movie. "
"This usually occurs when all frames have the same syllable label "
"(use `plot_syllable_frequencies` to check if this is the case)"
)
)
return
sampled_instances = sample_instances(
syllable_instances,
rows * cols,
coordinates=coordinates,
centroids=centroids,
headings=headings,
**sampling_options,
)
# if the data is 3D, pick 2 dimensions to use for plotting
keypoint_dimension = next(iter(centroids.values())).shape[-1]
if keypoint_dimension == 3:
ds = np.array(use_dims)
centroids = {k: v[:, ds] for k, v in centroids.items()}
if coordinates is not None:
coordinates = {k: v[:, :, ds] for k, v in coordinates.items()}
# determine window size for grid movies
if window_size is None:
window_size = get_grid_movie_window_size(
sampled_instances, centroids, headings, coordinates, pre, post
)
print(f"Using window size of {window_size} pixels")
if keypoints_only:
if window_size < 64:
warnings.warn(
fill(
"The scale of the keypoints is very small. This may result in "
"poor quality grid movies. Try increasing `keypoints_scale`."
)
)
# possibly reduce window size to keep grid movies under max_video_size
scaled_window_size = max_video_size / max(rows, cols)
scaled_window_size = int(np.floor(scaled_window_size / 16) * 16)
scaled_window_size = min(scaled_window_size, window_size)
scale_factor = scaled_window_size / window_size
if scale_factor < 1:
warnings.warn(
"\n"
+ fill(
f"Videos will be downscaled by a factor of {scale_factor:.2f} "
f"so that the grid movies are under {max_video_size} pixels. "
"Use `max_video_size` to increase or decrease this size limit."
)
+ "\n\n"
)
# add colormap to plot options
plot_options.update({"keypoint_colormap": keypoint_colormap})
# generate grid movies
for syllable, instances in tqdm.tqdm(
sampled_instances.items(), desc="Generating grid movies", ncols=72
):
frames = grid_movie(
instances,
rows,
cols,
videos,
centroids,
headings,
edges=edges,
window_size=window_size,
scaled_window_size=scaled_window_size,
dot_color=dot_color,
pre=pre,
post=post,
dot_radius=dot_radius,
overlay_keypoints=overlay_keypoints,
coordinates=coordinates,
plot_options=plot_options,
downsample_rate=downsample_rate,
)
path = os.path.join(output_dir, f"syllable{syllable}.mp4")
write_video_clip(frames, path, fps=fps, quality=quality)
return sampled_instances
[docs]
def get_limits(
coordinates,
pctl=1,
blocksize=None,
left=0.2,
right=0.2,
top=0.2,
bottom=0.2,
):
"""Get axis limits based on the coordinates of all keypoints.
For each axis, limits are determined using the percentiles
`pctl` and `100-pctl` and then padded by `padding`.
Parameters
----------
coordinates: ndarray or dict
Coordinates as an ndarray of shape (..., 2), or a dict
with values that are ndarrays of shape (..., 2).
pctl: float, default=1
Percentile to use for determining the axis limits.
blocksize: int, default=None
Axis limits are cast to integers and padded so that the width
and height are multiples of `blocksize`. This is useful
when they are used for generating cropped images for a video.
left, right, top, bottom: float, default=0.1
Fraction of the axis range to pad on each side.
Returns
-------
lims: ndarray of shape (2,dim)
Axis limits, in the format `[[xmin,ymin,...],[xmax,ymax,...]]`.
"""
if isinstance(coordinates, dict):
X = np.concatenate(list(coordinates.values())).reshape(-1, 2)
else:
X = coordinates.reshape(-1, 2)
xmin, ymin = np.nanpercentile(X, pctl, axis=0)
xmax, ymax = np.nanpercentile(X, 100 - pctl, axis=0)
width = xmax - xmin
height = ymax - ymin
xmin -= width * left
xmax += width * right
ymin -= height * bottom
ymax += height * top
lims = np.array([[xmin, ymin], [xmax, ymax]])
if blocksize is not None:
lims = np.round(lims)
padding = np.mod(lims[0] - lims[1], blocksize) / 2
lims[0] -= padding
lims[1] += padding
lims = np.ceil(lims)
return lims.astype(int)
def rasterize_figure(fig):
canvas = fig.canvas
canvas.draw()
width, height = canvas.get_width_height()
raster_flat = np.frombuffer(canvas.tostring_rgb(), dtype="uint8")
raster = raster_flat.reshape((height, width, 3))
return raster
[docs]
def plot_trajectories(
titles,
Xs,
lims,
edges=[],
n_cols=4,
invert=False,
keypoint_colormap="autumn",
keypoint_colors=None,
node_size=50.0,
line_width=3.0,
alpha=0.2,
num_timesteps=10,
plot_width=4,
overlap=(0.2, 0),
return_rasters=False,
):
"""Plot one or more pose trajectories on a common axis and return the axis.
(See :py:func:`keypoint_moseq.viz.generate_trajectory_plots`)
Parameters
----------
titles: list of str
List of titles for each trajectory plot.
Xs: list of ndarray
List of pose trajectories as ndarrays of shape
(n_frames, n_keypoints, 2).
lims: ndarray
Axis limits used for all the trajectory plots. The limits
should be provided as an array of shape (2,2) with the format
`[[xmin,ymin],[xmax,ymax]]`.
edges: list of tuples, default=[]
List of edges, where each edge is a tuple of two integers
n_cols: int, default=4
Number of columns in the figure (used when plotting multiple
trajectories).
invert: bool, default=False
Determines the background color of the figure. If `True`,
the background will be black.
keypoint_colormap : str or list
Name of a matplotlib colormap or a list of colors as (r,b,g)
tuples in the same order as as the keypoints.
keypoint_colors : array-like, shape=(num_keypoints,3), default=None
Color for each keypoint. If None, the keypoint colormap is used.
If the dtype is int, the values are assumed to be in the range 0-255,
otherwise they are assumed to be in the range 0-1.
node_size: int, default=50
Size of each keypoint.
line_width: int, default=3
Width of the lines connecting keypoints.
alpha: float, default=0.2
Opacity of fade-out layers.
num_timesteps: int, default=10
Number of timesteps to plot for each trajectory. The pose
at each timestep is determined by linearly interpolating
between the keypoints.
plot_width: int, default=4
Width of each trajectory plot in inches. The height is
determined by the aspect ratio of `lims`. The final figure
width is `fig_width * min(n_cols, len(X))`.
overlap: tuple of float, default=(0.2,0)
Amount of overlap between each trajectory plot as a tuple
with the format `(x_overlap, y_overlap)`. The values should
be between 0 and 1.
return_rasters: bool, default=False
Rasterize the matplotlib canvas after plotting each step of
the trajecory. This is used to generate an animated video/gif
of the trajectory.
Returns
-------
fig : :py:class:`matplotlib.figure.Figure`
Figure handle
ax: matplotlib.axes.Axes
Axis containing the trajectory plots.
"""
fill_color = "k" if invert else "w"
if keypoint_colors is None:
cmap = plt.colormaps[keypoint_colormap]
colors = plt.get_cmap(cmap)(np.linspace(0, 1, Xs[0].shape[1]))
elif isinstance(keypoint_colors[0][0], int):
colors = list(np.array(keypoint_colors) / 255)
else:
colors = list(keypoint_colors)
n_cols = min(n_cols, len(Xs))
n_rows = np.ceil(len(Xs) / n_cols)
offsets = np.stack(
np.meshgrid(
np.arange(n_cols) * np.diff(lims[:, 0]) * (1 - overlap[0]),
np.arange(n_rows) * np.diff(lims[:, 1]) * (overlap[1] - 1),
),
axis=-1,
).reshape(-1, 2)[: len(Xs)]
Xs = interpolate_along_axis(
np.linspace(0, Xs[0].shape[0], num_timesteps),
np.arange(Xs[0].shape[0]),
np.array(Xs),
axis=1,
)
Xs = Xs + offsets[:, None, None]
xmin, ymin = lims[0] + offsets.min(0)
xmax, ymax = lims[1] + offsets.max(0)
fig, ax = plt.subplots(frameon=False)
ax.fill_between(
[xmin, xmax],
y1=[ymax, ymax],
y2=[ymin, ymin],
facecolor=fill_color,
zorder=0,
clip_on=False,
)
title_xy = (lims * np.array([[0.5, 0.1], [0.5, 0.9]])).sum(0)
title_color = "w" if invert else "k"
for xy, text in zip(offsets + title_xy, titles):
ax.text(
*xy,
text,
c=title_color,
ha="center",
va="top",
zorder=Xs.shape[1] * 4 + 4,
)
# final extents in axis
final_width = xmax - xmin
final_height = title_xy[1] - ymin
fig_width = plot_width * (n_cols - (n_cols - 1) * overlap[0])
fig_height = final_height / final_width * fig_width
fig.set_size_inches((fig_width, fig_height))
ax.set_xlim(xmin, xmax)
ax.set_ylim(ymin, ymax)
ax.set_aspect("equal")
ax.axis("off")
plt.tight_layout()
rasters = [] # for making a gif
for i in range(Xs.shape[1]):
for X, offset in zip(Xs, offsets):
for ii, jj in edges:
ax.plot(
*X[i, (ii, jj)].T,
c="k",
zorder=i * 4,
linewidth=line_width,
clip_on=False,
)
for ii, jj in edges:
ax.plot(
*X[i, (ii, jj)].T,
c=colors[ii],
zorder=i * 4 + 1,
linewidth=line_width * 0.9,
clip_on=False,
)
ax.scatter(
*X[i].T,
c=colors,
zorder=i * 4 + 2,
edgecolor="k",
linewidth=0.4,
s=node_size,
clip_on=False,
)
if i < Xs.shape[1] - 1:
ax.fill_between(
[xmin, xmax],
y1=[ymax, ymax],
y2=[ymin, ymin],
facecolor=fill_color,
alpha=alpha,
zorder=i * 4 + 3,
clip_on=False,
)
if return_rasters:
rasters.append(rasterize_figure(fig))
return fig, ax, rasters
def save_gif(image_list, gif_filename, duration=0.5):
# Convert NumPy arrays to PIL Image objects
pil_images = [Image.fromarray(np.uint8(img)) for img in image_list]
# Save the PIL Images as an animated GIF
pil_images[0].save(
gif_filename,
save_all=True,
append_images=pil_images[1:],
duration=int(duration * 1000),
loop=0,
)
[docs]
def generate_trajectory_plots(
coordinates,
results,
project_dir=None,
model_name=None,
output_dir=None,
pre=5,
post=15,
min_frequency=0.005,
min_duration=3,
skeleton=[],
bodyparts=None,
use_bodyparts=None,
keypoint_colormap="autumn",
plot_options={},
get_limits_pctl=0,
padding={"left": 0.1, "right": 0.1, "top": 0.2, "bottom": 0.2},
lims=None,
save_individually=True,
save_gifs=True,
save_mp4s=False,
fps=30,
projection_planes=["xy", "xz"],
interactive=True,
density_sample=True,
sampling_options={"n_neighbors": 50},
**kwargs,
):
"""
Generate trajectory plots for a modeled dataset.
Each trajectory plot shows a sequence of poses along the average
trajectory through latent space associated with a given syllable.
A separate figure (and gif, optionally) is saved for each syllable,
along with a single figure showing all syllables in a grid. The
plots are saved to `{output_dir}` if it is provided, otherwise
they are saved to `{project_dir}/{model_name}/trajectory_plots`.
Plot-related parameters are described below. For the remaining
parameters see (:py:func:`keypoint_moseq.util.get_typical_trajectories`)
Parameters
----------
coordinates: dict
Dictionary mapping recording names to keypoint coordinates as
ndarrays of shape (n_frames, n_bodyparts, [2 or 3]).
results: dict
Dictionary containing modeling results for a dataset (see
:py:func:`keypoint_moseq.fitting.extract_results`).
project_dir: str, default=None
Project directory. Required to save trajectory plots if
`output_dir` is None.
model_name: str, default=None
Name of the model. Required to save trajectory plots if
`output_dir` is None.
output_dir: str, default=None
Directory where trajectory plots should be saved. If None,
plots will be saved to `{project_dir}/{model_name}/trajectory_plots`.
skeleton : list, default=[]
List of edges that define the skeleton, where each edge is a
pair of bodypart names or a pair of indexes.
keypoint_colormap : str
Name of a matplotlib colormap to use for coloring the keypoints.
plot_options: dict, default={}
Dictionary of options for trajectory plots (see
:py:func:`keypoint_moseq.util.plot_trajectories`).
get_limits_pctl: float, default=0
Percentile to use for determining the axis limits. Higher values lead
to tighter axis limits.
padding: dict, default={'left':0.1, 'right':0.1, 'top':0.2, 'bottom':0.2}
Padding around trajectory plots. Controls the the distance
between trajectories (when multiple are shown in one figure)
as well as the title offset.
lims: ndarray of shape (2,2), default=None
Axis limits used for all the trajectory plots with format
`[[xmin,ymin],[xmax,ymax]]`. If None, the limits are determined
automatically based on the coordinates of the keypoints using
:py:func:`keypoint_moseq.viz.get_limits`.
save_individually: bool, default=True
If True, a separate figure is saved for each syllable (in
addition to the grid figure).
save_gifs: bool, default=True
Whether to save an animated gif of the trajectory plots.
save_mp4s: bool, default=False
Whether to save videos of the trajectory plots as .mp4 files
fps: int, default=30
Framerate of the videos from which keypoints were derived.
Used to set the framerate of gifs when `save_gif=True`.
projection_planes: list (subset of ['xy', 'yz', 'xz']), default=['xy','xz']
For 3D data, defines the 2D plane(s) on which to project keypoint
coordinates. A separate plot will be saved for each plane with
the name of the plane (e.g. 'xy') as a suffix. This argument is
ignored for 2D data.
interactive: bool, default=True
For 3D data, whether to create an visualization that can be
rotated and zoomed. This argument is ignored for 2D data.
"""
plot_options.update({"keypoint_colormap": keypoint_colormap})
edges = [] if len(skeleton) == 0 else get_edges(use_bodyparts, skeleton)
output_dir = _get_path(
project_dir, model_name, output_dir, "trajectory_plots", "output_dir"
)
if not os.path.exists(output_dir):
os.makedirs(output_dir)
print(f"Saving trajectory plots to {output_dir}")
typical_trajectories = get_typical_trajectories(
coordinates,
results,
pre,
post,
min_frequency,
min_duration,
bodyparts,
use_bodyparts,
density_sample,
sampling_options,
)
syllable_ixs = sorted(typical_trajectories.keys())
titles = [f"Syllable{s}" for s in syllable_ixs]
Xs = np.stack([typical_trajectories[s] for s in syllable_ixs])
if Xs.shape[-1] == 3:
projection_planes = [
"".join(sorted(plane.lower())) for plane in projection_planes
]
assert set(projection_planes) <= set(["xy", "yz", "xz"]), fill(
"`projection_planes` must be a subset of `['xy','yz','xz']`"
)
if lims is not None:
assert lims.shape == (2, 3), fill(
"`lims` must be None or an ndarray of shape (2,3) when plotting 3D data"
)
all_Xs, all_lims, suffixes = [], [], []
for plane in projection_planes:
use_dims = {"xy": [0, 1], "yz": [1, 2], "xz": [0, 2]}[plane]
all_Xs.append(Xs[..., use_dims])
suffixes.append("." + plane)
if lims is None:
all_lims.append(get_limits(all_Xs[-1], pctl=get_limits_pctl, **padding))
else:
all_lims.append(lims[..., use_dims])
else:
all_Xs = [Xs * np.array([1, -1])] # flip y-axis
if lims is None:
lims = get_limits(all_Xs[-1], pctl=get_limits_pctl, **padding)
all_lims = [lims]
suffixes = [""]
for Xs_2D, lims, suffix in zip(all_Xs, all_lims, suffixes):
# individual plots
if save_individually:
desc = "Generating trajectory plots"
for title, X in tqdm.tqdm(
zip(titles, Xs_2D), desc=desc, total=len(titles), ncols=72
):
fig, ax, rasters = plot_trajectories(
[title],
X[None],
lims,
edges=edges,
return_rasters=(save_gifs or save_mp4s),
**plot_options,
)
plt.savefig(os.path.join(output_dir, f"{title}{suffix}.pdf"))
plt.close(fig=fig)
if save_gifs:
frame_duration = (pre + post) / len(rasters) / fps
path = os.path.join(output_dir, f"{title}{suffix}.gif")
save_gif(rasters, path, duration=frame_duration)
if save_mp4s:
use_fps = len(rasters) / (pre + post) * fps
path = os.path.join(output_dir, f"{title}{suffix}.mp4")
write_video_clip(rasters, path, fps=use_fps)
# grid plot
fig, ax, rasters = plot_trajectories(
titles,
Xs_2D,
lims,
edges=edges,
return_rasters=(save_gifs or save_mp4s),
**plot_options,
)
plt.savefig(os.path.join(output_dir, f"all_trajectories{suffix}.pdf"))
plt.show()
if save_gifs:
frame_duration = (pre + post) / len(rasters) / fps
path = os.path.join(output_dir, f"all_trajectories{suffix}.gif")
save_gif(rasters, path, duration=frame_duration)
if save_mp4s:
use_fps = len(rasters) / (pre + post) * fps
path = os.path.join(output_dir, f"all_trajectories{suffix}.mp4")
write_video_clip(rasters, path, fps=use_fps)
if interactive and Xs.shape[-1] == 3:
plot_trajectories_3D(Xs, titles, edges, output_dir, **plot_options)
[docs]
def overlay_keypoints_on_image(
image,
coordinates,
edges=[],
keypoint_colormap="autumn",
keypoint_colors=None,
node_size=5,
line_width=2,
copy=False,
opacity=1.0,
):
"""Overlay keypoints on an image.
Parameters
----------
image: ndarray of shape (height, width, 3)
Image to overlay keypoints on.
coordinates: ndarray of shape (num_keypoints, 2)
Array of keypoint coordinates.
edges: list of tuples, default=[]
List of edges that define the skeleton, where each edge is a
pair of indexes.
keypoint_colormap: str, default='autumn'
Name of a matplotlib colormap to use for coloring the keypoints.
keypoint_colors : array-like, shape=(num_keypoints,3), default=None
Color for each keypoint. If None, the keypoint colormap is used.
If the dtype is int, the values are assumed to be in the range 0-255,
otherwise they are assumed to be in the range 0-1.
node_size: int, default=5
Size of the keypoints.
line_width: int, default=2
Width of the skeleton lines.
copy: bool, default=False
Whether to copy the image before overlaying keypoints.
opacity: float, default=1.0
Opacity of the overlay graphics (0.0-1.0).
Returns
-------
image: ndarray of shape (height, width, 3)
Image with keypoints overlayed.
"""
if copy or opacity < 1.0:
canvas = image.copy()
else:
canvas = image
if keypoint_colors is None:
cmap = plt.colormaps[keypoint_colormap]
colors = np.array(cmap(np.linspace(0, 1, coordinates.shape[0])))[:, :3]
else:
colors = np.array(keypoint_colors)
if isinstance(colors[0, 0], float):
colors = [tuple([int(c) for c in cs * 255]) for cs in colors]
# overlay skeleton
for i, j in edges:
if np.isnan(coordinates[i, 0]) or np.isnan(coordinates[j, 0]):
continue
pos1 = (int(coordinates[i, 0]), int(coordinates[i, 1]))
pos2 = (int(coordinates[j, 0]), int(coordinates[j, 1]))
canvas = cv2.line(canvas, pos1, pos2, colors[i], line_width, cv2.LINE_AA)
# overlay keypoints
for i, (x, y) in enumerate(coordinates):
if np.isnan(x) or np.isnan(y):
continue
pos = (int(x), int(y))
canvas = cv2.circle(canvas, pos, node_size, colors[i], -1, lineType=cv2.LINE_AA)
if opacity < 1.0:
image = cv2.addWeighted(image, 1 - opacity, canvas, opacity, 0)
return image
[docs]
def overlay_keypoints_on_video(
video_path,
coordinates,
skeleton=[],
bodyparts=None,
use_bodyparts=None,
output_path=None,
show_frame_numbers=True,
text_color=(255, 255, 255),
crop_size=None,
frames=None,
quality=7,
centroid_smoothing_filter=10,
plot_options={},
downsample_rate=1,
):
"""Overlay keypoints on a video.
Parameters
----------
video_path: str
Path to a video file.
coordinates: ndarray of shape (num_frames, num_keypoints, 2)
Array of keypoint coordinates.
skeleton: list of tuples, default=[]
List of edges that define the skeleton, where each edge is a pair of bodypart
names or a pair of indexes.
bodyparts: list of str, default=None
List of bodypart names in `coordinates`. Required if `skeleton` is defined using
bodypart names.
use_bodyparts: list of str, default=None
Subset of bodyparts to plot. If None, all bodyparts are plotted.
output_path: str, default=None
Path to save the video. If None, the video is saved to `video_path` with the
suffix `_keypoints`.
show_frame_numbers: bool, default=True
Whether to overlay the frame number in the video.
text_color: tuple of int, default=(255,255,255)
Color for the frame number overlay.
crop_size: int, default=None
Size of the crop around the keypoints to overlay on the video. If None, the
entire video is used.
frames: iterable of int, default=None
Frames to overlay keypoints on. If None, all frames are used.
quality: int, default=7
Quality of the output video.
centroid_smoothing_filter: int, default=10
Amount of smoothing to determine cropping centroid.
plot_options: dict, default={}
Additional keyword arguments to pass to
:py:func:`keypoint_moseq.viz.overlay_keypoints`.
downsample_rate: int, default=1
Downsampling rate for the video frames. Coordinates at index `i` will be
overlayed on the frame at index `i*downsample_rate`.
"""
if output_path is None:
output_path = os.path.splitext(video_path)[0] + "_keypoints.mp4"
print(f"Saving video to {output_path}")
if bodyparts is not None:
if use_bodyparts is not None:
coordinates = reindex_by_bodyparts(coordinates, bodyparts, use_bodyparts)
else:
use_bodyparts = bodyparts
edges = get_edges(use_bodyparts, skeleton)
else:
edges = skeleton
if crop_size is not None:
outliers = np.any(np.isnan(coordinates), axis=2)
interpolated_coordinates = interpolate_keypoints(coordinates, outliers)
crop_centroid = np.nanmedian(interpolated_coordinates, axis=1)
crop_centroid = gaussian_filter1d(
crop_centroid, centroid_smoothing_filter, axis=0
)
reader = OpenCVReader(video_path)
fps = reader.fps
if frames is None:
frames = np.arange(len(reader))
with imageio.get_writer(
output_path, pixelformat="yuv420p", fps=fps, quality=quality
) as writer:
for frame in tqdm.tqdm(frames, ncols=72):
image = overlay_keypoints_on_image(
reader[frame], coordinates[frame], edges=edges, **plot_options
)
if crop_size is not None:
image = crop_image(image, crop_centroid[frame], crop_size)
if show_frame_numbers:
image = cv2.putText(
image,
f"Frame {frame}",
(10, 20),
cv2.FONT_HERSHEY_SIMPLEX,
0.5,
text_color,
1,
cv2.LINE_AA,
)
writer.append_data(image)
[docs]
def add_3D_pose_to_plotly_fig(
fig,
coords,
edges,
keypoint_colors,
node_size=50.0,
line_width=3.0,
visible=True,
opacity=1,
):
"""
Add a 3D pose to a plotly figure.
Parameters
----------
fig: plotly figure
Figure to which the pose should be added.
coords: ndarray (N,3)
3D coordinates of the pose.
edges: list of index pairs
Skeleton edges
keypoint_colors : array-like with shape (num_keypoints,3)
Color for each keypoint. If None, the keypoint colormap is used.
If the dtype is int, the values are assumed to be in the range 0-255,
otherwise they are assumed to be in the range 0-1.
node_size: float, default=50.0
Size of keypoints.
line_width: float, default=3.0
Width of skeleton edges.
visibility: bool, default=True
Initial visibility state of the nodes and edges
opacity: float, default=1
Opacity of the nodes and edges (0-1)
"""
if isinstance(keypoint_colors[0, 0], int):
keypoint_colors = np.array(keypoint_colors) / 255.0
marker = {
"size": node_size / 10,
"color": keypoint_colors,
"line": dict(color="black", width=0.5),
"opacity": opacity,
}
line = {"width": line_width, "color": f"rgba(0,0,0,{opacity})"}
fig.add_trace(
plotly.graph_objs.Scatter3d(
x=coords[:, 0],
y=coords[:, 1],
z=coords[:, 2],
mode="markers",
visible=visible,
marker=marker,
)
)
for e in edges:
fig.add_trace(
plotly.graph_objs.Scatter3d(
x=coords[e, 0],
y=coords[e, 1],
z=coords[e, 2],
mode="lines",
visible=visible,
line=line,
)
)
if keypoint_colors is None:
# Use a default color (for example, red) if no colors are provided
keypoint_colors = ["red"] * len(coords)
elif isinstance(keypoint_colors[0], int):
# Convert RGB values from [0, 255] to [0, 1]
keypoint_colors = np.array(keypoint_colors) / 255.0
else:
keypoint_colors = keypoint_colors
marker = {
"size": node_size,
"color": keypoint_colors,
"line": dict(color="black", width=0.5),
"opacity": opacity,
}
def plot_pcs_3D(
ymean,
ypcs,
edges,
keypoint_colors,
savefig,
project_dir=None,
node_size=50,
line_width=3,
height=400,
mean_pose_opacity=0.2,
):
"""
Visualize the components of a fitted PCA model based on 3D components.
For each PC, a subplot shows the mean pose (semi-transparent) along
with a perturbation of the mean pose in the direction of the PC.
Parameters
----------
ymean : ndarray (num_bodyparts, 3)
Mean pose.
ypcs : ndarray (num_pcs, num_bodyparts, 3)
Perturbations of the mean pose in the direction of each PC.
edges : list of index pairs
Skeleton edges.
keypoint_colors : array-like, shape=(num_keypoints,3), default=None
Color for each keypoint. If None, the keypoint colormap is used.
If the dtype is int, the values are assumed to be in the range 0-255,
otherwise they are assumed to be in the range 0-1.
savefig : bool
Whether to save the figure to a file. If true, the figure is
saved to `{project_dir}/pcs.html`
project_dir : str, default=None
Path to the project directory. Required if `savefig` is True.
node_size : float, default=50.0
Size of the keypoints in the figure.
line_width: float, default=3.0
Width of edges in skeleton
height : int, default=400
Height of the figure in pixels.
mean_pose_opacity: float, default=0.4
Opacity of the mean pose
"""
from plotly.subplots import make_subplots
fig = make_subplots(rows=1, cols=1, specs=[[{"type": "scatter3d"}]])
def visibility_mask(i):
visible = np.zeros((len(edges) + 1) * (len(ypcs) + 1))
visible[-(len(edges) + 1) :] = 1
visible[(len(edges) + 1) * i : (len(edges) + 1) * (i + 1)] = 1
return visible > 0
steps = []
for i, coords in enumerate(ypcs):
add_3D_pose_to_plotly_fig(
fig,
coords,
edges,
keypoint_colors,
visible=(i == 0),
node_size=node_size,
line_width=line_width,
)
steps.append(
dict(
method="update",
label=f"PC {i+1}",
args=[{"visible": visibility_mask(i)}],
)
)
add_3D_pose_to_plotly_fig(
fig,
ymean,
edges,
keypoint_colors,
opacity=mean_pose_opacity,
node_size=node_size,
line_width=line_width,
)
fig.update_layout(
height=height,
showlegend=False,
sliders=[dict(steps=steps)],
scene=dict(
xaxis=dict(showgrid=False, showbackground=False),
yaxis=dict(showgrid=False, showbackground=False),
zaxis=dict(showgrid=False, showline=True, linecolor="black"),
bgcolor="white",
aspectmode="data",
),
margin=dict(l=20, r=20, b=0, t=0, pad=10),
)
if savefig:
assert project_dir is not None, fill(
"The `savefig` option requires a `project_dir`"
)
save_path = os.path.join(project_dir, f"pcs.html")
fig.write_html(save_path)
print(f"Saved interactive plot to {save_path}")
fig.show()
def plot_trajectories_3D(
Xs,
titles,
edges,
output_dir,
keypoint_colormap="autumn",
keypoint_colors=None,
node_size=50.0,
line_width=3.0,
height=500,
skiprate=1,
):
"""
Visualize a set of 3D trajectories.
Parameters
----------
Xs : list of ndarrays (num_syllables, num_frames, num_bodyparts, 3)
Trajectories to visualize.
titles : list of str
Title for each trajectory.
edges : list of index pairs
Skeleton edges.
output_dir : str
Path to save the interactive plot.
keypoint_colormap : str, default='autumn'
Name of a matplotlib colormap to use for coloring the keypoints.
keypoint_colors : array-like, shape=(num_keypoints,3), default=None
Color for each keypoint. If None, the keypoint colormap is used.
If the dtype is int, the values are assumed to be in the range 0-255,
otherwise they are assumed to be in the range 0-1.
node_size : float, default=50.0
Size of the keypoints in the figure.
line_width: float, default=3.0
Width of edges in skeleton
height : int, default=500
Height of the figure in pixels.
skiprate : int, default=1
Plot every `skiprate` frames.
"""
from plotly.subplots import make_subplots
if keypoint_colors is None:
cmap = plt.colormaps[keypoint_colormap]
keypoint_colors = np.array(cmap(np.linspace(0, 1, Xs.shape[2])))[:, :3]
fig = make_subplots(rows=1, cols=1, specs=[[{"type": "scatter3d"}]])
Xs = Xs[:, ::skiprate]
def visibility_mask(i):
n = (len(edges) + 1) * len(Xs[1])
visible = np.zeros(n * len(Xs))
visible[n * i : n * (i + 1)] = 1
return visible > 0
steps = []
for i, X in enumerate(Xs):
opacities = np.linspace(0.3, 1, len(X) + 1)[1:] ** 2
for coords, opacity in zip(X, opacities):
add_3D_pose_to_plotly_fig(
fig,
coords,
edges,
keypoint_colors,
visible=(i == 0),
node_size=node_size,
line_width=line_width,
opacity=opacity,
)
steps.append(
dict(
method="update",
label=titles[i],
args=[{"visible": visibility_mask(i)}],
)
)
fig.update_layout(
height=height,
showlegend=False,
sliders=[dict(steps=steps)],
scene=dict(
xaxis=dict(showgrid=False, showbackground=False),
yaxis=dict(showgrid=False, showbackground=False),
zaxis=dict(showgrid=False, showline=True, linecolor="black"),
bgcolor="white",
aspectmode="data",
),
margin=dict(l=20, r=20, b=0, t=0, pad=10),
)
if output_dir is not None:
save_path = os.path.join(output_dir, f"all_trajectories.html")
fig.write_html(save_path)
print(f"Saved interactive trajectories plot to {save_path}")
fig.show()
[docs]
def plot_similarity_dendrogram(
coordinates,
results,
project_dir=None,
model_name=None,
save_path=None,
metric="cosine",
pre=5,
post=15,
min_frequency=0.005,
min_duration=3,
bodyparts=None,
use_bodyparts=None,
density_sample=False,
sampling_options={"n_neighbors": 50},
figsize=(6, 3),
**kwargs,
):
"""Plot a dendrogram showing the similarity between syllable trajectories.
The dendrogram is saved to `{save_path}` if it is provided, or
else to `{project_dir}/{model_name}/similarity_dendrogram.pdf`. Plot-
related parameters are described below. For the remaining parameters
see (:py:func:`keypoint_moseq.util.get_typical_trajectories`)
Parameters
----------
coordinates: dict
Dictionary mapping recording names to keypoint coordinates as
ndarrays of shape (n_frames, n_bodyparts, [2 or 3]).
results: dict
Dictionary containing modeling results for a dataset (see
:py:func:`keypoint_moseq.fitting.extract_results`).
project_dir: str, default=None
Project directory. Required to save figure if `save_path` is None.
model_name: str, default=None
Model name. Required to save figure if `save_path` is None.
save_path: str, default=None
Path to save the dendrogram plot (do not include an extension).
If None, the plot will be saved to
`{project_dir}/{name}/similarity_dendrogram.[pdf/png]`.
metric: str, default='cosine'
Distance metric to use. See :py:func:`scipy.spatial.pdist` for options.
figsize: tuple of float, default=(10,5)
Size of the dendrogram plot.
"""
save_path = _get_path(project_dir, model_name, save_path, "similarity_dendrogram")
distances, syllable_ixs = syllable_similarity(
coordinates,
results,
metric,
pre,
post,
min_frequency,
min_duration,
bodyparts,
use_bodyparts,
density_sample,
sampling_options,
)
Z = linkage(squareform(distances), "complete")
fig, ax = plt.subplots(1, 1)
labels = [f"Syllable {s}" for s in syllable_ixs]
dendrogram(Z, labels=labels, leaf_font_size=10, ax=ax, leaf_rotation=90)
ax.set_yticks([])
for spine in ax.spines.values():
spine.set_color("lightgray")
ax.set_title("Syllable similarity")
fig.set_size_inches(figsize)
print(f"Saving dendrogram plot to {save_path}")
for ext in ["pdf", "png"]:
plt.savefig(save_path + "." + ext)
[docs]
def matplotlib_colormap_to_plotly(cmap):
"""
Convert a matplotlib colormap to a plotly colormap.
Parameters
----------
cmap: str
Name of a matplotlib colormap.
Returns
-------
pl_colorscale: list
Plotly colormap.
"""
cmap = plt.colormaps[cmap]
pl_entries = 255
h = 1.0 / (pl_entries - 1)
pl_colorscale = []
for k in range(pl_entries):
C = (np.array(cmap(k * h)[:3]) * 255).astype(np.uint8)
pl_colorscale.append([k * h, "rgb" + str((C[0], C[1], C[2]))])
return pl_colorscale
[docs]
def initialize_3D_plot(height=500):
"""Create an empty 3D plotly figure."""
fig = make_subplots(rows=1, cols=1, specs=[[{"type": "scatter3d"}]])
fig.update_layout(
height=height,
showlegend=False,
scene=dict(
xaxis=dict(showgrid=False, showbackground=False),
yaxis=dict(showgrid=False, showbackground=False),
zaxis=dict(showgrid=False, showline=True, linecolor="black"),
bgcolor="white",
aspectmode="data",
),
margin=dict(l=20, r=20, b=0, t=0, pad=10),
)
return fig
[docs]
def add_3D_pose_to_fig(
fig,
coords,
edges,
keypoint_colormap="autumn",
node_size=6.0,
linewidth=3.0,
visible=True,
opacity=1,
):
"""Add a 3D pose to a plotly figure.
Parameters
----------
fig: plotly figure
Figure to which the pose should be added.
coords: ndarray (N,3)
3D coordinates of the pose.
edges: list of index pairs
Skeleton edges
keypoint_colormap: str, default='autumn'
Colormap to use for coloring keypoints.
node_size: float, default=6.0
Size of keypoints.
linewidth: float, default=3.0
Width of skeleton edges.
visibility: bool, default=True
Initial visibility state of the nodes and edges
opacity: float, default=1
Opacity of the nodes and edges (0-1)
"""
marker = {
"size": node_size,
"color": np.linspace(0, 1, len(coords)),
"colorscale": matplotlib_colormap_to_plotly(keypoint_colormap),
"line": dict(color="black", width=0.5),
"opacity": opacity,
}
line = {"width": linewidth, "color": f"rgba(0,0,0,{opacity})"}
fig.add_trace(
plotly.graph_objs.Scatter3d(
x=coords[:, 0],
y=coords[:, 1],
z=coords[:, 2],
mode="markers",
visible=visible,
marker=marker,
)
)
for e in edges:
fig.add_trace(
plotly.graph_objs.Scatter3d(
x=coords[e, 0],
y=coords[e, 1],
z=coords[e, 2],
mode="lines",
visible=visible,
line=line,
)
)
[docs]
def plot_pcs_3D(
ymean,
ypcs,
edges,
keypoint_colormap,
project_dir=None,
node_size=6,
linewidth=2,
height=400,
mean_pose_opacity=0.2,
):
"""
Visualize the components of a fitted PCA model based on 3D components.
For each PC, a subplot shows the mean pose (semi-transparent) along
with a perturbation of the mean pose in the direction of the PC.
Parameters
----------
ymean : ndarray (num_bodyparts, 3)
Mean pose.
ypcs : ndarray (num_pcs, num_bodyparts, 3)
Perturbations of the mean pose in the direction of each PC.
edges : list of index pairs
Skeleton edges.
keypoint_colormap : str
Name of a matplotlib colormap to use for coloring the keypoints.
project_dir : str, default=None
Path to the project directory. Required if `savefig` is True.
node_size : float, default=30.0
Size of the keypoints in the figure.
linewidth: float, default=2.0
Width of edges in skeleton
height : int, default=400
Height of the figure in pixels.
mean_pose_opacity: float, default=0.4
Opacity of the mean pose
"""
fig = initialize_3D_plot(height)
def visibility_mask(i):
visible = np.zeros((len(edges) + 1) * (len(ypcs) + 1))
visible[-(len(edges) + 1) :] = 1
visible[(len(edges) + 1) * i : (len(edges) + 1) * (i + 1)] = 1
return visible > 0
steps = []
for i, coords in enumerate(ypcs):
add_3D_pose_to_fig(
fig,
coords,
edges,
visible=(i == 0),
node_size=node_size,
linewidth=linewidth,
keypoint_colormap=keypoint_colormap,
)
steps.append(
dict(
method="update",
label=f"PC {i+1}",
args=[{"visible": visibility_mask(i)}],
)
)
add_3D_pose_to_fig(
fig,
ymean,
edges,
opacity=mean_pose_opacity,
node_size=node_size,
linewidth=linewidth,
keypoint_colormap=keypoint_colormap,
)
fig.update_layout(sliders=[dict(steps=steps)])
if project_dir is not None:
save_path = os.path.join(project_dir, f"pcs.html")
fig.write_html(save_path)
print(f"Saved interactive plot to {save_path}")
fig.show()
[docs]
def plot_trajectories_3D(
Xs,
titles,
edges,
output_dir,
keypoint_colormap="autumn",
node_size=8,
linewidth=3,
height=500,
skiprate=1,
):
"""
Visualize a set of 3D trajectories.
Parameters
----------
Xs : list of ndarrays (num_syllables, num_frames, num_bodyparts, 3)
Trajectories to visualize.
titles : list of str
Title for each trajectory.
edges : list of index pairs
Skeleton edges.
output_dir : str
Path to save the interactive plot.
keypoint_colormap : str, default='autumn'
Name of a matplotlib colormap to use for coloring the keypoints.
node_size : float, default=8.0
Size of the keypoints in the figure.
linewidth: float, default=3.0
Width of edges in skeleton
height : int, default=500
Height of the figure in pixels.
skiprate : int, default=1
Plot every `skiprate` frames.
"""
fig = initialize_3D_plot(height)
def visibility_mask(i):
n = (len(edges) + 1) * len(Xs[1])
visible = np.zeros(n * len(Xs))
visible[n * i : n * (i + 1)] = 1
return visible > 0
steps = []
Xs = Xs[:, ::skiprate]
for i, X in enumerate(Xs):
opacities = np.linspace(0.3, 1, len(X) + 1)[1:] ** 2
for coords, opacity in zip(X, opacities):
add_3D_pose_to_fig(
fig,
coords,
edges,
visible=(i == 0),
node_size=node_size,
linewidth=linewidth,
keypoint_colormap=keypoint_colormap,
opacity=opacity,
)
steps.append(
dict(
method="update",
label=titles[i],
args=[{"visible": visibility_mask(i)}],
)
)
fig.update_layout(sliders=[dict(steps=steps)])
if output_dir is not None:
save_path = os.path.join(output_dir, f"all_trajectories.html")
fig.write_html(save_path)
print(f"Saved interactive trajectories plot to {save_path}")
fig.show()
[docs]
def plot_poses_3D(
poses,
edges,
keypoint_colormap="autumn",
node_size=6.0,
linewidth=3.0,
):
"""Plot a sequence of 3D poses.
Parameters
----------
poses: array of shape (num_poses, num_bodyparts, 3)
3D poses to plot.
edges: list of index pairs
Skeleton edges.
keypoint_colormap: str, default='autumn'
Colormap to use for coloring keypoints.
node_size: float, default=6.0
Size of keypoints.
linewidth: float, default=3.0
Width of skeleton edges.
"""
fig = initialize_3D_plot()
def visibility_mask(i):
n = len(edges) + 1
visible = np.zeros(n * len(poses))
visible[n * i : n * (i + 1)] = 1
return visible > 0
steps = []
for i, pose in enumerate(poses):
add_3D_pose_to_fig(
fig,
pose,
edges,
visible=(i == 0),
keypoint_colormap=keypoint_colormap,
node_size=node_size,
linewidth=linewidth,
)
steps.append(
dict(
method="update",
label=f"Pose {i+1}",
args=[{"visible": visibility_mask(i)}],
)
)
fig.update_layout(sliders=[dict(steps=steps)])
fig.show()
[docs]
def hierarchical_clustering_order(X, dist_metric="euclidean", linkage_method="ward"):
"""Linearly order a set of points using hierarchical clustering.
Parameters
----------
X: ndarray of shape (num_points, num_features)
Points to order.
dist_metric: str, default='euclidean'
Distance metric to use.
linkage_method: str, default='ward'
Linkage method to use.
Returns
-------
ordering: ndarray of shape (num_points,)
Linear ordering of the points.
"""
D = pdist(X, dist_metric)
Z = linkage(D, linkage_method)
ordering = leaves_list(Z)
return ordering
[docs]
def plot_confusion_matrix(
results1, results2, min_frequency=0.005, sort=True, normalize=True
):
"""Plot a confusion matrix that compares syllables across two models.
Parameters
----------
results1: dict
Dictionary containing modeling results for the first model (see
:py:func:`keypoint_moseq.fitting.extract_results`).
results2: dict
Dictionary containing modeling results for the second model (see
:py:func:`keypoint_moseq.fitting.extract_results`).
min_frequency: float, default=0.005
Minimum frequency of a syllable to include in the confusion matrix.
sort: bool, default=True
Whether to sort the syllables from each model to emphasize the diagonal.
normalize: bool, default=True
Whether to row-normalize the confusion matrix.
Returns
-------
fig: matplotlib figure
Figure containing the confusion matrix.
ax: matplotlib axis
Axis containing the confusion matrix.
"""
syllables1 = np.concatenate(
[results1[k]["syllable"] for k in sorted(results1.keys())]
)
syllables2 = np.concatenate(
[results2[k]["syllable"] for k in sorted(results2.keys())]
)
C = np.zeros((np.max(syllables1) + 1, np.max(syllables2) + 1))
np.add.at(C, (syllables1, syllables2), 1)
if normalize:
C = C / np.sum(C, axis=1, keepdims=True)
ix1 = (get_frequencies(syllables1) > 0.005).nonzero()[0]
ix2 = (get_frequencies(syllables2) > 0.005).nonzero()[0]
C = C[ix1, :][:, ix2]
if sort:
row_order = hierarchical_clustering_order(C)
C = C[row_order, :]
ix1 = ix1[row_order]
col_order = np.argsort(np.argmax(C, axis=0))
C = C[:, col_order]
ix2 = ix2[col_order]
fig, ax = plt.subplots(1, 1)
im = ax.imshow(C)
ax.set_xticks(np.arange(len(ix2)))
ax.set_xticklabels(ix2)
ax.set_yticks(np.arange(len(ix1)))
ax.set_yticklabels(ix1)
ax.set_xlabel("Model 2")
ax.set_ylabel("Model 1")
ax.set_title("Confusion matrix")
cbar = fig.colorbar(im, ax=ax)
cbar.set_label("Probability")
fig.tight_layout()
return fig, ax
[docs]
def plot_eml_scores(eml_scores, eml_std_errs, model_names):
"""Plot expected marginal likelihood scores for a set of models.
Parameters
----------
eml_scores: ndarray of shape (num_models,)
EML score for each model.
eml_std_errs: ndarray of shape (num_models,)
Standard error of the EML score for each model.
model_names: list of str
Name of each model.
"""
num_models = len(eml_scores)
ordering = np.argsort(eml_scores)
eml_scores = eml_scores[ordering]
eml_std_errs = eml_std_errs[ordering]
model_names = [model_names[i] for i in ordering]
err_low = eml_scores - eml_std_errs
err_high = eml_scores + eml_std_errs
fig, ax = plt.subplots(1, 1, figsize=(4, 3.5))
for i in range(num_models):
ax.plot([i, i], [err_low[i], err_high[i]], c="k", linewidth=1)
ax.scatter(range(num_models), eml_scores, c="k")
ax.set_xticks(range(num_models))
ax.set_xticklabels(model_names, rotation=90)
ax.set_ylabel("EML score")
plt.tight_layout()
return fig, ax