diff --git a/pytorch3d/csrc/rasterize_points/rasterize_points.cu b/pytorch3d/csrc/rasterize_points/rasterize_points.cu index 8872f370..d02a5680 100644 --- a/pytorch3d/csrc/rasterize_points/rasterize_points.cu +++ b/pytorch3d/csrc/rasterize_points/rasterize_points.cu @@ -38,13 +38,15 @@ __device__ void CheckPixelInsidePoint( float& q_max_z, int& q_max_idx, PointQ& q, - const float radius2, + const float* radius, const float xf, const float yf, const int K) { const float px = points[p_idx * 3 + 0]; const float py = points[p_idx * 3 + 1]; const float pz = points[p_idx * 3 + 2]; + const float p_radius = radius[p_idx]; + const float radius2 = p_radius * p_radius; if (pz < 0) return; // Don't render points behind the camera const float dx = xf - px; @@ -81,7 +83,7 @@ __global__ void RasterizePointsNaiveCudaKernel( const float* points, // (P, 3) const int64_t* cloud_to_packed_first_idx, // (N) const int64_t* num_points_per_cloud, // (N) - const float radius, + const float* radius, const int N, const int S, const int K, @@ -91,7 +93,6 @@ __global__ void RasterizePointsNaiveCudaKernel( // Simple version: One thread per output pixel const int num_threads = gridDim.x * blockDim.x; const int tid = blockDim.x * blockIdx.x + threadIdx.x; - const float radius2 = radius * radius; for (int i = tid; i < N * S * S; i += num_threads) { // Convert linear index to 3D index const int n = i / (S * S); // Batch index @@ -128,7 +129,7 @@ __global__ void RasterizePointsNaiveCudaKernel( for (int p_idx = point_start_idx; p_idx < point_stop_idx; ++p_idx) { CheckPixelInsidePoint( - points, p_idx, q_size, q_max_z, q_max_idx, q, radius2, xf, yf, K); + points, p_idx, q_size, q_max_z, q_max_idx, q, radius, xf, yf, K); } BubbleSort(q, q_size); int idx = n * S * S * K + pix_idx * K; @@ -145,7 +146,7 @@ std::tuple RasterizePointsNaiveCuda( const at::Tensor& cloud_to_packed_first_idx, // (N) const at::Tensor& num_points_per_cloud, // (N) const int image_size, - const float radius, + const at::Tensor& radius, const int points_per_pixel) { // Check inputs are on the same device at::TensorArg points_t{points, "points", 1}, @@ -194,7 +195,7 @@ std::tuple RasterizePointsNaiveCuda( points.contiguous().data_ptr(), cloud_to_packed_first_idx.contiguous().data_ptr(), num_points_per_cloud.contiguous().data_ptr(), - radius, + radius.contiguous().data_ptr(), N, S, K, @@ -214,7 +215,7 @@ __global__ void RasterizePointsCoarseCudaKernel( const float* points, // (P, 3) const int64_t* cloud_to_packed_first_idx, // (N) const int64_t* num_points_per_cloud, // (N) - const float radius, + const float* radius, const int N, const int P, const int S, @@ -266,12 +267,13 @@ __global__ void RasterizePointsCoarseCudaKernel( const float px = points[p_idx * 3 + 0]; const float py = points[p_idx * 3 + 1]; const float pz = points[p_idx * 3 + 2]; + const float p_radius = radius[p_idx]; if (pz < 0) continue; // Don't render points behind the camera. - const float px0 = px - radius; - const float px1 = px + radius; - const float py0 = py - radius; - const float py1 = py + radius; + const float px0 = px - p_radius; + const float px1 = px + p_radius; + const float py0 = py - p_radius; + const float py1 = py + p_radius; // Brute-force search over all bins; TODO something smarter? // For example we could compute the exact bin where the point falls, @@ -341,7 +343,7 @@ at::Tensor RasterizePointsCoarseCuda( const at::Tensor& cloud_to_packed_first_idx, // (N) const at::Tensor& num_points_per_cloud, // (N) const int image_size, - const float radius, + const at::Tensor& radius, const int bin_size, const int max_points_per_bin) { TORCH_CHECK( @@ -390,7 +392,7 @@ at::Tensor RasterizePointsCoarseCuda( points.contiguous().data_ptr(), cloud_to_packed_first_idx.contiguous().data_ptr(), num_points_per_cloud.contiguous().data_ptr(), - radius, + radius.contiguous().data_ptr(), N, P, image_size, @@ -411,7 +413,7 @@ at::Tensor RasterizePointsCoarseCuda( __global__ void RasterizePointsFineCudaKernel( const float* points, // (P, 3) const int32_t* bin_points, // (N, B, B, T) - const float radius, + const float* radius, const int bin_size, const int N, const int B, // num_bins @@ -425,7 +427,6 @@ __global__ void RasterizePointsFineCudaKernel( const int num_pixels = N * B * B * bin_size * bin_size; const int num_threads = gridDim.x * blockDim.x; const int tid = blockIdx.x * blockDim.x + threadIdx.x; - const float radius2 = radius * radius; for (int pid = tid; pid < num_pixels; pid += num_threads) { // Convert linear index into bin and pixel indices. We make the within @@ -464,7 +465,7 @@ __global__ void RasterizePointsFineCudaKernel( continue; } CheckPixelInsidePoint( - points, p, q_size, q_max_z, q_max_idx, q, radius2, xf, yf, K); + points, p, q_size, q_max_z, q_max_idx, q, radius, xf, yf, K); } // Now we've looked at all the points for this bin, so we can write // output for the current pixel. @@ -488,7 +489,7 @@ std::tuple RasterizePointsFineCuda( const at::Tensor& points, // (P, 3) const at::Tensor& bin_points, const int image_size, - const float radius, + const at::Tensor& radius, const int bin_size, const int points_per_pixel) { // Check inputs are on the same device @@ -525,7 +526,7 @@ std::tuple RasterizePointsFineCuda( RasterizePointsFineCudaKernel<<>>( points.contiguous().data_ptr(), bin_points.contiguous().data_ptr(), - radius, + radius.contiguous().data_ptr(), bin_size, N, B, diff --git a/pytorch3d/csrc/rasterize_points/rasterize_points.h b/pytorch3d/csrc/rasterize_points/rasterize_points.h index 6f557e05..f1ec1aaf 100644 --- a/pytorch3d/csrc/rasterize_points/rasterize_points.h +++ b/pytorch3d/csrc/rasterize_points/rasterize_points.h @@ -15,7 +15,7 @@ std::tuple RasterizePointsNaiveCpu( const torch::Tensor& cloud_to_packed_first_idx, const torch::Tensor& num_points_per_cloud, const int image_size, - const float radius, + const torch::Tensor& radius, const int points_per_pixel); #ifdef WITH_CUDA @@ -25,7 +25,7 @@ RasterizePointsNaiveCuda( const torch::Tensor& cloud_to_packed_first_idx, const torch::Tensor& num_points_per_cloud, const int image_size, - const float radius, + const torch::Tensor& radius, const int points_per_pixel); #endif // Naive (forward) pointcloud rasterization: For each pixel, for each point, @@ -41,7 +41,8 @@ RasterizePointsNaiveCuda( // in the batch where N is the batch size. // num_points_per_cloud: LongTensor of shape (N) giving the number of points // for each pointcloud in the batch. -// radius: Radius of each point (in NDC units) +// radius: FloatTensor of shape (P) giving the radius (in NDC units) of +// each point in points. // image_size: (S) Size of the image to return (in pixels) // points_per_pixel: (K) The number closest of points to return for each pixel // @@ -62,7 +63,7 @@ std::tuple RasterizePointsNaive( const torch::Tensor& cloud_to_packed_first_idx, const torch::Tensor& num_points_per_cloud, const int image_size, - const float radius, + const torch::Tensor& radius, const int points_per_pixel) { if (points.is_cuda() && cloud_to_packed_first_idx.is_cuda() && num_points_per_cloud.is_cuda()) { @@ -70,6 +71,7 @@ std::tuple RasterizePointsNaive( CHECK_CUDA(points); CHECK_CUDA(cloud_to_packed_first_idx); CHECK_CUDA(num_points_per_cloud); + CHECK_CUDA(radius); return RasterizePointsNaiveCuda( points, cloud_to_packed_first_idx, @@ -100,7 +102,7 @@ torch::Tensor RasterizePointsCoarseCpu( const torch::Tensor& cloud_to_packed_first_idx, const torch::Tensor& num_points_per_cloud, const int image_size, - const float radius, + const torch::Tensor& radius, const int bin_size, const int max_points_per_bin); @@ -110,7 +112,7 @@ torch::Tensor RasterizePointsCoarseCuda( const torch::Tensor& cloud_to_packed_first_idx, const torch::Tensor& num_points_per_cloud, const int image_size, - const float radius, + const torch::Tensor& radius, const int bin_size, const int max_points_per_bin); #endif @@ -124,7 +126,8 @@ torch::Tensor RasterizePointsCoarseCuda( // in the batch where N is the batch size. // num_points_per_cloud: LongTensor of shape (N) giving the number of points // for each pointcloud in the batch. -// radius: Radius of points to rasterize (in NDC units) +// radius: FloatTensor of shape (P) giving the radius (in NDC units) of +// each point in points. // image_size: Size of the image to generate (in pixels) // bin_size: Size of each bin within the image (in pixels) // @@ -138,7 +141,7 @@ torch::Tensor RasterizePointsCoarse( const torch::Tensor& cloud_to_packed_first_idx, const torch::Tensor& num_points_per_cloud, const int image_size, - const float radius, + const torch::Tensor& radius, const int bin_size, const int max_points_per_bin) { if (points.is_cuda() && cloud_to_packed_first_idx.is_cuda() && @@ -147,6 +150,7 @@ torch::Tensor RasterizePointsCoarse( CHECK_CUDA(points); CHECK_CUDA(cloud_to_packed_first_idx); CHECK_CUDA(num_points_per_cloud); + CHECK_CUDA(radius); return RasterizePointsCoarseCuda( points, cloud_to_packed_first_idx, @@ -179,7 +183,7 @@ std::tuple RasterizePointsFineCuda( const torch::Tensor& points, const torch::Tensor& bin_points, const int image_size, - const float radius, + const torch::Tensor& radius, const int bin_size, const int points_per_pixel); #endif @@ -191,7 +195,8 @@ std::tuple RasterizePointsFineCuda( // bin_points: int32 Tensor of shape (N, B, B, M) giving the indices of points // that fall into each bin (output from coarse rasterization) // image_size: Size of image to generate (in pixels) -// radius: Radius of points to rasterize (NDC units) +// radius: FloatTensor of shape (P) giving the radius (in NDC units) of +// each point in points. // bin_size: Size of each bin (in pixels) // points_per_pixel: How many points to rasterize for each pixel // @@ -210,7 +215,7 @@ std::tuple RasterizePointsFine( const torch::Tensor& points, const torch::Tensor& bin_points, const int image_size, - const float radius, + const torch::Tensor& radius, const int bin_size, const int points_per_pixel) { if (points.is_cuda()) { @@ -296,7 +301,8 @@ torch::Tensor RasterizePointsBackward( // in the batch where N is the batch size. // num_points_per_cloud: LongTensor of shape (N) giving the number of points // for each pointcloud in the batch. -// radius: Radius of each point (in NDC units) +// radius: FloatTensor of shape (P) giving the radius (in NDC units) of +// each point in points. // image_size: (S) Size of the image to return (in pixels) // points_per_pixel: (K) The number of points to return for each pixel // bin_size: Bin size (in pixels) for coarse-to-fine rasterization. Setting @@ -320,7 +326,7 @@ std::tuple RasterizePoints( const torch::Tensor& cloud_to_packed_first_idx, const torch::Tensor& num_points_per_cloud, const int image_size, - const float radius, + const torch::Tensor& radius, const int points_per_pixel, const int bin_size, const int max_points_per_bin) { diff --git a/pytorch3d/csrc/rasterize_points/rasterize_points_cpu.cpp b/pytorch3d/csrc/rasterize_points/rasterize_points_cpu.cpp index ab687680..d7913f65 100644 --- a/pytorch3d/csrc/rasterize_points/rasterize_points_cpu.cpp +++ b/pytorch3d/csrc/rasterize_points/rasterize_points_cpu.cpp @@ -17,7 +17,7 @@ std::tuple RasterizePointsNaiveCpu( const torch::Tensor& cloud_to_packed_first_idx, // (N) const torch::Tensor& num_points_per_cloud, // (N) const int image_size, - const float radius, + const torch::Tensor& radius, const int points_per_pixel) { const int32_t N = cloud_to_packed_first_idx.size(0); // batch_size. @@ -35,8 +35,8 @@ std::tuple RasterizePointsNaiveCpu( auto point_idxs_a = point_idxs.accessor(); auto zbuf_a = zbuf.accessor(); auto pix_dists_a = pix_dists.accessor(); + auto radius_a = radius.accessor(); - const float radius2 = radius * radius; for (int n = 0; n < N; ++n) { // Loop through each pointcloud in the batch. // Get the start index of the points in points_packed and the num points @@ -63,6 +63,8 @@ std::tuple RasterizePointsNaiveCpu( const float px = points_a[p][0]; const float py = points_a[p][1]; const float pz = points_a[p][2]; + const float p_radius = radius_a[p]; + const float radius2 = p_radius * p_radius; if (pz < 0) { continue; } @@ -98,7 +100,7 @@ torch::Tensor RasterizePointsCoarseCpu( const torch::Tensor& cloud_to_packed_first_idx, // (N) const torch::Tensor& num_points_per_cloud, // (N) const int image_size, - const float radius, + const torch::Tensor& radius, const int bin_size, const int max_points_per_bin) { const int32_t N = cloud_to_packed_first_idx.size(0); // batch_size. @@ -112,6 +114,7 @@ torch::Tensor RasterizePointsCoarseCpu( auto points_a = points.accessor(); auto points_per_bin_a = points_per_bin.accessor(); auto bin_points_a = bin_points.accessor(); + auto radius_a = radius.accessor(); const float pixel_width = 2.0f / image_size; const float bin_width = pixel_width * bin_size; @@ -140,13 +143,14 @@ torch::Tensor RasterizePointsCoarseCpu( float px = points_a[p][0]; float py = points_a[p][1]; float pz = points_a[p][2]; + const float p_radius = radius_a[p]; if (pz < 0) { continue; } - float point_x_min = px - radius; - float point_x_max = px + radius; - float point_y_min = py - radius; - float point_y_max = py + radius; + float point_x_min = px - p_radius; + float point_x_max = px + p_radius; + float point_y_min = py - p_radius; + float point_y_max = py + p_radius; // Use a half-open interval so that points exactly on the // boundary between bins will fall into exactly one bin. diff --git a/pytorch3d/renderer/points/rasterize_points.py b/pytorch3d/renderer/points/rasterize_points.py index 400ae3c9..fc3cff19 100644 --- a/pytorch3d/renderer/points/rasterize_points.py +++ b/pytorch3d/renderer/points/rasterize_points.py @@ -1,6 +1,6 @@ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -from typing import Optional +from typing import List, Optional, Tuple, Union import torch @@ -18,7 +18,7 @@ kMaxPointsPerBin = 22 def rasterize_points( pointclouds, image_size: int = 256, - radius: float = 0.01, + radius: Union[float, List, Tuple, torch.Tensor] = 0.01, points_per_pixel: int = 8, bin_size: Optional[int] = None, max_points_per_bin: Optional[int] = None, @@ -35,8 +35,10 @@ def rasterize_points( (0, 0, 0); In the camera coordinate frame the x-axis goes from right-to-left, the y-axis goes from bottom-to-top, and the z-axis goes from back-to-front. image_size: Integer giving the resolution of the rasterized image - radius (Optional): Float giving the radius (in NDC units) of the disk to - be rasterized for each point. + radius (Optional): The radius (in NDC units) of the disk to + be rasterized. This can either be a float in which case the same radius is used + for each point, or a torch.Tensor of shape (N, P) giving a radius per point + in the batch. points_per_pixel (Optional): We will keep track of this many points per pixel, returning the nearest points_per_pixel points along the z-axis bin_size: Size of bins to use for coarse-to-fine rasterization. Setting @@ -74,6 +76,8 @@ def rasterize_points( cloud_to_packed_first_idx = pointclouds.cloud_to_packed_first_idx() num_points_per_cloud = pointclouds.num_points_per_cloud() + radius = _format_radius(radius, pointclouds) + if bin_size is None: if not points_packed.is_cuda: # Binned CPU rasterization not fully implemented @@ -117,6 +121,48 @@ def rasterize_points( ) +def _format_radius( + radius: Union[float, List, Tuple, torch.Tensor], pointclouds +) -> torch.Tensor: + """ + Format the radius as a torch tensor of shape (P_packed,) + where P_packed is the total number of points in the + batch (i.e. pointclouds.points_packed().shape[0]). + + This will enable support for a different size radius + for each point in the batch. + + Args: + radius: can be a float, List, Tuple or tensor of + shape (N, P_padded) where P_padded is the + maximum number of points for each pointcloud + in the batch. + + Returns: + radius: torch.Tensor of shape (P_packed) + """ + N, P_padded = pointclouds._N, pointclouds._P + points_packed = pointclouds.points_packed() + P_packed = points_packed.shape[0] + if isinstance(radius, (list, tuple)): + radius = torch.tensor(radius).type_as(points_packed) + if isinstance(radius, torch.Tensor): + if N == 1 and radius.ndim == 1: + radius = radius[None, ...] + if radius.shape != (N, P_padded): + msg = "radius must be of shape (N, P): got %s" + raise ValueError(msg % (repr(radius.shape))) + else: + padded_to_packed_idx = pointclouds.padded_to_packed_idx() + radius = radius.view(-1)[padded_to_packed_idx] + elif isinstance(radius, float): + radius = torch.full((P_packed,), fill_value=radius).type_as(points_packed) + else: + msg = "radius must be a float, list, tuple or tensor; got %s" + raise ValueError(msg % type(radius)) + return radius + + class _RasterizePoints(torch.autograd.Function): @staticmethod def forward( @@ -125,7 +171,7 @@ class _RasterizePoints(torch.autograd.Function): cloud_to_packed_first_idx, num_points_per_cloud, image_size: int = 256, - radius: float = 0.01, + radius: Union[float, torch.Tensor] = 0.01, points_per_pixel: int = 8, bin_size: int = 0, max_points_per_bin: int = 0, @@ -175,7 +221,10 @@ class _RasterizePoints(torch.autograd.Function): def rasterize_points_python( - pointclouds, image_size: int = 256, radius: float = 0.01, points_per_pixel: int = 8 + pointclouds, + image_size: int = 256, + radius: Union[float, torch.Tensor] = 0.01, + points_per_pixel: int = 8, ): """ Naive pure PyTorch implementation of pointcloud rasterization. @@ -190,6 +239,9 @@ def rasterize_points_python( cloud_to_packed_first_idx = pointclouds.cloud_to_packed_first_idx() num_points_per_cloud = pointclouds.num_points_per_cloud() + # Support variable size radius for each point in the batch + radius = _format_radius(radius, pointclouds) + # Intialize output tensors. point_idxs = torch.full( (N, S, S, K), fill_value=-1, dtype=torch.int32, device=device @@ -225,12 +277,13 @@ def rasterize_points_python( # Check whether each point in the batch affects this pixel. for p in range(point_start_idx, point_stop_idx): px, py, pz = points_packed[p, :] + r = radius2[p] if pz < 0: continue dx = px - xf dy = py - yf dist2 = dx * dx + dy * dy - if dist2 < radius2: + if dist2 < r: top_k_points.append((pz, p, dist2)) top_k_points.sort() if len(top_k_points) > K: diff --git a/pytorch3d/renderer/points/rasterizer.py b/pytorch3d/renderer/points/rasterizer.py index 8b6d8f1f..7a7ebcda 100644 --- a/pytorch3d/renderer/points/rasterizer.py +++ b/pytorch3d/renderer/points/rasterizer.py @@ -2,7 +2,7 @@ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -from typing import NamedTuple, Optional +from typing import NamedTuple, Optional, Union import torch import torch.nn as nn @@ -30,7 +30,7 @@ class PointsRasterizationSettings: def __init__( self, image_size: int = 256, - radius: float = 0.01, + radius: Union[float, torch.Tensor] = 0.01, points_per_pixel: int = 8, bin_size: Optional[int] = None, max_points_per_bin: Optional[int] = None, diff --git a/tests/bm_cameras_alignment.py b/tests/bm_cameras_alignment.py index 6c65e3d2..4d0c0397 100644 --- a/tests/bm_cameras_alignment.py +++ b/tests/bm_cameras_alignment.py @@ -1,6 +1,7 @@ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. import itertools + from fvcore.common.benchmark import benchmark from test_cameras_alignment import TestCamerasAlignment diff --git a/tests/bm_rasterize_points.py b/tests/bm_rasterize_points.py index b281fe1e..0deb3ef2 100644 --- a/tests/bm_rasterize_points.py +++ b/tests/bm_rasterize_points.py @@ -1,5 +1,6 @@ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +from itertools import product import torch from fvcore.common.benchmark import benchmark @@ -18,44 +19,64 @@ def _bm_python_with_init(N, P, img_size=32, radius=0.1, pts_per_pxl=3): return lambda: rasterize_points_python(*args) -def _bm_cpu_with_init(N, P, img_size=32, radius=0.1, pts_per_pxl=3): +def _bm_rasterize_points_with_init( + N, P, img_size=32, radius=0.1, pts_per_pxl=3, device="cpu", expand_radius=False +): torch.manual_seed(231) - points = torch.randn(N, P, 3) - pointclouds = Pointclouds(points=points) - args = (pointclouds, img_size, radius, pts_per_pxl) - return lambda: rasterize_points(*args) - - -def _bm_cuda_with_init(N, P, img_size=32, radius=0.1, pts_per_pxl=3): - torch.manual_seed(231) - device = torch.device("cuda:0") + device = torch.device(device) points = torch.randn(N, P, 3, device=device) pointclouds = Pointclouds(points=points) + + if expand_radius: + points_padded = pointclouds.points_padded() + radius = torch.full((N, P), fill_value=radius).type_as(points_padded) + args = (pointclouds, img_size, radius, pts_per_pxl) - torch.cuda.synchronize(device) + if device == "cuda": + torch.cuda.synchronize(device) def fn(): rasterize_points(*args) - torch.cuda.synchronize(device) + if device == "cuda": + torch.cuda.synchronize(device) return fn -def bm_python_vs_cpu() -> None: - kwargs_list = [ - {"N": 1, "P": 32, "img_size": 32, "radius": 0.1, "pts_per_pxl": 3}, - {"N": 2, "P": 32, "img_size": 32, "radius": 0.1, "pts_per_pxl": 3}, - ] - benchmark(_bm_python_with_init, "RASTERIZE_PYTHON", kwargs_list, warmup_iters=1) - benchmark(_bm_cpu_with_init, "RASTERIZE_CPU", kwargs_list, warmup_iters=1) - kwargs_list = [ - {"N": 2, "P": 32, "img_size": 32, "radius": 0.1, "pts_per_pxl": 3}, - {"N": 4, "P": 1024, "img_size": 128, "radius": 0.05, "pts_per_pxl": 5}, - ] - benchmark(_bm_cpu_with_init, "RASTERIZE_CPU", kwargs_list, warmup_iters=1) +def bm_python_vs_cpu_vs_cuda() -> None: + kwargs_list = [] + num_meshes = [1] + num_points = [10000, 2000] + image_size = [128, 256] + radius = [1e-3, 0.01] + pts_per_pxl = [50, 100] + expand = [True, False] + test_cases = product( + num_meshes, num_points, image_size, radius, pts_per_pxl, expand + ) + for case in test_cases: + n, p, im, r, pts, e = case + kwargs_list.append( + { + "N": n, + "P": p, + "img_size": im, + "radius": r, + "pts_per_pxl": pts, + "device": "cpu", + "expand_radius": e, + } + ) + + benchmark( + _bm_rasterize_points_with_init, "RASTERIZE_CPU", kwargs_list, warmup_iters=1 + ) kwargs_list += [ - {"N": 32, "P": 10000, "img_size": 128, "radius": 0.01, "pts_per_pxl": 50}, {"N": 32, "P": 100000, "img_size": 128, "radius": 0.01, "pts_per_pxl": 50}, {"N": 8, "P": 200000, "img_size": 512, "radius": 0.01, "pts_per_pxl": 50}, ] - benchmark(_bm_cuda_with_init, "RASTERIZE_CUDA", kwargs_list, warmup_iters=1) + for k in kwargs_list: + k["device"] = "cuda" + benchmark( + _bm_rasterize_points_with_init, "RASTERIZE_CUDA", kwargs_list, warmup_iters=1 + ) diff --git a/tests/test_rasterize_points.py b/tests/test_rasterize_points.py index a7591c52..eef3b85e 100644 --- a/tests/test_rasterize_points.py +++ b/tests/test_rasterize_points.py @@ -8,6 +8,7 @@ import torch from common_testing import TestCaseMixin, get_random_cuda_device from pytorch3d import _C from pytorch3d.renderer.points.rasterize_points import ( + _format_radius, rasterize_points, rasterize_points_python, ) @@ -40,6 +41,21 @@ class TestRasterizePoints(TestCaseMixin, unittest.TestCase): device = get_random_cuda_device() self._test_behind_camera(rasterize_points, device, bin_size=0) + def test_python_variable_radius(self): + self._test_variable_size_radius( + rasterize_points_python, torch.device("cpu"), bin_size=-1 + ) + + def test_cpu_variable_radius(self): + self._test_variable_size_radius(rasterize_points, torch.device("cpu")) + + def test_cuda_variable_radius(self): + device = get_random_cuda_device() + # Naive + self._test_variable_size_radius(rasterize_points, device, bin_size=0) + # Coarse to fine + self._test_variable_size_radius(rasterize_points, device, bin_size=None) + def test_cpp_vs_naive_vs_binned(self): # Make sure that the backward pass runs for all pathways N = 2 @@ -403,6 +419,8 @@ class TestRasterizePoints(TestCaseMixin, unittest.TestCase): points_packed = pointclouds.points_packed() cloud_to_packed_first_idx = pointclouds.cloud_to_packed_first_idx() num_points_per_cloud = pointclouds.num_points_per_cloud() + + radius = torch.full((points_packed.shape[0],), fill_value=radius) args = ( points_packed, cloud_to_packed_first_idx, @@ -419,6 +437,7 @@ class TestRasterizePoints(TestCaseMixin, unittest.TestCase): points_packed = pointclouds_cuda.points_packed() cloud_to_packed_first_idx = pointclouds_cuda.cloud_to_packed_first_idx() num_points_per_cloud = pointclouds_cuda.num_points_per_cloud() + radius = radius.to(device) args = ( points_packed, cloud_to_packed_first_idx, @@ -499,6 +518,7 @@ class TestRasterizePoints(TestCaseMixin, unittest.TestCase): bin_points_expected[0, 1, 1, :2] = torch.tensor([0, 1]) pointclouds = Pointclouds(points=[points]) + radius = torch.full((points.shape[0],), fill_value=radius, device=device) args = ( pointclouds.points_packed(), pointclouds.cloud_to_packed_first_idx(), @@ -512,3 +532,115 @@ class TestRasterizePoints(TestCaseMixin, unittest.TestCase): bin_points_same = (bin_points == bin_points_expected).all() self.assertTrue(bin_points_same.item() == 1) + + def _test_variable_size_radius(self, rasterize_points_fn, device, bin_size=0): + # Two points + points = torch.tensor( + [[0.5, 0.5, 0.3], [0.5, -0.5, -0.1], [0.0, 0.0, 0.3]], + dtype=torch.float32, + device=device, + ) + image_size = 16 + points_per_pixel = 1 + radius = torch.tensor([0.1, 0.0, 0.2], dtype=torch.float32, device=device) + pointclouds = Pointclouds(points=[points]) + if bin_size == -1: + # simple python case with no binning + idx, zbuf, dists = rasterize_points_fn( + pointclouds, image_size, radius, points_per_pixel + ) + else: + idx, zbuf, dists = rasterize_points_fn( + pointclouds, image_size, radius, points_per_pixel, bin_size + ) + + idx_expected = torch.zeros( + (1, image_size, image_size, 1), dtype=torch.int64, device=device + ) + # fmt: off + idx_expected[0, ..., 0] = torch.tensor( + [ + [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241 E201 + [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241 E201 + [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241 E201 + [-1, -1, -1, 0, 0, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241 E201 + [-1, -1, -1, 0, 0, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241 E201 + [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241 E201 + [-1, -1, -1, -1, -1, -1, -1, 2, 2, -1, -1, -1, -1, -1, -1, -1], # noqa: E241 E201 + [-1, -1, -1, -1, -1, -1, 2, 2, 2, 2, -1, -1, -1, -1, -1, -1], # noqa: E241 E201 + [-1, -1, -1, -1, -1, -1, 2, 2, 2, 2, -1, -1, -1, -1, -1, -1], # noqa: E241 E201 + [-1, -1, -1, -1, -1, -1, -1, 2, 2, -1, -1, -1, -1, -1, -1, -1], # noqa: E241 E201 + [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241 E201 + [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241 E201 + [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241 E201 + [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241 E201 + [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241 E201 + [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1] # noqa: E241 E201 + ], + dtype=torch.int64, + device=device + ) + # fmt: on + zbuf_expected = torch.full( + idx_expected.shape, fill_value=-1, dtype=torch.float32, device=device + ) + zbuf_expected[idx_expected == 0] = 0.3 + zbuf_expected[idx_expected == 2] = 0.3 + + dists_expected = torch.full( + idx_expected.shape, fill_value=-1, dtype=torch.float32, device=device + ) + + # fmt: off + dists_expected[0, ..., 0] = torch.Tensor( + [ + [-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.], # noqa: E241 E201 + [-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.], # noqa: E241 E201 + [-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.], # noqa: E241 E201 + [-1., -1., -1., 0.0078, 0.0078, -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.], # noqa: E241 E201 + [-1., -1., -1., 0.0078, 0.0078, -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.], # noqa: E241 E201 + [-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.], # noqa: E241 E201 + [-1., -1., -1., -1., -1., -1., -1., 0.0391, 0.0391, -1., -1., -1., -1., -1., -1., -1.], # noqa: E241 E201 + [-1., -1., -1., -1., -1., -1., 0.0391, 0.0078, 0.0078, 0.0391, -1., -1., -1., -1., -1., -1.], # noqa: E241 E201 + [-1., -1., -1., -1., -1., -1., 0.0391, 0.0078, 0.0078, 0.0391, -1., -1., -1., -1., -1., -1.], # noqa: E241 E201 + [-1., -1., -1., -1., -1., -1., -1., 0.0391, 0.0391, -1., -1., -1., -1., -1., -1., -1.], # noqa: E241 E201 + [-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.], # noqa: E241 E201 + [-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.], # noqa: E241 E201 + [-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.], # noqa: E241 E201 + [-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.], # noqa: E241 E201 + [-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.], # noqa: E241 E201 + [-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.] # noqa: E241 E201 + ] + ) + # fmt: on + + # Check the distances for a point are less than the squared radius + # for that point. + self.assertTrue((dists[idx == 0] < radius[0] ** 2).all()) + self.assertTrue((dists[idx == 2] < radius[2] ** 2).all()) + + # Check all values are correct. + idx_same = (idx == idx_expected).all().item() == 1 + zbuf_same = (zbuf == zbuf_expected).all().item() == 1 + + self.assertTrue(idx_same) + self.assertTrue(zbuf_same) + self.assertClose(dists, dists_expected, atol=4e-5) + + def test_radius_format_failure(self): + N = 20 + P_max = 15 + points_list = [] + for _ in range(N): + p = torch.randint(low=1, high=P_max, size=(1,))[0] + points_list.append(torch.randn((p, 3))) + + points = Pointclouds(points=points_list) + + # Incorrect shape + with self.assertRaisesRegex(ValueError, "radius must be of shape"): + _format_radius([0, 1, 2], points) + + # Incorrect type + with self.assertRaisesRegex(ValueError, "float, list, tuple or tensor"): + _format_radius({0: [0, 1, 2]}, points)