Avoid raysampler dict

Summary:
A significant speedup (e.g. >2% of a forward pass).

Move NDCMultinomialRaysampler parts of AbstractMaskRaySampler to members instead of living in a dict. The dict was hiding them from the nn.Module system so their _xy_grid members were remaining on the CPU. Therefore they were being copied to the GPU in every forward pass.

(We couldn't easily use a ModuleDict here because the enum keys are not strs.)

Reviewed By: shapovalov

Differential Revision: D39668589

fbshipit-source-id: 719b88e4a08fd7263a284e0ab38189e666bd7e3a
This commit is contained in:
Jeremy Reizenstein 2022-09-21 04:29:44 -07:00 committed by Facebook GitHub Bot
parent da7fe2854e
commit 305cf32f6b

View File

@ -100,34 +100,32 @@ class AbstractMaskRaySampler(RaySamplerBase, torch.nn.Module):
),
}
self._raysamplers = {
EvaluationMode.TRAINING: NDCMultinomialRaysampler(
image_width=self.image_width,
image_height=self.image_height,
n_pts_per_ray=self.n_pts_per_ray_training,
min_depth=0.0,
max_depth=0.0,
n_rays_per_image=self.n_rays_per_image_sampled_from_mask
if self._sampling_mode[EvaluationMode.TRAINING]
== RenderSamplingMode.MASK_SAMPLE
else None,
unit_directions=True,
stratified_sampling=self.stratified_point_sampling_training,
),
EvaluationMode.EVALUATION: NDCMultinomialRaysampler(
image_width=self.image_width,
image_height=self.image_height,
n_pts_per_ray=self.n_pts_per_ray_evaluation,
min_depth=0.0,
max_depth=0.0,
n_rays_per_image=self.n_rays_per_image_sampled_from_mask
if self._sampling_mode[EvaluationMode.EVALUATION]
== RenderSamplingMode.MASK_SAMPLE
else None,
unit_directions=True,
stratified_sampling=self.stratified_point_sampling_evaluation,
),
}
self._training_raysampler = NDCMultinomialRaysampler(
image_width=self.image_width,
image_height=self.image_height,
n_pts_per_ray=self.n_pts_per_ray_training,
min_depth=0.0,
max_depth=0.0,
n_rays_per_image=self.n_rays_per_image_sampled_from_mask
if self._sampling_mode[EvaluationMode.TRAINING]
== RenderSamplingMode.MASK_SAMPLE
else None,
unit_directions=True,
stratified_sampling=self.stratified_point_sampling_training,
)
self._evaluation_raysampler = NDCMultinomialRaysampler(
image_width=self.image_width,
image_height=self.image_height,
n_pts_per_ray=self.n_pts_per_ray_evaluation,
min_depth=0.0,
max_depth=0.0,
n_rays_per_image=self.n_rays_per_image_sampled_from_mask
if self._sampling_mode[EvaluationMode.EVALUATION]
== RenderSamplingMode.MASK_SAMPLE
else None,
unit_directions=True,
stratified_sampling=self.stratified_point_sampling_evaluation,
)
def _get_min_max_depth_bounds(self, cameras: CamerasBase) -> Tuple[float, float]:
raise NotImplementedError()
@ -169,11 +167,13 @@ class AbstractMaskRaySampler(RaySamplerBase, torch.nn.Module):
min_depth, max_depth = self._get_min_max_depth_bounds(cameras)
raysampler = {
EvaluationMode.TRAINING: self._training_raysampler,
EvaluationMode.EVALUATION: self._evaluation_raysampler,
}[evaluation_mode]
# pyre-fixme[29]:
# `Union[BoundMethod[typing.Callable(torch.Tensor.__getitem__)[[Named(self,
# torch.Tensor), Named(item, typing.Any)], typing.Any], torch.Tensor],
# torch.Tensor, torch.nn.Module]` is not a function.
ray_bundle = self._raysamplers[evaluation_mode](
ray_bundle = raysampler(
cameras=cameras,
mask=sample_mask,
min_depth=min_depth,