mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 18:02:14 +08:00 
			
		
		
		
	Add .to methods to the splatter and SplatterPhongShader.
Summary: Needed to properly change devices during OpenGL rasterization. Reviewed By: jcjohnson Differential Revision: D37698568 fbshipit-source-id: 38968149d577322e662d3b5d04880204b0a7be29
This commit is contained in:
		
							parent
							
								
									78bb6d17fa
								
							
						
					
					
						commit
						36edf2b302
					
				@ -324,6 +324,11 @@ class SplatterPhongShader(ShaderBase):
 | 
			
		||||
        self.splatter_blender = None
 | 
			
		||||
        super().__init__(**kwargs)
 | 
			
		||||
 | 
			
		||||
    def to(self, device: Device):
 | 
			
		||||
        if self.splatter_blender:
 | 
			
		||||
            self.splatter_blender.to(device)
 | 
			
		||||
        return super().to(device)
 | 
			
		||||
 | 
			
		||||
    def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor:
 | 
			
		||||
        cameras = super()._get_cameras(**kwargs)
 | 
			
		||||
        texels = meshes.sample_textures(fragments)
 | 
			
		||||
@ -349,7 +354,7 @@ class SplatterPhongShader(ShaderBase):
 | 
			
		||||
            pixel_coords_cameras,
 | 
			
		||||
            cameras,
 | 
			
		||||
            fragments.pix_to_face < 0,
 | 
			
		||||
            self.blend_params,
 | 
			
		||||
            kwargs.get("blend_params", self.blend_params),
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        return images
 | 
			
		||||
@ -398,6 +403,9 @@ class SoftDepthShader(ShaderBase):
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor:
 | 
			
		||||
        if fragments.dists is None:
 | 
			
		||||
            raise ValueError("SoftDepthShader requires Fragments.dists to be present.")
 | 
			
		||||
 | 
			
		||||
        cameras = super()._get_cameras(**kwargs)
 | 
			
		||||
 | 
			
		||||
        N, H, W, K = fragments.pix_to_face.shape
 | 
			
		||||
 | 
			
		||||
@ -464,6 +464,12 @@ class SplatterBlender(torch.nn.Module):
 | 
			
		||||
            input_shape, device
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def to(self, device):
 | 
			
		||||
        self.offsets = self.offsets.to(device)
 | 
			
		||||
        self.crop_ids_h = self.crop_ids_h.to(device)
 | 
			
		||||
        self.crop_ids_w = self.crop_ids_w.to(device)
 | 
			
		||||
        super().to(device)
 | 
			
		||||
 | 
			
		||||
    def forward(
 | 
			
		||||
        self,
 | 
			
		||||
        colors: torch.Tensor,
 | 
			
		||||
 | 
			
		||||
@ -60,7 +60,7 @@ class TestShader(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
                self.assertIs(cpu_shader, cuda_shader)
 | 
			
		||||
                if cameras is None:
 | 
			
		||||
                    self.assertIsNone(cuda_shader.cameras)
 | 
			
		||||
                    with self.assertRaisesRegexp(ValueError, "Cameras must be"):
 | 
			
		||||
                    with self.assertRaisesRegex(ValueError, "Cameras must be"):
 | 
			
		||||
                        cuda_shader._get_cameras()
 | 
			
		||||
                else:
 | 
			
		||||
                    self.assertEqual(cuda_device, cuda_shader.cameras.device)
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user