Enable mixed frame raysampling

Summary:
Changed ray_sampler and metrics to be able to use mixed frame raysampling.

Ray_sampler now has a new member which it passes to the pytorch3d raysampler.
If the raybundle is heterogeneous metrics now samples images by padding xys first. This reduces memory consumption.

Reviewed By: bottler, kjchalup

Differential Revision: D39542221

fbshipit-source-id: a6fec23838d3049ae5c2fd2e1f641c46c7c927e3
This commit is contained in:
Darijan Gudelj 2022-10-03 08:36:47 -07:00 committed by Facebook GitHub Bot
parent ad8907d373
commit c311a4cbb9
8 changed files with 102 additions and 35 deletions

View File

@ -222,6 +222,7 @@ class Experiment(Configurable): # pyre-ignore: 13
train_loader=train_loader,
val_loader=val_loader,
test_loader=test_loader,
train_dataset=datasets.train,
model=model,
optimizer=optimizer,
scheduler=scheduler,

View File

@ -197,6 +197,7 @@ model_factory_ImplicitronModelFactory_args:
n_pts_per_ray_training: 64
n_pts_per_ray_evaluation: 64
n_rays_per_image_sampled_from_mask: 1024
n_rays_total_training: null
stratified_point_sampling_training: true
stratified_point_sampling_evaluation: false
scene_extent: 8.0
@ -208,6 +209,7 @@ model_factory_ImplicitronModelFactory_args:
n_pts_per_ray_training: 64
n_pts_per_ray_evaluation: 64
n_rays_per_image_sampled_from_mask: 1024
n_rays_total_training: null
stratified_point_sampling_training: true
stratified_point_sampling_evaluation: false
min_depth: 0.1

View File

@ -473,7 +473,7 @@ class GenericModel(ImplicitronModelBase): # pyre-ignore: 13
self.view_metrics(
results=preds,
raymarched=rendered,
xys=ray_bundle.xys,
ray_bundle=ray_bundle,
image_rgb=safe_slice_targets(image_rgb),
depth_map=safe_slice_targets(depth_map),
fg_probability=safe_slice_targets(fg_probability),
@ -932,6 +932,11 @@ def _chunk_generator(
if len(iter) >= tqdm_trigger_threshold:
iter = tqdm.tqdm(iter)
def _safe_slice(
tensor: Optional[torch.Tensor], start_idx: int, end_idx: int
) -> Optional[torch.Tensor]:
return tensor[start_idx:end_idx] if tensor is not None else None
for start_idx in iter:
end_idx = min(start_idx + chunk_size_in_rays, n_rays)
ray_bundle_chunk = ImplicitronRayBundle(
@ -943,6 +948,8 @@ def _chunk_generator(
:, start_idx:end_idx
],
xys=ray_bundle.xys.reshape(batch_size, -1, 2)[:, start_idx:end_idx],
camera_ids=_safe_slice(ray_bundle.camera_ids, start_idx, end_idx),
camera_counts=_safe_slice(ray_bundle.camera_counts, start_idx, end_idx),
)
extra_args = kwargs.copy()
for k, v in chunked_inputs.items():

View File

@ -9,8 +9,10 @@ import warnings
from typing import Any, Dict, Optional
import torch
from pytorch3d.implicitron.models.renderer.ray_sampler import ImplicitronRayBundle
from pytorch3d.implicitron.tools import metric_utils as utils
from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
from pytorch3d.ops import packed_to_padded, padded_to_packed
from pytorch3d.renderer import utils as rend_utils
from .renderer.base import RendererOutput
@ -60,7 +62,7 @@ class ViewMetricsBase(ReplaceableBase, torch.nn.Module):
def forward(
self,
raymarched: RendererOutput,
xys: torch.Tensor,
ray_bundle: ImplicitronRayBundle,
image_rgb: Optional[torch.Tensor] = None,
depth_map: Optional[torch.Tensor] = None,
fg_probability: Optional[torch.Tensor] = None,
@ -79,10 +81,8 @@ class ViewMetricsBase(ReplaceableBase, torch.nn.Module):
names of the output metrics `metric_name_i` with their corresponding
values `metric_value_i` represented as 0-dimensional float tensors.
raymarched: Output of the renderer.
xys: A tensor of shape `(B, ..., 2)` containing 2D image locations at which
the predictions are defined. All ground truth inputs are sampled at
these locations in order to extract values that correspond to the
predictions.
ray_bundle: ImplicitronRayBundle object which was used to produce the raymarched
object
image_rgb: A tensor of shape `(B, H, W, 3)` containing ground truth rgb
values.
depth_map: A tensor of shape `(B, Hd, Wd, 1)` containing ground truth depth
@ -141,7 +141,7 @@ class ViewMetrics(ViewMetricsBase):
def forward(
self,
raymarched: RendererOutput,
xys: torch.Tensor,
ray_bundle: ImplicitronRayBundle,
image_rgb: Optional[torch.Tensor] = None,
depth_map: Optional[torch.Tensor] = None,
fg_probability: Optional[torch.Tensor] = None,
@ -165,10 +165,8 @@ class ViewMetrics(ViewMetricsBase):
input 3D coordinates used to compute the eikonal loss.
raymarched.aux["density_grid"]: A tensor of shape `(B, Hg, Wg, Dg, 1)`
containing a `Hg x Wg x Dg` voxel grid of density values.
xys: A tensor of shape `(B, ..., 2)` containing 2D image locations at which
the predictions are defined. All ground truth inputs are sampled at
these locations in order to extract values that correspond to the
predictions.
ray_bundle: ImplicitronRayBundle object which was used to produce the raymarched
object
image_rgb: A tensor of shape `(B, H, W, 3)` containing ground truth rgb
values.
depth_map: A tensor of shape `(B, Hd, Wd, 1)` containing ground truth depth
@ -209,7 +207,7 @@ class ViewMetrics(ViewMetricsBase):
"""
metrics = self._calculate_stage(
raymarched,
xys,
ray_bundle,
image_rgb,
depth_map,
fg_probability,
@ -221,7 +219,7 @@ class ViewMetrics(ViewMetricsBase):
metrics.update(
self(
raymarched.prev_stage,
xys,
ray_bundle,
image_rgb,
depth_map,
fg_probability,
@ -235,7 +233,7 @@ class ViewMetrics(ViewMetricsBase):
def _calculate_stage(
self,
raymarched: RendererOutput,
xys: torch.Tensor,
ray_bundle: ImplicitronRayBundle,
image_rgb: Optional[torch.Tensor] = None,
depth_map: Optional[torch.Tensor] = None,
fg_probability: Optional[torch.Tensor] = None,
@ -253,6 +251,27 @@ class ViewMetrics(ViewMetricsBase):
_reshape_nongrid_var(x)
for x in [raymarched.features, raymarched.masks, raymarched.depths]
]
xys = ray_bundle.xys
# If ray_bundle is packed than we can sample images in padded state to lower
# memory requirements. Instead of having one image for every element in
# ray_bundle we can than have one image per unique sampled camera.
if ray_bundle.is_packed():
# pyre-ignore[6]
cumsum = torch.cumsum(ray_bundle.camera_counts, dim=0, dtype=torch.long)
first_idxs = torch.cat(
(
# pyre-ignore[16]
ray_bundle.camera_counts.new_zeros((1,), dtype=torch.long),
cumsum[:-1],
)
)
# pyre-ignore[16]
num_inputs = int(ray_bundle.camera_counts.sum())
# pyre-ignore[6]
max_size = int(torch.max(ray_bundle.camera_counts))
xys = packed_to_padded(xys, first_idxs, max_size)
# reshape the sampling grid as well
# TODO: we can get rid of the singular dimension here and in _reshape_nongrid_var
# now that we use rend_utils.ndc_grid_sample
@ -262,7 +281,20 @@ class ViewMetrics(ViewMetricsBase):
def sample(tensor, mode):
if tensor is None:
return tensor
return rend_utils.ndc_grid_sample(tensor, xys, mode=mode)
if ray_bundle.is_packed():
# select images that corespond to sampled cameras if raybundle is packed
tensor = tensor[ray_bundle.camera_ids]
result = rend_utils.ndc_grid_sample(tensor, xys, mode=mode)
if ray_bundle.is_packed():
# Images after sampling are in a form [batch, 3, max_num_rays, 1],
# packed_to_padded combines first two dimensions so we need to swap 1st
# and 2nd dimension. the result is [n_rays_total_training, 1, 3, 1]
# (we use keepdim=True).
result = result.transpose(1, 2)
result = padded_to_packed(result, first_idxs, num_inputs)[:, None]
result = result.transpose(1, 2)
return result
# eval all results in this size
image_rgb = sample(image_rgb, mode="bilinear")

View File

@ -5,8 +5,9 @@
# LICENSE file in the root directory of this source tree.
import torch
from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle
from pytorch3d.implicitron.tools.config import Configurable, expand_args_fields
from pytorch3d.renderer import RayBundle
from pytorch3d.renderer.implicit.sample_pdf import sample_pdf
@ -42,13 +43,13 @@ class RayPointRefiner(Configurable, torch.nn.Module):
def forward(
self,
input_ray_bundle: RayBundle,
input_ray_bundle: ImplicitronRayBundle,
ray_weights: torch.Tensor,
**kwargs,
) -> RayBundle:
) -> ImplicitronRayBundle:
"""
Args:
input_ray_bundle: An instance of `RayBundle` specifying the
input_ray_bundle: An instance of `ImplicitronRayBundle` specifying the
source rays for sampling of the probability distribution.
ray_weights: A tensor of shape
`(..., input_ray_bundle.legths.shape[-1])` with non-negative
@ -56,7 +57,7 @@ class RayPointRefiner(Configurable, torch.nn.Module):
ray points from.
Returns:
ray_bundle: A new `RayBundle` instance containing the input ray
ray_bundle: A new `ImplicitronRayBundle` instance containing the input ray
points together with `n_pts_per_ray` additionally sampled
points per ray. For each ray, the lengths are sorted.
"""
@ -79,9 +80,6 @@ class RayPointRefiner(Configurable, torch.nn.Module):
# Resort by depth.
z_vals, _ = torch.sort(z_vals, dim=-1)
return RayBundle(
origins=input_ray_bundle.origins,
directions=input_ray_bundle.directions,
lengths=z_vals,
xys=input_ray_bundle.xys,
)
new_bundle = ImplicitronRayBundle(**vars(input_ray_bundle))
new_bundle.lengths = z_vals
return new_bundle

View File

@ -72,7 +72,17 @@ class AbstractMaskRaySampler(RaySamplerBase, torch.nn.Module):
sampling_mode_evaluation: Same as above but for evaluation.
n_pts_per_ray_training: The number of points sampled along each ray during training.
n_pts_per_ray_evaluation: The number of points sampled along each ray during evaluation.
n_rays_per_image_sampled_from_mask: The amount of rays to be sampled from the image grid
n_rays_per_image_sampled_from_mask: The amount of rays to be sampled from the image
grid. Given a batch of image grids, this many is sampled from each.
`n_rays_per_image_sampled_from_mask` and `n_rays_total_training` cannot both be
defined.
n_rays_total_training: (optional) How many rays in total to sample from the entire
batch of provided image grid. The result is as if `n_rays_total_training`
cameras/image grids were sampled with replacement from the cameras / image grids
provided and for every camera one ray was sampled.
`n_rays_per_image_sampled_from_mask` and `n_rays_total_training` cannot both be
defined, to use you have to set `n_rays_per_image` to None.
Used only for EvaluationMode.TRAINING.
stratified_point_sampling_training: if set, performs stratified random sampling
along the ray; otherwise takes ray points at deterministic offsets.
stratified_point_sampling_evaluation: Same as above but for evaluation.
@ -85,7 +95,8 @@ class AbstractMaskRaySampler(RaySamplerBase, torch.nn.Module):
sampling_mode_evaluation: str = "full_grid"
n_pts_per_ray_training: int = 64
n_pts_per_ray_evaluation: int = 64
n_rays_per_image_sampled_from_mask: int = 1024
n_rays_per_image_sampled_from_mask: Optional[int] = 1024
n_rays_total_training: Optional[int] = None
# stratified sampling vs taking points at deterministic offsets
stratified_point_sampling_training: bool = True
stratified_point_sampling_evaluation: bool = False
@ -93,6 +104,14 @@ class AbstractMaskRaySampler(RaySamplerBase, torch.nn.Module):
def __post_init__(self):
super().__init__()
if (self.n_rays_per_image_sampled_from_mask is not None) and (
self.n_rays_total_training is not None
):
raise ValueError(
"Cannot both define n_rays_total_training and "
"n_rays_per_image_sampled_from_mask."
)
self._sampling_mode = {
EvaluationMode.TRAINING: RenderSamplingMode(self.sampling_mode_training),
EvaluationMode.EVALUATION: RenderSamplingMode(
@ -110,9 +129,11 @@ class AbstractMaskRaySampler(RaySamplerBase, torch.nn.Module):
if self._sampling_mode[EvaluationMode.TRAINING]
== RenderSamplingMode.MASK_SAMPLE
else None,
n_rays_total=self.n_rays_total_training,
unit_directions=True,
stratified_sampling=self.stratified_point_sampling_training,
)
self._evaluation_raysampler = NDCMultinomialRaysampler(
image_width=self.image_width,
image_height=self.image_height,

View File

@ -90,11 +90,12 @@ class MultinomialRaysampler(torch.nn.Module):
min_depth: The minimum depth of a ray-point.
max_depth: The maximum depth of a ray-point.
n_rays_per_image: If given, this amount of rays are sampled from the grid.
`n_rays_per_image` and `n_rays_total` cannot both be defined.
n_rays_total: How many rays in total to sample from the cameras provided. The result
is as if `n_rays_total` cameras were sampled with replacement from the
cameras provided and for every camera one ray was sampled. If set, this disables
`n_rays_per_image` and returns the HeterogeneousRayBundle with
batch_size=n_rays_total.
is as if `n_rays_total_training` cameras were sampled with replacement from the
cameras provided and for every camera one ray was sampled. If set returns the
HeterogeneousRayBundle with batch_size=n_rays_total.
`n_rays_per_image` and `n_rays_total` cannot both be defined.
unit_directions: whether to normalize direction vectors in ray bundle.
stratified_sampling: if True, performs stratified random sampling
along the ray; otherwise takes ray points at deterministic offsets.
@ -144,13 +145,15 @@ class MultinomialRaysampler(torch.nn.Module):
min_depth: The minimum depth of a ray-point.
max_depth: The maximum depth of a ray-point.
n_rays_per_image: If given, this amount of rays are sampled from the grid.
`n_rays_per_image` and `n_rays_total` cannot both be defined.
n_pts_per_ray: The number of points sampled along each ray.
stratified_sampling: if set, overrides stratified_sampling provided
in __init__.
n_rays_total: How many rays in total to sample from the cameras provided. The result
is as if `n_rays_total_training` cameras were sampled with replacement from the
cameras provided and for every camera one ray was sampled. If set, returns the
cameras provided and for every camera one ray was sampled. If set returns the
HeterogeneousRayBundle with batch_size=n_rays_total.
`n_rays_per_image` and `n_rays_total` cannot both be defined.
Returns:
A named tuple RayBundle or dataclass HeterogeneousRayBundle with the
following fields:
@ -352,13 +355,15 @@ class MonteCarloRaysampler(torch.nn.Module):
min_y: The smallest y-coordinate of each ray's source pixel.
max_y: The largest y-coordinate of each ray's source pixel.
n_rays_per_image: The number of rays randomly sampled in each camera.
`n_rays_per_image` and `n_rays_total` cannot both be defined.
n_pts_per_ray: The number of points sampled along each ray.
min_depth: The minimum depth of each ray-point.
max_depth: The maximum depth of each ray-point.
n_rays_total: How many rays in total to sample from the cameras provided. The result
is as if `n_rays_total_training` cameras were sampled with replacement from the
cameras provided and for every camera one ray was sampled. If set, this returns
the HeterogeneousRayBundleyBundle with batch_size=n_rays_total.
cameras provided and for every camera one ray was sampled. If set returns the
HeterogeneousRayBundle with batch_size=n_rays_total.
`n_rays_per_image` and `n_rays_total` cannot both be defined.
unit_directions: whether to normalize direction vectors in ray bundle.
stratified_sampling: if True, performs stratified sampling in n_pts_per_ray
bins for each ray; otherwise takes n_pts_per_ray deterministic points

View File

@ -59,6 +59,7 @@ raysampler_AdaptiveRaySampler_args:
n_pts_per_ray_training: 64
n_pts_per_ray_evaluation: 64
n_rays_per_image_sampled_from_mask: 1024
n_rays_total_training: null
stratified_point_sampling_training: true
stratified_point_sampling_evaluation: false
scene_extent: 8.0