mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 18:02:14 +08:00 
			
		
		
		
	Optional ground-truth depth maps in visualiser
Summary: The code does not crash if depth map/mask are not given. Reviewed By: bottler Differential Revision: D45082985 fbshipit-source-id: 3610d8beb4ac897fbbe52f56a6dd012a6365b89b
This commit is contained in:
		
							parent
							
								
									1af6bf4768
								
							
						
					
					
						commit
						0e3138eca8
					
				@ -38,8 +38,8 @@ class _Visualizer:
 | 
			
		||||
    image_render: torch.Tensor
 | 
			
		||||
    image_rgb_masked: torch.Tensor
 | 
			
		||||
    depth_render: torch.Tensor
 | 
			
		||||
    depth_map: torch.Tensor
 | 
			
		||||
    depth_mask: torch.Tensor
 | 
			
		||||
    depth_map: Optional[torch.Tensor]
 | 
			
		||||
    depth_mask: Optional[torch.Tensor]
 | 
			
		||||
 | 
			
		||||
    visdom_env: str = "eval_debug"
 | 
			
		||||
 | 
			
		||||
@ -75,9 +75,11 @@ class _Visualizer:
 | 
			
		||||
        viz = self._viz
 | 
			
		||||
        viz.images(
 | 
			
		||||
            torch.cat(
 | 
			
		||||
                (
 | 
			
		||||
                    make_depth_image(self.depth_render, loss_mask_now),
 | 
			
		||||
                    make_depth_image(self.depth_map, loss_mask_now),
 | 
			
		||||
                (make_depth_image(self.depth_render, loss_mask_now),)
 | 
			
		||||
                + (
 | 
			
		||||
                    (make_depth_image(self.depth_map, loss_mask_now),)
 | 
			
		||||
                    if self.depth_map is not None
 | 
			
		||||
                    else ()
 | 
			
		||||
                ),
 | 
			
		||||
                dim=3,
 | 
			
		||||
            ),
 | 
			
		||||
@ -91,12 +93,13 @@ class _Visualizer:
 | 
			
		||||
            win="depth_abs" + name_postfix + "_mask",
 | 
			
		||||
            opts={"title": f"depth_abs_{name_postfix}_{depth_loss:1.2f}_mask"},
 | 
			
		||||
        )
 | 
			
		||||
        viz.images(
 | 
			
		||||
            self.depth_mask,
 | 
			
		||||
            env=self.visdom_env,
 | 
			
		||||
            win="depth_abs" + name_postfix + "_maskd",
 | 
			
		||||
            opts={"title": f"depth_abs_{name_postfix}_{depth_loss:1.2f}_maskd"},
 | 
			
		||||
        )
 | 
			
		||||
        if self.depth_mask is not None:
 | 
			
		||||
            viz.images(
 | 
			
		||||
                self.depth_mask,
 | 
			
		||||
                env=self.visdom_env,
 | 
			
		||||
                win="depth_abs" + name_postfix + "_maskd",
 | 
			
		||||
                opts={"title": f"depth_abs_{name_postfix}_{depth_loss:1.2f}_maskd"},
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        # show the 3D plot
 | 
			
		||||
        # pyre-fixme[9]: viewpoint_trivial has type `PerspectiveCameras`; used as
 | 
			
		||||
@ -104,29 +107,30 @@ class _Visualizer:
 | 
			
		||||
        viewpoint_trivial: PerspectiveCameras = PerspectiveCameras().to(
 | 
			
		||||
            loss_mask_now.device
 | 
			
		||||
        )
 | 
			
		||||
        pcl_pred = get_rgbd_point_cloud(
 | 
			
		||||
            viewpoint_trivial,
 | 
			
		||||
            self.image_render,
 | 
			
		||||
            self.depth_render,
 | 
			
		||||
            # mask_crop,
 | 
			
		||||
            torch.ones_like(self.depth_render),
 | 
			
		||||
            # loss_mask_now,
 | 
			
		||||
        )
 | 
			
		||||
        pcl_gt = get_rgbd_point_cloud(
 | 
			
		||||
            viewpoint_trivial,
 | 
			
		||||
            self.image_rgb_masked,
 | 
			
		||||
            self.depth_map,
 | 
			
		||||
            # mask_crop,
 | 
			
		||||
            torch.ones_like(self.depth_map),
 | 
			
		||||
            # loss_mask_now,
 | 
			
		||||
        )
 | 
			
		||||
        _pcls = {
 | 
			
		||||
            pn: p
 | 
			
		||||
            for pn, p in zip(("pred_depth", "gt_depth"), (pcl_pred, pcl_gt))
 | 
			
		||||
            if int(p.num_points_per_cloud()) > 0
 | 
			
		||||
            "pred_depth": get_rgbd_point_cloud(
 | 
			
		||||
                viewpoint_trivial,
 | 
			
		||||
                self.image_render,
 | 
			
		||||
                self.depth_render,
 | 
			
		||||
                # mask_crop,
 | 
			
		||||
                torch.ones_like(self.depth_render),
 | 
			
		||||
                # loss_mask_now,
 | 
			
		||||
            )
 | 
			
		||||
        }
 | 
			
		||||
        if self.depth_map is not None:
 | 
			
		||||
            _pcls["gt_depth"] = get_rgbd_point_cloud(
 | 
			
		||||
                viewpoint_trivial,
 | 
			
		||||
                self.image_rgb_masked,
 | 
			
		||||
                self.depth_map,
 | 
			
		||||
                # mask_crop,
 | 
			
		||||
                torch.ones_like(self.depth_map),
 | 
			
		||||
                # loss_mask_now,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        _pcls = {pn: p for pn, p in _pcls.items() if int(p.num_points_per_cloud()) > 0}
 | 
			
		||||
 | 
			
		||||
        plotlyplot = plot_scene(
 | 
			
		||||
            {f"pcl{name_postfix}": _pcls},
 | 
			
		||||
            {f"pcl{name_postfix}": _pcls},  # pyre-ignore
 | 
			
		||||
            camera_scale=1.0,
 | 
			
		||||
            pointcloud_max_points=10000,
 | 
			
		||||
            pointcloud_marker_size=1,
 | 
			
		||||
@ -277,10 +281,10 @@ def eval_batch(
 | 
			
		||||
            image_render=image_render,
 | 
			
		||||
            image_rgb_masked=image_rgb_masked,
 | 
			
		||||
            depth_render=cloned_render["depth_render"],
 | 
			
		||||
            # pyre-fixme[6]: Expected `Tensor` for 4th param but got
 | 
			
		||||
            #  `Optional[torch.Tensor]`.
 | 
			
		||||
            depth_map=frame_data.depth_map,
 | 
			
		||||
            depth_mask=frame_data.depth_mask[:1],
 | 
			
		||||
            depth_mask=frame_data.depth_mask[:1]
 | 
			
		||||
            if frame_data.depth_mask is not None
 | 
			
		||||
            else None,
 | 
			
		||||
            visdom_env=visualize_visdom_env,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -360,7 +360,7 @@ class GenericModel(ImplicitronModelBase):  # pyre-ignore: 13
 | 
			
		||||
                and source images, which will be used for intersecting with target rays.
 | 
			
		||||
            fg_probability: A tensor of shape `(B, 1, H, W)` containing a batch of
 | 
			
		||||
                foreground masks.
 | 
			
		||||
            mask_crop: A binary tensor of shape `(B, 1, H, W)` deonting valid
 | 
			
		||||
            mask_crop: A binary tensor of shape `(B, 1, H, W)` denoting valid
 | 
			
		||||
                regions in the input images (i.e. regions that do not correspond
 | 
			
		||||
                to, e.g., zero-padding). When the `RaySampler`'s sampling mode is set to
 | 
			
		||||
                "mask_sample", rays  will be sampled in the non zero regions.
 | 
			
		||||
 | 
			
		||||
@ -26,6 +26,8 @@ from tests.common_testing import interactive_testing_requested
 | 
			
		||||
 | 
			
		||||
from .common_resources import get_skateboard_data
 | 
			
		||||
 | 
			
		||||
VISDOM_PORT = int(os.environ.get("VISDOM_PORT", 8097))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestDatasetVisualize(unittest.TestCase):
 | 
			
		||||
    def setUp(self):
 | 
			
		||||
@ -77,7 +79,7 @@ class TestDatasetVisualize(unittest.TestCase):
 | 
			
		||||
                for k, dataset in self.datasets.items()
 | 
			
		||||
            }
 | 
			
		||||
        )
 | 
			
		||||
        self.visdom = Visdom()
 | 
			
		||||
        self.visdom = Visdom(port=VISDOM_PORT)
 | 
			
		||||
        if not self.visdom.check_connection():
 | 
			
		||||
            print("Visdom server not running! Disabling visdom visualizations.")
 | 
			
		||||
            self.visdom = None
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user