mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 18:02:14 +08:00 
			
		
		
		
	(new) CUDA IoU for 3D boxes
Summary: CUDA implementation of 3D bounding box overlap calculation. Reviewed By: gkioxari Differential Revision: D31157919 fbshipit-source-id: 5dc89805d01fef2d6779f00a33226131e39c43ed
This commit is contained in:
		
							parent
							
								
									53266ec9ff
								
							
						
					
					
						commit
						ff8d4762f4
					
				
							
								
								
									
										176
									
								
								pytorch3d/csrc/iou_box3d/iou_box3d.cu
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										176
									
								
								pytorch3d/csrc/iou_box3d/iou_box3d.cu
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,176 @@
 | 
			
		||||
/*
 | 
			
		||||
 * Copyright (c) Facebook, Inc. and its affiliates.
 | 
			
		||||
 * All rights reserved.
 | 
			
		||||
 *
 | 
			
		||||
 * This source code is licensed under the BSD-style license found in the
 | 
			
		||||
 * LICENSE file in the root directory of this source tree.
 | 
			
		||||
 */
 | 
			
		||||
 | 
			
		||||
#include <ATen/ATen.h>
 | 
			
		||||
#include <ATen/cuda/CUDAContext.h>
 | 
			
		||||
#include <c10/cuda/CUDAGuard.h>
 | 
			
		||||
#include <math.h>
 | 
			
		||||
#include <stdio.h>
 | 
			
		||||
#include <stdlib.h>
 | 
			
		||||
#include <thrust/device_vector.h>
 | 
			
		||||
#include <thrust/tuple.h>
 | 
			
		||||
#include "iou_box3d/iou_utils.cuh"
 | 
			
		||||
#include "utils/pytorch3d_cutils.h"
 | 
			
		||||
 | 
			
		||||
// Parallelize over N*M computations which can each be done
 | 
			
		||||
// independently
 | 
			
		||||
__global__ void IoUBox3DKernel(
 | 
			
		||||
    const at::PackedTensorAccessor64<float, 3, at::RestrictPtrTraits> boxes1,
 | 
			
		||||
    const at::PackedTensorAccessor64<float, 3, at::RestrictPtrTraits> boxes2,
 | 
			
		||||
    at::PackedTensorAccessor64<float, 2, at::RestrictPtrTraits> vols,
 | 
			
		||||
    at::PackedTensorAccessor64<float, 2, at::RestrictPtrTraits> ious) {
 | 
			
		||||
  const size_t N = boxes1.size(0);
 | 
			
		||||
  const size_t M = boxes2.size(0);
 | 
			
		||||
 | 
			
		||||
  const size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
 | 
			
		||||
  const size_t stride = gridDim.x * blockDim.x;
 | 
			
		||||
 | 
			
		||||
  for (size_t i = tid; i < N * M; i += stride) {
 | 
			
		||||
    const size_t n = i / M; // box1 index
 | 
			
		||||
    const size_t m = i % M; // box2 index
 | 
			
		||||
 | 
			
		||||
    // Convert to array of structs of face vertices i.e. effectively (F, 3, 3)
 | 
			
		||||
    // FaceVerts is a data type defined in iou_utils.cuh
 | 
			
		||||
    FaceVerts box1_tris[NUM_TRIS];
 | 
			
		||||
    FaceVerts box2_tris[NUM_TRIS];
 | 
			
		||||
    GetBoxTris(boxes1[n], box1_tris);
 | 
			
		||||
    GetBoxTris(boxes2[m], box2_tris);
 | 
			
		||||
 | 
			
		||||
    // Calculate the position of the center of the box which is used in
 | 
			
		||||
    // several calculations. This requires a tensor as input.
 | 
			
		||||
    const float3 box1_center = BoxCenter(boxes1[n]);
 | 
			
		||||
    const float3 box2_center = BoxCenter(boxes2[m]);
 | 
			
		||||
 | 
			
		||||
    // Convert to an array of face vertices
 | 
			
		||||
    FaceVerts box1_planes[NUM_PLANES];
 | 
			
		||||
    GetBoxPlanes(boxes1[n], box1_planes);
 | 
			
		||||
    FaceVerts box2_planes[NUM_PLANES];
 | 
			
		||||
    GetBoxPlanes(boxes2[m], box2_planes);
 | 
			
		||||
 | 
			
		||||
    // Get Box Volumes
 | 
			
		||||
    const float box1_vol = BoxVolume(box1_tris, box1_center, NUM_TRIS);
 | 
			
		||||
    const float box2_vol = BoxVolume(box2_tris, box2_center, NUM_TRIS);
 | 
			
		||||
 | 
			
		||||
    // Tris in Box1 intersection with Planes in Box2
 | 
			
		||||
    // Initialize box1 intersecting faces. MAX_TRIS is the
 | 
			
		||||
    // max faces possible in the intersecting shape.
 | 
			
		||||
    // TODO: determine if the value of MAX_TRIS is sufficient or
 | 
			
		||||
    // if we should store the max tris for each NxM computation
 | 
			
		||||
    // and throw an error if any exceeds the max.
 | 
			
		||||
    FaceVerts box1_intersect[MAX_TRIS];
 | 
			
		||||
    for (int j = 0; j < NUM_TRIS; ++j) {
 | 
			
		||||
      // Initialize the faces from the box
 | 
			
		||||
      box1_intersect[j] = box1_tris[j];
 | 
			
		||||
    }
 | 
			
		||||
    // Get the count of the actual number of faces in the intersecting shape
 | 
			
		||||
    int box1_count = BoxIntersections(box2_planes, box2_center, box1_intersect);
 | 
			
		||||
 | 
			
		||||
    // Tris in Box2 intersection with Planes in Box1
 | 
			
		||||
    FaceVerts box2_intersect[MAX_TRIS];
 | 
			
		||||
    for (int j = 0; j < NUM_TRIS; ++j) {
 | 
			
		||||
      box2_intersect[j] = box2_tris[j];
 | 
			
		||||
    }
 | 
			
		||||
    const int box2_count =
 | 
			
		||||
        BoxIntersections(box1_planes, box1_center, box2_intersect);
 | 
			
		||||
 | 
			
		||||
    // If there are overlapping regions in Box2, remove any coplanar faces
 | 
			
		||||
    if (box2_count > 0) {
 | 
			
		||||
      // Identify if any triangles in Box2 are coplanar with Box1
 | 
			
		||||
      Keep tri2_keep[MAX_TRIS];
 | 
			
		||||
      for (int j = 0; j < MAX_TRIS; ++j) {
 | 
			
		||||
        // Initialize the valid faces to be true
 | 
			
		||||
        tri2_keep[j].keep = j < box2_count ? true : false;
 | 
			
		||||
      }
 | 
			
		||||
      for (int b1 = 0; b1 < box1_count; ++b1) {
 | 
			
		||||
        for (int b2 = 0; b2 < box2_count; ++b2) {
 | 
			
		||||
          const bool is_coplanar =
 | 
			
		||||
              IsCoplanarFace(box1_intersect[b1], box2_intersect[b2]);
 | 
			
		||||
          if (is_coplanar) {
 | 
			
		||||
            tri2_keep[b2].keep = false;
 | 
			
		||||
          }
 | 
			
		||||
        }
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
      // Keep only the non coplanar triangles in Box2 - add them to the
 | 
			
		||||
      // Box1 triangles.
 | 
			
		||||
      for (int b2 = 0; b2 < box2_count; ++b2) {
 | 
			
		||||
        if (tri2_keep[b2].keep) {
 | 
			
		||||
          box1_intersect[box1_count] = box2_intersect[b2];
 | 
			
		||||
          // box1_count will determine the total faces in the
 | 
			
		||||
          // intersecting shape
 | 
			
		||||
          box1_count++;
 | 
			
		||||
        }
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // Initialize the vol and iou to 0.0 in case there are no triangles
 | 
			
		||||
    // in the intersecting shape.
 | 
			
		||||
    float vol = 0.0;
 | 
			
		||||
    float iou = 0.0;
 | 
			
		||||
 | 
			
		||||
    // If there are triangles in the intersecting shape
 | 
			
		||||
    if (box1_count > 0) {
 | 
			
		||||
      // The intersecting shape is a polyhedron made up of the
 | 
			
		||||
      // triangular faces that are all now in box1_intersect.
 | 
			
		||||
      // Calculate the polyhedron center
 | 
			
		||||
      const float3 poly_center = PolyhedronCenter(box1_intersect, box1_count);
 | 
			
		||||
      // Compute intersecting polyhedron volume
 | 
			
		||||
      vol = BoxVolume(box1_intersect, poly_center, box1_count);
 | 
			
		||||
      // Compute IoU
 | 
			
		||||
      iou = vol / (box1_vol + box2_vol - vol);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // Write the volume and IoU to global memory
 | 
			
		||||
    vols[n][m] = vol;
 | 
			
		||||
    ious[n][m] = iou;
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
std::tuple<at::Tensor, at::Tensor> IoUBox3DCuda(
 | 
			
		||||
    const at::Tensor& boxes1, // (N, 8, 3)
 | 
			
		||||
    const at::Tensor& boxes2) { // (M, 8, 3)
 | 
			
		||||
  // Check inputs are on the same device
 | 
			
		||||
  at::TensorArg boxes1_t{boxes1, "boxes1", 1}, boxes2_t{boxes2, "boxes2", 2};
 | 
			
		||||
  at::CheckedFrom c = "IoUBox3DCuda";
 | 
			
		||||
  at::checkAllSameGPU(c, {boxes1_t, boxes2_t});
 | 
			
		||||
  at::checkAllSameType(c, {boxes1_t, boxes2_t});
 | 
			
		||||
 | 
			
		||||
  // Set the device for the kernel launch based on the device of boxes1
 | 
			
		||||
  at::cuda::CUDAGuard device_guard(boxes1.device());
 | 
			
		||||
  cudaStream_t stream = at::cuda::getCurrentCUDAStream();
 | 
			
		||||
 | 
			
		||||
  TORCH_CHECK(boxes2.size(2) == boxes1.size(2), "Boxes must have shape (8, 3)");
 | 
			
		||||
 | 
			
		||||
  TORCH_CHECK(
 | 
			
		||||
      (boxes2.size(1) == 8) && (boxes1.size(1) == 8),
 | 
			
		||||
      "Boxes must have shape (8, 3)");
 | 
			
		||||
 | 
			
		||||
  const int64_t N = boxes1.size(0);
 | 
			
		||||
  const int64_t M = boxes2.size(0);
 | 
			
		||||
 | 
			
		||||
  auto vols = at::zeros({N, M}, boxes1.options());
 | 
			
		||||
  auto ious = at::zeros({N, M}, boxes1.options());
 | 
			
		||||
 | 
			
		||||
  if (vols.numel() == 0) {
 | 
			
		||||
    AT_CUDA_CHECK(cudaGetLastError());
 | 
			
		||||
    return std::make_tuple(vols, ious);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  const size_t blocks = 512;
 | 
			
		||||
  const size_t threads = 256;
 | 
			
		||||
 | 
			
		||||
  IoUBox3DKernel<<<blocks, threads, 0, stream>>>(
 | 
			
		||||
      boxes1.packed_accessor64<float, 3, at::RestrictPtrTraits>(),
 | 
			
		||||
      boxes2.packed_accessor64<float, 3, at::RestrictPtrTraits>(),
 | 
			
		||||
      vols.packed_accessor64<float, 2, at::RestrictPtrTraits>(),
 | 
			
		||||
      ious.packed_accessor64<float, 2, at::RestrictPtrTraits>());
 | 
			
		||||
 | 
			
		||||
  AT_CUDA_CHECK(cudaGetLastError());
 | 
			
		||||
 | 
			
		||||
  return std::make_tuple(vols, ious);
 | 
			
		||||
}
 | 
			
		||||
@ -26,12 +26,23 @@ std::tuple<at::Tensor, at::Tensor> IoUBox3DCpu(
 | 
			
		||||
    const at::Tensor& boxes1,
 | 
			
		||||
    const at::Tensor& boxes2);
 | 
			
		||||
 | 
			
		||||
// CUDA implementation
 | 
			
		||||
std::tuple<at::Tensor, at::Tensor> IoUBox3DCuda(
 | 
			
		||||
    const at::Tensor& boxes1,
 | 
			
		||||
    const at::Tensor& boxes2);
 | 
			
		||||
 | 
			
		||||
// Implementation which is exposed
 | 
			
		||||
inline std::tuple<at::Tensor, at::Tensor> IoUBox3D(
 | 
			
		||||
    const at::Tensor& boxes1,
 | 
			
		||||
    const at::Tensor& boxes2) {
 | 
			
		||||
  if (boxes1.is_cuda() || boxes2.is_cuda()) {
 | 
			
		||||
    AT_ERROR("GPU support not implemented");
 | 
			
		||||
#ifdef WITH_CUDA
 | 
			
		||||
    CHECK_CUDA(boxes1);
 | 
			
		||||
    CHECK_CUDA(boxes2);
 | 
			
		||||
    return IoUBox3DCuda(boxes1.contiguous(), boxes2.contiguous());
 | 
			
		||||
#else
 | 
			
		||||
    AT_ERROR("Not compiled with GPU support.");
 | 
			
		||||
#endif
 | 
			
		||||
  }
 | 
			
		||||
  return IoUBox3DCpu(boxes1.contiguous(), boxes2.contiguous());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -79,7 +79,7 @@ std::tuple<at::Tensor, at::Tensor> IoUBox3DCpu(
 | 
			
		||||
        std::fill(tri2_keep.begin(), tri2_keep.end(), 1);
 | 
			
		||||
        for (int b1 = 0; b1 < box1_intersect.size(); ++b1) {
 | 
			
		||||
          for (int b2 = 0; b2 < box2_intersect.size(); ++b2) {
 | 
			
		||||
            bool is_coplanar =
 | 
			
		||||
            const bool is_coplanar =
 | 
			
		||||
                IsCoplanarFace(box1_intersect[b1], box2_intersect[b2]);
 | 
			
		||||
            if (is_coplanar) {
 | 
			
		||||
              tri2_keep[b2] = 0;
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										584
									
								
								pytorch3d/csrc/iou_box3d/iou_utils.cuh
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										584
									
								
								pytorch3d/csrc/iou_box3d/iou_utils.cuh
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,584 @@
 | 
			
		||||
/*
 | 
			
		||||
 * Copyright (c) Facebook, Inc. and its affiliates.
 | 
			
		||||
 * All rights reserved.
 | 
			
		||||
 *
 | 
			
		||||
 * This source code is licensed under the BSD-style license found in the
 | 
			
		||||
 * LICENSE file in the root directory of this source tree.
 | 
			
		||||
 */
 | 
			
		||||
 | 
			
		||||
#include <float.h>
 | 
			
		||||
#include <math.h>
 | 
			
		||||
#include <thrust/device_vector.h>
 | 
			
		||||
#include <cstdio>
 | 
			
		||||
#include "utils/float_math.cuh"
 | 
			
		||||
#include "utils/geometry_utils.cuh"
 | 
			
		||||
 | 
			
		||||
/*
 | 
			
		||||
_PLANES and _TRIS define the 4- and 3-connectivity
 | 
			
		||||
of the 8 box corners.
 | 
			
		||||
_PLANES gives the quad faces of the 3D box
 | 
			
		||||
_TRIS gives the triangle faces of the 3D box
 | 
			
		||||
*/
 | 
			
		||||
const int NUM_PLANES = 6;
 | 
			
		||||
const int NUM_TRIS = 12;
 | 
			
		||||
// This is required for iniitalizing the faces
 | 
			
		||||
// in the intersecting shape
 | 
			
		||||
const int MAX_TRIS = 100;
 | 
			
		||||
 | 
			
		||||
// Create data types for representing the
 | 
			
		||||
// verts for each face and the indices.
 | 
			
		||||
// We will use struct arrays for representing
 | 
			
		||||
// the data for each box and intersecting
 | 
			
		||||
// triangles
 | 
			
		||||
typedef struct {
 | 
			
		||||
  float3 v0;
 | 
			
		||||
  float3 v1;
 | 
			
		||||
  float3 v2;
 | 
			
		||||
  float3 v3; // Can be empty for triangles
 | 
			
		||||
} FaceVerts;
 | 
			
		||||
 | 
			
		||||
typedef struct {
 | 
			
		||||
  int v0;
 | 
			
		||||
  int v1;
 | 
			
		||||
  int v2;
 | 
			
		||||
  int v3; // Can be empty for triangles
 | 
			
		||||
} FaceVertsIdx;
 | 
			
		||||
 | 
			
		||||
// This is used when deciding which faces to
 | 
			
		||||
// keep that are not coplanar
 | 
			
		||||
typedef struct {
 | 
			
		||||
  bool keep;
 | 
			
		||||
} Keep;
 | 
			
		||||
 | 
			
		||||
__device__ FaceVertsIdx _PLANES[] = {
 | 
			
		||||
    {0, 1, 2, 3},
 | 
			
		||||
    {3, 2, 6, 7},
 | 
			
		||||
    {0, 1, 5, 4},
 | 
			
		||||
    {0, 3, 7, 4},
 | 
			
		||||
    {1, 5, 6, 2},
 | 
			
		||||
    {4, 5, 6, 7},
 | 
			
		||||
};
 | 
			
		||||
__device__ FaceVertsIdx _TRIS[] = {
 | 
			
		||||
    {0, 1, 2},
 | 
			
		||||
    {0, 3, 2},
 | 
			
		||||
    {4, 5, 6},
 | 
			
		||||
    {4, 6, 7},
 | 
			
		||||
    {1, 5, 6},
 | 
			
		||||
    {1, 6, 2},
 | 
			
		||||
    {0, 4, 7},
 | 
			
		||||
    {0, 7, 3},
 | 
			
		||||
    {3, 2, 6},
 | 
			
		||||
    {3, 6, 7},
 | 
			
		||||
    {0, 1, 5},
 | 
			
		||||
    {0, 4, 5},
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
// Args
 | 
			
		||||
//    box: (8, 3) tensor accessor for the box vertices
 | 
			
		||||
//    box_tris: Array of structs of type FaceVerts,
 | 
			
		||||
//      effectively (F, 3, 3) where the coordinates of the
 | 
			
		||||
//      verts for each face will be saved to.
 | 
			
		||||
//
 | 
			
		||||
// Returns: None (output saved to box_tris)
 | 
			
		||||
//
 | 
			
		||||
template <typename Box, typename BoxTris>
 | 
			
		||||
__device__ inline void GetBoxTris(const Box& box, BoxTris& box_tris) {
 | 
			
		||||
  for (int t = 0; t < NUM_TRIS; ++t) {
 | 
			
		||||
    const float3 v0 = make_float3(
 | 
			
		||||
        box[_TRIS[t].v0][0], box[_TRIS[t].v0][1], box[_TRIS[t].v0][2]);
 | 
			
		||||
    const float3 v1 = make_float3(
 | 
			
		||||
        box[_TRIS[t].v1][0], box[_TRIS[t].v1][1], box[_TRIS[t].v1][2]);
 | 
			
		||||
    const float3 v2 = make_float3(
 | 
			
		||||
        box[_TRIS[t].v2][0], box[_TRIS[t].v2][1], box[_TRIS[t].v2][2]);
 | 
			
		||||
    box_tris[t] = {v0, v1, v2};
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Args
 | 
			
		||||
//    box: (8, 3) tensor accessor for the box vertices
 | 
			
		||||
//    box_planes: Array of structs of type FaceVerts, effectively (P, 4, 3)
 | 
			
		||||
//      where the coordinates of the verts for the four corners of each plane
 | 
			
		||||
//      will be saved to
 | 
			
		||||
//
 | 
			
		||||
// Returns: None (output saved to box_planes)
 | 
			
		||||
//
 | 
			
		||||
template <typename Box, typename FaceVertsBoxPlanes>
 | 
			
		||||
__device__ inline void GetBoxPlanes(
 | 
			
		||||
    const Box& box,
 | 
			
		||||
    FaceVertsBoxPlanes& box_planes) {
 | 
			
		||||
  for (int t = 0; t < NUM_PLANES; ++t) {
 | 
			
		||||
    const float3 v0 = make_float3(
 | 
			
		||||
        box[_PLANES[t].v0][0], box[_PLANES[t].v0][1], box[_PLANES[t].v0][2]);
 | 
			
		||||
    const float3 v1 = make_float3(
 | 
			
		||||
        box[_PLANES[t].v1][0], box[_PLANES[t].v1][1], box[_PLANES[t].v1][2]);
 | 
			
		||||
    const float3 v2 = make_float3(
 | 
			
		||||
        box[_PLANES[t].v2][0], box[_PLANES[t].v2][1], box[_PLANES[t].v2][2]);
 | 
			
		||||
    const float3 v3 = make_float3(
 | 
			
		||||
        box[_PLANES[t].v3][0], box[_PLANES[t].v3][1], box[_PLANES[t].v3][2]);
 | 
			
		||||
    box_planes[t] = {v0, v1, v2, v3};
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// The normal of the face defined by vertices (v0, v1, v2)
 | 
			
		||||
// Define e0 to be the edge connecting (v1, v0)
 | 
			
		||||
// Define e1 to be the edge connecting (v2, v0)
 | 
			
		||||
// normal is the cross product of e0, e1
 | 
			
		||||
//
 | 
			
		||||
// Args
 | 
			
		||||
//    v0, v1, v2: float3 coordinates of the vertices of the face
 | 
			
		||||
//
 | 
			
		||||
// Returns
 | 
			
		||||
//    float3: normal for the face
 | 
			
		||||
//
 | 
			
		||||
__device__ inline float3
 | 
			
		||||
FaceNormal(const float3 v0, const float3 v1, const float3 v2) {
 | 
			
		||||
  float3 n = cross(v1 - v0, v2 - v0);
 | 
			
		||||
  n = n / fmaxf(norm(n), kEpsilon);
 | 
			
		||||
  return n;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// The normal of a box plane defined by the verts in `plane` with
 | 
			
		||||
// the centroid of the box given by `center`.
 | 
			
		||||
// Args
 | 
			
		||||
//    plane: float3 coordinates of the vertices of the plane
 | 
			
		||||
//    center: float3 coordinates of the center of the box from
 | 
			
		||||
//        which the plane originated
 | 
			
		||||
//
 | 
			
		||||
// Returns
 | 
			
		||||
//    float3: normal for the plane such that it points towards
 | 
			
		||||
//      the center of the box
 | 
			
		||||
//
 | 
			
		||||
template <typename FaceVertsPlane>
 | 
			
		||||
__device__ inline float3 PlaneNormalDirection(
 | 
			
		||||
    const FaceVertsPlane& plane,
 | 
			
		||||
    const float3& center) {
 | 
			
		||||
  // Only need the first 3 verts of the plane
 | 
			
		||||
  const float3 v0 = plane.v0;
 | 
			
		||||
  const float3 v1 = plane.v1;
 | 
			
		||||
  const float3 v2 = plane.v2;
 | 
			
		||||
 | 
			
		||||
  // We project the center on the plane defined by (v0, v1, v2)
 | 
			
		||||
  // We can write center = v0 + a * e0 + b * e1 + c * n
 | 
			
		||||
  // We know that <e0, n> = 0 and <e1, n> = 0 and
 | 
			
		||||
  // <a, b> is the dot product between a and b.
 | 
			
		||||
  // This means we can solve for c as:
 | 
			
		||||
  // c = <center - v0 - a * e0 - b * e1, n> = <center - v0, n>
 | 
			
		||||
  float3 n = FaceNormal(v0, v1, v2);
 | 
			
		||||
  const float c = dot((center - v0), n);
 | 
			
		||||
 | 
			
		||||
  // If c is negative, then we revert the direction of n such that n
 | 
			
		||||
  // points "inside"
 | 
			
		||||
  if (c < kEpsilon) {
 | 
			
		||||
    n = -1.0f * n;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  return n;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Calculate the volume of the box by summing the volume of
 | 
			
		||||
// each of the tetrahedrons formed with a triangle face and
 | 
			
		||||
// the box centroid.
 | 
			
		||||
//
 | 
			
		||||
// Args
 | 
			
		||||
//    box_tris: vector of float3 coordinates of the vertices of each
 | 
			
		||||
//       of the triangles in the box
 | 
			
		||||
//    box_center: float3 coordinates of the center of the box
 | 
			
		||||
//
 | 
			
		||||
// Returns
 | 
			
		||||
//    float: volume of the box
 | 
			
		||||
//
 | 
			
		||||
template <typename BoxTris>
 | 
			
		||||
__device__ inline float BoxVolume(
 | 
			
		||||
    const BoxTris& box_tris,
 | 
			
		||||
    const float3& box_center,
 | 
			
		||||
    const int num_tris) {
 | 
			
		||||
  float box_vol = 0.0;
 | 
			
		||||
  // Iterate through each triange, calculate the area of the
 | 
			
		||||
  // tetrahedron formed with the box_center and sum them
 | 
			
		||||
  for (int t = 0; t < num_tris; ++t) {
 | 
			
		||||
    // Subtract the center:
 | 
			
		||||
    float3 v0 = box_tris[t].v0;
 | 
			
		||||
    float3 v1 = box_tris[t].v1;
 | 
			
		||||
    float3 v2 = box_tris[t].v2;
 | 
			
		||||
 | 
			
		||||
    v0 = v0 - box_center;
 | 
			
		||||
    v1 = v1 - box_center;
 | 
			
		||||
    v2 = v2 - box_center;
 | 
			
		||||
 | 
			
		||||
    // Compute the area
 | 
			
		||||
    const float area = dot(v0, cross(v1, v2));
 | 
			
		||||
    const float vol = abs(area) / 6.0;
 | 
			
		||||
    box_vol = box_vol + vol;
 | 
			
		||||
  }
 | 
			
		||||
  return box_vol;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Compute the box center as the mean of the verts
 | 
			
		||||
//
 | 
			
		||||
// Args
 | 
			
		||||
//    box_verts: (8, 3) tensor of the corner vertices of the box
 | 
			
		||||
//
 | 
			
		||||
// Returns
 | 
			
		||||
//    float3: coordinates of the center of the box
 | 
			
		||||
//
 | 
			
		||||
template <typename Box>
 | 
			
		||||
__device__ inline float3 BoxCenter(const Box box_verts) {
 | 
			
		||||
  float x = 0.0;
 | 
			
		||||
  float y = 0.0;
 | 
			
		||||
  float z = 0.0;
 | 
			
		||||
  const int num_verts = box_verts.size(0); // Should be 8
 | 
			
		||||
  // Sum all x, y, z, and take the mean
 | 
			
		||||
  for (int t = 0; t < num_verts; ++t) {
 | 
			
		||||
    x = x + box_verts[t][0];
 | 
			
		||||
    y = y + box_verts[t][1];
 | 
			
		||||
    z = z + box_verts[t][2];
 | 
			
		||||
  }
 | 
			
		||||
  // Take the mean of all the vertex positions
 | 
			
		||||
  x = x / num_verts;
 | 
			
		||||
  y = y / num_verts;
 | 
			
		||||
  z = z / num_verts;
 | 
			
		||||
  const float3 center = make_float3(x, y, z);
 | 
			
		||||
  return center;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Compute the polyhedron center as the mean of the face centers
 | 
			
		||||
// of the triangle faces
 | 
			
		||||
//
 | 
			
		||||
// Args
 | 
			
		||||
//    tris: vector of float3 coordinates of the
 | 
			
		||||
//       vertices of each of the triangles in the polyhedron
 | 
			
		||||
//
 | 
			
		||||
// Returns
 | 
			
		||||
//    float3: coordinates of the center of the polyhedron
 | 
			
		||||
//
 | 
			
		||||
template <typename Tris>
 | 
			
		||||
__device__ inline float3 PolyhedronCenter(
 | 
			
		||||
    const Tris& tris,
 | 
			
		||||
    const int num_tris) {
 | 
			
		||||
  float x = 0.0;
 | 
			
		||||
  float y = 0.0;
 | 
			
		||||
  float z = 0.0;
 | 
			
		||||
 | 
			
		||||
  // Find the center point of each face
 | 
			
		||||
  for (int t = 0; t < num_tris; ++t) {
 | 
			
		||||
    const float3 v0 = tris[t].v0;
 | 
			
		||||
    const float3 v1 = tris[t].v1;
 | 
			
		||||
    const float3 v2 = tris[t].v2;
 | 
			
		||||
    const float x_face = (v0.x + v1.x + v2.x) / 3.0;
 | 
			
		||||
    const float y_face = (v0.y + v1.y + v2.y) / 3.0;
 | 
			
		||||
    const float z_face = (v0.z + v1.z + v2.z) / 3.0;
 | 
			
		||||
    x = x + x_face;
 | 
			
		||||
    y = y + y_face;
 | 
			
		||||
    z = z + z_face;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Take the mean of the centers of all faces
 | 
			
		||||
  x = x / num_tris;
 | 
			
		||||
  y = y / num_tris;
 | 
			
		||||
  z = z / num_tris;
 | 
			
		||||
 | 
			
		||||
  const float3 center = make_float3(x, y, z);
 | 
			
		||||
  return center;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Compute a boolean indicator for whether a point
 | 
			
		||||
// is inside a plane, where inside refers to whether
 | 
			
		||||
// or not the point has a component in the
 | 
			
		||||
// normal direction of the plane.
 | 
			
		||||
//
 | 
			
		||||
// Args
 | 
			
		||||
//    plane: vector of float3 coordinates of the
 | 
			
		||||
//       vertices of each of the triangles in the box
 | 
			
		||||
//    normal: float3 of the direction of the plane normal
 | 
			
		||||
//    point: float3 of the position of the point of interest
 | 
			
		||||
//
 | 
			
		||||
// Returns
 | 
			
		||||
//    bool: whether or not the point is inside the plane
 | 
			
		||||
//
 | 
			
		||||
__device__ inline bool
 | 
			
		||||
IsInside(const FaceVerts& plane, const float3& normal, const float3& point) {
 | 
			
		||||
  // Get one vert of the plane
 | 
			
		||||
  const float3 v0 = plane.v0;
 | 
			
		||||
 | 
			
		||||
  // Every point p can be written as p = v0 + a e0 + b e1 + c n
 | 
			
		||||
  // Solving for c:
 | 
			
		||||
  // c = (point - v0 - a * e0 - b * e1).dot(n)
 | 
			
		||||
  // We know that <e0, n> = 0 and <e1, n> = 0
 | 
			
		||||
  // So the calculation can be simplified as:
 | 
			
		||||
  const float c = dot((point - v0), normal);
 | 
			
		||||
  const bool inside = c > -1.0f * kEpsilon;
 | 
			
		||||
  return inside;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Find the point of intersection between a plane
 | 
			
		||||
// and a line given by the end points (p0, p1)
 | 
			
		||||
//
 | 
			
		||||
// Args
 | 
			
		||||
//    plane: vector of float3 coordinates of the
 | 
			
		||||
//       vertices of each of the triangles in the box
 | 
			
		||||
//    normal: float3 of the direction of the plane normal
 | 
			
		||||
//    p0, p1: float3 of the start and end point of the line
 | 
			
		||||
//
 | 
			
		||||
// Returns
 | 
			
		||||
//    float3: position of the intersection point
 | 
			
		||||
//
 | 
			
		||||
__device__ inline float3 PlaneEdgeIntersection(
 | 
			
		||||
    const FaceVerts& plane,
 | 
			
		||||
    const float3& normal,
 | 
			
		||||
    const float3& p0,
 | 
			
		||||
    const float3& p1) {
 | 
			
		||||
  // Get one vert of the plane
 | 
			
		||||
  const float3 v0 = plane.v0;
 | 
			
		||||
 | 
			
		||||
  // The point of intersection can be parametrized
 | 
			
		||||
  // p = p0 + a (p1 - p0) where a in [0, 1]
 | 
			
		||||
  // We want to find a such that p is on plane
 | 
			
		||||
  // <p - v0, n> = 0
 | 
			
		||||
  const float top = dot(-1.0f * (p0 - v0), normal);
 | 
			
		||||
  const float bot = dot(p1 - p0, normal);
 | 
			
		||||
  const float a = top / bot;
 | 
			
		||||
  const float3 p = p0 + a * (p1 - p0);
 | 
			
		||||
  return p;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Triangle is clipped into a quadrilateral
 | 
			
		||||
// based on the intersection points with the plane.
 | 
			
		||||
// Then the quadrilateral is divided into two triangles.
 | 
			
		||||
//
 | 
			
		||||
// Args
 | 
			
		||||
//    plane: vector of float3 coordinates of the
 | 
			
		||||
//        vertices of each of the triangles in the box
 | 
			
		||||
//    normal: float3 of the direction of the plane normal
 | 
			
		||||
//    vout: float3 of the point in the triangle which is outside
 | 
			
		||||
//       the plane
 | 
			
		||||
//    vin1, vin2: float3 of the points in the triangle which are
 | 
			
		||||
//        inside the plane
 | 
			
		||||
//    face_verts_out: Array of structs of type FaceVerts,
 | 
			
		||||
//       with the coordinates of the new triangle faces
 | 
			
		||||
//       formed after clipping.
 | 
			
		||||
//       All triangles are now "inside" the plane.
 | 
			
		||||
//
 | 
			
		||||
// Returns:
 | 
			
		||||
//    count: (int) number of new faces after clipping the triangle
 | 
			
		||||
//      i.e. the valid faces which have been saved
 | 
			
		||||
//      to face_verts_out
 | 
			
		||||
//
 | 
			
		||||
template <typename FaceVertsBox>
 | 
			
		||||
__device__ inline int ClipTriByPlaneOneOut(
 | 
			
		||||
    const FaceVerts& plane,
 | 
			
		||||
    const float3& normal,
 | 
			
		||||
    const float3& vout,
 | 
			
		||||
    const float3& vin1,
 | 
			
		||||
    const float3& vin2,
 | 
			
		||||
    FaceVertsBox& face_verts_out) {
 | 
			
		||||
  // point of intersection between plane and (vin1, vout)
 | 
			
		||||
  const float3 pint1 = PlaneEdgeIntersection(plane, normal, vin1, vout);
 | 
			
		||||
  // point of intersection between plane and (vin2, vout)
 | 
			
		||||
  const float3 pint2 = PlaneEdgeIntersection(plane, normal, vin2, vout);
 | 
			
		||||
 | 
			
		||||
  face_verts_out[0] = {vin1, pint1, pint2};
 | 
			
		||||
  face_verts_out[1] = {vin1, pint2, vin2};
 | 
			
		||||
 | 
			
		||||
  return 2;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Triangle is clipped into a smaller triangle based
 | 
			
		||||
// on the intersection points with the plane.
 | 
			
		||||
//
 | 
			
		||||
// Args
 | 
			
		||||
//    plane: vector of float3 coordinates of the
 | 
			
		||||
//       vertices of each of the triangles in the box
 | 
			
		||||
//    normal: float3 of the direction of the plane normal
 | 
			
		||||
//    vout1, vout2: float3 of the points in the triangle which are
 | 
			
		||||
//       outside the plane
 | 
			
		||||
//    vin: float3 of the point in the triangle which is inside
 | 
			
		||||
//        the plane
 | 
			
		||||
//    face_verts_out: Array of structs of type FaceVerts,
 | 
			
		||||
//       with the coordinates of the new triangle faces
 | 
			
		||||
//       formed after clipping.
 | 
			
		||||
//       All triangles are now "inside" the plane.
 | 
			
		||||
//
 | 
			
		||||
// Returns:
 | 
			
		||||
//    count: (int) number of new faces after clipping the triangle
 | 
			
		||||
//      i.e. the valid faces which have been saved
 | 
			
		||||
//      to face_verts_out
 | 
			
		||||
//
 | 
			
		||||
template <typename FaceVertsBox>
 | 
			
		||||
__device__ inline int ClipTriByPlaneTwoOut(
 | 
			
		||||
    const FaceVerts& plane,
 | 
			
		||||
    const float3& normal,
 | 
			
		||||
    const float3& vout1,
 | 
			
		||||
    const float3& vout2,
 | 
			
		||||
    const float3& vin,
 | 
			
		||||
    FaceVertsBox& face_verts_out) {
 | 
			
		||||
  // point of intersection between plane and (vin, vout1)
 | 
			
		||||
  const float3 pint1 = PlaneEdgeIntersection(plane, normal, vin, vout1);
 | 
			
		||||
  // point of intersection between plane and (vin, vout2)
 | 
			
		||||
  const float3 pint2 = PlaneEdgeIntersection(plane, normal, vin, vout2);
 | 
			
		||||
 | 
			
		||||
  face_verts_out[0] = {vin, pint1, pint2};
 | 
			
		||||
 | 
			
		||||
  return 1;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Clip the triangle faces so that they lie within the
 | 
			
		||||
// plane, creating new triangle faces where necessary.
 | 
			
		||||
//
 | 
			
		||||
// Args
 | 
			
		||||
//    plane: Array of structs of type FaceVerts with the coordinates
 | 
			
		||||
//       of the vertices of each of the triangles in the box
 | 
			
		||||
//    tri: Array of structs of type FaceVerts with the vertex
 | 
			
		||||
//       coordinates of the triangle faces
 | 
			
		||||
//    normal: float3 of the direction of the plane normal
 | 
			
		||||
//    face_verts_out: Array of structs of type FaceVerts,
 | 
			
		||||
//       with the coordinates of the new triangle faces
 | 
			
		||||
//       formed after clipping.
 | 
			
		||||
//       All triangles are now "inside" the plane.
 | 
			
		||||
//
 | 
			
		||||
// Returns:
 | 
			
		||||
//    count: (int) number of new faces after clipping the triangle
 | 
			
		||||
//      i.e. the valid faces which have been saved
 | 
			
		||||
//      to face_verts_out
 | 
			
		||||
//
 | 
			
		||||
template <typename FaceVertsBox>
 | 
			
		||||
__device__ inline int ClipTriByPlane(
 | 
			
		||||
    const FaceVerts& plane,
 | 
			
		||||
    const FaceVerts& tri,
 | 
			
		||||
    const float3& normal,
 | 
			
		||||
    FaceVertsBox& face_verts_out) {
 | 
			
		||||
  // Get Triangle vertices
 | 
			
		||||
  const float3 v0 = tri.v0;
 | 
			
		||||
  const float3 v1 = tri.v1;
 | 
			
		||||
  const float3 v2 = tri.v2;
 | 
			
		||||
 | 
			
		||||
  // Check each of the triangle vertices to see if it is inside the plane
 | 
			
		||||
  const bool isin0 = IsInside(plane, normal, v0);
 | 
			
		||||
  const bool isin1 = IsInside(plane, normal, v1);
 | 
			
		||||
  const bool isin2 = IsInside(plane, normal, v2);
 | 
			
		||||
 | 
			
		||||
  // All in
 | 
			
		||||
  if (isin0 && isin1 && isin2) {
 | 
			
		||||
    // Return input vertices
 | 
			
		||||
    face_verts_out[0] = {v0, v1, v2};
 | 
			
		||||
    return 1;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // All out
 | 
			
		||||
  if (!isin0 && !isin1 && !isin2) {
 | 
			
		||||
    return 0;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // One vert out
 | 
			
		||||
  if (isin0 && isin1 && !isin2) {
 | 
			
		||||
    return ClipTriByPlaneOneOut(plane, normal, v2, v0, v1, face_verts_out);
 | 
			
		||||
  }
 | 
			
		||||
  if (isin0 && not isin1 && isin2) {
 | 
			
		||||
    return ClipTriByPlaneOneOut(plane, normal, v1, v0, v2, face_verts_out);
 | 
			
		||||
  }
 | 
			
		||||
  if (not isin0 && isin1 && isin2) {
 | 
			
		||||
    return ClipTriByPlaneOneOut(plane, normal, v0, v1, v2, face_verts_out);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Two verts out
 | 
			
		||||
  if (isin0 && !isin1 && !isin2) {
 | 
			
		||||
    return ClipTriByPlaneTwoOut(plane, normal, v1, v2, v0, face_verts_out);
 | 
			
		||||
  }
 | 
			
		||||
  if (!isin0 && !isin1 && isin2) {
 | 
			
		||||
    return ClipTriByPlaneTwoOut(plane, normal, v0, v1, v2, face_verts_out);
 | 
			
		||||
  }
 | 
			
		||||
  if (!isin0 && isin1 && !isin2) {
 | 
			
		||||
    return ClipTriByPlaneTwoOut(plane, normal, v0, v2, v1, face_verts_out);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Else return empty (should not be reached)
 | 
			
		||||
  return 0;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Compute a boolean indicator for whether or not two faces
 | 
			
		||||
// are coplanar
 | 
			
		||||
//
 | 
			
		||||
// Args
 | 
			
		||||
//    tri1, tri2: FaceVerts struct of the vertex coordinates of
 | 
			
		||||
//       the triangle face
 | 
			
		||||
//
 | 
			
		||||
// Returns
 | 
			
		||||
//    bool: whether or not the two faces are coplanar
 | 
			
		||||
//
 | 
			
		||||
__device__ inline bool IsCoplanarFace(
 | 
			
		||||
    const FaceVerts& tri1,
 | 
			
		||||
    const FaceVerts& tri2) {
 | 
			
		||||
  // Get verts for face 1
 | 
			
		||||
  const float3 v0 = tri1.v0;
 | 
			
		||||
  const float3 v1 = tri1.v1;
 | 
			
		||||
  const float3 v2 = tri1.v2;
 | 
			
		||||
 | 
			
		||||
  const float3 n1 = FaceNormal(v0, v1, v2);
 | 
			
		||||
  int coplanar_count = 0;
 | 
			
		||||
 | 
			
		||||
  // Check v0, v1, v2
 | 
			
		||||
  if (abs(dot(tri2.v0 - v0, n1)) < kEpsilon) {
 | 
			
		||||
    coplanar_count++;
 | 
			
		||||
  }
 | 
			
		||||
  if (abs(dot(tri2.v1 - v0, n1)) < kEpsilon) {
 | 
			
		||||
    coplanar_count++;
 | 
			
		||||
  }
 | 
			
		||||
  if (abs(dot(tri2.v2 - v0, n1)) < kEpsilon) {
 | 
			
		||||
    coplanar_count++;
 | 
			
		||||
  }
 | 
			
		||||
  return (coplanar_count == 3);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Get the triangles from each box which are part of the
 | 
			
		||||
// intersecting polyhedron by computing the intersection
 | 
			
		||||
// points with each of the planes.
 | 
			
		||||
//
 | 
			
		||||
// Args
 | 
			
		||||
//    planes: Array of structs of type FaceVerts with the coordinates
 | 
			
		||||
//       of the vertices of each of the triangles in the box
 | 
			
		||||
//    center: float3 coordinates of the center of the box from which
 | 
			
		||||
//        the planes originate
 | 
			
		||||
//    face_verts_out: Array of structs of type FaceVerts,
 | 
			
		||||
//       where the coordinates of the new triangle faces
 | 
			
		||||
//       formed after clipping will be saved to.
 | 
			
		||||
//       All triangles are now "inside" the plane.
 | 
			
		||||
//
 | 
			
		||||
// Returns:
 | 
			
		||||
//    count: (int) number of faces in the intersecting shape
 | 
			
		||||
//      i.e. the valid faces which have been saved
 | 
			
		||||
//      to face_verts_out
 | 
			
		||||
//
 | 
			
		||||
template <typename FaceVertsPlane, typename FaceVertsBox>
 | 
			
		||||
__device__ inline int BoxIntersections(
 | 
			
		||||
    const FaceVertsPlane& planes,
 | 
			
		||||
    const float3& center,
 | 
			
		||||
    FaceVertsBox& face_verts_out) {
 | 
			
		||||
  // Initialize num tris to 12
 | 
			
		||||
  int num_tris = NUM_TRIS;
 | 
			
		||||
  for (int p = 0; p < NUM_PLANES; ++p) {
 | 
			
		||||
    // Get plane normal direction
 | 
			
		||||
    const float3 n2 = PlaneNormalDirection(planes[p], center);
 | 
			
		||||
    // Create intermediate vector to store the updated tris
 | 
			
		||||
    FaceVerts tri_verts_updated[MAX_TRIS];
 | 
			
		||||
    int offset = 0;
 | 
			
		||||
 | 
			
		||||
    // Iterate through triangles in face_verts_out
 | 
			
		||||
    // for the valid tris given by num_tris
 | 
			
		||||
    for (int t = 0; t < num_tris; ++t) {
 | 
			
		||||
      // Clip tri by plane, can max be split into 2 triangles
 | 
			
		||||
      FaceVerts tri_updated[2];
 | 
			
		||||
      const int count =
 | 
			
		||||
          ClipTriByPlane(planes[p], face_verts_out[t], n2, tri_updated);
 | 
			
		||||
      // Add to the tri_verts_updated output if not empty
 | 
			
		||||
      for (int v = 0; v < count; ++v) {
 | 
			
		||||
        tri_verts_updated[offset] = tri_updated[v];
 | 
			
		||||
        offset++;
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
    // Update the face_verts_out tris
 | 
			
		||||
    num_tris = offset;
 | 
			
		||||
    for (int j = 0; j < num_tris; ++j) {
 | 
			
		||||
      face_verts_out[j] = tri_verts_updated[j];
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  return num_tris;
 | 
			
		||||
}
 | 
			
		||||
@ -9,6 +9,7 @@ from .cameras_alignment import corresponding_cameras_alignment
 | 
			
		||||
from .cubify import cubify
 | 
			
		||||
from .graph_conv import GraphConv
 | 
			
		||||
from .interp_face_attrs import interpolate_face_attributes
 | 
			
		||||
from .iou_box3d import box3d_overlap
 | 
			
		||||
from .knn import knn_gather, knn_points
 | 
			
		||||
from .laplacian_matrices import cot_laplacian, laplacian, norm_laplacian
 | 
			
		||||
from .mesh_face_areas_normals import mesh_face_areas_normals
 | 
			
		||||
 | 
			
		||||
@ -7,10 +7,68 @@
 | 
			
		||||
from typing import Tuple
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
import torch.nn.functional as F
 | 
			
		||||
from pytorch3d import _C
 | 
			
		||||
from torch.autograd import Function
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# -------------------------------------------------- #
 | 
			
		||||
#                  CONSTANTS                         #
 | 
			
		||||
# -------------------------------------------------- #
 | 
			
		||||
"""
 | 
			
		||||
_box_planes and _box_triangles define the 4- and 3-connectivity
 | 
			
		||||
of the 8 box corners.
 | 
			
		||||
_box_planes gives the quad faces of the 3D box
 | 
			
		||||
_box_triangles gives the triangle faces of the 3D box
 | 
			
		||||
"""
 | 
			
		||||
_box_planes = [
 | 
			
		||||
    [0, 1, 2, 3],
 | 
			
		||||
    [3, 2, 6, 7],
 | 
			
		||||
    [0, 1, 5, 4],
 | 
			
		||||
    [0, 3, 7, 4],
 | 
			
		||||
    [1, 2, 6, 5],
 | 
			
		||||
    [4, 5, 6, 7],
 | 
			
		||||
]
 | 
			
		||||
_box_triangles = [
 | 
			
		||||
    [0, 1, 2],
 | 
			
		||||
    [0, 3, 2],
 | 
			
		||||
    [4, 5, 6],
 | 
			
		||||
    [4, 6, 7],
 | 
			
		||||
    [1, 5, 6],
 | 
			
		||||
    [1, 6, 2],
 | 
			
		||||
    [0, 4, 7],
 | 
			
		||||
    [0, 7, 3],
 | 
			
		||||
    [3, 2, 6],
 | 
			
		||||
    [3, 6, 7],
 | 
			
		||||
    [0, 1, 5],
 | 
			
		||||
    [0, 4, 5],
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _check_coplanar(boxes: torch.Tensor, eps: float = 1e-5) -> None:
 | 
			
		||||
    faces = torch.tensor(_box_planes, dtype=torch.int64, device=boxes.device)
 | 
			
		||||
    # pyre-fixme[16]: `boxes` has no attribute `index_select`.
 | 
			
		||||
    verts = boxes.index_select(index=faces.view(-1), dim=1)
 | 
			
		||||
    B = boxes.shape[0]
 | 
			
		||||
    P, V = faces.shape
 | 
			
		||||
    # (B, P, 4, 3) -> (B, P, 3)
 | 
			
		||||
    v0, v1, v2, v3 = verts.reshape(B, P, V, 3).unbind(2)
 | 
			
		||||
 | 
			
		||||
    # Compute the normal
 | 
			
		||||
    e0 = F.normalize(v1 - v0, dim=-1)
 | 
			
		||||
    e1 = F.normalize(v2 - v0, dim=-1)
 | 
			
		||||
    normal = F.normalize(torch.cross(e0, e1, dim=-1), dim=-1)
 | 
			
		||||
 | 
			
		||||
    # Check the fourth vertex is also on the same plane
 | 
			
		||||
    mat1 = (v3 - v0).view(B, 1, -1)  # (B, 1, P*3)
 | 
			
		||||
    mat2 = normal.view(B, -1, 1)  # (B, P*3, 1)
 | 
			
		||||
    if not (mat1.bmm(mat2).abs() < eps).all().item():
 | 
			
		||||
        msg = "Plane vertices are not coplanar"
 | 
			
		||||
        raise ValueError(msg)
 | 
			
		||||
 | 
			
		||||
    return
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class _box3d_overlap(Function):
 | 
			
		||||
    """
 | 
			
		||||
    Torch autograd Function wrapper for box3d_overlap C++/CUDA implementations.
 | 
			
		||||
@ -35,6 +93,7 @@ def box3d_overlap(
 | 
			
		||||
) -> Tuple[torch.Tensor, torch.Tensor]:
 | 
			
		||||
    """
 | 
			
		||||
    Computes the intersection of 3D boxes1 and boxes2.
 | 
			
		||||
 | 
			
		||||
    Inputs boxes1, boxes2 are tensors of shape (B, 8, 3)
 | 
			
		||||
    (where B doesn't have to be the same for boxes1 and boxes1),
 | 
			
		||||
    containing the 8 corners of the boxes, as follows:
 | 
			
		||||
@ -47,6 +106,25 @@ def box3d_overlap(
 | 
			
		||||
            ` .   |     ` . |
 | 
			
		||||
            (3) ` +---------+ (2)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    NOTE: Throughout this implementation, we assume that boxes
 | 
			
		||||
    are defined by their 8 corners exactly in the order specified in the
 | 
			
		||||
    diagram above for the function to give correct results. In addition
 | 
			
		||||
    the vertices on each plane must be coplanar.
 | 
			
		||||
    As an alternative to the diagram, this is a unit bounding
 | 
			
		||||
    box which has the correct vertex ordering:
 | 
			
		||||
 | 
			
		||||
    box_corner_vertices = [
 | 
			
		||||
        [0, 0, 0],
 | 
			
		||||
        [1, 0, 0],
 | 
			
		||||
        [1, 1, 0],
 | 
			
		||||
        [0, 1, 0],
 | 
			
		||||
        [0, 0, 1],
 | 
			
		||||
        [1, 0, 1],
 | 
			
		||||
        [1, 1, 1],
 | 
			
		||||
        [0, 1, 1],
 | 
			
		||||
    ]
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        boxes1: tensor of shape (N, 8, 3) of the coordinates of the 1st boxes
 | 
			
		||||
        boxes2: tensor of shape (M, 8, 3) of the coordinates of the 2nd boxes
 | 
			
		||||
@ -58,6 +136,9 @@ def box3d_overlap(
 | 
			
		||||
    if not all((8, 3) == box.shape[1:] for box in [boxes1, boxes2]):
 | 
			
		||||
        raise ValueError("Each box in the batch must be of shape (8, 3)")
 | 
			
		||||
 | 
			
		||||
    _check_coplanar(boxes1)
 | 
			
		||||
    _check_coplanar(boxes2)
 | 
			
		||||
 | 
			
		||||
    # pyre-fixme[16]: `_box3d_overlap` has no attribute `apply`.
 | 
			
		||||
    vol, iou = _box3d_overlap.apply(boxes1, boxes2)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -11,25 +11,42 @@ from test_iou_box3d import TestIoU3D
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def bm_iou_box3d() -> None:
 | 
			
		||||
    N = [1, 4, 8, 16]
 | 
			
		||||
    num_samples = [2000, 5000, 10000, 20000]
 | 
			
		||||
    # Realistic use cases
 | 
			
		||||
    N = [30, 100]
 | 
			
		||||
    M = [5, 10, 100]
 | 
			
		||||
    kwargs_list = []
 | 
			
		||||
    test_cases = product(N, M)
 | 
			
		||||
    for case in test_cases:
 | 
			
		||||
        n, m = case
 | 
			
		||||
        kwargs_list.append({"N": n, "M": m, "device": "cuda:0"})
 | 
			
		||||
    benchmark(TestIoU3D.iou, "3D_IOU", kwargs_list, warmup_iters=1)
 | 
			
		||||
 | 
			
		||||
    # Comparison of C++/CUDA
 | 
			
		||||
    kwargs_list = []
 | 
			
		||||
    N = [1, 4, 8, 16]
 | 
			
		||||
    devices = ["cpu", "cuda:0"]
 | 
			
		||||
    test_cases = product(N, N, devices)
 | 
			
		||||
    for case in test_cases:
 | 
			
		||||
        n, m, d = case
 | 
			
		||||
        kwargs_list.append({"N": n, "M": m, "device": d})
 | 
			
		||||
    benchmark(TestIoU3D.iou, "3D_IOU", kwargs_list, warmup_iters=1)
 | 
			
		||||
 | 
			
		||||
    # Naive PyTorch
 | 
			
		||||
    N = [1, 4]
 | 
			
		||||
    kwargs_list = []
 | 
			
		||||
    test_cases = product(N, N)
 | 
			
		||||
    for case in test_cases:
 | 
			
		||||
        n, m = case
 | 
			
		||||
        kwargs_list.append({"N": n, "M": m, "device": "cuda:0"})
 | 
			
		||||
 | 
			
		||||
    benchmark(TestIoU3D.iou_naive, "3D_IOU_NAIVE", kwargs_list, warmup_iters=1)
 | 
			
		||||
 | 
			
		||||
    [k.update({"device": "cpu"}) for k in kwargs_list]
 | 
			
		||||
    benchmark(TestIoU3D.iou, "3D_IOU", kwargs_list, warmup_iters=1)
 | 
			
		||||
 | 
			
		||||
    # Sampling based method
 | 
			
		||||
    num_samples = [2000, 5000]
 | 
			
		||||
    kwargs_list = []
 | 
			
		||||
    test_cases = product([1, 4], [1, 4], num_samples)
 | 
			
		||||
    test_cases = product(N, N, num_samples)
 | 
			
		||||
    for case in test_cases:
 | 
			
		||||
        n, m, s = case
 | 
			
		||||
        kwargs_list.append({"N": n, "M": m, "num_samples": s})
 | 
			
		||||
        kwargs_list.append({"N": n, "M": m, "num_samples": s, "device": "cuda:0"})
 | 
			
		||||
    benchmark(TestIoU3D.iou_sampling, "3D_IOU_SAMPLING", kwargs_list, warmup_iters=1)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										
											BIN
										
									
								
								tests/data/objectron_vols_ious.pt
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								tests/data/objectron_vols_ious.pt
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							@ -10,13 +10,28 @@ from typing import List, Tuple, Union
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
import torch.nn.functional as F
 | 
			
		||||
from common_testing import TestCaseMixin
 | 
			
		||||
from common_testing import TestCaseMixin, get_random_cuda_device, get_tests_dir
 | 
			
		||||
from pytorch3d.io import save_obj
 | 
			
		||||
 | 
			
		||||
from pytorch3d.ops.iou_box3d import box3d_overlap
 | 
			
		||||
from pytorch3d.ops.iou_box3d import _box_planes, _box_triangles, box3d_overlap
 | 
			
		||||
from pytorch3d.transforms.rotation_conversions import random_rotation
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
OBJECTRON_TO_PYTORCH3D_FACE_IDX = [0, 4, 6, 2, 1, 5, 7, 3]
 | 
			
		||||
DATA_DIR = get_tests_dir() / "data"
 | 
			
		||||
DEBUG = False
 | 
			
		||||
 | 
			
		||||
UNIT_BOX = [
 | 
			
		||||
    [0, 0, 0],
 | 
			
		||||
    [1, 0, 0],
 | 
			
		||||
    [1, 1, 0],
 | 
			
		||||
    [0, 1, 0],
 | 
			
		||||
    [0, 0, 1],
 | 
			
		||||
    [1, 0, 1],
 | 
			
		||||
    [1, 1, 1],
 | 
			
		||||
    [0, 1, 1],
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestIoU3D(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
    def setUp(self) -> None:
 | 
			
		||||
        super().setUp()
 | 
			
		||||
@ -78,16 +93,7 @@ class TestIoU3D(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
    def _test_iou(self, overlap_fn, device):
 | 
			
		||||
 | 
			
		||||
        box1 = torch.tensor(
 | 
			
		||||
            [
 | 
			
		||||
                [0, 0, 0],
 | 
			
		||||
                [1, 0, 0],
 | 
			
		||||
                [1, 1, 0],
 | 
			
		||||
                [0, 1, 0],
 | 
			
		||||
                [0, 0, 1],
 | 
			
		||||
                [1, 0, 1],
 | 
			
		||||
                [1, 1, 1],
 | 
			
		||||
                [0, 1, 1],
 | 
			
		||||
            ],
 | 
			
		||||
            UNIT_BOX,
 | 
			
		||||
            dtype=torch.float32,
 | 
			
		||||
            device=device,
 | 
			
		||||
        )
 | 
			
		||||
@ -126,6 +132,10 @@ class TestIoU3D(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
            ),
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # Also check IoU is 1 when computing overlap with the same shifted box
 | 
			
		||||
        vol, iou = overlap_fn(box2[None], box2[None])
 | 
			
		||||
        self.assertClose(iou, torch.tensor([[1.0]], device=vol.device, dtype=vol.dtype))
 | 
			
		||||
 | 
			
		||||
        # 5th test
 | 
			
		||||
        ddx, ddy, ddz = random.random(), random.random(), random.random()
 | 
			
		||||
        box2 = box1 + torch.tensor([[ddx, ddy, ddz]], device=device)
 | 
			
		||||
@ -207,15 +217,15 @@ class TestIoU3D(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
        # create box1
 | 
			
		||||
        ctrs = torch.rand((2, 3), device=device)
 | 
			
		||||
        whl = torch.rand((2, 3), device=device) * 10.0 + 1.0
 | 
			
		||||
        # box1 & box2
 | 
			
		||||
        box1 = self.create_box(ctrs[0], whl[0])
 | 
			
		||||
        box2 = self.create_box(ctrs[1], whl[1])
 | 
			
		||||
        # box8a & box8b
 | 
			
		||||
        box8a = self.create_box(ctrs[0], whl[0])
 | 
			
		||||
        box8b = self.create_box(ctrs[1], whl[1])
 | 
			
		||||
        RR1 = random_rotation(dtype=torch.float32, device=device)
 | 
			
		||||
        TT1 = torch.rand((1, 3), dtype=torch.float32, device=device)
 | 
			
		||||
        RR2 = random_rotation(dtype=torch.float32, device=device)
 | 
			
		||||
        TT2 = torch.rand((1, 3), dtype=torch.float32, device=device)
 | 
			
		||||
        box1r = box1 @ RR1.transpose(0, 1) + TT1
 | 
			
		||||
        box2r = box2 @ RR2.transpose(0, 1) + TT2
 | 
			
		||||
        box1r = box8a @ RR1.transpose(0, 1) + TT1
 | 
			
		||||
        box2r = box8b @ RR2.transpose(0, 1) + TT2
 | 
			
		||||
        vol, iou = overlap_fn(box1r[None], box2r[None])
 | 
			
		||||
        iou_sampling = self._box3d_overlap_sampling_batched(
 | 
			
		||||
            box1r[None], box2r[None], num_samples=10000
 | 
			
		||||
@ -229,27 +239,90 @@ class TestIoU3D(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
        self.assertClose(vol, torch.tensor([[0.0]], device=vol.device, dtype=vol.dtype))
 | 
			
		||||
        self.assertClose(iou, torch.tensor([[0.0]], device=vol.device, dtype=vol.dtype))
 | 
			
		||||
 | 
			
		||||
        # 10th test: Non coplanar verts in a plane
 | 
			
		||||
        box10 = box1 + torch.rand((8, 3), dtype=torch.float32, device=device)
 | 
			
		||||
        msg = "Plane vertices are not coplanar"
 | 
			
		||||
        with self.assertRaisesRegex(ValueError, msg):
 | 
			
		||||
            overlap_fn(box10[None], box10[None])
 | 
			
		||||
 | 
			
		||||
        # 11th test: Skewed bounding boxes but all verts are coplanar
 | 
			
		||||
        box_skew_1 = torch.tensor(
 | 
			
		||||
            [
 | 
			
		||||
                [0, 0, 0],
 | 
			
		||||
                [1, 0, 0],
 | 
			
		||||
                [1, 1, 0],
 | 
			
		||||
                [0, 1, 0],
 | 
			
		||||
                [-2, -2, 2],
 | 
			
		||||
                [2, -2, 2],
 | 
			
		||||
                [2, 2, 2],
 | 
			
		||||
                [-2, 2, 2],
 | 
			
		||||
            ],
 | 
			
		||||
            dtype=torch.float32,
 | 
			
		||||
            device=device,
 | 
			
		||||
        )
 | 
			
		||||
        box_skew_2 = torch.tensor(
 | 
			
		||||
            [
 | 
			
		||||
                [2.015995, 0.695233, 2.152806],
 | 
			
		||||
                [2.832533, 0.663448, 1.576389],
 | 
			
		||||
                [2.675445, -0.309592, 1.407520],
 | 
			
		||||
                [1.858907, -0.277806, 1.983936],
 | 
			
		||||
                [-0.413922, 3.161758, 2.044343],
 | 
			
		||||
                [2.852230, 3.034615, -0.261321],
 | 
			
		||||
                [2.223878, -0.857545, -0.936800],
 | 
			
		||||
                [-1.042273, -0.730402, 1.368864],
 | 
			
		||||
            ],
 | 
			
		||||
            dtype=torch.float32,
 | 
			
		||||
            device=device,
 | 
			
		||||
        )
 | 
			
		||||
        vol1 = 14.000
 | 
			
		||||
        vol2 = 14.000005
 | 
			
		||||
        vol_inters = 5.431122
 | 
			
		||||
        iou = vol_inters / (vol1 + vol2 - vol_inters)
 | 
			
		||||
 | 
			
		||||
        vols, ious = overlap_fn(box_skew_1[None], box_skew_2[None])
 | 
			
		||||
        self.assertClose(vols, torch.tensor([[vol_inters]], device=device), atol=1e-1)
 | 
			
		||||
        self.assertClose(ious, torch.tensor([[iou]], device=device), atol=1e-1)
 | 
			
		||||
 | 
			
		||||
    def test_iou_naive(self):
 | 
			
		||||
        device = torch.device("cuda:0")
 | 
			
		||||
        device = get_random_cuda_device()
 | 
			
		||||
        self._test_iou(self._box3d_overlap_naive_batched, device)
 | 
			
		||||
        self._test_compare_objectron(self._box3d_overlap_naive_batched, device)
 | 
			
		||||
 | 
			
		||||
    def test_iou_cpu(self):
 | 
			
		||||
        device = torch.device("cpu")
 | 
			
		||||
        self._test_iou(box3d_overlap, device)
 | 
			
		||||
        self._test_compare_objectron(box3d_overlap, device)
 | 
			
		||||
 | 
			
		||||
    def test_cpu_vs_naive_batched(self):
 | 
			
		||||
        N, M = 3, 6
 | 
			
		||||
        device = "cpu"
 | 
			
		||||
        boxes1 = torch.randn((N, 8, 3), device=device)
 | 
			
		||||
        boxes2 = torch.randn((M, 8, 3), device=device)
 | 
			
		||||
        vol1, iou1 = self._box3d_overlap_naive_batched(boxes1, boxes2)
 | 
			
		||||
        vol2, iou2 = box3d_overlap(boxes1, boxes2)
 | 
			
		||||
        # check shape
 | 
			
		||||
        for val in [vol1, vol2, iou1, iou2]:
 | 
			
		||||
            self.assertClose(val.shape, (N, M))
 | 
			
		||||
        # check values
 | 
			
		||||
        self.assertClose(vol1, vol2)
 | 
			
		||||
        self.assertClose(iou1, iou2)
 | 
			
		||||
    def test_iou_cuda(self):
 | 
			
		||||
        device = torch.device("cuda:0")
 | 
			
		||||
        self._test_iou(box3d_overlap, device)
 | 
			
		||||
        self._test_compare_objectron(box3d_overlap, device)
 | 
			
		||||
 | 
			
		||||
    def _test_compare_objectron(self, overlap_fn, device):
 | 
			
		||||
        # Load saved objectron data
 | 
			
		||||
        data_filename = "./objectron_vols_ious.pt"
 | 
			
		||||
        objectron_vals = torch.load(DATA_DIR / data_filename)
 | 
			
		||||
        boxes1 = objectron_vals["boxes1"]
 | 
			
		||||
        boxes2 = objectron_vals["boxes2"]
 | 
			
		||||
        vols_objectron = objectron_vals["vols"]
 | 
			
		||||
        ious_objectron = objectron_vals["ious"]
 | 
			
		||||
 | 
			
		||||
        boxes1 = boxes1.to(device=device, dtype=torch.float32)
 | 
			
		||||
        boxes2 = boxes2.to(device=device, dtype=torch.float32)
 | 
			
		||||
 | 
			
		||||
        # Convert vertex orderings from Objectron to PyTorch3D convention
 | 
			
		||||
        idx = torch.tensor(
 | 
			
		||||
            OBJECTRON_TO_PYTORCH3D_FACE_IDX, dtype=torch.int64, device=device
 | 
			
		||||
        )
 | 
			
		||||
        boxes1 = boxes1.index_select(index=idx, dim=1)
 | 
			
		||||
        boxes2 = boxes2.index_select(index=idx, dim=1)
 | 
			
		||||
 | 
			
		||||
        # Run PyTorch3D version
 | 
			
		||||
        vols, ious = overlap_fn(boxes1, boxes2)
 | 
			
		||||
 | 
			
		||||
        # Check values match
 | 
			
		||||
        self.assertClose(vols_objectron, vols.cpu())
 | 
			
		||||
        self.assertClose(ious_objectron, ious.cpu())
 | 
			
		||||
 | 
			
		||||
    def test_batched_errors(self):
 | 
			
		||||
        N, M = 5, 10
 | 
			
		||||
@ -316,16 +389,7 @@ class TestIoU3D(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
    def test_box_planar_dir(self):
 | 
			
		||||
        device = torch.device("cuda:0")
 | 
			
		||||
        box1 = torch.tensor(
 | 
			
		||||
            [
 | 
			
		||||
                [0, 0, 0],
 | 
			
		||||
                [1, 0, 0],
 | 
			
		||||
                [1, 1, 0],
 | 
			
		||||
                [0, 1, 0],
 | 
			
		||||
                [0, 0, 1],
 | 
			
		||||
                [1, 0, 1],
 | 
			
		||||
                [1, 1, 1],
 | 
			
		||||
                [0, 1, 1],
 | 
			
		||||
            ],
 | 
			
		||||
            UNIT_BOX,
 | 
			
		||||
            dtype=torch.float32,
 | 
			
		||||
            device=device,
 | 
			
		||||
        )
 | 
			
		||||
@ -353,8 +417,13 @@ class TestIoU3D(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def iou_naive(N: int, M: int, device="cpu"):
 | 
			
		||||
        boxes1 = torch.randn((N, 8, 3))
 | 
			
		||||
        boxes2 = torch.randn((M, 8, 3))
 | 
			
		||||
        box = torch.tensor(
 | 
			
		||||
            [UNIT_BOX],
 | 
			
		||||
            dtype=torch.float32,
 | 
			
		||||
            device=device,
 | 
			
		||||
        )
 | 
			
		||||
        boxes1 = box + torch.randn((N, 1, 3), device=device)
 | 
			
		||||
        boxes2 = box + torch.randn((M, 1, 3), device=device)
 | 
			
		||||
 | 
			
		||||
        def output():
 | 
			
		||||
            vol, iou = TestIoU3D._box3d_overlap_naive_batched(boxes1, boxes2)
 | 
			
		||||
@ -363,8 +432,13 @@ class TestIoU3D(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def iou(N: int, M: int, device="cpu"):
 | 
			
		||||
        boxes1 = torch.randn((N, 8, 3), device=device)
 | 
			
		||||
        boxes2 = torch.randn((M, 8, 3), device=device)
 | 
			
		||||
        box = torch.tensor(
 | 
			
		||||
            [UNIT_BOX],
 | 
			
		||||
            dtype=torch.float32,
 | 
			
		||||
            device=device,
 | 
			
		||||
        )
 | 
			
		||||
        boxes1 = box + torch.randn((N, 1, 3), device=device)
 | 
			
		||||
        boxes2 = box + torch.randn((M, 1, 3), device=device)
 | 
			
		||||
 | 
			
		||||
        def output():
 | 
			
		||||
            vol, iou = box3d_overlap(boxes1, boxes2)
 | 
			
		||||
@ -372,9 +446,14 @@ class TestIoU3D(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
        return output
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def iou_sampling(N: int, M: int, num_samples: int):
 | 
			
		||||
        boxes1 = torch.randn((N, 8, 3))
 | 
			
		||||
        boxes2 = torch.randn((M, 8, 3))
 | 
			
		||||
    def iou_sampling(N: int, M: int, num_samples: int, device="cpu"):
 | 
			
		||||
        box = torch.tensor(
 | 
			
		||||
            [UNIT_BOX],
 | 
			
		||||
            dtype=torch.float32,
 | 
			
		||||
            device=device,
 | 
			
		||||
        )
 | 
			
		||||
        boxes1 = box + torch.randn((N, 1, 3), device=device)
 | 
			
		||||
        boxes2 = box + torch.randn((M, 1, 3), device=device)
 | 
			
		||||
 | 
			
		||||
        def output():
 | 
			
		||||
            _ = TestIoU3D._box3d_overlap_sampling_batched(boxes1, boxes2, num_samples)
 | 
			
		||||
@ -408,38 +487,6 @@ Note that both implementations currently do not support batching.
 | 
			
		||||
#
 | 
			
		||||
# -------------------------------------------------- #
 | 
			
		||||
 | 
			
		||||
# -------------------------------------------------- #
 | 
			
		||||
#                  CONSTANTS                         #
 | 
			
		||||
# -------------------------------------------------- #
 | 
			
		||||
"""
 | 
			
		||||
_box_planes and _box_triangles define the 4- and 3-connectivity
 | 
			
		||||
of the 8 box corners.
 | 
			
		||||
_box_planes gives the quad faces of the 3D box
 | 
			
		||||
_box_triangles gives the triangle faces of the 3D box
 | 
			
		||||
"""
 | 
			
		||||
_box_planes = [
 | 
			
		||||
    [0, 1, 2, 3],
 | 
			
		||||
    [3, 2, 6, 7],
 | 
			
		||||
    [0, 1, 5, 4],
 | 
			
		||||
    [0, 3, 7, 4],
 | 
			
		||||
    [1, 5, 6, 2],
 | 
			
		||||
    [4, 5, 6, 7],
 | 
			
		||||
]
 | 
			
		||||
_box_triangles = [
 | 
			
		||||
    [0, 1, 2],
 | 
			
		||||
    [0, 3, 2],
 | 
			
		||||
    [4, 5, 6],
 | 
			
		||||
    [4, 6, 7],
 | 
			
		||||
    [1, 5, 6],
 | 
			
		||||
    [1, 6, 2],
 | 
			
		||||
    [0, 4, 7],
 | 
			
		||||
    [0, 7, 3],
 | 
			
		||||
    [3, 2, 6],
 | 
			
		||||
    [3, 6, 7],
 | 
			
		||||
    [0, 1, 5],
 | 
			
		||||
    [0, 4, 5],
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
# -------------------------------------------------- #
 | 
			
		||||
#       HELPER FUNCTIONS FOR EXACT SOLUTION          #
 | 
			
		||||
# -------------------------------------------------- #
 | 
			
		||||
@ -477,7 +524,7 @@ def get_plane_verts(box: torch.Tensor) -> torch.Tensor:
 | 
			
		||||
    return plane_verts
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def box_planar_dir(box: torch.Tensor) -> torch.Tensor:
 | 
			
		||||
def box_planar_dir(box: torch.Tensor, eps=1e-4) -> torch.Tensor:
 | 
			
		||||
    """
 | 
			
		||||
    Finds the unit vector n which is perpendicular to each plane in the box
 | 
			
		||||
    and points towards the inside of the box.
 | 
			
		||||
@ -507,6 +554,11 @@ def box_planar_dir(box: torch.Tensor) -> torch.Tensor:
 | 
			
		||||
    e1 = F.normalize(v2 - v0, dim=-1)
 | 
			
		||||
    n = F.normalize(torch.cross(e0, e1, dim=-1), dim=-1)
 | 
			
		||||
 | 
			
		||||
    # Check all verts are coplanar
 | 
			
		||||
    if not ((v3 - v0).unsqueeze(1).bmm(n.unsqueeze(2)).abs() < eps).all().item():
 | 
			
		||||
        msg = "Plane vertices are not coplanar"
 | 
			
		||||
        raise ValueError(msg)
 | 
			
		||||
 | 
			
		||||
    # We can write:  `ctr = v0 + a * e0 + b * e1 + c * n`, (1).
 | 
			
		||||
    # With <e0, n> = 0 and <e1, n> = 0, where <.,.> refers to the dot product,
 | 
			
		||||
    # since that e0 is orthogonal to n. Same for e1.
 | 
			
		||||
@ -733,10 +785,10 @@ def clip_tri_by_plane_oneout(
 | 
			
		||||
    device = plane.device
 | 
			
		||||
    # point of intersection between plane and (vin1, vout)
 | 
			
		||||
    pint1, a1 = plane_edge_point_of_intersection(plane, n, vin1, vout)
 | 
			
		||||
    assert a1 >= eps and a1 <= 1.0
 | 
			
		||||
    assert a1 >= eps and a1 <= 1.0, a1
 | 
			
		||||
    # point of intersection between plane and (vin2, vout)
 | 
			
		||||
    pint2, a2 = plane_edge_point_of_intersection(plane, n, vin2, vout)
 | 
			
		||||
    assert a2 >= 0.0 and a2 <= 1.0
 | 
			
		||||
    assert a2 >= 0.0 and a2 <= 1.0, a2
 | 
			
		||||
 | 
			
		||||
    verts = torch.stack((vin1, pint1, pint2, vin2), dim=0)  # 4x3
 | 
			
		||||
    faces = torch.tensor(
 | 
			
		||||
@ -771,10 +823,10 @@ def clip_tri_by_plane_twoout(
 | 
			
		||||
    device = plane.device
 | 
			
		||||
    # point of intersection between plane and (vin, vout1)
 | 
			
		||||
    pint1, a1 = plane_edge_point_of_intersection(plane, n, vin, vout1)
 | 
			
		||||
    assert a1 >= eps and a1 <= 1.0
 | 
			
		||||
    assert a1 >= eps and a1 <= 1.0, a1
 | 
			
		||||
    # point of intersection between plane and (vin, vout2)
 | 
			
		||||
    pint2, a2 = plane_edge_point_of_intersection(plane, n, vin, vout2)
 | 
			
		||||
    assert a2 >= eps and a2 <= 1.0
 | 
			
		||||
    assert a2 >= eps and a2 <= 1.0, a2
 | 
			
		||||
 | 
			
		||||
    verts = torch.stack((vin, pint1, pint2), dim=0)  # 3x3
 | 
			
		||||
    faces = torch.tensor(
 | 
			
		||||
@ -945,7 +997,7 @@ def box3d_overlap_naive(box1: torch.Tensor, box2: torch.Tensor):
 | 
			
		||||
 | 
			
		||||
    iou = vol / (vol1 + vol2 - vol)
 | 
			
		||||
 | 
			
		||||
    if 0:
 | 
			
		||||
    if DEBUG:
 | 
			
		||||
        # save shapes
 | 
			
		||||
        tri_faces = torch.tensor(_box_triangles, device=device, dtype=torch.int64)
 | 
			
		||||
        save_obj("/tmp/output/shape1.obj", box1, tri_faces)
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user