diff --git a/pytorch3d/renderer/mesh/shader.py b/pytorch3d/renderer/mesh/shader.py index 677812a7..72de786d 100644 --- a/pytorch3d/renderer/mesh/shader.py +++ b/pytorch3d/renderer/mesh/shader.py @@ -353,3 +353,71 @@ class SplatterPhongShader(ShaderBase): ) return images + + +class HardDepthShader(ShaderBase): + """ + Renders the Z distances of the closest face for each pixel. If no face is + found it returns the zfar value of the camera. + + Output from this shader is [N, H, W, 1] since it's only depth. + + To use the default values, simply initialize the shader with the desired + device e.g. + + .. code-block:: + + shader = HardDepthShader(device=torch.device("cuda:0")) + """ + + def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor: + cameras = super()._get_cameras(**kwargs) + + zfar = kwargs.get("zfar", getattr(cameras, "zfar", 100.0)) + mask = fragments.pix_to_face < 0 + + zbuf = fragments.zbuf[..., 0].clone() + zbuf[mask] = zfar + return zbuf.unsqueeze(3) + + +class SoftDepthShader(ShaderBase): + """ + Renders the Z distances using an aggregate of the distances of each face + based off of the point distance. If no face is found it returns the zfar + value of the camera. + + Output from this shader is [N, H, W, 1] since it's only depth. + + To use the default values, simply initialize the shader with the desired + device e.g. + + .. code-block:: + + shader = SoftDepthShader(device=torch.device("cuda:0")) + """ + + def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor: + cameras = super()._get_cameras(**kwargs) + + N, H, W, K = fragments.pix_to_face.shape + device = fragments.zbuf.device + mask = fragments.pix_to_face >= 0 + + zfar = kwargs.get("zfar", getattr(cameras, "zfar", 100.0)) + + # Sigmoid probability map based on the distance of the pixel to the face. + prob_map = torch.sigmoid(-fragments.dists / self.blend_params.sigma) * mask + + # append extra face for zfar + dists = torch.cat( + (fragments.zbuf, torch.ones((N, H, W, 1), device=device) * zfar), dim=3 + ) + probs = torch.cat((prob_map, torch.ones((N, H, W, 1), device=device)), dim=3) + + # compute weighting based off of probabilities using cumsum + probs = probs.cumsum(dim=3) + probs = probs.clamp(max=1) + probs = probs.diff(dim=3, prepend=torch.zeros((N, H, W, 1), device=device)) + + return (probs * dists).sum(dim=3).unsqueeze(3) diff --git a/tests/test_shader.py b/tests/test_shader.py index 3b751e8b..93ef71c7 100644 --- a/tests/test_shader.py +++ b/tests/test_shader.py @@ -10,9 +10,11 @@ import torch from pytorch3d.renderer.cameras import look_at_view_transform, PerspectiveCameras from pytorch3d.renderer.mesh.rasterizer import Fragments from pytorch3d.renderer.mesh.shader import ( + HardDepthShader, HardFlatShader, HardGouraudShader, HardPhongShader, + SoftDepthShader, SoftPhongShader, SplatterPhongShader, ) @@ -24,9 +26,11 @@ from .common_testing import TestCaseMixin class TestShader(TestCaseMixin, unittest.TestCase): def setUp(self): self.shader_classes = [ + HardDepthShader, HardFlatShader, HardGouraudShader, HardPhongShader, + SoftDepthShader, SoftPhongShader, SplatterPhongShader, ]