small fix for iou3d

Summary:
A small numerical fix for IoU for 3D boxes, fixes GH #992

* Adds a check for boxes with zero side areas (invalid boxes)
* Fixes numerical issue when two boxes have coplanar sides

Reviewed By: nikhilaravi

Differential Revision: D33195691

fbshipit-source-id: 8a34b4d1f1e5ec2edb6d54143930da44bdde0906
This commit is contained in:
Georgia Gkioxari
2021-12-17 16:12:51 -08:00
committed by Facebook GitHub Bot
parent 069c9fd759
commit ccfb72cc50
6 changed files with 202 additions and 4 deletions

View File

@@ -111,6 +111,11 @@ class TestIoU3D(TestCaseMixin, unittest.TestCase):
self.assertClose(
vol, torch.tensor([[1 - dd]], device=vol.device, dtype=vol.dtype)
)
# symmetry
vol, iou = overlap_fn(box2[None], box1[None])
self.assertClose(
vol, torch.tensor([[1 - dd]], device=vol.device, dtype=vol.dtype)
)
# 3rd test
dd = random.random()
@@ -119,6 +124,11 @@ class TestIoU3D(TestCaseMixin, unittest.TestCase):
self.assertClose(
vol, torch.tensor([[1 - dd]], device=vol.device, dtype=vol.dtype)
)
# symmetry
vol, _ = overlap_fn(box2[None], box1[None])
self.assertClose(
vol, torch.tensor([[1 - dd]], device=vol.device, dtype=vol.dtype)
)
# 4th test
ddx, ddy, ddz = random.random(), random.random(), random.random()
@@ -132,6 +142,16 @@ class TestIoU3D(TestCaseMixin, unittest.TestCase):
dtype=vol.dtype,
),
)
# symmetry
vol, _ = overlap_fn(box2[None], box1[None])
self.assertClose(
vol,
torch.tensor(
[[(1 - ddx) * (1 - ddy) * (1 - ddz)]],
device=vol.device,
dtype=vol.dtype,
),
)
# Also check IoU is 1 when computing overlap with the same shifted box
vol, iou = overlap_fn(box2[None], box2[None])
@@ -152,6 +172,16 @@ class TestIoU3D(TestCaseMixin, unittest.TestCase):
dtype=vol.dtype,
),
)
# symmetry
vol, _ = overlap_fn(box2r[None], box1r[None])
self.assertClose(
vol,
torch.tensor(
[[(1 - ddx) * (1 - ddy) * (1 - ddz)]],
device=vol.device,
dtype=vol.dtype,
),
)
# 6th test
ddx, ddy, ddz = random.random(), random.random(), random.random()
@@ -170,6 +200,17 @@ class TestIoU3D(TestCaseMixin, unittest.TestCase):
),
atol=1e-7,
)
# symmetry
vol, _ = overlap_fn(box2r[None], box1r[None])
self.assertClose(
vol,
torch.tensor(
[[(1 - ddx) * (1 - ddy) * (1 - ddz)]],
device=vol.device,
dtype=vol.dtype,
),
atol=1e-7,
)
# 7th test: hand coded example and test with meshlab output
@@ -214,6 +255,10 @@ class TestIoU3D(TestCaseMixin, unittest.TestCase):
vol, iou = overlap_fn(box1r[None], box2r[None])
self.assertClose(vol, torch.tensor([[vol_inters]], device=device), atol=1e-1)
self.assertClose(iou, torch.tensor([[iou_mesh]], device=device), atol=1e-1)
# symmetry
vol, iou = overlap_fn(box2r[None], box1r[None])
self.assertClose(vol, torch.tensor([[vol_inters]], device=device), atol=1e-1)
self.assertClose(iou, torch.tensor([[iou_mesh]], device=device), atol=1e-1)
# 8th test: compare with sampling
# create box1
@@ -232,7 +277,9 @@ class TestIoU3D(TestCaseMixin, unittest.TestCase):
iou_sampling = self._box3d_overlap_sampling_batched(
box1r[None], box2r[None], num_samples=10000
)
self.assertClose(iou, iou_sampling, atol=1e-2)
# symmetry
vol, iou = overlap_fn(box2r[None], box1r[None])
self.assertClose(iou, iou_sampling, atol=1e-2)
# 9th test: non overlapping boxes, iou = 0.0
@@ -240,6 +287,10 @@ class TestIoU3D(TestCaseMixin, unittest.TestCase):
vol, iou = overlap_fn(box1[None], box2[None])
self.assertClose(vol, torch.tensor([[0.0]], device=vol.device, dtype=vol.dtype))
self.assertClose(iou, torch.tensor([[0.0]], device=vol.device, dtype=vol.dtype))
# symmetry
vol, iou = overlap_fn(box2[None], box1[None])
self.assertClose(vol, torch.tensor([[0.0]], device=vol.device, dtype=vol.dtype))
self.assertClose(iou, torch.tensor([[0.0]], device=vol.device, dtype=vol.dtype))
# 10th test: Non coplanar verts in a plane
box10 = box1 + torch.rand((8, 3), dtype=torch.float32, device=device)
@@ -284,6 +335,56 @@ class TestIoU3D(TestCaseMixin, unittest.TestCase):
vols, ious = overlap_fn(box_skew_1[None], box_skew_2[None])
self.assertClose(vols, torch.tensor([[vol_inters]], device=device), atol=1e-1)
self.assertClose(ious, torch.tensor([[iou]], device=device), atol=1e-1)
# symmetry
vols, ious = overlap_fn(box_skew_2[None], box_skew_1[None])
self.assertClose(vols, torch.tensor([[vol_inters]], device=device), atol=1e-1)
self.assertClose(ious, torch.tensor([[iou]], device=device), atol=1e-1)
# 12th test: Zero area bounding box (from GH issue #992)
box12a = torch.tensor(
[
[-1.0000, -1.0000, -0.5000],
[1.0000, -1.0000, -0.5000],
[1.0000, 1.0000, -0.5000],
[-1.0000, 1.0000, -0.5000],
[-1.0000, -1.0000, 0.5000],
[1.0000, -1.0000, 0.5000],
[1.0000, 1.0000, 0.5000],
[-1.0000, 1.0000, 0.5000],
],
device=device,
dtype=torch.float32,
)
box12b = torch.tensor(
[
[0.0, 0.0, 0.0],
[0.0, 0.0, 0.0],
[0.0, 0.0, 0.0],
[0.0, 0.0, 0.0],
[0.0, 0.0, 0.0],
[0.0, 0.0, 0.0],
[0.0, 0.0, 0.0],
[0.0, 0.0, 0.0],
],
device=device,
dtype=torch.float32,
)
msg = "Planes have zero areas"
with self.assertRaisesRegex(ValueError, msg):
overlap_fn(box12a[None], box12b[None])
# symmetry
with self.assertRaisesRegex(ValueError, msg):
overlap_fn(box12b[None], box12a[None])
# 13th test: From GH issue #992
# Zero area coplanar face after intersection
ctrs = torch.tensor([[0.0, 0.0, 0.0], [-1.0, 1.0, 0.0]])
whl = torch.tensor([[2.0, 2.0, 2.0], [2.0, 2, 2]])
box13a = TestIoU3D.create_box(ctrs[0], whl[0])
box13b = TestIoU3D.create_box(ctrs[1], whl[1])
vol, iou = overlap_fn(box13a[None], box13b[None])
self.assertClose(vol, torch.tensor([[2.0]], device=vol.device, dtype=vol.dtype))
def _test_real_boxes(self, overlap_fn, device):
data_filename = "./real_boxes.pkl"
@@ -577,6 +678,13 @@ def box_planar_dir(box: torch.Tensor, eps=1e-4) -> torch.Tensor:
msg = "Plane vertices are not coplanar"
raise ValueError(msg)
# Check all faces have non zero area
area1 = torch.cross(v1 - v0, v2 - v0, dim=-1).norm(dim=-1) / 2
area2 = torch.cross(v3 - v0, v2 - v0, dim=-1).norm(dim=-1) / 2
if (area1 < eps).any().item() or (area2 < eps).any().item():
msg = "Planes have zero areas"
raise ValueError(msg)
# We can write: `ctr = v0 + a * e0 + b * e1 + c * n`, (1).
# With <e0, n> = 0 and <e1, n> = 0, where <.,.> refers to the dot product,
# since that e0 is orthogonal to n. Same for e1.
@@ -607,6 +715,27 @@ def box_planar_dir(box: torch.Tensor, eps=1e-4) -> torch.Tensor:
return n
def tri_verts_area(tri_verts: torch.Tensor) -> torch.Tensor:
"""
Computes the area of the triangle faces in tri_verts
Args:
tri_verts: tensor of shape (T, 3, 3)
Returns:
areas: the area of the triangles (T, 1)
"""
add_dim = False
if tri_verts.ndim == 2:
tri_verts = tri_verts.unsqueeze(0)
add_dim = True
v0, v1, v2 = tri_verts.unbind(1)
areas = torch.cross(v1 - v0, v2 - v0, dim=-1).norm(dim=-1) / 2.0
if add_dim:
areas = areas[0]
return areas
def box_volume(box: torch.Tensor) -> torch.Tensor:
"""
Computes the volume of each box in boxes.
@@ -988,7 +1117,10 @@ def box3d_overlap_naive(box1: torch.Tensor, box2: torch.Tensor):
keep2 = torch.ones((tri_verts2.shape[0],), device=device, dtype=torch.bool)
for i1 in range(tri_verts1.shape[0]):
for i2 in range(tri_verts2.shape[0]):
if coplanar_tri_faces(tri_verts1[i1], tri_verts2[i2]):
if (
coplanar_tri_faces(tri_verts1[i1], tri_verts2[i2])
and tri_verts_area(tri_verts1[i1]) > 1e-4
):
keep2[i2] = 0
keep2 = keep2.nonzero()[:, 0]
tri_verts2 = tri_verts2[keep2]