mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Non Square image rasterization for pointclouds
Summary: Similar to non square image rasterization for meshes, apply the same updates to the pointcloud rasterizer. Main API Change: - PointRasterizationSettings now accepts a tuple/list of (H, W) for the image size. Reviewed By: jcjohnson Differential Revision: D25465206 fbshipit-source-id: 7370d83c431af1b972158cecae19d82364623380
This commit is contained in:
parent
569e5229a9
commit
3d769a66cb
@ -30,15 +30,15 @@ __global__ void alphaCompositeCudaForwardKernel(
|
|||||||
// Get the batch and index
|
// Get the batch and index
|
||||||
const int batch = blockIdx.x;
|
const int batch = blockIdx.x;
|
||||||
|
|
||||||
const int num_pixels = C * W * H;
|
const int num_pixels = C * H * W;
|
||||||
const int num_threads = gridDim.y * blockDim.x;
|
const int num_threads = gridDim.y * blockDim.x;
|
||||||
const int tid = blockIdx.y * blockDim.x + threadIdx.x;
|
const int tid = blockIdx.y * blockDim.x + threadIdx.x;
|
||||||
|
|
||||||
// Iterate over each feature in each pixel
|
// Iterate over each feature in each pixel
|
||||||
for (int pid = tid; pid < num_pixels; pid += num_threads) {
|
for (int pid = tid; pid < num_pixels; pid += num_threads) {
|
||||||
int ch = pid / (W * H);
|
int ch = pid / (H * W);
|
||||||
int j = (pid % (W * H)) / H;
|
int j = (pid % (H * W)) / W;
|
||||||
int i = (pid % (W * H)) % H;
|
int i = (pid % (H * W)) % W;
|
||||||
|
|
||||||
// alphacomposite the different values
|
// alphacomposite the different values
|
||||||
float cum_alpha = 1.;
|
float cum_alpha = 1.;
|
||||||
@ -81,16 +81,16 @@ __global__ void alphaCompositeCudaBackwardKernel(
|
|||||||
// Get the batch and index
|
// Get the batch and index
|
||||||
const int batch = blockIdx.x;
|
const int batch = blockIdx.x;
|
||||||
|
|
||||||
const int num_pixels = C * W * H;
|
const int num_pixels = C * H * W;
|
||||||
const int num_threads = gridDim.y * blockDim.x;
|
const int num_threads = gridDim.y * blockDim.x;
|
||||||
const int tid = blockIdx.y * blockDim.x + threadIdx.x;
|
const int tid = blockIdx.y * blockDim.x + threadIdx.x;
|
||||||
|
|
||||||
// Parallelize over each feature in each pixel in images of size H * W,
|
// Parallelize over each feature in each pixel in images of size H * W,
|
||||||
// for each image in the batch of size batch_size
|
// for each image in the batch of size batch_size
|
||||||
for (int pid = tid; pid < num_pixels; pid += num_threads) {
|
for (int pid = tid; pid < num_pixels; pid += num_threads) {
|
||||||
int ch = pid / (W * H);
|
int ch = pid / (H * W);
|
||||||
int j = (pid % (W * H)) / H;
|
int j = (pid % (H * W)) / W;
|
||||||
int i = (pid % (W * H)) % H;
|
int i = (pid % (H * W)) % W;
|
||||||
|
|
||||||
// alphacomposite the different values
|
// alphacomposite the different values
|
||||||
float cum_alpha = 1.;
|
float cum_alpha = 1.;
|
||||||
|
@ -11,13 +11,13 @@
|
|||||||
// features: FloatTensor of shape (C, P) which gives the features
|
// features: FloatTensor of shape (C, P) which gives the features
|
||||||
// of each point where C is the size of the feature and
|
// of each point where C is the size of the feature and
|
||||||
// P the number of points.
|
// P the number of points.
|
||||||
// alphas: FloatTensor of shape (N, points_per_pixel, W, W) where
|
// alphas: FloatTensor of shape (N, points_per_pixel, H, W) where
|
||||||
// points_per_pixel is the number of points in the z-buffer
|
// points_per_pixel is the number of points in the z-buffer
|
||||||
// sorted in z-order, and W is the image size.
|
// sorted in z-order, and (H, W) is the image size.
|
||||||
// points_idx: IntTensor of shape (N, points_per_pixel, W, W) giving the
|
// points_idx: IntTensor of shape (N, points_per_pixel, H, W) giving the
|
||||||
// indices of the nearest points at each pixel, sorted in z-order.
|
// indices of the nearest points at each pixel, sorted in z-order.
|
||||||
// Returns:
|
// Returns:
|
||||||
// weighted_fs: FloatTensor of shape (N, C, W, W) giving the accumulated
|
// weighted_fs: FloatTensor of shape (N, C, H, W) giving the accumulated
|
||||||
// feature for each point. Concretely, it gives:
|
// feature for each point. Concretely, it gives:
|
||||||
// weighted_fs[b,c,i,j] = sum_k cum_alpha_k *
|
// weighted_fs[b,c,i,j] = sum_k cum_alpha_k *
|
||||||
// features[c,points_idx[b,k,i,j]]
|
// features[c,points_idx[b,k,i,j]]
|
||||||
|
@ -30,16 +30,16 @@ __global__ void weightedSumNormCudaForwardKernel(
|
|||||||
// Get the batch and index
|
// Get the batch and index
|
||||||
const int batch = blockIdx.x;
|
const int batch = blockIdx.x;
|
||||||
|
|
||||||
const int num_pixels = C * W * H;
|
const int num_pixels = C * H * W;
|
||||||
const int num_threads = gridDim.y * blockDim.x;
|
const int num_threads = gridDim.y * blockDim.x;
|
||||||
const int tid = blockIdx.y * blockDim.x + threadIdx.x;
|
const int tid = blockIdx.y * blockDim.x + threadIdx.x;
|
||||||
|
|
||||||
// Parallelize over each feature in each pixel in images of size H * W,
|
// Parallelize over each feature in each pixel in images of size H * W,
|
||||||
// for each image in the batch of size batch_size
|
// for each image in the batch of size batch_size
|
||||||
for (int pid = tid; pid < num_pixels; pid += num_threads) {
|
for (int pid = tid; pid < num_pixels; pid += num_threads) {
|
||||||
int ch = pid / (W * H);
|
int ch = pid / (H * W);
|
||||||
int j = (pid % (W * H)) / H;
|
int j = (pid % (H * W)) / W;
|
||||||
int i = (pid % (W * H)) % H;
|
int i = (pid % (H * W)) % W;
|
||||||
|
|
||||||
// Store the accumulated alpha value
|
// Store the accumulated alpha value
|
||||||
float cum_alpha = 0.;
|
float cum_alpha = 0.;
|
||||||
@ -101,9 +101,9 @@ __global__ void weightedSumNormCudaBackwardKernel(
|
|||||||
// Parallelize over each feature in each pixel in images of size H * W,
|
// Parallelize over each feature in each pixel in images of size H * W,
|
||||||
// for each image in the batch of size batch_size
|
// for each image in the batch of size batch_size
|
||||||
for (int pid = tid; pid < num_pixels; pid += num_threads) {
|
for (int pid = tid; pid < num_pixels; pid += num_threads) {
|
||||||
int ch = pid / (W * H);
|
int ch = pid / (H * W);
|
||||||
int j = (pid % (W * H)) / H;
|
int j = (pid % (H * W)) / W;
|
||||||
int i = (pid % (W * H)) % H;
|
int i = (pid % (H * W)) % W;
|
||||||
|
|
||||||
float sum_alpha = 0.;
|
float sum_alpha = 0.;
|
||||||
float sum_alphafs = 0.;
|
float sum_alphafs = 0.;
|
||||||
|
@ -11,13 +11,13 @@
|
|||||||
// features: FloatTensor of shape (C, P) which gives the features
|
// features: FloatTensor of shape (C, P) which gives the features
|
||||||
// of each point where C is the size of the feature and
|
// of each point where C is the size of the feature and
|
||||||
// P the number of points.
|
// P the number of points.
|
||||||
// alphas: FloatTensor of shape (N, points_per_pixel, W, W) where
|
// alphas: FloatTensor of shape (N, points_per_pixel, H, W) where
|
||||||
// points_per_pixel is the number of points in the z-buffer
|
// points_per_pixel is the number of points in the z-buffer
|
||||||
// sorted in z-order, and W is the image size.
|
// sorted in z-order, and (H, W) is the image size.
|
||||||
// points_idx: IntTensor of shape (N, points_per_pixel, W, W) giving the
|
// points_idx: IntTensor of shape (N, points_per_pixel, H, W) giving the
|
||||||
// indices of the nearest points at each pixel, sorted in z-order.
|
// indices of the nearest points at each pixel, sorted in z-order.
|
||||||
// Returns:
|
// Returns:
|
||||||
// weighted_fs: FloatTensor of shape (N, C, W, W) giving the accumulated
|
// weighted_fs: FloatTensor of shape (N, C, H, W) giving the accumulated
|
||||||
// feature in each point. Concretely, it gives:
|
// feature in each point. Concretely, it gives:
|
||||||
// weighted_fs[b,c,i,j] = sum_k alphas[b,k,i,j] *
|
// weighted_fs[b,c,i,j] = sum_k alphas[b,k,i,j] *
|
||||||
// features[c,points_idx[b,k,i,j]] / sum_k alphas[b,k,i,j]
|
// features[c,points_idx[b,k,i,j]] / sum_k alphas[b,k,i,j]
|
||||||
|
@ -28,16 +28,16 @@ __global__ void weightedSumCudaForwardKernel(
|
|||||||
// Get the batch and index
|
// Get the batch and index
|
||||||
const int batch = blockIdx.x;
|
const int batch = blockIdx.x;
|
||||||
|
|
||||||
const int num_pixels = C * W * H;
|
const int num_pixels = C * H * W;
|
||||||
const int num_threads = gridDim.y * blockDim.x;
|
const int num_threads = gridDim.y * blockDim.x;
|
||||||
const int tid = blockIdx.y * blockDim.x + threadIdx.x;
|
const int tid = blockIdx.y * blockDim.x + threadIdx.x;
|
||||||
|
|
||||||
// Parallelize over each feature in each pixel in images of size H * W,
|
// Parallelize over each feature in each pixel in images of size H * W,
|
||||||
// for each image in the batch of size batch_size
|
// for each image in the batch of size batch_size
|
||||||
for (int pid = tid; pid < num_pixels; pid += num_threads) {
|
for (int pid = tid; pid < num_pixels; pid += num_threads) {
|
||||||
int ch = pid / (W * H);
|
int ch = pid / (H * W);
|
||||||
int j = (pid % (W * H)) / H;
|
int j = (pid % (H * W)) / W;
|
||||||
int i = (pid % (W * H)) % H;
|
int i = (pid % (H * W)) % W;
|
||||||
|
|
||||||
// Iterate through the closest K points for this pixel
|
// Iterate through the closest K points for this pixel
|
||||||
for (int k = 0; k < points_idx.size(1); ++k) {
|
for (int k = 0; k < points_idx.size(1); ++k) {
|
||||||
@ -76,16 +76,16 @@ __global__ void weightedSumCudaBackwardKernel(
|
|||||||
// Get the batch and index
|
// Get the batch and index
|
||||||
const int batch = blockIdx.x;
|
const int batch = blockIdx.x;
|
||||||
|
|
||||||
const int num_pixels = C * W * H;
|
const int num_pixels = C * H * W;
|
||||||
const int num_threads = gridDim.y * blockDim.x;
|
const int num_threads = gridDim.y * blockDim.x;
|
||||||
const int tid = blockIdx.y * blockDim.x + threadIdx.x;
|
const int tid = blockIdx.y * blockDim.x + threadIdx.x;
|
||||||
|
|
||||||
// Iterate over each pixel to compute the contribution to the
|
// Iterate over each pixel to compute the contribution to the
|
||||||
// gradient for the features and weights
|
// gradient for the features and weights
|
||||||
for (int pid = tid; pid < num_pixels; pid += num_threads) {
|
for (int pid = tid; pid < num_pixels; pid += num_threads) {
|
||||||
int ch = pid / (W * H);
|
int ch = pid / (H * W);
|
||||||
int j = (pid % (W * H)) / H;
|
int j = (pid % (H * W)) / W;
|
||||||
int i = (pid % (W * H)) % H;
|
int i = (pid % (H * W)) % W;
|
||||||
|
|
||||||
// Iterate through the closest K points for this pixel
|
// Iterate through the closest K points for this pixel
|
||||||
for (int k = 0; k < points_idx.size(1); ++k) {
|
for (int k = 0; k < points_idx.size(1); ++k) {
|
||||||
|
@ -11,13 +11,13 @@
|
|||||||
// features: FloatTensor of shape (C, P) which gives the features
|
// features: FloatTensor of shape (C, P) which gives the features
|
||||||
// of each point where C is the size of the feature and
|
// of each point where C is the size of the feature and
|
||||||
// P the number of points.
|
// P the number of points.
|
||||||
// alphas: FloatTensor of shape (N, points_per_pixel, W, W) where
|
// alphas: FloatTensor of shape (N, points_per_pixel, H, W) where
|
||||||
// points_per_pixel is the number of points in the z-buffer
|
// points_per_pixel is the number of points in the z-buffer
|
||||||
// sorted in z-order, and W is the image size.
|
// sorted in z-order, and (H, W) is the image size.
|
||||||
// points_idx: IntTensor of shape (N, points_per_pixel, W, W) giving the
|
// points_idx: IntTensor of shape (N, points_per_pixel, W, W) giving the
|
||||||
// indices of the nearest points at each pixel, sorted in z-order.
|
// indices of the nearest points at each pixel, sorted in z-order.
|
||||||
// Returns:
|
// Returns:
|
||||||
// weighted_fs: FloatTensor of shape (N, C, W, W) giving the accumulated
|
// weighted_fs: FloatTensor of shape (N, C, H, W) giving the accumulated
|
||||||
// feature in each point. Concretely, it gives:
|
// feature in each point. Concretely, it gives:
|
||||||
// weighted_fs[b,c,i,j] = sum_k alphas[b,k,i,j] *
|
// weighted_fs[b,c,i,j] = sum_k alphas[b,k,i,j] *
|
||||||
// features[c,points_idx[b,k,i,j]]
|
// features[c,points_idx[b,k,i,j]]
|
||||||
|
@ -452,7 +452,6 @@ __global__ void RasterizeMeshesBackwardCudaKernel(
|
|||||||
const bool inside = b_pp.x > 0.0f && b_pp.y > 0.0f && b_pp.z > 0.0f;
|
const bool inside = b_pp.x > 0.0f && b_pp.y > 0.0f && b_pp.z > 0.0f;
|
||||||
const float sign = inside ? -1.0f : 1.0f;
|
const float sign = inside ? -1.0f : 1.0f;
|
||||||
|
|
||||||
// TODO(T52813608) Add support for non-square images.
|
|
||||||
auto grad_dist_f = PointTriangleDistanceBackward(
|
auto grad_dist_f = PointTriangleDistanceBackward(
|
||||||
pxy, v0xy, v1xy, v2xy, sign * grad_dist_upstream);
|
pxy, v0xy, v1xy, v2xy, sign * grad_dist_upstream);
|
||||||
const float2 ddist_d_v0 = thrust::get<1>(grad_dist_f);
|
const float2 ddist_d_v0 = thrust::get<1>(grad_dist_f);
|
||||||
@ -606,7 +605,7 @@ __global__ void RasterizeMeshesCoarseCudaKernel(
|
|||||||
const float half_pix_x = NDC_x_half_range / W;
|
const float half_pix_x = NDC_x_half_range / W;
|
||||||
const float half_pix_y = NDC_y_half_range / H;
|
const float half_pix_y = NDC_y_half_range / H;
|
||||||
|
|
||||||
// This is a boolean array of shape (num_bins, num_bins, chunk_size)
|
// This is a boolean array of shape (num_bins_y, num_bins_x, chunk_size)
|
||||||
// stored in shared memory that will track whether each point in the chunk
|
// stored in shared memory that will track whether each point in the chunk
|
||||||
// falls into each bin of the image.
|
// falls into each bin of the image.
|
||||||
BitMask binmask((unsigned int*)sbuf, num_bins_y, num_bins_x, chunk_size);
|
BitMask binmask((unsigned int*)sbuf, num_bins_y, num_bins_x, chunk_size);
|
||||||
@ -755,7 +754,7 @@ at::Tensor RasterizeMeshesCoarseCuda(
|
|||||||
const int num_bins_y = 1 + (H - 1) / bin_size;
|
const int num_bins_y = 1 + (H - 1) / bin_size;
|
||||||
const int num_bins_x = 1 + (W - 1) / bin_size;
|
const int num_bins_x = 1 + (W - 1) / bin_size;
|
||||||
|
|
||||||
if (num_bins_y >= kMaxFacesPerBin || num_bins_x >= kMaxFacesPerBin) {
|
if (num_bins_y >= kMaxItemsPerBin || num_bins_x >= kMaxItemsPerBin) {
|
||||||
std::stringstream ss;
|
std::stringstream ss;
|
||||||
ss << "In Coarse Rasterizer got num_bins_y: " << num_bins_y
|
ss << "In Coarse Rasterizer got num_bins_y: " << num_bins_y
|
||||||
<< ", num_bins_x: " << num_bins_x << ", "
|
<< ", num_bins_x: " << num_bins_x << ", "
|
||||||
@ -800,7 +799,7 @@ at::Tensor RasterizeMeshesCoarseCuda(
|
|||||||
// ****************************************************************************
|
// ****************************************************************************
|
||||||
__global__ void RasterizeMeshesFineCudaKernel(
|
__global__ void RasterizeMeshesFineCudaKernel(
|
||||||
const float* face_verts, // (F, 3, 3)
|
const float* face_verts, // (F, 3, 3)
|
||||||
const int32_t* bin_faces, // (N, B, B, T)
|
const int32_t* bin_faces, // (N, BH, BW, T)
|
||||||
const float blur_radius,
|
const float blur_radius,
|
||||||
const int bin_size,
|
const int bin_size,
|
||||||
const bool perspective_correct,
|
const bool perspective_correct,
|
||||||
@ -813,12 +812,12 @@ __global__ void RasterizeMeshesFineCudaKernel(
|
|||||||
const int H,
|
const int H,
|
||||||
const int W,
|
const int W,
|
||||||
const int K,
|
const int K,
|
||||||
int64_t* face_idxs, // (N, S, S, K)
|
int64_t* face_idxs, // (N, H, W, K)
|
||||||
float* zbuf, // (N, S, S, K)
|
float* zbuf, // (N, H, W, K)
|
||||||
float* pix_dists, // (N, S, S, K)
|
float* pix_dists, // (N, H, W, K)
|
||||||
float* bary // (N, S, S, K, 3)
|
float* bary // (N, H, W, K, 3)
|
||||||
) {
|
) {
|
||||||
// This can be more than S^2 if S % bin_size != 0
|
// This can be more than H * W if H or W are not divisible by bin_size.
|
||||||
int num_pixels = N * BH * BW * bin_size * bin_size;
|
int num_pixels = N * BH * BW * bin_size * bin_size;
|
||||||
int num_threads = gridDim.x * blockDim.x;
|
int num_threads = gridDim.x * blockDim.x;
|
||||||
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
@ -5,41 +5,11 @@
|
|||||||
#include <list>
|
#include <list>
|
||||||
#include <queue>
|
#include <queue>
|
||||||
#include <tuple>
|
#include <tuple>
|
||||||
|
#include "rasterize_points/rasterization_utils.h"
|
||||||
#include "utils/geometry_utils.h"
|
#include "utils/geometry_utils.h"
|
||||||
#include "utils/vec2.h"
|
#include "utils/vec2.h"
|
||||||
#include "utils/vec3.h"
|
#include "utils/vec3.h"
|
||||||
|
|
||||||
// The default value of the NDC range is [-1, 1], however in the case that
|
|
||||||
// H != W, the NDC range is set such that the shorter side has range [-1, 1] and
|
|
||||||
// the longer side is scaled by the ratio of H:W. S1 is the dimension for which
|
|
||||||
// the NDC range is calculated and S2 is the other image dimension.
|
|
||||||
// e.g. to get the NDC x range S1 = W and S2 = H
|
|
||||||
float NonSquareNdcRange(int S1, int S2) {
|
|
||||||
float range = 2.0f;
|
|
||||||
if (S1 > S2) {
|
|
||||||
range = ((S1 / S2) * range);
|
|
||||||
}
|
|
||||||
return range;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Given a pixel coordinate 0 <= i < S1, convert it to a normalized device
|
|
||||||
// coordinates. We divide the NDC range into S1 evenly-sized
|
|
||||||
// pixels, and assume that each pixel falls in the *center* of its range.
|
|
||||||
// The default value of the NDC range is [-1, 1], however in the case that
|
|
||||||
// H != W, the NDC range is set such that the shorter side has range [-1, 1] and
|
|
||||||
// the longer side is scaled by the ratio of H:W. The dimension of i should be
|
|
||||||
// S1 and the other image dimension is S2 For example, to get the x and y NDC
|
|
||||||
// coordinates or a given pixel i:
|
|
||||||
// x = PixToNonSquareNdc(i, W, H)
|
|
||||||
// y = PixToNonSquareNdc(i, H, W)
|
|
||||||
float PixToNonSquareNdc(int i, int S1, int S2) {
|
|
||||||
float range = NonSquareNdcRange(S1, S2);
|
|
||||||
// NDC: offset + (i * pixel_width + half_pixel_width)
|
|
||||||
// The NDC range is [-range/2, range/2].
|
|
||||||
const float offset = (range / 2.0f);
|
|
||||||
return -offset + (range * i + offset) / S1;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get (x, y, z) values for vertex from (3, 3) tensor face.
|
// Get (x, y, z) values for vertex from (3, 3) tensor face.
|
||||||
template <typename Face>
|
template <typename Face>
|
||||||
auto ExtractVerts(const Face& face, const int vertex_index) {
|
auto ExtractVerts(const Face& face, const int vertex_index) {
|
||||||
|
@ -2,16 +2,6 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
// Given a pixel coordinate 0 <= i < S, convert it to a normalized device
|
|
||||||
// coordinates in the range [-1, 1]. We divide the NDC range into S evenly-sized
|
|
||||||
// pixels, and assume that each pixel falls in the *center* of its range.
|
|
||||||
// TODO: delete this function after updating the pointcloud rasterizer to
|
|
||||||
// support non square images.
|
|
||||||
__device__ inline float PixToNdc(int i, int S) {
|
|
||||||
// NDC: x-offset + (i * pixel_width + half_pixel_width)
|
|
||||||
return -1.0 + (2 * i + 1.0) / S;
|
|
||||||
}
|
|
||||||
|
|
||||||
// The default value of the NDC range is [-1, 1], however in the case that
|
// The default value of the NDC range is [-1, 1], however in the case that
|
||||||
// H != W, the NDC range is set such that the shorter side has range [-1, 1] and
|
// H != W, the NDC range is set such that the shorter side has range [-1, 1] and
|
||||||
// the longer side is scaled by the ratio of H:W. S1 is the dimension for which
|
// the longer side is scaled by the ratio of H:W. S1 is the dimension for which
|
||||||
@ -50,7 +40,7 @@ __device__ inline float PixToNonSquareNdc(int i, int S1, int S2) {
|
|||||||
// TODO: is 8 enough? Would increasing have performance considerations?
|
// TODO: is 8 enough? Would increasing have performance considerations?
|
||||||
const int32_t kMaxPointsPerPixel = 150;
|
const int32_t kMaxPointsPerPixel = 150;
|
||||||
|
|
||||||
const int32_t kMaxFacesPerBin = 22;
|
const int32_t kMaxItemsPerBin = 22;
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__device__ inline void BubbleSort(T* arr, int n) {
|
__device__ inline void BubbleSort(T* arr, int n) {
|
||||||
|
34
pytorch3d/csrc/rasterize_points/rasterization_utils.h
Normal file
34
pytorch3d/csrc/rasterize_points/rasterization_utils.h
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
// The default value of the NDC range is [-1, 1], however in the case that
|
||||||
|
// H != W, the NDC range is set such that the shorter side has range [-1, 1] and
|
||||||
|
// the longer side is scaled by the ratio of H:W. S1 is the dimension for which
|
||||||
|
// the NDC range is calculated and S2 is the other image dimension.
|
||||||
|
// e.g. to get the NDC x range S1 = W and S2 = H
|
||||||
|
inline float NonSquareNdcRange(int S1, int S2) {
|
||||||
|
float range = 2.0f;
|
||||||
|
if (S1 > S2) {
|
||||||
|
range = ((S1 / S2) * range);
|
||||||
|
}
|
||||||
|
return range;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Given a pixel coordinate 0 <= i < S1, convert it to a normalized device
|
||||||
|
// coordinates. We divide the NDC range into S1 evenly-sized
|
||||||
|
// pixels, and assume that each pixel falls in the *center* of its range.
|
||||||
|
// The default value of the NDC range is [-1, 1], however in the case that
|
||||||
|
// H != W, the NDC range is set such that the shorter side has range [-1, 1] and
|
||||||
|
// the longer side is scaled by the ratio of H:W. The dimension of i should be
|
||||||
|
// S1 and the other image dimension is S2 For example, to get the x and y NDC
|
||||||
|
// coordinates or a given pixel i:
|
||||||
|
// x = PixToNonSquareNdc(i, W, H)
|
||||||
|
// y = PixToNonSquareNdc(i, H, W)
|
||||||
|
inline float PixToNonSquareNdc(int i, int S1, int S2) {
|
||||||
|
float range = NonSquareNdcRange(S1, S2);
|
||||||
|
// NDC: offset + (i * pixel_width + half_pixel_width)
|
||||||
|
// The NDC range is [-range/2, range/2].
|
||||||
|
const float offset = (range / 2.0f);
|
||||||
|
return -offset + (range * i + offset) / S1;
|
||||||
|
}
|
@ -85,26 +85,28 @@ __global__ void RasterizePointsNaiveCudaKernel(
|
|||||||
const int64_t* num_points_per_cloud, // (N)
|
const int64_t* num_points_per_cloud, // (N)
|
||||||
const float* radius,
|
const float* radius,
|
||||||
const int N,
|
const int N,
|
||||||
const int S,
|
const int H,
|
||||||
|
const int W,
|
||||||
const int K,
|
const int K,
|
||||||
int32_t* point_idxs, // (N, S, S, K)
|
int32_t* point_idxs, // (N, H, W, K)
|
||||||
float* zbuf, // (N, S, S, K)
|
float* zbuf, // (N, H, W, K)
|
||||||
float* pix_dists) { // (N, S, S, K)
|
float* pix_dists) { // (N, H, W, K)
|
||||||
// Simple version: One thread per output pixel
|
// Simple version: One thread per output pixel
|
||||||
const int num_threads = gridDim.x * blockDim.x;
|
const int num_threads = gridDim.x * blockDim.x;
|
||||||
const int tid = blockDim.x * blockIdx.x + threadIdx.x;
|
const int tid = blockDim.x * blockIdx.x + threadIdx.x;
|
||||||
for (int i = tid; i < N * S * S; i += num_threads) {
|
for (int i = tid; i < N * H * W; i += num_threads) {
|
||||||
// Convert linear index to 3D index
|
// Convert linear index to 3D index
|
||||||
const int n = i / (S * S); // Batch index
|
const int n = i / (H * W); // Batch index
|
||||||
const int pix_idx = i % (S * S);
|
const int pix_idx = i % (H * W);
|
||||||
|
|
||||||
// Reverse ordering of the X and Y axis as the camera coordinates
|
// Reverse ordering of the X and Y axis as the camera coordinates
|
||||||
// assume that +Y is pointing up and +X is pointing left.
|
// assume that +Y is pointing up and +X is pointing left.
|
||||||
const int yi = S - 1 - pix_idx / S;
|
const int yi = H - 1 - pix_idx / W;
|
||||||
const int xi = S - 1 - pix_idx % S;
|
const int xi = W - 1 - pix_idx % W;
|
||||||
|
|
||||||
const float xf = PixToNdc(xi, S);
|
// screen coordinates to ndc coordiantes of pixel.
|
||||||
const float yf = PixToNdc(yi, S);
|
const float xf = PixToNonSquareNdc(xi, W, H);
|
||||||
|
const float yf = PixToNonSquareNdc(yi, H, W);
|
||||||
|
|
||||||
// For keeping track of the K closest points we want a data structure
|
// For keeping track of the K closest points we want a data structure
|
||||||
// that (1) gives O(1) access to the closest point for easy comparisons,
|
// that (1) gives O(1) access to the closest point for easy comparisons,
|
||||||
@ -132,7 +134,7 @@ __global__ void RasterizePointsNaiveCudaKernel(
|
|||||||
points, p_idx, q_size, q_max_z, q_max_idx, q, radius, xf, yf, K);
|
points, p_idx, q_size, q_max_z, q_max_idx, q, radius, xf, yf, K);
|
||||||
}
|
}
|
||||||
BubbleSort(q, q_size);
|
BubbleSort(q, q_size);
|
||||||
int idx = n * S * S * K + pix_idx * K;
|
int idx = n * H * W * K + pix_idx * K;
|
||||||
for (int k = 0; k < q_size; ++k) {
|
for (int k = 0; k < q_size; ++k) {
|
||||||
point_idxs[idx + k] = q[k].idx;
|
point_idxs[idx + k] = q[k].idx;
|
||||||
zbuf[idx + k] = q[k].z;
|
zbuf[idx + k] = q[k].z;
|
||||||
@ -145,7 +147,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> RasterizePointsNaiveCuda(
|
|||||||
const at::Tensor& points, // (P. 3)
|
const at::Tensor& points, // (P. 3)
|
||||||
const at::Tensor& cloud_to_packed_first_idx, // (N)
|
const at::Tensor& cloud_to_packed_first_idx, // (N)
|
||||||
const at::Tensor& num_points_per_cloud, // (N)
|
const at::Tensor& num_points_per_cloud, // (N)
|
||||||
const int image_size,
|
const std::tuple<int, int> image_size,
|
||||||
const at::Tensor& radius,
|
const at::Tensor& radius,
|
||||||
const int points_per_pixel) {
|
const int points_per_pixel) {
|
||||||
// Check inputs are on the same device
|
// Check inputs are on the same device
|
||||||
@ -169,7 +171,8 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> RasterizePointsNaiveCuda(
|
|||||||
"num_points_per_cloud must have same size first dimension as cloud_to_packed_first_idx");
|
"num_points_per_cloud must have same size first dimension as cloud_to_packed_first_idx");
|
||||||
|
|
||||||
const int N = num_points_per_cloud.size(0); // batch size.
|
const int N = num_points_per_cloud.size(0); // batch size.
|
||||||
const int S = image_size;
|
const int H = std::get<0>(image_size);
|
||||||
|
const int W = std::get<1>(image_size);
|
||||||
const int K = points_per_pixel;
|
const int K = points_per_pixel;
|
||||||
|
|
||||||
if (K > kMaxPointsPerPixel) {
|
if (K > kMaxPointsPerPixel) {
|
||||||
@ -180,9 +183,9 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> RasterizePointsNaiveCuda(
|
|||||||
|
|
||||||
auto int_opts = num_points_per_cloud.options().dtype(at::kInt);
|
auto int_opts = num_points_per_cloud.options().dtype(at::kInt);
|
||||||
auto float_opts = points.options().dtype(at::kFloat);
|
auto float_opts = points.options().dtype(at::kFloat);
|
||||||
at::Tensor point_idxs = at::full({N, S, S, K}, -1, int_opts);
|
at::Tensor point_idxs = at::full({N, H, W, K}, -1, int_opts);
|
||||||
at::Tensor zbuf = at::full({N, S, S, K}, -1, float_opts);
|
at::Tensor zbuf = at::full({N, H, W, K}, -1, float_opts);
|
||||||
at::Tensor pix_dists = at::full({N, S, S, K}, -1, float_opts);
|
at::Tensor pix_dists = at::full({N, H, W, K}, -1, float_opts);
|
||||||
|
|
||||||
if (point_idxs.numel() == 0) {
|
if (point_idxs.numel() == 0) {
|
||||||
AT_CUDA_CHECK(cudaGetLastError());
|
AT_CUDA_CHECK(cudaGetLastError());
|
||||||
@ -197,7 +200,8 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> RasterizePointsNaiveCuda(
|
|||||||
num_points_per_cloud.contiguous().data_ptr<int64_t>(),
|
num_points_per_cloud.contiguous().data_ptr<int64_t>(),
|
||||||
radius.contiguous().data_ptr<float>(),
|
radius.contiguous().data_ptr<float>(),
|
||||||
N,
|
N,
|
||||||
S,
|
H,
|
||||||
|
W,
|
||||||
K,
|
K,
|
||||||
point_idxs.contiguous().data_ptr<int32_t>(),
|
point_idxs.contiguous().data_ptr<int32_t>(),
|
||||||
zbuf.contiguous().data_ptr<float>(),
|
zbuf.contiguous().data_ptr<float>(),
|
||||||
@ -218,7 +222,8 @@ __global__ void RasterizePointsCoarseCudaKernel(
|
|||||||
const float* radius,
|
const float* radius,
|
||||||
const int N,
|
const int N,
|
||||||
const int P,
|
const int P,
|
||||||
const int S,
|
const int H,
|
||||||
|
const int W,
|
||||||
const int bin_size,
|
const int bin_size,
|
||||||
const int chunk_size,
|
const int chunk_size,
|
||||||
const int max_points_per_bin,
|
const int max_points_per_bin,
|
||||||
@ -226,13 +231,26 @@ __global__ void RasterizePointsCoarseCudaKernel(
|
|||||||
int* bin_points) {
|
int* bin_points) {
|
||||||
extern __shared__ char sbuf[];
|
extern __shared__ char sbuf[];
|
||||||
const int M = max_points_per_bin;
|
const int M = max_points_per_bin;
|
||||||
const int num_bins = 1 + (S - 1) / bin_size; // Integer divide round up
|
|
||||||
const float half_pix = 1.0f / S; // Size of half a pixel in NDC units
|
|
||||||
|
|
||||||
// This is a boolean array of shape (num_bins, num_bins, chunk_size)
|
// Integer divide round up
|
||||||
|
const int num_bins_x = 1 + (W - 1) / bin_size;
|
||||||
|
const int num_bins_y = 1 + (H - 1) / bin_size;
|
||||||
|
|
||||||
|
// NDC range depends on the ratio of W/H
|
||||||
|
// The shorter side from (H, W) is given an NDC range of 2.0 and
|
||||||
|
// the other side is scaled by the ratio of H:W.
|
||||||
|
const float NDC_x_half_range = NonSquareNdcRange(W, H) / 2.0f;
|
||||||
|
const float NDC_y_half_range = NonSquareNdcRange(H, W) / 2.0f;
|
||||||
|
|
||||||
|
// Size of half a pixel in NDC units is the NDC half range
|
||||||
|
// divided by the corresponding image dimension
|
||||||
|
const float half_pix_x = NDC_x_half_range / W;
|
||||||
|
const float half_pix_y = NDC_y_half_range / H;
|
||||||
|
|
||||||
|
// This is a boolean array of shape (num_bins_y, num_bins_x, chunk_size)
|
||||||
// stored in shared memory that will track whether each point in the chunk
|
// stored in shared memory that will track whether each point in the chunk
|
||||||
// falls into each bin of the image.
|
// falls into each bin of the image.
|
||||||
BitMask binmask((unsigned int*)sbuf, num_bins, num_bins, chunk_size);
|
BitMask binmask((unsigned int*)sbuf, num_bins_y, num_bins_x, chunk_size);
|
||||||
|
|
||||||
// Have each block handle a chunk of points and build a 3D bitmask in
|
// Have each block handle a chunk of points and build a 3D bitmask in
|
||||||
// shared memory to mark which points hit which bins. In this first phase,
|
// shared memory to mark which points hit which bins. In this first phase,
|
||||||
@ -279,22 +297,24 @@ __global__ void RasterizePointsCoarseCudaKernel(
|
|||||||
// For example we could compute the exact bin where the point falls,
|
// For example we could compute the exact bin where the point falls,
|
||||||
// then check neighboring bins. This way we wouldn't have to check
|
// then check neighboring bins. This way we wouldn't have to check
|
||||||
// all bins (however then we might have more warp divergence?)
|
// all bins (however then we might have more warp divergence?)
|
||||||
for (int by = 0; by < num_bins; ++by) {
|
for (int by = 0; by < num_bins_y; ++by) {
|
||||||
// Get y extent for the bin. PixToNdc gives us the location of
|
// Get y extent for the bin. PixToNonSquareNdc gives us the location of
|
||||||
// the center of each pixel, so we need to add/subtract a half
|
// the center of each pixel, so we need to add/subtract a half
|
||||||
// pixel to get the true extent of the bin.
|
// pixel to get the true extent of the bin.
|
||||||
const float by0 = PixToNdc(by * bin_size, S) - half_pix;
|
const float by0 = PixToNonSquareNdc(by * bin_size, H, W) - half_pix_y;
|
||||||
const float by1 = PixToNdc((by + 1) * bin_size - 1, S) + half_pix;
|
const float by1 =
|
||||||
|
PixToNonSquareNdc((by + 1) * bin_size - 1, H, W) + half_pix_y;
|
||||||
const bool y_overlap = (py0 <= by1) && (by0 <= py1);
|
const bool y_overlap = (py0 <= by1) && (by0 <= py1);
|
||||||
|
|
||||||
if (!y_overlap) {
|
if (!y_overlap) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
for (int bx = 0; bx < num_bins; ++bx) {
|
for (int bx = 0; bx < num_bins_x; ++bx) {
|
||||||
// Get x extent for the bin; again we need to adjust the
|
// Get x extent for the bin; again we need to adjust the
|
||||||
// output of PixToNdc by half a pixel.
|
// output of PixToNonSquareNdc by half a pixel.
|
||||||
const float bx0 = PixToNdc(bx * bin_size, S) - half_pix;
|
const float bx0 = PixToNonSquareNdc(bx * bin_size, W, H) - half_pix_x;
|
||||||
const float bx1 = PixToNdc((bx + 1) * bin_size - 1, S) + half_pix;
|
const float bx1 =
|
||||||
|
PixToNonSquareNdc((bx + 1) * bin_size - 1, W, H) + half_pix_x;
|
||||||
const bool x_overlap = (px0 <= bx1) && (bx0 <= px1);
|
const bool x_overlap = (px0 <= bx1) && (bx0 <= px1);
|
||||||
|
|
||||||
if (x_overlap) {
|
if (x_overlap) {
|
||||||
@ -307,12 +327,13 @@ __global__ void RasterizePointsCoarseCudaKernel(
|
|||||||
// Now we have processed every point in the current chunk. We need to
|
// Now we have processed every point in the current chunk. We need to
|
||||||
// count the number of points in each bin so we can write the indices
|
// count the number of points in each bin so we can write the indices
|
||||||
// out to global memory. We have each thread handle a different bin.
|
// out to global memory. We have each thread handle a different bin.
|
||||||
for (int byx = threadIdx.x; byx < num_bins * num_bins; byx += blockDim.x) {
|
for (int byx = threadIdx.x; byx < num_bins_y * num_bins_x;
|
||||||
const int by = byx / num_bins;
|
byx += blockDim.x) {
|
||||||
const int bx = byx % num_bins;
|
const int by = byx / num_bins_x;
|
||||||
|
const int bx = byx % num_bins_x;
|
||||||
const int count = binmask.count(by, bx);
|
const int count = binmask.count(by, bx);
|
||||||
const int points_per_bin_idx =
|
const int points_per_bin_idx =
|
||||||
batch_idx * num_bins * num_bins + by * num_bins + bx;
|
batch_idx * num_bins_y * num_bins_x + by * num_bins_x + bx;
|
||||||
|
|
||||||
// This atomically increments the (global) number of points found
|
// This atomically increments the (global) number of points found
|
||||||
// in the current bin, and gets the previous value of the counter;
|
// in the current bin, and gets the previous value of the counter;
|
||||||
@ -322,8 +343,8 @@ __global__ void RasterizePointsCoarseCudaKernel(
|
|||||||
|
|
||||||
// Now loop over the binmask and write the active bits for this bin
|
// Now loop over the binmask and write the active bits for this bin
|
||||||
// out to bin_points.
|
// out to bin_points.
|
||||||
int next_idx = batch_idx * num_bins * num_bins * M + by * num_bins * M +
|
int next_idx = batch_idx * num_bins_y * num_bins_x * M +
|
||||||
bx * M + start;
|
by * num_bins_x * M + bx * M + start;
|
||||||
for (int p = 0; p < chunk_size; ++p) {
|
for (int p = 0; p < chunk_size; ++p) {
|
||||||
if (binmask.get(by, bx, p)) {
|
if (binmask.get(by, bx, p)) {
|
||||||
// TODO: Throw an error if next_idx >= M -- this means that
|
// TODO: Throw an error if next_idx >= M -- this means that
|
||||||
@ -342,7 +363,7 @@ at::Tensor RasterizePointsCoarseCuda(
|
|||||||
const at::Tensor& points, // (P, 3)
|
const at::Tensor& points, // (P, 3)
|
||||||
const at::Tensor& cloud_to_packed_first_idx, // (N)
|
const at::Tensor& cloud_to_packed_first_idx, // (N)
|
||||||
const at::Tensor& num_points_per_cloud, // (N)
|
const at::Tensor& num_points_per_cloud, // (N)
|
||||||
const int image_size,
|
const std::tuple<int, int> image_size,
|
||||||
const at::Tensor& radius,
|
const at::Tensor& radius,
|
||||||
const int bin_size,
|
const int bin_size,
|
||||||
const int max_points_per_bin) {
|
const int max_points_per_bin) {
|
||||||
@ -363,20 +384,28 @@ at::Tensor RasterizePointsCoarseCuda(
|
|||||||
at::cuda::CUDAGuard device_guard(points.device());
|
at::cuda::CUDAGuard device_guard(points.device());
|
||||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
|
|
||||||
|
const int H = std::get<0>(image_size);
|
||||||
|
const int W = std::get<1>(image_size);
|
||||||
|
|
||||||
const int P = points.size(0);
|
const int P = points.size(0);
|
||||||
const int N = num_points_per_cloud.size(0);
|
const int N = num_points_per_cloud.size(0);
|
||||||
const int num_bins = 1 + (image_size - 1) / bin_size; // divide round up
|
|
||||||
const int M = max_points_per_bin;
|
const int M = max_points_per_bin;
|
||||||
|
|
||||||
if (num_bins >= 22) {
|
// Integer divide round up.
|
||||||
|
const int num_bins_y = 1 + (H - 1) / bin_size;
|
||||||
|
const int num_bins_x = 1 + (W - 1) / bin_size;
|
||||||
|
|
||||||
|
if (num_bins_y >= kMaxItemsPerBin || num_bins_x >= kMaxItemsPerBin) {
|
||||||
// Make sure we do not use too much shared memory.
|
// Make sure we do not use too much shared memory.
|
||||||
std::stringstream ss;
|
std::stringstream ss;
|
||||||
ss << "Got " << num_bins << "; that's too many!";
|
ss << "In Coarse Rasterizer got num_bins_y: " << num_bins_y
|
||||||
|
<< ", num_bins_x: " << num_bins_x << ", "
|
||||||
|
<< "; that's too many!";
|
||||||
AT_ERROR(ss.str());
|
AT_ERROR(ss.str());
|
||||||
}
|
}
|
||||||
auto opts = num_points_per_cloud.options().dtype(at::kInt);
|
auto opts = num_points_per_cloud.options().dtype(at::kInt);
|
||||||
at::Tensor points_per_bin = at::zeros({N, num_bins, num_bins}, opts);
|
at::Tensor points_per_bin = at::zeros({N, num_bins_y, num_bins_x}, opts);
|
||||||
at::Tensor bin_points = at::full({N, num_bins, num_bins, M}, -1, opts);
|
at::Tensor bin_points = at::full({N, num_bins_y, num_bins_x, M}, -1, opts);
|
||||||
|
|
||||||
if (bin_points.numel() == 0) {
|
if (bin_points.numel() == 0) {
|
||||||
AT_CUDA_CHECK(cudaGetLastError());
|
AT_CUDA_CHECK(cudaGetLastError());
|
||||||
@ -384,7 +413,7 @@ at::Tensor RasterizePointsCoarseCuda(
|
|||||||
}
|
}
|
||||||
|
|
||||||
const int chunk_size = 512;
|
const int chunk_size = 512;
|
||||||
const size_t shared_size = num_bins * num_bins * chunk_size / 8;
|
const size_t shared_size = num_bins_y * num_bins_x * chunk_size / 8;
|
||||||
const size_t blocks = 64;
|
const size_t blocks = 64;
|
||||||
const size_t threads = 512;
|
const size_t threads = 512;
|
||||||
|
|
||||||
@ -395,7 +424,8 @@ at::Tensor RasterizePointsCoarseCuda(
|
|||||||
radius.contiguous().data_ptr<float>(),
|
radius.contiguous().data_ptr<float>(),
|
||||||
N,
|
N,
|
||||||
P,
|
P,
|
||||||
image_size,
|
H,
|
||||||
|
W,
|
||||||
bin_size,
|
bin_size,
|
||||||
chunk_size,
|
chunk_size,
|
||||||
M,
|
M,
|
||||||
@ -412,19 +442,21 @@ at::Tensor RasterizePointsCoarseCuda(
|
|||||||
|
|
||||||
__global__ void RasterizePointsFineCudaKernel(
|
__global__ void RasterizePointsFineCudaKernel(
|
||||||
const float* points, // (P, 3)
|
const float* points, // (P, 3)
|
||||||
const int32_t* bin_points, // (N, B, B, T)
|
const int32_t* bin_points, // (N, BH, BW, T)
|
||||||
const float* radius,
|
const float* radius,
|
||||||
const int bin_size,
|
const int bin_size,
|
||||||
const int N,
|
const int N,
|
||||||
const int B, // num_bins
|
const int BH, // num_bins y
|
||||||
|
const int BW, // num_bins x
|
||||||
const int M,
|
const int M,
|
||||||
const int S,
|
const int H,
|
||||||
|
const int W,
|
||||||
const int K,
|
const int K,
|
||||||
int32_t* point_idxs, // (N, S, S, K)
|
int32_t* point_idxs, // (N, H, W, K)
|
||||||
float* zbuf, // (N, S, S, K)
|
float* zbuf, // (N, H, W, K)
|
||||||
float* pix_dists) { // (N, S, S, K)
|
float* pix_dists) { // (N, H, W, K)
|
||||||
// This can be more than S^2 if S is not dividable by bin_size.
|
// This can be more than H * W if H or W are not divisible by bin_size.
|
||||||
const int num_pixels = N * B * B * bin_size * bin_size;
|
const int num_pixels = N * BH * BW * bin_size * bin_size;
|
||||||
const int num_threads = gridDim.x * blockDim.x;
|
const int num_threads = gridDim.x * blockDim.x;
|
||||||
const int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
const int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
|
||||||
@ -434,21 +466,21 @@ __global__ void RasterizePointsFineCudaKernel(
|
|||||||
// into the same bin; this should give them coalesced memory reads when
|
// into the same bin; this should give them coalesced memory reads when
|
||||||
// they read from points and bin_points.
|
// they read from points and bin_points.
|
||||||
int i = pid;
|
int i = pid;
|
||||||
const int n = i / (B * B * bin_size * bin_size);
|
const int n = i / (BH * BW * bin_size * bin_size);
|
||||||
i %= B * B * bin_size * bin_size;
|
i %= BH * BW * bin_size * bin_size;
|
||||||
const int by = i / (B * bin_size * bin_size);
|
const int by = i / (BW * bin_size * bin_size);
|
||||||
i %= B * bin_size * bin_size;
|
i %= BW * bin_size * bin_size;
|
||||||
const int bx = i / (bin_size * bin_size);
|
const int bx = i / (bin_size * bin_size);
|
||||||
i %= bin_size * bin_size;
|
i %= bin_size * bin_size;
|
||||||
|
|
||||||
const int yi = i / bin_size + by * bin_size;
|
const int yi = i / bin_size + by * bin_size;
|
||||||
const int xi = i % bin_size + bx * bin_size;
|
const int xi = i % bin_size + bx * bin_size;
|
||||||
|
|
||||||
if (yi >= S || xi >= S)
|
if (yi >= H || xi >= W)
|
||||||
continue;
|
continue;
|
||||||
|
|
||||||
const float xf = PixToNdc(xi, S);
|
const float xf = PixToNonSquareNdc(xi, W, H);
|
||||||
const float yf = PixToNdc(yi, S);
|
const float yf = PixToNonSquareNdc(yi, H, W);
|
||||||
|
|
||||||
// This part looks like the naive rasterization kernel, except we use
|
// This part looks like the naive rasterization kernel, except we use
|
||||||
// bin_points to only look at a subset of points already known to fall
|
// bin_points to only look at a subset of points already known to fall
|
||||||
@ -459,7 +491,7 @@ __global__ void RasterizePointsFineCudaKernel(
|
|||||||
float q_max_z = -1000;
|
float q_max_z = -1000;
|
||||||
int q_max_idx = -1;
|
int q_max_idx = -1;
|
||||||
for (int m = 0; m < M; ++m) {
|
for (int m = 0; m < M; ++m) {
|
||||||
const int p = bin_points[n * B * B * M + by * B * M + bx * M + m];
|
const int p = bin_points[n * BH * BW * M + by * BW * M + bx * M + m];
|
||||||
if (p < 0) {
|
if (p < 0) {
|
||||||
// bin_points uses -1 as a sentinal value
|
// bin_points uses -1 as a sentinal value
|
||||||
continue;
|
continue;
|
||||||
@ -473,10 +505,10 @@ __global__ void RasterizePointsFineCudaKernel(
|
|||||||
|
|
||||||
// Reverse ordering of the X and Y axis as the camera coordinates
|
// Reverse ordering of the X and Y axis as the camera coordinates
|
||||||
// assume that +Y is pointing up and +X is pointing left.
|
// assume that +Y is pointing up and +X is pointing left.
|
||||||
const int yidx = S - 1 - yi;
|
const int yidx = H - 1 - yi;
|
||||||
const int xidx = S - 1 - xi;
|
const int xidx = W - 1 - xi;
|
||||||
|
|
||||||
const int pix_idx = n * S * S * K + yidx * S * K + xidx * K;
|
const int pix_idx = n * H * W * K + yidx * W * K + xidx * K;
|
||||||
for (int k = 0; k < q_size; ++k) {
|
for (int k = 0; k < q_size; ++k) {
|
||||||
point_idxs[pix_idx + k] = q[k].idx;
|
point_idxs[pix_idx + k] = q[k].idx;
|
||||||
zbuf[pix_idx + k] = q[k].z;
|
zbuf[pix_idx + k] = q[k].z;
|
||||||
@ -488,7 +520,7 @@ __global__ void RasterizePointsFineCudaKernel(
|
|||||||
std::tuple<at::Tensor, at::Tensor, at::Tensor> RasterizePointsFineCuda(
|
std::tuple<at::Tensor, at::Tensor, at::Tensor> RasterizePointsFineCuda(
|
||||||
const at::Tensor& points, // (P, 3)
|
const at::Tensor& points, // (P, 3)
|
||||||
const at::Tensor& bin_points,
|
const at::Tensor& bin_points,
|
||||||
const int image_size,
|
const std::tuple<int, int> image_size,
|
||||||
const at::Tensor& radius,
|
const at::Tensor& radius,
|
||||||
const int bin_size,
|
const int bin_size,
|
||||||
const int points_per_pixel) {
|
const int points_per_pixel) {
|
||||||
@ -503,18 +535,22 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> RasterizePointsFineCuda(
|
|||||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
|
|
||||||
const int N = bin_points.size(0);
|
const int N = bin_points.size(0);
|
||||||
const int B = bin_points.size(1); // num_bins
|
const int BH = bin_points.size(1);
|
||||||
|
const int BW = bin_points.size(2);
|
||||||
const int M = bin_points.size(3);
|
const int M = bin_points.size(3);
|
||||||
const int S = image_size;
|
|
||||||
const int K = points_per_pixel;
|
const int K = points_per_pixel;
|
||||||
|
|
||||||
|
const int H = std::get<0>(image_size);
|
||||||
|
const int W = std::get<1>(image_size);
|
||||||
|
|
||||||
if (K > kMaxPointsPerPixel) {
|
if (K > kMaxPointsPerPixel) {
|
||||||
AT_ERROR("Must have num_closest <= 150");
|
AT_ERROR("Must have num_closest <= 150");
|
||||||
}
|
}
|
||||||
auto int_opts = bin_points.options().dtype(at::kInt);
|
auto int_opts = bin_points.options().dtype(at::kInt);
|
||||||
auto float_opts = points.options().dtype(at::kFloat);
|
auto float_opts = points.options().dtype(at::kFloat);
|
||||||
at::Tensor point_idxs = at::full({N, S, S, K}, -1, int_opts);
|
at::Tensor point_idxs = at::full({N, H, W, K}, -1, int_opts);
|
||||||
at::Tensor zbuf = at::full({N, S, S, K}, -1, float_opts);
|
at::Tensor zbuf = at::full({N, H, W, K}, -1, float_opts);
|
||||||
at::Tensor pix_dists = at::full({N, S, S, K}, -1, float_opts);
|
at::Tensor pix_dists = at::full({N, H, W, K}, -1, float_opts);
|
||||||
|
|
||||||
if (point_idxs.numel() == 0) {
|
if (point_idxs.numel() == 0) {
|
||||||
AT_CUDA_CHECK(cudaGetLastError());
|
AT_CUDA_CHECK(cudaGetLastError());
|
||||||
@ -529,9 +565,11 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> RasterizePointsFineCuda(
|
|||||||
radius.contiguous().data_ptr<float>(),
|
radius.contiguous().data_ptr<float>(),
|
||||||
bin_size,
|
bin_size,
|
||||||
N,
|
N,
|
||||||
B,
|
BH,
|
||||||
|
BW,
|
||||||
M,
|
M,
|
||||||
S,
|
H,
|
||||||
|
W,
|
||||||
K,
|
K,
|
||||||
point_idxs.contiguous().data_ptr<int32_t>(),
|
point_idxs.contiguous().data_ptr<int32_t>(),
|
||||||
zbuf.contiguous().data_ptr<float>(),
|
zbuf.contiguous().data_ptr<float>(),
|
||||||
@ -571,8 +609,8 @@ __global__ void RasterizePointsBackwardCudaKernel(
|
|||||||
const int yidx = H - 1 - yi;
|
const int yidx = H - 1 - yi;
|
||||||
const int xidx = W - 1 - xi;
|
const int xidx = W - 1 - xi;
|
||||||
|
|
||||||
const float xf = PixToNdc(xidx, W);
|
const float xf = PixToNonSquareNdc(xidx, W, H);
|
||||||
const float yf = PixToNdc(yidx, H);
|
const float yf = PixToNonSquareNdc(yidx, H, W);
|
||||||
|
|
||||||
const int p = idxs[i];
|
const int p = idxs[i];
|
||||||
if (p < 0)
|
if (p < 0)
|
||||||
|
@ -14,7 +14,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsNaiveCpu(
|
|||||||
const torch::Tensor& points,
|
const torch::Tensor& points,
|
||||||
const torch::Tensor& cloud_to_packed_first_idx,
|
const torch::Tensor& cloud_to_packed_first_idx,
|
||||||
const torch::Tensor& num_points_per_cloud,
|
const torch::Tensor& num_points_per_cloud,
|
||||||
const int image_size,
|
const std::tuple<int, int> image_size,
|
||||||
const torch::Tensor& radius,
|
const torch::Tensor& radius,
|
||||||
const int points_per_pixel);
|
const int points_per_pixel);
|
||||||
|
|
||||||
@ -24,7 +24,7 @@ RasterizePointsNaiveCuda(
|
|||||||
const torch::Tensor& points,
|
const torch::Tensor& points,
|
||||||
const torch::Tensor& cloud_to_packed_first_idx,
|
const torch::Tensor& cloud_to_packed_first_idx,
|
||||||
const torch::Tensor& num_points_per_cloud,
|
const torch::Tensor& num_points_per_cloud,
|
||||||
const int image_size,
|
const std::tuple<int, int> image_size,
|
||||||
const torch::Tensor& radius,
|
const torch::Tensor& radius,
|
||||||
const int points_per_pixel);
|
const int points_per_pixel);
|
||||||
#endif
|
#endif
|
||||||
@ -43,7 +43,8 @@ RasterizePointsNaiveCuda(
|
|||||||
// for each pointcloud in the batch.
|
// for each pointcloud in the batch.
|
||||||
// radius: FloatTensor of shape (P) giving the radius (in NDC units) of
|
// radius: FloatTensor of shape (P) giving the radius (in NDC units) of
|
||||||
// each point in points.
|
// each point in points.
|
||||||
// image_size: (S) Size of the image to return (in pixels)
|
// image_size: Tuple (H, W) giving the size in pixels of the output
|
||||||
|
// image to be rasterized.
|
||||||
// points_per_pixel: (K) The number closest of points to return for each pixel
|
// points_per_pixel: (K) The number closest of points to return for each pixel
|
||||||
//
|
//
|
||||||
// Returns:
|
// Returns:
|
||||||
@ -62,7 +63,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsNaive(
|
|||||||
const torch::Tensor& points,
|
const torch::Tensor& points,
|
||||||
const torch::Tensor& cloud_to_packed_first_idx,
|
const torch::Tensor& cloud_to_packed_first_idx,
|
||||||
const torch::Tensor& num_points_per_cloud,
|
const torch::Tensor& num_points_per_cloud,
|
||||||
const int image_size,
|
const std::tuple<int, int> image_size,
|
||||||
const torch::Tensor& radius,
|
const torch::Tensor& radius,
|
||||||
const int points_per_pixel) {
|
const int points_per_pixel) {
|
||||||
if (points.is_cuda() && cloud_to_packed_first_idx.is_cuda() &&
|
if (points.is_cuda() && cloud_to_packed_first_idx.is_cuda() &&
|
||||||
@ -101,7 +102,7 @@ torch::Tensor RasterizePointsCoarseCpu(
|
|||||||
const torch::Tensor& points,
|
const torch::Tensor& points,
|
||||||
const torch::Tensor& cloud_to_packed_first_idx,
|
const torch::Tensor& cloud_to_packed_first_idx,
|
||||||
const torch::Tensor& num_points_per_cloud,
|
const torch::Tensor& num_points_per_cloud,
|
||||||
const int image_size,
|
const std::tuple<int, int> image_size,
|
||||||
const torch::Tensor& radius,
|
const torch::Tensor& radius,
|
||||||
const int bin_size,
|
const int bin_size,
|
||||||
const int max_points_per_bin);
|
const int max_points_per_bin);
|
||||||
@ -111,7 +112,7 @@ torch::Tensor RasterizePointsCoarseCuda(
|
|||||||
const torch::Tensor& points,
|
const torch::Tensor& points,
|
||||||
const torch::Tensor& cloud_to_packed_first_idx,
|
const torch::Tensor& cloud_to_packed_first_idx,
|
||||||
const torch::Tensor& num_points_per_cloud,
|
const torch::Tensor& num_points_per_cloud,
|
||||||
const int image_size,
|
const std::tuple<int, int> image_size,
|
||||||
const torch::Tensor& radius,
|
const torch::Tensor& radius,
|
||||||
const int bin_size,
|
const int bin_size,
|
||||||
const int max_points_per_bin);
|
const int max_points_per_bin);
|
||||||
@ -128,7 +129,8 @@ torch::Tensor RasterizePointsCoarseCuda(
|
|||||||
// for each pointcloud in the batch.
|
// for each pointcloud in the batch.
|
||||||
// radius: FloatTensor of shape (P) giving the radius (in NDC units) of
|
// radius: FloatTensor of shape (P) giving the radius (in NDC units) of
|
||||||
// each point in points.
|
// each point in points.
|
||||||
// image_size: Size of the image to generate (in pixels)
|
// image_size: Tuple (H, W) giving the size in pixels of the output
|
||||||
|
// image to be rasterized.
|
||||||
// bin_size: Size of each bin within the image (in pixels)
|
// bin_size: Size of each bin within the image (in pixels)
|
||||||
//
|
//
|
||||||
// Returns:
|
// Returns:
|
||||||
@ -140,7 +142,7 @@ torch::Tensor RasterizePointsCoarse(
|
|||||||
const torch::Tensor& points,
|
const torch::Tensor& points,
|
||||||
const torch::Tensor& cloud_to_packed_first_idx,
|
const torch::Tensor& cloud_to_packed_first_idx,
|
||||||
const torch::Tensor& num_points_per_cloud,
|
const torch::Tensor& num_points_per_cloud,
|
||||||
const int image_size,
|
const std::tuple<int, int> image_size,
|
||||||
const torch::Tensor& radius,
|
const torch::Tensor& radius,
|
||||||
const int bin_size,
|
const int bin_size,
|
||||||
const int max_points_per_bin) {
|
const int max_points_per_bin) {
|
||||||
@ -182,7 +184,7 @@ torch::Tensor RasterizePointsCoarse(
|
|||||||
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsFineCuda(
|
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsFineCuda(
|
||||||
const torch::Tensor& points,
|
const torch::Tensor& points,
|
||||||
const torch::Tensor& bin_points,
|
const torch::Tensor& bin_points,
|
||||||
const int image_size,
|
const std::tuple<int, int> image_size,
|
||||||
const torch::Tensor& radius,
|
const torch::Tensor& radius,
|
||||||
const int bin_size,
|
const int bin_size,
|
||||||
const int points_per_pixel);
|
const int points_per_pixel);
|
||||||
@ -194,7 +196,8 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsFineCuda(
|
|||||||
// are expected to be in NDC coordinates in the range [-1, 1].
|
// are expected to be in NDC coordinates in the range [-1, 1].
|
||||||
// bin_points: int32 Tensor of shape (N, B, B, M) giving the indices of points
|
// bin_points: int32 Tensor of shape (N, B, B, M) giving the indices of points
|
||||||
// that fall into each bin (output from coarse rasterization)
|
// that fall into each bin (output from coarse rasterization)
|
||||||
// image_size: Size of image to generate (in pixels)
|
// image_size: Tuple (H, W) giving the size in pixels of the output
|
||||||
|
// image to be rasterized.
|
||||||
// radius: FloatTensor of shape (P) giving the radius (in NDC units) of
|
// radius: FloatTensor of shape (P) giving the radius (in NDC units) of
|
||||||
// each point in points.
|
// each point in points.
|
||||||
// bin_size: Size of each bin (in pixels)
|
// bin_size: Size of each bin (in pixels)
|
||||||
@ -214,7 +217,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsFineCuda(
|
|||||||
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsFine(
|
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsFine(
|
||||||
const torch::Tensor& points,
|
const torch::Tensor& points,
|
||||||
const torch::Tensor& bin_points,
|
const torch::Tensor& bin_points,
|
||||||
const int image_size,
|
const std::tuple<int, int> image_size,
|
||||||
const torch::Tensor& radius,
|
const torch::Tensor& radius,
|
||||||
const int bin_size,
|
const int bin_size,
|
||||||
const int points_per_pixel) {
|
const int points_per_pixel) {
|
||||||
@ -303,7 +306,8 @@ torch::Tensor RasterizePointsBackward(
|
|||||||
// for each pointcloud in the batch.
|
// for each pointcloud in the batch.
|
||||||
// radius: FloatTensor of shape (P) giving the radius (in NDC units) of
|
// radius: FloatTensor of shape (P) giving the radius (in NDC units) of
|
||||||
// each point in points.
|
// each point in points.
|
||||||
// image_size: (S) Size of the image to return (in pixels)
|
// image_size: Tuple (H, W) giving the size in pixels of the output
|
||||||
|
// image to be rasterized.
|
||||||
// points_per_pixel: (K) The number of points to return for each pixel
|
// points_per_pixel: (K) The number of points to return for each pixel
|
||||||
// bin_size: Bin size (in pixels) for coarse-to-fine rasterization. Setting
|
// bin_size: Bin size (in pixels) for coarse-to-fine rasterization. Setting
|
||||||
// bin_size=0 uses naive rasterization instead.
|
// bin_size=0 uses naive rasterization instead.
|
||||||
@ -325,7 +329,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePoints(
|
|||||||
const torch::Tensor& points,
|
const torch::Tensor& points,
|
||||||
const torch::Tensor& cloud_to_packed_first_idx,
|
const torch::Tensor& cloud_to_packed_first_idx,
|
||||||
const torch::Tensor& num_points_per_cloud,
|
const torch::Tensor& num_points_per_cloud,
|
||||||
const int image_size,
|
const std::tuple<int, int> image_size,
|
||||||
const torch::Tensor& radius,
|
const torch::Tensor& radius,
|
||||||
const int points_per_pixel,
|
const int points_per_pixel,
|
||||||
const int bin_size,
|
const int bin_size,
|
||||||
|
@ -3,33 +3,27 @@
|
|||||||
#include <torch/extension.h>
|
#include <torch/extension.h>
|
||||||
#include <queue>
|
#include <queue>
|
||||||
#include <tuple>
|
#include <tuple>
|
||||||
|
#include "rasterization_utils.h"
|
||||||
// Given a pixel coordinate 0 <= i < S, convert it to a normalized device
|
|
||||||
// coordinate in the range [-1, 1]. The NDC range is divided into S evenly-sized
|
|
||||||
// pixels, and assume that each pixel falls in the *center* of its range.
|
|
||||||
static float PixToNdc(const int i, const int S) {
|
|
||||||
// NDC x-offset + (i * pixel_width + half_pixel_width)
|
|
||||||
return -1 + (2 * i + 1.0f) / S;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsNaiveCpu(
|
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsNaiveCpu(
|
||||||
const torch::Tensor& points, // (P, 3)
|
const torch::Tensor& points, // (P, 3)
|
||||||
const torch::Tensor& cloud_to_packed_first_idx, // (N)
|
const torch::Tensor& cloud_to_packed_first_idx, // (N)
|
||||||
const torch::Tensor& num_points_per_cloud, // (N)
|
const torch::Tensor& num_points_per_cloud, // (N)
|
||||||
const int image_size,
|
const std::tuple<int, int> image_size,
|
||||||
const torch::Tensor& radius,
|
const torch::Tensor& radius,
|
||||||
const int points_per_pixel) {
|
const int points_per_pixel) {
|
||||||
const int32_t N = cloud_to_packed_first_idx.size(0); // batch_size.
|
const int32_t N = cloud_to_packed_first_idx.size(0); // batch_size.
|
||||||
|
|
||||||
const int S = image_size;
|
const int H = std::get<0>(image_size);
|
||||||
|
const int W = std::get<1>(image_size);
|
||||||
const int K = points_per_pixel;
|
const int K = points_per_pixel;
|
||||||
|
|
||||||
// Initialize output tensors.
|
// Initialize output tensors.
|
||||||
auto int_opts = num_points_per_cloud.options().dtype(torch::kInt32);
|
auto int_opts = num_points_per_cloud.options().dtype(torch::kInt32);
|
||||||
auto float_opts = points.options().dtype(torch::kFloat32);
|
auto float_opts = points.options().dtype(torch::kFloat32);
|
||||||
torch::Tensor point_idxs = torch::full({N, S, S, K}, -1, int_opts);
|
torch::Tensor point_idxs = torch::full({N, H, W, K}, -1, int_opts);
|
||||||
torch::Tensor zbuf = torch::full({N, S, S, K}, -1, float_opts);
|
torch::Tensor zbuf = torch::full({N, H, W, K}, -1, float_opts);
|
||||||
torch::Tensor pix_dists = torch::full({N, S, S, K}, -1, float_opts);
|
torch::Tensor pix_dists = torch::full({N, H, W, K}, -1, float_opts);
|
||||||
|
|
||||||
auto points_a = points.accessor<float, 2>();
|
auto points_a = points.accessor<float, 2>();
|
||||||
auto point_idxs_a = point_idxs.accessor<int32_t, 4>();
|
auto point_idxs_a = point_idxs.accessor<int32_t, 4>();
|
||||||
@ -46,16 +40,16 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsNaiveCpu(
|
|||||||
const int point_stop_idx =
|
const int point_stop_idx =
|
||||||
(point_start_idx + num_points_per_cloud[n].item().to<int32_t>());
|
(point_start_idx + num_points_per_cloud[n].item().to<int32_t>());
|
||||||
|
|
||||||
for (int yi = 0; yi < S; ++yi) {
|
for (int yi = 0; yi < H; ++yi) {
|
||||||
// Reverse the order of yi so that +Y is pointing upwards in the image.
|
// Reverse the order of yi so that +Y is pointing upwards in the image.
|
||||||
const int yidx = S - 1 - yi;
|
const int yidx = H - 1 - yi;
|
||||||
const float yf = PixToNdc(yidx, S);
|
const float yf = PixToNonSquareNdc(yidx, H, W);
|
||||||
|
|
||||||
for (int xi = 0; xi < S; ++xi) {
|
for (int xi = 0; xi < W; ++xi) {
|
||||||
// Reverse the order of xi so that +X is pointing to the left in the
|
// Reverse the order of xi so that +X is pointing to the left in the
|
||||||
// image.
|
// image.
|
||||||
const int xidx = S - 1 - xi;
|
const int xidx = W - 1 - xi;
|
||||||
const float xf = PixToNdc(xidx, S);
|
const float xf = PixToNonSquareNdc(xidx, W, H);
|
||||||
|
|
||||||
// Use a priority queue to hold (z, idx, r)
|
// Use a priority queue to hold (z, idx, r)
|
||||||
std::priority_queue<std::tuple<float, int, float>> q;
|
std::priority_queue<std::tuple<float, int, float>> q;
|
||||||
@ -99,25 +93,36 @@ torch::Tensor RasterizePointsCoarseCpu(
|
|||||||
const torch::Tensor& points, // (P, 3)
|
const torch::Tensor& points, // (P, 3)
|
||||||
const torch::Tensor& cloud_to_packed_first_idx, // (N)
|
const torch::Tensor& cloud_to_packed_first_idx, // (N)
|
||||||
const torch::Tensor& num_points_per_cloud, // (N)
|
const torch::Tensor& num_points_per_cloud, // (N)
|
||||||
const int image_size,
|
const std::tuple<int, int> image_size,
|
||||||
const torch::Tensor& radius,
|
const torch::Tensor& radius,
|
||||||
const int bin_size,
|
const int bin_size,
|
||||||
const int max_points_per_bin) {
|
const int max_points_per_bin) {
|
||||||
const int32_t N = cloud_to_packed_first_idx.size(0); // batch_size.
|
const int32_t N = cloud_to_packed_first_idx.size(0); // batch_size.
|
||||||
|
|
||||||
const int B = 1 + (image_size - 1) / bin_size; // Integer division round up
|
|
||||||
const int M = max_points_per_bin;
|
const int M = max_points_per_bin;
|
||||||
|
|
||||||
|
const float H = std::get<0>(image_size);
|
||||||
|
const float W = std::get<1>(image_size);
|
||||||
|
|
||||||
|
// Integer division round up.
|
||||||
|
const int BH = 1 + (H - 1) / bin_size;
|
||||||
|
const int BW = 1 + (W - 1) / bin_size;
|
||||||
|
|
||||||
auto opts = num_points_per_cloud.options().dtype(torch::kInt32);
|
auto opts = num_points_per_cloud.options().dtype(torch::kInt32);
|
||||||
torch::Tensor points_per_bin = torch::zeros({N, B, B}, opts);
|
torch::Tensor points_per_bin = torch::zeros({N, BH, BW}, opts);
|
||||||
torch::Tensor bin_points = torch::full({N, B, B, M}, -1, opts);
|
torch::Tensor bin_points = torch::full({N, BH, BW, M}, -1, opts);
|
||||||
|
|
||||||
auto points_a = points.accessor<float, 2>();
|
auto points_a = points.accessor<float, 2>();
|
||||||
auto points_per_bin_a = points_per_bin.accessor<int32_t, 3>();
|
auto points_per_bin_a = points_per_bin.accessor<int32_t, 3>();
|
||||||
auto bin_points_a = bin_points.accessor<int32_t, 4>();
|
auto bin_points_a = bin_points.accessor<int32_t, 4>();
|
||||||
auto radius_a = radius.accessor<float, 1>();
|
auto radius_a = radius.accessor<float, 1>();
|
||||||
|
|
||||||
const float pixel_width = 2.0f / image_size;
|
const float ndc_x_range = NonSquareNdcRange(W, H);
|
||||||
const float bin_width = pixel_width * bin_size;
|
const float pixel_width_x = ndc_x_range / W;
|
||||||
|
const float bin_width_x = pixel_width_x * bin_size;
|
||||||
|
|
||||||
|
const float ndc_y_range = NonSquareNdcRange(H, W);
|
||||||
|
const float pixel_width_y = ndc_y_range / H;
|
||||||
|
const float bin_width_y = pixel_width_y * bin_size;
|
||||||
|
|
||||||
for (int n = 0; n < N; ++n) {
|
for (int n = 0; n < N; ++n) {
|
||||||
// Loop through each pointcloud in the batch.
|
// Loop through each pointcloud in the batch.
|
||||||
@ -129,15 +134,15 @@ torch::Tensor RasterizePointsCoarseCpu(
|
|||||||
(point_start_idx + num_points_per_cloud[n].item().to<int32_t>());
|
(point_start_idx + num_points_per_cloud[n].item().to<int32_t>());
|
||||||
|
|
||||||
float bin_y_min = -1.0f;
|
float bin_y_min = -1.0f;
|
||||||
float bin_y_max = bin_y_min + bin_width;
|
float bin_y_max = bin_y_min + bin_width_y;
|
||||||
|
|
||||||
// Iterate through the horizontal bins from top to bottom.
|
// Iterate through the horizontal bins from top to bottom.
|
||||||
for (int by = 0; by < B; by++) {
|
for (int by = 0; by < BH; by++) {
|
||||||
float bin_x_min = -1.0f;
|
float bin_x_min = -1.0f;
|
||||||
float bin_x_max = bin_x_min + bin_width;
|
float bin_x_max = bin_x_min + bin_width_x;
|
||||||
|
|
||||||
// Iterate through bins on this horizontal line, left to right.
|
// Iterate through bins on this horizontal line, left to right.
|
||||||
for (int bx = 0; bx < B; bx++) {
|
for (int bx = 0; bx < BW; bx++) {
|
||||||
int32_t points_hit = 0;
|
int32_t points_hit = 0;
|
||||||
for (int p = point_start_idx; p < point_stop_idx; ++p) {
|
for (int p = point_start_idx; p < point_stop_idx; ++p) {
|
||||||
float px = points_a[p][0];
|
float px = points_a[p][0];
|
||||||
@ -172,11 +177,11 @@ torch::Tensor RasterizePointsCoarseCpu(
|
|||||||
|
|
||||||
// Shift the bin to the right for the next loop iteration
|
// Shift the bin to the right for the next loop iteration
|
||||||
bin_x_min = bin_x_max;
|
bin_x_min = bin_x_max;
|
||||||
bin_x_max = bin_x_min + bin_width;
|
bin_x_max = bin_x_min + bin_width_x;
|
||||||
}
|
}
|
||||||
// Shift the bin down for the next loop iteration
|
// Shift the bin down for the next loop iteration
|
||||||
bin_y_min = bin_y_max;
|
bin_y_min = bin_y_max;
|
||||||
bin_y_max = bin_y_min + bin_width;
|
bin_y_max = bin_y_min + bin_width_y;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return bin_points;
|
return bin_points;
|
||||||
@ -194,11 +199,6 @@ torch::Tensor RasterizePointsBackwardCpu(
|
|||||||
const int W = idxs.size(2);
|
const int W = idxs.size(2);
|
||||||
const int K = idxs.size(3);
|
const int K = idxs.size(3);
|
||||||
|
|
||||||
// For now only support square images.
|
|
||||||
// TODO(jcjohns): Extend to non-square images.
|
|
||||||
if (H != W) {
|
|
||||||
AT_ERROR("RasterizePointsBackwardCpu only supports square images");
|
|
||||||
}
|
|
||||||
torch::Tensor grad_points = torch::zeros({P, 3}, points.options());
|
torch::Tensor grad_points = torch::zeros({P, 3}, points.options());
|
||||||
|
|
||||||
auto points_a = points.accessor<float, 2>();
|
auto points_a = points.accessor<float, 2>();
|
||||||
@ -212,7 +212,7 @@ torch::Tensor RasterizePointsBackwardCpu(
|
|||||||
// Reverse the order of yi so that +Y is pointing upwards in the image.
|
// Reverse the order of yi so that +Y is pointing upwards in the image.
|
||||||
const int yidx = H - 1 - y;
|
const int yidx = H - 1 - y;
|
||||||
// Y coordinate of the top of the pixel.
|
// Y coordinate of the top of the pixel.
|
||||||
const float yf = PixToNdc(yidx, H);
|
const float yf = PixToNonSquareNdc(yidx, H, W);
|
||||||
|
|
||||||
// Iterate through pixels on this horizontal line, left to right.
|
// Iterate through pixels on this horizontal line, left to right.
|
||||||
for (int x = 0; x < W; ++x) { // Loop over pixels in the row
|
for (int x = 0; x < W; ++x) { // Loop over pixels in the row
|
||||||
@ -220,7 +220,7 @@ torch::Tensor RasterizePointsBackwardCpu(
|
|||||||
// Reverse the order of xi so that +X is pointing to the left in the
|
// Reverse the order of xi so that +X is pointing to the left in the
|
||||||
// image.
|
// image.
|
||||||
const int xidx = W - 1 - x;
|
const int xidx = W - 1 - x;
|
||||||
const float xf = PixToNdc(xidx, W);
|
const float xf = PixToNonSquareNdc(xidx, W, H);
|
||||||
for (int k = 0; k < K; ++k) { // Loop over points for the pixel
|
for (int k = 0; k < K; ++k) { // Loop over points for the pixel
|
||||||
const int p = idxs_a[n][y][x][k];
|
const int p = idxs_a[n][y][x][k];
|
||||||
if (p < 0) {
|
if (p < 0) {
|
||||||
|
@ -6,6 +6,7 @@ from .rasterizer import MeshRasterizer, RasterizationSettings
|
|||||||
from .renderer import MeshRenderer
|
from .renderer import MeshRenderer
|
||||||
from .shader import TexturedSoftPhongShader # DEPRECATED
|
from .shader import TexturedSoftPhongShader # DEPRECATED
|
||||||
from .shader import (
|
from .shader import (
|
||||||
|
BlendParams,
|
||||||
HardFlatShader,
|
HardFlatShader,
|
||||||
HardGouraudShader,
|
HardGouraudShader,
|
||||||
HardPhongShader,
|
HardPhongShader,
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||||
|
|
||||||
|
|
||||||
from typing import Optional, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -20,7 +20,7 @@ kMaxFacesPerBin = 22
|
|||||||
|
|
||||||
def rasterize_meshes(
|
def rasterize_meshes(
|
||||||
meshes,
|
meshes,
|
||||||
image_size: Union[int, Tuple[int, int]] = 256,
|
image_size: Union[int, List[int], Tuple[int, int]] = 256,
|
||||||
blur_radius: float = 0.0,
|
blur_radius: float = 0.0,
|
||||||
faces_per_pixel: int = 8,
|
faces_per_pixel: int = 8,
|
||||||
bin_size: Optional[int] = None,
|
bin_size: Optional[int] = None,
|
||||||
@ -219,7 +219,7 @@ class _RasterizeFaceVerts(torch.autograd.Function):
|
|||||||
face_verts,
|
face_verts,
|
||||||
mesh_to_face_first_idx,
|
mesh_to_face_first_idx,
|
||||||
num_faces_per_mesh,
|
num_faces_per_mesh,
|
||||||
image_size: Tuple[int, int] = (256, 256),
|
image_size: Union[List[int], Tuple[int, int]] = (256, 256),
|
||||||
blur_radius: float = 0.01,
|
blur_radius: float = 0.01,
|
||||||
faces_per_pixel: int = 0,
|
faces_per_pixel: int = 0,
|
||||||
bin_size: int = 0,
|
bin_size: int = 0,
|
||||||
@ -287,11 +287,6 @@ class _RasterizeFaceVerts(torch.autograd.Function):
|
|||||||
return grads
|
return grads
|
||||||
|
|
||||||
|
|
||||||
def pix_to_ndc(i, S):
|
|
||||||
# NDC x-offset + (i * pixel_width + half_pixel_width)
|
|
||||||
return -1 + (2 * i + 1.0) / S
|
|
||||||
|
|
||||||
|
|
||||||
def non_square_ndc_range(S1, S2):
|
def non_square_ndc_range(S1, S2):
|
||||||
"""
|
"""
|
||||||
In the case of non square images, we scale the NDC range
|
In the case of non square images, we scale the NDC range
|
||||||
|
@ -75,7 +75,7 @@ def _add_background_color_to_images(pix_idxs, images, background_color):
|
|||||||
pixels with accumulated features have unchanged values.
|
pixels with accumulated features have unchanged values.
|
||||||
"""
|
"""
|
||||||
# Initialize background mask
|
# Initialize background mask
|
||||||
background_mask = pix_idxs[:, 0] < 0 # (N, image_size, image_size)
|
background_mask = pix_idxs[:, 0] < 0 # (N, H, W)
|
||||||
|
|
||||||
# Convert background_color to an appropriate tensor and check shape
|
# Convert background_color to an appropriate tensor and check shape
|
||||||
if not torch.is_tensor(background_color):
|
if not torch.is_tensor(background_color):
|
||||||
|
@ -6,7 +6,7 @@ import torch
|
|||||||
|
|
||||||
# pyre-fixme[21]: Could not find name `_C` in `pytorch3d`.
|
# pyre-fixme[21]: Could not find name `_C` in `pytorch3d`.
|
||||||
from pytorch3d import _C
|
from pytorch3d import _C
|
||||||
from pytorch3d.renderer.mesh.rasterize_meshes import pix_to_ndc
|
from pytorch3d.renderer.mesh.rasterize_meshes import pix_to_non_square_ndc
|
||||||
|
|
||||||
|
|
||||||
# Maxinum number of faces per bins for
|
# Maxinum number of faces per bins for
|
||||||
@ -14,17 +14,30 @@ from pytorch3d.renderer.mesh.rasterize_meshes import pix_to_ndc
|
|||||||
kMaxPointsPerBin = 22
|
kMaxPointsPerBin = 22
|
||||||
|
|
||||||
|
|
||||||
# TODO(jcjohns): Support non-square images
|
|
||||||
def rasterize_points(
|
def rasterize_points(
|
||||||
pointclouds,
|
pointclouds,
|
||||||
image_size: int = 256,
|
image_size: Union[int, List[int], Tuple[int, int]] = 256,
|
||||||
radius: Union[float, List, Tuple, torch.Tensor] = 0.01,
|
radius: Union[float, List, Tuple, torch.Tensor] = 0.01,
|
||||||
points_per_pixel: int = 8,
|
points_per_pixel: int = 8,
|
||||||
bin_size: Optional[int] = None,
|
bin_size: Optional[int] = None,
|
||||||
max_points_per_bin: Optional[int] = None,
|
max_points_per_bin: Optional[int] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Pointcloud rasterization
|
Each pointcloud is rasterized onto a separate image of shape
|
||||||
|
(H, W) if `image_size` is a tuple or (image_size, image_size) if it
|
||||||
|
is an int.
|
||||||
|
|
||||||
|
If the desired image size is non square (i.e. a tuple of (H, W) where H != W)
|
||||||
|
the aspect ratio needs special consideration. There are two aspect ratios
|
||||||
|
to be aware of:
|
||||||
|
- the aspect ratio of each pixel
|
||||||
|
- the aspect ratio of the output image
|
||||||
|
The camera can be used to set the pixel aspect ratio. In the rasterizer,
|
||||||
|
we assume square pixels, but variable image aspect ratio (i.e rectangle images).
|
||||||
|
|
||||||
|
In most cases you will want to set the camera aspect ratio to
|
||||||
|
1.0 (i.e. square pixels) and only vary the
|
||||||
|
`image_size` (i.e. the output image dimensions in pix
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
pointclouds: A Pointclouds object representing a batch of point clouds to be
|
pointclouds: A Pointclouds object representing a batch of point clouds to be
|
||||||
@ -34,7 +47,8 @@ def rasterize_points(
|
|||||||
be in normalized device coordinates (NDC): [-1, 1]^3 with the camera at
|
be in normalized device coordinates (NDC): [-1, 1]^3 with the camera at
|
||||||
(0, 0, 0); In the camera coordinate frame the x-axis goes from right-to-left,
|
(0, 0, 0); In the camera coordinate frame the x-axis goes from right-to-left,
|
||||||
the y-axis goes from bottom-to-top, and the z-axis goes from back-to-front.
|
the y-axis goes from bottom-to-top, and the z-axis goes from back-to-front.
|
||||||
image_size: Integer giving the resolution of the rasterized image
|
image_size: Size in pixels of the output image to be rasterized.
|
||||||
|
Can optionally be a tuple of (H, W) in the case of non square images.
|
||||||
radius (Optional): The radius (in NDC units) of the disk to
|
radius (Optional): The radius (in NDC units) of the disk to
|
||||||
be rasterized. This can either be a float in which case the same radius is used
|
be rasterized. This can either be a float in which case the same radius is used
|
||||||
for each point, or a torch.Tensor of shape (N, P) giving a radius per point
|
for each point, or a torch.Tensor of shape (N, P) giving a radius per point
|
||||||
@ -71,6 +85,9 @@ def rasterize_points(
|
|||||||
then `dists[n, y, x, k]` is the squared distance between the pixel (y, x)
|
then `dists[n, y, x, k]` is the squared distance between the pixel (y, x)
|
||||||
and the point `(points[n, p, 0], points[n, p, 1])`. Pixels hit with fewer
|
and the point `(points[n, p, 0], points[n, p, 1])`. Pixels hit with fewer
|
||||||
than points_per_pixel are padded with -1.
|
than points_per_pixel are padded with -1.
|
||||||
|
|
||||||
|
In the case that image_size is a tuple of (H, W) then the outputs
|
||||||
|
will be of shape `(N, H, W, ...)`.
|
||||||
"""
|
"""
|
||||||
points_packed = pointclouds.points_packed()
|
points_packed = pointclouds.points_packed()
|
||||||
cloud_to_packed_first_idx = pointclouds.cloud_to_packed_first_idx()
|
cloud_to_packed_first_idx = pointclouds.cloud_to_packed_first_idx()
|
||||||
@ -78,26 +95,46 @@ def rasterize_points(
|
|||||||
|
|
||||||
radius = _format_radius(radius, pointclouds)
|
radius = _format_radius(radius, pointclouds)
|
||||||
|
|
||||||
|
# In the case that H != W use the max image size to set the bin_size
|
||||||
|
# to accommodate the num bins constraint in the coarse rasteizer.
|
||||||
|
# If the ratio of H:W is large this might cause issues as the smaller
|
||||||
|
# dimension will have fewer bins.
|
||||||
|
# TODO: consider a better way of setting the bin size.
|
||||||
|
if isinstance(image_size, (tuple, list)):
|
||||||
|
if len(image_size) != 2:
|
||||||
|
raise ValueError("Image size can only be a tuple/list of (H, W)")
|
||||||
|
if not all(i > 0 for i in image_size):
|
||||||
|
raise ValueError(
|
||||||
|
"Image sizes must be greater than 0; got %d, %d" % image_size
|
||||||
|
)
|
||||||
|
if not all(type(i) == int for i in image_size):
|
||||||
|
raise ValueError("Image sizes must be integers; got %f, %f" % image_size)
|
||||||
|
max_image_size = max(*image_size)
|
||||||
|
im_size = image_size
|
||||||
|
else:
|
||||||
|
im_size = (image_size, image_size)
|
||||||
|
max_image_size = image_size
|
||||||
|
|
||||||
if bin_size is None:
|
if bin_size is None:
|
||||||
if not points_packed.is_cuda:
|
if not points_packed.is_cuda:
|
||||||
# Binned CPU rasterization not fully implemented
|
# Binned CPU rasterization not fully implemented
|
||||||
bin_size = 0
|
bin_size = 0
|
||||||
else:
|
else:
|
||||||
# TODO: These heuristics are not well-thought out!
|
# TODO: These heuristics are not well-thought out!
|
||||||
if image_size <= 64:
|
if max_image_size <= 64:
|
||||||
bin_size = 8
|
bin_size = 8
|
||||||
elif image_size <= 256:
|
elif max_image_size <= 256:
|
||||||
bin_size = 16
|
bin_size = 16
|
||||||
elif image_size <= 512:
|
elif max_image_size <= 512:
|
||||||
bin_size = 32
|
bin_size = 32
|
||||||
elif image_size <= 1024:
|
elif max_image_size <= 1024:
|
||||||
bin_size = 64
|
bin_size = 64
|
||||||
|
|
||||||
if bin_size != 0:
|
if bin_size != 0:
|
||||||
# There is a limit on the number of points per bin in the cuda kernel.
|
# There is a limit on the number of points per bin in the cuda kernel.
|
||||||
# pyre-fixme[58]: `//` is not supported for operand types `int` and
|
# pyre-fixme[58]: `//` is not supported for operand types `int` and
|
||||||
# `Union[int, None, int]`.
|
# `Union[int, None, int]`.
|
||||||
points_per_bin = 1 + (image_size - 1) // bin_size
|
points_per_bin = 1 + (max_image_size - 1) // bin_size
|
||||||
if points_per_bin >= kMaxPointsPerBin:
|
if points_per_bin >= kMaxPointsPerBin:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"bin_size too small, number of points per bin must be less than %d; got %d"
|
"bin_size too small, number of points per bin must be less than %d; got %d"
|
||||||
@ -114,7 +151,7 @@ def rasterize_points(
|
|||||||
points_packed,
|
points_packed,
|
||||||
cloud_to_packed_first_idx,
|
cloud_to_packed_first_idx,
|
||||||
num_points_per_cloud,
|
num_points_per_cloud,
|
||||||
image_size,
|
im_size,
|
||||||
radius,
|
radius,
|
||||||
points_per_pixel,
|
points_per_pixel,
|
||||||
bin_size,
|
bin_size,
|
||||||
@ -173,7 +210,7 @@ class _RasterizePoints(torch.autograd.Function):
|
|||||||
points, # (P, 3)
|
points, # (P, 3)
|
||||||
cloud_to_packed_first_idx,
|
cloud_to_packed_first_idx,
|
||||||
num_points_per_cloud,
|
num_points_per_cloud,
|
||||||
image_size: int = 256,
|
image_size: Union[List[int], Tuple[int, int]] = (256, 256),
|
||||||
radius: Union[float, torch.Tensor] = 0.01,
|
radius: Union[float, torch.Tensor] = 0.01,
|
||||||
points_per_pixel: int = 8,
|
points_per_pixel: int = 8,
|
||||||
bin_size: int = 0,
|
bin_size: int = 0,
|
||||||
@ -225,7 +262,7 @@ class _RasterizePoints(torch.autograd.Function):
|
|||||||
|
|
||||||
def rasterize_points_python(
|
def rasterize_points_python(
|
||||||
pointclouds,
|
pointclouds,
|
||||||
image_size: int = 256,
|
image_size: Union[int, Tuple[int, int]] = 256,
|
||||||
radius: Union[float, torch.Tensor] = 0.01,
|
radius: Union[float, torch.Tensor] = 0.01,
|
||||||
points_per_pixel: int = 8,
|
points_per_pixel: int = 8,
|
||||||
):
|
):
|
||||||
@ -235,7 +272,12 @@ def rasterize_points_python(
|
|||||||
Inputs / Outputs: Same as above
|
Inputs / Outputs: Same as above
|
||||||
"""
|
"""
|
||||||
N = len(pointclouds)
|
N = len(pointclouds)
|
||||||
S, K = image_size, points_per_pixel
|
H, W = (
|
||||||
|
image_size
|
||||||
|
if isinstance(image_size, (tuple, list))
|
||||||
|
else (image_size, image_size)
|
||||||
|
)
|
||||||
|
K = points_per_pixel
|
||||||
device = pointclouds.device
|
device = pointclouds.device
|
||||||
|
|
||||||
points_packed = pointclouds.points_packed()
|
points_packed = pointclouds.points_packed()
|
||||||
@ -247,11 +289,11 @@ def rasterize_points_python(
|
|||||||
|
|
||||||
# Intialize output tensors.
|
# Intialize output tensors.
|
||||||
point_idxs = torch.full(
|
point_idxs = torch.full(
|
||||||
(N, S, S, K), fill_value=-1, dtype=torch.int32, device=device
|
(N, H, W, K), fill_value=-1, dtype=torch.int32, device=device
|
||||||
)
|
)
|
||||||
zbuf = torch.full((N, S, S, K), fill_value=-1, dtype=torch.float32, device=device)
|
zbuf = torch.full((N, H, W, K), fill_value=-1, dtype=torch.float32, device=device)
|
||||||
pix_dists = torch.full(
|
pix_dists = torch.full(
|
||||||
(N, S, S, K), fill_value=-1, dtype=torch.float32, device=device
|
(N, H, W, K), fill_value=-1, dtype=torch.float32, device=device
|
||||||
)
|
)
|
||||||
|
|
||||||
# NDC is from [-1, 1]. Get pixel size using specified image size.
|
# NDC is from [-1, 1]. Get pixel size using specified image size.
|
||||||
@ -263,18 +305,18 @@ def rasterize_points_python(
|
|||||||
point_stop_idx = point_start_idx + num_points_per_cloud[n]
|
point_stop_idx = point_start_idx + num_points_per_cloud[n]
|
||||||
|
|
||||||
# Iterate through the horizontal lines of the image from top to bottom.
|
# Iterate through the horizontal lines of the image from top to bottom.
|
||||||
for yi in range(S):
|
for yi in range(H):
|
||||||
# Y coordinate of one end of the image. Reverse the ordering
|
# Y coordinate of one end of the image. Reverse the ordering
|
||||||
# of yi so that +Y is pointing up in the image.
|
# of yi so that +Y is pointing up in the image.
|
||||||
yfix = S - 1 - yi
|
yfix = H - 1 - yi
|
||||||
yf = pix_to_ndc(yfix, S)
|
yf = pix_to_non_square_ndc(yfix, H, W)
|
||||||
|
|
||||||
# Iterate through pixels on this horizontal line, left to right.
|
# Iterate through pixels on this horizontal line, left to right.
|
||||||
for xi in range(S):
|
for xi in range(W):
|
||||||
# X coordinate of one end of the image. Reverse the ordering
|
# X coordinate of one end of the image. Reverse the ordering
|
||||||
# of xi so that +X is pointing to the left in the image.
|
# of xi so that +X is pointing to the left in the image.
|
||||||
xfix = S - 1 - xi
|
xfix = W - 1 - xi
|
||||||
xf = pix_to_ndc(xfix, S)
|
xf = pix_to_non_square_ndc(xfix, W, H)
|
||||||
|
|
||||||
top_k_points = []
|
top_k_points = []
|
||||||
# Check whether each point in the batch affects this pixel.
|
# Check whether each point in the batch affects this pixel.
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||||
|
|
||||||
|
|
||||||
from typing import NamedTuple, Optional, Union
|
from typing import NamedTuple, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -29,7 +29,7 @@ class PointsRasterizationSettings:
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
image_size: int = 256,
|
image_size: Union[int, Tuple[int, int]] = 256,
|
||||||
radius: Union[float, torch.Tensor] = 0.01,
|
radius: Union[float, torch.Tensor] = 0.01,
|
||||||
points_per_pixel: int = 8,
|
points_per_pixel: int = 8,
|
||||||
bin_size: Optional[int] = None,
|
bin_size: Optional[int] = None,
|
||||||
|
@ -74,6 +74,21 @@ def bm_python_vs_cpu_vs_cuda() -> None:
|
|||||||
kwargs_list += [
|
kwargs_list += [
|
||||||
{"N": 32, "P": 100000, "img_size": 128, "radius": 0.01, "pts_per_pxl": 50},
|
{"N": 32, "P": 100000, "img_size": 128, "radius": 0.01, "pts_per_pxl": 50},
|
||||||
{"N": 8, "P": 200000, "img_size": 512, "radius": 0.01, "pts_per_pxl": 50},
|
{"N": 8, "P": 200000, "img_size": 512, "radius": 0.01, "pts_per_pxl": 50},
|
||||||
|
{"N": 8, "P": 200000, "img_size": 256, "radius": 0.01, "pts_per_pxl": 50},
|
||||||
|
{
|
||||||
|
"N": 8,
|
||||||
|
"P": 200000,
|
||||||
|
"img_size": (512, 256),
|
||||||
|
"radius": 0.01,
|
||||||
|
"pts_per_pxl": 50,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"N": 8,
|
||||||
|
"P": 200000,
|
||||||
|
"img_size": (256, 512),
|
||||||
|
"radius": 0.01,
|
||||||
|
"pts_per_pxl": 50,
|
||||||
|
},
|
||||||
]
|
]
|
||||||
for k in kwargs_list:
|
for k in kwargs_list:
|
||||||
k["device"] = "cuda"
|
k["device"] = "cuda"
|
||||||
|
BIN
tests/data/test_pointcloud_rectangle_image.png
Normal file
BIN
tests/data/test_pointcloud_rectangle_image.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 20 KiB |
@ -404,7 +404,7 @@ class TestRasterizePoints(TestCaseMixin, unittest.TestCase):
|
|||||||
torch.manual_seed(231)
|
torch.manual_seed(231)
|
||||||
N = 3
|
N = 3
|
||||||
max_P = 1000
|
max_P = 1000
|
||||||
image_size = 64
|
image_size = (64, 64)
|
||||||
radius = 0.1
|
radius = 0.1
|
||||||
bin_size = 16
|
bin_size = 16
|
||||||
max_points_per_bin = 500
|
max_points_per_bin = 500
|
||||||
@ -501,7 +501,7 @@ class TestRasterizePoints(TestCaseMixin, unittest.TestCase):
|
|||||||
device=device
|
device=device
|
||||||
)
|
)
|
||||||
# fmt: on
|
# fmt: on
|
||||||
image_size = 16
|
image_size = (16, 16)
|
||||||
radius = 0.2
|
radius = 0.2
|
||||||
bin_size = 8
|
bin_size = 8
|
||||||
max_points_per_bin = 5
|
max_points_per_bin = 5
|
||||||
|
@ -12,19 +12,33 @@ from pytorch3d.io import load_obj
|
|||||||
from pytorch3d.renderer.cameras import FoVPerspectiveCameras, look_at_view_transform
|
from pytorch3d.renderer.cameras import FoVPerspectiveCameras, look_at_view_transform
|
||||||
from pytorch3d.renderer.lighting import PointLights
|
from pytorch3d.renderer.lighting import PointLights
|
||||||
from pytorch3d.renderer.materials import Materials
|
from pytorch3d.renderer.materials import Materials
|
||||||
from pytorch3d.renderer.mesh import TexturesUV
|
from pytorch3d.renderer.mesh import (
|
||||||
|
BlendParams,
|
||||||
|
MeshRasterizer,
|
||||||
|
MeshRenderer,
|
||||||
|
RasterizationSettings,
|
||||||
|
SoftPhongShader,
|
||||||
|
TexturesUV,
|
||||||
|
)
|
||||||
from pytorch3d.renderer.mesh.rasterize_meshes import (
|
from pytorch3d.renderer.mesh.rasterize_meshes import (
|
||||||
rasterize_meshes,
|
rasterize_meshes,
|
||||||
rasterize_meshes_python,
|
rasterize_meshes_python,
|
||||||
)
|
)
|
||||||
from pytorch3d.renderer.mesh.rasterizer import (
|
from pytorch3d.renderer.mesh.rasterizer import Fragments
|
||||||
Fragments,
|
from pytorch3d.renderer.points import (
|
||||||
MeshRasterizer,
|
AlphaCompositor,
|
||||||
RasterizationSettings,
|
PointsRasterizationSettings,
|
||||||
|
PointsRasterizer,
|
||||||
|
PointsRenderer,
|
||||||
)
|
)
|
||||||
from pytorch3d.renderer.mesh.renderer import MeshRenderer
|
from pytorch3d.renderer.points.rasterize_points import (
|
||||||
from pytorch3d.renderer.mesh.shader import BlendParams, SoftPhongShader
|
rasterize_points,
|
||||||
from pytorch3d.structures import Meshes
|
rasterize_points_python,
|
||||||
|
)
|
||||||
|
from pytorch3d.renderer.points.rasterizer import PointFragments
|
||||||
|
from pytorch3d.structures import Meshes, Pointclouds
|
||||||
|
from pytorch3d.transforms.transform3d import Transform3d
|
||||||
|
from pytorch3d.utils import torus
|
||||||
|
|
||||||
|
|
||||||
DEBUG = False
|
DEBUG = False
|
||||||
@ -44,9 +58,36 @@ verts0 = torch.tensor(
|
|||||||
)
|
)
|
||||||
faces0 = torch.tensor([[1, 0, 2], [4, 3, 5]], dtype=torch.int64)
|
faces0 = torch.tensor([[1, 0, 2], [4, 3, 5]], dtype=torch.int64)
|
||||||
|
|
||||||
|
# Points for a simple pointcloud. Get the vertices from a
|
||||||
|
# torus and apply rotations such that the points are no longer
|
||||||
|
# symmerical in X/Y.
|
||||||
|
torus_mesh = torus(r=0.25, R=1.0, sides=5, rings=2 * 5)
|
||||||
|
t = (
|
||||||
|
Transform3d()
|
||||||
|
.rotate_axis_angle(angle=90, axis="Y")
|
||||||
|
.rotate_axis_angle(angle=45, axis="Z")
|
||||||
|
.scale(0.3)
|
||||||
|
)
|
||||||
|
torus_points = t.transform_points(torus_mesh.verts_padded()).squeeze()
|
||||||
|
|
||||||
class TestRasterizeRectanglesErrors(TestCaseMixin, unittest.TestCase):
|
|
||||||
def test_image_size_arg(self):
|
def _save_debug_image(idx, image_size, bin_size, blur):
|
||||||
|
"""
|
||||||
|
Save a mask image from the rasterization output for debugging.
|
||||||
|
"""
|
||||||
|
H, W = image_size
|
||||||
|
# Save out the last image for debugging
|
||||||
|
rgb = (idx[-1, ..., :3].cpu() > -1).squeeze()
|
||||||
|
suffix = "square" if H == W else "non_square"
|
||||||
|
filename = "%s_bin_size_%s_blur_%.3f_%dx%d.png"
|
||||||
|
filename = filename % (suffix, str(bin_size), blur, H, W)
|
||||||
|
if DEBUG:
|
||||||
|
filename = "DEBUG_%s" % filename
|
||||||
|
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(DATA_DIR / filename)
|
||||||
|
|
||||||
|
|
||||||
|
class TestRasterizeRectangleImagesErrors(TestCaseMixin, unittest.TestCase):
|
||||||
|
def test_mesh_image_size_arg(self):
|
||||||
meshes = Meshes(verts=[verts0], faces=[faces0])
|
meshes = Meshes(verts=[verts0], faces=[faces0])
|
||||||
|
|
||||||
with self.assertRaises(ValueError) as cm:
|
with self.assertRaises(ValueError) as cm:
|
||||||
@ -76,8 +117,38 @@ class TestRasterizeRectanglesErrors(TestCaseMixin, unittest.TestCase):
|
|||||||
)
|
)
|
||||||
self.assertTrue("sizes must be integers" in cm.msg)
|
self.assertTrue("sizes must be integers" in cm.msg)
|
||||||
|
|
||||||
|
def test_points_image_size_arg(self):
|
||||||
|
points = Pointclouds([verts0])
|
||||||
|
|
||||||
class TestRasterizeRectangles(TestCaseMixin, unittest.TestCase):
|
with self.assertRaises(ValueError) as cm:
|
||||||
|
rasterize_points(
|
||||||
|
points,
|
||||||
|
(100, 200, 3),
|
||||||
|
0.0001,
|
||||||
|
points_per_pixel=1,
|
||||||
|
)
|
||||||
|
self.assertTrue("tuple/list of (H, W)" in cm.msg)
|
||||||
|
|
||||||
|
with self.assertRaises(ValueError) as cm:
|
||||||
|
rasterize_points(
|
||||||
|
points,
|
||||||
|
(0, 10),
|
||||||
|
0.0001,
|
||||||
|
points_per_pixel=1,
|
||||||
|
)
|
||||||
|
self.assertTrue("sizes must be positive" in cm.msg)
|
||||||
|
|
||||||
|
with self.assertRaises(ValueError) as cm:
|
||||||
|
rasterize_points(
|
||||||
|
points,
|
||||||
|
(100.5, 120.5),
|
||||||
|
0.0001,
|
||||||
|
points_per_pixel=1,
|
||||||
|
)
|
||||||
|
self.assertTrue("sizes must be integers" in cm.msg)
|
||||||
|
|
||||||
|
|
||||||
|
class TestRasterizeRectangleImagesMeshes(TestCaseMixin, unittest.TestCase):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _clone_mesh(verts0, faces0, device, batch_size):
|
def _clone_mesh(verts0, faces0, device, batch_size):
|
||||||
"""
|
"""
|
||||||
@ -164,7 +235,7 @@ class TestRasterizeRectangles(TestCaseMixin, unittest.TestCase):
|
|||||||
meshes_sq, image_size=(S, S), bin_size=0, blur=blur
|
meshes_sq, image_size=(S, S), bin_size=0, blur=blur
|
||||||
)
|
)
|
||||||
# Save debug image
|
# Save debug image
|
||||||
self._save_debug_image(square_fragments, (S, S), 0, blur)
|
_save_debug_image(square_fragments.pix_to_face, (S, S), 0, blur)
|
||||||
|
|
||||||
# Extract the values in the square image which are non zero.
|
# Extract the values in the square image which are non zero.
|
||||||
square_mask = square_fragments.pix_to_face > -1
|
square_mask = square_fragments.pix_to_face > -1
|
||||||
@ -284,8 +355,8 @@ class TestRasterizeRectangles(TestCaseMixin, unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Save out debug images if needed
|
# Save out debug images if needed
|
||||||
self._save_debug_image(fragments_naive, image_size, 0, blur)
|
_save_debug_image(fragments_naive.pix_to_face, image_size, 0, blur)
|
||||||
self._save_debug_image(fragments_binned, image_size, None, blur)
|
_save_debug_image(fragments_binned.pix_to_face, image_size, None, blur)
|
||||||
|
|
||||||
# Check naive and binned fragments give the same outputs
|
# Check naive and binned fragments give the same outputs
|
||||||
self._check_fragments(fragments_naive, fragments_binned)
|
self._check_fragments(fragments_naive, fragments_binned)
|
||||||
@ -354,8 +425,8 @@ class TestRasterizeRectangles(TestCaseMixin, unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Save debug images if DEBUG is set to true at the top of the file.
|
# Save debug images if DEBUG is set to true at the top of the file.
|
||||||
self._save_debug_image(fragments_naive, image_size, 0, blur)
|
_save_debug_image(fragments_naive.pix_to_face, image_size, 0, blur)
|
||||||
self._save_debug_image(fragments_python, image_size, "python", blur)
|
_save_debug_image(fragments_python.pix_to_face, image_size, "python", blur)
|
||||||
|
|
||||||
# List of non square outputs to compare with the square output
|
# List of non square outputs to compare with the square output
|
||||||
nonsq_fragment_gradtensor_list = [
|
nonsq_fragment_gradtensor_list = [
|
||||||
@ -437,3 +508,293 @@ class TestRasterizeRectangles(TestCaseMixin, unittest.TestCase):
|
|||||||
# NOTE some pixels can be flaky
|
# NOTE some pixels can be flaky
|
||||||
cond1 = torch.allclose(rgb, image_ref, atol=0.05)
|
cond1 = torch.allclose(rgb, image_ref, atol=0.05)
|
||||||
self.assertTrue(cond1)
|
self.assertTrue(cond1)
|
||||||
|
|
||||||
|
|
||||||
|
class TestRasterizeRectangleImagesPointclouds(TestCaseMixin, unittest.TestCase):
|
||||||
|
@staticmethod
|
||||||
|
def _clone_pointcloud(verts0, device, batch_size):
|
||||||
|
"""
|
||||||
|
Helper function to detach and clone the verts.
|
||||||
|
This is needed in order to set up the tensors for
|
||||||
|
gradient computation in different tests.
|
||||||
|
"""
|
||||||
|
verts = verts0.detach().clone()
|
||||||
|
verts.requires_grad = True
|
||||||
|
pointclouds = Pointclouds(points=[verts])
|
||||||
|
pointclouds = pointclouds.to(device).extend(batch_size)
|
||||||
|
return verts, pointclouds
|
||||||
|
|
||||||
|
def _rasterize(self, meshes, image_size, bin_size, blur):
|
||||||
|
"""
|
||||||
|
Simple wrapper around the rasterize function to return
|
||||||
|
the fragment data.
|
||||||
|
"""
|
||||||
|
idxs, zbuf, dists = rasterize_points(
|
||||||
|
meshes,
|
||||||
|
image_size,
|
||||||
|
blur,
|
||||||
|
points_per_pixel=1,
|
||||||
|
bin_size=bin_size,
|
||||||
|
)
|
||||||
|
return PointFragments(
|
||||||
|
idx=idxs,
|
||||||
|
zbuf=zbuf,
|
||||||
|
dists=dists,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _check_fragments(self, frag_1, frag_2):
|
||||||
|
"""
|
||||||
|
Helper function to check that the tensors in
|
||||||
|
the Fragments frag_1 and frag_2 are the same.
|
||||||
|
"""
|
||||||
|
self.assertClose(frag_1.idx, frag_2.idx)
|
||||||
|
self.assertClose(frag_1.dists, frag_2.dists)
|
||||||
|
self.assertClose(frag_1.zbuf, frag_2.zbuf)
|
||||||
|
|
||||||
|
def _compare_square_with_nonsq(
|
||||||
|
self,
|
||||||
|
image_size,
|
||||||
|
blur,
|
||||||
|
device,
|
||||||
|
points,
|
||||||
|
nonsq_fragment_gradtensor_list,
|
||||||
|
batch_size=1,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Calculate the output from rasterizing a square image with the minimum of (H, W).
|
||||||
|
Then compare this with the same square region in the non square image.
|
||||||
|
The input points are contained within the [-1, 1] range of the image
|
||||||
|
so all the relevant pixels will be within the square region.
|
||||||
|
|
||||||
|
`nonsq_fragment_gradtensor_list` is a list of fragments and verts grad tensors
|
||||||
|
from rasterizing non square images.
|
||||||
|
"""
|
||||||
|
# Rasterize the square version of the image
|
||||||
|
H, W = image_size
|
||||||
|
S = min(H, W)
|
||||||
|
points_square, pointclouds_sq = self._clone_pointcloud(
|
||||||
|
points, device, batch_size
|
||||||
|
)
|
||||||
|
square_fragments = self._rasterize(
|
||||||
|
pointclouds_sq, image_size=(S, S), bin_size=0, blur=blur
|
||||||
|
)
|
||||||
|
# Save debug image
|
||||||
|
_save_debug_image(square_fragments.idx, (S, S), 0, blur)
|
||||||
|
|
||||||
|
# Extract the values in the square image which are non zero.
|
||||||
|
square_mask = square_fragments.idx > -1
|
||||||
|
square_dists = square_fragments.dists[square_mask]
|
||||||
|
square_zbuf = square_fragments.zbuf[square_mask]
|
||||||
|
|
||||||
|
# Retain gradients on the output of fragments to check
|
||||||
|
# intermediate values with the non square outputs.
|
||||||
|
square_fragments.dists.retain_grad()
|
||||||
|
square_fragments.zbuf.retain_grad()
|
||||||
|
|
||||||
|
# Calculate gradient for the square image
|
||||||
|
torch.manual_seed(231)
|
||||||
|
grad_zbuf = torch.randn_like(square_zbuf)
|
||||||
|
grad_dist = torch.randn_like(square_dists)
|
||||||
|
loss0 = (grad_dist * square_dists).sum() + (grad_zbuf * square_zbuf).sum()
|
||||||
|
loss0.backward()
|
||||||
|
|
||||||
|
# Now compare against the non square outputs provided
|
||||||
|
# in the nonsq_fragment_gradtensor_list list
|
||||||
|
for fragments, grad_tensor, _name in nonsq_fragment_gradtensor_list:
|
||||||
|
# Check that there are the same number of non zero pixels
|
||||||
|
# in both the square and non square images.
|
||||||
|
non_square_mask = fragments.idx > -1
|
||||||
|
self.assertEqual(non_square_mask.sum().item(), square_mask.sum().item())
|
||||||
|
|
||||||
|
# Check dists, zbuf and bary match the square image
|
||||||
|
non_square_dists = fragments.dists[non_square_mask]
|
||||||
|
non_square_zbuf = fragments.zbuf[non_square_mask]
|
||||||
|
self.assertClose(square_dists, non_square_dists)
|
||||||
|
self.assertClose(square_zbuf, non_square_zbuf)
|
||||||
|
|
||||||
|
# Retain gradients to compare values with outputs from
|
||||||
|
# square image
|
||||||
|
fragments.dists.retain_grad()
|
||||||
|
fragments.zbuf.retain_grad()
|
||||||
|
loss1 = (grad_dist * non_square_dists).sum() + (
|
||||||
|
grad_zbuf * non_square_zbuf
|
||||||
|
).sum()
|
||||||
|
loss1.sum().backward()
|
||||||
|
|
||||||
|
# Get the non zero values in the intermediate gradients
|
||||||
|
# and compare with the values from the square image
|
||||||
|
non_square_grad_dists = fragments.dists.grad[non_square_mask]
|
||||||
|
non_square_grad_zbuf = fragments.zbuf.grad[non_square_mask]
|
||||||
|
|
||||||
|
self.assertClose(
|
||||||
|
non_square_grad_dists,
|
||||||
|
square_fragments.dists.grad[square_mask],
|
||||||
|
)
|
||||||
|
self.assertClose(
|
||||||
|
non_square_grad_zbuf,
|
||||||
|
square_fragments.zbuf.grad[square_mask],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Finally check the gradients of the input vertices for
|
||||||
|
# the square and non square case
|
||||||
|
self.assertClose(points_square.grad, grad_tensor.grad, rtol=2e-4)
|
||||||
|
|
||||||
|
def test_gpu(self):
|
||||||
|
"""
|
||||||
|
Test that the output of rendering non square images
|
||||||
|
gives the same result as square images. i.e. the
|
||||||
|
dists, zbuf, idx are all the same for the square
|
||||||
|
region which is present in both images.
|
||||||
|
"""
|
||||||
|
# Test both cases: (W > H), (H > W)
|
||||||
|
image_sizes = [(64, 128), (128, 64), (128, 256), (256, 128)]
|
||||||
|
|
||||||
|
devices = ["cuda:0"]
|
||||||
|
blurs = [5e-2]
|
||||||
|
batch_sizes = [1, 4]
|
||||||
|
test_cases = product(image_sizes, blurs, devices, batch_sizes)
|
||||||
|
|
||||||
|
for image_size, blur, device, batch_size in test_cases:
|
||||||
|
# Initialize the verts grad tensor and the meshes objects
|
||||||
|
verts_nonsq_naive, pointcloud_nonsq_naive = self._clone_pointcloud(
|
||||||
|
torus_points, device, batch_size
|
||||||
|
)
|
||||||
|
verts_nonsq_binned, pointcloud_nonsq_binned = self._clone_pointcloud(
|
||||||
|
torus_points, device, batch_size
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get the outputs for both naive and coarse to fine rasterization
|
||||||
|
fragments_naive = self._rasterize(
|
||||||
|
pointcloud_nonsq_naive,
|
||||||
|
image_size,
|
||||||
|
blur=blur,
|
||||||
|
bin_size=0,
|
||||||
|
)
|
||||||
|
fragments_binned = self._rasterize(
|
||||||
|
pointcloud_nonsq_binned,
|
||||||
|
image_size,
|
||||||
|
blur=blur,
|
||||||
|
bin_size=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save out debug images if needed
|
||||||
|
_save_debug_image(fragments_naive.idx, image_size, 0, blur)
|
||||||
|
_save_debug_image(fragments_binned.idx, image_size, None, blur)
|
||||||
|
|
||||||
|
# Check naive and binned fragments give the same outputs
|
||||||
|
self._check_fragments(fragments_naive, fragments_binned)
|
||||||
|
|
||||||
|
# Here we want to compare the square image with the naive and the
|
||||||
|
# coarse to fine methods outputs
|
||||||
|
nonsq_fragment_gradtensor_list = [
|
||||||
|
(fragments_naive, verts_nonsq_naive, "naive"),
|
||||||
|
(fragments_binned, verts_nonsq_binned, "coarse-to-fine"),
|
||||||
|
]
|
||||||
|
|
||||||
|
self._compare_square_with_nonsq(
|
||||||
|
image_size,
|
||||||
|
blur,
|
||||||
|
device,
|
||||||
|
torus_points,
|
||||||
|
nonsq_fragment_gradtensor_list,
|
||||||
|
batch_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_cpu(self):
|
||||||
|
"""
|
||||||
|
Test that the output of rendering non square images
|
||||||
|
gives the same result as square images. i.e. the
|
||||||
|
dists, zbuf, idx are all the same for the square
|
||||||
|
region which is present in both images.
|
||||||
|
|
||||||
|
In this test we compare between the naive C++ implementation
|
||||||
|
and the naive python implementation as the Coarse/Fine
|
||||||
|
method is not fully implemented in C++
|
||||||
|
"""
|
||||||
|
# Test both when (W > H) and (H > W).
|
||||||
|
# Using smaller image sizes here as the Python rasterizer is really slow.
|
||||||
|
image_sizes = [(32, 64), (64, 32)]
|
||||||
|
devices = ["cpu"]
|
||||||
|
blurs = [5e-2]
|
||||||
|
batch_sizes = [1]
|
||||||
|
test_cases = product(image_sizes, blurs, devices, batch_sizes)
|
||||||
|
|
||||||
|
for image_size, blur, device, batch_size in test_cases:
|
||||||
|
# Initialize the verts grad tensor and the meshes objects
|
||||||
|
verts_nonsq_naive, pointcloud_nonsq_naive = self._clone_pointcloud(
|
||||||
|
torus_points, device, batch_size
|
||||||
|
)
|
||||||
|
verts_nonsq_python, pointcloud_nonsq_python = self._clone_pointcloud(
|
||||||
|
torus_points, device, batch_size
|
||||||
|
)
|
||||||
|
|
||||||
|
# Compare Naive CPU with Python as Coarse/Fine rasteriztation
|
||||||
|
# is not implemented for CPU
|
||||||
|
fragments_naive = self._rasterize(
|
||||||
|
pointcloud_nonsq_naive, image_size, bin_size=0, blur=blur
|
||||||
|
)
|
||||||
|
idxs, zbuf, pix_dists = rasterize_points_python(
|
||||||
|
pointcloud_nonsq_python,
|
||||||
|
image_size,
|
||||||
|
blur,
|
||||||
|
points_per_pixel=1,
|
||||||
|
)
|
||||||
|
fragments_python = PointFragments(
|
||||||
|
idx=idxs,
|
||||||
|
zbuf=zbuf,
|
||||||
|
dists=pix_dists,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save debug images if DEBUG is set to true at the top of the file.
|
||||||
|
_save_debug_image(fragments_naive.idx, image_size, 0, blur)
|
||||||
|
_save_debug_image(fragments_python.idx, image_size, "python", blur)
|
||||||
|
|
||||||
|
# List of non square outputs to compare with the square output
|
||||||
|
nonsq_fragment_gradtensor_list = [
|
||||||
|
(fragments_naive, verts_nonsq_naive, "naive"),
|
||||||
|
(fragments_python, verts_nonsq_python, "python"),
|
||||||
|
]
|
||||||
|
self._compare_square_with_nonsq(
|
||||||
|
image_size,
|
||||||
|
blur,
|
||||||
|
device,
|
||||||
|
torus_points,
|
||||||
|
nonsq_fragment_gradtensor_list,
|
||||||
|
batch_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_render_pointcloud(self):
|
||||||
|
"""
|
||||||
|
Test a textured poincloud is rendered correctly in a non square image.
|
||||||
|
"""
|
||||||
|
device = torch.device("cuda:0")
|
||||||
|
pointclouds = Pointclouds(
|
||||||
|
points=[torus_points * 2.0],
|
||||||
|
features=torch.ones_like(torus_points[None, ...]),
|
||||||
|
).to(device)
|
||||||
|
R, T = look_at_view_transform(2.7, 0.0, 0.0)
|
||||||
|
cameras = FoVPerspectiveCameras(device=device, R=R, T=T)
|
||||||
|
raster_settings = PointsRasterizationSettings(
|
||||||
|
image_size=(512, 1024), radius=5e-2, points_per_pixel=1
|
||||||
|
)
|
||||||
|
rasterizer = PointsRasterizer(cameras=cameras, raster_settings=raster_settings)
|
||||||
|
compositor = AlphaCompositor()
|
||||||
|
renderer = PointsRenderer(rasterizer=rasterizer, compositor=compositor)
|
||||||
|
|
||||||
|
# Load reference image
|
||||||
|
image_ref = load_rgb_image("test_pointcloud_rectangle_image.png", DATA_DIR)
|
||||||
|
|
||||||
|
for bin_size in [0, None]:
|
||||||
|
# Check both naive and coarse to fine produce the same output.
|
||||||
|
renderer.rasterizer.raster_settings.bin_size = bin_size
|
||||||
|
images = renderer(pointclouds)
|
||||||
|
rgb = images[0, ..., :3].squeeze().cpu()
|
||||||
|
|
||||||
|
if DEBUG:
|
||||||
|
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
|
||||||
|
DATA_DIR / "DEBUG_pointcloud_rectangle_image.png"
|
||||||
|
)
|
||||||
|
|
||||||
|
# NOTE some pixels can be flaky
|
||||||
|
cond1 = torch.allclose(rgb, image_ref, atol=0.05)
|
||||||
|
self.assertTrue(cond1)
|
Loading…
x
Reference in New Issue
Block a user