diff --git a/pytorch3d/implicitron/models/visualization/render_flyaround.py b/pytorch3d/implicitron/models/visualization/render_flyaround.py index a1634616..4d7c23fe 100644 --- a/pytorch3d/implicitron/models/visualization/render_flyaround.py +++ b/pytorch3d/implicitron/models/visualization/render_flyaround.py @@ -190,6 +190,7 @@ def render_flyaround( sequence_name=batch.sequence_name[0], viz=viz, viz_env=visdom_environment, + predicted_keys=visualize_preds_keys, ) logger.info(f"Exporting videos for sequence {sequence_name} ...") @@ -202,6 +203,7 @@ def render_flyaround( video_path=output_video_path, resize=video_resize, video_frames_dir=output_video_frames_dir, + predicted_keys=visualize_preds_keys, ) @@ -338,10 +340,15 @@ def _generate_prediction_videos( # init a video writer for each predicted key vws = {} for k in predicted_keys: + cache_dir = ( + None + if video_frames_dir is None + else os.path.join(video_frames_dir, f"{sequence_name}_{k}") + ) vws[k] = VideoWriter( fps=fps, out_path=f"{video_path}_{sequence_name}_{k}.mp4", - cache_dir=os.path.join(video_frames_dir, f"{sequence_name}_{k}"), + cache_dir=cache_dir, ) for rendered_pred in tqdm(preds): diff --git a/tests/implicitron/test_model_visualize.py b/tests/implicitron/test_model_visualize.py index 1815d678..2c9b1f9a 100644 --- a/tests/implicitron/test_model_visualize.py +++ b/tests/implicitron/test_model_visualize.py @@ -80,35 +80,36 @@ class TestModelVisualize(unittest.TestCase): os.makedirs(output_dir, exist_ok=True) - render_flyaround( - train_dataset, - show_sequence_name, - model, - video_path, - n_flyaround_poses=40, - fps=20, - max_angle=2 * math.pi, - trajectory_type="circular_lsq_fit", - trajectory_scale=1.1, - scene_center=(0.0, 0.0, 0.0), - up=(0.0, 1.0, 0.0), - traj_offset=1.0, - n_source_views=1, - visdom_show_preds=visdom_show_preds, - visdom_environment="test_model_visalize", - visdom_server="http://127.0.0.1", - visdom_port=8097, - num_workers=10, - seed=None, - video_resize=None, - visualize_preds_keys=[ - "images_render", - "depths_render", - "masks_render", - "_all_source_images", - ], - output_video_frames_dir=video_path, - ) + for output_video_frames_dir in [None, video_path]: + render_flyaround( + train_dataset, + show_sequence_name, + model, + video_path, + n_flyaround_poses=10, + fps=5, + max_angle=2 * math.pi, + trajectory_type="circular_lsq_fit", + trajectory_scale=1.1, + scene_center=(0.0, 0.0, 0.0), + up=(0.0, 1.0, 0.0), + traj_offset=1.0, + n_source_views=1, + visdom_show_preds=visdom_show_preds, + visdom_environment="test_model_visalize", + visdom_server="http://127.0.0.1", + visdom_port=8097, + num_workers=10, + seed=None, + video_resize=None, + visualize_preds_keys=[ + "images_render", + "depths_render", + "masks_render", + "_all_source_images", + ], + output_video_frames_dir=output_video_frames_dir, + ) class _PointcloudRenderingModel(torch.nn.Module):