mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-20 21:02:48 +08:00
Shader API more consistent naming
Summary: Renamed shaders to be prefixed with Hard/Soft depending on if they use a probabalistic blending (Soft) or use the closest face (Hard). There is some code duplication but I thought it would be cleaner to have separate shaders for each task rather than: - inheritance (which we discussed previously that we want to avoid) - boolean (hard/soft) or a string (hard/soft) - new blending functions other than the ones provided would need if statements in the current shaders which might get messy. Also added a `flat_shading` function and a `FlatShader` - I could make this into a tutorial as it was really easy to add a new shader and it might be a nice showcase. NOTE: There are a few more places where the naming will need to change (e.g the tutorials) but I wanted to reach a consensus on this before changing it everywhere. Reviewed By: jcjohnson Differential Revision: D19761036 fbshipit-source-id: f972f6530c7f66dc5550b0284c191abc4a7f6fc4
This commit is contained in:
parent
60f3c4e7d2
commit
f0dc65110a
4
.gitignore
vendored
4
.gitignore
vendored
@ -2,6 +2,10 @@ build/
|
||||
dist/
|
||||
*.egg-info/
|
||||
**/__pycache__/
|
||||
*-checkpoint.ipynb
|
||||
**/.ipynb_checkpoints
|
||||
**/.ipynb_checkpoints/**
|
||||
|
||||
|
||||
# Docusaurus site
|
||||
website/yarn.lock
|
||||
|
@ -84,3 +84,25 @@ renderer = MeshRenderer(
|
||||
shader=PhongShader(device=device, cameras=cameras)
|
||||
)
|
||||
```
|
||||
|
||||
### A custom shader
|
||||
|
||||
Shaders are the most flexible part of the PyTorch3D rendering API. We have created some examples of shaders in `shaders.py` but this is a non exhaustive set.
|
||||
|
||||
A shader can incorporate several steps:
|
||||
- **texturing** (e.g interpolation of vertex RGB colors or interpolation of vertex UV coordinates followed by sampling from a texture map (interpolation uses barycentric coordinates output from rasterization))
|
||||
- **lighting/shading** (e.g. ambient, diffuse, specular lighting, Phong, Gourad, Flat)
|
||||
- **blending** (e.g. hard blending using only the closest face for each pixel, or soft blending using a weighted sum of the top K faces per pixel)
|
||||
|
||||
We have examples of several combinations of these functions based on the texturing/shading/blending support we have currently. These are summarised in this table below. Many other combinations are possible and we plan to expand the options available for texturing, shading and blending.
|
||||
|
||||
|
||||
|Example Shaders | Vertex Textures| Texture Map| Flat Shading| Gourad Shading| Phong Shading | Hard blending | Soft Blending |
|
||||
| ------------- |:-------------: | :--------------:| :--------------:| :--------------:| :--------------:|:--------------:|:--------------:|
|
||||
| HardPhongShader | :heavy_check_mark: |||| :heavy_check_mark: | :heavy_check_mark:||
|
||||
| SoftPhongShader | :heavy_check_mark: |||| :heavy_check_mark: | | :heavy_check_mark:|
|
||||
| HardGouradShader | :heavy_check_mark: ||| :heavy_check_mark: || :heavy_check_mark:||
|
||||
| SoftGouradShader | :heavy_check_mark: ||| :heavy_check_mark: ||| :heavy_check_mark:|
|
||||
| TexturedSoftPhongShader || :heavy_check_mark: ||| :heavy_check_mark: || :heavy_check_mark:|
|
||||
| HardFlatShader | :heavy_check_mark: || :heavy_check_mark: ||| :heavy_check_mark:||
|
||||
| SoftSilhouetteShader ||||||| :heavy_check_mark:|
|
||||
|
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@ -17,13 +17,16 @@ from .cameras import (
|
||||
from .lighting import DirectionalLights, PointLights, diffuse, specular
|
||||
from .materials import Materials
|
||||
from .mesh import (
|
||||
GouradShader,
|
||||
HardFlatShader,
|
||||
HardGouradShader,
|
||||
HardPhongShader,
|
||||
MeshRasterizer,
|
||||
MeshRenderer,
|
||||
PhongShader,
|
||||
RasterizationSettings,
|
||||
SilhouetteShader,
|
||||
TexturedPhongShader,
|
||||
SoftGouradShader,
|
||||
SoftPhongShader,
|
||||
SoftSilhouetteShader,
|
||||
TexturedSoftPhongShader,
|
||||
gourad_shading,
|
||||
interpolate_face_attributes,
|
||||
interpolate_texture_map,
|
||||
|
@ -4,10 +4,13 @@ from .rasterize_meshes import rasterize_meshes
|
||||
from .rasterizer import MeshRasterizer, RasterizationSettings
|
||||
from .renderer import MeshRenderer
|
||||
from .shader import (
|
||||
GouradShader,
|
||||
PhongShader,
|
||||
SilhouetteShader,
|
||||
TexturedPhongShader,
|
||||
HardFlatShader,
|
||||
HardGouradShader,
|
||||
HardPhongShader,
|
||||
SoftGouradShader,
|
||||
SoftPhongShader,
|
||||
SoftSilhouetteShader,
|
||||
TexturedSoftPhongShader,
|
||||
)
|
||||
from .shading import gourad_shading, phong_shading
|
||||
from .texturing import ( # isort: skip
|
||||
|
@ -14,7 +14,7 @@ from ..blending import (
|
||||
from ..cameras import OpenGLPerspectiveCameras
|
||||
from ..lighting import PointLights
|
||||
from ..materials import Materials
|
||||
from .shading import gourad_shading, phong_shading
|
||||
from .shading import flat_shading, gourad_shading, phong_shading
|
||||
from .texturing import interpolate_texture_map, interpolate_vertex_colors
|
||||
|
||||
# A Shader should take as input fragments from the output of rasterization
|
||||
@ -26,17 +26,18 @@ from .texturing import interpolate_texture_map, interpolate_vertex_colors
|
||||
# - blend colors across top K faces per pixel.
|
||||
|
||||
|
||||
class PhongShader(nn.Module):
|
||||
class HardPhongShader(nn.Module):
|
||||
"""
|
||||
Per pixel lighting. Apply the lighting model using the interpolated coords
|
||||
and normals for each pixel.
|
||||
Per pixel lighting - the lighting model is applied using the interpolated
|
||||
coordinates and normals for each pixel. The blending function hard assigns
|
||||
the color of the closest face for each pixel.
|
||||
|
||||
To use the default values, simply initialize the shader with the desired
|
||||
device e.g.
|
||||
|
||||
.. code-block::
|
||||
|
||||
shader = PhongShader(device=torch.device("cuda:0"))
|
||||
shader = HardPhongShader(device=torch.device("cuda:0"))
|
||||
"""
|
||||
|
||||
def __init__(self, device="cpu", cameras=None, lights=None, materials=None):
|
||||
@ -70,17 +71,74 @@ class PhongShader(nn.Module):
|
||||
return images
|
||||
|
||||
|
||||
class GouradShader(nn.Module):
|
||||
class SoftPhongShader(nn.Module):
|
||||
"""
|
||||
Per vertex lighting. Apply the lighting model to the vertex colors and then
|
||||
interpolate using the barycentric coordinates to get colors for each pixel.
|
||||
Per pixel lighting - the lighting model is applied using the interpolated
|
||||
coordinates and normals for each pixel. The blending function returns the
|
||||
soft aggregated color using all the faces per pixel.
|
||||
|
||||
To use the default values, simply initialize the shader with the desired
|
||||
device e.g.
|
||||
|
||||
.. code-block::
|
||||
|
||||
shader = GouradShader(device=torch.device("cuda:0"))
|
||||
shader = SoftPhongShader(device=torch.device("cuda:0"))
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
device="cpu",
|
||||
cameras=None,
|
||||
lights=None,
|
||||
materials=None,
|
||||
blend_params=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.lights = (
|
||||
lights if lights is not None else PointLights(device=device)
|
||||
)
|
||||
self.materials = (
|
||||
materials if materials is not None else Materials(device=device)
|
||||
)
|
||||
self.cameras = (
|
||||
cameras
|
||||
if cameras is not None
|
||||
else OpenGLPerspectiveCameras(device=device)
|
||||
)
|
||||
self.blend_params = (
|
||||
blend_params if blend_params is not None else BlendParams()
|
||||
)
|
||||
|
||||
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
|
||||
texels = interpolate_vertex_colors(fragments, meshes)
|
||||
cameras = kwargs.get("cameras", self.cameras)
|
||||
lights = kwargs.get("lights", self.lights)
|
||||
materials = kwargs.get("materials", self.materials)
|
||||
colors = phong_shading(
|
||||
meshes=meshes,
|
||||
fragments=fragments,
|
||||
texels=texels,
|
||||
lights=lights,
|
||||
cameras=cameras,
|
||||
materials=materials,
|
||||
)
|
||||
images = softmax_rgb_blend(colors, fragments, self.blend_params)
|
||||
return images
|
||||
|
||||
|
||||
class HardGouradShader(nn.Module):
|
||||
"""
|
||||
Per vertex lighting - the lighting model is applied to the vertex colors and
|
||||
the colors are then interpolated using the barycentric coordinates to
|
||||
obtain the colors for each pixel. The blending function hard assigns
|
||||
the color of the closest face for each pixel.
|
||||
|
||||
To use the default values, simply initialize the shader with the desired
|
||||
device e.g.
|
||||
|
||||
.. code-block::
|
||||
|
||||
shader = HardGouradShader(device=torch.device("cuda:0"))
|
||||
"""
|
||||
|
||||
def __init__(self, device="cpu", cameras=None, lights=None, materials=None):
|
||||
@ -112,12 +170,69 @@ class GouradShader(nn.Module):
|
||||
return images
|
||||
|
||||
|
||||
class TexturedPhongShader(nn.Module):
|
||||
class SoftGouradShader(nn.Module):
|
||||
"""
|
||||
Per vertex lighting - the lighting model is applied to the vertex colors and
|
||||
the colors are then interpolated using the barycentric coordinates to
|
||||
obtain the colors for each pixel. The blending function returns the
|
||||
soft aggregated color using all the faces per pixel.
|
||||
|
||||
To use the default values, simply initialize the shader with the desired
|
||||
device e.g.
|
||||
|
||||
.. code-block::
|
||||
|
||||
shader = SoftGouradShader(device=torch.device("cuda:0"))
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
device="cpu",
|
||||
cameras=None,
|
||||
lights=None,
|
||||
materials=None,
|
||||
blend_params=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.lights = (
|
||||
lights if lights is not None else PointLights(device=device)
|
||||
)
|
||||
self.materials = (
|
||||
materials if materials is not None else Materials(device=device)
|
||||
)
|
||||
self.cameras = (
|
||||
cameras
|
||||
if cameras is not None
|
||||
else OpenGLPerspectiveCameras(device=device)
|
||||
)
|
||||
self.blend_params = (
|
||||
blend_params if blend_params is not None else BlendParams()
|
||||
)
|
||||
|
||||
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
|
||||
cameras = kwargs.get("cameras", self.cameras)
|
||||
lights = kwargs.get("lights", self.lights)
|
||||
materials = kwargs.get("materials", self.materials)
|
||||
pixel_colors = gourad_shading(
|
||||
meshes=meshes,
|
||||
fragments=fragments,
|
||||
lights=lights,
|
||||
cameras=cameras,
|
||||
materials=materials,
|
||||
)
|
||||
images = softmax_rgb_blend(pixel_colors, fragments, self.blend_params)
|
||||
return images
|
||||
|
||||
|
||||
class TexturedSoftPhongShader(nn.Module):
|
||||
"""
|
||||
Per pixel lighting applied to a texture map. First interpolate the vertex
|
||||
uv coordinates and sample from a texture map. Then apply the lighting model
|
||||
using the interpolated coords and normals for each pixel.
|
||||
|
||||
The blending function returns the soft aggregated color using all
|
||||
the faces per pixel.
|
||||
|
||||
To use the default values, simply initialize the shader with the desired
|
||||
device e.g.
|
||||
|
||||
@ -167,7 +282,52 @@ class TexturedPhongShader(nn.Module):
|
||||
return images
|
||||
|
||||
|
||||
class SilhouetteShader(nn.Module):
|
||||
class HardFlatShader(nn.Module):
|
||||
"""
|
||||
Per face lighting - the lighting model is applied using the average face
|
||||
position and the face normal. The blending function hard assigns
|
||||
the color of the closest face for each pixel.
|
||||
|
||||
To use the default values, simply initialize the shader with the desired
|
||||
device e.g.
|
||||
|
||||
.. code-block::
|
||||
|
||||
shader = HardFlatShader(device=torch.device("cuda:0"))
|
||||
"""
|
||||
|
||||
def __init__(self, device="cpu", cameras=None, lights=None, materials=None):
|
||||
super().__init__()
|
||||
self.lights = (
|
||||
lights if lights is not None else PointLights(device=device)
|
||||
)
|
||||
self.materials = (
|
||||
materials if materials is not None else Materials(device=device)
|
||||
)
|
||||
self.cameras = (
|
||||
cameras
|
||||
if cameras is not None
|
||||
else OpenGLPerspectiveCameras(device=device)
|
||||
)
|
||||
|
||||
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
|
||||
texels = interpolate_vertex_colors(fragments, meshes)
|
||||
cameras = kwargs.get("cameras", self.cameras)
|
||||
lights = kwargs.get("lights", self.lights)
|
||||
materials = kwargs.get("materials", self.materials)
|
||||
colors = flat_shading(
|
||||
meshes=meshes,
|
||||
fragments=fragments,
|
||||
texels=texels,
|
||||
lights=lights,
|
||||
cameras=cameras,
|
||||
materials=materials,
|
||||
)
|
||||
images = hard_rgb_blend(colors, fragments)
|
||||
return images
|
||||
|
||||
|
||||
class SoftSilhouetteShader(nn.Module):
|
||||
"""
|
||||
Calculate the silhouette by blending the top K faces for each pixel based
|
||||
on the 2d euclidean distance of the centre of the pixel to the mesh face.
|
||||
|
@ -124,3 +124,39 @@ def gourad_shading(
|
||||
face_colors = verts_colors_shaded[faces]
|
||||
colors = interpolate_face_attributes(fragments, face_colors)
|
||||
return colors
|
||||
|
||||
|
||||
def flat_shading(
|
||||
meshes, fragments, lights, cameras, materials, texels
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Apply per face shading. Use the average face position and the face normals
|
||||
to compute the ambient, diffuse and specular lighting. Apply the ambient
|
||||
and diffuse color to the pixel color and add the specular component to
|
||||
determine the final pixel color.
|
||||
|
||||
Args:
|
||||
meshes: Batch of meshes
|
||||
fragments: Fragments named tuple with the outputs of rasterization
|
||||
lights: Lights class containing a batch of lights parameters
|
||||
cameras: Cameras class containing a batch of cameras parameters
|
||||
materials: Materials class containing a batch of material properties
|
||||
texels: texture per pixel of shape (N, H, W, K, 3)
|
||||
|
||||
Returns:
|
||||
colors: (N, H, W, K, 3)
|
||||
"""
|
||||
verts = meshes.verts_packed() # (V, 3)
|
||||
faces = meshes.faces_packed() # (F, 3)
|
||||
face_normals = meshes.faces_normals_packed() # (V, 3)
|
||||
faces_verts = verts[faces]
|
||||
face_coords = faces_verts.mean(dim=-2) # (F, 3, XYZ) mean xyz across verts
|
||||
pixel_coords = face_coords[fragments.pix_to_face]
|
||||
pixel_normals = face_normals[fragments.pix_to_face]
|
||||
|
||||
# Calculate the illumination at each face
|
||||
ambient, diffuse, specular = _apply_lighting(
|
||||
pixel_coords, pixel_normals, lights, cameras, materials
|
||||
)
|
||||
colors = (ambient + diffuse) * texels + specular
|
||||
return colors
|
||||
|
@ -25,10 +25,10 @@ from pytorch3d.renderer.mesh.rasterizer import (
|
||||
from pytorch3d.renderer.mesh.renderer import MeshRenderer
|
||||
from pytorch3d.renderer.mesh.shader import (
|
||||
BlendParams,
|
||||
GouradShader,
|
||||
PhongShader,
|
||||
SilhouetteShader,
|
||||
TexturedPhongShader,
|
||||
HardGouradShader,
|
||||
HardPhongShader,
|
||||
SoftSilhouetteShader,
|
||||
TexturedSoftPhongShader,
|
||||
)
|
||||
from pytorch3d.renderer.mesh.texturing import Textures
|
||||
from pytorch3d.structures.meshes import Meshes
|
||||
@ -92,7 +92,7 @@ class TestRenderingMeshes(unittest.TestCase):
|
||||
)
|
||||
renderer = MeshRenderer(
|
||||
rasterizer=rasterizer,
|
||||
shader=PhongShader(
|
||||
shader=HardPhongShader(
|
||||
lights=lights, cameras=cameras, materials=materials
|
||||
),
|
||||
)
|
||||
@ -133,7 +133,7 @@ class TestRenderingMeshes(unittest.TestCase):
|
||||
lights.location = torch.tensor([0.0, 0.0, -2.0], device=device)[None]
|
||||
renderer = MeshRenderer(
|
||||
rasterizer=rasterizer,
|
||||
shader=GouradShader(
|
||||
shader=HardGouradShader(
|
||||
lights=lights, cameras=cameras, materials=materials
|
||||
),
|
||||
)
|
||||
@ -197,7 +197,7 @@ class TestRenderingMeshes(unittest.TestCase):
|
||||
rasterizer=MeshRasterizer(
|
||||
cameras=cameras, raster_settings=raster_settings
|
||||
),
|
||||
shader=PhongShader(
|
||||
shader=HardPhongShader(
|
||||
lights=lights, cameras=cameras, materials=materials
|
||||
),
|
||||
)
|
||||
@ -242,7 +242,7 @@ class TestRenderingMeshes(unittest.TestCase):
|
||||
rasterizer=MeshRasterizer(
|
||||
cameras=cameras, raster_settings=raster_settings
|
||||
),
|
||||
shader=SilhouetteShader(blend_params=blend_params),
|
||||
shader=SoftSilhouetteShader(blend_params=blend_params),
|
||||
)
|
||||
images = renderer(sphere_mesh)
|
||||
alpha = images[0, ..., 3].squeeze().cpu()
|
||||
@ -296,7 +296,7 @@ class TestRenderingMeshes(unittest.TestCase):
|
||||
rasterizer=MeshRasterizer(
|
||||
cameras=cameras, raster_settings=raster_settings
|
||||
),
|
||||
shader=TexturedPhongShader(
|
||||
shader=TexturedSoftPhongShader(
|
||||
lights=lights, cameras=cameras, materials=materials
|
||||
),
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user