From 2d4d345b6fd2720580bff5f63dcbd3b230b43996 Mon Sep 17 00:00:00 2001 From: Eugene Park Date: Thu, 30 Oct 2025 05:01:32 -0700 Subject: [PATCH] Improve `ball_query()` runtime for large-scale cases (#2006) Summary: ### Overview The current C++ code for `pytorch3d.ops.ball_query()` performs floating point multiplication for every coordinate of every pair of points (up until the maximum number of neighbor points is reached). This PR modifies the code (for both CPU and CUDA versions) to implement idea presented [here](https://stackoverflow.com/a/3939525): a `D`-cube around the `D`-ball is first constructed, and any point pairs falling outside the cube are skipped, without explicitly computing the squared distances. This change is especially useful for when the dimension `D` and the number of points `P2` are large and the radius is much smaller than the overall volume of space occupied by the point clouds; as much as **~2.5x speedup** (CPU case; ~1.8x speedup in CUDA case) is observed when `D = 10` and `radius = 0.01`. In all benchmark cases, points were uniform randomly distributed inside a unit `D`-cube. The benchmark code used was different from `tests/benchmarks/bm_ball_query.py` (only the forward part is benchmarked, larger input sizes were used) and is stored in `tests/benchmarks/bm_ball_query_large.py`. ### Average time comparisons cpu-03-0 01-avg cuda-03-0 01-avg cpu-03-0 10-avg cuda-03-0 10-avg cpu-10-0 01-avg cuda-10-0 01-avg cpu-10-0 10-avg cuda-10-0 10-avg ### Peak time comparisons cpu-03-0 01-peak cuda-03-0 01-peak cpu-03-0 10-peak cuda-03-0 10-peak cpu-10-0 01-peak cuda-10-0 01-peak cpu-10-0 10-peak cuda-10-0 10-peak ### Full benchmark logs [benchmark-before-change.txt](https://github.com/user-attachments/files/22978300/benchmark-before-change.txt) [benchmark-after-change.txt](https://github.com/user-attachments/files/22978299/benchmark-after-change.txt) Pull Request resolved: https://github.com/facebookresearch/pytorch3d/pull/2006 Reviewed By: shapovalov Differential Revision: D85356394 Pulled By: bottler fbshipit-source-id: 9b3ce5fc87bb73d4323cc5b4190fc38ae42f41b2 --- pytorch3d/csrc/ball_query/ball_query.cu | 25 +++++++-- pytorch3d/csrc/ball_query/ball_query.h | 18 +++++-- pytorch3d/csrc/ball_query/ball_query_cpu.cpp | 14 ++++- pytorch3d/ops/ball_query.py | 16 ++++-- tests/benchmarks/bm_ball_query_large.py | 56 ++++++++++++++++++++ 5 files changed, 115 insertions(+), 14 deletions(-) create mode 100644 tests/benchmarks/bm_ball_query_large.py diff --git a/pytorch3d/csrc/ball_query/ball_query.cu b/pytorch3d/csrc/ball_query/ball_query.cu index 586701c1..2314d7af 100644 --- a/pytorch3d/csrc/ball_query/ball_query.cu +++ b/pytorch3d/csrc/ball_query/ball_query.cu @@ -32,7 +32,9 @@ __global__ void BallQueryKernel( at::PackedTensorAccessor64 idxs, at::PackedTensorAccessor64 dists, const int64_t K, - const float radius2) { + const float radius, + const float radius2, + const bool skip_points_outside_cube) { const int64_t N = p1.size(0); const int64_t chunks_per_cloud = (1 + (p1.size(1) - 1) / blockDim.x); const int64_t chunks_to_do = N * chunks_per_cloud; @@ -51,7 +53,19 @@ __global__ void BallQueryKernel( // Iterate over points in p2 until desired count is reached or // all points have been considered for (int64_t j = 0, count = 0; j < lengths2[n] && count < K; ++j) { - // Calculate the distance between the points + if (skip_points_outside_cube) { + bool is_within_radius = true; + // Filter when any one coordinate is already outside the radius + for (int d = 0; is_within_radius && d < D; ++d) { + scalar_t abs_diff = fabs(p1[n][i][d] - p2[n][j][d]); + is_within_radius = (abs_diff <= radius); + } + if (!is_within_radius) { + continue; + } + } + + // Else, calculate the distance between the points and compare scalar_t dist2 = 0.0; for (int d = 0; d < D; ++d) { scalar_t diff = p1[n][i][d] - p2[n][j][d]; @@ -77,7 +91,8 @@ std::tuple BallQueryCuda( const at::Tensor& lengths1, // (N,) const at::Tensor& lengths2, // (N,) int K, - float radius) { + float radius, + bool skip_points_outside_cube) { // Check inputs are on the same device at::TensorArg p1_t{p1, "p1", 1}, p2_t{p2, "p2", 2}, lengths1_t{lengths1, "lengths1", 3}, lengths2_t{lengths2, "lengths2", 4}; @@ -120,7 +135,9 @@ std::tuple BallQueryCuda( idxs.packed_accessor64(), dists.packed_accessor64(), K_64, - radius2); + radius, + radius2, + skip_points_outside_cube); })); AT_CUDA_CHECK(cudaGetLastError()); diff --git a/pytorch3d/csrc/ball_query/ball_query.h b/pytorch3d/csrc/ball_query/ball_query.h index eb8f54da..dc7a7851 100644 --- a/pytorch3d/csrc/ball_query/ball_query.h +++ b/pytorch3d/csrc/ball_query/ball_query.h @@ -25,6 +25,9 @@ // within the radius // radius: the radius around each point within which the neighbors need to be // located +// skip_points_outside_cube: If true, reduce multiplications of float values +// by not explicitly calculating distances to points that fall outside the +// D-cube with side length (2*radius) centered at each point in p1. // // Returns: // p1_neighbor_idx: LongTensor of shape (N, P1, K), where @@ -46,7 +49,8 @@ std::tuple BallQueryCpu( const at::Tensor& lengths1, const at::Tensor& lengths2, const int K, - const float radius); + const float radius, + const bool skip_points_outside_cube); // CUDA implementation std::tuple BallQueryCuda( @@ -55,7 +59,8 @@ std::tuple BallQueryCuda( const at::Tensor& lengths1, const at::Tensor& lengths2, const int K, - const float radius); + const float radius, + const bool skip_points_outside_cube); // Implementation which is exposed // Note: the backward pass reuses the KNearestNeighborBackward kernel @@ -65,7 +70,8 @@ inline std::tuple BallQuery( const at::Tensor& lengths1, const at::Tensor& lengths2, int K, - float radius) { + float radius, + bool skip_points_outside_cube) { if (p1.is_cuda() || p2.is_cuda()) { #ifdef WITH_CUDA CHECK_CUDA(p1); @@ -76,7 +82,8 @@ inline std::tuple BallQuery( lengths1.contiguous(), lengths2.contiguous(), K, - radius); + radius, + skip_points_outside_cube); #else AT_ERROR("Not compiled with GPU support."); #endif @@ -89,5 +96,6 @@ inline std::tuple BallQuery( lengths1.contiguous(), lengths2.contiguous(), K, - radius); + radius, + skip_points_outside_cube); } diff --git a/pytorch3d/csrc/ball_query/ball_query_cpu.cpp b/pytorch3d/csrc/ball_query/ball_query_cpu.cpp index 24cdf388..e9f431e4 100644 --- a/pytorch3d/csrc/ball_query/ball_query_cpu.cpp +++ b/pytorch3d/csrc/ball_query/ball_query_cpu.cpp @@ -6,6 +6,7 @@ * LICENSE file in the root directory of this source tree. */ +#include #include #include @@ -15,7 +16,8 @@ std::tuple BallQueryCpu( const at::Tensor& lengths1, const at::Tensor& lengths2, int K, - float radius) { + float radius, + bool skip_points_outside_cube) { const int N = p1.size(0); const int P1 = p1.size(1); const int D = p1.size(2); @@ -37,6 +39,16 @@ std::tuple BallQueryCpu( const int64_t length2 = lengths2_a[n]; for (int64_t i = 0; i < length1; ++i) { for (int64_t j = 0, count = 0; j < length2 && count < K; ++j) { + if (skip_points_outside_cube) { + bool is_within_radius = true; + for (int d = 0; is_within_radius && d < D; ++d) { + float abs_diff = fabs(p1_a[n][i][d] - p2_a[n][j][d]); + is_within_radius = (abs_diff <= radius); + } + if (!is_within_radius) { + continue; + } + } float dist2 = 0; for (int d = 0; d < D; ++d) { float diff = p1_a[n][i][d] - p2_a[n][j][d]; diff --git a/pytorch3d/ops/ball_query.py b/pytorch3d/ops/ball_query.py index 31266c4d..698d816c 100644 --- a/pytorch3d/ops/ball_query.py +++ b/pytorch3d/ops/ball_query.py @@ -23,11 +23,13 @@ class _ball_query(Function): """ @staticmethod - def forward(ctx, p1, p2, lengths1, lengths2, K, radius): + def forward(ctx, p1, p2, lengths1, lengths2, K, radius, skip_points_outside_cube): """ Arguments defintions the same as in the ball_query function """ - idx, dists = _C.ball_query(p1, p2, lengths1, lengths2, K, radius) + idx, dists = _C.ball_query( + p1, p2, lengths1, lengths2, K, radius, skip_points_outside_cube + ) ctx.save_for_backward(p1, p2, lengths1, lengths2, idx) ctx.mark_non_differentiable(idx) return dists, idx @@ -49,7 +51,7 @@ class _ball_query(Function): grad_p1, grad_p2 = _C.knn_points_backward( p1, p2, lengths1, lengths2, idx, 2, grad_dists ) - return grad_p1, grad_p2, None, None, None, None + return grad_p1, grad_p2, None, None, None, None, None def ball_query( @@ -60,6 +62,7 @@ def ball_query( K: int = 500, radius: float = 0.2, return_nn: bool = True, + skip_points_outside_cube: bool = False, ): """ Ball Query is an alternative to KNN. It can be @@ -98,6 +101,9 @@ def ball_query( within the radius radius: the radius around each point within which the neighbors need to be located return_nn: If set to True returns the K neighbor points in p2 for each point in p1. + skip_points_outside_cube: If set to True, reduce multiplications of float values + by not explicitly calculating distances to points that fall outside the + D-cube with side length (2*radius) centered at each point in p1. Returns: dists: Tensor of shape (N, P1, K) giving the squared distances to @@ -134,7 +140,9 @@ def ball_query( if lengths2 is None: lengths2 = torch.full((N,), P2, dtype=torch.int64, device=p1.device) - dists, idx = _ball_query.apply(p1, p2, lengths1, lengths2, K, radius) + dists, idx = _ball_query.apply( + p1, p2, lengths1, lengths2, K, radius, skip_points_outside_cube + ) # Gather the neighbors if needed points_nn = masked_gather(p2, idx) if return_nn else None diff --git a/tests/benchmarks/bm_ball_query_large.py b/tests/benchmarks/bm_ball_query_large.py new file mode 100644 index 00000000..c9b69fa5 --- /dev/null +++ b/tests/benchmarks/bm_ball_query_large.py @@ -0,0 +1,56 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from itertools import product + +import torch +from fvcore.common.benchmark import benchmark + +from pytorch3d.ops.ball_query import ball_query + + +def ball_query_square( + N: int, P1: int, P2: int, D: int, K: int, radius: float, device: str +): + device = torch.device(device) + pts1 = torch.rand(N, P1, D, device=device) + pts2 = torch.rand(N, P2, D, device=device) + torch.cuda.synchronize() + + def output(): + ball_query(pts1, pts2, K=K, radius=radius, skip_points_outside_cube=True) + torch.cuda.synchronize() + + return output + + +def bm_ball_query() -> None: + backends = ["cpu", "cuda:0"] + + kwargs_list = [] + Ns = [32] + P1s = [256] + P2s = [2**p for p in range(9, 20, 2)] + Ds = [3, 10] + Ks = [500] + Rs = [0.01, 0.1] + test_cases = product(Ns, P1s, P2s, Ds, Ks, Rs, backends) + for case in test_cases: + N, P1, P2, D, K, R, b = case + kwargs_list.append( + {"N": N, "P1": P1, "P2": P2, "D": D, "K": K, "radius": R, "device": b} + ) + benchmark( + ball_query_square, + "BALLQUERY_SQUARE", + kwargs_list, + num_iters=30, + warmup_iters=1, + ) + + +if __name__ == "__main__": + bm_ball_query()