barycentric clipping in cuda/c++

Summary:
Added support for barycentric clipping in the C++/CUDA rasterization kernels which can be switched on/off via a rasterization setting.

Added tests and a benchmark to compare with the current implementation in PyTorch - for some cases of large image size/faces per pixel the cuda version is 10x faster.

Reviewed By: gkioxari

Differential Revision: D21705503

fbshipit-source-id: e835c0f927f1e5088ca89020aef5ff27ac3a8769
This commit is contained in:
Nikhila Ravi 2020-07-16 10:15:30 -07:00 committed by Facebook GitHub Bot
parent bce396df93
commit cc70950f40
13 changed files with 611 additions and 55 deletions

View File

@ -114,6 +114,7 @@ __device__ void CheckPixelInsideFace(
const float2 pxy, // Coordinates of the pixel
const int K,
const bool perspective_correct,
const bool clip_barycentric_coords,
const bool cull_backfaces) {
const auto v012 = GetSingleFaceVerts(face_verts, face_idx);
const float3 v0 = thrust::get<0>(v012);
@ -149,8 +150,12 @@ __device__ void CheckPixelInsideFace(
const float3 p_bary = !perspective_correct
? p_bary0
: BarycentricPerspectiveCorrectionForward(p_bary0, v0.z, v1.z, v2.z);
const float3 p_bary_clip =
!clip_barycentric_coords ? p_bary : BarycentricClipForward(p_bary);
const float pz =
p_bary_clip.x * v0.z + p_bary_clip.y * v1.z + p_bary_clip.z * v2.z;
const float pz = p_bary.x * v0.z + p_bary.y * v1.z + p_bary.z * v2.z;
if (pz < 0) {
return; // Face is behind the image plane.
}
@ -158,7 +163,8 @@ __device__ void CheckPixelInsideFace(
// Get abs squared distance
const float dist = PointTriangleDistanceForward(pxy, v0xy, v1xy, v2xy);
// Use the bary coordinates to determine if the point is inside the face.
// Use the unclipped bary coordinates to determine if the point is inside the
// face.
const bool inside = p_bary.x > 0.0f && p_bary.y > 0.0f && p_bary.z > 0.0f;
const float signed_dist = inside ? -dist : dist;
@ -169,7 +175,7 @@ __device__ void CheckPixelInsideFace(
if (q_size < K) {
// Just insert it.
q[q_size] = {pz, face_idx, signed_dist, p_bary};
q[q_size] = {pz, face_idx, signed_dist, p_bary_clip};
if (pz > q_max_z) {
q_max_z = pz;
q_max_idx = q_size;
@ -177,7 +183,7 @@ __device__ void CheckPixelInsideFace(
q_size++;
} else if (pz < q_max_z) {
// Overwrite the old max, and find the new max.
q[q_max_idx] = {pz, face_idx, signed_dist, p_bary};
q[q_max_idx] = {pz, face_idx, signed_dist, p_bary_clip};
q_max_z = pz;
for (int i = 0; i < K; i++) {
if (q[i].z > q_max_z) {
@ -198,6 +204,7 @@ __global__ void RasterizeMeshesNaiveCudaKernel(
const int64_t* num_faces_per_mesh,
const float blur_radius,
const bool perspective_correct,
const bool clip_barycentric_coords,
const bool cull_backfaces,
const int N,
const int H,
@ -260,6 +267,7 @@ __global__ void RasterizeMeshesNaiveCudaKernel(
pxy,
K,
perspective_correct,
clip_barycentric_coords,
cull_backfaces);
}
@ -286,6 +294,7 @@ RasterizeMeshesNaiveCuda(
const float blur_radius,
const int num_closest,
const bool perspective_correct,
const bool clip_barycentric_coords,
const bool cull_backfaces) {
TORCH_CHECK(
face_verts.ndimension() == 3 && face_verts.size(1) == 3 &&
@ -343,6 +352,7 @@ RasterizeMeshesNaiveCuda(
num_faces_per_mesh.contiguous().data_ptr<int64_t>(),
blur_radius,
perspective_correct,
clip_barycentric_coords,
cull_backfaces,
N,
H,
@ -365,6 +375,7 @@ __global__ void RasterizeMeshesBackwardCudaKernel(
const float* face_verts, // (F, 3, 3)
const int64_t* pix_to_face, // (N, H, W, K)
const bool perspective_correct,
const bool clip_barycentric_coords,
const int N,
const int H,
const int W,
@ -422,11 +433,15 @@ __global__ void RasterizeMeshesBackwardCudaKernel(
const float3 grad_bary_upstream = make_float3(
grad_bary_upstream_w0, grad_bary_upstream_w1, grad_bary_upstream_w2);
const float3 bary0 = BarycentricCoordsForward(pxy, v0xy, v1xy, v2xy);
const float3 bary = !perspective_correct
? bary0
: BarycentricPerspectiveCorrectionForward(bary0, v0.z, v1.z, v2.z);
const bool inside = bary.x > 0.0f && bary.y > 0.0f && bary.z > 0.0f;
const float3 b_w = BarycentricCoordsForward(pxy, v0xy, v1xy, v2xy);
const float3 b_pp = !perspective_correct
? b_w
: BarycentricPerspectiveCorrectionForward(b_w, v0.z, v1.z, v2.z);
const float3 b_w_clip =
!clip_barycentric_coords ? b_pp : BarycentricClipForward(b_pp);
const bool inside = b_pp.x > 0.0f && b_pp.y > 0.0f && b_pp.z > 0.0f;
const float sign = inside ? -1.0f : 1.0f;
// TODO(T52813608) Add support for non-square images.
@ -442,22 +457,29 @@ __global__ void RasterizeMeshesBackwardCudaKernel(
// d_zbuf/d_bary_w0 = z0
// d_zbuf/d_bary_w1 = z1
// d_zbuf/d_bary_w2 = z2
const float3 d_zbuf_d_bary = make_float3(v0.z, v1.z, v2.z);
const float3 d_zbuf_d_bwclip = make_float3(v0.z, v1.z, v2.z);
// Total upstream barycentric gradients are the sum of
// external upstream gradients and contribution from zbuf.
const float3 grad_bary_f_sum =
(grad_bary_upstream + grad_zbuf_upstream * d_zbuf_d_bary);
(grad_bary_upstream + grad_zbuf_upstream * d_zbuf_d_bwclip);
float3 grad_bary0 = grad_bary_f_sum;
if (clip_barycentric_coords) {
grad_bary0 = BarycentricClipBackward(b_w, grad_bary_f_sum);
}
float dz0_persp = 0.0f, dz1_persp = 0.0f, dz2_persp = 0.0f;
if (perspective_correct) {
auto perspective_grads = BarycentricPerspectiveCorrectionBackward(
bary0, v0.z, v1.z, v2.z, grad_bary_f_sum);
b_w, v0.z, v1.z, v2.z, grad_bary0);
grad_bary0 = thrust::get<0>(perspective_grads);
dz0_persp = thrust::get<1>(perspective_grads);
dz1_persp = thrust::get<2>(perspective_grads);
dz2_persp = thrust::get<3>(perspective_grads);
}
auto grad_bary_f =
BarycentricCoordsBackward(pxy, v0xy, v1xy, v2xy, grad_bary0);
const float2 dbary_d_v0 = thrust::get<1>(grad_bary_f);
@ -467,15 +489,18 @@ __global__ void RasterizeMeshesBackwardCudaKernel(
atomicAdd(grad_face_verts + f * 9 + 0, dbary_d_v0.x + ddist_d_v0.x);
atomicAdd(grad_face_verts + f * 9 + 1, dbary_d_v0.y + ddist_d_v0.y);
atomicAdd(
grad_face_verts + f * 9 + 2, grad_zbuf_upstream * bary.x + dz0_persp);
grad_face_verts + f * 9 + 2,
grad_zbuf_upstream * b_w_clip.x + dz0_persp);
atomicAdd(grad_face_verts + f * 9 + 3, dbary_d_v1.x + ddist_d_v1.x);
atomicAdd(grad_face_verts + f * 9 + 4, dbary_d_v1.y + ddist_d_v1.y);
atomicAdd(
grad_face_verts + f * 9 + 5, grad_zbuf_upstream * bary.y + dz1_persp);
grad_face_verts + f * 9 + 5,
grad_zbuf_upstream * b_w_clip.y + dz1_persp);
atomicAdd(grad_face_verts + f * 9 + 6, dbary_d_v2.x + ddist_d_v2.x);
atomicAdd(grad_face_verts + f * 9 + 7, dbary_d_v2.y + ddist_d_v2.y);
atomicAdd(
grad_face_verts + f * 9 + 8, grad_zbuf_upstream * bary.z + dz2_persp);
grad_face_verts + f * 9 + 8,
grad_zbuf_upstream * b_w_clip.z + dz2_persp);
}
}
}
@ -486,7 +511,8 @@ at::Tensor RasterizeMeshesBackwardCuda(
const at::Tensor& grad_zbuf, // (N, H, W, K)
const at::Tensor& grad_bary, // (N, H, W, K, 3)
const at::Tensor& grad_dists, // (N, H, W, K)
const bool perspective_correct) {
const bool perspective_correct,
const bool clip_barycentric_coords) {
// Check inputs are on the same device
at::TensorArg face_verts_t{face_verts, "face_verts", 1},
pix_to_face_t{pix_to_face, "pix_to_face", 2},
@ -523,6 +549,7 @@ at::Tensor RasterizeMeshesBackwardCuda(
face_verts.contiguous().data_ptr<float>(),
pix_to_face.contiguous().data_ptr<int64_t>(),
perspective_correct,
clip_barycentric_coords,
N,
H,
W,
@ -743,6 +770,7 @@ __global__ void RasterizeMeshesFineCudaKernel(
const float blur_radius,
const int bin_size,
const bool perspective_correct,
const bool clip_barycentric_coords,
const bool cull_backfaces,
const int N,
const int B,
@ -808,6 +836,7 @@ __global__ void RasterizeMeshesFineCudaKernel(
pxy,
K,
perspective_correct,
clip_barycentric_coords,
cull_backfaces);
}
@ -841,6 +870,7 @@ RasterizeMeshesFineCuda(
const int bin_size,
const int faces_per_pixel,
const bool perspective_correct,
const bool clip_barycentric_coords,
const bool cull_backfaces) {
TORCH_CHECK(
face_verts.ndimension() == 3 && face_verts.size(1) == 3 &&
@ -890,6 +920,7 @@ RasterizeMeshesFineCuda(
blur_radius,
bin_size,
perspective_correct,
clip_barycentric_coords,
cull_backfaces,
N,
B,

View File

@ -19,6 +19,7 @@ RasterizeMeshesNaiveCpu(
const float blur_radius,
const int faces_per_pixel,
const bool perspective_correct,
const bool clip_barycentric_coords,
const bool cull_backfaces);
#ifdef WITH_CUDA
@ -31,6 +32,7 @@ RasterizeMeshesNaiveCuda(
const float blur_radius,
const int num_closest,
const bool perspective_correct,
const bool clip_barycentric_coords,
const bool cull_backfaces);
#endif
// Forward pass for rasterizing a batch of meshes.
@ -92,6 +94,7 @@ RasterizeMeshesNaive(
const float blur_radius,
const int faces_per_pixel,
const bool perspective_correct,
const bool clip_barycentric_coords,
const bool cull_backfaces) {
// TODO: Better type checking.
if (face_verts.is_cuda()) {
@ -107,6 +110,7 @@ RasterizeMeshesNaive(
blur_radius,
faces_per_pixel,
perspective_correct,
clip_barycentric_coords,
cull_backfaces);
#else
AT_ERROR("Not compiled with GPU support");
@ -120,6 +124,7 @@ RasterizeMeshesNaive(
blur_radius,
faces_per_pixel,
perspective_correct,
clip_barycentric_coords,
cull_backfaces);
}
}
@ -134,7 +139,8 @@ torch::Tensor RasterizeMeshesBackwardCpu(
const torch::Tensor& grad_bary,
const torch::Tensor& grad_zbuf,
const torch::Tensor& grad_dists,
const bool perspective_correct);
const bool perspective_correct,
const bool clip_barycentric_coords);
#ifdef WITH_CUDA
torch::Tensor RasterizeMeshesBackwardCuda(
@ -143,7 +149,8 @@ torch::Tensor RasterizeMeshesBackwardCuda(
const torch::Tensor& grad_bary,
const torch::Tensor& grad_zbuf,
const torch::Tensor& grad_dists,
const bool perspective_correct);
const bool perspective_correct,
const bool clip_barycentric_coords);
#endif
// Args:
@ -176,7 +183,8 @@ torch::Tensor RasterizeMeshesBackward(
const torch::Tensor& grad_zbuf,
const torch::Tensor& grad_bary,
const torch::Tensor& grad_dists,
const bool perspective_correct) {
const bool perspective_correct,
const bool clip_barycentric_coords) {
if (face_verts.is_cuda()) {
#ifdef WITH_CUDA
CHECK_CUDA(face_verts);
@ -190,7 +198,8 @@ torch::Tensor RasterizeMeshesBackward(
grad_zbuf,
grad_bary,
grad_dists,
perspective_correct);
perspective_correct,
clip_barycentric_coords);
#else
AT_ERROR("Not compiled with GPU support");
#endif
@ -201,7 +210,8 @@ torch::Tensor RasterizeMeshesBackward(
grad_zbuf,
grad_bary,
grad_dists,
perspective_correct);
perspective_correct,
clip_barycentric_coords);
}
}
@ -300,6 +310,7 @@ RasterizeMeshesFineCuda(
const int bin_size,
const int faces_per_pixel,
const bool perspective_correct,
const bool clip_barycentric_coords,
const bool cull_backfaces);
#endif
// Args:
@ -356,6 +367,7 @@ RasterizeMeshesFine(
const int bin_size,
const int faces_per_pixel,
const bool perspective_correct,
const bool clip_barycentric_coords,
const bool cull_backfaces) {
if (face_verts.is_cuda()) {
#ifdef WITH_CUDA
@ -369,6 +381,7 @@ RasterizeMeshesFine(
bin_size,
faces_per_pixel,
perspective_correct,
clip_barycentric_coords,
cull_backfaces);
#else
AT_ERROR("Not compiled with GPU support");
@ -446,6 +459,7 @@ RasterizeMeshes(
const int bin_size,
const int max_faces_per_bin,
const bool perspective_correct,
const bool clip_barycentric_coords,
const bool cull_backfaces) {
if (bin_size > 0 && max_faces_per_bin > 0) {
// Use coarse-to-fine rasterization
@ -465,6 +479,7 @@ RasterizeMeshes(
bin_size,
faces_per_pixel,
perspective_correct,
clip_barycentric_coords,
cull_backfaces);
} else {
// Use the naive per-pixel implementation
@ -476,6 +491,7 @@ RasterizeMeshes(
blur_radius,
faces_per_pixel,
perspective_correct,
clip_barycentric_coords,
cull_backfaces);
}
}

View File

@ -108,6 +108,7 @@ RasterizeMeshesNaiveCpu(
const float blur_radius,
const int faces_per_pixel,
const bool perspective_correct,
const bool clip_barycentric_coords,
const bool cull_backfaces) {
if (face_verts.ndimension() != 3 || face_verts.size(1) != 3 ||
face_verts.size(2) != 3) {
@ -213,8 +214,12 @@ RasterizeMeshesNaiveCpu(
? bary0
: BarycentricPerspectiveCorrectionForward(bary0, z0, z1, z2);
const vec3<float> bary_clip =
!clip_barycentric_coords ? bary : BarycentricClipForward(bary);
// Use barycentric coordinates to get the depth of the current pixel
const float pz = (bary.x * z0 + bary.y * z1 + bary.z * z2);
const float pz =
(bary_clip.x * z0 + bary_clip.y * z1 + bary_clip.z * z2);
if (pz < 0) {
continue; // Point is behind the image plane so ignore.
@ -236,7 +241,7 @@ RasterizeMeshesNaiveCpu(
continue;
}
// The current pixel lies inside the current face.
q.emplace(pz, f, signed_dist, bary.x, bary.y, bary.z);
q.emplace(pz, f, signed_dist, bary_clip.x, bary_clip.y, bary_clip.z);
if (static_cast<int>(q.size()) > K) {
q.pop();
}
@ -264,7 +269,8 @@ torch::Tensor RasterizeMeshesBackwardCpu(
const torch::Tensor& grad_zbuf, // (N, H, W, K)
const torch::Tensor& grad_bary, // (N, H, W, K, 3)
const torch::Tensor& grad_dists, // (N, H, W, K)
const bool perspective_correct) {
const bool perspective_correct,
const bool clip_barycentric_coords) {
const int F = face_verts.size(0);
const int N = pix_to_face.size(0);
const int H = pix_to_face.size(1);
@ -335,6 +341,8 @@ torch::Tensor RasterizeMeshesBackwardCpu(
const vec3<float> bary = !perspective_correct
? bary0
: BarycentricPerspectiveCorrectionForward(bary0, z0, z1, z2);
const vec3<float> bary_clip =
!clip_barycentric_coords ? bary : BarycentricClipForward(bary);
// Distances inside the face are negative so get the
// correct sign to apply to the upstream gradient.
@ -354,22 +362,28 @@ torch::Tensor RasterizeMeshesBackwardCpu(
// d_zbuf/d_bary_w0 = z0
// d_zbuf/d_bary_w1 = z1
// d_zbuf/d_bary_w2 = z2
const vec3<float> d_zbuf_d_bary(z0, z1, z2);
const vec3<float> d_zbuf_d_baryclip(z0, z1, z2);
// Total upstream barycentric gradients are the sum of
// external upstream gradients and contribution from zbuf.
vec3<float> grad_bary_f_sum =
(grad_bary_upstream + grad_zbuf_upstream * d_zbuf_d_bary);
const vec3<float> grad_bary_f_sum =
(grad_bary_upstream + grad_zbuf_upstream * d_zbuf_d_baryclip);
vec3<float> grad_bary0 = grad_bary_f_sum;
if (clip_barycentric_coords) {
grad_bary0 = BarycentricClipBackward(bary, grad_bary0);
}
if (perspective_correct) {
auto perspective_grads = BarycentricPerspectiveCorrectionBackward(
bary0, z0, z1, z2, grad_bary_f_sum);
bary0, z0, z1, z2, grad_bary0);
grad_bary0 = std::get<0>(perspective_grads);
grad_face_verts[f][0][2] += std::get<1>(perspective_grads);
grad_face_verts[f][1][2] += std::get<2>(perspective_grads);
grad_face_verts[f][2][2] += std::get<3>(perspective_grads);
}
auto grad_bary_f =
BarycentricCoordsBackward(pxy, v0xy, v1xy, v2xy, grad_bary0);
const vec2<float> dbary_d_v0 = std::get<1>(grad_bary_f);
@ -379,13 +393,13 @@ torch::Tensor RasterizeMeshesBackwardCpu(
// Update output gradient buffer.
grad_face_verts[f][0][0] += dbary_d_v0.x + ddist_d_v0.x;
grad_face_verts[f][0][1] += dbary_d_v0.y + ddist_d_v0.y;
grad_face_verts[f][0][2] += grad_zbuf_upstream * bary.x;
grad_face_verts[f][0][2] += grad_zbuf_upstream * bary_clip.x;
grad_face_verts[f][1][0] += dbary_d_v1.x + ddist_d_v1.x;
grad_face_verts[f][1][1] += dbary_d_v1.y + ddist_d_v1.y;
grad_face_verts[f][1][2] += grad_zbuf_upstream * bary.y;
grad_face_verts[f][1][2] += grad_zbuf_upstream * bary_clip.y;
grad_face_verts[f][2][0] += dbary_d_v2.x + ddist_d_v2.x;
grad_face_verts[f][2][1] += dbary_d_v2.y + ddist_d_v2.y;
grad_face_verts[f][2][2] += grad_zbuf_upstream * bary.z;
grad_face_verts[f][2][2] += grad_zbuf_upstream * bary_clip.z;
}
}
}

View File

@ -221,8 +221,108 @@ BarycentricPerspectiveCorrectionBackward(
return thrust::make_tuple(grad_bary, grad_z0, grad_z1, grad_z2);
}
// Calculate minimum squared distance between a line segment (v1 - v0) and a
// point p.
// Clip negative barycentric coordinates to 0.0 and renormalize so
// the barycentric coordinates for a point sum to 1. When the blur_radius
// is greater than 0, a face will still be recorded as overlapping a pixel
// if the pixel is outisde the face. In this case at least one of the
// barycentric coordinates for the pixel relative to the face will be negative.
// Clipping will ensure that the texture and z buffer are interpolated
// correctly.
//
// Args
// bary: (w0, w1, w2) barycentric coordinates which can be outside the
// range [0, 1].
//
// Returns
// bary: (w0, w1, w2) barycentric coordinates in the range [0, 1] which
// satisfy the condition: sum(w0, w1, w2) = 1.0.
//
__device__ inline float3 BarycentricClipForward(const float3 bary) {
float3 w = make_float3(0.0f, 0.0f, 0.0f);
// Clamp lower bound only
w.x = max(bary.x, 0.0);
w.y = max(bary.y, 0.0);
w.z = max(bary.z, 0.0);
float w_sum = w.x + w.y + w.z;
w_sum = fmaxf(w_sum, 1e-5);
w.x /= w_sum;
w.y /= w_sum;
w.z /= w_sum;
return w;
}
// Backward pass for barycentric coordinate clipping.
//
// Args
// bary: (w0, w1, w2) barycentric coordinates which can be outside the
// range [0, 1].
// grad_baryclip_upstream: vec3<T> Upstream gradient for each of the clipped
// barycentric coordinates [grad_w0, grad_w1, grad_w2].
//
// Returns
// vec3<T> of gradients for the unclipped barycentric coordinates:
// (grad_w0, grad_w1, grad_w2)
//
__device__ inline float3 BarycentricClipBackward(
const float3 bary,
const float3 grad_baryclip_upstream) {
// Redo some of the forward pass calculations
float3 w = make_float3(0.0f, 0.0f, 0.0f);
// Clamp lower bound only
w.x = max(bary.x, 0.0);
w.y = max(bary.y, 0.0);
w.z = max(bary.z, 0.0);
float w_sum = w.x + w.y + w.z;
float3 grad_bary = make_float3(1.0f, 1.0f, 1.0f);
float3 grad_clip = make_float3(1.0f, 1.0f, 1.0f);
float3 grad_sum = make_float3(1.0f, 1.0f, 1.0f);
// Check if sum was clipped.
float grad_sum_clip = 1.0f;
if (w_sum < 1e-5) {
grad_sum_clip = 0.0f;
w_sum = 1e-5;
}
// Check if any of bary values have been clipped.
if (bary.x < 0.0f) {
grad_clip.x = 0.0f;
}
if (bary.y < 0.0f) {
grad_clip.y = 0.0f;
}
if (bary.z < 0.0f) {
grad_clip.z = 0.0f;
}
// Gradients of the sum.
grad_sum.x = -w.x / (pow(w_sum, 2.0f)) * grad_sum_clip;
grad_sum.y = -w.y / (pow(w_sum, 2.0f)) * grad_sum_clip;
grad_sum.z = -w.z / (pow(w_sum, 2.0f)) * grad_sum_clip;
// Gradients for each of the bary coordinates including the cross terms
// from the sum.
grad_bary.x = grad_clip.x *
(grad_baryclip_upstream.x * (1.0f / w_sum + grad_sum.x) +
grad_baryclip_upstream.y * (grad_sum.y) +
grad_baryclip_upstream.z * (grad_sum.z));
grad_bary.y = grad_clip.y *
(grad_baryclip_upstream.y * (1.0f / w_sum + grad_sum.y) +
grad_baryclip_upstream.x * (grad_sum.x) +
grad_baryclip_upstream.z * (grad_sum.z));
grad_bary.z = grad_clip.z *
(grad_baryclip_upstream.z * (1.0f / w_sum + grad_sum.z) +
grad_baryclip_upstream.x * (grad_sum.x) +
grad_baryclip_upstream.y * (grad_sum.y));
return grad_bary;
}
// Return minimum distance between line segment (v1 - v0) and point p.
//
// Args:
// p: Coordinates of a point.

View File

@ -242,8 +242,108 @@ inline std::tuple<vec3<T>, T, T, T> BarycentricPerspectiveCorrectionBackward(
return std::make_tuple(grad_bary, grad_z0, grad_z1, grad_z2);
}
// Calculate minimum squared distance between a line segment (v1 - v0) and a
// point p.
// Clip negative barycentric coordinates to 0.0 and renormalize so
// the barycentric coordinates for a point sum to 1. When the blur_radius
// is greater than 0, a face will still be recorded as overlapping a pixel
// if the pixel is outisde the face. In this case at least one of the
// barycentric coordinates for the pixel relative to the face will be negative.
// Clipping will ensure that the texture and z buffer are interpolated
// correctly.
//
// Args
// bary: (w0, w1, w2) barycentric coordinates which can contain values < 0.
//
// Returns
// bary: (w0, w1, w2) barycentric coordinates in the range [0, 1] which
// satisfy the condition: sum(w0, w1, w2) = 1.0.
//
template <typename T>
vec3<T> BarycentricClipForward(const vec3<T> bary) {
vec3<T> w(0.0f, 0.0f, 0.0f);
// Only clamp negative values to 0.0.
// No need to clamp values > 1.0 as they will be renormalized.
w.x = std::max(bary.x, 0.0f);
w.y = std::max(bary.y, 0.0f);
w.z = std::max(bary.z, 0.0f);
float w_sum = w.x + w.y + w.z;
w_sum = std::fmaxf(w_sum, 1e-5);
w.x /= w_sum;
w.y /= w_sum;
w.z /= w_sum;
return w;
}
// Backward pass for barycentric coordinate clipping.
//
// Args
// bary: (w0, w1, w2) barycentric coordinates which can contain values < 0.
// grad_baryclip_upstream: vec3<T> Upstream gradient for each of the clipped
// barycentric coordinates [grad_w0, grad_w1, grad_w2].
//
// Returns
// vec3<T> of gradients for the unclipped barycentric coordinates:
// (grad_w0, grad_w1, grad_w2)
//
template <typename T>
vec3<T> BarycentricClipBackward(
const vec3<T> bary,
const vec3<T> grad_baryclip_upstream) {
// Redo some of the forward pass calculations
vec3<T> w(0.0f, 0.0f, 0.0f);
w.x = std::max(bary.x, 0.0f);
w.y = std::max(bary.y, 0.0f);
w.z = std::max(bary.z, 0.0f);
float w_sum = w.x + w.y + w.z;
vec3<T> grad_bary(1.0f, 1.0f, 1.0f);
vec3<T> grad_clip(1.0f, 1.0f, 1.0f);
vec3<T> grad_sum(1.0f, 1.0f, 1.0f);
// Check if the sum was clipped.
float grad_sum_clip = 1.0f;
if (w_sum < 1e-5) {
grad_sum_clip = 0.0f;
w_sum = 1e-5;
}
// Check if any of the bary coordinates have been clipped.
// Only negative values are clamped to 0.0.
if (bary.x < 0.0f) {
grad_clip.x = 0.0f;
}
if (bary.y < 0.0f) {
grad_clip.y = 0.0f;
}
if (bary.z < 0.0f) {
grad_clip.z = 0.0f;
}
// Gradients of the sum.
grad_sum.x = -w.x / (pow(w_sum, 2.0f)) * grad_sum_clip;
grad_sum.y = -w.y / (pow(w_sum, 2.0f)) * grad_sum_clip;
grad_sum.z = -w.z / (pow(w_sum, 2.0f)) * grad_sum_clip;
// Gradients for each of the bary coordinates including the cross terms
// from the sum.
grad_bary.x = grad_clip.x *
(grad_baryclip_upstream.x * (1.0f / w_sum + grad_sum.x) +
grad_baryclip_upstream.y * (grad_sum.y) +
grad_baryclip_upstream.z * (grad_sum.z));
grad_bary.y = grad_clip.y *
(grad_baryclip_upstream.y * (1.0f / w_sum + grad_sum.y) +
grad_baryclip_upstream.x * (grad_sum.x) +
grad_baryclip_upstream.z * (grad_sum.z));
grad_bary.z = grad_clip.z *
(grad_baryclip_upstream.z * (1.0f / w_sum + grad_sum.z) +
grad_baryclip_upstream.x * (grad_sum.x) +
grad_baryclip_upstream.y * (grad_sum.y));
return grad_bary;
}
// Calculate minimum distance between a line segment (v1 - v0) and point p.
//
// Args:
// p: Coordinates of a point.

View File

@ -24,6 +24,7 @@ def rasterize_meshes(
bin_size: Optional[int] = None,
max_faces_per_bin: Optional[int] = None,
perspective_correct: bool = False,
clip_barycentric_coords: bool = False,
cull_backfaces: bool = False,
):
"""
@ -143,6 +144,7 @@ def rasterize_meshes(
bin_size,
max_faces_per_bin,
perspective_correct,
clip_barycentric_coords,
cull_backfaces,
)
@ -183,6 +185,7 @@ class _RasterizeFaceVerts(torch.autograd.Function):
bin_size: int = 0,
max_faces_per_bin: int = 0,
perspective_correct: bool = False,
clip_barycentric_coords: bool = False,
cull_backfaces: bool = False,
):
# pyre-fixme[16]: Module `pytorch3d` has no attribute `_C`.
@ -196,11 +199,13 @@ class _RasterizeFaceVerts(torch.autograd.Function):
bin_size,
max_faces_per_bin,
perspective_correct,
clip_barycentric_coords,
cull_backfaces,
)
ctx.save_for_backward(face_verts, pix_to_face)
ctx.mark_non_differentiable(pix_to_face)
ctx.perspective_correct = perspective_correct
ctx.clip_barycentric_coords = clip_barycentric_coords
return pix_to_face, zbuf, barycentric_coords, dists
@staticmethod
@ -214,6 +219,7 @@ class _RasterizeFaceVerts(torch.autograd.Function):
grad_bin_size = None
grad_max_faces_per_bin = None
grad_perspective_correct = None
grad_clip_barycentric_coords = None
grad_cull_backfaces = None
face_verts, pix_to_face = ctx.saved_tensors
grad_face_verts = _C.rasterize_meshes_backward(
@ -223,6 +229,7 @@ class _RasterizeFaceVerts(torch.autograd.Function):
grad_barycentric_coords,
grad_dists,
ctx.perspective_correct,
ctx.clip_barycentric_coords,
)
grads = (
grad_face_verts,
@ -234,6 +241,7 @@ class _RasterizeFaceVerts(torch.autograd.Function):
grad_bin_size,
grad_max_faces_per_bin,
grad_perspective_correct,
grad_clip_barycentric_coords,
grad_cull_backfaces,
)
return grads
@ -250,6 +258,7 @@ def rasterize_meshes_python(
blur_radius: float = 0.0,
faces_per_pixel: int = 8,
perspective_correct: bool = False,
clip_barycentric_coords: bool = False,
cull_backfaces: bool = False,
):
"""
@ -356,6 +365,14 @@ def rasterize_meshes_python(
top2 = z0 * z1 * l2
bot = top0 + top1 + top2
bary = torch.stack([top0 / bot, top1 / bot, top2 / bot])
# Check if inside before clipping
inside = all(x > 0.0 for x in bary)
# Barycentric clipping
if clip_barycentric_coords:
bary = barycentric_coordinates_clip(bary)
# use clipped barycentric coords to calculate the z value
pz = bary[0] * v0[2] + bary[1] * v1[2] + bary[2] * v2[2]
# Check if point is behind the image.
@ -365,7 +382,6 @@ def rasterize_meshes_python(
# Calculate signed 2D distance from point to face.
# Points inside the triangle have negative distance.
dist = point_triangle_distance(pxy, v0[:2], v1[:2], v2[:2])
inside = all(x > 0.0 for x in bary)
signed_dist = dist * -1.0 if inside else dist
@ -433,6 +449,33 @@ def edge_function(p, v0, v1):
return (p[0] - v0[0]) * (v1[1] - v0[1]) - (p[1] - v0[1]) * (v1[0] - v0[0])
def barycentric_coordinates_clip(bary):
"""
Clip negative barycentric coordinates to 0.0 and renormalize so
the barycentric coordinates for a point sum to 1. When the blur_radius
is greater than 0, a face will still be recorded as overlapping a pixel
if the pixel is outisde the face. In this case at least one of the
barycentric coordinates for the pixel relative to the face will be negative.
Clipping will ensure that the texture and z buffer are interpolated correctly.
Args:
bary: tuple of barycentric coordinates
Returns
bary_clip: (w0, w1, w2) barycentric coordinates with no negative values.
"""
# Only negative values are clamped to 0.0.
w0_clip = torch.clamp(bary[0], min=0.0)
w1_clip = torch.clamp(bary[1], min=0.0)
w2_clip = torch.clamp(bary[2], min=0.0)
bary_sum = torch.clamp(w0_clip + w1_clip + w2_clip, min=1e-5)
w0_clip = w0_clip / bary_sum
w1_clip = w1_clip / bary_sum
w2_clip = w2_clip / bary_sum
return (w0_clip, w1_clip, w2_clip)
def barycentric_coordinates(p, v0, v1, v2):
"""
Compute the barycentric coordinates of a point relative to a triangle.

View File

@ -26,6 +26,7 @@ class RasterizationSettings:
"bin_size",
"max_faces_per_bin",
"perspective_correct",
"clip_barycentric_coords",
"cull_backfaces",
]
@ -37,6 +38,7 @@ class RasterizationSettings:
bin_size: Optional[int] = None,
max_faces_per_bin: Optional[int] = None,
perspective_correct: bool = False,
clip_barycentric_coords: bool = False,
cull_backfaces: bool = False,
):
self.image_size = image_size
@ -45,6 +47,7 @@ class RasterizationSettings:
self.bin_size = bin_size
self.max_faces_per_bin = max_faces_per_bin
self.perspective_correct = perspective_correct
self.clip_barycentric_coords = clip_barycentric_coords
self.cull_backfaces = cull_backfaces
@ -127,6 +130,7 @@ class MeshRasterizer(nn.Module):
bin_size=raster_settings.bin_size,
max_faces_per_bin=raster_settings.max_faces_per_bin,
perspective_correct=raster_settings.perspective_correct,
clip_barycentric_coords=raster_settings.clip_barycentric_coords,
cull_backfaces=raster_settings.cull_backfaces,
)
return Fragments(

View File

@ -49,21 +49,6 @@ class MeshRenderer(nn.Module):
the range for the corresponding face.
"""
fragments = self.rasterizer(meshes_world, **kwargs)
raster_settings = kwargs.get("raster_settings", self.rasterizer.raster_settings)
if raster_settings.blur_radius > 0.0:
# TODO: potentially move barycentric clipping to the rasterizer
# if no downstream functions requires unclipped values.
# This will avoid unnecssary re-interpolation of the z buffer.
clipped_bary_coords = _clip_barycentric_coordinates(fragments.bary_coords)
clipped_zbuf = _interpolate_zbuf(
fragments.pix_to_face, clipped_bary_coords, meshes_world
)
fragments = Fragments(
bary_coords=clipped_bary_coords,
zbuf=clipped_zbuf,
dists=fragments.dists,
pix_to_face=fragments.pix_to_face,
)
images = self.shader(fragments, meshes_world, **kwargs)
return images

View File

@ -20,9 +20,13 @@ def _clip_barycentric_coordinates(bary) -> torch.Tensor:
if bary.shape[-1] != 3:
msg = "Expected barycentric coords to have last dim = 3; got %r"
raise ValueError(msg % (bary.shape,))
ndims = bary.ndim - 1
mask = bary.eq(-1).all(dim=-1, keepdim=True).expand(*((-1,) * ndims + (3,)))
clipped = bary.clamp(min=0.0)
clipped[mask] = 0.0
clipped_sum = torch.clamp(clipped.sum(dim=-1, keepdim=True), min=1e-5)
clipped = clipped / clipped_sum
clipped[mask] = -1.0
return clipped
@ -49,6 +53,8 @@ def _interpolate_zbuf(
verts = meshes.verts_packed()
faces = meshes.faces_packed()
faces_verts_z = verts[faces][..., 2][..., None] # (F, 3, 1)
return interpolate_face_attributes(pix_to_face, barycentric_coords, faces_verts_z)[
zbuf = interpolate_face_attributes(pix_to_face, barycentric_coords, faces_verts_z)[
..., 0
] # (1, H, W, K)
zbuf[pix_to_face == -1] = -1
return zbuf

View File

@ -0,0 +1,112 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from itertools import product
import torch
from fvcore.common.benchmark import benchmark
from pytorch3d.renderer.cameras import OpenGLPerspectiveCameras, look_at_view_transform
from pytorch3d.renderer.mesh.rasterizer import (
Fragments,
MeshRasterizer,
RasterizationSettings,
)
from pytorch3d.renderer.mesh.utils import (
_clip_barycentric_coordinates,
_interpolate_zbuf,
)
from pytorch3d.utils.ico_sphere import ico_sphere
def baryclip_cuda(
num_meshes: int = 8,
ico_level: int = 5,
image_size: int = 64,
faces_per_pixel: int = 50,
device="cuda",
):
# Init meshes
sphere_meshes = ico_sphere(ico_level, device).extend(num_meshes)
# Init transform
R, T = look_at_view_transform(1.0, 0.0, 0.0)
cameras = OpenGLPerspectiveCameras(device=device, R=R, T=T)
# Init rasterizer
raster_settings = RasterizationSettings(
image_size=image_size,
blur_radius=1e-4,
faces_per_pixel=faces_per_pixel,
clip_barycentric_coords=True,
)
rasterizer = MeshRasterizer(cameras=cameras, raster_settings=raster_settings)
torch.cuda.synchronize()
def raster_fn():
rasterizer(sphere_meshes)
torch.cuda.synchronize()
return raster_fn
def baryclip_pytorch(
num_meshes: int = 8,
ico_level: int = 5,
image_size: int = 64,
faces_per_pixel: int = 50,
device="cuda",
):
# Init meshes
sphere_meshes = ico_sphere(ico_level, device).extend(num_meshes)
# Init transform
R, T = look_at_view_transform(1.0, 0.0, 0.0)
cameras = OpenGLPerspectiveCameras(device=device, R=R, T=T)
# Init rasterizer
raster_settings = RasterizationSettings(
image_size=image_size,
blur_radius=1e-4,
faces_per_pixel=faces_per_pixel,
clip_barycentric_coords=False,
)
rasterizer = MeshRasterizer(cameras=cameras, raster_settings=raster_settings)
torch.cuda.synchronize()
def raster_fn():
fragments = rasterizer(sphere_meshes)
# Clip bary and reinterpolate
clipped_bary_coords = _clip_barycentric_coordinates(fragments.bary_coords)
clipped_zbuf = _interpolate_zbuf(
fragments.pix_to_face, clipped_bary_coords, sphere_meshes
)
fragments = Fragments(
bary_coords=clipped_bary_coords,
zbuf=clipped_zbuf,
dists=fragments.dists,
pix_to_face=fragments.pix_to_face,
)
torch.cuda.synchronize()
return raster_fn
def bm_barycentric_clip() -> None:
if torch.cuda.is_available():
kwargs_list = []
num_meshes = [1, 8]
ico_level = [0, 4]
image_size = [64, 128, 256]
faces_per_pixel = [10, 75, 100]
test_cases = product(num_meshes, ico_level, image_size, faces_per_pixel)
for case in test_cases:
n, ic, im, nf = case
kwargs_list.append(
{
"num_meshes": n,
"ico_level": ic,
"image_size": im,
"faces_per_pixel": nf,
}
)
benchmark(baryclip_cuda, "BARY_CLIP_CUDA", kwargs_list, warmup_iters=1)
benchmark(baryclip_pytorch, "BARY_CLIP_PYTORCH", kwargs_list, warmup_iters=1)

Binary file not shown.

Before

Width:  |  Height:  |  Size: 43 KiB

After

Width:  |  Height:  |  Size: 43 KiB

View File

@ -10,6 +10,10 @@ from pytorch3d.renderer.mesh.rasterize_meshes import (
rasterize_meshes,
rasterize_meshes_python,
)
from pytorch3d.renderer.mesh.utils import (
_clip_barycentric_coordinates,
_interpolate_zbuf,
)
from pytorch3d.structures import Meshes
from pytorch3d.utils import ico_sphere
@ -21,6 +25,7 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase):
self._simple_blurry_raster(rasterize_meshes_python, device, bin_size=-1)
self._test_behind_camera(rasterize_meshes_python, device, bin_size=-1)
self._test_perspective_correct(rasterize_meshes_python, device, bin_size=-1)
self._test_barycentric_clipping(rasterize_meshes_python, device, bin_size=-1)
self._test_back_face_culling(rasterize_meshes_python, device, bin_size=-1)
def test_simple_cpu_naive(self):
@ -170,8 +175,29 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase):
verts2.requires_grad = True
meshes_cuda = Meshes(verts=[verts2], faces=[faces2])
args_cpu = (meshes_cpu, image_size, radius, faces_per_pixel)
args_cuda = (meshes_cuda, image_size, radius, faces_per_pixel, 0, 0)
barycentric_clip = True
args_cpu = (
meshes_cpu,
image_size,
radius,
faces_per_pixel,
None,
None,
False,
barycentric_clip,
False,
)
args_cuda = (
meshes_cuda,
image_size,
radius,
faces_per_pixel,
0,
0,
False,
barycentric_clip,
False,
)
self._compare_impls(
rasterize_meshes,
rasterize_meshes,
@ -333,6 +359,39 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase):
idxs_cuda[:K] = sorted(idxs_cuda[:K])
self.assertEqual(idxs_cpu, idxs_cuda)
def test_python_vs_cpp_bary_clip(self):
torch.manual_seed(232)
N = 2
V = 10
F = 5
verts1 = torch.randn(N, V, 3, requires_grad=True)
verts2 = verts1.detach().clone().requires_grad_(True)
faces = torch.randint(V, size=(N, F, 3))
meshes1 = Meshes(verts1, faces)
meshes2 = Meshes(verts2, faces)
kwargs = {"image_size": 24, "clip_barycentric_coords": True}
fn1 = functools.partial(rasterize_meshes, meshes1, **kwargs)
fn2 = functools.partial(rasterize_meshes_python, meshes2, **kwargs)
args = ()
self._compare_impls(fn1, fn2, args, args, verts1, verts2, compare_grads=True)
def test_cpp_vs_cuda_bary_clip(self):
meshes = ico_sphere(2, device=torch.device("cpu"))
verts1, faces1 = meshes.get_mesh_verts_faces(0)
verts1.requires_grad = True
meshes1 = Meshes(verts=[verts1], faces=[faces1])
device = get_random_cuda_device()
verts2 = verts1.detach().to(device).requires_grad_(True)
faces2 = faces1.detach().clone().to(device)
meshes2 = Meshes(verts=[verts2], faces=[faces2])
kwargs = {"image_size": 64, "clip_barycentric_coords": True}
fn1 = functools.partial(rasterize_meshes, meshes1, **kwargs)
fn2 = functools.partial(rasterize_meshes, meshes2, bin_size=0, **kwargs)
args = ()
self._compare_impls(fn1, fn2, args, args, verts1, verts2, compare_grads=True)
def test_python_vs_cpp_perspective_correct(self):
torch.manual_seed(232)
N = 2
@ -621,6 +680,82 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase):
self.assertLess(zbuf_f_bary_diff, 1e-4)
self.assertLess(zbuf_t_bary_diff, 1e-4)
def _test_barycentric_clipping(self, rasterize_meshes_fn, device, bin_size=None):
# fmt: off
verts = torch.tensor([
[-0.4, -0.4, 10], # noqa: E241, E201
[ 0.4, -0.4, 10], # noqa: E241, E201
[ 0.0, 0.4, 20], # noqa: E241, E201
], dtype=torch.float32, device=device)
# fmt: on
faces = torch.tensor([[0, 1, 2]], device=device)
meshes = Meshes(verts=[verts], faces=[faces])
kwargs = {
"meshes": meshes,
"image_size": 5,
"faces_per_pixel": 1,
"blur_radius": 0.2,
"perspective_correct": False,
"clip_barycentric_coords": False, # Initially set this to false
}
if bin_size != -1:
kwargs["bin_size"] = bin_size
# Run with and without perspective correction
idx_f, zbuf_f, bary_f, dists_f = rasterize_meshes_fn(**kwargs)
# fmt: off
expected_bary = torch.tensor([
[
[-1.0000, -1.0000, -1.0000], # noqa: E241, E201
[-1.0000, -1.0000, -1.0000], # noqa: E241, E201
[-0.2500, -0.2500, 1.5000], # noqa: E241, E201
[-1.0000, -1.0000, -1.0000], # noqa: E241, E201
[-1.0000, -1.0000, -1.0000] # noqa: E241, E201
],
[
[-1.0000, -1.0000, -1.0000], # noqa: E241, E201
[-0.5000, 0.5000, 1.0000], # noqa: E241, E201
[-0.0000, -0.0000, 1.0000], # noqa: E241, E201
[ 0.5000, -0.5000, 1.0000], # noqa: E241, E201
[-1.0000, -1.0000, -1.0000] # noqa: E241, E201
],
[
[-1.0000, -1.0000, -1.0000], # noqa: E241, E201
[-0.2500, 0.7500, 0.5000], # noqa: E241, E201
[ 0.2500, 0.2500, 0.5000], # noqa: E241, E201
[ 0.7500, -0.2500, 0.5000], # noqa: E241, E201
[-1.0000, -1.0000, -1.0000] # noqa: E241, E201
],
[
[-0.5000, 1.5000, -0.0000], # noqa: E241, E201
[-0.0000, 1.0000, -0.0000], # noqa: E241, E201
[ 0.5000, 0.5000, -0.0000], # noqa: E241, E201
[ 1.0000, -0.0000, -0.0000], # noqa: E241, E201
[ 1.5000, -0.5000, 0.0000] # noqa: E241, E201
],
[
[-1.0000, -1.0000, -1.0000], # noqa: E241, E201
[ 0.2500, 1.2500, -0.5000], # noqa: E241, E201
[ 0.7500, 0.7500, -0.5000], # noqa: E241, E201
[ 1.2500, 0.2500, -0.5000], # noqa: E241, E201
[-1.0000, -1.0000, -1.0000] # noqa: E241, E201
]
], dtype=torch.float32, device=device).view(1, 5, 5, 1, 3)
# fmt: on
self.assertClose(expected_bary, bary_f, atol=1e-4)
# calculate the expected clipped barycentrics and zbuf
expected_bary_clipped = _clip_barycentric_coordinates(expected_bary)
expected_z_clipped = _interpolate_zbuf(idx_f, expected_bary_clipped, meshes)
kwargs["clip_barycentric_coords"] = True
idx_t, zbuf_t, bary_t, dists_t = rasterize_meshes_fn(**kwargs)
self.assertClose(expected_bary_clipped, bary_t, atol=1e-4)
self.assertClose(expected_z_clipped, zbuf_t, atol=1e-4)
def _test_behind_camera(self, rasterize_meshes_fn, device, bin_size=None):
"""
All verts are behind the camera so nothing should get rasterized.

View File

@ -212,6 +212,7 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
image_size=512,
blur_radius=np.log(1.0 / 1e-4 - 1.0) * blend_params.sigma,
faces_per_pixel=80,
clip_barycentric_coords=True,
)
# Init rasterizer settings
@ -269,11 +270,19 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
# the cow is facing the -z direction.
lights.location = torch.tensor([0.0, 0.0, 2.0], device=device)[None]
blend_params = BlendParams(
sigma=1e-1,
gamma=1e-4,
background_color=torch.tensor([1.0, 1.0, 1.0], device=device),
)
# Init renderer
renderer = MeshRenderer(
rasterizer=MeshRasterizer(cameras=cameras, raster_settings=raster_settings),
shader=TexturedSoftPhongShader(
lights=lights, cameras=cameras, materials=materials
lights=lights,
cameras=cameras,
materials=materials,
blend_params=blend_params,
),
)
@ -346,6 +355,7 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
image_size=512,
blur_radius=np.log(1.0 / 1e-4 - 1.0) * blend_params.sigma,
faces_per_pixel=100,
clip_barycentric_coords=True,
)
# Load reference image