Visualization

Functions:

crop_image(image, centroid, crop_size)

Crop an image around a centroid.

plot_scree(pca[, savefig, project_dir, fig_size])

Plot explained variance as a function of the number of PCs.

plot_pcs(pca, *, use_bodyparts, skeleton[, ...])

Visualize the components of a fitted PCA model.

plot_syllable_frequencies([project_dir, ...])

Plot a histogram showing the frequency of each syllable.

plot_duration_distribution([project_dir, ...])

Plot a histogram showing the frequency of each syllable.

plot_kappa_scan(kappas, project_dir, prefix)

Plot the results of a kappa scan.

plot_progress(model, data, checkpoint_path, ...)

Plot the progress of the model during fitting.

write_video_clip(frames, path[, fps, quality])

Write a video clip to a file.

grid_movie(instances, rows, cols, videos, ...)

Generate a grid movie and return it as an array of frames.

get_grid_movie_window_size(...[, pctl, ...])

Automatically determine the window size for a grid movie.

generate_grid_movies(results[, project_dir, ...])

Generate grid movies for a modeled dataset.

get_limits(coordinates[, pctl, blocksize, ...])

Get axis limits based on the coordinates of all keypoints.

plot_trajectories(titles, Xs, lims[, edges, ...])

Plot one or more pose trajectories on a common axis and return the axis.

generate_trajectory_plots(coordinates, results)

Generate trajectory plots for a modeled dataset.

overlay_keypoints_on_image(image, coordinates)

Overlay keypoints on an image.

overlay_keypoints_on_video(video_path, ...)

Overlay keypoints on a video.

add_3D_pose_to_plotly_fig(fig, coords, ...)

Add a 3D pose to a plotly figure.

plot_similarity_dendrogram(coordinates, results)

Plot a dendrogram showing the similarity between syllable trajectories.

matplotlib_colormap_to_plotly(cmap)

Convert a matplotlib colormap to a plotly colormap.

initialize_3D_plot([height])

Create an empty 3D plotly figure.

add_3D_pose_to_fig(fig, coords, edges[, ...])

Add a 3D pose to a plotly figure.

plot_pcs_3D(ymean, ypcs, edges, ...[, ...])

Visualize the components of a fitted PCA model based on 3D components.

plot_trajectories_3D(Xs, titles, edges, ...)

Visualize a set of 3D trajectories.

plot_poses_3D(poses, edges[, ...])

Plot a sequence of 3D poses.

hierarchical_clustering_order(X[, ...])

Linearly order a set of points using hierarchical clustering.

plot_confusion_matrix(results1, results2[, ...])

Plot a confusion matrix that compares syllables across two models.

plot_eml_scores(eml_scores, eml_std_errs, ...)

Plot expected marginal likelihood scores for a set of models.

keypoint_moseq.viz.crop_image(image, centroid, crop_size)[source]

Crop an image around a centroid.

Parameters:
  • image (ndarray of shape (height, width, 3)) – Image to crop.

  • centroid (tuple of int) – (x,y) coordinates of the centroid.

  • crop_size (int or tuple(int,int)) – Size of the crop around the centroid. Either a single int for a square crop, or a tuple of ints (w,h) for a rectangular crop.

Returns:

image – Cropped image.

Return type:

ndarray of shape (crop_size, crop_size, 3)

keypoint_moseq.viz.plot_scree(pca, savefig=True, project_dir=None, fig_size=(3, 2))[source]

Plot explained variance as a function of the number of PCs.

Parameters:
  • pca (sklearn.decomposition.PCA()) – Fitted PCA model

  • savefig (bool, True) – Whether to save the figure to a file. If true, the figure is saved to {project_dir}/pca_scree.pdf.

  • project_dir (str, default=None) – Path to the project directory. Required if savefig is True.

  • fig_size (tuple, (2.5,2)) – Size of the figure in inches.

Returns:

fig – Figure handle

Return type:

matplotlib.figure.Figure

keypoint_moseq.viz.plot_pcs(pca, *, use_bodyparts, skeleton, keypoint_colormap='autumn', keypoint_colors=None, savefig=True, project_dir=None, scale=1, plot_n_pcs=10, axis_size=(2, 1.5), ncols=5, node_size=30.0, line_width=2.0, interactive=True, **kwargs)[source]

Visualize the components of a fitted PCA model.

For each PC, a subplot shows the mean pose (semi-transparent) along with a perturbation of the mean pose in the direction of the PC.

Parameters:
  • pca (sklearn.decomposition.PCA()) – Fitted PCA model

  • use_bodyparts (list of str) – List of bodyparts to that are used in the model; used to index bodypart names in the skeleton.

  • skeleton (list) – List of edges that define the skeleton, where each edge is a pair of bodypart names.

  • keypoint_colormap (str) – Name of a matplotlib colormap to use for coloring the keypoints.

  • keypoint_colors (array-like, shape=(num_keypoints,3), default=None) – Color for each keypoint. If None, keypoint_colormap is used. If the dtype is int, the values are assumed to be in the range 0-255, otherwise they are assumed to be in the range 0-1.

  • savefig (bool, True) – Whether to save the figure to a file. If true, the figure is saved to {project_dir}/pcs-{xy/xz/yz}.pdf (xz and yz are only included for 3D data).

  • project_dir (str, default=None) – Path to the project directory. Required if savefig is True.

  • scale (float, default=0.5) – Scale factor for the perturbation of the mean pose.

  • plot_n_pcs (int, default=10) – Number of PCs to plot.

  • axis_size (tuple of float, default=(2,1.5)) – Size of each subplot in inches.

  • ncols (int, default=5) – Number of columns in the figure.

  • node_size (float, default=30.0) – Size of the keypoints in the figure.

  • line_width (float, default=2.0) – Width of edges in skeleton

  • interactive (bool, default=True) – For 3D data, whether to generate an interactive 3D plot.

keypoint_moseq.viz.plot_syllable_frequencies(project_dir=None, model_name=None, results=None, path=None, minlength=10, min_frequency=0.005)[source]

Plot a histogram showing the frequency of each syllable.

Caller must provide a results dictionary, a path to a results .h5, or a project directory and model name, in which case the results are loaded from {project_dir}/{model_name}/results.h5.

Parameters:
  • results (dict, default=None) – Dictionary containing modeling results for a dataset (see keypoint_moseq.fitting.extract_results())

  • model_name (str, default=None) – Name of the model. Required to load results if results is None and path is None.

  • project_dir (str, default=None) – Project directory. Required to load results if results is None and path is None.

  • path (str, default=None) – Path to a results file. If None, results will be loaded from {project_dir}/{model_name}/results.h5.

  • minlength (int, default=10) – Minimum x-axis length of the histogram.

  • min_frequency (float, default=0.005) – Minimum frequency of syllables to include in the histogram.

Returns:

  • fig (matplotlib.figure.Figure) – Figure containing the histogram.

  • ax (matplotlib.axes.Axes) – Axes containing the histogram.

keypoint_moseq.viz.plot_duration_distribution(project_dir=None, model_name=None, results=None, path=None, lim=None, num_bins=30, fps=None, show_median=True)[source]

Plot a histogram showing the frequency of each syllable.

Caller must provide a results dictionary, a path to a results .h5, or a project directory and model name, in which case the results are loaded from {project_dir}/{model_name}/results.h5.

Parameters:
  • results (dict, default=None) – Dictionary containing modeling results for a dataset (see keypoint_moseq.fitting.extract_results())

  • model_name (str, default=None) – Name of the model. Required to load results if results is None and path is None.

  • project_dir (str, default=None) – Project directory. Required to load results if results is None and path is None.

  • path (str, default=None) – Path to a results file. If None, results will be loaded from {project_dir}/{model_name}/results.h5.

  • lim (tuple, default=None) – x-axis limits as a pair of ints (in units of frames). If None, the limits are set to (0, 95th-percentile).

  • num_bins (int, default=30) – Number of bins in the histogram.

  • fps (int, default=None) – Frames per second. Used to convert x-axis from frames to seconds.

  • show_median (bool, default=True) – Whether to show the median duration as a vertical line.

Returns:

  • fig (matplotlib.figure.Figure) – Figure containing the histogram.

  • ax (matplotlib.axes.Axes) – Axes containing the histogram.

keypoint_moseq.viz.plot_kappa_scan(kappas, project_dir, prefix, figsize=(8, 2.5))[source]

Plot the results of a kappa scan.

This function assumes that model results for each kappa value are stored in {project_dir}/{prefix}-{kappa}/checkpoint.h5. Two plots are generated: (1) a line plot showing the median syllable duration over the course of fitting for each kappa value; (2) and a plot showing the final median syllable duration as a function of kappa.

Parameters:
  • kappas (array-like of float) – Kapppa values used in the scan.

  • project_dir (str) – Path to the project directory.

  • prefix (str) – Prefix for the kappa scan model names.

Returns:

  • fig (matplotlib.figure.Figure) – Figure containing the plot.

  • final_median_durations (array of float) – Median syllable durations for each kappa value, derived using the final iteration of each model.

keypoint_moseq.viz.plot_progress(model, data, checkpoint_path, iteration, project_dir=None, model_name=None, path=None, savefig=True, fig_size=None, window_size=600, min_frequency=0.001, min_histogram_length=10)[source]

Plot the progress of the model during fitting.

The figure shows the following plots:
  • Duration distribution:

    The distribution of state durations for the most recent iteration of the model.

  • Frequency distribution:

    The distribution of state frequencies for the most recent iteration of the model.

  • Median duration:

    The median state duration across iterations.

  • State sequence history

    The state sequence across iterations in a random window (a new window is selected each time the progress is plotted).

Parameters:
  • model (dict) – Model dictionary containing states

  • data (dict) – Data dictionary containing mask

  • checkpoint_path (str) – Path to an HDF5 file containing model checkpoints.

  • iteration (int) – Current iteration of model fitting

  • project_dir (str, default=None) – Path to the project directory. Required if savefig is True.

  • model_name (str, default=None) – Name of the model. Required if savefig is True.

  • savefig (bool, default=True) – Whether to save the figure to a file. If true, the figure is either saved to path or, to {project_dir}/{model_name}-progress.pdf if path is None.

  • fig_size (tuple of float, default=None) – Size of the figure in inches.

  • window_size (int, default=600) – Window size for state sequence history plot.

  • min_frequency (float, default=.001) – Minimum frequency for including a state in the frequency distribution plot.

  • min_histogram_length (int, default=10) – Minimum x-axis length of the frequency distribution plot.

Returns:

  • fig (matplotlib.figure.Figure) – Figure containing the plots.

  • axs (list of matplotlib.axes.Axes) – Axes containing the plots.

keypoint_moseq.viz.write_video_clip(frames, path, fps=30, quality=7)[source]

Write a video clip to a file.

Parameters:
  • frames (np.ndarray) – Video frames as a 4D array of shape (num_frames, height, width, 3) or a 3D array of shape (num_frames, height, width).

  • path (str) – Path to save the video clip.

  • fps (int, default=30) – Framerate of video encoding.

  • quality (int, default=7) – Quality of video encoding.

keypoint_moseq.viz.grid_movie(instances, rows, cols, videos, centroids, headings, window_size, dot_color=(255, 255, 255), dot_radius=4, pre=30, post=60, scaled_window_size=None, edges=[], overlay_keypoints=False, coordinates=None, plot_options={}, downsample_rate=1)[source]

Generate a grid movie and return it as an array of frames.

Grid movies show many instances of a syllable. Each instance contains a snippet of video (and/or keypoint-overlay) centered on the animal and synchronized to the onset of the syllable. A dot appears at syllable onset and disappears at syllable offset.

Parameters:
  • instances (list of tuples (key, start, end)) – List of syllable instances to include in the grid movie, where each instance is specified as a tuple with the video name, start frame and end frame. The list must have length rows*cols. The video names must also be keys in videos.

  • rows (int, cols : int) – Number of rows and columns in the grid movie grid

  • videos (dict or None) – Dictionary mapping video names to video readers. Frames from each reader should be accessible via __getitem__(int or slice). If None, the the grid movie will not include video frames.

  • centroids (dict) – Dictionary mapping video names to arrays of shape (n_frames, 2) with the x,y coordinates of animal centroid on each frame

  • headings (dict) – Dictionary mapping video names to arrays of shape (n_frames,) with the heading of the animal on each frame (in radians)

  • window_size (int) – Size of the window around the animal. This should be a multiple of 16 or imageio will complain.

  • dot_color (tuple of ints, default=(255,255,255)) – RGB color of the dot indicating syllable onset and offset

  • dot_radius (int, default=4) – Radius of the dot indicating syllable onset and offset

  • pre (int, default=30) – Number of frames before syllable onset to include in the movie

  • post (int, default=60) – Number of frames after syllable onset to include in the movie

  • scaled_window_size (int, default=None) – Window size after scaling the video. If None, the no scaling is performed (i.e. scaled_window_size = window_size)

  • overlay_keypoints (bool, default=False) – If True, overlay the pose skeleton on the video frames.

  • edges (list of tuples, default=[]) – List of edges defining pose skeleton. Used when overlay_keypoints=True.

  • coordinates (dict, default=None) – Dictionary mapping video names to arrays of shape (n_frames, 2). Used when overlay_keypoints=True.

  • plot_options (dict, default={}) – Dictionary of options to pass to overlay_keypoints_on_image. Used when overlay_keypoints=True.

  • downsample_rate (int, default=1) – Downsampling rate for the video frames. Coordinates at index i will be plotted on the video frame at index i*downsample_rate.

Returns:

frames

Array of frames in the grid movie where:

width = rows * scaled_window_size
height = cols * scaled_window_size

Return type:

array of shape (post+pre, width, height, 3)

keypoint_moseq.viz.get_grid_movie_window_size(sampled_instances, centroids, headings, coordinates, pre, post, pctl=90, fudge_factor=1.1, blocksize=16)[source]

Automatically determine the window size for a grid movie.

The window size is set such that across all sampled instances, the animal is fully visible in at least pctl percent of frames.

Parameters:
  • sampled_instances (dict) – Dictionary mapping syllables to lists of instances, where each instance is specified as a tuple with the video name, start frame and end frame.

  • centroids (dict) – Dictionary mapping video names to arrays of shape (n_frames, 2) with the x,y coordinates of animal centroid on each frame

  • headings (dict) – Dictionary mapping video names to arrays of shape (n_frames,) with the heading of the animal on each frame (in radians)

  • coordinates (dict) – Dictionary mapping recording names to keypoint coordinates as ndarrays of shape (n_frames, n_bodyparts, 2).

  • pre (int) – Number of frames before/after syllable onset that are included in the grid movies.

  • post (int) – Number of frames before/after syllable onset that are included in the grid movies.

  • pctl (int, default=95) – Percentile of frames in which the animal should be fully visible.

  • fudge_factor (float, default=1.1) – Factor by which to multiply the window size.

  • blocksize (int, default=16) – Window size is rounded up to the nearest multiple of blocksize.

keypoint_moseq.viz.generate_grid_movies(results, project_dir=None, model_name=None, output_dir=None, video_dir=None, video_paths=None, rows=4, cols=6, filter_size=9, pre=30, post=60, min_frequency=0.005, min_duration=3, dot_radius=4, dot_color=(255, 255, 255), quality=7, window_size=None, coordinates=None, centroids=None, headings=None, bodyparts=None, use_bodyparts=None, sampling_options={}, video_extension=None, max_video_size=1920, skeleton=[], overlay_keypoints=False, keypoints_only=False, keypoints_scale=1, fps=None, plot_options={}, use_dims=[0, 1], keypoint_colormap='autumn', downsample_rate=1, **kwargs)[source]

Generate grid movies for a modeled dataset.

Grid movies show many instances of a syllable and are useful in figuring out what behavior the syllable captures (see keypoint_moseq.viz.grid_movie()). This method generates a grid movie for each syllable that is used sufficiently often (i.e. has at least rows*cols instances with duration of at least min_duration and an overall frequency of at least min_frequency). The grid movies are saved to output_dir if specified, or else to {project_dir}/{model_name}/grid_movies.

Parameters:
  • results (dict) – Dictionary containing modeling results for a dataset (see keypoint_moseq.fitting.extract_results())

  • project_dir (str, default=None) – Project directory. Required to save grid movies if output_dir is None.

  • model_name (str, default=None) – Name of the model. Required to save grid movies if output_dir is None.

  • output_dir (str, default=None) – Directory where grid movies should be saved. If None, grid movies will be saved to {project_dir}/{model_name}/grid_movies.

  • video_dir (str, default=None) – Directory containing videos of the modeled data (see keypoint_moseq.io.find_matching_videos()). Either video_dir or video_paths must be provided unless keypoints_only=True.

  • video_paths (dict, default=None) – Dictionary mapping recording names to video paths. The recording names must correspond to keys in the results dictionary. Either video_dir or video_paths must be provided unless keypoints_only=True.

  • filter_size (int, default=9) – Size of the median filter applied to centroids and headings

  • min_frequency (float, default=0.005) – Minimum frequency of a syllable to be included in the grid movies.

  • min_duration (int, default=3) – Minimum duration of a syllable instance to be included in the grid movie for that syllable.

  • sampling_options (dict, default={}) – Dictionary of options for sampling syllable instances (see keypoint_moseq.util.sample_instances()).

  • coordinates (dict, default=None) – Dictionary mapping recording names to keypoint coordinates as ndarrays of shape (n_frames, n_bodyparts, [2 or 3]). Required when window_size=None, or overlay_keypoints=True, or if using density-based sampling (i.e. when sampling_options[‘mode’]==’density’; see keypoint_moseq.util.sample_instances()).

  • centroids (dict, default=None) – Dictionary mapping recording names to arrays of shape (n_frames, 2). Overrides the centroid information in results.

  • headings (dict, default=None) – Dictionary mapping recording names to arrays of shape (n_frames,). Overrides the heading information in results.

  • bodyparts (list of str, default=None) – List of bodypart names in coordinates. Required when coordinates is provided and bodyparts were reindexed for modeling.

  • use_bodyparts (list of str, default=None) – Ordered list of bodyparts used for modeling. Required when coordinates is provided and bodyparts were reindexed for modeling.

  • quality (int, default=7) – Quality of the grid movies. Higher values result in higher quality movies but larger file sizes.

  • rows – See keypoint_moseq.viz.grid_movie()

  • cols – See keypoint_moseq.viz.grid_movie()

  • pre – See keypoint_moseq.viz.grid_movie()

  • post – See keypoint_moseq.viz.grid_movie()

  • dot_radius – See keypoint_moseq.viz.grid_movie()

  • dot_color – See keypoint_moseq.viz.grid_movie()

  • window_size (int, default=None) – See keypoint_moseq.viz.grid_movie()

  • video_extension (str, default=None) – Preferred video extension (passed to keypoint_moseq.util.find_matching_videos())

  • window_size – Size of the window around the animal. If None, the window size is determined automatically based on the size of the animal. If provided explicitly, window_size should be a multiple of 16 or imageio will complain.

  • max_video_size (int, default=4000) – Maximum size of the grid movie in pixels. If the grid movie is larger than this, it will be downsampled.

  • skeleton (list of tuples, default=[]) – List of tuples specifying the skeleton. Used when overlay_keypoints=True.

  • overlay_keypoints (bool, default=False) – Whether to overlay the keypoints on the grid movie.

  • keypoints_only (bool, default=False) – Whether to only show the keypoints (i.e. no video frames). Overrides overlay_keypoints. When this option is used, the framerate should be explicitly specified using fps.

  • keypoints_scale (float, default=1) – Factor to scale keypoint coordinates before plotting. Only used when keypoints_only=True. This is useful when the keypoints are 3D and encoded in units that are larger than a pixel.

  • fps (int, default=30) – Framerate of the grid movie. If None, the framerate is determined from the videos.

  • plot_options (dict, default={}) – Dictionary of options to pass to keypoint_moseq.viz.overlay_keypoints_on_image().

  • use_dims (pair of ints, default=[0,1]) – Dimensions to use for plotting keypoints. Only used when overlay_keypoints=True and the keypoints are 3D.

  • keypoint_colormap (str, default='autumn') – Colormap used to color keypoints. Used when overlay_keypoints=True.

  • downsample_rate (int, default=1) – Downsampling rate for the video frames. Coordinates at index i will be plotted on the video frame at index i*downsample_rate.

See keypoint_moseq.viz.grid_movie() for the remaining parameters.

Returns:

sampled_instances – Dictionary mapping syllables to lists of instances shown in each in grid movie (in row-major order), where each instance is specified as a tuple with the video name, start frame and end frame.

Return type:

dict

keypoint_moseq.viz.get_limits(coordinates, pctl=1, blocksize=None, left=0.2, right=0.2, top=0.2, bottom=0.2)[source]

Get axis limits based on the coordinates of all keypoints.

For each axis, limits are determined using the percentiles pctl and 100-pctl and then padded by padding.

Parameters:
  • coordinates (ndarray or dict) – Coordinates as an ndarray of shape (…, 2), or a dict with values that are ndarrays of shape (…, 2).

  • pctl (float, default=1) – Percentile to use for determining the axis limits.

  • blocksize (int, default=None) – Axis limits are cast to integers and padded so that the width and height are multiples of blocksize. This is useful when they are used for generating cropped images for a video.

  • left (float, default=0.1) – Fraction of the axis range to pad on each side.

  • right (float, default=0.1) – Fraction of the axis range to pad on each side.

  • top (float, default=0.1) – Fraction of the axis range to pad on each side.

  • bottom (float, default=0.1) – Fraction of the axis range to pad on each side.

Returns:

lims – Axis limits, in the format [[xmin,ymin,…],[xmax,ymax,…]].

Return type:

ndarray of shape (2,dim)

keypoint_moseq.viz.plot_trajectories(titles, Xs, lims, edges=[], n_cols=4, invert=False, keypoint_colormap='autumn', keypoint_colors=None, node_size=50.0, line_width=3.0, alpha=0.2, num_timesteps=10, plot_width=4, overlap=(0.2, 0), return_rasters=False)[source]

Plot one or more pose trajectories on a common axis and return the axis.

(See keypoint_moseq.viz.generate_trajectory_plots())

Parameters:
  • titles (list of str) – List of titles for each trajectory plot.

  • Xs (list of ndarray) – List of pose trajectories as ndarrays of shape (n_frames, n_keypoints, 2).

  • lims (ndarray) – Axis limits used for all the trajectory plots. The limits should be provided as an array of shape (2,2) with the format [[xmin,ymin],[xmax,ymax]].

  • edges (list of tuples, default=[]) – List of edges, where each edge is a tuple of two integers

  • n_cols (int, default=4) – Number of columns in the figure (used when plotting multiple trajectories).

  • invert (bool, default=False) – Determines the background color of the figure. If True, the background will be black.

  • keypoint_colormap (str or list) – Name of a matplotlib colormap or a list of colors as (r,b,g) tuples in the same order as as the keypoints.

  • keypoint_colors (array-like, shape=(num_keypoints,3), default=None) – Color for each keypoint. If None, the keypoint colormap is used. If the dtype is int, the values are assumed to be in the range 0-255, otherwise they are assumed to be in the range 0-1.

  • node_size (int, default=50) – Size of each keypoint.

  • line_width (int, default=3) – Width of the lines connecting keypoints.

  • alpha (float, default=0.2) – Opacity of fade-out layers.

  • num_timesteps (int, default=10) – Number of timesteps to plot for each trajectory. The pose at each timestep is determined by linearly interpolating between the keypoints.

  • plot_width (int, default=4) – Width of each trajectory plot in inches. The height is determined by the aspect ratio of lims. The final figure width is fig_width * min(n_cols, len(X)).

  • overlap (tuple of float, default=(0.2,0)) – Amount of overlap between each trajectory plot as a tuple with the format (x_overlap, y_overlap). The values should be between 0 and 1.

  • return_rasters (bool, default=False) – Rasterize the matplotlib canvas after plotting each step of the trajecory. This is used to generate an animated video/gif of the trajectory.

Returns:

  • fig (matplotlib.figure.Figure) – Figure handle

  • ax (matplotlib.axes.Axes) – Axis containing the trajectory plots.

keypoint_moseq.viz.generate_trajectory_plots(coordinates, results, project_dir=None, model_name=None, output_dir=None, pre=5, post=15, min_frequency=0.005, min_duration=3, skeleton=[], bodyparts=None, use_bodyparts=None, keypoint_colormap='autumn', plot_options={}, get_limits_pctl=0, padding={'bottom': 0.2, 'left': 0.1, 'right': 0.1, 'top': 0.2}, lims=None, save_individually=True, save_gifs=True, save_mp4s=False, fps=30, projection_planes=['xy', 'xz'], interactive=True, density_sample=True, sampling_options={'n_neighbors': 50}, **kwargs)[source]

Generate trajectory plots for a modeled dataset.

Each trajectory plot shows a sequence of poses along the average trajectory through latent space associated with a given syllable. A separate figure (and gif, optionally) is saved for each syllable, along with a single figure showing all syllables in a grid. The plots are saved to {output_dir} if it is provided, otherwise they are saved to {project_dir}/{model_name}/trajectory_plots.

Plot-related parameters are described below. For the remaining parameters see (keypoint_moseq.util.get_typical_trajectories())

Parameters:
  • coordinates (dict) – Dictionary mapping recording names to keypoint coordinates as ndarrays of shape (n_frames, n_bodyparts, [2 or 3]).

  • results (dict) – Dictionary containing modeling results for a dataset (see keypoint_moseq.fitting.extract_results()).

  • project_dir (str, default=None) – Project directory. Required to save trajectory plots if output_dir is None.

  • model_name (str, default=None) – Name of the model. Required to save trajectory plots if output_dir is None.

  • output_dir (str, default=None) – Directory where trajectory plots should be saved. If None, plots will be saved to {project_dir}/{model_name}/trajectory_plots.

  • skeleton (list, default=[]) – List of edges that define the skeleton, where each edge is a pair of bodypart names or a pair of indexes.

  • keypoint_colormap (str) – Name of a matplotlib colormap to use for coloring the keypoints.

  • plot_options (dict, default={}) – Dictionary of options for trajectory plots (see keypoint_moseq.util.plot_trajectories()).

  • get_limits_pctl (float, default=0) – Percentile to use for determining the axis limits. Higher values lead to tighter axis limits.

  • padding (dict, default={'left':0.1, 'right':0.1, 'top':0.2, 'bottom':0.2}) – Padding around trajectory plots. Controls the the distance between trajectories (when multiple are shown in one figure) as well as the title offset.

  • lims (ndarray of shape (2,2), default=None) – Axis limits used for all the trajectory plots with format [[xmin,ymin],[xmax,ymax]]. If None, the limits are determined automatically based on the coordinates of the keypoints using keypoint_moseq.viz.get_limits().

  • save_individually (bool, default=True) – If True, a separate figure is saved for each syllable (in addition to the grid figure).

  • save_gifs (bool, default=True) – Whether to save an animated gif of the trajectory plots.

  • save_mp4s (bool, default=False) – Whether to save videos of the trajectory plots as .mp4 files

  • fps (int, default=30) – Framerate of the videos from which keypoints were derived. Used to set the framerate of gifs when save_gif=True.

  • projection_planes (list (subset of ['xy', 'yz', 'xz']), default=['xy','xz']) – For 3D data, defines the 2D plane(s) on which to project keypoint coordinates. A separate plot will be saved for each plane with the name of the plane (e.g. ‘xy’) as a suffix. This argument is ignored for 2D data.

  • interactive (bool, default=True) – For 3D data, whether to create an visualization that can be rotated and zoomed. This argument is ignored for 2D data.

keypoint_moseq.viz.overlay_keypoints_on_image(image, coordinates, edges=[], keypoint_colormap='autumn', keypoint_colors=None, node_size=5, line_width=2, copy=False, opacity=1.0)[source]

Overlay keypoints on an image.

Parameters:
  • image (ndarray of shape (height, width, 3)) – Image to overlay keypoints on.

  • coordinates (ndarray of shape (num_keypoints, 2)) – Array of keypoint coordinates.

  • edges (list of tuples, default=[]) – List of edges that define the skeleton, where each edge is a pair of indexes.

  • keypoint_colormap (str, default='autumn') – Name of a matplotlib colormap to use for coloring the keypoints.

  • keypoint_colors (array-like, shape=(num_keypoints,3), default=None) – Color for each keypoint. If None, the keypoint colormap is used. If the dtype is int, the values are assumed to be in the range 0-255, otherwise they are assumed to be in the range 0-1.

  • node_size (int, default=5) – Size of the keypoints.

  • line_width (int, default=2) – Width of the skeleton lines.

  • copy (bool, default=False) – Whether to copy the image before overlaying keypoints.

  • opacity (float, default=1.0) – Opacity of the overlay graphics (0.0-1.0).

Returns:

image – Image with keypoints overlayed.

Return type:

ndarray of shape (height, width, 3)

keypoint_moseq.viz.overlay_keypoints_on_video(video_path, coordinates, skeleton=[], bodyparts=None, use_bodyparts=None, output_path=None, show_frame_numbers=True, text_color=(255, 255, 255), crop_size=None, frames=None, quality=7, centroid_smoothing_filter=10, plot_options={}, downsample_rate=1)[source]

Overlay keypoints on a video.

Parameters:
  • video_path (str) – Path to a video file.

  • coordinates (ndarray of shape (num_frames, num_keypoints, 2)) – Array of keypoint coordinates.

  • skeleton (list of tuples, default=[]) – List of edges that define the skeleton, where each edge is a pair of bodypart names or a pair of indexes.

  • bodyparts (list of str, default=None) – List of bodypart names in coordinates. Required if skeleton is defined using bodypart names.

  • use_bodyparts (list of str, default=None) – Subset of bodyparts to plot. If None, all bodyparts are plotted.

  • output_path (str, default=None) – Path to save the video. If None, the video is saved to video_path with the suffix _keypoints.

  • show_frame_numbers (bool, default=True) – Whether to overlay the frame number in the video.

  • text_color (tuple of int, default=(255,255,255)) – Color for the frame number overlay.

  • crop_size (int, default=None) – Size of the crop around the keypoints to overlay on the video. If None, the entire video is used.

  • frames (iterable of int, default=None) – Frames to overlay keypoints on. If None, all frames are used.

  • quality (int, default=7) – Quality of the output video.

  • centroid_smoothing_filter (int, default=10) – Amount of smoothing to determine cropping centroid.

  • plot_options (dict, default={}) – Additional keyword arguments to pass to keypoint_moseq.viz.overlay_keypoints().

  • downsample_rate (int, default=1) – Downsampling rate for the video frames. Coordinates at index i will be overlayed on the frame at index i*downsample_rate.

keypoint_moseq.viz.add_3D_pose_to_plotly_fig(fig, coords, edges, keypoint_colors, node_size=50.0, line_width=3.0, visible=True, opacity=1)[source]

Add a 3D pose to a plotly figure.

Parameters:
  • fig (plotly figure) – Figure to which the pose should be added.

  • coords (ndarray (N,3)) – 3D coordinates of the pose.

  • edges (list of index pairs) – Skeleton edges

  • keypoint_colors (array-like with shape (num_keypoints,3)) – Color for each keypoint. If None, the keypoint colormap is used. If the dtype is int, the values are assumed to be in the range 0-255, otherwise they are assumed to be in the range 0-1.

  • node_size (float, default=50.0) – Size of keypoints.

  • line_width (float, default=3.0) – Width of skeleton edges.

  • visibility (bool, default=True) – Initial visibility state of the nodes and edges

  • opacity (float, default=1) – Opacity of the nodes and edges (0-1)

keypoint_moseq.viz.plot_similarity_dendrogram(coordinates, results, project_dir=None, model_name=None, save_path=None, metric='cosine', pre=5, post=15, min_frequency=0.005, min_duration=3, bodyparts=None, use_bodyparts=None, density_sample=False, sampling_options={'n_neighbors': 50}, figsize=(6, 3), **kwargs)[source]

Plot a dendrogram showing the similarity between syllable trajectories.

The dendrogram is saved to {save_path} if it is provided, or else to {project_dir}/{model_name}/similarity_dendrogram.pdf. Plot- related parameters are described below. For the remaining parameters see (keypoint_moseq.util.get_typical_trajectories())

Parameters:
  • coordinates (dict) – Dictionary mapping recording names to keypoint coordinates as ndarrays of shape (n_frames, n_bodyparts, [2 or 3]).

  • results (dict) – Dictionary containing modeling results for a dataset (see keypoint_moseq.fitting.extract_results()).

  • project_dir (str, default=None) – Project directory. Required to save figure if save_path is None.

  • model_name (str, default=None) – Model name. Required to save figure if save_path is None.

  • save_path (str, default=None) – Path to save the dendrogram plot (do not include an extension). If None, the plot will be saved to {project_dir}/{name}/similarity_dendrogram.[pdf/png].

  • metric (str, default='cosine') – Distance metric to use. See scipy.spatial.pdist() for options.

  • figsize (tuple of float, default=(10,5)) – Size of the dendrogram plot.

keypoint_moseq.viz.matplotlib_colormap_to_plotly(cmap)[source]

Convert a matplotlib colormap to a plotly colormap.

Parameters:

cmap (str) – Name of a matplotlib colormap.

Returns:

pl_colorscale – Plotly colormap.

Return type:

list

keypoint_moseq.viz.initialize_3D_plot(height=500)[source]

Create an empty 3D plotly figure.

keypoint_moseq.viz.add_3D_pose_to_fig(fig, coords, edges, keypoint_colormap='autumn', node_size=6.0, linewidth=3.0, visible=True, opacity=1)[source]

Add a 3D pose to a plotly figure.

Parameters:
  • fig (plotly figure) – Figure to which the pose should be added.

  • coords (ndarray (N,3)) – 3D coordinates of the pose.

  • edges (list of index pairs) – Skeleton edges

  • keypoint_colormap (str, default='autumn') – Colormap to use for coloring keypoints.

  • node_size (float, default=6.0) – Size of keypoints.

  • linewidth (float, default=3.0) – Width of skeleton edges.

  • visibility (bool, default=True) – Initial visibility state of the nodes and edges

  • opacity (float, default=1) – Opacity of the nodes and edges (0-1)

keypoint_moseq.viz.plot_pcs_3D(ymean, ypcs, edges, keypoint_colormap, project_dir=None, node_size=6, linewidth=2, height=400, mean_pose_opacity=0.2)[source]

Visualize the components of a fitted PCA model based on 3D components.

For each PC, a subplot shows the mean pose (semi-transparent) along with a perturbation of the mean pose in the direction of the PC.

Parameters:
  • ymean (ndarray (num_bodyparts, 3)) – Mean pose.

  • ypcs (ndarray (num_pcs, num_bodyparts, 3)) – Perturbations of the mean pose in the direction of each PC.

  • edges (list of index pairs) – Skeleton edges.

  • keypoint_colormap (str) – Name of a matplotlib colormap to use for coloring the keypoints.

  • project_dir (str, default=None) – Path to the project directory. Required if savefig is True.

  • node_size (float, default=30.0) – Size of the keypoints in the figure.

  • linewidth (float, default=2.0) – Width of edges in skeleton

  • height (int, default=400) – Height of the figure in pixels.

  • mean_pose_opacity (float, default=0.4) – Opacity of the mean pose

keypoint_moseq.viz.plot_trajectories_3D(Xs, titles, edges, output_dir, keypoint_colormap='autumn', node_size=8, linewidth=3, height=500, skiprate=1)[source]

Visualize a set of 3D trajectories.

Parameters:
  • Xs (list of ndarrays (num_syllables, num_frames, num_bodyparts, 3)) – Trajectories to visualize.

  • titles (list of str) – Title for each trajectory.

  • edges (list of index pairs) – Skeleton edges.

  • output_dir (str) – Path to save the interactive plot.

  • keypoint_colormap (str, default='autumn') – Name of a matplotlib colormap to use for coloring the keypoints.

  • node_size (float, default=8.0) – Size of the keypoints in the figure.

  • linewidth (float, default=3.0) – Width of edges in skeleton

  • height (int, default=500) – Height of the figure in pixels.

  • skiprate (int, default=1) – Plot every skiprate frames.

keypoint_moseq.viz.plot_poses_3D(poses, edges, keypoint_colormap='autumn', node_size=6.0, linewidth=3.0)[source]

Plot a sequence of 3D poses.

Parameters:
  • poses (array of shape (num_poses, num_bodyparts, 3)) – 3D poses to plot.

  • edges (list of index pairs) – Skeleton edges.

  • keypoint_colormap (str, default='autumn') – Colormap to use for coloring keypoints.

  • node_size (float, default=6.0) – Size of keypoints.

  • linewidth (float, default=3.0) – Width of skeleton edges.

keypoint_moseq.viz.hierarchical_clustering_order(X, dist_metric='euclidean', linkage_method='ward')[source]

Linearly order a set of points using hierarchical clustering.

Parameters:
  • X (ndarray of shape (num_points, num_features)) – Points to order.

  • dist_metric (str, default='euclidean') – Distance metric to use.

  • linkage_method (str, default='ward') – Linkage method to use.

Returns:

ordering – Linear ordering of the points.

Return type:

ndarray of shape (num_points,)

keypoint_moseq.viz.plot_confusion_matrix(results1, results2, min_frequency=0.005, sort=True, normalize=True)[source]

Plot a confusion matrix that compares syllables across two models.

Parameters:
  • results1 (dict) – Dictionary containing modeling results for the first model (see keypoint_moseq.fitting.extract_results()).

  • results2 (dict) – Dictionary containing modeling results for the second model (see keypoint_moseq.fitting.extract_results()).

  • min_frequency (float, default=0.005) – Minimum frequency of a syllable to include in the confusion matrix.

  • sort (bool, default=True) – Whether to sort the syllables from each model to emphasize the diagonal.

  • normalize (bool, default=True) – Whether to row-normalize the confusion matrix.

Returns:

  • fig (matplotlib figure) – Figure containing the confusion matrix.

  • ax (matplotlib axis) – Axis containing the confusion matrix.

keypoint_moseq.viz.plot_eml_scores(eml_scores, eml_std_errs, model_names)[source]

Plot expected marginal likelihood scores for a set of models.

Parameters:
  • eml_scores (ndarray of shape (num_models,)) – EML score for each model.

  • eml_std_errs (ndarray of shape (num_models,)) – Standard error of the EML score for each model.

  • model_names (list of str) – Name of each model.