mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-19 22:00:35 +08:00
Implement K-Nearest Neighbors
Summary: Implements K-Nearest Neighbors with C++ and CUDA versions. KNN in CUDA is highly nontrivial. I've implemented a few different versions of the kernel, and we heuristically dispatch to different kernels based on the problem size. Some of the kernels rely on template specialization on either D or K, so we use template metaprogramming to compile specialized versions for ranges of D and K. These kernels are up to 3x faster than our existing 1-nearest-neighbor kernels, so we should also consider swapping out `nn_points_idx` to use these kernels in the backend. I've been working mostly on the CUDA kernels, and haven't converged on the correct Python API. I still want to benchmark against FAISS to see how far away we are from their performance. Reviewed By: bottler Differential Revision: D19729286 fbshipit-source-id: 608ffbb7030c21fe4008f330522f4890f0c3c21a
This commit is contained in:
committed by
Facebook GitHub Bot
parent
02d4968ee0
commit
870290df34
174
tests/bm_knn.py
Normal file
174
tests/bm_knn.py
Normal file
@@ -0,0 +1,174 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
from itertools import product
|
||||
|
||||
import torch
|
||||
from fvcore.common.benchmark import benchmark
|
||||
|
||||
from pytorch3d import _C
|
||||
from pytorch3d.ops.knn import _knn_points_idx_naive
|
||||
|
||||
|
||||
def bm_knn() -> None:
|
||||
""" Entry point for the benchmark """
|
||||
benchmark_knn_cpu()
|
||||
benchmark_knn_cuda_vs_naive()
|
||||
benchmark_knn_cuda_versions()
|
||||
|
||||
|
||||
def benchmark_knn_cuda_versions() -> None:
|
||||
# Compare our different KNN implementations,
|
||||
# and also compare against our existing 1-NN
|
||||
Ns = [1, 2]
|
||||
Ps = [4096, 16384]
|
||||
Ds = [3]
|
||||
Ks = [1, 4, 16, 64]
|
||||
versions = [0, 1, 2, 3]
|
||||
knn_kwargs, nn_kwargs = [], []
|
||||
for N, P, D, K, version in product(Ns, Ps, Ds, Ks, versions):
|
||||
if version == 2 and K > 32:
|
||||
continue
|
||||
if version == 3 and K > 4:
|
||||
continue
|
||||
knn_kwargs.append({'N': N, 'D': D, 'P': P, 'K': K, 'v': version})
|
||||
for N, P, D in product(Ns, Ps, Ds):
|
||||
nn_kwargs.append({'N': N, 'D': D, 'P': P})
|
||||
benchmark(
|
||||
knn_cuda_with_init,
|
||||
'KNN_CUDA_VERSIONS',
|
||||
knn_kwargs,
|
||||
warmup_iters=1,
|
||||
)
|
||||
benchmark(
|
||||
nn_cuda_with_init,
|
||||
'NN_CUDA',
|
||||
nn_kwargs,
|
||||
warmup_iters=1,
|
||||
)
|
||||
|
||||
|
||||
def benchmark_knn_cuda_vs_naive() -> None:
|
||||
# Compare against naive pytorch version of KNN
|
||||
Ns = [1, 2, 4]
|
||||
Ps = [1024, 4096, 16384, 65536]
|
||||
Ds = [3]
|
||||
Ks = [1, 2, 4, 8, 16]
|
||||
knn_kwargs, naive_kwargs = [], []
|
||||
for N, P, D, K in product(Ns, Ps, Ds, Ks):
|
||||
knn_kwargs.append({'N': N, 'D': D, 'P': P, 'K': K})
|
||||
if P <= 4096:
|
||||
naive_kwargs.append({'N': N, 'D': D, 'P': P, 'K': K})
|
||||
benchmark(
|
||||
knn_python_cuda_with_init,
|
||||
'KNN_CUDA_PYTHON',
|
||||
naive_kwargs,
|
||||
warmup_iters=1,
|
||||
)
|
||||
benchmark(
|
||||
knn_cuda_with_init,
|
||||
'KNN_CUDA',
|
||||
knn_kwargs,
|
||||
warmup_iters=1,
|
||||
)
|
||||
|
||||
|
||||
def benchmark_knn_cpu() -> None:
|
||||
Ns = [1, 2]
|
||||
Ps = [256, 512]
|
||||
Ds = [3]
|
||||
Ks = [1, 2, 4]
|
||||
knn_kwargs, nn_kwargs = [], []
|
||||
for N, P, D, K in product(Ns, Ps, Ds, Ks):
|
||||
knn_kwargs.append({'N': N, 'D': D, 'P': P, 'K': K})
|
||||
for N, P, D in product(Ns, Ps, Ds):
|
||||
nn_kwargs.append({'N': N, 'D': D, 'P': P})
|
||||
benchmark(
|
||||
knn_python_cpu_with_init,
|
||||
'KNN_CPU_PYTHON',
|
||||
knn_kwargs,
|
||||
warmup_iters=1,
|
||||
)
|
||||
benchmark(
|
||||
knn_cpu_with_init,
|
||||
'KNN_CPU_CPP',
|
||||
knn_kwargs,
|
||||
warmup_iters=1,
|
||||
)
|
||||
benchmark(
|
||||
nn_cpu_with_init,
|
||||
'NN_CPU_CPP',
|
||||
nn_kwargs,
|
||||
warmup_iters=1,
|
||||
)
|
||||
|
||||
|
||||
def knn_cuda_with_init(N, D, P, K, v=-1):
|
||||
device = torch.device('cuda:0')
|
||||
x = torch.randn(N, P, D, device=device)
|
||||
y = torch.randn(N, P, D, device=device)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def knn():
|
||||
_C.knn_points_idx(x, y, K, v)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
return knn
|
||||
|
||||
|
||||
def knn_cpu_with_init(N, D, P, K):
|
||||
device = torch.device('cpu')
|
||||
x = torch.randn(N, P, D, device=device)
|
||||
y = torch.randn(N, P, D, device=device)
|
||||
|
||||
def knn():
|
||||
_C.knn_points_idx(x, y, K, 0)
|
||||
|
||||
return knn
|
||||
|
||||
|
||||
def knn_python_cuda_with_init(N, D, P, K):
|
||||
device = torch.device('cuda')
|
||||
x = torch.randn(N, P, D, device=device)
|
||||
y = torch.randn(N, P, D, device=device)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def knn():
|
||||
_knn_points_idx_naive(x, y, K)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
return knn
|
||||
|
||||
|
||||
def knn_python_cpu_with_init(N, D, P, K):
|
||||
device = torch.device('cpu')
|
||||
x = torch.randn(N, P, D, device=device)
|
||||
y = torch.randn(N, P, D, device=device)
|
||||
|
||||
def knn():
|
||||
_knn_points_idx_naive(x, y, K)
|
||||
|
||||
return knn
|
||||
|
||||
|
||||
def nn_cuda_with_init(N, D, P):
|
||||
device = torch.device('cuda')
|
||||
x = torch.randn(N, P, D, device=device)
|
||||
y = torch.randn(N, P, D, device=device)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def knn():
|
||||
_C.nn_points_idx(x, y)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
return knn
|
||||
|
||||
|
||||
def nn_cpu_with_init(N, D, P):
|
||||
device = torch.device('cpu')
|
||||
x = torch.randn(N, P, D, device=device)
|
||||
y = torch.randn(N, P, D, device=device)
|
||||
|
||||
def knn():
|
||||
_C.nn_points_idx(x, y)
|
||||
|
||||
return knn
|
||||
65
tests/test_knn.py
Normal file
65
tests/test_knn.py
Normal file
@@ -0,0 +1,65 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
import unittest
|
||||
from itertools import product
|
||||
import torch
|
||||
|
||||
from pytorch3d.ops.knn import _knn_points_idx_naive, knn_points_idx
|
||||
|
||||
|
||||
class TestKNN(unittest.TestCase):
|
||||
def _check_knn_result(self, out1, out2, sorted):
|
||||
# When sorted=True, points should be sorted by distance and should
|
||||
# match between implementations. When sorted=False we we only want to
|
||||
# check that we got the same set of indices, so we sort the indices by
|
||||
# index value.
|
||||
idx1, dist1 = out1
|
||||
idx2, dist2 = out2
|
||||
if not sorted:
|
||||
idx1 = idx1.sort(dim=2).values
|
||||
idx2 = idx2.sort(dim=2).values
|
||||
dist1 = dist1.sort(dim=2).values
|
||||
dist2 = dist2.sort(dim=2).values
|
||||
if not torch.all(idx1 == idx2):
|
||||
print(idx1)
|
||||
print(idx2)
|
||||
self.assertTrue(torch.all(idx1 == idx2))
|
||||
self.assertTrue(torch.allclose(dist1, dist2))
|
||||
|
||||
def test_knn_vs_python_cpu(self):
|
||||
""" Test CPU output vs PyTorch implementation """
|
||||
device = torch.device('cpu')
|
||||
Ns = [1, 4]
|
||||
Ds = [2, 3]
|
||||
P1s = [1, 10, 101]
|
||||
P2s = [10, 101]
|
||||
Ks = [1, 3, 10]
|
||||
sorts = [True, False]
|
||||
factors = [Ns, Ds, P1s, P2s, Ks, sorts]
|
||||
for N, D, P1, P2, K, sort in product(*factors):
|
||||
x = torch.randn(N, P1, D, device=device)
|
||||
y = torch.randn(N, P2, D, device=device)
|
||||
out1 = _knn_points_idx_naive(x, y, K, sort)
|
||||
out2 = knn_points_idx(x, y, K, sort)
|
||||
self._check_knn_result(out1, out2, sort)
|
||||
|
||||
def test_knn_vs_python_cuda(self):
|
||||
""" Test CUDA output vs PyTorch implementation """
|
||||
device = torch.device('cuda')
|
||||
Ns = [1, 4]
|
||||
Ds = [2, 3, 8]
|
||||
P1s = [1, 8, 64, 128, 1001]
|
||||
P2s = [32, 128, 513]
|
||||
Ks = [1, 3, 10]
|
||||
sorts = [True, False]
|
||||
versions = [0, 1, 2, 3]
|
||||
factors = [Ns, Ds, P1s, P2s, Ks, sorts]
|
||||
for N, D, P1, P2, K, sort in product(*factors):
|
||||
x = torch.randn(N, P1, D, device=device)
|
||||
y = torch.randn(N, P2, D, device=device)
|
||||
out1 = _knn_points_idx_naive(x, y, K, sorted=sort)
|
||||
for version in versions:
|
||||
if version == 3 and K > 4:
|
||||
continue
|
||||
out2 = knn_points_idx(x, y, K, sort, version)
|
||||
self._check_knn_result(out1, out2, sort)
|
||||
@@ -1,6 +1,7 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
import unittest
|
||||
from itertools import product
|
||||
import torch
|
||||
|
||||
from pytorch3d import _C
|
||||
|
||||
Reference in New Issue
Block a user