diff --git a/pytorch3d/loss/chamfer.py b/pytorch3d/loss/chamfer.py index 49690ec3..8ec828ec 100644 --- a/pytorch3d/loss/chamfer.py +++ b/pytorch3d/loss/chamfer.py @@ -68,6 +68,94 @@ def _handle_pointcloud_input( return X, lengths, normals +def _chamfer_distance_single_direction( + x, + y, + x_lengths, + y_lengths, + x_normals, + y_normals, + weights, + batch_reduction: Union[str, None], + point_reduction: str, + norm: int, + abs_cosine: bool, +): + return_normals = x_normals is not None and y_normals is not None + + N, P1, D = x.shape + + # Check if inputs are heterogeneous and create a lengths mask. + is_x_heterogeneous = (x_lengths != P1).any() + x_mask = ( + torch.arange(P1, device=x.device)[None] >= x_lengths[:, None] + ) # shape [N, P1] + if y.shape[0] != N or y.shape[2] != D: + raise ValueError("y does not have the correct shape.") + if weights is not None: + if weights.size(0) != N: + raise ValueError("weights must be of shape (N,).") + if not (weights >= 0).all(): + raise ValueError("weights cannot be negative.") + if weights.sum() == 0.0: + weights = weights.view(N, 1) + if batch_reduction in ["mean", "sum"]: + return ( + (x.sum((1, 2)) * weights).sum() * 0.0, + (x.sum((1, 2)) * weights).sum() * 0.0, + ) + return ((x.sum((1, 2)) * weights) * 0.0, (x.sum((1, 2)) * weights) * 0.0) + + cham_norm_x = x.new_zeros(()) + + x_nn = knn_points(x, y, lengths1=x_lengths, lengths2=y_lengths, norm=norm, K=1) + cham_x = x_nn.dists[..., 0] # (N, P1) + + if is_x_heterogeneous: + cham_x[x_mask] = 0.0 + + if weights is not None: + cham_x *= weights.view(N, 1) + + if return_normals: + # Gather the normals using the indices and keep only value for k=0 + x_normals_near = knn_gather(y_normals, x_nn.idx, y_lengths)[..., 0, :] + + cosine_sim = F.cosine_similarity(x_normals, x_normals_near, dim=2, eps=1e-6) + # If abs_cosine, ignore orientation and take the absolute value of the cosine sim. + cham_norm_x = 1 - (torch.abs(cosine_sim) if abs_cosine else cosine_sim) + + if is_x_heterogeneous: + cham_norm_x[x_mask] = 0.0 + + if weights is not None: + cham_norm_x *= weights.view(N, 1) + cham_norm_x = cham_norm_x.sum(1) # (N,) + + # Apply point reduction + cham_x = cham_x.sum(1) # (N,) + if point_reduction == "mean": + x_lengths_clamped = x_lengths.clamp(min=1) + cham_x /= x_lengths_clamped + if return_normals: + cham_norm_x /= x_lengths_clamped + + if batch_reduction is not None: + # batch_reduction == "sum" + cham_x = cham_x.sum() + if return_normals: + cham_norm_x = cham_norm_x.sum() + if batch_reduction == "mean": + div = weights.sum() if weights is not None else max(N, 1) + cham_x /= div + if return_normals: + cham_norm_x /= div + + cham_dist = cham_x + cham_normals = cham_norm_x if return_normals else None + return cham_dist, cham_normals + + def chamfer_distance( x, y, @@ -79,6 +167,8 @@ def chamfer_distance( batch_reduction: Union[str, None] = "mean", point_reduction: str = "mean", norm: int = 2, + single_directional: bool = False, + abs_cosine: bool = True, ): """ Chamfer distance between two pointclouds x and y. @@ -103,6 +193,14 @@ def chamfer_distance( point_reduction: Reduction operation to apply for the loss across the points, can be one of ["mean", "sum"]. norm: int indicates the norm used for the distance. Supports 1 for L1 and 2 for L2. + single_directional: If False (default), loss comes from both the distance between + each point in x and its nearest neighbor in y and each point in y and its nearest + neighbor in x. If True, loss is the distance between each point in x and its + nearest neighbor in y. + abs_cosine: If False, loss_normals is from one minus the cosine similarity. + If True (default), loss_normals is from one minus the absolute value of the + cosine similarity, which means that exactly opposite normals are considered + equivalent to exactly matching normals, i.e. sign does not matter. Returns: 2-element tuple containing @@ -112,116 +210,45 @@ def chamfer_distance( - **loss_normals**: Tensor giving the reduced cosine distance of normals between pointclouds in x and pointclouds in y. Returns None if x_normals and y_normals are None. + """ _validate_chamfer_reduction_inputs(batch_reduction, point_reduction) if not ((norm == 1) or (norm == 2)): raise ValueError("Support for 1 or 2 norm.") - x, x_lengths, x_normals = _handle_pointcloud_input(x, x_lengths, x_normals) y, y_lengths, y_normals = _handle_pointcloud_input(y, y_lengths, y_normals) - return_normals = x_normals is not None and y_normals is not None - - N, P1, D = x.shape - P2 = y.shape[1] - - # Check if inputs are heterogeneous and create a lengths mask. - is_x_heterogeneous = (x_lengths != P1).any() - is_y_heterogeneous = (y_lengths != P2).any() - x_mask = ( - torch.arange(P1, device=x.device)[None] >= x_lengths[:, None] - ) # shape [N, P1] - y_mask = ( - torch.arange(P2, device=y.device)[None] >= y_lengths[:, None] - ) # shape [N, P2] - - if y.shape[0] != N or y.shape[2] != D: - raise ValueError("y does not have the correct shape.") - if weights is not None: - if weights.size(0) != N: - raise ValueError("weights must be of shape (N,).") - if not (weights >= 0).all(): - raise ValueError("weights cannot be negative.") - if weights.sum() == 0.0: - weights = weights.view(N, 1) - if batch_reduction in ["mean", "sum"]: - return ( - (x.sum((1, 2)) * weights).sum() * 0.0, - (x.sum((1, 2)) * weights).sum() * 0.0, - ) - return ((x.sum((1, 2)) * weights) * 0.0, (x.sum((1, 2)) * weights) * 0.0) - - cham_norm_x = x.new_zeros(()) - cham_norm_y = x.new_zeros(()) - - x_nn = knn_points(x, y, lengths1=x_lengths, lengths2=y_lengths, norm=norm, K=1) - y_nn = knn_points(y, x, lengths1=y_lengths, lengths2=x_lengths, norm=norm, K=1) - - cham_x = x_nn.dists[..., 0] # (N, P1) - cham_y = y_nn.dists[..., 0] # (N, P2) - - if is_x_heterogeneous: - cham_x[x_mask] = 0.0 - if is_y_heterogeneous: - cham_y[y_mask] = 0.0 - - if weights is not None: - cham_x *= weights.view(N, 1) - cham_y *= weights.view(N, 1) - - if return_normals: - # Gather the normals using the indices and keep only value for k=0 - x_normals_near = knn_gather(y_normals, x_nn.idx, y_lengths)[..., 0, :] - y_normals_near = knn_gather(x_normals, y_nn.idx, x_lengths)[..., 0, :] - - cham_norm_x = 1 - torch.abs( - F.cosine_similarity(x_normals, x_normals_near, dim=2, eps=1e-6) + cham_x, cham_norm_x = _chamfer_distance_single_direction( + x, + y, + x_lengths, + y_lengths, + x_normals, + y_normals, + weights, + batch_reduction, + point_reduction, + norm, + abs_cosine, + ) + if single_directional: + return cham_x, cham_norm_x + else: + cham_y, cham_norm_y = _chamfer_distance_single_direction( + y, + x, + y_lengths, + x_lengths, + y_normals, + x_normals, + weights, + batch_reduction, + point_reduction, + norm, + abs_cosine, ) - cham_norm_y = 1 - torch.abs( - F.cosine_similarity(y_normals, y_normals_near, dim=2, eps=1e-6) + return ( + cham_x + cham_y, + (cham_norm_x + cham_norm_y) if cham_norm_x is not None else None, ) - - if is_x_heterogeneous: - cham_norm_x[x_mask] = 0.0 - if is_y_heterogeneous: - cham_norm_y[y_mask] = 0.0 - - if weights is not None: - cham_norm_x *= weights.view(N, 1) - cham_norm_y *= weights.view(N, 1) - - # Apply point reduction - cham_x = cham_x.sum(1) # (N,) - cham_y = cham_y.sum(1) # (N,) - if return_normals: - cham_norm_x = cham_norm_x.sum(1) # (N,) - cham_norm_y = cham_norm_y.sum(1) # (N,) - if point_reduction == "mean": - x_lengths_clamped = x_lengths.clamp(min=1) - y_lengths_clamped = y_lengths.clamp(min=1) - cham_x /= x_lengths_clamped - cham_y /= y_lengths_clamped - if return_normals: - cham_norm_x /= x_lengths_clamped - cham_norm_y /= y_lengths_clamped - - if batch_reduction is not None: - # batch_reduction == "sum" - cham_x = cham_x.sum() - cham_y = cham_y.sum() - if return_normals: - cham_norm_x = cham_norm_x.sum() - cham_norm_y = cham_norm_y.sum() - if batch_reduction == "mean": - div = weights.sum() if weights is not None else max(N, 1) - cham_x /= div - cham_y /= div - if return_normals: - cham_norm_x /= div - cham_norm_y /= div - - cham_dist = cham_x + cham_y - cham_normals = cham_norm_x + cham_norm_y if return_normals else None - - return cham_dist, cham_normals diff --git a/tests/test_chamfer.py b/tests/test_chamfer.py index 964a9fab..e6a7897d 100644 --- a/tests/test_chamfer.py +++ b/tests/test_chamfer.py @@ -88,7 +88,9 @@ class TestChamfer(TestCaseMixin, unittest.TestCase): ) @staticmethod - def chamfer_distance_naive_pointclouds(p1, p2, norm: int = 2, device="cpu"): + def chamfer_distance_naive_pointclouds( + p1, p2, norm: int = 2, device="cpu", abs_cosine=True + ): """ Naive iterative implementation of nearest neighbor and chamfer distance. x and y are assumed to be pointclouds objects with points and optionally normals. @@ -146,17 +148,20 @@ class TestChamfer(TestCaseMixin, unittest.TestCase): if return_normals: x_index = dist.argmin(2).view(N, P1, 1).expand(N, P1, 3) y_index = dist.argmin(1).view(N, P2, 1).expand(N, P2, 3) - lnorm1 = 1 - torch.abs( - F.cosine_similarity( - x_normals, y_normals.gather(1, x_index), dim=2, eps=1e-6 - ) + cosine_sim1 = F.cosine_similarity( + x_normals, y_normals.gather(1, x_index), dim=2, eps=1e-6 ) - lnorm2 = 1 - torch.abs( - F.cosine_similarity( - y_normals, x_normals.gather(1, y_index), dim=2, eps=1e-6 - ) + cosine_sim2 = F.cosine_similarity( + y_normals, x_normals.gather(1, y_index), dim=2, eps=1e-6 ) + if abs_cosine: + lnorm1 = 1 - torch.abs(cosine_sim1) + lnorm2 = 1 - torch.abs(cosine_sim2) + else: + lnorm1 = 1 - cosine_sim1 + lnorm2 = 1 - cosine_sim2 + if is_x_heterogeneous: lnorm1[x_mask] = 0.0 if is_y_heterogeneous: @@ -167,7 +172,9 @@ class TestChamfer(TestCaseMixin, unittest.TestCase): return loss, lnorm @staticmethod - def chamfer_distance_naive(x, y, x_normals=None, y_normals=None, norm: int = 2): + def chamfer_distance_naive( + x, y, x_normals=None, y_normals=None, norm: int = 2, abs_cosine=True + ): """ Naive iterative implementation of nearest neighbor and chamfer distance. Returns lists of the unreduced loss and loss_normals. This naive @@ -200,16 +207,21 @@ class TestChamfer(TestCaseMixin, unittest.TestCase): if return_normals: x_index = dist.argmin(2).view(N, P1, 1).expand(N, P1, 3) y_index = dist.argmin(1).view(N, P2, 1).expand(N, P2, 3) - lnorm1 = 1 - torch.abs( - F.cosine_similarity( - x_normals, y_normals.gather(1, x_index), dim=2, eps=1e-6 - ) + + cosine_sim1 = F.cosine_similarity( + x_normals, y_normals.gather(1, x_index), dim=2, eps=1e-6 ) - lnorm2 = 1 - torch.abs( - F.cosine_similarity( - y_normals, x_normals.gather(1, y_index), dim=2, eps=1e-6 - ) + cosine_sim2 = F.cosine_similarity( + y_normals, x_normals.gather(1, y_index), dim=2, eps=1e-6 ) + + if abs_cosine: + lnorm1 = 1 - torch.abs(cosine_sim1) + lnorm2 = 1 - torch.abs(cosine_sim2) + else: + lnorm1 = 1 - cosine_sim1 + lnorm2 = 1 - cosine_sim2 + lnorm = [lnorm1, lnorm2] # [(N, P1), (N, P2)] return loss, lnorm @@ -323,6 +335,80 @@ class TestChamfer(TestCaseMixin, unittest.TestCase): y_lengths, ) + def test_single_directional_chamfer_vs_naive_pointcloud(self): + """ + Test the single directional settings for chamfer_distance + (point reduction = "mean" and batch_reduction="mean") but with heterogeneous + pointclouds as input. Compare with the naive implementation of chamfer + which supports heterogeneous pointcloud objects. + """ + N, max_P1, max_P2 = 3, 70, 70 + device = get_random_cuda_device() + + for norm in [1, 2]: + for abs_cosine in [True, False]: + points_normals = TestChamfer.init_pointclouds(N, max_P1, max_P2, device) + weights = points_normals.weights + x_lengths = points_normals.p1_lengths + y_lengths = points_normals.p2_lengths + + # Chamfer with tensors as input for heterogeneous pointclouds. + cham_tensor, norm_tensor = chamfer_distance( + points_normals.p1, + points_normals.p2, + x_normals=points_normals.n1, + y_normals=points_normals.n2, + x_lengths=points_normals.p1_lengths, + y_lengths=points_normals.p2_lengths, + weights=weights, + norm=norm, + single_directional=True, + abs_cosine=abs_cosine, + ) + + # Chamfer with pointclouds as input. + ( + pred_loss, + pred_norm_loss, + ) = TestChamfer.chamfer_distance_naive_pointclouds( + points_normals.cloud1, + points_normals.cloud2, + norm=norm, + device=device, + abs_cosine=abs_cosine, + ) + + # Mean reduction point loss. + pred_loss[0] *= weights.view(N, 1) + pred_loss_mean = pred_loss[0].sum(1) / x_lengths + pred_loss_mean = pred_loss_mean.sum() + pred_loss_mean /= weights.sum() + + # Mean reduction norm loss. + pred_norm_loss[0] *= weights.view(N, 1) + pred_norm_loss_mean = pred_norm_loss[0].sum(1) / x_lengths + pred_norm_loss_mean = pred_norm_loss_mean.sum() / weights.sum() + + self.assertClose(pred_loss_mean, cham_tensor) + self.assertClose(pred_norm_loss_mean, norm_tensor) + + self._check_gradients( + cham_tensor, + norm_tensor, + pred_loss_mean, + pred_norm_loss_mean, + points_normals.cloud1.points_list(), + points_normals.p1, + points_normals.cloud2.points_list(), + points_normals.p2, + points_normals.cloud1.normals_list(), + points_normals.n1, + points_normals.cloud2.normals_list(), + points_normals.n2, + x_lengths, + y_lengths, + ) + def test_chamfer_pointcloud_object_withnormals(self): N = 5 P1, P2 = 100, 100 @@ -485,6 +571,53 @@ class TestChamfer(TestCaseMixin, unittest.TestCase): loss, loss_norm, pred_loss_mean, pred_loss_norm_mean, p1, p11, p2, p22 ) + def test_single_direction_chamfer_point_reduction_mean(self): + """ + Compare output of vectorized chamfer loss with naive implementation + for point_reduction = "mean" and batch_reduction = None. + """ + N, max_P1, max_P2 = 7, 10, 18 + device = get_random_cuda_device() + points_normals = TestChamfer.init_pointclouds(N, max_P1, max_P2, device) + p1 = points_normals.p1 + p2 = points_normals.p2 + p1_normals = points_normals.n1 + p2_normals = points_normals.n2 + weights = points_normals.weights + p11 = p1.detach().clone() + p22 = p2.detach().clone() + p11.requires_grad = True + p22.requires_grad = True + P1 = p1.shape[1] + + pred_loss, pred_loss_norm = TestChamfer.chamfer_distance_naive( + p1, p2, x_normals=p1_normals, y_normals=p2_normals + ) + + # point_reduction = "mean". + loss, loss_norm = chamfer_distance( + p11, + p22, + x_normals=p1_normals, + y_normals=p2_normals, + weights=weights, + batch_reduction=None, + point_reduction="mean", + single_directional=True, + ) + pred_loss_mean = pred_loss[0].sum(1) / P1 + pred_loss_mean *= weights + self.assertClose(loss, pred_loss_mean) + + pred_loss_norm_mean = pred_loss_norm[0].sum(1) / P1 + pred_loss_norm_mean *= weights + self.assertClose(loss_norm, pred_loss_norm_mean) + + # Check gradients + self._check_gradients( + loss, loss_norm, pred_loss_mean, pred_loss_norm_mean, p1, p11, p2, p22 + ) + def test_chamfer_point_reduction_sum(self): """ Compare output of vectorized chamfer loss with naive implementation @@ -529,6 +662,51 @@ class TestChamfer(TestCaseMixin, unittest.TestCase): loss, loss_norm, pred_loss_sum, pred_loss_norm_sum, p1, p11, p2, p22 ) + def test_single_directional_chamfer_point_reduction_sum(self): + """ + Compare output of vectorized single directional chamfer loss with naive implementation + for point_reduction = "sum" and batch_reduction = None. + """ + N, P1, P2 = 7, 10, 18 + device = get_random_cuda_device() + points_normals = TestChamfer.init_pointclouds(N, P1, P2, device) + p1 = points_normals.p1 + p2 = points_normals.p2 + p1_normals = points_normals.n1 + p2_normals = points_normals.n2 + weights = points_normals.weights + p11 = p1.detach().clone() + p22 = p2.detach().clone() + p11.requires_grad = True + p22.requires_grad = True + + pred_loss, pred_loss_norm = TestChamfer.chamfer_distance_naive( + p1, p2, x_normals=p1_normals, y_normals=p2_normals + ) + + loss, loss_norm = chamfer_distance( + p11, + p22, + x_normals=p1_normals, + y_normals=p2_normals, + weights=weights, + batch_reduction=None, + point_reduction="sum", + single_directional=True, + ) + pred_loss_sum = pred_loss[0].sum(1) + pred_loss_sum *= weights + self.assertClose(loss, pred_loss_sum) + + pred_loss_norm_sum = pred_loss_norm[0].sum(1) + pred_loss_norm_sum *= weights + self.assertClose(loss_norm, pred_loss_norm_sum) + + # Check gradients + self._check_gradients( + loss, loss_norm, pred_loss_sum, pred_loss_norm_sum, p1, p11, p2, p22 + ) + def _check_gradients( self, loss,