diff --git a/fmpose3d/inference_api/fmpose3d.py b/fmpose3d/inference_api/fmpose3d.py index ff4e23c..7a680fb 100644 --- a/fmpose3d/inference_api/fmpose3d.py +++ b/fmpose3d/inference_api/fmpose3d.py @@ -253,17 +253,18 @@ def predict( cv2.imwrite(p, frames[idx]) paths.append(p) - # Run DeepLabCut on each frame individually. + # Run DeepLabCut once for all frames. + predictions = superanimal_analyze_images( + superanimal_name=cfg.superanimal_name, + model_name=cfg.sa_model_name, + detector_name=cfg.detector_name, + images=paths, + max_individuals=cfg.max_individuals, + out_folder=tmpdir, + ) + # predictions: {image_path: {"bodyparts": (N_ind, K, 3), ...}} + # Iterate in input order to keep frame alignment stable. for img_path in paths: - predictions = superanimal_analyze_images( - superanimal_name=cfg.superanimal_name, - model_name=cfg.sa_model_name, - detector_name=cfg.detector_name, - images=img_path, - max_individuals=cfg.max_individuals, - out_folder=tmpdir, - ) - # predictions: {image_path: {"bodyparts": (N_ind, K, 3), ...}} payload = predictions.get(img_path) if isinstance(predictions, dict) else None if payload is None and isinstance(predictions, dict) and len(predictions) == 1: payload = next(iter(predictions.values()))