Support variable size radius for points in rasterizer

Summary:
Support variable size pointclouds in the renderer API to allow compatibility with Pulsar rasterizer.

If radius is provided as a float, it is converted to a tensor of shape (P). Otherwise radius is expected to be an (N, P_padded) dimensional tensor where P_padded is the max number of points in the batch (following the convention from pulsar: https://our.intern.facebook.com/intern/diffusion/FBS/browse/master/fbcode/frl/gemini/pulsar/pulsar/renderer.py?commit=ee0342850210e5df441e14fd97162675c70d147c&lines=50)

Reviewed By: jcjohnson, gkioxari

Differential Revision: D21429400

fbshipit-source-id: 65de7d9cd2472b27fc29f96160c33687e88098a2
This commit is contained in:
Nikhila Ravi 2020-09-18 18:46:45 -07:00 committed by Facebook GitHub Bot
parent e40c2167ae
commit ebe2693b11
8 changed files with 291 additions and 73 deletions

View File

@ -38,13 +38,15 @@ __device__ void CheckPixelInsidePoint(
float& q_max_z,
int& q_max_idx,
PointQ& q,
const float radius2,
const float* radius,
const float xf,
const float yf,
const int K) {
const float px = points[p_idx * 3 + 0];
const float py = points[p_idx * 3 + 1];
const float pz = points[p_idx * 3 + 2];
const float p_radius = radius[p_idx];
const float radius2 = p_radius * p_radius;
if (pz < 0)
return; // Don't render points behind the camera
const float dx = xf - px;
@ -81,7 +83,7 @@ __global__ void RasterizePointsNaiveCudaKernel(
const float* points, // (P, 3)
const int64_t* cloud_to_packed_first_idx, // (N)
const int64_t* num_points_per_cloud, // (N)
const float radius,
const float* radius,
const int N,
const int S,
const int K,
@ -91,7 +93,6 @@ __global__ void RasterizePointsNaiveCudaKernel(
// Simple version: One thread per output pixel
const int num_threads = gridDim.x * blockDim.x;
const int tid = blockDim.x * blockIdx.x + threadIdx.x;
const float radius2 = radius * radius;
for (int i = tid; i < N * S * S; i += num_threads) {
// Convert linear index to 3D index
const int n = i / (S * S); // Batch index
@ -128,7 +129,7 @@ __global__ void RasterizePointsNaiveCudaKernel(
for (int p_idx = point_start_idx; p_idx < point_stop_idx; ++p_idx) {
CheckPixelInsidePoint(
points, p_idx, q_size, q_max_z, q_max_idx, q, radius2, xf, yf, K);
points, p_idx, q_size, q_max_z, q_max_idx, q, radius, xf, yf, K);
}
BubbleSort(q, q_size);
int idx = n * S * S * K + pix_idx * K;
@ -145,7 +146,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> RasterizePointsNaiveCuda(
const at::Tensor& cloud_to_packed_first_idx, // (N)
const at::Tensor& num_points_per_cloud, // (N)
const int image_size,
const float radius,
const at::Tensor& radius,
const int points_per_pixel) {
// Check inputs are on the same device
at::TensorArg points_t{points, "points", 1},
@ -194,7 +195,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> RasterizePointsNaiveCuda(
points.contiguous().data_ptr<float>(),
cloud_to_packed_first_idx.contiguous().data_ptr<int64_t>(),
num_points_per_cloud.contiguous().data_ptr<int64_t>(),
radius,
radius.contiguous().data_ptr<float>(),
N,
S,
K,
@ -214,7 +215,7 @@ __global__ void RasterizePointsCoarseCudaKernel(
const float* points, // (P, 3)
const int64_t* cloud_to_packed_first_idx, // (N)
const int64_t* num_points_per_cloud, // (N)
const float radius,
const float* radius,
const int N,
const int P,
const int S,
@ -266,12 +267,13 @@ __global__ void RasterizePointsCoarseCudaKernel(
const float px = points[p_idx * 3 + 0];
const float py = points[p_idx * 3 + 1];
const float pz = points[p_idx * 3 + 2];
const float p_radius = radius[p_idx];
if (pz < 0)
continue; // Don't render points behind the camera.
const float px0 = px - radius;
const float px1 = px + radius;
const float py0 = py - radius;
const float py1 = py + radius;
const float px0 = px - p_radius;
const float px1 = px + p_radius;
const float py0 = py - p_radius;
const float py1 = py + p_radius;
// Brute-force search over all bins; TODO something smarter?
// For example we could compute the exact bin where the point falls,
@ -341,7 +343,7 @@ at::Tensor RasterizePointsCoarseCuda(
const at::Tensor& cloud_to_packed_first_idx, // (N)
const at::Tensor& num_points_per_cloud, // (N)
const int image_size,
const float radius,
const at::Tensor& radius,
const int bin_size,
const int max_points_per_bin) {
TORCH_CHECK(
@ -390,7 +392,7 @@ at::Tensor RasterizePointsCoarseCuda(
points.contiguous().data_ptr<float>(),
cloud_to_packed_first_idx.contiguous().data_ptr<int64_t>(),
num_points_per_cloud.contiguous().data_ptr<int64_t>(),
radius,
radius.contiguous().data_ptr<float>(),
N,
P,
image_size,
@ -411,7 +413,7 @@ at::Tensor RasterizePointsCoarseCuda(
__global__ void RasterizePointsFineCudaKernel(
const float* points, // (P, 3)
const int32_t* bin_points, // (N, B, B, T)
const float radius,
const float* radius,
const int bin_size,
const int N,
const int B, // num_bins
@ -425,7 +427,6 @@ __global__ void RasterizePointsFineCudaKernel(
const int num_pixels = N * B * B * bin_size * bin_size;
const int num_threads = gridDim.x * blockDim.x;
const int tid = blockIdx.x * blockDim.x + threadIdx.x;
const float radius2 = radius * radius;
for (int pid = tid; pid < num_pixels; pid += num_threads) {
// Convert linear index into bin and pixel indices. We make the within
@ -464,7 +465,7 @@ __global__ void RasterizePointsFineCudaKernel(
continue;
}
CheckPixelInsidePoint(
points, p, q_size, q_max_z, q_max_idx, q, radius2, xf, yf, K);
points, p, q_size, q_max_z, q_max_idx, q, radius, xf, yf, K);
}
// Now we've looked at all the points for this bin, so we can write
// output for the current pixel.
@ -488,7 +489,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> RasterizePointsFineCuda(
const at::Tensor& points, // (P, 3)
const at::Tensor& bin_points,
const int image_size,
const float radius,
const at::Tensor& radius,
const int bin_size,
const int points_per_pixel) {
// Check inputs are on the same device
@ -525,7 +526,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> RasterizePointsFineCuda(
RasterizePointsFineCudaKernel<<<blocks, threads, 0, stream>>>(
points.contiguous().data_ptr<float>(),
bin_points.contiguous().data_ptr<int32_t>(),
radius,
radius.contiguous().data_ptr<float>(),
bin_size,
N,
B,

View File

@ -15,7 +15,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsNaiveCpu(
const torch::Tensor& cloud_to_packed_first_idx,
const torch::Tensor& num_points_per_cloud,
const int image_size,
const float radius,
const torch::Tensor& radius,
const int points_per_pixel);
#ifdef WITH_CUDA
@ -25,7 +25,7 @@ RasterizePointsNaiveCuda(
const torch::Tensor& cloud_to_packed_first_idx,
const torch::Tensor& num_points_per_cloud,
const int image_size,
const float radius,
const torch::Tensor& radius,
const int points_per_pixel);
#endif
// Naive (forward) pointcloud rasterization: For each pixel, for each point,
@ -41,7 +41,8 @@ RasterizePointsNaiveCuda(
// in the batch where N is the batch size.
// num_points_per_cloud: LongTensor of shape (N) giving the number of points
// for each pointcloud in the batch.
// radius: Radius of each point (in NDC units)
// radius: FloatTensor of shape (P) giving the radius (in NDC units) of
// each point in points.
// image_size: (S) Size of the image to return (in pixels)
// points_per_pixel: (K) The number closest of points to return for each pixel
//
@ -62,7 +63,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsNaive(
const torch::Tensor& cloud_to_packed_first_idx,
const torch::Tensor& num_points_per_cloud,
const int image_size,
const float radius,
const torch::Tensor& radius,
const int points_per_pixel) {
if (points.is_cuda() && cloud_to_packed_first_idx.is_cuda() &&
num_points_per_cloud.is_cuda()) {
@ -70,6 +71,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsNaive(
CHECK_CUDA(points);
CHECK_CUDA(cloud_to_packed_first_idx);
CHECK_CUDA(num_points_per_cloud);
CHECK_CUDA(radius);
return RasterizePointsNaiveCuda(
points,
cloud_to_packed_first_idx,
@ -100,7 +102,7 @@ torch::Tensor RasterizePointsCoarseCpu(
const torch::Tensor& cloud_to_packed_first_idx,
const torch::Tensor& num_points_per_cloud,
const int image_size,
const float radius,
const torch::Tensor& radius,
const int bin_size,
const int max_points_per_bin);
@ -110,7 +112,7 @@ torch::Tensor RasterizePointsCoarseCuda(
const torch::Tensor& cloud_to_packed_first_idx,
const torch::Tensor& num_points_per_cloud,
const int image_size,
const float radius,
const torch::Tensor& radius,
const int bin_size,
const int max_points_per_bin);
#endif
@ -124,7 +126,8 @@ torch::Tensor RasterizePointsCoarseCuda(
// in the batch where N is the batch size.
// num_points_per_cloud: LongTensor of shape (N) giving the number of points
// for each pointcloud in the batch.
// radius: Radius of points to rasterize (in NDC units)
// radius: FloatTensor of shape (P) giving the radius (in NDC units) of
// each point in points.
// image_size: Size of the image to generate (in pixels)
// bin_size: Size of each bin within the image (in pixels)
//
@ -138,7 +141,7 @@ torch::Tensor RasterizePointsCoarse(
const torch::Tensor& cloud_to_packed_first_idx,
const torch::Tensor& num_points_per_cloud,
const int image_size,
const float radius,
const torch::Tensor& radius,
const int bin_size,
const int max_points_per_bin) {
if (points.is_cuda() && cloud_to_packed_first_idx.is_cuda() &&
@ -147,6 +150,7 @@ torch::Tensor RasterizePointsCoarse(
CHECK_CUDA(points);
CHECK_CUDA(cloud_to_packed_first_idx);
CHECK_CUDA(num_points_per_cloud);
CHECK_CUDA(radius);
return RasterizePointsCoarseCuda(
points,
cloud_to_packed_first_idx,
@ -179,7 +183,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsFineCuda(
const torch::Tensor& points,
const torch::Tensor& bin_points,
const int image_size,
const float radius,
const torch::Tensor& radius,
const int bin_size,
const int points_per_pixel);
#endif
@ -191,7 +195,8 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsFineCuda(
// bin_points: int32 Tensor of shape (N, B, B, M) giving the indices of points
// that fall into each bin (output from coarse rasterization)
// image_size: Size of image to generate (in pixels)
// radius: Radius of points to rasterize (NDC units)
// radius: FloatTensor of shape (P) giving the radius (in NDC units) of
// each point in points.
// bin_size: Size of each bin (in pixels)
// points_per_pixel: How many points to rasterize for each pixel
//
@ -210,7 +215,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsFine(
const torch::Tensor& points,
const torch::Tensor& bin_points,
const int image_size,
const float radius,
const torch::Tensor& radius,
const int bin_size,
const int points_per_pixel) {
if (points.is_cuda()) {
@ -296,7 +301,8 @@ torch::Tensor RasterizePointsBackward(
// in the batch where N is the batch size.
// num_points_per_cloud: LongTensor of shape (N) giving the number of points
// for each pointcloud in the batch.
// radius: Radius of each point (in NDC units)
// radius: FloatTensor of shape (P) giving the radius (in NDC units) of
// each point in points.
// image_size: (S) Size of the image to return (in pixels)
// points_per_pixel: (K) The number of points to return for each pixel
// bin_size: Bin size (in pixels) for coarse-to-fine rasterization. Setting
@ -320,7 +326,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePoints(
const torch::Tensor& cloud_to_packed_first_idx,
const torch::Tensor& num_points_per_cloud,
const int image_size,
const float radius,
const torch::Tensor& radius,
const int points_per_pixel,
const int bin_size,
const int max_points_per_bin) {

View File

@ -17,7 +17,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsNaiveCpu(
const torch::Tensor& cloud_to_packed_first_idx, // (N)
const torch::Tensor& num_points_per_cloud, // (N)
const int image_size,
const float radius,
const torch::Tensor& radius,
const int points_per_pixel) {
const int32_t N = cloud_to_packed_first_idx.size(0); // batch_size.
@ -35,8 +35,8 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsNaiveCpu(
auto point_idxs_a = point_idxs.accessor<int32_t, 4>();
auto zbuf_a = zbuf.accessor<float, 4>();
auto pix_dists_a = pix_dists.accessor<float, 4>();
auto radius_a = radius.accessor<float, 1>();
const float radius2 = radius * radius;
for (int n = 0; n < N; ++n) {
// Loop through each pointcloud in the batch.
// Get the start index of the points in points_packed and the num points
@ -63,6 +63,8 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsNaiveCpu(
const float px = points_a[p][0];
const float py = points_a[p][1];
const float pz = points_a[p][2];
const float p_radius = radius_a[p];
const float radius2 = p_radius * p_radius;
if (pz < 0) {
continue;
}
@ -98,7 +100,7 @@ torch::Tensor RasterizePointsCoarseCpu(
const torch::Tensor& cloud_to_packed_first_idx, // (N)
const torch::Tensor& num_points_per_cloud, // (N)
const int image_size,
const float radius,
const torch::Tensor& radius,
const int bin_size,
const int max_points_per_bin) {
const int32_t N = cloud_to_packed_first_idx.size(0); // batch_size.
@ -112,6 +114,7 @@ torch::Tensor RasterizePointsCoarseCpu(
auto points_a = points.accessor<float, 2>();
auto points_per_bin_a = points_per_bin.accessor<int32_t, 3>();
auto bin_points_a = bin_points.accessor<int32_t, 4>();
auto radius_a = radius.accessor<float, 1>();
const float pixel_width = 2.0f / image_size;
const float bin_width = pixel_width * bin_size;
@ -140,13 +143,14 @@ torch::Tensor RasterizePointsCoarseCpu(
float px = points_a[p][0];
float py = points_a[p][1];
float pz = points_a[p][2];
const float p_radius = radius_a[p];
if (pz < 0) {
continue;
}
float point_x_min = px - radius;
float point_x_max = px + radius;
float point_y_min = py - radius;
float point_y_max = py + radius;
float point_x_min = px - p_radius;
float point_x_max = px + p_radius;
float point_y_min = py - p_radius;
float point_y_max = py + p_radius;
// Use a half-open interval so that points exactly on the
// boundary between bins will fall into exactly one bin.

View File

@ -1,6 +1,6 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from typing import Optional
from typing import List, Optional, Tuple, Union
import torch
@ -18,7 +18,7 @@ kMaxPointsPerBin = 22
def rasterize_points(
pointclouds,
image_size: int = 256,
radius: float = 0.01,
radius: Union[float, List, Tuple, torch.Tensor] = 0.01,
points_per_pixel: int = 8,
bin_size: Optional[int] = None,
max_points_per_bin: Optional[int] = None,
@ -35,8 +35,10 @@ def rasterize_points(
(0, 0, 0); In the camera coordinate frame the x-axis goes from right-to-left,
the y-axis goes from bottom-to-top, and the z-axis goes from back-to-front.
image_size: Integer giving the resolution of the rasterized image
radius (Optional): Float giving the radius (in NDC units) of the disk to
be rasterized for each point.
radius (Optional): The radius (in NDC units) of the disk to
be rasterized. This can either be a float in which case the same radius is used
for each point, or a torch.Tensor of shape (N, P) giving a radius per point
in the batch.
points_per_pixel (Optional): We will keep track of this many points per
pixel, returning the nearest points_per_pixel points along the z-axis
bin_size: Size of bins to use for coarse-to-fine rasterization. Setting
@ -74,6 +76,8 @@ def rasterize_points(
cloud_to_packed_first_idx = pointclouds.cloud_to_packed_first_idx()
num_points_per_cloud = pointclouds.num_points_per_cloud()
radius = _format_radius(radius, pointclouds)
if bin_size is None:
if not points_packed.is_cuda:
# Binned CPU rasterization not fully implemented
@ -117,6 +121,48 @@ def rasterize_points(
)
def _format_radius(
radius: Union[float, List, Tuple, torch.Tensor], pointclouds
) -> torch.Tensor:
"""
Format the radius as a torch tensor of shape (P_packed,)
where P_packed is the total number of points in the
batch (i.e. pointclouds.points_packed().shape[0]).
This will enable support for a different size radius
for each point in the batch.
Args:
radius: can be a float, List, Tuple or tensor of
shape (N, P_padded) where P_padded is the
maximum number of points for each pointcloud
in the batch.
Returns:
radius: torch.Tensor of shape (P_packed)
"""
N, P_padded = pointclouds._N, pointclouds._P
points_packed = pointclouds.points_packed()
P_packed = points_packed.shape[0]
if isinstance(radius, (list, tuple)):
radius = torch.tensor(radius).type_as(points_packed)
if isinstance(radius, torch.Tensor):
if N == 1 and radius.ndim == 1:
radius = radius[None, ...]
if radius.shape != (N, P_padded):
msg = "radius must be of shape (N, P): got %s"
raise ValueError(msg % (repr(radius.shape)))
else:
padded_to_packed_idx = pointclouds.padded_to_packed_idx()
radius = radius.view(-1)[padded_to_packed_idx]
elif isinstance(radius, float):
radius = torch.full((P_packed,), fill_value=radius).type_as(points_packed)
else:
msg = "radius must be a float, list, tuple or tensor; got %s"
raise ValueError(msg % type(radius))
return radius
class _RasterizePoints(torch.autograd.Function):
@staticmethod
def forward(
@ -125,7 +171,7 @@ class _RasterizePoints(torch.autograd.Function):
cloud_to_packed_first_idx,
num_points_per_cloud,
image_size: int = 256,
radius: float = 0.01,
radius: Union[float, torch.Tensor] = 0.01,
points_per_pixel: int = 8,
bin_size: int = 0,
max_points_per_bin: int = 0,
@ -175,7 +221,10 @@ class _RasterizePoints(torch.autograd.Function):
def rasterize_points_python(
pointclouds, image_size: int = 256, radius: float = 0.01, points_per_pixel: int = 8
pointclouds,
image_size: int = 256,
radius: Union[float, torch.Tensor] = 0.01,
points_per_pixel: int = 8,
):
"""
Naive pure PyTorch implementation of pointcloud rasterization.
@ -190,6 +239,9 @@ def rasterize_points_python(
cloud_to_packed_first_idx = pointclouds.cloud_to_packed_first_idx()
num_points_per_cloud = pointclouds.num_points_per_cloud()
# Support variable size radius for each point in the batch
radius = _format_radius(radius, pointclouds)
# Intialize output tensors.
point_idxs = torch.full(
(N, S, S, K), fill_value=-1, dtype=torch.int32, device=device
@ -225,12 +277,13 @@ def rasterize_points_python(
# Check whether each point in the batch affects this pixel.
for p in range(point_start_idx, point_stop_idx):
px, py, pz = points_packed[p, :]
r = radius2[p]
if pz < 0:
continue
dx = px - xf
dy = py - yf
dist2 = dx * dx + dy * dy
if dist2 < radius2:
if dist2 < r:
top_k_points.append((pz, p, dist2))
top_k_points.sort()
if len(top_k_points) > K:

View File

@ -2,7 +2,7 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from typing import NamedTuple, Optional
from typing import NamedTuple, Optional, Union
import torch
import torch.nn as nn
@ -30,7 +30,7 @@ class PointsRasterizationSettings:
def __init__(
self,
image_size: int = 256,
radius: float = 0.01,
radius: Union[float, torch.Tensor] = 0.01,
points_per_pixel: int = 8,
bin_size: Optional[int] = None,
max_points_per_bin: Optional[int] = None,

View File

@ -1,6 +1,7 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import itertools
from fvcore.common.benchmark import benchmark
from test_cameras_alignment import TestCamerasAlignment

View File

@ -1,5 +1,6 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from itertools import product
import torch
from fvcore.common.benchmark import benchmark
@ -18,44 +19,64 @@ def _bm_python_with_init(N, P, img_size=32, radius=0.1, pts_per_pxl=3):
return lambda: rasterize_points_python(*args)
def _bm_cpu_with_init(N, P, img_size=32, radius=0.1, pts_per_pxl=3):
def _bm_rasterize_points_with_init(
N, P, img_size=32, radius=0.1, pts_per_pxl=3, device="cpu", expand_radius=False
):
torch.manual_seed(231)
points = torch.randn(N, P, 3)
pointclouds = Pointclouds(points=points)
args = (pointclouds, img_size, radius, pts_per_pxl)
return lambda: rasterize_points(*args)
def _bm_cuda_with_init(N, P, img_size=32, radius=0.1, pts_per_pxl=3):
torch.manual_seed(231)
device = torch.device("cuda:0")
device = torch.device(device)
points = torch.randn(N, P, 3, device=device)
pointclouds = Pointclouds(points=points)
if expand_radius:
points_padded = pointclouds.points_padded()
radius = torch.full((N, P), fill_value=radius).type_as(points_padded)
args = (pointclouds, img_size, radius, pts_per_pxl)
torch.cuda.synchronize(device)
if device == "cuda":
torch.cuda.synchronize(device)
def fn():
rasterize_points(*args)
torch.cuda.synchronize(device)
if device == "cuda":
torch.cuda.synchronize(device)
return fn
def bm_python_vs_cpu() -> None:
kwargs_list = [
{"N": 1, "P": 32, "img_size": 32, "radius": 0.1, "pts_per_pxl": 3},
{"N": 2, "P": 32, "img_size": 32, "radius": 0.1, "pts_per_pxl": 3},
]
benchmark(_bm_python_with_init, "RASTERIZE_PYTHON", kwargs_list, warmup_iters=1)
benchmark(_bm_cpu_with_init, "RASTERIZE_CPU", kwargs_list, warmup_iters=1)
kwargs_list = [
{"N": 2, "P": 32, "img_size": 32, "radius": 0.1, "pts_per_pxl": 3},
{"N": 4, "P": 1024, "img_size": 128, "radius": 0.05, "pts_per_pxl": 5},
]
benchmark(_bm_cpu_with_init, "RASTERIZE_CPU", kwargs_list, warmup_iters=1)
def bm_python_vs_cpu_vs_cuda() -> None:
kwargs_list = []
num_meshes = [1]
num_points = [10000, 2000]
image_size = [128, 256]
radius = [1e-3, 0.01]
pts_per_pxl = [50, 100]
expand = [True, False]
test_cases = product(
num_meshes, num_points, image_size, radius, pts_per_pxl, expand
)
for case in test_cases:
n, p, im, r, pts, e = case
kwargs_list.append(
{
"N": n,
"P": p,
"img_size": im,
"radius": r,
"pts_per_pxl": pts,
"device": "cpu",
"expand_radius": e,
}
)
benchmark(
_bm_rasterize_points_with_init, "RASTERIZE_CPU", kwargs_list, warmup_iters=1
)
kwargs_list += [
{"N": 32, "P": 10000, "img_size": 128, "radius": 0.01, "pts_per_pxl": 50},
{"N": 32, "P": 100000, "img_size": 128, "radius": 0.01, "pts_per_pxl": 50},
{"N": 8, "P": 200000, "img_size": 512, "radius": 0.01, "pts_per_pxl": 50},
]
benchmark(_bm_cuda_with_init, "RASTERIZE_CUDA", kwargs_list, warmup_iters=1)
for k in kwargs_list:
k["device"] = "cuda"
benchmark(
_bm_rasterize_points_with_init, "RASTERIZE_CUDA", kwargs_list, warmup_iters=1
)

View File

@ -8,6 +8,7 @@ import torch
from common_testing import TestCaseMixin, get_random_cuda_device
from pytorch3d import _C
from pytorch3d.renderer.points.rasterize_points import (
_format_radius,
rasterize_points,
rasterize_points_python,
)
@ -40,6 +41,21 @@ class TestRasterizePoints(TestCaseMixin, unittest.TestCase):
device = get_random_cuda_device()
self._test_behind_camera(rasterize_points, device, bin_size=0)
def test_python_variable_radius(self):
self._test_variable_size_radius(
rasterize_points_python, torch.device("cpu"), bin_size=-1
)
def test_cpu_variable_radius(self):
self._test_variable_size_radius(rasterize_points, torch.device("cpu"))
def test_cuda_variable_radius(self):
device = get_random_cuda_device()
# Naive
self._test_variable_size_radius(rasterize_points, device, bin_size=0)
# Coarse to fine
self._test_variable_size_radius(rasterize_points, device, bin_size=None)
def test_cpp_vs_naive_vs_binned(self):
# Make sure that the backward pass runs for all pathways
N = 2
@ -403,6 +419,8 @@ class TestRasterizePoints(TestCaseMixin, unittest.TestCase):
points_packed = pointclouds.points_packed()
cloud_to_packed_first_idx = pointclouds.cloud_to_packed_first_idx()
num_points_per_cloud = pointclouds.num_points_per_cloud()
radius = torch.full((points_packed.shape[0],), fill_value=radius)
args = (
points_packed,
cloud_to_packed_first_idx,
@ -419,6 +437,7 @@ class TestRasterizePoints(TestCaseMixin, unittest.TestCase):
points_packed = pointclouds_cuda.points_packed()
cloud_to_packed_first_idx = pointclouds_cuda.cloud_to_packed_first_idx()
num_points_per_cloud = pointclouds_cuda.num_points_per_cloud()
radius = radius.to(device)
args = (
points_packed,
cloud_to_packed_first_idx,
@ -499,6 +518,7 @@ class TestRasterizePoints(TestCaseMixin, unittest.TestCase):
bin_points_expected[0, 1, 1, :2] = torch.tensor([0, 1])
pointclouds = Pointclouds(points=[points])
radius = torch.full((points.shape[0],), fill_value=radius, device=device)
args = (
pointclouds.points_packed(),
pointclouds.cloud_to_packed_first_idx(),
@ -512,3 +532,115 @@ class TestRasterizePoints(TestCaseMixin, unittest.TestCase):
bin_points_same = (bin_points == bin_points_expected).all()
self.assertTrue(bin_points_same.item() == 1)
def _test_variable_size_radius(self, rasterize_points_fn, device, bin_size=0):
# Two points
points = torch.tensor(
[[0.5, 0.5, 0.3], [0.5, -0.5, -0.1], [0.0, 0.0, 0.3]],
dtype=torch.float32,
device=device,
)
image_size = 16
points_per_pixel = 1
radius = torch.tensor([0.1, 0.0, 0.2], dtype=torch.float32, device=device)
pointclouds = Pointclouds(points=[points])
if bin_size == -1:
# simple python case with no binning
idx, zbuf, dists = rasterize_points_fn(
pointclouds, image_size, radius, points_per_pixel
)
else:
idx, zbuf, dists = rasterize_points_fn(
pointclouds, image_size, radius, points_per_pixel, bin_size
)
idx_expected = torch.zeros(
(1, image_size, image_size, 1), dtype=torch.int64, device=device
)
# fmt: off
idx_expected[0, ..., 0] = torch.tensor(
[
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241 E201
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241 E201
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241 E201
[-1, -1, -1, 0, 0, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241 E201
[-1, -1, -1, 0, 0, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241 E201
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241 E201
[-1, -1, -1, -1, -1, -1, -1, 2, 2, -1, -1, -1, -1, -1, -1, -1], # noqa: E241 E201
[-1, -1, -1, -1, -1, -1, 2, 2, 2, 2, -1, -1, -1, -1, -1, -1], # noqa: E241 E201
[-1, -1, -1, -1, -1, -1, 2, 2, 2, 2, -1, -1, -1, -1, -1, -1], # noqa: E241 E201
[-1, -1, -1, -1, -1, -1, -1, 2, 2, -1, -1, -1, -1, -1, -1, -1], # noqa: E241 E201
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241 E201
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241 E201
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241 E201
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241 E201
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241 E201
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1] # noqa: E241 E201
],
dtype=torch.int64,
device=device
)
# fmt: on
zbuf_expected = torch.full(
idx_expected.shape, fill_value=-1, dtype=torch.float32, device=device
)
zbuf_expected[idx_expected == 0] = 0.3
zbuf_expected[idx_expected == 2] = 0.3
dists_expected = torch.full(
idx_expected.shape, fill_value=-1, dtype=torch.float32, device=device
)
# fmt: off
dists_expected[0, ..., 0] = torch.Tensor(
[
[-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.], # noqa: E241 E201
[-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.], # noqa: E241 E201
[-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.], # noqa: E241 E201
[-1., -1., -1., 0.0078, 0.0078, -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.], # noqa: E241 E201
[-1., -1., -1., 0.0078, 0.0078, -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.], # noqa: E241 E201
[-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.], # noqa: E241 E201
[-1., -1., -1., -1., -1., -1., -1., 0.0391, 0.0391, -1., -1., -1., -1., -1., -1., -1.], # noqa: E241 E201
[-1., -1., -1., -1., -1., -1., 0.0391, 0.0078, 0.0078, 0.0391, -1., -1., -1., -1., -1., -1.], # noqa: E241 E201
[-1., -1., -1., -1., -1., -1., 0.0391, 0.0078, 0.0078, 0.0391, -1., -1., -1., -1., -1., -1.], # noqa: E241 E201
[-1., -1., -1., -1., -1., -1., -1., 0.0391, 0.0391, -1., -1., -1., -1., -1., -1., -1.], # noqa: E241 E201
[-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.], # noqa: E241 E201
[-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.], # noqa: E241 E201
[-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.], # noqa: E241 E201
[-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.], # noqa: E241 E201
[-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.], # noqa: E241 E201
[-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.] # noqa: E241 E201
]
)
# fmt: on
# Check the distances for a point are less than the squared radius
# for that point.
self.assertTrue((dists[idx == 0] < radius[0] ** 2).all())
self.assertTrue((dists[idx == 2] < radius[2] ** 2).all())
# Check all values are correct.
idx_same = (idx == idx_expected).all().item() == 1
zbuf_same = (zbuf == zbuf_expected).all().item() == 1
self.assertTrue(idx_same)
self.assertTrue(zbuf_same)
self.assertClose(dists, dists_expected, atol=4e-5)
def test_radius_format_failure(self):
N = 20
P_max = 15
points_list = []
for _ in range(N):
p = torch.randint(low=1, high=P_max, size=(1,))[0]
points_list.append(torch.randn((p, 3)))
points = Pointclouds(points=points_list)
# Incorrect shape
with self.assertRaisesRegex(ValueError, "radius must be of shape"):
_format_radius([0, 1, 2], points)
# Incorrect type
with self.assertRaisesRegex(ValueError, "float, list, tuple or tensor"):
_format_radius({0: [0, 1, 2]}, points)