From 7e0146ece438f6de98a7ae930a339587311cb410 Mon Sep 17 00:00:00 2001 From: Tristan Rice Date: Sun, 26 Jun 2022 04:01:29 -0700 Subject: [PATCH] shader: add SoftDepthShader and HardDepthShader for rendering depth maps (#36) Summary: X-link: https://github.com/fairinternal/pytorch3d/pull/36 This adds two shaders for rendering depth maps for meshes. This is useful for structure from motion applications that learn depths based off of camera pair disparities. There's two shaders, one hard which just returns the distances and then a second that does a cumsum on the probabilities of the points with a weighted sum. Areas that don't have any z faces are set to the zfar distance. Output from this renderer is `[N, H, W]` since it's just depth no need for channels. I haven't tested this in an ML model yet just in a notebook. hard: ![hardzshader](https://user-images.githubusercontent.com/909104/170190363-ef662c97-0bd2-488c-8675-0557a3c7dd06.png) soft: ![softzshader](https://user-images.githubusercontent.com/909104/170190365-65b08cd7-0c49-4119-803e-d33c1d8c676e.png) Pull Request resolved: https://github.com/facebookresearch/pytorch3d/pull/1208 Reviewed By: bottler Differential Revision: D36682194 Pulled By: d4l3k fbshipit-source-id: 5d4e10c6fb0fff5427be4ddd3bd76305a7ccc1e2 --- pytorch3d/renderer/mesh/shader.py | 68 +++++++++++++++++++++++++++++++ tests/test_shader.py | 4 ++ 2 files changed, 72 insertions(+) diff --git a/pytorch3d/renderer/mesh/shader.py b/pytorch3d/renderer/mesh/shader.py index 677812a7..72de786d 100644 --- a/pytorch3d/renderer/mesh/shader.py +++ b/pytorch3d/renderer/mesh/shader.py @@ -353,3 +353,71 @@ class SplatterPhongShader(ShaderBase): ) return images + + +class HardDepthShader(ShaderBase): + """ + Renders the Z distances of the closest face for each pixel. If no face is + found it returns the zfar value of the camera. + + Output from this shader is [N, H, W, 1] since it's only depth. + + To use the default values, simply initialize the shader with the desired + device e.g. + + .. code-block:: + + shader = HardDepthShader(device=torch.device("cuda:0")) + """ + + def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor: + cameras = super()._get_cameras(**kwargs) + + zfar = kwargs.get("zfar", getattr(cameras, "zfar", 100.0)) + mask = fragments.pix_to_face < 0 + + zbuf = fragments.zbuf[..., 0].clone() + zbuf[mask] = zfar + return zbuf.unsqueeze(3) + + +class SoftDepthShader(ShaderBase): + """ + Renders the Z distances using an aggregate of the distances of each face + based off of the point distance. If no face is found it returns the zfar + value of the camera. + + Output from this shader is [N, H, W, 1] since it's only depth. + + To use the default values, simply initialize the shader with the desired + device e.g. + + .. code-block:: + + shader = SoftDepthShader(device=torch.device("cuda:0")) + """ + + def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor: + cameras = super()._get_cameras(**kwargs) + + N, H, W, K = fragments.pix_to_face.shape + device = fragments.zbuf.device + mask = fragments.pix_to_face >= 0 + + zfar = kwargs.get("zfar", getattr(cameras, "zfar", 100.0)) + + # Sigmoid probability map based on the distance of the pixel to the face. + prob_map = torch.sigmoid(-fragments.dists / self.blend_params.sigma) * mask + + # append extra face for zfar + dists = torch.cat( + (fragments.zbuf, torch.ones((N, H, W, 1), device=device) * zfar), dim=3 + ) + probs = torch.cat((prob_map, torch.ones((N, H, W, 1), device=device)), dim=3) + + # compute weighting based off of probabilities using cumsum + probs = probs.cumsum(dim=3) + probs = probs.clamp(max=1) + probs = probs.diff(dim=3, prepend=torch.zeros((N, H, W, 1), device=device)) + + return (probs * dists).sum(dim=3).unsqueeze(3) diff --git a/tests/test_shader.py b/tests/test_shader.py index 3b751e8b..93ef71c7 100644 --- a/tests/test_shader.py +++ b/tests/test_shader.py @@ -10,9 +10,11 @@ import torch from pytorch3d.renderer.cameras import look_at_view_transform, PerspectiveCameras from pytorch3d.renderer.mesh.rasterizer import Fragments from pytorch3d.renderer.mesh.shader import ( + HardDepthShader, HardFlatShader, HardGouraudShader, HardPhongShader, + SoftDepthShader, SoftPhongShader, SplatterPhongShader, ) @@ -24,9 +26,11 @@ from .common_testing import TestCaseMixin class TestShader(TestCaseMixin, unittest.TestCase): def setUp(self): self.shader_classes = [ + HardDepthShader, HardFlatShader, HardGouraudShader, HardPhongShader, + SoftDepthShader, SoftPhongShader, SplatterPhongShader, ]