From 7cbda3ec174f52831ffc4a5677afecc69b446191 Mon Sep 17 00:00:00 2001 From: David Novotny Date: Tue, 2 Feb 2021 05:42:59 -0800 Subject: [PATCH] NeRF Raysampler Summary: Implements the NeRF raysampler. Reviewed By: nikhilaravi Differential Revision: D25684403 fbshipit-source-id: 616a60f047c79479f60a6a75d214f87cbfb06d28 --- projects/nerf/nerf/raysampler.py | 364 +++++++++++++++++++++++++ projects/nerf/nerf/utils.py | 67 +++++ projects/nerf/tests/test_raysampler.py | 122 +++++++++ 3 files changed, 553 insertions(+) create mode 100644 projects/nerf/nerf/raysampler.py create mode 100644 projects/nerf/nerf/utils.py create mode 100644 projects/nerf/tests/test_raysampler.py diff --git a/projects/nerf/nerf/raysampler.py b/projects/nerf/nerf/raysampler.py new file mode 100644 index 00000000..0d32d574 --- /dev/null +++ b/projects/nerf/nerf/raysampler.py @@ -0,0 +1,364 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +import math +from typing import List + +import torch +from pytorch3d.renderer import RayBundle, NDCGridRaysampler, MonteCarloRaysampler +from pytorch3d.renderer.cameras import CamerasBase + +from .utils import sample_pdf + + +class ProbabilisticRaysampler(torch.nn.Module): + """ + Implements the importance sampling of points along rays. + The input is a `RayBundle` object with a `ray_weights` tensor + which specifies the probabilities of sampling a point along each ray. + + This raysampler is used for the fine rendering pass of NeRF. + As such, the forward pass accepts the RayBundle output by the + raysampling of the coarse rendering pass. Hence, it does not + take cameras as input. + """ + + def __init__( + self, + n_pts_per_ray: int, + stratified: bool, + stratified_test: bool, + add_input_samples: bool = True, + ): + """ + Args: + n_pts_per_ray: The number of points to sample along each ray. + stratified: If `True`, the input `ray_weights` are assumed to be + sampled at equidistant intervals. + stratified_test: Same as `stratified` with the difference that this + setting is applied when the module is in the `eval` mode + (`self.training==False`). + add_input_samples: Concatenates and returns the sampled values + together with the input samples. + """ + super().__init__() + self._n_pts_per_ray = n_pts_per_ray + self._stratified = stratified + self._stratified_test = stratified_test + self._add_input_samples = add_input_samples + + def forward( + self, + input_ray_bundle: RayBundle, + ray_weights: torch.Tensor, + **kwargs, + ) -> RayBundle: + """ + Args: + input_ray_bundle: An instance of `RayBundle` specifying the + source rays for sampling of the probability distribution. + ray_weights: A tensor of shape + `(..., input_ray_bundle.legths.shape[-1])` with non-negative + elements defining the probability distribution to sample + ray points from. + + Returns: + ray_bundle: A new `RayBundle` instance containing the input ray + points together with `n_pts_per_ray` additional sampled + points per ray. + """ + + # Calculate the mid-points between the ray depths. + z_vals = input_ray_bundle.lengths + batch_size = z_vals.shape[0] + z_vals_mid = 0.5 * (z_vals[..., 1:] + z_vals[..., :-1]) + + # Carry out the importance sampling. + z_samples = ( + sample_pdf( + z_vals_mid.view(-1, z_vals_mid.shape[-1]), + ray_weights.view(-1, ray_weights.shape[-1])[..., 1:-1], + self._n_pts_per_ray, + det=not ( + (self._stratified and self.training) + or (self._stratified_test and not self.training) + ), + ) + .detach() + .view(batch_size, z_vals.shape[1], self._n_pts_per_ray) + ) + + if self._add_input_samples: + # Add the new samples to the input ones. + z_vals = torch.cat((z_vals, z_samples), dim=-1) + else: + z_vals = z_samples + # Resort by depth. + z_vals, _ = torch.sort(z_vals, dim=-1) + + return RayBundle( + origins=input_ray_bundle.origins, + directions=input_ray_bundle.directions, + lengths=z_vals, + xys=input_ray_bundle.xys, + ) + + +class NeRFRaysampler(torch.nn.Module): + """ + Implements the raysampler of NeRF. + + Depending on the `self.training` flag, the raysampler either samples + a chunk of random rays (`self.training==True`), or returns a subset of rays + of the full image grid (`self.training==False`). + The chunking of rays allows for efficient evaluation of the NeRF implicit + surface function without encountering out-of-GPU-memory errors. + + Additionally, this raysampler supports pre-caching of the ray bundles + for a set of input cameras (`self.precache_rays`). + Pre-caching the rays before training greatly speeds-up the ensuing + raysampling step of the training NeRF iterations. + """ + + def __init__( + self, + n_pts_per_ray: int, + min_depth: float, + max_depth: float, + n_rays_per_image: int, + image_width: int, + image_height: int, + stratified: bool = False, + stratified_test: bool = False, + ): + """ + Args: + n_pts_per_ray: The number of points sampled along each ray. + min_depth: The minimum depth of a ray-point. + max_depth: The maximum depth of a ray-point. + n_rays_per_image: Number of Monte Carlo ray samples when training + (`self.training==True`). + image_width: The horizontal size of the image grid. + image_height: The vertical size of the image grid. + stratified: If `True`, stratifies (=randomly offsets) the depths + of each ray point during training (`self.training==True`). + stratified_test: If `True`, stratifies (=randomly offsets) the depths + of each ray point during evaluation (`self.training==False`). + """ + + super().__init__() + self._stratified = stratified + self._stratified_test = stratified_test + + # Initialize the grid ray sampler. + self._grid_raysampler = NDCGridRaysampler( + image_width=image_width, + image_height=image_height, + n_pts_per_ray=n_pts_per_ray, + min_depth=min_depth, + max_depth=max_depth, + ) + + # Initialize the Monte Carlo ray sampler. + self._mc_raysampler = MonteCarloRaysampler( + min_x=-1.0, + max_x=1.0, + min_y=-1.0, + max_y=1.0, + n_rays_per_image=n_rays_per_image, + n_pts_per_ray=n_pts_per_ray, + min_depth=min_depth, + max_depth=max_depth, + ) + + # create empty ray cache + self._ray_cache = {} + + def get_n_chunks(self, chunksize: int, batch_size: int): + """ + Returns the total number of `chunksize`-sized chunks + of the raysampler's rays. + + Args: + chunksize: The number of rays per chunk. + batch_size: The size of the batch of the raysampler. + + Returns: + n_chunks: The total number of chunks. + """ + return int( + math.ceil( + (self._grid_raysampler._xy_grid.numel() * 0.5 * batch_size) / chunksize + ) + ) + + def _print_precaching_progress(self, i, total, bar_len=30): + """ + Print a progress bar for ray precaching. + """ + position = round((i + 1) / total * bar_len) + pbar = "[" + "█" * position + " " * (bar_len - position) + "]" + print(pbar, end="\r") + + def precache_rays(self, cameras: List[CamerasBase], camera_hashes: List): + """ + Precaches the rays emitted from the list of cameras `cameras`, + where each camera is uniquely identified with the corresponding hash + from `camera_hashes`. + + The cached rays are moved to cpu and stored in `self._ray_cache`. + Raises `ValueError` when caching two cameras with the same hash. + + Args: + cameras: A list of `N` cameras for which the rays are pre-cached. + camera_hashes: A list of `N` unique identifiers of each + camera from `cameras`. + """ + print(f"Precaching {len(cameras)} ray bundles ...") + full_chunksize = ( + self._grid_raysampler._xy_grid.numel() + // 2 + * self._grid_raysampler._n_pts_per_ray + ) + if self.get_n_chunks(full_chunksize, 1) != 1: + raise ValueError("There has to be one chunk for precaching rays!") + for camera_i, (camera, camera_hash) in enumerate(zip(cameras, camera_hashes)): + ray_bundle = self.forward( + camera, + caching=True, + chunksize=full_chunksize, + ) + if camera_hash in self._ray_cache: + raise ValueError("There are redundant cameras!") + self._ray_cache[camera_hash] = RayBundle( + *[v.to("cpu").detach() for v in ray_bundle] + ) + self._print_precaching_progress(camera_i, len(cameras)) + print("") + + def _stratify_ray_bundle(self, ray_bundle: RayBundle): + """ + Stratifies the lengths of the input `ray_bundle`. + + More specifically, the stratification replaces each ray points' depth `z` + with a sample from a uniform random distribution on + `[z - delta_depth, z+delta_depth]`, where `delta_depth` is the difference + of depths of the consecutive ray depth values. + + Args: + `ray_bundle`: The input `RayBundle`. + + Returns: + `stratified_ray_bundle`: `ray_bundle` whose `lengths` field is replaced + with the stratified samples. + """ + z_vals = ray_bundle.lengths + # Get intervals between samples. + mids = 0.5 * (z_vals[..., 1:] + z_vals[..., :-1]) + upper = torch.cat((mids, z_vals[..., -1:]), dim=-1) + lower = torch.cat((z_vals[..., :1], mids), dim=-1) + # Stratified samples in those intervals. + z_vals = lower + (upper - lower) * torch.rand_like(lower) + return ray_bundle._replace(lengths=z_vals) + + def _normalize_raybundle(self, ray_bundle: RayBundle): + """ + Normalizes the ray directions of the input `RayBundle` to unit norm. + """ + ray_bundle = ray_bundle._replace( + directions=torch.nn.functional.normalize(ray_bundle.directions, dim=-1) + ) + return ray_bundle + + def forward( + self, + cameras: CamerasBase, + chunksize: int = None, + chunk_idx: int = 0, + camera_hash: str = None, + caching: bool = False, + **kwargs, + ) -> RayBundle: + """ + Args: + cameras: A batch of `batch_size` cameras from which the rays are emitted. + chunksize: The number of rays per chunk. + Active only when `self.training==False`. + chunk_idx: The index of the ray chunk. The number has to be in + `[0, self.get_n_chunks(chunksize, batch_size)-1]`. + Active only when `self.training==False`. + camera_hash: A unique identifier of a pre-cached camera. If `None`, + the cache is not searched and the rays are calculated from scratch. + caching: If `True`, activates the caching mode that returns the `RayBundle` + that should be stored into the cache. + Returns: + A named tuple `RayBundle` with the following fields: + origins: A tensor of shape + `(batch_size, n_rays_per_image, 3)` + denoting the locations of ray origins in the world coordinates. + directions: A tensor of shape + `(batch_size, n_rays_per_image, 3)` + denoting the directions of each ray in the world coordinates. + lengths: A tensor of shape + `(batch_size, n_rays_per_image, n_pts_per_ray)` + containing the z-coordinate (=depth) of each ray in world units. + xys: A tensor of shape + `(batch_size, n_rays_per_image, 2)` + containing the 2D image coordinates of each ray. + """ + + batch_size = cameras.R.shape[0] # pyre-ignore + device = cameras.device + + if (camera_hash is None) and (not caching) and self.training: + # Sample random rays from scratch. + ray_bundle = self._mc_raysampler(cameras) + ray_bundle = self._normalize_raybundle(ray_bundle) + else: + if camera_hash is not None: + # The case where we retrieve a camera from cache. + if batch_size != 1: + raise NotImplementedError( + "Ray caching works only for batches with a single camera!" + ) + full_ray_bundle = self._ray_cache[camera_hash] + else: + # We generate a full ray grid from scratch. + full_ray_bundle = self._grid_raysampler(cameras) + full_ray_bundle = self._normalize_raybundle(full_ray_bundle) + + n_pixels = full_ray_bundle.directions.shape[:-1].numel() + + if self.training: + # During training we randomly subsample rays. + sel_rays = torch.randperm(n_pixels, device=device)[ + : self._mc_raysampler._n_rays_per_image + ] + else: + # In case we test, we take only the requested chunk. + if chunksize is None: + chunksize = n_pixels * batch_size + start = chunk_idx * chunksize * batch_size + end = min(start + chunksize, n_pixels) + sel_rays = torch.arange( + start, + end, + dtype=torch.long, + device=full_ray_bundle.lengths.device, + ) + + # Take the "sel_rays" rays from the full ray bundle. + ray_bundle = RayBundle( + *[ + v.view(n_pixels, -1)[sel_rays] + .view(batch_size, sel_rays.numel() // batch_size, -1) + .to(device) + for v in full_ray_bundle + ] + ) + + if ( + (self._stratified and self.training) + or (self._stratified_test and not self.training) + ) and not caching: # Make sure not to stratify when caching! + ray_bundle = self._stratify_ray_bundle(ray_bundle) + + return ray_bundle diff --git a/projects/nerf/nerf/utils.py b/projects/nerf/nerf/utils.py new file mode 100644 index 00000000..f50464ef --- /dev/null +++ b/projects/nerf/nerf/utils.py @@ -0,0 +1,67 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +import torch + + +def sample_pdf( + bins: torch.Tensor, + weights: torch.Tensor, + N_samples: int, + det: bool = False, + eps: float = 1e-5, +): + """ + Samples a probability density functions defined by bin edges `bins` and + the non-negative per-bin probabilities `weights`. + + Note: This is a direct conversion of the TensorFlow function from the original + release [1] to PyTorch. + + Args: + bins: Tensor of shape `(..., n_bins+1)` denoting the edges of the sampling bins. + weights: Tensor of shape `(..., n_bins)` containing non-negative numbers + representing the probability of sampling the corresponding bin. + N_samples: The number of samples to draw from each set of bins. + det: If `False`, the sampling is random. `True` yields deterministic + uniformly-spaced sampling from the inverse cumulative density function. + eps: A constant preventing division by zero in case empty bins are present. + + Returns: + samples: Tensor of shape `(..., N_samples)` containing `N_samples` samples + drawn from each set probability distribution. + + Refs: + [1] https://github.com/bmild/nerf/blob/55d8b00244d7b5178f4d003526ab6667683c9da9/run_nerf_helpers.py#L183 # noqa E501 + """ + + # Get pdf + weights = weights + eps # prevent nans + pdf = weights / weights.sum(dim=-1, keepdim=True) + cdf = torch.cumsum(pdf, -1) + cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1) + + # Take uniform samples + if det: + u = torch.linspace(0.0, 1.0, N_samples, device=cdf.device, dtype=cdf.dtype) + u = u.expand(list(cdf.shape[:-1]) + [N_samples]).contiguous() + else: + u = torch.rand( + list(cdf.shape[:-1]) + [N_samples], device=cdf.device, dtype=cdf.dtype + ) + + # Invert CDF + inds = torch.searchsorted(cdf, u, right=True) + below = (inds - 1).clamp(0) + above = inds.clamp(max=cdf.shape[-1] - 1) + inds_g = torch.stack([below, above], -1).view( + *below.shape[:-1], below.shape[-1] * 2 + ) + + cdf_g = torch.gather(cdf, -1, inds_g).view(*below.shape, 2) + bins_g = torch.gather(bins, -1, inds_g).view(*below.shape, 2) + + denom = cdf_g[..., 1] - cdf_g[..., 0] + denom = torch.where(denom < eps, torch.ones_like(denom), denom) + t = (u - cdf_g[..., 0]) / denom + samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0]) + + return samples diff --git a/projects/nerf/tests/test_raysampler.py b/projects/nerf/tests/test_raysampler.py new file mode 100644 index 00000000..bae85ad4 --- /dev/null +++ b/projects/nerf/tests/test_raysampler.py @@ -0,0 +1,122 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. + +import unittest + +import torch +from nerf.raysampler import NeRFRaysampler, ProbabilisticRaysampler +from pytorch3d.renderer import PerspectiveCameras +from pytorch3d.transforms.rotation_conversions import random_rotations + + +class TestRaysampler(unittest.TestCase): + def setUp(self) -> None: + torch.manual_seed(42) + + def test_raysampler_caching(self, batch_size=10): + """ + Tests the consistency of the NeRF raysampler caching. + """ + + raysampler = NeRFRaysampler( + min_x=0.0, + max_x=10.0, + min_y=0.0, + max_y=10.0, + n_pts_per_ray=10, + min_depth=0.1, + max_depth=10.0, + n_rays_per_image=12, + image_width=10, + image_height=10, + stratified=False, + stratified_test=False, + invert_directions=True, + ) + + raysampler.eval() + + cameras, rays = [], [] + + for _ in range(batch_size): + + R = random_rotations(1) + T = torch.randn(1, 3) + focal_length = torch.rand(1, 2) + 0.5 + principal_point = torch.randn(1, 2) + + camera = PerspectiveCameras( + focal_length=focal_length, + principal_point=principal_point, + R=R, + T=T, + ) + + cameras.append(camera) + rays.append(raysampler(camera)) + + raysampler.precache_rays(cameras, list(range(batch_size))) + + for cam_index, rays_ in enumerate(rays): + rays_cached_ = raysampler( + cameras=cameras[cam_index], + chunksize=None, + chunk_idx=0, + camera_hash=cam_index, + caching=False, + ) + + for v, v_cached in zip(rays_, rays_cached_): + self.assertTrue(torch.allclose(v, v_cached)) + + def test_probabilistic_raysampler(self, batch_size=1, n_pts_per_ray=60): + """ + Check that the probabilisitc ray sampler does not crash for various + settings. + """ + + raysampler_grid = NeRFRaysampler( + min_x=0.0, + max_x=10.0, + min_y=0.0, + max_y=10.0, + n_pts_per_ray=n_pts_per_ray, + min_depth=1.0, + max_depth=10.0, + n_rays_per_image=12, + image_width=10, + image_height=10, + stratified=False, + stratified_test=False, + invert_directions=True, + ) + + R = random_rotations(batch_size) + T = torch.randn(batch_size, 3) + focal_length = torch.rand(batch_size, 2) + 0.5 + principal_point = torch.randn(batch_size, 2) + camera = PerspectiveCameras( + focal_length=focal_length, + principal_point=principal_point, + R=R, + T=T, + ) + + raysampler_grid.eval() + + ray_bundle = raysampler_grid(cameras=camera) + + ray_weights = torch.rand_like(ray_bundle.lengths) + + # Just check that we dont crash for all possible settings. + for stratified_test in (True, False): + for stratified in (True, False): + raysampler_prob = ProbabilisticRaysampler( + n_pts_per_ray=n_pts_per_ray, + stratified=stratified, + stratified_test=stratified_test, + add_input_samples=True, + ) + for mode in ("train", "eval"): + getattr(raysampler_prob, mode)() + for _ in range(10): + raysampler_prob(ray_bundle, ray_weights)