mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-21 05:12:48 +08:00
Shader API more consistent naming
Summary: Renamed shaders to be prefixed with Hard/Soft depending on if they use a probabalistic blending (Soft) or use the closest face (Hard). There is some code duplication but I thought it would be cleaner to have separate shaders for each task rather than: - inheritance (which we discussed previously that we want to avoid) - boolean (hard/soft) or a string (hard/soft) - new blending functions other than the ones provided would need if statements in the current shaders which might get messy. Also added a `flat_shading` function and a `FlatShader` - I could make this into a tutorial as it was really easy to add a new shader and it might be a nice showcase. NOTE: There are a few more places where the naming will need to change (e.g the tutorials) but I wanted to reach a consensus on this before changing it everywhere. Reviewed By: jcjohnson Differential Revision: D19761036 fbshipit-source-id: f972f6530c7f66dc5550b0284c191abc4a7f6fc4
This commit is contained in:
parent
60f3c4e7d2
commit
f0dc65110a
4
.gitignore
vendored
4
.gitignore
vendored
@ -2,6 +2,10 @@ build/
|
|||||||
dist/
|
dist/
|
||||||
*.egg-info/
|
*.egg-info/
|
||||||
**/__pycache__/
|
**/__pycache__/
|
||||||
|
*-checkpoint.ipynb
|
||||||
|
**/.ipynb_checkpoints
|
||||||
|
**/.ipynb_checkpoints/**
|
||||||
|
|
||||||
|
|
||||||
# Docusaurus site
|
# Docusaurus site
|
||||||
website/yarn.lock
|
website/yarn.lock
|
||||||
|
@ -84,3 +84,25 @@ renderer = MeshRenderer(
|
|||||||
shader=PhongShader(device=device, cameras=cameras)
|
shader=PhongShader(device=device, cameras=cameras)
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### A custom shader
|
||||||
|
|
||||||
|
Shaders are the most flexible part of the PyTorch3D rendering API. We have created some examples of shaders in `shaders.py` but this is a non exhaustive set.
|
||||||
|
|
||||||
|
A shader can incorporate several steps:
|
||||||
|
- **texturing** (e.g interpolation of vertex RGB colors or interpolation of vertex UV coordinates followed by sampling from a texture map (interpolation uses barycentric coordinates output from rasterization))
|
||||||
|
- **lighting/shading** (e.g. ambient, diffuse, specular lighting, Phong, Gourad, Flat)
|
||||||
|
- **blending** (e.g. hard blending using only the closest face for each pixel, or soft blending using a weighted sum of the top K faces per pixel)
|
||||||
|
|
||||||
|
We have examples of several combinations of these functions based on the texturing/shading/blending support we have currently. These are summarised in this table below. Many other combinations are possible and we plan to expand the options available for texturing, shading and blending.
|
||||||
|
|
||||||
|
|
||||||
|
|Example Shaders | Vertex Textures| Texture Map| Flat Shading| Gourad Shading| Phong Shading | Hard blending | Soft Blending |
|
||||||
|
| ------------- |:-------------: | :--------------:| :--------------:| :--------------:| :--------------:|:--------------:|:--------------:|
|
||||||
|
| HardPhongShader | :heavy_check_mark: |||| :heavy_check_mark: | :heavy_check_mark:||
|
||||||
|
| SoftPhongShader | :heavy_check_mark: |||| :heavy_check_mark: | | :heavy_check_mark:|
|
||||||
|
| HardGouradShader | :heavy_check_mark: ||| :heavy_check_mark: || :heavy_check_mark:||
|
||||||
|
| SoftGouradShader | :heavy_check_mark: ||| :heavy_check_mark: ||| :heavy_check_mark:|
|
||||||
|
| TexturedSoftPhongShader || :heavy_check_mark: ||| :heavy_check_mark: || :heavy_check_mark:|
|
||||||
|
| HardFlatShader | :heavy_check_mark: || :heavy_check_mark: ||| :heavy_check_mark:||
|
||||||
|
| SoftSilhouetteShader ||||||| :heavy_check_mark:|
|
||||||
|
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@ -17,13 +17,16 @@ from .cameras import (
|
|||||||
from .lighting import DirectionalLights, PointLights, diffuse, specular
|
from .lighting import DirectionalLights, PointLights, diffuse, specular
|
||||||
from .materials import Materials
|
from .materials import Materials
|
||||||
from .mesh import (
|
from .mesh import (
|
||||||
GouradShader,
|
HardFlatShader,
|
||||||
|
HardGouradShader,
|
||||||
|
HardPhongShader,
|
||||||
MeshRasterizer,
|
MeshRasterizer,
|
||||||
MeshRenderer,
|
MeshRenderer,
|
||||||
PhongShader,
|
|
||||||
RasterizationSettings,
|
RasterizationSettings,
|
||||||
SilhouetteShader,
|
SoftGouradShader,
|
||||||
TexturedPhongShader,
|
SoftPhongShader,
|
||||||
|
SoftSilhouetteShader,
|
||||||
|
TexturedSoftPhongShader,
|
||||||
gourad_shading,
|
gourad_shading,
|
||||||
interpolate_face_attributes,
|
interpolate_face_attributes,
|
||||||
interpolate_texture_map,
|
interpolate_texture_map,
|
||||||
|
@ -4,10 +4,13 @@ from .rasterize_meshes import rasterize_meshes
|
|||||||
from .rasterizer import MeshRasterizer, RasterizationSettings
|
from .rasterizer import MeshRasterizer, RasterizationSettings
|
||||||
from .renderer import MeshRenderer
|
from .renderer import MeshRenderer
|
||||||
from .shader import (
|
from .shader import (
|
||||||
GouradShader,
|
HardFlatShader,
|
||||||
PhongShader,
|
HardGouradShader,
|
||||||
SilhouetteShader,
|
HardPhongShader,
|
||||||
TexturedPhongShader,
|
SoftGouradShader,
|
||||||
|
SoftPhongShader,
|
||||||
|
SoftSilhouetteShader,
|
||||||
|
TexturedSoftPhongShader,
|
||||||
)
|
)
|
||||||
from .shading import gourad_shading, phong_shading
|
from .shading import gourad_shading, phong_shading
|
||||||
from .texturing import ( # isort: skip
|
from .texturing import ( # isort: skip
|
||||||
|
@ -14,7 +14,7 @@ from ..blending import (
|
|||||||
from ..cameras import OpenGLPerspectiveCameras
|
from ..cameras import OpenGLPerspectiveCameras
|
||||||
from ..lighting import PointLights
|
from ..lighting import PointLights
|
||||||
from ..materials import Materials
|
from ..materials import Materials
|
||||||
from .shading import gourad_shading, phong_shading
|
from .shading import flat_shading, gourad_shading, phong_shading
|
||||||
from .texturing import interpolate_texture_map, interpolate_vertex_colors
|
from .texturing import interpolate_texture_map, interpolate_vertex_colors
|
||||||
|
|
||||||
# A Shader should take as input fragments from the output of rasterization
|
# A Shader should take as input fragments from the output of rasterization
|
||||||
@ -26,17 +26,18 @@ from .texturing import interpolate_texture_map, interpolate_vertex_colors
|
|||||||
# - blend colors across top K faces per pixel.
|
# - blend colors across top K faces per pixel.
|
||||||
|
|
||||||
|
|
||||||
class PhongShader(nn.Module):
|
class HardPhongShader(nn.Module):
|
||||||
"""
|
"""
|
||||||
Per pixel lighting. Apply the lighting model using the interpolated coords
|
Per pixel lighting - the lighting model is applied using the interpolated
|
||||||
and normals for each pixel.
|
coordinates and normals for each pixel. The blending function hard assigns
|
||||||
|
the color of the closest face for each pixel.
|
||||||
|
|
||||||
To use the default values, simply initialize the shader with the desired
|
To use the default values, simply initialize the shader with the desired
|
||||||
device e.g.
|
device e.g.
|
||||||
|
|
||||||
.. code-block::
|
.. code-block::
|
||||||
|
|
||||||
shader = PhongShader(device=torch.device("cuda:0"))
|
shader = HardPhongShader(device=torch.device("cuda:0"))
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, device="cpu", cameras=None, lights=None, materials=None):
|
def __init__(self, device="cpu", cameras=None, lights=None, materials=None):
|
||||||
@ -70,17 +71,74 @@ class PhongShader(nn.Module):
|
|||||||
return images
|
return images
|
||||||
|
|
||||||
|
|
||||||
class GouradShader(nn.Module):
|
class SoftPhongShader(nn.Module):
|
||||||
"""
|
"""
|
||||||
Per vertex lighting. Apply the lighting model to the vertex colors and then
|
Per pixel lighting - the lighting model is applied using the interpolated
|
||||||
interpolate using the barycentric coordinates to get colors for each pixel.
|
coordinates and normals for each pixel. The blending function returns the
|
||||||
|
soft aggregated color using all the faces per pixel.
|
||||||
|
|
||||||
To use the default values, simply initialize the shader with the desired
|
To use the default values, simply initialize the shader with the desired
|
||||||
device e.g.
|
device e.g.
|
||||||
|
|
||||||
.. code-block::
|
.. code-block::
|
||||||
|
|
||||||
shader = GouradShader(device=torch.device("cuda:0"))
|
shader = SoftPhongShader(device=torch.device("cuda:0"))
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
device="cpu",
|
||||||
|
cameras=None,
|
||||||
|
lights=None,
|
||||||
|
materials=None,
|
||||||
|
blend_params=None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.lights = (
|
||||||
|
lights if lights is not None else PointLights(device=device)
|
||||||
|
)
|
||||||
|
self.materials = (
|
||||||
|
materials if materials is not None else Materials(device=device)
|
||||||
|
)
|
||||||
|
self.cameras = (
|
||||||
|
cameras
|
||||||
|
if cameras is not None
|
||||||
|
else OpenGLPerspectiveCameras(device=device)
|
||||||
|
)
|
||||||
|
self.blend_params = (
|
||||||
|
blend_params if blend_params is not None else BlendParams()
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
|
||||||
|
texels = interpolate_vertex_colors(fragments, meshes)
|
||||||
|
cameras = kwargs.get("cameras", self.cameras)
|
||||||
|
lights = kwargs.get("lights", self.lights)
|
||||||
|
materials = kwargs.get("materials", self.materials)
|
||||||
|
colors = phong_shading(
|
||||||
|
meshes=meshes,
|
||||||
|
fragments=fragments,
|
||||||
|
texels=texels,
|
||||||
|
lights=lights,
|
||||||
|
cameras=cameras,
|
||||||
|
materials=materials,
|
||||||
|
)
|
||||||
|
images = softmax_rgb_blend(colors, fragments, self.blend_params)
|
||||||
|
return images
|
||||||
|
|
||||||
|
|
||||||
|
class HardGouradShader(nn.Module):
|
||||||
|
"""
|
||||||
|
Per vertex lighting - the lighting model is applied to the vertex colors and
|
||||||
|
the colors are then interpolated using the barycentric coordinates to
|
||||||
|
obtain the colors for each pixel. The blending function hard assigns
|
||||||
|
the color of the closest face for each pixel.
|
||||||
|
|
||||||
|
To use the default values, simply initialize the shader with the desired
|
||||||
|
device e.g.
|
||||||
|
|
||||||
|
.. code-block::
|
||||||
|
|
||||||
|
shader = HardGouradShader(device=torch.device("cuda:0"))
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, device="cpu", cameras=None, lights=None, materials=None):
|
def __init__(self, device="cpu", cameras=None, lights=None, materials=None):
|
||||||
@ -112,12 +170,69 @@ class GouradShader(nn.Module):
|
|||||||
return images
|
return images
|
||||||
|
|
||||||
|
|
||||||
class TexturedPhongShader(nn.Module):
|
class SoftGouradShader(nn.Module):
|
||||||
|
"""
|
||||||
|
Per vertex lighting - the lighting model is applied to the vertex colors and
|
||||||
|
the colors are then interpolated using the barycentric coordinates to
|
||||||
|
obtain the colors for each pixel. The blending function returns the
|
||||||
|
soft aggregated color using all the faces per pixel.
|
||||||
|
|
||||||
|
To use the default values, simply initialize the shader with the desired
|
||||||
|
device e.g.
|
||||||
|
|
||||||
|
.. code-block::
|
||||||
|
|
||||||
|
shader = SoftGouradShader(device=torch.device("cuda:0"))
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
device="cpu",
|
||||||
|
cameras=None,
|
||||||
|
lights=None,
|
||||||
|
materials=None,
|
||||||
|
blend_params=None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.lights = (
|
||||||
|
lights if lights is not None else PointLights(device=device)
|
||||||
|
)
|
||||||
|
self.materials = (
|
||||||
|
materials if materials is not None else Materials(device=device)
|
||||||
|
)
|
||||||
|
self.cameras = (
|
||||||
|
cameras
|
||||||
|
if cameras is not None
|
||||||
|
else OpenGLPerspectiveCameras(device=device)
|
||||||
|
)
|
||||||
|
self.blend_params = (
|
||||||
|
blend_params if blend_params is not None else BlendParams()
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
|
||||||
|
cameras = kwargs.get("cameras", self.cameras)
|
||||||
|
lights = kwargs.get("lights", self.lights)
|
||||||
|
materials = kwargs.get("materials", self.materials)
|
||||||
|
pixel_colors = gourad_shading(
|
||||||
|
meshes=meshes,
|
||||||
|
fragments=fragments,
|
||||||
|
lights=lights,
|
||||||
|
cameras=cameras,
|
||||||
|
materials=materials,
|
||||||
|
)
|
||||||
|
images = softmax_rgb_blend(pixel_colors, fragments, self.blend_params)
|
||||||
|
return images
|
||||||
|
|
||||||
|
|
||||||
|
class TexturedSoftPhongShader(nn.Module):
|
||||||
"""
|
"""
|
||||||
Per pixel lighting applied to a texture map. First interpolate the vertex
|
Per pixel lighting applied to a texture map. First interpolate the vertex
|
||||||
uv coordinates and sample from a texture map. Then apply the lighting model
|
uv coordinates and sample from a texture map. Then apply the lighting model
|
||||||
using the interpolated coords and normals for each pixel.
|
using the interpolated coords and normals for each pixel.
|
||||||
|
|
||||||
|
The blending function returns the soft aggregated color using all
|
||||||
|
the faces per pixel.
|
||||||
|
|
||||||
To use the default values, simply initialize the shader with the desired
|
To use the default values, simply initialize the shader with the desired
|
||||||
device e.g.
|
device e.g.
|
||||||
|
|
||||||
@ -167,7 +282,52 @@ class TexturedPhongShader(nn.Module):
|
|||||||
return images
|
return images
|
||||||
|
|
||||||
|
|
||||||
class SilhouetteShader(nn.Module):
|
class HardFlatShader(nn.Module):
|
||||||
|
"""
|
||||||
|
Per face lighting - the lighting model is applied using the average face
|
||||||
|
position and the face normal. The blending function hard assigns
|
||||||
|
the color of the closest face for each pixel.
|
||||||
|
|
||||||
|
To use the default values, simply initialize the shader with the desired
|
||||||
|
device e.g.
|
||||||
|
|
||||||
|
.. code-block::
|
||||||
|
|
||||||
|
shader = HardFlatShader(device=torch.device("cuda:0"))
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, device="cpu", cameras=None, lights=None, materials=None):
|
||||||
|
super().__init__()
|
||||||
|
self.lights = (
|
||||||
|
lights if lights is not None else PointLights(device=device)
|
||||||
|
)
|
||||||
|
self.materials = (
|
||||||
|
materials if materials is not None else Materials(device=device)
|
||||||
|
)
|
||||||
|
self.cameras = (
|
||||||
|
cameras
|
||||||
|
if cameras is not None
|
||||||
|
else OpenGLPerspectiveCameras(device=device)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
|
||||||
|
texels = interpolate_vertex_colors(fragments, meshes)
|
||||||
|
cameras = kwargs.get("cameras", self.cameras)
|
||||||
|
lights = kwargs.get("lights", self.lights)
|
||||||
|
materials = kwargs.get("materials", self.materials)
|
||||||
|
colors = flat_shading(
|
||||||
|
meshes=meshes,
|
||||||
|
fragments=fragments,
|
||||||
|
texels=texels,
|
||||||
|
lights=lights,
|
||||||
|
cameras=cameras,
|
||||||
|
materials=materials,
|
||||||
|
)
|
||||||
|
images = hard_rgb_blend(colors, fragments)
|
||||||
|
return images
|
||||||
|
|
||||||
|
|
||||||
|
class SoftSilhouetteShader(nn.Module):
|
||||||
"""
|
"""
|
||||||
Calculate the silhouette by blending the top K faces for each pixel based
|
Calculate the silhouette by blending the top K faces for each pixel based
|
||||||
on the 2d euclidean distance of the centre of the pixel to the mesh face.
|
on the 2d euclidean distance of the centre of the pixel to the mesh face.
|
||||||
|
@ -124,3 +124,39 @@ def gourad_shading(
|
|||||||
face_colors = verts_colors_shaded[faces]
|
face_colors = verts_colors_shaded[faces]
|
||||||
colors = interpolate_face_attributes(fragments, face_colors)
|
colors = interpolate_face_attributes(fragments, face_colors)
|
||||||
return colors
|
return colors
|
||||||
|
|
||||||
|
|
||||||
|
def flat_shading(
|
||||||
|
meshes, fragments, lights, cameras, materials, texels
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Apply per face shading. Use the average face position and the face normals
|
||||||
|
to compute the ambient, diffuse and specular lighting. Apply the ambient
|
||||||
|
and diffuse color to the pixel color and add the specular component to
|
||||||
|
determine the final pixel color.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
meshes: Batch of meshes
|
||||||
|
fragments: Fragments named tuple with the outputs of rasterization
|
||||||
|
lights: Lights class containing a batch of lights parameters
|
||||||
|
cameras: Cameras class containing a batch of cameras parameters
|
||||||
|
materials: Materials class containing a batch of material properties
|
||||||
|
texels: texture per pixel of shape (N, H, W, K, 3)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
colors: (N, H, W, K, 3)
|
||||||
|
"""
|
||||||
|
verts = meshes.verts_packed() # (V, 3)
|
||||||
|
faces = meshes.faces_packed() # (F, 3)
|
||||||
|
face_normals = meshes.faces_normals_packed() # (V, 3)
|
||||||
|
faces_verts = verts[faces]
|
||||||
|
face_coords = faces_verts.mean(dim=-2) # (F, 3, XYZ) mean xyz across verts
|
||||||
|
pixel_coords = face_coords[fragments.pix_to_face]
|
||||||
|
pixel_normals = face_normals[fragments.pix_to_face]
|
||||||
|
|
||||||
|
# Calculate the illumination at each face
|
||||||
|
ambient, diffuse, specular = _apply_lighting(
|
||||||
|
pixel_coords, pixel_normals, lights, cameras, materials
|
||||||
|
)
|
||||||
|
colors = (ambient + diffuse) * texels + specular
|
||||||
|
return colors
|
||||||
|
@ -25,10 +25,10 @@ from pytorch3d.renderer.mesh.rasterizer import (
|
|||||||
from pytorch3d.renderer.mesh.renderer import MeshRenderer
|
from pytorch3d.renderer.mesh.renderer import MeshRenderer
|
||||||
from pytorch3d.renderer.mesh.shader import (
|
from pytorch3d.renderer.mesh.shader import (
|
||||||
BlendParams,
|
BlendParams,
|
||||||
GouradShader,
|
HardGouradShader,
|
||||||
PhongShader,
|
HardPhongShader,
|
||||||
SilhouetteShader,
|
SoftSilhouetteShader,
|
||||||
TexturedPhongShader,
|
TexturedSoftPhongShader,
|
||||||
)
|
)
|
||||||
from pytorch3d.renderer.mesh.texturing import Textures
|
from pytorch3d.renderer.mesh.texturing import Textures
|
||||||
from pytorch3d.structures.meshes import Meshes
|
from pytorch3d.structures.meshes import Meshes
|
||||||
@ -92,7 +92,7 @@ class TestRenderingMeshes(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
renderer = MeshRenderer(
|
renderer = MeshRenderer(
|
||||||
rasterizer=rasterizer,
|
rasterizer=rasterizer,
|
||||||
shader=PhongShader(
|
shader=HardPhongShader(
|
||||||
lights=lights, cameras=cameras, materials=materials
|
lights=lights, cameras=cameras, materials=materials
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@ -133,7 +133,7 @@ class TestRenderingMeshes(unittest.TestCase):
|
|||||||
lights.location = torch.tensor([0.0, 0.0, -2.0], device=device)[None]
|
lights.location = torch.tensor([0.0, 0.0, -2.0], device=device)[None]
|
||||||
renderer = MeshRenderer(
|
renderer = MeshRenderer(
|
||||||
rasterizer=rasterizer,
|
rasterizer=rasterizer,
|
||||||
shader=GouradShader(
|
shader=HardGouradShader(
|
||||||
lights=lights, cameras=cameras, materials=materials
|
lights=lights, cameras=cameras, materials=materials
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@ -197,7 +197,7 @@ class TestRenderingMeshes(unittest.TestCase):
|
|||||||
rasterizer=MeshRasterizer(
|
rasterizer=MeshRasterizer(
|
||||||
cameras=cameras, raster_settings=raster_settings
|
cameras=cameras, raster_settings=raster_settings
|
||||||
),
|
),
|
||||||
shader=PhongShader(
|
shader=HardPhongShader(
|
||||||
lights=lights, cameras=cameras, materials=materials
|
lights=lights, cameras=cameras, materials=materials
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@ -242,7 +242,7 @@ class TestRenderingMeshes(unittest.TestCase):
|
|||||||
rasterizer=MeshRasterizer(
|
rasterizer=MeshRasterizer(
|
||||||
cameras=cameras, raster_settings=raster_settings
|
cameras=cameras, raster_settings=raster_settings
|
||||||
),
|
),
|
||||||
shader=SilhouetteShader(blend_params=blend_params),
|
shader=SoftSilhouetteShader(blend_params=blend_params),
|
||||||
)
|
)
|
||||||
images = renderer(sphere_mesh)
|
images = renderer(sphere_mesh)
|
||||||
alpha = images[0, ..., 3].squeeze().cpu()
|
alpha = images[0, ..., 3].squeeze().cpu()
|
||||||
@ -296,7 +296,7 @@ class TestRenderingMeshes(unittest.TestCase):
|
|||||||
rasterizer=MeshRasterizer(
|
rasterizer=MeshRasterizer(
|
||||||
cameras=cameras, raster_settings=raster_settings
|
cameras=cameras, raster_settings=raster_settings
|
||||||
),
|
),
|
||||||
shader=TexturedPhongShader(
|
shader=TexturedSoftPhongShader(
|
||||||
lights=lights, cameras=cameras, materials=materials
|
lights=lights, cameras=cameras, materials=materials
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user