mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Radiance field renderer
Summary: Implements the main NeRF model class that controls the radiance field and its renderer Reviewed By: nikhilaravi Differential Revision: D25684419 fbshipit-source-id: fae45572daa6748c6234bd212f3e68110f778238
This commit is contained in:
parent
bf633ab556
commit
eb908487b8
359
projects/nerf/nerf/nerf_renderer.py
Normal file
359
projects/nerf/nerf/nerf_renderer.py
Normal file
@ -0,0 +1,359 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
from typing import Tuple, List, Optional
|
||||
|
||||
import torch
|
||||
from pytorch3d.renderer import ImplicitRenderer
|
||||
from pytorch3d.renderer.cameras import CamerasBase
|
||||
|
||||
from .implicit_function import NeuralRadianceField
|
||||
from .raymarcher import EmissionAbsorptionNeRFRaymarcher
|
||||
from .raysampler import NeRFRaysampler, ProbabilisticRaysampler
|
||||
from .utils import sample_images_at_mc_locs, calc_psnr, calc_mse
|
||||
|
||||
|
||||
class RadianceFieldRenderer(torch.nn.Module):
|
||||
"""
|
||||
Implements a renderer of a Neural Radiance Field.
|
||||
|
||||
This class holds pointers to the fine and coarse renderer objects, which are
|
||||
instances of `pytorch3d.renderer.ImplicitRenderer`, and pointers to the
|
||||
neural networks representing the fine and coarse Neural Radiance Fields,
|
||||
which are instances of `NeuralRadianceField`.
|
||||
|
||||
The rendering forward pass proceeds as follows:
|
||||
1) For a given input camera, rendering rays are generated with the
|
||||
`NeRFRaysampler` object of `self._renderer['coarse']`.
|
||||
In the training mode (`self.training==True`), the rays are a set
|
||||
of `n_rays_per_image` random 2D locations of the image grid.
|
||||
In the evaluation mode (`self.training==False`), the rays correspond
|
||||
to the full image grid. The rays are further split to
|
||||
`chunk_size_test`-sized chunks to prevent out-of-memory errors.
|
||||
2) For each ray point, the coarse `NeuralRadianceField` MLP is evaluated.
|
||||
The pointer to this MLP is stored in `self._implicit_function['coarse']`
|
||||
3) The coarse radiance field is rendered with the
|
||||
`EmissionAbsorptionNeRFRaymarcher` object of `self._renderer['coarse']`.
|
||||
4) The coarse raymarcher outputs a probability distribution that guides
|
||||
the importance raysampling of the fine rendering pass. The
|
||||
`ProbabilisticRaysampler` stored in `self._renderer['fine'].raysampler`
|
||||
implements the importance ray-sampling.
|
||||
5) Similar to 2) the fine MLP in `self._implicit_function['fine']`
|
||||
labels the ray points with occupancies and colors.
|
||||
6) self._renderer['fine'].raymarcher` generates the final fine render.
|
||||
7) The fine and coarse renders are compared to the ground truth input image
|
||||
with PSNR and MSE metrics.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_size: Tuple[int, int],
|
||||
n_pts_per_ray: int,
|
||||
n_pts_per_ray_fine: int,
|
||||
n_rays_per_image: int,
|
||||
min_depth: float,
|
||||
max_depth: float,
|
||||
stratified: bool,
|
||||
stratified_test: bool,
|
||||
chunk_size_test: int,
|
||||
n_harmonic_functions_xyz: int = 6,
|
||||
n_harmonic_functions_dir: int = 4,
|
||||
n_hidden_neurons_xyz: int = 256,
|
||||
n_hidden_neurons_dir: int = 128,
|
||||
n_layers_xyz: int = 8,
|
||||
append_xyz: List[int] = (5,),
|
||||
density_noise_std: float = 0.0,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
image_size: The size of the rendered image (`[height, width]`).
|
||||
n_pts_per_ray: The number of points sampled along each ray for the
|
||||
coarse rendering pass.
|
||||
n_pts_per_ray_fine: The number of points sampled along each ray for the
|
||||
fine rendering pass.
|
||||
n_rays_per_image: Number of Monte Carlo ray samples when training
|
||||
(`self.training==True`).
|
||||
min_depth: The minimum depth of a sampled ray-point for the coarse rendering.
|
||||
max_depth: The maximum depth of a sampled ray-point for the coarse rendering.
|
||||
stratified: If `True`, stratifies (=randomly offsets) the depths
|
||||
of each ray point during training (`self.training==True`).
|
||||
stratified_test: If `True`, stratifies (=randomly offsets) the depths
|
||||
of each ray point during evaluation (`self.training==False`).
|
||||
chunk_size_test: The number of rays in each chunk of image rays.
|
||||
Active only when `self.training==True`.
|
||||
n_harmonic_functions_xyz: The number of harmonic functions
|
||||
used to form the harmonic embedding of 3D point locations.
|
||||
n_harmonic_functions_dir: The number of harmonic functions
|
||||
used to form the harmonic embedding of the ray directions.
|
||||
n_hidden_neurons_xyz: The number of hidden units in the
|
||||
fully connected layers of the MLP that accepts the 3D point
|
||||
locations and outputs the occupancy field with the intermediate
|
||||
features.
|
||||
n_hidden_neurons_dir: The number of hidden units in the
|
||||
fully connected layers of the MLP that accepts the intermediate
|
||||
features and ray directions and outputs the radiance field
|
||||
(per-point colors).
|
||||
n_layers_xyz: The number of layers of the MLP that outputs the
|
||||
occupancy field.
|
||||
append_xyz: The list of indices of the skip layers of the occupancy MLP.
|
||||
Prior to evaluating the skip layers, the tensor which was input to MLP
|
||||
is appended to the skip layer input.
|
||||
density_noise_std: The standard deviation of the random normal noise
|
||||
added to the output of the occupancy MLP.
|
||||
Active only when `self.training==True`.
|
||||
"""
|
||||
|
||||
super().__init__()
|
||||
|
||||
# The renderers and implicit functions are stored under the fine/coarse
|
||||
# keys in ModuleDict PyTorch modules.
|
||||
self._renderer = torch.nn.ModuleDict()
|
||||
self._implicit_function = torch.nn.ModuleDict()
|
||||
|
||||
# Init the EA raymarcher used by both passes.
|
||||
raymarcher = EmissionAbsorptionNeRFRaymarcher()
|
||||
|
||||
# Parse out image dimensions.
|
||||
image_height, image_width = image_size
|
||||
|
||||
for render_pass in ("coarse", "fine"):
|
||||
if render_pass == "coarse":
|
||||
# Initialize the coarse raysampler.
|
||||
raysampler = NeRFRaysampler(
|
||||
n_pts_per_ray=n_pts_per_ray,
|
||||
min_depth=min_depth,
|
||||
max_depth=max_depth,
|
||||
stratified=stratified,
|
||||
stratified_test=stratified_test,
|
||||
n_rays_per_image=n_rays_per_image,
|
||||
image_height=image_height,
|
||||
image_width=image_width,
|
||||
)
|
||||
elif render_pass == "fine":
|
||||
# Initialize the fine raysampler.
|
||||
raysampler = ProbabilisticRaysampler(
|
||||
n_pts_per_ray=n_pts_per_ray_fine,
|
||||
stratified=stratified,
|
||||
stratified_test=stratified_test,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"No such rendering pass {render_pass}")
|
||||
|
||||
# Initialize the fine/coarse renderer.
|
||||
self._renderer[render_pass] = ImplicitRenderer(
|
||||
raysampler=raysampler,
|
||||
raymarcher=raymarcher,
|
||||
)
|
||||
|
||||
# Instantiate the fine/coarse NeuralRadianceField module.
|
||||
self._implicit_function[render_pass] = NeuralRadianceField(
|
||||
n_harmonic_functions_xyz=n_harmonic_functions_xyz,
|
||||
n_harmonic_functions_dir=n_harmonic_functions_dir,
|
||||
n_hidden_neurons_xyz=n_hidden_neurons_xyz,
|
||||
n_hidden_neurons_dir=n_hidden_neurons_dir,
|
||||
n_layers_xyz=n_layers_xyz,
|
||||
append_xyz=append_xyz,
|
||||
)
|
||||
|
||||
self._density_noise_std = density_noise_std
|
||||
self._chunk_size_test = chunk_size_test
|
||||
self._image_size = image_size
|
||||
|
||||
def precache_rays(
|
||||
self,
|
||||
cache_cameras: List[CamerasBase],
|
||||
cache_camera_hashes: List[str],
|
||||
):
|
||||
"""
|
||||
Precaches the rays emitted from the list of cameras `cache_cameras`,
|
||||
where each camera is uniquely identified with the corresponding hash
|
||||
from `cache_camera_hashes`.
|
||||
|
||||
The cached rays are moved to cpu and stored in
|
||||
`self._renderer['coarse']._ray_cache`.
|
||||
|
||||
Raises `ValueError` when caching two cameras with the same hash.
|
||||
|
||||
Args:
|
||||
cache_cameras: A list of `N` cameras for which the rays are pre-cached.
|
||||
cache_camera_hashes: A list of `N` unique identifiers for each
|
||||
camera from `cameras`.
|
||||
"""
|
||||
self._renderer["coarse"].raysampler.precache_rays(
|
||||
cache_cameras,
|
||||
cache_camera_hashes,
|
||||
)
|
||||
|
||||
def _process_ray_chunk(
|
||||
self,
|
||||
camera_hash: Optional[str],
|
||||
camera: CamerasBase,
|
||||
image: torch.Tensor,
|
||||
chunk_idx: int,
|
||||
) -> dict:
|
||||
"""
|
||||
Samples and renders a chunk of rays.
|
||||
|
||||
Args:
|
||||
camera_hash: A unique identifier of a pre-cached camera.
|
||||
If `None`, the cache is not searched and the sampled rays are
|
||||
calculated from scratch.
|
||||
camera: A batch of cameras from which the scene is rendered.
|
||||
image: A batch of corresponding ground truth images of shape
|
||||
('batch_size', ·, ·, 3).
|
||||
chunk_idx: The index of the currently rendered ray chunk.
|
||||
Returns:
|
||||
out: `dict` containing the outputs of the rendering:
|
||||
`rgb_coarse`: The result of the coarse rendering pass.
|
||||
`rgb_fine`: The result of the fine rendering pass.
|
||||
`rgb_gt`: The corresponding ground-truth RGB values.
|
||||
"""
|
||||
# Initialize the outputs of the coarse rendering to None.
|
||||
coarse_ray_bundle = None
|
||||
coarse_weights = None
|
||||
|
||||
# First evaluate the coarse rendering pass, then the fine one.
|
||||
for renderer_pass in ("coarse", "fine"):
|
||||
(rgb, weights), ray_bundle_out = self._renderer[renderer_pass](
|
||||
cameras=camera,
|
||||
volumetric_function=self._implicit_function[renderer_pass],
|
||||
chunksize=self._chunk_size_test,
|
||||
chunk_idx=chunk_idx,
|
||||
density_noise_std=(self._density_noise_std if self.training else 0.0),
|
||||
input_ray_bundle=coarse_ray_bundle,
|
||||
ray_weights=coarse_weights,
|
||||
camera_hash=camera_hash,
|
||||
)
|
||||
|
||||
if renderer_pass == "coarse":
|
||||
rgb_coarse = rgb
|
||||
# Store the weights and the rays of the first rendering pass
|
||||
# for the ensuing importance ray-sampling of the fine render.
|
||||
coarse_ray_bundle = ray_bundle_out
|
||||
coarse_weights = weights
|
||||
if image is not None:
|
||||
# Sample the ground truth images at the xy locations of the
|
||||
# rendering ray pixels.
|
||||
rgb_gt = sample_images_at_mc_locs(
|
||||
image[..., :3][None],
|
||||
ray_bundle_out.xys,
|
||||
)
|
||||
else:
|
||||
rgb_gt = None
|
||||
|
||||
elif renderer_pass == "fine":
|
||||
rgb_fine = rgb
|
||||
|
||||
else:
|
||||
raise ValueError(f"No such rendering pass {renderer_pass}")
|
||||
|
||||
return {
|
||||
"rgb_fine": rgb_fine,
|
||||
"rgb_coarse": rgb_coarse,
|
||||
"rgb_gt": rgb_gt,
|
||||
# Store the coarse rays/weights only for visualization purposes.
|
||||
"coarse_ray_bundle": type(coarse_ray_bundle)(
|
||||
*[v.detach().cpu() for k, v in coarse_ray_bundle._asdict().items()]
|
||||
),
|
||||
"coarse_weights": coarse_weights.detach().cpu(),
|
||||
}
|
||||
|
||||
def forward(
|
||||
self,
|
||||
camera_hash: Optional[str],
|
||||
camera: CamerasBase,
|
||||
image: torch.Tensor,
|
||||
) -> Tuple[dict, dict]:
|
||||
"""
|
||||
Performs the coarse and fine rendering passees of the radiance field
|
||||
from the viewpoint of the input `camera`.
|
||||
Afterwards, both renders are compared to the input ground truth `image`
|
||||
by evaluating the peak signal-to-noise ratio and the mean-squared error.
|
||||
|
||||
The rendering result depends on the `self.training` flag:
|
||||
- In the training mode (`self.training==True`), the function renders
|
||||
a random subset of image rays (Monte Carlo rendering).
|
||||
- In evaluation mode (`self.training==False`), the function renders
|
||||
the full image. In order to prevent out-of-memory errors,
|
||||
when `self.training==False`, the rays are sampled and rendered
|
||||
in batches of size `chunksize`.
|
||||
|
||||
Args:
|
||||
camera_hash: A unique identifier of a pre-cached camera.
|
||||
If `None`, the cache is not searched and the sampled rays are
|
||||
calculated from scratch.
|
||||
camera: A batch of cameras from which the scene is rendered.
|
||||
image: A batch of corresponding ground truth images of shape
|
||||
('batch_size', ·, ·, 3).
|
||||
Returns:
|
||||
out: `dict` containing the outputs of the rendering:
|
||||
`rgb_coarse`: The result of the coarse rendering pass.
|
||||
`rgb_fine`: The result of the fine rendering pass.
|
||||
`rgb_gt`: The corresponding ground-truth RGB values.
|
||||
|
||||
The shape of `rgb_coarse`, `rgb_fine`, `rgb_gt` depends on the
|
||||
`self.training` flag:
|
||||
If `==True`, all 3 tensors are of shape
|
||||
`(batch_size, n_rays_per_image, 3)` and contain the result
|
||||
of the Monte Carlo training rendering pass.
|
||||
If `==False`, all 3 tensors are of shape
|
||||
`(batch_size, image_size[0], image_size[1], 3)` and contain
|
||||
the result of the full image rendering pass.
|
||||
metrics: `dict` containing the error metrics comparing the fine and
|
||||
coarse renders to the ground truth:
|
||||
`mse_coarse`: Mean-squared error between the coarse render and
|
||||
the input `image`
|
||||
`mse_fine`: Mean-squared error between the fine render and
|
||||
the input `image`
|
||||
`psnr_coarse`: Peak signal-to-noise ratio between the coarse render and
|
||||
the input `image`
|
||||
`psnr_fine`: Peak signal-to-noise ratio between the fine render and
|
||||
the input `image`
|
||||
"""
|
||||
if not self.training:
|
||||
# Full evaluation pass.
|
||||
n_chunks = self._renderer["coarse"].raysampler.get_n_chunks(
|
||||
self._chunk_size_test,
|
||||
camera.R.shape[0],
|
||||
)
|
||||
else:
|
||||
# MonteCarlo ray sampling.
|
||||
n_chunks = 1
|
||||
|
||||
# Process the chunks of rays.
|
||||
chunk_outputs = [
|
||||
self._process_ray_chunk(
|
||||
camera_hash,
|
||||
camera,
|
||||
image,
|
||||
chunk_idx,
|
||||
)
|
||||
for chunk_idx in range(n_chunks)
|
||||
]
|
||||
|
||||
if not self.training:
|
||||
# For a full render pass concatenate the output chunks,
|
||||
# and reshape to image size.
|
||||
out = {
|
||||
k: torch.cat(
|
||||
[ch_o[k] for ch_o in chunk_outputs],
|
||||
dim=1,
|
||||
).view(-1, *self._image_size, 3)
|
||||
if chunk_outputs[0][k] is not None
|
||||
else None
|
||||
for k in ("rgb_fine", "rgb_coarse", "rgb_gt")
|
||||
}
|
||||
else:
|
||||
out = chunk_outputs[0]
|
||||
|
||||
# Calc the error metrics.
|
||||
metrics = {}
|
||||
if image is not None:
|
||||
for render_pass in ("coarse", "fine"):
|
||||
for metric_name, metric_fun in zip(
|
||||
("mse", "psnr"), (calc_mse, calc_psnr)
|
||||
):
|
||||
metrics[f"{metric_name}_{render_pass}"] = metric_fun(
|
||||
out["rgb_" + render_pass][..., :3],
|
||||
out["rgb_gt"][..., :3],
|
||||
)
|
||||
|
||||
return out, metrics
|
@ -65,3 +65,55 @@ def sample_pdf(
|
||||
samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0])
|
||||
|
||||
return samples
|
||||
|
||||
|
||||
def calc_mse(x: torch.Tensor, y: torch.Tensor):
|
||||
"""
|
||||
Calculates the mean square error between tensors `x` and `y`.
|
||||
"""
|
||||
return torch.mean((x - y) ** 2)
|
||||
|
||||
|
||||
def calc_psnr(x: torch.Tensor, y: torch.Tensor):
|
||||
"""
|
||||
Calculates the Peak-signal-to-noise ratio between tensors `x` and `y`.
|
||||
"""
|
||||
mse = calc_mse(x, y)
|
||||
psnr = -10.0 * torch.log10(mse)
|
||||
return psnr
|
||||
|
||||
|
||||
def sample_images_at_mc_locs(
|
||||
target_images: torch.Tensor,
|
||||
sampled_rays_xy: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Given a set of pixel locations `sampled_rays_xy` this method samples the tensor
|
||||
`target_images` at the respective 2D locations.
|
||||
|
||||
This function is used in order to extract the colors from ground truth images
|
||||
that correspond to the colors rendered using a Monte Carlo rendering.
|
||||
|
||||
Args:
|
||||
target_images: A tensor of shape `(batch_size, ..., 3)`.
|
||||
sampled_rays_xy: A tensor of shape `(batch_size, S_1, ..., S_N, 2)`.
|
||||
|
||||
Returns:
|
||||
images_sampled: A tensor of shape `(batch_size, S_1, ..., S_N, 3)`
|
||||
containing `target_images` sampled at `sampled_rays_xy`.
|
||||
"""
|
||||
ba = target_images.shape[0]
|
||||
dim = target_images.shape[-1]
|
||||
spatial_size = sampled_rays_xy.shape[1:-1]
|
||||
|
||||
# The coordinate grid convention for grid_sample has both x and y
|
||||
# directions inverted.
|
||||
xy_sample = -sampled_rays_xy.view(ba, -1, 1, 2).clone()
|
||||
|
||||
images_sampled = torch.nn.functional.grid_sample(
|
||||
target_images.permute(0, 3, 1, 2),
|
||||
xy_sample,
|
||||
align_corners=True,
|
||||
mode="bilinear",
|
||||
)
|
||||
return images_sampled.permute(0, 2, 3, 1).view(ba, *spatial_size, dim)
|
||||
|
Loading…
x
Reference in New Issue
Block a user