mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2026-02-07 14:52:21 +08:00
Non Square image rasterization for pointclouds
Summary: Similar to non square image rasterization for meshes, apply the same updates to the pointcloud rasterizer. Main API Change: - PointRasterizationSettings now accepts a tuple/list of (H, W) for the image size. Reviewed By: jcjohnson Differential Revision: D25465206 fbshipit-source-id: 7370d83c431af1b972158cecae19d82364623380
This commit is contained in:
committed by
Facebook GitHub Bot
parent
569e5229a9
commit
3d769a66cb
@@ -30,15 +30,15 @@ __global__ void alphaCompositeCudaForwardKernel(
|
||||
// Get the batch and index
|
||||
const int batch = blockIdx.x;
|
||||
|
||||
const int num_pixels = C * W * H;
|
||||
const int num_pixels = C * H * W;
|
||||
const int num_threads = gridDim.y * blockDim.x;
|
||||
const int tid = blockIdx.y * blockDim.x + threadIdx.x;
|
||||
|
||||
// Iterate over each feature in each pixel
|
||||
for (int pid = tid; pid < num_pixels; pid += num_threads) {
|
||||
int ch = pid / (W * H);
|
||||
int j = (pid % (W * H)) / H;
|
||||
int i = (pid % (W * H)) % H;
|
||||
int ch = pid / (H * W);
|
||||
int j = (pid % (H * W)) / W;
|
||||
int i = (pid % (H * W)) % W;
|
||||
|
||||
// alphacomposite the different values
|
||||
float cum_alpha = 1.;
|
||||
@@ -81,16 +81,16 @@ __global__ void alphaCompositeCudaBackwardKernel(
|
||||
// Get the batch and index
|
||||
const int batch = blockIdx.x;
|
||||
|
||||
const int num_pixels = C * W * H;
|
||||
const int num_pixels = C * H * W;
|
||||
const int num_threads = gridDim.y * blockDim.x;
|
||||
const int tid = blockIdx.y * blockDim.x + threadIdx.x;
|
||||
|
||||
// Parallelize over each feature in each pixel in images of size H * W,
|
||||
// for each image in the batch of size batch_size
|
||||
for (int pid = tid; pid < num_pixels; pid += num_threads) {
|
||||
int ch = pid / (W * H);
|
||||
int j = (pid % (W * H)) / H;
|
||||
int i = (pid % (W * H)) % H;
|
||||
int ch = pid / (H * W);
|
||||
int j = (pid % (H * W)) / W;
|
||||
int i = (pid % (H * W)) % W;
|
||||
|
||||
// alphacomposite the different values
|
||||
float cum_alpha = 1.;
|
||||
|
||||
@@ -11,13 +11,13 @@
|
||||
// features: FloatTensor of shape (C, P) which gives the features
|
||||
// of each point where C is the size of the feature and
|
||||
// P the number of points.
|
||||
// alphas: FloatTensor of shape (N, points_per_pixel, W, W) where
|
||||
// alphas: FloatTensor of shape (N, points_per_pixel, H, W) where
|
||||
// points_per_pixel is the number of points in the z-buffer
|
||||
// sorted in z-order, and W is the image size.
|
||||
// points_idx: IntTensor of shape (N, points_per_pixel, W, W) giving the
|
||||
// sorted in z-order, and (H, W) is the image size.
|
||||
// points_idx: IntTensor of shape (N, points_per_pixel, H, W) giving the
|
||||
// indices of the nearest points at each pixel, sorted in z-order.
|
||||
// Returns:
|
||||
// weighted_fs: FloatTensor of shape (N, C, W, W) giving the accumulated
|
||||
// weighted_fs: FloatTensor of shape (N, C, H, W) giving the accumulated
|
||||
// feature for each point. Concretely, it gives:
|
||||
// weighted_fs[b,c,i,j] = sum_k cum_alpha_k *
|
||||
// features[c,points_idx[b,k,i,j]]
|
||||
|
||||
@@ -30,16 +30,16 @@ __global__ void weightedSumNormCudaForwardKernel(
|
||||
// Get the batch and index
|
||||
const int batch = blockIdx.x;
|
||||
|
||||
const int num_pixels = C * W * H;
|
||||
const int num_pixels = C * H * W;
|
||||
const int num_threads = gridDim.y * blockDim.x;
|
||||
const int tid = blockIdx.y * blockDim.x + threadIdx.x;
|
||||
|
||||
// Parallelize over each feature in each pixel in images of size H * W,
|
||||
// for each image in the batch of size batch_size
|
||||
for (int pid = tid; pid < num_pixels; pid += num_threads) {
|
||||
int ch = pid / (W * H);
|
||||
int j = (pid % (W * H)) / H;
|
||||
int i = (pid % (W * H)) % H;
|
||||
int ch = pid / (H * W);
|
||||
int j = (pid % (H * W)) / W;
|
||||
int i = (pid % (H * W)) % W;
|
||||
|
||||
// Store the accumulated alpha value
|
||||
float cum_alpha = 0.;
|
||||
@@ -101,9 +101,9 @@ __global__ void weightedSumNormCudaBackwardKernel(
|
||||
// Parallelize over each feature in each pixel in images of size H * W,
|
||||
// for each image in the batch of size batch_size
|
||||
for (int pid = tid; pid < num_pixels; pid += num_threads) {
|
||||
int ch = pid / (W * H);
|
||||
int j = (pid % (W * H)) / H;
|
||||
int i = (pid % (W * H)) % H;
|
||||
int ch = pid / (H * W);
|
||||
int j = (pid % (H * W)) / W;
|
||||
int i = (pid % (H * W)) % W;
|
||||
|
||||
float sum_alpha = 0.;
|
||||
float sum_alphafs = 0.;
|
||||
|
||||
@@ -11,13 +11,13 @@
|
||||
// features: FloatTensor of shape (C, P) which gives the features
|
||||
// of each point where C is the size of the feature and
|
||||
// P the number of points.
|
||||
// alphas: FloatTensor of shape (N, points_per_pixel, W, W) where
|
||||
// alphas: FloatTensor of shape (N, points_per_pixel, H, W) where
|
||||
// points_per_pixel is the number of points in the z-buffer
|
||||
// sorted in z-order, and W is the image size.
|
||||
// points_idx: IntTensor of shape (N, points_per_pixel, W, W) giving the
|
||||
// sorted in z-order, and (H, W) is the image size.
|
||||
// points_idx: IntTensor of shape (N, points_per_pixel, H, W) giving the
|
||||
// indices of the nearest points at each pixel, sorted in z-order.
|
||||
// Returns:
|
||||
// weighted_fs: FloatTensor of shape (N, C, W, W) giving the accumulated
|
||||
// weighted_fs: FloatTensor of shape (N, C, H, W) giving the accumulated
|
||||
// feature in each point. Concretely, it gives:
|
||||
// weighted_fs[b,c,i,j] = sum_k alphas[b,k,i,j] *
|
||||
// features[c,points_idx[b,k,i,j]] / sum_k alphas[b,k,i,j]
|
||||
|
||||
@@ -28,16 +28,16 @@ __global__ void weightedSumCudaForwardKernel(
|
||||
// Get the batch and index
|
||||
const int batch = blockIdx.x;
|
||||
|
||||
const int num_pixels = C * W * H;
|
||||
const int num_pixels = C * H * W;
|
||||
const int num_threads = gridDim.y * blockDim.x;
|
||||
const int tid = blockIdx.y * blockDim.x + threadIdx.x;
|
||||
|
||||
// Parallelize over each feature in each pixel in images of size H * W,
|
||||
// for each image in the batch of size batch_size
|
||||
for (int pid = tid; pid < num_pixels; pid += num_threads) {
|
||||
int ch = pid / (W * H);
|
||||
int j = (pid % (W * H)) / H;
|
||||
int i = (pid % (W * H)) % H;
|
||||
int ch = pid / (H * W);
|
||||
int j = (pid % (H * W)) / W;
|
||||
int i = (pid % (H * W)) % W;
|
||||
|
||||
// Iterate through the closest K points for this pixel
|
||||
for (int k = 0; k < points_idx.size(1); ++k) {
|
||||
@@ -76,16 +76,16 @@ __global__ void weightedSumCudaBackwardKernel(
|
||||
// Get the batch and index
|
||||
const int batch = blockIdx.x;
|
||||
|
||||
const int num_pixels = C * W * H;
|
||||
const int num_pixels = C * H * W;
|
||||
const int num_threads = gridDim.y * blockDim.x;
|
||||
const int tid = blockIdx.y * blockDim.x + threadIdx.x;
|
||||
|
||||
// Iterate over each pixel to compute the contribution to the
|
||||
// gradient for the features and weights
|
||||
for (int pid = tid; pid < num_pixels; pid += num_threads) {
|
||||
int ch = pid / (W * H);
|
||||
int j = (pid % (W * H)) / H;
|
||||
int i = (pid % (W * H)) % H;
|
||||
int ch = pid / (H * W);
|
||||
int j = (pid % (H * W)) / W;
|
||||
int i = (pid % (H * W)) % W;
|
||||
|
||||
// Iterate through the closest K points for this pixel
|
||||
for (int k = 0; k < points_idx.size(1); ++k) {
|
||||
|
||||
@@ -11,13 +11,13 @@
|
||||
// features: FloatTensor of shape (C, P) which gives the features
|
||||
// of each point where C is the size of the feature and
|
||||
// P the number of points.
|
||||
// alphas: FloatTensor of shape (N, points_per_pixel, W, W) where
|
||||
// alphas: FloatTensor of shape (N, points_per_pixel, H, W) where
|
||||
// points_per_pixel is the number of points in the z-buffer
|
||||
// sorted in z-order, and W is the image size.
|
||||
// sorted in z-order, and (H, W) is the image size.
|
||||
// points_idx: IntTensor of shape (N, points_per_pixel, W, W) giving the
|
||||
// indices of the nearest points at each pixel, sorted in z-order.
|
||||
// Returns:
|
||||
// weighted_fs: FloatTensor of shape (N, C, W, W) giving the accumulated
|
||||
// weighted_fs: FloatTensor of shape (N, C, H, W) giving the accumulated
|
||||
// feature in each point. Concretely, it gives:
|
||||
// weighted_fs[b,c,i,j] = sum_k alphas[b,k,i,j] *
|
||||
// features[c,points_idx[b,k,i,j]]
|
||||
|
||||
@@ -452,7 +452,6 @@ __global__ void RasterizeMeshesBackwardCudaKernel(
|
||||
const bool inside = b_pp.x > 0.0f && b_pp.y > 0.0f && b_pp.z > 0.0f;
|
||||
const float sign = inside ? -1.0f : 1.0f;
|
||||
|
||||
// TODO(T52813608) Add support for non-square images.
|
||||
auto grad_dist_f = PointTriangleDistanceBackward(
|
||||
pxy, v0xy, v1xy, v2xy, sign * grad_dist_upstream);
|
||||
const float2 ddist_d_v0 = thrust::get<1>(grad_dist_f);
|
||||
@@ -606,7 +605,7 @@ __global__ void RasterizeMeshesCoarseCudaKernel(
|
||||
const float half_pix_x = NDC_x_half_range / W;
|
||||
const float half_pix_y = NDC_y_half_range / H;
|
||||
|
||||
// This is a boolean array of shape (num_bins, num_bins, chunk_size)
|
||||
// This is a boolean array of shape (num_bins_y, num_bins_x, chunk_size)
|
||||
// stored in shared memory that will track whether each point in the chunk
|
||||
// falls into each bin of the image.
|
||||
BitMask binmask((unsigned int*)sbuf, num_bins_y, num_bins_x, chunk_size);
|
||||
@@ -755,7 +754,7 @@ at::Tensor RasterizeMeshesCoarseCuda(
|
||||
const int num_bins_y = 1 + (H - 1) / bin_size;
|
||||
const int num_bins_x = 1 + (W - 1) / bin_size;
|
||||
|
||||
if (num_bins_y >= kMaxFacesPerBin || num_bins_x >= kMaxFacesPerBin) {
|
||||
if (num_bins_y >= kMaxItemsPerBin || num_bins_x >= kMaxItemsPerBin) {
|
||||
std::stringstream ss;
|
||||
ss << "In Coarse Rasterizer got num_bins_y: " << num_bins_y
|
||||
<< ", num_bins_x: " << num_bins_x << ", "
|
||||
@@ -800,7 +799,7 @@ at::Tensor RasterizeMeshesCoarseCuda(
|
||||
// ****************************************************************************
|
||||
__global__ void RasterizeMeshesFineCudaKernel(
|
||||
const float* face_verts, // (F, 3, 3)
|
||||
const int32_t* bin_faces, // (N, B, B, T)
|
||||
const int32_t* bin_faces, // (N, BH, BW, T)
|
||||
const float blur_radius,
|
||||
const int bin_size,
|
||||
const bool perspective_correct,
|
||||
@@ -813,12 +812,12 @@ __global__ void RasterizeMeshesFineCudaKernel(
|
||||
const int H,
|
||||
const int W,
|
||||
const int K,
|
||||
int64_t* face_idxs, // (N, S, S, K)
|
||||
float* zbuf, // (N, S, S, K)
|
||||
float* pix_dists, // (N, S, S, K)
|
||||
float* bary // (N, S, S, K, 3)
|
||||
int64_t* face_idxs, // (N, H, W, K)
|
||||
float* zbuf, // (N, H, W, K)
|
||||
float* pix_dists, // (N, H, W, K)
|
||||
float* bary // (N, H, W, K, 3)
|
||||
) {
|
||||
// This can be more than S^2 if S % bin_size != 0
|
||||
// This can be more than H * W if H or W are not divisible by bin_size.
|
||||
int num_pixels = N * BH * BW * bin_size * bin_size;
|
||||
int num_threads = gridDim.x * blockDim.x;
|
||||
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
@@ -5,41 +5,11 @@
|
||||
#include <list>
|
||||
#include <queue>
|
||||
#include <tuple>
|
||||
#include "rasterize_points/rasterization_utils.h"
|
||||
#include "utils/geometry_utils.h"
|
||||
#include "utils/vec2.h"
|
||||
#include "utils/vec3.h"
|
||||
|
||||
// The default value of the NDC range is [-1, 1], however in the case that
|
||||
// H != W, the NDC range is set such that the shorter side has range [-1, 1] and
|
||||
// the longer side is scaled by the ratio of H:W. S1 is the dimension for which
|
||||
// the NDC range is calculated and S2 is the other image dimension.
|
||||
// e.g. to get the NDC x range S1 = W and S2 = H
|
||||
float NonSquareNdcRange(int S1, int S2) {
|
||||
float range = 2.0f;
|
||||
if (S1 > S2) {
|
||||
range = ((S1 / S2) * range);
|
||||
}
|
||||
return range;
|
||||
}
|
||||
|
||||
// Given a pixel coordinate 0 <= i < S1, convert it to a normalized device
|
||||
// coordinates. We divide the NDC range into S1 evenly-sized
|
||||
// pixels, and assume that each pixel falls in the *center* of its range.
|
||||
// The default value of the NDC range is [-1, 1], however in the case that
|
||||
// H != W, the NDC range is set such that the shorter side has range [-1, 1] and
|
||||
// the longer side is scaled by the ratio of H:W. The dimension of i should be
|
||||
// S1 and the other image dimension is S2 For example, to get the x and y NDC
|
||||
// coordinates or a given pixel i:
|
||||
// x = PixToNonSquareNdc(i, W, H)
|
||||
// y = PixToNonSquareNdc(i, H, W)
|
||||
float PixToNonSquareNdc(int i, int S1, int S2) {
|
||||
float range = NonSquareNdcRange(S1, S2);
|
||||
// NDC: offset + (i * pixel_width + half_pixel_width)
|
||||
// The NDC range is [-range/2, range/2].
|
||||
const float offset = (range / 2.0f);
|
||||
return -offset + (range * i + offset) / S1;
|
||||
}
|
||||
|
||||
// Get (x, y, z) values for vertex from (3, 3) tensor face.
|
||||
template <typename Face>
|
||||
auto ExtractVerts(const Face& face, const int vertex_index) {
|
||||
|
||||
@@ -2,16 +2,6 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
// Given a pixel coordinate 0 <= i < S, convert it to a normalized device
|
||||
// coordinates in the range [-1, 1]. We divide the NDC range into S evenly-sized
|
||||
// pixels, and assume that each pixel falls in the *center* of its range.
|
||||
// TODO: delete this function after updating the pointcloud rasterizer to
|
||||
// support non square images.
|
||||
__device__ inline float PixToNdc(int i, int S) {
|
||||
// NDC: x-offset + (i * pixel_width + half_pixel_width)
|
||||
return -1.0 + (2 * i + 1.0) / S;
|
||||
}
|
||||
|
||||
// The default value of the NDC range is [-1, 1], however in the case that
|
||||
// H != W, the NDC range is set such that the shorter side has range [-1, 1] and
|
||||
// the longer side is scaled by the ratio of H:W. S1 is the dimension for which
|
||||
@@ -50,7 +40,7 @@ __device__ inline float PixToNonSquareNdc(int i, int S1, int S2) {
|
||||
// TODO: is 8 enough? Would increasing have performance considerations?
|
||||
const int32_t kMaxPointsPerPixel = 150;
|
||||
|
||||
const int32_t kMaxFacesPerBin = 22;
|
||||
const int32_t kMaxItemsPerBin = 22;
|
||||
|
||||
template <typename T>
|
||||
__device__ inline void BubbleSort(T* arr, int n) {
|
||||
|
||||
34
pytorch3d/csrc/rasterize_points/rasterization_utils.h
Normal file
34
pytorch3d/csrc/rasterize_points/rasterization_utils.h
Normal file
@@ -0,0 +1,34 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
// The default value of the NDC range is [-1, 1], however in the case that
|
||||
// H != W, the NDC range is set such that the shorter side has range [-1, 1] and
|
||||
// the longer side is scaled by the ratio of H:W. S1 is the dimension for which
|
||||
// the NDC range is calculated and S2 is the other image dimension.
|
||||
// e.g. to get the NDC x range S1 = W and S2 = H
|
||||
inline float NonSquareNdcRange(int S1, int S2) {
|
||||
float range = 2.0f;
|
||||
if (S1 > S2) {
|
||||
range = ((S1 / S2) * range);
|
||||
}
|
||||
return range;
|
||||
}
|
||||
|
||||
// Given a pixel coordinate 0 <= i < S1, convert it to a normalized device
|
||||
// coordinates. We divide the NDC range into S1 evenly-sized
|
||||
// pixels, and assume that each pixel falls in the *center* of its range.
|
||||
// The default value of the NDC range is [-1, 1], however in the case that
|
||||
// H != W, the NDC range is set such that the shorter side has range [-1, 1] and
|
||||
// the longer side is scaled by the ratio of H:W. The dimension of i should be
|
||||
// S1 and the other image dimension is S2 For example, to get the x and y NDC
|
||||
// coordinates or a given pixel i:
|
||||
// x = PixToNonSquareNdc(i, W, H)
|
||||
// y = PixToNonSquareNdc(i, H, W)
|
||||
inline float PixToNonSquareNdc(int i, int S1, int S2) {
|
||||
float range = NonSquareNdcRange(S1, S2);
|
||||
// NDC: offset + (i * pixel_width + half_pixel_width)
|
||||
// The NDC range is [-range/2, range/2].
|
||||
const float offset = (range / 2.0f);
|
||||
return -offset + (range * i + offset) / S1;
|
||||
}
|
||||
@@ -85,26 +85,28 @@ __global__ void RasterizePointsNaiveCudaKernel(
|
||||
const int64_t* num_points_per_cloud, // (N)
|
||||
const float* radius,
|
||||
const int N,
|
||||
const int S,
|
||||
const int H,
|
||||
const int W,
|
||||
const int K,
|
||||
int32_t* point_idxs, // (N, S, S, K)
|
||||
float* zbuf, // (N, S, S, K)
|
||||
float* pix_dists) { // (N, S, S, K)
|
||||
int32_t* point_idxs, // (N, H, W, K)
|
||||
float* zbuf, // (N, H, W, K)
|
||||
float* pix_dists) { // (N, H, W, K)
|
||||
// Simple version: One thread per output pixel
|
||||
const int num_threads = gridDim.x * blockDim.x;
|
||||
const int tid = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
for (int i = tid; i < N * S * S; i += num_threads) {
|
||||
for (int i = tid; i < N * H * W; i += num_threads) {
|
||||
// Convert linear index to 3D index
|
||||
const int n = i / (S * S); // Batch index
|
||||
const int pix_idx = i % (S * S);
|
||||
const int n = i / (H * W); // Batch index
|
||||
const int pix_idx = i % (H * W);
|
||||
|
||||
// Reverse ordering of the X and Y axis as the camera coordinates
|
||||
// assume that +Y is pointing up and +X is pointing left.
|
||||
const int yi = S - 1 - pix_idx / S;
|
||||
const int xi = S - 1 - pix_idx % S;
|
||||
const int yi = H - 1 - pix_idx / W;
|
||||
const int xi = W - 1 - pix_idx % W;
|
||||
|
||||
const float xf = PixToNdc(xi, S);
|
||||
const float yf = PixToNdc(yi, S);
|
||||
// screen coordinates to ndc coordiantes of pixel.
|
||||
const float xf = PixToNonSquareNdc(xi, W, H);
|
||||
const float yf = PixToNonSquareNdc(yi, H, W);
|
||||
|
||||
// For keeping track of the K closest points we want a data structure
|
||||
// that (1) gives O(1) access to the closest point for easy comparisons,
|
||||
@@ -132,7 +134,7 @@ __global__ void RasterizePointsNaiveCudaKernel(
|
||||
points, p_idx, q_size, q_max_z, q_max_idx, q, radius, xf, yf, K);
|
||||
}
|
||||
BubbleSort(q, q_size);
|
||||
int idx = n * S * S * K + pix_idx * K;
|
||||
int idx = n * H * W * K + pix_idx * K;
|
||||
for (int k = 0; k < q_size; ++k) {
|
||||
point_idxs[idx + k] = q[k].idx;
|
||||
zbuf[idx + k] = q[k].z;
|
||||
@@ -145,7 +147,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> RasterizePointsNaiveCuda(
|
||||
const at::Tensor& points, // (P. 3)
|
||||
const at::Tensor& cloud_to_packed_first_idx, // (N)
|
||||
const at::Tensor& num_points_per_cloud, // (N)
|
||||
const int image_size,
|
||||
const std::tuple<int, int> image_size,
|
||||
const at::Tensor& radius,
|
||||
const int points_per_pixel) {
|
||||
// Check inputs are on the same device
|
||||
@@ -169,7 +171,8 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> RasterizePointsNaiveCuda(
|
||||
"num_points_per_cloud must have same size first dimension as cloud_to_packed_first_idx");
|
||||
|
||||
const int N = num_points_per_cloud.size(0); // batch size.
|
||||
const int S = image_size;
|
||||
const int H = std::get<0>(image_size);
|
||||
const int W = std::get<1>(image_size);
|
||||
const int K = points_per_pixel;
|
||||
|
||||
if (K > kMaxPointsPerPixel) {
|
||||
@@ -180,9 +183,9 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> RasterizePointsNaiveCuda(
|
||||
|
||||
auto int_opts = num_points_per_cloud.options().dtype(at::kInt);
|
||||
auto float_opts = points.options().dtype(at::kFloat);
|
||||
at::Tensor point_idxs = at::full({N, S, S, K}, -1, int_opts);
|
||||
at::Tensor zbuf = at::full({N, S, S, K}, -1, float_opts);
|
||||
at::Tensor pix_dists = at::full({N, S, S, K}, -1, float_opts);
|
||||
at::Tensor point_idxs = at::full({N, H, W, K}, -1, int_opts);
|
||||
at::Tensor zbuf = at::full({N, H, W, K}, -1, float_opts);
|
||||
at::Tensor pix_dists = at::full({N, H, W, K}, -1, float_opts);
|
||||
|
||||
if (point_idxs.numel() == 0) {
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
@@ -197,7 +200,8 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> RasterizePointsNaiveCuda(
|
||||
num_points_per_cloud.contiguous().data_ptr<int64_t>(),
|
||||
radius.contiguous().data_ptr<float>(),
|
||||
N,
|
||||
S,
|
||||
H,
|
||||
W,
|
||||
K,
|
||||
point_idxs.contiguous().data_ptr<int32_t>(),
|
||||
zbuf.contiguous().data_ptr<float>(),
|
||||
@@ -218,7 +222,8 @@ __global__ void RasterizePointsCoarseCudaKernel(
|
||||
const float* radius,
|
||||
const int N,
|
||||
const int P,
|
||||
const int S,
|
||||
const int H,
|
||||
const int W,
|
||||
const int bin_size,
|
||||
const int chunk_size,
|
||||
const int max_points_per_bin,
|
||||
@@ -226,13 +231,26 @@ __global__ void RasterizePointsCoarseCudaKernel(
|
||||
int* bin_points) {
|
||||
extern __shared__ char sbuf[];
|
||||
const int M = max_points_per_bin;
|
||||
const int num_bins = 1 + (S - 1) / bin_size; // Integer divide round up
|
||||
const float half_pix = 1.0f / S; // Size of half a pixel in NDC units
|
||||
|
||||
// This is a boolean array of shape (num_bins, num_bins, chunk_size)
|
||||
// Integer divide round up
|
||||
const int num_bins_x = 1 + (W - 1) / bin_size;
|
||||
const int num_bins_y = 1 + (H - 1) / bin_size;
|
||||
|
||||
// NDC range depends on the ratio of W/H
|
||||
// The shorter side from (H, W) is given an NDC range of 2.0 and
|
||||
// the other side is scaled by the ratio of H:W.
|
||||
const float NDC_x_half_range = NonSquareNdcRange(W, H) / 2.0f;
|
||||
const float NDC_y_half_range = NonSquareNdcRange(H, W) / 2.0f;
|
||||
|
||||
// Size of half a pixel in NDC units is the NDC half range
|
||||
// divided by the corresponding image dimension
|
||||
const float half_pix_x = NDC_x_half_range / W;
|
||||
const float half_pix_y = NDC_y_half_range / H;
|
||||
|
||||
// This is a boolean array of shape (num_bins_y, num_bins_x, chunk_size)
|
||||
// stored in shared memory that will track whether each point in the chunk
|
||||
// falls into each bin of the image.
|
||||
BitMask binmask((unsigned int*)sbuf, num_bins, num_bins, chunk_size);
|
||||
BitMask binmask((unsigned int*)sbuf, num_bins_y, num_bins_x, chunk_size);
|
||||
|
||||
// Have each block handle a chunk of points and build a 3D bitmask in
|
||||
// shared memory to mark which points hit which bins. In this first phase,
|
||||
@@ -279,22 +297,24 @@ __global__ void RasterizePointsCoarseCudaKernel(
|
||||
// For example we could compute the exact bin where the point falls,
|
||||
// then check neighboring bins. This way we wouldn't have to check
|
||||
// all bins (however then we might have more warp divergence?)
|
||||
for (int by = 0; by < num_bins; ++by) {
|
||||
// Get y extent for the bin. PixToNdc gives us the location of
|
||||
for (int by = 0; by < num_bins_y; ++by) {
|
||||
// Get y extent for the bin. PixToNonSquareNdc gives us the location of
|
||||
// the center of each pixel, so we need to add/subtract a half
|
||||
// pixel to get the true extent of the bin.
|
||||
const float by0 = PixToNdc(by * bin_size, S) - half_pix;
|
||||
const float by1 = PixToNdc((by + 1) * bin_size - 1, S) + half_pix;
|
||||
const float by0 = PixToNonSquareNdc(by * bin_size, H, W) - half_pix_y;
|
||||
const float by1 =
|
||||
PixToNonSquareNdc((by + 1) * bin_size - 1, H, W) + half_pix_y;
|
||||
const bool y_overlap = (py0 <= by1) && (by0 <= py1);
|
||||
|
||||
if (!y_overlap) {
|
||||
continue;
|
||||
}
|
||||
for (int bx = 0; bx < num_bins; ++bx) {
|
||||
for (int bx = 0; bx < num_bins_x; ++bx) {
|
||||
// Get x extent for the bin; again we need to adjust the
|
||||
// output of PixToNdc by half a pixel.
|
||||
const float bx0 = PixToNdc(bx * bin_size, S) - half_pix;
|
||||
const float bx1 = PixToNdc((bx + 1) * bin_size - 1, S) + half_pix;
|
||||
// output of PixToNonSquareNdc by half a pixel.
|
||||
const float bx0 = PixToNonSquareNdc(bx * bin_size, W, H) - half_pix_x;
|
||||
const float bx1 =
|
||||
PixToNonSquareNdc((bx + 1) * bin_size - 1, W, H) + half_pix_x;
|
||||
const bool x_overlap = (px0 <= bx1) && (bx0 <= px1);
|
||||
|
||||
if (x_overlap) {
|
||||
@@ -307,12 +327,13 @@ __global__ void RasterizePointsCoarseCudaKernel(
|
||||
// Now we have processed every point in the current chunk. We need to
|
||||
// count the number of points in each bin so we can write the indices
|
||||
// out to global memory. We have each thread handle a different bin.
|
||||
for (int byx = threadIdx.x; byx < num_bins * num_bins; byx += blockDim.x) {
|
||||
const int by = byx / num_bins;
|
||||
const int bx = byx % num_bins;
|
||||
for (int byx = threadIdx.x; byx < num_bins_y * num_bins_x;
|
||||
byx += blockDim.x) {
|
||||
const int by = byx / num_bins_x;
|
||||
const int bx = byx % num_bins_x;
|
||||
const int count = binmask.count(by, bx);
|
||||
const int points_per_bin_idx =
|
||||
batch_idx * num_bins * num_bins + by * num_bins + bx;
|
||||
batch_idx * num_bins_y * num_bins_x + by * num_bins_x + bx;
|
||||
|
||||
// This atomically increments the (global) number of points found
|
||||
// in the current bin, and gets the previous value of the counter;
|
||||
@@ -322,8 +343,8 @@ __global__ void RasterizePointsCoarseCudaKernel(
|
||||
|
||||
// Now loop over the binmask and write the active bits for this bin
|
||||
// out to bin_points.
|
||||
int next_idx = batch_idx * num_bins * num_bins * M + by * num_bins * M +
|
||||
bx * M + start;
|
||||
int next_idx = batch_idx * num_bins_y * num_bins_x * M +
|
||||
by * num_bins_x * M + bx * M + start;
|
||||
for (int p = 0; p < chunk_size; ++p) {
|
||||
if (binmask.get(by, bx, p)) {
|
||||
// TODO: Throw an error if next_idx >= M -- this means that
|
||||
@@ -342,7 +363,7 @@ at::Tensor RasterizePointsCoarseCuda(
|
||||
const at::Tensor& points, // (P, 3)
|
||||
const at::Tensor& cloud_to_packed_first_idx, // (N)
|
||||
const at::Tensor& num_points_per_cloud, // (N)
|
||||
const int image_size,
|
||||
const std::tuple<int, int> image_size,
|
||||
const at::Tensor& radius,
|
||||
const int bin_size,
|
||||
const int max_points_per_bin) {
|
||||
@@ -363,20 +384,28 @@ at::Tensor RasterizePointsCoarseCuda(
|
||||
at::cuda::CUDAGuard device_guard(points.device());
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
const int H = std::get<0>(image_size);
|
||||
const int W = std::get<1>(image_size);
|
||||
|
||||
const int P = points.size(0);
|
||||
const int N = num_points_per_cloud.size(0);
|
||||
const int num_bins = 1 + (image_size - 1) / bin_size; // divide round up
|
||||
const int M = max_points_per_bin;
|
||||
|
||||
if (num_bins >= 22) {
|
||||
// Integer divide round up.
|
||||
const int num_bins_y = 1 + (H - 1) / bin_size;
|
||||
const int num_bins_x = 1 + (W - 1) / bin_size;
|
||||
|
||||
if (num_bins_y >= kMaxItemsPerBin || num_bins_x >= kMaxItemsPerBin) {
|
||||
// Make sure we do not use too much shared memory.
|
||||
std::stringstream ss;
|
||||
ss << "Got " << num_bins << "; that's too many!";
|
||||
ss << "In Coarse Rasterizer got num_bins_y: " << num_bins_y
|
||||
<< ", num_bins_x: " << num_bins_x << ", "
|
||||
<< "; that's too many!";
|
||||
AT_ERROR(ss.str());
|
||||
}
|
||||
auto opts = num_points_per_cloud.options().dtype(at::kInt);
|
||||
at::Tensor points_per_bin = at::zeros({N, num_bins, num_bins}, opts);
|
||||
at::Tensor bin_points = at::full({N, num_bins, num_bins, M}, -1, opts);
|
||||
at::Tensor points_per_bin = at::zeros({N, num_bins_y, num_bins_x}, opts);
|
||||
at::Tensor bin_points = at::full({N, num_bins_y, num_bins_x, M}, -1, opts);
|
||||
|
||||
if (bin_points.numel() == 0) {
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
@@ -384,7 +413,7 @@ at::Tensor RasterizePointsCoarseCuda(
|
||||
}
|
||||
|
||||
const int chunk_size = 512;
|
||||
const size_t shared_size = num_bins * num_bins * chunk_size / 8;
|
||||
const size_t shared_size = num_bins_y * num_bins_x * chunk_size / 8;
|
||||
const size_t blocks = 64;
|
||||
const size_t threads = 512;
|
||||
|
||||
@@ -395,7 +424,8 @@ at::Tensor RasterizePointsCoarseCuda(
|
||||
radius.contiguous().data_ptr<float>(),
|
||||
N,
|
||||
P,
|
||||
image_size,
|
||||
H,
|
||||
W,
|
||||
bin_size,
|
||||
chunk_size,
|
||||
M,
|
||||
@@ -412,19 +442,21 @@ at::Tensor RasterizePointsCoarseCuda(
|
||||
|
||||
__global__ void RasterizePointsFineCudaKernel(
|
||||
const float* points, // (P, 3)
|
||||
const int32_t* bin_points, // (N, B, B, T)
|
||||
const int32_t* bin_points, // (N, BH, BW, T)
|
||||
const float* radius,
|
||||
const int bin_size,
|
||||
const int N,
|
||||
const int B, // num_bins
|
||||
const int BH, // num_bins y
|
||||
const int BW, // num_bins x
|
||||
const int M,
|
||||
const int S,
|
||||
const int H,
|
||||
const int W,
|
||||
const int K,
|
||||
int32_t* point_idxs, // (N, S, S, K)
|
||||
float* zbuf, // (N, S, S, K)
|
||||
float* pix_dists) { // (N, S, S, K)
|
||||
// This can be more than S^2 if S is not dividable by bin_size.
|
||||
const int num_pixels = N * B * B * bin_size * bin_size;
|
||||
int32_t* point_idxs, // (N, H, W, K)
|
||||
float* zbuf, // (N, H, W, K)
|
||||
float* pix_dists) { // (N, H, W, K)
|
||||
// This can be more than H * W if H or W are not divisible by bin_size.
|
||||
const int num_pixels = N * BH * BW * bin_size * bin_size;
|
||||
const int num_threads = gridDim.x * blockDim.x;
|
||||
const int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
@@ -434,21 +466,21 @@ __global__ void RasterizePointsFineCudaKernel(
|
||||
// into the same bin; this should give them coalesced memory reads when
|
||||
// they read from points and bin_points.
|
||||
int i = pid;
|
||||
const int n = i / (B * B * bin_size * bin_size);
|
||||
i %= B * B * bin_size * bin_size;
|
||||
const int by = i / (B * bin_size * bin_size);
|
||||
i %= B * bin_size * bin_size;
|
||||
const int n = i / (BH * BW * bin_size * bin_size);
|
||||
i %= BH * BW * bin_size * bin_size;
|
||||
const int by = i / (BW * bin_size * bin_size);
|
||||
i %= BW * bin_size * bin_size;
|
||||
const int bx = i / (bin_size * bin_size);
|
||||
i %= bin_size * bin_size;
|
||||
|
||||
const int yi = i / bin_size + by * bin_size;
|
||||
const int xi = i % bin_size + bx * bin_size;
|
||||
|
||||
if (yi >= S || xi >= S)
|
||||
if (yi >= H || xi >= W)
|
||||
continue;
|
||||
|
||||
const float xf = PixToNdc(xi, S);
|
||||
const float yf = PixToNdc(yi, S);
|
||||
const float xf = PixToNonSquareNdc(xi, W, H);
|
||||
const float yf = PixToNonSquareNdc(yi, H, W);
|
||||
|
||||
// This part looks like the naive rasterization kernel, except we use
|
||||
// bin_points to only look at a subset of points already known to fall
|
||||
@@ -459,7 +491,7 @@ __global__ void RasterizePointsFineCudaKernel(
|
||||
float q_max_z = -1000;
|
||||
int q_max_idx = -1;
|
||||
for (int m = 0; m < M; ++m) {
|
||||
const int p = bin_points[n * B * B * M + by * B * M + bx * M + m];
|
||||
const int p = bin_points[n * BH * BW * M + by * BW * M + bx * M + m];
|
||||
if (p < 0) {
|
||||
// bin_points uses -1 as a sentinal value
|
||||
continue;
|
||||
@@ -473,10 +505,10 @@ __global__ void RasterizePointsFineCudaKernel(
|
||||
|
||||
// Reverse ordering of the X and Y axis as the camera coordinates
|
||||
// assume that +Y is pointing up and +X is pointing left.
|
||||
const int yidx = S - 1 - yi;
|
||||
const int xidx = S - 1 - xi;
|
||||
const int yidx = H - 1 - yi;
|
||||
const int xidx = W - 1 - xi;
|
||||
|
||||
const int pix_idx = n * S * S * K + yidx * S * K + xidx * K;
|
||||
const int pix_idx = n * H * W * K + yidx * W * K + xidx * K;
|
||||
for (int k = 0; k < q_size; ++k) {
|
||||
point_idxs[pix_idx + k] = q[k].idx;
|
||||
zbuf[pix_idx + k] = q[k].z;
|
||||
@@ -488,7 +520,7 @@ __global__ void RasterizePointsFineCudaKernel(
|
||||
std::tuple<at::Tensor, at::Tensor, at::Tensor> RasterizePointsFineCuda(
|
||||
const at::Tensor& points, // (P, 3)
|
||||
const at::Tensor& bin_points,
|
||||
const int image_size,
|
||||
const std::tuple<int, int> image_size,
|
||||
const at::Tensor& radius,
|
||||
const int bin_size,
|
||||
const int points_per_pixel) {
|
||||
@@ -503,18 +535,22 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> RasterizePointsFineCuda(
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
const int N = bin_points.size(0);
|
||||
const int B = bin_points.size(1); // num_bins
|
||||
const int BH = bin_points.size(1);
|
||||
const int BW = bin_points.size(2);
|
||||
const int M = bin_points.size(3);
|
||||
const int S = image_size;
|
||||
const int K = points_per_pixel;
|
||||
|
||||
const int H = std::get<0>(image_size);
|
||||
const int W = std::get<1>(image_size);
|
||||
|
||||
if (K > kMaxPointsPerPixel) {
|
||||
AT_ERROR("Must have num_closest <= 150");
|
||||
}
|
||||
auto int_opts = bin_points.options().dtype(at::kInt);
|
||||
auto float_opts = points.options().dtype(at::kFloat);
|
||||
at::Tensor point_idxs = at::full({N, S, S, K}, -1, int_opts);
|
||||
at::Tensor zbuf = at::full({N, S, S, K}, -1, float_opts);
|
||||
at::Tensor pix_dists = at::full({N, S, S, K}, -1, float_opts);
|
||||
at::Tensor point_idxs = at::full({N, H, W, K}, -1, int_opts);
|
||||
at::Tensor zbuf = at::full({N, H, W, K}, -1, float_opts);
|
||||
at::Tensor pix_dists = at::full({N, H, W, K}, -1, float_opts);
|
||||
|
||||
if (point_idxs.numel() == 0) {
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
@@ -529,9 +565,11 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> RasterizePointsFineCuda(
|
||||
radius.contiguous().data_ptr<float>(),
|
||||
bin_size,
|
||||
N,
|
||||
B,
|
||||
BH,
|
||||
BW,
|
||||
M,
|
||||
S,
|
||||
H,
|
||||
W,
|
||||
K,
|
||||
point_idxs.contiguous().data_ptr<int32_t>(),
|
||||
zbuf.contiguous().data_ptr<float>(),
|
||||
@@ -571,8 +609,8 @@ __global__ void RasterizePointsBackwardCudaKernel(
|
||||
const int yidx = H - 1 - yi;
|
||||
const int xidx = W - 1 - xi;
|
||||
|
||||
const float xf = PixToNdc(xidx, W);
|
||||
const float yf = PixToNdc(yidx, H);
|
||||
const float xf = PixToNonSquareNdc(xidx, W, H);
|
||||
const float yf = PixToNonSquareNdc(yidx, H, W);
|
||||
|
||||
const int p = idxs[i];
|
||||
if (p < 0)
|
||||
|
||||
@@ -14,7 +14,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsNaiveCpu(
|
||||
const torch::Tensor& points,
|
||||
const torch::Tensor& cloud_to_packed_first_idx,
|
||||
const torch::Tensor& num_points_per_cloud,
|
||||
const int image_size,
|
||||
const std::tuple<int, int> image_size,
|
||||
const torch::Tensor& radius,
|
||||
const int points_per_pixel);
|
||||
|
||||
@@ -24,7 +24,7 @@ RasterizePointsNaiveCuda(
|
||||
const torch::Tensor& points,
|
||||
const torch::Tensor& cloud_to_packed_first_idx,
|
||||
const torch::Tensor& num_points_per_cloud,
|
||||
const int image_size,
|
||||
const std::tuple<int, int> image_size,
|
||||
const torch::Tensor& radius,
|
||||
const int points_per_pixel);
|
||||
#endif
|
||||
@@ -43,7 +43,8 @@ RasterizePointsNaiveCuda(
|
||||
// for each pointcloud in the batch.
|
||||
// radius: FloatTensor of shape (P) giving the radius (in NDC units) of
|
||||
// each point in points.
|
||||
// image_size: (S) Size of the image to return (in pixels)
|
||||
// image_size: Tuple (H, W) giving the size in pixels of the output
|
||||
// image to be rasterized.
|
||||
// points_per_pixel: (K) The number closest of points to return for each pixel
|
||||
//
|
||||
// Returns:
|
||||
@@ -62,7 +63,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsNaive(
|
||||
const torch::Tensor& points,
|
||||
const torch::Tensor& cloud_to_packed_first_idx,
|
||||
const torch::Tensor& num_points_per_cloud,
|
||||
const int image_size,
|
||||
const std::tuple<int, int> image_size,
|
||||
const torch::Tensor& radius,
|
||||
const int points_per_pixel) {
|
||||
if (points.is_cuda() && cloud_to_packed_first_idx.is_cuda() &&
|
||||
@@ -101,7 +102,7 @@ torch::Tensor RasterizePointsCoarseCpu(
|
||||
const torch::Tensor& points,
|
||||
const torch::Tensor& cloud_to_packed_first_idx,
|
||||
const torch::Tensor& num_points_per_cloud,
|
||||
const int image_size,
|
||||
const std::tuple<int, int> image_size,
|
||||
const torch::Tensor& radius,
|
||||
const int bin_size,
|
||||
const int max_points_per_bin);
|
||||
@@ -111,7 +112,7 @@ torch::Tensor RasterizePointsCoarseCuda(
|
||||
const torch::Tensor& points,
|
||||
const torch::Tensor& cloud_to_packed_first_idx,
|
||||
const torch::Tensor& num_points_per_cloud,
|
||||
const int image_size,
|
||||
const std::tuple<int, int> image_size,
|
||||
const torch::Tensor& radius,
|
||||
const int bin_size,
|
||||
const int max_points_per_bin);
|
||||
@@ -128,7 +129,8 @@ torch::Tensor RasterizePointsCoarseCuda(
|
||||
// for each pointcloud in the batch.
|
||||
// radius: FloatTensor of shape (P) giving the radius (in NDC units) of
|
||||
// each point in points.
|
||||
// image_size: Size of the image to generate (in pixels)
|
||||
// image_size: Tuple (H, W) giving the size in pixels of the output
|
||||
// image to be rasterized.
|
||||
// bin_size: Size of each bin within the image (in pixels)
|
||||
//
|
||||
// Returns:
|
||||
@@ -140,7 +142,7 @@ torch::Tensor RasterizePointsCoarse(
|
||||
const torch::Tensor& points,
|
||||
const torch::Tensor& cloud_to_packed_first_idx,
|
||||
const torch::Tensor& num_points_per_cloud,
|
||||
const int image_size,
|
||||
const std::tuple<int, int> image_size,
|
||||
const torch::Tensor& radius,
|
||||
const int bin_size,
|
||||
const int max_points_per_bin) {
|
||||
@@ -182,7 +184,7 @@ torch::Tensor RasterizePointsCoarse(
|
||||
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsFineCuda(
|
||||
const torch::Tensor& points,
|
||||
const torch::Tensor& bin_points,
|
||||
const int image_size,
|
||||
const std::tuple<int, int> image_size,
|
||||
const torch::Tensor& radius,
|
||||
const int bin_size,
|
||||
const int points_per_pixel);
|
||||
@@ -194,7 +196,8 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsFineCuda(
|
||||
// are expected to be in NDC coordinates in the range [-1, 1].
|
||||
// bin_points: int32 Tensor of shape (N, B, B, M) giving the indices of points
|
||||
// that fall into each bin (output from coarse rasterization)
|
||||
// image_size: Size of image to generate (in pixels)
|
||||
// image_size: Tuple (H, W) giving the size in pixels of the output
|
||||
// image to be rasterized.
|
||||
// radius: FloatTensor of shape (P) giving the radius (in NDC units) of
|
||||
// each point in points.
|
||||
// bin_size: Size of each bin (in pixels)
|
||||
@@ -214,7 +217,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsFineCuda(
|
||||
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsFine(
|
||||
const torch::Tensor& points,
|
||||
const torch::Tensor& bin_points,
|
||||
const int image_size,
|
||||
const std::tuple<int, int> image_size,
|
||||
const torch::Tensor& radius,
|
||||
const int bin_size,
|
||||
const int points_per_pixel) {
|
||||
@@ -303,7 +306,8 @@ torch::Tensor RasterizePointsBackward(
|
||||
// for each pointcloud in the batch.
|
||||
// radius: FloatTensor of shape (P) giving the radius (in NDC units) of
|
||||
// each point in points.
|
||||
// image_size: (S) Size of the image to return (in pixels)
|
||||
// image_size: Tuple (H, W) giving the size in pixels of the output
|
||||
// image to be rasterized.
|
||||
// points_per_pixel: (K) The number of points to return for each pixel
|
||||
// bin_size: Bin size (in pixels) for coarse-to-fine rasterization. Setting
|
||||
// bin_size=0 uses naive rasterization instead.
|
||||
@@ -325,7 +329,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePoints(
|
||||
const torch::Tensor& points,
|
||||
const torch::Tensor& cloud_to_packed_first_idx,
|
||||
const torch::Tensor& num_points_per_cloud,
|
||||
const int image_size,
|
||||
const std::tuple<int, int> image_size,
|
||||
const torch::Tensor& radius,
|
||||
const int points_per_pixel,
|
||||
const int bin_size,
|
||||
|
||||
@@ -3,33 +3,27 @@
|
||||
#include <torch/extension.h>
|
||||
#include <queue>
|
||||
#include <tuple>
|
||||
|
||||
// Given a pixel coordinate 0 <= i < S, convert it to a normalized device
|
||||
// coordinate in the range [-1, 1]. The NDC range is divided into S evenly-sized
|
||||
// pixels, and assume that each pixel falls in the *center* of its range.
|
||||
static float PixToNdc(const int i, const int S) {
|
||||
// NDC x-offset + (i * pixel_width + half_pixel_width)
|
||||
return -1 + (2 * i + 1.0f) / S;
|
||||
}
|
||||
#include "rasterization_utils.h"
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsNaiveCpu(
|
||||
const torch::Tensor& points, // (P, 3)
|
||||
const torch::Tensor& cloud_to_packed_first_idx, // (N)
|
||||
const torch::Tensor& num_points_per_cloud, // (N)
|
||||
const int image_size,
|
||||
const std::tuple<int, int> image_size,
|
||||
const torch::Tensor& radius,
|
||||
const int points_per_pixel) {
|
||||
const int32_t N = cloud_to_packed_first_idx.size(0); // batch_size.
|
||||
|
||||
const int S = image_size;
|
||||
const int H = std::get<0>(image_size);
|
||||
const int W = std::get<1>(image_size);
|
||||
const int K = points_per_pixel;
|
||||
|
||||
// Initialize output tensors.
|
||||
auto int_opts = num_points_per_cloud.options().dtype(torch::kInt32);
|
||||
auto float_opts = points.options().dtype(torch::kFloat32);
|
||||
torch::Tensor point_idxs = torch::full({N, S, S, K}, -1, int_opts);
|
||||
torch::Tensor zbuf = torch::full({N, S, S, K}, -1, float_opts);
|
||||
torch::Tensor pix_dists = torch::full({N, S, S, K}, -1, float_opts);
|
||||
torch::Tensor point_idxs = torch::full({N, H, W, K}, -1, int_opts);
|
||||
torch::Tensor zbuf = torch::full({N, H, W, K}, -1, float_opts);
|
||||
torch::Tensor pix_dists = torch::full({N, H, W, K}, -1, float_opts);
|
||||
|
||||
auto points_a = points.accessor<float, 2>();
|
||||
auto point_idxs_a = point_idxs.accessor<int32_t, 4>();
|
||||
@@ -46,16 +40,16 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsNaiveCpu(
|
||||
const int point_stop_idx =
|
||||
(point_start_idx + num_points_per_cloud[n].item().to<int32_t>());
|
||||
|
||||
for (int yi = 0; yi < S; ++yi) {
|
||||
for (int yi = 0; yi < H; ++yi) {
|
||||
// Reverse the order of yi so that +Y is pointing upwards in the image.
|
||||
const int yidx = S - 1 - yi;
|
||||
const float yf = PixToNdc(yidx, S);
|
||||
const int yidx = H - 1 - yi;
|
||||
const float yf = PixToNonSquareNdc(yidx, H, W);
|
||||
|
||||
for (int xi = 0; xi < S; ++xi) {
|
||||
for (int xi = 0; xi < W; ++xi) {
|
||||
// Reverse the order of xi so that +X is pointing to the left in the
|
||||
// image.
|
||||
const int xidx = S - 1 - xi;
|
||||
const float xf = PixToNdc(xidx, S);
|
||||
const int xidx = W - 1 - xi;
|
||||
const float xf = PixToNonSquareNdc(xidx, W, H);
|
||||
|
||||
// Use a priority queue to hold (z, idx, r)
|
||||
std::priority_queue<std::tuple<float, int, float>> q;
|
||||
@@ -99,25 +93,36 @@ torch::Tensor RasterizePointsCoarseCpu(
|
||||
const torch::Tensor& points, // (P, 3)
|
||||
const torch::Tensor& cloud_to_packed_first_idx, // (N)
|
||||
const torch::Tensor& num_points_per_cloud, // (N)
|
||||
const int image_size,
|
||||
const std::tuple<int, int> image_size,
|
||||
const torch::Tensor& radius,
|
||||
const int bin_size,
|
||||
const int max_points_per_bin) {
|
||||
const int32_t N = cloud_to_packed_first_idx.size(0); // batch_size.
|
||||
|
||||
const int B = 1 + (image_size - 1) / bin_size; // Integer division round up
|
||||
const int M = max_points_per_bin;
|
||||
|
||||
const float H = std::get<0>(image_size);
|
||||
const float W = std::get<1>(image_size);
|
||||
|
||||
// Integer division round up.
|
||||
const int BH = 1 + (H - 1) / bin_size;
|
||||
const int BW = 1 + (W - 1) / bin_size;
|
||||
|
||||
auto opts = num_points_per_cloud.options().dtype(torch::kInt32);
|
||||
torch::Tensor points_per_bin = torch::zeros({N, B, B}, opts);
|
||||
torch::Tensor bin_points = torch::full({N, B, B, M}, -1, opts);
|
||||
torch::Tensor points_per_bin = torch::zeros({N, BH, BW}, opts);
|
||||
torch::Tensor bin_points = torch::full({N, BH, BW, M}, -1, opts);
|
||||
|
||||
auto points_a = points.accessor<float, 2>();
|
||||
auto points_per_bin_a = points_per_bin.accessor<int32_t, 3>();
|
||||
auto bin_points_a = bin_points.accessor<int32_t, 4>();
|
||||
auto radius_a = radius.accessor<float, 1>();
|
||||
|
||||
const float pixel_width = 2.0f / image_size;
|
||||
const float bin_width = pixel_width * bin_size;
|
||||
const float ndc_x_range = NonSquareNdcRange(W, H);
|
||||
const float pixel_width_x = ndc_x_range / W;
|
||||
const float bin_width_x = pixel_width_x * bin_size;
|
||||
|
||||
const float ndc_y_range = NonSquareNdcRange(H, W);
|
||||
const float pixel_width_y = ndc_y_range / H;
|
||||
const float bin_width_y = pixel_width_y * bin_size;
|
||||
|
||||
for (int n = 0; n < N; ++n) {
|
||||
// Loop through each pointcloud in the batch.
|
||||
@@ -129,15 +134,15 @@ torch::Tensor RasterizePointsCoarseCpu(
|
||||
(point_start_idx + num_points_per_cloud[n].item().to<int32_t>());
|
||||
|
||||
float bin_y_min = -1.0f;
|
||||
float bin_y_max = bin_y_min + bin_width;
|
||||
float bin_y_max = bin_y_min + bin_width_y;
|
||||
|
||||
// Iterate through the horizontal bins from top to bottom.
|
||||
for (int by = 0; by < B; by++) {
|
||||
for (int by = 0; by < BH; by++) {
|
||||
float bin_x_min = -1.0f;
|
||||
float bin_x_max = bin_x_min + bin_width;
|
||||
float bin_x_max = bin_x_min + bin_width_x;
|
||||
|
||||
// Iterate through bins on this horizontal line, left to right.
|
||||
for (int bx = 0; bx < B; bx++) {
|
||||
for (int bx = 0; bx < BW; bx++) {
|
||||
int32_t points_hit = 0;
|
||||
for (int p = point_start_idx; p < point_stop_idx; ++p) {
|
||||
float px = points_a[p][0];
|
||||
@@ -172,11 +177,11 @@ torch::Tensor RasterizePointsCoarseCpu(
|
||||
|
||||
// Shift the bin to the right for the next loop iteration
|
||||
bin_x_min = bin_x_max;
|
||||
bin_x_max = bin_x_min + bin_width;
|
||||
bin_x_max = bin_x_min + bin_width_x;
|
||||
}
|
||||
// Shift the bin down for the next loop iteration
|
||||
bin_y_min = bin_y_max;
|
||||
bin_y_max = bin_y_min + bin_width;
|
||||
bin_y_max = bin_y_min + bin_width_y;
|
||||
}
|
||||
}
|
||||
return bin_points;
|
||||
@@ -194,11 +199,6 @@ torch::Tensor RasterizePointsBackwardCpu(
|
||||
const int W = idxs.size(2);
|
||||
const int K = idxs.size(3);
|
||||
|
||||
// For now only support square images.
|
||||
// TODO(jcjohns): Extend to non-square images.
|
||||
if (H != W) {
|
||||
AT_ERROR("RasterizePointsBackwardCpu only supports square images");
|
||||
}
|
||||
torch::Tensor grad_points = torch::zeros({P, 3}, points.options());
|
||||
|
||||
auto points_a = points.accessor<float, 2>();
|
||||
@@ -212,7 +212,7 @@ torch::Tensor RasterizePointsBackwardCpu(
|
||||
// Reverse the order of yi so that +Y is pointing upwards in the image.
|
||||
const int yidx = H - 1 - y;
|
||||
// Y coordinate of the top of the pixel.
|
||||
const float yf = PixToNdc(yidx, H);
|
||||
const float yf = PixToNonSquareNdc(yidx, H, W);
|
||||
|
||||
// Iterate through pixels on this horizontal line, left to right.
|
||||
for (int x = 0; x < W; ++x) { // Loop over pixels in the row
|
||||
@@ -220,7 +220,7 @@ torch::Tensor RasterizePointsBackwardCpu(
|
||||
// Reverse the order of xi so that +X is pointing to the left in the
|
||||
// image.
|
||||
const int xidx = W - 1 - x;
|
||||
const float xf = PixToNdc(xidx, W);
|
||||
const float xf = PixToNonSquareNdc(xidx, W, H);
|
||||
for (int k = 0; k < K; ++k) { // Loop over points for the pixel
|
||||
const int p = idxs_a[n][y][x][k];
|
||||
if (p < 0) {
|
||||
|
||||
Reference in New Issue
Block a user