Mesh normal consistency when many faces intersect

Summary: We were double counting some pairs in some cases. Specifically if four or more faces share an edge, then some of them were getting double counted. This is a minimal tweak to avoid that.

Reviewed By: nikhilaravi

Differential Revision: D26073477

fbshipit-source-id: a40032acf3044bb98dd91cb29904614ef64d5599
This commit is contained in:
Jeremy Reizenstein 2021-01-27 17:30:25 -08:00 committed by Facebook GitHub Bot
parent 00acda7ab0
commit 7f62eacdb2
2 changed files with 12 additions and 10 deletions

View File

@ -108,7 +108,7 @@ def mesh_normal_consistency(meshes):
for e in vert_edge_pair_idx for e in vert_edge_pair_idx
for i in range(len(e) - 1) for i in range(len(e) - 1)
for j in range(1, len(e)) for j in range(1, len(e))
if i != j if i < j
] ]
vert_edge_pair_idx = torch.tensor( vert_edge_pair_idx = torch.tensor(
vert_edge_pair_idx, device=meshes.device, dtype=torch.int64 vert_edge_pair_idx, device=meshes.device, dtype=torch.int64

View File

@ -10,6 +10,9 @@ from pytorch3d.utils.ico_sphere import ico_sphere
class TestMeshNormalConsistency(unittest.TestCase): class TestMeshNormalConsistency(unittest.TestCase):
def setUp(self) -> None:
torch.manual_seed(42)
@staticmethod @staticmethod
def init_faces(num_verts: int = 1000): def init_faces(num_verts: int = 1000):
faces = [] faces = []
@ -95,17 +98,16 @@ class TestMeshNormalConsistency(unittest.TestCase):
v2 = verts_packed[v2] v2 = verts_packed[v2]
normals.append((v1 - v0).view(-1).cross((v2 - v0).view(-1))) normals.append((v1 - v0).view(-1).cross((v2 - v0).view(-1)))
for i in range(len(normals) - 1): for i in range(len(normals) - 1):
for j in range(1, len(normals)): for j in range(i + 1, len(normals)):
if i != j: mesh_idx.append(edges_packed_to_mesh_idx[e])
mesh_idx.append(edges_packed_to_mesh_idx[e]) loss.append(
loss.append( (
( 1
1 - torch.cosine_similarity(
- torch.cosine_similarity( normals[i].view(1, 3), -normals[j].view(1, 3)
normals[i].view(1, 3), -normals[j].view(1, 3)
)
) )
) )
)
mesh_idx = torch.tensor(mesh_idx, device=meshes.device) mesh_idx = torch.tensor(mesh_idx, device=meshes.device)
num = mesh_idx.bincount(minlength=N) num = mesh_idx.bincount(minlength=N)