SplatterPhongShader 1: Pull out common Shader functionality into ShaderBase

Summary: Most of the shaders copypaste exactly the same code into `__init__` and `to`. I will be adding a new shader in the next diff, so let's make it a bit easier.

Reviewed By: bottler

Differential Revision: D35767884

fbshipit-source-id: 0057e3e2ae3be4eaa49ae7e2bf3e4176953dde9d
This commit is contained in:
Krzysztof Chalupka 2022-04-27 12:07:51 -07:00 committed by Facebook GitHub Bot
parent 9f443ed26b
commit 96889deab9

View File

@ -32,22 +32,7 @@ from .shading import flat_shading, gouraud_shading, phong_shading
# - sample colors from a texture map
# - apply per pixel lighting
# - blend colors across top K faces per pixel.
class HardPhongShader(nn.Module):
"""
Per pixel lighting - the lighting model is applied using the interpolated
coordinates and normals for each pixel. The blending function hard assigns
the color of the closest face for each pixel.
To use the default values, simply initialize the shader with the desired
device e.g.
.. code-block::
shader = HardPhongShader(device=torch.device("cuda:0"))
"""
class ShaderBase(nn.Module):
def __init__(
self,
device: Device = "cpu",
@ -74,6 +59,21 @@ class HardPhongShader(nn.Module):
self.lights = self.lights.to(device)
return self
class HardPhongShader(ShaderBase):
"""
Per pixel lighting - the lighting model is applied using the interpolated
coordinates and normals for each pixel. The blending function hard assigns
the color of the closest face for each pixel.
To use the default values, simply initialize the shader with the desired
device e.g.
.. code-block::
shader = HardPhongShader(device=torch.device("cuda:0"))
"""
def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor:
cameras = kwargs.get("cameras", self.cameras)
if cameras is None:
@ -97,7 +97,7 @@ class HardPhongShader(nn.Module):
return images
class SoftPhongShader(nn.Module):
class SoftPhongShader(ShaderBase):
"""
Per pixel lighting - the lighting model is applied using the interpolated
coordinates and normals for each pixel. The blending function returns the
@ -111,32 +111,6 @@ class SoftPhongShader(nn.Module):
shader = SoftPhongShader(device=torch.device("cuda:0"))
"""
def __init__(
self,
device: Device = "cpu",
cameras: Optional[TensorProperties] = None,
lights: Optional[TensorProperties] = None,
materials: Optional[Materials] = None,
blend_params: Optional[BlendParams] = None,
) -> None:
super().__init__()
self.lights = lights if lights is not None else PointLights(device=device)
self.materials = (
materials if materials is not None else Materials(device=device)
)
self.cameras = cameras
self.blend_params = blend_params if blend_params is not None else BlendParams()
# pyre-fixme[14]: `to` overrides method defined in `Module` inconsistently.
def to(self, device: Device):
# Manually move to device modules which are not subclasses of nn.Module
cameras = self.cameras
if cameras is not None:
self.cameras = cameras.to(device)
self.materials = self.materials.to(device)
self.lights = self.lights.to(device)
return self
def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor:
cameras = kwargs.get("cameras", self.cameras)
if cameras is None:
@ -164,7 +138,7 @@ class SoftPhongShader(nn.Module):
return images
class HardGouraudShader(nn.Module):
class HardGouraudShader(ShaderBase):
"""
Per vertex lighting - the lighting model is applied to the vertex colors and
the colors are then interpolated using the barycentric coordinates to
@ -179,32 +153,6 @@ class HardGouraudShader(nn.Module):
shader = HardGouraudShader(device=torch.device("cuda:0"))
"""
def __init__(
self,
device: Device = "cpu",
cameras: Optional[TensorProperties] = None,
lights: Optional[TensorProperties] = None,
materials: Optional[Materials] = None,
blend_params: Optional[BlendParams] = None,
) -> None:
super().__init__()
self.lights = lights if lights is not None else PointLights(device=device)
self.materials = (
materials if materials is not None else Materials(device=device)
)
self.cameras = cameras
self.blend_params = blend_params if blend_params is not None else BlendParams()
# pyre-fixme[14]: `to` overrides method defined in `Module` inconsistently.
def to(self, device: Device):
# Manually move to device modules which are not subclasses of nn.Module
cameras = self.cameras
if cameras is not None:
self.cameras = cameras.to(device)
self.materials = self.materials.to(device)
self.lights = self.lights.to(device)
return self
def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor:
cameras = kwargs.get("cameras", self.cameras)
if cameras is None:
@ -231,7 +179,7 @@ class HardGouraudShader(nn.Module):
return images
class SoftGouraudShader(nn.Module):
class SoftGouraudShader(ShaderBase):
"""
Per vertex lighting - the lighting model is applied to the vertex colors and
the colors are then interpolated using the barycentric coordinates to
@ -246,32 +194,6 @@ class SoftGouraudShader(nn.Module):
shader = SoftGouraudShader(device=torch.device("cuda:0"))
"""
def __init__(
self,
device: Device = "cpu",
cameras: Optional[TensorProperties] = None,
lights: Optional[TensorProperties] = None,
materials: Optional[Materials] = None,
blend_params: Optional[BlendParams] = None,
) -> None:
super().__init__()
self.lights = lights if lights is not None else PointLights(device=device)
self.materials = (
materials if materials is not None else Materials(device=device)
)
self.cameras = cameras
self.blend_params = blend_params if blend_params is not None else BlendParams()
# pyre-fixme[14]: `to` overrides method defined in `Module` inconsistently.
def to(self, device: Device):
# Manually move to device modules which are not subclasses of nn.Module
cameras = self.cameras
if cameras is not None:
self.cameras = cameras.to(device)
self.materials = self.materials.to(device)
self.lights = self.lights.to(device)
return self
def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor:
cameras = kwargs.get("cameras", self.cameras)
if cameras is None:
@ -320,7 +242,7 @@ def TexturedSoftPhongShader(
)
class HardFlatShader(nn.Module):
class HardFlatShader(ShaderBase):
"""
Per face lighting - the lighting model is applied using the average face
position and the face normal. The blending function hard assigns
@ -334,32 +256,6 @@ class HardFlatShader(nn.Module):
shader = HardFlatShader(device=torch.device("cuda:0"))
"""
def __init__(
self,
device: Device = "cpu",
cameras: Optional[TensorProperties] = None,
lights: Optional[TensorProperties] = None,
materials: Optional[Materials] = None,
blend_params: Optional[BlendParams] = None,
) -> None:
super().__init__()
self.lights = lights if lights is not None else PointLights(device=device)
self.materials = (
materials if materials is not None else Materials(device=device)
)
self.cameras = cameras
self.blend_params = blend_params if blend_params is not None else BlendParams()
# pyre-fixme[14]: `to` overrides method defined in `Module` inconsistently.
def to(self, device: Device):
# Manually move to device modules which are not subclasses of nn.Module
cameras = self.cameras
if cameras is not None:
self.cameras = cameras.to(device)
self.materials = self.materials.to(device)
self.lights = self.lights.to(device)
return self
def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor:
cameras = kwargs.get("cameras", self.cameras)
if cameras is None: