use assertClose

Summary: use assertClose in some tests, which enforces shape equality. Fixes some small problems, including graph_conv on an empty graph.

Reviewed By: nikhilaravi

Differential Revision: D20556912

fbshipit-source-id: 60a61eafe3c03ce0f6c9c1a842685708fb10ac5b
This commit is contained in:
Jeremy Reizenstein
2020-03-23 11:33:10 -07:00
committed by Facebook GitHub Bot
parent 744ef0c2c8
commit 595aca27ea
13 changed files with 216 additions and 241 deletions

View File

@@ -9,8 +9,10 @@ from pytorch3d.ops import sample_points_from_meshes
from pytorch3d.structures.meshes import Meshes
from pytorch3d.utils.ico_sphere import ico_sphere
from common_testing import TestCaseMixin
class TestSamplePoints(unittest.TestCase):
class TestSamplePoints(TestCaseMixin, unittest.TestCase):
def setUp(self) -> None:
super().setUp()
torch.manual_seed(1)
@@ -98,32 +100,24 @@ class TestSamplePoints(unittest.TestCase):
self.assertEqual(normals.shape, (3, num_samples, 3))
# Empty meshes: should have all zeros for samples and normals.
self.assertTrue(
torch.allclose(samples[0, :], torch.zeros((1, num_samples, 3)))
)
self.assertTrue(
torch.allclose(normals[0, :], torch.zeros((1, num_samples, 3)))
)
self.assertClose(samples[0, :], torch.zeros((num_samples, 3)))
self.assertClose(normals[0, :], torch.zeros((num_samples, 3)))
# Sphere: points should have radius 1.
x, y, z = samples[1, :].unbind(1)
radius = torch.sqrt(x ** 2 + y ** 2 + z ** 2)
self.assertTrue(torch.allclose(radius, torch.ones((num_samples))))
self.assertClose(radius, torch.ones((num_samples)))
# Pyramid: points shoudl lie on one of the faces.
pyramid_verts = samples[2, :]
pyramid_normals = normals[2, :]
self.assertTrue(
torch.allclose(
pyramid_verts.lt(1).float(), torch.ones_like(pyramid_verts)
)
self.assertClose(
pyramid_verts.lt(1).float(), torch.ones_like(pyramid_verts)
)
self.assertTrue(
torch.allclose(
(pyramid_verts >= 0).float(), torch.ones_like(pyramid_verts)
)
self.assertClose(
(pyramid_verts >= 0).float(), torch.ones_like(pyramid_verts)
)
# Face 1: z = 0, x + y <= 1, normals = (0, 0, 1).
@@ -135,13 +129,11 @@ class TestSamplePoints(unittest.TestCase):
self.assertTrue(
torch.all((face_1_verts[:, 0] + face_1_verts[:, 1]) <= 1)
)
self.assertTrue(
torch.allclose(
face_1_normals,
torch.tensor([0, 0, 1], dtype=torch.float32).expand(
face_1_normals.size()
),
)
self.assertClose(
face_1_normals,
torch.tensor([0, 0, 1], dtype=torch.float32).expand(
face_1_normals.size()
),
)
# Face 2: x = 0, z + y <= 1, normals = (1, 0, 0).
@@ -153,13 +145,11 @@ class TestSamplePoints(unittest.TestCase):
self.assertTrue(
torch.all((face_2_verts[:, 1] + face_2_verts[:, 2]) <= 1)
)
self.assertTrue(
torch.allclose(
face_2_normals,
torch.tensor([1, 0, 0], dtype=torch.float32).expand(
face_2_normals.size()
),
)
self.assertClose(
face_2_normals,
torch.tensor([1, 0, 0], dtype=torch.float32).expand(
face_2_normals.size()
),
)
# Face 3: y = 0, x + z <= 1, normals = (0, -1, 0).
@@ -171,13 +161,11 @@ class TestSamplePoints(unittest.TestCase):
self.assertTrue(
torch.all((face_3_verts[:, 0] + face_3_verts[:, 2]) <= 1)
)
self.assertTrue(
torch.allclose(
face_3_normals,
torch.tensor([0, -1, 0], dtype=torch.float32).expand(
face_3_normals.size()
),
)
self.assertClose(
face_3_normals,
torch.tensor([0, -1, 0], dtype=torch.float32).expand(
face_3_normals.size()
),
)
# Face 4: x + y + z = 1, normals = (1, 1, 1)/sqrt(3).
@@ -186,22 +174,16 @@ class TestSamplePoints(unittest.TestCase):
pyramid_verts[face_4_idxs, :],
pyramid_normals[face_4_idxs, :],
)
self.assertTrue(
torch.allclose(
face_4_verts.sum(1), torch.ones(face_4_verts.size(0))
)
)
self.assertTrue(
torch.allclose(
face_4_normals,
(
torch.tensor([1, 1, 1], dtype=torch.float32)
/ torch.sqrt(torch.tensor(3, dtype=torch.float32))
).expand(face_4_normals.size()),
)
self.assertClose(face_4_verts.sum(1), torch.ones(face_4_verts.size(0)))
self.assertClose(
face_4_normals,
(
torch.tensor([1, 1, 1], dtype=torch.float32)
/ torch.sqrt(torch.tensor(3, dtype=torch.float32))
).expand(face_4_normals.size()),
)
def test_mutinomial(self):
def test_multinomial(self):
"""
Confirm that torch.multinomial does not sample elements which have
zero probability.