mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2026-03-11 14:55:59 +08:00
barycentric clipping in cuda/c++
Summary: Added support for barycentric clipping in the C++/CUDA rasterization kernels which can be switched on/off via a rasterization setting. Added tests and a benchmark to compare with the current implementation in PyTorch - for some cases of large image size/faces per pixel the cuda version is 10x faster. Reviewed By: gkioxari Differential Revision: D21705503 fbshipit-source-id: e835c0f927f1e5088ca89020aef5ff27ac3a8769
This commit is contained in:
committed by
Facebook GitHub Bot
parent
bce396df93
commit
cc70950f40
@@ -114,6 +114,7 @@ __device__ void CheckPixelInsideFace(
|
||||
const float2 pxy, // Coordinates of the pixel
|
||||
const int K,
|
||||
const bool perspective_correct,
|
||||
const bool clip_barycentric_coords,
|
||||
const bool cull_backfaces) {
|
||||
const auto v012 = GetSingleFaceVerts(face_verts, face_idx);
|
||||
const float3 v0 = thrust::get<0>(v012);
|
||||
@@ -149,8 +150,12 @@ __device__ void CheckPixelInsideFace(
|
||||
const float3 p_bary = !perspective_correct
|
||||
? p_bary0
|
||||
: BarycentricPerspectiveCorrectionForward(p_bary0, v0.z, v1.z, v2.z);
|
||||
const float3 p_bary_clip =
|
||||
!clip_barycentric_coords ? p_bary : BarycentricClipForward(p_bary);
|
||||
|
||||
const float pz =
|
||||
p_bary_clip.x * v0.z + p_bary_clip.y * v1.z + p_bary_clip.z * v2.z;
|
||||
|
||||
const float pz = p_bary.x * v0.z + p_bary.y * v1.z + p_bary.z * v2.z;
|
||||
if (pz < 0) {
|
||||
return; // Face is behind the image plane.
|
||||
}
|
||||
@@ -158,7 +163,8 @@ __device__ void CheckPixelInsideFace(
|
||||
// Get abs squared distance
|
||||
const float dist = PointTriangleDistanceForward(pxy, v0xy, v1xy, v2xy);
|
||||
|
||||
// Use the bary coordinates to determine if the point is inside the face.
|
||||
// Use the unclipped bary coordinates to determine if the point is inside the
|
||||
// face.
|
||||
const bool inside = p_bary.x > 0.0f && p_bary.y > 0.0f && p_bary.z > 0.0f;
|
||||
const float signed_dist = inside ? -dist : dist;
|
||||
|
||||
@@ -169,7 +175,7 @@ __device__ void CheckPixelInsideFace(
|
||||
|
||||
if (q_size < K) {
|
||||
// Just insert it.
|
||||
q[q_size] = {pz, face_idx, signed_dist, p_bary};
|
||||
q[q_size] = {pz, face_idx, signed_dist, p_bary_clip};
|
||||
if (pz > q_max_z) {
|
||||
q_max_z = pz;
|
||||
q_max_idx = q_size;
|
||||
@@ -177,7 +183,7 @@ __device__ void CheckPixelInsideFace(
|
||||
q_size++;
|
||||
} else if (pz < q_max_z) {
|
||||
// Overwrite the old max, and find the new max.
|
||||
q[q_max_idx] = {pz, face_idx, signed_dist, p_bary};
|
||||
q[q_max_idx] = {pz, face_idx, signed_dist, p_bary_clip};
|
||||
q_max_z = pz;
|
||||
for (int i = 0; i < K; i++) {
|
||||
if (q[i].z > q_max_z) {
|
||||
@@ -198,6 +204,7 @@ __global__ void RasterizeMeshesNaiveCudaKernel(
|
||||
const int64_t* num_faces_per_mesh,
|
||||
const float blur_radius,
|
||||
const bool perspective_correct,
|
||||
const bool clip_barycentric_coords,
|
||||
const bool cull_backfaces,
|
||||
const int N,
|
||||
const int H,
|
||||
@@ -260,6 +267,7 @@ __global__ void RasterizeMeshesNaiveCudaKernel(
|
||||
pxy,
|
||||
K,
|
||||
perspective_correct,
|
||||
clip_barycentric_coords,
|
||||
cull_backfaces);
|
||||
}
|
||||
|
||||
@@ -286,6 +294,7 @@ RasterizeMeshesNaiveCuda(
|
||||
const float blur_radius,
|
||||
const int num_closest,
|
||||
const bool perspective_correct,
|
||||
const bool clip_barycentric_coords,
|
||||
const bool cull_backfaces) {
|
||||
TORCH_CHECK(
|
||||
face_verts.ndimension() == 3 && face_verts.size(1) == 3 &&
|
||||
@@ -343,6 +352,7 @@ RasterizeMeshesNaiveCuda(
|
||||
num_faces_per_mesh.contiguous().data_ptr<int64_t>(),
|
||||
blur_radius,
|
||||
perspective_correct,
|
||||
clip_barycentric_coords,
|
||||
cull_backfaces,
|
||||
N,
|
||||
H,
|
||||
@@ -365,6 +375,7 @@ __global__ void RasterizeMeshesBackwardCudaKernel(
|
||||
const float* face_verts, // (F, 3, 3)
|
||||
const int64_t* pix_to_face, // (N, H, W, K)
|
||||
const bool perspective_correct,
|
||||
const bool clip_barycentric_coords,
|
||||
const int N,
|
||||
const int H,
|
||||
const int W,
|
||||
@@ -422,11 +433,15 @@ __global__ void RasterizeMeshesBackwardCudaKernel(
|
||||
const float3 grad_bary_upstream = make_float3(
|
||||
grad_bary_upstream_w0, grad_bary_upstream_w1, grad_bary_upstream_w2);
|
||||
|
||||
const float3 bary0 = BarycentricCoordsForward(pxy, v0xy, v1xy, v2xy);
|
||||
const float3 bary = !perspective_correct
|
||||
? bary0
|
||||
: BarycentricPerspectiveCorrectionForward(bary0, v0.z, v1.z, v2.z);
|
||||
const bool inside = bary.x > 0.0f && bary.y > 0.0f && bary.z > 0.0f;
|
||||
const float3 b_w = BarycentricCoordsForward(pxy, v0xy, v1xy, v2xy);
|
||||
const float3 b_pp = !perspective_correct
|
||||
? b_w
|
||||
: BarycentricPerspectiveCorrectionForward(b_w, v0.z, v1.z, v2.z);
|
||||
|
||||
const float3 b_w_clip =
|
||||
!clip_barycentric_coords ? b_pp : BarycentricClipForward(b_pp);
|
||||
|
||||
const bool inside = b_pp.x > 0.0f && b_pp.y > 0.0f && b_pp.z > 0.0f;
|
||||
const float sign = inside ? -1.0f : 1.0f;
|
||||
|
||||
// TODO(T52813608) Add support for non-square images.
|
||||
@@ -442,22 +457,29 @@ __global__ void RasterizeMeshesBackwardCudaKernel(
|
||||
// d_zbuf/d_bary_w0 = z0
|
||||
// d_zbuf/d_bary_w1 = z1
|
||||
// d_zbuf/d_bary_w2 = z2
|
||||
const float3 d_zbuf_d_bary = make_float3(v0.z, v1.z, v2.z);
|
||||
const float3 d_zbuf_d_bwclip = make_float3(v0.z, v1.z, v2.z);
|
||||
|
||||
// Total upstream barycentric gradients are the sum of
|
||||
// external upstream gradients and contribution from zbuf.
|
||||
const float3 grad_bary_f_sum =
|
||||
(grad_bary_upstream + grad_zbuf_upstream * d_zbuf_d_bary);
|
||||
(grad_bary_upstream + grad_zbuf_upstream * d_zbuf_d_bwclip);
|
||||
|
||||
float3 grad_bary0 = grad_bary_f_sum;
|
||||
|
||||
if (clip_barycentric_coords) {
|
||||
grad_bary0 = BarycentricClipBackward(b_w, grad_bary_f_sum);
|
||||
}
|
||||
|
||||
float dz0_persp = 0.0f, dz1_persp = 0.0f, dz2_persp = 0.0f;
|
||||
if (perspective_correct) {
|
||||
auto perspective_grads = BarycentricPerspectiveCorrectionBackward(
|
||||
bary0, v0.z, v1.z, v2.z, grad_bary_f_sum);
|
||||
b_w, v0.z, v1.z, v2.z, grad_bary0);
|
||||
grad_bary0 = thrust::get<0>(perspective_grads);
|
||||
dz0_persp = thrust::get<1>(perspective_grads);
|
||||
dz1_persp = thrust::get<2>(perspective_grads);
|
||||
dz2_persp = thrust::get<3>(perspective_grads);
|
||||
}
|
||||
|
||||
auto grad_bary_f =
|
||||
BarycentricCoordsBackward(pxy, v0xy, v1xy, v2xy, grad_bary0);
|
||||
const float2 dbary_d_v0 = thrust::get<1>(grad_bary_f);
|
||||
@@ -467,15 +489,18 @@ __global__ void RasterizeMeshesBackwardCudaKernel(
|
||||
atomicAdd(grad_face_verts + f * 9 + 0, dbary_d_v0.x + ddist_d_v0.x);
|
||||
atomicAdd(grad_face_verts + f * 9 + 1, dbary_d_v0.y + ddist_d_v0.y);
|
||||
atomicAdd(
|
||||
grad_face_verts + f * 9 + 2, grad_zbuf_upstream * bary.x + dz0_persp);
|
||||
grad_face_verts + f * 9 + 2,
|
||||
grad_zbuf_upstream * b_w_clip.x + dz0_persp);
|
||||
atomicAdd(grad_face_verts + f * 9 + 3, dbary_d_v1.x + ddist_d_v1.x);
|
||||
atomicAdd(grad_face_verts + f * 9 + 4, dbary_d_v1.y + ddist_d_v1.y);
|
||||
atomicAdd(
|
||||
grad_face_verts + f * 9 + 5, grad_zbuf_upstream * bary.y + dz1_persp);
|
||||
grad_face_verts + f * 9 + 5,
|
||||
grad_zbuf_upstream * b_w_clip.y + dz1_persp);
|
||||
atomicAdd(grad_face_verts + f * 9 + 6, dbary_d_v2.x + ddist_d_v2.x);
|
||||
atomicAdd(grad_face_verts + f * 9 + 7, dbary_d_v2.y + ddist_d_v2.y);
|
||||
atomicAdd(
|
||||
grad_face_verts + f * 9 + 8, grad_zbuf_upstream * bary.z + dz2_persp);
|
||||
grad_face_verts + f * 9 + 8,
|
||||
grad_zbuf_upstream * b_w_clip.z + dz2_persp);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -486,7 +511,8 @@ at::Tensor RasterizeMeshesBackwardCuda(
|
||||
const at::Tensor& grad_zbuf, // (N, H, W, K)
|
||||
const at::Tensor& grad_bary, // (N, H, W, K, 3)
|
||||
const at::Tensor& grad_dists, // (N, H, W, K)
|
||||
const bool perspective_correct) {
|
||||
const bool perspective_correct,
|
||||
const bool clip_barycentric_coords) {
|
||||
// Check inputs are on the same device
|
||||
at::TensorArg face_verts_t{face_verts, "face_verts", 1},
|
||||
pix_to_face_t{pix_to_face, "pix_to_face", 2},
|
||||
@@ -523,6 +549,7 @@ at::Tensor RasterizeMeshesBackwardCuda(
|
||||
face_verts.contiguous().data_ptr<float>(),
|
||||
pix_to_face.contiguous().data_ptr<int64_t>(),
|
||||
perspective_correct,
|
||||
clip_barycentric_coords,
|
||||
N,
|
||||
H,
|
||||
W,
|
||||
@@ -743,6 +770,7 @@ __global__ void RasterizeMeshesFineCudaKernel(
|
||||
const float blur_radius,
|
||||
const int bin_size,
|
||||
const bool perspective_correct,
|
||||
const bool clip_barycentric_coords,
|
||||
const bool cull_backfaces,
|
||||
const int N,
|
||||
const int B,
|
||||
@@ -808,6 +836,7 @@ __global__ void RasterizeMeshesFineCudaKernel(
|
||||
pxy,
|
||||
K,
|
||||
perspective_correct,
|
||||
clip_barycentric_coords,
|
||||
cull_backfaces);
|
||||
}
|
||||
|
||||
@@ -841,6 +870,7 @@ RasterizeMeshesFineCuda(
|
||||
const int bin_size,
|
||||
const int faces_per_pixel,
|
||||
const bool perspective_correct,
|
||||
const bool clip_barycentric_coords,
|
||||
const bool cull_backfaces) {
|
||||
TORCH_CHECK(
|
||||
face_verts.ndimension() == 3 && face_verts.size(1) == 3 &&
|
||||
@@ -890,6 +920,7 @@ RasterizeMeshesFineCuda(
|
||||
blur_radius,
|
||||
bin_size,
|
||||
perspective_correct,
|
||||
clip_barycentric_coords,
|
||||
cull_backfaces,
|
||||
N,
|
||||
B,
|
||||
|
||||
Reference in New Issue
Block a user