mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-01 03:12:49 +08:00
Add .to methods to the splatter and SplatterPhongShader.
Summary: Needed to properly change devices during OpenGL rasterization. Reviewed By: jcjohnson Differential Revision: D37698568 fbshipit-source-id: 38968149d577322e662d3b5d04880204b0a7be29
This commit is contained in:
parent
78bb6d17fa
commit
36edf2b302
@ -324,6 +324,11 @@ class SplatterPhongShader(ShaderBase):
|
|||||||
self.splatter_blender = None
|
self.splatter_blender = None
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
def to(self, device: Device):
|
||||||
|
if self.splatter_blender:
|
||||||
|
self.splatter_blender.to(device)
|
||||||
|
return super().to(device)
|
||||||
|
|
||||||
def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor:
|
def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor:
|
||||||
cameras = super()._get_cameras(**kwargs)
|
cameras = super()._get_cameras(**kwargs)
|
||||||
texels = meshes.sample_textures(fragments)
|
texels = meshes.sample_textures(fragments)
|
||||||
@ -349,7 +354,7 @@ class SplatterPhongShader(ShaderBase):
|
|||||||
pixel_coords_cameras,
|
pixel_coords_cameras,
|
||||||
cameras,
|
cameras,
|
||||||
fragments.pix_to_face < 0,
|
fragments.pix_to_face < 0,
|
||||||
self.blend_params,
|
kwargs.get("blend_params", self.blend_params),
|
||||||
)
|
)
|
||||||
|
|
||||||
return images
|
return images
|
||||||
@ -398,6 +403,9 @@ class SoftDepthShader(ShaderBase):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor:
|
def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor:
|
||||||
|
if fragments.dists is None:
|
||||||
|
raise ValueError("SoftDepthShader requires Fragments.dists to be present.")
|
||||||
|
|
||||||
cameras = super()._get_cameras(**kwargs)
|
cameras = super()._get_cameras(**kwargs)
|
||||||
|
|
||||||
N, H, W, K = fragments.pix_to_face.shape
|
N, H, W, K = fragments.pix_to_face.shape
|
||||||
|
@ -464,6 +464,12 @@ class SplatterBlender(torch.nn.Module):
|
|||||||
input_shape, device
|
input_shape, device
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def to(self, device):
|
||||||
|
self.offsets = self.offsets.to(device)
|
||||||
|
self.crop_ids_h = self.crop_ids_h.to(device)
|
||||||
|
self.crop_ids_w = self.crop_ids_w.to(device)
|
||||||
|
super().to(device)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
colors: torch.Tensor,
|
colors: torch.Tensor,
|
||||||
|
@ -60,7 +60,7 @@ class TestShader(TestCaseMixin, unittest.TestCase):
|
|||||||
self.assertIs(cpu_shader, cuda_shader)
|
self.assertIs(cpu_shader, cuda_shader)
|
||||||
if cameras is None:
|
if cameras is None:
|
||||||
self.assertIsNone(cuda_shader.cameras)
|
self.assertIsNone(cuda_shader.cameras)
|
||||||
with self.assertRaisesRegexp(ValueError, "Cameras must be"):
|
with self.assertRaisesRegex(ValueError, "Cameras must be"):
|
||||||
cuda_shader._get_cameras()
|
cuda_shader._get_cameras()
|
||||||
else:
|
else:
|
||||||
self.assertEqual(cuda_device, cuda_shader.cameras.device)
|
self.assertEqual(cuda_device, cuda_shader.cameras.device)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user