Culling to frustrum bug fix

Summary:
When `z_clip_value = None` and faces are outside the view frustum the shape of one of the tensors in `clip.py` is incorrect.

`faces_num_clipped_verts` should be (F,) but it was (F,3).  Added a new test to ensure this case is handled.

Reviewed By: bottler

Differential Revision: D29051282

fbshipit-source-id: 5f4172ba4d4a75d928404dde9abf48aef18c68bd
This commit is contained in:
Nikhila Ravi 2021-06-11 14:33:01 -07:00 committed by Facebook GitHub Bot
parent ef16253953
commit a0f79318c5
3 changed files with 33 additions and 12 deletions

View File

@ -372,7 +372,7 @@ def clip_faces(
# (F) dim tensor containing the number of clipped vertices in each triangle
faces_num_clipped_verts = faces_clipped_verts.sum(1)
else:
faces_num_clipped_verts = torch.zeros([F, 3], device=device)
faces_num_clipped_verts = torch.zeros([F], device=device)
# If no triangles need to be clipped or culled, avoid unnecessary computation
# and return early

Binary file not shown.

Before

Width:  |  Height:  |  Size: 7.2 KiB

View File

@ -15,7 +15,11 @@ import numpy as np
import torch
from common_testing import TestCaseMixin, get_tests_dir, load_rgb_image
from pytorch3d.io import save_obj
from pytorch3d.renderer.cameras import FoVPerspectiveCameras, look_at_view_transform
from pytorch3d.renderer.cameras import (
FoVPerspectiveCameras,
look_at_view_transform,
PerspectiveCameras,
)
from pytorch3d.renderer.lighting import PointLights
from pytorch3d.renderer.mesh import (
ClipFrustum,
@ -27,8 +31,9 @@ from pytorch3d.renderer.mesh.rasterize_meshes import _RasterizeFaceVerts
from pytorch3d.renderer.mesh.rasterizer import MeshRasterizer, RasterizationSettings
from pytorch3d.renderer.mesh.renderer import MeshRenderer
from pytorch3d.renderer.mesh.shader import SoftPhongShader
from pytorch3d.renderer.mesh.textures import TexturesVertex
from pytorch3d.structures.meshes import Meshes
from pytorch3d.utils import torus
# If DEBUG=True, save out images generated in the tests for debugging.
# All saved images have prefix DEBUG_
@ -97,9 +102,9 @@ class TestRenderMeshesClipping(TestCaseMixin, unittest.TestCase):
return mesh, verts
return mesh
def test_cube_mesh_render(self):
def debug_cube_mesh_render(self):
"""
End-End test of rendering a cube mesh with texture
End-End debug run of rendering a cube mesh with texture
from decreasing camera distances. The camera starts
outside the cube and enters the inside of the cube.
"""
@ -132,22 +137,16 @@ class TestRenderMeshesClipping(TestCaseMixin, unittest.TestCase):
# the camera enters the cube. Check the output looks correct.
images_list = []
dists = np.linspace(0.1, 2.5, 20)[::-1]
for d in dists:
R, T = look_at_view_transform(d, 0, 0)
T[0, 1] -= 0.1 # move down in the y axis
cameras = FoVPerspectiveCameras(device=device, R=R, T=T, fov=90)
images = renderer(mesh, cameras=cameras)
rgb = images[0, ..., :3].cpu().detach()
filename = "DEBUG_cube_dist=%.1f.jpg" % d
im = (rgb.numpy() * 255).astype(np.uint8)
images_list.append(im)
# Check one of the images where the camera is inside the mesh
if d == 0.5:
filename = "test_render_mesh_clipped_cam_dist=0.5.jpg"
image_ref = load_rgb_image(filename, DATA_DIR)
self.assertClose(rgb, image_ref, atol=0.05)
# Save a gif of the output - this should show
# the camera moving inside the cube.
if DEBUG:
@ -655,3 +654,25 @@ class TestRenderMeshesClipping(TestCaseMixin, unittest.TestCase):
double_hit = torch.tensor([0, 0, -1], device=device)
check_double_hit = any(torch.allclose(i, double_hit) for i in unique_vals)
self.assertFalse(check_double_hit)
def test_mesh_outside_frustrum(self):
"""
Test the case where the mesh is completely outside the view
frustrum so all faces are culled and z_clip_value = None.
"""
device = "cuda:0"
mesh = torus(20.0, 85.0, 32, 16, device=device)
tex = TexturesVertex(verts_features=torch.rand_like(mesh.verts_padded()))
mesh.textures = tex
raster_settings = RasterizationSettings(image_size=512, cull_to_frustum=True)
R, T = look_at_view_transform(1.0, 0.0, 0.0)
cameras = PerspectiveCameras(device=device, R=R, T=T)
renderer = MeshRenderer(
rasterizer=MeshRasterizer(cameras=cameras, raster_settings=raster_settings),
shader=SoftPhongShader(cameras=cameras, device=device),
)
images = renderer(mesh)
# Mesh is completely outside the view frustrum
# The image should be white.
self.assertClose(images[0, ..., :3], torch.ones_like(images[0, ..., :3]))