Enable mixed frame raysampling

Summary:
Changed ray_sampler and metrics to be able to use mixed frame raysampling.

Ray_sampler now has a new member which it passes to the pytorch3d raysampler.
If the raybundle is heterogeneous metrics now samples images by padding xys first. This reduces memory consumption.

Reviewed By: bottler, kjchalup

Differential Revision: D39542221

fbshipit-source-id: a6fec23838d3049ae5c2fd2e1f641c46c7c927e3
This commit is contained in:
Darijan Gudelj
2022-10-03 08:36:47 -07:00
committed by Facebook GitHub Bot
parent ad8907d373
commit c311a4cbb9
8 changed files with 102 additions and 35 deletions

View File

@@ -5,8 +5,9 @@
# LICENSE file in the root directory of this source tree.
import torch
from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle
from pytorch3d.implicitron.tools.config import Configurable, expand_args_fields
from pytorch3d.renderer import RayBundle
from pytorch3d.renderer.implicit.sample_pdf import sample_pdf
@@ -42,13 +43,13 @@ class RayPointRefiner(Configurable, torch.nn.Module):
def forward(
self,
input_ray_bundle: RayBundle,
input_ray_bundle: ImplicitronRayBundle,
ray_weights: torch.Tensor,
**kwargs,
) -> RayBundle:
) -> ImplicitronRayBundle:
"""
Args:
input_ray_bundle: An instance of `RayBundle` specifying the
input_ray_bundle: An instance of `ImplicitronRayBundle` specifying the
source rays for sampling of the probability distribution.
ray_weights: A tensor of shape
`(..., input_ray_bundle.legths.shape[-1])` with non-negative
@@ -56,7 +57,7 @@ class RayPointRefiner(Configurable, torch.nn.Module):
ray points from.
Returns:
ray_bundle: A new `RayBundle` instance containing the input ray
ray_bundle: A new `ImplicitronRayBundle` instance containing the input ray
points together with `n_pts_per_ray` additionally sampled
points per ray. For each ray, the lengths are sorted.
"""
@@ -79,9 +80,6 @@ class RayPointRefiner(Configurable, torch.nn.Module):
# Resort by depth.
z_vals, _ = torch.sort(z_vals, dim=-1)
return RayBundle(
origins=input_ray_bundle.origins,
directions=input_ray_bundle.directions,
lengths=z_vals,
xys=input_ray_bundle.xys,
)
new_bundle = ImplicitronRayBundle(**vars(input_ray_bundle))
new_bundle.lengths = z_vals
return new_bundle