mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-19 05:40:34 +08:00
C++ IoU for 3D Boxes
Summary: C++ Implementation of algorithm to compute 3D bounding boxes for batches of bboxes of shape (N, 8, 3) and (M, 8, 3). Reviewed By: gkioxari Differential Revision: D30905190 fbshipit-source-id: 02e2cf025cd4fa3ff706ce5cf9b82c0fb5443f96
This commit is contained in:
committed by
Facebook GitHub Bot
parent
2293f1fed0
commit
53266ec9ff
37
tests/bm_iou_box3d.py
Normal file
37
tests/bm_iou_box3d.py
Normal file
@@ -0,0 +1,37 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from itertools import product
|
||||
|
||||
from fvcore.common.benchmark import benchmark
|
||||
from test_iou_box3d import TestIoU3D
|
||||
|
||||
|
||||
def bm_iou_box3d() -> None:
|
||||
N = [1, 4, 8, 16]
|
||||
num_samples = [2000, 5000, 10000, 20000]
|
||||
|
||||
kwargs_list = []
|
||||
test_cases = product(N, N)
|
||||
for case in test_cases:
|
||||
n, m = case
|
||||
kwargs_list.append({"N": n, "M": m, "device": "cuda:0"})
|
||||
|
||||
benchmark(TestIoU3D.iou_naive, "3D_IOU_NAIVE", kwargs_list, warmup_iters=1)
|
||||
|
||||
[k.update({"device": "cpu"}) for k in kwargs_list]
|
||||
benchmark(TestIoU3D.iou, "3D_IOU", kwargs_list, warmup_iters=1)
|
||||
|
||||
kwargs_list = []
|
||||
test_cases = product([1, 4], [1, 4], num_samples)
|
||||
for case in test_cases:
|
||||
n, m, s = case
|
||||
kwargs_list.append({"N": n, "M": m, "num_samples": s})
|
||||
benchmark(TestIoU3D.iou_sampling, "3D_IOU_SAMPLING", kwargs_list, warmup_iters=1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
bm_iou_box3d()
|
||||
@@ -4,7 +4,6 @@
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import random
|
||||
import unittest
|
||||
from typing import List, Tuple, Union
|
||||
@@ -13,6 +12,8 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from common_testing import TestCaseMixin
|
||||
from pytorch3d.io import save_obj
|
||||
|
||||
from pytorch3d.ops.iou_box3d import box3d_overlap
|
||||
from pytorch3d.transforms.rotation_conversions import random_rotation
|
||||
|
||||
|
||||
@@ -21,7 +22,8 @@ class TestIoU3D(TestCaseMixin, unittest.TestCase):
|
||||
super().setUp()
|
||||
torch.manual_seed(1)
|
||||
|
||||
def create_box(self, xyz, whl):
|
||||
@staticmethod
|
||||
def create_box(xyz, whl):
|
||||
x, y, z = xyz
|
||||
w, h, le = whl
|
||||
|
||||
@@ -41,8 +43,39 @@ class TestIoU3D(TestCaseMixin, unittest.TestCase):
|
||||
)
|
||||
return verts
|
||||
|
||||
def test_iou(self):
|
||||
device = torch.device("cuda:0")
|
||||
@staticmethod
|
||||
def _box3d_overlap_naive_batched(boxes1, boxes2):
|
||||
"""
|
||||
Wrapper around box3d_overlap_naive to support
|
||||
batched input
|
||||
"""
|
||||
N = boxes1.shape[0]
|
||||
M = boxes2.shape[0]
|
||||
vols = torch.zeros((N, M), dtype=torch.float32, device=boxes1.device)
|
||||
ious = torch.zeros((N, M), dtype=torch.float32, device=boxes1.device)
|
||||
for n in range(N):
|
||||
for m in range(M):
|
||||
vol, iou = box3d_overlap_naive(boxes1[n], boxes2[m])
|
||||
vols[n, m] = vol
|
||||
ious[n, m] = iou
|
||||
return vols, ious
|
||||
|
||||
@staticmethod
|
||||
def _box3d_overlap_sampling_batched(boxes1, boxes2, num_samples: int):
|
||||
"""
|
||||
Wrapper around box3d_overlap_sampling to support
|
||||
batched input
|
||||
"""
|
||||
N = boxes1.shape[0]
|
||||
M = boxes2.shape[0]
|
||||
ious = torch.zeros((N, M), dtype=torch.float32, device=boxes1.device)
|
||||
for n in range(N):
|
||||
for m in range(M):
|
||||
iou = box3d_overlap_sampling(boxes1[n], boxes2[m])
|
||||
ious[n, m] = iou
|
||||
return ious
|
||||
|
||||
def _test_iou(self, overlap_fn, device):
|
||||
|
||||
box1 = torch.tensor(
|
||||
[
|
||||
@@ -60,30 +93,36 @@ class TestIoU3D(TestCaseMixin, unittest.TestCase):
|
||||
)
|
||||
|
||||
# 1st test: same box, iou = 1.0
|
||||
vol, iou = box3d_overlap(box1, box1)
|
||||
self.assertClose(vol, torch.tensor(1.0, device=vol.device, dtype=vol.dtype))
|
||||
self.assertClose(iou, torch.tensor(1.0, device=vol.device, dtype=vol.dtype))
|
||||
vol, iou = overlap_fn(box1[None], box1[None])
|
||||
self.assertClose(vol, torch.tensor([[1.0]], device=vol.device, dtype=vol.dtype))
|
||||
self.assertClose(iou, torch.tensor([[1.0]], device=vol.device, dtype=vol.dtype))
|
||||
|
||||
# 2nd test
|
||||
dd = random.random()
|
||||
box2 = box1 + torch.tensor([[0.0, dd, 0.0]], device=device)
|
||||
vol, iou = box3d_overlap(box1, box2)
|
||||
self.assertClose(vol, torch.tensor(1 - dd, device=vol.device, dtype=vol.dtype))
|
||||
vol, iou = overlap_fn(box1[None], box2[None])
|
||||
self.assertClose(
|
||||
vol, torch.tensor([[1 - dd]], device=vol.device, dtype=vol.dtype)
|
||||
)
|
||||
|
||||
# 3rd test
|
||||
dd = random.random()
|
||||
box2 = box1 + torch.tensor([[dd, 0.0, 0.0]], device=device)
|
||||
vol, _ = box3d_overlap(box1, box2)
|
||||
self.assertClose(vol, torch.tensor(1 - dd, device=vol.device, dtype=vol.dtype))
|
||||
vol, _ = overlap_fn(box1[None], box2[None])
|
||||
self.assertClose(
|
||||
vol, torch.tensor([[1 - dd]], device=vol.device, dtype=vol.dtype)
|
||||
)
|
||||
|
||||
# 4th test
|
||||
ddx, ddy, ddz = random.random(), random.random(), random.random()
|
||||
box2 = box1 + torch.tensor([[ddx, ddy, ddz]], device=device)
|
||||
vol, _ = box3d_overlap(box1, box2)
|
||||
vol, _ = overlap_fn(box1[None], box2[None])
|
||||
self.assertClose(
|
||||
vol,
|
||||
torch.tensor(
|
||||
(1 - ddx) * (1 - ddy) * (1 - ddz), device=vol.device, dtype=vol.dtype
|
||||
[[(1 - ddx) * (1 - ddy) * (1 - ddz)]],
|
||||
device=vol.device,
|
||||
dtype=vol.dtype,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -93,11 +132,13 @@ class TestIoU3D(TestCaseMixin, unittest.TestCase):
|
||||
RR = random_rotation(dtype=torch.float32, device=device)
|
||||
box1r = box1 @ RR.transpose(0, 1)
|
||||
box2r = box2 @ RR.transpose(0, 1)
|
||||
vol, _ = box3d_overlap(box1r, box2r)
|
||||
vol, _ = overlap_fn(box1r[None], box2r[None])
|
||||
self.assertClose(
|
||||
vol,
|
||||
torch.tensor(
|
||||
(1 - ddx) * (1 - ddy) * (1 - ddz), device=vol.device, dtype=vol.dtype
|
||||
[[(1 - ddx) * (1 - ddy) * (1 - ddz)]],
|
||||
device=vol.device,
|
||||
dtype=vol.dtype,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -108,11 +149,13 @@ class TestIoU3D(TestCaseMixin, unittest.TestCase):
|
||||
TT = torch.rand((1, 3), dtype=torch.float32, device=device)
|
||||
box1r = box1 @ RR.transpose(0, 1) + TT
|
||||
box2r = box2 @ RR.transpose(0, 1) + TT
|
||||
vol, _ = box3d_overlap(box1r, box2r)
|
||||
vol, _ = overlap_fn(box1r[None], box2r[None])
|
||||
self.assertClose(
|
||||
vol,
|
||||
torch.tensor(
|
||||
(1 - ddx) * (1 - ddy) * (1 - ddz), device=vol.device, dtype=vol.dtype
|
||||
[[(1 - ddx) * (1 - ddy) * (1 - ddz)]],
|
||||
device=vol.device,
|
||||
dtype=vol.dtype,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -135,7 +178,7 @@ class TestIoU3D(TestCaseMixin, unittest.TestCase):
|
||||
[-2.8789, 6.0142, 0.7549],
|
||||
[-4.3586, 3.5345, -1.1831],
|
||||
],
|
||||
device="cuda:0",
|
||||
device=device,
|
||||
)
|
||||
box2r = torch.tensor(
|
||||
[
|
||||
@@ -148,7 +191,7 @@ class TestIoU3D(TestCaseMixin, unittest.TestCase):
|
||||
[0.4328, -5.3761, -3.5436],
|
||||
[-2.3633, -5.6305, -1.2893],
|
||||
],
|
||||
device="cuda:0",
|
||||
device=device,
|
||||
)
|
||||
# from Meshlab:
|
||||
vol_inters = 33.558529
|
||||
@@ -156,9 +199,9 @@ class TestIoU3D(TestCaseMixin, unittest.TestCase):
|
||||
vol_box2 = 156.386719
|
||||
iou_mesh = vol_inters / (vol_box1 + vol_box2 - vol_inters)
|
||||
|
||||
vol, iou = box3d_overlap(box1r, box2r)
|
||||
self.assertClose(vol, torch.tensor(vol_inters, device=device), atol=1e-1)
|
||||
self.assertClose(iou, torch.tensor(iou_mesh, device=device), atol=1e-1)
|
||||
vol, iou = overlap_fn(box1r[None], box2r[None])
|
||||
self.assertClose(vol, torch.tensor([[vol_inters]], device=device), atol=1e-1)
|
||||
self.assertClose(iou, torch.tensor([[iou_mesh]], device=device), atol=1e-1)
|
||||
|
||||
# 8th test: compare with sampling
|
||||
# create box1
|
||||
@@ -173,16 +216,47 @@ class TestIoU3D(TestCaseMixin, unittest.TestCase):
|
||||
TT2 = torch.rand((1, 3), dtype=torch.float32, device=device)
|
||||
box1r = box1 @ RR1.transpose(0, 1) + TT1
|
||||
box2r = box2 @ RR2.transpose(0, 1) + TT2
|
||||
vol, iou = box3d_overlap(box1r, box2r)
|
||||
iou_sampling = box3d_overlap_sampling(box1r, box2r, num_samples=10000)
|
||||
vol, iou = overlap_fn(box1r[None], box2r[None])
|
||||
iou_sampling = self._box3d_overlap_sampling_batched(
|
||||
box1r[None], box2r[None], num_samples=10000
|
||||
)
|
||||
|
||||
self.assertClose(iou, iou_sampling, atol=1e-2)
|
||||
|
||||
# 9th test: non overlapping boxes, iou = 0.0
|
||||
box2 = box1 + torch.tensor([[0.0, 100.0, 0.0]], device=device)
|
||||
vol, iou = box3d_overlap(box1, box2)
|
||||
self.assertClose(vol, torch.tensor(0.0, device=vol.device, dtype=vol.dtype))
|
||||
self.assertClose(iou, torch.tensor(0.0, device=vol.device, dtype=vol.dtype))
|
||||
vol, iou = overlap_fn(box1[None], box2[None])
|
||||
self.assertClose(vol, torch.tensor([[0.0]], device=vol.device, dtype=vol.dtype))
|
||||
self.assertClose(iou, torch.tensor([[0.0]], device=vol.device, dtype=vol.dtype))
|
||||
|
||||
def test_iou_naive(self):
|
||||
device = torch.device("cuda:0")
|
||||
self._test_iou(self._box3d_overlap_naive_batched, device)
|
||||
|
||||
def test_iou_cpu(self):
|
||||
device = torch.device("cpu")
|
||||
self._test_iou(box3d_overlap, device)
|
||||
|
||||
def test_cpu_vs_naive_batched(self):
|
||||
N, M = 3, 6
|
||||
device = "cpu"
|
||||
boxes1 = torch.randn((N, 8, 3), device=device)
|
||||
boxes2 = torch.randn((M, 8, 3), device=device)
|
||||
vol1, iou1 = self._box3d_overlap_naive_batched(boxes1, boxes2)
|
||||
vol2, iou2 = box3d_overlap(boxes1, boxes2)
|
||||
# check shape
|
||||
for val in [vol1, vol2, iou1, iou2]:
|
||||
self.assertClose(val.shape, (N, M))
|
||||
# check values
|
||||
self.assertClose(vol1, vol2)
|
||||
self.assertClose(iou1, iou2)
|
||||
|
||||
def test_batched_errors(self):
|
||||
N, M = 5, 10
|
||||
boxes1 = torch.randn((N, 8, 3))
|
||||
boxes2 = torch.randn((M, 10, 3))
|
||||
with self.assertRaisesRegex(ValueError, "(8, 3)"):
|
||||
box3d_overlap(boxes1, boxes2)
|
||||
|
||||
def test_box_volume(self):
|
||||
device = torch.device("cuda:0")
|
||||
@@ -277,6 +351,36 @@ class TestIoU3D(TestCaseMixin, unittest.TestCase):
|
||||
self.assertClose(box_planar_dir(box1), n1)
|
||||
self.assertClose(box_planar_dir(box2), n2)
|
||||
|
||||
@staticmethod
|
||||
def iou_naive(N: int, M: int, device="cpu"):
|
||||
boxes1 = torch.randn((N, 8, 3))
|
||||
boxes2 = torch.randn((M, 8, 3))
|
||||
|
||||
def output():
|
||||
vol, iou = TestIoU3D._box3d_overlap_naive_batched(boxes1, boxes2)
|
||||
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def iou(N: int, M: int, device="cpu"):
|
||||
boxes1 = torch.randn((N, 8, 3), device=device)
|
||||
boxes2 = torch.randn((M, 8, 3), device=device)
|
||||
|
||||
def output():
|
||||
vol, iou = box3d_overlap(boxes1, boxes2)
|
||||
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def iou_sampling(N: int, M: int, num_samples: int):
|
||||
boxes1 = torch.randn((N, 8, 3))
|
||||
boxes2 = torch.randn((M, 8, 3))
|
||||
|
||||
def output():
|
||||
_ = TestIoU3D._box3d_overlap_sampling_batched(boxes1, boxes2, num_samples)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
# -------------------------------------------------- #
|
||||
# NAIVE IMPLEMENTATION #
|
||||
@@ -284,7 +388,7 @@ class TestIoU3D(TestCaseMixin, unittest.TestCase):
|
||||
|
||||
"""
|
||||
The main functions below are:
|
||||
* box3d_overlap: which computes the exact IoU of box1 and box2
|
||||
* box3d_overlap_naive: which computes the exact IoU of box1 and box2
|
||||
* box3d_overlap_sampling: which computes an approximate IoU of box1 and box2
|
||||
by sampling points within the boxes
|
||||
|
||||
@@ -738,7 +842,7 @@ def clip_tri_by_plane(plane, n, tri_verts) -> Union[List, torch.Tensor]:
|
||||
# -------------------------------------------------- #
|
||||
|
||||
|
||||
def box3d_overlap(box1: torch.Tensor, box2: torch.Tensor):
|
||||
def box3d_overlap_naive(box1: torch.Tensor, box2: torch.Tensor):
|
||||
"""
|
||||
Computes the intersection of 3D boxes1 and boxes2.
|
||||
Inputs boxes1, boxes2 are tensors of shape (8, 3) containing
|
||||
|
||||
Reference in New Issue
Block a user