Model fitting
Exceptions:
Functions:
|
Initialize a model. |
|
Fit a model to data. |
|
Apply a model to new data. |
|
Estimate marginal distributions over syllables. |
|
Edit the hyperparameters of a model. |
|
Calculate the expected marginal likelihood score for each model. |
- keypoint_moseq.fitting.init_model(*args, location_aware=False, allo_hypparams=None, trans_hypparams=None, **kwargs)[source]
Initialize a model. Wrapper for jax_moseq.models.keypoint_slds.init_model and jax_moseq.models.allo_keypoint_slds.init_model.
- Parameters:
location_aware (bool, default=False) – If True, the model will be initialized using the location-aware version of the keypoint-SLDS model (jax_moseq.models.allo_keypoint_slds).
allo_hypparams (dict, default=None) – Hyperparameters for the allo_keypoint_slds model. If None, default hyperparameters will be used.
- Returns:
model – Model dictionary containing states, parameters, hyperparameters, noise prior, and random seed.
- Return type:
dict
- keypoint_moseq.fitting.fit_model(model, data, metadata, project_dir=None, model_name=None, num_iters=50, start_iter=0, verbose=False, ar_only=False, parallel_message_passing=None, jitter=0.001, generate_progress_plots=True, save_every_n_iters=25, location_aware=False, **kwargs)[source]
Fit a model to data.
- This method optionally:
saves checkpoints of the model and data at regular intervals
plots of the model’s progress during fitting (see
jax_moseq.viz.plot_progress()
)
Note that if a checkpoint file already exists, all model snapshots after start_iter will be deleted.
- Parameters:
model (dict) – Model dictionary containing states, parameters, hyperparameters, noise prior, and random seed.
data (dict) – Data for model fitting (see
keypoint_moseq.io.format_data()
).metadata (tuple (keys, bounds)) – Recordings and start/end frames for the data (see
keypoint_moseq.io.format_data()
).project_dir (str, default=None) – Project directory; required if save_every_n_iters>0.
model_name (str, default=None) – Name of the model. If None, the model is named using the current date and time.
num_iters (int, default=50) – Number of Gibbs sampling iterations to run.
start_iter (int, default=0) – Index of the starting iteration, which is non-zero when continuing a previous fit.
verbose (bool, default=True) – If True, print the model’s progress during fitting.
ar_only (bool, default=False) – If True, fit an AR-HMM model using the latent trajectory defined by model[‘states’][‘x’] (see
jax_moseq.models.arhmm.resample_model()
). Otherwise fit a full keypoint-SLDS model (seejax_moseq.models.keypoint_slds.resample_model()
)save_every_n_iters (int, default=25) – Save the current model every save_every_n_iters. To only save the final model, set save_every_n_iter=-1. To save nothing, set save_every_n_iters=None.
generate_progress_plots (bool, default=True) – If True, generate plots of the model’s progress during fitting. Plots are saved to {project_dir}/{model_name}/plots/.
parallel_message_passing (bool | string, default=None,) – Use parallel implementation of Kalman sampling, which can be faster but has a significantly longer jit time. If None, will be set automatically based on the backend (True for GPU, False for CPU). A warning will be raised if parallel_message_passing=True and JAX is CPU-bound. Set to ‘force’ to skip this check.
jitter (float, default=0.001) – Amount to boost the diagonal of the dynamics covariance matrix when resampling pose trajectories. Increasing this value can help prevent NaNs during fitting.
location_aware (bool, default=False) – If True, the model will be fit using the location-aware version of the keypoint-SLDS model (jax_moseq.models.allo_keypoint_slds).
- Returns:
model (dict) – Model dictionary containing states, parameters, hyperparameters, noise prior, and random seed.
model_name (str) – Name of the model.
- keypoint_moseq.fitting.apply_model(model, data, metadata, project_dir=None, model_name=None, num_iters=500, ar_only=False, save_results=True, verbose=False, results_path=None, parallel_message_passing=None, return_model=False, location_aware=False, **kwargs)[source]
Apply a model to new data.
- Parameters:
model (dict) – Model dictionary containing states, parameters, hyperparameters, noise prior, and random seed.
data (dict) – Data for model fitting (see
keypoint_moseq.io.format_data()
).metadata (tuple (keys, bounds)) – Recordings and start/end frames for the data (see
keypoint_moseq.io.format_data()
).project_dir (str, default=None) – Path to the project directory. Required if save_results=True and results_path=None.
model_name (str, default=None) – Name of the model. Required if save_results=True and results_path=None.
num_iters (int, default=500) – Number of iterations to run the model.
ar_only (bool, default=False) – See
keypoint_moseq.fitting.fit_model()
.save_results (bool, default=True) – If True, the model outputs will be saved to disk (see
keypoint_moseq.io.extract_results()
for the output format).verbose (bool, default=False) – Whether to print progress updates.
results_path (str, default=None) – Optional path for saving model outputs.
parallel_message_passing (bool | string, default=None,) – Use parallel implementation of Kalman sampling, which can be faster but has a significantly longer jit time. If None, will be set automatically based on the backend (True for GPU, False for CPU). A warning will be raised if parallel_message_passing=True and JAX is CPU-bound. Set to ‘force’ to skip this check.
return_model (bool, default=False) – Whether to return the model after fitting.
location_aware (bool, default=False) – If True, the model will be fit using the location-aware version of the keypoint-SLDS model (jax_moseq.models.allo_keypoint_slds).
- Returns:
results (dict) – Dictionary of model outputs (for results format, see
keypoint_moseq.io.extract_results()
).model (dict) – Model dictionary containing states, parameters, hyperparameters, noise prior, and random seed. Only returned if return_model=True.
- keypoint_moseq.fitting.estimate_syllable_marginals(model, data, metadata, burn_in_iters=200, num_samples=100, steps_per_sample=10, return_samples=False, verbose=False, parallel_message_passing=None, location_aware=False, **kwargs)[source]
Estimate marginal distributions over syllables.
- Parameters:
model (dict) – Model dictionary containing states, parameters, hyperparameters, noise prior, and random seed.
data (dict) – Data for model fitting (see
keypoint_moseq.io.format_data()
).metadata (tuple (keys, bounds)) – Recordings and start/end frames for the data (see
keypoint_moseq.io.format_data()
).burn_in_iters (int, default=200) – Number of resampling iterations to run before collecting samples.
num_samples (int, default=100) – Number of samples to collect for marginalization.
steps_per_sample (int, default=10) – Number of resampling iterations to run between collecting samples.
return_samples (bool, default=False) – Whether to store and return sampled syllable sequences. May require significant RAM.
verbose (bool, default=False) – Whether to print progress updates.
parallel_message_passing (bool | string, default=None,) – Use parallel implementation of Kalman sampling, which can be faster but has a significantly longer jit time. If None, will be set automatically based on the backend (True for GPU, False for CPU). A warning will be raised if parallel_message_passing=True and JAX is CPU-bound. Set to ‘force’ to skip this check.
location_aware (bool, default=False) – If True, the model will be fit using the location-aware version of the keypoint-SLDS model (jax_moseq.models.allo_keypoint_slds).
- Returns:
marginal_estimates (dict) – Estimated marginal distributions over syllables in the form of a dictionary mapping recoriding names to arrays of shape
(num_timepoints, num_syllables)
.samples (dict) – Sampled syllable sequences in the form of a dictionary mapping recording names to arrays of shape
(num_samples, num_timepoints)
. Only returned if return_samples=True.
- keypoint_moseq.fitting.update_hypparams(model_dict, **kwargs)[source]
Edit the hyperparameters of a model.
Hyperparameters are stored as a nested dictionary in the hypparams key of the model dictionary. This function allows the user to update the hyperparameters of a model by passing in keyword arguments with the same name as the hyperparameter. The hyperparameter will be updated if it is a scalar value.
- Parameters:
model_dict (dict) – Model dictionary.
kwargs (dict) – Keyword arguments mapping hyperparameter names to new values.
- Returns:
model_dict – Model dictionary with updated hyperparameters.
- Return type:
dict
- keypoint_moseq.fitting.expected_marginal_likelihoods(project_dir=None, model_names=None, checkpoint_paths=None)[source]
Calculate the expected marginal likelihood score for each model.
The score is calculated as follows, where $ heta^{(i)}$ denotes the autoregressive parameters and transition matrix for the i’th model, $x^{(i)}$ denotes the latent trajectories for the i’th model, and the number of models iss $N$
\[ext{score}( heta^{(i)}) =\]rac{1}{(N-1)} sum_{j eq i} P(x^{(j)} | heta^{(i)})$
- project_dirstr
Path to the project directory. Required if
checkpoint_paths
is None.- model_nameslist of str
Names of the models to compare. Required if
checkpoint_paths
is None.- checkpoint_pathslist of str
Paths to the checkpoints to compare. Required if
model_names
andproject_dir
are None.
- scoresnumpy array
Expected marginal likelihood score for each model.
- standard_errorsnumpy array
Standard error of the expected marginal likelihood score for each model.