pytorch3d/tests/test_nearest_neighbor_points.py
Justin Johnson 870290df34 Implement K-Nearest Neighbors
Summary:
Implements K-Nearest Neighbors with C++ and CUDA versions.

KNN in CUDA is highly nontrivial. I've implemented a few different versions of the kernel, and we heuristically dispatch to different kernels based on the problem size. Some of the kernels rely on template specialization on either D or K, so we use template metaprogramming to compile specialized versions for ranges of D and K.

These kernels are up to 3x faster than our existing 1-nearest-neighbor kernels, so we should also consider swapping out `nn_points_idx` to use these kernels in the backend.

I've been working mostly on the CUDA kernels, and haven't converged on the correct Python API.

I still want to benchmark against FAISS to see how far away we are from their performance.

Reviewed By: bottler

Differential Revision: D19729286

fbshipit-source-id: 608ffbb7030c21fe4008f330522f4890f0c3c21a
2020-03-26 13:40:26 -07:00

95 lines
2.8 KiB
Python

# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import unittest
from itertools import product
import torch
from pytorch3d import _C
class TestNearestNeighborPoints(unittest.TestCase):
@staticmethod
def nn_points_idx_naive(x, y):
"""
PyTorch implementation of nn_points_idx function.
"""
N, P1, D = x.shape
_N, P2, _D = y.shape
assert N == _N and D == _D
diffs = x.view(N, P1, 1, D) - y.view(N, 1, P2, D)
dists2 = (diffs * diffs).sum(3)
idx = dists2.argmin(2)
return idx
def _test_nn_helper(self, device):
for D in [3, 4]:
for N in [1, 4]:
for P1 in [1, 8, 64, 128]:
for P2 in [32, 128]:
x = torch.randn(N, P1, D, device=device)
y = torch.randn(N, P2, D, device=device)
# _C.nn_points_idx should dispatch
# to the cpp or cuda versions of the function
# depending on the input type.
idx1 = _C.nn_points_idx(x, y)
idx2 = TestNearestNeighborPoints.nn_points_idx_naive(
x, y
)
self.assertTrue(idx1.size(1) == P1)
self.assertTrue(torch.all(idx1 == idx2))
def test_nn_cuda(self):
"""
Test cuda output vs naive python implementation.
"""
device = torch.device("cuda:0")
self._test_nn_helper(device)
def test_nn_cpu(self):
"""
Test cpu output vs naive python implementation
"""
device = torch.device("cpu")
self._test_nn_helper(device)
@staticmethod
def bm_nn_points_cpu_with_init(
N: int = 4, D: int = 4, P1: int = 128, P2: int = 128
):
device = torch.device("cpu")
x = torch.randn(N, P1, D, device=device)
y = torch.randn(N, P2, D, device=device)
def nn_cpu():
_C.nn_points_idx(x.contiguous(), y.contiguous())
return nn_cpu
@staticmethod
def bm_nn_points_cuda_with_init(
N: int = 4, D: int = 4, P1: int = 128, P2: int = 128
):
device = torch.device("cuda:0")
x = torch.randn(N, P1, D, device=device)
y = torch.randn(N, P2, D, device=device)
torch.cuda.synchronize()
def nn_cpp():
_C.nn_points_idx(x.contiguous(), y.contiguous())
torch.cuda.synchronize()
return nn_cpp
@staticmethod
def bm_nn_points_python_with_init(
N: int = 4, D: int = 4, P1: int = 128, P2: int = 128
):
x = torch.randn(N, P1, D)
y = torch.randn(N, P2, D)
def nn_python():
TestNearestNeighborPoints.nn_points_idx_naive(x, y)
return nn_python