diff --git a/pytorch3d/csrc/ext.cpp b/pytorch3d/csrc/ext.cpp index b162bc59..2502eff8 100644 --- a/pytorch3d/csrc/ext.cpp +++ b/pytorch3d/csrc/ext.cpp @@ -7,7 +7,6 @@ #include "face_areas_normals/face_areas_normals.h" #include "gather_scatter/gather_scatter.h" #include "knn/knn.h" -#include "nearest_neighbor_points/nearest_neighbor_points.h" #include "packed_to_padded_tensor/packed_to_padded_tensor.h" #include "point_mesh/point_mesh_edge.h" #include "point_mesh/point_mesh_face.h" @@ -21,7 +20,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("padded_to_packed", &PaddedToPacked); m.def("knn_points_idx", &KNearestNeighborIdx); m.def("knn_points_backward", &KNearestNeighborBackward); - m.def("nn_points_idx", &NearestNeighborIdx); m.def("gather_scatter", &gather_scatter); m.def("rasterize_points", &RasterizePoints); m.def("rasterize_points_backward", &RasterizePointsBackward); diff --git a/pytorch3d/csrc/nearest_neighbor_points/nearest_neighbors_points_cpu.cpp b/pytorch3d/csrc/nearest_neighbor_points/nearest_neighbors_points_cpu.cpp deleted file mode 100644 index 3dd373b9..00000000 --- a/pytorch3d/csrc/nearest_neighbor_points/nearest_neighbors_points_cpu.cpp +++ /dev/null @@ -1,38 +0,0 @@ -// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. - -#include - -at::Tensor NearestNeighborIdxCpu(at::Tensor p1, at::Tensor p2) { - const int N = p1.size(0); - const int P1 = p1.size(1); - const int D = p1.size(2); - const int P2 = p2.size(1); - - auto long_opts = p1.options().dtype(torch::kInt64); - torch::Tensor out = torch::empty({N, P1}, long_opts); - - auto p1_a = p1.accessor(); - auto p2_a = p2.accessor(); - auto out_a = out.accessor(); - - for (int n = 0; n < N; ++n) { - for (int i1 = 0; i1 < P1; ++i1) { - // TODO: support other floating-point types? - float min_dist = -1; - int64_t min_idx = -1; - for (int i2 = 0; i2 < P2; ++i2) { - float dist = 0; - for (int d = 0; d < D; ++d) { - float diff = p1_a[n][i1][d] - p2_a[n][i2][d]; - dist += diff * diff; - } - if (min_dist == -1 || dist < min_dist) { - min_dist = dist; - min_idx = i2; - } - } - out_a[n][i1] = min_idx; - } - } - return out; -} diff --git a/pytorch3d/ops/__init__.py b/pytorch3d/ops/__init__.py index 0ca9eb6d..48703e2a 100644 --- a/pytorch3d/ops/__init__.py +++ b/pytorch3d/ops/__init__.py @@ -5,7 +5,6 @@ from .cubify import cubify from .graph_conv import GraphConv from .knn import knn_gather, knn_points from .mesh_face_areas_normals import mesh_face_areas_normals -from .nearest_neighbor_points import nn_points_idx from .packed_to_padded import packed_to_padded, padded_to_packed from .points_alignment import corresponding_points_alignment from .sample_points_from_meshes import sample_points_from_meshes diff --git a/pytorch3d/ops/nearest_neighbor_points.py b/pytorch3d/ops/nearest_neighbor_points.py deleted file mode 100644 index ffb40a8e..00000000 --- a/pytorch3d/ops/nearest_neighbor_points.py +++ /dev/null @@ -1,43 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. - - -import torch -from pytorch3d import _C - - -def nn_points_idx(p1, p2, p2_normals=None) -> torch.Tensor: - """ - Compute the coordinates of nearest neighbors in pointcloud p2 to points in p1. - Args: - p1: FloatTensor of shape (N, P1, D) giving a batch of pointclouds each - containing P1 points of dimension D. - p2: FloatTensor of shape (N, P2, D) giving a batch of pointclouds each - containing P2 points of dimension D. - p2_normals: [optional] FloatTensor of shape (N, P2, D) giving - normals for p2. Default: None. - - Returns: - 3-element tuple containing - - - **p1_nn_points**: FloatTensor of shape (N, P1, D) where - p1_neighbors[n, i] is the point in p2[n] which is - the nearest neighbor to p1[n, i]. - - **p1_nn_idx**: LongTensor of shape (N, P1) giving the indices of - the neighbors. - - **p1_nn_normals**: Normal vectors for each point in p1_neighbors; - only returned if p2_normals is passed - else return []. - """ - N, P1, D = p1.shape - with torch.no_grad(): - p1_nn_idx = _C.nn_points_idx(p1.contiguous(), p2.contiguous()) # (N, P1) - p1_nn_idx_expanded = p1_nn_idx.view(N, P1, 1).expand(N, P1, D) - p1_nn_points = p2.gather(1, p1_nn_idx_expanded) - if p2_normals is None: - p1_nn_normals = [] - else: - if p2_normals.shape != p2.shape: - raise ValueError("p2_normals has incorrect shape.") - p1_nn_normals = p2_normals.gather(1, p1_nn_idx_expanded) - - return p1_nn_points, p1_nn_idx, p1_nn_normals diff --git a/tests/bm_nearest_neighbor_points.py b/tests/bm_nearest_neighbor_points.py deleted file mode 100644 index f98ae17e..00000000 --- a/tests/bm_nearest_neighbor_points.py +++ /dev/null @@ -1,42 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. - -from itertools import product - -import torch -from fvcore.common.benchmark import benchmark -from test_nearest_neighbor_points import TestNearestNeighborPoints - - -def bm_nn_points() -> None: - kwargs_list = [] - - N = [1, 4, 32] - D = [3, 4] - P1 = [1, 128] - P2 = [32, 128] - test_cases = product(N, D, P1, P2) - for case in test_cases: - n, d, p1, p2 = case - kwargs_list.append({"N": n, "D": d, "P1": p1, "P2": p2}) - - benchmark( - TestNearestNeighborPoints.bm_nn_points_python_with_init, - "NN_PYTHON", - kwargs_list, - warmup_iters=1, - ) - - benchmark( - TestNearestNeighborPoints.bm_nn_points_cpu_with_init, - "NN_CPU", - kwargs_list, - warmup_iters=1, - ) - - if torch.cuda.is_available(): - benchmark( - TestNearestNeighborPoints.bm_nn_points_cuda_with_init, - "NN_CUDA", - kwargs_list, - warmup_iters=1, - ) diff --git a/tests/test_nearest_neighbor_points.py b/tests/test_nearest_neighbor_points.py deleted file mode 100644 index aabfe29b..00000000 --- a/tests/test_nearest_neighbor_points.py +++ /dev/null @@ -1,91 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. - -import unittest - -import torch -from pytorch3d import _C - - -class TestNearestNeighborPoints(unittest.TestCase): - @staticmethod - def nn_points_idx_naive(x, y): - """ - PyTorch implementation of nn_points_idx function. - """ - N, P1, D = x.shape - _N, P2, _D = y.shape - assert N == _N and D == _D - diffs = x.view(N, P1, 1, D) - y.view(N, 1, P2, D) - dists2 = (diffs * diffs).sum(3) - idx = dists2.argmin(2) - return idx - - def _test_nn_helper(self, device): - for D in [3, 4]: - for N in [1, 4]: - for P1 in [1, 8, 64, 128]: - for P2 in [32, 128]: - x = torch.randn(N, P1, D, device=device) - y = torch.randn(N, P2, D, device=device) - - # _C.nn_points_idx should dispatch - # to the cpp or cuda versions of the function - # depending on the input type. - idx1 = _C.nn_points_idx(x, y) - idx2 = TestNearestNeighborPoints.nn_points_idx_naive(x, y) - self.assertTrue(idx1.size(1) == P1) - self.assertTrue(torch.all(idx1 == idx2)) - - def test_nn_cuda(self): - """ - Test cuda output vs naive python implementation. - """ - device = torch.device("cuda:0") - self._test_nn_helper(device) - - def test_nn_cpu(self): - """ - Test cpu output vs naive python implementation - """ - device = torch.device("cpu") - self._test_nn_helper(device) - - @staticmethod - def bm_nn_points_cpu_with_init( - N: int = 4, D: int = 4, P1: int = 128, P2: int = 128 - ): - device = torch.device("cpu") - x = torch.randn(N, P1, D, device=device) - y = torch.randn(N, P2, D, device=device) - - def nn_cpu(): - _C.nn_points_idx(x.contiguous(), y.contiguous()) - - return nn_cpu - - @staticmethod - def bm_nn_points_cuda_with_init( - N: int = 4, D: int = 4, P1: int = 128, P2: int = 128 - ): - device = torch.device("cuda:0") - x = torch.randn(N, P1, D, device=device) - y = torch.randn(N, P2, D, device=device) - torch.cuda.synchronize() - - def nn_cpp(): - _C.nn_points_idx(x.contiguous(), y.contiguous()) - torch.cuda.synchronize() - - return nn_cpp - - @staticmethod - def bm_nn_points_python_with_init( - N: int = 4, D: int = 4, P1: int = 128, P2: int = 128 - ): - x = torch.randn(N, P1, D) - y = torch.randn(N, P2, D) - - def nn_python(): - TestNearestNeighborPoints.nn_points_idx_naive(x, y) - - return nn_python