mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Add blurpool following MIPNerf paper.
Summary: Add blurpool has defined in [MIP-NeRF](https://arxiv.org/abs/2103.13415). It has been added has an option for RayPointRefiner. Reviewed By: shapovalov Differential Revision: D46356189 fbshipit-source-id: ad841bad86d2b591a68e1cb885d4f781cf26c111
This commit is contained in:
parent
ccf860f1db
commit
5910d81b7b
@ -249,6 +249,8 @@ model_factory_ImplicitronModelFactory_args:
|
|||||||
append_coarse_samples_to_fine: true
|
append_coarse_samples_to_fine: true
|
||||||
density_noise_std_train: 0.0
|
density_noise_std_train: 0.0
|
||||||
return_weights: false
|
return_weights: false
|
||||||
|
blurpool_weights: false
|
||||||
|
sample_pdf_eps: 1.0e-05
|
||||||
raymarcher_CumsumRaymarcher_args:
|
raymarcher_CumsumRaymarcher_args:
|
||||||
surface_thickness: 1
|
surface_thickness: 1
|
||||||
bg_color:
|
bg_color:
|
||||||
@ -679,6 +681,8 @@ model_factory_ImplicitronModelFactory_args:
|
|||||||
append_coarse_samples_to_fine: true
|
append_coarse_samples_to_fine: true
|
||||||
density_noise_std_train: 0.0
|
density_noise_std_train: 0.0
|
||||||
return_weights: false
|
return_weights: false
|
||||||
|
blurpool_weights: false
|
||||||
|
sample_pdf_eps: 1.0e-05
|
||||||
raymarcher_CumsumRaymarcher_args:
|
raymarcher_CumsumRaymarcher_args:
|
||||||
surface_thickness: 1
|
surface_thickness: 1
|
||||||
bg_color:
|
bg_color:
|
||||||
|
@ -65,6 +65,9 @@ class MultiPassEmissionAbsorptionRenderer( # pyre-ignore: 13
|
|||||||
opacity field.
|
opacity field.
|
||||||
return_weights: Enables returning the rendering weights of the EA raymarcher.
|
return_weights: Enables returning the rendering weights of the EA raymarcher.
|
||||||
Setting to `True` can lead to a prohibitivelly large memory consumption.
|
Setting to `True` can lead to a prohibitivelly large memory consumption.
|
||||||
|
blurpool_weights: Use blurpool defined in [3], on the input weights of
|
||||||
|
each implicit_function except the first (implicit_functions[0]).
|
||||||
|
sample_pdf_eps: Padding applied to the weights (alpha in equation 18 of [3]).
|
||||||
raymarcher_class_type: The type of self.raymarcher corresponding to
|
raymarcher_class_type: The type of self.raymarcher corresponding to
|
||||||
a child of `RaymarcherBase` in the registry.
|
a child of `RaymarcherBase` in the registry.
|
||||||
raymarcher: The raymarcher object used to convert per-point features
|
raymarcher: The raymarcher object used to convert per-point features
|
||||||
@ -75,6 +78,8 @@ class MultiPassEmissionAbsorptionRenderer( # pyre-ignore: 13
|
|||||||
Fields for View Synthesis." ECCV 2020.
|
Fields for View Synthesis." ECCV 2020.
|
||||||
[2] Lombardi, Stephen, et al. "Neural Volumes: Learning Dynamic Renderable
|
[2] Lombardi, Stephen, et al. "Neural Volumes: Learning Dynamic Renderable
|
||||||
Volumes from Images." SIGGRAPH 2019.
|
Volumes from Images." SIGGRAPH 2019.
|
||||||
|
[3] Jonathan T. Barron, et al. "Mip-NeRF: A Multiscale Representation
|
||||||
|
for Anti-Aliasing Neural Radiance Fields." ICCV 2021.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -88,6 +93,8 @@ class MultiPassEmissionAbsorptionRenderer( # pyre-ignore: 13
|
|||||||
append_coarse_samples_to_fine: bool = True
|
append_coarse_samples_to_fine: bool = True
|
||||||
density_noise_std_train: float = 0.0
|
density_noise_std_train: float = 0.0
|
||||||
return_weights: bool = False
|
return_weights: bool = False
|
||||||
|
blurpool_weights: bool = False
|
||||||
|
sample_pdf_eps: float = 1e-5
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
self._refiners = {
|
self._refiners = {
|
||||||
@ -95,11 +102,15 @@ class MultiPassEmissionAbsorptionRenderer( # pyre-ignore: 13
|
|||||||
n_pts_per_ray=self.n_pts_per_ray_fine_training,
|
n_pts_per_ray=self.n_pts_per_ray_fine_training,
|
||||||
random_sampling=self.stratified_sampling_coarse_training,
|
random_sampling=self.stratified_sampling_coarse_training,
|
||||||
add_input_samples=self.append_coarse_samples_to_fine,
|
add_input_samples=self.append_coarse_samples_to_fine,
|
||||||
|
blurpool_weights=self.blurpool_weights,
|
||||||
|
sample_pdf_eps=self.sample_pdf_eps,
|
||||||
),
|
),
|
||||||
EvaluationMode.EVALUATION: RayPointRefiner(
|
EvaluationMode.EVALUATION: RayPointRefiner(
|
||||||
n_pts_per_ray=self.n_pts_per_ray_fine_evaluation,
|
n_pts_per_ray=self.n_pts_per_ray_fine_evaluation,
|
||||||
random_sampling=self.stratified_sampling_coarse_evaluation,
|
random_sampling=self.stratified_sampling_coarse_evaluation,
|
||||||
add_input_samples=self.append_coarse_samples_to_fine,
|
add_input_samples=self.append_coarse_samples_to_fine,
|
||||||
|
blurpool_weights=self.blurpool_weights,
|
||||||
|
sample_pdf_eps=self.sample_pdf_eps,
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
run_auto_creation(self)
|
run_auto_creation(self)
|
||||||
|
@ -32,16 +32,27 @@ class RayPointRefiner(Configurable, torch.nn.Module):
|
|||||||
sampling from that distribution.
|
sampling from that distribution.
|
||||||
add_input_samples: Concatenates and returns the sampled values
|
add_input_samples: Concatenates and returns the sampled values
|
||||||
together with the input samples.
|
together with the input samples.
|
||||||
|
blurpool_weights: Use blurpool defined in [1], on the input weights.
|
||||||
|
sample_pdf_eps: A constant preventing division by zero in case empty bins
|
||||||
|
are present.
|
||||||
|
|
||||||
|
References:
|
||||||
|
[1] Jonathan T. Barron, et al. "Mip-NeRF: A Multiscale Representation
|
||||||
|
for Anti-Aliasing Neural Radiance Fields." ICCV 2021.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
n_pts_per_ray: int
|
n_pts_per_ray: int
|
||||||
random_sampling: bool
|
random_sampling: bool
|
||||||
add_input_samples: bool = True
|
add_input_samples: bool = True
|
||||||
|
blurpool_weights: bool = False
|
||||||
|
sample_pdf_eps: float = 1e-5
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ray_bundle: ImplicitronRayBundle,
|
input_ray_bundle: ImplicitronRayBundle,
|
||||||
ray_weights: torch.Tensor,
|
ray_weights: torch.Tensor,
|
||||||
|
blurpool_weights: bool = False,
|
||||||
|
sample_pdf_padding: float = 1e-5,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> ImplicitronRayBundle:
|
) -> ImplicitronRayBundle:
|
||||||
"""
|
"""
|
||||||
@ -49,28 +60,38 @@ class RayPointRefiner(Configurable, torch.nn.Module):
|
|||||||
input_ray_bundle: An instance of `ImplicitronRayBundle` specifying the
|
input_ray_bundle: An instance of `ImplicitronRayBundle` specifying the
|
||||||
source rays for sampling of the probability distribution.
|
source rays for sampling of the probability distribution.
|
||||||
ray_weights: A tensor of shape
|
ray_weights: A tensor of shape
|
||||||
`(..., input_ray_bundle.legths.shape[-1])` with non-negative
|
`(..., input_ray_bundle.lengths.shape[-1])` with non-negative
|
||||||
elements defining the probability distribution to sample
|
elements defining the probability distribution to sample
|
||||||
ray points from.
|
ray points from.
|
||||||
|
blurpool_weights: Use blurpool defined in [1], on the input weights.
|
||||||
|
sample_pdf_padding: A constant preventing division by zero in case empty bins
|
||||||
|
are present.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
ray_bundle: A new `ImplicitronRayBundle` instance containing the input ray
|
ray_bundle: A new `ImplicitronRayBundle` instance containing the input ray
|
||||||
points together with `n_pts_per_ray` additionally sampled
|
points together with `n_pts_per_ray` additionally sampled
|
||||||
points per ray. For each ray, the lengths are sorted.
|
points per ray. For each ray, the lengths are sorted.
|
||||||
|
|
||||||
|
References:
|
||||||
|
[1] Jonathan T. Barron, et al. "Mip-NeRF: A Multiscale Representation
|
||||||
|
for Anti-Aliasing Neural Radiance Fields." ICCV 2021.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
z_vals = input_ray_bundle.lengths
|
z_vals = input_ray_bundle.lengths
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
if self.blurpool_weights:
|
||||||
|
ray_weights = apply_blurpool_on_weights(ray_weights)
|
||||||
|
|
||||||
z_vals_mid = torch.lerp(z_vals[..., 1:], z_vals[..., :-1], 0.5)
|
z_vals_mid = torch.lerp(z_vals[..., 1:], z_vals[..., :-1], 0.5)
|
||||||
z_samples = sample_pdf(
|
z_samples = sample_pdf(
|
||||||
z_vals_mid.view(-1, z_vals_mid.shape[-1]),
|
z_vals_mid.view(-1, z_vals_mid.shape[-1]),
|
||||||
ray_weights.view(-1, ray_weights.shape[-1])[..., 1:-1],
|
ray_weights.view(-1, ray_weights.shape[-1])[..., 1:-1],
|
||||||
self.n_pts_per_ray,
|
self.n_pts_per_ray,
|
||||||
det=not self.random_sampling,
|
det=not self.random_sampling,
|
||||||
|
eps=self.sample_pdf_eps,
|
||||||
).view(*z_vals.shape[:-1], self.n_pts_per_ray)
|
).view(*z_vals.shape[:-1], self.n_pts_per_ray)
|
||||||
|
|
||||||
if self.add_input_samples:
|
if self.add_input_samples:
|
||||||
# Add the new samples to the input ones.
|
|
||||||
z_vals = torch.cat((z_vals, z_samples), dim=-1)
|
z_vals = torch.cat((z_vals, z_samples), dim=-1)
|
||||||
else:
|
else:
|
||||||
z_vals = z_samples
|
z_vals = z_samples
|
||||||
@ -80,3 +101,31 @@ class RayPointRefiner(Configurable, torch.nn.Module):
|
|||||||
new_bundle = ImplicitronRayBundle(**vars(input_ray_bundle))
|
new_bundle = ImplicitronRayBundle(**vars(input_ray_bundle))
|
||||||
new_bundle.lengths = z_vals
|
new_bundle.lengths = z_vals
|
||||||
return new_bundle
|
return new_bundle
|
||||||
|
|
||||||
|
|
||||||
|
def apply_blurpool_on_weights(weights) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Filter weights with a 2-tap max filters followed by a 2-tap blur filter,
|
||||||
|
which produces a wide and smooth upper envelope on the weights.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
weights: Tensor of shape `(..., dim)`
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
blured_weights: Tensor of shape `(..., dim)`
|
||||||
|
"""
|
||||||
|
weights_pad = torch.concatenate(
|
||||||
|
[
|
||||||
|
weights[..., :1],
|
||||||
|
weights,
|
||||||
|
weights[..., -1:],
|
||||||
|
],
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
|
|
||||||
|
weights_max = torch.nn.functional.max_pool1d(
|
||||||
|
weights_pad.flatten(end_dim=-2), 2, stride=1
|
||||||
|
)
|
||||||
|
return torch.lerp(weights_max[..., :-1], weights_max[..., 1:], 0.5).reshape_as(
|
||||||
|
weights
|
||||||
|
)
|
||||||
|
@ -5,9 +5,14 @@
|
|||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
from itertools import product
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from pytorch3d.implicitron.models.renderer.ray_point_refiner import RayPointRefiner
|
|
||||||
|
from pytorch3d.implicitron.models.renderer.ray_point_refiner import (
|
||||||
|
apply_blurpool_on_weights,
|
||||||
|
RayPointRefiner,
|
||||||
|
)
|
||||||
from pytorch3d.implicitron.models.renderer.ray_sampler import ImplicitronRayBundle
|
from pytorch3d.implicitron.models.renderer.ray_sampler import ImplicitronRayBundle
|
||||||
from tests.common_testing import TestCaseMixin
|
from tests.common_testing import TestCaseMixin
|
||||||
|
|
||||||
@ -17,11 +22,12 @@ class TestRayPointRefiner(TestCaseMixin, unittest.TestCase):
|
|||||||
length = 15
|
length = 15
|
||||||
n_pts_per_ray = 10
|
n_pts_per_ray = 10
|
||||||
|
|
||||||
for add_input_samples in [False, True]:
|
for add_input_samples, use_blurpool in product([False, True], [False, True]):
|
||||||
ray_point_refiner = RayPointRefiner(
|
ray_point_refiner = RayPointRefiner(
|
||||||
n_pts_per_ray=n_pts_per_ray,
|
n_pts_per_ray=n_pts_per_ray,
|
||||||
random_sampling=False,
|
random_sampling=False,
|
||||||
add_input_samples=add_input_samples,
|
add_input_samples=add_input_samples,
|
||||||
|
blurpool_weights=use_blurpool,
|
||||||
)
|
)
|
||||||
lengths = torch.arange(length, dtype=torch.float32).expand(3, 25, length)
|
lengths = torch.arange(length, dtype=torch.float32).expand(3, 25, length)
|
||||||
bundle = ImplicitronRayBundle(
|
bundle = ImplicitronRayBundle(
|
||||||
@ -50,6 +56,7 @@ class TestRayPointRefiner(TestCaseMixin, unittest.TestCase):
|
|||||||
n_pts_per_ray=n_pts_per_ray,
|
n_pts_per_ray=n_pts_per_ray,
|
||||||
random_sampling=True,
|
random_sampling=True,
|
||||||
add_input_samples=add_input_samples,
|
add_input_samples=add_input_samples,
|
||||||
|
blurpool_weights=use_blurpool,
|
||||||
)
|
)
|
||||||
refined_random = ray_point_refiner_random(bundle, weights)
|
refined_random = ray_point_refiner_random(bundle, weights)
|
||||||
lengths_random = refined_random.lengths
|
lengths_random = refined_random.lengths
|
||||||
@ -62,3 +69,24 @@ class TestRayPointRefiner(TestCaseMixin, unittest.TestCase):
|
|||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
(lengths_random[..., 1:] - lengths_random[..., :-1] > 0).all()
|
(lengths_random[..., 1:] - lengths_random[..., :-1] > 0).all()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_apply_blurpool_on_weights(self):
|
||||||
|
weights = torch.tensor(
|
||||||
|
[
|
||||||
|
[0.5, 0.6, 0.7],
|
||||||
|
[0.5, 0.3, 0.9],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
expected_weights = 0.5 * torch.tensor(
|
||||||
|
[
|
||||||
|
[0.5 + 0.6, 0.6 + 0.7, 0.7 + 0.7],
|
||||||
|
[0.5 + 0.5, 0.5 + 0.9, 0.9 + 0.9],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
out_weights = apply_blurpool_on_weights(weights)
|
||||||
|
self.assertTrue(torch.allclose(out_weights, expected_weights))
|
||||||
|
|
||||||
|
def test_shapes_apply_blurpool_on_weights(self):
|
||||||
|
weights = torch.randn((5, 4, 3, 2, 1))
|
||||||
|
out_weights = apply_blurpool_on_weights(weights)
|
||||||
|
self.assertEqual((5, 4, 3, 2, 1), out_weights.shape)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user