SplatterBlender follow-ups

Summary: A few minor additions I didn't fit into the SplatterBlender diffs, as requested by reviewers.

Reviewed By: jcjohnson

Differential Revision: D36682437

fbshipit-source-id: 57af995e766dfd2674b3984a3ba00aef7ca7db80
This commit is contained in:
Krzysztof Chalupka 2022-05-26 13:03:57 -07:00 committed by Facebook GitHub Bot
parent c31bf85a23
commit a42a89a5ba
5 changed files with 36 additions and 53 deletions

View File

@ -36,7 +36,6 @@ class BlendParams(NamedTuple):
sigma: float = 1e-4
gamma: float = 1e-4
background_color: Union[torch.Tensor, Sequence[float]] = (1.0, 1.0, 1.0)
background_alpha: float = 0.0
def _get_background_color(

View File

@ -16,6 +16,10 @@ from .rasterize_meshes import rasterize_meshes
@dataclass(frozen=True)
class Fragments:
"""
A dataclass representing the outputs of a rasterizer. Can be detached from the
computational graph in order to stop the gradients from flowing through the
rasterizer.
Members:
pix_to_face:
LongTensor of shape (N, image_size, image_size, faces_per_pixel) giving

View File

@ -55,6 +55,15 @@ class ShaderBase(nn.Module):
self.cameras = cameras
self.blend_params = blend_params if blend_params is not None else BlendParams()
def _get_cameras(self, **kwargs):
cameras = kwargs.get("cameras", self.cameras)
if cameras is None:
msg = "Cameras must be specified either at initialization \
or in the forward pass of the shader."
raise ValueError(msg)
return cameras
# pyre-fixme[14]: `to` overrides method defined in `Module` inconsistently.
def to(self, device: Device):
# Manually move to device modules which are not subclasses of nn.Module
@ -81,12 +90,7 @@ class HardPhongShader(ShaderBase):
"""
def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor:
cameras = kwargs.get("cameras", self.cameras)
if cameras is None:
msg = "Cameras must be specified either at initialization \
or in the forward pass of HardPhongShader"
raise ValueError(msg)
cameras = super()._get_cameras(**kwargs)
texels = meshes.sample_textures(fragments)
lights = kwargs.get("lights", self.lights)
materials = kwargs.get("materials", self.materials)
@ -118,12 +122,7 @@ class SoftPhongShader(ShaderBase):
"""
def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor:
cameras = kwargs.get("cameras", self.cameras)
if cameras is None:
msg = "Cameras must be specified either at initialization \
or in the forward pass of SoftPhongShader"
raise ValueError(msg)
cameras = super()._get_cameras(**kwargs)
texels = meshes.sample_textures(fragments)
lights = kwargs.get("lights", self.lights)
materials = kwargs.get("materials", self.materials)
@ -160,11 +159,7 @@ class HardGouraudShader(ShaderBase):
"""
def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor:
cameras = kwargs.get("cameras", self.cameras)
if cameras is None:
msg = "Cameras must be specified either at initialization \
or in the forward pass of HardGouraudShader"
raise ValueError(msg)
cameras = super()._get_cameras(**kwargs)
lights = kwargs.get("lights", self.lights)
materials = kwargs.get("materials", self.materials)
blend_params = kwargs.get("blend_params", self.blend_params)
@ -201,11 +196,7 @@ class SoftGouraudShader(ShaderBase):
"""
def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor:
cameras = kwargs.get("cameras", self.cameras)
if cameras is None:
msg = "Cameras must be specified either at initialization \
or in the forward pass of SoftGouraudShader"
raise ValueError(msg)
cameras = super()._get_cameras(**kwargs)
lights = kwargs.get("lights", self.lights)
materials = kwargs.get("materials", self.materials)
pixel_colors = gouraud_shading(
@ -263,11 +254,7 @@ class HardFlatShader(ShaderBase):
"""
def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor:
cameras = kwargs.get("cameras", self.cameras)
if cameras is None:
msg = "Cameras must be specified either at initialization \
or in the forward pass of HardFlatShader"
raise ValueError(msg)
cameras = super()._get_cameras(**kwargs)
texels = meshes.sample_textures(fragments)
lights = kwargs.get("lights", self.lights)
materials = kwargs.get("materials", self.materials)
@ -329,11 +316,6 @@ class SplatterPhongShader(ShaderBase):
shader = SplatterPhongShader(device=torch.device("cuda:0"))
Args:
detach_rasterizer: If True, stop gradients from flowing through the rasterizer.
This simulates the pipeline in [0] which uses a non-differentiable OpenGL
rasterizer.
[0] Cole, F. et al., "Differentiable Surface Rendering via Non-differentiable
Sampling".
"""
@ -343,11 +325,7 @@ class SplatterPhongShader(ShaderBase):
super().__init__(**kwargs)
def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor:
cameras = kwargs.get("cameras", self.cameras)
if cameras is None:
msg = "Cameras must be specified either at initialization \
or in the forward pass of SplatterPhongShader."
raise ValueError(msg)
cameras = super()._get_cameras(**kwargs)
texels = meshes.sample_textures(fragments)
lights = kwargs.get("lights", self.lights)
materials = kwargs.get("materials", self.materials)

View File

@ -4,6 +4,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# This file defines SplatterBlender, which is used for blending in SplatterPhongShader.
import itertools
from typing import Tuple

View File

@ -14,6 +14,7 @@ from pytorch3d.renderer.mesh.shader import (
HardGouraudShader,
HardPhongShader,
SoftPhongShader,
SplatterPhongShader,
)
from pytorch3d.structures.meshes import Meshes
@ -21,20 +22,22 @@ from .common_testing import TestCaseMixin
class TestShader(TestCaseMixin, unittest.TestCase):
def setUp(self):
self.shader_classes = [
HardFlatShader,
HardGouraudShader,
HardPhongShader,
SoftPhongShader,
SplatterPhongShader,
]
def test_to(self):
cpu_device = torch.device("cpu")
cuda_device = torch.device("cuda:0")
R, T = look_at_view_transform()
shader_classes = [
HardFlatShader,
HardGouraudShader,
HardPhongShader,
SoftPhongShader,
]
for shader_class in shader_classes:
for shader_class in self.shader_classes:
for cameras_class in (None, PerspectiveCameras):
if cameras_class is None:
cameras = None
@ -53,8 +56,11 @@ class TestShader(TestCaseMixin, unittest.TestCase):
self.assertIs(cpu_shader, cuda_shader)
if cameras is None:
self.assertIsNone(cuda_shader.cameras)
with self.assertRaisesRegexp(ValueError, "Cameras must be"):
cuda_shader._get_cameras()
else:
self.assertEqual(cuda_device, cuda_shader.cameras.device)
self.assertIsInstance(cuda_shader._get_cameras(), cameras_class)
self.assertEqual(cuda_device, cuda_shader.materials.device)
self.assertEqual(cuda_device, cuda_shader.lights.device)
@ -75,14 +81,8 @@ class TestShader(TestCaseMixin, unittest.TestCase):
zbuf=torch.ones_like(pix_to_face),
dists=torch.ones_like(pix_to_face),
)
shader_classes = [
HardFlatShader,
HardGouraudShader,
HardPhongShader,
SoftPhongShader,
]
for shader_class in shader_classes:
for shader_class in self.shader_classes:
shader = shader_class()
with self.assertRaises(ValueError):