pytorch3d/projects/nerf/nerf/nerf_renderer.py
Jeremy Reizenstein 9eeb456e82 Update license for company name
Summary: Update all FB license strings to the new format.

Reviewed By: patricklabatut

Differential Revision: D33403538

fbshipit-source-id: 97a4596c5c888f3c54f44456dc07e718a387a02c
2022-01-04 11:43:38 -08:00

435 lines
18 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import List, Optional, Tuple
import torch
from pytorch3d.renderer import ImplicitRenderer, ray_bundle_to_ray_points
from pytorch3d.renderer.cameras import CamerasBase
from pytorch3d.structures import Pointclouds
from pytorch3d.vis.plotly_vis import plot_scene
from visdom import Visdom
from .implicit_function import NeuralRadianceField
from .raymarcher import EmissionAbsorptionNeRFRaymarcher
from .raysampler import NeRFRaysampler, ProbabilisticRaysampler
from .utils import calc_mse, calc_psnr, sample_images_at_mc_locs
class RadianceFieldRenderer(torch.nn.Module):
"""
Implements a renderer of a Neural Radiance Field.
This class holds pointers to the fine and coarse renderer objects, which are
instances of `pytorch3d.renderer.ImplicitRenderer`, and pointers to the
neural networks representing the fine and coarse Neural Radiance Fields,
which are instances of `NeuralRadianceField`.
The rendering forward pass proceeds as follows:
1) For a given input camera, rendering rays are generated with the
`NeRFRaysampler` object of `self._renderer['coarse']`.
In the training mode (`self.training==True`), the rays are a set
of `n_rays_per_image` random 2D locations of the image grid.
In the evaluation mode (`self.training==False`), the rays correspond
to the full image grid. The rays are further split to
`chunk_size_test`-sized chunks to prevent out-of-memory errors.
2) For each ray point, the coarse `NeuralRadianceField` MLP is evaluated.
The pointer to this MLP is stored in `self._implicit_function['coarse']`
3) The coarse radiance field is rendered with the
`EmissionAbsorptionNeRFRaymarcher` object of `self._renderer['coarse']`.
4) The coarse raymarcher outputs a probability distribution that guides
the importance raysampling of the fine rendering pass. The
`ProbabilisticRaysampler` stored in `self._renderer['fine'].raysampler`
implements the importance ray-sampling.
5) Similar to 2) the fine MLP in `self._implicit_function['fine']`
labels the ray points with occupancies and colors.
6) self._renderer['fine'].raymarcher` generates the final fine render.
7) The fine and coarse renders are compared to the ground truth input image
with PSNR and MSE metrics.
"""
def __init__(
self,
image_size: Tuple[int, int],
n_pts_per_ray: int,
n_pts_per_ray_fine: int,
n_rays_per_image: int,
min_depth: float,
max_depth: float,
stratified: bool,
stratified_test: bool,
chunk_size_test: int,
n_harmonic_functions_xyz: int = 6,
n_harmonic_functions_dir: int = 4,
n_hidden_neurons_xyz: int = 256,
n_hidden_neurons_dir: int = 128,
n_layers_xyz: int = 8,
append_xyz: Tuple[int] = (5,),
density_noise_std: float = 0.0,
visualization: bool = False,
):
"""
Args:
image_size: The size of the rendered image (`[height, width]`).
n_pts_per_ray: The number of points sampled along each ray for the
coarse rendering pass.
n_pts_per_ray_fine: The number of points sampled along each ray for the
fine rendering pass.
n_rays_per_image: Number of Monte Carlo ray samples when training
(`self.training==True`).
min_depth: The minimum depth of a sampled ray-point for the coarse rendering.
max_depth: The maximum depth of a sampled ray-point for the coarse rendering.
stratified: If `True`, stratifies (=randomly offsets) the depths
of each ray point during training (`self.training==True`).
stratified_test: If `True`, stratifies (=randomly offsets) the depths
of each ray point during evaluation (`self.training==False`).
chunk_size_test: The number of rays in each chunk of image rays.
Active only when `self.training==True`.
n_harmonic_functions_xyz: The number of harmonic functions
used to form the harmonic embedding of 3D point locations.
n_harmonic_functions_dir: The number of harmonic functions
used to form the harmonic embedding of the ray directions.
n_hidden_neurons_xyz: The number of hidden units in the
fully connected layers of the MLP that accepts the 3D point
locations and outputs the occupancy field with the intermediate
features.
n_hidden_neurons_dir: The number of hidden units in the
fully connected layers of the MLP that accepts the intermediate
features and ray directions and outputs the radiance field
(per-point colors).
n_layers_xyz: The number of layers of the MLP that outputs the
occupancy field.
append_xyz: The list of indices of the skip layers of the occupancy MLP.
Prior to evaluating the skip layers, the tensor which was input to MLP
is appended to the skip layer input.
density_noise_std: The standard deviation of the random normal noise
added to the output of the occupancy MLP.
Active only when `self.training==True`.
visualization: whether to store extra output for visualization.
"""
super().__init__()
# The renderers and implicit functions are stored under the fine/coarse
# keys in ModuleDict PyTorch modules.
self._renderer = torch.nn.ModuleDict()
self._implicit_function = torch.nn.ModuleDict()
# Init the EA raymarcher used by both passes.
raymarcher = EmissionAbsorptionNeRFRaymarcher()
# Parse out image dimensions.
image_height, image_width = image_size
for render_pass in ("coarse", "fine"):
if render_pass == "coarse":
# Initialize the coarse raysampler.
raysampler = NeRFRaysampler(
n_pts_per_ray=n_pts_per_ray,
min_depth=min_depth,
max_depth=max_depth,
stratified=stratified,
stratified_test=stratified_test,
n_rays_per_image=n_rays_per_image,
image_height=image_height,
image_width=image_width,
)
elif render_pass == "fine":
# Initialize the fine raysampler.
raysampler = ProbabilisticRaysampler(
n_pts_per_ray=n_pts_per_ray_fine,
stratified=stratified,
stratified_test=stratified_test,
)
else:
raise ValueError(f"No such rendering pass {render_pass}")
# Initialize the fine/coarse renderer.
self._renderer[render_pass] = ImplicitRenderer(
raysampler=raysampler,
raymarcher=raymarcher,
)
# Instantiate the fine/coarse NeuralRadianceField module.
self._implicit_function[render_pass] = NeuralRadianceField(
n_harmonic_functions_xyz=n_harmonic_functions_xyz,
n_harmonic_functions_dir=n_harmonic_functions_dir,
n_hidden_neurons_xyz=n_hidden_neurons_xyz,
n_hidden_neurons_dir=n_hidden_neurons_dir,
n_layers_xyz=n_layers_xyz,
append_xyz=append_xyz,
)
self._density_noise_std = density_noise_std
self._chunk_size_test = chunk_size_test
self._image_size = image_size
self.visualization = visualization
def precache_rays(
self,
cache_cameras: List[CamerasBase],
cache_camera_hashes: List[str],
):
"""
Precaches the rays emitted from the list of cameras `cache_cameras`,
where each camera is uniquely identified with the corresponding hash
from `cache_camera_hashes`.
The cached rays are moved to cpu and stored in
`self._renderer['coarse']._ray_cache`.
Raises `ValueError` when caching two cameras with the same hash.
Args:
cache_cameras: A list of `N` cameras for which the rays are pre-cached.
cache_camera_hashes: A list of `N` unique identifiers for each
camera from `cameras`.
"""
self._renderer["coarse"].raysampler.precache_rays(
cache_cameras,
cache_camera_hashes,
)
def _process_ray_chunk(
self,
camera_hash: Optional[str],
camera: CamerasBase,
image: torch.Tensor,
chunk_idx: int,
) -> dict:
"""
Samples and renders a chunk of rays.
Args:
camera_hash: A unique identifier of a pre-cached camera.
If `None`, the cache is not searched and the sampled rays are
calculated from scratch.
camera: A batch of cameras from which the scene is rendered.
image: A batch of corresponding ground truth images of shape
('batch_size', ·, ·, 3).
chunk_idx: The index of the currently rendered ray chunk.
Returns:
out: `dict` containing the outputs of the rendering:
`rgb_coarse`: The result of the coarse rendering pass.
`rgb_fine`: The result of the fine rendering pass.
`rgb_gt`: The corresponding ground-truth RGB values.
"""
# Initialize the outputs of the coarse rendering to None.
coarse_ray_bundle = None
coarse_weights = None
# First evaluate the coarse rendering pass, then the fine one.
for renderer_pass in ("coarse", "fine"):
(rgb, weights), ray_bundle_out = self._renderer[renderer_pass](
cameras=camera,
volumetric_function=self._implicit_function[renderer_pass],
chunksize=self._chunk_size_test,
chunk_idx=chunk_idx,
density_noise_std=(self._density_noise_std if self.training else 0.0),
input_ray_bundle=coarse_ray_bundle,
ray_weights=coarse_weights,
camera_hash=camera_hash,
)
if renderer_pass == "coarse":
rgb_coarse = rgb
# Store the weights and the rays of the first rendering pass
# for the ensuing importance ray-sampling of the fine render.
coarse_ray_bundle = ray_bundle_out
coarse_weights = weights
if image is not None:
# Sample the ground truth images at the xy locations of the
# rendering ray pixels.
rgb_gt = sample_images_at_mc_locs(
image[..., :3][None],
ray_bundle_out.xys,
)
else:
rgb_gt = None
elif renderer_pass == "fine":
rgb_fine = rgb
else:
raise ValueError(f"No such rendering pass {renderer_pass}")
out = {"rgb_fine": rgb_fine, "rgb_coarse": rgb_coarse, "rgb_gt": rgb_gt}
if self.visualization:
# Store the coarse rays/weights only for visualization purposes.
out["coarse_ray_bundle"] = type(coarse_ray_bundle)(
*[v.detach().cpu() for k, v in coarse_ray_bundle._asdict().items()]
)
out["coarse_weights"] = coarse_weights.detach().cpu()
return out
def forward(
self,
camera_hash: Optional[str],
camera: CamerasBase,
image: torch.Tensor,
) -> Tuple[dict, dict]:
"""
Performs the coarse and fine rendering passes of the radiance field
from the viewpoint of the input `camera`.
Afterwards, both renders are compared to the input ground truth `image`
by evaluating the peak signal-to-noise ratio and the mean-squared error.
The rendering result depends on the `self.training` flag:
- In the training mode (`self.training==True`), the function renders
a random subset of image rays (Monte Carlo rendering).
- In evaluation mode (`self.training==False`), the function renders
the full image. In order to prevent out-of-memory errors,
when `self.training==False`, the rays are sampled and rendered
in batches of size `chunksize`.
Args:
camera_hash: A unique identifier of a pre-cached camera.
If `None`, the cache is not searched and the sampled rays are
calculated from scratch.
camera: A batch of cameras from which the scene is rendered.
image: A batch of corresponding ground truth images of shape
('batch_size', ·, ·, 3).
Returns:
out: `dict` containing the outputs of the rendering:
`rgb_coarse`: The result of the coarse rendering pass.
`rgb_fine`: The result of the fine rendering pass.
`rgb_gt`: The corresponding ground-truth RGB values.
The shape of `rgb_coarse`, `rgb_fine`, `rgb_gt` depends on the
`self.training` flag:
If `==True`, all 3 tensors are of shape
`(batch_size, n_rays_per_image, 3)` and contain the result
of the Monte Carlo training rendering pass.
If `==False`, all 3 tensors are of shape
`(batch_size, image_size[0], image_size[1], 3)` and contain
the result of the full image rendering pass.
metrics: `dict` containing the error metrics comparing the fine and
coarse renders to the ground truth:
`mse_coarse`: Mean-squared error between the coarse render and
the input `image`
`mse_fine`: Mean-squared error between the fine render and
the input `image`
`psnr_coarse`: Peak signal-to-noise ratio between the coarse render and
the input `image`
`psnr_fine`: Peak signal-to-noise ratio between the fine render and
the input `image`
"""
if not self.training:
# Full evaluation pass.
n_chunks = self._renderer["coarse"].raysampler.get_n_chunks(
self._chunk_size_test,
camera.R.shape[0],
)
else:
# MonteCarlo ray sampling.
n_chunks = 1
# Process the chunks of rays.
chunk_outputs = [
self._process_ray_chunk(
camera_hash,
camera,
image,
chunk_idx,
)
for chunk_idx in range(n_chunks)
]
if not self.training:
# For a full render pass concatenate the output chunks,
# and reshape to image size.
out = {
k: torch.cat(
[ch_o[k] for ch_o in chunk_outputs],
dim=1,
).view(-1, *self._image_size, 3)
if chunk_outputs[0][k] is not None
else None
for k in ("rgb_fine", "rgb_coarse", "rgb_gt")
}
else:
out = chunk_outputs[0]
# Calc the error metrics.
metrics = {}
if image is not None:
for render_pass in ("coarse", "fine"):
for metric_name, metric_fun in zip(
("mse", "psnr"), (calc_mse, calc_psnr)
):
metrics[f"{metric_name}_{render_pass}"] = metric_fun(
out["rgb_" + render_pass][..., :3],
out["rgb_gt"][..., :3],
)
return out, metrics
def visualize_nerf_outputs(
nerf_out: dict, output_cache: List, viz: Visdom, visdom_env: str
):
"""
Visualizes the outputs of the `RadianceFieldRenderer`.
Args:
nerf_out: An output of the validation rendering pass.
output_cache: A list with outputs of several training render passes.
viz: A visdom connection object.
visdom_env: The name of visdom environment for visualization.
"""
# Show the training images.
ims = torch.stack([o["image"] for o in output_cache])
ims = torch.cat(list(ims), dim=1)
viz.image(
ims.permute(2, 0, 1),
env=visdom_env,
win="images",
opts={"title": "train_images"},
)
# Show the coarse and fine renders together with the ground truth images.
ims_full = torch.cat(
[
nerf_out[imvar][0].permute(2, 0, 1).detach().cpu().clamp(0.0, 1.0)
for imvar in ("rgb_coarse", "rgb_fine", "rgb_gt")
],
dim=2,
)
viz.image(
ims_full,
env=visdom_env,
win="images_full",
opts={"title": "coarse | fine | target"},
)
# Make a 3D plot of training cameras and their emitted rays.
camera_trace = {
f"camera_{ci:03d}": o["camera"].cpu() for ci, o in enumerate(output_cache)
}
ray_pts_trace = {
f"ray_pts_{ci:03d}": Pointclouds(
ray_bundle_to_ray_points(o["coarse_ray_bundle"])
.detach()
.cpu()
.view(1, -1, 3)
)
for ci, o in enumerate(output_cache)
}
plotly_plot = plot_scene(
{
"training_scene": {
**camera_trace,
**ray_pts_trace,
},
},
pointcloud_max_points=5000,
pointcloud_marker_size=1,
camera_scale=0.3,
)
viz.plotlyplot(plotly_plot, env=visdom_env, win="scenes")