mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-11-04 01:42:11 +08:00
Improve ball_query() runtime for large-scale cases (#2006)
Summary: ### Overview The current C++ code for `pytorch3d.ops.ball_query()` performs floating point multiplication for every coordinate of every pair of points (up until the maximum number of neighbor points is reached). This PR modifies the code (for both CPU and CUDA versions) to implement idea presented [here](https://stackoverflow.com/a/3939525): a `D`-cube around the `D`-ball is first constructed, and any point pairs falling outside the cube are skipped, without explicitly computing the squared distances. This change is especially useful for when the dimension `D` and the number of points `P2` are large and the radius is much smaller than the overall volume of space occupied by the point clouds; as much as **~2.5x speedup** (CPU case; ~1.8x speedup in CUDA case) is observed when `D = 10` and `radius = 0.01`. In all benchmark cases, points were uniform randomly distributed inside a unit `D`-cube. The benchmark code used was different from `tests/benchmarks/bm_ball_query.py` (only the forward part is benchmarked, larger input sizes were used) and is stored in `tests/benchmarks/bm_ball_query_large.py`. ### Average time comparisons <img width="360" height="270" alt="cpu-03-0 01-avg" src="https://github.com/user-attachments/assets/6cc79893-7921-44af-9366-1766c3caf142" /> <img width="360" height="270" alt="cuda-03-0 01-avg" src="https://github.com/user-attachments/assets/5151647d-0273-40a3-aac6-8b9399ede18a" /> <img width="360" height="270" alt="cpu-03-0 10-avg" src="https://github.com/user-attachments/assets/a87bc150-a5eb-47cd-a4ba-83c2ec81edaf" /> <img width="360" height="270" alt="cuda-03-0 10-avg" src="https://github.com/user-attachments/assets/e3699a9f-dfd3-4dd3-b3c9-619296186d43" /> <img width="360" height="270" alt="cpu-10-0 01-avg" src="https://github.com/user-attachments/assets/5ec8c32d-8e4d-4ced-a94e-1b816b1cb0f8" /> <img width="360" height="270" alt="cuda-10-0 01-avg" src="https://github.com/user-attachments/assets/168a3dfc-777a-4fb3-8023-1ac8c13985b8" /> <img width="360" height="270" alt="cpu-10-0 10-avg" src="https://github.com/user-attachments/assets/43a57fd6-1e01-4c5e-87a9-8ef604ef5fa0" /> <img width="360" height="270" alt="cuda-10-0 10-avg" src="https://github.com/user-attachments/assets/a7c7cc69-f273-493e-95b8-3ba2bb2e32da" /> ### Peak time comparisons <img width="360" height="270" alt="cpu-03-0 01-peak" src="https://github.com/user-attachments/assets/5bbbea3f-ef9b-490d-ab0d-ce551711d74f" /> <img width="360" height="270" alt="cuda-03-0 01-peak" src="https://github.com/user-attachments/assets/30b5ab9b-45cb-4057-b69f-bda6e76bd1dc" /> <img width="360" height="270" alt="cpu-03-0 10-peak" src="https://github.com/user-attachments/assets/db69c333-e5ac-4305-8a86-a26a8a9fe80d" /> <img width="360" height="270" alt="cuda-03-0 10-peak" src="https://github.com/user-attachments/assets/82549656-1f12-409e-8160-dd4c4c9d14f7" /> <img width="360" height="270" alt="cpu-10-0 01-peak" src="https://github.com/user-attachments/assets/d0be8ef1-535e-47bc-b773-b87fad625bf0" /> <img width="360" height="270" alt="cuda-10-0 01-peak" src="https://github.com/user-attachments/assets/e308e66e-ae30-400f-8ad2-015517f6e1af" /> <img width="360" height="270" alt="cpu-10-0 10-peak" src="https://github.com/user-attachments/assets/c9b5bf59-9cc2-465c-ad5d-d4e23bdd138a" /> <img width="360" height="270" alt="cuda-10-0 10-peak" src="https://github.com/user-attachments/assets/311354d4-b488-400c-a1dc-c85a21917aa9" /> ### Full benchmark logs [benchmark-before-change.txt](https://github.com/user-attachments/files/22978300/benchmark-before-change.txt) [benchmark-after-change.txt](https://github.com/user-attachments/files/22978299/benchmark-after-change.txt) Pull Request resolved: https://github.com/facebookresearch/pytorch3d/pull/2006 Reviewed By: shapovalov Differential Revision: D85356394 Pulled By: bottler fbshipit-source-id: 9b3ce5fc87bb73d4323cc5b4190fc38ae42f41b2
This commit is contained in:
parent
45df20e9e2
commit
2d4d345b6f
@ -32,7 +32,9 @@ __global__ void BallQueryKernel(
|
||||
at::PackedTensorAccessor64<int64_t, 3, at::RestrictPtrTraits> idxs,
|
||||
at::PackedTensorAccessor64<scalar_t, 3, at::RestrictPtrTraits> dists,
|
||||
const int64_t K,
|
||||
const float radius2) {
|
||||
const float radius,
|
||||
const float radius2,
|
||||
const bool skip_points_outside_cube) {
|
||||
const int64_t N = p1.size(0);
|
||||
const int64_t chunks_per_cloud = (1 + (p1.size(1) - 1) / blockDim.x);
|
||||
const int64_t chunks_to_do = N * chunks_per_cloud;
|
||||
@ -51,7 +53,19 @@ __global__ void BallQueryKernel(
|
||||
// Iterate over points in p2 until desired count is reached or
|
||||
// all points have been considered
|
||||
for (int64_t j = 0, count = 0; j < lengths2[n] && count < K; ++j) {
|
||||
// Calculate the distance between the points
|
||||
if (skip_points_outside_cube) {
|
||||
bool is_within_radius = true;
|
||||
// Filter when any one coordinate is already outside the radius
|
||||
for (int d = 0; is_within_radius && d < D; ++d) {
|
||||
scalar_t abs_diff = fabs(p1[n][i][d] - p2[n][j][d]);
|
||||
is_within_radius = (abs_diff <= radius);
|
||||
}
|
||||
if (!is_within_radius) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// Else, calculate the distance between the points and compare
|
||||
scalar_t dist2 = 0.0;
|
||||
for (int d = 0; d < D; ++d) {
|
||||
scalar_t diff = p1[n][i][d] - p2[n][j][d];
|
||||
@ -77,7 +91,8 @@ std::tuple<at::Tensor, at::Tensor> BallQueryCuda(
|
||||
const at::Tensor& lengths1, // (N,)
|
||||
const at::Tensor& lengths2, // (N,)
|
||||
int K,
|
||||
float radius) {
|
||||
float radius,
|
||||
bool skip_points_outside_cube) {
|
||||
// Check inputs are on the same device
|
||||
at::TensorArg p1_t{p1, "p1", 1}, p2_t{p2, "p2", 2},
|
||||
lengths1_t{lengths1, "lengths1", 3}, lengths2_t{lengths2, "lengths2", 4};
|
||||
@ -120,7 +135,9 @@ std::tuple<at::Tensor, at::Tensor> BallQueryCuda(
|
||||
idxs.packed_accessor64<int64_t, 3, at::RestrictPtrTraits>(),
|
||||
dists.packed_accessor64<float, 3, at::RestrictPtrTraits>(),
|
||||
K_64,
|
||||
radius2);
|
||||
radius,
|
||||
radius2,
|
||||
skip_points_outside_cube);
|
||||
}));
|
||||
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
@ -25,6 +25,9 @@
|
||||
// within the radius
|
||||
// radius: the radius around each point within which the neighbors need to be
|
||||
// located
|
||||
// skip_points_outside_cube: If true, reduce multiplications of float values
|
||||
// by not explicitly calculating distances to points that fall outside the
|
||||
// D-cube with side length (2*radius) centered at each point in p1.
|
||||
//
|
||||
// Returns:
|
||||
// p1_neighbor_idx: LongTensor of shape (N, P1, K), where
|
||||
@ -46,7 +49,8 @@ std::tuple<at::Tensor, at::Tensor> BallQueryCpu(
|
||||
const at::Tensor& lengths1,
|
||||
const at::Tensor& lengths2,
|
||||
const int K,
|
||||
const float radius);
|
||||
const float radius,
|
||||
const bool skip_points_outside_cube);
|
||||
|
||||
// CUDA implementation
|
||||
std::tuple<at::Tensor, at::Tensor> BallQueryCuda(
|
||||
@ -55,7 +59,8 @@ std::tuple<at::Tensor, at::Tensor> BallQueryCuda(
|
||||
const at::Tensor& lengths1,
|
||||
const at::Tensor& lengths2,
|
||||
const int K,
|
||||
const float radius);
|
||||
const float radius,
|
||||
const bool skip_points_outside_cube);
|
||||
|
||||
// Implementation which is exposed
|
||||
// Note: the backward pass reuses the KNearestNeighborBackward kernel
|
||||
@ -65,7 +70,8 @@ inline std::tuple<at::Tensor, at::Tensor> BallQuery(
|
||||
const at::Tensor& lengths1,
|
||||
const at::Tensor& lengths2,
|
||||
int K,
|
||||
float radius) {
|
||||
float radius,
|
||||
bool skip_points_outside_cube) {
|
||||
if (p1.is_cuda() || p2.is_cuda()) {
|
||||
#ifdef WITH_CUDA
|
||||
CHECK_CUDA(p1);
|
||||
@ -76,7 +82,8 @@ inline std::tuple<at::Tensor, at::Tensor> BallQuery(
|
||||
lengths1.contiguous(),
|
||||
lengths2.contiguous(),
|
||||
K,
|
||||
radius);
|
||||
radius,
|
||||
skip_points_outside_cube);
|
||||
#else
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
@ -89,5 +96,6 @@ inline std::tuple<at::Tensor, at::Tensor> BallQuery(
|
||||
lengths1.contiguous(),
|
||||
lengths2.contiguous(),
|
||||
K,
|
||||
radius);
|
||||
radius,
|
||||
skip_points_outside_cube);
|
||||
}
|
||||
|
||||
@ -6,6 +6,7 @@
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*/
|
||||
|
||||
#include <math.h>
|
||||
#include <torch/extension.h>
|
||||
#include <tuple>
|
||||
|
||||
@ -15,7 +16,8 @@ std::tuple<at::Tensor, at::Tensor> BallQueryCpu(
|
||||
const at::Tensor& lengths1,
|
||||
const at::Tensor& lengths2,
|
||||
int K,
|
||||
float radius) {
|
||||
float radius,
|
||||
bool skip_points_outside_cube) {
|
||||
const int N = p1.size(0);
|
||||
const int P1 = p1.size(1);
|
||||
const int D = p1.size(2);
|
||||
@ -37,6 +39,16 @@ std::tuple<at::Tensor, at::Tensor> BallQueryCpu(
|
||||
const int64_t length2 = lengths2_a[n];
|
||||
for (int64_t i = 0; i < length1; ++i) {
|
||||
for (int64_t j = 0, count = 0; j < length2 && count < K; ++j) {
|
||||
if (skip_points_outside_cube) {
|
||||
bool is_within_radius = true;
|
||||
for (int d = 0; is_within_radius && d < D; ++d) {
|
||||
float abs_diff = fabs(p1_a[n][i][d] - p2_a[n][j][d]);
|
||||
is_within_radius = (abs_diff <= radius);
|
||||
}
|
||||
if (!is_within_radius) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
float dist2 = 0;
|
||||
for (int d = 0; d < D; ++d) {
|
||||
float diff = p1_a[n][i][d] - p2_a[n][j][d];
|
||||
|
||||
@ -23,11 +23,13 @@ class _ball_query(Function):
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, p1, p2, lengths1, lengths2, K, radius):
|
||||
def forward(ctx, p1, p2, lengths1, lengths2, K, radius, skip_points_outside_cube):
|
||||
"""
|
||||
Arguments defintions the same as in the ball_query function
|
||||
"""
|
||||
idx, dists = _C.ball_query(p1, p2, lengths1, lengths2, K, radius)
|
||||
idx, dists = _C.ball_query(
|
||||
p1, p2, lengths1, lengths2, K, radius, skip_points_outside_cube
|
||||
)
|
||||
ctx.save_for_backward(p1, p2, lengths1, lengths2, idx)
|
||||
ctx.mark_non_differentiable(idx)
|
||||
return dists, idx
|
||||
@ -49,7 +51,7 @@ class _ball_query(Function):
|
||||
grad_p1, grad_p2 = _C.knn_points_backward(
|
||||
p1, p2, lengths1, lengths2, idx, 2, grad_dists
|
||||
)
|
||||
return grad_p1, grad_p2, None, None, None, None
|
||||
return grad_p1, grad_p2, None, None, None, None, None
|
||||
|
||||
|
||||
def ball_query(
|
||||
@ -60,6 +62,7 @@ def ball_query(
|
||||
K: int = 500,
|
||||
radius: float = 0.2,
|
||||
return_nn: bool = True,
|
||||
skip_points_outside_cube: bool = False,
|
||||
):
|
||||
"""
|
||||
Ball Query is an alternative to KNN. It can be
|
||||
@ -98,6 +101,9 @@ def ball_query(
|
||||
within the radius
|
||||
radius: the radius around each point within which the neighbors need to be located
|
||||
return_nn: If set to True returns the K neighbor points in p2 for each point in p1.
|
||||
skip_points_outside_cube: If set to True, reduce multiplications of float values
|
||||
by not explicitly calculating distances to points that fall outside the
|
||||
D-cube with side length (2*radius) centered at each point in p1.
|
||||
|
||||
Returns:
|
||||
dists: Tensor of shape (N, P1, K) giving the squared distances to
|
||||
@ -134,7 +140,9 @@ def ball_query(
|
||||
if lengths2 is None:
|
||||
lengths2 = torch.full((N,), P2, dtype=torch.int64, device=p1.device)
|
||||
|
||||
dists, idx = _ball_query.apply(p1, p2, lengths1, lengths2, K, radius)
|
||||
dists, idx = _ball_query.apply(
|
||||
p1, p2, lengths1, lengths2, K, radius, skip_points_outside_cube
|
||||
)
|
||||
|
||||
# Gather the neighbors if needed
|
||||
points_nn = masked_gather(p2, idx) if return_nn else None
|
||||
|
||||
56
tests/benchmarks/bm_ball_query_large.py
Normal file
56
tests/benchmarks/bm_ball_query_large.py
Normal file
@ -0,0 +1,56 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from itertools import product
|
||||
|
||||
import torch
|
||||
from fvcore.common.benchmark import benchmark
|
||||
|
||||
from pytorch3d.ops.ball_query import ball_query
|
||||
|
||||
|
||||
def ball_query_square(
|
||||
N: int, P1: int, P2: int, D: int, K: int, radius: float, device: str
|
||||
):
|
||||
device = torch.device(device)
|
||||
pts1 = torch.rand(N, P1, D, device=device)
|
||||
pts2 = torch.rand(N, P2, D, device=device)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def output():
|
||||
ball_query(pts1, pts2, K=K, radius=radius, skip_points_outside_cube=True)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def bm_ball_query() -> None:
|
||||
backends = ["cpu", "cuda:0"]
|
||||
|
||||
kwargs_list = []
|
||||
Ns = [32]
|
||||
P1s = [256]
|
||||
P2s = [2**p for p in range(9, 20, 2)]
|
||||
Ds = [3, 10]
|
||||
Ks = [500]
|
||||
Rs = [0.01, 0.1]
|
||||
test_cases = product(Ns, P1s, P2s, Ds, Ks, Rs, backends)
|
||||
for case in test_cases:
|
||||
N, P1, P2, D, K, R, b = case
|
||||
kwargs_list.append(
|
||||
{"N": N, "P1": P1, "P2": P2, "D": D, "K": K, "radius": R, "device": b}
|
||||
)
|
||||
benchmark(
|
||||
ball_query_square,
|
||||
"BALLQUERY_SQUARE",
|
||||
kwargs_list,
|
||||
num_iters=30,
|
||||
warmup_iters=1,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
bm_ball_query()
|
||||
Loading…
x
Reference in New Issue
Block a user