mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
SplatterBlender follow-ups
Summary: A few minor additions I didn't fit into the SplatterBlender diffs, as requested by reviewers. Reviewed By: jcjohnson Differential Revision: D36682437 fbshipit-source-id: 57af995e766dfd2674b3984a3ba00aef7ca7db80
This commit is contained in:
parent
c31bf85a23
commit
a42a89a5ba
@ -36,7 +36,6 @@ class BlendParams(NamedTuple):
|
||||
sigma: float = 1e-4
|
||||
gamma: float = 1e-4
|
||||
background_color: Union[torch.Tensor, Sequence[float]] = (1.0, 1.0, 1.0)
|
||||
background_alpha: float = 0.0
|
||||
|
||||
|
||||
def _get_background_color(
|
||||
|
@ -16,6 +16,10 @@ from .rasterize_meshes import rasterize_meshes
|
||||
@dataclass(frozen=True)
|
||||
class Fragments:
|
||||
"""
|
||||
A dataclass representing the outputs of a rasterizer. Can be detached from the
|
||||
computational graph in order to stop the gradients from flowing through the
|
||||
rasterizer.
|
||||
|
||||
Members:
|
||||
pix_to_face:
|
||||
LongTensor of shape (N, image_size, image_size, faces_per_pixel) giving
|
||||
|
@ -55,6 +55,15 @@ class ShaderBase(nn.Module):
|
||||
self.cameras = cameras
|
||||
self.blend_params = blend_params if blend_params is not None else BlendParams()
|
||||
|
||||
def _get_cameras(self, **kwargs):
|
||||
cameras = kwargs.get("cameras", self.cameras)
|
||||
if cameras is None:
|
||||
msg = "Cameras must be specified either at initialization \
|
||||
or in the forward pass of the shader."
|
||||
raise ValueError(msg)
|
||||
|
||||
return cameras
|
||||
|
||||
# pyre-fixme[14]: `to` overrides method defined in `Module` inconsistently.
|
||||
def to(self, device: Device):
|
||||
# Manually move to device modules which are not subclasses of nn.Module
|
||||
@ -81,12 +90,7 @@ class HardPhongShader(ShaderBase):
|
||||
"""
|
||||
|
||||
def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor:
|
||||
cameras = kwargs.get("cameras", self.cameras)
|
||||
if cameras is None:
|
||||
msg = "Cameras must be specified either at initialization \
|
||||
or in the forward pass of HardPhongShader"
|
||||
raise ValueError(msg)
|
||||
|
||||
cameras = super()._get_cameras(**kwargs)
|
||||
texels = meshes.sample_textures(fragments)
|
||||
lights = kwargs.get("lights", self.lights)
|
||||
materials = kwargs.get("materials", self.materials)
|
||||
@ -118,12 +122,7 @@ class SoftPhongShader(ShaderBase):
|
||||
"""
|
||||
|
||||
def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor:
|
||||
cameras = kwargs.get("cameras", self.cameras)
|
||||
if cameras is None:
|
||||
msg = "Cameras must be specified either at initialization \
|
||||
or in the forward pass of SoftPhongShader"
|
||||
raise ValueError(msg)
|
||||
|
||||
cameras = super()._get_cameras(**kwargs)
|
||||
texels = meshes.sample_textures(fragments)
|
||||
lights = kwargs.get("lights", self.lights)
|
||||
materials = kwargs.get("materials", self.materials)
|
||||
@ -160,11 +159,7 @@ class HardGouraudShader(ShaderBase):
|
||||
"""
|
||||
|
||||
def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor:
|
||||
cameras = kwargs.get("cameras", self.cameras)
|
||||
if cameras is None:
|
||||
msg = "Cameras must be specified either at initialization \
|
||||
or in the forward pass of HardGouraudShader"
|
||||
raise ValueError(msg)
|
||||
cameras = super()._get_cameras(**kwargs)
|
||||
lights = kwargs.get("lights", self.lights)
|
||||
materials = kwargs.get("materials", self.materials)
|
||||
blend_params = kwargs.get("blend_params", self.blend_params)
|
||||
@ -201,11 +196,7 @@ class SoftGouraudShader(ShaderBase):
|
||||
"""
|
||||
|
||||
def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor:
|
||||
cameras = kwargs.get("cameras", self.cameras)
|
||||
if cameras is None:
|
||||
msg = "Cameras must be specified either at initialization \
|
||||
or in the forward pass of SoftGouraudShader"
|
||||
raise ValueError(msg)
|
||||
cameras = super()._get_cameras(**kwargs)
|
||||
lights = kwargs.get("lights", self.lights)
|
||||
materials = kwargs.get("materials", self.materials)
|
||||
pixel_colors = gouraud_shading(
|
||||
@ -263,11 +254,7 @@ class HardFlatShader(ShaderBase):
|
||||
"""
|
||||
|
||||
def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor:
|
||||
cameras = kwargs.get("cameras", self.cameras)
|
||||
if cameras is None:
|
||||
msg = "Cameras must be specified either at initialization \
|
||||
or in the forward pass of HardFlatShader"
|
||||
raise ValueError(msg)
|
||||
cameras = super()._get_cameras(**kwargs)
|
||||
texels = meshes.sample_textures(fragments)
|
||||
lights = kwargs.get("lights", self.lights)
|
||||
materials = kwargs.get("materials", self.materials)
|
||||
@ -329,11 +316,6 @@ class SplatterPhongShader(ShaderBase):
|
||||
|
||||
shader = SplatterPhongShader(device=torch.device("cuda:0"))
|
||||
|
||||
Args:
|
||||
detach_rasterizer: If True, stop gradients from flowing through the rasterizer.
|
||||
This simulates the pipeline in [0] which uses a non-differentiable OpenGL
|
||||
rasterizer.
|
||||
|
||||
[0] Cole, F. et al., "Differentiable Surface Rendering via Non-differentiable
|
||||
Sampling".
|
||||
"""
|
||||
@ -343,11 +325,7 @@ class SplatterPhongShader(ShaderBase):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor:
|
||||
cameras = kwargs.get("cameras", self.cameras)
|
||||
if cameras is None:
|
||||
msg = "Cameras must be specified either at initialization \
|
||||
or in the forward pass of SplatterPhongShader."
|
||||
raise ValueError(msg)
|
||||
cameras = super()._get_cameras(**kwargs)
|
||||
texels = meshes.sample_textures(fragments)
|
||||
lights = kwargs.get("lights", self.lights)
|
||||
materials = kwargs.get("materials", self.materials)
|
||||
|
@ -4,6 +4,8 @@
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
# This file defines SplatterBlender, which is used for blending in SplatterPhongShader.
|
||||
|
||||
import itertools
|
||||
from typing import Tuple
|
||||
|
||||
|
@ -14,6 +14,7 @@ from pytorch3d.renderer.mesh.shader import (
|
||||
HardGouraudShader,
|
||||
HardPhongShader,
|
||||
SoftPhongShader,
|
||||
SplatterPhongShader,
|
||||
)
|
||||
from pytorch3d.structures.meshes import Meshes
|
||||
|
||||
@ -21,20 +22,22 @@ from .common_testing import TestCaseMixin
|
||||
|
||||
|
||||
class TestShader(TestCaseMixin, unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.shader_classes = [
|
||||
HardFlatShader,
|
||||
HardGouraudShader,
|
||||
HardPhongShader,
|
||||
SoftPhongShader,
|
||||
SplatterPhongShader,
|
||||
]
|
||||
|
||||
def test_to(self):
|
||||
cpu_device = torch.device("cpu")
|
||||
cuda_device = torch.device("cuda:0")
|
||||
|
||||
R, T = look_at_view_transform()
|
||||
|
||||
shader_classes = [
|
||||
HardFlatShader,
|
||||
HardGouraudShader,
|
||||
HardPhongShader,
|
||||
SoftPhongShader,
|
||||
]
|
||||
|
||||
for shader_class in shader_classes:
|
||||
for shader_class in self.shader_classes:
|
||||
for cameras_class in (None, PerspectiveCameras):
|
||||
if cameras_class is None:
|
||||
cameras = None
|
||||
@ -53,8 +56,11 @@ class TestShader(TestCaseMixin, unittest.TestCase):
|
||||
self.assertIs(cpu_shader, cuda_shader)
|
||||
if cameras is None:
|
||||
self.assertIsNone(cuda_shader.cameras)
|
||||
with self.assertRaisesRegexp(ValueError, "Cameras must be"):
|
||||
cuda_shader._get_cameras()
|
||||
else:
|
||||
self.assertEqual(cuda_device, cuda_shader.cameras.device)
|
||||
self.assertIsInstance(cuda_shader._get_cameras(), cameras_class)
|
||||
self.assertEqual(cuda_device, cuda_shader.materials.device)
|
||||
self.assertEqual(cuda_device, cuda_shader.lights.device)
|
||||
|
||||
@ -75,14 +81,8 @@ class TestShader(TestCaseMixin, unittest.TestCase):
|
||||
zbuf=torch.ones_like(pix_to_face),
|
||||
dists=torch.ones_like(pix_to_face),
|
||||
)
|
||||
shader_classes = [
|
||||
HardFlatShader,
|
||||
HardGouraudShader,
|
||||
HardPhongShader,
|
||||
SoftPhongShader,
|
||||
]
|
||||
|
||||
for shader_class in shader_classes:
|
||||
for shader_class in self.shader_classes:
|
||||
shader = shader_class()
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
|
Loading…
x
Reference in New Issue
Block a user