add L1 support for KNN & Chamfer

Summary:
Added L1 norm for KNN and chamfer op
* The norm is now specified with a variable `norm` which can only be 1 or 2

Reviewed By: bottler

Differential Revision: D35419637

fbshipit-source-id: 77813fec650b30c28342af90d5ed02c89133e136
This commit is contained in:
Georgia Gkioxari
2022-04-10 10:27:20 -07:00
committed by Facebook GitHub Bot
parent 4b94649f7b
commit 67fff956a2
8 changed files with 265 additions and 129 deletions

View File

@@ -18,7 +18,9 @@ class TestKNN(TestCaseMixin, unittest.TestCase):
torch.manual_seed(1)
@staticmethod
def _knn_points_naive(p1, p2, lengths1, lengths2, K: int) -> torch.Tensor:
def _knn_points_naive(
p1, p2, lengths1, lengths2, K: int, norm: int = 2
) -> torch.Tensor:
"""
Naive PyTorch implementation of K-Nearest Neighbors.
Returns always sorted results
@@ -42,7 +44,12 @@ class TestKNN(TestCaseMixin, unittest.TestCase):
pp1 = p1[n, :num1].view(num1, 1, D)
pp2 = p2[n, :num2].view(1, num2, D)
diff = pp1 - pp2
diff = (diff * diff).sum(2)
if norm == 2:
diff = (diff * diff).sum(2)
elif norm == 1:
diff = diff.abs().sum(2)
else:
raise ValueError("No support for norm %d" % (norm))
num2 = min(num2, K)
for i in range(num1):
dd = diff[i]
@@ -59,9 +66,10 @@ class TestKNN(TestCaseMixin, unittest.TestCase):
P1s = [8, 24]
P2s = [8, 16, 32]
Ks = [1, 3, 10]
norms = [1, 2]
versions = [0, 1, 2, 3]
factors = [Ns, Ds, P1s, P2s, Ks]
for N, D, P1, P2, K in product(*factors):
factors = [Ns, Ds, P1s, P2s, Ks, norms]
for N, D, P1, P2, K, norm in product(*factors):
for version in versions:
if version == 3 and K > 4:
continue
@@ -73,9 +81,16 @@ class TestKNN(TestCaseMixin, unittest.TestCase):
y_cuda.requires_grad_(True)
# forward
out1 = self._knn_points_naive(x, y, lengths1=None, lengths2=None, K=K)
out1 = self._knn_points_naive(
x, y, lengths1=None, lengths2=None, K=K, norm=norm
)
out2 = knn_points(
x_cuda, y_cuda, K=K, version=version, return_sorted=return_sorted
x_cuda,
y_cuda,
K=K,
norm=norm,
version=version,
return_sorted=return_sorted,
)
if K > 1 and not return_sorted:
# check out2 is not sorted
@@ -121,8 +136,9 @@ class TestKNN(TestCaseMixin, unittest.TestCase):
P1s = [8, 24]
P2s = [8, 16, 32]
Ks = [1, 3, 10]
factors = [Ns, Ds, P1s, P2s, Ks]
for N, D, P1, P2, K in product(*factors):
norms = [1, 2]
factors = [Ns, Ds, P1s, P2s, Ks, norms]
for N, D, P1, P2, K, norm in product(*factors):
x = torch.rand((N, P1, D), device=device, requires_grad=True)
y = torch.rand((N, P2, D), device=device, requires_grad=True)
lengths1 = torch.randint(low=1, high=P1, size=(N,), device=device)
@@ -135,9 +151,11 @@ class TestKNN(TestCaseMixin, unittest.TestCase):
# forward
out1 = self._knn_points_naive(
x, y, lengths1=lengths1, lengths2=lengths2, K=K
x, y, lengths1=lengths1, lengths2=lengths2, K=K, norm=norm
)
out2 = knn_points(
x_csrc, y_csrc, lengths1=lengths1, lengths2=lengths2, K=K, norm=norm
)
out2 = knn_points(x_csrc, y_csrc, lengths1=lengths1, lengths2=lengths2, K=K)
self.assertClose(out1[0], out2[0])
self.assertTrue(torch.all(out1[1] == out2[1]))
@@ -198,6 +216,17 @@ class TestKNN(TestCaseMixin, unittest.TestCase):
expected = all_expected[version]
self.assertEqual(actual, expected)
def test_invalid_norm(self):
device = get_random_cuda_device()
N, P1, P2, K, D = 4, 16, 12, 8, 3
x = torch.rand((N, P1, D), device=device)
y = torch.rand((N, P2, D), device=device)
with self.assertRaisesRegex(ValueError, "Support for 1 or 2 norm."):
knn_points(x, y, K=K, norm=3)
with self.assertRaisesRegex(ValueError, "Support for 1 or 2 norm."):
knn_points(x, y, K=K, norm=0)
@staticmethod
def knn_square(N: int, P1: int, P2: int, D: int, K: int, device: str):
device = torch.device(device)