Support variable size radius for points in rasterizer

Summary:
Support variable size pointclouds in the renderer API to allow compatibility with Pulsar rasterizer.

If radius is provided as a float, it is converted to a tensor of shape (P). Otherwise radius is expected to be an (N, P_padded) dimensional tensor where P_padded is the max number of points in the batch (following the convention from pulsar: https://our.intern.facebook.com/intern/diffusion/FBS/browse/master/fbcode/frl/gemini/pulsar/pulsar/renderer.py?commit=ee0342850210e5df441e14fd97162675c70d147c&lines=50)

Reviewed By: jcjohnson, gkioxari

Differential Revision: D21429400

fbshipit-source-id: 65de7d9cd2472b27fc29f96160c33687e88098a2
This commit is contained in:
Nikhila Ravi
2020-09-18 18:46:45 -07:00
committed by Facebook GitHub Bot
parent e40c2167ae
commit ebe2693b11
8 changed files with 291 additions and 73 deletions

View File

@@ -1,6 +1,7 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import itertools
from fvcore.common.benchmark import benchmark
from test_cameras_alignment import TestCamerasAlignment

View File

@@ -1,5 +1,6 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from itertools import product
import torch
from fvcore.common.benchmark import benchmark
@@ -18,44 +19,64 @@ def _bm_python_with_init(N, P, img_size=32, radius=0.1, pts_per_pxl=3):
return lambda: rasterize_points_python(*args)
def _bm_cpu_with_init(N, P, img_size=32, radius=0.1, pts_per_pxl=3):
def _bm_rasterize_points_with_init(
N, P, img_size=32, radius=0.1, pts_per_pxl=3, device="cpu", expand_radius=False
):
torch.manual_seed(231)
points = torch.randn(N, P, 3)
pointclouds = Pointclouds(points=points)
args = (pointclouds, img_size, radius, pts_per_pxl)
return lambda: rasterize_points(*args)
def _bm_cuda_with_init(N, P, img_size=32, radius=0.1, pts_per_pxl=3):
torch.manual_seed(231)
device = torch.device("cuda:0")
device = torch.device(device)
points = torch.randn(N, P, 3, device=device)
pointclouds = Pointclouds(points=points)
if expand_radius:
points_padded = pointclouds.points_padded()
radius = torch.full((N, P), fill_value=radius).type_as(points_padded)
args = (pointclouds, img_size, radius, pts_per_pxl)
torch.cuda.synchronize(device)
if device == "cuda":
torch.cuda.synchronize(device)
def fn():
rasterize_points(*args)
torch.cuda.synchronize(device)
if device == "cuda":
torch.cuda.synchronize(device)
return fn
def bm_python_vs_cpu() -> None:
kwargs_list = [
{"N": 1, "P": 32, "img_size": 32, "radius": 0.1, "pts_per_pxl": 3},
{"N": 2, "P": 32, "img_size": 32, "radius": 0.1, "pts_per_pxl": 3},
]
benchmark(_bm_python_with_init, "RASTERIZE_PYTHON", kwargs_list, warmup_iters=1)
benchmark(_bm_cpu_with_init, "RASTERIZE_CPU", kwargs_list, warmup_iters=1)
kwargs_list = [
{"N": 2, "P": 32, "img_size": 32, "radius": 0.1, "pts_per_pxl": 3},
{"N": 4, "P": 1024, "img_size": 128, "radius": 0.05, "pts_per_pxl": 5},
]
benchmark(_bm_cpu_with_init, "RASTERIZE_CPU", kwargs_list, warmup_iters=1)
def bm_python_vs_cpu_vs_cuda() -> None:
kwargs_list = []
num_meshes = [1]
num_points = [10000, 2000]
image_size = [128, 256]
radius = [1e-3, 0.01]
pts_per_pxl = [50, 100]
expand = [True, False]
test_cases = product(
num_meshes, num_points, image_size, radius, pts_per_pxl, expand
)
for case in test_cases:
n, p, im, r, pts, e = case
kwargs_list.append(
{
"N": n,
"P": p,
"img_size": im,
"radius": r,
"pts_per_pxl": pts,
"device": "cpu",
"expand_radius": e,
}
)
benchmark(
_bm_rasterize_points_with_init, "RASTERIZE_CPU", kwargs_list, warmup_iters=1
)
kwargs_list += [
{"N": 32, "P": 10000, "img_size": 128, "radius": 0.01, "pts_per_pxl": 50},
{"N": 32, "P": 100000, "img_size": 128, "radius": 0.01, "pts_per_pxl": 50},
{"N": 8, "P": 200000, "img_size": 512, "radius": 0.01, "pts_per_pxl": 50},
]
benchmark(_bm_cuda_with_init, "RASTERIZE_CUDA", kwargs_list, warmup_iters=1)
for k in kwargs_list:
k["device"] = "cuda"
benchmark(
_bm_rasterize_points_with_init, "RASTERIZE_CUDA", kwargs_list, warmup_iters=1
)

View File

@@ -8,6 +8,7 @@ import torch
from common_testing import TestCaseMixin, get_random_cuda_device
from pytorch3d import _C
from pytorch3d.renderer.points.rasterize_points import (
_format_radius,
rasterize_points,
rasterize_points_python,
)
@@ -40,6 +41,21 @@ class TestRasterizePoints(TestCaseMixin, unittest.TestCase):
device = get_random_cuda_device()
self._test_behind_camera(rasterize_points, device, bin_size=0)
def test_python_variable_radius(self):
self._test_variable_size_radius(
rasterize_points_python, torch.device("cpu"), bin_size=-1
)
def test_cpu_variable_radius(self):
self._test_variable_size_radius(rasterize_points, torch.device("cpu"))
def test_cuda_variable_radius(self):
device = get_random_cuda_device()
# Naive
self._test_variable_size_radius(rasterize_points, device, bin_size=0)
# Coarse to fine
self._test_variable_size_radius(rasterize_points, device, bin_size=None)
def test_cpp_vs_naive_vs_binned(self):
# Make sure that the backward pass runs for all pathways
N = 2
@@ -403,6 +419,8 @@ class TestRasterizePoints(TestCaseMixin, unittest.TestCase):
points_packed = pointclouds.points_packed()
cloud_to_packed_first_idx = pointclouds.cloud_to_packed_first_idx()
num_points_per_cloud = pointclouds.num_points_per_cloud()
radius = torch.full((points_packed.shape[0],), fill_value=radius)
args = (
points_packed,
cloud_to_packed_first_idx,
@@ -419,6 +437,7 @@ class TestRasterizePoints(TestCaseMixin, unittest.TestCase):
points_packed = pointclouds_cuda.points_packed()
cloud_to_packed_first_idx = pointclouds_cuda.cloud_to_packed_first_idx()
num_points_per_cloud = pointclouds_cuda.num_points_per_cloud()
radius = radius.to(device)
args = (
points_packed,
cloud_to_packed_first_idx,
@@ -499,6 +518,7 @@ class TestRasterizePoints(TestCaseMixin, unittest.TestCase):
bin_points_expected[0, 1, 1, :2] = torch.tensor([0, 1])
pointclouds = Pointclouds(points=[points])
radius = torch.full((points.shape[0],), fill_value=radius, device=device)
args = (
pointclouds.points_packed(),
pointclouds.cloud_to_packed_first_idx(),
@@ -512,3 +532,115 @@ class TestRasterizePoints(TestCaseMixin, unittest.TestCase):
bin_points_same = (bin_points == bin_points_expected).all()
self.assertTrue(bin_points_same.item() == 1)
def _test_variable_size_radius(self, rasterize_points_fn, device, bin_size=0):
# Two points
points = torch.tensor(
[[0.5, 0.5, 0.3], [0.5, -0.5, -0.1], [0.0, 0.0, 0.3]],
dtype=torch.float32,
device=device,
)
image_size = 16
points_per_pixel = 1
radius = torch.tensor([0.1, 0.0, 0.2], dtype=torch.float32, device=device)
pointclouds = Pointclouds(points=[points])
if bin_size == -1:
# simple python case with no binning
idx, zbuf, dists = rasterize_points_fn(
pointclouds, image_size, radius, points_per_pixel
)
else:
idx, zbuf, dists = rasterize_points_fn(
pointclouds, image_size, radius, points_per_pixel, bin_size
)
idx_expected = torch.zeros(
(1, image_size, image_size, 1), dtype=torch.int64, device=device
)
# fmt: off
idx_expected[0, ..., 0] = torch.tensor(
[
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241 E201
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241 E201
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241 E201
[-1, -1, -1, 0, 0, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241 E201
[-1, -1, -1, 0, 0, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241 E201
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241 E201
[-1, -1, -1, -1, -1, -1, -1, 2, 2, -1, -1, -1, -1, -1, -1, -1], # noqa: E241 E201
[-1, -1, -1, -1, -1, -1, 2, 2, 2, 2, -1, -1, -1, -1, -1, -1], # noqa: E241 E201
[-1, -1, -1, -1, -1, -1, 2, 2, 2, 2, -1, -1, -1, -1, -1, -1], # noqa: E241 E201
[-1, -1, -1, -1, -1, -1, -1, 2, 2, -1, -1, -1, -1, -1, -1, -1], # noqa: E241 E201
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241 E201
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241 E201
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241 E201
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241 E201
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241 E201
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1] # noqa: E241 E201
],
dtype=torch.int64,
device=device
)
# fmt: on
zbuf_expected = torch.full(
idx_expected.shape, fill_value=-1, dtype=torch.float32, device=device
)
zbuf_expected[idx_expected == 0] = 0.3
zbuf_expected[idx_expected == 2] = 0.3
dists_expected = torch.full(
idx_expected.shape, fill_value=-1, dtype=torch.float32, device=device
)
# fmt: off
dists_expected[0, ..., 0] = torch.Tensor(
[
[-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.], # noqa: E241 E201
[-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.], # noqa: E241 E201
[-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.], # noqa: E241 E201
[-1., -1., -1., 0.0078, 0.0078, -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.], # noqa: E241 E201
[-1., -1., -1., 0.0078, 0.0078, -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.], # noqa: E241 E201
[-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.], # noqa: E241 E201
[-1., -1., -1., -1., -1., -1., -1., 0.0391, 0.0391, -1., -1., -1., -1., -1., -1., -1.], # noqa: E241 E201
[-1., -1., -1., -1., -1., -1., 0.0391, 0.0078, 0.0078, 0.0391, -1., -1., -1., -1., -1., -1.], # noqa: E241 E201
[-1., -1., -1., -1., -1., -1., 0.0391, 0.0078, 0.0078, 0.0391, -1., -1., -1., -1., -1., -1.], # noqa: E241 E201
[-1., -1., -1., -1., -1., -1., -1., 0.0391, 0.0391, -1., -1., -1., -1., -1., -1., -1.], # noqa: E241 E201
[-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.], # noqa: E241 E201
[-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.], # noqa: E241 E201
[-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.], # noqa: E241 E201
[-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.], # noqa: E241 E201
[-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.], # noqa: E241 E201
[-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.] # noqa: E241 E201
]
)
# fmt: on
# Check the distances for a point are less than the squared radius
# for that point.
self.assertTrue((dists[idx == 0] < radius[0] ** 2).all())
self.assertTrue((dists[idx == 2] < radius[2] ** 2).all())
# Check all values are correct.
idx_same = (idx == idx_expected).all().item() == 1
zbuf_same = (zbuf == zbuf_expected).all().item() == 1
self.assertTrue(idx_same)
self.assertTrue(zbuf_same)
self.assertClose(dists, dists_expected, atol=4e-5)
def test_radius_format_failure(self):
N = 20
P_max = 15
points_list = []
for _ in range(N):
p = torch.randint(low=1, high=P_max, size=(1,))[0]
points_list.append(torch.randn((p, 3)))
points = Pointclouds(points=points_list)
# Incorrect shape
with self.assertRaisesRegex(ValueError, "radius must be of shape"):
_format_radius([0, 1, 2], points)
# Incorrect type
with self.assertRaisesRegex(ValueError, "float, list, tuple or tensor"):
_format_radius({0: [0, 1, 2]}, points)