mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2026-03-01 09:45:58 +08:00
Summary: ### Overview The current C++ code for `pytorch3d.ops.ball_query()` performs floating point multiplication for every coordinate of every pair of points (up until the maximum number of neighbor points is reached). This PR modifies the code (for both CPU and CUDA versions) to implement idea presented [here](https://stackoverflow.com/a/3939525): a `D`-cube around the `D`-ball is first constructed, and any point pairs falling outside the cube are skipped, without explicitly computing the squared distances. This change is especially useful for when the dimension `D` and the number of points `P2` are large and the radius is much smaller than the overall volume of space occupied by the point clouds; as much as **~2.5x speedup** (CPU case; ~1.8x speedup in CUDA case) is observed when `D = 10` and `radius = 0.01`. In all benchmark cases, points were uniform randomly distributed inside a unit `D`-cube. The benchmark code used was different from `tests/benchmarks/bm_ball_query.py` (only the forward part is benchmarked, larger input sizes were used) and is stored in `tests/benchmarks/bm_ball_query_large.py`. ### Average time comparisons <img width="360" height="270" alt="cpu-03-0 01-avg" src="https://github.com/user-attachments/assets/6cc79893-7921-44af-9366-1766c3caf142" /> <img width="360" height="270" alt="cuda-03-0 01-avg" src="https://github.com/user-attachments/assets/5151647d-0273-40a3-aac6-8b9399ede18a" /> <img width="360" height="270" alt="cpu-03-0 10-avg" src="https://github.com/user-attachments/assets/a87bc150-a5eb-47cd-a4ba-83c2ec81edaf" /> <img width="360" height="270" alt="cuda-03-0 10-avg" src="https://github.com/user-attachments/assets/e3699a9f-dfd3-4dd3-b3c9-619296186d43" /> <img width="360" height="270" alt="cpu-10-0 01-avg" src="https://github.com/user-attachments/assets/5ec8c32d-8e4d-4ced-a94e-1b816b1cb0f8" /> <img width="360" height="270" alt="cuda-10-0 01-avg" src="https://github.com/user-attachments/assets/168a3dfc-777a-4fb3-8023-1ac8c13985b8" /> <img width="360" height="270" alt="cpu-10-0 10-avg" src="https://github.com/user-attachments/assets/43a57fd6-1e01-4c5e-87a9-8ef604ef5fa0" /> <img width="360" height="270" alt="cuda-10-0 10-avg" src="https://github.com/user-attachments/assets/a7c7cc69-f273-493e-95b8-3ba2bb2e32da" /> ### Peak time comparisons <img width="360" height="270" alt="cpu-03-0 01-peak" src="https://github.com/user-attachments/assets/5bbbea3f-ef9b-490d-ab0d-ce551711d74f" /> <img width="360" height="270" alt="cuda-03-0 01-peak" src="https://github.com/user-attachments/assets/30b5ab9b-45cb-4057-b69f-bda6e76bd1dc" /> <img width="360" height="270" alt="cpu-03-0 10-peak" src="https://github.com/user-attachments/assets/db69c333-e5ac-4305-8a86-a26a8a9fe80d" /> <img width="360" height="270" alt="cuda-03-0 10-peak" src="https://github.com/user-attachments/assets/82549656-1f12-409e-8160-dd4c4c9d14f7" /> <img width="360" height="270" alt="cpu-10-0 01-peak" src="https://github.com/user-attachments/assets/d0be8ef1-535e-47bc-b773-b87fad625bf0" /> <img width="360" height="270" alt="cuda-10-0 01-peak" src="https://github.com/user-attachments/assets/e308e66e-ae30-400f-8ad2-015517f6e1af" /> <img width="360" height="270" alt="cpu-10-0 10-peak" src="https://github.com/user-attachments/assets/c9b5bf59-9cc2-465c-ad5d-d4e23bdd138a" /> <img width="360" height="270" alt="cuda-10-0 10-peak" src="https://github.com/user-attachments/assets/311354d4-b488-400c-a1dc-c85a21917aa9" /> ### Full benchmark logs [benchmark-before-change.txt](https://github.com/user-attachments/files/22978300/benchmark-before-change.txt) [benchmark-after-change.txt](https://github.com/user-attachments/files/22978299/benchmark-after-change.txt) Pull Request resolved: https://github.com/facebookresearch/pytorch3d/pull/2006 Reviewed By: shapovalov Differential Revision: D85356394 Pulled By: bottler fbshipit-source-id: 9b3ce5fc87bb73d4323cc5b4190fc38ae42f41b2
102 lines
3.3 KiB
C++
102 lines
3.3 KiB
C++
/*
|
|
* Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
* All rights reserved.
|
|
*
|
|
* This source code is licensed under the BSD-style license found in the
|
|
* LICENSE file in the root directory of this source tree.
|
|
*/
|
|
|
|
#pragma once
|
|
#include <torch/extension.h>
|
|
#include <tuple>
|
|
#include "utils/pytorch3d_cutils.h"
|
|
|
|
// Compute indices of K neighbors in pointcloud p2 to points
|
|
// in pointcloud p1 which fall within a specified radius
|
|
//
|
|
// Args:
|
|
// p1: FloatTensor of shape (N, P1, D) giving a batch of pointclouds each
|
|
// containing P1 points of dimension D.
|
|
// p2: FloatTensor of shape (N, P2, D) giving a batch of pointclouds each
|
|
// containing P2 points of dimension D.
|
|
// lengths1: LongTensor, shape (N,), giving actual length of each P1 cloud.
|
|
// lengths2: LongTensor, shape (N,), giving actual length of each P2 cloud.
|
|
// K: Integer giving the upper bound on the number of samples to take
|
|
// within the radius
|
|
// radius: the radius around each point within which the neighbors need to be
|
|
// located
|
|
// skip_points_outside_cube: If true, reduce multiplications of float values
|
|
// by not explicitly calculating distances to points that fall outside the
|
|
// D-cube with side length (2*radius) centered at each point in p1.
|
|
//
|
|
// Returns:
|
|
// p1_neighbor_idx: LongTensor of shape (N, P1, K), where
|
|
// p1_neighbor_idx[n, i, k] = j means that the kth
|
|
// neighbor to p1[n, i] in the cloud p2[n] is p2[n, j].
|
|
// This is padded with -1s both where a cloud in p2 has fewer than
|
|
// S points and where a cloud in p1 has fewer than P1 points and
|
|
// also if there are fewer than K points which satisfy the radius
|
|
// threshold.
|
|
//
|
|
// p1_neighbor_dists: FloatTensor of shape (N, P1, K) containing the squared
|
|
// distance from each point p1[n, p, :] to its K neighbors
|
|
// p2[n, p1_neighbor_idx[n, p, k], :].
|
|
|
|
// CPU implementation
|
|
std::tuple<at::Tensor, at::Tensor> BallQueryCpu(
|
|
const at::Tensor& p1,
|
|
const at::Tensor& p2,
|
|
const at::Tensor& lengths1,
|
|
const at::Tensor& lengths2,
|
|
const int K,
|
|
const float radius,
|
|
const bool skip_points_outside_cube);
|
|
|
|
// CUDA implementation
|
|
std::tuple<at::Tensor, at::Tensor> BallQueryCuda(
|
|
const at::Tensor& p1,
|
|
const at::Tensor& p2,
|
|
const at::Tensor& lengths1,
|
|
const at::Tensor& lengths2,
|
|
const int K,
|
|
const float radius,
|
|
const bool skip_points_outside_cube);
|
|
|
|
// Implementation which is exposed
|
|
// Note: the backward pass reuses the KNearestNeighborBackward kernel
|
|
inline std::tuple<at::Tensor, at::Tensor> BallQuery(
|
|
const at::Tensor& p1,
|
|
const at::Tensor& p2,
|
|
const at::Tensor& lengths1,
|
|
const at::Tensor& lengths2,
|
|
int K,
|
|
float radius,
|
|
bool skip_points_outside_cube) {
|
|
if (p1.is_cuda() || p2.is_cuda()) {
|
|
#ifdef WITH_CUDA
|
|
CHECK_CUDA(p1);
|
|
CHECK_CUDA(p2);
|
|
return BallQueryCuda(
|
|
p1.contiguous(),
|
|
p2.contiguous(),
|
|
lengths1.contiguous(),
|
|
lengths2.contiguous(),
|
|
K,
|
|
radius,
|
|
skip_points_outside_cube);
|
|
#else
|
|
AT_ERROR("Not compiled with GPU support.");
|
|
#endif
|
|
}
|
|
CHECK_CPU(p1);
|
|
CHECK_CPU(p2);
|
|
return BallQueryCpu(
|
|
p1.contiguous(),
|
|
p2.contiguous(),
|
|
lengths1.contiguous(),
|
|
lengths2.contiguous(),
|
|
K,
|
|
radius,
|
|
skip_points_outside_cube);
|
|
}
|