mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 11:52:50 +08:00
Radiance field renderer
Summary: Implements the main NeRF model class that controls the radiance field and its renderer Reviewed By: nikhilaravi Differential Revision: D25684419 fbshipit-source-id: fae45572daa6748c6234bd212f3e68110f778238
This commit is contained in:
parent
bf633ab556
commit
eb908487b8
359
projects/nerf/nerf/nerf_renderer.py
Normal file
359
projects/nerf/nerf/nerf_renderer.py
Normal file
@ -0,0 +1,359 @@
|
|||||||
|
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||||
|
from typing import Tuple, List, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from pytorch3d.renderer import ImplicitRenderer
|
||||||
|
from pytorch3d.renderer.cameras import CamerasBase
|
||||||
|
|
||||||
|
from .implicit_function import NeuralRadianceField
|
||||||
|
from .raymarcher import EmissionAbsorptionNeRFRaymarcher
|
||||||
|
from .raysampler import NeRFRaysampler, ProbabilisticRaysampler
|
||||||
|
from .utils import sample_images_at_mc_locs, calc_psnr, calc_mse
|
||||||
|
|
||||||
|
|
||||||
|
class RadianceFieldRenderer(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
Implements a renderer of a Neural Radiance Field.
|
||||||
|
|
||||||
|
This class holds pointers to the fine and coarse renderer objects, which are
|
||||||
|
instances of `pytorch3d.renderer.ImplicitRenderer`, and pointers to the
|
||||||
|
neural networks representing the fine and coarse Neural Radiance Fields,
|
||||||
|
which are instances of `NeuralRadianceField`.
|
||||||
|
|
||||||
|
The rendering forward pass proceeds as follows:
|
||||||
|
1) For a given input camera, rendering rays are generated with the
|
||||||
|
`NeRFRaysampler` object of `self._renderer['coarse']`.
|
||||||
|
In the training mode (`self.training==True`), the rays are a set
|
||||||
|
of `n_rays_per_image` random 2D locations of the image grid.
|
||||||
|
In the evaluation mode (`self.training==False`), the rays correspond
|
||||||
|
to the full image grid. The rays are further split to
|
||||||
|
`chunk_size_test`-sized chunks to prevent out-of-memory errors.
|
||||||
|
2) For each ray point, the coarse `NeuralRadianceField` MLP is evaluated.
|
||||||
|
The pointer to this MLP is stored in `self._implicit_function['coarse']`
|
||||||
|
3) The coarse radiance field is rendered with the
|
||||||
|
`EmissionAbsorptionNeRFRaymarcher` object of `self._renderer['coarse']`.
|
||||||
|
4) The coarse raymarcher outputs a probability distribution that guides
|
||||||
|
the importance raysampling of the fine rendering pass. The
|
||||||
|
`ProbabilisticRaysampler` stored in `self._renderer['fine'].raysampler`
|
||||||
|
implements the importance ray-sampling.
|
||||||
|
5) Similar to 2) the fine MLP in `self._implicit_function['fine']`
|
||||||
|
labels the ray points with occupancies and colors.
|
||||||
|
6) self._renderer['fine'].raymarcher` generates the final fine render.
|
||||||
|
7) The fine and coarse renders are compared to the ground truth input image
|
||||||
|
with PSNR and MSE metrics.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
image_size: Tuple[int, int],
|
||||||
|
n_pts_per_ray: int,
|
||||||
|
n_pts_per_ray_fine: int,
|
||||||
|
n_rays_per_image: int,
|
||||||
|
min_depth: float,
|
||||||
|
max_depth: float,
|
||||||
|
stratified: bool,
|
||||||
|
stratified_test: bool,
|
||||||
|
chunk_size_test: int,
|
||||||
|
n_harmonic_functions_xyz: int = 6,
|
||||||
|
n_harmonic_functions_dir: int = 4,
|
||||||
|
n_hidden_neurons_xyz: int = 256,
|
||||||
|
n_hidden_neurons_dir: int = 128,
|
||||||
|
n_layers_xyz: int = 8,
|
||||||
|
append_xyz: List[int] = (5,),
|
||||||
|
density_noise_std: float = 0.0,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
image_size: The size of the rendered image (`[height, width]`).
|
||||||
|
n_pts_per_ray: The number of points sampled along each ray for the
|
||||||
|
coarse rendering pass.
|
||||||
|
n_pts_per_ray_fine: The number of points sampled along each ray for the
|
||||||
|
fine rendering pass.
|
||||||
|
n_rays_per_image: Number of Monte Carlo ray samples when training
|
||||||
|
(`self.training==True`).
|
||||||
|
min_depth: The minimum depth of a sampled ray-point for the coarse rendering.
|
||||||
|
max_depth: The maximum depth of a sampled ray-point for the coarse rendering.
|
||||||
|
stratified: If `True`, stratifies (=randomly offsets) the depths
|
||||||
|
of each ray point during training (`self.training==True`).
|
||||||
|
stratified_test: If `True`, stratifies (=randomly offsets) the depths
|
||||||
|
of each ray point during evaluation (`self.training==False`).
|
||||||
|
chunk_size_test: The number of rays in each chunk of image rays.
|
||||||
|
Active only when `self.training==True`.
|
||||||
|
n_harmonic_functions_xyz: The number of harmonic functions
|
||||||
|
used to form the harmonic embedding of 3D point locations.
|
||||||
|
n_harmonic_functions_dir: The number of harmonic functions
|
||||||
|
used to form the harmonic embedding of the ray directions.
|
||||||
|
n_hidden_neurons_xyz: The number of hidden units in the
|
||||||
|
fully connected layers of the MLP that accepts the 3D point
|
||||||
|
locations and outputs the occupancy field with the intermediate
|
||||||
|
features.
|
||||||
|
n_hidden_neurons_dir: The number of hidden units in the
|
||||||
|
fully connected layers of the MLP that accepts the intermediate
|
||||||
|
features and ray directions and outputs the radiance field
|
||||||
|
(per-point colors).
|
||||||
|
n_layers_xyz: The number of layers of the MLP that outputs the
|
||||||
|
occupancy field.
|
||||||
|
append_xyz: The list of indices of the skip layers of the occupancy MLP.
|
||||||
|
Prior to evaluating the skip layers, the tensor which was input to MLP
|
||||||
|
is appended to the skip layer input.
|
||||||
|
density_noise_std: The standard deviation of the random normal noise
|
||||||
|
added to the output of the occupancy MLP.
|
||||||
|
Active only when `self.training==True`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# The renderers and implicit functions are stored under the fine/coarse
|
||||||
|
# keys in ModuleDict PyTorch modules.
|
||||||
|
self._renderer = torch.nn.ModuleDict()
|
||||||
|
self._implicit_function = torch.nn.ModuleDict()
|
||||||
|
|
||||||
|
# Init the EA raymarcher used by both passes.
|
||||||
|
raymarcher = EmissionAbsorptionNeRFRaymarcher()
|
||||||
|
|
||||||
|
# Parse out image dimensions.
|
||||||
|
image_height, image_width = image_size
|
||||||
|
|
||||||
|
for render_pass in ("coarse", "fine"):
|
||||||
|
if render_pass == "coarse":
|
||||||
|
# Initialize the coarse raysampler.
|
||||||
|
raysampler = NeRFRaysampler(
|
||||||
|
n_pts_per_ray=n_pts_per_ray,
|
||||||
|
min_depth=min_depth,
|
||||||
|
max_depth=max_depth,
|
||||||
|
stratified=stratified,
|
||||||
|
stratified_test=stratified_test,
|
||||||
|
n_rays_per_image=n_rays_per_image,
|
||||||
|
image_height=image_height,
|
||||||
|
image_width=image_width,
|
||||||
|
)
|
||||||
|
elif render_pass == "fine":
|
||||||
|
# Initialize the fine raysampler.
|
||||||
|
raysampler = ProbabilisticRaysampler(
|
||||||
|
n_pts_per_ray=n_pts_per_ray_fine,
|
||||||
|
stratified=stratified,
|
||||||
|
stratified_test=stratified_test,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"No such rendering pass {render_pass}")
|
||||||
|
|
||||||
|
# Initialize the fine/coarse renderer.
|
||||||
|
self._renderer[render_pass] = ImplicitRenderer(
|
||||||
|
raysampler=raysampler,
|
||||||
|
raymarcher=raymarcher,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Instantiate the fine/coarse NeuralRadianceField module.
|
||||||
|
self._implicit_function[render_pass] = NeuralRadianceField(
|
||||||
|
n_harmonic_functions_xyz=n_harmonic_functions_xyz,
|
||||||
|
n_harmonic_functions_dir=n_harmonic_functions_dir,
|
||||||
|
n_hidden_neurons_xyz=n_hidden_neurons_xyz,
|
||||||
|
n_hidden_neurons_dir=n_hidden_neurons_dir,
|
||||||
|
n_layers_xyz=n_layers_xyz,
|
||||||
|
append_xyz=append_xyz,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._density_noise_std = density_noise_std
|
||||||
|
self._chunk_size_test = chunk_size_test
|
||||||
|
self._image_size = image_size
|
||||||
|
|
||||||
|
def precache_rays(
|
||||||
|
self,
|
||||||
|
cache_cameras: List[CamerasBase],
|
||||||
|
cache_camera_hashes: List[str],
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Precaches the rays emitted from the list of cameras `cache_cameras`,
|
||||||
|
where each camera is uniquely identified with the corresponding hash
|
||||||
|
from `cache_camera_hashes`.
|
||||||
|
|
||||||
|
The cached rays are moved to cpu and stored in
|
||||||
|
`self._renderer['coarse']._ray_cache`.
|
||||||
|
|
||||||
|
Raises `ValueError` when caching two cameras with the same hash.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cache_cameras: A list of `N` cameras for which the rays are pre-cached.
|
||||||
|
cache_camera_hashes: A list of `N` unique identifiers for each
|
||||||
|
camera from `cameras`.
|
||||||
|
"""
|
||||||
|
self._renderer["coarse"].raysampler.precache_rays(
|
||||||
|
cache_cameras,
|
||||||
|
cache_camera_hashes,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _process_ray_chunk(
|
||||||
|
self,
|
||||||
|
camera_hash: Optional[str],
|
||||||
|
camera: CamerasBase,
|
||||||
|
image: torch.Tensor,
|
||||||
|
chunk_idx: int,
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
Samples and renders a chunk of rays.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
camera_hash: A unique identifier of a pre-cached camera.
|
||||||
|
If `None`, the cache is not searched and the sampled rays are
|
||||||
|
calculated from scratch.
|
||||||
|
camera: A batch of cameras from which the scene is rendered.
|
||||||
|
image: A batch of corresponding ground truth images of shape
|
||||||
|
('batch_size', ·, ·, 3).
|
||||||
|
chunk_idx: The index of the currently rendered ray chunk.
|
||||||
|
Returns:
|
||||||
|
out: `dict` containing the outputs of the rendering:
|
||||||
|
`rgb_coarse`: The result of the coarse rendering pass.
|
||||||
|
`rgb_fine`: The result of the fine rendering pass.
|
||||||
|
`rgb_gt`: The corresponding ground-truth RGB values.
|
||||||
|
"""
|
||||||
|
# Initialize the outputs of the coarse rendering to None.
|
||||||
|
coarse_ray_bundle = None
|
||||||
|
coarse_weights = None
|
||||||
|
|
||||||
|
# First evaluate the coarse rendering pass, then the fine one.
|
||||||
|
for renderer_pass in ("coarse", "fine"):
|
||||||
|
(rgb, weights), ray_bundle_out = self._renderer[renderer_pass](
|
||||||
|
cameras=camera,
|
||||||
|
volumetric_function=self._implicit_function[renderer_pass],
|
||||||
|
chunksize=self._chunk_size_test,
|
||||||
|
chunk_idx=chunk_idx,
|
||||||
|
density_noise_std=(self._density_noise_std if self.training else 0.0),
|
||||||
|
input_ray_bundle=coarse_ray_bundle,
|
||||||
|
ray_weights=coarse_weights,
|
||||||
|
camera_hash=camera_hash,
|
||||||
|
)
|
||||||
|
|
||||||
|
if renderer_pass == "coarse":
|
||||||
|
rgb_coarse = rgb
|
||||||
|
# Store the weights and the rays of the first rendering pass
|
||||||
|
# for the ensuing importance ray-sampling of the fine render.
|
||||||
|
coarse_ray_bundle = ray_bundle_out
|
||||||
|
coarse_weights = weights
|
||||||
|
if image is not None:
|
||||||
|
# Sample the ground truth images at the xy locations of the
|
||||||
|
# rendering ray pixels.
|
||||||
|
rgb_gt = sample_images_at_mc_locs(
|
||||||
|
image[..., :3][None],
|
||||||
|
ray_bundle_out.xys,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
rgb_gt = None
|
||||||
|
|
||||||
|
elif renderer_pass == "fine":
|
||||||
|
rgb_fine = rgb
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError(f"No such rendering pass {renderer_pass}")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"rgb_fine": rgb_fine,
|
||||||
|
"rgb_coarse": rgb_coarse,
|
||||||
|
"rgb_gt": rgb_gt,
|
||||||
|
# Store the coarse rays/weights only for visualization purposes.
|
||||||
|
"coarse_ray_bundle": type(coarse_ray_bundle)(
|
||||||
|
*[v.detach().cpu() for k, v in coarse_ray_bundle._asdict().items()]
|
||||||
|
),
|
||||||
|
"coarse_weights": coarse_weights.detach().cpu(),
|
||||||
|
}
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
camera_hash: Optional[str],
|
||||||
|
camera: CamerasBase,
|
||||||
|
image: torch.Tensor,
|
||||||
|
) -> Tuple[dict, dict]:
|
||||||
|
"""
|
||||||
|
Performs the coarse and fine rendering passees of the radiance field
|
||||||
|
from the viewpoint of the input `camera`.
|
||||||
|
Afterwards, both renders are compared to the input ground truth `image`
|
||||||
|
by evaluating the peak signal-to-noise ratio and the mean-squared error.
|
||||||
|
|
||||||
|
The rendering result depends on the `self.training` flag:
|
||||||
|
- In the training mode (`self.training==True`), the function renders
|
||||||
|
a random subset of image rays (Monte Carlo rendering).
|
||||||
|
- In evaluation mode (`self.training==False`), the function renders
|
||||||
|
the full image. In order to prevent out-of-memory errors,
|
||||||
|
when `self.training==False`, the rays are sampled and rendered
|
||||||
|
in batches of size `chunksize`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
camera_hash: A unique identifier of a pre-cached camera.
|
||||||
|
If `None`, the cache is not searched and the sampled rays are
|
||||||
|
calculated from scratch.
|
||||||
|
camera: A batch of cameras from which the scene is rendered.
|
||||||
|
image: A batch of corresponding ground truth images of shape
|
||||||
|
('batch_size', ·, ·, 3).
|
||||||
|
Returns:
|
||||||
|
out: `dict` containing the outputs of the rendering:
|
||||||
|
`rgb_coarse`: The result of the coarse rendering pass.
|
||||||
|
`rgb_fine`: The result of the fine rendering pass.
|
||||||
|
`rgb_gt`: The corresponding ground-truth RGB values.
|
||||||
|
|
||||||
|
The shape of `rgb_coarse`, `rgb_fine`, `rgb_gt` depends on the
|
||||||
|
`self.training` flag:
|
||||||
|
If `==True`, all 3 tensors are of shape
|
||||||
|
`(batch_size, n_rays_per_image, 3)` and contain the result
|
||||||
|
of the Monte Carlo training rendering pass.
|
||||||
|
If `==False`, all 3 tensors are of shape
|
||||||
|
`(batch_size, image_size[0], image_size[1], 3)` and contain
|
||||||
|
the result of the full image rendering pass.
|
||||||
|
metrics: `dict` containing the error metrics comparing the fine and
|
||||||
|
coarse renders to the ground truth:
|
||||||
|
`mse_coarse`: Mean-squared error between the coarse render and
|
||||||
|
the input `image`
|
||||||
|
`mse_fine`: Mean-squared error between the fine render and
|
||||||
|
the input `image`
|
||||||
|
`psnr_coarse`: Peak signal-to-noise ratio between the coarse render and
|
||||||
|
the input `image`
|
||||||
|
`psnr_fine`: Peak signal-to-noise ratio between the fine render and
|
||||||
|
the input `image`
|
||||||
|
"""
|
||||||
|
if not self.training:
|
||||||
|
# Full evaluation pass.
|
||||||
|
n_chunks = self._renderer["coarse"].raysampler.get_n_chunks(
|
||||||
|
self._chunk_size_test,
|
||||||
|
camera.R.shape[0],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# MonteCarlo ray sampling.
|
||||||
|
n_chunks = 1
|
||||||
|
|
||||||
|
# Process the chunks of rays.
|
||||||
|
chunk_outputs = [
|
||||||
|
self._process_ray_chunk(
|
||||||
|
camera_hash,
|
||||||
|
camera,
|
||||||
|
image,
|
||||||
|
chunk_idx,
|
||||||
|
)
|
||||||
|
for chunk_idx in range(n_chunks)
|
||||||
|
]
|
||||||
|
|
||||||
|
if not self.training:
|
||||||
|
# For a full render pass concatenate the output chunks,
|
||||||
|
# and reshape to image size.
|
||||||
|
out = {
|
||||||
|
k: torch.cat(
|
||||||
|
[ch_o[k] for ch_o in chunk_outputs],
|
||||||
|
dim=1,
|
||||||
|
).view(-1, *self._image_size, 3)
|
||||||
|
if chunk_outputs[0][k] is not None
|
||||||
|
else None
|
||||||
|
for k in ("rgb_fine", "rgb_coarse", "rgb_gt")
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
out = chunk_outputs[0]
|
||||||
|
|
||||||
|
# Calc the error metrics.
|
||||||
|
metrics = {}
|
||||||
|
if image is not None:
|
||||||
|
for render_pass in ("coarse", "fine"):
|
||||||
|
for metric_name, metric_fun in zip(
|
||||||
|
("mse", "psnr"), (calc_mse, calc_psnr)
|
||||||
|
):
|
||||||
|
metrics[f"{metric_name}_{render_pass}"] = metric_fun(
|
||||||
|
out["rgb_" + render_pass][..., :3],
|
||||||
|
out["rgb_gt"][..., :3],
|
||||||
|
)
|
||||||
|
|
||||||
|
return out, metrics
|
@ -65,3 +65,55 @@ def sample_pdf(
|
|||||||
samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0])
|
samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0])
|
||||||
|
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
|
|
||||||
|
def calc_mse(x: torch.Tensor, y: torch.Tensor):
|
||||||
|
"""
|
||||||
|
Calculates the mean square error between tensors `x` and `y`.
|
||||||
|
"""
|
||||||
|
return torch.mean((x - y) ** 2)
|
||||||
|
|
||||||
|
|
||||||
|
def calc_psnr(x: torch.Tensor, y: torch.Tensor):
|
||||||
|
"""
|
||||||
|
Calculates the Peak-signal-to-noise ratio between tensors `x` and `y`.
|
||||||
|
"""
|
||||||
|
mse = calc_mse(x, y)
|
||||||
|
psnr = -10.0 * torch.log10(mse)
|
||||||
|
return psnr
|
||||||
|
|
||||||
|
|
||||||
|
def sample_images_at_mc_locs(
|
||||||
|
target_images: torch.Tensor,
|
||||||
|
sampled_rays_xy: torch.Tensor,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Given a set of pixel locations `sampled_rays_xy` this method samples the tensor
|
||||||
|
`target_images` at the respective 2D locations.
|
||||||
|
|
||||||
|
This function is used in order to extract the colors from ground truth images
|
||||||
|
that correspond to the colors rendered using a Monte Carlo rendering.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
target_images: A tensor of shape `(batch_size, ..., 3)`.
|
||||||
|
sampled_rays_xy: A tensor of shape `(batch_size, S_1, ..., S_N, 2)`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
images_sampled: A tensor of shape `(batch_size, S_1, ..., S_N, 3)`
|
||||||
|
containing `target_images` sampled at `sampled_rays_xy`.
|
||||||
|
"""
|
||||||
|
ba = target_images.shape[0]
|
||||||
|
dim = target_images.shape[-1]
|
||||||
|
spatial_size = sampled_rays_xy.shape[1:-1]
|
||||||
|
|
||||||
|
# The coordinate grid convention for grid_sample has both x and y
|
||||||
|
# directions inverted.
|
||||||
|
xy_sample = -sampled_rays_xy.view(ba, -1, 1, 2).clone()
|
||||||
|
|
||||||
|
images_sampled = torch.nn.functional.grid_sample(
|
||||||
|
target_images.permute(0, 3, 1, 2),
|
||||||
|
xy_sample,
|
||||||
|
align_corners=True,
|
||||||
|
mode="bilinear",
|
||||||
|
)
|
||||||
|
return images_sampled.permute(0, 2, 3, 1).view(ba, *spatial_size, dim)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user