Add .to methods to the splatter and SplatterPhongShader.

Summary: Needed to properly change devices during OpenGL rasterization.

Reviewed By: jcjohnson

Differential Revision: D37698568

fbshipit-source-id: 38968149d577322e662d3b5d04880204b0a7be29
This commit is contained in:
Krzysztof Chalupka 2022-07-22 14:36:22 -07:00 committed by Facebook GitHub Bot
parent 78bb6d17fa
commit 36edf2b302
3 changed files with 16 additions and 2 deletions

View File

@ -324,6 +324,11 @@ class SplatterPhongShader(ShaderBase):
self.splatter_blender = None
super().__init__(**kwargs)
def to(self, device: Device):
if self.splatter_blender:
self.splatter_blender.to(device)
return super().to(device)
def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor:
cameras = super()._get_cameras(**kwargs)
texels = meshes.sample_textures(fragments)
@ -349,7 +354,7 @@ class SplatterPhongShader(ShaderBase):
pixel_coords_cameras,
cameras,
fragments.pix_to_face < 0,
self.blend_params,
kwargs.get("blend_params", self.blend_params),
)
return images
@ -398,6 +403,9 @@ class SoftDepthShader(ShaderBase):
"""
def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor:
if fragments.dists is None:
raise ValueError("SoftDepthShader requires Fragments.dists to be present.")
cameras = super()._get_cameras(**kwargs)
N, H, W, K = fragments.pix_to_face.shape

View File

@ -464,6 +464,12 @@ class SplatterBlender(torch.nn.Module):
input_shape, device
)
def to(self, device):
self.offsets = self.offsets.to(device)
self.crop_ids_h = self.crop_ids_h.to(device)
self.crop_ids_w = self.crop_ids_w.to(device)
super().to(device)
def forward(
self,
colors: torch.Tensor,

View File

@ -60,7 +60,7 @@ class TestShader(TestCaseMixin, unittest.TestCase):
self.assertIs(cpu_shader, cuda_shader)
if cameras is None:
self.assertIsNone(cuda_shader.cameras)
with self.assertRaisesRegexp(ValueError, "Cameras must be"):
with self.assertRaisesRegex(ValueError, "Cameras must be"):
cuda_shader._get_cameras()
else:
self.assertEqual(cuda_device, cuda_shader.cameras.device)