From 4ecc9ea89d55b51c6ad66996ff0edd013ded0815 Mon Sep 17 00:00:00 2001 From: Tristan Rice Date: Tue, 12 Jul 2022 04:38:33 -0700 Subject: [PATCH] shader: fix HardDepthShader sizes + tests (#1252) Summary: This fixes a indexing bug in HardDepthShader and adds proper unit tests for both of the DepthShaders. This bug was introduced when updating the shader sizes and discovered when I switched my local model onto pytorch3d trunk instead of the patched copy. Pull Request resolved: https://github.com/facebookresearch/pytorch3d/pull/1252 Test Plan: Unit test + custom model code ``` pytest tests/test_shader.py ``` ![image](https://user-images.githubusercontent.com/909104/178397456-f478d0e0-9f6c-467a-a85b-adb4c47adfee.png) Reviewed By: bottler Differential Revision: D37775767 Pulled By: d4l3k fbshipit-source-id: 5f001903985976d7067d1fa0a3102d602790e3e8 --- pytorch3d/renderer/mesh/shader.py | 6 +++--- tests/test_shader.py | 32 +++++++++++++++++++++++++++++++ 2 files changed, 35 insertions(+), 3 deletions(-) diff --git a/pytorch3d/renderer/mesh/shader.py b/pytorch3d/renderer/mesh/shader.py index 72de786d..a919e0ca 100644 --- a/pytorch3d/renderer/mesh/shader.py +++ b/pytorch3d/renderer/mesh/shader.py @@ -374,11 +374,11 @@ class HardDepthShader(ShaderBase): cameras = super()._get_cameras(**kwargs) zfar = kwargs.get("zfar", getattr(cameras, "zfar", 100.0)) - mask = fragments.pix_to_face < 0 + mask = fragments.pix_to_face[..., 0:1] < 0 - zbuf = fragments.zbuf[..., 0].clone() + zbuf = fragments.zbuf[..., 0:1].clone() zbuf[mask] = zfar - return zbuf.unsqueeze(3) + return zbuf class SoftDepthShader(ShaderBase): diff --git a/tests/test_shader.py b/tests/test_shader.py index 93ef71c7..f4ac8116 100644 --- a/tests/test_shader.py +++ b/tests/test_shader.py @@ -91,3 +91,35 @@ class TestShader(TestCaseMixin, unittest.TestCase): with self.assertRaises(ValueError): shader(fragments, meshes) + + def test_depth_shader(self): + shader_classes = [ + HardDepthShader, + SoftDepthShader, + ] + + verts = torch.tensor( + [[-1, -1, 0], [1, -1, 1], [1, 1, 0], [-1, 1, 1]], dtype=torch.float32 + ) + faces = torch.tensor([[0, 1, 2], [2, 3, 0]], dtype=torch.int64) + meshes = Meshes(verts=[verts], faces=[faces]) + + pix_to_face = torch.tensor([0, 1], dtype=torch.int64).view(1, 1, 1, 2) + barycentric_coords = torch.tensor( + [[0.1, 0.2, 0.7], [0.3, 0.5, 0.2]], dtype=torch.float32 + ).view(1, 1, 1, 2, -1) + for faces_per_pixel in [1, 2]: + fragments = Fragments( + pix_to_face=pix_to_face[:, :, :, :faces_per_pixel], + bary_coords=barycentric_coords[:, :, :, :faces_per_pixel], + zbuf=torch.ones_like(pix_to_face), + dists=torch.ones_like(pix_to_face), + ) + R, T = look_at_view_transform() + cameras = PerspectiveCameras(R=R, T=T) + + for shader_class in shader_classes: + shader = shader_class() + + out = shader(fragments, meshes, cameras=cameras) + self.assertEqual(out.shape, (1, 1, 1, 1))