diff --git a/pytorch3d/csrc/rasterize_meshes/rasterize_meshes.cu b/pytorch3d/csrc/rasterize_meshes/rasterize_meshes.cu index e280b3f7..aa069725 100644 --- a/pytorch3d/csrc/rasterize_meshes/rasterize_meshes.cu +++ b/pytorch3d/csrc/rasterize_meshes/rasterize_meshes.cu @@ -114,6 +114,7 @@ __device__ void CheckPixelInsideFace( const float2 pxy, // Coordinates of the pixel const int K, const bool perspective_correct, + const bool clip_barycentric_coords, const bool cull_backfaces) { const auto v012 = GetSingleFaceVerts(face_verts, face_idx); const float3 v0 = thrust::get<0>(v012); @@ -149,8 +150,12 @@ __device__ void CheckPixelInsideFace( const float3 p_bary = !perspective_correct ? p_bary0 : BarycentricPerspectiveCorrectionForward(p_bary0, v0.z, v1.z, v2.z); + const float3 p_bary_clip = + !clip_barycentric_coords ? p_bary : BarycentricClipForward(p_bary); + + const float pz = + p_bary_clip.x * v0.z + p_bary_clip.y * v1.z + p_bary_clip.z * v2.z; - const float pz = p_bary.x * v0.z + p_bary.y * v1.z + p_bary.z * v2.z; if (pz < 0) { return; // Face is behind the image plane. } @@ -158,7 +163,8 @@ __device__ void CheckPixelInsideFace( // Get abs squared distance const float dist = PointTriangleDistanceForward(pxy, v0xy, v1xy, v2xy); - // Use the bary coordinates to determine if the point is inside the face. + // Use the unclipped bary coordinates to determine if the point is inside the + // face. const bool inside = p_bary.x > 0.0f && p_bary.y > 0.0f && p_bary.z > 0.0f; const float signed_dist = inside ? -dist : dist; @@ -169,7 +175,7 @@ __device__ void CheckPixelInsideFace( if (q_size < K) { // Just insert it. - q[q_size] = {pz, face_idx, signed_dist, p_bary}; + q[q_size] = {pz, face_idx, signed_dist, p_bary_clip}; if (pz > q_max_z) { q_max_z = pz; q_max_idx = q_size; @@ -177,7 +183,7 @@ __device__ void CheckPixelInsideFace( q_size++; } else if (pz < q_max_z) { // Overwrite the old max, and find the new max. - q[q_max_idx] = {pz, face_idx, signed_dist, p_bary}; + q[q_max_idx] = {pz, face_idx, signed_dist, p_bary_clip}; q_max_z = pz; for (int i = 0; i < K; i++) { if (q[i].z > q_max_z) { @@ -198,6 +204,7 @@ __global__ void RasterizeMeshesNaiveCudaKernel( const int64_t* num_faces_per_mesh, const float blur_radius, const bool perspective_correct, + const bool clip_barycentric_coords, const bool cull_backfaces, const int N, const int H, @@ -260,6 +267,7 @@ __global__ void RasterizeMeshesNaiveCudaKernel( pxy, K, perspective_correct, + clip_barycentric_coords, cull_backfaces); } @@ -286,6 +294,7 @@ RasterizeMeshesNaiveCuda( const float blur_radius, const int num_closest, const bool perspective_correct, + const bool clip_barycentric_coords, const bool cull_backfaces) { TORCH_CHECK( face_verts.ndimension() == 3 && face_verts.size(1) == 3 && @@ -343,6 +352,7 @@ RasterizeMeshesNaiveCuda( num_faces_per_mesh.contiguous().data_ptr(), blur_radius, perspective_correct, + clip_barycentric_coords, cull_backfaces, N, H, @@ -365,6 +375,7 @@ __global__ void RasterizeMeshesBackwardCudaKernel( const float* face_verts, // (F, 3, 3) const int64_t* pix_to_face, // (N, H, W, K) const bool perspective_correct, + const bool clip_barycentric_coords, const int N, const int H, const int W, @@ -422,11 +433,15 @@ __global__ void RasterizeMeshesBackwardCudaKernel( const float3 grad_bary_upstream = make_float3( grad_bary_upstream_w0, grad_bary_upstream_w1, grad_bary_upstream_w2); - const float3 bary0 = BarycentricCoordsForward(pxy, v0xy, v1xy, v2xy); - const float3 bary = !perspective_correct - ? bary0 - : BarycentricPerspectiveCorrectionForward(bary0, v0.z, v1.z, v2.z); - const bool inside = bary.x > 0.0f && bary.y > 0.0f && bary.z > 0.0f; + const float3 b_w = BarycentricCoordsForward(pxy, v0xy, v1xy, v2xy); + const float3 b_pp = !perspective_correct + ? b_w + : BarycentricPerspectiveCorrectionForward(b_w, v0.z, v1.z, v2.z); + + const float3 b_w_clip = + !clip_barycentric_coords ? b_pp : BarycentricClipForward(b_pp); + + const bool inside = b_pp.x > 0.0f && b_pp.y > 0.0f && b_pp.z > 0.0f; const float sign = inside ? -1.0f : 1.0f; // TODO(T52813608) Add support for non-square images. @@ -442,22 +457,29 @@ __global__ void RasterizeMeshesBackwardCudaKernel( // d_zbuf/d_bary_w0 = z0 // d_zbuf/d_bary_w1 = z1 // d_zbuf/d_bary_w2 = z2 - const float3 d_zbuf_d_bary = make_float3(v0.z, v1.z, v2.z); + const float3 d_zbuf_d_bwclip = make_float3(v0.z, v1.z, v2.z); // Total upstream barycentric gradients are the sum of // external upstream gradients and contribution from zbuf. const float3 grad_bary_f_sum = - (grad_bary_upstream + grad_zbuf_upstream * d_zbuf_d_bary); + (grad_bary_upstream + grad_zbuf_upstream * d_zbuf_d_bwclip); + float3 grad_bary0 = grad_bary_f_sum; + + if (clip_barycentric_coords) { + grad_bary0 = BarycentricClipBackward(b_w, grad_bary_f_sum); + } + float dz0_persp = 0.0f, dz1_persp = 0.0f, dz2_persp = 0.0f; if (perspective_correct) { auto perspective_grads = BarycentricPerspectiveCorrectionBackward( - bary0, v0.z, v1.z, v2.z, grad_bary_f_sum); + b_w, v0.z, v1.z, v2.z, grad_bary0); grad_bary0 = thrust::get<0>(perspective_grads); dz0_persp = thrust::get<1>(perspective_grads); dz1_persp = thrust::get<2>(perspective_grads); dz2_persp = thrust::get<3>(perspective_grads); } + auto grad_bary_f = BarycentricCoordsBackward(pxy, v0xy, v1xy, v2xy, grad_bary0); const float2 dbary_d_v0 = thrust::get<1>(grad_bary_f); @@ -467,15 +489,18 @@ __global__ void RasterizeMeshesBackwardCudaKernel( atomicAdd(grad_face_verts + f * 9 + 0, dbary_d_v0.x + ddist_d_v0.x); atomicAdd(grad_face_verts + f * 9 + 1, dbary_d_v0.y + ddist_d_v0.y); atomicAdd( - grad_face_verts + f * 9 + 2, grad_zbuf_upstream * bary.x + dz0_persp); + grad_face_verts + f * 9 + 2, + grad_zbuf_upstream * b_w_clip.x + dz0_persp); atomicAdd(grad_face_verts + f * 9 + 3, dbary_d_v1.x + ddist_d_v1.x); atomicAdd(grad_face_verts + f * 9 + 4, dbary_d_v1.y + ddist_d_v1.y); atomicAdd( - grad_face_verts + f * 9 + 5, grad_zbuf_upstream * bary.y + dz1_persp); + grad_face_verts + f * 9 + 5, + grad_zbuf_upstream * b_w_clip.y + dz1_persp); atomicAdd(grad_face_verts + f * 9 + 6, dbary_d_v2.x + ddist_d_v2.x); atomicAdd(grad_face_verts + f * 9 + 7, dbary_d_v2.y + ddist_d_v2.y); atomicAdd( - grad_face_verts + f * 9 + 8, grad_zbuf_upstream * bary.z + dz2_persp); + grad_face_verts + f * 9 + 8, + grad_zbuf_upstream * b_w_clip.z + dz2_persp); } } } @@ -486,7 +511,8 @@ at::Tensor RasterizeMeshesBackwardCuda( const at::Tensor& grad_zbuf, // (N, H, W, K) const at::Tensor& grad_bary, // (N, H, W, K, 3) const at::Tensor& grad_dists, // (N, H, W, K) - const bool perspective_correct) { + const bool perspective_correct, + const bool clip_barycentric_coords) { // Check inputs are on the same device at::TensorArg face_verts_t{face_verts, "face_verts", 1}, pix_to_face_t{pix_to_face, "pix_to_face", 2}, @@ -523,6 +549,7 @@ at::Tensor RasterizeMeshesBackwardCuda( face_verts.contiguous().data_ptr(), pix_to_face.contiguous().data_ptr(), perspective_correct, + clip_barycentric_coords, N, H, W, @@ -743,6 +770,7 @@ __global__ void RasterizeMeshesFineCudaKernel( const float blur_radius, const int bin_size, const bool perspective_correct, + const bool clip_barycentric_coords, const bool cull_backfaces, const int N, const int B, @@ -808,6 +836,7 @@ __global__ void RasterizeMeshesFineCudaKernel( pxy, K, perspective_correct, + clip_barycentric_coords, cull_backfaces); } @@ -841,6 +870,7 @@ RasterizeMeshesFineCuda( const int bin_size, const int faces_per_pixel, const bool perspective_correct, + const bool clip_barycentric_coords, const bool cull_backfaces) { TORCH_CHECK( face_verts.ndimension() == 3 && face_verts.size(1) == 3 && @@ -890,6 +920,7 @@ RasterizeMeshesFineCuda( blur_radius, bin_size, perspective_correct, + clip_barycentric_coords, cull_backfaces, N, B, diff --git a/pytorch3d/csrc/rasterize_meshes/rasterize_meshes.h b/pytorch3d/csrc/rasterize_meshes/rasterize_meshes.h index 54031b17..8b9528db 100644 --- a/pytorch3d/csrc/rasterize_meshes/rasterize_meshes.h +++ b/pytorch3d/csrc/rasterize_meshes/rasterize_meshes.h @@ -19,6 +19,7 @@ RasterizeMeshesNaiveCpu( const float blur_radius, const int faces_per_pixel, const bool perspective_correct, + const bool clip_barycentric_coords, const bool cull_backfaces); #ifdef WITH_CUDA @@ -31,6 +32,7 @@ RasterizeMeshesNaiveCuda( const float blur_radius, const int num_closest, const bool perspective_correct, + const bool clip_barycentric_coords, const bool cull_backfaces); #endif // Forward pass for rasterizing a batch of meshes. @@ -92,6 +94,7 @@ RasterizeMeshesNaive( const float blur_radius, const int faces_per_pixel, const bool perspective_correct, + const bool clip_barycentric_coords, const bool cull_backfaces) { // TODO: Better type checking. if (face_verts.is_cuda()) { @@ -107,6 +110,7 @@ RasterizeMeshesNaive( blur_radius, faces_per_pixel, perspective_correct, + clip_barycentric_coords, cull_backfaces); #else AT_ERROR("Not compiled with GPU support"); @@ -120,6 +124,7 @@ RasterizeMeshesNaive( blur_radius, faces_per_pixel, perspective_correct, + clip_barycentric_coords, cull_backfaces); } } @@ -134,7 +139,8 @@ torch::Tensor RasterizeMeshesBackwardCpu( const torch::Tensor& grad_bary, const torch::Tensor& grad_zbuf, const torch::Tensor& grad_dists, - const bool perspective_correct); + const bool perspective_correct, + const bool clip_barycentric_coords); #ifdef WITH_CUDA torch::Tensor RasterizeMeshesBackwardCuda( @@ -143,7 +149,8 @@ torch::Tensor RasterizeMeshesBackwardCuda( const torch::Tensor& grad_bary, const torch::Tensor& grad_zbuf, const torch::Tensor& grad_dists, - const bool perspective_correct); + const bool perspective_correct, + const bool clip_barycentric_coords); #endif // Args: @@ -176,7 +183,8 @@ torch::Tensor RasterizeMeshesBackward( const torch::Tensor& grad_zbuf, const torch::Tensor& grad_bary, const torch::Tensor& grad_dists, - const bool perspective_correct) { + const bool perspective_correct, + const bool clip_barycentric_coords) { if (face_verts.is_cuda()) { #ifdef WITH_CUDA CHECK_CUDA(face_verts); @@ -190,7 +198,8 @@ torch::Tensor RasterizeMeshesBackward( grad_zbuf, grad_bary, grad_dists, - perspective_correct); + perspective_correct, + clip_barycentric_coords); #else AT_ERROR("Not compiled with GPU support"); #endif @@ -201,7 +210,8 @@ torch::Tensor RasterizeMeshesBackward( grad_zbuf, grad_bary, grad_dists, - perspective_correct); + perspective_correct, + clip_barycentric_coords); } } @@ -300,6 +310,7 @@ RasterizeMeshesFineCuda( const int bin_size, const int faces_per_pixel, const bool perspective_correct, + const bool clip_barycentric_coords, const bool cull_backfaces); #endif // Args: @@ -356,6 +367,7 @@ RasterizeMeshesFine( const int bin_size, const int faces_per_pixel, const bool perspective_correct, + const bool clip_barycentric_coords, const bool cull_backfaces) { if (face_verts.is_cuda()) { #ifdef WITH_CUDA @@ -369,6 +381,7 @@ RasterizeMeshesFine( bin_size, faces_per_pixel, perspective_correct, + clip_barycentric_coords, cull_backfaces); #else AT_ERROR("Not compiled with GPU support"); @@ -446,6 +459,7 @@ RasterizeMeshes( const int bin_size, const int max_faces_per_bin, const bool perspective_correct, + const bool clip_barycentric_coords, const bool cull_backfaces) { if (bin_size > 0 && max_faces_per_bin > 0) { // Use coarse-to-fine rasterization @@ -465,6 +479,7 @@ RasterizeMeshes( bin_size, faces_per_pixel, perspective_correct, + clip_barycentric_coords, cull_backfaces); } else { // Use the naive per-pixel implementation @@ -476,6 +491,7 @@ RasterizeMeshes( blur_radius, faces_per_pixel, perspective_correct, + clip_barycentric_coords, cull_backfaces); } } diff --git a/pytorch3d/csrc/rasterize_meshes/rasterize_meshes_cpu.cpp b/pytorch3d/csrc/rasterize_meshes/rasterize_meshes_cpu.cpp index b01bdbaa..89a7dc3a 100644 --- a/pytorch3d/csrc/rasterize_meshes/rasterize_meshes_cpu.cpp +++ b/pytorch3d/csrc/rasterize_meshes/rasterize_meshes_cpu.cpp @@ -108,6 +108,7 @@ RasterizeMeshesNaiveCpu( const float blur_radius, const int faces_per_pixel, const bool perspective_correct, + const bool clip_barycentric_coords, const bool cull_backfaces) { if (face_verts.ndimension() != 3 || face_verts.size(1) != 3 || face_verts.size(2) != 3) { @@ -213,8 +214,12 @@ RasterizeMeshesNaiveCpu( ? bary0 : BarycentricPerspectiveCorrectionForward(bary0, z0, z1, z2); + const vec3 bary_clip = + !clip_barycentric_coords ? bary : BarycentricClipForward(bary); + // Use barycentric coordinates to get the depth of the current pixel - const float pz = (bary.x * z0 + bary.y * z1 + bary.z * z2); + const float pz = + (bary_clip.x * z0 + bary_clip.y * z1 + bary_clip.z * z2); if (pz < 0) { continue; // Point is behind the image plane so ignore. @@ -236,7 +241,7 @@ RasterizeMeshesNaiveCpu( continue; } // The current pixel lies inside the current face. - q.emplace(pz, f, signed_dist, bary.x, bary.y, bary.z); + q.emplace(pz, f, signed_dist, bary_clip.x, bary_clip.y, bary_clip.z); if (static_cast(q.size()) > K) { q.pop(); } @@ -264,7 +269,8 @@ torch::Tensor RasterizeMeshesBackwardCpu( const torch::Tensor& grad_zbuf, // (N, H, W, K) const torch::Tensor& grad_bary, // (N, H, W, K, 3) const torch::Tensor& grad_dists, // (N, H, W, K) - const bool perspective_correct) { + const bool perspective_correct, + const bool clip_barycentric_coords) { const int F = face_verts.size(0); const int N = pix_to_face.size(0); const int H = pix_to_face.size(1); @@ -335,6 +341,8 @@ torch::Tensor RasterizeMeshesBackwardCpu( const vec3 bary = !perspective_correct ? bary0 : BarycentricPerspectiveCorrectionForward(bary0, z0, z1, z2); + const vec3 bary_clip = + !clip_barycentric_coords ? bary : BarycentricClipForward(bary); // Distances inside the face are negative so get the // correct sign to apply to the upstream gradient. @@ -354,22 +362,28 @@ torch::Tensor RasterizeMeshesBackwardCpu( // d_zbuf/d_bary_w0 = z0 // d_zbuf/d_bary_w1 = z1 // d_zbuf/d_bary_w2 = z2 - const vec3 d_zbuf_d_bary(z0, z1, z2); + const vec3 d_zbuf_d_baryclip(z0, z1, z2); // Total upstream barycentric gradients are the sum of // external upstream gradients and contribution from zbuf. - vec3 grad_bary_f_sum = - (grad_bary_upstream + grad_zbuf_upstream * d_zbuf_d_bary); + const vec3 grad_bary_f_sum = + (grad_bary_upstream + grad_zbuf_upstream * d_zbuf_d_baryclip); vec3 grad_bary0 = grad_bary_f_sum; + + if (clip_barycentric_coords) { + grad_bary0 = BarycentricClipBackward(bary, grad_bary0); + } + if (perspective_correct) { auto perspective_grads = BarycentricPerspectiveCorrectionBackward( - bary0, z0, z1, z2, grad_bary_f_sum); + bary0, z0, z1, z2, grad_bary0); grad_bary0 = std::get<0>(perspective_grads); grad_face_verts[f][0][2] += std::get<1>(perspective_grads); grad_face_verts[f][1][2] += std::get<2>(perspective_grads); grad_face_verts[f][2][2] += std::get<3>(perspective_grads); } + auto grad_bary_f = BarycentricCoordsBackward(pxy, v0xy, v1xy, v2xy, grad_bary0); const vec2 dbary_d_v0 = std::get<1>(grad_bary_f); @@ -379,13 +393,13 @@ torch::Tensor RasterizeMeshesBackwardCpu( // Update output gradient buffer. grad_face_verts[f][0][0] += dbary_d_v0.x + ddist_d_v0.x; grad_face_verts[f][0][1] += dbary_d_v0.y + ddist_d_v0.y; - grad_face_verts[f][0][2] += grad_zbuf_upstream * bary.x; + grad_face_verts[f][0][2] += grad_zbuf_upstream * bary_clip.x; grad_face_verts[f][1][0] += dbary_d_v1.x + ddist_d_v1.x; grad_face_verts[f][1][1] += dbary_d_v1.y + ddist_d_v1.y; - grad_face_verts[f][1][2] += grad_zbuf_upstream * bary.y; + grad_face_verts[f][1][2] += grad_zbuf_upstream * bary_clip.y; grad_face_verts[f][2][0] += dbary_d_v2.x + ddist_d_v2.x; grad_face_verts[f][2][1] += dbary_d_v2.y + ddist_d_v2.y; - grad_face_verts[f][2][2] += grad_zbuf_upstream * bary.z; + grad_face_verts[f][2][2] += grad_zbuf_upstream * bary_clip.z; } } } diff --git a/pytorch3d/csrc/utils/geometry_utils.cuh b/pytorch3d/csrc/utils/geometry_utils.cuh index 5fd4ed6b..f8d6b2a8 100644 --- a/pytorch3d/csrc/utils/geometry_utils.cuh +++ b/pytorch3d/csrc/utils/geometry_utils.cuh @@ -221,8 +221,108 @@ BarycentricPerspectiveCorrectionBackward( return thrust::make_tuple(grad_bary, grad_z0, grad_z1, grad_z2); } -// Calculate minimum squared distance between a line segment (v1 - v0) and a -// point p. +// Clip negative barycentric coordinates to 0.0 and renormalize so +// the barycentric coordinates for a point sum to 1. When the blur_radius +// is greater than 0, a face will still be recorded as overlapping a pixel +// if the pixel is outisde the face. In this case at least one of the +// barycentric coordinates for the pixel relative to the face will be negative. +// Clipping will ensure that the texture and z buffer are interpolated +// correctly. +// +// Args +// bary: (w0, w1, w2) barycentric coordinates which can be outside the +// range [0, 1]. +// +// Returns +// bary: (w0, w1, w2) barycentric coordinates in the range [0, 1] which +// satisfy the condition: sum(w0, w1, w2) = 1.0. +// +__device__ inline float3 BarycentricClipForward(const float3 bary) { + float3 w = make_float3(0.0f, 0.0f, 0.0f); + // Clamp lower bound only + w.x = max(bary.x, 0.0); + w.y = max(bary.y, 0.0); + w.z = max(bary.z, 0.0); + float w_sum = w.x + w.y + w.z; + w_sum = fmaxf(w_sum, 1e-5); + w.x /= w_sum; + w.y /= w_sum; + w.z /= w_sum; + + return w; +} + +// Backward pass for barycentric coordinate clipping. +// +// Args +// bary: (w0, w1, w2) barycentric coordinates which can be outside the +// range [0, 1]. +// grad_baryclip_upstream: vec3 Upstream gradient for each of the clipped +// barycentric coordinates [grad_w0, grad_w1, grad_w2]. +// +// Returns +// vec3 of gradients for the unclipped barycentric coordinates: +// (grad_w0, grad_w1, grad_w2) +// +__device__ inline float3 BarycentricClipBackward( + const float3 bary, + const float3 grad_baryclip_upstream) { + // Redo some of the forward pass calculations + float3 w = make_float3(0.0f, 0.0f, 0.0f); + // Clamp lower bound only + w.x = max(bary.x, 0.0); + w.y = max(bary.y, 0.0); + w.z = max(bary.z, 0.0); + float w_sum = w.x + w.y + w.z; + + float3 grad_bary = make_float3(1.0f, 1.0f, 1.0f); + float3 grad_clip = make_float3(1.0f, 1.0f, 1.0f); + float3 grad_sum = make_float3(1.0f, 1.0f, 1.0f); + + // Check if sum was clipped. + float grad_sum_clip = 1.0f; + if (w_sum < 1e-5) { + grad_sum_clip = 0.0f; + w_sum = 1e-5; + } + + // Check if any of bary values have been clipped. + if (bary.x < 0.0f) { + grad_clip.x = 0.0f; + } + if (bary.y < 0.0f) { + grad_clip.y = 0.0f; + } + if (bary.z < 0.0f) { + grad_clip.z = 0.0f; + } + + // Gradients of the sum. + grad_sum.x = -w.x / (pow(w_sum, 2.0f)) * grad_sum_clip; + grad_sum.y = -w.y / (pow(w_sum, 2.0f)) * grad_sum_clip; + grad_sum.z = -w.z / (pow(w_sum, 2.0f)) * grad_sum_clip; + + // Gradients for each of the bary coordinates including the cross terms + // from the sum. + grad_bary.x = grad_clip.x * + (grad_baryclip_upstream.x * (1.0f / w_sum + grad_sum.x) + + grad_baryclip_upstream.y * (grad_sum.y) + + grad_baryclip_upstream.z * (grad_sum.z)); + + grad_bary.y = grad_clip.y * + (grad_baryclip_upstream.y * (1.0f / w_sum + grad_sum.y) + + grad_baryclip_upstream.x * (grad_sum.x) + + grad_baryclip_upstream.z * (grad_sum.z)); + + grad_bary.z = grad_clip.z * + (grad_baryclip_upstream.z * (1.0f / w_sum + grad_sum.z) + + grad_baryclip_upstream.x * (grad_sum.x) + + grad_baryclip_upstream.y * (grad_sum.y)); + + return grad_bary; +} + +// Return minimum distance between line segment (v1 - v0) and point p. // // Args: // p: Coordinates of a point. diff --git a/pytorch3d/csrc/utils/geometry_utils.h b/pytorch3d/csrc/utils/geometry_utils.h index e283c8f4..f24f58ac 100644 --- a/pytorch3d/csrc/utils/geometry_utils.h +++ b/pytorch3d/csrc/utils/geometry_utils.h @@ -242,8 +242,108 @@ inline std::tuple, T, T, T> BarycentricPerspectiveCorrectionBackward( return std::make_tuple(grad_bary, grad_z0, grad_z1, grad_z2); } -// Calculate minimum squared distance between a line segment (v1 - v0) and a -// point p. +// Clip negative barycentric coordinates to 0.0 and renormalize so +// the barycentric coordinates for a point sum to 1. When the blur_radius +// is greater than 0, a face will still be recorded as overlapping a pixel +// if the pixel is outisde the face. In this case at least one of the +// barycentric coordinates for the pixel relative to the face will be negative. +// Clipping will ensure that the texture and z buffer are interpolated +// correctly. +// +// Args +// bary: (w0, w1, w2) barycentric coordinates which can contain values < 0. +// +// Returns +// bary: (w0, w1, w2) barycentric coordinates in the range [0, 1] which +// satisfy the condition: sum(w0, w1, w2) = 1.0. +// +template +vec3 BarycentricClipForward(const vec3 bary) { + vec3 w(0.0f, 0.0f, 0.0f); + // Only clamp negative values to 0.0. + // No need to clamp values > 1.0 as they will be renormalized. + w.x = std::max(bary.x, 0.0f); + w.y = std::max(bary.y, 0.0f); + w.z = std::max(bary.z, 0.0f); + float w_sum = w.x + w.y + w.z; + w_sum = std::fmaxf(w_sum, 1e-5); + w.x /= w_sum; + w.y /= w_sum; + w.z /= w_sum; + return w; +} + +// Backward pass for barycentric coordinate clipping. +// +// Args +// bary: (w0, w1, w2) barycentric coordinates which can contain values < 0. +// grad_baryclip_upstream: vec3 Upstream gradient for each of the clipped +// barycentric coordinates [grad_w0, grad_w1, grad_w2]. +// +// Returns +// vec3 of gradients for the unclipped barycentric coordinates: +// (grad_w0, grad_w1, grad_w2) +// +template +vec3 BarycentricClipBackward( + const vec3 bary, + const vec3 grad_baryclip_upstream) { + // Redo some of the forward pass calculations + vec3 w(0.0f, 0.0f, 0.0f); + w.x = std::max(bary.x, 0.0f); + w.y = std::max(bary.y, 0.0f); + w.z = std::max(bary.z, 0.0f); + float w_sum = w.x + w.y + w.z; + + vec3 grad_bary(1.0f, 1.0f, 1.0f); + vec3 grad_clip(1.0f, 1.0f, 1.0f); + vec3 grad_sum(1.0f, 1.0f, 1.0f); + + // Check if the sum was clipped. + float grad_sum_clip = 1.0f; + if (w_sum < 1e-5) { + grad_sum_clip = 0.0f; + w_sum = 1e-5; + } + + // Check if any of the bary coordinates have been clipped. + // Only negative values are clamped to 0.0. + if (bary.x < 0.0f) { + grad_clip.x = 0.0f; + } + if (bary.y < 0.0f) { + grad_clip.y = 0.0f; + } + if (bary.z < 0.0f) { + grad_clip.z = 0.0f; + } + + // Gradients of the sum. + grad_sum.x = -w.x / (pow(w_sum, 2.0f)) * grad_sum_clip; + grad_sum.y = -w.y / (pow(w_sum, 2.0f)) * grad_sum_clip; + grad_sum.z = -w.z / (pow(w_sum, 2.0f)) * grad_sum_clip; + + // Gradients for each of the bary coordinates including the cross terms + // from the sum. + grad_bary.x = grad_clip.x * + (grad_baryclip_upstream.x * (1.0f / w_sum + grad_sum.x) + + grad_baryclip_upstream.y * (grad_sum.y) + + grad_baryclip_upstream.z * (grad_sum.z)); + + grad_bary.y = grad_clip.y * + (grad_baryclip_upstream.y * (1.0f / w_sum + grad_sum.y) + + grad_baryclip_upstream.x * (grad_sum.x) + + grad_baryclip_upstream.z * (grad_sum.z)); + + grad_bary.z = grad_clip.z * + (grad_baryclip_upstream.z * (1.0f / w_sum + grad_sum.z) + + grad_baryclip_upstream.x * (grad_sum.x) + + grad_baryclip_upstream.y * (grad_sum.y)); + + return grad_bary; +} + +// Calculate minimum distance between a line segment (v1 - v0) and point p. // // Args: // p: Coordinates of a point. diff --git a/pytorch3d/renderer/mesh/rasterize_meshes.py b/pytorch3d/renderer/mesh/rasterize_meshes.py index 125eec20..984d3352 100644 --- a/pytorch3d/renderer/mesh/rasterize_meshes.py +++ b/pytorch3d/renderer/mesh/rasterize_meshes.py @@ -24,6 +24,7 @@ def rasterize_meshes( bin_size: Optional[int] = None, max_faces_per_bin: Optional[int] = None, perspective_correct: bool = False, + clip_barycentric_coords: bool = False, cull_backfaces: bool = False, ): """ @@ -143,6 +144,7 @@ def rasterize_meshes( bin_size, max_faces_per_bin, perspective_correct, + clip_barycentric_coords, cull_backfaces, ) @@ -183,6 +185,7 @@ class _RasterizeFaceVerts(torch.autograd.Function): bin_size: int = 0, max_faces_per_bin: int = 0, perspective_correct: bool = False, + clip_barycentric_coords: bool = False, cull_backfaces: bool = False, ): # pyre-fixme[16]: Module `pytorch3d` has no attribute `_C`. @@ -196,11 +199,13 @@ class _RasterizeFaceVerts(torch.autograd.Function): bin_size, max_faces_per_bin, perspective_correct, + clip_barycentric_coords, cull_backfaces, ) ctx.save_for_backward(face_verts, pix_to_face) ctx.mark_non_differentiable(pix_to_face) ctx.perspective_correct = perspective_correct + ctx.clip_barycentric_coords = clip_barycentric_coords return pix_to_face, zbuf, barycentric_coords, dists @staticmethod @@ -214,6 +219,7 @@ class _RasterizeFaceVerts(torch.autograd.Function): grad_bin_size = None grad_max_faces_per_bin = None grad_perspective_correct = None + grad_clip_barycentric_coords = None grad_cull_backfaces = None face_verts, pix_to_face = ctx.saved_tensors grad_face_verts = _C.rasterize_meshes_backward( @@ -223,6 +229,7 @@ class _RasterizeFaceVerts(torch.autograd.Function): grad_barycentric_coords, grad_dists, ctx.perspective_correct, + ctx.clip_barycentric_coords, ) grads = ( grad_face_verts, @@ -234,6 +241,7 @@ class _RasterizeFaceVerts(torch.autograd.Function): grad_bin_size, grad_max_faces_per_bin, grad_perspective_correct, + grad_clip_barycentric_coords, grad_cull_backfaces, ) return grads @@ -250,6 +258,7 @@ def rasterize_meshes_python( blur_radius: float = 0.0, faces_per_pixel: int = 8, perspective_correct: bool = False, + clip_barycentric_coords: bool = False, cull_backfaces: bool = False, ): """ @@ -356,6 +365,14 @@ def rasterize_meshes_python( top2 = z0 * z1 * l2 bot = top0 + top1 + top2 bary = torch.stack([top0 / bot, top1 / bot, top2 / bot]) + + # Check if inside before clipping + inside = all(x > 0.0 for x in bary) + + # Barycentric clipping + if clip_barycentric_coords: + bary = barycentric_coordinates_clip(bary) + # use clipped barycentric coords to calculate the z value pz = bary[0] * v0[2] + bary[1] * v1[2] + bary[2] * v2[2] # Check if point is behind the image. @@ -365,7 +382,6 @@ def rasterize_meshes_python( # Calculate signed 2D distance from point to face. # Points inside the triangle have negative distance. dist = point_triangle_distance(pxy, v0[:2], v1[:2], v2[:2]) - inside = all(x > 0.0 for x in bary) signed_dist = dist * -1.0 if inside else dist @@ -433,6 +449,33 @@ def edge_function(p, v0, v1): return (p[0] - v0[0]) * (v1[1] - v0[1]) - (p[1] - v0[1]) * (v1[0] - v0[0]) +def barycentric_coordinates_clip(bary): + """ + Clip negative barycentric coordinates to 0.0 and renormalize so + the barycentric coordinates for a point sum to 1. When the blur_radius + is greater than 0, a face will still be recorded as overlapping a pixel + if the pixel is outisde the face. In this case at least one of the + barycentric coordinates for the pixel relative to the face will be negative. + Clipping will ensure that the texture and z buffer are interpolated correctly. + + Args: + bary: tuple of barycentric coordinates + + Returns + bary_clip: (w0, w1, w2) barycentric coordinates with no negative values. + """ + # Only negative values are clamped to 0.0. + w0_clip = torch.clamp(bary[0], min=0.0) + w1_clip = torch.clamp(bary[1], min=0.0) + w2_clip = torch.clamp(bary[2], min=0.0) + bary_sum = torch.clamp(w0_clip + w1_clip + w2_clip, min=1e-5) + w0_clip = w0_clip / bary_sum + w1_clip = w1_clip / bary_sum + w2_clip = w2_clip / bary_sum + + return (w0_clip, w1_clip, w2_clip) + + def barycentric_coordinates(p, v0, v1, v2): """ Compute the barycentric coordinates of a point relative to a triangle. diff --git a/pytorch3d/renderer/mesh/rasterizer.py b/pytorch3d/renderer/mesh/rasterizer.py index 3ea2ae16..eefe1cf4 100644 --- a/pytorch3d/renderer/mesh/rasterizer.py +++ b/pytorch3d/renderer/mesh/rasterizer.py @@ -26,6 +26,7 @@ class RasterizationSettings: "bin_size", "max_faces_per_bin", "perspective_correct", + "clip_barycentric_coords", "cull_backfaces", ] @@ -37,6 +38,7 @@ class RasterizationSettings: bin_size: Optional[int] = None, max_faces_per_bin: Optional[int] = None, perspective_correct: bool = False, + clip_barycentric_coords: bool = False, cull_backfaces: bool = False, ): self.image_size = image_size @@ -45,6 +47,7 @@ class RasterizationSettings: self.bin_size = bin_size self.max_faces_per_bin = max_faces_per_bin self.perspective_correct = perspective_correct + self.clip_barycentric_coords = clip_barycentric_coords self.cull_backfaces = cull_backfaces @@ -127,6 +130,7 @@ class MeshRasterizer(nn.Module): bin_size=raster_settings.bin_size, max_faces_per_bin=raster_settings.max_faces_per_bin, perspective_correct=raster_settings.perspective_correct, + clip_barycentric_coords=raster_settings.clip_barycentric_coords, cull_backfaces=raster_settings.cull_backfaces, ) return Fragments( diff --git a/pytorch3d/renderer/mesh/renderer.py b/pytorch3d/renderer/mesh/renderer.py index 87835a70..46c0231f 100644 --- a/pytorch3d/renderer/mesh/renderer.py +++ b/pytorch3d/renderer/mesh/renderer.py @@ -49,21 +49,6 @@ class MeshRenderer(nn.Module): the range for the corresponding face. """ fragments = self.rasterizer(meshes_world, **kwargs) - raster_settings = kwargs.get("raster_settings", self.rasterizer.raster_settings) - if raster_settings.blur_radius > 0.0: - # TODO: potentially move barycentric clipping to the rasterizer - # if no downstream functions requires unclipped values. - # This will avoid unnecssary re-interpolation of the z buffer. - clipped_bary_coords = _clip_barycentric_coordinates(fragments.bary_coords) - clipped_zbuf = _interpolate_zbuf( - fragments.pix_to_face, clipped_bary_coords, meshes_world - ) - fragments = Fragments( - bary_coords=clipped_bary_coords, - zbuf=clipped_zbuf, - dists=fragments.dists, - pix_to_face=fragments.pix_to_face, - ) images = self.shader(fragments, meshes_world, **kwargs) return images diff --git a/pytorch3d/renderer/mesh/utils.py b/pytorch3d/renderer/mesh/utils.py index f61f4faf..749a746d 100644 --- a/pytorch3d/renderer/mesh/utils.py +++ b/pytorch3d/renderer/mesh/utils.py @@ -20,9 +20,13 @@ def _clip_barycentric_coordinates(bary) -> torch.Tensor: if bary.shape[-1] != 3: msg = "Expected barycentric coords to have last dim = 3; got %r" raise ValueError(msg % (bary.shape,)) + ndims = bary.ndim - 1 + mask = bary.eq(-1).all(dim=-1, keepdim=True).expand(*((-1,) * ndims + (3,))) clipped = bary.clamp(min=0.0) + clipped[mask] = 0.0 clipped_sum = torch.clamp(clipped.sum(dim=-1, keepdim=True), min=1e-5) clipped = clipped / clipped_sum + clipped[mask] = -1.0 return clipped @@ -49,6 +53,8 @@ def _interpolate_zbuf( verts = meshes.verts_packed() faces = meshes.faces_packed() faces_verts_z = verts[faces][..., 2][..., None] # (F, 3, 1) - return interpolate_face_attributes(pix_to_face, barycentric_coords, faces_verts_z)[ + zbuf = interpolate_face_attributes(pix_to_face, barycentric_coords, faces_verts_z)[ ..., 0 ] # (1, H, W, K) + zbuf[pix_to_face == -1] = -1 + return zbuf diff --git a/tests/bm_barycentric_clipping.py b/tests/bm_barycentric_clipping.py new file mode 100644 index 00000000..0941a97c --- /dev/null +++ b/tests/bm_barycentric_clipping.py @@ -0,0 +1,112 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. + +from itertools import product + +import torch +from fvcore.common.benchmark import benchmark +from pytorch3d.renderer.cameras import OpenGLPerspectiveCameras, look_at_view_transform +from pytorch3d.renderer.mesh.rasterizer import ( + Fragments, + MeshRasterizer, + RasterizationSettings, +) +from pytorch3d.renderer.mesh.utils import ( + _clip_barycentric_coordinates, + _interpolate_zbuf, +) +from pytorch3d.utils.ico_sphere import ico_sphere + + +def baryclip_cuda( + num_meshes: int = 8, + ico_level: int = 5, + image_size: int = 64, + faces_per_pixel: int = 50, + device="cuda", +): + # Init meshes + sphere_meshes = ico_sphere(ico_level, device).extend(num_meshes) + # Init transform + R, T = look_at_view_transform(1.0, 0.0, 0.0) + cameras = OpenGLPerspectiveCameras(device=device, R=R, T=T) + # Init rasterizer + raster_settings = RasterizationSettings( + image_size=image_size, + blur_radius=1e-4, + faces_per_pixel=faces_per_pixel, + clip_barycentric_coords=True, + ) + rasterizer = MeshRasterizer(cameras=cameras, raster_settings=raster_settings) + + torch.cuda.synchronize() + + def raster_fn(): + rasterizer(sphere_meshes) + torch.cuda.synchronize() + + return raster_fn + + +def baryclip_pytorch( + num_meshes: int = 8, + ico_level: int = 5, + image_size: int = 64, + faces_per_pixel: int = 50, + device="cuda", +): + # Init meshes + sphere_meshes = ico_sphere(ico_level, device).extend(num_meshes) + # Init transform + R, T = look_at_view_transform(1.0, 0.0, 0.0) + cameras = OpenGLPerspectiveCameras(device=device, R=R, T=T) + # Init rasterizer + raster_settings = RasterizationSettings( + image_size=image_size, + blur_radius=1e-4, + faces_per_pixel=faces_per_pixel, + clip_barycentric_coords=False, + ) + rasterizer = MeshRasterizer(cameras=cameras, raster_settings=raster_settings) + + torch.cuda.synchronize() + + def raster_fn(): + fragments = rasterizer(sphere_meshes) + + # Clip bary and reinterpolate + clipped_bary_coords = _clip_barycentric_coordinates(fragments.bary_coords) + clipped_zbuf = _interpolate_zbuf( + fragments.pix_to_face, clipped_bary_coords, sphere_meshes + ) + fragments = Fragments( + bary_coords=clipped_bary_coords, + zbuf=clipped_zbuf, + dists=fragments.dists, + pix_to_face=fragments.pix_to_face, + ) + torch.cuda.synchronize() + + return raster_fn + + +def bm_barycentric_clip() -> None: + if torch.cuda.is_available(): + kwargs_list = [] + num_meshes = [1, 8] + ico_level = [0, 4] + image_size = [64, 128, 256] + faces_per_pixel = [10, 75, 100] + test_cases = product(num_meshes, ico_level, image_size, faces_per_pixel) + for case in test_cases: + n, ic, im, nf = case + kwargs_list.append( + { + "num_meshes": n, + "ico_level": ic, + "image_size": im, + "faces_per_pixel": nf, + } + ) + + benchmark(baryclip_cuda, "BARY_CLIP_CUDA", kwargs_list, warmup_iters=1) + benchmark(baryclip_pytorch, "BARY_CLIP_PYTORCH", kwargs_list, warmup_iters=1) diff --git a/tests/data/test_blurry_textured_rendering.png b/tests/data/test_blurry_textured_rendering.png index ab3e8d99..30a870ad 100644 Binary files a/tests/data/test_blurry_textured_rendering.png and b/tests/data/test_blurry_textured_rendering.png differ diff --git a/tests/test_rasterize_meshes.py b/tests/test_rasterize_meshes.py index 5fde5ac7..c6746ccf 100644 --- a/tests/test_rasterize_meshes.py +++ b/tests/test_rasterize_meshes.py @@ -10,6 +10,10 @@ from pytorch3d.renderer.mesh.rasterize_meshes import ( rasterize_meshes, rasterize_meshes_python, ) +from pytorch3d.renderer.mesh.utils import ( + _clip_barycentric_coordinates, + _interpolate_zbuf, +) from pytorch3d.structures import Meshes from pytorch3d.utils import ico_sphere @@ -21,6 +25,7 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase): self._simple_blurry_raster(rasterize_meshes_python, device, bin_size=-1) self._test_behind_camera(rasterize_meshes_python, device, bin_size=-1) self._test_perspective_correct(rasterize_meshes_python, device, bin_size=-1) + self._test_barycentric_clipping(rasterize_meshes_python, device, bin_size=-1) self._test_back_face_culling(rasterize_meshes_python, device, bin_size=-1) def test_simple_cpu_naive(self): @@ -170,8 +175,29 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase): verts2.requires_grad = True meshes_cuda = Meshes(verts=[verts2], faces=[faces2]) - args_cpu = (meshes_cpu, image_size, radius, faces_per_pixel) - args_cuda = (meshes_cuda, image_size, radius, faces_per_pixel, 0, 0) + barycentric_clip = True + args_cpu = ( + meshes_cpu, + image_size, + radius, + faces_per_pixel, + None, + None, + False, + barycentric_clip, + False, + ) + args_cuda = ( + meshes_cuda, + image_size, + radius, + faces_per_pixel, + 0, + 0, + False, + barycentric_clip, + False, + ) self._compare_impls( rasterize_meshes, rasterize_meshes, @@ -333,6 +359,39 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase): idxs_cuda[:K] = sorted(idxs_cuda[:K]) self.assertEqual(idxs_cpu, idxs_cuda) + def test_python_vs_cpp_bary_clip(self): + torch.manual_seed(232) + N = 2 + V = 10 + F = 5 + verts1 = torch.randn(N, V, 3, requires_grad=True) + verts2 = verts1.detach().clone().requires_grad_(True) + faces = torch.randint(V, size=(N, F, 3)) + meshes1 = Meshes(verts1, faces) + meshes2 = Meshes(verts2, faces) + + kwargs = {"image_size": 24, "clip_barycentric_coords": True} + fn1 = functools.partial(rasterize_meshes, meshes1, **kwargs) + fn2 = functools.partial(rasterize_meshes_python, meshes2, **kwargs) + args = () + self._compare_impls(fn1, fn2, args, args, verts1, verts2, compare_grads=True) + + def test_cpp_vs_cuda_bary_clip(self): + meshes = ico_sphere(2, device=torch.device("cpu")) + verts1, faces1 = meshes.get_mesh_verts_faces(0) + verts1.requires_grad = True + meshes1 = Meshes(verts=[verts1], faces=[faces1]) + device = get_random_cuda_device() + verts2 = verts1.detach().to(device).requires_grad_(True) + faces2 = faces1.detach().clone().to(device) + meshes2 = Meshes(verts=[verts2], faces=[faces2]) + + kwargs = {"image_size": 64, "clip_barycentric_coords": True} + fn1 = functools.partial(rasterize_meshes, meshes1, **kwargs) + fn2 = functools.partial(rasterize_meshes, meshes2, bin_size=0, **kwargs) + args = () + self._compare_impls(fn1, fn2, args, args, verts1, verts2, compare_grads=True) + def test_python_vs_cpp_perspective_correct(self): torch.manual_seed(232) N = 2 @@ -621,6 +680,82 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase): self.assertLess(zbuf_f_bary_diff, 1e-4) self.assertLess(zbuf_t_bary_diff, 1e-4) + def _test_barycentric_clipping(self, rasterize_meshes_fn, device, bin_size=None): + # fmt: off + verts = torch.tensor([ + [-0.4, -0.4, 10], # noqa: E241, E201 + [ 0.4, -0.4, 10], # noqa: E241, E201 + [ 0.0, 0.4, 20], # noqa: E241, E201 + ], dtype=torch.float32, device=device) + # fmt: on + faces = torch.tensor([[0, 1, 2]], device=device) + meshes = Meshes(verts=[verts], faces=[faces]) + kwargs = { + "meshes": meshes, + "image_size": 5, + "faces_per_pixel": 1, + "blur_radius": 0.2, + "perspective_correct": False, + "clip_barycentric_coords": False, # Initially set this to false + } + if bin_size != -1: + kwargs["bin_size"] = bin_size + + # Run with and without perspective correction + idx_f, zbuf_f, bary_f, dists_f = rasterize_meshes_fn(**kwargs) + + # fmt: off + expected_bary = torch.tensor([ + [ + [-1.0000, -1.0000, -1.0000], # noqa: E241, E201 + [-1.0000, -1.0000, -1.0000], # noqa: E241, E201 + [-0.2500, -0.2500, 1.5000], # noqa: E241, E201 + [-1.0000, -1.0000, -1.0000], # noqa: E241, E201 + [-1.0000, -1.0000, -1.0000] # noqa: E241, E201 + ], + [ + [-1.0000, -1.0000, -1.0000], # noqa: E241, E201 + [-0.5000, 0.5000, 1.0000], # noqa: E241, E201 + [-0.0000, -0.0000, 1.0000], # noqa: E241, E201 + [ 0.5000, -0.5000, 1.0000], # noqa: E241, E201 + [-1.0000, -1.0000, -1.0000] # noqa: E241, E201 + ], + [ + [-1.0000, -1.0000, -1.0000], # noqa: E241, E201 + [-0.2500, 0.7500, 0.5000], # noqa: E241, E201 + [ 0.2500, 0.2500, 0.5000], # noqa: E241, E201 + [ 0.7500, -0.2500, 0.5000], # noqa: E241, E201 + [-1.0000, -1.0000, -1.0000] # noqa: E241, E201 + ], + [ + [-0.5000, 1.5000, -0.0000], # noqa: E241, E201 + [-0.0000, 1.0000, -0.0000], # noqa: E241, E201 + [ 0.5000, 0.5000, -0.0000], # noqa: E241, E201 + [ 1.0000, -0.0000, -0.0000], # noqa: E241, E201 + [ 1.5000, -0.5000, 0.0000] # noqa: E241, E201 + ], + [ + [-1.0000, -1.0000, -1.0000], # noqa: E241, E201 + [ 0.2500, 1.2500, -0.5000], # noqa: E241, E201 + [ 0.7500, 0.7500, -0.5000], # noqa: E241, E201 + [ 1.2500, 0.2500, -0.5000], # noqa: E241, E201 + [-1.0000, -1.0000, -1.0000] # noqa: E241, E201 + ] + ], dtype=torch.float32, device=device).view(1, 5, 5, 1, 3) + # fmt: on + + self.assertClose(expected_bary, bary_f, atol=1e-4) + + # calculate the expected clipped barycentrics and zbuf + expected_bary_clipped = _clip_barycentric_coordinates(expected_bary) + expected_z_clipped = _interpolate_zbuf(idx_f, expected_bary_clipped, meshes) + + kwargs["clip_barycentric_coords"] = True + idx_t, zbuf_t, bary_t, dists_t = rasterize_meshes_fn(**kwargs) + + self.assertClose(expected_bary_clipped, bary_t, atol=1e-4) + self.assertClose(expected_z_clipped, zbuf_t, atol=1e-4) + def _test_behind_camera(self, rasterize_meshes_fn, device, bin_size=None): """ All verts are behind the camera so nothing should get rasterized. diff --git a/tests/test_render_meshes.py b/tests/test_render_meshes.py index 0ae19471..c4325e30 100644 --- a/tests/test_render_meshes.py +++ b/tests/test_render_meshes.py @@ -212,6 +212,7 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase): image_size=512, blur_radius=np.log(1.0 / 1e-4 - 1.0) * blend_params.sigma, faces_per_pixel=80, + clip_barycentric_coords=True, ) # Init rasterizer settings @@ -269,11 +270,19 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase): # the cow is facing the -z direction. lights.location = torch.tensor([0.0, 0.0, 2.0], device=device)[None] + blend_params = BlendParams( + sigma=1e-1, + gamma=1e-4, + background_color=torch.tensor([1.0, 1.0, 1.0], device=device), + ) # Init renderer renderer = MeshRenderer( rasterizer=MeshRasterizer(cameras=cameras, raster_settings=raster_settings), shader=TexturedSoftPhongShader( - lights=lights, cameras=cameras, materials=materials + lights=lights, + cameras=cameras, + materials=materials, + blend_params=blend_params, ), ) @@ -346,6 +355,7 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase): image_size=512, blur_radius=np.log(1.0 / 1e-4 - 1.0) * blend_params.sigma, faces_per_pixel=100, + clip_barycentric_coords=True, ) # Load reference image