(new) CUDA IoU for 3D boxes

Summary: CUDA implementation of 3D bounding box overlap calculation.

Reviewed By: gkioxari

Differential Revision: D31157919

fbshipit-source-id: 5dc89805d01fef2d6779f00a33226131e39c43ed
This commit is contained in:
Nikhila Ravi 2021-09-29 18:48:11 -07:00 committed by Facebook GitHub Bot
parent 53266ec9ff
commit ff8d4762f4
9 changed files with 1019 additions and 97 deletions

View File

@ -0,0 +1,176 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <math.h>
#include <stdio.h>
#include <stdlib.h>
#include <thrust/device_vector.h>
#include <thrust/tuple.h>
#include "iou_box3d/iou_utils.cuh"
#include "utils/pytorch3d_cutils.h"
// Parallelize over N*M computations which can each be done
// independently
__global__ void IoUBox3DKernel(
const at::PackedTensorAccessor64<float, 3, at::RestrictPtrTraits> boxes1,
const at::PackedTensorAccessor64<float, 3, at::RestrictPtrTraits> boxes2,
at::PackedTensorAccessor64<float, 2, at::RestrictPtrTraits> vols,
at::PackedTensorAccessor64<float, 2, at::RestrictPtrTraits> ious) {
const size_t N = boxes1.size(0);
const size_t M = boxes2.size(0);
const size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
const size_t stride = gridDim.x * blockDim.x;
for (size_t i = tid; i < N * M; i += stride) {
const size_t n = i / M; // box1 index
const size_t m = i % M; // box2 index
// Convert to array of structs of face vertices i.e. effectively (F, 3, 3)
// FaceVerts is a data type defined in iou_utils.cuh
FaceVerts box1_tris[NUM_TRIS];
FaceVerts box2_tris[NUM_TRIS];
GetBoxTris(boxes1[n], box1_tris);
GetBoxTris(boxes2[m], box2_tris);
// Calculate the position of the center of the box which is used in
// several calculations. This requires a tensor as input.
const float3 box1_center = BoxCenter(boxes1[n]);
const float3 box2_center = BoxCenter(boxes2[m]);
// Convert to an array of face vertices
FaceVerts box1_planes[NUM_PLANES];
GetBoxPlanes(boxes1[n], box1_planes);
FaceVerts box2_planes[NUM_PLANES];
GetBoxPlanes(boxes2[m], box2_planes);
// Get Box Volumes
const float box1_vol = BoxVolume(box1_tris, box1_center, NUM_TRIS);
const float box2_vol = BoxVolume(box2_tris, box2_center, NUM_TRIS);
// Tris in Box1 intersection with Planes in Box2
// Initialize box1 intersecting faces. MAX_TRIS is the
// max faces possible in the intersecting shape.
// TODO: determine if the value of MAX_TRIS is sufficient or
// if we should store the max tris for each NxM computation
// and throw an error if any exceeds the max.
FaceVerts box1_intersect[MAX_TRIS];
for (int j = 0; j < NUM_TRIS; ++j) {
// Initialize the faces from the box
box1_intersect[j] = box1_tris[j];
}
// Get the count of the actual number of faces in the intersecting shape
int box1_count = BoxIntersections(box2_planes, box2_center, box1_intersect);
// Tris in Box2 intersection with Planes in Box1
FaceVerts box2_intersect[MAX_TRIS];
for (int j = 0; j < NUM_TRIS; ++j) {
box2_intersect[j] = box2_tris[j];
}
const int box2_count =
BoxIntersections(box1_planes, box1_center, box2_intersect);
// If there are overlapping regions in Box2, remove any coplanar faces
if (box2_count > 0) {
// Identify if any triangles in Box2 are coplanar with Box1
Keep tri2_keep[MAX_TRIS];
for (int j = 0; j < MAX_TRIS; ++j) {
// Initialize the valid faces to be true
tri2_keep[j].keep = j < box2_count ? true : false;
}
for (int b1 = 0; b1 < box1_count; ++b1) {
for (int b2 = 0; b2 < box2_count; ++b2) {
const bool is_coplanar =
IsCoplanarFace(box1_intersect[b1], box2_intersect[b2]);
if (is_coplanar) {
tri2_keep[b2].keep = false;
}
}
}
// Keep only the non coplanar triangles in Box2 - add them to the
// Box1 triangles.
for (int b2 = 0; b2 < box2_count; ++b2) {
if (tri2_keep[b2].keep) {
box1_intersect[box1_count] = box2_intersect[b2];
// box1_count will determine the total faces in the
// intersecting shape
box1_count++;
}
}
}
// Initialize the vol and iou to 0.0 in case there are no triangles
// in the intersecting shape.
float vol = 0.0;
float iou = 0.0;
// If there are triangles in the intersecting shape
if (box1_count > 0) {
// The intersecting shape is a polyhedron made up of the
// triangular faces that are all now in box1_intersect.
// Calculate the polyhedron center
const float3 poly_center = PolyhedronCenter(box1_intersect, box1_count);
// Compute intersecting polyhedron volume
vol = BoxVolume(box1_intersect, poly_center, box1_count);
// Compute IoU
iou = vol / (box1_vol + box2_vol - vol);
}
// Write the volume and IoU to global memory
vols[n][m] = vol;
ious[n][m] = iou;
}
}
std::tuple<at::Tensor, at::Tensor> IoUBox3DCuda(
const at::Tensor& boxes1, // (N, 8, 3)
const at::Tensor& boxes2) { // (M, 8, 3)
// Check inputs are on the same device
at::TensorArg boxes1_t{boxes1, "boxes1", 1}, boxes2_t{boxes2, "boxes2", 2};
at::CheckedFrom c = "IoUBox3DCuda";
at::checkAllSameGPU(c, {boxes1_t, boxes2_t});
at::checkAllSameType(c, {boxes1_t, boxes2_t});
// Set the device for the kernel launch based on the device of boxes1
at::cuda::CUDAGuard device_guard(boxes1.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
TORCH_CHECK(boxes2.size(2) == boxes1.size(2), "Boxes must have shape (8, 3)");
TORCH_CHECK(
(boxes2.size(1) == 8) && (boxes1.size(1) == 8),
"Boxes must have shape (8, 3)");
const int64_t N = boxes1.size(0);
const int64_t M = boxes2.size(0);
auto vols = at::zeros({N, M}, boxes1.options());
auto ious = at::zeros({N, M}, boxes1.options());
if (vols.numel() == 0) {
AT_CUDA_CHECK(cudaGetLastError());
return std::make_tuple(vols, ious);
}
const size_t blocks = 512;
const size_t threads = 256;
IoUBox3DKernel<<<blocks, threads, 0, stream>>>(
boxes1.packed_accessor64<float, 3, at::RestrictPtrTraits>(),
boxes2.packed_accessor64<float, 3, at::RestrictPtrTraits>(),
vols.packed_accessor64<float, 2, at::RestrictPtrTraits>(),
ious.packed_accessor64<float, 2, at::RestrictPtrTraits>());
AT_CUDA_CHECK(cudaGetLastError());
return std::make_tuple(vols, ious);
}

View File

@ -26,12 +26,23 @@ std::tuple<at::Tensor, at::Tensor> IoUBox3DCpu(
const at::Tensor& boxes1,
const at::Tensor& boxes2);
// CUDA implementation
std::tuple<at::Tensor, at::Tensor> IoUBox3DCuda(
const at::Tensor& boxes1,
const at::Tensor& boxes2);
// Implementation which is exposed
inline std::tuple<at::Tensor, at::Tensor> IoUBox3D(
const at::Tensor& boxes1,
const at::Tensor& boxes2) {
if (boxes1.is_cuda() || boxes2.is_cuda()) {
AT_ERROR("GPU support not implemented");
#ifdef WITH_CUDA
CHECK_CUDA(boxes1);
CHECK_CUDA(boxes2);
return IoUBox3DCuda(boxes1.contiguous(), boxes2.contiguous());
#else
AT_ERROR("Not compiled with GPU support.");
#endif
}
return IoUBox3DCpu(boxes1.contiguous(), boxes2.contiguous());
}

View File

@ -79,7 +79,7 @@ std::tuple<at::Tensor, at::Tensor> IoUBox3DCpu(
std::fill(tri2_keep.begin(), tri2_keep.end(), 1);
for (int b1 = 0; b1 < box1_intersect.size(); ++b1) {
for (int b2 = 0; b2 < box2_intersect.size(); ++b2) {
bool is_coplanar =
const bool is_coplanar =
IsCoplanarFace(box1_intersect[b1], box2_intersect[b2]);
if (is_coplanar) {
tri2_keep[b2] = 0;

View File

@ -0,0 +1,584 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <float.h>
#include <math.h>
#include <thrust/device_vector.h>
#include <cstdio>
#include "utils/float_math.cuh"
#include "utils/geometry_utils.cuh"
/*
_PLANES and _TRIS define the 4- and 3-connectivity
of the 8 box corners.
_PLANES gives the quad faces of the 3D box
_TRIS gives the triangle faces of the 3D box
*/
const int NUM_PLANES = 6;
const int NUM_TRIS = 12;
// This is required for iniitalizing the faces
// in the intersecting shape
const int MAX_TRIS = 100;
// Create data types for representing the
// verts for each face and the indices.
// We will use struct arrays for representing
// the data for each box and intersecting
// triangles
typedef struct {
float3 v0;
float3 v1;
float3 v2;
float3 v3; // Can be empty for triangles
} FaceVerts;
typedef struct {
int v0;
int v1;
int v2;
int v3; // Can be empty for triangles
} FaceVertsIdx;
// This is used when deciding which faces to
// keep that are not coplanar
typedef struct {
bool keep;
} Keep;
__device__ FaceVertsIdx _PLANES[] = {
{0, 1, 2, 3},
{3, 2, 6, 7},
{0, 1, 5, 4},
{0, 3, 7, 4},
{1, 5, 6, 2},
{4, 5, 6, 7},
};
__device__ FaceVertsIdx _TRIS[] = {
{0, 1, 2},
{0, 3, 2},
{4, 5, 6},
{4, 6, 7},
{1, 5, 6},
{1, 6, 2},
{0, 4, 7},
{0, 7, 3},
{3, 2, 6},
{3, 6, 7},
{0, 1, 5},
{0, 4, 5},
};
// Args
// box: (8, 3) tensor accessor for the box vertices
// box_tris: Array of structs of type FaceVerts,
// effectively (F, 3, 3) where the coordinates of the
// verts for each face will be saved to.
//
// Returns: None (output saved to box_tris)
//
template <typename Box, typename BoxTris>
__device__ inline void GetBoxTris(const Box& box, BoxTris& box_tris) {
for (int t = 0; t < NUM_TRIS; ++t) {
const float3 v0 = make_float3(
box[_TRIS[t].v0][0], box[_TRIS[t].v0][1], box[_TRIS[t].v0][2]);
const float3 v1 = make_float3(
box[_TRIS[t].v1][0], box[_TRIS[t].v1][1], box[_TRIS[t].v1][2]);
const float3 v2 = make_float3(
box[_TRIS[t].v2][0], box[_TRIS[t].v2][1], box[_TRIS[t].v2][2]);
box_tris[t] = {v0, v1, v2};
}
}
// Args
// box: (8, 3) tensor accessor for the box vertices
// box_planes: Array of structs of type FaceVerts, effectively (P, 4, 3)
// where the coordinates of the verts for the four corners of each plane
// will be saved to
//
// Returns: None (output saved to box_planes)
//
template <typename Box, typename FaceVertsBoxPlanes>
__device__ inline void GetBoxPlanes(
const Box& box,
FaceVertsBoxPlanes& box_planes) {
for (int t = 0; t < NUM_PLANES; ++t) {
const float3 v0 = make_float3(
box[_PLANES[t].v0][0], box[_PLANES[t].v0][1], box[_PLANES[t].v0][2]);
const float3 v1 = make_float3(
box[_PLANES[t].v1][0], box[_PLANES[t].v1][1], box[_PLANES[t].v1][2]);
const float3 v2 = make_float3(
box[_PLANES[t].v2][0], box[_PLANES[t].v2][1], box[_PLANES[t].v2][2]);
const float3 v3 = make_float3(
box[_PLANES[t].v3][0], box[_PLANES[t].v3][1], box[_PLANES[t].v3][2]);
box_planes[t] = {v0, v1, v2, v3};
}
}
// The normal of the face defined by vertices (v0, v1, v2)
// Define e0 to be the edge connecting (v1, v0)
// Define e1 to be the edge connecting (v2, v0)
// normal is the cross product of e0, e1
//
// Args
// v0, v1, v2: float3 coordinates of the vertices of the face
//
// Returns
// float3: normal for the face
//
__device__ inline float3
FaceNormal(const float3 v0, const float3 v1, const float3 v2) {
float3 n = cross(v1 - v0, v2 - v0);
n = n / fmaxf(norm(n), kEpsilon);
return n;
}
// The normal of a box plane defined by the verts in `plane` with
// the centroid of the box given by `center`.
// Args
// plane: float3 coordinates of the vertices of the plane
// center: float3 coordinates of the center of the box from
// which the plane originated
//
// Returns
// float3: normal for the plane such that it points towards
// the center of the box
//
template <typename FaceVertsPlane>
__device__ inline float3 PlaneNormalDirection(
const FaceVertsPlane& plane,
const float3& center) {
// Only need the first 3 verts of the plane
const float3 v0 = plane.v0;
const float3 v1 = plane.v1;
const float3 v2 = plane.v2;
// We project the center on the plane defined by (v0, v1, v2)
// We can write center = v0 + a * e0 + b * e1 + c * n
// We know that <e0, n> = 0 and <e1, n> = 0 and
// <a, b> is the dot product between a and b.
// This means we can solve for c as:
// c = <center - v0 - a * e0 - b * e1, n> = <center - v0, n>
float3 n = FaceNormal(v0, v1, v2);
const float c = dot((center - v0), n);
// If c is negative, then we revert the direction of n such that n
// points "inside"
if (c < kEpsilon) {
n = -1.0f * n;
}
return n;
}
// Calculate the volume of the box by summing the volume of
// each of the tetrahedrons formed with a triangle face and
// the box centroid.
//
// Args
// box_tris: vector of float3 coordinates of the vertices of each
// of the triangles in the box
// box_center: float3 coordinates of the center of the box
//
// Returns
// float: volume of the box
//
template <typename BoxTris>
__device__ inline float BoxVolume(
const BoxTris& box_tris,
const float3& box_center,
const int num_tris) {
float box_vol = 0.0;
// Iterate through each triange, calculate the area of the
// tetrahedron formed with the box_center and sum them
for (int t = 0; t < num_tris; ++t) {
// Subtract the center:
float3 v0 = box_tris[t].v0;
float3 v1 = box_tris[t].v1;
float3 v2 = box_tris[t].v2;
v0 = v0 - box_center;
v1 = v1 - box_center;
v2 = v2 - box_center;
// Compute the area
const float area = dot(v0, cross(v1, v2));
const float vol = abs(area) / 6.0;
box_vol = box_vol + vol;
}
return box_vol;
}
// Compute the box center as the mean of the verts
//
// Args
// box_verts: (8, 3) tensor of the corner vertices of the box
//
// Returns
// float3: coordinates of the center of the box
//
template <typename Box>
__device__ inline float3 BoxCenter(const Box box_verts) {
float x = 0.0;
float y = 0.0;
float z = 0.0;
const int num_verts = box_verts.size(0); // Should be 8
// Sum all x, y, z, and take the mean
for (int t = 0; t < num_verts; ++t) {
x = x + box_verts[t][0];
y = y + box_verts[t][1];
z = z + box_verts[t][2];
}
// Take the mean of all the vertex positions
x = x / num_verts;
y = y / num_verts;
z = z / num_verts;
const float3 center = make_float3(x, y, z);
return center;
}
// Compute the polyhedron center as the mean of the face centers
// of the triangle faces
//
// Args
// tris: vector of float3 coordinates of the
// vertices of each of the triangles in the polyhedron
//
// Returns
// float3: coordinates of the center of the polyhedron
//
template <typename Tris>
__device__ inline float3 PolyhedronCenter(
const Tris& tris,
const int num_tris) {
float x = 0.0;
float y = 0.0;
float z = 0.0;
// Find the center point of each face
for (int t = 0; t < num_tris; ++t) {
const float3 v0 = tris[t].v0;
const float3 v1 = tris[t].v1;
const float3 v2 = tris[t].v2;
const float x_face = (v0.x + v1.x + v2.x) / 3.0;
const float y_face = (v0.y + v1.y + v2.y) / 3.0;
const float z_face = (v0.z + v1.z + v2.z) / 3.0;
x = x + x_face;
y = y + y_face;
z = z + z_face;
}
// Take the mean of the centers of all faces
x = x / num_tris;
y = y / num_tris;
z = z / num_tris;
const float3 center = make_float3(x, y, z);
return center;
}
// Compute a boolean indicator for whether a point
// is inside a plane, where inside refers to whether
// or not the point has a component in the
// normal direction of the plane.
//
// Args
// plane: vector of float3 coordinates of the
// vertices of each of the triangles in the box
// normal: float3 of the direction of the plane normal
// point: float3 of the position of the point of interest
//
// Returns
// bool: whether or not the point is inside the plane
//
__device__ inline bool
IsInside(const FaceVerts& plane, const float3& normal, const float3& point) {
// Get one vert of the plane
const float3 v0 = plane.v0;
// Every point p can be written as p = v0 + a e0 + b e1 + c n
// Solving for c:
// c = (point - v0 - a * e0 - b * e1).dot(n)
// We know that <e0, n> = 0 and <e1, n> = 0
// So the calculation can be simplified as:
const float c = dot((point - v0), normal);
const bool inside = c > -1.0f * kEpsilon;
return inside;
}
// Find the point of intersection between a plane
// and a line given by the end points (p0, p1)
//
// Args
// plane: vector of float3 coordinates of the
// vertices of each of the triangles in the box
// normal: float3 of the direction of the plane normal
// p0, p1: float3 of the start and end point of the line
//
// Returns
// float3: position of the intersection point
//
__device__ inline float3 PlaneEdgeIntersection(
const FaceVerts& plane,
const float3& normal,
const float3& p0,
const float3& p1) {
// Get one vert of the plane
const float3 v0 = plane.v0;
// The point of intersection can be parametrized
// p = p0 + a (p1 - p0) where a in [0, 1]
// We want to find a such that p is on plane
// <p - v0, n> = 0
const float top = dot(-1.0f * (p0 - v0), normal);
const float bot = dot(p1 - p0, normal);
const float a = top / bot;
const float3 p = p0 + a * (p1 - p0);
return p;
}
// Triangle is clipped into a quadrilateral
// based on the intersection points with the plane.
// Then the quadrilateral is divided into two triangles.
//
// Args
// plane: vector of float3 coordinates of the
// vertices of each of the triangles in the box
// normal: float3 of the direction of the plane normal
// vout: float3 of the point in the triangle which is outside
// the plane
// vin1, vin2: float3 of the points in the triangle which are
// inside the plane
// face_verts_out: Array of structs of type FaceVerts,
// with the coordinates of the new triangle faces
// formed after clipping.
// All triangles are now "inside" the plane.
//
// Returns:
// count: (int) number of new faces after clipping the triangle
// i.e. the valid faces which have been saved
// to face_verts_out
//
template <typename FaceVertsBox>
__device__ inline int ClipTriByPlaneOneOut(
const FaceVerts& plane,
const float3& normal,
const float3& vout,
const float3& vin1,
const float3& vin2,
FaceVertsBox& face_verts_out) {
// point of intersection between plane and (vin1, vout)
const float3 pint1 = PlaneEdgeIntersection(plane, normal, vin1, vout);
// point of intersection between plane and (vin2, vout)
const float3 pint2 = PlaneEdgeIntersection(plane, normal, vin2, vout);
face_verts_out[0] = {vin1, pint1, pint2};
face_verts_out[1] = {vin1, pint2, vin2};
return 2;
}
// Triangle is clipped into a smaller triangle based
// on the intersection points with the plane.
//
// Args
// plane: vector of float3 coordinates of the
// vertices of each of the triangles in the box
// normal: float3 of the direction of the plane normal
// vout1, vout2: float3 of the points in the triangle which are
// outside the plane
// vin: float3 of the point in the triangle which is inside
// the plane
// face_verts_out: Array of structs of type FaceVerts,
// with the coordinates of the new triangle faces
// formed after clipping.
// All triangles are now "inside" the plane.
//
// Returns:
// count: (int) number of new faces after clipping the triangle
// i.e. the valid faces which have been saved
// to face_verts_out
//
template <typename FaceVertsBox>
__device__ inline int ClipTriByPlaneTwoOut(
const FaceVerts& plane,
const float3& normal,
const float3& vout1,
const float3& vout2,
const float3& vin,
FaceVertsBox& face_verts_out) {
// point of intersection between plane and (vin, vout1)
const float3 pint1 = PlaneEdgeIntersection(plane, normal, vin, vout1);
// point of intersection between plane and (vin, vout2)
const float3 pint2 = PlaneEdgeIntersection(plane, normal, vin, vout2);
face_verts_out[0] = {vin, pint1, pint2};
return 1;
}
// Clip the triangle faces so that they lie within the
// plane, creating new triangle faces where necessary.
//
// Args
// plane: Array of structs of type FaceVerts with the coordinates
// of the vertices of each of the triangles in the box
// tri: Array of structs of type FaceVerts with the vertex
// coordinates of the triangle faces
// normal: float3 of the direction of the plane normal
// face_verts_out: Array of structs of type FaceVerts,
// with the coordinates of the new triangle faces
// formed after clipping.
// All triangles are now "inside" the plane.
//
// Returns:
// count: (int) number of new faces after clipping the triangle
// i.e. the valid faces which have been saved
// to face_verts_out
//
template <typename FaceVertsBox>
__device__ inline int ClipTriByPlane(
const FaceVerts& plane,
const FaceVerts& tri,
const float3& normal,
FaceVertsBox& face_verts_out) {
// Get Triangle vertices
const float3 v0 = tri.v0;
const float3 v1 = tri.v1;
const float3 v2 = tri.v2;
// Check each of the triangle vertices to see if it is inside the plane
const bool isin0 = IsInside(plane, normal, v0);
const bool isin1 = IsInside(plane, normal, v1);
const bool isin2 = IsInside(plane, normal, v2);
// All in
if (isin0 && isin1 && isin2) {
// Return input vertices
face_verts_out[0] = {v0, v1, v2};
return 1;
}
// All out
if (!isin0 && !isin1 && !isin2) {
return 0;
}
// One vert out
if (isin0 && isin1 && !isin2) {
return ClipTriByPlaneOneOut(plane, normal, v2, v0, v1, face_verts_out);
}
if (isin0 && not isin1 && isin2) {
return ClipTriByPlaneOneOut(plane, normal, v1, v0, v2, face_verts_out);
}
if (not isin0 && isin1 && isin2) {
return ClipTriByPlaneOneOut(plane, normal, v0, v1, v2, face_verts_out);
}
// Two verts out
if (isin0 && !isin1 && !isin2) {
return ClipTriByPlaneTwoOut(plane, normal, v1, v2, v0, face_verts_out);
}
if (!isin0 && !isin1 && isin2) {
return ClipTriByPlaneTwoOut(plane, normal, v0, v1, v2, face_verts_out);
}
if (!isin0 && isin1 && !isin2) {
return ClipTriByPlaneTwoOut(plane, normal, v0, v2, v1, face_verts_out);
}
// Else return empty (should not be reached)
return 0;
}
// Compute a boolean indicator for whether or not two faces
// are coplanar
//
// Args
// tri1, tri2: FaceVerts struct of the vertex coordinates of
// the triangle face
//
// Returns
// bool: whether or not the two faces are coplanar
//
__device__ inline bool IsCoplanarFace(
const FaceVerts& tri1,
const FaceVerts& tri2) {
// Get verts for face 1
const float3 v0 = tri1.v0;
const float3 v1 = tri1.v1;
const float3 v2 = tri1.v2;
const float3 n1 = FaceNormal(v0, v1, v2);
int coplanar_count = 0;
// Check v0, v1, v2
if (abs(dot(tri2.v0 - v0, n1)) < kEpsilon) {
coplanar_count++;
}
if (abs(dot(tri2.v1 - v0, n1)) < kEpsilon) {
coplanar_count++;
}
if (abs(dot(tri2.v2 - v0, n1)) < kEpsilon) {
coplanar_count++;
}
return (coplanar_count == 3);
}
// Get the triangles from each box which are part of the
// intersecting polyhedron by computing the intersection
// points with each of the planes.
//
// Args
// planes: Array of structs of type FaceVerts with the coordinates
// of the vertices of each of the triangles in the box
// center: float3 coordinates of the center of the box from which
// the planes originate
// face_verts_out: Array of structs of type FaceVerts,
// where the coordinates of the new triangle faces
// formed after clipping will be saved to.
// All triangles are now "inside" the plane.
//
// Returns:
// count: (int) number of faces in the intersecting shape
// i.e. the valid faces which have been saved
// to face_verts_out
//
template <typename FaceVertsPlane, typename FaceVertsBox>
__device__ inline int BoxIntersections(
const FaceVertsPlane& planes,
const float3& center,
FaceVertsBox& face_verts_out) {
// Initialize num tris to 12
int num_tris = NUM_TRIS;
for (int p = 0; p < NUM_PLANES; ++p) {
// Get plane normal direction
const float3 n2 = PlaneNormalDirection(planes[p], center);
// Create intermediate vector to store the updated tris
FaceVerts tri_verts_updated[MAX_TRIS];
int offset = 0;
// Iterate through triangles in face_verts_out
// for the valid tris given by num_tris
for (int t = 0; t < num_tris; ++t) {
// Clip tri by plane, can max be split into 2 triangles
FaceVerts tri_updated[2];
const int count =
ClipTriByPlane(planes[p], face_verts_out[t], n2, tri_updated);
// Add to the tri_verts_updated output if not empty
for (int v = 0; v < count; ++v) {
tri_verts_updated[offset] = tri_updated[v];
offset++;
}
}
// Update the face_verts_out tris
num_tris = offset;
for (int j = 0; j < num_tris; ++j) {
face_verts_out[j] = tri_verts_updated[j];
}
}
return num_tris;
}

View File

@ -9,6 +9,7 @@ from .cameras_alignment import corresponding_cameras_alignment
from .cubify import cubify
from .graph_conv import GraphConv
from .interp_face_attrs import interpolate_face_attributes
from .iou_box3d import box3d_overlap
from .knn import knn_gather, knn_points
from .laplacian_matrices import cot_laplacian, laplacian, norm_laplacian
from .mesh_face_areas_normals import mesh_face_areas_normals

View File

@ -7,10 +7,68 @@
from typing import Tuple
import torch
import torch.nn.functional as F
from pytorch3d import _C
from torch.autograd import Function
# -------------------------------------------------- #
# CONSTANTS #
# -------------------------------------------------- #
"""
_box_planes and _box_triangles define the 4- and 3-connectivity
of the 8 box corners.
_box_planes gives the quad faces of the 3D box
_box_triangles gives the triangle faces of the 3D box
"""
_box_planes = [
[0, 1, 2, 3],
[3, 2, 6, 7],
[0, 1, 5, 4],
[0, 3, 7, 4],
[1, 2, 6, 5],
[4, 5, 6, 7],
]
_box_triangles = [
[0, 1, 2],
[0, 3, 2],
[4, 5, 6],
[4, 6, 7],
[1, 5, 6],
[1, 6, 2],
[0, 4, 7],
[0, 7, 3],
[3, 2, 6],
[3, 6, 7],
[0, 1, 5],
[0, 4, 5],
]
def _check_coplanar(boxes: torch.Tensor, eps: float = 1e-5) -> None:
faces = torch.tensor(_box_planes, dtype=torch.int64, device=boxes.device)
# pyre-fixme[16]: `boxes` has no attribute `index_select`.
verts = boxes.index_select(index=faces.view(-1), dim=1)
B = boxes.shape[0]
P, V = faces.shape
# (B, P, 4, 3) -> (B, P, 3)
v0, v1, v2, v3 = verts.reshape(B, P, V, 3).unbind(2)
# Compute the normal
e0 = F.normalize(v1 - v0, dim=-1)
e1 = F.normalize(v2 - v0, dim=-1)
normal = F.normalize(torch.cross(e0, e1, dim=-1), dim=-1)
# Check the fourth vertex is also on the same plane
mat1 = (v3 - v0).view(B, 1, -1) # (B, 1, P*3)
mat2 = normal.view(B, -1, 1) # (B, P*3, 1)
if not (mat1.bmm(mat2).abs() < eps).all().item():
msg = "Plane vertices are not coplanar"
raise ValueError(msg)
return
class _box3d_overlap(Function):
"""
Torch autograd Function wrapper for box3d_overlap C++/CUDA implementations.
@ -35,6 +93,7 @@ def box3d_overlap(
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Computes the intersection of 3D boxes1 and boxes2.
Inputs boxes1, boxes2 are tensors of shape (B, 8, 3)
(where B doesn't have to be the same for boxes1 and boxes1),
containing the 8 corners of the boxes, as follows:
@ -47,6 +106,25 @@ def box3d_overlap(
` . | ` . |
(3) ` +---------+ (2)
NOTE: Throughout this implementation, we assume that boxes
are defined by their 8 corners exactly in the order specified in the
diagram above for the function to give correct results. In addition
the vertices on each plane must be coplanar.
As an alternative to the diagram, this is a unit bounding
box which has the correct vertex ordering:
box_corner_vertices = [
[0, 0, 0],
[1, 0, 0],
[1, 1, 0],
[0, 1, 0],
[0, 0, 1],
[1, 0, 1],
[1, 1, 1],
[0, 1, 1],
]
Args:
boxes1: tensor of shape (N, 8, 3) of the coordinates of the 1st boxes
boxes2: tensor of shape (M, 8, 3) of the coordinates of the 2nd boxes
@ -58,6 +136,9 @@ def box3d_overlap(
if not all((8, 3) == box.shape[1:] for box in [boxes1, boxes2]):
raise ValueError("Each box in the batch must be of shape (8, 3)")
_check_coplanar(boxes1)
_check_coplanar(boxes2)
# pyre-fixme[16]: `_box3d_overlap` has no attribute `apply`.
vol, iou = _box3d_overlap.apply(boxes1, boxes2)

View File

@ -11,25 +11,42 @@ from test_iou_box3d import TestIoU3D
def bm_iou_box3d() -> None:
N = [1, 4, 8, 16]
num_samples = [2000, 5000, 10000, 20000]
# Realistic use cases
N = [30, 100]
M = [5, 10, 100]
kwargs_list = []
test_cases = product(N, M)
for case in test_cases:
n, m = case
kwargs_list.append({"N": n, "M": m, "device": "cuda:0"})
benchmark(TestIoU3D.iou, "3D_IOU", kwargs_list, warmup_iters=1)
# Comparison of C++/CUDA
kwargs_list = []
N = [1, 4, 8, 16]
devices = ["cpu", "cuda:0"]
test_cases = product(N, N, devices)
for case in test_cases:
n, m, d = case
kwargs_list.append({"N": n, "M": m, "device": d})
benchmark(TestIoU3D.iou, "3D_IOU", kwargs_list, warmup_iters=1)
# Naive PyTorch
N = [1, 4]
kwargs_list = []
test_cases = product(N, N)
for case in test_cases:
n, m = case
kwargs_list.append({"N": n, "M": m, "device": "cuda:0"})
benchmark(TestIoU3D.iou_naive, "3D_IOU_NAIVE", kwargs_list, warmup_iters=1)
[k.update({"device": "cpu"}) for k in kwargs_list]
benchmark(TestIoU3D.iou, "3D_IOU", kwargs_list, warmup_iters=1)
# Sampling based method
num_samples = [2000, 5000]
kwargs_list = []
test_cases = product([1, 4], [1, 4], num_samples)
test_cases = product(N, N, num_samples)
for case in test_cases:
n, m, s = case
kwargs_list.append({"N": n, "M": m, "num_samples": s})
kwargs_list.append({"N": n, "M": m, "num_samples": s, "device": "cuda:0"})
benchmark(TestIoU3D.iou_sampling, "3D_IOU_SAMPLING", kwargs_list, warmup_iters=1)

Binary file not shown.

View File

@ -10,13 +10,28 @@ from typing import List, Tuple, Union
import torch
import torch.nn.functional as F
from common_testing import TestCaseMixin
from common_testing import TestCaseMixin, get_random_cuda_device, get_tests_dir
from pytorch3d.io import save_obj
from pytorch3d.ops.iou_box3d import box3d_overlap
from pytorch3d.ops.iou_box3d import _box_planes, _box_triangles, box3d_overlap
from pytorch3d.transforms.rotation_conversions import random_rotation
OBJECTRON_TO_PYTORCH3D_FACE_IDX = [0, 4, 6, 2, 1, 5, 7, 3]
DATA_DIR = get_tests_dir() / "data"
DEBUG = False
UNIT_BOX = [
[0, 0, 0],
[1, 0, 0],
[1, 1, 0],
[0, 1, 0],
[0, 0, 1],
[1, 0, 1],
[1, 1, 1],
[0, 1, 1],
]
class TestIoU3D(TestCaseMixin, unittest.TestCase):
def setUp(self) -> None:
super().setUp()
@ -78,16 +93,7 @@ class TestIoU3D(TestCaseMixin, unittest.TestCase):
def _test_iou(self, overlap_fn, device):
box1 = torch.tensor(
[
[0, 0, 0],
[1, 0, 0],
[1, 1, 0],
[0, 1, 0],
[0, 0, 1],
[1, 0, 1],
[1, 1, 1],
[0, 1, 1],
],
UNIT_BOX,
dtype=torch.float32,
device=device,
)
@ -126,6 +132,10 @@ class TestIoU3D(TestCaseMixin, unittest.TestCase):
),
)
# Also check IoU is 1 when computing overlap with the same shifted box
vol, iou = overlap_fn(box2[None], box2[None])
self.assertClose(iou, torch.tensor([[1.0]], device=vol.device, dtype=vol.dtype))
# 5th test
ddx, ddy, ddz = random.random(), random.random(), random.random()
box2 = box1 + torch.tensor([[ddx, ddy, ddz]], device=device)
@ -207,15 +217,15 @@ class TestIoU3D(TestCaseMixin, unittest.TestCase):
# create box1
ctrs = torch.rand((2, 3), device=device)
whl = torch.rand((2, 3), device=device) * 10.0 + 1.0
# box1 & box2
box1 = self.create_box(ctrs[0], whl[0])
box2 = self.create_box(ctrs[1], whl[1])
# box8a & box8b
box8a = self.create_box(ctrs[0], whl[0])
box8b = self.create_box(ctrs[1], whl[1])
RR1 = random_rotation(dtype=torch.float32, device=device)
TT1 = torch.rand((1, 3), dtype=torch.float32, device=device)
RR2 = random_rotation(dtype=torch.float32, device=device)
TT2 = torch.rand((1, 3), dtype=torch.float32, device=device)
box1r = box1 @ RR1.transpose(0, 1) + TT1
box2r = box2 @ RR2.transpose(0, 1) + TT2
box1r = box8a @ RR1.transpose(0, 1) + TT1
box2r = box8b @ RR2.transpose(0, 1) + TT2
vol, iou = overlap_fn(box1r[None], box2r[None])
iou_sampling = self._box3d_overlap_sampling_batched(
box1r[None], box2r[None], num_samples=10000
@ -229,27 +239,90 @@ class TestIoU3D(TestCaseMixin, unittest.TestCase):
self.assertClose(vol, torch.tensor([[0.0]], device=vol.device, dtype=vol.dtype))
self.assertClose(iou, torch.tensor([[0.0]], device=vol.device, dtype=vol.dtype))
# 10th test: Non coplanar verts in a plane
box10 = box1 + torch.rand((8, 3), dtype=torch.float32, device=device)
msg = "Plane vertices are not coplanar"
with self.assertRaisesRegex(ValueError, msg):
overlap_fn(box10[None], box10[None])
# 11th test: Skewed bounding boxes but all verts are coplanar
box_skew_1 = torch.tensor(
[
[0, 0, 0],
[1, 0, 0],
[1, 1, 0],
[0, 1, 0],
[-2, -2, 2],
[2, -2, 2],
[2, 2, 2],
[-2, 2, 2],
],
dtype=torch.float32,
device=device,
)
box_skew_2 = torch.tensor(
[
[2.015995, 0.695233, 2.152806],
[2.832533, 0.663448, 1.576389],
[2.675445, -0.309592, 1.407520],
[1.858907, -0.277806, 1.983936],
[-0.413922, 3.161758, 2.044343],
[2.852230, 3.034615, -0.261321],
[2.223878, -0.857545, -0.936800],
[-1.042273, -0.730402, 1.368864],
],
dtype=torch.float32,
device=device,
)
vol1 = 14.000
vol2 = 14.000005
vol_inters = 5.431122
iou = vol_inters / (vol1 + vol2 - vol_inters)
vols, ious = overlap_fn(box_skew_1[None], box_skew_2[None])
self.assertClose(vols, torch.tensor([[vol_inters]], device=device), atol=1e-1)
self.assertClose(ious, torch.tensor([[iou]], device=device), atol=1e-1)
def test_iou_naive(self):
device = torch.device("cuda:0")
device = get_random_cuda_device()
self._test_iou(self._box3d_overlap_naive_batched, device)
self._test_compare_objectron(self._box3d_overlap_naive_batched, device)
def test_iou_cpu(self):
device = torch.device("cpu")
self._test_iou(box3d_overlap, device)
self._test_compare_objectron(box3d_overlap, device)
def test_cpu_vs_naive_batched(self):
N, M = 3, 6
device = "cpu"
boxes1 = torch.randn((N, 8, 3), device=device)
boxes2 = torch.randn((M, 8, 3), device=device)
vol1, iou1 = self._box3d_overlap_naive_batched(boxes1, boxes2)
vol2, iou2 = box3d_overlap(boxes1, boxes2)
# check shape
for val in [vol1, vol2, iou1, iou2]:
self.assertClose(val.shape, (N, M))
# check values
self.assertClose(vol1, vol2)
self.assertClose(iou1, iou2)
def test_iou_cuda(self):
device = torch.device("cuda:0")
self._test_iou(box3d_overlap, device)
self._test_compare_objectron(box3d_overlap, device)
def _test_compare_objectron(self, overlap_fn, device):
# Load saved objectron data
data_filename = "./objectron_vols_ious.pt"
objectron_vals = torch.load(DATA_DIR / data_filename)
boxes1 = objectron_vals["boxes1"]
boxes2 = objectron_vals["boxes2"]
vols_objectron = objectron_vals["vols"]
ious_objectron = objectron_vals["ious"]
boxes1 = boxes1.to(device=device, dtype=torch.float32)
boxes2 = boxes2.to(device=device, dtype=torch.float32)
# Convert vertex orderings from Objectron to PyTorch3D convention
idx = torch.tensor(
OBJECTRON_TO_PYTORCH3D_FACE_IDX, dtype=torch.int64, device=device
)
boxes1 = boxes1.index_select(index=idx, dim=1)
boxes2 = boxes2.index_select(index=idx, dim=1)
# Run PyTorch3D version
vols, ious = overlap_fn(boxes1, boxes2)
# Check values match
self.assertClose(vols_objectron, vols.cpu())
self.assertClose(ious_objectron, ious.cpu())
def test_batched_errors(self):
N, M = 5, 10
@ -316,16 +389,7 @@ class TestIoU3D(TestCaseMixin, unittest.TestCase):
def test_box_planar_dir(self):
device = torch.device("cuda:0")
box1 = torch.tensor(
[
[0, 0, 0],
[1, 0, 0],
[1, 1, 0],
[0, 1, 0],
[0, 0, 1],
[1, 0, 1],
[1, 1, 1],
[0, 1, 1],
],
UNIT_BOX,
dtype=torch.float32,
device=device,
)
@ -353,8 +417,13 @@ class TestIoU3D(TestCaseMixin, unittest.TestCase):
@staticmethod
def iou_naive(N: int, M: int, device="cpu"):
boxes1 = torch.randn((N, 8, 3))
boxes2 = torch.randn((M, 8, 3))
box = torch.tensor(
[UNIT_BOX],
dtype=torch.float32,
device=device,
)
boxes1 = box + torch.randn((N, 1, 3), device=device)
boxes2 = box + torch.randn((M, 1, 3), device=device)
def output():
vol, iou = TestIoU3D._box3d_overlap_naive_batched(boxes1, boxes2)
@ -363,8 +432,13 @@ class TestIoU3D(TestCaseMixin, unittest.TestCase):
@staticmethod
def iou(N: int, M: int, device="cpu"):
boxes1 = torch.randn((N, 8, 3), device=device)
boxes2 = torch.randn((M, 8, 3), device=device)
box = torch.tensor(
[UNIT_BOX],
dtype=torch.float32,
device=device,
)
boxes1 = box + torch.randn((N, 1, 3), device=device)
boxes2 = box + torch.randn((M, 1, 3), device=device)
def output():
vol, iou = box3d_overlap(boxes1, boxes2)
@ -372,9 +446,14 @@ class TestIoU3D(TestCaseMixin, unittest.TestCase):
return output
@staticmethod
def iou_sampling(N: int, M: int, num_samples: int):
boxes1 = torch.randn((N, 8, 3))
boxes2 = torch.randn((M, 8, 3))
def iou_sampling(N: int, M: int, num_samples: int, device="cpu"):
box = torch.tensor(
[UNIT_BOX],
dtype=torch.float32,
device=device,
)
boxes1 = box + torch.randn((N, 1, 3), device=device)
boxes2 = box + torch.randn((M, 1, 3), device=device)
def output():
_ = TestIoU3D._box3d_overlap_sampling_batched(boxes1, boxes2, num_samples)
@ -408,38 +487,6 @@ Note that both implementations currently do not support batching.
#
# -------------------------------------------------- #
# -------------------------------------------------- #
# CONSTANTS #
# -------------------------------------------------- #
"""
_box_planes and _box_triangles define the 4- and 3-connectivity
of the 8 box corners.
_box_planes gives the quad faces of the 3D box
_box_triangles gives the triangle faces of the 3D box
"""
_box_planes = [
[0, 1, 2, 3],
[3, 2, 6, 7],
[0, 1, 5, 4],
[0, 3, 7, 4],
[1, 5, 6, 2],
[4, 5, 6, 7],
]
_box_triangles = [
[0, 1, 2],
[0, 3, 2],
[4, 5, 6],
[4, 6, 7],
[1, 5, 6],
[1, 6, 2],
[0, 4, 7],
[0, 7, 3],
[3, 2, 6],
[3, 6, 7],
[0, 1, 5],
[0, 4, 5],
]
# -------------------------------------------------- #
# HELPER FUNCTIONS FOR EXACT SOLUTION #
# -------------------------------------------------- #
@ -477,7 +524,7 @@ def get_plane_verts(box: torch.Tensor) -> torch.Tensor:
return plane_verts
def box_planar_dir(box: torch.Tensor) -> torch.Tensor:
def box_planar_dir(box: torch.Tensor, eps=1e-4) -> torch.Tensor:
"""
Finds the unit vector n which is perpendicular to each plane in the box
and points towards the inside of the box.
@ -507,6 +554,11 @@ def box_planar_dir(box: torch.Tensor) -> torch.Tensor:
e1 = F.normalize(v2 - v0, dim=-1)
n = F.normalize(torch.cross(e0, e1, dim=-1), dim=-1)
# Check all verts are coplanar
if not ((v3 - v0).unsqueeze(1).bmm(n.unsqueeze(2)).abs() < eps).all().item():
msg = "Plane vertices are not coplanar"
raise ValueError(msg)
# We can write: `ctr = v0 + a * e0 + b * e1 + c * n`, (1).
# With <e0, n> = 0 and <e1, n> = 0, where <.,.> refers to the dot product,
# since that e0 is orthogonal to n. Same for e1.
@ -733,10 +785,10 @@ def clip_tri_by_plane_oneout(
device = plane.device
# point of intersection between plane and (vin1, vout)
pint1, a1 = plane_edge_point_of_intersection(plane, n, vin1, vout)
assert a1 >= eps and a1 <= 1.0
assert a1 >= eps and a1 <= 1.0, a1
# point of intersection between plane and (vin2, vout)
pint2, a2 = plane_edge_point_of_intersection(plane, n, vin2, vout)
assert a2 >= 0.0 and a2 <= 1.0
assert a2 >= 0.0 and a2 <= 1.0, a2
verts = torch.stack((vin1, pint1, pint2, vin2), dim=0) # 4x3
faces = torch.tensor(
@ -771,10 +823,10 @@ def clip_tri_by_plane_twoout(
device = plane.device
# point of intersection between plane and (vin, vout1)
pint1, a1 = plane_edge_point_of_intersection(plane, n, vin, vout1)
assert a1 >= eps and a1 <= 1.0
assert a1 >= eps and a1 <= 1.0, a1
# point of intersection between plane and (vin, vout2)
pint2, a2 = plane_edge_point_of_intersection(plane, n, vin, vout2)
assert a2 >= eps and a2 <= 1.0
assert a2 >= eps and a2 <= 1.0, a2
verts = torch.stack((vin, pint1, pint2), dim=0) # 3x3
faces = torch.tensor(
@ -945,7 +997,7 @@ def box3d_overlap_naive(box1: torch.Tensor, box2: torch.Tensor):
iou = vol / (vol1 + vol2 - vol)
if 0:
if DEBUG:
# save shapes
tri_faces = torch.tensor(_box_triangles, device=device, dtype=torch.int64)
save_obj("/tmp/output/shape1.obj", box1, tri_faces)