doc rgbd point cloud

Summary: docstring and shape fix

Reviewed By: shapovalov

Differential Revision: D42609661

fbshipit-source-id: fd50234872ad61b5452821eeb89d51344f70c957
This commit is contained in:
Jeremy Reizenstein 2023-01-24 15:26:52 -08:00 committed by Facebook GitHub Bot
parent d561f1913e
commit a12612a48f
2 changed files with 88 additions and 3 deletions

View File

@ -27,13 +27,33 @@ def get_rgbd_point_cloud(
mask: Optional[torch.Tensor] = None,
mask_thr: float = 0.5,
mask_points: bool = True,
euclidean: bool = False,
) -> Pointclouds:
"""
Given a batch of images, depths, masks and cameras, generate a colored
point cloud by unprojecting depth maps to the and coloring with the source
Given a batch of images, depths, masks and cameras, generate a single colored
point cloud by unprojecting depth maps and coloring with the source
pixel colors.
Arguments:
camera: Batch of N cameras
image_rgb: Batch of N images of shape (N, C, H, W).
For RGB images C=3.
depth_map: Batch of N depth maps of shape (N, 1, H', W').
Only positive values here are used to generate points.
If euclidean=False (default) this contains perpendicular distances
from each point to the camera plane (z-values).
If euclidean=True, this contains distances from each point to
the camera center.
mask: If provided, batch of N masks of the same shape as depth_map.
If provided, values in depth_map are ignored if the corresponding
element of mask is smaller than mask_thr.
mask_thr: used in interpreting mask
euclidean: used in interpreting depth_map.
Returns:
Pointclouds object containing one point cloud.
"""
imh, imw = image_rgb.shape[2:]
imh, imw = depth_map.shape[2:]
# convert the depth maps to point clouds using the grid ray sampler
pts_3d = ray_bundle_to_ray_points(
@ -43,6 +63,7 @@ def get_rgbd_point_cloud(
n_pts_per_ray=1,
min_depth=1.0,
max_depth=1.0,
unit_directions=euclidean,
)(camera)._replace(lengths=depth_map[:, 0, ..., None])
)

View File

@ -0,0 +1,64 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import unittest
import torch
from pytorch3d.implicitron.tools.point_cloud_utils import get_rgbd_point_cloud
from pytorch3d.renderer.cameras import PerspectiveCameras
from tests.common_testing import TestCaseMixin
class TestPointCloudUtils(TestCaseMixin, unittest.TestCase):
def setUp(self):
torch.manual_seed(42)
def test_unproject(self):
H, W = 50, 100
# Random RGBD image with depth 3
# (depth 0 = at the camera)
# and purple in the upper right corner
image = torch.rand(4, H, W)
depth = 3
image[3] = depth
image[1, H // 2 :, W // 2 :] *= 0.4
# two ways to define the same camera:
# at the origin facing the positive z axis
ndc_camera = PerspectiveCameras(focal_length=1.0)
screen_camera = PerspectiveCameras(
focal_length=H // 2,
in_ndc=False,
image_size=((H, W),),
principal_point=((W / 2, H / 2),),
)
for camera in (ndc_camera, screen_camera):
# 1. z-depth
cloud = get_rgbd_point_cloud(
camera,
image_rgb=image[:3][None],
depth_map=image[3:][None],
euclidean=False,
)
[points] = cloud.points_list()
self.assertConstant(points[:, 2], depth) # constant depth
extremes = depth * torch.tensor([W / H - 1 / H, 1 - 1 / H])
self.assertClose(points[:, :2].min(0).values, -extremes)
self.assertClose(points[:, :2].max(0).values, extremes)
# 2. euclidean
cloud = get_rgbd_point_cloud(
camera,
image_rgb=image[:3][None],
depth_map=image[3:][None],
euclidean=True,
)
[points] = cloud.points_list()
self.assertConstant(torch.norm(points, dim=1), depth, atol=1e-5)