mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
chamfer test consistency
Summary: Modify test_chamfer for more robustness. Avoid empty pointclouds, including where point_reduction is mean, for which we currently return nan (*), and so that we aren't looking at an empty gradient. Make sure we aren't using padding as points in the homogenous cases in the tests, which will lead to a tie between closest points and therefore a potential instability in the gradient - see https://github.com/pytorch/pytorch/issues/35699. (*) This doesn't attempt to fix the nan. Reviewed By: nikhilaravi, gkioxari Differential Revision: D21157322 fbshipit-source-id: a609e84e25a24379c8928ff645d587552526e4af
This commit is contained in:
parent
faf0885811
commit
9e4bd2e5e0
@ -123,4 +123,8 @@ class TestCaseMixin(unittest.TestCase):
|
||||
input, other, rtol=rtol, atol=atol, equal_nan=equal_nan
|
||||
)
|
||||
|
||||
if not close and msg is None:
|
||||
max_diff = backend.abs(input - other).max()
|
||||
self.fail(f"Not close. max diff {max_diff}.")
|
||||
|
||||
self.assertTrue(close, msg)
|
||||
|
@ -9,7 +9,6 @@ import torch.nn.functional as F
|
||||
from common_testing import TestCaseMixin
|
||||
from pytorch3d.loss import chamfer_distance
|
||||
from pytorch3d.structures.pointclouds import Pointclouds
|
||||
from pytorch3d.structures.utils import list_to_padded
|
||||
|
||||
|
||||
# Output of init_pointclouds
|
||||
@ -24,18 +23,28 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
|
||||
torch.manual_seed(1)
|
||||
|
||||
@staticmethod
|
||||
def init_pointclouds(N, P1, P2, device, requires_grad: bool = True):
|
||||
def init_pointclouds(
|
||||
N, P1, P2, device, requires_grad: bool = True, allow_empty: bool = True
|
||||
):
|
||||
"""
|
||||
Create 2 pointclouds object and associated padded points/normals tensors by
|
||||
starting from lists. The clouds and tensors have the same data. The
|
||||
leaf nodes for the clouds are a list of tensors. The padded tensor can be
|
||||
used directly as a leaf node.
|
||||
"""
|
||||
p1_lengths = torch.randint(P1, size=(N,), dtype=torch.int64, device=device)
|
||||
p2_lengths = torch.randint(P2, size=(N,), dtype=torch.int64, device=device)
|
||||
low = 0 if allow_empty else 1
|
||||
p1_lengths = torch.randint(low, P1, size=(N,), dtype=torch.int64, device=device)
|
||||
p2_lengths = torch.randint(low, P2, size=(N,), dtype=torch.int64, device=device)
|
||||
weights = torch.rand((N,), dtype=torch.float32, device=device)
|
||||
|
||||
# list of points and normals tensors
|
||||
p1 = torch.rand((N, P1, 3), dtype=torch.float32, device=device)
|
||||
p2 = torch.rand((N, P2, 3), dtype=torch.float32, device=device)
|
||||
n1 = torch.rand((N, P1, 3), dtype=torch.float32, device=device)
|
||||
n2 = torch.rand((N, P2, 3), dtype=torch.float32, device=device)
|
||||
n1 /= n1.norm(dim=-1, p=2, keepdim=True)
|
||||
n2 /= n2.norm(dim=-1, p=2, keepdim=True)
|
||||
|
||||
p1_list = []
|
||||
p2_list = []
|
||||
n1_list = []
|
||||
@ -43,19 +52,10 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
|
||||
for i in range(N):
|
||||
l1 = p1_lengths[i]
|
||||
l2 = p2_lengths[i]
|
||||
p1_list.append(torch.rand((l1, 3), dtype=torch.float32, device=device))
|
||||
p2_list.append(torch.rand((l2, 3), dtype=torch.float32, device=device))
|
||||
n1_list.append(torch.rand((l1, 3), dtype=torch.float32, device=device))
|
||||
n2_list.append(torch.rand((l2, 3), dtype=torch.float32, device=device))
|
||||
|
||||
n1_list = [n / n.norm(dim=-1, p=2, keepdim=True) for n in n1_list]
|
||||
n2_list = [n / n.norm(dim=-1, p=2, keepdim=True) for n in n2_list]
|
||||
|
||||
# Clone the lists and initialize padded tensors.
|
||||
p1 = list_to_padded([p.clone() for p in p1_list])
|
||||
p2 = list_to_padded([p.clone() for p in p2_list])
|
||||
n1 = list_to_padded([p.clone() for p in n1_list])
|
||||
n2 = list_to_padded([p.clone() for p in n2_list])
|
||||
p1_list.append(p1[i, :l1].clone())
|
||||
p2_list.append(p2[i, :l2].clone())
|
||||
n1_list.append(n1[i, :l1].clone())
|
||||
n2_list.append(n2[i, :l2].clone())
|
||||
|
||||
# Set requires_grad for all tensors in the lists and
|
||||
# padded tensors.
|
||||
@ -313,7 +313,9 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
|
||||
|
||||
# Reinitialize all the tensors so that the
|
||||
# backward pass can be computed.
|
||||
points_normals = TestChamfer.init_pointclouds(N, P1, P2, device)
|
||||
points_normals = TestChamfer.init_pointclouds(
|
||||
N, P1, P2, device, allow_empty=False
|
||||
)
|
||||
|
||||
# Chamfer with pointclouds as input.
|
||||
cham_cloud, norm_cloud = chamfer_distance(
|
||||
@ -371,7 +373,9 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
|
||||
|
||||
# Reinitialize all the tensors so that the
|
||||
# backward pass can be computed.
|
||||
points_normals = TestChamfer.init_pointclouds(N, P1, P2, device)
|
||||
points_normals = TestChamfer.init_pointclouds(
|
||||
N, P1, P2, device, allow_empty=False
|
||||
)
|
||||
|
||||
# Chamfer with pointclouds as input.
|
||||
cham_cloud, _ = chamfer_distance(
|
||||
@ -560,7 +564,7 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
|
||||
|
||||
# List of tensors vs padded tensor
|
||||
for i in range(len(x1)):
|
||||
self.assertClose(x1[i].grad, x2.grad[i, : lengths[i]])
|
||||
self.assertClose(x1[i].grad, x2.grad[i, : lengths[i]], atol=1e-7)
|
||||
self.assertTrue(x2.grad[i, lengths[i] :].sum().item() == 0.0)
|
||||
elif all(torch.is_tensor(p) for p in [x1, x2]):
|
||||
# Two tensors
|
||||
|
Loading…
x
Reference in New Issue
Block a user