From ccf860f1db38b839db9dcde206b6b5091ac50385 Mon Sep 17 00:00:00 2001 From: Emilien Garreau Date: Thu, 6 Jul 2023 02:20:53 -0700 Subject: [PATCH] Add integrated position encoding based on MIPNerf implementation. Summary: Add a new implicit module Integral Position Encoding based on [MIP-NeRF](https://arxiv.org/abs/2103.13415). Reviewed By: shapovalov Differential Revision: D46352730 fbshipit-source-id: c6a56134c975d80052b3a11f5e92fd7d95cbff1e --- .../implicitron_trainer/tests/experiment.yaml | 6 + .../neural_radiance_field.py | 36 +++++- .../models/implicit_function/utils.py | 7 +- .../renderer/implicit/harmonic_embedding.py | 83 ++++++++++--- ...implicit_function_neural_radiance_field.py | 66 ++++++++++ tests/test_harmonic_embedding.py | 114 +++++++++++++++++- 6 files changed, 283 insertions(+), 29 deletions(-) create mode 100644 tests/implicitron/test_implicit_function_neural_radiance_field.py diff --git a/projects/implicitron_trainer/tests/experiment.yaml b/projects/implicitron_trainer/tests/experiment.yaml index c60d444d..d9f1284b 100644 --- a/projects/implicitron_trainer/tests/experiment.yaml +++ b/projects/implicitron_trainer/tests/experiment.yaml @@ -361,6 +361,7 @@ model_factory_ImplicitronModelFactory_args: n_hidden_neurons_dir: 128 input_xyz: true xyz_ray_dir_in_camera_coords: false + use_integrated_positional_encoding: false transformer_dim_down_factor: 2.0 n_hidden_neurons_xyz: 80 n_layers_xyz: 2 @@ -372,6 +373,7 @@ model_factory_ImplicitronModelFactory_args: n_hidden_neurons_dir: 128 input_xyz: true xyz_ray_dir_in_camera_coords: false + use_integrated_positional_encoding: false transformer_dim_down_factor: 1.0 n_hidden_neurons_xyz: 256 n_layers_xyz: 8 @@ -741,6 +743,7 @@ model_factory_ImplicitronModelFactory_args: n_hidden_neurons_dir: 128 input_xyz: true xyz_ray_dir_in_camera_coords: false + use_integrated_positional_encoding: false transformer_dim_down_factor: 2.0 n_hidden_neurons_xyz: 80 n_layers_xyz: 2 @@ -752,6 +755,7 @@ model_factory_ImplicitronModelFactory_args: n_hidden_neurons_dir: 128 input_xyz: true xyz_ray_dir_in_camera_coords: false + use_integrated_positional_encoding: false transformer_dim_down_factor: 1.0 n_hidden_neurons_xyz: 256 n_layers_xyz: 8 @@ -979,6 +983,7 @@ model_factory_ImplicitronModelFactory_args: n_hidden_neurons_dir: 128 input_xyz: true xyz_ray_dir_in_camera_coords: false + use_integrated_positional_encoding: false transformer_dim_down_factor: 2.0 n_hidden_neurons_xyz: 80 n_layers_xyz: 2 @@ -990,6 +995,7 @@ model_factory_ImplicitronModelFactory_args: n_hidden_neurons_dir: 128 input_xyz: true xyz_ray_dir_in_camera_coords: false + use_integrated_positional_encoding: false transformer_dim_down_factor: 1.0 n_hidden_neurons_xyz: 256 n_layers_xyz: 8 diff --git a/pytorch3d/implicitron/models/implicit_function/neural_radiance_field.py b/pytorch3d/implicitron/models/implicit_function/neural_radiance_field.py index db42a4cc..0706d9a8 100644 --- a/pytorch3d/implicitron/models/implicit_function/neural_radiance_field.py +++ b/pytorch3d/implicitron/models/implicit_function/neural_radiance_field.py @@ -9,11 +9,14 @@ from typing import Optional, Tuple import torch from pytorch3d.common.linear_with_repeat import LinearWithRepeat -from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle +from pytorch3d.implicitron.models.renderer.base import ( + conical_frustum_to_gaussian, + ImplicitronRayBundle, +) from pytorch3d.implicitron.tools.config import expand_args_fields, registry -from pytorch3d.renderer import ray_bundle_to_ray_points from pytorch3d.renderer.cameras import CamerasBase from pytorch3d.renderer.implicit import HarmonicEmbedding +from pytorch3d.renderer.implicit.utils import ray_bundle_to_ray_points from .base import ImplicitFunctionBase @@ -36,6 +39,7 @@ class NeuralRadianceFieldBase(ImplicitFunctionBase, torch.nn.Module): input_xyz: bool = True xyz_ray_dir_in_camera_coords: bool = False color_dim: int = 3 + use_integrated_positional_encoding: bool = False """ Args: n_harmonic_functions_xyz: The number of harmonic functions @@ -53,6 +57,10 @@ class NeuralRadianceFieldBase(ImplicitFunctionBase, torch.nn.Module): n_layers_xyz: The number of layers of the MLP that outputs the occupancy field. append_xyz: The list of indices of the skip layers of the occupancy MLP. + use_integrated_positional_encoding: If True, use integrated positional enoding + as defined in `MIP-NeRF `_. + If False, use the classical harmonic embedding + defined in `NeRF `_. """ def __post_init__(self): @@ -149,6 +157,10 @@ class NeuralRadianceFieldBase(ImplicitFunctionBase, torch.nn.Module): containing the direction vectors of sampling rays in world coords. lengths: A tensor of shape `(minibatch, ..., num_points_per_ray)` containing the lengths at which the rays are sampled. + bins: An optional tensor of shape `(minibatch,..., num_points_per_ray + 1)` + containing the bins at which the rays are sampled. In this case + lengths is equal to the midpoints of bins. + fun_viewpool: an optional callback with the signature fun_fiewpool(points) -> pooled_features where points is a [N_TGT x N x 3] tensor of world coords, @@ -160,11 +172,22 @@ class NeuralRadianceFieldBase(ImplicitFunctionBase, torch.nn.Module): denoting the opacitiy of each ray point. rays_colors: A tensor of shape `(minibatch, ..., num_points_per_ray, 3)` denoting the color of each ray point. + + Raises: + ValueError: If `use_integrated_positional_encoding` is True and + `ray_bundle.bins` is None. """ - # We first convert the ray parametrizations to world - # coordinates with `ray_bundle_to_ray_points`. - # pyre-ignore[6] - rays_points_world = ray_bundle_to_ray_points(ray_bundle) + if self.use_integrated_positional_encoding and ray_bundle.bins is None: + raise ValueError( + "When use_integrated_positional_encoding is True, ray_bundle.bins must be set." + "Have you set to True `AbstractMaskRaySampler.use_bins_for_ray_sampling`?" + ) + + rays_points_world, diag_cov = ( + conical_frustum_to_gaussian(ray_bundle) + if self.use_integrated_positional_encoding + else (ray_bundle_to_ray_points(ray_bundle), None) # pyre-ignore + ) # rays_points_world.shape = [minibatch x ... x pts_per_ray x 3] embeds = create_embeddings_for_implicit_function( @@ -177,6 +200,7 @@ class NeuralRadianceFieldBase(ImplicitFunctionBase, torch.nn.Module): fun_viewpool=fun_viewpool, xyz_in_camera_coords=self.xyz_ray_dir_in_camera_coords, camera=camera, + diag_cov=diag_cov, ) # embeds.shape = [minibatch x n_src x n_rays x n_pts x self.n_harmonic_functions*6+3] diff --git a/pytorch3d/implicitron/models/implicit_function/utils.py b/pytorch3d/implicitron/models/implicit_function/utils.py index e9b688ef..25ec3fcb 100644 --- a/pytorch3d/implicitron/models/implicit_function/utils.py +++ b/pytorch3d/implicitron/models/implicit_function/utils.py @@ -36,6 +36,7 @@ def create_embeddings_for_implicit_function( camera: Optional[CamerasBase], fun_viewpool: Optional[Callable], xyz_embedding_function: Optional[Callable], + diag_cov: Optional[torch.Tensor] = None, ) -> torch.Tensor: bs, *spatial_size, pts_per_ray, _ = xyz_world.shape @@ -59,11 +60,11 @@ def create_embeddings_for_implicit_function( prod(spatial_size), pts_per_ray, 0, - dtype=xyz_world.dtype, - device=xyz_world.device, ) else: - embeds = xyz_embedding_function(ray_points_for_embed).reshape( + + embeds = xyz_embedding_function(ray_points_for_embed, diag_cov=diag_cov) + embeds = embeds.reshape( bs, 1, prod(spatial_size), diff --git a/pytorch3d/renderer/implicit/harmonic_embedding.py b/pytorch3d/renderer/implicit/harmonic_embedding.py index 44907014..418eaa73 100644 --- a/pytorch3d/renderer/implicit/harmonic_embedding.py +++ b/pytorch3d/renderer/implicit/harmonic_embedding.py @@ -4,6 +4,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from typing import Optional + import torch @@ -16,8 +18,18 @@ class HarmonicEmbedding(torch.nn.Module): append_input: bool = True, ) -> None: """ - Given an input tensor `x` of shape [minibatch, ... , dim], - the harmonic embedding layer converts each feature + The harmonic embedding layer supports the classical + Nerf positional encoding described in + `NeRF `_ + and the integrated position encoding in + `MIP-NeRF `_. + + During, the inference you can provide the extra argument `diag_cov`. + + If `diag_cov is None`, it converts + rays parametrized with a `ray_bundle` to 3D points by + extending each ray according to the corresponding length. + Then it converts each feature (i.e. vector along the last dimension) in `x` into a series of harmonic features `embedding`, where for each i in range(dim) the following are present @@ -38,6 +50,31 @@ class HarmonicEmbedding(torch.nn.Module): where N corresponds to `n_harmonic_functions-1`, and f_i is a scalar denoting the i-th frequency of the harmonic embedding. + + If `diag_cov is not None`, it approximates + conical frustums following a ray bundle as gaussians, + defined by x, the means of the gaussians and diag_cov, + the diagonal covariances. + Then it converts each gaussian + into a series of harmonic features `embedding`, + where for each i in range(dim) the following are present + in embedding[...]:: + + [ + sin(f_1*x[..., i]) * exp(0.5 * f_1**2 * diag_cov[..., i,]), + sin(f_2*x[..., i]) * exp(0.5 * f_2**2 * diag_cov[..., i,]), + ... + sin(f_N * x[..., i]) * exp(0.5 * f_N**2 * diag_cov[..., i,]), + cos(f_1*x[..., i]) * exp(0.5 * f_1**2 * diag_cov[..., i,]), + cos(f_2*x[..., i]) * exp(0.5 * f_2**2 * diag_cov[..., i,]),, + ... + cos(f_N * x[..., i]) * exp(0.5 * f_N**2 * diag_cov[..., i,]), + x[..., i], # only present if append_input is True. + ] + + where N equals `n_harmonic_functions-1`, and f_i is a scalar + denoting the i-th frequency of the harmonic embedding. + If `logspace==True`, the frequencies `[f_1, ..., f_N]` are powers of 2: `f_1, ..., f_N = 2**torch.arange(n_harmonic_functions)` @@ -59,8 +96,7 @@ class HarmonicEmbedding(torch.nn.Module): logspace or linear space append_input: bool, whether to concat the original input to the harmonic embedding. If true the - output is of the form (x, embed.sin(), embed.cos() - + output is of the form (embed.sin(), embed.cos(), x) """ super().__init__() @@ -78,23 +114,42 @@ class HarmonicEmbedding(torch.nn.Module): ) self.register_buffer("_frequencies", frequencies * omega_0, persistent=False) + self.register_buffer( + "_zero_half_pi", torch.tensor([0.0, 0.5 * torch.pi]), persistent=False + ) self.append_input = append_input - def forward(self, x: torch.Tensor) -> torch.Tensor: + def forward( + self, x: torch.Tensor, diag_cov: Optional[torch.Tensor] = None, **kwargs + ) -> torch.Tensor: """ Args: x: tensor of shape [..., dim] + diag_cov: An optional tensor of shape `(..., dim)` + representing the diagonal covariance matrices of our Gaussians, joined with x + as means of the Gaussians. + Returns: - embedding: a harmonic embedding of `x` - of shape [..., (n_harmonic_functions * 2 + int(append_input)) * dim] + embedding: a harmonic embedding of `x` of shape + [..., (n_harmonic_functions * 2 + int(append_input)) * num_points_per_ray] """ - embed = (x[..., None] * self._frequencies).reshape(*x.shape[:-1], -1) - embed = torch.cat( - (embed.sin(), embed.cos(), x) - if self.append_input - else (embed.sin(), embed.cos()), - dim=-1, - ) + # [..., dim, n_harmonic_functions] + embed = x[..., None] * self._frequencies + # [..., 1, dim, n_harmonic_functions] + [2, 1, 1] => [..., 2, dim, n_harmonic_functions] + embed = embed[..., None, :, :] + self._zero_half_pi[..., None, None] + # Use the trig identity cos(x) = sin(x + pi/2) + # and do one vectorized call to sin([x, x+pi/2]) instead of (sin(x), cos(x)). + embed = embed.sin() + if diag_cov is not None: + x_var = diag_cov[..., None] * torch.pow(self._frequencies, 2) + exp_var = torch.exp(-0.5 * x_var) + # [..., 2, dim, n_harmonic_functions] + embed = embed * exp_var[..., None, :, :] + + embed = embed.reshape(*x.shape[:-1], -1) + + if self.append_input: + return torch.cat([embed, x], dim=-1) return embed @staticmethod diff --git a/tests/implicitron/test_implicit_function_neural_radiance_field.py b/tests/implicitron/test_implicit_function_neural_radiance_field.py new file mode 100644 index 00000000..f31dfcd1 --- /dev/null +++ b/tests/implicitron/test_implicit_function_neural_radiance_field.py @@ -0,0 +1,66 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import torch +from pytorch3d.implicitron.models.implicit_function.base import ImplicitronRayBundle +from pytorch3d.implicitron.models.implicit_function.neural_radiance_field import ( + NeuralRadianceFieldImplicitFunction, +) + + +class TestNeuralRadianceFieldImplicitFunction(unittest.TestCase): + def setUp(self): + torch.manual_seed(42) + + def test_forward_with_integrated_positionial_embedding(self): + shape = [2, 4, 4] + ray_bundle = ImplicitronRayBundle( + origins=torch.randn(*shape, 3), + directions=torch.randn(*shape, 3), + bins=torch.randn(*shape, 6 + 1), + lengths=torch.randn(*shape, 6), + pixel_radii_2d=torch.randn(*shape, 1), + xys=None, + ) + model = NeuralRadianceFieldImplicitFunction( + n_hidden_neurons_dir=32, use_integrated_positional_encoding=True + ) + raw_densities, ray_colors, _ = model(ray_bundle=ray_bundle) + + self.assertEqual(raw_densities.shape, (*shape, ray_bundle.lengths.shape[-1], 1)) + self.assertEqual(ray_colors.shape, (*shape, ray_bundle.lengths.shape[-1], 3)) + + def test_forward_with_integrated_positionial_embedding_raise_exception(self): + shape = [2, 4, 4] + ray_bundle = ImplicitronRayBundle( + origins=torch.randn(*shape, 3), + directions=torch.randn(*shape, 3), + bins=None, + lengths=torch.randn(*shape, 6), + pixel_radii_2d=torch.randn(*shape, 1), + xys=None, + ) + model = NeuralRadianceFieldImplicitFunction( + n_hidden_neurons_dir=32, use_integrated_positional_encoding=True + ) + with self.assertRaises(ValueError): + _ = model(ray_bundle=ray_bundle) + + def test_forward(self): + shape = [2, 4, 4] + ray_bundle = ImplicitronRayBundle( + origins=torch.randn(*shape, 3), + directions=torch.randn(*shape, 3), + lengths=torch.randn(*shape, 6), + pixel_radii_2d=torch.randn(*shape, 1), + xys=None, + ) + model = NeuralRadianceFieldImplicitFunction(n_hidden_neurons_dir=32) + raw_densities, ray_colors, _ = model(ray_bundle=ray_bundle) + self.assertEqual(raw_densities.shape, (*shape, 6, 1)) + self.assertEqual(ray_colors.shape, (*shape, 6, 3)) diff --git a/tests/test_harmonic_embedding.py b/tests/test_harmonic_embedding.py index b2b35ba4..91e1ba8b 100644 --- a/tests/test_harmonic_embedding.py +++ b/tests/test_harmonic_embedding.py @@ -8,6 +8,7 @@ import unittest import torch from pytorch3d.renderer.implicit import HarmonicEmbedding +from torch.distributions import MultivariateNormal from .common_testing import TestCaseMixin @@ -36,16 +37,117 @@ class TestHarmonicEmbedding(TestCaseMixin, unittest.TestCase): self.assertClose(embed_fun_lin._frequencies, torch.FloatTensor((1.0, 2.5, 4.0))) def test_correct_embed_out(self): - embed_fun = HarmonicEmbedding(n_harmonic_functions=2, append_input=False) + n_harmonic_functions = 2 x = torch.randn((1, 5)) - D = 5 * 4 + D = 5 * n_harmonic_functions * 2 # sin + cos + + embed_fun = HarmonicEmbedding( + n_harmonic_functions=n_harmonic_functions, append_input=False + ) embed_out = embed_fun(x) + self.assertEqual(embed_out.shape, (1, D)) # Sum the squares of the respective frequencies + # cos^2(x) + sin^2(x) = 1 sum_squares = embed_out[0, : D // 2] ** 2 + embed_out[0, D // 2 :] ** 2 self.assertClose(sum_squares, torch.ones((D // 2))) - embed_fun = HarmonicEmbedding(n_harmonic_functions=2, append_input=True) - embed_out = embed_fun(x) - self.assertClose(embed_out.shape, torch.tensor((1, 5 * 5))) + + # Test append input + embed_fun = HarmonicEmbedding( + n_harmonic_functions=n_harmonic_functions, append_input=True + ) + embed_out_appended_input = embed_fun(x) + self.assertClose( + embed_out_appended_input.shape, torch.tensor((1, D + x.shape[-1])) + ) # Last plane in output is the input - self.assertClose(embed_out[..., -5:], x) + self.assertClose(embed_out_appended_input[..., -x.shape[-1] :], x) + self.assertClose(embed_out_appended_input[..., : -x.shape[-1]], embed_out) + + def test_correct_embed_out_with_diag_cov(self): + n_harmonic_functions = 2 + x = torch.randn((1, 3)) + diag_cov = torch.randn((1, 3)) + D = 3 * n_harmonic_functions * 2 # sin + cos + + embed_fun = HarmonicEmbedding( + n_harmonic_functions=n_harmonic_functions, append_input=False + ) + embed_out = embed_fun(x, diag_cov=diag_cov) + + self.assertEqual(embed_out.shape, (1, D)) + + # Compute the scaling factor introduce in MipNerf + scale_factor = ( + -0.5 * diag_cov[..., None] * torch.pow(embed_fun._frequencies[None, :], 2) + ) + scale_factor = torch.exp(scale_factor).reshape(1, -1).tile((1, 2)) + # If we remove this scaling factor, we should go back to the + # classical harmonic embedding: + # Sum the squares of the respective frequencies + # cos^2(x) + sin^2(x) = 1 + embed_out_without_cov = embed_out / scale_factor + sum_squares = ( + embed_out_without_cov[0, : D // 2] ** 2 + + embed_out_without_cov[0, D // 2 :] ** 2 + ) + self.assertClose(sum_squares, torch.ones((D // 2))) + + # Test append input + embed_fun = HarmonicEmbedding( + n_harmonic_functions=n_harmonic_functions, append_input=True + ) + embed_out_appended_input = embed_fun(x, diag_cov=diag_cov) + self.assertClose( + embed_out_appended_input.shape, torch.tensor((1, D + x.shape[-1])) + ) + # Last plane in output is the input + self.assertClose(embed_out_appended_input[..., -x.shape[-1] :], x) + self.assertClose(embed_out_appended_input[..., : -x.shape[-1]], embed_out) + + def test_correct_behavior_between_ipe_and_its_estimation_from_harmonic_embedding( + self, + ): + """ + Check that the HarmonicEmbedding with integrated_position_encoding (IPE) set to + True is coherent with the HarmonicEmbedding. + + What is the idea behind this test? + + We wish to produce an IPE that is the expectation + of our lifted multivariate gaussian, modulated by the sine and cosine of + the coordinates. These expectation has a closed-form + (see equations 11, 12, 13, 14 of [1]). + + We sample N elements from the multivariate gaussian defined by its mean and covariance + and compute the HarmonicEmbedding. The expected value of those embeddings should be + equal to our IPE. + + Inspired from: + https://github.com/google/mipnerf/blob/84c969e0a623edd183b75693aed72a7e7c22902d/internal/mip_test.py#L359 + + References: + [1] `MIP-NeRF `_. + """ + num_dims = 3 + n_harmonic_functions = 6 + mean = torch.randn(num_dims) + diag_cov = torch.rand(num_dims) + + he_fun = HarmonicEmbedding( + n_harmonic_functions=n_harmonic_functions, logspace=True, append_input=False + ) + ipe_fun = HarmonicEmbedding( + n_harmonic_functions=n_harmonic_functions, + append_input=False, + ) + + embedding_ipe = ipe_fun(mean, diag_cov=diag_cov) + + rand_mvn = MultivariateNormal(mean, torch.eye(num_dims) * diag_cov) + + # Providing a large enough number of samples + # we should obtain an estimation close to our IPE + num_samples = 100000 + embedding_he = he_fun(rand_mvn.sample_n(num_samples)) + self.assertClose(embedding_he.mean(0), embedding_ipe, rtol=1e-2, atol=1e-2)