mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 18:02:14 +08:00 
			
		
		
		
	barycentric clipping in cuda/c++
Summary: Added support for barycentric clipping in the C++/CUDA rasterization kernels which can be switched on/off via a rasterization setting. Added tests and a benchmark to compare with the current implementation in PyTorch - for some cases of large image size/faces per pixel the cuda version is 10x faster. Reviewed By: gkioxari Differential Revision: D21705503 fbshipit-source-id: e835c0f927f1e5088ca89020aef5ff27ac3a8769
This commit is contained in:
		
							parent
							
								
									bce396df93
								
							
						
					
					
						commit
						cc70950f40
					
				@ -114,6 +114,7 @@ __device__ void CheckPixelInsideFace(
 | 
			
		||||
    const float2 pxy, // Coordinates of the pixel
 | 
			
		||||
    const int K,
 | 
			
		||||
    const bool perspective_correct,
 | 
			
		||||
    const bool clip_barycentric_coords,
 | 
			
		||||
    const bool cull_backfaces) {
 | 
			
		||||
  const auto v012 = GetSingleFaceVerts(face_verts, face_idx);
 | 
			
		||||
  const float3 v0 = thrust::get<0>(v012);
 | 
			
		||||
@ -149,8 +150,12 @@ __device__ void CheckPixelInsideFace(
 | 
			
		||||
  const float3 p_bary = !perspective_correct
 | 
			
		||||
      ? p_bary0
 | 
			
		||||
      : BarycentricPerspectiveCorrectionForward(p_bary0, v0.z, v1.z, v2.z);
 | 
			
		||||
  const float3 p_bary_clip =
 | 
			
		||||
      !clip_barycentric_coords ? p_bary : BarycentricClipForward(p_bary);
 | 
			
		||||
 | 
			
		||||
  const float pz =
 | 
			
		||||
      p_bary_clip.x * v0.z + p_bary_clip.y * v1.z + p_bary_clip.z * v2.z;
 | 
			
		||||
 | 
			
		||||
  const float pz = p_bary.x * v0.z + p_bary.y * v1.z + p_bary.z * v2.z;
 | 
			
		||||
  if (pz < 0) {
 | 
			
		||||
    return; // Face is behind the image plane.
 | 
			
		||||
  }
 | 
			
		||||
@ -158,7 +163,8 @@ __device__ void CheckPixelInsideFace(
 | 
			
		||||
  // Get abs squared distance
 | 
			
		||||
  const float dist = PointTriangleDistanceForward(pxy, v0xy, v1xy, v2xy);
 | 
			
		||||
 | 
			
		||||
  // Use the bary coordinates to determine if the point is inside the face.
 | 
			
		||||
  // Use the unclipped bary coordinates to determine if the point is inside the
 | 
			
		||||
  // face.
 | 
			
		||||
  const bool inside = p_bary.x > 0.0f && p_bary.y > 0.0f && p_bary.z > 0.0f;
 | 
			
		||||
  const float signed_dist = inside ? -dist : dist;
 | 
			
		||||
 | 
			
		||||
@ -169,7 +175,7 @@ __device__ void CheckPixelInsideFace(
 | 
			
		||||
 | 
			
		||||
  if (q_size < K) {
 | 
			
		||||
    // Just insert it.
 | 
			
		||||
    q[q_size] = {pz, face_idx, signed_dist, p_bary};
 | 
			
		||||
    q[q_size] = {pz, face_idx, signed_dist, p_bary_clip};
 | 
			
		||||
    if (pz > q_max_z) {
 | 
			
		||||
      q_max_z = pz;
 | 
			
		||||
      q_max_idx = q_size;
 | 
			
		||||
@ -177,7 +183,7 @@ __device__ void CheckPixelInsideFace(
 | 
			
		||||
    q_size++;
 | 
			
		||||
  } else if (pz < q_max_z) {
 | 
			
		||||
    // Overwrite the old max, and find the new max.
 | 
			
		||||
    q[q_max_idx] = {pz, face_idx, signed_dist, p_bary};
 | 
			
		||||
    q[q_max_idx] = {pz, face_idx, signed_dist, p_bary_clip};
 | 
			
		||||
    q_max_z = pz;
 | 
			
		||||
    for (int i = 0; i < K; i++) {
 | 
			
		||||
      if (q[i].z > q_max_z) {
 | 
			
		||||
@ -198,6 +204,7 @@ __global__ void RasterizeMeshesNaiveCudaKernel(
 | 
			
		||||
    const int64_t* num_faces_per_mesh,
 | 
			
		||||
    const float blur_radius,
 | 
			
		||||
    const bool perspective_correct,
 | 
			
		||||
    const bool clip_barycentric_coords,
 | 
			
		||||
    const bool cull_backfaces,
 | 
			
		||||
    const int N,
 | 
			
		||||
    const int H,
 | 
			
		||||
@ -260,6 +267,7 @@ __global__ void RasterizeMeshesNaiveCudaKernel(
 | 
			
		||||
          pxy,
 | 
			
		||||
          K,
 | 
			
		||||
          perspective_correct,
 | 
			
		||||
          clip_barycentric_coords,
 | 
			
		||||
          cull_backfaces);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
@ -286,6 +294,7 @@ RasterizeMeshesNaiveCuda(
 | 
			
		||||
    const float blur_radius,
 | 
			
		||||
    const int num_closest,
 | 
			
		||||
    const bool perspective_correct,
 | 
			
		||||
    const bool clip_barycentric_coords,
 | 
			
		||||
    const bool cull_backfaces) {
 | 
			
		||||
  TORCH_CHECK(
 | 
			
		||||
      face_verts.ndimension() == 3 && face_verts.size(1) == 3 &&
 | 
			
		||||
@ -343,6 +352,7 @@ RasterizeMeshesNaiveCuda(
 | 
			
		||||
      num_faces_per_mesh.contiguous().data_ptr<int64_t>(),
 | 
			
		||||
      blur_radius,
 | 
			
		||||
      perspective_correct,
 | 
			
		||||
      clip_barycentric_coords,
 | 
			
		||||
      cull_backfaces,
 | 
			
		||||
      N,
 | 
			
		||||
      H,
 | 
			
		||||
@ -365,6 +375,7 @@ __global__ void RasterizeMeshesBackwardCudaKernel(
 | 
			
		||||
    const float* face_verts, // (F, 3, 3)
 | 
			
		||||
    const int64_t* pix_to_face, // (N, H, W, K)
 | 
			
		||||
    const bool perspective_correct,
 | 
			
		||||
    const bool clip_barycentric_coords,
 | 
			
		||||
    const int N,
 | 
			
		||||
    const int H,
 | 
			
		||||
    const int W,
 | 
			
		||||
@ -422,11 +433,15 @@ __global__ void RasterizeMeshesBackwardCudaKernel(
 | 
			
		||||
      const float3 grad_bary_upstream = make_float3(
 | 
			
		||||
          grad_bary_upstream_w0, grad_bary_upstream_w1, grad_bary_upstream_w2);
 | 
			
		||||
 | 
			
		||||
      const float3 bary0 = BarycentricCoordsForward(pxy, v0xy, v1xy, v2xy);
 | 
			
		||||
      const float3 bary = !perspective_correct
 | 
			
		||||
          ? bary0
 | 
			
		||||
          : BarycentricPerspectiveCorrectionForward(bary0, v0.z, v1.z, v2.z);
 | 
			
		||||
      const bool inside = bary.x > 0.0f && bary.y > 0.0f && bary.z > 0.0f;
 | 
			
		||||
      const float3 b_w = BarycentricCoordsForward(pxy, v0xy, v1xy, v2xy);
 | 
			
		||||
      const float3 b_pp = !perspective_correct
 | 
			
		||||
          ? b_w
 | 
			
		||||
          : BarycentricPerspectiveCorrectionForward(b_w, v0.z, v1.z, v2.z);
 | 
			
		||||
 | 
			
		||||
      const float3 b_w_clip =
 | 
			
		||||
          !clip_barycentric_coords ? b_pp : BarycentricClipForward(b_pp);
 | 
			
		||||
 | 
			
		||||
      const bool inside = b_pp.x > 0.0f && b_pp.y > 0.0f && b_pp.z > 0.0f;
 | 
			
		||||
      const float sign = inside ? -1.0f : 1.0f;
 | 
			
		||||
 | 
			
		||||
      // TODO(T52813608) Add support for non-square images.
 | 
			
		||||
@ -442,22 +457,29 @@ __global__ void RasterizeMeshesBackwardCudaKernel(
 | 
			
		||||
      // d_zbuf/d_bary_w0 = z0
 | 
			
		||||
      // d_zbuf/d_bary_w1 = z1
 | 
			
		||||
      // d_zbuf/d_bary_w2 = z2
 | 
			
		||||
      const float3 d_zbuf_d_bary = make_float3(v0.z, v1.z, v2.z);
 | 
			
		||||
      const float3 d_zbuf_d_bwclip = make_float3(v0.z, v1.z, v2.z);
 | 
			
		||||
 | 
			
		||||
      // Total upstream barycentric gradients are the sum of
 | 
			
		||||
      // external upstream gradients and contribution from zbuf.
 | 
			
		||||
      const float3 grad_bary_f_sum =
 | 
			
		||||
          (grad_bary_upstream + grad_zbuf_upstream * d_zbuf_d_bary);
 | 
			
		||||
          (grad_bary_upstream + grad_zbuf_upstream * d_zbuf_d_bwclip);
 | 
			
		||||
 | 
			
		||||
      float3 grad_bary0 = grad_bary_f_sum;
 | 
			
		||||
 | 
			
		||||
      if (clip_barycentric_coords) {
 | 
			
		||||
        grad_bary0 = BarycentricClipBackward(b_w, grad_bary_f_sum);
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
      float dz0_persp = 0.0f, dz1_persp = 0.0f, dz2_persp = 0.0f;
 | 
			
		||||
      if (perspective_correct) {
 | 
			
		||||
        auto perspective_grads = BarycentricPerspectiveCorrectionBackward(
 | 
			
		||||
            bary0, v0.z, v1.z, v2.z, grad_bary_f_sum);
 | 
			
		||||
            b_w, v0.z, v1.z, v2.z, grad_bary0);
 | 
			
		||||
        grad_bary0 = thrust::get<0>(perspective_grads);
 | 
			
		||||
        dz0_persp = thrust::get<1>(perspective_grads);
 | 
			
		||||
        dz1_persp = thrust::get<2>(perspective_grads);
 | 
			
		||||
        dz2_persp = thrust::get<3>(perspective_grads);
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
      auto grad_bary_f =
 | 
			
		||||
          BarycentricCoordsBackward(pxy, v0xy, v1xy, v2xy, grad_bary0);
 | 
			
		||||
      const float2 dbary_d_v0 = thrust::get<1>(grad_bary_f);
 | 
			
		||||
@ -467,15 +489,18 @@ __global__ void RasterizeMeshesBackwardCudaKernel(
 | 
			
		||||
      atomicAdd(grad_face_verts + f * 9 + 0, dbary_d_v0.x + ddist_d_v0.x);
 | 
			
		||||
      atomicAdd(grad_face_verts + f * 9 + 1, dbary_d_v0.y + ddist_d_v0.y);
 | 
			
		||||
      atomicAdd(
 | 
			
		||||
          grad_face_verts + f * 9 + 2, grad_zbuf_upstream * bary.x + dz0_persp);
 | 
			
		||||
          grad_face_verts + f * 9 + 2,
 | 
			
		||||
          grad_zbuf_upstream * b_w_clip.x + dz0_persp);
 | 
			
		||||
      atomicAdd(grad_face_verts + f * 9 + 3, dbary_d_v1.x + ddist_d_v1.x);
 | 
			
		||||
      atomicAdd(grad_face_verts + f * 9 + 4, dbary_d_v1.y + ddist_d_v1.y);
 | 
			
		||||
      atomicAdd(
 | 
			
		||||
          grad_face_verts + f * 9 + 5, grad_zbuf_upstream * bary.y + dz1_persp);
 | 
			
		||||
          grad_face_verts + f * 9 + 5,
 | 
			
		||||
          grad_zbuf_upstream * b_w_clip.y + dz1_persp);
 | 
			
		||||
      atomicAdd(grad_face_verts + f * 9 + 6, dbary_d_v2.x + ddist_d_v2.x);
 | 
			
		||||
      atomicAdd(grad_face_verts + f * 9 + 7, dbary_d_v2.y + ddist_d_v2.y);
 | 
			
		||||
      atomicAdd(
 | 
			
		||||
          grad_face_verts + f * 9 + 8, grad_zbuf_upstream * bary.z + dz2_persp);
 | 
			
		||||
          grad_face_verts + f * 9 + 8,
 | 
			
		||||
          grad_zbuf_upstream * b_w_clip.z + dz2_persp);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
@ -486,7 +511,8 @@ at::Tensor RasterizeMeshesBackwardCuda(
 | 
			
		||||
    const at::Tensor& grad_zbuf, // (N, H, W, K)
 | 
			
		||||
    const at::Tensor& grad_bary, // (N, H, W, K, 3)
 | 
			
		||||
    const at::Tensor& grad_dists, // (N, H, W, K)
 | 
			
		||||
    const bool perspective_correct) {
 | 
			
		||||
    const bool perspective_correct,
 | 
			
		||||
    const bool clip_barycentric_coords) {
 | 
			
		||||
  // Check inputs are on the same device
 | 
			
		||||
  at::TensorArg face_verts_t{face_verts, "face_verts", 1},
 | 
			
		||||
      pix_to_face_t{pix_to_face, "pix_to_face", 2},
 | 
			
		||||
@ -523,6 +549,7 @@ at::Tensor RasterizeMeshesBackwardCuda(
 | 
			
		||||
      face_verts.contiguous().data_ptr<float>(),
 | 
			
		||||
      pix_to_face.contiguous().data_ptr<int64_t>(),
 | 
			
		||||
      perspective_correct,
 | 
			
		||||
      clip_barycentric_coords,
 | 
			
		||||
      N,
 | 
			
		||||
      H,
 | 
			
		||||
      W,
 | 
			
		||||
@ -743,6 +770,7 @@ __global__ void RasterizeMeshesFineCudaKernel(
 | 
			
		||||
    const float blur_radius,
 | 
			
		||||
    const int bin_size,
 | 
			
		||||
    const bool perspective_correct,
 | 
			
		||||
    const bool clip_barycentric_coords,
 | 
			
		||||
    const bool cull_backfaces,
 | 
			
		||||
    const int N,
 | 
			
		||||
    const int B,
 | 
			
		||||
@ -808,6 +836,7 @@ __global__ void RasterizeMeshesFineCudaKernel(
 | 
			
		||||
          pxy,
 | 
			
		||||
          K,
 | 
			
		||||
          perspective_correct,
 | 
			
		||||
          clip_barycentric_coords,
 | 
			
		||||
          cull_backfaces);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
@ -841,6 +870,7 @@ RasterizeMeshesFineCuda(
 | 
			
		||||
    const int bin_size,
 | 
			
		||||
    const int faces_per_pixel,
 | 
			
		||||
    const bool perspective_correct,
 | 
			
		||||
    const bool clip_barycentric_coords,
 | 
			
		||||
    const bool cull_backfaces) {
 | 
			
		||||
  TORCH_CHECK(
 | 
			
		||||
      face_verts.ndimension() == 3 && face_verts.size(1) == 3 &&
 | 
			
		||||
@ -890,6 +920,7 @@ RasterizeMeshesFineCuda(
 | 
			
		||||
      blur_radius,
 | 
			
		||||
      bin_size,
 | 
			
		||||
      perspective_correct,
 | 
			
		||||
      clip_barycentric_coords,
 | 
			
		||||
      cull_backfaces,
 | 
			
		||||
      N,
 | 
			
		||||
      B,
 | 
			
		||||
 | 
			
		||||
@ -19,6 +19,7 @@ RasterizeMeshesNaiveCpu(
 | 
			
		||||
    const float blur_radius,
 | 
			
		||||
    const int faces_per_pixel,
 | 
			
		||||
    const bool perspective_correct,
 | 
			
		||||
    const bool clip_barycentric_coords,
 | 
			
		||||
    const bool cull_backfaces);
 | 
			
		||||
 | 
			
		||||
#ifdef WITH_CUDA
 | 
			
		||||
@ -31,6 +32,7 @@ RasterizeMeshesNaiveCuda(
 | 
			
		||||
    const float blur_radius,
 | 
			
		||||
    const int num_closest,
 | 
			
		||||
    const bool perspective_correct,
 | 
			
		||||
    const bool clip_barycentric_coords,
 | 
			
		||||
    const bool cull_backfaces);
 | 
			
		||||
#endif
 | 
			
		||||
// Forward pass for rasterizing a batch of meshes.
 | 
			
		||||
@ -92,6 +94,7 @@ RasterizeMeshesNaive(
 | 
			
		||||
    const float blur_radius,
 | 
			
		||||
    const int faces_per_pixel,
 | 
			
		||||
    const bool perspective_correct,
 | 
			
		||||
    const bool clip_barycentric_coords,
 | 
			
		||||
    const bool cull_backfaces) {
 | 
			
		||||
  // TODO: Better type checking.
 | 
			
		||||
  if (face_verts.is_cuda()) {
 | 
			
		||||
@ -107,6 +110,7 @@ RasterizeMeshesNaive(
 | 
			
		||||
        blur_radius,
 | 
			
		||||
        faces_per_pixel,
 | 
			
		||||
        perspective_correct,
 | 
			
		||||
        clip_barycentric_coords,
 | 
			
		||||
        cull_backfaces);
 | 
			
		||||
#else
 | 
			
		||||
    AT_ERROR("Not compiled with GPU support");
 | 
			
		||||
@ -120,6 +124,7 @@ RasterizeMeshesNaive(
 | 
			
		||||
        blur_radius,
 | 
			
		||||
        faces_per_pixel,
 | 
			
		||||
        perspective_correct,
 | 
			
		||||
        clip_barycentric_coords,
 | 
			
		||||
        cull_backfaces);
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
@ -134,7 +139,8 @@ torch::Tensor RasterizeMeshesBackwardCpu(
 | 
			
		||||
    const torch::Tensor& grad_bary,
 | 
			
		||||
    const torch::Tensor& grad_zbuf,
 | 
			
		||||
    const torch::Tensor& grad_dists,
 | 
			
		||||
    const bool perspective_correct);
 | 
			
		||||
    const bool perspective_correct,
 | 
			
		||||
    const bool clip_barycentric_coords);
 | 
			
		||||
 | 
			
		||||
#ifdef WITH_CUDA
 | 
			
		||||
torch::Tensor RasterizeMeshesBackwardCuda(
 | 
			
		||||
@ -143,7 +149,8 @@ torch::Tensor RasterizeMeshesBackwardCuda(
 | 
			
		||||
    const torch::Tensor& grad_bary,
 | 
			
		||||
    const torch::Tensor& grad_zbuf,
 | 
			
		||||
    const torch::Tensor& grad_dists,
 | 
			
		||||
    const bool perspective_correct);
 | 
			
		||||
    const bool perspective_correct,
 | 
			
		||||
    const bool clip_barycentric_coords);
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
// Args:
 | 
			
		||||
@ -176,7 +183,8 @@ torch::Tensor RasterizeMeshesBackward(
 | 
			
		||||
    const torch::Tensor& grad_zbuf,
 | 
			
		||||
    const torch::Tensor& grad_bary,
 | 
			
		||||
    const torch::Tensor& grad_dists,
 | 
			
		||||
    const bool perspective_correct) {
 | 
			
		||||
    const bool perspective_correct,
 | 
			
		||||
    const bool clip_barycentric_coords) {
 | 
			
		||||
  if (face_verts.is_cuda()) {
 | 
			
		||||
#ifdef WITH_CUDA
 | 
			
		||||
    CHECK_CUDA(face_verts);
 | 
			
		||||
@ -190,7 +198,8 @@ torch::Tensor RasterizeMeshesBackward(
 | 
			
		||||
        grad_zbuf,
 | 
			
		||||
        grad_bary,
 | 
			
		||||
        grad_dists,
 | 
			
		||||
        perspective_correct);
 | 
			
		||||
        perspective_correct,
 | 
			
		||||
        clip_barycentric_coords);
 | 
			
		||||
#else
 | 
			
		||||
    AT_ERROR("Not compiled with GPU support");
 | 
			
		||||
#endif
 | 
			
		||||
@ -201,7 +210,8 @@ torch::Tensor RasterizeMeshesBackward(
 | 
			
		||||
        grad_zbuf,
 | 
			
		||||
        grad_bary,
 | 
			
		||||
        grad_dists,
 | 
			
		||||
        perspective_correct);
 | 
			
		||||
        perspective_correct,
 | 
			
		||||
        clip_barycentric_coords);
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -300,6 +310,7 @@ RasterizeMeshesFineCuda(
 | 
			
		||||
    const int bin_size,
 | 
			
		||||
    const int faces_per_pixel,
 | 
			
		||||
    const bool perspective_correct,
 | 
			
		||||
    const bool clip_barycentric_coords,
 | 
			
		||||
    const bool cull_backfaces);
 | 
			
		||||
#endif
 | 
			
		||||
// Args:
 | 
			
		||||
@ -356,6 +367,7 @@ RasterizeMeshesFine(
 | 
			
		||||
    const int bin_size,
 | 
			
		||||
    const int faces_per_pixel,
 | 
			
		||||
    const bool perspective_correct,
 | 
			
		||||
    const bool clip_barycentric_coords,
 | 
			
		||||
    const bool cull_backfaces) {
 | 
			
		||||
  if (face_verts.is_cuda()) {
 | 
			
		||||
#ifdef WITH_CUDA
 | 
			
		||||
@ -369,6 +381,7 @@ RasterizeMeshesFine(
 | 
			
		||||
        bin_size,
 | 
			
		||||
        faces_per_pixel,
 | 
			
		||||
        perspective_correct,
 | 
			
		||||
        clip_barycentric_coords,
 | 
			
		||||
        cull_backfaces);
 | 
			
		||||
#else
 | 
			
		||||
    AT_ERROR("Not compiled with GPU support");
 | 
			
		||||
@ -446,6 +459,7 @@ RasterizeMeshes(
 | 
			
		||||
    const int bin_size,
 | 
			
		||||
    const int max_faces_per_bin,
 | 
			
		||||
    const bool perspective_correct,
 | 
			
		||||
    const bool clip_barycentric_coords,
 | 
			
		||||
    const bool cull_backfaces) {
 | 
			
		||||
  if (bin_size > 0 && max_faces_per_bin > 0) {
 | 
			
		||||
    // Use coarse-to-fine rasterization
 | 
			
		||||
@ -465,6 +479,7 @@ RasterizeMeshes(
 | 
			
		||||
        bin_size,
 | 
			
		||||
        faces_per_pixel,
 | 
			
		||||
        perspective_correct,
 | 
			
		||||
        clip_barycentric_coords,
 | 
			
		||||
        cull_backfaces);
 | 
			
		||||
  } else {
 | 
			
		||||
    // Use the naive per-pixel implementation
 | 
			
		||||
@ -476,6 +491,7 @@ RasterizeMeshes(
 | 
			
		||||
        blur_radius,
 | 
			
		||||
        faces_per_pixel,
 | 
			
		||||
        perspective_correct,
 | 
			
		||||
        clip_barycentric_coords,
 | 
			
		||||
        cull_backfaces);
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -108,6 +108,7 @@ RasterizeMeshesNaiveCpu(
 | 
			
		||||
    const float blur_radius,
 | 
			
		||||
    const int faces_per_pixel,
 | 
			
		||||
    const bool perspective_correct,
 | 
			
		||||
    const bool clip_barycentric_coords,
 | 
			
		||||
    const bool cull_backfaces) {
 | 
			
		||||
  if (face_verts.ndimension() != 3 || face_verts.size(1) != 3 ||
 | 
			
		||||
      face_verts.size(2) != 3) {
 | 
			
		||||
@ -213,8 +214,12 @@ RasterizeMeshesNaiveCpu(
 | 
			
		||||
              ? bary0
 | 
			
		||||
              : BarycentricPerspectiveCorrectionForward(bary0, z0, z1, z2);
 | 
			
		||||
 | 
			
		||||
          const vec3<float> bary_clip =
 | 
			
		||||
              !clip_barycentric_coords ? bary : BarycentricClipForward(bary);
 | 
			
		||||
 | 
			
		||||
          // Use barycentric coordinates to get the depth of the current pixel
 | 
			
		||||
          const float pz = (bary.x * z0 + bary.y * z1 + bary.z * z2);
 | 
			
		||||
          const float pz =
 | 
			
		||||
              (bary_clip.x * z0 + bary_clip.y * z1 + bary_clip.z * z2);
 | 
			
		||||
 | 
			
		||||
          if (pz < 0) {
 | 
			
		||||
            continue; // Point is behind the image plane so ignore.
 | 
			
		||||
@ -236,7 +241,7 @@ RasterizeMeshesNaiveCpu(
 | 
			
		||||
            continue;
 | 
			
		||||
          }
 | 
			
		||||
          // The current pixel lies inside the current face.
 | 
			
		||||
          q.emplace(pz, f, signed_dist, bary.x, bary.y, bary.z);
 | 
			
		||||
          q.emplace(pz, f, signed_dist, bary_clip.x, bary_clip.y, bary_clip.z);
 | 
			
		||||
          if (static_cast<int>(q.size()) > K) {
 | 
			
		||||
            q.pop();
 | 
			
		||||
          }
 | 
			
		||||
@ -264,7 +269,8 @@ torch::Tensor RasterizeMeshesBackwardCpu(
 | 
			
		||||
    const torch::Tensor& grad_zbuf, // (N, H, W, K)
 | 
			
		||||
    const torch::Tensor& grad_bary, // (N, H, W, K, 3)
 | 
			
		||||
    const torch::Tensor& grad_dists, // (N, H, W, K)
 | 
			
		||||
    const bool perspective_correct) {
 | 
			
		||||
    const bool perspective_correct,
 | 
			
		||||
    const bool clip_barycentric_coords) {
 | 
			
		||||
  const int F = face_verts.size(0);
 | 
			
		||||
  const int N = pix_to_face.size(0);
 | 
			
		||||
  const int H = pix_to_face.size(1);
 | 
			
		||||
@ -335,6 +341,8 @@ torch::Tensor RasterizeMeshesBackwardCpu(
 | 
			
		||||
          const vec3<float> bary = !perspective_correct
 | 
			
		||||
              ? bary0
 | 
			
		||||
              : BarycentricPerspectiveCorrectionForward(bary0, z0, z1, z2);
 | 
			
		||||
          const vec3<float> bary_clip =
 | 
			
		||||
              !clip_barycentric_coords ? bary : BarycentricClipForward(bary);
 | 
			
		||||
 | 
			
		||||
          // Distances inside the face are negative so get the
 | 
			
		||||
          // correct sign to apply to the upstream gradient.
 | 
			
		||||
@ -354,22 +362,28 @@ torch::Tensor RasterizeMeshesBackwardCpu(
 | 
			
		||||
          // d_zbuf/d_bary_w0 = z0
 | 
			
		||||
          // d_zbuf/d_bary_w1 = z1
 | 
			
		||||
          // d_zbuf/d_bary_w2 = z2
 | 
			
		||||
          const vec3<float> d_zbuf_d_bary(z0, z1, z2);
 | 
			
		||||
          const vec3<float> d_zbuf_d_baryclip(z0, z1, z2);
 | 
			
		||||
 | 
			
		||||
          // Total upstream barycentric gradients are the sum of
 | 
			
		||||
          // external upstream gradients and contribution from zbuf.
 | 
			
		||||
          vec3<float> grad_bary_f_sum =
 | 
			
		||||
              (grad_bary_upstream + grad_zbuf_upstream * d_zbuf_d_bary);
 | 
			
		||||
          const vec3<float> grad_bary_f_sum =
 | 
			
		||||
              (grad_bary_upstream + grad_zbuf_upstream * d_zbuf_d_baryclip);
 | 
			
		||||
 | 
			
		||||
          vec3<float> grad_bary0 = grad_bary_f_sum;
 | 
			
		||||
 | 
			
		||||
          if (clip_barycentric_coords) {
 | 
			
		||||
            grad_bary0 = BarycentricClipBackward(bary, grad_bary0);
 | 
			
		||||
          }
 | 
			
		||||
 | 
			
		||||
          if (perspective_correct) {
 | 
			
		||||
            auto perspective_grads = BarycentricPerspectiveCorrectionBackward(
 | 
			
		||||
                bary0, z0, z1, z2, grad_bary_f_sum);
 | 
			
		||||
                bary0, z0, z1, z2, grad_bary0);
 | 
			
		||||
            grad_bary0 = std::get<0>(perspective_grads);
 | 
			
		||||
            grad_face_verts[f][0][2] += std::get<1>(perspective_grads);
 | 
			
		||||
            grad_face_verts[f][1][2] += std::get<2>(perspective_grads);
 | 
			
		||||
            grad_face_verts[f][2][2] += std::get<3>(perspective_grads);
 | 
			
		||||
          }
 | 
			
		||||
 | 
			
		||||
          auto grad_bary_f =
 | 
			
		||||
              BarycentricCoordsBackward(pxy, v0xy, v1xy, v2xy, grad_bary0);
 | 
			
		||||
          const vec2<float> dbary_d_v0 = std::get<1>(grad_bary_f);
 | 
			
		||||
@ -379,13 +393,13 @@ torch::Tensor RasterizeMeshesBackwardCpu(
 | 
			
		||||
          // Update output gradient buffer.
 | 
			
		||||
          grad_face_verts[f][0][0] += dbary_d_v0.x + ddist_d_v0.x;
 | 
			
		||||
          grad_face_verts[f][0][1] += dbary_d_v0.y + ddist_d_v0.y;
 | 
			
		||||
          grad_face_verts[f][0][2] += grad_zbuf_upstream * bary.x;
 | 
			
		||||
          grad_face_verts[f][0][2] += grad_zbuf_upstream * bary_clip.x;
 | 
			
		||||
          grad_face_verts[f][1][0] += dbary_d_v1.x + ddist_d_v1.x;
 | 
			
		||||
          grad_face_verts[f][1][1] += dbary_d_v1.y + ddist_d_v1.y;
 | 
			
		||||
          grad_face_verts[f][1][2] += grad_zbuf_upstream * bary.y;
 | 
			
		||||
          grad_face_verts[f][1][2] += grad_zbuf_upstream * bary_clip.y;
 | 
			
		||||
          grad_face_verts[f][2][0] += dbary_d_v2.x + ddist_d_v2.x;
 | 
			
		||||
          grad_face_verts[f][2][1] += dbary_d_v2.y + ddist_d_v2.y;
 | 
			
		||||
          grad_face_verts[f][2][2] += grad_zbuf_upstream * bary.z;
 | 
			
		||||
          grad_face_verts[f][2][2] += grad_zbuf_upstream * bary_clip.z;
 | 
			
		||||
        }
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
@ -221,8 +221,108 @@ BarycentricPerspectiveCorrectionBackward(
 | 
			
		||||
  return thrust::make_tuple(grad_bary, grad_z0, grad_z1, grad_z2);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Calculate minimum squared distance between a line segment (v1 - v0) and a
 | 
			
		||||
// point p.
 | 
			
		||||
// Clip negative barycentric coordinates to 0.0 and renormalize so
 | 
			
		||||
// the barycentric coordinates for a point sum to 1. When the blur_radius
 | 
			
		||||
// is greater than 0, a face will still be recorded as overlapping a pixel
 | 
			
		||||
// if the pixel is outisde the face. In this case at least one of the
 | 
			
		||||
// barycentric coordinates for the pixel relative to the face will be negative.
 | 
			
		||||
// Clipping will ensure that the texture and z buffer are interpolated
 | 
			
		||||
// correctly.
 | 
			
		||||
//
 | 
			
		||||
//  Args
 | 
			
		||||
//     bary: (w0, w1, w2) barycentric coordinates which can be outside the
 | 
			
		||||
//            range [0, 1].
 | 
			
		||||
//
 | 
			
		||||
//  Returns
 | 
			
		||||
//     bary: (w0, w1, w2) barycentric coordinates in the range [0, 1] which
 | 
			
		||||
//           satisfy the condition: sum(w0, w1, w2) = 1.0.
 | 
			
		||||
//
 | 
			
		||||
__device__ inline float3 BarycentricClipForward(const float3 bary) {
 | 
			
		||||
  float3 w = make_float3(0.0f, 0.0f, 0.0f);
 | 
			
		||||
  // Clamp lower bound only
 | 
			
		||||
  w.x = max(bary.x, 0.0);
 | 
			
		||||
  w.y = max(bary.y, 0.0);
 | 
			
		||||
  w.z = max(bary.z, 0.0);
 | 
			
		||||
  float w_sum = w.x + w.y + w.z;
 | 
			
		||||
  w_sum = fmaxf(w_sum, 1e-5);
 | 
			
		||||
  w.x /= w_sum;
 | 
			
		||||
  w.y /= w_sum;
 | 
			
		||||
  w.z /= w_sum;
 | 
			
		||||
 | 
			
		||||
  return w;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Backward pass for barycentric coordinate clipping.
 | 
			
		||||
//
 | 
			
		||||
//  Args
 | 
			
		||||
//     bary: (w0, w1, w2) barycentric coordinates which can be outside the
 | 
			
		||||
//            range [0, 1].
 | 
			
		||||
//     grad_baryclip_upstream: vec3<T> Upstream gradient for each of the clipped
 | 
			
		||||
//                         barycentric coordinates [grad_w0, grad_w1, grad_w2].
 | 
			
		||||
//
 | 
			
		||||
// Returns
 | 
			
		||||
//    vec3<T> of gradients for the unclipped barycentric coordinates:
 | 
			
		||||
//    (grad_w0, grad_w1, grad_w2)
 | 
			
		||||
//
 | 
			
		||||
__device__ inline float3 BarycentricClipBackward(
 | 
			
		||||
    const float3 bary,
 | 
			
		||||
    const float3 grad_baryclip_upstream) {
 | 
			
		||||
  // Redo some of the forward pass calculations
 | 
			
		||||
  float3 w = make_float3(0.0f, 0.0f, 0.0f);
 | 
			
		||||
  // Clamp lower bound only
 | 
			
		||||
  w.x = max(bary.x, 0.0);
 | 
			
		||||
  w.y = max(bary.y, 0.0);
 | 
			
		||||
  w.z = max(bary.z, 0.0);
 | 
			
		||||
  float w_sum = w.x + w.y + w.z;
 | 
			
		||||
 | 
			
		||||
  float3 grad_bary = make_float3(1.0f, 1.0f, 1.0f);
 | 
			
		||||
  float3 grad_clip = make_float3(1.0f, 1.0f, 1.0f);
 | 
			
		||||
  float3 grad_sum = make_float3(1.0f, 1.0f, 1.0f);
 | 
			
		||||
 | 
			
		||||
  // Check if sum was clipped.
 | 
			
		||||
  float grad_sum_clip = 1.0f;
 | 
			
		||||
  if (w_sum < 1e-5) {
 | 
			
		||||
    grad_sum_clip = 0.0f;
 | 
			
		||||
    w_sum = 1e-5;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Check if any of bary values have been clipped.
 | 
			
		||||
  if (bary.x < 0.0f) {
 | 
			
		||||
    grad_clip.x = 0.0f;
 | 
			
		||||
  }
 | 
			
		||||
  if (bary.y < 0.0f) {
 | 
			
		||||
    grad_clip.y = 0.0f;
 | 
			
		||||
  }
 | 
			
		||||
  if (bary.z < 0.0f) {
 | 
			
		||||
    grad_clip.z = 0.0f;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Gradients of the sum.
 | 
			
		||||
  grad_sum.x = -w.x / (pow(w_sum, 2.0f)) * grad_sum_clip;
 | 
			
		||||
  grad_sum.y = -w.y / (pow(w_sum, 2.0f)) * grad_sum_clip;
 | 
			
		||||
  grad_sum.z = -w.z / (pow(w_sum, 2.0f)) * grad_sum_clip;
 | 
			
		||||
 | 
			
		||||
  // Gradients for each of the bary coordinates including the cross terms
 | 
			
		||||
  // from the sum.
 | 
			
		||||
  grad_bary.x = grad_clip.x *
 | 
			
		||||
      (grad_baryclip_upstream.x * (1.0f / w_sum + grad_sum.x) +
 | 
			
		||||
       grad_baryclip_upstream.y * (grad_sum.y) +
 | 
			
		||||
       grad_baryclip_upstream.z * (grad_sum.z));
 | 
			
		||||
 | 
			
		||||
  grad_bary.y = grad_clip.y *
 | 
			
		||||
      (grad_baryclip_upstream.y * (1.0f / w_sum + grad_sum.y) +
 | 
			
		||||
       grad_baryclip_upstream.x * (grad_sum.x) +
 | 
			
		||||
       grad_baryclip_upstream.z * (grad_sum.z));
 | 
			
		||||
 | 
			
		||||
  grad_bary.z = grad_clip.z *
 | 
			
		||||
      (grad_baryclip_upstream.z * (1.0f / w_sum + grad_sum.z) +
 | 
			
		||||
       grad_baryclip_upstream.x * (grad_sum.x) +
 | 
			
		||||
       grad_baryclip_upstream.y * (grad_sum.y));
 | 
			
		||||
 | 
			
		||||
  return grad_bary;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Return minimum distance between line segment (v1 - v0) and point p.
 | 
			
		||||
//
 | 
			
		||||
// Args:
 | 
			
		||||
//     p: Coordinates of a point.
 | 
			
		||||
 | 
			
		||||
@ -242,8 +242,108 @@ inline std::tuple<vec3<T>, T, T, T> BarycentricPerspectiveCorrectionBackward(
 | 
			
		||||
  return std::make_tuple(grad_bary, grad_z0, grad_z1, grad_z2);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Calculate minimum squared distance between a line segment (v1 - v0) and a
 | 
			
		||||
// point p.
 | 
			
		||||
// Clip negative barycentric coordinates to 0.0 and renormalize so
 | 
			
		||||
// the barycentric coordinates for a point sum to 1. When the blur_radius
 | 
			
		||||
// is greater than 0, a face will still be recorded as overlapping a pixel
 | 
			
		||||
// if the pixel is outisde the face. In this case at least one of the
 | 
			
		||||
// barycentric coordinates for the pixel relative to the face will be negative.
 | 
			
		||||
// Clipping will ensure that the texture and z buffer are interpolated
 | 
			
		||||
// correctly.
 | 
			
		||||
//
 | 
			
		||||
//  Args
 | 
			
		||||
//     bary: (w0, w1, w2) barycentric coordinates which can contain values < 0.
 | 
			
		||||
//
 | 
			
		||||
//  Returns
 | 
			
		||||
//     bary: (w0, w1, w2) barycentric coordinates in the range [0, 1] which
 | 
			
		||||
//           satisfy the condition: sum(w0, w1, w2) = 1.0.
 | 
			
		||||
//
 | 
			
		||||
template <typename T>
 | 
			
		||||
vec3<T> BarycentricClipForward(const vec3<T> bary) {
 | 
			
		||||
  vec3<T> w(0.0f, 0.0f, 0.0f);
 | 
			
		||||
  // Only clamp negative values to 0.0.
 | 
			
		||||
  // No need to clamp values > 1.0 as they will be renormalized.
 | 
			
		||||
  w.x = std::max(bary.x, 0.0f);
 | 
			
		||||
  w.y = std::max(bary.y, 0.0f);
 | 
			
		||||
  w.z = std::max(bary.z, 0.0f);
 | 
			
		||||
  float w_sum = w.x + w.y + w.z;
 | 
			
		||||
  w_sum = std::fmaxf(w_sum, 1e-5);
 | 
			
		||||
  w.x /= w_sum;
 | 
			
		||||
  w.y /= w_sum;
 | 
			
		||||
  w.z /= w_sum;
 | 
			
		||||
  return w;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Backward pass for barycentric coordinate clipping.
 | 
			
		||||
//
 | 
			
		||||
//  Args
 | 
			
		||||
//     bary: (w0, w1, w2) barycentric coordinates which can contain values < 0.
 | 
			
		||||
//     grad_baryclip_upstream: vec3<T> Upstream gradient for each of the clipped
 | 
			
		||||
//                         barycentric coordinates [grad_w0, grad_w1, grad_w2].
 | 
			
		||||
//
 | 
			
		||||
// Returns
 | 
			
		||||
//    vec3<T> of gradients for the unclipped barycentric coordinates:
 | 
			
		||||
//    (grad_w0, grad_w1, grad_w2)
 | 
			
		||||
//
 | 
			
		||||
template <typename T>
 | 
			
		||||
vec3<T> BarycentricClipBackward(
 | 
			
		||||
    const vec3<T> bary,
 | 
			
		||||
    const vec3<T> grad_baryclip_upstream) {
 | 
			
		||||
  // Redo some of the forward pass calculations
 | 
			
		||||
  vec3<T> w(0.0f, 0.0f, 0.0f);
 | 
			
		||||
  w.x = std::max(bary.x, 0.0f);
 | 
			
		||||
  w.y = std::max(bary.y, 0.0f);
 | 
			
		||||
  w.z = std::max(bary.z, 0.0f);
 | 
			
		||||
  float w_sum = w.x + w.y + w.z;
 | 
			
		||||
 | 
			
		||||
  vec3<T> grad_bary(1.0f, 1.0f, 1.0f);
 | 
			
		||||
  vec3<T> grad_clip(1.0f, 1.0f, 1.0f);
 | 
			
		||||
  vec3<T> grad_sum(1.0f, 1.0f, 1.0f);
 | 
			
		||||
 | 
			
		||||
  // Check if the sum was clipped.
 | 
			
		||||
  float grad_sum_clip = 1.0f;
 | 
			
		||||
  if (w_sum < 1e-5) {
 | 
			
		||||
    grad_sum_clip = 0.0f;
 | 
			
		||||
    w_sum = 1e-5;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Check if any of the bary coordinates have been clipped.
 | 
			
		||||
  // Only negative values are clamped to 0.0.
 | 
			
		||||
  if (bary.x < 0.0f) {
 | 
			
		||||
    grad_clip.x = 0.0f;
 | 
			
		||||
  }
 | 
			
		||||
  if (bary.y < 0.0f) {
 | 
			
		||||
    grad_clip.y = 0.0f;
 | 
			
		||||
  }
 | 
			
		||||
  if (bary.z < 0.0f) {
 | 
			
		||||
    grad_clip.z = 0.0f;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Gradients of the sum.
 | 
			
		||||
  grad_sum.x = -w.x / (pow(w_sum, 2.0f)) * grad_sum_clip;
 | 
			
		||||
  grad_sum.y = -w.y / (pow(w_sum, 2.0f)) * grad_sum_clip;
 | 
			
		||||
  grad_sum.z = -w.z / (pow(w_sum, 2.0f)) * grad_sum_clip;
 | 
			
		||||
 | 
			
		||||
  // Gradients for each of the bary coordinates including the cross terms
 | 
			
		||||
  // from the sum.
 | 
			
		||||
  grad_bary.x = grad_clip.x *
 | 
			
		||||
      (grad_baryclip_upstream.x * (1.0f / w_sum + grad_sum.x) +
 | 
			
		||||
       grad_baryclip_upstream.y * (grad_sum.y) +
 | 
			
		||||
       grad_baryclip_upstream.z * (grad_sum.z));
 | 
			
		||||
 | 
			
		||||
  grad_bary.y = grad_clip.y *
 | 
			
		||||
      (grad_baryclip_upstream.y * (1.0f / w_sum + grad_sum.y) +
 | 
			
		||||
       grad_baryclip_upstream.x * (grad_sum.x) +
 | 
			
		||||
       grad_baryclip_upstream.z * (grad_sum.z));
 | 
			
		||||
 | 
			
		||||
  grad_bary.z = grad_clip.z *
 | 
			
		||||
      (grad_baryclip_upstream.z * (1.0f / w_sum + grad_sum.z) +
 | 
			
		||||
       grad_baryclip_upstream.x * (grad_sum.x) +
 | 
			
		||||
       grad_baryclip_upstream.y * (grad_sum.y));
 | 
			
		||||
 | 
			
		||||
  return grad_bary;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Calculate minimum distance between a line segment (v1 - v0) and point p.
 | 
			
		||||
//
 | 
			
		||||
// Args:
 | 
			
		||||
//     p: Coordinates of a point.
 | 
			
		||||
 | 
			
		||||
@ -24,6 +24,7 @@ def rasterize_meshes(
 | 
			
		||||
    bin_size: Optional[int] = None,
 | 
			
		||||
    max_faces_per_bin: Optional[int] = None,
 | 
			
		||||
    perspective_correct: bool = False,
 | 
			
		||||
    clip_barycentric_coords: bool = False,
 | 
			
		||||
    cull_backfaces: bool = False,
 | 
			
		||||
):
 | 
			
		||||
    """
 | 
			
		||||
@ -143,6 +144,7 @@ def rasterize_meshes(
 | 
			
		||||
        bin_size,
 | 
			
		||||
        max_faces_per_bin,
 | 
			
		||||
        perspective_correct,
 | 
			
		||||
        clip_barycentric_coords,
 | 
			
		||||
        cull_backfaces,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
@ -183,6 +185,7 @@ class _RasterizeFaceVerts(torch.autograd.Function):
 | 
			
		||||
        bin_size: int = 0,
 | 
			
		||||
        max_faces_per_bin: int = 0,
 | 
			
		||||
        perspective_correct: bool = False,
 | 
			
		||||
        clip_barycentric_coords: bool = False,
 | 
			
		||||
        cull_backfaces: bool = False,
 | 
			
		||||
    ):
 | 
			
		||||
        # pyre-fixme[16]: Module `pytorch3d` has no attribute `_C`.
 | 
			
		||||
@ -196,11 +199,13 @@ class _RasterizeFaceVerts(torch.autograd.Function):
 | 
			
		||||
            bin_size,
 | 
			
		||||
            max_faces_per_bin,
 | 
			
		||||
            perspective_correct,
 | 
			
		||||
            clip_barycentric_coords,
 | 
			
		||||
            cull_backfaces,
 | 
			
		||||
        )
 | 
			
		||||
        ctx.save_for_backward(face_verts, pix_to_face)
 | 
			
		||||
        ctx.mark_non_differentiable(pix_to_face)
 | 
			
		||||
        ctx.perspective_correct = perspective_correct
 | 
			
		||||
        ctx.clip_barycentric_coords = clip_barycentric_coords
 | 
			
		||||
        return pix_to_face, zbuf, barycentric_coords, dists
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
@ -214,6 +219,7 @@ class _RasterizeFaceVerts(torch.autograd.Function):
 | 
			
		||||
        grad_bin_size = None
 | 
			
		||||
        grad_max_faces_per_bin = None
 | 
			
		||||
        grad_perspective_correct = None
 | 
			
		||||
        grad_clip_barycentric_coords = None
 | 
			
		||||
        grad_cull_backfaces = None
 | 
			
		||||
        face_verts, pix_to_face = ctx.saved_tensors
 | 
			
		||||
        grad_face_verts = _C.rasterize_meshes_backward(
 | 
			
		||||
@ -223,6 +229,7 @@ class _RasterizeFaceVerts(torch.autograd.Function):
 | 
			
		||||
            grad_barycentric_coords,
 | 
			
		||||
            grad_dists,
 | 
			
		||||
            ctx.perspective_correct,
 | 
			
		||||
            ctx.clip_barycentric_coords,
 | 
			
		||||
        )
 | 
			
		||||
        grads = (
 | 
			
		||||
            grad_face_verts,
 | 
			
		||||
@ -234,6 +241,7 @@ class _RasterizeFaceVerts(torch.autograd.Function):
 | 
			
		||||
            grad_bin_size,
 | 
			
		||||
            grad_max_faces_per_bin,
 | 
			
		||||
            grad_perspective_correct,
 | 
			
		||||
            grad_clip_barycentric_coords,
 | 
			
		||||
            grad_cull_backfaces,
 | 
			
		||||
        )
 | 
			
		||||
        return grads
 | 
			
		||||
@ -250,6 +258,7 @@ def rasterize_meshes_python(
 | 
			
		||||
    blur_radius: float = 0.0,
 | 
			
		||||
    faces_per_pixel: int = 8,
 | 
			
		||||
    perspective_correct: bool = False,
 | 
			
		||||
    clip_barycentric_coords: bool = False,
 | 
			
		||||
    cull_backfaces: bool = False,
 | 
			
		||||
):
 | 
			
		||||
    """
 | 
			
		||||
@ -356,6 +365,14 @@ def rasterize_meshes_python(
 | 
			
		||||
                        top2 = z0 * z1 * l2
 | 
			
		||||
                        bot = top0 + top1 + top2
 | 
			
		||||
                        bary = torch.stack([top0 / bot, top1 / bot, top2 / bot])
 | 
			
		||||
 | 
			
		||||
                    # Check if inside before clipping
 | 
			
		||||
                    inside = all(x > 0.0 for x in bary)
 | 
			
		||||
 | 
			
		||||
                    # Barycentric clipping
 | 
			
		||||
                    if clip_barycentric_coords:
 | 
			
		||||
                        bary = barycentric_coordinates_clip(bary)
 | 
			
		||||
                    # use clipped barycentric coords to calculate the z value
 | 
			
		||||
                    pz = bary[0] * v0[2] + bary[1] * v1[2] + bary[2] * v2[2]
 | 
			
		||||
 | 
			
		||||
                    # Check if point is behind the image.
 | 
			
		||||
@ -365,7 +382,6 @@ def rasterize_meshes_python(
 | 
			
		||||
                    # Calculate signed 2D distance from point to face.
 | 
			
		||||
                    # Points inside the triangle have negative distance.
 | 
			
		||||
                    dist = point_triangle_distance(pxy, v0[:2], v1[:2], v2[:2])
 | 
			
		||||
                    inside = all(x > 0.0 for x in bary)
 | 
			
		||||
 | 
			
		||||
                    signed_dist = dist * -1.0 if inside else dist
 | 
			
		||||
 | 
			
		||||
@ -433,6 +449,33 @@ def edge_function(p, v0, v1):
 | 
			
		||||
    return (p[0] - v0[0]) * (v1[1] - v0[1]) - (p[1] - v0[1]) * (v1[0] - v0[0])
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def barycentric_coordinates_clip(bary):
 | 
			
		||||
    """
 | 
			
		||||
    Clip negative barycentric coordinates to 0.0 and renormalize so
 | 
			
		||||
    the barycentric coordinates for a point sum to 1. When the blur_radius
 | 
			
		||||
    is greater than 0, a face will still be recorded as overlapping a pixel
 | 
			
		||||
    if the pixel is outisde the face. In this case at least one of the
 | 
			
		||||
    barycentric coordinates for the pixel relative to the face will be negative.
 | 
			
		||||
    Clipping will ensure that the texture and z buffer are interpolated correctly.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        bary: tuple of barycentric coordinates
 | 
			
		||||
 | 
			
		||||
    Returns
 | 
			
		||||
        bary_clip: (w0, w1, w2) barycentric coordinates with no negative values.
 | 
			
		||||
    """
 | 
			
		||||
    # Only negative values are clamped to 0.0.
 | 
			
		||||
    w0_clip = torch.clamp(bary[0], min=0.0)
 | 
			
		||||
    w1_clip = torch.clamp(bary[1], min=0.0)
 | 
			
		||||
    w2_clip = torch.clamp(bary[2], min=0.0)
 | 
			
		||||
    bary_sum = torch.clamp(w0_clip + w1_clip + w2_clip, min=1e-5)
 | 
			
		||||
    w0_clip = w0_clip / bary_sum
 | 
			
		||||
    w1_clip = w1_clip / bary_sum
 | 
			
		||||
    w2_clip = w2_clip / bary_sum
 | 
			
		||||
 | 
			
		||||
    return (w0_clip, w1_clip, w2_clip)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def barycentric_coordinates(p, v0, v1, v2):
 | 
			
		||||
    """
 | 
			
		||||
    Compute the barycentric coordinates of a point relative to a triangle.
 | 
			
		||||
 | 
			
		||||
@ -26,6 +26,7 @@ class RasterizationSettings:
 | 
			
		||||
        "bin_size",
 | 
			
		||||
        "max_faces_per_bin",
 | 
			
		||||
        "perspective_correct",
 | 
			
		||||
        "clip_barycentric_coords",
 | 
			
		||||
        "cull_backfaces",
 | 
			
		||||
    ]
 | 
			
		||||
 | 
			
		||||
@ -37,6 +38,7 @@ class RasterizationSettings:
 | 
			
		||||
        bin_size: Optional[int] = None,
 | 
			
		||||
        max_faces_per_bin: Optional[int] = None,
 | 
			
		||||
        perspective_correct: bool = False,
 | 
			
		||||
        clip_barycentric_coords: bool = False,
 | 
			
		||||
        cull_backfaces: bool = False,
 | 
			
		||||
    ):
 | 
			
		||||
        self.image_size = image_size
 | 
			
		||||
@ -45,6 +47,7 @@ class RasterizationSettings:
 | 
			
		||||
        self.bin_size = bin_size
 | 
			
		||||
        self.max_faces_per_bin = max_faces_per_bin
 | 
			
		||||
        self.perspective_correct = perspective_correct
 | 
			
		||||
        self.clip_barycentric_coords = clip_barycentric_coords
 | 
			
		||||
        self.cull_backfaces = cull_backfaces
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -127,6 +130,7 @@ class MeshRasterizer(nn.Module):
 | 
			
		||||
            bin_size=raster_settings.bin_size,
 | 
			
		||||
            max_faces_per_bin=raster_settings.max_faces_per_bin,
 | 
			
		||||
            perspective_correct=raster_settings.perspective_correct,
 | 
			
		||||
            clip_barycentric_coords=raster_settings.clip_barycentric_coords,
 | 
			
		||||
            cull_backfaces=raster_settings.cull_backfaces,
 | 
			
		||||
        )
 | 
			
		||||
        return Fragments(
 | 
			
		||||
 | 
			
		||||
@ -49,21 +49,6 @@ class MeshRenderer(nn.Module):
 | 
			
		||||
        the range for the corresponding face.
 | 
			
		||||
        """
 | 
			
		||||
        fragments = self.rasterizer(meshes_world, **kwargs)
 | 
			
		||||
        raster_settings = kwargs.get("raster_settings", self.rasterizer.raster_settings)
 | 
			
		||||
        if raster_settings.blur_radius > 0.0:
 | 
			
		||||
            # TODO: potentially move barycentric clipping to the rasterizer
 | 
			
		||||
            # if no downstream functions requires unclipped values.
 | 
			
		||||
            # This will avoid unnecssary re-interpolation of the z buffer.
 | 
			
		||||
            clipped_bary_coords = _clip_barycentric_coordinates(fragments.bary_coords)
 | 
			
		||||
            clipped_zbuf = _interpolate_zbuf(
 | 
			
		||||
                fragments.pix_to_face, clipped_bary_coords, meshes_world
 | 
			
		||||
            )
 | 
			
		||||
            fragments = Fragments(
 | 
			
		||||
                bary_coords=clipped_bary_coords,
 | 
			
		||||
                zbuf=clipped_zbuf,
 | 
			
		||||
                dists=fragments.dists,
 | 
			
		||||
                pix_to_face=fragments.pix_to_face,
 | 
			
		||||
            )
 | 
			
		||||
        images = self.shader(fragments, meshes_world, **kwargs)
 | 
			
		||||
 | 
			
		||||
        return images
 | 
			
		||||
 | 
			
		||||
@ -20,9 +20,13 @@ def _clip_barycentric_coordinates(bary) -> torch.Tensor:
 | 
			
		||||
    if bary.shape[-1] != 3:
 | 
			
		||||
        msg = "Expected barycentric coords to have last dim = 3; got %r"
 | 
			
		||||
        raise ValueError(msg % (bary.shape,))
 | 
			
		||||
    ndims = bary.ndim - 1
 | 
			
		||||
    mask = bary.eq(-1).all(dim=-1, keepdim=True).expand(*((-1,) * ndims + (3,)))
 | 
			
		||||
    clipped = bary.clamp(min=0.0)
 | 
			
		||||
    clipped[mask] = 0.0
 | 
			
		||||
    clipped_sum = torch.clamp(clipped.sum(dim=-1, keepdim=True), min=1e-5)
 | 
			
		||||
    clipped = clipped / clipped_sum
 | 
			
		||||
    clipped[mask] = -1.0
 | 
			
		||||
    return clipped
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -49,6 +53,8 @@ def _interpolate_zbuf(
 | 
			
		||||
    verts = meshes.verts_packed()
 | 
			
		||||
    faces = meshes.faces_packed()
 | 
			
		||||
    faces_verts_z = verts[faces][..., 2][..., None]  # (F, 3, 1)
 | 
			
		||||
    return interpolate_face_attributes(pix_to_face, barycentric_coords, faces_verts_z)[
 | 
			
		||||
    zbuf = interpolate_face_attributes(pix_to_face, barycentric_coords, faces_verts_z)[
 | 
			
		||||
        ..., 0
 | 
			
		||||
    ]  # (1, H, W, K)
 | 
			
		||||
    zbuf[pix_to_face == -1] = -1
 | 
			
		||||
    return zbuf
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										112
									
								
								tests/bm_barycentric_clipping.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										112
									
								
								tests/bm_barycentric_clipping.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,112 @@
 | 
			
		||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
 | 
			
		||||
 | 
			
		||||
from itertools import product
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from fvcore.common.benchmark import benchmark
 | 
			
		||||
from pytorch3d.renderer.cameras import OpenGLPerspectiveCameras, look_at_view_transform
 | 
			
		||||
from pytorch3d.renderer.mesh.rasterizer import (
 | 
			
		||||
    Fragments,
 | 
			
		||||
    MeshRasterizer,
 | 
			
		||||
    RasterizationSettings,
 | 
			
		||||
)
 | 
			
		||||
from pytorch3d.renderer.mesh.utils import (
 | 
			
		||||
    _clip_barycentric_coordinates,
 | 
			
		||||
    _interpolate_zbuf,
 | 
			
		||||
)
 | 
			
		||||
from pytorch3d.utils.ico_sphere import ico_sphere
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def baryclip_cuda(
 | 
			
		||||
    num_meshes: int = 8,
 | 
			
		||||
    ico_level: int = 5,
 | 
			
		||||
    image_size: int = 64,
 | 
			
		||||
    faces_per_pixel: int = 50,
 | 
			
		||||
    device="cuda",
 | 
			
		||||
):
 | 
			
		||||
    # Init meshes
 | 
			
		||||
    sphere_meshes = ico_sphere(ico_level, device).extend(num_meshes)
 | 
			
		||||
    # Init transform
 | 
			
		||||
    R, T = look_at_view_transform(1.0, 0.0, 0.0)
 | 
			
		||||
    cameras = OpenGLPerspectiveCameras(device=device, R=R, T=T)
 | 
			
		||||
    # Init rasterizer
 | 
			
		||||
    raster_settings = RasterizationSettings(
 | 
			
		||||
        image_size=image_size,
 | 
			
		||||
        blur_radius=1e-4,
 | 
			
		||||
        faces_per_pixel=faces_per_pixel,
 | 
			
		||||
        clip_barycentric_coords=True,
 | 
			
		||||
    )
 | 
			
		||||
    rasterizer = MeshRasterizer(cameras=cameras, raster_settings=raster_settings)
 | 
			
		||||
 | 
			
		||||
    torch.cuda.synchronize()
 | 
			
		||||
 | 
			
		||||
    def raster_fn():
 | 
			
		||||
        rasterizer(sphere_meshes)
 | 
			
		||||
        torch.cuda.synchronize()
 | 
			
		||||
 | 
			
		||||
    return raster_fn
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def baryclip_pytorch(
 | 
			
		||||
    num_meshes: int = 8,
 | 
			
		||||
    ico_level: int = 5,
 | 
			
		||||
    image_size: int = 64,
 | 
			
		||||
    faces_per_pixel: int = 50,
 | 
			
		||||
    device="cuda",
 | 
			
		||||
):
 | 
			
		||||
    # Init meshes
 | 
			
		||||
    sphere_meshes = ico_sphere(ico_level, device).extend(num_meshes)
 | 
			
		||||
    # Init transform
 | 
			
		||||
    R, T = look_at_view_transform(1.0, 0.0, 0.0)
 | 
			
		||||
    cameras = OpenGLPerspectiveCameras(device=device, R=R, T=T)
 | 
			
		||||
    # Init rasterizer
 | 
			
		||||
    raster_settings = RasterizationSettings(
 | 
			
		||||
        image_size=image_size,
 | 
			
		||||
        blur_radius=1e-4,
 | 
			
		||||
        faces_per_pixel=faces_per_pixel,
 | 
			
		||||
        clip_barycentric_coords=False,
 | 
			
		||||
    )
 | 
			
		||||
    rasterizer = MeshRasterizer(cameras=cameras, raster_settings=raster_settings)
 | 
			
		||||
 | 
			
		||||
    torch.cuda.synchronize()
 | 
			
		||||
 | 
			
		||||
    def raster_fn():
 | 
			
		||||
        fragments = rasterizer(sphere_meshes)
 | 
			
		||||
 | 
			
		||||
        # Clip bary and reinterpolate
 | 
			
		||||
        clipped_bary_coords = _clip_barycentric_coordinates(fragments.bary_coords)
 | 
			
		||||
        clipped_zbuf = _interpolate_zbuf(
 | 
			
		||||
            fragments.pix_to_face, clipped_bary_coords, sphere_meshes
 | 
			
		||||
        )
 | 
			
		||||
        fragments = Fragments(
 | 
			
		||||
            bary_coords=clipped_bary_coords,
 | 
			
		||||
            zbuf=clipped_zbuf,
 | 
			
		||||
            dists=fragments.dists,
 | 
			
		||||
            pix_to_face=fragments.pix_to_face,
 | 
			
		||||
        )
 | 
			
		||||
        torch.cuda.synchronize()
 | 
			
		||||
 | 
			
		||||
    return raster_fn
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def bm_barycentric_clip() -> None:
 | 
			
		||||
    if torch.cuda.is_available():
 | 
			
		||||
        kwargs_list = []
 | 
			
		||||
        num_meshes = [1, 8]
 | 
			
		||||
        ico_level = [0, 4]
 | 
			
		||||
        image_size = [64, 128, 256]
 | 
			
		||||
        faces_per_pixel = [10, 75, 100]
 | 
			
		||||
        test_cases = product(num_meshes, ico_level, image_size, faces_per_pixel)
 | 
			
		||||
        for case in test_cases:
 | 
			
		||||
            n, ic, im, nf = case
 | 
			
		||||
            kwargs_list.append(
 | 
			
		||||
                {
 | 
			
		||||
                    "num_meshes": n,
 | 
			
		||||
                    "ico_level": ic,
 | 
			
		||||
                    "image_size": im,
 | 
			
		||||
                    "faces_per_pixel": nf,
 | 
			
		||||
                }
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        benchmark(baryclip_cuda, "BARY_CLIP_CUDA", kwargs_list, warmup_iters=1)
 | 
			
		||||
        benchmark(baryclip_pytorch, "BARY_CLIP_PYTORCH", kwargs_list, warmup_iters=1)
 | 
			
		||||
										
											Binary file not shown.
										
									
								
							| 
		 Before Width: | Height: | Size: 43 KiB After Width: | Height: | Size: 43 KiB  | 
@ -10,6 +10,10 @@ from pytorch3d.renderer.mesh.rasterize_meshes import (
 | 
			
		||||
    rasterize_meshes,
 | 
			
		||||
    rasterize_meshes_python,
 | 
			
		||||
)
 | 
			
		||||
from pytorch3d.renderer.mesh.utils import (
 | 
			
		||||
    _clip_barycentric_coordinates,
 | 
			
		||||
    _interpolate_zbuf,
 | 
			
		||||
)
 | 
			
		||||
from pytorch3d.structures import Meshes
 | 
			
		||||
from pytorch3d.utils import ico_sphere
 | 
			
		||||
 | 
			
		||||
@ -21,6 +25,7 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
        self._simple_blurry_raster(rasterize_meshes_python, device, bin_size=-1)
 | 
			
		||||
        self._test_behind_camera(rasterize_meshes_python, device, bin_size=-1)
 | 
			
		||||
        self._test_perspective_correct(rasterize_meshes_python, device, bin_size=-1)
 | 
			
		||||
        self._test_barycentric_clipping(rasterize_meshes_python, device, bin_size=-1)
 | 
			
		||||
        self._test_back_face_culling(rasterize_meshes_python, device, bin_size=-1)
 | 
			
		||||
 | 
			
		||||
    def test_simple_cpu_naive(self):
 | 
			
		||||
@ -170,8 +175,29 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
        verts2.requires_grad = True
 | 
			
		||||
        meshes_cuda = Meshes(verts=[verts2], faces=[faces2])
 | 
			
		||||
 | 
			
		||||
        args_cpu = (meshes_cpu, image_size, radius, faces_per_pixel)
 | 
			
		||||
        args_cuda = (meshes_cuda, image_size, radius, faces_per_pixel, 0, 0)
 | 
			
		||||
        barycentric_clip = True
 | 
			
		||||
        args_cpu = (
 | 
			
		||||
            meshes_cpu,
 | 
			
		||||
            image_size,
 | 
			
		||||
            radius,
 | 
			
		||||
            faces_per_pixel,
 | 
			
		||||
            None,
 | 
			
		||||
            None,
 | 
			
		||||
            False,
 | 
			
		||||
            barycentric_clip,
 | 
			
		||||
            False,
 | 
			
		||||
        )
 | 
			
		||||
        args_cuda = (
 | 
			
		||||
            meshes_cuda,
 | 
			
		||||
            image_size,
 | 
			
		||||
            radius,
 | 
			
		||||
            faces_per_pixel,
 | 
			
		||||
            0,
 | 
			
		||||
            0,
 | 
			
		||||
            False,
 | 
			
		||||
            barycentric_clip,
 | 
			
		||||
            False,
 | 
			
		||||
        )
 | 
			
		||||
        self._compare_impls(
 | 
			
		||||
            rasterize_meshes,
 | 
			
		||||
            rasterize_meshes,
 | 
			
		||||
@ -333,6 +359,39 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
                    idxs_cuda[:K] = sorted(idxs_cuda[:K])
 | 
			
		||||
                    self.assertEqual(idxs_cpu, idxs_cuda)
 | 
			
		||||
 | 
			
		||||
    def test_python_vs_cpp_bary_clip(self):
 | 
			
		||||
        torch.manual_seed(232)
 | 
			
		||||
        N = 2
 | 
			
		||||
        V = 10
 | 
			
		||||
        F = 5
 | 
			
		||||
        verts1 = torch.randn(N, V, 3, requires_grad=True)
 | 
			
		||||
        verts2 = verts1.detach().clone().requires_grad_(True)
 | 
			
		||||
        faces = torch.randint(V, size=(N, F, 3))
 | 
			
		||||
        meshes1 = Meshes(verts1, faces)
 | 
			
		||||
        meshes2 = Meshes(verts2, faces)
 | 
			
		||||
 | 
			
		||||
        kwargs = {"image_size": 24, "clip_barycentric_coords": True}
 | 
			
		||||
        fn1 = functools.partial(rasterize_meshes, meshes1, **kwargs)
 | 
			
		||||
        fn2 = functools.partial(rasterize_meshes_python, meshes2, **kwargs)
 | 
			
		||||
        args = ()
 | 
			
		||||
        self._compare_impls(fn1, fn2, args, args, verts1, verts2, compare_grads=True)
 | 
			
		||||
 | 
			
		||||
    def test_cpp_vs_cuda_bary_clip(self):
 | 
			
		||||
        meshes = ico_sphere(2, device=torch.device("cpu"))
 | 
			
		||||
        verts1, faces1 = meshes.get_mesh_verts_faces(0)
 | 
			
		||||
        verts1.requires_grad = True
 | 
			
		||||
        meshes1 = Meshes(verts=[verts1], faces=[faces1])
 | 
			
		||||
        device = get_random_cuda_device()
 | 
			
		||||
        verts2 = verts1.detach().to(device).requires_grad_(True)
 | 
			
		||||
        faces2 = faces1.detach().clone().to(device)
 | 
			
		||||
        meshes2 = Meshes(verts=[verts2], faces=[faces2])
 | 
			
		||||
 | 
			
		||||
        kwargs = {"image_size": 64, "clip_barycentric_coords": True}
 | 
			
		||||
        fn1 = functools.partial(rasterize_meshes, meshes1, **kwargs)
 | 
			
		||||
        fn2 = functools.partial(rasterize_meshes, meshes2, bin_size=0, **kwargs)
 | 
			
		||||
        args = ()
 | 
			
		||||
        self._compare_impls(fn1, fn2, args, args, verts1, verts2, compare_grads=True)
 | 
			
		||||
 | 
			
		||||
    def test_python_vs_cpp_perspective_correct(self):
 | 
			
		||||
        torch.manual_seed(232)
 | 
			
		||||
        N = 2
 | 
			
		||||
@ -621,6 +680,82 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
        self.assertLess(zbuf_f_bary_diff, 1e-4)
 | 
			
		||||
        self.assertLess(zbuf_t_bary_diff, 1e-4)
 | 
			
		||||
 | 
			
		||||
    def _test_barycentric_clipping(self, rasterize_meshes_fn, device, bin_size=None):
 | 
			
		||||
        # fmt: off
 | 
			
		||||
        verts = torch.tensor([
 | 
			
		||||
            [-0.4, -0.4, 10],  # noqa: E241, E201
 | 
			
		||||
            [ 0.4, -0.4, 10],  # noqa: E241, E201
 | 
			
		||||
            [ 0.0,  0.4, 20],  # noqa: E241, E201
 | 
			
		||||
        ], dtype=torch.float32, device=device)
 | 
			
		||||
        # fmt: on
 | 
			
		||||
        faces = torch.tensor([[0, 1, 2]], device=device)
 | 
			
		||||
        meshes = Meshes(verts=[verts], faces=[faces])
 | 
			
		||||
        kwargs = {
 | 
			
		||||
            "meshes": meshes,
 | 
			
		||||
            "image_size": 5,
 | 
			
		||||
            "faces_per_pixel": 1,
 | 
			
		||||
            "blur_radius": 0.2,
 | 
			
		||||
            "perspective_correct": False,
 | 
			
		||||
            "clip_barycentric_coords": False,  # Initially set this to false
 | 
			
		||||
        }
 | 
			
		||||
        if bin_size != -1:
 | 
			
		||||
            kwargs["bin_size"] = bin_size
 | 
			
		||||
 | 
			
		||||
        # Run with and without perspective correction
 | 
			
		||||
        idx_f, zbuf_f, bary_f, dists_f = rasterize_meshes_fn(**kwargs)
 | 
			
		||||
 | 
			
		||||
        # fmt: off
 | 
			
		||||
        expected_bary = torch.tensor([
 | 
			
		||||
            [
 | 
			
		||||
                [-1.0000, -1.0000, -1.0000],  # noqa: E241, E201
 | 
			
		||||
                [-1.0000, -1.0000, -1.0000],  # noqa: E241, E201
 | 
			
		||||
                [-0.2500, -0.2500,  1.5000],  # noqa: E241, E201
 | 
			
		||||
                [-1.0000, -1.0000, -1.0000],  # noqa: E241, E201
 | 
			
		||||
                [-1.0000, -1.0000, -1.0000]   # noqa: E241, E201
 | 
			
		||||
            ],
 | 
			
		||||
            [
 | 
			
		||||
                [-1.0000, -1.0000, -1.0000],  # noqa: E241, E201
 | 
			
		||||
                [-0.5000,  0.5000,  1.0000],  # noqa: E241, E201
 | 
			
		||||
                [-0.0000, -0.0000,  1.0000],  # noqa: E241, E201
 | 
			
		||||
                [ 0.5000, -0.5000,  1.0000],  # noqa: E241, E201
 | 
			
		||||
                [-1.0000, -1.0000, -1.0000]   # noqa: E241, E201
 | 
			
		||||
            ],
 | 
			
		||||
            [
 | 
			
		||||
                [-1.0000, -1.0000, -1.0000],  # noqa: E241, E201
 | 
			
		||||
                [-0.2500,  0.7500,  0.5000],  # noqa: E241, E201
 | 
			
		||||
                [ 0.2500,  0.2500,  0.5000],  # noqa: E241, E201
 | 
			
		||||
                [ 0.7500, -0.2500,  0.5000],  # noqa: E241, E201
 | 
			
		||||
                [-1.0000, -1.0000, -1.0000]   # noqa: E241, E201
 | 
			
		||||
            ],
 | 
			
		||||
            [
 | 
			
		||||
                [-0.5000,  1.5000, -0.0000],  # noqa: E241, E201
 | 
			
		||||
                [-0.0000,  1.0000, -0.0000],  # noqa: E241, E201
 | 
			
		||||
                [ 0.5000,  0.5000, -0.0000],  # noqa: E241, E201
 | 
			
		||||
                [ 1.0000, -0.0000, -0.0000],  # noqa: E241, E201
 | 
			
		||||
                [ 1.5000, -0.5000,  0.0000]   # noqa: E241, E201
 | 
			
		||||
            ],
 | 
			
		||||
            [
 | 
			
		||||
                [-1.0000, -1.0000, -1.0000],  # noqa: E241, E201
 | 
			
		||||
                [ 0.2500,  1.2500, -0.5000],  # noqa: E241, E201
 | 
			
		||||
                [ 0.7500,  0.7500, -0.5000],  # noqa: E241, E201
 | 
			
		||||
                [ 1.2500,  0.2500, -0.5000],  # noqa: E241, E201
 | 
			
		||||
                [-1.0000, -1.0000, -1.0000]   # noqa: E241, E201
 | 
			
		||||
            ]
 | 
			
		||||
        ], dtype=torch.float32, device=device).view(1, 5, 5, 1, 3)
 | 
			
		||||
        # fmt: on
 | 
			
		||||
 | 
			
		||||
        self.assertClose(expected_bary, bary_f, atol=1e-4)
 | 
			
		||||
 | 
			
		||||
        # calculate the expected clipped barycentrics and zbuf
 | 
			
		||||
        expected_bary_clipped = _clip_barycentric_coordinates(expected_bary)
 | 
			
		||||
        expected_z_clipped = _interpolate_zbuf(idx_f, expected_bary_clipped, meshes)
 | 
			
		||||
 | 
			
		||||
        kwargs["clip_barycentric_coords"] = True
 | 
			
		||||
        idx_t, zbuf_t, bary_t, dists_t = rasterize_meshes_fn(**kwargs)
 | 
			
		||||
 | 
			
		||||
        self.assertClose(expected_bary_clipped, bary_t, atol=1e-4)
 | 
			
		||||
        self.assertClose(expected_z_clipped, zbuf_t, atol=1e-4)
 | 
			
		||||
 | 
			
		||||
    def _test_behind_camera(self, rasterize_meshes_fn, device, bin_size=None):
 | 
			
		||||
        """
 | 
			
		||||
        All verts are behind the camera so nothing should get rasterized.
 | 
			
		||||
 | 
			
		||||
@ -212,6 +212,7 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
            image_size=512,
 | 
			
		||||
            blur_radius=np.log(1.0 / 1e-4 - 1.0) * blend_params.sigma,
 | 
			
		||||
            faces_per_pixel=80,
 | 
			
		||||
            clip_barycentric_coords=True,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # Init rasterizer settings
 | 
			
		||||
@ -269,11 +270,19 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
        # the cow is facing the -z direction.
 | 
			
		||||
        lights.location = torch.tensor([0.0, 0.0, 2.0], device=device)[None]
 | 
			
		||||
 | 
			
		||||
        blend_params = BlendParams(
 | 
			
		||||
            sigma=1e-1,
 | 
			
		||||
            gamma=1e-4,
 | 
			
		||||
            background_color=torch.tensor([1.0, 1.0, 1.0], device=device),
 | 
			
		||||
        )
 | 
			
		||||
        # Init renderer
 | 
			
		||||
        renderer = MeshRenderer(
 | 
			
		||||
            rasterizer=MeshRasterizer(cameras=cameras, raster_settings=raster_settings),
 | 
			
		||||
            shader=TexturedSoftPhongShader(
 | 
			
		||||
                lights=lights, cameras=cameras, materials=materials
 | 
			
		||||
                lights=lights,
 | 
			
		||||
                cameras=cameras,
 | 
			
		||||
                materials=materials,
 | 
			
		||||
                blend_params=blend_params,
 | 
			
		||||
            ),
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
@ -346,6 +355,7 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
            image_size=512,
 | 
			
		||||
            blur_radius=np.log(1.0 / 1e-4 - 1.0) * blend_params.sigma,
 | 
			
		||||
            faces_per_pixel=100,
 | 
			
		||||
            clip_barycentric_coords=True,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # Load reference image
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user