mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
SplatterPhongShader 1: Pull out common Shader functionality into ShaderBase
Summary: Most of the shaders copypaste exactly the same code into `__init__` and `to`. I will be adding a new shader in the next diff, so let's make it a bit easier. Reviewed By: bottler Differential Revision: D35767884 fbshipit-source-id: 0057e3e2ae3be4eaa49ae7e2bf3e4176953dde9d
This commit is contained in:
parent
9f443ed26b
commit
96889deab9
@ -32,22 +32,7 @@ from .shading import flat_shading, gouraud_shading, phong_shading
|
||||
# - sample colors from a texture map
|
||||
# - apply per pixel lighting
|
||||
# - blend colors across top K faces per pixel.
|
||||
|
||||
|
||||
class HardPhongShader(nn.Module):
|
||||
"""
|
||||
Per pixel lighting - the lighting model is applied using the interpolated
|
||||
coordinates and normals for each pixel. The blending function hard assigns
|
||||
the color of the closest face for each pixel.
|
||||
|
||||
To use the default values, simply initialize the shader with the desired
|
||||
device e.g.
|
||||
|
||||
.. code-block::
|
||||
|
||||
shader = HardPhongShader(device=torch.device("cuda:0"))
|
||||
"""
|
||||
|
||||
class ShaderBase(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
device: Device = "cpu",
|
||||
@ -74,6 +59,21 @@ class HardPhongShader(nn.Module):
|
||||
self.lights = self.lights.to(device)
|
||||
return self
|
||||
|
||||
|
||||
class HardPhongShader(ShaderBase):
|
||||
"""
|
||||
Per pixel lighting - the lighting model is applied using the interpolated
|
||||
coordinates and normals for each pixel. The blending function hard assigns
|
||||
the color of the closest face for each pixel.
|
||||
|
||||
To use the default values, simply initialize the shader with the desired
|
||||
device e.g.
|
||||
|
||||
.. code-block::
|
||||
|
||||
shader = HardPhongShader(device=torch.device("cuda:0"))
|
||||
"""
|
||||
|
||||
def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor:
|
||||
cameras = kwargs.get("cameras", self.cameras)
|
||||
if cameras is None:
|
||||
@ -97,7 +97,7 @@ class HardPhongShader(nn.Module):
|
||||
return images
|
||||
|
||||
|
||||
class SoftPhongShader(nn.Module):
|
||||
class SoftPhongShader(ShaderBase):
|
||||
"""
|
||||
Per pixel lighting - the lighting model is applied using the interpolated
|
||||
coordinates and normals for each pixel. The blending function returns the
|
||||
@ -111,32 +111,6 @@ class SoftPhongShader(nn.Module):
|
||||
shader = SoftPhongShader(device=torch.device("cuda:0"))
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
device: Device = "cpu",
|
||||
cameras: Optional[TensorProperties] = None,
|
||||
lights: Optional[TensorProperties] = None,
|
||||
materials: Optional[Materials] = None,
|
||||
blend_params: Optional[BlendParams] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.lights = lights if lights is not None else PointLights(device=device)
|
||||
self.materials = (
|
||||
materials if materials is not None else Materials(device=device)
|
||||
)
|
||||
self.cameras = cameras
|
||||
self.blend_params = blend_params if blend_params is not None else BlendParams()
|
||||
|
||||
# pyre-fixme[14]: `to` overrides method defined in `Module` inconsistently.
|
||||
def to(self, device: Device):
|
||||
# Manually move to device modules which are not subclasses of nn.Module
|
||||
cameras = self.cameras
|
||||
if cameras is not None:
|
||||
self.cameras = cameras.to(device)
|
||||
self.materials = self.materials.to(device)
|
||||
self.lights = self.lights.to(device)
|
||||
return self
|
||||
|
||||
def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor:
|
||||
cameras = kwargs.get("cameras", self.cameras)
|
||||
if cameras is None:
|
||||
@ -164,7 +138,7 @@ class SoftPhongShader(nn.Module):
|
||||
return images
|
||||
|
||||
|
||||
class HardGouraudShader(nn.Module):
|
||||
class HardGouraudShader(ShaderBase):
|
||||
"""
|
||||
Per vertex lighting - the lighting model is applied to the vertex colors and
|
||||
the colors are then interpolated using the barycentric coordinates to
|
||||
@ -179,32 +153,6 @@ class HardGouraudShader(nn.Module):
|
||||
shader = HardGouraudShader(device=torch.device("cuda:0"))
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
device: Device = "cpu",
|
||||
cameras: Optional[TensorProperties] = None,
|
||||
lights: Optional[TensorProperties] = None,
|
||||
materials: Optional[Materials] = None,
|
||||
blend_params: Optional[BlendParams] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.lights = lights if lights is not None else PointLights(device=device)
|
||||
self.materials = (
|
||||
materials if materials is not None else Materials(device=device)
|
||||
)
|
||||
self.cameras = cameras
|
||||
self.blend_params = blend_params if blend_params is not None else BlendParams()
|
||||
|
||||
# pyre-fixme[14]: `to` overrides method defined in `Module` inconsistently.
|
||||
def to(self, device: Device):
|
||||
# Manually move to device modules which are not subclasses of nn.Module
|
||||
cameras = self.cameras
|
||||
if cameras is not None:
|
||||
self.cameras = cameras.to(device)
|
||||
self.materials = self.materials.to(device)
|
||||
self.lights = self.lights.to(device)
|
||||
return self
|
||||
|
||||
def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor:
|
||||
cameras = kwargs.get("cameras", self.cameras)
|
||||
if cameras is None:
|
||||
@ -231,7 +179,7 @@ class HardGouraudShader(nn.Module):
|
||||
return images
|
||||
|
||||
|
||||
class SoftGouraudShader(nn.Module):
|
||||
class SoftGouraudShader(ShaderBase):
|
||||
"""
|
||||
Per vertex lighting - the lighting model is applied to the vertex colors and
|
||||
the colors are then interpolated using the barycentric coordinates to
|
||||
@ -246,32 +194,6 @@ class SoftGouraudShader(nn.Module):
|
||||
shader = SoftGouraudShader(device=torch.device("cuda:0"))
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
device: Device = "cpu",
|
||||
cameras: Optional[TensorProperties] = None,
|
||||
lights: Optional[TensorProperties] = None,
|
||||
materials: Optional[Materials] = None,
|
||||
blend_params: Optional[BlendParams] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.lights = lights if lights is not None else PointLights(device=device)
|
||||
self.materials = (
|
||||
materials if materials is not None else Materials(device=device)
|
||||
)
|
||||
self.cameras = cameras
|
||||
self.blend_params = blend_params if blend_params is not None else BlendParams()
|
||||
|
||||
# pyre-fixme[14]: `to` overrides method defined in `Module` inconsistently.
|
||||
def to(self, device: Device):
|
||||
# Manually move to device modules which are not subclasses of nn.Module
|
||||
cameras = self.cameras
|
||||
if cameras is not None:
|
||||
self.cameras = cameras.to(device)
|
||||
self.materials = self.materials.to(device)
|
||||
self.lights = self.lights.to(device)
|
||||
return self
|
||||
|
||||
def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor:
|
||||
cameras = kwargs.get("cameras", self.cameras)
|
||||
if cameras is None:
|
||||
@ -320,7 +242,7 @@ def TexturedSoftPhongShader(
|
||||
)
|
||||
|
||||
|
||||
class HardFlatShader(nn.Module):
|
||||
class HardFlatShader(ShaderBase):
|
||||
"""
|
||||
Per face lighting - the lighting model is applied using the average face
|
||||
position and the face normal. The blending function hard assigns
|
||||
@ -334,32 +256,6 @@ class HardFlatShader(nn.Module):
|
||||
shader = HardFlatShader(device=torch.device("cuda:0"))
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
device: Device = "cpu",
|
||||
cameras: Optional[TensorProperties] = None,
|
||||
lights: Optional[TensorProperties] = None,
|
||||
materials: Optional[Materials] = None,
|
||||
blend_params: Optional[BlendParams] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.lights = lights if lights is not None else PointLights(device=device)
|
||||
self.materials = (
|
||||
materials if materials is not None else Materials(device=device)
|
||||
)
|
||||
self.cameras = cameras
|
||||
self.blend_params = blend_params if blend_params is not None else BlendParams()
|
||||
|
||||
# pyre-fixme[14]: `to` overrides method defined in `Module` inconsistently.
|
||||
def to(self, device: Device):
|
||||
# Manually move to device modules which are not subclasses of nn.Module
|
||||
cameras = self.cameras
|
||||
if cameras is not None:
|
||||
self.cameras = cameras.to(device)
|
||||
self.materials = self.materials.to(device)
|
||||
self.lights = self.lights.to(device)
|
||||
return self
|
||||
|
||||
def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor:
|
||||
cameras = kwargs.get("cameras", self.cameras)
|
||||
if cameras is None:
|
||||
|
Loading…
x
Reference in New Issue
Block a user