mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-19 05:40:34 +08:00
Add CPU implementation for nearest neighbor
Summary: Adds a CPU implementation for `pytorch3d.ops.nn_points_idx`. Also renames the associated C++ and CUDA functions to use `AllCaps` names used in other C++ / CUDA code. Reviewed By: gkioxari Differential Revision: D19670491 fbshipit-source-id: 1b6409404025bf05e6a93f5d847e35afc9062f05
This commit is contained in:
committed by
Facebook Github Bot
parent
25c2f34096
commit
e290f87ca9
@@ -27,6 +27,13 @@ def bm_nn_points() -> None:
|
||||
warmup_iters=1,
|
||||
)
|
||||
|
||||
benchmark(
|
||||
TestNearestNeighborPoints.bm_nn_points_cpu_with_init,
|
||||
"NN_CPU",
|
||||
kwargs_list,
|
||||
warmup_iters=1,
|
||||
)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
benchmark(
|
||||
TestNearestNeighborPoints.bm_nn_points_cuda_with_init,
|
||||
|
||||
@@ -21,11 +21,7 @@ class TestNearestNeighborPoints(unittest.TestCase):
|
||||
idx = dists2.argmin(2)
|
||||
return idx
|
||||
|
||||
def test_nn_cuda(self):
|
||||
"""
|
||||
Test cuda output vs naive python implementation.
|
||||
"""
|
||||
device = torch.device("cuda:0")
|
||||
def _test_nn_helper(self, device):
|
||||
for D in [3, 4]:
|
||||
for N in [1, 4]:
|
||||
for P1 in [1, 8, 64, 128]:
|
||||
@@ -43,16 +39,32 @@ class TestNearestNeighborPoints(unittest.TestCase):
|
||||
self.assertTrue(idx1.size(1) == P1)
|
||||
self.assertTrue(torch.all(idx1 == idx2))
|
||||
|
||||
def test_nn_cuda_error(self):
|
||||
def test_nn_cuda(self):
|
||||
"""
|
||||
Check that nn_points_idx throws an error if cpu tensors
|
||||
are given as input.
|
||||
Test cuda output vs naive python implementation.
|
||||
"""
|
||||
x = torch.randn(1, 1, 3)
|
||||
y = torch.randn(1, 1, 3)
|
||||
with self.assertRaises(Exception) as err:
|
||||
_C.nn_points_idx(x, y)
|
||||
self.assertTrue("Not implemented on the CPU" in str(err.exception))
|
||||
device = torch.device('cuda:0')
|
||||
self._test_nn_helper(device)
|
||||
|
||||
def test_nn_cpu(self):
|
||||
"""
|
||||
Test cpu output vs naive python implementation
|
||||
"""
|
||||
device = torch.device('cpu')
|
||||
self._test_nn_helper(device)
|
||||
|
||||
@staticmethod
|
||||
def bm_nn_points_cpu_with_init(
|
||||
N: int = 4, D: int = 4, P1: int = 128, P2: int = 128
|
||||
):
|
||||
device = torch.device('cpu')
|
||||
x = torch.randn(N, P1, D, device=device)
|
||||
y = torch.randn(N, P2, D, device=device)
|
||||
|
||||
def nn_cpu():
|
||||
_C.nn_points_idx(x.contiguous(), y.contiguous())
|
||||
|
||||
return nn_cpu
|
||||
|
||||
@staticmethod
|
||||
def bm_nn_points_cuda_with_init(
|
||||
|
||||
Reference in New Issue
Block a user