mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Implicit function
Summary: Implements the radiance field function of NeRF Reviewed By: nikhilaravi Differential Revision: D25684413 fbshipit-source-id: 4bf6dd5d22e6134a09f7b9f314536ec16670f737
This commit is contained in:
parent
1e82341da7
commit
bf633ab556
275
projects/nerf/nerf/implicit_function.py
Normal file
275
projects/nerf/nerf/implicit_function.py
Normal file
@ -0,0 +1,275 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
from pytorch3d.renderer import RayBundle, ray_bundle_to_ray_points
|
||||
|
||||
from .harmonic_embedding import HarmonicEmbedding
|
||||
|
||||
|
||||
def _xavier_init(linear):
|
||||
"""
|
||||
Performs the Xavier weight initialization of the linear layer `linear`.
|
||||
"""
|
||||
torch.nn.init.xavier_uniform_(linear.weight.data)
|
||||
|
||||
|
||||
class NeuralRadianceField(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
n_harmonic_functions_xyz: int = 6,
|
||||
n_harmonic_functions_dir: int = 4,
|
||||
n_hidden_neurons_xyz: int = 256,
|
||||
n_hidden_neurons_dir: int = 128,
|
||||
n_layers_xyz: int = 8,
|
||||
append_xyz: List[int] = (5,),
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
n_harmonic_functions_xyz: The number of harmonic functions
|
||||
used to form the harmonic embedding of 3D point locations.
|
||||
n_harmonic_functions_dir: The number of harmonic functions
|
||||
used to form the harmonic embedding of the ray directions.
|
||||
n_hidden_neurons_xyz: The number of hidden units in the
|
||||
fully connected layers of the MLP that accepts the 3D point
|
||||
locations and outputs the occupancy field with the intermediate
|
||||
features.
|
||||
n_hidden_neurons_dir: The number of hidden units in the
|
||||
fully connected layers of the MLP that accepts the intermediate
|
||||
features and ray directions and outputs the radiance field
|
||||
(per-point colors).
|
||||
n_layers_xyz: The number of layers of the MLP that outputs the
|
||||
occupancy field.
|
||||
append_xyz: The list of indices of the skip layers of the occupancy MLP.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
# The harmonic embedding layer converts input 3D coordinates
|
||||
# to a representation that is more suitable for
|
||||
# processing with a deep neural network.
|
||||
self.harmonic_embedding_xyz = HarmonicEmbedding(n_harmonic_functions_xyz)
|
||||
self.harmonic_embedding_dir = HarmonicEmbedding(n_harmonic_functions_dir)
|
||||
embedding_dim_xyz = n_harmonic_functions_xyz * 2 * 3 + 3
|
||||
embedding_dim_dir = n_harmonic_functions_dir * 2 * 3 + 3
|
||||
|
||||
self.mlp_xyz = MLPWithInputSkips(
|
||||
n_layers_xyz,
|
||||
embedding_dim_xyz,
|
||||
n_hidden_neurons_xyz,
|
||||
embedding_dim_xyz,
|
||||
n_hidden_neurons_xyz,
|
||||
input_skips=append_xyz,
|
||||
)
|
||||
|
||||
self.intermediate_linear = torch.nn.Linear(
|
||||
n_hidden_neurons_xyz, n_hidden_neurons_xyz
|
||||
)
|
||||
_xavier_init(self.intermediate_linear)
|
||||
|
||||
self.density_layer = torch.nn.Linear(n_hidden_neurons_xyz, 1)
|
||||
_xavier_init(self.density_layer)
|
||||
|
||||
# Zero the bias of the density layer to avoid
|
||||
# a completely transparent initialization.
|
||||
self.density_layer.bias.data[:] = 0.0 # fixme: Sometimes this is not enough
|
||||
|
||||
self.color_layer = torch.nn.Sequential(
|
||||
torch.nn.Linear(
|
||||
n_hidden_neurons_xyz + embedding_dim_dir, n_hidden_neurons_dir
|
||||
),
|
||||
torch.nn.ReLU(True),
|
||||
torch.nn.Linear(n_hidden_neurons_dir, 3),
|
||||
torch.nn.Sigmoid(),
|
||||
)
|
||||
|
||||
def _get_densities(
|
||||
self,
|
||||
features: torch.Tensor,
|
||||
depth_values: torch.Tensor,
|
||||
density_noise_std: float,
|
||||
):
|
||||
"""
|
||||
This function takes `features` predicted by `self.mlp_xyz`
|
||||
and converts them to `raw_densities` with `self.density_layer`.
|
||||
`raw_densities` are later re-weighted using the depth step sizes
|
||||
and mapped to [0-1] range with 1 - inverse exponential of `raw_densities`.
|
||||
"""
|
||||
raw_densities = self.density_layer(features)
|
||||
deltas = torch.cat(
|
||||
(
|
||||
depth_values[..., 1:] - depth_values[..., :-1],
|
||||
1e10 * torch.ones_like(depth_values[..., :1]),
|
||||
),
|
||||
dim=-1,
|
||||
)[..., None]
|
||||
if density_noise_std > 0.0:
|
||||
raw_densities = (
|
||||
raw_densities + torch.randn_like(raw_densities) * density_noise_std
|
||||
)
|
||||
densities = 1 - (-deltas * torch.relu(raw_densities)).exp()
|
||||
return densities
|
||||
|
||||
def _get_colors(self, features: torch.Tensor, rays_directions: torch.Tensor):
|
||||
"""
|
||||
This function takes per-point `features` predicted by `self.mlp_xyz`
|
||||
and evaluates the color model in order to attach to each
|
||||
point a 3D vector of its RGB color.
|
||||
"""
|
||||
spatial_size = features.shape[:-1]
|
||||
# Normalize the ray_directions to unit l2 norm.
|
||||
rays_directions_normed = torch.nn.functional.normalize(rays_directions, dim=-1)
|
||||
|
||||
# Obtain the harmonic embedding of the normalized ray directions.
|
||||
rays_embedding = torch.cat(
|
||||
(
|
||||
self.harmonic_embedding_dir(rays_directions_normed),
|
||||
rays_directions_normed,
|
||||
),
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
# Expand the ray directions tensor so that its spatial size
|
||||
# is equal to the size of features.
|
||||
rays_embedding_expand = rays_embedding[..., None, :].expand(
|
||||
*spatial_size, rays_embedding.shape[-1]
|
||||
)
|
||||
|
||||
# Concatenate ray direction embeddings with
|
||||
# features and evaluate the color model.
|
||||
color_layer_input = torch.cat(
|
||||
(self.intermediate_linear(features), rays_embedding_expand), dim=-1
|
||||
)
|
||||
return self.color_layer(color_layer_input)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
ray_bundle: RayBundle,
|
||||
density_noise_std: float = 0.0,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
The forward function accepts the parametrizations of
|
||||
3D points sampled along projection rays. The forward
|
||||
pass is responsible for attaching a 3D vector
|
||||
and a 1D scalar representing the point's
|
||||
RGB color and opacity respectively.
|
||||
|
||||
Args:
|
||||
ray_bundle: A RayBundle object containing the following variables:
|
||||
origins: A tensor of shape `(minibatch, ..., 3)` denoting the
|
||||
origins of the sampling rays in world coords.
|
||||
directions: A tensor of shape `(minibatch, ..., 3)`
|
||||
containing the direction vectors of sampling rays in world coords.
|
||||
lengths: A tensor of shape `(minibatch, ..., num_points_per_ray)`
|
||||
containing the lengths at which the rays are sampled.
|
||||
density_noise_std: A floating point value representing the
|
||||
variance of the random normal noise added to the output of
|
||||
the opacity function. This can prevent floating artifacts.
|
||||
|
||||
Returns:
|
||||
rays_densities: A tensor of shape `(minibatch, ..., num_points_per_ray, 1)`
|
||||
denoting the opacitiy of each ray point.
|
||||
rays_colors: A tensor of shape `(minibatch, ..., num_points_per_ray, 3)`
|
||||
denoting the color of each ray point.
|
||||
"""
|
||||
# We first convert the ray parametrizations to world
|
||||
# coordinates with `ray_bundle_to_ray_points`.
|
||||
rays_points_world = ray_bundle_to_ray_points(ray_bundle)
|
||||
# rays_points_world.shape = [minibatch x ... x 3]
|
||||
|
||||
# For each 3D world coordinate, we obtain its harmonic embedding.
|
||||
embeds_xyz = torch.cat(
|
||||
(self.harmonic_embedding_xyz(rays_points_world), rays_points_world),
|
||||
dim=-1,
|
||||
)
|
||||
# embeds_xyz.shape = [minibatch x ... x self.n_harmonic_functions*6 + 3]
|
||||
|
||||
# self.mlp maps each harmonic embedding to a latent feature space.
|
||||
features = self.mlp_xyz(embeds_xyz, embeds_xyz)
|
||||
# features.shape = [minibatch x ... x self.n_hidden_neurons_xyz]
|
||||
|
||||
rays_densities = self._get_densities(
|
||||
features, ray_bundle.lengths, density_noise_std
|
||||
)
|
||||
# rays_densities.shape = [minibatch x ... x 1] in [0-1]
|
||||
|
||||
rays_colors = self._get_colors(features, ray_bundle.directions)
|
||||
# rays_colors.shape = [minibatch x ... x 3] in [0-1]
|
||||
|
||||
return rays_densities, rays_colors
|
||||
|
||||
|
||||
class MLPWithInputSkips(torch.nn.Module):
|
||||
"""
|
||||
Implements the multi-layer perceptron architecture of the Neural Radiance Field.
|
||||
|
||||
As such, `MLPWithInputSkips` is a multi layer perceptron consisting
|
||||
of a sequence of linear layers with ReLU activations.
|
||||
|
||||
Additionally, for a set of predefined layers `input_skips`, the forward pass
|
||||
appends a skip tensor `z` to the output of the preceding layer.
|
||||
|
||||
Note that this follows the architecture described in the Supplementary
|
||||
Material (Fig. 7) of [1].
|
||||
|
||||
References:
|
||||
[1] Ben Mildenhall and Pratul P. Srinivasan and Matthew Tancik
|
||||
and Jonathan T. Barron and Ravi Ramamoorthi and Ren Ng:
|
||||
NeRF: Representing Scenes as Neural Radiance Fields for View
|
||||
Synthesis, ECCV2020
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
n_layers: int,
|
||||
input_dim: int,
|
||||
output_dim: int,
|
||||
skip_dim: int,
|
||||
hidden_dim: int,
|
||||
input_skips: List[int] = (),
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
n_layers: The number of linear layers of the MLP.
|
||||
input_dim: The number of channels of the input tensor.
|
||||
output_dim: The number of channels of the output.
|
||||
skip_dim: The number of channels of the tensor `z` appended when
|
||||
evaluating the skip layers.
|
||||
hidden_dim: The number of hidden units of the MLP.
|
||||
input_skips: The list of layer indices at which we append the skip
|
||||
tensor `z`.
|
||||
"""
|
||||
super().__init__()
|
||||
layers = []
|
||||
for layeri in range(n_layers):
|
||||
if layeri == 0:
|
||||
dimin = input_dim
|
||||
dimout = hidden_dim
|
||||
elif layeri in input_skips:
|
||||
dimin = hidden_dim + skip_dim
|
||||
dimout = hidden_dim
|
||||
else:
|
||||
dimin = hidden_dim
|
||||
dimout = hidden_dim
|
||||
linear = torch.nn.Linear(dimin, dimout)
|
||||
_xavier_init(linear)
|
||||
layers.append(torch.nn.Sequential(linear, torch.nn.ReLU(True)))
|
||||
self.mlp = torch.nn.ModuleList(layers)
|
||||
self._input_skips = set(input_skips)
|
||||
|
||||
def forward(self, x, z):
|
||||
"""
|
||||
Args:
|
||||
x: The input tensor of shape `(..., input_dim)`.
|
||||
z: The input skip tensor of shape `(..., skip_dim)` which is appended
|
||||
to layers whose indices are specified by `input_skips`.
|
||||
Returns:
|
||||
y: The output tensor of shape `(..., output_dim)`.
|
||||
"""
|
||||
y = x
|
||||
for li, layer in enumerate(self.mlp):
|
||||
if li in self._input_skips:
|
||||
y = torch.cat((y, z), dim=-1)
|
||||
y = layer(y)
|
||||
return y
|
Loading…
x
Reference in New Issue
Block a user