pytorch3d/tests/test_nearest_neighbor_points.py
Justin Johnson e290f87ca9 Add CPU implementation for nearest neighbor
Summary:
Adds a CPU implementation for `pytorch3d.ops.nn_points_idx`.

Also renames the associated C++ and CUDA functions to use `AllCaps` names used in other C++ / CUDA code.

Reviewed By: gkioxari

Differential Revision: D19670491

fbshipit-source-id: 1b6409404025bf05e6a93f5d847e35afc9062f05
2020-02-03 10:06:10 -08:00

95 lines
2.8 KiB
Python

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import unittest
import torch
from pytorch3d import _C
class TestNearestNeighborPoints(unittest.TestCase):
@staticmethod
def nn_points_idx_naive(x, y):
"""
PyTorch implementation of nn_points_idx function.
"""
N, P1, D = x.shape
_N, P2, _D = y.shape
assert N == _N and D == _D
diffs = x.view(N, P1, 1, D) - y.view(N, 1, P2, D)
dists2 = (diffs * diffs).sum(3)
idx = dists2.argmin(2)
return idx
def _test_nn_helper(self, device):
for D in [3, 4]:
for N in [1, 4]:
for P1 in [1, 8, 64, 128]:
for P2 in [32, 128]:
x = torch.randn(N, P1, D, device=device)
y = torch.randn(N, P2, D, device=device)
# _C.nn_points_idx should dispatch
# to the cpp or cuda versions of the function
# depending on the input type.
idx1 = _C.nn_points_idx(x, y)
idx2 = TestNearestNeighborPoints.nn_points_idx_naive(
x, y
)
self.assertTrue(idx1.size(1) == P1)
self.assertTrue(torch.all(idx1 == idx2))
def test_nn_cuda(self):
"""
Test cuda output vs naive python implementation.
"""
device = torch.device('cuda:0')
self._test_nn_helper(device)
def test_nn_cpu(self):
"""
Test cpu output vs naive python implementation
"""
device = torch.device('cpu')
self._test_nn_helper(device)
@staticmethod
def bm_nn_points_cpu_with_init(
N: int = 4, D: int = 4, P1: int = 128, P2: int = 128
):
device = torch.device('cpu')
x = torch.randn(N, P1, D, device=device)
y = torch.randn(N, P2, D, device=device)
def nn_cpu():
_C.nn_points_idx(x.contiguous(), y.contiguous())
return nn_cpu
@staticmethod
def bm_nn_points_cuda_with_init(
N: int = 4, D: int = 4, P1: int = 128, P2: int = 128
):
device = torch.device("cuda:0")
x = torch.randn(N, P1, D, device=device)
y = torch.randn(N, P2, D, device=device)
torch.cuda.synchronize()
def nn_cpp():
_C.nn_points_idx(x.contiguous(), y.contiguous())
torch.cuda.synchronize()
return nn_cpp
@staticmethod
def bm_nn_points_python_with_init(
N: int = 4, D: int = 4, P1: int = 128, P2: int = 128
):
x = torch.randn(N, P1, D)
y = torch.randn(N, P2, D)
def nn_python():
TestNearestNeighborPoints.nn_points_idx_naive(x, y)
return nn_python