mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-01 03:12:49 +08:00
Add integrated position encoding based on MIPNerf implementation.
Summary: Add a new implicit module Integral Position Encoding based on [MIP-NeRF](https://arxiv.org/abs/2103.13415). Reviewed By: shapovalov Differential Revision: D46352730 fbshipit-source-id: c6a56134c975d80052b3a11f5e92fd7d95cbff1e
This commit is contained in:
parent
29b8ebd802
commit
ccf860f1db
@ -361,6 +361,7 @@ model_factory_ImplicitronModelFactory_args:
|
|||||||
n_hidden_neurons_dir: 128
|
n_hidden_neurons_dir: 128
|
||||||
input_xyz: true
|
input_xyz: true
|
||||||
xyz_ray_dir_in_camera_coords: false
|
xyz_ray_dir_in_camera_coords: false
|
||||||
|
use_integrated_positional_encoding: false
|
||||||
transformer_dim_down_factor: 2.0
|
transformer_dim_down_factor: 2.0
|
||||||
n_hidden_neurons_xyz: 80
|
n_hidden_neurons_xyz: 80
|
||||||
n_layers_xyz: 2
|
n_layers_xyz: 2
|
||||||
@ -372,6 +373,7 @@ model_factory_ImplicitronModelFactory_args:
|
|||||||
n_hidden_neurons_dir: 128
|
n_hidden_neurons_dir: 128
|
||||||
input_xyz: true
|
input_xyz: true
|
||||||
xyz_ray_dir_in_camera_coords: false
|
xyz_ray_dir_in_camera_coords: false
|
||||||
|
use_integrated_positional_encoding: false
|
||||||
transformer_dim_down_factor: 1.0
|
transformer_dim_down_factor: 1.0
|
||||||
n_hidden_neurons_xyz: 256
|
n_hidden_neurons_xyz: 256
|
||||||
n_layers_xyz: 8
|
n_layers_xyz: 8
|
||||||
@ -741,6 +743,7 @@ model_factory_ImplicitronModelFactory_args:
|
|||||||
n_hidden_neurons_dir: 128
|
n_hidden_neurons_dir: 128
|
||||||
input_xyz: true
|
input_xyz: true
|
||||||
xyz_ray_dir_in_camera_coords: false
|
xyz_ray_dir_in_camera_coords: false
|
||||||
|
use_integrated_positional_encoding: false
|
||||||
transformer_dim_down_factor: 2.0
|
transformer_dim_down_factor: 2.0
|
||||||
n_hidden_neurons_xyz: 80
|
n_hidden_neurons_xyz: 80
|
||||||
n_layers_xyz: 2
|
n_layers_xyz: 2
|
||||||
@ -752,6 +755,7 @@ model_factory_ImplicitronModelFactory_args:
|
|||||||
n_hidden_neurons_dir: 128
|
n_hidden_neurons_dir: 128
|
||||||
input_xyz: true
|
input_xyz: true
|
||||||
xyz_ray_dir_in_camera_coords: false
|
xyz_ray_dir_in_camera_coords: false
|
||||||
|
use_integrated_positional_encoding: false
|
||||||
transformer_dim_down_factor: 1.0
|
transformer_dim_down_factor: 1.0
|
||||||
n_hidden_neurons_xyz: 256
|
n_hidden_neurons_xyz: 256
|
||||||
n_layers_xyz: 8
|
n_layers_xyz: 8
|
||||||
@ -979,6 +983,7 @@ model_factory_ImplicitronModelFactory_args:
|
|||||||
n_hidden_neurons_dir: 128
|
n_hidden_neurons_dir: 128
|
||||||
input_xyz: true
|
input_xyz: true
|
||||||
xyz_ray_dir_in_camera_coords: false
|
xyz_ray_dir_in_camera_coords: false
|
||||||
|
use_integrated_positional_encoding: false
|
||||||
transformer_dim_down_factor: 2.0
|
transformer_dim_down_factor: 2.0
|
||||||
n_hidden_neurons_xyz: 80
|
n_hidden_neurons_xyz: 80
|
||||||
n_layers_xyz: 2
|
n_layers_xyz: 2
|
||||||
@ -990,6 +995,7 @@ model_factory_ImplicitronModelFactory_args:
|
|||||||
n_hidden_neurons_dir: 128
|
n_hidden_neurons_dir: 128
|
||||||
input_xyz: true
|
input_xyz: true
|
||||||
xyz_ray_dir_in_camera_coords: false
|
xyz_ray_dir_in_camera_coords: false
|
||||||
|
use_integrated_positional_encoding: false
|
||||||
transformer_dim_down_factor: 1.0
|
transformer_dim_down_factor: 1.0
|
||||||
n_hidden_neurons_xyz: 256
|
n_hidden_neurons_xyz: 256
|
||||||
n_layers_xyz: 8
|
n_layers_xyz: 8
|
||||||
|
@ -9,11 +9,14 @@ from typing import Optional, Tuple
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from pytorch3d.common.linear_with_repeat import LinearWithRepeat
|
from pytorch3d.common.linear_with_repeat import LinearWithRepeat
|
||||||
from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle
|
from pytorch3d.implicitron.models.renderer.base import (
|
||||||
|
conical_frustum_to_gaussian,
|
||||||
|
ImplicitronRayBundle,
|
||||||
|
)
|
||||||
from pytorch3d.implicitron.tools.config import expand_args_fields, registry
|
from pytorch3d.implicitron.tools.config import expand_args_fields, registry
|
||||||
from pytorch3d.renderer import ray_bundle_to_ray_points
|
|
||||||
from pytorch3d.renderer.cameras import CamerasBase
|
from pytorch3d.renderer.cameras import CamerasBase
|
||||||
from pytorch3d.renderer.implicit import HarmonicEmbedding
|
from pytorch3d.renderer.implicit import HarmonicEmbedding
|
||||||
|
from pytorch3d.renderer.implicit.utils import ray_bundle_to_ray_points
|
||||||
|
|
||||||
from .base import ImplicitFunctionBase
|
from .base import ImplicitFunctionBase
|
||||||
|
|
||||||
@ -36,6 +39,7 @@ class NeuralRadianceFieldBase(ImplicitFunctionBase, torch.nn.Module):
|
|||||||
input_xyz: bool = True
|
input_xyz: bool = True
|
||||||
xyz_ray_dir_in_camera_coords: bool = False
|
xyz_ray_dir_in_camera_coords: bool = False
|
||||||
color_dim: int = 3
|
color_dim: int = 3
|
||||||
|
use_integrated_positional_encoding: bool = False
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
n_harmonic_functions_xyz: The number of harmonic functions
|
n_harmonic_functions_xyz: The number of harmonic functions
|
||||||
@ -53,6 +57,10 @@ class NeuralRadianceFieldBase(ImplicitFunctionBase, torch.nn.Module):
|
|||||||
n_layers_xyz: The number of layers of the MLP that outputs the
|
n_layers_xyz: The number of layers of the MLP that outputs the
|
||||||
occupancy field.
|
occupancy field.
|
||||||
append_xyz: The list of indices of the skip layers of the occupancy MLP.
|
append_xyz: The list of indices of the skip layers of the occupancy MLP.
|
||||||
|
use_integrated_positional_encoding: If True, use integrated positional enoding
|
||||||
|
as defined in `MIP-NeRF <https://arxiv.org/abs/2103.13415>`_.
|
||||||
|
If False, use the classical harmonic embedding
|
||||||
|
defined in `NeRF <https://arxiv.org/abs/2003.08934>`_.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
@ -149,6 +157,10 @@ class NeuralRadianceFieldBase(ImplicitFunctionBase, torch.nn.Module):
|
|||||||
containing the direction vectors of sampling rays in world coords.
|
containing the direction vectors of sampling rays in world coords.
|
||||||
lengths: A tensor of shape `(minibatch, ..., num_points_per_ray)`
|
lengths: A tensor of shape `(minibatch, ..., num_points_per_ray)`
|
||||||
containing the lengths at which the rays are sampled.
|
containing the lengths at which the rays are sampled.
|
||||||
|
bins: An optional tensor of shape `(minibatch,..., num_points_per_ray + 1)`
|
||||||
|
containing the bins at which the rays are sampled. In this case
|
||||||
|
lengths is equal to the midpoints of bins.
|
||||||
|
|
||||||
fun_viewpool: an optional callback with the signature
|
fun_viewpool: an optional callback with the signature
|
||||||
fun_fiewpool(points) -> pooled_features
|
fun_fiewpool(points) -> pooled_features
|
||||||
where points is a [N_TGT x N x 3] tensor of world coords,
|
where points is a [N_TGT x N x 3] tensor of world coords,
|
||||||
@ -160,11 +172,22 @@ class NeuralRadianceFieldBase(ImplicitFunctionBase, torch.nn.Module):
|
|||||||
denoting the opacitiy of each ray point.
|
denoting the opacitiy of each ray point.
|
||||||
rays_colors: A tensor of shape `(minibatch, ..., num_points_per_ray, 3)`
|
rays_colors: A tensor of shape `(minibatch, ..., num_points_per_ray, 3)`
|
||||||
denoting the color of each ray point.
|
denoting the color of each ray point.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If `use_integrated_positional_encoding` is True and
|
||||||
|
`ray_bundle.bins` is None.
|
||||||
"""
|
"""
|
||||||
# We first convert the ray parametrizations to world
|
if self.use_integrated_positional_encoding and ray_bundle.bins is None:
|
||||||
# coordinates with `ray_bundle_to_ray_points`.
|
raise ValueError(
|
||||||
# pyre-ignore[6]
|
"When use_integrated_positional_encoding is True, ray_bundle.bins must be set."
|
||||||
rays_points_world = ray_bundle_to_ray_points(ray_bundle)
|
"Have you set to True `AbstractMaskRaySampler.use_bins_for_ray_sampling`?"
|
||||||
|
)
|
||||||
|
|
||||||
|
rays_points_world, diag_cov = (
|
||||||
|
conical_frustum_to_gaussian(ray_bundle)
|
||||||
|
if self.use_integrated_positional_encoding
|
||||||
|
else (ray_bundle_to_ray_points(ray_bundle), None) # pyre-ignore
|
||||||
|
)
|
||||||
# rays_points_world.shape = [minibatch x ... x pts_per_ray x 3]
|
# rays_points_world.shape = [minibatch x ... x pts_per_ray x 3]
|
||||||
|
|
||||||
embeds = create_embeddings_for_implicit_function(
|
embeds = create_embeddings_for_implicit_function(
|
||||||
@ -177,6 +200,7 @@ class NeuralRadianceFieldBase(ImplicitFunctionBase, torch.nn.Module):
|
|||||||
fun_viewpool=fun_viewpool,
|
fun_viewpool=fun_viewpool,
|
||||||
xyz_in_camera_coords=self.xyz_ray_dir_in_camera_coords,
|
xyz_in_camera_coords=self.xyz_ray_dir_in_camera_coords,
|
||||||
camera=camera,
|
camera=camera,
|
||||||
|
diag_cov=diag_cov,
|
||||||
)
|
)
|
||||||
|
|
||||||
# embeds.shape = [minibatch x n_src x n_rays x n_pts x self.n_harmonic_functions*6+3]
|
# embeds.shape = [minibatch x n_src x n_rays x n_pts x self.n_harmonic_functions*6+3]
|
||||||
|
@ -36,6 +36,7 @@ def create_embeddings_for_implicit_function(
|
|||||||
camera: Optional[CamerasBase],
|
camera: Optional[CamerasBase],
|
||||||
fun_viewpool: Optional[Callable],
|
fun_viewpool: Optional[Callable],
|
||||||
xyz_embedding_function: Optional[Callable],
|
xyz_embedding_function: Optional[Callable],
|
||||||
|
diag_cov: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|
||||||
bs, *spatial_size, pts_per_ray, _ = xyz_world.shape
|
bs, *spatial_size, pts_per_ray, _ = xyz_world.shape
|
||||||
@ -59,11 +60,11 @@ def create_embeddings_for_implicit_function(
|
|||||||
prod(spatial_size),
|
prod(spatial_size),
|
||||||
pts_per_ray,
|
pts_per_ray,
|
||||||
0,
|
0,
|
||||||
dtype=xyz_world.dtype,
|
|
||||||
device=xyz_world.device,
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
embeds = xyz_embedding_function(ray_points_for_embed).reshape(
|
|
||||||
|
embeds = xyz_embedding_function(ray_points_for_embed, diag_cov=diag_cov)
|
||||||
|
embeds = embeds.reshape(
|
||||||
bs,
|
bs,
|
||||||
1,
|
1,
|
||||||
prod(spatial_size),
|
prod(spatial_size),
|
||||||
|
@ -4,6 +4,8 @@
|
|||||||
# This source code is licensed under the BSD-style license found in the
|
# This source code is licensed under the BSD-style license found in the
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
@ -16,8 +18,18 @@ class HarmonicEmbedding(torch.nn.Module):
|
|||||||
append_input: bool = True,
|
append_input: bool = True,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Given an input tensor `x` of shape [minibatch, ... , dim],
|
The harmonic embedding layer supports the classical
|
||||||
the harmonic embedding layer converts each feature
|
Nerf positional encoding described in
|
||||||
|
`NeRF <https://arxiv.org/abs/2003.08934>`_
|
||||||
|
and the integrated position encoding in
|
||||||
|
`MIP-NeRF <https://arxiv.org/abs/2103.13415>`_.
|
||||||
|
|
||||||
|
During, the inference you can provide the extra argument `diag_cov`.
|
||||||
|
|
||||||
|
If `diag_cov is None`, it converts
|
||||||
|
rays parametrized with a `ray_bundle` to 3D points by
|
||||||
|
extending each ray according to the corresponding length.
|
||||||
|
Then it converts each feature
|
||||||
(i.e. vector along the last dimension) in `x`
|
(i.e. vector along the last dimension) in `x`
|
||||||
into a series of harmonic features `embedding`,
|
into a series of harmonic features `embedding`,
|
||||||
where for each i in range(dim) the following are present
|
where for each i in range(dim) the following are present
|
||||||
@ -38,6 +50,31 @@ class HarmonicEmbedding(torch.nn.Module):
|
|||||||
where N corresponds to `n_harmonic_functions-1`, and f_i is a scalar
|
where N corresponds to `n_harmonic_functions-1`, and f_i is a scalar
|
||||||
denoting the i-th frequency of the harmonic embedding.
|
denoting the i-th frequency of the harmonic embedding.
|
||||||
|
|
||||||
|
|
||||||
|
If `diag_cov is not None`, it approximates
|
||||||
|
conical frustums following a ray bundle as gaussians,
|
||||||
|
defined by x, the means of the gaussians and diag_cov,
|
||||||
|
the diagonal covariances.
|
||||||
|
Then it converts each gaussian
|
||||||
|
into a series of harmonic features `embedding`,
|
||||||
|
where for each i in range(dim) the following are present
|
||||||
|
in embedding[...]::
|
||||||
|
|
||||||
|
[
|
||||||
|
sin(f_1*x[..., i]) * exp(0.5 * f_1**2 * diag_cov[..., i,]),
|
||||||
|
sin(f_2*x[..., i]) * exp(0.5 * f_2**2 * diag_cov[..., i,]),
|
||||||
|
...
|
||||||
|
sin(f_N * x[..., i]) * exp(0.5 * f_N**2 * diag_cov[..., i,]),
|
||||||
|
cos(f_1*x[..., i]) * exp(0.5 * f_1**2 * diag_cov[..., i,]),
|
||||||
|
cos(f_2*x[..., i]) * exp(0.5 * f_2**2 * diag_cov[..., i,]),,
|
||||||
|
...
|
||||||
|
cos(f_N * x[..., i]) * exp(0.5 * f_N**2 * diag_cov[..., i,]),
|
||||||
|
x[..., i], # only present if append_input is True.
|
||||||
|
]
|
||||||
|
|
||||||
|
where N equals `n_harmonic_functions-1`, and f_i is a scalar
|
||||||
|
denoting the i-th frequency of the harmonic embedding.
|
||||||
|
|
||||||
If `logspace==True`, the frequencies `[f_1, ..., f_N]` are
|
If `logspace==True`, the frequencies `[f_1, ..., f_N]` are
|
||||||
powers of 2:
|
powers of 2:
|
||||||
`f_1, ..., f_N = 2**torch.arange(n_harmonic_functions)`
|
`f_1, ..., f_N = 2**torch.arange(n_harmonic_functions)`
|
||||||
@ -59,8 +96,7 @@ class HarmonicEmbedding(torch.nn.Module):
|
|||||||
logspace or linear space
|
logspace or linear space
|
||||||
append_input: bool, whether to concat the original
|
append_input: bool, whether to concat the original
|
||||||
input to the harmonic embedding. If true the
|
input to the harmonic embedding. If true the
|
||||||
output is of the form (x, embed.sin(), embed.cos()
|
output is of the form (embed.sin(), embed.cos(), x)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -78,23 +114,42 @@ class HarmonicEmbedding(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.register_buffer("_frequencies", frequencies * omega_0, persistent=False)
|
self.register_buffer("_frequencies", frequencies * omega_0, persistent=False)
|
||||||
|
self.register_buffer(
|
||||||
|
"_zero_half_pi", torch.tensor([0.0, 0.5 * torch.pi]), persistent=False
|
||||||
|
)
|
||||||
self.append_input = append_input
|
self.append_input = append_input
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(
|
||||||
|
self, x: torch.Tensor, diag_cov: Optional[torch.Tensor] = None, **kwargs
|
||||||
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
x: tensor of shape [..., dim]
|
x: tensor of shape [..., dim]
|
||||||
|
diag_cov: An optional tensor of shape `(..., dim)`
|
||||||
|
representing the diagonal covariance matrices of our Gaussians, joined with x
|
||||||
|
as means of the Gaussians.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
embedding: a harmonic embedding of `x`
|
embedding: a harmonic embedding of `x` of shape
|
||||||
of shape [..., (n_harmonic_functions * 2 + int(append_input)) * dim]
|
[..., (n_harmonic_functions * 2 + int(append_input)) * num_points_per_ray]
|
||||||
"""
|
"""
|
||||||
embed = (x[..., None] * self._frequencies).reshape(*x.shape[:-1], -1)
|
# [..., dim, n_harmonic_functions]
|
||||||
embed = torch.cat(
|
embed = x[..., None] * self._frequencies
|
||||||
(embed.sin(), embed.cos(), x)
|
# [..., 1, dim, n_harmonic_functions] + [2, 1, 1] => [..., 2, dim, n_harmonic_functions]
|
||||||
if self.append_input
|
embed = embed[..., None, :, :] + self._zero_half_pi[..., None, None]
|
||||||
else (embed.sin(), embed.cos()),
|
# Use the trig identity cos(x) = sin(x + pi/2)
|
||||||
dim=-1,
|
# and do one vectorized call to sin([x, x+pi/2]) instead of (sin(x), cos(x)).
|
||||||
)
|
embed = embed.sin()
|
||||||
|
if diag_cov is not None:
|
||||||
|
x_var = diag_cov[..., None] * torch.pow(self._frequencies, 2)
|
||||||
|
exp_var = torch.exp(-0.5 * x_var)
|
||||||
|
# [..., 2, dim, n_harmonic_functions]
|
||||||
|
embed = embed * exp_var[..., None, :, :]
|
||||||
|
|
||||||
|
embed = embed.reshape(*x.shape[:-1], -1)
|
||||||
|
|
||||||
|
if self.append_input:
|
||||||
|
return torch.cat([embed, x], dim=-1)
|
||||||
return embed
|
return embed
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -0,0 +1,66 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the BSD-style license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from pytorch3d.implicitron.models.implicit_function.base import ImplicitronRayBundle
|
||||||
|
from pytorch3d.implicitron.models.implicit_function.neural_radiance_field import (
|
||||||
|
NeuralRadianceFieldImplicitFunction,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestNeuralRadianceFieldImplicitFunction(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
torch.manual_seed(42)
|
||||||
|
|
||||||
|
def test_forward_with_integrated_positionial_embedding(self):
|
||||||
|
shape = [2, 4, 4]
|
||||||
|
ray_bundle = ImplicitronRayBundle(
|
||||||
|
origins=torch.randn(*shape, 3),
|
||||||
|
directions=torch.randn(*shape, 3),
|
||||||
|
bins=torch.randn(*shape, 6 + 1),
|
||||||
|
lengths=torch.randn(*shape, 6),
|
||||||
|
pixel_radii_2d=torch.randn(*shape, 1),
|
||||||
|
xys=None,
|
||||||
|
)
|
||||||
|
model = NeuralRadianceFieldImplicitFunction(
|
||||||
|
n_hidden_neurons_dir=32, use_integrated_positional_encoding=True
|
||||||
|
)
|
||||||
|
raw_densities, ray_colors, _ = model(ray_bundle=ray_bundle)
|
||||||
|
|
||||||
|
self.assertEqual(raw_densities.shape, (*shape, ray_bundle.lengths.shape[-1], 1))
|
||||||
|
self.assertEqual(ray_colors.shape, (*shape, ray_bundle.lengths.shape[-1], 3))
|
||||||
|
|
||||||
|
def test_forward_with_integrated_positionial_embedding_raise_exception(self):
|
||||||
|
shape = [2, 4, 4]
|
||||||
|
ray_bundle = ImplicitronRayBundle(
|
||||||
|
origins=torch.randn(*shape, 3),
|
||||||
|
directions=torch.randn(*shape, 3),
|
||||||
|
bins=None,
|
||||||
|
lengths=torch.randn(*shape, 6),
|
||||||
|
pixel_radii_2d=torch.randn(*shape, 1),
|
||||||
|
xys=None,
|
||||||
|
)
|
||||||
|
model = NeuralRadianceFieldImplicitFunction(
|
||||||
|
n_hidden_neurons_dir=32, use_integrated_positional_encoding=True
|
||||||
|
)
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
_ = model(ray_bundle=ray_bundle)
|
||||||
|
|
||||||
|
def test_forward(self):
|
||||||
|
shape = [2, 4, 4]
|
||||||
|
ray_bundle = ImplicitronRayBundle(
|
||||||
|
origins=torch.randn(*shape, 3),
|
||||||
|
directions=torch.randn(*shape, 3),
|
||||||
|
lengths=torch.randn(*shape, 6),
|
||||||
|
pixel_radii_2d=torch.randn(*shape, 1),
|
||||||
|
xys=None,
|
||||||
|
)
|
||||||
|
model = NeuralRadianceFieldImplicitFunction(n_hidden_neurons_dir=32)
|
||||||
|
raw_densities, ray_colors, _ = model(ray_bundle=ray_bundle)
|
||||||
|
self.assertEqual(raw_densities.shape, (*shape, 6, 1))
|
||||||
|
self.assertEqual(ray_colors.shape, (*shape, 6, 3))
|
@ -8,6 +8,7 @@ import unittest
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from pytorch3d.renderer.implicit import HarmonicEmbedding
|
from pytorch3d.renderer.implicit import HarmonicEmbedding
|
||||||
|
from torch.distributions import MultivariateNormal
|
||||||
|
|
||||||
from .common_testing import TestCaseMixin
|
from .common_testing import TestCaseMixin
|
||||||
|
|
||||||
@ -36,16 +37,117 @@ class TestHarmonicEmbedding(TestCaseMixin, unittest.TestCase):
|
|||||||
self.assertClose(embed_fun_lin._frequencies, torch.FloatTensor((1.0, 2.5, 4.0)))
|
self.assertClose(embed_fun_lin._frequencies, torch.FloatTensor((1.0, 2.5, 4.0)))
|
||||||
|
|
||||||
def test_correct_embed_out(self):
|
def test_correct_embed_out(self):
|
||||||
embed_fun = HarmonicEmbedding(n_harmonic_functions=2, append_input=False)
|
n_harmonic_functions = 2
|
||||||
x = torch.randn((1, 5))
|
x = torch.randn((1, 5))
|
||||||
D = 5 * 4
|
D = 5 * n_harmonic_functions * 2 # sin + cos
|
||||||
|
|
||||||
|
embed_fun = HarmonicEmbedding(
|
||||||
|
n_harmonic_functions=n_harmonic_functions, append_input=False
|
||||||
|
)
|
||||||
embed_out = embed_fun(x)
|
embed_out = embed_fun(x)
|
||||||
|
|
||||||
self.assertEqual(embed_out.shape, (1, D))
|
self.assertEqual(embed_out.shape, (1, D))
|
||||||
# Sum the squares of the respective frequencies
|
# Sum the squares of the respective frequencies
|
||||||
|
# cos^2(x) + sin^2(x) = 1
|
||||||
sum_squares = embed_out[0, : D // 2] ** 2 + embed_out[0, D // 2 :] ** 2
|
sum_squares = embed_out[0, : D // 2] ** 2 + embed_out[0, D // 2 :] ** 2
|
||||||
self.assertClose(sum_squares, torch.ones((D // 2)))
|
self.assertClose(sum_squares, torch.ones((D // 2)))
|
||||||
embed_fun = HarmonicEmbedding(n_harmonic_functions=2, append_input=True)
|
|
||||||
embed_out = embed_fun(x)
|
# Test append input
|
||||||
self.assertClose(embed_out.shape, torch.tensor((1, 5 * 5)))
|
embed_fun = HarmonicEmbedding(
|
||||||
|
n_harmonic_functions=n_harmonic_functions, append_input=True
|
||||||
|
)
|
||||||
|
embed_out_appended_input = embed_fun(x)
|
||||||
|
self.assertClose(
|
||||||
|
embed_out_appended_input.shape, torch.tensor((1, D + x.shape[-1]))
|
||||||
|
)
|
||||||
# Last plane in output is the input
|
# Last plane in output is the input
|
||||||
self.assertClose(embed_out[..., -5:], x)
|
self.assertClose(embed_out_appended_input[..., -x.shape[-1] :], x)
|
||||||
|
self.assertClose(embed_out_appended_input[..., : -x.shape[-1]], embed_out)
|
||||||
|
|
||||||
|
def test_correct_embed_out_with_diag_cov(self):
|
||||||
|
n_harmonic_functions = 2
|
||||||
|
x = torch.randn((1, 3))
|
||||||
|
diag_cov = torch.randn((1, 3))
|
||||||
|
D = 3 * n_harmonic_functions * 2 # sin + cos
|
||||||
|
|
||||||
|
embed_fun = HarmonicEmbedding(
|
||||||
|
n_harmonic_functions=n_harmonic_functions, append_input=False
|
||||||
|
)
|
||||||
|
embed_out = embed_fun(x, diag_cov=diag_cov)
|
||||||
|
|
||||||
|
self.assertEqual(embed_out.shape, (1, D))
|
||||||
|
|
||||||
|
# Compute the scaling factor introduce in MipNerf
|
||||||
|
scale_factor = (
|
||||||
|
-0.5 * diag_cov[..., None] * torch.pow(embed_fun._frequencies[None, :], 2)
|
||||||
|
)
|
||||||
|
scale_factor = torch.exp(scale_factor).reshape(1, -1).tile((1, 2))
|
||||||
|
# If we remove this scaling factor, we should go back to the
|
||||||
|
# classical harmonic embedding:
|
||||||
|
# Sum the squares of the respective frequencies
|
||||||
|
# cos^2(x) + sin^2(x) = 1
|
||||||
|
embed_out_without_cov = embed_out / scale_factor
|
||||||
|
sum_squares = (
|
||||||
|
embed_out_without_cov[0, : D // 2] ** 2
|
||||||
|
+ embed_out_without_cov[0, D // 2 :] ** 2
|
||||||
|
)
|
||||||
|
self.assertClose(sum_squares, torch.ones((D // 2)))
|
||||||
|
|
||||||
|
# Test append input
|
||||||
|
embed_fun = HarmonicEmbedding(
|
||||||
|
n_harmonic_functions=n_harmonic_functions, append_input=True
|
||||||
|
)
|
||||||
|
embed_out_appended_input = embed_fun(x, diag_cov=diag_cov)
|
||||||
|
self.assertClose(
|
||||||
|
embed_out_appended_input.shape, torch.tensor((1, D + x.shape[-1]))
|
||||||
|
)
|
||||||
|
# Last plane in output is the input
|
||||||
|
self.assertClose(embed_out_appended_input[..., -x.shape[-1] :], x)
|
||||||
|
self.assertClose(embed_out_appended_input[..., : -x.shape[-1]], embed_out)
|
||||||
|
|
||||||
|
def test_correct_behavior_between_ipe_and_its_estimation_from_harmonic_embedding(
|
||||||
|
self,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Check that the HarmonicEmbedding with integrated_position_encoding (IPE) set to
|
||||||
|
True is coherent with the HarmonicEmbedding.
|
||||||
|
|
||||||
|
What is the idea behind this test?
|
||||||
|
|
||||||
|
We wish to produce an IPE that is the expectation
|
||||||
|
of our lifted multivariate gaussian, modulated by the sine and cosine of
|
||||||
|
the coordinates. These expectation has a closed-form
|
||||||
|
(see equations 11, 12, 13, 14 of [1]).
|
||||||
|
|
||||||
|
We sample N elements from the multivariate gaussian defined by its mean and covariance
|
||||||
|
and compute the HarmonicEmbedding. The expected value of those embeddings should be
|
||||||
|
equal to our IPE.
|
||||||
|
|
||||||
|
Inspired from:
|
||||||
|
https://github.com/google/mipnerf/blob/84c969e0a623edd183b75693aed72a7e7c22902d/internal/mip_test.py#L359
|
||||||
|
|
||||||
|
References:
|
||||||
|
[1] `MIP-NeRF <https://arxiv.org/abs/2103.13415>`_.
|
||||||
|
"""
|
||||||
|
num_dims = 3
|
||||||
|
n_harmonic_functions = 6
|
||||||
|
mean = torch.randn(num_dims)
|
||||||
|
diag_cov = torch.rand(num_dims)
|
||||||
|
|
||||||
|
he_fun = HarmonicEmbedding(
|
||||||
|
n_harmonic_functions=n_harmonic_functions, logspace=True, append_input=False
|
||||||
|
)
|
||||||
|
ipe_fun = HarmonicEmbedding(
|
||||||
|
n_harmonic_functions=n_harmonic_functions,
|
||||||
|
append_input=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
embedding_ipe = ipe_fun(mean, diag_cov=diag_cov)
|
||||||
|
|
||||||
|
rand_mvn = MultivariateNormal(mean, torch.eye(num_dims) * diag_cov)
|
||||||
|
|
||||||
|
# Providing a large enough number of samples
|
||||||
|
# we should obtain an estimation close to our IPE
|
||||||
|
num_samples = 100000
|
||||||
|
embedding_he = he_fun(rand_mvn.sample_n(num_samples))
|
||||||
|
self.assertClose(embedding_he.mean(0), embedding_ipe, rtol=1e-2, atol=1e-2)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user