From 340662e98e97c5e105cf6570765d7bae3e6228bf Mon Sep 17 00:00:00 2001 From: Nikhila Ravi Date: Mon, 8 Feb 2021 14:30:55 -0800 Subject: [PATCH] CUDA/C++ Rasterizer updates to handle clipped faces Summary: - Updated the C++/CUDA mesh rasterization kernels to handle the clipped faces. In particular this required careful handling of the distance calculation for faces which are cut into a quadrilateral by the image plane and then split into two sub triangles i.e. both sub triangles can't be part of the top K faces. - Updated `rasterize_meshes.py` to use the utils functions to clip the meshes and convert the fragments back to in terms of the unclipped mesh - Added end to end tests Reviewed By: jcjohnson Differential Revision: D26169685 fbshipit-source-id: d64cd0d656109b965f44a35c301b7c81f451cfa0 --- .../csrc/rasterize_meshes/rasterize_meshes.cu | 109 ++++-- .../csrc/rasterize_meshes/rasterize_meshes.h | 29 ++ .../rasterize_meshes/rasterize_meshes_cpu.cpp | 81 ++++- pytorch3d/renderer/mesh/rasterize_meshes.py | 145 +++++++- pytorch3d/renderer/mesh/rasterizer.py | 23 +- tests/bm_rasterize_meshes.py | 31 +- tests/data/room.jpg | Bin 0 -> 7940 bytes .../test_render_mesh_clipped_cam_dist=0.5.jpg | Bin 0 -> 7393 bytes tests/test_rasterize_meshes.py | 49 +++ tests/test_render_implicit.py | 1 + tests/test_render_meshes.py | 1 + tests/test_render_meshes_clipped.py | 310 +++++++++++++++++- 12 files changed, 733 insertions(+), 46 deletions(-) create mode 100644 tests/data/room.jpg create mode 100644 tests/data/test_render_mesh_clipped_cam_dist=0.5.jpg diff --git a/pytorch3d/csrc/rasterize_meshes/rasterize_meshes.cu b/pytorch3d/csrc/rasterize_meshes/rasterize_meshes.cu index a92a64e5..714d6aca 100644 --- a/pytorch3d/csrc/rasterize_meshes/rasterize_meshes.cu +++ b/pytorch3d/csrc/rasterize_meshes/rasterize_meshes.cu @@ -17,8 +17,8 @@ namespace { // A structure for holding details about a pixel. struct Pixel { float z; - int64_t idx; - float dist; + int64_t idx; // idx of face + float dist; // abs distance of pixel to face float3 bary; }; @@ -111,6 +111,7 @@ __device__ bool CheckPointOutsideBoundingBox( template __device__ void CheckPixelInsideFace( const float* face_verts, // (F, 3, 3) + const int64_t* clipped_faces_neighbor_idx, // (F,) const int face_idx, int& q_size, float& q_max_z, @@ -173,32 +174,72 @@ __device__ void CheckPixelInsideFace( // face. const bool inside = p_bary.x > 0.0f && p_bary.y > 0.0f && p_bary.z > 0.0f; const float signed_dist = inside ? -dist : dist; - // Check if pixel is outside blur region if (!inside && dist >= blur_radius) { return; } - if (q_size < K) { - // Just insert it. - q[q_size] = {pz, face_idx, signed_dist, p_bary_clip}; - if (pz > q_max_z) { - q_max_z = pz; - q_max_idx = q_size; + // Handle the case where a face (f) partially behind the image plane is + // clipped to a quadrilateral and then split into two faces (t1, t2). In this + // case we: + // 1. Find the index of the neighboring face (e.g. for t1 need index of t2) + // 2. Check if the neighboring face (t2) is already in the top K faces + // 3. If yes, compare the distance of the pixel to t1 with the distance to t2. + // 4. If dist_t1 < dist_t2, overwrite the values for t2 in the top K faces. + const int neighbor_idx = clipped_faces_neighbor_idx[face_idx]; + int neighbor_idx_top_k = -1; + + // Check if neighboring face is already in the top K. + // -1 is the fill value in clipped_faces_neighbor_idx + if (neighbor_idx != -1) { + // Only need to loop until q_size. + for (int i = 0; i < q_size; i++) { + if (q[i].idx == neighbor_idx) { + neighbor_idx_top_k = i; + break; + } } - q_size++; - } else if (pz < q_max_z) { - // Overwrite the old max, and find the new max. - q[q_max_idx] = {pz, face_idx, signed_dist, p_bary_clip}; - q_max_z = pz; - for (int i = 0; i < K; i++) { - if (q[i].z > q_max_z) { - q_max_z = q[i].z; - q_max_idx = i; + } + // If neighbor idx is not -1 then it is in the top K struct. + if (neighbor_idx_top_k != -1) { + // If dist of current face is less than neighbor then overwrite the + // neighbor face values in the top K struct. + float neighbor_dist = abs(q[neighbor_idx_top_k].dist); + if (dist < neighbor_dist) { + // Overwrite the neighbor face values + q[neighbor_idx_top_k] = {pz, face_idx, signed_dist, p_bary_clip}; + + // If pz > q_max then overwrite the max values and index of the max. + // q_size stays the same. + if (pz > q_max_z) { + q_max_z = pz; + q_max_idx = neighbor_idx_top_k; + } + } + } else { + // Handle as a normal face + if (q_size < K) { + // Just insert it. + q[q_size] = {pz, face_idx, signed_dist, p_bary_clip}; + if (pz > q_max_z) { + q_max_z = pz; + q_max_idx = q_size; + } + q_size++; + } else if (pz < q_max_z) { + // Overwrite the old max, and find the new max. + q[q_max_idx] = {pz, face_idx, signed_dist, p_bary_clip}; + q_max_z = pz; + for (int i = 0; i < K; i++) { + if (q[i].z > q_max_z) { + q_max_z = q[i].z; + q_max_idx = i; + } } } } } + } // namespace // **************************************************************************** @@ -208,6 +249,7 @@ __global__ void RasterizeMeshesNaiveCudaKernel( const float* face_verts, const int64_t* mesh_to_face_first_idx, const int64_t* num_faces_per_mesh, + const int64_t* clipped_faces_neighbor_idx, const float blur_radius, const bool perspective_correct, const bool clip_barycentric_coords, @@ -265,6 +307,7 @@ __global__ void RasterizeMeshesNaiveCudaKernel( CheckPixelInsideFace( face_verts, + clipped_faces_neighbor_idx, f, q_size, q_max_z, @@ -298,6 +341,7 @@ RasterizeMeshesNaiveCuda( const at::Tensor& face_verts, const at::Tensor& mesh_to_faces_packed_first_idx, const at::Tensor& num_faces_per_mesh, + const at::Tensor& clipped_faces_neighbor_idx, const std::tuple image_size, const float blur_radius, const int num_closest, @@ -313,6 +357,10 @@ RasterizeMeshesNaiveCuda( num_faces_per_mesh.size(0) == mesh_to_faces_packed_first_idx.size(0), "num_faces_per_mesh must have save size first dimension as mesh_to_faces_packed_first_idx"); + TORCH_CHECK( + clipped_faces_neighbor_idx.size(0) == face_verts.size(0), + "clipped_faces_neighbor_idx must have save size first dimension as face_verts"); + if (num_closest > kMaxPointsPerPixel) { std::stringstream ss; ss << "Must have points_per_pixel <= " << kMaxPointsPerPixel; @@ -323,11 +371,16 @@ RasterizeMeshesNaiveCuda( at::TensorArg face_verts_t{face_verts, "face_verts", 1}, mesh_to_faces_packed_first_idx_t{ mesh_to_faces_packed_first_idx, "mesh_to_faces_packed_first_idx", 2}, - num_faces_per_mesh_t{num_faces_per_mesh, "num_faces_per_mesh", 3}; + num_faces_per_mesh_t{num_faces_per_mesh, "num_faces_per_mesh", 3}, + clipped_faces_neighbor_idx_t{ + clipped_faces_neighbor_idx, "clipped_faces_neighbor_idx", 4}; at::CheckedFrom c = "RasterizeMeshesNaiveCuda"; at::checkAllSameGPU( c, - {face_verts_t, mesh_to_faces_packed_first_idx_t, num_faces_per_mesh_t}); + {face_verts_t, + mesh_to_faces_packed_first_idx_t, + num_faces_per_mesh_t, + clipped_faces_neighbor_idx_t}); // Set the device for the kernel launch based on the device of the input at::cuda::CUDAGuard device_guard(face_verts.device()); @@ -358,6 +411,7 @@ RasterizeMeshesNaiveCuda( face_verts.contiguous().data_ptr(), mesh_to_faces_packed_first_idx.contiguous().data_ptr(), num_faces_per_mesh.contiguous().data_ptr(), + clipped_faces_neighbor_idx.contiguous().data_ptr(), blur_radius, perspective_correct, clip_barycentric_coords, @@ -800,6 +854,7 @@ at::Tensor RasterizeMeshesCoarseCuda( __global__ void RasterizeMeshesFineCudaKernel( const float* face_verts, // (F, 3, 3) const int32_t* bin_faces, // (N, BH, BW, T) + const int64_t* clipped_faces_neighbor_idx, // (F,) const float blur_radius, const int bin_size, const bool perspective_correct, @@ -858,6 +913,7 @@ __global__ void RasterizeMeshesFineCudaKernel( int q_size = 0; float q_max_z = -1000; int q_max_idx = -1; + for (int m = 0; m < M; m++) { const int f = bin_faces[n * BH * BW * M + by * BW * M + bx * M + m]; if (f < 0) { @@ -867,6 +923,7 @@ __global__ void RasterizeMeshesFineCudaKernel( // update q, q_size, q_max_z and q_max_idx in place. CheckPixelInsideFace( face_verts, + clipped_faces_neighbor_idx, f, q_size, q_max_z, @@ -906,6 +963,7 @@ std::tuple RasterizeMeshesFineCuda( const at::Tensor& face_verts, const at::Tensor& bin_faces, + const at::Tensor& clipped_faces_neighbor_idx, const std::tuple image_size, const float blur_radius, const int bin_size, @@ -918,12 +976,18 @@ RasterizeMeshesFineCuda( face_verts.size(2) == 3, "face_verts must have dimensions (num_faces, 3, 3)"); TORCH_CHECK(bin_faces.ndimension() == 4, "bin_faces must have 4 dimensions"); + TORCH_CHECK( + clipped_faces_neighbor_idx.size(0) == face_verts.size(0), + "clipped_faces_neighbor_idx must have the same first dimension as face_verts"); // Check inputs are on the same device at::TensorArg face_verts_t{face_verts, "face_verts", 1}, - bin_faces_t{bin_faces, "bin_faces", 2}; + bin_faces_t{bin_faces, "bin_faces", 2}, + clipped_faces_neighbor_idx_t{ + clipped_faces_neighbor_idx, "clipped_faces_neighbor_idx", 3}; at::CheckedFrom c = "RasterizeMeshesFineCuda"; - at::checkAllSameGPU(c, {face_verts_t, bin_faces_t}); + at::checkAllSameGPU( + c, {face_verts_t, bin_faces_t, clipped_faces_neighbor_idx_t}); // Set the device for the kernel launch based on the device of the input at::cuda::CUDAGuard device_guard(face_verts.device()); @@ -961,6 +1025,7 @@ RasterizeMeshesFineCuda( RasterizeMeshesFineCudaKernel<<>>( face_verts.contiguous().data_ptr(), bin_faces.contiguous().data_ptr(), + clipped_faces_neighbor_idx.contiguous().data_ptr(), blur_radius, bin_size, perspective_correct, diff --git a/pytorch3d/csrc/rasterize_meshes/rasterize_meshes.h b/pytorch3d/csrc/rasterize_meshes/rasterize_meshes.h index c722492a..0461e726 100644 --- a/pytorch3d/csrc/rasterize_meshes/rasterize_meshes.h +++ b/pytorch3d/csrc/rasterize_meshes/rasterize_meshes.h @@ -15,6 +15,7 @@ RasterizeMeshesNaiveCpu( const torch::Tensor& face_verts, const torch::Tensor& mesh_to_face_first_idx, const torch::Tensor& num_faces_per_mesh, + const torch::Tensor& clipped_faces_neighbor_idx, const std::tuple image_size, const float blur_radius, const int faces_per_pixel, @@ -28,6 +29,7 @@ RasterizeMeshesNaiveCuda( const at::Tensor& face_verts, const at::Tensor& mesh_to_face_first_idx, const at::Tensor& num_faces_per_mesh, + const torch::Tensor& clipped_faces_neighbor_idx, const std::tuple image_size, const float blur_radius, const int num_closest, @@ -48,6 +50,12 @@ RasterizeMeshesNaiveCuda( // the batch where N is the batch size. // num_faces_per_mesh: LongTensor of shape (N) giving the number of faces // for each mesh in the batch. +// clipped_faces_neighbor_idx: LongTensor of shape (F,) giving the +// index of the neighboring face for each face which was clipped to a +// quadrilateral and then divided into two triangles. +// e.g. for a face f partially behind the image plane which is split into +// two triangles (t1, t2): clipped_faces_neighbor_idx[t1_idx] = t2_idx +// Faces which are not clipped and subdivided are set to -1. // image_size: Tuple (H, W) giving the size in pixels of the output // image to be rasterized. // blur_radius: float distance in NDC coordinates uses to expand the face @@ -90,6 +98,7 @@ RasterizeMeshesNaive( const torch::Tensor& face_verts, const torch::Tensor& mesh_to_face_first_idx, const torch::Tensor& num_faces_per_mesh, + const torch::Tensor& clipped_faces_neighbor_idx, const std::tuple image_size, const float blur_radius, const int faces_per_pixel, @@ -106,6 +115,7 @@ RasterizeMeshesNaive( face_verts, mesh_to_face_first_idx, num_faces_per_mesh, + clipped_faces_neighbor_idx, image_size, blur_radius, faces_per_pixel, @@ -120,6 +130,7 @@ RasterizeMeshesNaive( face_verts, mesh_to_face_first_idx, num_faces_per_mesh, + clipped_faces_neighbor_idx, image_size, blur_radius, faces_per_pixel, @@ -306,6 +317,7 @@ std::tuple RasterizeMeshesFineCuda( const torch::Tensor& face_verts, const torch::Tensor& bin_faces, + const torch::Tensor& clipped_faces_neighbor_idx, const std::tuple image_size, const float blur_radius, const int bin_size, @@ -322,6 +334,12 @@ RasterizeMeshesFineCuda( // in NDC coordinates in the range [-1, 1]. // bin_faces: int32 Tensor of shape (N, B, B, M) giving the indices of faces // that fall into each bin (output from coarse rasterization). +// clipped_faces_neighbor_idx: LongTensor of shape (F,) giving the +// index of the neighboring face for each face which was clipped to a +// quadrilateral and then divided into two triangles. +// e.g. for a face f partially behind the image plane which is split into +// two triangles (t1, t2): clipped_faces_neighbor_idx[t1_idx] = t2_idx +// Faces which are not clipped and subdivided are set to -1. // image_size: Tuple (H, W) giving the size in pixels of the output // image to be rasterized. // blur_radius: float distance in NDC coordinates uses to expand the face @@ -364,6 +382,7 @@ std::tuple RasterizeMeshesFine( const torch::Tensor& face_verts, const torch::Tensor& bin_faces, + const torch::Tensor& clipped_faces_neighbor_idx, const std::tuple image_size, const float blur_radius, const int bin_size, @@ -378,6 +397,7 @@ RasterizeMeshesFine( return RasterizeMeshesFineCuda( face_verts, bin_faces, + clipped_faces_neighbor_idx, image_size, blur_radius, bin_size, @@ -411,6 +431,12 @@ RasterizeMeshesFine( // the batch where N is the batch size. // num_faces_per_mesh: LongTensor of shape (N) giving the number of faces // for each mesh in the batch. +// clipped_faces_neighbor_idx: LongTensor of shape (F,) giving the +// index of the neighboring face for each face which was clipped to a +// quadrilateral and then divided into two triangles. +// e.g. for a face f partially behind the image plane which is split into +// two triangles (t1, t2): clipped_faces_neighbor_idx[t1_idx] = t2_idx +// Faces which are not clipped and subdivided are set to -1. // image_size: Tuple (H, W) giving the size in pixels of the output // image to be rasterized. // blur_radius: float distance in NDC coordinates uses to expand the face @@ -456,6 +482,7 @@ RasterizeMeshes( const torch::Tensor& face_verts, const torch::Tensor& mesh_to_face_first_idx, const torch::Tensor& num_faces_per_mesh, + const torch::Tensor& clipped_faces_neighbor_idx, const std::tuple image_size, const float blur_radius, const int faces_per_pixel, @@ -477,6 +504,7 @@ RasterizeMeshes( return RasterizeMeshesFine( face_verts, bin_faces, + clipped_faces_neighbor_idx, image_size, blur_radius, bin_size, @@ -490,6 +518,7 @@ RasterizeMeshes( face_verts, mesh_to_face_first_idx, num_faces_per_mesh, + clipped_faces_neighbor_idx, image_size, blur_radius, faces_per_pixel, diff --git a/pytorch3d/csrc/rasterize_meshes/rasterize_meshes_cpu.cpp b/pytorch3d/csrc/rasterize_meshes/rasterize_meshes_cpu.cpp index 3160e685..4b925f02 100644 --- a/pytorch3d/csrc/rasterize_meshes/rasterize_meshes_cpu.cpp +++ b/pytorch3d/csrc/rasterize_meshes/rasterize_meshes_cpu.cpp @@ -99,11 +99,31 @@ auto ComputeFaceAreas(const torch::Tensor& face_verts) { return face_areas; } +// Helper function to use with std::find_if to find the index of any +// values in the top k struct which match a given idx. +struct IsNeighbor { + IsNeighbor(int neighbor_idx) { + this->neighbor_idx = neighbor_idx; + } + bool operator()(std::tuple elem) { + return (std::get<1>(elem) == neighbor_idx); + } + int neighbor_idx; +}; + +// Function to sort based on the z distance in the top K queue +bool SortTopKByZdist( + std::tuple a, + std::tuple b) { + return std::get<0>(a) < std::get<0>(b); +} + std::tuple RasterizeMeshesNaiveCpu( const torch::Tensor& face_verts, const torch::Tensor& mesh_to_face_first_idx, const torch::Tensor& num_faces_per_mesh, + const torch::Tensor& clipped_faces_neighbor_idx, const std::tuple image_size, const float blur_radius, const int faces_per_pixel, @@ -139,6 +159,7 @@ RasterizeMeshesNaiveCpu( auto zbuf_a = zbuf.accessor(); auto pix_dists_a = pix_dists.accessor(); auto barycentric_coords_a = barycentric_coords.accessor(); + auto neighbor_idx_a = clipped_faces_neighbor_idx.accessor(); auto face_bboxes = ComputeFaceBoundingBoxes(face_verts); auto face_bboxes_a = face_bboxes.accessor(); @@ -168,10 +189,11 @@ RasterizeMeshesNaiveCpu( // X coordinate of the left of the pixel. const float xf = PixToNonSquareNdc(xidx, W, H); - // Use a priority queue to hold values: + + // Use a deque to hold values: // (z, idx, r, bary.x, bary.y. bary.z) - std::priority_queue> - q; + // Sort the deque as needed to mimic a priority queue. + std::deque> q; // Loop through the faces in the mesh. for (int f = face_start_idx; f < face_stop_idx; ++f) { @@ -240,15 +262,58 @@ RasterizeMeshesNaiveCpu( if (!inside && dist >= blur_radius) { continue; } - // The current pixel lies inside the current face. - q.emplace(pz, f, signed_dist, bary_clip.x, bary_clip.y, bary_clip.z); + + // Handle the case where a face (f) partially behind the image plane + // is clipped to a quadrilateral and then split into two faces (t1, + // t2). In this case we: + // 1. Find the index of the neighbor (e.g. for t1 need index of t2) + // 2. Check if the neighbor (t2) is already in the top K faces + // 3. If yes, compare the distance of the pixel to t1 with the + // distance to t2. + // 4. If dist_t1 < dist_t2, overwrite the values for t2 in the top K + // faces. + const int neighbor_idx = neighbor_idx_a[f]; + int idx_top_k = -1; + + // Check if neighboring face is already in the top K. + if (neighbor_idx != -1) { + const auto it = + std::find_if(q.begin(), q.end(), IsNeighbor(neighbor_idx)); + // Get the index of the element from the iterator + idx_top_k = (it != q.end()) ? it - q.begin() : idx_top_k; + } + + // If idx_top_k idx is not -1 then it is in the top K struct. + if (idx_top_k != -1) { + // If dist of current face is less than neighbor, overwrite + // the neighbor face values in the top K struct. + const auto neighbor = q[idx_top_k]; + const float dist_neighbor = std::abs(std::get<2>(neighbor)); + if (dist < dist_neighbor) { + // Overwrite the neighbor face values. + q[idx_top_k] = { + pz, f, signed_dist, bary_clip.x, bary_clip.y, bary_clip.z}; + } + } else { + // Handle as a normal face. + // The current pixel lies inside the current face. + // Add at the end of the deque. + q.emplace_back( + pz, f, signed_dist, bary_clip.x, bary_clip.y, bary_clip.z); + } + + // Sort the deque inplace based on the z distance + // to mimic using a priority queue. + std::sort(q.begin(), q.end(), SortTopKByZdist); if (static_cast(q.size()) > K) { - q.pop(); + // remove the last value + q.pop_back(); } } while (!q.empty()) { - auto t = q.top(); - q.pop(); + // Loop through and add values to the output tensors + auto t = q.back(); + q.pop_back(); const int i = q.size(); zbuf_a[n][yi][xi][i] = std::get<0>(t); face_idxs_a[n][yi][xi][i] = std::get<1>(t); diff --git a/pytorch3d/renderer/mesh/rasterize_meshes.py b/pytorch3d/renderer/mesh/rasterize_meshes.py index d4cb24ab..34a2f612 100644 --- a/pytorch3d/renderer/mesh/rasterize_meshes.py +++ b/pytorch3d/renderer/mesh/rasterize_meshes.py @@ -9,6 +9,12 @@ import torch # pyre-fixme[21]: Could not find name `_C` in `pytorch3d`. from pytorch3d import _C +from .clip import ( + ClipFrustum, + clip_faces, + convert_clipped_rasterization_to_original_faces, +) + # TODO make the epsilon user configurable kEpsilon = 1e-8 @@ -28,6 +34,8 @@ def rasterize_meshes( perspective_correct: bool = False, clip_barycentric_coords: bool = False, cull_backfaces: bool = False, + z_clip_value: Optional[float] = None, + cull_to_frustum: bool = False, ): """ Rasterize a batch of meshes given the shape of the desired output image. @@ -67,7 +75,8 @@ def rasterize_meshes( will be raised. This should not affect the output values, but can affect the memory usage in the forward pass. perspective_correct: Bool, Whether to apply perspective correction when computing - barycentric coordinates for pixels. + barycentric coordinates for pixels. This should be set to True if a perspective + camera is used. cull_backfaces: Bool, Whether to only rasterize mesh faces which are visible to the camera. This assumes that vertices of front-facing triangles are ordered in an anti-clockwise @@ -76,6 +85,15 @@ def rasterize_meshes( direction. NOTE: This will only work if the mesh faces are consistently defined with counter-clockwise ordering when viewed from the outside. + z_clip_value: if not None, then triangles will be clipped (and possibly + subdivided into smaller triangles) such that z >= z_clip_value. + This avoids camera projections that go to infinity as z->0. + Default is None as clipping affects rasterization speed and + should only be turned on if explicitly needed. + See clip.py for all the extra computation that is required. + cull_to_frustum: if True, triangles outside the view frustum will be culled. + Culling involves removing all faces which fall outside view frustum. + Default is False so that it is turned on only when needed. Returns: 4-element tuple containing @@ -141,6 +159,42 @@ def rasterize_meshes( im_size = (image_size, image_size) max_image_size = image_size + clipped_faces_neighbor_idx = None + + if z_clip_value is not None or cull_to_frustum: + # Cull faces outside the view frustum, and clip faces that are partially + # behind the camera into the portion of the triangle in front of the + # camera. This may change the number of faces + frustum = ClipFrustum( + left=-1, + right=1, + top=-1, + bottom=1, + perspective_correct=perspective_correct, + z_clip_value=z_clip_value, + cull=cull_to_frustum, + ) + clipped_faces = clip_faces( + face_verts, mesh_to_face_first_idx, num_faces_per_mesh, frustum=frustum + ) + face_verts = clipped_faces.face_verts + mesh_to_face_first_idx = clipped_faces.mesh_to_face_first_idx + num_faces_per_mesh = clipped_faces.num_faces_per_mesh + + # For case 4 clipped triangles (where a big triangle is split in two smaller triangles), + # need the index of the neighboring clipped triangle as only one can be in + # in the top K closest faces in the rasterization step. + clipped_faces_neighbor_idx = clipped_faces.clipped_faces_neighbor_idx + + if clipped_faces_neighbor_idx is None: + # Set to the default which is all -1s. + clipped_faces_neighbor_idx = torch.full( + size=(face_verts.shape[0],), + fill_value=-1, + device=meshes.device, + dtype=torch.int64, + ) + # TODO: Choose naive vs coarse-to-fine based on mesh size and image size. if bin_size is None: if not verts_packed.is_cuda: @@ -172,10 +226,11 @@ def rasterize_meshes( max_faces_per_bin = int(max(10000, meshes._F / 5)) # pyre-fixme[16]: `_RasterizeFaceVerts` has no attribute `apply`. - return _RasterizeFaceVerts.apply( + pix_to_face, zbuf, barycentric_coords, dists = _RasterizeFaceVerts.apply( face_verts, mesh_to_face_first_idx, num_faces_per_mesh, + clipped_faces_neighbor_idx, im_size, blur_radius, faces_per_pixel, @@ -186,6 +241,17 @@ def rasterize_meshes( cull_backfaces, ) + if z_clip_value is not None or cull_to_frustum: + # If faces were clipped, map the rasterization result to be in terms of the + # original unclipped faces. This may involve converting barycentric + # coordinates + outputs = convert_clipped_rasterization_to_original_faces( + pix_to_face, barycentric_coords, clipped_faces + ) + pix_to_face, barycentric_coords = outputs + + return pix_to_face, zbuf, barycentric_coords, dists + class _RasterizeFaceVerts(torch.autograd.Function): """ @@ -216,9 +282,10 @@ class _RasterizeFaceVerts(torch.autograd.Function): # pyre-fixme[14]: `forward` overrides method defined in `Function` inconsistently. def forward( ctx, - face_verts, - mesh_to_face_first_idx, - num_faces_per_mesh, + face_verts: torch.Tensor, + mesh_to_face_first_idx: torch.Tensor, + num_faces_per_mesh: torch.Tensor, + clipped_faces_neighbor_idx: torch.Tensor, image_size: Union[List[int], Tuple[int, int]] = (256, 256), blur_radius: float = 0.01, faces_per_pixel: int = 0, @@ -227,12 +294,15 @@ class _RasterizeFaceVerts(torch.autograd.Function): perspective_correct: bool = False, clip_barycentric_coords: bool = False, cull_backfaces: bool = False, + z_clip_value: Optional[float] = None, + cull_to_frustum: bool = True, ): # pyre-fixme[16]: Module `pytorch3d` has no attribute `_C`. pix_to_face, zbuf, barycentric_coords, dists = _C.rasterize_meshes( face_verts, mesh_to_face_first_idx, num_faces_per_mesh, + clipped_faces_neighbor_idx, image_size, blur_radius, faces_per_pixel, @@ -242,6 +312,7 @@ class _RasterizeFaceVerts(torch.autograd.Function): clip_barycentric_coords, cull_backfaces, ) + ctx.save_for_backward(face_verts, pix_to_face) ctx.mark_non_differentiable(pix_to_face) ctx.perspective_correct = perspective_correct @@ -253,6 +324,7 @@ class _RasterizeFaceVerts(torch.autograd.Function): grad_face_verts = None grad_mesh_to_face_first_idx = None grad_num_faces_per_mesh = None + grad_clipped_faces_neighbor_idx = None grad_image_size = None grad_radius = None grad_faces_per_pixel = None @@ -275,6 +347,7 @@ class _RasterizeFaceVerts(torch.autograd.Function): grad_face_verts, grad_mesh_to_face_first_idx, grad_num_faces_per_mesh, + grad_clipped_faces_neighbor_idx, grad_image_size, grad_radius, grad_faces_per_pixel, @@ -339,6 +412,9 @@ def rasterize_meshes_python( perspective_correct: bool = False, clip_barycentric_coords: bool = False, cull_backfaces: bool = False, + z_clip_value: Optional[float] = None, + cull_to_frustum: bool = True, + clipped_faces_neighbor_idx: Optional[torch.Tensor] = None, ): """ Naive PyTorch implementation of mesh rasterization with the same inputs and @@ -359,6 +435,26 @@ def rasterize_meshes_python( mesh_to_face_first_idx = meshes.mesh_to_faces_packed_first_idx() num_faces_per_mesh = meshes.num_faces_per_mesh() + if z_clip_value is not None or cull_to_frustum: + # Cull faces outside the view frustum, and clip faces that are partially + # behind the camera into the portion of the triangle in front of the + # camera. This may change the number of faces + frustum = ClipFrustum( + left=-1, + right=1, + top=-1, + bottom=1, + perspective_correct=perspective_correct, + z_clip_value=z_clip_value, + cull=cull_to_frustum, + ) + clipped_faces = clip_faces( + faces_verts, mesh_to_face_first_idx, num_faces_per_mesh, frustum=frustum + ) + faces_verts = clipped_faces.face_verts + mesh_to_face_first_idx = clipped_faces.mesh_to_face_first_idx + num_faces_per_mesh = clipped_faces.num_faces_per_mesh + # Intialize output tensors. face_idxs = torch.full( (N, H, W, K), fill_value=-1, dtype=torch.int64, device=device @@ -468,26 +564,55 @@ def rasterize_meshes_python( # Points inside the triangle have negative distance. dist = point_triangle_distance(pxy, v0[:2], v1[:2], v2[:2]) - signed_dist = dist * -1.0 if inside else dist - # Add an epsilon to prevent errors when comparing distance # to blur radius. if not inside and dist >= blur_radius: continue - top_k_points.append((pz, f, bary, signed_dist)) + # Handle the case where a face (f) partially behind the image plane is + # clipped to a quadrilateral and then split into two faces (t1, t2). + top_k_idx = -1 + if ( + clipped_faces_neighbor_idx is not None + and clipped_faces_neighbor_idx[f] != -1 + ): + neighbor_idx = clipped_faces_neighbor_idx[f] + # See if neighbor_idx is in top_k and find index + top_k_idx = [ + i + for i, val in enumerate(top_k_points) + if val[1] == neighbor_idx + ] + top_k_idx = top_k_idx[0] if len(top_k_idx) > 0 else -1 + + if top_k_idx != -1 and dist < top_k_points[top_k_idx][3]: + # Overwrite the neighbor with current face info + top_k_points[top_k_idx] = (pz, f, bary, dist, inside) + else: + # Handle as a normal face + top_k_points.append((pz, f, bary, dist, inside)) + top_k_points.sort() if len(top_k_points) > K: top_k_points = top_k_points[:K] # Save to output tensors. - for k, (pz, f, bary, dist) in enumerate(top_k_points): + for k, (pz, f, bary, dist, inside) in enumerate(top_k_points): zbuf[n, yi, xi, k] = pz face_idxs[n, yi, xi, k] = f bary_coords[n, yi, xi, k, 0] = bary[0] bary_coords[n, yi, xi, k, 1] = bary[1] bary_coords[n, yi, xi, k, 2] = bary[2] - pix_dists[n, yi, xi, k] = dist + # Write the signed distance + pix_dists[n, yi, xi, k] = -dist if inside else dist + + if z_clip_value is not None or cull_to_frustum: + # If faces were clipped, map the rasterization result to be in terms of the + # original unclipped faces. This may involve converting barycentric + # coordinates + (face_idxs, bary_coords,) = convert_clipped_rasterization_to_original_faces( + face_idxs, bary_coords, clipped_faces + ) return face_idxs, zbuf, bary_coords, pix_dists diff --git a/pytorch3d/renderer/mesh/rasterizer.py b/pytorch3d/renderer/mesh/rasterizer.py index 3042b071..413a635e 100644 --- a/pytorch3d/renderer/mesh/rasterizer.py +++ b/pytorch3d/renderer/mesh/rasterizer.py @@ -27,6 +27,8 @@ class RasterizationSettings: "perspective_correct", "clip_barycentric_coords", "cull_backfaces", + "z_clip_value", + "cull_to_frustum", ] def __init__( @@ -36,9 +38,13 @@ class RasterizationSettings: faces_per_pixel: int = 1, bin_size: Optional[int] = None, max_faces_per_bin: Optional[int] = None, - perspective_correct: bool = False, + # set perspective_correct = None so that the + # value can be inferred correctly from the Camera type + perspective_correct: Optional[bool] = None, clip_barycentric_coords: Optional[bool] = None, cull_backfaces: bool = False, + z_clip_value: Optional[float] = None, + cull_to_frustum: bool = False, ): self.image_size = image_size self.blur_radius = blur_radius @@ -48,6 +54,8 @@ class RasterizationSettings: self.perspective_correct = perspective_correct self.clip_barycentric_coords = clip_barycentric_coords self.cull_backfaces = cull_backfaces + self.z_clip_value = z_clip_value + self.cull_to_frustum = cull_to_frustum class MeshRasterizer(nn.Module): @@ -139,12 +147,19 @@ class MeshRasterizer(nn.Module): if clip_barycentric_coords is None: clip_barycentric_coords = raster_settings.blur_radius > 0.0 - # If not specified, infer perspective_correct from the camera + # If not specified, infer perspective_correct and z_clip_value from the camera cameras = kwargs.get("cameras", self.cameras) if raster_settings.perspective_correct is not None: perspective_correct = raster_settings.perspective_correct else: perspective_correct = cameras.is_perspective() + if raster_settings.z_clip_value is not None: + z_clip = raster_settings.z_clip_value + else: + znear = cameras.get_znear() + if isinstance(znear, torch.Tensor): + znear = znear.min().item() + z_clip = None if not perspective_correct or znear is None else znear / 2 pix_to_face, zbuf, bary_coords, dists = rasterize_meshes( meshes_screen, @@ -153,9 +168,11 @@ class MeshRasterizer(nn.Module): faces_per_pixel=raster_settings.faces_per_pixel, bin_size=raster_settings.bin_size, max_faces_per_bin=raster_settings.max_faces_per_bin, - perspective_correct=perspective_correct, clip_barycentric_coords=clip_barycentric_coords, + perspective_correct=perspective_correct, cull_backfaces=raster_settings.cull_backfaces, + z_clip_value=z_clip, + cull_to_frustum=raster_settings.cull_to_frustum, ) return Fragments( pix_to_face=pix_to_face, zbuf=zbuf, bary_coords=bary_coords, dists=dists diff --git a/tests/bm_rasterize_meshes.py b/tests/bm_rasterize_meshes.py index 91addb17..a6948247 100644 --- a/tests/bm_rasterize_meshes.py +++ b/tests/bm_rasterize_meshes.py @@ -66,7 +66,7 @@ def bm_rasterize_meshes() -> None: # Square and non square cases image_size = [64, 128, 512, (512, 256), (256, 512)] blur = [1e-6] - faces_per_pixel = [50] + faces_per_pixel = [40] test_cases = product(num_meshes, ico_level, image_size, blur, faces_per_pixel) for case in test_cases: @@ -87,6 +87,35 @@ def bm_rasterize_meshes() -> None: warmup_iters=1, ) + # Test a subset of the cases with the + # image plane intersecting the mesh. + kwargs_list = [] + num_meshes = [8, 16] + # Square and non square cases + image_size = [64, 128, 512, (512, 256), (256, 512)] + dist = [3, 0.8, 0.5] + test_cases = product(num_meshes, dist, image_size) + + for case in test_cases: + n, d, im = case + kwargs_list.append( + { + "num_meshes": n, + "ico_level": 4, + "image_size": im, + "blur_radius": 1e-6, + "faces_per_pixel": 40, + "dist": d, + } + ) + + benchmark( + TestRasterizeMeshes.bm_rasterize_meshes_with_clipping, + "RASTERIZE_MESHES_CUDA_CLIPPING", + kwargs_list, + warmup_iters=1, + ) + if __name__ == "__main__": bm_rasterize_meshes() diff --git a/tests/data/room.jpg b/tests/data/room.jpg new file mode 100644 index 0000000000000000000000000000000000000000..29b9441fff2631a0ec7a4e74cbeab9e27e3c0631 GIT binary patch literal 7940 zcmeHM2UL?uyPgzc=t<~B48282C<00eF(OR`1SzYb7*K�dWO%i5L_`6r)HH1fqa| zBFN7g6_BRDT99Tz7g10YT~R=#bH7le{j%P3?mhRO|2Q+eZJwFDdB2%&-U(xm$JzkK z+QP~Lz~BJDpbEzR228uHJ9g~|+PY&WURR3<3@KJ?kVwESTV&yCz{`TNhDsz6i9}CB z!*NbD)FT=}I#J({xS7j|bf`i_mqhgDa!&)I50@M1lXSS8q(>UZ(}@#&vlq#eYbWZF z#&KZmDOd@3QH#)vP`qd~nvai{Ur0h&NKjBnR#Xftp&+NIC@+V{D{JU4RaVte!{fEg zv~>)OjE#+zmQpOtms{wsFgD^&0^{T36A}=T5f+v)Qo*Yj{kO|l1;C==mWa~`pa5W4 z7#s^5s{&F0HeSY~GC`0CIE;te1%m+u^f#lx5J(>QLR~FD00G4x3&%oqXhnJU`a1HA z@e7JURatw%VrNXr>9hrL-#QJKpE*Vb0M*am;28SNx`2SFU86eQ9oAOJQD3$Ia;i(W z4)J_a7W6%ErP^ewDy@P67bUlJA$xI>&} z)_5p*ggnbKOX0nz{Bf26fZsNn#_K+qWBjJV>O>ar{H9F_)m!;+d<}Ez$WfPnTV3B% zo-2y8=hd&hUp1#DoaPbzLet7@$)GYYf3^kE(Atc)dBzC?Vz0r^5BM3PIjiI6Dw+t$ z|NcfBFupHH+oEz#F~WcZHEq|4Aa6! z_mHH$eVwgx^qD=gWkmATZFx>ceZBOr_x?4*iHFBS{IZHm3xUp?BlJ=R$%>>+N9ceVubBr+xzCPj;O4!g|)ov(z*7RBAqp3KrQ)Nb%=ucXn?v( zNUu_$>xdy`{i~YJA?g4{h5n&?sOtQgkq*l~+aAjL_LPpO&J2-M$yG|-CWfh^t8;Ly zEvW+CN~=sZ8^@mNrx>Y@f#7c0Cm#eiN^|sg`nENf=N@oI<^9j7ni1covTlnjhg4)9 z??%=&xLQq?vr?XwfS1YP&^FI49yE@m>tl&Ll=yC*cl#q#_Wy?%F=FmgmybTW z?bojVDJiWPbwSx&gV)YF!V#Aihn%S>Wh+~0SUFg;99ie%rj&R`Rz|Z{cCby(CWo$g z$LkqQDOTTal>65rb#0gd4$b+Q+X95tpY*-gaA;qgX!z!Fi0YEo$~%TxTXr;6-ZQ)s zFnrgzFEcq@ATuG$|8GuNx1`g1~1V;@mt zmlk`xsxUO8=9bx0^D}Kk&-Jx0(!!R-Q*$eAWa&2i8#PL|me(8;z47wd`nQB72ckFi zpNrw@?JZYaL|zaAk6$-@rC{jL!@4DI!%yR+)S zq&6#Z z+Yzti#_1RgTw61I>P(012T9GiXO}ahos{-r{T_E~Ap^DDuWwqZ`2Gy7b7w1$6K%vX z$ka?sVPMJIc4fa|=g&6T$4qOB>cR!voG$12#TBPZ)EE`7U@f%`+;#KMO^yLy*5ri^ zwq4qn>{hF&@W^7cAM4lO>`TA?JV3Vlc!px)#+Ga-Pzo6CbA^pR+^Im3Jao+MwqUB?2 z=sHpYhx%4dN(cYF5T5`0vwfVQ(e>j8gbMe7;6Xw`pV~?=kS-6M4xh00o+rp3F+}qK z{L~R)COK8+EcZ@HYQCLq`# z+)WQU-kc@$)iim+}% zRD|8Hiw&v2QM-cmM9!uIPm_ynsPk+JUX2kOjJ6~Dm-Vhtl*>Gx!E3PoO1NnsJA8$_ z5c{@dm;IiYZ01sm@YkX!G}2nZ?c_RrP+sxs&9zW%yNwqh{yTMt9WAZIJT&5S0y+YU7YKATDj5ec|y)hn(|AUw#8$J!2(0bG$3Em;sP znIDZVh(<${+xS6&CZ>QN!=*O3d>d50AJl|D+Uisirdu^3aCY;$P4KEFaA_4D{~1X+ zyV4MENzj2MJ)tsh5pL#_Sv9az7{_hw(Q**$8mVyDf{D7)Pljod89w5F;twC!gL{8u zz&D(KV4HSRq!uI!Gf^z@TJ(WbQ8-<^OPYTdv)rGJPG#_s4-AQEGGlxW3?QvNzu?Fb zsxrxon2R1*VW{oF5gtdBGI|_A^7bBB!Y^vE3GUEppZ+lfv$D>(_s7}^p!9>tUhcaM z^F!eAv=4^-0UwUyC*GKF;eP_(pF~1<#6NsB!o-~86w|m6*pNK~3Xq};fE#BbA({94 z4oA&_Lg%Oeu@b@)h?S5a3kf?Rf$P{@B&2^Y!XM{jzZ-crrbFGO!{dk<1m-ioZ=tO#Kx`B=ep@XYQ2qE5S5;2e zn}GvBm@!M8t)N**;x`Ll9)2M#0F9E9ScC;}VI+i$ZL))?00T2QWkO&F$qy zFXJC{bH0dVlq-zhwbGy&iF)4>GT+Ck`1>jDxT;_I?n{FD>t41-zUs#pDsq2Im~b3{ z;FaS(BnQ+=hnN>@_;yO#H)v^k-VQoy3mYFDm8Ov7Z z@X=qqRqy*^-Jyo2t;X5B&n~xa;CP#~^5`_aQ!M#t)zG7V`fO8$hv?gp^o_M|l}kQa z+x3t#N}sj3s3nA$p-oM@9>!Owb;;SLZoESFd6ZVML`l5bnjVUzBv6s7^0tkniVrX)6o;<<%+Pre;jCeoiCTN7f3 z3OHTGe@MVQk@1>$1g=13SsA4gqgUZ|BFs+87=hH}C(rA4FCA73Lka2f$^DvA(l{XU z#xfQkYq52yKfT8LB^hT4ROK*T{7Wo1nA}JpQ<1W-t}4FQRV((4H8J(^sX1H6xoe|! zaVK-9xoXo<(Fk$0f43$U19aCO+wQZ^haa>)_dc45QY9cn2h{{XS?HJ5%XXW?Fy0uW zn-9=Td8L=TB+@_VsNVydlh%5s!S@86dc&aNk5+j5{idkGPfv*+1E^!WO9Oq-7!ZLr zszownW5L?x$ZLC2ltw#`tZ*JZELHPIrB2M%ipC}HM3?5DP(8heqJ}@21vf|H(5s}8 zYh4PKbE9-U`cD;Uly^DP5%LaCJZ~FjA1O z$C<^6tIE?4&b{QWG34BGy}2S!KPWen*7eyrr^Ko2a&wsGh;nnthl5iZVYPi`|U=d-hKban#vY%(Ib$zhS^Zi71V$YoZk^?>G= z)aReX#Fi|nbywb-ipW}?Q2B zioq(hX~PqMc2g!9g~t>qa$J})QG=|DlnvYVB7qOQ!llOLgSa3M$+o&-v6m=TBq+Ze zNPtS2jO95xWi4(niJ|ELAPqs*2Q zDS@@8YZ=H{2pk+xJGE9EBdEXq_P|94^ec>2)aHRQrh$ewjzd5`(@*BtYl)Qckr;dO EUjSk=0ssI2 literal 0 HcmV?d00001 diff --git a/tests/data/test_render_mesh_clipped_cam_dist=0.5.jpg b/tests/data/test_render_mesh_clipped_cam_dist=0.5.jpg new file mode 100644 index 0000000000000000000000000000000000000000..8ea2545edd28f9f8d6ebaa2de13dcffdd0a85d75 GIT binary patch literal 7393 zcmeG>4Lp?T_Rq{1gH%3V@;yt7L6oRn+KjTpB4rq=QCH-au~#LMn;2=;HWh_Fu42Y8 zzLnacR2!+*ZW1F(Xst@V8p_9rVc!2UBi;UcyKQ%W-QB&vdtV;!nRCv2=6pWqybrV) z{RybJIlDRo3MrJC{n4zq#ral|d*C86{>*(oy z^s$xcM@IAJ>gka_F`Ku*+Sb;VXll3g(}kZ|+1M_W3c)BUDk{%V)|xp}YoVc@;lh9U zLMwr)0v2Ea4r2tcsu-Lq2E74@P)|HYc7azH28)x!%M%n7m1aPOQWb#3;BZ(u93C$x z2f1S)9muKT)eNnd$*WTX2u5MZ!aYYb73MBKcTK}f@cd(&!0qvhN}6-Dv~`S4Oi5{Okno5dkvpUQx@+%O3H$bceSnpm zl6vg;iIaaz<7Q>^a&q(X3(kMTFD<)pvHVii^&2<8tN#9nTMc&`o9^9je$dkK=<$=z zr(NAWy)OobhDSuBFUQ8Ec445M)7O+?KeI~}+J%*q!^sh(c44qP;lioP;SH_j)s|5S z0b%M!3->4>%a3NByQVnT#_PF8;C6wM=Et@j#xJC%$t-(6!{UFHWv>nUyIn0n8Ha(v z!>Ix?7)~Q`{O4lPjl7Z=+VE;!l+kx zW>!(6pwhJfVXC=9qB&vk9Q z^G{BfcblD99r$pYU3mR2U#H7ow4QuYkFBw(HQdxUk)QfA&$`Oh;rg9M2$^}^oSEHL zY8{^#c3tc@r*}O*+Hu5QPt{NJzU#Kr+a4Fx<5~ai+YqlVzmn~iIQY`UYwl2l zWzTTYom`VRF{;KYx*~h5zIG?sBaDzCme|!8TTpZ!kBNLW!*!M z96udsBig6r@*>g}lS%QGG@eHRJ0VN#RU?VjjI*NBP*6IrqN^i#ijvBik?KPhy3ZdQ zb0e3|LqXT2di*W(Tv7Uoi3v+$i!XENc`9ZLt8g-FGYZIxu9H03-BGiXyIpAavL1?vvw(^;y^;yD5 z6bwBe+k*_&ei{6bZ0~S?t4uOSF)wclP8s5fX+#-c8szh2)qZpibRX~^!B}vAS7%p*4NSqudVd>`c`A$&uV-#FXFxX8VY9Wie zJ>|YFTs$bYOPpIBnb$wCqCC+-y?msA2B;L|oy$82+^9@kz-W9xZo1GAXIBKixM@Pj z?Q@HLF)naaObT!&@GKE1(!IMYTHw9|sd+2URn3H&fZPS-XqhZg5a1^EqSz1TWU7!_ohifd62L!|(1 zhg`|H?d7D=Q}CARiS4Yo!n9w(9?+3W6xWUk$}6t4rz4)&Cn$(&Dz@ACqt&BoP-f#} z-hToGMCNmKM_rny{m<(DZM*B@RU@Ipx!yr2@NIVkPm_Km(@Es0$>0y)Mb7jvUyPyv zv&d~4vc>SBiaiv7K$G9OWI;{$NSXADA>q>a!uwksC`1NUcSDEg5$lrjHPAn?GM9jHeJYBqe=Hz zv!!YIK?3|{l1#@~7Hn9sl^%c4Xpg-v8!7am%H{(^dAc`G$N1Modur?1A=`*Jvx22{ zy;mA@K;nMb^?Gll75sgdbmAefz%dl$-AlIh{9ZY!MACncss5Ddc-gpj{&{BMRvrj< z86XxT#4E9|v;C57kdF0&J?x!3@Ebd{#!hz>BvrgrEby6zI5RUFSOm9UF43GHB*j-# zd;h}jZS8T?6Nsz(CjyE)BH6U`M=1ZOVMOa>}8YT3M)Nv(5*9iCW))g2w|A zY{;=`C>R$U!QIpF%jIb%%SL-yYI+iErTQ(C#{9i=?bqgH-Al#aE+^lXes9CQPQ?G5 zl$0eWXhHtaBhVkYDgMe4=zmS|0}X16PypWt(><6M+nr#;>5a`v&QMMK!uLB5)8i<3 zxG1IAV~L$nlbuM7O+SZ%ON_;zjCx2>R<0slcjtS< z_g3U%0bqAd5}?BwV-m8n=IRHx zEe#eO>EB;tRV8q}xJz)f&*ihv5^AijQFQkgljhz_j2fF&S3-3?>P&Hn8Y5Udvqacm zXT3netG}Noo320Y?%HsMu?Ieeexd7caKyMGKn9X-2E@bXf|K6S;qfttW`cZGG& zy_lQ*;IU@0*Zo^XF-{F4Rr4uCirBtgT)KsIR;ZNu%!nqrb3r0Byla{Cyy+H6A}AoP zKHx64Y89vMcfTO%{8JWE!8aeB)AYzwQUe+1i+n0ZLq-r3@Iyvt7vA7RoZ2{nupel) zw-qayp@887k#&;lGD)YYXwijH3pW;-U$DQ4vq@hE%49c}K}LHzncoX>e?~!ID|35_ zSglojtdK$$+H=RRtYAM|!tD2gBtPc%QU;V6;yb(W^IJL^&t1=_Y4x8HEW6&T{SBy23 zBCR2^oTuU6QtAYsXOMY?5tVw57b&ljGA^s%NUXBhl9Kl2;)1W9msqddOB@aMQ2D&$vDSug zzX%@}yS(%VWZHq`%0iW0EVpmNwqDWvQ>lwqpa0quk2_eo(Y=4Y$eML(5T2latA6_J z!#S&R3$34!n!cyHSb1)%*2-4`J zi;C_PWF3}8ceDD4hon-s&0&3-4LQ`z7@Y80`+b#eX2HC$W~iihSdT~MVks@Xd5u(fNeP@)b*!qY8a&My>si-M`?#Ol-B}+)q4S8DXW9&` zq8Rxmn|)dNx;AUGu`8xl51fd%w;`W2*v*Ye>BjX zk5=Vn(#`?pG(o(oGE-=fbLI4Xo5WceZ1v{e>z+IOkDUAb90_+e+>1(y-AOO%t-Z5F znDNAjTNKYC6&?Hsfp31nIymiXWx?a3szb85-xVH$dxjKY;ZU^GElNaXC=?IH>q4M-VM`P{}77Ms6i|m>W zvWAe%aw_2~Cj)JhWWFV5@H-LFdeY*~?Qn$+J?%o>j*|YOD+%$lEu9Ch>=M(u(%dYD zZm%@SG0MuP_D0v|aEeu-H#_IOJ=I=~v)SA|2Pwt~;CoRWXHbkSH32**JPD_2N~JXV zaCWC|liDB0-^a3{+%O^WYLUfD68@16-LMXOJeegof(++0JUw#PonsuJjm!70zZW6U zN-AX}M&;NmBKc?iEN*Y!mv?5^&id0GRBal6%*gt5WX|etnne{K;YNk;-P;Mh?wid< GAN(Ix=};U1 literal 0 HcmV?d00001 diff --git a/tests/test_rasterize_meshes.py b/tests/test_rasterize_meshes.py index c8ba8298..7a0cebb9 100644 --- a/tests/test_rasterize_meshes.py +++ b/tests/test_rasterize_meshes.py @@ -6,6 +6,8 @@ import unittest import torch from common_testing import TestCaseMixin, get_random_cuda_device from pytorch3d import _C +from pytorch3d.renderer import FoVPerspectiveCameras, look_at_view_transform +from pytorch3d.renderer.mesh import MeshRasterizer, RasterizationSettings from pytorch3d.renderer.mesh.rasterize_meshes import ( rasterize_meshes, rasterize_meshes_python, @@ -1204,3 +1206,50 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase): torch.cuda.synchronize(device) return rasterize + + @staticmethod + def bm_rasterize_meshes_with_clipping( + num_meshes: int, + ico_level: int, + image_size: int, + blur_radius: float, + faces_per_pixel: int, + dist: float, + ): + device = get_random_cuda_device() + meshes = ico_sphere(ico_level, device) + meshes_batch = meshes.extend(num_meshes) + + settings = RasterizationSettings( + image_size=image_size, + blur_radius=blur_radius, + faces_per_pixel=faces_per_pixel, + z_clip_value=1e-2, + perspective_correct=True, + cull_to_frustum=True, + ) + + # The camera is positioned so that the image plane intersects + # the mesh and some faces are partially behind the image plane. + R, T = look_at_view_transform(dist, 0, 0) + cameras = FoVPerspectiveCameras(device=device, R=R, T=T, fov=90) + rasterizer = MeshRasterizer(raster_settings=settings, cameras=cameras) + + # Transform the meshes to projec them onto the image plane + meshes_screen = rasterizer.transform(meshes_batch) + torch.cuda.synchronize(device) + + def rasterize(): + # Only measure rasterization speed (including clipping) + rasterize_meshes( + meshes_screen, + image_size, + blur_radius, + faces_per_pixel, + z_clip_value=1e-2, + perspective_correct=True, + cull_to_frustum=True, + ) + torch.cuda.synchronize(device) + + return rasterize diff --git a/tests/test_render_implicit.py b/tests/test_render_implicit.py index c4f0cb5e..4c067810 100644 --- a/tests/test_render_implicit.py +++ b/tests/test_render_implicit.py @@ -245,6 +245,7 @@ class TestRenderImplicit(TestCaseMixin, unittest.TestCase): image_size=image_size, blur_radius=1e-3, faces_per_pixel=10, + z_clip_value=None, perspective_correct=False, ), ), diff --git a/tests/test_render_meshes.py b/tests/test_render_meshes.py index 81505776..d97c0d4f 100644 --- a/tests/test_render_meshes.py +++ b/tests/test_render_meshes.py @@ -994,6 +994,7 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase): blur_radius=0.0, faces_per_pixel=1, cull_backfaces=True, + perspective_correct=False, ) # Init shader settings diff --git a/tests/test_render_meshes_clipped.py b/tests/test_render_meshes_clipped.py index 474d730f..805e936e 100644 --- a/tests/test_render_meshes_clipped.py +++ b/tests/test_render_meshes_clipped.py @@ -9,14 +9,161 @@ See pytorch3d/renderer/mesh/clip.py for more details about the clipping process. """ import unittest +from pathlib import Path +import imageio +import numpy as np import torch -from common_testing import TestCaseMixin -from pytorch3d.renderer.mesh import ClipFrustum, clip_faces +from common_testing import TestCaseMixin, load_rgb_image +from pytorch3d.io import save_obj +from pytorch3d.renderer.cameras import FoVPerspectiveCameras, look_at_view_transform +from pytorch3d.renderer.lighting import PointLights +from pytorch3d.renderer.mesh import ( + ClipFrustum, + TexturesUV, + clip_faces, + convert_clipped_rasterization_to_original_faces, +) +from pytorch3d.renderer.mesh.rasterize_meshes import _RasterizeFaceVerts +from pytorch3d.renderer.mesh.rasterizer import MeshRasterizer, RasterizationSettings +from pytorch3d.renderer.mesh.renderer import MeshRenderer +from pytorch3d.renderer.mesh.shader import SoftPhongShader from pytorch3d.structures.meshes import Meshes +# If DEBUG=True, save out images generated in the tests for debugging. +# All saved images have prefix DEBUG_ +DEBUG = False +DATA_DIR = Path(__file__).resolve().parent / "data" + + class TestRenderMeshesClipping(TestCaseMixin, unittest.TestCase): + def load_cube_mesh_with_texture(self, device="cpu", with_grad: bool = False): + verts = torch.tensor( + [ + [-1, 1, 1], + [1, 1, 1], + [1, -1, 1], + [-1, -1, 1], + [-1, 1, -1], + [1, 1, -1], + [1, -1, -1], + [-1, -1, -1], + ], + device=device, + dtype=torch.float32, + requires_grad=with_grad, + ) + + # all faces correctly wound + faces = torch.tensor( + [ + [0, 1, 4], + [4, 1, 5], + [1, 2, 5], + [5, 2, 6], + [2, 7, 6], + [2, 3, 7], + [3, 4, 7], + [0, 4, 3], + [4, 5, 6], + [4, 6, 7], + ], + device=device, + dtype=torch.int64, + ) + + verts_uvs = torch.tensor( + [ + [ + [0, 1], + [1, 1], + [1, 0], + [0, 0], + [0.204, 0.743], + [0.781, 0.743], + [0.781, 0.154], + [0.204, 0.154], + ] + ], + device=device, + dtype=torch.float, + ) + texture_map = load_rgb_image("room.jpg", DATA_DIR).to(device) + textures = TexturesUV( + maps=[texture_map], faces_uvs=faces.unsqueeze(0), verts_uvs=verts_uvs + ) + mesh = Meshes([verts], [faces], textures=textures) + if with_grad: + return mesh, verts + return mesh + + def test_cube_mesh_render(self): + """ + End-End test of rendering a cube mesh with texture + from decreasing camera distances. The camera starts + outside the cube and enters the inside of the cube. + """ + device = torch.device("cuda:0") + mesh = self.load_cube_mesh_with_texture(device) + raster_settings = RasterizationSettings( + image_size=512, + blur_radius=1e-8, + faces_per_pixel=5, + z_clip_value=1e-2, + perspective_correct=True, + bin_size=0, + ) + + # Only ambient, no diffuse or specular + lights = PointLights( + device=device, + ambient_color=((1.0, 1.0, 1.0),), + diffuse_color=((0.0, 0.0, 0.0),), + specular_color=((0.0, 0.0, 0.0),), + location=[[0.0, 0.0, -3.0]], + ) + + renderer = MeshRenderer( + rasterizer=MeshRasterizer(raster_settings=raster_settings), + shader=SoftPhongShader(device=device, lights=lights), + ) + + # Render the cube by decreasing the distance from the camera until + # the camera enters the cube. Check the output looks correct. + images_list = [] + dists = np.linspace(0.1, 2.5, 20)[::-1] + for d in dists: + R, T = look_at_view_transform(d, 0, 0) + T[0, 1] -= 0.1 # move down in the y axis + cameras = FoVPerspectiveCameras(device=device, R=R, T=T, fov=90) + images = renderer(mesh, cameras=cameras) + rgb = images[0, ..., :3].cpu().detach() + filename = "DEBUG_cube_dist=%.1f.jpg" % d + im = (rgb.numpy() * 255).astype(np.uint8) + images_list.append(im) + + # Check one of the images where the camera is inside the mesh + if d == 0.5: + filename = "test_render_mesh_clipped_cam_dist=0.5.jpg" + image_ref = load_rgb_image(filename, DATA_DIR) + self.assertClose(rgb, image_ref, atol=0.05) + + # Save a gif of the output - this should show + # the camera moving inside the cube. + if DEBUG: + gif_filename = ( + "room_original.gif" + if raster_settings.z_clip_value is None + else "room_clipped.gif" + ) + imageio.mimsave(DATA_DIR / gif_filename, images_list, fps=2) + save_obj( + f=DATA_DIR / "cube.obj", + verts=mesh.verts_packed().cpu(), + faces=mesh.faces_packed().cpu(), + ) + @staticmethod def clip_faces(meshes): verts_packed = meshes.verts_packed() @@ -42,6 +189,34 @@ class TestRenderMeshesClipping(TestCaseMixin, unittest.TestCase): ) return clipped_faces + def test_grad(self): + """ + Check that gradient flow is unaffected when the camera is inside the mesh + """ + device = torch.device("cuda:0") + mesh, verts = self.load_cube_mesh_with_texture(device=device, with_grad=True) + raster_settings = RasterizationSettings( + image_size=512, + blur_radius=1e-5, + faces_per_pixel=5, + z_clip_value=1e-2, + perspective_correct=True, + bin_size=0, + ) + + renderer = MeshRenderer( + rasterizer=MeshRasterizer(raster_settings=raster_settings), + shader=SoftPhongShader(device=device), + ) + dist = 0.4 # Camera is inside the cube + R, T = look_at_view_transform(dist, 0, 0) + cameras = FoVPerspectiveCameras(device=device, R=R, T=T, fov=90) + images = renderer(mesh, cameras=cameras) + images.sum().backward() + + # Check gradients exist + self.assertIsNotNone(verts.grad) + def test_case_1(self): """ Case 1: Single triangle fully in front of the image plane (z=0) @@ -350,3 +525,134 @@ class TestRenderMeshesClipping(TestCaseMixin, unittest.TestCase): # barycentric conversion matrix. bary_idx = idx.new_tensor([1, 4, 2, 5, 0, -1, 3, 6]) self.assertClose(clipped_faces.faces_clipped_to_conversion_idx, bary_idx) + + def test_convert_clipped_to_unclipped_case_4(self): + """ + Test with a single case 4 triangle which is clipped into + a quadrilateral and subdivided. + """ + device = "cuda:0" + # fmt: off + verts = torch.tensor( + [ + [-1.0, 0.0, -1.0], # noqa: E241, E201 + [ 0.0, 1.0, -1.0], # noqa: E241, E201 + [ 1.0, 0.0, -1.0], # noqa: E241, E201 + [ 0.0, -1.0, -1.0], # noqa: E241, E201 + [-1.0, 0.5, 0.5], # noqa: E241, E201 + [ 1.0, 1.0, 1.0], # noqa: E241, E201 + [ 0.0, -1.0, 1.0], # noqa: E241, E201 + [-1.0, 0.5, -0.5], # noqa: E241, E201 + [ 1.0, 1.0, -1.0], # noqa: E241, E201 + [-1.0, 0.0, 1.0], # noqa: E241, E201 + [ 0.0, 1.0, 1.0], # noqa: E241, E201 + [ 1.0, 0.0, 1.0], # noqa: E241, E201 + ], + dtype=torch.float32, + device=device, + ) + faces = torch.tensor( + [ + [0, 1, 2], # noqa: E241, E201 Case 2 fully clipped + [3, 4, 5], # noqa: E241, E201 Case 4 clipped and subdivided + [5, 4, 3], # noqa: E241, E201 Repeat of Case 4 + [6, 7, 8], # noqa: E241, E201 Case 3 clipped + [9, 10, 11], # noqa: E241, E201 Case 1 untouched + ], + dtype=torch.int64, + device=device, + ) + # fmt: on + meshes = Meshes(verts=[verts], faces=[faces]) + + # Clip meshes + clipped_faces = self.clip_faces(meshes) + + # 4x faces (from Case 4) + 1 (from Case 3) + 1 (from Case 1) + self.assertEqual(clipped_faces.face_verts.shape[0], 6) + + image_size = (10, 10) + blur_radius = 0.05 + faces_per_pixel = 2 + perspective_correct = True + bin_size = 0 + max_faces_per_bin = 20 + clip_barycentric_coords = False + cull_backfaces = False + + # Rasterize clipped mesh + pix_to_face, zbuf, barycentric_coords, dists = _RasterizeFaceVerts.apply( + clipped_faces.face_verts, + clipped_faces.mesh_to_face_first_idx, + clipped_faces.num_faces_per_mesh, + clipped_faces.clipped_faces_neighbor_idx, + image_size, + blur_radius, + faces_per_pixel, + bin_size, + max_faces_per_bin, + perspective_correct, + clip_barycentric_coords, + cull_backfaces, + ) + + # Convert outputs so they are in terms of the unclipped mesh. + outputs = convert_clipped_rasterization_to_original_faces( + pix_to_face, + barycentric_coords, + clipped_faces, + ) + pix_to_face_unclipped, barycentric_coords_unclipped = outputs + + # In the clipped mesh there are more faces than in the unclipped mesh + self.assertTrue(pix_to_face.max() > pix_to_face_unclipped.max()) + # Unclipped pix_to_face indices must be in the limit of the number + # of faces in the unclipped mesh. + self.assertTrue(pix_to_face_unclipped.max() < faces.shape[0]) + + def test_case_4_no_duplicates(self): + """ + In the case of an simple mesh with one face that is cut by the image + plane into a quadrilateral, there shouldn't be duplicates indices of + the face in the pix_to_face output of rasterization. + """ + for (device, bin_size) in [("cpu", 0), ("cuda:0", 0), ("cuda:0", None)]: + verts = torch.tensor( + [[0.0, -10.0, 1.0], [-1.0, 2.0, -2.0], [1.0, 5.0, -10.0]], + dtype=torch.float32, + device=device, + ) + faces = torch.tensor( + [ + [0, 1, 2], + ], + dtype=torch.int64, + device=device, + ) + meshes = Meshes(verts=[verts], faces=[faces]) + k = 3 + settings = RasterizationSettings( + image_size=10, + blur_radius=0.05, + faces_per_pixel=k, + z_clip_value=1e-2, + perspective_correct=True, + cull_to_frustum=True, + bin_size=bin_size, + ) + + # The camera is positioned so that the image plane cuts + # the mesh face into a quadrilateral. + R, T = look_at_view_transform(0.2, 0, 0) + cameras = FoVPerspectiveCameras(device=device, R=R, T=T, fov=90) + rasterizer = MeshRasterizer(raster_settings=settings, cameras=cameras) + fragments = rasterizer(meshes) + + p2f = fragments.pix_to_face.reshape(-1, k) + unique_vals, idx_counts = p2f.unique(dim=0, return_counts=True) + # There is only one face in this mesh so if it hits a pixel + # it can only be at position k = 0 + # For any pixel, the values [0, 0, 1] for the top K faces cannot be possible + double_hit = torch.tensor([0, 0, -1], device=device) + check_double_hit = any(torch.allclose(i, double_hit) for i in unique_vals) + self.assertFalse(check_double_hit)