mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-20 22:30:35 +08:00
barycentric clipping in cuda/c++
Summary: Added support for barycentric clipping in the C++/CUDA rasterization kernels which can be switched on/off via a rasterization setting. Added tests and a benchmark to compare with the current implementation in PyTorch - for some cases of large image size/faces per pixel the cuda version is 10x faster. Reviewed By: gkioxari Differential Revision: D21705503 fbshipit-source-id: e835c0f927f1e5088ca89020aef5ff27ac3a8769
This commit is contained in:
committed by
Facebook GitHub Bot
parent
bce396df93
commit
cc70950f40
@@ -114,6 +114,7 @@ __device__ void CheckPixelInsideFace(
|
||||
const float2 pxy, // Coordinates of the pixel
|
||||
const int K,
|
||||
const bool perspective_correct,
|
||||
const bool clip_barycentric_coords,
|
||||
const bool cull_backfaces) {
|
||||
const auto v012 = GetSingleFaceVerts(face_verts, face_idx);
|
||||
const float3 v0 = thrust::get<0>(v012);
|
||||
@@ -149,8 +150,12 @@ __device__ void CheckPixelInsideFace(
|
||||
const float3 p_bary = !perspective_correct
|
||||
? p_bary0
|
||||
: BarycentricPerspectiveCorrectionForward(p_bary0, v0.z, v1.z, v2.z);
|
||||
const float3 p_bary_clip =
|
||||
!clip_barycentric_coords ? p_bary : BarycentricClipForward(p_bary);
|
||||
|
||||
const float pz =
|
||||
p_bary_clip.x * v0.z + p_bary_clip.y * v1.z + p_bary_clip.z * v2.z;
|
||||
|
||||
const float pz = p_bary.x * v0.z + p_bary.y * v1.z + p_bary.z * v2.z;
|
||||
if (pz < 0) {
|
||||
return; // Face is behind the image plane.
|
||||
}
|
||||
@@ -158,7 +163,8 @@ __device__ void CheckPixelInsideFace(
|
||||
// Get abs squared distance
|
||||
const float dist = PointTriangleDistanceForward(pxy, v0xy, v1xy, v2xy);
|
||||
|
||||
// Use the bary coordinates to determine if the point is inside the face.
|
||||
// Use the unclipped bary coordinates to determine if the point is inside the
|
||||
// face.
|
||||
const bool inside = p_bary.x > 0.0f && p_bary.y > 0.0f && p_bary.z > 0.0f;
|
||||
const float signed_dist = inside ? -dist : dist;
|
||||
|
||||
@@ -169,7 +175,7 @@ __device__ void CheckPixelInsideFace(
|
||||
|
||||
if (q_size < K) {
|
||||
// Just insert it.
|
||||
q[q_size] = {pz, face_idx, signed_dist, p_bary};
|
||||
q[q_size] = {pz, face_idx, signed_dist, p_bary_clip};
|
||||
if (pz > q_max_z) {
|
||||
q_max_z = pz;
|
||||
q_max_idx = q_size;
|
||||
@@ -177,7 +183,7 @@ __device__ void CheckPixelInsideFace(
|
||||
q_size++;
|
||||
} else if (pz < q_max_z) {
|
||||
// Overwrite the old max, and find the new max.
|
||||
q[q_max_idx] = {pz, face_idx, signed_dist, p_bary};
|
||||
q[q_max_idx] = {pz, face_idx, signed_dist, p_bary_clip};
|
||||
q_max_z = pz;
|
||||
for (int i = 0; i < K; i++) {
|
||||
if (q[i].z > q_max_z) {
|
||||
@@ -198,6 +204,7 @@ __global__ void RasterizeMeshesNaiveCudaKernel(
|
||||
const int64_t* num_faces_per_mesh,
|
||||
const float blur_radius,
|
||||
const bool perspective_correct,
|
||||
const bool clip_barycentric_coords,
|
||||
const bool cull_backfaces,
|
||||
const int N,
|
||||
const int H,
|
||||
@@ -260,6 +267,7 @@ __global__ void RasterizeMeshesNaiveCudaKernel(
|
||||
pxy,
|
||||
K,
|
||||
perspective_correct,
|
||||
clip_barycentric_coords,
|
||||
cull_backfaces);
|
||||
}
|
||||
|
||||
@@ -286,6 +294,7 @@ RasterizeMeshesNaiveCuda(
|
||||
const float blur_radius,
|
||||
const int num_closest,
|
||||
const bool perspective_correct,
|
||||
const bool clip_barycentric_coords,
|
||||
const bool cull_backfaces) {
|
||||
TORCH_CHECK(
|
||||
face_verts.ndimension() == 3 && face_verts.size(1) == 3 &&
|
||||
@@ -343,6 +352,7 @@ RasterizeMeshesNaiveCuda(
|
||||
num_faces_per_mesh.contiguous().data_ptr<int64_t>(),
|
||||
blur_radius,
|
||||
perspective_correct,
|
||||
clip_barycentric_coords,
|
||||
cull_backfaces,
|
||||
N,
|
||||
H,
|
||||
@@ -365,6 +375,7 @@ __global__ void RasterizeMeshesBackwardCudaKernel(
|
||||
const float* face_verts, // (F, 3, 3)
|
||||
const int64_t* pix_to_face, // (N, H, W, K)
|
||||
const bool perspective_correct,
|
||||
const bool clip_barycentric_coords,
|
||||
const int N,
|
||||
const int H,
|
||||
const int W,
|
||||
@@ -422,11 +433,15 @@ __global__ void RasterizeMeshesBackwardCudaKernel(
|
||||
const float3 grad_bary_upstream = make_float3(
|
||||
grad_bary_upstream_w0, grad_bary_upstream_w1, grad_bary_upstream_w2);
|
||||
|
||||
const float3 bary0 = BarycentricCoordsForward(pxy, v0xy, v1xy, v2xy);
|
||||
const float3 bary = !perspective_correct
|
||||
? bary0
|
||||
: BarycentricPerspectiveCorrectionForward(bary0, v0.z, v1.z, v2.z);
|
||||
const bool inside = bary.x > 0.0f && bary.y > 0.0f && bary.z > 0.0f;
|
||||
const float3 b_w = BarycentricCoordsForward(pxy, v0xy, v1xy, v2xy);
|
||||
const float3 b_pp = !perspective_correct
|
||||
? b_w
|
||||
: BarycentricPerspectiveCorrectionForward(b_w, v0.z, v1.z, v2.z);
|
||||
|
||||
const float3 b_w_clip =
|
||||
!clip_barycentric_coords ? b_pp : BarycentricClipForward(b_pp);
|
||||
|
||||
const bool inside = b_pp.x > 0.0f && b_pp.y > 0.0f && b_pp.z > 0.0f;
|
||||
const float sign = inside ? -1.0f : 1.0f;
|
||||
|
||||
// TODO(T52813608) Add support for non-square images.
|
||||
@@ -442,22 +457,29 @@ __global__ void RasterizeMeshesBackwardCudaKernel(
|
||||
// d_zbuf/d_bary_w0 = z0
|
||||
// d_zbuf/d_bary_w1 = z1
|
||||
// d_zbuf/d_bary_w2 = z2
|
||||
const float3 d_zbuf_d_bary = make_float3(v0.z, v1.z, v2.z);
|
||||
const float3 d_zbuf_d_bwclip = make_float3(v0.z, v1.z, v2.z);
|
||||
|
||||
// Total upstream barycentric gradients are the sum of
|
||||
// external upstream gradients and contribution from zbuf.
|
||||
const float3 grad_bary_f_sum =
|
||||
(grad_bary_upstream + grad_zbuf_upstream * d_zbuf_d_bary);
|
||||
(grad_bary_upstream + grad_zbuf_upstream * d_zbuf_d_bwclip);
|
||||
|
||||
float3 grad_bary0 = grad_bary_f_sum;
|
||||
|
||||
if (clip_barycentric_coords) {
|
||||
grad_bary0 = BarycentricClipBackward(b_w, grad_bary_f_sum);
|
||||
}
|
||||
|
||||
float dz0_persp = 0.0f, dz1_persp = 0.0f, dz2_persp = 0.0f;
|
||||
if (perspective_correct) {
|
||||
auto perspective_grads = BarycentricPerspectiveCorrectionBackward(
|
||||
bary0, v0.z, v1.z, v2.z, grad_bary_f_sum);
|
||||
b_w, v0.z, v1.z, v2.z, grad_bary0);
|
||||
grad_bary0 = thrust::get<0>(perspective_grads);
|
||||
dz0_persp = thrust::get<1>(perspective_grads);
|
||||
dz1_persp = thrust::get<2>(perspective_grads);
|
||||
dz2_persp = thrust::get<3>(perspective_grads);
|
||||
}
|
||||
|
||||
auto grad_bary_f =
|
||||
BarycentricCoordsBackward(pxy, v0xy, v1xy, v2xy, grad_bary0);
|
||||
const float2 dbary_d_v0 = thrust::get<1>(grad_bary_f);
|
||||
@@ -467,15 +489,18 @@ __global__ void RasterizeMeshesBackwardCudaKernel(
|
||||
atomicAdd(grad_face_verts + f * 9 + 0, dbary_d_v0.x + ddist_d_v0.x);
|
||||
atomicAdd(grad_face_verts + f * 9 + 1, dbary_d_v0.y + ddist_d_v0.y);
|
||||
atomicAdd(
|
||||
grad_face_verts + f * 9 + 2, grad_zbuf_upstream * bary.x + dz0_persp);
|
||||
grad_face_verts + f * 9 + 2,
|
||||
grad_zbuf_upstream * b_w_clip.x + dz0_persp);
|
||||
atomicAdd(grad_face_verts + f * 9 + 3, dbary_d_v1.x + ddist_d_v1.x);
|
||||
atomicAdd(grad_face_verts + f * 9 + 4, dbary_d_v1.y + ddist_d_v1.y);
|
||||
atomicAdd(
|
||||
grad_face_verts + f * 9 + 5, grad_zbuf_upstream * bary.y + dz1_persp);
|
||||
grad_face_verts + f * 9 + 5,
|
||||
grad_zbuf_upstream * b_w_clip.y + dz1_persp);
|
||||
atomicAdd(grad_face_verts + f * 9 + 6, dbary_d_v2.x + ddist_d_v2.x);
|
||||
atomicAdd(grad_face_verts + f * 9 + 7, dbary_d_v2.y + ddist_d_v2.y);
|
||||
atomicAdd(
|
||||
grad_face_verts + f * 9 + 8, grad_zbuf_upstream * bary.z + dz2_persp);
|
||||
grad_face_verts + f * 9 + 8,
|
||||
grad_zbuf_upstream * b_w_clip.z + dz2_persp);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -486,7 +511,8 @@ at::Tensor RasterizeMeshesBackwardCuda(
|
||||
const at::Tensor& grad_zbuf, // (N, H, W, K)
|
||||
const at::Tensor& grad_bary, // (N, H, W, K, 3)
|
||||
const at::Tensor& grad_dists, // (N, H, W, K)
|
||||
const bool perspective_correct) {
|
||||
const bool perspective_correct,
|
||||
const bool clip_barycentric_coords) {
|
||||
// Check inputs are on the same device
|
||||
at::TensorArg face_verts_t{face_verts, "face_verts", 1},
|
||||
pix_to_face_t{pix_to_face, "pix_to_face", 2},
|
||||
@@ -523,6 +549,7 @@ at::Tensor RasterizeMeshesBackwardCuda(
|
||||
face_verts.contiguous().data_ptr<float>(),
|
||||
pix_to_face.contiguous().data_ptr<int64_t>(),
|
||||
perspective_correct,
|
||||
clip_barycentric_coords,
|
||||
N,
|
||||
H,
|
||||
W,
|
||||
@@ -743,6 +770,7 @@ __global__ void RasterizeMeshesFineCudaKernel(
|
||||
const float blur_radius,
|
||||
const int bin_size,
|
||||
const bool perspective_correct,
|
||||
const bool clip_barycentric_coords,
|
||||
const bool cull_backfaces,
|
||||
const int N,
|
||||
const int B,
|
||||
@@ -808,6 +836,7 @@ __global__ void RasterizeMeshesFineCudaKernel(
|
||||
pxy,
|
||||
K,
|
||||
perspective_correct,
|
||||
clip_barycentric_coords,
|
||||
cull_backfaces);
|
||||
}
|
||||
|
||||
@@ -841,6 +870,7 @@ RasterizeMeshesFineCuda(
|
||||
const int bin_size,
|
||||
const int faces_per_pixel,
|
||||
const bool perspective_correct,
|
||||
const bool clip_barycentric_coords,
|
||||
const bool cull_backfaces) {
|
||||
TORCH_CHECK(
|
||||
face_verts.ndimension() == 3 && face_verts.size(1) == 3 &&
|
||||
@@ -890,6 +920,7 @@ RasterizeMeshesFineCuda(
|
||||
blur_radius,
|
||||
bin_size,
|
||||
perspective_correct,
|
||||
clip_barycentric_coords,
|
||||
cull_backfaces,
|
||||
N,
|
||||
B,
|
||||
|
||||
@@ -19,6 +19,7 @@ RasterizeMeshesNaiveCpu(
|
||||
const float blur_radius,
|
||||
const int faces_per_pixel,
|
||||
const bool perspective_correct,
|
||||
const bool clip_barycentric_coords,
|
||||
const bool cull_backfaces);
|
||||
|
||||
#ifdef WITH_CUDA
|
||||
@@ -31,6 +32,7 @@ RasterizeMeshesNaiveCuda(
|
||||
const float blur_radius,
|
||||
const int num_closest,
|
||||
const bool perspective_correct,
|
||||
const bool clip_barycentric_coords,
|
||||
const bool cull_backfaces);
|
||||
#endif
|
||||
// Forward pass for rasterizing a batch of meshes.
|
||||
@@ -92,6 +94,7 @@ RasterizeMeshesNaive(
|
||||
const float blur_radius,
|
||||
const int faces_per_pixel,
|
||||
const bool perspective_correct,
|
||||
const bool clip_barycentric_coords,
|
||||
const bool cull_backfaces) {
|
||||
// TODO: Better type checking.
|
||||
if (face_verts.is_cuda()) {
|
||||
@@ -107,6 +110,7 @@ RasterizeMeshesNaive(
|
||||
blur_radius,
|
||||
faces_per_pixel,
|
||||
perspective_correct,
|
||||
clip_barycentric_coords,
|
||||
cull_backfaces);
|
||||
#else
|
||||
AT_ERROR("Not compiled with GPU support");
|
||||
@@ -120,6 +124,7 @@ RasterizeMeshesNaive(
|
||||
blur_radius,
|
||||
faces_per_pixel,
|
||||
perspective_correct,
|
||||
clip_barycentric_coords,
|
||||
cull_backfaces);
|
||||
}
|
||||
}
|
||||
@@ -134,7 +139,8 @@ torch::Tensor RasterizeMeshesBackwardCpu(
|
||||
const torch::Tensor& grad_bary,
|
||||
const torch::Tensor& grad_zbuf,
|
||||
const torch::Tensor& grad_dists,
|
||||
const bool perspective_correct);
|
||||
const bool perspective_correct,
|
||||
const bool clip_barycentric_coords);
|
||||
|
||||
#ifdef WITH_CUDA
|
||||
torch::Tensor RasterizeMeshesBackwardCuda(
|
||||
@@ -143,7 +149,8 @@ torch::Tensor RasterizeMeshesBackwardCuda(
|
||||
const torch::Tensor& grad_bary,
|
||||
const torch::Tensor& grad_zbuf,
|
||||
const torch::Tensor& grad_dists,
|
||||
const bool perspective_correct);
|
||||
const bool perspective_correct,
|
||||
const bool clip_barycentric_coords);
|
||||
#endif
|
||||
|
||||
// Args:
|
||||
@@ -176,7 +183,8 @@ torch::Tensor RasterizeMeshesBackward(
|
||||
const torch::Tensor& grad_zbuf,
|
||||
const torch::Tensor& grad_bary,
|
||||
const torch::Tensor& grad_dists,
|
||||
const bool perspective_correct) {
|
||||
const bool perspective_correct,
|
||||
const bool clip_barycentric_coords) {
|
||||
if (face_verts.is_cuda()) {
|
||||
#ifdef WITH_CUDA
|
||||
CHECK_CUDA(face_verts);
|
||||
@@ -190,7 +198,8 @@ torch::Tensor RasterizeMeshesBackward(
|
||||
grad_zbuf,
|
||||
grad_bary,
|
||||
grad_dists,
|
||||
perspective_correct);
|
||||
perspective_correct,
|
||||
clip_barycentric_coords);
|
||||
#else
|
||||
AT_ERROR("Not compiled with GPU support");
|
||||
#endif
|
||||
@@ -201,7 +210,8 @@ torch::Tensor RasterizeMeshesBackward(
|
||||
grad_zbuf,
|
||||
grad_bary,
|
||||
grad_dists,
|
||||
perspective_correct);
|
||||
perspective_correct,
|
||||
clip_barycentric_coords);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -300,6 +310,7 @@ RasterizeMeshesFineCuda(
|
||||
const int bin_size,
|
||||
const int faces_per_pixel,
|
||||
const bool perspective_correct,
|
||||
const bool clip_barycentric_coords,
|
||||
const bool cull_backfaces);
|
||||
#endif
|
||||
// Args:
|
||||
@@ -356,6 +367,7 @@ RasterizeMeshesFine(
|
||||
const int bin_size,
|
||||
const int faces_per_pixel,
|
||||
const bool perspective_correct,
|
||||
const bool clip_barycentric_coords,
|
||||
const bool cull_backfaces) {
|
||||
if (face_verts.is_cuda()) {
|
||||
#ifdef WITH_CUDA
|
||||
@@ -369,6 +381,7 @@ RasterizeMeshesFine(
|
||||
bin_size,
|
||||
faces_per_pixel,
|
||||
perspective_correct,
|
||||
clip_barycentric_coords,
|
||||
cull_backfaces);
|
||||
#else
|
||||
AT_ERROR("Not compiled with GPU support");
|
||||
@@ -446,6 +459,7 @@ RasterizeMeshes(
|
||||
const int bin_size,
|
||||
const int max_faces_per_bin,
|
||||
const bool perspective_correct,
|
||||
const bool clip_barycentric_coords,
|
||||
const bool cull_backfaces) {
|
||||
if (bin_size > 0 && max_faces_per_bin > 0) {
|
||||
// Use coarse-to-fine rasterization
|
||||
@@ -465,6 +479,7 @@ RasterizeMeshes(
|
||||
bin_size,
|
||||
faces_per_pixel,
|
||||
perspective_correct,
|
||||
clip_barycentric_coords,
|
||||
cull_backfaces);
|
||||
} else {
|
||||
// Use the naive per-pixel implementation
|
||||
@@ -476,6 +491,7 @@ RasterizeMeshes(
|
||||
blur_radius,
|
||||
faces_per_pixel,
|
||||
perspective_correct,
|
||||
clip_barycentric_coords,
|
||||
cull_backfaces);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -108,6 +108,7 @@ RasterizeMeshesNaiveCpu(
|
||||
const float blur_radius,
|
||||
const int faces_per_pixel,
|
||||
const bool perspective_correct,
|
||||
const bool clip_barycentric_coords,
|
||||
const bool cull_backfaces) {
|
||||
if (face_verts.ndimension() != 3 || face_verts.size(1) != 3 ||
|
||||
face_verts.size(2) != 3) {
|
||||
@@ -213,8 +214,12 @@ RasterizeMeshesNaiveCpu(
|
||||
? bary0
|
||||
: BarycentricPerspectiveCorrectionForward(bary0, z0, z1, z2);
|
||||
|
||||
const vec3<float> bary_clip =
|
||||
!clip_barycentric_coords ? bary : BarycentricClipForward(bary);
|
||||
|
||||
// Use barycentric coordinates to get the depth of the current pixel
|
||||
const float pz = (bary.x * z0 + bary.y * z1 + bary.z * z2);
|
||||
const float pz =
|
||||
(bary_clip.x * z0 + bary_clip.y * z1 + bary_clip.z * z2);
|
||||
|
||||
if (pz < 0) {
|
||||
continue; // Point is behind the image plane so ignore.
|
||||
@@ -236,7 +241,7 @@ RasterizeMeshesNaiveCpu(
|
||||
continue;
|
||||
}
|
||||
// The current pixel lies inside the current face.
|
||||
q.emplace(pz, f, signed_dist, bary.x, bary.y, bary.z);
|
||||
q.emplace(pz, f, signed_dist, bary_clip.x, bary_clip.y, bary_clip.z);
|
||||
if (static_cast<int>(q.size()) > K) {
|
||||
q.pop();
|
||||
}
|
||||
@@ -264,7 +269,8 @@ torch::Tensor RasterizeMeshesBackwardCpu(
|
||||
const torch::Tensor& grad_zbuf, // (N, H, W, K)
|
||||
const torch::Tensor& grad_bary, // (N, H, W, K, 3)
|
||||
const torch::Tensor& grad_dists, // (N, H, W, K)
|
||||
const bool perspective_correct) {
|
||||
const bool perspective_correct,
|
||||
const bool clip_barycentric_coords) {
|
||||
const int F = face_verts.size(0);
|
||||
const int N = pix_to_face.size(0);
|
||||
const int H = pix_to_face.size(1);
|
||||
@@ -335,6 +341,8 @@ torch::Tensor RasterizeMeshesBackwardCpu(
|
||||
const vec3<float> bary = !perspective_correct
|
||||
? bary0
|
||||
: BarycentricPerspectiveCorrectionForward(bary0, z0, z1, z2);
|
||||
const vec3<float> bary_clip =
|
||||
!clip_barycentric_coords ? bary : BarycentricClipForward(bary);
|
||||
|
||||
// Distances inside the face are negative so get the
|
||||
// correct sign to apply to the upstream gradient.
|
||||
@@ -354,22 +362,28 @@ torch::Tensor RasterizeMeshesBackwardCpu(
|
||||
// d_zbuf/d_bary_w0 = z0
|
||||
// d_zbuf/d_bary_w1 = z1
|
||||
// d_zbuf/d_bary_w2 = z2
|
||||
const vec3<float> d_zbuf_d_bary(z0, z1, z2);
|
||||
const vec3<float> d_zbuf_d_baryclip(z0, z1, z2);
|
||||
|
||||
// Total upstream barycentric gradients are the sum of
|
||||
// external upstream gradients and contribution from zbuf.
|
||||
vec3<float> grad_bary_f_sum =
|
||||
(grad_bary_upstream + grad_zbuf_upstream * d_zbuf_d_bary);
|
||||
const vec3<float> grad_bary_f_sum =
|
||||
(grad_bary_upstream + grad_zbuf_upstream * d_zbuf_d_baryclip);
|
||||
|
||||
vec3<float> grad_bary0 = grad_bary_f_sum;
|
||||
|
||||
if (clip_barycentric_coords) {
|
||||
grad_bary0 = BarycentricClipBackward(bary, grad_bary0);
|
||||
}
|
||||
|
||||
if (perspective_correct) {
|
||||
auto perspective_grads = BarycentricPerspectiveCorrectionBackward(
|
||||
bary0, z0, z1, z2, grad_bary_f_sum);
|
||||
bary0, z0, z1, z2, grad_bary0);
|
||||
grad_bary0 = std::get<0>(perspective_grads);
|
||||
grad_face_verts[f][0][2] += std::get<1>(perspective_grads);
|
||||
grad_face_verts[f][1][2] += std::get<2>(perspective_grads);
|
||||
grad_face_verts[f][2][2] += std::get<3>(perspective_grads);
|
||||
}
|
||||
|
||||
auto grad_bary_f =
|
||||
BarycentricCoordsBackward(pxy, v0xy, v1xy, v2xy, grad_bary0);
|
||||
const vec2<float> dbary_d_v0 = std::get<1>(grad_bary_f);
|
||||
@@ -379,13 +393,13 @@ torch::Tensor RasterizeMeshesBackwardCpu(
|
||||
// Update output gradient buffer.
|
||||
grad_face_verts[f][0][0] += dbary_d_v0.x + ddist_d_v0.x;
|
||||
grad_face_verts[f][0][1] += dbary_d_v0.y + ddist_d_v0.y;
|
||||
grad_face_verts[f][0][2] += grad_zbuf_upstream * bary.x;
|
||||
grad_face_verts[f][0][2] += grad_zbuf_upstream * bary_clip.x;
|
||||
grad_face_verts[f][1][0] += dbary_d_v1.x + ddist_d_v1.x;
|
||||
grad_face_verts[f][1][1] += dbary_d_v1.y + ddist_d_v1.y;
|
||||
grad_face_verts[f][1][2] += grad_zbuf_upstream * bary.y;
|
||||
grad_face_verts[f][1][2] += grad_zbuf_upstream * bary_clip.y;
|
||||
grad_face_verts[f][2][0] += dbary_d_v2.x + ddist_d_v2.x;
|
||||
grad_face_verts[f][2][1] += dbary_d_v2.y + ddist_d_v2.y;
|
||||
grad_face_verts[f][2][2] += grad_zbuf_upstream * bary.z;
|
||||
grad_face_verts[f][2][2] += grad_zbuf_upstream * bary_clip.z;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -221,8 +221,108 @@ BarycentricPerspectiveCorrectionBackward(
|
||||
return thrust::make_tuple(grad_bary, grad_z0, grad_z1, grad_z2);
|
||||
}
|
||||
|
||||
// Calculate minimum squared distance between a line segment (v1 - v0) and a
|
||||
// point p.
|
||||
// Clip negative barycentric coordinates to 0.0 and renormalize so
|
||||
// the barycentric coordinates for a point sum to 1. When the blur_radius
|
||||
// is greater than 0, a face will still be recorded as overlapping a pixel
|
||||
// if the pixel is outisde the face. In this case at least one of the
|
||||
// barycentric coordinates for the pixel relative to the face will be negative.
|
||||
// Clipping will ensure that the texture and z buffer are interpolated
|
||||
// correctly.
|
||||
//
|
||||
// Args
|
||||
// bary: (w0, w1, w2) barycentric coordinates which can be outside the
|
||||
// range [0, 1].
|
||||
//
|
||||
// Returns
|
||||
// bary: (w0, w1, w2) barycentric coordinates in the range [0, 1] which
|
||||
// satisfy the condition: sum(w0, w1, w2) = 1.0.
|
||||
//
|
||||
__device__ inline float3 BarycentricClipForward(const float3 bary) {
|
||||
float3 w = make_float3(0.0f, 0.0f, 0.0f);
|
||||
// Clamp lower bound only
|
||||
w.x = max(bary.x, 0.0);
|
||||
w.y = max(bary.y, 0.0);
|
||||
w.z = max(bary.z, 0.0);
|
||||
float w_sum = w.x + w.y + w.z;
|
||||
w_sum = fmaxf(w_sum, 1e-5);
|
||||
w.x /= w_sum;
|
||||
w.y /= w_sum;
|
||||
w.z /= w_sum;
|
||||
|
||||
return w;
|
||||
}
|
||||
|
||||
// Backward pass for barycentric coordinate clipping.
|
||||
//
|
||||
// Args
|
||||
// bary: (w0, w1, w2) barycentric coordinates which can be outside the
|
||||
// range [0, 1].
|
||||
// grad_baryclip_upstream: vec3<T> Upstream gradient for each of the clipped
|
||||
// barycentric coordinates [grad_w0, grad_w1, grad_w2].
|
||||
//
|
||||
// Returns
|
||||
// vec3<T> of gradients for the unclipped barycentric coordinates:
|
||||
// (grad_w0, grad_w1, grad_w2)
|
||||
//
|
||||
__device__ inline float3 BarycentricClipBackward(
|
||||
const float3 bary,
|
||||
const float3 grad_baryclip_upstream) {
|
||||
// Redo some of the forward pass calculations
|
||||
float3 w = make_float3(0.0f, 0.0f, 0.0f);
|
||||
// Clamp lower bound only
|
||||
w.x = max(bary.x, 0.0);
|
||||
w.y = max(bary.y, 0.0);
|
||||
w.z = max(bary.z, 0.0);
|
||||
float w_sum = w.x + w.y + w.z;
|
||||
|
||||
float3 grad_bary = make_float3(1.0f, 1.0f, 1.0f);
|
||||
float3 grad_clip = make_float3(1.0f, 1.0f, 1.0f);
|
||||
float3 grad_sum = make_float3(1.0f, 1.0f, 1.0f);
|
||||
|
||||
// Check if sum was clipped.
|
||||
float grad_sum_clip = 1.0f;
|
||||
if (w_sum < 1e-5) {
|
||||
grad_sum_clip = 0.0f;
|
||||
w_sum = 1e-5;
|
||||
}
|
||||
|
||||
// Check if any of bary values have been clipped.
|
||||
if (bary.x < 0.0f) {
|
||||
grad_clip.x = 0.0f;
|
||||
}
|
||||
if (bary.y < 0.0f) {
|
||||
grad_clip.y = 0.0f;
|
||||
}
|
||||
if (bary.z < 0.0f) {
|
||||
grad_clip.z = 0.0f;
|
||||
}
|
||||
|
||||
// Gradients of the sum.
|
||||
grad_sum.x = -w.x / (pow(w_sum, 2.0f)) * grad_sum_clip;
|
||||
grad_sum.y = -w.y / (pow(w_sum, 2.0f)) * grad_sum_clip;
|
||||
grad_sum.z = -w.z / (pow(w_sum, 2.0f)) * grad_sum_clip;
|
||||
|
||||
// Gradients for each of the bary coordinates including the cross terms
|
||||
// from the sum.
|
||||
grad_bary.x = grad_clip.x *
|
||||
(grad_baryclip_upstream.x * (1.0f / w_sum + grad_sum.x) +
|
||||
grad_baryclip_upstream.y * (grad_sum.y) +
|
||||
grad_baryclip_upstream.z * (grad_sum.z));
|
||||
|
||||
grad_bary.y = grad_clip.y *
|
||||
(grad_baryclip_upstream.y * (1.0f / w_sum + grad_sum.y) +
|
||||
grad_baryclip_upstream.x * (grad_sum.x) +
|
||||
grad_baryclip_upstream.z * (grad_sum.z));
|
||||
|
||||
grad_bary.z = grad_clip.z *
|
||||
(grad_baryclip_upstream.z * (1.0f / w_sum + grad_sum.z) +
|
||||
grad_baryclip_upstream.x * (grad_sum.x) +
|
||||
grad_baryclip_upstream.y * (grad_sum.y));
|
||||
|
||||
return grad_bary;
|
||||
}
|
||||
|
||||
// Return minimum distance between line segment (v1 - v0) and point p.
|
||||
//
|
||||
// Args:
|
||||
// p: Coordinates of a point.
|
||||
|
||||
@@ -242,8 +242,108 @@ inline std::tuple<vec3<T>, T, T, T> BarycentricPerspectiveCorrectionBackward(
|
||||
return std::make_tuple(grad_bary, grad_z0, grad_z1, grad_z2);
|
||||
}
|
||||
|
||||
// Calculate minimum squared distance between a line segment (v1 - v0) and a
|
||||
// point p.
|
||||
// Clip negative barycentric coordinates to 0.0 and renormalize so
|
||||
// the barycentric coordinates for a point sum to 1. When the blur_radius
|
||||
// is greater than 0, a face will still be recorded as overlapping a pixel
|
||||
// if the pixel is outisde the face. In this case at least one of the
|
||||
// barycentric coordinates for the pixel relative to the face will be negative.
|
||||
// Clipping will ensure that the texture and z buffer are interpolated
|
||||
// correctly.
|
||||
//
|
||||
// Args
|
||||
// bary: (w0, w1, w2) barycentric coordinates which can contain values < 0.
|
||||
//
|
||||
// Returns
|
||||
// bary: (w0, w1, w2) barycentric coordinates in the range [0, 1] which
|
||||
// satisfy the condition: sum(w0, w1, w2) = 1.0.
|
||||
//
|
||||
template <typename T>
|
||||
vec3<T> BarycentricClipForward(const vec3<T> bary) {
|
||||
vec3<T> w(0.0f, 0.0f, 0.0f);
|
||||
// Only clamp negative values to 0.0.
|
||||
// No need to clamp values > 1.0 as they will be renormalized.
|
||||
w.x = std::max(bary.x, 0.0f);
|
||||
w.y = std::max(bary.y, 0.0f);
|
||||
w.z = std::max(bary.z, 0.0f);
|
||||
float w_sum = w.x + w.y + w.z;
|
||||
w_sum = std::fmaxf(w_sum, 1e-5);
|
||||
w.x /= w_sum;
|
||||
w.y /= w_sum;
|
||||
w.z /= w_sum;
|
||||
return w;
|
||||
}
|
||||
|
||||
// Backward pass for barycentric coordinate clipping.
|
||||
//
|
||||
// Args
|
||||
// bary: (w0, w1, w2) barycentric coordinates which can contain values < 0.
|
||||
// grad_baryclip_upstream: vec3<T> Upstream gradient for each of the clipped
|
||||
// barycentric coordinates [grad_w0, grad_w1, grad_w2].
|
||||
//
|
||||
// Returns
|
||||
// vec3<T> of gradients for the unclipped barycentric coordinates:
|
||||
// (grad_w0, grad_w1, grad_w2)
|
||||
//
|
||||
template <typename T>
|
||||
vec3<T> BarycentricClipBackward(
|
||||
const vec3<T> bary,
|
||||
const vec3<T> grad_baryclip_upstream) {
|
||||
// Redo some of the forward pass calculations
|
||||
vec3<T> w(0.0f, 0.0f, 0.0f);
|
||||
w.x = std::max(bary.x, 0.0f);
|
||||
w.y = std::max(bary.y, 0.0f);
|
||||
w.z = std::max(bary.z, 0.0f);
|
||||
float w_sum = w.x + w.y + w.z;
|
||||
|
||||
vec3<T> grad_bary(1.0f, 1.0f, 1.0f);
|
||||
vec3<T> grad_clip(1.0f, 1.0f, 1.0f);
|
||||
vec3<T> grad_sum(1.0f, 1.0f, 1.0f);
|
||||
|
||||
// Check if the sum was clipped.
|
||||
float grad_sum_clip = 1.0f;
|
||||
if (w_sum < 1e-5) {
|
||||
grad_sum_clip = 0.0f;
|
||||
w_sum = 1e-5;
|
||||
}
|
||||
|
||||
// Check if any of the bary coordinates have been clipped.
|
||||
// Only negative values are clamped to 0.0.
|
||||
if (bary.x < 0.0f) {
|
||||
grad_clip.x = 0.0f;
|
||||
}
|
||||
if (bary.y < 0.0f) {
|
||||
grad_clip.y = 0.0f;
|
||||
}
|
||||
if (bary.z < 0.0f) {
|
||||
grad_clip.z = 0.0f;
|
||||
}
|
||||
|
||||
// Gradients of the sum.
|
||||
grad_sum.x = -w.x / (pow(w_sum, 2.0f)) * grad_sum_clip;
|
||||
grad_sum.y = -w.y / (pow(w_sum, 2.0f)) * grad_sum_clip;
|
||||
grad_sum.z = -w.z / (pow(w_sum, 2.0f)) * grad_sum_clip;
|
||||
|
||||
// Gradients for each of the bary coordinates including the cross terms
|
||||
// from the sum.
|
||||
grad_bary.x = grad_clip.x *
|
||||
(grad_baryclip_upstream.x * (1.0f / w_sum + grad_sum.x) +
|
||||
grad_baryclip_upstream.y * (grad_sum.y) +
|
||||
grad_baryclip_upstream.z * (grad_sum.z));
|
||||
|
||||
grad_bary.y = grad_clip.y *
|
||||
(grad_baryclip_upstream.y * (1.0f / w_sum + grad_sum.y) +
|
||||
grad_baryclip_upstream.x * (grad_sum.x) +
|
||||
grad_baryclip_upstream.z * (grad_sum.z));
|
||||
|
||||
grad_bary.z = grad_clip.z *
|
||||
(grad_baryclip_upstream.z * (1.0f / w_sum + grad_sum.z) +
|
||||
grad_baryclip_upstream.x * (grad_sum.x) +
|
||||
grad_baryclip_upstream.y * (grad_sum.y));
|
||||
|
||||
return grad_bary;
|
||||
}
|
||||
|
||||
// Calculate minimum distance between a line segment (v1 - v0) and point p.
|
||||
//
|
||||
// Args:
|
||||
// p: Coordinates of a point.
|
||||
|
||||
@@ -24,6 +24,7 @@ def rasterize_meshes(
|
||||
bin_size: Optional[int] = None,
|
||||
max_faces_per_bin: Optional[int] = None,
|
||||
perspective_correct: bool = False,
|
||||
clip_barycentric_coords: bool = False,
|
||||
cull_backfaces: bool = False,
|
||||
):
|
||||
"""
|
||||
@@ -143,6 +144,7 @@ def rasterize_meshes(
|
||||
bin_size,
|
||||
max_faces_per_bin,
|
||||
perspective_correct,
|
||||
clip_barycentric_coords,
|
||||
cull_backfaces,
|
||||
)
|
||||
|
||||
@@ -183,6 +185,7 @@ class _RasterizeFaceVerts(torch.autograd.Function):
|
||||
bin_size: int = 0,
|
||||
max_faces_per_bin: int = 0,
|
||||
perspective_correct: bool = False,
|
||||
clip_barycentric_coords: bool = False,
|
||||
cull_backfaces: bool = False,
|
||||
):
|
||||
# pyre-fixme[16]: Module `pytorch3d` has no attribute `_C`.
|
||||
@@ -196,11 +199,13 @@ class _RasterizeFaceVerts(torch.autograd.Function):
|
||||
bin_size,
|
||||
max_faces_per_bin,
|
||||
perspective_correct,
|
||||
clip_barycentric_coords,
|
||||
cull_backfaces,
|
||||
)
|
||||
ctx.save_for_backward(face_verts, pix_to_face)
|
||||
ctx.mark_non_differentiable(pix_to_face)
|
||||
ctx.perspective_correct = perspective_correct
|
||||
ctx.clip_barycentric_coords = clip_barycentric_coords
|
||||
return pix_to_face, zbuf, barycentric_coords, dists
|
||||
|
||||
@staticmethod
|
||||
@@ -214,6 +219,7 @@ class _RasterizeFaceVerts(torch.autograd.Function):
|
||||
grad_bin_size = None
|
||||
grad_max_faces_per_bin = None
|
||||
grad_perspective_correct = None
|
||||
grad_clip_barycentric_coords = None
|
||||
grad_cull_backfaces = None
|
||||
face_verts, pix_to_face = ctx.saved_tensors
|
||||
grad_face_verts = _C.rasterize_meshes_backward(
|
||||
@@ -223,6 +229,7 @@ class _RasterizeFaceVerts(torch.autograd.Function):
|
||||
grad_barycentric_coords,
|
||||
grad_dists,
|
||||
ctx.perspective_correct,
|
||||
ctx.clip_barycentric_coords,
|
||||
)
|
||||
grads = (
|
||||
grad_face_verts,
|
||||
@@ -234,6 +241,7 @@ class _RasterizeFaceVerts(torch.autograd.Function):
|
||||
grad_bin_size,
|
||||
grad_max_faces_per_bin,
|
||||
grad_perspective_correct,
|
||||
grad_clip_barycentric_coords,
|
||||
grad_cull_backfaces,
|
||||
)
|
||||
return grads
|
||||
@@ -250,6 +258,7 @@ def rasterize_meshes_python(
|
||||
blur_radius: float = 0.0,
|
||||
faces_per_pixel: int = 8,
|
||||
perspective_correct: bool = False,
|
||||
clip_barycentric_coords: bool = False,
|
||||
cull_backfaces: bool = False,
|
||||
):
|
||||
"""
|
||||
@@ -356,6 +365,14 @@ def rasterize_meshes_python(
|
||||
top2 = z0 * z1 * l2
|
||||
bot = top0 + top1 + top2
|
||||
bary = torch.stack([top0 / bot, top1 / bot, top2 / bot])
|
||||
|
||||
# Check if inside before clipping
|
||||
inside = all(x > 0.0 for x in bary)
|
||||
|
||||
# Barycentric clipping
|
||||
if clip_barycentric_coords:
|
||||
bary = barycentric_coordinates_clip(bary)
|
||||
# use clipped barycentric coords to calculate the z value
|
||||
pz = bary[0] * v0[2] + bary[1] * v1[2] + bary[2] * v2[2]
|
||||
|
||||
# Check if point is behind the image.
|
||||
@@ -365,7 +382,6 @@ def rasterize_meshes_python(
|
||||
# Calculate signed 2D distance from point to face.
|
||||
# Points inside the triangle have negative distance.
|
||||
dist = point_triangle_distance(pxy, v0[:2], v1[:2], v2[:2])
|
||||
inside = all(x > 0.0 for x in bary)
|
||||
|
||||
signed_dist = dist * -1.0 if inside else dist
|
||||
|
||||
@@ -433,6 +449,33 @@ def edge_function(p, v0, v1):
|
||||
return (p[0] - v0[0]) * (v1[1] - v0[1]) - (p[1] - v0[1]) * (v1[0] - v0[0])
|
||||
|
||||
|
||||
def barycentric_coordinates_clip(bary):
|
||||
"""
|
||||
Clip negative barycentric coordinates to 0.0 and renormalize so
|
||||
the barycentric coordinates for a point sum to 1. When the blur_radius
|
||||
is greater than 0, a face will still be recorded as overlapping a pixel
|
||||
if the pixel is outisde the face. In this case at least one of the
|
||||
barycentric coordinates for the pixel relative to the face will be negative.
|
||||
Clipping will ensure that the texture and z buffer are interpolated correctly.
|
||||
|
||||
Args:
|
||||
bary: tuple of barycentric coordinates
|
||||
|
||||
Returns
|
||||
bary_clip: (w0, w1, w2) barycentric coordinates with no negative values.
|
||||
"""
|
||||
# Only negative values are clamped to 0.0.
|
||||
w0_clip = torch.clamp(bary[0], min=0.0)
|
||||
w1_clip = torch.clamp(bary[1], min=0.0)
|
||||
w2_clip = torch.clamp(bary[2], min=0.0)
|
||||
bary_sum = torch.clamp(w0_clip + w1_clip + w2_clip, min=1e-5)
|
||||
w0_clip = w0_clip / bary_sum
|
||||
w1_clip = w1_clip / bary_sum
|
||||
w2_clip = w2_clip / bary_sum
|
||||
|
||||
return (w0_clip, w1_clip, w2_clip)
|
||||
|
||||
|
||||
def barycentric_coordinates(p, v0, v1, v2):
|
||||
"""
|
||||
Compute the barycentric coordinates of a point relative to a triangle.
|
||||
|
||||
@@ -26,6 +26,7 @@ class RasterizationSettings:
|
||||
"bin_size",
|
||||
"max_faces_per_bin",
|
||||
"perspective_correct",
|
||||
"clip_barycentric_coords",
|
||||
"cull_backfaces",
|
||||
]
|
||||
|
||||
@@ -37,6 +38,7 @@ class RasterizationSettings:
|
||||
bin_size: Optional[int] = None,
|
||||
max_faces_per_bin: Optional[int] = None,
|
||||
perspective_correct: bool = False,
|
||||
clip_barycentric_coords: bool = False,
|
||||
cull_backfaces: bool = False,
|
||||
):
|
||||
self.image_size = image_size
|
||||
@@ -45,6 +47,7 @@ class RasterizationSettings:
|
||||
self.bin_size = bin_size
|
||||
self.max_faces_per_bin = max_faces_per_bin
|
||||
self.perspective_correct = perspective_correct
|
||||
self.clip_barycentric_coords = clip_barycentric_coords
|
||||
self.cull_backfaces = cull_backfaces
|
||||
|
||||
|
||||
@@ -127,6 +130,7 @@ class MeshRasterizer(nn.Module):
|
||||
bin_size=raster_settings.bin_size,
|
||||
max_faces_per_bin=raster_settings.max_faces_per_bin,
|
||||
perspective_correct=raster_settings.perspective_correct,
|
||||
clip_barycentric_coords=raster_settings.clip_barycentric_coords,
|
||||
cull_backfaces=raster_settings.cull_backfaces,
|
||||
)
|
||||
return Fragments(
|
||||
|
||||
@@ -49,21 +49,6 @@ class MeshRenderer(nn.Module):
|
||||
the range for the corresponding face.
|
||||
"""
|
||||
fragments = self.rasterizer(meshes_world, **kwargs)
|
||||
raster_settings = kwargs.get("raster_settings", self.rasterizer.raster_settings)
|
||||
if raster_settings.blur_radius > 0.0:
|
||||
# TODO: potentially move barycentric clipping to the rasterizer
|
||||
# if no downstream functions requires unclipped values.
|
||||
# This will avoid unnecssary re-interpolation of the z buffer.
|
||||
clipped_bary_coords = _clip_barycentric_coordinates(fragments.bary_coords)
|
||||
clipped_zbuf = _interpolate_zbuf(
|
||||
fragments.pix_to_face, clipped_bary_coords, meshes_world
|
||||
)
|
||||
fragments = Fragments(
|
||||
bary_coords=clipped_bary_coords,
|
||||
zbuf=clipped_zbuf,
|
||||
dists=fragments.dists,
|
||||
pix_to_face=fragments.pix_to_face,
|
||||
)
|
||||
images = self.shader(fragments, meshes_world, **kwargs)
|
||||
|
||||
return images
|
||||
|
||||
@@ -20,9 +20,13 @@ def _clip_barycentric_coordinates(bary) -> torch.Tensor:
|
||||
if bary.shape[-1] != 3:
|
||||
msg = "Expected barycentric coords to have last dim = 3; got %r"
|
||||
raise ValueError(msg % (bary.shape,))
|
||||
ndims = bary.ndim - 1
|
||||
mask = bary.eq(-1).all(dim=-1, keepdim=True).expand(*((-1,) * ndims + (3,)))
|
||||
clipped = bary.clamp(min=0.0)
|
||||
clipped[mask] = 0.0
|
||||
clipped_sum = torch.clamp(clipped.sum(dim=-1, keepdim=True), min=1e-5)
|
||||
clipped = clipped / clipped_sum
|
||||
clipped[mask] = -1.0
|
||||
return clipped
|
||||
|
||||
|
||||
@@ -49,6 +53,8 @@ def _interpolate_zbuf(
|
||||
verts = meshes.verts_packed()
|
||||
faces = meshes.faces_packed()
|
||||
faces_verts_z = verts[faces][..., 2][..., None] # (F, 3, 1)
|
||||
return interpolate_face_attributes(pix_to_face, barycentric_coords, faces_verts_z)[
|
||||
zbuf = interpolate_face_attributes(pix_to_face, barycentric_coords, faces_verts_z)[
|
||||
..., 0
|
||||
] # (1, H, W, K)
|
||||
zbuf[pix_to_face == -1] = -1
|
||||
return zbuf
|
||||
|
||||
Reference in New Issue
Block a user