mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2026-03-02 02:05:59 +08:00
Enable mixed frame raysampling
Summary: Changed ray_sampler and metrics to be able to use mixed frame raysampling. Ray_sampler now has a new member which it passes to the pytorch3d raysampler. If the raybundle is heterogeneous metrics now samples images by padding xys first. This reduces memory consumption. Reviewed By: bottler, kjchalup Differential Revision: D39542221 fbshipit-source-id: a6fec23838d3049ae5c2fd2e1f641c46c7c927e3
This commit is contained in:
committed by
Facebook GitHub Bot
parent
ad8907d373
commit
c311a4cbb9
@@ -5,8 +5,9 @@
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle
|
||||
from pytorch3d.implicitron.tools.config import Configurable, expand_args_fields
|
||||
from pytorch3d.renderer import RayBundle
|
||||
|
||||
from pytorch3d.renderer.implicit.sample_pdf import sample_pdf
|
||||
|
||||
|
||||
@@ -42,13 +43,13 @@ class RayPointRefiner(Configurable, torch.nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ray_bundle: RayBundle,
|
||||
input_ray_bundle: ImplicitronRayBundle,
|
||||
ray_weights: torch.Tensor,
|
||||
**kwargs,
|
||||
) -> RayBundle:
|
||||
) -> ImplicitronRayBundle:
|
||||
"""
|
||||
Args:
|
||||
input_ray_bundle: An instance of `RayBundle` specifying the
|
||||
input_ray_bundle: An instance of `ImplicitronRayBundle` specifying the
|
||||
source rays for sampling of the probability distribution.
|
||||
ray_weights: A tensor of shape
|
||||
`(..., input_ray_bundle.legths.shape[-1])` with non-negative
|
||||
@@ -56,7 +57,7 @@ class RayPointRefiner(Configurable, torch.nn.Module):
|
||||
ray points from.
|
||||
|
||||
Returns:
|
||||
ray_bundle: A new `RayBundle` instance containing the input ray
|
||||
ray_bundle: A new `ImplicitronRayBundle` instance containing the input ray
|
||||
points together with `n_pts_per_ray` additionally sampled
|
||||
points per ray. For each ray, the lengths are sorted.
|
||||
"""
|
||||
@@ -79,9 +80,6 @@ class RayPointRefiner(Configurable, torch.nn.Module):
|
||||
# Resort by depth.
|
||||
z_vals, _ = torch.sort(z_vals, dim=-1)
|
||||
|
||||
return RayBundle(
|
||||
origins=input_ray_bundle.origins,
|
||||
directions=input_ray_bundle.directions,
|
||||
lengths=z_vals,
|
||||
xys=input_ray_bundle.xys,
|
||||
)
|
||||
new_bundle = ImplicitronRayBundle(**vars(input_ray_bundle))
|
||||
new_bundle.lengths = z_vals
|
||||
return new_bundle
|
||||
|
||||
Reference in New Issue
Block a user