mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
doc rgbd point cloud
Summary: docstring and shape fix Reviewed By: shapovalov Differential Revision: D42609661 fbshipit-source-id: fd50234872ad61b5452821eeb89d51344f70c957
This commit is contained in:
parent
d561f1913e
commit
a12612a48f
@ -27,13 +27,33 @@ def get_rgbd_point_cloud(
|
|||||||
mask: Optional[torch.Tensor] = None,
|
mask: Optional[torch.Tensor] = None,
|
||||||
mask_thr: float = 0.5,
|
mask_thr: float = 0.5,
|
||||||
mask_points: bool = True,
|
mask_points: bool = True,
|
||||||
|
euclidean: bool = False,
|
||||||
) -> Pointclouds:
|
) -> Pointclouds:
|
||||||
"""
|
"""
|
||||||
Given a batch of images, depths, masks and cameras, generate a colored
|
Given a batch of images, depths, masks and cameras, generate a single colored
|
||||||
point cloud by unprojecting depth maps to the and coloring with the source
|
point cloud by unprojecting depth maps and coloring with the source
|
||||||
pixel colors.
|
pixel colors.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
camera: Batch of N cameras
|
||||||
|
image_rgb: Batch of N images of shape (N, C, H, W).
|
||||||
|
For RGB images C=3.
|
||||||
|
depth_map: Batch of N depth maps of shape (N, 1, H', W').
|
||||||
|
Only positive values here are used to generate points.
|
||||||
|
If euclidean=False (default) this contains perpendicular distances
|
||||||
|
from each point to the camera plane (z-values).
|
||||||
|
If euclidean=True, this contains distances from each point to
|
||||||
|
the camera center.
|
||||||
|
mask: If provided, batch of N masks of the same shape as depth_map.
|
||||||
|
If provided, values in depth_map are ignored if the corresponding
|
||||||
|
element of mask is smaller than mask_thr.
|
||||||
|
mask_thr: used in interpreting mask
|
||||||
|
euclidean: used in interpreting depth_map.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Pointclouds object containing one point cloud.
|
||||||
"""
|
"""
|
||||||
imh, imw = image_rgb.shape[2:]
|
imh, imw = depth_map.shape[2:]
|
||||||
|
|
||||||
# convert the depth maps to point clouds using the grid ray sampler
|
# convert the depth maps to point clouds using the grid ray sampler
|
||||||
pts_3d = ray_bundle_to_ray_points(
|
pts_3d = ray_bundle_to_ray_points(
|
||||||
@ -43,6 +63,7 @@ def get_rgbd_point_cloud(
|
|||||||
n_pts_per_ray=1,
|
n_pts_per_ray=1,
|
||||||
min_depth=1.0,
|
min_depth=1.0,
|
||||||
max_depth=1.0,
|
max_depth=1.0,
|
||||||
|
unit_directions=euclidean,
|
||||||
)(camera)._replace(lengths=depth_map[:, 0, ..., None])
|
)(camera)._replace(lengths=depth_map[:, 0, ..., None])
|
||||||
)
|
)
|
||||||
|
|
||||||
|
64
tests/implicitron/test_pointcloud_utils.py
Normal file
64
tests/implicitron/test_pointcloud_utils.py
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the BSD-style license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from pytorch3d.implicitron.tools.point_cloud_utils import get_rgbd_point_cloud
|
||||||
|
|
||||||
|
from pytorch3d.renderer.cameras import PerspectiveCameras
|
||||||
|
from tests.common_testing import TestCaseMixin
|
||||||
|
|
||||||
|
|
||||||
|
class TestPointCloudUtils(TestCaseMixin, unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
torch.manual_seed(42)
|
||||||
|
|
||||||
|
def test_unproject(self):
|
||||||
|
H, W = 50, 100
|
||||||
|
|
||||||
|
# Random RGBD image with depth 3
|
||||||
|
# (depth 0 = at the camera)
|
||||||
|
# and purple in the upper right corner
|
||||||
|
|
||||||
|
image = torch.rand(4, H, W)
|
||||||
|
depth = 3
|
||||||
|
image[3] = depth
|
||||||
|
image[1, H // 2 :, W // 2 :] *= 0.4
|
||||||
|
|
||||||
|
# two ways to define the same camera:
|
||||||
|
# at the origin facing the positive z axis
|
||||||
|
ndc_camera = PerspectiveCameras(focal_length=1.0)
|
||||||
|
screen_camera = PerspectiveCameras(
|
||||||
|
focal_length=H // 2,
|
||||||
|
in_ndc=False,
|
||||||
|
image_size=((H, W),),
|
||||||
|
principal_point=((W / 2, H / 2),),
|
||||||
|
)
|
||||||
|
|
||||||
|
for camera in (ndc_camera, screen_camera):
|
||||||
|
# 1. z-depth
|
||||||
|
cloud = get_rgbd_point_cloud(
|
||||||
|
camera,
|
||||||
|
image_rgb=image[:3][None],
|
||||||
|
depth_map=image[3:][None],
|
||||||
|
euclidean=False,
|
||||||
|
)
|
||||||
|
[points] = cloud.points_list()
|
||||||
|
self.assertConstant(points[:, 2], depth) # constant depth
|
||||||
|
extremes = depth * torch.tensor([W / H - 1 / H, 1 - 1 / H])
|
||||||
|
self.assertClose(points[:, :2].min(0).values, -extremes)
|
||||||
|
self.assertClose(points[:, :2].max(0).values, extremes)
|
||||||
|
|
||||||
|
# 2. euclidean
|
||||||
|
cloud = get_rgbd_point_cloud(
|
||||||
|
camera,
|
||||||
|
image_rgb=image[:3][None],
|
||||||
|
depth_map=image[3:][None],
|
||||||
|
euclidean=True,
|
||||||
|
)
|
||||||
|
[points] = cloud.points_list()
|
||||||
|
self.assertConstant(torch.norm(points, dim=1), depth, atol=1e-5)
|
Loading…
x
Reference in New Issue
Block a user