mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-22 23:30:35 +08:00
(new) CUDA IoU for 3D boxes
Summary: CUDA implementation of 3D bounding box overlap calculation. Reviewed By: gkioxari Differential Revision: D31157919 fbshipit-source-id: 5dc89805d01fef2d6779f00a33226131e39c43ed
This commit is contained in:
committed by
Facebook GitHub Bot
parent
53266ec9ff
commit
ff8d4762f4
@@ -9,6 +9,7 @@ from .cameras_alignment import corresponding_cameras_alignment
|
||||
from .cubify import cubify
|
||||
from .graph_conv import GraphConv
|
||||
from .interp_face_attrs import interpolate_face_attributes
|
||||
from .iou_box3d import box3d_overlap
|
||||
from .knn import knn_gather, knn_points
|
||||
from .laplacian_matrices import cot_laplacian, laplacian, norm_laplacian
|
||||
from .mesh_face_areas_normals import mesh_face_areas_normals
|
||||
|
||||
@@ -7,10 +7,68 @@
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from pytorch3d import _C
|
||||
from torch.autograd import Function
|
||||
|
||||
|
||||
# -------------------------------------------------- #
|
||||
# CONSTANTS #
|
||||
# -------------------------------------------------- #
|
||||
"""
|
||||
_box_planes and _box_triangles define the 4- and 3-connectivity
|
||||
of the 8 box corners.
|
||||
_box_planes gives the quad faces of the 3D box
|
||||
_box_triangles gives the triangle faces of the 3D box
|
||||
"""
|
||||
_box_planes = [
|
||||
[0, 1, 2, 3],
|
||||
[3, 2, 6, 7],
|
||||
[0, 1, 5, 4],
|
||||
[0, 3, 7, 4],
|
||||
[1, 2, 6, 5],
|
||||
[4, 5, 6, 7],
|
||||
]
|
||||
_box_triangles = [
|
||||
[0, 1, 2],
|
||||
[0, 3, 2],
|
||||
[4, 5, 6],
|
||||
[4, 6, 7],
|
||||
[1, 5, 6],
|
||||
[1, 6, 2],
|
||||
[0, 4, 7],
|
||||
[0, 7, 3],
|
||||
[3, 2, 6],
|
||||
[3, 6, 7],
|
||||
[0, 1, 5],
|
||||
[0, 4, 5],
|
||||
]
|
||||
|
||||
|
||||
def _check_coplanar(boxes: torch.Tensor, eps: float = 1e-5) -> None:
|
||||
faces = torch.tensor(_box_planes, dtype=torch.int64, device=boxes.device)
|
||||
# pyre-fixme[16]: `boxes` has no attribute `index_select`.
|
||||
verts = boxes.index_select(index=faces.view(-1), dim=1)
|
||||
B = boxes.shape[0]
|
||||
P, V = faces.shape
|
||||
# (B, P, 4, 3) -> (B, P, 3)
|
||||
v0, v1, v2, v3 = verts.reshape(B, P, V, 3).unbind(2)
|
||||
|
||||
# Compute the normal
|
||||
e0 = F.normalize(v1 - v0, dim=-1)
|
||||
e1 = F.normalize(v2 - v0, dim=-1)
|
||||
normal = F.normalize(torch.cross(e0, e1, dim=-1), dim=-1)
|
||||
|
||||
# Check the fourth vertex is also on the same plane
|
||||
mat1 = (v3 - v0).view(B, 1, -1) # (B, 1, P*3)
|
||||
mat2 = normal.view(B, -1, 1) # (B, P*3, 1)
|
||||
if not (mat1.bmm(mat2).abs() < eps).all().item():
|
||||
msg = "Plane vertices are not coplanar"
|
||||
raise ValueError(msg)
|
||||
|
||||
return
|
||||
|
||||
|
||||
class _box3d_overlap(Function):
|
||||
"""
|
||||
Torch autograd Function wrapper for box3d_overlap C++/CUDA implementations.
|
||||
@@ -35,6 +93,7 @@ def box3d_overlap(
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Computes the intersection of 3D boxes1 and boxes2.
|
||||
|
||||
Inputs boxes1, boxes2 are tensors of shape (B, 8, 3)
|
||||
(where B doesn't have to be the same for boxes1 and boxes1),
|
||||
containing the 8 corners of the boxes, as follows:
|
||||
@@ -47,6 +106,25 @@ def box3d_overlap(
|
||||
` . | ` . |
|
||||
(3) ` +---------+ (2)
|
||||
|
||||
|
||||
NOTE: Throughout this implementation, we assume that boxes
|
||||
are defined by their 8 corners exactly in the order specified in the
|
||||
diagram above for the function to give correct results. In addition
|
||||
the vertices on each plane must be coplanar.
|
||||
As an alternative to the diagram, this is a unit bounding
|
||||
box which has the correct vertex ordering:
|
||||
|
||||
box_corner_vertices = [
|
||||
[0, 0, 0],
|
||||
[1, 0, 0],
|
||||
[1, 1, 0],
|
||||
[0, 1, 0],
|
||||
[0, 0, 1],
|
||||
[1, 0, 1],
|
||||
[1, 1, 1],
|
||||
[0, 1, 1],
|
||||
]
|
||||
|
||||
Args:
|
||||
boxes1: tensor of shape (N, 8, 3) of the coordinates of the 1st boxes
|
||||
boxes2: tensor of shape (M, 8, 3) of the coordinates of the 2nd boxes
|
||||
@@ -58,6 +136,9 @@ def box3d_overlap(
|
||||
if not all((8, 3) == box.shape[1:] for box in [boxes1, boxes2]):
|
||||
raise ValueError("Each box in the batch must be of shape (8, 3)")
|
||||
|
||||
_check_coplanar(boxes1)
|
||||
_check_coplanar(boxes2)
|
||||
|
||||
# pyre-fixme[16]: `_box3d_overlap` has no attribute `apply`.
|
||||
vol, iou = _box3d_overlap.apply(boxes1, boxes2)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user