From c6519f29f0512e209906f8265e0d049085670304 Mon Sep 17 00:00:00 2001 From: Jeremy Reizenstein Date: Thu, 26 May 2022 14:56:22 -0700 Subject: [PATCH] chamfer for empty pointclouds #1174 Summary: Fix divide by zero for empty pointcloud in chamfer. Also for empty batches. In process, needed to regularize num_points_per_cloud for empty batches. Reviewed By: kjchalup Differential Revision: D36311330 fbshipit-source-id: 3378ab738bee77ecc286f2110a5c8dc445960340 --- pytorch3d/implicitron/models/model_dbir.py | 2 +- pytorch3d/loss/chamfer.py | 12 +++++++----- pytorch3d/loss/point_mesh_distance.py | 2 ++ pytorch3d/structures/pointclouds.py | 5 +++-- tests/test_chamfer.py | 11 +++++++++++ 5 files changed, 24 insertions(+), 8 deletions(-) diff --git a/pytorch3d/implicitron/models/model_dbir.py b/pytorch3d/implicitron/models/model_dbir.py index 780d5995..47eab20e 100644 --- a/pytorch3d/implicitron/models/model_dbir.py +++ b/pytorch3d/implicitron/models/model_dbir.py @@ -99,7 +99,7 @@ class ModelDBIR(ImplicitronModelBase, torch.nn.Module): mask_fg[is_known_idx], ) - pcl_size = int(point_cloud.num_points_per_cloud()) + pcl_size = point_cloud.num_points_per_cloud().item() if (self.max_points > 0) and (pcl_size > self.max_points): prm = torch.randperm(pcl_size)[: self.max_points] point_cloud = Pointclouds( diff --git a/pytorch3d/loss/chamfer.py b/pytorch3d/loss/chamfer.py index 4d65989e..dc2158fa 100644 --- a/pytorch3d/loss/chamfer.py +++ b/pytorch3d/loss/chamfer.py @@ -197,11 +197,13 @@ def chamfer_distance( cham_norm_x = cham_norm_x.sum(1) # (N,) cham_norm_y = cham_norm_y.sum(1) # (N,) if point_reduction == "mean": - cham_x /= x_lengths - cham_y /= y_lengths + x_lengths_clamped = x_lengths.clamp(min=1) + y_lengths_clamped = y_lengths.clamp(min=1) + cham_x /= x_lengths_clamped + cham_y /= y_lengths_clamped if return_normals: - cham_norm_x /= x_lengths - cham_norm_y /= y_lengths + cham_norm_x /= x_lengths_clamped + cham_norm_y /= y_lengths_clamped if batch_reduction is not None: # batch_reduction == "sum" @@ -211,7 +213,7 @@ def chamfer_distance( cham_norm_x = cham_norm_x.sum() cham_norm_y = cham_norm_y.sum() if batch_reduction == "mean": - div = weights.sum() if weights is not None else N + div = weights.sum() if weights is not None else max(N, 1) cham_x /= div cham_y /= div if return_normals: diff --git a/pytorch3d/loss/point_mesh_distance.py b/pytorch3d/loss/point_mesh_distance.py index e901f9da..40497c94 100644 --- a/pytorch3d/loss/point_mesh_distance.py +++ b/pytorch3d/loss/point_mesh_distance.py @@ -303,6 +303,7 @@ def point_mesh_edge_distance(meshes: Meshes, pcls: Pointclouds): # weight each example by the inverse of number of points in the example point_to_cloud_idx = pcls.packed_to_cloud_idx() # (sum(P_i), ) num_points_per_cloud = pcls.num_points_per_cloud() # (N,) + # pyre-ignore[16]: `torch.Tensor` has no attribute `gather` weights_p = num_points_per_cloud.gather(0, point_to_cloud_idx) weights_p = 1.0 / weights_p.float() point_to_edge = point_to_edge * weights_p @@ -377,6 +378,7 @@ def point_mesh_face_distance( # weight each example by the inverse of number of points in the example point_to_cloud_idx = pcls.packed_to_cloud_idx() # (sum(P_i),) num_points_per_cloud = pcls.num_points_per_cloud() # (N,) + # pyre-ignore[16]: `torch.Tensor` has no attribute `gather` weights_p = num_points_per_cloud.gather(0, point_to_cloud_idx) weights_p = 1.0 / weights_p.float() point_to_face = point_to_face * weights_p diff --git a/pytorch3d/structures/pointclouds.py b/pytorch3d/structures/pointclouds.py index bc990c11..3dd2d126 100644 --- a/pytorch3d/structures/pointclouds.py +++ b/pytorch3d/structures/pointclouds.py @@ -185,7 +185,6 @@ class Pointclouds: self._points_list = points self._N = len(self._points_list) self.valid = torch.zeros((self._N,), dtype=torch.bool, device=self.device) - self._num_points_per_cloud = [] if self._N > 0: self.device = self._points_list[0].device @@ -208,6 +207,8 @@ class Pointclouds: if len(num_points_per_cloud.unique()) == 1: self.equisized = True self._num_points_per_cloud = num_points_per_cloud + else: + self._num_points_per_cloud = torch.tensor([], dtype=torch.int64) elif torch.is_tensor(points): if points.dim() != 3 or points.shape[2] != 3: @@ -525,7 +526,7 @@ class Pointclouds: self._compute_packed() return self._cloud_to_packed_first_idx - def num_points_per_cloud(self): + def num_points_per_cloud(self) -> torch.Tensor: """ Return a 1D tensor x with length equal to the number of clouds giving the number of points in each cloud. diff --git a/tests/test_chamfer.py b/tests/test_chamfer.py index 8113d199..964a9fab 100644 --- a/tests/test_chamfer.py +++ b/tests/test_chamfer.py @@ -778,6 +778,17 @@ class TestChamfer(TestCaseMixin, unittest.TestCase): with self.assertRaisesRegex(ValueError, "Support for 1 or 2 norm."): chamfer_distance(p1, p2, norm=3) + def test_empty_clouds(self): + # Check that point_reduction doesn't divide by zero + points1 = Pointclouds(points=[torch.zeros(0, 3), torch.zeros(10, 3)]) + points2 = Pointclouds(points=torch.ones(2, 40, 3)) + loss, _ = chamfer_distance(points1, points2, batch_reduction=None) + self.assertClose(loss, torch.tensor([0.0, 6.0])) + + # Check that batch_reduction doesn't divide by zero + loss2, _ = chamfer_distance(Pointclouds([]), Pointclouds([])) + self.assertClose(loss2, torch.tensor(0.0)) + @staticmethod def chamfer_with_init( batch_size: int,