mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
render_flyaround bugfix
Summary: Fixes a bug which would crash render_flyaround anytime visualize_preds_keys is adjusted Reviewed By: shapovalov Differential Revision: D41124462 fbshipit-source-id: 127045a91a055909f8bd56c8af81afac02c00f60
This commit is contained in:
parent
35f8cb9430
commit
94f321fa3d
@ -10,7 +10,17 @@ import logging
|
|||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
from typing import Any, Dict, List, Optional, Sequence, Tuple, TYPE_CHECKING, Union
|
from typing import (
|
||||||
|
Any,
|
||||||
|
Dict,
|
||||||
|
Iterable,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
Sequence,
|
||||||
|
Tuple,
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -180,7 +190,7 @@ def render_flyaround(
|
|||||||
preds.update(net_input) # merge everything into one big dict
|
preds.update(net_input) # merge everything into one big dict
|
||||||
|
|
||||||
# Render the predictions to images
|
# Render the predictions to images
|
||||||
rendered_pred = _images_from_preds(preds)
|
rendered_pred = _images_from_preds(preds, extract_keys=visualize_preds_keys)
|
||||||
preds_total.append(rendered_pred)
|
preds_total.append(rendered_pred)
|
||||||
|
|
||||||
# show the preds every 5% of the export iterations
|
# show the preds every 5% of the export iterations
|
||||||
@ -223,9 +233,9 @@ def _load_whole_dataset(
|
|||||||
return next(iter(load_all_dataloader))
|
return next(iter(load_all_dataloader))
|
||||||
|
|
||||||
|
|
||||||
def _images_from_preds(preds: Dict[str, Any]) -> Dict[str, torch.Tensor]:
|
def _images_from_preds(
|
||||||
imout = {}
|
preds: Dict[str, Any],
|
||||||
for k in (
|
extract_keys: Iterable[str] = (
|
||||||
"image_rgb",
|
"image_rgb",
|
||||||
"images_render",
|
"images_render",
|
||||||
"fg_probability",
|
"fg_probability",
|
||||||
@ -233,7 +243,10 @@ def _images_from_preds(preds: Dict[str, Any]) -> Dict[str, torch.Tensor]:
|
|||||||
"depths_render",
|
"depths_render",
|
||||||
"depth_map",
|
"depth_map",
|
||||||
"_all_source_images",
|
"_all_source_images",
|
||||||
):
|
),
|
||||||
|
) -> Dict[str, torch.Tensor]:
|
||||||
|
imout = {}
|
||||||
|
for k in extract_keys:
|
||||||
if k == "_all_source_images" and "image_rgb" in preds:
|
if k == "_all_source_images" and "image_rgb" in preds:
|
||||||
src_ims = preds["image_rgb"][1:].cpu().detach().clone()
|
src_ims = preds["image_rgb"][1:].cpu().detach().clone()
|
||||||
v = _stack_images(src_ims, None)[None]
|
v = _stack_images(src_ims, None)[None]
|
||||||
@ -343,6 +356,9 @@ def _generate_prediction_videos(
|
|||||||
# init a video writer for each predicted key
|
# init a video writer for each predicted key
|
||||||
vws = {}
|
vws = {}
|
||||||
for k in predicted_keys:
|
for k in predicted_keys:
|
||||||
|
if k not in preds[0]:
|
||||||
|
logger.warn(f"Cannot generate video for prediction key '{k}'")
|
||||||
|
continue
|
||||||
cache_dir = (
|
cache_dir = (
|
||||||
None
|
None
|
||||||
if video_frames_dir is None
|
if video_frames_dir is None
|
||||||
@ -355,13 +371,15 @@ def _generate_prediction_videos(
|
|||||||
)
|
)
|
||||||
|
|
||||||
for rendered_pred in tqdm(preds):
|
for rendered_pred in tqdm(preds):
|
||||||
for k in predicted_keys:
|
for k in vws:
|
||||||
vws[k].write_frame(
|
vws[k].write_frame(
|
||||||
rendered_pred[k][0].clip(0.0, 1.0).detach().cpu().numpy(),
|
rendered_pred[k][0].clip(0.0, 1.0).detach().cpu().numpy(),
|
||||||
resize=resize,
|
resize=resize,
|
||||||
)
|
)
|
||||||
|
|
||||||
for k in predicted_keys:
|
for k in predicted_keys:
|
||||||
|
if k not in vws:
|
||||||
|
continue
|
||||||
vws[k].get_video()
|
vws[k].get_video()
|
||||||
logger.info(f"Generated {vws[k].out_path}.")
|
logger.info(f"Generated {vws[k].out_path}.")
|
||||||
if viz is not None:
|
if viz is not None:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user