mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2026-02-07 23:02:16 +08:00
C++ IoU for 3D Boxes
Summary: C++ Implementation of algorithm to compute 3D bounding boxes for batches of bboxes of shape (N, 8, 3) and (M, 8, 3). Reviewed By: gkioxari Differential Revision: D30905190 fbshipit-source-id: 02e2cf025cd4fa3ff706ce5cf9b82c0fb5443f96
This commit is contained in:
committed by
Facebook GitHub Bot
parent
2293f1fed0
commit
53266ec9ff
64
pytorch3d/ops/iou_box3d.py
Normal file
64
pytorch3d/ops/iou_box3d.py
Normal file
@@ -0,0 +1,64 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
from pytorch3d import _C
|
||||
from torch.autograd import Function
|
||||
|
||||
|
||||
class _box3d_overlap(Function):
|
||||
"""
|
||||
Torch autograd Function wrapper for box3d_overlap C++/CUDA implementations.
|
||||
Backward is not supported.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, boxes1, boxes2):
|
||||
"""
|
||||
Arguments defintions the same as in the box3d_overlap function
|
||||
"""
|
||||
vol, iou = _C.iou_box3d(boxes1, boxes2)
|
||||
return vol, iou
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_vol, grad_iou):
|
||||
raise ValueError("box3d_overlap backward is not supported")
|
||||
|
||||
|
||||
def box3d_overlap(
|
||||
boxes1: torch.Tensor, boxes2: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Computes the intersection of 3D boxes1 and boxes2.
|
||||
Inputs boxes1, boxes2 are tensors of shape (B, 8, 3)
|
||||
(where B doesn't have to be the same for boxes1 and boxes1),
|
||||
containing the 8 corners of the boxes, as follows:
|
||||
|
||||
(4) +---------+. (5)
|
||||
| ` . | ` .
|
||||
| (0) +---+-----+ (1)
|
||||
| | | |
|
||||
(7) +-----+---+. (6)|
|
||||
` . | ` . |
|
||||
(3) ` +---------+ (2)
|
||||
|
||||
Args:
|
||||
boxes1: tensor of shape (N, 8, 3) of the coordinates of the 1st boxes
|
||||
boxes2: tensor of shape (M, 8, 3) of the coordinates of the 2nd boxes
|
||||
Returns:
|
||||
vol: (N, M) tensor of the volume of the intersecting convex shapes
|
||||
iou: (N, M) tensor of the intersection over union which is
|
||||
defined as: `iou = vol / (vol1 + vol2 - vol)`
|
||||
"""
|
||||
if not all((8, 3) == box.shape[1:] for box in [boxes1, boxes2]):
|
||||
raise ValueError("Each box in the batch must be of shape (8, 3)")
|
||||
|
||||
# pyre-fixme[16]: `_box3d_overlap` has no attribute `apply`.
|
||||
vol, iou = _box3d_overlap.apply(boxes1, boxes2)
|
||||
|
||||
return vol, iou
|
||||
Reference in New Issue
Block a user