Generic Raymarcher refactor

Summary: Uses the GenericRaymarcher only as an ABC and derives two common implementations - EA raymarcher and Cumsum raymarcher (from neural volumes)

Reviewed By: shapovalov

Differential Revision: D35927653

fbshipit-source-id: f7e6776e71f8a4e99eefc018a47f29ae769895ee
This commit is contained in:
David Novotny 2022-05-12 14:57:50 -07:00 committed by Facebook GitHub Bot
parent 47d06c8924
commit e85fa03c5a
3 changed files with 163 additions and 75 deletions

View File

@ -47,6 +47,7 @@ class RendererOutput:
prev_stage: Optional[RendererOutput] = None prev_stage: Optional[RendererOutput] = None
normals: Optional[torch.Tensor] = None normals: Optional[torch.Tensor] = None
points: Optional[torch.Tensor] = None # TODO: redundant with depths points: Optional[torch.Tensor] = None # TODO: redundant with depths
weights: Optional[torch.Tensor] = None
aux: Dict[str, Any] = field(default_factory=lambda: {}) aux: Dict[str, Any] = field(default_factory=lambda: {})

View File

@ -4,18 +4,22 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from typing import Tuple from typing import List
import torch import torch
from pytorch3d.implicitron.tools.config import registry from pytorch3d.implicitron.models.renderer.base import ImplicitFunctionWrapper
from pytorch3d.implicitron.tools.config import registry, run_auto_creation
from pytorch3d.renderer import RayBundle
from .base import BaseRenderer, EvaluationMode, RendererOutput from .base import BaseRenderer, EvaluationMode, RendererOutput
from .ray_point_refiner import RayPointRefiner from .ray_point_refiner import RayPointRefiner
from .raymarcher import GenericRaymarcher from .raymarcher import RaymarcherBase
@registry.register @registry.register
class MultiPassEmissionAbsorptionRenderer(BaseRenderer, torch.nn.Module): class MultiPassEmissionAbsorptionRenderer( # pyre-ignore: 13
BaseRenderer, torch.nn.Module
):
""" """
Implements the multi-pass rendering function, in particular, Implements the multi-pass rendering function, in particular,
with emission-absorption ray marching used in NeRF [1]. First, it evaluates with emission-absorption ray marching used in NeRF [1]. First, it evaluates
@ -33,7 +37,17 @@ class MultiPassEmissionAbsorptionRenderer(BaseRenderer, torch.nn.Module):
``` ```
and the final rendered quantities are computed by a dot-product of ray values and the final rendered quantities are computed by a dot-product of ray values
with the weights, e.g. `features = sum_n(weight_n * ray_features_n)`. with the weights, e.g. `features = sum_n(weight_n * ray_features_n)`.
See below for possible values of `cap_fn` and `weight_fn`.
By default, for the EA raymarcher from [1] (
activated with `self.raymarcher_class_type="EmissionAbsorptionRaymarcher"`
):
```
cap_fn(x) = 1 - exp(-x),
weight_fn(x) = w * x.
```
Note that the latter can altered by changing `self.raymarcher_class_type`,
e.g. to "CumsumRaymarcher" which implements the cumulative-sum raymarcher
from NeuralVolumes [2].
Settings: Settings:
n_pts_per_ray_fine_training: The number of points sampled per ray for the n_pts_per_ray_fine_training: The number of points sampled per ray for the
@ -46,42 +60,33 @@ class MultiPassEmissionAbsorptionRenderer(BaseRenderer, torch.nn.Module):
evaluation. evaluation.
append_coarse_samples_to_fine: Add the fine ray points to the coarse points append_coarse_samples_to_fine: Add the fine ray points to the coarse points
after sampling. after sampling.
bg_color: The background color. A tuple of either 1 element or of D elements,
where D matches the feature dimensionality; it is broadcasted when necessary.
density_noise_std_train: Standard deviation of the noise added to the density_noise_std_train: Standard deviation of the noise added to the
opacity field. opacity field.
capping_function: The capping function of the raymarcher. return_weights: Enables returning the rendering weights of the EA raymarcher.
Options: Setting to `True` can lead to a prohibitivelly large memory consumption.
- "exponential" (`cap_fn(x) = 1 - exp(-x)`) raymarcher_class_type: The type of self.raymarcher corresponding to
- "cap1" (`cap_fn(x) = min(x, 1)`) a child of `RaymarcherBase` in the registry.
Set to "exponential" for the standard Emission Absorption raymarching. raymarcher: The raymarcher object used to convert per-point features
weight_function: The weighting function of the raymarcher. and opacities to a feature render.
Options:
- "product" (`weight_fn(w, x) = w * x`)
- "minimum" (`weight_fn(w, x) = min(w, x)`)
Set to "product" for the standard Emission Absorption raymarching.
background_opacity: The raw opacity value (i.e. before exponentiation)
of the background.
blend_output: If `True`, alpha-blends the output renders with the
background color using the rendered opacity mask.
References: References:
[1] Mildenhall, Ben, et al. "Nerf: Representing scenes as neural radiance [1] Mildenhall, Ben, et al. "Nerf: Representing Scenes as Neural Radiance
fields for view synthesis." ECCV 2020. Fields for View Synthesis." ECCV 2020.
[2] Lombardi, Stephen, et al. "Neural Volumes: Learning Dynamic Renderable
Volumes from Images." SIGGRAPH 2019.
""" """
raymarcher_class_type: str = "EmissionAbsorptionRaymarcher"
raymarcher: RaymarcherBase
n_pts_per_ray_fine_training: int = 64 n_pts_per_ray_fine_training: int = 64
n_pts_per_ray_fine_evaluation: int = 64 n_pts_per_ray_fine_evaluation: int = 64
stratified_sampling_coarse_training: bool = True stratified_sampling_coarse_training: bool = True
stratified_sampling_coarse_evaluation: bool = False stratified_sampling_coarse_evaluation: bool = False
append_coarse_samples_to_fine: bool = True append_coarse_samples_to_fine: bool = True
bg_color: Tuple[float, ...] = (0.0,)
density_noise_std_train: float = 0.0 density_noise_std_train: float = 0.0
capping_function: str = "exponential" # exponential | cap1 return_weights: bool = False
weight_function: str = "product" # product | minimum
background_opacity: float = 1e10
blend_output: bool = False
def __post_init__(self): def __post_init__(self):
super().__init__() super().__init__()
@ -97,20 +102,12 @@ class MultiPassEmissionAbsorptionRenderer(BaseRenderer, torch.nn.Module):
add_input_samples=self.append_coarse_samples_to_fine, add_input_samples=self.append_coarse_samples_to_fine,
), ),
} }
run_auto_creation(self)
self._raymarcher = GenericRaymarcher(
1,
self.bg_color,
capping_function=self.capping_function,
weight_function=self.weight_function,
background_opacity=self.background_opacity,
blend_output=self.blend_output,
)
def forward( def forward(
self, self,
ray_bundle, ray_bundle: RayBundle,
implicit_functions=[], implicit_functions: List[ImplicitFunctionWrapper] = [],
evaluation_mode: EvaluationMode = EvaluationMode.EVALUATION, evaluation_mode: EvaluationMode = EvaluationMode.EVALUATION,
**kwargs **kwargs
) -> RendererOutput: ) -> RendererOutput:
@ -149,14 +146,16 @@ class MultiPassEmissionAbsorptionRenderer(BaseRenderer, torch.nn.Module):
else 0.0 else 0.0
) )
features, depth, mask, weights, aux = self._raymarcher( output = self.raymarcher(
*implicit_functions[0](ray_bundle), *implicit_functions[0](ray_bundle),
ray_lengths=ray_bundle.lengths, ray_lengths=ray_bundle.lengths,
density_noise_std=density_noise_std, density_noise_std=density_noise_std,
) )
output = RendererOutput( output.prev_stage = prev_stage
features=features, depths=depth, masks=mask, aux=aux, prev_stage=prev_stage
) weights = output.weights
if not self.return_weights:
output.weights = None
# we may need to make a recursive call # we may need to make a recursive call
if len(implicit_functions) > 1: if len(implicit_functions) > 1:

View File

@ -4,51 +4,99 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from typing import Any, Callable, Dict, Tuple, Union from typing import Any, Callable, Dict, Tuple
import torch import torch
from pytorch3d.implicitron.models.renderer.base import RendererOutput
from pytorch3d.implicitron.tools.config import ReplaceableBase, registry
from pytorch3d.renderer.implicit.raymarching import _check_raymarcher_inputs from pytorch3d.renderer.implicit.raymarching import _check_raymarcher_inputs
_TTensor = torch.Tensor _TTensor = torch.Tensor
class GenericRaymarcher(torch.nn.Module): class RaymarcherBase(ReplaceableBase):
"""
Defines a base class for raymarchers. Specifically, a raymarcher is responsible
for taking a set of features and density descriptors along rendering rays
and marching along them in order to generate a feature render.
"""
def __init__(self):
super().__init__()
def forward(
self,
rays_densities: torch.Tensor,
rays_features: torch.Tensor,
aux: Dict[str, Any],
) -> RendererOutput:
"""
Args:
rays_densities: Per-ray density values represented with a tensor
of shape `(..., n_points_per_ray, 1)`.
rays_features: Per-ray feature values represented with a tensor
of shape `(..., n_points_per_ray, feature_dim)`.
aux: a dictionary with extra information.
"""
raise NotImplementedError()
class AccumulativeRaymarcherBase(RaymarcherBase, torch.nn.Module):
""" """
This generalizes the `pytorch3d.renderer.EmissionAbsorptionRaymarcher` This generalizes the `pytorch3d.renderer.EmissionAbsorptionRaymarcher`
and NeuralVolumes' Accumulative ray marcher. It additionally returns and NeuralVolumes' cumsum ray marcher. It additionally returns
the rendering weights that can be used in the NVS pipeline to carry out the rendering weights that can be used in the NVS pipeline to carry out
the importance ray-sampling in the refining pass. the importance ray-sampling in the refining pass.
Different from `EmissionAbsorptionRaymarcher`, it takes raw Different from `pytorch3d.renderer.EmissionAbsorptionRaymarcher`, it takes raw
(non-exponentiated) densities. (non-exponentiated) densities.
Args: Args:
bg_color: background_color. Must be of shape (1,) or (feature_dim,) surface_thickness: The thickness of the raymarched surface.
bg_color: The background color. A tuple of either 1 element or of D elements,
where D matches the feature dimensionality; it is broadcast when necessary.
background_opacity: The raw opacity value (i.e. before exponentiation)
of the background.
density_relu: If `True`, passes the input density through ReLU before
raymarching.
blend_output: If `True`, alpha-blends the output renders with the
background color using the rendered opacity mask.
capping_function: The capping function of the raymarcher.
Options:
- "exponential" (`cap_fn(x) = 1 - exp(-x)`)
- "cap1" (`cap_fn(x) = min(x, 1)`)
Set to "exponential" for the standard Emission Absorption raymarching.
weight_function: The weighting function of the raymarcher.
Options:
- "product" (`weight_fn(w, x) = w * x`)
- "minimum" (`weight_fn(w, x) = min(w, x)`)
Set to "product" for the standard Emission Absorption raymarching.
""" """
def __init__( surface_thickness: int = 1
self, bg_color: Tuple[float, ...] = (0.0,)
surface_thickness: int = 1, background_opacity: float = 0.0
bg_color: Union[Tuple[float, ...], _TTensor] = (0.0,), density_relu: bool = True
capping_function: str = "exponential", # exponential | cap1 blend_output: bool = False
weight_function: str = "product", # product | minimum
background_opacity: float = 0.0, @property
density_relu: bool = True, def capping_function_type(self) -> str:
blend_output: bool = True, raise NotImplementedError()
):
@property
def weight_function_type(self) -> str:
raise NotImplementedError()
def __post_init__(self):
""" """
Args: Args:
surface_thickness: Denotes the overlap between the absorption surface_thickness: Denotes the overlap between the absorption
function and the density function. function and the density function.
""" """
super().__init__() super().__init__()
self.surface_thickness = surface_thickness
self.density_relu = density_relu
self.background_opacity = background_opacity
self.blend_output = blend_output
if not isinstance(bg_color, torch.Tensor):
bg_color = torch.tensor(bg_color)
bg_color = torch.tensor(self.bg_color)
if bg_color.ndim != 1: if bg_color.ndim != 1:
raise ValueError(f"bg_color (shape {bg_color.shape}) should be a 1D tensor") raise ValueError(f"bg_color (shape {bg_color.shape}) should be a 1D tensor")
@ -57,12 +105,12 @@ class GenericRaymarcher(torch.nn.Module):
self._capping_function: Callable[[_TTensor], _TTensor] = { self._capping_function: Callable[[_TTensor], _TTensor] = {
"exponential": lambda x: 1.0 - torch.exp(-x), "exponential": lambda x: 1.0 - torch.exp(-x),
"cap1": lambda x: x.clamp(max=1.0), "cap1": lambda x: x.clamp(max=1.0),
}[capping_function] }[self.capping_function_type]
self._weight_function: Callable[[_TTensor, _TTensor], _TTensor] = { self._weight_function: Callable[[_TTensor, _TTensor], _TTensor] = {
"product": lambda curr, acc: curr * acc, "product": lambda curr, acc: curr * acc,
"minimum": lambda curr, acc: torch.minimum(curr, acc), "minimum": lambda curr, acc: torch.minimum(curr, acc),
}[weight_function] }[self.weight_function_type]
def forward( def forward(
self, self,
@ -71,7 +119,8 @@ class GenericRaymarcher(torch.nn.Module):
aux: Dict[str, Any], aux: Dict[str, Any],
ray_lengths: torch.Tensor, ray_lengths: torch.Tensor,
density_noise_std: float = 0.0, density_noise_std: float = 0.0,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Dict[str, Any]]: **kwargs,
) -> RendererOutput:
""" """
Args: Args:
rays_densities: Per-ray density values represented with a tensor rays_densities: Per-ray density values represented with a tensor
@ -87,7 +136,7 @@ class GenericRaymarcher(torch.nn.Module):
features: A tensor of shape `(..., feature_dim)` containing features: A tensor of shape `(..., feature_dim)` containing
the rendered features for each ray. the rendered features for each ray.
depth: A tensor of shape `(..., 1)` containing estimated depth. depth: A tensor of shape `(..., 1)` containing estimated depth.
opacities: A tensor of shape `(..., 1)` containing rendered opacsities. opacities: A tensor of shape `(..., 1)` containing rendered opacities.
weights: A tensor of shape `(..., n_points_per_ray)` containing weights: A tensor of shape `(..., n_points_per_ray)` containing
the ray-specific non-negative opacity weights. In general, they the ray-specific non-negative opacity weights. In general, they
don't sum to 1 but do not overcome it, i.e. don't sum to 1 but do not overcome it, i.e.
@ -113,16 +162,15 @@ class GenericRaymarcher(torch.nn.Module):
rays_densities = rays_densities[..., 0] rays_densities = rays_densities[..., 0]
if density_noise_std > 0.0: if density_noise_std > 0.0:
rays_densities = ( noise: _TTensor = torch.randn_like(rays_densities).mul(density_noise_std)
rays_densities + torch.randn_like(rays_densities) * density_noise_std rays_densities = rays_densities + noise
)
if self.density_relu: if self.density_relu:
rays_densities = torch.relu(rays_densities) rays_densities = torch.relu(rays_densities)
weighted_densities = deltas * rays_densities weighted_densities = deltas * rays_densities
capped_densities = self._capping_function(weighted_densities) capped_densities = self._capping_function(weighted_densities) # pyre-ignore: 29
rays_opacities = self._capping_function( rays_opacities = self._capping_function( # pyre-ignore: 29
torch.cumsum(weighted_densities, dim=-1) torch.cumsum(weighted_densities, dim=-1)
) )
opacities = rays_opacities[..., -1:] opacities = rays_opacities[..., -1:]
@ -131,7 +179,9 @@ class GenericRaymarcher(torch.nn.Module):
) )
absorption_shifted[..., : self.surface_thickness] = 1.0 absorption_shifted[..., : self.surface_thickness] = 1.0
weights = self._weight_function(capped_densities, absorption_shifted) weights = self._weight_function( # pyre-ignore: 29
capped_densities, absorption_shifted
)
features = (weights[..., None] * rays_features).sum(dim=-2) features = (weights[..., None] * rays_features).sum(dim=-2)
depth = (weights * ray_lengths)[..., None].sum(dim=-2) depth = (weights * ray_lengths)[..., None].sum(dim=-2)
@ -140,4 +190,42 @@ class GenericRaymarcher(torch.nn.Module):
raise ValueError("Wrong number of background color channels.") raise ValueError("Wrong number of background color channels.")
features = alpha * features + (1 - opacities) * self._bg_color features = alpha * features + (1 - opacities) * self._bg_color
return features, depth, opacities, weights, aux return RendererOutput(
features=features,
depths=depth,
masks=opacities,
weights=weights,
aux=aux,
)
@registry.register
class EmissionAbsorptionRaymarcher(AccumulativeRaymarcherBase):
"""
Implements the EmissionAbsorption raymarcher.
"""
background_opacity: float = 1e10
@property
def capping_function_type(self) -> str:
return "exponential"
@property
def weight_function_type(self) -> str:
return "product"
@registry.register
class CumsumRaymarcher(AccumulativeRaymarcherBase):
"""
Implements the NeuralVolumes' cumulative-sum raymarcher.
"""
@property
def capping_function_type(self) -> str:
return "cap1"
@property
def weight_function_type(self) -> str:
return "minimum"