shader: add SoftDepthShader and HardDepthShader for rendering depth maps (#36)

Summary:
X-link: https://github.com/fairinternal/pytorch3d/pull/36

This adds two shaders for rendering depth maps for meshes. This is useful for structure from motion applications that learn depths based off of camera pair disparities.

There's two shaders, one hard which just returns the distances and then a second that does a cumsum on the probabilities of the points with a weighted sum. Areas that don't have any z faces are set to the zfar distance.

Output from this renderer is `[N, H, W]` since it's just depth no need for channels.

I haven't tested this in an ML model yet just in a notebook.

hard:
![hardzshader](https://user-images.githubusercontent.com/909104/170190363-ef662c97-0bd2-488c-8675-0557a3c7dd06.png)

soft:
![softzshader](https://user-images.githubusercontent.com/909104/170190365-65b08cd7-0c49-4119-803e-d33c1d8c676e.png)

Pull Request resolved: https://github.com/facebookresearch/pytorch3d/pull/1208

Reviewed By: bottler

Differential Revision: D36682194

Pulled By: d4l3k

fbshipit-source-id: 5d4e10c6fb0fff5427be4ddd3bd76305a7ccc1e2
This commit is contained in:
Tristan Rice 2022-06-26 04:01:29 -07:00 committed by Facebook GitHub Bot
parent 0e4c53c612
commit 7e0146ece4
2 changed files with 72 additions and 0 deletions

View File

@ -353,3 +353,71 @@ class SplatterPhongShader(ShaderBase):
)
return images
class HardDepthShader(ShaderBase):
"""
Renders the Z distances of the closest face for each pixel. If no face is
found it returns the zfar value of the camera.
Output from this shader is [N, H, W, 1] since it's only depth.
To use the default values, simply initialize the shader with the desired
device e.g.
.. code-block::
shader = HardDepthShader(device=torch.device("cuda:0"))
"""
def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor:
cameras = super()._get_cameras(**kwargs)
zfar = kwargs.get("zfar", getattr(cameras, "zfar", 100.0))
mask = fragments.pix_to_face < 0
zbuf = fragments.zbuf[..., 0].clone()
zbuf[mask] = zfar
return zbuf.unsqueeze(3)
class SoftDepthShader(ShaderBase):
"""
Renders the Z distances using an aggregate of the distances of each face
based off of the point distance. If no face is found it returns the zfar
value of the camera.
Output from this shader is [N, H, W, 1] since it's only depth.
To use the default values, simply initialize the shader with the desired
device e.g.
.. code-block::
shader = SoftDepthShader(device=torch.device("cuda:0"))
"""
def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor:
cameras = super()._get_cameras(**kwargs)
N, H, W, K = fragments.pix_to_face.shape
device = fragments.zbuf.device
mask = fragments.pix_to_face >= 0
zfar = kwargs.get("zfar", getattr(cameras, "zfar", 100.0))
# Sigmoid probability map based on the distance of the pixel to the face.
prob_map = torch.sigmoid(-fragments.dists / self.blend_params.sigma) * mask
# append extra face for zfar
dists = torch.cat(
(fragments.zbuf, torch.ones((N, H, W, 1), device=device) * zfar), dim=3
)
probs = torch.cat((prob_map, torch.ones((N, H, W, 1), device=device)), dim=3)
# compute weighting based off of probabilities using cumsum
probs = probs.cumsum(dim=3)
probs = probs.clamp(max=1)
probs = probs.diff(dim=3, prepend=torch.zeros((N, H, W, 1), device=device))
return (probs * dists).sum(dim=3).unsqueeze(3)

View File

@ -10,9 +10,11 @@ import torch
from pytorch3d.renderer.cameras import look_at_view_transform, PerspectiveCameras
from pytorch3d.renderer.mesh.rasterizer import Fragments
from pytorch3d.renderer.mesh.shader import (
HardDepthShader,
HardFlatShader,
HardGouraudShader,
HardPhongShader,
SoftDepthShader,
SoftPhongShader,
SplatterPhongShader,
)
@ -24,9 +26,11 @@ from .common_testing import TestCaseMixin
class TestShader(TestCaseMixin, unittest.TestCase):
def setUp(self):
self.shader_classes = [
HardDepthShader,
HardFlatShader,
HardGouraudShader,
HardPhongShader,
SoftDepthShader,
SoftPhongShader,
SplatterPhongShader,
]