CUDA/C++ Rasterizer updates to handle clipped faces

Summary:
- Updated the C++/CUDA mesh rasterization kernels to handle the clipped faces. In particular this required careful handling of the distance calculation for faces which are cut into a quadrilateral by the image plane and then split into two sub triangles i.e. both sub triangles can't be part of the top K faces.
- Updated `rasterize_meshes.py` to use the utils functions to clip the meshes and convert the fragments back to in terms of the unclipped mesh
- Added end to end tests

Reviewed By: jcjohnson

Differential Revision: D26169685

fbshipit-source-id: d64cd0d656109b965f44a35c301b7c81f451cfa0
This commit is contained in:
Nikhila Ravi 2021-02-08 14:30:55 -08:00 committed by Facebook GitHub Bot
parent 838b73d3b6
commit 340662e98e
12 changed files with 733 additions and 46 deletions

View File

@ -17,8 +17,8 @@ namespace {
// A structure for holding details about a pixel.
struct Pixel {
float z;
int64_t idx;
float dist;
int64_t idx; // idx of face
float dist; // abs distance of pixel to face
float3 bary;
};
@ -111,6 +111,7 @@ __device__ bool CheckPointOutsideBoundingBox(
template <typename FaceQ>
__device__ void CheckPixelInsideFace(
const float* face_verts, // (F, 3, 3)
const int64_t* clipped_faces_neighbor_idx, // (F,)
const int face_idx,
int& q_size,
float& q_max_z,
@ -173,12 +174,50 @@ __device__ void CheckPixelInsideFace(
// face.
const bool inside = p_bary.x > 0.0f && p_bary.y > 0.0f && p_bary.z > 0.0f;
const float signed_dist = inside ? -dist : dist;
// Check if pixel is outside blur region
if (!inside && dist >= blur_radius) {
return;
}
// Handle the case where a face (f) partially behind the image plane is
// clipped to a quadrilateral and then split into two faces (t1, t2). In this
// case we:
// 1. Find the index of the neighboring face (e.g. for t1 need index of t2)
// 2. Check if the neighboring face (t2) is already in the top K faces
// 3. If yes, compare the distance of the pixel to t1 with the distance to t2.
// 4. If dist_t1 < dist_t2, overwrite the values for t2 in the top K faces.
const int neighbor_idx = clipped_faces_neighbor_idx[face_idx];
int neighbor_idx_top_k = -1;
// Check if neighboring face is already in the top K.
// -1 is the fill value in clipped_faces_neighbor_idx
if (neighbor_idx != -1) {
// Only need to loop until q_size.
for (int i = 0; i < q_size; i++) {
if (q[i].idx == neighbor_idx) {
neighbor_idx_top_k = i;
break;
}
}
}
// If neighbor idx is not -1 then it is in the top K struct.
if (neighbor_idx_top_k != -1) {
// If dist of current face is less than neighbor then overwrite the
// neighbor face values in the top K struct.
float neighbor_dist = abs(q[neighbor_idx_top_k].dist);
if (dist < neighbor_dist) {
// Overwrite the neighbor face values
q[neighbor_idx_top_k] = {pz, face_idx, signed_dist, p_bary_clip};
// If pz > q_max then overwrite the max values and index of the max.
// q_size stays the same.
if (pz > q_max_z) {
q_max_z = pz;
q_max_idx = neighbor_idx_top_k;
}
}
} else {
// Handle as a normal face
if (q_size < K) {
// Just insert it.
q[q_size] = {pz, face_idx, signed_dist, p_bary_clip};
@ -198,7 +237,9 @@ __device__ void CheckPixelInsideFace(
}
}
}
}
}
} // namespace
// ****************************************************************************
@ -208,6 +249,7 @@ __global__ void RasterizeMeshesNaiveCudaKernel(
const float* face_verts,
const int64_t* mesh_to_face_first_idx,
const int64_t* num_faces_per_mesh,
const int64_t* clipped_faces_neighbor_idx,
const float blur_radius,
const bool perspective_correct,
const bool clip_barycentric_coords,
@ -265,6 +307,7 @@ __global__ void RasterizeMeshesNaiveCudaKernel(
CheckPixelInsideFace(
face_verts,
clipped_faces_neighbor_idx,
f,
q_size,
q_max_z,
@ -298,6 +341,7 @@ RasterizeMeshesNaiveCuda(
const at::Tensor& face_verts,
const at::Tensor& mesh_to_faces_packed_first_idx,
const at::Tensor& num_faces_per_mesh,
const at::Tensor& clipped_faces_neighbor_idx,
const std::tuple<int, int> image_size,
const float blur_radius,
const int num_closest,
@ -313,6 +357,10 @@ RasterizeMeshesNaiveCuda(
num_faces_per_mesh.size(0) == mesh_to_faces_packed_first_idx.size(0),
"num_faces_per_mesh must have save size first dimension as mesh_to_faces_packed_first_idx");
TORCH_CHECK(
clipped_faces_neighbor_idx.size(0) == face_verts.size(0),
"clipped_faces_neighbor_idx must have save size first dimension as face_verts");
if (num_closest > kMaxPointsPerPixel) {
std::stringstream ss;
ss << "Must have points_per_pixel <= " << kMaxPointsPerPixel;
@ -323,11 +371,16 @@ RasterizeMeshesNaiveCuda(
at::TensorArg face_verts_t{face_verts, "face_verts", 1},
mesh_to_faces_packed_first_idx_t{
mesh_to_faces_packed_first_idx, "mesh_to_faces_packed_first_idx", 2},
num_faces_per_mesh_t{num_faces_per_mesh, "num_faces_per_mesh", 3};
num_faces_per_mesh_t{num_faces_per_mesh, "num_faces_per_mesh", 3},
clipped_faces_neighbor_idx_t{
clipped_faces_neighbor_idx, "clipped_faces_neighbor_idx", 4};
at::CheckedFrom c = "RasterizeMeshesNaiveCuda";
at::checkAllSameGPU(
c,
{face_verts_t, mesh_to_faces_packed_first_idx_t, num_faces_per_mesh_t});
{face_verts_t,
mesh_to_faces_packed_first_idx_t,
num_faces_per_mesh_t,
clipped_faces_neighbor_idx_t});
// Set the device for the kernel launch based on the device of the input
at::cuda::CUDAGuard device_guard(face_verts.device());
@ -358,6 +411,7 @@ RasterizeMeshesNaiveCuda(
face_verts.contiguous().data_ptr<float>(),
mesh_to_faces_packed_first_idx.contiguous().data_ptr<int64_t>(),
num_faces_per_mesh.contiguous().data_ptr<int64_t>(),
clipped_faces_neighbor_idx.contiguous().data_ptr<int64_t>(),
blur_radius,
perspective_correct,
clip_barycentric_coords,
@ -800,6 +854,7 @@ at::Tensor RasterizeMeshesCoarseCuda(
__global__ void RasterizeMeshesFineCudaKernel(
const float* face_verts, // (F, 3, 3)
const int32_t* bin_faces, // (N, BH, BW, T)
const int64_t* clipped_faces_neighbor_idx, // (F,)
const float blur_radius,
const int bin_size,
const bool perspective_correct,
@ -858,6 +913,7 @@ __global__ void RasterizeMeshesFineCudaKernel(
int q_size = 0;
float q_max_z = -1000;
int q_max_idx = -1;
for (int m = 0; m < M; m++) {
const int f = bin_faces[n * BH * BW * M + by * BW * M + bx * M + m];
if (f < 0) {
@ -867,6 +923,7 @@ __global__ void RasterizeMeshesFineCudaKernel(
// update q, q_size, q_max_z and q_max_idx in place.
CheckPixelInsideFace(
face_verts,
clipped_faces_neighbor_idx,
f,
q_size,
q_max_z,
@ -906,6 +963,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
RasterizeMeshesFineCuda(
const at::Tensor& face_verts,
const at::Tensor& bin_faces,
const at::Tensor& clipped_faces_neighbor_idx,
const std::tuple<int, int> image_size,
const float blur_radius,
const int bin_size,
@ -918,12 +976,18 @@ RasterizeMeshesFineCuda(
face_verts.size(2) == 3,
"face_verts must have dimensions (num_faces, 3, 3)");
TORCH_CHECK(bin_faces.ndimension() == 4, "bin_faces must have 4 dimensions");
TORCH_CHECK(
clipped_faces_neighbor_idx.size(0) == face_verts.size(0),
"clipped_faces_neighbor_idx must have the same first dimension as face_verts");
// Check inputs are on the same device
at::TensorArg face_verts_t{face_verts, "face_verts", 1},
bin_faces_t{bin_faces, "bin_faces", 2};
bin_faces_t{bin_faces, "bin_faces", 2},
clipped_faces_neighbor_idx_t{
clipped_faces_neighbor_idx, "clipped_faces_neighbor_idx", 3};
at::CheckedFrom c = "RasterizeMeshesFineCuda";
at::checkAllSameGPU(c, {face_verts_t, bin_faces_t});
at::checkAllSameGPU(
c, {face_verts_t, bin_faces_t, clipped_faces_neighbor_idx_t});
// Set the device for the kernel launch based on the device of the input
at::cuda::CUDAGuard device_guard(face_verts.device());
@ -961,6 +1025,7 @@ RasterizeMeshesFineCuda(
RasterizeMeshesFineCudaKernel<<<blocks, threads, 0, stream>>>(
face_verts.contiguous().data_ptr<float>(),
bin_faces.contiguous().data_ptr<int32_t>(),
clipped_faces_neighbor_idx.contiguous().data_ptr<int64_t>(),
blur_radius,
bin_size,
perspective_correct,

View File

@ -15,6 +15,7 @@ RasterizeMeshesNaiveCpu(
const torch::Tensor& face_verts,
const torch::Tensor& mesh_to_face_first_idx,
const torch::Tensor& num_faces_per_mesh,
const torch::Tensor& clipped_faces_neighbor_idx,
const std::tuple<int, int> image_size,
const float blur_radius,
const int faces_per_pixel,
@ -28,6 +29,7 @@ RasterizeMeshesNaiveCuda(
const at::Tensor& face_verts,
const at::Tensor& mesh_to_face_first_idx,
const at::Tensor& num_faces_per_mesh,
const torch::Tensor& clipped_faces_neighbor_idx,
const std::tuple<int, int> image_size,
const float blur_radius,
const int num_closest,
@ -48,6 +50,12 @@ RasterizeMeshesNaiveCuda(
// the batch where N is the batch size.
// num_faces_per_mesh: LongTensor of shape (N) giving the number of faces
// for each mesh in the batch.
// clipped_faces_neighbor_idx: LongTensor of shape (F,) giving the
// index of the neighboring face for each face which was clipped to a
// quadrilateral and then divided into two triangles.
// e.g. for a face f partially behind the image plane which is split into
// two triangles (t1, t2): clipped_faces_neighbor_idx[t1_idx] = t2_idx
// Faces which are not clipped and subdivided are set to -1.
// image_size: Tuple (H, W) giving the size in pixels of the output
// image to be rasterized.
// blur_radius: float distance in NDC coordinates uses to expand the face
@ -90,6 +98,7 @@ RasterizeMeshesNaive(
const torch::Tensor& face_verts,
const torch::Tensor& mesh_to_face_first_idx,
const torch::Tensor& num_faces_per_mesh,
const torch::Tensor& clipped_faces_neighbor_idx,
const std::tuple<int, int> image_size,
const float blur_radius,
const int faces_per_pixel,
@ -106,6 +115,7 @@ RasterizeMeshesNaive(
face_verts,
mesh_to_face_first_idx,
num_faces_per_mesh,
clipped_faces_neighbor_idx,
image_size,
blur_radius,
faces_per_pixel,
@ -120,6 +130,7 @@ RasterizeMeshesNaive(
face_verts,
mesh_to_face_first_idx,
num_faces_per_mesh,
clipped_faces_neighbor_idx,
image_size,
blur_radius,
faces_per_pixel,
@ -306,6 +317,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
RasterizeMeshesFineCuda(
const torch::Tensor& face_verts,
const torch::Tensor& bin_faces,
const torch::Tensor& clipped_faces_neighbor_idx,
const std::tuple<int, int> image_size,
const float blur_radius,
const int bin_size,
@ -322,6 +334,12 @@ RasterizeMeshesFineCuda(
// in NDC coordinates in the range [-1, 1].
// bin_faces: int32 Tensor of shape (N, B, B, M) giving the indices of faces
// that fall into each bin (output from coarse rasterization).
// clipped_faces_neighbor_idx: LongTensor of shape (F,) giving the
// index of the neighboring face for each face which was clipped to a
// quadrilateral and then divided into two triangles.
// e.g. for a face f partially behind the image plane which is split into
// two triangles (t1, t2): clipped_faces_neighbor_idx[t1_idx] = t2_idx
// Faces which are not clipped and subdivided are set to -1.
// image_size: Tuple (H, W) giving the size in pixels of the output
// image to be rasterized.
// blur_radius: float distance in NDC coordinates uses to expand the face
@ -364,6 +382,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
RasterizeMeshesFine(
const torch::Tensor& face_verts,
const torch::Tensor& bin_faces,
const torch::Tensor& clipped_faces_neighbor_idx,
const std::tuple<int, int> image_size,
const float blur_radius,
const int bin_size,
@ -378,6 +397,7 @@ RasterizeMeshesFine(
return RasterizeMeshesFineCuda(
face_verts,
bin_faces,
clipped_faces_neighbor_idx,
image_size,
blur_radius,
bin_size,
@ -411,6 +431,12 @@ RasterizeMeshesFine(
// the batch where N is the batch size.
// num_faces_per_mesh: LongTensor of shape (N) giving the number of faces
// for each mesh in the batch.
// clipped_faces_neighbor_idx: LongTensor of shape (F,) giving the
// index of the neighboring face for each face which was clipped to a
// quadrilateral and then divided into two triangles.
// e.g. for a face f partially behind the image plane which is split into
// two triangles (t1, t2): clipped_faces_neighbor_idx[t1_idx] = t2_idx
// Faces which are not clipped and subdivided are set to -1.
// image_size: Tuple (H, W) giving the size in pixels of the output
// image to be rasterized.
// blur_radius: float distance in NDC coordinates uses to expand the face
@ -456,6 +482,7 @@ RasterizeMeshes(
const torch::Tensor& face_verts,
const torch::Tensor& mesh_to_face_first_idx,
const torch::Tensor& num_faces_per_mesh,
const torch::Tensor& clipped_faces_neighbor_idx,
const std::tuple<int, int> image_size,
const float blur_radius,
const int faces_per_pixel,
@ -477,6 +504,7 @@ RasterizeMeshes(
return RasterizeMeshesFine(
face_verts,
bin_faces,
clipped_faces_neighbor_idx,
image_size,
blur_radius,
bin_size,
@ -490,6 +518,7 @@ RasterizeMeshes(
face_verts,
mesh_to_face_first_idx,
num_faces_per_mesh,
clipped_faces_neighbor_idx,
image_size,
blur_radius,
faces_per_pixel,

View File

@ -99,11 +99,31 @@ auto ComputeFaceAreas(const torch::Tensor& face_verts) {
return face_areas;
}
// Helper function to use with std::find_if to find the index of any
// values in the top k struct which match a given idx.
struct IsNeighbor {
IsNeighbor(int neighbor_idx) {
this->neighbor_idx = neighbor_idx;
}
bool operator()(std::tuple<float, int, float, float, float, float> elem) {
return (std::get<1>(elem) == neighbor_idx);
}
int neighbor_idx;
};
// Function to sort based on the z distance in the top K queue
bool SortTopKByZdist(
std::tuple<float, int, float, float, float, float> a,
std::tuple<float, int, float, float, float, float> b) {
return std::get<0>(a) < std::get<0>(b);
}
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
RasterizeMeshesNaiveCpu(
const torch::Tensor& face_verts,
const torch::Tensor& mesh_to_face_first_idx,
const torch::Tensor& num_faces_per_mesh,
const torch::Tensor& clipped_faces_neighbor_idx,
const std::tuple<int, int> image_size,
const float blur_radius,
const int faces_per_pixel,
@ -139,6 +159,7 @@ RasterizeMeshesNaiveCpu(
auto zbuf_a = zbuf.accessor<float, 4>();
auto pix_dists_a = pix_dists.accessor<float, 4>();
auto barycentric_coords_a = barycentric_coords.accessor<float, 5>();
auto neighbor_idx_a = clipped_faces_neighbor_idx.accessor<int64_t, 1>();
auto face_bboxes = ComputeFaceBoundingBoxes(face_verts);
auto face_bboxes_a = face_bboxes.accessor<float, 2>();
@ -168,10 +189,11 @@ RasterizeMeshesNaiveCpu(
// X coordinate of the left of the pixel.
const float xf = PixToNonSquareNdc(xidx, W, H);
// Use a priority queue to hold values:
// Use a deque to hold values:
// (z, idx, r, bary.x, bary.y. bary.z)
std::priority_queue<std::tuple<float, int, float, float, float, float>>
q;
// Sort the deque as needed to mimic a priority queue.
std::deque<std::tuple<float, int, float, float, float, float>> q;
// Loop through the faces in the mesh.
for (int f = face_start_idx; f < face_stop_idx; ++f) {
@ -240,15 +262,58 @@ RasterizeMeshesNaiveCpu(
if (!inside && dist >= blur_radius) {
continue;
}
// Handle the case where a face (f) partially behind the image plane
// is clipped to a quadrilateral and then split into two faces (t1,
// t2). In this case we:
// 1. Find the index of the neighbor (e.g. for t1 need index of t2)
// 2. Check if the neighbor (t2) is already in the top K faces
// 3. If yes, compare the distance of the pixel to t1 with the
// distance to t2.
// 4. If dist_t1 < dist_t2, overwrite the values for t2 in the top K
// faces.
const int neighbor_idx = neighbor_idx_a[f];
int idx_top_k = -1;
// Check if neighboring face is already in the top K.
if (neighbor_idx != -1) {
const auto it =
std::find_if(q.begin(), q.end(), IsNeighbor(neighbor_idx));
// Get the index of the element from the iterator
idx_top_k = (it != q.end()) ? it - q.begin() : idx_top_k;
}
// If idx_top_k idx is not -1 then it is in the top K struct.
if (idx_top_k != -1) {
// If dist of current face is less than neighbor, overwrite
// the neighbor face values in the top K struct.
const auto neighbor = q[idx_top_k];
const float dist_neighbor = std::abs(std::get<2>(neighbor));
if (dist < dist_neighbor) {
// Overwrite the neighbor face values.
q[idx_top_k] = {
pz, f, signed_dist, bary_clip.x, bary_clip.y, bary_clip.z};
}
} else {
// Handle as a normal face.
// The current pixel lies inside the current face.
q.emplace(pz, f, signed_dist, bary_clip.x, bary_clip.y, bary_clip.z);
// Add at the end of the deque.
q.emplace_back(
pz, f, signed_dist, bary_clip.x, bary_clip.y, bary_clip.z);
}
// Sort the deque inplace based on the z distance
// to mimic using a priority queue.
std::sort(q.begin(), q.end(), SortTopKByZdist);
if (static_cast<int>(q.size()) > K) {
q.pop();
// remove the last value
q.pop_back();
}
}
while (!q.empty()) {
auto t = q.top();
q.pop();
// Loop through and add values to the output tensors
auto t = q.back();
q.pop_back();
const int i = q.size();
zbuf_a[n][yi][xi][i] = std::get<0>(t);
face_idxs_a[n][yi][xi][i] = std::get<1>(t);

View File

@ -9,6 +9,12 @@ import torch
# pyre-fixme[21]: Could not find name `_C` in `pytorch3d`.
from pytorch3d import _C
from .clip import (
ClipFrustum,
clip_faces,
convert_clipped_rasterization_to_original_faces,
)
# TODO make the epsilon user configurable
kEpsilon = 1e-8
@ -28,6 +34,8 @@ def rasterize_meshes(
perspective_correct: bool = False,
clip_barycentric_coords: bool = False,
cull_backfaces: bool = False,
z_clip_value: Optional[float] = None,
cull_to_frustum: bool = False,
):
"""
Rasterize a batch of meshes given the shape of the desired output image.
@ -67,7 +75,8 @@ def rasterize_meshes(
will be raised. This should not affect the output values, but can affect
the memory usage in the forward pass.
perspective_correct: Bool, Whether to apply perspective correction when computing
barycentric coordinates for pixels.
barycentric coordinates for pixels. This should be set to True if a perspective
camera is used.
cull_backfaces: Bool, Whether to only rasterize mesh faces which are
visible to the camera. This assumes that vertices of
front-facing triangles are ordered in an anti-clockwise
@ -76,6 +85,15 @@ def rasterize_meshes(
direction. NOTE: This will only work if the mesh faces are
consistently defined with counter-clockwise ordering when
viewed from the outside.
z_clip_value: if not None, then triangles will be clipped (and possibly
subdivided into smaller triangles) such that z >= z_clip_value.
This avoids camera projections that go to infinity as z->0.
Default is None as clipping affects rasterization speed and
should only be turned on if explicitly needed.
See clip.py for all the extra computation that is required.
cull_to_frustum: if True, triangles outside the view frustum will be culled.
Culling involves removing all faces which fall outside view frustum.
Default is False so that it is turned on only when needed.
Returns:
4-element tuple containing
@ -141,6 +159,42 @@ def rasterize_meshes(
im_size = (image_size, image_size)
max_image_size = image_size
clipped_faces_neighbor_idx = None
if z_clip_value is not None or cull_to_frustum:
# Cull faces outside the view frustum, and clip faces that are partially
# behind the camera into the portion of the triangle in front of the
# camera. This may change the number of faces
frustum = ClipFrustum(
left=-1,
right=1,
top=-1,
bottom=1,
perspective_correct=perspective_correct,
z_clip_value=z_clip_value,
cull=cull_to_frustum,
)
clipped_faces = clip_faces(
face_verts, mesh_to_face_first_idx, num_faces_per_mesh, frustum=frustum
)
face_verts = clipped_faces.face_verts
mesh_to_face_first_idx = clipped_faces.mesh_to_face_first_idx
num_faces_per_mesh = clipped_faces.num_faces_per_mesh
# For case 4 clipped triangles (where a big triangle is split in two smaller triangles),
# need the index of the neighboring clipped triangle as only one can be in
# in the top K closest faces in the rasterization step.
clipped_faces_neighbor_idx = clipped_faces.clipped_faces_neighbor_idx
if clipped_faces_neighbor_idx is None:
# Set to the default which is all -1s.
clipped_faces_neighbor_idx = torch.full(
size=(face_verts.shape[0],),
fill_value=-1,
device=meshes.device,
dtype=torch.int64,
)
# TODO: Choose naive vs coarse-to-fine based on mesh size and image size.
if bin_size is None:
if not verts_packed.is_cuda:
@ -172,10 +226,11 @@ def rasterize_meshes(
max_faces_per_bin = int(max(10000, meshes._F / 5))
# pyre-fixme[16]: `_RasterizeFaceVerts` has no attribute `apply`.
return _RasterizeFaceVerts.apply(
pix_to_face, zbuf, barycentric_coords, dists = _RasterizeFaceVerts.apply(
face_verts,
mesh_to_face_first_idx,
num_faces_per_mesh,
clipped_faces_neighbor_idx,
im_size,
blur_radius,
faces_per_pixel,
@ -186,6 +241,17 @@ def rasterize_meshes(
cull_backfaces,
)
if z_clip_value is not None or cull_to_frustum:
# If faces were clipped, map the rasterization result to be in terms of the
# original unclipped faces. This may involve converting barycentric
# coordinates
outputs = convert_clipped_rasterization_to_original_faces(
pix_to_face, barycentric_coords, clipped_faces
)
pix_to_face, barycentric_coords = outputs
return pix_to_face, zbuf, barycentric_coords, dists
class _RasterizeFaceVerts(torch.autograd.Function):
"""
@ -216,9 +282,10 @@ class _RasterizeFaceVerts(torch.autograd.Function):
# pyre-fixme[14]: `forward` overrides method defined in `Function` inconsistently.
def forward(
ctx,
face_verts,
mesh_to_face_first_idx,
num_faces_per_mesh,
face_verts: torch.Tensor,
mesh_to_face_first_idx: torch.Tensor,
num_faces_per_mesh: torch.Tensor,
clipped_faces_neighbor_idx: torch.Tensor,
image_size: Union[List[int], Tuple[int, int]] = (256, 256),
blur_radius: float = 0.01,
faces_per_pixel: int = 0,
@ -227,12 +294,15 @@ class _RasterizeFaceVerts(torch.autograd.Function):
perspective_correct: bool = False,
clip_barycentric_coords: bool = False,
cull_backfaces: bool = False,
z_clip_value: Optional[float] = None,
cull_to_frustum: bool = True,
):
# pyre-fixme[16]: Module `pytorch3d` has no attribute `_C`.
pix_to_face, zbuf, barycentric_coords, dists = _C.rasterize_meshes(
face_verts,
mesh_to_face_first_idx,
num_faces_per_mesh,
clipped_faces_neighbor_idx,
image_size,
blur_radius,
faces_per_pixel,
@ -242,6 +312,7 @@ class _RasterizeFaceVerts(torch.autograd.Function):
clip_barycentric_coords,
cull_backfaces,
)
ctx.save_for_backward(face_verts, pix_to_face)
ctx.mark_non_differentiable(pix_to_face)
ctx.perspective_correct = perspective_correct
@ -253,6 +324,7 @@ class _RasterizeFaceVerts(torch.autograd.Function):
grad_face_verts = None
grad_mesh_to_face_first_idx = None
grad_num_faces_per_mesh = None
grad_clipped_faces_neighbor_idx = None
grad_image_size = None
grad_radius = None
grad_faces_per_pixel = None
@ -275,6 +347,7 @@ class _RasterizeFaceVerts(torch.autograd.Function):
grad_face_verts,
grad_mesh_to_face_first_idx,
grad_num_faces_per_mesh,
grad_clipped_faces_neighbor_idx,
grad_image_size,
grad_radius,
grad_faces_per_pixel,
@ -339,6 +412,9 @@ def rasterize_meshes_python(
perspective_correct: bool = False,
clip_barycentric_coords: bool = False,
cull_backfaces: bool = False,
z_clip_value: Optional[float] = None,
cull_to_frustum: bool = True,
clipped_faces_neighbor_idx: Optional[torch.Tensor] = None,
):
"""
Naive PyTorch implementation of mesh rasterization with the same inputs and
@ -359,6 +435,26 @@ def rasterize_meshes_python(
mesh_to_face_first_idx = meshes.mesh_to_faces_packed_first_idx()
num_faces_per_mesh = meshes.num_faces_per_mesh()
if z_clip_value is not None or cull_to_frustum:
# Cull faces outside the view frustum, and clip faces that are partially
# behind the camera into the portion of the triangle in front of the
# camera. This may change the number of faces
frustum = ClipFrustum(
left=-1,
right=1,
top=-1,
bottom=1,
perspective_correct=perspective_correct,
z_clip_value=z_clip_value,
cull=cull_to_frustum,
)
clipped_faces = clip_faces(
faces_verts, mesh_to_face_first_idx, num_faces_per_mesh, frustum=frustum
)
faces_verts = clipped_faces.face_verts
mesh_to_face_first_idx = clipped_faces.mesh_to_face_first_idx
num_faces_per_mesh = clipped_faces.num_faces_per_mesh
# Intialize output tensors.
face_idxs = torch.full(
(N, H, W, K), fill_value=-1, dtype=torch.int64, device=device
@ -468,26 +564,55 @@ def rasterize_meshes_python(
# Points inside the triangle have negative distance.
dist = point_triangle_distance(pxy, v0[:2], v1[:2], v2[:2])
signed_dist = dist * -1.0 if inside else dist
# Add an epsilon to prevent errors when comparing distance
# to blur radius.
if not inside and dist >= blur_radius:
continue
top_k_points.append((pz, f, bary, signed_dist))
# Handle the case where a face (f) partially behind the image plane is
# clipped to a quadrilateral and then split into two faces (t1, t2).
top_k_idx = -1
if (
clipped_faces_neighbor_idx is not None
and clipped_faces_neighbor_idx[f] != -1
):
neighbor_idx = clipped_faces_neighbor_idx[f]
# See if neighbor_idx is in top_k and find index
top_k_idx = [
i
for i, val in enumerate(top_k_points)
if val[1] == neighbor_idx
]
top_k_idx = top_k_idx[0] if len(top_k_idx) > 0 else -1
if top_k_idx != -1 and dist < top_k_points[top_k_idx][3]:
# Overwrite the neighbor with current face info
top_k_points[top_k_idx] = (pz, f, bary, dist, inside)
else:
# Handle as a normal face
top_k_points.append((pz, f, bary, dist, inside))
top_k_points.sort()
if len(top_k_points) > K:
top_k_points = top_k_points[:K]
# Save to output tensors.
for k, (pz, f, bary, dist) in enumerate(top_k_points):
for k, (pz, f, bary, dist, inside) in enumerate(top_k_points):
zbuf[n, yi, xi, k] = pz
face_idxs[n, yi, xi, k] = f
bary_coords[n, yi, xi, k, 0] = bary[0]
bary_coords[n, yi, xi, k, 1] = bary[1]
bary_coords[n, yi, xi, k, 2] = bary[2]
pix_dists[n, yi, xi, k] = dist
# Write the signed distance
pix_dists[n, yi, xi, k] = -dist if inside else dist
if z_clip_value is not None or cull_to_frustum:
# If faces were clipped, map the rasterization result to be in terms of the
# original unclipped faces. This may involve converting barycentric
# coordinates
(face_idxs, bary_coords,) = convert_clipped_rasterization_to_original_faces(
face_idxs, bary_coords, clipped_faces
)
return face_idxs, zbuf, bary_coords, pix_dists

View File

@ -27,6 +27,8 @@ class RasterizationSettings:
"perspective_correct",
"clip_barycentric_coords",
"cull_backfaces",
"z_clip_value",
"cull_to_frustum",
]
def __init__(
@ -36,9 +38,13 @@ class RasterizationSettings:
faces_per_pixel: int = 1,
bin_size: Optional[int] = None,
max_faces_per_bin: Optional[int] = None,
perspective_correct: bool = False,
# set perspective_correct = None so that the
# value can be inferred correctly from the Camera type
perspective_correct: Optional[bool] = None,
clip_barycentric_coords: Optional[bool] = None,
cull_backfaces: bool = False,
z_clip_value: Optional[float] = None,
cull_to_frustum: bool = False,
):
self.image_size = image_size
self.blur_radius = blur_radius
@ -48,6 +54,8 @@ class RasterizationSettings:
self.perspective_correct = perspective_correct
self.clip_barycentric_coords = clip_barycentric_coords
self.cull_backfaces = cull_backfaces
self.z_clip_value = z_clip_value
self.cull_to_frustum = cull_to_frustum
class MeshRasterizer(nn.Module):
@ -139,12 +147,19 @@ class MeshRasterizer(nn.Module):
if clip_barycentric_coords is None:
clip_barycentric_coords = raster_settings.blur_radius > 0.0
# If not specified, infer perspective_correct from the camera
# If not specified, infer perspective_correct and z_clip_value from the camera
cameras = kwargs.get("cameras", self.cameras)
if raster_settings.perspective_correct is not None:
perspective_correct = raster_settings.perspective_correct
else:
perspective_correct = cameras.is_perspective()
if raster_settings.z_clip_value is not None:
z_clip = raster_settings.z_clip_value
else:
znear = cameras.get_znear()
if isinstance(znear, torch.Tensor):
znear = znear.min().item()
z_clip = None if not perspective_correct or znear is None else znear / 2
pix_to_face, zbuf, bary_coords, dists = rasterize_meshes(
meshes_screen,
@ -153,9 +168,11 @@ class MeshRasterizer(nn.Module):
faces_per_pixel=raster_settings.faces_per_pixel,
bin_size=raster_settings.bin_size,
max_faces_per_bin=raster_settings.max_faces_per_bin,
perspective_correct=perspective_correct,
clip_barycentric_coords=clip_barycentric_coords,
perspective_correct=perspective_correct,
cull_backfaces=raster_settings.cull_backfaces,
z_clip_value=z_clip,
cull_to_frustum=raster_settings.cull_to_frustum,
)
return Fragments(
pix_to_face=pix_to_face, zbuf=zbuf, bary_coords=bary_coords, dists=dists

View File

@ -66,7 +66,7 @@ def bm_rasterize_meshes() -> None:
# Square and non square cases
image_size = [64, 128, 512, (512, 256), (256, 512)]
blur = [1e-6]
faces_per_pixel = [50]
faces_per_pixel = [40]
test_cases = product(num_meshes, ico_level, image_size, blur, faces_per_pixel)
for case in test_cases:
@ -87,6 +87,35 @@ def bm_rasterize_meshes() -> None:
warmup_iters=1,
)
# Test a subset of the cases with the
# image plane intersecting the mesh.
kwargs_list = []
num_meshes = [8, 16]
# Square and non square cases
image_size = [64, 128, 512, (512, 256), (256, 512)]
dist = [3, 0.8, 0.5]
test_cases = product(num_meshes, dist, image_size)
for case in test_cases:
n, d, im = case
kwargs_list.append(
{
"num_meshes": n,
"ico_level": 4,
"image_size": im,
"blur_radius": 1e-6,
"faces_per_pixel": 40,
"dist": d,
}
)
benchmark(
TestRasterizeMeshes.bm_rasterize_meshes_with_clipping,
"RASTERIZE_MESHES_CUDA_CLIPPING",
kwargs_list,
warmup_iters=1,
)
if __name__ == "__main__":
bm_rasterize_meshes()

BIN
tests/data/room.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 7.8 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 7.2 KiB

View File

@ -6,6 +6,8 @@ import unittest
import torch
from common_testing import TestCaseMixin, get_random_cuda_device
from pytorch3d import _C
from pytorch3d.renderer import FoVPerspectiveCameras, look_at_view_transform
from pytorch3d.renderer.mesh import MeshRasterizer, RasterizationSettings
from pytorch3d.renderer.mesh.rasterize_meshes import (
rasterize_meshes,
rasterize_meshes_python,
@ -1204,3 +1206,50 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase):
torch.cuda.synchronize(device)
return rasterize
@staticmethod
def bm_rasterize_meshes_with_clipping(
num_meshes: int,
ico_level: int,
image_size: int,
blur_radius: float,
faces_per_pixel: int,
dist: float,
):
device = get_random_cuda_device()
meshes = ico_sphere(ico_level, device)
meshes_batch = meshes.extend(num_meshes)
settings = RasterizationSettings(
image_size=image_size,
blur_radius=blur_radius,
faces_per_pixel=faces_per_pixel,
z_clip_value=1e-2,
perspective_correct=True,
cull_to_frustum=True,
)
# The camera is positioned so that the image plane intersects
# the mesh and some faces are partially behind the image plane.
R, T = look_at_view_transform(dist, 0, 0)
cameras = FoVPerspectiveCameras(device=device, R=R, T=T, fov=90)
rasterizer = MeshRasterizer(raster_settings=settings, cameras=cameras)
# Transform the meshes to projec them onto the image plane
meshes_screen = rasterizer.transform(meshes_batch)
torch.cuda.synchronize(device)
def rasterize():
# Only measure rasterization speed (including clipping)
rasterize_meshes(
meshes_screen,
image_size,
blur_radius,
faces_per_pixel,
z_clip_value=1e-2,
perspective_correct=True,
cull_to_frustum=True,
)
torch.cuda.synchronize(device)
return rasterize

View File

@ -245,6 +245,7 @@ class TestRenderImplicit(TestCaseMixin, unittest.TestCase):
image_size=image_size,
blur_radius=1e-3,
faces_per_pixel=10,
z_clip_value=None,
perspective_correct=False,
),
),

View File

@ -994,6 +994,7 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
blur_radius=0.0,
faces_per_pixel=1,
cull_backfaces=True,
perspective_correct=False,
)
# Init shader settings

View File

@ -9,14 +9,161 @@ See pytorch3d/renderer/mesh/clip.py for more details about the
clipping process.
"""
import unittest
from pathlib import Path
import imageio
import numpy as np
import torch
from common_testing import TestCaseMixin
from pytorch3d.renderer.mesh import ClipFrustum, clip_faces
from common_testing import TestCaseMixin, load_rgb_image
from pytorch3d.io import save_obj
from pytorch3d.renderer.cameras import FoVPerspectiveCameras, look_at_view_transform
from pytorch3d.renderer.lighting import PointLights
from pytorch3d.renderer.mesh import (
ClipFrustum,
TexturesUV,
clip_faces,
convert_clipped_rasterization_to_original_faces,
)
from pytorch3d.renderer.mesh.rasterize_meshes import _RasterizeFaceVerts
from pytorch3d.renderer.mesh.rasterizer import MeshRasterizer, RasterizationSettings
from pytorch3d.renderer.mesh.renderer import MeshRenderer
from pytorch3d.renderer.mesh.shader import SoftPhongShader
from pytorch3d.structures.meshes import Meshes
# If DEBUG=True, save out images generated in the tests for debugging.
# All saved images have prefix DEBUG_
DEBUG = False
DATA_DIR = Path(__file__).resolve().parent / "data"
class TestRenderMeshesClipping(TestCaseMixin, unittest.TestCase):
def load_cube_mesh_with_texture(self, device="cpu", with_grad: bool = False):
verts = torch.tensor(
[
[-1, 1, 1],
[1, 1, 1],
[1, -1, 1],
[-1, -1, 1],
[-1, 1, -1],
[1, 1, -1],
[1, -1, -1],
[-1, -1, -1],
],
device=device,
dtype=torch.float32,
requires_grad=with_grad,
)
# all faces correctly wound
faces = torch.tensor(
[
[0, 1, 4],
[4, 1, 5],
[1, 2, 5],
[5, 2, 6],
[2, 7, 6],
[2, 3, 7],
[3, 4, 7],
[0, 4, 3],
[4, 5, 6],
[4, 6, 7],
],
device=device,
dtype=torch.int64,
)
verts_uvs = torch.tensor(
[
[
[0, 1],
[1, 1],
[1, 0],
[0, 0],
[0.204, 0.743],
[0.781, 0.743],
[0.781, 0.154],
[0.204, 0.154],
]
],
device=device,
dtype=torch.float,
)
texture_map = load_rgb_image("room.jpg", DATA_DIR).to(device)
textures = TexturesUV(
maps=[texture_map], faces_uvs=faces.unsqueeze(0), verts_uvs=verts_uvs
)
mesh = Meshes([verts], [faces], textures=textures)
if with_grad:
return mesh, verts
return mesh
def test_cube_mesh_render(self):
"""
End-End test of rendering a cube mesh with texture
from decreasing camera distances. The camera starts
outside the cube and enters the inside of the cube.
"""
device = torch.device("cuda:0")
mesh = self.load_cube_mesh_with_texture(device)
raster_settings = RasterizationSettings(
image_size=512,
blur_radius=1e-8,
faces_per_pixel=5,
z_clip_value=1e-2,
perspective_correct=True,
bin_size=0,
)
# Only ambient, no diffuse or specular
lights = PointLights(
device=device,
ambient_color=((1.0, 1.0, 1.0),),
diffuse_color=((0.0, 0.0, 0.0),),
specular_color=((0.0, 0.0, 0.0),),
location=[[0.0, 0.0, -3.0]],
)
renderer = MeshRenderer(
rasterizer=MeshRasterizer(raster_settings=raster_settings),
shader=SoftPhongShader(device=device, lights=lights),
)
# Render the cube by decreasing the distance from the camera until
# the camera enters the cube. Check the output looks correct.
images_list = []
dists = np.linspace(0.1, 2.5, 20)[::-1]
for d in dists:
R, T = look_at_view_transform(d, 0, 0)
T[0, 1] -= 0.1 # move down in the y axis
cameras = FoVPerspectiveCameras(device=device, R=R, T=T, fov=90)
images = renderer(mesh, cameras=cameras)
rgb = images[0, ..., :3].cpu().detach()
filename = "DEBUG_cube_dist=%.1f.jpg" % d
im = (rgb.numpy() * 255).astype(np.uint8)
images_list.append(im)
# Check one of the images where the camera is inside the mesh
if d == 0.5:
filename = "test_render_mesh_clipped_cam_dist=0.5.jpg"
image_ref = load_rgb_image(filename, DATA_DIR)
self.assertClose(rgb, image_ref, atol=0.05)
# Save a gif of the output - this should show
# the camera moving inside the cube.
if DEBUG:
gif_filename = (
"room_original.gif"
if raster_settings.z_clip_value is None
else "room_clipped.gif"
)
imageio.mimsave(DATA_DIR / gif_filename, images_list, fps=2)
save_obj(
f=DATA_DIR / "cube.obj",
verts=mesh.verts_packed().cpu(),
faces=mesh.faces_packed().cpu(),
)
@staticmethod
def clip_faces(meshes):
verts_packed = meshes.verts_packed()
@ -42,6 +189,34 @@ class TestRenderMeshesClipping(TestCaseMixin, unittest.TestCase):
)
return clipped_faces
def test_grad(self):
"""
Check that gradient flow is unaffected when the camera is inside the mesh
"""
device = torch.device("cuda:0")
mesh, verts = self.load_cube_mesh_with_texture(device=device, with_grad=True)
raster_settings = RasterizationSettings(
image_size=512,
blur_radius=1e-5,
faces_per_pixel=5,
z_clip_value=1e-2,
perspective_correct=True,
bin_size=0,
)
renderer = MeshRenderer(
rasterizer=MeshRasterizer(raster_settings=raster_settings),
shader=SoftPhongShader(device=device),
)
dist = 0.4 # Camera is inside the cube
R, T = look_at_view_transform(dist, 0, 0)
cameras = FoVPerspectiveCameras(device=device, R=R, T=T, fov=90)
images = renderer(mesh, cameras=cameras)
images.sum().backward()
# Check gradients exist
self.assertIsNotNone(verts.grad)
def test_case_1(self):
"""
Case 1: Single triangle fully in front of the image plane (z=0)
@ -350,3 +525,134 @@ class TestRenderMeshesClipping(TestCaseMixin, unittest.TestCase):
# barycentric conversion matrix.
bary_idx = idx.new_tensor([1, 4, 2, 5, 0, -1, 3, 6])
self.assertClose(clipped_faces.faces_clipped_to_conversion_idx, bary_idx)
def test_convert_clipped_to_unclipped_case_4(self):
"""
Test with a single case 4 triangle which is clipped into
a quadrilateral and subdivided.
"""
device = "cuda:0"
# fmt: off
verts = torch.tensor(
[
[-1.0, 0.0, -1.0], # noqa: E241, E201
[ 0.0, 1.0, -1.0], # noqa: E241, E201
[ 1.0, 0.0, -1.0], # noqa: E241, E201
[ 0.0, -1.0, -1.0], # noqa: E241, E201
[-1.0, 0.5, 0.5], # noqa: E241, E201
[ 1.0, 1.0, 1.0], # noqa: E241, E201
[ 0.0, -1.0, 1.0], # noqa: E241, E201
[-1.0, 0.5, -0.5], # noqa: E241, E201
[ 1.0, 1.0, -1.0], # noqa: E241, E201
[-1.0, 0.0, 1.0], # noqa: E241, E201
[ 0.0, 1.0, 1.0], # noqa: E241, E201
[ 1.0, 0.0, 1.0], # noqa: E241, E201
],
dtype=torch.float32,
device=device,
)
faces = torch.tensor(
[
[0, 1, 2], # noqa: E241, E201 Case 2 fully clipped
[3, 4, 5], # noqa: E241, E201 Case 4 clipped and subdivided
[5, 4, 3], # noqa: E241, E201 Repeat of Case 4
[6, 7, 8], # noqa: E241, E201 Case 3 clipped
[9, 10, 11], # noqa: E241, E201 Case 1 untouched
],
dtype=torch.int64,
device=device,
)
# fmt: on
meshes = Meshes(verts=[verts], faces=[faces])
# Clip meshes
clipped_faces = self.clip_faces(meshes)
# 4x faces (from Case 4) + 1 (from Case 3) + 1 (from Case 1)
self.assertEqual(clipped_faces.face_verts.shape[0], 6)
image_size = (10, 10)
blur_radius = 0.05
faces_per_pixel = 2
perspective_correct = True
bin_size = 0
max_faces_per_bin = 20
clip_barycentric_coords = False
cull_backfaces = False
# Rasterize clipped mesh
pix_to_face, zbuf, barycentric_coords, dists = _RasterizeFaceVerts.apply(
clipped_faces.face_verts,
clipped_faces.mesh_to_face_first_idx,
clipped_faces.num_faces_per_mesh,
clipped_faces.clipped_faces_neighbor_idx,
image_size,
blur_radius,
faces_per_pixel,
bin_size,
max_faces_per_bin,
perspective_correct,
clip_barycentric_coords,
cull_backfaces,
)
# Convert outputs so they are in terms of the unclipped mesh.
outputs = convert_clipped_rasterization_to_original_faces(
pix_to_face,
barycentric_coords,
clipped_faces,
)
pix_to_face_unclipped, barycentric_coords_unclipped = outputs
# In the clipped mesh there are more faces than in the unclipped mesh
self.assertTrue(pix_to_face.max() > pix_to_face_unclipped.max())
# Unclipped pix_to_face indices must be in the limit of the number
# of faces in the unclipped mesh.
self.assertTrue(pix_to_face_unclipped.max() < faces.shape[0])
def test_case_4_no_duplicates(self):
"""
In the case of an simple mesh with one face that is cut by the image
plane into a quadrilateral, there shouldn't be duplicates indices of
the face in the pix_to_face output of rasterization.
"""
for (device, bin_size) in [("cpu", 0), ("cuda:0", 0), ("cuda:0", None)]:
verts = torch.tensor(
[[0.0, -10.0, 1.0], [-1.0, 2.0, -2.0], [1.0, 5.0, -10.0]],
dtype=torch.float32,
device=device,
)
faces = torch.tensor(
[
[0, 1, 2],
],
dtype=torch.int64,
device=device,
)
meshes = Meshes(verts=[verts], faces=[faces])
k = 3
settings = RasterizationSettings(
image_size=10,
blur_radius=0.05,
faces_per_pixel=k,
z_clip_value=1e-2,
perspective_correct=True,
cull_to_frustum=True,
bin_size=bin_size,
)
# The camera is positioned so that the image plane cuts
# the mesh face into a quadrilateral.
R, T = look_at_view_transform(0.2, 0, 0)
cameras = FoVPerspectiveCameras(device=device, R=R, T=T, fov=90)
rasterizer = MeshRasterizer(raster_settings=settings, cameras=cameras)
fragments = rasterizer(meshes)
p2f = fragments.pix_to_face.reshape(-1, k)
unique_vals, idx_counts = p2f.unique(dim=0, return_counts=True)
# There is only one face in this mesh so if it hits a pixel
# it can only be at position k = 0
# For any pixel, the values [0, 0, 1] for the top K faces cannot be possible
double_hit = torch.tensor([0, 0, -1], device=device)
check_double_hit = any(torch.allclose(i, double_hit) for i in unique_vals)
self.assertFalse(check_double_hit)