From e290f87ca949c077803d2da02c48173607ce70e4 Mon Sep 17 00:00:00 2001 From: Justin Johnson Date: Mon, 3 Feb 2020 10:04:10 -0800 Subject: [PATCH] Add CPU implementation for nearest neighbor Summary: Adds a CPU implementation for `pytorch3d.ops.nn_points_idx`. Also renames the associated C++ and CUDA functions to use `AllCaps` names used in other C++ / CUDA code. Reviewed By: gkioxari Differential Revision: D19670491 fbshipit-source-id: 1b6409404025bf05e6a93f5d847e35afc9062f05 --- pytorch3d/csrc/ext.cpp | 2 +- .../nearest_neighbor_points.cu | 16 ++++---- .../nearest_neighbor_points.h | 11 ++++-- .../nearest_neighbors_points_cpu.cpp | 38 +++++++++++++++++++ tests/bm_nearest_neighbor_points.py | 7 ++++ tests/test_nearest_neighbor_points.py | 38 ++++++++++++------- 6 files changed, 86 insertions(+), 26 deletions(-) create mode 100644 pytorch3d/csrc/nearest_neighbor_points/nearest_neighbors_points_cpu.cpp diff --git a/pytorch3d/csrc/ext.cpp b/pytorch3d/csrc/ext.cpp index ad71c2cf..6c5b51ab 100644 --- a/pytorch3d/csrc/ext.cpp +++ b/pytorch3d/csrc/ext.cpp @@ -11,7 +11,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("face_areas_normals", &face_areas_normals); m.def("packed_to_padded_tensor", &packed_to_padded_tensor); - m.def("nn_points_idx", &nn_points_idx); + m.def("nn_points_idx", &NearestNeighborIdx); m.def("gather_scatter", &gather_scatter); m.def("rasterize_points", &RasterizePoints); m.def("rasterize_points_backward", &RasterizePointsBackward); diff --git a/pytorch3d/csrc/nearest_neighbor_points/nearest_neighbor_points.cu b/pytorch3d/csrc/nearest_neighbor_points/nearest_neighbor_points.cu index 07d2bbb3..e1f119d0 100644 --- a/pytorch3d/csrc/nearest_neighbor_points/nearest_neighbor_points.cu +++ b/pytorch3d/csrc/nearest_neighbor_points/nearest_neighbor_points.cu @@ -4,7 +4,7 @@ #include template -__device__ void warp_reduce( +__device__ void WarpReduce( volatile scalar_t* min_dists, volatile long* min_idxs, const size_t tid) { @@ -54,7 +54,7 @@ __device__ void warp_reduce( // is aligned. // template -__global__ void nearest_neighbor_kernel( +__global__ void NearestNeighborKernel( const scalar_t* __restrict__ points1, const scalar_t* __restrict__ points2, long* __restrict__ idx, @@ -123,7 +123,7 @@ __global__ void nearest_neighbor_kernel( // Unroll the last 6 iterations of the loop since they will happen // synchronized within a single warp. if (tid < 32) - warp_reduce(min_dists, min_idxs, tid); + WarpReduce(min_dists, min_idxs, tid); // Finally thread 0 writes the result to the output buffer. if (tid == 0) { @@ -144,7 +144,7 @@ __global__ void nearest_neighbor_kernel( // P2: Number of points in points2. // template -__global__ void nearest_neighbor_kernel_D3( +__global__ void NearestNeighborKernelD3( const scalar_t* __restrict__ points1, const scalar_t* __restrict__ points2, long* __restrict__ idx, @@ -204,7 +204,7 @@ __global__ void nearest_neighbor_kernel_D3( // Unroll the last 6 iterations of the loop since they will happen // synchronized within a single warp. if (tid < 32) - warp_reduce(min_dists, min_idxs, tid); + WarpReduce(min_dists, min_idxs, tid); // Finally thread 0 writes the result to the output buffer. if (tid == 0) { @@ -212,7 +212,7 @@ __global__ void nearest_neighbor_kernel_D3( } } -at::Tensor nn_points_idx_cuda(at::Tensor p1, at::Tensor p2) { +at::Tensor NearestNeighborIdxCuda(at::Tensor p1, at::Tensor p2) { const auto N = p1.size(0); const auto P1 = p1.size(1); const auto P2 = p2.size(1); @@ -231,7 +231,7 @@ at::Tensor nn_points_idx_cuda(at::Tensor p1, at::Tensor p2) { AT_DISPATCH_FLOATING_TYPES(p1.type(), "nearest_neighbor_v3_cuda", ([&] { size_t shared_size = threads * sizeof(size_t) + threads * sizeof(long); - nearest_neighbor_kernel_D3 + NearestNeighborKernelD3 <<>>( p1.data_ptr(), p2.data_ptr(), @@ -249,7 +249,7 @@ at::Tensor nn_points_idx_cuda(at::Tensor p1, at::Tensor p2) { size_t D_2 = D + (D % 2); size_t shared_size = (D_2 + threads) * sizeof(size_t); shared_size += threads * sizeof(long); - nearest_neighbor_kernel<<>>( + NearestNeighborKernel<<>>( p1.data_ptr(), p2.data_ptr(), idx.data_ptr(), diff --git a/pytorch3d/csrc/nearest_neighbor_points/nearest_neighbor_points.h b/pytorch3d/csrc/nearest_neighbor_points/nearest_neighbor_points.h index 7d87f1b4..51c7e72e 100644 --- a/pytorch3d/csrc/nearest_neighbor_points/nearest_neighbor_points.h +++ b/pytorch3d/csrc/nearest_neighbor_points/nearest_neighbor_points.h @@ -19,19 +19,22 @@ // to p1[n, i] in the cloud p2[n] is p2[n, j]. // +// CPU implementation. +at::Tensor NearestNeighborIdxCpu(at::Tensor p1, at::Tensor p2); + // Cuda implementation. -at::Tensor nn_points_idx_cuda(at::Tensor p1, at::Tensor p2); +at::Tensor NearestNeighborIdxCuda(at::Tensor p1, at::Tensor p2); // Implementation which is exposed. -at::Tensor nn_points_idx(at::Tensor p1, at::Tensor p2) { +at::Tensor NearestNeighborIdx(at::Tensor p1, at::Tensor p2) { if (p1.type().is_cuda() && p2.type().is_cuda()) { #ifdef WITH_CUDA CHECK_CONTIGUOUS_CUDA(p1); CHECK_CONTIGUOUS_CUDA(p2); - return nn_points_idx_cuda(p1, p2); + return NearestNeighborIdxCuda(p1, p2); #else AT_ERROR("Not compiled with GPU support."); #endif } - AT_ERROR("Not implemented on the CPU."); + return NearestNeighborIdxCpu(p1, p2); }; diff --git a/pytorch3d/csrc/nearest_neighbor_points/nearest_neighbors_points_cpu.cpp b/pytorch3d/csrc/nearest_neighbor_points/nearest_neighbors_points_cpu.cpp new file mode 100644 index 00000000..0fa80488 --- /dev/null +++ b/pytorch3d/csrc/nearest_neighbor_points/nearest_neighbors_points_cpu.cpp @@ -0,0 +1,38 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. + +#include + +at::Tensor NearestNeighborIdxCpu(at::Tensor p1, at::Tensor p2) { + const int N = p1.size(0); + const int P1 = p1.size(1); + const int D = p1.size(2); + const int P2 = p2.size(1); + + auto long_opts = p1.options().dtype(torch::kInt64); + torch::Tensor out = torch::empty({N, P1}, long_opts); + + auto p1_a = p1.accessor(); + auto p2_a = p2.accessor(); + auto out_a = out.accessor(); + + for (int n = 0; n < N; ++n) { + for (int i1 = 0; i1 < P1; ++i1) { + // TODO: support other floating-point types? + float min_dist = -1; + int64_t min_idx = -1; + for (int i2 = 0; i2 < P2; ++i2) { + float dist = 0; + for (int d = 0; d < D; ++d) { + float diff = p1_a[n][i1][d] - p2_a[n][i2][d]; + dist += diff * diff; + } + if (min_dist == -1 || dist < min_dist) { + min_dist = dist; + min_idx = i2; + } + } + out_a[n][i1] = min_idx; + } + } + return out; +} diff --git a/tests/bm_nearest_neighbor_points.py b/tests/bm_nearest_neighbor_points.py index ca488eee..c5d87e6a 100644 --- a/tests/bm_nearest_neighbor_points.py +++ b/tests/bm_nearest_neighbor_points.py @@ -27,6 +27,13 @@ def bm_nn_points() -> None: warmup_iters=1, ) + benchmark( + TestNearestNeighborPoints.bm_nn_points_cpu_with_init, + "NN_CPU", + kwargs_list, + warmup_iters=1, + ) + if torch.cuda.is_available(): benchmark( TestNearestNeighborPoints.bm_nn_points_cuda_with_init, diff --git a/tests/test_nearest_neighbor_points.py b/tests/test_nearest_neighbor_points.py index 697b5477..964b5a9a 100644 --- a/tests/test_nearest_neighbor_points.py +++ b/tests/test_nearest_neighbor_points.py @@ -21,11 +21,7 @@ class TestNearestNeighborPoints(unittest.TestCase): idx = dists2.argmin(2) return idx - def test_nn_cuda(self): - """ - Test cuda output vs naive python implementation. - """ - device = torch.device("cuda:0") + def _test_nn_helper(self, device): for D in [3, 4]: for N in [1, 4]: for P1 in [1, 8, 64, 128]: @@ -43,16 +39,32 @@ class TestNearestNeighborPoints(unittest.TestCase): self.assertTrue(idx1.size(1) == P1) self.assertTrue(torch.all(idx1 == idx2)) - def test_nn_cuda_error(self): + def test_nn_cuda(self): """ - Check that nn_points_idx throws an error if cpu tensors - are given as input. + Test cuda output vs naive python implementation. """ - x = torch.randn(1, 1, 3) - y = torch.randn(1, 1, 3) - with self.assertRaises(Exception) as err: - _C.nn_points_idx(x, y) - self.assertTrue("Not implemented on the CPU" in str(err.exception)) + device = torch.device('cuda:0') + self._test_nn_helper(device) + + def test_nn_cpu(self): + """ + Test cpu output vs naive python implementation + """ + device = torch.device('cpu') + self._test_nn_helper(device) + + @staticmethod + def bm_nn_points_cpu_with_init( + N: int = 4, D: int = 4, P1: int = 128, P2: int = 128 + ): + device = torch.device('cpu') + x = torch.randn(N, P1, D, device=device) + y = torch.randn(N, P2, D, device=device) + + def nn_cpu(): + _C.nn_points_idx(x.contiguous(), y.contiguous()) + + return nn_cpu @staticmethod def bm_nn_points_cuda_with_init(