diff --git a/pytorch3d/loss/mesh_normal_consistency.py b/pytorch3d/loss/mesh_normal_consistency.py index cd672625..1433da52 100644 --- a/pytorch3d/loss/mesh_normal_consistency.py +++ b/pytorch3d/loss/mesh_normal_consistency.py @@ -114,6 +114,11 @@ def mesh_normal_consistency(meshes): vert_edge_pair_idx, device=meshes.device, dtype=torch.int64 ) + if vert_edge_pair_idx.shape[0] == 0: + return torch.tensor( + [0.0], dtype=torch.float32, device=meshes.device, requires_grad=True + ) + v0_idx = edges_packed[edge_idx, 0] v0 = verts_packed[v0_idx] v1_idx = edges_packed[edge_idx, 1] diff --git a/tests/test_mesh_normal_consistency.py b/tests/test_mesh_normal_consistency.py index 92077cc9..d597facd 100644 --- a/tests/test_mesh_normal_consistency.py +++ b/tests/test_mesh_normal_consistency.py @@ -218,6 +218,17 @@ class TestMeshNormalConsistency(unittest.TestCase): self.assertTrue(torch.allclose(out1, out2)) + def test_no_intersection(self): + """ + Test Mesh Normal Consistency for a mesh known to have no + intersecting faces. + """ + verts = torch.rand(1, 6, 2) + faces = torch.arange(6).reshape(1, 2, 3) + meshes = Meshes(verts=verts, faces=faces) + out = mesh_normal_consistency(meshes) + self.assertEqual(out.item(), 0) + @staticmethod def mesh_normal_consistency_with_ico( num_meshes: int, level: int = 3, device: str = "cpu"