Update point cloud rasterizer to support heterogeneous point clouds

Summary:
Update the point cloud rasterizer to:
- use the pointcloud datastructure (rebased on top of D19791851.)
- support rasterization of heterogeneous point clouds in the same way as with Meshes.

The main changes to the API will be as follows:
- The input to `rasterize_points` will be a `Pointclouds` object instead of a tensor. This will be easy to update e.g.
```
points = torch.randn(N, P, 3)
idx2, zbuf2, dists2 = rasterize_points(points, image_size, radius, points_per_pixel)

points = torch.randn(N, P, 3)
pointclouds = Pointclouds(points=points)
idx2, zbuf2, dists2 = rasterize_points(pointclouds, image_size, radius, points_per_pixel)
```

- The indices output from rasterization will now refer to points in `poinclouds.points_packed()`.
This may require some changes to the functions which consume the outputs of rasterization if they were previously
assuming that the indices ranged from 0 to P where P is the number of points in each pointcloud.

Making this change now so that Olivia can update her PR accordingly.

Reviewed By: gkioxari

Differential Revision: D20088651

fbshipit-source-id: 833ed659909712bcbbb6a50e2ec0189839f0413a
This commit is contained in:
Nikhila Ravi 2020-03-12 07:46:15 -07:00 committed by Facebook GitHub Bot
parent cae325718e
commit 32ad869dea
4 changed files with 241 additions and 102 deletions

View File

@ -102,16 +102,16 @@ __device__ bool CheckPointOutsideBoundingBox(
// RasterizeMeshesFineCudaKernel.
template <typename FaceQ>
__device__ void CheckPixelInsideFace(
const float* face_verts, // (N, P, 3)
int face_idx,
const float* face_verts, // (F, 3, 3)
const int face_idx,
int& q_size,
float& q_max_z,
int& q_max_idx,
FaceQ& q,
float blur_radius,
float2 pxy, // Coordinates of the pixel
int K,
bool perspective_correct) {
const float blur_radius,
const float2 pxy, // Coordinates of the pixel
const int K,
const bool perspective_correct) {
const auto v012 = GetSingleFaceVerts(face_verts, face_idx);
const float3 v0 = thrust::get<0>(v012);
const float3 v1 = thrust::get<1>(v012);
@ -335,7 +335,6 @@ __global__ void RasterizeMeshesBackwardCudaKernel(
const int64_t* pix_to_face, // (N, H, W, K)
const bool perspective_correct,
const int N,
const int F,
const int H,
const int W,
const int K,
@ -472,7 +471,6 @@ torch::Tensor RasterizeMeshesBackwardCuda(
pix_to_face.contiguous().data<int64_t>(),
perspective_correct,
N,
F,
H,
W,
K,
@ -671,7 +669,6 @@ __global__ void RasterizeMeshesFineCudaKernel(
const int bin_size,
const bool perspective_correct,
const int N,
const int F,
const int B,
const int M,
const int H,
@ -774,7 +771,6 @@ RasterizeMeshesFineCuda(
if (bin_faces.ndimension() != 4) {
AT_ERROR("bin_faces must have 4 dimensions");
}
const int F = face_verts.size(0);
const int N = bin_faces.size(0);
const int B = bin_faces.size(1);
const int M = bin_faces.size(3);
@ -803,7 +799,6 @@ RasterizeMeshesFineCuda(
bin_size,
perspective_correct,
N,
F,
B,
M,
H,

View File

@ -30,8 +30,8 @@ __device__ inline bool operator<(const Pix& a, const Pix& b) {
// RasterizePointsFineCudaKernel.
template <typename PointQ>
__device__ void CheckPixelInsidePoint(
const float* points, // (N, P, 3)
const int p,
const float* points, // (P, 3)
const int p_idx,
int& q_size,
float& q_max_z,
int& q_max_idx,
@ -39,12 +39,10 @@ __device__ void CheckPixelInsidePoint(
const float radius2,
const float xf,
const float yf,
const int n,
const int P,
const int K) {
const float px = points[n * P * 3 + p * 3 + 0];
const float py = points[n * P * 3 + p * 3 + 1];
const float pz = points[n * P * 3 + p * 3 + 2];
const float px = points[p_idx * 3 + 0];
const float py = points[p_idx * 3 + 1];
const float pz = points[p_idx * 3 + 2];
if (pz < 0)
return; // Don't render points behind the camera
const float dx = xf - px;
@ -53,7 +51,7 @@ __device__ void CheckPixelInsidePoint(
if (dist2 < radius2) {
if (q_size < K) {
// Just insert it
q[q_size] = {pz, p, dist2};
q[q_size] = {pz, p_idx, dist2};
if (pz > q_max_z) {
q_max_z = pz;
q_max_idx = q_size;
@ -61,7 +59,7 @@ __device__ void CheckPixelInsidePoint(
q_size++;
} else if (pz < q_max_z) {
// Overwrite the old max, and find the new max
q[q_max_idx] = {pz, p, dist2};
q[q_max_idx] = {pz, p_idx, dist2};
q_max_z = pz;
for (int i = 0; i < K; i++) {
if (q[i].z > q_max_z) {
@ -78,10 +76,11 @@ __device__ void CheckPixelInsidePoint(
// ****************************************************************************
__global__ void RasterizePointsNaiveCudaKernel(
const float* points, // (N, P, 3)
const float* points, // (P, 3)
const int64_t* cloud_to_packed_first_idx, // (N)
const int64_t* num_points_per_cloud, // (N)
const float radius,
const int N,
const int P,
const int S,
const int K,
int32_t* point_idxs, // (N, S, S, K)
@ -116,9 +115,15 @@ __global__ void RasterizePointsNaiveCudaKernel(
int q_size = 0;
float q_max_z = -1000;
int q_max_idx = -1;
for (int p = 0; p < P; ++p) {
// Using the batch index of the thread get the start and stop
// indices for the points.
const int64_t point_start_idx = cloud_to_packed_first_idx[n];
const int64_t point_stop_idx = point_start_idx + num_points_per_cloud[n];
for (int p_idx = point_start_idx; p_idx < point_stop_idx; ++p_idx) {
CheckPixelInsidePoint(
points, p, q_size, q_max_z, q_max_idx, q, radius2, xf, yf, n, P, K);
points, p_idx, q_size, q_max_z, q_max_idx, q, radius2, xf, yf, K);
}
BubbleSort(q, q_size);
int idx = n * S * S * K + yi * S * K + xi * K;
@ -132,14 +137,24 @@ __global__ void RasterizePointsNaiveCudaKernel(
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>
RasterizePointsNaiveCuda(
const torch::Tensor& points,
const torch::Tensor& points, // (P. 3)
const torch::Tensor& cloud_to_packed_first_idx, // (N)
const torch::Tensor& num_points_per_cloud, // (N)
const int image_size,
const float radius,
const int points_per_pixel) {
const int N = points.size(0);
const int P = points.size(1);
if (points.ndimension() != 2 || points.size(1) != 3) {
AT_ERROR("points must have dimensions (num_points, 3)");
}
if (num_points_per_cloud.size(0) != cloud_to_packed_first_idx.size(0)) {
AT_ERROR(
"num_points_per_cloud must have same size first dimension as cloud_to_packed_first_idx");
}
const int N = num_points_per_cloud.size(0); // batch size.
const int S = image_size;
const int K = points_per_pixel;
if (K > kMaxPointsPerPixel) {
std::stringstream ss;
ss << "Must have points_per_pixel <= " << kMaxPointsPerPixel;
@ -156,9 +171,10 @@ RasterizePointsNaiveCuda(
const size_t threads = 64;
RasterizePointsNaiveCudaKernel<<<blocks, threads>>>(
points.contiguous().data<float>(),
cloud_to_packed_first_idx.contiguous().data<int64_t>(),
num_points_per_cloud.contiguous().data<int64_t>(),
radius,
N,
P,
S,
K,
point_idxs.contiguous().data<int32_t>(),
@ -172,7 +188,9 @@ RasterizePointsNaiveCuda(
// ****************************************************************************
__global__ void RasterizePointsCoarseCudaKernel(
const float* points,
const float* points, // (P, 3)
const int64_t* cloud_to_packed_first_idx, // (N)
const int64_t* num_points_per_cloud, // (N)
const float radius,
const int N,
const int P,
@ -206,16 +224,27 @@ __global__ void RasterizePointsCoarseCudaKernel(
binmask.block_clear();
// Using the batch index of the thread get the start and stop
// indices for the points.
const int64_t cloud_point_start_idx = cloud_to_packed_first_idx[batch_idx];
const int64_t cloud_point_stop_idx =
cloud_point_start_idx + num_points_per_cloud[batch_idx];
// Have each thread handle a different point within the chunk
for (int p = threadIdx.x; p < chunk_size; p += blockDim.x) {
const int p_idx = point_start_idx + p;
if (p_idx >= P)
break;
const float px = points[batch_idx * P * 3 + p_idx * 3 + 0];
const float py = points[batch_idx * P * 3 + p_idx * 3 + 1];
const float pz = points[batch_idx * P * 3 + p_idx * 3 + 2];
// Check if point index corresponds to the cloud in the batch given by
// batch_idx.
if (p_idx >= cloud_point_stop_idx || p_idx < cloud_point_start_idx) {
continue;
}
const float px = points[p_idx * 3 + 0];
const float py = points[p_idx * 3 + 1];
const float pz = points[p_idx * 3 + 2];
if (pz < 0)
continue; // Don't render points behind the camera
continue; // Don't render points behind the camera.
const float px0 = px - radius;
const float px1 = px + radius;
const float py0 = py - radius;
@ -283,15 +312,20 @@ __global__ void RasterizePointsCoarseCudaKernel(
}
torch::Tensor RasterizePointsCoarseCuda(
const torch::Tensor& points,
const torch::Tensor& points, // (P, 3)
const torch::Tensor& cloud_to_packed_first_idx, // (N)
const torch::Tensor& num_points_per_cloud, // (N)
const int image_size,
const float radius,
const int bin_size,
const int max_points_per_bin) {
const int N = points.size(0);
const int P = points.size(1);
const int P = points.size(0);
const int N = num_points_per_cloud.size(0);
const int num_bins = 1 + (image_size - 1) / bin_size; // divide round up
const int M = max_points_per_bin;
if (points.ndimension() != 2 || points.size(1) != 3) {
AT_ERROR("points must have dimensions (num_points, 3)");
}
if (num_bins >= 22) {
// Make sure we do not use too much shared memory.
std::stringstream ss;
@ -307,6 +341,8 @@ torch::Tensor RasterizePointsCoarseCuda(
const size_t threads = 512;
RasterizePointsCoarseCudaKernel<<<blocks, threads, shared_size>>>(
points.contiguous().data<float>(),
cloud_to_packed_first_idx.contiguous().data<int64_t>(),
num_points_per_cloud.contiguous().data<int64_t>(),
radius,
N,
P,
@ -324,12 +360,11 @@ torch::Tensor RasterizePointsCoarseCuda(
// ****************************************************************************
__global__ void RasterizePointsFineCudaKernel(
const float* points, // (N, P, 3)
const float* points, // (P, 3)
const int32_t* bin_points, // (N, B, B, T)
const float radius,
const int bin_size,
const int N,
const int P,
const int B,
const int M,
const int S,
@ -342,6 +377,7 @@ __global__ void RasterizePointsFineCudaKernel(
const int num_threads = gridDim.x * blockDim.x;
const int tid = blockIdx.x * blockDim.x + threadIdx.x;
const float radius2 = radius * radius;
for (int pid = tid; pid < num_pixels; pid += num_threads) {
// Convert linear index into bin and pixel indices. We make the within
// block pixel ids move the fastest, so that adjacent threads will fall
@ -377,7 +413,7 @@ __global__ void RasterizePointsFineCudaKernel(
continue;
}
CheckPixelInsidePoint(
points, p, q_size, q_max_z, q_max_idx, q, radius2, xf, yf, n, P, K);
points, p, q_size, q_max_z, q_max_idx, q, radius2, xf, yf, K);
}
// Now we've looked at all the points for this bin, so we can write
// output for the current pixel.
@ -392,14 +428,13 @@ __global__ void RasterizePointsFineCudaKernel(
}
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsFineCuda(
const torch::Tensor& points,
const torch::Tensor& points, // (P, 3)
const torch::Tensor& bin_points,
const int image_size,
const float radius,
const int bin_size,
const int points_per_pixel) {
const int N = points.size(0);
const int P = points.size(1);
const int N = bin_points.size(0);
const int B = bin_points.size(1);
const int M = bin_points.size(3);
const int S = image_size;
@ -421,7 +456,6 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsFineCuda(
radius,
bin_size,
N,
P,
B,
M,
S,
@ -438,7 +472,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsFineCuda(
// ****************************************************************************
// TODO(T55115174) Add more documentation for backward kernel.
__global__ void RasterizePointsBackwardCudaKernel(
const float* points, // (N, P, 3)
const float* points, // (P, 3)
const int32_t* idxs, // (N, H, W, K)
const int N,
const int P,
@ -447,13 +481,13 @@ __global__ void RasterizePointsBackwardCudaKernel(
const int K,
const float* grad_zbuf, // (N, H, W, K)
const float* grad_dists, // (N, H, W, K)
float* grad_points) { // (N, P, 3)
float* grad_points) { // (P, 3)
// Parallelized over each of K points per pixel, for each pixel in images of
// size H * W, for each image in the batch of size N.
int num_threads = gridDim.x * blockDim.x;
int tid = blockIdx.x * blockDim.x + threadIdx.x;
for (int i = tid; i < N * H * W * K; i += num_threads) {
const int n = i / (H * W * K);
// const int n = i / (H * W * K); // batch index (not needed).
const int yxk = i % (H * W * K);
const int yi = yxk / (W * K);
const int xk = yxk % (W * K);
@ -466,15 +500,15 @@ __global__ void RasterizePointsBackwardCudaKernel(
if (p < 0)
continue;
const float grad_dist2 = grad_dists[i];
const int p_ind = n * P * 3 + p * 3;
const float px = points[p_ind];
const int p_ind = p * 3; // index into packed points tensor
const float px = points[p_ind + 0];
const float py = points[p_ind + 1];
const float dx = px - xf;
const float dy = py - yf;
const float grad_px = 2.0f * grad_dist2 * dx;
const float grad_py = 2.0f * grad_dist2 * dy;
const float grad_pz = grad_zbuf[i];
atomicAdd(grad_points + p_ind, grad_px);
atomicAdd(grad_points + p_ind + 0, grad_px);
atomicAdd(grad_points + p_ind + 1, grad_py);
atomicAdd(grad_points + p_ind + 2, grad_pz);
}
@ -485,13 +519,13 @@ torch::Tensor RasterizePointsBackwardCuda(
const torch::Tensor& idxs, // (N, H, W, K)
const torch::Tensor& grad_zbuf, // (N, H, W, K)
const torch::Tensor& grad_dists) { // (N, H, W, K)
const int N = points.size(0);
const int P = points.size(1);
const int P = points.size(0);
const int N = idxs.size(0);
const int H = idxs.size(1);
const int W = idxs.size(2);
const int K = idxs.size(3);
torch::Tensor grad_points = torch::zeros({N, P, 3}, points.options());
torch::Tensor grad_points = torch::zeros({P, 3}, points.options());
const size_t blocks = 1024;
const size_t threads = 64;

View File

@ -11,6 +11,8 @@
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsNaiveCpu(
const torch::Tensor& points,
const torch::Tensor& cloud_to_packed_first_idx,
const torch::Tensor& num_points_per_cloud,
const int image_size,
const float radius,
const int points_per_pixel);
@ -19,6 +21,8 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsNaiveCpu(
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>
RasterizePointsNaiveCuda(
const torch::Tensor& points,
const torch::Tensor& cloud_to_packed_first_idx,
const torch::Tensor& num_points_per_cloud,
const int image_size,
const float radius,
const int points_per_pixel);
@ -27,16 +31,26 @@ RasterizePointsNaiveCuda(
// check whether that point hits the pixel.
//
// Args:
// points: Tensor of shape (N, P, 3) (in NDC)
// points: Tensor of shape (P, 3) giving (packed) positions for
// points in all N pointclouds in the batch where P is the total
// number of points in the batch across all pointclouds. These points
// are expected to be in NDC coordinates in the range [-1, 1].
// cloud_to_packed_first_idx: LongTensor of shape (N) giving the index in
// points_packed of the first point in each pointcloud
// in the batch where N is the batch size.
// num_points_per_cloud: LongTensor of shape (N) giving the number of points
// for each pointcloud in the batch.
// radius: Radius of each point (in NDC units)
// image_size: (S) Size of the image to return (in pixels)
// image_size: (S) Size of the image to return (in pixels)
// points_per_pixel: (K) The number closest of points to return for each pixel
//
// Returns:
// A 4 element tuple of:
// idxs: int32 Tensor of shape (N, S, S, K) giving the indices of the
// closest K points along the z-axis for each pixel, padded with -1 for
// pixels
// hit by fewer than K points.
// pixels hit by fewer than K points. The indices refer to points in
// points packed i.e a tensor of shape (P, 3) representing the flattened
// points for all pointclouds in the batch.
// zbuf: float32 Tensor of shape (N, S, S, K) giving the depth of each
// closest point for each pixel.
// dists: float32 Tensor of shape (N, S, S, K) giving squared Euclidean
@ -44,19 +58,32 @@ RasterizePointsNaiveCuda(
// points along the z axis.
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsNaive(
const torch::Tensor& points,
const torch::Tensor& cloud_to_packed_first_idx,
const torch::Tensor& num_points_per_cloud,
const int image_size,
const float radius,
const int points_per_pixel) {
if (points.type().is_cuda()) {
if (points.type().is_cuda() && cloud_to_packed_first_idx.type().is_cuda() &&
num_points_per_cloud.type().is_cuda()) {
#ifdef WITH_CUDA
return RasterizePointsNaiveCuda(
points, image_size, radius, points_per_pixel);
points,
cloud_to_packed_first_idx,
num_points_per_cloud,
image_size,
radius,
points_per_pixel);
#else
AT_ERROR("Not compiled with GPU support");
#endif
} else {
return RasterizePointsNaiveCpu(
points, image_size, radius, points_per_pixel);
points,
cloud_to_packed_first_idx,
num_points_per_cloud,
image_size,
radius,
points_per_pixel);
}
}
@ -66,6 +93,8 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsNaive(
torch::Tensor RasterizePointsCoarseCpu(
const torch::Tensor& points,
const torch::Tensor& cloud_to_packed_first_idx,
const torch::Tensor& num_points_per_cloud,
const int image_size,
const float radius,
const int bin_size,
@ -74,13 +103,23 @@ torch::Tensor RasterizePointsCoarseCpu(
#ifdef WITH_CUDA
torch::Tensor RasterizePointsCoarseCuda(
const torch::Tensor& points,
const torch::Tensor& cloud_to_packed_first_idx,
const torch::Tensor& num_points_per_cloud,
const int image_size,
const float radius,
const int bin_size,
const int max_points_per_bin);
#endif
// Args:
// points: Tensor of shape (N, P, 3)
// points: Tensor of shape (P, 3) giving (packed) positions for
// points in all N pointclouds in the batch where P is the total
// number of points in the batch across all pointclouds. These points
// are expected to be in NDC coordinates in the range [-1, 1].
// cloud_to_packed_first_idx: LongTensor of shape (N) giving the index in
// points_packed of the first point in each pointcloud
// in the batch where N is the batch size.
// num_points_per_cloud: LongTensor of shape (N) giving the number of points
// for each pointcloud in the batch.
// radius: Radius of points to rasterize (in NDC units)
// image_size: Size of the image to generate (in pixels)
// bin_size: Size of each bin within the image (in pixels)
@ -92,20 +131,35 @@ torch::Tensor RasterizePointsCoarseCuda(
// of points that fall into each bin.
torch::Tensor RasterizePointsCoarse(
const torch::Tensor& points,
const torch::Tensor& cloud_to_packed_first_idx,
const torch::Tensor& num_points_per_cloud,
const int image_size,
const float radius,
const int bin_size,
const int max_points_per_bin) {
if (points.type().is_cuda()) {
if (points.type().is_cuda() && cloud_to_packed_first_idx.type().is_cuda() &&
num_points_per_cloud.type().is_cuda()) {
#ifdef WITH_CUDA
return RasterizePointsCoarseCuda(
points, image_size, radius, bin_size, max_points_per_bin);
points,
cloud_to_packed_first_idx,
num_points_per_cloud,
image_size,
radius,
bin_size,
max_points_per_bin);
#else
AT_ERROR("Not compiled with GPU support");
#endif
} else {
return RasterizePointsCoarseCpu(
points, image_size, radius, bin_size, max_points_per_bin);
points,
cloud_to_packed_first_idx,
num_points_per_cloud,
image_size,
radius,
bin_size,
max_points_per_bin);
}
}
@ -123,7 +177,10 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsFineCuda(
const int points_per_pixel);
#endif
// Args:
// points: float32 Tensor of shape (N, P, 3)
// points: Tensor of shape (P, 3) giving (packed) positions for
// points in all N pointclouds in the batch where P is the total
// number of points in the batch across all pointclouds. These points
// are expected to be in NDC coordinates in the range [-1, 1].
// bin_points: int32 Tensor of shape (N, B, B, M) giving the indices of points
// that fall into each bin (output from coarse rasterization)
// image_size: Size of image to generate (in pixels)
@ -132,9 +189,11 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsFineCuda(
// points_per_pixel: How many points to rasterize for each pixel
//
// Returns (same as rasterize_points):
// idxs: int32 Tensor of shape (N, S, S, K) giving the indices of the closest
// points_per_pixel points along the z-axis for each pixel, padded with
// -1 for pixels hit by fewer than points_per_pixel points
// idxs: int32 Tensor of shape (N, S, S, K) giving the indices of the
// closest K points along the z-axis for each pixel, padded with -1 for
// pixels hit by fewer than K points. The indices refer to points in
// points packed i.e a tensor of shape (P, 3) representing the flattened
// points for all pointclouds in the batch.
// zbuf: float32 Tensor of shape (N, S, S, K) giving the depth of each of each
// closest point for each pixel
// dists: float32 Tensor of shape (N, S, S, K) giving squared Euclidean
@ -177,7 +236,10 @@ torch::Tensor RasterizePointsBackwardCuda(
const torch::Tensor& grad_dists);
#endif
// Args:
// points: float32 Tensor of shape (N, P, 3)
// points: Tensor of shape (P, 3) giving (packed) positions for
// points in all N pointclouds in the batch where P is the total
// number of points in the batch across all pointclouds. These points
// are expected to be in NDC coordinates in the range [-1, 1].
// idxs: int32 Tensor of shape (N, H, W, K) (from forward pass)
// grad_zbuf: float32 Tensor of shape (N, H, W, K) giving upstream gradient
// d(loss)/d(zbuf) of the distances from each pixel to its nearest
@ -212,7 +274,15 @@ torch::Tensor RasterizePointsBackward(
// it uses either naive or coarse-to-fine rasterization based on bin_size.
//
// Args:
// points: Tensor of shape (N, P, 3) (in NDC)
// points: Tensor of shape (P, 3) giving (packed) positions for
// points in all N pointclouds in the batch where P is the total
// number of points in the batch across all pointclouds. These points
// are expected to be in NDC coordinates in the range [-1, 1].
// cloud_to_packed_first_idx: LongTensor of shape (N) giving the index in
// points_packed of the first point in each pointcloud
// in the batch where N is the batch size.
// num_points_per_cloud: LongTensor of shape (N) giving the number of points
// for each pointcloud in the batch.
// radius: Radius of each point (in NDC units)
// image_size: (S) Size of the image to return (in pixels)
// points_per_pixel: (K) The number of points to return for each pixel
@ -223,8 +293,10 @@ torch::Tensor RasterizePointsBackward(
//
// Returns:
// idxs: int32 Tensor of shape (N, S, S, K) giving the indices of the
// closest points_per_pixel points along the z-axis for each pixel,
// padded with -1 for pixels hit by fewer than points_per_pixel points
// closest K points along the z-axis for each pixel, padded with -1 for
// pixels hit by fewer than K points. The indices refer to points in
// points packed i.e a tensor of shape (P, 3) representing the flattened
// points for all pointclouds in the batch.
// zbuf: float32 Tensor of shape (N, S, S, K) giving the depth of each of each
// closest point for each pixel
// dists: float32 Tensor of shape (N, S, S, K) giving squared Euclidean
@ -232,6 +304,8 @@ torch::Tensor RasterizePointsBackward(
// points along the z axis.
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePoints(
const torch::Tensor& points,
const torch::Tensor& cloud_to_packed_first_idx,
const torch::Tensor& num_points_per_cloud,
const int image_size,
const float radius,
const int points_per_pixel,
@ -239,11 +313,23 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePoints(
const int max_points_per_bin) {
if (bin_size == 0) {
// Use the naive per-pixel implementation
return RasterizePointsNaive(points, image_size, radius, points_per_pixel);
return RasterizePointsNaive(
points,
cloud_to_packed_first_idx,
num_points_per_cloud,
image_size,
radius,
points_per_pixel);
} else {
// Use coarse-to-fine rasterization
const auto bin_points = RasterizePointsCoarse(
points, image_size, radius, bin_size, max_points_per_bin);
points,
cloud_to_packed_first_idx,
num_points_per_cloud,
image_size,
radius,
bin_size,
max_points_per_bin);
return RasterizePointsFine(
points, bin_points, image_size, radius, bin_size, points_per_pixel);
}

View File

@ -13,37 +13,49 @@ static float PixToNdc(const int i, const int S) {
}
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsNaiveCpu(
const torch::Tensor& points,
const torch::Tensor& points, // (P, 3)
const torch::Tensor& cloud_to_packed_first_idx, // (N)
const torch::Tensor& num_points_per_cloud, // (N)
const int image_size,
const float radius,
const int points_per_pixel) {
const int N = points.size(0);
const int P = points.size(1);
const int32_t N = cloud_to_packed_first_idx.size(0); // batch_size.
const int S = image_size;
const int K = points_per_pixel;
// Initialize output tensors.
auto int_opts = points.options().dtype(torch::kInt32);
auto float_opts = points.options().dtype(torch::kFloat32);
torch::Tensor point_idxs = torch::full({N, S, S, K}, -1, int_opts);
torch::Tensor zbuf = torch::full({N, S, S, K}, -1, float_opts);
torch::Tensor pix_dists = torch::full({N, S, S, K}, -1, float_opts);
auto points_a = points.accessor<float, 3>();
auto points_a = points.accessor<float, 2>();
auto point_idxs_a = point_idxs.accessor<int32_t, 4>();
auto zbuf_a = zbuf.accessor<float, 4>();
auto pix_dists_a = pix_dists.accessor<float, 4>();
const float radius2 = radius * radius;
for (int n = 0; n < N; ++n) {
// Loop through each pointcloud in the batch.
// Get the start index of the points in points_packed and the num points
// in the point cloud.
const int point_start_idx =
cloud_to_packed_first_idx[n].item().to<int32_t>();
const int point_stop_idx =
(point_start_idx + num_points_per_cloud[n].item().to<int32_t>());
for (int yi = 0; yi < S; ++yi) {
float yf = PixToNdc(yi, S);
for (int xi = 0; xi < S; ++xi) {
float xf = PixToNdc(xi, S);
// Use a priority queue to hold (z, idx, r)
std::priority_queue<std::tuple<float, int, float>> q;
for (int p = 0; p < P; ++p) {
const float px = points_a[n][p][0];
const float py = points_a[n][p][1];
const float pz = points_a[n][p][2];
for (int p = point_start_idx; p < point_stop_idx; ++p) {
const float px = points_a[p][0];
const float py = points_a[p][1];
const float pz = points_a[p][2];
if (pz < 0) {
continue;
}
@ -75,26 +87,37 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsNaiveCpu(
}
torch::Tensor RasterizePointsCoarseCpu(
const torch::Tensor& points,
const torch::Tensor& points, // (P, 3)
const torch::Tensor& cloud_to_packed_first_idx, // (N)
const torch::Tensor& num_points_per_cloud, // (N)
const int image_size,
const float radius,
const int bin_size,
const int max_points_per_bin) {
const int N = points.size(0);
const int P = points.size(1);
const int32_t N = cloud_to_packed_first_idx.size(0); // batch_size.
const int B = 1 + (image_size - 1) / bin_size; // Integer division round up
const int M = max_points_per_bin;
auto opts = points.options().dtype(torch::kInt32);
torch::Tensor points_per_bin = torch::zeros({N, B, B}, opts);
torch::Tensor bin_points = torch::full({N, B, B, M}, -1, opts);
auto points_a = points.accessor<float, 3>();
auto points_a = points.accessor<float, 2>();
auto points_per_bin_a = points_per_bin.accessor<int32_t, 3>();
auto bin_points_a = bin_points.accessor<int32_t, 4>();
const float pixel_width = 2.0f / image_size;
const float bin_width = pixel_width * bin_size;
for (int n = 0; n < N; ++n) {
// Loop through each pointcloud in the batch.
// Get the start index of the points in points_packed and the num points
// in the point cloud.
const int point_start_idx =
cloud_to_packed_first_idx[n].item().to<int32_t>();
const int point_stop_idx =
(point_start_idx + num_points_per_cloud[n].item().to<int32_t>());
float bin_y_min = -1.0f;
float bin_y_max = bin_y_min + bin_width;
for (int by = 0; by < B; by++) {
@ -102,10 +125,10 @@ torch::Tensor RasterizePointsCoarseCpu(
float bin_x_max = bin_x_min + bin_width;
for (int bx = 0; bx < B; bx++) {
int32_t points_hit = 0;
for (int32_t p = 0; p < P; p++) {
float px = points_a[n][p][0];
float py = points_a[n][p][1];
float pz = points_a[n][p][2];
for (int p = point_start_idx; p < point_stop_idx; ++p) {
float px = points_a[p][0];
float py = points_a[p][1];
float pz = points_a[p][2];
if (pz < 0) {
continue;
}
@ -144,12 +167,13 @@ torch::Tensor RasterizePointsCoarseCpu(
}
torch::Tensor RasterizePointsBackwardCpu(
const torch::Tensor& points, // (N, P, 3)
const torch::Tensor& points, // (P, 3)
const torch::Tensor& idxs, // (N, H, W, K)
const torch::Tensor& grad_zbuf, // (N, H, W, K)
const torch::Tensor& grad_dists) { // (N, H, W, K)
const int N = points.size(0);
const int P = points.size(1);
const int N = idxs.size(0);
const int P = points.size(0);
const int H = idxs.size(1);
const int W = idxs.size(2);
const int K = idxs.size(3);
@ -159,13 +183,13 @@ torch::Tensor RasterizePointsBackwardCpu(
if (H != W) {
AT_ERROR("RasterizePointsBackwardCpu only supports square images");
}
torch::Tensor grad_points = torch::zeros({N, P, 3}, points.options());
torch::Tensor grad_points = torch::zeros({P, 3}, points.options());
auto points_a = points.accessor<float, 3>();
auto points_a = points.accessor<float, 2>();
auto idxs_a = idxs.accessor<int32_t, 4>();
auto grad_dists_a = grad_dists.accessor<float, 4>();
auto grad_zbuf_a = grad_zbuf.accessor<float, 4>();
auto grad_points_a = grad_points.accessor<float, 3>();
auto grad_points_a = grad_points.accessor<float, 2>();
for (int n = 0; n < N; ++n) { // Loop over images in the batch
for (int y = 0; y < H; ++y) { // Loop over rows in the image
@ -178,16 +202,16 @@ torch::Tensor RasterizePointsBackwardCpu(
break;
}
const float grad_dist2 = grad_dists_a[n][y][x][k];
const float px = points_a[n][p][0];
const float py = points_a[n][p][1];
const float px = points_a[p][0];
const float py = points_a[p][1];
const float dx = px - xf;
const float dy = py - yf;
// Remember: dists[n][y][x][k] = dx * dx + dy * dy;
const float grad_px = 2.0f * grad_dist2 * dx;
const float grad_py = 2.0f * grad_dist2 * dy;
grad_points_a[n][p][0] += grad_px;
grad_points_a[n][p][1] += grad_py;
grad_points_a[n][p][2] += grad_zbuf_a[n][y][x][k];
grad_points_a[p][0] += grad_px;
grad_points_a[p][1] += grad_py;
grad_points_a[p][2] += grad_zbuf_a[n][y][x][k];
}
}
}