From 32ad869deae90622e39be53496e57e779f8f07b9 Mon Sep 17 00:00:00 2001 From: Nikhila Ravi Date: Thu, 12 Mar 2020 07:46:15 -0700 Subject: [PATCH] Update point cloud rasterizer to support heterogeneous point clouds Summary: Update the point cloud rasterizer to: - use the pointcloud datastructure (rebased on top of D19791851.) - support rasterization of heterogeneous point clouds in the same way as with Meshes. The main changes to the API will be as follows: - The input to `rasterize_points` will be a `Pointclouds` object instead of a tensor. This will be easy to update e.g. ``` points = torch.randn(N, P, 3) idx2, zbuf2, dists2 = rasterize_points(points, image_size, radius, points_per_pixel) points = torch.randn(N, P, 3) pointclouds = Pointclouds(points=points) idx2, zbuf2, dists2 = rasterize_points(pointclouds, image_size, radius, points_per_pixel) ``` - The indices output from rasterization will now refer to points in `poinclouds.points_packed()`. This may require some changes to the functions which consume the outputs of rasterization if they were previously assuming that the indices ranged from 0 to P where P is the number of points in each pointcloud. Making this change now so that Olivia can update her PR accordingly. Reviewed By: gkioxari Differential Revision: D20088651 fbshipit-source-id: 833ed659909712bcbbb6a50e2ec0189839f0413a --- .../csrc/rasterize_meshes/rasterize_meshes.cu | 17 +-- .../csrc/rasterize_points/rasterize_points.cu | 120 ++++++++++------ .../csrc/rasterize_points/rasterize_points.h | 128 +++++++++++++++--- .../rasterize_points/rasterize_points_cpu.cpp | 78 +++++++---- 4 files changed, 241 insertions(+), 102 deletions(-) diff --git a/pytorch3d/csrc/rasterize_meshes/rasterize_meshes.cu b/pytorch3d/csrc/rasterize_meshes/rasterize_meshes.cu index ba130d9d..315c7130 100644 --- a/pytorch3d/csrc/rasterize_meshes/rasterize_meshes.cu +++ b/pytorch3d/csrc/rasterize_meshes/rasterize_meshes.cu @@ -102,16 +102,16 @@ __device__ bool CheckPointOutsideBoundingBox( // RasterizeMeshesFineCudaKernel. template __device__ void CheckPixelInsideFace( - const float* face_verts, // (N, P, 3) - int face_idx, + const float* face_verts, // (F, 3, 3) + const int face_idx, int& q_size, float& q_max_z, int& q_max_idx, FaceQ& q, - float blur_radius, - float2 pxy, // Coordinates of the pixel - int K, - bool perspective_correct) { + const float blur_radius, + const float2 pxy, // Coordinates of the pixel + const int K, + const bool perspective_correct) { const auto v012 = GetSingleFaceVerts(face_verts, face_idx); const float3 v0 = thrust::get<0>(v012); const float3 v1 = thrust::get<1>(v012); @@ -335,7 +335,6 @@ __global__ void RasterizeMeshesBackwardCudaKernel( const int64_t* pix_to_face, // (N, H, W, K) const bool perspective_correct, const int N, - const int F, const int H, const int W, const int K, @@ -472,7 +471,6 @@ torch::Tensor RasterizeMeshesBackwardCuda( pix_to_face.contiguous().data(), perspective_correct, N, - F, H, W, K, @@ -671,7 +669,6 @@ __global__ void RasterizeMeshesFineCudaKernel( const int bin_size, const bool perspective_correct, const int N, - const int F, const int B, const int M, const int H, @@ -774,7 +771,6 @@ RasterizeMeshesFineCuda( if (bin_faces.ndimension() != 4) { AT_ERROR("bin_faces must have 4 dimensions"); } - const int F = face_verts.size(0); const int N = bin_faces.size(0); const int B = bin_faces.size(1); const int M = bin_faces.size(3); @@ -803,7 +799,6 @@ RasterizeMeshesFineCuda( bin_size, perspective_correct, N, - F, B, M, H, diff --git a/pytorch3d/csrc/rasterize_points/rasterize_points.cu b/pytorch3d/csrc/rasterize_points/rasterize_points.cu index 4dde4699..d04be892 100644 --- a/pytorch3d/csrc/rasterize_points/rasterize_points.cu +++ b/pytorch3d/csrc/rasterize_points/rasterize_points.cu @@ -30,8 +30,8 @@ __device__ inline bool operator<(const Pix& a, const Pix& b) { // RasterizePointsFineCudaKernel. template __device__ void CheckPixelInsidePoint( - const float* points, // (N, P, 3) - const int p, + const float* points, // (P, 3) + const int p_idx, int& q_size, float& q_max_z, int& q_max_idx, @@ -39,12 +39,10 @@ __device__ void CheckPixelInsidePoint( const float radius2, const float xf, const float yf, - const int n, - const int P, const int K) { - const float px = points[n * P * 3 + p * 3 + 0]; - const float py = points[n * P * 3 + p * 3 + 1]; - const float pz = points[n * P * 3 + p * 3 + 2]; + const float px = points[p_idx * 3 + 0]; + const float py = points[p_idx * 3 + 1]; + const float pz = points[p_idx * 3 + 2]; if (pz < 0) return; // Don't render points behind the camera const float dx = xf - px; @@ -53,7 +51,7 @@ __device__ void CheckPixelInsidePoint( if (dist2 < radius2) { if (q_size < K) { // Just insert it - q[q_size] = {pz, p, dist2}; + q[q_size] = {pz, p_idx, dist2}; if (pz > q_max_z) { q_max_z = pz; q_max_idx = q_size; @@ -61,7 +59,7 @@ __device__ void CheckPixelInsidePoint( q_size++; } else if (pz < q_max_z) { // Overwrite the old max, and find the new max - q[q_max_idx] = {pz, p, dist2}; + q[q_max_idx] = {pz, p_idx, dist2}; q_max_z = pz; for (int i = 0; i < K; i++) { if (q[i].z > q_max_z) { @@ -78,10 +76,11 @@ __device__ void CheckPixelInsidePoint( // **************************************************************************** __global__ void RasterizePointsNaiveCudaKernel( - const float* points, // (N, P, 3) + const float* points, // (P, 3) + const int64_t* cloud_to_packed_first_idx, // (N) + const int64_t* num_points_per_cloud, // (N) const float radius, const int N, - const int P, const int S, const int K, int32_t* point_idxs, // (N, S, S, K) @@ -116,9 +115,15 @@ __global__ void RasterizePointsNaiveCudaKernel( int q_size = 0; float q_max_z = -1000; int q_max_idx = -1; - for (int p = 0; p < P; ++p) { + + // Using the batch index of the thread get the start and stop + // indices for the points. + const int64_t point_start_idx = cloud_to_packed_first_idx[n]; + const int64_t point_stop_idx = point_start_idx + num_points_per_cloud[n]; + + for (int p_idx = point_start_idx; p_idx < point_stop_idx; ++p_idx) { CheckPixelInsidePoint( - points, p, q_size, q_max_z, q_max_idx, q, radius2, xf, yf, n, P, K); + points, p_idx, q_size, q_max_z, q_max_idx, q, radius2, xf, yf, K); } BubbleSort(q, q_size); int idx = n * S * S * K + yi * S * K + xi * K; @@ -132,14 +137,24 @@ __global__ void RasterizePointsNaiveCudaKernel( std::tuple RasterizePointsNaiveCuda( - const torch::Tensor& points, + const torch::Tensor& points, // (P. 3) + const torch::Tensor& cloud_to_packed_first_idx, // (N) + const torch::Tensor& num_points_per_cloud, // (N) const int image_size, const float radius, const int points_per_pixel) { - const int N = points.size(0); - const int P = points.size(1); + if (points.ndimension() != 2 || points.size(1) != 3) { + AT_ERROR("points must have dimensions (num_points, 3)"); + } + if (num_points_per_cloud.size(0) != cloud_to_packed_first_idx.size(0)) { + AT_ERROR( + "num_points_per_cloud must have same size first dimension as cloud_to_packed_first_idx"); + } + + const int N = num_points_per_cloud.size(0); // batch size. const int S = image_size; const int K = points_per_pixel; + if (K > kMaxPointsPerPixel) { std::stringstream ss; ss << "Must have points_per_pixel <= " << kMaxPointsPerPixel; @@ -156,9 +171,10 @@ RasterizePointsNaiveCuda( const size_t threads = 64; RasterizePointsNaiveCudaKernel<<>>( points.contiguous().data(), + cloud_to_packed_first_idx.contiguous().data(), + num_points_per_cloud.contiguous().data(), radius, N, - P, S, K, point_idxs.contiguous().data(), @@ -172,7 +188,9 @@ RasterizePointsNaiveCuda( // **************************************************************************** __global__ void RasterizePointsCoarseCudaKernel( - const float* points, + const float* points, // (P, 3) + const int64_t* cloud_to_packed_first_idx, // (N) + const int64_t* num_points_per_cloud, // (N) const float radius, const int N, const int P, @@ -206,16 +224,27 @@ __global__ void RasterizePointsCoarseCudaKernel( binmask.block_clear(); + // Using the batch index of the thread get the start and stop + // indices for the points. + const int64_t cloud_point_start_idx = cloud_to_packed_first_idx[batch_idx]; + const int64_t cloud_point_stop_idx = + cloud_point_start_idx + num_points_per_cloud[batch_idx]; + // Have each thread handle a different point within the chunk for (int p = threadIdx.x; p < chunk_size; p += blockDim.x) { const int p_idx = point_start_idx + p; - if (p_idx >= P) - break; - const float px = points[batch_idx * P * 3 + p_idx * 3 + 0]; - const float py = points[batch_idx * P * 3 + p_idx * 3 + 1]; - const float pz = points[batch_idx * P * 3 + p_idx * 3 + 2]; + + // Check if point index corresponds to the cloud in the batch given by + // batch_idx. + if (p_idx >= cloud_point_stop_idx || p_idx < cloud_point_start_idx) { + continue; + } + + const float px = points[p_idx * 3 + 0]; + const float py = points[p_idx * 3 + 1]; + const float pz = points[p_idx * 3 + 2]; if (pz < 0) - continue; // Don't render points behind the camera + continue; // Don't render points behind the camera. const float px0 = px - radius; const float px1 = px + radius; const float py0 = py - radius; @@ -283,15 +312,20 @@ __global__ void RasterizePointsCoarseCudaKernel( } torch::Tensor RasterizePointsCoarseCuda( - const torch::Tensor& points, + const torch::Tensor& points, // (P, 3) + const torch::Tensor& cloud_to_packed_first_idx, // (N) + const torch::Tensor& num_points_per_cloud, // (N) const int image_size, const float radius, const int bin_size, const int max_points_per_bin) { - const int N = points.size(0); - const int P = points.size(1); + const int P = points.size(0); + const int N = num_points_per_cloud.size(0); const int num_bins = 1 + (image_size - 1) / bin_size; // divide round up const int M = max_points_per_bin; + if (points.ndimension() != 2 || points.size(1) != 3) { + AT_ERROR("points must have dimensions (num_points, 3)"); + } if (num_bins >= 22) { // Make sure we do not use too much shared memory. std::stringstream ss; @@ -307,6 +341,8 @@ torch::Tensor RasterizePointsCoarseCuda( const size_t threads = 512; RasterizePointsCoarseCudaKernel<<>>( points.contiguous().data(), + cloud_to_packed_first_idx.contiguous().data(), + num_points_per_cloud.contiguous().data(), radius, N, P, @@ -324,12 +360,11 @@ torch::Tensor RasterizePointsCoarseCuda( // **************************************************************************** __global__ void RasterizePointsFineCudaKernel( - const float* points, // (N, P, 3) + const float* points, // (P, 3) const int32_t* bin_points, // (N, B, B, T) const float radius, const int bin_size, const int N, - const int P, const int B, const int M, const int S, @@ -342,6 +377,7 @@ __global__ void RasterizePointsFineCudaKernel( const int num_threads = gridDim.x * blockDim.x; const int tid = blockIdx.x * blockDim.x + threadIdx.x; const float radius2 = radius * radius; + for (int pid = tid; pid < num_pixels; pid += num_threads) { // Convert linear index into bin and pixel indices. We make the within // block pixel ids move the fastest, so that adjacent threads will fall @@ -377,7 +413,7 @@ __global__ void RasterizePointsFineCudaKernel( continue; } CheckPixelInsidePoint( - points, p, q_size, q_max_z, q_max_idx, q, radius2, xf, yf, n, P, K); + points, p, q_size, q_max_z, q_max_idx, q, radius2, xf, yf, K); } // Now we've looked at all the points for this bin, so we can write // output for the current pixel. @@ -392,14 +428,13 @@ __global__ void RasterizePointsFineCudaKernel( } std::tuple RasterizePointsFineCuda( - const torch::Tensor& points, + const torch::Tensor& points, // (P, 3) const torch::Tensor& bin_points, const int image_size, const float radius, const int bin_size, const int points_per_pixel) { - const int N = points.size(0); - const int P = points.size(1); + const int N = bin_points.size(0); const int B = bin_points.size(1); const int M = bin_points.size(3); const int S = image_size; @@ -421,7 +456,6 @@ std::tuple RasterizePointsFineCuda( radius, bin_size, N, - P, B, M, S, @@ -438,7 +472,7 @@ std::tuple RasterizePointsFineCuda( // **************************************************************************** // TODO(T55115174) Add more documentation for backward kernel. __global__ void RasterizePointsBackwardCudaKernel( - const float* points, // (N, P, 3) + const float* points, // (P, 3) const int32_t* idxs, // (N, H, W, K) const int N, const int P, @@ -447,13 +481,13 @@ __global__ void RasterizePointsBackwardCudaKernel( const int K, const float* grad_zbuf, // (N, H, W, K) const float* grad_dists, // (N, H, W, K) - float* grad_points) { // (N, P, 3) + float* grad_points) { // (P, 3) // Parallelized over each of K points per pixel, for each pixel in images of // size H * W, for each image in the batch of size N. int num_threads = gridDim.x * blockDim.x; int tid = blockIdx.x * blockDim.x + threadIdx.x; for (int i = tid; i < N * H * W * K; i += num_threads) { - const int n = i / (H * W * K); + // const int n = i / (H * W * K); // batch index (not needed). const int yxk = i % (H * W * K); const int yi = yxk / (W * K); const int xk = yxk % (W * K); @@ -466,15 +500,15 @@ __global__ void RasterizePointsBackwardCudaKernel( if (p < 0) continue; const float grad_dist2 = grad_dists[i]; - const int p_ind = n * P * 3 + p * 3; - const float px = points[p_ind]; + const int p_ind = p * 3; // index into packed points tensor + const float px = points[p_ind + 0]; const float py = points[p_ind + 1]; const float dx = px - xf; const float dy = py - yf; const float grad_px = 2.0f * grad_dist2 * dx; const float grad_py = 2.0f * grad_dist2 * dy; const float grad_pz = grad_zbuf[i]; - atomicAdd(grad_points + p_ind, grad_px); + atomicAdd(grad_points + p_ind + 0, grad_px); atomicAdd(grad_points + p_ind + 1, grad_py); atomicAdd(grad_points + p_ind + 2, grad_pz); } @@ -485,13 +519,13 @@ torch::Tensor RasterizePointsBackwardCuda( const torch::Tensor& idxs, // (N, H, W, K) const torch::Tensor& grad_zbuf, // (N, H, W, K) const torch::Tensor& grad_dists) { // (N, H, W, K) - const int N = points.size(0); - const int P = points.size(1); + const int P = points.size(0); + const int N = idxs.size(0); const int H = idxs.size(1); const int W = idxs.size(2); const int K = idxs.size(3); - torch::Tensor grad_points = torch::zeros({N, P, 3}, points.options()); + torch::Tensor grad_points = torch::zeros({P, 3}, points.options()); const size_t blocks = 1024; const size_t threads = 64; diff --git a/pytorch3d/csrc/rasterize_points/rasterize_points.h b/pytorch3d/csrc/rasterize_points/rasterize_points.h index a80dfcb1..e171db4b 100644 --- a/pytorch3d/csrc/rasterize_points/rasterize_points.h +++ b/pytorch3d/csrc/rasterize_points/rasterize_points.h @@ -11,6 +11,8 @@ std::tuple RasterizePointsNaiveCpu( const torch::Tensor& points, + const torch::Tensor& cloud_to_packed_first_idx, + const torch::Tensor& num_points_per_cloud, const int image_size, const float radius, const int points_per_pixel); @@ -19,6 +21,8 @@ std::tuple RasterizePointsNaiveCpu( std::tuple RasterizePointsNaiveCuda( const torch::Tensor& points, + const torch::Tensor& cloud_to_packed_first_idx, + const torch::Tensor& num_points_per_cloud, const int image_size, const float radius, const int points_per_pixel); @@ -27,16 +31,26 @@ RasterizePointsNaiveCuda( // check whether that point hits the pixel. // // Args: -// points: Tensor of shape (N, P, 3) (in NDC) +// points: Tensor of shape (P, 3) giving (packed) positions for +// points in all N pointclouds in the batch where P is the total +// number of points in the batch across all pointclouds. These points +// are expected to be in NDC coordinates in the range [-1, 1]. +// cloud_to_packed_first_idx: LongTensor of shape (N) giving the index in +// points_packed of the first point in each pointcloud +// in the batch where N is the batch size. +// num_points_per_cloud: LongTensor of shape (N) giving the number of points +// for each pointcloud in the batch. // radius: Radius of each point (in NDC units) -// image_size: (S) Size of the image to return (in pixels) +// image_size: (S) Size of the image to return (in pixels) // points_per_pixel: (K) The number closest of points to return for each pixel // // Returns: +// A 4 element tuple of: // idxs: int32 Tensor of shape (N, S, S, K) giving the indices of the // closest K points along the z-axis for each pixel, padded with -1 for -// pixels -// hit by fewer than K points. +// pixels hit by fewer than K points. The indices refer to points in +// points packed i.e a tensor of shape (P, 3) representing the flattened +// points for all pointclouds in the batch. // zbuf: float32 Tensor of shape (N, S, S, K) giving the depth of each // closest point for each pixel. // dists: float32 Tensor of shape (N, S, S, K) giving squared Euclidean @@ -44,19 +58,32 @@ RasterizePointsNaiveCuda( // points along the z axis. std::tuple RasterizePointsNaive( const torch::Tensor& points, + const torch::Tensor& cloud_to_packed_first_idx, + const torch::Tensor& num_points_per_cloud, const int image_size, const float radius, const int points_per_pixel) { - if (points.type().is_cuda()) { + if (points.type().is_cuda() && cloud_to_packed_first_idx.type().is_cuda() && + num_points_per_cloud.type().is_cuda()) { #ifdef WITH_CUDA return RasterizePointsNaiveCuda( - points, image_size, radius, points_per_pixel); + points, + cloud_to_packed_first_idx, + num_points_per_cloud, + image_size, + radius, + points_per_pixel); #else AT_ERROR("Not compiled with GPU support"); #endif } else { return RasterizePointsNaiveCpu( - points, image_size, radius, points_per_pixel); + points, + cloud_to_packed_first_idx, + num_points_per_cloud, + image_size, + radius, + points_per_pixel); } } @@ -66,6 +93,8 @@ std::tuple RasterizePointsNaive( torch::Tensor RasterizePointsCoarseCpu( const torch::Tensor& points, + const torch::Tensor& cloud_to_packed_first_idx, + const torch::Tensor& num_points_per_cloud, const int image_size, const float radius, const int bin_size, @@ -74,13 +103,23 @@ torch::Tensor RasterizePointsCoarseCpu( #ifdef WITH_CUDA torch::Tensor RasterizePointsCoarseCuda( const torch::Tensor& points, + const torch::Tensor& cloud_to_packed_first_idx, + const torch::Tensor& num_points_per_cloud, const int image_size, const float radius, const int bin_size, const int max_points_per_bin); #endif // Args: -// points: Tensor of shape (N, P, 3) +// points: Tensor of shape (P, 3) giving (packed) positions for +// points in all N pointclouds in the batch where P is the total +// number of points in the batch across all pointclouds. These points +// are expected to be in NDC coordinates in the range [-1, 1]. +// cloud_to_packed_first_idx: LongTensor of shape (N) giving the index in +// points_packed of the first point in each pointcloud +// in the batch where N is the batch size. +// num_points_per_cloud: LongTensor of shape (N) giving the number of points +// for each pointcloud in the batch. // radius: Radius of points to rasterize (in NDC units) // image_size: Size of the image to generate (in pixels) // bin_size: Size of each bin within the image (in pixels) @@ -92,20 +131,35 @@ torch::Tensor RasterizePointsCoarseCuda( // of points that fall into each bin. torch::Tensor RasterizePointsCoarse( const torch::Tensor& points, + const torch::Tensor& cloud_to_packed_first_idx, + const torch::Tensor& num_points_per_cloud, const int image_size, const float radius, const int bin_size, const int max_points_per_bin) { - if (points.type().is_cuda()) { + if (points.type().is_cuda() && cloud_to_packed_first_idx.type().is_cuda() && + num_points_per_cloud.type().is_cuda()) { #ifdef WITH_CUDA return RasterizePointsCoarseCuda( - points, image_size, radius, bin_size, max_points_per_bin); + points, + cloud_to_packed_first_idx, + num_points_per_cloud, + image_size, + radius, + bin_size, + max_points_per_bin); #else AT_ERROR("Not compiled with GPU support"); #endif } else { return RasterizePointsCoarseCpu( - points, image_size, radius, bin_size, max_points_per_bin); + points, + cloud_to_packed_first_idx, + num_points_per_cloud, + image_size, + radius, + bin_size, + max_points_per_bin); } } @@ -123,7 +177,10 @@ std::tuple RasterizePointsFineCuda( const int points_per_pixel); #endif // Args: -// points: float32 Tensor of shape (N, P, 3) +// points: Tensor of shape (P, 3) giving (packed) positions for +// points in all N pointclouds in the batch where P is the total +// number of points in the batch across all pointclouds. These points +// are expected to be in NDC coordinates in the range [-1, 1]. // bin_points: int32 Tensor of shape (N, B, B, M) giving the indices of points // that fall into each bin (output from coarse rasterization) // image_size: Size of image to generate (in pixels) @@ -132,9 +189,11 @@ std::tuple RasterizePointsFineCuda( // points_per_pixel: How many points to rasterize for each pixel // // Returns (same as rasterize_points): -// idxs: int32 Tensor of shape (N, S, S, K) giving the indices of the closest -// points_per_pixel points along the z-axis for each pixel, padded with -// -1 for pixels hit by fewer than points_per_pixel points +// idxs: int32 Tensor of shape (N, S, S, K) giving the indices of the +// closest K points along the z-axis for each pixel, padded with -1 for +// pixels hit by fewer than K points. The indices refer to points in +// points packed i.e a tensor of shape (P, 3) representing the flattened +// points for all pointclouds in the batch. // zbuf: float32 Tensor of shape (N, S, S, K) giving the depth of each of each // closest point for each pixel // dists: float32 Tensor of shape (N, S, S, K) giving squared Euclidean @@ -177,7 +236,10 @@ torch::Tensor RasterizePointsBackwardCuda( const torch::Tensor& grad_dists); #endif // Args: -// points: float32 Tensor of shape (N, P, 3) +// points: Tensor of shape (P, 3) giving (packed) positions for +// points in all N pointclouds in the batch where P is the total +// number of points in the batch across all pointclouds. These points +// are expected to be in NDC coordinates in the range [-1, 1]. // idxs: int32 Tensor of shape (N, H, W, K) (from forward pass) // grad_zbuf: float32 Tensor of shape (N, H, W, K) giving upstream gradient // d(loss)/d(zbuf) of the distances from each pixel to its nearest @@ -212,7 +274,15 @@ torch::Tensor RasterizePointsBackward( // it uses either naive or coarse-to-fine rasterization based on bin_size. // // Args: -// points: Tensor of shape (N, P, 3) (in NDC) +// points: Tensor of shape (P, 3) giving (packed) positions for +// points in all N pointclouds in the batch where P is the total +// number of points in the batch across all pointclouds. These points +// are expected to be in NDC coordinates in the range [-1, 1]. +// cloud_to_packed_first_idx: LongTensor of shape (N) giving the index in +// points_packed of the first point in each pointcloud +// in the batch where N is the batch size. +// num_points_per_cloud: LongTensor of shape (N) giving the number of points +// for each pointcloud in the batch. // radius: Radius of each point (in NDC units) // image_size: (S) Size of the image to return (in pixels) // points_per_pixel: (K) The number of points to return for each pixel @@ -223,8 +293,10 @@ torch::Tensor RasterizePointsBackward( // // Returns: // idxs: int32 Tensor of shape (N, S, S, K) giving the indices of the -// closest points_per_pixel points along the z-axis for each pixel, -// padded with -1 for pixels hit by fewer than points_per_pixel points +// closest K points along the z-axis for each pixel, padded with -1 for +// pixels hit by fewer than K points. The indices refer to points in +// points packed i.e a tensor of shape (P, 3) representing the flattened +// points for all pointclouds in the batch. // zbuf: float32 Tensor of shape (N, S, S, K) giving the depth of each of each // closest point for each pixel // dists: float32 Tensor of shape (N, S, S, K) giving squared Euclidean @@ -232,6 +304,8 @@ torch::Tensor RasterizePointsBackward( // points along the z axis. std::tuple RasterizePoints( const torch::Tensor& points, + const torch::Tensor& cloud_to_packed_first_idx, + const torch::Tensor& num_points_per_cloud, const int image_size, const float radius, const int points_per_pixel, @@ -239,11 +313,23 @@ std::tuple RasterizePoints( const int max_points_per_bin) { if (bin_size == 0) { // Use the naive per-pixel implementation - return RasterizePointsNaive(points, image_size, radius, points_per_pixel); + return RasterizePointsNaive( + points, + cloud_to_packed_first_idx, + num_points_per_cloud, + image_size, + radius, + points_per_pixel); } else { // Use coarse-to-fine rasterization const auto bin_points = RasterizePointsCoarse( - points, image_size, radius, bin_size, max_points_per_bin); + points, + cloud_to_packed_first_idx, + num_points_per_cloud, + image_size, + radius, + bin_size, + max_points_per_bin); return RasterizePointsFine( points, bin_points, image_size, radius, bin_size, points_per_pixel); } diff --git a/pytorch3d/csrc/rasterize_points/rasterize_points_cpu.cpp b/pytorch3d/csrc/rasterize_points/rasterize_points_cpu.cpp index 3ea2a213..893c17cc 100644 --- a/pytorch3d/csrc/rasterize_points/rasterize_points_cpu.cpp +++ b/pytorch3d/csrc/rasterize_points/rasterize_points_cpu.cpp @@ -13,37 +13,49 @@ static float PixToNdc(const int i, const int S) { } std::tuple RasterizePointsNaiveCpu( - const torch::Tensor& points, + const torch::Tensor& points, // (P, 3) + const torch::Tensor& cloud_to_packed_first_idx, // (N) + const torch::Tensor& num_points_per_cloud, // (N) const int image_size, const float radius, const int points_per_pixel) { - const int N = points.size(0); - const int P = points.size(1); + const int32_t N = cloud_to_packed_first_idx.size(0); // batch_size. + const int S = image_size; const int K = points_per_pixel; + + // Initialize output tensors. auto int_opts = points.options().dtype(torch::kInt32); auto float_opts = points.options().dtype(torch::kFloat32); torch::Tensor point_idxs = torch::full({N, S, S, K}, -1, int_opts); torch::Tensor zbuf = torch::full({N, S, S, K}, -1, float_opts); torch::Tensor pix_dists = torch::full({N, S, S, K}, -1, float_opts); - auto points_a = points.accessor(); + auto points_a = points.accessor(); auto point_idxs_a = point_idxs.accessor(); auto zbuf_a = zbuf.accessor(); auto pix_dists_a = pix_dists.accessor(); const float radius2 = radius * radius; for (int n = 0; n < N; ++n) { + // Loop through each pointcloud in the batch. + // Get the start index of the points in points_packed and the num points + // in the point cloud. + const int point_start_idx = + cloud_to_packed_first_idx[n].item().to(); + const int point_stop_idx = + (point_start_idx + num_points_per_cloud[n].item().to()); + for (int yi = 0; yi < S; ++yi) { float yf = PixToNdc(yi, S); for (int xi = 0; xi < S; ++xi) { float xf = PixToNdc(xi, S); // Use a priority queue to hold (z, idx, r) std::priority_queue> q; - for (int p = 0; p < P; ++p) { - const float px = points_a[n][p][0]; - const float py = points_a[n][p][1]; - const float pz = points_a[n][p][2]; + for (int p = point_start_idx; p < point_stop_idx; ++p) { + const float px = points_a[p][0]; + const float py = points_a[p][1]; + const float pz = points_a[p][2]; if (pz < 0) { continue; } @@ -75,26 +87,37 @@ std::tuple RasterizePointsNaiveCpu( } torch::Tensor RasterizePointsCoarseCpu( - const torch::Tensor& points, + const torch::Tensor& points, // (P, 3) + const torch::Tensor& cloud_to_packed_first_idx, // (N) + const torch::Tensor& num_points_per_cloud, // (N) const int image_size, const float radius, const int bin_size, const int max_points_per_bin) { - const int N = points.size(0); - const int P = points.size(1); + const int32_t N = cloud_to_packed_first_idx.size(0); // batch_size. + const int B = 1 + (image_size - 1) / bin_size; // Integer division round up const int M = max_points_per_bin; auto opts = points.options().dtype(torch::kInt32); torch::Tensor points_per_bin = torch::zeros({N, B, B}, opts); torch::Tensor bin_points = torch::full({N, B, B, M}, -1, opts); - auto points_a = points.accessor(); + auto points_a = points.accessor(); auto points_per_bin_a = points_per_bin.accessor(); auto bin_points_a = bin_points.accessor(); const float pixel_width = 2.0f / image_size; const float bin_width = pixel_width * bin_size; + for (int n = 0; n < N; ++n) { + // Loop through each pointcloud in the batch. + // Get the start index of the points in points_packed and the num points + // in the point cloud. + const int point_start_idx = + cloud_to_packed_first_idx[n].item().to(); + const int point_stop_idx = + (point_start_idx + num_points_per_cloud[n].item().to()); + float bin_y_min = -1.0f; float bin_y_max = bin_y_min + bin_width; for (int by = 0; by < B; by++) { @@ -102,10 +125,10 @@ torch::Tensor RasterizePointsCoarseCpu( float bin_x_max = bin_x_min + bin_width; for (int bx = 0; bx < B; bx++) { int32_t points_hit = 0; - for (int32_t p = 0; p < P; p++) { - float px = points_a[n][p][0]; - float py = points_a[n][p][1]; - float pz = points_a[n][p][2]; + for (int p = point_start_idx; p < point_stop_idx; ++p) { + float px = points_a[p][0]; + float py = points_a[p][1]; + float pz = points_a[p][2]; if (pz < 0) { continue; } @@ -144,12 +167,13 @@ torch::Tensor RasterizePointsCoarseCpu( } torch::Tensor RasterizePointsBackwardCpu( - const torch::Tensor& points, // (N, P, 3) + const torch::Tensor& points, // (P, 3) const torch::Tensor& idxs, // (N, H, W, K) const torch::Tensor& grad_zbuf, // (N, H, W, K) const torch::Tensor& grad_dists) { // (N, H, W, K) - const int N = points.size(0); - const int P = points.size(1); + + const int N = idxs.size(0); + const int P = points.size(0); const int H = idxs.size(1); const int W = idxs.size(2); const int K = idxs.size(3); @@ -159,13 +183,13 @@ torch::Tensor RasterizePointsBackwardCpu( if (H != W) { AT_ERROR("RasterizePointsBackwardCpu only supports square images"); } - torch::Tensor grad_points = torch::zeros({N, P, 3}, points.options()); + torch::Tensor grad_points = torch::zeros({P, 3}, points.options()); - auto points_a = points.accessor(); + auto points_a = points.accessor(); auto idxs_a = idxs.accessor(); auto grad_dists_a = grad_dists.accessor(); auto grad_zbuf_a = grad_zbuf.accessor(); - auto grad_points_a = grad_points.accessor(); + auto grad_points_a = grad_points.accessor(); for (int n = 0; n < N; ++n) { // Loop over images in the batch for (int y = 0; y < H; ++y) { // Loop over rows in the image @@ -178,16 +202,16 @@ torch::Tensor RasterizePointsBackwardCpu( break; } const float grad_dist2 = grad_dists_a[n][y][x][k]; - const float px = points_a[n][p][0]; - const float py = points_a[n][p][1]; + const float px = points_a[p][0]; + const float py = points_a[p][1]; const float dx = px - xf; const float dy = py - yf; // Remember: dists[n][y][x][k] = dx * dx + dy * dy; const float grad_px = 2.0f * grad_dist2 * dx; const float grad_py = 2.0f * grad_dist2 * dy; - grad_points_a[n][p][0] += grad_px; - grad_points_a[n][p][1] += grad_py; - grad_points_a[n][p][2] += grad_zbuf_a[n][y][x][k]; + grad_points_a[p][0] += grad_px; + grad_points_a[p][1] += grad_py; + grad_points_a[p][2] += grad_zbuf_a[n][y][x][k]; } } }