Support variable size radius for points in rasterizer

Summary:
Support variable size pointclouds in the renderer API to allow compatibility with Pulsar rasterizer.

If radius is provided as a float, it is converted to a tensor of shape (P). Otherwise radius is expected to be an (N, P_padded) dimensional tensor where P_padded is the max number of points in the batch (following the convention from pulsar: https://our.intern.facebook.com/intern/diffusion/FBS/browse/master/fbcode/frl/gemini/pulsar/pulsar/renderer.py?commit=ee0342850210e5df441e14fd97162675c70d147c&lines=50)

Reviewed By: jcjohnson, gkioxari

Differential Revision: D21429400

fbshipit-source-id: 65de7d9cd2472b27fc29f96160c33687e88098a2
This commit is contained in:
Nikhila Ravi 2020-09-18 18:46:45 -07:00 committed by Facebook GitHub Bot
parent e40c2167ae
commit ebe2693b11
8 changed files with 291 additions and 73 deletions

View File

@ -38,13 +38,15 @@ __device__ void CheckPixelInsidePoint(
float& q_max_z, float& q_max_z,
int& q_max_idx, int& q_max_idx,
PointQ& q, PointQ& q,
const float radius2, const float* radius,
const float xf, const float xf,
const float yf, const float yf,
const int K) { const int K) {
const float px = points[p_idx * 3 + 0]; const float px = points[p_idx * 3 + 0];
const float py = points[p_idx * 3 + 1]; const float py = points[p_idx * 3 + 1];
const float pz = points[p_idx * 3 + 2]; const float pz = points[p_idx * 3 + 2];
const float p_radius = radius[p_idx];
const float radius2 = p_radius * p_radius;
if (pz < 0) if (pz < 0)
return; // Don't render points behind the camera return; // Don't render points behind the camera
const float dx = xf - px; const float dx = xf - px;
@ -81,7 +83,7 @@ __global__ void RasterizePointsNaiveCudaKernel(
const float* points, // (P, 3) const float* points, // (P, 3)
const int64_t* cloud_to_packed_first_idx, // (N) const int64_t* cloud_to_packed_first_idx, // (N)
const int64_t* num_points_per_cloud, // (N) const int64_t* num_points_per_cloud, // (N)
const float radius, const float* radius,
const int N, const int N,
const int S, const int S,
const int K, const int K,
@ -91,7 +93,6 @@ __global__ void RasterizePointsNaiveCudaKernel(
// Simple version: One thread per output pixel // Simple version: One thread per output pixel
const int num_threads = gridDim.x * blockDim.x; const int num_threads = gridDim.x * blockDim.x;
const int tid = blockDim.x * blockIdx.x + threadIdx.x; const int tid = blockDim.x * blockIdx.x + threadIdx.x;
const float radius2 = radius * radius;
for (int i = tid; i < N * S * S; i += num_threads) { for (int i = tid; i < N * S * S; i += num_threads) {
// Convert linear index to 3D index // Convert linear index to 3D index
const int n = i / (S * S); // Batch index const int n = i / (S * S); // Batch index
@ -128,7 +129,7 @@ __global__ void RasterizePointsNaiveCudaKernel(
for (int p_idx = point_start_idx; p_idx < point_stop_idx; ++p_idx) { for (int p_idx = point_start_idx; p_idx < point_stop_idx; ++p_idx) {
CheckPixelInsidePoint( CheckPixelInsidePoint(
points, p_idx, q_size, q_max_z, q_max_idx, q, radius2, xf, yf, K); points, p_idx, q_size, q_max_z, q_max_idx, q, radius, xf, yf, K);
} }
BubbleSort(q, q_size); BubbleSort(q, q_size);
int idx = n * S * S * K + pix_idx * K; int idx = n * S * S * K + pix_idx * K;
@ -145,7 +146,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> RasterizePointsNaiveCuda(
const at::Tensor& cloud_to_packed_first_idx, // (N) const at::Tensor& cloud_to_packed_first_idx, // (N)
const at::Tensor& num_points_per_cloud, // (N) const at::Tensor& num_points_per_cloud, // (N)
const int image_size, const int image_size,
const float radius, const at::Tensor& radius,
const int points_per_pixel) { const int points_per_pixel) {
// Check inputs are on the same device // Check inputs are on the same device
at::TensorArg points_t{points, "points", 1}, at::TensorArg points_t{points, "points", 1},
@ -194,7 +195,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> RasterizePointsNaiveCuda(
points.contiguous().data_ptr<float>(), points.contiguous().data_ptr<float>(),
cloud_to_packed_first_idx.contiguous().data_ptr<int64_t>(), cloud_to_packed_first_idx.contiguous().data_ptr<int64_t>(),
num_points_per_cloud.contiguous().data_ptr<int64_t>(), num_points_per_cloud.contiguous().data_ptr<int64_t>(),
radius, radius.contiguous().data_ptr<float>(),
N, N,
S, S,
K, K,
@ -214,7 +215,7 @@ __global__ void RasterizePointsCoarseCudaKernel(
const float* points, // (P, 3) const float* points, // (P, 3)
const int64_t* cloud_to_packed_first_idx, // (N) const int64_t* cloud_to_packed_first_idx, // (N)
const int64_t* num_points_per_cloud, // (N) const int64_t* num_points_per_cloud, // (N)
const float radius, const float* radius,
const int N, const int N,
const int P, const int P,
const int S, const int S,
@ -266,12 +267,13 @@ __global__ void RasterizePointsCoarseCudaKernel(
const float px = points[p_idx * 3 + 0]; const float px = points[p_idx * 3 + 0];
const float py = points[p_idx * 3 + 1]; const float py = points[p_idx * 3 + 1];
const float pz = points[p_idx * 3 + 2]; const float pz = points[p_idx * 3 + 2];
const float p_radius = radius[p_idx];
if (pz < 0) if (pz < 0)
continue; // Don't render points behind the camera. continue; // Don't render points behind the camera.
const float px0 = px - radius; const float px0 = px - p_radius;
const float px1 = px + radius; const float px1 = px + p_radius;
const float py0 = py - radius; const float py0 = py - p_radius;
const float py1 = py + radius; const float py1 = py + p_radius;
// Brute-force search over all bins; TODO something smarter? // Brute-force search over all bins; TODO something smarter?
// For example we could compute the exact bin where the point falls, // For example we could compute the exact bin where the point falls,
@ -341,7 +343,7 @@ at::Tensor RasterizePointsCoarseCuda(
const at::Tensor& cloud_to_packed_first_idx, // (N) const at::Tensor& cloud_to_packed_first_idx, // (N)
const at::Tensor& num_points_per_cloud, // (N) const at::Tensor& num_points_per_cloud, // (N)
const int image_size, const int image_size,
const float radius, const at::Tensor& radius,
const int bin_size, const int bin_size,
const int max_points_per_bin) { const int max_points_per_bin) {
TORCH_CHECK( TORCH_CHECK(
@ -390,7 +392,7 @@ at::Tensor RasterizePointsCoarseCuda(
points.contiguous().data_ptr<float>(), points.contiguous().data_ptr<float>(),
cloud_to_packed_first_idx.contiguous().data_ptr<int64_t>(), cloud_to_packed_first_idx.contiguous().data_ptr<int64_t>(),
num_points_per_cloud.contiguous().data_ptr<int64_t>(), num_points_per_cloud.contiguous().data_ptr<int64_t>(),
radius, radius.contiguous().data_ptr<float>(),
N, N,
P, P,
image_size, image_size,
@ -411,7 +413,7 @@ at::Tensor RasterizePointsCoarseCuda(
__global__ void RasterizePointsFineCudaKernel( __global__ void RasterizePointsFineCudaKernel(
const float* points, // (P, 3) const float* points, // (P, 3)
const int32_t* bin_points, // (N, B, B, T) const int32_t* bin_points, // (N, B, B, T)
const float radius, const float* radius,
const int bin_size, const int bin_size,
const int N, const int N,
const int B, // num_bins const int B, // num_bins
@ -425,7 +427,6 @@ __global__ void RasterizePointsFineCudaKernel(
const int num_pixels = N * B * B * bin_size * bin_size; const int num_pixels = N * B * B * bin_size * bin_size;
const int num_threads = gridDim.x * blockDim.x; const int num_threads = gridDim.x * blockDim.x;
const int tid = blockIdx.x * blockDim.x + threadIdx.x; const int tid = blockIdx.x * blockDim.x + threadIdx.x;
const float radius2 = radius * radius;
for (int pid = tid; pid < num_pixels; pid += num_threads) { for (int pid = tid; pid < num_pixels; pid += num_threads) {
// Convert linear index into bin and pixel indices. We make the within // Convert linear index into bin and pixel indices. We make the within
@ -464,7 +465,7 @@ __global__ void RasterizePointsFineCudaKernel(
continue; continue;
} }
CheckPixelInsidePoint( CheckPixelInsidePoint(
points, p, q_size, q_max_z, q_max_idx, q, radius2, xf, yf, K); points, p, q_size, q_max_z, q_max_idx, q, radius, xf, yf, K);
} }
// Now we've looked at all the points for this bin, so we can write // Now we've looked at all the points for this bin, so we can write
// output for the current pixel. // output for the current pixel.
@ -488,7 +489,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> RasterizePointsFineCuda(
const at::Tensor& points, // (P, 3) const at::Tensor& points, // (P, 3)
const at::Tensor& bin_points, const at::Tensor& bin_points,
const int image_size, const int image_size,
const float radius, const at::Tensor& radius,
const int bin_size, const int bin_size,
const int points_per_pixel) { const int points_per_pixel) {
// Check inputs are on the same device // Check inputs are on the same device
@ -525,7 +526,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> RasterizePointsFineCuda(
RasterizePointsFineCudaKernel<<<blocks, threads, 0, stream>>>( RasterizePointsFineCudaKernel<<<blocks, threads, 0, stream>>>(
points.contiguous().data_ptr<float>(), points.contiguous().data_ptr<float>(),
bin_points.contiguous().data_ptr<int32_t>(), bin_points.contiguous().data_ptr<int32_t>(),
radius, radius.contiguous().data_ptr<float>(),
bin_size, bin_size,
N, N,
B, B,

View File

@ -15,7 +15,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsNaiveCpu(
const torch::Tensor& cloud_to_packed_first_idx, const torch::Tensor& cloud_to_packed_first_idx,
const torch::Tensor& num_points_per_cloud, const torch::Tensor& num_points_per_cloud,
const int image_size, const int image_size,
const float radius, const torch::Tensor& radius,
const int points_per_pixel); const int points_per_pixel);
#ifdef WITH_CUDA #ifdef WITH_CUDA
@ -25,7 +25,7 @@ RasterizePointsNaiveCuda(
const torch::Tensor& cloud_to_packed_first_idx, const torch::Tensor& cloud_to_packed_first_idx,
const torch::Tensor& num_points_per_cloud, const torch::Tensor& num_points_per_cloud,
const int image_size, const int image_size,
const float radius, const torch::Tensor& radius,
const int points_per_pixel); const int points_per_pixel);
#endif #endif
// Naive (forward) pointcloud rasterization: For each pixel, for each point, // Naive (forward) pointcloud rasterization: For each pixel, for each point,
@ -41,7 +41,8 @@ RasterizePointsNaiveCuda(
// in the batch where N is the batch size. // in the batch where N is the batch size.
// num_points_per_cloud: LongTensor of shape (N) giving the number of points // num_points_per_cloud: LongTensor of shape (N) giving the number of points
// for each pointcloud in the batch. // for each pointcloud in the batch.
// radius: Radius of each point (in NDC units) // radius: FloatTensor of shape (P) giving the radius (in NDC units) of
// each point in points.
// image_size: (S) Size of the image to return (in pixels) // image_size: (S) Size of the image to return (in pixels)
// points_per_pixel: (K) The number closest of points to return for each pixel // points_per_pixel: (K) The number closest of points to return for each pixel
// //
@ -62,7 +63,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsNaive(
const torch::Tensor& cloud_to_packed_first_idx, const torch::Tensor& cloud_to_packed_first_idx,
const torch::Tensor& num_points_per_cloud, const torch::Tensor& num_points_per_cloud,
const int image_size, const int image_size,
const float radius, const torch::Tensor& radius,
const int points_per_pixel) { const int points_per_pixel) {
if (points.is_cuda() && cloud_to_packed_first_idx.is_cuda() && if (points.is_cuda() && cloud_to_packed_first_idx.is_cuda() &&
num_points_per_cloud.is_cuda()) { num_points_per_cloud.is_cuda()) {
@ -70,6 +71,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsNaive(
CHECK_CUDA(points); CHECK_CUDA(points);
CHECK_CUDA(cloud_to_packed_first_idx); CHECK_CUDA(cloud_to_packed_first_idx);
CHECK_CUDA(num_points_per_cloud); CHECK_CUDA(num_points_per_cloud);
CHECK_CUDA(radius);
return RasterizePointsNaiveCuda( return RasterizePointsNaiveCuda(
points, points,
cloud_to_packed_first_idx, cloud_to_packed_first_idx,
@ -100,7 +102,7 @@ torch::Tensor RasterizePointsCoarseCpu(
const torch::Tensor& cloud_to_packed_first_idx, const torch::Tensor& cloud_to_packed_first_idx,
const torch::Tensor& num_points_per_cloud, const torch::Tensor& num_points_per_cloud,
const int image_size, const int image_size,
const float radius, const torch::Tensor& radius,
const int bin_size, const int bin_size,
const int max_points_per_bin); const int max_points_per_bin);
@ -110,7 +112,7 @@ torch::Tensor RasterizePointsCoarseCuda(
const torch::Tensor& cloud_to_packed_first_idx, const torch::Tensor& cloud_to_packed_first_idx,
const torch::Tensor& num_points_per_cloud, const torch::Tensor& num_points_per_cloud,
const int image_size, const int image_size,
const float radius, const torch::Tensor& radius,
const int bin_size, const int bin_size,
const int max_points_per_bin); const int max_points_per_bin);
#endif #endif
@ -124,7 +126,8 @@ torch::Tensor RasterizePointsCoarseCuda(
// in the batch where N is the batch size. // in the batch where N is the batch size.
// num_points_per_cloud: LongTensor of shape (N) giving the number of points // num_points_per_cloud: LongTensor of shape (N) giving the number of points
// for each pointcloud in the batch. // for each pointcloud in the batch.
// radius: Radius of points to rasterize (in NDC units) // radius: FloatTensor of shape (P) giving the radius (in NDC units) of
// each point in points.
// image_size: Size of the image to generate (in pixels) // image_size: Size of the image to generate (in pixels)
// bin_size: Size of each bin within the image (in pixels) // bin_size: Size of each bin within the image (in pixels)
// //
@ -138,7 +141,7 @@ torch::Tensor RasterizePointsCoarse(
const torch::Tensor& cloud_to_packed_first_idx, const torch::Tensor& cloud_to_packed_first_idx,
const torch::Tensor& num_points_per_cloud, const torch::Tensor& num_points_per_cloud,
const int image_size, const int image_size,
const float radius, const torch::Tensor& radius,
const int bin_size, const int bin_size,
const int max_points_per_bin) { const int max_points_per_bin) {
if (points.is_cuda() && cloud_to_packed_first_idx.is_cuda() && if (points.is_cuda() && cloud_to_packed_first_idx.is_cuda() &&
@ -147,6 +150,7 @@ torch::Tensor RasterizePointsCoarse(
CHECK_CUDA(points); CHECK_CUDA(points);
CHECK_CUDA(cloud_to_packed_first_idx); CHECK_CUDA(cloud_to_packed_first_idx);
CHECK_CUDA(num_points_per_cloud); CHECK_CUDA(num_points_per_cloud);
CHECK_CUDA(radius);
return RasterizePointsCoarseCuda( return RasterizePointsCoarseCuda(
points, points,
cloud_to_packed_first_idx, cloud_to_packed_first_idx,
@ -179,7 +183,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsFineCuda(
const torch::Tensor& points, const torch::Tensor& points,
const torch::Tensor& bin_points, const torch::Tensor& bin_points,
const int image_size, const int image_size,
const float radius, const torch::Tensor& radius,
const int bin_size, const int bin_size,
const int points_per_pixel); const int points_per_pixel);
#endif #endif
@ -191,7 +195,8 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsFineCuda(
// bin_points: int32 Tensor of shape (N, B, B, M) giving the indices of points // bin_points: int32 Tensor of shape (N, B, B, M) giving the indices of points
// that fall into each bin (output from coarse rasterization) // that fall into each bin (output from coarse rasterization)
// image_size: Size of image to generate (in pixels) // image_size: Size of image to generate (in pixels)
// radius: Radius of points to rasterize (NDC units) // radius: FloatTensor of shape (P) giving the radius (in NDC units) of
// each point in points.
// bin_size: Size of each bin (in pixels) // bin_size: Size of each bin (in pixels)
// points_per_pixel: How many points to rasterize for each pixel // points_per_pixel: How many points to rasterize for each pixel
// //
@ -210,7 +215,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsFine(
const torch::Tensor& points, const torch::Tensor& points,
const torch::Tensor& bin_points, const torch::Tensor& bin_points,
const int image_size, const int image_size,
const float radius, const torch::Tensor& radius,
const int bin_size, const int bin_size,
const int points_per_pixel) { const int points_per_pixel) {
if (points.is_cuda()) { if (points.is_cuda()) {
@ -296,7 +301,8 @@ torch::Tensor RasterizePointsBackward(
// in the batch where N is the batch size. // in the batch where N is the batch size.
// num_points_per_cloud: LongTensor of shape (N) giving the number of points // num_points_per_cloud: LongTensor of shape (N) giving the number of points
// for each pointcloud in the batch. // for each pointcloud in the batch.
// radius: Radius of each point (in NDC units) // radius: FloatTensor of shape (P) giving the radius (in NDC units) of
// each point in points.
// image_size: (S) Size of the image to return (in pixels) // image_size: (S) Size of the image to return (in pixels)
// points_per_pixel: (K) The number of points to return for each pixel // points_per_pixel: (K) The number of points to return for each pixel
// bin_size: Bin size (in pixels) for coarse-to-fine rasterization. Setting // bin_size: Bin size (in pixels) for coarse-to-fine rasterization. Setting
@ -320,7 +326,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePoints(
const torch::Tensor& cloud_to_packed_first_idx, const torch::Tensor& cloud_to_packed_first_idx,
const torch::Tensor& num_points_per_cloud, const torch::Tensor& num_points_per_cloud,
const int image_size, const int image_size,
const float radius, const torch::Tensor& radius,
const int points_per_pixel, const int points_per_pixel,
const int bin_size, const int bin_size,
const int max_points_per_bin) { const int max_points_per_bin) {

View File

@ -17,7 +17,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsNaiveCpu(
const torch::Tensor& cloud_to_packed_first_idx, // (N) const torch::Tensor& cloud_to_packed_first_idx, // (N)
const torch::Tensor& num_points_per_cloud, // (N) const torch::Tensor& num_points_per_cloud, // (N)
const int image_size, const int image_size,
const float radius, const torch::Tensor& radius,
const int points_per_pixel) { const int points_per_pixel) {
const int32_t N = cloud_to_packed_first_idx.size(0); // batch_size. const int32_t N = cloud_to_packed_first_idx.size(0); // batch_size.
@ -35,8 +35,8 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsNaiveCpu(
auto point_idxs_a = point_idxs.accessor<int32_t, 4>(); auto point_idxs_a = point_idxs.accessor<int32_t, 4>();
auto zbuf_a = zbuf.accessor<float, 4>(); auto zbuf_a = zbuf.accessor<float, 4>();
auto pix_dists_a = pix_dists.accessor<float, 4>(); auto pix_dists_a = pix_dists.accessor<float, 4>();
auto radius_a = radius.accessor<float, 1>();
const float radius2 = radius * radius;
for (int n = 0; n < N; ++n) { for (int n = 0; n < N; ++n) {
// Loop through each pointcloud in the batch. // Loop through each pointcloud in the batch.
// Get the start index of the points in points_packed and the num points // Get the start index of the points in points_packed and the num points
@ -63,6 +63,8 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsNaiveCpu(
const float px = points_a[p][0]; const float px = points_a[p][0];
const float py = points_a[p][1]; const float py = points_a[p][1];
const float pz = points_a[p][2]; const float pz = points_a[p][2];
const float p_radius = radius_a[p];
const float radius2 = p_radius * p_radius;
if (pz < 0) { if (pz < 0) {
continue; continue;
} }
@ -98,7 +100,7 @@ torch::Tensor RasterizePointsCoarseCpu(
const torch::Tensor& cloud_to_packed_first_idx, // (N) const torch::Tensor& cloud_to_packed_first_idx, // (N)
const torch::Tensor& num_points_per_cloud, // (N) const torch::Tensor& num_points_per_cloud, // (N)
const int image_size, const int image_size,
const float radius, const torch::Tensor& radius,
const int bin_size, const int bin_size,
const int max_points_per_bin) { const int max_points_per_bin) {
const int32_t N = cloud_to_packed_first_idx.size(0); // batch_size. const int32_t N = cloud_to_packed_first_idx.size(0); // batch_size.
@ -112,6 +114,7 @@ torch::Tensor RasterizePointsCoarseCpu(
auto points_a = points.accessor<float, 2>(); auto points_a = points.accessor<float, 2>();
auto points_per_bin_a = points_per_bin.accessor<int32_t, 3>(); auto points_per_bin_a = points_per_bin.accessor<int32_t, 3>();
auto bin_points_a = bin_points.accessor<int32_t, 4>(); auto bin_points_a = bin_points.accessor<int32_t, 4>();
auto radius_a = radius.accessor<float, 1>();
const float pixel_width = 2.0f / image_size; const float pixel_width = 2.0f / image_size;
const float bin_width = pixel_width * bin_size; const float bin_width = pixel_width * bin_size;
@ -140,13 +143,14 @@ torch::Tensor RasterizePointsCoarseCpu(
float px = points_a[p][0]; float px = points_a[p][0];
float py = points_a[p][1]; float py = points_a[p][1];
float pz = points_a[p][2]; float pz = points_a[p][2];
const float p_radius = radius_a[p];
if (pz < 0) { if (pz < 0) {
continue; continue;
} }
float point_x_min = px - radius; float point_x_min = px - p_radius;
float point_x_max = px + radius; float point_x_max = px + p_radius;
float point_y_min = py - radius; float point_y_min = py - p_radius;
float point_y_max = py + radius; float point_y_max = py + p_radius;
// Use a half-open interval so that points exactly on the // Use a half-open interval so that points exactly on the
// boundary between bins will fall into exactly one bin. // boundary between bins will fall into exactly one bin.

View File

@ -1,6 +1,6 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from typing import Optional from typing import List, Optional, Tuple, Union
import torch import torch
@ -18,7 +18,7 @@ kMaxPointsPerBin = 22
def rasterize_points( def rasterize_points(
pointclouds, pointclouds,
image_size: int = 256, image_size: int = 256,
radius: float = 0.01, radius: Union[float, List, Tuple, torch.Tensor] = 0.01,
points_per_pixel: int = 8, points_per_pixel: int = 8,
bin_size: Optional[int] = None, bin_size: Optional[int] = None,
max_points_per_bin: Optional[int] = None, max_points_per_bin: Optional[int] = None,
@ -35,8 +35,10 @@ def rasterize_points(
(0, 0, 0); In the camera coordinate frame the x-axis goes from right-to-left, (0, 0, 0); In the camera coordinate frame the x-axis goes from right-to-left,
the y-axis goes from bottom-to-top, and the z-axis goes from back-to-front. the y-axis goes from bottom-to-top, and the z-axis goes from back-to-front.
image_size: Integer giving the resolution of the rasterized image image_size: Integer giving the resolution of the rasterized image
radius (Optional): Float giving the radius (in NDC units) of the disk to radius (Optional): The radius (in NDC units) of the disk to
be rasterized for each point. be rasterized. This can either be a float in which case the same radius is used
for each point, or a torch.Tensor of shape (N, P) giving a radius per point
in the batch.
points_per_pixel (Optional): We will keep track of this many points per points_per_pixel (Optional): We will keep track of this many points per
pixel, returning the nearest points_per_pixel points along the z-axis pixel, returning the nearest points_per_pixel points along the z-axis
bin_size: Size of bins to use for coarse-to-fine rasterization. Setting bin_size: Size of bins to use for coarse-to-fine rasterization. Setting
@ -74,6 +76,8 @@ def rasterize_points(
cloud_to_packed_first_idx = pointclouds.cloud_to_packed_first_idx() cloud_to_packed_first_idx = pointclouds.cloud_to_packed_first_idx()
num_points_per_cloud = pointclouds.num_points_per_cloud() num_points_per_cloud = pointclouds.num_points_per_cloud()
radius = _format_radius(radius, pointclouds)
if bin_size is None: if bin_size is None:
if not points_packed.is_cuda: if not points_packed.is_cuda:
# Binned CPU rasterization not fully implemented # Binned CPU rasterization not fully implemented
@ -117,6 +121,48 @@ def rasterize_points(
) )
def _format_radius(
radius: Union[float, List, Tuple, torch.Tensor], pointclouds
) -> torch.Tensor:
"""
Format the radius as a torch tensor of shape (P_packed,)
where P_packed is the total number of points in the
batch (i.e. pointclouds.points_packed().shape[0]).
This will enable support for a different size radius
for each point in the batch.
Args:
radius: can be a float, List, Tuple or tensor of
shape (N, P_padded) where P_padded is the
maximum number of points for each pointcloud
in the batch.
Returns:
radius: torch.Tensor of shape (P_packed)
"""
N, P_padded = pointclouds._N, pointclouds._P
points_packed = pointclouds.points_packed()
P_packed = points_packed.shape[0]
if isinstance(radius, (list, tuple)):
radius = torch.tensor(radius).type_as(points_packed)
if isinstance(radius, torch.Tensor):
if N == 1 and radius.ndim == 1:
radius = radius[None, ...]
if radius.shape != (N, P_padded):
msg = "radius must be of shape (N, P): got %s"
raise ValueError(msg % (repr(radius.shape)))
else:
padded_to_packed_idx = pointclouds.padded_to_packed_idx()
radius = radius.view(-1)[padded_to_packed_idx]
elif isinstance(radius, float):
radius = torch.full((P_packed,), fill_value=radius).type_as(points_packed)
else:
msg = "radius must be a float, list, tuple or tensor; got %s"
raise ValueError(msg % type(radius))
return radius
class _RasterizePoints(torch.autograd.Function): class _RasterizePoints(torch.autograd.Function):
@staticmethod @staticmethod
def forward( def forward(
@ -125,7 +171,7 @@ class _RasterizePoints(torch.autograd.Function):
cloud_to_packed_first_idx, cloud_to_packed_first_idx,
num_points_per_cloud, num_points_per_cloud,
image_size: int = 256, image_size: int = 256,
radius: float = 0.01, radius: Union[float, torch.Tensor] = 0.01,
points_per_pixel: int = 8, points_per_pixel: int = 8,
bin_size: int = 0, bin_size: int = 0,
max_points_per_bin: int = 0, max_points_per_bin: int = 0,
@ -175,7 +221,10 @@ class _RasterizePoints(torch.autograd.Function):
def rasterize_points_python( def rasterize_points_python(
pointclouds, image_size: int = 256, radius: float = 0.01, points_per_pixel: int = 8 pointclouds,
image_size: int = 256,
radius: Union[float, torch.Tensor] = 0.01,
points_per_pixel: int = 8,
): ):
""" """
Naive pure PyTorch implementation of pointcloud rasterization. Naive pure PyTorch implementation of pointcloud rasterization.
@ -190,6 +239,9 @@ def rasterize_points_python(
cloud_to_packed_first_idx = pointclouds.cloud_to_packed_first_idx() cloud_to_packed_first_idx = pointclouds.cloud_to_packed_first_idx()
num_points_per_cloud = pointclouds.num_points_per_cloud() num_points_per_cloud = pointclouds.num_points_per_cloud()
# Support variable size radius for each point in the batch
radius = _format_radius(radius, pointclouds)
# Intialize output tensors. # Intialize output tensors.
point_idxs = torch.full( point_idxs = torch.full(
(N, S, S, K), fill_value=-1, dtype=torch.int32, device=device (N, S, S, K), fill_value=-1, dtype=torch.int32, device=device
@ -225,12 +277,13 @@ def rasterize_points_python(
# Check whether each point in the batch affects this pixel. # Check whether each point in the batch affects this pixel.
for p in range(point_start_idx, point_stop_idx): for p in range(point_start_idx, point_stop_idx):
px, py, pz = points_packed[p, :] px, py, pz = points_packed[p, :]
r = radius2[p]
if pz < 0: if pz < 0:
continue continue
dx = px - xf dx = px - xf
dy = py - yf dy = py - yf
dist2 = dx * dx + dy * dy dist2 = dx * dx + dy * dy
if dist2 < radius2: if dist2 < r:
top_k_points.append((pz, p, dist2)) top_k_points.append((pz, p, dist2))
top_k_points.sort() top_k_points.sort()
if len(top_k_points) > K: if len(top_k_points) > K:

View File

@ -2,7 +2,7 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from typing import NamedTuple, Optional from typing import NamedTuple, Optional, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -30,7 +30,7 @@ class PointsRasterizationSettings:
def __init__( def __init__(
self, self,
image_size: int = 256, image_size: int = 256,
radius: float = 0.01, radius: Union[float, torch.Tensor] = 0.01,
points_per_pixel: int = 8, points_per_pixel: int = 8,
bin_size: Optional[int] = None, bin_size: Optional[int] = None,
max_points_per_bin: Optional[int] = None, max_points_per_bin: Optional[int] = None,

View File

@ -1,6 +1,7 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import itertools import itertools
from fvcore.common.benchmark import benchmark from fvcore.common.benchmark import benchmark
from test_cameras_alignment import TestCamerasAlignment from test_cameras_alignment import TestCamerasAlignment

View File

@ -1,5 +1,6 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from itertools import product
import torch import torch
from fvcore.common.benchmark import benchmark from fvcore.common.benchmark import benchmark
@ -18,44 +19,64 @@ def _bm_python_with_init(N, P, img_size=32, radius=0.1, pts_per_pxl=3):
return lambda: rasterize_points_python(*args) return lambda: rasterize_points_python(*args)
def _bm_cpu_with_init(N, P, img_size=32, radius=0.1, pts_per_pxl=3): def _bm_rasterize_points_with_init(
N, P, img_size=32, radius=0.1, pts_per_pxl=3, device="cpu", expand_radius=False
):
torch.manual_seed(231) torch.manual_seed(231)
points = torch.randn(N, P, 3) device = torch.device(device)
pointclouds = Pointclouds(points=points)
args = (pointclouds, img_size, radius, pts_per_pxl)
return lambda: rasterize_points(*args)
def _bm_cuda_with_init(N, P, img_size=32, radius=0.1, pts_per_pxl=3):
torch.manual_seed(231)
device = torch.device("cuda:0")
points = torch.randn(N, P, 3, device=device) points = torch.randn(N, P, 3, device=device)
pointclouds = Pointclouds(points=points) pointclouds = Pointclouds(points=points)
if expand_radius:
points_padded = pointclouds.points_padded()
radius = torch.full((N, P), fill_value=radius).type_as(points_padded)
args = (pointclouds, img_size, radius, pts_per_pxl) args = (pointclouds, img_size, radius, pts_per_pxl)
torch.cuda.synchronize(device) if device == "cuda":
torch.cuda.synchronize(device)
def fn(): def fn():
rasterize_points(*args) rasterize_points(*args)
torch.cuda.synchronize(device) if device == "cuda":
torch.cuda.synchronize(device)
return fn return fn
def bm_python_vs_cpu() -> None: def bm_python_vs_cpu_vs_cuda() -> None:
kwargs_list = [ kwargs_list = []
{"N": 1, "P": 32, "img_size": 32, "radius": 0.1, "pts_per_pxl": 3}, num_meshes = [1]
{"N": 2, "P": 32, "img_size": 32, "radius": 0.1, "pts_per_pxl": 3}, num_points = [10000, 2000]
] image_size = [128, 256]
benchmark(_bm_python_with_init, "RASTERIZE_PYTHON", kwargs_list, warmup_iters=1) radius = [1e-3, 0.01]
benchmark(_bm_cpu_with_init, "RASTERIZE_CPU", kwargs_list, warmup_iters=1) pts_per_pxl = [50, 100]
kwargs_list = [ expand = [True, False]
{"N": 2, "P": 32, "img_size": 32, "radius": 0.1, "pts_per_pxl": 3}, test_cases = product(
{"N": 4, "P": 1024, "img_size": 128, "radius": 0.05, "pts_per_pxl": 5}, num_meshes, num_points, image_size, radius, pts_per_pxl, expand
] )
benchmark(_bm_cpu_with_init, "RASTERIZE_CPU", kwargs_list, warmup_iters=1) for case in test_cases:
n, p, im, r, pts, e = case
kwargs_list.append(
{
"N": n,
"P": p,
"img_size": im,
"radius": r,
"pts_per_pxl": pts,
"device": "cpu",
"expand_radius": e,
}
)
benchmark(
_bm_rasterize_points_with_init, "RASTERIZE_CPU", kwargs_list, warmup_iters=1
)
kwargs_list += [ kwargs_list += [
{"N": 32, "P": 10000, "img_size": 128, "radius": 0.01, "pts_per_pxl": 50},
{"N": 32, "P": 100000, "img_size": 128, "radius": 0.01, "pts_per_pxl": 50}, {"N": 32, "P": 100000, "img_size": 128, "radius": 0.01, "pts_per_pxl": 50},
{"N": 8, "P": 200000, "img_size": 512, "radius": 0.01, "pts_per_pxl": 50}, {"N": 8, "P": 200000, "img_size": 512, "radius": 0.01, "pts_per_pxl": 50},
] ]
benchmark(_bm_cuda_with_init, "RASTERIZE_CUDA", kwargs_list, warmup_iters=1) for k in kwargs_list:
k["device"] = "cuda"
benchmark(
_bm_rasterize_points_with_init, "RASTERIZE_CUDA", kwargs_list, warmup_iters=1
)

View File

@ -8,6 +8,7 @@ import torch
from common_testing import TestCaseMixin, get_random_cuda_device from common_testing import TestCaseMixin, get_random_cuda_device
from pytorch3d import _C from pytorch3d import _C
from pytorch3d.renderer.points.rasterize_points import ( from pytorch3d.renderer.points.rasterize_points import (
_format_radius,
rasterize_points, rasterize_points,
rasterize_points_python, rasterize_points_python,
) )
@ -40,6 +41,21 @@ class TestRasterizePoints(TestCaseMixin, unittest.TestCase):
device = get_random_cuda_device() device = get_random_cuda_device()
self._test_behind_camera(rasterize_points, device, bin_size=0) self._test_behind_camera(rasterize_points, device, bin_size=0)
def test_python_variable_radius(self):
self._test_variable_size_radius(
rasterize_points_python, torch.device("cpu"), bin_size=-1
)
def test_cpu_variable_radius(self):
self._test_variable_size_radius(rasterize_points, torch.device("cpu"))
def test_cuda_variable_radius(self):
device = get_random_cuda_device()
# Naive
self._test_variable_size_radius(rasterize_points, device, bin_size=0)
# Coarse to fine
self._test_variable_size_radius(rasterize_points, device, bin_size=None)
def test_cpp_vs_naive_vs_binned(self): def test_cpp_vs_naive_vs_binned(self):
# Make sure that the backward pass runs for all pathways # Make sure that the backward pass runs for all pathways
N = 2 N = 2
@ -403,6 +419,8 @@ class TestRasterizePoints(TestCaseMixin, unittest.TestCase):
points_packed = pointclouds.points_packed() points_packed = pointclouds.points_packed()
cloud_to_packed_first_idx = pointclouds.cloud_to_packed_first_idx() cloud_to_packed_first_idx = pointclouds.cloud_to_packed_first_idx()
num_points_per_cloud = pointclouds.num_points_per_cloud() num_points_per_cloud = pointclouds.num_points_per_cloud()
radius = torch.full((points_packed.shape[0],), fill_value=radius)
args = ( args = (
points_packed, points_packed,
cloud_to_packed_first_idx, cloud_to_packed_first_idx,
@ -419,6 +437,7 @@ class TestRasterizePoints(TestCaseMixin, unittest.TestCase):
points_packed = pointclouds_cuda.points_packed() points_packed = pointclouds_cuda.points_packed()
cloud_to_packed_first_idx = pointclouds_cuda.cloud_to_packed_first_idx() cloud_to_packed_first_idx = pointclouds_cuda.cloud_to_packed_first_idx()
num_points_per_cloud = pointclouds_cuda.num_points_per_cloud() num_points_per_cloud = pointclouds_cuda.num_points_per_cloud()
radius = radius.to(device)
args = ( args = (
points_packed, points_packed,
cloud_to_packed_first_idx, cloud_to_packed_first_idx,
@ -499,6 +518,7 @@ class TestRasterizePoints(TestCaseMixin, unittest.TestCase):
bin_points_expected[0, 1, 1, :2] = torch.tensor([0, 1]) bin_points_expected[0, 1, 1, :2] = torch.tensor([0, 1])
pointclouds = Pointclouds(points=[points]) pointclouds = Pointclouds(points=[points])
radius = torch.full((points.shape[0],), fill_value=radius, device=device)
args = ( args = (
pointclouds.points_packed(), pointclouds.points_packed(),
pointclouds.cloud_to_packed_first_idx(), pointclouds.cloud_to_packed_first_idx(),
@ -512,3 +532,115 @@ class TestRasterizePoints(TestCaseMixin, unittest.TestCase):
bin_points_same = (bin_points == bin_points_expected).all() bin_points_same = (bin_points == bin_points_expected).all()
self.assertTrue(bin_points_same.item() == 1) self.assertTrue(bin_points_same.item() == 1)
def _test_variable_size_radius(self, rasterize_points_fn, device, bin_size=0):
# Two points
points = torch.tensor(
[[0.5, 0.5, 0.3], [0.5, -0.5, -0.1], [0.0, 0.0, 0.3]],
dtype=torch.float32,
device=device,
)
image_size = 16
points_per_pixel = 1
radius = torch.tensor([0.1, 0.0, 0.2], dtype=torch.float32, device=device)
pointclouds = Pointclouds(points=[points])
if bin_size == -1:
# simple python case with no binning
idx, zbuf, dists = rasterize_points_fn(
pointclouds, image_size, radius, points_per_pixel
)
else:
idx, zbuf, dists = rasterize_points_fn(
pointclouds, image_size, radius, points_per_pixel, bin_size
)
idx_expected = torch.zeros(
(1, image_size, image_size, 1), dtype=torch.int64, device=device
)
# fmt: off
idx_expected[0, ..., 0] = torch.tensor(
[
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241 E201
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241 E201
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241 E201
[-1, -1, -1, 0, 0, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241 E201
[-1, -1, -1, 0, 0, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241 E201
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241 E201
[-1, -1, -1, -1, -1, -1, -1, 2, 2, -1, -1, -1, -1, -1, -1, -1], # noqa: E241 E201
[-1, -1, -1, -1, -1, -1, 2, 2, 2, 2, -1, -1, -1, -1, -1, -1], # noqa: E241 E201
[-1, -1, -1, -1, -1, -1, 2, 2, 2, 2, -1, -1, -1, -1, -1, -1], # noqa: E241 E201
[-1, -1, -1, -1, -1, -1, -1, 2, 2, -1, -1, -1, -1, -1, -1, -1], # noqa: E241 E201
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241 E201
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241 E201
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241 E201
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241 E201
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241 E201
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1] # noqa: E241 E201
],
dtype=torch.int64,
device=device
)
# fmt: on
zbuf_expected = torch.full(
idx_expected.shape, fill_value=-1, dtype=torch.float32, device=device
)
zbuf_expected[idx_expected == 0] = 0.3
zbuf_expected[idx_expected == 2] = 0.3
dists_expected = torch.full(
idx_expected.shape, fill_value=-1, dtype=torch.float32, device=device
)
# fmt: off
dists_expected[0, ..., 0] = torch.Tensor(
[
[-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.], # noqa: E241 E201
[-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.], # noqa: E241 E201
[-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.], # noqa: E241 E201
[-1., -1., -1., 0.0078, 0.0078, -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.], # noqa: E241 E201
[-1., -1., -1., 0.0078, 0.0078, -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.], # noqa: E241 E201
[-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.], # noqa: E241 E201
[-1., -1., -1., -1., -1., -1., -1., 0.0391, 0.0391, -1., -1., -1., -1., -1., -1., -1.], # noqa: E241 E201
[-1., -1., -1., -1., -1., -1., 0.0391, 0.0078, 0.0078, 0.0391, -1., -1., -1., -1., -1., -1.], # noqa: E241 E201
[-1., -1., -1., -1., -1., -1., 0.0391, 0.0078, 0.0078, 0.0391, -1., -1., -1., -1., -1., -1.], # noqa: E241 E201
[-1., -1., -1., -1., -1., -1., -1., 0.0391, 0.0391, -1., -1., -1., -1., -1., -1., -1.], # noqa: E241 E201
[-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.], # noqa: E241 E201
[-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.], # noqa: E241 E201
[-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.], # noqa: E241 E201
[-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.], # noqa: E241 E201
[-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.], # noqa: E241 E201
[-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.] # noqa: E241 E201
]
)
# fmt: on
# Check the distances for a point are less than the squared radius
# for that point.
self.assertTrue((dists[idx == 0] < radius[0] ** 2).all())
self.assertTrue((dists[idx == 2] < radius[2] ** 2).all())
# Check all values are correct.
idx_same = (idx == idx_expected).all().item() == 1
zbuf_same = (zbuf == zbuf_expected).all().item() == 1
self.assertTrue(idx_same)
self.assertTrue(zbuf_same)
self.assertClose(dists, dists_expected, atol=4e-5)
def test_radius_format_failure(self):
N = 20
P_max = 15
points_list = []
for _ in range(N):
p = torch.randint(low=1, high=P_max, size=(1,))[0]
points_list.append(torch.randn((p, 3)))
points = Pointclouds(points=points_list)
# Incorrect shape
with self.assertRaisesRegex(ValueError, "radius must be of shape"):
_format_radius([0, 1, 2], points)
# Incorrect type
with self.assertRaisesRegex(ValueError, "float, list, tuple or tensor"):
_format_radius({0: [0, 1, 2]}, points)