mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 20:02:49 +08:00
making sorting for K >1 optional in KNN points function
Summary: Added `sorted` argument to the `knn_points` function. This came up during the benchmarking against Faiss - sorting added extra memory usage. Match the memory usage of Faiss by making sorting optional. Reviewed By: bottler, gkioxari Differential Revision: D22329070 fbshipit-source-id: 0828ff9b48eefce99ce1f60089389f6885d03139
This commit is contained in:
parent
dd4a35cf9f
commit
806ca361c0
@ -18,7 +18,9 @@ class _knn_points(Function):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, p1, p2, lengths1, lengths2, K, version):
|
def forward(
|
||||||
|
ctx, p1, p2, lengths1, lengths2, K, version, return_sorted: bool = True
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
K-Nearest neighbors on point clouds.
|
K-Nearest neighbors on point clouds.
|
||||||
|
|
||||||
@ -36,6 +38,8 @@ class _knn_points(Function):
|
|||||||
K: Integer giving the number of nearest neighbors to return.
|
K: Integer giving the number of nearest neighbors to return.
|
||||||
version: Which KNN implementation to use in the backend. If version=-1,
|
version: Which KNN implementation to use in the backend. If version=-1,
|
||||||
the correct implementation is selected based on the shapes of the inputs.
|
the correct implementation is selected based on the shapes of the inputs.
|
||||||
|
return_sorted: (bool) whether to return the nearest neighbors sorted in
|
||||||
|
ascending order of distance.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
p1_dists: Tensor of shape (N, P1, K) giving the squared distances to
|
p1_dists: Tensor of shape (N, P1, K) giving the squared distances to
|
||||||
@ -52,7 +56,7 @@ class _knn_points(Function):
|
|||||||
idx, dists = _C.knn_points_idx(p1, p2, lengths1, lengths2, K, version)
|
idx, dists = _C.knn_points_idx(p1, p2, lengths1, lengths2, K, version)
|
||||||
|
|
||||||
# sort KNN in ascending order if K > 1
|
# sort KNN in ascending order if K > 1
|
||||||
if K > 1:
|
if K > 1 and return_sorted:
|
||||||
if lengths2.min() < K:
|
if lengths2.min() < K:
|
||||||
P1 = p1.shape[1]
|
P1 = p1.shape[1]
|
||||||
mask = lengths2[:, None] <= torch.arange(K, device=dists.device)[None]
|
mask = lengths2[:, None] <= torch.arange(K, device=dists.device)[None]
|
||||||
@ -84,7 +88,7 @@ class _knn_points(Function):
|
|||||||
grad_p1, grad_p2 = _C.knn_points_backward(
|
grad_p1, grad_p2 = _C.knn_points_backward(
|
||||||
p1, p2, lengths1, lengths2, idx, grad_dists
|
p1, p2, lengths1, lengths2, idx, grad_dists
|
||||||
)
|
)
|
||||||
return grad_p1, grad_p2, None, None, None, None
|
return grad_p1, grad_p2, None, None, None, None, None
|
||||||
|
|
||||||
|
|
||||||
def knn_points(
|
def knn_points(
|
||||||
@ -95,6 +99,7 @@ def knn_points(
|
|||||||
K: int = 1,
|
K: int = 1,
|
||||||
version: int = -1,
|
version: int = -1,
|
||||||
return_nn: bool = False,
|
return_nn: bool = False,
|
||||||
|
return_sorted: bool = True,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
K-Nearest neighbors on point clouds.
|
K-Nearest neighbors on point clouds.
|
||||||
@ -113,7 +118,9 @@ def knn_points(
|
|||||||
K: Integer giving the number of nearest neighbors to return.
|
K: Integer giving the number of nearest neighbors to return.
|
||||||
version: Which KNN implementation to use in the backend. If version=-1,
|
version: Which KNN implementation to use in the backend. If version=-1,
|
||||||
the correct implementation is selected based on the shapes of the inputs.
|
the correct implementation is selected based on the shapes of the inputs.
|
||||||
return_nn: If set to True returns the K nearest neighors in p2 for each point in p1.
|
return_nn: If set to True returns the K nearest neighbors in p2 for each point in p1.
|
||||||
|
return_sorted: (bool) whether to return the nearest neighbors sorted in
|
||||||
|
ascending order of distance.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dists: Tensor of shape (N, P1, K) giving the squared distances to
|
dists: Tensor of shape (N, P1, K) giving the squared distances to
|
||||||
@ -158,7 +165,9 @@ def knn_points(
|
|||||||
lengths2 = torch.full((p1.shape[0],), P2, dtype=torch.int64, device=p1.device)
|
lengths2 = torch.full((p1.shape[0],), P2, dtype=torch.int64, device=p1.device)
|
||||||
|
|
||||||
# pyre-fixme[16]: `_knn_points` has no attribute `apply`.
|
# pyre-fixme[16]: `_knn_points` has no attribute `apply`.
|
||||||
p1_dists, p1_idx = _knn_points.apply(p1, p2, lengths1, lengths2, K, version)
|
p1_dists, p1_idx = _knn_points.apply(
|
||||||
|
p1, p2, lengths1, lengths2, K, version, return_sorted
|
||||||
|
)
|
||||||
|
|
||||||
p2_nn = None
|
p2_nn = None
|
||||||
if return_nn:
|
if return_nn:
|
||||||
|
@ -49,7 +49,7 @@ class TestKNN(TestCaseMixin, unittest.TestCase):
|
|||||||
|
|
||||||
return _KNN(dists=dists, idx=idx, knn=None)
|
return _KNN(dists=dists, idx=idx, knn=None)
|
||||||
|
|
||||||
def _knn_vs_python_square_helper(self, device):
|
def _knn_vs_python_square_helper(self, device, return_sorted):
|
||||||
Ns = [1, 4]
|
Ns = [1, 4]
|
||||||
Ds = [3, 5, 8]
|
Ds = [3, 5, 8]
|
||||||
P1s = [8, 24]
|
P1s = [8, 24]
|
||||||
@ -70,7 +70,24 @@ class TestKNN(TestCaseMixin, unittest.TestCase):
|
|||||||
|
|
||||||
# forward
|
# forward
|
||||||
out1 = self._knn_points_naive(x, y, lengths1=None, lengths2=None, K=K)
|
out1 = self._knn_points_naive(x, y, lengths1=None, lengths2=None, K=K)
|
||||||
out2 = knn_points(x_cuda, y_cuda, K=K, version=version)
|
out2 = knn_points(
|
||||||
|
x_cuda, y_cuda, K=K, version=version, return_sorted=return_sorted
|
||||||
|
)
|
||||||
|
if K > 1 and not return_sorted:
|
||||||
|
# check out2 is not sorted
|
||||||
|
self.assertFalse(torch.allclose(out1[0], out2[0]))
|
||||||
|
self.assertFalse(torch.allclose(out1[1], out2[1]))
|
||||||
|
# now sort out2
|
||||||
|
dists, idx, _ = out2
|
||||||
|
if P2 < K:
|
||||||
|
dists[..., P2:] = float("inf")
|
||||||
|
dists, sort_idx = dists.sort(dim=2)
|
||||||
|
dists[..., P2:] = 0
|
||||||
|
else:
|
||||||
|
dists, sort_idx = dists.sort(dim=2)
|
||||||
|
idx = idx.gather(2, sort_idx)
|
||||||
|
out2 = _KNN(dists, idx, None)
|
||||||
|
|
||||||
self.assertClose(out1[0], out2[0])
|
self.assertClose(out1[0], out2[0])
|
||||||
self.assertTrue(torch.all(out1[1] == out2[1]))
|
self.assertTrue(torch.all(out1[1] == out2[1]))
|
||||||
|
|
||||||
@ -86,11 +103,13 @@ class TestKNN(TestCaseMixin, unittest.TestCase):
|
|||||||
|
|
||||||
def test_knn_vs_python_square_cpu(self):
|
def test_knn_vs_python_square_cpu(self):
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
self._knn_vs_python_square_helper(device)
|
self._knn_vs_python_square_helper(device, return_sorted=True)
|
||||||
|
|
||||||
def test_knn_vs_python_square_cuda(self):
|
def test_knn_vs_python_square_cuda(self):
|
||||||
device = get_random_cuda_device()
|
device = get_random_cuda_device()
|
||||||
self._knn_vs_python_square_helper(device)
|
# Check both cases where the output is sorted and unsorted
|
||||||
|
self._knn_vs_python_square_helper(device, return_sorted=True)
|
||||||
|
self._knn_vs_python_square_helper(device, return_sorted=False)
|
||||||
|
|
||||||
def _knn_vs_python_ragged_helper(self, device):
|
def _knn_vs_python_ragged_helper(self, device):
|
||||||
Ns = [1, 4]
|
Ns = [1, 4]
|
||||||
|
Loading…
x
Reference in New Issue
Block a user