mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-20 22:30:35 +08:00
eps fix for iou3d
Summary: Fix EPS issue that causes numerical instabilities when boxes are very close Reviewed By: kjchalup Differential Revision: D38661465 fbshipit-source-id: d2b6753cba9dc2f0072ace5289c9aa815a1a29f6
This commit is contained in:
committed by
Facebook GitHub Bot
parent
06cbba2628
commit
1bfe6bf20a
@@ -21,7 +21,8 @@ from .common_testing import get_random_cuda_device, get_tests_dir, TestCaseMixin
|
||||
OBJECTRON_TO_PYTORCH3D_FACE_IDX = [0, 4, 6, 2, 1, 5, 7, 3]
|
||||
DATA_DIR = get_tests_dir() / "data"
|
||||
DEBUG = False
|
||||
EPS = 1e-5
|
||||
DOT_EPS = 1e-3
|
||||
AREA_EPS = 1e-4
|
||||
|
||||
UNIT_BOX = [
|
||||
[0, 0, 0],
|
||||
@@ -457,6 +458,207 @@ class TestIoU3D(TestCaseMixin, unittest.TestCase):
|
||||
self.assertClose(
|
||||
iou, torch.tensor([[0.91]], device=vol.device, dtype=vol.dtype), atol=1e-2
|
||||
)
|
||||
# symmetry
|
||||
vol, iou = overlap_fn(box15b[None], box15a[None])
|
||||
self.assertClose(
|
||||
iou, torch.tensor([[0.91]], device=vol.device, dtype=vol.dtype), atol=1e-2
|
||||
)
|
||||
|
||||
# 16th test: From GH issue 1287
|
||||
box16a = torch.tensor(
|
||||
[
|
||||
[-167.5847, -70.6167, -2.7927],
|
||||
[-166.7333, -72.4264, -2.7927],
|
||||
[-166.7333, -72.4264, -4.5927],
|
||||
[-167.5847, -70.6167, -4.5927],
|
||||
[-163.0605, -68.4880, -2.7927],
|
||||
[-162.2090, -70.2977, -2.7927],
|
||||
[-162.2090, -70.2977, -4.5927],
|
||||
[-163.0605, -68.4880, -4.5927],
|
||||
],
|
||||
device=device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
|
||||
box16b = torch.tensor(
|
||||
[
|
||||
[-167.5847, -70.6167, -2.7927],
|
||||
[-166.7333, -72.4264, -2.7927],
|
||||
[-166.7333, -72.4264, -4.5927],
|
||||
[-167.5847, -70.6167, -4.5927],
|
||||
[-163.0605, -68.4880, -2.7927],
|
||||
[-162.2090, -70.2977, -2.7927],
|
||||
[-162.2090, -70.2977, -4.5927],
|
||||
[-163.0605, -68.4880, -4.5927],
|
||||
],
|
||||
device=device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
vol, iou = overlap_fn(box16a[None], box16b[None])
|
||||
self.assertClose(
|
||||
iou, torch.tensor([[1.0]], device=vol.device, dtype=vol.dtype), atol=1e-2
|
||||
)
|
||||
# symmetry
|
||||
vol, iou = overlap_fn(box16b[None], box16a[None])
|
||||
self.assertClose(
|
||||
iou, torch.tensor([[1.0]], device=vol.device, dtype=vol.dtype), atol=1e-2
|
||||
)
|
||||
|
||||
# 17th test: From GH issue 1287
|
||||
box17a = torch.tensor(
|
||||
[
|
||||
[-33.94158, -4.51639, 0.96941],
|
||||
[-34.67156, -2.65437, 0.96941],
|
||||
[-34.67156, -2.65437, -0.95367],
|
||||
[-33.94158, -4.51639, -0.95367],
|
||||
[-38.75954, -6.40521, 0.96941],
|
||||
[-39.48952, -4.54319, 0.96941],
|
||||
[-39.48952, -4.54319, -0.95367],
|
||||
[-38.75954, -6.40521, -0.95367],
|
||||
],
|
||||
device=device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
|
||||
box17b = torch.tensor(
|
||||
[
|
||||
[-33.94159, -4.51638, 0.96939],
|
||||
[-34.67158, -2.65437, 0.96939],
|
||||
[-34.67158, -2.65437, -0.95368],
|
||||
[-33.94159, -4.51638, -0.95368],
|
||||
[-38.75954, -6.40523, 0.96939],
|
||||
[-39.48953, -4.54321, 0.96939],
|
||||
[-39.48953, -4.54321, -0.95368],
|
||||
[-38.75954, -6.40523, -0.95368],
|
||||
],
|
||||
device=device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
vol, iou = overlap_fn(box17a[None], box17b[None])
|
||||
self.assertClose(
|
||||
iou, torch.tensor([[1.0]], device=vol.device, dtype=vol.dtype), atol=1e-2
|
||||
)
|
||||
# symmetry
|
||||
vol, iou = overlap_fn(box17b[None], box17a[None])
|
||||
self.assertClose(
|
||||
iou, torch.tensor([[1.0]], device=vol.device, dtype=vol.dtype), atol=1e-2
|
||||
)
|
||||
|
||||
# 18th test: From GH issue 1287
|
||||
box18a = torch.tensor(
|
||||
[
|
||||
[-105.6248, -32.7026, -1.2279],
|
||||
[-106.4690, -30.8895, -1.2279],
|
||||
[-106.4690, -30.8895, -3.0279],
|
||||
[-105.6248, -32.7026, -3.0279],
|
||||
[-110.1575, -34.8132, -1.2279],
|
||||
[-111.0017, -33.0001, -1.2279],
|
||||
[-111.0017, -33.0001, -3.0279],
|
||||
[-110.1575, -34.8132, -3.0279],
|
||||
],
|
||||
device=device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
box18b = torch.tensor(
|
||||
[
|
||||
[-105.5094, -32.9504, -1.0641],
|
||||
[-106.4272, -30.9793, -1.0641],
|
||||
[-106.4272, -30.9793, -3.1916],
|
||||
[-105.5094, -32.9504, -3.1916],
|
||||
[-110.0421, -35.0609, -1.0641],
|
||||
[-110.9599, -33.0899, -1.0641],
|
||||
[-110.9599, -33.0899, -3.1916],
|
||||
[-110.0421, -35.0609, -3.1916],
|
||||
],
|
||||
device=device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
# from Meshlab
|
||||
vol_inters = 17.108501
|
||||
vol_box1 = 18.000067
|
||||
vol_box2 = 23.128527
|
||||
iou_mesh = vol_inters / (vol_box1 + vol_box2 - vol_inters)
|
||||
vol, iou = overlap_fn(box18a[None], box18b[None])
|
||||
self.assertClose(
|
||||
iou,
|
||||
torch.tensor([[iou_mesh]], device=vol.device, dtype=vol.dtype),
|
||||
atol=1e-2,
|
||||
)
|
||||
self.assertClose(
|
||||
vol,
|
||||
torch.tensor([[vol_inters]], device=vol.device, dtype=vol.dtype),
|
||||
atol=1e-2,
|
||||
)
|
||||
# symmetry
|
||||
vol, iou = overlap_fn(box18b[None], box18a[None])
|
||||
self.assertClose(
|
||||
iou,
|
||||
torch.tensor([[iou_mesh]], device=vol.device, dtype=vol.dtype),
|
||||
atol=1e-2,
|
||||
)
|
||||
self.assertClose(
|
||||
vol,
|
||||
torch.tensor([[vol_inters]], device=vol.device, dtype=vol.dtype),
|
||||
atol=1e-2,
|
||||
)
|
||||
|
||||
# 19th example: From GH issue 1287
|
||||
box19a = torch.tensor(
|
||||
[
|
||||
[-59.4785, -15.6003, 0.4398],
|
||||
[-60.2263, -13.6928, 0.4398],
|
||||
[-60.2263, -13.6928, -1.3909],
|
||||
[-59.4785, -15.6003, -1.3909],
|
||||
[-64.1743, -17.4412, 0.4398],
|
||||
[-64.9221, -15.5337, 0.4398],
|
||||
[-64.9221, -15.5337, -1.3909],
|
||||
[-64.1743, -17.4412, -1.3909],
|
||||
],
|
||||
device=device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
box19b = torch.tensor(
|
||||
[
|
||||
[-59.4874, -15.5775, -0.1512],
|
||||
[-60.2174, -13.7155, -0.1512],
|
||||
[-60.2174, -13.7155, -1.9820],
|
||||
[-59.4874, -15.5775, -1.9820],
|
||||
[-64.1832, -17.4185, -0.1512],
|
||||
[-64.9132, -15.5564, -0.1512],
|
||||
[-64.9132, -15.5564, -1.9820],
|
||||
[-64.1832, -17.4185, -1.9820],
|
||||
],
|
||||
device=device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
# from Meshlab
|
||||
vol_inters = 12.505723
|
||||
vol_box1 = 18.918238
|
||||
vol_box2 = 18.468531
|
||||
iou_mesh = vol_inters / (vol_box1 + vol_box2 - vol_inters)
|
||||
vol, iou = overlap_fn(box19a[None], box19b[None])
|
||||
self.assertClose(
|
||||
iou,
|
||||
torch.tensor([[iou_mesh]], device=vol.device, dtype=vol.dtype),
|
||||
atol=1e-2,
|
||||
)
|
||||
self.assertClose(
|
||||
vol,
|
||||
torch.tensor([[vol_inters]], device=vol.device, dtype=vol.dtype),
|
||||
atol=1e-2,
|
||||
)
|
||||
# symmetry
|
||||
vol, iou = overlap_fn(box19b[None], box19a[None])
|
||||
self.assertClose(
|
||||
iou,
|
||||
torch.tensor([[iou_mesh]], device=vol.device, dtype=vol.dtype),
|
||||
atol=1e-2,
|
||||
)
|
||||
self.assertClose(
|
||||
vol,
|
||||
torch.tensor([[vol_inters]], device=vol.device, dtype=vol.dtype),
|
||||
atol=1e-2,
|
||||
)
|
||||
|
||||
def _test_real_boxes(self, overlap_fn, device):
|
||||
data_filename = "./real_boxes.pkl"
|
||||
@@ -715,7 +917,86 @@ def get_plane_verts(box: torch.Tensor) -> torch.Tensor:
|
||||
return plane_verts
|
||||
|
||||
|
||||
def box_planar_dir(box: torch.Tensor, eps: float = 1e-4) -> torch.Tensor:
|
||||
def get_tri_center_normal(tris: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Returns the center and normal of triangles
|
||||
Args:
|
||||
tris: tensor of shape (T, 3, 3)
|
||||
Returns:
|
||||
center: tensor of shape (T, 3)
|
||||
normal: tensor of shape (T, 3)
|
||||
"""
|
||||
add_dim0 = False
|
||||
if tris.ndim == 2:
|
||||
tris = tris.unsqueeze(0)
|
||||
add_dim0 = True
|
||||
|
||||
ctr = tris.mean(1) # (T, 3)
|
||||
normals = torch.zeros_like(ctr)
|
||||
|
||||
v0, v1, v2 = tris.unbind(1) # 3 x (T, 3)
|
||||
|
||||
# unvectorized solution
|
||||
T = tris.shape[0]
|
||||
for t in range(T):
|
||||
ns = torch.zeros((3, 3), device=tris.device)
|
||||
ns[0] = torch.cross(v0[t] - ctr[t], v1[t] - ctr[t], dim=-1)
|
||||
ns[1] = torch.cross(v0[t] - ctr[t], v2[t] - ctr[t], dim=-1)
|
||||
ns[2] = torch.cross(v1[t] - ctr[t], v2[t] - ctr[t], dim=-1)
|
||||
|
||||
i = torch.norm(ns, dim=-1).argmax()
|
||||
normals[t] = ns[i]
|
||||
|
||||
if add_dim0:
|
||||
ctr = ctr[0]
|
||||
normals = normals[0]
|
||||
normals = F.normalize(normals, dim=-1)
|
||||
return ctr, normals
|
||||
|
||||
|
||||
def get_plane_center_normal(planes: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Returns the center and normal of planes
|
||||
Args:
|
||||
planes: tensor of shape (P, 4, 3)
|
||||
Returns:
|
||||
center: tensor of shape (P, 3)
|
||||
normal: tensor of shape (P, 3)
|
||||
"""
|
||||
add_dim0 = False
|
||||
if planes.ndim == 2:
|
||||
planes = planes.unsqueeze(0)
|
||||
add_dim0 = True
|
||||
|
||||
ctr = planes.mean(1) # (P, 3)
|
||||
normals = torch.zeros_like(ctr)
|
||||
|
||||
v0, v1, v2, v3 = planes.unbind(1) # 4 x (P, 3)
|
||||
|
||||
# unvectorized solution
|
||||
P = planes.shape[0]
|
||||
for t in range(P):
|
||||
ns = torch.zeros((6, 3), device=planes.device)
|
||||
ns[0] = torch.cross(v0[t] - ctr[t], v1[t] - ctr[t], dim=-1)
|
||||
ns[1] = torch.cross(v0[t] - ctr[t], v2[t] - ctr[t], dim=-1)
|
||||
ns[2] = torch.cross(v0[t] - ctr[t], v3[t] - ctr[t], dim=-1)
|
||||
ns[3] = torch.cross(v1[t] - ctr[t], v2[t] - ctr[t], dim=-1)
|
||||
ns[4] = torch.cross(v1[t] - ctr[t], v3[t] - ctr[t], dim=-1)
|
||||
ns[5] = torch.cross(v2[t] - ctr[t], v3[t] - ctr[t], dim=-1)
|
||||
|
||||
i = torch.norm(ns, dim=-1).argmax()
|
||||
normals[t] = ns[i]
|
||||
|
||||
if add_dim0:
|
||||
ctr = ctr[0]
|
||||
normals = normals[0]
|
||||
normals = F.normalize(normals, dim=-1)
|
||||
return ctr, normals
|
||||
|
||||
|
||||
def box_planar_dir(
|
||||
box: torch.Tensor, dot_eps: float = DOT_EPS, area_eps: float = AREA_EPS
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Finds the unit vector n which is perpendicular to each plane in the box
|
||||
and points towards the inside of the box.
|
||||
@@ -731,33 +1012,33 @@ def box_planar_dir(box: torch.Tensor, eps: float = 1e-4) -> torch.Tensor:
|
||||
assert box.shape[0] == 8 and box.shape[1] == 3
|
||||
|
||||
# center point of each box
|
||||
ctr = box.mean(0).view(1, 3)
|
||||
box_ctr = box.mean(0).view(1, 3)
|
||||
|
||||
verts = get_plane_verts(box) # (6, 4, 3)
|
||||
|
||||
v0, v1, v2, v3 = verts.unbind(1) # each v of shape (6, 3)
|
||||
|
||||
# We project the ctr on the plane defined by (v0, v1, v2, v3)
|
||||
# We define e0 to be the edge connecting (v1, v0)
|
||||
# We define e1 to be the edge connecting (v2, v0)
|
||||
# And n is the cross product of e0, e1, either pointing "inside" or not.
|
||||
e0 = F.normalize(v1 - v0, dim=-1)
|
||||
e1 = F.normalize(v2 - v0, dim=-1)
|
||||
n = F.normalize(torch.cross(e0, e1, dim=-1), dim=-1)
|
||||
# box planes
|
||||
plane_verts = get_plane_verts(box) # (6, 4, 3)
|
||||
v0, v1, v2, v3 = plane_verts.unbind(1)
|
||||
plane_ctr, n = get_plane_center_normal(plane_verts)
|
||||
|
||||
# Check all verts are coplanar
|
||||
if not ((v3 - v0).unsqueeze(1).bmm(n.unsqueeze(2)).abs() < eps).all().item():
|
||||
if (
|
||||
not (
|
||||
F.normalize(v3 - v0, dim=-1).unsqueeze(1).bmm(n.unsqueeze(2)).abs()
|
||||
< dot_eps
|
||||
)
|
||||
.all()
|
||||
.item()
|
||||
):
|
||||
msg = "Plane vertices are not coplanar"
|
||||
raise ValueError(msg)
|
||||
|
||||
# Check all faces have non zero area
|
||||
area1 = torch.cross(v1 - v0, v2 - v0, dim=-1).norm(dim=-1) / 2
|
||||
area2 = torch.cross(v3 - v0, v2 - v0, dim=-1).norm(dim=-1) / 2
|
||||
if (area1 < eps).any().item() or (area2 < eps).any().item():
|
||||
if (area1 < area_eps).any().item() or (area2 < area_eps).any().item():
|
||||
msg = "Planes have zero areas"
|
||||
raise ValueError(msg)
|
||||
|
||||
# We can write: `ctr = v0 + a * e0 + b * e1 + c * n`, (1).
|
||||
# We can write: `box_ctr = plane_ctr + a * e0 + b * e1 + c * n`, (1).
|
||||
# With <e0, n> = 0 and <e1, n> = 0, where <.,.> refers to the dot product,
|
||||
# since that e0 is orthogonal to n. Same for e1.
|
||||
"""
|
||||
@@ -768,16 +1049,17 @@ def box_planar_dir(box: torch.Tensor, eps: float = 1e-4) -> torch.Tensor:
|
||||
B = torch.ones((numF, 2), dtype=torch.float32, device=device)
|
||||
A[:, 0, 1] = (e0 * e1).sum(-1)
|
||||
A[:, 1, 0] = (e0 * e1).sum(-1)
|
||||
B[:, 0] = ((ctr - v0) * e0).sum(-1)
|
||||
B[:, 1] = ((ctr - v1) * e1).sum(-1)
|
||||
B[:, 0] = ((box_ctr - plane_ctr) * e0).sum(-1)
|
||||
B[:, 1] = ((box_ctr - plane_ctr) * e1).sum(-1)
|
||||
ab = torch.linalg.solve(A, B) # (numF, 2)
|
||||
a, b = ab.unbind(1)
|
||||
# solving for c
|
||||
c = ((ctr - v0 - a.view(numF, 1) * e0 - b.view(numF, 1) * e1) * n).sum(-1)
|
||||
c = ((box_ctr - plane_ctr - a.view(numF, 1) * e0 - b.view(numF, 1) * e1) * n).sum(-1)
|
||||
"""
|
||||
# Since we know that <e0, n> = 0 and <e1, n> = 0 (e0 and e1 are orthogonal to n),
|
||||
# the above solution is equivalent to
|
||||
c = ((ctr - v0) * n).sum(-1)
|
||||
direc = F.normalize(box_ctr - plane_ctr, dim=-1) # (6, 3)
|
||||
c = (direc * n).sum(-1)
|
||||
# If c is negative, then we revert the direction of n such that n points "inside"
|
||||
negc = c < 0.0
|
||||
n[negc] *= -1.0
|
||||
@@ -848,7 +1130,7 @@ def box_volume(box: torch.Tensor) -> torch.Tensor:
|
||||
return vols
|
||||
|
||||
|
||||
def coplanar_tri_faces(tri1: torch.Tensor, tri2: torch.Tensor, eps: float = EPS):
|
||||
def coplanar_tri_faces(tri1: torch.Tensor, tri2: torch.Tensor, eps: float = DOT_EPS):
|
||||
"""
|
||||
Determines whether two triangle faces in 3D are coplanar
|
||||
Args:
|
||||
@@ -857,17 +1139,49 @@ def coplanar_tri_faces(tri1: torch.Tensor, tri2: torch.Tensor, eps: float = EPS)
|
||||
Returns:
|
||||
is_coplanar: bool
|
||||
"""
|
||||
v0, v1, v2 = tri1.unbind(0)
|
||||
e0 = F.normalize(v1 - v0, dim=0)
|
||||
e1 = F.normalize(v2 - v0, dim=0)
|
||||
e2 = F.normalize(torch.cross(e0, e1), dim=0)
|
||||
tri1_ctr, tri1_n = get_tri_center_normal(tri1)
|
||||
tri2_ctr, tri2_n = get_tri_center_normal(tri2)
|
||||
|
||||
coplanar2 = torch.zeros((3,), dtype=torch.bool, device=tri1.device)
|
||||
for i in range(3):
|
||||
if (tri2[i] - v0).dot(e2).abs() < eps:
|
||||
coplanar2[i] = 1
|
||||
coplanar2 = coplanar2.all()
|
||||
return coplanar2
|
||||
check1 = tri1_n.dot(tri2_n).abs() > 1 - eps # checks if parallel
|
||||
|
||||
dist12 = torch.norm(tri1.unsqueeze(1) - tri2.unsqueeze(0), dim=-1)
|
||||
dist12_argmax = dist12.argmax()
|
||||
i1 = dist12_argmax // 3
|
||||
i2 = dist12_argmax % 3
|
||||
assert dist12[i1, i2] == dist12.max()
|
||||
|
||||
check2 = (
|
||||
F.normalize(tri1[i1] - tri2[i2], dim=0).dot(tri1_n).abs() < eps
|
||||
) or F.normalize(tri1[i1] - tri2[i2], dim=0).dot(tri2_n).abs() < eps
|
||||
|
||||
return check1 and check2
|
||||
|
||||
|
||||
def coplanar_tri_plane(
|
||||
tri: torch.Tensor, plane: torch.Tensor, n: torch.Tensor, eps: float = DOT_EPS
|
||||
):
|
||||
"""
|
||||
Determines whether two triangle faces in 3D are coplanar
|
||||
Args:
|
||||
tri: tensor of shape (3, 3) of the vertices of the triangle
|
||||
plane: tensor of shape (4, 3) of the vertices of the plane
|
||||
n: tensor of shape (3,) of the unit "inside" direction on the plane
|
||||
Returns:
|
||||
is_coplanar: bool
|
||||
"""
|
||||
tri_ctr, tri_n = get_tri_center_normal(tri)
|
||||
|
||||
check1 = tri_n.dot(n).abs() > 1 - eps # checks if parallel
|
||||
|
||||
dist12 = torch.norm(tri.unsqueeze(1) - plane.unsqueeze(0), dim=-1)
|
||||
dist12_argmax = dist12.argmax()
|
||||
i1 = dist12_argmax // 4
|
||||
i2 = dist12_argmax % 4
|
||||
assert dist12[i1, i2] == dist12.max()
|
||||
|
||||
check2 = F.normalize(tri[i1] - plane[i2], dim=0).dot(n).abs() < eps
|
||||
|
||||
return check1 and check2
|
||||
|
||||
|
||||
def is_inside(
|
||||
@@ -875,7 +1189,6 @@ def is_inside(
|
||||
n: torch.Tensor,
|
||||
points: torch.Tensor,
|
||||
return_proj: bool = True,
|
||||
eps: float = EPS,
|
||||
):
|
||||
"""
|
||||
Computes whether point is "inside" the plane.
|
||||
@@ -900,12 +1213,13 @@ def is_inside(
|
||||
p_proj: tensor of shape (P, 3) of the projected point on plane
|
||||
"""
|
||||
device = plane.device
|
||||
v0, v1, v2, v3 = plane
|
||||
e0 = F.normalize(v1 - v0, dim=0)
|
||||
e1 = F.normalize(v2 - v0, dim=0)
|
||||
if not torch.allclose(e0.dot(n), torch.zeros((1,), device=device), atol=1e-6):
|
||||
v0, v1, v2, v3 = plane.unbind(0)
|
||||
plane_ctr = plane.mean(0)
|
||||
e0 = F.normalize(v0 - plane_ctr, dim=0)
|
||||
e1 = F.normalize(v1 - plane_ctr, dim=0)
|
||||
if not torch.allclose(e0.dot(n), torch.zeros((1,), device=device), atol=1e-2):
|
||||
raise ValueError("Input n is not perpendicular to the plane")
|
||||
if not torch.allclose(e1.dot(n), torch.zeros((1,), device=device), atol=1e-6):
|
||||
if not torch.allclose(e1.dot(n), torch.zeros((1,), device=device), atol=1e-2):
|
||||
raise ValueError("Input n is not perpendicular to the plane")
|
||||
|
||||
add_dim = False
|
||||
@@ -914,7 +1228,7 @@ def is_inside(
|
||||
add_dim = True
|
||||
|
||||
assert points.shape[1] == 3
|
||||
# Every point p can be written as p = v0 + a e0 + b e1 + c n
|
||||
# Every point p can be written as p = ctr + a e0 + b e1 + c n
|
||||
|
||||
# If return_proj is True, we need to solve for (a, b)
|
||||
p_proj = None
|
||||
@@ -924,16 +1238,17 @@ def is_inside(
|
||||
[[1.0, e0.dot(e1)], [e0.dot(e1), 1.0]], dtype=torch.float32, device=device
|
||||
)
|
||||
B = torch.zeros((2, points.shape[0]), dtype=torch.float32, device=device)
|
||||
B[0, :] = torch.sum((points - v0.view(1, 3)) * e0.view(1, 3), dim=-1)
|
||||
B[1, :] = torch.sum((points - v0.view(1, 3)) * e1.view(1, 3), dim=-1)
|
||||
|
||||
B[0, :] = torch.sum((points - plane_ctr.view(1, 3)) * e0.view(1, 3), dim=-1)
|
||||
B[1, :] = torch.sum((points - plane_ctr.view(1, 3)) * e1.view(1, 3), dim=-1)
|
||||
ab = A.inverse() @ B # (2, P)
|
||||
p_proj = v0.view(1, 3) + ab.transpose(0, 1) @ torch.stack((e0, e1), dim=0)
|
||||
p_proj = plane_ctr.view(1, 3) + ab.transpose(0, 1) @ torch.stack(
|
||||
(e0, e1), dim=0
|
||||
)
|
||||
|
||||
# solving for c
|
||||
# c = (point - v0 - a * e0 - b * e1).dot(n)
|
||||
c = torch.sum((points - v0.view(1, 3)) * n.view(1, 3), dim=-1)
|
||||
ins = c > -eps
|
||||
# c = (point - ctr - a * e0 - b * e1).dot(n)
|
||||
direc = torch.sum((points - plane_ctr.view(1, 3)) * n.view(1, 3), dim=-1)
|
||||
ins = direc >= 0.0
|
||||
|
||||
if add_dim:
|
||||
assert p_proj.shape[0] == 1
|
||||
@@ -942,7 +1257,7 @@ def is_inside(
|
||||
return ins, p_proj
|
||||
|
||||
|
||||
def plane_edge_point_of_intersection(plane, n, p0, p1):
|
||||
def plane_edge_point_of_intersection(plane, n, p0, p1, eps: float = DOT_EPS):
|
||||
"""
|
||||
Finds the point of intersection between a box plane and
|
||||
a line segment connecting (p0, p1).
|
||||
@@ -961,11 +1276,17 @@ def plane_edge_point_of_intersection(plane, n, p0, p1):
|
||||
# The point of intersection can be parametrized
|
||||
# p = p0 + a (p1 - p0) where a in [0, 1]
|
||||
# We want to find a such that p is on plane
|
||||
# <p - v0, n> = 0
|
||||
v0, v1, v2, v3 = plane
|
||||
a = -(p0 - v0).dot(n) / (p1 - p0).dot(n)
|
||||
p = p0 + a * (p1 - p0)
|
||||
return p, a
|
||||
# <p - ctr, n> = 0
|
||||
|
||||
# if segment (p0, p1) is parallel to plane (it can only be on it)
|
||||
direc = F.normalize(p1 - p0, dim=0)
|
||||
if direc.dot(n).abs() < eps:
|
||||
return (p1 + p0) / 2.0, 0.5
|
||||
else:
|
||||
ctr = plane.mean(0)
|
||||
a = -(p0 - ctr).dot(n) / ((p1 - p0).dot(n))
|
||||
p = p0 + a * (p1 - p0)
|
||||
return p, a
|
||||
|
||||
|
||||
"""
|
||||
@@ -983,7 +1304,6 @@ def clip_tri_by_plane_oneout(
|
||||
vout: torch.Tensor,
|
||||
vin1: torch.Tensor,
|
||||
vin2: torch.Tensor,
|
||||
eps: float = EPS,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Case (a).
|
||||
@@ -1004,10 +1324,10 @@ def clip_tri_by_plane_oneout(
|
||||
device = plane.device
|
||||
# point of intersection between plane and (vin1, vout)
|
||||
pint1, a1 = plane_edge_point_of_intersection(plane, n, vin1, vout)
|
||||
assert a1 >= -eps and a1 <= 1.0 + eps, a1
|
||||
assert a1 >= -0.0001 and a1 <= 1.0001, a1
|
||||
# point of intersection between plane and (vin2, vout)
|
||||
pint2, a2 = plane_edge_point_of_intersection(plane, n, vin2, vout)
|
||||
assert a2 >= -eps and a2 <= 1.0 + eps, a2
|
||||
assert a2 >= -0.0001 and a2 <= 1.0001, a2
|
||||
|
||||
verts = torch.stack((vin1, pint1, pint2, vin2), dim=0) # 4x3
|
||||
faces = torch.tensor(
|
||||
@@ -1022,7 +1342,6 @@ def clip_tri_by_plane_twoout(
|
||||
vout1: torch.Tensor,
|
||||
vout2: torch.Tensor,
|
||||
vin: torch.Tensor,
|
||||
eps: float = EPS,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Case (b).
|
||||
@@ -1042,10 +1361,10 @@ def clip_tri_by_plane_twoout(
|
||||
device = plane.device
|
||||
# point of intersection between plane and (vin, vout1)
|
||||
pint1, a1 = plane_edge_point_of_intersection(plane, n, vin, vout1)
|
||||
assert a1 >= -eps and a1 <= 1.0 + eps, a1
|
||||
assert a1 >= -0.0001 and a1 <= 1.0001, a1
|
||||
# point of intersection between plane and (vin, vout2)
|
||||
pint2, a2 = plane_edge_point_of_intersection(plane, n, vin, vout2)
|
||||
assert a2 >= -eps and a2 <= 1.0 + eps, a2
|
||||
assert a2 >= -0.0001 and a2 <= 1.0001, a2
|
||||
|
||||
verts = torch.stack((vin, pint1, pint2), dim=0) # 3x3
|
||||
faces = torch.tensor(
|
||||
@@ -1071,6 +1390,9 @@ def clip_tri_by_plane(plane, n, tri_verts) -> Union[List, torch.Tensor]:
|
||||
tri_verts: tensor of shape (K, 3, 3) of the vertex coordinates of the triangles formed
|
||||
after clipping. All K triangles are now "inside" the plane.
|
||||
"""
|
||||
if coplanar_tri_plane(tri_verts, plane, n):
|
||||
return tri_verts.view(1, 3, 3)
|
||||
|
||||
v0, v1, v2 = tri_verts.unbind(0)
|
||||
isin0, _ = is_inside(plane, n, v0)
|
||||
isin1, _ = is_inside(plane, n, v1)
|
||||
@@ -1191,7 +1513,7 @@ def box3d_overlap_naive(box1: torch.Tensor, box2: torch.Tensor):
|
||||
for i2 in range(tri_verts2.shape[0]):
|
||||
if (
|
||||
coplanar_tri_faces(tri_verts1[i1], tri_verts2[i2])
|
||||
and tri_verts_area(tri_verts1[i1]) > 1e-4
|
||||
and tri_verts_area(tri_verts1[i1]) > AREA_EPS
|
||||
):
|
||||
keep2[i2] = 0
|
||||
keep2 = keep2.nonzero()[:, 0]
|
||||
|
||||
Reference in New Issue
Block a user