From 83fef0a57681b06ab03dd7302e76e7a7d21961dc Mon Sep 17 00:00:00 2001 From: Dave Schnizlein Date: Thu, 5 Nov 2020 09:18:36 -0800 Subject: [PATCH] Add MeshRendererWithFragments class to also return fragments after rendering Summary: Users want to be able to obtain the depth from the renderer. Current work-around requires running the rasterizer and extra time. This change creates a new renderer class that also returns the fragments from the rasterizer. Reviewed By: nikhilaravi Differential Revision: D24432381 fbshipit-source-id: 6552e8a6bfee646791afb34bdb7452fbc4094aed --- pytorch3d/renderer/mesh/renderer.py | 43 +++++++++++++++++++++++++++++ tests/test_render_meshes.py | 40 +++++++++++++++++++++++---- 2 files changed, 77 insertions(+), 6 deletions(-) diff --git a/pytorch3d/renderer/mesh/renderer.py b/pytorch3d/renderer/mesh/renderer.py index 9e366616..fe2ba5d1 100644 --- a/pytorch3d/renderer/mesh/renderer.py +++ b/pytorch3d/renderer/mesh/renderer.py @@ -55,3 +55,46 @@ class MeshRenderer(nn.Module): images = self.shader(fragments, meshes_world, **kwargs) return images + + +class MeshRendererWithFragments(nn.Module): + """ + A class for rendering a batch of heterogeneous meshes. The class should + be initialized with a rasterizer and shader class which each have a forward + function. + + In the forward pass this class returns the `fragments` from which intermediate + values such as the depth map can be easily extracted e.g. + + .. code-block:: python + images, fragments = renderer(meshes) + depth = fragments.zbuf + """ + + def __init__(self, rasterizer, shader): + super().__init__() + self.rasterizer = rasterizer + self.shader = shader + + def to(self, device): + # Rasterizer and shader have submodules which are not of type nn.Module + self.rasterizer.to(device) + self.shader.to(device) + + def forward(self, meshes_world, **kwargs): + """ + Render a batch of images from a batch of meshes by rasterizing and then + shading. + + NOTE: If the blur radius for rasterization is > 0.0, some pixels can + have one or more barycentric coordinates lying outside the range [0, 1]. + For a pixel with out of bounds barycentric coordinates with respect to a + face f, clipping is required before interpolating the texture uv + coordinates and z buffer so that the colors and depths are limited to + the range for the corresponding face. + For this set rasterizer.raster_settings.clip_barycentric_coords=True + """ + fragments = self.rasterizer(meshes_world, **kwargs) + images = self.shader(fragments, meshes_world, **kwargs) + + return images, fragments diff --git a/tests/test_render_meshes.py b/tests/test_render_meshes.py index 4f0c2ba2..344b32d3 100644 --- a/tests/test_render_meshes.py +++ b/tests/test_render_meshes.py @@ -24,7 +24,7 @@ from pytorch3d.renderer.lighting import PointLights from pytorch3d.renderer.materials import Materials from pytorch3d.renderer.mesh import TexturesAtlas, TexturesUV, TexturesVertex from pytorch3d.renderer.mesh.rasterizer import MeshRasterizer, RasterizationSettings -from pytorch3d.renderer.mesh.renderer import MeshRenderer +from pytorch3d.renderer.mesh.renderer import MeshRenderer, MeshRendererWithFragments from pytorch3d.renderer.mesh.shader import ( BlendParams, HardFlatShader, @@ -50,7 +50,7 @@ DATA_DIR = Path(__file__).resolve().parent / "data" class TestRenderMeshes(TestCaseMixin, unittest.TestCase): - def test_simple_sphere(self, elevated_camera=False): + def test_simple_sphere(self, elevated_camera=False, check_depth=False): """ Test output of phong and gouraud shading matches a reference image using the default values for the light sources. @@ -114,8 +114,16 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase): materials=materials, blend_params=blend_params, ) - renderer = MeshRenderer(rasterizer=rasterizer, shader=shader) - images = renderer(sphere_mesh) + if check_depth: + renderer = MeshRendererWithFragments( + rasterizer=rasterizer, shader=shader + ) + images, fragments = renderer(sphere_mesh) + self.assertClose(fragments.zbuf, rasterizer(sphere_mesh).zbuf) + else: + renderer = MeshRenderer(rasterizer=rasterizer, shader=shader) + images = renderer(sphere_mesh) + rgb = images[0, ..., :3].squeeze().cpu() filename = "simple_sphere_light_%s%s%s.png" % ( name, @@ -144,8 +152,19 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase): materials=materials, blend_params=blend_params, ) - phong_renderer = MeshRenderer(rasterizer=rasterizer, shader=phong_shader) - images = phong_renderer(sphere_mesh, lights=lights) + if check_depth: + phong_renderer = MeshRendererWithFragments( + rasterizer=rasterizer, shader=phong_shader + ) + images, fragments = phong_renderer(sphere_mesh, lights=lights) + self.assertClose( + fragments.zbuf, rasterizer(sphere_mesh, lights=lights).zbuf + ) + else: + phong_renderer = MeshRenderer( + rasterizer=rasterizer, shader=phong_shader + ) + images = phong_renderer(sphere_mesh, lights=lights) rgb = images[0, ..., :3].squeeze().cpu() if DEBUG: filename = "DEBUG_simple_sphere_dark%s%s.png" % ( @@ -171,6 +190,15 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase): """ self.test_simple_sphere(elevated_camera=True) + def test_simple_sphere_depth(self): + """ + Test output of phong and gouraud shading matches a reference image using + the default values for the light sources. + + The rendering is performed with a camera that has non-zero elevation. + """ + self.test_simple_sphere(check_depth=True) + def test_simple_sphere_screen(self): """