mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
NeRF Raysampler
Summary: Implements the NeRF raysampler. Reviewed By: nikhilaravi Differential Revision: D25684403 fbshipit-source-id: 616a60f047c79479f60a6a75d214f87cbfb06d28
This commit is contained in:
parent
fba419b7f7
commit
7cbda3ec17
364
projects/nerf/nerf/raysampler.py
Normal file
364
projects/nerf/nerf/raysampler.py
Normal file
@ -0,0 +1,364 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
import math
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
from pytorch3d.renderer import RayBundle, NDCGridRaysampler, MonteCarloRaysampler
|
||||
from pytorch3d.renderer.cameras import CamerasBase
|
||||
|
||||
from .utils import sample_pdf
|
||||
|
||||
|
||||
class ProbabilisticRaysampler(torch.nn.Module):
|
||||
"""
|
||||
Implements the importance sampling of points along rays.
|
||||
The input is a `RayBundle` object with a `ray_weights` tensor
|
||||
which specifies the probabilities of sampling a point along each ray.
|
||||
|
||||
This raysampler is used for the fine rendering pass of NeRF.
|
||||
As such, the forward pass accepts the RayBundle output by the
|
||||
raysampling of the coarse rendering pass. Hence, it does not
|
||||
take cameras as input.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
n_pts_per_ray: int,
|
||||
stratified: bool,
|
||||
stratified_test: bool,
|
||||
add_input_samples: bool = True,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
n_pts_per_ray: The number of points to sample along each ray.
|
||||
stratified: If `True`, the input `ray_weights` are assumed to be
|
||||
sampled at equidistant intervals.
|
||||
stratified_test: Same as `stratified` with the difference that this
|
||||
setting is applied when the module is in the `eval` mode
|
||||
(`self.training==False`).
|
||||
add_input_samples: Concatenates and returns the sampled values
|
||||
together with the input samples.
|
||||
"""
|
||||
super().__init__()
|
||||
self._n_pts_per_ray = n_pts_per_ray
|
||||
self._stratified = stratified
|
||||
self._stratified_test = stratified_test
|
||||
self._add_input_samples = add_input_samples
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ray_bundle: RayBundle,
|
||||
ray_weights: torch.Tensor,
|
||||
**kwargs,
|
||||
) -> RayBundle:
|
||||
"""
|
||||
Args:
|
||||
input_ray_bundle: An instance of `RayBundle` specifying the
|
||||
source rays for sampling of the probability distribution.
|
||||
ray_weights: A tensor of shape
|
||||
`(..., input_ray_bundle.legths.shape[-1])` with non-negative
|
||||
elements defining the probability distribution to sample
|
||||
ray points from.
|
||||
|
||||
Returns:
|
||||
ray_bundle: A new `RayBundle` instance containing the input ray
|
||||
points together with `n_pts_per_ray` additional sampled
|
||||
points per ray.
|
||||
"""
|
||||
|
||||
# Calculate the mid-points between the ray depths.
|
||||
z_vals = input_ray_bundle.lengths
|
||||
batch_size = z_vals.shape[0]
|
||||
z_vals_mid = 0.5 * (z_vals[..., 1:] + z_vals[..., :-1])
|
||||
|
||||
# Carry out the importance sampling.
|
||||
z_samples = (
|
||||
sample_pdf(
|
||||
z_vals_mid.view(-1, z_vals_mid.shape[-1]),
|
||||
ray_weights.view(-1, ray_weights.shape[-1])[..., 1:-1],
|
||||
self._n_pts_per_ray,
|
||||
det=not (
|
||||
(self._stratified and self.training)
|
||||
or (self._stratified_test and not self.training)
|
||||
),
|
||||
)
|
||||
.detach()
|
||||
.view(batch_size, z_vals.shape[1], self._n_pts_per_ray)
|
||||
)
|
||||
|
||||
if self._add_input_samples:
|
||||
# Add the new samples to the input ones.
|
||||
z_vals = torch.cat((z_vals, z_samples), dim=-1)
|
||||
else:
|
||||
z_vals = z_samples
|
||||
# Resort by depth.
|
||||
z_vals, _ = torch.sort(z_vals, dim=-1)
|
||||
|
||||
return RayBundle(
|
||||
origins=input_ray_bundle.origins,
|
||||
directions=input_ray_bundle.directions,
|
||||
lengths=z_vals,
|
||||
xys=input_ray_bundle.xys,
|
||||
)
|
||||
|
||||
|
||||
class NeRFRaysampler(torch.nn.Module):
|
||||
"""
|
||||
Implements the raysampler of NeRF.
|
||||
|
||||
Depending on the `self.training` flag, the raysampler either samples
|
||||
a chunk of random rays (`self.training==True`), or returns a subset of rays
|
||||
of the full image grid (`self.training==False`).
|
||||
The chunking of rays allows for efficient evaluation of the NeRF implicit
|
||||
surface function without encountering out-of-GPU-memory errors.
|
||||
|
||||
Additionally, this raysampler supports pre-caching of the ray bundles
|
||||
for a set of input cameras (`self.precache_rays`).
|
||||
Pre-caching the rays before training greatly speeds-up the ensuing
|
||||
raysampling step of the training NeRF iterations.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
n_pts_per_ray: int,
|
||||
min_depth: float,
|
||||
max_depth: float,
|
||||
n_rays_per_image: int,
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
stratified: bool = False,
|
||||
stratified_test: bool = False,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
n_pts_per_ray: The number of points sampled along each ray.
|
||||
min_depth: The minimum depth of a ray-point.
|
||||
max_depth: The maximum depth of a ray-point.
|
||||
n_rays_per_image: Number of Monte Carlo ray samples when training
|
||||
(`self.training==True`).
|
||||
image_width: The horizontal size of the image grid.
|
||||
image_height: The vertical size of the image grid.
|
||||
stratified: If `True`, stratifies (=randomly offsets) the depths
|
||||
of each ray point during training (`self.training==True`).
|
||||
stratified_test: If `True`, stratifies (=randomly offsets) the depths
|
||||
of each ray point during evaluation (`self.training==False`).
|
||||
"""
|
||||
|
||||
super().__init__()
|
||||
self._stratified = stratified
|
||||
self._stratified_test = stratified_test
|
||||
|
||||
# Initialize the grid ray sampler.
|
||||
self._grid_raysampler = NDCGridRaysampler(
|
||||
image_width=image_width,
|
||||
image_height=image_height,
|
||||
n_pts_per_ray=n_pts_per_ray,
|
||||
min_depth=min_depth,
|
||||
max_depth=max_depth,
|
||||
)
|
||||
|
||||
# Initialize the Monte Carlo ray sampler.
|
||||
self._mc_raysampler = MonteCarloRaysampler(
|
||||
min_x=-1.0,
|
||||
max_x=1.0,
|
||||
min_y=-1.0,
|
||||
max_y=1.0,
|
||||
n_rays_per_image=n_rays_per_image,
|
||||
n_pts_per_ray=n_pts_per_ray,
|
||||
min_depth=min_depth,
|
||||
max_depth=max_depth,
|
||||
)
|
||||
|
||||
# create empty ray cache
|
||||
self._ray_cache = {}
|
||||
|
||||
def get_n_chunks(self, chunksize: int, batch_size: int):
|
||||
"""
|
||||
Returns the total number of `chunksize`-sized chunks
|
||||
of the raysampler's rays.
|
||||
|
||||
Args:
|
||||
chunksize: The number of rays per chunk.
|
||||
batch_size: The size of the batch of the raysampler.
|
||||
|
||||
Returns:
|
||||
n_chunks: The total number of chunks.
|
||||
"""
|
||||
return int(
|
||||
math.ceil(
|
||||
(self._grid_raysampler._xy_grid.numel() * 0.5 * batch_size) / chunksize
|
||||
)
|
||||
)
|
||||
|
||||
def _print_precaching_progress(self, i, total, bar_len=30):
|
||||
"""
|
||||
Print a progress bar for ray precaching.
|
||||
"""
|
||||
position = round((i + 1) / total * bar_len)
|
||||
pbar = "[" + "█" * position + " " * (bar_len - position) + "]"
|
||||
print(pbar, end="\r")
|
||||
|
||||
def precache_rays(self, cameras: List[CamerasBase], camera_hashes: List):
|
||||
"""
|
||||
Precaches the rays emitted from the list of cameras `cameras`,
|
||||
where each camera is uniquely identified with the corresponding hash
|
||||
from `camera_hashes`.
|
||||
|
||||
The cached rays are moved to cpu and stored in `self._ray_cache`.
|
||||
Raises `ValueError` when caching two cameras with the same hash.
|
||||
|
||||
Args:
|
||||
cameras: A list of `N` cameras for which the rays are pre-cached.
|
||||
camera_hashes: A list of `N` unique identifiers of each
|
||||
camera from `cameras`.
|
||||
"""
|
||||
print(f"Precaching {len(cameras)} ray bundles ...")
|
||||
full_chunksize = (
|
||||
self._grid_raysampler._xy_grid.numel()
|
||||
// 2
|
||||
* self._grid_raysampler._n_pts_per_ray
|
||||
)
|
||||
if self.get_n_chunks(full_chunksize, 1) != 1:
|
||||
raise ValueError("There has to be one chunk for precaching rays!")
|
||||
for camera_i, (camera, camera_hash) in enumerate(zip(cameras, camera_hashes)):
|
||||
ray_bundle = self.forward(
|
||||
camera,
|
||||
caching=True,
|
||||
chunksize=full_chunksize,
|
||||
)
|
||||
if camera_hash in self._ray_cache:
|
||||
raise ValueError("There are redundant cameras!")
|
||||
self._ray_cache[camera_hash] = RayBundle(
|
||||
*[v.to("cpu").detach() for v in ray_bundle]
|
||||
)
|
||||
self._print_precaching_progress(camera_i, len(cameras))
|
||||
print("")
|
||||
|
||||
def _stratify_ray_bundle(self, ray_bundle: RayBundle):
|
||||
"""
|
||||
Stratifies the lengths of the input `ray_bundle`.
|
||||
|
||||
More specifically, the stratification replaces each ray points' depth `z`
|
||||
with a sample from a uniform random distribution on
|
||||
`[z - delta_depth, z+delta_depth]`, where `delta_depth` is the difference
|
||||
of depths of the consecutive ray depth values.
|
||||
|
||||
Args:
|
||||
`ray_bundle`: The input `RayBundle`.
|
||||
|
||||
Returns:
|
||||
`stratified_ray_bundle`: `ray_bundle` whose `lengths` field is replaced
|
||||
with the stratified samples.
|
||||
"""
|
||||
z_vals = ray_bundle.lengths
|
||||
# Get intervals between samples.
|
||||
mids = 0.5 * (z_vals[..., 1:] + z_vals[..., :-1])
|
||||
upper = torch.cat((mids, z_vals[..., -1:]), dim=-1)
|
||||
lower = torch.cat((z_vals[..., :1], mids), dim=-1)
|
||||
# Stratified samples in those intervals.
|
||||
z_vals = lower + (upper - lower) * torch.rand_like(lower)
|
||||
return ray_bundle._replace(lengths=z_vals)
|
||||
|
||||
def _normalize_raybundle(self, ray_bundle: RayBundle):
|
||||
"""
|
||||
Normalizes the ray directions of the input `RayBundle` to unit norm.
|
||||
"""
|
||||
ray_bundle = ray_bundle._replace(
|
||||
directions=torch.nn.functional.normalize(ray_bundle.directions, dim=-1)
|
||||
)
|
||||
return ray_bundle
|
||||
|
||||
def forward(
|
||||
self,
|
||||
cameras: CamerasBase,
|
||||
chunksize: int = None,
|
||||
chunk_idx: int = 0,
|
||||
camera_hash: str = None,
|
||||
caching: bool = False,
|
||||
**kwargs,
|
||||
) -> RayBundle:
|
||||
"""
|
||||
Args:
|
||||
cameras: A batch of `batch_size` cameras from which the rays are emitted.
|
||||
chunksize: The number of rays per chunk.
|
||||
Active only when `self.training==False`.
|
||||
chunk_idx: The index of the ray chunk. The number has to be in
|
||||
`[0, self.get_n_chunks(chunksize, batch_size)-1]`.
|
||||
Active only when `self.training==False`.
|
||||
camera_hash: A unique identifier of a pre-cached camera. If `None`,
|
||||
the cache is not searched and the rays are calculated from scratch.
|
||||
caching: If `True`, activates the caching mode that returns the `RayBundle`
|
||||
that should be stored into the cache.
|
||||
Returns:
|
||||
A named tuple `RayBundle` with the following fields:
|
||||
origins: A tensor of shape
|
||||
`(batch_size, n_rays_per_image, 3)`
|
||||
denoting the locations of ray origins in the world coordinates.
|
||||
directions: A tensor of shape
|
||||
`(batch_size, n_rays_per_image, 3)`
|
||||
denoting the directions of each ray in the world coordinates.
|
||||
lengths: A tensor of shape
|
||||
`(batch_size, n_rays_per_image, n_pts_per_ray)`
|
||||
containing the z-coordinate (=depth) of each ray in world units.
|
||||
xys: A tensor of shape
|
||||
`(batch_size, n_rays_per_image, 2)`
|
||||
containing the 2D image coordinates of each ray.
|
||||
"""
|
||||
|
||||
batch_size = cameras.R.shape[0] # pyre-ignore
|
||||
device = cameras.device
|
||||
|
||||
if (camera_hash is None) and (not caching) and self.training:
|
||||
# Sample random rays from scratch.
|
||||
ray_bundle = self._mc_raysampler(cameras)
|
||||
ray_bundle = self._normalize_raybundle(ray_bundle)
|
||||
else:
|
||||
if camera_hash is not None:
|
||||
# The case where we retrieve a camera from cache.
|
||||
if batch_size != 1:
|
||||
raise NotImplementedError(
|
||||
"Ray caching works only for batches with a single camera!"
|
||||
)
|
||||
full_ray_bundle = self._ray_cache[camera_hash]
|
||||
else:
|
||||
# We generate a full ray grid from scratch.
|
||||
full_ray_bundle = self._grid_raysampler(cameras)
|
||||
full_ray_bundle = self._normalize_raybundle(full_ray_bundle)
|
||||
|
||||
n_pixels = full_ray_bundle.directions.shape[:-1].numel()
|
||||
|
||||
if self.training:
|
||||
# During training we randomly subsample rays.
|
||||
sel_rays = torch.randperm(n_pixels, device=device)[
|
||||
: self._mc_raysampler._n_rays_per_image
|
||||
]
|
||||
else:
|
||||
# In case we test, we take only the requested chunk.
|
||||
if chunksize is None:
|
||||
chunksize = n_pixels * batch_size
|
||||
start = chunk_idx * chunksize * batch_size
|
||||
end = min(start + chunksize, n_pixels)
|
||||
sel_rays = torch.arange(
|
||||
start,
|
||||
end,
|
||||
dtype=torch.long,
|
||||
device=full_ray_bundle.lengths.device,
|
||||
)
|
||||
|
||||
# Take the "sel_rays" rays from the full ray bundle.
|
||||
ray_bundle = RayBundle(
|
||||
*[
|
||||
v.view(n_pixels, -1)[sel_rays]
|
||||
.view(batch_size, sel_rays.numel() // batch_size, -1)
|
||||
.to(device)
|
||||
for v in full_ray_bundle
|
||||
]
|
||||
)
|
||||
|
||||
if (
|
||||
(self._stratified and self.training)
|
||||
or (self._stratified_test and not self.training)
|
||||
) and not caching: # Make sure not to stratify when caching!
|
||||
ray_bundle = self._stratify_ray_bundle(ray_bundle)
|
||||
|
||||
return ray_bundle
|
67
projects/nerf/nerf/utils.py
Normal file
67
projects/nerf/nerf/utils.py
Normal file
@ -0,0 +1,67 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
import torch
|
||||
|
||||
|
||||
def sample_pdf(
|
||||
bins: torch.Tensor,
|
||||
weights: torch.Tensor,
|
||||
N_samples: int,
|
||||
det: bool = False,
|
||||
eps: float = 1e-5,
|
||||
):
|
||||
"""
|
||||
Samples a probability density functions defined by bin edges `bins` and
|
||||
the non-negative per-bin probabilities `weights`.
|
||||
|
||||
Note: This is a direct conversion of the TensorFlow function from the original
|
||||
release [1] to PyTorch.
|
||||
|
||||
Args:
|
||||
bins: Tensor of shape `(..., n_bins+1)` denoting the edges of the sampling bins.
|
||||
weights: Tensor of shape `(..., n_bins)` containing non-negative numbers
|
||||
representing the probability of sampling the corresponding bin.
|
||||
N_samples: The number of samples to draw from each set of bins.
|
||||
det: If `False`, the sampling is random. `True` yields deterministic
|
||||
uniformly-spaced sampling from the inverse cumulative density function.
|
||||
eps: A constant preventing division by zero in case empty bins are present.
|
||||
|
||||
Returns:
|
||||
samples: Tensor of shape `(..., N_samples)` containing `N_samples` samples
|
||||
drawn from each set probability distribution.
|
||||
|
||||
Refs:
|
||||
[1] https://github.com/bmild/nerf/blob/55d8b00244d7b5178f4d003526ab6667683c9da9/run_nerf_helpers.py#L183 # noqa E501
|
||||
"""
|
||||
|
||||
# Get pdf
|
||||
weights = weights + eps # prevent nans
|
||||
pdf = weights / weights.sum(dim=-1, keepdim=True)
|
||||
cdf = torch.cumsum(pdf, -1)
|
||||
cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1)
|
||||
|
||||
# Take uniform samples
|
||||
if det:
|
||||
u = torch.linspace(0.0, 1.0, N_samples, device=cdf.device, dtype=cdf.dtype)
|
||||
u = u.expand(list(cdf.shape[:-1]) + [N_samples]).contiguous()
|
||||
else:
|
||||
u = torch.rand(
|
||||
list(cdf.shape[:-1]) + [N_samples], device=cdf.device, dtype=cdf.dtype
|
||||
)
|
||||
|
||||
# Invert CDF
|
||||
inds = torch.searchsorted(cdf, u, right=True)
|
||||
below = (inds - 1).clamp(0)
|
||||
above = inds.clamp(max=cdf.shape[-1] - 1)
|
||||
inds_g = torch.stack([below, above], -1).view(
|
||||
*below.shape[:-1], below.shape[-1] * 2
|
||||
)
|
||||
|
||||
cdf_g = torch.gather(cdf, -1, inds_g).view(*below.shape, 2)
|
||||
bins_g = torch.gather(bins, -1, inds_g).view(*below.shape, 2)
|
||||
|
||||
denom = cdf_g[..., 1] - cdf_g[..., 0]
|
||||
denom = torch.where(denom < eps, torch.ones_like(denom), denom)
|
||||
t = (u - cdf_g[..., 0]) / denom
|
||||
samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0])
|
||||
|
||||
return samples
|
122
projects/nerf/tests/test_raysampler.py
Normal file
122
projects/nerf/tests/test_raysampler.py
Normal file
@ -0,0 +1,122 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from nerf.raysampler import NeRFRaysampler, ProbabilisticRaysampler
|
||||
from pytorch3d.renderer import PerspectiveCameras
|
||||
from pytorch3d.transforms.rotation_conversions import random_rotations
|
||||
|
||||
|
||||
class TestRaysampler(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
torch.manual_seed(42)
|
||||
|
||||
def test_raysampler_caching(self, batch_size=10):
|
||||
"""
|
||||
Tests the consistency of the NeRF raysampler caching.
|
||||
"""
|
||||
|
||||
raysampler = NeRFRaysampler(
|
||||
min_x=0.0,
|
||||
max_x=10.0,
|
||||
min_y=0.0,
|
||||
max_y=10.0,
|
||||
n_pts_per_ray=10,
|
||||
min_depth=0.1,
|
||||
max_depth=10.0,
|
||||
n_rays_per_image=12,
|
||||
image_width=10,
|
||||
image_height=10,
|
||||
stratified=False,
|
||||
stratified_test=False,
|
||||
invert_directions=True,
|
||||
)
|
||||
|
||||
raysampler.eval()
|
||||
|
||||
cameras, rays = [], []
|
||||
|
||||
for _ in range(batch_size):
|
||||
|
||||
R = random_rotations(1)
|
||||
T = torch.randn(1, 3)
|
||||
focal_length = torch.rand(1, 2) + 0.5
|
||||
principal_point = torch.randn(1, 2)
|
||||
|
||||
camera = PerspectiveCameras(
|
||||
focal_length=focal_length,
|
||||
principal_point=principal_point,
|
||||
R=R,
|
||||
T=T,
|
||||
)
|
||||
|
||||
cameras.append(camera)
|
||||
rays.append(raysampler(camera))
|
||||
|
||||
raysampler.precache_rays(cameras, list(range(batch_size)))
|
||||
|
||||
for cam_index, rays_ in enumerate(rays):
|
||||
rays_cached_ = raysampler(
|
||||
cameras=cameras[cam_index],
|
||||
chunksize=None,
|
||||
chunk_idx=0,
|
||||
camera_hash=cam_index,
|
||||
caching=False,
|
||||
)
|
||||
|
||||
for v, v_cached in zip(rays_, rays_cached_):
|
||||
self.assertTrue(torch.allclose(v, v_cached))
|
||||
|
||||
def test_probabilistic_raysampler(self, batch_size=1, n_pts_per_ray=60):
|
||||
"""
|
||||
Check that the probabilisitc ray sampler does not crash for various
|
||||
settings.
|
||||
"""
|
||||
|
||||
raysampler_grid = NeRFRaysampler(
|
||||
min_x=0.0,
|
||||
max_x=10.0,
|
||||
min_y=0.0,
|
||||
max_y=10.0,
|
||||
n_pts_per_ray=n_pts_per_ray,
|
||||
min_depth=1.0,
|
||||
max_depth=10.0,
|
||||
n_rays_per_image=12,
|
||||
image_width=10,
|
||||
image_height=10,
|
||||
stratified=False,
|
||||
stratified_test=False,
|
||||
invert_directions=True,
|
||||
)
|
||||
|
||||
R = random_rotations(batch_size)
|
||||
T = torch.randn(batch_size, 3)
|
||||
focal_length = torch.rand(batch_size, 2) + 0.5
|
||||
principal_point = torch.randn(batch_size, 2)
|
||||
camera = PerspectiveCameras(
|
||||
focal_length=focal_length,
|
||||
principal_point=principal_point,
|
||||
R=R,
|
||||
T=T,
|
||||
)
|
||||
|
||||
raysampler_grid.eval()
|
||||
|
||||
ray_bundle = raysampler_grid(cameras=camera)
|
||||
|
||||
ray_weights = torch.rand_like(ray_bundle.lengths)
|
||||
|
||||
# Just check that we dont crash for all possible settings.
|
||||
for stratified_test in (True, False):
|
||||
for stratified in (True, False):
|
||||
raysampler_prob = ProbabilisticRaysampler(
|
||||
n_pts_per_ray=n_pts_per_ray,
|
||||
stratified=stratified,
|
||||
stratified_test=stratified_test,
|
||||
add_input_samples=True,
|
||||
)
|
||||
for mode in ("train", "eval"):
|
||||
getattr(raysampler_prob, mode)()
|
||||
for _ in range(10):
|
||||
raysampler_prob(ray_bundle, ray_weights)
|
Loading…
x
Reference in New Issue
Block a user