Bugfixes in render_flyaround

Summary: Fixes bugs in render_flyaround

Reviewed By: bottler

Differential Revision: D39271932

fbshipit-source-id: 07e6c9ee07ba91feb437b725af0a8942fd98db0b
This commit is contained in:
David Novotny 2022-09-06 05:45:31 -07:00 committed by Facebook GitHub Bot
parent c79c954dea
commit f6d43eaa62
2 changed files with 38 additions and 30 deletions

View File

@ -190,6 +190,7 @@ def render_flyaround(
sequence_name=batch.sequence_name[0], sequence_name=batch.sequence_name[0],
viz=viz, viz=viz,
viz_env=visdom_environment, viz_env=visdom_environment,
predicted_keys=visualize_preds_keys,
) )
logger.info(f"Exporting videos for sequence {sequence_name} ...") logger.info(f"Exporting videos for sequence {sequence_name} ...")
@ -202,6 +203,7 @@ def render_flyaround(
video_path=output_video_path, video_path=output_video_path,
resize=video_resize, resize=video_resize,
video_frames_dir=output_video_frames_dir, video_frames_dir=output_video_frames_dir,
predicted_keys=visualize_preds_keys,
) )
@ -338,10 +340,15 @@ def _generate_prediction_videos(
# init a video writer for each predicted key # init a video writer for each predicted key
vws = {} vws = {}
for k in predicted_keys: for k in predicted_keys:
cache_dir = (
None
if video_frames_dir is None
else os.path.join(video_frames_dir, f"{sequence_name}_{k}")
)
vws[k] = VideoWriter( vws[k] = VideoWriter(
fps=fps, fps=fps,
out_path=f"{video_path}_{sequence_name}_{k}.mp4", out_path=f"{video_path}_{sequence_name}_{k}.mp4",
cache_dir=os.path.join(video_frames_dir, f"{sequence_name}_{k}"), cache_dir=cache_dir,
) )
for rendered_pred in tqdm(preds): for rendered_pred in tqdm(preds):

View File

@ -80,35 +80,36 @@ class TestModelVisualize(unittest.TestCase):
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
render_flyaround( for output_video_frames_dir in [None, video_path]:
train_dataset, render_flyaround(
show_sequence_name, train_dataset,
model, show_sequence_name,
video_path, model,
n_flyaround_poses=40, video_path,
fps=20, n_flyaround_poses=10,
max_angle=2 * math.pi, fps=5,
trajectory_type="circular_lsq_fit", max_angle=2 * math.pi,
trajectory_scale=1.1, trajectory_type="circular_lsq_fit",
scene_center=(0.0, 0.0, 0.0), trajectory_scale=1.1,
up=(0.0, 1.0, 0.0), scene_center=(0.0, 0.0, 0.0),
traj_offset=1.0, up=(0.0, 1.0, 0.0),
n_source_views=1, traj_offset=1.0,
visdom_show_preds=visdom_show_preds, n_source_views=1,
visdom_environment="test_model_visalize", visdom_show_preds=visdom_show_preds,
visdom_server="http://127.0.0.1", visdom_environment="test_model_visalize",
visdom_port=8097, visdom_server="http://127.0.0.1",
num_workers=10, visdom_port=8097,
seed=None, num_workers=10,
video_resize=None, seed=None,
visualize_preds_keys=[ video_resize=None,
"images_render", visualize_preds_keys=[
"depths_render", "images_render",
"masks_render", "depths_render",
"_all_source_images", "masks_render",
], "_all_source_images",
output_video_frames_dir=video_path, ],
) output_video_frames_dir=output_video_frames_dir,
)
class _PointcloudRenderingModel(torch.nn.Module): class _PointcloudRenderingModel(torch.nn.Module):