mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 18:02:14 +08:00 
			
		
		
		
	Avoid raysampler dict
Summary: A significant speedup (e.g. >2% of a forward pass). Move NDCMultinomialRaysampler parts of AbstractMaskRaySampler to members instead of living in a dict. The dict was hiding them from the nn.Module system so their _xy_grid members were remaining on the CPU. Therefore they were being copied to the GPU in every forward pass. (We couldn't easily use a ModuleDict here because the enum keys are not strs.) Reviewed By: shapovalov Differential Revision: D39668589 fbshipit-source-id: 719b88e4a08fd7263a284e0ab38189e666bd7e3a
This commit is contained in:
		
							parent
							
								
									da7fe2854e
								
							
						
					
					
						commit
						305cf32f6b
					
				@ -100,34 +100,32 @@ class AbstractMaskRaySampler(RaySamplerBase, torch.nn.Module):
 | 
			
		||||
            ),
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        self._raysamplers = {
 | 
			
		||||
            EvaluationMode.TRAINING: NDCMultinomialRaysampler(
 | 
			
		||||
                image_width=self.image_width,
 | 
			
		||||
                image_height=self.image_height,
 | 
			
		||||
                n_pts_per_ray=self.n_pts_per_ray_training,
 | 
			
		||||
                min_depth=0.0,
 | 
			
		||||
                max_depth=0.0,
 | 
			
		||||
                n_rays_per_image=self.n_rays_per_image_sampled_from_mask
 | 
			
		||||
                if self._sampling_mode[EvaluationMode.TRAINING]
 | 
			
		||||
                == RenderSamplingMode.MASK_SAMPLE
 | 
			
		||||
                else None,
 | 
			
		||||
                unit_directions=True,
 | 
			
		||||
                stratified_sampling=self.stratified_point_sampling_training,
 | 
			
		||||
            ),
 | 
			
		||||
            EvaluationMode.EVALUATION: NDCMultinomialRaysampler(
 | 
			
		||||
                image_width=self.image_width,
 | 
			
		||||
                image_height=self.image_height,
 | 
			
		||||
                n_pts_per_ray=self.n_pts_per_ray_evaluation,
 | 
			
		||||
                min_depth=0.0,
 | 
			
		||||
                max_depth=0.0,
 | 
			
		||||
                n_rays_per_image=self.n_rays_per_image_sampled_from_mask
 | 
			
		||||
                if self._sampling_mode[EvaluationMode.EVALUATION]
 | 
			
		||||
                == RenderSamplingMode.MASK_SAMPLE
 | 
			
		||||
                else None,
 | 
			
		||||
                unit_directions=True,
 | 
			
		||||
                stratified_sampling=self.stratified_point_sampling_evaluation,
 | 
			
		||||
            ),
 | 
			
		||||
        }
 | 
			
		||||
        self._training_raysampler = NDCMultinomialRaysampler(
 | 
			
		||||
            image_width=self.image_width,
 | 
			
		||||
            image_height=self.image_height,
 | 
			
		||||
            n_pts_per_ray=self.n_pts_per_ray_training,
 | 
			
		||||
            min_depth=0.0,
 | 
			
		||||
            max_depth=0.0,
 | 
			
		||||
            n_rays_per_image=self.n_rays_per_image_sampled_from_mask
 | 
			
		||||
            if self._sampling_mode[EvaluationMode.TRAINING]
 | 
			
		||||
            == RenderSamplingMode.MASK_SAMPLE
 | 
			
		||||
            else None,
 | 
			
		||||
            unit_directions=True,
 | 
			
		||||
            stratified_sampling=self.stratified_point_sampling_training,
 | 
			
		||||
        )
 | 
			
		||||
        self._evaluation_raysampler = NDCMultinomialRaysampler(
 | 
			
		||||
            image_width=self.image_width,
 | 
			
		||||
            image_height=self.image_height,
 | 
			
		||||
            n_pts_per_ray=self.n_pts_per_ray_evaluation,
 | 
			
		||||
            min_depth=0.0,
 | 
			
		||||
            max_depth=0.0,
 | 
			
		||||
            n_rays_per_image=self.n_rays_per_image_sampled_from_mask
 | 
			
		||||
            if self._sampling_mode[EvaluationMode.EVALUATION]
 | 
			
		||||
            == RenderSamplingMode.MASK_SAMPLE
 | 
			
		||||
            else None,
 | 
			
		||||
            unit_directions=True,
 | 
			
		||||
            stratified_sampling=self.stratified_point_sampling_evaluation,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def _get_min_max_depth_bounds(self, cameras: CamerasBase) -> Tuple[float, float]:
 | 
			
		||||
        raise NotImplementedError()
 | 
			
		||||
@ -169,11 +167,13 @@ class AbstractMaskRaySampler(RaySamplerBase, torch.nn.Module):
 | 
			
		||||
 | 
			
		||||
        min_depth, max_depth = self._get_min_max_depth_bounds(cameras)
 | 
			
		||||
 | 
			
		||||
        raysampler = {
 | 
			
		||||
            EvaluationMode.TRAINING: self._training_raysampler,
 | 
			
		||||
            EvaluationMode.EVALUATION: self._evaluation_raysampler,
 | 
			
		||||
        }[evaluation_mode]
 | 
			
		||||
 | 
			
		||||
        # pyre-fixme[29]:
 | 
			
		||||
        #  `Union[BoundMethod[typing.Callable(torch.Tensor.__getitem__)[[Named(self,
 | 
			
		||||
        #  torch.Tensor), Named(item, typing.Any)], typing.Any], torch.Tensor],
 | 
			
		||||
        #  torch.Tensor, torch.nn.Module]` is not a function.
 | 
			
		||||
        ray_bundle = self._raysamplers[evaluation_mode](
 | 
			
		||||
        ray_bundle = raysampler(
 | 
			
		||||
            cameras=cameras,
 | 
			
		||||
            mask=sample_mask,
 | 
			
		||||
            min_depth=min_depth,
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user