mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 18:02:14 +08:00 
			
		
		
		
	CUDA/C++ Rasterizer updates to handle clipped faces
Summary: - Updated the C++/CUDA mesh rasterization kernels to handle the clipped faces. In particular this required careful handling of the distance calculation for faces which are cut into a quadrilateral by the image plane and then split into two sub triangles i.e. both sub triangles can't be part of the top K faces. - Updated `rasterize_meshes.py` to use the utils functions to clip the meshes and convert the fragments back to in terms of the unclipped mesh - Added end to end tests Reviewed By: jcjohnson Differential Revision: D26169685 fbshipit-source-id: d64cd0d656109b965f44a35c301b7c81f451cfa0
This commit is contained in:
		
							parent
							
								
									838b73d3b6
								
							
						
					
					
						commit
						340662e98e
					
				@ -17,8 +17,8 @@ namespace {
 | 
			
		||||
// A structure for holding details about a pixel.
 | 
			
		||||
struct Pixel {
 | 
			
		||||
  float z;
 | 
			
		||||
  int64_t idx;
 | 
			
		||||
  float dist;
 | 
			
		||||
  int64_t idx; // idx of face
 | 
			
		||||
  float dist; // abs distance of pixel to face
 | 
			
		||||
  float3 bary;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
@ -111,6 +111,7 @@ __device__ bool CheckPointOutsideBoundingBox(
 | 
			
		||||
template <typename FaceQ>
 | 
			
		||||
__device__ void CheckPixelInsideFace(
 | 
			
		||||
    const float* face_verts, // (F, 3, 3)
 | 
			
		||||
    const int64_t* clipped_faces_neighbor_idx, // (F,)
 | 
			
		||||
    const int face_idx,
 | 
			
		||||
    int& q_size,
 | 
			
		||||
    float& q_max_z,
 | 
			
		||||
@ -173,32 +174,72 @@ __device__ void CheckPixelInsideFace(
 | 
			
		||||
  // face.
 | 
			
		||||
  const bool inside = p_bary.x > 0.0f && p_bary.y > 0.0f && p_bary.z > 0.0f;
 | 
			
		||||
  const float signed_dist = inside ? -dist : dist;
 | 
			
		||||
 | 
			
		||||
  // Check if pixel is outside blur region
 | 
			
		||||
  if (!inside && dist >= blur_radius) {
 | 
			
		||||
    return;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  if (q_size < K) {
 | 
			
		||||
    // Just insert it.
 | 
			
		||||
    q[q_size] = {pz, face_idx, signed_dist, p_bary_clip};
 | 
			
		||||
    if (pz > q_max_z) {
 | 
			
		||||
      q_max_z = pz;
 | 
			
		||||
      q_max_idx = q_size;
 | 
			
		||||
  // Handle the case where a face (f) partially behind the image plane is
 | 
			
		||||
  // clipped to a quadrilateral and then split into two faces (t1, t2). In this
 | 
			
		||||
  // case we:
 | 
			
		||||
  // 1. Find the index of the neighboring face (e.g. for t1 need index of t2)
 | 
			
		||||
  // 2. Check if the neighboring face (t2) is already in the top K faces
 | 
			
		||||
  // 3. If yes, compare the distance of the pixel to t1 with the distance to t2.
 | 
			
		||||
  // 4. If dist_t1 < dist_t2, overwrite the values for t2 in the top K faces.
 | 
			
		||||
  const int neighbor_idx = clipped_faces_neighbor_idx[face_idx];
 | 
			
		||||
  int neighbor_idx_top_k = -1;
 | 
			
		||||
 | 
			
		||||
  // Check if neighboring face is already in the top K.
 | 
			
		||||
  // -1 is the fill value in clipped_faces_neighbor_idx
 | 
			
		||||
  if (neighbor_idx != -1) {
 | 
			
		||||
    // Only need to loop until q_size.
 | 
			
		||||
    for (int i = 0; i < q_size; i++) {
 | 
			
		||||
      if (q[i].idx == neighbor_idx) {
 | 
			
		||||
        neighbor_idx_top_k = i;
 | 
			
		||||
        break;
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
    q_size++;
 | 
			
		||||
  } else if (pz < q_max_z) {
 | 
			
		||||
    // Overwrite the old max, and find the new max.
 | 
			
		||||
    q[q_max_idx] = {pz, face_idx, signed_dist, p_bary_clip};
 | 
			
		||||
    q_max_z = pz;
 | 
			
		||||
    for (int i = 0; i < K; i++) {
 | 
			
		||||
      if (q[i].z > q_max_z) {
 | 
			
		||||
        q_max_z = q[i].z;
 | 
			
		||||
        q_max_idx = i;
 | 
			
		||||
  }
 | 
			
		||||
  // If neighbor idx is not -1 then it is in the top K struct.
 | 
			
		||||
  if (neighbor_idx_top_k != -1) {
 | 
			
		||||
    // If dist of current face is less than neighbor then overwrite the
 | 
			
		||||
    // neighbor face values in the top K struct.
 | 
			
		||||
    float neighbor_dist = abs(q[neighbor_idx_top_k].dist);
 | 
			
		||||
    if (dist < neighbor_dist) {
 | 
			
		||||
      // Overwrite the neighbor face values
 | 
			
		||||
      q[neighbor_idx_top_k] = {pz, face_idx, signed_dist, p_bary_clip};
 | 
			
		||||
 | 
			
		||||
      // If pz > q_max then overwrite the max values and index of the max.
 | 
			
		||||
      // q_size stays the same.
 | 
			
		||||
      if (pz > q_max_z) {
 | 
			
		||||
        q_max_z = pz;
 | 
			
		||||
        q_max_idx = neighbor_idx_top_k;
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
  } else {
 | 
			
		||||
    // Handle as a normal face
 | 
			
		||||
    if (q_size < K) {
 | 
			
		||||
      // Just insert it.
 | 
			
		||||
      q[q_size] = {pz, face_idx, signed_dist, p_bary_clip};
 | 
			
		||||
      if (pz > q_max_z) {
 | 
			
		||||
        q_max_z = pz;
 | 
			
		||||
        q_max_idx = q_size;
 | 
			
		||||
      }
 | 
			
		||||
      q_size++;
 | 
			
		||||
    } else if (pz < q_max_z) {
 | 
			
		||||
      // Overwrite the old max, and find the new max.
 | 
			
		||||
      q[q_max_idx] = {pz, face_idx, signed_dist, p_bary_clip};
 | 
			
		||||
      q_max_z = pz;
 | 
			
		||||
      for (int i = 0; i < K; i++) {
 | 
			
		||||
        if (q[i].z > q_max_z) {
 | 
			
		||||
          q_max_z = q[i].z;
 | 
			
		||||
          q_max_idx = i;
 | 
			
		||||
        }
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
} // namespace
 | 
			
		||||
 | 
			
		||||
// ****************************************************************************
 | 
			
		||||
@ -208,6 +249,7 @@ __global__ void RasterizeMeshesNaiveCudaKernel(
 | 
			
		||||
    const float* face_verts,
 | 
			
		||||
    const int64_t* mesh_to_face_first_idx,
 | 
			
		||||
    const int64_t* num_faces_per_mesh,
 | 
			
		||||
    const int64_t* clipped_faces_neighbor_idx,
 | 
			
		||||
    const float blur_radius,
 | 
			
		||||
    const bool perspective_correct,
 | 
			
		||||
    const bool clip_barycentric_coords,
 | 
			
		||||
@ -265,6 +307,7 @@ __global__ void RasterizeMeshesNaiveCudaKernel(
 | 
			
		||||
 | 
			
		||||
      CheckPixelInsideFace(
 | 
			
		||||
          face_verts,
 | 
			
		||||
          clipped_faces_neighbor_idx,
 | 
			
		||||
          f,
 | 
			
		||||
          q_size,
 | 
			
		||||
          q_max_z,
 | 
			
		||||
@ -298,6 +341,7 @@ RasterizeMeshesNaiveCuda(
 | 
			
		||||
    const at::Tensor& face_verts,
 | 
			
		||||
    const at::Tensor& mesh_to_faces_packed_first_idx,
 | 
			
		||||
    const at::Tensor& num_faces_per_mesh,
 | 
			
		||||
    const at::Tensor& clipped_faces_neighbor_idx,
 | 
			
		||||
    const std::tuple<int, int> image_size,
 | 
			
		||||
    const float blur_radius,
 | 
			
		||||
    const int num_closest,
 | 
			
		||||
@ -313,6 +357,10 @@ RasterizeMeshesNaiveCuda(
 | 
			
		||||
      num_faces_per_mesh.size(0) == mesh_to_faces_packed_first_idx.size(0),
 | 
			
		||||
      "num_faces_per_mesh must have save size first dimension as mesh_to_faces_packed_first_idx");
 | 
			
		||||
 | 
			
		||||
  TORCH_CHECK(
 | 
			
		||||
      clipped_faces_neighbor_idx.size(0) == face_verts.size(0),
 | 
			
		||||
      "clipped_faces_neighbor_idx must have save size first dimension as face_verts");
 | 
			
		||||
 | 
			
		||||
  if (num_closest > kMaxPointsPerPixel) {
 | 
			
		||||
    std::stringstream ss;
 | 
			
		||||
    ss << "Must have points_per_pixel <= " << kMaxPointsPerPixel;
 | 
			
		||||
@ -323,11 +371,16 @@ RasterizeMeshesNaiveCuda(
 | 
			
		||||
  at::TensorArg face_verts_t{face_verts, "face_verts", 1},
 | 
			
		||||
      mesh_to_faces_packed_first_idx_t{
 | 
			
		||||
          mesh_to_faces_packed_first_idx, "mesh_to_faces_packed_first_idx", 2},
 | 
			
		||||
      num_faces_per_mesh_t{num_faces_per_mesh, "num_faces_per_mesh", 3};
 | 
			
		||||
      num_faces_per_mesh_t{num_faces_per_mesh, "num_faces_per_mesh", 3},
 | 
			
		||||
      clipped_faces_neighbor_idx_t{
 | 
			
		||||
          clipped_faces_neighbor_idx, "clipped_faces_neighbor_idx", 4};
 | 
			
		||||
  at::CheckedFrom c = "RasterizeMeshesNaiveCuda";
 | 
			
		||||
  at::checkAllSameGPU(
 | 
			
		||||
      c,
 | 
			
		||||
      {face_verts_t, mesh_to_faces_packed_first_idx_t, num_faces_per_mesh_t});
 | 
			
		||||
      {face_verts_t,
 | 
			
		||||
       mesh_to_faces_packed_first_idx_t,
 | 
			
		||||
       num_faces_per_mesh_t,
 | 
			
		||||
       clipped_faces_neighbor_idx_t});
 | 
			
		||||
 | 
			
		||||
  // Set the device for the kernel launch based on the device of the input
 | 
			
		||||
  at::cuda::CUDAGuard device_guard(face_verts.device());
 | 
			
		||||
@ -358,6 +411,7 @@ RasterizeMeshesNaiveCuda(
 | 
			
		||||
      face_verts.contiguous().data_ptr<float>(),
 | 
			
		||||
      mesh_to_faces_packed_first_idx.contiguous().data_ptr<int64_t>(),
 | 
			
		||||
      num_faces_per_mesh.contiguous().data_ptr<int64_t>(),
 | 
			
		||||
      clipped_faces_neighbor_idx.contiguous().data_ptr<int64_t>(),
 | 
			
		||||
      blur_radius,
 | 
			
		||||
      perspective_correct,
 | 
			
		||||
      clip_barycentric_coords,
 | 
			
		||||
@ -800,6 +854,7 @@ at::Tensor RasterizeMeshesCoarseCuda(
 | 
			
		||||
__global__ void RasterizeMeshesFineCudaKernel(
 | 
			
		||||
    const float* face_verts, // (F, 3, 3)
 | 
			
		||||
    const int32_t* bin_faces, // (N, BH, BW, T)
 | 
			
		||||
    const int64_t* clipped_faces_neighbor_idx, // (F,)
 | 
			
		||||
    const float blur_radius,
 | 
			
		||||
    const int bin_size,
 | 
			
		||||
    const bool perspective_correct,
 | 
			
		||||
@ -858,6 +913,7 @@ __global__ void RasterizeMeshesFineCudaKernel(
 | 
			
		||||
    int q_size = 0;
 | 
			
		||||
    float q_max_z = -1000;
 | 
			
		||||
    int q_max_idx = -1;
 | 
			
		||||
 | 
			
		||||
    for (int m = 0; m < M; m++) {
 | 
			
		||||
      const int f = bin_faces[n * BH * BW * M + by * BW * M + bx * M + m];
 | 
			
		||||
      if (f < 0) {
 | 
			
		||||
@ -867,6 +923,7 @@ __global__ void RasterizeMeshesFineCudaKernel(
 | 
			
		||||
      // update q, q_size, q_max_z and q_max_idx in place.
 | 
			
		||||
      CheckPixelInsideFace(
 | 
			
		||||
          face_verts,
 | 
			
		||||
          clipped_faces_neighbor_idx,
 | 
			
		||||
          f,
 | 
			
		||||
          q_size,
 | 
			
		||||
          q_max_z,
 | 
			
		||||
@ -906,6 +963,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
 | 
			
		||||
RasterizeMeshesFineCuda(
 | 
			
		||||
    const at::Tensor& face_verts,
 | 
			
		||||
    const at::Tensor& bin_faces,
 | 
			
		||||
    const at::Tensor& clipped_faces_neighbor_idx,
 | 
			
		||||
    const std::tuple<int, int> image_size,
 | 
			
		||||
    const float blur_radius,
 | 
			
		||||
    const int bin_size,
 | 
			
		||||
@ -918,12 +976,18 @@ RasterizeMeshesFineCuda(
 | 
			
		||||
          face_verts.size(2) == 3,
 | 
			
		||||
      "face_verts must have dimensions (num_faces, 3, 3)");
 | 
			
		||||
  TORCH_CHECK(bin_faces.ndimension() == 4, "bin_faces must have 4 dimensions");
 | 
			
		||||
  TORCH_CHECK(
 | 
			
		||||
      clipped_faces_neighbor_idx.size(0) == face_verts.size(0),
 | 
			
		||||
      "clipped_faces_neighbor_idx must have the same first dimension as face_verts");
 | 
			
		||||
 | 
			
		||||
  // Check inputs are on the same device
 | 
			
		||||
  at::TensorArg face_verts_t{face_verts, "face_verts", 1},
 | 
			
		||||
      bin_faces_t{bin_faces, "bin_faces", 2};
 | 
			
		||||
      bin_faces_t{bin_faces, "bin_faces", 2},
 | 
			
		||||
      clipped_faces_neighbor_idx_t{
 | 
			
		||||
          clipped_faces_neighbor_idx, "clipped_faces_neighbor_idx", 3};
 | 
			
		||||
  at::CheckedFrom c = "RasterizeMeshesFineCuda";
 | 
			
		||||
  at::checkAllSameGPU(c, {face_verts_t, bin_faces_t});
 | 
			
		||||
  at::checkAllSameGPU(
 | 
			
		||||
      c, {face_verts_t, bin_faces_t, clipped_faces_neighbor_idx_t});
 | 
			
		||||
 | 
			
		||||
  // Set the device for the kernel launch based on the device of the input
 | 
			
		||||
  at::cuda::CUDAGuard device_guard(face_verts.device());
 | 
			
		||||
@ -961,6 +1025,7 @@ RasterizeMeshesFineCuda(
 | 
			
		||||
  RasterizeMeshesFineCudaKernel<<<blocks, threads, 0, stream>>>(
 | 
			
		||||
      face_verts.contiguous().data_ptr<float>(),
 | 
			
		||||
      bin_faces.contiguous().data_ptr<int32_t>(),
 | 
			
		||||
      clipped_faces_neighbor_idx.contiguous().data_ptr<int64_t>(),
 | 
			
		||||
      blur_radius,
 | 
			
		||||
      bin_size,
 | 
			
		||||
      perspective_correct,
 | 
			
		||||
 | 
			
		||||
@ -15,6 +15,7 @@ RasterizeMeshesNaiveCpu(
 | 
			
		||||
    const torch::Tensor& face_verts,
 | 
			
		||||
    const torch::Tensor& mesh_to_face_first_idx,
 | 
			
		||||
    const torch::Tensor& num_faces_per_mesh,
 | 
			
		||||
    const torch::Tensor& clipped_faces_neighbor_idx,
 | 
			
		||||
    const std::tuple<int, int> image_size,
 | 
			
		||||
    const float blur_radius,
 | 
			
		||||
    const int faces_per_pixel,
 | 
			
		||||
@ -28,6 +29,7 @@ RasterizeMeshesNaiveCuda(
 | 
			
		||||
    const at::Tensor& face_verts,
 | 
			
		||||
    const at::Tensor& mesh_to_face_first_idx,
 | 
			
		||||
    const at::Tensor& num_faces_per_mesh,
 | 
			
		||||
    const torch::Tensor& clipped_faces_neighbor_idx,
 | 
			
		||||
    const std::tuple<int, int> image_size,
 | 
			
		||||
    const float blur_radius,
 | 
			
		||||
    const int num_closest,
 | 
			
		||||
@ -48,6 +50,12 @@ RasterizeMeshesNaiveCuda(
 | 
			
		||||
//                            the batch where N is the batch size.
 | 
			
		||||
//    num_faces_per_mesh: LongTensor of shape (N) giving the number of faces
 | 
			
		||||
//                        for each mesh in the batch.
 | 
			
		||||
//    clipped_faces_neighbor_idx: LongTensor of shape (F,) giving the
 | 
			
		||||
//        index of the neighboring face for each face which was clipped to a
 | 
			
		||||
//        quadrilateral and then divided into two triangles.
 | 
			
		||||
//        e.g. for a face f partially behind the image plane which is split into
 | 
			
		||||
//        two triangles (t1, t2): clipped_faces_neighbor_idx[t1_idx] = t2_idx
 | 
			
		||||
//        Faces which are not clipped and subdivided are set to -1.
 | 
			
		||||
//    image_size: Tuple (H, W) giving the size in pixels of the output
 | 
			
		||||
//                image to be rasterized.
 | 
			
		||||
//    blur_radius: float distance in NDC coordinates uses to expand the face
 | 
			
		||||
@ -90,6 +98,7 @@ RasterizeMeshesNaive(
 | 
			
		||||
    const torch::Tensor& face_verts,
 | 
			
		||||
    const torch::Tensor& mesh_to_face_first_idx,
 | 
			
		||||
    const torch::Tensor& num_faces_per_mesh,
 | 
			
		||||
    const torch::Tensor& clipped_faces_neighbor_idx,
 | 
			
		||||
    const std::tuple<int, int> image_size,
 | 
			
		||||
    const float blur_radius,
 | 
			
		||||
    const int faces_per_pixel,
 | 
			
		||||
@ -106,6 +115,7 @@ RasterizeMeshesNaive(
 | 
			
		||||
        face_verts,
 | 
			
		||||
        mesh_to_face_first_idx,
 | 
			
		||||
        num_faces_per_mesh,
 | 
			
		||||
        clipped_faces_neighbor_idx,
 | 
			
		||||
        image_size,
 | 
			
		||||
        blur_radius,
 | 
			
		||||
        faces_per_pixel,
 | 
			
		||||
@ -120,6 +130,7 @@ RasterizeMeshesNaive(
 | 
			
		||||
        face_verts,
 | 
			
		||||
        mesh_to_face_first_idx,
 | 
			
		||||
        num_faces_per_mesh,
 | 
			
		||||
        clipped_faces_neighbor_idx,
 | 
			
		||||
        image_size,
 | 
			
		||||
        blur_radius,
 | 
			
		||||
        faces_per_pixel,
 | 
			
		||||
@ -306,6 +317,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
 | 
			
		||||
RasterizeMeshesFineCuda(
 | 
			
		||||
    const torch::Tensor& face_verts,
 | 
			
		||||
    const torch::Tensor& bin_faces,
 | 
			
		||||
    const torch::Tensor& clipped_faces_neighbor_idx,
 | 
			
		||||
    const std::tuple<int, int> image_size,
 | 
			
		||||
    const float blur_radius,
 | 
			
		||||
    const int bin_size,
 | 
			
		||||
@ -322,6 +334,12 @@ RasterizeMeshesFineCuda(
 | 
			
		||||
//                in NDC coordinates in the range [-1, 1].
 | 
			
		||||
//    bin_faces: int32 Tensor of shape (N, B, B, M) giving the indices of faces
 | 
			
		||||
//               that fall into each bin (output from coarse rasterization).
 | 
			
		||||
//    clipped_faces_neighbor_idx: LongTensor of shape (F,) giving the
 | 
			
		||||
//        index of the neighboring face for each face which was clipped to a
 | 
			
		||||
//        quadrilateral and then divided into two triangles.
 | 
			
		||||
//        e.g. for a face f partially behind the image plane which is split into
 | 
			
		||||
//        two triangles (t1, t2): clipped_faces_neighbor_idx[t1_idx] = t2_idx
 | 
			
		||||
//        Faces which are not clipped and subdivided are set to -1.
 | 
			
		||||
//    image_size: Tuple (H, W) giving the size in pixels of the output
 | 
			
		||||
//                image to be rasterized.
 | 
			
		||||
//    blur_radius: float distance in NDC coordinates uses to expand the face
 | 
			
		||||
@ -364,6 +382,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
 | 
			
		||||
RasterizeMeshesFine(
 | 
			
		||||
    const torch::Tensor& face_verts,
 | 
			
		||||
    const torch::Tensor& bin_faces,
 | 
			
		||||
    const torch::Tensor& clipped_faces_neighbor_idx,
 | 
			
		||||
    const std::tuple<int, int> image_size,
 | 
			
		||||
    const float blur_radius,
 | 
			
		||||
    const int bin_size,
 | 
			
		||||
@ -378,6 +397,7 @@ RasterizeMeshesFine(
 | 
			
		||||
    return RasterizeMeshesFineCuda(
 | 
			
		||||
        face_verts,
 | 
			
		||||
        bin_faces,
 | 
			
		||||
        clipped_faces_neighbor_idx,
 | 
			
		||||
        image_size,
 | 
			
		||||
        blur_radius,
 | 
			
		||||
        bin_size,
 | 
			
		||||
@ -411,6 +431,12 @@ RasterizeMeshesFine(
 | 
			
		||||
//                            the batch where N is the batch size.
 | 
			
		||||
//    num_faces_per_mesh: LongTensor of shape (N) giving the number of faces
 | 
			
		||||
//                        for each mesh in the batch.
 | 
			
		||||
//    clipped_faces_neighbor_idx: LongTensor of shape (F,) giving the
 | 
			
		||||
//        index of the neighboring face for each face which was clipped to a
 | 
			
		||||
//        quadrilateral and then divided into two triangles.
 | 
			
		||||
//        e.g. for a face f partially behind the image plane which is split into
 | 
			
		||||
//        two triangles (t1, t2): clipped_faces_neighbor_idx[t1_idx] = t2_idx
 | 
			
		||||
//        Faces which are not clipped and subdivided are set to -1.
 | 
			
		||||
//    image_size: Tuple (H, W) giving the size in pixels of the output
 | 
			
		||||
//                image to be rasterized.
 | 
			
		||||
//    blur_radius: float distance in NDC coordinates uses to expand the face
 | 
			
		||||
@ -456,6 +482,7 @@ RasterizeMeshes(
 | 
			
		||||
    const torch::Tensor& face_verts,
 | 
			
		||||
    const torch::Tensor& mesh_to_face_first_idx,
 | 
			
		||||
    const torch::Tensor& num_faces_per_mesh,
 | 
			
		||||
    const torch::Tensor& clipped_faces_neighbor_idx,
 | 
			
		||||
    const std::tuple<int, int> image_size,
 | 
			
		||||
    const float blur_radius,
 | 
			
		||||
    const int faces_per_pixel,
 | 
			
		||||
@ -477,6 +504,7 @@ RasterizeMeshes(
 | 
			
		||||
    return RasterizeMeshesFine(
 | 
			
		||||
        face_verts,
 | 
			
		||||
        bin_faces,
 | 
			
		||||
        clipped_faces_neighbor_idx,
 | 
			
		||||
        image_size,
 | 
			
		||||
        blur_radius,
 | 
			
		||||
        bin_size,
 | 
			
		||||
@ -490,6 +518,7 @@ RasterizeMeshes(
 | 
			
		||||
        face_verts,
 | 
			
		||||
        mesh_to_face_first_idx,
 | 
			
		||||
        num_faces_per_mesh,
 | 
			
		||||
        clipped_faces_neighbor_idx,
 | 
			
		||||
        image_size,
 | 
			
		||||
        blur_radius,
 | 
			
		||||
        faces_per_pixel,
 | 
			
		||||
 | 
			
		||||
@ -99,11 +99,31 @@ auto ComputeFaceAreas(const torch::Tensor& face_verts) {
 | 
			
		||||
  return face_areas;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Helper function to use with std::find_if to find the index of any
 | 
			
		||||
// values in the top k struct which match a given idx.
 | 
			
		||||
struct IsNeighbor {
 | 
			
		||||
  IsNeighbor(int neighbor_idx) {
 | 
			
		||||
    this->neighbor_idx = neighbor_idx;
 | 
			
		||||
  }
 | 
			
		||||
  bool operator()(std::tuple<float, int, float, float, float, float> elem) {
 | 
			
		||||
    return (std::get<1>(elem) == neighbor_idx);
 | 
			
		||||
  }
 | 
			
		||||
  int neighbor_idx;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
// Function to sort based on the z distance in the top K queue
 | 
			
		||||
bool SortTopKByZdist(
 | 
			
		||||
    std::tuple<float, int, float, float, float, float> a,
 | 
			
		||||
    std::tuple<float, int, float, float, float, float> b) {
 | 
			
		||||
  return std::get<0>(a) < std::get<0>(b);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
 | 
			
		||||
RasterizeMeshesNaiveCpu(
 | 
			
		||||
    const torch::Tensor& face_verts,
 | 
			
		||||
    const torch::Tensor& mesh_to_face_first_idx,
 | 
			
		||||
    const torch::Tensor& num_faces_per_mesh,
 | 
			
		||||
    const torch::Tensor& clipped_faces_neighbor_idx,
 | 
			
		||||
    const std::tuple<int, int> image_size,
 | 
			
		||||
    const float blur_radius,
 | 
			
		||||
    const int faces_per_pixel,
 | 
			
		||||
@ -139,6 +159,7 @@ RasterizeMeshesNaiveCpu(
 | 
			
		||||
  auto zbuf_a = zbuf.accessor<float, 4>();
 | 
			
		||||
  auto pix_dists_a = pix_dists.accessor<float, 4>();
 | 
			
		||||
  auto barycentric_coords_a = barycentric_coords.accessor<float, 5>();
 | 
			
		||||
  auto neighbor_idx_a = clipped_faces_neighbor_idx.accessor<int64_t, 1>();
 | 
			
		||||
 | 
			
		||||
  auto face_bboxes = ComputeFaceBoundingBoxes(face_verts);
 | 
			
		||||
  auto face_bboxes_a = face_bboxes.accessor<float, 2>();
 | 
			
		||||
@ -168,10 +189,11 @@ RasterizeMeshesNaiveCpu(
 | 
			
		||||
 | 
			
		||||
        // X coordinate of the left of the pixel.
 | 
			
		||||
        const float xf = PixToNonSquareNdc(xidx, W, H);
 | 
			
		||||
        // Use a priority queue to hold values:
 | 
			
		||||
 | 
			
		||||
        // Use a deque to hold values:
 | 
			
		||||
        // (z, idx, r, bary.x, bary.y. bary.z)
 | 
			
		||||
        std::priority_queue<std::tuple<float, int, float, float, float, float>>
 | 
			
		||||
            q;
 | 
			
		||||
        // Sort the deque as needed to mimic a priority queue.
 | 
			
		||||
        std::deque<std::tuple<float, int, float, float, float, float>> q;
 | 
			
		||||
 | 
			
		||||
        // Loop through the faces in the mesh.
 | 
			
		||||
        for (int f = face_start_idx; f < face_stop_idx; ++f) {
 | 
			
		||||
@ -240,15 +262,58 @@ RasterizeMeshesNaiveCpu(
 | 
			
		||||
          if (!inside && dist >= blur_radius) {
 | 
			
		||||
            continue;
 | 
			
		||||
          }
 | 
			
		||||
          // The current pixel lies inside the current face.
 | 
			
		||||
          q.emplace(pz, f, signed_dist, bary_clip.x, bary_clip.y, bary_clip.z);
 | 
			
		||||
 | 
			
		||||
          // Handle the case where a face (f) partially behind the image plane
 | 
			
		||||
          // is clipped to a quadrilateral and then split into two faces (t1,
 | 
			
		||||
          // t2). In this case we:
 | 
			
		||||
          // 1. Find the index of the neighbor (e.g. for t1 need index of t2)
 | 
			
		||||
          // 2. Check if the neighbor (t2) is already in the top K faces
 | 
			
		||||
          // 3. If yes, compare the distance of the pixel to t1 with the
 | 
			
		||||
          // distance to t2.
 | 
			
		||||
          // 4. If dist_t1 < dist_t2, overwrite the values for t2 in the top K
 | 
			
		||||
          // faces.
 | 
			
		||||
          const int neighbor_idx = neighbor_idx_a[f];
 | 
			
		||||
          int idx_top_k = -1;
 | 
			
		||||
 | 
			
		||||
          // Check if neighboring face is already in the top K.
 | 
			
		||||
          if (neighbor_idx != -1) {
 | 
			
		||||
            const auto it =
 | 
			
		||||
                std::find_if(q.begin(), q.end(), IsNeighbor(neighbor_idx));
 | 
			
		||||
            // Get the index of the element from the iterator
 | 
			
		||||
            idx_top_k = (it != q.end()) ? it - q.begin() : idx_top_k;
 | 
			
		||||
          }
 | 
			
		||||
 | 
			
		||||
          // If idx_top_k idx is not -1 then it is in the top K struct.
 | 
			
		||||
          if (idx_top_k != -1) {
 | 
			
		||||
            // If dist of current face is less than neighbor, overwrite
 | 
			
		||||
            // the neighbor face values in the top K struct.
 | 
			
		||||
            const auto neighbor = q[idx_top_k];
 | 
			
		||||
            const float dist_neighbor = std::abs(std::get<2>(neighbor));
 | 
			
		||||
            if (dist < dist_neighbor) {
 | 
			
		||||
              // Overwrite the neighbor face values.
 | 
			
		||||
              q[idx_top_k] = {
 | 
			
		||||
                  pz, f, signed_dist, bary_clip.x, bary_clip.y, bary_clip.z};
 | 
			
		||||
            }
 | 
			
		||||
          } else {
 | 
			
		||||
            // Handle as a normal face.
 | 
			
		||||
            // The current pixel lies inside the current face.
 | 
			
		||||
            // Add at the end of the deque.
 | 
			
		||||
            q.emplace_back(
 | 
			
		||||
                pz, f, signed_dist, bary_clip.x, bary_clip.y, bary_clip.z);
 | 
			
		||||
          }
 | 
			
		||||
 | 
			
		||||
          // Sort the deque inplace based on the z distance
 | 
			
		||||
          // to mimic using a priority queue.
 | 
			
		||||
          std::sort(q.begin(), q.end(), SortTopKByZdist);
 | 
			
		||||
          if (static_cast<int>(q.size()) > K) {
 | 
			
		||||
            q.pop();
 | 
			
		||||
            // remove the last value
 | 
			
		||||
            q.pop_back();
 | 
			
		||||
          }
 | 
			
		||||
        }
 | 
			
		||||
        while (!q.empty()) {
 | 
			
		||||
          auto t = q.top();
 | 
			
		||||
          q.pop();
 | 
			
		||||
          // Loop through and add values to the output tensors
 | 
			
		||||
          auto t = q.back();
 | 
			
		||||
          q.pop_back();
 | 
			
		||||
          const int i = q.size();
 | 
			
		||||
          zbuf_a[n][yi][xi][i] = std::get<0>(t);
 | 
			
		||||
          face_idxs_a[n][yi][xi][i] = std::get<1>(t);
 | 
			
		||||
 | 
			
		||||
@ -9,6 +9,12 @@ import torch
 | 
			
		||||
# pyre-fixme[21]: Could not find name `_C` in `pytorch3d`.
 | 
			
		||||
from pytorch3d import _C
 | 
			
		||||
 | 
			
		||||
from .clip import (
 | 
			
		||||
    ClipFrustum,
 | 
			
		||||
    clip_faces,
 | 
			
		||||
    convert_clipped_rasterization_to_original_faces,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# TODO make the epsilon user configurable
 | 
			
		||||
kEpsilon = 1e-8
 | 
			
		||||
@ -28,6 +34,8 @@ def rasterize_meshes(
 | 
			
		||||
    perspective_correct: bool = False,
 | 
			
		||||
    clip_barycentric_coords: bool = False,
 | 
			
		||||
    cull_backfaces: bool = False,
 | 
			
		||||
    z_clip_value: Optional[float] = None,
 | 
			
		||||
    cull_to_frustum: bool = False,
 | 
			
		||||
):
 | 
			
		||||
    """
 | 
			
		||||
    Rasterize a batch of meshes given the shape of the desired output image.
 | 
			
		||||
@ -67,7 +75,8 @@ def rasterize_meshes(
 | 
			
		||||
            will be raised. This should not affect the output values, but can affect
 | 
			
		||||
            the memory usage in the forward pass.
 | 
			
		||||
        perspective_correct: Bool, Whether to apply perspective correction when computing
 | 
			
		||||
            barycentric coordinates for pixels.
 | 
			
		||||
            barycentric coordinates for pixels. This should be set to True if a perspective
 | 
			
		||||
            camera is used.
 | 
			
		||||
        cull_backfaces: Bool, Whether to only rasterize mesh faces which are
 | 
			
		||||
            visible to the camera.  This assumes that vertices of
 | 
			
		||||
            front-facing triangles are ordered in an anti-clockwise
 | 
			
		||||
@ -76,6 +85,15 @@ def rasterize_meshes(
 | 
			
		||||
            direction. NOTE: This will only work if the mesh faces are
 | 
			
		||||
            consistently defined with counter-clockwise ordering when
 | 
			
		||||
            viewed from the outside.
 | 
			
		||||
        z_clip_value: if not None, then triangles will be clipped (and possibly
 | 
			
		||||
            subdivided into smaller triangles) such that z >= z_clip_value.
 | 
			
		||||
            This avoids camera projections that go to infinity as z->0.
 | 
			
		||||
            Default is None as clipping affects rasterization speed and
 | 
			
		||||
            should only be turned on if explicitly needed.
 | 
			
		||||
            See clip.py for all the extra computation that is required.
 | 
			
		||||
        cull_to_frustum: if True, triangles outside the view frustum will be culled.
 | 
			
		||||
            Culling involves removing all faces which fall outside view frustum.
 | 
			
		||||
            Default is False so that it is turned on only when needed.
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
        4-element tuple containing
 | 
			
		||||
@ -141,6 +159,42 @@ def rasterize_meshes(
 | 
			
		||||
        im_size = (image_size, image_size)
 | 
			
		||||
        max_image_size = image_size
 | 
			
		||||
 | 
			
		||||
    clipped_faces_neighbor_idx = None
 | 
			
		||||
 | 
			
		||||
    if z_clip_value is not None or cull_to_frustum:
 | 
			
		||||
        # Cull faces outside the view frustum, and clip faces that are partially
 | 
			
		||||
        # behind the camera into the portion of the triangle in front of the
 | 
			
		||||
        # camera.  This may change the number of faces
 | 
			
		||||
        frustum = ClipFrustum(
 | 
			
		||||
            left=-1,
 | 
			
		||||
            right=1,
 | 
			
		||||
            top=-1,
 | 
			
		||||
            bottom=1,
 | 
			
		||||
            perspective_correct=perspective_correct,
 | 
			
		||||
            z_clip_value=z_clip_value,
 | 
			
		||||
            cull=cull_to_frustum,
 | 
			
		||||
        )
 | 
			
		||||
        clipped_faces = clip_faces(
 | 
			
		||||
            face_verts, mesh_to_face_first_idx, num_faces_per_mesh, frustum=frustum
 | 
			
		||||
        )
 | 
			
		||||
        face_verts = clipped_faces.face_verts
 | 
			
		||||
        mesh_to_face_first_idx = clipped_faces.mesh_to_face_first_idx
 | 
			
		||||
        num_faces_per_mesh = clipped_faces.num_faces_per_mesh
 | 
			
		||||
 | 
			
		||||
        # For case 4 clipped triangles (where a big triangle is split in two smaller triangles),
 | 
			
		||||
        # need the index of the neighboring clipped triangle as only one can be in
 | 
			
		||||
        # in the top K closest faces in the rasterization step.
 | 
			
		||||
        clipped_faces_neighbor_idx = clipped_faces.clipped_faces_neighbor_idx
 | 
			
		||||
 | 
			
		||||
    if clipped_faces_neighbor_idx is None:
 | 
			
		||||
        # Set to the default which is all -1s.
 | 
			
		||||
        clipped_faces_neighbor_idx = torch.full(
 | 
			
		||||
            size=(face_verts.shape[0],),
 | 
			
		||||
            fill_value=-1,
 | 
			
		||||
            device=meshes.device,
 | 
			
		||||
            dtype=torch.int64,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    # TODO: Choose naive vs coarse-to-fine based on mesh size and image size.
 | 
			
		||||
    if bin_size is None:
 | 
			
		||||
        if not verts_packed.is_cuda:
 | 
			
		||||
@ -172,10 +226,11 @@ def rasterize_meshes(
 | 
			
		||||
        max_faces_per_bin = int(max(10000, meshes._F / 5))
 | 
			
		||||
 | 
			
		||||
    # pyre-fixme[16]: `_RasterizeFaceVerts` has no attribute `apply`.
 | 
			
		||||
    return _RasterizeFaceVerts.apply(
 | 
			
		||||
    pix_to_face, zbuf, barycentric_coords, dists = _RasterizeFaceVerts.apply(
 | 
			
		||||
        face_verts,
 | 
			
		||||
        mesh_to_face_first_idx,
 | 
			
		||||
        num_faces_per_mesh,
 | 
			
		||||
        clipped_faces_neighbor_idx,
 | 
			
		||||
        im_size,
 | 
			
		||||
        blur_radius,
 | 
			
		||||
        faces_per_pixel,
 | 
			
		||||
@ -186,6 +241,17 @@ def rasterize_meshes(
 | 
			
		||||
        cull_backfaces,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    if z_clip_value is not None or cull_to_frustum:
 | 
			
		||||
        # If faces were clipped, map the rasterization result to be in terms of the
 | 
			
		||||
        # original unclipped faces.  This may involve converting barycentric
 | 
			
		||||
        # coordinates
 | 
			
		||||
        outputs = convert_clipped_rasterization_to_original_faces(
 | 
			
		||||
            pix_to_face, barycentric_coords, clipped_faces
 | 
			
		||||
        )
 | 
			
		||||
        pix_to_face, barycentric_coords = outputs
 | 
			
		||||
 | 
			
		||||
    return pix_to_face, zbuf, barycentric_coords, dists
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class _RasterizeFaceVerts(torch.autograd.Function):
 | 
			
		||||
    """
 | 
			
		||||
@ -216,9 +282,10 @@ class _RasterizeFaceVerts(torch.autograd.Function):
 | 
			
		||||
    # pyre-fixme[14]: `forward` overrides method defined in `Function` inconsistently.
 | 
			
		||||
    def forward(
 | 
			
		||||
        ctx,
 | 
			
		||||
        face_verts,
 | 
			
		||||
        mesh_to_face_first_idx,
 | 
			
		||||
        num_faces_per_mesh,
 | 
			
		||||
        face_verts: torch.Tensor,
 | 
			
		||||
        mesh_to_face_first_idx: torch.Tensor,
 | 
			
		||||
        num_faces_per_mesh: torch.Tensor,
 | 
			
		||||
        clipped_faces_neighbor_idx: torch.Tensor,
 | 
			
		||||
        image_size: Union[List[int], Tuple[int, int]] = (256, 256),
 | 
			
		||||
        blur_radius: float = 0.01,
 | 
			
		||||
        faces_per_pixel: int = 0,
 | 
			
		||||
@ -227,12 +294,15 @@ class _RasterizeFaceVerts(torch.autograd.Function):
 | 
			
		||||
        perspective_correct: bool = False,
 | 
			
		||||
        clip_barycentric_coords: bool = False,
 | 
			
		||||
        cull_backfaces: bool = False,
 | 
			
		||||
        z_clip_value: Optional[float] = None,
 | 
			
		||||
        cull_to_frustum: bool = True,
 | 
			
		||||
    ):
 | 
			
		||||
        # pyre-fixme[16]: Module `pytorch3d` has no attribute `_C`.
 | 
			
		||||
        pix_to_face, zbuf, barycentric_coords, dists = _C.rasterize_meshes(
 | 
			
		||||
            face_verts,
 | 
			
		||||
            mesh_to_face_first_idx,
 | 
			
		||||
            num_faces_per_mesh,
 | 
			
		||||
            clipped_faces_neighbor_idx,
 | 
			
		||||
            image_size,
 | 
			
		||||
            blur_radius,
 | 
			
		||||
            faces_per_pixel,
 | 
			
		||||
@ -242,6 +312,7 @@ class _RasterizeFaceVerts(torch.autograd.Function):
 | 
			
		||||
            clip_barycentric_coords,
 | 
			
		||||
            cull_backfaces,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        ctx.save_for_backward(face_verts, pix_to_face)
 | 
			
		||||
        ctx.mark_non_differentiable(pix_to_face)
 | 
			
		||||
        ctx.perspective_correct = perspective_correct
 | 
			
		||||
@ -253,6 +324,7 @@ class _RasterizeFaceVerts(torch.autograd.Function):
 | 
			
		||||
        grad_face_verts = None
 | 
			
		||||
        grad_mesh_to_face_first_idx = None
 | 
			
		||||
        grad_num_faces_per_mesh = None
 | 
			
		||||
        grad_clipped_faces_neighbor_idx = None
 | 
			
		||||
        grad_image_size = None
 | 
			
		||||
        grad_radius = None
 | 
			
		||||
        grad_faces_per_pixel = None
 | 
			
		||||
@ -275,6 +347,7 @@ class _RasterizeFaceVerts(torch.autograd.Function):
 | 
			
		||||
            grad_face_verts,
 | 
			
		||||
            grad_mesh_to_face_first_idx,
 | 
			
		||||
            grad_num_faces_per_mesh,
 | 
			
		||||
            grad_clipped_faces_neighbor_idx,
 | 
			
		||||
            grad_image_size,
 | 
			
		||||
            grad_radius,
 | 
			
		||||
            grad_faces_per_pixel,
 | 
			
		||||
@ -339,6 +412,9 @@ def rasterize_meshes_python(
 | 
			
		||||
    perspective_correct: bool = False,
 | 
			
		||||
    clip_barycentric_coords: bool = False,
 | 
			
		||||
    cull_backfaces: bool = False,
 | 
			
		||||
    z_clip_value: Optional[float] = None,
 | 
			
		||||
    cull_to_frustum: bool = True,
 | 
			
		||||
    clipped_faces_neighbor_idx: Optional[torch.Tensor] = None,
 | 
			
		||||
):
 | 
			
		||||
    """
 | 
			
		||||
    Naive PyTorch implementation of mesh rasterization with the same inputs and
 | 
			
		||||
@ -359,6 +435,26 @@ def rasterize_meshes_python(
 | 
			
		||||
    mesh_to_face_first_idx = meshes.mesh_to_faces_packed_first_idx()
 | 
			
		||||
    num_faces_per_mesh = meshes.num_faces_per_mesh()
 | 
			
		||||
 | 
			
		||||
    if z_clip_value is not None or cull_to_frustum:
 | 
			
		||||
        # Cull faces outside the view frustum, and clip faces that are partially
 | 
			
		||||
        # behind the camera into the portion of the triangle in front of the
 | 
			
		||||
        # camera.  This may change the number of faces
 | 
			
		||||
        frustum = ClipFrustum(
 | 
			
		||||
            left=-1,
 | 
			
		||||
            right=1,
 | 
			
		||||
            top=-1,
 | 
			
		||||
            bottom=1,
 | 
			
		||||
            perspective_correct=perspective_correct,
 | 
			
		||||
            z_clip_value=z_clip_value,
 | 
			
		||||
            cull=cull_to_frustum,
 | 
			
		||||
        )
 | 
			
		||||
        clipped_faces = clip_faces(
 | 
			
		||||
            faces_verts, mesh_to_face_first_idx, num_faces_per_mesh, frustum=frustum
 | 
			
		||||
        )
 | 
			
		||||
        faces_verts = clipped_faces.face_verts
 | 
			
		||||
        mesh_to_face_first_idx = clipped_faces.mesh_to_face_first_idx
 | 
			
		||||
        num_faces_per_mesh = clipped_faces.num_faces_per_mesh
 | 
			
		||||
 | 
			
		||||
    # Intialize output tensors.
 | 
			
		||||
    face_idxs = torch.full(
 | 
			
		||||
        (N, H, W, K), fill_value=-1, dtype=torch.int64, device=device
 | 
			
		||||
@ -468,26 +564,55 @@ def rasterize_meshes_python(
 | 
			
		||||
                    # Points inside the triangle have negative distance.
 | 
			
		||||
                    dist = point_triangle_distance(pxy, v0[:2], v1[:2], v2[:2])
 | 
			
		||||
 | 
			
		||||
                    signed_dist = dist * -1.0 if inside else dist
 | 
			
		||||
 | 
			
		||||
                    # Add an epsilon to prevent errors when comparing distance
 | 
			
		||||
                    # to blur radius.
 | 
			
		||||
                    if not inside and dist >= blur_radius:
 | 
			
		||||
                        continue
 | 
			
		||||
 | 
			
		||||
                    top_k_points.append((pz, f, bary, signed_dist))
 | 
			
		||||
                    # Handle the case where a face (f) partially behind the image plane is
 | 
			
		||||
                    # clipped to a quadrilateral and then split into two faces (t1, t2).
 | 
			
		||||
                    top_k_idx = -1
 | 
			
		||||
                    if (
 | 
			
		||||
                        clipped_faces_neighbor_idx is not None
 | 
			
		||||
                        and clipped_faces_neighbor_idx[f] != -1
 | 
			
		||||
                    ):
 | 
			
		||||
                        neighbor_idx = clipped_faces_neighbor_idx[f]
 | 
			
		||||
                        # See if neighbor_idx is in top_k and find index
 | 
			
		||||
                        top_k_idx = [
 | 
			
		||||
                            i
 | 
			
		||||
                            for i, val in enumerate(top_k_points)
 | 
			
		||||
                            if val[1] == neighbor_idx
 | 
			
		||||
                        ]
 | 
			
		||||
                        top_k_idx = top_k_idx[0] if len(top_k_idx) > 0 else -1
 | 
			
		||||
 | 
			
		||||
                    if top_k_idx != -1 and dist < top_k_points[top_k_idx][3]:
 | 
			
		||||
                        # Overwrite the neighbor with current face info
 | 
			
		||||
                        top_k_points[top_k_idx] = (pz, f, bary, dist, inside)
 | 
			
		||||
                    else:
 | 
			
		||||
                        # Handle as a normal face
 | 
			
		||||
                        top_k_points.append((pz, f, bary, dist, inside))
 | 
			
		||||
 | 
			
		||||
                    top_k_points.sort()
 | 
			
		||||
                    if len(top_k_points) > K:
 | 
			
		||||
                        top_k_points = top_k_points[:K]
 | 
			
		||||
 | 
			
		||||
                # Save to output tensors.
 | 
			
		||||
                for k, (pz, f, bary, dist) in enumerate(top_k_points):
 | 
			
		||||
                for k, (pz, f, bary, dist, inside) in enumerate(top_k_points):
 | 
			
		||||
                    zbuf[n, yi, xi, k] = pz
 | 
			
		||||
                    face_idxs[n, yi, xi, k] = f
 | 
			
		||||
                    bary_coords[n, yi, xi, k, 0] = bary[0]
 | 
			
		||||
                    bary_coords[n, yi, xi, k, 1] = bary[1]
 | 
			
		||||
                    bary_coords[n, yi, xi, k, 2] = bary[2]
 | 
			
		||||
                    pix_dists[n, yi, xi, k] = dist
 | 
			
		||||
                    # Write the signed distance
 | 
			
		||||
                    pix_dists[n, yi, xi, k] = -dist if inside else dist
 | 
			
		||||
 | 
			
		||||
    if z_clip_value is not None or cull_to_frustum:
 | 
			
		||||
        # If faces were clipped, map the rasterization result to be in terms of the
 | 
			
		||||
        # original unclipped faces.  This may involve converting barycentric
 | 
			
		||||
        # coordinates
 | 
			
		||||
        (face_idxs, bary_coords,) = convert_clipped_rasterization_to_original_faces(
 | 
			
		||||
            face_idxs, bary_coords, clipped_faces
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    return face_idxs, zbuf, bary_coords, pix_dists
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -27,6 +27,8 @@ class RasterizationSettings:
 | 
			
		||||
        "perspective_correct",
 | 
			
		||||
        "clip_barycentric_coords",
 | 
			
		||||
        "cull_backfaces",
 | 
			
		||||
        "z_clip_value",
 | 
			
		||||
        "cull_to_frustum",
 | 
			
		||||
    ]
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
@ -36,9 +38,13 @@ class RasterizationSettings:
 | 
			
		||||
        faces_per_pixel: int = 1,
 | 
			
		||||
        bin_size: Optional[int] = None,
 | 
			
		||||
        max_faces_per_bin: Optional[int] = None,
 | 
			
		||||
        perspective_correct: bool = False,
 | 
			
		||||
        # set perspective_correct = None so that the
 | 
			
		||||
        # value can be inferred correctly from the Camera type
 | 
			
		||||
        perspective_correct: Optional[bool] = None,
 | 
			
		||||
        clip_barycentric_coords: Optional[bool] = None,
 | 
			
		||||
        cull_backfaces: bool = False,
 | 
			
		||||
        z_clip_value: Optional[float] = None,
 | 
			
		||||
        cull_to_frustum: bool = False,
 | 
			
		||||
    ):
 | 
			
		||||
        self.image_size = image_size
 | 
			
		||||
        self.blur_radius = blur_radius
 | 
			
		||||
@ -48,6 +54,8 @@ class RasterizationSettings:
 | 
			
		||||
        self.perspective_correct = perspective_correct
 | 
			
		||||
        self.clip_barycentric_coords = clip_barycentric_coords
 | 
			
		||||
        self.cull_backfaces = cull_backfaces
 | 
			
		||||
        self.z_clip_value = z_clip_value
 | 
			
		||||
        self.cull_to_frustum = cull_to_frustum
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class MeshRasterizer(nn.Module):
 | 
			
		||||
@ -139,12 +147,19 @@ class MeshRasterizer(nn.Module):
 | 
			
		||||
        if clip_barycentric_coords is None:
 | 
			
		||||
            clip_barycentric_coords = raster_settings.blur_radius > 0.0
 | 
			
		||||
 | 
			
		||||
        # If not specified, infer perspective_correct from the camera
 | 
			
		||||
        # If not specified, infer perspective_correct and z_clip_value from the camera
 | 
			
		||||
        cameras = kwargs.get("cameras", self.cameras)
 | 
			
		||||
        if raster_settings.perspective_correct is not None:
 | 
			
		||||
            perspective_correct = raster_settings.perspective_correct
 | 
			
		||||
        else:
 | 
			
		||||
            perspective_correct = cameras.is_perspective()
 | 
			
		||||
        if raster_settings.z_clip_value is not None:
 | 
			
		||||
            z_clip = raster_settings.z_clip_value
 | 
			
		||||
        else:
 | 
			
		||||
            znear = cameras.get_znear()
 | 
			
		||||
            if isinstance(znear, torch.Tensor):
 | 
			
		||||
                znear = znear.min().item()
 | 
			
		||||
            z_clip = None if not perspective_correct or znear is None else znear / 2
 | 
			
		||||
 | 
			
		||||
        pix_to_face, zbuf, bary_coords, dists = rasterize_meshes(
 | 
			
		||||
            meshes_screen,
 | 
			
		||||
@ -153,9 +168,11 @@ class MeshRasterizer(nn.Module):
 | 
			
		||||
            faces_per_pixel=raster_settings.faces_per_pixel,
 | 
			
		||||
            bin_size=raster_settings.bin_size,
 | 
			
		||||
            max_faces_per_bin=raster_settings.max_faces_per_bin,
 | 
			
		||||
            perspective_correct=perspective_correct,
 | 
			
		||||
            clip_barycentric_coords=clip_barycentric_coords,
 | 
			
		||||
            perspective_correct=perspective_correct,
 | 
			
		||||
            cull_backfaces=raster_settings.cull_backfaces,
 | 
			
		||||
            z_clip_value=z_clip,
 | 
			
		||||
            cull_to_frustum=raster_settings.cull_to_frustum,
 | 
			
		||||
        )
 | 
			
		||||
        return Fragments(
 | 
			
		||||
            pix_to_face=pix_to_face, zbuf=zbuf, bary_coords=bary_coords, dists=dists
 | 
			
		||||
 | 
			
		||||
@ -66,7 +66,7 @@ def bm_rasterize_meshes() -> None:
 | 
			
		||||
        # Square and non square cases
 | 
			
		||||
        image_size = [64, 128, 512, (512, 256), (256, 512)]
 | 
			
		||||
        blur = [1e-6]
 | 
			
		||||
        faces_per_pixel = [50]
 | 
			
		||||
        faces_per_pixel = [40]
 | 
			
		||||
        test_cases = product(num_meshes, ico_level, image_size, blur, faces_per_pixel)
 | 
			
		||||
 | 
			
		||||
        for case in test_cases:
 | 
			
		||||
@ -87,6 +87,35 @@ def bm_rasterize_meshes() -> None:
 | 
			
		||||
            warmup_iters=1,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # Test a subset of the cases with the
 | 
			
		||||
        # image plane intersecting the mesh.
 | 
			
		||||
        kwargs_list = []
 | 
			
		||||
        num_meshes = [8, 16]
 | 
			
		||||
        # Square and non square cases
 | 
			
		||||
        image_size = [64, 128, 512, (512, 256), (256, 512)]
 | 
			
		||||
        dist = [3, 0.8, 0.5]
 | 
			
		||||
        test_cases = product(num_meshes, dist, image_size)
 | 
			
		||||
 | 
			
		||||
        for case in test_cases:
 | 
			
		||||
            n, d, im = case
 | 
			
		||||
            kwargs_list.append(
 | 
			
		||||
                {
 | 
			
		||||
                    "num_meshes": n,
 | 
			
		||||
                    "ico_level": 4,
 | 
			
		||||
                    "image_size": im,
 | 
			
		||||
                    "blur_radius": 1e-6,
 | 
			
		||||
                    "faces_per_pixel": 40,
 | 
			
		||||
                    "dist": d,
 | 
			
		||||
                }
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        benchmark(
 | 
			
		||||
            TestRasterizeMeshes.bm_rasterize_meshes_with_clipping,
 | 
			
		||||
            "RASTERIZE_MESHES_CUDA_CLIPPING",
 | 
			
		||||
            kwargs_list,
 | 
			
		||||
            warmup_iters=1,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    bm_rasterize_meshes()
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										
											BIN
										
									
								
								tests/data/room.jpg
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								tests/data/room.jpg
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							| 
		 After Width: | Height: | Size: 7.8 KiB  | 
							
								
								
									
										
											BIN
										
									
								
								tests/data/test_render_mesh_clipped_cam_dist=0.5.jpg
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								tests/data/test_render_mesh_clipped_cam_dist=0.5.jpg
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							| 
		 After Width: | Height: | Size: 7.2 KiB  | 
@ -6,6 +6,8 @@ import unittest
 | 
			
		||||
import torch
 | 
			
		||||
from common_testing import TestCaseMixin, get_random_cuda_device
 | 
			
		||||
from pytorch3d import _C
 | 
			
		||||
from pytorch3d.renderer import FoVPerspectiveCameras, look_at_view_transform
 | 
			
		||||
from pytorch3d.renderer.mesh import MeshRasterizer, RasterizationSettings
 | 
			
		||||
from pytorch3d.renderer.mesh.rasterize_meshes import (
 | 
			
		||||
    rasterize_meshes,
 | 
			
		||||
    rasterize_meshes_python,
 | 
			
		||||
@ -1204,3 +1206,50 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
            torch.cuda.synchronize(device)
 | 
			
		||||
 | 
			
		||||
        return rasterize
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def bm_rasterize_meshes_with_clipping(
 | 
			
		||||
        num_meshes: int,
 | 
			
		||||
        ico_level: int,
 | 
			
		||||
        image_size: int,
 | 
			
		||||
        blur_radius: float,
 | 
			
		||||
        faces_per_pixel: int,
 | 
			
		||||
        dist: float,
 | 
			
		||||
    ):
 | 
			
		||||
        device = get_random_cuda_device()
 | 
			
		||||
        meshes = ico_sphere(ico_level, device)
 | 
			
		||||
        meshes_batch = meshes.extend(num_meshes)
 | 
			
		||||
 | 
			
		||||
        settings = RasterizationSettings(
 | 
			
		||||
            image_size=image_size,
 | 
			
		||||
            blur_radius=blur_radius,
 | 
			
		||||
            faces_per_pixel=faces_per_pixel,
 | 
			
		||||
            z_clip_value=1e-2,
 | 
			
		||||
            perspective_correct=True,
 | 
			
		||||
            cull_to_frustum=True,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # The camera is positioned so that the image plane intersects
 | 
			
		||||
        # the mesh and some faces are partially behind the image plane.
 | 
			
		||||
        R, T = look_at_view_transform(dist, 0, 0)
 | 
			
		||||
        cameras = FoVPerspectiveCameras(device=device, R=R, T=T, fov=90)
 | 
			
		||||
        rasterizer = MeshRasterizer(raster_settings=settings, cameras=cameras)
 | 
			
		||||
 | 
			
		||||
        # Transform the meshes to projec them onto the image plane
 | 
			
		||||
        meshes_screen = rasterizer.transform(meshes_batch)
 | 
			
		||||
        torch.cuda.synchronize(device)
 | 
			
		||||
 | 
			
		||||
        def rasterize():
 | 
			
		||||
            # Only measure rasterization speed (including clipping)
 | 
			
		||||
            rasterize_meshes(
 | 
			
		||||
                meshes_screen,
 | 
			
		||||
                image_size,
 | 
			
		||||
                blur_radius,
 | 
			
		||||
                faces_per_pixel,
 | 
			
		||||
                z_clip_value=1e-2,
 | 
			
		||||
                perspective_correct=True,
 | 
			
		||||
                cull_to_frustum=True,
 | 
			
		||||
            )
 | 
			
		||||
            torch.cuda.synchronize(device)
 | 
			
		||||
 | 
			
		||||
        return rasterize
 | 
			
		||||
 | 
			
		||||
@ -245,6 +245,7 @@ class TestRenderImplicit(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
                    image_size=image_size,
 | 
			
		||||
                    blur_radius=1e-3,
 | 
			
		||||
                    faces_per_pixel=10,
 | 
			
		||||
                    z_clip_value=None,
 | 
			
		||||
                    perspective_correct=False,
 | 
			
		||||
                ),
 | 
			
		||||
            ),
 | 
			
		||||
 | 
			
		||||
@ -994,6 +994,7 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
            blur_radius=0.0,
 | 
			
		||||
            faces_per_pixel=1,
 | 
			
		||||
            cull_backfaces=True,
 | 
			
		||||
            perspective_correct=False,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # Init shader settings
 | 
			
		||||
 | 
			
		||||
@ -9,14 +9,161 @@ See pytorch3d/renderer/mesh/clip.py for more details about the
 | 
			
		||||
clipping process.
 | 
			
		||||
"""
 | 
			
		||||
import unittest
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
 | 
			
		||||
import imageio
 | 
			
		||||
import numpy as np
 | 
			
		||||
import torch
 | 
			
		||||
from common_testing import TestCaseMixin
 | 
			
		||||
from pytorch3d.renderer.mesh import ClipFrustum, clip_faces
 | 
			
		||||
from common_testing import TestCaseMixin, load_rgb_image
 | 
			
		||||
from pytorch3d.io import save_obj
 | 
			
		||||
from pytorch3d.renderer.cameras import FoVPerspectiveCameras, look_at_view_transform
 | 
			
		||||
from pytorch3d.renderer.lighting import PointLights
 | 
			
		||||
from pytorch3d.renderer.mesh import (
 | 
			
		||||
    ClipFrustum,
 | 
			
		||||
    TexturesUV,
 | 
			
		||||
    clip_faces,
 | 
			
		||||
    convert_clipped_rasterization_to_original_faces,
 | 
			
		||||
)
 | 
			
		||||
from pytorch3d.renderer.mesh.rasterize_meshes import _RasterizeFaceVerts
 | 
			
		||||
from pytorch3d.renderer.mesh.rasterizer import MeshRasterizer, RasterizationSettings
 | 
			
		||||
from pytorch3d.renderer.mesh.renderer import MeshRenderer
 | 
			
		||||
from pytorch3d.renderer.mesh.shader import SoftPhongShader
 | 
			
		||||
from pytorch3d.structures.meshes import Meshes
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# If DEBUG=True, save out images generated in the tests for debugging.
 | 
			
		||||
# All saved images have prefix DEBUG_
 | 
			
		||||
DEBUG = False
 | 
			
		||||
DATA_DIR = Path(__file__).resolve().parent / "data"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestRenderMeshesClipping(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
    def load_cube_mesh_with_texture(self, device="cpu", with_grad: bool = False):
 | 
			
		||||
        verts = torch.tensor(
 | 
			
		||||
            [
 | 
			
		||||
                [-1, 1, 1],
 | 
			
		||||
                [1, 1, 1],
 | 
			
		||||
                [1, -1, 1],
 | 
			
		||||
                [-1, -1, 1],
 | 
			
		||||
                [-1, 1, -1],
 | 
			
		||||
                [1, 1, -1],
 | 
			
		||||
                [1, -1, -1],
 | 
			
		||||
                [-1, -1, -1],
 | 
			
		||||
            ],
 | 
			
		||||
            device=device,
 | 
			
		||||
            dtype=torch.float32,
 | 
			
		||||
            requires_grad=with_grad,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # all faces correctly wound
 | 
			
		||||
        faces = torch.tensor(
 | 
			
		||||
            [
 | 
			
		||||
                [0, 1, 4],
 | 
			
		||||
                [4, 1, 5],
 | 
			
		||||
                [1, 2, 5],
 | 
			
		||||
                [5, 2, 6],
 | 
			
		||||
                [2, 7, 6],
 | 
			
		||||
                [2, 3, 7],
 | 
			
		||||
                [3, 4, 7],
 | 
			
		||||
                [0, 4, 3],
 | 
			
		||||
                [4, 5, 6],
 | 
			
		||||
                [4, 6, 7],
 | 
			
		||||
            ],
 | 
			
		||||
            device=device,
 | 
			
		||||
            dtype=torch.int64,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        verts_uvs = torch.tensor(
 | 
			
		||||
            [
 | 
			
		||||
                [
 | 
			
		||||
                    [0, 1],
 | 
			
		||||
                    [1, 1],
 | 
			
		||||
                    [1, 0],
 | 
			
		||||
                    [0, 0],
 | 
			
		||||
                    [0.204, 0.743],
 | 
			
		||||
                    [0.781, 0.743],
 | 
			
		||||
                    [0.781, 0.154],
 | 
			
		||||
                    [0.204, 0.154],
 | 
			
		||||
                ]
 | 
			
		||||
            ],
 | 
			
		||||
            device=device,
 | 
			
		||||
            dtype=torch.float,
 | 
			
		||||
        )
 | 
			
		||||
        texture_map = load_rgb_image("room.jpg", DATA_DIR).to(device)
 | 
			
		||||
        textures = TexturesUV(
 | 
			
		||||
            maps=[texture_map], faces_uvs=faces.unsqueeze(0), verts_uvs=verts_uvs
 | 
			
		||||
        )
 | 
			
		||||
        mesh = Meshes([verts], [faces], textures=textures)
 | 
			
		||||
        if with_grad:
 | 
			
		||||
            return mesh, verts
 | 
			
		||||
        return mesh
 | 
			
		||||
 | 
			
		||||
    def test_cube_mesh_render(self):
 | 
			
		||||
        """
 | 
			
		||||
        End-End test of rendering a cube mesh with texture
 | 
			
		||||
        from decreasing camera distances. The camera starts
 | 
			
		||||
        outside the cube and enters the inside of the cube.
 | 
			
		||||
        """
 | 
			
		||||
        device = torch.device("cuda:0")
 | 
			
		||||
        mesh = self.load_cube_mesh_with_texture(device)
 | 
			
		||||
        raster_settings = RasterizationSettings(
 | 
			
		||||
            image_size=512,
 | 
			
		||||
            blur_radius=1e-8,
 | 
			
		||||
            faces_per_pixel=5,
 | 
			
		||||
            z_clip_value=1e-2,
 | 
			
		||||
            perspective_correct=True,
 | 
			
		||||
            bin_size=0,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # Only ambient, no diffuse or specular
 | 
			
		||||
        lights = PointLights(
 | 
			
		||||
            device=device,
 | 
			
		||||
            ambient_color=((1.0, 1.0, 1.0),),
 | 
			
		||||
            diffuse_color=((0.0, 0.0, 0.0),),
 | 
			
		||||
            specular_color=((0.0, 0.0, 0.0),),
 | 
			
		||||
            location=[[0.0, 0.0, -3.0]],
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        renderer = MeshRenderer(
 | 
			
		||||
            rasterizer=MeshRasterizer(raster_settings=raster_settings),
 | 
			
		||||
            shader=SoftPhongShader(device=device, lights=lights),
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # Render the cube by decreasing the distance from the camera until
 | 
			
		||||
        # the camera enters the cube. Check the output looks correct.
 | 
			
		||||
        images_list = []
 | 
			
		||||
        dists = np.linspace(0.1, 2.5, 20)[::-1]
 | 
			
		||||
        for d in dists:
 | 
			
		||||
            R, T = look_at_view_transform(d, 0, 0)
 | 
			
		||||
            T[0, 1] -= 0.1  # move down in the y axis
 | 
			
		||||
            cameras = FoVPerspectiveCameras(device=device, R=R, T=T, fov=90)
 | 
			
		||||
            images = renderer(mesh, cameras=cameras)
 | 
			
		||||
            rgb = images[0, ..., :3].cpu().detach()
 | 
			
		||||
            filename = "DEBUG_cube_dist=%.1f.jpg" % d
 | 
			
		||||
            im = (rgb.numpy() * 255).astype(np.uint8)
 | 
			
		||||
            images_list.append(im)
 | 
			
		||||
 | 
			
		||||
            # Check one of the images where the camera is inside the mesh
 | 
			
		||||
            if d == 0.5:
 | 
			
		||||
                filename = "test_render_mesh_clipped_cam_dist=0.5.jpg"
 | 
			
		||||
                image_ref = load_rgb_image(filename, DATA_DIR)
 | 
			
		||||
                self.assertClose(rgb, image_ref, atol=0.05)
 | 
			
		||||
 | 
			
		||||
        # Save a gif of the output - this should show
 | 
			
		||||
        # the camera moving inside the cube.
 | 
			
		||||
        if DEBUG:
 | 
			
		||||
            gif_filename = (
 | 
			
		||||
                "room_original.gif"
 | 
			
		||||
                if raster_settings.z_clip_value is None
 | 
			
		||||
                else "room_clipped.gif"
 | 
			
		||||
            )
 | 
			
		||||
            imageio.mimsave(DATA_DIR / gif_filename, images_list, fps=2)
 | 
			
		||||
            save_obj(
 | 
			
		||||
                f=DATA_DIR / "cube.obj",
 | 
			
		||||
                verts=mesh.verts_packed().cpu(),
 | 
			
		||||
                faces=mesh.faces_packed().cpu(),
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def clip_faces(meshes):
 | 
			
		||||
        verts_packed = meshes.verts_packed()
 | 
			
		||||
@ -42,6 +189,34 @@ class TestRenderMeshesClipping(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
        )
 | 
			
		||||
        return clipped_faces
 | 
			
		||||
 | 
			
		||||
    def test_grad(self):
 | 
			
		||||
        """
 | 
			
		||||
        Check that gradient flow is unaffected when the camera is inside the mesh
 | 
			
		||||
        """
 | 
			
		||||
        device = torch.device("cuda:0")
 | 
			
		||||
        mesh, verts = self.load_cube_mesh_with_texture(device=device, with_grad=True)
 | 
			
		||||
        raster_settings = RasterizationSettings(
 | 
			
		||||
            image_size=512,
 | 
			
		||||
            blur_radius=1e-5,
 | 
			
		||||
            faces_per_pixel=5,
 | 
			
		||||
            z_clip_value=1e-2,
 | 
			
		||||
            perspective_correct=True,
 | 
			
		||||
            bin_size=0,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        renderer = MeshRenderer(
 | 
			
		||||
            rasterizer=MeshRasterizer(raster_settings=raster_settings),
 | 
			
		||||
            shader=SoftPhongShader(device=device),
 | 
			
		||||
        )
 | 
			
		||||
        dist = 0.4  # Camera is inside the cube
 | 
			
		||||
        R, T = look_at_view_transform(dist, 0, 0)
 | 
			
		||||
        cameras = FoVPerspectiveCameras(device=device, R=R, T=T, fov=90)
 | 
			
		||||
        images = renderer(mesh, cameras=cameras)
 | 
			
		||||
        images.sum().backward()
 | 
			
		||||
 | 
			
		||||
        # Check gradients exist
 | 
			
		||||
        self.assertIsNotNone(verts.grad)
 | 
			
		||||
 | 
			
		||||
    def test_case_1(self):
 | 
			
		||||
        """
 | 
			
		||||
        Case 1: Single triangle fully in front of the image plane (z=0)
 | 
			
		||||
@ -350,3 +525,134 @@ class TestRenderMeshesClipping(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
        # barycentric conversion matrix.
 | 
			
		||||
        bary_idx = idx.new_tensor([1, 4, 2, 5, 0, -1, 3, 6])
 | 
			
		||||
        self.assertClose(clipped_faces.faces_clipped_to_conversion_idx, bary_idx)
 | 
			
		||||
 | 
			
		||||
    def test_convert_clipped_to_unclipped_case_4(self):
 | 
			
		||||
        """
 | 
			
		||||
        Test with a single case 4 triangle which is clipped into
 | 
			
		||||
        a quadrilateral and subdivided.
 | 
			
		||||
        """
 | 
			
		||||
        device = "cuda:0"
 | 
			
		||||
        # fmt: off
 | 
			
		||||
        verts = torch.tensor(
 | 
			
		||||
            [
 | 
			
		||||
                [-1.0,  0.0, -1.0],  # noqa: E241, E201
 | 
			
		||||
                [ 0.0,  1.0, -1.0],  # noqa: E241, E201
 | 
			
		||||
                [ 1.0,  0.0, -1.0],  # noqa: E241, E201
 | 
			
		||||
                [ 0.0, -1.0, -1.0],  # noqa: E241, E201
 | 
			
		||||
                [-1.0,  0.5,  0.5],  # noqa: E241, E201
 | 
			
		||||
                [ 1.0,  1.0,  1.0],  # noqa: E241, E201
 | 
			
		||||
                [ 0.0, -1.0,  1.0],  # noqa: E241, E201
 | 
			
		||||
                [-1.0,  0.5, -0.5],  # noqa: E241, E201
 | 
			
		||||
                [ 1.0,  1.0, -1.0],  # noqa: E241, E201
 | 
			
		||||
                [-1.0,  0.0,  1.0],  # noqa: E241, E201
 | 
			
		||||
                [ 0.0,  1.0,  1.0],  # noqa: E241, E201
 | 
			
		||||
                [ 1.0,  0.0,  1.0],  # noqa: E241, E201
 | 
			
		||||
            ],
 | 
			
		||||
            dtype=torch.float32,
 | 
			
		||||
            device=device,
 | 
			
		||||
        )
 | 
			
		||||
        faces = torch.tensor(
 | 
			
		||||
            [
 | 
			
		||||
                [0,  1,  2],  # noqa: E241, E201  Case 2 fully clipped
 | 
			
		||||
                [3,  4,  5],  # noqa: E241, E201  Case 4 clipped and subdivided
 | 
			
		||||
                [5,  4,  3],  # noqa: E241, E201  Repeat of Case 4
 | 
			
		||||
                [6,  7,  8],  # noqa: E241, E201  Case 3 clipped
 | 
			
		||||
                [9, 10, 11],  # noqa: E241, E201  Case 1 untouched
 | 
			
		||||
            ],
 | 
			
		||||
            dtype=torch.int64,
 | 
			
		||||
            device=device,
 | 
			
		||||
        )
 | 
			
		||||
        # fmt: on
 | 
			
		||||
        meshes = Meshes(verts=[verts], faces=[faces])
 | 
			
		||||
 | 
			
		||||
        # Clip meshes
 | 
			
		||||
        clipped_faces = self.clip_faces(meshes)
 | 
			
		||||
 | 
			
		||||
        # 4x faces (from Case 4) + 1 (from Case 3) + 1 (from Case 1)
 | 
			
		||||
        self.assertEqual(clipped_faces.face_verts.shape[0], 6)
 | 
			
		||||
 | 
			
		||||
        image_size = (10, 10)
 | 
			
		||||
        blur_radius = 0.05
 | 
			
		||||
        faces_per_pixel = 2
 | 
			
		||||
        perspective_correct = True
 | 
			
		||||
        bin_size = 0
 | 
			
		||||
        max_faces_per_bin = 20
 | 
			
		||||
        clip_barycentric_coords = False
 | 
			
		||||
        cull_backfaces = False
 | 
			
		||||
 | 
			
		||||
        # Rasterize clipped mesh
 | 
			
		||||
        pix_to_face, zbuf, barycentric_coords, dists = _RasterizeFaceVerts.apply(
 | 
			
		||||
            clipped_faces.face_verts,
 | 
			
		||||
            clipped_faces.mesh_to_face_first_idx,
 | 
			
		||||
            clipped_faces.num_faces_per_mesh,
 | 
			
		||||
            clipped_faces.clipped_faces_neighbor_idx,
 | 
			
		||||
            image_size,
 | 
			
		||||
            blur_radius,
 | 
			
		||||
            faces_per_pixel,
 | 
			
		||||
            bin_size,
 | 
			
		||||
            max_faces_per_bin,
 | 
			
		||||
            perspective_correct,
 | 
			
		||||
            clip_barycentric_coords,
 | 
			
		||||
            cull_backfaces,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # Convert outputs so they are in terms of the unclipped mesh.
 | 
			
		||||
        outputs = convert_clipped_rasterization_to_original_faces(
 | 
			
		||||
            pix_to_face,
 | 
			
		||||
            barycentric_coords,
 | 
			
		||||
            clipped_faces,
 | 
			
		||||
        )
 | 
			
		||||
        pix_to_face_unclipped, barycentric_coords_unclipped = outputs
 | 
			
		||||
 | 
			
		||||
        # In the clipped mesh there are more faces than in the unclipped mesh
 | 
			
		||||
        self.assertTrue(pix_to_face.max() > pix_to_face_unclipped.max())
 | 
			
		||||
        # Unclipped pix_to_face indices must be in the limit of the number
 | 
			
		||||
        # of faces in the unclipped mesh.
 | 
			
		||||
        self.assertTrue(pix_to_face_unclipped.max() < faces.shape[0])
 | 
			
		||||
 | 
			
		||||
    def test_case_4_no_duplicates(self):
 | 
			
		||||
        """
 | 
			
		||||
        In the case of an simple mesh with one face that is cut by the image
 | 
			
		||||
        plane into a quadrilateral, there shouldn't be duplicates indices of
 | 
			
		||||
        the face in the pix_to_face output of rasterization.
 | 
			
		||||
        """
 | 
			
		||||
        for (device, bin_size) in [("cpu", 0), ("cuda:0", 0), ("cuda:0", None)]:
 | 
			
		||||
            verts = torch.tensor(
 | 
			
		||||
                [[0.0, -10.0, 1.0], [-1.0, 2.0, -2.0], [1.0, 5.0, -10.0]],
 | 
			
		||||
                dtype=torch.float32,
 | 
			
		||||
                device=device,
 | 
			
		||||
            )
 | 
			
		||||
            faces = torch.tensor(
 | 
			
		||||
                [
 | 
			
		||||
                    [0, 1, 2],
 | 
			
		||||
                ],
 | 
			
		||||
                dtype=torch.int64,
 | 
			
		||||
                device=device,
 | 
			
		||||
            )
 | 
			
		||||
            meshes = Meshes(verts=[verts], faces=[faces])
 | 
			
		||||
            k = 3
 | 
			
		||||
            settings = RasterizationSettings(
 | 
			
		||||
                image_size=10,
 | 
			
		||||
                blur_radius=0.05,
 | 
			
		||||
                faces_per_pixel=k,
 | 
			
		||||
                z_clip_value=1e-2,
 | 
			
		||||
                perspective_correct=True,
 | 
			
		||||
                cull_to_frustum=True,
 | 
			
		||||
                bin_size=bin_size,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            # The camera is positioned so that the image plane cuts
 | 
			
		||||
            # the mesh face into a quadrilateral.
 | 
			
		||||
            R, T = look_at_view_transform(0.2, 0, 0)
 | 
			
		||||
            cameras = FoVPerspectiveCameras(device=device, R=R, T=T, fov=90)
 | 
			
		||||
            rasterizer = MeshRasterizer(raster_settings=settings, cameras=cameras)
 | 
			
		||||
            fragments = rasterizer(meshes)
 | 
			
		||||
 | 
			
		||||
            p2f = fragments.pix_to_face.reshape(-1, k)
 | 
			
		||||
            unique_vals, idx_counts = p2f.unique(dim=0, return_counts=True)
 | 
			
		||||
            # There is only one face in this mesh so if it hits a pixel
 | 
			
		||||
            # it can only be at position k = 0
 | 
			
		||||
            # For any pixel, the values [0, 0, 1] for the top K faces cannot be possible
 | 
			
		||||
            double_hit = torch.tensor([0, 0, -1], device=device)
 | 
			
		||||
            check_double_hit = any(torch.allclose(i, double_hit) for i in unique_vals)
 | 
			
		||||
            self.assertFalse(check_double_hit)
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user