mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-19 22:00:35 +08:00
heterogenous KNN
Summary: Interface and working implementation of ragged KNN. Benchmarks (which aren't ragged) haven't slowed. New benchmark shows that ragged is faster than non-ragged of the same shape. Reviewed By: jcjohnson Differential Revision: D20696507 fbshipit-source-id: 21b80f71343a3475c8d3ee0ce2680f92f0fae4de
This commit is contained in:
committed by
Facebook GitHub Bot
parent
29b9c44c0a
commit
01b5f7b228
@@ -13,6 +13,7 @@ def bm_knn() -> None:
|
||||
benchmark_knn_cpu()
|
||||
benchmark_knn_cuda_vs_naive()
|
||||
benchmark_knn_cuda_versions()
|
||||
benchmark_knn_cuda_versions_ragged()
|
||||
|
||||
|
||||
def benchmark_knn_cuda_versions() -> None:
|
||||
@@ -36,6 +37,25 @@ def benchmark_knn_cuda_versions() -> None:
|
||||
benchmark(nn_cuda_with_init, "NN_CUDA", nn_kwargs, warmup_iters=1)
|
||||
|
||||
|
||||
def benchmark_knn_cuda_versions_ragged() -> None:
|
||||
# Compare our different KNN implementations,
|
||||
# and also compare against our existing 1-NN
|
||||
Ns = [8]
|
||||
Ps = [4096, 16384]
|
||||
Ds = [3]
|
||||
Ks = [1, 4, 16, 64]
|
||||
versions = [0, 1, 2, 3]
|
||||
knn_kwargs = []
|
||||
for N, P, D, K, version in product(Ns, Ps, Ds, Ks, versions):
|
||||
if version == 2 and K > 32:
|
||||
continue
|
||||
if version == 3 and K > 4:
|
||||
continue
|
||||
knn_kwargs.append({"N": N, "D": D, "P": P, "K": K, "v": version})
|
||||
benchmark(knn_cuda_with_init, "KNN_CUDA_COMPARISON", knn_kwargs, warmup_iters=1)
|
||||
benchmark(knn_cuda_ragged, "KNN_CUDA_RAGGED", knn_kwargs, warmup_iters=1)
|
||||
|
||||
|
||||
def benchmark_knn_cuda_vs_naive() -> None:
|
||||
# Compare against naive pytorch version of KNN
|
||||
Ns = [1, 2, 4]
|
||||
@@ -72,10 +92,27 @@ def knn_cuda_with_init(N, D, P, K, v=-1):
|
||||
device = torch.device("cuda:0")
|
||||
x = torch.randn(N, P, D, device=device)
|
||||
y = torch.randn(N, P, D, device=device)
|
||||
lengths = torch.full((N,), P, dtype=torch.int64, device=device)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def knn():
|
||||
_C.knn_points_idx(x, y, K, v)
|
||||
_C.knn_points_idx(x, y, lengths, lengths, K, v)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
return knn
|
||||
|
||||
|
||||
def knn_cuda_ragged(N, D, P, K, v=-1):
|
||||
device = torch.device("cuda:0")
|
||||
x = torch.randn(N, P, D, device=device)
|
||||
y = torch.randn(N, P, D, device=device)
|
||||
lengths1 = torch.randint(P, size=(N,), device=device, dtype=torch.int64)
|
||||
lengths2 = torch.randint(P, size=(N,), device=device, dtype=torch.int64)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def knn():
|
||||
_C.knn_points_idx(x, y, lengths1, lengths2, K, v)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
return knn
|
||||
@@ -85,9 +122,10 @@ def knn_cpu_with_init(N, D, P, K):
|
||||
device = torch.device("cpu")
|
||||
x = torch.randn(N, P, D, device=device)
|
||||
y = torch.randn(N, P, D, device=device)
|
||||
lengths = torch.full((N,), P, dtype=torch.int64, device=device)
|
||||
|
||||
def knn():
|
||||
_C.knn_points_idx(x, y, K, 0)
|
||||
_C.knn_points_idx(x, y, lengths, lengths, K, -1)
|
||||
|
||||
return knn
|
||||
|
||||
@@ -96,10 +134,12 @@ def knn_python_cuda_with_init(N, D, P, K):
|
||||
device = torch.device("cuda")
|
||||
x = torch.randn(N, P, D, device=device)
|
||||
y = torch.randn(N, P, D, device=device)
|
||||
lengths = torch.full((N,), P, dtype=torch.int64, device=device)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def knn():
|
||||
_knn_points_idx_naive(x, y, K)
|
||||
_knn_points_idx_naive(x, y, K=K, lengths1=lengths, lengths2=lengths)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
return knn
|
||||
@@ -109,9 +149,10 @@ def knn_python_cpu_with_init(N, D, P, K):
|
||||
device = torch.device("cpu")
|
||||
x = torch.randn(N, P, D, device=device)
|
||||
y = torch.randn(N, P, D, device=device)
|
||||
lengths = torch.full((N,), P, dtype=torch.int64, device=device)
|
||||
|
||||
def knn():
|
||||
_knn_points_idx_naive(x, y, K)
|
||||
_knn_points_idx_naive(x, y, K=K, lengths1=lengths, lengths2=lengths)
|
||||
|
||||
return knn
|
||||
|
||||
|
||||
@@ -8,6 +8,10 @@ from pytorch3d.ops.knn import _knn_points_idx_naive, knn_points_idx
|
||||
|
||||
|
||||
class TestKNN(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
torch.manual_seed(1)
|
||||
|
||||
def _check_knn_result(self, out1, out2, sorted):
|
||||
# When sorted=True, points should be sorted by distance and should
|
||||
# match between implementations. When sorted=False we we only want to
|
||||
@@ -26,7 +30,7 @@ class TestKNN(unittest.TestCase):
|
||||
self.assertTrue(torch.all(idx1 == idx2))
|
||||
self.assertTrue(torch.allclose(dist1, dist2))
|
||||
|
||||
def test_knn_vs_python_cpu(self):
|
||||
def test_knn_vs_python_cpu_square(self):
|
||||
""" Test CPU output vs PyTorch implementation """
|
||||
device = torch.device("cpu")
|
||||
Ns = [1, 4]
|
||||
@@ -37,13 +41,19 @@ class TestKNN(unittest.TestCase):
|
||||
sorts = [True, False]
|
||||
factors = [Ns, Ds, P1s, P2s, Ks, sorts]
|
||||
for N, D, P1, P2, K, sort in product(*factors):
|
||||
lengths1 = torch.full((N,), P1, dtype=torch.int64, device=device)
|
||||
lengths2 = torch.full((N,), P2, dtype=torch.int64, device=device)
|
||||
x = torch.randn(N, P1, D, device=device)
|
||||
y = torch.randn(N, P2, D, device=device)
|
||||
out1 = _knn_points_idx_naive(x, y, K, sort)
|
||||
out2 = knn_points_idx(x, y, K, sort)
|
||||
out1 = _knn_points_idx_naive(
|
||||
x, y, lengths1=lengths1, lengths2=lengths2, K=K
|
||||
)
|
||||
out2 = knn_points_idx(
|
||||
x, y, K=K, lengths1=lengths1, lengths2=lengths2, sorted=sort
|
||||
)
|
||||
self._check_knn_result(out1, out2, sort)
|
||||
|
||||
def test_knn_vs_python_cuda(self):
|
||||
def test_knn_vs_python_cuda_square(self):
|
||||
""" Test CUDA output vs PyTorch implementation """
|
||||
device = torch.device("cuda")
|
||||
Ns = [1, 4]
|
||||
@@ -57,9 +67,53 @@ class TestKNN(unittest.TestCase):
|
||||
for N, D, P1, P2, K, sort in product(*factors):
|
||||
x = torch.randn(N, P1, D, device=device)
|
||||
y = torch.randn(N, P2, D, device=device)
|
||||
out1 = _knn_points_idx_naive(x, y, K, sorted=sort)
|
||||
out1 = _knn_points_idx_naive(x, y, lengths1=None, lengths2=None, K=K)
|
||||
for version in versions:
|
||||
if version == 3 and K > 4:
|
||||
continue
|
||||
out2 = knn_points_idx(x, y, K, sort, version)
|
||||
out2 = knn_points_idx(x, y, K=K, sorted=sort, version=version)
|
||||
self._check_knn_result(out1, out2, sort)
|
||||
|
||||
def test_knn_vs_python_cpu_ragged(self):
|
||||
device = torch.device("cpu")
|
||||
lengths1 = torch.tensor([10, 100, 10, 100], device=device, dtype=torch.int64)
|
||||
lengths2 = torch.tensor([10, 10, 100, 100], device=device, dtype=torch.int64)
|
||||
N = 4
|
||||
D = 3
|
||||
Ks = [1, 9, 10, 11, 101]
|
||||
sorts = [False, True]
|
||||
factors = [Ks, sorts]
|
||||
for K, sort in product(*factors):
|
||||
x = torch.randn(N, lengths1.max(), D, device=device)
|
||||
y = torch.randn(N, lengths2.max(), D, device=device)
|
||||
out1 = _knn_points_idx_naive(
|
||||
x, y, lengths1=lengths1, lengths2=lengths2, K=K
|
||||
)
|
||||
out2 = knn_points_idx(
|
||||
x, y, lengths1=lengths1, lengths2=lengths2, K=K, sorted=sort
|
||||
)
|
||||
self._check_knn_result(out1, out2, sort)
|
||||
|
||||
def test_knn_vs_python_cuda_ragged(self):
|
||||
device = torch.device("cuda")
|
||||
lengths1 = torch.tensor([10, 100, 10, 100], device=device, dtype=torch.int64)
|
||||
lengths2 = torch.tensor([10, 10, 100, 100], device=device, dtype=torch.int64)
|
||||
N = 4
|
||||
D = 3
|
||||
Ks = [1, 9, 10, 11, 101]
|
||||
sorts = [True, False]
|
||||
versions = [0, 1, 2, 3]
|
||||
factors = [Ks, sorts]
|
||||
for K, sort in product(*factors):
|
||||
x = torch.randn(N, lengths1.max(), D, device=device)
|
||||
y = torch.randn(N, lengths2.max(), D, device=device)
|
||||
out1 = _knn_points_idx_naive(
|
||||
x, y, lengths1=lengths1, lengths2=lengths2, K=K
|
||||
)
|
||||
for version in versions:
|
||||
if version == 3 and K > 4:
|
||||
continue
|
||||
out2 = knn_points_idx(
|
||||
x, y, lengths1=lengths1, lengths2=lengths2, K=K, sorted=sort
|
||||
)
|
||||
self._check_knn_result(out1, out2, sort)
|
||||
|
||||
Reference in New Issue
Block a user