mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
doc rgbd point cloud
Summary: docstring and shape fix Reviewed By: shapovalov Differential Revision: D42609661 fbshipit-source-id: fd50234872ad61b5452821eeb89d51344f70c957
This commit is contained in:
parent
d561f1913e
commit
a12612a48f
@ -27,13 +27,33 @@ def get_rgbd_point_cloud(
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
mask_thr: float = 0.5,
|
||||
mask_points: bool = True,
|
||||
euclidean: bool = False,
|
||||
) -> Pointclouds:
|
||||
"""
|
||||
Given a batch of images, depths, masks and cameras, generate a colored
|
||||
point cloud by unprojecting depth maps to the and coloring with the source
|
||||
Given a batch of images, depths, masks and cameras, generate a single colored
|
||||
point cloud by unprojecting depth maps and coloring with the source
|
||||
pixel colors.
|
||||
|
||||
Arguments:
|
||||
camera: Batch of N cameras
|
||||
image_rgb: Batch of N images of shape (N, C, H, W).
|
||||
For RGB images C=3.
|
||||
depth_map: Batch of N depth maps of shape (N, 1, H', W').
|
||||
Only positive values here are used to generate points.
|
||||
If euclidean=False (default) this contains perpendicular distances
|
||||
from each point to the camera plane (z-values).
|
||||
If euclidean=True, this contains distances from each point to
|
||||
the camera center.
|
||||
mask: If provided, batch of N masks of the same shape as depth_map.
|
||||
If provided, values in depth_map are ignored if the corresponding
|
||||
element of mask is smaller than mask_thr.
|
||||
mask_thr: used in interpreting mask
|
||||
euclidean: used in interpreting depth_map.
|
||||
|
||||
Returns:
|
||||
Pointclouds object containing one point cloud.
|
||||
"""
|
||||
imh, imw = image_rgb.shape[2:]
|
||||
imh, imw = depth_map.shape[2:]
|
||||
|
||||
# convert the depth maps to point clouds using the grid ray sampler
|
||||
pts_3d = ray_bundle_to_ray_points(
|
||||
@ -43,6 +63,7 @@ def get_rgbd_point_cloud(
|
||||
n_pts_per_ray=1,
|
||||
min_depth=1.0,
|
||||
max_depth=1.0,
|
||||
unit_directions=euclidean,
|
||||
)(camera)._replace(lengths=depth_map[:, 0, ..., None])
|
||||
)
|
||||
|
||||
|
64
tests/implicitron/test_pointcloud_utils.py
Normal file
64
tests/implicitron/test_pointcloud_utils.py
Normal file
@ -0,0 +1,64 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from pytorch3d.implicitron.tools.point_cloud_utils import get_rgbd_point_cloud
|
||||
|
||||
from pytorch3d.renderer.cameras import PerspectiveCameras
|
||||
from tests.common_testing import TestCaseMixin
|
||||
|
||||
|
||||
class TestPointCloudUtils(TestCaseMixin, unittest.TestCase):
|
||||
def setUp(self):
|
||||
torch.manual_seed(42)
|
||||
|
||||
def test_unproject(self):
|
||||
H, W = 50, 100
|
||||
|
||||
# Random RGBD image with depth 3
|
||||
# (depth 0 = at the camera)
|
||||
# and purple in the upper right corner
|
||||
|
||||
image = torch.rand(4, H, W)
|
||||
depth = 3
|
||||
image[3] = depth
|
||||
image[1, H // 2 :, W // 2 :] *= 0.4
|
||||
|
||||
# two ways to define the same camera:
|
||||
# at the origin facing the positive z axis
|
||||
ndc_camera = PerspectiveCameras(focal_length=1.0)
|
||||
screen_camera = PerspectiveCameras(
|
||||
focal_length=H // 2,
|
||||
in_ndc=False,
|
||||
image_size=((H, W),),
|
||||
principal_point=((W / 2, H / 2),),
|
||||
)
|
||||
|
||||
for camera in (ndc_camera, screen_camera):
|
||||
# 1. z-depth
|
||||
cloud = get_rgbd_point_cloud(
|
||||
camera,
|
||||
image_rgb=image[:3][None],
|
||||
depth_map=image[3:][None],
|
||||
euclidean=False,
|
||||
)
|
||||
[points] = cloud.points_list()
|
||||
self.assertConstant(points[:, 2], depth) # constant depth
|
||||
extremes = depth * torch.tensor([W / H - 1 / H, 1 - 1 / H])
|
||||
self.assertClose(points[:, :2].min(0).values, -extremes)
|
||||
self.assertClose(points[:, :2].max(0).values, extremes)
|
||||
|
||||
# 2. euclidean
|
||||
cloud = get_rgbd_point_cloud(
|
||||
camera,
|
||||
image_rgb=image[:3][None],
|
||||
depth_map=image[3:][None],
|
||||
euclidean=True,
|
||||
)
|
||||
[points] = cloud.points_list()
|
||||
self.assertConstant(torch.norm(points, dim=1), depth, atol=1e-5)
|
Loading…
x
Reference in New Issue
Block a user