Single directional chamfer distance and non-absolute cosine similarity

Summary: Single directional chamfer distance and option to use non-absolute cosine similarity

Reviewed By: bottler

Differential Revision: D46593980

fbshipit-source-id: b2e591706a0cdde1c2d361614cecebb84a581433
This commit is contained in:
Norman Mueller 2023-06-13 09:09:15 -07:00 committed by Facebook GitHub Bot
parent 573a42cd5f
commit 5ffeb4d580
2 changed files with 326 additions and 121 deletions

View File

@ -68,6 +68,94 @@ def _handle_pointcloud_input(
return X, lengths, normals
def _chamfer_distance_single_direction(
x,
y,
x_lengths,
y_lengths,
x_normals,
y_normals,
weights,
batch_reduction: Union[str, None],
point_reduction: str,
norm: int,
abs_cosine: bool,
):
return_normals = x_normals is not None and y_normals is not None
N, P1, D = x.shape
# Check if inputs are heterogeneous and create a lengths mask.
is_x_heterogeneous = (x_lengths != P1).any()
x_mask = (
torch.arange(P1, device=x.device)[None] >= x_lengths[:, None]
) # shape [N, P1]
if y.shape[0] != N or y.shape[2] != D:
raise ValueError("y does not have the correct shape.")
if weights is not None:
if weights.size(0) != N:
raise ValueError("weights must be of shape (N,).")
if not (weights >= 0).all():
raise ValueError("weights cannot be negative.")
if weights.sum() == 0.0:
weights = weights.view(N, 1)
if batch_reduction in ["mean", "sum"]:
return (
(x.sum((1, 2)) * weights).sum() * 0.0,
(x.sum((1, 2)) * weights).sum() * 0.0,
)
return ((x.sum((1, 2)) * weights) * 0.0, (x.sum((1, 2)) * weights) * 0.0)
cham_norm_x = x.new_zeros(())
x_nn = knn_points(x, y, lengths1=x_lengths, lengths2=y_lengths, norm=norm, K=1)
cham_x = x_nn.dists[..., 0] # (N, P1)
if is_x_heterogeneous:
cham_x[x_mask] = 0.0
if weights is not None:
cham_x *= weights.view(N, 1)
if return_normals:
# Gather the normals using the indices and keep only value for k=0
x_normals_near = knn_gather(y_normals, x_nn.idx, y_lengths)[..., 0, :]
cosine_sim = F.cosine_similarity(x_normals, x_normals_near, dim=2, eps=1e-6)
# If abs_cosine, ignore orientation and take the absolute value of the cosine sim.
cham_norm_x = 1 - (torch.abs(cosine_sim) if abs_cosine else cosine_sim)
if is_x_heterogeneous:
cham_norm_x[x_mask] = 0.0
if weights is not None:
cham_norm_x *= weights.view(N, 1)
cham_norm_x = cham_norm_x.sum(1) # (N,)
# Apply point reduction
cham_x = cham_x.sum(1) # (N,)
if point_reduction == "mean":
x_lengths_clamped = x_lengths.clamp(min=1)
cham_x /= x_lengths_clamped
if return_normals:
cham_norm_x /= x_lengths_clamped
if batch_reduction is not None:
# batch_reduction == "sum"
cham_x = cham_x.sum()
if return_normals:
cham_norm_x = cham_norm_x.sum()
if batch_reduction == "mean":
div = weights.sum() if weights is not None else max(N, 1)
cham_x /= div
if return_normals:
cham_norm_x /= div
cham_dist = cham_x
cham_normals = cham_norm_x if return_normals else None
return cham_dist, cham_normals
def chamfer_distance(
x,
y,
@ -79,6 +167,8 @@ def chamfer_distance(
batch_reduction: Union[str, None] = "mean",
point_reduction: str = "mean",
norm: int = 2,
single_directional: bool = False,
abs_cosine: bool = True,
):
"""
Chamfer distance between two pointclouds x and y.
@ -103,6 +193,14 @@ def chamfer_distance(
point_reduction: Reduction operation to apply for the loss across the
points, can be one of ["mean", "sum"].
norm: int indicates the norm used for the distance. Supports 1 for L1 and 2 for L2.
single_directional: If False (default), loss comes from both the distance between
each point in x and its nearest neighbor in y and each point in y and its nearest
neighbor in x. If True, loss is the distance between each point in x and its
nearest neighbor in y.
abs_cosine: If False, loss_normals is from one minus the cosine similarity.
If True (default), loss_normals is from one minus the absolute value of the
cosine similarity, which means that exactly opposite normals are considered
equivalent to exactly matching normals, i.e. sign does not matter.
Returns:
2-element tuple containing
@ -112,116 +210,45 @@ def chamfer_distance(
- **loss_normals**: Tensor giving the reduced cosine distance of normals
between pointclouds in x and pointclouds in y. Returns None if
x_normals and y_normals are None.
"""
_validate_chamfer_reduction_inputs(batch_reduction, point_reduction)
if not ((norm == 1) or (norm == 2)):
raise ValueError("Support for 1 or 2 norm.")
x, x_lengths, x_normals = _handle_pointcloud_input(x, x_lengths, x_normals)
y, y_lengths, y_normals = _handle_pointcloud_input(y, y_lengths, y_normals)
return_normals = x_normals is not None and y_normals is not None
N, P1, D = x.shape
P2 = y.shape[1]
# Check if inputs are heterogeneous and create a lengths mask.
is_x_heterogeneous = (x_lengths != P1).any()
is_y_heterogeneous = (y_lengths != P2).any()
x_mask = (
torch.arange(P1, device=x.device)[None] >= x_lengths[:, None]
) # shape [N, P1]
y_mask = (
torch.arange(P2, device=y.device)[None] >= y_lengths[:, None]
) # shape [N, P2]
if y.shape[0] != N or y.shape[2] != D:
raise ValueError("y does not have the correct shape.")
if weights is not None:
if weights.size(0) != N:
raise ValueError("weights must be of shape (N,).")
if not (weights >= 0).all():
raise ValueError("weights cannot be negative.")
if weights.sum() == 0.0:
weights = weights.view(N, 1)
if batch_reduction in ["mean", "sum"]:
return (
(x.sum((1, 2)) * weights).sum() * 0.0,
(x.sum((1, 2)) * weights).sum() * 0.0,
)
return ((x.sum((1, 2)) * weights) * 0.0, (x.sum((1, 2)) * weights) * 0.0)
cham_norm_x = x.new_zeros(())
cham_norm_y = x.new_zeros(())
x_nn = knn_points(x, y, lengths1=x_lengths, lengths2=y_lengths, norm=norm, K=1)
y_nn = knn_points(y, x, lengths1=y_lengths, lengths2=x_lengths, norm=norm, K=1)
cham_x = x_nn.dists[..., 0] # (N, P1)
cham_y = y_nn.dists[..., 0] # (N, P2)
if is_x_heterogeneous:
cham_x[x_mask] = 0.0
if is_y_heterogeneous:
cham_y[y_mask] = 0.0
if weights is not None:
cham_x *= weights.view(N, 1)
cham_y *= weights.view(N, 1)
if return_normals:
# Gather the normals using the indices and keep only value for k=0
x_normals_near = knn_gather(y_normals, x_nn.idx, y_lengths)[..., 0, :]
y_normals_near = knn_gather(x_normals, y_nn.idx, x_lengths)[..., 0, :]
cham_norm_x = 1 - torch.abs(
F.cosine_similarity(x_normals, x_normals_near, dim=2, eps=1e-6)
cham_x, cham_norm_x = _chamfer_distance_single_direction(
x,
y,
x_lengths,
y_lengths,
x_normals,
y_normals,
weights,
batch_reduction,
point_reduction,
norm,
abs_cosine,
)
if single_directional:
return cham_x, cham_norm_x
else:
cham_y, cham_norm_y = _chamfer_distance_single_direction(
y,
x,
y_lengths,
x_lengths,
y_normals,
x_normals,
weights,
batch_reduction,
point_reduction,
norm,
abs_cosine,
)
cham_norm_y = 1 - torch.abs(
F.cosine_similarity(y_normals, y_normals_near, dim=2, eps=1e-6)
return (
cham_x + cham_y,
(cham_norm_x + cham_norm_y) if cham_norm_x is not None else None,
)
if is_x_heterogeneous:
cham_norm_x[x_mask] = 0.0
if is_y_heterogeneous:
cham_norm_y[y_mask] = 0.0
if weights is not None:
cham_norm_x *= weights.view(N, 1)
cham_norm_y *= weights.view(N, 1)
# Apply point reduction
cham_x = cham_x.sum(1) # (N,)
cham_y = cham_y.sum(1) # (N,)
if return_normals:
cham_norm_x = cham_norm_x.sum(1) # (N,)
cham_norm_y = cham_norm_y.sum(1) # (N,)
if point_reduction == "mean":
x_lengths_clamped = x_lengths.clamp(min=1)
y_lengths_clamped = y_lengths.clamp(min=1)
cham_x /= x_lengths_clamped
cham_y /= y_lengths_clamped
if return_normals:
cham_norm_x /= x_lengths_clamped
cham_norm_y /= y_lengths_clamped
if batch_reduction is not None:
# batch_reduction == "sum"
cham_x = cham_x.sum()
cham_y = cham_y.sum()
if return_normals:
cham_norm_x = cham_norm_x.sum()
cham_norm_y = cham_norm_y.sum()
if batch_reduction == "mean":
div = weights.sum() if weights is not None else max(N, 1)
cham_x /= div
cham_y /= div
if return_normals:
cham_norm_x /= div
cham_norm_y /= div
cham_dist = cham_x + cham_y
cham_normals = cham_norm_x + cham_norm_y if return_normals else None
return cham_dist, cham_normals

View File

@ -88,7 +88,9 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
)
@staticmethod
def chamfer_distance_naive_pointclouds(p1, p2, norm: int = 2, device="cpu"):
def chamfer_distance_naive_pointclouds(
p1, p2, norm: int = 2, device="cpu", abs_cosine=True
):
"""
Naive iterative implementation of nearest neighbor and chamfer distance.
x and y are assumed to be pointclouds objects with points and optionally normals.
@ -146,17 +148,20 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
if return_normals:
x_index = dist.argmin(2).view(N, P1, 1).expand(N, P1, 3)
y_index = dist.argmin(1).view(N, P2, 1).expand(N, P2, 3)
lnorm1 = 1 - torch.abs(
F.cosine_similarity(
x_normals, y_normals.gather(1, x_index), dim=2, eps=1e-6
)
cosine_sim1 = F.cosine_similarity(
x_normals, y_normals.gather(1, x_index), dim=2, eps=1e-6
)
lnorm2 = 1 - torch.abs(
F.cosine_similarity(
y_normals, x_normals.gather(1, y_index), dim=2, eps=1e-6
)
cosine_sim2 = F.cosine_similarity(
y_normals, x_normals.gather(1, y_index), dim=2, eps=1e-6
)
if abs_cosine:
lnorm1 = 1 - torch.abs(cosine_sim1)
lnorm2 = 1 - torch.abs(cosine_sim2)
else:
lnorm1 = 1 - cosine_sim1
lnorm2 = 1 - cosine_sim2
if is_x_heterogeneous:
lnorm1[x_mask] = 0.0
if is_y_heterogeneous:
@ -167,7 +172,9 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
return loss, lnorm
@staticmethod
def chamfer_distance_naive(x, y, x_normals=None, y_normals=None, norm: int = 2):
def chamfer_distance_naive(
x, y, x_normals=None, y_normals=None, norm: int = 2, abs_cosine=True
):
"""
Naive iterative implementation of nearest neighbor and chamfer distance.
Returns lists of the unreduced loss and loss_normals. This naive
@ -200,16 +207,21 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
if return_normals:
x_index = dist.argmin(2).view(N, P1, 1).expand(N, P1, 3)
y_index = dist.argmin(1).view(N, P2, 1).expand(N, P2, 3)
lnorm1 = 1 - torch.abs(
F.cosine_similarity(
x_normals, y_normals.gather(1, x_index), dim=2, eps=1e-6
)
cosine_sim1 = F.cosine_similarity(
x_normals, y_normals.gather(1, x_index), dim=2, eps=1e-6
)
lnorm2 = 1 - torch.abs(
F.cosine_similarity(
y_normals, x_normals.gather(1, y_index), dim=2, eps=1e-6
)
cosine_sim2 = F.cosine_similarity(
y_normals, x_normals.gather(1, y_index), dim=2, eps=1e-6
)
if abs_cosine:
lnorm1 = 1 - torch.abs(cosine_sim1)
lnorm2 = 1 - torch.abs(cosine_sim2)
else:
lnorm1 = 1 - cosine_sim1
lnorm2 = 1 - cosine_sim2
lnorm = [lnorm1, lnorm2] # [(N, P1), (N, P2)]
return loss, lnorm
@ -323,6 +335,80 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
y_lengths,
)
def test_single_directional_chamfer_vs_naive_pointcloud(self):
"""
Test the single directional settings for chamfer_distance
(point reduction = "mean" and batch_reduction="mean") but with heterogeneous
pointclouds as input. Compare with the naive implementation of chamfer
which supports heterogeneous pointcloud objects.
"""
N, max_P1, max_P2 = 3, 70, 70
device = get_random_cuda_device()
for norm in [1, 2]:
for abs_cosine in [True, False]:
points_normals = TestChamfer.init_pointclouds(N, max_P1, max_P2, device)
weights = points_normals.weights
x_lengths = points_normals.p1_lengths
y_lengths = points_normals.p2_lengths
# Chamfer with tensors as input for heterogeneous pointclouds.
cham_tensor, norm_tensor = chamfer_distance(
points_normals.p1,
points_normals.p2,
x_normals=points_normals.n1,
y_normals=points_normals.n2,
x_lengths=points_normals.p1_lengths,
y_lengths=points_normals.p2_lengths,
weights=weights,
norm=norm,
single_directional=True,
abs_cosine=abs_cosine,
)
# Chamfer with pointclouds as input.
(
pred_loss,
pred_norm_loss,
) = TestChamfer.chamfer_distance_naive_pointclouds(
points_normals.cloud1,
points_normals.cloud2,
norm=norm,
device=device,
abs_cosine=abs_cosine,
)
# Mean reduction point loss.
pred_loss[0] *= weights.view(N, 1)
pred_loss_mean = pred_loss[0].sum(1) / x_lengths
pred_loss_mean = pred_loss_mean.sum()
pred_loss_mean /= weights.sum()
# Mean reduction norm loss.
pred_norm_loss[0] *= weights.view(N, 1)
pred_norm_loss_mean = pred_norm_loss[0].sum(1) / x_lengths
pred_norm_loss_mean = pred_norm_loss_mean.sum() / weights.sum()
self.assertClose(pred_loss_mean, cham_tensor)
self.assertClose(pred_norm_loss_mean, norm_tensor)
self._check_gradients(
cham_tensor,
norm_tensor,
pred_loss_mean,
pred_norm_loss_mean,
points_normals.cloud1.points_list(),
points_normals.p1,
points_normals.cloud2.points_list(),
points_normals.p2,
points_normals.cloud1.normals_list(),
points_normals.n1,
points_normals.cloud2.normals_list(),
points_normals.n2,
x_lengths,
y_lengths,
)
def test_chamfer_pointcloud_object_withnormals(self):
N = 5
P1, P2 = 100, 100
@ -485,6 +571,53 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
loss, loss_norm, pred_loss_mean, pred_loss_norm_mean, p1, p11, p2, p22
)
def test_single_direction_chamfer_point_reduction_mean(self):
"""
Compare output of vectorized chamfer loss with naive implementation
for point_reduction = "mean" and batch_reduction = None.
"""
N, max_P1, max_P2 = 7, 10, 18
device = get_random_cuda_device()
points_normals = TestChamfer.init_pointclouds(N, max_P1, max_P2, device)
p1 = points_normals.p1
p2 = points_normals.p2
p1_normals = points_normals.n1
p2_normals = points_normals.n2
weights = points_normals.weights
p11 = p1.detach().clone()
p22 = p2.detach().clone()
p11.requires_grad = True
p22.requires_grad = True
P1 = p1.shape[1]
pred_loss, pred_loss_norm = TestChamfer.chamfer_distance_naive(
p1, p2, x_normals=p1_normals, y_normals=p2_normals
)
# point_reduction = "mean".
loss, loss_norm = chamfer_distance(
p11,
p22,
x_normals=p1_normals,
y_normals=p2_normals,
weights=weights,
batch_reduction=None,
point_reduction="mean",
single_directional=True,
)
pred_loss_mean = pred_loss[0].sum(1) / P1
pred_loss_mean *= weights
self.assertClose(loss, pred_loss_mean)
pred_loss_norm_mean = pred_loss_norm[0].sum(1) / P1
pred_loss_norm_mean *= weights
self.assertClose(loss_norm, pred_loss_norm_mean)
# Check gradients
self._check_gradients(
loss, loss_norm, pred_loss_mean, pred_loss_norm_mean, p1, p11, p2, p22
)
def test_chamfer_point_reduction_sum(self):
"""
Compare output of vectorized chamfer loss with naive implementation
@ -529,6 +662,51 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
loss, loss_norm, pred_loss_sum, pred_loss_norm_sum, p1, p11, p2, p22
)
def test_single_directional_chamfer_point_reduction_sum(self):
"""
Compare output of vectorized single directional chamfer loss with naive implementation
for point_reduction = "sum" and batch_reduction = None.
"""
N, P1, P2 = 7, 10, 18
device = get_random_cuda_device()
points_normals = TestChamfer.init_pointclouds(N, P1, P2, device)
p1 = points_normals.p1
p2 = points_normals.p2
p1_normals = points_normals.n1
p2_normals = points_normals.n2
weights = points_normals.weights
p11 = p1.detach().clone()
p22 = p2.detach().clone()
p11.requires_grad = True
p22.requires_grad = True
pred_loss, pred_loss_norm = TestChamfer.chamfer_distance_naive(
p1, p2, x_normals=p1_normals, y_normals=p2_normals
)
loss, loss_norm = chamfer_distance(
p11,
p22,
x_normals=p1_normals,
y_normals=p2_normals,
weights=weights,
batch_reduction=None,
point_reduction="sum",
single_directional=True,
)
pred_loss_sum = pred_loss[0].sum(1)
pred_loss_sum *= weights
self.assertClose(loss, pred_loss_sum)
pred_loss_norm_sum = pred_loss_norm[0].sum(1)
pred_loss_norm_sum *= weights
self.assertClose(loss_norm, pred_loss_norm_sum)
# Check gradients
self._check_gradients(
loss, loss_norm, pred_loss_sum, pred_loss_norm_sum, p1, p11, p2, p22
)
def _check_gradients(
self,
loss,