add L1 support for KNN & Chamfer

Summary:
Added L1 norm for KNN and chamfer op
* The norm is now specified with a variable `norm` which can only be 1 or 2

Reviewed By: bottler

Differential Revision: D35419637

fbshipit-source-id: 77813fec650b30c28342af90d5ed02c89133e136
This commit is contained in:
Georgia Gkioxari
2022-04-10 10:27:20 -07:00
committed by Facebook GitHub Bot
parent 4b94649f7b
commit 67fff956a2
8 changed files with 265 additions and 129 deletions

View File

@@ -87,7 +87,7 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
)
@staticmethod
def chamfer_distance_naive_pointclouds(p1, p2, device="cpu"):
def chamfer_distance_naive_pointclouds(p1, p2, norm: int = 2, device="cpu"):
"""
Naive iterative implementation of nearest neighbor and chamfer distance.
x and y are assumed to be pointclouds objects with points and optionally normals.
@@ -121,7 +121,14 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
for n in range(N):
for i1 in range(x_lengths[n]):
for i2 in range(y_lengths[n]):
dist[n, i1, i2] = torch.sum((x[n, i1, :] - y[n, i2, :]) ** 2)
if norm == 2:
dist[n, i1, i2] = torch.sum((x[n, i1, :] - y[n, i2, :]) ** 2)
elif norm == 1:
dist[n, i1, i2] = torch.sum(
torch.abs(x[n, i1, :] - y[n, i2, :])
)
else:
raise ValueError("No support for norm %d" % (norm))
x_dist = torch.min(dist, dim=2)[0] # (N, P1)
y_dist = torch.min(dist, dim=1)[0] # (N, P2)
@@ -159,7 +166,7 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
return loss, lnorm
@staticmethod
def chamfer_distance_naive(x, y, x_normals=None, y_normals=None):
def chamfer_distance_naive(x, y, x_normals=None, y_normals=None, norm: int = 2):
"""
Naive iterative implementation of nearest neighbor and chamfer distance.
Returns lists of the unreduced loss and loss_normals. This naive
@@ -174,7 +181,14 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
for n in range(N):
for i1 in range(P1):
for i2 in range(P2):
dist[n, i1, i2] = torch.sum((x[n, i1, :] - y[n, i2, :]) ** 2)
if norm == 2:
dist[n, i1, i2] = torch.sum((x[n, i1, :] - y[n, i2, :]) ** 2)
elif norm == 1:
dist[n, i1, i2] = torch.sum(
torch.abs(x[n, i1, :] - y[n, i2, :])
)
else:
raise ValueError("No support for norm %d" % (norm))
loss = [
torch.min(dist, dim=2)[0], # (N, P1)
@@ -208,30 +222,34 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
"""
N, max_P1, max_P2 = 7, 10, 18
device = get_random_cuda_device()
points_normals = TestChamfer.init_pointclouds(N, max_P1, max_P2, device)
p1 = points_normals.p1
p2 = points_normals.p2
weights = points_normals.weights
p11 = p1.detach().clone()
p22 = p2.detach().clone()
p11.requires_grad = True
p22.requires_grad = True
P1 = p1.shape[1]
P2 = p2.shape[1]
pred_loss, pred_loss_norm = TestChamfer.chamfer_distance_naive(p1, p2)
for norm in [1, 2]:
points_normals = TestChamfer.init_pointclouds(N, max_P1, max_P2, device)
p1 = points_normals.p1
p2 = points_normals.p2
weights = points_normals.weights
p11 = p1.detach().clone()
p22 = p2.detach().clone()
p11.requires_grad = True
p22.requires_grad = True
P1 = p1.shape[1]
P2 = p2.shape[1]
# point_reduction = "mean".
loss, loss_norm = chamfer_distance(p11, p22, weights=weights)
pred_loss = pred_loss[0].sum(1) / P1 + pred_loss[1].sum(1) / P2
pred_loss *= weights
pred_loss = pred_loss.sum() / weights.sum()
pred_loss, pred_loss_norm = TestChamfer.chamfer_distance_naive(
p1, p2, norm=norm
)
self.assertClose(loss, pred_loss)
self.assertTrue(loss_norm is None)
# point_reduction = "mean".
loss, loss_norm = chamfer_distance(p11, p22, weights=weights, norm=norm)
pred_loss = pred_loss[0].sum(1) / P1 + pred_loss[1].sum(1) / P2
pred_loss *= weights
pred_loss = pred_loss.sum() / weights.sum()
# Check gradients
self._check_gradients(loss, None, pred_loss, None, p1, p11, p2, p22)
self.assertClose(loss, pred_loss)
self.assertTrue(loss_norm is None)
# Check gradients
self._check_gradients(loss, None, pred_loss, None, p1, p11, p2, p22)
def test_chamfer_vs_naive_pointcloud(self):
"""
@@ -242,63 +260,67 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
"""
N, max_P1, max_P2 = 3, 70, 70
device = get_random_cuda_device()
points_normals = TestChamfer.init_pointclouds(N, max_P1, max_P2, device)
weights = points_normals.weights
x_lengths = points_normals.p1_lengths
y_lengths = points_normals.p2_lengths
# Chamfer with tensors as input for heterogeneous pointclouds.
cham_tensor, norm_tensor = chamfer_distance(
points_normals.p1,
points_normals.p2,
x_normals=points_normals.n1,
y_normals=points_normals.n2,
x_lengths=points_normals.p1_lengths,
y_lengths=points_normals.p2_lengths,
weights=weights,
)
for norm in [1, 2]:
points_normals = TestChamfer.init_pointclouds(N, max_P1, max_P2, device)
weights = points_normals.weights
x_lengths = points_normals.p1_lengths
y_lengths = points_normals.p2_lengths
# Chamfer with pointclouds as input.
pred_loss, pred_norm_loss = TestChamfer.chamfer_distance_naive_pointclouds(
points_normals.cloud1, points_normals.cloud2, device=device
)
# Chamfer with tensors as input for heterogeneous pointclouds.
cham_tensor, norm_tensor = chamfer_distance(
points_normals.p1,
points_normals.p2,
x_normals=points_normals.n1,
y_normals=points_normals.n2,
x_lengths=points_normals.p1_lengths,
y_lengths=points_normals.p2_lengths,
weights=weights,
norm=norm,
)
# Mean reduction point loss.
pred_loss[0] *= weights.view(N, 1)
pred_loss[1] *= weights.view(N, 1)
pred_loss_mean = (
pred_loss[0].sum(1) / x_lengths + pred_loss[1].sum(1) / y_lengths
)
pred_loss_mean = pred_loss_mean.sum()
pred_loss_mean /= weights.sum()
# Chamfer with pointclouds as input.
pred_loss, pred_norm_loss = TestChamfer.chamfer_distance_naive_pointclouds(
points_normals.cloud1, points_normals.cloud2, norm=norm, device=device
)
# Mean reduction norm loss.
pred_norm_loss[0] *= weights.view(N, 1)
pred_norm_loss[1] *= weights.view(N, 1)
pred_norm_loss_mean = (
pred_norm_loss[0].sum(1) / x_lengths + pred_norm_loss[1].sum(1) / y_lengths
)
pred_norm_loss_mean = pred_norm_loss_mean.sum() / weights.sum()
# Mean reduction point loss.
pred_loss[0] *= weights.view(N, 1)
pred_loss[1] *= weights.view(N, 1)
pred_loss_mean = (
pred_loss[0].sum(1) / x_lengths + pred_loss[1].sum(1) / y_lengths
)
pred_loss_mean = pred_loss_mean.sum()
pred_loss_mean /= weights.sum()
self.assertClose(pred_loss_mean, cham_tensor)
self.assertClose(pred_norm_loss_mean, norm_tensor)
# Mean reduction norm loss.
pred_norm_loss[0] *= weights.view(N, 1)
pred_norm_loss[1] *= weights.view(N, 1)
pred_norm_loss_mean = (
pred_norm_loss[0].sum(1) / x_lengths
+ pred_norm_loss[1].sum(1) / y_lengths
)
pred_norm_loss_mean = pred_norm_loss_mean.sum() / weights.sum()
self._check_gradients(
cham_tensor,
norm_tensor,
pred_loss_mean,
pred_norm_loss_mean,
points_normals.cloud1.points_list(),
points_normals.p1,
points_normals.cloud2.points_list(),
points_normals.p2,
points_normals.cloud1.normals_list(),
points_normals.n1,
points_normals.cloud2.normals_list(),
points_normals.n2,
x_lengths,
y_lengths,
)
self.assertClose(pred_loss_mean, cham_tensor)
self.assertClose(pred_norm_loss_mean, norm_tensor)
self._check_gradients(
cham_tensor,
norm_tensor,
pred_loss_mean,
pred_norm_loss_mean,
points_normals.cloud1.points_list(),
points_normals.p1,
points_normals.cloud2.points_list(),
points_normals.p2,
points_normals.cloud1.normals_list(),
points_normals.n1,
points_normals.cloud2.normals_list(),
points_normals.n2,
x_lengths,
y_lengths,
)
def test_chamfer_pointcloud_object_withnormals(self):
N = 5
@@ -742,6 +764,19 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
with self.assertRaisesRegex(ValueError, "Pointclouds objects or torch.Tensor"):
chamfer_distance(x=[1, 1, 1], y=[1, 1, 1])
def test_invalid_norm(self):
N, P1, P2 = 7, 10, 18
device = get_random_cuda_device()
points_normals = TestChamfer.init_pointclouds(N, P1, P2, device)
p1 = points_normals.p1
p2 = points_normals.p2
with self.assertRaisesRegex(ValueError, "Support for 1 or 2 norm."):
chamfer_distance(p1, p2, norm=0)
with self.assertRaisesRegex(ValueError, "Support for 1 or 2 norm."):
chamfer_distance(p1, p2, norm=3)
@staticmethod
def chamfer_with_init(
batch_size: int,

View File

@@ -18,7 +18,9 @@ class TestKNN(TestCaseMixin, unittest.TestCase):
torch.manual_seed(1)
@staticmethod
def _knn_points_naive(p1, p2, lengths1, lengths2, K: int) -> torch.Tensor:
def _knn_points_naive(
p1, p2, lengths1, lengths2, K: int, norm: int = 2
) -> torch.Tensor:
"""
Naive PyTorch implementation of K-Nearest Neighbors.
Returns always sorted results
@@ -42,7 +44,12 @@ class TestKNN(TestCaseMixin, unittest.TestCase):
pp1 = p1[n, :num1].view(num1, 1, D)
pp2 = p2[n, :num2].view(1, num2, D)
diff = pp1 - pp2
diff = (diff * diff).sum(2)
if norm == 2:
diff = (diff * diff).sum(2)
elif norm == 1:
diff = diff.abs().sum(2)
else:
raise ValueError("No support for norm %d" % (norm))
num2 = min(num2, K)
for i in range(num1):
dd = diff[i]
@@ -59,9 +66,10 @@ class TestKNN(TestCaseMixin, unittest.TestCase):
P1s = [8, 24]
P2s = [8, 16, 32]
Ks = [1, 3, 10]
norms = [1, 2]
versions = [0, 1, 2, 3]
factors = [Ns, Ds, P1s, P2s, Ks]
for N, D, P1, P2, K in product(*factors):
factors = [Ns, Ds, P1s, P2s, Ks, norms]
for N, D, P1, P2, K, norm in product(*factors):
for version in versions:
if version == 3 and K > 4:
continue
@@ -73,9 +81,16 @@ class TestKNN(TestCaseMixin, unittest.TestCase):
y_cuda.requires_grad_(True)
# forward
out1 = self._knn_points_naive(x, y, lengths1=None, lengths2=None, K=K)
out1 = self._knn_points_naive(
x, y, lengths1=None, lengths2=None, K=K, norm=norm
)
out2 = knn_points(
x_cuda, y_cuda, K=K, version=version, return_sorted=return_sorted
x_cuda,
y_cuda,
K=K,
norm=norm,
version=version,
return_sorted=return_sorted,
)
if K > 1 and not return_sorted:
# check out2 is not sorted
@@ -121,8 +136,9 @@ class TestKNN(TestCaseMixin, unittest.TestCase):
P1s = [8, 24]
P2s = [8, 16, 32]
Ks = [1, 3, 10]
factors = [Ns, Ds, P1s, P2s, Ks]
for N, D, P1, P2, K in product(*factors):
norms = [1, 2]
factors = [Ns, Ds, P1s, P2s, Ks, norms]
for N, D, P1, P2, K, norm in product(*factors):
x = torch.rand((N, P1, D), device=device, requires_grad=True)
y = torch.rand((N, P2, D), device=device, requires_grad=True)
lengths1 = torch.randint(low=1, high=P1, size=(N,), device=device)
@@ -135,9 +151,11 @@ class TestKNN(TestCaseMixin, unittest.TestCase):
# forward
out1 = self._knn_points_naive(
x, y, lengths1=lengths1, lengths2=lengths2, K=K
x, y, lengths1=lengths1, lengths2=lengths2, K=K, norm=norm
)
out2 = knn_points(
x_csrc, y_csrc, lengths1=lengths1, lengths2=lengths2, K=K, norm=norm
)
out2 = knn_points(x_csrc, y_csrc, lengths1=lengths1, lengths2=lengths2, K=K)
self.assertClose(out1[0], out2[0])
self.assertTrue(torch.all(out1[1] == out2[1]))
@@ -198,6 +216,17 @@ class TestKNN(TestCaseMixin, unittest.TestCase):
expected = all_expected[version]
self.assertEqual(actual, expected)
def test_invalid_norm(self):
device = get_random_cuda_device()
N, P1, P2, K, D = 4, 16, 12, 8, 3
x = torch.rand((N, P1, D), device=device)
y = torch.rand((N, P2, D), device=device)
with self.assertRaisesRegex(ValueError, "Support for 1 or 2 norm."):
knn_points(x, y, K=K, norm=3)
with self.assertRaisesRegex(ValueError, "Support for 1 or 2 norm."):
knn_points(x, y, K=K, norm=0)
@staticmethod
def knn_square(N: int, P1: int, P2: int, D: int, K: int, device: str):
device = torch.device(device)