diff --git a/tests/test_meshes.py b/tests/test_meshes.py index a0e940df..afe6e941 100644 --- a/tests/test_meshes.py +++ b/tests/test_meshes.py @@ -127,6 +127,45 @@ def init_simple_mesh(device: str = "cpu"): return Meshes(verts=verts, faces=faces) +def mesh_structures_equal(mesh1, mesh2) -> bool: + """ + Two meshes are equal if they have identical verts_list and faces_list. + + Use to_sorted() before passing into this function to obtain meshes invariant to + vertex permutations. Note that this operator treats two geometrically identical + meshes as different if their vertices are in different coordinate frames. + """ + if mesh1.__class__ != mesh1.__class__: + return False + + if mesh1.textures is not None or mesh2.textures is not None: + raise NotImplementedError( + "mesh equality is not implemented for textured meshes." + ) + + if len(mesh1.verts_list()) != len(mesh2.verts_list()) or not all( + torch.equal(verts_mesh1, verts_mesh2) + for (verts_mesh1, verts_mesh2) in zip(mesh1.verts_list(), mesh2.verts_list()) + ): + return False + + if len(mesh1.faces_list()) != len(mesh2.faces_list()) or not all( + torch.equal(faces_mesh1, faces_mesh2) + for (faces_mesh1, faces_mesh2) in zip(mesh1.faces_list(), mesh2.faces_list()) + ): + return False + + if len(mesh1.verts_normals_list()) != len(mesh2.verts_normals_list()) or not all( + torch.equal(normals_mesh1, normals_mesh2) + for (normals_mesh1, normals_mesh2) in zip( + mesh1.verts_normals_list(), mesh2.verts_normals_list() + ) + ): + return False + + return True + + class TestMeshes(TestCaseMixin, unittest.TestCase): def setUp(self) -> None: np.random.seed(42) @@ -1172,6 +1211,18 @@ class TestMeshes(TestCaseMixin, unittest.TestCase): face_normals_cpu[nonzero], face_normals_cuda.cpu()[nonzero], atol=1e-6 ) + def test_equality(self): + meshes1 = init_mesh(num_meshes=2) + meshes2 = init_mesh(num_meshes=2) + meshes3 = init_mesh(num_meshes=3) + empty_mesh = Meshes([], []) + self.assertTrue(mesh_structures_equal(empty_mesh, Meshes([], []))) + self.assertTrue(mesh_structures_equal(meshes1, meshes1)) + self.assertTrue(mesh_structures_equal(meshes1, meshes1.clone())) + self.assertFalse(mesh_structures_equal(empty_mesh, meshes1)) + self.assertFalse(mesh_structures_equal(meshes1, meshes2)) + self.assertFalse(mesh_structures_equal(meshes1, meshes3)) + @staticmethod def compute_packed_with_init( num_meshes: int = 10, max_v: int = 100, max_f: int = 300, device: str = "cpu"