mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-19 22:00:35 +08:00
Support variable size radius for points in rasterizer
Summary: Support variable size pointclouds in the renderer API to allow compatibility with Pulsar rasterizer. If radius is provided as a float, it is converted to a tensor of shape (P). Otherwise radius is expected to be an (N, P_padded) dimensional tensor where P_padded is the max number of points in the batch (following the convention from pulsar: https://our.intern.facebook.com/intern/diffusion/FBS/browse/master/fbcode/frl/gemini/pulsar/pulsar/renderer.py?commit=ee0342850210e5df441e14fd97162675c70d147c&lines=50) Reviewed By: jcjohnson, gkioxari Differential Revision: D21429400 fbshipit-source-id: 65de7d9cd2472b27fc29f96160c33687e88098a2
This commit is contained in:
committed by
Facebook GitHub Bot
parent
e40c2167ae
commit
ebe2693b11
@@ -1,6 +1,7 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
import itertools
|
||||
|
||||
from fvcore.common.benchmark import benchmark
|
||||
from test_cameras_alignment import TestCamerasAlignment
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
from itertools import product
|
||||
|
||||
import torch
|
||||
from fvcore.common.benchmark import benchmark
|
||||
@@ -18,44 +19,64 @@ def _bm_python_with_init(N, P, img_size=32, radius=0.1, pts_per_pxl=3):
|
||||
return lambda: rasterize_points_python(*args)
|
||||
|
||||
|
||||
def _bm_cpu_with_init(N, P, img_size=32, radius=0.1, pts_per_pxl=3):
|
||||
def _bm_rasterize_points_with_init(
|
||||
N, P, img_size=32, radius=0.1, pts_per_pxl=3, device="cpu", expand_radius=False
|
||||
):
|
||||
torch.manual_seed(231)
|
||||
points = torch.randn(N, P, 3)
|
||||
pointclouds = Pointclouds(points=points)
|
||||
args = (pointclouds, img_size, radius, pts_per_pxl)
|
||||
return lambda: rasterize_points(*args)
|
||||
|
||||
|
||||
def _bm_cuda_with_init(N, P, img_size=32, radius=0.1, pts_per_pxl=3):
|
||||
torch.manual_seed(231)
|
||||
device = torch.device("cuda:0")
|
||||
device = torch.device(device)
|
||||
points = torch.randn(N, P, 3, device=device)
|
||||
pointclouds = Pointclouds(points=points)
|
||||
|
||||
if expand_radius:
|
||||
points_padded = pointclouds.points_padded()
|
||||
radius = torch.full((N, P), fill_value=radius).type_as(points_padded)
|
||||
|
||||
args = (pointclouds, img_size, radius, pts_per_pxl)
|
||||
torch.cuda.synchronize(device)
|
||||
if device == "cuda":
|
||||
torch.cuda.synchronize(device)
|
||||
|
||||
def fn():
|
||||
rasterize_points(*args)
|
||||
torch.cuda.synchronize(device)
|
||||
if device == "cuda":
|
||||
torch.cuda.synchronize(device)
|
||||
|
||||
return fn
|
||||
|
||||
|
||||
def bm_python_vs_cpu() -> None:
|
||||
kwargs_list = [
|
||||
{"N": 1, "P": 32, "img_size": 32, "radius": 0.1, "pts_per_pxl": 3},
|
||||
{"N": 2, "P": 32, "img_size": 32, "radius": 0.1, "pts_per_pxl": 3},
|
||||
]
|
||||
benchmark(_bm_python_with_init, "RASTERIZE_PYTHON", kwargs_list, warmup_iters=1)
|
||||
benchmark(_bm_cpu_with_init, "RASTERIZE_CPU", kwargs_list, warmup_iters=1)
|
||||
kwargs_list = [
|
||||
{"N": 2, "P": 32, "img_size": 32, "radius": 0.1, "pts_per_pxl": 3},
|
||||
{"N": 4, "P": 1024, "img_size": 128, "radius": 0.05, "pts_per_pxl": 5},
|
||||
]
|
||||
benchmark(_bm_cpu_with_init, "RASTERIZE_CPU", kwargs_list, warmup_iters=1)
|
||||
def bm_python_vs_cpu_vs_cuda() -> None:
|
||||
kwargs_list = []
|
||||
num_meshes = [1]
|
||||
num_points = [10000, 2000]
|
||||
image_size = [128, 256]
|
||||
radius = [1e-3, 0.01]
|
||||
pts_per_pxl = [50, 100]
|
||||
expand = [True, False]
|
||||
test_cases = product(
|
||||
num_meshes, num_points, image_size, radius, pts_per_pxl, expand
|
||||
)
|
||||
for case in test_cases:
|
||||
n, p, im, r, pts, e = case
|
||||
kwargs_list.append(
|
||||
{
|
||||
"N": n,
|
||||
"P": p,
|
||||
"img_size": im,
|
||||
"radius": r,
|
||||
"pts_per_pxl": pts,
|
||||
"device": "cpu",
|
||||
"expand_radius": e,
|
||||
}
|
||||
)
|
||||
|
||||
benchmark(
|
||||
_bm_rasterize_points_with_init, "RASTERIZE_CPU", kwargs_list, warmup_iters=1
|
||||
)
|
||||
kwargs_list += [
|
||||
{"N": 32, "P": 10000, "img_size": 128, "radius": 0.01, "pts_per_pxl": 50},
|
||||
{"N": 32, "P": 100000, "img_size": 128, "radius": 0.01, "pts_per_pxl": 50},
|
||||
{"N": 8, "P": 200000, "img_size": 512, "radius": 0.01, "pts_per_pxl": 50},
|
||||
]
|
||||
benchmark(_bm_cuda_with_init, "RASTERIZE_CUDA", kwargs_list, warmup_iters=1)
|
||||
for k in kwargs_list:
|
||||
k["device"] = "cuda"
|
||||
benchmark(
|
||||
_bm_rasterize_points_with_init, "RASTERIZE_CUDA", kwargs_list, warmup_iters=1
|
||||
)
|
||||
|
||||
@@ -8,6 +8,7 @@ import torch
|
||||
from common_testing import TestCaseMixin, get_random_cuda_device
|
||||
from pytorch3d import _C
|
||||
from pytorch3d.renderer.points.rasterize_points import (
|
||||
_format_radius,
|
||||
rasterize_points,
|
||||
rasterize_points_python,
|
||||
)
|
||||
@@ -40,6 +41,21 @@ class TestRasterizePoints(TestCaseMixin, unittest.TestCase):
|
||||
device = get_random_cuda_device()
|
||||
self._test_behind_camera(rasterize_points, device, bin_size=0)
|
||||
|
||||
def test_python_variable_radius(self):
|
||||
self._test_variable_size_radius(
|
||||
rasterize_points_python, torch.device("cpu"), bin_size=-1
|
||||
)
|
||||
|
||||
def test_cpu_variable_radius(self):
|
||||
self._test_variable_size_radius(rasterize_points, torch.device("cpu"))
|
||||
|
||||
def test_cuda_variable_radius(self):
|
||||
device = get_random_cuda_device()
|
||||
# Naive
|
||||
self._test_variable_size_radius(rasterize_points, device, bin_size=0)
|
||||
# Coarse to fine
|
||||
self._test_variable_size_radius(rasterize_points, device, bin_size=None)
|
||||
|
||||
def test_cpp_vs_naive_vs_binned(self):
|
||||
# Make sure that the backward pass runs for all pathways
|
||||
N = 2
|
||||
@@ -403,6 +419,8 @@ class TestRasterizePoints(TestCaseMixin, unittest.TestCase):
|
||||
points_packed = pointclouds.points_packed()
|
||||
cloud_to_packed_first_idx = pointclouds.cloud_to_packed_first_idx()
|
||||
num_points_per_cloud = pointclouds.num_points_per_cloud()
|
||||
|
||||
radius = torch.full((points_packed.shape[0],), fill_value=radius)
|
||||
args = (
|
||||
points_packed,
|
||||
cloud_to_packed_first_idx,
|
||||
@@ -419,6 +437,7 @@ class TestRasterizePoints(TestCaseMixin, unittest.TestCase):
|
||||
points_packed = pointclouds_cuda.points_packed()
|
||||
cloud_to_packed_first_idx = pointclouds_cuda.cloud_to_packed_first_idx()
|
||||
num_points_per_cloud = pointclouds_cuda.num_points_per_cloud()
|
||||
radius = radius.to(device)
|
||||
args = (
|
||||
points_packed,
|
||||
cloud_to_packed_first_idx,
|
||||
@@ -499,6 +518,7 @@ class TestRasterizePoints(TestCaseMixin, unittest.TestCase):
|
||||
bin_points_expected[0, 1, 1, :2] = torch.tensor([0, 1])
|
||||
|
||||
pointclouds = Pointclouds(points=[points])
|
||||
radius = torch.full((points.shape[0],), fill_value=radius, device=device)
|
||||
args = (
|
||||
pointclouds.points_packed(),
|
||||
pointclouds.cloud_to_packed_first_idx(),
|
||||
@@ -512,3 +532,115 @@ class TestRasterizePoints(TestCaseMixin, unittest.TestCase):
|
||||
bin_points_same = (bin_points == bin_points_expected).all()
|
||||
|
||||
self.assertTrue(bin_points_same.item() == 1)
|
||||
|
||||
def _test_variable_size_radius(self, rasterize_points_fn, device, bin_size=0):
|
||||
# Two points
|
||||
points = torch.tensor(
|
||||
[[0.5, 0.5, 0.3], [0.5, -0.5, -0.1], [0.0, 0.0, 0.3]],
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
)
|
||||
image_size = 16
|
||||
points_per_pixel = 1
|
||||
radius = torch.tensor([0.1, 0.0, 0.2], dtype=torch.float32, device=device)
|
||||
pointclouds = Pointclouds(points=[points])
|
||||
if bin_size == -1:
|
||||
# simple python case with no binning
|
||||
idx, zbuf, dists = rasterize_points_fn(
|
||||
pointclouds, image_size, radius, points_per_pixel
|
||||
)
|
||||
else:
|
||||
idx, zbuf, dists = rasterize_points_fn(
|
||||
pointclouds, image_size, radius, points_per_pixel, bin_size
|
||||
)
|
||||
|
||||
idx_expected = torch.zeros(
|
||||
(1, image_size, image_size, 1), dtype=torch.int64, device=device
|
||||
)
|
||||
# fmt: off
|
||||
idx_expected[0, ..., 0] = torch.tensor(
|
||||
[
|
||||
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241 E201
|
||||
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241 E201
|
||||
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241 E201
|
||||
[-1, -1, -1, 0, 0, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241 E201
|
||||
[-1, -1, -1, 0, 0, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241 E201
|
||||
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241 E201
|
||||
[-1, -1, -1, -1, -1, -1, -1, 2, 2, -1, -1, -1, -1, -1, -1, -1], # noqa: E241 E201
|
||||
[-1, -1, -1, -1, -1, -1, 2, 2, 2, 2, -1, -1, -1, -1, -1, -1], # noqa: E241 E201
|
||||
[-1, -1, -1, -1, -1, -1, 2, 2, 2, 2, -1, -1, -1, -1, -1, -1], # noqa: E241 E201
|
||||
[-1, -1, -1, -1, -1, -1, -1, 2, 2, -1, -1, -1, -1, -1, -1, -1], # noqa: E241 E201
|
||||
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241 E201
|
||||
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241 E201
|
||||
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241 E201
|
||||
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241 E201
|
||||
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241 E201
|
||||
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1] # noqa: E241 E201
|
||||
],
|
||||
dtype=torch.int64,
|
||||
device=device
|
||||
)
|
||||
# fmt: on
|
||||
zbuf_expected = torch.full(
|
||||
idx_expected.shape, fill_value=-1, dtype=torch.float32, device=device
|
||||
)
|
||||
zbuf_expected[idx_expected == 0] = 0.3
|
||||
zbuf_expected[idx_expected == 2] = 0.3
|
||||
|
||||
dists_expected = torch.full(
|
||||
idx_expected.shape, fill_value=-1, dtype=torch.float32, device=device
|
||||
)
|
||||
|
||||
# fmt: off
|
||||
dists_expected[0, ..., 0] = torch.Tensor(
|
||||
[
|
||||
[-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.], # noqa: E241 E201
|
||||
[-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.], # noqa: E241 E201
|
||||
[-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.], # noqa: E241 E201
|
||||
[-1., -1., -1., 0.0078, 0.0078, -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.], # noqa: E241 E201
|
||||
[-1., -1., -1., 0.0078, 0.0078, -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.], # noqa: E241 E201
|
||||
[-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.], # noqa: E241 E201
|
||||
[-1., -1., -1., -1., -1., -1., -1., 0.0391, 0.0391, -1., -1., -1., -1., -1., -1., -1.], # noqa: E241 E201
|
||||
[-1., -1., -1., -1., -1., -1., 0.0391, 0.0078, 0.0078, 0.0391, -1., -1., -1., -1., -1., -1.], # noqa: E241 E201
|
||||
[-1., -1., -1., -1., -1., -1., 0.0391, 0.0078, 0.0078, 0.0391, -1., -1., -1., -1., -1., -1.], # noqa: E241 E201
|
||||
[-1., -1., -1., -1., -1., -1., -1., 0.0391, 0.0391, -1., -1., -1., -1., -1., -1., -1.], # noqa: E241 E201
|
||||
[-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.], # noqa: E241 E201
|
||||
[-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.], # noqa: E241 E201
|
||||
[-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.], # noqa: E241 E201
|
||||
[-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.], # noqa: E241 E201
|
||||
[-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.], # noqa: E241 E201
|
||||
[-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.] # noqa: E241 E201
|
||||
]
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
# Check the distances for a point are less than the squared radius
|
||||
# for that point.
|
||||
self.assertTrue((dists[idx == 0] < radius[0] ** 2).all())
|
||||
self.assertTrue((dists[idx == 2] < radius[2] ** 2).all())
|
||||
|
||||
# Check all values are correct.
|
||||
idx_same = (idx == idx_expected).all().item() == 1
|
||||
zbuf_same = (zbuf == zbuf_expected).all().item() == 1
|
||||
|
||||
self.assertTrue(idx_same)
|
||||
self.assertTrue(zbuf_same)
|
||||
self.assertClose(dists, dists_expected, atol=4e-5)
|
||||
|
||||
def test_radius_format_failure(self):
|
||||
N = 20
|
||||
P_max = 15
|
||||
points_list = []
|
||||
for _ in range(N):
|
||||
p = torch.randint(low=1, high=P_max, size=(1,))[0]
|
||||
points_list.append(torch.randn((p, 3)))
|
||||
|
||||
points = Pointclouds(points=points_list)
|
||||
|
||||
# Incorrect shape
|
||||
with self.assertRaisesRegex(ValueError, "radius must be of shape"):
|
||||
_format_radius([0, 1, 2], points)
|
||||
|
||||
# Incorrect type
|
||||
with self.assertRaisesRegex(ValueError, "float, list, tuple or tensor"):
|
||||
_format_radius({0: [0, 1, 2]}, points)
|
||||
|
||||
Reference in New Issue
Block a user