mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-19 05:40:34 +08:00
Ball Query
Summary: Implementation of ball query from PointNet++. This function is similar to KNN (find the neighbors in p2 for all points in p1). These are the key differences: - It will return the **first** K neighbors within a specified radius as opposed to the **closest** K neighbors. - As all the points in p2 do not need to be considered to find the closest K, the algorithm is much faster than KNN when p2 has a large number of points. - The neighbors are not sorted - Due to the radius threshold it is not guaranteed that there will be K neighbors even if there are more than K points in p2. - The padding value for `idx` is -1 instead of 0. # Note: - Some of the code is very similar to KNN so it could be possible to modify the KNN forward kernels to support ball query. - Some users might want to use kNN with ball query - for this we could provide a wrapper function around the current `knn_points` which enables applying the radius threshold afterwards as an alternative. This could be called `ball_query_knn`. Reviewed By: jcjohnson Differential Revision: D30261362 fbshipit-source-id: 66b6a7e0114beff7164daf7eba21546ff41ec450
This commit is contained in:
committed by
Facebook GitHub Bot
parent
e5c58a8a8b
commit
103da63393
40
tests/bm_ball_query.py
Normal file
40
tests/bm_ball_query.py
Normal file
@@ -0,0 +1,40 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from itertools import product
|
||||
|
||||
from fvcore.common.benchmark import benchmark
|
||||
from test_ball_query import TestBallQuery
|
||||
|
||||
|
||||
def bm_ball_query() -> None:
|
||||
|
||||
backends = ["cpu", "cuda:0"]
|
||||
|
||||
kwargs_list = []
|
||||
Ns = [32]
|
||||
P1s = [256]
|
||||
P2s = [128, 512]
|
||||
Ds = [3, 10]
|
||||
Ks = [3, 24, 100]
|
||||
Rs = [0.1, 0.2, 5]
|
||||
test_cases = product(Ns, P1s, P2s, Ds, Ks, Rs, backends)
|
||||
for case in test_cases:
|
||||
N, P1, P2, D, K, R, b = case
|
||||
kwargs_list.append(
|
||||
{"N": N, "P1": P1, "P2": P2, "D": D, "K": K, "radius": R, "device": b}
|
||||
)
|
||||
|
||||
benchmark(
|
||||
TestBallQuery.ball_query_square, "BALLQUERY_SQUARE", kwargs_list, warmup_iters=1
|
||||
)
|
||||
benchmark(
|
||||
TestBallQuery.ball_query_ragged, "BALLQUERY_RAGGED", kwargs_list, warmup_iters=1
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
bm_ball_query()
|
||||
230
tests/test_ball_query.py
Normal file
230
tests/test_ball_query.py
Normal file
@@ -0,0 +1,230 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import unittest
|
||||
from itertools import product
|
||||
|
||||
import torch
|
||||
from common_testing import TestCaseMixin, get_random_cuda_device
|
||||
from pytorch3d.ops import sample_points_from_meshes
|
||||
from pytorch3d.ops.ball_query import ball_query
|
||||
from pytorch3d.ops.knn import _KNN
|
||||
from pytorch3d.utils import ico_sphere
|
||||
|
||||
|
||||
class TestBallQuery(TestCaseMixin, unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
torch.manual_seed(1)
|
||||
|
||||
@staticmethod
|
||||
def _ball_query_naive(
|
||||
p1, p2, lengths1, lengths2, K: int, radius: float
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Naive PyTorch implementation of ball query.
|
||||
"""
|
||||
N, P1, D = p1.shape
|
||||
_N, P2, _D = p2.shape
|
||||
|
||||
assert N == _N and D == _D
|
||||
|
||||
if lengths1 is None:
|
||||
lengths1 = torch.full((N,), P1, dtype=torch.int64, device=p1.device)
|
||||
if lengths2 is None:
|
||||
lengths2 = torch.full((N,), P2, dtype=torch.int64, device=p1.device)
|
||||
|
||||
radius2 = radius * radius
|
||||
dists = torch.zeros((N, P1, K), dtype=torch.float32, device=p1.device)
|
||||
idx = torch.full((N, P1, K), fill_value=-1, dtype=torch.int64, device=p1.device)
|
||||
|
||||
# Iterate through the batches
|
||||
for n in range(N):
|
||||
num1 = lengths1[n].item()
|
||||
num2 = lengths2[n].item()
|
||||
|
||||
# Iterate through the points in the p1
|
||||
for i in range(num1):
|
||||
# Iterate through the points in the p2
|
||||
count = 0
|
||||
for j in range(num2):
|
||||
dist = p2[n, j] - p1[n, i]
|
||||
dist2 = (dist * dist).sum()
|
||||
if dist2 < radius2 and count < K:
|
||||
dists[n, i, count] = dist2
|
||||
idx[n, i, count] = j
|
||||
count += 1
|
||||
|
||||
return _KNN(dists=dists, idx=idx, knn=None)
|
||||
|
||||
def _ball_query_vs_python_square_helper(self, device):
|
||||
Ns = [1, 4]
|
||||
Ds = [3, 5, 8]
|
||||
P1s = [8, 24]
|
||||
P2s = [8, 16, 32]
|
||||
Ks = [1, 5]
|
||||
Rs = [3, 5]
|
||||
factors = [Ns, Ds, P1s, P2s, Ks, Rs]
|
||||
for N, D, P1, P2, K, R in product(*factors):
|
||||
x = torch.randn(N, P1, D, device=device, requires_grad=True)
|
||||
x_cuda = x.clone().detach()
|
||||
x_cuda.requires_grad_(True)
|
||||
y = torch.randn(N, P2, D, device=device, requires_grad=True)
|
||||
y_cuda = y.clone().detach()
|
||||
y_cuda.requires_grad_(True)
|
||||
|
||||
# forward
|
||||
out1 = self._ball_query_naive(
|
||||
x, y, lengths1=None, lengths2=None, K=K, radius=R
|
||||
)
|
||||
out2 = ball_query(x_cuda, y_cuda, K=K, radius=R)
|
||||
|
||||
# Check dists
|
||||
self.assertClose(out1.dists, out2.dists)
|
||||
# Check idx
|
||||
self.assertTrue(torch.all(out1.idx == out2.idx))
|
||||
|
||||
# backward
|
||||
grad_dist = torch.ones((N, P1, K), dtype=torch.float32, device=device)
|
||||
loss1 = (out1.dists * grad_dist).sum()
|
||||
loss1.backward()
|
||||
loss2 = (out2.dists * grad_dist).sum()
|
||||
loss2.backward()
|
||||
|
||||
self.assertClose(x_cuda.grad, x.grad, atol=5e-6)
|
||||
self.assertClose(y_cuda.grad, y.grad, atol=5e-6)
|
||||
|
||||
def test_ball_query_vs_python_square_cpu(self):
|
||||
device = torch.device("cpu")
|
||||
self._ball_query_vs_python_square_helper(device)
|
||||
|
||||
def test_ball_query_vs_python_square_cuda(self):
|
||||
device = get_random_cuda_device()
|
||||
self._ball_query_vs_python_square_helper(device)
|
||||
|
||||
def _ball_query_vs_python_ragged_helper(self, device):
|
||||
Ns = [1, 4]
|
||||
Ds = [3, 5, 8]
|
||||
P1s = [8, 24]
|
||||
P2s = [8, 16, 32]
|
||||
Ks = [2, 3, 10]
|
||||
Rs = [1.4, 5] # radius
|
||||
factors = [Ns, Ds, P1s, P2s, Ks, Rs]
|
||||
for N, D, P1, P2, K, R in product(*factors):
|
||||
x = torch.rand((N, P1, D), device=device, requires_grad=True)
|
||||
y = torch.rand((N, P2, D), device=device, requires_grad=True)
|
||||
lengths1 = torch.randint(low=1, high=P1, size=(N,), device=device)
|
||||
lengths2 = torch.randint(low=1, high=P2, size=(N,), device=device)
|
||||
|
||||
x_csrc = x.clone().detach()
|
||||
x_csrc.requires_grad_(True)
|
||||
y_csrc = y.clone().detach()
|
||||
y_csrc.requires_grad_(True)
|
||||
|
||||
# forward
|
||||
out1 = self._ball_query_naive(
|
||||
x, y, lengths1=lengths1, lengths2=lengths2, K=K, radius=R
|
||||
)
|
||||
out2 = ball_query(
|
||||
x_csrc,
|
||||
y_csrc,
|
||||
lengths1=lengths1,
|
||||
lengths2=lengths2,
|
||||
K=K,
|
||||
radius=R,
|
||||
)
|
||||
|
||||
self.assertClose(out1.idx, out2.idx)
|
||||
self.assertClose(out1.dists, out2.dists)
|
||||
|
||||
# backward
|
||||
grad_dist = torch.ones((N, P1, K), dtype=torch.float32, device=device)
|
||||
loss1 = (out1.dists * grad_dist).sum()
|
||||
loss1.backward()
|
||||
loss2 = (out2.dists * grad_dist).sum()
|
||||
loss2.backward()
|
||||
|
||||
self.assertClose(x_csrc.grad, x.grad, atol=5e-6)
|
||||
self.assertClose(y_csrc.grad, y.grad, atol=5e-6)
|
||||
|
||||
def test_ball_query_vs_python_ragged_cpu(self):
|
||||
device = torch.device("cpu")
|
||||
self._ball_query_vs_python_ragged_helper(device)
|
||||
|
||||
def test_ball_query_vs_python_ragged_cuda(self):
|
||||
device = get_random_cuda_device()
|
||||
self._ball_query_vs_python_ragged_helper(device)
|
||||
|
||||
def test_ball_query_output_simple(self):
|
||||
device = get_random_cuda_device()
|
||||
N, P1, P2, K = 5, 8, 16, 4
|
||||
sphere = ico_sphere(level=2, device=device).extend(N)
|
||||
points_1 = sample_points_from_meshes(sphere, P1)
|
||||
points_2 = sample_points_from_meshes(sphere, P2) * 5.0
|
||||
radius = 6.0
|
||||
|
||||
naive_out = self._ball_query_naive(
|
||||
points_1, points_2, lengths1=None, lengths2=None, K=K, radius=radius
|
||||
)
|
||||
cuda_out = ball_query(points_1, points_2, K=K, radius=radius)
|
||||
|
||||
# All points should have N sample neighbors as radius is large
|
||||
# Zero is a valid index but can only be present once (i.e. no zero padding)
|
||||
naive_out_zeros = (naive_out.idx == 0).sum(dim=-1).max()
|
||||
cuda_out_zeros = (cuda_out.idx == 0).sum(dim=-1).max()
|
||||
self.assertTrue(naive_out_zeros == 0 or naive_out_zeros == 1)
|
||||
self.assertTrue(cuda_out_zeros == 0 or cuda_out_zeros == 1)
|
||||
|
||||
# All points should now have zero sample neighbors as radius is small
|
||||
radius = 0.5
|
||||
naive_out = self._ball_query_naive(
|
||||
points_1, points_2, lengths1=None, lengths2=None, K=K, radius=radius
|
||||
)
|
||||
cuda_out = ball_query(points_1, points_2, K=K, radius=radius)
|
||||
naive_out_allzeros = (naive_out.idx == -1).all()
|
||||
cuda_out_allzeros = (cuda_out.idx == -1).sum()
|
||||
self.assertTrue(naive_out_allzeros)
|
||||
self.assertTrue(cuda_out_allzeros)
|
||||
|
||||
@staticmethod
|
||||
def ball_query_square(
|
||||
N: int, P1: int, P2: int, D: int, K: int, radius: float, device: str
|
||||
):
|
||||
device = torch.device(device)
|
||||
pts1 = torch.randn(N, P1, D, device=device, requires_grad=True)
|
||||
pts2 = torch.randn(N, P2, D, device=device, requires_grad=True)
|
||||
grad_dists = torch.randn(N, P1, K, device=device)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def output():
|
||||
out = ball_query(pts1, pts2, K=K, radius=radius)
|
||||
loss = (out.dists * grad_dists).sum()
|
||||
loss.backward()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def ball_query_ragged(
|
||||
N: int, P1: int, P2: int, D: int, K: int, radius: float, device: str
|
||||
):
|
||||
device = torch.device(device)
|
||||
pts1 = torch.rand((N, P1, D), device=device, requires_grad=True)
|
||||
pts2 = torch.rand((N, P2, D), device=device, requires_grad=True)
|
||||
lengths1 = torch.randint(low=1, high=P1, size=(N,), device=device)
|
||||
lengths2 = torch.randint(low=1, high=P2, size=(N,), device=device)
|
||||
grad_dists = torch.randn(N, P1, K, device=device)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def output():
|
||||
out = ball_query(
|
||||
pts1, pts2, lengths1=lengths1, lengths2=lengths2, K=K, radius=radius
|
||||
)
|
||||
loss = (out.dists * grad_dists).sum()
|
||||
loss.backward()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
return output
|
||||
Reference in New Issue
Block a user