mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-20 22:30:35 +08:00
add None option for chamfer distance point reduction (#1605)
Summary: The `chamfer_distance` function currently allows `"sum"` or `"mean"` reduction, but does not support returning unreduced (per-point) loss terms. Unreduced losses could be useful if the user wishes to inspect individual losses, or perform additional modifications to loss terms before reduction. One example would be implementing a robust kernel over the loss. This PR adds a `None` option to the `point_reduction` parameter, similar to `batch_reduction`. In case of bi-directional chamfer loss, both the forward and backward distances are returned (a tuple of Tensors of shape `[D, N]` is returned). If normals are provided, similar logic applies to normals as well. This PR addresses issue https://github.com/facebookresearch/pytorch3d/issues/622. Pull Request resolved: https://github.com/facebookresearch/pytorch3d/pull/1605 Reviewed By: jcjohnson Differential Revision: D48313857 Pulled By: bottler fbshipit-source-id: 35c824827a143649b04166c4817449e1341b7fd9
This commit is contained in:
committed by
Facebook GitHub Bot
parent
099fc069fb
commit
d84f274a08
@@ -13,7 +13,7 @@ from pytorch3d.structures.pointclouds import Pointclouds
|
||||
|
||||
|
||||
def _validate_chamfer_reduction_inputs(
|
||||
batch_reduction: Union[str, None], point_reduction: str
|
||||
batch_reduction: Union[str, None], point_reduction: Union[str, None]
|
||||
) -> None:
|
||||
"""Check the requested reductions are valid.
|
||||
|
||||
@@ -21,12 +21,14 @@ def _validate_chamfer_reduction_inputs(
|
||||
batch_reduction: Reduction operation to apply for the loss across the
|
||||
batch, can be one of ["mean", "sum"] or None.
|
||||
point_reduction: Reduction operation to apply for the loss across the
|
||||
points, can be one of ["mean", "sum"].
|
||||
points, can be one of ["mean", "sum"] or None.
|
||||
"""
|
||||
if batch_reduction is not None and batch_reduction not in ["mean", "sum"]:
|
||||
raise ValueError('batch_reduction must be one of ["mean", "sum"] or None')
|
||||
if point_reduction not in ["mean", "sum"]:
|
||||
raise ValueError('point_reduction must be one of ["mean", "sum"]')
|
||||
if point_reduction is not None and point_reduction not in ["mean", "sum"]:
|
||||
raise ValueError('point_reduction must be one of ["mean", "sum"] or None')
|
||||
if point_reduction is None and batch_reduction is not None:
|
||||
raise ValueError("Batch reduction must be None if point_reduction is None")
|
||||
|
||||
|
||||
def _handle_pointcloud_input(
|
||||
@@ -77,7 +79,7 @@ def _chamfer_distance_single_direction(
|
||||
y_normals,
|
||||
weights,
|
||||
batch_reduction: Union[str, None],
|
||||
point_reduction: str,
|
||||
point_reduction: Union[str, None],
|
||||
norm: int,
|
||||
abs_cosine: bool,
|
||||
):
|
||||
@@ -130,26 +132,28 @@ def _chamfer_distance_single_direction(
|
||||
|
||||
if weights is not None:
|
||||
cham_norm_x *= weights.view(N, 1)
|
||||
cham_norm_x = cham_norm_x.sum(1) # (N,)
|
||||
|
||||
# Apply point reduction
|
||||
cham_x = cham_x.sum(1) # (N,)
|
||||
if point_reduction == "mean":
|
||||
x_lengths_clamped = x_lengths.clamp(min=1)
|
||||
cham_x /= x_lengths_clamped
|
||||
if point_reduction is not None:
|
||||
# Apply point reduction
|
||||
cham_x = cham_x.sum(1) # (N,)
|
||||
if return_normals:
|
||||
cham_norm_x /= x_lengths_clamped
|
||||
|
||||
if batch_reduction is not None:
|
||||
# batch_reduction == "sum"
|
||||
cham_x = cham_x.sum()
|
||||
if return_normals:
|
||||
cham_norm_x = cham_norm_x.sum()
|
||||
if batch_reduction == "mean":
|
||||
div = weights.sum() if weights is not None else max(N, 1)
|
||||
cham_x /= div
|
||||
cham_norm_x = cham_norm_x.sum(1) # (N,)
|
||||
if point_reduction == "mean":
|
||||
x_lengths_clamped = x_lengths.clamp(min=1)
|
||||
cham_x /= x_lengths_clamped
|
||||
if return_normals:
|
||||
cham_norm_x /= div
|
||||
cham_norm_x /= x_lengths_clamped
|
||||
|
||||
if batch_reduction is not None:
|
||||
# batch_reduction == "sum"
|
||||
cham_x = cham_x.sum()
|
||||
if return_normals:
|
||||
cham_norm_x = cham_norm_x.sum()
|
||||
if batch_reduction == "mean":
|
||||
div = weights.sum() if weights is not None else max(N, 1)
|
||||
cham_x /= div
|
||||
if return_normals:
|
||||
cham_norm_x /= div
|
||||
|
||||
cham_dist = cham_x
|
||||
cham_normals = cham_norm_x if return_normals else None
|
||||
@@ -165,7 +169,7 @@ def chamfer_distance(
|
||||
y_normals=None,
|
||||
weights=None,
|
||||
batch_reduction: Union[str, None] = "mean",
|
||||
point_reduction: str = "mean",
|
||||
point_reduction: Union[str, None] = "mean",
|
||||
norm: int = 2,
|
||||
single_directional: bool = False,
|
||||
abs_cosine: bool = True,
|
||||
@@ -191,7 +195,7 @@ def chamfer_distance(
|
||||
batch_reduction: Reduction operation to apply for the loss across the
|
||||
batch, can be one of ["mean", "sum"] or None.
|
||||
point_reduction: Reduction operation to apply for the loss across the
|
||||
points, can be one of ["mean", "sum"].
|
||||
points, can be one of ["mean", "sum"] or None.
|
||||
norm: int indicates the norm used for the distance. Supports 1 for L1 and 2 for L2.
|
||||
single_directional: If False (default), loss comes from both the distance between
|
||||
each point in x and its nearest neighbor in y and each point in y and its nearest
|
||||
@@ -206,11 +210,16 @@ def chamfer_distance(
|
||||
2-element tuple containing
|
||||
|
||||
- **loss**: Tensor giving the reduced distance between the pointclouds
|
||||
in x and the pointclouds in y.
|
||||
in x and the pointclouds in y. If point_reduction is None, a 2-element
|
||||
tuple of Tensors containing forward and backward loss terms shaped (N, P1)
|
||||
and (N, P2) (if single_directional is False) or a Tensor containing loss
|
||||
terms shaped (N, P1) (if single_directional is True) is returned.
|
||||
- **loss_normals**: Tensor giving the reduced cosine distance of normals
|
||||
between pointclouds in x and pointclouds in y. Returns None if
|
||||
x_normals and y_normals are None.
|
||||
|
||||
x_normals and y_normals are None. If point_reduction is None, a 2-element
|
||||
tuple of Tensors containing forward and backward loss terms shaped (N, P1)
|
||||
and (N, P2) (if single_directional is False) or a Tensor containing loss
|
||||
terms shaped (N, P1) (if single_directional is True) is returned.
|
||||
"""
|
||||
_validate_chamfer_reduction_inputs(batch_reduction, point_reduction)
|
||||
|
||||
@@ -248,7 +257,12 @@ def chamfer_distance(
|
||||
norm,
|
||||
abs_cosine,
|
||||
)
|
||||
if point_reduction is not None:
|
||||
return (
|
||||
cham_x + cham_y,
|
||||
(cham_norm_x + cham_norm_y) if cham_norm_x is not None else None,
|
||||
)
|
||||
return (
|
||||
cham_x + cham_y,
|
||||
(cham_norm_x + cham_norm_y) if cham_norm_x is not None else None,
|
||||
(cham_x, cham_y),
|
||||
(cham_norm_x, cham_norm_y) if cham_norm_x is not None else None,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user