mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 18:02:14 +08:00 
			
		
		
		
	Add blurpool following MIPNerf paper.
Summary: Add blurpool has defined in [MIP-NeRF](https://arxiv.org/abs/2103.13415). It has been added has an option for RayPointRefiner. Reviewed By: shapovalov Differential Revision: D46356189 fbshipit-source-id: ad841bad86d2b591a68e1cb885d4f781cf26c111
This commit is contained in:
		
							parent
							
								
									ccf860f1db
								
							
						
					
					
						commit
						5910d81b7b
					
				@ -249,6 +249,8 @@ model_factory_ImplicitronModelFactory_args:
 | 
			
		||||
      append_coarse_samples_to_fine: true
 | 
			
		||||
      density_noise_std_train: 0.0
 | 
			
		||||
      return_weights: false
 | 
			
		||||
      blurpool_weights: false
 | 
			
		||||
      sample_pdf_eps: 1.0e-05
 | 
			
		||||
      raymarcher_CumsumRaymarcher_args:
 | 
			
		||||
        surface_thickness: 1
 | 
			
		||||
        bg_color:
 | 
			
		||||
@ -679,6 +681,8 @@ model_factory_ImplicitronModelFactory_args:
 | 
			
		||||
      append_coarse_samples_to_fine: true
 | 
			
		||||
      density_noise_std_train: 0.0
 | 
			
		||||
      return_weights: false
 | 
			
		||||
      blurpool_weights: false
 | 
			
		||||
      sample_pdf_eps: 1.0e-05
 | 
			
		||||
      raymarcher_CumsumRaymarcher_args:
 | 
			
		||||
        surface_thickness: 1
 | 
			
		||||
        bg_color:
 | 
			
		||||
 | 
			
		||||
@ -65,6 +65,9 @@ class MultiPassEmissionAbsorptionRenderer(  # pyre-ignore: 13
 | 
			
		||||
            opacity field.
 | 
			
		||||
        return_weights: Enables returning the rendering weights of the EA raymarcher.
 | 
			
		||||
            Setting to `True` can lead to a prohibitivelly large memory consumption.
 | 
			
		||||
        blurpool_weights: Use blurpool defined in [3], on the input weights of
 | 
			
		||||
            each implicit_function except the first (implicit_functions[0]).
 | 
			
		||||
        sample_pdf_eps: Padding applied to the weights (alpha in equation 18 of [3]).
 | 
			
		||||
        raymarcher_class_type: The type of self.raymarcher corresponding to
 | 
			
		||||
            a child of `RaymarcherBase` in the registry.
 | 
			
		||||
        raymarcher: The raymarcher object used to convert per-point features
 | 
			
		||||
@ -75,6 +78,8 @@ class MultiPassEmissionAbsorptionRenderer(  # pyre-ignore: 13
 | 
			
		||||
            Fields for View Synthesis." ECCV 2020.
 | 
			
		||||
        [2] Lombardi, Stephen, et al. "Neural Volumes: Learning Dynamic Renderable
 | 
			
		||||
            Volumes from Images." SIGGRAPH 2019.
 | 
			
		||||
        [3] Jonathan T. Barron, et al. "Mip-NeRF: A Multiscale Representation
 | 
			
		||||
            for Anti-Aliasing Neural Radiance Fields." ICCV 2021.
 | 
			
		||||
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
@ -88,6 +93,8 @@ class MultiPassEmissionAbsorptionRenderer(  # pyre-ignore: 13
 | 
			
		||||
    append_coarse_samples_to_fine: bool = True
 | 
			
		||||
    density_noise_std_train: float = 0.0
 | 
			
		||||
    return_weights: bool = False
 | 
			
		||||
    blurpool_weights: bool = False
 | 
			
		||||
    sample_pdf_eps: float = 1e-5
 | 
			
		||||
 | 
			
		||||
    def __post_init__(self):
 | 
			
		||||
        self._refiners = {
 | 
			
		||||
@ -95,11 +102,15 @@ class MultiPassEmissionAbsorptionRenderer(  # pyre-ignore: 13
 | 
			
		||||
                n_pts_per_ray=self.n_pts_per_ray_fine_training,
 | 
			
		||||
                random_sampling=self.stratified_sampling_coarse_training,
 | 
			
		||||
                add_input_samples=self.append_coarse_samples_to_fine,
 | 
			
		||||
                blurpool_weights=self.blurpool_weights,
 | 
			
		||||
                sample_pdf_eps=self.sample_pdf_eps,
 | 
			
		||||
            ),
 | 
			
		||||
            EvaluationMode.EVALUATION: RayPointRefiner(
 | 
			
		||||
                n_pts_per_ray=self.n_pts_per_ray_fine_evaluation,
 | 
			
		||||
                random_sampling=self.stratified_sampling_coarse_evaluation,
 | 
			
		||||
                add_input_samples=self.append_coarse_samples_to_fine,
 | 
			
		||||
                blurpool_weights=self.blurpool_weights,
 | 
			
		||||
                sample_pdf_eps=self.sample_pdf_eps,
 | 
			
		||||
            ),
 | 
			
		||||
        }
 | 
			
		||||
        run_auto_creation(self)
 | 
			
		||||
 | 
			
		||||
@ -32,16 +32,27 @@ class RayPointRefiner(Configurable, torch.nn.Module):
 | 
			
		||||
            sampling from that distribution.
 | 
			
		||||
        add_input_samples: Concatenates and returns the sampled values
 | 
			
		||||
            together with the input samples.
 | 
			
		||||
        blurpool_weights: Use blurpool defined in [1], on the input weights.
 | 
			
		||||
        sample_pdf_eps: A constant preventing division by zero in case empty bins
 | 
			
		||||
            are present.
 | 
			
		||||
 | 
			
		||||
    References:
 | 
			
		||||
        [1] Jonathan T. Barron, et al. "Mip-NeRF: A Multiscale Representation
 | 
			
		||||
            for Anti-Aliasing Neural Radiance Fields." ICCV 2021.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    n_pts_per_ray: int
 | 
			
		||||
    random_sampling: bool
 | 
			
		||||
    add_input_samples: bool = True
 | 
			
		||||
    blurpool_weights: bool = False
 | 
			
		||||
    sample_pdf_eps: float = 1e-5
 | 
			
		||||
 | 
			
		||||
    def forward(
 | 
			
		||||
        self,
 | 
			
		||||
        input_ray_bundle: ImplicitronRayBundle,
 | 
			
		||||
        ray_weights: torch.Tensor,
 | 
			
		||||
        blurpool_weights: bool = False,
 | 
			
		||||
        sample_pdf_padding: float = 1e-5,
 | 
			
		||||
        **kwargs,
 | 
			
		||||
    ) -> ImplicitronRayBundle:
 | 
			
		||||
        """
 | 
			
		||||
@ -49,28 +60,38 @@ class RayPointRefiner(Configurable, torch.nn.Module):
 | 
			
		||||
            input_ray_bundle: An instance of `ImplicitronRayBundle` specifying the
 | 
			
		||||
                source rays for sampling of the probability distribution.
 | 
			
		||||
            ray_weights: A tensor of shape
 | 
			
		||||
                `(..., input_ray_bundle.legths.shape[-1])` with non-negative
 | 
			
		||||
                `(..., input_ray_bundle.lengths.shape[-1])` with non-negative
 | 
			
		||||
                elements defining the probability distribution to sample
 | 
			
		||||
                ray points from.
 | 
			
		||||
            blurpool_weights: Use blurpool defined in [1], on the input weights.
 | 
			
		||||
            sample_pdf_padding: A constant preventing division by zero in case empty bins
 | 
			
		||||
                are present.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            ray_bundle: A new `ImplicitronRayBundle` instance containing the input ray
 | 
			
		||||
                points together with `n_pts_per_ray` additionally sampled
 | 
			
		||||
                points per ray. For each ray, the lengths are sorted.
 | 
			
		||||
 | 
			
		||||
        References:
 | 
			
		||||
            [1] Jonathan T. Barron, et al. "Mip-NeRF: A Multiscale Representation
 | 
			
		||||
                for Anti-Aliasing Neural Radiance Fields." ICCV 2021.
 | 
			
		||||
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
        z_vals = input_ray_bundle.lengths
 | 
			
		||||
        with torch.no_grad():
 | 
			
		||||
            if self.blurpool_weights:
 | 
			
		||||
                ray_weights = apply_blurpool_on_weights(ray_weights)
 | 
			
		||||
 | 
			
		||||
            z_vals_mid = torch.lerp(z_vals[..., 1:], z_vals[..., :-1], 0.5)
 | 
			
		||||
            z_samples = sample_pdf(
 | 
			
		||||
                z_vals_mid.view(-1, z_vals_mid.shape[-1]),
 | 
			
		||||
                ray_weights.view(-1, ray_weights.shape[-1])[..., 1:-1],
 | 
			
		||||
                self.n_pts_per_ray,
 | 
			
		||||
                det=not self.random_sampling,
 | 
			
		||||
                eps=self.sample_pdf_eps,
 | 
			
		||||
            ).view(*z_vals.shape[:-1], self.n_pts_per_ray)
 | 
			
		||||
 | 
			
		||||
        if self.add_input_samples:
 | 
			
		||||
            # Add the new samples to the input ones.
 | 
			
		||||
            z_vals = torch.cat((z_vals, z_samples), dim=-1)
 | 
			
		||||
        else:
 | 
			
		||||
            z_vals = z_samples
 | 
			
		||||
@ -80,3 +101,31 @@ class RayPointRefiner(Configurable, torch.nn.Module):
 | 
			
		||||
        new_bundle = ImplicitronRayBundle(**vars(input_ray_bundle))
 | 
			
		||||
        new_bundle.lengths = z_vals
 | 
			
		||||
        return new_bundle
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def apply_blurpool_on_weights(weights) -> torch.Tensor:
 | 
			
		||||
    """
 | 
			
		||||
    Filter weights with a 2-tap max filters followed by a 2-tap blur filter,
 | 
			
		||||
    which produces a wide and smooth upper envelope on the weights.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        weights: Tensor of shape `(..., dim)`
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
        blured_weights: Tensor of shape `(..., dim)`
 | 
			
		||||
    """
 | 
			
		||||
    weights_pad = torch.concatenate(
 | 
			
		||||
        [
 | 
			
		||||
            weights[..., :1],
 | 
			
		||||
            weights,
 | 
			
		||||
            weights[..., -1:],
 | 
			
		||||
        ],
 | 
			
		||||
        dim=-1,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    weights_max = torch.nn.functional.max_pool1d(
 | 
			
		||||
        weights_pad.flatten(end_dim=-2), 2, stride=1
 | 
			
		||||
    )
 | 
			
		||||
    return torch.lerp(weights_max[..., :-1], weights_max[..., 1:], 0.5).reshape_as(
 | 
			
		||||
        weights
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
@ -5,9 +5,14 @@
 | 
			
		||||
# LICENSE file in the root directory of this source tree.
 | 
			
		||||
 | 
			
		||||
import unittest
 | 
			
		||||
from itertools import product
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from pytorch3d.implicitron.models.renderer.ray_point_refiner import RayPointRefiner
 | 
			
		||||
 | 
			
		||||
from pytorch3d.implicitron.models.renderer.ray_point_refiner import (
 | 
			
		||||
    apply_blurpool_on_weights,
 | 
			
		||||
    RayPointRefiner,
 | 
			
		||||
)
 | 
			
		||||
from pytorch3d.implicitron.models.renderer.ray_sampler import ImplicitronRayBundle
 | 
			
		||||
from tests.common_testing import TestCaseMixin
 | 
			
		||||
 | 
			
		||||
@ -17,11 +22,12 @@ class TestRayPointRefiner(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
        length = 15
 | 
			
		||||
        n_pts_per_ray = 10
 | 
			
		||||
 | 
			
		||||
        for add_input_samples in [False, True]:
 | 
			
		||||
        for add_input_samples, use_blurpool in product([False, True], [False, True]):
 | 
			
		||||
            ray_point_refiner = RayPointRefiner(
 | 
			
		||||
                n_pts_per_ray=n_pts_per_ray,
 | 
			
		||||
                random_sampling=False,
 | 
			
		||||
                add_input_samples=add_input_samples,
 | 
			
		||||
                blurpool_weights=use_blurpool,
 | 
			
		||||
            )
 | 
			
		||||
            lengths = torch.arange(length, dtype=torch.float32).expand(3, 25, length)
 | 
			
		||||
            bundle = ImplicitronRayBundle(
 | 
			
		||||
@ -50,6 +56,7 @@ class TestRayPointRefiner(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
                n_pts_per_ray=n_pts_per_ray,
 | 
			
		||||
                random_sampling=True,
 | 
			
		||||
                add_input_samples=add_input_samples,
 | 
			
		||||
                blurpool_weights=use_blurpool,
 | 
			
		||||
            )
 | 
			
		||||
            refined_random = ray_point_refiner_random(bundle, weights)
 | 
			
		||||
            lengths_random = refined_random.lengths
 | 
			
		||||
@ -62,3 +69,24 @@ class TestRayPointRefiner(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
            self.assertTrue(
 | 
			
		||||
                (lengths_random[..., 1:] - lengths_random[..., :-1] > 0).all()
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
    def test_apply_blurpool_on_weights(self):
 | 
			
		||||
        weights = torch.tensor(
 | 
			
		||||
            [
 | 
			
		||||
                [0.5, 0.6, 0.7],
 | 
			
		||||
                [0.5, 0.3, 0.9],
 | 
			
		||||
            ]
 | 
			
		||||
        )
 | 
			
		||||
        expected_weights = 0.5 * torch.tensor(
 | 
			
		||||
            [
 | 
			
		||||
                [0.5 + 0.6, 0.6 + 0.7, 0.7 + 0.7],
 | 
			
		||||
                [0.5 + 0.5, 0.5 + 0.9, 0.9 + 0.9],
 | 
			
		||||
            ]
 | 
			
		||||
        )
 | 
			
		||||
        out_weights = apply_blurpool_on_weights(weights)
 | 
			
		||||
        self.assertTrue(torch.allclose(out_weights, expected_weights))
 | 
			
		||||
 | 
			
		||||
    def test_shapes_apply_blurpool_on_weights(self):
 | 
			
		||||
        weights = torch.randn((5, 4, 3, 2, 1))
 | 
			
		||||
        out_weights = apply_blurpool_on_weights(weights)
 | 
			
		||||
        self.assertEqual((5, 4, 3, 2, 1), out_weights.shape)
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user