import os
import cv2
import tqdm
import imageio
import warnings
import logging
import h5py
import numpy as np
import plotly
import matplotlib.pyplot as plt
from scipy.ndimage import gaussian_filter1d
from vidio.read import OpenCVReader
from scipy.spatial.distance import squareform, pdist
from scipy.cluster.hierarchy import linkage, dendrogram, leaves_list
from textwrap import fill
from PIL import Image
from keypoint_moseq.util import *
from keypoint_moseq.io import load_results, _get_path
from jax_moseq.models.keypoint_slds import center_embedding
from jax_moseq.utils import get_durations, get_frequencies
from plotly.subplots import make_subplots
import plotly.io as pio
pio.renderers.default = "iframe"
# set matplotlib defaults
plt.rcParams["figure.dpi"] = 100
# suppress warnings from imageio
logging.getLogger().setLevel(logging.ERROR)
[docs]
def crop_image(image, centroid, crop_size):
"""Crop an image around a centroid.
Parameters
----------
image: ndarray of shape (height, width, 3)
Image to crop.
centroid: tuple of int
(x,y) coordinates of the centroid.
crop_size: int or tuple(int,int)
Size of the crop around the centroid. Either a single int for a square
crop, or a tuple of ints (w,h) for a rectangular crop.
Returns
-------
image: ndarray of shape (crop_size, crop_size, 3)
Cropped image.
"""
if isinstance(crop_size, tuple):
w, h = crop_size
else:
w, h = crop_size, crop_size
x, y = int(centroid[0]), int(centroid[1])
x_min = max(0, x - w // 2)
y_min = max(0, y - h // 2)
x_max = min(image.shape[1], x + w // 2)
y_max = min(image.shape[0], y + h // 2)
cropped = image[y_min:y_max, x_min:x_max]
padded = np.zeros((h, w, *image.shape[2:]), dtype=image.dtype)
pad_x = max(w // 2 - x, 0)
pad_y = max(h // 2 - y, 0)
padded[pad_y : pad_y + cropped.shape[0], pad_x : pad_x + cropped.shape[1]] = cropped
return padded
[docs]
def plot_scree(pca, savefig=True, project_dir=None, fig_size=(3, 2)):
"""Plot explained variance as a function of the number of PCs.
Parameters
----------
pca : :py:func:`sklearn.decomposition.PCA`
Fitted PCA model
savefig : bool, True
Whether to save the figure to a file. If true, the figure is saved to
`{project_dir}/pca_scree.pdf`.
project_dir : str, default=None
Path to the project directory. Required if `savefig` is True.
fig_size : tuple, (2.5,2)
Size of the figure in inches.
Returns
-------
fig : :py:class:`matplotlib.figure.Figure`
Figure handle
"""
fig = plt.figure()
num_pcs = len(pca.components_)
plt.plot(np.arange(num_pcs) + 1, np.cumsum(pca.explained_variance_ratio_))
plt.xlabel("PCs")
plt.ylabel("Explained variance")
plt.gcf().set_size_inches(fig_size)
plt.grid()
plt.tight_layout()
if savefig:
assert project_dir is not None, fill(
"The `savefig` option requires a `project_dir`"
)
plt.savefig(os.path.join(project_dir, "pca_scree.pdf"))
plt.show()
return fig
[docs]
def plot_pcs(
pca,
*,
use_bodyparts,
skeleton,
keypoint_colormap="autumn",
keypoint_colors=None,
savefig=True,
project_dir=None,
scale=1,
plot_n_pcs=10,
axis_size=(2, 1.5),
ncols=5,
node_size=30.0,
line_width=2.0,
interactive=True,
**kwargs,
):
"""
Visualize the components of a fitted PCA model.
For each PC, a subplot shows the mean pose (semi-transparent) along with a
perturbation of the mean pose in the direction of the PC.
Parameters
----------
pca : :py:func:`sklearn.decomposition.PCA`
Fitted PCA model
use_bodyparts : list of str
List of bodyparts to that are used in the model; used to index bodypart
names in the skeleton.
skeleton : list
List of edges that define the skeleton, where each edge is a pair of
bodypart names.
keypoint_colormap : str
Name of a matplotlib colormap to use for coloring the keypoints.
keypoint_colors : array-like, shape=(num_keypoints,3), default=None
Color for each keypoint. If None, `keypoint_colormap` is used. If the
dtype is int, the values are assumed to be in the range 0-255,
otherwise they are assumed to be in the range 0-1.
savefig : bool, True
Whether to save the figure to a file. If true, the figure is saved to
`{project_dir}/pcs-{xy/xz/yz}.pdf` (`xz` and `yz` are only included
for 3D data).
project_dir : str, default=None
Path to the project directory. Required if `savefig` is True.
scale : float, default=0.5
Scale factor for the perturbation of the mean pose.
plot_n_pcs : int, default=10
Number of PCs to plot.
axis_size : tuple of float, default=(2,1.5)
Size of each subplot in inches.
ncols : int, default=5
Number of columns in the figure.
node_size : float, default=30.0
Size of the keypoints in the figure.
line_width: float, default=2.0
Width of edges in skeleton
interactive : bool, default=True
For 3D data, whether to generate an interactive 3D plot.
"""
k = len(use_bodyparts)
d = len(pca.mean_) // (k - 1)
if keypoint_colors is None:
cmap = plt.cm.get_cmap(keypoint_colormap)
keypoint_colors = cmap(np.linspace(0, 1, k))
Gamma = np.array(center_embedding(k))
edges = get_edges(use_bodyparts, skeleton)
plot_n_pcs = min(plot_n_pcs, pca.components_.shape[0])
magnitude = np.sqrt((pca.mean_**2).mean()) * scale
ymean = Gamma @ pca.mean_.reshape(k - 1, d)
ypcs = (pca.mean_ + magnitude * pca.components_).reshape(-1, k - 1, d)
ypcs = Gamma[np.newaxis] @ ypcs[:plot_n_pcs]
if d == 2:
dims_list, names = [[0, 1]], ["xy"]
if d == 3:
dims_list, names = [[0, 1], [0, 2]], ["xy", "xz"]
for dims, name in zip(dims_list, names):
nrows = int(np.ceil(plot_n_pcs / ncols))
fig, axs = plt.subplots(nrows, ncols, sharex=True, sharey=True)
for i, ax in enumerate(axs.flat):
if i >= plot_n_pcs:
ax.axis("off")
continue
for e in edges:
ax.plot(
*ymean[:, dims][e].T,
color=keypoint_colors[e[0]],
zorder=0,
alpha=0.25,
linewidth=line_width,
)
ax.plot(
*ypcs[i][:, dims][e].T,
color="k",
zorder=2,
linewidth=line_width + 0.2,
)
ax.plot(
*ypcs[i][:, dims][e].T,
color=keypoint_colors[e[0]],
zorder=3,
linewidth=line_width,
)
ax.scatter(
*ymean[:, dims].T,
c=keypoint_colors,
s=node_size,
zorder=1,
alpha=0.25,
linewidth=0,
)
ax.scatter(
*ypcs[i][:, dims].T,
c=keypoint_colors,
s=node_size,
zorder=4,
edgecolor="k",
linewidth=0.2,
)
ax.set_title(f"PC {i+1}", fontsize=10)
ax.set_aspect("equal")
ax.axis("off")
fig.set_size_inches((axis_size[0] * ncols, axis_size[1] * nrows))
plt.tight_layout()
if savefig:
assert project_dir is not None, fill(
"The `savefig` option requires a `project_dir`"
)
plt.savefig(os.path.join(project_dir, f"pcs-{name}.pdf"))
plt.show()
if interactive and d == 3:
plot_pcs_3D(
ymean,
ypcs,
edges,
keypoint_colormap,
project_dir if savefig else None,
node_size / 3,
line_width * 2,
)
[docs]
def plot_syllable_frequencies(
project_dir=None,
model_name=None,
results=None,
path=None,
minlength=10,
min_frequency=0.005,
):
"""Plot a histogram showing the frequency of each syllable.
Caller must provide a results dictionary, a path to a results .h5, or a
project directory and model name, in which case the results are loaded
from `{project_dir}/{model_name}/results.h5`.
Parameters
----------
results : dict, default=None
Dictionary containing modeling results for a dataset (see
:py:func:`keypoint_moseq.fitting.extract_results`)
model_name: str, default=None
Name of the model. Required to load results if `results` is None and
`path` is None.
project_dir: str, default=None
Project directory. Required to load results if `results` is None and
`path` is None.
path: str, default=None
Path to a results file. If None, results will be loaded from
`{project_dir}/{model_name}/results.h5`.
minlength: int, default=10
Minimum x-axis length of the histogram.
min_frequency: float, default=0.005
Minimum frequency of syllables to include in the histogram.
Returns
-------
fig : matplotlib.figure.Figure
Figure containing the histogram.
ax : matplotlib.axes.Axes
Axes containing the histogram.
"""
if results is None:
results = load_results(project_dir, model_name, path)
syllables = {k: res["syllable"] for k, res in results.items()}
frequencies = get_frequencies(syllables)
frequencies = frequencies[frequencies > min_frequency]
xmax = max(minlength, np.max(np.nonzero(frequencies > min_frequency)[0]) + 1)
fig, ax = plt.subplots()
ax.bar(range(len(frequencies)), frequencies, width=1)
ax.set_ylabel("probability")
ax.set_xlabel("syllable rank")
ax.set_xlim(-1, xmax + 1)
ax.set_title("Frequency distribution")
ax.set_yticks([])
return fig, ax
[docs]
def plot_duration_distribution(
project_dir=None,
model_name=None,
results=None,
path=None,
lim=None,
num_bins=30,
fps=None,
show_median=True,
):
"""Plot a histogram showing the frequency of each syllable.
Caller must provide a results dictionary, a path to a results .h5, or a
project directory and model name, in which case the results are loaded from
`{project_dir}/{model_name}/results.h5`.
Parameters
----------
results : dict, default=None
Dictionary containing modeling results for a dataset (see
:py:func:`keypoint_moseq.fitting.extract_results`)
model_name: str, default=None
Name of the model. Required to load results if `results` is None and
`path` is None.
project_dir: str, default=None
Project directory. Required to load results if `results` is None and
`path` is None.
path: str, default=None
Path to a results file. If None, results will be loaded from
`{project_dir}/{model_name}/results.h5`.
lim: tuple, default=None
x-axis limits as a pair of ints (in units of frames). If None, the
limits are set to (0, 95th-percentile).
num_bins: int, default=30
Number of bins in the histogram.
fps: int, default=None
Frames per second. Used to convert x-axis from frames to seconds.
show_median: bool, default=True
Whether to show the median duration as a vertical line.
Returns
-------
fig : matplotlib.figure.Figure
Figure containing the histogram.
ax : matplotlib.axes.Axes
Axes containing the histogram.
"""
if results is None:
results = load_results(project_dir, model_name, path)
syllables = {k: res["syllable"] for k, res in results.items()}
durations = get_durations(syllables)
if lim is None:
lim = int(np.percentile(durations, 95))
binsize = max(int(np.floor(lim / num_bins)), 1)
if fps is not None:
durations = durations / fps
binsize = binsize / fps
lim = lim / fps
xlabel = "syllable duration (s)"
else:
xlabel = "syllable duration (frames)"
fig, ax = plt.subplots()
ax.hist(durations, range=(0, lim), bins=(int(lim / binsize)), density=True)
ax.set_xlim([0, lim])
ax.set_xlabel(xlabel)
ax.set_ylabel("probability")
ax.set_title("Duration distribution")
ax.set_yticks([])
if show_median:
ax.axvline(np.median(durations), color="k", linestyle="--")
return fig, ax
[docs]
def plot_kappa_scan(kappas, project_dir, prefix, figsize=(8, 2.5)):
"""Plot the results of a kappa scan.
This function assumes that model results for each kappa value are stored
in `{project_dir}/{prefix}-{kappa}/checkpoint.h5`. Two plots are generated:
(1) a line plot showing the median syllable duration over the course of
fitting for each kappa value; (2) and a plot showing the final median
syllable duration as a function of kappa.
Parameters
----------
kappas : array-like of float
Kapppa values used in the scan.
project_dir : str
Path to the project directory.
prefix : str
Prefix for the kappa scan model names.
Returns
-------
fig : matplotlib.figure.Figure
Figure containing the plot.
final_median_durations : array of float
Median syllable durations for each kappa value, derived using the
final iteration of each model.
"""
median_dur_histories = []
final_median_durs = []
for kappa in tqdm.tqdm(kappas, desc="Loading checkpoints"):
model_dir = f"{project_dir}/{prefix}-{kappa}"
with h5py.File(f"{model_dir}/checkpoint.h5", "r") as h5:
mask = h5["data/mask"][()]
iterations = np.sort([int(i) for i in h5["model_snapshots"]])
history = {}
for itr in iterations:
z = h5[f"model_snapshots/{itr}/states/z"][()]
durs = get_durations(z, mask)
history[itr] = np.median(durs)
final_median_durs.append(history[iterations[-1]])
median_dur_histories.append(history)
fig, axs = plt.subplots(1, 2)
for i, (kappa, history) in enumerate(zip(kappas, median_dur_histories)):
color = plt.cm.viridis(i / (len(kappas) - 1))
label = "{:.1e}".format(kappa)
axs[0].plot(*zip(*history.items()), color=color, label=label)
axs[0].legend(loc="center left", bbox_to_anchor=(1, 0.5))
axs[0].set_xlabel("iteration")
axs[0].set_ylabel("median duration")
axs[1].scatter(kappas, final_median_durs)
axs[1].set_xlabel("kappa")
axs[1].set_ylabel("final median duration")
axs[1].set_xscale("log")
fig.set_size_inches(figsize)
plt.tight_layout()
return fig, np.array(final_median_durs)
[docs]
def plot_progress(
model,
data,
checkpoint_path,
iteration,
project_dir=None,
model_name=None,
path=None,
savefig=True,
fig_size=None,
window_size=600,
min_frequency=0.001,
min_histogram_length=10,
):
"""Plot the progress of the model during fitting.
The figure shows the following plots:
- Duration distribution:
The distribution of state durations for the most recent iteration
of the model.
- Frequency distribution:
The distribution of state frequencies for the most recent iteration
of the model.
- Median duration:
The median state duration across iterations.
- State sequence history
The state sequence across iterations in a random window (a new
window is selected each time the progress is plotted).
Parameters
----------
model : dict
Model dictionary containing `states`
data : dict
Data dictionary containing `mask`
checkpoint_path : str
Path to an HDF5 file containing model checkpoints.
iteration : int
Current iteration of model fitting
project_dir : str, default=None
Path to the project directory. Required if `savefig` is True.
model_name : str, default=None
Name of the model. Required if `savefig` is True.
savefig : bool, default=True
Whether to save the figure to a file. If true, the figure is either
saved to `path` or, to `{project_dir}/{model_name}-progress.pdf` if
`path` is None.
fig_size : tuple of float, default=None
Size of the figure in inches.
window_size : int, default=600
Window size for state sequence history plot.
min_frequency : float, default=.001
Minimum frequency for including a state in the frequency distribution
plot.
min_histogram_length : int, default=10
Minimum x-axis length of the frequency distribution plot.
Returns
-------
fig : matplotlib.figure.Figure
Figure containing the plots.
axs : list of matplotlib.axes.Axes
Axes containing the plots.
"""
z = np.array(model["states"]["z"])
mask = np.array(data["mask"])
durations = get_durations(z, mask)
frequencies = get_frequencies(z, mask)
with h5py.File(checkpoint_path, "r") as f:
saved_iterations = np.sort([int(i) for i in f["model_snapshots"]])
if len(saved_iterations) > 1:
fig, axs = plt.subplots(1, 4, gridspec_kw={"width_ratios": [1, 1, 1, 3]})
if fig_size is None:
fig_size = (12, 2.5)
else:
fig, axs = plt.subplots(1, 2)
if fig_size is None:
fig_size = (4, 2.5)
frequencies = np.sort(frequencies[frequencies > min_frequency])[::-1]
xmax = max(len(frequencies), min_histogram_length)
axs[0].bar(range(len(frequencies)), frequencies, width=1)
axs[0].set_ylabel("probability")
axs[0].set_xlabel("syllable rank")
axs[0].set_xlim([-1, xmax + 1])
axs[0].set_title("Frequency distribution")
axs[0].set_yticks([])
lim = int(np.percentile(durations, 95))
binsize = max(int(np.floor(lim / 30)), 1)
axs[1].hist(durations, range=(1, lim), bins=(int(lim / binsize)), density=True)
axs[1].set_xlim([1, lim])
axs[1].set_xlabel("syllable duration (frames)")
axs[1].set_ylabel("probability")
axs[1].set_title("Duration distribution")
axs[1].set_yticks([])
if len(saved_iterations) > 1:
window_size = int(min(window_size, mask.max(0).sum() - 1))
nz = np.stack(np.array(mask[:, window_size:]).nonzero(), axis=1)
batch_ix, start = nz[np.random.randint(nz.shape[0])]
sample_state_history = []
median_durations = []
for i in saved_iterations:
with h5py.File(checkpoint_path, "r") as f:
z = np.array(f[f"model_snapshots/{i}/states/z"])
sample_state_history.append(z[batch_ix, start : start + window_size])
median_durations.append(np.median(get_durations(z, mask)))
axs[2].scatter(saved_iterations, median_durations)
axs[2].set_ylim([-1, np.max(median_durations) * 1.1])
axs[2].set_xlabel("iteration")
axs[2].set_ylabel("duration")
axs[2].set_title("Median duration")
axs[3].imshow(
sample_state_history,
cmap=plt.cm.jet,
aspect="auto",
interpolation="nearest",
)
axs[3].set_xlabel("Time (frames)")
axs[3].set_ylabel("Iterations")
axs[3].set_title("State sequence history")
yticks = [
int(y) for y in axs[3].get_yticks() if y < len(saved_iterations) and y > 0
]
yticklabels = saved_iterations[yticks]
axs[3].set_yticks(yticks)
axs[3].set_yticklabels(yticklabels)
title = f"Iteration {iteration}"
if model_name is not None:
title = f"{model_name}: {title}"
fig.suptitle(title)
fig.set_size_inches(fig_size)
plt.tight_layout()
if savefig:
path = _get_path(project_dir, model_name, path, "fitting_progress.pdf")
plt.savefig(path)
plt.show()
return fig, axs
[docs]
def write_video_clip(frames, path, fps=30, quality=7):
"""Write a video clip to a file.
Parameters
----------
frames : np.ndarray
Video frames as a 4D array of shape `(num_frames, height, width, 3)`
or a 3D array of shape `(num_frames, height, width)`.
path : str
Path to save the video clip.
fps : int, default=30
Framerate of video encoding.
quality : int, default=7
Quality of video encoding.
"""
with imageio.get_writer(
path, pixelformat="yuv420p", fps=fps, quality=quality
) as writer:
for frame in frames:
writer.append_data(frame)
def _grid_movie_tile(
key,
start,
end,
video_paths,
centroids,
headings,
dot_color,
window_size,
scaled_window_size,
pre,
post,
dot_radius,
overlay_keypoints,
edges,
coordinates,
plot_options,
video_frame_indexes,
use_dims,
):
scale_factor = scaled_window_size / window_size
cs = centroids[key][start - pre : start + post]
h, c = headings[key][start], cs[pre]
r = np.float32([[np.cos(h), np.sin(h)], [-np.sin(h), np.cos(h)]])
syllable_coordinates = coordinates[key][start - pre : start + post].copy()
keypoint_dimension = next(iter(centroids.values())).shape[-1]
assert not (
keypoint_dimension == 3 and video_paths is not None
), "3D keypoints are not supported when video paths are provided"
if keypoint_dimension == 3:
ds = np.array(use_dims)
if ds[1] == 2:
syllable_coordinates[:, :, :2] = (
(syllable_coordinates[:, :, :2] - c[:2]) @ r.T
) + c[:2]
syllable_coordinates[:, :, 2] = (
-(syllable_coordinates[:, :, 2] - c[2]) + c[2]
)
r = np.float32([[1, 0], [0, 1]])
cs = cs[:, ds]
c = c[ds]
syllable_coordinates = syllable_coordinates[:, :, ds]
tile = []
if video_paths is not None:
frame_ixs = video_frame_indexes[key][start - pre : start + post]
reader = OpenCVReader(video_paths[key])
frames = [reader[ix] for ix in frame_ixs]
reader.close()
c = r @ c - window_size // 2
M = [[np.cos(h), np.sin(h), -c[0]], [-np.sin(h), np.cos(h), -c[1]]]
for ii, (frame, c) in enumerate(zip(frames, cs)):
if overlay_keypoints:
coords = syllable_coordinates[ii]
frame = overlay_keypoints_on_image(
frame, coords, edges=edges, **plot_options
)
frame = cv2.warpAffine(frame, np.float32(M), (window_size, window_size))
frame = cv2.resize(frame, (scaled_window_size, scaled_window_size))
if 0 <= ii - pre <= end - start and dot_radius > 0:
pos = tuple([int(x) for x in M @ np.append(c, 1) * scale_factor])
cv2.circle(frame, pos, dot_radius, dot_color, -1, cv2.LINE_AA)
tile.append(frame)
else: # first transform keypoints, then overlay on black background
assert overlay_keypoints, fill(
"If no videos are provided, then `overlay_keypoints` must "
"be True. Otherwise there is nothing to show"
)
scale_factor = scaled_window_size / window_size
syllable_coordinates = (
syllable_coordinates - c
) @ r.T * scale_factor + scaled_window_size // 2
cs = (cs - c) @ r.T * scale_factor + scaled_window_size // 2
background = np.zeros((scaled_window_size, scaled_window_size, 3))
for ii, (uvs, c) in enumerate(zip(syllable_coordinates, cs)):
frame = overlay_keypoints_on_image(
background.copy(), uvs, edges=edges, **plot_options
)
if 0 <= ii - pre <= end - start and dot_radius > 0:
pos = (int(c[0]), int(c[1]))
cv2.circle(frame, pos, dot_radius, dot_color, -1, cv2.LINE_AA)
tile.append(frame)
return np.stack(tile)
[docs]
def grid_movie(
instances,
rows,
cols,
video_paths,
centroids,
headings,
window_size,
video_frame_indexes,
dot_color=(255, 255, 255),
dot_radius=4,
pre=30,
post=60,
scaled_window_size=None,
edges=[],
overlay_keypoints=False,
coordinates=None,
plot_options={},
use_dims=None,
):
"""Generate a grid movie and return it as an array of frames.
Grid movies show many instances of a syllable. Each instance contains a snippet of
video (and/or keypoint-overlay) centered on the animal and synchronized to the onset
of the syllable. A dot appears at syllable onset and disappears at syllable offset.
Parameters
----------
instances: list of tuples `(key, start, end)`
List of syllable instances to include in the grid movie, where each instance is
specified as a tuple with the video name, start frame and end frame. The list
must have length `rows*cols`. The video names must also be keys in `videos`.
rows: int, cols : int
Number of rows and columns in the grid movie grid
video_paths: dict or None
Dictionary mapping video names to video paths. If None, the the grid movie will
not include video frames.
centroids: dict
Dictionary mapping video names to arrays of shape `(n_frames, 2)` with the x,y
coordinates of animal centroid on each frame
headings: dict
Dictionary mapping video names to arrays of shape `(n_frames,)` with the heading
of the animal on each frame (in radians)
window_size: int
Size of the window around the animal. This should be a multiple of 16 or imageio
will complain.
video_frame_indexes: dict
Dictionary mapping recording names to arrays of video frame indexes.
This is useful when the original keypoint coordinates used for modeling
corresponded to a subset of frames from each video (i.e. if videos were
trimmed or coordinates were downsampled).
dot_color: tuple of ints, default=(255,255,255)
RGB color of the dot indicating syllable onset and offset
dot_radius: int, default=4
Radius of the dot indicating syllable onset and offset
pre: int, default=30
Number of frames before syllable onset to include in the movie
post: int, default=60
Number of frames after syllable onset to include in the movie
scaled_window_size: int, default=None
Window size after scaling the video. If None, the no scaling is performed (i.e.
`scaled_window_size = window_size`)
overlay_keypoints: bool, default=False
If True, overlay the pose skeleton on the video frames.
edges: list of tuples, default=[]
List of edges defining pose skeleton. Used when `overlay_keypoints=True`.
coordinates: dict, default=None
Dictionary mapping video names to arrays of shape `(n_frames, 2)`. Used when
`overlay_keypoints=True`.
plot_options: dict, default={}
Dictionary of options to pass to `overlay_keypoints_on_image`. Used when
`overlay_keypoints=True`.
use_dims: pair of ints, default=[0,1]
Dimensions to use for plotting keypoints. Only used when
`overlay_keypoints=True` and the keypoints are 3D.
Returns
-------
frames: array of shape `(post+pre, width, height, 3)`
Array of frames in the grid movie where::
width = rows * scaled_window_size
height = cols * scaled_window_size
"""
if video_paths is None:
assert overlay_keypoints, fill(
"If no videos are provided, then `overlay_keypoints` must "
"be True. Otherwise there is nothing to show"
)
if scaled_window_size is None:
scaled_window_size = window_size
tiles = []
for key, start, end in instances:
tiles.append(
_grid_movie_tile(
key,
start,
end,
video_paths,
centroids,
headings,
dot_color,
window_size,
scaled_window_size,
pre,
post,
dot_radius,
overlay_keypoints,
edges,
coordinates,
plot_options,
video_frame_indexes,
use_dims,
)
)
tiles = np.stack(tiles).reshape(
rows, cols, post + pre, scaled_window_size, scaled_window_size, 3
)
frames = np.concatenate(np.concatenate(tiles, axis=2), axis=2)
return frames
[docs]
def get_grid_movie_window_size(
sampled_instances,
centroids,
headings,
coordinates,
pre,
post,
pctl=90,
fudge_factor=1.1,
blocksize=16,
):
"""Automatically determine the window size for a grid movie.
The window size is set such that across all sampled instances,
the animal is fully visible in at least `pctl` percent of frames.
Parameters
----------
sampled_instances: dict
Dictionary mapping syllables to lists of instances, where each
instance is specified as a tuple with the video name, start frame
and end frame.
centroids: dict
Dictionary mapping video names to arrays of shape `(n_frames, 2)`
with the x,y coordinates of animal centroid on each frame
headings: dict
Dictionary mapping video names to arrays of shape `(n_frames,)`
with the heading of the animal on each frame (in radians)
coordinates: dict
Dictionary mapping recording names to keypoint coordinates as
ndarrays of shape (n_frames, n_bodyparts, 2).
pre, post: int
Number of frames before/after syllable onset that are included
in the grid movies.
pctl: int, default=95
Percentile of frames in which the animal should be fully visible.
fudge_factor: float, default=1.1
Factor by which to multiply the window size.
blocksize: int, default=16
Window size is rounded up to the nearest multiple of `blocksize`.
"""
all_trajectories = get_instance_trajectories(
sum(sampled_instances.values(), []),
coordinates,
pre=pre,
post=post,
centroids=centroids,
headings=headings,
)
all_trajectories = np.concatenate(all_trajectories, axis=0)
all_trajectories = all_trajectories[~np.isnan(all_trajectories).all((1, 2))]
max_distances = np.nanmax(np.abs(all_trajectories), axis=1)
window_size = np.percentile(max_distances, pctl) * fudge_factor * 2
window_size = int(np.ceil(window_size / blocksize) * blocksize)
return window_size
[docs]
def generate_grid_movies(
results,
project_dir=None,
model_name=None,
output_dir=None,
video_dir=None,
video_paths=None,
video_frame_indexes=None,
pre=1.0,
post=2.0,
rows=4,
cols=6,
filter_size=9,
min_frequency=0.005,
min_duration=3,
dot_radius=4,
dot_color=(255, 255, 255),
quality=7,
window_size=None,
coordinates=None,
centroids=None,
headings=None,
bodyparts=None,
use_bodyparts=None,
sampling_options={},
video_extension=None,
max_video_size=1920,
skeleton=[],
overlay_keypoints=False,
keypoints_only=False,
keypoints_scale=1,
fps=None,
plot_options={},
use_dims=[0, 1],
keypoint_colormap="autumn",
**kwargs,
):
"""Generate grid movies for a modeled dataset.
Grid movies show many instances of a syllable and are useful in
figuring out what behavior the syllable captures
(see :py:func:`keypoint_moseq.viz.grid_movie`). This method
generates a grid movie for each syllable that is used sufficiently
often (i.e. has at least `rows*cols` instances with duration
of at least `min_duration` and an overall frequency of at least
`min_frequency`). The grid movies are saved to `output_dir` if
specified, or else to `{project_dir}/{model_name}/grid_movies`.
A subset of parameters are documented below. See
:py:func:`keypoint_moseq.viz.grid_movie` for the remaining parameters.
Parameters
----------
results: dict
Dictionary containing modeling results for a dataset (see
:py:func:`keypoint_moseq.fitting.extract_results`)
project_dir: str, default=None
Project directory. Required to save grid movies if `output_dir`
is None.
model_name: str, default=None
Name of the model. Required to save grid movies if
`output_dir` is None.
output_dir: str, default=None
Directory where grid movies should be saved. If None, grid
movies will be saved to `{project_dir}/{model_name}/grid_movies`.
video_dir: str, default=None
Directory containing videos of the modeled data (see
:py:func:`keypoint_moseq.io.find_matching_videos`). Either `video_dir`
or `video_paths` must be provided unless `keypoints_only=True`.
video_paths: dict, default=None
Dictionary mapping recording names to video paths. The recording
names must correspond to keys in the results dictionary. Either
`video_dir` or `video_paths` must be provided unless
`keypoints_only=True`.
video_frame_indexes: dict, default=None
Dictionary mapping recording names to arrays of video frame indexes.
This is useful when the original keypoint coordinates used for modeling
corresponded to a subset of frames from each video (i.e. if videos were
trimmed or coordinates were downsampled).
filter_size: int, default=9
Size of the median filter applied to centroids and headings
min_frequency: float, default=0.005
Minimum frequency of a syllable to be included in the grid movies.
min_duration: int, default=3
Minimum duration of a syllable instance to be included in the
grid movie for that syllable.
sampling_options: dict, default={}
Dictionary of options for sampling syllable instances (see
:py:func:`keypoint_moseq.util.sample_instances`).
coordinates: dict, default=None
Dictionary mapping recording names to keypoint coordinates as
ndarrays of shape (n_frames, n_bodyparts, [2 or 3]). Required when
`window_size=None`, or `overlay_keypoints=True`, or if using
density-based sampling (i.e. when `sampling_options['mode']=='density'`;
see :py:func:`keypoint_moseq.util.sample_instances`).
centroids: dict, default=None
Dictionary mapping recording names to arrays of shape `(n_frames, 2)`.
Overrides the centroid information in `results`.
headings: dict, default=None
Dictionary mapping recording names to arrays of shape `(n_frames,)`.
Overrides the heading information in `results`.
bodyparts: list of str, default=None
List of bodypart names in `coordinates`. Required when `coordinates` is
provided and bodyparts were reindexed for modeling.
use_bodyparts: list of str, default=None
Ordered list of bodyparts used for modeling. Required when
`coordinates` is provided and bodyparts were reindexed
for modeling.
quality: int, default=7
Quality of the grid movies. Higher values result in higher
quality movies but larger file sizes.
pre: float, default=1.0
Time in seconds before syllable onset to include in the grid movie.
This value will be converted to frames using the fps parameter.
post: float, default=2.0
Time in seconds after syllable onset to include in the grid movie.
This value will be converted to frames using the fps parameter.
rows, cols, dot_radius, dot_color: int
See :py:func:`keypoint_moseq.viz.grid_movie`
video_extension: str, default=None
Preferred video extension (passed to
:py:func:`keypoint_moseq.util.find_matching_videos`)
window_size: int, default=None
Size of the window around the animal. If None, the window
size is determined automatically based on the size of the
animal. If provided explicitly, `window_size` should be a
multiple of 16 or imageio will complain.
max_video_size: int, default=4000
Maximum size of the grid movie in pixels. If the grid movie
is larger than this, it will be downsampled.
skeleton: list of tuples, default=[]
List of tuples specifying the skeleton. Used when
`overlay_keypoints=True`.
overlay_keypoints: bool, default=False
Whether to overlay the keypoints on the grid movie.
keypoints_only: bool, default=False
Whether to only show the keypoints (i.e. no video frames).
Overrides `overlay_keypoints`.
keypoints_scale: float, default=1
Factor to scale keypoint coordinates before plotting. Only used when
`keypoints_only=True`. This is useful when the keypoints are 3D and
encoded in units that are larger than a pixel.
fps: int, default=None
Framerate of the videos from which keypoints were derived (required).
plot_options: dict, default={}
Dictionary of options to pass to
:py:func:`keypoint_moseq.viz.overlay_keypoints_on_image`.
use_dims: pair of ints, default=[0,1]
Dimensions to use for plotting keypoints. Only used when
`overlay_keypoints=True` and the keypoints are 3D.
keypoint_colormap: str, default='autumn'
Colormap used to color keypoints. Used when
`overlay_keypoints=True`.
Returns
-------
sampled_instances: dict
Dictionary mapping syllables to lists of instances shown in each in
grid movie (in row-major order), where each instance is specified as a
tuple with the video name, start frame and end frame.
"""
dimension_pairs = [
(0, 1),
(0, 2),
(1, 2),
]
assert (
tuple(use_dims) in dimension_pairs
), f"use_dims must be one of {[list(d) for d in dimension_pairs]}. Received {use_dims}."
# check inputs
assert fps is not None, "Passing None for fps is not supported."
if keypoints_only:
overlay_keypoints = True
else:
assert (video_dir is not None) or (video_paths is not None), fill(
"Either `video_dir` or `video_paths` is required unless `keypoints_only=True`"
)
if window_size is None or overlay_keypoints:
assert coordinates is not None, fill(
"`coordinates` must be provided if `window_size` is None "
"or `overlay_keypoints` is True"
)
pre = round(pre * fps)
post = round(post * fps)
# prepare output directory
output_dir = _get_path(
project_dir, model_name, output_dir, "grid_movies", "output_dir"
)
if not os.path.exists(output_dir):
os.makedirs(output_dir)
print(f"Writing grid movies to {output_dir}")
# reindex coordinates if necessary
if not (bodyparts is None or use_bodyparts is None or coordinates is None):
coordinates = reindex_by_bodyparts(coordinates, bodyparts, use_bodyparts)
# get edges for plotting skeleton
edges = []
if len(skeleton) > 0 and overlay_keypoints:
edges = get_edges(use_bodyparts, skeleton)
# load results
if results is None:
results = load_results(project_dir, model_name)
# extract syllables from results
syllables = {k: v["syllable"] for k, v in results.items()}
# extract and smooth centroids and headings
if centroids is None:
centroids = {k: v["centroid"] for k, v in results.items()}
if headings is None:
headings = {k: v["heading"] for k, v in results.items()}
centroids, headings = filter_centroids_headings(
centroids, headings, filter_size=filter_size
)
# scale keypoints if necessary
if keypoints_only:
for k, v in coordinates.items():
coordinates[k] = v * keypoints_scale
for k, v in centroids.items():
centroids[k] = v * keypoints_scale
# load video readers if necessary
if not keypoints_only:
if video_paths is None:
video_paths = find_matching_videos(
results.keys(),
video_dir,
as_dict=True,
video_extension=video_extension,
)
check_video_paths(video_paths, results.keys())
if video_frame_indexes is None:
video_frame_indexes = {k: np.arange(len(v)) for k, v in syllables.items()}
else:
assert set(video_frame_indexes.keys()) == set(
syllables.keys()
), "The keys of `video_frame_indexes` must match the keys of `results`"
for k, v in syllables.items():
assert len(v) == len(video_frame_indexes[k]), (
"There is a mismatch between the length of `video_frame_indexes` "
f"and the length of modeling results for key {k}."
f"\n\tLength of `video_frame_indexes` = {len(video_frame_indexes[k])}"
f"\n\tLength of modeling results = {len(v)}"
)
else:
video_paths = None
# sample instances for each syllable
syllable_instances = get_syllable_instances(
syllables,
pre=pre,
post=post,
min_duration=min_duration,
min_frequency=min_frequency,
min_instances=rows * cols,
)
if len(syllable_instances) == 0:
warnings.warn(
fill(
"No syllables with sufficient instances to make a grid movie. "
"This usually occurs when all frames have the same syllable label "
"(use `plot_syllable_frequencies` to check if this is the case)"
)
)
return
sampled_instances = sample_instances(
syllable_instances,
rows * cols,
coordinates=coordinates,
centroids=centroids,
headings=headings,
**sampling_options,
)
# determine window size for grid movies
if window_size is None:
window_size = get_grid_movie_window_size(
sampled_instances, centroids, headings, coordinates, pre, post
)
print(f"Using window size of {window_size} pixels")
if keypoints_only:
if window_size < 64:
warnings.warn(
fill(
"The scale of the keypoints is very small. This may result in "
"poor quality grid movies. Try increasing `keypoints_scale`."
)
)
# possibly reduce window size to keep grid movies under max_video_size
scaled_window_size = max_video_size / max(rows, cols)
scaled_window_size = int(np.floor(scaled_window_size / 16) * 16)
scaled_window_size = min(scaled_window_size, window_size)
scale_factor = scaled_window_size / window_size
if scale_factor < 1:
warnings.warn(
"\n"
+ fill(
f"Videos will be downscaled by a factor of {scale_factor:.2f} "
f"so that the grid movies are under {max_video_size} pixels. "
"Use `max_video_size` to increase or decrease this size limit."
)
+ "\n\n"
)
# add colormap to plot options
plot_options.update({"keypoint_colormap": keypoint_colormap})
# generate grid movies
for syllable, instances in tqdm.tqdm(
sampled_instances.items(), desc="Generating grid movies", ncols=72
):
frames = grid_movie(
instances,
rows,
cols,
video_paths,
centroids,
headings,
window_size,
video_frame_indexes,
edges=edges,
scaled_window_size=scaled_window_size,
dot_color=dot_color,
pre=pre,
post=post,
dot_radius=dot_radius,
overlay_keypoints=overlay_keypoints,
coordinates=coordinates,
plot_options=plot_options,
use_dims=use_dims,
)
path = os.path.join(output_dir, f"syllable{syllable}.mp4")
write_video_clip(frames, path, fps=fps, quality=quality)
return sampled_instances
[docs]
def get_limits(
coordinates,
pctl=1,
blocksize=None,
left=0.2,
right=0.2,
top=0.2,
bottom=0.2,
):
"""Get axis limits based on the coordinates of all keypoints.
For each axis, limits are determined using the percentiles
`pctl` and `100-pctl` and then padded by `padding`.
Parameters
----------
coordinates: ndarray or dict
Coordinates as an ndarray of shape (..., 2), or a dict
with values that are ndarrays of shape (..., 2).
pctl: float, default=1
Percentile to use for determining the axis limits.
blocksize: int, default=None
Axis limits are cast to integers and padded so that the width
and height are multiples of `blocksize`. This is useful
when they are used for generating cropped images for a video.
left, right, top, bottom: float, default=0.1
Fraction of the axis range to pad on each side.
Returns
-------
lims: ndarray of shape (2,dim)
Axis limits, in the format `[[xmin,ymin,...],[xmax,ymax,...]]`.
"""
if isinstance(coordinates, dict):
X = np.concatenate(list(coordinates.values())).reshape(-1, 2)
else:
X = coordinates.reshape(-1, 2)
xmin, ymin = np.nanpercentile(X, pctl, axis=0)
xmax, ymax = np.nanpercentile(X, 100 - pctl, axis=0)
width = xmax - xmin
height = ymax - ymin
xmin -= width * left
xmax += width * right
ymin -= height * bottom
ymax += height * top
lims = np.array([[xmin, ymin], [xmax, ymax]])
if blocksize is not None:
lims = np.round(lims)
padding = np.mod(lims[0] - lims[1], blocksize) / 2
lims[0] -= padding
lims[1] += padding
lims = np.ceil(lims)
return lims.astype(int)
def rasterize_figure(fig):
canvas = fig.canvas
canvas.draw()
width, height = canvas.get_width_height()
raster_flat = np.frombuffer(canvas.tostring_rgb(), dtype="uint8")
raster = raster_flat.reshape((height, width, 3))
return raster
[docs]
def plot_trajectories(
titles,
Xs,
lims,
edges=[],
n_cols=4,
invert=False,
keypoint_colormap="autumn",
keypoint_colors=None,
node_size=50.0,
line_width=3.0,
alpha=0.2,
num_timesteps=10,
plot_width=4,
overlap=(0.2, 0),
return_rasters=False,
):
"""Plot one or more pose trajectories on a common axis and return the axis.
(See :py:func:`keypoint_moseq.viz.generate_trajectory_plots`)
Parameters
----------
titles: list of str
List of titles for each trajectory plot.
Xs: list of ndarray
List of pose trajectories as ndarrays of shape
(n_frames, n_keypoints, 2).
lims: ndarray
Axis limits used for all the trajectory plots. The limits
should be provided as an array of shape (2,2) with the format
`[[xmin,ymin],[xmax,ymax]]`.
edges: list of tuples, default=[]
List of edges, where each edge is a tuple of two integers
n_cols: int, default=4
Number of columns in the figure (used when plotting multiple
trajectories).
invert: bool, default=False
Determines the background color of the figure. If `True`,
the background will be black.
keypoint_colormap : str or list
Name of a matplotlib colormap or a list of colors as (r,b,g)
tuples in the same order as as the keypoints.
keypoint_colors : array-like, shape=(num_keypoints,3), default=None
Color for each keypoint. If None, the keypoint colormap is used.
If the dtype is int, the values are assumed to be in the range 0-255,
otherwise they are assumed to be in the range 0-1.
node_size: int, default=50
Size of each keypoint.
line_width: int, default=3
Width of the lines connecting keypoints.
alpha: float, default=0.2
Opacity of fade-out layers.
num_timesteps: int, default=10
Number of timesteps to plot for each trajectory. The pose
at each timestep is determined by linearly interpolating
between the keypoints.
plot_width: int, default=4
Width of each trajectory plot in inches. The height is
determined by the aspect ratio of `lims`. The final figure
width is `fig_width * min(n_cols, len(X))`.
overlap: tuple of float, default=(0.2,0)
Amount of overlap between each trajectory plot as a tuple
with the format `(x_overlap, y_overlap)`. The values should
be between 0 and 1.
return_rasters: bool, default=False
Rasterize the matplotlib canvas after plotting each step of
the trajecory. This is used to generate an animated video/gif
of the trajectory.
Returns
-------
fig : :py:class:`matplotlib.figure.Figure`
Figure handle
ax: matplotlib.axes.Axes
Axis containing the trajectory plots.
"""
fill_color = "k" if invert else "w"
if keypoint_colors is None:
cmap = plt.colormaps[keypoint_colormap]
colors = plt.get_cmap(cmap)(np.linspace(0, 1, Xs[0].shape[1]))
elif isinstance(keypoint_colors[0][0], int):
colors = list(np.array(keypoint_colors) / 255)
else:
colors = list(keypoint_colors)
n_cols = min(n_cols, len(Xs))
n_rows = np.ceil(len(Xs) / n_cols)
offsets = np.stack(
np.meshgrid(
np.arange(n_cols) * np.diff(lims[:, 0]) * (1 - overlap[0]),
np.arange(n_rows) * np.diff(lims[:, 1]) * (overlap[1] - 1),
),
axis=-1,
).reshape(-1, 2)[: len(Xs)]
Xs = interpolate_along_axis(
np.linspace(0, Xs[0].shape[0], num_timesteps),
np.arange(Xs[0].shape[0]),
np.array(Xs),
axis=1,
)
Xs = Xs + offsets[:, None, None]
xmin, ymin = lims[0] + offsets.min(0)
xmax, ymax = lims[1] + offsets.max(0)
fig, ax = plt.subplots(frameon=False)
ax.fill_between(
[xmin, xmax],
y1=[ymax, ymax],
y2=[ymin, ymin],
facecolor=fill_color,
zorder=0,
clip_on=False,
)
title_xy = (lims * np.array([[0.5, 0.1], [0.5, 0.9]])).sum(0)
title_color = "w" if invert else "k"
for xy, text in zip(offsets + title_xy, titles):
ax.text(
*xy,
text,
c=title_color,
ha="center",
va="top",
zorder=Xs.shape[1] * 4 + 4,
)
# final extents in axis
final_width = xmax - xmin
final_height = title_xy[1] - ymin
fig_width = plot_width * (n_cols - (n_cols - 1) * overlap[0])
fig_height = final_height / final_width * fig_width
fig.set_size_inches((fig_width, fig_height))
ax.set_xlim(xmin, xmax)
ax.set_ylim(ymin, ymax)
ax.set_aspect("equal")
ax.axis("off")
plt.tight_layout()
rasters = [] # for making a gif
for i in range(Xs.shape[1]):
for X, offset in zip(Xs, offsets):
for ii, jj in edges:
ax.plot(
*X[i, (ii, jj)].T,
c="k",
zorder=i * 4,
linewidth=line_width,
clip_on=False,
)
for ii, jj in edges:
ax.plot(
*X[i, (ii, jj)].T,
c=colors[ii],
zorder=i * 4 + 1,
linewidth=line_width * 0.9,
clip_on=False,
)
ax.scatter(
*X[i].T,
c=colors,
zorder=i * 4 + 2,
edgecolor="k",
linewidth=0.4,
s=node_size,
clip_on=False,
)
if i < Xs.shape[1] - 1:
ax.fill_between(
[xmin, xmax],
y1=[ymax, ymax],
y2=[ymin, ymin],
facecolor=fill_color,
alpha=alpha,
zorder=i * 4 + 3,
clip_on=False,
)
if return_rasters:
rasters.append(rasterize_figure(fig))
return fig, ax, rasters
def save_gif(image_list, gif_filename, duration=0.5):
# Convert NumPy arrays to PIL Image objects
pil_images = [Image.fromarray(np.uint8(img)) for img in image_list]
# Save the PIL Images as an animated GIF
pil_images[0].save(
gif_filename,
save_all=True,
append_images=pil_images[1:],
duration=int(duration * 1000),
loop=0,
)
[docs]
def generate_trajectory_plots(
coordinates,
results,
project_dir=None,
model_name=None,
output_dir=None,
pre=0.167, # 5 frames at 30 fps
post=0.5, # 15 frames at 30 fps
min_frequency=0.005,
min_duration=3,
skeleton=[],
bodyparts=None,
use_bodyparts=None,
keypoint_colormap="autumn",
plot_options={},
get_limits_pctl=0,
padding={"left": 0.1, "right": 0.1, "top": 0.2, "bottom": 0.2},
lims=None,
save_individually=True,
save_gifs=True,
save_mp4s=False,
fps=None,
projection_planes=["xy", "xz"],
interactive=True,
density_sample=True,
sampling_options={"n_neighbors": 50},
**kwargs,
):
"""
Generate trajectory plots for a modeled dataset.
Each trajectory plot shows a sequence of poses along the average
trajectory through latent space associated with a given syllable.
A separate figure (and gif, optionally) is saved for each syllable,
along with a single figure showing all syllables in a grid. The
plots are saved to `{output_dir}` if it is provided, otherwise
they are saved to `{project_dir}/{model_name}/trajectory_plots`.
Plot-related parameters are described below. For the remaining
parameters see (:py:func:`keypoint_moseq.util.get_typical_trajectories`)
Parameters
----------
coordinates: dict
Dictionary mapping recording names to keypoint coordinates as
ndarrays of shape (n_frames, n_bodyparts, [2 or 3]).
results: dict
Dictionary containing modeling results for a dataset (see
:py:func:`keypoint_moseq.fitting.extract_results`).
project_dir: str, default=None
Project directory. Required to save trajectory plots if
`output_dir` is None.
model_name: str, default=None
Name of the model. Required to save trajectory plots if
`output_dir` is None.
output_dir: str, default=None
Directory where trajectory plots should be saved. If None,
plots will be saved to `{project_dir}/{model_name}/trajectory_plots`.
pre: float, default=0.167
Time in seconds before syllable onset to include in the trajectory plots.
This value will be converted to frames using the fps parameter.
post: float, default=0.5
Time in seconds after syllable onset to include in the trajectory plots.
This value will be converted to frames using the fps parameter.
skeleton : list, default=[]
List of edges that define the skeleton, where each edge is a
pair of bodypart names or a pair of indexes.
keypoint_colormap : str
Name of a matplotlib colormap to use for coloring the keypoints.
plot_options: dict, default={}
Dictionary of options for trajectory plots (see
:py:func:`keypoint_moseq.util.plot_trajectories`).
get_limits_pctl: float, default=0
Percentile to use for determining the axis limits. Higher values lead
to tighter axis limits.
padding: dict, default={'left':0.1, 'right':0.1, 'top':0.2, 'bottom':0.2}
Padding around trajectory plots. Controls the the distance
between trajectories (when multiple are shown in one figure)
as well as the title offset.
lims: ndarray of shape (2,2), default=None
Axis limits used for all the trajectory plots with format
`[[xmin,ymin],[xmax,ymax]]`. If None, the limits are determined
automatically based on the coordinates of the keypoints using
:py:func:`keypoint_moseq.viz.get_limits`.
save_individually: bool, default=True
If True, a separate figure is saved for each syllable (in
addition to the grid figure).
save_gifs: bool, default=True
Whether to save an animated gif of the trajectory plots.
save_mp4s: bool, default=False
Whether to save videos of the trajectory plots as .mp4 files
fps: int, default=None
Framerate of the videos from which keypoints were derived (required).
projection_planes: list (subset of ['xy', 'yz', 'xz']), default=['xy','xz']
For 3D data, defines the 2D plane(s) on which to project keypoint
coordinates. A separate plot will be saved for each plane with
the name of the plane (e.g. 'xy') as a suffix. This argument is
ignored for 2D data.
interactive: bool, default=True
For 3D data, whether to create an visualization that can be
rotated and zoomed. This argument is ignored for 2D data.
"""
assert fps is not None, "Passing None for fps is not supported."
pre = round(pre * fps)
post = round(post * fps)
plot_options.update({"keypoint_colormap": keypoint_colormap})
edges = [] if len(skeleton) == 0 else get_edges(use_bodyparts, skeleton)
output_dir = _get_path(
project_dir, model_name, output_dir, "trajectory_plots", "output_dir"
)
if not os.path.exists(output_dir):
os.makedirs(output_dir)
print(f"Saving trajectory plots to {output_dir}")
typical_trajectories = get_typical_trajectories(
coordinates,
results,
pre,
post,
min_frequency,
min_duration,
bodyparts,
use_bodyparts,
density_sample,
sampling_options,
)
syllable_ixs = sorted(typical_trajectories.keys())
titles = [f"Syllable{s}" for s in syllable_ixs]
Xs = np.stack([typical_trajectories[s] for s in syllable_ixs])
if Xs.shape[-1] == 3:
projection_planes = [
"".join(sorted(plane.lower())) for plane in projection_planes
]
assert set(projection_planes) <= set(["xy", "yz", "xz"]), fill(
"`projection_planes` must be a subset of `['xy','yz','xz']`"
)
if lims is not None:
assert lims.shape == (2, 3), fill(
"`lims` must be None or an ndarray of shape (2,3) when plotting 3D data"
)
all_Xs, all_lims, suffixes = [], [], []
for plane in projection_planes:
use_dims = {"xy": [0, 1], "yz": [1, 2], "xz": [0, 2]}[plane]
all_Xs.append(Xs[..., use_dims])
suffixes.append("." + plane)
if lims is None:
all_lims.append(get_limits(all_Xs[-1], pctl=get_limits_pctl, **padding))
else:
all_lims.append(lims[..., use_dims])
else:
all_Xs = [Xs * np.array([1, -1])] # flip y-axis
if lims is None:
lims = get_limits(all_Xs[-1], pctl=get_limits_pctl, **padding)
all_lims = [lims]
suffixes = [""]
for Xs_2D, lims, suffix in zip(all_Xs, all_lims, suffixes):
# individual plots
if save_individually:
desc = "Generating trajectory plots"
for title, X in tqdm.tqdm(
zip(titles, Xs_2D), desc=desc, total=len(titles), ncols=72
):
fig, ax, rasters = plot_trajectories(
[title],
X[None],
lims,
edges=edges,
return_rasters=(save_gifs or save_mp4s),
**plot_options,
)
plt.savefig(os.path.join(output_dir, f"{title}{suffix}.pdf"))
plt.close(fig=fig)
if save_gifs:
frame_duration = (pre + post) / len(rasters) / fps
path = os.path.join(output_dir, f"{title}{suffix}.gif")
save_gif(rasters, path, duration=frame_duration)
if save_mp4s:
use_fps = len(rasters) / (pre + post) * fps
path = os.path.join(output_dir, f"{title}{suffix}.mp4")
write_video_clip(rasters, path, fps=use_fps)
# grid plot
fig, ax, rasters = plot_trajectories(
titles,
Xs_2D,
lims,
edges=edges,
return_rasters=(save_gifs or save_mp4s),
**plot_options,
)
plt.savefig(os.path.join(output_dir, f"all_trajectories{suffix}.pdf"))
plt.show()
if save_gifs:
frame_duration = (pre + post) / len(rasters) / fps
path = os.path.join(output_dir, f"all_trajectories{suffix}.gif")
save_gif(rasters, path, duration=frame_duration)
if save_mp4s:
use_fps = len(rasters) / (pre + post) * fps
path = os.path.join(output_dir, f"all_trajectories{suffix}.mp4")
write_video_clip(rasters, path, fps=use_fps)
if interactive and Xs.shape[-1] == 3:
plot_trajectories_3D(Xs, titles, edges, output_dir, **plot_options)
[docs]
def overlay_keypoints_on_image(
image,
coordinates,
edges=[],
keypoint_colormap="autumn",
keypoint_colors=None,
node_size=5,
line_width=2,
copy=False,
opacity=1.0,
):
"""Overlay keypoints on an image.
Parameters
----------
image: ndarray of shape (height, width, 3)
Image to overlay keypoints on.
coordinates: ndarray of shape (num_keypoints, 2)
Array of keypoint coordinates.
edges: list of tuples, default=[]
List of edges that define the skeleton, where each edge is a
pair of indexes.
keypoint_colormap: str, default='autumn'
Name of a matplotlib colormap to use for coloring the keypoints.
keypoint_colors : array-like, shape=(num_keypoints,3), default=None
Color for each keypoint. If None, the keypoint colormap is used.
If the dtype is int, the values are assumed to be in the range 0-255,
otherwise they are assumed to be in the range 0-1.
node_size: int, default=5
Size of the keypoints.
line_width: int, default=2
Width of the skeleton lines.
copy: bool, default=False
Whether to copy the image before overlaying keypoints.
opacity: float, default=1.0
Opacity of the overlay graphics (0.0-1.0).
Returns
-------
image: ndarray of shape (height, width, 3)
Image with keypoints overlayed.
"""
if copy or opacity < 1.0:
canvas = image.copy()
else:
canvas = image
if keypoint_colors is None:
cmap = plt.colormaps[keypoint_colormap]
colors = np.array(cmap(np.linspace(0, 1, coordinates.shape[0])))[:, :3]
else:
colors = np.array(keypoint_colors)
if isinstance(colors[0, 0], float):
colors = [tuple([int(c) for c in cs * 255]) for cs in colors]
# overlay skeleton
for i, j in edges:
if np.isnan(coordinates[i, 0]) or np.isnan(coordinates[j, 0]):
continue
pos1 = (int(coordinates[i, 0]), int(coordinates[i, 1]))
pos2 = (int(coordinates[j, 0]), int(coordinates[j, 1]))
canvas = cv2.line(canvas, pos1, pos2, colors[i], line_width, cv2.LINE_AA)
# overlay keypoints
for i, (x, y) in enumerate(coordinates):
if np.isnan(x) or np.isnan(y):
continue
pos = (int(x), int(y))
canvas = cv2.circle(canvas, pos, node_size, colors[i], -1, lineType=cv2.LINE_AA)
if opacity < 1.0:
image = cv2.addWeighted(image, 1 - opacity, canvas, opacity, 0)
return image
[docs]
def overlay_keypoints_on_video(
video_path,
coordinates,
skeleton=[],
bodyparts=None,
use_bodyparts=None,
output_path=None,
show_frame_numbers=True,
text_color=(255, 255, 255),
crop_size=None,
frames=None,
quality=7,
centroid_smoothing_filter=10,
plot_options={},
video_frame_indexes=None,
):
"""Overlay keypoints on a video.
Parameters
----------
video_path: str
Path to a video file.
coordinates: ndarray of shape (num_frames, num_keypoints, 2)
Array of keypoint coordinates.
skeleton: list of tuples, default=[]
List of edges that define the skeleton, where each edge is a pair of bodypart
names or a pair of indexes.
bodyparts: list of str, default=None
List of bodypart names in `coordinates`. Required if `skeleton` is defined using
bodypart names.
use_bodyparts: list of str, default=None
Subset of bodyparts to plot. If None, all bodyparts are plotted.
output_path: str, default=None
Path to save the video. If None, the video is saved to `video_path` with the
suffix `_keypoints`.
show_frame_numbers: bool, default=True
Whether to overlay the frame number in the video.
text_color: tuple of int, default=(255,255,255)
Color for the frame number overlay.
crop_size: int, default=None
Size of the crop around the keypoints to overlay on the video. If None, the
entire video is used.
frames: iterable of int, default=None
Frames to overlay keypoints on (in the numbering of `coordinates`). If None,
all of `coordinates` is used. This option can be used in conjunction with
`video_frame_indexes` when the entries of `coordinates` do not correspond
one-to-one with frames of the video.
quality: int, default=7
Quality of the output video.
centroid_smoothing_filter: int, default=10
Amount of smoothing to determine cropping centroid.
plot_options: dict, default={}
Additional keyword arguments to pass to
:py:func:`keypoint_moseq.viz.overlay_keypoints`.
video_frame_indexes: array, default=None
Video frames corresponding to the entries of `coordinates`. If None, it is
assumed that the i'th entry of `coordinate` corresponds to the i'th video frame.
"""
if output_path is None:
output_path = os.path.splitext(video_path)[0] + "_keypoints.mp4"
print(f"Saving video to {output_path}")
if bodyparts is not None:
if use_bodyparts is not None:
coordinates = reindex_by_bodyparts(coordinates, bodyparts, use_bodyparts)
else:
use_bodyparts = bodyparts
edges = get_edges(use_bodyparts, skeleton)
else:
edges = skeleton
if crop_size is not None:
outliers = np.any(np.isnan(coordinates), axis=2)
interpolated_coordinates = interpolate_keypoints(coordinates, outliers)
crop_centroid = np.nanmedian(interpolated_coordinates, axis=1)
crop_centroid = gaussian_filter1d(
crop_centroid, centroid_smoothing_filter, axis=0
)
reader = OpenCVReader(video_path)
fps = reader.fps
if frames is None:
frames = np.arange(len(reader))
if video_frame_indexes is None:
video_frame_indexes = np.arange(len(coordinates))
with imageio.get_writer(
output_path, pixelformat="yuv420p", fps=fps, quality=quality
) as writer:
for frame in tqdm.tqdm(frames, ncols=72):
image = overlay_keypoints_on_image(
reader[frame], coordinates[frame], edges=edges, **plot_options
)
if crop_size is not None:
image = crop_image(image, crop_centroid[frame], crop_size)
if show_frame_numbers:
image = cv2.putText(
image,
f"Frame {frame}",
(10, 20),
cv2.FONT_HERSHEY_SIMPLEX,
0.5,
text_color,
1,
cv2.LINE_AA,
)
writer.append_data(image)
[docs]
def add_3D_pose_to_plotly_fig(
fig,
coords,
edges,
keypoint_colors,
node_size=50.0,
line_width=3.0,
visible=True,
opacity=1,
):
"""
Add a 3D pose to a plotly figure.
Parameters
----------
fig: plotly figure
Figure to which the pose should be added.
coords: ndarray (N,3)
3D coordinates of the pose.
edges: list of index pairs
Skeleton edges
keypoint_colors : array-like with shape (num_keypoints,3)
Color for each keypoint. If None, the keypoint colormap is used.
If the dtype is int, the values are assumed to be in the range 0-255,
otherwise they are assumed to be in the range 0-1.
node_size: float, default=50.0
Size of keypoints.
line_width: float, default=3.0
Width of skeleton edges.
visibility: bool, default=True
Initial visibility state of the nodes and edges
opacity: float, default=1
Opacity of the nodes and edges (0-1)
"""
if isinstance(keypoint_colors[0, 0], int):
keypoint_colors = np.array(keypoint_colors) / 255.0
marker = {
"size": node_size / 10,
"color": keypoint_colors,
"line": dict(color="black", width=0.5),
"opacity": opacity,
}
line = {"width": line_width, "color": f"rgba(0,0,0,{opacity})"}
fig.add_trace(
plotly.graph_objs.Scatter3d(
x=coords[:, 0],
y=coords[:, 1],
z=coords[:, 2],
mode="markers",
visible=visible,
marker=marker,
)
)
for e in edges:
fig.add_trace(
plotly.graph_objs.Scatter3d(
x=coords[e, 0],
y=coords[e, 1],
z=coords[e, 2],
mode="lines",
visible=visible,
line=line,
)
)
if keypoint_colors is None:
# Use a default color (for example, red) if no colors are provided
keypoint_colors = ["red"] * len(coords)
elif isinstance(keypoint_colors[0], int):
# Convert RGB values from [0, 255] to [0, 1]
keypoint_colors = np.array(keypoint_colors) / 255.0
else:
keypoint_colors = keypoint_colors
marker = {
"size": node_size,
"color": keypoint_colors,
"line": dict(color="black", width=0.5),
"opacity": opacity,
}
def plot_pcs_3D(
ymean,
ypcs,
edges,
keypoint_colors,
savefig,
project_dir=None,
node_size=50,
line_width=3,
height=400,
mean_pose_opacity=0.2,
):
"""
Visualize the components of a fitted PCA model based on 3D components.
For each PC, a subplot shows the mean pose (semi-transparent) along
with a perturbation of the mean pose in the direction of the PC.
Parameters
----------
ymean : ndarray (num_bodyparts, 3)
Mean pose.
ypcs : ndarray (num_pcs, num_bodyparts, 3)
Perturbations of the mean pose in the direction of each PC.
edges : list of index pairs
Skeleton edges.
keypoint_colors : array-like, shape=(num_keypoints,3), default=None
Color for each keypoint. If None, the keypoint colormap is used.
If the dtype is int, the values are assumed to be in the range 0-255,
otherwise they are assumed to be in the range 0-1.
savefig : bool
Whether to save the figure to a file. If true, the figure is
saved to `{project_dir}/pcs.html`
project_dir : str, default=None
Path to the project directory. Required if `savefig` is True.
node_size : float, default=50.0
Size of the keypoints in the figure.
line_width: float, default=3.0
Width of edges in skeleton
height : int, default=400
Height of the figure in pixels.
mean_pose_opacity: float, default=0.4
Opacity of the mean pose
"""
from plotly.subplots import make_subplots
fig = make_subplots(rows=1, cols=1, specs=[[{"type": "scatter3d"}]])
def visibility_mask(i):
visible = np.zeros((len(edges) + 1) * (len(ypcs) + 1))
visible[-(len(edges) + 1) :] = 1
visible[(len(edges) + 1) * i : (len(edges) + 1) * (i + 1)] = 1
return visible > 0
steps = []
for i, coords in enumerate(ypcs):
add_3D_pose_to_plotly_fig(
fig,
coords,
edges,
keypoint_colors,
visible=(i == 0),
node_size=node_size,
line_width=line_width,
)
steps.append(
dict(
method="update",
label=f"PC {i+1}",
args=[{"visible": visibility_mask(i)}],
)
)
add_3D_pose_to_plotly_fig(
fig,
ymean,
edges,
keypoint_colors,
opacity=mean_pose_opacity,
node_size=node_size,
line_width=line_width,
)
fig.update_layout(
height=height,
showlegend=False,
sliders=[dict(steps=steps)],
scene=dict(
xaxis=dict(showgrid=False, showbackground=False),
yaxis=dict(showgrid=False, showbackground=False),
zaxis=dict(showgrid=False, showline=True, linecolor="black"),
bgcolor="white",
aspectmode="data",
),
margin=dict(l=20, r=20, b=0, t=0, pad=10),
)
if savefig:
assert project_dir is not None, fill(
"The `savefig` option requires a `project_dir`"
)
save_path = os.path.join(project_dir, f"pcs.html")
fig.write_html(save_path)
print(f"Saved interactive plot to {save_path}")
fig.show()
def plot_trajectories_3D(
Xs,
titles,
edges,
output_dir,
keypoint_colormap="autumn",
keypoint_colors=None,
node_size=50.0,
line_width=3.0,
height=500,
skiprate=1,
):
"""
Visualize a set of 3D trajectories.
Parameters
----------
Xs : list of ndarrays (num_syllables, num_frames, num_bodyparts, 3)
Trajectories to visualize.
titles : list of str
Title for each trajectory.
edges : list of index pairs
Skeleton edges.
output_dir : str
Path to save the interactive plot.
keypoint_colormap : str, default='autumn'
Name of a matplotlib colormap to use for coloring the keypoints.
keypoint_colors : array-like, shape=(num_keypoints,3), default=None
Color for each keypoint. If None, the keypoint colormap is used.
If the dtype is int, the values are assumed to be in the range 0-255,
otherwise they are assumed to be in the range 0-1.
node_size : float, default=50.0
Size of the keypoints in the figure.
line_width: float, default=3.0
Width of edges in skeleton
height : int, default=500
Height of the figure in pixels.
skiprate : int, default=1
Plot every `skiprate` frames.
"""
from plotly.subplots import make_subplots
if keypoint_colors is None:
cmap = plt.colormaps[keypoint_colormap]
keypoint_colors = np.array(cmap(np.linspace(0, 1, Xs.shape[2])))[:, :3]
fig = make_subplots(rows=1, cols=1, specs=[[{"type": "scatter3d"}]])
Xs = Xs[:, ::skiprate]
def visibility_mask(i):
n = (len(edges) + 1) * len(Xs[1])
visible = np.zeros(n * len(Xs))
visible[n * i : n * (i + 1)] = 1
return visible > 0
steps = []
for i, X in enumerate(Xs):
opacities = np.linspace(0.3, 1, len(X) + 1)[1:] ** 2
for coords, opacity in zip(X, opacities):
add_3D_pose_to_plotly_fig(
fig,
coords,
edges,
keypoint_colors,
visible=(i == 0),
node_size=node_size,
line_width=line_width,
opacity=opacity,
)
steps.append(
dict(
method="update",
label=titles[i],
args=[{"visible": visibility_mask(i)}],
)
)
fig.update_layout(
height=height,
showlegend=False,
sliders=[dict(steps=steps)],
scene=dict(
xaxis=dict(showgrid=False, showbackground=False),
yaxis=dict(showgrid=False, showbackground=False),
zaxis=dict(showgrid=False, showline=True, linecolor="black"),
bgcolor="white",
aspectmode="data",
),
margin=dict(l=20, r=20, b=0, t=0, pad=10),
)
if output_dir is not None:
save_path = os.path.join(output_dir, f"all_trajectories.html")
fig.write_html(save_path)
print(f"Saved interactive trajectories plot to {save_path}")
fig.show()
[docs]
def plot_similarity_dendrogram(
coordinates,
results,
project_dir=None,
model_name=None,
save_path=None,
metric="cosine",
pre=0.167,
post=0.5,
min_frequency=0.005,
min_duration=3,
bodyparts=None,
use_bodyparts=None,
density_sample=False,
sampling_options={"n_neighbors": 50},
figsize=(6, 3),
fps=None,
**kwargs,
):
"""Plot a dendrogram showing the similarity between syllable trajectories.
The dendrogram is saved to `{save_path}` if it is provided, or
else to `{project_dir}/{model_name}/similarity_dendrogram.pdf`. Plot-
related parameters are described below. For the remaining parameters
see (:py:func:`keypoint_moseq.util.get_typical_trajectories`)
Parameters
----------
coordinates: dict
Dictionary mapping recording names to keypoint coordinates as
ndarrays of shape (n_frames, n_bodyparts, [2 or 3]).
results: dict
Dictionary containing modeling results for a dataset (see
:py:func:`keypoint_moseq.fitting.extract_results`).
project_dir: str, default=None
Project directory. Required to save figure if `save_path` is None.
model_name: str, default=None
Model name. Required to save figure if `save_path` is None.
save_path: str, default=None
Path to save the dendrogram plot (do not include an extension).
If None, the plot will be saved to
`{project_dir}/{name}/similarity_dendrogram.[pdf/png]`.
metric: str, default='cosine'
Distance metric to use. See :py:func:`scipy.spatial.pdist` for options.
figsize: tuple of float, default=(10,5)
Size of the dendrogram plot.
fps: int, default=None
Framerate of the videos from which keypoints were derived.
Must be specified, typically using the project config.
"""
pre = round(pre * fps)
post = round(post * fps)
save_path = _get_path(project_dir, model_name, save_path, "similarity_dendrogram")
distances, syllable_ixs = syllable_similarity(
coordinates,
results,
metric,
pre,
post,
min_frequency,
min_duration,
bodyparts,
use_bodyparts,
density_sample,
sampling_options,
)
Z = linkage(squareform(distances), "complete")
fig, ax = plt.subplots(1, 1)
labels = [f"Syllable {s}" for s in syllable_ixs]
dendrogram(Z, labels=labels, leaf_font_size=10, ax=ax, leaf_rotation=90)
ax.set_yticks([])
for spine in ax.spines.values():
spine.set_color("lightgray")
ax.set_title("Syllable similarity")
fig.set_size_inches(figsize)
print(f"Saving dendrogram plot to {save_path}")
for ext in ["pdf", "png"]:
plt.savefig(save_path + "." + ext)
[docs]
def matplotlib_colormap_to_plotly(cmap):
"""
Convert a matplotlib colormap to a plotly colormap.
Parameters
----------
cmap: str
Name of a matplotlib colormap.
Returns
-------
pl_colorscale: list
Plotly colormap.
"""
cmap = plt.colormaps[cmap]
pl_entries = 255
h = 1.0 / (pl_entries - 1)
pl_colorscale = []
for k in range(pl_entries):
C = (np.array(cmap(k * h)[:3]) * 255).astype(np.uint8)
pl_colorscale.append([k * h, "rgb" + str((C[0], C[1], C[2]))])
return pl_colorscale
[docs]
def initialize_3D_plot(height=500):
"""Create an empty 3D plotly figure."""
fig = make_subplots(rows=1, cols=1, specs=[[{"type": "scatter3d"}]])
fig.update_layout(
height=height,
showlegend=False,
scene=dict(
xaxis=dict(showgrid=False, showbackground=False),
yaxis=dict(showgrid=False, showbackground=False),
zaxis=dict(showgrid=False, showline=True, linecolor="black"),
bgcolor="white",
aspectmode="data",
),
margin=dict(l=20, r=20, b=0, t=0, pad=10),
)
return fig
[docs]
def add_3D_pose_to_fig(
fig,
coords,
edges,
keypoint_colormap="autumn",
node_size=6.0,
linewidth=3.0,
visible=True,
opacity=1,
):
"""Add a 3D pose to a plotly figure.
Parameters
----------
fig: plotly figure
Figure to which the pose should be added.
coords: ndarray (N,3)
3D coordinates of the pose.
edges: list of index pairs
Skeleton edges
keypoint_colormap: str, default='autumn'
Colormap to use for coloring keypoints.
node_size: float, default=6.0
Size of keypoints.
linewidth: float, default=3.0
Width of skeleton edges.
visibility: bool, default=True
Initial visibility state of the nodes and edges
opacity: float, default=1
Opacity of the nodes and edges (0-1)
"""
marker = {
"size": node_size,
"color": np.linspace(0, 1, len(coords)),
"colorscale": matplotlib_colormap_to_plotly(keypoint_colormap),
"line": dict(color="black", width=0.5),
"opacity": opacity,
}
line = {"width": linewidth, "color": f"rgba(0,0,0,{opacity})"}
fig.add_trace(
plotly.graph_objs.Scatter3d(
x=coords[:, 0],
y=coords[:, 1],
z=coords[:, 2],
mode="markers",
visible=visible,
marker=marker,
)
)
for e in edges:
fig.add_trace(
plotly.graph_objs.Scatter3d(
x=coords[e, 0],
y=coords[e, 1],
z=coords[e, 2],
mode="lines",
visible=visible,
line=line,
)
)
[docs]
def plot_pcs_3D(
ymean,
ypcs,
edges,
keypoint_colormap,
project_dir=None,
node_size=6,
linewidth=2,
height=400,
mean_pose_opacity=0.2,
):
"""
Visualize the components of a fitted PCA model based on 3D components.
For each PC, a subplot shows the mean pose (semi-transparent) along
with a perturbation of the mean pose in the direction of the PC.
Parameters
----------
ymean : ndarray (num_bodyparts, 3)
Mean pose.
ypcs : ndarray (num_pcs, num_bodyparts, 3)
Perturbations of the mean pose in the direction of each PC.
edges : list of index pairs
Skeleton edges.
keypoint_colormap : str
Name of a matplotlib colormap to use for coloring the keypoints.
project_dir : str, default=None
Path to the project directory. Required if `savefig` is True.
node_size : float, default=30.0
Size of the keypoints in the figure.
linewidth: float, default=2.0
Width of edges in skeleton
height : int, default=400
Height of the figure in pixels.
mean_pose_opacity: float, default=0.4
Opacity of the mean pose
"""
fig = initialize_3D_plot(height)
def visibility_mask(i):
visible = np.zeros((len(edges) + 1) * (len(ypcs) + 1))
visible[-(len(edges) + 1) :] = 1
visible[(len(edges) + 1) * i : (len(edges) + 1) * (i + 1)] = 1
return visible > 0
steps = []
for i, coords in enumerate(ypcs):
add_3D_pose_to_fig(
fig,
coords,
edges,
visible=(i == 0),
node_size=node_size,
linewidth=linewidth,
keypoint_colormap=keypoint_colormap,
)
steps.append(
dict(
method="update",
label=f"PC {i+1}",
args=[{"visible": visibility_mask(i)}],
)
)
add_3D_pose_to_fig(
fig,
ymean,
edges,
opacity=mean_pose_opacity,
node_size=node_size,
linewidth=linewidth,
keypoint_colormap=keypoint_colormap,
)
fig.update_layout(sliders=[dict(steps=steps)])
if project_dir is not None:
save_path = os.path.join(project_dir, f"pcs.html")
fig.write_html(save_path)
print(f"Saved interactive plot to {save_path}")
fig.show()
[docs]
def plot_trajectories_3D(
Xs,
titles,
edges,
output_dir,
keypoint_colormap="autumn",
node_size=8,
linewidth=3,
height=500,
skiprate=1,
):
"""
Visualize a set of 3D trajectories.
Parameters
----------
Xs : list of ndarrays (num_syllables, num_frames, num_bodyparts, 3)
Trajectories to visualize.
titles : list of str
Title for each trajectory.
edges : list of index pairs
Skeleton edges.
output_dir : str
Path to save the interactive plot.
keypoint_colormap : str, default='autumn'
Name of a matplotlib colormap to use for coloring the keypoints.
node_size : float, default=8.0
Size of the keypoints in the figure.
linewidth: float, default=3.0
Width of edges in skeleton
height : int, default=500
Height of the figure in pixels.
skiprate : int, default=1
Plot every `skiprate` frames.
"""
fig = initialize_3D_plot(height)
def visibility_mask(i):
n = (len(edges) + 1) * len(Xs[1])
visible = np.zeros(n * len(Xs))
visible[n * i : n * (i + 1)] = 1
return visible > 0
steps = []
Xs = Xs[:, ::skiprate]
for i, X in enumerate(Xs):
opacities = np.linspace(0.3, 1, len(X) + 1)[1:] ** 2
for coords, opacity in zip(X, opacities):
add_3D_pose_to_fig(
fig,
coords,
edges,
visible=(i == 0),
node_size=node_size,
linewidth=linewidth,
keypoint_colormap=keypoint_colormap,
opacity=opacity,
)
steps.append(
dict(
method="update",
label=titles[i],
args=[{"visible": visibility_mask(i)}],
)
)
fig.update_layout(sliders=[dict(steps=steps)])
if output_dir is not None:
save_path = os.path.join(output_dir, f"all_trajectories.html")
fig.write_html(save_path)
print(f"Saved interactive trajectories plot to {save_path}")
fig.show()
[docs]
def plot_poses_3D(
poses,
edges,
keypoint_colormap="autumn",
node_size=6.0,
linewidth=3.0,
):
"""Plot a sequence of 3D poses.
Parameters
----------
poses: array of shape (num_poses, num_bodyparts, 3)
3D poses to plot.
edges: list of index pairs
Skeleton edges.
keypoint_colormap: str, default='autumn'
Colormap to use for coloring keypoints.
node_size: float, default=6.0
Size of keypoints.
linewidth: float, default=3.0
Width of skeleton edges.
"""
fig = initialize_3D_plot()
def visibility_mask(i):
n = len(edges) + 1
visible = np.zeros(n * len(poses))
visible[n * i : n * (i + 1)] = 1
return visible > 0
steps = []
for i, pose in enumerate(poses):
add_3D_pose_to_fig(
fig,
pose,
edges,
visible=(i == 0),
keypoint_colormap=keypoint_colormap,
node_size=node_size,
linewidth=linewidth,
)
steps.append(
dict(
method="update",
label=f"Pose {i+1}",
args=[{"visible": visibility_mask(i)}],
)
)
fig.update_layout(sliders=[dict(steps=steps)])
fig.show()
[docs]
def hierarchical_clustering_order(X, dist_metric="euclidean", linkage_method="ward"):
"""Linearly order a set of points using hierarchical clustering.
Parameters
----------
X: ndarray of shape (num_points, num_features)
Points to order.
dist_metric: str, default='euclidean'
Distance metric to use.
linkage_method: str, default='ward'
Linkage method to use.
Returns
-------
ordering: ndarray of shape (num_points,)
Linear ordering of the points.
"""
D = pdist(X, dist_metric)
Z = linkage(D, linkage_method)
ordering = leaves_list(Z)
return ordering
[docs]
def plot_confusion_matrix(
results1, results2, min_frequency=0.005, sort=True, normalize=True
):
"""Plot a confusion matrix that compares syllables across two models.
Parameters
----------
results1: dict
Dictionary containing modeling results for the first model (see
:py:func:`keypoint_moseq.fitting.extract_results`).
results2: dict
Dictionary containing modeling results for the second model (see
:py:func:`keypoint_moseq.fitting.extract_results`).
min_frequency: float, default=0.005
Minimum frequency of a syllable to include in the confusion matrix.
sort: bool, default=True
Whether to sort the syllables from each model to emphasize the diagonal.
normalize: bool, default=True
Whether to row-normalize the confusion matrix.
Returns
-------
fig: matplotlib figure
Figure containing the confusion matrix.
ax: matplotlib axis
Axis containing the confusion matrix.
"""
syllables1 = np.concatenate(
[results1[k]["syllable"] for k in sorted(results1.keys())]
)
syllables2 = np.concatenate(
[results2[k]["syllable"] for k in sorted(results2.keys())]
)
C = np.zeros((np.max(syllables1) + 1, np.max(syllables2) + 1))
np.add.at(C, (syllables1, syllables2), 1)
if normalize:
C = C / np.sum(C, axis=1, keepdims=True)
ix1 = (get_frequencies(syllables1) > min_frequency).nonzero()[0]
ix2 = (get_frequencies(syllables2) > min_frequency).nonzero()[0]
C = C[ix1, :][:, ix2]
if sort:
row_order = hierarchical_clustering_order(C)
C = C[row_order, :]
ix1 = ix1[row_order]
col_order = np.argsort(np.argmax(C, axis=0))
C = C[:, col_order]
ix2 = ix2[col_order]
fig, ax = plt.subplots(1, 1)
im = ax.imshow(C)
ax.set_xticks(np.arange(len(ix2)))
ax.set_xticklabels(ix2)
ax.set_yticks(np.arange(len(ix1)))
ax.set_yticklabels(ix1)
ax.set_xlabel("Model 2")
ax.set_ylabel("Model 1")
ax.set_title("Confusion matrix")
cbar = fig.colorbar(im, ax=ax)
cbar.set_label("Probability")
fig.tight_layout()
return fig, ax
[docs]
def plot_eml_scores(eml_scores, eml_std_errs, model_names):
"""Plot expected marginal likelihood scores for a set of models.
Parameters
----------
eml_scores: ndarray of shape (num_models,)
EML score for each model.
eml_std_errs: ndarray of shape (num_models,)
Standard error of the EML score for each model.
model_names: list of str
Name of each model.
"""
num_models = len(eml_scores)
ordering = np.argsort(eml_scores)
eml_scores = eml_scores[ordering]
eml_std_errs = eml_std_errs[ordering]
model_names = [model_names[i] for i in ordering]
err_low = eml_scores - eml_std_errs
err_high = eml_scores + eml_std_errs
fig, ax = plt.subplots(1, 1, figsize=(4, 3.5))
for i in range(num_models):
ax.plot([i, i], [err_low[i], err_high[i]], c="k", linewidth=1)
ax.scatter(range(num_models), eml_scores, c="k")
ax.set_xticks(range(num_models))
ax.set_xticklabels(model_names, rotation=90)
ax.set_ylabel("EML score")
plt.tight_layout()
return fig, ax
[docs]
def plot_pose(
coordinates, bodyparts, skeleton, cmap="autumn", node_size=6, linewidth=3, ax=None
):
"""
Plot a single pose using matplotlib.
Parameters
----------
coordinates: ndarray of shape (num_bodyparts, 2)
2D coordinates of the pose.
bodyparts: list of str
Bodypart names.
skeleton: list of tuples
Skeleton edges as pairs of bodypart names.
cmap: str, default='autumn'
Colormap to use for coloring keypoints.
node_size: float, default=6
Size of keypoints.
linewidth: float, default=3
Width of skeleton edges.
ax: matplotlib axis, default=None
Axis to plot on. If None, a new axis is created.
Returns
-------
ax: matplotlib axis
Axis containing the plot.
"""
if ax is None:
fig, ax = plt.subplots(1, 1)
cmap = plt.get_cmap(cmap)
colors = cmap(np.linspace(0, 1, len(bodyparts)))
edges = get_edges(bodyparts, skeleton)
for i, (x, y) in enumerate(coordinates):
ax.scatter(x, y, s=node_size, c=[colors[i]])
for i, j in edges:
x = [coordinates[i, 0], coordinates[j, 0]]
y = [coordinates[i, 1], coordinates[j, 1]]
ax.plot(x, y, c=colors[i], linewidth=linewidth)
ax.set_aspect("equal")
return ax