Optional ground-truth depth maps in visualiser

Summary: The code does not crash if depth map/mask are not given.

Reviewed By: bottler

Differential Revision: D45082985

fbshipit-source-id: 3610d8beb4ac897fbbe52f56a6dd012a6365b89b
This commit is contained in:
Roman Shapovalov 2023-04-18 07:00:17 -07:00 committed by Facebook GitHub Bot
parent 1af6bf4768
commit 0e3138eca8
3 changed files with 42 additions and 36 deletions

View File

@ -38,8 +38,8 @@ class _Visualizer:
image_render: torch.Tensor
image_rgb_masked: torch.Tensor
depth_render: torch.Tensor
depth_map: torch.Tensor
depth_mask: torch.Tensor
depth_map: Optional[torch.Tensor]
depth_mask: Optional[torch.Tensor]
visdom_env: str = "eval_debug"
@ -75,9 +75,11 @@ class _Visualizer:
viz = self._viz
viz.images(
torch.cat(
(
make_depth_image(self.depth_render, loss_mask_now),
make_depth_image(self.depth_map, loss_mask_now),
(make_depth_image(self.depth_render, loss_mask_now),)
+ (
(make_depth_image(self.depth_map, loss_mask_now),)
if self.depth_map is not None
else ()
),
dim=3,
),
@ -91,6 +93,7 @@ class _Visualizer:
win="depth_abs" + name_postfix + "_mask",
opts={"title": f"depth_abs_{name_postfix}_{depth_loss:1.2f}_mask"},
)
if self.depth_mask is not None:
viz.images(
self.depth_mask,
env=self.visdom_env,
@ -104,7 +107,8 @@ class _Visualizer:
viewpoint_trivial: PerspectiveCameras = PerspectiveCameras().to(
loss_mask_now.device
)
pcl_pred = get_rgbd_point_cloud(
_pcls = {
"pred_depth": get_rgbd_point_cloud(
viewpoint_trivial,
self.image_render,
self.depth_render,
@ -112,7 +116,9 @@ class _Visualizer:
torch.ones_like(self.depth_render),
# loss_mask_now,
)
pcl_gt = get_rgbd_point_cloud(
}
if self.depth_map is not None:
_pcls["gt_depth"] = get_rgbd_point_cloud(
viewpoint_trivial,
self.image_rgb_masked,
self.depth_map,
@ -120,13 +126,11 @@ class _Visualizer:
torch.ones_like(self.depth_map),
# loss_mask_now,
)
_pcls = {
pn: p
for pn, p in zip(("pred_depth", "gt_depth"), (pcl_pred, pcl_gt))
if int(p.num_points_per_cloud()) > 0
}
_pcls = {pn: p for pn, p in _pcls.items() if int(p.num_points_per_cloud()) > 0}
plotlyplot = plot_scene(
{f"pcl{name_postfix}": _pcls},
{f"pcl{name_postfix}": _pcls}, # pyre-ignore
camera_scale=1.0,
pointcloud_max_points=10000,
pointcloud_marker_size=1,
@ -277,10 +281,10 @@ def eval_batch(
image_render=image_render,
image_rgb_masked=image_rgb_masked,
depth_render=cloned_render["depth_render"],
# pyre-fixme[6]: Expected `Tensor` for 4th param but got
# `Optional[torch.Tensor]`.
depth_map=frame_data.depth_map,
depth_mask=frame_data.depth_mask[:1],
depth_mask=frame_data.depth_mask[:1]
if frame_data.depth_mask is not None
else None,
visdom_env=visualize_visdom_env,
)

View File

@ -360,7 +360,7 @@ class GenericModel(ImplicitronModelBase): # pyre-ignore: 13
and source images, which will be used for intersecting with target rays.
fg_probability: A tensor of shape `(B, 1, H, W)` containing a batch of
foreground masks.
mask_crop: A binary tensor of shape `(B, 1, H, W)` deonting valid
mask_crop: A binary tensor of shape `(B, 1, H, W)` denoting valid
regions in the input images (i.e. regions that do not correspond
to, e.g., zero-padding). When the `RaySampler`'s sampling mode is set to
"mask_sample", rays will be sampled in the non zero regions.

View File

@ -26,6 +26,8 @@ from tests.common_testing import interactive_testing_requested
from .common_resources import get_skateboard_data
VISDOM_PORT = int(os.environ.get("VISDOM_PORT", 8097))
class TestDatasetVisualize(unittest.TestCase):
def setUp(self):
@ -77,7 +79,7 @@ class TestDatasetVisualize(unittest.TestCase):
for k, dataset in self.datasets.items()
}
)
self.visdom = Visdom()
self.visdom = Visdom(port=VISDOM_PORT)
if not self.visdom.check_connection():
print("Visdom server not running! Disabling visdom visualizations.")
self.visdom = None