import numpy as np
import tqdm
import os
from textwrap import fill
from vidio.read import OpenCVReader
from keypoint_moseq.io import update_config
from keypoint_moseq.util import find_matching_videos, get_edges
[docs]
def sample_error_frames(
confidences,
bodyparts,
use_bodyparts,
num_bins=10,
num_samples=100,
conf_pseudocount=1e-3,
):
"""Randomly sample frames, enriching for those with low confidence keypoint
detections.
Parameters
----------
confidences: dict
Keypoint detection confidences for a collection of recordings
bodyparts: list
Label for each keypoint represented in `confidences`
use_bodyparts: list
Ordered subset of keypoint labels to be used for modeling
num_bins: int, default=10
Number of bins to use for enriching low-confidence keypoint
detections. Confidence values for all used keypoints are
divided into log-spaced bins and an equal number of instances
are sampled from each bin.
num_samples: int, default=100
Total number of frames to sample
conf_pseudocount: float, default=1e-3
Pseudocount used to augment keypoint confidences.
Returns
-------
sample_keys: list of tuples
List of sampled frames as tuples with format
(key, frame_number, bodypart)
"""
confidences = {k: v + conf_pseudocount for k, v in confidences.items()}
all_confs = np.concatenate([v.flatten() for v in confidences.values()])
min_conf, max_conf = np.nanmin(all_confs), np.nanmax(all_confs)
thresholds = np.logspace(np.log10(min_conf), np.log10(max_conf), num_bins)
mask = np.array([bp in use_bodyparts for bp in bodyparts])[None, :]
sample_keys = []
for low, high in zip(thresholds[:-1], thresholds[1:]):
samples_in_bin = []
for key, confs in confidences.items():
for t, k in zip(*np.nonzero((confs >= low) * (confs < high) * mask)):
samples_in_bin.append((key, t, bodyparts[k]))
if len(samples_in_bin) > 0:
n = min(num_samples // num_bins, len(samples_in_bin))
for i in np.random.choice(len(samples_in_bin), n, replace=False):
sample_keys.append(samples_in_bin[i])
sample_keys = [sample_keys[i] for i in np.random.permutation(len(sample_keys))]
return sample_keys
[docs]
def load_sampled_frames(
sample_keys, video_dir, video_extension=None, downsample_rate=1
):
"""Load sampled frames from a directory of videos.
Parameters
----------
sample_keys: list of tuples
List of sampled frames as tuples with format
(key, frame_number, bodypart)
video_dir: str
Path to directory containing videos
video_extension: str, default=None
Preferred video extension (passed to :py:func:`keypoint_moseq.util.find_matching_videos`)
downsample_rate: int, default=1
Downsampling rate for the video frames. Only change if keypoint detections were
also downsampled.
Returns
-------
sample_keys: dict
Dictionary mapping elements from `sample_keys` to the
corresponding videos frames.
"""
keys = sorted(set([k[0] for k in sample_keys]))
videos = find_matching_videos(keys, video_dir)
key_to_video = dict(zip(keys, videos))
readers = {key: OpenCVReader(video) for key, video in zip(keys, videos)}
pbar = tqdm.tqdm(
sample_keys,
desc="Loading sample frames",
position=0,
leave=True,
ncols=72,
)
return {
(key, frame, bodypart): readers[key][frame * downsample_rate]
for key, frame, bodypart in pbar
}
[docs]
def load_annotations(project_dir):
"""Reload saved calibration annotations.
Parameters
----------
project_dir: str
Load annotations from `{project_dir}/error_annotations.csv`
Returns
-------
annotations: dict
Dictionary mapping sample keys to annotated keypoint
coordinates. (See :py:func:`keypoint_moseq.calibration.sample_error_frames`
for format of sample keys)
"""
annotations = {}
annotations_path = os.path.join(project_dir, "error_annotations.csv")
if os.path.exists(annotations_path):
for l in open(annotations_path, "r").read().split("\n")[1:]:
key, frame, bodypart, x, y = l.split(",")
sample_key = (key, int(frame), bodypart)
annotations[sample_key] = (float(x), float(y))
return annotations
[docs]
def save_annotations(project_dir, annotations):
"""Save calibration annotations to a csv file.
Parameters
----------
project_dir: str
Save annotations to `{project_dir}/error_annotations.csv`
annotations: dict
Dictionary mapping sample keys to annotated keypoint
coordinates. (See :py:func:`keypoint_moseq.calibration.sample_error_frames`
for format of sample keys)
"""
output = ["key,frame,bodypart,x,y"]
for (key, frame, bodypart), (x, y) in annotations.items():
output.append(f"{key},{frame},{bodypart},{x},{y}")
path = os.path.join(project_dir, "error_annotations.csv")
open(path, "w").write("\n".join(output))
print(fill(f"Annotations saved to {path}"))
[docs]
def save_params(project_dir, estimator):
"""Save config parameters learned via calibration.
Parameters
----------
project_dir: str
Save parameters `{project_dir}/config.yml`
estimator: :py:func:`holoviews.streams.Stream`
Stream object with fields `conf_threshold`, `slope`, `intercept`
"""
update_config(
project_dir,
conf_threshold=float(estimator.conf_threshold),
slope=float(estimator.slope),
intercept=float(estimator.intercept),
)
def _confs_and_dists_from_annotations(coordinates, confidences, annotations, bodyparts):
confs, dists = [], []
for (key, frame, bodypart), xy in annotations.items():
if key in coordinates and key in confidences:
k = bodyparts.index(bodypart)
confs.append(confidences[key][frame][k])
dists.append(
np.sqrt(((coordinates[key][frame][k] - np.array(xy)) ** 2).sum())
)
return confs, dists
def _noise_calibration_widget(
project_dir,
coordinates,
confidences,
sample_keys,
sample_images,
annotations,
*,
keypoint_colormap,
bodyparts,
skeleton,
error_estimator,
conf_threshold,
**kwargs,
):
from scipy.stats import linregress
from holoviews.streams import Tap, Stream
import holoviews as hv
import panel as pn
from bokeh.models import GlyphRenderer, ImageRGBA, Scatter, GraphRenderer
hv.extension("bokeh")
max_height = np.max([sample_images[k].shape[0] for k in sample_keys])
max_width = np.max([sample_images[k].shape[1] for k in sample_keys])
edges = np.array(get_edges(bodyparts, skeleton))
conf_vals = np.hstack([v.flatten() for v in confidences.values()])
min_conf, max_conf = np.nanpercentile(conf_vals, 0.01), np.nanmax(conf_vals)
annotations_stream = Stream.define("Annotations", annotations=annotations)()
current_sample = Stream.define("Current sample", sample_ix=0)()
estimator = Stream.define(
"Estimator",
slope=float(error_estimator["slope"]),
intercept=float(error_estimator["intercept"]),
conf_threshold=float(conf_threshold),
)()
img_tap = Tap(transient=True)
vline_tap = Tap(transient=True)
def update_scatter(x, y, annotations):
confs, dists = _confs_and_dists_from_annotations(
coordinates, confidences, annotations, bodyparts
)
log_dists = np.log10(np.array(dists) + 1)
log_confs = np.log10(np.maximum(confs, min_conf))
max_dist = np.log10(np.sqrt(max_height**2 + max_width**2) + 1)
xspan = np.log10(max_conf) - np.log10(min_conf)
xlim = (
np.log10(min_conf) - xspan / 10,
np.log10(max_conf) + xspan / 10,
)
ylim = (-max_dist / 50, max_dist)
if len(log_dists) > 1:
m, b = linregress(log_confs, log_dists)[:2]
estimator.event(slope=m, intercept=b)
else:
m, b = estimator.slope, estimator.intercept
if x is None:
x = np.log10(conf_threshold)
else:
estimator.event(conf_threshold=10**x)
passing_percent = (conf_vals > 10**x).mean() * 100
scatter = hv.Scatter(zip(log_confs, log_dists)).opts(
color="k",
size=6,
xlim=xlim,
ylim=ylim,
axiswise=True,
frame_width=250,
default_tools=[],
)
curve = hv.Curve([(xlim[0], xlim[0] * m + b), (xlim[1], xlim[1] * m + b)]).opts(
xlim=xlim, ylim=ylim, axiswise=True, default_tools=[]
)
vline_label = hv.Text(
x - (xlim[1] - xlim[0]) / 50,
ylim[1] - (ylim[1] - ylim[0]) / 100,
f"confidence\nthreshold\n{10**x:.5f}\n({passing_percent:.1f}%)",
).opts(
axiswise=True,
text_align="right",
text_baseline="top",
text_font_size="8pt",
default_tools=[],
)
vline = hv.VLine(x).opts(
axiswise=True,
line_dash="dashed",
color="lightgray",
default_tools=[],
)
return (scatter * curve * vline * vline_label).opts(
toolbar=None,
default_tools=[],
xlabel="log10(confidence)",
ylabel="log10(error)",
)
def enforce_z_order_hook(plot, element):
bokeh_figure = plot.state
graph, scatter, rgb = None, None, None
for r in bokeh_figure.renderers:
if isinstance(r, GlyphRenderer):
if isinstance(r.glyph, ImageRGBA):
rgb = r
if isinstance(r.glyph, Scatter):
scatter = r
if isinstance(r, GraphRenderer):
graph = r
bokeh_figure.renderers = [rgb, graph, scatter]
def update_img(sample_ix, x, y):
key, frame, bodypart = sample_key = sample_keys[sample_ix]
image = sample_images[sample_key]
h, w = image.shape[:2]
keypoint_ix = bodyparts.index(bodypart)
xys = coordinates[key][frame].copy()
crop_size = np.sqrt(((xys - xys[keypoint_ix]) ** 2).sum(1)).max() * 2.5
xys[:, 1] = h - xys[:, 1]
masked_nodes = np.nonzero(~np.isnan(xys).any(1))[0]
confs = confidences[key][frame]
if x and y:
annotations_stream.annotations.update({sample_key: (x, h - y)})
annotations_stream.event()
if sample_key in annotations_stream.annotations:
point = np.array(annotations_stream.annotations[sample_key])
point[1] = h - point[1]
else:
point = xys[keypoint_ix]
colorvals = np.linspace(0, 1, len(bodyparts))
pt_data = np.append(point, colorvals[keypoint_ix])[None]
hv_point = hv.Points(pt_data, vdims=["bodypart"]).opts(
color="bodypart",
cmap="autumn",
size=15,
framewise=True,
marker="x",
line_width=3,
)
label = f"{bodypart}, confidence = {confs[keypoint_ix]:.5f}"
rgb = hv.RGB(image, bounds=(0, 0, w, h), label=label).opts(
framewise=True, xaxis="bare", yaxis="bare", frame_width=250
)
xlim = (
xys[keypoint_ix, 0] - crop_size / 2,
xys[keypoint_ix, 0] + crop_size / 2,
)
ylim = (
xys[keypoint_ix, 1] - crop_size / 2,
xys[keypoint_ix, 1] + crop_size / 2,
)
edge_data = ((), (), ())
if len(edges) > 0:
masked_edges = edges[np.isin(edges, masked_nodes).all(1)]
if len(masked_edges) > 0:
edge_data = (*masked_edges.T, colorvals[masked_edges[:, 0]])
sizes = np.where(np.arange(len(xys)) == keypoint_ix, 10, 6)[masked_nodes]
masked_bodyparts = [bodyparts[i] for i in masked_nodes]
nodes = hv.Nodes(
(*xys[masked_nodes].T, masked_nodes, masked_bodyparts, sizes),
vdims=["name", "size"],
)
graph = hv.Graph((edge_data, nodes), vdims="ecolor").opts(
node_color="name",
node_cmap=keypoint_colormap,
tools=[],
edge_color="ecolor",
edge_cmap=keypoint_colormap,
node_size="size",
)
return (rgb * graph * hv_point).opts(
data_aspect=1,
xlim=xlim,
ylim=ylim,
toolbar=None,
hooks=[enforce_z_order_hook],
)
def update_estimator_text(*, slope, intercept, conf_threshold):
lines = [
f"slope: {slope:.6f}",
f"intercept: {intercept:.6f}",
f"conf_threshold: {conf_threshold:.6f}",
]
estimator_textbox.value = "<br>".join(lines)
prev_button = pn.widgets.Button(name="\u25c0", width=50, align="center")
next_button = pn.widgets.Button(name="\u25b6", width=50, align="center")
save_button = pn.widgets.Button(name="Save", width=100, align="center")
estimator_textbox = pn.widgets.StaticText(align="center")
def next_sample(event):
if current_sample.sample_ix < len(sample_keys) - 1:
current_sample.event(sample_ix=int(current_sample.sample_ix) + 1)
def prev_sample(event):
if current_sample.sample_ix > 0:
current_sample.event(sample_ix=int(current_sample.sample_ix) - 1)
def save_all(event):
save_annotations(project_dir, annotations_stream.annotations)
save_params(project_dir, estimator)
prev_button.on_click(prev_sample)
next_button.on_click(next_sample)
save_button.on_click(save_all)
estimator.add_subscriber(update_estimator_text)
estimator.event()
img_dmap = hv.DynamicMap(
update_img,
streams=[current_sample, img_tap],
).opts(framewise=True)
scatter_dmap = hv.DynamicMap(
update_scatter,
streams=[annotations_stream, vline_tap],
).opts(framewise=True, axiswise=True)
controls = pn.Row(
prev_button,
next_button,
pn.Spacer(width=50),
save_button,
pn.Spacer(width=50),
estimator_textbox,
)
plots = pn.Row(img_dmap, scatter_dmap)
return pn.Column(controls, plots)
[docs]
def noise_calibration(
project_dir,
coordinates,
confidences,
*,
bodyparts,
use_bodyparts,
video_dir,
video_extension=None,
conf_pseudocount=0.001,
downsample_rate=1,
**kwargs,
):
"""Perform manual annotation to calibrate the relationship between keypoint
error and neural network confidence.
This function creates a widget for interactive annotation in jupyter lab.
Users mark correct keypoint locations for a sequence of frames, and a
regression line is fit to the `log(confidence), log(error)` pairs obtained
through annotation. The regression coefficients are used during modeling to
set a prior on the noise level for each keypoint on each frame.
Follow these steps to use the widget:
- After executing this function, a widget should appear with a
video frame in the center.
- Annotate the labeled bodypart in each frame by left-clicking
at the correct location. An "X" should appear there.
- Use the arrow buttons to annotate additional frames.
- Each annotation adds a point to the right-hand scatter plot.
Continue until the regression line stabilizes.
- At any point, adjust the confidence threshold by clicking on
the scatter plot. The confidence threshold is used to define
outlier keypoints for PCA and model initialization.
- Use the "save" button to store your annotations to disk and
save `slope`, `intercept`, and `confidence_threshold`
to the config.
Parameters
----------
project_dir: str
Project directory. Must contain a `config.yml` file.
coordinates: dict
Keypoint coordinates for a collection of recordings. Values
must be numpy arrays of shape (T,K,2) where K is the number
of keypoints. Keys can be any unique str, but must start with
the name of a videofile in `video_dir`.
confidences: dict
Nonnegative confidence values for the keypoints in
`coordinates` as numpy arrays of shape (T,K).
bodyparts: list
Label for each keypoint represented in `coordinates`
use_bodyparts: list
Ordered subset of keypoint labels to be used for modeling
video_dir: str
Path to directory containing videos. Each video should
correspond to a key in `coordinates`. The key must
contain the videoname as a prefix.
video_extension: str, default=None
Preferred video extension (used in :py:func:`keypoint_moseq.util.find_matching_videos`)
conf_pseudocount: float, default=0.001
Pseudocount added to confidence values to avoid log(0) errors.
downsample_rate: int, default=1
Downsampling rate for the video frames. Only change if keypoint detections were
also downsampled.
"""
dim = list(coordinates.values())[0].shape[-1]
assert dim == 2, "Calibration is only supported for 2D keypoints."
confidences = {k: v + conf_pseudocount for k, v in confidences.items()}
sample_keys = sample_error_frames(confidences, bodyparts, use_bodyparts)
annotations = load_annotations(project_dir)
sample_keys.extend(annotations.keys())
sample_images = load_sampled_frames(
sample_keys, video_dir, video_extension, downsample_rate
)
return _noise_calibration_widget(
project_dir,
coordinates,
confidences,
sample_keys,
sample_images,
annotations,
bodyparts=bodyparts,
**kwargs,
)