mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
C++ IoU for 3D Boxes
Summary: C++ Implementation of algorithm to compute 3D bounding boxes for batches of bboxes of shape (N, 8, 3) and (M, 8, 3). Reviewed By: gkioxari Differential Revision: D30905190 fbshipit-source-id: 02e2cf025cd4fa3ff706ce5cf9b82c0fb5443f96
This commit is contained in:
parent
2293f1fed0
commit
53266ec9ff
@ -20,6 +20,7 @@
|
||||
#include "face_areas_normals/face_areas_normals.h"
|
||||
#include "gather_scatter/gather_scatter.h"
|
||||
#include "interp_face_attrs/interp_face_attrs.h"
|
||||
#include "iou_box3d/iou_box3d.h"
|
||||
#include "knn/knn.h"
|
||||
#include "mesh_normal_consistency/mesh_normal_consistency.h"
|
||||
#include "packed_to_padded_tensor/packed_to_padded_tensor.h"
|
||||
@ -87,6 +88,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
// Sample PDF
|
||||
m.def("sample_pdf", &SamplePdf);
|
||||
|
||||
// 3D IoU
|
||||
m.def("iou_box3d", &IoUBox3D);
|
||||
|
||||
// Pulsar.
|
||||
#ifdef PULSAR_LOGGING_ENABLED
|
||||
c10::ShowLogInfoToStderr();
|
||||
|
37
pytorch3d/csrc/iou_box3d/iou_box3d.h
Normal file
37
pytorch3d/csrc/iou_box3d/iou_box3d.h
Normal file
@ -0,0 +1,37 @@
|
||||
/*
|
||||
* Copyright (c) Facebook, Inc. and its affiliates.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This source code is licensed under the BSD-style license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
#include <torch/extension.h>
|
||||
#include <tuple>
|
||||
#include "utils/pytorch3d_cutils.h"
|
||||
|
||||
// Calculate the intersection volume and IoU metric for two batches of boxes
|
||||
//
|
||||
// Args:
|
||||
// boxes1: tensor of shape (N, 8, 3) of the coordinates of the 1st boxes
|
||||
// boxes2: tensor of shape (M, 8, 3) of the coordinates of the 2nd boxes
|
||||
// Returns:
|
||||
// vol: (N, M) tensor of the volume of the intersecting convex shapes
|
||||
// iou: (N, M) tensor of the intersection over union which is
|
||||
// defined as: `iou = vol / (vol1 + vol2 - vol)`
|
||||
|
||||
// CPU implementation
|
||||
std::tuple<at::Tensor, at::Tensor> IoUBox3DCpu(
|
||||
const at::Tensor& boxes1,
|
||||
const at::Tensor& boxes2);
|
||||
|
||||
// Implementation which is exposed
|
||||
inline std::tuple<at::Tensor, at::Tensor> IoUBox3D(
|
||||
const at::Tensor& boxes1,
|
||||
const at::Tensor& boxes2) {
|
||||
if (boxes1.is_cuda() || boxes2.is_cuda()) {
|
||||
AT_ERROR("GPU support not implemented");
|
||||
}
|
||||
return IoUBox3DCpu(boxes1.contiguous(), boxes2.contiguous());
|
||||
}
|
121
pytorch3d/csrc/iou_box3d/iou_box3d_cpu.cpp
Normal file
121
pytorch3d/csrc/iou_box3d/iou_box3d_cpu.cpp
Normal file
@ -0,0 +1,121 @@
|
||||
/*
|
||||
* Copyright (c) Facebook, Inc. and its affiliates.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This source code is licensed under the BSD-style license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*/
|
||||
|
||||
#include <torch/extension.h>
|
||||
#include <torch/torch.h>
|
||||
#include <list>
|
||||
#include <numeric>
|
||||
#include <queue>
|
||||
#include <tuple>
|
||||
#include "iou_box3d/iou_utils.h"
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor> IoUBox3DCpu(
|
||||
const at::Tensor& boxes1,
|
||||
const at::Tensor& boxes2) {
|
||||
const int N = boxes1.size(0);
|
||||
const int M = boxes2.size(0);
|
||||
auto float_opts = boxes1.options().dtype(torch::kFloat32);
|
||||
torch::Tensor vols = torch::zeros({N, M}, float_opts);
|
||||
torch::Tensor ious = torch::zeros({N, M}, float_opts);
|
||||
|
||||
// Create tensor accessors
|
||||
auto boxes1_a = boxes1.accessor<float, 3>();
|
||||
auto boxes2_a = boxes2.accessor<float, 3>();
|
||||
auto vols_a = vols.accessor<float, 2>();
|
||||
auto ious_a = ious.accessor<float, 2>();
|
||||
|
||||
// Iterate through the N boxes in boxes1
|
||||
for (int n = 0; n < N; ++n) {
|
||||
const auto& box1 = boxes1_a[n];
|
||||
// Convert to vector of face vertices i.e. effectively (F, 3, 3)
|
||||
// face_verts is a data type defined in iou_utils.h
|
||||
const face_verts box1_tris = GetBoxTris(box1);
|
||||
|
||||
// Calculate the position of the center of the box which is used in
|
||||
// several calculations. This requires a tensor as input.
|
||||
const vec3<float> box1_center = BoxCenter(boxes1[n]);
|
||||
|
||||
// Convert to vector of face vertices i.e. effectively (P, 4, 3)
|
||||
const face_verts box1_planes = GetBoxPlanes(box1);
|
||||
|
||||
// Get Box Volumes
|
||||
const float box1_vol = BoxVolume(box1_tris, box1_center);
|
||||
|
||||
// Iterate through the M boxes in boxes2
|
||||
for (int m = 0; m < M; ++m) {
|
||||
// Repeat above steps for box2
|
||||
// TODO: check if caching these value helps performance.
|
||||
const auto& box2 = boxes2_a[m];
|
||||
const face_verts box2_tris = GetBoxTris(box2);
|
||||
const vec3<float> box2_center = BoxCenter(boxes2[m]);
|
||||
const face_verts box2_planes = GetBoxPlanes(box2);
|
||||
const float box2_vol = BoxVolume(box2_tris, box2_center);
|
||||
|
||||
// Every triangle in one box will be compared to each plane in the other
|
||||
// box. There are 3 possible outcomes:
|
||||
// 1. If the triangle is fully inside, then it will
|
||||
// remain as is.
|
||||
// 2. If the triagnle it is fully outside, it will be removed.
|
||||
// 3. If the triangle intersects with the (infinite) plane, it
|
||||
// will be broken into subtriangles such that each subtriangle is full
|
||||
// inside the plane and part of the intersecting tetrahedron.
|
||||
|
||||
// Tris in Box1 -> Planes in Box2
|
||||
face_verts box1_intersect =
|
||||
BoxIntersections(box1_tris, box2_planes, box2_center);
|
||||
// Tris in Box2 -> Planes in Box1
|
||||
face_verts box2_intersect =
|
||||
BoxIntersections(box2_tris, box1_planes, box1_center);
|
||||
|
||||
// If there are overlapping regions in Box2, remove any coplanar faces
|
||||
if (box2_intersect.size() > 0) {
|
||||
// Identify if any triangles in Box2 are coplanar with Box1
|
||||
std::vector<int> tri2_keep(box2_intersect.size());
|
||||
std::fill(tri2_keep.begin(), tri2_keep.end(), 1);
|
||||
for (int b1 = 0; b1 < box1_intersect.size(); ++b1) {
|
||||
for (int b2 = 0; b2 < box2_intersect.size(); ++b2) {
|
||||
bool is_coplanar =
|
||||
IsCoplanarFace(box1_intersect[b1], box2_intersect[b2]);
|
||||
if (is_coplanar) {
|
||||
tri2_keep[b2] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Keep only the non coplanar triangles in Box2 - add them to the
|
||||
// Box1 triangles.
|
||||
for (int b2 = 0; b2 < box2_intersect.size(); ++b2) {
|
||||
if (tri2_keep[b2] == 1) {
|
||||
box1_intersect.push_back((box2_intersect[b2]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize the vol and iou to 0.0 in case there are no triangles
|
||||
// in the intersecting shape.
|
||||
float vol = 0.0;
|
||||
float iou = 0.0;
|
||||
|
||||
// If there are triangles in the intersecting shape
|
||||
if (box1_intersect.size() > 0) {
|
||||
// The intersecting shape is a polyhedron made up of the
|
||||
// triangular faces that are all now in box1_intersect.
|
||||
// Calculate the polyhedron center
|
||||
const vec3<float> polyhedron_center = PolyhedronCenter(box1_intersect);
|
||||
// Compute intersecting polyhedron volume
|
||||
vol = BoxVolume(box1_intersect, polyhedron_center);
|
||||
// Compute IoU
|
||||
iou = vol / (box1_vol + box2_vol - vol);
|
||||
}
|
||||
// Save out volume and IoU
|
||||
vols_a[n][m] = vol;
|
||||
ious_a[n][m] = iou;
|
||||
}
|
||||
}
|
||||
return std::make_tuple(vols, ious);
|
||||
}
|
531
pytorch3d/csrc/iou_box3d/iou_utils.h
Normal file
531
pytorch3d/csrc/iou_box3d/iou_utils.h
Normal file
@ -0,0 +1,531 @@
|
||||
/*
|
||||
* Copyright (c) Facebook, Inc. and its affiliates.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This source code is licensed under the BSD-style license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*/
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <assert.h>
|
||||
#include <torch/extension.h>
|
||||
#include <torch/torch.h>
|
||||
#include <algorithm>
|
||||
#include <list>
|
||||
#include <numeric>
|
||||
#include <queue>
|
||||
#include <tuple>
|
||||
#include <type_traits>
|
||||
#include "utils/geometry_utils.h"
|
||||
#include "utils/vec3.h"
|
||||
|
||||
/*
|
||||
_PLANES and _TRIS define the 4- and 3-connectivity
|
||||
of the 8 box corners.
|
||||
_PLANES gives the quad faces of the 3D box
|
||||
_TRIS gives the triangle faces of the 3D box
|
||||
*/
|
||||
const int NUM_PLANES = 6;
|
||||
const int NUM_TRIS = 12;
|
||||
const int _PLANES[6][4] = {
|
||||
{0, 1, 2, 3},
|
||||
{3, 2, 6, 7},
|
||||
{0, 1, 5, 4},
|
||||
{0, 3, 7, 4},
|
||||
{1, 5, 6, 2},
|
||||
{4, 5, 6, 7},
|
||||
};
|
||||
const int _TRIS[12][3] = {
|
||||
{0, 1, 2},
|
||||
{0, 3, 2},
|
||||
{4, 5, 6},
|
||||
{4, 6, 7},
|
||||
{1, 5, 6},
|
||||
{1, 6, 2},
|
||||
{0, 4, 7},
|
||||
{0, 7, 3},
|
||||
{3, 2, 6},
|
||||
{3, 6, 7},
|
||||
{0, 1, 5},
|
||||
{0, 4, 5},
|
||||
};
|
||||
|
||||
// Create a new data type for representing the
|
||||
// verts for each face which can be triangle or plane.
|
||||
// This helps make the code more readable.
|
||||
using face_verts = std::vector<std::vector<vec3<float>>>;
|
||||
|
||||
// Args
|
||||
// box: (8, 3) tensor accessor for the box vertices
|
||||
// plane_idx: index of the plane in the box
|
||||
// vert_idx: index of the vertex in the plane
|
||||
//
|
||||
// Returns
|
||||
// vec3<T> (x, y, x) vertex coordinates
|
||||
//
|
||||
template <typename Box>
|
||||
inline vec3<float>
|
||||
ExtractVertsPlane(const Box& box, const int plane_idx, const int vert_idx) {
|
||||
return vec3<float>(
|
||||
box[_PLANES[plane_idx][vert_idx]][0],
|
||||
box[_PLANES[plane_idx][vert_idx]][1],
|
||||
box[_PLANES[plane_idx][vert_idx]][2]);
|
||||
}
|
||||
|
||||
// Args
|
||||
// box: (8, 3) tensor accessor for the box vertices
|
||||
// tri_idx: index of the triangle face in the box
|
||||
// vert_idx: index of the vertex in the triangle
|
||||
//
|
||||
// Returns
|
||||
// vec3<T> (x, y, x) vertex coordinates
|
||||
//
|
||||
template <typename Box>
|
||||
inline vec3<float>
|
||||
ExtractVertsTri(const Box& box, const int tri_idx, const int vert_idx) {
|
||||
return vec3<float>(
|
||||
box[_TRIS[tri_idx][vert_idx]][0],
|
||||
box[_TRIS[tri_idx][vert_idx]][1],
|
||||
box[_TRIS[tri_idx][vert_idx]][2]);
|
||||
}
|
||||
|
||||
// Args
|
||||
// box: (8, 3) tensor accessor for the box vertices
|
||||
//
|
||||
// Returns
|
||||
// std::vector<std::vector<vec3<T>>> effectively (F, 3, 3)
|
||||
// coordinates of the verts for each face
|
||||
//
|
||||
template <typename Box>
|
||||
inline face_verts GetBoxTris(const Box& box) {
|
||||
face_verts box_tris;
|
||||
for (int t = 0; t < NUM_TRIS; ++t) {
|
||||
vec3<float> v0 = ExtractVertsTri(box, t, 0);
|
||||
vec3<float> v1 = ExtractVertsTri(box, t, 1);
|
||||
vec3<float> v2 = ExtractVertsTri(box, t, 2);
|
||||
box_tris.push_back({v0, v1, v2});
|
||||
}
|
||||
return box_tris;
|
||||
}
|
||||
|
||||
// Args
|
||||
// box: (8, 3) tensor accessor for the box vertices
|
||||
//
|
||||
// Returns
|
||||
// std::vector<std::vector<vec3<T>>> effectively (P, 3, 3)
|
||||
// coordinates of the 4 verts for each plane
|
||||
//
|
||||
template <typename Box>
|
||||
inline face_verts GetBoxPlanes(const Box& box) {
|
||||
face_verts box_planes;
|
||||
for (int t = 0; t < NUM_PLANES; ++t) {
|
||||
vec3<float> v0 = ExtractVertsPlane(box, t, 0);
|
||||
vec3<float> v1 = ExtractVertsPlane(box, t, 1);
|
||||
vec3<float> v2 = ExtractVertsPlane(box, t, 2);
|
||||
vec3<float> v3 = ExtractVertsPlane(box, t, 3);
|
||||
box_planes.push_back({v0, v1, v2, v3});
|
||||
}
|
||||
return box_planes;
|
||||
}
|
||||
|
||||
// The normal of the face defined by vertices (v0, v1, v2)
|
||||
// Define e0 to be the edge connecting (v1, v0)
|
||||
// Define e1 to be the edge connecting (v2, v0)
|
||||
// normal is the cross product of e0, e1
|
||||
//
|
||||
// Args
|
||||
// v0, v1, v2: vec3 coordinates of the vertices of the face
|
||||
//
|
||||
// Returns
|
||||
// vec3: normal for the face
|
||||
//
|
||||
inline vec3<float> FaceNormal(vec3<float> v0, vec3<float> v1, vec3<float> v2) {
|
||||
vec3<float> n = cross(v1 - v0, v2 - v0);
|
||||
n = n / std::fmaxf(norm(n), kEpsilon);
|
||||
return n;
|
||||
}
|
||||
|
||||
// The normal of a box plane defined by the verts in `plane` with
|
||||
// the centroid of the box given by `center`.
|
||||
// Args
|
||||
// plane: vec3 coordinates of the vertices of the plane
|
||||
// center: vec3 coordinates of the center of the box from
|
||||
// which the plane originated
|
||||
//
|
||||
// Returns
|
||||
// vec3: normal for the plane such that it points towards
|
||||
// the center of the box
|
||||
//
|
||||
inline vec3<float> PlaneNormalDirection(
|
||||
const std::vector<vec3<float>>& plane,
|
||||
const vec3<float>& center) {
|
||||
// Only need the first 3 verts of the plane
|
||||
const vec3<float> v0 = plane[0];
|
||||
const vec3<float> v1 = plane[1];
|
||||
const vec3<float> v2 = plane[2];
|
||||
|
||||
// We project the center on the plane defined by (v0, v1, v2)
|
||||
// We can write center = v0 + a * e0 + b * e1 + c * n
|
||||
// We know that <e0, n> = 0 and <e1, n> = 0 and
|
||||
// <a, b> is the dot product between a and b.
|
||||
// This means we can solve for c as:
|
||||
// c = <center - v0 - a * e0 - b * e1, n> = <center - v0, n>
|
||||
vec3<float> n = FaceNormal(v0, v1, v2);
|
||||
const float c = dot((center - v0), n);
|
||||
|
||||
// If c is negative, then we revert the direction of n such that n
|
||||
// points "inside"
|
||||
if (c < kEpsilon) {
|
||||
n = -1.0f * n;
|
||||
}
|
||||
|
||||
return n;
|
||||
}
|
||||
|
||||
// Calculate the volume of the box by summing the volume of
|
||||
// each of the tetrahedrons formed with a triangle face and
|
||||
// the box centroid.
|
||||
//
|
||||
// Args
|
||||
// box_tris: vector of vec3 coordinates of the vertices of each
|
||||
// of the triangles in the box
|
||||
// box_center: vec3 coordinates of the center of the box
|
||||
//
|
||||
// Returns
|
||||
// float: volume of the box
|
||||
//
|
||||
inline float BoxVolume(
|
||||
const face_verts& box_tris,
|
||||
const vec3<float>& box_center) {
|
||||
float box_vol = 0.0;
|
||||
// Iterate through each triange, calculate the area of the
|
||||
// tetrahedron formed with the box_center and sum them
|
||||
for (int t = 0; t < box_tris.size(); ++t) {
|
||||
// Subtract the center:
|
||||
const vec3<float> v0 = box_tris[t][0] - box_center;
|
||||
const vec3<float> v1 = box_tris[t][1] - box_center;
|
||||
const vec3<float> v2 = box_tris[t][2] - box_center;
|
||||
|
||||
// Compute the area
|
||||
const float area = dot(v0, cross(v1, v2));
|
||||
const float vol = std::abs(area) / 6.0;
|
||||
box_vol = box_vol + vol;
|
||||
}
|
||||
return box_vol;
|
||||
}
|
||||
|
||||
// Compute the box center as the mean of the verts
|
||||
//
|
||||
// Args
|
||||
// box_verts: (8, 3) tensor of the corner vertices of the box
|
||||
//
|
||||
// Returns
|
||||
// vec3: coordinates of the center of the box
|
||||
//
|
||||
inline vec3<float> BoxCenter(const at::Tensor& box_verts) {
|
||||
const auto& box_center_t = at::mean(box_verts, 0);
|
||||
const vec3<float> box_center(
|
||||
box_center_t[0].item<float>(),
|
||||
box_center_t[1].item<float>(),
|
||||
box_center_t[2].item<float>());
|
||||
return box_center;
|
||||
}
|
||||
|
||||
// Compute the polyhedron center as the mean of the face centers
|
||||
// of the triangle faces
|
||||
//
|
||||
// Args
|
||||
// tris: vector of vec3 coordinates of the
|
||||
// vertices of each of the triangles in the polyhedron
|
||||
//
|
||||
// Returns
|
||||
// vec3: coordinates of the center of the polyhedron
|
||||
//
|
||||
inline vec3<float> PolyhedronCenter(const face_verts& tris) {
|
||||
float x = 0.0;
|
||||
float y = 0.0;
|
||||
float z = 0.0;
|
||||
const int num_tris = tris.size();
|
||||
|
||||
// Find the center point of each face
|
||||
for (int t = 0; t < num_tris; ++t) {
|
||||
const vec3<float> v0 = tris[t][0];
|
||||
const vec3<float> v1 = tris[t][1];
|
||||
const vec3<float> v2 = tris[t][2];
|
||||
const float x_face = (v0.x + v1.x + v2.x) / 3.0;
|
||||
const float y_face = (v0.y + v1.y + v2.y) / 3.0;
|
||||
const float z_face = (v0.z + v1.z + v2.z) / 3.0;
|
||||
x = x + x_face;
|
||||
y = y + y_face;
|
||||
z = z + z_face;
|
||||
}
|
||||
|
||||
// Take the mean of the centers of all faces
|
||||
x = x / num_tris;
|
||||
y = y / num_tris;
|
||||
z = z / num_tris;
|
||||
|
||||
const vec3<float> center(x, y, z);
|
||||
return center;
|
||||
}
|
||||
|
||||
// Compute a boolean indicator for whether a point
|
||||
// is inside a plane, where inside refers to whether
|
||||
// or not the point has a component in the
|
||||
// normal direction of the plane.
|
||||
//
|
||||
// Args
|
||||
// plane: vector of vec3 coordinates of the
|
||||
// vertices of each of the triangles in the box
|
||||
// normal: vec3 of the direction of the plane normal
|
||||
// point: vec3 of the position of the point of interest
|
||||
//
|
||||
// Returns
|
||||
// bool: whether or not the point is inside the plane
|
||||
//
|
||||
inline bool IsInside(
|
||||
const std::vector<vec3<float>>& plane,
|
||||
const vec3<float>& normal,
|
||||
const vec3<float>& point) {
|
||||
// Get one vert of the plane
|
||||
const vec3<float> v0 = plane[0];
|
||||
|
||||
// Every point p can be written as p = v0 + a e0 + b e1 + c n
|
||||
// Solving for c:
|
||||
// c = (point - v0 - a * e0 - b * e1).dot(n)
|
||||
// We know that <e0, n> = 0 and <e1, n> = 0
|
||||
// So the calculation can be simplified as:
|
||||
const float c = dot((point - v0), normal);
|
||||
const bool inside = c > -1.0f * kEpsilon;
|
||||
return inside;
|
||||
}
|
||||
|
||||
// Find the point of intersection between a plane
|
||||
// and a line given by the end points (p0, p1)
|
||||
//
|
||||
// Args
|
||||
// plane: vector of vec3 coordinates of the
|
||||
// vertices of each of the triangles in the box
|
||||
// normal: vec3 of the direction of the plane normal
|
||||
// p0, p1: vec3 of the start and end point of the line
|
||||
//
|
||||
// Returns
|
||||
// vec3: position of the intersection point
|
||||
//
|
||||
inline vec3<float> PlaneEdgeIntersection(
|
||||
const std::vector<vec3<float>>& plane,
|
||||
const vec3<float>& normal,
|
||||
const vec3<float>& p0,
|
||||
const vec3<float>& p1) {
|
||||
// Get one vert of the plane
|
||||
const vec3<float> v0 = plane[0];
|
||||
|
||||
// The point of intersection can be parametrized
|
||||
// p = p0 + a (p1 - p0) where a in [0, 1]
|
||||
// We want to find a such that p is on plane
|
||||
// <p - v0, n> = 0
|
||||
const float top = dot(-1.0f * (p0 - v0), normal);
|
||||
const float bot = dot(p1 - p0, normal);
|
||||
const float a = top / bot;
|
||||
const vec3<float> p = p0 + a * (p1 - p0);
|
||||
return p;
|
||||
}
|
||||
|
||||
// Triangle is clipped into a quadrilateral
|
||||
// based on the intersection points with the plane.
|
||||
// Then the quadrilateral is divided into two triangles.
|
||||
//
|
||||
// Args
|
||||
// plane: vector of vec3 coordinates of the
|
||||
// vertices of each of the triangles in the box
|
||||
// normal: vec3 of the direction of the plane normal
|
||||
// vout: vec3 of the point in the triangle which is outside
|
||||
// the plane
|
||||
// vin1, vin2: vec3 of the points in the triangle which are
|
||||
// inside the plane
|
||||
//
|
||||
// Returns
|
||||
// std::vector<std::vector<vec3>>: vector of vertex coordinates
|
||||
// of the new triangle faces
|
||||
//
|
||||
inline face_verts ClipTriByPlaneOneOut(
|
||||
const std::vector<vec3<float>>& plane,
|
||||
const vec3<float>& normal,
|
||||
const vec3<float>& vout,
|
||||
const vec3<float>& vin1,
|
||||
const vec3<float>& vin2) {
|
||||
// point of intersection between plane and (vin1, vout)
|
||||
const vec3<float> pint1 = PlaneEdgeIntersection(plane, normal, vin1, vout);
|
||||
// point of intersection between plane and (vin2, vout)
|
||||
const vec3<float> pint2 = PlaneEdgeIntersection(plane, normal, vin2, vout);
|
||||
const face_verts face_verts = {{vin1, pint1, pint2}, {vin1, pint2, vin2}};
|
||||
return face_verts;
|
||||
}
|
||||
|
||||
// Triangle is clipped into a smaller triangle based
|
||||
// on the intersection points with the plane.
|
||||
//
|
||||
// Args
|
||||
// plane: vector of vec3 coordinates of the
|
||||
// vertices of each of the triangles in the box
|
||||
// normal: vec3 of the direction of the plane normal
|
||||
// vout1, vout2: vec3 of the points in the triangle which are
|
||||
// outside the plane
|
||||
// vin: vec3 of the point in the triangle which is inside
|
||||
// the plane
|
||||
// Returns
|
||||
// std::vector<std::vector<vec3>>: vector of vertex coordinates
|
||||
// of the new triangle face
|
||||
//
|
||||
inline face_verts ClipTriByPlaneTwoOut(
|
||||
const std::vector<vec3<float>>& plane,
|
||||
const vec3<float>& normal,
|
||||
const vec3<float>& vout1,
|
||||
const vec3<float>& vout2,
|
||||
const vec3<float>& vin) {
|
||||
// point of intersection between plane and (vin, vout1)
|
||||
const vec3<float> pint1 = PlaneEdgeIntersection(plane, normal, vin, vout1);
|
||||
// point of intersection between plane and (vin, vout2)
|
||||
const vec3<float> pint2 = PlaneEdgeIntersection(plane, normal, vin, vout2);
|
||||
const face_verts face_verts = {{vin, pint1, pint2}};
|
||||
return face_verts;
|
||||
}
|
||||
|
||||
// Clip the triangle faces so that they lie within the
|
||||
// plane, creating new triangle faces where necessary.
|
||||
//
|
||||
// Args
|
||||
// plane: vector of vec3 coordinates of the
|
||||
// vertices of each of the triangles in the box
|
||||
// tri: std:vector<vec3> of the vertex coordinates of the
|
||||
// triangle faces
|
||||
// normal: vec3 of the direction of the plane normal
|
||||
//
|
||||
// Returns
|
||||
// std::vector<std::vector<vec3>>: vector of vertex coordinates
|
||||
// of the new triangle faces formed after clipping.
|
||||
// All triangles are now "inside" the plane.
|
||||
//
|
||||
inline face_verts ClipTriByPlane(
|
||||
const std::vector<vec3<float>>& plane,
|
||||
const std::vector<vec3<float>>& tri,
|
||||
const vec3<float>& normal) {
|
||||
// Get Triangle vertices
|
||||
const vec3<float> v0 = tri[0];
|
||||
const vec3<float> v1 = tri[1];
|
||||
const vec3<float> v2 = tri[2];
|
||||
|
||||
// Check each of the triangle vertices to see if it is inside the plane
|
||||
const bool isin0 = IsInside(plane, normal, v0);
|
||||
const bool isin1 = IsInside(plane, normal, v1);
|
||||
const bool isin2 = IsInside(plane, normal, v2);
|
||||
|
||||
// All in
|
||||
if (isin0 && isin1 && isin2) {
|
||||
// Return input vertices
|
||||
face_verts tris = {{v0, v1, v2}};
|
||||
return tris;
|
||||
}
|
||||
|
||||
face_verts empty_tris = {};
|
||||
// All out
|
||||
if (!isin0 && !isin1 && !isin2) {
|
||||
return empty_tris;
|
||||
}
|
||||
|
||||
// One vert out
|
||||
if (isin0 && isin1 && !isin2) {
|
||||
return ClipTriByPlaneOneOut(plane, normal, v2, v0, v1);
|
||||
}
|
||||
if (isin0 && not isin1 && isin2) {
|
||||
return ClipTriByPlaneOneOut(plane, normal, v1, v0, v2);
|
||||
}
|
||||
if (not isin0 && isin1 && isin2) {
|
||||
return ClipTriByPlaneOneOut(plane, normal, v0, v1, v2);
|
||||
}
|
||||
|
||||
// Two verts out
|
||||
if (isin0 && !isin1 && !isin2) {
|
||||
return ClipTriByPlaneTwoOut(plane, normal, v1, v2, v0);
|
||||
}
|
||||
if (!isin0 && !isin1 && isin2) {
|
||||
return ClipTriByPlaneTwoOut(plane, normal, v0, v1, v2);
|
||||
}
|
||||
if (!isin0 && isin1 && !isin2) {
|
||||
return ClipTriByPlaneTwoOut(plane, normal, v0, v2, v1);
|
||||
}
|
||||
|
||||
// Else return empty (should not be reached)
|
||||
return empty_tris;
|
||||
}
|
||||
|
||||
// Compute a boolean indicator for whether or not two faces
|
||||
// are coplanar
|
||||
//
|
||||
// Args
|
||||
// tri1, tri2: std:vector<vec3> of the vertex coordinates of
|
||||
// triangle faces
|
||||
//
|
||||
// Returns
|
||||
// bool: whether or not the two faces are coplanar
|
||||
//
|
||||
inline bool IsCoplanarFace(
|
||||
const std::vector<vec3<float>>& tri1,
|
||||
const std::vector<vec3<float>>& tri2) {
|
||||
// Get verts for face 1
|
||||
const vec3<float> v0 = tri1[0];
|
||||
const vec3<float> v1 = tri1[1];
|
||||
const vec3<float> v2 = tri1[2];
|
||||
|
||||
const vec3<float> n1 = FaceNormal(v0, v1, v2);
|
||||
int coplanar_count = 0;
|
||||
for (int i = 0; i < 3; ++i) {
|
||||
float d = std::abs(dot(tri2[i] - v0, n1));
|
||||
if (d < kEpsilon) {
|
||||
coplanar_count = coplanar_count + 1;
|
||||
}
|
||||
}
|
||||
return (coplanar_count == 3);
|
||||
}
|
||||
|
||||
// Get the triangles from each box which are part of the
|
||||
// intersecting polyhedron by computing the intersection
|
||||
// points with each of the planes.
|
||||
//
|
||||
// Args
|
||||
// tris: vertex coordinates of all the triangle faces
|
||||
// in the box
|
||||
// planes: vertex coordinates of all the planes in the box
|
||||
// center: vec3 coordinates of the center of the box from which
|
||||
// the planes originate
|
||||
//
|
||||
// Returns
|
||||
// std::vector<std::vector<vec3>>> vector of vertex coordinates
|
||||
// of the new triangle faces formed after clipping.
|
||||
// All triangles are now "inside" the planes.
|
||||
//
|
||||
inline face_verts BoxIntersections(
|
||||
const face_verts& tris,
|
||||
const face_verts& planes,
|
||||
const vec3<float>& center) {
|
||||
// Create a new vector to avoid modifying in place
|
||||
face_verts out_tris = tris;
|
||||
for (int p = 0; p < NUM_PLANES; ++p) {
|
||||
// Get plane normal direction
|
||||
const vec3<float> n2 = PlaneNormalDirection(planes[p], center);
|
||||
// Iterate through triangles in tris
|
||||
// Create intermediate vector to store the updated tris
|
||||
face_verts tri_verts_updated;
|
||||
for (int t = 0; t < out_tris.size(); ++t) {
|
||||
// Clip tri by plane
|
||||
const face_verts tri_updated = ClipTriByPlane(planes[p], out_tris[t], n2);
|
||||
// Add to the tri_verts_updated output if not empty
|
||||
for (int v = 0; v < tri_updated.size(); ++v) {
|
||||
tri_verts_updated.push_back(tri_updated[v]);
|
||||
}
|
||||
}
|
||||
// Update the tris
|
||||
out_tris = tri_verts_updated;
|
||||
}
|
||||
return out_tris;
|
||||
}
|
64
pytorch3d/ops/iou_box3d.py
Normal file
64
pytorch3d/ops/iou_box3d.py
Normal file
@ -0,0 +1,64 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
from pytorch3d import _C
|
||||
from torch.autograd import Function
|
||||
|
||||
|
||||
class _box3d_overlap(Function):
|
||||
"""
|
||||
Torch autograd Function wrapper for box3d_overlap C++/CUDA implementations.
|
||||
Backward is not supported.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, boxes1, boxes2):
|
||||
"""
|
||||
Arguments defintions the same as in the box3d_overlap function
|
||||
"""
|
||||
vol, iou = _C.iou_box3d(boxes1, boxes2)
|
||||
return vol, iou
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_vol, grad_iou):
|
||||
raise ValueError("box3d_overlap backward is not supported")
|
||||
|
||||
|
||||
def box3d_overlap(
|
||||
boxes1: torch.Tensor, boxes2: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Computes the intersection of 3D boxes1 and boxes2.
|
||||
Inputs boxes1, boxes2 are tensors of shape (B, 8, 3)
|
||||
(where B doesn't have to be the same for boxes1 and boxes1),
|
||||
containing the 8 corners of the boxes, as follows:
|
||||
|
||||
(4) +---------+. (5)
|
||||
| ` . | ` .
|
||||
| (0) +---+-----+ (1)
|
||||
| | | |
|
||||
(7) +-----+---+. (6)|
|
||||
` . | ` . |
|
||||
(3) ` +---------+ (2)
|
||||
|
||||
Args:
|
||||
boxes1: tensor of shape (N, 8, 3) of the coordinates of the 1st boxes
|
||||
boxes2: tensor of shape (M, 8, 3) of the coordinates of the 2nd boxes
|
||||
Returns:
|
||||
vol: (N, M) tensor of the volume of the intersecting convex shapes
|
||||
iou: (N, M) tensor of the intersection over union which is
|
||||
defined as: `iou = vol / (vol1 + vol2 - vol)`
|
||||
"""
|
||||
if not all((8, 3) == box.shape[1:] for box in [boxes1, boxes2]):
|
||||
raise ValueError("Each box in the batch must be of shape (8, 3)")
|
||||
|
||||
# pyre-fixme[16]: `_box3d_overlap` has no attribute `apply`.
|
||||
vol, iou = _box3d_overlap.apply(boxes1, boxes2)
|
||||
|
||||
return vol, iou
|
37
tests/bm_iou_box3d.py
Normal file
37
tests/bm_iou_box3d.py
Normal file
@ -0,0 +1,37 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from itertools import product
|
||||
|
||||
from fvcore.common.benchmark import benchmark
|
||||
from test_iou_box3d import TestIoU3D
|
||||
|
||||
|
||||
def bm_iou_box3d() -> None:
|
||||
N = [1, 4, 8, 16]
|
||||
num_samples = [2000, 5000, 10000, 20000]
|
||||
|
||||
kwargs_list = []
|
||||
test_cases = product(N, N)
|
||||
for case in test_cases:
|
||||
n, m = case
|
||||
kwargs_list.append({"N": n, "M": m, "device": "cuda:0"})
|
||||
|
||||
benchmark(TestIoU3D.iou_naive, "3D_IOU_NAIVE", kwargs_list, warmup_iters=1)
|
||||
|
||||
[k.update({"device": "cpu"}) for k in kwargs_list]
|
||||
benchmark(TestIoU3D.iou, "3D_IOU", kwargs_list, warmup_iters=1)
|
||||
|
||||
kwargs_list = []
|
||||
test_cases = product([1, 4], [1, 4], num_samples)
|
||||
for case in test_cases:
|
||||
n, m, s = case
|
||||
kwargs_list.append({"N": n, "M": m, "num_samples": s})
|
||||
benchmark(TestIoU3D.iou_sampling, "3D_IOU_SAMPLING", kwargs_list, warmup_iters=1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
bm_iou_box3d()
|
@ -4,7 +4,6 @@
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import random
|
||||
import unittest
|
||||
from typing import List, Tuple, Union
|
||||
@ -13,6 +12,8 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from common_testing import TestCaseMixin
|
||||
from pytorch3d.io import save_obj
|
||||
|
||||
from pytorch3d.ops.iou_box3d import box3d_overlap
|
||||
from pytorch3d.transforms.rotation_conversions import random_rotation
|
||||
|
||||
|
||||
@ -21,7 +22,8 @@ class TestIoU3D(TestCaseMixin, unittest.TestCase):
|
||||
super().setUp()
|
||||
torch.manual_seed(1)
|
||||
|
||||
def create_box(self, xyz, whl):
|
||||
@staticmethod
|
||||
def create_box(xyz, whl):
|
||||
x, y, z = xyz
|
||||
w, h, le = whl
|
||||
|
||||
@ -41,8 +43,39 @@ class TestIoU3D(TestCaseMixin, unittest.TestCase):
|
||||
)
|
||||
return verts
|
||||
|
||||
def test_iou(self):
|
||||
device = torch.device("cuda:0")
|
||||
@staticmethod
|
||||
def _box3d_overlap_naive_batched(boxes1, boxes2):
|
||||
"""
|
||||
Wrapper around box3d_overlap_naive to support
|
||||
batched input
|
||||
"""
|
||||
N = boxes1.shape[0]
|
||||
M = boxes2.shape[0]
|
||||
vols = torch.zeros((N, M), dtype=torch.float32, device=boxes1.device)
|
||||
ious = torch.zeros((N, M), dtype=torch.float32, device=boxes1.device)
|
||||
for n in range(N):
|
||||
for m in range(M):
|
||||
vol, iou = box3d_overlap_naive(boxes1[n], boxes2[m])
|
||||
vols[n, m] = vol
|
||||
ious[n, m] = iou
|
||||
return vols, ious
|
||||
|
||||
@staticmethod
|
||||
def _box3d_overlap_sampling_batched(boxes1, boxes2, num_samples: int):
|
||||
"""
|
||||
Wrapper around box3d_overlap_sampling to support
|
||||
batched input
|
||||
"""
|
||||
N = boxes1.shape[0]
|
||||
M = boxes2.shape[0]
|
||||
ious = torch.zeros((N, M), dtype=torch.float32, device=boxes1.device)
|
||||
for n in range(N):
|
||||
for m in range(M):
|
||||
iou = box3d_overlap_sampling(boxes1[n], boxes2[m])
|
||||
ious[n, m] = iou
|
||||
return ious
|
||||
|
||||
def _test_iou(self, overlap_fn, device):
|
||||
|
||||
box1 = torch.tensor(
|
||||
[
|
||||
@ -60,30 +93,36 @@ class TestIoU3D(TestCaseMixin, unittest.TestCase):
|
||||
)
|
||||
|
||||
# 1st test: same box, iou = 1.0
|
||||
vol, iou = box3d_overlap(box1, box1)
|
||||
self.assertClose(vol, torch.tensor(1.0, device=vol.device, dtype=vol.dtype))
|
||||
self.assertClose(iou, torch.tensor(1.0, device=vol.device, dtype=vol.dtype))
|
||||
vol, iou = overlap_fn(box1[None], box1[None])
|
||||
self.assertClose(vol, torch.tensor([[1.0]], device=vol.device, dtype=vol.dtype))
|
||||
self.assertClose(iou, torch.tensor([[1.0]], device=vol.device, dtype=vol.dtype))
|
||||
|
||||
# 2nd test
|
||||
dd = random.random()
|
||||
box2 = box1 + torch.tensor([[0.0, dd, 0.0]], device=device)
|
||||
vol, iou = box3d_overlap(box1, box2)
|
||||
self.assertClose(vol, torch.tensor(1 - dd, device=vol.device, dtype=vol.dtype))
|
||||
vol, iou = overlap_fn(box1[None], box2[None])
|
||||
self.assertClose(
|
||||
vol, torch.tensor([[1 - dd]], device=vol.device, dtype=vol.dtype)
|
||||
)
|
||||
|
||||
# 3rd test
|
||||
dd = random.random()
|
||||
box2 = box1 + torch.tensor([[dd, 0.0, 0.0]], device=device)
|
||||
vol, _ = box3d_overlap(box1, box2)
|
||||
self.assertClose(vol, torch.tensor(1 - dd, device=vol.device, dtype=vol.dtype))
|
||||
vol, _ = overlap_fn(box1[None], box2[None])
|
||||
self.assertClose(
|
||||
vol, torch.tensor([[1 - dd]], device=vol.device, dtype=vol.dtype)
|
||||
)
|
||||
|
||||
# 4th test
|
||||
ddx, ddy, ddz = random.random(), random.random(), random.random()
|
||||
box2 = box1 + torch.tensor([[ddx, ddy, ddz]], device=device)
|
||||
vol, _ = box3d_overlap(box1, box2)
|
||||
vol, _ = overlap_fn(box1[None], box2[None])
|
||||
self.assertClose(
|
||||
vol,
|
||||
torch.tensor(
|
||||
(1 - ddx) * (1 - ddy) * (1 - ddz), device=vol.device, dtype=vol.dtype
|
||||
[[(1 - ddx) * (1 - ddy) * (1 - ddz)]],
|
||||
device=vol.device,
|
||||
dtype=vol.dtype,
|
||||
),
|
||||
)
|
||||
|
||||
@ -93,11 +132,13 @@ class TestIoU3D(TestCaseMixin, unittest.TestCase):
|
||||
RR = random_rotation(dtype=torch.float32, device=device)
|
||||
box1r = box1 @ RR.transpose(0, 1)
|
||||
box2r = box2 @ RR.transpose(0, 1)
|
||||
vol, _ = box3d_overlap(box1r, box2r)
|
||||
vol, _ = overlap_fn(box1r[None], box2r[None])
|
||||
self.assertClose(
|
||||
vol,
|
||||
torch.tensor(
|
||||
(1 - ddx) * (1 - ddy) * (1 - ddz), device=vol.device, dtype=vol.dtype
|
||||
[[(1 - ddx) * (1 - ddy) * (1 - ddz)]],
|
||||
device=vol.device,
|
||||
dtype=vol.dtype,
|
||||
),
|
||||
)
|
||||
|
||||
@ -108,11 +149,13 @@ class TestIoU3D(TestCaseMixin, unittest.TestCase):
|
||||
TT = torch.rand((1, 3), dtype=torch.float32, device=device)
|
||||
box1r = box1 @ RR.transpose(0, 1) + TT
|
||||
box2r = box2 @ RR.transpose(0, 1) + TT
|
||||
vol, _ = box3d_overlap(box1r, box2r)
|
||||
vol, _ = overlap_fn(box1r[None], box2r[None])
|
||||
self.assertClose(
|
||||
vol,
|
||||
torch.tensor(
|
||||
(1 - ddx) * (1 - ddy) * (1 - ddz), device=vol.device, dtype=vol.dtype
|
||||
[[(1 - ddx) * (1 - ddy) * (1 - ddz)]],
|
||||
device=vol.device,
|
||||
dtype=vol.dtype,
|
||||
),
|
||||
)
|
||||
|
||||
@ -135,7 +178,7 @@ class TestIoU3D(TestCaseMixin, unittest.TestCase):
|
||||
[-2.8789, 6.0142, 0.7549],
|
||||
[-4.3586, 3.5345, -1.1831],
|
||||
],
|
||||
device="cuda:0",
|
||||
device=device,
|
||||
)
|
||||
box2r = torch.tensor(
|
||||
[
|
||||
@ -148,7 +191,7 @@ class TestIoU3D(TestCaseMixin, unittest.TestCase):
|
||||
[0.4328, -5.3761, -3.5436],
|
||||
[-2.3633, -5.6305, -1.2893],
|
||||
],
|
||||
device="cuda:0",
|
||||
device=device,
|
||||
)
|
||||
# from Meshlab:
|
||||
vol_inters = 33.558529
|
||||
@ -156,9 +199,9 @@ class TestIoU3D(TestCaseMixin, unittest.TestCase):
|
||||
vol_box2 = 156.386719
|
||||
iou_mesh = vol_inters / (vol_box1 + vol_box2 - vol_inters)
|
||||
|
||||
vol, iou = box3d_overlap(box1r, box2r)
|
||||
self.assertClose(vol, torch.tensor(vol_inters, device=device), atol=1e-1)
|
||||
self.assertClose(iou, torch.tensor(iou_mesh, device=device), atol=1e-1)
|
||||
vol, iou = overlap_fn(box1r[None], box2r[None])
|
||||
self.assertClose(vol, torch.tensor([[vol_inters]], device=device), atol=1e-1)
|
||||
self.assertClose(iou, torch.tensor([[iou_mesh]], device=device), atol=1e-1)
|
||||
|
||||
# 8th test: compare with sampling
|
||||
# create box1
|
||||
@ -173,16 +216,47 @@ class TestIoU3D(TestCaseMixin, unittest.TestCase):
|
||||
TT2 = torch.rand((1, 3), dtype=torch.float32, device=device)
|
||||
box1r = box1 @ RR1.transpose(0, 1) + TT1
|
||||
box2r = box2 @ RR2.transpose(0, 1) + TT2
|
||||
vol, iou = box3d_overlap(box1r, box2r)
|
||||
iou_sampling = box3d_overlap_sampling(box1r, box2r, num_samples=10000)
|
||||
vol, iou = overlap_fn(box1r[None], box2r[None])
|
||||
iou_sampling = self._box3d_overlap_sampling_batched(
|
||||
box1r[None], box2r[None], num_samples=10000
|
||||
)
|
||||
|
||||
self.assertClose(iou, iou_sampling, atol=1e-2)
|
||||
|
||||
# 9th test: non overlapping boxes, iou = 0.0
|
||||
box2 = box1 + torch.tensor([[0.0, 100.0, 0.0]], device=device)
|
||||
vol, iou = box3d_overlap(box1, box2)
|
||||
self.assertClose(vol, torch.tensor(0.0, device=vol.device, dtype=vol.dtype))
|
||||
self.assertClose(iou, torch.tensor(0.0, device=vol.device, dtype=vol.dtype))
|
||||
vol, iou = overlap_fn(box1[None], box2[None])
|
||||
self.assertClose(vol, torch.tensor([[0.0]], device=vol.device, dtype=vol.dtype))
|
||||
self.assertClose(iou, torch.tensor([[0.0]], device=vol.device, dtype=vol.dtype))
|
||||
|
||||
def test_iou_naive(self):
|
||||
device = torch.device("cuda:0")
|
||||
self._test_iou(self._box3d_overlap_naive_batched, device)
|
||||
|
||||
def test_iou_cpu(self):
|
||||
device = torch.device("cpu")
|
||||
self._test_iou(box3d_overlap, device)
|
||||
|
||||
def test_cpu_vs_naive_batched(self):
|
||||
N, M = 3, 6
|
||||
device = "cpu"
|
||||
boxes1 = torch.randn((N, 8, 3), device=device)
|
||||
boxes2 = torch.randn((M, 8, 3), device=device)
|
||||
vol1, iou1 = self._box3d_overlap_naive_batched(boxes1, boxes2)
|
||||
vol2, iou2 = box3d_overlap(boxes1, boxes2)
|
||||
# check shape
|
||||
for val in [vol1, vol2, iou1, iou2]:
|
||||
self.assertClose(val.shape, (N, M))
|
||||
# check values
|
||||
self.assertClose(vol1, vol2)
|
||||
self.assertClose(iou1, iou2)
|
||||
|
||||
def test_batched_errors(self):
|
||||
N, M = 5, 10
|
||||
boxes1 = torch.randn((N, 8, 3))
|
||||
boxes2 = torch.randn((M, 10, 3))
|
||||
with self.assertRaisesRegex(ValueError, "(8, 3)"):
|
||||
box3d_overlap(boxes1, boxes2)
|
||||
|
||||
def test_box_volume(self):
|
||||
device = torch.device("cuda:0")
|
||||
@ -277,6 +351,36 @@ class TestIoU3D(TestCaseMixin, unittest.TestCase):
|
||||
self.assertClose(box_planar_dir(box1), n1)
|
||||
self.assertClose(box_planar_dir(box2), n2)
|
||||
|
||||
@staticmethod
|
||||
def iou_naive(N: int, M: int, device="cpu"):
|
||||
boxes1 = torch.randn((N, 8, 3))
|
||||
boxes2 = torch.randn((M, 8, 3))
|
||||
|
||||
def output():
|
||||
vol, iou = TestIoU3D._box3d_overlap_naive_batched(boxes1, boxes2)
|
||||
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def iou(N: int, M: int, device="cpu"):
|
||||
boxes1 = torch.randn((N, 8, 3), device=device)
|
||||
boxes2 = torch.randn((M, 8, 3), device=device)
|
||||
|
||||
def output():
|
||||
vol, iou = box3d_overlap(boxes1, boxes2)
|
||||
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def iou_sampling(N: int, M: int, num_samples: int):
|
||||
boxes1 = torch.randn((N, 8, 3))
|
||||
boxes2 = torch.randn((M, 8, 3))
|
||||
|
||||
def output():
|
||||
_ = TestIoU3D._box3d_overlap_sampling_batched(boxes1, boxes2, num_samples)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
# -------------------------------------------------- #
|
||||
# NAIVE IMPLEMENTATION #
|
||||
@ -284,7 +388,7 @@ class TestIoU3D(TestCaseMixin, unittest.TestCase):
|
||||
|
||||
"""
|
||||
The main functions below are:
|
||||
* box3d_overlap: which computes the exact IoU of box1 and box2
|
||||
* box3d_overlap_naive: which computes the exact IoU of box1 and box2
|
||||
* box3d_overlap_sampling: which computes an approximate IoU of box1 and box2
|
||||
by sampling points within the boxes
|
||||
|
||||
@ -738,7 +842,7 @@ def clip_tri_by_plane(plane, n, tri_verts) -> Union[List, torch.Tensor]:
|
||||
# -------------------------------------------------- #
|
||||
|
||||
|
||||
def box3d_overlap(box1: torch.Tensor, box2: torch.Tensor):
|
||||
def box3d_overlap_naive(box1: torch.Tensor, box2: torch.Tensor):
|
||||
"""
|
||||
Computes the intersection of 3D boxes1 and boxes2.
|
||||
Inputs boxes1, boxes2 are tensors of shape (8, 3) containing
|
||||
|
Loading…
x
Reference in New Issue
Block a user