Add integrated position encoding based on MIPNerf implementation.

Summary: Add a new implicit module Integral Position Encoding based on [MIP-NeRF](https://arxiv.org/abs/2103.13415).

Reviewed By: shapovalov

Differential Revision: D46352730

fbshipit-source-id: c6a56134c975d80052b3a11f5e92fd7d95cbff1e
This commit is contained in:
Emilien Garreau 2023-07-06 02:20:53 -07:00 committed by Facebook GitHub Bot
parent 29b8ebd802
commit ccf860f1db
6 changed files with 283 additions and 29 deletions

View File

@ -361,6 +361,7 @@ model_factory_ImplicitronModelFactory_args:
n_hidden_neurons_dir: 128
input_xyz: true
xyz_ray_dir_in_camera_coords: false
use_integrated_positional_encoding: false
transformer_dim_down_factor: 2.0
n_hidden_neurons_xyz: 80
n_layers_xyz: 2
@ -372,6 +373,7 @@ model_factory_ImplicitronModelFactory_args:
n_hidden_neurons_dir: 128
input_xyz: true
xyz_ray_dir_in_camera_coords: false
use_integrated_positional_encoding: false
transformer_dim_down_factor: 1.0
n_hidden_neurons_xyz: 256
n_layers_xyz: 8
@ -741,6 +743,7 @@ model_factory_ImplicitronModelFactory_args:
n_hidden_neurons_dir: 128
input_xyz: true
xyz_ray_dir_in_camera_coords: false
use_integrated_positional_encoding: false
transformer_dim_down_factor: 2.0
n_hidden_neurons_xyz: 80
n_layers_xyz: 2
@ -752,6 +755,7 @@ model_factory_ImplicitronModelFactory_args:
n_hidden_neurons_dir: 128
input_xyz: true
xyz_ray_dir_in_camera_coords: false
use_integrated_positional_encoding: false
transformer_dim_down_factor: 1.0
n_hidden_neurons_xyz: 256
n_layers_xyz: 8
@ -979,6 +983,7 @@ model_factory_ImplicitronModelFactory_args:
n_hidden_neurons_dir: 128
input_xyz: true
xyz_ray_dir_in_camera_coords: false
use_integrated_positional_encoding: false
transformer_dim_down_factor: 2.0
n_hidden_neurons_xyz: 80
n_layers_xyz: 2
@ -990,6 +995,7 @@ model_factory_ImplicitronModelFactory_args:
n_hidden_neurons_dir: 128
input_xyz: true
xyz_ray_dir_in_camera_coords: false
use_integrated_positional_encoding: false
transformer_dim_down_factor: 1.0
n_hidden_neurons_xyz: 256
n_layers_xyz: 8

View File

@ -9,11 +9,14 @@ from typing import Optional, Tuple
import torch
from pytorch3d.common.linear_with_repeat import LinearWithRepeat
from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle
from pytorch3d.implicitron.models.renderer.base import (
conical_frustum_to_gaussian,
ImplicitronRayBundle,
)
from pytorch3d.implicitron.tools.config import expand_args_fields, registry
from pytorch3d.renderer import ray_bundle_to_ray_points
from pytorch3d.renderer.cameras import CamerasBase
from pytorch3d.renderer.implicit import HarmonicEmbedding
from pytorch3d.renderer.implicit.utils import ray_bundle_to_ray_points
from .base import ImplicitFunctionBase
@ -36,6 +39,7 @@ class NeuralRadianceFieldBase(ImplicitFunctionBase, torch.nn.Module):
input_xyz: bool = True
xyz_ray_dir_in_camera_coords: bool = False
color_dim: int = 3
use_integrated_positional_encoding: bool = False
"""
Args:
n_harmonic_functions_xyz: The number of harmonic functions
@ -53,6 +57,10 @@ class NeuralRadianceFieldBase(ImplicitFunctionBase, torch.nn.Module):
n_layers_xyz: The number of layers of the MLP that outputs the
occupancy field.
append_xyz: The list of indices of the skip layers of the occupancy MLP.
use_integrated_positional_encoding: If True, use integrated positional enoding
as defined in `MIP-NeRF <https://arxiv.org/abs/2103.13415>`_.
If False, use the classical harmonic embedding
defined in `NeRF <https://arxiv.org/abs/2003.08934>`_.
"""
def __post_init__(self):
@ -149,6 +157,10 @@ class NeuralRadianceFieldBase(ImplicitFunctionBase, torch.nn.Module):
containing the direction vectors of sampling rays in world coords.
lengths: A tensor of shape `(minibatch, ..., num_points_per_ray)`
containing the lengths at which the rays are sampled.
bins: An optional tensor of shape `(minibatch,..., num_points_per_ray + 1)`
containing the bins at which the rays are sampled. In this case
lengths is equal to the midpoints of bins.
fun_viewpool: an optional callback with the signature
fun_fiewpool(points) -> pooled_features
where points is a [N_TGT x N x 3] tensor of world coords,
@ -160,11 +172,22 @@ class NeuralRadianceFieldBase(ImplicitFunctionBase, torch.nn.Module):
denoting the opacitiy of each ray point.
rays_colors: A tensor of shape `(minibatch, ..., num_points_per_ray, 3)`
denoting the color of each ray point.
Raises:
ValueError: If `use_integrated_positional_encoding` is True and
`ray_bundle.bins` is None.
"""
# We first convert the ray parametrizations to world
# coordinates with `ray_bundle_to_ray_points`.
# pyre-ignore[6]
rays_points_world = ray_bundle_to_ray_points(ray_bundle)
if self.use_integrated_positional_encoding and ray_bundle.bins is None:
raise ValueError(
"When use_integrated_positional_encoding is True, ray_bundle.bins must be set."
"Have you set to True `AbstractMaskRaySampler.use_bins_for_ray_sampling`?"
)
rays_points_world, diag_cov = (
conical_frustum_to_gaussian(ray_bundle)
if self.use_integrated_positional_encoding
else (ray_bundle_to_ray_points(ray_bundle), None) # pyre-ignore
)
# rays_points_world.shape = [minibatch x ... x pts_per_ray x 3]
embeds = create_embeddings_for_implicit_function(
@ -177,6 +200,7 @@ class NeuralRadianceFieldBase(ImplicitFunctionBase, torch.nn.Module):
fun_viewpool=fun_viewpool,
xyz_in_camera_coords=self.xyz_ray_dir_in_camera_coords,
camera=camera,
diag_cov=diag_cov,
)
# embeds.shape = [minibatch x n_src x n_rays x n_pts x self.n_harmonic_functions*6+3]

View File

@ -36,6 +36,7 @@ def create_embeddings_for_implicit_function(
camera: Optional[CamerasBase],
fun_viewpool: Optional[Callable],
xyz_embedding_function: Optional[Callable],
diag_cov: Optional[torch.Tensor] = None,
) -> torch.Tensor:
bs, *spatial_size, pts_per_ray, _ = xyz_world.shape
@ -59,11 +60,11 @@ def create_embeddings_for_implicit_function(
prod(spatial_size),
pts_per_ray,
0,
dtype=xyz_world.dtype,
device=xyz_world.device,
)
else:
embeds = xyz_embedding_function(ray_points_for_embed).reshape(
embeds = xyz_embedding_function(ray_points_for_embed, diag_cov=diag_cov)
embeds = embeds.reshape(
bs,
1,
prod(spatial_size),

View File

@ -4,6 +4,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Optional
import torch
@ -16,8 +18,18 @@ class HarmonicEmbedding(torch.nn.Module):
append_input: bool = True,
) -> None:
"""
Given an input tensor `x` of shape [minibatch, ... , dim],
the harmonic embedding layer converts each feature
The harmonic embedding layer supports the classical
Nerf positional encoding described in
`NeRF <https://arxiv.org/abs/2003.08934>`_
and the integrated position encoding in
`MIP-NeRF <https://arxiv.org/abs/2103.13415>`_.
During, the inference you can provide the extra argument `diag_cov`.
If `diag_cov is None`, it converts
rays parametrized with a `ray_bundle` to 3D points by
extending each ray according to the corresponding length.
Then it converts each feature
(i.e. vector along the last dimension) in `x`
into a series of harmonic features `embedding`,
where for each i in range(dim) the following are present
@ -38,6 +50,31 @@ class HarmonicEmbedding(torch.nn.Module):
where N corresponds to `n_harmonic_functions-1`, and f_i is a scalar
denoting the i-th frequency of the harmonic embedding.
If `diag_cov is not None`, it approximates
conical frustums following a ray bundle as gaussians,
defined by x, the means of the gaussians and diag_cov,
the diagonal covariances.
Then it converts each gaussian
into a series of harmonic features `embedding`,
where for each i in range(dim) the following are present
in embedding[...]::
[
sin(f_1*x[..., i]) * exp(0.5 * f_1**2 * diag_cov[..., i,]),
sin(f_2*x[..., i]) * exp(0.5 * f_2**2 * diag_cov[..., i,]),
...
sin(f_N * x[..., i]) * exp(0.5 * f_N**2 * diag_cov[..., i,]),
cos(f_1*x[..., i]) * exp(0.5 * f_1**2 * diag_cov[..., i,]),
cos(f_2*x[..., i]) * exp(0.5 * f_2**2 * diag_cov[..., i,]),,
...
cos(f_N * x[..., i]) * exp(0.5 * f_N**2 * diag_cov[..., i,]),
x[..., i], # only present if append_input is True.
]
where N equals `n_harmonic_functions-1`, and f_i is a scalar
denoting the i-th frequency of the harmonic embedding.
If `logspace==True`, the frequencies `[f_1, ..., f_N]` are
powers of 2:
`f_1, ..., f_N = 2**torch.arange(n_harmonic_functions)`
@ -59,8 +96,7 @@ class HarmonicEmbedding(torch.nn.Module):
logspace or linear space
append_input: bool, whether to concat the original
input to the harmonic embedding. If true the
output is of the form (x, embed.sin(), embed.cos()
output is of the form (embed.sin(), embed.cos(), x)
"""
super().__init__()
@ -78,23 +114,42 @@ class HarmonicEmbedding(torch.nn.Module):
)
self.register_buffer("_frequencies", frequencies * omega_0, persistent=False)
self.register_buffer(
"_zero_half_pi", torch.tensor([0.0, 0.5 * torch.pi]), persistent=False
)
self.append_input = append_input
def forward(self, x: torch.Tensor) -> torch.Tensor:
def forward(
self, x: torch.Tensor, diag_cov: Optional[torch.Tensor] = None, **kwargs
) -> torch.Tensor:
"""
Args:
x: tensor of shape [..., dim]
diag_cov: An optional tensor of shape `(..., dim)`
representing the diagonal covariance matrices of our Gaussians, joined with x
as means of the Gaussians.
Returns:
embedding: a harmonic embedding of `x`
of shape [..., (n_harmonic_functions * 2 + int(append_input)) * dim]
embedding: a harmonic embedding of `x` of shape
[..., (n_harmonic_functions * 2 + int(append_input)) * num_points_per_ray]
"""
embed = (x[..., None] * self._frequencies).reshape(*x.shape[:-1], -1)
embed = torch.cat(
(embed.sin(), embed.cos(), x)
if self.append_input
else (embed.sin(), embed.cos()),
dim=-1,
)
# [..., dim, n_harmonic_functions]
embed = x[..., None] * self._frequencies
# [..., 1, dim, n_harmonic_functions] + [2, 1, 1] => [..., 2, dim, n_harmonic_functions]
embed = embed[..., None, :, :] + self._zero_half_pi[..., None, None]
# Use the trig identity cos(x) = sin(x + pi/2)
# and do one vectorized call to sin([x, x+pi/2]) instead of (sin(x), cos(x)).
embed = embed.sin()
if diag_cov is not None:
x_var = diag_cov[..., None] * torch.pow(self._frequencies, 2)
exp_var = torch.exp(-0.5 * x_var)
# [..., 2, dim, n_harmonic_functions]
embed = embed * exp_var[..., None, :, :]
embed = embed.reshape(*x.shape[:-1], -1)
if self.append_input:
return torch.cat([embed, x], dim=-1)
return embed
@staticmethod

View File

@ -0,0 +1,66 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import unittest
import torch
from pytorch3d.implicitron.models.implicit_function.base import ImplicitronRayBundle
from pytorch3d.implicitron.models.implicit_function.neural_radiance_field import (
NeuralRadianceFieldImplicitFunction,
)
class TestNeuralRadianceFieldImplicitFunction(unittest.TestCase):
def setUp(self):
torch.manual_seed(42)
def test_forward_with_integrated_positionial_embedding(self):
shape = [2, 4, 4]
ray_bundle = ImplicitronRayBundle(
origins=torch.randn(*shape, 3),
directions=torch.randn(*shape, 3),
bins=torch.randn(*shape, 6 + 1),
lengths=torch.randn(*shape, 6),
pixel_radii_2d=torch.randn(*shape, 1),
xys=None,
)
model = NeuralRadianceFieldImplicitFunction(
n_hidden_neurons_dir=32, use_integrated_positional_encoding=True
)
raw_densities, ray_colors, _ = model(ray_bundle=ray_bundle)
self.assertEqual(raw_densities.shape, (*shape, ray_bundle.lengths.shape[-1], 1))
self.assertEqual(ray_colors.shape, (*shape, ray_bundle.lengths.shape[-1], 3))
def test_forward_with_integrated_positionial_embedding_raise_exception(self):
shape = [2, 4, 4]
ray_bundle = ImplicitronRayBundle(
origins=torch.randn(*shape, 3),
directions=torch.randn(*shape, 3),
bins=None,
lengths=torch.randn(*shape, 6),
pixel_radii_2d=torch.randn(*shape, 1),
xys=None,
)
model = NeuralRadianceFieldImplicitFunction(
n_hidden_neurons_dir=32, use_integrated_positional_encoding=True
)
with self.assertRaises(ValueError):
_ = model(ray_bundle=ray_bundle)
def test_forward(self):
shape = [2, 4, 4]
ray_bundle = ImplicitronRayBundle(
origins=torch.randn(*shape, 3),
directions=torch.randn(*shape, 3),
lengths=torch.randn(*shape, 6),
pixel_radii_2d=torch.randn(*shape, 1),
xys=None,
)
model = NeuralRadianceFieldImplicitFunction(n_hidden_neurons_dir=32)
raw_densities, ray_colors, _ = model(ray_bundle=ray_bundle)
self.assertEqual(raw_densities.shape, (*shape, 6, 1))
self.assertEqual(ray_colors.shape, (*shape, 6, 3))

View File

@ -8,6 +8,7 @@ import unittest
import torch
from pytorch3d.renderer.implicit import HarmonicEmbedding
from torch.distributions import MultivariateNormal
from .common_testing import TestCaseMixin
@ -36,16 +37,117 @@ class TestHarmonicEmbedding(TestCaseMixin, unittest.TestCase):
self.assertClose(embed_fun_lin._frequencies, torch.FloatTensor((1.0, 2.5, 4.0)))
def test_correct_embed_out(self):
embed_fun = HarmonicEmbedding(n_harmonic_functions=2, append_input=False)
n_harmonic_functions = 2
x = torch.randn((1, 5))
D = 5 * 4
D = 5 * n_harmonic_functions * 2 # sin + cos
embed_fun = HarmonicEmbedding(
n_harmonic_functions=n_harmonic_functions, append_input=False
)
embed_out = embed_fun(x)
self.assertEqual(embed_out.shape, (1, D))
# Sum the squares of the respective frequencies
# cos^2(x) + sin^2(x) = 1
sum_squares = embed_out[0, : D // 2] ** 2 + embed_out[0, D // 2 :] ** 2
self.assertClose(sum_squares, torch.ones((D // 2)))
embed_fun = HarmonicEmbedding(n_harmonic_functions=2, append_input=True)
embed_out = embed_fun(x)
self.assertClose(embed_out.shape, torch.tensor((1, 5 * 5)))
# Test append input
embed_fun = HarmonicEmbedding(
n_harmonic_functions=n_harmonic_functions, append_input=True
)
embed_out_appended_input = embed_fun(x)
self.assertClose(
embed_out_appended_input.shape, torch.tensor((1, D + x.shape[-1]))
)
# Last plane in output is the input
self.assertClose(embed_out[..., -5:], x)
self.assertClose(embed_out_appended_input[..., -x.shape[-1] :], x)
self.assertClose(embed_out_appended_input[..., : -x.shape[-1]], embed_out)
def test_correct_embed_out_with_diag_cov(self):
n_harmonic_functions = 2
x = torch.randn((1, 3))
diag_cov = torch.randn((1, 3))
D = 3 * n_harmonic_functions * 2 # sin + cos
embed_fun = HarmonicEmbedding(
n_harmonic_functions=n_harmonic_functions, append_input=False
)
embed_out = embed_fun(x, diag_cov=diag_cov)
self.assertEqual(embed_out.shape, (1, D))
# Compute the scaling factor introduce in MipNerf
scale_factor = (
-0.5 * diag_cov[..., None] * torch.pow(embed_fun._frequencies[None, :], 2)
)
scale_factor = torch.exp(scale_factor).reshape(1, -1).tile((1, 2))
# If we remove this scaling factor, we should go back to the
# classical harmonic embedding:
# Sum the squares of the respective frequencies
# cos^2(x) + sin^2(x) = 1
embed_out_without_cov = embed_out / scale_factor
sum_squares = (
embed_out_without_cov[0, : D // 2] ** 2
+ embed_out_without_cov[0, D // 2 :] ** 2
)
self.assertClose(sum_squares, torch.ones((D // 2)))
# Test append input
embed_fun = HarmonicEmbedding(
n_harmonic_functions=n_harmonic_functions, append_input=True
)
embed_out_appended_input = embed_fun(x, diag_cov=diag_cov)
self.assertClose(
embed_out_appended_input.shape, torch.tensor((1, D + x.shape[-1]))
)
# Last plane in output is the input
self.assertClose(embed_out_appended_input[..., -x.shape[-1] :], x)
self.assertClose(embed_out_appended_input[..., : -x.shape[-1]], embed_out)
def test_correct_behavior_between_ipe_and_its_estimation_from_harmonic_embedding(
self,
):
"""
Check that the HarmonicEmbedding with integrated_position_encoding (IPE) set to
True is coherent with the HarmonicEmbedding.
What is the idea behind this test?
We wish to produce an IPE that is the expectation
of our lifted multivariate gaussian, modulated by the sine and cosine of
the coordinates. These expectation has a closed-form
(see equations 11, 12, 13, 14 of [1]).
We sample N elements from the multivariate gaussian defined by its mean and covariance
and compute the HarmonicEmbedding. The expected value of those embeddings should be
equal to our IPE.
Inspired from:
https://github.com/google/mipnerf/blob/84c969e0a623edd183b75693aed72a7e7c22902d/internal/mip_test.py#L359
References:
[1] `MIP-NeRF <https://arxiv.org/abs/2103.13415>`_.
"""
num_dims = 3
n_harmonic_functions = 6
mean = torch.randn(num_dims)
diag_cov = torch.rand(num_dims)
he_fun = HarmonicEmbedding(
n_harmonic_functions=n_harmonic_functions, logspace=True, append_input=False
)
ipe_fun = HarmonicEmbedding(
n_harmonic_functions=n_harmonic_functions,
append_input=False,
)
embedding_ipe = ipe_fun(mean, diag_cov=diag_cov)
rand_mvn = MultivariateNormal(mean, torch.eye(num_dims) * diag_cov)
# Providing a large enough number of samples
# we should obtain an estimation close to our IPE
num_samples = 100000
embedding_he = he_fun(rand_mvn.sample_n(num_samples))
self.assertClose(embedding_he.mean(0), embedding_ipe, rtol=1e-2, atol=1e-2)