doc rgbd point cloud

Summary: docstring and shape fix

Reviewed By: shapovalov

Differential Revision: D42609661

fbshipit-source-id: fd50234872ad61b5452821eeb89d51344f70c957
This commit is contained in:
Jeremy Reizenstein 2023-01-24 15:26:52 -08:00 committed by Facebook GitHub Bot
parent d561f1913e
commit a12612a48f
2 changed files with 88 additions and 3 deletions

View File

@ -27,13 +27,33 @@ def get_rgbd_point_cloud(
mask: Optional[torch.Tensor] = None, mask: Optional[torch.Tensor] = None,
mask_thr: float = 0.5, mask_thr: float = 0.5,
mask_points: bool = True, mask_points: bool = True,
euclidean: bool = False,
) -> Pointclouds: ) -> Pointclouds:
""" """
Given a batch of images, depths, masks and cameras, generate a colored Given a batch of images, depths, masks and cameras, generate a single colored
point cloud by unprojecting depth maps to the and coloring with the source point cloud by unprojecting depth maps and coloring with the source
pixel colors. pixel colors.
Arguments:
camera: Batch of N cameras
image_rgb: Batch of N images of shape (N, C, H, W).
For RGB images C=3.
depth_map: Batch of N depth maps of shape (N, 1, H', W').
Only positive values here are used to generate points.
If euclidean=False (default) this contains perpendicular distances
from each point to the camera plane (z-values).
If euclidean=True, this contains distances from each point to
the camera center.
mask: If provided, batch of N masks of the same shape as depth_map.
If provided, values in depth_map are ignored if the corresponding
element of mask is smaller than mask_thr.
mask_thr: used in interpreting mask
euclidean: used in interpreting depth_map.
Returns:
Pointclouds object containing one point cloud.
""" """
imh, imw = image_rgb.shape[2:] imh, imw = depth_map.shape[2:]
# convert the depth maps to point clouds using the grid ray sampler # convert the depth maps to point clouds using the grid ray sampler
pts_3d = ray_bundle_to_ray_points( pts_3d = ray_bundle_to_ray_points(
@ -43,6 +63,7 @@ def get_rgbd_point_cloud(
n_pts_per_ray=1, n_pts_per_ray=1,
min_depth=1.0, min_depth=1.0,
max_depth=1.0, max_depth=1.0,
unit_directions=euclidean,
)(camera)._replace(lengths=depth_map[:, 0, ..., None]) )(camera)._replace(lengths=depth_map[:, 0, ..., None])
) )

View File

@ -0,0 +1,64 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import unittest
import torch
from pytorch3d.implicitron.tools.point_cloud_utils import get_rgbd_point_cloud
from pytorch3d.renderer.cameras import PerspectiveCameras
from tests.common_testing import TestCaseMixin
class TestPointCloudUtils(TestCaseMixin, unittest.TestCase):
def setUp(self):
torch.manual_seed(42)
def test_unproject(self):
H, W = 50, 100
# Random RGBD image with depth 3
# (depth 0 = at the camera)
# and purple in the upper right corner
image = torch.rand(4, H, W)
depth = 3
image[3] = depth
image[1, H // 2 :, W // 2 :] *= 0.4
# two ways to define the same camera:
# at the origin facing the positive z axis
ndc_camera = PerspectiveCameras(focal_length=1.0)
screen_camera = PerspectiveCameras(
focal_length=H // 2,
in_ndc=False,
image_size=((H, W),),
principal_point=((W / 2, H / 2),),
)
for camera in (ndc_camera, screen_camera):
# 1. z-depth
cloud = get_rgbd_point_cloud(
camera,
image_rgb=image[:3][None],
depth_map=image[3:][None],
euclidean=False,
)
[points] = cloud.points_list()
self.assertConstant(points[:, 2], depth) # constant depth
extremes = depth * torch.tensor([W / H - 1 / H, 1 - 1 / H])
self.assertClose(points[:, :2].min(0).values, -extremes)
self.assertClose(points[:, :2].max(0).values, extremes)
# 2. euclidean
cloud = get_rgbd_point_cloud(
camera,
image_rgb=image[:3][None],
depth_map=image[3:][None],
euclidean=True,
)
[points] = cloud.points_list()
self.assertConstant(torch.norm(points, dim=1), depth, atol=1e-5)