Culling to frustrum bug fix

Summary:
When `z_clip_value = None` and faces are outside the view frustum the shape of one of the tensors in `clip.py` is incorrect.

`faces_num_clipped_verts` should be (F,) but it was (F,3).  Added a new test to ensure this case is handled.

Reviewed By: bottler

Differential Revision: D29051282

fbshipit-source-id: 5f4172ba4d4a75d928404dde9abf48aef18c68bd
This commit is contained in:
Nikhila Ravi 2021-06-11 14:33:01 -07:00 committed by Facebook GitHub Bot
parent ef16253953
commit a0f79318c5
3 changed files with 33 additions and 12 deletions

View File

@ -372,7 +372,7 @@ def clip_faces(
# (F) dim tensor containing the number of clipped vertices in each triangle # (F) dim tensor containing the number of clipped vertices in each triangle
faces_num_clipped_verts = faces_clipped_verts.sum(1) faces_num_clipped_verts = faces_clipped_verts.sum(1)
else: else:
faces_num_clipped_verts = torch.zeros([F, 3], device=device) faces_num_clipped_verts = torch.zeros([F], device=device)
# If no triangles need to be clipped or culled, avoid unnecessary computation # If no triangles need to be clipped or culled, avoid unnecessary computation
# and return early # and return early

Binary file not shown.

Before

Width:  |  Height:  |  Size: 7.2 KiB

View File

@ -15,7 +15,11 @@ import numpy as np
import torch import torch
from common_testing import TestCaseMixin, get_tests_dir, load_rgb_image from common_testing import TestCaseMixin, get_tests_dir, load_rgb_image
from pytorch3d.io import save_obj from pytorch3d.io import save_obj
from pytorch3d.renderer.cameras import FoVPerspectiveCameras, look_at_view_transform from pytorch3d.renderer.cameras import (
FoVPerspectiveCameras,
look_at_view_transform,
PerspectiveCameras,
)
from pytorch3d.renderer.lighting import PointLights from pytorch3d.renderer.lighting import PointLights
from pytorch3d.renderer.mesh import ( from pytorch3d.renderer.mesh import (
ClipFrustum, ClipFrustum,
@ -27,8 +31,9 @@ from pytorch3d.renderer.mesh.rasterize_meshes import _RasterizeFaceVerts
from pytorch3d.renderer.mesh.rasterizer import MeshRasterizer, RasterizationSettings from pytorch3d.renderer.mesh.rasterizer import MeshRasterizer, RasterizationSettings
from pytorch3d.renderer.mesh.renderer import MeshRenderer from pytorch3d.renderer.mesh.renderer import MeshRenderer
from pytorch3d.renderer.mesh.shader import SoftPhongShader from pytorch3d.renderer.mesh.shader import SoftPhongShader
from pytorch3d.renderer.mesh.textures import TexturesVertex
from pytorch3d.structures.meshes import Meshes from pytorch3d.structures.meshes import Meshes
from pytorch3d.utils import torus
# If DEBUG=True, save out images generated in the tests for debugging. # If DEBUG=True, save out images generated in the tests for debugging.
# All saved images have prefix DEBUG_ # All saved images have prefix DEBUG_
@ -97,9 +102,9 @@ class TestRenderMeshesClipping(TestCaseMixin, unittest.TestCase):
return mesh, verts return mesh, verts
return mesh return mesh
def test_cube_mesh_render(self): def debug_cube_mesh_render(self):
""" """
End-End test of rendering a cube mesh with texture End-End debug run of rendering a cube mesh with texture
from decreasing camera distances. The camera starts from decreasing camera distances. The camera starts
outside the cube and enters the inside of the cube. outside the cube and enters the inside of the cube.
""" """
@ -132,22 +137,16 @@ class TestRenderMeshesClipping(TestCaseMixin, unittest.TestCase):
# the camera enters the cube. Check the output looks correct. # the camera enters the cube. Check the output looks correct.
images_list = [] images_list = []
dists = np.linspace(0.1, 2.5, 20)[::-1] dists = np.linspace(0.1, 2.5, 20)[::-1]
for d in dists: for d in dists:
R, T = look_at_view_transform(d, 0, 0) R, T = look_at_view_transform(d, 0, 0)
T[0, 1] -= 0.1 # move down in the y axis T[0, 1] -= 0.1 # move down in the y axis
cameras = FoVPerspectiveCameras(device=device, R=R, T=T, fov=90) cameras = FoVPerspectiveCameras(device=device, R=R, T=T, fov=90)
images = renderer(mesh, cameras=cameras) images = renderer(mesh, cameras=cameras)
rgb = images[0, ..., :3].cpu().detach() rgb = images[0, ..., :3].cpu().detach()
filename = "DEBUG_cube_dist=%.1f.jpg" % d
im = (rgb.numpy() * 255).astype(np.uint8) im = (rgb.numpy() * 255).astype(np.uint8)
images_list.append(im) images_list.append(im)
# Check one of the images where the camera is inside the mesh
if d == 0.5:
filename = "test_render_mesh_clipped_cam_dist=0.5.jpg"
image_ref = load_rgb_image(filename, DATA_DIR)
self.assertClose(rgb, image_ref, atol=0.05)
# Save a gif of the output - this should show # Save a gif of the output - this should show
# the camera moving inside the cube. # the camera moving inside the cube.
if DEBUG: if DEBUG:
@ -655,3 +654,25 @@ class TestRenderMeshesClipping(TestCaseMixin, unittest.TestCase):
double_hit = torch.tensor([0, 0, -1], device=device) double_hit = torch.tensor([0, 0, -1], device=device)
check_double_hit = any(torch.allclose(i, double_hit) for i in unique_vals) check_double_hit = any(torch.allclose(i, double_hit) for i in unique_vals)
self.assertFalse(check_double_hit) self.assertFalse(check_double_hit)
def test_mesh_outside_frustrum(self):
"""
Test the case where the mesh is completely outside the view
frustrum so all faces are culled and z_clip_value = None.
"""
device = "cuda:0"
mesh = torus(20.0, 85.0, 32, 16, device=device)
tex = TexturesVertex(verts_features=torch.rand_like(mesh.verts_padded()))
mesh.textures = tex
raster_settings = RasterizationSettings(image_size=512, cull_to_frustum=True)
R, T = look_at_view_transform(1.0, 0.0, 0.0)
cameras = PerspectiveCameras(device=device, R=R, T=T)
renderer = MeshRenderer(
rasterizer=MeshRasterizer(cameras=cameras, raster_settings=raster_settings),
shader=SoftPhongShader(cameras=cameras, device=device),
)
images = renderer(mesh)
# Mesh is completely outside the view frustrum
# The image should be white.
self.assertClose(images[0, ..., :3], torch.ones_like(images[0, ..., :3]))