eps fix for iou3d

Summary: Fix EPS issue that causes numerical instabilities when boxes are very close

Reviewed By: kjchalup

Differential Revision: D38661465

fbshipit-source-id: d2b6753cba9dc2f0072ace5289c9aa815a1a29f6
This commit is contained in:
Georgia Gkioxari
2022-08-22 04:26:19 -07:00
committed by Facebook GitHub Bot
parent 06cbba2628
commit 1bfe6bf20a
5 changed files with 923 additions and 188 deletions

View File

@@ -21,7 +21,8 @@ from .common_testing import get_random_cuda_device, get_tests_dir, TestCaseMixin
OBJECTRON_TO_PYTORCH3D_FACE_IDX = [0, 4, 6, 2, 1, 5, 7, 3]
DATA_DIR = get_tests_dir() / "data"
DEBUG = False
EPS = 1e-5
DOT_EPS = 1e-3
AREA_EPS = 1e-4
UNIT_BOX = [
[0, 0, 0],
@@ -457,6 +458,207 @@ class TestIoU3D(TestCaseMixin, unittest.TestCase):
self.assertClose(
iou, torch.tensor([[0.91]], device=vol.device, dtype=vol.dtype), atol=1e-2
)
# symmetry
vol, iou = overlap_fn(box15b[None], box15a[None])
self.assertClose(
iou, torch.tensor([[0.91]], device=vol.device, dtype=vol.dtype), atol=1e-2
)
# 16th test: From GH issue 1287
box16a = torch.tensor(
[
[-167.5847, -70.6167, -2.7927],
[-166.7333, -72.4264, -2.7927],
[-166.7333, -72.4264, -4.5927],
[-167.5847, -70.6167, -4.5927],
[-163.0605, -68.4880, -2.7927],
[-162.2090, -70.2977, -2.7927],
[-162.2090, -70.2977, -4.5927],
[-163.0605, -68.4880, -4.5927],
],
device=device,
dtype=torch.float32,
)
box16b = torch.tensor(
[
[-167.5847, -70.6167, -2.7927],
[-166.7333, -72.4264, -2.7927],
[-166.7333, -72.4264, -4.5927],
[-167.5847, -70.6167, -4.5927],
[-163.0605, -68.4880, -2.7927],
[-162.2090, -70.2977, -2.7927],
[-162.2090, -70.2977, -4.5927],
[-163.0605, -68.4880, -4.5927],
],
device=device,
dtype=torch.float32,
)
vol, iou = overlap_fn(box16a[None], box16b[None])
self.assertClose(
iou, torch.tensor([[1.0]], device=vol.device, dtype=vol.dtype), atol=1e-2
)
# symmetry
vol, iou = overlap_fn(box16b[None], box16a[None])
self.assertClose(
iou, torch.tensor([[1.0]], device=vol.device, dtype=vol.dtype), atol=1e-2
)
# 17th test: From GH issue 1287
box17a = torch.tensor(
[
[-33.94158, -4.51639, 0.96941],
[-34.67156, -2.65437, 0.96941],
[-34.67156, -2.65437, -0.95367],
[-33.94158, -4.51639, -0.95367],
[-38.75954, -6.40521, 0.96941],
[-39.48952, -4.54319, 0.96941],
[-39.48952, -4.54319, -0.95367],
[-38.75954, -6.40521, -0.95367],
],
device=device,
dtype=torch.float32,
)
box17b = torch.tensor(
[
[-33.94159, -4.51638, 0.96939],
[-34.67158, -2.65437, 0.96939],
[-34.67158, -2.65437, -0.95368],
[-33.94159, -4.51638, -0.95368],
[-38.75954, -6.40523, 0.96939],
[-39.48953, -4.54321, 0.96939],
[-39.48953, -4.54321, -0.95368],
[-38.75954, -6.40523, -0.95368],
],
device=device,
dtype=torch.float32,
)
vol, iou = overlap_fn(box17a[None], box17b[None])
self.assertClose(
iou, torch.tensor([[1.0]], device=vol.device, dtype=vol.dtype), atol=1e-2
)
# symmetry
vol, iou = overlap_fn(box17b[None], box17a[None])
self.assertClose(
iou, torch.tensor([[1.0]], device=vol.device, dtype=vol.dtype), atol=1e-2
)
# 18th test: From GH issue 1287
box18a = torch.tensor(
[
[-105.6248, -32.7026, -1.2279],
[-106.4690, -30.8895, -1.2279],
[-106.4690, -30.8895, -3.0279],
[-105.6248, -32.7026, -3.0279],
[-110.1575, -34.8132, -1.2279],
[-111.0017, -33.0001, -1.2279],
[-111.0017, -33.0001, -3.0279],
[-110.1575, -34.8132, -3.0279],
],
device=device,
dtype=torch.float32,
)
box18b = torch.tensor(
[
[-105.5094, -32.9504, -1.0641],
[-106.4272, -30.9793, -1.0641],
[-106.4272, -30.9793, -3.1916],
[-105.5094, -32.9504, -3.1916],
[-110.0421, -35.0609, -1.0641],
[-110.9599, -33.0899, -1.0641],
[-110.9599, -33.0899, -3.1916],
[-110.0421, -35.0609, -3.1916],
],
device=device,
dtype=torch.float32,
)
# from Meshlab
vol_inters = 17.108501
vol_box1 = 18.000067
vol_box2 = 23.128527
iou_mesh = vol_inters / (vol_box1 + vol_box2 - vol_inters)
vol, iou = overlap_fn(box18a[None], box18b[None])
self.assertClose(
iou,
torch.tensor([[iou_mesh]], device=vol.device, dtype=vol.dtype),
atol=1e-2,
)
self.assertClose(
vol,
torch.tensor([[vol_inters]], device=vol.device, dtype=vol.dtype),
atol=1e-2,
)
# symmetry
vol, iou = overlap_fn(box18b[None], box18a[None])
self.assertClose(
iou,
torch.tensor([[iou_mesh]], device=vol.device, dtype=vol.dtype),
atol=1e-2,
)
self.assertClose(
vol,
torch.tensor([[vol_inters]], device=vol.device, dtype=vol.dtype),
atol=1e-2,
)
# 19th example: From GH issue 1287
box19a = torch.tensor(
[
[-59.4785, -15.6003, 0.4398],
[-60.2263, -13.6928, 0.4398],
[-60.2263, -13.6928, -1.3909],
[-59.4785, -15.6003, -1.3909],
[-64.1743, -17.4412, 0.4398],
[-64.9221, -15.5337, 0.4398],
[-64.9221, -15.5337, -1.3909],
[-64.1743, -17.4412, -1.3909],
],
device=device,
dtype=torch.float32,
)
box19b = torch.tensor(
[
[-59.4874, -15.5775, -0.1512],
[-60.2174, -13.7155, -0.1512],
[-60.2174, -13.7155, -1.9820],
[-59.4874, -15.5775, -1.9820],
[-64.1832, -17.4185, -0.1512],
[-64.9132, -15.5564, -0.1512],
[-64.9132, -15.5564, -1.9820],
[-64.1832, -17.4185, -1.9820],
],
device=device,
dtype=torch.float32,
)
# from Meshlab
vol_inters = 12.505723
vol_box1 = 18.918238
vol_box2 = 18.468531
iou_mesh = vol_inters / (vol_box1 + vol_box2 - vol_inters)
vol, iou = overlap_fn(box19a[None], box19b[None])
self.assertClose(
iou,
torch.tensor([[iou_mesh]], device=vol.device, dtype=vol.dtype),
atol=1e-2,
)
self.assertClose(
vol,
torch.tensor([[vol_inters]], device=vol.device, dtype=vol.dtype),
atol=1e-2,
)
# symmetry
vol, iou = overlap_fn(box19b[None], box19a[None])
self.assertClose(
iou,
torch.tensor([[iou_mesh]], device=vol.device, dtype=vol.dtype),
atol=1e-2,
)
self.assertClose(
vol,
torch.tensor([[vol_inters]], device=vol.device, dtype=vol.dtype),
atol=1e-2,
)
def _test_real_boxes(self, overlap_fn, device):
data_filename = "./real_boxes.pkl"
@@ -715,7 +917,86 @@ def get_plane_verts(box: torch.Tensor) -> torch.Tensor:
return plane_verts
def box_planar_dir(box: torch.Tensor, eps: float = 1e-4) -> torch.Tensor:
def get_tri_center_normal(tris: torch.Tensor) -> torch.Tensor:
"""
Returns the center and normal of triangles
Args:
tris: tensor of shape (T, 3, 3)
Returns:
center: tensor of shape (T, 3)
normal: tensor of shape (T, 3)
"""
add_dim0 = False
if tris.ndim == 2:
tris = tris.unsqueeze(0)
add_dim0 = True
ctr = tris.mean(1) # (T, 3)
normals = torch.zeros_like(ctr)
v0, v1, v2 = tris.unbind(1) # 3 x (T, 3)
# unvectorized solution
T = tris.shape[0]
for t in range(T):
ns = torch.zeros((3, 3), device=tris.device)
ns[0] = torch.cross(v0[t] - ctr[t], v1[t] - ctr[t], dim=-1)
ns[1] = torch.cross(v0[t] - ctr[t], v2[t] - ctr[t], dim=-1)
ns[2] = torch.cross(v1[t] - ctr[t], v2[t] - ctr[t], dim=-1)
i = torch.norm(ns, dim=-1).argmax()
normals[t] = ns[i]
if add_dim0:
ctr = ctr[0]
normals = normals[0]
normals = F.normalize(normals, dim=-1)
return ctr, normals
def get_plane_center_normal(planes: torch.Tensor) -> torch.Tensor:
"""
Returns the center and normal of planes
Args:
planes: tensor of shape (P, 4, 3)
Returns:
center: tensor of shape (P, 3)
normal: tensor of shape (P, 3)
"""
add_dim0 = False
if planes.ndim == 2:
planes = planes.unsqueeze(0)
add_dim0 = True
ctr = planes.mean(1) # (P, 3)
normals = torch.zeros_like(ctr)
v0, v1, v2, v3 = planes.unbind(1) # 4 x (P, 3)
# unvectorized solution
P = planes.shape[0]
for t in range(P):
ns = torch.zeros((6, 3), device=planes.device)
ns[0] = torch.cross(v0[t] - ctr[t], v1[t] - ctr[t], dim=-1)
ns[1] = torch.cross(v0[t] - ctr[t], v2[t] - ctr[t], dim=-1)
ns[2] = torch.cross(v0[t] - ctr[t], v3[t] - ctr[t], dim=-1)
ns[3] = torch.cross(v1[t] - ctr[t], v2[t] - ctr[t], dim=-1)
ns[4] = torch.cross(v1[t] - ctr[t], v3[t] - ctr[t], dim=-1)
ns[5] = torch.cross(v2[t] - ctr[t], v3[t] - ctr[t], dim=-1)
i = torch.norm(ns, dim=-1).argmax()
normals[t] = ns[i]
if add_dim0:
ctr = ctr[0]
normals = normals[0]
normals = F.normalize(normals, dim=-1)
return ctr, normals
def box_planar_dir(
box: torch.Tensor, dot_eps: float = DOT_EPS, area_eps: float = AREA_EPS
) -> torch.Tensor:
"""
Finds the unit vector n which is perpendicular to each plane in the box
and points towards the inside of the box.
@@ -731,33 +1012,33 @@ def box_planar_dir(box: torch.Tensor, eps: float = 1e-4) -> torch.Tensor:
assert box.shape[0] == 8 and box.shape[1] == 3
# center point of each box
ctr = box.mean(0).view(1, 3)
box_ctr = box.mean(0).view(1, 3)
verts = get_plane_verts(box) # (6, 4, 3)
v0, v1, v2, v3 = verts.unbind(1) # each v of shape (6, 3)
# We project the ctr on the plane defined by (v0, v1, v2, v3)
# We define e0 to be the edge connecting (v1, v0)
# We define e1 to be the edge connecting (v2, v0)
# And n is the cross product of e0, e1, either pointing "inside" or not.
e0 = F.normalize(v1 - v0, dim=-1)
e1 = F.normalize(v2 - v0, dim=-1)
n = F.normalize(torch.cross(e0, e1, dim=-1), dim=-1)
# box planes
plane_verts = get_plane_verts(box) # (6, 4, 3)
v0, v1, v2, v3 = plane_verts.unbind(1)
plane_ctr, n = get_plane_center_normal(plane_verts)
# Check all verts are coplanar
if not ((v3 - v0).unsqueeze(1).bmm(n.unsqueeze(2)).abs() < eps).all().item():
if (
not (
F.normalize(v3 - v0, dim=-1).unsqueeze(1).bmm(n.unsqueeze(2)).abs()
< dot_eps
)
.all()
.item()
):
msg = "Plane vertices are not coplanar"
raise ValueError(msg)
# Check all faces have non zero area
area1 = torch.cross(v1 - v0, v2 - v0, dim=-1).norm(dim=-1) / 2
area2 = torch.cross(v3 - v0, v2 - v0, dim=-1).norm(dim=-1) / 2
if (area1 < eps).any().item() or (area2 < eps).any().item():
if (area1 < area_eps).any().item() or (area2 < area_eps).any().item():
msg = "Planes have zero areas"
raise ValueError(msg)
# We can write: `ctr = v0 + a * e0 + b * e1 + c * n`, (1).
# We can write: `box_ctr = plane_ctr + a * e0 + b * e1 + c * n`, (1).
# With <e0, n> = 0 and <e1, n> = 0, where <.,.> refers to the dot product,
# since that e0 is orthogonal to n. Same for e1.
"""
@@ -768,16 +1049,17 @@ def box_planar_dir(box: torch.Tensor, eps: float = 1e-4) -> torch.Tensor:
B = torch.ones((numF, 2), dtype=torch.float32, device=device)
A[:, 0, 1] = (e0 * e1).sum(-1)
A[:, 1, 0] = (e0 * e1).sum(-1)
B[:, 0] = ((ctr - v0) * e0).sum(-1)
B[:, 1] = ((ctr - v1) * e1).sum(-1)
B[:, 0] = ((box_ctr - plane_ctr) * e0).sum(-1)
B[:, 1] = ((box_ctr - plane_ctr) * e1).sum(-1)
ab = torch.linalg.solve(A, B) # (numF, 2)
a, b = ab.unbind(1)
# solving for c
c = ((ctr - v0 - a.view(numF, 1) * e0 - b.view(numF, 1) * e1) * n).sum(-1)
c = ((box_ctr - plane_ctr - a.view(numF, 1) * e0 - b.view(numF, 1) * e1) * n).sum(-1)
"""
# Since we know that <e0, n> = 0 and <e1, n> = 0 (e0 and e1 are orthogonal to n),
# the above solution is equivalent to
c = ((ctr - v0) * n).sum(-1)
direc = F.normalize(box_ctr - plane_ctr, dim=-1) # (6, 3)
c = (direc * n).sum(-1)
# If c is negative, then we revert the direction of n such that n points "inside"
negc = c < 0.0
n[negc] *= -1.0
@@ -848,7 +1130,7 @@ def box_volume(box: torch.Tensor) -> torch.Tensor:
return vols
def coplanar_tri_faces(tri1: torch.Tensor, tri2: torch.Tensor, eps: float = EPS):
def coplanar_tri_faces(tri1: torch.Tensor, tri2: torch.Tensor, eps: float = DOT_EPS):
"""
Determines whether two triangle faces in 3D are coplanar
Args:
@@ -857,17 +1139,49 @@ def coplanar_tri_faces(tri1: torch.Tensor, tri2: torch.Tensor, eps: float = EPS)
Returns:
is_coplanar: bool
"""
v0, v1, v2 = tri1.unbind(0)
e0 = F.normalize(v1 - v0, dim=0)
e1 = F.normalize(v2 - v0, dim=0)
e2 = F.normalize(torch.cross(e0, e1), dim=0)
tri1_ctr, tri1_n = get_tri_center_normal(tri1)
tri2_ctr, tri2_n = get_tri_center_normal(tri2)
coplanar2 = torch.zeros((3,), dtype=torch.bool, device=tri1.device)
for i in range(3):
if (tri2[i] - v0).dot(e2).abs() < eps:
coplanar2[i] = 1
coplanar2 = coplanar2.all()
return coplanar2
check1 = tri1_n.dot(tri2_n).abs() > 1 - eps # checks if parallel
dist12 = torch.norm(tri1.unsqueeze(1) - tri2.unsqueeze(0), dim=-1)
dist12_argmax = dist12.argmax()
i1 = dist12_argmax // 3
i2 = dist12_argmax % 3
assert dist12[i1, i2] == dist12.max()
check2 = (
F.normalize(tri1[i1] - tri2[i2], dim=0).dot(tri1_n).abs() < eps
) or F.normalize(tri1[i1] - tri2[i2], dim=0).dot(tri2_n).abs() < eps
return check1 and check2
def coplanar_tri_plane(
tri: torch.Tensor, plane: torch.Tensor, n: torch.Tensor, eps: float = DOT_EPS
):
"""
Determines whether two triangle faces in 3D are coplanar
Args:
tri: tensor of shape (3, 3) of the vertices of the triangle
plane: tensor of shape (4, 3) of the vertices of the plane
n: tensor of shape (3,) of the unit "inside" direction on the plane
Returns:
is_coplanar: bool
"""
tri_ctr, tri_n = get_tri_center_normal(tri)
check1 = tri_n.dot(n).abs() > 1 - eps # checks if parallel
dist12 = torch.norm(tri.unsqueeze(1) - plane.unsqueeze(0), dim=-1)
dist12_argmax = dist12.argmax()
i1 = dist12_argmax // 4
i2 = dist12_argmax % 4
assert dist12[i1, i2] == dist12.max()
check2 = F.normalize(tri[i1] - plane[i2], dim=0).dot(n).abs() < eps
return check1 and check2
def is_inside(
@@ -875,7 +1189,6 @@ def is_inside(
n: torch.Tensor,
points: torch.Tensor,
return_proj: bool = True,
eps: float = EPS,
):
"""
Computes whether point is "inside" the plane.
@@ -900,12 +1213,13 @@ def is_inside(
p_proj: tensor of shape (P, 3) of the projected point on plane
"""
device = plane.device
v0, v1, v2, v3 = plane
e0 = F.normalize(v1 - v0, dim=0)
e1 = F.normalize(v2 - v0, dim=0)
if not torch.allclose(e0.dot(n), torch.zeros((1,), device=device), atol=1e-6):
v0, v1, v2, v3 = plane.unbind(0)
plane_ctr = plane.mean(0)
e0 = F.normalize(v0 - plane_ctr, dim=0)
e1 = F.normalize(v1 - plane_ctr, dim=0)
if not torch.allclose(e0.dot(n), torch.zeros((1,), device=device), atol=1e-2):
raise ValueError("Input n is not perpendicular to the plane")
if not torch.allclose(e1.dot(n), torch.zeros((1,), device=device), atol=1e-6):
if not torch.allclose(e1.dot(n), torch.zeros((1,), device=device), atol=1e-2):
raise ValueError("Input n is not perpendicular to the plane")
add_dim = False
@@ -914,7 +1228,7 @@ def is_inside(
add_dim = True
assert points.shape[1] == 3
# Every point p can be written as p = v0 + a e0 + b e1 + c n
# Every point p can be written as p = ctr + a e0 + b e1 + c n
# If return_proj is True, we need to solve for (a, b)
p_proj = None
@@ -924,16 +1238,17 @@ def is_inside(
[[1.0, e0.dot(e1)], [e0.dot(e1), 1.0]], dtype=torch.float32, device=device
)
B = torch.zeros((2, points.shape[0]), dtype=torch.float32, device=device)
B[0, :] = torch.sum((points - v0.view(1, 3)) * e0.view(1, 3), dim=-1)
B[1, :] = torch.sum((points - v0.view(1, 3)) * e1.view(1, 3), dim=-1)
B[0, :] = torch.sum((points - plane_ctr.view(1, 3)) * e0.view(1, 3), dim=-1)
B[1, :] = torch.sum((points - plane_ctr.view(1, 3)) * e1.view(1, 3), dim=-1)
ab = A.inverse() @ B # (2, P)
p_proj = v0.view(1, 3) + ab.transpose(0, 1) @ torch.stack((e0, e1), dim=0)
p_proj = plane_ctr.view(1, 3) + ab.transpose(0, 1) @ torch.stack(
(e0, e1), dim=0
)
# solving for c
# c = (point - v0 - a * e0 - b * e1).dot(n)
c = torch.sum((points - v0.view(1, 3)) * n.view(1, 3), dim=-1)
ins = c > -eps
# c = (point - ctr - a * e0 - b * e1).dot(n)
direc = torch.sum((points - plane_ctr.view(1, 3)) * n.view(1, 3), dim=-1)
ins = direc >= 0.0
if add_dim:
assert p_proj.shape[0] == 1
@@ -942,7 +1257,7 @@ def is_inside(
return ins, p_proj
def plane_edge_point_of_intersection(plane, n, p0, p1):
def plane_edge_point_of_intersection(plane, n, p0, p1, eps: float = DOT_EPS):
"""
Finds the point of intersection between a box plane and
a line segment connecting (p0, p1).
@@ -961,11 +1276,17 @@ def plane_edge_point_of_intersection(plane, n, p0, p1):
# The point of intersection can be parametrized
# p = p0 + a (p1 - p0) where a in [0, 1]
# We want to find a such that p is on plane
# <p - v0, n> = 0
v0, v1, v2, v3 = plane
a = -(p0 - v0).dot(n) / (p1 - p0).dot(n)
p = p0 + a * (p1 - p0)
return p, a
# <p - ctr, n> = 0
# if segment (p0, p1) is parallel to plane (it can only be on it)
direc = F.normalize(p1 - p0, dim=0)
if direc.dot(n).abs() < eps:
return (p1 + p0) / 2.0, 0.5
else:
ctr = plane.mean(0)
a = -(p0 - ctr).dot(n) / ((p1 - p0).dot(n))
p = p0 + a * (p1 - p0)
return p, a
"""
@@ -983,7 +1304,6 @@ def clip_tri_by_plane_oneout(
vout: torch.Tensor,
vin1: torch.Tensor,
vin2: torch.Tensor,
eps: float = EPS,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Case (a).
@@ -1004,10 +1324,10 @@ def clip_tri_by_plane_oneout(
device = plane.device
# point of intersection between plane and (vin1, vout)
pint1, a1 = plane_edge_point_of_intersection(plane, n, vin1, vout)
assert a1 >= -eps and a1 <= 1.0 + eps, a1
assert a1 >= -0.0001 and a1 <= 1.0001, a1
# point of intersection between plane and (vin2, vout)
pint2, a2 = plane_edge_point_of_intersection(plane, n, vin2, vout)
assert a2 >= -eps and a2 <= 1.0 + eps, a2
assert a2 >= -0.0001 and a2 <= 1.0001, a2
verts = torch.stack((vin1, pint1, pint2, vin2), dim=0) # 4x3
faces = torch.tensor(
@@ -1022,7 +1342,6 @@ def clip_tri_by_plane_twoout(
vout1: torch.Tensor,
vout2: torch.Tensor,
vin: torch.Tensor,
eps: float = EPS,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Case (b).
@@ -1042,10 +1361,10 @@ def clip_tri_by_plane_twoout(
device = plane.device
# point of intersection between plane and (vin, vout1)
pint1, a1 = plane_edge_point_of_intersection(plane, n, vin, vout1)
assert a1 >= -eps and a1 <= 1.0 + eps, a1
assert a1 >= -0.0001 and a1 <= 1.0001, a1
# point of intersection between plane and (vin, vout2)
pint2, a2 = plane_edge_point_of_intersection(plane, n, vin, vout2)
assert a2 >= -eps and a2 <= 1.0 + eps, a2
assert a2 >= -0.0001 and a2 <= 1.0001, a2
verts = torch.stack((vin, pint1, pint2), dim=0) # 3x3
faces = torch.tensor(
@@ -1071,6 +1390,9 @@ def clip_tri_by_plane(plane, n, tri_verts) -> Union[List, torch.Tensor]:
tri_verts: tensor of shape (K, 3, 3) of the vertex coordinates of the triangles formed
after clipping. All K triangles are now "inside" the plane.
"""
if coplanar_tri_plane(tri_verts, plane, n):
return tri_verts.view(1, 3, 3)
v0, v1, v2 = tri_verts.unbind(0)
isin0, _ = is_inside(plane, n, v0)
isin1, _ = is_inside(plane, n, v1)
@@ -1191,7 +1513,7 @@ def box3d_overlap_naive(box1: torch.Tensor, box2: torch.Tensor):
for i2 in range(tri_verts2.shape[0]):
if (
coplanar_tri_faces(tri_verts1[i1], tri_verts2[i2])
and tri_verts_area(tri_verts1[i1]) > 1e-4
and tri_verts_area(tri_verts1[i1]) > AREA_EPS
):
keep2[i2] = 0
keep2 = keep2.nonzero()[:, 0]