Implement K-Nearest Neighbors

Summary:
Implements K-Nearest Neighbors with C++ and CUDA versions.

KNN in CUDA is highly nontrivial. I've implemented a few different versions of the kernel, and we heuristically dispatch to different kernels based on the problem size. Some of the kernels rely on template specialization on either D or K, so we use template metaprogramming to compile specialized versions for ranges of D and K.

These kernels are up to 3x faster than our existing 1-nearest-neighbor kernels, so we should also consider swapping out `nn_points_idx` to use these kernels in the backend.

I've been working mostly on the CUDA kernels, and haven't converged on the correct Python API.

I still want to benchmark against FAISS to see how far away we are from their performance.

Reviewed By: bottler

Differential Revision: D19729286

fbshipit-source-id: 608ffbb7030c21fe4008f330522f4890f0c3c21a
This commit is contained in:
Justin Johnson
2020-03-26 13:37:32 -07:00
committed by Facebook GitHub Bot
parent 02d4968ee0
commit 870290df34
12 changed files with 1328 additions and 1 deletions

174
tests/bm_knn.py Normal file
View File

@@ -0,0 +1,174 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from itertools import product
import torch
from fvcore.common.benchmark import benchmark
from pytorch3d import _C
from pytorch3d.ops.knn import _knn_points_idx_naive
def bm_knn() -> None:
""" Entry point for the benchmark """
benchmark_knn_cpu()
benchmark_knn_cuda_vs_naive()
benchmark_knn_cuda_versions()
def benchmark_knn_cuda_versions() -> None:
# Compare our different KNN implementations,
# and also compare against our existing 1-NN
Ns = [1, 2]
Ps = [4096, 16384]
Ds = [3]
Ks = [1, 4, 16, 64]
versions = [0, 1, 2, 3]
knn_kwargs, nn_kwargs = [], []
for N, P, D, K, version in product(Ns, Ps, Ds, Ks, versions):
if version == 2 and K > 32:
continue
if version == 3 and K > 4:
continue
knn_kwargs.append({'N': N, 'D': D, 'P': P, 'K': K, 'v': version})
for N, P, D in product(Ns, Ps, Ds):
nn_kwargs.append({'N': N, 'D': D, 'P': P})
benchmark(
knn_cuda_with_init,
'KNN_CUDA_VERSIONS',
knn_kwargs,
warmup_iters=1,
)
benchmark(
nn_cuda_with_init,
'NN_CUDA',
nn_kwargs,
warmup_iters=1,
)
def benchmark_knn_cuda_vs_naive() -> None:
# Compare against naive pytorch version of KNN
Ns = [1, 2, 4]
Ps = [1024, 4096, 16384, 65536]
Ds = [3]
Ks = [1, 2, 4, 8, 16]
knn_kwargs, naive_kwargs = [], []
for N, P, D, K in product(Ns, Ps, Ds, Ks):
knn_kwargs.append({'N': N, 'D': D, 'P': P, 'K': K})
if P <= 4096:
naive_kwargs.append({'N': N, 'D': D, 'P': P, 'K': K})
benchmark(
knn_python_cuda_with_init,
'KNN_CUDA_PYTHON',
naive_kwargs,
warmup_iters=1,
)
benchmark(
knn_cuda_with_init,
'KNN_CUDA',
knn_kwargs,
warmup_iters=1,
)
def benchmark_knn_cpu() -> None:
Ns = [1, 2]
Ps = [256, 512]
Ds = [3]
Ks = [1, 2, 4]
knn_kwargs, nn_kwargs = [], []
for N, P, D, K in product(Ns, Ps, Ds, Ks):
knn_kwargs.append({'N': N, 'D': D, 'P': P, 'K': K})
for N, P, D in product(Ns, Ps, Ds):
nn_kwargs.append({'N': N, 'D': D, 'P': P})
benchmark(
knn_python_cpu_with_init,
'KNN_CPU_PYTHON',
knn_kwargs,
warmup_iters=1,
)
benchmark(
knn_cpu_with_init,
'KNN_CPU_CPP',
knn_kwargs,
warmup_iters=1,
)
benchmark(
nn_cpu_with_init,
'NN_CPU_CPP',
nn_kwargs,
warmup_iters=1,
)
def knn_cuda_with_init(N, D, P, K, v=-1):
device = torch.device('cuda:0')
x = torch.randn(N, P, D, device=device)
y = torch.randn(N, P, D, device=device)
torch.cuda.synchronize()
def knn():
_C.knn_points_idx(x, y, K, v)
torch.cuda.synchronize()
return knn
def knn_cpu_with_init(N, D, P, K):
device = torch.device('cpu')
x = torch.randn(N, P, D, device=device)
y = torch.randn(N, P, D, device=device)
def knn():
_C.knn_points_idx(x, y, K, 0)
return knn
def knn_python_cuda_with_init(N, D, P, K):
device = torch.device('cuda')
x = torch.randn(N, P, D, device=device)
y = torch.randn(N, P, D, device=device)
torch.cuda.synchronize()
def knn():
_knn_points_idx_naive(x, y, K)
torch.cuda.synchronize()
return knn
def knn_python_cpu_with_init(N, D, P, K):
device = torch.device('cpu')
x = torch.randn(N, P, D, device=device)
y = torch.randn(N, P, D, device=device)
def knn():
_knn_points_idx_naive(x, y, K)
return knn
def nn_cuda_with_init(N, D, P):
device = torch.device('cuda')
x = torch.randn(N, P, D, device=device)
y = torch.randn(N, P, D, device=device)
torch.cuda.synchronize()
def knn():
_C.nn_points_idx(x, y)
torch.cuda.synchronize()
return knn
def nn_cpu_with_init(N, D, P):
device = torch.device('cpu')
x = torch.randn(N, P, D, device=device)
y = torch.randn(N, P, D, device=device)
def knn():
_C.nn_points_idx(x, y)
return knn

65
tests/test_knn.py Normal file
View File

@@ -0,0 +1,65 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import unittest
from itertools import product
import torch
from pytorch3d.ops.knn import _knn_points_idx_naive, knn_points_idx
class TestKNN(unittest.TestCase):
def _check_knn_result(self, out1, out2, sorted):
# When sorted=True, points should be sorted by distance and should
# match between implementations. When sorted=False we we only want to
# check that we got the same set of indices, so we sort the indices by
# index value.
idx1, dist1 = out1
idx2, dist2 = out2
if not sorted:
idx1 = idx1.sort(dim=2).values
idx2 = idx2.sort(dim=2).values
dist1 = dist1.sort(dim=2).values
dist2 = dist2.sort(dim=2).values
if not torch.all(idx1 == idx2):
print(idx1)
print(idx2)
self.assertTrue(torch.all(idx1 == idx2))
self.assertTrue(torch.allclose(dist1, dist2))
def test_knn_vs_python_cpu(self):
""" Test CPU output vs PyTorch implementation """
device = torch.device('cpu')
Ns = [1, 4]
Ds = [2, 3]
P1s = [1, 10, 101]
P2s = [10, 101]
Ks = [1, 3, 10]
sorts = [True, False]
factors = [Ns, Ds, P1s, P2s, Ks, sorts]
for N, D, P1, P2, K, sort in product(*factors):
x = torch.randn(N, P1, D, device=device)
y = torch.randn(N, P2, D, device=device)
out1 = _knn_points_idx_naive(x, y, K, sort)
out2 = knn_points_idx(x, y, K, sort)
self._check_knn_result(out1, out2, sort)
def test_knn_vs_python_cuda(self):
""" Test CUDA output vs PyTorch implementation """
device = torch.device('cuda')
Ns = [1, 4]
Ds = [2, 3, 8]
P1s = [1, 8, 64, 128, 1001]
P2s = [32, 128, 513]
Ks = [1, 3, 10]
sorts = [True, False]
versions = [0, 1, 2, 3]
factors = [Ns, Ds, P1s, P2s, Ks, sorts]
for N, D, P1, P2, K, sort in product(*factors):
x = torch.randn(N, P1, D, device=device)
y = torch.randn(N, P2, D, device=device)
out1 = _knn_points_idx_naive(x, y, K, sorted=sort)
for version in versions:
if version == 3 and K > 4:
continue
out2 = knn_points_idx(x, y, K, sort, version)
self._check_knn_result(out1, out2, sort)

View File

@@ -1,6 +1,7 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import unittest
from itertools import product
import torch
from pytorch3d import _C