Model fitting

Exceptions:

StopResampling

Functions:

init_model(*args[, location_aware, ...])

Initialize a model.

fit_model(model, data, metadata[, ...])

Fit a model to data.

apply_model(model, data, metadata[, ...])

Apply a model to new data.

estimate_syllable_marginals(model, data, ...)

Estimate marginal distributions over syllables.

update_hypparams(model_dict, **kwargs)

Edit the hyperparameters of a model.

expected_marginal_likelihoods([project_dir, ...])

Calculate the expected marginal likelihood score for each model.

exception keypoint_moseq.fitting.StopResampling[source]
keypoint_moseq.fitting.init_model(*args, location_aware=False, allo_hypparams=None, trans_hypparams=None, **kwargs)[source]

Initialize a model. Wrapper for jax_moseq.models.keypoint_slds.init_model and jax_moseq.models.allo_keypoint_slds.init_model.

Parameters:
  • location_aware (bool, default=False) – If True, the model will be initialized using the location-aware version of the keypoint-SLDS model (jax_moseq.models.allo_keypoint_slds).

  • allo_hypparams (dict, default=None) – Hyperparameters for the allo_keypoint_slds model. If None, default hyperparameters will be used.

Returns:

model – Model dictionary containing states, parameters, hyperparameters, noise prior, and random seed.

Return type:

dict

keypoint_moseq.fitting.fit_model(model, data, metadata, project_dir=None, model_name=None, num_iters=50, start_iter=0, verbose=False, ar_only=False, parallel_message_passing=None, jitter=0.001, generate_progress_plots=True, save_every_n_iters=25, location_aware=False, **kwargs)[source]

Fit a model to data.

This method optionally:
  • saves checkpoints of the model and data at regular intervals

  • plots of the model’s progress during fitting (see jax_moseq.viz.plot_progress())

Note that if a checkpoint file already exists, all model snapshots after start_iter will be deleted.

Parameters:
  • model (dict) – Model dictionary containing states, parameters, hyperparameters, noise prior, and random seed.

  • data (dict) – Data for model fitting (see keypoint_moseq.io.format_data()).

  • metadata (tuple (keys, bounds)) – Recordings and start/end frames for the data (see keypoint_moseq.io.format_data()).

  • project_dir (str, default=None) – Project directory; required if save_every_n_iters>0.

  • model_name (str, default=None) – Name of the model. If None, the model is named using the current date and time.

  • num_iters (int, default=50) – Number of Gibbs sampling iterations to run.

  • start_iter (int, default=0) – Index of the starting iteration, which is non-zero when continuing a previous fit.

  • verbose (bool, default=True) – If True, print the model’s progress during fitting.

  • ar_only (bool, default=False) – If True, fit an AR-HMM model using the latent trajectory defined by model[‘states’][‘x’] (see jax_moseq.models.arhmm.resample_model()). Otherwise fit a full keypoint-SLDS model (see jax_moseq.models.keypoint_slds.resample_model())

  • save_every_n_iters (int, default=25) – Save the current model every save_every_n_iters. To only save the final model, set save_every_n_iter=-1. To save nothing, set save_every_n_iters=None.

  • generate_progress_plots (bool, default=True) – If True, generate plots of the model’s progress during fitting. Plots are saved to {project_dir}/{model_name}/plots/.

  • parallel_message_passing (bool | string, default=None,) – Use parallel implementation of Kalman sampling, which can be faster but has a significantly longer jit time. If None, will be set automatically based on the backend (True for GPU, False for CPU). A warning will be raised if parallel_message_passing=True and JAX is CPU-bound. Set to ‘force’ to skip this check.

  • jitter (float, default=0.001) – Amount to boost the diagonal of the dynamics covariance matrix when resampling pose trajectories. Increasing this value can help prevent NaNs during fitting.

  • location_aware (bool, default=False) – If True, the model will be fit using the location-aware version of the keypoint-SLDS model (jax_moseq.models.allo_keypoint_slds).

Returns:

  • model (dict) – Model dictionary containing states, parameters, hyperparameters, noise prior, and random seed.

  • model_name (str) – Name of the model.

keypoint_moseq.fitting.apply_model(model, data, metadata, project_dir=None, model_name=None, num_iters=500, ar_only=False, save_results=True, verbose=False, results_path=None, parallel_message_passing=None, return_model=False, location_aware=False, **kwargs)[source]

Apply a model to new data.

Parameters:
  • model (dict) – Model dictionary containing states, parameters, hyperparameters, noise prior, and random seed.

  • data (dict) – Data for model fitting (see keypoint_moseq.io.format_data()).

  • metadata (tuple (keys, bounds)) – Recordings and start/end frames for the data (see keypoint_moseq.io.format_data()).

  • project_dir (str, default=None) – Path to the project directory. Required if save_results=True and results_path=None.

  • model_name (str, default=None) – Name of the model. Required if save_results=True and results_path=None.

  • num_iters (int, default=500) – Number of iterations to run the model.

  • ar_only (bool, default=False) – See keypoint_moseq.fitting.fit_model().

  • save_results (bool, default=True) – If True, the model outputs will be saved to disk (see keypoint_moseq.io.extract_results() for the output format).

  • verbose (bool, default=False) – Whether to print progress updates.

  • results_path (str, default=None) – Optional path for saving model outputs.

  • parallel_message_passing (bool | string, default=None,) – Use parallel implementation of Kalman sampling, which can be faster but has a significantly longer jit time. If None, will be set automatically based on the backend (True for GPU, False for CPU). A warning will be raised if parallel_message_passing=True and JAX is CPU-bound. Set to ‘force’ to skip this check.

  • return_model (bool, default=False) – Whether to return the model after fitting.

  • location_aware (bool, default=False) – If True, the model will be fit using the location-aware version of the keypoint-SLDS model (jax_moseq.models.allo_keypoint_slds).

Returns:

  • results (dict) – Dictionary of model outputs (for results format, see keypoint_moseq.io.extract_results()).

  • model (dict) – Model dictionary containing states, parameters, hyperparameters, noise prior, and random seed. Only returned if return_model=True.

keypoint_moseq.fitting.estimate_syllable_marginals(model, data, metadata, burn_in_iters=200, num_samples=100, steps_per_sample=10, return_samples=False, verbose=False, parallel_message_passing=None, location_aware=False, **kwargs)[source]

Estimate marginal distributions over syllables.

Parameters:
  • model (dict) – Model dictionary containing states, parameters, hyperparameters, noise prior, and random seed.

  • data (dict) – Data for model fitting (see keypoint_moseq.io.format_data()).

  • metadata (tuple (keys, bounds)) – Recordings and start/end frames for the data (see keypoint_moseq.io.format_data()).

  • burn_in_iters (int, default=200) – Number of resampling iterations to run before collecting samples.

  • num_samples (int, default=100) – Number of samples to collect for marginalization.

  • steps_per_sample (int, default=10) – Number of resampling iterations to run between collecting samples.

  • return_samples (bool, default=False) – Whether to store and return sampled syllable sequences. May require significant RAM.

  • verbose (bool, default=False) – Whether to print progress updates.

  • parallel_message_passing (bool | string, default=None,) – Use parallel implementation of Kalman sampling, which can be faster but has a significantly longer jit time. If None, will be set automatically based on the backend (True for GPU, False for CPU). A warning will be raised if parallel_message_passing=True and JAX is CPU-bound. Set to ‘force’ to skip this check.

  • location_aware (bool, default=False) – If True, the model will be fit using the location-aware version of the keypoint-SLDS model (jax_moseq.models.allo_keypoint_slds).

Returns:

  • marginal_estimates (dict) – Estimated marginal distributions over syllables in the form of a dictionary mapping recoriding names to arrays of shape (num_timepoints, num_syllables).

  • samples (dict) – Sampled syllable sequences in the form of a dictionary mapping recording names to arrays of shape (num_samples, num_timepoints). Only returned if return_samples=True.

keypoint_moseq.fitting.update_hypparams(model_dict, **kwargs)[source]

Edit the hyperparameters of a model.

Hyperparameters are stored as a nested dictionary in the hypparams key of the model dictionary. This function allows the user to update the hyperparameters of a model by passing in keyword arguments with the same name as the hyperparameter. The hyperparameter will be updated if it is a scalar value.

Parameters:
  • model_dict (dict) – Model dictionary.

  • kwargs (dict) – Keyword arguments mapping hyperparameter names to new values.

Returns:

model_dict – Model dictionary with updated hyperparameters.

Return type:

dict

keypoint_moseq.fitting.expected_marginal_likelihoods(project_dir=None, model_names=None, checkpoint_paths=None)[source]

Calculate the expected marginal likelihood score for each model.

The score is calculated as follows, where $ heta^{(i)}$ denotes the autoregressive parameters and transition matrix for the i’th model, $x^{(i)}$ denotes the latent trajectories for the i’th model, and the number of models iss $N$

\[ext{score}( heta^{(i)}) =\]

rac{1}{(N-1)} sum_{j eq i} P(x^{(j)} | heta^{(i)})$

project_dirstr

Path to the project directory. Required if checkpoint_paths is None.

model_nameslist of str

Names of the models to compare. Required if checkpoint_paths is None.

checkpoint_pathslist of str

Paths to the checkpoints to compare. Required if model_names and project_dir are None.

scoresnumpy array

Expected marginal likelihood score for each model.

standard_errorsnumpy array

Standard error of the expected marginal likelihood score for each model.