mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2026-01-17 11:50:35 +08:00
Add integrated position encoding based on MIPNerf implementation.
Summary: Add a new implicit module Integral Position Encoding based on [MIP-NeRF](https://arxiv.org/abs/2103.13415). Reviewed By: shapovalov Differential Revision: D46352730 fbshipit-source-id: c6a56134c975d80052b3a11f5e92fd7d95cbff1e
This commit is contained in:
committed by
Facebook GitHub Bot
parent
29b8ebd802
commit
ccf860f1db
@@ -9,11 +9,14 @@ from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from pytorch3d.common.linear_with_repeat import LinearWithRepeat
|
||||
from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle
|
||||
from pytorch3d.implicitron.models.renderer.base import (
|
||||
conical_frustum_to_gaussian,
|
||||
ImplicitronRayBundle,
|
||||
)
|
||||
from pytorch3d.implicitron.tools.config import expand_args_fields, registry
|
||||
from pytorch3d.renderer import ray_bundle_to_ray_points
|
||||
from pytorch3d.renderer.cameras import CamerasBase
|
||||
from pytorch3d.renderer.implicit import HarmonicEmbedding
|
||||
from pytorch3d.renderer.implicit.utils import ray_bundle_to_ray_points
|
||||
|
||||
from .base import ImplicitFunctionBase
|
||||
|
||||
@@ -36,6 +39,7 @@ class NeuralRadianceFieldBase(ImplicitFunctionBase, torch.nn.Module):
|
||||
input_xyz: bool = True
|
||||
xyz_ray_dir_in_camera_coords: bool = False
|
||||
color_dim: int = 3
|
||||
use_integrated_positional_encoding: bool = False
|
||||
"""
|
||||
Args:
|
||||
n_harmonic_functions_xyz: The number of harmonic functions
|
||||
@@ -53,6 +57,10 @@ class NeuralRadianceFieldBase(ImplicitFunctionBase, torch.nn.Module):
|
||||
n_layers_xyz: The number of layers of the MLP that outputs the
|
||||
occupancy field.
|
||||
append_xyz: The list of indices of the skip layers of the occupancy MLP.
|
||||
use_integrated_positional_encoding: If True, use integrated positional enoding
|
||||
as defined in `MIP-NeRF <https://arxiv.org/abs/2103.13415>`_.
|
||||
If False, use the classical harmonic embedding
|
||||
defined in `NeRF <https://arxiv.org/abs/2003.08934>`_.
|
||||
"""
|
||||
|
||||
def __post_init__(self):
|
||||
@@ -149,6 +157,10 @@ class NeuralRadianceFieldBase(ImplicitFunctionBase, torch.nn.Module):
|
||||
containing the direction vectors of sampling rays in world coords.
|
||||
lengths: A tensor of shape `(minibatch, ..., num_points_per_ray)`
|
||||
containing the lengths at which the rays are sampled.
|
||||
bins: An optional tensor of shape `(minibatch,..., num_points_per_ray + 1)`
|
||||
containing the bins at which the rays are sampled. In this case
|
||||
lengths is equal to the midpoints of bins.
|
||||
|
||||
fun_viewpool: an optional callback with the signature
|
||||
fun_fiewpool(points) -> pooled_features
|
||||
where points is a [N_TGT x N x 3] tensor of world coords,
|
||||
@@ -160,11 +172,22 @@ class NeuralRadianceFieldBase(ImplicitFunctionBase, torch.nn.Module):
|
||||
denoting the opacitiy of each ray point.
|
||||
rays_colors: A tensor of shape `(minibatch, ..., num_points_per_ray, 3)`
|
||||
denoting the color of each ray point.
|
||||
|
||||
Raises:
|
||||
ValueError: If `use_integrated_positional_encoding` is True and
|
||||
`ray_bundle.bins` is None.
|
||||
"""
|
||||
# We first convert the ray parametrizations to world
|
||||
# coordinates with `ray_bundle_to_ray_points`.
|
||||
# pyre-ignore[6]
|
||||
rays_points_world = ray_bundle_to_ray_points(ray_bundle)
|
||||
if self.use_integrated_positional_encoding and ray_bundle.bins is None:
|
||||
raise ValueError(
|
||||
"When use_integrated_positional_encoding is True, ray_bundle.bins must be set."
|
||||
"Have you set to True `AbstractMaskRaySampler.use_bins_for_ray_sampling`?"
|
||||
)
|
||||
|
||||
rays_points_world, diag_cov = (
|
||||
conical_frustum_to_gaussian(ray_bundle)
|
||||
if self.use_integrated_positional_encoding
|
||||
else (ray_bundle_to_ray_points(ray_bundle), None) # pyre-ignore
|
||||
)
|
||||
# rays_points_world.shape = [minibatch x ... x pts_per_ray x 3]
|
||||
|
||||
embeds = create_embeddings_for_implicit_function(
|
||||
@@ -177,6 +200,7 @@ class NeuralRadianceFieldBase(ImplicitFunctionBase, torch.nn.Module):
|
||||
fun_viewpool=fun_viewpool,
|
||||
xyz_in_camera_coords=self.xyz_ray_dir_in_camera_coords,
|
||||
camera=camera,
|
||||
diag_cov=diag_cov,
|
||||
)
|
||||
|
||||
# embeds.shape = [minibatch x n_src x n_rays x n_pts x self.n_harmonic_functions*6+3]
|
||||
|
||||
@@ -36,6 +36,7 @@ def create_embeddings_for_implicit_function(
|
||||
camera: Optional[CamerasBase],
|
||||
fun_viewpool: Optional[Callable],
|
||||
xyz_embedding_function: Optional[Callable],
|
||||
diag_cov: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
bs, *spatial_size, pts_per_ray, _ = xyz_world.shape
|
||||
@@ -59,11 +60,11 @@ def create_embeddings_for_implicit_function(
|
||||
prod(spatial_size),
|
||||
pts_per_ray,
|
||||
0,
|
||||
dtype=xyz_world.dtype,
|
||||
device=xyz_world.device,
|
||||
)
|
||||
else:
|
||||
embeds = xyz_embedding_function(ray_points_for_embed).reshape(
|
||||
|
||||
embeds = xyz_embedding_function(ray_points_for_embed, diag_cov=diag_cov)
|
||||
embeds = embeds.reshape(
|
||||
bs,
|
||||
1,
|
||||
prod(spatial_size),
|
||||
|
||||
@@ -4,6 +4,8 @@
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
@@ -16,8 +18,18 @@ class HarmonicEmbedding(torch.nn.Module):
|
||||
append_input: bool = True,
|
||||
) -> None:
|
||||
"""
|
||||
Given an input tensor `x` of shape [minibatch, ... , dim],
|
||||
the harmonic embedding layer converts each feature
|
||||
The harmonic embedding layer supports the classical
|
||||
Nerf positional encoding described in
|
||||
`NeRF <https://arxiv.org/abs/2003.08934>`_
|
||||
and the integrated position encoding in
|
||||
`MIP-NeRF <https://arxiv.org/abs/2103.13415>`_.
|
||||
|
||||
During, the inference you can provide the extra argument `diag_cov`.
|
||||
|
||||
If `diag_cov is None`, it converts
|
||||
rays parametrized with a `ray_bundle` to 3D points by
|
||||
extending each ray according to the corresponding length.
|
||||
Then it converts each feature
|
||||
(i.e. vector along the last dimension) in `x`
|
||||
into a series of harmonic features `embedding`,
|
||||
where for each i in range(dim) the following are present
|
||||
@@ -38,6 +50,31 @@ class HarmonicEmbedding(torch.nn.Module):
|
||||
where N corresponds to `n_harmonic_functions-1`, and f_i is a scalar
|
||||
denoting the i-th frequency of the harmonic embedding.
|
||||
|
||||
|
||||
If `diag_cov is not None`, it approximates
|
||||
conical frustums following a ray bundle as gaussians,
|
||||
defined by x, the means of the gaussians and diag_cov,
|
||||
the diagonal covariances.
|
||||
Then it converts each gaussian
|
||||
into a series of harmonic features `embedding`,
|
||||
where for each i in range(dim) the following are present
|
||||
in embedding[...]::
|
||||
|
||||
[
|
||||
sin(f_1*x[..., i]) * exp(0.5 * f_1**2 * diag_cov[..., i,]),
|
||||
sin(f_2*x[..., i]) * exp(0.5 * f_2**2 * diag_cov[..., i,]),
|
||||
...
|
||||
sin(f_N * x[..., i]) * exp(0.5 * f_N**2 * diag_cov[..., i,]),
|
||||
cos(f_1*x[..., i]) * exp(0.5 * f_1**2 * diag_cov[..., i,]),
|
||||
cos(f_2*x[..., i]) * exp(0.5 * f_2**2 * diag_cov[..., i,]),,
|
||||
...
|
||||
cos(f_N * x[..., i]) * exp(0.5 * f_N**2 * diag_cov[..., i,]),
|
||||
x[..., i], # only present if append_input is True.
|
||||
]
|
||||
|
||||
where N equals `n_harmonic_functions-1`, and f_i is a scalar
|
||||
denoting the i-th frequency of the harmonic embedding.
|
||||
|
||||
If `logspace==True`, the frequencies `[f_1, ..., f_N]` are
|
||||
powers of 2:
|
||||
`f_1, ..., f_N = 2**torch.arange(n_harmonic_functions)`
|
||||
@@ -59,8 +96,7 @@ class HarmonicEmbedding(torch.nn.Module):
|
||||
logspace or linear space
|
||||
append_input: bool, whether to concat the original
|
||||
input to the harmonic embedding. If true the
|
||||
output is of the form (x, embed.sin(), embed.cos()
|
||||
|
||||
output is of the form (embed.sin(), embed.cos(), x)
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
@@ -78,23 +114,42 @@ class HarmonicEmbedding(torch.nn.Module):
|
||||
)
|
||||
|
||||
self.register_buffer("_frequencies", frequencies * omega_0, persistent=False)
|
||||
self.register_buffer(
|
||||
"_zero_half_pi", torch.tensor([0.0, 0.5 * torch.pi]), persistent=False
|
||||
)
|
||||
self.append_input = append_input
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
def forward(
|
||||
self, x: torch.Tensor, diag_cov: Optional[torch.Tensor] = None, **kwargs
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
x: tensor of shape [..., dim]
|
||||
diag_cov: An optional tensor of shape `(..., dim)`
|
||||
representing the diagonal covariance matrices of our Gaussians, joined with x
|
||||
as means of the Gaussians.
|
||||
|
||||
Returns:
|
||||
embedding: a harmonic embedding of `x`
|
||||
of shape [..., (n_harmonic_functions * 2 + int(append_input)) * dim]
|
||||
embedding: a harmonic embedding of `x` of shape
|
||||
[..., (n_harmonic_functions * 2 + int(append_input)) * num_points_per_ray]
|
||||
"""
|
||||
embed = (x[..., None] * self._frequencies).reshape(*x.shape[:-1], -1)
|
||||
embed = torch.cat(
|
||||
(embed.sin(), embed.cos(), x)
|
||||
if self.append_input
|
||||
else (embed.sin(), embed.cos()),
|
||||
dim=-1,
|
||||
)
|
||||
# [..., dim, n_harmonic_functions]
|
||||
embed = x[..., None] * self._frequencies
|
||||
# [..., 1, dim, n_harmonic_functions] + [2, 1, 1] => [..., 2, dim, n_harmonic_functions]
|
||||
embed = embed[..., None, :, :] + self._zero_half_pi[..., None, None]
|
||||
# Use the trig identity cos(x) = sin(x + pi/2)
|
||||
# and do one vectorized call to sin([x, x+pi/2]) instead of (sin(x), cos(x)).
|
||||
embed = embed.sin()
|
||||
if diag_cov is not None:
|
||||
x_var = diag_cov[..., None] * torch.pow(self._frequencies, 2)
|
||||
exp_var = torch.exp(-0.5 * x_var)
|
||||
# [..., 2, dim, n_harmonic_functions]
|
||||
embed = embed * exp_var[..., None, :, :]
|
||||
|
||||
embed = embed.reshape(*x.shape[:-1], -1)
|
||||
|
||||
if self.append_input:
|
||||
return torch.cat([embed, x], dim=-1)
|
||||
return embed
|
||||
|
||||
@staticmethod
|
||||
|
||||
Reference in New Issue
Block a user