diff --git a/pytorch3d/csrc/compositing/alpha_composite.cu b/pytorch3d/csrc/compositing/alpha_composite.cu index 389b95f8..16fadb02 100644 --- a/pytorch3d/csrc/compositing/alpha_composite.cu +++ b/pytorch3d/csrc/compositing/alpha_composite.cu @@ -30,15 +30,15 @@ __global__ void alphaCompositeCudaForwardKernel( // Get the batch and index const int batch = blockIdx.x; - const int num_pixels = C * W * H; + const int num_pixels = C * H * W; const int num_threads = gridDim.y * blockDim.x; const int tid = blockIdx.y * blockDim.x + threadIdx.x; // Iterate over each feature in each pixel for (int pid = tid; pid < num_pixels; pid += num_threads) { - int ch = pid / (W * H); - int j = (pid % (W * H)) / H; - int i = (pid % (W * H)) % H; + int ch = pid / (H * W); + int j = (pid % (H * W)) / W; + int i = (pid % (H * W)) % W; // alphacomposite the different values float cum_alpha = 1.; @@ -81,16 +81,16 @@ __global__ void alphaCompositeCudaBackwardKernel( // Get the batch and index const int batch = blockIdx.x; - const int num_pixels = C * W * H; + const int num_pixels = C * H * W; const int num_threads = gridDim.y * blockDim.x; const int tid = blockIdx.y * blockDim.x + threadIdx.x; // Parallelize over each feature in each pixel in images of size H * W, // for each image in the batch of size batch_size for (int pid = tid; pid < num_pixels; pid += num_threads) { - int ch = pid / (W * H); - int j = (pid % (W * H)) / H; - int i = (pid % (W * H)) % H; + int ch = pid / (H * W); + int j = (pid % (H * W)) / W; + int i = (pid % (H * W)) % W; // alphacomposite the different values float cum_alpha = 1.; diff --git a/pytorch3d/csrc/compositing/alpha_composite.h b/pytorch3d/csrc/compositing/alpha_composite.h index 735d87e1..c910c32d 100644 --- a/pytorch3d/csrc/compositing/alpha_composite.h +++ b/pytorch3d/csrc/compositing/alpha_composite.h @@ -11,13 +11,13 @@ // features: FloatTensor of shape (C, P) which gives the features // of each point where C is the size of the feature and // P the number of points. -// alphas: FloatTensor of shape (N, points_per_pixel, W, W) where +// alphas: FloatTensor of shape (N, points_per_pixel, H, W) where // points_per_pixel is the number of points in the z-buffer -// sorted in z-order, and W is the image size. -// points_idx: IntTensor of shape (N, points_per_pixel, W, W) giving the +// sorted in z-order, and (H, W) is the image size. +// points_idx: IntTensor of shape (N, points_per_pixel, H, W) giving the // indices of the nearest points at each pixel, sorted in z-order. // Returns: -// weighted_fs: FloatTensor of shape (N, C, W, W) giving the accumulated +// weighted_fs: FloatTensor of shape (N, C, H, W) giving the accumulated // feature for each point. Concretely, it gives: // weighted_fs[b,c,i,j] = sum_k cum_alpha_k * // features[c,points_idx[b,k,i,j]] diff --git a/pytorch3d/csrc/compositing/norm_weighted_sum.cu b/pytorch3d/csrc/compositing/norm_weighted_sum.cu index a787e1fa..1885bec6 100644 --- a/pytorch3d/csrc/compositing/norm_weighted_sum.cu +++ b/pytorch3d/csrc/compositing/norm_weighted_sum.cu @@ -30,16 +30,16 @@ __global__ void weightedSumNormCudaForwardKernel( // Get the batch and index const int batch = blockIdx.x; - const int num_pixels = C * W * H; + const int num_pixels = C * H * W; const int num_threads = gridDim.y * blockDim.x; const int tid = blockIdx.y * blockDim.x + threadIdx.x; // Parallelize over each feature in each pixel in images of size H * W, // for each image in the batch of size batch_size for (int pid = tid; pid < num_pixels; pid += num_threads) { - int ch = pid / (W * H); - int j = (pid % (W * H)) / H; - int i = (pid % (W * H)) % H; + int ch = pid / (H * W); + int j = (pid % (H * W)) / W; + int i = (pid % (H * W)) % W; // Store the accumulated alpha value float cum_alpha = 0.; @@ -101,9 +101,9 @@ __global__ void weightedSumNormCudaBackwardKernel( // Parallelize over each feature in each pixel in images of size H * W, // for each image in the batch of size batch_size for (int pid = tid; pid < num_pixels; pid += num_threads) { - int ch = pid / (W * H); - int j = (pid % (W * H)) / H; - int i = (pid % (W * H)) % H; + int ch = pid / (H * W); + int j = (pid % (H * W)) / W; + int i = (pid % (H * W)) % W; float sum_alpha = 0.; float sum_alphafs = 0.; diff --git a/pytorch3d/csrc/compositing/norm_weighted_sum.h b/pytorch3d/csrc/compositing/norm_weighted_sum.h index 34c271bc..c2878503 100644 --- a/pytorch3d/csrc/compositing/norm_weighted_sum.h +++ b/pytorch3d/csrc/compositing/norm_weighted_sum.h @@ -11,13 +11,13 @@ // features: FloatTensor of shape (C, P) which gives the features // of each point where C is the size of the feature and // P the number of points. -// alphas: FloatTensor of shape (N, points_per_pixel, W, W) where +// alphas: FloatTensor of shape (N, points_per_pixel, H, W) where // points_per_pixel is the number of points in the z-buffer -// sorted in z-order, and W is the image size. -// points_idx: IntTensor of shape (N, points_per_pixel, W, W) giving the +// sorted in z-order, and (H, W) is the image size. +// points_idx: IntTensor of shape (N, points_per_pixel, H, W) giving the // indices of the nearest points at each pixel, sorted in z-order. // Returns: -// weighted_fs: FloatTensor of shape (N, C, W, W) giving the accumulated +// weighted_fs: FloatTensor of shape (N, C, H, W) giving the accumulated // feature in each point. Concretely, it gives: // weighted_fs[b,c,i,j] = sum_k alphas[b,k,i,j] * // features[c,points_idx[b,k,i,j]] / sum_k alphas[b,k,i,j] diff --git a/pytorch3d/csrc/compositing/weighted_sum.cu b/pytorch3d/csrc/compositing/weighted_sum.cu index 68ec351e..cee8e75a 100644 --- a/pytorch3d/csrc/compositing/weighted_sum.cu +++ b/pytorch3d/csrc/compositing/weighted_sum.cu @@ -28,16 +28,16 @@ __global__ void weightedSumCudaForwardKernel( // Get the batch and index const int batch = blockIdx.x; - const int num_pixels = C * W * H; + const int num_pixels = C * H * W; const int num_threads = gridDim.y * blockDim.x; const int tid = blockIdx.y * blockDim.x + threadIdx.x; // Parallelize over each feature in each pixel in images of size H * W, // for each image in the batch of size batch_size for (int pid = tid; pid < num_pixels; pid += num_threads) { - int ch = pid / (W * H); - int j = (pid % (W * H)) / H; - int i = (pid % (W * H)) % H; + int ch = pid / (H * W); + int j = (pid % (H * W)) / W; + int i = (pid % (H * W)) % W; // Iterate through the closest K points for this pixel for (int k = 0; k < points_idx.size(1); ++k) { @@ -76,16 +76,16 @@ __global__ void weightedSumCudaBackwardKernel( // Get the batch and index const int batch = blockIdx.x; - const int num_pixels = C * W * H; + const int num_pixels = C * H * W; const int num_threads = gridDim.y * blockDim.x; const int tid = blockIdx.y * blockDim.x + threadIdx.x; // Iterate over each pixel to compute the contribution to the // gradient for the features and weights for (int pid = tid; pid < num_pixels; pid += num_threads) { - int ch = pid / (W * H); - int j = (pid % (W * H)) / H; - int i = (pid % (W * H)) % H; + int ch = pid / (H * W); + int j = (pid % (H * W)) / W; + int i = (pid % (H * W)) % W; // Iterate through the closest K points for this pixel for (int k = 0; k < points_idx.size(1); ++k) { diff --git a/pytorch3d/csrc/compositing/weighted_sum.h b/pytorch3d/csrc/compositing/weighted_sum.h index 4928a252..aa4154ed 100644 --- a/pytorch3d/csrc/compositing/weighted_sum.h +++ b/pytorch3d/csrc/compositing/weighted_sum.h @@ -11,13 +11,13 @@ // features: FloatTensor of shape (C, P) which gives the features // of each point where C is the size of the feature and // P the number of points. -// alphas: FloatTensor of shape (N, points_per_pixel, W, W) where +// alphas: FloatTensor of shape (N, points_per_pixel, H, W) where // points_per_pixel is the number of points in the z-buffer -// sorted in z-order, and W is the image size. +// sorted in z-order, and (H, W) is the image size. // points_idx: IntTensor of shape (N, points_per_pixel, W, W) giving the // indices of the nearest points at each pixel, sorted in z-order. // Returns: -// weighted_fs: FloatTensor of shape (N, C, W, W) giving the accumulated +// weighted_fs: FloatTensor of shape (N, C, H, W) giving the accumulated // feature in each point. Concretely, it gives: // weighted_fs[b,c,i,j] = sum_k alphas[b,k,i,j] * // features[c,points_idx[b,k,i,j]] diff --git a/pytorch3d/csrc/rasterize_meshes/rasterize_meshes.cu b/pytorch3d/csrc/rasterize_meshes/rasterize_meshes.cu index af973f38..a92a64e5 100644 --- a/pytorch3d/csrc/rasterize_meshes/rasterize_meshes.cu +++ b/pytorch3d/csrc/rasterize_meshes/rasterize_meshes.cu @@ -452,7 +452,6 @@ __global__ void RasterizeMeshesBackwardCudaKernel( const bool inside = b_pp.x > 0.0f && b_pp.y > 0.0f && b_pp.z > 0.0f; const float sign = inside ? -1.0f : 1.0f; - // TODO(T52813608) Add support for non-square images. auto grad_dist_f = PointTriangleDistanceBackward( pxy, v0xy, v1xy, v2xy, sign * grad_dist_upstream); const float2 ddist_d_v0 = thrust::get<1>(grad_dist_f); @@ -606,7 +605,7 @@ __global__ void RasterizeMeshesCoarseCudaKernel( const float half_pix_x = NDC_x_half_range / W; const float half_pix_y = NDC_y_half_range / H; - // This is a boolean array of shape (num_bins, num_bins, chunk_size) + // This is a boolean array of shape (num_bins_y, num_bins_x, chunk_size) // stored in shared memory that will track whether each point in the chunk // falls into each bin of the image. BitMask binmask((unsigned int*)sbuf, num_bins_y, num_bins_x, chunk_size); @@ -755,7 +754,7 @@ at::Tensor RasterizeMeshesCoarseCuda( const int num_bins_y = 1 + (H - 1) / bin_size; const int num_bins_x = 1 + (W - 1) / bin_size; - if (num_bins_y >= kMaxFacesPerBin || num_bins_x >= kMaxFacesPerBin) { + if (num_bins_y >= kMaxItemsPerBin || num_bins_x >= kMaxItemsPerBin) { std::stringstream ss; ss << "In Coarse Rasterizer got num_bins_y: " << num_bins_y << ", num_bins_x: " << num_bins_x << ", " @@ -800,7 +799,7 @@ at::Tensor RasterizeMeshesCoarseCuda( // **************************************************************************** __global__ void RasterizeMeshesFineCudaKernel( const float* face_verts, // (F, 3, 3) - const int32_t* bin_faces, // (N, B, B, T) + const int32_t* bin_faces, // (N, BH, BW, T) const float blur_radius, const int bin_size, const bool perspective_correct, @@ -813,12 +812,12 @@ __global__ void RasterizeMeshesFineCudaKernel( const int H, const int W, const int K, - int64_t* face_idxs, // (N, S, S, K) - float* zbuf, // (N, S, S, K) - float* pix_dists, // (N, S, S, K) - float* bary // (N, S, S, K, 3) + int64_t* face_idxs, // (N, H, W, K) + float* zbuf, // (N, H, W, K) + float* pix_dists, // (N, H, W, K) + float* bary // (N, H, W, K, 3) ) { - // This can be more than S^2 if S % bin_size != 0 + // This can be more than H * W if H or W are not divisible by bin_size. int num_pixels = N * BH * BW * bin_size * bin_size; int num_threads = gridDim.x * blockDim.x; int tid = blockIdx.x * blockDim.x + threadIdx.x; diff --git a/pytorch3d/csrc/rasterize_meshes/rasterize_meshes_cpu.cpp b/pytorch3d/csrc/rasterize_meshes/rasterize_meshes_cpu.cpp index b8a73e20..3160e685 100644 --- a/pytorch3d/csrc/rasterize_meshes/rasterize_meshes_cpu.cpp +++ b/pytorch3d/csrc/rasterize_meshes/rasterize_meshes_cpu.cpp @@ -5,41 +5,11 @@ #include #include #include +#include "rasterize_points/rasterization_utils.h" #include "utils/geometry_utils.h" #include "utils/vec2.h" #include "utils/vec3.h" -// The default value of the NDC range is [-1, 1], however in the case that -// H != W, the NDC range is set such that the shorter side has range [-1, 1] and -// the longer side is scaled by the ratio of H:W. S1 is the dimension for which -// the NDC range is calculated and S2 is the other image dimension. -// e.g. to get the NDC x range S1 = W and S2 = H -float NonSquareNdcRange(int S1, int S2) { - float range = 2.0f; - if (S1 > S2) { - range = ((S1 / S2) * range); - } - return range; -} - -// Given a pixel coordinate 0 <= i < S1, convert it to a normalized device -// coordinates. We divide the NDC range into S1 evenly-sized -// pixels, and assume that each pixel falls in the *center* of its range. -// The default value of the NDC range is [-1, 1], however in the case that -// H != W, the NDC range is set such that the shorter side has range [-1, 1] and -// the longer side is scaled by the ratio of H:W. The dimension of i should be -// S1 and the other image dimension is S2 For example, to get the x and y NDC -// coordinates or a given pixel i: -// x = PixToNonSquareNdc(i, W, H) -// y = PixToNonSquareNdc(i, H, W) -float PixToNonSquareNdc(int i, int S1, int S2) { - float range = NonSquareNdcRange(S1, S2); - // NDC: offset + (i * pixel_width + half_pixel_width) - // The NDC range is [-range/2, range/2]. - const float offset = (range / 2.0f); - return -offset + (range * i + offset) / S1; -} - // Get (x, y, z) values for vertex from (3, 3) tensor face. template auto ExtractVerts(const Face& face, const int vertex_index) { diff --git a/pytorch3d/csrc/rasterize_points/rasterization_utils.cuh b/pytorch3d/csrc/rasterize_points/rasterization_utils.cuh index 8492bad1..18272ab7 100644 --- a/pytorch3d/csrc/rasterize_points/rasterization_utils.cuh +++ b/pytorch3d/csrc/rasterize_points/rasterization_utils.cuh @@ -2,16 +2,6 @@ #pragma once -// Given a pixel coordinate 0 <= i < S, convert it to a normalized device -// coordinates in the range [-1, 1]. We divide the NDC range into S evenly-sized -// pixels, and assume that each pixel falls in the *center* of its range. -// TODO: delete this function after updating the pointcloud rasterizer to -// support non square images. -__device__ inline float PixToNdc(int i, int S) { - // NDC: x-offset + (i * pixel_width + half_pixel_width) - return -1.0 + (2 * i + 1.0) / S; -} - // The default value of the NDC range is [-1, 1], however in the case that // H != W, the NDC range is set such that the shorter side has range [-1, 1] and // the longer side is scaled by the ratio of H:W. S1 is the dimension for which @@ -50,7 +40,7 @@ __device__ inline float PixToNonSquareNdc(int i, int S1, int S2) { // TODO: is 8 enough? Would increasing have performance considerations? const int32_t kMaxPointsPerPixel = 150; -const int32_t kMaxFacesPerBin = 22; +const int32_t kMaxItemsPerBin = 22; template __device__ inline void BubbleSort(T* arr, int n) { diff --git a/pytorch3d/csrc/rasterize_points/rasterization_utils.h b/pytorch3d/csrc/rasterize_points/rasterization_utils.h new file mode 100644 index 00000000..06b6bc5c --- /dev/null +++ b/pytorch3d/csrc/rasterize_points/rasterization_utils.h @@ -0,0 +1,34 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. + +#pragma once + +// The default value of the NDC range is [-1, 1], however in the case that +// H != W, the NDC range is set such that the shorter side has range [-1, 1] and +// the longer side is scaled by the ratio of H:W. S1 is the dimension for which +// the NDC range is calculated and S2 is the other image dimension. +// e.g. to get the NDC x range S1 = W and S2 = H +inline float NonSquareNdcRange(int S1, int S2) { + float range = 2.0f; + if (S1 > S2) { + range = ((S1 / S2) * range); + } + return range; +} + +// Given a pixel coordinate 0 <= i < S1, convert it to a normalized device +// coordinates. We divide the NDC range into S1 evenly-sized +// pixels, and assume that each pixel falls in the *center* of its range. +// The default value of the NDC range is [-1, 1], however in the case that +// H != W, the NDC range is set such that the shorter side has range [-1, 1] and +// the longer side is scaled by the ratio of H:W. The dimension of i should be +// S1 and the other image dimension is S2 For example, to get the x and y NDC +// coordinates or a given pixel i: +// x = PixToNonSquareNdc(i, W, H) +// y = PixToNonSquareNdc(i, H, W) +inline float PixToNonSquareNdc(int i, int S1, int S2) { + float range = NonSquareNdcRange(S1, S2); + // NDC: offset + (i * pixel_width + half_pixel_width) + // The NDC range is [-range/2, range/2]. + const float offset = (range / 2.0f); + return -offset + (range * i + offset) / S1; +} diff --git a/pytorch3d/csrc/rasterize_points/rasterize_points.cu b/pytorch3d/csrc/rasterize_points/rasterize_points.cu index d02a5680..8b5ea133 100644 --- a/pytorch3d/csrc/rasterize_points/rasterize_points.cu +++ b/pytorch3d/csrc/rasterize_points/rasterize_points.cu @@ -85,26 +85,28 @@ __global__ void RasterizePointsNaiveCudaKernel( const int64_t* num_points_per_cloud, // (N) const float* radius, const int N, - const int S, + const int H, + const int W, const int K, - int32_t* point_idxs, // (N, S, S, K) - float* zbuf, // (N, S, S, K) - float* pix_dists) { // (N, S, S, K) + int32_t* point_idxs, // (N, H, W, K) + float* zbuf, // (N, H, W, K) + float* pix_dists) { // (N, H, W, K) // Simple version: One thread per output pixel const int num_threads = gridDim.x * blockDim.x; const int tid = blockDim.x * blockIdx.x + threadIdx.x; - for (int i = tid; i < N * S * S; i += num_threads) { + for (int i = tid; i < N * H * W; i += num_threads) { // Convert linear index to 3D index - const int n = i / (S * S); // Batch index - const int pix_idx = i % (S * S); + const int n = i / (H * W); // Batch index + const int pix_idx = i % (H * W); // Reverse ordering of the X and Y axis as the camera coordinates // assume that +Y is pointing up and +X is pointing left. - const int yi = S - 1 - pix_idx / S; - const int xi = S - 1 - pix_idx % S; + const int yi = H - 1 - pix_idx / W; + const int xi = W - 1 - pix_idx % W; - const float xf = PixToNdc(xi, S); - const float yf = PixToNdc(yi, S); + // screen coordinates to ndc coordiantes of pixel. + const float xf = PixToNonSquareNdc(xi, W, H); + const float yf = PixToNonSquareNdc(yi, H, W); // For keeping track of the K closest points we want a data structure // that (1) gives O(1) access to the closest point for easy comparisons, @@ -132,7 +134,7 @@ __global__ void RasterizePointsNaiveCudaKernel( points, p_idx, q_size, q_max_z, q_max_idx, q, radius, xf, yf, K); } BubbleSort(q, q_size); - int idx = n * S * S * K + pix_idx * K; + int idx = n * H * W * K + pix_idx * K; for (int k = 0; k < q_size; ++k) { point_idxs[idx + k] = q[k].idx; zbuf[idx + k] = q[k].z; @@ -145,7 +147,7 @@ std::tuple RasterizePointsNaiveCuda( const at::Tensor& points, // (P. 3) const at::Tensor& cloud_to_packed_first_idx, // (N) const at::Tensor& num_points_per_cloud, // (N) - const int image_size, + const std::tuple image_size, const at::Tensor& radius, const int points_per_pixel) { // Check inputs are on the same device @@ -169,7 +171,8 @@ std::tuple RasterizePointsNaiveCuda( "num_points_per_cloud must have same size first dimension as cloud_to_packed_first_idx"); const int N = num_points_per_cloud.size(0); // batch size. - const int S = image_size; + const int H = std::get<0>(image_size); + const int W = std::get<1>(image_size); const int K = points_per_pixel; if (K > kMaxPointsPerPixel) { @@ -180,9 +183,9 @@ std::tuple RasterizePointsNaiveCuda( auto int_opts = num_points_per_cloud.options().dtype(at::kInt); auto float_opts = points.options().dtype(at::kFloat); - at::Tensor point_idxs = at::full({N, S, S, K}, -1, int_opts); - at::Tensor zbuf = at::full({N, S, S, K}, -1, float_opts); - at::Tensor pix_dists = at::full({N, S, S, K}, -1, float_opts); + at::Tensor point_idxs = at::full({N, H, W, K}, -1, int_opts); + at::Tensor zbuf = at::full({N, H, W, K}, -1, float_opts); + at::Tensor pix_dists = at::full({N, H, W, K}, -1, float_opts); if (point_idxs.numel() == 0) { AT_CUDA_CHECK(cudaGetLastError()); @@ -197,7 +200,8 @@ std::tuple RasterizePointsNaiveCuda( num_points_per_cloud.contiguous().data_ptr(), radius.contiguous().data_ptr(), N, - S, + H, + W, K, point_idxs.contiguous().data_ptr(), zbuf.contiguous().data_ptr(), @@ -218,7 +222,8 @@ __global__ void RasterizePointsCoarseCudaKernel( const float* radius, const int N, const int P, - const int S, + const int H, + const int W, const int bin_size, const int chunk_size, const int max_points_per_bin, @@ -226,13 +231,26 @@ __global__ void RasterizePointsCoarseCudaKernel( int* bin_points) { extern __shared__ char sbuf[]; const int M = max_points_per_bin; - const int num_bins = 1 + (S - 1) / bin_size; // Integer divide round up - const float half_pix = 1.0f / S; // Size of half a pixel in NDC units - // This is a boolean array of shape (num_bins, num_bins, chunk_size) + // Integer divide round up + const int num_bins_x = 1 + (W - 1) / bin_size; + const int num_bins_y = 1 + (H - 1) / bin_size; + + // NDC range depends on the ratio of W/H + // The shorter side from (H, W) is given an NDC range of 2.0 and + // the other side is scaled by the ratio of H:W. + const float NDC_x_half_range = NonSquareNdcRange(W, H) / 2.0f; + const float NDC_y_half_range = NonSquareNdcRange(H, W) / 2.0f; + + // Size of half a pixel in NDC units is the NDC half range + // divided by the corresponding image dimension + const float half_pix_x = NDC_x_half_range / W; + const float half_pix_y = NDC_y_half_range / H; + + // This is a boolean array of shape (num_bins_y, num_bins_x, chunk_size) // stored in shared memory that will track whether each point in the chunk // falls into each bin of the image. - BitMask binmask((unsigned int*)sbuf, num_bins, num_bins, chunk_size); + BitMask binmask((unsigned int*)sbuf, num_bins_y, num_bins_x, chunk_size); // Have each block handle a chunk of points and build a 3D bitmask in // shared memory to mark which points hit which bins. In this first phase, @@ -279,22 +297,24 @@ __global__ void RasterizePointsCoarseCudaKernel( // For example we could compute the exact bin where the point falls, // then check neighboring bins. This way we wouldn't have to check // all bins (however then we might have more warp divergence?) - for (int by = 0; by < num_bins; ++by) { - // Get y extent for the bin. PixToNdc gives us the location of + for (int by = 0; by < num_bins_y; ++by) { + // Get y extent for the bin. PixToNonSquareNdc gives us the location of // the center of each pixel, so we need to add/subtract a half // pixel to get the true extent of the bin. - const float by0 = PixToNdc(by * bin_size, S) - half_pix; - const float by1 = PixToNdc((by + 1) * bin_size - 1, S) + half_pix; + const float by0 = PixToNonSquareNdc(by * bin_size, H, W) - half_pix_y; + const float by1 = + PixToNonSquareNdc((by + 1) * bin_size - 1, H, W) + half_pix_y; const bool y_overlap = (py0 <= by1) && (by0 <= py1); if (!y_overlap) { continue; } - for (int bx = 0; bx < num_bins; ++bx) { + for (int bx = 0; bx < num_bins_x; ++bx) { // Get x extent for the bin; again we need to adjust the - // output of PixToNdc by half a pixel. - const float bx0 = PixToNdc(bx * bin_size, S) - half_pix; - const float bx1 = PixToNdc((bx + 1) * bin_size - 1, S) + half_pix; + // output of PixToNonSquareNdc by half a pixel. + const float bx0 = PixToNonSquareNdc(bx * bin_size, W, H) - half_pix_x; + const float bx1 = + PixToNonSquareNdc((bx + 1) * bin_size - 1, W, H) + half_pix_x; const bool x_overlap = (px0 <= bx1) && (bx0 <= px1); if (x_overlap) { @@ -307,12 +327,13 @@ __global__ void RasterizePointsCoarseCudaKernel( // Now we have processed every point in the current chunk. We need to // count the number of points in each bin so we can write the indices // out to global memory. We have each thread handle a different bin. - for (int byx = threadIdx.x; byx < num_bins * num_bins; byx += blockDim.x) { - const int by = byx / num_bins; - const int bx = byx % num_bins; + for (int byx = threadIdx.x; byx < num_bins_y * num_bins_x; + byx += blockDim.x) { + const int by = byx / num_bins_x; + const int bx = byx % num_bins_x; const int count = binmask.count(by, bx); const int points_per_bin_idx = - batch_idx * num_bins * num_bins + by * num_bins + bx; + batch_idx * num_bins_y * num_bins_x + by * num_bins_x + bx; // This atomically increments the (global) number of points found // in the current bin, and gets the previous value of the counter; @@ -322,8 +343,8 @@ __global__ void RasterizePointsCoarseCudaKernel( // Now loop over the binmask and write the active bits for this bin // out to bin_points. - int next_idx = batch_idx * num_bins * num_bins * M + by * num_bins * M + - bx * M + start; + int next_idx = batch_idx * num_bins_y * num_bins_x * M + + by * num_bins_x * M + bx * M + start; for (int p = 0; p < chunk_size; ++p) { if (binmask.get(by, bx, p)) { // TODO: Throw an error if next_idx >= M -- this means that @@ -342,7 +363,7 @@ at::Tensor RasterizePointsCoarseCuda( const at::Tensor& points, // (P, 3) const at::Tensor& cloud_to_packed_first_idx, // (N) const at::Tensor& num_points_per_cloud, // (N) - const int image_size, + const std::tuple image_size, const at::Tensor& radius, const int bin_size, const int max_points_per_bin) { @@ -363,20 +384,28 @@ at::Tensor RasterizePointsCoarseCuda( at::cuda::CUDAGuard device_guard(points.device()); cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const int H = std::get<0>(image_size); + const int W = std::get<1>(image_size); + const int P = points.size(0); const int N = num_points_per_cloud.size(0); - const int num_bins = 1 + (image_size - 1) / bin_size; // divide round up const int M = max_points_per_bin; - if (num_bins >= 22) { + // Integer divide round up. + const int num_bins_y = 1 + (H - 1) / bin_size; + const int num_bins_x = 1 + (W - 1) / bin_size; + + if (num_bins_y >= kMaxItemsPerBin || num_bins_x >= kMaxItemsPerBin) { // Make sure we do not use too much shared memory. std::stringstream ss; - ss << "Got " << num_bins << "; that's too many!"; + ss << "In Coarse Rasterizer got num_bins_y: " << num_bins_y + << ", num_bins_x: " << num_bins_x << ", " + << "; that's too many!"; AT_ERROR(ss.str()); } auto opts = num_points_per_cloud.options().dtype(at::kInt); - at::Tensor points_per_bin = at::zeros({N, num_bins, num_bins}, opts); - at::Tensor bin_points = at::full({N, num_bins, num_bins, M}, -1, opts); + at::Tensor points_per_bin = at::zeros({N, num_bins_y, num_bins_x}, opts); + at::Tensor bin_points = at::full({N, num_bins_y, num_bins_x, M}, -1, opts); if (bin_points.numel() == 0) { AT_CUDA_CHECK(cudaGetLastError()); @@ -384,7 +413,7 @@ at::Tensor RasterizePointsCoarseCuda( } const int chunk_size = 512; - const size_t shared_size = num_bins * num_bins * chunk_size / 8; + const size_t shared_size = num_bins_y * num_bins_x * chunk_size / 8; const size_t blocks = 64; const size_t threads = 512; @@ -395,7 +424,8 @@ at::Tensor RasterizePointsCoarseCuda( radius.contiguous().data_ptr(), N, P, - image_size, + H, + W, bin_size, chunk_size, M, @@ -412,19 +442,21 @@ at::Tensor RasterizePointsCoarseCuda( __global__ void RasterizePointsFineCudaKernel( const float* points, // (P, 3) - const int32_t* bin_points, // (N, B, B, T) + const int32_t* bin_points, // (N, BH, BW, T) const float* radius, const int bin_size, const int N, - const int B, // num_bins + const int BH, // num_bins y + const int BW, // num_bins x const int M, - const int S, + const int H, + const int W, const int K, - int32_t* point_idxs, // (N, S, S, K) - float* zbuf, // (N, S, S, K) - float* pix_dists) { // (N, S, S, K) - // This can be more than S^2 if S is not dividable by bin_size. - const int num_pixels = N * B * B * bin_size * bin_size; + int32_t* point_idxs, // (N, H, W, K) + float* zbuf, // (N, H, W, K) + float* pix_dists) { // (N, H, W, K) + // This can be more than H * W if H or W are not divisible by bin_size. + const int num_pixels = N * BH * BW * bin_size * bin_size; const int num_threads = gridDim.x * blockDim.x; const int tid = blockIdx.x * blockDim.x + threadIdx.x; @@ -434,21 +466,21 @@ __global__ void RasterizePointsFineCudaKernel( // into the same bin; this should give them coalesced memory reads when // they read from points and bin_points. int i = pid; - const int n = i / (B * B * bin_size * bin_size); - i %= B * B * bin_size * bin_size; - const int by = i / (B * bin_size * bin_size); - i %= B * bin_size * bin_size; + const int n = i / (BH * BW * bin_size * bin_size); + i %= BH * BW * bin_size * bin_size; + const int by = i / (BW * bin_size * bin_size); + i %= BW * bin_size * bin_size; const int bx = i / (bin_size * bin_size); i %= bin_size * bin_size; const int yi = i / bin_size + by * bin_size; const int xi = i % bin_size + bx * bin_size; - if (yi >= S || xi >= S) + if (yi >= H || xi >= W) continue; - const float xf = PixToNdc(xi, S); - const float yf = PixToNdc(yi, S); + const float xf = PixToNonSquareNdc(xi, W, H); + const float yf = PixToNonSquareNdc(yi, H, W); // This part looks like the naive rasterization kernel, except we use // bin_points to only look at a subset of points already known to fall @@ -459,7 +491,7 @@ __global__ void RasterizePointsFineCudaKernel( float q_max_z = -1000; int q_max_idx = -1; for (int m = 0; m < M; ++m) { - const int p = bin_points[n * B * B * M + by * B * M + bx * M + m]; + const int p = bin_points[n * BH * BW * M + by * BW * M + bx * M + m]; if (p < 0) { // bin_points uses -1 as a sentinal value continue; @@ -473,10 +505,10 @@ __global__ void RasterizePointsFineCudaKernel( // Reverse ordering of the X and Y axis as the camera coordinates // assume that +Y is pointing up and +X is pointing left. - const int yidx = S - 1 - yi; - const int xidx = S - 1 - xi; + const int yidx = H - 1 - yi; + const int xidx = W - 1 - xi; - const int pix_idx = n * S * S * K + yidx * S * K + xidx * K; + const int pix_idx = n * H * W * K + yidx * W * K + xidx * K; for (int k = 0; k < q_size; ++k) { point_idxs[pix_idx + k] = q[k].idx; zbuf[pix_idx + k] = q[k].z; @@ -488,7 +520,7 @@ __global__ void RasterizePointsFineCudaKernel( std::tuple RasterizePointsFineCuda( const at::Tensor& points, // (P, 3) const at::Tensor& bin_points, - const int image_size, + const std::tuple image_size, const at::Tensor& radius, const int bin_size, const int points_per_pixel) { @@ -503,18 +535,22 @@ std::tuple RasterizePointsFineCuda( cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const int N = bin_points.size(0); - const int B = bin_points.size(1); // num_bins + const int BH = bin_points.size(1); + const int BW = bin_points.size(2); const int M = bin_points.size(3); - const int S = image_size; const int K = points_per_pixel; + + const int H = std::get<0>(image_size); + const int W = std::get<1>(image_size); + if (K > kMaxPointsPerPixel) { AT_ERROR("Must have num_closest <= 150"); } auto int_opts = bin_points.options().dtype(at::kInt); auto float_opts = points.options().dtype(at::kFloat); - at::Tensor point_idxs = at::full({N, S, S, K}, -1, int_opts); - at::Tensor zbuf = at::full({N, S, S, K}, -1, float_opts); - at::Tensor pix_dists = at::full({N, S, S, K}, -1, float_opts); + at::Tensor point_idxs = at::full({N, H, W, K}, -1, int_opts); + at::Tensor zbuf = at::full({N, H, W, K}, -1, float_opts); + at::Tensor pix_dists = at::full({N, H, W, K}, -1, float_opts); if (point_idxs.numel() == 0) { AT_CUDA_CHECK(cudaGetLastError()); @@ -529,9 +565,11 @@ std::tuple RasterizePointsFineCuda( radius.contiguous().data_ptr(), bin_size, N, - B, + BH, + BW, M, - S, + H, + W, K, point_idxs.contiguous().data_ptr(), zbuf.contiguous().data_ptr(), @@ -571,8 +609,8 @@ __global__ void RasterizePointsBackwardCudaKernel( const int yidx = H - 1 - yi; const int xidx = W - 1 - xi; - const float xf = PixToNdc(xidx, W); - const float yf = PixToNdc(yidx, H); + const float xf = PixToNonSquareNdc(xidx, W, H); + const float yf = PixToNonSquareNdc(yidx, H, W); const int p = idxs[i]; if (p < 0) diff --git a/pytorch3d/csrc/rasterize_points/rasterize_points.h b/pytorch3d/csrc/rasterize_points/rasterize_points.h index f1ec1aaf..a13d9773 100644 --- a/pytorch3d/csrc/rasterize_points/rasterize_points.h +++ b/pytorch3d/csrc/rasterize_points/rasterize_points.h @@ -14,7 +14,7 @@ std::tuple RasterizePointsNaiveCpu( const torch::Tensor& points, const torch::Tensor& cloud_to_packed_first_idx, const torch::Tensor& num_points_per_cloud, - const int image_size, + const std::tuple image_size, const torch::Tensor& radius, const int points_per_pixel); @@ -24,7 +24,7 @@ RasterizePointsNaiveCuda( const torch::Tensor& points, const torch::Tensor& cloud_to_packed_first_idx, const torch::Tensor& num_points_per_cloud, - const int image_size, + const std::tuple image_size, const torch::Tensor& radius, const int points_per_pixel); #endif @@ -43,7 +43,8 @@ RasterizePointsNaiveCuda( // for each pointcloud in the batch. // radius: FloatTensor of shape (P) giving the radius (in NDC units) of // each point in points. -// image_size: (S) Size of the image to return (in pixels) +// image_size: Tuple (H, W) giving the size in pixels of the output +// image to be rasterized. // points_per_pixel: (K) The number closest of points to return for each pixel // // Returns: @@ -62,7 +63,7 @@ std::tuple RasterizePointsNaive( const torch::Tensor& points, const torch::Tensor& cloud_to_packed_first_idx, const torch::Tensor& num_points_per_cloud, - const int image_size, + const std::tuple image_size, const torch::Tensor& radius, const int points_per_pixel) { if (points.is_cuda() && cloud_to_packed_first_idx.is_cuda() && @@ -101,7 +102,7 @@ torch::Tensor RasterizePointsCoarseCpu( const torch::Tensor& points, const torch::Tensor& cloud_to_packed_first_idx, const torch::Tensor& num_points_per_cloud, - const int image_size, + const std::tuple image_size, const torch::Tensor& radius, const int bin_size, const int max_points_per_bin); @@ -111,7 +112,7 @@ torch::Tensor RasterizePointsCoarseCuda( const torch::Tensor& points, const torch::Tensor& cloud_to_packed_first_idx, const torch::Tensor& num_points_per_cloud, - const int image_size, + const std::tuple image_size, const torch::Tensor& radius, const int bin_size, const int max_points_per_bin); @@ -128,7 +129,8 @@ torch::Tensor RasterizePointsCoarseCuda( // for each pointcloud in the batch. // radius: FloatTensor of shape (P) giving the radius (in NDC units) of // each point in points. -// image_size: Size of the image to generate (in pixels) +// image_size: Tuple (H, W) giving the size in pixels of the output +// image to be rasterized. // bin_size: Size of each bin within the image (in pixels) // // Returns: @@ -140,7 +142,7 @@ torch::Tensor RasterizePointsCoarse( const torch::Tensor& points, const torch::Tensor& cloud_to_packed_first_idx, const torch::Tensor& num_points_per_cloud, - const int image_size, + const std::tuple image_size, const torch::Tensor& radius, const int bin_size, const int max_points_per_bin) { @@ -182,7 +184,7 @@ torch::Tensor RasterizePointsCoarse( std::tuple RasterizePointsFineCuda( const torch::Tensor& points, const torch::Tensor& bin_points, - const int image_size, + const std::tuple image_size, const torch::Tensor& radius, const int bin_size, const int points_per_pixel); @@ -194,7 +196,8 @@ std::tuple RasterizePointsFineCuda( // are expected to be in NDC coordinates in the range [-1, 1]. // bin_points: int32 Tensor of shape (N, B, B, M) giving the indices of points // that fall into each bin (output from coarse rasterization) -// image_size: Size of image to generate (in pixels) +// image_size: Tuple (H, W) giving the size in pixels of the output +// image to be rasterized. // radius: FloatTensor of shape (P) giving the radius (in NDC units) of // each point in points. // bin_size: Size of each bin (in pixels) @@ -214,7 +217,7 @@ std::tuple RasterizePointsFineCuda( std::tuple RasterizePointsFine( const torch::Tensor& points, const torch::Tensor& bin_points, - const int image_size, + const std::tuple image_size, const torch::Tensor& radius, const int bin_size, const int points_per_pixel) { @@ -303,7 +306,8 @@ torch::Tensor RasterizePointsBackward( // for each pointcloud in the batch. // radius: FloatTensor of shape (P) giving the radius (in NDC units) of // each point in points. -// image_size: (S) Size of the image to return (in pixels) +// image_size: Tuple (H, W) giving the size in pixels of the output +// image to be rasterized. // points_per_pixel: (K) The number of points to return for each pixel // bin_size: Bin size (in pixels) for coarse-to-fine rasterization. Setting // bin_size=0 uses naive rasterization instead. @@ -325,7 +329,7 @@ std::tuple RasterizePoints( const torch::Tensor& points, const torch::Tensor& cloud_to_packed_first_idx, const torch::Tensor& num_points_per_cloud, - const int image_size, + const std::tuple image_size, const torch::Tensor& radius, const int points_per_pixel, const int bin_size, diff --git a/pytorch3d/csrc/rasterize_points/rasterize_points_cpu.cpp b/pytorch3d/csrc/rasterize_points/rasterize_points_cpu.cpp index d7913f65..53cd6cba 100644 --- a/pytorch3d/csrc/rasterize_points/rasterize_points_cpu.cpp +++ b/pytorch3d/csrc/rasterize_points/rasterize_points_cpu.cpp @@ -3,33 +3,27 @@ #include #include #include - -// Given a pixel coordinate 0 <= i < S, convert it to a normalized device -// coordinate in the range [-1, 1]. The NDC range is divided into S evenly-sized -// pixels, and assume that each pixel falls in the *center* of its range. -static float PixToNdc(const int i, const int S) { - // NDC x-offset + (i * pixel_width + half_pixel_width) - return -1 + (2 * i + 1.0f) / S; -} +#include "rasterization_utils.h" std::tuple RasterizePointsNaiveCpu( const torch::Tensor& points, // (P, 3) const torch::Tensor& cloud_to_packed_first_idx, // (N) const torch::Tensor& num_points_per_cloud, // (N) - const int image_size, + const std::tuple image_size, const torch::Tensor& radius, const int points_per_pixel) { const int32_t N = cloud_to_packed_first_idx.size(0); // batch_size. - const int S = image_size; + const int H = std::get<0>(image_size); + const int W = std::get<1>(image_size); const int K = points_per_pixel; // Initialize output tensors. auto int_opts = num_points_per_cloud.options().dtype(torch::kInt32); auto float_opts = points.options().dtype(torch::kFloat32); - torch::Tensor point_idxs = torch::full({N, S, S, K}, -1, int_opts); - torch::Tensor zbuf = torch::full({N, S, S, K}, -1, float_opts); - torch::Tensor pix_dists = torch::full({N, S, S, K}, -1, float_opts); + torch::Tensor point_idxs = torch::full({N, H, W, K}, -1, int_opts); + torch::Tensor zbuf = torch::full({N, H, W, K}, -1, float_opts); + torch::Tensor pix_dists = torch::full({N, H, W, K}, -1, float_opts); auto points_a = points.accessor(); auto point_idxs_a = point_idxs.accessor(); @@ -46,16 +40,16 @@ std::tuple RasterizePointsNaiveCpu( const int point_stop_idx = (point_start_idx + num_points_per_cloud[n].item().to()); - for (int yi = 0; yi < S; ++yi) { + for (int yi = 0; yi < H; ++yi) { // Reverse the order of yi so that +Y is pointing upwards in the image. - const int yidx = S - 1 - yi; - const float yf = PixToNdc(yidx, S); + const int yidx = H - 1 - yi; + const float yf = PixToNonSquareNdc(yidx, H, W); - for (int xi = 0; xi < S; ++xi) { + for (int xi = 0; xi < W; ++xi) { // Reverse the order of xi so that +X is pointing to the left in the // image. - const int xidx = S - 1 - xi; - const float xf = PixToNdc(xidx, S); + const int xidx = W - 1 - xi; + const float xf = PixToNonSquareNdc(xidx, W, H); // Use a priority queue to hold (z, idx, r) std::priority_queue> q; @@ -99,25 +93,36 @@ torch::Tensor RasterizePointsCoarseCpu( const torch::Tensor& points, // (P, 3) const torch::Tensor& cloud_to_packed_first_idx, // (N) const torch::Tensor& num_points_per_cloud, // (N) - const int image_size, + const std::tuple image_size, const torch::Tensor& radius, const int bin_size, const int max_points_per_bin) { const int32_t N = cloud_to_packed_first_idx.size(0); // batch_size. - - const int B = 1 + (image_size - 1) / bin_size; // Integer division round up const int M = max_points_per_bin; + + const float H = std::get<0>(image_size); + const float W = std::get<1>(image_size); + + // Integer division round up. + const int BH = 1 + (H - 1) / bin_size; + const int BW = 1 + (W - 1) / bin_size; + auto opts = num_points_per_cloud.options().dtype(torch::kInt32); - torch::Tensor points_per_bin = torch::zeros({N, B, B}, opts); - torch::Tensor bin_points = torch::full({N, B, B, M}, -1, opts); + torch::Tensor points_per_bin = torch::zeros({N, BH, BW}, opts); + torch::Tensor bin_points = torch::full({N, BH, BW, M}, -1, opts); auto points_a = points.accessor(); auto points_per_bin_a = points_per_bin.accessor(); auto bin_points_a = bin_points.accessor(); auto radius_a = radius.accessor(); - const float pixel_width = 2.0f / image_size; - const float bin_width = pixel_width * bin_size; + const float ndc_x_range = NonSquareNdcRange(W, H); + const float pixel_width_x = ndc_x_range / W; + const float bin_width_x = pixel_width_x * bin_size; + + const float ndc_y_range = NonSquareNdcRange(H, W); + const float pixel_width_y = ndc_y_range / H; + const float bin_width_y = pixel_width_y * bin_size; for (int n = 0; n < N; ++n) { // Loop through each pointcloud in the batch. @@ -129,15 +134,15 @@ torch::Tensor RasterizePointsCoarseCpu( (point_start_idx + num_points_per_cloud[n].item().to()); float bin_y_min = -1.0f; - float bin_y_max = bin_y_min + bin_width; + float bin_y_max = bin_y_min + bin_width_y; // Iterate through the horizontal bins from top to bottom. - for (int by = 0; by < B; by++) { + for (int by = 0; by < BH; by++) { float bin_x_min = -1.0f; - float bin_x_max = bin_x_min + bin_width; + float bin_x_max = bin_x_min + bin_width_x; // Iterate through bins on this horizontal line, left to right. - for (int bx = 0; bx < B; bx++) { + for (int bx = 0; bx < BW; bx++) { int32_t points_hit = 0; for (int p = point_start_idx; p < point_stop_idx; ++p) { float px = points_a[p][0]; @@ -172,11 +177,11 @@ torch::Tensor RasterizePointsCoarseCpu( // Shift the bin to the right for the next loop iteration bin_x_min = bin_x_max; - bin_x_max = bin_x_min + bin_width; + bin_x_max = bin_x_min + bin_width_x; } // Shift the bin down for the next loop iteration bin_y_min = bin_y_max; - bin_y_max = bin_y_min + bin_width; + bin_y_max = bin_y_min + bin_width_y; } } return bin_points; @@ -194,11 +199,6 @@ torch::Tensor RasterizePointsBackwardCpu( const int W = idxs.size(2); const int K = idxs.size(3); - // For now only support square images. - // TODO(jcjohns): Extend to non-square images. - if (H != W) { - AT_ERROR("RasterizePointsBackwardCpu only supports square images"); - } torch::Tensor grad_points = torch::zeros({P, 3}, points.options()); auto points_a = points.accessor(); @@ -212,7 +212,7 @@ torch::Tensor RasterizePointsBackwardCpu( // Reverse the order of yi so that +Y is pointing upwards in the image. const int yidx = H - 1 - y; // Y coordinate of the top of the pixel. - const float yf = PixToNdc(yidx, H); + const float yf = PixToNonSquareNdc(yidx, H, W); // Iterate through pixels on this horizontal line, left to right. for (int x = 0; x < W; ++x) { // Loop over pixels in the row @@ -220,7 +220,7 @@ torch::Tensor RasterizePointsBackwardCpu( // Reverse the order of xi so that +X is pointing to the left in the // image. const int xidx = W - 1 - x; - const float xf = PixToNdc(xidx, W); + const float xf = PixToNonSquareNdc(xidx, W, H); for (int k = 0; k < K; ++k) { // Loop over points for the pixel const int p = idxs_a[n][y][x][k]; if (p < 0) { diff --git a/pytorch3d/renderer/mesh/__init__.py b/pytorch3d/renderer/mesh/__init__.py index a0a01086..d8b6b13b 100644 --- a/pytorch3d/renderer/mesh/__init__.py +++ b/pytorch3d/renderer/mesh/__init__.py @@ -6,6 +6,7 @@ from .rasterizer import MeshRasterizer, RasterizationSettings from .renderer import MeshRenderer from .shader import TexturedSoftPhongShader # DEPRECATED from .shader import ( + BlendParams, HardFlatShader, HardGouraudShader, HardPhongShader, diff --git a/pytorch3d/renderer/mesh/rasterize_meshes.py b/pytorch3d/renderer/mesh/rasterize_meshes.py index c0702631..d4cb24ab 100644 --- a/pytorch3d/renderer/mesh/rasterize_meshes.py +++ b/pytorch3d/renderer/mesh/rasterize_meshes.py @@ -1,7 +1,7 @@ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -from typing import Optional, Tuple, Union +from typing import List, Optional, Tuple, Union import numpy as np import torch @@ -20,7 +20,7 @@ kMaxFacesPerBin = 22 def rasterize_meshes( meshes, - image_size: Union[int, Tuple[int, int]] = 256, + image_size: Union[int, List[int], Tuple[int, int]] = 256, blur_radius: float = 0.0, faces_per_pixel: int = 8, bin_size: Optional[int] = None, @@ -219,7 +219,7 @@ class _RasterizeFaceVerts(torch.autograd.Function): face_verts, mesh_to_face_first_idx, num_faces_per_mesh, - image_size: Tuple[int, int] = (256, 256), + image_size: Union[List[int], Tuple[int, int]] = (256, 256), blur_radius: float = 0.01, faces_per_pixel: int = 0, bin_size: int = 0, @@ -287,11 +287,6 @@ class _RasterizeFaceVerts(torch.autograd.Function): return grads -def pix_to_ndc(i, S): - # NDC x-offset + (i * pixel_width + half_pixel_width) - return -1 + (2 * i + 1.0) / S - - def non_square_ndc_range(S1, S2): """ In the case of non square images, we scale the NDC range diff --git a/pytorch3d/renderer/points/compositor.py b/pytorch3d/renderer/points/compositor.py index 6f6c274c..650f0a3a 100644 --- a/pytorch3d/renderer/points/compositor.py +++ b/pytorch3d/renderer/points/compositor.py @@ -75,7 +75,7 @@ def _add_background_color_to_images(pix_idxs, images, background_color): pixels with accumulated features have unchanged values. """ # Initialize background mask - background_mask = pix_idxs[:, 0] < 0 # (N, image_size, image_size) + background_mask = pix_idxs[:, 0] < 0 # (N, H, W) # Convert background_color to an appropriate tensor and check shape if not torch.is_tensor(background_color): diff --git a/pytorch3d/renderer/points/rasterize_points.py b/pytorch3d/renderer/points/rasterize_points.py index 2942f708..80a0e99b 100644 --- a/pytorch3d/renderer/points/rasterize_points.py +++ b/pytorch3d/renderer/points/rasterize_points.py @@ -6,7 +6,7 @@ import torch # pyre-fixme[21]: Could not find name `_C` in `pytorch3d`. from pytorch3d import _C -from pytorch3d.renderer.mesh.rasterize_meshes import pix_to_ndc +from pytorch3d.renderer.mesh.rasterize_meshes import pix_to_non_square_ndc # Maxinum number of faces per bins for @@ -14,17 +14,30 @@ from pytorch3d.renderer.mesh.rasterize_meshes import pix_to_ndc kMaxPointsPerBin = 22 -# TODO(jcjohns): Support non-square images def rasterize_points( pointclouds, - image_size: int = 256, + image_size: Union[int, List[int], Tuple[int, int]] = 256, radius: Union[float, List, Tuple, torch.Tensor] = 0.01, points_per_pixel: int = 8, bin_size: Optional[int] = None, max_points_per_bin: Optional[int] = None, ): """ - Pointcloud rasterization + Each pointcloud is rasterized onto a separate image of shape + (H, W) if `image_size` is a tuple or (image_size, image_size) if it + is an int. + + If the desired image size is non square (i.e. a tuple of (H, W) where H != W) + the aspect ratio needs special consideration. There are two aspect ratios + to be aware of: + - the aspect ratio of each pixel + - the aspect ratio of the output image + The camera can be used to set the pixel aspect ratio. In the rasterizer, + we assume square pixels, but variable image aspect ratio (i.e rectangle images). + + In most cases you will want to set the camera aspect ratio to + 1.0 (i.e. square pixels) and only vary the + `image_size` (i.e. the output image dimensions in pix Args: pointclouds: A Pointclouds object representing a batch of point clouds to be @@ -34,7 +47,8 @@ def rasterize_points( be in normalized device coordinates (NDC): [-1, 1]^3 with the camera at (0, 0, 0); In the camera coordinate frame the x-axis goes from right-to-left, the y-axis goes from bottom-to-top, and the z-axis goes from back-to-front. - image_size: Integer giving the resolution of the rasterized image + image_size: Size in pixels of the output image to be rasterized. + Can optionally be a tuple of (H, W) in the case of non square images. radius (Optional): The radius (in NDC units) of the disk to be rasterized. This can either be a float in which case the same radius is used for each point, or a torch.Tensor of shape (N, P) giving a radius per point @@ -71,6 +85,9 @@ def rasterize_points( then `dists[n, y, x, k]` is the squared distance between the pixel (y, x) and the point `(points[n, p, 0], points[n, p, 1])`. Pixels hit with fewer than points_per_pixel are padded with -1. + + In the case that image_size is a tuple of (H, W) then the outputs + will be of shape `(N, H, W, ...)`. """ points_packed = pointclouds.points_packed() cloud_to_packed_first_idx = pointclouds.cloud_to_packed_first_idx() @@ -78,26 +95,46 @@ def rasterize_points( radius = _format_radius(radius, pointclouds) + # In the case that H != W use the max image size to set the bin_size + # to accommodate the num bins constraint in the coarse rasteizer. + # If the ratio of H:W is large this might cause issues as the smaller + # dimension will have fewer bins. + # TODO: consider a better way of setting the bin size. + if isinstance(image_size, (tuple, list)): + if len(image_size) != 2: + raise ValueError("Image size can only be a tuple/list of (H, W)") + if not all(i > 0 for i in image_size): + raise ValueError( + "Image sizes must be greater than 0; got %d, %d" % image_size + ) + if not all(type(i) == int for i in image_size): + raise ValueError("Image sizes must be integers; got %f, %f" % image_size) + max_image_size = max(*image_size) + im_size = image_size + else: + im_size = (image_size, image_size) + max_image_size = image_size + if bin_size is None: if not points_packed.is_cuda: # Binned CPU rasterization not fully implemented bin_size = 0 else: # TODO: These heuristics are not well-thought out! - if image_size <= 64: + if max_image_size <= 64: bin_size = 8 - elif image_size <= 256: + elif max_image_size <= 256: bin_size = 16 - elif image_size <= 512: + elif max_image_size <= 512: bin_size = 32 - elif image_size <= 1024: + elif max_image_size <= 1024: bin_size = 64 if bin_size != 0: # There is a limit on the number of points per bin in the cuda kernel. # pyre-fixme[58]: `//` is not supported for operand types `int` and # `Union[int, None, int]`. - points_per_bin = 1 + (image_size - 1) // bin_size + points_per_bin = 1 + (max_image_size - 1) // bin_size if points_per_bin >= kMaxPointsPerBin: raise ValueError( "bin_size too small, number of points per bin must be less than %d; got %d" @@ -114,7 +151,7 @@ def rasterize_points( points_packed, cloud_to_packed_first_idx, num_points_per_cloud, - image_size, + im_size, radius, points_per_pixel, bin_size, @@ -173,7 +210,7 @@ class _RasterizePoints(torch.autograd.Function): points, # (P, 3) cloud_to_packed_first_idx, num_points_per_cloud, - image_size: int = 256, + image_size: Union[List[int], Tuple[int, int]] = (256, 256), radius: Union[float, torch.Tensor] = 0.01, points_per_pixel: int = 8, bin_size: int = 0, @@ -225,7 +262,7 @@ class _RasterizePoints(torch.autograd.Function): def rasterize_points_python( pointclouds, - image_size: int = 256, + image_size: Union[int, Tuple[int, int]] = 256, radius: Union[float, torch.Tensor] = 0.01, points_per_pixel: int = 8, ): @@ -235,7 +272,12 @@ def rasterize_points_python( Inputs / Outputs: Same as above """ N = len(pointclouds) - S, K = image_size, points_per_pixel + H, W = ( + image_size + if isinstance(image_size, (tuple, list)) + else (image_size, image_size) + ) + K = points_per_pixel device = pointclouds.device points_packed = pointclouds.points_packed() @@ -247,11 +289,11 @@ def rasterize_points_python( # Intialize output tensors. point_idxs = torch.full( - (N, S, S, K), fill_value=-1, dtype=torch.int32, device=device + (N, H, W, K), fill_value=-1, dtype=torch.int32, device=device ) - zbuf = torch.full((N, S, S, K), fill_value=-1, dtype=torch.float32, device=device) + zbuf = torch.full((N, H, W, K), fill_value=-1, dtype=torch.float32, device=device) pix_dists = torch.full( - (N, S, S, K), fill_value=-1, dtype=torch.float32, device=device + (N, H, W, K), fill_value=-1, dtype=torch.float32, device=device ) # NDC is from [-1, 1]. Get pixel size using specified image size. @@ -263,18 +305,18 @@ def rasterize_points_python( point_stop_idx = point_start_idx + num_points_per_cloud[n] # Iterate through the horizontal lines of the image from top to bottom. - for yi in range(S): + for yi in range(H): # Y coordinate of one end of the image. Reverse the ordering # of yi so that +Y is pointing up in the image. - yfix = S - 1 - yi - yf = pix_to_ndc(yfix, S) + yfix = H - 1 - yi + yf = pix_to_non_square_ndc(yfix, H, W) # Iterate through pixels on this horizontal line, left to right. - for xi in range(S): + for xi in range(W): # X coordinate of one end of the image. Reverse the ordering # of xi so that +X is pointing to the left in the image. - xfix = S - 1 - xi - xf = pix_to_ndc(xfix, S) + xfix = W - 1 - xi + xf = pix_to_non_square_ndc(xfix, W, H) top_k_points = [] # Check whether each point in the batch affects this pixel. diff --git a/pytorch3d/renderer/points/rasterizer.py b/pytorch3d/renderer/points/rasterizer.py index 85e93e42..e8794f4c 100644 --- a/pytorch3d/renderer/points/rasterizer.py +++ b/pytorch3d/renderer/points/rasterizer.py @@ -2,7 +2,7 @@ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -from typing import NamedTuple, Optional, Union +from typing import NamedTuple, Optional, Tuple, Union import torch import torch.nn as nn @@ -29,7 +29,7 @@ class PointsRasterizationSettings: def __init__( self, - image_size: int = 256, + image_size: Union[int, Tuple[int, int]] = 256, radius: Union[float, torch.Tensor] = 0.01, points_per_pixel: int = 8, bin_size: Optional[int] = None, diff --git a/tests/bm_rasterize_points.py b/tests/bm_rasterize_points.py index d00e45ac..70e4a778 100644 --- a/tests/bm_rasterize_points.py +++ b/tests/bm_rasterize_points.py @@ -74,6 +74,21 @@ def bm_python_vs_cpu_vs_cuda() -> None: kwargs_list += [ {"N": 32, "P": 100000, "img_size": 128, "radius": 0.01, "pts_per_pxl": 50}, {"N": 8, "P": 200000, "img_size": 512, "radius": 0.01, "pts_per_pxl": 50}, + {"N": 8, "P": 200000, "img_size": 256, "radius": 0.01, "pts_per_pxl": 50}, + { + "N": 8, + "P": 200000, + "img_size": (512, 256), + "radius": 0.01, + "pts_per_pxl": 50, + }, + { + "N": 8, + "P": 200000, + "img_size": (256, 512), + "radius": 0.01, + "pts_per_pxl": 50, + }, ] for k in kwargs_list: k["device"] = "cuda" diff --git a/tests/data/test_pointcloud_rectangle_image.png b/tests/data/test_pointcloud_rectangle_image.png new file mode 100644 index 00000000..8bf30329 Binary files /dev/null and b/tests/data/test_pointcloud_rectangle_image.png differ diff --git a/tests/test_rasterize_points.py b/tests/test_rasterize_points.py index eef3b85e..dc59e9b2 100644 --- a/tests/test_rasterize_points.py +++ b/tests/test_rasterize_points.py @@ -404,7 +404,7 @@ class TestRasterizePoints(TestCaseMixin, unittest.TestCase): torch.manual_seed(231) N = 3 max_P = 1000 - image_size = 64 + image_size = (64, 64) radius = 0.1 bin_size = 16 max_points_per_bin = 500 @@ -501,7 +501,7 @@ class TestRasterizePoints(TestCaseMixin, unittest.TestCase): device=device ) # fmt: on - image_size = 16 + image_size = (16, 16) radius = 0.2 bin_size = 8 max_points_per_bin = 5 diff --git a/tests/test_rasterize_rectangles.py b/tests/test_rasterize_rectangle_images.py similarity index 51% rename from tests/test_rasterize_rectangles.py rename to tests/test_rasterize_rectangle_images.py index 72bb6fc5..98424938 100644 --- a/tests/test_rasterize_rectangles.py +++ b/tests/test_rasterize_rectangle_images.py @@ -12,19 +12,33 @@ from pytorch3d.io import load_obj from pytorch3d.renderer.cameras import FoVPerspectiveCameras, look_at_view_transform from pytorch3d.renderer.lighting import PointLights from pytorch3d.renderer.materials import Materials -from pytorch3d.renderer.mesh import TexturesUV +from pytorch3d.renderer.mesh import ( + BlendParams, + MeshRasterizer, + MeshRenderer, + RasterizationSettings, + SoftPhongShader, + TexturesUV, +) from pytorch3d.renderer.mesh.rasterize_meshes import ( rasterize_meshes, rasterize_meshes_python, ) -from pytorch3d.renderer.mesh.rasterizer import ( - Fragments, - MeshRasterizer, - RasterizationSettings, +from pytorch3d.renderer.mesh.rasterizer import Fragments +from pytorch3d.renderer.points import ( + AlphaCompositor, + PointsRasterizationSettings, + PointsRasterizer, + PointsRenderer, ) -from pytorch3d.renderer.mesh.renderer import MeshRenderer -from pytorch3d.renderer.mesh.shader import BlendParams, SoftPhongShader -from pytorch3d.structures import Meshes +from pytorch3d.renderer.points.rasterize_points import ( + rasterize_points, + rasterize_points_python, +) +from pytorch3d.renderer.points.rasterizer import PointFragments +from pytorch3d.structures import Meshes, Pointclouds +from pytorch3d.transforms.transform3d import Transform3d +from pytorch3d.utils import torus DEBUG = False @@ -44,9 +58,36 @@ verts0 = torch.tensor( ) faces0 = torch.tensor([[1, 0, 2], [4, 3, 5]], dtype=torch.int64) +# Points for a simple pointcloud. Get the vertices from a +# torus and apply rotations such that the points are no longer +# symmerical in X/Y. +torus_mesh = torus(r=0.25, R=1.0, sides=5, rings=2 * 5) +t = ( + Transform3d() + .rotate_axis_angle(angle=90, axis="Y") + .rotate_axis_angle(angle=45, axis="Z") + .scale(0.3) +) +torus_points = t.transform_points(torus_mesh.verts_padded()).squeeze() -class TestRasterizeRectanglesErrors(TestCaseMixin, unittest.TestCase): - def test_image_size_arg(self): + +def _save_debug_image(idx, image_size, bin_size, blur): + """ + Save a mask image from the rasterization output for debugging. + """ + H, W = image_size + # Save out the last image for debugging + rgb = (idx[-1, ..., :3].cpu() > -1).squeeze() + suffix = "square" if H == W else "non_square" + filename = "%s_bin_size_%s_blur_%.3f_%dx%d.png" + filename = filename % (suffix, str(bin_size), blur, H, W) + if DEBUG: + filename = "DEBUG_%s" % filename + Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(DATA_DIR / filename) + + +class TestRasterizeRectangleImagesErrors(TestCaseMixin, unittest.TestCase): + def test_mesh_image_size_arg(self): meshes = Meshes(verts=[verts0], faces=[faces0]) with self.assertRaises(ValueError) as cm: @@ -76,8 +117,38 @@ class TestRasterizeRectanglesErrors(TestCaseMixin, unittest.TestCase): ) self.assertTrue("sizes must be integers" in cm.msg) + def test_points_image_size_arg(self): + points = Pointclouds([verts0]) -class TestRasterizeRectangles(TestCaseMixin, unittest.TestCase): + with self.assertRaises(ValueError) as cm: + rasterize_points( + points, + (100, 200, 3), + 0.0001, + points_per_pixel=1, + ) + self.assertTrue("tuple/list of (H, W)" in cm.msg) + + with self.assertRaises(ValueError) as cm: + rasterize_points( + points, + (0, 10), + 0.0001, + points_per_pixel=1, + ) + self.assertTrue("sizes must be positive" in cm.msg) + + with self.assertRaises(ValueError) as cm: + rasterize_points( + points, + (100.5, 120.5), + 0.0001, + points_per_pixel=1, + ) + self.assertTrue("sizes must be integers" in cm.msg) + + +class TestRasterizeRectangleImagesMeshes(TestCaseMixin, unittest.TestCase): @staticmethod def _clone_mesh(verts0, faces0, device, batch_size): """ @@ -164,7 +235,7 @@ class TestRasterizeRectangles(TestCaseMixin, unittest.TestCase): meshes_sq, image_size=(S, S), bin_size=0, blur=blur ) # Save debug image - self._save_debug_image(square_fragments, (S, S), 0, blur) + _save_debug_image(square_fragments.pix_to_face, (S, S), 0, blur) # Extract the values in the square image which are non zero. square_mask = square_fragments.pix_to_face > -1 @@ -284,8 +355,8 @@ class TestRasterizeRectangles(TestCaseMixin, unittest.TestCase): ) # Save out debug images if needed - self._save_debug_image(fragments_naive, image_size, 0, blur) - self._save_debug_image(fragments_binned, image_size, None, blur) + _save_debug_image(fragments_naive.pix_to_face, image_size, 0, blur) + _save_debug_image(fragments_binned.pix_to_face, image_size, None, blur) # Check naive and binned fragments give the same outputs self._check_fragments(fragments_naive, fragments_binned) @@ -354,8 +425,8 @@ class TestRasterizeRectangles(TestCaseMixin, unittest.TestCase): ) # Save debug images if DEBUG is set to true at the top of the file. - self._save_debug_image(fragments_naive, image_size, 0, blur) - self._save_debug_image(fragments_python, image_size, "python", blur) + _save_debug_image(fragments_naive.pix_to_face, image_size, 0, blur) + _save_debug_image(fragments_python.pix_to_face, image_size, "python", blur) # List of non square outputs to compare with the square output nonsq_fragment_gradtensor_list = [ @@ -437,3 +508,293 @@ class TestRasterizeRectangles(TestCaseMixin, unittest.TestCase): # NOTE some pixels can be flaky cond1 = torch.allclose(rgb, image_ref, atol=0.05) self.assertTrue(cond1) + + +class TestRasterizeRectangleImagesPointclouds(TestCaseMixin, unittest.TestCase): + @staticmethod + def _clone_pointcloud(verts0, device, batch_size): + """ + Helper function to detach and clone the verts. + This is needed in order to set up the tensors for + gradient computation in different tests. + """ + verts = verts0.detach().clone() + verts.requires_grad = True + pointclouds = Pointclouds(points=[verts]) + pointclouds = pointclouds.to(device).extend(batch_size) + return verts, pointclouds + + def _rasterize(self, meshes, image_size, bin_size, blur): + """ + Simple wrapper around the rasterize function to return + the fragment data. + """ + idxs, zbuf, dists = rasterize_points( + meshes, + image_size, + blur, + points_per_pixel=1, + bin_size=bin_size, + ) + return PointFragments( + idx=idxs, + zbuf=zbuf, + dists=dists, + ) + + def _check_fragments(self, frag_1, frag_2): + """ + Helper function to check that the tensors in + the Fragments frag_1 and frag_2 are the same. + """ + self.assertClose(frag_1.idx, frag_2.idx) + self.assertClose(frag_1.dists, frag_2.dists) + self.assertClose(frag_1.zbuf, frag_2.zbuf) + + def _compare_square_with_nonsq( + self, + image_size, + blur, + device, + points, + nonsq_fragment_gradtensor_list, + batch_size=1, + ): + """ + Calculate the output from rasterizing a square image with the minimum of (H, W). + Then compare this with the same square region in the non square image. + The input points are contained within the [-1, 1] range of the image + so all the relevant pixels will be within the square region. + + `nonsq_fragment_gradtensor_list` is a list of fragments and verts grad tensors + from rasterizing non square images. + """ + # Rasterize the square version of the image + H, W = image_size + S = min(H, W) + points_square, pointclouds_sq = self._clone_pointcloud( + points, device, batch_size + ) + square_fragments = self._rasterize( + pointclouds_sq, image_size=(S, S), bin_size=0, blur=blur + ) + # Save debug image + _save_debug_image(square_fragments.idx, (S, S), 0, blur) + + # Extract the values in the square image which are non zero. + square_mask = square_fragments.idx > -1 + square_dists = square_fragments.dists[square_mask] + square_zbuf = square_fragments.zbuf[square_mask] + + # Retain gradients on the output of fragments to check + # intermediate values with the non square outputs. + square_fragments.dists.retain_grad() + square_fragments.zbuf.retain_grad() + + # Calculate gradient for the square image + torch.manual_seed(231) + grad_zbuf = torch.randn_like(square_zbuf) + grad_dist = torch.randn_like(square_dists) + loss0 = (grad_dist * square_dists).sum() + (grad_zbuf * square_zbuf).sum() + loss0.backward() + + # Now compare against the non square outputs provided + # in the nonsq_fragment_gradtensor_list list + for fragments, grad_tensor, _name in nonsq_fragment_gradtensor_list: + # Check that there are the same number of non zero pixels + # in both the square and non square images. + non_square_mask = fragments.idx > -1 + self.assertEqual(non_square_mask.sum().item(), square_mask.sum().item()) + + # Check dists, zbuf and bary match the square image + non_square_dists = fragments.dists[non_square_mask] + non_square_zbuf = fragments.zbuf[non_square_mask] + self.assertClose(square_dists, non_square_dists) + self.assertClose(square_zbuf, non_square_zbuf) + + # Retain gradients to compare values with outputs from + # square image + fragments.dists.retain_grad() + fragments.zbuf.retain_grad() + loss1 = (grad_dist * non_square_dists).sum() + ( + grad_zbuf * non_square_zbuf + ).sum() + loss1.sum().backward() + + # Get the non zero values in the intermediate gradients + # and compare with the values from the square image + non_square_grad_dists = fragments.dists.grad[non_square_mask] + non_square_grad_zbuf = fragments.zbuf.grad[non_square_mask] + + self.assertClose( + non_square_grad_dists, + square_fragments.dists.grad[square_mask], + ) + self.assertClose( + non_square_grad_zbuf, + square_fragments.zbuf.grad[square_mask], + ) + + # Finally check the gradients of the input vertices for + # the square and non square case + self.assertClose(points_square.grad, grad_tensor.grad, rtol=2e-4) + + def test_gpu(self): + """ + Test that the output of rendering non square images + gives the same result as square images. i.e. the + dists, zbuf, idx are all the same for the square + region which is present in both images. + """ + # Test both cases: (W > H), (H > W) + image_sizes = [(64, 128), (128, 64), (128, 256), (256, 128)] + + devices = ["cuda:0"] + blurs = [5e-2] + batch_sizes = [1, 4] + test_cases = product(image_sizes, blurs, devices, batch_sizes) + + for image_size, blur, device, batch_size in test_cases: + # Initialize the verts grad tensor and the meshes objects + verts_nonsq_naive, pointcloud_nonsq_naive = self._clone_pointcloud( + torus_points, device, batch_size + ) + verts_nonsq_binned, pointcloud_nonsq_binned = self._clone_pointcloud( + torus_points, device, batch_size + ) + + # Get the outputs for both naive and coarse to fine rasterization + fragments_naive = self._rasterize( + pointcloud_nonsq_naive, + image_size, + blur=blur, + bin_size=0, + ) + fragments_binned = self._rasterize( + pointcloud_nonsq_binned, + image_size, + blur=blur, + bin_size=None, + ) + + # Save out debug images if needed + _save_debug_image(fragments_naive.idx, image_size, 0, blur) + _save_debug_image(fragments_binned.idx, image_size, None, blur) + + # Check naive and binned fragments give the same outputs + self._check_fragments(fragments_naive, fragments_binned) + + # Here we want to compare the square image with the naive and the + # coarse to fine methods outputs + nonsq_fragment_gradtensor_list = [ + (fragments_naive, verts_nonsq_naive, "naive"), + (fragments_binned, verts_nonsq_binned, "coarse-to-fine"), + ] + + self._compare_square_with_nonsq( + image_size, + blur, + device, + torus_points, + nonsq_fragment_gradtensor_list, + batch_size, + ) + + def test_cpu(self): + """ + Test that the output of rendering non square images + gives the same result as square images. i.e. the + dists, zbuf, idx are all the same for the square + region which is present in both images. + + In this test we compare between the naive C++ implementation + and the naive python implementation as the Coarse/Fine + method is not fully implemented in C++ + """ + # Test both when (W > H) and (H > W). + # Using smaller image sizes here as the Python rasterizer is really slow. + image_sizes = [(32, 64), (64, 32)] + devices = ["cpu"] + blurs = [5e-2] + batch_sizes = [1] + test_cases = product(image_sizes, blurs, devices, batch_sizes) + + for image_size, blur, device, batch_size in test_cases: + # Initialize the verts grad tensor and the meshes objects + verts_nonsq_naive, pointcloud_nonsq_naive = self._clone_pointcloud( + torus_points, device, batch_size + ) + verts_nonsq_python, pointcloud_nonsq_python = self._clone_pointcloud( + torus_points, device, batch_size + ) + + # Compare Naive CPU with Python as Coarse/Fine rasteriztation + # is not implemented for CPU + fragments_naive = self._rasterize( + pointcloud_nonsq_naive, image_size, bin_size=0, blur=blur + ) + idxs, zbuf, pix_dists = rasterize_points_python( + pointcloud_nonsq_python, + image_size, + blur, + points_per_pixel=1, + ) + fragments_python = PointFragments( + idx=idxs, + zbuf=zbuf, + dists=pix_dists, + ) + + # Save debug images if DEBUG is set to true at the top of the file. + _save_debug_image(fragments_naive.idx, image_size, 0, blur) + _save_debug_image(fragments_python.idx, image_size, "python", blur) + + # List of non square outputs to compare with the square output + nonsq_fragment_gradtensor_list = [ + (fragments_naive, verts_nonsq_naive, "naive"), + (fragments_python, verts_nonsq_python, "python"), + ] + self._compare_square_with_nonsq( + image_size, + blur, + device, + torus_points, + nonsq_fragment_gradtensor_list, + batch_size, + ) + + def test_render_pointcloud(self): + """ + Test a textured poincloud is rendered correctly in a non square image. + """ + device = torch.device("cuda:0") + pointclouds = Pointclouds( + points=[torus_points * 2.0], + features=torch.ones_like(torus_points[None, ...]), + ).to(device) + R, T = look_at_view_transform(2.7, 0.0, 0.0) + cameras = FoVPerspectiveCameras(device=device, R=R, T=T) + raster_settings = PointsRasterizationSettings( + image_size=(512, 1024), radius=5e-2, points_per_pixel=1 + ) + rasterizer = PointsRasterizer(cameras=cameras, raster_settings=raster_settings) + compositor = AlphaCompositor() + renderer = PointsRenderer(rasterizer=rasterizer, compositor=compositor) + + # Load reference image + image_ref = load_rgb_image("test_pointcloud_rectangle_image.png", DATA_DIR) + + for bin_size in [0, None]: + # Check both naive and coarse to fine produce the same output. + renderer.rasterizer.raster_settings.bin_size = bin_size + images = renderer(pointclouds) + rgb = images[0, ..., :3].squeeze().cpu() + + if DEBUG: + Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save( + DATA_DIR / "DEBUG_pointcloud_rectangle_image.png" + ) + + # NOTE some pixels can be flaky + cond1 = torch.allclose(rgb, image_ref, atol=0.05) + self.assertTrue(cond1)