CUDA/C++ Rasterizer updates to handle clipped faces

Summary:
- Updated the C++/CUDA mesh rasterization kernels to handle the clipped faces. In particular this required careful handling of the distance calculation for faces which are cut into a quadrilateral by the image plane and then split into two sub triangles i.e. both sub triangles can't be part of the top K faces.
- Updated `rasterize_meshes.py` to use the utils functions to clip the meshes and convert the fragments back to in terms of the unclipped mesh
- Added end to end tests

Reviewed By: jcjohnson

Differential Revision: D26169685

fbshipit-source-id: d64cd0d656109b965f44a35c301b7c81f451cfa0
This commit is contained in:
Nikhila Ravi
2021-02-08 14:30:55 -08:00
committed by Facebook GitHub Bot
parent 838b73d3b6
commit 340662e98e
12 changed files with 733 additions and 46 deletions

View File

@@ -17,8 +17,8 @@ namespace {
// A structure for holding details about a pixel.
struct Pixel {
float z;
int64_t idx;
float dist;
int64_t idx; // idx of face
float dist; // abs distance of pixel to face
float3 bary;
};
@@ -111,6 +111,7 @@ __device__ bool CheckPointOutsideBoundingBox(
template <typename FaceQ>
__device__ void CheckPixelInsideFace(
const float* face_verts, // (F, 3, 3)
const int64_t* clipped_faces_neighbor_idx, // (F,)
const int face_idx,
int& q_size,
float& q_max_z,
@@ -173,32 +174,72 @@ __device__ void CheckPixelInsideFace(
// face.
const bool inside = p_bary.x > 0.0f && p_bary.y > 0.0f && p_bary.z > 0.0f;
const float signed_dist = inside ? -dist : dist;
// Check if pixel is outside blur region
if (!inside && dist >= blur_radius) {
return;
}
if (q_size < K) {
// Just insert it.
q[q_size] = {pz, face_idx, signed_dist, p_bary_clip};
if (pz > q_max_z) {
q_max_z = pz;
q_max_idx = q_size;
// Handle the case where a face (f) partially behind the image plane is
// clipped to a quadrilateral and then split into two faces (t1, t2). In this
// case we:
// 1. Find the index of the neighboring face (e.g. for t1 need index of t2)
// 2. Check if the neighboring face (t2) is already in the top K faces
// 3. If yes, compare the distance of the pixel to t1 with the distance to t2.
// 4. If dist_t1 < dist_t2, overwrite the values for t2 in the top K faces.
const int neighbor_idx = clipped_faces_neighbor_idx[face_idx];
int neighbor_idx_top_k = -1;
// Check if neighboring face is already in the top K.
// -1 is the fill value in clipped_faces_neighbor_idx
if (neighbor_idx != -1) {
// Only need to loop until q_size.
for (int i = 0; i < q_size; i++) {
if (q[i].idx == neighbor_idx) {
neighbor_idx_top_k = i;
break;
}
}
q_size++;
} else if (pz < q_max_z) {
// Overwrite the old max, and find the new max.
q[q_max_idx] = {pz, face_idx, signed_dist, p_bary_clip};
q_max_z = pz;
for (int i = 0; i < K; i++) {
if (q[i].z > q_max_z) {
q_max_z = q[i].z;
q_max_idx = i;
}
// If neighbor idx is not -1 then it is in the top K struct.
if (neighbor_idx_top_k != -1) {
// If dist of current face is less than neighbor then overwrite the
// neighbor face values in the top K struct.
float neighbor_dist = abs(q[neighbor_idx_top_k].dist);
if (dist < neighbor_dist) {
// Overwrite the neighbor face values
q[neighbor_idx_top_k] = {pz, face_idx, signed_dist, p_bary_clip};
// If pz > q_max then overwrite the max values and index of the max.
// q_size stays the same.
if (pz > q_max_z) {
q_max_z = pz;
q_max_idx = neighbor_idx_top_k;
}
}
} else {
// Handle as a normal face
if (q_size < K) {
// Just insert it.
q[q_size] = {pz, face_idx, signed_dist, p_bary_clip};
if (pz > q_max_z) {
q_max_z = pz;
q_max_idx = q_size;
}
q_size++;
} else if (pz < q_max_z) {
// Overwrite the old max, and find the new max.
q[q_max_idx] = {pz, face_idx, signed_dist, p_bary_clip};
q_max_z = pz;
for (int i = 0; i < K; i++) {
if (q[i].z > q_max_z) {
q_max_z = q[i].z;
q_max_idx = i;
}
}
}
}
}
} // namespace
// ****************************************************************************
@@ -208,6 +249,7 @@ __global__ void RasterizeMeshesNaiveCudaKernel(
const float* face_verts,
const int64_t* mesh_to_face_first_idx,
const int64_t* num_faces_per_mesh,
const int64_t* clipped_faces_neighbor_idx,
const float blur_radius,
const bool perspective_correct,
const bool clip_barycentric_coords,
@@ -265,6 +307,7 @@ __global__ void RasterizeMeshesNaiveCudaKernel(
CheckPixelInsideFace(
face_verts,
clipped_faces_neighbor_idx,
f,
q_size,
q_max_z,
@@ -298,6 +341,7 @@ RasterizeMeshesNaiveCuda(
const at::Tensor& face_verts,
const at::Tensor& mesh_to_faces_packed_first_idx,
const at::Tensor& num_faces_per_mesh,
const at::Tensor& clipped_faces_neighbor_idx,
const std::tuple<int, int> image_size,
const float blur_radius,
const int num_closest,
@@ -313,6 +357,10 @@ RasterizeMeshesNaiveCuda(
num_faces_per_mesh.size(0) == mesh_to_faces_packed_first_idx.size(0),
"num_faces_per_mesh must have save size first dimension as mesh_to_faces_packed_first_idx");
TORCH_CHECK(
clipped_faces_neighbor_idx.size(0) == face_verts.size(0),
"clipped_faces_neighbor_idx must have save size first dimension as face_verts");
if (num_closest > kMaxPointsPerPixel) {
std::stringstream ss;
ss << "Must have points_per_pixel <= " << kMaxPointsPerPixel;
@@ -323,11 +371,16 @@ RasterizeMeshesNaiveCuda(
at::TensorArg face_verts_t{face_verts, "face_verts", 1},
mesh_to_faces_packed_first_idx_t{
mesh_to_faces_packed_first_idx, "mesh_to_faces_packed_first_idx", 2},
num_faces_per_mesh_t{num_faces_per_mesh, "num_faces_per_mesh", 3};
num_faces_per_mesh_t{num_faces_per_mesh, "num_faces_per_mesh", 3},
clipped_faces_neighbor_idx_t{
clipped_faces_neighbor_idx, "clipped_faces_neighbor_idx", 4};
at::CheckedFrom c = "RasterizeMeshesNaiveCuda";
at::checkAllSameGPU(
c,
{face_verts_t, mesh_to_faces_packed_first_idx_t, num_faces_per_mesh_t});
{face_verts_t,
mesh_to_faces_packed_first_idx_t,
num_faces_per_mesh_t,
clipped_faces_neighbor_idx_t});
// Set the device for the kernel launch based on the device of the input
at::cuda::CUDAGuard device_guard(face_verts.device());
@@ -358,6 +411,7 @@ RasterizeMeshesNaiveCuda(
face_verts.contiguous().data_ptr<float>(),
mesh_to_faces_packed_first_idx.contiguous().data_ptr<int64_t>(),
num_faces_per_mesh.contiguous().data_ptr<int64_t>(),
clipped_faces_neighbor_idx.contiguous().data_ptr<int64_t>(),
blur_radius,
perspective_correct,
clip_barycentric_coords,
@@ -800,6 +854,7 @@ at::Tensor RasterizeMeshesCoarseCuda(
__global__ void RasterizeMeshesFineCudaKernel(
const float* face_verts, // (F, 3, 3)
const int32_t* bin_faces, // (N, BH, BW, T)
const int64_t* clipped_faces_neighbor_idx, // (F,)
const float blur_radius,
const int bin_size,
const bool perspective_correct,
@@ -858,6 +913,7 @@ __global__ void RasterizeMeshesFineCudaKernel(
int q_size = 0;
float q_max_z = -1000;
int q_max_idx = -1;
for (int m = 0; m < M; m++) {
const int f = bin_faces[n * BH * BW * M + by * BW * M + bx * M + m];
if (f < 0) {
@@ -867,6 +923,7 @@ __global__ void RasterizeMeshesFineCudaKernel(
// update q, q_size, q_max_z and q_max_idx in place.
CheckPixelInsideFace(
face_verts,
clipped_faces_neighbor_idx,
f,
q_size,
q_max_z,
@@ -906,6 +963,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
RasterizeMeshesFineCuda(
const at::Tensor& face_verts,
const at::Tensor& bin_faces,
const at::Tensor& clipped_faces_neighbor_idx,
const std::tuple<int, int> image_size,
const float blur_radius,
const int bin_size,
@@ -918,12 +976,18 @@ RasterizeMeshesFineCuda(
face_verts.size(2) == 3,
"face_verts must have dimensions (num_faces, 3, 3)");
TORCH_CHECK(bin_faces.ndimension() == 4, "bin_faces must have 4 dimensions");
TORCH_CHECK(
clipped_faces_neighbor_idx.size(0) == face_verts.size(0),
"clipped_faces_neighbor_idx must have the same first dimension as face_verts");
// Check inputs are on the same device
at::TensorArg face_verts_t{face_verts, "face_verts", 1},
bin_faces_t{bin_faces, "bin_faces", 2};
bin_faces_t{bin_faces, "bin_faces", 2},
clipped_faces_neighbor_idx_t{
clipped_faces_neighbor_idx, "clipped_faces_neighbor_idx", 3};
at::CheckedFrom c = "RasterizeMeshesFineCuda";
at::checkAllSameGPU(c, {face_verts_t, bin_faces_t});
at::checkAllSameGPU(
c, {face_verts_t, bin_faces_t, clipped_faces_neighbor_idx_t});
// Set the device for the kernel launch based on the device of the input
at::cuda::CUDAGuard device_guard(face_verts.device());
@@ -961,6 +1025,7 @@ RasterizeMeshesFineCuda(
RasterizeMeshesFineCudaKernel<<<blocks, threads, 0, stream>>>(
face_verts.contiguous().data_ptr<float>(),
bin_faces.contiguous().data_ptr<int32_t>(),
clipped_faces_neighbor_idx.contiguous().data_ptr<int64_t>(),
blur_radius,
bin_size,
perspective_correct,

View File

@@ -15,6 +15,7 @@ RasterizeMeshesNaiveCpu(
const torch::Tensor& face_verts,
const torch::Tensor& mesh_to_face_first_idx,
const torch::Tensor& num_faces_per_mesh,
const torch::Tensor& clipped_faces_neighbor_idx,
const std::tuple<int, int> image_size,
const float blur_radius,
const int faces_per_pixel,
@@ -28,6 +29,7 @@ RasterizeMeshesNaiveCuda(
const at::Tensor& face_verts,
const at::Tensor& mesh_to_face_first_idx,
const at::Tensor& num_faces_per_mesh,
const torch::Tensor& clipped_faces_neighbor_idx,
const std::tuple<int, int> image_size,
const float blur_radius,
const int num_closest,
@@ -48,6 +50,12 @@ RasterizeMeshesNaiveCuda(
// the batch where N is the batch size.
// num_faces_per_mesh: LongTensor of shape (N) giving the number of faces
// for each mesh in the batch.
// clipped_faces_neighbor_idx: LongTensor of shape (F,) giving the
// index of the neighboring face for each face which was clipped to a
// quadrilateral and then divided into two triangles.
// e.g. for a face f partially behind the image plane which is split into
// two triangles (t1, t2): clipped_faces_neighbor_idx[t1_idx] = t2_idx
// Faces which are not clipped and subdivided are set to -1.
// image_size: Tuple (H, W) giving the size in pixels of the output
// image to be rasterized.
// blur_radius: float distance in NDC coordinates uses to expand the face
@@ -90,6 +98,7 @@ RasterizeMeshesNaive(
const torch::Tensor& face_verts,
const torch::Tensor& mesh_to_face_first_idx,
const torch::Tensor& num_faces_per_mesh,
const torch::Tensor& clipped_faces_neighbor_idx,
const std::tuple<int, int> image_size,
const float blur_radius,
const int faces_per_pixel,
@@ -106,6 +115,7 @@ RasterizeMeshesNaive(
face_verts,
mesh_to_face_first_idx,
num_faces_per_mesh,
clipped_faces_neighbor_idx,
image_size,
blur_radius,
faces_per_pixel,
@@ -120,6 +130,7 @@ RasterizeMeshesNaive(
face_verts,
mesh_to_face_first_idx,
num_faces_per_mesh,
clipped_faces_neighbor_idx,
image_size,
blur_radius,
faces_per_pixel,
@@ -306,6 +317,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
RasterizeMeshesFineCuda(
const torch::Tensor& face_verts,
const torch::Tensor& bin_faces,
const torch::Tensor& clipped_faces_neighbor_idx,
const std::tuple<int, int> image_size,
const float blur_radius,
const int bin_size,
@@ -322,6 +334,12 @@ RasterizeMeshesFineCuda(
// in NDC coordinates in the range [-1, 1].
// bin_faces: int32 Tensor of shape (N, B, B, M) giving the indices of faces
// that fall into each bin (output from coarse rasterization).
// clipped_faces_neighbor_idx: LongTensor of shape (F,) giving the
// index of the neighboring face for each face which was clipped to a
// quadrilateral and then divided into two triangles.
// e.g. for a face f partially behind the image plane which is split into
// two triangles (t1, t2): clipped_faces_neighbor_idx[t1_idx] = t2_idx
// Faces which are not clipped and subdivided are set to -1.
// image_size: Tuple (H, W) giving the size in pixels of the output
// image to be rasterized.
// blur_radius: float distance in NDC coordinates uses to expand the face
@@ -364,6 +382,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
RasterizeMeshesFine(
const torch::Tensor& face_verts,
const torch::Tensor& bin_faces,
const torch::Tensor& clipped_faces_neighbor_idx,
const std::tuple<int, int> image_size,
const float blur_radius,
const int bin_size,
@@ -378,6 +397,7 @@ RasterizeMeshesFine(
return RasterizeMeshesFineCuda(
face_verts,
bin_faces,
clipped_faces_neighbor_idx,
image_size,
blur_radius,
bin_size,
@@ -411,6 +431,12 @@ RasterizeMeshesFine(
// the batch where N is the batch size.
// num_faces_per_mesh: LongTensor of shape (N) giving the number of faces
// for each mesh in the batch.
// clipped_faces_neighbor_idx: LongTensor of shape (F,) giving the
// index of the neighboring face for each face which was clipped to a
// quadrilateral and then divided into two triangles.
// e.g. for a face f partially behind the image plane which is split into
// two triangles (t1, t2): clipped_faces_neighbor_idx[t1_idx] = t2_idx
// Faces which are not clipped and subdivided are set to -1.
// image_size: Tuple (H, W) giving the size in pixels of the output
// image to be rasterized.
// blur_radius: float distance in NDC coordinates uses to expand the face
@@ -456,6 +482,7 @@ RasterizeMeshes(
const torch::Tensor& face_verts,
const torch::Tensor& mesh_to_face_first_idx,
const torch::Tensor& num_faces_per_mesh,
const torch::Tensor& clipped_faces_neighbor_idx,
const std::tuple<int, int> image_size,
const float blur_radius,
const int faces_per_pixel,
@@ -477,6 +504,7 @@ RasterizeMeshes(
return RasterizeMeshesFine(
face_verts,
bin_faces,
clipped_faces_neighbor_idx,
image_size,
blur_radius,
bin_size,
@@ -490,6 +518,7 @@ RasterizeMeshes(
face_verts,
mesh_to_face_first_idx,
num_faces_per_mesh,
clipped_faces_neighbor_idx,
image_size,
blur_radius,
faces_per_pixel,

View File

@@ -99,11 +99,31 @@ auto ComputeFaceAreas(const torch::Tensor& face_verts) {
return face_areas;
}
// Helper function to use with std::find_if to find the index of any
// values in the top k struct which match a given idx.
struct IsNeighbor {
IsNeighbor(int neighbor_idx) {
this->neighbor_idx = neighbor_idx;
}
bool operator()(std::tuple<float, int, float, float, float, float> elem) {
return (std::get<1>(elem) == neighbor_idx);
}
int neighbor_idx;
};
// Function to sort based on the z distance in the top K queue
bool SortTopKByZdist(
std::tuple<float, int, float, float, float, float> a,
std::tuple<float, int, float, float, float, float> b) {
return std::get<0>(a) < std::get<0>(b);
}
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
RasterizeMeshesNaiveCpu(
const torch::Tensor& face_verts,
const torch::Tensor& mesh_to_face_first_idx,
const torch::Tensor& num_faces_per_mesh,
const torch::Tensor& clipped_faces_neighbor_idx,
const std::tuple<int, int> image_size,
const float blur_radius,
const int faces_per_pixel,
@@ -139,6 +159,7 @@ RasterizeMeshesNaiveCpu(
auto zbuf_a = zbuf.accessor<float, 4>();
auto pix_dists_a = pix_dists.accessor<float, 4>();
auto barycentric_coords_a = barycentric_coords.accessor<float, 5>();
auto neighbor_idx_a = clipped_faces_neighbor_idx.accessor<int64_t, 1>();
auto face_bboxes = ComputeFaceBoundingBoxes(face_verts);
auto face_bboxes_a = face_bboxes.accessor<float, 2>();
@@ -168,10 +189,11 @@ RasterizeMeshesNaiveCpu(
// X coordinate of the left of the pixel.
const float xf = PixToNonSquareNdc(xidx, W, H);
// Use a priority queue to hold values:
// Use a deque to hold values:
// (z, idx, r, bary.x, bary.y. bary.z)
std::priority_queue<std::tuple<float, int, float, float, float, float>>
q;
// Sort the deque as needed to mimic a priority queue.
std::deque<std::tuple<float, int, float, float, float, float>> q;
// Loop through the faces in the mesh.
for (int f = face_start_idx; f < face_stop_idx; ++f) {
@@ -240,15 +262,58 @@ RasterizeMeshesNaiveCpu(
if (!inside && dist >= blur_radius) {
continue;
}
// The current pixel lies inside the current face.
q.emplace(pz, f, signed_dist, bary_clip.x, bary_clip.y, bary_clip.z);
// Handle the case where a face (f) partially behind the image plane
// is clipped to a quadrilateral and then split into two faces (t1,
// t2). In this case we:
// 1. Find the index of the neighbor (e.g. for t1 need index of t2)
// 2. Check if the neighbor (t2) is already in the top K faces
// 3. If yes, compare the distance of the pixel to t1 with the
// distance to t2.
// 4. If dist_t1 < dist_t2, overwrite the values for t2 in the top K
// faces.
const int neighbor_idx = neighbor_idx_a[f];
int idx_top_k = -1;
// Check if neighboring face is already in the top K.
if (neighbor_idx != -1) {
const auto it =
std::find_if(q.begin(), q.end(), IsNeighbor(neighbor_idx));
// Get the index of the element from the iterator
idx_top_k = (it != q.end()) ? it - q.begin() : idx_top_k;
}
// If idx_top_k idx is not -1 then it is in the top K struct.
if (idx_top_k != -1) {
// If dist of current face is less than neighbor, overwrite
// the neighbor face values in the top K struct.
const auto neighbor = q[idx_top_k];
const float dist_neighbor = std::abs(std::get<2>(neighbor));
if (dist < dist_neighbor) {
// Overwrite the neighbor face values.
q[idx_top_k] = {
pz, f, signed_dist, bary_clip.x, bary_clip.y, bary_clip.z};
}
} else {
// Handle as a normal face.
// The current pixel lies inside the current face.
// Add at the end of the deque.
q.emplace_back(
pz, f, signed_dist, bary_clip.x, bary_clip.y, bary_clip.z);
}
// Sort the deque inplace based on the z distance
// to mimic using a priority queue.
std::sort(q.begin(), q.end(), SortTopKByZdist);
if (static_cast<int>(q.size()) > K) {
q.pop();
// remove the last value
q.pop_back();
}
}
while (!q.empty()) {
auto t = q.top();
q.pop();
// Loop through and add values to the output tensors
auto t = q.back();
q.pop_back();
const int i = q.size();
zbuf_a[n][yi][xi][i] = std::get<0>(t);
face_idxs_a[n][yi][xi][i] = std::get<1>(t);