mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-20 14:20:38 +08:00
Add "max" point reduction for chamfer distance
Summary: * Adds a "max" option for the point_reduction input to the chamfer_distance function. * When combining the x and y directions, maxes the losses instead of summing them when point_reduction="max". * Moves batch reduction to happen after the directions are combined. * Adds test_chamfer_point_reduction_max and test_single_directional_chamfer_point_reduction_max tests. Fixes https://github.com/facebookresearch/pytorch3d/issues/1838 Reviewed By: bottler Differential Revision: D60614661 fbshipit-source-id: 7879816acfda03e945bada951b931d2c522756eb
This commit is contained in:
committed by
Facebook GitHub Bot
parent
7edaee71a9
commit
44702fdb4b
@@ -847,6 +847,85 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
|
||||
loss, loss_norm, pred_loss[0], pred_loss_norm[0], p1, p11, p2, p22
|
||||
)
|
||||
|
||||
def test_chamfer_point_reduction_max(self):
|
||||
"""
|
||||
Compare output of vectorized chamfer loss with naive implementation
|
||||
for point_reduction = "max" and batch_reduction = None.
|
||||
"""
|
||||
N, P1, P2 = 7, 10, 18
|
||||
device = get_random_cuda_device()
|
||||
points_normals = TestChamfer.init_pointclouds(N, P1, P2, device)
|
||||
p1 = points_normals.p1
|
||||
p2 = points_normals.p2
|
||||
weights = points_normals.weights
|
||||
p11 = p1.detach().clone()
|
||||
p22 = p2.detach().clone()
|
||||
p11.requires_grad = True
|
||||
p22.requires_grad = True
|
||||
|
||||
pred_loss, unused_pred_loss_norm = TestChamfer.chamfer_distance_naive(
|
||||
p1, p2, x_normals=None, y_normals=None
|
||||
)
|
||||
|
||||
loss, loss_norm = chamfer_distance(
|
||||
p11,
|
||||
p22,
|
||||
x_normals=None,
|
||||
y_normals=None,
|
||||
weights=weights,
|
||||
batch_reduction=None,
|
||||
point_reduction="max",
|
||||
)
|
||||
pred_loss_max = torch.maximum(
|
||||
pred_loss[0].max(1).values, pred_loss[1].max(1).values
|
||||
)
|
||||
pred_loss_max *= weights
|
||||
self.assertClose(loss, pred_loss_max)
|
||||
|
||||
self.assertIsNone(loss_norm)
|
||||
|
||||
# Check gradients
|
||||
self._check_gradients(loss, loss_norm, pred_loss_max, None, p1, p11, p2, p22)
|
||||
|
||||
def test_single_directional_chamfer_point_reduction_max(self):
|
||||
"""
|
||||
Compare output of vectorized single directional chamfer loss with naive implementation
|
||||
for point_reduction = "max" and batch_reduction = None.
|
||||
"""
|
||||
N, P1, P2 = 7, 10, 18
|
||||
device = get_random_cuda_device()
|
||||
points_normals = TestChamfer.init_pointclouds(N, P1, P2, device)
|
||||
p1 = points_normals.p1
|
||||
p2 = points_normals.p2
|
||||
weights = points_normals.weights
|
||||
p11 = p1.detach().clone()
|
||||
p22 = p2.detach().clone()
|
||||
p11.requires_grad = True
|
||||
p22.requires_grad = True
|
||||
|
||||
pred_loss, unused_pred_loss_norm = TestChamfer.chamfer_distance_naive(
|
||||
p1, p2, x_normals=None, y_normals=None
|
||||
)
|
||||
|
||||
loss, loss_norm = chamfer_distance(
|
||||
p11,
|
||||
p22,
|
||||
x_normals=None,
|
||||
y_normals=None,
|
||||
weights=weights,
|
||||
batch_reduction=None,
|
||||
point_reduction="max",
|
||||
single_directional=True,
|
||||
)
|
||||
pred_loss_max = pred_loss[0].max(1).values
|
||||
pred_loss_max *= weights
|
||||
self.assertClose(loss, pred_loss_max)
|
||||
|
||||
self.assertIsNone(loss_norm)
|
||||
|
||||
# Check gradients
|
||||
self._check_gradients(loss, loss_norm, pred_loss_max, None, p1, p11, p2, p22)
|
||||
|
||||
def _check_gradients(
|
||||
self,
|
||||
loss,
|
||||
@@ -1020,9 +1099,9 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
|
||||
with self.assertRaisesRegex(ValueError, "batch_reduction must be one of"):
|
||||
chamfer_distance(p1, p2, weights=weights, batch_reduction="max")
|
||||
|
||||
# Error when point_reduction is not in ["mean", "sum"] or None.
|
||||
# Error when point_reduction is not in ["mean", "sum", "max"] or None.
|
||||
with self.assertRaisesRegex(ValueError, "point_reduction must be one of"):
|
||||
chamfer_distance(p1, p2, weights=weights, point_reduction="max")
|
||||
chamfer_distance(p1, p2, weights=weights, point_reduction="min")
|
||||
|
||||
def test_incorrect_weights(self):
|
||||
N, P1, P2 = 16, 64, 128
|
||||
|
||||
Reference in New Issue
Block a user