diff --git a/pytorch3d/csrc/iou_box3d/iou_box3d.cu b/pytorch3d/csrc/iou_box3d/iou_box3d.cu new file mode 100644 index 00000000..345b1914 --- /dev/null +++ b/pytorch3d/csrc/iou_box3d/iou_box3d.cu @@ -0,0 +1,176 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include "iou_box3d/iou_utils.cuh" +#include "utils/pytorch3d_cutils.h" + +// Parallelize over N*M computations which can each be done +// independently +__global__ void IoUBox3DKernel( + const at::PackedTensorAccessor64 boxes1, + const at::PackedTensorAccessor64 boxes2, + at::PackedTensorAccessor64 vols, + at::PackedTensorAccessor64 ious) { + const size_t N = boxes1.size(0); + const size_t M = boxes2.size(0); + + const size_t tid = blockIdx.x * blockDim.x + threadIdx.x; + const size_t stride = gridDim.x * blockDim.x; + + for (size_t i = tid; i < N * M; i += stride) { + const size_t n = i / M; // box1 index + const size_t m = i % M; // box2 index + + // Convert to array of structs of face vertices i.e. effectively (F, 3, 3) + // FaceVerts is a data type defined in iou_utils.cuh + FaceVerts box1_tris[NUM_TRIS]; + FaceVerts box2_tris[NUM_TRIS]; + GetBoxTris(boxes1[n], box1_tris); + GetBoxTris(boxes2[m], box2_tris); + + // Calculate the position of the center of the box which is used in + // several calculations. This requires a tensor as input. + const float3 box1_center = BoxCenter(boxes1[n]); + const float3 box2_center = BoxCenter(boxes2[m]); + + // Convert to an array of face vertices + FaceVerts box1_planes[NUM_PLANES]; + GetBoxPlanes(boxes1[n], box1_planes); + FaceVerts box2_planes[NUM_PLANES]; + GetBoxPlanes(boxes2[m], box2_planes); + + // Get Box Volumes + const float box1_vol = BoxVolume(box1_tris, box1_center, NUM_TRIS); + const float box2_vol = BoxVolume(box2_tris, box2_center, NUM_TRIS); + + // Tris in Box1 intersection with Planes in Box2 + // Initialize box1 intersecting faces. MAX_TRIS is the + // max faces possible in the intersecting shape. + // TODO: determine if the value of MAX_TRIS is sufficient or + // if we should store the max tris for each NxM computation + // and throw an error if any exceeds the max. + FaceVerts box1_intersect[MAX_TRIS]; + for (int j = 0; j < NUM_TRIS; ++j) { + // Initialize the faces from the box + box1_intersect[j] = box1_tris[j]; + } + // Get the count of the actual number of faces in the intersecting shape + int box1_count = BoxIntersections(box2_planes, box2_center, box1_intersect); + + // Tris in Box2 intersection with Planes in Box1 + FaceVerts box2_intersect[MAX_TRIS]; + for (int j = 0; j < NUM_TRIS; ++j) { + box2_intersect[j] = box2_tris[j]; + } + const int box2_count = + BoxIntersections(box1_planes, box1_center, box2_intersect); + + // If there are overlapping regions in Box2, remove any coplanar faces + if (box2_count > 0) { + // Identify if any triangles in Box2 are coplanar with Box1 + Keep tri2_keep[MAX_TRIS]; + for (int j = 0; j < MAX_TRIS; ++j) { + // Initialize the valid faces to be true + tri2_keep[j].keep = j < box2_count ? true : false; + } + for (int b1 = 0; b1 < box1_count; ++b1) { + for (int b2 = 0; b2 < box2_count; ++b2) { + const bool is_coplanar = + IsCoplanarFace(box1_intersect[b1], box2_intersect[b2]); + if (is_coplanar) { + tri2_keep[b2].keep = false; + } + } + } + + // Keep only the non coplanar triangles in Box2 - add them to the + // Box1 triangles. + for (int b2 = 0; b2 < box2_count; ++b2) { + if (tri2_keep[b2].keep) { + box1_intersect[box1_count] = box2_intersect[b2]; + // box1_count will determine the total faces in the + // intersecting shape + box1_count++; + } + } + } + + // Initialize the vol and iou to 0.0 in case there are no triangles + // in the intersecting shape. + float vol = 0.0; + float iou = 0.0; + + // If there are triangles in the intersecting shape + if (box1_count > 0) { + // The intersecting shape is a polyhedron made up of the + // triangular faces that are all now in box1_intersect. + // Calculate the polyhedron center + const float3 poly_center = PolyhedronCenter(box1_intersect, box1_count); + // Compute intersecting polyhedron volume + vol = BoxVolume(box1_intersect, poly_center, box1_count); + // Compute IoU + iou = vol / (box1_vol + box2_vol - vol); + } + + // Write the volume and IoU to global memory + vols[n][m] = vol; + ious[n][m] = iou; + } +} + +std::tuple IoUBox3DCuda( + const at::Tensor& boxes1, // (N, 8, 3) + const at::Tensor& boxes2) { // (M, 8, 3) + // Check inputs are on the same device + at::TensorArg boxes1_t{boxes1, "boxes1", 1}, boxes2_t{boxes2, "boxes2", 2}; + at::CheckedFrom c = "IoUBox3DCuda"; + at::checkAllSameGPU(c, {boxes1_t, boxes2_t}); + at::checkAllSameType(c, {boxes1_t, boxes2_t}); + + // Set the device for the kernel launch based on the device of boxes1 + at::cuda::CUDAGuard device_guard(boxes1.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + TORCH_CHECK(boxes2.size(2) == boxes1.size(2), "Boxes must have shape (8, 3)"); + + TORCH_CHECK( + (boxes2.size(1) == 8) && (boxes1.size(1) == 8), + "Boxes must have shape (8, 3)"); + + const int64_t N = boxes1.size(0); + const int64_t M = boxes2.size(0); + + auto vols = at::zeros({N, M}, boxes1.options()); + auto ious = at::zeros({N, M}, boxes1.options()); + + if (vols.numel() == 0) { + AT_CUDA_CHECK(cudaGetLastError()); + return std::make_tuple(vols, ious); + } + + const size_t blocks = 512; + const size_t threads = 256; + + IoUBox3DKernel<<>>( + boxes1.packed_accessor64(), + boxes2.packed_accessor64(), + vols.packed_accessor64(), + ious.packed_accessor64()); + + AT_CUDA_CHECK(cudaGetLastError()); + + return std::make_tuple(vols, ious); +} diff --git a/pytorch3d/csrc/iou_box3d/iou_box3d.h b/pytorch3d/csrc/iou_box3d/iou_box3d.h index 50703c6e..1fedfd10 100644 --- a/pytorch3d/csrc/iou_box3d/iou_box3d.h +++ b/pytorch3d/csrc/iou_box3d/iou_box3d.h @@ -26,12 +26,23 @@ std::tuple IoUBox3DCpu( const at::Tensor& boxes1, const at::Tensor& boxes2); +// CUDA implementation +std::tuple IoUBox3DCuda( + const at::Tensor& boxes1, + const at::Tensor& boxes2); + // Implementation which is exposed inline std::tuple IoUBox3D( const at::Tensor& boxes1, const at::Tensor& boxes2) { if (boxes1.is_cuda() || boxes2.is_cuda()) { - AT_ERROR("GPU support not implemented"); +#ifdef WITH_CUDA + CHECK_CUDA(boxes1); + CHECK_CUDA(boxes2); + return IoUBox3DCuda(boxes1.contiguous(), boxes2.contiguous()); +#else + AT_ERROR("Not compiled with GPU support."); +#endif } return IoUBox3DCpu(boxes1.contiguous(), boxes2.contiguous()); } diff --git a/pytorch3d/csrc/iou_box3d/iou_box3d_cpu.cpp b/pytorch3d/csrc/iou_box3d/iou_box3d_cpu.cpp index a9d60a19..754e7770 100644 --- a/pytorch3d/csrc/iou_box3d/iou_box3d_cpu.cpp +++ b/pytorch3d/csrc/iou_box3d/iou_box3d_cpu.cpp @@ -79,7 +79,7 @@ std::tuple IoUBox3DCpu( std::fill(tri2_keep.begin(), tri2_keep.end(), 1); for (int b1 = 0; b1 < box1_intersect.size(); ++b1) { for (int b2 = 0; b2 < box2_intersect.size(); ++b2) { - bool is_coplanar = + const bool is_coplanar = IsCoplanarFace(box1_intersect[b1], box2_intersect[b2]); if (is_coplanar) { tri2_keep[b2] = 0; diff --git a/pytorch3d/csrc/iou_box3d/iou_utils.cuh b/pytorch3d/csrc/iou_box3d/iou_utils.cuh new file mode 100644 index 00000000..452cf7dc --- /dev/null +++ b/pytorch3d/csrc/iou_box3d/iou_utils.cuh @@ -0,0 +1,584 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include "utils/float_math.cuh" +#include "utils/geometry_utils.cuh" + +/* +_PLANES and _TRIS define the 4- and 3-connectivity +of the 8 box corners. +_PLANES gives the quad faces of the 3D box +_TRIS gives the triangle faces of the 3D box +*/ +const int NUM_PLANES = 6; +const int NUM_TRIS = 12; +// This is required for iniitalizing the faces +// in the intersecting shape +const int MAX_TRIS = 100; + +// Create data types for representing the +// verts for each face and the indices. +// We will use struct arrays for representing +// the data for each box and intersecting +// triangles +typedef struct { + float3 v0; + float3 v1; + float3 v2; + float3 v3; // Can be empty for triangles +} FaceVerts; + +typedef struct { + int v0; + int v1; + int v2; + int v3; // Can be empty for triangles +} FaceVertsIdx; + +// This is used when deciding which faces to +// keep that are not coplanar +typedef struct { + bool keep; +} Keep; + +__device__ FaceVertsIdx _PLANES[] = { + {0, 1, 2, 3}, + {3, 2, 6, 7}, + {0, 1, 5, 4}, + {0, 3, 7, 4}, + {1, 5, 6, 2}, + {4, 5, 6, 7}, +}; +__device__ FaceVertsIdx _TRIS[] = { + {0, 1, 2}, + {0, 3, 2}, + {4, 5, 6}, + {4, 6, 7}, + {1, 5, 6}, + {1, 6, 2}, + {0, 4, 7}, + {0, 7, 3}, + {3, 2, 6}, + {3, 6, 7}, + {0, 1, 5}, + {0, 4, 5}, +}; + +// Args +// box: (8, 3) tensor accessor for the box vertices +// box_tris: Array of structs of type FaceVerts, +// effectively (F, 3, 3) where the coordinates of the +// verts for each face will be saved to. +// +// Returns: None (output saved to box_tris) +// +template +__device__ inline void GetBoxTris(const Box& box, BoxTris& box_tris) { + for (int t = 0; t < NUM_TRIS; ++t) { + const float3 v0 = make_float3( + box[_TRIS[t].v0][0], box[_TRIS[t].v0][1], box[_TRIS[t].v0][2]); + const float3 v1 = make_float3( + box[_TRIS[t].v1][0], box[_TRIS[t].v1][1], box[_TRIS[t].v1][2]); + const float3 v2 = make_float3( + box[_TRIS[t].v2][0], box[_TRIS[t].v2][1], box[_TRIS[t].v2][2]); + box_tris[t] = {v0, v1, v2}; + } +} + +// Args +// box: (8, 3) tensor accessor for the box vertices +// box_planes: Array of structs of type FaceVerts, effectively (P, 4, 3) +// where the coordinates of the verts for the four corners of each plane +// will be saved to +// +// Returns: None (output saved to box_planes) +// +template +__device__ inline void GetBoxPlanes( + const Box& box, + FaceVertsBoxPlanes& box_planes) { + for (int t = 0; t < NUM_PLANES; ++t) { + const float3 v0 = make_float3( + box[_PLANES[t].v0][0], box[_PLANES[t].v0][1], box[_PLANES[t].v0][2]); + const float3 v1 = make_float3( + box[_PLANES[t].v1][0], box[_PLANES[t].v1][1], box[_PLANES[t].v1][2]); + const float3 v2 = make_float3( + box[_PLANES[t].v2][0], box[_PLANES[t].v2][1], box[_PLANES[t].v2][2]); + const float3 v3 = make_float3( + box[_PLANES[t].v3][0], box[_PLANES[t].v3][1], box[_PLANES[t].v3][2]); + box_planes[t] = {v0, v1, v2, v3}; + } +} + +// The normal of the face defined by vertices (v0, v1, v2) +// Define e0 to be the edge connecting (v1, v0) +// Define e1 to be the edge connecting (v2, v0) +// normal is the cross product of e0, e1 +// +// Args +// v0, v1, v2: float3 coordinates of the vertices of the face +// +// Returns +// float3: normal for the face +// +__device__ inline float3 +FaceNormal(const float3 v0, const float3 v1, const float3 v2) { + float3 n = cross(v1 - v0, v2 - v0); + n = n / fmaxf(norm(n), kEpsilon); + return n; +} + +// The normal of a box plane defined by the verts in `plane` with +// the centroid of the box given by `center`. +// Args +// plane: float3 coordinates of the vertices of the plane +// center: float3 coordinates of the center of the box from +// which the plane originated +// +// Returns +// float3: normal for the plane such that it points towards +// the center of the box +// +template +__device__ inline float3 PlaneNormalDirection( + const FaceVertsPlane& plane, + const float3& center) { + // Only need the first 3 verts of the plane + const float3 v0 = plane.v0; + const float3 v1 = plane.v1; + const float3 v2 = plane.v2; + + // We project the center on the plane defined by (v0, v1, v2) + // We can write center = v0 + a * e0 + b * e1 + c * n + // We know that = 0 and = 0 and + // is the dot product between a and b. + // This means we can solve for c as: + // c =
=
+ float3 n = FaceNormal(v0, v1, v2); + const float c = dot((center - v0), n); + + // If c is negative, then we revert the direction of n such that n + // points "inside" + if (c < kEpsilon) { + n = -1.0f * n; + } + + return n; +} + +// Calculate the volume of the box by summing the volume of +// each of the tetrahedrons formed with a triangle face and +// the box centroid. +// +// Args +// box_tris: vector of float3 coordinates of the vertices of each +// of the triangles in the box +// box_center: float3 coordinates of the center of the box +// +// Returns +// float: volume of the box +// +template +__device__ inline float BoxVolume( + const BoxTris& box_tris, + const float3& box_center, + const int num_tris) { + float box_vol = 0.0; + // Iterate through each triange, calculate the area of the + // tetrahedron formed with the box_center and sum them + for (int t = 0; t < num_tris; ++t) { + // Subtract the center: + float3 v0 = box_tris[t].v0; + float3 v1 = box_tris[t].v1; + float3 v2 = box_tris[t].v2; + + v0 = v0 - box_center; + v1 = v1 - box_center; + v2 = v2 - box_center; + + // Compute the area + const float area = dot(v0, cross(v1, v2)); + const float vol = abs(area) / 6.0; + box_vol = box_vol + vol; + } + return box_vol; +} + +// Compute the box center as the mean of the verts +// +// Args +// box_verts: (8, 3) tensor of the corner vertices of the box +// +// Returns +// float3: coordinates of the center of the box +// +template +__device__ inline float3 BoxCenter(const Box box_verts) { + float x = 0.0; + float y = 0.0; + float z = 0.0; + const int num_verts = box_verts.size(0); // Should be 8 + // Sum all x, y, z, and take the mean + for (int t = 0; t < num_verts; ++t) { + x = x + box_verts[t][0]; + y = y + box_verts[t][1]; + z = z + box_verts[t][2]; + } + // Take the mean of all the vertex positions + x = x / num_verts; + y = y / num_verts; + z = z / num_verts; + const float3 center = make_float3(x, y, z); + return center; +} + +// Compute the polyhedron center as the mean of the face centers +// of the triangle faces +// +// Args +// tris: vector of float3 coordinates of the +// vertices of each of the triangles in the polyhedron +// +// Returns +// float3: coordinates of the center of the polyhedron +// +template +__device__ inline float3 PolyhedronCenter( + const Tris& tris, + const int num_tris) { + float x = 0.0; + float y = 0.0; + float z = 0.0; + + // Find the center point of each face + for (int t = 0; t < num_tris; ++t) { + const float3 v0 = tris[t].v0; + const float3 v1 = tris[t].v1; + const float3 v2 = tris[t].v2; + const float x_face = (v0.x + v1.x + v2.x) / 3.0; + const float y_face = (v0.y + v1.y + v2.y) / 3.0; + const float z_face = (v0.z + v1.z + v2.z) / 3.0; + x = x + x_face; + y = y + y_face; + z = z + z_face; + } + + // Take the mean of the centers of all faces + x = x / num_tris; + y = y / num_tris; + z = z / num_tris; + + const float3 center = make_float3(x, y, z); + return center; +} + +// Compute a boolean indicator for whether a point +// is inside a plane, where inside refers to whether +// or not the point has a component in the +// normal direction of the plane. +// +// Args +// plane: vector of float3 coordinates of the +// vertices of each of the triangles in the box +// normal: float3 of the direction of the plane normal +// point: float3 of the position of the point of interest +// +// Returns +// bool: whether or not the point is inside the plane +// +__device__ inline bool +IsInside(const FaceVerts& plane, const float3& normal, const float3& point) { + // Get one vert of the plane + const float3 v0 = plane.v0; + + // Every point p can be written as p = v0 + a e0 + b e1 + c n + // Solving for c: + // c = (point - v0 - a * e0 - b * e1).dot(n) + // We know that = 0 and = 0 + // So the calculation can be simplified as: + const float c = dot((point - v0), normal); + const bool inside = c > -1.0f * kEpsilon; + return inside; +} + +// Find the point of intersection between a plane +// and a line given by the end points (p0, p1) +// +// Args +// plane: vector of float3 coordinates of the +// vertices of each of the triangles in the box +// normal: float3 of the direction of the plane normal +// p0, p1: float3 of the start and end point of the line +// +// Returns +// float3: position of the intersection point +// +__device__ inline float3 PlaneEdgeIntersection( + const FaceVerts& plane, + const float3& normal, + const float3& p0, + const float3& p1) { + // Get one vert of the plane + const float3 v0 = plane.v0; + + // The point of intersection can be parametrized + // p = p0 + a (p1 - p0) where a in [0, 1] + // We want to find a such that p is on plane + //

= 0 + const float top = dot(-1.0f * (p0 - v0), normal); + const float bot = dot(p1 - p0, normal); + const float a = top / bot; + const float3 p = p0 + a * (p1 - p0); + return p; +} + +// Triangle is clipped into a quadrilateral +// based on the intersection points with the plane. +// Then the quadrilateral is divided into two triangles. +// +// Args +// plane: vector of float3 coordinates of the +// vertices of each of the triangles in the box +// normal: float3 of the direction of the plane normal +// vout: float3 of the point in the triangle which is outside +// the plane +// vin1, vin2: float3 of the points in the triangle which are +// inside the plane +// face_verts_out: Array of structs of type FaceVerts, +// with the coordinates of the new triangle faces +// formed after clipping. +// All triangles are now "inside" the plane. +// +// Returns: +// count: (int) number of new faces after clipping the triangle +// i.e. the valid faces which have been saved +// to face_verts_out +// +template +__device__ inline int ClipTriByPlaneOneOut( + const FaceVerts& plane, + const float3& normal, + const float3& vout, + const float3& vin1, + const float3& vin2, + FaceVertsBox& face_verts_out) { + // point of intersection between plane and (vin1, vout) + const float3 pint1 = PlaneEdgeIntersection(plane, normal, vin1, vout); + // point of intersection between plane and (vin2, vout) + const float3 pint2 = PlaneEdgeIntersection(plane, normal, vin2, vout); + + face_verts_out[0] = {vin1, pint1, pint2}; + face_verts_out[1] = {vin1, pint2, vin2}; + + return 2; +} + +// Triangle is clipped into a smaller triangle based +// on the intersection points with the plane. +// +// Args +// plane: vector of float3 coordinates of the +// vertices of each of the triangles in the box +// normal: float3 of the direction of the plane normal +// vout1, vout2: float3 of the points in the triangle which are +// outside the plane +// vin: float3 of the point in the triangle which is inside +// the plane +// face_verts_out: Array of structs of type FaceVerts, +// with the coordinates of the new triangle faces +// formed after clipping. +// All triangles are now "inside" the plane. +// +// Returns: +// count: (int) number of new faces after clipping the triangle +// i.e. the valid faces which have been saved +// to face_verts_out +// +template +__device__ inline int ClipTriByPlaneTwoOut( + const FaceVerts& plane, + const float3& normal, + const float3& vout1, + const float3& vout2, + const float3& vin, + FaceVertsBox& face_verts_out) { + // point of intersection between plane and (vin, vout1) + const float3 pint1 = PlaneEdgeIntersection(plane, normal, vin, vout1); + // point of intersection between plane and (vin, vout2) + const float3 pint2 = PlaneEdgeIntersection(plane, normal, vin, vout2); + + face_verts_out[0] = {vin, pint1, pint2}; + + return 1; +} + +// Clip the triangle faces so that they lie within the +// plane, creating new triangle faces where necessary. +// +// Args +// plane: Array of structs of type FaceVerts with the coordinates +// of the vertices of each of the triangles in the box +// tri: Array of structs of type FaceVerts with the vertex +// coordinates of the triangle faces +// normal: float3 of the direction of the plane normal +// face_verts_out: Array of structs of type FaceVerts, +// with the coordinates of the new triangle faces +// formed after clipping. +// All triangles are now "inside" the plane. +// +// Returns: +// count: (int) number of new faces after clipping the triangle +// i.e. the valid faces which have been saved +// to face_verts_out +// +template +__device__ inline int ClipTriByPlane( + const FaceVerts& plane, + const FaceVerts& tri, + const float3& normal, + FaceVertsBox& face_verts_out) { + // Get Triangle vertices + const float3 v0 = tri.v0; + const float3 v1 = tri.v1; + const float3 v2 = tri.v2; + + // Check each of the triangle vertices to see if it is inside the plane + const bool isin0 = IsInside(plane, normal, v0); + const bool isin1 = IsInside(plane, normal, v1); + const bool isin2 = IsInside(plane, normal, v2); + + // All in + if (isin0 && isin1 && isin2) { + // Return input vertices + face_verts_out[0] = {v0, v1, v2}; + return 1; + } + + // All out + if (!isin0 && !isin1 && !isin2) { + return 0; + } + + // One vert out + if (isin0 && isin1 && !isin2) { + return ClipTriByPlaneOneOut(plane, normal, v2, v0, v1, face_verts_out); + } + if (isin0 && not isin1 && isin2) { + return ClipTriByPlaneOneOut(plane, normal, v1, v0, v2, face_verts_out); + } + if (not isin0 && isin1 && isin2) { + return ClipTriByPlaneOneOut(plane, normal, v0, v1, v2, face_verts_out); + } + + // Two verts out + if (isin0 && !isin1 && !isin2) { + return ClipTriByPlaneTwoOut(plane, normal, v1, v2, v0, face_verts_out); + } + if (!isin0 && !isin1 && isin2) { + return ClipTriByPlaneTwoOut(plane, normal, v0, v1, v2, face_verts_out); + } + if (!isin0 && isin1 && !isin2) { + return ClipTriByPlaneTwoOut(plane, normal, v0, v2, v1, face_verts_out); + } + + // Else return empty (should not be reached) + return 0; +} + +// Compute a boolean indicator for whether or not two faces +// are coplanar +// +// Args +// tri1, tri2: FaceVerts struct of the vertex coordinates of +// the triangle face +// +// Returns +// bool: whether or not the two faces are coplanar +// +__device__ inline bool IsCoplanarFace( + const FaceVerts& tri1, + const FaceVerts& tri2) { + // Get verts for face 1 + const float3 v0 = tri1.v0; + const float3 v1 = tri1.v1; + const float3 v2 = tri1.v2; + + const float3 n1 = FaceNormal(v0, v1, v2); + int coplanar_count = 0; + + // Check v0, v1, v2 + if (abs(dot(tri2.v0 - v0, n1)) < kEpsilon) { + coplanar_count++; + } + if (abs(dot(tri2.v1 - v0, n1)) < kEpsilon) { + coplanar_count++; + } + if (abs(dot(tri2.v2 - v0, n1)) < kEpsilon) { + coplanar_count++; + } + return (coplanar_count == 3); +} + +// Get the triangles from each box which are part of the +// intersecting polyhedron by computing the intersection +// points with each of the planes. +// +// Args +// planes: Array of structs of type FaceVerts with the coordinates +// of the vertices of each of the triangles in the box +// center: float3 coordinates of the center of the box from which +// the planes originate +// face_verts_out: Array of structs of type FaceVerts, +// where the coordinates of the new triangle faces +// formed after clipping will be saved to. +// All triangles are now "inside" the plane. +// +// Returns: +// count: (int) number of faces in the intersecting shape +// i.e. the valid faces which have been saved +// to face_verts_out +// +template +__device__ inline int BoxIntersections( + const FaceVertsPlane& planes, + const float3& center, + FaceVertsBox& face_verts_out) { + // Initialize num tris to 12 + int num_tris = NUM_TRIS; + for (int p = 0; p < NUM_PLANES; ++p) { + // Get plane normal direction + const float3 n2 = PlaneNormalDirection(planes[p], center); + // Create intermediate vector to store the updated tris + FaceVerts tri_verts_updated[MAX_TRIS]; + int offset = 0; + + // Iterate through triangles in face_verts_out + // for the valid tris given by num_tris + for (int t = 0; t < num_tris; ++t) { + // Clip tri by plane, can max be split into 2 triangles + FaceVerts tri_updated[2]; + const int count = + ClipTriByPlane(planes[p], face_verts_out[t], n2, tri_updated); + // Add to the tri_verts_updated output if not empty + for (int v = 0; v < count; ++v) { + tri_verts_updated[offset] = tri_updated[v]; + offset++; + } + } + // Update the face_verts_out tris + num_tris = offset; + for (int j = 0; j < num_tris; ++j) { + face_verts_out[j] = tri_verts_updated[j]; + } + } + return num_tris; +} diff --git a/pytorch3d/ops/__init__.py b/pytorch3d/ops/__init__.py index aeba4495..2a9b0282 100644 --- a/pytorch3d/ops/__init__.py +++ b/pytorch3d/ops/__init__.py @@ -9,6 +9,7 @@ from .cameras_alignment import corresponding_cameras_alignment from .cubify import cubify from .graph_conv import GraphConv from .interp_face_attrs import interpolate_face_attributes +from .iou_box3d import box3d_overlap from .knn import knn_gather, knn_points from .laplacian_matrices import cot_laplacian, laplacian, norm_laplacian from .mesh_face_areas_normals import mesh_face_areas_normals diff --git a/pytorch3d/ops/iou_box3d.py b/pytorch3d/ops/iou_box3d.py index 2d248ec8..2c1979b8 100644 --- a/pytorch3d/ops/iou_box3d.py +++ b/pytorch3d/ops/iou_box3d.py @@ -7,10 +7,68 @@ from typing import Tuple import torch +import torch.nn.functional as F from pytorch3d import _C from torch.autograd import Function +# -------------------------------------------------- # +# CONSTANTS # +# -------------------------------------------------- # +""" +_box_planes and _box_triangles define the 4- and 3-connectivity +of the 8 box corners. +_box_planes gives the quad faces of the 3D box +_box_triangles gives the triangle faces of the 3D box +""" +_box_planes = [ + [0, 1, 2, 3], + [3, 2, 6, 7], + [0, 1, 5, 4], + [0, 3, 7, 4], + [1, 2, 6, 5], + [4, 5, 6, 7], +] +_box_triangles = [ + [0, 1, 2], + [0, 3, 2], + [4, 5, 6], + [4, 6, 7], + [1, 5, 6], + [1, 6, 2], + [0, 4, 7], + [0, 7, 3], + [3, 2, 6], + [3, 6, 7], + [0, 1, 5], + [0, 4, 5], +] + + +def _check_coplanar(boxes: torch.Tensor, eps: float = 1e-5) -> None: + faces = torch.tensor(_box_planes, dtype=torch.int64, device=boxes.device) + # pyre-fixme[16]: `boxes` has no attribute `index_select`. + verts = boxes.index_select(index=faces.view(-1), dim=1) + B = boxes.shape[0] + P, V = faces.shape + # (B, P, 4, 3) -> (B, P, 3) + v0, v1, v2, v3 = verts.reshape(B, P, V, 3).unbind(2) + + # Compute the normal + e0 = F.normalize(v1 - v0, dim=-1) + e1 = F.normalize(v2 - v0, dim=-1) + normal = F.normalize(torch.cross(e0, e1, dim=-1), dim=-1) + + # Check the fourth vertex is also on the same plane + mat1 = (v3 - v0).view(B, 1, -1) # (B, 1, P*3) + mat2 = normal.view(B, -1, 1) # (B, P*3, 1) + if not (mat1.bmm(mat2).abs() < eps).all().item(): + msg = "Plane vertices are not coplanar" + raise ValueError(msg) + + return + + class _box3d_overlap(Function): """ Torch autograd Function wrapper for box3d_overlap C++/CUDA implementations. @@ -35,6 +93,7 @@ def box3d_overlap( ) -> Tuple[torch.Tensor, torch.Tensor]: """ Computes the intersection of 3D boxes1 and boxes2. + Inputs boxes1, boxes2 are tensors of shape (B, 8, 3) (where B doesn't have to be the same for boxes1 and boxes1), containing the 8 corners of the boxes, as follows: @@ -47,6 +106,25 @@ def box3d_overlap( ` . | ` . | (3) ` +---------+ (2) + + NOTE: Throughout this implementation, we assume that boxes + are defined by their 8 corners exactly in the order specified in the + diagram above for the function to give correct results. In addition + the vertices on each plane must be coplanar. + As an alternative to the diagram, this is a unit bounding + box which has the correct vertex ordering: + + box_corner_vertices = [ + [0, 0, 0], + [1, 0, 0], + [1, 1, 0], + [0, 1, 0], + [0, 0, 1], + [1, 0, 1], + [1, 1, 1], + [0, 1, 1], + ] + Args: boxes1: tensor of shape (N, 8, 3) of the coordinates of the 1st boxes boxes2: tensor of shape (M, 8, 3) of the coordinates of the 2nd boxes @@ -58,6 +136,9 @@ def box3d_overlap( if not all((8, 3) == box.shape[1:] for box in [boxes1, boxes2]): raise ValueError("Each box in the batch must be of shape (8, 3)") + _check_coplanar(boxes1) + _check_coplanar(boxes2) + # pyre-fixme[16]: `_box3d_overlap` has no attribute `apply`. vol, iou = _box3d_overlap.apply(boxes1, boxes2) diff --git a/tests/bm_iou_box3d.py b/tests/bm_iou_box3d.py index b8b0d2d4..0794fe55 100644 --- a/tests/bm_iou_box3d.py +++ b/tests/bm_iou_box3d.py @@ -11,25 +11,42 @@ from test_iou_box3d import TestIoU3D def bm_iou_box3d() -> None: - N = [1, 4, 8, 16] - num_samples = [2000, 5000, 10000, 20000] + # Realistic use cases + N = [30, 100] + M = [5, 10, 100] + kwargs_list = [] + test_cases = product(N, M) + for case in test_cases: + n, m = case + kwargs_list.append({"N": n, "M": m, "device": "cuda:0"}) + benchmark(TestIoU3D.iou, "3D_IOU", kwargs_list, warmup_iters=1) + # Comparison of C++/CUDA + kwargs_list = [] + N = [1, 4, 8, 16] + devices = ["cpu", "cuda:0"] + test_cases = product(N, N, devices) + for case in test_cases: + n, m, d = case + kwargs_list.append({"N": n, "M": m, "device": d}) + benchmark(TestIoU3D.iou, "3D_IOU", kwargs_list, warmup_iters=1) + + # Naive PyTorch + N = [1, 4] kwargs_list = [] test_cases = product(N, N) for case in test_cases: n, m = case kwargs_list.append({"N": n, "M": m, "device": "cuda:0"}) - benchmark(TestIoU3D.iou_naive, "3D_IOU_NAIVE", kwargs_list, warmup_iters=1) - [k.update({"device": "cpu"}) for k in kwargs_list] - benchmark(TestIoU3D.iou, "3D_IOU", kwargs_list, warmup_iters=1) - + # Sampling based method + num_samples = [2000, 5000] kwargs_list = [] - test_cases = product([1, 4], [1, 4], num_samples) + test_cases = product(N, N, num_samples) for case in test_cases: n, m, s = case - kwargs_list.append({"N": n, "M": m, "num_samples": s}) + kwargs_list.append({"N": n, "M": m, "num_samples": s, "device": "cuda:0"}) benchmark(TestIoU3D.iou_sampling, "3D_IOU_SAMPLING", kwargs_list, warmup_iters=1) diff --git a/tests/data/objectron_vols_ious.pt b/tests/data/objectron_vols_ious.pt new file mode 100644 index 00000000..b15c78ef Binary files /dev/null and b/tests/data/objectron_vols_ious.pt differ diff --git a/tests/test_iou_box3d.py b/tests/test_iou_box3d.py index b8c0eb6d..f98e6582 100644 --- a/tests/test_iou_box3d.py +++ b/tests/test_iou_box3d.py @@ -10,13 +10,28 @@ from typing import List, Tuple, Union import torch import torch.nn.functional as F -from common_testing import TestCaseMixin +from common_testing import TestCaseMixin, get_random_cuda_device, get_tests_dir from pytorch3d.io import save_obj - -from pytorch3d.ops.iou_box3d import box3d_overlap +from pytorch3d.ops.iou_box3d import _box_planes, _box_triangles, box3d_overlap from pytorch3d.transforms.rotation_conversions import random_rotation +OBJECTRON_TO_PYTORCH3D_FACE_IDX = [0, 4, 6, 2, 1, 5, 7, 3] +DATA_DIR = get_tests_dir() / "data" +DEBUG = False + +UNIT_BOX = [ + [0, 0, 0], + [1, 0, 0], + [1, 1, 0], + [0, 1, 0], + [0, 0, 1], + [1, 0, 1], + [1, 1, 1], + [0, 1, 1], +] + + class TestIoU3D(TestCaseMixin, unittest.TestCase): def setUp(self) -> None: super().setUp() @@ -78,16 +93,7 @@ class TestIoU3D(TestCaseMixin, unittest.TestCase): def _test_iou(self, overlap_fn, device): box1 = torch.tensor( - [ - [0, 0, 0], - [1, 0, 0], - [1, 1, 0], - [0, 1, 0], - [0, 0, 1], - [1, 0, 1], - [1, 1, 1], - [0, 1, 1], - ], + UNIT_BOX, dtype=torch.float32, device=device, ) @@ -126,6 +132,10 @@ class TestIoU3D(TestCaseMixin, unittest.TestCase): ), ) + # Also check IoU is 1 when computing overlap with the same shifted box + vol, iou = overlap_fn(box2[None], box2[None]) + self.assertClose(iou, torch.tensor([[1.0]], device=vol.device, dtype=vol.dtype)) + # 5th test ddx, ddy, ddz = random.random(), random.random(), random.random() box2 = box1 + torch.tensor([[ddx, ddy, ddz]], device=device) @@ -207,15 +217,15 @@ class TestIoU3D(TestCaseMixin, unittest.TestCase): # create box1 ctrs = torch.rand((2, 3), device=device) whl = torch.rand((2, 3), device=device) * 10.0 + 1.0 - # box1 & box2 - box1 = self.create_box(ctrs[0], whl[0]) - box2 = self.create_box(ctrs[1], whl[1]) + # box8a & box8b + box8a = self.create_box(ctrs[0], whl[0]) + box8b = self.create_box(ctrs[1], whl[1]) RR1 = random_rotation(dtype=torch.float32, device=device) TT1 = torch.rand((1, 3), dtype=torch.float32, device=device) RR2 = random_rotation(dtype=torch.float32, device=device) TT2 = torch.rand((1, 3), dtype=torch.float32, device=device) - box1r = box1 @ RR1.transpose(0, 1) + TT1 - box2r = box2 @ RR2.transpose(0, 1) + TT2 + box1r = box8a @ RR1.transpose(0, 1) + TT1 + box2r = box8b @ RR2.transpose(0, 1) + TT2 vol, iou = overlap_fn(box1r[None], box2r[None]) iou_sampling = self._box3d_overlap_sampling_batched( box1r[None], box2r[None], num_samples=10000 @@ -229,27 +239,90 @@ class TestIoU3D(TestCaseMixin, unittest.TestCase): self.assertClose(vol, torch.tensor([[0.0]], device=vol.device, dtype=vol.dtype)) self.assertClose(iou, torch.tensor([[0.0]], device=vol.device, dtype=vol.dtype)) + # 10th test: Non coplanar verts in a plane + box10 = box1 + torch.rand((8, 3), dtype=torch.float32, device=device) + msg = "Plane vertices are not coplanar" + with self.assertRaisesRegex(ValueError, msg): + overlap_fn(box10[None], box10[None]) + + # 11th test: Skewed bounding boxes but all verts are coplanar + box_skew_1 = torch.tensor( + [ + [0, 0, 0], + [1, 0, 0], + [1, 1, 0], + [0, 1, 0], + [-2, -2, 2], + [2, -2, 2], + [2, 2, 2], + [-2, 2, 2], + ], + dtype=torch.float32, + device=device, + ) + box_skew_2 = torch.tensor( + [ + [2.015995, 0.695233, 2.152806], + [2.832533, 0.663448, 1.576389], + [2.675445, -0.309592, 1.407520], + [1.858907, -0.277806, 1.983936], + [-0.413922, 3.161758, 2.044343], + [2.852230, 3.034615, -0.261321], + [2.223878, -0.857545, -0.936800], + [-1.042273, -0.730402, 1.368864], + ], + dtype=torch.float32, + device=device, + ) + vol1 = 14.000 + vol2 = 14.000005 + vol_inters = 5.431122 + iou = vol_inters / (vol1 + vol2 - vol_inters) + + vols, ious = overlap_fn(box_skew_1[None], box_skew_2[None]) + self.assertClose(vols, torch.tensor([[vol_inters]], device=device), atol=1e-1) + self.assertClose(ious, torch.tensor([[iou]], device=device), atol=1e-1) + def test_iou_naive(self): - device = torch.device("cuda:0") + device = get_random_cuda_device() self._test_iou(self._box3d_overlap_naive_batched, device) + self._test_compare_objectron(self._box3d_overlap_naive_batched, device) def test_iou_cpu(self): device = torch.device("cpu") self._test_iou(box3d_overlap, device) + self._test_compare_objectron(box3d_overlap, device) - def test_cpu_vs_naive_batched(self): - N, M = 3, 6 - device = "cpu" - boxes1 = torch.randn((N, 8, 3), device=device) - boxes2 = torch.randn((M, 8, 3), device=device) - vol1, iou1 = self._box3d_overlap_naive_batched(boxes1, boxes2) - vol2, iou2 = box3d_overlap(boxes1, boxes2) - # check shape - for val in [vol1, vol2, iou1, iou2]: - self.assertClose(val.shape, (N, M)) - # check values - self.assertClose(vol1, vol2) - self.assertClose(iou1, iou2) + def test_iou_cuda(self): + device = torch.device("cuda:0") + self._test_iou(box3d_overlap, device) + self._test_compare_objectron(box3d_overlap, device) + + def _test_compare_objectron(self, overlap_fn, device): + # Load saved objectron data + data_filename = "./objectron_vols_ious.pt" + objectron_vals = torch.load(DATA_DIR / data_filename) + boxes1 = objectron_vals["boxes1"] + boxes2 = objectron_vals["boxes2"] + vols_objectron = objectron_vals["vols"] + ious_objectron = objectron_vals["ious"] + + boxes1 = boxes1.to(device=device, dtype=torch.float32) + boxes2 = boxes2.to(device=device, dtype=torch.float32) + + # Convert vertex orderings from Objectron to PyTorch3D convention + idx = torch.tensor( + OBJECTRON_TO_PYTORCH3D_FACE_IDX, dtype=torch.int64, device=device + ) + boxes1 = boxes1.index_select(index=idx, dim=1) + boxes2 = boxes2.index_select(index=idx, dim=1) + + # Run PyTorch3D version + vols, ious = overlap_fn(boxes1, boxes2) + + # Check values match + self.assertClose(vols_objectron, vols.cpu()) + self.assertClose(ious_objectron, ious.cpu()) def test_batched_errors(self): N, M = 5, 10 @@ -316,16 +389,7 @@ class TestIoU3D(TestCaseMixin, unittest.TestCase): def test_box_planar_dir(self): device = torch.device("cuda:0") box1 = torch.tensor( - [ - [0, 0, 0], - [1, 0, 0], - [1, 1, 0], - [0, 1, 0], - [0, 0, 1], - [1, 0, 1], - [1, 1, 1], - [0, 1, 1], - ], + UNIT_BOX, dtype=torch.float32, device=device, ) @@ -353,8 +417,13 @@ class TestIoU3D(TestCaseMixin, unittest.TestCase): @staticmethod def iou_naive(N: int, M: int, device="cpu"): - boxes1 = torch.randn((N, 8, 3)) - boxes2 = torch.randn((M, 8, 3)) + box = torch.tensor( + [UNIT_BOX], + dtype=torch.float32, + device=device, + ) + boxes1 = box + torch.randn((N, 1, 3), device=device) + boxes2 = box + torch.randn((M, 1, 3), device=device) def output(): vol, iou = TestIoU3D._box3d_overlap_naive_batched(boxes1, boxes2) @@ -363,8 +432,13 @@ class TestIoU3D(TestCaseMixin, unittest.TestCase): @staticmethod def iou(N: int, M: int, device="cpu"): - boxes1 = torch.randn((N, 8, 3), device=device) - boxes2 = torch.randn((M, 8, 3), device=device) + box = torch.tensor( + [UNIT_BOX], + dtype=torch.float32, + device=device, + ) + boxes1 = box + torch.randn((N, 1, 3), device=device) + boxes2 = box + torch.randn((M, 1, 3), device=device) def output(): vol, iou = box3d_overlap(boxes1, boxes2) @@ -372,9 +446,14 @@ class TestIoU3D(TestCaseMixin, unittest.TestCase): return output @staticmethod - def iou_sampling(N: int, M: int, num_samples: int): - boxes1 = torch.randn((N, 8, 3)) - boxes2 = torch.randn((M, 8, 3)) + def iou_sampling(N: int, M: int, num_samples: int, device="cpu"): + box = torch.tensor( + [UNIT_BOX], + dtype=torch.float32, + device=device, + ) + boxes1 = box + torch.randn((N, 1, 3), device=device) + boxes2 = box + torch.randn((M, 1, 3), device=device) def output(): _ = TestIoU3D._box3d_overlap_sampling_batched(boxes1, boxes2, num_samples) @@ -408,38 +487,6 @@ Note that both implementations currently do not support batching. # # -------------------------------------------------- # -# -------------------------------------------------- # -# CONSTANTS # -# -------------------------------------------------- # -""" -_box_planes and _box_triangles define the 4- and 3-connectivity -of the 8 box corners. -_box_planes gives the quad faces of the 3D box -_box_triangles gives the triangle faces of the 3D box -""" -_box_planes = [ - [0, 1, 2, 3], - [3, 2, 6, 7], - [0, 1, 5, 4], - [0, 3, 7, 4], - [1, 5, 6, 2], - [4, 5, 6, 7], -] -_box_triangles = [ - [0, 1, 2], - [0, 3, 2], - [4, 5, 6], - [4, 6, 7], - [1, 5, 6], - [1, 6, 2], - [0, 4, 7], - [0, 7, 3], - [3, 2, 6], - [3, 6, 7], - [0, 1, 5], - [0, 4, 5], -] - # -------------------------------------------------- # # HELPER FUNCTIONS FOR EXACT SOLUTION # # -------------------------------------------------- # @@ -477,7 +524,7 @@ def get_plane_verts(box: torch.Tensor) -> torch.Tensor: return plane_verts -def box_planar_dir(box: torch.Tensor) -> torch.Tensor: +def box_planar_dir(box: torch.Tensor, eps=1e-4) -> torch.Tensor: """ Finds the unit vector n which is perpendicular to each plane in the box and points towards the inside of the box. @@ -507,6 +554,11 @@ def box_planar_dir(box: torch.Tensor) -> torch.Tensor: e1 = F.normalize(v2 - v0, dim=-1) n = F.normalize(torch.cross(e0, e1, dim=-1), dim=-1) + # Check all verts are coplanar + if not ((v3 - v0).unsqueeze(1).bmm(n.unsqueeze(2)).abs() < eps).all().item(): + msg = "Plane vertices are not coplanar" + raise ValueError(msg) + # We can write: `ctr = v0 + a * e0 + b * e1 + c * n`, (1). # With = 0 and = 0, where <.,.> refers to the dot product, # since that e0 is orthogonal to n. Same for e1. @@ -733,10 +785,10 @@ def clip_tri_by_plane_oneout( device = plane.device # point of intersection between plane and (vin1, vout) pint1, a1 = plane_edge_point_of_intersection(plane, n, vin1, vout) - assert a1 >= eps and a1 <= 1.0 + assert a1 >= eps and a1 <= 1.0, a1 # point of intersection between plane and (vin2, vout) pint2, a2 = plane_edge_point_of_intersection(plane, n, vin2, vout) - assert a2 >= 0.0 and a2 <= 1.0 + assert a2 >= 0.0 and a2 <= 1.0, a2 verts = torch.stack((vin1, pint1, pint2, vin2), dim=0) # 4x3 faces = torch.tensor( @@ -771,10 +823,10 @@ def clip_tri_by_plane_twoout( device = plane.device # point of intersection between plane and (vin, vout1) pint1, a1 = plane_edge_point_of_intersection(plane, n, vin, vout1) - assert a1 >= eps and a1 <= 1.0 + assert a1 >= eps and a1 <= 1.0, a1 # point of intersection between plane and (vin, vout2) pint2, a2 = plane_edge_point_of_intersection(plane, n, vin, vout2) - assert a2 >= eps and a2 <= 1.0 + assert a2 >= eps and a2 <= 1.0, a2 verts = torch.stack((vin, pint1, pint2), dim=0) # 3x3 faces = torch.tensor( @@ -945,7 +997,7 @@ def box3d_overlap_naive(box1: torch.Tensor, box2: torch.Tensor): iou = vol / (vol1 + vol2 - vol) - if 0: + if DEBUG: # save shapes tri_faces = torch.tensor(_box_triangles, device=device, dtype=torch.int64) save_obj("/tmp/output/shape1.obj", box1, tri_faces)