mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Add MeshRendererWithFragments class to also return fragments after rendering
Summary: Users want to be able to obtain the depth from the renderer. Current work-around requires running the rasterizer and extra time. This change creates a new renderer class that also returns the fragments from the rasterizer. Reviewed By: nikhilaravi Differential Revision: D24432381 fbshipit-source-id: 6552e8a6bfee646791afb34bdb7452fbc4094aed
This commit is contained in:
parent
b6be3b95fb
commit
83fef0a576
@ -55,3 +55,46 @@ class MeshRenderer(nn.Module):
|
||||
images = self.shader(fragments, meshes_world, **kwargs)
|
||||
|
||||
return images
|
||||
|
||||
|
||||
class MeshRendererWithFragments(nn.Module):
|
||||
"""
|
||||
A class for rendering a batch of heterogeneous meshes. The class should
|
||||
be initialized with a rasterizer and shader class which each have a forward
|
||||
function.
|
||||
|
||||
In the forward pass this class returns the `fragments` from which intermediate
|
||||
values such as the depth map can be easily extracted e.g.
|
||||
|
||||
.. code-block:: python
|
||||
images, fragments = renderer(meshes)
|
||||
depth = fragments.zbuf
|
||||
"""
|
||||
|
||||
def __init__(self, rasterizer, shader):
|
||||
super().__init__()
|
||||
self.rasterizer = rasterizer
|
||||
self.shader = shader
|
||||
|
||||
def to(self, device):
|
||||
# Rasterizer and shader have submodules which are not of type nn.Module
|
||||
self.rasterizer.to(device)
|
||||
self.shader.to(device)
|
||||
|
||||
def forward(self, meshes_world, **kwargs):
|
||||
"""
|
||||
Render a batch of images from a batch of meshes by rasterizing and then
|
||||
shading.
|
||||
|
||||
NOTE: If the blur radius for rasterization is > 0.0, some pixels can
|
||||
have one or more barycentric coordinates lying outside the range [0, 1].
|
||||
For a pixel with out of bounds barycentric coordinates with respect to a
|
||||
face f, clipping is required before interpolating the texture uv
|
||||
coordinates and z buffer so that the colors and depths are limited to
|
||||
the range for the corresponding face.
|
||||
For this set rasterizer.raster_settings.clip_barycentric_coords=True
|
||||
"""
|
||||
fragments = self.rasterizer(meshes_world, **kwargs)
|
||||
images = self.shader(fragments, meshes_world, **kwargs)
|
||||
|
||||
return images, fragments
|
||||
|
@ -24,7 +24,7 @@ from pytorch3d.renderer.lighting import PointLights
|
||||
from pytorch3d.renderer.materials import Materials
|
||||
from pytorch3d.renderer.mesh import TexturesAtlas, TexturesUV, TexturesVertex
|
||||
from pytorch3d.renderer.mesh.rasterizer import MeshRasterizer, RasterizationSettings
|
||||
from pytorch3d.renderer.mesh.renderer import MeshRenderer
|
||||
from pytorch3d.renderer.mesh.renderer import MeshRenderer, MeshRendererWithFragments
|
||||
from pytorch3d.renderer.mesh.shader import (
|
||||
BlendParams,
|
||||
HardFlatShader,
|
||||
@ -50,7 +50,7 @@ DATA_DIR = Path(__file__).resolve().parent / "data"
|
||||
|
||||
|
||||
class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
|
||||
def test_simple_sphere(self, elevated_camera=False):
|
||||
def test_simple_sphere(self, elevated_camera=False, check_depth=False):
|
||||
"""
|
||||
Test output of phong and gouraud shading matches a reference image using
|
||||
the default values for the light sources.
|
||||
@ -114,8 +114,16 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
|
||||
materials=materials,
|
||||
blend_params=blend_params,
|
||||
)
|
||||
renderer = MeshRenderer(rasterizer=rasterizer, shader=shader)
|
||||
images = renderer(sphere_mesh)
|
||||
if check_depth:
|
||||
renderer = MeshRendererWithFragments(
|
||||
rasterizer=rasterizer, shader=shader
|
||||
)
|
||||
images, fragments = renderer(sphere_mesh)
|
||||
self.assertClose(fragments.zbuf, rasterizer(sphere_mesh).zbuf)
|
||||
else:
|
||||
renderer = MeshRenderer(rasterizer=rasterizer, shader=shader)
|
||||
images = renderer(sphere_mesh)
|
||||
|
||||
rgb = images[0, ..., :3].squeeze().cpu()
|
||||
filename = "simple_sphere_light_%s%s%s.png" % (
|
||||
name,
|
||||
@ -144,8 +152,19 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
|
||||
materials=materials,
|
||||
blend_params=blend_params,
|
||||
)
|
||||
phong_renderer = MeshRenderer(rasterizer=rasterizer, shader=phong_shader)
|
||||
images = phong_renderer(sphere_mesh, lights=lights)
|
||||
if check_depth:
|
||||
phong_renderer = MeshRendererWithFragments(
|
||||
rasterizer=rasterizer, shader=phong_shader
|
||||
)
|
||||
images, fragments = phong_renderer(sphere_mesh, lights=lights)
|
||||
self.assertClose(
|
||||
fragments.zbuf, rasterizer(sphere_mesh, lights=lights).zbuf
|
||||
)
|
||||
else:
|
||||
phong_renderer = MeshRenderer(
|
||||
rasterizer=rasterizer, shader=phong_shader
|
||||
)
|
||||
images = phong_renderer(sphere_mesh, lights=lights)
|
||||
rgb = images[0, ..., :3].squeeze().cpu()
|
||||
if DEBUG:
|
||||
filename = "DEBUG_simple_sphere_dark%s%s.png" % (
|
||||
@ -171,6 +190,15 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
|
||||
"""
|
||||
self.test_simple_sphere(elevated_camera=True)
|
||||
|
||||
def test_simple_sphere_depth(self):
|
||||
"""
|
||||
Test output of phong and gouraud shading matches a reference image using
|
||||
the default values for the light sources.
|
||||
|
||||
The rendering is performed with a camera that has non-zero elevation.
|
||||
"""
|
||||
self.test_simple_sphere(check_depth=True)
|
||||
|
||||
def test_simple_sphere_screen(self):
|
||||
|
||||
"""
|
||||
|
Loading…
x
Reference in New Issue
Block a user