mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Update point cloud rasterizer to support heterogeneous point clouds
Summary: Update the point cloud rasterizer to: - use the pointcloud datastructure (rebased on top of D19791851.) - support rasterization of heterogeneous point clouds in the same way as with Meshes. The main changes to the API will be as follows: - The input to `rasterize_points` will be a `Pointclouds` object instead of a tensor. This will be easy to update e.g. ``` points = torch.randn(N, P, 3) idx2, zbuf2, dists2 = rasterize_points(points, image_size, radius, points_per_pixel) points = torch.randn(N, P, 3) pointclouds = Pointclouds(points=points) idx2, zbuf2, dists2 = rasterize_points(pointclouds, image_size, radius, points_per_pixel) ``` - The indices output from rasterization will now refer to points in `poinclouds.points_packed()`. This may require some changes to the functions which consume the outputs of rasterization if they were previously assuming that the indices ranged from 0 to P where P is the number of points in each pointcloud. Making this change now so that Olivia can update her PR accordingly. Reviewed By: gkioxari Differential Revision: D20088651 fbshipit-source-id: 833ed659909712bcbbb6a50e2ec0189839f0413a
This commit is contained in:
parent
cae325718e
commit
32ad869dea
@ -102,16 +102,16 @@ __device__ bool CheckPointOutsideBoundingBox(
|
||||
// RasterizeMeshesFineCudaKernel.
|
||||
template <typename FaceQ>
|
||||
__device__ void CheckPixelInsideFace(
|
||||
const float* face_verts, // (N, P, 3)
|
||||
int face_idx,
|
||||
const float* face_verts, // (F, 3, 3)
|
||||
const int face_idx,
|
||||
int& q_size,
|
||||
float& q_max_z,
|
||||
int& q_max_idx,
|
||||
FaceQ& q,
|
||||
float blur_radius,
|
||||
float2 pxy, // Coordinates of the pixel
|
||||
int K,
|
||||
bool perspective_correct) {
|
||||
const float blur_radius,
|
||||
const float2 pxy, // Coordinates of the pixel
|
||||
const int K,
|
||||
const bool perspective_correct) {
|
||||
const auto v012 = GetSingleFaceVerts(face_verts, face_idx);
|
||||
const float3 v0 = thrust::get<0>(v012);
|
||||
const float3 v1 = thrust::get<1>(v012);
|
||||
@ -335,7 +335,6 @@ __global__ void RasterizeMeshesBackwardCudaKernel(
|
||||
const int64_t* pix_to_face, // (N, H, W, K)
|
||||
const bool perspective_correct,
|
||||
const int N,
|
||||
const int F,
|
||||
const int H,
|
||||
const int W,
|
||||
const int K,
|
||||
@ -472,7 +471,6 @@ torch::Tensor RasterizeMeshesBackwardCuda(
|
||||
pix_to_face.contiguous().data<int64_t>(),
|
||||
perspective_correct,
|
||||
N,
|
||||
F,
|
||||
H,
|
||||
W,
|
||||
K,
|
||||
@ -671,7 +669,6 @@ __global__ void RasterizeMeshesFineCudaKernel(
|
||||
const int bin_size,
|
||||
const bool perspective_correct,
|
||||
const int N,
|
||||
const int F,
|
||||
const int B,
|
||||
const int M,
|
||||
const int H,
|
||||
@ -774,7 +771,6 @@ RasterizeMeshesFineCuda(
|
||||
if (bin_faces.ndimension() != 4) {
|
||||
AT_ERROR("bin_faces must have 4 dimensions");
|
||||
}
|
||||
const int F = face_verts.size(0);
|
||||
const int N = bin_faces.size(0);
|
||||
const int B = bin_faces.size(1);
|
||||
const int M = bin_faces.size(3);
|
||||
@ -803,7 +799,6 @@ RasterizeMeshesFineCuda(
|
||||
bin_size,
|
||||
perspective_correct,
|
||||
N,
|
||||
F,
|
||||
B,
|
||||
M,
|
||||
H,
|
||||
|
@ -30,8 +30,8 @@ __device__ inline bool operator<(const Pix& a, const Pix& b) {
|
||||
// RasterizePointsFineCudaKernel.
|
||||
template <typename PointQ>
|
||||
__device__ void CheckPixelInsidePoint(
|
||||
const float* points, // (N, P, 3)
|
||||
const int p,
|
||||
const float* points, // (P, 3)
|
||||
const int p_idx,
|
||||
int& q_size,
|
||||
float& q_max_z,
|
||||
int& q_max_idx,
|
||||
@ -39,12 +39,10 @@ __device__ void CheckPixelInsidePoint(
|
||||
const float radius2,
|
||||
const float xf,
|
||||
const float yf,
|
||||
const int n,
|
||||
const int P,
|
||||
const int K) {
|
||||
const float px = points[n * P * 3 + p * 3 + 0];
|
||||
const float py = points[n * P * 3 + p * 3 + 1];
|
||||
const float pz = points[n * P * 3 + p * 3 + 2];
|
||||
const float px = points[p_idx * 3 + 0];
|
||||
const float py = points[p_idx * 3 + 1];
|
||||
const float pz = points[p_idx * 3 + 2];
|
||||
if (pz < 0)
|
||||
return; // Don't render points behind the camera
|
||||
const float dx = xf - px;
|
||||
@ -53,7 +51,7 @@ __device__ void CheckPixelInsidePoint(
|
||||
if (dist2 < radius2) {
|
||||
if (q_size < K) {
|
||||
// Just insert it
|
||||
q[q_size] = {pz, p, dist2};
|
||||
q[q_size] = {pz, p_idx, dist2};
|
||||
if (pz > q_max_z) {
|
||||
q_max_z = pz;
|
||||
q_max_idx = q_size;
|
||||
@ -61,7 +59,7 @@ __device__ void CheckPixelInsidePoint(
|
||||
q_size++;
|
||||
} else if (pz < q_max_z) {
|
||||
// Overwrite the old max, and find the new max
|
||||
q[q_max_idx] = {pz, p, dist2};
|
||||
q[q_max_idx] = {pz, p_idx, dist2};
|
||||
q_max_z = pz;
|
||||
for (int i = 0; i < K; i++) {
|
||||
if (q[i].z > q_max_z) {
|
||||
@ -78,10 +76,11 @@ __device__ void CheckPixelInsidePoint(
|
||||
// ****************************************************************************
|
||||
|
||||
__global__ void RasterizePointsNaiveCudaKernel(
|
||||
const float* points, // (N, P, 3)
|
||||
const float* points, // (P, 3)
|
||||
const int64_t* cloud_to_packed_first_idx, // (N)
|
||||
const int64_t* num_points_per_cloud, // (N)
|
||||
const float radius,
|
||||
const int N,
|
||||
const int P,
|
||||
const int S,
|
||||
const int K,
|
||||
int32_t* point_idxs, // (N, S, S, K)
|
||||
@ -116,9 +115,15 @@ __global__ void RasterizePointsNaiveCudaKernel(
|
||||
int q_size = 0;
|
||||
float q_max_z = -1000;
|
||||
int q_max_idx = -1;
|
||||
for (int p = 0; p < P; ++p) {
|
||||
|
||||
// Using the batch index of the thread get the start and stop
|
||||
// indices for the points.
|
||||
const int64_t point_start_idx = cloud_to_packed_first_idx[n];
|
||||
const int64_t point_stop_idx = point_start_idx + num_points_per_cloud[n];
|
||||
|
||||
for (int p_idx = point_start_idx; p_idx < point_stop_idx; ++p_idx) {
|
||||
CheckPixelInsidePoint(
|
||||
points, p, q_size, q_max_z, q_max_idx, q, radius2, xf, yf, n, P, K);
|
||||
points, p_idx, q_size, q_max_z, q_max_idx, q, radius2, xf, yf, K);
|
||||
}
|
||||
BubbleSort(q, q_size);
|
||||
int idx = n * S * S * K + yi * S * K + xi * K;
|
||||
@ -132,14 +137,24 @@ __global__ void RasterizePointsNaiveCudaKernel(
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>
|
||||
RasterizePointsNaiveCuda(
|
||||
const torch::Tensor& points,
|
||||
const torch::Tensor& points, // (P. 3)
|
||||
const torch::Tensor& cloud_to_packed_first_idx, // (N)
|
||||
const torch::Tensor& num_points_per_cloud, // (N)
|
||||
const int image_size,
|
||||
const float radius,
|
||||
const int points_per_pixel) {
|
||||
const int N = points.size(0);
|
||||
const int P = points.size(1);
|
||||
if (points.ndimension() != 2 || points.size(1) != 3) {
|
||||
AT_ERROR("points must have dimensions (num_points, 3)");
|
||||
}
|
||||
if (num_points_per_cloud.size(0) != cloud_to_packed_first_idx.size(0)) {
|
||||
AT_ERROR(
|
||||
"num_points_per_cloud must have same size first dimension as cloud_to_packed_first_idx");
|
||||
}
|
||||
|
||||
const int N = num_points_per_cloud.size(0); // batch size.
|
||||
const int S = image_size;
|
||||
const int K = points_per_pixel;
|
||||
|
||||
if (K > kMaxPointsPerPixel) {
|
||||
std::stringstream ss;
|
||||
ss << "Must have points_per_pixel <= " << kMaxPointsPerPixel;
|
||||
@ -156,9 +171,10 @@ RasterizePointsNaiveCuda(
|
||||
const size_t threads = 64;
|
||||
RasterizePointsNaiveCudaKernel<<<blocks, threads>>>(
|
||||
points.contiguous().data<float>(),
|
||||
cloud_to_packed_first_idx.contiguous().data<int64_t>(),
|
||||
num_points_per_cloud.contiguous().data<int64_t>(),
|
||||
radius,
|
||||
N,
|
||||
P,
|
||||
S,
|
||||
K,
|
||||
point_idxs.contiguous().data<int32_t>(),
|
||||
@ -172,7 +188,9 @@ RasterizePointsNaiveCuda(
|
||||
// ****************************************************************************
|
||||
|
||||
__global__ void RasterizePointsCoarseCudaKernel(
|
||||
const float* points,
|
||||
const float* points, // (P, 3)
|
||||
const int64_t* cloud_to_packed_first_idx, // (N)
|
||||
const int64_t* num_points_per_cloud, // (N)
|
||||
const float radius,
|
||||
const int N,
|
||||
const int P,
|
||||
@ -206,16 +224,27 @@ __global__ void RasterizePointsCoarseCudaKernel(
|
||||
|
||||
binmask.block_clear();
|
||||
|
||||
// Using the batch index of the thread get the start and stop
|
||||
// indices for the points.
|
||||
const int64_t cloud_point_start_idx = cloud_to_packed_first_idx[batch_idx];
|
||||
const int64_t cloud_point_stop_idx =
|
||||
cloud_point_start_idx + num_points_per_cloud[batch_idx];
|
||||
|
||||
// Have each thread handle a different point within the chunk
|
||||
for (int p = threadIdx.x; p < chunk_size; p += blockDim.x) {
|
||||
const int p_idx = point_start_idx + p;
|
||||
if (p_idx >= P)
|
||||
break;
|
||||
const float px = points[batch_idx * P * 3 + p_idx * 3 + 0];
|
||||
const float py = points[batch_idx * P * 3 + p_idx * 3 + 1];
|
||||
const float pz = points[batch_idx * P * 3 + p_idx * 3 + 2];
|
||||
|
||||
// Check if point index corresponds to the cloud in the batch given by
|
||||
// batch_idx.
|
||||
if (p_idx >= cloud_point_stop_idx || p_idx < cloud_point_start_idx) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const float px = points[p_idx * 3 + 0];
|
||||
const float py = points[p_idx * 3 + 1];
|
||||
const float pz = points[p_idx * 3 + 2];
|
||||
if (pz < 0)
|
||||
continue; // Don't render points behind the camera
|
||||
continue; // Don't render points behind the camera.
|
||||
const float px0 = px - radius;
|
||||
const float px1 = px + radius;
|
||||
const float py0 = py - radius;
|
||||
@ -283,15 +312,20 @@ __global__ void RasterizePointsCoarseCudaKernel(
|
||||
}
|
||||
|
||||
torch::Tensor RasterizePointsCoarseCuda(
|
||||
const torch::Tensor& points,
|
||||
const torch::Tensor& points, // (P, 3)
|
||||
const torch::Tensor& cloud_to_packed_first_idx, // (N)
|
||||
const torch::Tensor& num_points_per_cloud, // (N)
|
||||
const int image_size,
|
||||
const float radius,
|
||||
const int bin_size,
|
||||
const int max_points_per_bin) {
|
||||
const int N = points.size(0);
|
||||
const int P = points.size(1);
|
||||
const int P = points.size(0);
|
||||
const int N = num_points_per_cloud.size(0);
|
||||
const int num_bins = 1 + (image_size - 1) / bin_size; // divide round up
|
||||
const int M = max_points_per_bin;
|
||||
if (points.ndimension() != 2 || points.size(1) != 3) {
|
||||
AT_ERROR("points must have dimensions (num_points, 3)");
|
||||
}
|
||||
if (num_bins >= 22) {
|
||||
// Make sure we do not use too much shared memory.
|
||||
std::stringstream ss;
|
||||
@ -307,6 +341,8 @@ torch::Tensor RasterizePointsCoarseCuda(
|
||||
const size_t threads = 512;
|
||||
RasterizePointsCoarseCudaKernel<<<blocks, threads, shared_size>>>(
|
||||
points.contiguous().data<float>(),
|
||||
cloud_to_packed_first_idx.contiguous().data<int64_t>(),
|
||||
num_points_per_cloud.contiguous().data<int64_t>(),
|
||||
radius,
|
||||
N,
|
||||
P,
|
||||
@ -324,12 +360,11 @@ torch::Tensor RasterizePointsCoarseCuda(
|
||||
// ****************************************************************************
|
||||
|
||||
__global__ void RasterizePointsFineCudaKernel(
|
||||
const float* points, // (N, P, 3)
|
||||
const float* points, // (P, 3)
|
||||
const int32_t* bin_points, // (N, B, B, T)
|
||||
const float radius,
|
||||
const int bin_size,
|
||||
const int N,
|
||||
const int P,
|
||||
const int B,
|
||||
const int M,
|
||||
const int S,
|
||||
@ -342,6 +377,7 @@ __global__ void RasterizePointsFineCudaKernel(
|
||||
const int num_threads = gridDim.x * blockDim.x;
|
||||
const int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const float radius2 = radius * radius;
|
||||
|
||||
for (int pid = tid; pid < num_pixels; pid += num_threads) {
|
||||
// Convert linear index into bin and pixel indices. We make the within
|
||||
// block pixel ids move the fastest, so that adjacent threads will fall
|
||||
@ -377,7 +413,7 @@ __global__ void RasterizePointsFineCudaKernel(
|
||||
continue;
|
||||
}
|
||||
CheckPixelInsidePoint(
|
||||
points, p, q_size, q_max_z, q_max_idx, q, radius2, xf, yf, n, P, K);
|
||||
points, p, q_size, q_max_z, q_max_idx, q, radius2, xf, yf, K);
|
||||
}
|
||||
// Now we've looked at all the points for this bin, so we can write
|
||||
// output for the current pixel.
|
||||
@ -392,14 +428,13 @@ __global__ void RasterizePointsFineCudaKernel(
|
||||
}
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsFineCuda(
|
||||
const torch::Tensor& points,
|
||||
const torch::Tensor& points, // (P, 3)
|
||||
const torch::Tensor& bin_points,
|
||||
const int image_size,
|
||||
const float radius,
|
||||
const int bin_size,
|
||||
const int points_per_pixel) {
|
||||
const int N = points.size(0);
|
||||
const int P = points.size(1);
|
||||
const int N = bin_points.size(0);
|
||||
const int B = bin_points.size(1);
|
||||
const int M = bin_points.size(3);
|
||||
const int S = image_size;
|
||||
@ -421,7 +456,6 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsFineCuda(
|
||||
radius,
|
||||
bin_size,
|
||||
N,
|
||||
P,
|
||||
B,
|
||||
M,
|
||||
S,
|
||||
@ -438,7 +472,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsFineCuda(
|
||||
// ****************************************************************************
|
||||
// TODO(T55115174) Add more documentation for backward kernel.
|
||||
__global__ void RasterizePointsBackwardCudaKernel(
|
||||
const float* points, // (N, P, 3)
|
||||
const float* points, // (P, 3)
|
||||
const int32_t* idxs, // (N, H, W, K)
|
||||
const int N,
|
||||
const int P,
|
||||
@ -447,13 +481,13 @@ __global__ void RasterizePointsBackwardCudaKernel(
|
||||
const int K,
|
||||
const float* grad_zbuf, // (N, H, W, K)
|
||||
const float* grad_dists, // (N, H, W, K)
|
||||
float* grad_points) { // (N, P, 3)
|
||||
float* grad_points) { // (P, 3)
|
||||
// Parallelized over each of K points per pixel, for each pixel in images of
|
||||
// size H * W, for each image in the batch of size N.
|
||||
int num_threads = gridDim.x * blockDim.x;
|
||||
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
for (int i = tid; i < N * H * W * K; i += num_threads) {
|
||||
const int n = i / (H * W * K);
|
||||
// const int n = i / (H * W * K); // batch index (not needed).
|
||||
const int yxk = i % (H * W * K);
|
||||
const int yi = yxk / (W * K);
|
||||
const int xk = yxk % (W * K);
|
||||
@ -466,15 +500,15 @@ __global__ void RasterizePointsBackwardCudaKernel(
|
||||
if (p < 0)
|
||||
continue;
|
||||
const float grad_dist2 = grad_dists[i];
|
||||
const int p_ind = n * P * 3 + p * 3;
|
||||
const float px = points[p_ind];
|
||||
const int p_ind = p * 3; // index into packed points tensor
|
||||
const float px = points[p_ind + 0];
|
||||
const float py = points[p_ind + 1];
|
||||
const float dx = px - xf;
|
||||
const float dy = py - yf;
|
||||
const float grad_px = 2.0f * grad_dist2 * dx;
|
||||
const float grad_py = 2.0f * grad_dist2 * dy;
|
||||
const float grad_pz = grad_zbuf[i];
|
||||
atomicAdd(grad_points + p_ind, grad_px);
|
||||
atomicAdd(grad_points + p_ind + 0, grad_px);
|
||||
atomicAdd(grad_points + p_ind + 1, grad_py);
|
||||
atomicAdd(grad_points + p_ind + 2, grad_pz);
|
||||
}
|
||||
@ -485,13 +519,13 @@ torch::Tensor RasterizePointsBackwardCuda(
|
||||
const torch::Tensor& idxs, // (N, H, W, K)
|
||||
const torch::Tensor& grad_zbuf, // (N, H, W, K)
|
||||
const torch::Tensor& grad_dists) { // (N, H, W, K)
|
||||
const int N = points.size(0);
|
||||
const int P = points.size(1);
|
||||
const int P = points.size(0);
|
||||
const int N = idxs.size(0);
|
||||
const int H = idxs.size(1);
|
||||
const int W = idxs.size(2);
|
||||
const int K = idxs.size(3);
|
||||
|
||||
torch::Tensor grad_points = torch::zeros({N, P, 3}, points.options());
|
||||
torch::Tensor grad_points = torch::zeros({P, 3}, points.options());
|
||||
const size_t blocks = 1024;
|
||||
const size_t threads = 64;
|
||||
|
||||
|
@ -11,6 +11,8 @@
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsNaiveCpu(
|
||||
const torch::Tensor& points,
|
||||
const torch::Tensor& cloud_to_packed_first_idx,
|
||||
const torch::Tensor& num_points_per_cloud,
|
||||
const int image_size,
|
||||
const float radius,
|
||||
const int points_per_pixel);
|
||||
@ -19,6 +21,8 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsNaiveCpu(
|
||||
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>
|
||||
RasterizePointsNaiveCuda(
|
||||
const torch::Tensor& points,
|
||||
const torch::Tensor& cloud_to_packed_first_idx,
|
||||
const torch::Tensor& num_points_per_cloud,
|
||||
const int image_size,
|
||||
const float radius,
|
||||
const int points_per_pixel);
|
||||
@ -27,16 +31,26 @@ RasterizePointsNaiveCuda(
|
||||
// check whether that point hits the pixel.
|
||||
//
|
||||
// Args:
|
||||
// points: Tensor of shape (N, P, 3) (in NDC)
|
||||
// points: Tensor of shape (P, 3) giving (packed) positions for
|
||||
// points in all N pointclouds in the batch where P is the total
|
||||
// number of points in the batch across all pointclouds. These points
|
||||
// are expected to be in NDC coordinates in the range [-1, 1].
|
||||
// cloud_to_packed_first_idx: LongTensor of shape (N) giving the index in
|
||||
// points_packed of the first point in each pointcloud
|
||||
// in the batch where N is the batch size.
|
||||
// num_points_per_cloud: LongTensor of shape (N) giving the number of points
|
||||
// for each pointcloud in the batch.
|
||||
// radius: Radius of each point (in NDC units)
|
||||
// image_size: (S) Size of the image to return (in pixels)
|
||||
// image_size: (S) Size of the image to return (in pixels)
|
||||
// points_per_pixel: (K) The number closest of points to return for each pixel
|
||||
//
|
||||
// Returns:
|
||||
// A 4 element tuple of:
|
||||
// idxs: int32 Tensor of shape (N, S, S, K) giving the indices of the
|
||||
// closest K points along the z-axis for each pixel, padded with -1 for
|
||||
// pixels
|
||||
// hit by fewer than K points.
|
||||
// pixels hit by fewer than K points. The indices refer to points in
|
||||
// points packed i.e a tensor of shape (P, 3) representing the flattened
|
||||
// points for all pointclouds in the batch.
|
||||
// zbuf: float32 Tensor of shape (N, S, S, K) giving the depth of each
|
||||
// closest point for each pixel.
|
||||
// dists: float32 Tensor of shape (N, S, S, K) giving squared Euclidean
|
||||
@ -44,19 +58,32 @@ RasterizePointsNaiveCuda(
|
||||
// points along the z axis.
|
||||
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsNaive(
|
||||
const torch::Tensor& points,
|
||||
const torch::Tensor& cloud_to_packed_first_idx,
|
||||
const torch::Tensor& num_points_per_cloud,
|
||||
const int image_size,
|
||||
const float radius,
|
||||
const int points_per_pixel) {
|
||||
if (points.type().is_cuda()) {
|
||||
if (points.type().is_cuda() && cloud_to_packed_first_idx.type().is_cuda() &&
|
||||
num_points_per_cloud.type().is_cuda()) {
|
||||
#ifdef WITH_CUDA
|
||||
return RasterizePointsNaiveCuda(
|
||||
points, image_size, radius, points_per_pixel);
|
||||
points,
|
||||
cloud_to_packed_first_idx,
|
||||
num_points_per_cloud,
|
||||
image_size,
|
||||
radius,
|
||||
points_per_pixel);
|
||||
#else
|
||||
AT_ERROR("Not compiled with GPU support");
|
||||
#endif
|
||||
} else {
|
||||
return RasterizePointsNaiveCpu(
|
||||
points, image_size, radius, points_per_pixel);
|
||||
points,
|
||||
cloud_to_packed_first_idx,
|
||||
num_points_per_cloud,
|
||||
image_size,
|
||||
radius,
|
||||
points_per_pixel);
|
||||
}
|
||||
}
|
||||
|
||||
@ -66,6 +93,8 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsNaive(
|
||||
|
||||
torch::Tensor RasterizePointsCoarseCpu(
|
||||
const torch::Tensor& points,
|
||||
const torch::Tensor& cloud_to_packed_first_idx,
|
||||
const torch::Tensor& num_points_per_cloud,
|
||||
const int image_size,
|
||||
const float radius,
|
||||
const int bin_size,
|
||||
@ -74,13 +103,23 @@ torch::Tensor RasterizePointsCoarseCpu(
|
||||
#ifdef WITH_CUDA
|
||||
torch::Tensor RasterizePointsCoarseCuda(
|
||||
const torch::Tensor& points,
|
||||
const torch::Tensor& cloud_to_packed_first_idx,
|
||||
const torch::Tensor& num_points_per_cloud,
|
||||
const int image_size,
|
||||
const float radius,
|
||||
const int bin_size,
|
||||
const int max_points_per_bin);
|
||||
#endif
|
||||
// Args:
|
||||
// points: Tensor of shape (N, P, 3)
|
||||
// points: Tensor of shape (P, 3) giving (packed) positions for
|
||||
// points in all N pointclouds in the batch where P is the total
|
||||
// number of points in the batch across all pointclouds. These points
|
||||
// are expected to be in NDC coordinates in the range [-1, 1].
|
||||
// cloud_to_packed_first_idx: LongTensor of shape (N) giving the index in
|
||||
// points_packed of the first point in each pointcloud
|
||||
// in the batch where N is the batch size.
|
||||
// num_points_per_cloud: LongTensor of shape (N) giving the number of points
|
||||
// for each pointcloud in the batch.
|
||||
// radius: Radius of points to rasterize (in NDC units)
|
||||
// image_size: Size of the image to generate (in pixels)
|
||||
// bin_size: Size of each bin within the image (in pixels)
|
||||
@ -92,20 +131,35 @@ torch::Tensor RasterizePointsCoarseCuda(
|
||||
// of points that fall into each bin.
|
||||
torch::Tensor RasterizePointsCoarse(
|
||||
const torch::Tensor& points,
|
||||
const torch::Tensor& cloud_to_packed_first_idx,
|
||||
const torch::Tensor& num_points_per_cloud,
|
||||
const int image_size,
|
||||
const float radius,
|
||||
const int bin_size,
|
||||
const int max_points_per_bin) {
|
||||
if (points.type().is_cuda()) {
|
||||
if (points.type().is_cuda() && cloud_to_packed_first_idx.type().is_cuda() &&
|
||||
num_points_per_cloud.type().is_cuda()) {
|
||||
#ifdef WITH_CUDA
|
||||
return RasterizePointsCoarseCuda(
|
||||
points, image_size, radius, bin_size, max_points_per_bin);
|
||||
points,
|
||||
cloud_to_packed_first_idx,
|
||||
num_points_per_cloud,
|
||||
image_size,
|
||||
radius,
|
||||
bin_size,
|
||||
max_points_per_bin);
|
||||
#else
|
||||
AT_ERROR("Not compiled with GPU support");
|
||||
#endif
|
||||
} else {
|
||||
return RasterizePointsCoarseCpu(
|
||||
points, image_size, radius, bin_size, max_points_per_bin);
|
||||
points,
|
||||
cloud_to_packed_first_idx,
|
||||
num_points_per_cloud,
|
||||
image_size,
|
||||
radius,
|
||||
bin_size,
|
||||
max_points_per_bin);
|
||||
}
|
||||
}
|
||||
|
||||
@ -123,7 +177,10 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsFineCuda(
|
||||
const int points_per_pixel);
|
||||
#endif
|
||||
// Args:
|
||||
// points: float32 Tensor of shape (N, P, 3)
|
||||
// points: Tensor of shape (P, 3) giving (packed) positions for
|
||||
// points in all N pointclouds in the batch where P is the total
|
||||
// number of points in the batch across all pointclouds. These points
|
||||
// are expected to be in NDC coordinates in the range [-1, 1].
|
||||
// bin_points: int32 Tensor of shape (N, B, B, M) giving the indices of points
|
||||
// that fall into each bin (output from coarse rasterization)
|
||||
// image_size: Size of image to generate (in pixels)
|
||||
@ -132,9 +189,11 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsFineCuda(
|
||||
// points_per_pixel: How many points to rasterize for each pixel
|
||||
//
|
||||
// Returns (same as rasterize_points):
|
||||
// idxs: int32 Tensor of shape (N, S, S, K) giving the indices of the closest
|
||||
// points_per_pixel points along the z-axis for each pixel, padded with
|
||||
// -1 for pixels hit by fewer than points_per_pixel points
|
||||
// idxs: int32 Tensor of shape (N, S, S, K) giving the indices of the
|
||||
// closest K points along the z-axis for each pixel, padded with -1 for
|
||||
// pixels hit by fewer than K points. The indices refer to points in
|
||||
// points packed i.e a tensor of shape (P, 3) representing the flattened
|
||||
// points for all pointclouds in the batch.
|
||||
// zbuf: float32 Tensor of shape (N, S, S, K) giving the depth of each of each
|
||||
// closest point for each pixel
|
||||
// dists: float32 Tensor of shape (N, S, S, K) giving squared Euclidean
|
||||
@ -177,7 +236,10 @@ torch::Tensor RasterizePointsBackwardCuda(
|
||||
const torch::Tensor& grad_dists);
|
||||
#endif
|
||||
// Args:
|
||||
// points: float32 Tensor of shape (N, P, 3)
|
||||
// points: Tensor of shape (P, 3) giving (packed) positions for
|
||||
// points in all N pointclouds in the batch where P is the total
|
||||
// number of points in the batch across all pointclouds. These points
|
||||
// are expected to be in NDC coordinates in the range [-1, 1].
|
||||
// idxs: int32 Tensor of shape (N, H, W, K) (from forward pass)
|
||||
// grad_zbuf: float32 Tensor of shape (N, H, W, K) giving upstream gradient
|
||||
// d(loss)/d(zbuf) of the distances from each pixel to its nearest
|
||||
@ -212,7 +274,15 @@ torch::Tensor RasterizePointsBackward(
|
||||
// it uses either naive or coarse-to-fine rasterization based on bin_size.
|
||||
//
|
||||
// Args:
|
||||
// points: Tensor of shape (N, P, 3) (in NDC)
|
||||
// points: Tensor of shape (P, 3) giving (packed) positions for
|
||||
// points in all N pointclouds in the batch where P is the total
|
||||
// number of points in the batch across all pointclouds. These points
|
||||
// are expected to be in NDC coordinates in the range [-1, 1].
|
||||
// cloud_to_packed_first_idx: LongTensor of shape (N) giving the index in
|
||||
// points_packed of the first point in each pointcloud
|
||||
// in the batch where N is the batch size.
|
||||
// num_points_per_cloud: LongTensor of shape (N) giving the number of points
|
||||
// for each pointcloud in the batch.
|
||||
// radius: Radius of each point (in NDC units)
|
||||
// image_size: (S) Size of the image to return (in pixels)
|
||||
// points_per_pixel: (K) The number of points to return for each pixel
|
||||
@ -223,8 +293,10 @@ torch::Tensor RasterizePointsBackward(
|
||||
//
|
||||
// Returns:
|
||||
// idxs: int32 Tensor of shape (N, S, S, K) giving the indices of the
|
||||
// closest points_per_pixel points along the z-axis for each pixel,
|
||||
// padded with -1 for pixels hit by fewer than points_per_pixel points
|
||||
// closest K points along the z-axis for each pixel, padded with -1 for
|
||||
// pixels hit by fewer than K points. The indices refer to points in
|
||||
// points packed i.e a tensor of shape (P, 3) representing the flattened
|
||||
// points for all pointclouds in the batch.
|
||||
// zbuf: float32 Tensor of shape (N, S, S, K) giving the depth of each of each
|
||||
// closest point for each pixel
|
||||
// dists: float32 Tensor of shape (N, S, S, K) giving squared Euclidean
|
||||
@ -232,6 +304,8 @@ torch::Tensor RasterizePointsBackward(
|
||||
// points along the z axis.
|
||||
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePoints(
|
||||
const torch::Tensor& points,
|
||||
const torch::Tensor& cloud_to_packed_first_idx,
|
||||
const torch::Tensor& num_points_per_cloud,
|
||||
const int image_size,
|
||||
const float radius,
|
||||
const int points_per_pixel,
|
||||
@ -239,11 +313,23 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePoints(
|
||||
const int max_points_per_bin) {
|
||||
if (bin_size == 0) {
|
||||
// Use the naive per-pixel implementation
|
||||
return RasterizePointsNaive(points, image_size, radius, points_per_pixel);
|
||||
return RasterizePointsNaive(
|
||||
points,
|
||||
cloud_to_packed_first_idx,
|
||||
num_points_per_cloud,
|
||||
image_size,
|
||||
radius,
|
||||
points_per_pixel);
|
||||
} else {
|
||||
// Use coarse-to-fine rasterization
|
||||
const auto bin_points = RasterizePointsCoarse(
|
||||
points, image_size, radius, bin_size, max_points_per_bin);
|
||||
points,
|
||||
cloud_to_packed_first_idx,
|
||||
num_points_per_cloud,
|
||||
image_size,
|
||||
radius,
|
||||
bin_size,
|
||||
max_points_per_bin);
|
||||
return RasterizePointsFine(
|
||||
points, bin_points, image_size, radius, bin_size, points_per_pixel);
|
||||
}
|
||||
|
@ -13,37 +13,49 @@ static float PixToNdc(const int i, const int S) {
|
||||
}
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsNaiveCpu(
|
||||
const torch::Tensor& points,
|
||||
const torch::Tensor& points, // (P, 3)
|
||||
const torch::Tensor& cloud_to_packed_first_idx, // (N)
|
||||
const torch::Tensor& num_points_per_cloud, // (N)
|
||||
const int image_size,
|
||||
const float radius,
|
||||
const int points_per_pixel) {
|
||||
const int N = points.size(0);
|
||||
const int P = points.size(1);
|
||||
const int32_t N = cloud_to_packed_first_idx.size(0); // batch_size.
|
||||
|
||||
const int S = image_size;
|
||||
const int K = points_per_pixel;
|
||||
|
||||
// Initialize output tensors.
|
||||
auto int_opts = points.options().dtype(torch::kInt32);
|
||||
auto float_opts = points.options().dtype(torch::kFloat32);
|
||||
torch::Tensor point_idxs = torch::full({N, S, S, K}, -1, int_opts);
|
||||
torch::Tensor zbuf = torch::full({N, S, S, K}, -1, float_opts);
|
||||
torch::Tensor pix_dists = torch::full({N, S, S, K}, -1, float_opts);
|
||||
|
||||
auto points_a = points.accessor<float, 3>();
|
||||
auto points_a = points.accessor<float, 2>();
|
||||
auto point_idxs_a = point_idxs.accessor<int32_t, 4>();
|
||||
auto zbuf_a = zbuf.accessor<float, 4>();
|
||||
auto pix_dists_a = pix_dists.accessor<float, 4>();
|
||||
|
||||
const float radius2 = radius * radius;
|
||||
for (int n = 0; n < N; ++n) {
|
||||
// Loop through each pointcloud in the batch.
|
||||
// Get the start index of the points in points_packed and the num points
|
||||
// in the point cloud.
|
||||
const int point_start_idx =
|
||||
cloud_to_packed_first_idx[n].item().to<int32_t>();
|
||||
const int point_stop_idx =
|
||||
(point_start_idx + num_points_per_cloud[n].item().to<int32_t>());
|
||||
|
||||
for (int yi = 0; yi < S; ++yi) {
|
||||
float yf = PixToNdc(yi, S);
|
||||
for (int xi = 0; xi < S; ++xi) {
|
||||
float xf = PixToNdc(xi, S);
|
||||
// Use a priority queue to hold (z, idx, r)
|
||||
std::priority_queue<std::tuple<float, int, float>> q;
|
||||
for (int p = 0; p < P; ++p) {
|
||||
const float px = points_a[n][p][0];
|
||||
const float py = points_a[n][p][1];
|
||||
const float pz = points_a[n][p][2];
|
||||
for (int p = point_start_idx; p < point_stop_idx; ++p) {
|
||||
const float px = points_a[p][0];
|
||||
const float py = points_a[p][1];
|
||||
const float pz = points_a[p][2];
|
||||
if (pz < 0) {
|
||||
continue;
|
||||
}
|
||||
@ -75,26 +87,37 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsNaiveCpu(
|
||||
}
|
||||
|
||||
torch::Tensor RasterizePointsCoarseCpu(
|
||||
const torch::Tensor& points,
|
||||
const torch::Tensor& points, // (P, 3)
|
||||
const torch::Tensor& cloud_to_packed_first_idx, // (N)
|
||||
const torch::Tensor& num_points_per_cloud, // (N)
|
||||
const int image_size,
|
||||
const float radius,
|
||||
const int bin_size,
|
||||
const int max_points_per_bin) {
|
||||
const int N = points.size(0);
|
||||
const int P = points.size(1);
|
||||
const int32_t N = cloud_to_packed_first_idx.size(0); // batch_size.
|
||||
|
||||
const int B = 1 + (image_size - 1) / bin_size; // Integer division round up
|
||||
const int M = max_points_per_bin;
|
||||
auto opts = points.options().dtype(torch::kInt32);
|
||||
torch::Tensor points_per_bin = torch::zeros({N, B, B}, opts);
|
||||
torch::Tensor bin_points = torch::full({N, B, B, M}, -1, opts);
|
||||
|
||||
auto points_a = points.accessor<float, 3>();
|
||||
auto points_a = points.accessor<float, 2>();
|
||||
auto points_per_bin_a = points_per_bin.accessor<int32_t, 3>();
|
||||
auto bin_points_a = bin_points.accessor<int32_t, 4>();
|
||||
|
||||
const float pixel_width = 2.0f / image_size;
|
||||
const float bin_width = pixel_width * bin_size;
|
||||
|
||||
for (int n = 0; n < N; ++n) {
|
||||
// Loop through each pointcloud in the batch.
|
||||
// Get the start index of the points in points_packed and the num points
|
||||
// in the point cloud.
|
||||
const int point_start_idx =
|
||||
cloud_to_packed_first_idx[n].item().to<int32_t>();
|
||||
const int point_stop_idx =
|
||||
(point_start_idx + num_points_per_cloud[n].item().to<int32_t>());
|
||||
|
||||
float bin_y_min = -1.0f;
|
||||
float bin_y_max = bin_y_min + bin_width;
|
||||
for (int by = 0; by < B; by++) {
|
||||
@ -102,10 +125,10 @@ torch::Tensor RasterizePointsCoarseCpu(
|
||||
float bin_x_max = bin_x_min + bin_width;
|
||||
for (int bx = 0; bx < B; bx++) {
|
||||
int32_t points_hit = 0;
|
||||
for (int32_t p = 0; p < P; p++) {
|
||||
float px = points_a[n][p][0];
|
||||
float py = points_a[n][p][1];
|
||||
float pz = points_a[n][p][2];
|
||||
for (int p = point_start_idx; p < point_stop_idx; ++p) {
|
||||
float px = points_a[p][0];
|
||||
float py = points_a[p][1];
|
||||
float pz = points_a[p][2];
|
||||
if (pz < 0) {
|
||||
continue;
|
||||
}
|
||||
@ -144,12 +167,13 @@ torch::Tensor RasterizePointsCoarseCpu(
|
||||
}
|
||||
|
||||
torch::Tensor RasterizePointsBackwardCpu(
|
||||
const torch::Tensor& points, // (N, P, 3)
|
||||
const torch::Tensor& points, // (P, 3)
|
||||
const torch::Tensor& idxs, // (N, H, W, K)
|
||||
const torch::Tensor& grad_zbuf, // (N, H, W, K)
|
||||
const torch::Tensor& grad_dists) { // (N, H, W, K)
|
||||
const int N = points.size(0);
|
||||
const int P = points.size(1);
|
||||
|
||||
const int N = idxs.size(0);
|
||||
const int P = points.size(0);
|
||||
const int H = idxs.size(1);
|
||||
const int W = idxs.size(2);
|
||||
const int K = idxs.size(3);
|
||||
@ -159,13 +183,13 @@ torch::Tensor RasterizePointsBackwardCpu(
|
||||
if (H != W) {
|
||||
AT_ERROR("RasterizePointsBackwardCpu only supports square images");
|
||||
}
|
||||
torch::Tensor grad_points = torch::zeros({N, P, 3}, points.options());
|
||||
torch::Tensor grad_points = torch::zeros({P, 3}, points.options());
|
||||
|
||||
auto points_a = points.accessor<float, 3>();
|
||||
auto points_a = points.accessor<float, 2>();
|
||||
auto idxs_a = idxs.accessor<int32_t, 4>();
|
||||
auto grad_dists_a = grad_dists.accessor<float, 4>();
|
||||
auto grad_zbuf_a = grad_zbuf.accessor<float, 4>();
|
||||
auto grad_points_a = grad_points.accessor<float, 3>();
|
||||
auto grad_points_a = grad_points.accessor<float, 2>();
|
||||
|
||||
for (int n = 0; n < N; ++n) { // Loop over images in the batch
|
||||
for (int y = 0; y < H; ++y) { // Loop over rows in the image
|
||||
@ -178,16 +202,16 @@ torch::Tensor RasterizePointsBackwardCpu(
|
||||
break;
|
||||
}
|
||||
const float grad_dist2 = grad_dists_a[n][y][x][k];
|
||||
const float px = points_a[n][p][0];
|
||||
const float py = points_a[n][p][1];
|
||||
const float px = points_a[p][0];
|
||||
const float py = points_a[p][1];
|
||||
const float dx = px - xf;
|
||||
const float dy = py - yf;
|
||||
// Remember: dists[n][y][x][k] = dx * dx + dy * dy;
|
||||
const float grad_px = 2.0f * grad_dist2 * dx;
|
||||
const float grad_py = 2.0f * grad_dist2 * dy;
|
||||
grad_points_a[n][p][0] += grad_px;
|
||||
grad_points_a[n][p][1] += grad_py;
|
||||
grad_points_a[n][p][2] += grad_zbuf_a[n][y][x][k];
|
||||
grad_points_a[p][0] += grad_px;
|
||||
grad_points_a[p][1] += grad_py;
|
||||
grad_points_a[p][2] += grad_zbuf_a[n][y][x][k];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user