mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 11:52:50 +08:00
chamfer for empty pointclouds #1174
Summary: Fix divide by zero for empty pointcloud in chamfer. Also for empty batches. In process, needed to regularize num_points_per_cloud for empty batches. Reviewed By: kjchalup Differential Revision: D36311330 fbshipit-source-id: 3378ab738bee77ecc286f2110a5c8dc445960340
This commit is contained in:
parent
a42a89a5ba
commit
c6519f29f0
@ -99,7 +99,7 @@ class ModelDBIR(ImplicitronModelBase, torch.nn.Module):
|
|||||||
mask_fg[is_known_idx],
|
mask_fg[is_known_idx],
|
||||||
)
|
)
|
||||||
|
|
||||||
pcl_size = int(point_cloud.num_points_per_cloud())
|
pcl_size = point_cloud.num_points_per_cloud().item()
|
||||||
if (self.max_points > 0) and (pcl_size > self.max_points):
|
if (self.max_points > 0) and (pcl_size > self.max_points):
|
||||||
prm = torch.randperm(pcl_size)[: self.max_points]
|
prm = torch.randperm(pcl_size)[: self.max_points]
|
||||||
point_cloud = Pointclouds(
|
point_cloud = Pointclouds(
|
||||||
|
@ -197,11 +197,13 @@ def chamfer_distance(
|
|||||||
cham_norm_x = cham_norm_x.sum(1) # (N,)
|
cham_norm_x = cham_norm_x.sum(1) # (N,)
|
||||||
cham_norm_y = cham_norm_y.sum(1) # (N,)
|
cham_norm_y = cham_norm_y.sum(1) # (N,)
|
||||||
if point_reduction == "mean":
|
if point_reduction == "mean":
|
||||||
cham_x /= x_lengths
|
x_lengths_clamped = x_lengths.clamp(min=1)
|
||||||
cham_y /= y_lengths
|
y_lengths_clamped = y_lengths.clamp(min=1)
|
||||||
|
cham_x /= x_lengths_clamped
|
||||||
|
cham_y /= y_lengths_clamped
|
||||||
if return_normals:
|
if return_normals:
|
||||||
cham_norm_x /= x_lengths
|
cham_norm_x /= x_lengths_clamped
|
||||||
cham_norm_y /= y_lengths
|
cham_norm_y /= y_lengths_clamped
|
||||||
|
|
||||||
if batch_reduction is not None:
|
if batch_reduction is not None:
|
||||||
# batch_reduction == "sum"
|
# batch_reduction == "sum"
|
||||||
@ -211,7 +213,7 @@ def chamfer_distance(
|
|||||||
cham_norm_x = cham_norm_x.sum()
|
cham_norm_x = cham_norm_x.sum()
|
||||||
cham_norm_y = cham_norm_y.sum()
|
cham_norm_y = cham_norm_y.sum()
|
||||||
if batch_reduction == "mean":
|
if batch_reduction == "mean":
|
||||||
div = weights.sum() if weights is not None else N
|
div = weights.sum() if weights is not None else max(N, 1)
|
||||||
cham_x /= div
|
cham_x /= div
|
||||||
cham_y /= div
|
cham_y /= div
|
||||||
if return_normals:
|
if return_normals:
|
||||||
|
@ -303,6 +303,7 @@ def point_mesh_edge_distance(meshes: Meshes, pcls: Pointclouds):
|
|||||||
# weight each example by the inverse of number of points in the example
|
# weight each example by the inverse of number of points in the example
|
||||||
point_to_cloud_idx = pcls.packed_to_cloud_idx() # (sum(P_i), )
|
point_to_cloud_idx = pcls.packed_to_cloud_idx() # (sum(P_i), )
|
||||||
num_points_per_cloud = pcls.num_points_per_cloud() # (N,)
|
num_points_per_cloud = pcls.num_points_per_cloud() # (N,)
|
||||||
|
# pyre-ignore[16]: `torch.Tensor` has no attribute `gather`
|
||||||
weights_p = num_points_per_cloud.gather(0, point_to_cloud_idx)
|
weights_p = num_points_per_cloud.gather(0, point_to_cloud_idx)
|
||||||
weights_p = 1.0 / weights_p.float()
|
weights_p = 1.0 / weights_p.float()
|
||||||
point_to_edge = point_to_edge * weights_p
|
point_to_edge = point_to_edge * weights_p
|
||||||
@ -377,6 +378,7 @@ def point_mesh_face_distance(
|
|||||||
# weight each example by the inverse of number of points in the example
|
# weight each example by the inverse of number of points in the example
|
||||||
point_to_cloud_idx = pcls.packed_to_cloud_idx() # (sum(P_i),)
|
point_to_cloud_idx = pcls.packed_to_cloud_idx() # (sum(P_i),)
|
||||||
num_points_per_cloud = pcls.num_points_per_cloud() # (N,)
|
num_points_per_cloud = pcls.num_points_per_cloud() # (N,)
|
||||||
|
# pyre-ignore[16]: `torch.Tensor` has no attribute `gather`
|
||||||
weights_p = num_points_per_cloud.gather(0, point_to_cloud_idx)
|
weights_p = num_points_per_cloud.gather(0, point_to_cloud_idx)
|
||||||
weights_p = 1.0 / weights_p.float()
|
weights_p = 1.0 / weights_p.float()
|
||||||
point_to_face = point_to_face * weights_p
|
point_to_face = point_to_face * weights_p
|
||||||
|
@ -185,7 +185,6 @@ class Pointclouds:
|
|||||||
self._points_list = points
|
self._points_list = points
|
||||||
self._N = len(self._points_list)
|
self._N = len(self._points_list)
|
||||||
self.valid = torch.zeros((self._N,), dtype=torch.bool, device=self.device)
|
self.valid = torch.zeros((self._N,), dtype=torch.bool, device=self.device)
|
||||||
self._num_points_per_cloud = []
|
|
||||||
|
|
||||||
if self._N > 0:
|
if self._N > 0:
|
||||||
self.device = self._points_list[0].device
|
self.device = self._points_list[0].device
|
||||||
@ -208,6 +207,8 @@ class Pointclouds:
|
|||||||
if len(num_points_per_cloud.unique()) == 1:
|
if len(num_points_per_cloud.unique()) == 1:
|
||||||
self.equisized = True
|
self.equisized = True
|
||||||
self._num_points_per_cloud = num_points_per_cloud
|
self._num_points_per_cloud = num_points_per_cloud
|
||||||
|
else:
|
||||||
|
self._num_points_per_cloud = torch.tensor([], dtype=torch.int64)
|
||||||
|
|
||||||
elif torch.is_tensor(points):
|
elif torch.is_tensor(points):
|
||||||
if points.dim() != 3 or points.shape[2] != 3:
|
if points.dim() != 3 or points.shape[2] != 3:
|
||||||
@ -525,7 +526,7 @@ class Pointclouds:
|
|||||||
self._compute_packed()
|
self._compute_packed()
|
||||||
return self._cloud_to_packed_first_idx
|
return self._cloud_to_packed_first_idx
|
||||||
|
|
||||||
def num_points_per_cloud(self):
|
def num_points_per_cloud(self) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Return a 1D tensor x with length equal to the number of clouds giving
|
Return a 1D tensor x with length equal to the number of clouds giving
|
||||||
the number of points in each cloud.
|
the number of points in each cloud.
|
||||||
|
@ -778,6 +778,17 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
|
|||||||
with self.assertRaisesRegex(ValueError, "Support for 1 or 2 norm."):
|
with self.assertRaisesRegex(ValueError, "Support for 1 or 2 norm."):
|
||||||
chamfer_distance(p1, p2, norm=3)
|
chamfer_distance(p1, p2, norm=3)
|
||||||
|
|
||||||
|
def test_empty_clouds(self):
|
||||||
|
# Check that point_reduction doesn't divide by zero
|
||||||
|
points1 = Pointclouds(points=[torch.zeros(0, 3), torch.zeros(10, 3)])
|
||||||
|
points2 = Pointclouds(points=torch.ones(2, 40, 3))
|
||||||
|
loss, _ = chamfer_distance(points1, points2, batch_reduction=None)
|
||||||
|
self.assertClose(loss, torch.tensor([0.0, 6.0]))
|
||||||
|
|
||||||
|
# Check that batch_reduction doesn't divide by zero
|
||||||
|
loss2, _ = chamfer_distance(Pointclouds([]), Pointclouds([]))
|
||||||
|
self.assertClose(loss2, torch.tensor(0.0))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def chamfer_with_init(
|
def chamfer_with_init(
|
||||||
batch_size: int,
|
batch_size: int,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user