Radiance field renderer

Summary: Implements the main NeRF model class that controls the radiance field and its renderer

Reviewed By: nikhilaravi

Differential Revision: D25684419

fbshipit-source-id: fae45572daa6748c6234bd212f3e68110f778238
This commit is contained in:
David Novotny 2021-02-02 05:42:59 -08:00 committed by Facebook GitHub Bot
parent bf633ab556
commit eb908487b8
2 changed files with 411 additions and 0 deletions

View File

@ -0,0 +1,359 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from typing import Tuple, List, Optional
import torch
from pytorch3d.renderer import ImplicitRenderer
from pytorch3d.renderer.cameras import CamerasBase
from .implicit_function import NeuralRadianceField
from .raymarcher import EmissionAbsorptionNeRFRaymarcher
from .raysampler import NeRFRaysampler, ProbabilisticRaysampler
from .utils import sample_images_at_mc_locs, calc_psnr, calc_mse
class RadianceFieldRenderer(torch.nn.Module):
"""
Implements a renderer of a Neural Radiance Field.
This class holds pointers to the fine and coarse renderer objects, which are
instances of `pytorch3d.renderer.ImplicitRenderer`, and pointers to the
neural networks representing the fine and coarse Neural Radiance Fields,
which are instances of `NeuralRadianceField`.
The rendering forward pass proceeds as follows:
1) For a given input camera, rendering rays are generated with the
`NeRFRaysampler` object of `self._renderer['coarse']`.
In the training mode (`self.training==True`), the rays are a set
of `n_rays_per_image` random 2D locations of the image grid.
In the evaluation mode (`self.training==False`), the rays correspond
to the full image grid. The rays are further split to
`chunk_size_test`-sized chunks to prevent out-of-memory errors.
2) For each ray point, the coarse `NeuralRadianceField` MLP is evaluated.
The pointer to this MLP is stored in `self._implicit_function['coarse']`
3) The coarse radiance field is rendered with the
`EmissionAbsorptionNeRFRaymarcher` object of `self._renderer['coarse']`.
4) The coarse raymarcher outputs a probability distribution that guides
the importance raysampling of the fine rendering pass. The
`ProbabilisticRaysampler` stored in `self._renderer['fine'].raysampler`
implements the importance ray-sampling.
5) Similar to 2) the fine MLP in `self._implicit_function['fine']`
labels the ray points with occupancies and colors.
6) self._renderer['fine'].raymarcher` generates the final fine render.
7) The fine and coarse renders are compared to the ground truth input image
with PSNR and MSE metrics.
"""
def __init__(
self,
image_size: Tuple[int, int],
n_pts_per_ray: int,
n_pts_per_ray_fine: int,
n_rays_per_image: int,
min_depth: float,
max_depth: float,
stratified: bool,
stratified_test: bool,
chunk_size_test: int,
n_harmonic_functions_xyz: int = 6,
n_harmonic_functions_dir: int = 4,
n_hidden_neurons_xyz: int = 256,
n_hidden_neurons_dir: int = 128,
n_layers_xyz: int = 8,
append_xyz: List[int] = (5,),
density_noise_std: float = 0.0,
):
"""
Args:
image_size: The size of the rendered image (`[height, width]`).
n_pts_per_ray: The number of points sampled along each ray for the
coarse rendering pass.
n_pts_per_ray_fine: The number of points sampled along each ray for the
fine rendering pass.
n_rays_per_image: Number of Monte Carlo ray samples when training
(`self.training==True`).
min_depth: The minimum depth of a sampled ray-point for the coarse rendering.
max_depth: The maximum depth of a sampled ray-point for the coarse rendering.
stratified: If `True`, stratifies (=randomly offsets) the depths
of each ray point during training (`self.training==True`).
stratified_test: If `True`, stratifies (=randomly offsets) the depths
of each ray point during evaluation (`self.training==False`).
chunk_size_test: The number of rays in each chunk of image rays.
Active only when `self.training==True`.
n_harmonic_functions_xyz: The number of harmonic functions
used to form the harmonic embedding of 3D point locations.
n_harmonic_functions_dir: The number of harmonic functions
used to form the harmonic embedding of the ray directions.
n_hidden_neurons_xyz: The number of hidden units in the
fully connected layers of the MLP that accepts the 3D point
locations and outputs the occupancy field with the intermediate
features.
n_hidden_neurons_dir: The number of hidden units in the
fully connected layers of the MLP that accepts the intermediate
features and ray directions and outputs the radiance field
(per-point colors).
n_layers_xyz: The number of layers of the MLP that outputs the
occupancy field.
append_xyz: The list of indices of the skip layers of the occupancy MLP.
Prior to evaluating the skip layers, the tensor which was input to MLP
is appended to the skip layer input.
density_noise_std: The standard deviation of the random normal noise
added to the output of the occupancy MLP.
Active only when `self.training==True`.
"""
super().__init__()
# The renderers and implicit functions are stored under the fine/coarse
# keys in ModuleDict PyTorch modules.
self._renderer = torch.nn.ModuleDict()
self._implicit_function = torch.nn.ModuleDict()
# Init the EA raymarcher used by both passes.
raymarcher = EmissionAbsorptionNeRFRaymarcher()
# Parse out image dimensions.
image_height, image_width = image_size
for render_pass in ("coarse", "fine"):
if render_pass == "coarse":
# Initialize the coarse raysampler.
raysampler = NeRFRaysampler(
n_pts_per_ray=n_pts_per_ray,
min_depth=min_depth,
max_depth=max_depth,
stratified=stratified,
stratified_test=stratified_test,
n_rays_per_image=n_rays_per_image,
image_height=image_height,
image_width=image_width,
)
elif render_pass == "fine":
# Initialize the fine raysampler.
raysampler = ProbabilisticRaysampler(
n_pts_per_ray=n_pts_per_ray_fine,
stratified=stratified,
stratified_test=stratified_test,
)
else:
raise ValueError(f"No such rendering pass {render_pass}")
# Initialize the fine/coarse renderer.
self._renderer[render_pass] = ImplicitRenderer(
raysampler=raysampler,
raymarcher=raymarcher,
)
# Instantiate the fine/coarse NeuralRadianceField module.
self._implicit_function[render_pass] = NeuralRadianceField(
n_harmonic_functions_xyz=n_harmonic_functions_xyz,
n_harmonic_functions_dir=n_harmonic_functions_dir,
n_hidden_neurons_xyz=n_hidden_neurons_xyz,
n_hidden_neurons_dir=n_hidden_neurons_dir,
n_layers_xyz=n_layers_xyz,
append_xyz=append_xyz,
)
self._density_noise_std = density_noise_std
self._chunk_size_test = chunk_size_test
self._image_size = image_size
def precache_rays(
self,
cache_cameras: List[CamerasBase],
cache_camera_hashes: List[str],
):
"""
Precaches the rays emitted from the list of cameras `cache_cameras`,
where each camera is uniquely identified with the corresponding hash
from `cache_camera_hashes`.
The cached rays are moved to cpu and stored in
`self._renderer['coarse']._ray_cache`.
Raises `ValueError` when caching two cameras with the same hash.
Args:
cache_cameras: A list of `N` cameras for which the rays are pre-cached.
cache_camera_hashes: A list of `N` unique identifiers for each
camera from `cameras`.
"""
self._renderer["coarse"].raysampler.precache_rays(
cache_cameras,
cache_camera_hashes,
)
def _process_ray_chunk(
self,
camera_hash: Optional[str],
camera: CamerasBase,
image: torch.Tensor,
chunk_idx: int,
) -> dict:
"""
Samples and renders a chunk of rays.
Args:
camera_hash: A unique identifier of a pre-cached camera.
If `None`, the cache is not searched and the sampled rays are
calculated from scratch.
camera: A batch of cameras from which the scene is rendered.
image: A batch of corresponding ground truth images of shape
('batch_size', ·, ·, 3).
chunk_idx: The index of the currently rendered ray chunk.
Returns:
out: `dict` containing the outputs of the rendering:
`rgb_coarse`: The result of the coarse rendering pass.
`rgb_fine`: The result of the fine rendering pass.
`rgb_gt`: The corresponding ground-truth RGB values.
"""
# Initialize the outputs of the coarse rendering to None.
coarse_ray_bundle = None
coarse_weights = None
# First evaluate the coarse rendering pass, then the fine one.
for renderer_pass in ("coarse", "fine"):
(rgb, weights), ray_bundle_out = self._renderer[renderer_pass](
cameras=camera,
volumetric_function=self._implicit_function[renderer_pass],
chunksize=self._chunk_size_test,
chunk_idx=chunk_idx,
density_noise_std=(self._density_noise_std if self.training else 0.0),
input_ray_bundle=coarse_ray_bundle,
ray_weights=coarse_weights,
camera_hash=camera_hash,
)
if renderer_pass == "coarse":
rgb_coarse = rgb
# Store the weights and the rays of the first rendering pass
# for the ensuing importance ray-sampling of the fine render.
coarse_ray_bundle = ray_bundle_out
coarse_weights = weights
if image is not None:
# Sample the ground truth images at the xy locations of the
# rendering ray pixels.
rgb_gt = sample_images_at_mc_locs(
image[..., :3][None],
ray_bundle_out.xys,
)
else:
rgb_gt = None
elif renderer_pass == "fine":
rgb_fine = rgb
else:
raise ValueError(f"No such rendering pass {renderer_pass}")
return {
"rgb_fine": rgb_fine,
"rgb_coarse": rgb_coarse,
"rgb_gt": rgb_gt,
# Store the coarse rays/weights only for visualization purposes.
"coarse_ray_bundle": type(coarse_ray_bundle)(
*[v.detach().cpu() for k, v in coarse_ray_bundle._asdict().items()]
),
"coarse_weights": coarse_weights.detach().cpu(),
}
def forward(
self,
camera_hash: Optional[str],
camera: CamerasBase,
image: torch.Tensor,
) -> Tuple[dict, dict]:
"""
Performs the coarse and fine rendering passees of the radiance field
from the viewpoint of the input `camera`.
Afterwards, both renders are compared to the input ground truth `image`
by evaluating the peak signal-to-noise ratio and the mean-squared error.
The rendering result depends on the `self.training` flag:
- In the training mode (`self.training==True`), the function renders
a random subset of image rays (Monte Carlo rendering).
- In evaluation mode (`self.training==False`), the function renders
the full image. In order to prevent out-of-memory errors,
when `self.training==False`, the rays are sampled and rendered
in batches of size `chunksize`.
Args:
camera_hash: A unique identifier of a pre-cached camera.
If `None`, the cache is not searched and the sampled rays are
calculated from scratch.
camera: A batch of cameras from which the scene is rendered.
image: A batch of corresponding ground truth images of shape
('batch_size', ·, ·, 3).
Returns:
out: `dict` containing the outputs of the rendering:
`rgb_coarse`: The result of the coarse rendering pass.
`rgb_fine`: The result of the fine rendering pass.
`rgb_gt`: The corresponding ground-truth RGB values.
The shape of `rgb_coarse`, `rgb_fine`, `rgb_gt` depends on the
`self.training` flag:
If `==True`, all 3 tensors are of shape
`(batch_size, n_rays_per_image, 3)` and contain the result
of the Monte Carlo training rendering pass.
If `==False`, all 3 tensors are of shape
`(batch_size, image_size[0], image_size[1], 3)` and contain
the result of the full image rendering pass.
metrics: `dict` containing the error metrics comparing the fine and
coarse renders to the ground truth:
`mse_coarse`: Mean-squared error between the coarse render and
the input `image`
`mse_fine`: Mean-squared error between the fine render and
the input `image`
`psnr_coarse`: Peak signal-to-noise ratio between the coarse render and
the input `image`
`psnr_fine`: Peak signal-to-noise ratio between the fine render and
the input `image`
"""
if not self.training:
# Full evaluation pass.
n_chunks = self._renderer["coarse"].raysampler.get_n_chunks(
self._chunk_size_test,
camera.R.shape[0],
)
else:
# MonteCarlo ray sampling.
n_chunks = 1
# Process the chunks of rays.
chunk_outputs = [
self._process_ray_chunk(
camera_hash,
camera,
image,
chunk_idx,
)
for chunk_idx in range(n_chunks)
]
if not self.training:
# For a full render pass concatenate the output chunks,
# and reshape to image size.
out = {
k: torch.cat(
[ch_o[k] for ch_o in chunk_outputs],
dim=1,
).view(-1, *self._image_size, 3)
if chunk_outputs[0][k] is not None
else None
for k in ("rgb_fine", "rgb_coarse", "rgb_gt")
}
else:
out = chunk_outputs[0]
# Calc the error metrics.
metrics = {}
if image is not None:
for render_pass in ("coarse", "fine"):
for metric_name, metric_fun in zip(
("mse", "psnr"), (calc_mse, calc_psnr)
):
metrics[f"{metric_name}_{render_pass}"] = metric_fun(
out["rgb_" + render_pass][..., :3],
out["rgb_gt"][..., :3],
)
return out, metrics

View File

@ -65,3 +65,55 @@ def sample_pdf(
samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0])
return samples
def calc_mse(x: torch.Tensor, y: torch.Tensor):
"""
Calculates the mean square error between tensors `x` and `y`.
"""
return torch.mean((x - y) ** 2)
def calc_psnr(x: torch.Tensor, y: torch.Tensor):
"""
Calculates the Peak-signal-to-noise ratio between tensors `x` and `y`.
"""
mse = calc_mse(x, y)
psnr = -10.0 * torch.log10(mse)
return psnr
def sample_images_at_mc_locs(
target_images: torch.Tensor,
sampled_rays_xy: torch.Tensor,
):
"""
Given a set of pixel locations `sampled_rays_xy` this method samples the tensor
`target_images` at the respective 2D locations.
This function is used in order to extract the colors from ground truth images
that correspond to the colors rendered using a Monte Carlo rendering.
Args:
target_images: A tensor of shape `(batch_size, ..., 3)`.
sampled_rays_xy: A tensor of shape `(batch_size, S_1, ..., S_N, 2)`.
Returns:
images_sampled: A tensor of shape `(batch_size, S_1, ..., S_N, 3)`
containing `target_images` sampled at `sampled_rays_xy`.
"""
ba = target_images.shape[0]
dim = target_images.shape[-1]
spatial_size = sampled_rays_xy.shape[1:-1]
# The coordinate grid convention for grid_sample has both x and y
# directions inverted.
xy_sample = -sampled_rays_xy.view(ba, -1, 1, 2).clone()
images_sampled = torch.nn.functional.grid_sample(
target_images.permute(0, 3, 1, 2),
xy_sample,
align_corners=True,
mode="bilinear",
)
return images_sampled.permute(0, 2, 3, 1).view(ba, *spatial_size, dim)