Add MeshRendererWithFragments class to also return fragments after rendering

Summary: Users want to be able to obtain the depth from the renderer. Current work-around requires running the rasterizer and extra time. This change creates a new renderer class that also returns the fragments from the rasterizer.

Reviewed By: nikhilaravi

Differential Revision: D24432381

fbshipit-source-id: 6552e8a6bfee646791afb34bdb7452fbc4094aed
This commit is contained in:
Dave Schnizlein 2020-11-05 09:18:36 -08:00 committed by Facebook GitHub Bot
parent b6be3b95fb
commit 83fef0a576
2 changed files with 77 additions and 6 deletions

View File

@ -55,3 +55,46 @@ class MeshRenderer(nn.Module):
images = self.shader(fragments, meshes_world, **kwargs)
return images
class MeshRendererWithFragments(nn.Module):
"""
A class for rendering a batch of heterogeneous meshes. The class should
be initialized with a rasterizer and shader class which each have a forward
function.
In the forward pass this class returns the `fragments` from which intermediate
values such as the depth map can be easily extracted e.g.
.. code-block:: python
images, fragments = renderer(meshes)
depth = fragments.zbuf
"""
def __init__(self, rasterizer, shader):
super().__init__()
self.rasterizer = rasterizer
self.shader = shader
def to(self, device):
# Rasterizer and shader have submodules which are not of type nn.Module
self.rasterizer.to(device)
self.shader.to(device)
def forward(self, meshes_world, **kwargs):
"""
Render a batch of images from a batch of meshes by rasterizing and then
shading.
NOTE: If the blur radius for rasterization is > 0.0, some pixels can
have one or more barycentric coordinates lying outside the range [0, 1].
For a pixel with out of bounds barycentric coordinates with respect to a
face f, clipping is required before interpolating the texture uv
coordinates and z buffer so that the colors and depths are limited to
the range for the corresponding face.
For this set rasterizer.raster_settings.clip_barycentric_coords=True
"""
fragments = self.rasterizer(meshes_world, **kwargs)
images = self.shader(fragments, meshes_world, **kwargs)
return images, fragments

View File

@ -24,7 +24,7 @@ from pytorch3d.renderer.lighting import PointLights
from pytorch3d.renderer.materials import Materials
from pytorch3d.renderer.mesh import TexturesAtlas, TexturesUV, TexturesVertex
from pytorch3d.renderer.mesh.rasterizer import MeshRasterizer, RasterizationSettings
from pytorch3d.renderer.mesh.renderer import MeshRenderer
from pytorch3d.renderer.mesh.renderer import MeshRenderer, MeshRendererWithFragments
from pytorch3d.renderer.mesh.shader import (
BlendParams,
HardFlatShader,
@ -50,7 +50,7 @@ DATA_DIR = Path(__file__).resolve().parent / "data"
class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
def test_simple_sphere(self, elevated_camera=False):
def test_simple_sphere(self, elevated_camera=False, check_depth=False):
"""
Test output of phong and gouraud shading matches a reference image using
the default values for the light sources.
@ -114,8 +114,16 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
materials=materials,
blend_params=blend_params,
)
renderer = MeshRenderer(rasterizer=rasterizer, shader=shader)
images = renderer(sphere_mesh)
if check_depth:
renderer = MeshRendererWithFragments(
rasterizer=rasterizer, shader=shader
)
images, fragments = renderer(sphere_mesh)
self.assertClose(fragments.zbuf, rasterizer(sphere_mesh).zbuf)
else:
renderer = MeshRenderer(rasterizer=rasterizer, shader=shader)
images = renderer(sphere_mesh)
rgb = images[0, ..., :3].squeeze().cpu()
filename = "simple_sphere_light_%s%s%s.png" % (
name,
@ -144,8 +152,19 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
materials=materials,
blend_params=blend_params,
)
phong_renderer = MeshRenderer(rasterizer=rasterizer, shader=phong_shader)
images = phong_renderer(sphere_mesh, lights=lights)
if check_depth:
phong_renderer = MeshRendererWithFragments(
rasterizer=rasterizer, shader=phong_shader
)
images, fragments = phong_renderer(sphere_mesh, lights=lights)
self.assertClose(
fragments.zbuf, rasterizer(sphere_mesh, lights=lights).zbuf
)
else:
phong_renderer = MeshRenderer(
rasterizer=rasterizer, shader=phong_shader
)
images = phong_renderer(sphere_mesh, lights=lights)
rgb = images[0, ..., :3].squeeze().cpu()
if DEBUG:
filename = "DEBUG_simple_sphere_dark%s%s.png" % (
@ -171,6 +190,15 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
"""
self.test_simple_sphere(elevated_camera=True)
def test_simple_sphere_depth(self):
"""
Test output of phong and gouraud shading matches a reference image using
the default values for the light sources.
The rendering is performed with a camera that has non-zero elevation.
"""
self.test_simple_sphere(check_depth=True)
def test_simple_sphere_screen(self):
"""