mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2026-03-11 14:55:59 +08:00
CUDA/C++ Rasterizer updates to handle clipped faces
Summary: - Updated the C++/CUDA mesh rasterization kernels to handle the clipped faces. In particular this required careful handling of the distance calculation for faces which are cut into a quadrilateral by the image plane and then split into two sub triangles i.e. both sub triangles can't be part of the top K faces. - Updated `rasterize_meshes.py` to use the utils functions to clip the meshes and convert the fragments back to in terms of the unclipped mesh - Added end to end tests Reviewed By: jcjohnson Differential Revision: D26169685 fbshipit-source-id: d64cd0d656109b965f44a35c301b7c81f451cfa0
This commit is contained in:
committed by
Facebook GitHub Bot
parent
838b73d3b6
commit
340662e98e
@@ -17,8 +17,8 @@ namespace {
|
||||
// A structure for holding details about a pixel.
|
||||
struct Pixel {
|
||||
float z;
|
||||
int64_t idx;
|
||||
float dist;
|
||||
int64_t idx; // idx of face
|
||||
float dist; // abs distance of pixel to face
|
||||
float3 bary;
|
||||
};
|
||||
|
||||
@@ -111,6 +111,7 @@ __device__ bool CheckPointOutsideBoundingBox(
|
||||
template <typename FaceQ>
|
||||
__device__ void CheckPixelInsideFace(
|
||||
const float* face_verts, // (F, 3, 3)
|
||||
const int64_t* clipped_faces_neighbor_idx, // (F,)
|
||||
const int face_idx,
|
||||
int& q_size,
|
||||
float& q_max_z,
|
||||
@@ -173,32 +174,72 @@ __device__ void CheckPixelInsideFace(
|
||||
// face.
|
||||
const bool inside = p_bary.x > 0.0f && p_bary.y > 0.0f && p_bary.z > 0.0f;
|
||||
const float signed_dist = inside ? -dist : dist;
|
||||
|
||||
// Check if pixel is outside blur region
|
||||
if (!inside && dist >= blur_radius) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (q_size < K) {
|
||||
// Just insert it.
|
||||
q[q_size] = {pz, face_idx, signed_dist, p_bary_clip};
|
||||
if (pz > q_max_z) {
|
||||
q_max_z = pz;
|
||||
q_max_idx = q_size;
|
||||
// Handle the case where a face (f) partially behind the image plane is
|
||||
// clipped to a quadrilateral and then split into two faces (t1, t2). In this
|
||||
// case we:
|
||||
// 1. Find the index of the neighboring face (e.g. for t1 need index of t2)
|
||||
// 2. Check if the neighboring face (t2) is already in the top K faces
|
||||
// 3. If yes, compare the distance of the pixel to t1 with the distance to t2.
|
||||
// 4. If dist_t1 < dist_t2, overwrite the values for t2 in the top K faces.
|
||||
const int neighbor_idx = clipped_faces_neighbor_idx[face_idx];
|
||||
int neighbor_idx_top_k = -1;
|
||||
|
||||
// Check if neighboring face is already in the top K.
|
||||
// -1 is the fill value in clipped_faces_neighbor_idx
|
||||
if (neighbor_idx != -1) {
|
||||
// Only need to loop until q_size.
|
||||
for (int i = 0; i < q_size; i++) {
|
||||
if (q[i].idx == neighbor_idx) {
|
||||
neighbor_idx_top_k = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
q_size++;
|
||||
} else if (pz < q_max_z) {
|
||||
// Overwrite the old max, and find the new max.
|
||||
q[q_max_idx] = {pz, face_idx, signed_dist, p_bary_clip};
|
||||
q_max_z = pz;
|
||||
for (int i = 0; i < K; i++) {
|
||||
if (q[i].z > q_max_z) {
|
||||
q_max_z = q[i].z;
|
||||
q_max_idx = i;
|
||||
}
|
||||
// If neighbor idx is not -1 then it is in the top K struct.
|
||||
if (neighbor_idx_top_k != -1) {
|
||||
// If dist of current face is less than neighbor then overwrite the
|
||||
// neighbor face values in the top K struct.
|
||||
float neighbor_dist = abs(q[neighbor_idx_top_k].dist);
|
||||
if (dist < neighbor_dist) {
|
||||
// Overwrite the neighbor face values
|
||||
q[neighbor_idx_top_k] = {pz, face_idx, signed_dist, p_bary_clip};
|
||||
|
||||
// If pz > q_max then overwrite the max values and index of the max.
|
||||
// q_size stays the same.
|
||||
if (pz > q_max_z) {
|
||||
q_max_z = pz;
|
||||
q_max_idx = neighbor_idx_top_k;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Handle as a normal face
|
||||
if (q_size < K) {
|
||||
// Just insert it.
|
||||
q[q_size] = {pz, face_idx, signed_dist, p_bary_clip};
|
||||
if (pz > q_max_z) {
|
||||
q_max_z = pz;
|
||||
q_max_idx = q_size;
|
||||
}
|
||||
q_size++;
|
||||
} else if (pz < q_max_z) {
|
||||
// Overwrite the old max, and find the new max.
|
||||
q[q_max_idx] = {pz, face_idx, signed_dist, p_bary_clip};
|
||||
q_max_z = pz;
|
||||
for (int i = 0; i < K; i++) {
|
||||
if (q[i].z > q_max_z) {
|
||||
q_max_z = q[i].z;
|
||||
q_max_idx = i;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// ****************************************************************************
|
||||
@@ -208,6 +249,7 @@ __global__ void RasterizeMeshesNaiveCudaKernel(
|
||||
const float* face_verts,
|
||||
const int64_t* mesh_to_face_first_idx,
|
||||
const int64_t* num_faces_per_mesh,
|
||||
const int64_t* clipped_faces_neighbor_idx,
|
||||
const float blur_radius,
|
||||
const bool perspective_correct,
|
||||
const bool clip_barycentric_coords,
|
||||
@@ -265,6 +307,7 @@ __global__ void RasterizeMeshesNaiveCudaKernel(
|
||||
|
||||
CheckPixelInsideFace(
|
||||
face_verts,
|
||||
clipped_faces_neighbor_idx,
|
||||
f,
|
||||
q_size,
|
||||
q_max_z,
|
||||
@@ -298,6 +341,7 @@ RasterizeMeshesNaiveCuda(
|
||||
const at::Tensor& face_verts,
|
||||
const at::Tensor& mesh_to_faces_packed_first_idx,
|
||||
const at::Tensor& num_faces_per_mesh,
|
||||
const at::Tensor& clipped_faces_neighbor_idx,
|
||||
const std::tuple<int, int> image_size,
|
||||
const float blur_radius,
|
||||
const int num_closest,
|
||||
@@ -313,6 +357,10 @@ RasterizeMeshesNaiveCuda(
|
||||
num_faces_per_mesh.size(0) == mesh_to_faces_packed_first_idx.size(0),
|
||||
"num_faces_per_mesh must have save size first dimension as mesh_to_faces_packed_first_idx");
|
||||
|
||||
TORCH_CHECK(
|
||||
clipped_faces_neighbor_idx.size(0) == face_verts.size(0),
|
||||
"clipped_faces_neighbor_idx must have save size first dimension as face_verts");
|
||||
|
||||
if (num_closest > kMaxPointsPerPixel) {
|
||||
std::stringstream ss;
|
||||
ss << "Must have points_per_pixel <= " << kMaxPointsPerPixel;
|
||||
@@ -323,11 +371,16 @@ RasterizeMeshesNaiveCuda(
|
||||
at::TensorArg face_verts_t{face_verts, "face_verts", 1},
|
||||
mesh_to_faces_packed_first_idx_t{
|
||||
mesh_to_faces_packed_first_idx, "mesh_to_faces_packed_first_idx", 2},
|
||||
num_faces_per_mesh_t{num_faces_per_mesh, "num_faces_per_mesh", 3};
|
||||
num_faces_per_mesh_t{num_faces_per_mesh, "num_faces_per_mesh", 3},
|
||||
clipped_faces_neighbor_idx_t{
|
||||
clipped_faces_neighbor_idx, "clipped_faces_neighbor_idx", 4};
|
||||
at::CheckedFrom c = "RasterizeMeshesNaiveCuda";
|
||||
at::checkAllSameGPU(
|
||||
c,
|
||||
{face_verts_t, mesh_to_faces_packed_first_idx_t, num_faces_per_mesh_t});
|
||||
{face_verts_t,
|
||||
mesh_to_faces_packed_first_idx_t,
|
||||
num_faces_per_mesh_t,
|
||||
clipped_faces_neighbor_idx_t});
|
||||
|
||||
// Set the device for the kernel launch based on the device of the input
|
||||
at::cuda::CUDAGuard device_guard(face_verts.device());
|
||||
@@ -358,6 +411,7 @@ RasterizeMeshesNaiveCuda(
|
||||
face_verts.contiguous().data_ptr<float>(),
|
||||
mesh_to_faces_packed_first_idx.contiguous().data_ptr<int64_t>(),
|
||||
num_faces_per_mesh.contiguous().data_ptr<int64_t>(),
|
||||
clipped_faces_neighbor_idx.contiguous().data_ptr<int64_t>(),
|
||||
blur_radius,
|
||||
perspective_correct,
|
||||
clip_barycentric_coords,
|
||||
@@ -800,6 +854,7 @@ at::Tensor RasterizeMeshesCoarseCuda(
|
||||
__global__ void RasterizeMeshesFineCudaKernel(
|
||||
const float* face_verts, // (F, 3, 3)
|
||||
const int32_t* bin_faces, // (N, BH, BW, T)
|
||||
const int64_t* clipped_faces_neighbor_idx, // (F,)
|
||||
const float blur_radius,
|
||||
const int bin_size,
|
||||
const bool perspective_correct,
|
||||
@@ -858,6 +913,7 @@ __global__ void RasterizeMeshesFineCudaKernel(
|
||||
int q_size = 0;
|
||||
float q_max_z = -1000;
|
||||
int q_max_idx = -1;
|
||||
|
||||
for (int m = 0; m < M; m++) {
|
||||
const int f = bin_faces[n * BH * BW * M + by * BW * M + bx * M + m];
|
||||
if (f < 0) {
|
||||
@@ -867,6 +923,7 @@ __global__ void RasterizeMeshesFineCudaKernel(
|
||||
// update q, q_size, q_max_z and q_max_idx in place.
|
||||
CheckPixelInsideFace(
|
||||
face_verts,
|
||||
clipped_faces_neighbor_idx,
|
||||
f,
|
||||
q_size,
|
||||
q_max_z,
|
||||
@@ -906,6 +963,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
|
||||
RasterizeMeshesFineCuda(
|
||||
const at::Tensor& face_verts,
|
||||
const at::Tensor& bin_faces,
|
||||
const at::Tensor& clipped_faces_neighbor_idx,
|
||||
const std::tuple<int, int> image_size,
|
||||
const float blur_radius,
|
||||
const int bin_size,
|
||||
@@ -918,12 +976,18 @@ RasterizeMeshesFineCuda(
|
||||
face_verts.size(2) == 3,
|
||||
"face_verts must have dimensions (num_faces, 3, 3)");
|
||||
TORCH_CHECK(bin_faces.ndimension() == 4, "bin_faces must have 4 dimensions");
|
||||
TORCH_CHECK(
|
||||
clipped_faces_neighbor_idx.size(0) == face_verts.size(0),
|
||||
"clipped_faces_neighbor_idx must have the same first dimension as face_verts");
|
||||
|
||||
// Check inputs are on the same device
|
||||
at::TensorArg face_verts_t{face_verts, "face_verts", 1},
|
||||
bin_faces_t{bin_faces, "bin_faces", 2};
|
||||
bin_faces_t{bin_faces, "bin_faces", 2},
|
||||
clipped_faces_neighbor_idx_t{
|
||||
clipped_faces_neighbor_idx, "clipped_faces_neighbor_idx", 3};
|
||||
at::CheckedFrom c = "RasterizeMeshesFineCuda";
|
||||
at::checkAllSameGPU(c, {face_verts_t, bin_faces_t});
|
||||
at::checkAllSameGPU(
|
||||
c, {face_verts_t, bin_faces_t, clipped_faces_neighbor_idx_t});
|
||||
|
||||
// Set the device for the kernel launch based on the device of the input
|
||||
at::cuda::CUDAGuard device_guard(face_verts.device());
|
||||
@@ -961,6 +1025,7 @@ RasterizeMeshesFineCuda(
|
||||
RasterizeMeshesFineCudaKernel<<<blocks, threads, 0, stream>>>(
|
||||
face_verts.contiguous().data_ptr<float>(),
|
||||
bin_faces.contiguous().data_ptr<int32_t>(),
|
||||
clipped_faces_neighbor_idx.contiguous().data_ptr<int64_t>(),
|
||||
blur_radius,
|
||||
bin_size,
|
||||
perspective_correct,
|
||||
|
||||
Reference in New Issue
Block a user