mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Optional ground-truth depth maps in visualiser
Summary: The code does not crash if depth map/mask are not given. Reviewed By: bottler Differential Revision: D45082985 fbshipit-source-id: 3610d8beb4ac897fbbe52f56a6dd012a6365b89b
This commit is contained in:
parent
1af6bf4768
commit
0e3138eca8
@ -38,8 +38,8 @@ class _Visualizer:
|
||||
image_render: torch.Tensor
|
||||
image_rgb_masked: torch.Tensor
|
||||
depth_render: torch.Tensor
|
||||
depth_map: torch.Tensor
|
||||
depth_mask: torch.Tensor
|
||||
depth_map: Optional[torch.Tensor]
|
||||
depth_mask: Optional[torch.Tensor]
|
||||
|
||||
visdom_env: str = "eval_debug"
|
||||
|
||||
@ -75,9 +75,11 @@ class _Visualizer:
|
||||
viz = self._viz
|
||||
viz.images(
|
||||
torch.cat(
|
||||
(
|
||||
make_depth_image(self.depth_render, loss_mask_now),
|
||||
make_depth_image(self.depth_map, loss_mask_now),
|
||||
(make_depth_image(self.depth_render, loss_mask_now),)
|
||||
+ (
|
||||
(make_depth_image(self.depth_map, loss_mask_now),)
|
||||
if self.depth_map is not None
|
||||
else ()
|
||||
),
|
||||
dim=3,
|
||||
),
|
||||
@ -91,6 +93,7 @@ class _Visualizer:
|
||||
win="depth_abs" + name_postfix + "_mask",
|
||||
opts={"title": f"depth_abs_{name_postfix}_{depth_loss:1.2f}_mask"},
|
||||
)
|
||||
if self.depth_mask is not None:
|
||||
viz.images(
|
||||
self.depth_mask,
|
||||
env=self.visdom_env,
|
||||
@ -104,7 +107,8 @@ class _Visualizer:
|
||||
viewpoint_trivial: PerspectiveCameras = PerspectiveCameras().to(
|
||||
loss_mask_now.device
|
||||
)
|
||||
pcl_pred = get_rgbd_point_cloud(
|
||||
_pcls = {
|
||||
"pred_depth": get_rgbd_point_cloud(
|
||||
viewpoint_trivial,
|
||||
self.image_render,
|
||||
self.depth_render,
|
||||
@ -112,7 +116,9 @@ class _Visualizer:
|
||||
torch.ones_like(self.depth_render),
|
||||
# loss_mask_now,
|
||||
)
|
||||
pcl_gt = get_rgbd_point_cloud(
|
||||
}
|
||||
if self.depth_map is not None:
|
||||
_pcls["gt_depth"] = get_rgbd_point_cloud(
|
||||
viewpoint_trivial,
|
||||
self.image_rgb_masked,
|
||||
self.depth_map,
|
||||
@ -120,13 +126,11 @@ class _Visualizer:
|
||||
torch.ones_like(self.depth_map),
|
||||
# loss_mask_now,
|
||||
)
|
||||
_pcls = {
|
||||
pn: p
|
||||
for pn, p in zip(("pred_depth", "gt_depth"), (pcl_pred, pcl_gt))
|
||||
if int(p.num_points_per_cloud()) > 0
|
||||
}
|
||||
|
||||
_pcls = {pn: p for pn, p in _pcls.items() if int(p.num_points_per_cloud()) > 0}
|
||||
|
||||
plotlyplot = plot_scene(
|
||||
{f"pcl{name_postfix}": _pcls},
|
||||
{f"pcl{name_postfix}": _pcls}, # pyre-ignore
|
||||
camera_scale=1.0,
|
||||
pointcloud_max_points=10000,
|
||||
pointcloud_marker_size=1,
|
||||
@ -277,10 +281,10 @@ def eval_batch(
|
||||
image_render=image_render,
|
||||
image_rgb_masked=image_rgb_masked,
|
||||
depth_render=cloned_render["depth_render"],
|
||||
# pyre-fixme[6]: Expected `Tensor` for 4th param but got
|
||||
# `Optional[torch.Tensor]`.
|
||||
depth_map=frame_data.depth_map,
|
||||
depth_mask=frame_data.depth_mask[:1],
|
||||
depth_mask=frame_data.depth_mask[:1]
|
||||
if frame_data.depth_mask is not None
|
||||
else None,
|
||||
visdom_env=visualize_visdom_env,
|
||||
)
|
||||
|
||||
|
@ -360,7 +360,7 @@ class GenericModel(ImplicitronModelBase): # pyre-ignore: 13
|
||||
and source images, which will be used for intersecting with target rays.
|
||||
fg_probability: A tensor of shape `(B, 1, H, W)` containing a batch of
|
||||
foreground masks.
|
||||
mask_crop: A binary tensor of shape `(B, 1, H, W)` deonting valid
|
||||
mask_crop: A binary tensor of shape `(B, 1, H, W)` denoting valid
|
||||
regions in the input images (i.e. regions that do not correspond
|
||||
to, e.g., zero-padding). When the `RaySampler`'s sampling mode is set to
|
||||
"mask_sample", rays will be sampled in the non zero regions.
|
||||
|
@ -26,6 +26,8 @@ from tests.common_testing import interactive_testing_requested
|
||||
|
||||
from .common_resources import get_skateboard_data
|
||||
|
||||
VISDOM_PORT = int(os.environ.get("VISDOM_PORT", 8097))
|
||||
|
||||
|
||||
class TestDatasetVisualize(unittest.TestCase):
|
||||
def setUp(self):
|
||||
@ -77,7 +79,7 @@ class TestDatasetVisualize(unittest.TestCase):
|
||||
for k, dataset in self.datasets.items()
|
||||
}
|
||||
)
|
||||
self.visdom = Visdom()
|
||||
self.visdom = Visdom(port=VISDOM_PORT)
|
||||
if not self.visdom.check_connection():
|
||||
print("Visdom server not running! Disabling visdom visualizations.")
|
||||
self.visdom = None
|
||||
|
Loading…
x
Reference in New Issue
Block a user