barycentric clipping in cuda/c++

Summary:
Added support for barycentric clipping in the C++/CUDA rasterization kernels which can be switched on/off via a rasterization setting.

Added tests and a benchmark to compare with the current implementation in PyTorch - for some cases of large image size/faces per pixel the cuda version is 10x faster.

Reviewed By: gkioxari

Differential Revision: D21705503

fbshipit-source-id: e835c0f927f1e5088ca89020aef5ff27ac3a8769
This commit is contained in:
Nikhila Ravi
2020-07-16 10:15:30 -07:00
committed by Facebook GitHub Bot
parent bce396df93
commit cc70950f40
13 changed files with 611 additions and 55 deletions

View File

@@ -114,6 +114,7 @@ __device__ void CheckPixelInsideFace(
const float2 pxy, // Coordinates of the pixel
const int K,
const bool perspective_correct,
const bool clip_barycentric_coords,
const bool cull_backfaces) {
const auto v012 = GetSingleFaceVerts(face_verts, face_idx);
const float3 v0 = thrust::get<0>(v012);
@@ -149,8 +150,12 @@ __device__ void CheckPixelInsideFace(
const float3 p_bary = !perspective_correct
? p_bary0
: BarycentricPerspectiveCorrectionForward(p_bary0, v0.z, v1.z, v2.z);
const float3 p_bary_clip =
!clip_barycentric_coords ? p_bary : BarycentricClipForward(p_bary);
const float pz =
p_bary_clip.x * v0.z + p_bary_clip.y * v1.z + p_bary_clip.z * v2.z;
const float pz = p_bary.x * v0.z + p_bary.y * v1.z + p_bary.z * v2.z;
if (pz < 0) {
return; // Face is behind the image plane.
}
@@ -158,7 +163,8 @@ __device__ void CheckPixelInsideFace(
// Get abs squared distance
const float dist = PointTriangleDistanceForward(pxy, v0xy, v1xy, v2xy);
// Use the bary coordinates to determine if the point is inside the face.
// Use the unclipped bary coordinates to determine if the point is inside the
// face.
const bool inside = p_bary.x > 0.0f && p_bary.y > 0.0f && p_bary.z > 0.0f;
const float signed_dist = inside ? -dist : dist;
@@ -169,7 +175,7 @@ __device__ void CheckPixelInsideFace(
if (q_size < K) {
// Just insert it.
q[q_size] = {pz, face_idx, signed_dist, p_bary};
q[q_size] = {pz, face_idx, signed_dist, p_bary_clip};
if (pz > q_max_z) {
q_max_z = pz;
q_max_idx = q_size;
@@ -177,7 +183,7 @@ __device__ void CheckPixelInsideFace(
q_size++;
} else if (pz < q_max_z) {
// Overwrite the old max, and find the new max.
q[q_max_idx] = {pz, face_idx, signed_dist, p_bary};
q[q_max_idx] = {pz, face_idx, signed_dist, p_bary_clip};
q_max_z = pz;
for (int i = 0; i < K; i++) {
if (q[i].z > q_max_z) {
@@ -198,6 +204,7 @@ __global__ void RasterizeMeshesNaiveCudaKernel(
const int64_t* num_faces_per_mesh,
const float blur_radius,
const bool perspective_correct,
const bool clip_barycentric_coords,
const bool cull_backfaces,
const int N,
const int H,
@@ -260,6 +267,7 @@ __global__ void RasterizeMeshesNaiveCudaKernel(
pxy,
K,
perspective_correct,
clip_barycentric_coords,
cull_backfaces);
}
@@ -286,6 +294,7 @@ RasterizeMeshesNaiveCuda(
const float blur_radius,
const int num_closest,
const bool perspective_correct,
const bool clip_barycentric_coords,
const bool cull_backfaces) {
TORCH_CHECK(
face_verts.ndimension() == 3 && face_verts.size(1) == 3 &&
@@ -343,6 +352,7 @@ RasterizeMeshesNaiveCuda(
num_faces_per_mesh.contiguous().data_ptr<int64_t>(),
blur_radius,
perspective_correct,
clip_barycentric_coords,
cull_backfaces,
N,
H,
@@ -365,6 +375,7 @@ __global__ void RasterizeMeshesBackwardCudaKernel(
const float* face_verts, // (F, 3, 3)
const int64_t* pix_to_face, // (N, H, W, K)
const bool perspective_correct,
const bool clip_barycentric_coords,
const int N,
const int H,
const int W,
@@ -422,11 +433,15 @@ __global__ void RasterizeMeshesBackwardCudaKernel(
const float3 grad_bary_upstream = make_float3(
grad_bary_upstream_w0, grad_bary_upstream_w1, grad_bary_upstream_w2);
const float3 bary0 = BarycentricCoordsForward(pxy, v0xy, v1xy, v2xy);
const float3 bary = !perspective_correct
? bary0
: BarycentricPerspectiveCorrectionForward(bary0, v0.z, v1.z, v2.z);
const bool inside = bary.x > 0.0f && bary.y > 0.0f && bary.z > 0.0f;
const float3 b_w = BarycentricCoordsForward(pxy, v0xy, v1xy, v2xy);
const float3 b_pp = !perspective_correct
? b_w
: BarycentricPerspectiveCorrectionForward(b_w, v0.z, v1.z, v2.z);
const float3 b_w_clip =
!clip_barycentric_coords ? b_pp : BarycentricClipForward(b_pp);
const bool inside = b_pp.x > 0.0f && b_pp.y > 0.0f && b_pp.z > 0.0f;
const float sign = inside ? -1.0f : 1.0f;
// TODO(T52813608) Add support for non-square images.
@@ -442,22 +457,29 @@ __global__ void RasterizeMeshesBackwardCudaKernel(
// d_zbuf/d_bary_w0 = z0
// d_zbuf/d_bary_w1 = z1
// d_zbuf/d_bary_w2 = z2
const float3 d_zbuf_d_bary = make_float3(v0.z, v1.z, v2.z);
const float3 d_zbuf_d_bwclip = make_float3(v0.z, v1.z, v2.z);
// Total upstream barycentric gradients are the sum of
// external upstream gradients and contribution from zbuf.
const float3 grad_bary_f_sum =
(grad_bary_upstream + grad_zbuf_upstream * d_zbuf_d_bary);
(grad_bary_upstream + grad_zbuf_upstream * d_zbuf_d_bwclip);
float3 grad_bary0 = grad_bary_f_sum;
if (clip_barycentric_coords) {
grad_bary0 = BarycentricClipBackward(b_w, grad_bary_f_sum);
}
float dz0_persp = 0.0f, dz1_persp = 0.0f, dz2_persp = 0.0f;
if (perspective_correct) {
auto perspective_grads = BarycentricPerspectiveCorrectionBackward(
bary0, v0.z, v1.z, v2.z, grad_bary_f_sum);
b_w, v0.z, v1.z, v2.z, grad_bary0);
grad_bary0 = thrust::get<0>(perspective_grads);
dz0_persp = thrust::get<1>(perspective_grads);
dz1_persp = thrust::get<2>(perspective_grads);
dz2_persp = thrust::get<3>(perspective_grads);
}
auto grad_bary_f =
BarycentricCoordsBackward(pxy, v0xy, v1xy, v2xy, grad_bary0);
const float2 dbary_d_v0 = thrust::get<1>(grad_bary_f);
@@ -467,15 +489,18 @@ __global__ void RasterizeMeshesBackwardCudaKernel(
atomicAdd(grad_face_verts + f * 9 + 0, dbary_d_v0.x + ddist_d_v0.x);
atomicAdd(grad_face_verts + f * 9 + 1, dbary_d_v0.y + ddist_d_v0.y);
atomicAdd(
grad_face_verts + f * 9 + 2, grad_zbuf_upstream * bary.x + dz0_persp);
grad_face_verts + f * 9 + 2,
grad_zbuf_upstream * b_w_clip.x + dz0_persp);
atomicAdd(grad_face_verts + f * 9 + 3, dbary_d_v1.x + ddist_d_v1.x);
atomicAdd(grad_face_verts + f * 9 + 4, dbary_d_v1.y + ddist_d_v1.y);
atomicAdd(
grad_face_verts + f * 9 + 5, grad_zbuf_upstream * bary.y + dz1_persp);
grad_face_verts + f * 9 + 5,
grad_zbuf_upstream * b_w_clip.y + dz1_persp);
atomicAdd(grad_face_verts + f * 9 + 6, dbary_d_v2.x + ddist_d_v2.x);
atomicAdd(grad_face_verts + f * 9 + 7, dbary_d_v2.y + ddist_d_v2.y);
atomicAdd(
grad_face_verts + f * 9 + 8, grad_zbuf_upstream * bary.z + dz2_persp);
grad_face_verts + f * 9 + 8,
grad_zbuf_upstream * b_w_clip.z + dz2_persp);
}
}
}
@@ -486,7 +511,8 @@ at::Tensor RasterizeMeshesBackwardCuda(
const at::Tensor& grad_zbuf, // (N, H, W, K)
const at::Tensor& grad_bary, // (N, H, W, K, 3)
const at::Tensor& grad_dists, // (N, H, W, K)
const bool perspective_correct) {
const bool perspective_correct,
const bool clip_barycentric_coords) {
// Check inputs are on the same device
at::TensorArg face_verts_t{face_verts, "face_verts", 1},
pix_to_face_t{pix_to_face, "pix_to_face", 2},
@@ -523,6 +549,7 @@ at::Tensor RasterizeMeshesBackwardCuda(
face_verts.contiguous().data_ptr<float>(),
pix_to_face.contiguous().data_ptr<int64_t>(),
perspective_correct,
clip_barycentric_coords,
N,
H,
W,
@@ -743,6 +770,7 @@ __global__ void RasterizeMeshesFineCudaKernel(
const float blur_radius,
const int bin_size,
const bool perspective_correct,
const bool clip_barycentric_coords,
const bool cull_backfaces,
const int N,
const int B,
@@ -808,6 +836,7 @@ __global__ void RasterizeMeshesFineCudaKernel(
pxy,
K,
perspective_correct,
clip_barycentric_coords,
cull_backfaces);
}
@@ -841,6 +870,7 @@ RasterizeMeshesFineCuda(
const int bin_size,
const int faces_per_pixel,
const bool perspective_correct,
const bool clip_barycentric_coords,
const bool cull_backfaces) {
TORCH_CHECK(
face_verts.ndimension() == 3 && face_verts.size(1) == 3 &&
@@ -890,6 +920,7 @@ RasterizeMeshesFineCuda(
blur_radius,
bin_size,
perspective_correct,
clip_barycentric_coords,
cull_backfaces,
N,
B,

View File

@@ -19,6 +19,7 @@ RasterizeMeshesNaiveCpu(
const float blur_radius,
const int faces_per_pixel,
const bool perspective_correct,
const bool clip_barycentric_coords,
const bool cull_backfaces);
#ifdef WITH_CUDA
@@ -31,6 +32,7 @@ RasterizeMeshesNaiveCuda(
const float blur_radius,
const int num_closest,
const bool perspective_correct,
const bool clip_barycentric_coords,
const bool cull_backfaces);
#endif
// Forward pass for rasterizing a batch of meshes.
@@ -92,6 +94,7 @@ RasterizeMeshesNaive(
const float blur_radius,
const int faces_per_pixel,
const bool perspective_correct,
const bool clip_barycentric_coords,
const bool cull_backfaces) {
// TODO: Better type checking.
if (face_verts.is_cuda()) {
@@ -107,6 +110,7 @@ RasterizeMeshesNaive(
blur_radius,
faces_per_pixel,
perspective_correct,
clip_barycentric_coords,
cull_backfaces);
#else
AT_ERROR("Not compiled with GPU support");
@@ -120,6 +124,7 @@ RasterizeMeshesNaive(
blur_radius,
faces_per_pixel,
perspective_correct,
clip_barycentric_coords,
cull_backfaces);
}
}
@@ -134,7 +139,8 @@ torch::Tensor RasterizeMeshesBackwardCpu(
const torch::Tensor& grad_bary,
const torch::Tensor& grad_zbuf,
const torch::Tensor& grad_dists,
const bool perspective_correct);
const bool perspective_correct,
const bool clip_barycentric_coords);
#ifdef WITH_CUDA
torch::Tensor RasterizeMeshesBackwardCuda(
@@ -143,7 +149,8 @@ torch::Tensor RasterizeMeshesBackwardCuda(
const torch::Tensor& grad_bary,
const torch::Tensor& grad_zbuf,
const torch::Tensor& grad_dists,
const bool perspective_correct);
const bool perspective_correct,
const bool clip_barycentric_coords);
#endif
// Args:
@@ -176,7 +183,8 @@ torch::Tensor RasterizeMeshesBackward(
const torch::Tensor& grad_zbuf,
const torch::Tensor& grad_bary,
const torch::Tensor& grad_dists,
const bool perspective_correct) {
const bool perspective_correct,
const bool clip_barycentric_coords) {
if (face_verts.is_cuda()) {
#ifdef WITH_CUDA
CHECK_CUDA(face_verts);
@@ -190,7 +198,8 @@ torch::Tensor RasterizeMeshesBackward(
grad_zbuf,
grad_bary,
grad_dists,
perspective_correct);
perspective_correct,
clip_barycentric_coords);
#else
AT_ERROR("Not compiled with GPU support");
#endif
@@ -201,7 +210,8 @@ torch::Tensor RasterizeMeshesBackward(
grad_zbuf,
grad_bary,
grad_dists,
perspective_correct);
perspective_correct,
clip_barycentric_coords);
}
}
@@ -300,6 +310,7 @@ RasterizeMeshesFineCuda(
const int bin_size,
const int faces_per_pixel,
const bool perspective_correct,
const bool clip_barycentric_coords,
const bool cull_backfaces);
#endif
// Args:
@@ -356,6 +367,7 @@ RasterizeMeshesFine(
const int bin_size,
const int faces_per_pixel,
const bool perspective_correct,
const bool clip_barycentric_coords,
const bool cull_backfaces) {
if (face_verts.is_cuda()) {
#ifdef WITH_CUDA
@@ -369,6 +381,7 @@ RasterizeMeshesFine(
bin_size,
faces_per_pixel,
perspective_correct,
clip_barycentric_coords,
cull_backfaces);
#else
AT_ERROR("Not compiled with GPU support");
@@ -446,6 +459,7 @@ RasterizeMeshes(
const int bin_size,
const int max_faces_per_bin,
const bool perspective_correct,
const bool clip_barycentric_coords,
const bool cull_backfaces) {
if (bin_size > 0 && max_faces_per_bin > 0) {
// Use coarse-to-fine rasterization
@@ -465,6 +479,7 @@ RasterizeMeshes(
bin_size,
faces_per_pixel,
perspective_correct,
clip_barycentric_coords,
cull_backfaces);
} else {
// Use the naive per-pixel implementation
@@ -476,6 +491,7 @@ RasterizeMeshes(
blur_radius,
faces_per_pixel,
perspective_correct,
clip_barycentric_coords,
cull_backfaces);
}
}

View File

@@ -108,6 +108,7 @@ RasterizeMeshesNaiveCpu(
const float blur_radius,
const int faces_per_pixel,
const bool perspective_correct,
const bool clip_barycentric_coords,
const bool cull_backfaces) {
if (face_verts.ndimension() != 3 || face_verts.size(1) != 3 ||
face_verts.size(2) != 3) {
@@ -213,8 +214,12 @@ RasterizeMeshesNaiveCpu(
? bary0
: BarycentricPerspectiveCorrectionForward(bary0, z0, z1, z2);
const vec3<float> bary_clip =
!clip_barycentric_coords ? bary : BarycentricClipForward(bary);
// Use barycentric coordinates to get the depth of the current pixel
const float pz = (bary.x * z0 + bary.y * z1 + bary.z * z2);
const float pz =
(bary_clip.x * z0 + bary_clip.y * z1 + bary_clip.z * z2);
if (pz < 0) {
continue; // Point is behind the image plane so ignore.
@@ -236,7 +241,7 @@ RasterizeMeshesNaiveCpu(
continue;
}
// The current pixel lies inside the current face.
q.emplace(pz, f, signed_dist, bary.x, bary.y, bary.z);
q.emplace(pz, f, signed_dist, bary_clip.x, bary_clip.y, bary_clip.z);
if (static_cast<int>(q.size()) > K) {
q.pop();
}
@@ -264,7 +269,8 @@ torch::Tensor RasterizeMeshesBackwardCpu(
const torch::Tensor& grad_zbuf, // (N, H, W, K)
const torch::Tensor& grad_bary, // (N, H, W, K, 3)
const torch::Tensor& grad_dists, // (N, H, W, K)
const bool perspective_correct) {
const bool perspective_correct,
const bool clip_barycentric_coords) {
const int F = face_verts.size(0);
const int N = pix_to_face.size(0);
const int H = pix_to_face.size(1);
@@ -335,6 +341,8 @@ torch::Tensor RasterizeMeshesBackwardCpu(
const vec3<float> bary = !perspective_correct
? bary0
: BarycentricPerspectiveCorrectionForward(bary0, z0, z1, z2);
const vec3<float> bary_clip =
!clip_barycentric_coords ? bary : BarycentricClipForward(bary);
// Distances inside the face are negative so get the
// correct sign to apply to the upstream gradient.
@@ -354,22 +362,28 @@ torch::Tensor RasterizeMeshesBackwardCpu(
// d_zbuf/d_bary_w0 = z0
// d_zbuf/d_bary_w1 = z1
// d_zbuf/d_bary_w2 = z2
const vec3<float> d_zbuf_d_bary(z0, z1, z2);
const vec3<float> d_zbuf_d_baryclip(z0, z1, z2);
// Total upstream barycentric gradients are the sum of
// external upstream gradients and contribution from zbuf.
vec3<float> grad_bary_f_sum =
(grad_bary_upstream + grad_zbuf_upstream * d_zbuf_d_bary);
const vec3<float> grad_bary_f_sum =
(grad_bary_upstream + grad_zbuf_upstream * d_zbuf_d_baryclip);
vec3<float> grad_bary0 = grad_bary_f_sum;
if (clip_barycentric_coords) {
grad_bary0 = BarycentricClipBackward(bary, grad_bary0);
}
if (perspective_correct) {
auto perspective_grads = BarycentricPerspectiveCorrectionBackward(
bary0, z0, z1, z2, grad_bary_f_sum);
bary0, z0, z1, z2, grad_bary0);
grad_bary0 = std::get<0>(perspective_grads);
grad_face_verts[f][0][2] += std::get<1>(perspective_grads);
grad_face_verts[f][1][2] += std::get<2>(perspective_grads);
grad_face_verts[f][2][2] += std::get<3>(perspective_grads);
}
auto grad_bary_f =
BarycentricCoordsBackward(pxy, v0xy, v1xy, v2xy, grad_bary0);
const vec2<float> dbary_d_v0 = std::get<1>(grad_bary_f);
@@ -379,13 +393,13 @@ torch::Tensor RasterizeMeshesBackwardCpu(
// Update output gradient buffer.
grad_face_verts[f][0][0] += dbary_d_v0.x + ddist_d_v0.x;
grad_face_verts[f][0][1] += dbary_d_v0.y + ddist_d_v0.y;
grad_face_verts[f][0][2] += grad_zbuf_upstream * bary.x;
grad_face_verts[f][0][2] += grad_zbuf_upstream * bary_clip.x;
grad_face_verts[f][1][0] += dbary_d_v1.x + ddist_d_v1.x;
grad_face_verts[f][1][1] += dbary_d_v1.y + ddist_d_v1.y;
grad_face_verts[f][1][2] += grad_zbuf_upstream * bary.y;
grad_face_verts[f][1][2] += grad_zbuf_upstream * bary_clip.y;
grad_face_verts[f][2][0] += dbary_d_v2.x + ddist_d_v2.x;
grad_face_verts[f][2][1] += dbary_d_v2.y + ddist_d_v2.y;
grad_face_verts[f][2][2] += grad_zbuf_upstream * bary.z;
grad_face_verts[f][2][2] += grad_zbuf_upstream * bary_clip.z;
}
}
}

View File

@@ -221,8 +221,108 @@ BarycentricPerspectiveCorrectionBackward(
return thrust::make_tuple(grad_bary, grad_z0, grad_z1, grad_z2);
}
// Calculate minimum squared distance between a line segment (v1 - v0) and a
// point p.
// Clip negative barycentric coordinates to 0.0 and renormalize so
// the barycentric coordinates for a point sum to 1. When the blur_radius
// is greater than 0, a face will still be recorded as overlapping a pixel
// if the pixel is outisde the face. In this case at least one of the
// barycentric coordinates for the pixel relative to the face will be negative.
// Clipping will ensure that the texture and z buffer are interpolated
// correctly.
//
// Args
// bary: (w0, w1, w2) barycentric coordinates which can be outside the
// range [0, 1].
//
// Returns
// bary: (w0, w1, w2) barycentric coordinates in the range [0, 1] which
// satisfy the condition: sum(w0, w1, w2) = 1.0.
//
__device__ inline float3 BarycentricClipForward(const float3 bary) {
float3 w = make_float3(0.0f, 0.0f, 0.0f);
// Clamp lower bound only
w.x = max(bary.x, 0.0);
w.y = max(bary.y, 0.0);
w.z = max(bary.z, 0.0);
float w_sum = w.x + w.y + w.z;
w_sum = fmaxf(w_sum, 1e-5);
w.x /= w_sum;
w.y /= w_sum;
w.z /= w_sum;
return w;
}
// Backward pass for barycentric coordinate clipping.
//
// Args
// bary: (w0, w1, w2) barycentric coordinates which can be outside the
// range [0, 1].
// grad_baryclip_upstream: vec3<T> Upstream gradient for each of the clipped
// barycentric coordinates [grad_w0, grad_w1, grad_w2].
//
// Returns
// vec3<T> of gradients for the unclipped barycentric coordinates:
// (grad_w0, grad_w1, grad_w2)
//
__device__ inline float3 BarycentricClipBackward(
const float3 bary,
const float3 grad_baryclip_upstream) {
// Redo some of the forward pass calculations
float3 w = make_float3(0.0f, 0.0f, 0.0f);
// Clamp lower bound only
w.x = max(bary.x, 0.0);
w.y = max(bary.y, 0.0);
w.z = max(bary.z, 0.0);
float w_sum = w.x + w.y + w.z;
float3 grad_bary = make_float3(1.0f, 1.0f, 1.0f);
float3 grad_clip = make_float3(1.0f, 1.0f, 1.0f);
float3 grad_sum = make_float3(1.0f, 1.0f, 1.0f);
// Check if sum was clipped.
float grad_sum_clip = 1.0f;
if (w_sum < 1e-5) {
grad_sum_clip = 0.0f;
w_sum = 1e-5;
}
// Check if any of bary values have been clipped.
if (bary.x < 0.0f) {
grad_clip.x = 0.0f;
}
if (bary.y < 0.0f) {
grad_clip.y = 0.0f;
}
if (bary.z < 0.0f) {
grad_clip.z = 0.0f;
}
// Gradients of the sum.
grad_sum.x = -w.x / (pow(w_sum, 2.0f)) * grad_sum_clip;
grad_sum.y = -w.y / (pow(w_sum, 2.0f)) * grad_sum_clip;
grad_sum.z = -w.z / (pow(w_sum, 2.0f)) * grad_sum_clip;
// Gradients for each of the bary coordinates including the cross terms
// from the sum.
grad_bary.x = grad_clip.x *
(grad_baryclip_upstream.x * (1.0f / w_sum + grad_sum.x) +
grad_baryclip_upstream.y * (grad_sum.y) +
grad_baryclip_upstream.z * (grad_sum.z));
grad_bary.y = grad_clip.y *
(grad_baryclip_upstream.y * (1.0f / w_sum + grad_sum.y) +
grad_baryclip_upstream.x * (grad_sum.x) +
grad_baryclip_upstream.z * (grad_sum.z));
grad_bary.z = grad_clip.z *
(grad_baryclip_upstream.z * (1.0f / w_sum + grad_sum.z) +
grad_baryclip_upstream.x * (grad_sum.x) +
grad_baryclip_upstream.y * (grad_sum.y));
return grad_bary;
}
// Return minimum distance between line segment (v1 - v0) and point p.
//
// Args:
// p: Coordinates of a point.

View File

@@ -242,8 +242,108 @@ inline std::tuple<vec3<T>, T, T, T> BarycentricPerspectiveCorrectionBackward(
return std::make_tuple(grad_bary, grad_z0, grad_z1, grad_z2);
}
// Calculate minimum squared distance between a line segment (v1 - v0) and a
// point p.
// Clip negative barycentric coordinates to 0.0 and renormalize so
// the barycentric coordinates for a point sum to 1. When the blur_radius
// is greater than 0, a face will still be recorded as overlapping a pixel
// if the pixel is outisde the face. In this case at least one of the
// barycentric coordinates for the pixel relative to the face will be negative.
// Clipping will ensure that the texture and z buffer are interpolated
// correctly.
//
// Args
// bary: (w0, w1, w2) barycentric coordinates which can contain values < 0.
//
// Returns
// bary: (w0, w1, w2) barycentric coordinates in the range [0, 1] which
// satisfy the condition: sum(w0, w1, w2) = 1.0.
//
template <typename T>
vec3<T> BarycentricClipForward(const vec3<T> bary) {
vec3<T> w(0.0f, 0.0f, 0.0f);
// Only clamp negative values to 0.0.
// No need to clamp values > 1.0 as they will be renormalized.
w.x = std::max(bary.x, 0.0f);
w.y = std::max(bary.y, 0.0f);
w.z = std::max(bary.z, 0.0f);
float w_sum = w.x + w.y + w.z;
w_sum = std::fmaxf(w_sum, 1e-5);
w.x /= w_sum;
w.y /= w_sum;
w.z /= w_sum;
return w;
}
// Backward pass for barycentric coordinate clipping.
//
// Args
// bary: (w0, w1, w2) barycentric coordinates which can contain values < 0.
// grad_baryclip_upstream: vec3<T> Upstream gradient for each of the clipped
// barycentric coordinates [grad_w0, grad_w1, grad_w2].
//
// Returns
// vec3<T> of gradients for the unclipped barycentric coordinates:
// (grad_w0, grad_w1, grad_w2)
//
template <typename T>
vec3<T> BarycentricClipBackward(
const vec3<T> bary,
const vec3<T> grad_baryclip_upstream) {
// Redo some of the forward pass calculations
vec3<T> w(0.0f, 0.0f, 0.0f);
w.x = std::max(bary.x, 0.0f);
w.y = std::max(bary.y, 0.0f);
w.z = std::max(bary.z, 0.0f);
float w_sum = w.x + w.y + w.z;
vec3<T> grad_bary(1.0f, 1.0f, 1.0f);
vec3<T> grad_clip(1.0f, 1.0f, 1.0f);
vec3<T> grad_sum(1.0f, 1.0f, 1.0f);
// Check if the sum was clipped.
float grad_sum_clip = 1.0f;
if (w_sum < 1e-5) {
grad_sum_clip = 0.0f;
w_sum = 1e-5;
}
// Check if any of the bary coordinates have been clipped.
// Only negative values are clamped to 0.0.
if (bary.x < 0.0f) {
grad_clip.x = 0.0f;
}
if (bary.y < 0.0f) {
grad_clip.y = 0.0f;
}
if (bary.z < 0.0f) {
grad_clip.z = 0.0f;
}
// Gradients of the sum.
grad_sum.x = -w.x / (pow(w_sum, 2.0f)) * grad_sum_clip;
grad_sum.y = -w.y / (pow(w_sum, 2.0f)) * grad_sum_clip;
grad_sum.z = -w.z / (pow(w_sum, 2.0f)) * grad_sum_clip;
// Gradients for each of the bary coordinates including the cross terms
// from the sum.
grad_bary.x = grad_clip.x *
(grad_baryclip_upstream.x * (1.0f / w_sum + grad_sum.x) +
grad_baryclip_upstream.y * (grad_sum.y) +
grad_baryclip_upstream.z * (grad_sum.z));
grad_bary.y = grad_clip.y *
(grad_baryclip_upstream.y * (1.0f / w_sum + grad_sum.y) +
grad_baryclip_upstream.x * (grad_sum.x) +
grad_baryclip_upstream.z * (grad_sum.z));
grad_bary.z = grad_clip.z *
(grad_baryclip_upstream.z * (1.0f / w_sum + grad_sum.z) +
grad_baryclip_upstream.x * (grad_sum.x) +
grad_baryclip_upstream.y * (grad_sum.y));
return grad_bary;
}
// Calculate minimum distance between a line segment (v1 - v0) and point p.
//
// Args:
// p: Coordinates of a point.