mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 20:02:49 +08:00
Avoid raysampler dict
Summary: A significant speedup (e.g. >2% of a forward pass). Move NDCMultinomialRaysampler parts of AbstractMaskRaySampler to members instead of living in a dict. The dict was hiding them from the nn.Module system so their _xy_grid members were remaining on the CPU. Therefore they were being copied to the GPU in every forward pass. (We couldn't easily use a ModuleDict here because the enum keys are not strs.) Reviewed By: shapovalov Differential Revision: D39668589 fbshipit-source-id: 719b88e4a08fd7263a284e0ab38189e666bd7e3a
This commit is contained in:
parent
da7fe2854e
commit
305cf32f6b
@ -100,8 +100,7 @@ class AbstractMaskRaySampler(RaySamplerBase, torch.nn.Module):
|
|||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
self._raysamplers = {
|
self._training_raysampler = NDCMultinomialRaysampler(
|
||||||
EvaluationMode.TRAINING: NDCMultinomialRaysampler(
|
|
||||||
image_width=self.image_width,
|
image_width=self.image_width,
|
||||||
image_height=self.image_height,
|
image_height=self.image_height,
|
||||||
n_pts_per_ray=self.n_pts_per_ray_training,
|
n_pts_per_ray=self.n_pts_per_ray_training,
|
||||||
@ -113,8 +112,8 @@ class AbstractMaskRaySampler(RaySamplerBase, torch.nn.Module):
|
|||||||
else None,
|
else None,
|
||||||
unit_directions=True,
|
unit_directions=True,
|
||||||
stratified_sampling=self.stratified_point_sampling_training,
|
stratified_sampling=self.stratified_point_sampling_training,
|
||||||
),
|
)
|
||||||
EvaluationMode.EVALUATION: NDCMultinomialRaysampler(
|
self._evaluation_raysampler = NDCMultinomialRaysampler(
|
||||||
image_width=self.image_width,
|
image_width=self.image_width,
|
||||||
image_height=self.image_height,
|
image_height=self.image_height,
|
||||||
n_pts_per_ray=self.n_pts_per_ray_evaluation,
|
n_pts_per_ray=self.n_pts_per_ray_evaluation,
|
||||||
@ -126,8 +125,7 @@ class AbstractMaskRaySampler(RaySamplerBase, torch.nn.Module):
|
|||||||
else None,
|
else None,
|
||||||
unit_directions=True,
|
unit_directions=True,
|
||||||
stratified_sampling=self.stratified_point_sampling_evaluation,
|
stratified_sampling=self.stratified_point_sampling_evaluation,
|
||||||
),
|
)
|
||||||
}
|
|
||||||
|
|
||||||
def _get_min_max_depth_bounds(self, cameras: CamerasBase) -> Tuple[float, float]:
|
def _get_min_max_depth_bounds(self, cameras: CamerasBase) -> Tuple[float, float]:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
@ -169,11 +167,13 @@ class AbstractMaskRaySampler(RaySamplerBase, torch.nn.Module):
|
|||||||
|
|
||||||
min_depth, max_depth = self._get_min_max_depth_bounds(cameras)
|
min_depth, max_depth = self._get_min_max_depth_bounds(cameras)
|
||||||
|
|
||||||
|
raysampler = {
|
||||||
|
EvaluationMode.TRAINING: self._training_raysampler,
|
||||||
|
EvaluationMode.EVALUATION: self._evaluation_raysampler,
|
||||||
|
}[evaluation_mode]
|
||||||
|
|
||||||
# pyre-fixme[29]:
|
# pyre-fixme[29]:
|
||||||
# `Union[BoundMethod[typing.Callable(torch.Tensor.__getitem__)[[Named(self,
|
ray_bundle = raysampler(
|
||||||
# torch.Tensor), Named(item, typing.Any)], typing.Any], torch.Tensor],
|
|
||||||
# torch.Tensor, torch.nn.Module]` is not a function.
|
|
||||||
ray_bundle = self._raysamplers[evaluation_mode](
|
|
||||||
cameras=cameras,
|
cameras=cameras,
|
||||||
mask=sample_mask,
|
mask=sample_mask,
|
||||||
min_depth=min_depth,
|
min_depth=min_depth,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user