NeRF Raysampler

Summary: Implements the NeRF raysampler.

Reviewed By: nikhilaravi

Differential Revision: D25684403

fbshipit-source-id: 616a60f047c79479f60a6a75d214f87cbfb06d28
This commit is contained in:
David Novotny 2021-02-02 05:42:59 -08:00 committed by Facebook GitHub Bot
parent fba419b7f7
commit 7cbda3ec17
3 changed files with 553 additions and 0 deletions

View File

@ -0,0 +1,364 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import math
from typing import List
import torch
from pytorch3d.renderer import RayBundle, NDCGridRaysampler, MonteCarloRaysampler
from pytorch3d.renderer.cameras import CamerasBase
from .utils import sample_pdf
class ProbabilisticRaysampler(torch.nn.Module):
"""
Implements the importance sampling of points along rays.
The input is a `RayBundle` object with a `ray_weights` tensor
which specifies the probabilities of sampling a point along each ray.
This raysampler is used for the fine rendering pass of NeRF.
As such, the forward pass accepts the RayBundle output by the
raysampling of the coarse rendering pass. Hence, it does not
take cameras as input.
"""
def __init__(
self,
n_pts_per_ray: int,
stratified: bool,
stratified_test: bool,
add_input_samples: bool = True,
):
"""
Args:
n_pts_per_ray: The number of points to sample along each ray.
stratified: If `True`, the input `ray_weights` are assumed to be
sampled at equidistant intervals.
stratified_test: Same as `stratified` with the difference that this
setting is applied when the module is in the `eval` mode
(`self.training==False`).
add_input_samples: Concatenates and returns the sampled values
together with the input samples.
"""
super().__init__()
self._n_pts_per_ray = n_pts_per_ray
self._stratified = stratified
self._stratified_test = stratified_test
self._add_input_samples = add_input_samples
def forward(
self,
input_ray_bundle: RayBundle,
ray_weights: torch.Tensor,
**kwargs,
) -> RayBundle:
"""
Args:
input_ray_bundle: An instance of `RayBundle` specifying the
source rays for sampling of the probability distribution.
ray_weights: A tensor of shape
`(..., input_ray_bundle.legths.shape[-1])` with non-negative
elements defining the probability distribution to sample
ray points from.
Returns:
ray_bundle: A new `RayBundle` instance containing the input ray
points together with `n_pts_per_ray` additional sampled
points per ray.
"""
# Calculate the mid-points between the ray depths.
z_vals = input_ray_bundle.lengths
batch_size = z_vals.shape[0]
z_vals_mid = 0.5 * (z_vals[..., 1:] + z_vals[..., :-1])
# Carry out the importance sampling.
z_samples = (
sample_pdf(
z_vals_mid.view(-1, z_vals_mid.shape[-1]),
ray_weights.view(-1, ray_weights.shape[-1])[..., 1:-1],
self._n_pts_per_ray,
det=not (
(self._stratified and self.training)
or (self._stratified_test and not self.training)
),
)
.detach()
.view(batch_size, z_vals.shape[1], self._n_pts_per_ray)
)
if self._add_input_samples:
# Add the new samples to the input ones.
z_vals = torch.cat((z_vals, z_samples), dim=-1)
else:
z_vals = z_samples
# Resort by depth.
z_vals, _ = torch.sort(z_vals, dim=-1)
return RayBundle(
origins=input_ray_bundle.origins,
directions=input_ray_bundle.directions,
lengths=z_vals,
xys=input_ray_bundle.xys,
)
class NeRFRaysampler(torch.nn.Module):
"""
Implements the raysampler of NeRF.
Depending on the `self.training` flag, the raysampler either samples
a chunk of random rays (`self.training==True`), or returns a subset of rays
of the full image grid (`self.training==False`).
The chunking of rays allows for efficient evaluation of the NeRF implicit
surface function without encountering out-of-GPU-memory errors.
Additionally, this raysampler supports pre-caching of the ray bundles
for a set of input cameras (`self.precache_rays`).
Pre-caching the rays before training greatly speeds-up the ensuing
raysampling step of the training NeRF iterations.
"""
def __init__(
self,
n_pts_per_ray: int,
min_depth: float,
max_depth: float,
n_rays_per_image: int,
image_width: int,
image_height: int,
stratified: bool = False,
stratified_test: bool = False,
):
"""
Args:
n_pts_per_ray: The number of points sampled along each ray.
min_depth: The minimum depth of a ray-point.
max_depth: The maximum depth of a ray-point.
n_rays_per_image: Number of Monte Carlo ray samples when training
(`self.training==True`).
image_width: The horizontal size of the image grid.
image_height: The vertical size of the image grid.
stratified: If `True`, stratifies (=randomly offsets) the depths
of each ray point during training (`self.training==True`).
stratified_test: If `True`, stratifies (=randomly offsets) the depths
of each ray point during evaluation (`self.training==False`).
"""
super().__init__()
self._stratified = stratified
self._stratified_test = stratified_test
# Initialize the grid ray sampler.
self._grid_raysampler = NDCGridRaysampler(
image_width=image_width,
image_height=image_height,
n_pts_per_ray=n_pts_per_ray,
min_depth=min_depth,
max_depth=max_depth,
)
# Initialize the Monte Carlo ray sampler.
self._mc_raysampler = MonteCarloRaysampler(
min_x=-1.0,
max_x=1.0,
min_y=-1.0,
max_y=1.0,
n_rays_per_image=n_rays_per_image,
n_pts_per_ray=n_pts_per_ray,
min_depth=min_depth,
max_depth=max_depth,
)
# create empty ray cache
self._ray_cache = {}
def get_n_chunks(self, chunksize: int, batch_size: int):
"""
Returns the total number of `chunksize`-sized chunks
of the raysampler's rays.
Args:
chunksize: The number of rays per chunk.
batch_size: The size of the batch of the raysampler.
Returns:
n_chunks: The total number of chunks.
"""
return int(
math.ceil(
(self._grid_raysampler._xy_grid.numel() * 0.5 * batch_size) / chunksize
)
)
def _print_precaching_progress(self, i, total, bar_len=30):
"""
Print a progress bar for ray precaching.
"""
position = round((i + 1) / total * bar_len)
pbar = "[" + "" * position + " " * (bar_len - position) + "]"
print(pbar, end="\r")
def precache_rays(self, cameras: List[CamerasBase], camera_hashes: List):
"""
Precaches the rays emitted from the list of cameras `cameras`,
where each camera is uniquely identified with the corresponding hash
from `camera_hashes`.
The cached rays are moved to cpu and stored in `self._ray_cache`.
Raises `ValueError` when caching two cameras with the same hash.
Args:
cameras: A list of `N` cameras for which the rays are pre-cached.
camera_hashes: A list of `N` unique identifiers of each
camera from `cameras`.
"""
print(f"Precaching {len(cameras)} ray bundles ...")
full_chunksize = (
self._grid_raysampler._xy_grid.numel()
// 2
* self._grid_raysampler._n_pts_per_ray
)
if self.get_n_chunks(full_chunksize, 1) != 1:
raise ValueError("There has to be one chunk for precaching rays!")
for camera_i, (camera, camera_hash) in enumerate(zip(cameras, camera_hashes)):
ray_bundle = self.forward(
camera,
caching=True,
chunksize=full_chunksize,
)
if camera_hash in self._ray_cache:
raise ValueError("There are redundant cameras!")
self._ray_cache[camera_hash] = RayBundle(
*[v.to("cpu").detach() for v in ray_bundle]
)
self._print_precaching_progress(camera_i, len(cameras))
print("")
def _stratify_ray_bundle(self, ray_bundle: RayBundle):
"""
Stratifies the lengths of the input `ray_bundle`.
More specifically, the stratification replaces each ray points' depth `z`
with a sample from a uniform random distribution on
`[z - delta_depth, z+delta_depth]`, where `delta_depth` is the difference
of depths of the consecutive ray depth values.
Args:
`ray_bundle`: The input `RayBundle`.
Returns:
`stratified_ray_bundle`: `ray_bundle` whose `lengths` field is replaced
with the stratified samples.
"""
z_vals = ray_bundle.lengths
# Get intervals between samples.
mids = 0.5 * (z_vals[..., 1:] + z_vals[..., :-1])
upper = torch.cat((mids, z_vals[..., -1:]), dim=-1)
lower = torch.cat((z_vals[..., :1], mids), dim=-1)
# Stratified samples in those intervals.
z_vals = lower + (upper - lower) * torch.rand_like(lower)
return ray_bundle._replace(lengths=z_vals)
def _normalize_raybundle(self, ray_bundle: RayBundle):
"""
Normalizes the ray directions of the input `RayBundle` to unit norm.
"""
ray_bundle = ray_bundle._replace(
directions=torch.nn.functional.normalize(ray_bundle.directions, dim=-1)
)
return ray_bundle
def forward(
self,
cameras: CamerasBase,
chunksize: int = None,
chunk_idx: int = 0,
camera_hash: str = None,
caching: bool = False,
**kwargs,
) -> RayBundle:
"""
Args:
cameras: A batch of `batch_size` cameras from which the rays are emitted.
chunksize: The number of rays per chunk.
Active only when `self.training==False`.
chunk_idx: The index of the ray chunk. The number has to be in
`[0, self.get_n_chunks(chunksize, batch_size)-1]`.
Active only when `self.training==False`.
camera_hash: A unique identifier of a pre-cached camera. If `None`,
the cache is not searched and the rays are calculated from scratch.
caching: If `True`, activates the caching mode that returns the `RayBundle`
that should be stored into the cache.
Returns:
A named tuple `RayBundle` with the following fields:
origins: A tensor of shape
`(batch_size, n_rays_per_image, 3)`
denoting the locations of ray origins in the world coordinates.
directions: A tensor of shape
`(batch_size, n_rays_per_image, 3)`
denoting the directions of each ray in the world coordinates.
lengths: A tensor of shape
`(batch_size, n_rays_per_image, n_pts_per_ray)`
containing the z-coordinate (=depth) of each ray in world units.
xys: A tensor of shape
`(batch_size, n_rays_per_image, 2)`
containing the 2D image coordinates of each ray.
"""
batch_size = cameras.R.shape[0] # pyre-ignore
device = cameras.device
if (camera_hash is None) and (not caching) and self.training:
# Sample random rays from scratch.
ray_bundle = self._mc_raysampler(cameras)
ray_bundle = self._normalize_raybundle(ray_bundle)
else:
if camera_hash is not None:
# The case where we retrieve a camera from cache.
if batch_size != 1:
raise NotImplementedError(
"Ray caching works only for batches with a single camera!"
)
full_ray_bundle = self._ray_cache[camera_hash]
else:
# We generate a full ray grid from scratch.
full_ray_bundle = self._grid_raysampler(cameras)
full_ray_bundle = self._normalize_raybundle(full_ray_bundle)
n_pixels = full_ray_bundle.directions.shape[:-1].numel()
if self.training:
# During training we randomly subsample rays.
sel_rays = torch.randperm(n_pixels, device=device)[
: self._mc_raysampler._n_rays_per_image
]
else:
# In case we test, we take only the requested chunk.
if chunksize is None:
chunksize = n_pixels * batch_size
start = chunk_idx * chunksize * batch_size
end = min(start + chunksize, n_pixels)
sel_rays = torch.arange(
start,
end,
dtype=torch.long,
device=full_ray_bundle.lengths.device,
)
# Take the "sel_rays" rays from the full ray bundle.
ray_bundle = RayBundle(
*[
v.view(n_pixels, -1)[sel_rays]
.view(batch_size, sel_rays.numel() // batch_size, -1)
.to(device)
for v in full_ray_bundle
]
)
if (
(self._stratified and self.training)
or (self._stratified_test and not self.training)
) and not caching: # Make sure not to stratify when caching!
ray_bundle = self._stratify_ray_bundle(ray_bundle)
return ray_bundle

View File

@ -0,0 +1,67 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import torch
def sample_pdf(
bins: torch.Tensor,
weights: torch.Tensor,
N_samples: int,
det: bool = False,
eps: float = 1e-5,
):
"""
Samples a probability density functions defined by bin edges `bins` and
the non-negative per-bin probabilities `weights`.
Note: This is a direct conversion of the TensorFlow function from the original
release [1] to PyTorch.
Args:
bins: Tensor of shape `(..., n_bins+1)` denoting the edges of the sampling bins.
weights: Tensor of shape `(..., n_bins)` containing non-negative numbers
representing the probability of sampling the corresponding bin.
N_samples: The number of samples to draw from each set of bins.
det: If `False`, the sampling is random. `True` yields deterministic
uniformly-spaced sampling from the inverse cumulative density function.
eps: A constant preventing division by zero in case empty bins are present.
Returns:
samples: Tensor of shape `(..., N_samples)` containing `N_samples` samples
drawn from each set probability distribution.
Refs:
[1] https://github.com/bmild/nerf/blob/55d8b00244d7b5178f4d003526ab6667683c9da9/run_nerf_helpers.py#L183 # noqa E501
"""
# Get pdf
weights = weights + eps # prevent nans
pdf = weights / weights.sum(dim=-1, keepdim=True)
cdf = torch.cumsum(pdf, -1)
cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1)
# Take uniform samples
if det:
u = torch.linspace(0.0, 1.0, N_samples, device=cdf.device, dtype=cdf.dtype)
u = u.expand(list(cdf.shape[:-1]) + [N_samples]).contiguous()
else:
u = torch.rand(
list(cdf.shape[:-1]) + [N_samples], device=cdf.device, dtype=cdf.dtype
)
# Invert CDF
inds = torch.searchsorted(cdf, u, right=True)
below = (inds - 1).clamp(0)
above = inds.clamp(max=cdf.shape[-1] - 1)
inds_g = torch.stack([below, above], -1).view(
*below.shape[:-1], below.shape[-1] * 2
)
cdf_g = torch.gather(cdf, -1, inds_g).view(*below.shape, 2)
bins_g = torch.gather(bins, -1, inds_g).view(*below.shape, 2)
denom = cdf_g[..., 1] - cdf_g[..., 0]
denom = torch.where(denom < eps, torch.ones_like(denom), denom)
t = (u - cdf_g[..., 0]) / denom
samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0])
return samples

View File

@ -0,0 +1,122 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import unittest
import torch
from nerf.raysampler import NeRFRaysampler, ProbabilisticRaysampler
from pytorch3d.renderer import PerspectiveCameras
from pytorch3d.transforms.rotation_conversions import random_rotations
class TestRaysampler(unittest.TestCase):
def setUp(self) -> None:
torch.manual_seed(42)
def test_raysampler_caching(self, batch_size=10):
"""
Tests the consistency of the NeRF raysampler caching.
"""
raysampler = NeRFRaysampler(
min_x=0.0,
max_x=10.0,
min_y=0.0,
max_y=10.0,
n_pts_per_ray=10,
min_depth=0.1,
max_depth=10.0,
n_rays_per_image=12,
image_width=10,
image_height=10,
stratified=False,
stratified_test=False,
invert_directions=True,
)
raysampler.eval()
cameras, rays = [], []
for _ in range(batch_size):
R = random_rotations(1)
T = torch.randn(1, 3)
focal_length = torch.rand(1, 2) + 0.5
principal_point = torch.randn(1, 2)
camera = PerspectiveCameras(
focal_length=focal_length,
principal_point=principal_point,
R=R,
T=T,
)
cameras.append(camera)
rays.append(raysampler(camera))
raysampler.precache_rays(cameras, list(range(batch_size)))
for cam_index, rays_ in enumerate(rays):
rays_cached_ = raysampler(
cameras=cameras[cam_index],
chunksize=None,
chunk_idx=0,
camera_hash=cam_index,
caching=False,
)
for v, v_cached in zip(rays_, rays_cached_):
self.assertTrue(torch.allclose(v, v_cached))
def test_probabilistic_raysampler(self, batch_size=1, n_pts_per_ray=60):
"""
Check that the probabilisitc ray sampler does not crash for various
settings.
"""
raysampler_grid = NeRFRaysampler(
min_x=0.0,
max_x=10.0,
min_y=0.0,
max_y=10.0,
n_pts_per_ray=n_pts_per_ray,
min_depth=1.0,
max_depth=10.0,
n_rays_per_image=12,
image_width=10,
image_height=10,
stratified=False,
stratified_test=False,
invert_directions=True,
)
R = random_rotations(batch_size)
T = torch.randn(batch_size, 3)
focal_length = torch.rand(batch_size, 2) + 0.5
principal_point = torch.randn(batch_size, 2)
camera = PerspectiveCameras(
focal_length=focal_length,
principal_point=principal_point,
R=R,
T=T,
)
raysampler_grid.eval()
ray_bundle = raysampler_grid(cameras=camera)
ray_weights = torch.rand_like(ray_bundle.lengths)
# Just check that we dont crash for all possible settings.
for stratified_test in (True, False):
for stratified in (True, False):
raysampler_prob = ProbabilisticRaysampler(
n_pts_per_ray=n_pts_per_ray,
stratified=stratified,
stratified_test=stratified_test,
add_input_samples=True,
)
for mode in ("train", "eval"):
getattr(raysampler_prob, mode)()
for _ in range(10):
raysampler_prob(ray_bundle, ray_weights)