From c311a4cbb93be458f8f48e7b269c6d3ee7fc2cf4 Mon Sep 17 00:00:00 2001 From: Darijan Gudelj Date: Mon, 3 Oct 2022 08:36:47 -0700 Subject: [PATCH] Enable mixed frame raysampling Summary: Changed ray_sampler and metrics to be able to use mixed frame raysampling. Ray_sampler now has a new member which it passes to the pytorch3d raysampler. If the raybundle is heterogeneous metrics now samples images by padding xys first. This reduces memory consumption. Reviewed By: bottler, kjchalup Differential Revision: D39542221 fbshipit-source-id: a6fec23838d3049ae5c2fd2e1f641c46c7c927e3 --- projects/implicitron_trainer/experiment.py | 1 + .../implicitron_trainer/tests/experiment.yaml | 2 + pytorch3d/implicitron/models/generic_model.py | 9 ++- pytorch3d/implicitron/models/metrics.py | 60 ++++++++++++++----- .../models/renderer/ray_point_refiner.py | 20 +++---- .../models/renderer/ray_sampler.py | 25 +++++++- pytorch3d/renderer/implicit/raysampling.py | 19 +++--- tests/implicitron/data/overrides.yaml | 1 + 8 files changed, 102 insertions(+), 35 deletions(-) diff --git a/projects/implicitron_trainer/experiment.py b/projects/implicitron_trainer/experiment.py index a033251a..7795a18e 100755 --- a/projects/implicitron_trainer/experiment.py +++ b/projects/implicitron_trainer/experiment.py @@ -222,6 +222,7 @@ class Experiment(Configurable): # pyre-ignore: 13 train_loader=train_loader, val_loader=val_loader, test_loader=test_loader, + train_dataset=datasets.train, model=model, optimizer=optimizer, scheduler=scheduler, diff --git a/projects/implicitron_trainer/tests/experiment.yaml b/projects/implicitron_trainer/tests/experiment.yaml index bd52beac..d6b6beed 100644 --- a/projects/implicitron_trainer/tests/experiment.yaml +++ b/projects/implicitron_trainer/tests/experiment.yaml @@ -197,6 +197,7 @@ model_factory_ImplicitronModelFactory_args: n_pts_per_ray_training: 64 n_pts_per_ray_evaluation: 64 n_rays_per_image_sampled_from_mask: 1024 + n_rays_total_training: null stratified_point_sampling_training: true stratified_point_sampling_evaluation: false scene_extent: 8.0 @@ -208,6 +209,7 @@ model_factory_ImplicitronModelFactory_args: n_pts_per_ray_training: 64 n_pts_per_ray_evaluation: 64 n_rays_per_image_sampled_from_mask: 1024 + n_rays_total_training: null stratified_point_sampling_training: true stratified_point_sampling_evaluation: false min_depth: 0.1 diff --git a/pytorch3d/implicitron/models/generic_model.py b/pytorch3d/implicitron/models/generic_model.py index 853e84ef..d2a4248c 100644 --- a/pytorch3d/implicitron/models/generic_model.py +++ b/pytorch3d/implicitron/models/generic_model.py @@ -473,7 +473,7 @@ class GenericModel(ImplicitronModelBase): # pyre-ignore: 13 self.view_metrics( results=preds, raymarched=rendered, - xys=ray_bundle.xys, + ray_bundle=ray_bundle, image_rgb=safe_slice_targets(image_rgb), depth_map=safe_slice_targets(depth_map), fg_probability=safe_slice_targets(fg_probability), @@ -932,6 +932,11 @@ def _chunk_generator( if len(iter) >= tqdm_trigger_threshold: iter = tqdm.tqdm(iter) + def _safe_slice( + tensor: Optional[torch.Tensor], start_idx: int, end_idx: int + ) -> Optional[torch.Tensor]: + return tensor[start_idx:end_idx] if tensor is not None else None + for start_idx in iter: end_idx = min(start_idx + chunk_size_in_rays, n_rays) ray_bundle_chunk = ImplicitronRayBundle( @@ -943,6 +948,8 @@ def _chunk_generator( :, start_idx:end_idx ], xys=ray_bundle.xys.reshape(batch_size, -1, 2)[:, start_idx:end_idx], + camera_ids=_safe_slice(ray_bundle.camera_ids, start_idx, end_idx), + camera_counts=_safe_slice(ray_bundle.camera_counts, start_idx, end_idx), ) extra_args = kwargs.copy() for k, v in chunked_inputs.items(): diff --git a/pytorch3d/implicitron/models/metrics.py b/pytorch3d/implicitron/models/metrics.py index 63a19724..cc44a518 100644 --- a/pytorch3d/implicitron/models/metrics.py +++ b/pytorch3d/implicitron/models/metrics.py @@ -9,8 +9,10 @@ import warnings from typing import Any, Dict, Optional import torch +from pytorch3d.implicitron.models.renderer.ray_sampler import ImplicitronRayBundle from pytorch3d.implicitron.tools import metric_utils as utils from pytorch3d.implicitron.tools.config import registry, ReplaceableBase +from pytorch3d.ops import packed_to_padded, padded_to_packed from pytorch3d.renderer import utils as rend_utils from .renderer.base import RendererOutput @@ -60,7 +62,7 @@ class ViewMetricsBase(ReplaceableBase, torch.nn.Module): def forward( self, raymarched: RendererOutput, - xys: torch.Tensor, + ray_bundle: ImplicitronRayBundle, image_rgb: Optional[torch.Tensor] = None, depth_map: Optional[torch.Tensor] = None, fg_probability: Optional[torch.Tensor] = None, @@ -79,10 +81,8 @@ class ViewMetricsBase(ReplaceableBase, torch.nn.Module): names of the output metrics `metric_name_i` with their corresponding values `metric_value_i` represented as 0-dimensional float tensors. raymarched: Output of the renderer. - xys: A tensor of shape `(B, ..., 2)` containing 2D image locations at which - the predictions are defined. All ground truth inputs are sampled at - these locations in order to extract values that correspond to the - predictions. + ray_bundle: ImplicitronRayBundle object which was used to produce the raymarched + object image_rgb: A tensor of shape `(B, H, W, 3)` containing ground truth rgb values. depth_map: A tensor of shape `(B, Hd, Wd, 1)` containing ground truth depth @@ -141,7 +141,7 @@ class ViewMetrics(ViewMetricsBase): def forward( self, raymarched: RendererOutput, - xys: torch.Tensor, + ray_bundle: ImplicitronRayBundle, image_rgb: Optional[torch.Tensor] = None, depth_map: Optional[torch.Tensor] = None, fg_probability: Optional[torch.Tensor] = None, @@ -165,10 +165,8 @@ class ViewMetrics(ViewMetricsBase): input 3D coordinates used to compute the eikonal loss. raymarched.aux["density_grid"]: A tensor of shape `(B, Hg, Wg, Dg, 1)` containing a `Hg x Wg x Dg` voxel grid of density values. - xys: A tensor of shape `(B, ..., 2)` containing 2D image locations at which - the predictions are defined. All ground truth inputs are sampled at - these locations in order to extract values that correspond to the - predictions. + ray_bundle: ImplicitronRayBundle object which was used to produce the raymarched + object image_rgb: A tensor of shape `(B, H, W, 3)` containing ground truth rgb values. depth_map: A tensor of shape `(B, Hd, Wd, 1)` containing ground truth depth @@ -209,7 +207,7 @@ class ViewMetrics(ViewMetricsBase): """ metrics = self._calculate_stage( raymarched, - xys, + ray_bundle, image_rgb, depth_map, fg_probability, @@ -221,7 +219,7 @@ class ViewMetrics(ViewMetricsBase): metrics.update( self( raymarched.prev_stage, - xys, + ray_bundle, image_rgb, depth_map, fg_probability, @@ -235,7 +233,7 @@ class ViewMetrics(ViewMetricsBase): def _calculate_stage( self, raymarched: RendererOutput, - xys: torch.Tensor, + ray_bundle: ImplicitronRayBundle, image_rgb: Optional[torch.Tensor] = None, depth_map: Optional[torch.Tensor] = None, fg_probability: Optional[torch.Tensor] = None, @@ -253,6 +251,27 @@ class ViewMetrics(ViewMetricsBase): _reshape_nongrid_var(x) for x in [raymarched.features, raymarched.masks, raymarched.depths] ] + xys = ray_bundle.xys + + # If ray_bundle is packed than we can sample images in padded state to lower + # memory requirements. Instead of having one image for every element in + # ray_bundle we can than have one image per unique sampled camera. + if ray_bundle.is_packed(): + # pyre-ignore[6] + cumsum = torch.cumsum(ray_bundle.camera_counts, dim=0, dtype=torch.long) + first_idxs = torch.cat( + ( + # pyre-ignore[16] + ray_bundle.camera_counts.new_zeros((1,), dtype=torch.long), + cumsum[:-1], + ) + ) + # pyre-ignore[16] + num_inputs = int(ray_bundle.camera_counts.sum()) + # pyre-ignore[6] + max_size = int(torch.max(ray_bundle.camera_counts)) + xys = packed_to_padded(xys, first_idxs, max_size) + # reshape the sampling grid as well # TODO: we can get rid of the singular dimension here and in _reshape_nongrid_var # now that we use rend_utils.ndc_grid_sample @@ -262,7 +281,20 @@ class ViewMetrics(ViewMetricsBase): def sample(tensor, mode): if tensor is None: return tensor - return rend_utils.ndc_grid_sample(tensor, xys, mode=mode) + if ray_bundle.is_packed(): + # select images that corespond to sampled cameras if raybundle is packed + tensor = tensor[ray_bundle.camera_ids] + result = rend_utils.ndc_grid_sample(tensor, xys, mode=mode) + if ray_bundle.is_packed(): + # Images after sampling are in a form [batch, 3, max_num_rays, 1], + # packed_to_padded combines first two dimensions so we need to swap 1st + # and 2nd dimension. the result is [n_rays_total_training, 1, 3, 1] + # (we use keepdim=True). + result = result.transpose(1, 2) + result = padded_to_packed(result, first_idxs, num_inputs)[:, None] + result = result.transpose(1, 2) + + return result # eval all results in this size image_rgb = sample(image_rgb, mode="bilinear") diff --git a/pytorch3d/implicitron/models/renderer/ray_point_refiner.py b/pytorch3d/implicitron/models/renderer/ray_point_refiner.py index 55fbc8d6..a69398c6 100644 --- a/pytorch3d/implicitron/models/renderer/ray_point_refiner.py +++ b/pytorch3d/implicitron/models/renderer/ray_point_refiner.py @@ -5,8 +5,9 @@ # LICENSE file in the root directory of this source tree. import torch +from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle from pytorch3d.implicitron.tools.config import Configurable, expand_args_fields -from pytorch3d.renderer import RayBundle + from pytorch3d.renderer.implicit.sample_pdf import sample_pdf @@ -42,13 +43,13 @@ class RayPointRefiner(Configurable, torch.nn.Module): def forward( self, - input_ray_bundle: RayBundle, + input_ray_bundle: ImplicitronRayBundle, ray_weights: torch.Tensor, **kwargs, - ) -> RayBundle: + ) -> ImplicitronRayBundle: """ Args: - input_ray_bundle: An instance of `RayBundle` specifying the + input_ray_bundle: An instance of `ImplicitronRayBundle` specifying the source rays for sampling of the probability distribution. ray_weights: A tensor of shape `(..., input_ray_bundle.legths.shape[-1])` with non-negative @@ -56,7 +57,7 @@ class RayPointRefiner(Configurable, torch.nn.Module): ray points from. Returns: - ray_bundle: A new `RayBundle` instance containing the input ray + ray_bundle: A new `ImplicitronRayBundle` instance containing the input ray points together with `n_pts_per_ray` additionally sampled points per ray. For each ray, the lengths are sorted. """ @@ -79,9 +80,6 @@ class RayPointRefiner(Configurable, torch.nn.Module): # Resort by depth. z_vals, _ = torch.sort(z_vals, dim=-1) - return RayBundle( - origins=input_ray_bundle.origins, - directions=input_ray_bundle.directions, - lengths=z_vals, - xys=input_ray_bundle.xys, - ) + new_bundle = ImplicitronRayBundle(**vars(input_ray_bundle)) + new_bundle.lengths = z_vals + return new_bundle diff --git a/pytorch3d/implicitron/models/renderer/ray_sampler.py b/pytorch3d/implicitron/models/renderer/ray_sampler.py index 6d3723ad..225084fc 100644 --- a/pytorch3d/implicitron/models/renderer/ray_sampler.py +++ b/pytorch3d/implicitron/models/renderer/ray_sampler.py @@ -72,7 +72,17 @@ class AbstractMaskRaySampler(RaySamplerBase, torch.nn.Module): sampling_mode_evaluation: Same as above but for evaluation. n_pts_per_ray_training: The number of points sampled along each ray during training. n_pts_per_ray_evaluation: The number of points sampled along each ray during evaluation. - n_rays_per_image_sampled_from_mask: The amount of rays to be sampled from the image grid + n_rays_per_image_sampled_from_mask: The amount of rays to be sampled from the image + grid. Given a batch of image grids, this many is sampled from each. + `n_rays_per_image_sampled_from_mask` and `n_rays_total_training` cannot both be + defined. + n_rays_total_training: (optional) How many rays in total to sample from the entire + batch of provided image grid. The result is as if `n_rays_total_training` + cameras/image grids were sampled with replacement from the cameras / image grids + provided and for every camera one ray was sampled. + `n_rays_per_image_sampled_from_mask` and `n_rays_total_training` cannot both be + defined, to use you have to set `n_rays_per_image` to None. + Used only for EvaluationMode.TRAINING. stratified_point_sampling_training: if set, performs stratified random sampling along the ray; otherwise takes ray points at deterministic offsets. stratified_point_sampling_evaluation: Same as above but for evaluation. @@ -85,7 +95,8 @@ class AbstractMaskRaySampler(RaySamplerBase, torch.nn.Module): sampling_mode_evaluation: str = "full_grid" n_pts_per_ray_training: int = 64 n_pts_per_ray_evaluation: int = 64 - n_rays_per_image_sampled_from_mask: int = 1024 + n_rays_per_image_sampled_from_mask: Optional[int] = 1024 + n_rays_total_training: Optional[int] = None # stratified sampling vs taking points at deterministic offsets stratified_point_sampling_training: bool = True stratified_point_sampling_evaluation: bool = False @@ -93,6 +104,14 @@ class AbstractMaskRaySampler(RaySamplerBase, torch.nn.Module): def __post_init__(self): super().__init__() + if (self.n_rays_per_image_sampled_from_mask is not None) and ( + self.n_rays_total_training is not None + ): + raise ValueError( + "Cannot both define n_rays_total_training and " + "n_rays_per_image_sampled_from_mask." + ) + self._sampling_mode = { EvaluationMode.TRAINING: RenderSamplingMode(self.sampling_mode_training), EvaluationMode.EVALUATION: RenderSamplingMode( @@ -110,9 +129,11 @@ class AbstractMaskRaySampler(RaySamplerBase, torch.nn.Module): if self._sampling_mode[EvaluationMode.TRAINING] == RenderSamplingMode.MASK_SAMPLE else None, + n_rays_total=self.n_rays_total_training, unit_directions=True, stratified_sampling=self.stratified_point_sampling_training, ) + self._evaluation_raysampler = NDCMultinomialRaysampler( image_width=self.image_width, image_height=self.image_height, diff --git a/pytorch3d/renderer/implicit/raysampling.py b/pytorch3d/renderer/implicit/raysampling.py index 033f783a..ac08b972 100644 --- a/pytorch3d/renderer/implicit/raysampling.py +++ b/pytorch3d/renderer/implicit/raysampling.py @@ -90,11 +90,12 @@ class MultinomialRaysampler(torch.nn.Module): min_depth: The minimum depth of a ray-point. max_depth: The maximum depth of a ray-point. n_rays_per_image: If given, this amount of rays are sampled from the grid. + `n_rays_per_image` and `n_rays_total` cannot both be defined. n_rays_total: How many rays in total to sample from the cameras provided. The result - is as if `n_rays_total` cameras were sampled with replacement from the - cameras provided and for every camera one ray was sampled. If set, this disables - `n_rays_per_image` and returns the HeterogeneousRayBundle with - batch_size=n_rays_total. + is as if `n_rays_total_training` cameras were sampled with replacement from the + cameras provided and for every camera one ray was sampled. If set returns the + HeterogeneousRayBundle with batch_size=n_rays_total. + `n_rays_per_image` and `n_rays_total` cannot both be defined. unit_directions: whether to normalize direction vectors in ray bundle. stratified_sampling: if True, performs stratified random sampling along the ray; otherwise takes ray points at deterministic offsets. @@ -144,13 +145,15 @@ class MultinomialRaysampler(torch.nn.Module): min_depth: The minimum depth of a ray-point. max_depth: The maximum depth of a ray-point. n_rays_per_image: If given, this amount of rays are sampled from the grid. + `n_rays_per_image` and `n_rays_total` cannot both be defined. n_pts_per_ray: The number of points sampled along each ray. stratified_sampling: if set, overrides stratified_sampling provided in __init__. n_rays_total: How many rays in total to sample from the cameras provided. The result is as if `n_rays_total_training` cameras were sampled with replacement from the - cameras provided and for every camera one ray was sampled. If set, returns the + cameras provided and for every camera one ray was sampled. If set returns the HeterogeneousRayBundle with batch_size=n_rays_total. + `n_rays_per_image` and `n_rays_total` cannot both be defined. Returns: A named tuple RayBundle or dataclass HeterogeneousRayBundle with the following fields: @@ -352,13 +355,15 @@ class MonteCarloRaysampler(torch.nn.Module): min_y: The smallest y-coordinate of each ray's source pixel. max_y: The largest y-coordinate of each ray's source pixel. n_rays_per_image: The number of rays randomly sampled in each camera. + `n_rays_per_image` and `n_rays_total` cannot both be defined. n_pts_per_ray: The number of points sampled along each ray. min_depth: The minimum depth of each ray-point. max_depth: The maximum depth of each ray-point. n_rays_total: How many rays in total to sample from the cameras provided. The result is as if `n_rays_total_training` cameras were sampled with replacement from the - cameras provided and for every camera one ray was sampled. If set, this returns - the HeterogeneousRayBundleyBundle with batch_size=n_rays_total. + cameras provided and for every camera one ray was sampled. If set returns the + HeterogeneousRayBundle with batch_size=n_rays_total. + `n_rays_per_image` and `n_rays_total` cannot both be defined. unit_directions: whether to normalize direction vectors in ray bundle. stratified_sampling: if True, performs stratified sampling in n_pts_per_ray bins for each ray; otherwise takes n_pts_per_ray deterministic points diff --git a/tests/implicitron/data/overrides.yaml b/tests/implicitron/data/overrides.yaml index 1414a5e8..e02f7bf6 100644 --- a/tests/implicitron/data/overrides.yaml +++ b/tests/implicitron/data/overrides.yaml @@ -59,6 +59,7 @@ raysampler_AdaptiveRaySampler_args: n_pts_per_ray_training: 64 n_pts_per_ray_evaluation: 64 n_rays_per_image_sampled_from_mask: 1024 + n_rays_total_training: null stratified_point_sampling_training: true stratified_point_sampling_evaluation: false scene_extent: 8.0