Utilities

Functions:

np_io(fn)

Converts a function involving jax arrays to one that inputs and outputs numpy arrays.

print_dims_to_explain_variance(pca, f)

Print the number of principal components requred to explain a given fraction of variance.

list_files_with_exts(filepath_pattern, ext_list)

This function lists all the files matching a pattern and with a an extension in a list of extensions.

find_matching_videos(keys, video_dir[, ...])

Find video files for a set of recording names. The filename of each video is assumed to be a prefix within the recording name, i.e. the recording name has the form {video_name}{more_text}. If more than one video matches a recording name, the longest match will be used. For example given the following video directory::.

pad_along_axis(arr, pad_widths[, axis, value])

Pad an array along a single axis.

filter_angle(angles[, size, axis, method])

Perform median filtering on time-series of angles by transforming to a (cos,sin) representation, filtering in R^2, and then transforming back into angle space.

get_centroids_headings(coordinates, ...[, ...])

Compute centroids and headings from keypoint coordinates.

filter_centroids_headings(centroids, headings)

Perform median filtering on centroids and headings.

get_syllable_instances(stateseqs[, ...])

Map each syllable to a list of instances when it occured.

get_edges(use_bodyparts, skeleton)

Represent the skeleton as a list of index-pairs.

reindex_by_bodyparts(data, bodyparts, ...[, ...])

Use an ordered list of bodyparts to reindex keypoint coordinates.

get_instance_trajectories(...[, pre, post, ...])

Extract keypoint trajectories for a collection of syllable instances.

sample_instances(syllable_instances, num_samples)

Sample a fixed number of instances for each syllable.

interpolate_along_axis(x, xp, fp[, axis])

Linearly interpolate along a given axis.

interpolate_keypoints(coordinates, outliers)

Use linear interpolation to impute the coordinates of outliers.

filtered_derivative(Y_flat, ksize[, axis])

Compute the filtered derivative of a signal along a given axis.

permute_cyclic(arr[, mask, axis])

Cyclically permute an array along a given axis.

check_nan_proportions(coordinates, bodyparts)

Check if any bodyparts have a high proportion of NaNs.

format_data(coordinates[, confidences, ...])

Format keypoint coordinates and confidences for inference.

get_typical_trajectories(coordinates, results)

Generate representative keypoint trajectories for each syllable.

syllable_similarity(coordinates, results[, ...])

Generate a distance matrix over syllable trajectories.

downsample_timepoints(data, downsample_rate)

Downsample timepoints, e.g. for of coordinates or confidences.

check_video_paths(video_paths, keys)

Check if video paths are valid and match the keys.

keypoint_moseq.util.np_io(fn)[source]

Converts a function involving jax arrays to one that inputs and outputs numpy arrays.

keypoint_moseq.util.print_dims_to_explain_variance(pca, f)[source]

Print the number of principal components requred to explain a given fraction of variance.

Parameters:
keypoint_moseq.util.list_files_with_exts(filepath_pattern, ext_list, recursive=True)[source]

This function lists all the files matching a pattern and with a an extension in a list of extensions.

Parameters:
  • filepath_pattern (str or list) – A filepath pattern or a list thereof. Filepath patterns can be be a single file, a directory, or a path with wildcards (e.g., ‘/path/to/dir/prefix*’).

  • ext_list (list of str) – A list of file extensions to search for.

  • recursive (bool, default=True) – Whether to search for files recursively.

Returns:

A list of file paths.

Return type:

list

keypoint_moseq.util.find_matching_videos(keys, video_dir, as_dict=False, recursive=True, recording_name_suffix='', video_extension=None)[source]

Find video files for a set of recording names. The filename of each video is assumed to be a prefix within the recording name, i.e. the recording name has the form {video_name}{more_text}. If more than one video matches a recording name, the longest match will be used. For example given the following video directory:

video_dir
├─ videoname1.avi
└─ videoname2.avi

the videos would be matched to recording names as follows:

>>> keys = ['videoname1blahblah','videoname2yadayada']
>>> find_matching_videos(keys, video_dir, as_dict=True)

{'videoname1blahblah': 'video_dir/videoname1.avi',
 'videoname2blahblah': 'video_dir/videoname2.avi'}

A suffix can also be specified, in which case the recording name is assumed to have the form {video_name}{suffix}{more_text}.

Parameters:
  • keys (iterable) – Recording names (as strings)

  • video_dir (str) – Path to the video directory.

  • video_extension (str, default=None) – Extension of the video files. If None, videos are assumed to have the one of the following extensions: “mp4”, “avi”, “mov”

  • recursive (bool, default=True) – If True, search recursively for videos in subdirectories of video_dir.

  • as_dict (bool, default=False) – Determines whether to return a dict mapping recording names to video paths, or a list of paths in the same order as keys.

  • recording_name_suffix (str, default=None) – Suffix to append to the video name when searching for a match.

Returns:

video_paths

Return type:

list or dict (depending on as_dict)

keypoint_moseq.util.pad_along_axis(arr, pad_widths, axis=0, value=0)[source]

Pad an array along a single axis.

Parameters:
  • arr (ndarray, Array to be padded)

  • pad_widths (tuple (int,int), Amount of padding on either end)

  • axis (int, Axis along which to add padding)

  • value (float, Value of padded array elements)

Returns:

padded_arr

Return type:

ndarray

keypoint_moseq.util.filter_angle(angles, size=9, axis=0, method='median')[source]

Perform median filtering on time-series of angles by transforming to a (cos,sin) representation, filtering in R^2, and then transforming back into angle space.

Parameters:
  • angles (ndarray) – Array of angles (in radians)

  • size (int, default=9) – Size of the filtering kernel

  • axis (int, default=0) – Axis along which to filter

  • method (str, default='median') – Method for filtering. Options are ‘median’ and ‘gaussian’

Returns:

filtered_angles

Return type:

ndarray

keypoint_moseq.util.get_centroids_headings(coordinates, anterior_idxs, posterior_idxs, bodyparts=None, use_bodyparts=None, **kwargs)[source]

Compute centroids and headings from keypoint coordinates.

Parameters:
  • coordinates (dict) – Dictionary mapping recording names to keypoint coordinates as ndarrays of shape (n_frames, n_bodyparts, [2 or 3]).

  • anterior_idxs (array-like of int) – Indices of anterior bodyparts (after reindexing by use_bodyparts when the latter is specified).

  • posterior_idxs (array-like of int) – Indices of anterior bodyparts (after reindexing by use_bodyparts when the latter is specified).

  • bodyparts (list of str, default=None) – List of bodypart names in coordinates. Used to reindex coordinates when use_bodyparts is specified.

  • use_bodyparts (list of str, default=None) – Ordered list of bodyparts used to reindex coordinates.

Returns:

  • centroids (dict) – Dictionary mapping recording names to centroid coordinates as ndarrays of shape (n_frames, [2 or 3]).

  • headings (dict) – Dictionary mapping recording names to heading angles (in radians) as 1d arrays of shape (n_frames,).

keypoint_moseq.util.filter_centroids_headings(centroids, headings, filter_size=9)[source]

Perform median filtering on centroids and headings.

Parameters:
  • centroids (dict) – Centroids stored as a dictionary mapping recording names to ndarrays, of shape (n_frames, [2 or 3]).

  • headings (dict) – Dictionary mapping recording names to heading angles (in radians) as 1d arrays of shape (n_frames,).

  • filter_size (int, default=9) – Kernel size for median filtering

Returns:

  • filtered_centroids (dict)

  • filtered_headings (dict)

keypoint_moseq.util.get_syllable_instances(stateseqs, min_duration=3, pre=30, post=60, min_frequency=0, min_instances=0)[source]

Map each syllable to a list of instances when it occured. Only include instances that meet the criteria specified by pre, post, and min_duration. Only include syllables that meet the criteria specified by min_frequency and min_instances.

Parameters:
  • stateseqs (dict {str : 1d array}) – Dictionary mapping names to syllable sequences

  • min_duration (int, default=3) – Mininum duration for inclusion of a syllable instance

  • pre (int, default=30) – Syllable instances that start before this location in the state sequence will be excluded

  • post (int, default=60) – Syllable instances that end after this location in the state sequence will be excluded

  • min_frequency (int, default=0) – Minimum allowed frequency (across all state sequences) for inclusion of a syllable

  • min_instances (int, default=0) – Minimum number of instances (across all state sequences) for inclusion of a syllable

Returns:

syllable_instances – Dictionary mapping each syllable to a list of instances. Each instance is a tuple (name,start,end) representing subsequence stateseqs[name][start:end].

Return type:

dict

keypoint_moseq.util.get_edges(use_bodyparts, skeleton)[source]

Represent the skeleton as a list of index-pairs.

Parameters:
  • use_bodyparts (list) – Bodypart names

  • skeleton (list) – Pairs of bodypart names as tuples (bodypart1,bodypart2)

Returns:

edges – Pairs of indexes representing the enties of skeleton

Return type:

list

keypoint_moseq.util.reindex_by_bodyparts(data, bodyparts, use_bodyparts, axis=1)[source]

Use an ordered list of bodyparts to reindex keypoint coordinates.

Parameters:
  • data (dict or ndarray) – A single array of keypoint coordinates or a dict mapping from names to arrays of keypoint coordinates

  • bodyparts (list) – Label for each keypoint represented in data

  • use_bodyparts (list) – Ordered subset of keypoint labels

  • axis (int, default=1) – The axis in data that represents keypoints. It is required that data.shape[axis]==len(bodyparts).

Returns:

reindexed_data – Keypoint coordinates in the same form as data with reindexing applied.

Return type:

ndarray or dict

keypoint_moseq.util.get_instance_trajectories(syllable_instances, coordinates, pre=0, post=None, centroids=None, headings=None, filter_size=9)[source]

Extract keypoint trajectories for a collection of syllable instances.

If centroids and headings are provided, each trajectory is transformed into the ego-centric reference frame from the moment of syllable onset. When post is not None, trajectories will all terminate a fixed number of frames after syllable onset.

Parameters:
  • syllable_instances (list) – List of syllable instances, where each instance is a tuple of the form (name,start,end)

  • coordinates (dict) – Dictionary mapping names to coordinates, formatted as ndarrays with shape (num_frames, num_keypoints, d)

  • pre (int, default=0) – Number of frames to include before syllable onset

  • post (int, defualt=None) – Determines the length of the trajectory. When post=None, the trajectory terminates at the end of the syllable instance. Otherwise the trajectory terminates at a fixed number of frames after syllable (where the number is determined by post).

  • centroids (dict, default=None) – Dictionary with the same keys as coordinates mapping each name to an ndarray with shape (num_frames, d)

  • headings (dict, default=None) – Dictionary with the same keys as coordinates mapping each name to a 1d array of heading angles in radians

  • filter_size (int, default=9) – Size of median filter applied to centroids and headings

Returns:

trajectories – List or array of trajectories (a list is used when post=None, otherwise an array). Each trajectory is an array of shape (n_frames, n_bodyparts, [2 or 3]).

Return type:

list

keypoint_moseq.util.sample_instances(syllable_instances, num_samples, mode='random', pca_samples=50000, pca_dim=4, n_neighbors=50, coordinates=None, pre=5, post=15, centroids=None, headings=None, filter_size=9)[source]

Sample a fixed number of instances for each syllable.

Parameters:
  • syllable_instances (dict) – Mapping from each syllable to a list of instances, where each instance is a tuple of the form (name,start,end)

  • num_samples (int) – Number of samples return for each syllable

  • mode (str, {'random', 'density'}, default='random') –

    Sampling method to use. Options are:

    • ’random’: Instances are chosen randomly (without replacement)

    • ’density’: For each syllable, a syllable-specific density function is computed in trajectory space and compared to the overall density across all syllables. An exemplar instance that maximizes this ratio is chosen for each syllable, and its nearest neighbors are randomly sampled.

  • pca_samples (int, default=50000) – Number of trajectories to sample when fitting a PCA model for density estimation (used when mode=’density’)

  • pca_dim (int, default=4) – Number of principal components to use for density estimation (used when mode=’density’)

  • n_neighbors (int, defualt=50) – Number of neighbors to use for density estimation and for sampling the neighbors of the examplar syllable instance (used when mode=’density’)

  • coordinates – Passed to keypoint_moseq.util.get_instance_trajectories()

  • pre – Passed to keypoint_moseq.util.get_instance_trajectories()

  • pos – Passed to keypoint_moseq.util.get_instance_trajectories()

  • centroids – Passed to keypoint_moseq.util.get_instance_trajectories()

  • heading – Passed to keypoint_moseq.util.get_instance_trajectories()

  • filter_size – Passed to keypoint_moseq.util.get_instance_trajectories()

Returns:

sampled_instances – Dictionary in the same format as syllable_instances mapping each syllable to a list of sampled instances.

Return type:

dict

keypoint_moseq.util.interpolate_along_axis(x, xp, fp, axis=0)[source]

Linearly interpolate along a given axis.

Parameters:
  • x (1D array) – The x-coordinates of the interpolated values

  • xp (1D array) – The x-coordinates of the data points

  • fp (ndarray) – The y-coordinates of the data points. fp.shape[axis] must be equal to the length of xp.

Returns:

x_interp – The interpolated values, with the same shape as fp except along the interpolation axis.

Return type:

ndarray

keypoint_moseq.util.interpolate_keypoints(coordinates, outliers)[source]

Use linear interpolation to impute the coordinates of outliers.

Parameters:
  • coordinates (ndarray of shape (num_frames, num_keypoints, dim)) – Keypoint observations.

  • outliers (ndarray of shape (num_frames, num_keypoints)) – Binary indicator whose true entries are outlier points.

Returns:

interpolated_coordinates – Keypoint observations with outliers imputed.

Return type:

ndarray with same shape as coordinates

keypoint_moseq.util.filtered_derivative(Y_flat, ksize, axis=0)[source]

Compute the filtered derivative of a signal along a given axis.

When ksize=3, for example, the filtered derivative is

\[\dot{y_t} = \frac{1}{3}( x_{t+3}+x_{t+2}+x_{t+1}-x_{t-1}-x_{t-2}-x_{t-3})\]
Parameters:
  • Y_flat (ndarray) – The signal to differentiate

  • ksize (int) – The size of the filter. Must be odd.

  • axis (int, default=0) – The axis along which to differentiate

Returns:

dY – The filtered derivative of the signal

Return type:

ndarray

keypoint_moseq.util.permute_cyclic(arr, mask=None, axis=0)[source]

Cyclically permute an array along a given axis.

Parameters:
  • arr (ndarray) – The array to permute

  • mask (ndarray, optional) – A boolean mask indicating which elements to permute. If None, all elements are permuted.

  • axis (int, default=0) – The axis along which to permute

Returns:

arr_permuted – The permuted array

Return type:

ndarray

keypoint_moseq.util.check_nan_proportions(coordinates, bodyparts, warning_threshold=0.5, breakdown=False, **kwargs)[source]

Check if any bodyparts have a high proportion of NaNs.

Parameters:
  • coordinates (dict) – Dictionary mapping filenames to keypoint coordinates as ndarrays of shape (n_frames, n_bodyparts, 2)

  • bodyparts (list of str) – Name of each bodypart. The order of the names should match the order of the bodyparts in coordinates.

  • warning_threshold (float, default=0.5) – If the proportion of NaNs for a bodypart is greater than warning_threshold, then a warning is printed.

  • breakdown (bool, default=False) – Whether to print a table detailing the proportion of NaNs for each bodyparts in each array of coordinates.

keypoint_moseq.util.format_data(coordinates, confidences=None, keys=None, seg_length=None, bodyparts=None, use_bodyparts=None, conf_pseudocount=0.001, added_noise_level=0.1, **kwargs)[source]

Format keypoint coordinates and confidences for inference.

Data are transformed as follows:
  1. Coordinates and confidences are each merged into a single array using keypoint_moseq.util.batch(). Each row of the merged arrays is a segment from one recording.

  2. The keypoints axis is reindexed according to the order of elements in use_bodyparts with respect to their initial orer in bodyparts.

  3. Uniform noise proportional to added_noise_level is added to the keypoint coordinates to prevent degenerate solutions during fitting.

  4. Keypoint confidences are augmented by conf_pseudocount.

  5. Wherever NaNs occur in the coordinates, they are replaced by values imputed using linear interpolation, and the corresponding confidences are set to conf_pseudocount.

Parameters:
  • coordinates (dict) – Keypoint coordinates for a collection of recordings. Values must be numpy arrays of shape (T,K,D) where K is the number of keypoints and D={2 or 3}.

  • confidences (dict, default=None) – Nonnegative confidence values for the keypoints in coordinates as numpy arrays of shape (T,K).

  • keys (list of str, default=None) – (See keypoint_moseq.util.batch())

  • bodyparts (list, default=None) – Label for each keypoint represented in coordinates. Required to reindex coordinates and confidences according to use_bodyparts.

  • use_bodyparts (list, default=None) – Ordered subset of keypoint labels to be used for modeling. If use_bodyparts=None, then all keypoints are used.

  • conf_pseudocount (float, default=1e-3) – Pseudocount used to augment keypoint confidences.

  • seg_length (int, default=None) – Length of each segment. If seg_length=None, a length is chosen so that no time-series are broken into multiple segments. If all time-series are shorter than seg_length, then seg_length is set to the length of the shortest time-series.

Returns:

  • data (dict with the following items) –

    Y: jax array with shape (n_segs, seg_length, K, D)

    Keypoint coordinates from all recordings broken into fixed-length segments.

    conf: jax array with shape (n_segs, seg_length, K)

    Confidences from all recordings broken into fixed-length segments. If no input is provided for confidences, then data[“conf”]=None.

    mask: jax array with shape (n_segs, seg_length)

    Binary array where 0 indicates areas of padding (see keypoint_moseq.util.batch()).

  • metadata (tuple (keys, bounds)) – Metadata for the rows of Y, conf and mask, as a tuple with a array of recording names and an array of (start,end) times. See jax_moseq.utils.batch() for details.

keypoint_moseq.util.get_typical_trajectories(coordinates, results, pre=5, post=15, min_frequency=0.005, min_duration=3, bodyparts=None, use_bodyparts=None, density_sample=True, sampling_options={'n_neighbors': 50})[source]

Generate representative keypoint trajectories for each syllable.

Parameters:
  • coordinates (dict) – Dictionary mapping recording names to keypoint coordinates as ndarrays of shape (n_frames, n_bodyparts, 2).

  • results (dict) – Dictionary containing modeling results for a dataset (see keypoint_moseq.fitting.extract_results()).

  • pre (int, default=5, post: int, default=15) – Defines the temporal window around syllable onset for computing the average trajectory. Note that the window is independent of the actual duration of the syllable.

  • min_frequency (float, default=0.005) – Minimum frequency of a syllable to plotted.

  • min_duration (float, default=3) – Minimum duration of a syllable instance to be included in the trajectory average.

  • bodyparts (list of str, default=None) – List of bodypart names in coordinates.

  • use_bodyparts (list of str, default=None) – Ordered list of bodyparts to include in each trajectory. If None, all bodyparts will be included.

  • density_sample (bool, default=True) – Whether to use density sampling when generating trajectories. If True, the trajectory is based on the most exemplary syllable instances, rather than being average across all instances.

  • sampling_options (dict, default={'n_neighbors':50}) – Dictionary of options for sampling syllable instances (see keypoint_moseq.util.sample_instances()). Only used when density_sample is True.

Returns:

representative_trajectories – Dictionary mapping syllable indexes to representative trajectories as arrays of shape (pre+pose, n_bodyparts, [2 or 3]).

Return type:

dict

keypoint_moseq.util.syllable_similarity(coordinates, results, metric='cosine', pre=5, post=15, min_frequency=0.005, min_duration=3, bodyparts=None, use_bodyparts=None, density_sample=False, sampling_options={'n_neighbors': 50}, **kwargs)[source]

Generate a distance matrix over syllable trajectories.

See keypoint_moseq.util.get_typical_trajectories() for a description of the parameters not listed below.

Parameters:

metric (str, default='cosine') – Distance metric to use. See scipy.spatial.pdist() for options.

Returns:

  • distances (ndarray of shape (n_syllables, n_syllables)) – Pairwise distances between the typical trajectories associated with each syllable. Only syllables with sufficient frequency of occurence are included.

  • syllable_ixs (array of int) – Syllable indexes corresponding to the rows and columns of distances.

keypoint_moseq.util.downsample_timepoints(data, downsample_rate)[source]

Downsample timepoints, e.g. for of coordinates or confidences.

Parameters:
  • data (ndarray or dict) – Array of shape (n_frames, …) or a dictionary with such arrays as values.

  • downsample_rate (int) – The downsampling rate (e.g., downsample_rate=2 keeps every other frame).

Returns:

downsampled_data – Downsampled array or dictionary of arrays.

Return type:

ndarray or dict

keypoint_moseq.util.check_video_paths(video_paths, keys)[source]

Check if video paths are valid and match the keys.

Parameters:
  • video_paths (dict) – Dictionary mapping keys to video paths.

  • keys (list) – List of keys that require a video path.

Raises:

ValueError – If any of the following are true: - a video path is not provided for a key in keys - a video isn’t readable. - a video path does not exist.