Single directional chamfer distance and non-absolute cosine similarity

Summary: Single directional chamfer distance and option to use non-absolute cosine similarity

Reviewed By: bottler

Differential Revision: D46593980

fbshipit-source-id: b2e591706a0cdde1c2d361614cecebb84a581433
This commit is contained in:
Norman Mueller
2023-06-13 09:09:15 -07:00
committed by Facebook GitHub Bot
parent 573a42cd5f
commit 5ffeb4d580
2 changed files with 326 additions and 121 deletions

View File

@@ -88,7 +88,9 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
)
@staticmethod
def chamfer_distance_naive_pointclouds(p1, p2, norm: int = 2, device="cpu"):
def chamfer_distance_naive_pointclouds(
p1, p2, norm: int = 2, device="cpu", abs_cosine=True
):
"""
Naive iterative implementation of nearest neighbor and chamfer distance.
x and y are assumed to be pointclouds objects with points and optionally normals.
@@ -146,17 +148,20 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
if return_normals:
x_index = dist.argmin(2).view(N, P1, 1).expand(N, P1, 3)
y_index = dist.argmin(1).view(N, P2, 1).expand(N, P2, 3)
lnorm1 = 1 - torch.abs(
F.cosine_similarity(
x_normals, y_normals.gather(1, x_index), dim=2, eps=1e-6
)
cosine_sim1 = F.cosine_similarity(
x_normals, y_normals.gather(1, x_index), dim=2, eps=1e-6
)
lnorm2 = 1 - torch.abs(
F.cosine_similarity(
y_normals, x_normals.gather(1, y_index), dim=2, eps=1e-6
)
cosine_sim2 = F.cosine_similarity(
y_normals, x_normals.gather(1, y_index), dim=2, eps=1e-6
)
if abs_cosine:
lnorm1 = 1 - torch.abs(cosine_sim1)
lnorm2 = 1 - torch.abs(cosine_sim2)
else:
lnorm1 = 1 - cosine_sim1
lnorm2 = 1 - cosine_sim2
if is_x_heterogeneous:
lnorm1[x_mask] = 0.0
if is_y_heterogeneous:
@@ -167,7 +172,9 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
return loss, lnorm
@staticmethod
def chamfer_distance_naive(x, y, x_normals=None, y_normals=None, norm: int = 2):
def chamfer_distance_naive(
x, y, x_normals=None, y_normals=None, norm: int = 2, abs_cosine=True
):
"""
Naive iterative implementation of nearest neighbor and chamfer distance.
Returns lists of the unreduced loss and loss_normals. This naive
@@ -200,16 +207,21 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
if return_normals:
x_index = dist.argmin(2).view(N, P1, 1).expand(N, P1, 3)
y_index = dist.argmin(1).view(N, P2, 1).expand(N, P2, 3)
lnorm1 = 1 - torch.abs(
F.cosine_similarity(
x_normals, y_normals.gather(1, x_index), dim=2, eps=1e-6
)
cosine_sim1 = F.cosine_similarity(
x_normals, y_normals.gather(1, x_index), dim=2, eps=1e-6
)
lnorm2 = 1 - torch.abs(
F.cosine_similarity(
y_normals, x_normals.gather(1, y_index), dim=2, eps=1e-6
)
cosine_sim2 = F.cosine_similarity(
y_normals, x_normals.gather(1, y_index), dim=2, eps=1e-6
)
if abs_cosine:
lnorm1 = 1 - torch.abs(cosine_sim1)
lnorm2 = 1 - torch.abs(cosine_sim2)
else:
lnorm1 = 1 - cosine_sim1
lnorm2 = 1 - cosine_sim2
lnorm = [lnorm1, lnorm2] # [(N, P1), (N, P2)]
return loss, lnorm
@@ -323,6 +335,80 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
y_lengths,
)
def test_single_directional_chamfer_vs_naive_pointcloud(self):
"""
Test the single directional settings for chamfer_distance
(point reduction = "mean" and batch_reduction="mean") but with heterogeneous
pointclouds as input. Compare with the naive implementation of chamfer
which supports heterogeneous pointcloud objects.
"""
N, max_P1, max_P2 = 3, 70, 70
device = get_random_cuda_device()
for norm in [1, 2]:
for abs_cosine in [True, False]:
points_normals = TestChamfer.init_pointclouds(N, max_P1, max_P2, device)
weights = points_normals.weights
x_lengths = points_normals.p1_lengths
y_lengths = points_normals.p2_lengths
# Chamfer with tensors as input for heterogeneous pointclouds.
cham_tensor, norm_tensor = chamfer_distance(
points_normals.p1,
points_normals.p2,
x_normals=points_normals.n1,
y_normals=points_normals.n2,
x_lengths=points_normals.p1_lengths,
y_lengths=points_normals.p2_lengths,
weights=weights,
norm=norm,
single_directional=True,
abs_cosine=abs_cosine,
)
# Chamfer with pointclouds as input.
(
pred_loss,
pred_norm_loss,
) = TestChamfer.chamfer_distance_naive_pointclouds(
points_normals.cloud1,
points_normals.cloud2,
norm=norm,
device=device,
abs_cosine=abs_cosine,
)
# Mean reduction point loss.
pred_loss[0] *= weights.view(N, 1)
pred_loss_mean = pred_loss[0].sum(1) / x_lengths
pred_loss_mean = pred_loss_mean.sum()
pred_loss_mean /= weights.sum()
# Mean reduction norm loss.
pred_norm_loss[0] *= weights.view(N, 1)
pred_norm_loss_mean = pred_norm_loss[0].sum(1) / x_lengths
pred_norm_loss_mean = pred_norm_loss_mean.sum() / weights.sum()
self.assertClose(pred_loss_mean, cham_tensor)
self.assertClose(pred_norm_loss_mean, norm_tensor)
self._check_gradients(
cham_tensor,
norm_tensor,
pred_loss_mean,
pred_norm_loss_mean,
points_normals.cloud1.points_list(),
points_normals.p1,
points_normals.cloud2.points_list(),
points_normals.p2,
points_normals.cloud1.normals_list(),
points_normals.n1,
points_normals.cloud2.normals_list(),
points_normals.n2,
x_lengths,
y_lengths,
)
def test_chamfer_pointcloud_object_withnormals(self):
N = 5
P1, P2 = 100, 100
@@ -485,6 +571,53 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
loss, loss_norm, pred_loss_mean, pred_loss_norm_mean, p1, p11, p2, p22
)
def test_single_direction_chamfer_point_reduction_mean(self):
"""
Compare output of vectorized chamfer loss with naive implementation
for point_reduction = "mean" and batch_reduction = None.
"""
N, max_P1, max_P2 = 7, 10, 18
device = get_random_cuda_device()
points_normals = TestChamfer.init_pointclouds(N, max_P1, max_P2, device)
p1 = points_normals.p1
p2 = points_normals.p2
p1_normals = points_normals.n1
p2_normals = points_normals.n2
weights = points_normals.weights
p11 = p1.detach().clone()
p22 = p2.detach().clone()
p11.requires_grad = True
p22.requires_grad = True
P1 = p1.shape[1]
pred_loss, pred_loss_norm = TestChamfer.chamfer_distance_naive(
p1, p2, x_normals=p1_normals, y_normals=p2_normals
)
# point_reduction = "mean".
loss, loss_norm = chamfer_distance(
p11,
p22,
x_normals=p1_normals,
y_normals=p2_normals,
weights=weights,
batch_reduction=None,
point_reduction="mean",
single_directional=True,
)
pred_loss_mean = pred_loss[0].sum(1) / P1
pred_loss_mean *= weights
self.assertClose(loss, pred_loss_mean)
pred_loss_norm_mean = pred_loss_norm[0].sum(1) / P1
pred_loss_norm_mean *= weights
self.assertClose(loss_norm, pred_loss_norm_mean)
# Check gradients
self._check_gradients(
loss, loss_norm, pred_loss_mean, pred_loss_norm_mean, p1, p11, p2, p22
)
def test_chamfer_point_reduction_sum(self):
"""
Compare output of vectorized chamfer loss with naive implementation
@@ -529,6 +662,51 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
loss, loss_norm, pred_loss_sum, pred_loss_norm_sum, p1, p11, p2, p22
)
def test_single_directional_chamfer_point_reduction_sum(self):
"""
Compare output of vectorized single directional chamfer loss with naive implementation
for point_reduction = "sum" and batch_reduction = None.
"""
N, P1, P2 = 7, 10, 18
device = get_random_cuda_device()
points_normals = TestChamfer.init_pointclouds(N, P1, P2, device)
p1 = points_normals.p1
p2 = points_normals.p2
p1_normals = points_normals.n1
p2_normals = points_normals.n2
weights = points_normals.weights
p11 = p1.detach().clone()
p22 = p2.detach().clone()
p11.requires_grad = True
p22.requires_grad = True
pred_loss, pred_loss_norm = TestChamfer.chamfer_distance_naive(
p1, p2, x_normals=p1_normals, y_normals=p2_normals
)
loss, loss_norm = chamfer_distance(
p11,
p22,
x_normals=p1_normals,
y_normals=p2_normals,
weights=weights,
batch_reduction=None,
point_reduction="sum",
single_directional=True,
)
pred_loss_sum = pred_loss[0].sum(1)
pred_loss_sum *= weights
self.assertClose(loss, pred_loss_sum)
pred_loss_norm_sum = pred_loss_norm[0].sum(1)
pred_loss_norm_sum *= weights
self.assertClose(loss_norm, pred_loss_norm_sum)
# Check gradients
self._check_gradients(
loss, loss_norm, pred_loss_sum, pred_loss_norm_sum, p1, p11, p2, p22
)
def _check_gradients(
self,
loss,