From 103da63393d6bbb697835ddbfc86b07572ea4d0c Mon Sep 17 00:00:00 2001 From: Nikhila Ravi Date: Thu, 12 Aug 2021 14:05:23 -0700 Subject: [PATCH] Ball Query Summary: Implementation of ball query from PointNet++. This function is similar to KNN (find the neighbors in p2 for all points in p1). These are the key differences: - It will return the **first** K neighbors within a specified radius as opposed to the **closest** K neighbors. - As all the points in p2 do not need to be considered to find the closest K, the algorithm is much faster than KNN when p2 has a large number of points. - The neighbors are not sorted - Due to the radius threshold it is not guaranteed that there will be K neighbors even if there are more than K points in p2. - The padding value for `idx` is -1 instead of 0. # Note: - Some of the code is very similar to KNN so it could be possible to modify the KNN forward kernels to support ball query. - Some users might want to use kNN with ball query - for this we could provide a wrapper function around the current `knn_points` which enables applying the radius threshold afterwards as an alternative. This could be called `ball_query_knn`. Reviewed By: jcjohnson Differential Revision: D30261362 fbshipit-source-id: 66b6a7e0114beff7164daf7eba21546ff41ec450 --- pytorch3d/csrc/ball_query/ball_query.cu | 130 +++++++++++ pytorch3d/csrc/ball_query/ball_query.h | 91 ++++++++ pytorch3d/csrc/ball_query/ball_query_cpu.cpp | 55 +++++ pytorch3d/csrc/ext.cpp | 4 + pytorch3d/csrc/knn/knn.cu | 4 + pytorch3d/csrc/knn/knn_cpu.cpp | 4 + pytorch3d/ops/__init__.py | 2 +- pytorch3d/ops/ball_query.py | 150 ++++++++++++ tests/bm_ball_query.py | 40 ++++ tests/test_ball_query.py | 230 +++++++++++++++++++ 10 files changed, 709 insertions(+), 1 deletion(-) create mode 100644 pytorch3d/csrc/ball_query/ball_query.cu create mode 100644 pytorch3d/csrc/ball_query/ball_query.h create mode 100644 pytorch3d/csrc/ball_query/ball_query_cpu.cpp create mode 100644 pytorch3d/ops/ball_query.py create mode 100644 tests/bm_ball_query.py create mode 100644 tests/test_ball_query.py diff --git a/pytorch3d/csrc/ball_query/ball_query.cu b/pytorch3d/csrc/ball_query/ball_query.cu new file mode 100644 index 00000000..bababec5 --- /dev/null +++ b/pytorch3d/csrc/ball_query/ball_query.cu @@ -0,0 +1,130 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include +#include "utils/pytorch3d_cutils.h" + +// A chunk of work is blocksize-many points of P1. +// The number of potential chunks to do is N*(1+(P1-1)/blocksize) +// call (1+(P1-1)/blocksize) chunks_per_cloud +// These chunks are divided among the gridSize-many blocks. +// In block b, we work on chunks b, b+gridSize, b+2*gridSize etc . +// In chunk i, we work on cloud i/chunks_per_cloud on points starting from +// blocksize*(i%chunks_per_cloud). + +template +__global__ void BallQueryKernel( + const at::PackedTensorAccessor64 p1, + const at::PackedTensorAccessor64 p2, + const at::PackedTensorAccessor64 + lengths1, + const at::PackedTensorAccessor64 + lengths2, + at::PackedTensorAccessor64 idxs, + at::PackedTensorAccessor64 dists, + const int64_t K, + const float radius2) { + const int64_t N = p1.size(0); + const int64_t chunks_per_cloud = (1 + (p1.size(1) - 1) / blockDim.x); + const int64_t chunks_to_do = N * chunks_per_cloud; + const int D = p1.size(2); + + for (int64_t chunk = blockIdx.x; chunk < chunks_to_do; chunk += gridDim.x) { + const int64_t n = chunk / chunks_per_cloud; // batch_index + const int64_t start_point = blockDim.x * (chunk % chunks_per_cloud); + int64_t i = start_point + threadIdx.x; + + // Check if point is valid in heterogeneous tensor + if (i >= lengths1[n]) { + continue; + } + + // Iterate over points in p2 until desired count is reached or + // all points have been considered + for (int64_t j = 0, count = 0; j < lengths2[n] && count < K; ++j) { + // Calculate the distance between the points + scalar_t dist2 = 0.0; + for (int d = 0; d < D; ++d) { + scalar_t diff = p1[n][i][d] - p2[n][j][d]; + dist2 += (diff * diff); + } + + if (dist2 < radius2) { + // If the point is within the radius + // Set the value of the index to the point index + idxs[n][i][count] = j; + dists[n][i][count] = dist2; + + // increment the number of selected samples for the point i + ++count; + } + } + } +} + +std::tuple BallQueryCuda( + const at::Tensor& p1, // (N, P1, 3) + const at::Tensor& p2, // (N, P2, 3) + const at::Tensor& lengths1, // (N,) + const at::Tensor& lengths2, // (N,) + int K, + float radius) { + // Check inputs are on the same device + at::TensorArg p1_t{p1, "p1", 1}, p2_t{p2, "p2", 2}, + lengths1_t{lengths1, "lengths1", 3}, lengths2_t{lengths2, "lengths2", 4}; + at::CheckedFrom c = "BallQueryCuda"; + at::checkAllSameGPU(c, {p1_t, p2_t, lengths1_t, lengths2_t}); + at::checkAllSameType(c, {p1_t, p2_t}); + + // Set the device for the kernel launch based on the device of p1 + at::cuda::CUDAGuard device_guard(p1.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + TORCH_CHECK( + p2.size(2) == p1.size(2), "Point sets must have the same last dimension"); + + const int N = p1.size(0); + const int P1 = p1.size(1); + const int64_t K_64 = K; + const float radius2 = radius * radius; + + // Output tensor with indices of neighbors for each point in p1 + auto long_dtype = lengths1.options().dtype(at::kLong); + auto idxs = at::full({N, P1, K}, -1, long_dtype); + auto dists = at::zeros({N, P1, K}, p1.options()); + + if (idxs.numel() == 0) { + AT_CUDA_CHECK(cudaGetLastError()); + return std::make_tuple(idxs, dists); + } + + const size_t blocks = 256; + const size_t threads = 256; + + AT_DISPATCH_FLOATING_TYPES( + p1.scalar_type(), "ball_query_kernel_cuda", ([&] { + BallQueryKernel<<>>( + p1.packed_accessor64(), + p2.packed_accessor64(), + lengths1.packed_accessor64(), + lengths2.packed_accessor64(), + idxs.packed_accessor64(), + dists.packed_accessor64(), + K_64, + radius2); + })); + + AT_CUDA_CHECK(cudaGetLastError()); + + return std::make_tuple(idxs, dists); +} diff --git a/pytorch3d/csrc/ball_query/ball_query.h b/pytorch3d/csrc/ball_query/ball_query.h new file mode 100644 index 00000000..c8a1cd76 --- /dev/null +++ b/pytorch3d/csrc/ball_query/ball_query.h @@ -0,0 +1,91 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once +#include +#include +#include "utils/pytorch3d_cutils.h" + +// Compute indices of K neighbors in pointcloud p2 to points +// in pointcloud p1 which fall within a specified radius +// +// Args: +// p1: FloatTensor of shape (N, P1, D) giving a batch of pointclouds each +// containing P1 points of dimension D. +// p2: FloatTensor of shape (N, P2, D) giving a batch of pointclouds each +// containing P2 points of dimension D. +// lengths1: LongTensor, shape (N,), giving actual length of each P1 cloud. +// lengths2: LongTensor, shape (N,), giving actual length of each P2 cloud. +// K: Integer giving the upper bound on the number of samples to take +// within the radius +// radius: the radius around each point within which the neighbors need to be +// located +// +// Returns: +// p1_neighbor_idx: LongTensor of shape (N, P1, K), where +// p1_neighbor_idx[n, i, k] = j means that the kth +// neighbor to p1[n, i] in the cloud p2[n] is p2[n, j]. +// This is padded with -1s both where a cloud in p2 has fewer than +// S points and where a cloud in p1 has fewer than P1 points and +// also if there are fewer than K points which satisfy the radius +// threshold. +// +// p1_neighbor_dists: FloatTensor of shape (N, P1, K) containing the squared +// distance from each point p1[n, p, :] to its K neighbors +// p2[n, p1_neighbor_idx[n, p, k], :]. + +// CPU implementation +std::tuple BallQueryCpu( + const at::Tensor& p1, + const at::Tensor& p2, + const at::Tensor& lengths1, + const at::Tensor& lengths2, + const int K, + const float radius); + +// CUDA implementation +std::tuple BallQueryCuda( + const at::Tensor& p1, + const at::Tensor& p2, + const at::Tensor& lengths1, + const at::Tensor& lengths2, + const int K, + const float radius); + +// Implementation which is exposed +// Note: the backward pass reuses the KNearestNeighborBackward kernel +inline std::tuple BallQuery( + const at::Tensor& p1, + const at::Tensor& p2, + const at::Tensor& lengths1, + const at::Tensor& lengths2, + int K, + float radius) { + if (p1.is_cuda() || p2.is_cuda()) { +#ifdef WITH_CUDA + CHECK_CUDA(p1); + CHECK_CUDA(p2); + return BallQueryCuda( + p1.contiguous(), + p2.contiguous(), + lengths1.contiguous(), + lengths2.contiguous(), + K, + radius); +#else + AT_ERROR("Not compiled with GPU support."); +#endif + } + return BallQueryCpu( + p1.contiguous(), + p2.contiguous(), + lengths1.contiguous(), + lengths2.contiguous(), + K, + radius); +} diff --git a/pytorch3d/csrc/ball_query/ball_query_cpu.cpp b/pytorch3d/csrc/ball_query/ball_query_cpu.cpp new file mode 100644 index 00000000..f7c59e0b --- /dev/null +++ b/pytorch3d/csrc/ball_query/ball_query_cpu.cpp @@ -0,0 +1,55 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +std::tuple BallQueryCpu( + const at::Tensor& p1, + const at::Tensor& p2, + const at::Tensor& lengths1, + const at::Tensor& lengths2, + int K, + float radius) { + const int N = p1.size(0); + const int P1 = p1.size(1); + const int D = p1.size(2); + + auto long_opts = lengths1.options().dtype(torch::kInt64); + torch::Tensor idxs = torch::full({N, P1, K}, -1, long_opts); + torch::Tensor dists = torch::full({N, P1, K}, 0, p1.options()); + const float radius2 = radius * radius; + + auto p1_a = p1.accessor(); + auto p2_a = p2.accessor(); + auto lengths1_a = lengths1.accessor(); + auto lengths2_a = lengths2.accessor(); + auto idxs_a = idxs.accessor(); + auto dists_a = dists.accessor(); + + for (int n = 0; n < N; ++n) { + const int64_t length1 = lengths1_a[n]; + const int64_t length2 = lengths2_a[n]; + for (int64_t i = 0; i < length1; ++i) { + for (int64_t j = 0, count = 0; j < length2 && count < K; ++j) { + float dist2 = 0; + for (int d = 0; d < D; ++d) { + float diff = p1_a[n][i][d] - p2_a[n][j][d]; + dist2 += diff * diff; + } + if (dist2 < radius2) { + dists_a[n][i][count] = dist2; + idxs_a[n][i][count] = j; + ++count; + } + } + } + } + return std::make_tuple(idxs, dists); +} diff --git a/pytorch3d/csrc/ext.cpp b/pytorch3d/csrc/ext.cpp index e3022a16..8a5a3d54 100644 --- a/pytorch3d/csrc/ext.cpp +++ b/pytorch3d/csrc/ext.cpp @@ -12,6 +12,7 @@ // clang-format on #include "./pulsar/pytorch/renderer.h" #include "./pulsar/pytorch/tensor_util.h" +#include "ball_query/ball_query.h" #include "blending/sigmoid_alpha_blend.h" #include "compositing/alpha_composite.h" #include "compositing/norm_weighted_sum.h" @@ -38,6 +39,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { #endif m.def("knn_points_idx", &KNearestNeighborIdx); m.def("knn_points_backward", &KNearestNeighborBackward); + + // Ball Query + m.def("ball_query", &BallQuery); m.def( "mesh_normal_consistency_find_verts", &MeshNormalConsistencyFindVertices); m.def("gather_scatter", &GatherScatter); diff --git a/pytorch3d/csrc/knn/knn.cu b/pytorch3d/csrc/knn/knn.cu index 48eaff54..eaa25203 100644 --- a/pytorch3d/csrc/knn/knn.cu +++ b/pytorch3d/csrc/knn/knn.cu @@ -477,6 +477,10 @@ __global__ void KNearestNeighborBackwardKernel( const float grad_dist = grad_dists[n * P1 * K + p1_idx * K + k]; // index of point in p2 corresponding to the k-th nearest neighbor const size_t p2_idx = idxs[n * P1 * K + p1_idx * K + k]; + // If the index is the pad value of -1 then ignore it + if (p2_idx == -1) { + continue; + } const float diff = 2.0 * grad_dist * (p1[n * P1 * D + p1_idx * D + d] - p2[n * P2 * D + p2_idx * D + d]); atomicAdd(grad_p1 + n * P1 * D + p1_idx * D + d, diff); diff --git a/pytorch3d/csrc/knn/knn_cpu.cpp b/pytorch3d/csrc/knn/knn_cpu.cpp index 0f8b9967..b7a5cb6c 100644 --- a/pytorch3d/csrc/knn/knn_cpu.cpp +++ b/pytorch3d/csrc/knn/knn_cpu.cpp @@ -99,6 +99,10 @@ std::tuple KNearestNeighborBackwardCpu( for (int64_t i1 = 0; i1 < length1; ++i1) { for (int64_t k = 0; k < length2; ++k) { const int64_t i2 = idxs_a[n][i1][k]; + // If the index is the pad value of -1 then ignore it + if (i2 == -1) { + continue; + } for (int64_t d = 0; d < D; ++d) { const float diff = 2.0f * grad_dists_a[n][i1][k] * (p1_a[n][i1][d] - p2_a[n][i2][d]); diff --git a/pytorch3d/ops/__init__.py b/pytorch3d/ops/__init__.py index 22ab32c8..9428578e 100644 --- a/pytorch3d/ops/__init__.py +++ b/pytorch3d/ops/__init__.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from .ball_query import ball_query from .cameras_alignment import corresponding_cameras_alignment from .cubify import cubify from .graph_conv import GraphConv @@ -34,5 +35,4 @@ from .utils import ( ) from .vert_align import vert_align - __all__ = [k for k in globals().keys() if not k.startswith("_")] diff --git a/pytorch3d/ops/ball_query.py b/pytorch3d/ops/ball_query.py new file mode 100644 index 00000000..5105400a --- /dev/null +++ b/pytorch3d/ops/ball_query.py @@ -0,0 +1,150 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Union + +import torch +from pytorch3d import _C +from torch.autograd import Function +from torch.autograd.function import once_differentiable + +from .knn import _KNN + + +class _ball_query(Function): + """ + Torch autograd Function wrapper for Ball Query C++/CUDA implementations. + """ + + @staticmethod + def forward(ctx, p1, p2, lengths1, lengths2, K, radius): + """ + Arguments defintions the same as in the ball_query function + """ + idx, dists = _C.ball_query(p1, p2, lengths1, lengths2, K, radius) + ctx.save_for_backward(p1, p2, lengths1, lengths2, idx) + ctx.mark_non_differentiable(idx) + return dists, idx + + @staticmethod + @once_differentiable + def backward(ctx, grad_dists, grad_idx): + p1, p2, lengths1, lengths2, idx = ctx.saved_tensors + # TODO(gkioxari) Change cast to floats once we add support for doubles. + if not (grad_dists.dtype == torch.float32): + grad_dists = grad_dists.float() + if not (p1.dtype == torch.float32): + p1 = p1.float() + if not (p2.dtype == torch.float32): + p2 = p2.float() + + # Reuse the KNN backward function + grad_p1, grad_p2 = _C.knn_points_backward( + p1, p2, lengths1, lengths2, idx, grad_dists + ) + return grad_p1, grad_p2, None, None, None, None + + +def ball_query( + p1: torch.Tensor, + p2: torch.Tensor, + lengths1: Union[torch.Tensor, None] = None, + lengths2: Union[torch.Tensor, None] = None, + K: int = 500, + radius: float = 0.2, + return_nn: bool = True, +): + """ + Ball Query is an alternative to KNN. It can be + used to find all points in p2 that are within a specified radius + to the query point in p1 (with an upper limit of K neighbors). + + The neighbors returned are not necssarily the *nearest* to the + point in p1, just the first K values in p2 which are within the + specified radius. + + This method is faster than kNN when there are large numbers of points + in p2 and the ordering of neighbors is not important compared to the + distance being within the radius threshold. + + "Ball query’s local neighborhood guarantees a fixed region scale thus + making local region features more generalizable across space, which is + preferred for tasks requiring local pattern recognition + (e.g. semantic point labeling)" [1]. + + [1] Charles R. Qi et al, "PointNet++: Deep Hierarchical Feature Learning + on Point Sets in a Metric Space", NeurIPS 2017. + + Args: + p1: Tensor of shape (N, P1, D) giving a batch of N point clouds, each + containing up to P1 points of dimension D. These represent the centers of + the ball queries. + p2: Tensor of shape (N, P2, D) giving a batch of N point clouds, each + containing up to P2 points of dimension D. + lengths1: LongTensor of shape (N,) of values in the range [0, P1], giving the + length of each pointcloud in p1. Or None to indicate that every cloud has + length P1. + lengths2: LongTensor of shape (N,) of values in the range [0, P2], giving the + length of each pointcloud in p2. Or None to indicate that every cloud has + length P2. + K: Integer giving the upper bound on the number of samples to take + within the radius + radius: the radius around each point within which the neighbors need to be located + return_nn: If set to True returns the K neighbor points in p2 for each point in p1. + + Returns: + dists: Tensor of shape (N, P1, K) giving the squared distances to + the neighbors. This is padded with zeros both where a cloud in p2 + has fewer than S points and where a cloud in p1 has fewer than P1 points + and also if there are fewer than K points which satisfy the radius threshold. + + idx: LongTensor of shape (N, P1, K) giving the indices of the + S neighbors in p2 for points in p1. + Concretely, if `p1_idx[n, i, k] = j` then `p2[n, j]` is the k-th + neighbor to `p1[n, i]` in `p2[n]`. This is padded with -1 both where a cloud + in p2 has fewer than S points and where a cloud in p1 has fewer than P1 + points and also if there are fewer than K points which satisfy the radius threshold. + + nn: Tensor of shape (N, P1, K, D) giving the K neighbors in p2 for + each point in p1. Concretely, `p2_nn[n, i, k]` gives the k-th neighbor + for `p1[n, i]`. Returned if `return_nn` is True. The output is a tensor + of shape (N, P1, K, U). + + """ + if p1.shape[0] != p2.shape[0]: + raise ValueError("pts1 and pts2 must have the same batch dimension.") + if p1.shape[2] != p2.shape[2]: + raise ValueError("pts1 and pts2 must have the same point dimension.") + + p1 = p1.contiguous() + p2 = p2.contiguous() + P1 = p1.shape[1] + P2 = p2.shape[1] + D = p2.shape[2] + N = p1.shape[0] + + if lengths1 is None: + lengths1 = torch.full((N,), P1, dtype=torch.int64, device=p1.device) + if lengths2 is None: + lengths2 = torch.full((N,), P2, dtype=torch.int64, device=p1.device) + + # pyre-fixme[16]: `_ball_query` has no attribute `apply`. + dists, idx = _ball_query.apply(p1, p2, lengths1, lengths2, K, radius) + + # Gather the neighbors if needed + points_nn = None + if return_nn: + idx_expanded = idx[:, :, :, None].expand(-1, -1, -1, D) + idx_mask = idx_expanded.eq(-1) + idx_new = idx_expanded.clone() + # Replace -1 values with 0 for gather + idx_new[idx_mask] = 0 + # Gather points from p2 + points_nn = p2[:, :, None].expand(-1, -1, K, -1).gather(1, idx_new) + # Replace padded values + points_nn[idx_mask] = 0.0 + + return _KNN(dists=dists, idx=idx, knn=points_nn) diff --git a/tests/bm_ball_query.py b/tests/bm_ball_query.py new file mode 100644 index 00000000..ac6a368d --- /dev/null +++ b/tests/bm_ball_query.py @@ -0,0 +1,40 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from itertools import product + +from fvcore.common.benchmark import benchmark +from test_ball_query import TestBallQuery + + +def bm_ball_query() -> None: + + backends = ["cpu", "cuda:0"] + + kwargs_list = [] + Ns = [32] + P1s = [256] + P2s = [128, 512] + Ds = [3, 10] + Ks = [3, 24, 100] + Rs = [0.1, 0.2, 5] + test_cases = product(Ns, P1s, P2s, Ds, Ks, Rs, backends) + for case in test_cases: + N, P1, P2, D, K, R, b = case + kwargs_list.append( + {"N": N, "P1": P1, "P2": P2, "D": D, "K": K, "radius": R, "device": b} + ) + + benchmark( + TestBallQuery.ball_query_square, "BALLQUERY_SQUARE", kwargs_list, warmup_iters=1 + ) + benchmark( + TestBallQuery.ball_query_ragged, "BALLQUERY_RAGGED", kwargs_list, warmup_iters=1 + ) + + +if __name__ == "__main__": + bm_ball_query() diff --git a/tests/test_ball_query.py b/tests/test_ball_query.py new file mode 100644 index 00000000..7bbc9fa7 --- /dev/null +++ b/tests/test_ball_query.py @@ -0,0 +1,230 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest +from itertools import product + +import torch +from common_testing import TestCaseMixin, get_random_cuda_device +from pytorch3d.ops import sample_points_from_meshes +from pytorch3d.ops.ball_query import ball_query +from pytorch3d.ops.knn import _KNN +from pytorch3d.utils import ico_sphere + + +class TestBallQuery(TestCaseMixin, unittest.TestCase): + def setUp(self) -> None: + super().setUp() + torch.manual_seed(1) + + @staticmethod + def _ball_query_naive( + p1, p2, lengths1, lengths2, K: int, radius: float + ) -> torch.Tensor: + """ + Naive PyTorch implementation of ball query. + """ + N, P1, D = p1.shape + _N, P2, _D = p2.shape + + assert N == _N and D == _D + + if lengths1 is None: + lengths1 = torch.full((N,), P1, dtype=torch.int64, device=p1.device) + if lengths2 is None: + lengths2 = torch.full((N,), P2, dtype=torch.int64, device=p1.device) + + radius2 = radius * radius + dists = torch.zeros((N, P1, K), dtype=torch.float32, device=p1.device) + idx = torch.full((N, P1, K), fill_value=-1, dtype=torch.int64, device=p1.device) + + # Iterate through the batches + for n in range(N): + num1 = lengths1[n].item() + num2 = lengths2[n].item() + + # Iterate through the points in the p1 + for i in range(num1): + # Iterate through the points in the p2 + count = 0 + for j in range(num2): + dist = p2[n, j] - p1[n, i] + dist2 = (dist * dist).sum() + if dist2 < radius2 and count < K: + dists[n, i, count] = dist2 + idx[n, i, count] = j + count += 1 + + return _KNN(dists=dists, idx=idx, knn=None) + + def _ball_query_vs_python_square_helper(self, device): + Ns = [1, 4] + Ds = [3, 5, 8] + P1s = [8, 24] + P2s = [8, 16, 32] + Ks = [1, 5] + Rs = [3, 5] + factors = [Ns, Ds, P1s, P2s, Ks, Rs] + for N, D, P1, P2, K, R in product(*factors): + x = torch.randn(N, P1, D, device=device, requires_grad=True) + x_cuda = x.clone().detach() + x_cuda.requires_grad_(True) + y = torch.randn(N, P2, D, device=device, requires_grad=True) + y_cuda = y.clone().detach() + y_cuda.requires_grad_(True) + + # forward + out1 = self._ball_query_naive( + x, y, lengths1=None, lengths2=None, K=K, radius=R + ) + out2 = ball_query(x_cuda, y_cuda, K=K, radius=R) + + # Check dists + self.assertClose(out1.dists, out2.dists) + # Check idx + self.assertTrue(torch.all(out1.idx == out2.idx)) + + # backward + grad_dist = torch.ones((N, P1, K), dtype=torch.float32, device=device) + loss1 = (out1.dists * grad_dist).sum() + loss1.backward() + loss2 = (out2.dists * grad_dist).sum() + loss2.backward() + + self.assertClose(x_cuda.grad, x.grad, atol=5e-6) + self.assertClose(y_cuda.grad, y.grad, atol=5e-6) + + def test_ball_query_vs_python_square_cpu(self): + device = torch.device("cpu") + self._ball_query_vs_python_square_helper(device) + + def test_ball_query_vs_python_square_cuda(self): + device = get_random_cuda_device() + self._ball_query_vs_python_square_helper(device) + + def _ball_query_vs_python_ragged_helper(self, device): + Ns = [1, 4] + Ds = [3, 5, 8] + P1s = [8, 24] + P2s = [8, 16, 32] + Ks = [2, 3, 10] + Rs = [1.4, 5] # radius + factors = [Ns, Ds, P1s, P2s, Ks, Rs] + for N, D, P1, P2, K, R in product(*factors): + x = torch.rand((N, P1, D), device=device, requires_grad=True) + y = torch.rand((N, P2, D), device=device, requires_grad=True) + lengths1 = torch.randint(low=1, high=P1, size=(N,), device=device) + lengths2 = torch.randint(low=1, high=P2, size=(N,), device=device) + + x_csrc = x.clone().detach() + x_csrc.requires_grad_(True) + y_csrc = y.clone().detach() + y_csrc.requires_grad_(True) + + # forward + out1 = self._ball_query_naive( + x, y, lengths1=lengths1, lengths2=lengths2, K=K, radius=R + ) + out2 = ball_query( + x_csrc, + y_csrc, + lengths1=lengths1, + lengths2=lengths2, + K=K, + radius=R, + ) + + self.assertClose(out1.idx, out2.idx) + self.assertClose(out1.dists, out2.dists) + + # backward + grad_dist = torch.ones((N, P1, K), dtype=torch.float32, device=device) + loss1 = (out1.dists * grad_dist).sum() + loss1.backward() + loss2 = (out2.dists * grad_dist).sum() + loss2.backward() + + self.assertClose(x_csrc.grad, x.grad, atol=5e-6) + self.assertClose(y_csrc.grad, y.grad, atol=5e-6) + + def test_ball_query_vs_python_ragged_cpu(self): + device = torch.device("cpu") + self._ball_query_vs_python_ragged_helper(device) + + def test_ball_query_vs_python_ragged_cuda(self): + device = get_random_cuda_device() + self._ball_query_vs_python_ragged_helper(device) + + def test_ball_query_output_simple(self): + device = get_random_cuda_device() + N, P1, P2, K = 5, 8, 16, 4 + sphere = ico_sphere(level=2, device=device).extend(N) + points_1 = sample_points_from_meshes(sphere, P1) + points_2 = sample_points_from_meshes(sphere, P2) * 5.0 + radius = 6.0 + + naive_out = self._ball_query_naive( + points_1, points_2, lengths1=None, lengths2=None, K=K, radius=radius + ) + cuda_out = ball_query(points_1, points_2, K=K, radius=radius) + + # All points should have N sample neighbors as radius is large + # Zero is a valid index but can only be present once (i.e. no zero padding) + naive_out_zeros = (naive_out.idx == 0).sum(dim=-1).max() + cuda_out_zeros = (cuda_out.idx == 0).sum(dim=-1).max() + self.assertTrue(naive_out_zeros == 0 or naive_out_zeros == 1) + self.assertTrue(cuda_out_zeros == 0 or cuda_out_zeros == 1) + + # All points should now have zero sample neighbors as radius is small + radius = 0.5 + naive_out = self._ball_query_naive( + points_1, points_2, lengths1=None, lengths2=None, K=K, radius=radius + ) + cuda_out = ball_query(points_1, points_2, K=K, radius=radius) + naive_out_allzeros = (naive_out.idx == -1).all() + cuda_out_allzeros = (cuda_out.idx == -1).sum() + self.assertTrue(naive_out_allzeros) + self.assertTrue(cuda_out_allzeros) + + @staticmethod + def ball_query_square( + N: int, P1: int, P2: int, D: int, K: int, radius: float, device: str + ): + device = torch.device(device) + pts1 = torch.randn(N, P1, D, device=device, requires_grad=True) + pts2 = torch.randn(N, P2, D, device=device, requires_grad=True) + grad_dists = torch.randn(N, P1, K, device=device) + torch.cuda.synchronize() + + def output(): + out = ball_query(pts1, pts2, K=K, radius=radius) + loss = (out.dists * grad_dists).sum() + loss.backward() + torch.cuda.synchronize() + + return output + + @staticmethod + def ball_query_ragged( + N: int, P1: int, P2: int, D: int, K: int, radius: float, device: str + ): + device = torch.device(device) + pts1 = torch.rand((N, P1, D), device=device, requires_grad=True) + pts2 = torch.rand((N, P2, D), device=device, requires_grad=True) + lengths1 = torch.randint(low=1, high=P1, size=(N,), device=device) + lengths2 = torch.randint(low=1, high=P2, size=(N,), device=device) + grad_dists = torch.randn(N, P1, K, device=device) + torch.cuda.synchronize() + + def output(): + out = ball_query( + pts1, pts2, lengths1=lengths1, lengths2=lengths2, K=K, radius=radius + ) + loss = (out.dists * grad_dists).sum() + loss.backward() + torch.cuda.synchronize() + + return output