mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
shader: fix HardDepthShader sizes + tests (#1252)
Summary: This fixes a indexing bug in HardDepthShader and adds proper unit tests for both of the DepthShaders. This bug was introduced when updating the shader sizes and discovered when I switched my local model onto pytorch3d trunk instead of the patched copy. Pull Request resolved: https://github.com/facebookresearch/pytorch3d/pull/1252 Test Plan: Unit test + custom model code ``` pytest tests/test_shader.py ```  Reviewed By: bottler Differential Revision: D37775767 Pulled By: d4l3k fbshipit-source-id: 5f001903985976d7067d1fa0a3102d602790e3e8
This commit is contained in:
parent
8d10ba52b2
commit
4ecc9ea89d
@ -374,11 +374,11 @@ class HardDepthShader(ShaderBase):
|
||||
cameras = super()._get_cameras(**kwargs)
|
||||
|
||||
zfar = kwargs.get("zfar", getattr(cameras, "zfar", 100.0))
|
||||
mask = fragments.pix_to_face < 0
|
||||
mask = fragments.pix_to_face[..., 0:1] < 0
|
||||
|
||||
zbuf = fragments.zbuf[..., 0].clone()
|
||||
zbuf = fragments.zbuf[..., 0:1].clone()
|
||||
zbuf[mask] = zfar
|
||||
return zbuf.unsqueeze(3)
|
||||
return zbuf
|
||||
|
||||
|
||||
class SoftDepthShader(ShaderBase):
|
||||
|
@ -91,3 +91,35 @@ class TestShader(TestCaseMixin, unittest.TestCase):
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
shader(fragments, meshes)
|
||||
|
||||
def test_depth_shader(self):
|
||||
shader_classes = [
|
||||
HardDepthShader,
|
||||
SoftDepthShader,
|
||||
]
|
||||
|
||||
verts = torch.tensor(
|
||||
[[-1, -1, 0], [1, -1, 1], [1, 1, 0], [-1, 1, 1]], dtype=torch.float32
|
||||
)
|
||||
faces = torch.tensor([[0, 1, 2], [2, 3, 0]], dtype=torch.int64)
|
||||
meshes = Meshes(verts=[verts], faces=[faces])
|
||||
|
||||
pix_to_face = torch.tensor([0, 1], dtype=torch.int64).view(1, 1, 1, 2)
|
||||
barycentric_coords = torch.tensor(
|
||||
[[0.1, 0.2, 0.7], [0.3, 0.5, 0.2]], dtype=torch.float32
|
||||
).view(1, 1, 1, 2, -1)
|
||||
for faces_per_pixel in [1, 2]:
|
||||
fragments = Fragments(
|
||||
pix_to_face=pix_to_face[:, :, :, :faces_per_pixel],
|
||||
bary_coords=barycentric_coords[:, :, :, :faces_per_pixel],
|
||||
zbuf=torch.ones_like(pix_to_face),
|
||||
dists=torch.ones_like(pix_to_face),
|
||||
)
|
||||
R, T = look_at_view_transform()
|
||||
cameras = PerspectiveCameras(R=R, T=T)
|
||||
|
||||
for shader_class in shader_classes:
|
||||
shader = shader_class()
|
||||
|
||||
out = shader(fragments, meshes, cameras=cameras)
|
||||
self.assertEqual(out.shape, (1, 1, 1, 1))
|
||||
|
Loading…
x
Reference in New Issue
Block a user