(new) CUDA IoU for 3D boxes

Summary: CUDA implementation of 3D bounding box overlap calculation.

Reviewed By: gkioxari

Differential Revision: D31157919

fbshipit-source-id: 5dc89805d01fef2d6779f00a33226131e39c43ed
This commit is contained in:
Nikhila Ravi
2021-09-29 18:48:11 -07:00
committed by Facebook GitHub Bot
parent 53266ec9ff
commit ff8d4762f4
9 changed files with 1019 additions and 97 deletions

View File

@@ -11,25 +11,42 @@ from test_iou_box3d import TestIoU3D
def bm_iou_box3d() -> None:
N = [1, 4, 8, 16]
num_samples = [2000, 5000, 10000, 20000]
# Realistic use cases
N = [30, 100]
M = [5, 10, 100]
kwargs_list = []
test_cases = product(N, M)
for case in test_cases:
n, m = case
kwargs_list.append({"N": n, "M": m, "device": "cuda:0"})
benchmark(TestIoU3D.iou, "3D_IOU", kwargs_list, warmup_iters=1)
# Comparison of C++/CUDA
kwargs_list = []
N = [1, 4, 8, 16]
devices = ["cpu", "cuda:0"]
test_cases = product(N, N, devices)
for case in test_cases:
n, m, d = case
kwargs_list.append({"N": n, "M": m, "device": d})
benchmark(TestIoU3D.iou, "3D_IOU", kwargs_list, warmup_iters=1)
# Naive PyTorch
N = [1, 4]
kwargs_list = []
test_cases = product(N, N)
for case in test_cases:
n, m = case
kwargs_list.append({"N": n, "M": m, "device": "cuda:0"})
benchmark(TestIoU3D.iou_naive, "3D_IOU_NAIVE", kwargs_list, warmup_iters=1)
[k.update({"device": "cpu"}) for k in kwargs_list]
benchmark(TestIoU3D.iou, "3D_IOU", kwargs_list, warmup_iters=1)
# Sampling based method
num_samples = [2000, 5000]
kwargs_list = []
test_cases = product([1, 4], [1, 4], num_samples)
test_cases = product(N, N, num_samples)
for case in test_cases:
n, m, s = case
kwargs_list.append({"N": n, "M": m, "num_samples": s})
kwargs_list.append({"N": n, "M": m, "num_samples": s, "device": "cuda:0"})
benchmark(TestIoU3D.iou_sampling, "3D_IOU_SAMPLING", kwargs_list, warmup_iters=1)

Binary file not shown.

View File

@@ -10,13 +10,28 @@ from typing import List, Tuple, Union
import torch
import torch.nn.functional as F
from common_testing import TestCaseMixin
from common_testing import TestCaseMixin, get_random_cuda_device, get_tests_dir
from pytorch3d.io import save_obj
from pytorch3d.ops.iou_box3d import box3d_overlap
from pytorch3d.ops.iou_box3d import _box_planes, _box_triangles, box3d_overlap
from pytorch3d.transforms.rotation_conversions import random_rotation
OBJECTRON_TO_PYTORCH3D_FACE_IDX = [0, 4, 6, 2, 1, 5, 7, 3]
DATA_DIR = get_tests_dir() / "data"
DEBUG = False
UNIT_BOX = [
[0, 0, 0],
[1, 0, 0],
[1, 1, 0],
[0, 1, 0],
[0, 0, 1],
[1, 0, 1],
[1, 1, 1],
[0, 1, 1],
]
class TestIoU3D(TestCaseMixin, unittest.TestCase):
def setUp(self) -> None:
super().setUp()
@@ -78,16 +93,7 @@ class TestIoU3D(TestCaseMixin, unittest.TestCase):
def _test_iou(self, overlap_fn, device):
box1 = torch.tensor(
[
[0, 0, 0],
[1, 0, 0],
[1, 1, 0],
[0, 1, 0],
[0, 0, 1],
[1, 0, 1],
[1, 1, 1],
[0, 1, 1],
],
UNIT_BOX,
dtype=torch.float32,
device=device,
)
@@ -126,6 +132,10 @@ class TestIoU3D(TestCaseMixin, unittest.TestCase):
),
)
# Also check IoU is 1 when computing overlap with the same shifted box
vol, iou = overlap_fn(box2[None], box2[None])
self.assertClose(iou, torch.tensor([[1.0]], device=vol.device, dtype=vol.dtype))
# 5th test
ddx, ddy, ddz = random.random(), random.random(), random.random()
box2 = box1 + torch.tensor([[ddx, ddy, ddz]], device=device)
@@ -207,15 +217,15 @@ class TestIoU3D(TestCaseMixin, unittest.TestCase):
# create box1
ctrs = torch.rand((2, 3), device=device)
whl = torch.rand((2, 3), device=device) * 10.0 + 1.0
# box1 & box2
box1 = self.create_box(ctrs[0], whl[0])
box2 = self.create_box(ctrs[1], whl[1])
# box8a & box8b
box8a = self.create_box(ctrs[0], whl[0])
box8b = self.create_box(ctrs[1], whl[1])
RR1 = random_rotation(dtype=torch.float32, device=device)
TT1 = torch.rand((1, 3), dtype=torch.float32, device=device)
RR2 = random_rotation(dtype=torch.float32, device=device)
TT2 = torch.rand((1, 3), dtype=torch.float32, device=device)
box1r = box1 @ RR1.transpose(0, 1) + TT1
box2r = box2 @ RR2.transpose(0, 1) + TT2
box1r = box8a @ RR1.transpose(0, 1) + TT1
box2r = box8b @ RR2.transpose(0, 1) + TT2
vol, iou = overlap_fn(box1r[None], box2r[None])
iou_sampling = self._box3d_overlap_sampling_batched(
box1r[None], box2r[None], num_samples=10000
@@ -229,27 +239,90 @@ class TestIoU3D(TestCaseMixin, unittest.TestCase):
self.assertClose(vol, torch.tensor([[0.0]], device=vol.device, dtype=vol.dtype))
self.assertClose(iou, torch.tensor([[0.0]], device=vol.device, dtype=vol.dtype))
# 10th test: Non coplanar verts in a plane
box10 = box1 + torch.rand((8, 3), dtype=torch.float32, device=device)
msg = "Plane vertices are not coplanar"
with self.assertRaisesRegex(ValueError, msg):
overlap_fn(box10[None], box10[None])
# 11th test: Skewed bounding boxes but all verts are coplanar
box_skew_1 = torch.tensor(
[
[0, 0, 0],
[1, 0, 0],
[1, 1, 0],
[0, 1, 0],
[-2, -2, 2],
[2, -2, 2],
[2, 2, 2],
[-2, 2, 2],
],
dtype=torch.float32,
device=device,
)
box_skew_2 = torch.tensor(
[
[2.015995, 0.695233, 2.152806],
[2.832533, 0.663448, 1.576389],
[2.675445, -0.309592, 1.407520],
[1.858907, -0.277806, 1.983936],
[-0.413922, 3.161758, 2.044343],
[2.852230, 3.034615, -0.261321],
[2.223878, -0.857545, -0.936800],
[-1.042273, -0.730402, 1.368864],
],
dtype=torch.float32,
device=device,
)
vol1 = 14.000
vol2 = 14.000005
vol_inters = 5.431122
iou = vol_inters / (vol1 + vol2 - vol_inters)
vols, ious = overlap_fn(box_skew_1[None], box_skew_2[None])
self.assertClose(vols, torch.tensor([[vol_inters]], device=device), atol=1e-1)
self.assertClose(ious, torch.tensor([[iou]], device=device), atol=1e-1)
def test_iou_naive(self):
device = torch.device("cuda:0")
device = get_random_cuda_device()
self._test_iou(self._box3d_overlap_naive_batched, device)
self._test_compare_objectron(self._box3d_overlap_naive_batched, device)
def test_iou_cpu(self):
device = torch.device("cpu")
self._test_iou(box3d_overlap, device)
self._test_compare_objectron(box3d_overlap, device)
def test_cpu_vs_naive_batched(self):
N, M = 3, 6
device = "cpu"
boxes1 = torch.randn((N, 8, 3), device=device)
boxes2 = torch.randn((M, 8, 3), device=device)
vol1, iou1 = self._box3d_overlap_naive_batched(boxes1, boxes2)
vol2, iou2 = box3d_overlap(boxes1, boxes2)
# check shape
for val in [vol1, vol2, iou1, iou2]:
self.assertClose(val.shape, (N, M))
# check values
self.assertClose(vol1, vol2)
self.assertClose(iou1, iou2)
def test_iou_cuda(self):
device = torch.device("cuda:0")
self._test_iou(box3d_overlap, device)
self._test_compare_objectron(box3d_overlap, device)
def _test_compare_objectron(self, overlap_fn, device):
# Load saved objectron data
data_filename = "./objectron_vols_ious.pt"
objectron_vals = torch.load(DATA_DIR / data_filename)
boxes1 = objectron_vals["boxes1"]
boxes2 = objectron_vals["boxes2"]
vols_objectron = objectron_vals["vols"]
ious_objectron = objectron_vals["ious"]
boxes1 = boxes1.to(device=device, dtype=torch.float32)
boxes2 = boxes2.to(device=device, dtype=torch.float32)
# Convert vertex orderings from Objectron to PyTorch3D convention
idx = torch.tensor(
OBJECTRON_TO_PYTORCH3D_FACE_IDX, dtype=torch.int64, device=device
)
boxes1 = boxes1.index_select(index=idx, dim=1)
boxes2 = boxes2.index_select(index=idx, dim=1)
# Run PyTorch3D version
vols, ious = overlap_fn(boxes1, boxes2)
# Check values match
self.assertClose(vols_objectron, vols.cpu())
self.assertClose(ious_objectron, ious.cpu())
def test_batched_errors(self):
N, M = 5, 10
@@ -316,16 +389,7 @@ class TestIoU3D(TestCaseMixin, unittest.TestCase):
def test_box_planar_dir(self):
device = torch.device("cuda:0")
box1 = torch.tensor(
[
[0, 0, 0],
[1, 0, 0],
[1, 1, 0],
[0, 1, 0],
[0, 0, 1],
[1, 0, 1],
[1, 1, 1],
[0, 1, 1],
],
UNIT_BOX,
dtype=torch.float32,
device=device,
)
@@ -353,8 +417,13 @@ class TestIoU3D(TestCaseMixin, unittest.TestCase):
@staticmethod
def iou_naive(N: int, M: int, device="cpu"):
boxes1 = torch.randn((N, 8, 3))
boxes2 = torch.randn((M, 8, 3))
box = torch.tensor(
[UNIT_BOX],
dtype=torch.float32,
device=device,
)
boxes1 = box + torch.randn((N, 1, 3), device=device)
boxes2 = box + torch.randn((M, 1, 3), device=device)
def output():
vol, iou = TestIoU3D._box3d_overlap_naive_batched(boxes1, boxes2)
@@ -363,8 +432,13 @@ class TestIoU3D(TestCaseMixin, unittest.TestCase):
@staticmethod
def iou(N: int, M: int, device="cpu"):
boxes1 = torch.randn((N, 8, 3), device=device)
boxes2 = torch.randn((M, 8, 3), device=device)
box = torch.tensor(
[UNIT_BOX],
dtype=torch.float32,
device=device,
)
boxes1 = box + torch.randn((N, 1, 3), device=device)
boxes2 = box + torch.randn((M, 1, 3), device=device)
def output():
vol, iou = box3d_overlap(boxes1, boxes2)
@@ -372,9 +446,14 @@ class TestIoU3D(TestCaseMixin, unittest.TestCase):
return output
@staticmethod
def iou_sampling(N: int, M: int, num_samples: int):
boxes1 = torch.randn((N, 8, 3))
boxes2 = torch.randn((M, 8, 3))
def iou_sampling(N: int, M: int, num_samples: int, device="cpu"):
box = torch.tensor(
[UNIT_BOX],
dtype=torch.float32,
device=device,
)
boxes1 = box + torch.randn((N, 1, 3), device=device)
boxes2 = box + torch.randn((M, 1, 3), device=device)
def output():
_ = TestIoU3D._box3d_overlap_sampling_batched(boxes1, boxes2, num_samples)
@@ -408,38 +487,6 @@ Note that both implementations currently do not support batching.
#
# -------------------------------------------------- #
# -------------------------------------------------- #
# CONSTANTS #
# -------------------------------------------------- #
"""
_box_planes and _box_triangles define the 4- and 3-connectivity
of the 8 box corners.
_box_planes gives the quad faces of the 3D box
_box_triangles gives the triangle faces of the 3D box
"""
_box_planes = [
[0, 1, 2, 3],
[3, 2, 6, 7],
[0, 1, 5, 4],
[0, 3, 7, 4],
[1, 5, 6, 2],
[4, 5, 6, 7],
]
_box_triangles = [
[0, 1, 2],
[0, 3, 2],
[4, 5, 6],
[4, 6, 7],
[1, 5, 6],
[1, 6, 2],
[0, 4, 7],
[0, 7, 3],
[3, 2, 6],
[3, 6, 7],
[0, 1, 5],
[0, 4, 5],
]
# -------------------------------------------------- #
# HELPER FUNCTIONS FOR EXACT SOLUTION #
# -------------------------------------------------- #
@@ -477,7 +524,7 @@ def get_plane_verts(box: torch.Tensor) -> torch.Tensor:
return plane_verts
def box_planar_dir(box: torch.Tensor) -> torch.Tensor:
def box_planar_dir(box: torch.Tensor, eps=1e-4) -> torch.Tensor:
"""
Finds the unit vector n which is perpendicular to each plane in the box
and points towards the inside of the box.
@@ -507,6 +554,11 @@ def box_planar_dir(box: torch.Tensor) -> torch.Tensor:
e1 = F.normalize(v2 - v0, dim=-1)
n = F.normalize(torch.cross(e0, e1, dim=-1), dim=-1)
# Check all verts are coplanar
if not ((v3 - v0).unsqueeze(1).bmm(n.unsqueeze(2)).abs() < eps).all().item():
msg = "Plane vertices are not coplanar"
raise ValueError(msg)
# We can write: `ctr = v0 + a * e0 + b * e1 + c * n`, (1).
# With <e0, n> = 0 and <e1, n> = 0, where <.,.> refers to the dot product,
# since that e0 is orthogonal to n. Same for e1.
@@ -733,10 +785,10 @@ def clip_tri_by_plane_oneout(
device = plane.device
# point of intersection between plane and (vin1, vout)
pint1, a1 = plane_edge_point_of_intersection(plane, n, vin1, vout)
assert a1 >= eps and a1 <= 1.0
assert a1 >= eps and a1 <= 1.0, a1
# point of intersection between plane and (vin2, vout)
pint2, a2 = plane_edge_point_of_intersection(plane, n, vin2, vout)
assert a2 >= 0.0 and a2 <= 1.0
assert a2 >= 0.0 and a2 <= 1.0, a2
verts = torch.stack((vin1, pint1, pint2, vin2), dim=0) # 4x3
faces = torch.tensor(
@@ -771,10 +823,10 @@ def clip_tri_by_plane_twoout(
device = plane.device
# point of intersection between plane and (vin, vout1)
pint1, a1 = plane_edge_point_of_intersection(plane, n, vin, vout1)
assert a1 >= eps and a1 <= 1.0
assert a1 >= eps and a1 <= 1.0, a1
# point of intersection between plane and (vin, vout2)
pint2, a2 = plane_edge_point_of_intersection(plane, n, vin, vout2)
assert a2 >= eps and a2 <= 1.0
assert a2 >= eps and a2 <= 1.0, a2
verts = torch.stack((vin, pint1, pint2), dim=0) # 3x3
faces = torch.tensor(
@@ -945,7 +997,7 @@ def box3d_overlap_naive(box1: torch.Tensor, box2: torch.Tensor):
iou = vol / (vol1 + vol2 - vol)
if 0:
if DEBUG:
# save shapes
tri_faces = torch.tensor(_box_triangles, device=device, dtype=torch.int64)
save_obj("/tmp/output/shape1.obj", box1, tri_faces)