add None option for chamfer distance point reduction (#1605)

Summary:
The `chamfer_distance` function currently allows `"sum"` or `"mean"` reduction, but does not support returning unreduced (per-point) loss terms. Unreduced losses could be useful if the user wishes to inspect individual losses, or perform additional modifications to loss terms before reduction. One example would be implementing a robust kernel over the loss.

This PR adds a `None` option to the `point_reduction` parameter, similar to `batch_reduction`. In case of bi-directional chamfer loss, both the forward and backward distances are returned (a tuple of Tensors of shape `[D, N]` is returned). If normals are provided, similar logic applies to normals as well.

This PR addresses issue https://github.com/facebookresearch/pytorch3d/issues/622.

Pull Request resolved: https://github.com/facebookresearch/pytorch3d/pull/1605

Reviewed By: jcjohnson

Differential Revision: D48313857

Pulled By: bottler

fbshipit-source-id: 35c824827a143649b04166c4817449e1341b7fd9
This commit is contained in:
Haritha Jayasinghe 2023-08-15 10:36:06 -07:00 committed by Facebook GitHub Bot
parent 099fc069fb
commit d84f274a08
2 changed files with 220 additions and 66 deletions

View File

@ -13,7 +13,7 @@ from pytorch3d.structures.pointclouds import Pointclouds
def _validate_chamfer_reduction_inputs( def _validate_chamfer_reduction_inputs(
batch_reduction: Union[str, None], point_reduction: str batch_reduction: Union[str, None], point_reduction: Union[str, None]
) -> None: ) -> None:
"""Check the requested reductions are valid. """Check the requested reductions are valid.
@ -21,12 +21,14 @@ def _validate_chamfer_reduction_inputs(
batch_reduction: Reduction operation to apply for the loss across the batch_reduction: Reduction operation to apply for the loss across the
batch, can be one of ["mean", "sum"] or None. batch, can be one of ["mean", "sum"] or None.
point_reduction: Reduction operation to apply for the loss across the point_reduction: Reduction operation to apply for the loss across the
points, can be one of ["mean", "sum"]. points, can be one of ["mean", "sum"] or None.
""" """
if batch_reduction is not None and batch_reduction not in ["mean", "sum"]: if batch_reduction is not None and batch_reduction not in ["mean", "sum"]:
raise ValueError('batch_reduction must be one of ["mean", "sum"] or None') raise ValueError('batch_reduction must be one of ["mean", "sum"] or None')
if point_reduction not in ["mean", "sum"]: if point_reduction is not None and point_reduction not in ["mean", "sum"]:
raise ValueError('point_reduction must be one of ["mean", "sum"]') raise ValueError('point_reduction must be one of ["mean", "sum"] or None')
if point_reduction is None and batch_reduction is not None:
raise ValueError("Batch reduction must be None if point_reduction is None")
def _handle_pointcloud_input( def _handle_pointcloud_input(
@ -77,7 +79,7 @@ def _chamfer_distance_single_direction(
y_normals, y_normals,
weights, weights,
batch_reduction: Union[str, None], batch_reduction: Union[str, None],
point_reduction: str, point_reduction: Union[str, None],
norm: int, norm: int,
abs_cosine: bool, abs_cosine: bool,
): ):
@ -130,26 +132,28 @@ def _chamfer_distance_single_direction(
if weights is not None: if weights is not None:
cham_norm_x *= weights.view(N, 1) cham_norm_x *= weights.view(N, 1)
cham_norm_x = cham_norm_x.sum(1) # (N,)
# Apply point reduction if point_reduction is not None:
cham_x = cham_x.sum(1) # (N,) # Apply point reduction
if point_reduction == "mean": cham_x = cham_x.sum(1) # (N,)
x_lengths_clamped = x_lengths.clamp(min=1)
cham_x /= x_lengths_clamped
if return_normals: if return_normals:
cham_norm_x /= x_lengths_clamped cham_norm_x = cham_norm_x.sum(1) # (N,)
if point_reduction == "mean":
if batch_reduction is not None: x_lengths_clamped = x_lengths.clamp(min=1)
# batch_reduction == "sum" cham_x /= x_lengths_clamped
cham_x = cham_x.sum()
if return_normals:
cham_norm_x = cham_norm_x.sum()
if batch_reduction == "mean":
div = weights.sum() if weights is not None else max(N, 1)
cham_x /= div
if return_normals: if return_normals:
cham_norm_x /= div cham_norm_x /= x_lengths_clamped
if batch_reduction is not None:
# batch_reduction == "sum"
cham_x = cham_x.sum()
if return_normals:
cham_norm_x = cham_norm_x.sum()
if batch_reduction == "mean":
div = weights.sum() if weights is not None else max(N, 1)
cham_x /= div
if return_normals:
cham_norm_x /= div
cham_dist = cham_x cham_dist = cham_x
cham_normals = cham_norm_x if return_normals else None cham_normals = cham_norm_x if return_normals else None
@ -165,7 +169,7 @@ def chamfer_distance(
y_normals=None, y_normals=None,
weights=None, weights=None,
batch_reduction: Union[str, None] = "mean", batch_reduction: Union[str, None] = "mean",
point_reduction: str = "mean", point_reduction: Union[str, None] = "mean",
norm: int = 2, norm: int = 2,
single_directional: bool = False, single_directional: bool = False,
abs_cosine: bool = True, abs_cosine: bool = True,
@ -191,7 +195,7 @@ def chamfer_distance(
batch_reduction: Reduction operation to apply for the loss across the batch_reduction: Reduction operation to apply for the loss across the
batch, can be one of ["mean", "sum"] or None. batch, can be one of ["mean", "sum"] or None.
point_reduction: Reduction operation to apply for the loss across the point_reduction: Reduction operation to apply for the loss across the
points, can be one of ["mean", "sum"]. points, can be one of ["mean", "sum"] or None.
norm: int indicates the norm used for the distance. Supports 1 for L1 and 2 for L2. norm: int indicates the norm used for the distance. Supports 1 for L1 and 2 for L2.
single_directional: If False (default), loss comes from both the distance between single_directional: If False (default), loss comes from both the distance between
each point in x and its nearest neighbor in y and each point in y and its nearest each point in x and its nearest neighbor in y and each point in y and its nearest
@ -206,11 +210,16 @@ def chamfer_distance(
2-element tuple containing 2-element tuple containing
- **loss**: Tensor giving the reduced distance between the pointclouds - **loss**: Tensor giving the reduced distance between the pointclouds
in x and the pointclouds in y. in x and the pointclouds in y. If point_reduction is None, a 2-element
tuple of Tensors containing forward and backward loss terms shaped (N, P1)
and (N, P2) (if single_directional is False) or a Tensor containing loss
terms shaped (N, P1) (if single_directional is True) is returned.
- **loss_normals**: Tensor giving the reduced cosine distance of normals - **loss_normals**: Tensor giving the reduced cosine distance of normals
between pointclouds in x and pointclouds in y. Returns None if between pointclouds in x and pointclouds in y. Returns None if
x_normals and y_normals are None. x_normals and y_normals are None. If point_reduction is None, a 2-element
tuple of Tensors containing forward and backward loss terms shaped (N, P1)
and (N, P2) (if single_directional is False) or a Tensor containing loss
terms shaped (N, P1) (if single_directional is True) is returned.
""" """
_validate_chamfer_reduction_inputs(batch_reduction, point_reduction) _validate_chamfer_reduction_inputs(batch_reduction, point_reduction)
@ -248,7 +257,12 @@ def chamfer_distance(
norm, norm,
abs_cosine, abs_cosine,
) )
if point_reduction is not None:
return (
cham_x + cham_y,
(cham_norm_x + cham_norm_y) if cham_norm_x is not None else None,
)
return ( return (
cham_x + cham_y, (cham_x, cham_y),
(cham_norm_x + cham_norm_y) if cham_norm_x is not None else None, (cham_norm_x, cham_norm_y) if cham_norm_x is not None else None,
) )

View File

@ -421,9 +421,9 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
("mean", "mean"), ("mean", "mean"),
("sum", None), ("sum", None),
("mean", None), ("mean", None),
(None, None),
] ]
for (point_reduction, batch_reduction) in reductions: for point_reduction, batch_reduction in reductions:
# Reinitialize all the tensors so that the # Reinitialize all the tensors so that the
# backward pass can be computed. # backward pass can be computed.
points_normals = TestChamfer.init_pointclouds( points_normals = TestChamfer.init_pointclouds(
@ -450,24 +450,52 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
batch_reduction=batch_reduction, batch_reduction=batch_reduction,
) )
self.assertClose(cham_cloud, cham_tensor) if point_reduction is None:
self.assertClose(norm_cloud, norm_tensor) cham_tensor_bidirectional = torch.hstack(
self._check_gradients( [cham_tensor[0], cham_tensor[1]]
cham_tensor, )
norm_tensor, norm_tensor_bidirectional = torch.hstack(
cham_cloud, [norm_tensor[0], norm_tensor[1]]
norm_cloud, )
points_normals.cloud1.points_list(), cham_cloud_bidirectional = torch.hstack([cham_cloud[0], cham_cloud[1]])
points_normals.p1, norm_cloud_bidirectional = torch.hstack([norm_cloud[0], norm_cloud[1]])
points_normals.cloud2.points_list(), self.assertClose(cham_cloud_bidirectional, cham_tensor_bidirectional)
points_normals.p2, self.assertClose(norm_cloud_bidirectional, norm_tensor_bidirectional)
points_normals.cloud1.normals_list(), self._check_gradients(
points_normals.n1, cham_tensor_bidirectional,
points_normals.cloud2.normals_list(), norm_tensor_bidirectional,
points_normals.n2, cham_cloud_bidirectional,
points_normals.p1_lengths, norm_cloud_bidirectional,
points_normals.p2_lengths, points_normals.cloud1.points_list(),
) points_normals.p1,
points_normals.cloud2.points_list(),
points_normals.p2,
points_normals.cloud1.normals_list(),
points_normals.n1,
points_normals.cloud2.normals_list(),
points_normals.n2,
points_normals.p1_lengths,
points_normals.p2_lengths,
)
else:
self.assertClose(cham_cloud, cham_tensor)
self.assertClose(norm_cloud, norm_tensor)
self._check_gradients(
cham_tensor,
norm_tensor,
cham_cloud,
norm_cloud,
points_normals.cloud1.points_list(),
points_normals.p1,
points_normals.cloud2.points_list(),
points_normals.p2,
points_normals.cloud1.normals_list(),
points_normals.n1,
points_normals.cloud2.normals_list(),
points_normals.n2,
points_normals.p1_lengths,
points_normals.p2_lengths,
)
def test_chamfer_pointcloud_object_nonormals(self): def test_chamfer_pointcloud_object_nonormals(self):
N = 5 N = 5
@ -481,9 +509,9 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
("mean", "mean"), ("mean", "mean"),
("sum", None), ("sum", None),
("mean", None), ("mean", None),
(None, None),
] ]
for (point_reduction, batch_reduction) in reductions: for point_reduction, batch_reduction in reductions:
# Reinitialize all the tensors so that the # Reinitialize all the tensors so that the
# backward pass can be computed. # backward pass can be computed.
points_normals = TestChamfer.init_pointclouds( points_normals = TestChamfer.init_pointclouds(
@ -508,19 +536,38 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
batch_reduction=batch_reduction, batch_reduction=batch_reduction,
) )
self.assertClose(cham_cloud, cham_tensor) if point_reduction is None:
self._check_gradients( cham_tensor_bidirectional = torch.hstack(
cham_tensor, [cham_tensor[0], cham_tensor[1]]
None, )
cham_cloud, cham_cloud_bidirectional = torch.hstack([cham_cloud[0], cham_cloud[1]])
None, self.assertClose(cham_cloud_bidirectional, cham_tensor_bidirectional)
points_normals.cloud1.points_list(), self._check_gradients(
points_normals.p1, cham_tensor_bidirectional,
points_normals.cloud2.points_list(), None,
points_normals.p2, cham_cloud_bidirectional,
lengths1=points_normals.p1_lengths, None,
lengths2=points_normals.p2_lengths, points_normals.cloud1.points_list(),
) points_normals.p1,
points_normals.cloud2.points_list(),
points_normals.p2,
lengths1=points_normals.p1_lengths,
lengths2=points_normals.p2_lengths,
)
else:
self.assertClose(cham_cloud, cham_tensor)
self._check_gradients(
cham_tensor,
None,
cham_cloud,
None,
points_normals.cloud1.points_list(),
points_normals.p1,
points_normals.cloud2.points_list(),
points_normals.p2,
lengths1=points_normals.p1_lengths,
lengths2=points_normals.p2_lengths,
)
def test_chamfer_point_reduction_mean(self): def test_chamfer_point_reduction_mean(self):
""" """
@ -707,6 +754,99 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
loss, loss_norm, pred_loss_sum, pred_loss_norm_sum, p1, p11, p2, p22 loss, loss_norm, pred_loss_sum, pred_loss_norm_sum, p1, p11, p2, p22
) )
def test_chamfer_point_reduction_none(self):
"""
Compare output of vectorized chamfer loss with naive implementation
for point_reduction = None and batch_reduction = None.
"""
N, max_P1, max_P2 = 7, 10, 18
device = get_random_cuda_device()
points_normals = TestChamfer.init_pointclouds(N, max_P1, max_P2, device)
p1 = points_normals.p1
p2 = points_normals.p2
p1_normals = points_normals.n1
p2_normals = points_normals.n2
p11 = p1.detach().clone()
p22 = p2.detach().clone()
p11.requires_grad = True
p22.requires_grad = True
pred_loss, pred_loss_norm = TestChamfer.chamfer_distance_naive(
p1, p2, x_normals=p1_normals, y_normals=p2_normals
)
# point_reduction = None
loss, loss_norm = chamfer_distance(
p11,
p22,
x_normals=p1_normals,
y_normals=p2_normals,
batch_reduction=None,
point_reduction=None,
)
loss_bidirectional = torch.hstack([loss[0], loss[1]])
pred_loss_bidirectional = torch.hstack([pred_loss[0], pred_loss[1]])
loss_norm_bidirectional = torch.hstack([loss_norm[0], loss_norm[1]])
pred_loss_norm_bidirectional = torch.hstack(
[pred_loss_norm[0], pred_loss_norm[1]]
)
self.assertClose(loss_bidirectional, pred_loss_bidirectional)
self.assertClose(loss_norm_bidirectional, pred_loss_norm_bidirectional)
# Check gradients
self._check_gradients(
loss_bidirectional,
loss_norm_bidirectional,
pred_loss_bidirectional,
pred_loss_norm_bidirectional,
p1,
p11,
p2,
p22,
)
def test_single_direction_chamfer_point_reduction_none(self):
"""
Compare output of vectorized chamfer loss with naive implementation
for point_reduction = None and batch_reduction = None.
"""
N, max_P1, max_P2 = 7, 10, 18
device = get_random_cuda_device()
points_normals = TestChamfer.init_pointclouds(N, max_P1, max_P2, device)
p1 = points_normals.p1
p2 = points_normals.p2
p1_normals = points_normals.n1
p2_normals = points_normals.n2
p11 = p1.detach().clone()
p22 = p2.detach().clone()
p11.requires_grad = True
p22.requires_grad = True
pred_loss, pred_loss_norm = TestChamfer.chamfer_distance_naive(
p1, p2, x_normals=p1_normals, y_normals=p2_normals
)
# point_reduction = None
loss, loss_norm = chamfer_distance(
p11,
p22,
x_normals=p1_normals,
y_normals=p2_normals,
batch_reduction=None,
point_reduction=None,
single_directional=True,
)
self.assertClose(loss, pred_loss[0])
self.assertClose(loss_norm, pred_loss_norm[0])
# Check gradients
self._check_gradients(
loss, loss_norm, pred_loss[0], pred_loss_norm[0], p1, p11, p2, p22
)
def _check_gradients( def _check_gradients(
self, self,
loss, loss,
@ -880,9 +1020,9 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
with self.assertRaisesRegex(ValueError, "batch_reduction must be one of"): with self.assertRaisesRegex(ValueError, "batch_reduction must be one of"):
chamfer_distance(p1, p2, weights=weights, batch_reduction="max") chamfer_distance(p1, p2, weights=weights, batch_reduction="max")
# Error when point_reduction is not in ["mean", "sum"]. # Error when point_reduction is not in ["mean", "sum"] or None.
with self.assertRaisesRegex(ValueError, "point_reduction must be one of"): with self.assertRaisesRegex(ValueError, "point_reduction must be one of"):
chamfer_distance(p1, p2, weights=weights, point_reduction=None) chamfer_distance(p1, p2, weights=weights, point_reduction="max")
def test_incorrect_weights(self): def test_incorrect_weights(self):
N, P1, P2 = 16, 64, 128 N, P1, P2 = 16, 64, 128