Non Square image rasterization for pointclouds

Summary:
Similar to non square image rasterization for meshes, apply the same updates to the pointcloud rasterizer.

Main API Change:
- PointRasterizationSettings now accepts a tuple/list of (H, W) for the image size.

Reviewed By: jcjohnson

Differential Revision: D25465206

fbshipit-source-id: 7370d83c431af1b972158cecae19d82364623380
This commit is contained in:
Nikhila Ravi 2020-12-15 14:14:27 -08:00 committed by Facebook GitHub Bot
parent 569e5229a9
commit 3d769a66cb
22 changed files with 712 additions and 263 deletions

View File

@ -30,15 +30,15 @@ __global__ void alphaCompositeCudaForwardKernel(
// Get the batch and index // Get the batch and index
const int batch = blockIdx.x; const int batch = blockIdx.x;
const int num_pixels = C * W * H; const int num_pixels = C * H * W;
const int num_threads = gridDim.y * blockDim.x; const int num_threads = gridDim.y * blockDim.x;
const int tid = blockIdx.y * blockDim.x + threadIdx.x; const int tid = blockIdx.y * blockDim.x + threadIdx.x;
// Iterate over each feature in each pixel // Iterate over each feature in each pixel
for (int pid = tid; pid < num_pixels; pid += num_threads) { for (int pid = tid; pid < num_pixels; pid += num_threads) {
int ch = pid / (W * H); int ch = pid / (H * W);
int j = (pid % (W * H)) / H; int j = (pid % (H * W)) / W;
int i = (pid % (W * H)) % H; int i = (pid % (H * W)) % W;
// alphacomposite the different values // alphacomposite the different values
float cum_alpha = 1.; float cum_alpha = 1.;
@ -81,16 +81,16 @@ __global__ void alphaCompositeCudaBackwardKernel(
// Get the batch and index // Get the batch and index
const int batch = blockIdx.x; const int batch = blockIdx.x;
const int num_pixels = C * W * H; const int num_pixels = C * H * W;
const int num_threads = gridDim.y * blockDim.x; const int num_threads = gridDim.y * blockDim.x;
const int tid = blockIdx.y * blockDim.x + threadIdx.x; const int tid = blockIdx.y * blockDim.x + threadIdx.x;
// Parallelize over each feature in each pixel in images of size H * W, // Parallelize over each feature in each pixel in images of size H * W,
// for each image in the batch of size batch_size // for each image in the batch of size batch_size
for (int pid = tid; pid < num_pixels; pid += num_threads) { for (int pid = tid; pid < num_pixels; pid += num_threads) {
int ch = pid / (W * H); int ch = pid / (H * W);
int j = (pid % (W * H)) / H; int j = (pid % (H * W)) / W;
int i = (pid % (W * H)) % H; int i = (pid % (H * W)) % W;
// alphacomposite the different values // alphacomposite the different values
float cum_alpha = 1.; float cum_alpha = 1.;

View File

@ -11,13 +11,13 @@
// features: FloatTensor of shape (C, P) which gives the features // features: FloatTensor of shape (C, P) which gives the features
// of each point where C is the size of the feature and // of each point where C is the size of the feature and
// P the number of points. // P the number of points.
// alphas: FloatTensor of shape (N, points_per_pixel, W, W) where // alphas: FloatTensor of shape (N, points_per_pixel, H, W) where
// points_per_pixel is the number of points in the z-buffer // points_per_pixel is the number of points in the z-buffer
// sorted in z-order, and W is the image size. // sorted in z-order, and (H, W) is the image size.
// points_idx: IntTensor of shape (N, points_per_pixel, W, W) giving the // points_idx: IntTensor of shape (N, points_per_pixel, H, W) giving the
// indices of the nearest points at each pixel, sorted in z-order. // indices of the nearest points at each pixel, sorted in z-order.
// Returns: // Returns:
// weighted_fs: FloatTensor of shape (N, C, W, W) giving the accumulated // weighted_fs: FloatTensor of shape (N, C, H, W) giving the accumulated
// feature for each point. Concretely, it gives: // feature for each point. Concretely, it gives:
// weighted_fs[b,c,i,j] = sum_k cum_alpha_k * // weighted_fs[b,c,i,j] = sum_k cum_alpha_k *
// features[c,points_idx[b,k,i,j]] // features[c,points_idx[b,k,i,j]]

View File

@ -30,16 +30,16 @@ __global__ void weightedSumNormCudaForwardKernel(
// Get the batch and index // Get the batch and index
const int batch = blockIdx.x; const int batch = blockIdx.x;
const int num_pixels = C * W * H; const int num_pixels = C * H * W;
const int num_threads = gridDim.y * blockDim.x; const int num_threads = gridDim.y * blockDim.x;
const int tid = blockIdx.y * blockDim.x + threadIdx.x; const int tid = blockIdx.y * blockDim.x + threadIdx.x;
// Parallelize over each feature in each pixel in images of size H * W, // Parallelize over each feature in each pixel in images of size H * W,
// for each image in the batch of size batch_size // for each image in the batch of size batch_size
for (int pid = tid; pid < num_pixels; pid += num_threads) { for (int pid = tid; pid < num_pixels; pid += num_threads) {
int ch = pid / (W * H); int ch = pid / (H * W);
int j = (pid % (W * H)) / H; int j = (pid % (H * W)) / W;
int i = (pid % (W * H)) % H; int i = (pid % (H * W)) % W;
// Store the accumulated alpha value // Store the accumulated alpha value
float cum_alpha = 0.; float cum_alpha = 0.;
@ -101,9 +101,9 @@ __global__ void weightedSumNormCudaBackwardKernel(
// Parallelize over each feature in each pixel in images of size H * W, // Parallelize over each feature in each pixel in images of size H * W,
// for each image in the batch of size batch_size // for each image in the batch of size batch_size
for (int pid = tid; pid < num_pixels; pid += num_threads) { for (int pid = tid; pid < num_pixels; pid += num_threads) {
int ch = pid / (W * H); int ch = pid / (H * W);
int j = (pid % (W * H)) / H; int j = (pid % (H * W)) / W;
int i = (pid % (W * H)) % H; int i = (pid % (H * W)) % W;
float sum_alpha = 0.; float sum_alpha = 0.;
float sum_alphafs = 0.; float sum_alphafs = 0.;

View File

@ -11,13 +11,13 @@
// features: FloatTensor of shape (C, P) which gives the features // features: FloatTensor of shape (C, P) which gives the features
// of each point where C is the size of the feature and // of each point where C is the size of the feature and
// P the number of points. // P the number of points.
// alphas: FloatTensor of shape (N, points_per_pixel, W, W) where // alphas: FloatTensor of shape (N, points_per_pixel, H, W) where
// points_per_pixel is the number of points in the z-buffer // points_per_pixel is the number of points in the z-buffer
// sorted in z-order, and W is the image size. // sorted in z-order, and (H, W) is the image size.
// points_idx: IntTensor of shape (N, points_per_pixel, W, W) giving the // points_idx: IntTensor of shape (N, points_per_pixel, H, W) giving the
// indices of the nearest points at each pixel, sorted in z-order. // indices of the nearest points at each pixel, sorted in z-order.
// Returns: // Returns:
// weighted_fs: FloatTensor of shape (N, C, W, W) giving the accumulated // weighted_fs: FloatTensor of shape (N, C, H, W) giving the accumulated
// feature in each point. Concretely, it gives: // feature in each point. Concretely, it gives:
// weighted_fs[b,c,i,j] = sum_k alphas[b,k,i,j] * // weighted_fs[b,c,i,j] = sum_k alphas[b,k,i,j] *
// features[c,points_idx[b,k,i,j]] / sum_k alphas[b,k,i,j] // features[c,points_idx[b,k,i,j]] / sum_k alphas[b,k,i,j]

View File

@ -28,16 +28,16 @@ __global__ void weightedSumCudaForwardKernel(
// Get the batch and index // Get the batch and index
const int batch = blockIdx.x; const int batch = blockIdx.x;
const int num_pixels = C * W * H; const int num_pixels = C * H * W;
const int num_threads = gridDim.y * blockDim.x; const int num_threads = gridDim.y * blockDim.x;
const int tid = blockIdx.y * blockDim.x + threadIdx.x; const int tid = blockIdx.y * blockDim.x + threadIdx.x;
// Parallelize over each feature in each pixel in images of size H * W, // Parallelize over each feature in each pixel in images of size H * W,
// for each image in the batch of size batch_size // for each image in the batch of size batch_size
for (int pid = tid; pid < num_pixels; pid += num_threads) { for (int pid = tid; pid < num_pixels; pid += num_threads) {
int ch = pid / (W * H); int ch = pid / (H * W);
int j = (pid % (W * H)) / H; int j = (pid % (H * W)) / W;
int i = (pid % (W * H)) % H; int i = (pid % (H * W)) % W;
// Iterate through the closest K points for this pixel // Iterate through the closest K points for this pixel
for (int k = 0; k < points_idx.size(1); ++k) { for (int k = 0; k < points_idx.size(1); ++k) {
@ -76,16 +76,16 @@ __global__ void weightedSumCudaBackwardKernel(
// Get the batch and index // Get the batch and index
const int batch = blockIdx.x; const int batch = blockIdx.x;
const int num_pixels = C * W * H; const int num_pixels = C * H * W;
const int num_threads = gridDim.y * blockDim.x; const int num_threads = gridDim.y * blockDim.x;
const int tid = blockIdx.y * blockDim.x + threadIdx.x; const int tid = blockIdx.y * blockDim.x + threadIdx.x;
// Iterate over each pixel to compute the contribution to the // Iterate over each pixel to compute the contribution to the
// gradient for the features and weights // gradient for the features and weights
for (int pid = tid; pid < num_pixels; pid += num_threads) { for (int pid = tid; pid < num_pixels; pid += num_threads) {
int ch = pid / (W * H); int ch = pid / (H * W);
int j = (pid % (W * H)) / H; int j = (pid % (H * W)) / W;
int i = (pid % (W * H)) % H; int i = (pid % (H * W)) % W;
// Iterate through the closest K points for this pixel // Iterate through the closest K points for this pixel
for (int k = 0; k < points_idx.size(1); ++k) { for (int k = 0; k < points_idx.size(1); ++k) {

View File

@ -11,13 +11,13 @@
// features: FloatTensor of shape (C, P) which gives the features // features: FloatTensor of shape (C, P) which gives the features
// of each point where C is the size of the feature and // of each point where C is the size of the feature and
// P the number of points. // P the number of points.
// alphas: FloatTensor of shape (N, points_per_pixel, W, W) where // alphas: FloatTensor of shape (N, points_per_pixel, H, W) where
// points_per_pixel is the number of points in the z-buffer // points_per_pixel is the number of points in the z-buffer
// sorted in z-order, and W is the image size. // sorted in z-order, and (H, W) is the image size.
// points_idx: IntTensor of shape (N, points_per_pixel, W, W) giving the // points_idx: IntTensor of shape (N, points_per_pixel, W, W) giving the
// indices of the nearest points at each pixel, sorted in z-order. // indices of the nearest points at each pixel, sorted in z-order.
// Returns: // Returns:
// weighted_fs: FloatTensor of shape (N, C, W, W) giving the accumulated // weighted_fs: FloatTensor of shape (N, C, H, W) giving the accumulated
// feature in each point. Concretely, it gives: // feature in each point. Concretely, it gives:
// weighted_fs[b,c,i,j] = sum_k alphas[b,k,i,j] * // weighted_fs[b,c,i,j] = sum_k alphas[b,k,i,j] *
// features[c,points_idx[b,k,i,j]] // features[c,points_idx[b,k,i,j]]

View File

@ -452,7 +452,6 @@ __global__ void RasterizeMeshesBackwardCudaKernel(
const bool inside = b_pp.x > 0.0f && b_pp.y > 0.0f && b_pp.z > 0.0f; const bool inside = b_pp.x > 0.0f && b_pp.y > 0.0f && b_pp.z > 0.0f;
const float sign = inside ? -1.0f : 1.0f; const float sign = inside ? -1.0f : 1.0f;
// TODO(T52813608) Add support for non-square images.
auto grad_dist_f = PointTriangleDistanceBackward( auto grad_dist_f = PointTriangleDistanceBackward(
pxy, v0xy, v1xy, v2xy, sign * grad_dist_upstream); pxy, v0xy, v1xy, v2xy, sign * grad_dist_upstream);
const float2 ddist_d_v0 = thrust::get<1>(grad_dist_f); const float2 ddist_d_v0 = thrust::get<1>(grad_dist_f);
@ -606,7 +605,7 @@ __global__ void RasterizeMeshesCoarseCudaKernel(
const float half_pix_x = NDC_x_half_range / W; const float half_pix_x = NDC_x_half_range / W;
const float half_pix_y = NDC_y_half_range / H; const float half_pix_y = NDC_y_half_range / H;
// This is a boolean array of shape (num_bins, num_bins, chunk_size) // This is a boolean array of shape (num_bins_y, num_bins_x, chunk_size)
// stored in shared memory that will track whether each point in the chunk // stored in shared memory that will track whether each point in the chunk
// falls into each bin of the image. // falls into each bin of the image.
BitMask binmask((unsigned int*)sbuf, num_bins_y, num_bins_x, chunk_size); BitMask binmask((unsigned int*)sbuf, num_bins_y, num_bins_x, chunk_size);
@ -755,7 +754,7 @@ at::Tensor RasterizeMeshesCoarseCuda(
const int num_bins_y = 1 + (H - 1) / bin_size; const int num_bins_y = 1 + (H - 1) / bin_size;
const int num_bins_x = 1 + (W - 1) / bin_size; const int num_bins_x = 1 + (W - 1) / bin_size;
if (num_bins_y >= kMaxFacesPerBin || num_bins_x >= kMaxFacesPerBin) { if (num_bins_y >= kMaxItemsPerBin || num_bins_x >= kMaxItemsPerBin) {
std::stringstream ss; std::stringstream ss;
ss << "In Coarse Rasterizer got num_bins_y: " << num_bins_y ss << "In Coarse Rasterizer got num_bins_y: " << num_bins_y
<< ", num_bins_x: " << num_bins_x << ", " << ", num_bins_x: " << num_bins_x << ", "
@ -800,7 +799,7 @@ at::Tensor RasterizeMeshesCoarseCuda(
// **************************************************************************** // ****************************************************************************
__global__ void RasterizeMeshesFineCudaKernel( __global__ void RasterizeMeshesFineCudaKernel(
const float* face_verts, // (F, 3, 3) const float* face_verts, // (F, 3, 3)
const int32_t* bin_faces, // (N, B, B, T) const int32_t* bin_faces, // (N, BH, BW, T)
const float blur_radius, const float blur_radius,
const int bin_size, const int bin_size,
const bool perspective_correct, const bool perspective_correct,
@ -813,12 +812,12 @@ __global__ void RasterizeMeshesFineCudaKernel(
const int H, const int H,
const int W, const int W,
const int K, const int K,
int64_t* face_idxs, // (N, S, S, K) int64_t* face_idxs, // (N, H, W, K)
float* zbuf, // (N, S, S, K) float* zbuf, // (N, H, W, K)
float* pix_dists, // (N, S, S, K) float* pix_dists, // (N, H, W, K)
float* bary // (N, S, S, K, 3) float* bary // (N, H, W, K, 3)
) { ) {
// This can be more than S^2 if S % bin_size != 0 // This can be more than H * W if H or W are not divisible by bin_size.
int num_pixels = N * BH * BW * bin_size * bin_size; int num_pixels = N * BH * BW * bin_size * bin_size;
int num_threads = gridDim.x * blockDim.x; int num_threads = gridDim.x * blockDim.x;
int tid = blockIdx.x * blockDim.x + threadIdx.x; int tid = blockIdx.x * blockDim.x + threadIdx.x;

View File

@ -5,41 +5,11 @@
#include <list> #include <list>
#include <queue> #include <queue>
#include <tuple> #include <tuple>
#include "rasterize_points/rasterization_utils.h"
#include "utils/geometry_utils.h" #include "utils/geometry_utils.h"
#include "utils/vec2.h" #include "utils/vec2.h"
#include "utils/vec3.h" #include "utils/vec3.h"
// The default value of the NDC range is [-1, 1], however in the case that
// H != W, the NDC range is set such that the shorter side has range [-1, 1] and
// the longer side is scaled by the ratio of H:W. S1 is the dimension for which
// the NDC range is calculated and S2 is the other image dimension.
// e.g. to get the NDC x range S1 = W and S2 = H
float NonSquareNdcRange(int S1, int S2) {
float range = 2.0f;
if (S1 > S2) {
range = ((S1 / S2) * range);
}
return range;
}
// Given a pixel coordinate 0 <= i < S1, convert it to a normalized device
// coordinates. We divide the NDC range into S1 evenly-sized
// pixels, and assume that each pixel falls in the *center* of its range.
// The default value of the NDC range is [-1, 1], however in the case that
// H != W, the NDC range is set such that the shorter side has range [-1, 1] and
// the longer side is scaled by the ratio of H:W. The dimension of i should be
// S1 and the other image dimension is S2 For example, to get the x and y NDC
// coordinates or a given pixel i:
// x = PixToNonSquareNdc(i, W, H)
// y = PixToNonSquareNdc(i, H, W)
float PixToNonSquareNdc(int i, int S1, int S2) {
float range = NonSquareNdcRange(S1, S2);
// NDC: offset + (i * pixel_width + half_pixel_width)
// The NDC range is [-range/2, range/2].
const float offset = (range / 2.0f);
return -offset + (range * i + offset) / S1;
}
// Get (x, y, z) values for vertex from (3, 3) tensor face. // Get (x, y, z) values for vertex from (3, 3) tensor face.
template <typename Face> template <typename Face>
auto ExtractVerts(const Face& face, const int vertex_index) { auto ExtractVerts(const Face& face, const int vertex_index) {

View File

@ -2,16 +2,6 @@
#pragma once #pragma once
// Given a pixel coordinate 0 <= i < S, convert it to a normalized device
// coordinates in the range [-1, 1]. We divide the NDC range into S evenly-sized
// pixels, and assume that each pixel falls in the *center* of its range.
// TODO: delete this function after updating the pointcloud rasterizer to
// support non square images.
__device__ inline float PixToNdc(int i, int S) {
// NDC: x-offset + (i * pixel_width + half_pixel_width)
return -1.0 + (2 * i + 1.0) / S;
}
// The default value of the NDC range is [-1, 1], however in the case that // The default value of the NDC range is [-1, 1], however in the case that
// H != W, the NDC range is set such that the shorter side has range [-1, 1] and // H != W, the NDC range is set such that the shorter side has range [-1, 1] and
// the longer side is scaled by the ratio of H:W. S1 is the dimension for which // the longer side is scaled by the ratio of H:W. S1 is the dimension for which
@ -50,7 +40,7 @@ __device__ inline float PixToNonSquareNdc(int i, int S1, int S2) {
// TODO: is 8 enough? Would increasing have performance considerations? // TODO: is 8 enough? Would increasing have performance considerations?
const int32_t kMaxPointsPerPixel = 150; const int32_t kMaxPointsPerPixel = 150;
const int32_t kMaxFacesPerBin = 22; const int32_t kMaxItemsPerBin = 22;
template <typename T> template <typename T>
__device__ inline void BubbleSort(T* arr, int n) { __device__ inline void BubbleSort(T* arr, int n) {

View File

@ -0,0 +1,34 @@
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#pragma once
// The default value of the NDC range is [-1, 1], however in the case that
// H != W, the NDC range is set such that the shorter side has range [-1, 1] and
// the longer side is scaled by the ratio of H:W. S1 is the dimension for which
// the NDC range is calculated and S2 is the other image dimension.
// e.g. to get the NDC x range S1 = W and S2 = H
inline float NonSquareNdcRange(int S1, int S2) {
float range = 2.0f;
if (S1 > S2) {
range = ((S1 / S2) * range);
}
return range;
}
// Given a pixel coordinate 0 <= i < S1, convert it to a normalized device
// coordinates. We divide the NDC range into S1 evenly-sized
// pixels, and assume that each pixel falls in the *center* of its range.
// The default value of the NDC range is [-1, 1], however in the case that
// H != W, the NDC range is set such that the shorter side has range [-1, 1] and
// the longer side is scaled by the ratio of H:W. The dimension of i should be
// S1 and the other image dimension is S2 For example, to get the x and y NDC
// coordinates or a given pixel i:
// x = PixToNonSquareNdc(i, W, H)
// y = PixToNonSquareNdc(i, H, W)
inline float PixToNonSquareNdc(int i, int S1, int S2) {
float range = NonSquareNdcRange(S1, S2);
// NDC: offset + (i * pixel_width + half_pixel_width)
// The NDC range is [-range/2, range/2].
const float offset = (range / 2.0f);
return -offset + (range * i + offset) / S1;
}

View File

@ -85,26 +85,28 @@ __global__ void RasterizePointsNaiveCudaKernel(
const int64_t* num_points_per_cloud, // (N) const int64_t* num_points_per_cloud, // (N)
const float* radius, const float* radius,
const int N, const int N,
const int S, const int H,
const int W,
const int K, const int K,
int32_t* point_idxs, // (N, S, S, K) int32_t* point_idxs, // (N, H, W, K)
float* zbuf, // (N, S, S, K) float* zbuf, // (N, H, W, K)
float* pix_dists) { // (N, S, S, K) float* pix_dists) { // (N, H, W, K)
// Simple version: One thread per output pixel // Simple version: One thread per output pixel
const int num_threads = gridDim.x * blockDim.x; const int num_threads = gridDim.x * blockDim.x;
const int tid = blockDim.x * blockIdx.x + threadIdx.x; const int tid = blockDim.x * blockIdx.x + threadIdx.x;
for (int i = tid; i < N * S * S; i += num_threads) { for (int i = tid; i < N * H * W; i += num_threads) {
// Convert linear index to 3D index // Convert linear index to 3D index
const int n = i / (S * S); // Batch index const int n = i / (H * W); // Batch index
const int pix_idx = i % (S * S); const int pix_idx = i % (H * W);
// Reverse ordering of the X and Y axis as the camera coordinates // Reverse ordering of the X and Y axis as the camera coordinates
// assume that +Y is pointing up and +X is pointing left. // assume that +Y is pointing up and +X is pointing left.
const int yi = S - 1 - pix_idx / S; const int yi = H - 1 - pix_idx / W;
const int xi = S - 1 - pix_idx % S; const int xi = W - 1 - pix_idx % W;
const float xf = PixToNdc(xi, S); // screen coordinates to ndc coordiantes of pixel.
const float yf = PixToNdc(yi, S); const float xf = PixToNonSquareNdc(xi, W, H);
const float yf = PixToNonSquareNdc(yi, H, W);
// For keeping track of the K closest points we want a data structure // For keeping track of the K closest points we want a data structure
// that (1) gives O(1) access to the closest point for easy comparisons, // that (1) gives O(1) access to the closest point for easy comparisons,
@ -132,7 +134,7 @@ __global__ void RasterizePointsNaiveCudaKernel(
points, p_idx, q_size, q_max_z, q_max_idx, q, radius, xf, yf, K); points, p_idx, q_size, q_max_z, q_max_idx, q, radius, xf, yf, K);
} }
BubbleSort(q, q_size); BubbleSort(q, q_size);
int idx = n * S * S * K + pix_idx * K; int idx = n * H * W * K + pix_idx * K;
for (int k = 0; k < q_size; ++k) { for (int k = 0; k < q_size; ++k) {
point_idxs[idx + k] = q[k].idx; point_idxs[idx + k] = q[k].idx;
zbuf[idx + k] = q[k].z; zbuf[idx + k] = q[k].z;
@ -145,7 +147,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> RasterizePointsNaiveCuda(
const at::Tensor& points, // (P. 3) const at::Tensor& points, // (P. 3)
const at::Tensor& cloud_to_packed_first_idx, // (N) const at::Tensor& cloud_to_packed_first_idx, // (N)
const at::Tensor& num_points_per_cloud, // (N) const at::Tensor& num_points_per_cloud, // (N)
const int image_size, const std::tuple<int, int> image_size,
const at::Tensor& radius, const at::Tensor& radius,
const int points_per_pixel) { const int points_per_pixel) {
// Check inputs are on the same device // Check inputs are on the same device
@ -169,7 +171,8 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> RasterizePointsNaiveCuda(
"num_points_per_cloud must have same size first dimension as cloud_to_packed_first_idx"); "num_points_per_cloud must have same size first dimension as cloud_to_packed_first_idx");
const int N = num_points_per_cloud.size(0); // batch size. const int N = num_points_per_cloud.size(0); // batch size.
const int S = image_size; const int H = std::get<0>(image_size);
const int W = std::get<1>(image_size);
const int K = points_per_pixel; const int K = points_per_pixel;
if (K > kMaxPointsPerPixel) { if (K > kMaxPointsPerPixel) {
@ -180,9 +183,9 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> RasterizePointsNaiveCuda(
auto int_opts = num_points_per_cloud.options().dtype(at::kInt); auto int_opts = num_points_per_cloud.options().dtype(at::kInt);
auto float_opts = points.options().dtype(at::kFloat); auto float_opts = points.options().dtype(at::kFloat);
at::Tensor point_idxs = at::full({N, S, S, K}, -1, int_opts); at::Tensor point_idxs = at::full({N, H, W, K}, -1, int_opts);
at::Tensor zbuf = at::full({N, S, S, K}, -1, float_opts); at::Tensor zbuf = at::full({N, H, W, K}, -1, float_opts);
at::Tensor pix_dists = at::full({N, S, S, K}, -1, float_opts); at::Tensor pix_dists = at::full({N, H, W, K}, -1, float_opts);
if (point_idxs.numel() == 0) { if (point_idxs.numel() == 0) {
AT_CUDA_CHECK(cudaGetLastError()); AT_CUDA_CHECK(cudaGetLastError());
@ -197,7 +200,8 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> RasterizePointsNaiveCuda(
num_points_per_cloud.contiguous().data_ptr<int64_t>(), num_points_per_cloud.contiguous().data_ptr<int64_t>(),
radius.contiguous().data_ptr<float>(), radius.contiguous().data_ptr<float>(),
N, N,
S, H,
W,
K, K,
point_idxs.contiguous().data_ptr<int32_t>(), point_idxs.contiguous().data_ptr<int32_t>(),
zbuf.contiguous().data_ptr<float>(), zbuf.contiguous().data_ptr<float>(),
@ -218,7 +222,8 @@ __global__ void RasterizePointsCoarseCudaKernel(
const float* radius, const float* radius,
const int N, const int N,
const int P, const int P,
const int S, const int H,
const int W,
const int bin_size, const int bin_size,
const int chunk_size, const int chunk_size,
const int max_points_per_bin, const int max_points_per_bin,
@ -226,13 +231,26 @@ __global__ void RasterizePointsCoarseCudaKernel(
int* bin_points) { int* bin_points) {
extern __shared__ char sbuf[]; extern __shared__ char sbuf[];
const int M = max_points_per_bin; const int M = max_points_per_bin;
const int num_bins = 1 + (S - 1) / bin_size; // Integer divide round up
const float half_pix = 1.0f / S; // Size of half a pixel in NDC units
// This is a boolean array of shape (num_bins, num_bins, chunk_size) // Integer divide round up
const int num_bins_x = 1 + (W - 1) / bin_size;
const int num_bins_y = 1 + (H - 1) / bin_size;
// NDC range depends on the ratio of W/H
// The shorter side from (H, W) is given an NDC range of 2.0 and
// the other side is scaled by the ratio of H:W.
const float NDC_x_half_range = NonSquareNdcRange(W, H) / 2.0f;
const float NDC_y_half_range = NonSquareNdcRange(H, W) / 2.0f;
// Size of half a pixel in NDC units is the NDC half range
// divided by the corresponding image dimension
const float half_pix_x = NDC_x_half_range / W;
const float half_pix_y = NDC_y_half_range / H;
// This is a boolean array of shape (num_bins_y, num_bins_x, chunk_size)
// stored in shared memory that will track whether each point in the chunk // stored in shared memory that will track whether each point in the chunk
// falls into each bin of the image. // falls into each bin of the image.
BitMask binmask((unsigned int*)sbuf, num_bins, num_bins, chunk_size); BitMask binmask((unsigned int*)sbuf, num_bins_y, num_bins_x, chunk_size);
// Have each block handle a chunk of points and build a 3D bitmask in // Have each block handle a chunk of points and build a 3D bitmask in
// shared memory to mark which points hit which bins. In this first phase, // shared memory to mark which points hit which bins. In this first phase,
@ -279,22 +297,24 @@ __global__ void RasterizePointsCoarseCudaKernel(
// For example we could compute the exact bin where the point falls, // For example we could compute the exact bin where the point falls,
// then check neighboring bins. This way we wouldn't have to check // then check neighboring bins. This way we wouldn't have to check
// all bins (however then we might have more warp divergence?) // all bins (however then we might have more warp divergence?)
for (int by = 0; by < num_bins; ++by) { for (int by = 0; by < num_bins_y; ++by) {
// Get y extent for the bin. PixToNdc gives us the location of // Get y extent for the bin. PixToNonSquareNdc gives us the location of
// the center of each pixel, so we need to add/subtract a half // the center of each pixel, so we need to add/subtract a half
// pixel to get the true extent of the bin. // pixel to get the true extent of the bin.
const float by0 = PixToNdc(by * bin_size, S) - half_pix; const float by0 = PixToNonSquareNdc(by * bin_size, H, W) - half_pix_y;
const float by1 = PixToNdc((by + 1) * bin_size - 1, S) + half_pix; const float by1 =
PixToNonSquareNdc((by + 1) * bin_size - 1, H, W) + half_pix_y;
const bool y_overlap = (py0 <= by1) && (by0 <= py1); const bool y_overlap = (py0 <= by1) && (by0 <= py1);
if (!y_overlap) { if (!y_overlap) {
continue; continue;
} }
for (int bx = 0; bx < num_bins; ++bx) { for (int bx = 0; bx < num_bins_x; ++bx) {
// Get x extent for the bin; again we need to adjust the // Get x extent for the bin; again we need to adjust the
// output of PixToNdc by half a pixel. // output of PixToNonSquareNdc by half a pixel.
const float bx0 = PixToNdc(bx * bin_size, S) - half_pix; const float bx0 = PixToNonSquareNdc(bx * bin_size, W, H) - half_pix_x;
const float bx1 = PixToNdc((bx + 1) * bin_size - 1, S) + half_pix; const float bx1 =
PixToNonSquareNdc((bx + 1) * bin_size - 1, W, H) + half_pix_x;
const bool x_overlap = (px0 <= bx1) && (bx0 <= px1); const bool x_overlap = (px0 <= bx1) && (bx0 <= px1);
if (x_overlap) { if (x_overlap) {
@ -307,12 +327,13 @@ __global__ void RasterizePointsCoarseCudaKernel(
// Now we have processed every point in the current chunk. We need to // Now we have processed every point in the current chunk. We need to
// count the number of points in each bin so we can write the indices // count the number of points in each bin so we can write the indices
// out to global memory. We have each thread handle a different bin. // out to global memory. We have each thread handle a different bin.
for (int byx = threadIdx.x; byx < num_bins * num_bins; byx += blockDim.x) { for (int byx = threadIdx.x; byx < num_bins_y * num_bins_x;
const int by = byx / num_bins; byx += blockDim.x) {
const int bx = byx % num_bins; const int by = byx / num_bins_x;
const int bx = byx % num_bins_x;
const int count = binmask.count(by, bx); const int count = binmask.count(by, bx);
const int points_per_bin_idx = const int points_per_bin_idx =
batch_idx * num_bins * num_bins + by * num_bins + bx; batch_idx * num_bins_y * num_bins_x + by * num_bins_x + bx;
// This atomically increments the (global) number of points found // This atomically increments the (global) number of points found
// in the current bin, and gets the previous value of the counter; // in the current bin, and gets the previous value of the counter;
@ -322,8 +343,8 @@ __global__ void RasterizePointsCoarseCudaKernel(
// Now loop over the binmask and write the active bits for this bin // Now loop over the binmask and write the active bits for this bin
// out to bin_points. // out to bin_points.
int next_idx = batch_idx * num_bins * num_bins * M + by * num_bins * M + int next_idx = batch_idx * num_bins_y * num_bins_x * M +
bx * M + start; by * num_bins_x * M + bx * M + start;
for (int p = 0; p < chunk_size; ++p) { for (int p = 0; p < chunk_size; ++p) {
if (binmask.get(by, bx, p)) { if (binmask.get(by, bx, p)) {
// TODO: Throw an error if next_idx >= M -- this means that // TODO: Throw an error if next_idx >= M -- this means that
@ -342,7 +363,7 @@ at::Tensor RasterizePointsCoarseCuda(
const at::Tensor& points, // (P, 3) const at::Tensor& points, // (P, 3)
const at::Tensor& cloud_to_packed_first_idx, // (N) const at::Tensor& cloud_to_packed_first_idx, // (N)
const at::Tensor& num_points_per_cloud, // (N) const at::Tensor& num_points_per_cloud, // (N)
const int image_size, const std::tuple<int, int> image_size,
const at::Tensor& radius, const at::Tensor& radius,
const int bin_size, const int bin_size,
const int max_points_per_bin) { const int max_points_per_bin) {
@ -363,20 +384,28 @@ at::Tensor RasterizePointsCoarseCuda(
at::cuda::CUDAGuard device_guard(points.device()); at::cuda::CUDAGuard device_guard(points.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream(); cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const int H = std::get<0>(image_size);
const int W = std::get<1>(image_size);
const int P = points.size(0); const int P = points.size(0);
const int N = num_points_per_cloud.size(0); const int N = num_points_per_cloud.size(0);
const int num_bins = 1 + (image_size - 1) / bin_size; // divide round up
const int M = max_points_per_bin; const int M = max_points_per_bin;
if (num_bins >= 22) { // Integer divide round up.
const int num_bins_y = 1 + (H - 1) / bin_size;
const int num_bins_x = 1 + (W - 1) / bin_size;
if (num_bins_y >= kMaxItemsPerBin || num_bins_x >= kMaxItemsPerBin) {
// Make sure we do not use too much shared memory. // Make sure we do not use too much shared memory.
std::stringstream ss; std::stringstream ss;
ss << "Got " << num_bins << "; that's too many!"; ss << "In Coarse Rasterizer got num_bins_y: " << num_bins_y
<< ", num_bins_x: " << num_bins_x << ", "
<< "; that's too many!";
AT_ERROR(ss.str()); AT_ERROR(ss.str());
} }
auto opts = num_points_per_cloud.options().dtype(at::kInt); auto opts = num_points_per_cloud.options().dtype(at::kInt);
at::Tensor points_per_bin = at::zeros({N, num_bins, num_bins}, opts); at::Tensor points_per_bin = at::zeros({N, num_bins_y, num_bins_x}, opts);
at::Tensor bin_points = at::full({N, num_bins, num_bins, M}, -1, opts); at::Tensor bin_points = at::full({N, num_bins_y, num_bins_x, M}, -1, opts);
if (bin_points.numel() == 0) { if (bin_points.numel() == 0) {
AT_CUDA_CHECK(cudaGetLastError()); AT_CUDA_CHECK(cudaGetLastError());
@ -384,7 +413,7 @@ at::Tensor RasterizePointsCoarseCuda(
} }
const int chunk_size = 512; const int chunk_size = 512;
const size_t shared_size = num_bins * num_bins * chunk_size / 8; const size_t shared_size = num_bins_y * num_bins_x * chunk_size / 8;
const size_t blocks = 64; const size_t blocks = 64;
const size_t threads = 512; const size_t threads = 512;
@ -395,7 +424,8 @@ at::Tensor RasterizePointsCoarseCuda(
radius.contiguous().data_ptr<float>(), radius.contiguous().data_ptr<float>(),
N, N,
P, P,
image_size, H,
W,
bin_size, bin_size,
chunk_size, chunk_size,
M, M,
@ -412,19 +442,21 @@ at::Tensor RasterizePointsCoarseCuda(
__global__ void RasterizePointsFineCudaKernel( __global__ void RasterizePointsFineCudaKernel(
const float* points, // (P, 3) const float* points, // (P, 3)
const int32_t* bin_points, // (N, B, B, T) const int32_t* bin_points, // (N, BH, BW, T)
const float* radius, const float* radius,
const int bin_size, const int bin_size,
const int N, const int N,
const int B, // num_bins const int BH, // num_bins y
const int BW, // num_bins x
const int M, const int M,
const int S, const int H,
const int W,
const int K, const int K,
int32_t* point_idxs, // (N, S, S, K) int32_t* point_idxs, // (N, H, W, K)
float* zbuf, // (N, S, S, K) float* zbuf, // (N, H, W, K)
float* pix_dists) { // (N, S, S, K) float* pix_dists) { // (N, H, W, K)
// This can be more than S^2 if S is not dividable by bin_size. // This can be more than H * W if H or W are not divisible by bin_size.
const int num_pixels = N * B * B * bin_size * bin_size; const int num_pixels = N * BH * BW * bin_size * bin_size;
const int num_threads = gridDim.x * blockDim.x; const int num_threads = gridDim.x * blockDim.x;
const int tid = blockIdx.x * blockDim.x + threadIdx.x; const int tid = blockIdx.x * blockDim.x + threadIdx.x;
@ -434,21 +466,21 @@ __global__ void RasterizePointsFineCudaKernel(
// into the same bin; this should give them coalesced memory reads when // into the same bin; this should give them coalesced memory reads when
// they read from points and bin_points. // they read from points and bin_points.
int i = pid; int i = pid;
const int n = i / (B * B * bin_size * bin_size); const int n = i / (BH * BW * bin_size * bin_size);
i %= B * B * bin_size * bin_size; i %= BH * BW * bin_size * bin_size;
const int by = i / (B * bin_size * bin_size); const int by = i / (BW * bin_size * bin_size);
i %= B * bin_size * bin_size; i %= BW * bin_size * bin_size;
const int bx = i / (bin_size * bin_size); const int bx = i / (bin_size * bin_size);
i %= bin_size * bin_size; i %= bin_size * bin_size;
const int yi = i / bin_size + by * bin_size; const int yi = i / bin_size + by * bin_size;
const int xi = i % bin_size + bx * bin_size; const int xi = i % bin_size + bx * bin_size;
if (yi >= S || xi >= S) if (yi >= H || xi >= W)
continue; continue;
const float xf = PixToNdc(xi, S); const float xf = PixToNonSquareNdc(xi, W, H);
const float yf = PixToNdc(yi, S); const float yf = PixToNonSquareNdc(yi, H, W);
// This part looks like the naive rasterization kernel, except we use // This part looks like the naive rasterization kernel, except we use
// bin_points to only look at a subset of points already known to fall // bin_points to only look at a subset of points already known to fall
@ -459,7 +491,7 @@ __global__ void RasterizePointsFineCudaKernel(
float q_max_z = -1000; float q_max_z = -1000;
int q_max_idx = -1; int q_max_idx = -1;
for (int m = 0; m < M; ++m) { for (int m = 0; m < M; ++m) {
const int p = bin_points[n * B * B * M + by * B * M + bx * M + m]; const int p = bin_points[n * BH * BW * M + by * BW * M + bx * M + m];
if (p < 0) { if (p < 0) {
// bin_points uses -1 as a sentinal value // bin_points uses -1 as a sentinal value
continue; continue;
@ -473,10 +505,10 @@ __global__ void RasterizePointsFineCudaKernel(
// Reverse ordering of the X and Y axis as the camera coordinates // Reverse ordering of the X and Y axis as the camera coordinates
// assume that +Y is pointing up and +X is pointing left. // assume that +Y is pointing up and +X is pointing left.
const int yidx = S - 1 - yi; const int yidx = H - 1 - yi;
const int xidx = S - 1 - xi; const int xidx = W - 1 - xi;
const int pix_idx = n * S * S * K + yidx * S * K + xidx * K; const int pix_idx = n * H * W * K + yidx * W * K + xidx * K;
for (int k = 0; k < q_size; ++k) { for (int k = 0; k < q_size; ++k) {
point_idxs[pix_idx + k] = q[k].idx; point_idxs[pix_idx + k] = q[k].idx;
zbuf[pix_idx + k] = q[k].z; zbuf[pix_idx + k] = q[k].z;
@ -488,7 +520,7 @@ __global__ void RasterizePointsFineCudaKernel(
std::tuple<at::Tensor, at::Tensor, at::Tensor> RasterizePointsFineCuda( std::tuple<at::Tensor, at::Tensor, at::Tensor> RasterizePointsFineCuda(
const at::Tensor& points, // (P, 3) const at::Tensor& points, // (P, 3)
const at::Tensor& bin_points, const at::Tensor& bin_points,
const int image_size, const std::tuple<int, int> image_size,
const at::Tensor& radius, const at::Tensor& radius,
const int bin_size, const int bin_size,
const int points_per_pixel) { const int points_per_pixel) {
@ -503,18 +535,22 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> RasterizePointsFineCuda(
cudaStream_t stream = at::cuda::getCurrentCUDAStream(); cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const int N = bin_points.size(0); const int N = bin_points.size(0);
const int B = bin_points.size(1); // num_bins const int BH = bin_points.size(1);
const int BW = bin_points.size(2);
const int M = bin_points.size(3); const int M = bin_points.size(3);
const int S = image_size;
const int K = points_per_pixel; const int K = points_per_pixel;
const int H = std::get<0>(image_size);
const int W = std::get<1>(image_size);
if (K > kMaxPointsPerPixel) { if (K > kMaxPointsPerPixel) {
AT_ERROR("Must have num_closest <= 150"); AT_ERROR("Must have num_closest <= 150");
} }
auto int_opts = bin_points.options().dtype(at::kInt); auto int_opts = bin_points.options().dtype(at::kInt);
auto float_opts = points.options().dtype(at::kFloat); auto float_opts = points.options().dtype(at::kFloat);
at::Tensor point_idxs = at::full({N, S, S, K}, -1, int_opts); at::Tensor point_idxs = at::full({N, H, W, K}, -1, int_opts);
at::Tensor zbuf = at::full({N, S, S, K}, -1, float_opts); at::Tensor zbuf = at::full({N, H, W, K}, -1, float_opts);
at::Tensor pix_dists = at::full({N, S, S, K}, -1, float_opts); at::Tensor pix_dists = at::full({N, H, W, K}, -1, float_opts);
if (point_idxs.numel() == 0) { if (point_idxs.numel() == 0) {
AT_CUDA_CHECK(cudaGetLastError()); AT_CUDA_CHECK(cudaGetLastError());
@ -529,9 +565,11 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> RasterizePointsFineCuda(
radius.contiguous().data_ptr<float>(), radius.contiguous().data_ptr<float>(),
bin_size, bin_size,
N, N,
B, BH,
BW,
M, M,
S, H,
W,
K, K,
point_idxs.contiguous().data_ptr<int32_t>(), point_idxs.contiguous().data_ptr<int32_t>(),
zbuf.contiguous().data_ptr<float>(), zbuf.contiguous().data_ptr<float>(),
@ -571,8 +609,8 @@ __global__ void RasterizePointsBackwardCudaKernel(
const int yidx = H - 1 - yi; const int yidx = H - 1 - yi;
const int xidx = W - 1 - xi; const int xidx = W - 1 - xi;
const float xf = PixToNdc(xidx, W); const float xf = PixToNonSquareNdc(xidx, W, H);
const float yf = PixToNdc(yidx, H); const float yf = PixToNonSquareNdc(yidx, H, W);
const int p = idxs[i]; const int p = idxs[i];
if (p < 0) if (p < 0)

View File

@ -14,7 +14,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsNaiveCpu(
const torch::Tensor& points, const torch::Tensor& points,
const torch::Tensor& cloud_to_packed_first_idx, const torch::Tensor& cloud_to_packed_first_idx,
const torch::Tensor& num_points_per_cloud, const torch::Tensor& num_points_per_cloud,
const int image_size, const std::tuple<int, int> image_size,
const torch::Tensor& radius, const torch::Tensor& radius,
const int points_per_pixel); const int points_per_pixel);
@ -24,7 +24,7 @@ RasterizePointsNaiveCuda(
const torch::Tensor& points, const torch::Tensor& points,
const torch::Tensor& cloud_to_packed_first_idx, const torch::Tensor& cloud_to_packed_first_idx,
const torch::Tensor& num_points_per_cloud, const torch::Tensor& num_points_per_cloud,
const int image_size, const std::tuple<int, int> image_size,
const torch::Tensor& radius, const torch::Tensor& radius,
const int points_per_pixel); const int points_per_pixel);
#endif #endif
@ -43,7 +43,8 @@ RasterizePointsNaiveCuda(
// for each pointcloud in the batch. // for each pointcloud in the batch.
// radius: FloatTensor of shape (P) giving the radius (in NDC units) of // radius: FloatTensor of shape (P) giving the radius (in NDC units) of
// each point in points. // each point in points.
// image_size: (S) Size of the image to return (in pixels) // image_size: Tuple (H, W) giving the size in pixels of the output
// image to be rasterized.
// points_per_pixel: (K) The number closest of points to return for each pixel // points_per_pixel: (K) The number closest of points to return for each pixel
// //
// Returns: // Returns:
@ -62,7 +63,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsNaive(
const torch::Tensor& points, const torch::Tensor& points,
const torch::Tensor& cloud_to_packed_first_idx, const torch::Tensor& cloud_to_packed_first_idx,
const torch::Tensor& num_points_per_cloud, const torch::Tensor& num_points_per_cloud,
const int image_size, const std::tuple<int, int> image_size,
const torch::Tensor& radius, const torch::Tensor& radius,
const int points_per_pixel) { const int points_per_pixel) {
if (points.is_cuda() && cloud_to_packed_first_idx.is_cuda() && if (points.is_cuda() && cloud_to_packed_first_idx.is_cuda() &&
@ -101,7 +102,7 @@ torch::Tensor RasterizePointsCoarseCpu(
const torch::Tensor& points, const torch::Tensor& points,
const torch::Tensor& cloud_to_packed_first_idx, const torch::Tensor& cloud_to_packed_first_idx,
const torch::Tensor& num_points_per_cloud, const torch::Tensor& num_points_per_cloud,
const int image_size, const std::tuple<int, int> image_size,
const torch::Tensor& radius, const torch::Tensor& radius,
const int bin_size, const int bin_size,
const int max_points_per_bin); const int max_points_per_bin);
@ -111,7 +112,7 @@ torch::Tensor RasterizePointsCoarseCuda(
const torch::Tensor& points, const torch::Tensor& points,
const torch::Tensor& cloud_to_packed_first_idx, const torch::Tensor& cloud_to_packed_first_idx,
const torch::Tensor& num_points_per_cloud, const torch::Tensor& num_points_per_cloud,
const int image_size, const std::tuple<int, int> image_size,
const torch::Tensor& radius, const torch::Tensor& radius,
const int bin_size, const int bin_size,
const int max_points_per_bin); const int max_points_per_bin);
@ -128,7 +129,8 @@ torch::Tensor RasterizePointsCoarseCuda(
// for each pointcloud in the batch. // for each pointcloud in the batch.
// radius: FloatTensor of shape (P) giving the radius (in NDC units) of // radius: FloatTensor of shape (P) giving the radius (in NDC units) of
// each point in points. // each point in points.
// image_size: Size of the image to generate (in pixels) // image_size: Tuple (H, W) giving the size in pixels of the output
// image to be rasterized.
// bin_size: Size of each bin within the image (in pixels) // bin_size: Size of each bin within the image (in pixels)
// //
// Returns: // Returns:
@ -140,7 +142,7 @@ torch::Tensor RasterizePointsCoarse(
const torch::Tensor& points, const torch::Tensor& points,
const torch::Tensor& cloud_to_packed_first_idx, const torch::Tensor& cloud_to_packed_first_idx,
const torch::Tensor& num_points_per_cloud, const torch::Tensor& num_points_per_cloud,
const int image_size, const std::tuple<int, int> image_size,
const torch::Tensor& radius, const torch::Tensor& radius,
const int bin_size, const int bin_size,
const int max_points_per_bin) { const int max_points_per_bin) {
@ -182,7 +184,7 @@ torch::Tensor RasterizePointsCoarse(
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsFineCuda( std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsFineCuda(
const torch::Tensor& points, const torch::Tensor& points,
const torch::Tensor& bin_points, const torch::Tensor& bin_points,
const int image_size, const std::tuple<int, int> image_size,
const torch::Tensor& radius, const torch::Tensor& radius,
const int bin_size, const int bin_size,
const int points_per_pixel); const int points_per_pixel);
@ -194,7 +196,8 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsFineCuda(
// are expected to be in NDC coordinates in the range [-1, 1]. // are expected to be in NDC coordinates in the range [-1, 1].
// bin_points: int32 Tensor of shape (N, B, B, M) giving the indices of points // bin_points: int32 Tensor of shape (N, B, B, M) giving the indices of points
// that fall into each bin (output from coarse rasterization) // that fall into each bin (output from coarse rasterization)
// image_size: Size of image to generate (in pixels) // image_size: Tuple (H, W) giving the size in pixels of the output
// image to be rasterized.
// radius: FloatTensor of shape (P) giving the radius (in NDC units) of // radius: FloatTensor of shape (P) giving the radius (in NDC units) of
// each point in points. // each point in points.
// bin_size: Size of each bin (in pixels) // bin_size: Size of each bin (in pixels)
@ -214,7 +217,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsFineCuda(
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsFine( std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsFine(
const torch::Tensor& points, const torch::Tensor& points,
const torch::Tensor& bin_points, const torch::Tensor& bin_points,
const int image_size, const std::tuple<int, int> image_size,
const torch::Tensor& radius, const torch::Tensor& radius,
const int bin_size, const int bin_size,
const int points_per_pixel) { const int points_per_pixel) {
@ -303,7 +306,8 @@ torch::Tensor RasterizePointsBackward(
// for each pointcloud in the batch. // for each pointcloud in the batch.
// radius: FloatTensor of shape (P) giving the radius (in NDC units) of // radius: FloatTensor of shape (P) giving the radius (in NDC units) of
// each point in points. // each point in points.
// image_size: (S) Size of the image to return (in pixels) // image_size: Tuple (H, W) giving the size in pixels of the output
// image to be rasterized.
// points_per_pixel: (K) The number of points to return for each pixel // points_per_pixel: (K) The number of points to return for each pixel
// bin_size: Bin size (in pixels) for coarse-to-fine rasterization. Setting // bin_size: Bin size (in pixels) for coarse-to-fine rasterization. Setting
// bin_size=0 uses naive rasterization instead. // bin_size=0 uses naive rasterization instead.
@ -325,7 +329,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePoints(
const torch::Tensor& points, const torch::Tensor& points,
const torch::Tensor& cloud_to_packed_first_idx, const torch::Tensor& cloud_to_packed_first_idx,
const torch::Tensor& num_points_per_cloud, const torch::Tensor& num_points_per_cloud,
const int image_size, const std::tuple<int, int> image_size,
const torch::Tensor& radius, const torch::Tensor& radius,
const int points_per_pixel, const int points_per_pixel,
const int bin_size, const int bin_size,

View File

@ -3,33 +3,27 @@
#include <torch/extension.h> #include <torch/extension.h>
#include <queue> #include <queue>
#include <tuple> #include <tuple>
#include "rasterization_utils.h"
// Given a pixel coordinate 0 <= i < S, convert it to a normalized device
// coordinate in the range [-1, 1]. The NDC range is divided into S evenly-sized
// pixels, and assume that each pixel falls in the *center* of its range.
static float PixToNdc(const int i, const int S) {
// NDC x-offset + (i * pixel_width + half_pixel_width)
return -1 + (2 * i + 1.0f) / S;
}
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsNaiveCpu( std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsNaiveCpu(
const torch::Tensor& points, // (P, 3) const torch::Tensor& points, // (P, 3)
const torch::Tensor& cloud_to_packed_first_idx, // (N) const torch::Tensor& cloud_to_packed_first_idx, // (N)
const torch::Tensor& num_points_per_cloud, // (N) const torch::Tensor& num_points_per_cloud, // (N)
const int image_size, const std::tuple<int, int> image_size,
const torch::Tensor& radius, const torch::Tensor& radius,
const int points_per_pixel) { const int points_per_pixel) {
const int32_t N = cloud_to_packed_first_idx.size(0); // batch_size. const int32_t N = cloud_to_packed_first_idx.size(0); // batch_size.
const int S = image_size; const int H = std::get<0>(image_size);
const int W = std::get<1>(image_size);
const int K = points_per_pixel; const int K = points_per_pixel;
// Initialize output tensors. // Initialize output tensors.
auto int_opts = num_points_per_cloud.options().dtype(torch::kInt32); auto int_opts = num_points_per_cloud.options().dtype(torch::kInt32);
auto float_opts = points.options().dtype(torch::kFloat32); auto float_opts = points.options().dtype(torch::kFloat32);
torch::Tensor point_idxs = torch::full({N, S, S, K}, -1, int_opts); torch::Tensor point_idxs = torch::full({N, H, W, K}, -1, int_opts);
torch::Tensor zbuf = torch::full({N, S, S, K}, -1, float_opts); torch::Tensor zbuf = torch::full({N, H, W, K}, -1, float_opts);
torch::Tensor pix_dists = torch::full({N, S, S, K}, -1, float_opts); torch::Tensor pix_dists = torch::full({N, H, W, K}, -1, float_opts);
auto points_a = points.accessor<float, 2>(); auto points_a = points.accessor<float, 2>();
auto point_idxs_a = point_idxs.accessor<int32_t, 4>(); auto point_idxs_a = point_idxs.accessor<int32_t, 4>();
@ -46,16 +40,16 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsNaiveCpu(
const int point_stop_idx = const int point_stop_idx =
(point_start_idx + num_points_per_cloud[n].item().to<int32_t>()); (point_start_idx + num_points_per_cloud[n].item().to<int32_t>());
for (int yi = 0; yi < S; ++yi) { for (int yi = 0; yi < H; ++yi) {
// Reverse the order of yi so that +Y is pointing upwards in the image. // Reverse the order of yi so that +Y is pointing upwards in the image.
const int yidx = S - 1 - yi; const int yidx = H - 1 - yi;
const float yf = PixToNdc(yidx, S); const float yf = PixToNonSquareNdc(yidx, H, W);
for (int xi = 0; xi < S; ++xi) { for (int xi = 0; xi < W; ++xi) {
// Reverse the order of xi so that +X is pointing to the left in the // Reverse the order of xi so that +X is pointing to the left in the
// image. // image.
const int xidx = S - 1 - xi; const int xidx = W - 1 - xi;
const float xf = PixToNdc(xidx, S); const float xf = PixToNonSquareNdc(xidx, W, H);
// Use a priority queue to hold (z, idx, r) // Use a priority queue to hold (z, idx, r)
std::priority_queue<std::tuple<float, int, float>> q; std::priority_queue<std::tuple<float, int, float>> q;
@ -99,25 +93,36 @@ torch::Tensor RasterizePointsCoarseCpu(
const torch::Tensor& points, // (P, 3) const torch::Tensor& points, // (P, 3)
const torch::Tensor& cloud_to_packed_first_idx, // (N) const torch::Tensor& cloud_to_packed_first_idx, // (N)
const torch::Tensor& num_points_per_cloud, // (N) const torch::Tensor& num_points_per_cloud, // (N)
const int image_size, const std::tuple<int, int> image_size,
const torch::Tensor& radius, const torch::Tensor& radius,
const int bin_size, const int bin_size,
const int max_points_per_bin) { const int max_points_per_bin) {
const int32_t N = cloud_to_packed_first_idx.size(0); // batch_size. const int32_t N = cloud_to_packed_first_idx.size(0); // batch_size.
const int B = 1 + (image_size - 1) / bin_size; // Integer division round up
const int M = max_points_per_bin; const int M = max_points_per_bin;
const float H = std::get<0>(image_size);
const float W = std::get<1>(image_size);
// Integer division round up.
const int BH = 1 + (H - 1) / bin_size;
const int BW = 1 + (W - 1) / bin_size;
auto opts = num_points_per_cloud.options().dtype(torch::kInt32); auto opts = num_points_per_cloud.options().dtype(torch::kInt32);
torch::Tensor points_per_bin = torch::zeros({N, B, B}, opts); torch::Tensor points_per_bin = torch::zeros({N, BH, BW}, opts);
torch::Tensor bin_points = torch::full({N, B, B, M}, -1, opts); torch::Tensor bin_points = torch::full({N, BH, BW, M}, -1, opts);
auto points_a = points.accessor<float, 2>(); auto points_a = points.accessor<float, 2>();
auto points_per_bin_a = points_per_bin.accessor<int32_t, 3>(); auto points_per_bin_a = points_per_bin.accessor<int32_t, 3>();
auto bin_points_a = bin_points.accessor<int32_t, 4>(); auto bin_points_a = bin_points.accessor<int32_t, 4>();
auto radius_a = radius.accessor<float, 1>(); auto radius_a = radius.accessor<float, 1>();
const float pixel_width = 2.0f / image_size; const float ndc_x_range = NonSquareNdcRange(W, H);
const float bin_width = pixel_width * bin_size; const float pixel_width_x = ndc_x_range / W;
const float bin_width_x = pixel_width_x * bin_size;
const float ndc_y_range = NonSquareNdcRange(H, W);
const float pixel_width_y = ndc_y_range / H;
const float bin_width_y = pixel_width_y * bin_size;
for (int n = 0; n < N; ++n) { for (int n = 0; n < N; ++n) {
// Loop through each pointcloud in the batch. // Loop through each pointcloud in the batch.
@ -129,15 +134,15 @@ torch::Tensor RasterizePointsCoarseCpu(
(point_start_idx + num_points_per_cloud[n].item().to<int32_t>()); (point_start_idx + num_points_per_cloud[n].item().to<int32_t>());
float bin_y_min = -1.0f; float bin_y_min = -1.0f;
float bin_y_max = bin_y_min + bin_width; float bin_y_max = bin_y_min + bin_width_y;
// Iterate through the horizontal bins from top to bottom. // Iterate through the horizontal bins from top to bottom.
for (int by = 0; by < B; by++) { for (int by = 0; by < BH; by++) {
float bin_x_min = -1.0f; float bin_x_min = -1.0f;
float bin_x_max = bin_x_min + bin_width; float bin_x_max = bin_x_min + bin_width_x;
// Iterate through bins on this horizontal line, left to right. // Iterate through bins on this horizontal line, left to right.
for (int bx = 0; bx < B; bx++) { for (int bx = 0; bx < BW; bx++) {
int32_t points_hit = 0; int32_t points_hit = 0;
for (int p = point_start_idx; p < point_stop_idx; ++p) { for (int p = point_start_idx; p < point_stop_idx; ++p) {
float px = points_a[p][0]; float px = points_a[p][0];
@ -172,11 +177,11 @@ torch::Tensor RasterizePointsCoarseCpu(
// Shift the bin to the right for the next loop iteration // Shift the bin to the right for the next loop iteration
bin_x_min = bin_x_max; bin_x_min = bin_x_max;
bin_x_max = bin_x_min + bin_width; bin_x_max = bin_x_min + bin_width_x;
} }
// Shift the bin down for the next loop iteration // Shift the bin down for the next loop iteration
bin_y_min = bin_y_max; bin_y_min = bin_y_max;
bin_y_max = bin_y_min + bin_width; bin_y_max = bin_y_min + bin_width_y;
} }
} }
return bin_points; return bin_points;
@ -194,11 +199,6 @@ torch::Tensor RasterizePointsBackwardCpu(
const int W = idxs.size(2); const int W = idxs.size(2);
const int K = idxs.size(3); const int K = idxs.size(3);
// For now only support square images.
// TODO(jcjohns): Extend to non-square images.
if (H != W) {
AT_ERROR("RasterizePointsBackwardCpu only supports square images");
}
torch::Tensor grad_points = torch::zeros({P, 3}, points.options()); torch::Tensor grad_points = torch::zeros({P, 3}, points.options());
auto points_a = points.accessor<float, 2>(); auto points_a = points.accessor<float, 2>();
@ -212,7 +212,7 @@ torch::Tensor RasterizePointsBackwardCpu(
// Reverse the order of yi so that +Y is pointing upwards in the image. // Reverse the order of yi so that +Y is pointing upwards in the image.
const int yidx = H - 1 - y; const int yidx = H - 1 - y;
// Y coordinate of the top of the pixel. // Y coordinate of the top of the pixel.
const float yf = PixToNdc(yidx, H); const float yf = PixToNonSquareNdc(yidx, H, W);
// Iterate through pixels on this horizontal line, left to right. // Iterate through pixels on this horizontal line, left to right.
for (int x = 0; x < W; ++x) { // Loop over pixels in the row for (int x = 0; x < W; ++x) { // Loop over pixels in the row
@ -220,7 +220,7 @@ torch::Tensor RasterizePointsBackwardCpu(
// Reverse the order of xi so that +X is pointing to the left in the // Reverse the order of xi so that +X is pointing to the left in the
// image. // image.
const int xidx = W - 1 - x; const int xidx = W - 1 - x;
const float xf = PixToNdc(xidx, W); const float xf = PixToNonSquareNdc(xidx, W, H);
for (int k = 0; k < K; ++k) { // Loop over points for the pixel for (int k = 0; k < K; ++k) { // Loop over points for the pixel
const int p = idxs_a[n][y][x][k]; const int p = idxs_a[n][y][x][k];
if (p < 0) { if (p < 0) {

View File

@ -6,6 +6,7 @@ from .rasterizer import MeshRasterizer, RasterizationSettings
from .renderer import MeshRenderer from .renderer import MeshRenderer
from .shader import TexturedSoftPhongShader # DEPRECATED from .shader import TexturedSoftPhongShader # DEPRECATED
from .shader import ( from .shader import (
BlendParams,
HardFlatShader, HardFlatShader,
HardGouraudShader, HardGouraudShader,
HardPhongShader, HardPhongShader,

View File

@ -1,7 +1,7 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from typing import Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
@ -20,7 +20,7 @@ kMaxFacesPerBin = 22
def rasterize_meshes( def rasterize_meshes(
meshes, meshes,
image_size: Union[int, Tuple[int, int]] = 256, image_size: Union[int, List[int], Tuple[int, int]] = 256,
blur_radius: float = 0.0, blur_radius: float = 0.0,
faces_per_pixel: int = 8, faces_per_pixel: int = 8,
bin_size: Optional[int] = None, bin_size: Optional[int] = None,
@ -219,7 +219,7 @@ class _RasterizeFaceVerts(torch.autograd.Function):
face_verts, face_verts,
mesh_to_face_first_idx, mesh_to_face_first_idx,
num_faces_per_mesh, num_faces_per_mesh,
image_size: Tuple[int, int] = (256, 256), image_size: Union[List[int], Tuple[int, int]] = (256, 256),
blur_radius: float = 0.01, blur_radius: float = 0.01,
faces_per_pixel: int = 0, faces_per_pixel: int = 0,
bin_size: int = 0, bin_size: int = 0,
@ -287,11 +287,6 @@ class _RasterizeFaceVerts(torch.autograd.Function):
return grads return grads
def pix_to_ndc(i, S):
# NDC x-offset + (i * pixel_width + half_pixel_width)
return -1 + (2 * i + 1.0) / S
def non_square_ndc_range(S1, S2): def non_square_ndc_range(S1, S2):
""" """
In the case of non square images, we scale the NDC range In the case of non square images, we scale the NDC range

View File

@ -75,7 +75,7 @@ def _add_background_color_to_images(pix_idxs, images, background_color):
pixels with accumulated features have unchanged values. pixels with accumulated features have unchanged values.
""" """
# Initialize background mask # Initialize background mask
background_mask = pix_idxs[:, 0] < 0 # (N, image_size, image_size) background_mask = pix_idxs[:, 0] < 0 # (N, H, W)
# Convert background_color to an appropriate tensor and check shape # Convert background_color to an appropriate tensor and check shape
if not torch.is_tensor(background_color): if not torch.is_tensor(background_color):

View File

@ -6,7 +6,7 @@ import torch
# pyre-fixme[21]: Could not find name `_C` in `pytorch3d`. # pyre-fixme[21]: Could not find name `_C` in `pytorch3d`.
from pytorch3d import _C from pytorch3d import _C
from pytorch3d.renderer.mesh.rasterize_meshes import pix_to_ndc from pytorch3d.renderer.mesh.rasterize_meshes import pix_to_non_square_ndc
# Maxinum number of faces per bins for # Maxinum number of faces per bins for
@ -14,17 +14,30 @@ from pytorch3d.renderer.mesh.rasterize_meshes import pix_to_ndc
kMaxPointsPerBin = 22 kMaxPointsPerBin = 22
# TODO(jcjohns): Support non-square images
def rasterize_points( def rasterize_points(
pointclouds, pointclouds,
image_size: int = 256, image_size: Union[int, List[int], Tuple[int, int]] = 256,
radius: Union[float, List, Tuple, torch.Tensor] = 0.01, radius: Union[float, List, Tuple, torch.Tensor] = 0.01,
points_per_pixel: int = 8, points_per_pixel: int = 8,
bin_size: Optional[int] = None, bin_size: Optional[int] = None,
max_points_per_bin: Optional[int] = None, max_points_per_bin: Optional[int] = None,
): ):
""" """
Pointcloud rasterization Each pointcloud is rasterized onto a separate image of shape
(H, W) if `image_size` is a tuple or (image_size, image_size) if it
is an int.
If the desired image size is non square (i.e. a tuple of (H, W) where H != W)
the aspect ratio needs special consideration. There are two aspect ratios
to be aware of:
- the aspect ratio of each pixel
- the aspect ratio of the output image
The camera can be used to set the pixel aspect ratio. In the rasterizer,
we assume square pixels, but variable image aspect ratio (i.e rectangle images).
In most cases you will want to set the camera aspect ratio to
1.0 (i.e. square pixels) and only vary the
`image_size` (i.e. the output image dimensions in pix
Args: Args:
pointclouds: A Pointclouds object representing a batch of point clouds to be pointclouds: A Pointclouds object representing a batch of point clouds to be
@ -34,7 +47,8 @@ def rasterize_points(
be in normalized device coordinates (NDC): [-1, 1]^3 with the camera at be in normalized device coordinates (NDC): [-1, 1]^3 with the camera at
(0, 0, 0); In the camera coordinate frame the x-axis goes from right-to-left, (0, 0, 0); In the camera coordinate frame the x-axis goes from right-to-left,
the y-axis goes from bottom-to-top, and the z-axis goes from back-to-front. the y-axis goes from bottom-to-top, and the z-axis goes from back-to-front.
image_size: Integer giving the resolution of the rasterized image image_size: Size in pixels of the output image to be rasterized.
Can optionally be a tuple of (H, W) in the case of non square images.
radius (Optional): The radius (in NDC units) of the disk to radius (Optional): The radius (in NDC units) of the disk to
be rasterized. This can either be a float in which case the same radius is used be rasterized. This can either be a float in which case the same radius is used
for each point, or a torch.Tensor of shape (N, P) giving a radius per point for each point, or a torch.Tensor of shape (N, P) giving a radius per point
@ -71,6 +85,9 @@ def rasterize_points(
then `dists[n, y, x, k]` is the squared distance between the pixel (y, x) then `dists[n, y, x, k]` is the squared distance between the pixel (y, x)
and the point `(points[n, p, 0], points[n, p, 1])`. Pixels hit with fewer and the point `(points[n, p, 0], points[n, p, 1])`. Pixels hit with fewer
than points_per_pixel are padded with -1. than points_per_pixel are padded with -1.
In the case that image_size is a tuple of (H, W) then the outputs
will be of shape `(N, H, W, ...)`.
""" """
points_packed = pointclouds.points_packed() points_packed = pointclouds.points_packed()
cloud_to_packed_first_idx = pointclouds.cloud_to_packed_first_idx() cloud_to_packed_first_idx = pointclouds.cloud_to_packed_first_idx()
@ -78,26 +95,46 @@ def rasterize_points(
radius = _format_radius(radius, pointclouds) radius = _format_radius(radius, pointclouds)
# In the case that H != W use the max image size to set the bin_size
# to accommodate the num bins constraint in the coarse rasteizer.
# If the ratio of H:W is large this might cause issues as the smaller
# dimension will have fewer bins.
# TODO: consider a better way of setting the bin size.
if isinstance(image_size, (tuple, list)):
if len(image_size) != 2:
raise ValueError("Image size can only be a tuple/list of (H, W)")
if not all(i > 0 for i in image_size):
raise ValueError(
"Image sizes must be greater than 0; got %d, %d" % image_size
)
if not all(type(i) == int for i in image_size):
raise ValueError("Image sizes must be integers; got %f, %f" % image_size)
max_image_size = max(*image_size)
im_size = image_size
else:
im_size = (image_size, image_size)
max_image_size = image_size
if bin_size is None: if bin_size is None:
if not points_packed.is_cuda: if not points_packed.is_cuda:
# Binned CPU rasterization not fully implemented # Binned CPU rasterization not fully implemented
bin_size = 0 bin_size = 0
else: else:
# TODO: These heuristics are not well-thought out! # TODO: These heuristics are not well-thought out!
if image_size <= 64: if max_image_size <= 64:
bin_size = 8 bin_size = 8
elif image_size <= 256: elif max_image_size <= 256:
bin_size = 16 bin_size = 16
elif image_size <= 512: elif max_image_size <= 512:
bin_size = 32 bin_size = 32
elif image_size <= 1024: elif max_image_size <= 1024:
bin_size = 64 bin_size = 64
if bin_size != 0: if bin_size != 0:
# There is a limit on the number of points per bin in the cuda kernel. # There is a limit on the number of points per bin in the cuda kernel.
# pyre-fixme[58]: `//` is not supported for operand types `int` and # pyre-fixme[58]: `//` is not supported for operand types `int` and
# `Union[int, None, int]`. # `Union[int, None, int]`.
points_per_bin = 1 + (image_size - 1) // bin_size points_per_bin = 1 + (max_image_size - 1) // bin_size
if points_per_bin >= kMaxPointsPerBin: if points_per_bin >= kMaxPointsPerBin:
raise ValueError( raise ValueError(
"bin_size too small, number of points per bin must be less than %d; got %d" "bin_size too small, number of points per bin must be less than %d; got %d"
@ -114,7 +151,7 @@ def rasterize_points(
points_packed, points_packed,
cloud_to_packed_first_idx, cloud_to_packed_first_idx,
num_points_per_cloud, num_points_per_cloud,
image_size, im_size,
radius, radius,
points_per_pixel, points_per_pixel,
bin_size, bin_size,
@ -173,7 +210,7 @@ class _RasterizePoints(torch.autograd.Function):
points, # (P, 3) points, # (P, 3)
cloud_to_packed_first_idx, cloud_to_packed_first_idx,
num_points_per_cloud, num_points_per_cloud,
image_size: int = 256, image_size: Union[List[int], Tuple[int, int]] = (256, 256),
radius: Union[float, torch.Tensor] = 0.01, radius: Union[float, torch.Tensor] = 0.01,
points_per_pixel: int = 8, points_per_pixel: int = 8,
bin_size: int = 0, bin_size: int = 0,
@ -225,7 +262,7 @@ class _RasterizePoints(torch.autograd.Function):
def rasterize_points_python( def rasterize_points_python(
pointclouds, pointclouds,
image_size: int = 256, image_size: Union[int, Tuple[int, int]] = 256,
radius: Union[float, torch.Tensor] = 0.01, radius: Union[float, torch.Tensor] = 0.01,
points_per_pixel: int = 8, points_per_pixel: int = 8,
): ):
@ -235,7 +272,12 @@ def rasterize_points_python(
Inputs / Outputs: Same as above Inputs / Outputs: Same as above
""" """
N = len(pointclouds) N = len(pointclouds)
S, K = image_size, points_per_pixel H, W = (
image_size
if isinstance(image_size, (tuple, list))
else (image_size, image_size)
)
K = points_per_pixel
device = pointclouds.device device = pointclouds.device
points_packed = pointclouds.points_packed() points_packed = pointclouds.points_packed()
@ -247,11 +289,11 @@ def rasterize_points_python(
# Intialize output tensors. # Intialize output tensors.
point_idxs = torch.full( point_idxs = torch.full(
(N, S, S, K), fill_value=-1, dtype=torch.int32, device=device (N, H, W, K), fill_value=-1, dtype=torch.int32, device=device
) )
zbuf = torch.full((N, S, S, K), fill_value=-1, dtype=torch.float32, device=device) zbuf = torch.full((N, H, W, K), fill_value=-1, dtype=torch.float32, device=device)
pix_dists = torch.full( pix_dists = torch.full(
(N, S, S, K), fill_value=-1, dtype=torch.float32, device=device (N, H, W, K), fill_value=-1, dtype=torch.float32, device=device
) )
# NDC is from [-1, 1]. Get pixel size using specified image size. # NDC is from [-1, 1]. Get pixel size using specified image size.
@ -263,18 +305,18 @@ def rasterize_points_python(
point_stop_idx = point_start_idx + num_points_per_cloud[n] point_stop_idx = point_start_idx + num_points_per_cloud[n]
# Iterate through the horizontal lines of the image from top to bottom. # Iterate through the horizontal lines of the image from top to bottom.
for yi in range(S): for yi in range(H):
# Y coordinate of one end of the image. Reverse the ordering # Y coordinate of one end of the image. Reverse the ordering
# of yi so that +Y is pointing up in the image. # of yi so that +Y is pointing up in the image.
yfix = S - 1 - yi yfix = H - 1 - yi
yf = pix_to_ndc(yfix, S) yf = pix_to_non_square_ndc(yfix, H, W)
# Iterate through pixels on this horizontal line, left to right. # Iterate through pixels on this horizontal line, left to right.
for xi in range(S): for xi in range(W):
# X coordinate of one end of the image. Reverse the ordering # X coordinate of one end of the image. Reverse the ordering
# of xi so that +X is pointing to the left in the image. # of xi so that +X is pointing to the left in the image.
xfix = S - 1 - xi xfix = W - 1 - xi
xf = pix_to_ndc(xfix, S) xf = pix_to_non_square_ndc(xfix, W, H)
top_k_points = [] top_k_points = []
# Check whether each point in the batch affects this pixel. # Check whether each point in the batch affects this pixel.

View File

@ -2,7 +2,7 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from typing import NamedTuple, Optional, Union from typing import NamedTuple, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -29,7 +29,7 @@ class PointsRasterizationSettings:
def __init__( def __init__(
self, self,
image_size: int = 256, image_size: Union[int, Tuple[int, int]] = 256,
radius: Union[float, torch.Tensor] = 0.01, radius: Union[float, torch.Tensor] = 0.01,
points_per_pixel: int = 8, points_per_pixel: int = 8,
bin_size: Optional[int] = None, bin_size: Optional[int] = None,

View File

@ -74,6 +74,21 @@ def bm_python_vs_cpu_vs_cuda() -> None:
kwargs_list += [ kwargs_list += [
{"N": 32, "P": 100000, "img_size": 128, "radius": 0.01, "pts_per_pxl": 50}, {"N": 32, "P": 100000, "img_size": 128, "radius": 0.01, "pts_per_pxl": 50},
{"N": 8, "P": 200000, "img_size": 512, "radius": 0.01, "pts_per_pxl": 50}, {"N": 8, "P": 200000, "img_size": 512, "radius": 0.01, "pts_per_pxl": 50},
{"N": 8, "P": 200000, "img_size": 256, "radius": 0.01, "pts_per_pxl": 50},
{
"N": 8,
"P": 200000,
"img_size": (512, 256),
"radius": 0.01,
"pts_per_pxl": 50,
},
{
"N": 8,
"P": 200000,
"img_size": (256, 512),
"radius": 0.01,
"pts_per_pxl": 50,
},
] ]
for k in kwargs_list: for k in kwargs_list:
k["device"] = "cuda" k["device"] = "cuda"

Binary file not shown.

After

Width:  |  Height:  |  Size: 20 KiB

View File

@ -404,7 +404,7 @@ class TestRasterizePoints(TestCaseMixin, unittest.TestCase):
torch.manual_seed(231) torch.manual_seed(231)
N = 3 N = 3
max_P = 1000 max_P = 1000
image_size = 64 image_size = (64, 64)
radius = 0.1 radius = 0.1
bin_size = 16 bin_size = 16
max_points_per_bin = 500 max_points_per_bin = 500
@ -501,7 +501,7 @@ class TestRasterizePoints(TestCaseMixin, unittest.TestCase):
device=device device=device
) )
# fmt: on # fmt: on
image_size = 16 image_size = (16, 16)
radius = 0.2 radius = 0.2
bin_size = 8 bin_size = 8
max_points_per_bin = 5 max_points_per_bin = 5

View File

@ -12,19 +12,33 @@ from pytorch3d.io import load_obj
from pytorch3d.renderer.cameras import FoVPerspectiveCameras, look_at_view_transform from pytorch3d.renderer.cameras import FoVPerspectiveCameras, look_at_view_transform
from pytorch3d.renderer.lighting import PointLights from pytorch3d.renderer.lighting import PointLights
from pytorch3d.renderer.materials import Materials from pytorch3d.renderer.materials import Materials
from pytorch3d.renderer.mesh import TexturesUV from pytorch3d.renderer.mesh import (
BlendParams,
MeshRasterizer,
MeshRenderer,
RasterizationSettings,
SoftPhongShader,
TexturesUV,
)
from pytorch3d.renderer.mesh.rasterize_meshes import ( from pytorch3d.renderer.mesh.rasterize_meshes import (
rasterize_meshes, rasterize_meshes,
rasterize_meshes_python, rasterize_meshes_python,
) )
from pytorch3d.renderer.mesh.rasterizer import ( from pytorch3d.renderer.mesh.rasterizer import Fragments
Fragments, from pytorch3d.renderer.points import (
MeshRasterizer, AlphaCompositor,
RasterizationSettings, PointsRasterizationSettings,
PointsRasterizer,
PointsRenderer,
) )
from pytorch3d.renderer.mesh.renderer import MeshRenderer from pytorch3d.renderer.points.rasterize_points import (
from pytorch3d.renderer.mesh.shader import BlendParams, SoftPhongShader rasterize_points,
from pytorch3d.structures import Meshes rasterize_points_python,
)
from pytorch3d.renderer.points.rasterizer import PointFragments
from pytorch3d.structures import Meshes, Pointclouds
from pytorch3d.transforms.transform3d import Transform3d
from pytorch3d.utils import torus
DEBUG = False DEBUG = False
@ -44,9 +58,36 @@ verts0 = torch.tensor(
) )
faces0 = torch.tensor([[1, 0, 2], [4, 3, 5]], dtype=torch.int64) faces0 = torch.tensor([[1, 0, 2], [4, 3, 5]], dtype=torch.int64)
# Points for a simple pointcloud. Get the vertices from a
# torus and apply rotations such that the points are no longer
# symmerical in X/Y.
torus_mesh = torus(r=0.25, R=1.0, sides=5, rings=2 * 5)
t = (
Transform3d()
.rotate_axis_angle(angle=90, axis="Y")
.rotate_axis_angle(angle=45, axis="Z")
.scale(0.3)
)
torus_points = t.transform_points(torus_mesh.verts_padded()).squeeze()
class TestRasterizeRectanglesErrors(TestCaseMixin, unittest.TestCase):
def test_image_size_arg(self): def _save_debug_image(idx, image_size, bin_size, blur):
"""
Save a mask image from the rasterization output for debugging.
"""
H, W = image_size
# Save out the last image for debugging
rgb = (idx[-1, ..., :3].cpu() > -1).squeeze()
suffix = "square" if H == W else "non_square"
filename = "%s_bin_size_%s_blur_%.3f_%dx%d.png"
filename = filename % (suffix, str(bin_size), blur, H, W)
if DEBUG:
filename = "DEBUG_%s" % filename
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(DATA_DIR / filename)
class TestRasterizeRectangleImagesErrors(TestCaseMixin, unittest.TestCase):
def test_mesh_image_size_arg(self):
meshes = Meshes(verts=[verts0], faces=[faces0]) meshes = Meshes(verts=[verts0], faces=[faces0])
with self.assertRaises(ValueError) as cm: with self.assertRaises(ValueError) as cm:
@ -76,8 +117,38 @@ class TestRasterizeRectanglesErrors(TestCaseMixin, unittest.TestCase):
) )
self.assertTrue("sizes must be integers" in cm.msg) self.assertTrue("sizes must be integers" in cm.msg)
def test_points_image_size_arg(self):
points = Pointclouds([verts0])
class TestRasterizeRectangles(TestCaseMixin, unittest.TestCase): with self.assertRaises(ValueError) as cm:
rasterize_points(
points,
(100, 200, 3),
0.0001,
points_per_pixel=1,
)
self.assertTrue("tuple/list of (H, W)" in cm.msg)
with self.assertRaises(ValueError) as cm:
rasterize_points(
points,
(0, 10),
0.0001,
points_per_pixel=1,
)
self.assertTrue("sizes must be positive" in cm.msg)
with self.assertRaises(ValueError) as cm:
rasterize_points(
points,
(100.5, 120.5),
0.0001,
points_per_pixel=1,
)
self.assertTrue("sizes must be integers" in cm.msg)
class TestRasterizeRectangleImagesMeshes(TestCaseMixin, unittest.TestCase):
@staticmethod @staticmethod
def _clone_mesh(verts0, faces0, device, batch_size): def _clone_mesh(verts0, faces0, device, batch_size):
""" """
@ -164,7 +235,7 @@ class TestRasterizeRectangles(TestCaseMixin, unittest.TestCase):
meshes_sq, image_size=(S, S), bin_size=0, blur=blur meshes_sq, image_size=(S, S), bin_size=0, blur=blur
) )
# Save debug image # Save debug image
self._save_debug_image(square_fragments, (S, S), 0, blur) _save_debug_image(square_fragments.pix_to_face, (S, S), 0, blur)
# Extract the values in the square image which are non zero. # Extract the values in the square image which are non zero.
square_mask = square_fragments.pix_to_face > -1 square_mask = square_fragments.pix_to_face > -1
@ -284,8 +355,8 @@ class TestRasterizeRectangles(TestCaseMixin, unittest.TestCase):
) )
# Save out debug images if needed # Save out debug images if needed
self._save_debug_image(fragments_naive, image_size, 0, blur) _save_debug_image(fragments_naive.pix_to_face, image_size, 0, blur)
self._save_debug_image(fragments_binned, image_size, None, blur) _save_debug_image(fragments_binned.pix_to_face, image_size, None, blur)
# Check naive and binned fragments give the same outputs # Check naive and binned fragments give the same outputs
self._check_fragments(fragments_naive, fragments_binned) self._check_fragments(fragments_naive, fragments_binned)
@ -354,8 +425,8 @@ class TestRasterizeRectangles(TestCaseMixin, unittest.TestCase):
) )
# Save debug images if DEBUG is set to true at the top of the file. # Save debug images if DEBUG is set to true at the top of the file.
self._save_debug_image(fragments_naive, image_size, 0, blur) _save_debug_image(fragments_naive.pix_to_face, image_size, 0, blur)
self._save_debug_image(fragments_python, image_size, "python", blur) _save_debug_image(fragments_python.pix_to_face, image_size, "python", blur)
# List of non square outputs to compare with the square output # List of non square outputs to compare with the square output
nonsq_fragment_gradtensor_list = [ nonsq_fragment_gradtensor_list = [
@ -437,3 +508,293 @@ class TestRasterizeRectangles(TestCaseMixin, unittest.TestCase):
# NOTE some pixels can be flaky # NOTE some pixels can be flaky
cond1 = torch.allclose(rgb, image_ref, atol=0.05) cond1 = torch.allclose(rgb, image_ref, atol=0.05)
self.assertTrue(cond1) self.assertTrue(cond1)
class TestRasterizeRectangleImagesPointclouds(TestCaseMixin, unittest.TestCase):
@staticmethod
def _clone_pointcloud(verts0, device, batch_size):
"""
Helper function to detach and clone the verts.
This is needed in order to set up the tensors for
gradient computation in different tests.
"""
verts = verts0.detach().clone()
verts.requires_grad = True
pointclouds = Pointclouds(points=[verts])
pointclouds = pointclouds.to(device).extend(batch_size)
return verts, pointclouds
def _rasterize(self, meshes, image_size, bin_size, blur):
"""
Simple wrapper around the rasterize function to return
the fragment data.
"""
idxs, zbuf, dists = rasterize_points(
meshes,
image_size,
blur,
points_per_pixel=1,
bin_size=bin_size,
)
return PointFragments(
idx=idxs,
zbuf=zbuf,
dists=dists,
)
def _check_fragments(self, frag_1, frag_2):
"""
Helper function to check that the tensors in
the Fragments frag_1 and frag_2 are the same.
"""
self.assertClose(frag_1.idx, frag_2.idx)
self.assertClose(frag_1.dists, frag_2.dists)
self.assertClose(frag_1.zbuf, frag_2.zbuf)
def _compare_square_with_nonsq(
self,
image_size,
blur,
device,
points,
nonsq_fragment_gradtensor_list,
batch_size=1,
):
"""
Calculate the output from rasterizing a square image with the minimum of (H, W).
Then compare this with the same square region in the non square image.
The input points are contained within the [-1, 1] range of the image
so all the relevant pixels will be within the square region.
`nonsq_fragment_gradtensor_list` is a list of fragments and verts grad tensors
from rasterizing non square images.
"""
# Rasterize the square version of the image
H, W = image_size
S = min(H, W)
points_square, pointclouds_sq = self._clone_pointcloud(
points, device, batch_size
)
square_fragments = self._rasterize(
pointclouds_sq, image_size=(S, S), bin_size=0, blur=blur
)
# Save debug image
_save_debug_image(square_fragments.idx, (S, S), 0, blur)
# Extract the values in the square image which are non zero.
square_mask = square_fragments.idx > -1
square_dists = square_fragments.dists[square_mask]
square_zbuf = square_fragments.zbuf[square_mask]
# Retain gradients on the output of fragments to check
# intermediate values with the non square outputs.
square_fragments.dists.retain_grad()
square_fragments.zbuf.retain_grad()
# Calculate gradient for the square image
torch.manual_seed(231)
grad_zbuf = torch.randn_like(square_zbuf)
grad_dist = torch.randn_like(square_dists)
loss0 = (grad_dist * square_dists).sum() + (grad_zbuf * square_zbuf).sum()
loss0.backward()
# Now compare against the non square outputs provided
# in the nonsq_fragment_gradtensor_list list
for fragments, grad_tensor, _name in nonsq_fragment_gradtensor_list:
# Check that there are the same number of non zero pixels
# in both the square and non square images.
non_square_mask = fragments.idx > -1
self.assertEqual(non_square_mask.sum().item(), square_mask.sum().item())
# Check dists, zbuf and bary match the square image
non_square_dists = fragments.dists[non_square_mask]
non_square_zbuf = fragments.zbuf[non_square_mask]
self.assertClose(square_dists, non_square_dists)
self.assertClose(square_zbuf, non_square_zbuf)
# Retain gradients to compare values with outputs from
# square image
fragments.dists.retain_grad()
fragments.zbuf.retain_grad()
loss1 = (grad_dist * non_square_dists).sum() + (
grad_zbuf * non_square_zbuf
).sum()
loss1.sum().backward()
# Get the non zero values in the intermediate gradients
# and compare with the values from the square image
non_square_grad_dists = fragments.dists.grad[non_square_mask]
non_square_grad_zbuf = fragments.zbuf.grad[non_square_mask]
self.assertClose(
non_square_grad_dists,
square_fragments.dists.grad[square_mask],
)
self.assertClose(
non_square_grad_zbuf,
square_fragments.zbuf.grad[square_mask],
)
# Finally check the gradients of the input vertices for
# the square and non square case
self.assertClose(points_square.grad, grad_tensor.grad, rtol=2e-4)
def test_gpu(self):
"""
Test that the output of rendering non square images
gives the same result as square images. i.e. the
dists, zbuf, idx are all the same for the square
region which is present in both images.
"""
# Test both cases: (W > H), (H > W)
image_sizes = [(64, 128), (128, 64), (128, 256), (256, 128)]
devices = ["cuda:0"]
blurs = [5e-2]
batch_sizes = [1, 4]
test_cases = product(image_sizes, blurs, devices, batch_sizes)
for image_size, blur, device, batch_size in test_cases:
# Initialize the verts grad tensor and the meshes objects
verts_nonsq_naive, pointcloud_nonsq_naive = self._clone_pointcloud(
torus_points, device, batch_size
)
verts_nonsq_binned, pointcloud_nonsq_binned = self._clone_pointcloud(
torus_points, device, batch_size
)
# Get the outputs for both naive and coarse to fine rasterization
fragments_naive = self._rasterize(
pointcloud_nonsq_naive,
image_size,
blur=blur,
bin_size=0,
)
fragments_binned = self._rasterize(
pointcloud_nonsq_binned,
image_size,
blur=blur,
bin_size=None,
)
# Save out debug images if needed
_save_debug_image(fragments_naive.idx, image_size, 0, blur)
_save_debug_image(fragments_binned.idx, image_size, None, blur)
# Check naive and binned fragments give the same outputs
self._check_fragments(fragments_naive, fragments_binned)
# Here we want to compare the square image with the naive and the
# coarse to fine methods outputs
nonsq_fragment_gradtensor_list = [
(fragments_naive, verts_nonsq_naive, "naive"),
(fragments_binned, verts_nonsq_binned, "coarse-to-fine"),
]
self._compare_square_with_nonsq(
image_size,
blur,
device,
torus_points,
nonsq_fragment_gradtensor_list,
batch_size,
)
def test_cpu(self):
"""
Test that the output of rendering non square images
gives the same result as square images. i.e. the
dists, zbuf, idx are all the same for the square
region which is present in both images.
In this test we compare between the naive C++ implementation
and the naive python implementation as the Coarse/Fine
method is not fully implemented in C++
"""
# Test both when (W > H) and (H > W).
# Using smaller image sizes here as the Python rasterizer is really slow.
image_sizes = [(32, 64), (64, 32)]
devices = ["cpu"]
blurs = [5e-2]
batch_sizes = [1]
test_cases = product(image_sizes, blurs, devices, batch_sizes)
for image_size, blur, device, batch_size in test_cases:
# Initialize the verts grad tensor and the meshes objects
verts_nonsq_naive, pointcloud_nonsq_naive = self._clone_pointcloud(
torus_points, device, batch_size
)
verts_nonsq_python, pointcloud_nonsq_python = self._clone_pointcloud(
torus_points, device, batch_size
)
# Compare Naive CPU with Python as Coarse/Fine rasteriztation
# is not implemented for CPU
fragments_naive = self._rasterize(
pointcloud_nonsq_naive, image_size, bin_size=0, blur=blur
)
idxs, zbuf, pix_dists = rasterize_points_python(
pointcloud_nonsq_python,
image_size,
blur,
points_per_pixel=1,
)
fragments_python = PointFragments(
idx=idxs,
zbuf=zbuf,
dists=pix_dists,
)
# Save debug images if DEBUG is set to true at the top of the file.
_save_debug_image(fragments_naive.idx, image_size, 0, blur)
_save_debug_image(fragments_python.idx, image_size, "python", blur)
# List of non square outputs to compare with the square output
nonsq_fragment_gradtensor_list = [
(fragments_naive, verts_nonsq_naive, "naive"),
(fragments_python, verts_nonsq_python, "python"),
]
self._compare_square_with_nonsq(
image_size,
blur,
device,
torus_points,
nonsq_fragment_gradtensor_list,
batch_size,
)
def test_render_pointcloud(self):
"""
Test a textured poincloud is rendered correctly in a non square image.
"""
device = torch.device("cuda:0")
pointclouds = Pointclouds(
points=[torus_points * 2.0],
features=torch.ones_like(torus_points[None, ...]),
).to(device)
R, T = look_at_view_transform(2.7, 0.0, 0.0)
cameras = FoVPerspectiveCameras(device=device, R=R, T=T)
raster_settings = PointsRasterizationSettings(
image_size=(512, 1024), radius=5e-2, points_per_pixel=1
)
rasterizer = PointsRasterizer(cameras=cameras, raster_settings=raster_settings)
compositor = AlphaCompositor()
renderer = PointsRenderer(rasterizer=rasterizer, compositor=compositor)
# Load reference image
image_ref = load_rgb_image("test_pointcloud_rectangle_image.png", DATA_DIR)
for bin_size in [0, None]:
# Check both naive and coarse to fine produce the same output.
renderer.rasterizer.raster_settings.bin_size = bin_size
images = renderer(pointclouds)
rgb = images[0, ..., :3].squeeze().cpu()
if DEBUG:
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
DATA_DIR / "DEBUG_pointcloud_rectangle_image.png"
)
# NOTE some pixels can be flaky
cond1 = torch.allclose(rgb, image_ref, atol=0.05)
self.assertTrue(cond1)