Optional ground-truth depth maps in visualiser

Summary: The code does not crash if depth map/mask are not given.

Reviewed By: bottler

Differential Revision: D45082985

fbshipit-source-id: 3610d8beb4ac897fbbe52f56a6dd012a6365b89b
This commit is contained in:
Roman Shapovalov 2023-04-18 07:00:17 -07:00 committed by Facebook GitHub Bot
parent 1af6bf4768
commit 0e3138eca8
3 changed files with 42 additions and 36 deletions

View File

@ -38,8 +38,8 @@ class _Visualizer:
image_render: torch.Tensor image_render: torch.Tensor
image_rgb_masked: torch.Tensor image_rgb_masked: torch.Tensor
depth_render: torch.Tensor depth_render: torch.Tensor
depth_map: torch.Tensor depth_map: Optional[torch.Tensor]
depth_mask: torch.Tensor depth_mask: Optional[torch.Tensor]
visdom_env: str = "eval_debug" visdom_env: str = "eval_debug"
@ -75,9 +75,11 @@ class _Visualizer:
viz = self._viz viz = self._viz
viz.images( viz.images(
torch.cat( torch.cat(
( (make_depth_image(self.depth_render, loss_mask_now),)
make_depth_image(self.depth_render, loss_mask_now), + (
make_depth_image(self.depth_map, loss_mask_now), (make_depth_image(self.depth_map, loss_mask_now),)
if self.depth_map is not None
else ()
), ),
dim=3, dim=3,
), ),
@ -91,12 +93,13 @@ class _Visualizer:
win="depth_abs" + name_postfix + "_mask", win="depth_abs" + name_postfix + "_mask",
opts={"title": f"depth_abs_{name_postfix}_{depth_loss:1.2f}_mask"}, opts={"title": f"depth_abs_{name_postfix}_{depth_loss:1.2f}_mask"},
) )
viz.images( if self.depth_mask is not None:
self.depth_mask, viz.images(
env=self.visdom_env, self.depth_mask,
win="depth_abs" + name_postfix + "_maskd", env=self.visdom_env,
opts={"title": f"depth_abs_{name_postfix}_{depth_loss:1.2f}_maskd"}, win="depth_abs" + name_postfix + "_maskd",
) opts={"title": f"depth_abs_{name_postfix}_{depth_loss:1.2f}_maskd"},
)
# show the 3D plot # show the 3D plot
# pyre-fixme[9]: viewpoint_trivial has type `PerspectiveCameras`; used as # pyre-fixme[9]: viewpoint_trivial has type `PerspectiveCameras`; used as
@ -104,29 +107,30 @@ class _Visualizer:
viewpoint_trivial: PerspectiveCameras = PerspectiveCameras().to( viewpoint_trivial: PerspectiveCameras = PerspectiveCameras().to(
loss_mask_now.device loss_mask_now.device
) )
pcl_pred = get_rgbd_point_cloud(
viewpoint_trivial,
self.image_render,
self.depth_render,
# mask_crop,
torch.ones_like(self.depth_render),
# loss_mask_now,
)
pcl_gt = get_rgbd_point_cloud(
viewpoint_trivial,
self.image_rgb_masked,
self.depth_map,
# mask_crop,
torch.ones_like(self.depth_map),
# loss_mask_now,
)
_pcls = { _pcls = {
pn: p "pred_depth": get_rgbd_point_cloud(
for pn, p in zip(("pred_depth", "gt_depth"), (pcl_pred, pcl_gt)) viewpoint_trivial,
if int(p.num_points_per_cloud()) > 0 self.image_render,
self.depth_render,
# mask_crop,
torch.ones_like(self.depth_render),
# loss_mask_now,
)
} }
if self.depth_map is not None:
_pcls["gt_depth"] = get_rgbd_point_cloud(
viewpoint_trivial,
self.image_rgb_masked,
self.depth_map,
# mask_crop,
torch.ones_like(self.depth_map),
# loss_mask_now,
)
_pcls = {pn: p for pn, p in _pcls.items() if int(p.num_points_per_cloud()) > 0}
plotlyplot = plot_scene( plotlyplot = plot_scene(
{f"pcl{name_postfix}": _pcls}, {f"pcl{name_postfix}": _pcls}, # pyre-ignore
camera_scale=1.0, camera_scale=1.0,
pointcloud_max_points=10000, pointcloud_max_points=10000,
pointcloud_marker_size=1, pointcloud_marker_size=1,
@ -277,10 +281,10 @@ def eval_batch(
image_render=image_render, image_render=image_render,
image_rgb_masked=image_rgb_masked, image_rgb_masked=image_rgb_masked,
depth_render=cloned_render["depth_render"], depth_render=cloned_render["depth_render"],
# pyre-fixme[6]: Expected `Tensor` for 4th param but got
# `Optional[torch.Tensor]`.
depth_map=frame_data.depth_map, depth_map=frame_data.depth_map,
depth_mask=frame_data.depth_mask[:1], depth_mask=frame_data.depth_mask[:1]
if frame_data.depth_mask is not None
else None,
visdom_env=visualize_visdom_env, visdom_env=visualize_visdom_env,
) )

View File

@ -360,7 +360,7 @@ class GenericModel(ImplicitronModelBase): # pyre-ignore: 13
and source images, which will be used for intersecting with target rays. and source images, which will be used for intersecting with target rays.
fg_probability: A tensor of shape `(B, 1, H, W)` containing a batch of fg_probability: A tensor of shape `(B, 1, H, W)` containing a batch of
foreground masks. foreground masks.
mask_crop: A binary tensor of shape `(B, 1, H, W)` deonting valid mask_crop: A binary tensor of shape `(B, 1, H, W)` denoting valid
regions in the input images (i.e. regions that do not correspond regions in the input images (i.e. regions that do not correspond
to, e.g., zero-padding). When the `RaySampler`'s sampling mode is set to to, e.g., zero-padding). When the `RaySampler`'s sampling mode is set to
"mask_sample", rays will be sampled in the non zero regions. "mask_sample", rays will be sampled in the non zero regions.

View File

@ -26,6 +26,8 @@ from tests.common_testing import interactive_testing_requested
from .common_resources import get_skateboard_data from .common_resources import get_skateboard_data
VISDOM_PORT = int(os.environ.get("VISDOM_PORT", 8097))
class TestDatasetVisualize(unittest.TestCase): class TestDatasetVisualize(unittest.TestCase):
def setUp(self): def setUp(self):
@ -77,7 +79,7 @@ class TestDatasetVisualize(unittest.TestCase):
for k, dataset in self.datasets.items() for k, dataset in self.datasets.items()
} }
) )
self.visdom = Visdom() self.visdom = Visdom(port=VISDOM_PORT)
if not self.visdom.check_connection(): if not self.visdom.check_connection():
print("Visdom server not running! Disabling visdom visualizations.") print("Visdom server not running! Disabling visdom visualizations.")
self.visdom = None self.visdom = None