diff --git a/pytorch3d/implicitron/models/renderer/ray_sampler.py b/pytorch3d/implicitron/models/renderer/ray_sampler.py index b876d906..76f9f5bc 100644 --- a/pytorch3d/implicitron/models/renderer/ray_sampler.py +++ b/pytorch3d/implicitron/models/renderer/ray_sampler.py @@ -100,34 +100,32 @@ class AbstractMaskRaySampler(RaySamplerBase, torch.nn.Module): ), } - self._raysamplers = { - EvaluationMode.TRAINING: NDCMultinomialRaysampler( - image_width=self.image_width, - image_height=self.image_height, - n_pts_per_ray=self.n_pts_per_ray_training, - min_depth=0.0, - max_depth=0.0, - n_rays_per_image=self.n_rays_per_image_sampled_from_mask - if self._sampling_mode[EvaluationMode.TRAINING] - == RenderSamplingMode.MASK_SAMPLE - else None, - unit_directions=True, - stratified_sampling=self.stratified_point_sampling_training, - ), - EvaluationMode.EVALUATION: NDCMultinomialRaysampler( - image_width=self.image_width, - image_height=self.image_height, - n_pts_per_ray=self.n_pts_per_ray_evaluation, - min_depth=0.0, - max_depth=0.0, - n_rays_per_image=self.n_rays_per_image_sampled_from_mask - if self._sampling_mode[EvaluationMode.EVALUATION] - == RenderSamplingMode.MASK_SAMPLE - else None, - unit_directions=True, - stratified_sampling=self.stratified_point_sampling_evaluation, - ), - } + self._training_raysampler = NDCMultinomialRaysampler( + image_width=self.image_width, + image_height=self.image_height, + n_pts_per_ray=self.n_pts_per_ray_training, + min_depth=0.0, + max_depth=0.0, + n_rays_per_image=self.n_rays_per_image_sampled_from_mask + if self._sampling_mode[EvaluationMode.TRAINING] + == RenderSamplingMode.MASK_SAMPLE + else None, + unit_directions=True, + stratified_sampling=self.stratified_point_sampling_training, + ) + self._evaluation_raysampler = NDCMultinomialRaysampler( + image_width=self.image_width, + image_height=self.image_height, + n_pts_per_ray=self.n_pts_per_ray_evaluation, + min_depth=0.0, + max_depth=0.0, + n_rays_per_image=self.n_rays_per_image_sampled_from_mask + if self._sampling_mode[EvaluationMode.EVALUATION] + == RenderSamplingMode.MASK_SAMPLE + else None, + unit_directions=True, + stratified_sampling=self.stratified_point_sampling_evaluation, + ) def _get_min_max_depth_bounds(self, cameras: CamerasBase) -> Tuple[float, float]: raise NotImplementedError() @@ -169,11 +167,13 @@ class AbstractMaskRaySampler(RaySamplerBase, torch.nn.Module): min_depth, max_depth = self._get_min_max_depth_bounds(cameras) + raysampler = { + EvaluationMode.TRAINING: self._training_raysampler, + EvaluationMode.EVALUATION: self._evaluation_raysampler, + }[evaluation_mode] + # pyre-fixme[29]: - # `Union[BoundMethod[typing.Callable(torch.Tensor.__getitem__)[[Named(self, - # torch.Tensor), Named(item, typing.Any)], typing.Any], torch.Tensor], - # torch.Tensor, torch.nn.Module]` is not a function. - ray_bundle = self._raysamplers[evaluation_mode]( + ray_bundle = raysampler( cameras=cameras, mask=sample_mask, min_depth=min_depth,