diff --git a/pytorch3d/ops/knn.py b/pytorch3d/ops/knn.py index a1969c2b..ad68f1da 100644 --- a/pytorch3d/ops/knn.py +++ b/pytorch3d/ops/knn.py @@ -18,7 +18,9 @@ class _knn_points(Function): """ @staticmethod - def forward(ctx, p1, p2, lengths1, lengths2, K, version): + def forward( + ctx, p1, p2, lengths1, lengths2, K, version, return_sorted: bool = True + ): """ K-Nearest neighbors on point clouds. @@ -36,6 +38,8 @@ class _knn_points(Function): K: Integer giving the number of nearest neighbors to return. version: Which KNN implementation to use in the backend. If version=-1, the correct implementation is selected based on the shapes of the inputs. + return_sorted: (bool) whether to return the nearest neighbors sorted in + ascending order of distance. Returns: p1_dists: Tensor of shape (N, P1, K) giving the squared distances to @@ -52,7 +56,7 @@ class _knn_points(Function): idx, dists = _C.knn_points_idx(p1, p2, lengths1, lengths2, K, version) # sort KNN in ascending order if K > 1 - if K > 1: + if K > 1 and return_sorted: if lengths2.min() < K: P1 = p1.shape[1] mask = lengths2[:, None] <= torch.arange(K, device=dists.device)[None] @@ -84,7 +88,7 @@ class _knn_points(Function): grad_p1, grad_p2 = _C.knn_points_backward( p1, p2, lengths1, lengths2, idx, grad_dists ) - return grad_p1, grad_p2, None, None, None, None + return grad_p1, grad_p2, None, None, None, None, None def knn_points( @@ -95,6 +99,7 @@ def knn_points( K: int = 1, version: int = -1, return_nn: bool = False, + return_sorted: bool = True, ): """ K-Nearest neighbors on point clouds. @@ -113,7 +118,9 @@ def knn_points( K: Integer giving the number of nearest neighbors to return. version: Which KNN implementation to use in the backend. If version=-1, the correct implementation is selected based on the shapes of the inputs. - return_nn: If set to True returns the K nearest neighors in p2 for each point in p1. + return_nn: If set to True returns the K nearest neighbors in p2 for each point in p1. + return_sorted: (bool) whether to return the nearest neighbors sorted in + ascending order of distance. Returns: dists: Tensor of shape (N, P1, K) giving the squared distances to @@ -158,7 +165,9 @@ def knn_points( lengths2 = torch.full((p1.shape[0],), P2, dtype=torch.int64, device=p1.device) # pyre-fixme[16]: `_knn_points` has no attribute `apply`. - p1_dists, p1_idx = _knn_points.apply(p1, p2, lengths1, lengths2, K, version) + p1_dists, p1_idx = _knn_points.apply( + p1, p2, lengths1, lengths2, K, version, return_sorted + ) p2_nn = None if return_nn: diff --git a/tests/test_knn.py b/tests/test_knn.py index 18244095..364a5889 100644 --- a/tests/test_knn.py +++ b/tests/test_knn.py @@ -49,7 +49,7 @@ class TestKNN(TestCaseMixin, unittest.TestCase): return _KNN(dists=dists, idx=idx, knn=None) - def _knn_vs_python_square_helper(self, device): + def _knn_vs_python_square_helper(self, device, return_sorted): Ns = [1, 4] Ds = [3, 5, 8] P1s = [8, 24] @@ -70,7 +70,24 @@ class TestKNN(TestCaseMixin, unittest.TestCase): # forward out1 = self._knn_points_naive(x, y, lengths1=None, lengths2=None, K=K) - out2 = knn_points(x_cuda, y_cuda, K=K, version=version) + out2 = knn_points( + x_cuda, y_cuda, K=K, version=version, return_sorted=return_sorted + ) + if K > 1 and not return_sorted: + # check out2 is not sorted + self.assertFalse(torch.allclose(out1[0], out2[0])) + self.assertFalse(torch.allclose(out1[1], out2[1])) + # now sort out2 + dists, idx, _ = out2 + if P2 < K: + dists[..., P2:] = float("inf") + dists, sort_idx = dists.sort(dim=2) + dists[..., P2:] = 0 + else: + dists, sort_idx = dists.sort(dim=2) + idx = idx.gather(2, sort_idx) + out2 = _KNN(dists, idx, None) + self.assertClose(out1[0], out2[0]) self.assertTrue(torch.all(out1[1] == out2[1])) @@ -86,11 +103,13 @@ class TestKNN(TestCaseMixin, unittest.TestCase): def test_knn_vs_python_square_cpu(self): device = torch.device("cpu") - self._knn_vs_python_square_helper(device) + self._knn_vs_python_square_helper(device, return_sorted=True) def test_knn_vs_python_square_cuda(self): device = get_random_cuda_device() - self._knn_vs_python_square_helper(device) + # Check both cases where the output is sorted and unsorted + self._knn_vs_python_square_helper(device, return_sorted=True) + self._knn_vs_python_square_helper(device, return_sorted=False) def _knn_vs_python_ragged_helper(self, device): Ns = [1, 4]