mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Avoid raysampler dict
Summary: A significant speedup (e.g. >2% of a forward pass). Move NDCMultinomialRaysampler parts of AbstractMaskRaySampler to members instead of living in a dict. The dict was hiding them from the nn.Module system so their _xy_grid members were remaining on the CPU. Therefore they were being copied to the GPU in every forward pass. (We couldn't easily use a ModuleDict here because the enum keys are not strs.) Reviewed By: shapovalov Differential Revision: D39668589 fbshipit-source-id: 719b88e4a08fd7263a284e0ab38189e666bd7e3a
This commit is contained in:
parent
da7fe2854e
commit
305cf32f6b
@ -100,34 +100,32 @@ class AbstractMaskRaySampler(RaySamplerBase, torch.nn.Module):
|
||||
),
|
||||
}
|
||||
|
||||
self._raysamplers = {
|
||||
EvaluationMode.TRAINING: NDCMultinomialRaysampler(
|
||||
image_width=self.image_width,
|
||||
image_height=self.image_height,
|
||||
n_pts_per_ray=self.n_pts_per_ray_training,
|
||||
min_depth=0.0,
|
||||
max_depth=0.0,
|
||||
n_rays_per_image=self.n_rays_per_image_sampled_from_mask
|
||||
if self._sampling_mode[EvaluationMode.TRAINING]
|
||||
== RenderSamplingMode.MASK_SAMPLE
|
||||
else None,
|
||||
unit_directions=True,
|
||||
stratified_sampling=self.stratified_point_sampling_training,
|
||||
),
|
||||
EvaluationMode.EVALUATION: NDCMultinomialRaysampler(
|
||||
image_width=self.image_width,
|
||||
image_height=self.image_height,
|
||||
n_pts_per_ray=self.n_pts_per_ray_evaluation,
|
||||
min_depth=0.0,
|
||||
max_depth=0.0,
|
||||
n_rays_per_image=self.n_rays_per_image_sampled_from_mask
|
||||
if self._sampling_mode[EvaluationMode.EVALUATION]
|
||||
== RenderSamplingMode.MASK_SAMPLE
|
||||
else None,
|
||||
unit_directions=True,
|
||||
stratified_sampling=self.stratified_point_sampling_evaluation,
|
||||
),
|
||||
}
|
||||
self._training_raysampler = NDCMultinomialRaysampler(
|
||||
image_width=self.image_width,
|
||||
image_height=self.image_height,
|
||||
n_pts_per_ray=self.n_pts_per_ray_training,
|
||||
min_depth=0.0,
|
||||
max_depth=0.0,
|
||||
n_rays_per_image=self.n_rays_per_image_sampled_from_mask
|
||||
if self._sampling_mode[EvaluationMode.TRAINING]
|
||||
== RenderSamplingMode.MASK_SAMPLE
|
||||
else None,
|
||||
unit_directions=True,
|
||||
stratified_sampling=self.stratified_point_sampling_training,
|
||||
)
|
||||
self._evaluation_raysampler = NDCMultinomialRaysampler(
|
||||
image_width=self.image_width,
|
||||
image_height=self.image_height,
|
||||
n_pts_per_ray=self.n_pts_per_ray_evaluation,
|
||||
min_depth=0.0,
|
||||
max_depth=0.0,
|
||||
n_rays_per_image=self.n_rays_per_image_sampled_from_mask
|
||||
if self._sampling_mode[EvaluationMode.EVALUATION]
|
||||
== RenderSamplingMode.MASK_SAMPLE
|
||||
else None,
|
||||
unit_directions=True,
|
||||
stratified_sampling=self.stratified_point_sampling_evaluation,
|
||||
)
|
||||
|
||||
def _get_min_max_depth_bounds(self, cameras: CamerasBase) -> Tuple[float, float]:
|
||||
raise NotImplementedError()
|
||||
@ -169,11 +167,13 @@ class AbstractMaskRaySampler(RaySamplerBase, torch.nn.Module):
|
||||
|
||||
min_depth, max_depth = self._get_min_max_depth_bounds(cameras)
|
||||
|
||||
raysampler = {
|
||||
EvaluationMode.TRAINING: self._training_raysampler,
|
||||
EvaluationMode.EVALUATION: self._evaluation_raysampler,
|
||||
}[evaluation_mode]
|
||||
|
||||
# pyre-fixme[29]:
|
||||
# `Union[BoundMethod[typing.Callable(torch.Tensor.__getitem__)[[Named(self,
|
||||
# torch.Tensor), Named(item, typing.Any)], typing.Any], torch.Tensor],
|
||||
# torch.Tensor, torch.nn.Module]` is not a function.
|
||||
ray_bundle = self._raysamplers[evaluation_mode](
|
||||
ray_bundle = raysampler(
|
||||
cameras=cameras,
|
||||
mask=sample_mask,
|
||||
min_depth=min_depth,
|
||||
|
Loading…
x
Reference in New Issue
Block a user