Raymarching

Summary: Implements two basic raymarchers.

Reviewed By: gkioxari

Differential Revision: D24064250

fbshipit-source-id: 18071bd039995336b7410caa403ea29fafb5c66f
This commit is contained in:
David Novotny 2021-01-05 03:58:22 -08:00 committed by Facebook GitHub Bot
parent aa9bcaf04c
commit 1af1a36bd6
5 changed files with 445 additions and 0 deletions

View File

@ -20,6 +20,7 @@ from .cameras import (
look_at_rotation,
look_at_view_transform,
)
from .implicit import AbsorptionOnlyRaymarcher, EmissionAbsorptionRaymarcher
from .lighting import DirectionalLights, PointLights, diffuse, specular
from .materials import Materials
from .mesh import (

View File

@ -0,0 +1,6 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from .raymarching import AbsorptionOnlyRaymarcher, EmissionAbsorptionRaymarcher
__all__ = [k for k in globals().keys() if not k.startswith("_")]

View File

@ -0,0 +1,223 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import warnings
from typing import Optional, Tuple, Union
import torch
class EmissionAbsorptionRaymarcher(torch.nn.Module):
"""
Raymarch using the Emission-Absorption (EA) algorithm.
The algorithm independently renders each ray by analyzing density and
feature values sampled at (typically uniformly) spaced 3D locations along
each ray. The density values `rays_densities` are of shape
`(..., n_points_per_ray)`, their values should range between [0, 1], and
represent the opaqueness of each point (the higher the less transparent).
The feature values `rays_features` of shape
`(..., n_points_per_ray, feature_dim)` represent the content of the
point that is supposed to be rendered in case the given point is opaque
(i.e. its density -> 1.0).
EA first utilizes `rays_densities` to compute the absorption function
along each ray as follows:
```
absorption = cumprod(1 - rays_densities, dim=-1)
```
The value of absorption at position `absorption[..., k]` specifies
how much light has reached `k`-th point along a ray since starting
its trajectory at `k=0`-th point.
Each ray is then rendered into a tensor `features` of shape `(..., feature_dim)`
by taking a weighed combination of per-ray features `rays_features` as follows:
```
weights = absorption * rays_densities
features = (rays_features * weights).sum(dim=-2)
```
Where `weights` denote a function that has a strong peak around the location
of the first surface point that a given ray passes through.
Note that for a perfectly bounded volume (with a strictly binary density),
the `weights = cumprod(1 - rays_densities, dim=-1) * rays_densities`
function would yield 0 everywhere. In order to prevent this,
the result of the cumulative product is shifted `self.surface_thickness`
elements along the ray direction.
"""
def __init__(self, surface_thickness: int = 1):
"""
Args:
surface_thickness: Denotes the overlap between the absorption
function and the density function.
"""
super().__init__()
self.surface_thickness = surface_thickness
def forward(
self,
rays_densities: torch.Tensor,
rays_features: torch.Tensor,
eps: float = 1e-10,
**kwargs,
) -> torch.Tensor:
"""
Args:
rays_densities: Per-ray density values represented with a tensor
of shape `(..., n_points_per_ray, 1)` whose values range in [0, 1].
rays_features: Per-ray feature values represented with a tensor
of shape `(..., n_points_per_ray, feature_dim)`.
eps: A lower bound added to `rays_densities` before computing
the absorbtion function (cumprod of `1-rays_densities` along
each ray). This prevents the cumprod to yield exact 0
which would inhibit any gradient-based learning.
Returns:
features_opacities: A tensor of shape `(..., feature_dim+1)`
that concatenates two tensors alonng the last dimension:
1) features: A tensor of per-ray renders
of shape `(..., feature_dim)`.
2) opacities: A tensor of per-ray opacity values
of shape `(..., 1)`. Its values range between [0, 1] and
denote the total amount of light that has been absorbed
for each ray. E.g. a value of 0 corresponds to the ray
completely passing through a volume. Please refer to the
`AbsorptionOnlyRaymarcher` documentation for the
explanation of the algorithm that computes `opacities`.
"""
_check_raymarcher_inputs(
rays_densities,
rays_features,
None,
z_can_be_none=True,
features_can_be_none=False,
density_1d=True,
)
_check_density_bounds(rays_densities)
rays_densities = rays_densities[..., 0]
absorption = _shifted_cumprod(
(1.0 + eps) - rays_densities, shift=self.surface_thickness
)
weights = rays_densities * absorption
features = (weights[..., None] * rays_features).sum(dim=-2)
opacities = 1.0 - torch.prod(1.0 - rays_densities, dim=-1, keepdim=True)
return torch.cat((features, opacities), dim=-1)
class AbsorptionOnlyRaymarcher(torch.nn.Module):
"""
Raymarch using the Absorption-Only (AO) algorithm.
The algorithm independently renders each ray by analyzing density and
feature values sampled at (typically uniformly) spaced 3D locations along
each ray. The density values `rays_densities` are of shape
`(..., n_points_per_ray, 1)`, their values should range between [0, 1], and
represent the opaqueness of each point (the higher the less transparent).
The algorithm only measures the total amount of light absorbed along each ray
and, besides outputting per-ray `opacity` values of shape `(...,)`,
does not produce any feature renderings.
The algorithm simply computes `total_transmission = prod(1 - rays_densities)`
of shape `(..., 1)` which, for each ray, measures the total amount of light
that passed through the volume.
It then returns `opacities = 1 - total_transmission`.
"""
def __init__(self):
super().__init__()
def forward(
self, rays_densities: torch.Tensor, **kwargs
) -> Union[None, torch.Tensor]:
"""
Args:
rays_densities: Per-ray density values represented with a tensor
of shape `(..., n_points_per_ray)` whose values range in [0, 1].
Returns:
opacities: A tensor of per-ray opacity values of shape `(..., 1)`.
Its values range between [0, 1] and denote the total amount
of light that has been absorbed for each ray. E.g. a value
of 0 corresponds to the ray completely passing through a volume.
"""
_check_raymarcher_inputs(
rays_densities,
None,
None,
features_can_be_none=True,
z_can_be_none=True,
density_1d=True,
)
rays_densities = rays_densities[..., 0]
_check_density_bounds(rays_densities)
total_transmission = torch.prod(1 - rays_densities, dim=-1, keepdim=True)
opacities = 1.0 - total_transmission
return opacities
def _shifted_cumprod(x, shift=1):
"""
Computes `torch.cumprod(x, dim=-1)` and prepends `shift` number of
ones and removes `shift` trailing elements to/from the last dimension
of the result.
"""
x_cumprod = torch.cumprod(x, dim=-1)
x_cumprod_shift = torch.cat(
[torch.ones_like(x_cumprod[..., :shift]), x_cumprod[..., :-shift]], dim=-1
)
return x_cumprod_shift
def _check_density_bounds(
rays_densities: torch.Tensor, bounds: Tuple[float, float] = (0.0, 1.0)
):
"""
Checks whether the elements of `rays_densities` range within `bounds`.
If not issues a warning.
"""
if ((rays_densities > bounds[1]) | (rays_densities < bounds[0])).any():
warnings.warn(
"One or more elements of rays_densities are outside of valid"
+ f"range {str(bounds)}"
)
def _check_raymarcher_inputs(
rays_densities: torch.Tensor,
rays_features: Optional[torch.Tensor],
rays_z: Optional[torch.Tensor],
features_can_be_none: bool = False,
z_can_be_none: bool = False,
density_1d: bool = True,
):
"""
Checks the validity of the inputs to raymarching algorithms.
"""
if not torch.is_tensor(rays_densities):
raise ValueError("rays_densities has to be an instance of torch.Tensor.")
if not z_can_be_none and not torch.is_tensor(rays_z):
raise ValueError("rays_z has to be an instance of torch.Tensor.")
if not features_can_be_none and not torch.is_tensor(rays_features):
raise ValueError("rays_features has to be an instance of torch.Tensor.")
if rays_densities.ndim < 1:
raise ValueError("rays_densities have to have at least one dimension.")
if density_1d and rays_densities.shape[-1] != 1:
raise ValueError(
"The size of the last dimension of rays_densities has to be one."
)
rays_shape = rays_densities.shape[:-1]
if not z_can_be_none and rays_z.shape != rays_shape:
raise ValueError("rays_z have to be of the same shape as rays_densities.")
if not features_can_be_none and rays_features.shape[:-1] != rays_shape:
raise ValueError(
"The first to previous to last dimensions of rays_features"
" have to be the same as all dimensions of rays_densities."
)

19
tests/bm_raymarching.py Normal file
View File

@ -0,0 +1,19 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import itertools
from fvcore.common.benchmark import benchmark
from pytorch3d.renderer import AbsorptionOnlyRaymarcher, EmissionAbsorptionRaymarcher
from test_raymarching import TestRaymarching
def bm_raymarching() -> None:
case_grid = {
"raymarcher_type": [EmissionAbsorptionRaymarcher, AbsorptionOnlyRaymarcher],
"n_rays": [10, 1000, 10000],
"n_pts_per_ray": [10, 1000, 10000],
}
test_cases = itertools.product(*case_grid.values())
kwargs_list = [dict(zip(case_grid.keys(), case)) for case in test_cases]
benchmark(TestRaymarching.raymarcher, "RAYMARCHER", kwargs_list, warmup_iters=1)

196
tests/test_raymarching.py Normal file
View File

@ -0,0 +1,196 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import unittest
import torch
from common_testing import TestCaseMixin
from pytorch3d.renderer import AbsorptionOnlyRaymarcher, EmissionAbsorptionRaymarcher
class TestRaymarching(TestCaseMixin, unittest.TestCase):
def setUp(self) -> None:
torch.manual_seed(42)
@staticmethod
def _init_random_rays(
n_rays=10, n_pts_per_ray=9, device="cuda", dtype=torch.float32
):
"""
Generate a batch of ray points with features, densities, and z-coodinates
such that their EmissionAbsorption renderring results in
feature renders `features_gt`, depth renders `depths_gt`,
and opacity renders `opacities_gt`.
"""
# generate trivial ray z-coordinates of sampled points coinciding with
# each point's order along a ray.
rays_z = torch.arange(n_pts_per_ray, dtype=dtype, device=device)[None].repeat(
n_rays, 1
)
# generate ground truth depth values of the underlying surface.
depths_gt = torch.randint(
low=1, high=n_pts_per_ray + 2, size=(n_rays,)
).type_as(rays_z)
# compute ideal densities that are 0 before the surface and 1 after
# the corresponding ground truth depth value
rays_densities = (rays_z >= depths_gt[..., None]).type_as(rays_z)[..., None]
opacities_gt = (depths_gt < n_pts_per_ray).type_as(rays_z)
# generate random per-ray features
rays_features = torch.rand(
(n_rays, n_pts_per_ray, 3), device=rays_z.device, dtype=rays_z.dtype
)
# infer the expected feature render "features_gt"
gt_surface = ((rays_z - depths_gt[..., None]).abs() <= 1e-4).type_as(rays_z)
features_gt = (rays_features * gt_surface[..., None]).sum(dim=-2)
return (
rays_z,
rays_densities,
rays_features,
depths_gt,
features_gt,
opacities_gt,
)
@staticmethod
def raymarcher(
raymarcher_type=EmissionAbsorptionRaymarcher, n_rays=10, n_pts_per_ray=10
):
(
rays_z,
rays_densities,
rays_features,
depths_gt,
features_gt,
opacities_gt,
) = TestRaymarching._init_random_rays(
n_rays=n_rays, n_pts_per_ray=n_pts_per_ray
)
raymarcher = raymarcher_type()
def run_raymarcher():
raymarcher(
rays_densities=rays_densities,
rays_features=rays_features,
rays_z=rays_z,
)
torch.cuda.synchronize()
return run_raymarcher
def test_emission_absorption_inputs(self):
"""
Test the checks of validity of the inputs to `EmissionAbsorptionRaymarcher`.
"""
# init the EA raymarcher
raymarcher_ea = EmissionAbsorptionRaymarcher()
# bad ways of passing densities and features
# [rays_densities, rays_features, rays_z]
bad_inputs = [
[torch.rand(10, 5, 4), None],
[torch.Tensor(3)[0], torch.rand(10, 5, 4)],
[1.0, torch.rand(10, 5, 4)],
[torch.rand(10, 5, 4), 1.0],
[torch.rand(10, 5, 4), None],
[torch.rand(10, 5, 4), torch.rand(10, 5, 4)],
[torch.rand(10, 5, 4), torch.rand(10, 5, 4, 3)],
[torch.rand(10, 5, 4, 3), torch.rand(10, 5, 4, 3)],
]
for bad_input in bad_inputs:
with self.assertRaises(ValueError):
raymarcher_ea(*bad_input)
def test_absorption_only_inputs(self):
"""
Test the checks of validity of the inputs to `AbsorptionOnlyRaymarcher`.
"""
# init the AO raymarcher
raymarcher_ao = AbsorptionOnlyRaymarcher()
# bad ways of passing densities and features
# [rays_densities, rays_features, rays_z]
bad_inputs = [[torch.Tensor(3)[0]]]
for bad_input in bad_inputs:
with self.assertRaises(ValueError):
raymarcher_ao(*bad_input)
def test_emission_absorption(self):
"""
Test the EA raymarching algorithm.
"""
(
rays_z,
rays_densities,
rays_features,
depths_gt,
features_gt,
opacities_gt,
) = TestRaymarching._init_random_rays(
n_rays=1000, n_pts_per_ray=9, device=None, dtype=torch.float32
)
# init the EA raymarcher
raymarcher_ea = EmissionAbsorptionRaymarcher()
# allow gradients for a differentiability check
rays_densities.requires_grad = True
rays_features.requires_grad = True
# render the features first and check with gt
data_render = raymarcher_ea(rays_densities, rays_features)
features_render, opacities_render = data_render[..., :-1], data_render[..., -1]
self.assertClose(opacities_render, opacities_gt)
self.assertClose(
features_render * opacities_render[..., None],
features_gt * opacities_gt[..., None],
)
# get the depth map by rendering the ray z components and check with gt
depths_render = raymarcher_ea(rays_densities, rays_z[..., None])[..., 0]
self.assertClose(depths_render * opacities_render, depths_gt * opacities_gt)
# check differentiability
loss = features_render.mean()
loss.backward()
for field in (rays_densities, rays_features):
self.assertTrue(field.grad.data.isfinite().all())
def test_absorption_only(self):
"""
Test the AO raymarching algorithm.
"""
(
rays_z,
rays_densities,
rays_features,
depths_gt,
features_gt,
opacities_gt,
) = TestRaymarching._init_random_rays(
n_rays=1000, n_pts_per_ray=9, device=None, dtype=torch.float32
)
# init the AO raymarcher
raymarcher_ao = AbsorptionOnlyRaymarcher()
# allow gradients for a differentiability check
rays_densities.requires_grad = True
# render opacities, check with gt and check that returned features are None
opacities_render = raymarcher_ao(rays_densities)[..., 0]
self.assertClose(opacities_render, opacities_gt)
# check differentiability
loss = opacities_render.mean()
loss.backward()
self.assertTrue(rays_densities.grad.data.isfinite().all())