Shader API more consistent naming

Summary:
Renamed shaders to be prefixed with Hard/Soft depending on if they use a probabalistic blending (Soft) or use the closest face (Hard).

There is some code duplication but I thought it would be cleaner to have separate shaders for each task rather than:
- inheritance (which we discussed previously that we want to avoid)
- boolean (hard/soft) or a string (hard/soft) - new blending functions other than the ones provided would need if statements in the current shaders which might get messy.

Also added a `flat_shading` function and a `FlatShader` - I could make this into a tutorial as it was really easy to add a new shader and it might be a nice showcase.

NOTE: There are a few more places where the naming will need to change (e.g the tutorials) but I wanted to reach a consensus on this before changing it everywhere.

Reviewed By: jcjohnson

Differential Revision: D19761036

fbshipit-source-id: f972f6530c7f66dc5550b0284c191abc4a7f6fc4
This commit is contained in:
Nikhila Ravi 2020-02-19 23:15:12 -08:00 committed by Facebook Github Bot
parent 60f3c4e7d2
commit f0dc65110a
9 changed files with 293 additions and 82 deletions

4
.gitignore vendored
View File

@ -2,6 +2,10 @@ build/
dist/ dist/
*.egg-info/ *.egg-info/
**/__pycache__/ **/__pycache__/
*-checkpoint.ipynb
**/.ipynb_checkpoints
**/.ipynb_checkpoints/**
# Docusaurus site # Docusaurus site
website/yarn.lock website/yarn.lock

View File

@ -84,3 +84,25 @@ renderer = MeshRenderer(
shader=PhongShader(device=device, cameras=cameras) shader=PhongShader(device=device, cameras=cameras)
) )
``` ```
### A custom shader
Shaders are the most flexible part of the PyTorch3D rendering API. We have created some examples of shaders in `shaders.py` but this is a non exhaustive set.
A shader can incorporate several steps:
- **texturing** (e.g interpolation of vertex RGB colors or interpolation of vertex UV coordinates followed by sampling from a texture map (interpolation uses barycentric coordinates output from rasterization))
- **lighting/shading** (e.g. ambient, diffuse, specular lighting, Phong, Gourad, Flat)
- **blending** (e.g. hard blending using only the closest face for each pixel, or soft blending using a weighted sum of the top K faces per pixel)
We have examples of several combinations of these functions based on the texturing/shading/blending support we have currently. These are summarised in this table below. Many other combinations are possible and we plan to expand the options available for texturing, shading and blending.
|Example Shaders | Vertex Textures| Texture Map| Flat Shading| Gourad Shading| Phong Shading | Hard blending | Soft Blending |
| ------------- |:-------------: | :--------------:| :--------------:| :--------------:| :--------------:|:--------------:|:--------------:|
| HardPhongShader | :heavy_check_mark: |||| :heavy_check_mark: | :heavy_check_mark:||
| SoftPhongShader | :heavy_check_mark: |||| :heavy_check_mark: | | :heavy_check_mark:|
| HardGouradShader | :heavy_check_mark: ||| :heavy_check_mark: || :heavy_check_mark:||
| SoftGouradShader | :heavy_check_mark: ||| :heavy_check_mark: ||| :heavy_check_mark:|
| TexturedSoftPhongShader || :heavy_check_mark: ||| :heavy_check_mark: || :heavy_check_mark:|
| HardFlatShader | :heavy_check_mark: || :heavy_check_mark: ||| :heavy_check_mark:||
| SoftSilhouetteShader ||||||| :heavy_check_mark:|

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -17,13 +17,16 @@ from .cameras import (
from .lighting import DirectionalLights, PointLights, diffuse, specular from .lighting import DirectionalLights, PointLights, diffuse, specular
from .materials import Materials from .materials import Materials
from .mesh import ( from .mesh import (
GouradShader, HardFlatShader,
HardGouradShader,
HardPhongShader,
MeshRasterizer, MeshRasterizer,
MeshRenderer, MeshRenderer,
PhongShader,
RasterizationSettings, RasterizationSettings,
SilhouetteShader, SoftGouradShader,
TexturedPhongShader, SoftPhongShader,
SoftSilhouetteShader,
TexturedSoftPhongShader,
gourad_shading, gourad_shading,
interpolate_face_attributes, interpolate_face_attributes,
interpolate_texture_map, interpolate_texture_map,

View File

@ -4,10 +4,13 @@ from .rasterize_meshes import rasterize_meshes
from .rasterizer import MeshRasterizer, RasterizationSettings from .rasterizer import MeshRasterizer, RasterizationSettings
from .renderer import MeshRenderer from .renderer import MeshRenderer
from .shader import ( from .shader import (
GouradShader, HardFlatShader,
PhongShader, HardGouradShader,
SilhouetteShader, HardPhongShader,
TexturedPhongShader, SoftGouradShader,
SoftPhongShader,
SoftSilhouetteShader,
TexturedSoftPhongShader,
) )
from .shading import gourad_shading, phong_shading from .shading import gourad_shading, phong_shading
from .texturing import ( # isort: skip from .texturing import ( # isort: skip

View File

@ -14,7 +14,7 @@ from ..blending import (
from ..cameras import OpenGLPerspectiveCameras from ..cameras import OpenGLPerspectiveCameras
from ..lighting import PointLights from ..lighting import PointLights
from ..materials import Materials from ..materials import Materials
from .shading import gourad_shading, phong_shading from .shading import flat_shading, gourad_shading, phong_shading
from .texturing import interpolate_texture_map, interpolate_vertex_colors from .texturing import interpolate_texture_map, interpolate_vertex_colors
# A Shader should take as input fragments from the output of rasterization # A Shader should take as input fragments from the output of rasterization
@ -26,17 +26,18 @@ from .texturing import interpolate_texture_map, interpolate_vertex_colors
# - blend colors across top K faces per pixel. # - blend colors across top K faces per pixel.
class PhongShader(nn.Module): class HardPhongShader(nn.Module):
""" """
Per pixel lighting. Apply the lighting model using the interpolated coords Per pixel lighting - the lighting model is applied using the interpolated
and normals for each pixel. coordinates and normals for each pixel. The blending function hard assigns
the color of the closest face for each pixel.
To use the default values, simply initialize the shader with the desired To use the default values, simply initialize the shader with the desired
device e.g. device e.g.
.. code-block:: .. code-block::
shader = PhongShader(device=torch.device("cuda:0")) shader = HardPhongShader(device=torch.device("cuda:0"))
""" """
def __init__(self, device="cpu", cameras=None, lights=None, materials=None): def __init__(self, device="cpu", cameras=None, lights=None, materials=None):
@ -70,17 +71,74 @@ class PhongShader(nn.Module):
return images return images
class GouradShader(nn.Module): class SoftPhongShader(nn.Module):
""" """
Per vertex lighting. Apply the lighting model to the vertex colors and then Per pixel lighting - the lighting model is applied using the interpolated
interpolate using the barycentric coordinates to get colors for each pixel. coordinates and normals for each pixel. The blending function returns the
soft aggregated color using all the faces per pixel.
To use the default values, simply initialize the shader with the desired To use the default values, simply initialize the shader with the desired
device e.g. device e.g.
.. code-block:: .. code-block::
shader = GouradShader(device=torch.device("cuda:0")) shader = SoftPhongShader(device=torch.device("cuda:0"))
"""
def __init__(
self,
device="cpu",
cameras=None,
lights=None,
materials=None,
blend_params=None,
):
super().__init__()
self.lights = (
lights if lights is not None else PointLights(device=device)
)
self.materials = (
materials if materials is not None else Materials(device=device)
)
self.cameras = (
cameras
if cameras is not None
else OpenGLPerspectiveCameras(device=device)
)
self.blend_params = (
blend_params if blend_params is not None else BlendParams()
)
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
texels = interpolate_vertex_colors(fragments, meshes)
cameras = kwargs.get("cameras", self.cameras)
lights = kwargs.get("lights", self.lights)
materials = kwargs.get("materials", self.materials)
colors = phong_shading(
meshes=meshes,
fragments=fragments,
texels=texels,
lights=lights,
cameras=cameras,
materials=materials,
)
images = softmax_rgb_blend(colors, fragments, self.blend_params)
return images
class HardGouradShader(nn.Module):
"""
Per vertex lighting - the lighting model is applied to the vertex colors and
the colors are then interpolated using the barycentric coordinates to
obtain the colors for each pixel. The blending function hard assigns
the color of the closest face for each pixel.
To use the default values, simply initialize the shader with the desired
device e.g.
.. code-block::
shader = HardGouradShader(device=torch.device("cuda:0"))
""" """
def __init__(self, device="cpu", cameras=None, lights=None, materials=None): def __init__(self, device="cpu", cameras=None, lights=None, materials=None):
@ -112,12 +170,69 @@ class GouradShader(nn.Module):
return images return images
class TexturedPhongShader(nn.Module): class SoftGouradShader(nn.Module):
"""
Per vertex lighting - the lighting model is applied to the vertex colors and
the colors are then interpolated using the barycentric coordinates to
obtain the colors for each pixel. The blending function returns the
soft aggregated color using all the faces per pixel.
To use the default values, simply initialize the shader with the desired
device e.g.
.. code-block::
shader = SoftGouradShader(device=torch.device("cuda:0"))
"""
def __init__(
self,
device="cpu",
cameras=None,
lights=None,
materials=None,
blend_params=None,
):
super().__init__()
self.lights = (
lights if lights is not None else PointLights(device=device)
)
self.materials = (
materials if materials is not None else Materials(device=device)
)
self.cameras = (
cameras
if cameras is not None
else OpenGLPerspectiveCameras(device=device)
)
self.blend_params = (
blend_params if blend_params is not None else BlendParams()
)
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
cameras = kwargs.get("cameras", self.cameras)
lights = kwargs.get("lights", self.lights)
materials = kwargs.get("materials", self.materials)
pixel_colors = gourad_shading(
meshes=meshes,
fragments=fragments,
lights=lights,
cameras=cameras,
materials=materials,
)
images = softmax_rgb_blend(pixel_colors, fragments, self.blend_params)
return images
class TexturedSoftPhongShader(nn.Module):
""" """
Per pixel lighting applied to a texture map. First interpolate the vertex Per pixel lighting applied to a texture map. First interpolate the vertex
uv coordinates and sample from a texture map. Then apply the lighting model uv coordinates and sample from a texture map. Then apply the lighting model
using the interpolated coords and normals for each pixel. using the interpolated coords and normals for each pixel.
The blending function returns the soft aggregated color using all
the faces per pixel.
To use the default values, simply initialize the shader with the desired To use the default values, simply initialize the shader with the desired
device e.g. device e.g.
@ -167,7 +282,52 @@ class TexturedPhongShader(nn.Module):
return images return images
class SilhouetteShader(nn.Module): class HardFlatShader(nn.Module):
"""
Per face lighting - the lighting model is applied using the average face
position and the face normal. The blending function hard assigns
the color of the closest face for each pixel.
To use the default values, simply initialize the shader with the desired
device e.g.
.. code-block::
shader = HardFlatShader(device=torch.device("cuda:0"))
"""
def __init__(self, device="cpu", cameras=None, lights=None, materials=None):
super().__init__()
self.lights = (
lights if lights is not None else PointLights(device=device)
)
self.materials = (
materials if materials is not None else Materials(device=device)
)
self.cameras = (
cameras
if cameras is not None
else OpenGLPerspectiveCameras(device=device)
)
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
texels = interpolate_vertex_colors(fragments, meshes)
cameras = kwargs.get("cameras", self.cameras)
lights = kwargs.get("lights", self.lights)
materials = kwargs.get("materials", self.materials)
colors = flat_shading(
meshes=meshes,
fragments=fragments,
texels=texels,
lights=lights,
cameras=cameras,
materials=materials,
)
images = hard_rgb_blend(colors, fragments)
return images
class SoftSilhouetteShader(nn.Module):
""" """
Calculate the silhouette by blending the top K faces for each pixel based Calculate the silhouette by blending the top K faces for each pixel based
on the 2d euclidean distance of the centre of the pixel to the mesh face. on the 2d euclidean distance of the centre of the pixel to the mesh face.

View File

@ -124,3 +124,39 @@ def gourad_shading(
face_colors = verts_colors_shaded[faces] face_colors = verts_colors_shaded[faces]
colors = interpolate_face_attributes(fragments, face_colors) colors = interpolate_face_attributes(fragments, face_colors)
return colors return colors
def flat_shading(
meshes, fragments, lights, cameras, materials, texels
) -> torch.Tensor:
"""
Apply per face shading. Use the average face position and the face normals
to compute the ambient, diffuse and specular lighting. Apply the ambient
and diffuse color to the pixel color and add the specular component to
determine the final pixel color.
Args:
meshes: Batch of meshes
fragments: Fragments named tuple with the outputs of rasterization
lights: Lights class containing a batch of lights parameters
cameras: Cameras class containing a batch of cameras parameters
materials: Materials class containing a batch of material properties
texels: texture per pixel of shape (N, H, W, K, 3)
Returns:
colors: (N, H, W, K, 3)
"""
verts = meshes.verts_packed() # (V, 3)
faces = meshes.faces_packed() # (F, 3)
face_normals = meshes.faces_normals_packed() # (V, 3)
faces_verts = verts[faces]
face_coords = faces_verts.mean(dim=-2) # (F, 3, XYZ) mean xyz across verts
pixel_coords = face_coords[fragments.pix_to_face]
pixel_normals = face_normals[fragments.pix_to_face]
# Calculate the illumination at each face
ambient, diffuse, specular = _apply_lighting(
pixel_coords, pixel_normals, lights, cameras, materials
)
colors = (ambient + diffuse) * texels + specular
return colors

View File

@ -25,10 +25,10 @@ from pytorch3d.renderer.mesh.rasterizer import (
from pytorch3d.renderer.mesh.renderer import MeshRenderer from pytorch3d.renderer.mesh.renderer import MeshRenderer
from pytorch3d.renderer.mesh.shader import ( from pytorch3d.renderer.mesh.shader import (
BlendParams, BlendParams,
GouradShader, HardGouradShader,
PhongShader, HardPhongShader,
SilhouetteShader, SoftSilhouetteShader,
TexturedPhongShader, TexturedSoftPhongShader,
) )
from pytorch3d.renderer.mesh.texturing import Textures from pytorch3d.renderer.mesh.texturing import Textures
from pytorch3d.structures.meshes import Meshes from pytorch3d.structures.meshes import Meshes
@ -92,7 +92,7 @@ class TestRenderingMeshes(unittest.TestCase):
) )
renderer = MeshRenderer( renderer = MeshRenderer(
rasterizer=rasterizer, rasterizer=rasterizer,
shader=PhongShader( shader=HardPhongShader(
lights=lights, cameras=cameras, materials=materials lights=lights, cameras=cameras, materials=materials
), ),
) )
@ -133,7 +133,7 @@ class TestRenderingMeshes(unittest.TestCase):
lights.location = torch.tensor([0.0, 0.0, -2.0], device=device)[None] lights.location = torch.tensor([0.0, 0.0, -2.0], device=device)[None]
renderer = MeshRenderer( renderer = MeshRenderer(
rasterizer=rasterizer, rasterizer=rasterizer,
shader=GouradShader( shader=HardGouradShader(
lights=lights, cameras=cameras, materials=materials lights=lights, cameras=cameras, materials=materials
), ),
) )
@ -197,7 +197,7 @@ class TestRenderingMeshes(unittest.TestCase):
rasterizer=MeshRasterizer( rasterizer=MeshRasterizer(
cameras=cameras, raster_settings=raster_settings cameras=cameras, raster_settings=raster_settings
), ),
shader=PhongShader( shader=HardPhongShader(
lights=lights, cameras=cameras, materials=materials lights=lights, cameras=cameras, materials=materials
), ),
) )
@ -242,7 +242,7 @@ class TestRenderingMeshes(unittest.TestCase):
rasterizer=MeshRasterizer( rasterizer=MeshRasterizer(
cameras=cameras, raster_settings=raster_settings cameras=cameras, raster_settings=raster_settings
), ),
shader=SilhouetteShader(blend_params=blend_params), shader=SoftSilhouetteShader(blend_params=blend_params),
) )
images = renderer(sphere_mesh) images = renderer(sphere_mesh)
alpha = images[0, ..., 3].squeeze().cpu() alpha = images[0, ..., 3].squeeze().cpu()
@ -296,7 +296,7 @@ class TestRenderingMeshes(unittest.TestCase):
rasterizer=MeshRasterizer( rasterizer=MeshRasterizer(
cameras=cameras, raster_settings=raster_settings cameras=cameras, raster_settings=raster_settings
), ),
shader=TexturedPhongShader( shader=TexturedSoftPhongShader(
lights=lights, cameras=cameras, materials=materials lights=lights, cameras=cameras, materials=materials
), ),
) )