making sorting for K >1 optional in KNN points function

Summary: Added `sorted` argument to the `knn_points` function. This came up during the benchmarking against Faiss - sorting added extra memory usage. Match the memory usage of Faiss by making sorting optional.

Reviewed By: bottler, gkioxari

Differential Revision: D22329070

fbshipit-source-id: 0828ff9b48eefce99ce1f60089389f6885d03139
This commit is contained in:
Nikhila Ravi
2020-07-02 16:06:49 -07:00
committed by Facebook GitHub Bot
parent dd4a35cf9f
commit 806ca361c0
2 changed files with 37 additions and 9 deletions

View File

@@ -18,7 +18,9 @@ class _knn_points(Function):
"""
@staticmethod
def forward(ctx, p1, p2, lengths1, lengths2, K, version):
def forward(
ctx, p1, p2, lengths1, lengths2, K, version, return_sorted: bool = True
):
"""
K-Nearest neighbors on point clouds.
@@ -36,6 +38,8 @@ class _knn_points(Function):
K: Integer giving the number of nearest neighbors to return.
version: Which KNN implementation to use in the backend. If version=-1,
the correct implementation is selected based on the shapes of the inputs.
return_sorted: (bool) whether to return the nearest neighbors sorted in
ascending order of distance.
Returns:
p1_dists: Tensor of shape (N, P1, K) giving the squared distances to
@@ -52,7 +56,7 @@ class _knn_points(Function):
idx, dists = _C.knn_points_idx(p1, p2, lengths1, lengths2, K, version)
# sort KNN in ascending order if K > 1
if K > 1:
if K > 1 and return_sorted:
if lengths2.min() < K:
P1 = p1.shape[1]
mask = lengths2[:, None] <= torch.arange(K, device=dists.device)[None]
@@ -84,7 +88,7 @@ class _knn_points(Function):
grad_p1, grad_p2 = _C.knn_points_backward(
p1, p2, lengths1, lengths2, idx, grad_dists
)
return grad_p1, grad_p2, None, None, None, None
return grad_p1, grad_p2, None, None, None, None, None
def knn_points(
@@ -95,6 +99,7 @@ def knn_points(
K: int = 1,
version: int = -1,
return_nn: bool = False,
return_sorted: bool = True,
):
"""
K-Nearest neighbors on point clouds.
@@ -113,7 +118,9 @@ def knn_points(
K: Integer giving the number of nearest neighbors to return.
version: Which KNN implementation to use in the backend. If version=-1,
the correct implementation is selected based on the shapes of the inputs.
return_nn: If set to True returns the K nearest neighors in p2 for each point in p1.
return_nn: If set to True returns the K nearest neighbors in p2 for each point in p1.
return_sorted: (bool) whether to return the nearest neighbors sorted in
ascending order of distance.
Returns:
dists: Tensor of shape (N, P1, K) giving the squared distances to
@@ -158,7 +165,9 @@ def knn_points(
lengths2 = torch.full((p1.shape[0],), P2, dtype=torch.int64, device=p1.device)
# pyre-fixme[16]: `_knn_points` has no attribute `apply`.
p1_dists, p1_idx = _knn_points.apply(p1, p2, lengths1, lengths2, K, version)
p1_dists, p1_idx = _knn_points.apply(
p1, p2, lengths1, lengths2, K, version, return_sorted
)
p2_nn = None
if return_nn: