From 3d769a66cb184d75126600abeb4ad953cd56cb8d Mon Sep 17 00:00:00 2001 From: Nikhila Ravi Date: Tue, 15 Dec 2020 14:14:27 -0800 Subject: [PATCH] Non Square image rasterization for pointclouds Summary: Similar to non square image rasterization for meshes, apply the same updates to the pointcloud rasterizer. Main API Change: - PointRasterizationSettings now accepts a tuple/list of (H, W) for the image size. Reviewed By: jcjohnson Differential Revision: D25465206 fbshipit-source-id: 7370d83c431af1b972158cecae19d82364623380 --- pytorch3d/csrc/compositing/alpha_composite.cu | 16 +- pytorch3d/csrc/compositing/alpha_composite.h | 8 +- .../csrc/compositing/norm_weighted_sum.cu | 14 +- .../csrc/compositing/norm_weighted_sum.h | 8 +- pytorch3d/csrc/compositing/weighted_sum.cu | 16 +- pytorch3d/csrc/compositing/weighted_sum.h | 6 +- .../csrc/rasterize_meshes/rasterize_meshes.cu | 17 +- .../rasterize_meshes/rasterize_meshes_cpu.cpp | 32 +- .../rasterize_points/rasterization_utils.cuh | 12 +- .../rasterize_points/rasterization_utils.h | 34 ++ .../csrc/rasterize_points/rasterize_points.cu | 186 +++++---- .../csrc/rasterize_points/rasterize_points.h | 30 +- .../rasterize_points/rasterize_points_cpu.cpp | 78 ++-- pytorch3d/renderer/mesh/__init__.py | 1 + pytorch3d/renderer/mesh/rasterize_meshes.py | 11 +- pytorch3d/renderer/points/compositor.py | 2 +- pytorch3d/renderer/points/rasterize_points.py | 88 +++- pytorch3d/renderer/points/rasterizer.py | 4 +- tests/bm_rasterize_points.py | 15 + .../data/test_pointcloud_rectangle_image.png | Bin 0 -> 20251 bytes tests/test_rasterize_points.py | 4 +- ....py => test_rasterize_rectangle_images.py} | 393 +++++++++++++++++- 22 files changed, 712 insertions(+), 263 deletions(-) create mode 100644 pytorch3d/csrc/rasterize_points/rasterization_utils.h create mode 100644 tests/data/test_pointcloud_rectangle_image.png rename tests/{test_rasterize_rectangles.py => test_rasterize_rectangle_images.py} (51%) diff --git a/pytorch3d/csrc/compositing/alpha_composite.cu b/pytorch3d/csrc/compositing/alpha_composite.cu index 389b95f8..16fadb02 100644 --- a/pytorch3d/csrc/compositing/alpha_composite.cu +++ b/pytorch3d/csrc/compositing/alpha_composite.cu @@ -30,15 +30,15 @@ __global__ void alphaCompositeCudaForwardKernel( // Get the batch and index const int batch = blockIdx.x; - const int num_pixels = C * W * H; + const int num_pixels = C * H * W; const int num_threads = gridDim.y * blockDim.x; const int tid = blockIdx.y * blockDim.x + threadIdx.x; // Iterate over each feature in each pixel for (int pid = tid; pid < num_pixels; pid += num_threads) { - int ch = pid / (W * H); - int j = (pid % (W * H)) / H; - int i = (pid % (W * H)) % H; + int ch = pid / (H * W); + int j = (pid % (H * W)) / W; + int i = (pid % (H * W)) % W; // alphacomposite the different values float cum_alpha = 1.; @@ -81,16 +81,16 @@ __global__ void alphaCompositeCudaBackwardKernel( // Get the batch and index const int batch = blockIdx.x; - const int num_pixels = C * W * H; + const int num_pixels = C * H * W; const int num_threads = gridDim.y * blockDim.x; const int tid = blockIdx.y * blockDim.x + threadIdx.x; // Parallelize over each feature in each pixel in images of size H * W, // for each image in the batch of size batch_size for (int pid = tid; pid < num_pixels; pid += num_threads) { - int ch = pid / (W * H); - int j = (pid % (W * H)) / H; - int i = (pid % (W * H)) % H; + int ch = pid / (H * W); + int j = (pid % (H * W)) / W; + int i = (pid % (H * W)) % W; // alphacomposite the different values float cum_alpha = 1.; diff --git a/pytorch3d/csrc/compositing/alpha_composite.h b/pytorch3d/csrc/compositing/alpha_composite.h index 735d87e1..c910c32d 100644 --- a/pytorch3d/csrc/compositing/alpha_composite.h +++ b/pytorch3d/csrc/compositing/alpha_composite.h @@ -11,13 +11,13 @@ // features: FloatTensor of shape (C, P) which gives the features // of each point where C is the size of the feature and // P the number of points. -// alphas: FloatTensor of shape (N, points_per_pixel, W, W) where +// alphas: FloatTensor of shape (N, points_per_pixel, H, W) where // points_per_pixel is the number of points in the z-buffer -// sorted in z-order, and W is the image size. -// points_idx: IntTensor of shape (N, points_per_pixel, W, W) giving the +// sorted in z-order, and (H, W) is the image size. +// points_idx: IntTensor of shape (N, points_per_pixel, H, W) giving the // indices of the nearest points at each pixel, sorted in z-order. // Returns: -// weighted_fs: FloatTensor of shape (N, C, W, W) giving the accumulated +// weighted_fs: FloatTensor of shape (N, C, H, W) giving the accumulated // feature for each point. Concretely, it gives: // weighted_fs[b,c,i,j] = sum_k cum_alpha_k * // features[c,points_idx[b,k,i,j]] diff --git a/pytorch3d/csrc/compositing/norm_weighted_sum.cu b/pytorch3d/csrc/compositing/norm_weighted_sum.cu index a787e1fa..1885bec6 100644 --- a/pytorch3d/csrc/compositing/norm_weighted_sum.cu +++ b/pytorch3d/csrc/compositing/norm_weighted_sum.cu @@ -30,16 +30,16 @@ __global__ void weightedSumNormCudaForwardKernel( // Get the batch and index const int batch = blockIdx.x; - const int num_pixels = C * W * H; + const int num_pixels = C * H * W; const int num_threads = gridDim.y * blockDim.x; const int tid = blockIdx.y * blockDim.x + threadIdx.x; // Parallelize over each feature in each pixel in images of size H * W, // for each image in the batch of size batch_size for (int pid = tid; pid < num_pixels; pid += num_threads) { - int ch = pid / (W * H); - int j = (pid % (W * H)) / H; - int i = (pid % (W * H)) % H; + int ch = pid / (H * W); + int j = (pid % (H * W)) / W; + int i = (pid % (H * W)) % W; // Store the accumulated alpha value float cum_alpha = 0.; @@ -101,9 +101,9 @@ __global__ void weightedSumNormCudaBackwardKernel( // Parallelize over each feature in each pixel in images of size H * W, // for each image in the batch of size batch_size for (int pid = tid; pid < num_pixels; pid += num_threads) { - int ch = pid / (W * H); - int j = (pid % (W * H)) / H; - int i = (pid % (W * H)) % H; + int ch = pid / (H * W); + int j = (pid % (H * W)) / W; + int i = (pid % (H * W)) % W; float sum_alpha = 0.; float sum_alphafs = 0.; diff --git a/pytorch3d/csrc/compositing/norm_weighted_sum.h b/pytorch3d/csrc/compositing/norm_weighted_sum.h index 34c271bc..c2878503 100644 --- a/pytorch3d/csrc/compositing/norm_weighted_sum.h +++ b/pytorch3d/csrc/compositing/norm_weighted_sum.h @@ -11,13 +11,13 @@ // features: FloatTensor of shape (C, P) which gives the features // of each point where C is the size of the feature and // P the number of points. -// alphas: FloatTensor of shape (N, points_per_pixel, W, W) where +// alphas: FloatTensor of shape (N, points_per_pixel, H, W) where // points_per_pixel is the number of points in the z-buffer -// sorted in z-order, and W is the image size. -// points_idx: IntTensor of shape (N, points_per_pixel, W, W) giving the +// sorted in z-order, and (H, W) is the image size. +// points_idx: IntTensor of shape (N, points_per_pixel, H, W) giving the // indices of the nearest points at each pixel, sorted in z-order. // Returns: -// weighted_fs: FloatTensor of shape (N, C, W, W) giving the accumulated +// weighted_fs: FloatTensor of shape (N, C, H, W) giving the accumulated // feature in each point. Concretely, it gives: // weighted_fs[b,c,i,j] = sum_k alphas[b,k,i,j] * // features[c,points_idx[b,k,i,j]] / sum_k alphas[b,k,i,j] diff --git a/pytorch3d/csrc/compositing/weighted_sum.cu b/pytorch3d/csrc/compositing/weighted_sum.cu index 68ec351e..cee8e75a 100644 --- a/pytorch3d/csrc/compositing/weighted_sum.cu +++ b/pytorch3d/csrc/compositing/weighted_sum.cu @@ -28,16 +28,16 @@ __global__ void weightedSumCudaForwardKernel( // Get the batch and index const int batch = blockIdx.x; - const int num_pixels = C * W * H; + const int num_pixels = C * H * W; const int num_threads = gridDim.y * blockDim.x; const int tid = blockIdx.y * blockDim.x + threadIdx.x; // Parallelize over each feature in each pixel in images of size H * W, // for each image in the batch of size batch_size for (int pid = tid; pid < num_pixels; pid += num_threads) { - int ch = pid / (W * H); - int j = (pid % (W * H)) / H; - int i = (pid % (W * H)) % H; + int ch = pid / (H * W); + int j = (pid % (H * W)) / W; + int i = (pid % (H * W)) % W; // Iterate through the closest K points for this pixel for (int k = 0; k < points_idx.size(1); ++k) { @@ -76,16 +76,16 @@ __global__ void weightedSumCudaBackwardKernel( // Get the batch and index const int batch = blockIdx.x; - const int num_pixels = C * W * H; + const int num_pixels = C * H * W; const int num_threads = gridDim.y * blockDim.x; const int tid = blockIdx.y * blockDim.x + threadIdx.x; // Iterate over each pixel to compute the contribution to the // gradient for the features and weights for (int pid = tid; pid < num_pixels; pid += num_threads) { - int ch = pid / (W * H); - int j = (pid % (W * H)) / H; - int i = (pid % (W * H)) % H; + int ch = pid / (H * W); + int j = (pid % (H * W)) / W; + int i = (pid % (H * W)) % W; // Iterate through the closest K points for this pixel for (int k = 0; k < points_idx.size(1); ++k) { diff --git a/pytorch3d/csrc/compositing/weighted_sum.h b/pytorch3d/csrc/compositing/weighted_sum.h index 4928a252..aa4154ed 100644 --- a/pytorch3d/csrc/compositing/weighted_sum.h +++ b/pytorch3d/csrc/compositing/weighted_sum.h @@ -11,13 +11,13 @@ // features: FloatTensor of shape (C, P) which gives the features // of each point where C is the size of the feature and // P the number of points. -// alphas: FloatTensor of shape (N, points_per_pixel, W, W) where +// alphas: FloatTensor of shape (N, points_per_pixel, H, W) where // points_per_pixel is the number of points in the z-buffer -// sorted in z-order, and W is the image size. +// sorted in z-order, and (H, W) is the image size. // points_idx: IntTensor of shape (N, points_per_pixel, W, W) giving the // indices of the nearest points at each pixel, sorted in z-order. // Returns: -// weighted_fs: FloatTensor of shape (N, C, W, W) giving the accumulated +// weighted_fs: FloatTensor of shape (N, C, H, W) giving the accumulated // feature in each point. Concretely, it gives: // weighted_fs[b,c,i,j] = sum_k alphas[b,k,i,j] * // features[c,points_idx[b,k,i,j]] diff --git a/pytorch3d/csrc/rasterize_meshes/rasterize_meshes.cu b/pytorch3d/csrc/rasterize_meshes/rasterize_meshes.cu index af973f38..a92a64e5 100644 --- a/pytorch3d/csrc/rasterize_meshes/rasterize_meshes.cu +++ b/pytorch3d/csrc/rasterize_meshes/rasterize_meshes.cu @@ -452,7 +452,6 @@ __global__ void RasterizeMeshesBackwardCudaKernel( const bool inside = b_pp.x > 0.0f && b_pp.y > 0.0f && b_pp.z > 0.0f; const float sign = inside ? -1.0f : 1.0f; - // TODO(T52813608) Add support for non-square images. auto grad_dist_f = PointTriangleDistanceBackward( pxy, v0xy, v1xy, v2xy, sign * grad_dist_upstream); const float2 ddist_d_v0 = thrust::get<1>(grad_dist_f); @@ -606,7 +605,7 @@ __global__ void RasterizeMeshesCoarseCudaKernel( const float half_pix_x = NDC_x_half_range / W; const float half_pix_y = NDC_y_half_range / H; - // This is a boolean array of shape (num_bins, num_bins, chunk_size) + // This is a boolean array of shape (num_bins_y, num_bins_x, chunk_size) // stored in shared memory that will track whether each point in the chunk // falls into each bin of the image. BitMask binmask((unsigned int*)sbuf, num_bins_y, num_bins_x, chunk_size); @@ -755,7 +754,7 @@ at::Tensor RasterizeMeshesCoarseCuda( const int num_bins_y = 1 + (H - 1) / bin_size; const int num_bins_x = 1 + (W - 1) / bin_size; - if (num_bins_y >= kMaxFacesPerBin || num_bins_x >= kMaxFacesPerBin) { + if (num_bins_y >= kMaxItemsPerBin || num_bins_x >= kMaxItemsPerBin) { std::stringstream ss; ss << "In Coarse Rasterizer got num_bins_y: " << num_bins_y << ", num_bins_x: " << num_bins_x << ", " @@ -800,7 +799,7 @@ at::Tensor RasterizeMeshesCoarseCuda( // **************************************************************************** __global__ void RasterizeMeshesFineCudaKernel( const float* face_verts, // (F, 3, 3) - const int32_t* bin_faces, // (N, B, B, T) + const int32_t* bin_faces, // (N, BH, BW, T) const float blur_radius, const int bin_size, const bool perspective_correct, @@ -813,12 +812,12 @@ __global__ void RasterizeMeshesFineCudaKernel( const int H, const int W, const int K, - int64_t* face_idxs, // (N, S, S, K) - float* zbuf, // (N, S, S, K) - float* pix_dists, // (N, S, S, K) - float* bary // (N, S, S, K, 3) + int64_t* face_idxs, // (N, H, W, K) + float* zbuf, // (N, H, W, K) + float* pix_dists, // (N, H, W, K) + float* bary // (N, H, W, K, 3) ) { - // This can be more than S^2 if S % bin_size != 0 + // This can be more than H * W if H or W are not divisible by bin_size. int num_pixels = N * BH * BW * bin_size * bin_size; int num_threads = gridDim.x * blockDim.x; int tid = blockIdx.x * blockDim.x + threadIdx.x; diff --git a/pytorch3d/csrc/rasterize_meshes/rasterize_meshes_cpu.cpp b/pytorch3d/csrc/rasterize_meshes/rasterize_meshes_cpu.cpp index b8a73e20..3160e685 100644 --- a/pytorch3d/csrc/rasterize_meshes/rasterize_meshes_cpu.cpp +++ b/pytorch3d/csrc/rasterize_meshes/rasterize_meshes_cpu.cpp @@ -5,41 +5,11 @@ #include #include #include +#include "rasterize_points/rasterization_utils.h" #include "utils/geometry_utils.h" #include "utils/vec2.h" #include "utils/vec3.h" -// The default value of the NDC range is [-1, 1], however in the case that -// H != W, the NDC range is set such that the shorter side has range [-1, 1] and -// the longer side is scaled by the ratio of H:W. S1 is the dimension for which -// the NDC range is calculated and S2 is the other image dimension. -// e.g. to get the NDC x range S1 = W and S2 = H -float NonSquareNdcRange(int S1, int S2) { - float range = 2.0f; - if (S1 > S2) { - range = ((S1 / S2) * range); - } - return range; -} - -// Given a pixel coordinate 0 <= i < S1, convert it to a normalized device -// coordinates. We divide the NDC range into S1 evenly-sized -// pixels, and assume that each pixel falls in the *center* of its range. -// The default value of the NDC range is [-1, 1], however in the case that -// H != W, the NDC range is set such that the shorter side has range [-1, 1] and -// the longer side is scaled by the ratio of H:W. The dimension of i should be -// S1 and the other image dimension is S2 For example, to get the x and y NDC -// coordinates or a given pixel i: -// x = PixToNonSquareNdc(i, W, H) -// y = PixToNonSquareNdc(i, H, W) -float PixToNonSquareNdc(int i, int S1, int S2) { - float range = NonSquareNdcRange(S1, S2); - // NDC: offset + (i * pixel_width + half_pixel_width) - // The NDC range is [-range/2, range/2]. - const float offset = (range / 2.0f); - return -offset + (range * i + offset) / S1; -} - // Get (x, y, z) values for vertex from (3, 3) tensor face. template auto ExtractVerts(const Face& face, const int vertex_index) { diff --git a/pytorch3d/csrc/rasterize_points/rasterization_utils.cuh b/pytorch3d/csrc/rasterize_points/rasterization_utils.cuh index 8492bad1..18272ab7 100644 --- a/pytorch3d/csrc/rasterize_points/rasterization_utils.cuh +++ b/pytorch3d/csrc/rasterize_points/rasterization_utils.cuh @@ -2,16 +2,6 @@ #pragma once -// Given a pixel coordinate 0 <= i < S, convert it to a normalized device -// coordinates in the range [-1, 1]. We divide the NDC range into S evenly-sized -// pixels, and assume that each pixel falls in the *center* of its range. -// TODO: delete this function after updating the pointcloud rasterizer to -// support non square images. -__device__ inline float PixToNdc(int i, int S) { - // NDC: x-offset + (i * pixel_width + half_pixel_width) - return -1.0 + (2 * i + 1.0) / S; -} - // The default value of the NDC range is [-1, 1], however in the case that // H != W, the NDC range is set such that the shorter side has range [-1, 1] and // the longer side is scaled by the ratio of H:W. S1 is the dimension for which @@ -50,7 +40,7 @@ __device__ inline float PixToNonSquareNdc(int i, int S1, int S2) { // TODO: is 8 enough? Would increasing have performance considerations? const int32_t kMaxPointsPerPixel = 150; -const int32_t kMaxFacesPerBin = 22; +const int32_t kMaxItemsPerBin = 22; template __device__ inline void BubbleSort(T* arr, int n) { diff --git a/pytorch3d/csrc/rasterize_points/rasterization_utils.h b/pytorch3d/csrc/rasterize_points/rasterization_utils.h new file mode 100644 index 00000000..06b6bc5c --- /dev/null +++ b/pytorch3d/csrc/rasterize_points/rasterization_utils.h @@ -0,0 +1,34 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. + +#pragma once + +// The default value of the NDC range is [-1, 1], however in the case that +// H != W, the NDC range is set such that the shorter side has range [-1, 1] and +// the longer side is scaled by the ratio of H:W. S1 is the dimension for which +// the NDC range is calculated and S2 is the other image dimension. +// e.g. to get the NDC x range S1 = W and S2 = H +inline float NonSquareNdcRange(int S1, int S2) { + float range = 2.0f; + if (S1 > S2) { + range = ((S1 / S2) * range); + } + return range; +} + +// Given a pixel coordinate 0 <= i < S1, convert it to a normalized device +// coordinates. We divide the NDC range into S1 evenly-sized +// pixels, and assume that each pixel falls in the *center* of its range. +// The default value of the NDC range is [-1, 1], however in the case that +// H != W, the NDC range is set such that the shorter side has range [-1, 1] and +// the longer side is scaled by the ratio of H:W. The dimension of i should be +// S1 and the other image dimension is S2 For example, to get the x and y NDC +// coordinates or a given pixel i: +// x = PixToNonSquareNdc(i, W, H) +// y = PixToNonSquareNdc(i, H, W) +inline float PixToNonSquareNdc(int i, int S1, int S2) { + float range = NonSquareNdcRange(S1, S2); + // NDC: offset + (i * pixel_width + half_pixel_width) + // The NDC range is [-range/2, range/2]. + const float offset = (range / 2.0f); + return -offset + (range * i + offset) / S1; +} diff --git a/pytorch3d/csrc/rasterize_points/rasterize_points.cu b/pytorch3d/csrc/rasterize_points/rasterize_points.cu index d02a5680..8b5ea133 100644 --- a/pytorch3d/csrc/rasterize_points/rasterize_points.cu +++ b/pytorch3d/csrc/rasterize_points/rasterize_points.cu @@ -85,26 +85,28 @@ __global__ void RasterizePointsNaiveCudaKernel( const int64_t* num_points_per_cloud, // (N) const float* radius, const int N, - const int S, + const int H, + const int W, const int K, - int32_t* point_idxs, // (N, S, S, K) - float* zbuf, // (N, S, S, K) - float* pix_dists) { // (N, S, S, K) + int32_t* point_idxs, // (N, H, W, K) + float* zbuf, // (N, H, W, K) + float* pix_dists) { // (N, H, W, K) // Simple version: One thread per output pixel const int num_threads = gridDim.x * blockDim.x; const int tid = blockDim.x * blockIdx.x + threadIdx.x; - for (int i = tid; i < N * S * S; i += num_threads) { + for (int i = tid; i < N * H * W; i += num_threads) { // Convert linear index to 3D index - const int n = i / (S * S); // Batch index - const int pix_idx = i % (S * S); + const int n = i / (H * W); // Batch index + const int pix_idx = i % (H * W); // Reverse ordering of the X and Y axis as the camera coordinates // assume that +Y is pointing up and +X is pointing left. - const int yi = S - 1 - pix_idx / S; - const int xi = S - 1 - pix_idx % S; + const int yi = H - 1 - pix_idx / W; + const int xi = W - 1 - pix_idx % W; - const float xf = PixToNdc(xi, S); - const float yf = PixToNdc(yi, S); + // screen coordinates to ndc coordiantes of pixel. + const float xf = PixToNonSquareNdc(xi, W, H); + const float yf = PixToNonSquareNdc(yi, H, W); // For keeping track of the K closest points we want a data structure // that (1) gives O(1) access to the closest point for easy comparisons, @@ -132,7 +134,7 @@ __global__ void RasterizePointsNaiveCudaKernel( points, p_idx, q_size, q_max_z, q_max_idx, q, radius, xf, yf, K); } BubbleSort(q, q_size); - int idx = n * S * S * K + pix_idx * K; + int idx = n * H * W * K + pix_idx * K; for (int k = 0; k < q_size; ++k) { point_idxs[idx + k] = q[k].idx; zbuf[idx + k] = q[k].z; @@ -145,7 +147,7 @@ std::tuple RasterizePointsNaiveCuda( const at::Tensor& points, // (P. 3) const at::Tensor& cloud_to_packed_first_idx, // (N) const at::Tensor& num_points_per_cloud, // (N) - const int image_size, + const std::tuple image_size, const at::Tensor& radius, const int points_per_pixel) { // Check inputs are on the same device @@ -169,7 +171,8 @@ std::tuple RasterizePointsNaiveCuda( "num_points_per_cloud must have same size first dimension as cloud_to_packed_first_idx"); const int N = num_points_per_cloud.size(0); // batch size. - const int S = image_size; + const int H = std::get<0>(image_size); + const int W = std::get<1>(image_size); const int K = points_per_pixel; if (K > kMaxPointsPerPixel) { @@ -180,9 +183,9 @@ std::tuple RasterizePointsNaiveCuda( auto int_opts = num_points_per_cloud.options().dtype(at::kInt); auto float_opts = points.options().dtype(at::kFloat); - at::Tensor point_idxs = at::full({N, S, S, K}, -1, int_opts); - at::Tensor zbuf = at::full({N, S, S, K}, -1, float_opts); - at::Tensor pix_dists = at::full({N, S, S, K}, -1, float_opts); + at::Tensor point_idxs = at::full({N, H, W, K}, -1, int_opts); + at::Tensor zbuf = at::full({N, H, W, K}, -1, float_opts); + at::Tensor pix_dists = at::full({N, H, W, K}, -1, float_opts); if (point_idxs.numel() == 0) { AT_CUDA_CHECK(cudaGetLastError()); @@ -197,7 +200,8 @@ std::tuple RasterizePointsNaiveCuda( num_points_per_cloud.contiguous().data_ptr(), radius.contiguous().data_ptr(), N, - S, + H, + W, K, point_idxs.contiguous().data_ptr(), zbuf.contiguous().data_ptr(), @@ -218,7 +222,8 @@ __global__ void RasterizePointsCoarseCudaKernel( const float* radius, const int N, const int P, - const int S, + const int H, + const int W, const int bin_size, const int chunk_size, const int max_points_per_bin, @@ -226,13 +231,26 @@ __global__ void RasterizePointsCoarseCudaKernel( int* bin_points) { extern __shared__ char sbuf[]; const int M = max_points_per_bin; - const int num_bins = 1 + (S - 1) / bin_size; // Integer divide round up - const float half_pix = 1.0f / S; // Size of half a pixel in NDC units - // This is a boolean array of shape (num_bins, num_bins, chunk_size) + // Integer divide round up + const int num_bins_x = 1 + (W - 1) / bin_size; + const int num_bins_y = 1 + (H - 1) / bin_size; + + // NDC range depends on the ratio of W/H + // The shorter side from (H, W) is given an NDC range of 2.0 and + // the other side is scaled by the ratio of H:W. + const float NDC_x_half_range = NonSquareNdcRange(W, H) / 2.0f; + const float NDC_y_half_range = NonSquareNdcRange(H, W) / 2.0f; + + // Size of half a pixel in NDC units is the NDC half range + // divided by the corresponding image dimension + const float half_pix_x = NDC_x_half_range / W; + const float half_pix_y = NDC_y_half_range / H; + + // This is a boolean array of shape (num_bins_y, num_bins_x, chunk_size) // stored in shared memory that will track whether each point in the chunk // falls into each bin of the image. - BitMask binmask((unsigned int*)sbuf, num_bins, num_bins, chunk_size); + BitMask binmask((unsigned int*)sbuf, num_bins_y, num_bins_x, chunk_size); // Have each block handle a chunk of points and build a 3D bitmask in // shared memory to mark which points hit which bins. In this first phase, @@ -279,22 +297,24 @@ __global__ void RasterizePointsCoarseCudaKernel( // For example we could compute the exact bin where the point falls, // then check neighboring bins. This way we wouldn't have to check // all bins (however then we might have more warp divergence?) - for (int by = 0; by < num_bins; ++by) { - // Get y extent for the bin. PixToNdc gives us the location of + for (int by = 0; by < num_bins_y; ++by) { + // Get y extent for the bin. PixToNonSquareNdc gives us the location of // the center of each pixel, so we need to add/subtract a half // pixel to get the true extent of the bin. - const float by0 = PixToNdc(by * bin_size, S) - half_pix; - const float by1 = PixToNdc((by + 1) * bin_size - 1, S) + half_pix; + const float by0 = PixToNonSquareNdc(by * bin_size, H, W) - half_pix_y; + const float by1 = + PixToNonSquareNdc((by + 1) * bin_size - 1, H, W) + half_pix_y; const bool y_overlap = (py0 <= by1) && (by0 <= py1); if (!y_overlap) { continue; } - for (int bx = 0; bx < num_bins; ++bx) { + for (int bx = 0; bx < num_bins_x; ++bx) { // Get x extent for the bin; again we need to adjust the - // output of PixToNdc by half a pixel. - const float bx0 = PixToNdc(bx * bin_size, S) - half_pix; - const float bx1 = PixToNdc((bx + 1) * bin_size - 1, S) + half_pix; + // output of PixToNonSquareNdc by half a pixel. + const float bx0 = PixToNonSquareNdc(bx * bin_size, W, H) - half_pix_x; + const float bx1 = + PixToNonSquareNdc((bx + 1) * bin_size - 1, W, H) + half_pix_x; const bool x_overlap = (px0 <= bx1) && (bx0 <= px1); if (x_overlap) { @@ -307,12 +327,13 @@ __global__ void RasterizePointsCoarseCudaKernel( // Now we have processed every point in the current chunk. We need to // count the number of points in each bin so we can write the indices // out to global memory. We have each thread handle a different bin. - for (int byx = threadIdx.x; byx < num_bins * num_bins; byx += blockDim.x) { - const int by = byx / num_bins; - const int bx = byx % num_bins; + for (int byx = threadIdx.x; byx < num_bins_y * num_bins_x; + byx += blockDim.x) { + const int by = byx / num_bins_x; + const int bx = byx % num_bins_x; const int count = binmask.count(by, bx); const int points_per_bin_idx = - batch_idx * num_bins * num_bins + by * num_bins + bx; + batch_idx * num_bins_y * num_bins_x + by * num_bins_x + bx; // This atomically increments the (global) number of points found // in the current bin, and gets the previous value of the counter; @@ -322,8 +343,8 @@ __global__ void RasterizePointsCoarseCudaKernel( // Now loop over the binmask and write the active bits for this bin // out to bin_points. - int next_idx = batch_idx * num_bins * num_bins * M + by * num_bins * M + - bx * M + start; + int next_idx = batch_idx * num_bins_y * num_bins_x * M + + by * num_bins_x * M + bx * M + start; for (int p = 0; p < chunk_size; ++p) { if (binmask.get(by, bx, p)) { // TODO: Throw an error if next_idx >= M -- this means that @@ -342,7 +363,7 @@ at::Tensor RasterizePointsCoarseCuda( const at::Tensor& points, // (P, 3) const at::Tensor& cloud_to_packed_first_idx, // (N) const at::Tensor& num_points_per_cloud, // (N) - const int image_size, + const std::tuple image_size, const at::Tensor& radius, const int bin_size, const int max_points_per_bin) { @@ -363,20 +384,28 @@ at::Tensor RasterizePointsCoarseCuda( at::cuda::CUDAGuard device_guard(points.device()); cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const int H = std::get<0>(image_size); + const int W = std::get<1>(image_size); + const int P = points.size(0); const int N = num_points_per_cloud.size(0); - const int num_bins = 1 + (image_size - 1) / bin_size; // divide round up const int M = max_points_per_bin; - if (num_bins >= 22) { + // Integer divide round up. + const int num_bins_y = 1 + (H - 1) / bin_size; + const int num_bins_x = 1 + (W - 1) / bin_size; + + if (num_bins_y >= kMaxItemsPerBin || num_bins_x >= kMaxItemsPerBin) { // Make sure we do not use too much shared memory. std::stringstream ss; - ss << "Got " << num_bins << "; that's too many!"; + ss << "In Coarse Rasterizer got num_bins_y: " << num_bins_y + << ", num_bins_x: " << num_bins_x << ", " + << "; that's too many!"; AT_ERROR(ss.str()); } auto opts = num_points_per_cloud.options().dtype(at::kInt); - at::Tensor points_per_bin = at::zeros({N, num_bins, num_bins}, opts); - at::Tensor bin_points = at::full({N, num_bins, num_bins, M}, -1, opts); + at::Tensor points_per_bin = at::zeros({N, num_bins_y, num_bins_x}, opts); + at::Tensor bin_points = at::full({N, num_bins_y, num_bins_x, M}, -1, opts); if (bin_points.numel() == 0) { AT_CUDA_CHECK(cudaGetLastError()); @@ -384,7 +413,7 @@ at::Tensor RasterizePointsCoarseCuda( } const int chunk_size = 512; - const size_t shared_size = num_bins * num_bins * chunk_size / 8; + const size_t shared_size = num_bins_y * num_bins_x * chunk_size / 8; const size_t blocks = 64; const size_t threads = 512; @@ -395,7 +424,8 @@ at::Tensor RasterizePointsCoarseCuda( radius.contiguous().data_ptr(), N, P, - image_size, + H, + W, bin_size, chunk_size, M, @@ -412,19 +442,21 @@ at::Tensor RasterizePointsCoarseCuda( __global__ void RasterizePointsFineCudaKernel( const float* points, // (P, 3) - const int32_t* bin_points, // (N, B, B, T) + const int32_t* bin_points, // (N, BH, BW, T) const float* radius, const int bin_size, const int N, - const int B, // num_bins + const int BH, // num_bins y + const int BW, // num_bins x const int M, - const int S, + const int H, + const int W, const int K, - int32_t* point_idxs, // (N, S, S, K) - float* zbuf, // (N, S, S, K) - float* pix_dists) { // (N, S, S, K) - // This can be more than S^2 if S is not dividable by bin_size. - const int num_pixels = N * B * B * bin_size * bin_size; + int32_t* point_idxs, // (N, H, W, K) + float* zbuf, // (N, H, W, K) + float* pix_dists) { // (N, H, W, K) + // This can be more than H * W if H or W are not divisible by bin_size. + const int num_pixels = N * BH * BW * bin_size * bin_size; const int num_threads = gridDim.x * blockDim.x; const int tid = blockIdx.x * blockDim.x + threadIdx.x; @@ -434,21 +466,21 @@ __global__ void RasterizePointsFineCudaKernel( // into the same bin; this should give them coalesced memory reads when // they read from points and bin_points. int i = pid; - const int n = i / (B * B * bin_size * bin_size); - i %= B * B * bin_size * bin_size; - const int by = i / (B * bin_size * bin_size); - i %= B * bin_size * bin_size; + const int n = i / (BH * BW * bin_size * bin_size); + i %= BH * BW * bin_size * bin_size; + const int by = i / (BW * bin_size * bin_size); + i %= BW * bin_size * bin_size; const int bx = i / (bin_size * bin_size); i %= bin_size * bin_size; const int yi = i / bin_size + by * bin_size; const int xi = i % bin_size + bx * bin_size; - if (yi >= S || xi >= S) + if (yi >= H || xi >= W) continue; - const float xf = PixToNdc(xi, S); - const float yf = PixToNdc(yi, S); + const float xf = PixToNonSquareNdc(xi, W, H); + const float yf = PixToNonSquareNdc(yi, H, W); // This part looks like the naive rasterization kernel, except we use // bin_points to only look at a subset of points already known to fall @@ -459,7 +491,7 @@ __global__ void RasterizePointsFineCudaKernel( float q_max_z = -1000; int q_max_idx = -1; for (int m = 0; m < M; ++m) { - const int p = bin_points[n * B * B * M + by * B * M + bx * M + m]; + const int p = bin_points[n * BH * BW * M + by * BW * M + bx * M + m]; if (p < 0) { // bin_points uses -1 as a sentinal value continue; @@ -473,10 +505,10 @@ __global__ void RasterizePointsFineCudaKernel( // Reverse ordering of the X and Y axis as the camera coordinates // assume that +Y is pointing up and +X is pointing left. - const int yidx = S - 1 - yi; - const int xidx = S - 1 - xi; + const int yidx = H - 1 - yi; + const int xidx = W - 1 - xi; - const int pix_idx = n * S * S * K + yidx * S * K + xidx * K; + const int pix_idx = n * H * W * K + yidx * W * K + xidx * K; for (int k = 0; k < q_size; ++k) { point_idxs[pix_idx + k] = q[k].idx; zbuf[pix_idx + k] = q[k].z; @@ -488,7 +520,7 @@ __global__ void RasterizePointsFineCudaKernel( std::tuple RasterizePointsFineCuda( const at::Tensor& points, // (P, 3) const at::Tensor& bin_points, - const int image_size, + const std::tuple image_size, const at::Tensor& radius, const int bin_size, const int points_per_pixel) { @@ -503,18 +535,22 @@ std::tuple RasterizePointsFineCuda( cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const int N = bin_points.size(0); - const int B = bin_points.size(1); // num_bins + const int BH = bin_points.size(1); + const int BW = bin_points.size(2); const int M = bin_points.size(3); - const int S = image_size; const int K = points_per_pixel; + + const int H = std::get<0>(image_size); + const int W = std::get<1>(image_size); + if (K > kMaxPointsPerPixel) { AT_ERROR("Must have num_closest <= 150"); } auto int_opts = bin_points.options().dtype(at::kInt); auto float_opts = points.options().dtype(at::kFloat); - at::Tensor point_idxs = at::full({N, S, S, K}, -1, int_opts); - at::Tensor zbuf = at::full({N, S, S, K}, -1, float_opts); - at::Tensor pix_dists = at::full({N, S, S, K}, -1, float_opts); + at::Tensor point_idxs = at::full({N, H, W, K}, -1, int_opts); + at::Tensor zbuf = at::full({N, H, W, K}, -1, float_opts); + at::Tensor pix_dists = at::full({N, H, W, K}, -1, float_opts); if (point_idxs.numel() == 0) { AT_CUDA_CHECK(cudaGetLastError()); @@ -529,9 +565,11 @@ std::tuple RasterizePointsFineCuda( radius.contiguous().data_ptr(), bin_size, N, - B, + BH, + BW, M, - S, + H, + W, K, point_idxs.contiguous().data_ptr(), zbuf.contiguous().data_ptr(), @@ -571,8 +609,8 @@ __global__ void RasterizePointsBackwardCudaKernel( const int yidx = H - 1 - yi; const int xidx = W - 1 - xi; - const float xf = PixToNdc(xidx, W); - const float yf = PixToNdc(yidx, H); + const float xf = PixToNonSquareNdc(xidx, W, H); + const float yf = PixToNonSquareNdc(yidx, H, W); const int p = idxs[i]; if (p < 0) diff --git a/pytorch3d/csrc/rasterize_points/rasterize_points.h b/pytorch3d/csrc/rasterize_points/rasterize_points.h index f1ec1aaf..a13d9773 100644 --- a/pytorch3d/csrc/rasterize_points/rasterize_points.h +++ b/pytorch3d/csrc/rasterize_points/rasterize_points.h @@ -14,7 +14,7 @@ std::tuple RasterizePointsNaiveCpu( const torch::Tensor& points, const torch::Tensor& cloud_to_packed_first_idx, const torch::Tensor& num_points_per_cloud, - const int image_size, + const std::tuple image_size, const torch::Tensor& radius, const int points_per_pixel); @@ -24,7 +24,7 @@ RasterizePointsNaiveCuda( const torch::Tensor& points, const torch::Tensor& cloud_to_packed_first_idx, const torch::Tensor& num_points_per_cloud, - const int image_size, + const std::tuple image_size, const torch::Tensor& radius, const int points_per_pixel); #endif @@ -43,7 +43,8 @@ RasterizePointsNaiveCuda( // for each pointcloud in the batch. // radius: FloatTensor of shape (P) giving the radius (in NDC units) of // each point in points. -// image_size: (S) Size of the image to return (in pixels) +// image_size: Tuple (H, W) giving the size in pixels of the output +// image to be rasterized. // points_per_pixel: (K) The number closest of points to return for each pixel // // Returns: @@ -62,7 +63,7 @@ std::tuple RasterizePointsNaive( const torch::Tensor& points, const torch::Tensor& cloud_to_packed_first_idx, const torch::Tensor& num_points_per_cloud, - const int image_size, + const std::tuple image_size, const torch::Tensor& radius, const int points_per_pixel) { if (points.is_cuda() && cloud_to_packed_first_idx.is_cuda() && @@ -101,7 +102,7 @@ torch::Tensor RasterizePointsCoarseCpu( const torch::Tensor& points, const torch::Tensor& cloud_to_packed_first_idx, const torch::Tensor& num_points_per_cloud, - const int image_size, + const std::tuple image_size, const torch::Tensor& radius, const int bin_size, const int max_points_per_bin); @@ -111,7 +112,7 @@ torch::Tensor RasterizePointsCoarseCuda( const torch::Tensor& points, const torch::Tensor& cloud_to_packed_first_idx, const torch::Tensor& num_points_per_cloud, - const int image_size, + const std::tuple image_size, const torch::Tensor& radius, const int bin_size, const int max_points_per_bin); @@ -128,7 +129,8 @@ torch::Tensor RasterizePointsCoarseCuda( // for each pointcloud in the batch. // radius: FloatTensor of shape (P) giving the radius (in NDC units) of // each point in points. -// image_size: Size of the image to generate (in pixels) +// image_size: Tuple (H, W) giving the size in pixels of the output +// image to be rasterized. // bin_size: Size of each bin within the image (in pixels) // // Returns: @@ -140,7 +142,7 @@ torch::Tensor RasterizePointsCoarse( const torch::Tensor& points, const torch::Tensor& cloud_to_packed_first_idx, const torch::Tensor& num_points_per_cloud, - const int image_size, + const std::tuple image_size, const torch::Tensor& radius, const int bin_size, const int max_points_per_bin) { @@ -182,7 +184,7 @@ torch::Tensor RasterizePointsCoarse( std::tuple RasterizePointsFineCuda( const torch::Tensor& points, const torch::Tensor& bin_points, - const int image_size, + const std::tuple image_size, const torch::Tensor& radius, const int bin_size, const int points_per_pixel); @@ -194,7 +196,8 @@ std::tuple RasterizePointsFineCuda( // are expected to be in NDC coordinates in the range [-1, 1]. // bin_points: int32 Tensor of shape (N, B, B, M) giving the indices of points // that fall into each bin (output from coarse rasterization) -// image_size: Size of image to generate (in pixels) +// image_size: Tuple (H, W) giving the size in pixels of the output +// image to be rasterized. // radius: FloatTensor of shape (P) giving the radius (in NDC units) of // each point in points. // bin_size: Size of each bin (in pixels) @@ -214,7 +217,7 @@ std::tuple RasterizePointsFineCuda( std::tuple RasterizePointsFine( const torch::Tensor& points, const torch::Tensor& bin_points, - const int image_size, + const std::tuple image_size, const torch::Tensor& radius, const int bin_size, const int points_per_pixel) { @@ -303,7 +306,8 @@ torch::Tensor RasterizePointsBackward( // for each pointcloud in the batch. // radius: FloatTensor of shape (P) giving the radius (in NDC units) of // each point in points. -// image_size: (S) Size of the image to return (in pixels) +// image_size: Tuple (H, W) giving the size in pixels of the output +// image to be rasterized. // points_per_pixel: (K) The number of points to return for each pixel // bin_size: Bin size (in pixels) for coarse-to-fine rasterization. Setting // bin_size=0 uses naive rasterization instead. @@ -325,7 +329,7 @@ std::tuple RasterizePoints( const torch::Tensor& points, const torch::Tensor& cloud_to_packed_first_idx, const torch::Tensor& num_points_per_cloud, - const int image_size, + const std::tuple image_size, const torch::Tensor& radius, const int points_per_pixel, const int bin_size, diff --git a/pytorch3d/csrc/rasterize_points/rasterize_points_cpu.cpp b/pytorch3d/csrc/rasterize_points/rasterize_points_cpu.cpp index d7913f65..53cd6cba 100644 --- a/pytorch3d/csrc/rasterize_points/rasterize_points_cpu.cpp +++ b/pytorch3d/csrc/rasterize_points/rasterize_points_cpu.cpp @@ -3,33 +3,27 @@ #include #include #include - -// Given a pixel coordinate 0 <= i < S, convert it to a normalized device -// coordinate in the range [-1, 1]. The NDC range is divided into S evenly-sized -// pixels, and assume that each pixel falls in the *center* of its range. -static float PixToNdc(const int i, const int S) { - // NDC x-offset + (i * pixel_width + half_pixel_width) - return -1 + (2 * i + 1.0f) / S; -} +#include "rasterization_utils.h" std::tuple RasterizePointsNaiveCpu( const torch::Tensor& points, // (P, 3) const torch::Tensor& cloud_to_packed_first_idx, // (N) const torch::Tensor& num_points_per_cloud, // (N) - const int image_size, + const std::tuple image_size, const torch::Tensor& radius, const int points_per_pixel) { const int32_t N = cloud_to_packed_first_idx.size(0); // batch_size. - const int S = image_size; + const int H = std::get<0>(image_size); + const int W = std::get<1>(image_size); const int K = points_per_pixel; // Initialize output tensors. auto int_opts = num_points_per_cloud.options().dtype(torch::kInt32); auto float_opts = points.options().dtype(torch::kFloat32); - torch::Tensor point_idxs = torch::full({N, S, S, K}, -1, int_opts); - torch::Tensor zbuf = torch::full({N, S, S, K}, -1, float_opts); - torch::Tensor pix_dists = torch::full({N, S, S, K}, -1, float_opts); + torch::Tensor point_idxs = torch::full({N, H, W, K}, -1, int_opts); + torch::Tensor zbuf = torch::full({N, H, W, K}, -1, float_opts); + torch::Tensor pix_dists = torch::full({N, H, W, K}, -1, float_opts); auto points_a = points.accessor(); auto point_idxs_a = point_idxs.accessor(); @@ -46,16 +40,16 @@ std::tuple RasterizePointsNaiveCpu( const int point_stop_idx = (point_start_idx + num_points_per_cloud[n].item().to()); - for (int yi = 0; yi < S; ++yi) { + for (int yi = 0; yi < H; ++yi) { // Reverse the order of yi so that +Y is pointing upwards in the image. - const int yidx = S - 1 - yi; - const float yf = PixToNdc(yidx, S); + const int yidx = H - 1 - yi; + const float yf = PixToNonSquareNdc(yidx, H, W); - for (int xi = 0; xi < S; ++xi) { + for (int xi = 0; xi < W; ++xi) { // Reverse the order of xi so that +X is pointing to the left in the // image. - const int xidx = S - 1 - xi; - const float xf = PixToNdc(xidx, S); + const int xidx = W - 1 - xi; + const float xf = PixToNonSquareNdc(xidx, W, H); // Use a priority queue to hold (z, idx, r) std::priority_queue> q; @@ -99,25 +93,36 @@ torch::Tensor RasterizePointsCoarseCpu( const torch::Tensor& points, // (P, 3) const torch::Tensor& cloud_to_packed_first_idx, // (N) const torch::Tensor& num_points_per_cloud, // (N) - const int image_size, + const std::tuple image_size, const torch::Tensor& radius, const int bin_size, const int max_points_per_bin) { const int32_t N = cloud_to_packed_first_idx.size(0); // batch_size. - - const int B = 1 + (image_size - 1) / bin_size; // Integer division round up const int M = max_points_per_bin; + + const float H = std::get<0>(image_size); + const float W = std::get<1>(image_size); + + // Integer division round up. + const int BH = 1 + (H - 1) / bin_size; + const int BW = 1 + (W - 1) / bin_size; + auto opts = num_points_per_cloud.options().dtype(torch::kInt32); - torch::Tensor points_per_bin = torch::zeros({N, B, B}, opts); - torch::Tensor bin_points = torch::full({N, B, B, M}, -1, opts); + torch::Tensor points_per_bin = torch::zeros({N, BH, BW}, opts); + torch::Tensor bin_points = torch::full({N, BH, BW, M}, -1, opts); auto points_a = points.accessor(); auto points_per_bin_a = points_per_bin.accessor(); auto bin_points_a = bin_points.accessor(); auto radius_a = radius.accessor(); - const float pixel_width = 2.0f / image_size; - const float bin_width = pixel_width * bin_size; + const float ndc_x_range = NonSquareNdcRange(W, H); + const float pixel_width_x = ndc_x_range / W; + const float bin_width_x = pixel_width_x * bin_size; + + const float ndc_y_range = NonSquareNdcRange(H, W); + const float pixel_width_y = ndc_y_range / H; + const float bin_width_y = pixel_width_y * bin_size; for (int n = 0; n < N; ++n) { // Loop through each pointcloud in the batch. @@ -129,15 +134,15 @@ torch::Tensor RasterizePointsCoarseCpu( (point_start_idx + num_points_per_cloud[n].item().to()); float bin_y_min = -1.0f; - float bin_y_max = bin_y_min + bin_width; + float bin_y_max = bin_y_min + bin_width_y; // Iterate through the horizontal bins from top to bottom. - for (int by = 0; by < B; by++) { + for (int by = 0; by < BH; by++) { float bin_x_min = -1.0f; - float bin_x_max = bin_x_min + bin_width; + float bin_x_max = bin_x_min + bin_width_x; // Iterate through bins on this horizontal line, left to right. - for (int bx = 0; bx < B; bx++) { + for (int bx = 0; bx < BW; bx++) { int32_t points_hit = 0; for (int p = point_start_idx; p < point_stop_idx; ++p) { float px = points_a[p][0]; @@ -172,11 +177,11 @@ torch::Tensor RasterizePointsCoarseCpu( // Shift the bin to the right for the next loop iteration bin_x_min = bin_x_max; - bin_x_max = bin_x_min + bin_width; + bin_x_max = bin_x_min + bin_width_x; } // Shift the bin down for the next loop iteration bin_y_min = bin_y_max; - bin_y_max = bin_y_min + bin_width; + bin_y_max = bin_y_min + bin_width_y; } } return bin_points; @@ -194,11 +199,6 @@ torch::Tensor RasterizePointsBackwardCpu( const int W = idxs.size(2); const int K = idxs.size(3); - // For now only support square images. - // TODO(jcjohns): Extend to non-square images. - if (H != W) { - AT_ERROR("RasterizePointsBackwardCpu only supports square images"); - } torch::Tensor grad_points = torch::zeros({P, 3}, points.options()); auto points_a = points.accessor(); @@ -212,7 +212,7 @@ torch::Tensor RasterizePointsBackwardCpu( // Reverse the order of yi so that +Y is pointing upwards in the image. const int yidx = H - 1 - y; // Y coordinate of the top of the pixel. - const float yf = PixToNdc(yidx, H); + const float yf = PixToNonSquareNdc(yidx, H, W); // Iterate through pixels on this horizontal line, left to right. for (int x = 0; x < W; ++x) { // Loop over pixels in the row @@ -220,7 +220,7 @@ torch::Tensor RasterizePointsBackwardCpu( // Reverse the order of xi so that +X is pointing to the left in the // image. const int xidx = W - 1 - x; - const float xf = PixToNdc(xidx, W); + const float xf = PixToNonSquareNdc(xidx, W, H); for (int k = 0; k < K; ++k) { // Loop over points for the pixel const int p = idxs_a[n][y][x][k]; if (p < 0) { diff --git a/pytorch3d/renderer/mesh/__init__.py b/pytorch3d/renderer/mesh/__init__.py index a0a01086..d8b6b13b 100644 --- a/pytorch3d/renderer/mesh/__init__.py +++ b/pytorch3d/renderer/mesh/__init__.py @@ -6,6 +6,7 @@ from .rasterizer import MeshRasterizer, RasterizationSettings from .renderer import MeshRenderer from .shader import TexturedSoftPhongShader # DEPRECATED from .shader import ( + BlendParams, HardFlatShader, HardGouraudShader, HardPhongShader, diff --git a/pytorch3d/renderer/mesh/rasterize_meshes.py b/pytorch3d/renderer/mesh/rasterize_meshes.py index c0702631..d4cb24ab 100644 --- a/pytorch3d/renderer/mesh/rasterize_meshes.py +++ b/pytorch3d/renderer/mesh/rasterize_meshes.py @@ -1,7 +1,7 @@ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -from typing import Optional, Tuple, Union +from typing import List, Optional, Tuple, Union import numpy as np import torch @@ -20,7 +20,7 @@ kMaxFacesPerBin = 22 def rasterize_meshes( meshes, - image_size: Union[int, Tuple[int, int]] = 256, + image_size: Union[int, List[int], Tuple[int, int]] = 256, blur_radius: float = 0.0, faces_per_pixel: int = 8, bin_size: Optional[int] = None, @@ -219,7 +219,7 @@ class _RasterizeFaceVerts(torch.autograd.Function): face_verts, mesh_to_face_first_idx, num_faces_per_mesh, - image_size: Tuple[int, int] = (256, 256), + image_size: Union[List[int], Tuple[int, int]] = (256, 256), blur_radius: float = 0.01, faces_per_pixel: int = 0, bin_size: int = 0, @@ -287,11 +287,6 @@ class _RasterizeFaceVerts(torch.autograd.Function): return grads -def pix_to_ndc(i, S): - # NDC x-offset + (i * pixel_width + half_pixel_width) - return -1 + (2 * i + 1.0) / S - - def non_square_ndc_range(S1, S2): """ In the case of non square images, we scale the NDC range diff --git a/pytorch3d/renderer/points/compositor.py b/pytorch3d/renderer/points/compositor.py index 6f6c274c..650f0a3a 100644 --- a/pytorch3d/renderer/points/compositor.py +++ b/pytorch3d/renderer/points/compositor.py @@ -75,7 +75,7 @@ def _add_background_color_to_images(pix_idxs, images, background_color): pixels with accumulated features have unchanged values. """ # Initialize background mask - background_mask = pix_idxs[:, 0] < 0 # (N, image_size, image_size) + background_mask = pix_idxs[:, 0] < 0 # (N, H, W) # Convert background_color to an appropriate tensor and check shape if not torch.is_tensor(background_color): diff --git a/pytorch3d/renderer/points/rasterize_points.py b/pytorch3d/renderer/points/rasterize_points.py index 2942f708..80a0e99b 100644 --- a/pytorch3d/renderer/points/rasterize_points.py +++ b/pytorch3d/renderer/points/rasterize_points.py @@ -6,7 +6,7 @@ import torch # pyre-fixme[21]: Could not find name `_C` in `pytorch3d`. from pytorch3d import _C -from pytorch3d.renderer.mesh.rasterize_meshes import pix_to_ndc +from pytorch3d.renderer.mesh.rasterize_meshes import pix_to_non_square_ndc # Maxinum number of faces per bins for @@ -14,17 +14,30 @@ from pytorch3d.renderer.mesh.rasterize_meshes import pix_to_ndc kMaxPointsPerBin = 22 -# TODO(jcjohns): Support non-square images def rasterize_points( pointclouds, - image_size: int = 256, + image_size: Union[int, List[int], Tuple[int, int]] = 256, radius: Union[float, List, Tuple, torch.Tensor] = 0.01, points_per_pixel: int = 8, bin_size: Optional[int] = None, max_points_per_bin: Optional[int] = None, ): """ - Pointcloud rasterization + Each pointcloud is rasterized onto a separate image of shape + (H, W) if `image_size` is a tuple or (image_size, image_size) if it + is an int. + + If the desired image size is non square (i.e. a tuple of (H, W) where H != W) + the aspect ratio needs special consideration. There are two aspect ratios + to be aware of: + - the aspect ratio of each pixel + - the aspect ratio of the output image + The camera can be used to set the pixel aspect ratio. In the rasterizer, + we assume square pixels, but variable image aspect ratio (i.e rectangle images). + + In most cases you will want to set the camera aspect ratio to + 1.0 (i.e. square pixels) and only vary the + `image_size` (i.e. the output image dimensions in pix Args: pointclouds: A Pointclouds object representing a batch of point clouds to be @@ -34,7 +47,8 @@ def rasterize_points( be in normalized device coordinates (NDC): [-1, 1]^3 with the camera at (0, 0, 0); In the camera coordinate frame the x-axis goes from right-to-left, the y-axis goes from bottom-to-top, and the z-axis goes from back-to-front. - image_size: Integer giving the resolution of the rasterized image + image_size: Size in pixels of the output image to be rasterized. + Can optionally be a tuple of (H, W) in the case of non square images. radius (Optional): The radius (in NDC units) of the disk to be rasterized. This can either be a float in which case the same radius is used for each point, or a torch.Tensor of shape (N, P) giving a radius per point @@ -71,6 +85,9 @@ def rasterize_points( then `dists[n, y, x, k]` is the squared distance between the pixel (y, x) and the point `(points[n, p, 0], points[n, p, 1])`. Pixels hit with fewer than points_per_pixel are padded with -1. + + In the case that image_size is a tuple of (H, W) then the outputs + will be of shape `(N, H, W, ...)`. """ points_packed = pointclouds.points_packed() cloud_to_packed_first_idx = pointclouds.cloud_to_packed_first_idx() @@ -78,26 +95,46 @@ def rasterize_points( radius = _format_radius(radius, pointclouds) + # In the case that H != W use the max image size to set the bin_size + # to accommodate the num bins constraint in the coarse rasteizer. + # If the ratio of H:W is large this might cause issues as the smaller + # dimension will have fewer bins. + # TODO: consider a better way of setting the bin size. + if isinstance(image_size, (tuple, list)): + if len(image_size) != 2: + raise ValueError("Image size can only be a tuple/list of (H, W)") + if not all(i > 0 for i in image_size): + raise ValueError( + "Image sizes must be greater than 0; got %d, %d" % image_size + ) + if not all(type(i) == int for i in image_size): + raise ValueError("Image sizes must be integers; got %f, %f" % image_size) + max_image_size = max(*image_size) + im_size = image_size + else: + im_size = (image_size, image_size) + max_image_size = image_size + if bin_size is None: if not points_packed.is_cuda: # Binned CPU rasterization not fully implemented bin_size = 0 else: # TODO: These heuristics are not well-thought out! - if image_size <= 64: + if max_image_size <= 64: bin_size = 8 - elif image_size <= 256: + elif max_image_size <= 256: bin_size = 16 - elif image_size <= 512: + elif max_image_size <= 512: bin_size = 32 - elif image_size <= 1024: + elif max_image_size <= 1024: bin_size = 64 if bin_size != 0: # There is a limit on the number of points per bin in the cuda kernel. # pyre-fixme[58]: `//` is not supported for operand types `int` and # `Union[int, None, int]`. - points_per_bin = 1 + (image_size - 1) // bin_size + points_per_bin = 1 + (max_image_size - 1) // bin_size if points_per_bin >= kMaxPointsPerBin: raise ValueError( "bin_size too small, number of points per bin must be less than %d; got %d" @@ -114,7 +151,7 @@ def rasterize_points( points_packed, cloud_to_packed_first_idx, num_points_per_cloud, - image_size, + im_size, radius, points_per_pixel, bin_size, @@ -173,7 +210,7 @@ class _RasterizePoints(torch.autograd.Function): points, # (P, 3) cloud_to_packed_first_idx, num_points_per_cloud, - image_size: int = 256, + image_size: Union[List[int], Tuple[int, int]] = (256, 256), radius: Union[float, torch.Tensor] = 0.01, points_per_pixel: int = 8, bin_size: int = 0, @@ -225,7 +262,7 @@ class _RasterizePoints(torch.autograd.Function): def rasterize_points_python( pointclouds, - image_size: int = 256, + image_size: Union[int, Tuple[int, int]] = 256, radius: Union[float, torch.Tensor] = 0.01, points_per_pixel: int = 8, ): @@ -235,7 +272,12 @@ def rasterize_points_python( Inputs / Outputs: Same as above """ N = len(pointclouds) - S, K = image_size, points_per_pixel + H, W = ( + image_size + if isinstance(image_size, (tuple, list)) + else (image_size, image_size) + ) + K = points_per_pixel device = pointclouds.device points_packed = pointclouds.points_packed() @@ -247,11 +289,11 @@ def rasterize_points_python( # Intialize output tensors. point_idxs = torch.full( - (N, S, S, K), fill_value=-1, dtype=torch.int32, device=device + (N, H, W, K), fill_value=-1, dtype=torch.int32, device=device ) - zbuf = torch.full((N, S, S, K), fill_value=-1, dtype=torch.float32, device=device) + zbuf = torch.full((N, H, W, K), fill_value=-1, dtype=torch.float32, device=device) pix_dists = torch.full( - (N, S, S, K), fill_value=-1, dtype=torch.float32, device=device + (N, H, W, K), fill_value=-1, dtype=torch.float32, device=device ) # NDC is from [-1, 1]. Get pixel size using specified image size. @@ -263,18 +305,18 @@ def rasterize_points_python( point_stop_idx = point_start_idx + num_points_per_cloud[n] # Iterate through the horizontal lines of the image from top to bottom. - for yi in range(S): + for yi in range(H): # Y coordinate of one end of the image. Reverse the ordering # of yi so that +Y is pointing up in the image. - yfix = S - 1 - yi - yf = pix_to_ndc(yfix, S) + yfix = H - 1 - yi + yf = pix_to_non_square_ndc(yfix, H, W) # Iterate through pixels on this horizontal line, left to right. - for xi in range(S): + for xi in range(W): # X coordinate of one end of the image. Reverse the ordering # of xi so that +X is pointing to the left in the image. - xfix = S - 1 - xi - xf = pix_to_ndc(xfix, S) + xfix = W - 1 - xi + xf = pix_to_non_square_ndc(xfix, W, H) top_k_points = [] # Check whether each point in the batch affects this pixel. diff --git a/pytorch3d/renderer/points/rasterizer.py b/pytorch3d/renderer/points/rasterizer.py index 85e93e42..e8794f4c 100644 --- a/pytorch3d/renderer/points/rasterizer.py +++ b/pytorch3d/renderer/points/rasterizer.py @@ -2,7 +2,7 @@ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -from typing import NamedTuple, Optional, Union +from typing import NamedTuple, Optional, Tuple, Union import torch import torch.nn as nn @@ -29,7 +29,7 @@ class PointsRasterizationSettings: def __init__( self, - image_size: int = 256, + image_size: Union[int, Tuple[int, int]] = 256, radius: Union[float, torch.Tensor] = 0.01, points_per_pixel: int = 8, bin_size: Optional[int] = None, diff --git a/tests/bm_rasterize_points.py b/tests/bm_rasterize_points.py index d00e45ac..70e4a778 100644 --- a/tests/bm_rasterize_points.py +++ b/tests/bm_rasterize_points.py @@ -74,6 +74,21 @@ def bm_python_vs_cpu_vs_cuda() -> None: kwargs_list += [ {"N": 32, "P": 100000, "img_size": 128, "radius": 0.01, "pts_per_pxl": 50}, {"N": 8, "P": 200000, "img_size": 512, "radius": 0.01, "pts_per_pxl": 50}, + {"N": 8, "P": 200000, "img_size": 256, "radius": 0.01, "pts_per_pxl": 50}, + { + "N": 8, + "P": 200000, + "img_size": (512, 256), + "radius": 0.01, + "pts_per_pxl": 50, + }, + { + "N": 8, + "P": 200000, + "img_size": (256, 512), + "radius": 0.01, + "pts_per_pxl": 50, + }, ] for k in kwargs_list: k["device"] = "cuda" diff --git a/tests/data/test_pointcloud_rectangle_image.png b/tests/data/test_pointcloud_rectangle_image.png new file mode 100644 index 0000000000000000000000000000000000000000..8bf30329477ed9d2a3945c63451fb05c32e33a93 GIT binary patch literal 20251 zcmeHv_aoK)`~T}8BGgSv2$d0vlATp4d(W~9Wkd;atkdGAGLyZsLL4K>Rw^smJDmy{ z$H+Lw;T+%V<^BGA{)X=lpYua^xp|G}^Lbs@<9a+EmoNi8brvQrCImrPG&R(W5QGl? zN`vgBf&awf!xj)k@3*Gf1ry(t`7!@zDUQP%OR2-DD|c@?b&lSy88uTC;XK2}A>hkD z5!Yw;>Tw@l-Fj?PJ-?6= z*_jgApi+F_V^O*D>pP#y)Uo;XmF9`z;hVXgUxn^icLt6PN=u>-@_(QITY>+tR$%Vb zJ9}FMar5x08JqDMnO}>Jj)uRw$%F=8Mvyn55Jca;3=LVFJh;3?;f-wmb>~?`#2U5^ zzUP*dOU1Q(6G`S9LxXz}RXza(`OEk8^Yv(rPf3WELh6+)!y+@5rjrMm5_NK|#P@Gq z#~^Z_he{+63yLueY~1Ha515=I6>Ly&yvVt$I9~bFG3F`5UZ(Aaf4N>77JU_1Hpb>DlCtfq5Ux zBC|Rhyoo3YOHsj!C3nc2LqGR~7;)Z?{IRVm<@$_BV=#^mY5aDOHMF_8nf3v$bLw7| z6>DhI3*_L_&jHH8f27H=Nnv4(f%|cEWHK+VUf-;+|K4^cyn_-Of}DTe^eZ7LDg5*G z@$vBwc6Yj%RP#-ko?J}S`J*#?@4NebDX!;g$GfG8ZZr$r6z-4<5dA|;=+mBbuk`ov z@xfvhxSn5Q$JyWqT`F8Ek^~T=#_>CHW@cvaA_^Ydtf7}Oqp$~KkoRNm#*9%-ze=v` z{JJBT+z~^IAR!Uxv%^1~Zr+^w^T%NiKMyOzqYEM!E^hAf(Rr6W8CK%NQh)s5dB5ce zg~hq`EqFt0U*f|0`P6US-QCgA|0t2o@y9U;!hDnw!KL1CadE+7u`VvNZy9kTwT1M) z(#ahrqW$Z)m3SiuiT!-uxg1C?-8@myaH(;|%`=u#XQcg(LHdg`&=)PBtVYmYyR?uoHj7y6|zKvleN3kbjs-x2lmpPCcne< ziv^pkN=)S$iMZn2%ZCSc2KD`6jV&3_*J&h@+P+H2cSRNns$wT26^zd^H?96j?RBmk zby52DHe{6r^FLQ~uNj?c!pG zo({VWknsp|CnyT`lJU!zFFdTx2~QM7B4dzMtP&CLH6V-chnOh3R#c##sTKA{`ZC#> zTqs^MS9US;=lYg$eqToJB?L(^hEVaLN&W)&hQq2t+}4`eqK4SN7~~Ae^f&=u&AnjP z?_#QQ^(be^SrwJV^({EgUP9K^)=G=E9LUd$>ah62)5S9+>?2kOR#k#+Yw7mTIt7BS z)Z#_*`T&NefaUe}Yb+B7(;mOL+U*Anyl0p%y$Y5-bt+LOZQWxZa_7Pc3{t}Wc0ka3 z$K}8f`Fd%Bvog*E<5+$*jCzzcH0Yr-TgeX-*fFQ_VJmTA28ci2NcBi6e4)rpEKaM5 zN=1;v2hdQs5U5*ZX3Ri-E78k;rnlAel!^^@CI@>VIL9dZmU5hy=Bfu1K{pns%c$zo zQQoO6HcIc1`~d4$W%ZKI&dp6{l8tQ1vQ@YTnUg{n=3zKwi>(0V-nSBg2Q1*tncuyC ze=GHX)yadBi8_kBvX`fV8VGOs8k3Mu<4Y>krO4-3E-g?--DH%=5(k++*qu4Zbd{im z7n(i)iW9+=tD=kbdXNZf2zKQgR0JV3-dcg@yZ>aL$B>S}Fo$b1_aaF|j6UE(RIpM3>AFKB=>*7Z@9Zch-X;b5VTiZIrVa z2(uL-_APxSkwL3jHTM_O($muu6Qf5rB@=Y4W0E9t(^9%VT`#Q{LqstK>(rVt_nNU{ z$WAWO>7Dx*&(lL0Ll4jk$gs;3nx=~E%=9F>#%c!0`HJM}tXyS9ER9{!Q8)xfF9@a6 zJ~gw3Ev7S5R`)ujpFVvWZQROzP5oK^v~96fg`gnh#uRCE{S8FpVt%CN=_2)8V~gYf z(oh_Y~lRc5{px; zTdweQ+UV2eV(XxeqPc)A13i0iLg3%yOOWR)M^7`5nQ+x4@dom8e8ZbX;yejk5wv=> z!Ly{Kgnx;$yu3`MY*A>Zk6u2X8o^5`q#&Z^$IvG|iazPpyk1CjXP8Mkrw-Mu_GlkI zeE29AnuZl7Rf`afk3a9oi8pL7-;slYu)auu-8Q)t+wKYg85L#(3oqY@vwK`=u}3>vO~GVQPKtwn3YYah}svp|8x4 z2`>7alP-7dhHQlqyDfN?JN7fFiBO-QcAYy{?!WG?^w4}#sZ7(H4QFuj;9k0*zmNDJ z>^*{3lG!N1zxLSFVe-1QB>HD0z}c{CL*PS&^YH=*2zs@{GCM(^{EuYCCK`>56_m z2Psgn5-W3>iOX5ah1x&NOsN+71}YInM~)sfZT6{Aif#Sc*;#0BBRU|fOY^ebj8PS8 zS&mVwL)ow!>y{-0;&ua)!@J5E=@{fq1B2?TVA3{}4AWIx(YOF*!gOMEaCPs@o8Tnf zEOCnn$QHa5b5V}Cl-1Lin6{k${{HKwa(y1=slbPJu z!XB&bWE!+gt+bA&wPv5DR1RcXdv?wHoKMwTH|cSEYlD|f?x1X~YomAU<9&2N5CCoX zG#^rO7k#ioz}Ee_##ewntZ16m9i>iT=^st{_98Y8TW3=UuqsK4+tl<@j7ke*!wzL}-;g1;N8~YUw1!G8e`DipflK`{{zxQy@5fInV?I&pQY!U zhGgdh2|-EJkLMTOXIY82`#5k=h`Nva1}3y}mRn9lYJ}|o7D)H~9cF`HFY*X<_4fAm z^V=+fbhbGe50x*ikw@;1DvR7X=0{a_kRVV@W(m0=@Vs~Rn7dzJ4KkBizHOHJ$EPNe zM}E0GGRQ;_!PhQ_AKVky?iluy&XVsfXV$4mgoER=Cr)&RqsWY^A(HM%x zg@68CAGj64t2FCb#fzOjcC~A$A-KW$-W@x91G7yk}Bl~;_;!&dDlV`jYA;Ezc)BJeR&hxBJTI3;8*nCde{PK$x5;Kc1>Xe zzoNI?J^lys=~VRwhTY!zFI3&Ud)LtVOU8mKzHV&);Iy@zX@g!^-zZma3Mpv(1HC)UOIbg|-)OGB%wY7cx`0?x4X9582GCUZgenTa8RC!@#c^m>6WyIeF_J}S&luWtth+&Rb0vLe`wO?ni**6C z*SiR3w8n@!jFIzUDLuW+4)WKU_;OjWt~qY==&LFjt~1s=u+AznVvkH zVX`cG**>Y@Gne44U0P3LaO#k4AgTWm!~TLyr&+r$CLC6bgJNFTZ>5n(sR<& zPb2sU()bc>5o-W!6|CfXhre8<%94mu=E8KGr7}2iW+P zFc#OYL0KNyTKIf@>D5#5YZey1%dNs&#mF`Z9}HON^L0ZO$$g;LX!rbA$TW5hZeE#Kn6Gsp9e4%H}m=yt4!bq=W;F$}jNr zl8GkvZ)c;j!bt#tNAUN&y-lC*GhnA>SBhRk@^&yCfI!a?C(YH%N zY+u5f85H(giJL)dqaRrmDQuU{eS>Ym;vap1axRv$PnXIb4Q#FADZ&Y=)fGZDmK zx4eg`U5qSptSolkqv}H};J~46V5&5mc>59rTu)ETNJt?oq^r)U*mV;|Rl>Bb4fI+3 z^XJ*aKEXZa=H@2jZKyMrfnSMi11S9ZcE)@71^lUhs zt8}U%FIrP1z4`kV>;tK|SM(|Q<4>Uo!>jsKSv64~K^;Gdg2wP{j3>aqc9fZD7f()4 zmt}4mX}8@!86O|tJF^PV@y+6}60~_E*9+N@eo2jYvq>_G3aQ^vK5WoKj)#|z&wP}T zN?mU#{B-Hi0p-nB{Uo6mGe;hJvz2t^W$uz7KHZlO`zfY5NXpNt>_P0F0@+k@BmejY zP!uN*IykB>PMq-Gj$uZduS|nmJ5&mJ$U5wu$NOI)M)@Wu4?@8obC0J~5Enl|B?FR; z{OxfKye0a+&!%JAyDHhAJ%0|b>$kcYkX!m;psK+oDv^4z=?rjU{jG(Cg-z5G@WMRk z3#S+CD_f)YuCjtOk(7{NVP(Zk|47obvCsUiQGQETX?}jbj!-$(-CB^7xj;kdAk&^l z7w}i~5rJdT9~{b5EYjV&+$*6!OG8@HoT>z30WeAc=a_#{CRr(xAcfd*^TXS%90`fm z%6Ze>&^*4`6JP)Ke;RuS2fW%Sx#V=%JIASo?d@Fn#4+BA08Fa~?p%V;vz=S$t2r00L%mf~6nGRf&08cGm=Kr9+0wnjEk<~>xF zqa#;Sk9@zsN0kkc9-N!WOBX-%bhgZqa1Z)e0k$S88G~5vg_F7OGf)=1yDZ4Qefywr zLi;U-3>>ghlDLgFk~_bjN(>N51yAPMiuzq5g52_A{-N{qTSm#6)5_Bv7sx_(j0=L8 zf1AfURdxb$jJlhT#kWt|o#AArk8PbG_iMpD&=AG%_MsN;M)PAxVy02P#?xcgr51BVGr$>?!UFy)E<%m2S?<}R@y91@F?eSyOzk(2 zY*AxP1)|Odj0Qrn&XB&%1GOU|TkS^oKEv1P7Z`EjR7e=%)JvINPzp9++kkDh63<)w z1Vz>S>F3|xgVuR8_lN!Yy*Jq?g@)H-kUk=5giD^f(X2wPTMX0}L;}r`@XxPm<7yk4 zesOYh$EwR8Sl*98_%|A8aQnW&Y7ssH%7>!VwMfzKZlD&Jhzc^QLKmNWcR#NpC)1O) z`O%xRlLR*z6^L6cOphzRF1H$Kt{dcVsMv2r?Y2}SeZ(_q0Gaox0e)swi!q8?tMHO<68MJyy=>{|d{$PLk?^^K?%)?) z1h1qZ{xpF7fm>@p47*e)sHmu%J9pyf<)&YM6~qv-#)JLA8%=~A>N@0OEhIbXWU8Jl z*K;!05o!Kr7Zu!yUYAS7EK%XGZB~Hz%fpD>DEjcni~^VZxhU$r&a)OfMccls(}u|% z?*N%g8#P{|$zDUI11YAu!N;S4W9u~#?>id);k61y|6gNy4K zRWl8u(z276%xRDmF=zzE<)DH?lb6Nw`ld-IU;jPh)Qt4pH^s%ut6i6XL~}{N?R|3b zfvTzpL4*-k(R?e%=q&TNKVj>~@87Q;FT)S_&a~wGS=qc$^|*a9ru!ong_J9s0npe_ zdTvT{Z&yQ|_qJHh_~1}*uE%K`D*hDO;q23^;0qjy=XpK-&a99Ws9AB%+Qw$#ypdg( zkR9>)y(1zAn4Y|R`O={eR~-+;t9Fi&P;BXyogR7<7U7rX&~Z6yzG!uTrW+PCub2B& zJ8U~pZ!R8}P98K8)LFL9k7EhEwKP3_OFVwGwt*9|I|pFu{edG~f!dH_OzU=ZD=W9K z#G=YKv2+J}c@g61ti*85eSCg`@IfX=$wal|nda1%eZW_b&cAz!v@$`+=|S(};5lx% z0f1v}?v3+q4o4!=+=O^+e}TL<@CQ6+8xfW5Na4V#9seh?S$5+z(A$71p_0L;w-`xy z6+FE3H3;zFLr6hZ|IU)IGRgf}!6FSambrMQ@8x-bw1G$eb?0>};_?iqY^lm8(se%p z$3wVJ2#i!(OR%~Gdv}`ul!Qc~7$g0oJZm)AxVOVo^lkz>_4zu8TXUE*93p^W65UFc z1n7{G8T3{ekVDjVu*|2EF+gd<&L(%jGN9CZ&cu&*-f$}i)##ASTtixN6vi(lDanWt zH@lJl(9M|z>b_nw0#|}J`TCMNP$9eM=qtqX@-Rfv1CC{f;p6xJ!B)o_0t0q zaauj&l5tu^#l;1)K15DT?<(S*)f*!)wmaGs|AfT2+2IUL#uM6MX3y`Vq z6zmosP{pJ%Agp=j4sBecL+tLMTYd)=9Cl&7gti6!8k&q(!8Kz)?!!Z$RQ;j;+n2Ig z|E0Nk^16;vaP~q2Y@CL~j14xTp1pDu@Bv6xwKn+9sg*SdR-HJN+*qI>?F!NOyg>9) z01~|5oz2(}v-4It;u^YDJtr2syA69eF{xPgK(Ba(VTFf30X+6RYEd|^Q+%3}86ela zYlh=XzHrY+AM6x)?^jvD>6ON5A*-O6B~C54V33=rNPES>?xoA!-b%WbYZe3|iIqpr zAm2o%3A1|e;uwECon4nACSN(`I&a|8g@q5q`@cI4Y3wRTZ-9DtD3*>tRWCACOXhS! zTJ-1VdU;tY`Mgh7MHCJXG9z;5;gJ3fhJy2Az2&7<*vCP6+wkDkQe=SXc;E@V0P^Sx z(|fC8_e$5g%mtsCNyC(<`^9_fb8~Y|X#N5$Xk)hXlRC00Mmf@Z7O^8jxkbu>h82=6 zg4PCTook0=r>g~uZVdTlh3j;h86`KL|7NNeOR zjr4FN8i!+NQw#C+WZ)zu(N6EP6tDI>_80ArnutRhFnM1cB%wWWnbdc^hbK&}=<4Wv zl7hF&~B7?UtxyNeLbSb2ERD0%cEt0U~u4p-6i>CXjA589TR$eOw&;l&yq-T!N61EOX zGw>{0!kTUqI;4{_#D^Fp`3W2h8G+;udKye*^GfFa!0E)BX@=I1pi>~wX5bOUo$VD8 z*1Cswn3n$@GEjOoUg`Y#WUU;?quC#BL0jfhYvXbEIp`hOx?E>lm#RxeZlLbI+QSIh z>jErv$uVeQ;9i55ma7qs%&uHZwm+gj5=w9s^c!LZtpdQIv_d8rn;Cc5VrLc|#Bv^x!6M7XQ=KGK# zEA(E5i8@ZeHaV4hE2MIJ)Sl&G-T51;^7bW?Pxli*3IsdQ1Fv?R%Ja#S#2aBxKfj{F z>6XpOifEq}t*CSSguBI^nepcy&_rTd9Ta)>1=x~3E8r$T6C%kcVq05SsCVRe_R={F z&cS^Ix}e>cXxJ{R7s_y=fCB&;PRN|Gt=YZKes3d=E1#e7s-~sT-c@q68oV~xGTtzLr{Y)8zPbx-NVnA;a`{b z>^ty3x*+n?0zL+GAQY=96bp%I$v5DUxnt!hDV32NVbf{}^y5mW3ceV~g#1*-2dZ}0 zuOA1jaPb(iw%n=wYQ9Myld4ecaJB>+az_t1=f)~Y8JXNQ`q!*!E`N$2j{s8b&NBwpLeHCng-ExL87)0&B<`wbO55K6(1 zm9%w>bp%vvT|OgvL{1&*k=ji&A#*uv3q;fZxDw{IS>eD7 zFm^O71}^%ev>1?q_l#x&6y`dakd01s4_FU^1trq}Qe3~9Uytv;|B*uaR$ z$iq)B?+TFCyP0Y&OAFNJ!C*&hf9Z{8NV-|w0Ci&!00+B2ISi%*-0e>H%66Q+CKACT zFR7p)dhg#s`KZCRIy85zKp7tR2G6R&a%JVs2m}{exCU>8I|M+D)*=_UIaK zMNbF&_?QVI(1Hs>x889E1YL+hP#T$=`pCUfd*7H9wzlS)9|<`-x1R|a3P5*}0;+hv zG!%2TuxT;0$A`rC^z*Li1e&t;$sUi*WcmlDg{DU+g*G`6p_B_y9U|#~L^u-J%nq4z zS1#hQV8>0R2U8oU#0CMf?31mEYaQfg_MD%upY<1|r_eM;fC!3=hKX7ruUgDvL+Mu$pG0l@|(xCItXA&hTUdAV(ZczgkWVxV%&Rx z1Lw5g59im}rzQ~mW}cpshn}YT>``s=f7N{py1Ky;DShS5?8hT;B=I7cp*QX--M74U zO<)C!zLL`+kF1VF)^e}3M^5u`Xkv|Ib|YndBAeG(``=Wt#9%9`63lG zcHzm6yT-GlD&SU=T7nI!Gxn|-vvjYT94ftm*7wCLH=qGRnoriGm3|P3A4vfQ-i^mS^g=7oF+!2BpVl|B(m8U3T151eiEKwLCvv^nG zVcmc)bhX~?Alp4Ze9@Hc(GYmWJ9rh-b^n+ObfrcKUOyW#ame=X!wpfN{+vtSe3mxq zrd3RX97d7YQ24_SKObLV!`YoV`JkSz&4>onZ`kXCSwm>|9!YaO2XNI}FO&1n6f}nb zik@5q$g0%#z?;2fr0H+IexNAVbMe!sL6|9{AyqLq`I2Yy?p1vUgm$(lPNmU(1&KApy^9U{IVs54EeKt6xs z%s~DgI3t|Q%KsIPMvhn&naNI;VSau@BhoU0u(+RuV*j^UmDfP7ag=W`5!GzcH9U2w ze`#W30%RyB>yj-m6UL})adR8L$?fghBt8CgZRLOK#h{`Trc4bsP)X8*RwvqN&@=xJ z`p4wPAmFqo(kPE)UT8{rvgp_8Pn`;;s&l#j@vN$Kx|cjF_YJoj6VUmCGbkxBvAa1I z8tGoJNcrvOQx2GnGjBU|(-ZKn>+U6e!5*?p6}hh7sAGA}dp23`WwVY@mcqhQFNM09 z)6(-d8zHgIuUVXWce69CRT>;o^O!Dk1Y#qcn)71@1X=(AFQC^?{1jkA&G`IL8S(O33Lgihz3|I`G$e^g5R^`HO?MXZO0U!XI|2|zz4DW)*AFPBbr)^dWr*4GWm*qR09s5o zv!IndT3m2GsMO*s)X4vMQxe@|#Ny<+&q7sDQabwdGxUjPI9InQq9HHaUV#fc--Nlz zz3ZVqV4z$4Z5_VAU+}+*5M66?mtZsgS z+F@Xx6F)XOzm}A1gKxi4XJgLSEVWlh#aQJ%w0brHOP{N%s!A>c6VFD`{_*P7&h=7G z)=<#sgeaeZINhx<1s120UIlZEv0lSy|!Y>;u;J_uFd*^#$|&Q}S-33j6U@>NXkD2_#-y6A3Jy2)7HL0oAvA{Sz| zh(_xYvK!#o`wu1sFeHcpIRw<~!KVExZr~Fez*F6^G*KN+y-AiHbs)Xf&)@dwW!N*# z5I$mVcl1QY=!-=bNEyck$R8lngHi-% z5>mL&KLTuw9I1s$%mt(Q)PP+}R>4~l-bgT`X$UNXLE-g28C^C5{n>`2Qdb%ke1I6G zLlY-Z(U0-<-}8VRDzG8U`)GSVSLm4Z1cC*cfHohRI0V3b+ba9vBtV+{dMa593=ZF+@nl~!6`)swN8F}L^Im@Fe=c?k0TPeQin7vMl?Gx@Y> z__N-xPU#-c146Rt*9NEuJw%f7-v|3O{*}nl1=+=fv8&j+rtkX2Gn@f4Iq`3r=q83t zC>{#UE~uQ4f|^fp!Qj6)vURY{iXhQ?=K=}8HA*Pf!a~sTdONjT6hwC6-{O;Qz`*0SXdqVwaA3TZNsfP$es9{e+nZI^~ zn;Z3`!k4)FHE90!8S9UEwx4zxNO;Bt_Nr2oes z;S_sST`(a;P|V!w?&VB_zgrYIhlx6@TQv;0mB(~Iqtb8|B> z@rWh@^ez~R!S})Yr0_J|YJ_{az`26G`;A~smX=<=5fR5)jVZVzLNji>yGzOU&7uM zt^ZVkQ(gP_SRao2vfnz(cRWQC5p16!?&I|6s8y$;5 z>oq86$?+TlW=SqRB>`WI3{rLbNIPXgKhao7vyc3a+S|T3m<;8Y5 z2|ktc{hJ|;*%&2IbK}WU$-0|{Y-}8q!%R=W(HzIYycPK$H4Dj5N$&9Z{fSW|uKWJ) zB%jFwY%&OZF3rv(@BBds_E+SkJVwp^G04Qq=Eck?y*$#^dWUiAH#<_+;~AH19F_{k zFkL`FLa|G}+n>VNJ?cmU-C63OXoPPFE*Tis3y7k5(Q<0WHI-t>$&GK88{Tfq_Ip z>_)KP!b%w!i`^MG#w!=@8sq5q~RCWptsIg;ijbtKL9C!}3U6deW~;f%C$vDtQ39yYVe}BF#TZ&?q|lkKDH`d3zcDCy>NL zV8I$fjV`KuqIcrMYL~N}iA9prenh;_;*TOc@_IlCJe}|=Y01e?o5w*?N&jJb@}Pi% zY{lFQK0%Cu_clx+XhXf`fo>1bI_a)MCyocTXv9}JL0`ecnoMN`_YWa4N<-iqWpF<- zbpVcM#(wE)?X{8O26pq`{*;spQ=J_hSD|XS)s>aSv`=2Sas^D$)4|Y*$G2b2+PRXI zX+{gI$1Z^kd93ovj+N-KD#(*lQ&TWP;rQl(lr&c#r^#*as3~X=|57D^1V?L(vO;Ra z*Lm{yDdlyr?Z1-8jgSgKz<0Oq{bTKY=G?g(VA%4Re{m(H&PDNZGlK^@g1~wOOcy zXT}P4zTK}s!HuBW3IJr7MgyP0po=s`D*y0Er3V3oXyraI83c3*{jBc}F128$qj|n8l~1pLU4NZAAq=(;9Ojp z%7->}yHu<`1uDCAVW5A7;cC7CMPada?I$Fl>U6TWH&E7~G#CUGzWb+!+|7_O!PdL5G~Q6R6gtmI=KMYZmtg&TbGxYgSY6j z4BVsQHz@Sb#lcj7QwNz5FiXWtdhiKu#bF#-8U7>4fUi7IU%bg=?1M=PFvvDK` z$s_%)J4vp^@&Owq;maiM&4QB}5*RJCbx=I?wBzg7*AZKyn648WW`%;%im(q;<->E? zP82r&UVq?Qn|b}`{v>BboemoWom(Ko;Um7j8_?eTsGjcW{cnFeROHt zC)6pLAd4UFvd-cwJm&~TD>qbV#r1+`w8v3}ht<}zRZyb#??cf^z88q%_TMJ(H@kj% zs#P@&o|Cfym1;MDSgmL0$NOu`psM|1im|Ams^JcfI*@s|2R!w-Wfr* zSsr#X*cWu#5x0ZO#uzb3$f71jrYE!y8c|~+m~V~U|29B_!F}Y2yyrv<8sn7{>DXBj z#^1WIIS=Ef8_Ws=$GHh|cYflezMGir({pXm<0vEMwJrgsg~3mgyV1?#z&dHpM?ovJ zZ(rHD3%gc>S3x{$=`_eUHuSySTHe`FjzF`qviaxF9~0*&@9=@smER#LgVRpBawEn!SqP>K9K95d_BZ!po%j#DHW8JqMkgEW%IcF7K#;l&MJndeufTN> zK+%`e1@DQ@Z7Td9@U2?|B7{d1=h$KM(_WSMP~G$^D#J>eU^S&m^`05b-)Zxzg}JH@ z&3jcT_M9-=SN-6^^L-`ovuQdk(je#NPB1`!g#y(pmyiDgE^>wBM=&G3JD?0GaH*{= zypNTHRjL6s)MNx#Af6JQe|SnEK3jH)=g5)a@|%$YAgZ59)f0}@bPYGc4-RS~Fl7Q$ zY?)3lp8P!^`o9oSM&KY5Wd3hozv}4fUS*&p{h`%ZRgsZ#svHIO2tU|MD9D>hiKu1> z9ytf@321!)4Q2E7xS!;WOiqewvpT1Qi5lArwbH5_JZASgLRNKY8~>mhe~Jy^96CbJt%iP=ej-UL z=l7Ct+3*|QvZq^%4Lgto_25vN6)sy8nd$b{>rfOJHitWdgoyjGs|5E-B8d!BHvNC- zh`tyi%qzCST-6D5>8}F2550ImfRz#xu$cv=hBw{`Y`nbt@zF)J3MF&ysV$m%7EoKmmfB82!GVa${LlOD*q8V)Z)IYKfhBZF!xHRvEUM)6MmO1(cjujX2jRBvgX zHmHp=sY&xyT=1*{j4yT6rE>I9&wkjO!w1m~8a9Z1r^7&IB;$2-I$;)L*j=m%n=?*D z8A;Eo?XFAv$AY%z$+v?{5BA;PkIT2s)8Ur8{SrujR%#&wBF7DsvA=@R_6;TD;4Adr zsd9He9{ivJ00nTZ`-;lO&qp&}b5&?eZZtuq}+?OA# zApK|NDU0VMU~!4D?iq=LfW%+&e&7RWMc$q^rOJZb0o)2u#YHSzBqao8#fG1l@D-)M zYLPk%cB9?JiCO)cKS*3Y$UMNH=Zja{;1_*4+z1Q|*{cNHFfc|>nWt&2 z0#`dCrwtUb(CyX-Z)f8}lFO7cIx$fjKMIWUx92^AWA%2hI{3(OlEskQJ5aLTzI_Ws zymgHaWC|^XO#Q7M|HxSpts=9Um2c-yM7O|<5^(d){Kxy4Ifx$s3X>bBp!|w8f`Q<+zfBzjAK0nfz!-g<^$Z3s48BT7c zxY}|a$Tg%}$GIO0QlhkY!9HjcmQ5uY<$I0Zt~`C})Tirb_jY5M1fN6Up^-(^^W z`$IBh%h1a4gll)ec!|D#x(p_2+P1DWj9oZ{EKS?!1Y{%Prg1#_%pL;lz(OifP!+l- zd0TYaS_2&*fHOJf-kw!Wkx*F}t%P570G$}&1@wS^kw#`Oi=ccJFaBG#Ly|Ouiw7G7e0T!+?(o4ynazKQ8D%}<=gqi?=|;5P*wRl zh~Y4#?7bi7y?Jroj(07d1xdz1PA-98_u=a2clA{;?8L>tFx)*~Nf2l8gVD|W%}#Wj z3+Nz_%pnjMao@kiSH=*&j4yc?+!w(7dkZzy-Gyl)?;Tu#IK=mPTBzH0_p*&doYS)K zh=s?fgt)j3AQhN!r36WLnUl7_bxies0%Tf?0TP?4dr^7J%nTLtsWZ_haGX$f=h3$W>sL1Q(cx`3Hz$$mk4 z)`E3a>C>$3+0+x?wm@LrRp{^709kAzniIMR`(#!KBTKM^Fh%`gW8-~AlE(Ew*I2%M z2^cIMSiOP8j@(s`oQTtE?by1Q$uCWYX~e9h;%<*k zBAC)eN!N(SRkuuv%(Ru5s+uKktLipL*12fFueWd$Li2kUXNW!LmzI_mpf6JLpc<1) zUA~=&{D!d-kC(>0PA}-b^;*56(t?jps3C@yl5j}%KP*FI)bq3>`Qy+S5YwFkcV_e7 zWIZe@f83u)UdpJyf|S&u>Gs|8UbA28N}0VgF9udqPdRbXnZbh`3b6<1cyFy0`MC|| z9fbL6UOGS(zQ||KG-`vc1@43eD0`%eX2`HjKWf~a$^gp_%+%uL7lGwZVMh9mMw^uM`r3 zhlOc(Y&tVx$T6$&AZnG0cHLXHhIrBVcH)j0s$$mdcvj7=Yz~NEK(>d53tT?^E~$7A zfXVXcn*Rm40VepSP%RNU^bUkeBX5Y1I6MUyjRT__+nem!y?(8nO@oKV*-LT ziUWBSjGSjiV@Vqf!UNN`Ox|xgHE5#RC*i%I&3trjt~XsnU}6%M3e3wg82}iCe%{-u z!u$+CDMUBMT@&He8-W!^z@^?w@z+F*lFsc^X8VA(e*ZO0#<(!?D;{u+Hc&C7af#TV@!{y%L{m{b4& literal 0 HcmV?d00001 diff --git a/tests/test_rasterize_points.py b/tests/test_rasterize_points.py index eef3b85e..dc59e9b2 100644 --- a/tests/test_rasterize_points.py +++ b/tests/test_rasterize_points.py @@ -404,7 +404,7 @@ class TestRasterizePoints(TestCaseMixin, unittest.TestCase): torch.manual_seed(231) N = 3 max_P = 1000 - image_size = 64 + image_size = (64, 64) radius = 0.1 bin_size = 16 max_points_per_bin = 500 @@ -501,7 +501,7 @@ class TestRasterizePoints(TestCaseMixin, unittest.TestCase): device=device ) # fmt: on - image_size = 16 + image_size = (16, 16) radius = 0.2 bin_size = 8 max_points_per_bin = 5 diff --git a/tests/test_rasterize_rectangles.py b/tests/test_rasterize_rectangle_images.py similarity index 51% rename from tests/test_rasterize_rectangles.py rename to tests/test_rasterize_rectangle_images.py index 72bb6fc5..98424938 100644 --- a/tests/test_rasterize_rectangles.py +++ b/tests/test_rasterize_rectangle_images.py @@ -12,19 +12,33 @@ from pytorch3d.io import load_obj from pytorch3d.renderer.cameras import FoVPerspectiveCameras, look_at_view_transform from pytorch3d.renderer.lighting import PointLights from pytorch3d.renderer.materials import Materials -from pytorch3d.renderer.mesh import TexturesUV +from pytorch3d.renderer.mesh import ( + BlendParams, + MeshRasterizer, + MeshRenderer, + RasterizationSettings, + SoftPhongShader, + TexturesUV, +) from pytorch3d.renderer.mesh.rasterize_meshes import ( rasterize_meshes, rasterize_meshes_python, ) -from pytorch3d.renderer.mesh.rasterizer import ( - Fragments, - MeshRasterizer, - RasterizationSettings, +from pytorch3d.renderer.mesh.rasterizer import Fragments +from pytorch3d.renderer.points import ( + AlphaCompositor, + PointsRasterizationSettings, + PointsRasterizer, + PointsRenderer, ) -from pytorch3d.renderer.mesh.renderer import MeshRenderer -from pytorch3d.renderer.mesh.shader import BlendParams, SoftPhongShader -from pytorch3d.structures import Meshes +from pytorch3d.renderer.points.rasterize_points import ( + rasterize_points, + rasterize_points_python, +) +from pytorch3d.renderer.points.rasterizer import PointFragments +from pytorch3d.structures import Meshes, Pointclouds +from pytorch3d.transforms.transform3d import Transform3d +from pytorch3d.utils import torus DEBUG = False @@ -44,9 +58,36 @@ verts0 = torch.tensor( ) faces0 = torch.tensor([[1, 0, 2], [4, 3, 5]], dtype=torch.int64) +# Points for a simple pointcloud. Get the vertices from a +# torus and apply rotations such that the points are no longer +# symmerical in X/Y. +torus_mesh = torus(r=0.25, R=1.0, sides=5, rings=2 * 5) +t = ( + Transform3d() + .rotate_axis_angle(angle=90, axis="Y") + .rotate_axis_angle(angle=45, axis="Z") + .scale(0.3) +) +torus_points = t.transform_points(torus_mesh.verts_padded()).squeeze() -class TestRasterizeRectanglesErrors(TestCaseMixin, unittest.TestCase): - def test_image_size_arg(self): + +def _save_debug_image(idx, image_size, bin_size, blur): + """ + Save a mask image from the rasterization output for debugging. + """ + H, W = image_size + # Save out the last image for debugging + rgb = (idx[-1, ..., :3].cpu() > -1).squeeze() + suffix = "square" if H == W else "non_square" + filename = "%s_bin_size_%s_blur_%.3f_%dx%d.png" + filename = filename % (suffix, str(bin_size), blur, H, W) + if DEBUG: + filename = "DEBUG_%s" % filename + Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(DATA_DIR / filename) + + +class TestRasterizeRectangleImagesErrors(TestCaseMixin, unittest.TestCase): + def test_mesh_image_size_arg(self): meshes = Meshes(verts=[verts0], faces=[faces0]) with self.assertRaises(ValueError) as cm: @@ -76,8 +117,38 @@ class TestRasterizeRectanglesErrors(TestCaseMixin, unittest.TestCase): ) self.assertTrue("sizes must be integers" in cm.msg) + def test_points_image_size_arg(self): + points = Pointclouds([verts0]) -class TestRasterizeRectangles(TestCaseMixin, unittest.TestCase): + with self.assertRaises(ValueError) as cm: + rasterize_points( + points, + (100, 200, 3), + 0.0001, + points_per_pixel=1, + ) + self.assertTrue("tuple/list of (H, W)" in cm.msg) + + with self.assertRaises(ValueError) as cm: + rasterize_points( + points, + (0, 10), + 0.0001, + points_per_pixel=1, + ) + self.assertTrue("sizes must be positive" in cm.msg) + + with self.assertRaises(ValueError) as cm: + rasterize_points( + points, + (100.5, 120.5), + 0.0001, + points_per_pixel=1, + ) + self.assertTrue("sizes must be integers" in cm.msg) + + +class TestRasterizeRectangleImagesMeshes(TestCaseMixin, unittest.TestCase): @staticmethod def _clone_mesh(verts0, faces0, device, batch_size): """ @@ -164,7 +235,7 @@ class TestRasterizeRectangles(TestCaseMixin, unittest.TestCase): meshes_sq, image_size=(S, S), bin_size=0, blur=blur ) # Save debug image - self._save_debug_image(square_fragments, (S, S), 0, blur) + _save_debug_image(square_fragments.pix_to_face, (S, S), 0, blur) # Extract the values in the square image which are non zero. square_mask = square_fragments.pix_to_face > -1 @@ -284,8 +355,8 @@ class TestRasterizeRectangles(TestCaseMixin, unittest.TestCase): ) # Save out debug images if needed - self._save_debug_image(fragments_naive, image_size, 0, blur) - self._save_debug_image(fragments_binned, image_size, None, blur) + _save_debug_image(fragments_naive.pix_to_face, image_size, 0, blur) + _save_debug_image(fragments_binned.pix_to_face, image_size, None, blur) # Check naive and binned fragments give the same outputs self._check_fragments(fragments_naive, fragments_binned) @@ -354,8 +425,8 @@ class TestRasterizeRectangles(TestCaseMixin, unittest.TestCase): ) # Save debug images if DEBUG is set to true at the top of the file. - self._save_debug_image(fragments_naive, image_size, 0, blur) - self._save_debug_image(fragments_python, image_size, "python", blur) + _save_debug_image(fragments_naive.pix_to_face, image_size, 0, blur) + _save_debug_image(fragments_python.pix_to_face, image_size, "python", blur) # List of non square outputs to compare with the square output nonsq_fragment_gradtensor_list = [ @@ -437,3 +508,293 @@ class TestRasterizeRectangles(TestCaseMixin, unittest.TestCase): # NOTE some pixels can be flaky cond1 = torch.allclose(rgb, image_ref, atol=0.05) self.assertTrue(cond1) + + +class TestRasterizeRectangleImagesPointclouds(TestCaseMixin, unittest.TestCase): + @staticmethod + def _clone_pointcloud(verts0, device, batch_size): + """ + Helper function to detach and clone the verts. + This is needed in order to set up the tensors for + gradient computation in different tests. + """ + verts = verts0.detach().clone() + verts.requires_grad = True + pointclouds = Pointclouds(points=[verts]) + pointclouds = pointclouds.to(device).extend(batch_size) + return verts, pointclouds + + def _rasterize(self, meshes, image_size, bin_size, blur): + """ + Simple wrapper around the rasterize function to return + the fragment data. + """ + idxs, zbuf, dists = rasterize_points( + meshes, + image_size, + blur, + points_per_pixel=1, + bin_size=bin_size, + ) + return PointFragments( + idx=idxs, + zbuf=zbuf, + dists=dists, + ) + + def _check_fragments(self, frag_1, frag_2): + """ + Helper function to check that the tensors in + the Fragments frag_1 and frag_2 are the same. + """ + self.assertClose(frag_1.idx, frag_2.idx) + self.assertClose(frag_1.dists, frag_2.dists) + self.assertClose(frag_1.zbuf, frag_2.zbuf) + + def _compare_square_with_nonsq( + self, + image_size, + blur, + device, + points, + nonsq_fragment_gradtensor_list, + batch_size=1, + ): + """ + Calculate the output from rasterizing a square image with the minimum of (H, W). + Then compare this with the same square region in the non square image. + The input points are contained within the [-1, 1] range of the image + so all the relevant pixels will be within the square region. + + `nonsq_fragment_gradtensor_list` is a list of fragments and verts grad tensors + from rasterizing non square images. + """ + # Rasterize the square version of the image + H, W = image_size + S = min(H, W) + points_square, pointclouds_sq = self._clone_pointcloud( + points, device, batch_size + ) + square_fragments = self._rasterize( + pointclouds_sq, image_size=(S, S), bin_size=0, blur=blur + ) + # Save debug image + _save_debug_image(square_fragments.idx, (S, S), 0, blur) + + # Extract the values in the square image which are non zero. + square_mask = square_fragments.idx > -1 + square_dists = square_fragments.dists[square_mask] + square_zbuf = square_fragments.zbuf[square_mask] + + # Retain gradients on the output of fragments to check + # intermediate values with the non square outputs. + square_fragments.dists.retain_grad() + square_fragments.zbuf.retain_grad() + + # Calculate gradient for the square image + torch.manual_seed(231) + grad_zbuf = torch.randn_like(square_zbuf) + grad_dist = torch.randn_like(square_dists) + loss0 = (grad_dist * square_dists).sum() + (grad_zbuf * square_zbuf).sum() + loss0.backward() + + # Now compare against the non square outputs provided + # in the nonsq_fragment_gradtensor_list list + for fragments, grad_tensor, _name in nonsq_fragment_gradtensor_list: + # Check that there are the same number of non zero pixels + # in both the square and non square images. + non_square_mask = fragments.idx > -1 + self.assertEqual(non_square_mask.sum().item(), square_mask.sum().item()) + + # Check dists, zbuf and bary match the square image + non_square_dists = fragments.dists[non_square_mask] + non_square_zbuf = fragments.zbuf[non_square_mask] + self.assertClose(square_dists, non_square_dists) + self.assertClose(square_zbuf, non_square_zbuf) + + # Retain gradients to compare values with outputs from + # square image + fragments.dists.retain_grad() + fragments.zbuf.retain_grad() + loss1 = (grad_dist * non_square_dists).sum() + ( + grad_zbuf * non_square_zbuf + ).sum() + loss1.sum().backward() + + # Get the non zero values in the intermediate gradients + # and compare with the values from the square image + non_square_grad_dists = fragments.dists.grad[non_square_mask] + non_square_grad_zbuf = fragments.zbuf.grad[non_square_mask] + + self.assertClose( + non_square_grad_dists, + square_fragments.dists.grad[square_mask], + ) + self.assertClose( + non_square_grad_zbuf, + square_fragments.zbuf.grad[square_mask], + ) + + # Finally check the gradients of the input vertices for + # the square and non square case + self.assertClose(points_square.grad, grad_tensor.grad, rtol=2e-4) + + def test_gpu(self): + """ + Test that the output of rendering non square images + gives the same result as square images. i.e. the + dists, zbuf, idx are all the same for the square + region which is present in both images. + """ + # Test both cases: (W > H), (H > W) + image_sizes = [(64, 128), (128, 64), (128, 256), (256, 128)] + + devices = ["cuda:0"] + blurs = [5e-2] + batch_sizes = [1, 4] + test_cases = product(image_sizes, blurs, devices, batch_sizes) + + for image_size, blur, device, batch_size in test_cases: + # Initialize the verts grad tensor and the meshes objects + verts_nonsq_naive, pointcloud_nonsq_naive = self._clone_pointcloud( + torus_points, device, batch_size + ) + verts_nonsq_binned, pointcloud_nonsq_binned = self._clone_pointcloud( + torus_points, device, batch_size + ) + + # Get the outputs for both naive and coarse to fine rasterization + fragments_naive = self._rasterize( + pointcloud_nonsq_naive, + image_size, + blur=blur, + bin_size=0, + ) + fragments_binned = self._rasterize( + pointcloud_nonsq_binned, + image_size, + blur=blur, + bin_size=None, + ) + + # Save out debug images if needed + _save_debug_image(fragments_naive.idx, image_size, 0, blur) + _save_debug_image(fragments_binned.idx, image_size, None, blur) + + # Check naive and binned fragments give the same outputs + self._check_fragments(fragments_naive, fragments_binned) + + # Here we want to compare the square image with the naive and the + # coarse to fine methods outputs + nonsq_fragment_gradtensor_list = [ + (fragments_naive, verts_nonsq_naive, "naive"), + (fragments_binned, verts_nonsq_binned, "coarse-to-fine"), + ] + + self._compare_square_with_nonsq( + image_size, + blur, + device, + torus_points, + nonsq_fragment_gradtensor_list, + batch_size, + ) + + def test_cpu(self): + """ + Test that the output of rendering non square images + gives the same result as square images. i.e. the + dists, zbuf, idx are all the same for the square + region which is present in both images. + + In this test we compare between the naive C++ implementation + and the naive python implementation as the Coarse/Fine + method is not fully implemented in C++ + """ + # Test both when (W > H) and (H > W). + # Using smaller image sizes here as the Python rasterizer is really slow. + image_sizes = [(32, 64), (64, 32)] + devices = ["cpu"] + blurs = [5e-2] + batch_sizes = [1] + test_cases = product(image_sizes, blurs, devices, batch_sizes) + + for image_size, blur, device, batch_size in test_cases: + # Initialize the verts grad tensor and the meshes objects + verts_nonsq_naive, pointcloud_nonsq_naive = self._clone_pointcloud( + torus_points, device, batch_size + ) + verts_nonsq_python, pointcloud_nonsq_python = self._clone_pointcloud( + torus_points, device, batch_size + ) + + # Compare Naive CPU with Python as Coarse/Fine rasteriztation + # is not implemented for CPU + fragments_naive = self._rasterize( + pointcloud_nonsq_naive, image_size, bin_size=0, blur=blur + ) + idxs, zbuf, pix_dists = rasterize_points_python( + pointcloud_nonsq_python, + image_size, + blur, + points_per_pixel=1, + ) + fragments_python = PointFragments( + idx=idxs, + zbuf=zbuf, + dists=pix_dists, + ) + + # Save debug images if DEBUG is set to true at the top of the file. + _save_debug_image(fragments_naive.idx, image_size, 0, blur) + _save_debug_image(fragments_python.idx, image_size, "python", blur) + + # List of non square outputs to compare with the square output + nonsq_fragment_gradtensor_list = [ + (fragments_naive, verts_nonsq_naive, "naive"), + (fragments_python, verts_nonsq_python, "python"), + ] + self._compare_square_with_nonsq( + image_size, + blur, + device, + torus_points, + nonsq_fragment_gradtensor_list, + batch_size, + ) + + def test_render_pointcloud(self): + """ + Test a textured poincloud is rendered correctly in a non square image. + """ + device = torch.device("cuda:0") + pointclouds = Pointclouds( + points=[torus_points * 2.0], + features=torch.ones_like(torus_points[None, ...]), + ).to(device) + R, T = look_at_view_transform(2.7, 0.0, 0.0) + cameras = FoVPerspectiveCameras(device=device, R=R, T=T) + raster_settings = PointsRasterizationSettings( + image_size=(512, 1024), radius=5e-2, points_per_pixel=1 + ) + rasterizer = PointsRasterizer(cameras=cameras, raster_settings=raster_settings) + compositor = AlphaCompositor() + renderer = PointsRenderer(rasterizer=rasterizer, compositor=compositor) + + # Load reference image + image_ref = load_rgb_image("test_pointcloud_rectangle_image.png", DATA_DIR) + + for bin_size in [0, None]: + # Check both naive and coarse to fine produce the same output. + renderer.rasterizer.raster_settings.bin_size = bin_size + images = renderer(pointclouds) + rgb = images[0, ..., :3].squeeze().cpu() + + if DEBUG: + Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save( + DATA_DIR / "DEBUG_pointcloud_rectangle_image.png" + ) + + # NOTE some pixels can be flaky + cond1 = torch.allclose(rgb, image_ref, atol=0.05) + self.assertTrue(cond1)