diff --git a/pytorch3d/implicitron/models/renderer/base.py b/pytorch3d/implicitron/models/renderer/base.py index b57c1de6..5f574cb7 100644 --- a/pytorch3d/implicitron/models/renderer/base.py +++ b/pytorch3d/implicitron/models/renderer/base.py @@ -47,6 +47,7 @@ class RendererOutput: prev_stage: Optional[RendererOutput] = None normals: Optional[torch.Tensor] = None points: Optional[torch.Tensor] = None # TODO: redundant with depths + weights: Optional[torch.Tensor] = None aux: Dict[str, Any] = field(default_factory=lambda: {}) diff --git a/pytorch3d/implicitron/models/renderer/multipass_ea.py b/pytorch3d/implicitron/models/renderer/multipass_ea.py index 84872e56..eddd736e 100644 --- a/pytorch3d/implicitron/models/renderer/multipass_ea.py +++ b/pytorch3d/implicitron/models/renderer/multipass_ea.py @@ -4,18 +4,22 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Tuple +from typing import List import torch -from pytorch3d.implicitron.tools.config import registry +from pytorch3d.implicitron.models.renderer.base import ImplicitFunctionWrapper +from pytorch3d.implicitron.tools.config import registry, run_auto_creation +from pytorch3d.renderer import RayBundle from .base import BaseRenderer, EvaluationMode, RendererOutput from .ray_point_refiner import RayPointRefiner -from .raymarcher import GenericRaymarcher +from .raymarcher import RaymarcherBase @registry.register -class MultiPassEmissionAbsorptionRenderer(BaseRenderer, torch.nn.Module): +class MultiPassEmissionAbsorptionRenderer( # pyre-ignore: 13 + BaseRenderer, torch.nn.Module +): """ Implements the multi-pass rendering function, in particular, with emission-absorption ray marching used in NeRF [1]. First, it evaluates @@ -33,7 +37,17 @@ class MultiPassEmissionAbsorptionRenderer(BaseRenderer, torch.nn.Module): ``` and the final rendered quantities are computed by a dot-product of ray values with the weights, e.g. `features = sum_n(weight_n * ray_features_n)`. - See below for possible values of `cap_fn` and `weight_fn`. + + By default, for the EA raymarcher from [1] ( + activated with `self.raymarcher_class_type="EmissionAbsorptionRaymarcher"` + ): + ``` + cap_fn(x) = 1 - exp(-x), + weight_fn(x) = w * x. + ``` + Note that the latter can altered by changing `self.raymarcher_class_type`, + e.g. to "CumsumRaymarcher" which implements the cumulative-sum raymarcher + from NeuralVolumes [2]. Settings: n_pts_per_ray_fine_training: The number of points sampled per ray for the @@ -46,42 +60,33 @@ class MultiPassEmissionAbsorptionRenderer(BaseRenderer, torch.nn.Module): evaluation. append_coarse_samples_to_fine: Add the fine ray points to the coarse points after sampling. - bg_color: The background color. A tuple of either 1 element or of D elements, - where D matches the feature dimensionality; it is broadcasted when necessary. density_noise_std_train: Standard deviation of the noise added to the opacity field. - capping_function: The capping function of the raymarcher. - Options: - - "exponential" (`cap_fn(x) = 1 - exp(-x)`) - - "cap1" (`cap_fn(x) = min(x, 1)`) - Set to "exponential" for the standard Emission Absorption raymarching. - weight_function: The weighting function of the raymarcher. - Options: - - "product" (`weight_fn(w, x) = w * x`) - - "minimum" (`weight_fn(w, x) = min(w, x)`) - Set to "product" for the standard Emission Absorption raymarching. - background_opacity: The raw opacity value (i.e. before exponentiation) - of the background. - blend_output: If `True`, alpha-blends the output renders with the - background color using the rendered opacity mask. + return_weights: Enables returning the rendering weights of the EA raymarcher. + Setting to `True` can lead to a prohibitivelly large memory consumption. + raymarcher_class_type: The type of self.raymarcher corresponding to + a child of `RaymarcherBase` in the registry. + raymarcher: The raymarcher object used to convert per-point features + and opacities to a feature render. References: - [1] Mildenhall, Ben, et al. "Nerf: Representing scenes as neural radiance - fields for view synthesis." ECCV 2020. + [1] Mildenhall, Ben, et al. "Nerf: Representing Scenes as Neural Radiance + Fields for View Synthesis." ECCV 2020. + [2] Lombardi, Stephen, et al. "Neural Volumes: Learning Dynamic Renderable + Volumes from Images." SIGGRAPH 2019. """ + raymarcher_class_type: str = "EmissionAbsorptionRaymarcher" + raymarcher: RaymarcherBase + n_pts_per_ray_fine_training: int = 64 n_pts_per_ray_fine_evaluation: int = 64 stratified_sampling_coarse_training: bool = True stratified_sampling_coarse_evaluation: bool = False append_coarse_samples_to_fine: bool = True - bg_color: Tuple[float, ...] = (0.0,) density_noise_std_train: float = 0.0 - capping_function: str = "exponential" # exponential | cap1 - weight_function: str = "product" # product | minimum - background_opacity: float = 1e10 - blend_output: bool = False + return_weights: bool = False def __post_init__(self): super().__init__() @@ -97,20 +102,12 @@ class MultiPassEmissionAbsorptionRenderer(BaseRenderer, torch.nn.Module): add_input_samples=self.append_coarse_samples_to_fine, ), } - - self._raymarcher = GenericRaymarcher( - 1, - self.bg_color, - capping_function=self.capping_function, - weight_function=self.weight_function, - background_opacity=self.background_opacity, - blend_output=self.blend_output, - ) + run_auto_creation(self) def forward( self, - ray_bundle, - implicit_functions=[], + ray_bundle: RayBundle, + implicit_functions: List[ImplicitFunctionWrapper] = [], evaluation_mode: EvaluationMode = EvaluationMode.EVALUATION, **kwargs ) -> RendererOutput: @@ -149,14 +146,16 @@ class MultiPassEmissionAbsorptionRenderer(BaseRenderer, torch.nn.Module): else 0.0 ) - features, depth, mask, weights, aux = self._raymarcher( + output = self.raymarcher( *implicit_functions[0](ray_bundle), ray_lengths=ray_bundle.lengths, density_noise_std=density_noise_std, ) - output = RendererOutput( - features=features, depths=depth, masks=mask, aux=aux, prev_stage=prev_stage - ) + output.prev_stage = prev_stage + + weights = output.weights + if not self.return_weights: + output.weights = None # we may need to make a recursive call if len(implicit_functions) > 1: diff --git a/pytorch3d/implicitron/models/renderer/raymarcher.py b/pytorch3d/implicitron/models/renderer/raymarcher.py index 87e52911..81495bda 100644 --- a/pytorch3d/implicitron/models/renderer/raymarcher.py +++ b/pytorch3d/implicitron/models/renderer/raymarcher.py @@ -4,51 +4,99 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Any, Callable, Dict, Tuple, Union +from typing import Any, Callable, Dict, Tuple import torch +from pytorch3d.implicitron.models.renderer.base import RendererOutput +from pytorch3d.implicitron.tools.config import ReplaceableBase, registry from pytorch3d.renderer.implicit.raymarching import _check_raymarcher_inputs _TTensor = torch.Tensor -class GenericRaymarcher(torch.nn.Module): +class RaymarcherBase(ReplaceableBase): + """ + Defines a base class for raymarchers. Specifically, a raymarcher is responsible + for taking a set of features and density descriptors along rendering rays + and marching along them in order to generate a feature render. + """ + + def __init__(self): + super().__init__() + + def forward( + self, + rays_densities: torch.Tensor, + rays_features: torch.Tensor, + aux: Dict[str, Any], + ) -> RendererOutput: + """ + Args: + rays_densities: Per-ray density values represented with a tensor + of shape `(..., n_points_per_ray, 1)`. + rays_features: Per-ray feature values represented with a tensor + of shape `(..., n_points_per_ray, feature_dim)`. + aux: a dictionary with extra information. + """ + raise NotImplementedError() + + +class AccumulativeRaymarcherBase(RaymarcherBase, torch.nn.Module): """ This generalizes the `pytorch3d.renderer.EmissionAbsorptionRaymarcher` - and NeuralVolumes' Accumulative ray marcher. It additionally returns + and NeuralVolumes' cumsum ray marcher. It additionally returns the rendering weights that can be used in the NVS pipeline to carry out the importance ray-sampling in the refining pass. - Different from `EmissionAbsorptionRaymarcher`, it takes raw + Different from `pytorch3d.renderer.EmissionAbsorptionRaymarcher`, it takes raw (non-exponentiated) densities. Args: - bg_color: background_color. Must be of shape (1,) or (feature_dim,) + surface_thickness: The thickness of the raymarched surface. + bg_color: The background color. A tuple of either 1 element or of D elements, + where D matches the feature dimensionality; it is broadcast when necessary. + background_opacity: The raw opacity value (i.e. before exponentiation) + of the background. + density_relu: If `True`, passes the input density through ReLU before + raymarching. + blend_output: If `True`, alpha-blends the output renders with the + background color using the rendered opacity mask. + + capping_function: The capping function of the raymarcher. + Options: + - "exponential" (`cap_fn(x) = 1 - exp(-x)`) + - "cap1" (`cap_fn(x) = min(x, 1)`) + Set to "exponential" for the standard Emission Absorption raymarching. + weight_function: The weighting function of the raymarcher. + Options: + - "product" (`weight_fn(w, x) = w * x`) + - "minimum" (`weight_fn(w, x) = min(w, x)`) + Set to "product" for the standard Emission Absorption raymarching. """ - def __init__( - self, - surface_thickness: int = 1, - bg_color: Union[Tuple[float, ...], _TTensor] = (0.0,), - capping_function: str = "exponential", # exponential | cap1 - weight_function: str = "product", # product | minimum - background_opacity: float = 0.0, - density_relu: bool = True, - blend_output: bool = True, - ): + surface_thickness: int = 1 + bg_color: Tuple[float, ...] = (0.0,) + background_opacity: float = 0.0 + density_relu: bool = True + blend_output: bool = False + + @property + def capping_function_type(self) -> str: + raise NotImplementedError() + + @property + def weight_function_type(self) -> str: + raise NotImplementedError() + + def __post_init__(self): """ Args: surface_thickness: Denotes the overlap between the absorption function and the density function. """ super().__init__() - self.surface_thickness = surface_thickness - self.density_relu = density_relu - self.background_opacity = background_opacity - self.blend_output = blend_output - if not isinstance(bg_color, torch.Tensor): - bg_color = torch.tensor(bg_color) + bg_color = torch.tensor(self.bg_color) if bg_color.ndim != 1: raise ValueError(f"bg_color (shape {bg_color.shape}) should be a 1D tensor") @@ -57,12 +105,12 @@ class GenericRaymarcher(torch.nn.Module): self._capping_function: Callable[[_TTensor], _TTensor] = { "exponential": lambda x: 1.0 - torch.exp(-x), "cap1": lambda x: x.clamp(max=1.0), - }[capping_function] + }[self.capping_function_type] self._weight_function: Callable[[_TTensor, _TTensor], _TTensor] = { "product": lambda curr, acc: curr * acc, "minimum": lambda curr, acc: torch.minimum(curr, acc), - }[weight_function] + }[self.weight_function_type] def forward( self, @@ -71,7 +119,8 @@ class GenericRaymarcher(torch.nn.Module): aux: Dict[str, Any], ray_lengths: torch.Tensor, density_noise_std: float = 0.0, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Dict[str, Any]]: + **kwargs, + ) -> RendererOutput: """ Args: rays_densities: Per-ray density values represented with a tensor @@ -87,7 +136,7 @@ class GenericRaymarcher(torch.nn.Module): features: A tensor of shape `(..., feature_dim)` containing the rendered features for each ray. depth: A tensor of shape `(..., 1)` containing estimated depth. - opacities: A tensor of shape `(..., 1)` containing rendered opacsities. + opacities: A tensor of shape `(..., 1)` containing rendered opacities. weights: A tensor of shape `(..., n_points_per_ray)` containing the ray-specific non-negative opacity weights. In general, they don't sum to 1 but do not overcome it, i.e. @@ -113,16 +162,15 @@ class GenericRaymarcher(torch.nn.Module): rays_densities = rays_densities[..., 0] if density_noise_std > 0.0: - rays_densities = ( - rays_densities + torch.randn_like(rays_densities) * density_noise_std - ) + noise: _TTensor = torch.randn_like(rays_densities).mul(density_noise_std) + rays_densities = rays_densities + noise if self.density_relu: rays_densities = torch.relu(rays_densities) weighted_densities = deltas * rays_densities - capped_densities = self._capping_function(weighted_densities) + capped_densities = self._capping_function(weighted_densities) # pyre-ignore: 29 - rays_opacities = self._capping_function( + rays_opacities = self._capping_function( # pyre-ignore: 29 torch.cumsum(weighted_densities, dim=-1) ) opacities = rays_opacities[..., -1:] @@ -131,7 +179,9 @@ class GenericRaymarcher(torch.nn.Module): ) absorption_shifted[..., : self.surface_thickness] = 1.0 - weights = self._weight_function(capped_densities, absorption_shifted) + weights = self._weight_function( # pyre-ignore: 29 + capped_densities, absorption_shifted + ) features = (weights[..., None] * rays_features).sum(dim=-2) depth = (weights * ray_lengths)[..., None].sum(dim=-2) @@ -140,4 +190,42 @@ class GenericRaymarcher(torch.nn.Module): raise ValueError("Wrong number of background color channels.") features = alpha * features + (1 - opacities) * self._bg_color - return features, depth, opacities, weights, aux + return RendererOutput( + features=features, + depths=depth, + masks=opacities, + weights=weights, + aux=aux, + ) + + +@registry.register +class EmissionAbsorptionRaymarcher(AccumulativeRaymarcherBase): + """ + Implements the EmissionAbsorption raymarcher. + """ + + background_opacity: float = 1e10 + + @property + def capping_function_type(self) -> str: + return "exponential" + + @property + def weight_function_type(self) -> str: + return "product" + + +@registry.register +class CumsumRaymarcher(AccumulativeRaymarcherBase): + """ + Implements the NeuralVolumes' cumulative-sum raymarcher. + """ + + @property + def capping_function_type(self) -> str: + return "cap1" + + @property + def weight_function_type(self) -> str: + return "minimum"