Avoid raysampler dict

Summary:
A significant speedup (e.g. >2% of a forward pass).

Move NDCMultinomialRaysampler parts of AbstractMaskRaySampler to members instead of living in a dict. The dict was hiding them from the nn.Module system so their _xy_grid members were remaining on the CPU. Therefore they were being copied to the GPU in every forward pass.

(We couldn't easily use a ModuleDict here because the enum keys are not strs.)

Reviewed By: shapovalov

Differential Revision: D39668589

fbshipit-source-id: 719b88e4a08fd7263a284e0ab38189e666bd7e3a
This commit is contained in:
Jeremy Reizenstein 2022-09-21 04:29:44 -07:00 committed by Facebook GitHub Bot
parent da7fe2854e
commit 305cf32f6b

View File

@ -100,8 +100,7 @@ class AbstractMaskRaySampler(RaySamplerBase, torch.nn.Module):
), ),
} }
self._raysamplers = { self._training_raysampler = NDCMultinomialRaysampler(
EvaluationMode.TRAINING: NDCMultinomialRaysampler(
image_width=self.image_width, image_width=self.image_width,
image_height=self.image_height, image_height=self.image_height,
n_pts_per_ray=self.n_pts_per_ray_training, n_pts_per_ray=self.n_pts_per_ray_training,
@ -113,8 +112,8 @@ class AbstractMaskRaySampler(RaySamplerBase, torch.nn.Module):
else None, else None,
unit_directions=True, unit_directions=True,
stratified_sampling=self.stratified_point_sampling_training, stratified_sampling=self.stratified_point_sampling_training,
), )
EvaluationMode.EVALUATION: NDCMultinomialRaysampler( self._evaluation_raysampler = NDCMultinomialRaysampler(
image_width=self.image_width, image_width=self.image_width,
image_height=self.image_height, image_height=self.image_height,
n_pts_per_ray=self.n_pts_per_ray_evaluation, n_pts_per_ray=self.n_pts_per_ray_evaluation,
@ -126,8 +125,7 @@ class AbstractMaskRaySampler(RaySamplerBase, torch.nn.Module):
else None, else None,
unit_directions=True, unit_directions=True,
stratified_sampling=self.stratified_point_sampling_evaluation, stratified_sampling=self.stratified_point_sampling_evaluation,
), )
}
def _get_min_max_depth_bounds(self, cameras: CamerasBase) -> Tuple[float, float]: def _get_min_max_depth_bounds(self, cameras: CamerasBase) -> Tuple[float, float]:
raise NotImplementedError() raise NotImplementedError()
@ -169,11 +167,13 @@ class AbstractMaskRaySampler(RaySamplerBase, torch.nn.Module):
min_depth, max_depth = self._get_min_max_depth_bounds(cameras) min_depth, max_depth = self._get_min_max_depth_bounds(cameras)
raysampler = {
EvaluationMode.TRAINING: self._training_raysampler,
EvaluationMode.EVALUATION: self._evaluation_raysampler,
}[evaluation_mode]
# pyre-fixme[29]: # pyre-fixme[29]:
# `Union[BoundMethod[typing.Callable(torch.Tensor.__getitem__)[[Named(self, ray_bundle = raysampler(
# torch.Tensor), Named(item, typing.Any)], typing.Any], torch.Tensor],
# torch.Tensor, torch.nn.Module]` is not a function.
ray_bundle = self._raysamplers[evaluation_mode](
cameras=cameras, cameras=cameras,
mask=sample_mask, mask=sample_mask,
min_depth=min_depth, min_depth=min_depth,