diff --git a/pytorch3d/renderer/mesh/shader.py b/pytorch3d/renderer/mesh/shader.py index aecf6e8f..797efbcc 100644 --- a/pytorch3d/renderer/mesh/shader.py +++ b/pytorch3d/renderer/mesh/shader.py @@ -32,22 +32,7 @@ from .shading import flat_shading, gouraud_shading, phong_shading # - sample colors from a texture map # - apply per pixel lighting # - blend colors across top K faces per pixel. - - -class HardPhongShader(nn.Module): - """ - Per pixel lighting - the lighting model is applied using the interpolated - coordinates and normals for each pixel. The blending function hard assigns - the color of the closest face for each pixel. - - To use the default values, simply initialize the shader with the desired - device e.g. - - .. code-block:: - - shader = HardPhongShader(device=torch.device("cuda:0")) - """ - +class ShaderBase(nn.Module): def __init__( self, device: Device = "cpu", @@ -74,6 +59,21 @@ class HardPhongShader(nn.Module): self.lights = self.lights.to(device) return self + +class HardPhongShader(ShaderBase): + """ + Per pixel lighting - the lighting model is applied using the interpolated + coordinates and normals for each pixel. The blending function hard assigns + the color of the closest face for each pixel. + + To use the default values, simply initialize the shader with the desired + device e.g. + + .. code-block:: + + shader = HardPhongShader(device=torch.device("cuda:0")) + """ + def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor: cameras = kwargs.get("cameras", self.cameras) if cameras is None: @@ -97,7 +97,7 @@ class HardPhongShader(nn.Module): return images -class SoftPhongShader(nn.Module): +class SoftPhongShader(ShaderBase): """ Per pixel lighting - the lighting model is applied using the interpolated coordinates and normals for each pixel. The blending function returns the @@ -111,32 +111,6 @@ class SoftPhongShader(nn.Module): shader = SoftPhongShader(device=torch.device("cuda:0")) """ - def __init__( - self, - device: Device = "cpu", - cameras: Optional[TensorProperties] = None, - lights: Optional[TensorProperties] = None, - materials: Optional[Materials] = None, - blend_params: Optional[BlendParams] = None, - ) -> None: - super().__init__() - self.lights = lights if lights is not None else PointLights(device=device) - self.materials = ( - materials if materials is not None else Materials(device=device) - ) - self.cameras = cameras - self.blend_params = blend_params if blend_params is not None else BlendParams() - - # pyre-fixme[14]: `to` overrides method defined in `Module` inconsistently. - def to(self, device: Device): - # Manually move to device modules which are not subclasses of nn.Module - cameras = self.cameras - if cameras is not None: - self.cameras = cameras.to(device) - self.materials = self.materials.to(device) - self.lights = self.lights.to(device) - return self - def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor: cameras = kwargs.get("cameras", self.cameras) if cameras is None: @@ -164,7 +138,7 @@ class SoftPhongShader(nn.Module): return images -class HardGouraudShader(nn.Module): +class HardGouraudShader(ShaderBase): """ Per vertex lighting - the lighting model is applied to the vertex colors and the colors are then interpolated using the barycentric coordinates to @@ -179,32 +153,6 @@ class HardGouraudShader(nn.Module): shader = HardGouraudShader(device=torch.device("cuda:0")) """ - def __init__( - self, - device: Device = "cpu", - cameras: Optional[TensorProperties] = None, - lights: Optional[TensorProperties] = None, - materials: Optional[Materials] = None, - blend_params: Optional[BlendParams] = None, - ) -> None: - super().__init__() - self.lights = lights if lights is not None else PointLights(device=device) - self.materials = ( - materials if materials is not None else Materials(device=device) - ) - self.cameras = cameras - self.blend_params = blend_params if blend_params is not None else BlendParams() - - # pyre-fixme[14]: `to` overrides method defined in `Module` inconsistently. - def to(self, device: Device): - # Manually move to device modules which are not subclasses of nn.Module - cameras = self.cameras - if cameras is not None: - self.cameras = cameras.to(device) - self.materials = self.materials.to(device) - self.lights = self.lights.to(device) - return self - def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor: cameras = kwargs.get("cameras", self.cameras) if cameras is None: @@ -231,7 +179,7 @@ class HardGouraudShader(nn.Module): return images -class SoftGouraudShader(nn.Module): +class SoftGouraudShader(ShaderBase): """ Per vertex lighting - the lighting model is applied to the vertex colors and the colors are then interpolated using the barycentric coordinates to @@ -246,32 +194,6 @@ class SoftGouraudShader(nn.Module): shader = SoftGouraudShader(device=torch.device("cuda:0")) """ - def __init__( - self, - device: Device = "cpu", - cameras: Optional[TensorProperties] = None, - lights: Optional[TensorProperties] = None, - materials: Optional[Materials] = None, - blend_params: Optional[BlendParams] = None, - ) -> None: - super().__init__() - self.lights = lights if lights is not None else PointLights(device=device) - self.materials = ( - materials if materials is not None else Materials(device=device) - ) - self.cameras = cameras - self.blend_params = blend_params if blend_params is not None else BlendParams() - - # pyre-fixme[14]: `to` overrides method defined in `Module` inconsistently. - def to(self, device: Device): - # Manually move to device modules which are not subclasses of nn.Module - cameras = self.cameras - if cameras is not None: - self.cameras = cameras.to(device) - self.materials = self.materials.to(device) - self.lights = self.lights.to(device) - return self - def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor: cameras = kwargs.get("cameras", self.cameras) if cameras is None: @@ -320,7 +242,7 @@ def TexturedSoftPhongShader( ) -class HardFlatShader(nn.Module): +class HardFlatShader(ShaderBase): """ Per face lighting - the lighting model is applied using the average face position and the face normal. The blending function hard assigns @@ -334,32 +256,6 @@ class HardFlatShader(nn.Module): shader = HardFlatShader(device=torch.device("cuda:0")) """ - def __init__( - self, - device: Device = "cpu", - cameras: Optional[TensorProperties] = None, - lights: Optional[TensorProperties] = None, - materials: Optional[Materials] = None, - blend_params: Optional[BlendParams] = None, - ) -> None: - super().__init__() - self.lights = lights if lights is not None else PointLights(device=device) - self.materials = ( - materials if materials is not None else Materials(device=device) - ) - self.cameras = cameras - self.blend_params = blend_params if blend_params is not None else BlendParams() - - # pyre-fixme[14]: `to` overrides method defined in `Module` inconsistently. - def to(self, device: Device): - # Manually move to device modules which are not subclasses of nn.Module - cameras = self.cameras - if cameras is not None: - self.cameras = cameras.to(device) - self.materials = self.materials.to(device) - self.lights = self.lights.to(device) - return self - def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor: cameras = kwargs.get("cameras", self.cameras) if cameras is None: