mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-20 22:30:35 +08:00
Summary: ### Overview The current C++ code for `pytorch3d.ops.ball_query()` performs floating point multiplication for every coordinate of every pair of points (up until the maximum number of neighbor points is reached). This PR modifies the code (for both CPU and CUDA versions) to implement idea presented [here](https://stackoverflow.com/a/3939525): a `D`-cube around the `D`-ball is first constructed, and any point pairs falling outside the cube are skipped, without explicitly computing the squared distances. This change is especially useful for when the dimension `D` and the number of points `P2` are large and the radius is much smaller than the overall volume of space occupied by the point clouds; as much as **~2.5x speedup** (CPU case; ~1.8x speedup in CUDA case) is observed when `D = 10` and `radius = 0.01`. In all benchmark cases, points were uniform randomly distributed inside a unit `D`-cube. The benchmark code used was different from `tests/benchmarks/bm_ball_query.py` (only the forward part is benchmarked, larger input sizes were used) and is stored in `tests/benchmarks/bm_ball_query_large.py`. ### Average time comparisons <img width="360" height="270" alt="cpu-03-0 01-avg" src="https://github.com/user-attachments/assets/6cc79893-7921-44af-9366-1766c3caf142" /> <img width="360" height="270" alt="cuda-03-0 01-avg" src="https://github.com/user-attachments/assets/5151647d-0273-40a3-aac6-8b9399ede18a" /> <img width="360" height="270" alt="cpu-03-0 10-avg" src="https://github.com/user-attachments/assets/a87bc150-a5eb-47cd-a4ba-83c2ec81edaf" /> <img width="360" height="270" alt="cuda-03-0 10-avg" src="https://github.com/user-attachments/assets/e3699a9f-dfd3-4dd3-b3c9-619296186d43" /> <img width="360" height="270" alt="cpu-10-0 01-avg" src="https://github.com/user-attachments/assets/5ec8c32d-8e4d-4ced-a94e-1b816b1cb0f8" /> <img width="360" height="270" alt="cuda-10-0 01-avg" src="https://github.com/user-attachments/assets/168a3dfc-777a-4fb3-8023-1ac8c13985b8" /> <img width="360" height="270" alt="cpu-10-0 10-avg" src="https://github.com/user-attachments/assets/43a57fd6-1e01-4c5e-87a9-8ef604ef5fa0" /> <img width="360" height="270" alt="cuda-10-0 10-avg" src="https://github.com/user-attachments/assets/a7c7cc69-f273-493e-95b8-3ba2bb2e32da" /> ### Peak time comparisons <img width="360" height="270" alt="cpu-03-0 01-peak" src="https://github.com/user-attachments/assets/5bbbea3f-ef9b-490d-ab0d-ce551711d74f" /> <img width="360" height="270" alt="cuda-03-0 01-peak" src="https://github.com/user-attachments/assets/30b5ab9b-45cb-4057-b69f-bda6e76bd1dc" /> <img width="360" height="270" alt="cpu-03-0 10-peak" src="https://github.com/user-attachments/assets/db69c333-e5ac-4305-8a86-a26a8a9fe80d" /> <img width="360" height="270" alt="cuda-03-0 10-peak" src="https://github.com/user-attachments/assets/82549656-1f12-409e-8160-dd4c4c9d14f7" /> <img width="360" height="270" alt="cpu-10-0 01-peak" src="https://github.com/user-attachments/assets/d0be8ef1-535e-47bc-b773-b87fad625bf0" /> <img width="360" height="270" alt="cuda-10-0 01-peak" src="https://github.com/user-attachments/assets/e308e66e-ae30-400f-8ad2-015517f6e1af" /> <img width="360" height="270" alt="cpu-10-0 10-peak" src="https://github.com/user-attachments/assets/c9b5bf59-9cc2-465c-ad5d-d4e23bdd138a" /> <img width="360" height="270" alt="cuda-10-0 10-peak" src="https://github.com/user-attachments/assets/311354d4-b488-400c-a1dc-c85a21917aa9" /> ### Full benchmark logs [benchmark-before-change.txt](https://github.com/user-attachments/files/22978300/benchmark-before-change.txt) [benchmark-after-change.txt](https://github.com/user-attachments/files/22978299/benchmark-after-change.txt) Pull Request resolved: https://github.com/facebookresearch/pytorch3d/pull/2006 Reviewed By: shapovalov Differential Revision: D85356394 Pulled By: bottler fbshipit-source-id: 9b3ce5fc87bb73d4323cc5b4190fc38ae42f41b2
151 lines
6.0 KiB
Python
151 lines
6.0 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||
# All rights reserved.
|
||
#
|
||
# This source code is licensed under the BSD-style license found in the
|
||
# LICENSE file in the root directory of this source tree.
|
||
|
||
# pyre-unsafe
|
||
|
||
from typing import Union
|
||
|
||
import torch
|
||
from pytorch3d import _C
|
||
from torch.autograd import Function
|
||
from torch.autograd.function import once_differentiable
|
||
|
||
from .knn import _KNN
|
||
from .utils import masked_gather
|
||
|
||
|
||
class _ball_query(Function):
|
||
"""
|
||
Torch autograd Function wrapper for Ball Query C++/CUDA implementations.
|
||
"""
|
||
|
||
@staticmethod
|
||
def forward(ctx, p1, p2, lengths1, lengths2, K, radius, skip_points_outside_cube):
|
||
"""
|
||
Arguments defintions the same as in the ball_query function
|
||
"""
|
||
idx, dists = _C.ball_query(
|
||
p1, p2, lengths1, lengths2, K, radius, skip_points_outside_cube
|
||
)
|
||
ctx.save_for_backward(p1, p2, lengths1, lengths2, idx)
|
||
ctx.mark_non_differentiable(idx)
|
||
return dists, idx
|
||
|
||
@staticmethod
|
||
@once_differentiable
|
||
def backward(ctx, grad_dists, grad_idx):
|
||
p1, p2, lengths1, lengths2, idx = ctx.saved_tensors
|
||
# TODO(gkioxari) Change cast to floats once we add support for doubles.
|
||
if not (grad_dists.dtype == torch.float32):
|
||
grad_dists = grad_dists.float()
|
||
if not (p1.dtype == torch.float32):
|
||
p1 = p1.float()
|
||
if not (p2.dtype == torch.float32):
|
||
p2 = p2.float()
|
||
|
||
# Reuse the KNN backward function
|
||
# by default, norm is 2
|
||
grad_p1, grad_p2 = _C.knn_points_backward(
|
||
p1, p2, lengths1, lengths2, idx, 2, grad_dists
|
||
)
|
||
return grad_p1, grad_p2, None, None, None, None, None
|
||
|
||
|
||
def ball_query(
|
||
p1: torch.Tensor,
|
||
p2: torch.Tensor,
|
||
lengths1: Union[torch.Tensor, None] = None,
|
||
lengths2: Union[torch.Tensor, None] = None,
|
||
K: int = 500,
|
||
radius: float = 0.2,
|
||
return_nn: bool = True,
|
||
skip_points_outside_cube: bool = False,
|
||
):
|
||
"""
|
||
Ball Query is an alternative to KNN. It can be
|
||
used to find all points in p2 that are within a specified radius
|
||
to the query point in p1 (with an upper limit of K neighbors).
|
||
|
||
The neighbors returned are not necssarily the *nearest* to the
|
||
point in p1, just the first K values in p2 which are within the
|
||
specified radius.
|
||
|
||
This method is faster than kNN when there are large numbers of points
|
||
in p2 and the ordering of neighbors is not important compared to the
|
||
distance being within the radius threshold.
|
||
|
||
"Ball query’s local neighborhood guarantees a fixed region scale thus
|
||
making local region features more generalizable across space, which is
|
||
preferred for tasks requiring local pattern recognition
|
||
(e.g. semantic point labeling)" [1].
|
||
|
||
[1] Charles R. Qi et al, "PointNet++: Deep Hierarchical Feature Learning
|
||
on Point Sets in a Metric Space", NeurIPS 2017.
|
||
|
||
Args:
|
||
p1: Tensor of shape (N, P1, D) giving a batch of N point clouds, each
|
||
containing up to P1 points of dimension D. These represent the centers of
|
||
the ball queries.
|
||
p2: Tensor of shape (N, P2, D) giving a batch of N point clouds, each
|
||
containing up to P2 points of dimension D.
|
||
lengths1: LongTensor of shape (N,) of values in the range [0, P1], giving the
|
||
length of each pointcloud in p1. Or None to indicate that every cloud has
|
||
length P1.
|
||
lengths2: LongTensor of shape (N,) of values in the range [0, P2], giving the
|
||
length of each pointcloud in p2. Or None to indicate that every cloud has
|
||
length P2.
|
||
K: Integer giving the upper bound on the number of samples to take
|
||
within the radius
|
||
radius: the radius around each point within which the neighbors need to be located
|
||
return_nn: If set to True returns the K neighbor points in p2 for each point in p1.
|
||
skip_points_outside_cube: If set to True, reduce multiplications of float values
|
||
by not explicitly calculating distances to points that fall outside the
|
||
D-cube with side length (2*radius) centered at each point in p1.
|
||
|
||
Returns:
|
||
dists: Tensor of shape (N, P1, K) giving the squared distances to
|
||
the neighbors. This is padded with zeros both where a cloud in p2
|
||
has fewer than S points and where a cloud in p1 has fewer than P1 points
|
||
and also if there are fewer than K points which satisfy the radius threshold.
|
||
|
||
idx: LongTensor of shape (N, P1, K) giving the indices of the
|
||
S neighbors in p2 for points in p1.
|
||
Concretely, if `p1_idx[n, i, k] = j` then `p2[n, j]` is the k-th
|
||
neighbor to `p1[n, i]` in `p2[n]`. This is padded with -1 both where a cloud
|
||
in p2 has fewer than S points and where a cloud in p1 has fewer than P1
|
||
points and also if there are fewer than K points which satisfy the radius threshold.
|
||
|
||
nn: Tensor of shape (N, P1, K, D) giving the K neighbors in p2 for
|
||
each point in p1. Concretely, `p2_nn[n, i, k]` gives the k-th neighbor
|
||
for `p1[n, i]`. Returned if `return_nn` is True. The output is a tensor
|
||
of shape (N, P1, K, U).
|
||
|
||
"""
|
||
if p1.shape[0] != p2.shape[0]:
|
||
raise ValueError("pts1 and pts2 must have the same batch dimension.")
|
||
if p1.shape[2] != p2.shape[2]:
|
||
raise ValueError("pts1 and pts2 must have the same point dimension.")
|
||
|
||
p1 = p1.contiguous()
|
||
p2 = p2.contiguous()
|
||
P1 = p1.shape[1]
|
||
P2 = p2.shape[1]
|
||
N = p1.shape[0]
|
||
|
||
if lengths1 is None:
|
||
lengths1 = torch.full((N,), P1, dtype=torch.int64, device=p1.device)
|
||
if lengths2 is None:
|
||
lengths2 = torch.full((N,), P2, dtype=torch.int64, device=p1.device)
|
||
|
||
dists, idx = _ball_query.apply(
|
||
p1, p2, lengths1, lengths2, K, radius, skip_points_outside_cube
|
||
)
|
||
|
||
# Gather the neighbors if needed
|
||
points_nn = masked_gather(p2, idx) if return_nn else None
|
||
|
||
return _KNN(dists=dists, idx=idx, knn=points_nn)
|