mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Optional ground-truth depth maps in visualiser
Summary: The code does not crash if depth map/mask are not given. Reviewed By: bottler Differential Revision: D45082985 fbshipit-source-id: 3610d8beb4ac897fbbe52f56a6dd012a6365b89b
This commit is contained in:
parent
1af6bf4768
commit
0e3138eca8
@ -38,8 +38,8 @@ class _Visualizer:
|
|||||||
image_render: torch.Tensor
|
image_render: torch.Tensor
|
||||||
image_rgb_masked: torch.Tensor
|
image_rgb_masked: torch.Tensor
|
||||||
depth_render: torch.Tensor
|
depth_render: torch.Tensor
|
||||||
depth_map: torch.Tensor
|
depth_map: Optional[torch.Tensor]
|
||||||
depth_mask: torch.Tensor
|
depth_mask: Optional[torch.Tensor]
|
||||||
|
|
||||||
visdom_env: str = "eval_debug"
|
visdom_env: str = "eval_debug"
|
||||||
|
|
||||||
@ -75,9 +75,11 @@ class _Visualizer:
|
|||||||
viz = self._viz
|
viz = self._viz
|
||||||
viz.images(
|
viz.images(
|
||||||
torch.cat(
|
torch.cat(
|
||||||
(
|
(make_depth_image(self.depth_render, loss_mask_now),)
|
||||||
make_depth_image(self.depth_render, loss_mask_now),
|
+ (
|
||||||
make_depth_image(self.depth_map, loss_mask_now),
|
(make_depth_image(self.depth_map, loss_mask_now),)
|
||||||
|
if self.depth_map is not None
|
||||||
|
else ()
|
||||||
),
|
),
|
||||||
dim=3,
|
dim=3,
|
||||||
),
|
),
|
||||||
@ -91,12 +93,13 @@ class _Visualizer:
|
|||||||
win="depth_abs" + name_postfix + "_mask",
|
win="depth_abs" + name_postfix + "_mask",
|
||||||
opts={"title": f"depth_abs_{name_postfix}_{depth_loss:1.2f}_mask"},
|
opts={"title": f"depth_abs_{name_postfix}_{depth_loss:1.2f}_mask"},
|
||||||
)
|
)
|
||||||
viz.images(
|
if self.depth_mask is not None:
|
||||||
self.depth_mask,
|
viz.images(
|
||||||
env=self.visdom_env,
|
self.depth_mask,
|
||||||
win="depth_abs" + name_postfix + "_maskd",
|
env=self.visdom_env,
|
||||||
opts={"title": f"depth_abs_{name_postfix}_{depth_loss:1.2f}_maskd"},
|
win="depth_abs" + name_postfix + "_maskd",
|
||||||
)
|
opts={"title": f"depth_abs_{name_postfix}_{depth_loss:1.2f}_maskd"},
|
||||||
|
)
|
||||||
|
|
||||||
# show the 3D plot
|
# show the 3D plot
|
||||||
# pyre-fixme[9]: viewpoint_trivial has type `PerspectiveCameras`; used as
|
# pyre-fixme[9]: viewpoint_trivial has type `PerspectiveCameras`; used as
|
||||||
@ -104,29 +107,30 @@ class _Visualizer:
|
|||||||
viewpoint_trivial: PerspectiveCameras = PerspectiveCameras().to(
|
viewpoint_trivial: PerspectiveCameras = PerspectiveCameras().to(
|
||||||
loss_mask_now.device
|
loss_mask_now.device
|
||||||
)
|
)
|
||||||
pcl_pred = get_rgbd_point_cloud(
|
|
||||||
viewpoint_trivial,
|
|
||||||
self.image_render,
|
|
||||||
self.depth_render,
|
|
||||||
# mask_crop,
|
|
||||||
torch.ones_like(self.depth_render),
|
|
||||||
# loss_mask_now,
|
|
||||||
)
|
|
||||||
pcl_gt = get_rgbd_point_cloud(
|
|
||||||
viewpoint_trivial,
|
|
||||||
self.image_rgb_masked,
|
|
||||||
self.depth_map,
|
|
||||||
# mask_crop,
|
|
||||||
torch.ones_like(self.depth_map),
|
|
||||||
# loss_mask_now,
|
|
||||||
)
|
|
||||||
_pcls = {
|
_pcls = {
|
||||||
pn: p
|
"pred_depth": get_rgbd_point_cloud(
|
||||||
for pn, p in zip(("pred_depth", "gt_depth"), (pcl_pred, pcl_gt))
|
viewpoint_trivial,
|
||||||
if int(p.num_points_per_cloud()) > 0
|
self.image_render,
|
||||||
|
self.depth_render,
|
||||||
|
# mask_crop,
|
||||||
|
torch.ones_like(self.depth_render),
|
||||||
|
# loss_mask_now,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
if self.depth_map is not None:
|
||||||
|
_pcls["gt_depth"] = get_rgbd_point_cloud(
|
||||||
|
viewpoint_trivial,
|
||||||
|
self.image_rgb_masked,
|
||||||
|
self.depth_map,
|
||||||
|
# mask_crop,
|
||||||
|
torch.ones_like(self.depth_map),
|
||||||
|
# loss_mask_now,
|
||||||
|
)
|
||||||
|
|
||||||
|
_pcls = {pn: p for pn, p in _pcls.items() if int(p.num_points_per_cloud()) > 0}
|
||||||
|
|
||||||
plotlyplot = plot_scene(
|
plotlyplot = plot_scene(
|
||||||
{f"pcl{name_postfix}": _pcls},
|
{f"pcl{name_postfix}": _pcls}, # pyre-ignore
|
||||||
camera_scale=1.0,
|
camera_scale=1.0,
|
||||||
pointcloud_max_points=10000,
|
pointcloud_max_points=10000,
|
||||||
pointcloud_marker_size=1,
|
pointcloud_marker_size=1,
|
||||||
@ -277,10 +281,10 @@ def eval_batch(
|
|||||||
image_render=image_render,
|
image_render=image_render,
|
||||||
image_rgb_masked=image_rgb_masked,
|
image_rgb_masked=image_rgb_masked,
|
||||||
depth_render=cloned_render["depth_render"],
|
depth_render=cloned_render["depth_render"],
|
||||||
# pyre-fixme[6]: Expected `Tensor` for 4th param but got
|
|
||||||
# `Optional[torch.Tensor]`.
|
|
||||||
depth_map=frame_data.depth_map,
|
depth_map=frame_data.depth_map,
|
||||||
depth_mask=frame_data.depth_mask[:1],
|
depth_mask=frame_data.depth_mask[:1]
|
||||||
|
if frame_data.depth_mask is not None
|
||||||
|
else None,
|
||||||
visdom_env=visualize_visdom_env,
|
visdom_env=visualize_visdom_env,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -360,7 +360,7 @@ class GenericModel(ImplicitronModelBase): # pyre-ignore: 13
|
|||||||
and source images, which will be used for intersecting with target rays.
|
and source images, which will be used for intersecting with target rays.
|
||||||
fg_probability: A tensor of shape `(B, 1, H, W)` containing a batch of
|
fg_probability: A tensor of shape `(B, 1, H, W)` containing a batch of
|
||||||
foreground masks.
|
foreground masks.
|
||||||
mask_crop: A binary tensor of shape `(B, 1, H, W)` deonting valid
|
mask_crop: A binary tensor of shape `(B, 1, H, W)` denoting valid
|
||||||
regions in the input images (i.e. regions that do not correspond
|
regions in the input images (i.e. regions that do not correspond
|
||||||
to, e.g., zero-padding). When the `RaySampler`'s sampling mode is set to
|
to, e.g., zero-padding). When the `RaySampler`'s sampling mode is set to
|
||||||
"mask_sample", rays will be sampled in the non zero regions.
|
"mask_sample", rays will be sampled in the non zero regions.
|
||||||
|
@ -26,6 +26,8 @@ from tests.common_testing import interactive_testing_requested
|
|||||||
|
|
||||||
from .common_resources import get_skateboard_data
|
from .common_resources import get_skateboard_data
|
||||||
|
|
||||||
|
VISDOM_PORT = int(os.environ.get("VISDOM_PORT", 8097))
|
||||||
|
|
||||||
|
|
||||||
class TestDatasetVisualize(unittest.TestCase):
|
class TestDatasetVisualize(unittest.TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
@ -77,7 +79,7 @@ class TestDatasetVisualize(unittest.TestCase):
|
|||||||
for k, dataset in self.datasets.items()
|
for k, dataset in self.datasets.items()
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
self.visdom = Visdom()
|
self.visdom = Visdom(port=VISDOM_PORT)
|
||||||
if not self.visdom.check_connection():
|
if not self.visdom.check_connection():
|
||||||
print("Visdom server not running! Disabling visdom visualizations.")
|
print("Visdom server not running! Disabling visdom visualizations.")
|
||||||
self.visdom = None
|
self.visdom = None
|
||||||
|
Loading…
x
Reference in New Issue
Block a user