mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Raymarching
Summary: Implements two basic raymarchers. Reviewed By: gkioxari Differential Revision: D24064250 fbshipit-source-id: 18071bd039995336b7410caa403ea29fafb5c66f
This commit is contained in:
parent
aa9bcaf04c
commit
1af1a36bd6
@ -20,6 +20,7 @@ from .cameras import (
|
||||
look_at_rotation,
|
||||
look_at_view_transform,
|
||||
)
|
||||
from .implicit import AbsorptionOnlyRaymarcher, EmissionAbsorptionRaymarcher
|
||||
from .lighting import DirectionalLights, PointLights, diffuse, specular
|
||||
from .materials import Materials
|
||||
from .mesh import (
|
||||
|
6
pytorch3d/renderer/implicit/__init__.py
Normal file
6
pytorch3d/renderer/implicit/__init__.py
Normal file
@ -0,0 +1,6 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
from .raymarching import AbsorptionOnlyRaymarcher, EmissionAbsorptionRaymarcher
|
||||
|
||||
|
||||
__all__ = [k for k in globals().keys() if not k.startswith("_")]
|
223
pytorch3d/renderer/implicit/raymarching.py
Normal file
223
pytorch3d/renderer/implicit/raymarching.py
Normal file
@ -0,0 +1,223 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
import warnings
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class EmissionAbsorptionRaymarcher(torch.nn.Module):
|
||||
"""
|
||||
Raymarch using the Emission-Absorption (EA) algorithm.
|
||||
|
||||
The algorithm independently renders each ray by analyzing density and
|
||||
feature values sampled at (typically uniformly) spaced 3D locations along
|
||||
each ray. The density values `rays_densities` are of shape
|
||||
`(..., n_points_per_ray)`, their values should range between [0, 1], and
|
||||
represent the opaqueness of each point (the higher the less transparent).
|
||||
The feature values `rays_features` of shape
|
||||
`(..., n_points_per_ray, feature_dim)` represent the content of the
|
||||
point that is supposed to be rendered in case the given point is opaque
|
||||
(i.e. its density -> 1.0).
|
||||
|
||||
EA first utilizes `rays_densities` to compute the absorption function
|
||||
along each ray as follows:
|
||||
```
|
||||
absorption = cumprod(1 - rays_densities, dim=-1)
|
||||
```
|
||||
The value of absorption at position `absorption[..., k]` specifies
|
||||
how much light has reached `k`-th point along a ray since starting
|
||||
its trajectory at `k=0`-th point.
|
||||
|
||||
Each ray is then rendered into a tensor `features` of shape `(..., feature_dim)`
|
||||
by taking a weighed combination of per-ray features `rays_features` as follows:
|
||||
```
|
||||
weights = absorption * rays_densities
|
||||
features = (rays_features * weights).sum(dim=-2)
|
||||
```
|
||||
Where `weights` denote a function that has a strong peak around the location
|
||||
of the first surface point that a given ray passes through.
|
||||
|
||||
Note that for a perfectly bounded volume (with a strictly binary density),
|
||||
the `weights = cumprod(1 - rays_densities, dim=-1) * rays_densities`
|
||||
function would yield 0 everywhere. In order to prevent this,
|
||||
the result of the cumulative product is shifted `self.surface_thickness`
|
||||
elements along the ray direction.
|
||||
"""
|
||||
|
||||
def __init__(self, surface_thickness: int = 1):
|
||||
"""
|
||||
Args:
|
||||
surface_thickness: Denotes the overlap between the absorption
|
||||
function and the density function.
|
||||
"""
|
||||
super().__init__()
|
||||
self.surface_thickness = surface_thickness
|
||||
|
||||
def forward(
|
||||
self,
|
||||
rays_densities: torch.Tensor,
|
||||
rays_features: torch.Tensor,
|
||||
eps: float = 1e-10,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
rays_densities: Per-ray density values represented with a tensor
|
||||
of shape `(..., n_points_per_ray, 1)` whose values range in [0, 1].
|
||||
rays_features: Per-ray feature values represented with a tensor
|
||||
of shape `(..., n_points_per_ray, feature_dim)`.
|
||||
eps: A lower bound added to `rays_densities` before computing
|
||||
the absorbtion function (cumprod of `1-rays_densities` along
|
||||
each ray). This prevents the cumprod to yield exact 0
|
||||
which would inhibit any gradient-based learning.
|
||||
|
||||
Returns:
|
||||
features_opacities: A tensor of shape `(..., feature_dim+1)`
|
||||
that concatenates two tensors alonng the last dimension:
|
||||
1) features: A tensor of per-ray renders
|
||||
of shape `(..., feature_dim)`.
|
||||
2) opacities: A tensor of per-ray opacity values
|
||||
of shape `(..., 1)`. Its values range between [0, 1] and
|
||||
denote the total amount of light that has been absorbed
|
||||
for each ray. E.g. a value of 0 corresponds to the ray
|
||||
completely passing through a volume. Please refer to the
|
||||
`AbsorptionOnlyRaymarcher` documentation for the
|
||||
explanation of the algorithm that computes `opacities`.
|
||||
"""
|
||||
_check_raymarcher_inputs(
|
||||
rays_densities,
|
||||
rays_features,
|
||||
None,
|
||||
z_can_be_none=True,
|
||||
features_can_be_none=False,
|
||||
density_1d=True,
|
||||
)
|
||||
_check_density_bounds(rays_densities)
|
||||
rays_densities = rays_densities[..., 0]
|
||||
absorption = _shifted_cumprod(
|
||||
(1.0 + eps) - rays_densities, shift=self.surface_thickness
|
||||
)
|
||||
weights = rays_densities * absorption
|
||||
features = (weights[..., None] * rays_features).sum(dim=-2)
|
||||
opacities = 1.0 - torch.prod(1.0 - rays_densities, dim=-1, keepdim=True)
|
||||
|
||||
return torch.cat((features, opacities), dim=-1)
|
||||
|
||||
|
||||
class AbsorptionOnlyRaymarcher(torch.nn.Module):
|
||||
"""
|
||||
Raymarch using the Absorption-Only (AO) algorithm.
|
||||
|
||||
The algorithm independently renders each ray by analyzing density and
|
||||
feature values sampled at (typically uniformly) spaced 3D locations along
|
||||
each ray. The density values `rays_densities` are of shape
|
||||
`(..., n_points_per_ray, 1)`, their values should range between [0, 1], and
|
||||
represent the opaqueness of each point (the higher the less transparent).
|
||||
The algorithm only measures the total amount of light absorbed along each ray
|
||||
and, besides outputting per-ray `opacity` values of shape `(...,)`,
|
||||
does not produce any feature renderings.
|
||||
|
||||
The algorithm simply computes `total_transmission = prod(1 - rays_densities)`
|
||||
of shape `(..., 1)` which, for each ray, measures the total amount of light
|
||||
that passed through the volume.
|
||||
It then returns `opacities = 1 - total_transmission`.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(
|
||||
self, rays_densities: torch.Tensor, **kwargs
|
||||
) -> Union[None, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
rays_densities: Per-ray density values represented with a tensor
|
||||
of shape `(..., n_points_per_ray)` whose values range in [0, 1].
|
||||
|
||||
Returns:
|
||||
opacities: A tensor of per-ray opacity values of shape `(..., 1)`.
|
||||
Its values range between [0, 1] and denote the total amount
|
||||
of light that has been absorbed for each ray. E.g. a value
|
||||
of 0 corresponds to the ray completely passing through a volume.
|
||||
"""
|
||||
|
||||
_check_raymarcher_inputs(
|
||||
rays_densities,
|
||||
None,
|
||||
None,
|
||||
features_can_be_none=True,
|
||||
z_can_be_none=True,
|
||||
density_1d=True,
|
||||
)
|
||||
rays_densities = rays_densities[..., 0]
|
||||
_check_density_bounds(rays_densities)
|
||||
total_transmission = torch.prod(1 - rays_densities, dim=-1, keepdim=True)
|
||||
opacities = 1.0 - total_transmission
|
||||
return opacities
|
||||
|
||||
|
||||
def _shifted_cumprod(x, shift=1):
|
||||
"""
|
||||
Computes `torch.cumprod(x, dim=-1)` and prepends `shift` number of
|
||||
ones and removes `shift` trailing elements to/from the last dimension
|
||||
of the result.
|
||||
"""
|
||||
x_cumprod = torch.cumprod(x, dim=-1)
|
||||
x_cumprod_shift = torch.cat(
|
||||
[torch.ones_like(x_cumprod[..., :shift]), x_cumprod[..., :-shift]], dim=-1
|
||||
)
|
||||
return x_cumprod_shift
|
||||
|
||||
|
||||
def _check_density_bounds(
|
||||
rays_densities: torch.Tensor, bounds: Tuple[float, float] = (0.0, 1.0)
|
||||
):
|
||||
"""
|
||||
Checks whether the elements of `rays_densities` range within `bounds`.
|
||||
If not issues a warning.
|
||||
"""
|
||||
if ((rays_densities > bounds[1]) | (rays_densities < bounds[0])).any():
|
||||
warnings.warn(
|
||||
"One or more elements of rays_densities are outside of valid"
|
||||
+ f"range {str(bounds)}"
|
||||
)
|
||||
|
||||
|
||||
def _check_raymarcher_inputs(
|
||||
rays_densities: torch.Tensor,
|
||||
rays_features: Optional[torch.Tensor],
|
||||
rays_z: Optional[torch.Tensor],
|
||||
features_can_be_none: bool = False,
|
||||
z_can_be_none: bool = False,
|
||||
density_1d: bool = True,
|
||||
):
|
||||
"""
|
||||
Checks the validity of the inputs to raymarching algorithms.
|
||||
"""
|
||||
if not torch.is_tensor(rays_densities):
|
||||
raise ValueError("rays_densities has to be an instance of torch.Tensor.")
|
||||
|
||||
if not z_can_be_none and not torch.is_tensor(rays_z):
|
||||
raise ValueError("rays_z has to be an instance of torch.Tensor.")
|
||||
|
||||
if not features_can_be_none and not torch.is_tensor(rays_features):
|
||||
raise ValueError("rays_features has to be an instance of torch.Tensor.")
|
||||
|
||||
if rays_densities.ndim < 1:
|
||||
raise ValueError("rays_densities have to have at least one dimension.")
|
||||
|
||||
if density_1d and rays_densities.shape[-1] != 1:
|
||||
raise ValueError(
|
||||
"The size of the last dimension of rays_densities has to be one."
|
||||
)
|
||||
|
||||
rays_shape = rays_densities.shape[:-1]
|
||||
|
||||
if not z_can_be_none and rays_z.shape != rays_shape:
|
||||
raise ValueError("rays_z have to be of the same shape as rays_densities.")
|
||||
|
||||
if not features_can_be_none and rays_features.shape[:-1] != rays_shape:
|
||||
raise ValueError(
|
||||
"The first to previous to last dimensions of rays_features"
|
||||
" have to be the same as all dimensions of rays_densities."
|
||||
)
|
19
tests/bm_raymarching.py
Normal file
19
tests/bm_raymarching.py
Normal file
@ -0,0 +1,19 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
import itertools
|
||||
|
||||
from fvcore.common.benchmark import benchmark
|
||||
from pytorch3d.renderer import AbsorptionOnlyRaymarcher, EmissionAbsorptionRaymarcher
|
||||
from test_raymarching import TestRaymarching
|
||||
|
||||
|
||||
def bm_raymarching() -> None:
|
||||
case_grid = {
|
||||
"raymarcher_type": [EmissionAbsorptionRaymarcher, AbsorptionOnlyRaymarcher],
|
||||
"n_rays": [10, 1000, 10000],
|
||||
"n_pts_per_ray": [10, 1000, 10000],
|
||||
}
|
||||
test_cases = itertools.product(*case_grid.values())
|
||||
kwargs_list = [dict(zip(case_grid.keys(), case)) for case in test_cases]
|
||||
|
||||
benchmark(TestRaymarching.raymarcher, "RAYMARCHER", kwargs_list, warmup_iters=1)
|
196
tests/test_raymarching.py
Normal file
196
tests/test_raymarching.py
Normal file
@ -0,0 +1,196 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from common_testing import TestCaseMixin
|
||||
from pytorch3d.renderer import AbsorptionOnlyRaymarcher, EmissionAbsorptionRaymarcher
|
||||
|
||||
|
||||
class TestRaymarching(TestCaseMixin, unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
torch.manual_seed(42)
|
||||
|
||||
@staticmethod
|
||||
def _init_random_rays(
|
||||
n_rays=10, n_pts_per_ray=9, device="cuda", dtype=torch.float32
|
||||
):
|
||||
"""
|
||||
Generate a batch of ray points with features, densities, and z-coodinates
|
||||
such that their EmissionAbsorption renderring results in
|
||||
feature renders `features_gt`, depth renders `depths_gt`,
|
||||
and opacity renders `opacities_gt`.
|
||||
"""
|
||||
|
||||
# generate trivial ray z-coordinates of sampled points coinciding with
|
||||
# each point's order along a ray.
|
||||
rays_z = torch.arange(n_pts_per_ray, dtype=dtype, device=device)[None].repeat(
|
||||
n_rays, 1
|
||||
)
|
||||
|
||||
# generate ground truth depth values of the underlying surface.
|
||||
depths_gt = torch.randint(
|
||||
low=1, high=n_pts_per_ray + 2, size=(n_rays,)
|
||||
).type_as(rays_z)
|
||||
|
||||
# compute ideal densities that are 0 before the surface and 1 after
|
||||
# the corresponding ground truth depth value
|
||||
rays_densities = (rays_z >= depths_gt[..., None]).type_as(rays_z)[..., None]
|
||||
opacities_gt = (depths_gt < n_pts_per_ray).type_as(rays_z)
|
||||
|
||||
# generate random per-ray features
|
||||
rays_features = torch.rand(
|
||||
(n_rays, n_pts_per_ray, 3), device=rays_z.device, dtype=rays_z.dtype
|
||||
)
|
||||
|
||||
# infer the expected feature render "features_gt"
|
||||
gt_surface = ((rays_z - depths_gt[..., None]).abs() <= 1e-4).type_as(rays_z)
|
||||
features_gt = (rays_features * gt_surface[..., None]).sum(dim=-2)
|
||||
|
||||
return (
|
||||
rays_z,
|
||||
rays_densities,
|
||||
rays_features,
|
||||
depths_gt,
|
||||
features_gt,
|
||||
opacities_gt,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def raymarcher(
|
||||
raymarcher_type=EmissionAbsorptionRaymarcher, n_rays=10, n_pts_per_ray=10
|
||||
):
|
||||
(
|
||||
rays_z,
|
||||
rays_densities,
|
||||
rays_features,
|
||||
depths_gt,
|
||||
features_gt,
|
||||
opacities_gt,
|
||||
) = TestRaymarching._init_random_rays(
|
||||
n_rays=n_rays, n_pts_per_ray=n_pts_per_ray
|
||||
)
|
||||
|
||||
raymarcher = raymarcher_type()
|
||||
|
||||
def run_raymarcher():
|
||||
raymarcher(
|
||||
rays_densities=rays_densities,
|
||||
rays_features=rays_features,
|
||||
rays_z=rays_z,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
return run_raymarcher
|
||||
|
||||
def test_emission_absorption_inputs(self):
|
||||
"""
|
||||
Test the checks of validity of the inputs to `EmissionAbsorptionRaymarcher`.
|
||||
"""
|
||||
|
||||
# init the EA raymarcher
|
||||
raymarcher_ea = EmissionAbsorptionRaymarcher()
|
||||
|
||||
# bad ways of passing densities and features
|
||||
# [rays_densities, rays_features, rays_z]
|
||||
bad_inputs = [
|
||||
[torch.rand(10, 5, 4), None],
|
||||
[torch.Tensor(3)[0], torch.rand(10, 5, 4)],
|
||||
[1.0, torch.rand(10, 5, 4)],
|
||||
[torch.rand(10, 5, 4), 1.0],
|
||||
[torch.rand(10, 5, 4), None],
|
||||
[torch.rand(10, 5, 4), torch.rand(10, 5, 4)],
|
||||
[torch.rand(10, 5, 4), torch.rand(10, 5, 4, 3)],
|
||||
[torch.rand(10, 5, 4, 3), torch.rand(10, 5, 4, 3)],
|
||||
]
|
||||
|
||||
for bad_input in bad_inputs:
|
||||
with self.assertRaises(ValueError):
|
||||
raymarcher_ea(*bad_input)
|
||||
|
||||
def test_absorption_only_inputs(self):
|
||||
"""
|
||||
Test the checks of validity of the inputs to `AbsorptionOnlyRaymarcher`.
|
||||
"""
|
||||
|
||||
# init the AO raymarcher
|
||||
raymarcher_ao = AbsorptionOnlyRaymarcher()
|
||||
|
||||
# bad ways of passing densities and features
|
||||
# [rays_densities, rays_features, rays_z]
|
||||
bad_inputs = [[torch.Tensor(3)[0]]]
|
||||
|
||||
for bad_input in bad_inputs:
|
||||
with self.assertRaises(ValueError):
|
||||
raymarcher_ao(*bad_input)
|
||||
|
||||
def test_emission_absorption(self):
|
||||
"""
|
||||
Test the EA raymarching algorithm.
|
||||
"""
|
||||
(
|
||||
rays_z,
|
||||
rays_densities,
|
||||
rays_features,
|
||||
depths_gt,
|
||||
features_gt,
|
||||
opacities_gt,
|
||||
) = TestRaymarching._init_random_rays(
|
||||
n_rays=1000, n_pts_per_ray=9, device=None, dtype=torch.float32
|
||||
)
|
||||
|
||||
# init the EA raymarcher
|
||||
raymarcher_ea = EmissionAbsorptionRaymarcher()
|
||||
|
||||
# allow gradients for a differentiability check
|
||||
rays_densities.requires_grad = True
|
||||
rays_features.requires_grad = True
|
||||
|
||||
# render the features first and check with gt
|
||||
data_render = raymarcher_ea(rays_densities, rays_features)
|
||||
features_render, opacities_render = data_render[..., :-1], data_render[..., -1]
|
||||
self.assertClose(opacities_render, opacities_gt)
|
||||
self.assertClose(
|
||||
features_render * opacities_render[..., None],
|
||||
features_gt * opacities_gt[..., None],
|
||||
)
|
||||
|
||||
# get the depth map by rendering the ray z components and check with gt
|
||||
depths_render = raymarcher_ea(rays_densities, rays_z[..., None])[..., 0]
|
||||
self.assertClose(depths_render * opacities_render, depths_gt * opacities_gt)
|
||||
|
||||
# check differentiability
|
||||
loss = features_render.mean()
|
||||
loss.backward()
|
||||
for field in (rays_densities, rays_features):
|
||||
self.assertTrue(field.grad.data.isfinite().all())
|
||||
|
||||
def test_absorption_only(self):
|
||||
"""
|
||||
Test the AO raymarching algorithm.
|
||||
"""
|
||||
(
|
||||
rays_z,
|
||||
rays_densities,
|
||||
rays_features,
|
||||
depths_gt,
|
||||
features_gt,
|
||||
opacities_gt,
|
||||
) = TestRaymarching._init_random_rays(
|
||||
n_rays=1000, n_pts_per_ray=9, device=None, dtype=torch.float32
|
||||
)
|
||||
|
||||
# init the AO raymarcher
|
||||
raymarcher_ao = AbsorptionOnlyRaymarcher()
|
||||
|
||||
# allow gradients for a differentiability check
|
||||
rays_densities.requires_grad = True
|
||||
|
||||
# render opacities, check with gt and check that returned features are None
|
||||
opacities_render = raymarcher_ao(rays_densities)[..., 0]
|
||||
self.assertClose(opacities_render, opacities_gt)
|
||||
|
||||
# check differentiability
|
||||
loss = opacities_render.mean()
|
||||
loss.backward()
|
||||
self.assertTrue(rays_densities.grad.data.isfinite().all())
|
Loading…
x
Reference in New Issue
Block a user