mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-01 03:12:49 +08:00
(new) CUDA IoU for 3D boxes
Summary: CUDA implementation of 3D bounding box overlap calculation. Reviewed By: gkioxari Differential Revision: D31157919 fbshipit-source-id: 5dc89805d01fef2d6779f00a33226131e39c43ed
This commit is contained in:
parent
53266ec9ff
commit
ff8d4762f4
176
pytorch3d/csrc/iou_box3d/iou_box3d.cu
Normal file
176
pytorch3d/csrc/iou_box3d/iou_box3d.cu
Normal file
@ -0,0 +1,176 @@
|
||||
/*
|
||||
* Copyright (c) Facebook, Inc. and its affiliates.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This source code is licensed under the BSD-style license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*/
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <math.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <thrust/device_vector.h>
|
||||
#include <thrust/tuple.h>
|
||||
#include "iou_box3d/iou_utils.cuh"
|
||||
#include "utils/pytorch3d_cutils.h"
|
||||
|
||||
// Parallelize over N*M computations which can each be done
|
||||
// independently
|
||||
__global__ void IoUBox3DKernel(
|
||||
const at::PackedTensorAccessor64<float, 3, at::RestrictPtrTraits> boxes1,
|
||||
const at::PackedTensorAccessor64<float, 3, at::RestrictPtrTraits> boxes2,
|
||||
at::PackedTensorAccessor64<float, 2, at::RestrictPtrTraits> vols,
|
||||
at::PackedTensorAccessor64<float, 2, at::RestrictPtrTraits> ious) {
|
||||
const size_t N = boxes1.size(0);
|
||||
const size_t M = boxes2.size(0);
|
||||
|
||||
const size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const size_t stride = gridDim.x * blockDim.x;
|
||||
|
||||
for (size_t i = tid; i < N * M; i += stride) {
|
||||
const size_t n = i / M; // box1 index
|
||||
const size_t m = i % M; // box2 index
|
||||
|
||||
// Convert to array of structs of face vertices i.e. effectively (F, 3, 3)
|
||||
// FaceVerts is a data type defined in iou_utils.cuh
|
||||
FaceVerts box1_tris[NUM_TRIS];
|
||||
FaceVerts box2_tris[NUM_TRIS];
|
||||
GetBoxTris(boxes1[n], box1_tris);
|
||||
GetBoxTris(boxes2[m], box2_tris);
|
||||
|
||||
// Calculate the position of the center of the box which is used in
|
||||
// several calculations. This requires a tensor as input.
|
||||
const float3 box1_center = BoxCenter(boxes1[n]);
|
||||
const float3 box2_center = BoxCenter(boxes2[m]);
|
||||
|
||||
// Convert to an array of face vertices
|
||||
FaceVerts box1_planes[NUM_PLANES];
|
||||
GetBoxPlanes(boxes1[n], box1_planes);
|
||||
FaceVerts box2_planes[NUM_PLANES];
|
||||
GetBoxPlanes(boxes2[m], box2_planes);
|
||||
|
||||
// Get Box Volumes
|
||||
const float box1_vol = BoxVolume(box1_tris, box1_center, NUM_TRIS);
|
||||
const float box2_vol = BoxVolume(box2_tris, box2_center, NUM_TRIS);
|
||||
|
||||
// Tris in Box1 intersection with Planes in Box2
|
||||
// Initialize box1 intersecting faces. MAX_TRIS is the
|
||||
// max faces possible in the intersecting shape.
|
||||
// TODO: determine if the value of MAX_TRIS is sufficient or
|
||||
// if we should store the max tris for each NxM computation
|
||||
// and throw an error if any exceeds the max.
|
||||
FaceVerts box1_intersect[MAX_TRIS];
|
||||
for (int j = 0; j < NUM_TRIS; ++j) {
|
||||
// Initialize the faces from the box
|
||||
box1_intersect[j] = box1_tris[j];
|
||||
}
|
||||
// Get the count of the actual number of faces in the intersecting shape
|
||||
int box1_count = BoxIntersections(box2_planes, box2_center, box1_intersect);
|
||||
|
||||
// Tris in Box2 intersection with Planes in Box1
|
||||
FaceVerts box2_intersect[MAX_TRIS];
|
||||
for (int j = 0; j < NUM_TRIS; ++j) {
|
||||
box2_intersect[j] = box2_tris[j];
|
||||
}
|
||||
const int box2_count =
|
||||
BoxIntersections(box1_planes, box1_center, box2_intersect);
|
||||
|
||||
// If there are overlapping regions in Box2, remove any coplanar faces
|
||||
if (box2_count > 0) {
|
||||
// Identify if any triangles in Box2 are coplanar with Box1
|
||||
Keep tri2_keep[MAX_TRIS];
|
||||
for (int j = 0; j < MAX_TRIS; ++j) {
|
||||
// Initialize the valid faces to be true
|
||||
tri2_keep[j].keep = j < box2_count ? true : false;
|
||||
}
|
||||
for (int b1 = 0; b1 < box1_count; ++b1) {
|
||||
for (int b2 = 0; b2 < box2_count; ++b2) {
|
||||
const bool is_coplanar =
|
||||
IsCoplanarFace(box1_intersect[b1], box2_intersect[b2]);
|
||||
if (is_coplanar) {
|
||||
tri2_keep[b2].keep = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Keep only the non coplanar triangles in Box2 - add them to the
|
||||
// Box1 triangles.
|
||||
for (int b2 = 0; b2 < box2_count; ++b2) {
|
||||
if (tri2_keep[b2].keep) {
|
||||
box1_intersect[box1_count] = box2_intersect[b2];
|
||||
// box1_count will determine the total faces in the
|
||||
// intersecting shape
|
||||
box1_count++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize the vol and iou to 0.0 in case there are no triangles
|
||||
// in the intersecting shape.
|
||||
float vol = 0.0;
|
||||
float iou = 0.0;
|
||||
|
||||
// If there are triangles in the intersecting shape
|
||||
if (box1_count > 0) {
|
||||
// The intersecting shape is a polyhedron made up of the
|
||||
// triangular faces that are all now in box1_intersect.
|
||||
// Calculate the polyhedron center
|
||||
const float3 poly_center = PolyhedronCenter(box1_intersect, box1_count);
|
||||
// Compute intersecting polyhedron volume
|
||||
vol = BoxVolume(box1_intersect, poly_center, box1_count);
|
||||
// Compute IoU
|
||||
iou = vol / (box1_vol + box2_vol - vol);
|
||||
}
|
||||
|
||||
// Write the volume and IoU to global memory
|
||||
vols[n][m] = vol;
|
||||
ious[n][m] = iou;
|
||||
}
|
||||
}
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor> IoUBox3DCuda(
|
||||
const at::Tensor& boxes1, // (N, 8, 3)
|
||||
const at::Tensor& boxes2) { // (M, 8, 3)
|
||||
// Check inputs are on the same device
|
||||
at::TensorArg boxes1_t{boxes1, "boxes1", 1}, boxes2_t{boxes2, "boxes2", 2};
|
||||
at::CheckedFrom c = "IoUBox3DCuda";
|
||||
at::checkAllSameGPU(c, {boxes1_t, boxes2_t});
|
||||
at::checkAllSameType(c, {boxes1_t, boxes2_t});
|
||||
|
||||
// Set the device for the kernel launch based on the device of boxes1
|
||||
at::cuda::CUDAGuard device_guard(boxes1.device());
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
TORCH_CHECK(boxes2.size(2) == boxes1.size(2), "Boxes must have shape (8, 3)");
|
||||
|
||||
TORCH_CHECK(
|
||||
(boxes2.size(1) == 8) && (boxes1.size(1) == 8),
|
||||
"Boxes must have shape (8, 3)");
|
||||
|
||||
const int64_t N = boxes1.size(0);
|
||||
const int64_t M = boxes2.size(0);
|
||||
|
||||
auto vols = at::zeros({N, M}, boxes1.options());
|
||||
auto ious = at::zeros({N, M}, boxes1.options());
|
||||
|
||||
if (vols.numel() == 0) {
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
return std::make_tuple(vols, ious);
|
||||
}
|
||||
|
||||
const size_t blocks = 512;
|
||||
const size_t threads = 256;
|
||||
|
||||
IoUBox3DKernel<<<blocks, threads, 0, stream>>>(
|
||||
boxes1.packed_accessor64<float, 3, at::RestrictPtrTraits>(),
|
||||
boxes2.packed_accessor64<float, 3, at::RestrictPtrTraits>(),
|
||||
vols.packed_accessor64<float, 2, at::RestrictPtrTraits>(),
|
||||
ious.packed_accessor64<float, 2, at::RestrictPtrTraits>());
|
||||
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
return std::make_tuple(vols, ious);
|
||||
}
|
@ -26,12 +26,23 @@ std::tuple<at::Tensor, at::Tensor> IoUBox3DCpu(
|
||||
const at::Tensor& boxes1,
|
||||
const at::Tensor& boxes2);
|
||||
|
||||
// CUDA implementation
|
||||
std::tuple<at::Tensor, at::Tensor> IoUBox3DCuda(
|
||||
const at::Tensor& boxes1,
|
||||
const at::Tensor& boxes2);
|
||||
|
||||
// Implementation which is exposed
|
||||
inline std::tuple<at::Tensor, at::Tensor> IoUBox3D(
|
||||
const at::Tensor& boxes1,
|
||||
const at::Tensor& boxes2) {
|
||||
if (boxes1.is_cuda() || boxes2.is_cuda()) {
|
||||
AT_ERROR("GPU support not implemented");
|
||||
#ifdef WITH_CUDA
|
||||
CHECK_CUDA(boxes1);
|
||||
CHECK_CUDA(boxes2);
|
||||
return IoUBox3DCuda(boxes1.contiguous(), boxes2.contiguous());
|
||||
#else
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
return IoUBox3DCpu(boxes1.contiguous(), boxes2.contiguous());
|
||||
}
|
||||
|
@ -79,7 +79,7 @@ std::tuple<at::Tensor, at::Tensor> IoUBox3DCpu(
|
||||
std::fill(tri2_keep.begin(), tri2_keep.end(), 1);
|
||||
for (int b1 = 0; b1 < box1_intersect.size(); ++b1) {
|
||||
for (int b2 = 0; b2 < box2_intersect.size(); ++b2) {
|
||||
bool is_coplanar =
|
||||
const bool is_coplanar =
|
||||
IsCoplanarFace(box1_intersect[b1], box2_intersect[b2]);
|
||||
if (is_coplanar) {
|
||||
tri2_keep[b2] = 0;
|
||||
|
584
pytorch3d/csrc/iou_box3d/iou_utils.cuh
Normal file
584
pytorch3d/csrc/iou_box3d/iou_utils.cuh
Normal file
@ -0,0 +1,584 @@
|
||||
/*
|
||||
* Copyright (c) Facebook, Inc. and its affiliates.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This source code is licensed under the BSD-style license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*/
|
||||
|
||||
#include <float.h>
|
||||
#include <math.h>
|
||||
#include <thrust/device_vector.h>
|
||||
#include <cstdio>
|
||||
#include "utils/float_math.cuh"
|
||||
#include "utils/geometry_utils.cuh"
|
||||
|
||||
/*
|
||||
_PLANES and _TRIS define the 4- and 3-connectivity
|
||||
of the 8 box corners.
|
||||
_PLANES gives the quad faces of the 3D box
|
||||
_TRIS gives the triangle faces of the 3D box
|
||||
*/
|
||||
const int NUM_PLANES = 6;
|
||||
const int NUM_TRIS = 12;
|
||||
// This is required for iniitalizing the faces
|
||||
// in the intersecting shape
|
||||
const int MAX_TRIS = 100;
|
||||
|
||||
// Create data types for representing the
|
||||
// verts for each face and the indices.
|
||||
// We will use struct arrays for representing
|
||||
// the data for each box and intersecting
|
||||
// triangles
|
||||
typedef struct {
|
||||
float3 v0;
|
||||
float3 v1;
|
||||
float3 v2;
|
||||
float3 v3; // Can be empty for triangles
|
||||
} FaceVerts;
|
||||
|
||||
typedef struct {
|
||||
int v0;
|
||||
int v1;
|
||||
int v2;
|
||||
int v3; // Can be empty for triangles
|
||||
} FaceVertsIdx;
|
||||
|
||||
// This is used when deciding which faces to
|
||||
// keep that are not coplanar
|
||||
typedef struct {
|
||||
bool keep;
|
||||
} Keep;
|
||||
|
||||
__device__ FaceVertsIdx _PLANES[] = {
|
||||
{0, 1, 2, 3},
|
||||
{3, 2, 6, 7},
|
||||
{0, 1, 5, 4},
|
||||
{0, 3, 7, 4},
|
||||
{1, 5, 6, 2},
|
||||
{4, 5, 6, 7},
|
||||
};
|
||||
__device__ FaceVertsIdx _TRIS[] = {
|
||||
{0, 1, 2},
|
||||
{0, 3, 2},
|
||||
{4, 5, 6},
|
||||
{4, 6, 7},
|
||||
{1, 5, 6},
|
||||
{1, 6, 2},
|
||||
{0, 4, 7},
|
||||
{0, 7, 3},
|
||||
{3, 2, 6},
|
||||
{3, 6, 7},
|
||||
{0, 1, 5},
|
||||
{0, 4, 5},
|
||||
};
|
||||
|
||||
// Args
|
||||
// box: (8, 3) tensor accessor for the box vertices
|
||||
// box_tris: Array of structs of type FaceVerts,
|
||||
// effectively (F, 3, 3) where the coordinates of the
|
||||
// verts for each face will be saved to.
|
||||
//
|
||||
// Returns: None (output saved to box_tris)
|
||||
//
|
||||
template <typename Box, typename BoxTris>
|
||||
__device__ inline void GetBoxTris(const Box& box, BoxTris& box_tris) {
|
||||
for (int t = 0; t < NUM_TRIS; ++t) {
|
||||
const float3 v0 = make_float3(
|
||||
box[_TRIS[t].v0][0], box[_TRIS[t].v0][1], box[_TRIS[t].v0][2]);
|
||||
const float3 v1 = make_float3(
|
||||
box[_TRIS[t].v1][0], box[_TRIS[t].v1][1], box[_TRIS[t].v1][2]);
|
||||
const float3 v2 = make_float3(
|
||||
box[_TRIS[t].v2][0], box[_TRIS[t].v2][1], box[_TRIS[t].v2][2]);
|
||||
box_tris[t] = {v0, v1, v2};
|
||||
}
|
||||
}
|
||||
|
||||
// Args
|
||||
// box: (8, 3) tensor accessor for the box vertices
|
||||
// box_planes: Array of structs of type FaceVerts, effectively (P, 4, 3)
|
||||
// where the coordinates of the verts for the four corners of each plane
|
||||
// will be saved to
|
||||
//
|
||||
// Returns: None (output saved to box_planes)
|
||||
//
|
||||
template <typename Box, typename FaceVertsBoxPlanes>
|
||||
__device__ inline void GetBoxPlanes(
|
||||
const Box& box,
|
||||
FaceVertsBoxPlanes& box_planes) {
|
||||
for (int t = 0; t < NUM_PLANES; ++t) {
|
||||
const float3 v0 = make_float3(
|
||||
box[_PLANES[t].v0][0], box[_PLANES[t].v0][1], box[_PLANES[t].v0][2]);
|
||||
const float3 v1 = make_float3(
|
||||
box[_PLANES[t].v1][0], box[_PLANES[t].v1][1], box[_PLANES[t].v1][2]);
|
||||
const float3 v2 = make_float3(
|
||||
box[_PLANES[t].v2][0], box[_PLANES[t].v2][1], box[_PLANES[t].v2][2]);
|
||||
const float3 v3 = make_float3(
|
||||
box[_PLANES[t].v3][0], box[_PLANES[t].v3][1], box[_PLANES[t].v3][2]);
|
||||
box_planes[t] = {v0, v1, v2, v3};
|
||||
}
|
||||
}
|
||||
|
||||
// The normal of the face defined by vertices (v0, v1, v2)
|
||||
// Define e0 to be the edge connecting (v1, v0)
|
||||
// Define e1 to be the edge connecting (v2, v0)
|
||||
// normal is the cross product of e0, e1
|
||||
//
|
||||
// Args
|
||||
// v0, v1, v2: float3 coordinates of the vertices of the face
|
||||
//
|
||||
// Returns
|
||||
// float3: normal for the face
|
||||
//
|
||||
__device__ inline float3
|
||||
FaceNormal(const float3 v0, const float3 v1, const float3 v2) {
|
||||
float3 n = cross(v1 - v0, v2 - v0);
|
||||
n = n / fmaxf(norm(n), kEpsilon);
|
||||
return n;
|
||||
}
|
||||
|
||||
// The normal of a box plane defined by the verts in `plane` with
|
||||
// the centroid of the box given by `center`.
|
||||
// Args
|
||||
// plane: float3 coordinates of the vertices of the plane
|
||||
// center: float3 coordinates of the center of the box from
|
||||
// which the plane originated
|
||||
//
|
||||
// Returns
|
||||
// float3: normal for the plane such that it points towards
|
||||
// the center of the box
|
||||
//
|
||||
template <typename FaceVertsPlane>
|
||||
__device__ inline float3 PlaneNormalDirection(
|
||||
const FaceVertsPlane& plane,
|
||||
const float3& center) {
|
||||
// Only need the first 3 verts of the plane
|
||||
const float3 v0 = plane.v0;
|
||||
const float3 v1 = plane.v1;
|
||||
const float3 v2 = plane.v2;
|
||||
|
||||
// We project the center on the plane defined by (v0, v1, v2)
|
||||
// We can write center = v0 + a * e0 + b * e1 + c * n
|
||||
// We know that <e0, n> = 0 and <e1, n> = 0 and
|
||||
// <a, b> is the dot product between a and b.
|
||||
// This means we can solve for c as:
|
||||
// c = <center - v0 - a * e0 - b * e1, n> = <center - v0, n>
|
||||
float3 n = FaceNormal(v0, v1, v2);
|
||||
const float c = dot((center - v0), n);
|
||||
|
||||
// If c is negative, then we revert the direction of n such that n
|
||||
// points "inside"
|
||||
if (c < kEpsilon) {
|
||||
n = -1.0f * n;
|
||||
}
|
||||
|
||||
return n;
|
||||
}
|
||||
|
||||
// Calculate the volume of the box by summing the volume of
|
||||
// each of the tetrahedrons formed with a triangle face and
|
||||
// the box centroid.
|
||||
//
|
||||
// Args
|
||||
// box_tris: vector of float3 coordinates of the vertices of each
|
||||
// of the triangles in the box
|
||||
// box_center: float3 coordinates of the center of the box
|
||||
//
|
||||
// Returns
|
||||
// float: volume of the box
|
||||
//
|
||||
template <typename BoxTris>
|
||||
__device__ inline float BoxVolume(
|
||||
const BoxTris& box_tris,
|
||||
const float3& box_center,
|
||||
const int num_tris) {
|
||||
float box_vol = 0.0;
|
||||
// Iterate through each triange, calculate the area of the
|
||||
// tetrahedron formed with the box_center and sum them
|
||||
for (int t = 0; t < num_tris; ++t) {
|
||||
// Subtract the center:
|
||||
float3 v0 = box_tris[t].v0;
|
||||
float3 v1 = box_tris[t].v1;
|
||||
float3 v2 = box_tris[t].v2;
|
||||
|
||||
v0 = v0 - box_center;
|
||||
v1 = v1 - box_center;
|
||||
v2 = v2 - box_center;
|
||||
|
||||
// Compute the area
|
||||
const float area = dot(v0, cross(v1, v2));
|
||||
const float vol = abs(area) / 6.0;
|
||||
box_vol = box_vol + vol;
|
||||
}
|
||||
return box_vol;
|
||||
}
|
||||
|
||||
// Compute the box center as the mean of the verts
|
||||
//
|
||||
// Args
|
||||
// box_verts: (8, 3) tensor of the corner vertices of the box
|
||||
//
|
||||
// Returns
|
||||
// float3: coordinates of the center of the box
|
||||
//
|
||||
template <typename Box>
|
||||
__device__ inline float3 BoxCenter(const Box box_verts) {
|
||||
float x = 0.0;
|
||||
float y = 0.0;
|
||||
float z = 0.0;
|
||||
const int num_verts = box_verts.size(0); // Should be 8
|
||||
// Sum all x, y, z, and take the mean
|
||||
for (int t = 0; t < num_verts; ++t) {
|
||||
x = x + box_verts[t][0];
|
||||
y = y + box_verts[t][1];
|
||||
z = z + box_verts[t][2];
|
||||
}
|
||||
// Take the mean of all the vertex positions
|
||||
x = x / num_verts;
|
||||
y = y / num_verts;
|
||||
z = z / num_verts;
|
||||
const float3 center = make_float3(x, y, z);
|
||||
return center;
|
||||
}
|
||||
|
||||
// Compute the polyhedron center as the mean of the face centers
|
||||
// of the triangle faces
|
||||
//
|
||||
// Args
|
||||
// tris: vector of float3 coordinates of the
|
||||
// vertices of each of the triangles in the polyhedron
|
||||
//
|
||||
// Returns
|
||||
// float3: coordinates of the center of the polyhedron
|
||||
//
|
||||
template <typename Tris>
|
||||
__device__ inline float3 PolyhedronCenter(
|
||||
const Tris& tris,
|
||||
const int num_tris) {
|
||||
float x = 0.0;
|
||||
float y = 0.0;
|
||||
float z = 0.0;
|
||||
|
||||
// Find the center point of each face
|
||||
for (int t = 0; t < num_tris; ++t) {
|
||||
const float3 v0 = tris[t].v0;
|
||||
const float3 v1 = tris[t].v1;
|
||||
const float3 v2 = tris[t].v2;
|
||||
const float x_face = (v0.x + v1.x + v2.x) / 3.0;
|
||||
const float y_face = (v0.y + v1.y + v2.y) / 3.0;
|
||||
const float z_face = (v0.z + v1.z + v2.z) / 3.0;
|
||||
x = x + x_face;
|
||||
y = y + y_face;
|
||||
z = z + z_face;
|
||||
}
|
||||
|
||||
// Take the mean of the centers of all faces
|
||||
x = x / num_tris;
|
||||
y = y / num_tris;
|
||||
z = z / num_tris;
|
||||
|
||||
const float3 center = make_float3(x, y, z);
|
||||
return center;
|
||||
}
|
||||
|
||||
// Compute a boolean indicator for whether a point
|
||||
// is inside a plane, where inside refers to whether
|
||||
// or not the point has a component in the
|
||||
// normal direction of the plane.
|
||||
//
|
||||
// Args
|
||||
// plane: vector of float3 coordinates of the
|
||||
// vertices of each of the triangles in the box
|
||||
// normal: float3 of the direction of the plane normal
|
||||
// point: float3 of the position of the point of interest
|
||||
//
|
||||
// Returns
|
||||
// bool: whether or not the point is inside the plane
|
||||
//
|
||||
__device__ inline bool
|
||||
IsInside(const FaceVerts& plane, const float3& normal, const float3& point) {
|
||||
// Get one vert of the plane
|
||||
const float3 v0 = plane.v0;
|
||||
|
||||
// Every point p can be written as p = v0 + a e0 + b e1 + c n
|
||||
// Solving for c:
|
||||
// c = (point - v0 - a * e0 - b * e1).dot(n)
|
||||
// We know that <e0, n> = 0 and <e1, n> = 0
|
||||
// So the calculation can be simplified as:
|
||||
const float c = dot((point - v0), normal);
|
||||
const bool inside = c > -1.0f * kEpsilon;
|
||||
return inside;
|
||||
}
|
||||
|
||||
// Find the point of intersection between a plane
|
||||
// and a line given by the end points (p0, p1)
|
||||
//
|
||||
// Args
|
||||
// plane: vector of float3 coordinates of the
|
||||
// vertices of each of the triangles in the box
|
||||
// normal: float3 of the direction of the plane normal
|
||||
// p0, p1: float3 of the start and end point of the line
|
||||
//
|
||||
// Returns
|
||||
// float3: position of the intersection point
|
||||
//
|
||||
__device__ inline float3 PlaneEdgeIntersection(
|
||||
const FaceVerts& plane,
|
||||
const float3& normal,
|
||||
const float3& p0,
|
||||
const float3& p1) {
|
||||
// Get one vert of the plane
|
||||
const float3 v0 = plane.v0;
|
||||
|
||||
// The point of intersection can be parametrized
|
||||
// p = p0 + a (p1 - p0) where a in [0, 1]
|
||||
// We want to find a such that p is on plane
|
||||
// <p - v0, n> = 0
|
||||
const float top = dot(-1.0f * (p0 - v0), normal);
|
||||
const float bot = dot(p1 - p0, normal);
|
||||
const float a = top / bot;
|
||||
const float3 p = p0 + a * (p1 - p0);
|
||||
return p;
|
||||
}
|
||||
|
||||
// Triangle is clipped into a quadrilateral
|
||||
// based on the intersection points with the plane.
|
||||
// Then the quadrilateral is divided into two triangles.
|
||||
//
|
||||
// Args
|
||||
// plane: vector of float3 coordinates of the
|
||||
// vertices of each of the triangles in the box
|
||||
// normal: float3 of the direction of the plane normal
|
||||
// vout: float3 of the point in the triangle which is outside
|
||||
// the plane
|
||||
// vin1, vin2: float3 of the points in the triangle which are
|
||||
// inside the plane
|
||||
// face_verts_out: Array of structs of type FaceVerts,
|
||||
// with the coordinates of the new triangle faces
|
||||
// formed after clipping.
|
||||
// All triangles are now "inside" the plane.
|
||||
//
|
||||
// Returns:
|
||||
// count: (int) number of new faces after clipping the triangle
|
||||
// i.e. the valid faces which have been saved
|
||||
// to face_verts_out
|
||||
//
|
||||
template <typename FaceVertsBox>
|
||||
__device__ inline int ClipTriByPlaneOneOut(
|
||||
const FaceVerts& plane,
|
||||
const float3& normal,
|
||||
const float3& vout,
|
||||
const float3& vin1,
|
||||
const float3& vin2,
|
||||
FaceVertsBox& face_verts_out) {
|
||||
// point of intersection between plane and (vin1, vout)
|
||||
const float3 pint1 = PlaneEdgeIntersection(plane, normal, vin1, vout);
|
||||
// point of intersection between plane and (vin2, vout)
|
||||
const float3 pint2 = PlaneEdgeIntersection(plane, normal, vin2, vout);
|
||||
|
||||
face_verts_out[0] = {vin1, pint1, pint2};
|
||||
face_verts_out[1] = {vin1, pint2, vin2};
|
||||
|
||||
return 2;
|
||||
}
|
||||
|
||||
// Triangle is clipped into a smaller triangle based
|
||||
// on the intersection points with the plane.
|
||||
//
|
||||
// Args
|
||||
// plane: vector of float3 coordinates of the
|
||||
// vertices of each of the triangles in the box
|
||||
// normal: float3 of the direction of the plane normal
|
||||
// vout1, vout2: float3 of the points in the triangle which are
|
||||
// outside the plane
|
||||
// vin: float3 of the point in the triangle which is inside
|
||||
// the plane
|
||||
// face_verts_out: Array of structs of type FaceVerts,
|
||||
// with the coordinates of the new triangle faces
|
||||
// formed after clipping.
|
||||
// All triangles are now "inside" the plane.
|
||||
//
|
||||
// Returns:
|
||||
// count: (int) number of new faces after clipping the triangle
|
||||
// i.e. the valid faces which have been saved
|
||||
// to face_verts_out
|
||||
//
|
||||
template <typename FaceVertsBox>
|
||||
__device__ inline int ClipTriByPlaneTwoOut(
|
||||
const FaceVerts& plane,
|
||||
const float3& normal,
|
||||
const float3& vout1,
|
||||
const float3& vout2,
|
||||
const float3& vin,
|
||||
FaceVertsBox& face_verts_out) {
|
||||
// point of intersection between plane and (vin, vout1)
|
||||
const float3 pint1 = PlaneEdgeIntersection(plane, normal, vin, vout1);
|
||||
// point of intersection between plane and (vin, vout2)
|
||||
const float3 pint2 = PlaneEdgeIntersection(plane, normal, vin, vout2);
|
||||
|
||||
face_verts_out[0] = {vin, pint1, pint2};
|
||||
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Clip the triangle faces so that they lie within the
|
||||
// plane, creating new triangle faces where necessary.
|
||||
//
|
||||
// Args
|
||||
// plane: Array of structs of type FaceVerts with the coordinates
|
||||
// of the vertices of each of the triangles in the box
|
||||
// tri: Array of structs of type FaceVerts with the vertex
|
||||
// coordinates of the triangle faces
|
||||
// normal: float3 of the direction of the plane normal
|
||||
// face_verts_out: Array of structs of type FaceVerts,
|
||||
// with the coordinates of the new triangle faces
|
||||
// formed after clipping.
|
||||
// All triangles are now "inside" the plane.
|
||||
//
|
||||
// Returns:
|
||||
// count: (int) number of new faces after clipping the triangle
|
||||
// i.e. the valid faces which have been saved
|
||||
// to face_verts_out
|
||||
//
|
||||
template <typename FaceVertsBox>
|
||||
__device__ inline int ClipTriByPlane(
|
||||
const FaceVerts& plane,
|
||||
const FaceVerts& tri,
|
||||
const float3& normal,
|
||||
FaceVertsBox& face_verts_out) {
|
||||
// Get Triangle vertices
|
||||
const float3 v0 = tri.v0;
|
||||
const float3 v1 = tri.v1;
|
||||
const float3 v2 = tri.v2;
|
||||
|
||||
// Check each of the triangle vertices to see if it is inside the plane
|
||||
const bool isin0 = IsInside(plane, normal, v0);
|
||||
const bool isin1 = IsInside(plane, normal, v1);
|
||||
const bool isin2 = IsInside(plane, normal, v2);
|
||||
|
||||
// All in
|
||||
if (isin0 && isin1 && isin2) {
|
||||
// Return input vertices
|
||||
face_verts_out[0] = {v0, v1, v2};
|
||||
return 1;
|
||||
}
|
||||
|
||||
// All out
|
||||
if (!isin0 && !isin1 && !isin2) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
// One vert out
|
||||
if (isin0 && isin1 && !isin2) {
|
||||
return ClipTriByPlaneOneOut(plane, normal, v2, v0, v1, face_verts_out);
|
||||
}
|
||||
if (isin0 && not isin1 && isin2) {
|
||||
return ClipTriByPlaneOneOut(plane, normal, v1, v0, v2, face_verts_out);
|
||||
}
|
||||
if (not isin0 && isin1 && isin2) {
|
||||
return ClipTriByPlaneOneOut(plane, normal, v0, v1, v2, face_verts_out);
|
||||
}
|
||||
|
||||
// Two verts out
|
||||
if (isin0 && !isin1 && !isin2) {
|
||||
return ClipTriByPlaneTwoOut(plane, normal, v1, v2, v0, face_verts_out);
|
||||
}
|
||||
if (!isin0 && !isin1 && isin2) {
|
||||
return ClipTriByPlaneTwoOut(plane, normal, v0, v1, v2, face_verts_out);
|
||||
}
|
||||
if (!isin0 && isin1 && !isin2) {
|
||||
return ClipTriByPlaneTwoOut(plane, normal, v0, v2, v1, face_verts_out);
|
||||
}
|
||||
|
||||
// Else return empty (should not be reached)
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Compute a boolean indicator for whether or not two faces
|
||||
// are coplanar
|
||||
//
|
||||
// Args
|
||||
// tri1, tri2: FaceVerts struct of the vertex coordinates of
|
||||
// the triangle face
|
||||
//
|
||||
// Returns
|
||||
// bool: whether or not the two faces are coplanar
|
||||
//
|
||||
__device__ inline bool IsCoplanarFace(
|
||||
const FaceVerts& tri1,
|
||||
const FaceVerts& tri2) {
|
||||
// Get verts for face 1
|
||||
const float3 v0 = tri1.v0;
|
||||
const float3 v1 = tri1.v1;
|
||||
const float3 v2 = tri1.v2;
|
||||
|
||||
const float3 n1 = FaceNormal(v0, v1, v2);
|
||||
int coplanar_count = 0;
|
||||
|
||||
// Check v0, v1, v2
|
||||
if (abs(dot(tri2.v0 - v0, n1)) < kEpsilon) {
|
||||
coplanar_count++;
|
||||
}
|
||||
if (abs(dot(tri2.v1 - v0, n1)) < kEpsilon) {
|
||||
coplanar_count++;
|
||||
}
|
||||
if (abs(dot(tri2.v2 - v0, n1)) < kEpsilon) {
|
||||
coplanar_count++;
|
||||
}
|
||||
return (coplanar_count == 3);
|
||||
}
|
||||
|
||||
// Get the triangles from each box which are part of the
|
||||
// intersecting polyhedron by computing the intersection
|
||||
// points with each of the planes.
|
||||
//
|
||||
// Args
|
||||
// planes: Array of structs of type FaceVerts with the coordinates
|
||||
// of the vertices of each of the triangles in the box
|
||||
// center: float3 coordinates of the center of the box from which
|
||||
// the planes originate
|
||||
// face_verts_out: Array of structs of type FaceVerts,
|
||||
// where the coordinates of the new triangle faces
|
||||
// formed after clipping will be saved to.
|
||||
// All triangles are now "inside" the plane.
|
||||
//
|
||||
// Returns:
|
||||
// count: (int) number of faces in the intersecting shape
|
||||
// i.e. the valid faces which have been saved
|
||||
// to face_verts_out
|
||||
//
|
||||
template <typename FaceVertsPlane, typename FaceVertsBox>
|
||||
__device__ inline int BoxIntersections(
|
||||
const FaceVertsPlane& planes,
|
||||
const float3& center,
|
||||
FaceVertsBox& face_verts_out) {
|
||||
// Initialize num tris to 12
|
||||
int num_tris = NUM_TRIS;
|
||||
for (int p = 0; p < NUM_PLANES; ++p) {
|
||||
// Get plane normal direction
|
||||
const float3 n2 = PlaneNormalDirection(planes[p], center);
|
||||
// Create intermediate vector to store the updated tris
|
||||
FaceVerts tri_verts_updated[MAX_TRIS];
|
||||
int offset = 0;
|
||||
|
||||
// Iterate through triangles in face_verts_out
|
||||
// for the valid tris given by num_tris
|
||||
for (int t = 0; t < num_tris; ++t) {
|
||||
// Clip tri by plane, can max be split into 2 triangles
|
||||
FaceVerts tri_updated[2];
|
||||
const int count =
|
||||
ClipTriByPlane(planes[p], face_verts_out[t], n2, tri_updated);
|
||||
// Add to the tri_verts_updated output if not empty
|
||||
for (int v = 0; v < count; ++v) {
|
||||
tri_verts_updated[offset] = tri_updated[v];
|
||||
offset++;
|
||||
}
|
||||
}
|
||||
// Update the face_verts_out tris
|
||||
num_tris = offset;
|
||||
for (int j = 0; j < num_tris; ++j) {
|
||||
face_verts_out[j] = tri_verts_updated[j];
|
||||
}
|
||||
}
|
||||
return num_tris;
|
||||
}
|
@ -9,6 +9,7 @@ from .cameras_alignment import corresponding_cameras_alignment
|
||||
from .cubify import cubify
|
||||
from .graph_conv import GraphConv
|
||||
from .interp_face_attrs import interpolate_face_attributes
|
||||
from .iou_box3d import box3d_overlap
|
||||
from .knn import knn_gather, knn_points
|
||||
from .laplacian_matrices import cot_laplacian, laplacian, norm_laplacian
|
||||
from .mesh_face_areas_normals import mesh_face_areas_normals
|
||||
|
@ -7,10 +7,68 @@
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from pytorch3d import _C
|
||||
from torch.autograd import Function
|
||||
|
||||
|
||||
# -------------------------------------------------- #
|
||||
# CONSTANTS #
|
||||
# -------------------------------------------------- #
|
||||
"""
|
||||
_box_planes and _box_triangles define the 4- and 3-connectivity
|
||||
of the 8 box corners.
|
||||
_box_planes gives the quad faces of the 3D box
|
||||
_box_triangles gives the triangle faces of the 3D box
|
||||
"""
|
||||
_box_planes = [
|
||||
[0, 1, 2, 3],
|
||||
[3, 2, 6, 7],
|
||||
[0, 1, 5, 4],
|
||||
[0, 3, 7, 4],
|
||||
[1, 2, 6, 5],
|
||||
[4, 5, 6, 7],
|
||||
]
|
||||
_box_triangles = [
|
||||
[0, 1, 2],
|
||||
[0, 3, 2],
|
||||
[4, 5, 6],
|
||||
[4, 6, 7],
|
||||
[1, 5, 6],
|
||||
[1, 6, 2],
|
||||
[0, 4, 7],
|
||||
[0, 7, 3],
|
||||
[3, 2, 6],
|
||||
[3, 6, 7],
|
||||
[0, 1, 5],
|
||||
[0, 4, 5],
|
||||
]
|
||||
|
||||
|
||||
def _check_coplanar(boxes: torch.Tensor, eps: float = 1e-5) -> None:
|
||||
faces = torch.tensor(_box_planes, dtype=torch.int64, device=boxes.device)
|
||||
# pyre-fixme[16]: `boxes` has no attribute `index_select`.
|
||||
verts = boxes.index_select(index=faces.view(-1), dim=1)
|
||||
B = boxes.shape[0]
|
||||
P, V = faces.shape
|
||||
# (B, P, 4, 3) -> (B, P, 3)
|
||||
v0, v1, v2, v3 = verts.reshape(B, P, V, 3).unbind(2)
|
||||
|
||||
# Compute the normal
|
||||
e0 = F.normalize(v1 - v0, dim=-1)
|
||||
e1 = F.normalize(v2 - v0, dim=-1)
|
||||
normal = F.normalize(torch.cross(e0, e1, dim=-1), dim=-1)
|
||||
|
||||
# Check the fourth vertex is also on the same plane
|
||||
mat1 = (v3 - v0).view(B, 1, -1) # (B, 1, P*3)
|
||||
mat2 = normal.view(B, -1, 1) # (B, P*3, 1)
|
||||
if not (mat1.bmm(mat2).abs() < eps).all().item():
|
||||
msg = "Plane vertices are not coplanar"
|
||||
raise ValueError(msg)
|
||||
|
||||
return
|
||||
|
||||
|
||||
class _box3d_overlap(Function):
|
||||
"""
|
||||
Torch autograd Function wrapper for box3d_overlap C++/CUDA implementations.
|
||||
@ -35,6 +93,7 @@ def box3d_overlap(
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Computes the intersection of 3D boxes1 and boxes2.
|
||||
|
||||
Inputs boxes1, boxes2 are tensors of shape (B, 8, 3)
|
||||
(where B doesn't have to be the same for boxes1 and boxes1),
|
||||
containing the 8 corners of the boxes, as follows:
|
||||
@ -47,6 +106,25 @@ def box3d_overlap(
|
||||
` . | ` . |
|
||||
(3) ` +---------+ (2)
|
||||
|
||||
|
||||
NOTE: Throughout this implementation, we assume that boxes
|
||||
are defined by their 8 corners exactly in the order specified in the
|
||||
diagram above for the function to give correct results. In addition
|
||||
the vertices on each plane must be coplanar.
|
||||
As an alternative to the diagram, this is a unit bounding
|
||||
box which has the correct vertex ordering:
|
||||
|
||||
box_corner_vertices = [
|
||||
[0, 0, 0],
|
||||
[1, 0, 0],
|
||||
[1, 1, 0],
|
||||
[0, 1, 0],
|
||||
[0, 0, 1],
|
||||
[1, 0, 1],
|
||||
[1, 1, 1],
|
||||
[0, 1, 1],
|
||||
]
|
||||
|
||||
Args:
|
||||
boxes1: tensor of shape (N, 8, 3) of the coordinates of the 1st boxes
|
||||
boxes2: tensor of shape (M, 8, 3) of the coordinates of the 2nd boxes
|
||||
@ -58,6 +136,9 @@ def box3d_overlap(
|
||||
if not all((8, 3) == box.shape[1:] for box in [boxes1, boxes2]):
|
||||
raise ValueError("Each box in the batch must be of shape (8, 3)")
|
||||
|
||||
_check_coplanar(boxes1)
|
||||
_check_coplanar(boxes2)
|
||||
|
||||
# pyre-fixme[16]: `_box3d_overlap` has no attribute `apply`.
|
||||
vol, iou = _box3d_overlap.apply(boxes1, boxes2)
|
||||
|
||||
|
@ -11,25 +11,42 @@ from test_iou_box3d import TestIoU3D
|
||||
|
||||
|
||||
def bm_iou_box3d() -> None:
|
||||
N = [1, 4, 8, 16]
|
||||
num_samples = [2000, 5000, 10000, 20000]
|
||||
# Realistic use cases
|
||||
N = [30, 100]
|
||||
M = [5, 10, 100]
|
||||
kwargs_list = []
|
||||
test_cases = product(N, M)
|
||||
for case in test_cases:
|
||||
n, m = case
|
||||
kwargs_list.append({"N": n, "M": m, "device": "cuda:0"})
|
||||
benchmark(TestIoU3D.iou, "3D_IOU", kwargs_list, warmup_iters=1)
|
||||
|
||||
# Comparison of C++/CUDA
|
||||
kwargs_list = []
|
||||
N = [1, 4, 8, 16]
|
||||
devices = ["cpu", "cuda:0"]
|
||||
test_cases = product(N, N, devices)
|
||||
for case in test_cases:
|
||||
n, m, d = case
|
||||
kwargs_list.append({"N": n, "M": m, "device": d})
|
||||
benchmark(TestIoU3D.iou, "3D_IOU", kwargs_list, warmup_iters=1)
|
||||
|
||||
# Naive PyTorch
|
||||
N = [1, 4]
|
||||
kwargs_list = []
|
||||
test_cases = product(N, N)
|
||||
for case in test_cases:
|
||||
n, m = case
|
||||
kwargs_list.append({"N": n, "M": m, "device": "cuda:0"})
|
||||
|
||||
benchmark(TestIoU3D.iou_naive, "3D_IOU_NAIVE", kwargs_list, warmup_iters=1)
|
||||
|
||||
[k.update({"device": "cpu"}) for k in kwargs_list]
|
||||
benchmark(TestIoU3D.iou, "3D_IOU", kwargs_list, warmup_iters=1)
|
||||
|
||||
# Sampling based method
|
||||
num_samples = [2000, 5000]
|
||||
kwargs_list = []
|
||||
test_cases = product([1, 4], [1, 4], num_samples)
|
||||
test_cases = product(N, N, num_samples)
|
||||
for case in test_cases:
|
||||
n, m, s = case
|
||||
kwargs_list.append({"N": n, "M": m, "num_samples": s})
|
||||
kwargs_list.append({"N": n, "M": m, "num_samples": s, "device": "cuda:0"})
|
||||
benchmark(TestIoU3D.iou_sampling, "3D_IOU_SAMPLING", kwargs_list, warmup_iters=1)
|
||||
|
||||
|
||||
|
BIN
tests/data/objectron_vols_ious.pt
Normal file
BIN
tests/data/objectron_vols_ious.pt
Normal file
Binary file not shown.
@ -10,13 +10,28 @@ from typing import List, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from common_testing import TestCaseMixin
|
||||
from common_testing import TestCaseMixin, get_random_cuda_device, get_tests_dir
|
||||
from pytorch3d.io import save_obj
|
||||
|
||||
from pytorch3d.ops.iou_box3d import box3d_overlap
|
||||
from pytorch3d.ops.iou_box3d import _box_planes, _box_triangles, box3d_overlap
|
||||
from pytorch3d.transforms.rotation_conversions import random_rotation
|
||||
|
||||
|
||||
OBJECTRON_TO_PYTORCH3D_FACE_IDX = [0, 4, 6, 2, 1, 5, 7, 3]
|
||||
DATA_DIR = get_tests_dir() / "data"
|
||||
DEBUG = False
|
||||
|
||||
UNIT_BOX = [
|
||||
[0, 0, 0],
|
||||
[1, 0, 0],
|
||||
[1, 1, 0],
|
||||
[0, 1, 0],
|
||||
[0, 0, 1],
|
||||
[1, 0, 1],
|
||||
[1, 1, 1],
|
||||
[0, 1, 1],
|
||||
]
|
||||
|
||||
|
||||
class TestIoU3D(TestCaseMixin, unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
@ -78,16 +93,7 @@ class TestIoU3D(TestCaseMixin, unittest.TestCase):
|
||||
def _test_iou(self, overlap_fn, device):
|
||||
|
||||
box1 = torch.tensor(
|
||||
[
|
||||
[0, 0, 0],
|
||||
[1, 0, 0],
|
||||
[1, 1, 0],
|
||||
[0, 1, 0],
|
||||
[0, 0, 1],
|
||||
[1, 0, 1],
|
||||
[1, 1, 1],
|
||||
[0, 1, 1],
|
||||
],
|
||||
UNIT_BOX,
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
)
|
||||
@ -126,6 +132,10 @@ class TestIoU3D(TestCaseMixin, unittest.TestCase):
|
||||
),
|
||||
)
|
||||
|
||||
# Also check IoU is 1 when computing overlap with the same shifted box
|
||||
vol, iou = overlap_fn(box2[None], box2[None])
|
||||
self.assertClose(iou, torch.tensor([[1.0]], device=vol.device, dtype=vol.dtype))
|
||||
|
||||
# 5th test
|
||||
ddx, ddy, ddz = random.random(), random.random(), random.random()
|
||||
box2 = box1 + torch.tensor([[ddx, ddy, ddz]], device=device)
|
||||
@ -207,15 +217,15 @@ class TestIoU3D(TestCaseMixin, unittest.TestCase):
|
||||
# create box1
|
||||
ctrs = torch.rand((2, 3), device=device)
|
||||
whl = torch.rand((2, 3), device=device) * 10.0 + 1.0
|
||||
# box1 & box2
|
||||
box1 = self.create_box(ctrs[0], whl[0])
|
||||
box2 = self.create_box(ctrs[1], whl[1])
|
||||
# box8a & box8b
|
||||
box8a = self.create_box(ctrs[0], whl[0])
|
||||
box8b = self.create_box(ctrs[1], whl[1])
|
||||
RR1 = random_rotation(dtype=torch.float32, device=device)
|
||||
TT1 = torch.rand((1, 3), dtype=torch.float32, device=device)
|
||||
RR2 = random_rotation(dtype=torch.float32, device=device)
|
||||
TT2 = torch.rand((1, 3), dtype=torch.float32, device=device)
|
||||
box1r = box1 @ RR1.transpose(0, 1) + TT1
|
||||
box2r = box2 @ RR2.transpose(0, 1) + TT2
|
||||
box1r = box8a @ RR1.transpose(0, 1) + TT1
|
||||
box2r = box8b @ RR2.transpose(0, 1) + TT2
|
||||
vol, iou = overlap_fn(box1r[None], box2r[None])
|
||||
iou_sampling = self._box3d_overlap_sampling_batched(
|
||||
box1r[None], box2r[None], num_samples=10000
|
||||
@ -229,27 +239,90 @@ class TestIoU3D(TestCaseMixin, unittest.TestCase):
|
||||
self.assertClose(vol, torch.tensor([[0.0]], device=vol.device, dtype=vol.dtype))
|
||||
self.assertClose(iou, torch.tensor([[0.0]], device=vol.device, dtype=vol.dtype))
|
||||
|
||||
# 10th test: Non coplanar verts in a plane
|
||||
box10 = box1 + torch.rand((8, 3), dtype=torch.float32, device=device)
|
||||
msg = "Plane vertices are not coplanar"
|
||||
with self.assertRaisesRegex(ValueError, msg):
|
||||
overlap_fn(box10[None], box10[None])
|
||||
|
||||
# 11th test: Skewed bounding boxes but all verts are coplanar
|
||||
box_skew_1 = torch.tensor(
|
||||
[
|
||||
[0, 0, 0],
|
||||
[1, 0, 0],
|
||||
[1, 1, 0],
|
||||
[0, 1, 0],
|
||||
[-2, -2, 2],
|
||||
[2, -2, 2],
|
||||
[2, 2, 2],
|
||||
[-2, 2, 2],
|
||||
],
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
)
|
||||
box_skew_2 = torch.tensor(
|
||||
[
|
||||
[2.015995, 0.695233, 2.152806],
|
||||
[2.832533, 0.663448, 1.576389],
|
||||
[2.675445, -0.309592, 1.407520],
|
||||
[1.858907, -0.277806, 1.983936],
|
||||
[-0.413922, 3.161758, 2.044343],
|
||||
[2.852230, 3.034615, -0.261321],
|
||||
[2.223878, -0.857545, -0.936800],
|
||||
[-1.042273, -0.730402, 1.368864],
|
||||
],
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
)
|
||||
vol1 = 14.000
|
||||
vol2 = 14.000005
|
||||
vol_inters = 5.431122
|
||||
iou = vol_inters / (vol1 + vol2 - vol_inters)
|
||||
|
||||
vols, ious = overlap_fn(box_skew_1[None], box_skew_2[None])
|
||||
self.assertClose(vols, torch.tensor([[vol_inters]], device=device), atol=1e-1)
|
||||
self.assertClose(ious, torch.tensor([[iou]], device=device), atol=1e-1)
|
||||
|
||||
def test_iou_naive(self):
|
||||
device = torch.device("cuda:0")
|
||||
device = get_random_cuda_device()
|
||||
self._test_iou(self._box3d_overlap_naive_batched, device)
|
||||
self._test_compare_objectron(self._box3d_overlap_naive_batched, device)
|
||||
|
||||
def test_iou_cpu(self):
|
||||
device = torch.device("cpu")
|
||||
self._test_iou(box3d_overlap, device)
|
||||
self._test_compare_objectron(box3d_overlap, device)
|
||||
|
||||
def test_cpu_vs_naive_batched(self):
|
||||
N, M = 3, 6
|
||||
device = "cpu"
|
||||
boxes1 = torch.randn((N, 8, 3), device=device)
|
||||
boxes2 = torch.randn((M, 8, 3), device=device)
|
||||
vol1, iou1 = self._box3d_overlap_naive_batched(boxes1, boxes2)
|
||||
vol2, iou2 = box3d_overlap(boxes1, boxes2)
|
||||
# check shape
|
||||
for val in [vol1, vol2, iou1, iou2]:
|
||||
self.assertClose(val.shape, (N, M))
|
||||
# check values
|
||||
self.assertClose(vol1, vol2)
|
||||
self.assertClose(iou1, iou2)
|
||||
def test_iou_cuda(self):
|
||||
device = torch.device("cuda:0")
|
||||
self._test_iou(box3d_overlap, device)
|
||||
self._test_compare_objectron(box3d_overlap, device)
|
||||
|
||||
def _test_compare_objectron(self, overlap_fn, device):
|
||||
# Load saved objectron data
|
||||
data_filename = "./objectron_vols_ious.pt"
|
||||
objectron_vals = torch.load(DATA_DIR / data_filename)
|
||||
boxes1 = objectron_vals["boxes1"]
|
||||
boxes2 = objectron_vals["boxes2"]
|
||||
vols_objectron = objectron_vals["vols"]
|
||||
ious_objectron = objectron_vals["ious"]
|
||||
|
||||
boxes1 = boxes1.to(device=device, dtype=torch.float32)
|
||||
boxes2 = boxes2.to(device=device, dtype=torch.float32)
|
||||
|
||||
# Convert vertex orderings from Objectron to PyTorch3D convention
|
||||
idx = torch.tensor(
|
||||
OBJECTRON_TO_PYTORCH3D_FACE_IDX, dtype=torch.int64, device=device
|
||||
)
|
||||
boxes1 = boxes1.index_select(index=idx, dim=1)
|
||||
boxes2 = boxes2.index_select(index=idx, dim=1)
|
||||
|
||||
# Run PyTorch3D version
|
||||
vols, ious = overlap_fn(boxes1, boxes2)
|
||||
|
||||
# Check values match
|
||||
self.assertClose(vols_objectron, vols.cpu())
|
||||
self.assertClose(ious_objectron, ious.cpu())
|
||||
|
||||
def test_batched_errors(self):
|
||||
N, M = 5, 10
|
||||
@ -316,16 +389,7 @@ class TestIoU3D(TestCaseMixin, unittest.TestCase):
|
||||
def test_box_planar_dir(self):
|
||||
device = torch.device("cuda:0")
|
||||
box1 = torch.tensor(
|
||||
[
|
||||
[0, 0, 0],
|
||||
[1, 0, 0],
|
||||
[1, 1, 0],
|
||||
[0, 1, 0],
|
||||
[0, 0, 1],
|
||||
[1, 0, 1],
|
||||
[1, 1, 1],
|
||||
[0, 1, 1],
|
||||
],
|
||||
UNIT_BOX,
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
)
|
||||
@ -353,8 +417,13 @@ class TestIoU3D(TestCaseMixin, unittest.TestCase):
|
||||
|
||||
@staticmethod
|
||||
def iou_naive(N: int, M: int, device="cpu"):
|
||||
boxes1 = torch.randn((N, 8, 3))
|
||||
boxes2 = torch.randn((M, 8, 3))
|
||||
box = torch.tensor(
|
||||
[UNIT_BOX],
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
)
|
||||
boxes1 = box + torch.randn((N, 1, 3), device=device)
|
||||
boxes2 = box + torch.randn((M, 1, 3), device=device)
|
||||
|
||||
def output():
|
||||
vol, iou = TestIoU3D._box3d_overlap_naive_batched(boxes1, boxes2)
|
||||
@ -363,8 +432,13 @@ class TestIoU3D(TestCaseMixin, unittest.TestCase):
|
||||
|
||||
@staticmethod
|
||||
def iou(N: int, M: int, device="cpu"):
|
||||
boxes1 = torch.randn((N, 8, 3), device=device)
|
||||
boxes2 = torch.randn((M, 8, 3), device=device)
|
||||
box = torch.tensor(
|
||||
[UNIT_BOX],
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
)
|
||||
boxes1 = box + torch.randn((N, 1, 3), device=device)
|
||||
boxes2 = box + torch.randn((M, 1, 3), device=device)
|
||||
|
||||
def output():
|
||||
vol, iou = box3d_overlap(boxes1, boxes2)
|
||||
@ -372,9 +446,14 @@ class TestIoU3D(TestCaseMixin, unittest.TestCase):
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def iou_sampling(N: int, M: int, num_samples: int):
|
||||
boxes1 = torch.randn((N, 8, 3))
|
||||
boxes2 = torch.randn((M, 8, 3))
|
||||
def iou_sampling(N: int, M: int, num_samples: int, device="cpu"):
|
||||
box = torch.tensor(
|
||||
[UNIT_BOX],
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
)
|
||||
boxes1 = box + torch.randn((N, 1, 3), device=device)
|
||||
boxes2 = box + torch.randn((M, 1, 3), device=device)
|
||||
|
||||
def output():
|
||||
_ = TestIoU3D._box3d_overlap_sampling_batched(boxes1, boxes2, num_samples)
|
||||
@ -408,38 +487,6 @@ Note that both implementations currently do not support batching.
|
||||
#
|
||||
# -------------------------------------------------- #
|
||||
|
||||
# -------------------------------------------------- #
|
||||
# CONSTANTS #
|
||||
# -------------------------------------------------- #
|
||||
"""
|
||||
_box_planes and _box_triangles define the 4- and 3-connectivity
|
||||
of the 8 box corners.
|
||||
_box_planes gives the quad faces of the 3D box
|
||||
_box_triangles gives the triangle faces of the 3D box
|
||||
"""
|
||||
_box_planes = [
|
||||
[0, 1, 2, 3],
|
||||
[3, 2, 6, 7],
|
||||
[0, 1, 5, 4],
|
||||
[0, 3, 7, 4],
|
||||
[1, 5, 6, 2],
|
||||
[4, 5, 6, 7],
|
||||
]
|
||||
_box_triangles = [
|
||||
[0, 1, 2],
|
||||
[0, 3, 2],
|
||||
[4, 5, 6],
|
||||
[4, 6, 7],
|
||||
[1, 5, 6],
|
||||
[1, 6, 2],
|
||||
[0, 4, 7],
|
||||
[0, 7, 3],
|
||||
[3, 2, 6],
|
||||
[3, 6, 7],
|
||||
[0, 1, 5],
|
||||
[0, 4, 5],
|
||||
]
|
||||
|
||||
# -------------------------------------------------- #
|
||||
# HELPER FUNCTIONS FOR EXACT SOLUTION #
|
||||
# -------------------------------------------------- #
|
||||
@ -477,7 +524,7 @@ def get_plane_verts(box: torch.Tensor) -> torch.Tensor:
|
||||
return plane_verts
|
||||
|
||||
|
||||
def box_planar_dir(box: torch.Tensor) -> torch.Tensor:
|
||||
def box_planar_dir(box: torch.Tensor, eps=1e-4) -> torch.Tensor:
|
||||
"""
|
||||
Finds the unit vector n which is perpendicular to each plane in the box
|
||||
and points towards the inside of the box.
|
||||
@ -507,6 +554,11 @@ def box_planar_dir(box: torch.Tensor) -> torch.Tensor:
|
||||
e1 = F.normalize(v2 - v0, dim=-1)
|
||||
n = F.normalize(torch.cross(e0, e1, dim=-1), dim=-1)
|
||||
|
||||
# Check all verts are coplanar
|
||||
if not ((v3 - v0).unsqueeze(1).bmm(n.unsqueeze(2)).abs() < eps).all().item():
|
||||
msg = "Plane vertices are not coplanar"
|
||||
raise ValueError(msg)
|
||||
|
||||
# We can write: `ctr = v0 + a * e0 + b * e1 + c * n`, (1).
|
||||
# With <e0, n> = 0 and <e1, n> = 0, where <.,.> refers to the dot product,
|
||||
# since that e0 is orthogonal to n. Same for e1.
|
||||
@ -733,10 +785,10 @@ def clip_tri_by_plane_oneout(
|
||||
device = plane.device
|
||||
# point of intersection between plane and (vin1, vout)
|
||||
pint1, a1 = plane_edge_point_of_intersection(plane, n, vin1, vout)
|
||||
assert a1 >= eps and a1 <= 1.0
|
||||
assert a1 >= eps and a1 <= 1.0, a1
|
||||
# point of intersection between plane and (vin2, vout)
|
||||
pint2, a2 = plane_edge_point_of_intersection(plane, n, vin2, vout)
|
||||
assert a2 >= 0.0 and a2 <= 1.0
|
||||
assert a2 >= 0.0 and a2 <= 1.0, a2
|
||||
|
||||
verts = torch.stack((vin1, pint1, pint2, vin2), dim=0) # 4x3
|
||||
faces = torch.tensor(
|
||||
@ -771,10 +823,10 @@ def clip_tri_by_plane_twoout(
|
||||
device = plane.device
|
||||
# point of intersection between plane and (vin, vout1)
|
||||
pint1, a1 = plane_edge_point_of_intersection(plane, n, vin, vout1)
|
||||
assert a1 >= eps and a1 <= 1.0
|
||||
assert a1 >= eps and a1 <= 1.0, a1
|
||||
# point of intersection between plane and (vin, vout2)
|
||||
pint2, a2 = plane_edge_point_of_intersection(plane, n, vin, vout2)
|
||||
assert a2 >= eps and a2 <= 1.0
|
||||
assert a2 >= eps and a2 <= 1.0, a2
|
||||
|
||||
verts = torch.stack((vin, pint1, pint2), dim=0) # 3x3
|
||||
faces = torch.tensor(
|
||||
@ -945,7 +997,7 @@ def box3d_overlap_naive(box1: torch.Tensor, box2: torch.Tensor):
|
||||
|
||||
iou = vol / (vol1 + vol2 - vol)
|
||||
|
||||
if 0:
|
||||
if DEBUG:
|
||||
# save shapes
|
||||
tri_faces = torch.tensor(_box_triangles, device=device, dtype=torch.int64)
|
||||
save_obj("/tmp/output/shape1.obj", box1, tri_faces)
|
||||
|
Loading…
x
Reference in New Issue
Block a user