mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-07-31 10:52:50 +08:00
Add "max" point reduction for chamfer distance
Summary: * Adds a "max" option for the point_reduction input to the chamfer_distance function. * When combining the x and y directions, maxes the losses instead of summing them when point_reduction="max". * Moves batch reduction to happen after the directions are combined. * Adds test_chamfer_point_reduction_max and test_single_directional_chamfer_point_reduction_max tests. Fixes https://github.com/facebookresearch/pytorch3d/issues/1838 Reviewed By: bottler Differential Revision: D60614661 fbshipit-source-id: 7879816acfda03e945bada951b931d2c522756eb
This commit is contained in:
parent
7edaee71a9
commit
44702fdb4b
@ -27,8 +27,10 @@ def _validate_chamfer_reduction_inputs(
|
||||
"""
|
||||
if batch_reduction is not None and batch_reduction not in ["mean", "sum"]:
|
||||
raise ValueError('batch_reduction must be one of ["mean", "sum"] or None')
|
||||
if point_reduction is not None and point_reduction not in ["mean", "sum"]:
|
||||
raise ValueError('point_reduction must be one of ["mean", "sum"] or None')
|
||||
if point_reduction is not None and point_reduction not in ["mean", "sum", "max"]:
|
||||
raise ValueError(
|
||||
'point_reduction must be one of ["mean", "sum", "max"] or None'
|
||||
)
|
||||
if point_reduction is None and batch_reduction is not None:
|
||||
raise ValueError("Batch reduction must be None if point_reduction is None")
|
||||
|
||||
@ -80,7 +82,6 @@ def _chamfer_distance_single_direction(
|
||||
x_normals,
|
||||
y_normals,
|
||||
weights,
|
||||
batch_reduction: Union[str, None],
|
||||
point_reduction: Union[str, None],
|
||||
norm: int,
|
||||
abs_cosine: bool,
|
||||
@ -103,11 +104,6 @@ def _chamfer_distance_single_direction(
|
||||
raise ValueError("weights cannot be negative.")
|
||||
if weights.sum() == 0.0:
|
||||
weights = weights.view(N, 1)
|
||||
if batch_reduction in ["mean", "sum"]:
|
||||
return (
|
||||
(x.sum((1, 2)) * weights).sum() * 0.0,
|
||||
(x.sum((1, 2)) * weights).sum() * 0.0,
|
||||
)
|
||||
return ((x.sum((1, 2)) * weights) * 0.0, (x.sum((1, 2)) * weights) * 0.0)
|
||||
|
||||
cham_norm_x = x.new_zeros(())
|
||||
@ -135,7 +131,10 @@ def _chamfer_distance_single_direction(
|
||||
if weights is not None:
|
||||
cham_norm_x *= weights.view(N, 1)
|
||||
|
||||
if point_reduction is not None:
|
||||
if point_reduction == "max":
|
||||
assert not return_normals
|
||||
cham_x = cham_x.max(1).values # (N,)
|
||||
elif point_reduction is not None:
|
||||
# Apply point reduction
|
||||
cham_x = cham_x.sum(1) # (N,)
|
||||
if return_normals:
|
||||
@ -146,22 +145,34 @@ def _chamfer_distance_single_direction(
|
||||
if return_normals:
|
||||
cham_norm_x /= x_lengths_clamped
|
||||
|
||||
if batch_reduction is not None:
|
||||
# batch_reduction == "sum"
|
||||
cham_x = cham_x.sum()
|
||||
if return_normals:
|
||||
cham_norm_x = cham_norm_x.sum()
|
||||
if batch_reduction == "mean":
|
||||
div = weights.sum() if weights is not None else max(N, 1)
|
||||
cham_x /= div
|
||||
if return_normals:
|
||||
cham_norm_x /= div
|
||||
|
||||
cham_dist = cham_x
|
||||
cham_normals = cham_norm_x if return_normals else None
|
||||
return cham_dist, cham_normals
|
||||
|
||||
|
||||
def _apply_batch_reduction(
|
||||
cham_x, cham_norm_x, weights, batch_reduction: Union[str, None]
|
||||
):
|
||||
if batch_reduction is None:
|
||||
return (cham_x, cham_norm_x)
|
||||
# batch_reduction == "sum"
|
||||
N = cham_x.shape[0]
|
||||
cham_x = cham_x.sum()
|
||||
if cham_norm_x is not None:
|
||||
cham_norm_x = cham_norm_x.sum()
|
||||
if batch_reduction == "mean":
|
||||
if weights is None:
|
||||
div = max(N, 1)
|
||||
elif weights.sum() == 0.0:
|
||||
div = 1
|
||||
else:
|
||||
div = weights.sum()
|
||||
cham_x /= div
|
||||
if cham_norm_x is not None:
|
||||
cham_norm_x /= div
|
||||
return (cham_x, cham_norm_x)
|
||||
|
||||
|
||||
def chamfer_distance(
|
||||
x,
|
||||
y,
|
||||
@ -197,7 +208,8 @@ def chamfer_distance(
|
||||
batch_reduction: Reduction operation to apply for the loss across the
|
||||
batch, can be one of ["mean", "sum"] or None.
|
||||
point_reduction: Reduction operation to apply for the loss across the
|
||||
points, can be one of ["mean", "sum"] or None.
|
||||
points, can be one of ["mean", "sum", "max"] or None. Using "max" leads to the
|
||||
Hausdorff distance.
|
||||
norm: int indicates the norm used for the distance. Supports 1 for L1 and 2 for L2.
|
||||
single_directional: If False (default), loss comes from both the distance between
|
||||
each point in x and its nearest neighbor in y and each point in y and its nearest
|
||||
@ -227,6 +239,10 @@ def chamfer_distance(
|
||||
|
||||
if not ((norm == 1) or (norm == 2)):
|
||||
raise ValueError("Support for 1 or 2 norm.")
|
||||
|
||||
if point_reduction == "max" and (x_normals is not None or y_normals is not None):
|
||||
raise ValueError('Normals must be None if point_reduction is "max"')
|
||||
|
||||
x, x_lengths, x_normals = _handle_pointcloud_input(x, x_lengths, x_normals)
|
||||
y, y_lengths, y_normals = _handle_pointcloud_input(y, y_lengths, y_normals)
|
||||
|
||||
@ -238,13 +254,13 @@ def chamfer_distance(
|
||||
x_normals,
|
||||
y_normals,
|
||||
weights,
|
||||
batch_reduction,
|
||||
point_reduction,
|
||||
norm,
|
||||
abs_cosine,
|
||||
)
|
||||
if single_directional:
|
||||
return cham_x, cham_norm_x
|
||||
loss = cham_x
|
||||
loss_normals = cham_norm_x
|
||||
else:
|
||||
cham_y, cham_norm_y = _chamfer_distance_single_direction(
|
||||
y,
|
||||
@ -254,17 +270,23 @@ def chamfer_distance(
|
||||
y_normals,
|
||||
x_normals,
|
||||
weights,
|
||||
batch_reduction,
|
||||
point_reduction,
|
||||
norm,
|
||||
abs_cosine,
|
||||
)
|
||||
if point_reduction is not None:
|
||||
return (
|
||||
cham_x + cham_y,
|
||||
(cham_norm_x + cham_norm_y) if cham_norm_x is not None else None,
|
||||
)
|
||||
return (
|
||||
(cham_x, cham_y),
|
||||
(cham_norm_x, cham_norm_y) if cham_norm_x is not None else None,
|
||||
)
|
||||
if point_reduction == "max":
|
||||
loss = torch.maximum(cham_x, cham_y)
|
||||
loss_normals = None
|
||||
elif point_reduction is not None:
|
||||
loss = cham_x + cham_y
|
||||
if cham_norm_x is not None:
|
||||
loss_normals = cham_norm_x + cham_norm_y
|
||||
else:
|
||||
loss_normals = None
|
||||
else:
|
||||
loss = (cham_x, cham_y)
|
||||
if cham_norm_x is not None:
|
||||
loss_normals = (cham_norm_x, cham_norm_y)
|
||||
else:
|
||||
loss_normals = None
|
||||
return _apply_batch_reduction(loss, loss_normals, weights, batch_reduction)
|
||||
|
@ -847,6 +847,85 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
|
||||
loss, loss_norm, pred_loss[0], pred_loss_norm[0], p1, p11, p2, p22
|
||||
)
|
||||
|
||||
def test_chamfer_point_reduction_max(self):
|
||||
"""
|
||||
Compare output of vectorized chamfer loss with naive implementation
|
||||
for point_reduction = "max" and batch_reduction = None.
|
||||
"""
|
||||
N, P1, P2 = 7, 10, 18
|
||||
device = get_random_cuda_device()
|
||||
points_normals = TestChamfer.init_pointclouds(N, P1, P2, device)
|
||||
p1 = points_normals.p1
|
||||
p2 = points_normals.p2
|
||||
weights = points_normals.weights
|
||||
p11 = p1.detach().clone()
|
||||
p22 = p2.detach().clone()
|
||||
p11.requires_grad = True
|
||||
p22.requires_grad = True
|
||||
|
||||
pred_loss, unused_pred_loss_norm = TestChamfer.chamfer_distance_naive(
|
||||
p1, p2, x_normals=None, y_normals=None
|
||||
)
|
||||
|
||||
loss, loss_norm = chamfer_distance(
|
||||
p11,
|
||||
p22,
|
||||
x_normals=None,
|
||||
y_normals=None,
|
||||
weights=weights,
|
||||
batch_reduction=None,
|
||||
point_reduction="max",
|
||||
)
|
||||
pred_loss_max = torch.maximum(
|
||||
pred_loss[0].max(1).values, pred_loss[1].max(1).values
|
||||
)
|
||||
pred_loss_max *= weights
|
||||
self.assertClose(loss, pred_loss_max)
|
||||
|
||||
self.assertIsNone(loss_norm)
|
||||
|
||||
# Check gradients
|
||||
self._check_gradients(loss, loss_norm, pred_loss_max, None, p1, p11, p2, p22)
|
||||
|
||||
def test_single_directional_chamfer_point_reduction_max(self):
|
||||
"""
|
||||
Compare output of vectorized single directional chamfer loss with naive implementation
|
||||
for point_reduction = "max" and batch_reduction = None.
|
||||
"""
|
||||
N, P1, P2 = 7, 10, 18
|
||||
device = get_random_cuda_device()
|
||||
points_normals = TestChamfer.init_pointclouds(N, P1, P2, device)
|
||||
p1 = points_normals.p1
|
||||
p2 = points_normals.p2
|
||||
weights = points_normals.weights
|
||||
p11 = p1.detach().clone()
|
||||
p22 = p2.detach().clone()
|
||||
p11.requires_grad = True
|
||||
p22.requires_grad = True
|
||||
|
||||
pred_loss, unused_pred_loss_norm = TestChamfer.chamfer_distance_naive(
|
||||
p1, p2, x_normals=None, y_normals=None
|
||||
)
|
||||
|
||||
loss, loss_norm = chamfer_distance(
|
||||
p11,
|
||||
p22,
|
||||
x_normals=None,
|
||||
y_normals=None,
|
||||
weights=weights,
|
||||
batch_reduction=None,
|
||||
point_reduction="max",
|
||||
single_directional=True,
|
||||
)
|
||||
pred_loss_max = pred_loss[0].max(1).values
|
||||
pred_loss_max *= weights
|
||||
self.assertClose(loss, pred_loss_max)
|
||||
|
||||
self.assertIsNone(loss_norm)
|
||||
|
||||
# Check gradients
|
||||
self._check_gradients(loss, loss_norm, pred_loss_max, None, p1, p11, p2, p22)
|
||||
|
||||
def _check_gradients(
|
||||
self,
|
||||
loss,
|
||||
@ -1020,9 +1099,9 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
|
||||
with self.assertRaisesRegex(ValueError, "batch_reduction must be one of"):
|
||||
chamfer_distance(p1, p2, weights=weights, batch_reduction="max")
|
||||
|
||||
# Error when point_reduction is not in ["mean", "sum"] or None.
|
||||
# Error when point_reduction is not in ["mean", "sum", "max"] or None.
|
||||
with self.assertRaisesRegex(ValueError, "point_reduction must be one of"):
|
||||
chamfer_distance(p1, p2, weights=weights, point_reduction="max")
|
||||
chamfer_distance(p1, p2, weights=weights, point_reduction="min")
|
||||
|
||||
def test_incorrect_weights(self):
|
||||
N, P1, P2 = 16, 64, 128
|
||||
|
Loading…
x
Reference in New Issue
Block a user