diff --git a/pytorch3d/renderer/blending.py b/pytorch3d/renderer/blending.py index 69318ca1..1a7cbc9a 100644 --- a/pytorch3d/renderer/blending.py +++ b/pytorch3d/renderer/blending.py @@ -36,7 +36,6 @@ class BlendParams(NamedTuple): sigma: float = 1e-4 gamma: float = 1e-4 background_color: Union[torch.Tensor, Sequence[float]] = (1.0, 1.0, 1.0) - background_alpha: float = 0.0 def _get_background_color( diff --git a/pytorch3d/renderer/mesh/rasterizer.py b/pytorch3d/renderer/mesh/rasterizer.py index 54f3bb47..6f43e7b3 100644 --- a/pytorch3d/renderer/mesh/rasterizer.py +++ b/pytorch3d/renderer/mesh/rasterizer.py @@ -16,6 +16,10 @@ from .rasterize_meshes import rasterize_meshes @dataclass(frozen=True) class Fragments: """ + A dataclass representing the outputs of a rasterizer. Can be detached from the + computational graph in order to stop the gradients from flowing through the + rasterizer. + Members: pix_to_face: LongTensor of shape (N, image_size, image_size, faces_per_pixel) giving diff --git a/pytorch3d/renderer/mesh/shader.py b/pytorch3d/renderer/mesh/shader.py index 815dd85e..677812a7 100644 --- a/pytorch3d/renderer/mesh/shader.py +++ b/pytorch3d/renderer/mesh/shader.py @@ -55,6 +55,15 @@ class ShaderBase(nn.Module): self.cameras = cameras self.blend_params = blend_params if blend_params is not None else BlendParams() + def _get_cameras(self, **kwargs): + cameras = kwargs.get("cameras", self.cameras) + if cameras is None: + msg = "Cameras must be specified either at initialization \ + or in the forward pass of the shader." + raise ValueError(msg) + + return cameras + # pyre-fixme[14]: `to` overrides method defined in `Module` inconsistently. def to(self, device: Device): # Manually move to device modules which are not subclasses of nn.Module @@ -81,12 +90,7 @@ class HardPhongShader(ShaderBase): """ def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor: - cameras = kwargs.get("cameras", self.cameras) - if cameras is None: - msg = "Cameras must be specified either at initialization \ - or in the forward pass of HardPhongShader" - raise ValueError(msg) - + cameras = super()._get_cameras(**kwargs) texels = meshes.sample_textures(fragments) lights = kwargs.get("lights", self.lights) materials = kwargs.get("materials", self.materials) @@ -118,12 +122,7 @@ class SoftPhongShader(ShaderBase): """ def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor: - cameras = kwargs.get("cameras", self.cameras) - if cameras is None: - msg = "Cameras must be specified either at initialization \ - or in the forward pass of SoftPhongShader" - raise ValueError(msg) - + cameras = super()._get_cameras(**kwargs) texels = meshes.sample_textures(fragments) lights = kwargs.get("lights", self.lights) materials = kwargs.get("materials", self.materials) @@ -160,11 +159,7 @@ class HardGouraudShader(ShaderBase): """ def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor: - cameras = kwargs.get("cameras", self.cameras) - if cameras is None: - msg = "Cameras must be specified either at initialization \ - or in the forward pass of HardGouraudShader" - raise ValueError(msg) + cameras = super()._get_cameras(**kwargs) lights = kwargs.get("lights", self.lights) materials = kwargs.get("materials", self.materials) blend_params = kwargs.get("blend_params", self.blend_params) @@ -201,11 +196,7 @@ class SoftGouraudShader(ShaderBase): """ def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor: - cameras = kwargs.get("cameras", self.cameras) - if cameras is None: - msg = "Cameras must be specified either at initialization \ - or in the forward pass of SoftGouraudShader" - raise ValueError(msg) + cameras = super()._get_cameras(**kwargs) lights = kwargs.get("lights", self.lights) materials = kwargs.get("materials", self.materials) pixel_colors = gouraud_shading( @@ -263,11 +254,7 @@ class HardFlatShader(ShaderBase): """ def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor: - cameras = kwargs.get("cameras", self.cameras) - if cameras is None: - msg = "Cameras must be specified either at initialization \ - or in the forward pass of HardFlatShader" - raise ValueError(msg) + cameras = super()._get_cameras(**kwargs) texels = meshes.sample_textures(fragments) lights = kwargs.get("lights", self.lights) materials = kwargs.get("materials", self.materials) @@ -329,11 +316,6 @@ class SplatterPhongShader(ShaderBase): shader = SplatterPhongShader(device=torch.device("cuda:0")) - Args: - detach_rasterizer: If True, stop gradients from flowing through the rasterizer. - This simulates the pipeline in [0] which uses a non-differentiable OpenGL - rasterizer. - [0] Cole, F. et al., "Differentiable Surface Rendering via Non-differentiable Sampling". """ @@ -343,11 +325,7 @@ class SplatterPhongShader(ShaderBase): super().__init__(**kwargs) def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor: - cameras = kwargs.get("cameras", self.cameras) - if cameras is None: - msg = "Cameras must be specified either at initialization \ - or in the forward pass of SplatterPhongShader." - raise ValueError(msg) + cameras = super()._get_cameras(**kwargs) texels = meshes.sample_textures(fragments) lights = kwargs.get("lights", self.lights) materials = kwargs.get("materials", self.materials) diff --git a/pytorch3d/renderer/splatter_blend.py b/pytorch3d/renderer/splatter_blend.py index 4bdb0a73..92f5220a 100644 --- a/pytorch3d/renderer/splatter_blend.py +++ b/pytorch3d/renderer/splatter_blend.py @@ -4,6 +4,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# This file defines SplatterBlender, which is used for blending in SplatterPhongShader. + import itertools from typing import Tuple diff --git a/tests/test_shader.py b/tests/test_shader.py index 4b9654d3..3b751e8b 100644 --- a/tests/test_shader.py +++ b/tests/test_shader.py @@ -14,6 +14,7 @@ from pytorch3d.renderer.mesh.shader import ( HardGouraudShader, HardPhongShader, SoftPhongShader, + SplatterPhongShader, ) from pytorch3d.structures.meshes import Meshes @@ -21,20 +22,22 @@ from .common_testing import TestCaseMixin class TestShader(TestCaseMixin, unittest.TestCase): + def setUp(self): + self.shader_classes = [ + HardFlatShader, + HardGouraudShader, + HardPhongShader, + SoftPhongShader, + SplatterPhongShader, + ] + def test_to(self): cpu_device = torch.device("cpu") cuda_device = torch.device("cuda:0") R, T = look_at_view_transform() - shader_classes = [ - HardFlatShader, - HardGouraudShader, - HardPhongShader, - SoftPhongShader, - ] - - for shader_class in shader_classes: + for shader_class in self.shader_classes: for cameras_class in (None, PerspectiveCameras): if cameras_class is None: cameras = None @@ -53,8 +56,11 @@ class TestShader(TestCaseMixin, unittest.TestCase): self.assertIs(cpu_shader, cuda_shader) if cameras is None: self.assertIsNone(cuda_shader.cameras) + with self.assertRaisesRegexp(ValueError, "Cameras must be"): + cuda_shader._get_cameras() else: self.assertEqual(cuda_device, cuda_shader.cameras.device) + self.assertIsInstance(cuda_shader._get_cameras(), cameras_class) self.assertEqual(cuda_device, cuda_shader.materials.device) self.assertEqual(cuda_device, cuda_shader.lights.device) @@ -75,14 +81,8 @@ class TestShader(TestCaseMixin, unittest.TestCase): zbuf=torch.ones_like(pix_to_face), dists=torch.ones_like(pix_to_face), ) - shader_classes = [ - HardFlatShader, - HardGouraudShader, - HardPhongShader, - SoftPhongShader, - ] - for shader_class in shader_classes: + for shader_class in self.shader_classes: shader = shader_class() with self.assertRaises(ValueError):