mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-01 03:12:49 +08:00
render_flyaround bugfix
Summary: Fixes a bug which would crash render_flyaround anytime visualize_preds_keys is adjusted Reviewed By: shapovalov Differential Revision: D41124462 fbshipit-source-id: 127045a91a055909f8bd56c8af81afac02c00f60
This commit is contained in:
parent
35f8cb9430
commit
94f321fa3d
@ -10,7 +10,17 @@ import logging
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
from typing import Any, Dict, List, Optional, Sequence, Tuple, TYPE_CHECKING, Union
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
TYPE_CHECKING,
|
||||
Union,
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -180,7 +190,7 @@ def render_flyaround(
|
||||
preds.update(net_input) # merge everything into one big dict
|
||||
|
||||
# Render the predictions to images
|
||||
rendered_pred = _images_from_preds(preds)
|
||||
rendered_pred = _images_from_preds(preds, extract_keys=visualize_preds_keys)
|
||||
preds_total.append(rendered_pred)
|
||||
|
||||
# show the preds every 5% of the export iterations
|
||||
@ -223,9 +233,9 @@ def _load_whole_dataset(
|
||||
return next(iter(load_all_dataloader))
|
||||
|
||||
|
||||
def _images_from_preds(preds: Dict[str, Any]) -> Dict[str, torch.Tensor]:
|
||||
imout = {}
|
||||
for k in (
|
||||
def _images_from_preds(
|
||||
preds: Dict[str, Any],
|
||||
extract_keys: Iterable[str] = (
|
||||
"image_rgb",
|
||||
"images_render",
|
||||
"fg_probability",
|
||||
@ -233,7 +243,10 @@ def _images_from_preds(preds: Dict[str, Any]) -> Dict[str, torch.Tensor]:
|
||||
"depths_render",
|
||||
"depth_map",
|
||||
"_all_source_images",
|
||||
):
|
||||
),
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
imout = {}
|
||||
for k in extract_keys:
|
||||
if k == "_all_source_images" and "image_rgb" in preds:
|
||||
src_ims = preds["image_rgb"][1:].cpu().detach().clone()
|
||||
v = _stack_images(src_ims, None)[None]
|
||||
@ -343,6 +356,9 @@ def _generate_prediction_videos(
|
||||
# init a video writer for each predicted key
|
||||
vws = {}
|
||||
for k in predicted_keys:
|
||||
if k not in preds[0]:
|
||||
logger.warn(f"Cannot generate video for prediction key '{k}'")
|
||||
continue
|
||||
cache_dir = (
|
||||
None
|
||||
if video_frames_dir is None
|
||||
@ -355,13 +371,15 @@ def _generate_prediction_videos(
|
||||
)
|
||||
|
||||
for rendered_pred in tqdm(preds):
|
||||
for k in predicted_keys:
|
||||
for k in vws:
|
||||
vws[k].write_frame(
|
||||
rendered_pred[k][0].clip(0.0, 1.0).detach().cpu().numpy(),
|
||||
resize=resize,
|
||||
)
|
||||
|
||||
for k in predicted_keys:
|
||||
if k not in vws:
|
||||
continue
|
||||
vws[k].get_video()
|
||||
logger.info(f"Generated {vws[k].out_path}.")
|
||||
if viz is not None:
|
||||
|
Loading…
x
Reference in New Issue
Block a user