mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-19 22:00:35 +08:00
making sorting for K >1 optional in KNN points function
Summary: Added `sorted` argument to the `knn_points` function. This came up during the benchmarking against Faiss - sorting added extra memory usage. Match the memory usage of Faiss by making sorting optional. Reviewed By: bottler, gkioxari Differential Revision: D22329070 fbshipit-source-id: 0828ff9b48eefce99ce1f60089389f6885d03139
This commit is contained in:
committed by
Facebook GitHub Bot
parent
dd4a35cf9f
commit
806ca361c0
@@ -49,7 +49,7 @@ class TestKNN(TestCaseMixin, unittest.TestCase):
|
||||
|
||||
return _KNN(dists=dists, idx=idx, knn=None)
|
||||
|
||||
def _knn_vs_python_square_helper(self, device):
|
||||
def _knn_vs_python_square_helper(self, device, return_sorted):
|
||||
Ns = [1, 4]
|
||||
Ds = [3, 5, 8]
|
||||
P1s = [8, 24]
|
||||
@@ -70,7 +70,24 @@ class TestKNN(TestCaseMixin, unittest.TestCase):
|
||||
|
||||
# forward
|
||||
out1 = self._knn_points_naive(x, y, lengths1=None, lengths2=None, K=K)
|
||||
out2 = knn_points(x_cuda, y_cuda, K=K, version=version)
|
||||
out2 = knn_points(
|
||||
x_cuda, y_cuda, K=K, version=version, return_sorted=return_sorted
|
||||
)
|
||||
if K > 1 and not return_sorted:
|
||||
# check out2 is not sorted
|
||||
self.assertFalse(torch.allclose(out1[0], out2[0]))
|
||||
self.assertFalse(torch.allclose(out1[1], out2[1]))
|
||||
# now sort out2
|
||||
dists, idx, _ = out2
|
||||
if P2 < K:
|
||||
dists[..., P2:] = float("inf")
|
||||
dists, sort_idx = dists.sort(dim=2)
|
||||
dists[..., P2:] = 0
|
||||
else:
|
||||
dists, sort_idx = dists.sort(dim=2)
|
||||
idx = idx.gather(2, sort_idx)
|
||||
out2 = _KNN(dists, idx, None)
|
||||
|
||||
self.assertClose(out1[0], out2[0])
|
||||
self.assertTrue(torch.all(out1[1] == out2[1]))
|
||||
|
||||
@@ -86,11 +103,13 @@ class TestKNN(TestCaseMixin, unittest.TestCase):
|
||||
|
||||
def test_knn_vs_python_square_cpu(self):
|
||||
device = torch.device("cpu")
|
||||
self._knn_vs_python_square_helper(device)
|
||||
self._knn_vs_python_square_helper(device, return_sorted=True)
|
||||
|
||||
def test_knn_vs_python_square_cuda(self):
|
||||
device = get_random_cuda_device()
|
||||
self._knn_vs_python_square_helper(device)
|
||||
# Check both cases where the output is sorted and unsorted
|
||||
self._knn_vs_python_square_helper(device, return_sorted=True)
|
||||
self._knn_vs_python_square_helper(device, return_sorted=False)
|
||||
|
||||
def _knn_vs_python_ragged_helper(self, device):
|
||||
Ns = [1, 4]
|
||||
|
||||
Reference in New Issue
Block a user