mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2026-03-03 18:55:59 +08:00
Add blurpool following MIPNerf paper.
Summary: Add blurpool has defined in [MIP-NeRF](https://arxiv.org/abs/2103.13415). It has been added has an option for RayPointRefiner. Reviewed By: shapovalov Differential Revision: D46356189 fbshipit-source-id: ad841bad86d2b591a68e1cb885d4f781cf26c111
This commit is contained in:
committed by
Facebook GitHub Bot
parent
ccf860f1db
commit
5910d81b7b
@@ -32,16 +32,27 @@ class RayPointRefiner(Configurable, torch.nn.Module):
|
||||
sampling from that distribution.
|
||||
add_input_samples: Concatenates and returns the sampled values
|
||||
together with the input samples.
|
||||
blurpool_weights: Use blurpool defined in [1], on the input weights.
|
||||
sample_pdf_eps: A constant preventing division by zero in case empty bins
|
||||
are present.
|
||||
|
||||
References:
|
||||
[1] Jonathan T. Barron, et al. "Mip-NeRF: A Multiscale Representation
|
||||
for Anti-Aliasing Neural Radiance Fields." ICCV 2021.
|
||||
"""
|
||||
|
||||
n_pts_per_ray: int
|
||||
random_sampling: bool
|
||||
add_input_samples: bool = True
|
||||
blurpool_weights: bool = False
|
||||
sample_pdf_eps: float = 1e-5
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ray_bundle: ImplicitronRayBundle,
|
||||
ray_weights: torch.Tensor,
|
||||
blurpool_weights: bool = False,
|
||||
sample_pdf_padding: float = 1e-5,
|
||||
**kwargs,
|
||||
) -> ImplicitronRayBundle:
|
||||
"""
|
||||
@@ -49,28 +60,38 @@ class RayPointRefiner(Configurable, torch.nn.Module):
|
||||
input_ray_bundle: An instance of `ImplicitronRayBundle` specifying the
|
||||
source rays for sampling of the probability distribution.
|
||||
ray_weights: A tensor of shape
|
||||
`(..., input_ray_bundle.legths.shape[-1])` with non-negative
|
||||
`(..., input_ray_bundle.lengths.shape[-1])` with non-negative
|
||||
elements defining the probability distribution to sample
|
||||
ray points from.
|
||||
blurpool_weights: Use blurpool defined in [1], on the input weights.
|
||||
sample_pdf_padding: A constant preventing division by zero in case empty bins
|
||||
are present.
|
||||
|
||||
Returns:
|
||||
ray_bundle: A new `ImplicitronRayBundle` instance containing the input ray
|
||||
points together with `n_pts_per_ray` additionally sampled
|
||||
points per ray. For each ray, the lengths are sorted.
|
||||
|
||||
References:
|
||||
[1] Jonathan T. Barron, et al. "Mip-NeRF: A Multiscale Representation
|
||||
for Anti-Aliasing Neural Radiance Fields." ICCV 2021.
|
||||
|
||||
"""
|
||||
|
||||
z_vals = input_ray_bundle.lengths
|
||||
with torch.no_grad():
|
||||
if self.blurpool_weights:
|
||||
ray_weights = apply_blurpool_on_weights(ray_weights)
|
||||
|
||||
z_vals_mid = torch.lerp(z_vals[..., 1:], z_vals[..., :-1], 0.5)
|
||||
z_samples = sample_pdf(
|
||||
z_vals_mid.view(-1, z_vals_mid.shape[-1]),
|
||||
ray_weights.view(-1, ray_weights.shape[-1])[..., 1:-1],
|
||||
self.n_pts_per_ray,
|
||||
det=not self.random_sampling,
|
||||
eps=self.sample_pdf_eps,
|
||||
).view(*z_vals.shape[:-1], self.n_pts_per_ray)
|
||||
|
||||
if self.add_input_samples:
|
||||
# Add the new samples to the input ones.
|
||||
z_vals = torch.cat((z_vals, z_samples), dim=-1)
|
||||
else:
|
||||
z_vals = z_samples
|
||||
@@ -80,3 +101,31 @@ class RayPointRefiner(Configurable, torch.nn.Module):
|
||||
new_bundle = ImplicitronRayBundle(**vars(input_ray_bundle))
|
||||
new_bundle.lengths = z_vals
|
||||
return new_bundle
|
||||
|
||||
|
||||
def apply_blurpool_on_weights(weights) -> torch.Tensor:
|
||||
"""
|
||||
Filter weights with a 2-tap max filters followed by a 2-tap blur filter,
|
||||
which produces a wide and smooth upper envelope on the weights.
|
||||
|
||||
Args:
|
||||
weights: Tensor of shape `(..., dim)`
|
||||
|
||||
Returns:
|
||||
blured_weights: Tensor of shape `(..., dim)`
|
||||
"""
|
||||
weights_pad = torch.concatenate(
|
||||
[
|
||||
weights[..., :1],
|
||||
weights,
|
||||
weights[..., -1:],
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
weights_max = torch.nn.functional.max_pool1d(
|
||||
weights_pad.flatten(end_dim=-2), 2, stride=1
|
||||
)
|
||||
return torch.lerp(weights_max[..., :-1], weights_max[..., 1:], 0.5).reshape_as(
|
||||
weights
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user