mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-01 03:12:49 +08:00
Enable mixed frame raysampling
Summary: Changed ray_sampler and metrics to be able to use mixed frame raysampling. Ray_sampler now has a new member which it passes to the pytorch3d raysampler. If the raybundle is heterogeneous metrics now samples images by padding xys first. This reduces memory consumption. Reviewed By: bottler, kjchalup Differential Revision: D39542221 fbshipit-source-id: a6fec23838d3049ae5c2fd2e1f641c46c7c927e3
This commit is contained in:
parent
ad8907d373
commit
c311a4cbb9
@ -222,6 +222,7 @@ class Experiment(Configurable): # pyre-ignore: 13
|
||||
train_loader=train_loader,
|
||||
val_loader=val_loader,
|
||||
test_loader=test_loader,
|
||||
train_dataset=datasets.train,
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
scheduler=scheduler,
|
||||
|
@ -197,6 +197,7 @@ model_factory_ImplicitronModelFactory_args:
|
||||
n_pts_per_ray_training: 64
|
||||
n_pts_per_ray_evaluation: 64
|
||||
n_rays_per_image_sampled_from_mask: 1024
|
||||
n_rays_total_training: null
|
||||
stratified_point_sampling_training: true
|
||||
stratified_point_sampling_evaluation: false
|
||||
scene_extent: 8.0
|
||||
@ -208,6 +209,7 @@ model_factory_ImplicitronModelFactory_args:
|
||||
n_pts_per_ray_training: 64
|
||||
n_pts_per_ray_evaluation: 64
|
||||
n_rays_per_image_sampled_from_mask: 1024
|
||||
n_rays_total_training: null
|
||||
stratified_point_sampling_training: true
|
||||
stratified_point_sampling_evaluation: false
|
||||
min_depth: 0.1
|
||||
|
@ -473,7 +473,7 @@ class GenericModel(ImplicitronModelBase): # pyre-ignore: 13
|
||||
self.view_metrics(
|
||||
results=preds,
|
||||
raymarched=rendered,
|
||||
xys=ray_bundle.xys,
|
||||
ray_bundle=ray_bundle,
|
||||
image_rgb=safe_slice_targets(image_rgb),
|
||||
depth_map=safe_slice_targets(depth_map),
|
||||
fg_probability=safe_slice_targets(fg_probability),
|
||||
@ -932,6 +932,11 @@ def _chunk_generator(
|
||||
if len(iter) >= tqdm_trigger_threshold:
|
||||
iter = tqdm.tqdm(iter)
|
||||
|
||||
def _safe_slice(
|
||||
tensor: Optional[torch.Tensor], start_idx: int, end_idx: int
|
||||
) -> Optional[torch.Tensor]:
|
||||
return tensor[start_idx:end_idx] if tensor is not None else None
|
||||
|
||||
for start_idx in iter:
|
||||
end_idx = min(start_idx + chunk_size_in_rays, n_rays)
|
||||
ray_bundle_chunk = ImplicitronRayBundle(
|
||||
@ -943,6 +948,8 @@ def _chunk_generator(
|
||||
:, start_idx:end_idx
|
||||
],
|
||||
xys=ray_bundle.xys.reshape(batch_size, -1, 2)[:, start_idx:end_idx],
|
||||
camera_ids=_safe_slice(ray_bundle.camera_ids, start_idx, end_idx),
|
||||
camera_counts=_safe_slice(ray_bundle.camera_counts, start_idx, end_idx),
|
||||
)
|
||||
extra_args = kwargs.copy()
|
||||
for k, v in chunked_inputs.items():
|
||||
|
@ -9,8 +9,10 @@ import warnings
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
from pytorch3d.implicitron.models.renderer.ray_sampler import ImplicitronRayBundle
|
||||
from pytorch3d.implicitron.tools import metric_utils as utils
|
||||
from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
|
||||
from pytorch3d.ops import packed_to_padded, padded_to_packed
|
||||
from pytorch3d.renderer import utils as rend_utils
|
||||
|
||||
from .renderer.base import RendererOutput
|
||||
@ -60,7 +62,7 @@ class ViewMetricsBase(ReplaceableBase, torch.nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
raymarched: RendererOutput,
|
||||
xys: torch.Tensor,
|
||||
ray_bundle: ImplicitronRayBundle,
|
||||
image_rgb: Optional[torch.Tensor] = None,
|
||||
depth_map: Optional[torch.Tensor] = None,
|
||||
fg_probability: Optional[torch.Tensor] = None,
|
||||
@ -79,10 +81,8 @@ class ViewMetricsBase(ReplaceableBase, torch.nn.Module):
|
||||
names of the output metrics `metric_name_i` with their corresponding
|
||||
values `metric_value_i` represented as 0-dimensional float tensors.
|
||||
raymarched: Output of the renderer.
|
||||
xys: A tensor of shape `(B, ..., 2)` containing 2D image locations at which
|
||||
the predictions are defined. All ground truth inputs are sampled at
|
||||
these locations in order to extract values that correspond to the
|
||||
predictions.
|
||||
ray_bundle: ImplicitronRayBundle object which was used to produce the raymarched
|
||||
object
|
||||
image_rgb: A tensor of shape `(B, H, W, 3)` containing ground truth rgb
|
||||
values.
|
||||
depth_map: A tensor of shape `(B, Hd, Wd, 1)` containing ground truth depth
|
||||
@ -141,7 +141,7 @@ class ViewMetrics(ViewMetricsBase):
|
||||
def forward(
|
||||
self,
|
||||
raymarched: RendererOutput,
|
||||
xys: torch.Tensor,
|
||||
ray_bundle: ImplicitronRayBundle,
|
||||
image_rgb: Optional[torch.Tensor] = None,
|
||||
depth_map: Optional[torch.Tensor] = None,
|
||||
fg_probability: Optional[torch.Tensor] = None,
|
||||
@ -165,10 +165,8 @@ class ViewMetrics(ViewMetricsBase):
|
||||
input 3D coordinates used to compute the eikonal loss.
|
||||
raymarched.aux["density_grid"]: A tensor of shape `(B, Hg, Wg, Dg, 1)`
|
||||
containing a `Hg x Wg x Dg` voxel grid of density values.
|
||||
xys: A tensor of shape `(B, ..., 2)` containing 2D image locations at which
|
||||
the predictions are defined. All ground truth inputs are sampled at
|
||||
these locations in order to extract values that correspond to the
|
||||
predictions.
|
||||
ray_bundle: ImplicitronRayBundle object which was used to produce the raymarched
|
||||
object
|
||||
image_rgb: A tensor of shape `(B, H, W, 3)` containing ground truth rgb
|
||||
values.
|
||||
depth_map: A tensor of shape `(B, Hd, Wd, 1)` containing ground truth depth
|
||||
@ -209,7 +207,7 @@ class ViewMetrics(ViewMetricsBase):
|
||||
"""
|
||||
metrics = self._calculate_stage(
|
||||
raymarched,
|
||||
xys,
|
||||
ray_bundle,
|
||||
image_rgb,
|
||||
depth_map,
|
||||
fg_probability,
|
||||
@ -221,7 +219,7 @@ class ViewMetrics(ViewMetricsBase):
|
||||
metrics.update(
|
||||
self(
|
||||
raymarched.prev_stage,
|
||||
xys,
|
||||
ray_bundle,
|
||||
image_rgb,
|
||||
depth_map,
|
||||
fg_probability,
|
||||
@ -235,7 +233,7 @@ class ViewMetrics(ViewMetricsBase):
|
||||
def _calculate_stage(
|
||||
self,
|
||||
raymarched: RendererOutput,
|
||||
xys: torch.Tensor,
|
||||
ray_bundle: ImplicitronRayBundle,
|
||||
image_rgb: Optional[torch.Tensor] = None,
|
||||
depth_map: Optional[torch.Tensor] = None,
|
||||
fg_probability: Optional[torch.Tensor] = None,
|
||||
@ -253,6 +251,27 @@ class ViewMetrics(ViewMetricsBase):
|
||||
_reshape_nongrid_var(x)
|
||||
for x in [raymarched.features, raymarched.masks, raymarched.depths]
|
||||
]
|
||||
xys = ray_bundle.xys
|
||||
|
||||
# If ray_bundle is packed than we can sample images in padded state to lower
|
||||
# memory requirements. Instead of having one image for every element in
|
||||
# ray_bundle we can than have one image per unique sampled camera.
|
||||
if ray_bundle.is_packed():
|
||||
# pyre-ignore[6]
|
||||
cumsum = torch.cumsum(ray_bundle.camera_counts, dim=0, dtype=torch.long)
|
||||
first_idxs = torch.cat(
|
||||
(
|
||||
# pyre-ignore[16]
|
||||
ray_bundle.camera_counts.new_zeros((1,), dtype=torch.long),
|
||||
cumsum[:-1],
|
||||
)
|
||||
)
|
||||
# pyre-ignore[16]
|
||||
num_inputs = int(ray_bundle.camera_counts.sum())
|
||||
# pyre-ignore[6]
|
||||
max_size = int(torch.max(ray_bundle.camera_counts))
|
||||
xys = packed_to_padded(xys, first_idxs, max_size)
|
||||
|
||||
# reshape the sampling grid as well
|
||||
# TODO: we can get rid of the singular dimension here and in _reshape_nongrid_var
|
||||
# now that we use rend_utils.ndc_grid_sample
|
||||
@ -262,7 +281,20 @@ class ViewMetrics(ViewMetricsBase):
|
||||
def sample(tensor, mode):
|
||||
if tensor is None:
|
||||
return tensor
|
||||
return rend_utils.ndc_grid_sample(tensor, xys, mode=mode)
|
||||
if ray_bundle.is_packed():
|
||||
# select images that corespond to sampled cameras if raybundle is packed
|
||||
tensor = tensor[ray_bundle.camera_ids]
|
||||
result = rend_utils.ndc_grid_sample(tensor, xys, mode=mode)
|
||||
if ray_bundle.is_packed():
|
||||
# Images after sampling are in a form [batch, 3, max_num_rays, 1],
|
||||
# packed_to_padded combines first two dimensions so we need to swap 1st
|
||||
# and 2nd dimension. the result is [n_rays_total_training, 1, 3, 1]
|
||||
# (we use keepdim=True).
|
||||
result = result.transpose(1, 2)
|
||||
result = padded_to_packed(result, first_idxs, num_inputs)[:, None]
|
||||
result = result.transpose(1, 2)
|
||||
|
||||
return result
|
||||
|
||||
# eval all results in this size
|
||||
image_rgb = sample(image_rgb, mode="bilinear")
|
||||
|
@ -5,8 +5,9 @@
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle
|
||||
from pytorch3d.implicitron.tools.config import Configurable, expand_args_fields
|
||||
from pytorch3d.renderer import RayBundle
|
||||
|
||||
from pytorch3d.renderer.implicit.sample_pdf import sample_pdf
|
||||
|
||||
|
||||
@ -42,13 +43,13 @@ class RayPointRefiner(Configurable, torch.nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ray_bundle: RayBundle,
|
||||
input_ray_bundle: ImplicitronRayBundle,
|
||||
ray_weights: torch.Tensor,
|
||||
**kwargs,
|
||||
) -> RayBundle:
|
||||
) -> ImplicitronRayBundle:
|
||||
"""
|
||||
Args:
|
||||
input_ray_bundle: An instance of `RayBundle` specifying the
|
||||
input_ray_bundle: An instance of `ImplicitronRayBundle` specifying the
|
||||
source rays for sampling of the probability distribution.
|
||||
ray_weights: A tensor of shape
|
||||
`(..., input_ray_bundle.legths.shape[-1])` with non-negative
|
||||
@ -56,7 +57,7 @@ class RayPointRefiner(Configurable, torch.nn.Module):
|
||||
ray points from.
|
||||
|
||||
Returns:
|
||||
ray_bundle: A new `RayBundle` instance containing the input ray
|
||||
ray_bundle: A new `ImplicitronRayBundle` instance containing the input ray
|
||||
points together with `n_pts_per_ray` additionally sampled
|
||||
points per ray. For each ray, the lengths are sorted.
|
||||
"""
|
||||
@ -79,9 +80,6 @@ class RayPointRefiner(Configurable, torch.nn.Module):
|
||||
# Resort by depth.
|
||||
z_vals, _ = torch.sort(z_vals, dim=-1)
|
||||
|
||||
return RayBundle(
|
||||
origins=input_ray_bundle.origins,
|
||||
directions=input_ray_bundle.directions,
|
||||
lengths=z_vals,
|
||||
xys=input_ray_bundle.xys,
|
||||
)
|
||||
new_bundle = ImplicitronRayBundle(**vars(input_ray_bundle))
|
||||
new_bundle.lengths = z_vals
|
||||
return new_bundle
|
||||
|
@ -72,7 +72,17 @@ class AbstractMaskRaySampler(RaySamplerBase, torch.nn.Module):
|
||||
sampling_mode_evaluation: Same as above but for evaluation.
|
||||
n_pts_per_ray_training: The number of points sampled along each ray during training.
|
||||
n_pts_per_ray_evaluation: The number of points sampled along each ray during evaluation.
|
||||
n_rays_per_image_sampled_from_mask: The amount of rays to be sampled from the image grid
|
||||
n_rays_per_image_sampled_from_mask: The amount of rays to be sampled from the image
|
||||
grid. Given a batch of image grids, this many is sampled from each.
|
||||
`n_rays_per_image_sampled_from_mask` and `n_rays_total_training` cannot both be
|
||||
defined.
|
||||
n_rays_total_training: (optional) How many rays in total to sample from the entire
|
||||
batch of provided image grid. The result is as if `n_rays_total_training`
|
||||
cameras/image grids were sampled with replacement from the cameras / image grids
|
||||
provided and for every camera one ray was sampled.
|
||||
`n_rays_per_image_sampled_from_mask` and `n_rays_total_training` cannot both be
|
||||
defined, to use you have to set `n_rays_per_image` to None.
|
||||
Used only for EvaluationMode.TRAINING.
|
||||
stratified_point_sampling_training: if set, performs stratified random sampling
|
||||
along the ray; otherwise takes ray points at deterministic offsets.
|
||||
stratified_point_sampling_evaluation: Same as above but for evaluation.
|
||||
@ -85,7 +95,8 @@ class AbstractMaskRaySampler(RaySamplerBase, torch.nn.Module):
|
||||
sampling_mode_evaluation: str = "full_grid"
|
||||
n_pts_per_ray_training: int = 64
|
||||
n_pts_per_ray_evaluation: int = 64
|
||||
n_rays_per_image_sampled_from_mask: int = 1024
|
||||
n_rays_per_image_sampled_from_mask: Optional[int] = 1024
|
||||
n_rays_total_training: Optional[int] = None
|
||||
# stratified sampling vs taking points at deterministic offsets
|
||||
stratified_point_sampling_training: bool = True
|
||||
stratified_point_sampling_evaluation: bool = False
|
||||
@ -93,6 +104,14 @@ class AbstractMaskRaySampler(RaySamplerBase, torch.nn.Module):
|
||||
def __post_init__(self):
|
||||
super().__init__()
|
||||
|
||||
if (self.n_rays_per_image_sampled_from_mask is not None) and (
|
||||
self.n_rays_total_training is not None
|
||||
):
|
||||
raise ValueError(
|
||||
"Cannot both define n_rays_total_training and "
|
||||
"n_rays_per_image_sampled_from_mask."
|
||||
)
|
||||
|
||||
self._sampling_mode = {
|
||||
EvaluationMode.TRAINING: RenderSamplingMode(self.sampling_mode_training),
|
||||
EvaluationMode.EVALUATION: RenderSamplingMode(
|
||||
@ -110,9 +129,11 @@ class AbstractMaskRaySampler(RaySamplerBase, torch.nn.Module):
|
||||
if self._sampling_mode[EvaluationMode.TRAINING]
|
||||
== RenderSamplingMode.MASK_SAMPLE
|
||||
else None,
|
||||
n_rays_total=self.n_rays_total_training,
|
||||
unit_directions=True,
|
||||
stratified_sampling=self.stratified_point_sampling_training,
|
||||
)
|
||||
|
||||
self._evaluation_raysampler = NDCMultinomialRaysampler(
|
||||
image_width=self.image_width,
|
||||
image_height=self.image_height,
|
||||
|
@ -90,11 +90,12 @@ class MultinomialRaysampler(torch.nn.Module):
|
||||
min_depth: The minimum depth of a ray-point.
|
||||
max_depth: The maximum depth of a ray-point.
|
||||
n_rays_per_image: If given, this amount of rays are sampled from the grid.
|
||||
`n_rays_per_image` and `n_rays_total` cannot both be defined.
|
||||
n_rays_total: How many rays in total to sample from the cameras provided. The result
|
||||
is as if `n_rays_total` cameras were sampled with replacement from the
|
||||
cameras provided and for every camera one ray was sampled. If set, this disables
|
||||
`n_rays_per_image` and returns the HeterogeneousRayBundle with
|
||||
batch_size=n_rays_total.
|
||||
is as if `n_rays_total_training` cameras were sampled with replacement from the
|
||||
cameras provided and for every camera one ray was sampled. If set returns the
|
||||
HeterogeneousRayBundle with batch_size=n_rays_total.
|
||||
`n_rays_per_image` and `n_rays_total` cannot both be defined.
|
||||
unit_directions: whether to normalize direction vectors in ray bundle.
|
||||
stratified_sampling: if True, performs stratified random sampling
|
||||
along the ray; otherwise takes ray points at deterministic offsets.
|
||||
@ -144,13 +145,15 @@ class MultinomialRaysampler(torch.nn.Module):
|
||||
min_depth: The minimum depth of a ray-point.
|
||||
max_depth: The maximum depth of a ray-point.
|
||||
n_rays_per_image: If given, this amount of rays are sampled from the grid.
|
||||
`n_rays_per_image` and `n_rays_total` cannot both be defined.
|
||||
n_pts_per_ray: The number of points sampled along each ray.
|
||||
stratified_sampling: if set, overrides stratified_sampling provided
|
||||
in __init__.
|
||||
n_rays_total: How many rays in total to sample from the cameras provided. The result
|
||||
is as if `n_rays_total_training` cameras were sampled with replacement from the
|
||||
cameras provided and for every camera one ray was sampled. If set, returns the
|
||||
cameras provided and for every camera one ray was sampled. If set returns the
|
||||
HeterogeneousRayBundle with batch_size=n_rays_total.
|
||||
`n_rays_per_image` and `n_rays_total` cannot both be defined.
|
||||
Returns:
|
||||
A named tuple RayBundle or dataclass HeterogeneousRayBundle with the
|
||||
following fields:
|
||||
@ -352,13 +355,15 @@ class MonteCarloRaysampler(torch.nn.Module):
|
||||
min_y: The smallest y-coordinate of each ray's source pixel.
|
||||
max_y: The largest y-coordinate of each ray's source pixel.
|
||||
n_rays_per_image: The number of rays randomly sampled in each camera.
|
||||
`n_rays_per_image` and `n_rays_total` cannot both be defined.
|
||||
n_pts_per_ray: The number of points sampled along each ray.
|
||||
min_depth: The minimum depth of each ray-point.
|
||||
max_depth: The maximum depth of each ray-point.
|
||||
n_rays_total: How many rays in total to sample from the cameras provided. The result
|
||||
is as if `n_rays_total_training` cameras were sampled with replacement from the
|
||||
cameras provided and for every camera one ray was sampled. If set, this returns
|
||||
the HeterogeneousRayBundleyBundle with batch_size=n_rays_total.
|
||||
cameras provided and for every camera one ray was sampled. If set returns the
|
||||
HeterogeneousRayBundle with batch_size=n_rays_total.
|
||||
`n_rays_per_image` and `n_rays_total` cannot both be defined.
|
||||
unit_directions: whether to normalize direction vectors in ray bundle.
|
||||
stratified_sampling: if True, performs stratified sampling in n_pts_per_ray
|
||||
bins for each ray; otherwise takes n_pts_per_ray deterministic points
|
||||
|
@ -59,6 +59,7 @@ raysampler_AdaptiveRaySampler_args:
|
||||
n_pts_per_ray_training: 64
|
||||
n_pts_per_ray_evaluation: 64
|
||||
n_rays_per_image_sampled_from_mask: 1024
|
||||
n_rays_total_training: null
|
||||
stratified_point_sampling_training: true
|
||||
stratified_point_sampling_evaluation: false
|
||||
scene_extent: 8.0
|
||||
|
Loading…
x
Reference in New Issue
Block a user