From 9aaba0483c08c9a40c26db0858f8c0688f33e850 Mon Sep 17 00:00:00 2001 From: Steve Branson Date: Thu, 20 Aug 2020 22:23:06 -0700 Subject: [PATCH] Temporary fix for mesh rasterization bug for traingles partially behind the camera Summary: A triangle is culled if any vertex in a triangle is behind the camera. This fixes incorrect rendering of triangles that are partially behind the camera, where screen coordinate calculations are strange. It doesn't work for triangles that are partially behind the camera but still intersect with the view frustum. Reviewed By: nikhilaravi Differential Revision: D22856181 fbshipit-source-id: a9cbaa1327d89601b83d0dfd3e4a04f934a4a213 --- .../csrc/rasterize_meshes/rasterize_meshes.cu | 17 +++++++++++++---- .../rasterize_meshes/rasterize_meshes_cpu.cpp | 15 +++++++++++---- pytorch3d/renderer/mesh/rasterize_meshes.py | 7 +++++++ tests/test_rasterize_meshes.py | 8 ++++++++ 4 files changed, 39 insertions(+), 8 deletions(-) diff --git a/pytorch3d/csrc/rasterize_meshes/rasterize_meshes.cu b/pytorch3d/csrc/rasterize_meshes/rasterize_meshes.cu index aa069725..ba575280 100644 --- a/pytorch3d/csrc/rasterize_meshes/rasterize_meshes.cu +++ b/pytorch3d/csrc/rasterize_meshes/rasterize_meshes.cu @@ -90,8 +90,14 @@ __device__ bool CheckPointOutsideBoundingBox( const float x_max = xlims.y + blur_radius; const float y_max = ylims.y + blur_radius; + // Faces with at least one vertex behind the camera won't render correctly + // and should be removed or clipped before calling the rasterizer + const bool z_invalid = zlims.x < kEpsilon; + // Check if the current point is oustside the triangle bounding box. - return (pxy.x > x_max || pxy.x < x_min || pxy.y > y_max || pxy.y < y_min); + return ( + pxy.x > x_max || pxy.x < x_min || pxy.y > y_max || pxy.y < y_min || + z_invalid); } // This function checks if a pixel given by xy location pxy lies within the @@ -625,10 +631,13 @@ __global__ void RasterizeMeshesCoarseCudaKernel( float ymin = FloatMin3(v0.y, v1.y, v2.y) - sqrt(blur_radius); float xmax = FloatMax3(v0.x, v1.x, v2.x) + sqrt(blur_radius); float ymax = FloatMax3(v0.y, v1.y, v2.y) + sqrt(blur_radius); - float zmax = FloatMax3(v0.z, v1.z, v2.z); + float zmin = FloatMin3(v0.z, v1.z, v2.z); - if (zmax < 0) { - continue; // Face is behind the camera. + // Faces with at least one vertex behind the camera won't render + // correctly and should be removed or clipped before calling the + // rasterizer + if (zmin < kEpsilon) { + continue; } // Brute-force search over all bins; TODO(T54294966) something smarter. diff --git a/pytorch3d/csrc/rasterize_meshes/rasterize_meshes_cpu.cpp b/pytorch3d/csrc/rasterize_meshes/rasterize_meshes_cpu.cpp index 89a7dc3a..af6f09a7 100644 --- a/pytorch3d/csrc/rasterize_meshes/rasterize_meshes_cpu.cpp +++ b/pytorch3d/csrc/rasterize_meshes/rasterize_meshes_cpu.cpp @@ -68,8 +68,12 @@ bool CheckPointOutsideBoundingBox( float x_max = face_bbox[2] + blur_radius; float y_max = face_bbox[3] + blur_radius; + // Faces with at least one vertex behind the camera won't render correctly + // and should be removed or clipped before calling the rasterizer + const bool z_invalid = face_bbox[4] < kEpsilon; + // Check if the current point is within the triangle bounding box. - return (px > x_max || px < x_min || py > y_max || py < y_min); + return (px > x_max || px < x_min || py > y_max || py < y_min || z_invalid); } // Calculate areas of all faces. Returns a tensor of shape (total_faces, 1) @@ -468,10 +472,13 @@ torch::Tensor RasterizeMeshesCoarseCpu( float face_y_min = face_bboxes_a[f][1] - std::sqrt(blur_radius); float face_x_max = face_bboxes_a[f][2] + std::sqrt(blur_radius); float face_y_max = face_bboxes_a[f][3] + std::sqrt(blur_radius); - float face_z_max = face_bboxes_a[f][5]; + float face_z_min = face_bboxes_a[f][4]; - if (face_z_max < 0) { - continue; // Face is behind the camera. + // Faces with at least one vertex behind the camera won't render + // correctly and should be removed or clipped before calling the + // rasterizer + if (face_z_min < kEpsilon) { + continue; } // Use a half-open interval so that faces exactly on the diff --git a/pytorch3d/renderer/mesh/rasterize_meshes.py b/pytorch3d/renderer/mesh/rasterize_meshes.py index 0d81b540..24ba306b 100644 --- a/pytorch3d/renderer/mesh/rasterize_meshes.py +++ b/pytorch3d/renderer/mesh/rasterize_meshes.py @@ -301,6 +301,7 @@ def rasterize_meshes_python( x_maxs = torch.max(faces_verts[:, :, 0], dim=1, keepdim=True).values y_mins = torch.min(faces_verts[:, :, 1], dim=1, keepdim=True).values y_maxs = torch.max(faces_verts[:, :, 1], dim=1, keepdim=True).values + z_mins = torch.min(faces_verts[:, :, 2], dim=1, keepdim=True).values # Expand by blur radius. x_mins = x_mins - np.sqrt(blur_radius) - kEpsilon @@ -351,6 +352,12 @@ def rasterize_meshes_python( or yf > y_maxs[f] ) + # Faces with at least one vertex behind the camera won't + # render correctly and should be removed or clipped before + # calling the rasterizer + if z_mins[f] < kEpsilon: + continue + # Check if pixel is outside of face bbox. if outside_bbox: continue diff --git a/tests/test_rasterize_meshes.py b/tests/test_rasterize_meshes.py index e25b5c8c..5a47766b 100644 --- a/tests/test_rasterize_meshes.py +++ b/tests/test_rasterize_meshes.py @@ -552,6 +552,10 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase): + (zbuf1 * grad_zbuf).sum() + (bary1 * grad_bary).sum() ) + + # avoid gradient error if rasterize_meshes_python() culls all triangles + loss1 += grad_var1.sum() * 0.0 + loss1.backward() grad_verts1 = grad_var1.grad.data.clone().cpu() @@ -563,6 +567,10 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase): + (zbuf2 * grad_zbuf).sum() + (bary2 * grad_bary).sum() ) + + # avoid gradient error if rasterize_meshes_python() culls all triangles + loss2 += grad_var2.sum() * 0.0 + grad_var1.grad.data.zero_() loss2.backward() grad_verts2 = grad_var2.grad.data.clone().cpu()