mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 18:02:14 +08:00 
			
		
		
		
	Single directional chamfer distance and non-absolute cosine similarity
Summary: Single directional chamfer distance and option to use non-absolute cosine similarity Reviewed By: bottler Differential Revision: D46593980 fbshipit-source-id: b2e591706a0cdde1c2d361614cecebb84a581433
This commit is contained in:
		
							parent
							
								
									573a42cd5f
								
							
						
					
					
						commit
						5ffeb4d580
					
				@ -68,6 +68,94 @@ def _handle_pointcloud_input(
 | 
			
		||||
    return X, lengths, normals
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _chamfer_distance_single_direction(
 | 
			
		||||
    x,
 | 
			
		||||
    y,
 | 
			
		||||
    x_lengths,
 | 
			
		||||
    y_lengths,
 | 
			
		||||
    x_normals,
 | 
			
		||||
    y_normals,
 | 
			
		||||
    weights,
 | 
			
		||||
    batch_reduction: Union[str, None],
 | 
			
		||||
    point_reduction: str,
 | 
			
		||||
    norm: int,
 | 
			
		||||
    abs_cosine: bool,
 | 
			
		||||
):
 | 
			
		||||
    return_normals = x_normals is not None and y_normals is not None
 | 
			
		||||
 | 
			
		||||
    N, P1, D = x.shape
 | 
			
		||||
 | 
			
		||||
    # Check if inputs are heterogeneous and create a lengths mask.
 | 
			
		||||
    is_x_heterogeneous = (x_lengths != P1).any()
 | 
			
		||||
    x_mask = (
 | 
			
		||||
        torch.arange(P1, device=x.device)[None] >= x_lengths[:, None]
 | 
			
		||||
    )  # shape [N, P1]
 | 
			
		||||
    if y.shape[0] != N or y.shape[2] != D:
 | 
			
		||||
        raise ValueError("y does not have the correct shape.")
 | 
			
		||||
    if weights is not None:
 | 
			
		||||
        if weights.size(0) != N:
 | 
			
		||||
            raise ValueError("weights must be of shape (N,).")
 | 
			
		||||
        if not (weights >= 0).all():
 | 
			
		||||
            raise ValueError("weights cannot be negative.")
 | 
			
		||||
        if weights.sum() == 0.0:
 | 
			
		||||
            weights = weights.view(N, 1)
 | 
			
		||||
            if batch_reduction in ["mean", "sum"]:
 | 
			
		||||
                return (
 | 
			
		||||
                    (x.sum((1, 2)) * weights).sum() * 0.0,
 | 
			
		||||
                    (x.sum((1, 2)) * weights).sum() * 0.0,
 | 
			
		||||
                )
 | 
			
		||||
            return ((x.sum((1, 2)) * weights) * 0.0, (x.sum((1, 2)) * weights) * 0.0)
 | 
			
		||||
 | 
			
		||||
    cham_norm_x = x.new_zeros(())
 | 
			
		||||
 | 
			
		||||
    x_nn = knn_points(x, y, lengths1=x_lengths, lengths2=y_lengths, norm=norm, K=1)
 | 
			
		||||
    cham_x = x_nn.dists[..., 0]  # (N, P1)
 | 
			
		||||
 | 
			
		||||
    if is_x_heterogeneous:
 | 
			
		||||
        cham_x[x_mask] = 0.0
 | 
			
		||||
 | 
			
		||||
    if weights is not None:
 | 
			
		||||
        cham_x *= weights.view(N, 1)
 | 
			
		||||
 | 
			
		||||
    if return_normals:
 | 
			
		||||
        # Gather the normals using the indices and keep only value for k=0
 | 
			
		||||
        x_normals_near = knn_gather(y_normals, x_nn.idx, y_lengths)[..., 0, :]
 | 
			
		||||
 | 
			
		||||
        cosine_sim = F.cosine_similarity(x_normals, x_normals_near, dim=2, eps=1e-6)
 | 
			
		||||
        # If abs_cosine, ignore orientation and take the absolute value of the cosine sim.
 | 
			
		||||
        cham_norm_x = 1 - (torch.abs(cosine_sim) if abs_cosine else cosine_sim)
 | 
			
		||||
 | 
			
		||||
        if is_x_heterogeneous:
 | 
			
		||||
            cham_norm_x[x_mask] = 0.0
 | 
			
		||||
 | 
			
		||||
        if weights is not None:
 | 
			
		||||
            cham_norm_x *= weights.view(N, 1)
 | 
			
		||||
        cham_norm_x = cham_norm_x.sum(1)  # (N,)
 | 
			
		||||
 | 
			
		||||
    # Apply point reduction
 | 
			
		||||
    cham_x = cham_x.sum(1)  # (N,)
 | 
			
		||||
    if point_reduction == "mean":
 | 
			
		||||
        x_lengths_clamped = x_lengths.clamp(min=1)
 | 
			
		||||
        cham_x /= x_lengths_clamped
 | 
			
		||||
        if return_normals:
 | 
			
		||||
            cham_norm_x /= x_lengths_clamped
 | 
			
		||||
 | 
			
		||||
    if batch_reduction is not None:
 | 
			
		||||
        # batch_reduction == "sum"
 | 
			
		||||
        cham_x = cham_x.sum()
 | 
			
		||||
        if return_normals:
 | 
			
		||||
            cham_norm_x = cham_norm_x.sum()
 | 
			
		||||
        if batch_reduction == "mean":
 | 
			
		||||
            div = weights.sum() if weights is not None else max(N, 1)
 | 
			
		||||
            cham_x /= div
 | 
			
		||||
            if return_normals:
 | 
			
		||||
                cham_norm_x /= div
 | 
			
		||||
 | 
			
		||||
    cham_dist = cham_x
 | 
			
		||||
    cham_normals = cham_norm_x if return_normals else None
 | 
			
		||||
    return cham_dist, cham_normals
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def chamfer_distance(
 | 
			
		||||
    x,
 | 
			
		||||
    y,
 | 
			
		||||
@ -79,6 +167,8 @@ def chamfer_distance(
 | 
			
		||||
    batch_reduction: Union[str, None] = "mean",
 | 
			
		||||
    point_reduction: str = "mean",
 | 
			
		||||
    norm: int = 2,
 | 
			
		||||
    single_directional: bool = False,
 | 
			
		||||
    abs_cosine: bool = True,
 | 
			
		||||
):
 | 
			
		||||
    """
 | 
			
		||||
    Chamfer distance between two pointclouds x and y.
 | 
			
		||||
@ -103,6 +193,14 @@ def chamfer_distance(
 | 
			
		||||
        point_reduction: Reduction operation to apply for the loss across the
 | 
			
		||||
            points, can be one of ["mean", "sum"].
 | 
			
		||||
        norm: int indicates the norm used for the distance. Supports 1 for L1 and 2 for L2.
 | 
			
		||||
        single_directional: If False (default), loss comes from both the distance between
 | 
			
		||||
            each point in x and its nearest neighbor in y and each point in y and its nearest
 | 
			
		||||
            neighbor in x. If True, loss is the distance between each point in x and its
 | 
			
		||||
            nearest neighbor in y.
 | 
			
		||||
        abs_cosine: If False, loss_normals is from one minus the cosine similarity.
 | 
			
		||||
            If True (default), loss_normals is from one minus the absolute value of the
 | 
			
		||||
            cosine similarity, which means that exactly opposite normals are considered
 | 
			
		||||
            equivalent to exactly matching normals, i.e. sign does not matter.
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
        2-element tuple containing
 | 
			
		||||
@ -112,116 +210,45 @@ def chamfer_distance(
 | 
			
		||||
        - **loss_normals**: Tensor giving the reduced cosine distance of normals
 | 
			
		||||
          between pointclouds in x and pointclouds in y. Returns None if
 | 
			
		||||
          x_normals and y_normals are None.
 | 
			
		||||
 | 
			
		||||
    """
 | 
			
		||||
    _validate_chamfer_reduction_inputs(batch_reduction, point_reduction)
 | 
			
		||||
 | 
			
		||||
    if not ((norm == 1) or (norm == 2)):
 | 
			
		||||
        raise ValueError("Support for 1 or 2 norm.")
 | 
			
		||||
 | 
			
		||||
    x, x_lengths, x_normals = _handle_pointcloud_input(x, x_lengths, x_normals)
 | 
			
		||||
    y, y_lengths, y_normals = _handle_pointcloud_input(y, y_lengths, y_normals)
 | 
			
		||||
 | 
			
		||||
    return_normals = x_normals is not None and y_normals is not None
 | 
			
		||||
 | 
			
		||||
    N, P1, D = x.shape
 | 
			
		||||
    P2 = y.shape[1]
 | 
			
		||||
 | 
			
		||||
    # Check if inputs are heterogeneous and create a lengths mask.
 | 
			
		||||
    is_x_heterogeneous = (x_lengths != P1).any()
 | 
			
		||||
    is_y_heterogeneous = (y_lengths != P2).any()
 | 
			
		||||
    x_mask = (
 | 
			
		||||
        torch.arange(P1, device=x.device)[None] >= x_lengths[:, None]
 | 
			
		||||
    )  # shape [N, P1]
 | 
			
		||||
    y_mask = (
 | 
			
		||||
        torch.arange(P2, device=y.device)[None] >= y_lengths[:, None]
 | 
			
		||||
    )  # shape [N, P2]
 | 
			
		||||
 | 
			
		||||
    if y.shape[0] != N or y.shape[2] != D:
 | 
			
		||||
        raise ValueError("y does not have the correct shape.")
 | 
			
		||||
    if weights is not None:
 | 
			
		||||
        if weights.size(0) != N:
 | 
			
		||||
            raise ValueError("weights must be of shape (N,).")
 | 
			
		||||
        if not (weights >= 0).all():
 | 
			
		||||
            raise ValueError("weights cannot be negative.")
 | 
			
		||||
        if weights.sum() == 0.0:
 | 
			
		||||
            weights = weights.view(N, 1)
 | 
			
		||||
            if batch_reduction in ["mean", "sum"]:
 | 
			
		||||
                return (
 | 
			
		||||
                    (x.sum((1, 2)) * weights).sum() * 0.0,
 | 
			
		||||
                    (x.sum((1, 2)) * weights).sum() * 0.0,
 | 
			
		||||
                )
 | 
			
		||||
            return ((x.sum((1, 2)) * weights) * 0.0, (x.sum((1, 2)) * weights) * 0.0)
 | 
			
		||||
 | 
			
		||||
    cham_norm_x = x.new_zeros(())
 | 
			
		||||
    cham_norm_y = x.new_zeros(())
 | 
			
		||||
 | 
			
		||||
    x_nn = knn_points(x, y, lengths1=x_lengths, lengths2=y_lengths, norm=norm, K=1)
 | 
			
		||||
    y_nn = knn_points(y, x, lengths1=y_lengths, lengths2=x_lengths, norm=norm, K=1)
 | 
			
		||||
 | 
			
		||||
    cham_x = x_nn.dists[..., 0]  # (N, P1)
 | 
			
		||||
    cham_y = y_nn.dists[..., 0]  # (N, P2)
 | 
			
		||||
 | 
			
		||||
    if is_x_heterogeneous:
 | 
			
		||||
        cham_x[x_mask] = 0.0
 | 
			
		||||
    if is_y_heterogeneous:
 | 
			
		||||
        cham_y[y_mask] = 0.0
 | 
			
		||||
 | 
			
		||||
    if weights is not None:
 | 
			
		||||
        cham_x *= weights.view(N, 1)
 | 
			
		||||
        cham_y *= weights.view(N, 1)
 | 
			
		||||
 | 
			
		||||
    if return_normals:
 | 
			
		||||
        # Gather the normals using the indices and keep only value for k=0
 | 
			
		||||
        x_normals_near = knn_gather(y_normals, x_nn.idx, y_lengths)[..., 0, :]
 | 
			
		||||
        y_normals_near = knn_gather(x_normals, y_nn.idx, x_lengths)[..., 0, :]
 | 
			
		||||
 | 
			
		||||
        cham_norm_x = 1 - torch.abs(
 | 
			
		||||
            F.cosine_similarity(x_normals, x_normals_near, dim=2, eps=1e-6)
 | 
			
		||||
    cham_x, cham_norm_x = _chamfer_distance_single_direction(
 | 
			
		||||
        x,
 | 
			
		||||
        y,
 | 
			
		||||
        x_lengths,
 | 
			
		||||
        y_lengths,
 | 
			
		||||
        x_normals,
 | 
			
		||||
        y_normals,
 | 
			
		||||
        weights,
 | 
			
		||||
        batch_reduction,
 | 
			
		||||
        point_reduction,
 | 
			
		||||
        norm,
 | 
			
		||||
        abs_cosine,
 | 
			
		||||
    )
 | 
			
		||||
    if single_directional:
 | 
			
		||||
        return cham_x, cham_norm_x
 | 
			
		||||
    else:
 | 
			
		||||
        cham_y, cham_norm_y = _chamfer_distance_single_direction(
 | 
			
		||||
            y,
 | 
			
		||||
            x,
 | 
			
		||||
            y_lengths,
 | 
			
		||||
            x_lengths,
 | 
			
		||||
            y_normals,
 | 
			
		||||
            x_normals,
 | 
			
		||||
            weights,
 | 
			
		||||
            batch_reduction,
 | 
			
		||||
            point_reduction,
 | 
			
		||||
            norm,
 | 
			
		||||
            abs_cosine,
 | 
			
		||||
        )
 | 
			
		||||
        cham_norm_y = 1 - torch.abs(
 | 
			
		||||
            F.cosine_similarity(y_normals, y_normals_near, dim=2, eps=1e-6)
 | 
			
		||||
        return (
 | 
			
		||||
            cham_x + cham_y,
 | 
			
		||||
            (cham_norm_x + cham_norm_y) if cham_norm_x is not None else None,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        if is_x_heterogeneous:
 | 
			
		||||
            cham_norm_x[x_mask] = 0.0
 | 
			
		||||
        if is_y_heterogeneous:
 | 
			
		||||
            cham_norm_y[y_mask] = 0.0
 | 
			
		||||
 | 
			
		||||
        if weights is not None:
 | 
			
		||||
            cham_norm_x *= weights.view(N, 1)
 | 
			
		||||
            cham_norm_y *= weights.view(N, 1)
 | 
			
		||||
 | 
			
		||||
    # Apply point reduction
 | 
			
		||||
    cham_x = cham_x.sum(1)  # (N,)
 | 
			
		||||
    cham_y = cham_y.sum(1)  # (N,)
 | 
			
		||||
    if return_normals:
 | 
			
		||||
        cham_norm_x = cham_norm_x.sum(1)  # (N,)
 | 
			
		||||
        cham_norm_y = cham_norm_y.sum(1)  # (N,)
 | 
			
		||||
    if point_reduction == "mean":
 | 
			
		||||
        x_lengths_clamped = x_lengths.clamp(min=1)
 | 
			
		||||
        y_lengths_clamped = y_lengths.clamp(min=1)
 | 
			
		||||
        cham_x /= x_lengths_clamped
 | 
			
		||||
        cham_y /= y_lengths_clamped
 | 
			
		||||
        if return_normals:
 | 
			
		||||
            cham_norm_x /= x_lengths_clamped
 | 
			
		||||
            cham_norm_y /= y_lengths_clamped
 | 
			
		||||
 | 
			
		||||
    if batch_reduction is not None:
 | 
			
		||||
        # batch_reduction == "sum"
 | 
			
		||||
        cham_x = cham_x.sum()
 | 
			
		||||
        cham_y = cham_y.sum()
 | 
			
		||||
        if return_normals:
 | 
			
		||||
            cham_norm_x = cham_norm_x.sum()
 | 
			
		||||
            cham_norm_y = cham_norm_y.sum()
 | 
			
		||||
        if batch_reduction == "mean":
 | 
			
		||||
            div = weights.sum() if weights is not None else max(N, 1)
 | 
			
		||||
            cham_x /= div
 | 
			
		||||
            cham_y /= div
 | 
			
		||||
            if return_normals:
 | 
			
		||||
                cham_norm_x /= div
 | 
			
		||||
                cham_norm_y /= div
 | 
			
		||||
 | 
			
		||||
    cham_dist = cham_x + cham_y
 | 
			
		||||
    cham_normals = cham_norm_x + cham_norm_y if return_normals else None
 | 
			
		||||
 | 
			
		||||
    return cham_dist, cham_normals
 | 
			
		||||
 | 
			
		||||
@ -88,7 +88,9 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def chamfer_distance_naive_pointclouds(p1, p2, norm: int = 2, device="cpu"):
 | 
			
		||||
    def chamfer_distance_naive_pointclouds(
 | 
			
		||||
        p1, p2, norm: int = 2, device="cpu", abs_cosine=True
 | 
			
		||||
    ):
 | 
			
		||||
        """
 | 
			
		||||
        Naive iterative implementation of nearest neighbor and chamfer distance.
 | 
			
		||||
        x and y are assumed to be pointclouds objects with points and optionally normals.
 | 
			
		||||
@ -146,17 +148,20 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
        if return_normals:
 | 
			
		||||
            x_index = dist.argmin(2).view(N, P1, 1).expand(N, P1, 3)
 | 
			
		||||
            y_index = dist.argmin(1).view(N, P2, 1).expand(N, P2, 3)
 | 
			
		||||
            lnorm1 = 1 - torch.abs(
 | 
			
		||||
                F.cosine_similarity(
 | 
			
		||||
                    x_normals, y_normals.gather(1, x_index), dim=2, eps=1e-6
 | 
			
		||||
                )
 | 
			
		||||
            cosine_sim1 = F.cosine_similarity(
 | 
			
		||||
                x_normals, y_normals.gather(1, x_index), dim=2, eps=1e-6
 | 
			
		||||
            )
 | 
			
		||||
            lnorm2 = 1 - torch.abs(
 | 
			
		||||
                F.cosine_similarity(
 | 
			
		||||
                    y_normals, x_normals.gather(1, y_index), dim=2, eps=1e-6
 | 
			
		||||
                )
 | 
			
		||||
            cosine_sim2 = F.cosine_similarity(
 | 
			
		||||
                y_normals, x_normals.gather(1, y_index), dim=2, eps=1e-6
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            if abs_cosine:
 | 
			
		||||
                lnorm1 = 1 - torch.abs(cosine_sim1)
 | 
			
		||||
                lnorm2 = 1 - torch.abs(cosine_sim2)
 | 
			
		||||
            else:
 | 
			
		||||
                lnorm1 = 1 - cosine_sim1
 | 
			
		||||
                lnorm2 = 1 - cosine_sim2
 | 
			
		||||
 | 
			
		||||
            if is_x_heterogeneous:
 | 
			
		||||
                lnorm1[x_mask] = 0.0
 | 
			
		||||
            if is_y_heterogeneous:
 | 
			
		||||
@ -167,7 +172,9 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
        return loss, lnorm
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def chamfer_distance_naive(x, y, x_normals=None, y_normals=None, norm: int = 2):
 | 
			
		||||
    def chamfer_distance_naive(
 | 
			
		||||
        x, y, x_normals=None, y_normals=None, norm: int = 2, abs_cosine=True
 | 
			
		||||
    ):
 | 
			
		||||
        """
 | 
			
		||||
        Naive iterative implementation of nearest neighbor and chamfer distance.
 | 
			
		||||
        Returns lists of the unreduced loss and loss_normals. This naive
 | 
			
		||||
@ -200,16 +207,21 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
        if return_normals:
 | 
			
		||||
            x_index = dist.argmin(2).view(N, P1, 1).expand(N, P1, 3)
 | 
			
		||||
            y_index = dist.argmin(1).view(N, P2, 1).expand(N, P2, 3)
 | 
			
		||||
            lnorm1 = 1 - torch.abs(
 | 
			
		||||
                F.cosine_similarity(
 | 
			
		||||
                    x_normals, y_normals.gather(1, x_index), dim=2, eps=1e-6
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
            cosine_sim1 = F.cosine_similarity(
 | 
			
		||||
                x_normals, y_normals.gather(1, x_index), dim=2, eps=1e-6
 | 
			
		||||
            )
 | 
			
		||||
            lnorm2 = 1 - torch.abs(
 | 
			
		||||
                F.cosine_similarity(
 | 
			
		||||
                    y_normals, x_normals.gather(1, y_index), dim=2, eps=1e-6
 | 
			
		||||
                )
 | 
			
		||||
            cosine_sim2 = F.cosine_similarity(
 | 
			
		||||
                y_normals, x_normals.gather(1, y_index), dim=2, eps=1e-6
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            if abs_cosine:
 | 
			
		||||
                lnorm1 = 1 - torch.abs(cosine_sim1)
 | 
			
		||||
                lnorm2 = 1 - torch.abs(cosine_sim2)
 | 
			
		||||
            else:
 | 
			
		||||
                lnorm1 = 1 - cosine_sim1
 | 
			
		||||
                lnorm2 = 1 - cosine_sim2
 | 
			
		||||
 | 
			
		||||
            lnorm = [lnorm1, lnorm2]  # [(N, P1), (N, P2)]
 | 
			
		||||
 | 
			
		||||
        return loss, lnorm
 | 
			
		||||
@ -323,6 +335,80 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
                y_lengths,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
    def test_single_directional_chamfer_vs_naive_pointcloud(self):
 | 
			
		||||
        """
 | 
			
		||||
        Test the single directional settings for chamfer_distance
 | 
			
		||||
        (point reduction = "mean" and batch_reduction="mean") but with heterogeneous
 | 
			
		||||
        pointclouds as input. Compare with the naive implementation of chamfer
 | 
			
		||||
        which supports heterogeneous pointcloud objects.
 | 
			
		||||
        """
 | 
			
		||||
        N, max_P1, max_P2 = 3, 70, 70
 | 
			
		||||
        device = get_random_cuda_device()
 | 
			
		||||
 | 
			
		||||
        for norm in [1, 2]:
 | 
			
		||||
            for abs_cosine in [True, False]:
 | 
			
		||||
                points_normals = TestChamfer.init_pointclouds(N, max_P1, max_P2, device)
 | 
			
		||||
                weights = points_normals.weights
 | 
			
		||||
                x_lengths = points_normals.p1_lengths
 | 
			
		||||
                y_lengths = points_normals.p2_lengths
 | 
			
		||||
 | 
			
		||||
                # Chamfer with tensors as input for heterogeneous pointclouds.
 | 
			
		||||
                cham_tensor, norm_tensor = chamfer_distance(
 | 
			
		||||
                    points_normals.p1,
 | 
			
		||||
                    points_normals.p2,
 | 
			
		||||
                    x_normals=points_normals.n1,
 | 
			
		||||
                    y_normals=points_normals.n2,
 | 
			
		||||
                    x_lengths=points_normals.p1_lengths,
 | 
			
		||||
                    y_lengths=points_normals.p2_lengths,
 | 
			
		||||
                    weights=weights,
 | 
			
		||||
                    norm=norm,
 | 
			
		||||
                    single_directional=True,
 | 
			
		||||
                    abs_cosine=abs_cosine,
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
                # Chamfer with pointclouds as input.
 | 
			
		||||
                (
 | 
			
		||||
                    pred_loss,
 | 
			
		||||
                    pred_norm_loss,
 | 
			
		||||
                ) = TestChamfer.chamfer_distance_naive_pointclouds(
 | 
			
		||||
                    points_normals.cloud1,
 | 
			
		||||
                    points_normals.cloud2,
 | 
			
		||||
                    norm=norm,
 | 
			
		||||
                    device=device,
 | 
			
		||||
                    abs_cosine=abs_cosine,
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
                # Mean reduction point loss.
 | 
			
		||||
                pred_loss[0] *= weights.view(N, 1)
 | 
			
		||||
                pred_loss_mean = pred_loss[0].sum(1) / x_lengths
 | 
			
		||||
                pred_loss_mean = pred_loss_mean.sum()
 | 
			
		||||
                pred_loss_mean /= weights.sum()
 | 
			
		||||
 | 
			
		||||
                # Mean reduction norm loss.
 | 
			
		||||
                pred_norm_loss[0] *= weights.view(N, 1)
 | 
			
		||||
                pred_norm_loss_mean = pred_norm_loss[0].sum(1) / x_lengths
 | 
			
		||||
                pred_norm_loss_mean = pred_norm_loss_mean.sum() / weights.sum()
 | 
			
		||||
 | 
			
		||||
                self.assertClose(pred_loss_mean, cham_tensor)
 | 
			
		||||
                self.assertClose(pred_norm_loss_mean, norm_tensor)
 | 
			
		||||
 | 
			
		||||
                self._check_gradients(
 | 
			
		||||
                    cham_tensor,
 | 
			
		||||
                    norm_tensor,
 | 
			
		||||
                    pred_loss_mean,
 | 
			
		||||
                    pred_norm_loss_mean,
 | 
			
		||||
                    points_normals.cloud1.points_list(),
 | 
			
		||||
                    points_normals.p1,
 | 
			
		||||
                    points_normals.cloud2.points_list(),
 | 
			
		||||
                    points_normals.p2,
 | 
			
		||||
                    points_normals.cloud1.normals_list(),
 | 
			
		||||
                    points_normals.n1,
 | 
			
		||||
                    points_normals.cloud2.normals_list(),
 | 
			
		||||
                    points_normals.n2,
 | 
			
		||||
                    x_lengths,
 | 
			
		||||
                    y_lengths,
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
    def test_chamfer_pointcloud_object_withnormals(self):
 | 
			
		||||
        N = 5
 | 
			
		||||
        P1, P2 = 100, 100
 | 
			
		||||
@ -485,6 +571,53 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
            loss, loss_norm, pred_loss_mean, pred_loss_norm_mean, p1, p11, p2, p22
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def test_single_direction_chamfer_point_reduction_mean(self):
 | 
			
		||||
        """
 | 
			
		||||
        Compare output of vectorized chamfer loss with naive implementation
 | 
			
		||||
        for point_reduction = "mean" and batch_reduction = None.
 | 
			
		||||
        """
 | 
			
		||||
        N, max_P1, max_P2 = 7, 10, 18
 | 
			
		||||
        device = get_random_cuda_device()
 | 
			
		||||
        points_normals = TestChamfer.init_pointclouds(N, max_P1, max_P2, device)
 | 
			
		||||
        p1 = points_normals.p1
 | 
			
		||||
        p2 = points_normals.p2
 | 
			
		||||
        p1_normals = points_normals.n1
 | 
			
		||||
        p2_normals = points_normals.n2
 | 
			
		||||
        weights = points_normals.weights
 | 
			
		||||
        p11 = p1.detach().clone()
 | 
			
		||||
        p22 = p2.detach().clone()
 | 
			
		||||
        p11.requires_grad = True
 | 
			
		||||
        p22.requires_grad = True
 | 
			
		||||
        P1 = p1.shape[1]
 | 
			
		||||
 | 
			
		||||
        pred_loss, pred_loss_norm = TestChamfer.chamfer_distance_naive(
 | 
			
		||||
            p1, p2, x_normals=p1_normals, y_normals=p2_normals
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # point_reduction = "mean".
 | 
			
		||||
        loss, loss_norm = chamfer_distance(
 | 
			
		||||
            p11,
 | 
			
		||||
            p22,
 | 
			
		||||
            x_normals=p1_normals,
 | 
			
		||||
            y_normals=p2_normals,
 | 
			
		||||
            weights=weights,
 | 
			
		||||
            batch_reduction=None,
 | 
			
		||||
            point_reduction="mean",
 | 
			
		||||
            single_directional=True,
 | 
			
		||||
        )
 | 
			
		||||
        pred_loss_mean = pred_loss[0].sum(1) / P1
 | 
			
		||||
        pred_loss_mean *= weights
 | 
			
		||||
        self.assertClose(loss, pred_loss_mean)
 | 
			
		||||
 | 
			
		||||
        pred_loss_norm_mean = pred_loss_norm[0].sum(1) / P1
 | 
			
		||||
        pred_loss_norm_mean *= weights
 | 
			
		||||
        self.assertClose(loss_norm, pred_loss_norm_mean)
 | 
			
		||||
 | 
			
		||||
        # Check gradients
 | 
			
		||||
        self._check_gradients(
 | 
			
		||||
            loss, loss_norm, pred_loss_mean, pred_loss_norm_mean, p1, p11, p2, p22
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def test_chamfer_point_reduction_sum(self):
 | 
			
		||||
        """
 | 
			
		||||
        Compare output of vectorized chamfer loss with naive implementation
 | 
			
		||||
@ -529,6 +662,51 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
            loss, loss_norm, pred_loss_sum, pred_loss_norm_sum, p1, p11, p2, p22
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def test_single_directional_chamfer_point_reduction_sum(self):
 | 
			
		||||
        """
 | 
			
		||||
        Compare output of vectorized single directional chamfer loss with naive implementation
 | 
			
		||||
        for point_reduction = "sum" and batch_reduction = None.
 | 
			
		||||
        """
 | 
			
		||||
        N, P1, P2 = 7, 10, 18
 | 
			
		||||
        device = get_random_cuda_device()
 | 
			
		||||
        points_normals = TestChamfer.init_pointclouds(N, P1, P2, device)
 | 
			
		||||
        p1 = points_normals.p1
 | 
			
		||||
        p2 = points_normals.p2
 | 
			
		||||
        p1_normals = points_normals.n1
 | 
			
		||||
        p2_normals = points_normals.n2
 | 
			
		||||
        weights = points_normals.weights
 | 
			
		||||
        p11 = p1.detach().clone()
 | 
			
		||||
        p22 = p2.detach().clone()
 | 
			
		||||
        p11.requires_grad = True
 | 
			
		||||
        p22.requires_grad = True
 | 
			
		||||
 | 
			
		||||
        pred_loss, pred_loss_norm = TestChamfer.chamfer_distance_naive(
 | 
			
		||||
            p1, p2, x_normals=p1_normals, y_normals=p2_normals
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        loss, loss_norm = chamfer_distance(
 | 
			
		||||
            p11,
 | 
			
		||||
            p22,
 | 
			
		||||
            x_normals=p1_normals,
 | 
			
		||||
            y_normals=p2_normals,
 | 
			
		||||
            weights=weights,
 | 
			
		||||
            batch_reduction=None,
 | 
			
		||||
            point_reduction="sum",
 | 
			
		||||
            single_directional=True,
 | 
			
		||||
        )
 | 
			
		||||
        pred_loss_sum = pred_loss[0].sum(1)
 | 
			
		||||
        pred_loss_sum *= weights
 | 
			
		||||
        self.assertClose(loss, pred_loss_sum)
 | 
			
		||||
 | 
			
		||||
        pred_loss_norm_sum = pred_loss_norm[0].sum(1)
 | 
			
		||||
        pred_loss_norm_sum *= weights
 | 
			
		||||
        self.assertClose(loss_norm, pred_loss_norm_sum)
 | 
			
		||||
 | 
			
		||||
        # Check gradients
 | 
			
		||||
        self._check_gradients(
 | 
			
		||||
            loss, loss_norm, pred_loss_sum, pred_loss_norm_sum, p1, p11, p2, p22
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def _check_gradients(
 | 
			
		||||
        self,
 | 
			
		||||
        loss,
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user