diff --git a/pytorch3d/renderer/mesh/shader.py b/pytorch3d/renderer/mesh/shader.py index a919e0ca..71b32c0b 100644 --- a/pytorch3d/renderer/mesh/shader.py +++ b/pytorch3d/renderer/mesh/shader.py @@ -324,6 +324,11 @@ class SplatterPhongShader(ShaderBase): self.splatter_blender = None super().__init__(**kwargs) + def to(self, device: Device): + if self.splatter_blender: + self.splatter_blender.to(device) + return super().to(device) + def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor: cameras = super()._get_cameras(**kwargs) texels = meshes.sample_textures(fragments) @@ -349,7 +354,7 @@ class SplatterPhongShader(ShaderBase): pixel_coords_cameras, cameras, fragments.pix_to_face < 0, - self.blend_params, + kwargs.get("blend_params", self.blend_params), ) return images @@ -398,6 +403,9 @@ class SoftDepthShader(ShaderBase): """ def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor: + if fragments.dists is None: + raise ValueError("SoftDepthShader requires Fragments.dists to be present.") + cameras = super()._get_cameras(**kwargs) N, H, W, K = fragments.pix_to_face.shape diff --git a/pytorch3d/renderer/splatter_blend.py b/pytorch3d/renderer/splatter_blend.py index 26107663..0149dfb3 100644 --- a/pytorch3d/renderer/splatter_blend.py +++ b/pytorch3d/renderer/splatter_blend.py @@ -464,6 +464,12 @@ class SplatterBlender(torch.nn.Module): input_shape, device ) + def to(self, device): + self.offsets = self.offsets.to(device) + self.crop_ids_h = self.crop_ids_h.to(device) + self.crop_ids_w = self.crop_ids_w.to(device) + super().to(device) + def forward( self, colors: torch.Tensor, diff --git a/tests/test_shader.py b/tests/test_shader.py index f4ac8116..396132e1 100644 --- a/tests/test_shader.py +++ b/tests/test_shader.py @@ -60,7 +60,7 @@ class TestShader(TestCaseMixin, unittest.TestCase): self.assertIs(cpu_shader, cuda_shader) if cameras is None: self.assertIsNone(cuda_shader.cameras) - with self.assertRaisesRegexp(ValueError, "Cameras must be"): + with self.assertRaisesRegex(ValueError, "Cameras must be"): cuda_shader._get_cameras() else: self.assertEqual(cuda_device, cuda_shader.cameras.device)