mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 18:02:14 +08:00 
			
		
		
		
	SplatterBlender follow-ups
Summary: A few minor additions I didn't fit into the SplatterBlender diffs, as requested by reviewers. Reviewed By: jcjohnson Differential Revision: D36682437 fbshipit-source-id: 57af995e766dfd2674b3984a3ba00aef7ca7db80
This commit is contained in:
		
							parent
							
								
									c31bf85a23
								
							
						
					
					
						commit
						a42a89a5ba
					
				@ -36,7 +36,6 @@ class BlendParams(NamedTuple):
 | 
			
		||||
    sigma: float = 1e-4
 | 
			
		||||
    gamma: float = 1e-4
 | 
			
		||||
    background_color: Union[torch.Tensor, Sequence[float]] = (1.0, 1.0, 1.0)
 | 
			
		||||
    background_alpha: float = 0.0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _get_background_color(
 | 
			
		||||
 | 
			
		||||
@ -16,6 +16,10 @@ from .rasterize_meshes import rasterize_meshes
 | 
			
		||||
@dataclass(frozen=True)
 | 
			
		||||
class Fragments:
 | 
			
		||||
    """
 | 
			
		||||
    A dataclass representing the outputs of a rasterizer. Can be detached from the
 | 
			
		||||
    computational graph in order to stop the gradients from flowing through the
 | 
			
		||||
    rasterizer.
 | 
			
		||||
 | 
			
		||||
    Members:
 | 
			
		||||
        pix_to_face:
 | 
			
		||||
            LongTensor of shape (N, image_size, image_size, faces_per_pixel) giving
 | 
			
		||||
 | 
			
		||||
@ -55,6 +55,15 @@ class ShaderBase(nn.Module):
 | 
			
		||||
        self.cameras = cameras
 | 
			
		||||
        self.blend_params = blend_params if blend_params is not None else BlendParams()
 | 
			
		||||
 | 
			
		||||
    def _get_cameras(self, **kwargs):
 | 
			
		||||
        cameras = kwargs.get("cameras", self.cameras)
 | 
			
		||||
        if cameras is None:
 | 
			
		||||
            msg = "Cameras must be specified either at initialization \
 | 
			
		||||
                or in the forward pass of the shader."
 | 
			
		||||
            raise ValueError(msg)
 | 
			
		||||
 | 
			
		||||
        return cameras
 | 
			
		||||
 | 
			
		||||
    # pyre-fixme[14]: `to` overrides method defined in `Module` inconsistently.
 | 
			
		||||
    def to(self, device: Device):
 | 
			
		||||
        # Manually move to device modules which are not subclasses of nn.Module
 | 
			
		||||
@ -81,12 +90,7 @@ class HardPhongShader(ShaderBase):
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor:
 | 
			
		||||
        cameras = kwargs.get("cameras", self.cameras)
 | 
			
		||||
        if cameras is None:
 | 
			
		||||
            msg = "Cameras must be specified either at initialization \
 | 
			
		||||
                or in the forward pass of HardPhongShader"
 | 
			
		||||
            raise ValueError(msg)
 | 
			
		||||
 | 
			
		||||
        cameras = super()._get_cameras(**kwargs)
 | 
			
		||||
        texels = meshes.sample_textures(fragments)
 | 
			
		||||
        lights = kwargs.get("lights", self.lights)
 | 
			
		||||
        materials = kwargs.get("materials", self.materials)
 | 
			
		||||
@ -118,12 +122,7 @@ class SoftPhongShader(ShaderBase):
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor:
 | 
			
		||||
        cameras = kwargs.get("cameras", self.cameras)
 | 
			
		||||
        if cameras is None:
 | 
			
		||||
            msg = "Cameras must be specified either at initialization \
 | 
			
		||||
                or in the forward pass of SoftPhongShader"
 | 
			
		||||
            raise ValueError(msg)
 | 
			
		||||
 | 
			
		||||
        cameras = super()._get_cameras(**kwargs)
 | 
			
		||||
        texels = meshes.sample_textures(fragments)
 | 
			
		||||
        lights = kwargs.get("lights", self.lights)
 | 
			
		||||
        materials = kwargs.get("materials", self.materials)
 | 
			
		||||
@ -160,11 +159,7 @@ class HardGouraudShader(ShaderBase):
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor:
 | 
			
		||||
        cameras = kwargs.get("cameras", self.cameras)
 | 
			
		||||
        if cameras is None:
 | 
			
		||||
            msg = "Cameras must be specified either at initialization \
 | 
			
		||||
                or in the forward pass of HardGouraudShader"
 | 
			
		||||
            raise ValueError(msg)
 | 
			
		||||
        cameras = super()._get_cameras(**kwargs)
 | 
			
		||||
        lights = kwargs.get("lights", self.lights)
 | 
			
		||||
        materials = kwargs.get("materials", self.materials)
 | 
			
		||||
        blend_params = kwargs.get("blend_params", self.blend_params)
 | 
			
		||||
@ -201,11 +196,7 @@ class SoftGouraudShader(ShaderBase):
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor:
 | 
			
		||||
        cameras = kwargs.get("cameras", self.cameras)
 | 
			
		||||
        if cameras is None:
 | 
			
		||||
            msg = "Cameras must be specified either at initialization \
 | 
			
		||||
                or in the forward pass of SoftGouraudShader"
 | 
			
		||||
            raise ValueError(msg)
 | 
			
		||||
        cameras = super()._get_cameras(**kwargs)
 | 
			
		||||
        lights = kwargs.get("lights", self.lights)
 | 
			
		||||
        materials = kwargs.get("materials", self.materials)
 | 
			
		||||
        pixel_colors = gouraud_shading(
 | 
			
		||||
@ -263,11 +254,7 @@ class HardFlatShader(ShaderBase):
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor:
 | 
			
		||||
        cameras = kwargs.get("cameras", self.cameras)
 | 
			
		||||
        if cameras is None:
 | 
			
		||||
            msg = "Cameras must be specified either at initialization \
 | 
			
		||||
                or in the forward pass of HardFlatShader"
 | 
			
		||||
            raise ValueError(msg)
 | 
			
		||||
        cameras = super()._get_cameras(**kwargs)
 | 
			
		||||
        texels = meshes.sample_textures(fragments)
 | 
			
		||||
        lights = kwargs.get("lights", self.lights)
 | 
			
		||||
        materials = kwargs.get("materials", self.materials)
 | 
			
		||||
@ -329,11 +316,6 @@ class SplatterPhongShader(ShaderBase):
 | 
			
		||||
 | 
			
		||||
        shader = SplatterPhongShader(device=torch.device("cuda:0"))
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        detach_rasterizer: If True, stop gradients from flowing through the rasterizer.
 | 
			
		||||
            This simulates the pipeline in [0] which uses a non-differentiable OpenGL
 | 
			
		||||
            rasterizer.
 | 
			
		||||
 | 
			
		||||
    [0] Cole, F. et al., "Differentiable Surface Rendering via Non-differentiable
 | 
			
		||||
        Sampling".
 | 
			
		||||
    """
 | 
			
		||||
@ -343,11 +325,7 @@ class SplatterPhongShader(ShaderBase):
 | 
			
		||||
        super().__init__(**kwargs)
 | 
			
		||||
 | 
			
		||||
    def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor:
 | 
			
		||||
        cameras = kwargs.get("cameras", self.cameras)
 | 
			
		||||
        if cameras is None:
 | 
			
		||||
            msg = "Cameras must be specified either at initialization \
 | 
			
		||||
                or in the forward pass of SplatterPhongShader."
 | 
			
		||||
            raise ValueError(msg)
 | 
			
		||||
        cameras = super()._get_cameras(**kwargs)
 | 
			
		||||
        texels = meshes.sample_textures(fragments)
 | 
			
		||||
        lights = kwargs.get("lights", self.lights)
 | 
			
		||||
        materials = kwargs.get("materials", self.materials)
 | 
			
		||||
 | 
			
		||||
@ -4,6 +4,8 @@
 | 
			
		||||
# This source code is licensed under the BSD-style license found in the
 | 
			
		||||
# LICENSE file in the root directory of this source tree.
 | 
			
		||||
 | 
			
		||||
# This file defines SplatterBlender, which is used for blending in SplatterPhongShader.
 | 
			
		||||
 | 
			
		||||
import itertools
 | 
			
		||||
from typing import Tuple
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -14,6 +14,7 @@ from pytorch3d.renderer.mesh.shader import (
 | 
			
		||||
    HardGouraudShader,
 | 
			
		||||
    HardPhongShader,
 | 
			
		||||
    SoftPhongShader,
 | 
			
		||||
    SplatterPhongShader,
 | 
			
		||||
)
 | 
			
		||||
from pytorch3d.structures.meshes import Meshes
 | 
			
		||||
 | 
			
		||||
@ -21,20 +22,22 @@ from .common_testing import TestCaseMixin
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestShader(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
    def setUp(self):
 | 
			
		||||
        self.shader_classes = [
 | 
			
		||||
            HardFlatShader,
 | 
			
		||||
            HardGouraudShader,
 | 
			
		||||
            HardPhongShader,
 | 
			
		||||
            SoftPhongShader,
 | 
			
		||||
            SplatterPhongShader,
 | 
			
		||||
        ]
 | 
			
		||||
 | 
			
		||||
    def test_to(self):
 | 
			
		||||
        cpu_device = torch.device("cpu")
 | 
			
		||||
        cuda_device = torch.device("cuda:0")
 | 
			
		||||
 | 
			
		||||
        R, T = look_at_view_transform()
 | 
			
		||||
 | 
			
		||||
        shader_classes = [
 | 
			
		||||
            HardFlatShader,
 | 
			
		||||
            HardGouraudShader,
 | 
			
		||||
            HardPhongShader,
 | 
			
		||||
            SoftPhongShader,
 | 
			
		||||
        ]
 | 
			
		||||
 | 
			
		||||
        for shader_class in shader_classes:
 | 
			
		||||
        for shader_class in self.shader_classes:
 | 
			
		||||
            for cameras_class in (None, PerspectiveCameras):
 | 
			
		||||
                if cameras_class is None:
 | 
			
		||||
                    cameras = None
 | 
			
		||||
@ -53,8 +56,11 @@ class TestShader(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
                self.assertIs(cpu_shader, cuda_shader)
 | 
			
		||||
                if cameras is None:
 | 
			
		||||
                    self.assertIsNone(cuda_shader.cameras)
 | 
			
		||||
                    with self.assertRaisesRegexp(ValueError, "Cameras must be"):
 | 
			
		||||
                        cuda_shader._get_cameras()
 | 
			
		||||
                else:
 | 
			
		||||
                    self.assertEqual(cuda_device, cuda_shader.cameras.device)
 | 
			
		||||
                    self.assertIsInstance(cuda_shader._get_cameras(), cameras_class)
 | 
			
		||||
                self.assertEqual(cuda_device, cuda_shader.materials.device)
 | 
			
		||||
                self.assertEqual(cuda_device, cuda_shader.lights.device)
 | 
			
		||||
 | 
			
		||||
@ -75,14 +81,8 @@ class TestShader(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
            zbuf=torch.ones_like(pix_to_face),
 | 
			
		||||
            dists=torch.ones_like(pix_to_face),
 | 
			
		||||
        )
 | 
			
		||||
        shader_classes = [
 | 
			
		||||
            HardFlatShader,
 | 
			
		||||
            HardGouraudShader,
 | 
			
		||||
            HardPhongShader,
 | 
			
		||||
            SoftPhongShader,
 | 
			
		||||
        ]
 | 
			
		||||
 | 
			
		||||
        for shader_class in shader_classes:
 | 
			
		||||
        for shader_class in self.shader_classes:
 | 
			
		||||
            shader = shader_class()
 | 
			
		||||
 | 
			
		||||
            with self.assertRaises(ValueError):
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user