mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Add structured bindings to iou to prove that we're C++17-friendly. Also other minor improvements to bbox iou
Summary: Recently we removed C++14-only compilation, should work. Reviewed By: bottler Differential Revision: D38919607 fbshipit-source-id: 6a26fa7713f7ba2163364ccc673ad774aa3a5adb
This commit is contained in:
parent
5e7707b157
commit
c4545a7cbc
@ -29,14 +29,17 @@ __global__ void IoUBox3DKernel(
|
||||
const size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const size_t stride = gridDim.x * blockDim.x;
|
||||
|
||||
std::array<FaceVerts, NUM_TRIS> box1_tris{};
|
||||
std::array<FaceVerts, NUM_TRIS> box2_tris{};
|
||||
std::array<FaceVerts, NUM_PLANES> box1_planes{};
|
||||
std::array<FaceVerts, NUM_PLANES> box2_planes{};
|
||||
|
||||
for (size_t i = tid; i < N * M; i += stride) {
|
||||
const size_t n = i / M; // box1 index
|
||||
const size_t m = i % M; // box2 index
|
||||
|
||||
// Convert to array of structs of face vertices i.e. effectively (F, 3, 3)
|
||||
// FaceVerts is a data type defined in iou_utils.cuh
|
||||
FaceVerts box1_tris[NUM_TRIS];
|
||||
FaceVerts box2_tris[NUM_TRIS];
|
||||
GetBoxTris(boxes1[n], box1_tris);
|
||||
GetBoxTris(boxes2[m], box2_tris);
|
||||
|
||||
@ -46,9 +49,7 @@ __global__ void IoUBox3DKernel(
|
||||
const float3 box2_center = BoxCenter(boxes2[m]);
|
||||
|
||||
// Convert to an array of face vertices
|
||||
FaceVerts box1_planes[NUM_PLANES];
|
||||
GetBoxPlanes(boxes1[n], box1_planes);
|
||||
FaceVerts box2_planes[NUM_PLANES];
|
||||
GetBoxPlanes(boxes2[m], box2_planes);
|
||||
|
||||
// Get Box Volumes
|
||||
|
@ -39,25 +39,25 @@ const int MAX_TRIS = 100;
|
||||
// We will use struct arrays for representing
|
||||
// the data for each box and intersecting
|
||||
// triangles
|
||||
typedef struct {
|
||||
struct FaceVerts {
|
||||
float3 v0;
|
||||
float3 v1;
|
||||
float3 v2;
|
||||
float3 v3; // Can be empty for triangles
|
||||
} FaceVerts;
|
||||
};
|
||||
|
||||
typedef struct {
|
||||
struct FaceVertsIdx {
|
||||
int v0;
|
||||
int v1;
|
||||
int v2;
|
||||
int v3; // Can be empty for triangles
|
||||
} FaceVertsIdx;
|
||||
};
|
||||
|
||||
// This is used when deciding which faces to
|
||||
// keep that are not coplanar
|
||||
typedef struct {
|
||||
struct Keep {
|
||||
bool keep;
|
||||
} Keep;
|
||||
};
|
||||
|
||||
__device__ FaceVertsIdx _PLANES[] = {
|
||||
{0, 1, 2, 3},
|
||||
@ -128,6 +128,23 @@ __device__ inline void GetBoxPlanes(
|
||||
}
|
||||
}
|
||||
|
||||
// The geometric center of a list of vertices.
|
||||
//
|
||||
// Args
|
||||
// vertices: A list of float3 vertices {v0, ..., vN}.
|
||||
//
|
||||
// Returns
|
||||
// float3: Geometric center of the vertices.
|
||||
//
|
||||
__device__ inline float3 FaceCenter(
|
||||
std::initializer_list<const float3> vertices) {
|
||||
auto sumVertices = float3{};
|
||||
for (const auto& vertex : vertices) {
|
||||
sumVertices = sumVertices + vertex;
|
||||
}
|
||||
return sumVertices / vertices.size();
|
||||
}
|
||||
|
||||
// The normal of a plane spanned by vectors e0 and e1
|
||||
//
|
||||
// Args
|
||||
@ -142,50 +159,33 @@ __device__ inline float3 GetNormal(const float3 e0, const float3 e1) {
|
||||
return n;
|
||||
}
|
||||
|
||||
// The center of a triangle defined by vertices (v0, v1, v2)
|
||||
//
|
||||
// Args
|
||||
// v0, v1, v2: float3 coordinates of the vertices of the triangle
|
||||
//
|
||||
// Returns
|
||||
// float3: center of the triangle
|
||||
//
|
||||
__device__ inline float3
|
||||
TriCenter(const float3 v0, const float3 v1, const float3 v2) {
|
||||
float3 ctr = (v0 + v1 + v2) / 3.0f;
|
||||
return ctr;
|
||||
}
|
||||
|
||||
// The normal of the triangle defined by vertices (v0, v1, v2)
|
||||
// The normal of a face with vertices (v0, v1, v2) or (v0, ..., v3).
|
||||
// We find the "best" edges connecting the face center to the vertices,
|
||||
// such that the cross product between the edges is maximized.
|
||||
//
|
||||
// Args
|
||||
// v0, v1, v2: float3 coordinates of the vertices of the face
|
||||
// vertices: a list of float3 coordinates of the vertices.
|
||||
//
|
||||
// Returns
|
||||
// float3: normal for the face
|
||||
// float3: center of the plane
|
||||
//
|
||||
__device__ inline float3
|
||||
TriNormal(const float3 v0, const float3 v1, const float3 v2) {
|
||||
const float3 ctr = TriCenter(v0, v1, v2);
|
||||
|
||||
const float d01 = norm(cross(v0 - ctr, v1 - ctr));
|
||||
const float d02 = norm(cross(v0 - ctr, v2 - ctr));
|
||||
const float d12 = norm(cross(v1 - ctr, v2 - ctr));
|
||||
|
||||
float3 n = GetNormal(v0 - ctr, v1 - ctr);
|
||||
float max_dist = d01;
|
||||
|
||||
if (d02 > max_dist) {
|
||||
max_dist = d02;
|
||||
n = GetNormal(v0 - ctr, v2 - ctr);
|
||||
__device__ inline float3 FaceNormal(
|
||||
std::initializer_list<const float3> vertices) {
|
||||
const auto faceCenter = FaceCenter(vertices);
|
||||
auto normal = float3();
|
||||
auto maxDist = -1;
|
||||
for (auto v1 = vertices.begin(); v1 != vertices.end() - 1; ++v1) {
|
||||
for (auto v2 = std::next(v1); v2 != vertices.end(); ++v2) {
|
||||
const auto v1ToCenter = *v1 - faceCenter;
|
||||
const auto v2ToCenter = *v2 - faceCenter;
|
||||
const auto dist = norm(cross(v1ToCenter, v2ToCenter));
|
||||
if (dist > maxDist) {
|
||||
normal = GetNormal(v1ToCenter, v2ToCenter);
|
||||
maxDist = dist;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (d12 > max_dist) {
|
||||
n = GetNormal(v1 - ctr, v2 - ctr);
|
||||
}
|
||||
|
||||
return n;
|
||||
return normal;
|
||||
}
|
||||
|
||||
// The area of the face defined by vertices (v0, v1, v2)
|
||||
@ -201,79 +201,10 @@ TriNormal(const float3 v0, const float3 v1, const float3 v2) {
|
||||
//
|
||||
__device__ inline float FaceArea(const FaceVerts& tri) {
|
||||
// Get verts for face 1
|
||||
const float3 v0 = tri.v0;
|
||||
const float3 v1 = tri.v1;
|
||||
const float3 v2 = tri.v2;
|
||||
const float3 n = cross(v1 - v0, v2 - v0);
|
||||
const float3 n = cross(tri.v1 - tri.v0, tri.v2 - tri.v0);
|
||||
return norm(n) / 2.0;
|
||||
}
|
||||
|
||||
// The center of a plane defined by vertices (v0, v1, v2, v3)
|
||||
//
|
||||
// Args
|
||||
// v0, v1, v2, v3: float3 coordinates of the vertices of the plane
|
||||
//
|
||||
// Returns
|
||||
// float3: center of the plane
|
||||
//
|
||||
__device__ inline float3 PlaneCenter(
|
||||
const float3 v0,
|
||||
const float3 v1,
|
||||
const float3 v2,
|
||||
const float3 v3) {
|
||||
float3 ctr = (v0 + v1 + v2 + v3) / 4.0f;
|
||||
return ctr;
|
||||
}
|
||||
|
||||
// The normal of a planar face with vertices (v0, v1, v2, v3)
|
||||
// We find the "best" edges connecting the face center to the vertices,
|
||||
// such that the cross product between the edges is maximized.
|
||||
//
|
||||
// Args
|
||||
// e0, e1: float3 coordinates of the vertices of the plane
|
||||
//
|
||||
// Returns
|
||||
// float3: center of the plane
|
||||
//
|
||||
__device__ inline float3 PlaneNormal(
|
||||
const float3 v0,
|
||||
const float3 v1,
|
||||
const float3 v2,
|
||||
const float3 v3) {
|
||||
const float3 ctr = PlaneCenter(v0, v1, v2, v3);
|
||||
|
||||
const float d01 = norm(cross(v0 - ctr, v1 - ctr));
|
||||
const float d02 = norm(cross(v0 - ctr, v2 - ctr));
|
||||
const float d03 = norm(cross(v0 - ctr, v3 - ctr));
|
||||
const float d12 = norm(cross(v1 - ctr, v2 - ctr));
|
||||
const float d13 = norm(cross(v1 - ctr, v3 - ctr));
|
||||
const float d23 = norm(cross(v2 - ctr, v3 - ctr));
|
||||
|
||||
float max_dist = d01;
|
||||
float3 n = GetNormal(v0 - ctr, v1 - ctr);
|
||||
|
||||
if (d02 > max_dist) {
|
||||
max_dist = d02;
|
||||
n = GetNormal(v0 - ctr, v2 - ctr);
|
||||
}
|
||||
if (d03 > max_dist) {
|
||||
max_dist = d03;
|
||||
n = GetNormal(v0 - ctr, v3 - ctr);
|
||||
}
|
||||
if (d12 > max_dist) {
|
||||
max_dist = d12;
|
||||
n = GetNormal(v1 - ctr, v2 - ctr);
|
||||
}
|
||||
if (d13 > max_dist) {
|
||||
max_dist = d13;
|
||||
n = GetNormal(v1 - ctr, v3 - ctr);
|
||||
}
|
||||
if (d23 > max_dist) {
|
||||
n = GetNormal(v2 - ctr, v3 - ctr);
|
||||
}
|
||||
return n;
|
||||
}
|
||||
|
||||
// The normal of a box plane defined by the verts in `plane` such that it
|
||||
// points toward the centroid of the box given by `center`.
|
||||
//
|
||||
@ -290,17 +221,12 @@ template <typename FaceVertsPlane>
|
||||
__device__ inline float3 PlaneNormalDirection(
|
||||
const FaceVertsPlane& plane,
|
||||
const float3& center) {
|
||||
// The plane's vertices
|
||||
const float3 v0 = plane.v0;
|
||||
const float3 v1 = plane.v1;
|
||||
const float3 v2 = plane.v2;
|
||||
const float3 v3 = plane.v3;
|
||||
|
||||
// The plane's center
|
||||
const float3 plane_center = PlaneCenter(v0, v1, v2, v3);
|
||||
const float3 plane_center =
|
||||
FaceCenter({plane.v0, plane.v1, plane.v2, plane.v3});
|
||||
|
||||
// The plane's normal
|
||||
float3 n = PlaneNormal(v0, v1, v2, v3);
|
||||
float3 n = FaceNormal({plane.v0, plane.v1, plane.v2, plane.v3});
|
||||
|
||||
// We project the center on the plane defined by (v0, v1, v2, v3)
|
||||
// We can write center = plane_center + a * e0 + b * e1 + c * n
|
||||
@ -442,14 +368,8 @@ __device__ inline float3 PolyhedronCenter(
|
||||
//
|
||||
__device__ inline bool
|
||||
IsInside(const FaceVerts& plane, const float3& normal, const float3& point) {
|
||||
// Vertices of the plane
|
||||
const float3 v0 = plane.v0;
|
||||
const float3 v1 = plane.v1;
|
||||
const float3 v2 = plane.v2;
|
||||
const float3 v3 = plane.v3;
|
||||
|
||||
// The center of the plane
|
||||
const float3 plane_ctr = PlaneCenter(v0, v1, v2, v3);
|
||||
const float3 plane_ctr = FaceCenter({plane.v0, plane.v1, plane.v2, plane.v3});
|
||||
|
||||
// Every point p can be written as p = plane_ctr + a e0 + b e1 + c n
|
||||
// Solving for c:
|
||||
@ -478,14 +398,8 @@ __device__ inline float3 PlaneEdgeIntersection(
|
||||
const float3& normal,
|
||||
const float3& p0,
|
||||
const float3& p1) {
|
||||
// Vertices of the plane
|
||||
const float3 v0 = plane.v0;
|
||||
const float3 v1 = plane.v1;
|
||||
const float3 v2 = plane.v2;
|
||||
const float3 v3 = plane.v3;
|
||||
|
||||
// The center of the plane
|
||||
const float3 plane_ctr = PlaneCenter(v0, v1, v2, v3);
|
||||
const float3 plane_ctr = FaceCenter({plane.v0, plane.v1, plane.v2, plane.v3});
|
||||
|
||||
// The point of intersection can be parametrized
|
||||
// p = p0 + a (p1 - p0) where a in [0, 1]
|
||||
@ -548,30 +462,18 @@ __device__ inline std::tuple<float3, float3> ArgMaxVerts(
|
||||
__device__ inline bool IsCoplanarTriTri(
|
||||
const FaceVerts& tri1,
|
||||
const FaceVerts& tri2) {
|
||||
// Get verts for face 1
|
||||
const float3 v0_1 = tri1.v0;
|
||||
const float3 v1_1 = tri1.v1;
|
||||
const float3 v2_1 = tri1.v2;
|
||||
const float3 tri1_ctr = FaceCenter({tri1.v0, tri1.v1, tri1.v2});
|
||||
const float3 tri1_n = FaceNormal({tri1.v0, tri1.v1, tri1.v2});
|
||||
|
||||
const float3 tri1_ctr = TriCenter(v0_1, v1_1, v2_1);
|
||||
const float3 tri1_n = TriNormal(v0_1, v1_1, v2_1);
|
||||
|
||||
// Get verts for face 2
|
||||
const float3 v0_2 = tri2.v0;
|
||||
const float3 v1_2 = tri2.v1;
|
||||
const float3 v2_2 = tri2.v2;
|
||||
|
||||
const float3 tri2_ctr = TriCenter(v0_2, v1_2, v2_2);
|
||||
const float3 tri2_n = TriNormal(v0_2, v1_2, v2_2);
|
||||
const float3 tri2_ctr = FaceCenter({tri2.v0, tri2.v1, tri2.v2});
|
||||
const float3 tri2_n = FaceNormal({tri2.v0, tri2.v1, tri2.v2});
|
||||
|
||||
// Check if parallel
|
||||
const bool check1 = abs(dot(tri1_n, tri2_n)) > 1 - dEpsilon;
|
||||
|
||||
// Compute most distant points
|
||||
auto argvs =
|
||||
const auto [v1m, v2m] =
|
||||
ArgMaxVerts({tri1.v0, tri1.v1, tri1.v2}, {tri2.v0, tri2.v1, tri2.v2});
|
||||
const float3 v1m = std::get<0>(argvs);
|
||||
const float3 v2m = std::get<1>(argvs);
|
||||
|
||||
float3 n12m = v1m - v2m;
|
||||
n12m = n12m / fmaxf(norm(n12m), kEpsilon);
|
||||
@ -597,22 +499,15 @@ __device__ inline bool IsCoplanarTriPlane(
|
||||
const FaceVerts& tri,
|
||||
const FaceVerts& plane,
|
||||
const float3& normal) {
|
||||
// Get verts for tri
|
||||
const float3 v0t = tri.v0;
|
||||
const float3 v1t = tri.v1;
|
||||
const float3 v2t = tri.v2;
|
||||
|
||||
const float3 tri_ctr = TriCenter(v0t, v1t, v2t);
|
||||
const float3 nt = TriNormal(v0t, v1t, v2t);
|
||||
const float3 tri_ctr = FaceCenter({tri.v0, tri.v1, tri.v2});
|
||||
const float3 nt = FaceNormal({tri.v0, tri.v1, tri.v2});
|
||||
|
||||
// check if parallel
|
||||
const bool check1 = abs(dot(nt, normal)) > 1 - dEpsilon;
|
||||
|
||||
// Compute most distant points
|
||||
auto argvs = ArgMaxVerts(
|
||||
const auto [v1m, v2m] = ArgMaxVerts(
|
||||
{tri.v0, tri.v1, tri.v2}, {plane.v0, plane.v1, plane.v2, plane.v3});
|
||||
const float3 v1m = std::get<0>(argvs);
|
||||
const float3 v2m = std::get<1>(argvs);
|
||||
|
||||
float3 n12m = v1m - v2m;
|
||||
n12m = n12m / fmaxf(norm(n12m), kEpsilon);
|
||||
|
Loading…
x
Reference in New Issue
Block a user