fix small face issue for ptmeshdist

Summary:
Fix small face issue for point_mesh distance computation.

The issue lies in the computation of `IsInsideTriangle` which is unstable and non-symmetrical when faces with small areas are given as input. This diff fixes the issue by returning `False` for `IsInsideTriangle` when small faces are given as input.

Reviewed By: bottler

Differential Revision: D29163052

fbshipit-source-id: be297002f26b5e6eded9394fde00553a37406bee
This commit is contained in:
Georgia Gkioxari
2021-06-18 09:29:01 -07:00
committed by Facebook GitHub Bot
parent a343cf534c
commit 88f5d79088
3 changed files with 93 additions and 11 deletions

View File

@@ -96,7 +96,7 @@ class TestPointMeshDistance(TestCaseMixin, unittest.TestCase):
d20 = v2.dot(v0)
d21 = v2.dot(v1)
denom = d00 * d11 - d01 * d01
denom = d00 * d11 - d01 * d01 + TestPointMeshDistance.eps()
s2 = (d11 * d20 - d01 * d21) / denom
s3 = (d00 * d21 - d01 * d20) / denom
s1 = 1.0 - s2 - s3
@@ -117,6 +117,13 @@ class TestPointMeshDistance(TestCaseMixin, unittest.TestCase):
Returns:
inside: BoolTensor of shape (1)
"""
v0 = tri[1] - tri[0]
v1 = tri[2] - tri[0]
area = torch.cross(v0, v1).norm() / 2.0
# check if triangle is a line or a point. In that case, return False
if area < 1e-5:
return False
bary = TestPointMeshDistance._point_to_bary(point, tri)
inside = ((bary >= 0.0) * (bary <= 1.0)).all()
return inside
@@ -836,6 +843,28 @@ class TestPointMeshDistance(TestCaseMixin, unittest.TestCase):
)
self.assertClose(pcls.points_list()[i].grad, pcls_op.points_list()[i].grad)
def test_small_faces_case(self):
for device in [torch.device("cpu"), torch.device("cuda:0")]:
mesh_vertices = torch.tensor(
[
[-0.0021, -0.3769, 0.7146],
[-0.0161, -0.3771, 0.7146],
[-0.0021, -0.3771, 0.7147],
],
dtype=torch.float32,
device=device,
)
mesh1_faces = torch.tensor([[0, 2, 1]], device=device)
mesh2_faces = torch.tensor([[2, 0, 1]], device=device)
pcd_points = torch.tensor([[-0.3623, -0.5340, 0.7727]], device=device)
mesh1 = Meshes(verts=[mesh_vertices], faces=[mesh1_faces])
mesh2 = Meshes(verts=[mesh_vertices], faces=[mesh2_faces])
pcd = Pointclouds(points=[pcd_points])
loss1 = point_mesh_face_distance(mesh1, pcd)
loss2 = point_mesh_face_distance(mesh2, pcd)
self.assertClose(loss1, loss2)
@staticmethod
def point_mesh_edge(N: int, V: int, F: int, P: int, device: str):
device = torch.device(device)