C++ IoU for 3D Boxes

Summary: C++ Implementation of algorithm to compute 3D bounding boxes for batches of bboxes of shape (N, 8, 3) and (M, 8, 3).

Reviewed By: gkioxari

Differential Revision: D30905190

fbshipit-source-id: 02e2cf025cd4fa3ff706ce5cf9b82c0fb5443f96
This commit is contained in:
Nikhila Ravi
2021-09-29 17:02:37 -07:00
committed by Facebook GitHub Bot
parent 2293f1fed0
commit 53266ec9ff
7 changed files with 927 additions and 29 deletions

View File

@@ -0,0 +1,64 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Tuple
import torch
from pytorch3d import _C
from torch.autograd import Function
class _box3d_overlap(Function):
"""
Torch autograd Function wrapper for box3d_overlap C++/CUDA implementations.
Backward is not supported.
"""
@staticmethod
def forward(ctx, boxes1, boxes2):
"""
Arguments defintions the same as in the box3d_overlap function
"""
vol, iou = _C.iou_box3d(boxes1, boxes2)
return vol, iou
@staticmethod
def backward(ctx, grad_vol, grad_iou):
raise ValueError("box3d_overlap backward is not supported")
def box3d_overlap(
boxes1: torch.Tensor, boxes2: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Computes the intersection of 3D boxes1 and boxes2.
Inputs boxes1, boxes2 are tensors of shape (B, 8, 3)
(where B doesn't have to be the same for boxes1 and boxes1),
containing the 8 corners of the boxes, as follows:
(4) +---------+. (5)
| ` . | ` .
| (0) +---+-----+ (1)
| | | |
(7) +-----+---+. (6)|
` . | ` . |
(3) ` +---------+ (2)
Args:
boxes1: tensor of shape (N, 8, 3) of the coordinates of the 1st boxes
boxes2: tensor of shape (M, 8, 3) of the coordinates of the 2nd boxes
Returns:
vol: (N, M) tensor of the volume of the intersecting convex shapes
iou: (N, M) tensor of the intersection over union which is
defined as: `iou = vol / (vol1 + vol2 - vol)`
"""
if not all((8, 3) == box.shape[1:] for box in [boxes1, boxes2]):
raise ValueError("Each box in the batch must be of shape (8, 3)")
# pyre-fixme[16]: `_box3d_overlap` has no attribute `apply`.
vol, iou = _box3d_overlap.apply(boxes1, boxes2)
return vol, iou