mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 18:02:14 +08:00 
			
		
		
		
	Generic Raymarcher refactor
Summary: Uses the GenericRaymarcher only as an ABC and derives two common implementations - EA raymarcher and Cumsum raymarcher (from neural volumes) Reviewed By: shapovalov Differential Revision: D35927653 fbshipit-source-id: f7e6776e71f8a4e99eefc018a47f29ae769895ee
This commit is contained in:
		
							parent
							
								
									47d06c8924
								
							
						
					
					
						commit
						e85fa03c5a
					
				@ -47,6 +47,7 @@ class RendererOutput:
 | 
			
		||||
    prev_stage: Optional[RendererOutput] = None
 | 
			
		||||
    normals: Optional[torch.Tensor] = None
 | 
			
		||||
    points: Optional[torch.Tensor] = None  # TODO: redundant with depths
 | 
			
		||||
    weights: Optional[torch.Tensor] = None
 | 
			
		||||
    aux: Dict[str, Any] = field(default_factory=lambda: {})
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -4,18 +4,22 @@
 | 
			
		||||
# This source code is licensed under the BSD-style license found in the
 | 
			
		||||
# LICENSE file in the root directory of this source tree.
 | 
			
		||||
 | 
			
		||||
from typing import Tuple
 | 
			
		||||
from typing import List
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from pytorch3d.implicitron.tools.config import registry
 | 
			
		||||
from pytorch3d.implicitron.models.renderer.base import ImplicitFunctionWrapper
 | 
			
		||||
from pytorch3d.implicitron.tools.config import registry, run_auto_creation
 | 
			
		||||
from pytorch3d.renderer import RayBundle
 | 
			
		||||
 | 
			
		||||
from .base import BaseRenderer, EvaluationMode, RendererOutput
 | 
			
		||||
from .ray_point_refiner import RayPointRefiner
 | 
			
		||||
from .raymarcher import GenericRaymarcher
 | 
			
		||||
from .raymarcher import RaymarcherBase
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@registry.register
 | 
			
		||||
class MultiPassEmissionAbsorptionRenderer(BaseRenderer, torch.nn.Module):
 | 
			
		||||
class MultiPassEmissionAbsorptionRenderer(  # pyre-ignore: 13
 | 
			
		||||
    BaseRenderer, torch.nn.Module
 | 
			
		||||
):
 | 
			
		||||
    """
 | 
			
		||||
    Implements the multi-pass rendering function, in particular,
 | 
			
		||||
    with emission-absorption ray marching used in NeRF [1]. First, it evaluates
 | 
			
		||||
@ -33,7 +37,17 @@ class MultiPassEmissionAbsorptionRenderer(BaseRenderer, torch.nn.Module):
 | 
			
		||||
    ```
 | 
			
		||||
    and the final rendered quantities are computed by a dot-product of ray values
 | 
			
		||||
    with the weights, e.g. `features = sum_n(weight_n * ray_features_n)`.
 | 
			
		||||
    See below for possible values of `cap_fn` and `weight_fn`.
 | 
			
		||||
 | 
			
		||||
    By default, for the EA raymarcher from [1] (
 | 
			
		||||
        activated with `self.raymarcher_class_type="EmissionAbsorptionRaymarcher"`
 | 
			
		||||
    ):
 | 
			
		||||
        ```
 | 
			
		||||
        cap_fn(x) = 1 - exp(-x),
 | 
			
		||||
        weight_fn(x) = w * x.
 | 
			
		||||
        ```
 | 
			
		||||
    Note that the latter can altered by changing `self.raymarcher_class_type`,
 | 
			
		||||
    e.g. to "CumsumRaymarcher" which implements the cumulative-sum raymarcher
 | 
			
		||||
    from NeuralVolumes [2].
 | 
			
		||||
 | 
			
		||||
    Settings:
 | 
			
		||||
        n_pts_per_ray_fine_training: The number of points sampled per ray for the
 | 
			
		||||
@ -46,42 +60,33 @@ class MultiPassEmissionAbsorptionRenderer(BaseRenderer, torch.nn.Module):
 | 
			
		||||
            evaluation.
 | 
			
		||||
        append_coarse_samples_to_fine: Add the fine ray points to the coarse points
 | 
			
		||||
            after sampling.
 | 
			
		||||
        bg_color: The background color. A tuple of either 1 element or of D elements,
 | 
			
		||||
            where D matches the feature dimensionality; it is broadcasted when necessary.
 | 
			
		||||
        density_noise_std_train: Standard deviation of the noise added to the
 | 
			
		||||
            opacity field.
 | 
			
		||||
        capping_function: The capping function of the raymarcher.
 | 
			
		||||
            Options:
 | 
			
		||||
                - "exponential" (`cap_fn(x) = 1 - exp(-x)`)
 | 
			
		||||
                - "cap1" (`cap_fn(x) = min(x, 1)`)
 | 
			
		||||
            Set to "exponential" for the standard Emission Absorption raymarching.
 | 
			
		||||
        weight_function: The weighting function of the raymarcher.
 | 
			
		||||
            Options:
 | 
			
		||||
                - "product" (`weight_fn(w, x) = w * x`)
 | 
			
		||||
                - "minimum" (`weight_fn(w, x) = min(w, x)`)
 | 
			
		||||
            Set to "product" for the standard Emission Absorption raymarching.
 | 
			
		||||
        background_opacity: The raw opacity value (i.e. before exponentiation)
 | 
			
		||||
            of the background.
 | 
			
		||||
        blend_output: If `True`, alpha-blends the output renders with the
 | 
			
		||||
            background color using the rendered opacity mask.
 | 
			
		||||
        return_weights: Enables returning the rendering weights of the EA raymarcher.
 | 
			
		||||
            Setting to `True` can lead to a prohibitivelly large memory consumption.
 | 
			
		||||
        raymarcher_class_type: The type of self.raymarcher corresponding to
 | 
			
		||||
            a child of `RaymarcherBase` in the registry.
 | 
			
		||||
        raymarcher: The raymarcher object used to convert per-point features
 | 
			
		||||
            and opacities to a feature render.
 | 
			
		||||
 | 
			
		||||
    References:
 | 
			
		||||
        [1] Mildenhall, Ben, et al. "Nerf: Representing scenes as neural radiance
 | 
			
		||||
            fields for view synthesis." ECCV 2020.
 | 
			
		||||
        [1] Mildenhall, Ben, et al. "Nerf: Representing Scenes as Neural Radiance
 | 
			
		||||
            Fields for View Synthesis." ECCV 2020.
 | 
			
		||||
        [2] Lombardi, Stephen, et al. "Neural Volumes: Learning Dynamic Renderable
 | 
			
		||||
            Volumes from Images." SIGGRAPH 2019.
 | 
			
		||||
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    raymarcher_class_type: str = "EmissionAbsorptionRaymarcher"
 | 
			
		||||
    raymarcher: RaymarcherBase
 | 
			
		||||
 | 
			
		||||
    n_pts_per_ray_fine_training: int = 64
 | 
			
		||||
    n_pts_per_ray_fine_evaluation: int = 64
 | 
			
		||||
    stratified_sampling_coarse_training: bool = True
 | 
			
		||||
    stratified_sampling_coarse_evaluation: bool = False
 | 
			
		||||
    append_coarse_samples_to_fine: bool = True
 | 
			
		||||
    bg_color: Tuple[float, ...] = (0.0,)
 | 
			
		||||
    density_noise_std_train: float = 0.0
 | 
			
		||||
    capping_function: str = "exponential"  # exponential | cap1
 | 
			
		||||
    weight_function: str = "product"  # product | minimum
 | 
			
		||||
    background_opacity: float = 1e10
 | 
			
		||||
    blend_output: bool = False
 | 
			
		||||
    return_weights: bool = False
 | 
			
		||||
 | 
			
		||||
    def __post_init__(self):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
@ -97,20 +102,12 @@ class MultiPassEmissionAbsorptionRenderer(BaseRenderer, torch.nn.Module):
 | 
			
		||||
                add_input_samples=self.append_coarse_samples_to_fine,
 | 
			
		||||
            ),
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        self._raymarcher = GenericRaymarcher(
 | 
			
		||||
            1,
 | 
			
		||||
            self.bg_color,
 | 
			
		||||
            capping_function=self.capping_function,
 | 
			
		||||
            weight_function=self.weight_function,
 | 
			
		||||
            background_opacity=self.background_opacity,
 | 
			
		||||
            blend_output=self.blend_output,
 | 
			
		||||
        )
 | 
			
		||||
        run_auto_creation(self)
 | 
			
		||||
 | 
			
		||||
    def forward(
 | 
			
		||||
        self,
 | 
			
		||||
        ray_bundle,
 | 
			
		||||
        implicit_functions=[],
 | 
			
		||||
        ray_bundle: RayBundle,
 | 
			
		||||
        implicit_functions: List[ImplicitFunctionWrapper] = [],
 | 
			
		||||
        evaluation_mode: EvaluationMode = EvaluationMode.EVALUATION,
 | 
			
		||||
        **kwargs
 | 
			
		||||
    ) -> RendererOutput:
 | 
			
		||||
@ -149,14 +146,16 @@ class MultiPassEmissionAbsorptionRenderer(BaseRenderer, torch.nn.Module):
 | 
			
		||||
            else 0.0
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        features, depth, mask, weights, aux = self._raymarcher(
 | 
			
		||||
        output = self.raymarcher(
 | 
			
		||||
            *implicit_functions[0](ray_bundle),
 | 
			
		||||
            ray_lengths=ray_bundle.lengths,
 | 
			
		||||
            density_noise_std=density_noise_std,
 | 
			
		||||
        )
 | 
			
		||||
        output = RendererOutput(
 | 
			
		||||
            features=features, depths=depth, masks=mask, aux=aux, prev_stage=prev_stage
 | 
			
		||||
        )
 | 
			
		||||
        output.prev_stage = prev_stage
 | 
			
		||||
 | 
			
		||||
        weights = output.weights
 | 
			
		||||
        if not self.return_weights:
 | 
			
		||||
            output.weights = None
 | 
			
		||||
 | 
			
		||||
        # we may need to make a recursive call
 | 
			
		||||
        if len(implicit_functions) > 1:
 | 
			
		||||
 | 
			
		||||
@ -4,51 +4,99 @@
 | 
			
		||||
# This source code is licensed under the BSD-style license found in the
 | 
			
		||||
# LICENSE file in the root directory of this source tree.
 | 
			
		||||
 | 
			
		||||
from typing import Any, Callable, Dict, Tuple, Union
 | 
			
		||||
from typing import Any, Callable, Dict, Tuple
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from pytorch3d.implicitron.models.renderer.base import RendererOutput
 | 
			
		||||
from pytorch3d.implicitron.tools.config import ReplaceableBase, registry
 | 
			
		||||
from pytorch3d.renderer.implicit.raymarching import _check_raymarcher_inputs
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
_TTensor = torch.Tensor
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class GenericRaymarcher(torch.nn.Module):
 | 
			
		||||
class RaymarcherBase(ReplaceableBase):
 | 
			
		||||
    """
 | 
			
		||||
    Defines a base class for raymarchers. Specifically, a raymarcher is responsible
 | 
			
		||||
    for taking a set of features and density descriptors along rendering rays
 | 
			
		||||
    and marching along them in order to generate a feature render.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
    def forward(
 | 
			
		||||
        self,
 | 
			
		||||
        rays_densities: torch.Tensor,
 | 
			
		||||
        rays_features: torch.Tensor,
 | 
			
		||||
        aux: Dict[str, Any],
 | 
			
		||||
    ) -> RendererOutput:
 | 
			
		||||
        """
 | 
			
		||||
        Args:
 | 
			
		||||
            rays_densities: Per-ray density values represented with a tensor
 | 
			
		||||
                of shape `(..., n_points_per_ray, 1)`.
 | 
			
		||||
            rays_features: Per-ray feature values represented with a tensor
 | 
			
		||||
                of shape `(..., n_points_per_ray, feature_dim)`.
 | 
			
		||||
            aux: a dictionary with extra information.
 | 
			
		||||
        """
 | 
			
		||||
        raise NotImplementedError()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class AccumulativeRaymarcherBase(RaymarcherBase, torch.nn.Module):
 | 
			
		||||
    """
 | 
			
		||||
    This generalizes the `pytorch3d.renderer.EmissionAbsorptionRaymarcher`
 | 
			
		||||
    and NeuralVolumes' Accumulative ray marcher. It additionally returns
 | 
			
		||||
    and NeuralVolumes' cumsum ray marcher. It additionally returns
 | 
			
		||||
    the rendering weights that can be used in the NVS pipeline to carry out
 | 
			
		||||
    the importance ray-sampling in the refining pass.
 | 
			
		||||
    Different from `EmissionAbsorptionRaymarcher`, it takes raw
 | 
			
		||||
    Different from `pytorch3d.renderer.EmissionAbsorptionRaymarcher`, it takes raw
 | 
			
		||||
    (non-exponentiated) densities.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        bg_color: background_color. Must be of shape (1,) or (feature_dim,)
 | 
			
		||||
        surface_thickness: The thickness of the raymarched surface.
 | 
			
		||||
        bg_color: The background color. A tuple of either 1 element or of D elements,
 | 
			
		||||
            where D matches the feature dimensionality; it is broadcast when necessary.
 | 
			
		||||
        background_opacity: The raw opacity value (i.e. before exponentiation)
 | 
			
		||||
            of the background.
 | 
			
		||||
        density_relu: If `True`, passes the input density through ReLU before
 | 
			
		||||
            raymarching.
 | 
			
		||||
        blend_output: If `True`, alpha-blends the output renders with the
 | 
			
		||||
            background color using the rendered opacity mask.
 | 
			
		||||
 | 
			
		||||
        capping_function: The capping function of the raymarcher.
 | 
			
		||||
            Options:
 | 
			
		||||
                - "exponential" (`cap_fn(x) = 1 - exp(-x)`)
 | 
			
		||||
                - "cap1" (`cap_fn(x) = min(x, 1)`)
 | 
			
		||||
            Set to "exponential" for the standard Emission Absorption raymarching.
 | 
			
		||||
        weight_function: The weighting function of the raymarcher.
 | 
			
		||||
            Options:
 | 
			
		||||
                - "product" (`weight_fn(w, x) = w * x`)
 | 
			
		||||
                - "minimum" (`weight_fn(w, x) = min(w, x)`)
 | 
			
		||||
            Set to "product" for the standard Emission Absorption raymarching.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        surface_thickness: int = 1,
 | 
			
		||||
        bg_color: Union[Tuple[float, ...], _TTensor] = (0.0,),
 | 
			
		||||
        capping_function: str = "exponential",  # exponential | cap1
 | 
			
		||||
        weight_function: str = "product",  # product | minimum
 | 
			
		||||
        background_opacity: float = 0.0,
 | 
			
		||||
        density_relu: bool = True,
 | 
			
		||||
        blend_output: bool = True,
 | 
			
		||||
    ):
 | 
			
		||||
    surface_thickness: int = 1
 | 
			
		||||
    bg_color: Tuple[float, ...] = (0.0,)
 | 
			
		||||
    background_opacity: float = 0.0
 | 
			
		||||
    density_relu: bool = True
 | 
			
		||||
    blend_output: bool = False
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def capping_function_type(self) -> str:
 | 
			
		||||
        raise NotImplementedError()
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def weight_function_type(self) -> str:
 | 
			
		||||
        raise NotImplementedError()
 | 
			
		||||
 | 
			
		||||
    def __post_init__(self):
 | 
			
		||||
        """
 | 
			
		||||
        Args:
 | 
			
		||||
            surface_thickness: Denotes the overlap between the absorption
 | 
			
		||||
                function and the density function.
 | 
			
		||||
        """
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self.surface_thickness = surface_thickness
 | 
			
		||||
        self.density_relu = density_relu
 | 
			
		||||
        self.background_opacity = background_opacity
 | 
			
		||||
        self.blend_output = blend_output
 | 
			
		||||
        if not isinstance(bg_color, torch.Tensor):
 | 
			
		||||
            bg_color = torch.tensor(bg_color)
 | 
			
		||||
 | 
			
		||||
        bg_color = torch.tensor(self.bg_color)
 | 
			
		||||
        if bg_color.ndim != 1:
 | 
			
		||||
            raise ValueError(f"bg_color (shape {bg_color.shape}) should be a 1D tensor")
 | 
			
		||||
 | 
			
		||||
@ -57,12 +105,12 @@ class GenericRaymarcher(torch.nn.Module):
 | 
			
		||||
        self._capping_function: Callable[[_TTensor], _TTensor] = {
 | 
			
		||||
            "exponential": lambda x: 1.0 - torch.exp(-x),
 | 
			
		||||
            "cap1": lambda x: x.clamp(max=1.0),
 | 
			
		||||
        }[capping_function]
 | 
			
		||||
        }[self.capping_function_type]
 | 
			
		||||
 | 
			
		||||
        self._weight_function: Callable[[_TTensor, _TTensor], _TTensor] = {
 | 
			
		||||
            "product": lambda curr, acc: curr * acc,
 | 
			
		||||
            "minimum": lambda curr, acc: torch.minimum(curr, acc),
 | 
			
		||||
        }[weight_function]
 | 
			
		||||
        }[self.weight_function_type]
 | 
			
		||||
 | 
			
		||||
    def forward(
 | 
			
		||||
        self,
 | 
			
		||||
@ -71,7 +119,8 @@ class GenericRaymarcher(torch.nn.Module):
 | 
			
		||||
        aux: Dict[str, Any],
 | 
			
		||||
        ray_lengths: torch.Tensor,
 | 
			
		||||
        density_noise_std: float = 0.0,
 | 
			
		||||
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Dict[str, Any]]:
 | 
			
		||||
        **kwargs,
 | 
			
		||||
    ) -> RendererOutput:
 | 
			
		||||
        """
 | 
			
		||||
        Args:
 | 
			
		||||
            rays_densities: Per-ray density values represented with a tensor
 | 
			
		||||
@ -87,7 +136,7 @@ class GenericRaymarcher(torch.nn.Module):
 | 
			
		||||
            features: A tensor of shape `(..., feature_dim)` containing
 | 
			
		||||
                the rendered features for each ray.
 | 
			
		||||
            depth: A tensor of shape `(..., 1)` containing estimated depth.
 | 
			
		||||
            opacities: A tensor of shape `(..., 1)` containing rendered opacsities.
 | 
			
		||||
            opacities: A tensor of shape `(..., 1)` containing rendered opacities.
 | 
			
		||||
            weights: A tensor of shape `(..., n_points_per_ray)` containing
 | 
			
		||||
                the ray-specific non-negative opacity weights. In general, they
 | 
			
		||||
                don't sum to 1 but do not overcome it, i.e.
 | 
			
		||||
@ -113,16 +162,15 @@ class GenericRaymarcher(torch.nn.Module):
 | 
			
		||||
        rays_densities = rays_densities[..., 0]
 | 
			
		||||
 | 
			
		||||
        if density_noise_std > 0.0:
 | 
			
		||||
            rays_densities = (
 | 
			
		||||
                rays_densities + torch.randn_like(rays_densities) * density_noise_std
 | 
			
		||||
            )
 | 
			
		||||
            noise: _TTensor = torch.randn_like(rays_densities).mul(density_noise_std)
 | 
			
		||||
            rays_densities = rays_densities + noise
 | 
			
		||||
        if self.density_relu:
 | 
			
		||||
            rays_densities = torch.relu(rays_densities)
 | 
			
		||||
 | 
			
		||||
        weighted_densities = deltas * rays_densities
 | 
			
		||||
        capped_densities = self._capping_function(weighted_densities)
 | 
			
		||||
        capped_densities = self._capping_function(weighted_densities)  # pyre-ignore: 29
 | 
			
		||||
 | 
			
		||||
        rays_opacities = self._capping_function(
 | 
			
		||||
        rays_opacities = self._capping_function(  # pyre-ignore: 29
 | 
			
		||||
            torch.cumsum(weighted_densities, dim=-1)
 | 
			
		||||
        )
 | 
			
		||||
        opacities = rays_opacities[..., -1:]
 | 
			
		||||
@ -131,7 +179,9 @@ class GenericRaymarcher(torch.nn.Module):
 | 
			
		||||
        )
 | 
			
		||||
        absorption_shifted[..., : self.surface_thickness] = 1.0
 | 
			
		||||
 | 
			
		||||
        weights = self._weight_function(capped_densities, absorption_shifted)
 | 
			
		||||
        weights = self._weight_function(  # pyre-ignore: 29
 | 
			
		||||
            capped_densities, absorption_shifted
 | 
			
		||||
        )
 | 
			
		||||
        features = (weights[..., None] * rays_features).sum(dim=-2)
 | 
			
		||||
        depth = (weights * ray_lengths)[..., None].sum(dim=-2)
 | 
			
		||||
 | 
			
		||||
@ -140,4 +190,42 @@ class GenericRaymarcher(torch.nn.Module):
 | 
			
		||||
            raise ValueError("Wrong number of background color channels.")
 | 
			
		||||
        features = alpha * features + (1 - opacities) * self._bg_color
 | 
			
		||||
 | 
			
		||||
        return features, depth, opacities, weights, aux
 | 
			
		||||
        return RendererOutput(
 | 
			
		||||
            features=features,
 | 
			
		||||
            depths=depth,
 | 
			
		||||
            masks=opacities,
 | 
			
		||||
            weights=weights,
 | 
			
		||||
            aux=aux,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@registry.register
 | 
			
		||||
class EmissionAbsorptionRaymarcher(AccumulativeRaymarcherBase):
 | 
			
		||||
    """
 | 
			
		||||
    Implements the EmissionAbsorption raymarcher.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    background_opacity: float = 1e10
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def capping_function_type(self) -> str:
 | 
			
		||||
        return "exponential"
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def weight_function_type(self) -> str:
 | 
			
		||||
        return "product"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@registry.register
 | 
			
		||||
class CumsumRaymarcher(AccumulativeRaymarcherBase):
 | 
			
		||||
    """
 | 
			
		||||
    Implements the NeuralVolumes' cumulative-sum raymarcher.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def capping_function_type(self) -> str:
 | 
			
		||||
        return "cap1"
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def weight_function_type(self) -> str:
 | 
			
		||||
        return "minimum"
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user