use assertClose

Summary: use assertClose in some tests, which enforces shape equality. Fixes some small problems, including graph_conv on an empty graph.

Reviewed By: nikhilaravi

Differential Revision: D20556912

fbshipit-source-id: 60a61eafe3c03ce0f6c9c1a842685708fb10ac5b
This commit is contained in:
Jeremy Reizenstein
2020-03-23 11:33:10 -07:00
committed by Facebook GitHub Bot
parent 744ef0c2c8
commit 595aca27ea
13 changed files with 216 additions and 241 deletions

View File

@@ -8,8 +8,10 @@ from pytorch3d.ops.subdivide_meshes import SubdivideMeshes
from pytorch3d.structures.meshes import Meshes
from pytorch3d.utils.ico_sphere import ico_sphere
from common_testing import TestCaseMixin
class TestSubdivideMeshes(unittest.TestCase):
class TestSubdivideMeshes(TestCaseMixin, unittest.TestCase):
def test_simple_subdivide(self):
# Create a mesh with one face and check the subdivided mesh has
# 4 faces with the correct vertex coordinates.
@@ -56,8 +58,8 @@ class TestSubdivideMeshes(unittest.TestCase):
device=device,
)
new_verts, new_faces = new_mesh.get_mesh_verts_faces(0)
self.assertTrue(torch.allclose(new_verts, gt_subdivide_verts))
self.assertTrue(torch.allclose(new_faces, gt_subdivide_faces))
self.assertClose(new_verts, gt_subdivide_verts)
self.assertClose(new_faces, gt_subdivide_faces)
self.assertTrue(new_verts.requires_grad == verts.requires_grad)
def test_heterogeneous_meshes(self):
@@ -185,12 +187,12 @@ class TestSubdivideMeshes(unittest.TestCase):
new_mesh_verts1, new_mesh_faces1 = new_mesh.get_mesh_verts_faces(0)
new_mesh_verts2, new_mesh_faces2 = new_mesh.get_mesh_verts_faces(1)
new_mesh_verts3, new_mesh_faces3 = new_mesh.get_mesh_verts_faces(2)
self.assertTrue(torch.allclose(new_mesh_verts1, gt_subdivided_verts1))
self.assertTrue(torch.allclose(new_mesh_faces1, gt_subdivided_faces1))
self.assertTrue(torch.allclose(new_mesh_verts2, gt_subdivided_verts2))
self.assertTrue(torch.allclose(new_mesh_faces2, gt_subdivided_faces2))
self.assertTrue(torch.allclose(new_mesh_verts3, gt_subdivided_verts3))
self.assertTrue(torch.allclose(new_mesh_faces3, gt_subdivided_faces3))
self.assertClose(new_mesh_verts1, gt_subdivided_verts1)
self.assertClose(new_mesh_faces1, gt_subdivided_faces1)
self.assertClose(new_mesh_verts2, gt_subdivided_verts2)
self.assertClose(new_mesh_faces2, gt_subdivided_faces2)
self.assertClose(new_mesh_verts3, gt_subdivided_verts3)
self.assertClose(new_mesh_faces3, gt_subdivided_faces3)
self.assertTrue(new_mesh_verts1.requires_grad == verts1.requires_grad)
self.assertTrue(new_mesh_verts2.requires_grad == verts2.requires_grad)
self.assertTrue(new_mesh_verts3.requires_grad == verts2.requires_grad)
@@ -212,7 +214,7 @@ class TestSubdivideMeshes(unittest.TestCase):
gt_feats = torch.cat(
(feats.view(N, V, D), app_feats.view(N, -1, D)), dim=1
).view(-1, D)
self.assertTrue(torch.allclose(new_feats, gt_feats))
self.assertClose(new_feats, gt_feats)
self.assertTrue(new_feats.requires_grad == gt_feats.requires_grad)
@staticmethod