From eb908487b8df2e52ab41fa8cabc72c2767fdfb8c Mon Sep 17 00:00:00 2001 From: David Novotny Date: Tue, 2 Feb 2021 05:42:59 -0800 Subject: [PATCH] Radiance field renderer Summary: Implements the main NeRF model class that controls the radiance field and its renderer Reviewed By: nikhilaravi Differential Revision: D25684419 fbshipit-source-id: fae45572daa6748c6234bd212f3e68110f778238 --- projects/nerf/nerf/nerf_renderer.py | 359 ++++++++++++++++++++++++++++ projects/nerf/nerf/utils.py | 52 ++++ 2 files changed, 411 insertions(+) create mode 100644 projects/nerf/nerf/nerf_renderer.py diff --git a/projects/nerf/nerf/nerf_renderer.py b/projects/nerf/nerf/nerf_renderer.py new file mode 100644 index 00000000..57ec5db6 --- /dev/null +++ b/projects/nerf/nerf/nerf_renderer.py @@ -0,0 +1,359 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +from typing import Tuple, List, Optional + +import torch +from pytorch3d.renderer import ImplicitRenderer +from pytorch3d.renderer.cameras import CamerasBase + +from .implicit_function import NeuralRadianceField +from .raymarcher import EmissionAbsorptionNeRFRaymarcher +from .raysampler import NeRFRaysampler, ProbabilisticRaysampler +from .utils import sample_images_at_mc_locs, calc_psnr, calc_mse + + +class RadianceFieldRenderer(torch.nn.Module): + """ + Implements a renderer of a Neural Radiance Field. + + This class holds pointers to the fine and coarse renderer objects, which are + instances of `pytorch3d.renderer.ImplicitRenderer`, and pointers to the + neural networks representing the fine and coarse Neural Radiance Fields, + which are instances of `NeuralRadianceField`. + + The rendering forward pass proceeds as follows: + 1) For a given input camera, rendering rays are generated with the + `NeRFRaysampler` object of `self._renderer['coarse']`. + In the training mode (`self.training==True`), the rays are a set + of `n_rays_per_image` random 2D locations of the image grid. + In the evaluation mode (`self.training==False`), the rays correspond + to the full image grid. The rays are further split to + `chunk_size_test`-sized chunks to prevent out-of-memory errors. + 2) For each ray point, the coarse `NeuralRadianceField` MLP is evaluated. + The pointer to this MLP is stored in `self._implicit_function['coarse']` + 3) The coarse radiance field is rendered with the + `EmissionAbsorptionNeRFRaymarcher` object of `self._renderer['coarse']`. + 4) The coarse raymarcher outputs a probability distribution that guides + the importance raysampling of the fine rendering pass. The + `ProbabilisticRaysampler` stored in `self._renderer['fine'].raysampler` + implements the importance ray-sampling. + 5) Similar to 2) the fine MLP in `self._implicit_function['fine']` + labels the ray points with occupancies and colors. + 6) self._renderer['fine'].raymarcher` generates the final fine render. + 7) The fine and coarse renders are compared to the ground truth input image + with PSNR and MSE metrics. + """ + + def __init__( + self, + image_size: Tuple[int, int], + n_pts_per_ray: int, + n_pts_per_ray_fine: int, + n_rays_per_image: int, + min_depth: float, + max_depth: float, + stratified: bool, + stratified_test: bool, + chunk_size_test: int, + n_harmonic_functions_xyz: int = 6, + n_harmonic_functions_dir: int = 4, + n_hidden_neurons_xyz: int = 256, + n_hidden_neurons_dir: int = 128, + n_layers_xyz: int = 8, + append_xyz: List[int] = (5,), + density_noise_std: float = 0.0, + ): + """ + Args: + image_size: The size of the rendered image (`[height, width]`). + n_pts_per_ray: The number of points sampled along each ray for the + coarse rendering pass. + n_pts_per_ray_fine: The number of points sampled along each ray for the + fine rendering pass. + n_rays_per_image: Number of Monte Carlo ray samples when training + (`self.training==True`). + min_depth: The minimum depth of a sampled ray-point for the coarse rendering. + max_depth: The maximum depth of a sampled ray-point for the coarse rendering. + stratified: If `True`, stratifies (=randomly offsets) the depths + of each ray point during training (`self.training==True`). + stratified_test: If `True`, stratifies (=randomly offsets) the depths + of each ray point during evaluation (`self.training==False`). + chunk_size_test: The number of rays in each chunk of image rays. + Active only when `self.training==True`. + n_harmonic_functions_xyz: The number of harmonic functions + used to form the harmonic embedding of 3D point locations. + n_harmonic_functions_dir: The number of harmonic functions + used to form the harmonic embedding of the ray directions. + n_hidden_neurons_xyz: The number of hidden units in the + fully connected layers of the MLP that accepts the 3D point + locations and outputs the occupancy field with the intermediate + features. + n_hidden_neurons_dir: The number of hidden units in the + fully connected layers of the MLP that accepts the intermediate + features and ray directions and outputs the radiance field + (per-point colors). + n_layers_xyz: The number of layers of the MLP that outputs the + occupancy field. + append_xyz: The list of indices of the skip layers of the occupancy MLP. + Prior to evaluating the skip layers, the tensor which was input to MLP + is appended to the skip layer input. + density_noise_std: The standard deviation of the random normal noise + added to the output of the occupancy MLP. + Active only when `self.training==True`. + """ + + super().__init__() + + # The renderers and implicit functions are stored under the fine/coarse + # keys in ModuleDict PyTorch modules. + self._renderer = torch.nn.ModuleDict() + self._implicit_function = torch.nn.ModuleDict() + + # Init the EA raymarcher used by both passes. + raymarcher = EmissionAbsorptionNeRFRaymarcher() + + # Parse out image dimensions. + image_height, image_width = image_size + + for render_pass in ("coarse", "fine"): + if render_pass == "coarse": + # Initialize the coarse raysampler. + raysampler = NeRFRaysampler( + n_pts_per_ray=n_pts_per_ray, + min_depth=min_depth, + max_depth=max_depth, + stratified=stratified, + stratified_test=stratified_test, + n_rays_per_image=n_rays_per_image, + image_height=image_height, + image_width=image_width, + ) + elif render_pass == "fine": + # Initialize the fine raysampler. + raysampler = ProbabilisticRaysampler( + n_pts_per_ray=n_pts_per_ray_fine, + stratified=stratified, + stratified_test=stratified_test, + ) + else: + raise ValueError(f"No such rendering pass {render_pass}") + + # Initialize the fine/coarse renderer. + self._renderer[render_pass] = ImplicitRenderer( + raysampler=raysampler, + raymarcher=raymarcher, + ) + + # Instantiate the fine/coarse NeuralRadianceField module. + self._implicit_function[render_pass] = NeuralRadianceField( + n_harmonic_functions_xyz=n_harmonic_functions_xyz, + n_harmonic_functions_dir=n_harmonic_functions_dir, + n_hidden_neurons_xyz=n_hidden_neurons_xyz, + n_hidden_neurons_dir=n_hidden_neurons_dir, + n_layers_xyz=n_layers_xyz, + append_xyz=append_xyz, + ) + + self._density_noise_std = density_noise_std + self._chunk_size_test = chunk_size_test + self._image_size = image_size + + def precache_rays( + self, + cache_cameras: List[CamerasBase], + cache_camera_hashes: List[str], + ): + """ + Precaches the rays emitted from the list of cameras `cache_cameras`, + where each camera is uniquely identified with the corresponding hash + from `cache_camera_hashes`. + + The cached rays are moved to cpu and stored in + `self._renderer['coarse']._ray_cache`. + + Raises `ValueError` when caching two cameras with the same hash. + + Args: + cache_cameras: A list of `N` cameras for which the rays are pre-cached. + cache_camera_hashes: A list of `N` unique identifiers for each + camera from `cameras`. + """ + self._renderer["coarse"].raysampler.precache_rays( + cache_cameras, + cache_camera_hashes, + ) + + def _process_ray_chunk( + self, + camera_hash: Optional[str], + camera: CamerasBase, + image: torch.Tensor, + chunk_idx: int, + ) -> dict: + """ + Samples and renders a chunk of rays. + + Args: + camera_hash: A unique identifier of a pre-cached camera. + If `None`, the cache is not searched and the sampled rays are + calculated from scratch. + camera: A batch of cameras from which the scene is rendered. + image: A batch of corresponding ground truth images of shape + ('batch_size', ·, ·, 3). + chunk_idx: The index of the currently rendered ray chunk. + Returns: + out: `dict` containing the outputs of the rendering: + `rgb_coarse`: The result of the coarse rendering pass. + `rgb_fine`: The result of the fine rendering pass. + `rgb_gt`: The corresponding ground-truth RGB values. + """ + # Initialize the outputs of the coarse rendering to None. + coarse_ray_bundle = None + coarse_weights = None + + # First evaluate the coarse rendering pass, then the fine one. + for renderer_pass in ("coarse", "fine"): + (rgb, weights), ray_bundle_out = self._renderer[renderer_pass]( + cameras=camera, + volumetric_function=self._implicit_function[renderer_pass], + chunksize=self._chunk_size_test, + chunk_idx=chunk_idx, + density_noise_std=(self._density_noise_std if self.training else 0.0), + input_ray_bundle=coarse_ray_bundle, + ray_weights=coarse_weights, + camera_hash=camera_hash, + ) + + if renderer_pass == "coarse": + rgb_coarse = rgb + # Store the weights and the rays of the first rendering pass + # for the ensuing importance ray-sampling of the fine render. + coarse_ray_bundle = ray_bundle_out + coarse_weights = weights + if image is not None: + # Sample the ground truth images at the xy locations of the + # rendering ray pixels. + rgb_gt = sample_images_at_mc_locs( + image[..., :3][None], + ray_bundle_out.xys, + ) + else: + rgb_gt = None + + elif renderer_pass == "fine": + rgb_fine = rgb + + else: + raise ValueError(f"No such rendering pass {renderer_pass}") + + return { + "rgb_fine": rgb_fine, + "rgb_coarse": rgb_coarse, + "rgb_gt": rgb_gt, + # Store the coarse rays/weights only for visualization purposes. + "coarse_ray_bundle": type(coarse_ray_bundle)( + *[v.detach().cpu() for k, v in coarse_ray_bundle._asdict().items()] + ), + "coarse_weights": coarse_weights.detach().cpu(), + } + + def forward( + self, + camera_hash: Optional[str], + camera: CamerasBase, + image: torch.Tensor, + ) -> Tuple[dict, dict]: + """ + Performs the coarse and fine rendering passees of the radiance field + from the viewpoint of the input `camera`. + Afterwards, both renders are compared to the input ground truth `image` + by evaluating the peak signal-to-noise ratio and the mean-squared error. + + The rendering result depends on the `self.training` flag: + - In the training mode (`self.training==True`), the function renders + a random subset of image rays (Monte Carlo rendering). + - In evaluation mode (`self.training==False`), the function renders + the full image. In order to prevent out-of-memory errors, + when `self.training==False`, the rays are sampled and rendered + in batches of size `chunksize`. + + Args: + camera_hash: A unique identifier of a pre-cached camera. + If `None`, the cache is not searched and the sampled rays are + calculated from scratch. + camera: A batch of cameras from which the scene is rendered. + image: A batch of corresponding ground truth images of shape + ('batch_size', ·, ·, 3). + Returns: + out: `dict` containing the outputs of the rendering: + `rgb_coarse`: The result of the coarse rendering pass. + `rgb_fine`: The result of the fine rendering pass. + `rgb_gt`: The corresponding ground-truth RGB values. + + The shape of `rgb_coarse`, `rgb_fine`, `rgb_gt` depends on the + `self.training` flag: + If `==True`, all 3 tensors are of shape + `(batch_size, n_rays_per_image, 3)` and contain the result + of the Monte Carlo training rendering pass. + If `==False`, all 3 tensors are of shape + `(batch_size, image_size[0], image_size[1], 3)` and contain + the result of the full image rendering pass. + metrics: `dict` containing the error metrics comparing the fine and + coarse renders to the ground truth: + `mse_coarse`: Mean-squared error between the coarse render and + the input `image` + `mse_fine`: Mean-squared error between the fine render and + the input `image` + `psnr_coarse`: Peak signal-to-noise ratio between the coarse render and + the input `image` + `psnr_fine`: Peak signal-to-noise ratio between the fine render and + the input `image` + """ + if not self.training: + # Full evaluation pass. + n_chunks = self._renderer["coarse"].raysampler.get_n_chunks( + self._chunk_size_test, + camera.R.shape[0], + ) + else: + # MonteCarlo ray sampling. + n_chunks = 1 + + # Process the chunks of rays. + chunk_outputs = [ + self._process_ray_chunk( + camera_hash, + camera, + image, + chunk_idx, + ) + for chunk_idx in range(n_chunks) + ] + + if not self.training: + # For a full render pass concatenate the output chunks, + # and reshape to image size. + out = { + k: torch.cat( + [ch_o[k] for ch_o in chunk_outputs], + dim=1, + ).view(-1, *self._image_size, 3) + if chunk_outputs[0][k] is not None + else None + for k in ("rgb_fine", "rgb_coarse", "rgb_gt") + } + else: + out = chunk_outputs[0] + + # Calc the error metrics. + metrics = {} + if image is not None: + for render_pass in ("coarse", "fine"): + for metric_name, metric_fun in zip( + ("mse", "psnr"), (calc_mse, calc_psnr) + ): + metrics[f"{metric_name}_{render_pass}"] = metric_fun( + out["rgb_" + render_pass][..., :3], + out["rgb_gt"][..., :3], + ) + + return out, metrics diff --git a/projects/nerf/nerf/utils.py b/projects/nerf/nerf/utils.py index f50464ef..ca2897a5 100644 --- a/projects/nerf/nerf/utils.py +++ b/projects/nerf/nerf/utils.py @@ -65,3 +65,55 @@ def sample_pdf( samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0]) return samples + + +def calc_mse(x: torch.Tensor, y: torch.Tensor): + """ + Calculates the mean square error between tensors `x` and `y`. + """ + return torch.mean((x - y) ** 2) + + +def calc_psnr(x: torch.Tensor, y: torch.Tensor): + """ + Calculates the Peak-signal-to-noise ratio between tensors `x` and `y`. + """ + mse = calc_mse(x, y) + psnr = -10.0 * torch.log10(mse) + return psnr + + +def sample_images_at_mc_locs( + target_images: torch.Tensor, + sampled_rays_xy: torch.Tensor, +): + """ + Given a set of pixel locations `sampled_rays_xy` this method samples the tensor + `target_images` at the respective 2D locations. + + This function is used in order to extract the colors from ground truth images + that correspond to the colors rendered using a Monte Carlo rendering. + + Args: + target_images: A tensor of shape `(batch_size, ..., 3)`. + sampled_rays_xy: A tensor of shape `(batch_size, S_1, ..., S_N, 2)`. + + Returns: + images_sampled: A tensor of shape `(batch_size, S_1, ..., S_N, 3)` + containing `target_images` sampled at `sampled_rays_xy`. + """ + ba = target_images.shape[0] + dim = target_images.shape[-1] + spatial_size = sampled_rays_xy.shape[1:-1] + + # The coordinate grid convention for grid_sample has both x and y + # directions inverted. + xy_sample = -sampled_rays_xy.view(ba, -1, 1, 2).clone() + + images_sampled = torch.nn.functional.grid_sample( + target_images.permute(0, 3, 1, 2), + xy_sample, + align_corners=True, + mode="bilinear", + ) + return images_sampled.permute(0, 2, 3, 1).view(ba, *spatial_size, dim)