chamfer test consistency

Summary:
Modify test_chamfer for more robustness. Avoid empty pointclouds, including where point_reduction is mean, for which we currently return nan (*), and so that we aren't looking at an empty gradient. Make sure we aren't using padding as points in the homogenous cases in the tests, which will lead to a tie between closest points and therefore a potential instability in the gradient - see https://github.com/pytorch/pytorch/issues/35699.

(*) This doesn't attempt to fix the nan.

Reviewed By: nikhilaravi, gkioxari

Differential Revision: D21157322

fbshipit-source-id: a609e84e25a24379c8928ff645d587552526e4af
This commit is contained in:
Jeremy Reizenstein 2020-04-22 09:26:05 -07:00 committed by Facebook GitHub Bot
parent faf0885811
commit 9e4bd2e5e0
2 changed files with 28 additions and 20 deletions

View File

@ -123,4 +123,8 @@ class TestCaseMixin(unittest.TestCase):
input, other, rtol=rtol, atol=atol, equal_nan=equal_nan input, other, rtol=rtol, atol=atol, equal_nan=equal_nan
) )
if not close and msg is None:
max_diff = backend.abs(input - other).max()
self.fail(f"Not close. max diff {max_diff}.")
self.assertTrue(close, msg) self.assertTrue(close, msg)

View File

@ -9,7 +9,6 @@ import torch.nn.functional as F
from common_testing import TestCaseMixin from common_testing import TestCaseMixin
from pytorch3d.loss import chamfer_distance from pytorch3d.loss import chamfer_distance
from pytorch3d.structures.pointclouds import Pointclouds from pytorch3d.structures.pointclouds import Pointclouds
from pytorch3d.structures.utils import list_to_padded
# Output of init_pointclouds # Output of init_pointclouds
@ -24,18 +23,28 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
torch.manual_seed(1) torch.manual_seed(1)
@staticmethod @staticmethod
def init_pointclouds(N, P1, P2, device, requires_grad: bool = True): def init_pointclouds(
N, P1, P2, device, requires_grad: bool = True, allow_empty: bool = True
):
""" """
Create 2 pointclouds object and associated padded points/normals tensors by Create 2 pointclouds object and associated padded points/normals tensors by
starting from lists. The clouds and tensors have the same data. The starting from lists. The clouds and tensors have the same data. The
leaf nodes for the clouds are a list of tensors. The padded tensor can be leaf nodes for the clouds are a list of tensors. The padded tensor can be
used directly as a leaf node. used directly as a leaf node.
""" """
p1_lengths = torch.randint(P1, size=(N,), dtype=torch.int64, device=device) low = 0 if allow_empty else 1
p2_lengths = torch.randint(P2, size=(N,), dtype=torch.int64, device=device) p1_lengths = torch.randint(low, P1, size=(N,), dtype=torch.int64, device=device)
p2_lengths = torch.randint(low, P2, size=(N,), dtype=torch.int64, device=device)
weights = torch.rand((N,), dtype=torch.float32, device=device) weights = torch.rand((N,), dtype=torch.float32, device=device)
# list of points and normals tensors # list of points and normals tensors
p1 = torch.rand((N, P1, 3), dtype=torch.float32, device=device)
p2 = torch.rand((N, P2, 3), dtype=torch.float32, device=device)
n1 = torch.rand((N, P1, 3), dtype=torch.float32, device=device)
n2 = torch.rand((N, P2, 3), dtype=torch.float32, device=device)
n1 /= n1.norm(dim=-1, p=2, keepdim=True)
n2 /= n2.norm(dim=-1, p=2, keepdim=True)
p1_list = [] p1_list = []
p2_list = [] p2_list = []
n1_list = [] n1_list = []
@ -43,19 +52,10 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
for i in range(N): for i in range(N):
l1 = p1_lengths[i] l1 = p1_lengths[i]
l2 = p2_lengths[i] l2 = p2_lengths[i]
p1_list.append(torch.rand((l1, 3), dtype=torch.float32, device=device)) p1_list.append(p1[i, :l1].clone())
p2_list.append(torch.rand((l2, 3), dtype=torch.float32, device=device)) p2_list.append(p2[i, :l2].clone())
n1_list.append(torch.rand((l1, 3), dtype=torch.float32, device=device)) n1_list.append(n1[i, :l1].clone())
n2_list.append(torch.rand((l2, 3), dtype=torch.float32, device=device)) n2_list.append(n2[i, :l2].clone())
n1_list = [n / n.norm(dim=-1, p=2, keepdim=True) for n in n1_list]
n2_list = [n / n.norm(dim=-1, p=2, keepdim=True) for n in n2_list]
# Clone the lists and initialize padded tensors.
p1 = list_to_padded([p.clone() for p in p1_list])
p2 = list_to_padded([p.clone() for p in p2_list])
n1 = list_to_padded([p.clone() for p in n1_list])
n2 = list_to_padded([p.clone() for p in n2_list])
# Set requires_grad for all tensors in the lists and # Set requires_grad for all tensors in the lists and
# padded tensors. # padded tensors.
@ -313,7 +313,9 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
# Reinitialize all the tensors so that the # Reinitialize all the tensors so that the
# backward pass can be computed. # backward pass can be computed.
points_normals = TestChamfer.init_pointclouds(N, P1, P2, device) points_normals = TestChamfer.init_pointclouds(
N, P1, P2, device, allow_empty=False
)
# Chamfer with pointclouds as input. # Chamfer with pointclouds as input.
cham_cloud, norm_cloud = chamfer_distance( cham_cloud, norm_cloud = chamfer_distance(
@ -371,7 +373,9 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
# Reinitialize all the tensors so that the # Reinitialize all the tensors so that the
# backward pass can be computed. # backward pass can be computed.
points_normals = TestChamfer.init_pointclouds(N, P1, P2, device) points_normals = TestChamfer.init_pointclouds(
N, P1, P2, device, allow_empty=False
)
# Chamfer with pointclouds as input. # Chamfer with pointclouds as input.
cham_cloud, _ = chamfer_distance( cham_cloud, _ = chamfer_distance(
@ -560,7 +564,7 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
# List of tensors vs padded tensor # List of tensors vs padded tensor
for i in range(len(x1)): for i in range(len(x1)):
self.assertClose(x1[i].grad, x2.grad[i, : lengths[i]]) self.assertClose(x1[i].grad, x2.grad[i, : lengths[i]], atol=1e-7)
self.assertTrue(x2.grad[i, lengths[i] :].sum().item() == 0.0) self.assertTrue(x2.grad[i, lengths[i] :].sum().item() == 0.0)
elif all(torch.is_tensor(p) for p in [x1, x2]): elif all(torch.is_tensor(p) for p in [x1, x2]):
# Two tensors # Two tensors