C++ IoU for 3D Boxes

Summary: C++ Implementation of algorithm to compute 3D bounding boxes for batches of bboxes of shape (N, 8, 3) and (M, 8, 3).

Reviewed By: gkioxari

Differential Revision: D30905190

fbshipit-source-id: 02e2cf025cd4fa3ff706ce5cf9b82c0fb5443f96
This commit is contained in:
Nikhila Ravi 2021-09-29 17:02:37 -07:00 committed by Facebook GitHub Bot
parent 2293f1fed0
commit 53266ec9ff
7 changed files with 927 additions and 29 deletions

View File

@ -20,6 +20,7 @@
#include "face_areas_normals/face_areas_normals.h" #include "face_areas_normals/face_areas_normals.h"
#include "gather_scatter/gather_scatter.h" #include "gather_scatter/gather_scatter.h"
#include "interp_face_attrs/interp_face_attrs.h" #include "interp_face_attrs/interp_face_attrs.h"
#include "iou_box3d/iou_box3d.h"
#include "knn/knn.h" #include "knn/knn.h"
#include "mesh_normal_consistency/mesh_normal_consistency.h" #include "mesh_normal_consistency/mesh_normal_consistency.h"
#include "packed_to_padded_tensor/packed_to_padded_tensor.h" #include "packed_to_padded_tensor/packed_to_padded_tensor.h"
@ -87,6 +88,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
// Sample PDF // Sample PDF
m.def("sample_pdf", &SamplePdf); m.def("sample_pdf", &SamplePdf);
// 3D IoU
m.def("iou_box3d", &IoUBox3D);
// Pulsar. // Pulsar.
#ifdef PULSAR_LOGGING_ENABLED #ifdef PULSAR_LOGGING_ENABLED
c10::ShowLogInfoToStderr(); c10::ShowLogInfoToStderr();

View File

@ -0,0 +1,37 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#pragma once
#include <torch/extension.h>
#include <tuple>
#include "utils/pytorch3d_cutils.h"
// Calculate the intersection volume and IoU metric for two batches of boxes
//
// Args:
// boxes1: tensor of shape (N, 8, 3) of the coordinates of the 1st boxes
// boxes2: tensor of shape (M, 8, 3) of the coordinates of the 2nd boxes
// Returns:
// vol: (N, M) tensor of the volume of the intersecting convex shapes
// iou: (N, M) tensor of the intersection over union which is
// defined as: `iou = vol / (vol1 + vol2 - vol)`
// CPU implementation
std::tuple<at::Tensor, at::Tensor> IoUBox3DCpu(
const at::Tensor& boxes1,
const at::Tensor& boxes2);
// Implementation which is exposed
inline std::tuple<at::Tensor, at::Tensor> IoUBox3D(
const at::Tensor& boxes1,
const at::Tensor& boxes2) {
if (boxes1.is_cuda() || boxes2.is_cuda()) {
AT_ERROR("GPU support not implemented");
}
return IoUBox3DCpu(boxes1.contiguous(), boxes2.contiguous());
}

View File

@ -0,0 +1,121 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <torch/extension.h>
#include <torch/torch.h>
#include <list>
#include <numeric>
#include <queue>
#include <tuple>
#include "iou_box3d/iou_utils.h"
std::tuple<at::Tensor, at::Tensor> IoUBox3DCpu(
const at::Tensor& boxes1,
const at::Tensor& boxes2) {
const int N = boxes1.size(0);
const int M = boxes2.size(0);
auto float_opts = boxes1.options().dtype(torch::kFloat32);
torch::Tensor vols = torch::zeros({N, M}, float_opts);
torch::Tensor ious = torch::zeros({N, M}, float_opts);
// Create tensor accessors
auto boxes1_a = boxes1.accessor<float, 3>();
auto boxes2_a = boxes2.accessor<float, 3>();
auto vols_a = vols.accessor<float, 2>();
auto ious_a = ious.accessor<float, 2>();
// Iterate through the N boxes in boxes1
for (int n = 0; n < N; ++n) {
const auto& box1 = boxes1_a[n];
// Convert to vector of face vertices i.e. effectively (F, 3, 3)
// face_verts is a data type defined in iou_utils.h
const face_verts box1_tris = GetBoxTris(box1);
// Calculate the position of the center of the box which is used in
// several calculations. This requires a tensor as input.
const vec3<float> box1_center = BoxCenter(boxes1[n]);
// Convert to vector of face vertices i.e. effectively (P, 4, 3)
const face_verts box1_planes = GetBoxPlanes(box1);
// Get Box Volumes
const float box1_vol = BoxVolume(box1_tris, box1_center);
// Iterate through the M boxes in boxes2
for (int m = 0; m < M; ++m) {
// Repeat above steps for box2
// TODO: check if caching these value helps performance.
const auto& box2 = boxes2_a[m];
const face_verts box2_tris = GetBoxTris(box2);
const vec3<float> box2_center = BoxCenter(boxes2[m]);
const face_verts box2_planes = GetBoxPlanes(box2);
const float box2_vol = BoxVolume(box2_tris, box2_center);
// Every triangle in one box will be compared to each plane in the other
// box. There are 3 possible outcomes:
// 1. If the triangle is fully inside, then it will
// remain as is.
// 2. If the triagnle it is fully outside, it will be removed.
// 3. If the triangle intersects with the (infinite) plane, it
// will be broken into subtriangles such that each subtriangle is full
// inside the plane and part of the intersecting tetrahedron.
// Tris in Box1 -> Planes in Box2
face_verts box1_intersect =
BoxIntersections(box1_tris, box2_planes, box2_center);
// Tris in Box2 -> Planes in Box1
face_verts box2_intersect =
BoxIntersections(box2_tris, box1_planes, box1_center);
// If there are overlapping regions in Box2, remove any coplanar faces
if (box2_intersect.size() > 0) {
// Identify if any triangles in Box2 are coplanar with Box1
std::vector<int> tri2_keep(box2_intersect.size());
std::fill(tri2_keep.begin(), tri2_keep.end(), 1);
for (int b1 = 0; b1 < box1_intersect.size(); ++b1) {
for (int b2 = 0; b2 < box2_intersect.size(); ++b2) {
bool is_coplanar =
IsCoplanarFace(box1_intersect[b1], box2_intersect[b2]);
if (is_coplanar) {
tri2_keep[b2] = 0;
}
}
}
// Keep only the non coplanar triangles in Box2 - add them to the
// Box1 triangles.
for (int b2 = 0; b2 < box2_intersect.size(); ++b2) {
if (tri2_keep[b2] == 1) {
box1_intersect.push_back((box2_intersect[b2]));
}
}
}
// Initialize the vol and iou to 0.0 in case there are no triangles
// in the intersecting shape.
float vol = 0.0;
float iou = 0.0;
// If there are triangles in the intersecting shape
if (box1_intersect.size() > 0) {
// The intersecting shape is a polyhedron made up of the
// triangular faces that are all now in box1_intersect.
// Calculate the polyhedron center
const vec3<float> polyhedron_center = PolyhedronCenter(box1_intersect);
// Compute intersecting polyhedron volume
vol = BoxVolume(box1_intersect, polyhedron_center);
// Compute IoU
iou = vol / (box1_vol + box2_vol - vol);
}
// Save out volume and IoU
vols_a[n][m] = vol;
ious_a[n][m] = iou;
}
}
return std::make_tuple(vols, ious);
}

View File

@ -0,0 +1,531 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <ATen/ATen.h>
#include <assert.h>
#include <torch/extension.h>
#include <torch/torch.h>
#include <algorithm>
#include <list>
#include <numeric>
#include <queue>
#include <tuple>
#include <type_traits>
#include "utils/geometry_utils.h"
#include "utils/vec3.h"
/*
_PLANES and _TRIS define the 4- and 3-connectivity
of the 8 box corners.
_PLANES gives the quad faces of the 3D box
_TRIS gives the triangle faces of the 3D box
*/
const int NUM_PLANES = 6;
const int NUM_TRIS = 12;
const int _PLANES[6][4] = {
{0, 1, 2, 3},
{3, 2, 6, 7},
{0, 1, 5, 4},
{0, 3, 7, 4},
{1, 5, 6, 2},
{4, 5, 6, 7},
};
const int _TRIS[12][3] = {
{0, 1, 2},
{0, 3, 2},
{4, 5, 6},
{4, 6, 7},
{1, 5, 6},
{1, 6, 2},
{0, 4, 7},
{0, 7, 3},
{3, 2, 6},
{3, 6, 7},
{0, 1, 5},
{0, 4, 5},
};
// Create a new data type for representing the
// verts for each face which can be triangle or plane.
// This helps make the code more readable.
using face_verts = std::vector<std::vector<vec3<float>>>;
// Args
// box: (8, 3) tensor accessor for the box vertices
// plane_idx: index of the plane in the box
// vert_idx: index of the vertex in the plane
//
// Returns
// vec3<T> (x, y, x) vertex coordinates
//
template <typename Box>
inline vec3<float>
ExtractVertsPlane(const Box& box, const int plane_idx, const int vert_idx) {
return vec3<float>(
box[_PLANES[plane_idx][vert_idx]][0],
box[_PLANES[plane_idx][vert_idx]][1],
box[_PLANES[plane_idx][vert_idx]][2]);
}
// Args
// box: (8, 3) tensor accessor for the box vertices
// tri_idx: index of the triangle face in the box
// vert_idx: index of the vertex in the triangle
//
// Returns
// vec3<T> (x, y, x) vertex coordinates
//
template <typename Box>
inline vec3<float>
ExtractVertsTri(const Box& box, const int tri_idx, const int vert_idx) {
return vec3<float>(
box[_TRIS[tri_idx][vert_idx]][0],
box[_TRIS[tri_idx][vert_idx]][1],
box[_TRIS[tri_idx][vert_idx]][2]);
}
// Args
// box: (8, 3) tensor accessor for the box vertices
//
// Returns
// std::vector<std::vector<vec3<T>>> effectively (F, 3, 3)
// coordinates of the verts for each face
//
template <typename Box>
inline face_verts GetBoxTris(const Box& box) {
face_verts box_tris;
for (int t = 0; t < NUM_TRIS; ++t) {
vec3<float> v0 = ExtractVertsTri(box, t, 0);
vec3<float> v1 = ExtractVertsTri(box, t, 1);
vec3<float> v2 = ExtractVertsTri(box, t, 2);
box_tris.push_back({v0, v1, v2});
}
return box_tris;
}
// Args
// box: (8, 3) tensor accessor for the box vertices
//
// Returns
// std::vector<std::vector<vec3<T>>> effectively (P, 3, 3)
// coordinates of the 4 verts for each plane
//
template <typename Box>
inline face_verts GetBoxPlanes(const Box& box) {
face_verts box_planes;
for (int t = 0; t < NUM_PLANES; ++t) {
vec3<float> v0 = ExtractVertsPlane(box, t, 0);
vec3<float> v1 = ExtractVertsPlane(box, t, 1);
vec3<float> v2 = ExtractVertsPlane(box, t, 2);
vec3<float> v3 = ExtractVertsPlane(box, t, 3);
box_planes.push_back({v0, v1, v2, v3});
}
return box_planes;
}
// The normal of the face defined by vertices (v0, v1, v2)
// Define e0 to be the edge connecting (v1, v0)
// Define e1 to be the edge connecting (v2, v0)
// normal is the cross product of e0, e1
//
// Args
// v0, v1, v2: vec3 coordinates of the vertices of the face
//
// Returns
// vec3: normal for the face
//
inline vec3<float> FaceNormal(vec3<float> v0, vec3<float> v1, vec3<float> v2) {
vec3<float> n = cross(v1 - v0, v2 - v0);
n = n / std::fmaxf(norm(n), kEpsilon);
return n;
}
// The normal of a box plane defined by the verts in `plane` with
// the centroid of the box given by `center`.
// Args
// plane: vec3 coordinates of the vertices of the plane
// center: vec3 coordinates of the center of the box from
// which the plane originated
//
// Returns
// vec3: normal for the plane such that it points towards
// the center of the box
//
inline vec3<float> PlaneNormalDirection(
const std::vector<vec3<float>>& plane,
const vec3<float>& center) {
// Only need the first 3 verts of the plane
const vec3<float> v0 = plane[0];
const vec3<float> v1 = plane[1];
const vec3<float> v2 = plane[2];
// We project the center on the plane defined by (v0, v1, v2)
// We can write center = v0 + a * e0 + b * e1 + c * n
// We know that <e0, n> = 0 and <e1, n> = 0 and
// <a, b> is the dot product between a and b.
// This means we can solve for c as:
// c = <center - v0 - a * e0 - b * e1, n> = <center - v0, n>
vec3<float> n = FaceNormal(v0, v1, v2);
const float c = dot((center - v0), n);
// If c is negative, then we revert the direction of n such that n
// points "inside"
if (c < kEpsilon) {
n = -1.0f * n;
}
return n;
}
// Calculate the volume of the box by summing the volume of
// each of the tetrahedrons formed with a triangle face and
// the box centroid.
//
// Args
// box_tris: vector of vec3 coordinates of the vertices of each
// of the triangles in the box
// box_center: vec3 coordinates of the center of the box
//
// Returns
// float: volume of the box
//
inline float BoxVolume(
const face_verts& box_tris,
const vec3<float>& box_center) {
float box_vol = 0.0;
// Iterate through each triange, calculate the area of the
// tetrahedron formed with the box_center and sum them
for (int t = 0; t < box_tris.size(); ++t) {
// Subtract the center:
const vec3<float> v0 = box_tris[t][0] - box_center;
const vec3<float> v1 = box_tris[t][1] - box_center;
const vec3<float> v2 = box_tris[t][2] - box_center;
// Compute the area
const float area = dot(v0, cross(v1, v2));
const float vol = std::abs(area) / 6.0;
box_vol = box_vol + vol;
}
return box_vol;
}
// Compute the box center as the mean of the verts
//
// Args
// box_verts: (8, 3) tensor of the corner vertices of the box
//
// Returns
// vec3: coordinates of the center of the box
//
inline vec3<float> BoxCenter(const at::Tensor& box_verts) {
const auto& box_center_t = at::mean(box_verts, 0);
const vec3<float> box_center(
box_center_t[0].item<float>(),
box_center_t[1].item<float>(),
box_center_t[2].item<float>());
return box_center;
}
// Compute the polyhedron center as the mean of the face centers
// of the triangle faces
//
// Args
// tris: vector of vec3 coordinates of the
// vertices of each of the triangles in the polyhedron
//
// Returns
// vec3: coordinates of the center of the polyhedron
//
inline vec3<float> PolyhedronCenter(const face_verts& tris) {
float x = 0.0;
float y = 0.0;
float z = 0.0;
const int num_tris = tris.size();
// Find the center point of each face
for (int t = 0; t < num_tris; ++t) {
const vec3<float> v0 = tris[t][0];
const vec3<float> v1 = tris[t][1];
const vec3<float> v2 = tris[t][2];
const float x_face = (v0.x + v1.x + v2.x) / 3.0;
const float y_face = (v0.y + v1.y + v2.y) / 3.0;
const float z_face = (v0.z + v1.z + v2.z) / 3.0;
x = x + x_face;
y = y + y_face;
z = z + z_face;
}
// Take the mean of the centers of all faces
x = x / num_tris;
y = y / num_tris;
z = z / num_tris;
const vec3<float> center(x, y, z);
return center;
}
// Compute a boolean indicator for whether a point
// is inside a plane, where inside refers to whether
// or not the point has a component in the
// normal direction of the plane.
//
// Args
// plane: vector of vec3 coordinates of the
// vertices of each of the triangles in the box
// normal: vec3 of the direction of the plane normal
// point: vec3 of the position of the point of interest
//
// Returns
// bool: whether or not the point is inside the plane
//
inline bool IsInside(
const std::vector<vec3<float>>& plane,
const vec3<float>& normal,
const vec3<float>& point) {
// Get one vert of the plane
const vec3<float> v0 = plane[0];
// Every point p can be written as p = v0 + a e0 + b e1 + c n
// Solving for c:
// c = (point - v0 - a * e0 - b * e1).dot(n)
// We know that <e0, n> = 0 and <e1, n> = 0
// So the calculation can be simplified as:
const float c = dot((point - v0), normal);
const bool inside = c > -1.0f * kEpsilon;
return inside;
}
// Find the point of intersection between a plane
// and a line given by the end points (p0, p1)
//
// Args
// plane: vector of vec3 coordinates of the
// vertices of each of the triangles in the box
// normal: vec3 of the direction of the plane normal
// p0, p1: vec3 of the start and end point of the line
//
// Returns
// vec3: position of the intersection point
//
inline vec3<float> PlaneEdgeIntersection(
const std::vector<vec3<float>>& plane,
const vec3<float>& normal,
const vec3<float>& p0,
const vec3<float>& p1) {
// Get one vert of the plane
const vec3<float> v0 = plane[0];
// The point of intersection can be parametrized
// p = p0 + a (p1 - p0) where a in [0, 1]
// We want to find a such that p is on plane
// <p - v0, n> = 0
const float top = dot(-1.0f * (p0 - v0), normal);
const float bot = dot(p1 - p0, normal);
const float a = top / bot;
const vec3<float> p = p0 + a * (p1 - p0);
return p;
}
// Triangle is clipped into a quadrilateral
// based on the intersection points with the plane.
// Then the quadrilateral is divided into two triangles.
//
// Args
// plane: vector of vec3 coordinates of the
// vertices of each of the triangles in the box
// normal: vec3 of the direction of the plane normal
// vout: vec3 of the point in the triangle which is outside
// the plane
// vin1, vin2: vec3 of the points in the triangle which are
// inside the plane
//
// Returns
// std::vector<std::vector<vec3>>: vector of vertex coordinates
// of the new triangle faces
//
inline face_verts ClipTriByPlaneOneOut(
const std::vector<vec3<float>>& plane,
const vec3<float>& normal,
const vec3<float>& vout,
const vec3<float>& vin1,
const vec3<float>& vin2) {
// point of intersection between plane and (vin1, vout)
const vec3<float> pint1 = PlaneEdgeIntersection(plane, normal, vin1, vout);
// point of intersection between plane and (vin2, vout)
const vec3<float> pint2 = PlaneEdgeIntersection(plane, normal, vin2, vout);
const face_verts face_verts = {{vin1, pint1, pint2}, {vin1, pint2, vin2}};
return face_verts;
}
// Triangle is clipped into a smaller triangle based
// on the intersection points with the plane.
//
// Args
// plane: vector of vec3 coordinates of the
// vertices of each of the triangles in the box
// normal: vec3 of the direction of the plane normal
// vout1, vout2: vec3 of the points in the triangle which are
// outside the plane
// vin: vec3 of the point in the triangle which is inside
// the plane
// Returns
// std::vector<std::vector<vec3>>: vector of vertex coordinates
// of the new triangle face
//
inline face_verts ClipTriByPlaneTwoOut(
const std::vector<vec3<float>>& plane,
const vec3<float>& normal,
const vec3<float>& vout1,
const vec3<float>& vout2,
const vec3<float>& vin) {
// point of intersection between plane and (vin, vout1)
const vec3<float> pint1 = PlaneEdgeIntersection(plane, normal, vin, vout1);
// point of intersection between plane and (vin, vout2)
const vec3<float> pint2 = PlaneEdgeIntersection(plane, normal, vin, vout2);
const face_verts face_verts = {{vin, pint1, pint2}};
return face_verts;
}
// Clip the triangle faces so that they lie within the
// plane, creating new triangle faces where necessary.
//
// Args
// plane: vector of vec3 coordinates of the
// vertices of each of the triangles in the box
// tri: std:vector<vec3> of the vertex coordinates of the
// triangle faces
// normal: vec3 of the direction of the plane normal
//
// Returns
// std::vector<std::vector<vec3>>: vector of vertex coordinates
// of the new triangle faces formed after clipping.
// All triangles are now "inside" the plane.
//
inline face_verts ClipTriByPlane(
const std::vector<vec3<float>>& plane,
const std::vector<vec3<float>>& tri,
const vec3<float>& normal) {
// Get Triangle vertices
const vec3<float> v0 = tri[0];
const vec3<float> v1 = tri[1];
const vec3<float> v2 = tri[2];
// Check each of the triangle vertices to see if it is inside the plane
const bool isin0 = IsInside(plane, normal, v0);
const bool isin1 = IsInside(plane, normal, v1);
const bool isin2 = IsInside(plane, normal, v2);
// All in
if (isin0 && isin1 && isin2) {
// Return input vertices
face_verts tris = {{v0, v1, v2}};
return tris;
}
face_verts empty_tris = {};
// All out
if (!isin0 && !isin1 && !isin2) {
return empty_tris;
}
// One vert out
if (isin0 && isin1 && !isin2) {
return ClipTriByPlaneOneOut(plane, normal, v2, v0, v1);
}
if (isin0 && not isin1 && isin2) {
return ClipTriByPlaneOneOut(plane, normal, v1, v0, v2);
}
if (not isin0 && isin1 && isin2) {
return ClipTriByPlaneOneOut(plane, normal, v0, v1, v2);
}
// Two verts out
if (isin0 && !isin1 && !isin2) {
return ClipTriByPlaneTwoOut(plane, normal, v1, v2, v0);
}
if (!isin0 && !isin1 && isin2) {
return ClipTriByPlaneTwoOut(plane, normal, v0, v1, v2);
}
if (!isin0 && isin1 && !isin2) {
return ClipTriByPlaneTwoOut(plane, normal, v0, v2, v1);
}
// Else return empty (should not be reached)
return empty_tris;
}
// Compute a boolean indicator for whether or not two faces
// are coplanar
//
// Args
// tri1, tri2: std:vector<vec3> of the vertex coordinates of
// triangle faces
//
// Returns
// bool: whether or not the two faces are coplanar
//
inline bool IsCoplanarFace(
const std::vector<vec3<float>>& tri1,
const std::vector<vec3<float>>& tri2) {
// Get verts for face 1
const vec3<float> v0 = tri1[0];
const vec3<float> v1 = tri1[1];
const vec3<float> v2 = tri1[2];
const vec3<float> n1 = FaceNormal(v0, v1, v2);
int coplanar_count = 0;
for (int i = 0; i < 3; ++i) {
float d = std::abs(dot(tri2[i] - v0, n1));
if (d < kEpsilon) {
coplanar_count = coplanar_count + 1;
}
}
return (coplanar_count == 3);
}
// Get the triangles from each box which are part of the
// intersecting polyhedron by computing the intersection
// points with each of the planes.
//
// Args
// tris: vertex coordinates of all the triangle faces
// in the box
// planes: vertex coordinates of all the planes in the box
// center: vec3 coordinates of the center of the box from which
// the planes originate
//
// Returns
// std::vector<std::vector<vec3>>> vector of vertex coordinates
// of the new triangle faces formed after clipping.
// All triangles are now "inside" the planes.
//
inline face_verts BoxIntersections(
const face_verts& tris,
const face_verts& planes,
const vec3<float>& center) {
// Create a new vector to avoid modifying in place
face_verts out_tris = tris;
for (int p = 0; p < NUM_PLANES; ++p) {
// Get plane normal direction
const vec3<float> n2 = PlaneNormalDirection(planes[p], center);
// Iterate through triangles in tris
// Create intermediate vector to store the updated tris
face_verts tri_verts_updated;
for (int t = 0; t < out_tris.size(); ++t) {
// Clip tri by plane
const face_verts tri_updated = ClipTriByPlane(planes[p], out_tris[t], n2);
// Add to the tri_verts_updated output if not empty
for (int v = 0; v < tri_updated.size(); ++v) {
tri_verts_updated.push_back(tri_updated[v]);
}
}
// Update the tris
out_tris = tri_verts_updated;
}
return out_tris;
}

View File

@ -0,0 +1,64 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Tuple
import torch
from pytorch3d import _C
from torch.autograd import Function
class _box3d_overlap(Function):
"""
Torch autograd Function wrapper for box3d_overlap C++/CUDA implementations.
Backward is not supported.
"""
@staticmethod
def forward(ctx, boxes1, boxes2):
"""
Arguments defintions the same as in the box3d_overlap function
"""
vol, iou = _C.iou_box3d(boxes1, boxes2)
return vol, iou
@staticmethod
def backward(ctx, grad_vol, grad_iou):
raise ValueError("box3d_overlap backward is not supported")
def box3d_overlap(
boxes1: torch.Tensor, boxes2: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Computes the intersection of 3D boxes1 and boxes2.
Inputs boxes1, boxes2 are tensors of shape (B, 8, 3)
(where B doesn't have to be the same for boxes1 and boxes1),
containing the 8 corners of the boxes, as follows:
(4) +---------+. (5)
| ` . | ` .
| (0) +---+-----+ (1)
| | | |
(7) +-----+---+. (6)|
` . | ` . |
(3) ` +---------+ (2)
Args:
boxes1: tensor of shape (N, 8, 3) of the coordinates of the 1st boxes
boxes2: tensor of shape (M, 8, 3) of the coordinates of the 2nd boxes
Returns:
vol: (N, M) tensor of the volume of the intersecting convex shapes
iou: (N, M) tensor of the intersection over union which is
defined as: `iou = vol / (vol1 + vol2 - vol)`
"""
if not all((8, 3) == box.shape[1:] for box in [boxes1, boxes2]):
raise ValueError("Each box in the batch must be of shape (8, 3)")
# pyre-fixme[16]: `_box3d_overlap` has no attribute `apply`.
vol, iou = _box3d_overlap.apply(boxes1, boxes2)
return vol, iou

37
tests/bm_iou_box3d.py Normal file
View File

@ -0,0 +1,37 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from itertools import product
from fvcore.common.benchmark import benchmark
from test_iou_box3d import TestIoU3D
def bm_iou_box3d() -> None:
N = [1, 4, 8, 16]
num_samples = [2000, 5000, 10000, 20000]
kwargs_list = []
test_cases = product(N, N)
for case in test_cases:
n, m = case
kwargs_list.append({"N": n, "M": m, "device": "cuda:0"})
benchmark(TestIoU3D.iou_naive, "3D_IOU_NAIVE", kwargs_list, warmup_iters=1)
[k.update({"device": "cpu"}) for k in kwargs_list]
benchmark(TestIoU3D.iou, "3D_IOU", kwargs_list, warmup_iters=1)
kwargs_list = []
test_cases = product([1, 4], [1, 4], num_samples)
for case in test_cases:
n, m, s = case
kwargs_list.append({"N": n, "M": m, "num_samples": s})
benchmark(TestIoU3D.iou_sampling, "3D_IOU_SAMPLING", kwargs_list, warmup_iters=1)
if __name__ == "__main__":
bm_iou_box3d()

View File

@ -4,7 +4,6 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import random import random
import unittest import unittest
from typing import List, Tuple, Union from typing import List, Tuple, Union
@ -13,6 +12,8 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from common_testing import TestCaseMixin from common_testing import TestCaseMixin
from pytorch3d.io import save_obj from pytorch3d.io import save_obj
from pytorch3d.ops.iou_box3d import box3d_overlap
from pytorch3d.transforms.rotation_conversions import random_rotation from pytorch3d.transforms.rotation_conversions import random_rotation
@ -21,7 +22,8 @@ class TestIoU3D(TestCaseMixin, unittest.TestCase):
super().setUp() super().setUp()
torch.manual_seed(1) torch.manual_seed(1)
def create_box(self, xyz, whl): @staticmethod
def create_box(xyz, whl):
x, y, z = xyz x, y, z = xyz
w, h, le = whl w, h, le = whl
@ -41,8 +43,39 @@ class TestIoU3D(TestCaseMixin, unittest.TestCase):
) )
return verts return verts
def test_iou(self): @staticmethod
device = torch.device("cuda:0") def _box3d_overlap_naive_batched(boxes1, boxes2):
"""
Wrapper around box3d_overlap_naive to support
batched input
"""
N = boxes1.shape[0]
M = boxes2.shape[0]
vols = torch.zeros((N, M), dtype=torch.float32, device=boxes1.device)
ious = torch.zeros((N, M), dtype=torch.float32, device=boxes1.device)
for n in range(N):
for m in range(M):
vol, iou = box3d_overlap_naive(boxes1[n], boxes2[m])
vols[n, m] = vol
ious[n, m] = iou
return vols, ious
@staticmethod
def _box3d_overlap_sampling_batched(boxes1, boxes2, num_samples: int):
"""
Wrapper around box3d_overlap_sampling to support
batched input
"""
N = boxes1.shape[0]
M = boxes2.shape[0]
ious = torch.zeros((N, M), dtype=torch.float32, device=boxes1.device)
for n in range(N):
for m in range(M):
iou = box3d_overlap_sampling(boxes1[n], boxes2[m])
ious[n, m] = iou
return ious
def _test_iou(self, overlap_fn, device):
box1 = torch.tensor( box1 = torch.tensor(
[ [
@ -60,30 +93,36 @@ class TestIoU3D(TestCaseMixin, unittest.TestCase):
) )
# 1st test: same box, iou = 1.0 # 1st test: same box, iou = 1.0
vol, iou = box3d_overlap(box1, box1) vol, iou = overlap_fn(box1[None], box1[None])
self.assertClose(vol, torch.tensor(1.0, device=vol.device, dtype=vol.dtype)) self.assertClose(vol, torch.tensor([[1.0]], device=vol.device, dtype=vol.dtype))
self.assertClose(iou, torch.tensor(1.0, device=vol.device, dtype=vol.dtype)) self.assertClose(iou, torch.tensor([[1.0]], device=vol.device, dtype=vol.dtype))
# 2nd test # 2nd test
dd = random.random() dd = random.random()
box2 = box1 + torch.tensor([[0.0, dd, 0.0]], device=device) box2 = box1 + torch.tensor([[0.0, dd, 0.0]], device=device)
vol, iou = box3d_overlap(box1, box2) vol, iou = overlap_fn(box1[None], box2[None])
self.assertClose(vol, torch.tensor(1 - dd, device=vol.device, dtype=vol.dtype)) self.assertClose(
vol, torch.tensor([[1 - dd]], device=vol.device, dtype=vol.dtype)
)
# 3rd test # 3rd test
dd = random.random() dd = random.random()
box2 = box1 + torch.tensor([[dd, 0.0, 0.0]], device=device) box2 = box1 + torch.tensor([[dd, 0.0, 0.0]], device=device)
vol, _ = box3d_overlap(box1, box2) vol, _ = overlap_fn(box1[None], box2[None])
self.assertClose(vol, torch.tensor(1 - dd, device=vol.device, dtype=vol.dtype)) self.assertClose(
vol, torch.tensor([[1 - dd]], device=vol.device, dtype=vol.dtype)
)
# 4th test # 4th test
ddx, ddy, ddz = random.random(), random.random(), random.random() ddx, ddy, ddz = random.random(), random.random(), random.random()
box2 = box1 + torch.tensor([[ddx, ddy, ddz]], device=device) box2 = box1 + torch.tensor([[ddx, ddy, ddz]], device=device)
vol, _ = box3d_overlap(box1, box2) vol, _ = overlap_fn(box1[None], box2[None])
self.assertClose( self.assertClose(
vol, vol,
torch.tensor( torch.tensor(
(1 - ddx) * (1 - ddy) * (1 - ddz), device=vol.device, dtype=vol.dtype [[(1 - ddx) * (1 - ddy) * (1 - ddz)]],
device=vol.device,
dtype=vol.dtype,
), ),
) )
@ -93,11 +132,13 @@ class TestIoU3D(TestCaseMixin, unittest.TestCase):
RR = random_rotation(dtype=torch.float32, device=device) RR = random_rotation(dtype=torch.float32, device=device)
box1r = box1 @ RR.transpose(0, 1) box1r = box1 @ RR.transpose(0, 1)
box2r = box2 @ RR.transpose(0, 1) box2r = box2 @ RR.transpose(0, 1)
vol, _ = box3d_overlap(box1r, box2r) vol, _ = overlap_fn(box1r[None], box2r[None])
self.assertClose( self.assertClose(
vol, vol,
torch.tensor( torch.tensor(
(1 - ddx) * (1 - ddy) * (1 - ddz), device=vol.device, dtype=vol.dtype [[(1 - ddx) * (1 - ddy) * (1 - ddz)]],
device=vol.device,
dtype=vol.dtype,
), ),
) )
@ -108,11 +149,13 @@ class TestIoU3D(TestCaseMixin, unittest.TestCase):
TT = torch.rand((1, 3), dtype=torch.float32, device=device) TT = torch.rand((1, 3), dtype=torch.float32, device=device)
box1r = box1 @ RR.transpose(0, 1) + TT box1r = box1 @ RR.transpose(0, 1) + TT
box2r = box2 @ RR.transpose(0, 1) + TT box2r = box2 @ RR.transpose(0, 1) + TT
vol, _ = box3d_overlap(box1r, box2r) vol, _ = overlap_fn(box1r[None], box2r[None])
self.assertClose( self.assertClose(
vol, vol,
torch.tensor( torch.tensor(
(1 - ddx) * (1 - ddy) * (1 - ddz), device=vol.device, dtype=vol.dtype [[(1 - ddx) * (1 - ddy) * (1 - ddz)]],
device=vol.device,
dtype=vol.dtype,
), ),
) )
@ -135,7 +178,7 @@ class TestIoU3D(TestCaseMixin, unittest.TestCase):
[-2.8789, 6.0142, 0.7549], [-2.8789, 6.0142, 0.7549],
[-4.3586, 3.5345, -1.1831], [-4.3586, 3.5345, -1.1831],
], ],
device="cuda:0", device=device,
) )
box2r = torch.tensor( box2r = torch.tensor(
[ [
@ -148,7 +191,7 @@ class TestIoU3D(TestCaseMixin, unittest.TestCase):
[0.4328, -5.3761, -3.5436], [0.4328, -5.3761, -3.5436],
[-2.3633, -5.6305, -1.2893], [-2.3633, -5.6305, -1.2893],
], ],
device="cuda:0", device=device,
) )
# from Meshlab: # from Meshlab:
vol_inters = 33.558529 vol_inters = 33.558529
@ -156,9 +199,9 @@ class TestIoU3D(TestCaseMixin, unittest.TestCase):
vol_box2 = 156.386719 vol_box2 = 156.386719
iou_mesh = vol_inters / (vol_box1 + vol_box2 - vol_inters) iou_mesh = vol_inters / (vol_box1 + vol_box2 - vol_inters)
vol, iou = box3d_overlap(box1r, box2r) vol, iou = overlap_fn(box1r[None], box2r[None])
self.assertClose(vol, torch.tensor(vol_inters, device=device), atol=1e-1) self.assertClose(vol, torch.tensor([[vol_inters]], device=device), atol=1e-1)
self.assertClose(iou, torch.tensor(iou_mesh, device=device), atol=1e-1) self.assertClose(iou, torch.tensor([[iou_mesh]], device=device), atol=1e-1)
# 8th test: compare with sampling # 8th test: compare with sampling
# create box1 # create box1
@ -173,16 +216,47 @@ class TestIoU3D(TestCaseMixin, unittest.TestCase):
TT2 = torch.rand((1, 3), dtype=torch.float32, device=device) TT2 = torch.rand((1, 3), dtype=torch.float32, device=device)
box1r = box1 @ RR1.transpose(0, 1) + TT1 box1r = box1 @ RR1.transpose(0, 1) + TT1
box2r = box2 @ RR2.transpose(0, 1) + TT2 box2r = box2 @ RR2.transpose(0, 1) + TT2
vol, iou = box3d_overlap(box1r, box2r) vol, iou = overlap_fn(box1r[None], box2r[None])
iou_sampling = box3d_overlap_sampling(box1r, box2r, num_samples=10000) iou_sampling = self._box3d_overlap_sampling_batched(
box1r[None], box2r[None], num_samples=10000
)
self.assertClose(iou, iou_sampling, atol=1e-2) self.assertClose(iou, iou_sampling, atol=1e-2)
# 9th test: non overlapping boxes, iou = 0.0 # 9th test: non overlapping boxes, iou = 0.0
box2 = box1 + torch.tensor([[0.0, 100.0, 0.0]], device=device) box2 = box1 + torch.tensor([[0.0, 100.0, 0.0]], device=device)
vol, iou = box3d_overlap(box1, box2) vol, iou = overlap_fn(box1[None], box2[None])
self.assertClose(vol, torch.tensor(0.0, device=vol.device, dtype=vol.dtype)) self.assertClose(vol, torch.tensor([[0.0]], device=vol.device, dtype=vol.dtype))
self.assertClose(iou, torch.tensor(0.0, device=vol.device, dtype=vol.dtype)) self.assertClose(iou, torch.tensor([[0.0]], device=vol.device, dtype=vol.dtype))
def test_iou_naive(self):
device = torch.device("cuda:0")
self._test_iou(self._box3d_overlap_naive_batched, device)
def test_iou_cpu(self):
device = torch.device("cpu")
self._test_iou(box3d_overlap, device)
def test_cpu_vs_naive_batched(self):
N, M = 3, 6
device = "cpu"
boxes1 = torch.randn((N, 8, 3), device=device)
boxes2 = torch.randn((M, 8, 3), device=device)
vol1, iou1 = self._box3d_overlap_naive_batched(boxes1, boxes2)
vol2, iou2 = box3d_overlap(boxes1, boxes2)
# check shape
for val in [vol1, vol2, iou1, iou2]:
self.assertClose(val.shape, (N, M))
# check values
self.assertClose(vol1, vol2)
self.assertClose(iou1, iou2)
def test_batched_errors(self):
N, M = 5, 10
boxes1 = torch.randn((N, 8, 3))
boxes2 = torch.randn((M, 10, 3))
with self.assertRaisesRegex(ValueError, "(8, 3)"):
box3d_overlap(boxes1, boxes2)
def test_box_volume(self): def test_box_volume(self):
device = torch.device("cuda:0") device = torch.device("cuda:0")
@ -277,6 +351,36 @@ class TestIoU3D(TestCaseMixin, unittest.TestCase):
self.assertClose(box_planar_dir(box1), n1) self.assertClose(box_planar_dir(box1), n1)
self.assertClose(box_planar_dir(box2), n2) self.assertClose(box_planar_dir(box2), n2)
@staticmethod
def iou_naive(N: int, M: int, device="cpu"):
boxes1 = torch.randn((N, 8, 3))
boxes2 = torch.randn((M, 8, 3))
def output():
vol, iou = TestIoU3D._box3d_overlap_naive_batched(boxes1, boxes2)
return output
@staticmethod
def iou(N: int, M: int, device="cpu"):
boxes1 = torch.randn((N, 8, 3), device=device)
boxes2 = torch.randn((M, 8, 3), device=device)
def output():
vol, iou = box3d_overlap(boxes1, boxes2)
return output
@staticmethod
def iou_sampling(N: int, M: int, num_samples: int):
boxes1 = torch.randn((N, 8, 3))
boxes2 = torch.randn((M, 8, 3))
def output():
_ = TestIoU3D._box3d_overlap_sampling_batched(boxes1, boxes2, num_samples)
return output
# -------------------------------------------------- # # -------------------------------------------------- #
# NAIVE IMPLEMENTATION # # NAIVE IMPLEMENTATION #
@ -284,7 +388,7 @@ class TestIoU3D(TestCaseMixin, unittest.TestCase):
""" """
The main functions below are: The main functions below are:
* box3d_overlap: which computes the exact IoU of box1 and box2 * box3d_overlap_naive: which computes the exact IoU of box1 and box2
* box3d_overlap_sampling: which computes an approximate IoU of box1 and box2 * box3d_overlap_sampling: which computes an approximate IoU of box1 and box2
by sampling points within the boxes by sampling points within the boxes
@ -738,7 +842,7 @@ def clip_tri_by_plane(plane, n, tri_verts) -> Union[List, torch.Tensor]:
# -------------------------------------------------- # # -------------------------------------------------- #
def box3d_overlap(box1: torch.Tensor, box2: torch.Tensor): def box3d_overlap_naive(box1: torch.Tensor, box2: torch.Tensor):
""" """
Computes the intersection of 3D boxes1 and boxes2. Computes the intersection of 3D boxes1 and boxes2.
Inputs boxes1, boxes2 are tensors of shape (8, 3) containing Inputs boxes1, boxes2 are tensors of shape (8, 3) containing