mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 20:02:49 +08:00
SplatterPhongShader 1: Pull out common Shader functionality into ShaderBase
Summary: Most of the shaders copypaste exactly the same code into `__init__` and `to`. I will be adding a new shader in the next diff, so let's make it a bit easier. Reviewed By: bottler Differential Revision: D35767884 fbshipit-source-id: 0057e3e2ae3be4eaa49ae7e2bf3e4176953dde9d
This commit is contained in:
parent
9f443ed26b
commit
96889deab9
@ -32,22 +32,7 @@ from .shading import flat_shading, gouraud_shading, phong_shading
|
|||||||
# - sample colors from a texture map
|
# - sample colors from a texture map
|
||||||
# - apply per pixel lighting
|
# - apply per pixel lighting
|
||||||
# - blend colors across top K faces per pixel.
|
# - blend colors across top K faces per pixel.
|
||||||
|
class ShaderBase(nn.Module):
|
||||||
|
|
||||||
class HardPhongShader(nn.Module):
|
|
||||||
"""
|
|
||||||
Per pixel lighting - the lighting model is applied using the interpolated
|
|
||||||
coordinates and normals for each pixel. The blending function hard assigns
|
|
||||||
the color of the closest face for each pixel.
|
|
||||||
|
|
||||||
To use the default values, simply initialize the shader with the desired
|
|
||||||
device e.g.
|
|
||||||
|
|
||||||
.. code-block::
|
|
||||||
|
|
||||||
shader = HardPhongShader(device=torch.device("cuda:0"))
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
device: Device = "cpu",
|
device: Device = "cpu",
|
||||||
@ -74,6 +59,21 @@ class HardPhongShader(nn.Module):
|
|||||||
self.lights = self.lights.to(device)
|
self.lights = self.lights.to(device)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
class HardPhongShader(ShaderBase):
|
||||||
|
"""
|
||||||
|
Per pixel lighting - the lighting model is applied using the interpolated
|
||||||
|
coordinates and normals for each pixel. The blending function hard assigns
|
||||||
|
the color of the closest face for each pixel.
|
||||||
|
|
||||||
|
To use the default values, simply initialize the shader with the desired
|
||||||
|
device e.g.
|
||||||
|
|
||||||
|
.. code-block::
|
||||||
|
|
||||||
|
shader = HardPhongShader(device=torch.device("cuda:0"))
|
||||||
|
"""
|
||||||
|
|
||||||
def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor:
|
def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor:
|
||||||
cameras = kwargs.get("cameras", self.cameras)
|
cameras = kwargs.get("cameras", self.cameras)
|
||||||
if cameras is None:
|
if cameras is None:
|
||||||
@ -97,7 +97,7 @@ class HardPhongShader(nn.Module):
|
|||||||
return images
|
return images
|
||||||
|
|
||||||
|
|
||||||
class SoftPhongShader(nn.Module):
|
class SoftPhongShader(ShaderBase):
|
||||||
"""
|
"""
|
||||||
Per pixel lighting - the lighting model is applied using the interpolated
|
Per pixel lighting - the lighting model is applied using the interpolated
|
||||||
coordinates and normals for each pixel. The blending function returns the
|
coordinates and normals for each pixel. The blending function returns the
|
||||||
@ -111,32 +111,6 @@ class SoftPhongShader(nn.Module):
|
|||||||
shader = SoftPhongShader(device=torch.device("cuda:0"))
|
shader = SoftPhongShader(device=torch.device("cuda:0"))
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
device: Device = "cpu",
|
|
||||||
cameras: Optional[TensorProperties] = None,
|
|
||||||
lights: Optional[TensorProperties] = None,
|
|
||||||
materials: Optional[Materials] = None,
|
|
||||||
blend_params: Optional[BlendParams] = None,
|
|
||||||
) -> None:
|
|
||||||
super().__init__()
|
|
||||||
self.lights = lights if lights is not None else PointLights(device=device)
|
|
||||||
self.materials = (
|
|
||||||
materials if materials is not None else Materials(device=device)
|
|
||||||
)
|
|
||||||
self.cameras = cameras
|
|
||||||
self.blend_params = blend_params if blend_params is not None else BlendParams()
|
|
||||||
|
|
||||||
# pyre-fixme[14]: `to` overrides method defined in `Module` inconsistently.
|
|
||||||
def to(self, device: Device):
|
|
||||||
# Manually move to device modules which are not subclasses of nn.Module
|
|
||||||
cameras = self.cameras
|
|
||||||
if cameras is not None:
|
|
||||||
self.cameras = cameras.to(device)
|
|
||||||
self.materials = self.materials.to(device)
|
|
||||||
self.lights = self.lights.to(device)
|
|
||||||
return self
|
|
||||||
|
|
||||||
def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor:
|
def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor:
|
||||||
cameras = kwargs.get("cameras", self.cameras)
|
cameras = kwargs.get("cameras", self.cameras)
|
||||||
if cameras is None:
|
if cameras is None:
|
||||||
@ -164,7 +138,7 @@ class SoftPhongShader(nn.Module):
|
|||||||
return images
|
return images
|
||||||
|
|
||||||
|
|
||||||
class HardGouraudShader(nn.Module):
|
class HardGouraudShader(ShaderBase):
|
||||||
"""
|
"""
|
||||||
Per vertex lighting - the lighting model is applied to the vertex colors and
|
Per vertex lighting - the lighting model is applied to the vertex colors and
|
||||||
the colors are then interpolated using the barycentric coordinates to
|
the colors are then interpolated using the barycentric coordinates to
|
||||||
@ -179,32 +153,6 @@ class HardGouraudShader(nn.Module):
|
|||||||
shader = HardGouraudShader(device=torch.device("cuda:0"))
|
shader = HardGouraudShader(device=torch.device("cuda:0"))
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
device: Device = "cpu",
|
|
||||||
cameras: Optional[TensorProperties] = None,
|
|
||||||
lights: Optional[TensorProperties] = None,
|
|
||||||
materials: Optional[Materials] = None,
|
|
||||||
blend_params: Optional[BlendParams] = None,
|
|
||||||
) -> None:
|
|
||||||
super().__init__()
|
|
||||||
self.lights = lights if lights is not None else PointLights(device=device)
|
|
||||||
self.materials = (
|
|
||||||
materials if materials is not None else Materials(device=device)
|
|
||||||
)
|
|
||||||
self.cameras = cameras
|
|
||||||
self.blend_params = blend_params if blend_params is not None else BlendParams()
|
|
||||||
|
|
||||||
# pyre-fixme[14]: `to` overrides method defined in `Module` inconsistently.
|
|
||||||
def to(self, device: Device):
|
|
||||||
# Manually move to device modules which are not subclasses of nn.Module
|
|
||||||
cameras = self.cameras
|
|
||||||
if cameras is not None:
|
|
||||||
self.cameras = cameras.to(device)
|
|
||||||
self.materials = self.materials.to(device)
|
|
||||||
self.lights = self.lights.to(device)
|
|
||||||
return self
|
|
||||||
|
|
||||||
def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor:
|
def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor:
|
||||||
cameras = kwargs.get("cameras", self.cameras)
|
cameras = kwargs.get("cameras", self.cameras)
|
||||||
if cameras is None:
|
if cameras is None:
|
||||||
@ -231,7 +179,7 @@ class HardGouraudShader(nn.Module):
|
|||||||
return images
|
return images
|
||||||
|
|
||||||
|
|
||||||
class SoftGouraudShader(nn.Module):
|
class SoftGouraudShader(ShaderBase):
|
||||||
"""
|
"""
|
||||||
Per vertex lighting - the lighting model is applied to the vertex colors and
|
Per vertex lighting - the lighting model is applied to the vertex colors and
|
||||||
the colors are then interpolated using the barycentric coordinates to
|
the colors are then interpolated using the barycentric coordinates to
|
||||||
@ -246,32 +194,6 @@ class SoftGouraudShader(nn.Module):
|
|||||||
shader = SoftGouraudShader(device=torch.device("cuda:0"))
|
shader = SoftGouraudShader(device=torch.device("cuda:0"))
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
device: Device = "cpu",
|
|
||||||
cameras: Optional[TensorProperties] = None,
|
|
||||||
lights: Optional[TensorProperties] = None,
|
|
||||||
materials: Optional[Materials] = None,
|
|
||||||
blend_params: Optional[BlendParams] = None,
|
|
||||||
) -> None:
|
|
||||||
super().__init__()
|
|
||||||
self.lights = lights if lights is not None else PointLights(device=device)
|
|
||||||
self.materials = (
|
|
||||||
materials if materials is not None else Materials(device=device)
|
|
||||||
)
|
|
||||||
self.cameras = cameras
|
|
||||||
self.blend_params = blend_params if blend_params is not None else BlendParams()
|
|
||||||
|
|
||||||
# pyre-fixme[14]: `to` overrides method defined in `Module` inconsistently.
|
|
||||||
def to(self, device: Device):
|
|
||||||
# Manually move to device modules which are not subclasses of nn.Module
|
|
||||||
cameras = self.cameras
|
|
||||||
if cameras is not None:
|
|
||||||
self.cameras = cameras.to(device)
|
|
||||||
self.materials = self.materials.to(device)
|
|
||||||
self.lights = self.lights.to(device)
|
|
||||||
return self
|
|
||||||
|
|
||||||
def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor:
|
def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor:
|
||||||
cameras = kwargs.get("cameras", self.cameras)
|
cameras = kwargs.get("cameras", self.cameras)
|
||||||
if cameras is None:
|
if cameras is None:
|
||||||
@ -320,7 +242,7 @@ def TexturedSoftPhongShader(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class HardFlatShader(nn.Module):
|
class HardFlatShader(ShaderBase):
|
||||||
"""
|
"""
|
||||||
Per face lighting - the lighting model is applied using the average face
|
Per face lighting - the lighting model is applied using the average face
|
||||||
position and the face normal. The blending function hard assigns
|
position and the face normal. The blending function hard assigns
|
||||||
@ -334,32 +256,6 @@ class HardFlatShader(nn.Module):
|
|||||||
shader = HardFlatShader(device=torch.device("cuda:0"))
|
shader = HardFlatShader(device=torch.device("cuda:0"))
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
device: Device = "cpu",
|
|
||||||
cameras: Optional[TensorProperties] = None,
|
|
||||||
lights: Optional[TensorProperties] = None,
|
|
||||||
materials: Optional[Materials] = None,
|
|
||||||
blend_params: Optional[BlendParams] = None,
|
|
||||||
) -> None:
|
|
||||||
super().__init__()
|
|
||||||
self.lights = lights if lights is not None else PointLights(device=device)
|
|
||||||
self.materials = (
|
|
||||||
materials if materials is not None else Materials(device=device)
|
|
||||||
)
|
|
||||||
self.cameras = cameras
|
|
||||||
self.blend_params = blend_params if blend_params is not None else BlendParams()
|
|
||||||
|
|
||||||
# pyre-fixme[14]: `to` overrides method defined in `Module` inconsistently.
|
|
||||||
def to(self, device: Device):
|
|
||||||
# Manually move to device modules which are not subclasses of nn.Module
|
|
||||||
cameras = self.cameras
|
|
||||||
if cameras is not None:
|
|
||||||
self.cameras = cameras.to(device)
|
|
||||||
self.materials = self.materials.to(device)
|
|
||||||
self.lights = self.lights.to(device)
|
|
||||||
return self
|
|
||||||
|
|
||||||
def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor:
|
def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor:
|
||||||
cameras = kwargs.get("cameras", self.cameras)
|
cameras = kwargs.get("cameras", self.cameras)
|
||||||
if cameras is None:
|
if cameras is None:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user