mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 18:02:14 +08:00 
			
		
		
		
	add None option for chamfer distance point reduction (#1605)
Summary: The `chamfer_distance` function currently allows `"sum"` or `"mean"` reduction, but does not support returning unreduced (per-point) loss terms. Unreduced losses could be useful if the user wishes to inspect individual losses, or perform additional modifications to loss terms before reduction. One example would be implementing a robust kernel over the loss. This PR adds a `None` option to the `point_reduction` parameter, similar to `batch_reduction`. In case of bi-directional chamfer loss, both the forward and backward distances are returned (a tuple of Tensors of shape `[D, N]` is returned). If normals are provided, similar logic applies to normals as well. This PR addresses issue https://github.com/facebookresearch/pytorch3d/issues/622. Pull Request resolved: https://github.com/facebookresearch/pytorch3d/pull/1605 Reviewed By: jcjohnson Differential Revision: D48313857 Pulled By: bottler fbshipit-source-id: 35c824827a143649b04166c4817449e1341b7fd9
This commit is contained in:
		
							parent
							
								
									099fc069fb
								
							
						
					
					
						commit
						d84f274a08
					
				@ -13,7 +13,7 @@ from pytorch3d.structures.pointclouds import Pointclouds
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _validate_chamfer_reduction_inputs(
 | 
			
		||||
    batch_reduction: Union[str, None], point_reduction: str
 | 
			
		||||
    batch_reduction: Union[str, None], point_reduction: Union[str, None]
 | 
			
		||||
) -> None:
 | 
			
		||||
    """Check the requested reductions are valid.
 | 
			
		||||
 | 
			
		||||
@ -21,12 +21,14 @@ def _validate_chamfer_reduction_inputs(
 | 
			
		||||
        batch_reduction: Reduction operation to apply for the loss across the
 | 
			
		||||
            batch, can be one of ["mean", "sum"] or None.
 | 
			
		||||
        point_reduction: Reduction operation to apply for the loss across the
 | 
			
		||||
            points, can be one of ["mean", "sum"].
 | 
			
		||||
            points, can be one of ["mean", "sum"] or None.
 | 
			
		||||
    """
 | 
			
		||||
    if batch_reduction is not None and batch_reduction not in ["mean", "sum"]:
 | 
			
		||||
        raise ValueError('batch_reduction must be one of ["mean", "sum"] or None')
 | 
			
		||||
    if point_reduction not in ["mean", "sum"]:
 | 
			
		||||
        raise ValueError('point_reduction must be one of ["mean", "sum"]')
 | 
			
		||||
    if point_reduction is not None and point_reduction not in ["mean", "sum"]:
 | 
			
		||||
        raise ValueError('point_reduction must be one of ["mean", "sum"] or None')
 | 
			
		||||
    if point_reduction is None and batch_reduction is not None:
 | 
			
		||||
        raise ValueError("Batch reduction must be None if point_reduction is None")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _handle_pointcloud_input(
 | 
			
		||||
@ -77,7 +79,7 @@ def _chamfer_distance_single_direction(
 | 
			
		||||
    y_normals,
 | 
			
		||||
    weights,
 | 
			
		||||
    batch_reduction: Union[str, None],
 | 
			
		||||
    point_reduction: str,
 | 
			
		||||
    point_reduction: Union[str, None],
 | 
			
		||||
    norm: int,
 | 
			
		||||
    abs_cosine: bool,
 | 
			
		||||
):
 | 
			
		||||
@ -130,26 +132,28 @@ def _chamfer_distance_single_direction(
 | 
			
		||||
 | 
			
		||||
        if weights is not None:
 | 
			
		||||
            cham_norm_x *= weights.view(N, 1)
 | 
			
		||||
        cham_norm_x = cham_norm_x.sum(1)  # (N,)
 | 
			
		||||
 | 
			
		||||
    # Apply point reduction
 | 
			
		||||
    cham_x = cham_x.sum(1)  # (N,)
 | 
			
		||||
    if point_reduction == "mean":
 | 
			
		||||
        x_lengths_clamped = x_lengths.clamp(min=1)
 | 
			
		||||
        cham_x /= x_lengths_clamped
 | 
			
		||||
    if point_reduction is not None:
 | 
			
		||||
        # Apply point reduction
 | 
			
		||||
        cham_x = cham_x.sum(1)  # (N,)
 | 
			
		||||
        if return_normals:
 | 
			
		||||
            cham_norm_x /= x_lengths_clamped
 | 
			
		||||
 | 
			
		||||
    if batch_reduction is not None:
 | 
			
		||||
        # batch_reduction == "sum"
 | 
			
		||||
        cham_x = cham_x.sum()
 | 
			
		||||
        if return_normals:
 | 
			
		||||
            cham_norm_x = cham_norm_x.sum()
 | 
			
		||||
        if batch_reduction == "mean":
 | 
			
		||||
            div = weights.sum() if weights is not None else max(N, 1)
 | 
			
		||||
            cham_x /= div
 | 
			
		||||
            cham_norm_x = cham_norm_x.sum(1)  # (N,)
 | 
			
		||||
        if point_reduction == "mean":
 | 
			
		||||
            x_lengths_clamped = x_lengths.clamp(min=1)
 | 
			
		||||
            cham_x /= x_lengths_clamped
 | 
			
		||||
            if return_normals:
 | 
			
		||||
                cham_norm_x /= div
 | 
			
		||||
                cham_norm_x /= x_lengths_clamped
 | 
			
		||||
 | 
			
		||||
        if batch_reduction is not None:
 | 
			
		||||
            # batch_reduction == "sum"
 | 
			
		||||
            cham_x = cham_x.sum()
 | 
			
		||||
            if return_normals:
 | 
			
		||||
                cham_norm_x = cham_norm_x.sum()
 | 
			
		||||
            if batch_reduction == "mean":
 | 
			
		||||
                div = weights.sum() if weights is not None else max(N, 1)
 | 
			
		||||
                cham_x /= div
 | 
			
		||||
                if return_normals:
 | 
			
		||||
                    cham_norm_x /= div
 | 
			
		||||
 | 
			
		||||
    cham_dist = cham_x
 | 
			
		||||
    cham_normals = cham_norm_x if return_normals else None
 | 
			
		||||
@ -165,7 +169,7 @@ def chamfer_distance(
 | 
			
		||||
    y_normals=None,
 | 
			
		||||
    weights=None,
 | 
			
		||||
    batch_reduction: Union[str, None] = "mean",
 | 
			
		||||
    point_reduction: str = "mean",
 | 
			
		||||
    point_reduction: Union[str, None] = "mean",
 | 
			
		||||
    norm: int = 2,
 | 
			
		||||
    single_directional: bool = False,
 | 
			
		||||
    abs_cosine: bool = True,
 | 
			
		||||
@ -191,7 +195,7 @@ def chamfer_distance(
 | 
			
		||||
        batch_reduction: Reduction operation to apply for the loss across the
 | 
			
		||||
            batch, can be one of ["mean", "sum"] or None.
 | 
			
		||||
        point_reduction: Reduction operation to apply for the loss across the
 | 
			
		||||
            points, can be one of ["mean", "sum"].
 | 
			
		||||
            points, can be one of ["mean", "sum"] or None.
 | 
			
		||||
        norm: int indicates the norm used for the distance. Supports 1 for L1 and 2 for L2.
 | 
			
		||||
        single_directional: If False (default), loss comes from both the distance between
 | 
			
		||||
            each point in x and its nearest neighbor in y and each point in y and its nearest
 | 
			
		||||
@ -206,11 +210,16 @@ def chamfer_distance(
 | 
			
		||||
        2-element tuple containing
 | 
			
		||||
 | 
			
		||||
        - **loss**: Tensor giving the reduced distance between the pointclouds
 | 
			
		||||
          in x and the pointclouds in y.
 | 
			
		||||
          in x and the pointclouds in y. If point_reduction is None, a 2-element
 | 
			
		||||
          tuple of Tensors containing forward and backward loss terms shaped (N, P1)
 | 
			
		||||
          and (N, P2) (if single_directional is False) or a Tensor containing loss
 | 
			
		||||
          terms shaped (N, P1) (if single_directional is True) is returned.
 | 
			
		||||
        - **loss_normals**: Tensor giving the reduced cosine distance of normals
 | 
			
		||||
          between pointclouds in x and pointclouds in y. Returns None if
 | 
			
		||||
          x_normals and y_normals are None.
 | 
			
		||||
 | 
			
		||||
          x_normals and y_normals are None. If point_reduction is None, a 2-element
 | 
			
		||||
          tuple of Tensors containing forward and backward loss terms shaped (N, P1)
 | 
			
		||||
          and (N, P2) (if single_directional is False) or a Tensor containing loss
 | 
			
		||||
          terms shaped (N, P1) (if single_directional is True) is returned.
 | 
			
		||||
    """
 | 
			
		||||
    _validate_chamfer_reduction_inputs(batch_reduction, point_reduction)
 | 
			
		||||
 | 
			
		||||
@ -248,7 +257,12 @@ def chamfer_distance(
 | 
			
		||||
            norm,
 | 
			
		||||
            abs_cosine,
 | 
			
		||||
        )
 | 
			
		||||
        if point_reduction is not None:
 | 
			
		||||
            return (
 | 
			
		||||
                cham_x + cham_y,
 | 
			
		||||
                (cham_norm_x + cham_norm_y) if cham_norm_x is not None else None,
 | 
			
		||||
            )
 | 
			
		||||
        return (
 | 
			
		||||
            cham_x + cham_y,
 | 
			
		||||
            (cham_norm_x + cham_norm_y) if cham_norm_x is not None else None,
 | 
			
		||||
            (cham_x, cham_y),
 | 
			
		||||
            (cham_norm_x, cham_norm_y) if cham_norm_x is not None else None,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
@ -421,9 +421,9 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
            ("mean", "mean"),
 | 
			
		||||
            ("sum", None),
 | 
			
		||||
            ("mean", None),
 | 
			
		||||
            (None, None),
 | 
			
		||||
        ]
 | 
			
		||||
        for (point_reduction, batch_reduction) in reductions:
 | 
			
		||||
 | 
			
		||||
        for point_reduction, batch_reduction in reductions:
 | 
			
		||||
            # Reinitialize all the tensors so that the
 | 
			
		||||
            # backward pass can be computed.
 | 
			
		||||
            points_normals = TestChamfer.init_pointclouds(
 | 
			
		||||
@ -450,24 +450,52 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
                batch_reduction=batch_reduction,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            self.assertClose(cham_cloud, cham_tensor)
 | 
			
		||||
            self.assertClose(norm_cloud, norm_tensor)
 | 
			
		||||
            self._check_gradients(
 | 
			
		||||
                cham_tensor,
 | 
			
		||||
                norm_tensor,
 | 
			
		||||
                cham_cloud,
 | 
			
		||||
                norm_cloud,
 | 
			
		||||
                points_normals.cloud1.points_list(),
 | 
			
		||||
                points_normals.p1,
 | 
			
		||||
                points_normals.cloud2.points_list(),
 | 
			
		||||
                points_normals.p2,
 | 
			
		||||
                points_normals.cloud1.normals_list(),
 | 
			
		||||
                points_normals.n1,
 | 
			
		||||
                points_normals.cloud2.normals_list(),
 | 
			
		||||
                points_normals.n2,
 | 
			
		||||
                points_normals.p1_lengths,
 | 
			
		||||
                points_normals.p2_lengths,
 | 
			
		||||
            )
 | 
			
		||||
            if point_reduction is None:
 | 
			
		||||
                cham_tensor_bidirectional = torch.hstack(
 | 
			
		||||
                    [cham_tensor[0], cham_tensor[1]]
 | 
			
		||||
                )
 | 
			
		||||
                norm_tensor_bidirectional = torch.hstack(
 | 
			
		||||
                    [norm_tensor[0], norm_tensor[1]]
 | 
			
		||||
                )
 | 
			
		||||
                cham_cloud_bidirectional = torch.hstack([cham_cloud[0], cham_cloud[1]])
 | 
			
		||||
                norm_cloud_bidirectional = torch.hstack([norm_cloud[0], norm_cloud[1]])
 | 
			
		||||
                self.assertClose(cham_cloud_bidirectional, cham_tensor_bidirectional)
 | 
			
		||||
                self.assertClose(norm_cloud_bidirectional, norm_tensor_bidirectional)
 | 
			
		||||
                self._check_gradients(
 | 
			
		||||
                    cham_tensor_bidirectional,
 | 
			
		||||
                    norm_tensor_bidirectional,
 | 
			
		||||
                    cham_cloud_bidirectional,
 | 
			
		||||
                    norm_cloud_bidirectional,
 | 
			
		||||
                    points_normals.cloud1.points_list(),
 | 
			
		||||
                    points_normals.p1,
 | 
			
		||||
                    points_normals.cloud2.points_list(),
 | 
			
		||||
                    points_normals.p2,
 | 
			
		||||
                    points_normals.cloud1.normals_list(),
 | 
			
		||||
                    points_normals.n1,
 | 
			
		||||
                    points_normals.cloud2.normals_list(),
 | 
			
		||||
                    points_normals.n2,
 | 
			
		||||
                    points_normals.p1_lengths,
 | 
			
		||||
                    points_normals.p2_lengths,
 | 
			
		||||
                )
 | 
			
		||||
            else:
 | 
			
		||||
                self.assertClose(cham_cloud, cham_tensor)
 | 
			
		||||
                self.assertClose(norm_cloud, norm_tensor)
 | 
			
		||||
                self._check_gradients(
 | 
			
		||||
                    cham_tensor,
 | 
			
		||||
                    norm_tensor,
 | 
			
		||||
                    cham_cloud,
 | 
			
		||||
                    norm_cloud,
 | 
			
		||||
                    points_normals.cloud1.points_list(),
 | 
			
		||||
                    points_normals.p1,
 | 
			
		||||
                    points_normals.cloud2.points_list(),
 | 
			
		||||
                    points_normals.p2,
 | 
			
		||||
                    points_normals.cloud1.normals_list(),
 | 
			
		||||
                    points_normals.n1,
 | 
			
		||||
                    points_normals.cloud2.normals_list(),
 | 
			
		||||
                    points_normals.n2,
 | 
			
		||||
                    points_normals.p1_lengths,
 | 
			
		||||
                    points_normals.p2_lengths,
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
    def test_chamfer_pointcloud_object_nonormals(self):
 | 
			
		||||
        N = 5
 | 
			
		||||
@ -481,9 +509,9 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
            ("mean", "mean"),
 | 
			
		||||
            ("sum", None),
 | 
			
		||||
            ("mean", None),
 | 
			
		||||
            (None, None),
 | 
			
		||||
        ]
 | 
			
		||||
        for (point_reduction, batch_reduction) in reductions:
 | 
			
		||||
 | 
			
		||||
        for point_reduction, batch_reduction in reductions:
 | 
			
		||||
            # Reinitialize all the tensors so that the
 | 
			
		||||
            # backward pass can be computed.
 | 
			
		||||
            points_normals = TestChamfer.init_pointclouds(
 | 
			
		||||
@ -508,19 +536,38 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
                batch_reduction=batch_reduction,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            self.assertClose(cham_cloud, cham_tensor)
 | 
			
		||||
            self._check_gradients(
 | 
			
		||||
                cham_tensor,
 | 
			
		||||
                None,
 | 
			
		||||
                cham_cloud,
 | 
			
		||||
                None,
 | 
			
		||||
                points_normals.cloud1.points_list(),
 | 
			
		||||
                points_normals.p1,
 | 
			
		||||
                points_normals.cloud2.points_list(),
 | 
			
		||||
                points_normals.p2,
 | 
			
		||||
                lengths1=points_normals.p1_lengths,
 | 
			
		||||
                lengths2=points_normals.p2_lengths,
 | 
			
		||||
            )
 | 
			
		||||
            if point_reduction is None:
 | 
			
		||||
                cham_tensor_bidirectional = torch.hstack(
 | 
			
		||||
                    [cham_tensor[0], cham_tensor[1]]
 | 
			
		||||
                )
 | 
			
		||||
                cham_cloud_bidirectional = torch.hstack([cham_cloud[0], cham_cloud[1]])
 | 
			
		||||
                self.assertClose(cham_cloud_bidirectional, cham_tensor_bidirectional)
 | 
			
		||||
                self._check_gradients(
 | 
			
		||||
                    cham_tensor_bidirectional,
 | 
			
		||||
                    None,
 | 
			
		||||
                    cham_cloud_bidirectional,
 | 
			
		||||
                    None,
 | 
			
		||||
                    points_normals.cloud1.points_list(),
 | 
			
		||||
                    points_normals.p1,
 | 
			
		||||
                    points_normals.cloud2.points_list(),
 | 
			
		||||
                    points_normals.p2,
 | 
			
		||||
                    lengths1=points_normals.p1_lengths,
 | 
			
		||||
                    lengths2=points_normals.p2_lengths,
 | 
			
		||||
                )
 | 
			
		||||
            else:
 | 
			
		||||
                self.assertClose(cham_cloud, cham_tensor)
 | 
			
		||||
                self._check_gradients(
 | 
			
		||||
                    cham_tensor,
 | 
			
		||||
                    None,
 | 
			
		||||
                    cham_cloud,
 | 
			
		||||
                    None,
 | 
			
		||||
                    points_normals.cloud1.points_list(),
 | 
			
		||||
                    points_normals.p1,
 | 
			
		||||
                    points_normals.cloud2.points_list(),
 | 
			
		||||
                    points_normals.p2,
 | 
			
		||||
                    lengths1=points_normals.p1_lengths,
 | 
			
		||||
                    lengths2=points_normals.p2_lengths,
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
    def test_chamfer_point_reduction_mean(self):
 | 
			
		||||
        """
 | 
			
		||||
@ -707,6 +754,99 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
            loss, loss_norm, pred_loss_sum, pred_loss_norm_sum, p1, p11, p2, p22
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def test_chamfer_point_reduction_none(self):
 | 
			
		||||
        """
 | 
			
		||||
        Compare output of vectorized chamfer loss with naive implementation
 | 
			
		||||
        for point_reduction = None and batch_reduction = None.
 | 
			
		||||
        """
 | 
			
		||||
        N, max_P1, max_P2 = 7, 10, 18
 | 
			
		||||
        device = get_random_cuda_device()
 | 
			
		||||
        points_normals = TestChamfer.init_pointclouds(N, max_P1, max_P2, device)
 | 
			
		||||
        p1 = points_normals.p1
 | 
			
		||||
        p2 = points_normals.p2
 | 
			
		||||
        p1_normals = points_normals.n1
 | 
			
		||||
        p2_normals = points_normals.n2
 | 
			
		||||
        p11 = p1.detach().clone()
 | 
			
		||||
        p22 = p2.detach().clone()
 | 
			
		||||
        p11.requires_grad = True
 | 
			
		||||
        p22.requires_grad = True
 | 
			
		||||
 | 
			
		||||
        pred_loss, pred_loss_norm = TestChamfer.chamfer_distance_naive(
 | 
			
		||||
            p1, p2, x_normals=p1_normals, y_normals=p2_normals
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # point_reduction = None
 | 
			
		||||
        loss, loss_norm = chamfer_distance(
 | 
			
		||||
            p11,
 | 
			
		||||
            p22,
 | 
			
		||||
            x_normals=p1_normals,
 | 
			
		||||
            y_normals=p2_normals,
 | 
			
		||||
            batch_reduction=None,
 | 
			
		||||
            point_reduction=None,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        loss_bidirectional = torch.hstack([loss[0], loss[1]])
 | 
			
		||||
        pred_loss_bidirectional = torch.hstack([pred_loss[0], pred_loss[1]])
 | 
			
		||||
        loss_norm_bidirectional = torch.hstack([loss_norm[0], loss_norm[1]])
 | 
			
		||||
        pred_loss_norm_bidirectional = torch.hstack(
 | 
			
		||||
            [pred_loss_norm[0], pred_loss_norm[1]]
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        self.assertClose(loss_bidirectional, pred_loss_bidirectional)
 | 
			
		||||
        self.assertClose(loss_norm_bidirectional, pred_loss_norm_bidirectional)
 | 
			
		||||
 | 
			
		||||
        # Check gradients
 | 
			
		||||
        self._check_gradients(
 | 
			
		||||
            loss_bidirectional,
 | 
			
		||||
            loss_norm_bidirectional,
 | 
			
		||||
            pred_loss_bidirectional,
 | 
			
		||||
            pred_loss_norm_bidirectional,
 | 
			
		||||
            p1,
 | 
			
		||||
            p11,
 | 
			
		||||
            p2,
 | 
			
		||||
            p22,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def test_single_direction_chamfer_point_reduction_none(self):
 | 
			
		||||
        """
 | 
			
		||||
        Compare output of vectorized chamfer loss with naive implementation
 | 
			
		||||
        for point_reduction = None and batch_reduction = None.
 | 
			
		||||
        """
 | 
			
		||||
        N, max_P1, max_P2 = 7, 10, 18
 | 
			
		||||
        device = get_random_cuda_device()
 | 
			
		||||
        points_normals = TestChamfer.init_pointclouds(N, max_P1, max_P2, device)
 | 
			
		||||
        p1 = points_normals.p1
 | 
			
		||||
        p2 = points_normals.p2
 | 
			
		||||
        p1_normals = points_normals.n1
 | 
			
		||||
        p2_normals = points_normals.n2
 | 
			
		||||
        p11 = p1.detach().clone()
 | 
			
		||||
        p22 = p2.detach().clone()
 | 
			
		||||
        p11.requires_grad = True
 | 
			
		||||
        p22.requires_grad = True
 | 
			
		||||
 | 
			
		||||
        pred_loss, pred_loss_norm = TestChamfer.chamfer_distance_naive(
 | 
			
		||||
            p1, p2, x_normals=p1_normals, y_normals=p2_normals
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # point_reduction = None
 | 
			
		||||
        loss, loss_norm = chamfer_distance(
 | 
			
		||||
            p11,
 | 
			
		||||
            p22,
 | 
			
		||||
            x_normals=p1_normals,
 | 
			
		||||
            y_normals=p2_normals,
 | 
			
		||||
            batch_reduction=None,
 | 
			
		||||
            point_reduction=None,
 | 
			
		||||
            single_directional=True,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        self.assertClose(loss, pred_loss[0])
 | 
			
		||||
        self.assertClose(loss_norm, pred_loss_norm[0])
 | 
			
		||||
 | 
			
		||||
        # Check gradients
 | 
			
		||||
        self._check_gradients(
 | 
			
		||||
            loss, loss_norm, pred_loss[0], pred_loss_norm[0], p1, p11, p2, p22
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def _check_gradients(
 | 
			
		||||
        self,
 | 
			
		||||
        loss,
 | 
			
		||||
@ -880,9 +1020,9 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
        with self.assertRaisesRegex(ValueError, "batch_reduction must be one of"):
 | 
			
		||||
            chamfer_distance(p1, p2, weights=weights, batch_reduction="max")
 | 
			
		||||
 | 
			
		||||
        # Error when point_reduction is not in ["mean", "sum"].
 | 
			
		||||
        # Error when point_reduction is not in ["mean", "sum"] or None.
 | 
			
		||||
        with self.assertRaisesRegex(ValueError, "point_reduction must be one of"):
 | 
			
		||||
            chamfer_distance(p1, p2, weights=weights, point_reduction=None)
 | 
			
		||||
            chamfer_distance(p1, p2, weights=weights, point_reduction="max")
 | 
			
		||||
 | 
			
		||||
    def test_incorrect_weights(self):
 | 
			
		||||
        N, P1, P2 = 16, 64, 128
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user