mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 09:52:11 +08:00 
			
		
		
		
	Add "max" point reduction for chamfer distance
Summary: * Adds a "max" option for the point_reduction input to the chamfer_distance function. * When combining the x and y directions, maxes the losses instead of summing them when point_reduction="max". * Moves batch reduction to happen after the directions are combined. * Adds test_chamfer_point_reduction_max and test_single_directional_chamfer_point_reduction_max tests. Fixes https://github.com/facebookresearch/pytorch3d/issues/1838 Reviewed By: bottler Differential Revision: D60614661 fbshipit-source-id: 7879816acfda03e945bada951b931d2c522756eb
This commit is contained in:
		
							parent
							
								
									7edaee71a9
								
							
						
					
					
						commit
						44702fdb4b
					
				@ -27,8 +27,10 @@ def _validate_chamfer_reduction_inputs(
 | 
			
		||||
    """
 | 
			
		||||
    if batch_reduction is not None and batch_reduction not in ["mean", "sum"]:
 | 
			
		||||
        raise ValueError('batch_reduction must be one of ["mean", "sum"] or None')
 | 
			
		||||
    if point_reduction is not None and point_reduction not in ["mean", "sum"]:
 | 
			
		||||
        raise ValueError('point_reduction must be one of ["mean", "sum"] or None')
 | 
			
		||||
    if point_reduction is not None and point_reduction not in ["mean", "sum", "max"]:
 | 
			
		||||
        raise ValueError(
 | 
			
		||||
            'point_reduction must be one of ["mean", "sum", "max"] or None'
 | 
			
		||||
        )
 | 
			
		||||
    if point_reduction is None and batch_reduction is not None:
 | 
			
		||||
        raise ValueError("Batch reduction must be None if point_reduction is None")
 | 
			
		||||
 | 
			
		||||
@ -80,7 +82,6 @@ def _chamfer_distance_single_direction(
 | 
			
		||||
    x_normals,
 | 
			
		||||
    y_normals,
 | 
			
		||||
    weights,
 | 
			
		||||
    batch_reduction: Union[str, None],
 | 
			
		||||
    point_reduction: Union[str, None],
 | 
			
		||||
    norm: int,
 | 
			
		||||
    abs_cosine: bool,
 | 
			
		||||
@ -103,11 +104,6 @@ def _chamfer_distance_single_direction(
 | 
			
		||||
            raise ValueError("weights cannot be negative.")
 | 
			
		||||
        if weights.sum() == 0.0:
 | 
			
		||||
            weights = weights.view(N, 1)
 | 
			
		||||
            if batch_reduction in ["mean", "sum"]:
 | 
			
		||||
                return (
 | 
			
		||||
                    (x.sum((1, 2)) * weights).sum() * 0.0,
 | 
			
		||||
                    (x.sum((1, 2)) * weights).sum() * 0.0,
 | 
			
		||||
                )
 | 
			
		||||
            return ((x.sum((1, 2)) * weights) * 0.0, (x.sum((1, 2)) * weights) * 0.0)
 | 
			
		||||
 | 
			
		||||
    cham_norm_x = x.new_zeros(())
 | 
			
		||||
@ -135,7 +131,10 @@ def _chamfer_distance_single_direction(
 | 
			
		||||
        if weights is not None:
 | 
			
		||||
            cham_norm_x *= weights.view(N, 1)
 | 
			
		||||
 | 
			
		||||
    if point_reduction is not None:
 | 
			
		||||
    if point_reduction == "max":
 | 
			
		||||
        assert not return_normals
 | 
			
		||||
        cham_x = cham_x.max(1).values  # (N,)
 | 
			
		||||
    elif point_reduction is not None:
 | 
			
		||||
        # Apply point reduction
 | 
			
		||||
        cham_x = cham_x.sum(1)  # (N,)
 | 
			
		||||
        if return_normals:
 | 
			
		||||
@ -146,22 +145,34 @@ def _chamfer_distance_single_direction(
 | 
			
		||||
            if return_normals:
 | 
			
		||||
                cham_norm_x /= x_lengths_clamped
 | 
			
		||||
 | 
			
		||||
        if batch_reduction is not None:
 | 
			
		||||
            # batch_reduction == "sum"
 | 
			
		||||
            cham_x = cham_x.sum()
 | 
			
		||||
            if return_normals:
 | 
			
		||||
                cham_norm_x = cham_norm_x.sum()
 | 
			
		||||
            if batch_reduction == "mean":
 | 
			
		||||
                div = weights.sum() if weights is not None else max(N, 1)
 | 
			
		||||
                cham_x /= div
 | 
			
		||||
                if return_normals:
 | 
			
		||||
                    cham_norm_x /= div
 | 
			
		||||
 | 
			
		||||
    cham_dist = cham_x
 | 
			
		||||
    cham_normals = cham_norm_x if return_normals else None
 | 
			
		||||
    return cham_dist, cham_normals
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _apply_batch_reduction(
 | 
			
		||||
    cham_x, cham_norm_x, weights, batch_reduction: Union[str, None]
 | 
			
		||||
):
 | 
			
		||||
    if batch_reduction is None:
 | 
			
		||||
        return (cham_x, cham_norm_x)
 | 
			
		||||
    # batch_reduction == "sum"
 | 
			
		||||
    N = cham_x.shape[0]
 | 
			
		||||
    cham_x = cham_x.sum()
 | 
			
		||||
    if cham_norm_x is not None:
 | 
			
		||||
        cham_norm_x = cham_norm_x.sum()
 | 
			
		||||
    if batch_reduction == "mean":
 | 
			
		||||
        if weights is None:
 | 
			
		||||
            div = max(N, 1)
 | 
			
		||||
        elif weights.sum() == 0.0:
 | 
			
		||||
            div = 1
 | 
			
		||||
        else:
 | 
			
		||||
            div = weights.sum()
 | 
			
		||||
        cham_x /= div
 | 
			
		||||
        if cham_norm_x is not None:
 | 
			
		||||
            cham_norm_x /= div
 | 
			
		||||
    return (cham_x, cham_norm_x)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def chamfer_distance(
 | 
			
		||||
    x,
 | 
			
		||||
    y,
 | 
			
		||||
@ -197,7 +208,8 @@ def chamfer_distance(
 | 
			
		||||
        batch_reduction: Reduction operation to apply for the loss across the
 | 
			
		||||
            batch, can be one of ["mean", "sum"] or None.
 | 
			
		||||
        point_reduction: Reduction operation to apply for the loss across the
 | 
			
		||||
            points, can be one of ["mean", "sum"] or None.
 | 
			
		||||
            points, can be one of ["mean", "sum", "max"] or None. Using "max" leads to the
 | 
			
		||||
            Hausdorff distance.
 | 
			
		||||
        norm: int indicates the norm used for the distance. Supports 1 for L1 and 2 for L2.
 | 
			
		||||
        single_directional: If False (default), loss comes from both the distance between
 | 
			
		||||
            each point in x and its nearest neighbor in y and each point in y and its nearest
 | 
			
		||||
@ -227,6 +239,10 @@ def chamfer_distance(
 | 
			
		||||
 | 
			
		||||
    if not ((norm == 1) or (norm == 2)):
 | 
			
		||||
        raise ValueError("Support for 1 or 2 norm.")
 | 
			
		||||
 | 
			
		||||
    if point_reduction == "max" and (x_normals is not None or y_normals is not None):
 | 
			
		||||
        raise ValueError('Normals must be None if point_reduction is "max"')
 | 
			
		||||
 | 
			
		||||
    x, x_lengths, x_normals = _handle_pointcloud_input(x, x_lengths, x_normals)
 | 
			
		||||
    y, y_lengths, y_normals = _handle_pointcloud_input(y, y_lengths, y_normals)
 | 
			
		||||
 | 
			
		||||
@ -238,13 +254,13 @@ def chamfer_distance(
 | 
			
		||||
        x_normals,
 | 
			
		||||
        y_normals,
 | 
			
		||||
        weights,
 | 
			
		||||
        batch_reduction,
 | 
			
		||||
        point_reduction,
 | 
			
		||||
        norm,
 | 
			
		||||
        abs_cosine,
 | 
			
		||||
    )
 | 
			
		||||
    if single_directional:
 | 
			
		||||
        return cham_x, cham_norm_x
 | 
			
		||||
        loss = cham_x
 | 
			
		||||
        loss_normals = cham_norm_x
 | 
			
		||||
    else:
 | 
			
		||||
        cham_y, cham_norm_y = _chamfer_distance_single_direction(
 | 
			
		||||
            y,
 | 
			
		||||
@ -254,17 +270,23 @@ def chamfer_distance(
 | 
			
		||||
            y_normals,
 | 
			
		||||
            x_normals,
 | 
			
		||||
            weights,
 | 
			
		||||
            batch_reduction,
 | 
			
		||||
            point_reduction,
 | 
			
		||||
            norm,
 | 
			
		||||
            abs_cosine,
 | 
			
		||||
        )
 | 
			
		||||
        if point_reduction is not None:
 | 
			
		||||
            return (
 | 
			
		||||
                cham_x + cham_y,
 | 
			
		||||
                (cham_norm_x + cham_norm_y) if cham_norm_x is not None else None,
 | 
			
		||||
            )
 | 
			
		||||
        return (
 | 
			
		||||
            (cham_x, cham_y),
 | 
			
		||||
            (cham_norm_x, cham_norm_y) if cham_norm_x is not None else None,
 | 
			
		||||
        )
 | 
			
		||||
        if point_reduction == "max":
 | 
			
		||||
            loss = torch.maximum(cham_x, cham_y)
 | 
			
		||||
            loss_normals = None
 | 
			
		||||
        elif point_reduction is not None:
 | 
			
		||||
            loss = cham_x + cham_y
 | 
			
		||||
            if cham_norm_x is not None:
 | 
			
		||||
                loss_normals = cham_norm_x + cham_norm_y
 | 
			
		||||
            else:
 | 
			
		||||
                loss_normals = None
 | 
			
		||||
        else:
 | 
			
		||||
            loss = (cham_x, cham_y)
 | 
			
		||||
            if cham_norm_x is not None:
 | 
			
		||||
                loss_normals = (cham_norm_x, cham_norm_y)
 | 
			
		||||
            else:
 | 
			
		||||
                loss_normals = None
 | 
			
		||||
    return _apply_batch_reduction(loss, loss_normals, weights, batch_reduction)
 | 
			
		||||
 | 
			
		||||
@ -847,6 +847,85 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
            loss, loss_norm, pred_loss[0], pred_loss_norm[0], p1, p11, p2, p22
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def test_chamfer_point_reduction_max(self):
 | 
			
		||||
        """
 | 
			
		||||
        Compare output of vectorized chamfer loss with naive implementation
 | 
			
		||||
        for point_reduction = "max" and batch_reduction = None.
 | 
			
		||||
        """
 | 
			
		||||
        N, P1, P2 = 7, 10, 18
 | 
			
		||||
        device = get_random_cuda_device()
 | 
			
		||||
        points_normals = TestChamfer.init_pointclouds(N, P1, P2, device)
 | 
			
		||||
        p1 = points_normals.p1
 | 
			
		||||
        p2 = points_normals.p2
 | 
			
		||||
        weights = points_normals.weights
 | 
			
		||||
        p11 = p1.detach().clone()
 | 
			
		||||
        p22 = p2.detach().clone()
 | 
			
		||||
        p11.requires_grad = True
 | 
			
		||||
        p22.requires_grad = True
 | 
			
		||||
 | 
			
		||||
        pred_loss, unused_pred_loss_norm = TestChamfer.chamfer_distance_naive(
 | 
			
		||||
            p1, p2, x_normals=None, y_normals=None
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        loss, loss_norm = chamfer_distance(
 | 
			
		||||
            p11,
 | 
			
		||||
            p22,
 | 
			
		||||
            x_normals=None,
 | 
			
		||||
            y_normals=None,
 | 
			
		||||
            weights=weights,
 | 
			
		||||
            batch_reduction=None,
 | 
			
		||||
            point_reduction="max",
 | 
			
		||||
        )
 | 
			
		||||
        pred_loss_max = torch.maximum(
 | 
			
		||||
            pred_loss[0].max(1).values, pred_loss[1].max(1).values
 | 
			
		||||
        )
 | 
			
		||||
        pred_loss_max *= weights
 | 
			
		||||
        self.assertClose(loss, pred_loss_max)
 | 
			
		||||
 | 
			
		||||
        self.assertIsNone(loss_norm)
 | 
			
		||||
 | 
			
		||||
        # Check gradients
 | 
			
		||||
        self._check_gradients(loss, loss_norm, pred_loss_max, None, p1, p11, p2, p22)
 | 
			
		||||
 | 
			
		||||
    def test_single_directional_chamfer_point_reduction_max(self):
 | 
			
		||||
        """
 | 
			
		||||
        Compare output of vectorized single directional chamfer loss with naive implementation
 | 
			
		||||
        for point_reduction = "max" and batch_reduction = None.
 | 
			
		||||
        """
 | 
			
		||||
        N, P1, P2 = 7, 10, 18
 | 
			
		||||
        device = get_random_cuda_device()
 | 
			
		||||
        points_normals = TestChamfer.init_pointclouds(N, P1, P2, device)
 | 
			
		||||
        p1 = points_normals.p1
 | 
			
		||||
        p2 = points_normals.p2
 | 
			
		||||
        weights = points_normals.weights
 | 
			
		||||
        p11 = p1.detach().clone()
 | 
			
		||||
        p22 = p2.detach().clone()
 | 
			
		||||
        p11.requires_grad = True
 | 
			
		||||
        p22.requires_grad = True
 | 
			
		||||
 | 
			
		||||
        pred_loss, unused_pred_loss_norm = TestChamfer.chamfer_distance_naive(
 | 
			
		||||
            p1, p2, x_normals=None, y_normals=None
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        loss, loss_norm = chamfer_distance(
 | 
			
		||||
            p11,
 | 
			
		||||
            p22,
 | 
			
		||||
            x_normals=None,
 | 
			
		||||
            y_normals=None,
 | 
			
		||||
            weights=weights,
 | 
			
		||||
            batch_reduction=None,
 | 
			
		||||
            point_reduction="max",
 | 
			
		||||
            single_directional=True,
 | 
			
		||||
        )
 | 
			
		||||
        pred_loss_max = pred_loss[0].max(1).values
 | 
			
		||||
        pred_loss_max *= weights
 | 
			
		||||
        self.assertClose(loss, pred_loss_max)
 | 
			
		||||
 | 
			
		||||
        self.assertIsNone(loss_norm)
 | 
			
		||||
 | 
			
		||||
        # Check gradients
 | 
			
		||||
        self._check_gradients(loss, loss_norm, pred_loss_max, None, p1, p11, p2, p22)
 | 
			
		||||
 | 
			
		||||
    def _check_gradients(
 | 
			
		||||
        self,
 | 
			
		||||
        loss,
 | 
			
		||||
@ -1020,9 +1099,9 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
        with self.assertRaisesRegex(ValueError, "batch_reduction must be one of"):
 | 
			
		||||
            chamfer_distance(p1, p2, weights=weights, batch_reduction="max")
 | 
			
		||||
 | 
			
		||||
        # Error when point_reduction is not in ["mean", "sum"] or None.
 | 
			
		||||
        # Error when point_reduction is not in ["mean", "sum", "max"] or None.
 | 
			
		||||
        with self.assertRaisesRegex(ValueError, "point_reduction must be one of"):
 | 
			
		||||
            chamfer_distance(p1, p2, weights=weights, point_reduction="max")
 | 
			
		||||
            chamfer_distance(p1, p2, weights=weights, point_reduction="min")
 | 
			
		||||
 | 
			
		||||
    def test_incorrect_weights(self):
 | 
			
		||||
        N, P1, P2 = 16, 64, 128
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user