diff --git a/pytorch3d/csrc/iou_box3d/iou_box3d.cu b/pytorch3d/csrc/iou_box3d/iou_box3d.cu index 4da2f1e8..32f50efb 100644 --- a/pytorch3d/csrc/iou_box3d/iou_box3d.cu +++ b/pytorch3d/csrc/iou_box3d/iou_box3d.cu @@ -29,14 +29,17 @@ __global__ void IoUBox3DKernel( const size_t tid = blockIdx.x * blockDim.x + threadIdx.x; const size_t stride = gridDim.x * blockDim.x; + std::array box1_tris{}; + std::array box2_tris{}; + std::array box1_planes{}; + std::array box2_planes{}; + for (size_t i = tid; i < N * M; i += stride) { const size_t n = i / M; // box1 index const size_t m = i % M; // box2 index // Convert to array of structs of face vertices i.e. effectively (F, 3, 3) // FaceVerts is a data type defined in iou_utils.cuh - FaceVerts box1_tris[NUM_TRIS]; - FaceVerts box2_tris[NUM_TRIS]; GetBoxTris(boxes1[n], box1_tris); GetBoxTris(boxes2[m], box2_tris); @@ -46,9 +49,7 @@ __global__ void IoUBox3DKernel( const float3 box2_center = BoxCenter(boxes2[m]); // Convert to an array of face vertices - FaceVerts box1_planes[NUM_PLANES]; GetBoxPlanes(boxes1[n], box1_planes); - FaceVerts box2_planes[NUM_PLANES]; GetBoxPlanes(boxes2[m], box2_planes); // Get Box Volumes diff --git a/pytorch3d/csrc/iou_box3d/iou_utils.cuh b/pytorch3d/csrc/iou_box3d/iou_utils.cuh index b4d08563..7f2f287a 100644 --- a/pytorch3d/csrc/iou_box3d/iou_utils.cuh +++ b/pytorch3d/csrc/iou_box3d/iou_utils.cuh @@ -39,25 +39,25 @@ const int MAX_TRIS = 100; // We will use struct arrays for representing // the data for each box and intersecting // triangles -typedef struct { +struct FaceVerts { float3 v0; float3 v1; float3 v2; float3 v3; // Can be empty for triangles -} FaceVerts; +}; -typedef struct { +struct FaceVertsIdx { int v0; int v1; int v2; int v3; // Can be empty for triangles -} FaceVertsIdx; +}; // This is used when deciding which faces to // keep that are not coplanar -typedef struct { +struct Keep { bool keep; -} Keep; +}; __device__ FaceVertsIdx _PLANES[] = { {0, 1, 2, 3}, @@ -128,6 +128,23 @@ __device__ inline void GetBoxPlanes( } } +// The geometric center of a list of vertices. +// +// Args +// vertices: A list of float3 vertices {v0, ..., vN}. +// +// Returns +// float3: Geometric center of the vertices. +// +__device__ inline float3 FaceCenter( + std::initializer_list vertices) { + auto sumVertices = float3{}; + for (const auto& vertex : vertices) { + sumVertices = sumVertices + vertex; + } + return sumVertices / vertices.size(); +} + // The normal of a plane spanned by vectors e0 and e1 // // Args @@ -142,50 +159,33 @@ __device__ inline float3 GetNormal(const float3 e0, const float3 e1) { return n; } -// The center of a triangle defined by vertices (v0, v1, v2) -// -// Args -// v0, v1, v2: float3 coordinates of the vertices of the triangle -// -// Returns -// float3: center of the triangle -// -__device__ inline float3 -TriCenter(const float3 v0, const float3 v1, const float3 v2) { - float3 ctr = (v0 + v1 + v2) / 3.0f; - return ctr; -} - -// The normal of the triangle defined by vertices (v0, v1, v2) +// The normal of a face with vertices (v0, v1, v2) or (v0, ..., v3). // We find the "best" edges connecting the face center to the vertices, // such that the cross product between the edges is maximized. // // Args -// v0, v1, v2: float3 coordinates of the vertices of the face +// vertices: a list of float3 coordinates of the vertices. // // Returns -// float3: normal for the face +// float3: center of the plane // -__device__ inline float3 -TriNormal(const float3 v0, const float3 v1, const float3 v2) { - const float3 ctr = TriCenter(v0, v1, v2); - - const float d01 = norm(cross(v0 - ctr, v1 - ctr)); - const float d02 = norm(cross(v0 - ctr, v2 - ctr)); - const float d12 = norm(cross(v1 - ctr, v2 - ctr)); - - float3 n = GetNormal(v0 - ctr, v1 - ctr); - float max_dist = d01; - - if (d02 > max_dist) { - max_dist = d02; - n = GetNormal(v0 - ctr, v2 - ctr); +__device__ inline float3 FaceNormal( + std::initializer_list vertices) { + const auto faceCenter = FaceCenter(vertices); + auto normal = float3(); + auto maxDist = -1; + for (auto v1 = vertices.begin(); v1 != vertices.end() - 1; ++v1) { + for (auto v2 = std::next(v1); v2 != vertices.end(); ++v2) { + const auto v1ToCenter = *v1 - faceCenter; + const auto v2ToCenter = *v2 - faceCenter; + const auto dist = norm(cross(v1ToCenter, v2ToCenter)); + if (dist > maxDist) { + normal = GetNormal(v1ToCenter, v2ToCenter); + maxDist = dist; + } + } } - if (d12 > max_dist) { - n = GetNormal(v1 - ctr, v2 - ctr); - } - - return n; + return normal; } // The area of the face defined by vertices (v0, v1, v2) @@ -201,79 +201,10 @@ TriNormal(const float3 v0, const float3 v1, const float3 v2) { // __device__ inline float FaceArea(const FaceVerts& tri) { // Get verts for face 1 - const float3 v0 = tri.v0; - const float3 v1 = tri.v1; - const float3 v2 = tri.v2; - const float3 n = cross(v1 - v0, v2 - v0); + const float3 n = cross(tri.v1 - tri.v0, tri.v2 - tri.v0); return norm(n) / 2.0; } -// The center of a plane defined by vertices (v0, v1, v2, v3) -// -// Args -// v0, v1, v2, v3: float3 coordinates of the vertices of the plane -// -// Returns -// float3: center of the plane -// -__device__ inline float3 PlaneCenter( - const float3 v0, - const float3 v1, - const float3 v2, - const float3 v3) { - float3 ctr = (v0 + v1 + v2 + v3) / 4.0f; - return ctr; -} - -// The normal of a planar face with vertices (v0, v1, v2, v3) -// We find the "best" edges connecting the face center to the vertices, -// such that the cross product between the edges is maximized. -// -// Args -// e0, e1: float3 coordinates of the vertices of the plane -// -// Returns -// float3: center of the plane -// -__device__ inline float3 PlaneNormal( - const float3 v0, - const float3 v1, - const float3 v2, - const float3 v3) { - const float3 ctr = PlaneCenter(v0, v1, v2, v3); - - const float d01 = norm(cross(v0 - ctr, v1 - ctr)); - const float d02 = norm(cross(v0 - ctr, v2 - ctr)); - const float d03 = norm(cross(v0 - ctr, v3 - ctr)); - const float d12 = norm(cross(v1 - ctr, v2 - ctr)); - const float d13 = norm(cross(v1 - ctr, v3 - ctr)); - const float d23 = norm(cross(v2 - ctr, v3 - ctr)); - - float max_dist = d01; - float3 n = GetNormal(v0 - ctr, v1 - ctr); - - if (d02 > max_dist) { - max_dist = d02; - n = GetNormal(v0 - ctr, v2 - ctr); - } - if (d03 > max_dist) { - max_dist = d03; - n = GetNormal(v0 - ctr, v3 - ctr); - } - if (d12 > max_dist) { - max_dist = d12; - n = GetNormal(v1 - ctr, v2 - ctr); - } - if (d13 > max_dist) { - max_dist = d13; - n = GetNormal(v1 - ctr, v3 - ctr); - } - if (d23 > max_dist) { - n = GetNormal(v2 - ctr, v3 - ctr); - } - return n; -} - // The normal of a box plane defined by the verts in `plane` such that it // points toward the centroid of the box given by `center`. // @@ -290,17 +221,12 @@ template __device__ inline float3 PlaneNormalDirection( const FaceVertsPlane& plane, const float3& center) { - // The plane's vertices - const float3 v0 = plane.v0; - const float3 v1 = plane.v1; - const float3 v2 = plane.v2; - const float3 v3 = plane.v3; - // The plane's center - const float3 plane_center = PlaneCenter(v0, v1, v2, v3); + const float3 plane_center = + FaceCenter({plane.v0, plane.v1, plane.v2, plane.v3}); // The plane's normal - float3 n = PlaneNormal(v0, v1, v2, v3); + float3 n = FaceNormal({plane.v0, plane.v1, plane.v2, plane.v3}); // We project the center on the plane defined by (v0, v1, v2, v3) // We can write center = plane_center + a * e0 + b * e1 + c * n @@ -442,14 +368,8 @@ __device__ inline float3 PolyhedronCenter( // __device__ inline bool IsInside(const FaceVerts& plane, const float3& normal, const float3& point) { - // Vertices of the plane - const float3 v0 = plane.v0; - const float3 v1 = plane.v1; - const float3 v2 = plane.v2; - const float3 v3 = plane.v3; - // The center of the plane - const float3 plane_ctr = PlaneCenter(v0, v1, v2, v3); + const float3 plane_ctr = FaceCenter({plane.v0, plane.v1, plane.v2, plane.v3}); // Every point p can be written as p = plane_ctr + a e0 + b e1 + c n // Solving for c: @@ -478,14 +398,8 @@ __device__ inline float3 PlaneEdgeIntersection( const float3& normal, const float3& p0, const float3& p1) { - // Vertices of the plane - const float3 v0 = plane.v0; - const float3 v1 = plane.v1; - const float3 v2 = plane.v2; - const float3 v3 = plane.v3; - // The center of the plane - const float3 plane_ctr = PlaneCenter(v0, v1, v2, v3); + const float3 plane_ctr = FaceCenter({plane.v0, plane.v1, plane.v2, plane.v3}); // The point of intersection can be parametrized // p = p0 + a (p1 - p0) where a in [0, 1] @@ -548,30 +462,18 @@ __device__ inline std::tuple ArgMaxVerts( __device__ inline bool IsCoplanarTriTri( const FaceVerts& tri1, const FaceVerts& tri2) { - // Get verts for face 1 - const float3 v0_1 = tri1.v0; - const float3 v1_1 = tri1.v1; - const float3 v2_1 = tri1.v2; + const float3 tri1_ctr = FaceCenter({tri1.v0, tri1.v1, tri1.v2}); + const float3 tri1_n = FaceNormal({tri1.v0, tri1.v1, tri1.v2}); - const float3 tri1_ctr = TriCenter(v0_1, v1_1, v2_1); - const float3 tri1_n = TriNormal(v0_1, v1_1, v2_1); - - // Get verts for face 2 - const float3 v0_2 = tri2.v0; - const float3 v1_2 = tri2.v1; - const float3 v2_2 = tri2.v2; - - const float3 tri2_ctr = TriCenter(v0_2, v1_2, v2_2); - const float3 tri2_n = TriNormal(v0_2, v1_2, v2_2); + const float3 tri2_ctr = FaceCenter({tri2.v0, tri2.v1, tri2.v2}); + const float3 tri2_n = FaceNormal({tri2.v0, tri2.v1, tri2.v2}); // Check if parallel const bool check1 = abs(dot(tri1_n, tri2_n)) > 1 - dEpsilon; // Compute most distant points - auto argvs = + const auto [v1m, v2m] = ArgMaxVerts({tri1.v0, tri1.v1, tri1.v2}, {tri2.v0, tri2.v1, tri2.v2}); - const float3 v1m = std::get<0>(argvs); - const float3 v2m = std::get<1>(argvs); float3 n12m = v1m - v2m; n12m = n12m / fmaxf(norm(n12m), kEpsilon); @@ -597,22 +499,15 @@ __device__ inline bool IsCoplanarTriPlane( const FaceVerts& tri, const FaceVerts& plane, const float3& normal) { - // Get verts for tri - const float3 v0t = tri.v0; - const float3 v1t = tri.v1; - const float3 v2t = tri.v2; - - const float3 tri_ctr = TriCenter(v0t, v1t, v2t); - const float3 nt = TriNormal(v0t, v1t, v2t); + const float3 tri_ctr = FaceCenter({tri.v0, tri.v1, tri.v2}); + const float3 nt = FaceNormal({tri.v0, tri.v1, tri.v2}); // check if parallel const bool check1 = abs(dot(nt, normal)) > 1 - dEpsilon; // Compute most distant points - auto argvs = ArgMaxVerts( + const auto [v1m, v2m] = ArgMaxVerts( {tri.v0, tri.v1, tri.v2}, {plane.v0, plane.v1, plane.v2, plane.v3}); - const float3 v1m = std::get<0>(argvs); - const float3 v2m = std::get<1>(argvs); float3 n12m = v1m - v2m; n12m = n12m / fmaxf(norm(n12m), kEpsilon);