mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 18:02:14 +08:00 
			
		
		
		
	SplatterPhongShader 1: Pull out common Shader functionality into ShaderBase
Summary: Most of the shaders copypaste exactly the same code into `__init__` and `to`. I will be adding a new shader in the next diff, so let's make it a bit easier. Reviewed By: bottler Differential Revision: D35767884 fbshipit-source-id: 0057e3e2ae3be4eaa49ae7e2bf3e4176953dde9d
This commit is contained in:
		
							parent
							
								
									9f443ed26b
								
							
						
					
					
						commit
						96889deab9
					
				@ -32,22 +32,7 @@ from .shading import flat_shading, gouraud_shading, phong_shading
 | 
			
		||||
#     - sample colors from a texture map
 | 
			
		||||
#     - apply per pixel lighting
 | 
			
		||||
#     - blend colors across top K faces per pixel.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class HardPhongShader(nn.Module):
 | 
			
		||||
    """
 | 
			
		||||
    Per pixel lighting - the lighting model is applied using the interpolated
 | 
			
		||||
    coordinates and normals for each pixel. The blending function hard assigns
 | 
			
		||||
    the color of the closest face for each pixel.
 | 
			
		||||
 | 
			
		||||
    To use the default values, simply initialize the shader with the desired
 | 
			
		||||
    device e.g.
 | 
			
		||||
 | 
			
		||||
    .. code-block::
 | 
			
		||||
 | 
			
		||||
        shader = HardPhongShader(device=torch.device("cuda:0"))
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
class ShaderBase(nn.Module):
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        device: Device = "cpu",
 | 
			
		||||
@ -74,6 +59,21 @@ class HardPhongShader(nn.Module):
 | 
			
		||||
        self.lights = self.lights.to(device)
 | 
			
		||||
        return self
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class HardPhongShader(ShaderBase):
 | 
			
		||||
    """
 | 
			
		||||
    Per pixel lighting - the lighting model is applied using the interpolated
 | 
			
		||||
    coordinates and normals for each pixel. The blending function hard assigns
 | 
			
		||||
    the color of the closest face for each pixel.
 | 
			
		||||
 | 
			
		||||
    To use the default values, simply initialize the shader with the desired
 | 
			
		||||
    device e.g.
 | 
			
		||||
 | 
			
		||||
    .. code-block::
 | 
			
		||||
 | 
			
		||||
        shader = HardPhongShader(device=torch.device("cuda:0"))
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor:
 | 
			
		||||
        cameras = kwargs.get("cameras", self.cameras)
 | 
			
		||||
        if cameras is None:
 | 
			
		||||
@ -97,7 +97,7 @@ class HardPhongShader(nn.Module):
 | 
			
		||||
        return images
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class SoftPhongShader(nn.Module):
 | 
			
		||||
class SoftPhongShader(ShaderBase):
 | 
			
		||||
    """
 | 
			
		||||
    Per pixel lighting - the lighting model is applied using the interpolated
 | 
			
		||||
    coordinates and normals for each pixel. The blending function returns the
 | 
			
		||||
@ -111,32 +111,6 @@ class SoftPhongShader(nn.Module):
 | 
			
		||||
        shader = SoftPhongShader(device=torch.device("cuda:0"))
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        device: Device = "cpu",
 | 
			
		||||
        cameras: Optional[TensorProperties] = None,
 | 
			
		||||
        lights: Optional[TensorProperties] = None,
 | 
			
		||||
        materials: Optional[Materials] = None,
 | 
			
		||||
        blend_params: Optional[BlendParams] = None,
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self.lights = lights if lights is not None else PointLights(device=device)
 | 
			
		||||
        self.materials = (
 | 
			
		||||
            materials if materials is not None else Materials(device=device)
 | 
			
		||||
        )
 | 
			
		||||
        self.cameras = cameras
 | 
			
		||||
        self.blend_params = blend_params if blend_params is not None else BlendParams()
 | 
			
		||||
 | 
			
		||||
    # pyre-fixme[14]: `to` overrides method defined in `Module` inconsistently.
 | 
			
		||||
    def to(self, device: Device):
 | 
			
		||||
        # Manually move to device modules which are not subclasses of nn.Module
 | 
			
		||||
        cameras = self.cameras
 | 
			
		||||
        if cameras is not None:
 | 
			
		||||
            self.cameras = cameras.to(device)
 | 
			
		||||
        self.materials = self.materials.to(device)
 | 
			
		||||
        self.lights = self.lights.to(device)
 | 
			
		||||
        return self
 | 
			
		||||
 | 
			
		||||
    def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor:
 | 
			
		||||
        cameras = kwargs.get("cameras", self.cameras)
 | 
			
		||||
        if cameras is None:
 | 
			
		||||
@ -164,7 +138,7 @@ class SoftPhongShader(nn.Module):
 | 
			
		||||
        return images
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class HardGouraudShader(nn.Module):
 | 
			
		||||
class HardGouraudShader(ShaderBase):
 | 
			
		||||
    """
 | 
			
		||||
    Per vertex lighting - the lighting model is applied to the vertex colors and
 | 
			
		||||
    the colors are then interpolated using the barycentric coordinates to
 | 
			
		||||
@ -179,32 +153,6 @@ class HardGouraudShader(nn.Module):
 | 
			
		||||
        shader = HardGouraudShader(device=torch.device("cuda:0"))
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        device: Device = "cpu",
 | 
			
		||||
        cameras: Optional[TensorProperties] = None,
 | 
			
		||||
        lights: Optional[TensorProperties] = None,
 | 
			
		||||
        materials: Optional[Materials] = None,
 | 
			
		||||
        blend_params: Optional[BlendParams] = None,
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self.lights = lights if lights is not None else PointLights(device=device)
 | 
			
		||||
        self.materials = (
 | 
			
		||||
            materials if materials is not None else Materials(device=device)
 | 
			
		||||
        )
 | 
			
		||||
        self.cameras = cameras
 | 
			
		||||
        self.blend_params = blend_params if blend_params is not None else BlendParams()
 | 
			
		||||
 | 
			
		||||
    # pyre-fixme[14]: `to` overrides method defined in `Module` inconsistently.
 | 
			
		||||
    def to(self, device: Device):
 | 
			
		||||
        # Manually move to device modules which are not subclasses of nn.Module
 | 
			
		||||
        cameras = self.cameras
 | 
			
		||||
        if cameras is not None:
 | 
			
		||||
            self.cameras = cameras.to(device)
 | 
			
		||||
        self.materials = self.materials.to(device)
 | 
			
		||||
        self.lights = self.lights.to(device)
 | 
			
		||||
        return self
 | 
			
		||||
 | 
			
		||||
    def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor:
 | 
			
		||||
        cameras = kwargs.get("cameras", self.cameras)
 | 
			
		||||
        if cameras is None:
 | 
			
		||||
@ -231,7 +179,7 @@ class HardGouraudShader(nn.Module):
 | 
			
		||||
        return images
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class SoftGouraudShader(nn.Module):
 | 
			
		||||
class SoftGouraudShader(ShaderBase):
 | 
			
		||||
    """
 | 
			
		||||
    Per vertex lighting - the lighting model is applied to the vertex colors and
 | 
			
		||||
    the colors are then interpolated using the barycentric coordinates to
 | 
			
		||||
@ -246,32 +194,6 @@ class SoftGouraudShader(nn.Module):
 | 
			
		||||
        shader = SoftGouraudShader(device=torch.device("cuda:0"))
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        device: Device = "cpu",
 | 
			
		||||
        cameras: Optional[TensorProperties] = None,
 | 
			
		||||
        lights: Optional[TensorProperties] = None,
 | 
			
		||||
        materials: Optional[Materials] = None,
 | 
			
		||||
        blend_params: Optional[BlendParams] = None,
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self.lights = lights if lights is not None else PointLights(device=device)
 | 
			
		||||
        self.materials = (
 | 
			
		||||
            materials if materials is not None else Materials(device=device)
 | 
			
		||||
        )
 | 
			
		||||
        self.cameras = cameras
 | 
			
		||||
        self.blend_params = blend_params if blend_params is not None else BlendParams()
 | 
			
		||||
 | 
			
		||||
    # pyre-fixme[14]: `to` overrides method defined in `Module` inconsistently.
 | 
			
		||||
    def to(self, device: Device):
 | 
			
		||||
        # Manually move to device modules which are not subclasses of nn.Module
 | 
			
		||||
        cameras = self.cameras
 | 
			
		||||
        if cameras is not None:
 | 
			
		||||
            self.cameras = cameras.to(device)
 | 
			
		||||
        self.materials = self.materials.to(device)
 | 
			
		||||
        self.lights = self.lights.to(device)
 | 
			
		||||
        return self
 | 
			
		||||
 | 
			
		||||
    def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor:
 | 
			
		||||
        cameras = kwargs.get("cameras", self.cameras)
 | 
			
		||||
        if cameras is None:
 | 
			
		||||
@ -320,7 +242,7 @@ def TexturedSoftPhongShader(
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class HardFlatShader(nn.Module):
 | 
			
		||||
class HardFlatShader(ShaderBase):
 | 
			
		||||
    """
 | 
			
		||||
    Per face lighting - the lighting model is applied using the average face
 | 
			
		||||
    position and the face normal. The blending function hard assigns
 | 
			
		||||
@ -334,32 +256,6 @@ class HardFlatShader(nn.Module):
 | 
			
		||||
        shader = HardFlatShader(device=torch.device("cuda:0"))
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        device: Device = "cpu",
 | 
			
		||||
        cameras: Optional[TensorProperties] = None,
 | 
			
		||||
        lights: Optional[TensorProperties] = None,
 | 
			
		||||
        materials: Optional[Materials] = None,
 | 
			
		||||
        blend_params: Optional[BlendParams] = None,
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self.lights = lights if lights is not None else PointLights(device=device)
 | 
			
		||||
        self.materials = (
 | 
			
		||||
            materials if materials is not None else Materials(device=device)
 | 
			
		||||
        )
 | 
			
		||||
        self.cameras = cameras
 | 
			
		||||
        self.blend_params = blend_params if blend_params is not None else BlendParams()
 | 
			
		||||
 | 
			
		||||
    # pyre-fixme[14]: `to` overrides method defined in `Module` inconsistently.
 | 
			
		||||
    def to(self, device: Device):
 | 
			
		||||
        # Manually move to device modules which are not subclasses of nn.Module
 | 
			
		||||
        cameras = self.cameras
 | 
			
		||||
        if cameras is not None:
 | 
			
		||||
            self.cameras = cameras.to(device)
 | 
			
		||||
        self.materials = self.materials.to(device)
 | 
			
		||||
        self.lights = self.lights.to(device)
 | 
			
		||||
        return self
 | 
			
		||||
 | 
			
		||||
    def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor:
 | 
			
		||||
        cameras = kwargs.get("cameras", self.cameras)
 | 
			
		||||
        if cameras is None:
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user