mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 18:02:14 +08:00 
			
		
		
		
	doc rgbd point cloud
Summary: docstring and shape fix Reviewed By: shapovalov Differential Revision: D42609661 fbshipit-source-id: fd50234872ad61b5452821eeb89d51344f70c957
This commit is contained in:
		
							parent
							
								
									d561f1913e
								
							
						
					
					
						commit
						a12612a48f
					
				@ -27,13 +27,33 @@ def get_rgbd_point_cloud(
 | 
			
		||||
    mask: Optional[torch.Tensor] = None,
 | 
			
		||||
    mask_thr: float = 0.5,
 | 
			
		||||
    mask_points: bool = True,
 | 
			
		||||
    euclidean: bool = False,
 | 
			
		||||
) -> Pointclouds:
 | 
			
		||||
    """
 | 
			
		||||
    Given a batch of images, depths, masks and cameras, generate a colored
 | 
			
		||||
    point cloud by unprojecting depth maps to the  and coloring with the source
 | 
			
		||||
    Given a batch of images, depths, masks and cameras, generate a single colored
 | 
			
		||||
    point cloud by unprojecting depth maps and coloring with the source
 | 
			
		||||
    pixel colors.
 | 
			
		||||
 | 
			
		||||
    Arguments:
 | 
			
		||||
        camera: Batch of N cameras
 | 
			
		||||
        image_rgb: Batch of N images of shape (N, C, H, W).
 | 
			
		||||
            For RGB images C=3.
 | 
			
		||||
        depth_map: Batch of N depth maps of shape (N, 1, H', W').
 | 
			
		||||
            Only positive values here are used to generate points.
 | 
			
		||||
            If euclidean=False (default) this contains perpendicular distances
 | 
			
		||||
            from each point to the camera plane (z-values).
 | 
			
		||||
            If euclidean=True, this contains distances from each point to
 | 
			
		||||
            the camera center.
 | 
			
		||||
        mask: If provided, batch of N masks of the same shape as depth_map.
 | 
			
		||||
            If provided, values in depth_map are ignored if the corresponding
 | 
			
		||||
            element of mask is smaller than mask_thr.
 | 
			
		||||
        mask_thr: used in interpreting mask
 | 
			
		||||
        euclidean: used in interpreting depth_map.
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
        Pointclouds object containing one point cloud.
 | 
			
		||||
    """
 | 
			
		||||
    imh, imw = image_rgb.shape[2:]
 | 
			
		||||
    imh, imw = depth_map.shape[2:]
 | 
			
		||||
 | 
			
		||||
    # convert the depth maps to point clouds using the grid ray sampler
 | 
			
		||||
    pts_3d = ray_bundle_to_ray_points(
 | 
			
		||||
@ -43,6 +63,7 @@ def get_rgbd_point_cloud(
 | 
			
		||||
            n_pts_per_ray=1,
 | 
			
		||||
            min_depth=1.0,
 | 
			
		||||
            max_depth=1.0,
 | 
			
		||||
            unit_directions=euclidean,
 | 
			
		||||
        )(camera)._replace(lengths=depth_map[:, 0, ..., None])
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										64
									
								
								tests/implicitron/test_pointcloud_utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										64
									
								
								tests/implicitron/test_pointcloud_utils.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,64 @@
 | 
			
		||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
 | 
			
		||||
# All rights reserved.
 | 
			
		||||
#
 | 
			
		||||
# This source code is licensed under the BSD-style license found in the
 | 
			
		||||
# LICENSE file in the root directory of this source tree.
 | 
			
		||||
 | 
			
		||||
import unittest
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from pytorch3d.implicitron.tools.point_cloud_utils import get_rgbd_point_cloud
 | 
			
		||||
 | 
			
		||||
from pytorch3d.renderer.cameras import PerspectiveCameras
 | 
			
		||||
from tests.common_testing import TestCaseMixin
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestPointCloudUtils(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
    def setUp(self):
 | 
			
		||||
        torch.manual_seed(42)
 | 
			
		||||
 | 
			
		||||
    def test_unproject(self):
 | 
			
		||||
        H, W = 50, 100
 | 
			
		||||
 | 
			
		||||
        # Random RGBD image with depth 3
 | 
			
		||||
        # (depth 0 = at the camera)
 | 
			
		||||
        # and purple in the upper right corner
 | 
			
		||||
 | 
			
		||||
        image = torch.rand(4, H, W)
 | 
			
		||||
        depth = 3
 | 
			
		||||
        image[3] = depth
 | 
			
		||||
        image[1, H // 2 :, W // 2 :] *= 0.4
 | 
			
		||||
 | 
			
		||||
        # two ways to define the same camera:
 | 
			
		||||
        # at the origin facing the positive z axis
 | 
			
		||||
        ndc_camera = PerspectiveCameras(focal_length=1.0)
 | 
			
		||||
        screen_camera = PerspectiveCameras(
 | 
			
		||||
            focal_length=H // 2,
 | 
			
		||||
            in_ndc=False,
 | 
			
		||||
            image_size=((H, W),),
 | 
			
		||||
            principal_point=((W / 2, H / 2),),
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        for camera in (ndc_camera, screen_camera):
 | 
			
		||||
            # 1. z-depth
 | 
			
		||||
            cloud = get_rgbd_point_cloud(
 | 
			
		||||
                camera,
 | 
			
		||||
                image_rgb=image[:3][None],
 | 
			
		||||
                depth_map=image[3:][None],
 | 
			
		||||
                euclidean=False,
 | 
			
		||||
            )
 | 
			
		||||
            [points] = cloud.points_list()
 | 
			
		||||
            self.assertConstant(points[:, 2], depth)  # constant depth
 | 
			
		||||
            extremes = depth * torch.tensor([W / H - 1 / H, 1 - 1 / H])
 | 
			
		||||
            self.assertClose(points[:, :2].min(0).values, -extremes)
 | 
			
		||||
            self.assertClose(points[:, :2].max(0).values, extremes)
 | 
			
		||||
 | 
			
		||||
            # 2. euclidean
 | 
			
		||||
            cloud = get_rgbd_point_cloud(
 | 
			
		||||
                camera,
 | 
			
		||||
                image_rgb=image[:3][None],
 | 
			
		||||
                depth_map=image[3:][None],
 | 
			
		||||
                euclidean=True,
 | 
			
		||||
            )
 | 
			
		||||
            [points] = cloud.points_list()
 | 
			
		||||
            self.assertConstant(torch.norm(points, dim=1), depth, atol=1e-5)
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user