add L1 support for KNN & Chamfer

Summary:
Added L1 norm for KNN and chamfer op
* The norm is now specified with a variable `norm` which can only be 1 or 2

Reviewed By: bottler

Differential Revision: D35419637

fbshipit-source-id: 77813fec650b30c28342af90d5ed02c89133e136
This commit is contained in:
Georgia Gkioxari
2022-04-10 10:27:20 -07:00
committed by Facebook GitHub Bot
parent 4b94649f7b
commit 67fff956a2
8 changed files with 265 additions and 129 deletions

View File

@@ -87,7 +87,7 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
)
@staticmethod
def chamfer_distance_naive_pointclouds(p1, p2, device="cpu"):
def chamfer_distance_naive_pointclouds(p1, p2, norm: int = 2, device="cpu"):
"""
Naive iterative implementation of nearest neighbor and chamfer distance.
x and y are assumed to be pointclouds objects with points and optionally normals.
@@ -121,7 +121,14 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
for n in range(N):
for i1 in range(x_lengths[n]):
for i2 in range(y_lengths[n]):
dist[n, i1, i2] = torch.sum((x[n, i1, :] - y[n, i2, :]) ** 2)
if norm == 2:
dist[n, i1, i2] = torch.sum((x[n, i1, :] - y[n, i2, :]) ** 2)
elif norm == 1:
dist[n, i1, i2] = torch.sum(
torch.abs(x[n, i1, :] - y[n, i2, :])
)
else:
raise ValueError("No support for norm %d" % (norm))
x_dist = torch.min(dist, dim=2)[0] # (N, P1)
y_dist = torch.min(dist, dim=1)[0] # (N, P2)
@@ -159,7 +166,7 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
return loss, lnorm
@staticmethod
def chamfer_distance_naive(x, y, x_normals=None, y_normals=None):
def chamfer_distance_naive(x, y, x_normals=None, y_normals=None, norm: int = 2):
"""
Naive iterative implementation of nearest neighbor and chamfer distance.
Returns lists of the unreduced loss and loss_normals. This naive
@@ -174,7 +181,14 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
for n in range(N):
for i1 in range(P1):
for i2 in range(P2):
dist[n, i1, i2] = torch.sum((x[n, i1, :] - y[n, i2, :]) ** 2)
if norm == 2:
dist[n, i1, i2] = torch.sum((x[n, i1, :] - y[n, i2, :]) ** 2)
elif norm == 1:
dist[n, i1, i2] = torch.sum(
torch.abs(x[n, i1, :] - y[n, i2, :])
)
else:
raise ValueError("No support for norm %d" % (norm))
loss = [
torch.min(dist, dim=2)[0], # (N, P1)
@@ -208,30 +222,34 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
"""
N, max_P1, max_P2 = 7, 10, 18
device = get_random_cuda_device()
points_normals = TestChamfer.init_pointclouds(N, max_P1, max_P2, device)
p1 = points_normals.p1
p2 = points_normals.p2
weights = points_normals.weights
p11 = p1.detach().clone()
p22 = p2.detach().clone()
p11.requires_grad = True
p22.requires_grad = True
P1 = p1.shape[1]
P2 = p2.shape[1]
pred_loss, pred_loss_norm = TestChamfer.chamfer_distance_naive(p1, p2)
for norm in [1, 2]:
points_normals = TestChamfer.init_pointclouds(N, max_P1, max_P2, device)
p1 = points_normals.p1
p2 = points_normals.p2
weights = points_normals.weights
p11 = p1.detach().clone()
p22 = p2.detach().clone()
p11.requires_grad = True
p22.requires_grad = True
P1 = p1.shape[1]
P2 = p2.shape[1]
# point_reduction = "mean".
loss, loss_norm = chamfer_distance(p11, p22, weights=weights)
pred_loss = pred_loss[0].sum(1) / P1 + pred_loss[1].sum(1) / P2
pred_loss *= weights
pred_loss = pred_loss.sum() / weights.sum()
pred_loss, pred_loss_norm = TestChamfer.chamfer_distance_naive(
p1, p2, norm=norm
)
self.assertClose(loss, pred_loss)
self.assertTrue(loss_norm is None)
# point_reduction = "mean".
loss, loss_norm = chamfer_distance(p11, p22, weights=weights, norm=norm)
pred_loss = pred_loss[0].sum(1) / P1 + pred_loss[1].sum(1) / P2
pred_loss *= weights
pred_loss = pred_loss.sum() / weights.sum()
# Check gradients
self._check_gradients(loss, None, pred_loss, None, p1, p11, p2, p22)
self.assertClose(loss, pred_loss)
self.assertTrue(loss_norm is None)
# Check gradients
self._check_gradients(loss, None, pred_loss, None, p1, p11, p2, p22)
def test_chamfer_vs_naive_pointcloud(self):
"""
@@ -242,63 +260,67 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
"""
N, max_P1, max_P2 = 3, 70, 70
device = get_random_cuda_device()
points_normals = TestChamfer.init_pointclouds(N, max_P1, max_P2, device)
weights = points_normals.weights
x_lengths = points_normals.p1_lengths
y_lengths = points_normals.p2_lengths
# Chamfer with tensors as input for heterogeneous pointclouds.
cham_tensor, norm_tensor = chamfer_distance(
points_normals.p1,
points_normals.p2,
x_normals=points_normals.n1,
y_normals=points_normals.n2,
x_lengths=points_normals.p1_lengths,
y_lengths=points_normals.p2_lengths,
weights=weights,
)
for norm in [1, 2]:
points_normals = TestChamfer.init_pointclouds(N, max_P1, max_P2, device)
weights = points_normals.weights
x_lengths = points_normals.p1_lengths
y_lengths = points_normals.p2_lengths
# Chamfer with pointclouds as input.
pred_loss, pred_norm_loss = TestChamfer.chamfer_distance_naive_pointclouds(
points_normals.cloud1, points_normals.cloud2, device=device
)
# Chamfer with tensors as input for heterogeneous pointclouds.
cham_tensor, norm_tensor = chamfer_distance(
points_normals.p1,
points_normals.p2,
x_normals=points_normals.n1,
y_normals=points_normals.n2,
x_lengths=points_normals.p1_lengths,
y_lengths=points_normals.p2_lengths,
weights=weights,
norm=norm,
)
# Mean reduction point loss.
pred_loss[0] *= weights.view(N, 1)
pred_loss[1] *= weights.view(N, 1)
pred_loss_mean = (
pred_loss[0].sum(1) / x_lengths + pred_loss[1].sum(1) / y_lengths
)
pred_loss_mean = pred_loss_mean.sum()
pred_loss_mean /= weights.sum()
# Chamfer with pointclouds as input.
pred_loss, pred_norm_loss = TestChamfer.chamfer_distance_naive_pointclouds(
points_normals.cloud1, points_normals.cloud2, norm=norm, device=device
)
# Mean reduction norm loss.
pred_norm_loss[0] *= weights.view(N, 1)
pred_norm_loss[1] *= weights.view(N, 1)
pred_norm_loss_mean = (
pred_norm_loss[0].sum(1) / x_lengths + pred_norm_loss[1].sum(1) / y_lengths
)
pred_norm_loss_mean = pred_norm_loss_mean.sum() / weights.sum()
# Mean reduction point loss.
pred_loss[0] *= weights.view(N, 1)
pred_loss[1] *= weights.view(N, 1)
pred_loss_mean = (
pred_loss[0].sum(1) / x_lengths + pred_loss[1].sum(1) / y_lengths
)
pred_loss_mean = pred_loss_mean.sum()
pred_loss_mean /= weights.sum()
self.assertClose(pred_loss_mean, cham_tensor)
self.assertClose(pred_norm_loss_mean, norm_tensor)
# Mean reduction norm loss.
pred_norm_loss[0] *= weights.view(N, 1)
pred_norm_loss[1] *= weights.view(N, 1)
pred_norm_loss_mean = (
pred_norm_loss[0].sum(1) / x_lengths
+ pred_norm_loss[1].sum(1) / y_lengths
)
pred_norm_loss_mean = pred_norm_loss_mean.sum() / weights.sum()
self._check_gradients(
cham_tensor,
norm_tensor,
pred_loss_mean,
pred_norm_loss_mean,
points_normals.cloud1.points_list(),
points_normals.p1,
points_normals.cloud2.points_list(),
points_normals.p2,
points_normals.cloud1.normals_list(),
points_normals.n1,
points_normals.cloud2.normals_list(),
points_normals.n2,
x_lengths,
y_lengths,
)
self.assertClose(pred_loss_mean, cham_tensor)
self.assertClose(pred_norm_loss_mean, norm_tensor)
self._check_gradients(
cham_tensor,
norm_tensor,
pred_loss_mean,
pred_norm_loss_mean,
points_normals.cloud1.points_list(),
points_normals.p1,
points_normals.cloud2.points_list(),
points_normals.p2,
points_normals.cloud1.normals_list(),
points_normals.n1,
points_normals.cloud2.normals_list(),
points_normals.n2,
x_lengths,
y_lengths,
)
def test_chamfer_pointcloud_object_withnormals(self):
N = 5
@@ -742,6 +764,19 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
with self.assertRaisesRegex(ValueError, "Pointclouds objects or torch.Tensor"):
chamfer_distance(x=[1, 1, 1], y=[1, 1, 1])
def test_invalid_norm(self):
N, P1, P2 = 7, 10, 18
device = get_random_cuda_device()
points_normals = TestChamfer.init_pointclouds(N, P1, P2, device)
p1 = points_normals.p1
p2 = points_normals.p2
with self.assertRaisesRegex(ValueError, "Support for 1 or 2 norm."):
chamfer_distance(p1, p2, norm=0)
with self.assertRaisesRegex(ValueError, "Support for 1 or 2 norm."):
chamfer_distance(p1, p2, norm=3)
@staticmethod
def chamfer_with_init(
batch_size: int,