mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 20:02:49 +08:00
Generic Raymarcher refactor
Summary: Uses the GenericRaymarcher only as an ABC and derives two common implementations - EA raymarcher and Cumsum raymarcher (from neural volumes) Reviewed By: shapovalov Differential Revision: D35927653 fbshipit-source-id: f7e6776e71f8a4e99eefc018a47f29ae769895ee
This commit is contained in:
parent
47d06c8924
commit
e85fa03c5a
@ -47,6 +47,7 @@ class RendererOutput:
|
|||||||
prev_stage: Optional[RendererOutput] = None
|
prev_stage: Optional[RendererOutput] = None
|
||||||
normals: Optional[torch.Tensor] = None
|
normals: Optional[torch.Tensor] = None
|
||||||
points: Optional[torch.Tensor] = None # TODO: redundant with depths
|
points: Optional[torch.Tensor] = None # TODO: redundant with depths
|
||||||
|
weights: Optional[torch.Tensor] = None
|
||||||
aux: Dict[str, Any] = field(default_factory=lambda: {})
|
aux: Dict[str, Any] = field(default_factory=lambda: {})
|
||||||
|
|
||||||
|
|
||||||
|
@ -4,18 +4,22 @@
|
|||||||
# This source code is licensed under the BSD-style license found in the
|
# This source code is licensed under the BSD-style license found in the
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Tuple
|
from typing import List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from pytorch3d.implicitron.tools.config import registry
|
from pytorch3d.implicitron.models.renderer.base import ImplicitFunctionWrapper
|
||||||
|
from pytorch3d.implicitron.tools.config import registry, run_auto_creation
|
||||||
|
from pytorch3d.renderer import RayBundle
|
||||||
|
|
||||||
from .base import BaseRenderer, EvaluationMode, RendererOutput
|
from .base import BaseRenderer, EvaluationMode, RendererOutput
|
||||||
from .ray_point_refiner import RayPointRefiner
|
from .ray_point_refiner import RayPointRefiner
|
||||||
from .raymarcher import GenericRaymarcher
|
from .raymarcher import RaymarcherBase
|
||||||
|
|
||||||
|
|
||||||
@registry.register
|
@registry.register
|
||||||
class MultiPassEmissionAbsorptionRenderer(BaseRenderer, torch.nn.Module):
|
class MultiPassEmissionAbsorptionRenderer( # pyre-ignore: 13
|
||||||
|
BaseRenderer, torch.nn.Module
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Implements the multi-pass rendering function, in particular,
|
Implements the multi-pass rendering function, in particular,
|
||||||
with emission-absorption ray marching used in NeRF [1]. First, it evaluates
|
with emission-absorption ray marching used in NeRF [1]. First, it evaluates
|
||||||
@ -33,7 +37,17 @@ class MultiPassEmissionAbsorptionRenderer(BaseRenderer, torch.nn.Module):
|
|||||||
```
|
```
|
||||||
and the final rendered quantities are computed by a dot-product of ray values
|
and the final rendered quantities are computed by a dot-product of ray values
|
||||||
with the weights, e.g. `features = sum_n(weight_n * ray_features_n)`.
|
with the weights, e.g. `features = sum_n(weight_n * ray_features_n)`.
|
||||||
See below for possible values of `cap_fn` and `weight_fn`.
|
|
||||||
|
By default, for the EA raymarcher from [1] (
|
||||||
|
activated with `self.raymarcher_class_type="EmissionAbsorptionRaymarcher"`
|
||||||
|
):
|
||||||
|
```
|
||||||
|
cap_fn(x) = 1 - exp(-x),
|
||||||
|
weight_fn(x) = w * x.
|
||||||
|
```
|
||||||
|
Note that the latter can altered by changing `self.raymarcher_class_type`,
|
||||||
|
e.g. to "CumsumRaymarcher" which implements the cumulative-sum raymarcher
|
||||||
|
from NeuralVolumes [2].
|
||||||
|
|
||||||
Settings:
|
Settings:
|
||||||
n_pts_per_ray_fine_training: The number of points sampled per ray for the
|
n_pts_per_ray_fine_training: The number of points sampled per ray for the
|
||||||
@ -46,42 +60,33 @@ class MultiPassEmissionAbsorptionRenderer(BaseRenderer, torch.nn.Module):
|
|||||||
evaluation.
|
evaluation.
|
||||||
append_coarse_samples_to_fine: Add the fine ray points to the coarse points
|
append_coarse_samples_to_fine: Add the fine ray points to the coarse points
|
||||||
after sampling.
|
after sampling.
|
||||||
bg_color: The background color. A tuple of either 1 element or of D elements,
|
|
||||||
where D matches the feature dimensionality; it is broadcasted when necessary.
|
|
||||||
density_noise_std_train: Standard deviation of the noise added to the
|
density_noise_std_train: Standard deviation of the noise added to the
|
||||||
opacity field.
|
opacity field.
|
||||||
capping_function: The capping function of the raymarcher.
|
return_weights: Enables returning the rendering weights of the EA raymarcher.
|
||||||
Options:
|
Setting to `True` can lead to a prohibitivelly large memory consumption.
|
||||||
- "exponential" (`cap_fn(x) = 1 - exp(-x)`)
|
raymarcher_class_type: The type of self.raymarcher corresponding to
|
||||||
- "cap1" (`cap_fn(x) = min(x, 1)`)
|
a child of `RaymarcherBase` in the registry.
|
||||||
Set to "exponential" for the standard Emission Absorption raymarching.
|
raymarcher: The raymarcher object used to convert per-point features
|
||||||
weight_function: The weighting function of the raymarcher.
|
and opacities to a feature render.
|
||||||
Options:
|
|
||||||
- "product" (`weight_fn(w, x) = w * x`)
|
|
||||||
- "minimum" (`weight_fn(w, x) = min(w, x)`)
|
|
||||||
Set to "product" for the standard Emission Absorption raymarching.
|
|
||||||
background_opacity: The raw opacity value (i.e. before exponentiation)
|
|
||||||
of the background.
|
|
||||||
blend_output: If `True`, alpha-blends the output renders with the
|
|
||||||
background color using the rendered opacity mask.
|
|
||||||
|
|
||||||
References:
|
References:
|
||||||
[1] Mildenhall, Ben, et al. "Nerf: Representing scenes as neural radiance
|
[1] Mildenhall, Ben, et al. "Nerf: Representing Scenes as Neural Radiance
|
||||||
fields for view synthesis." ECCV 2020.
|
Fields for View Synthesis." ECCV 2020.
|
||||||
|
[2] Lombardi, Stephen, et al. "Neural Volumes: Learning Dynamic Renderable
|
||||||
|
Volumes from Images." SIGGRAPH 2019.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
raymarcher_class_type: str = "EmissionAbsorptionRaymarcher"
|
||||||
|
raymarcher: RaymarcherBase
|
||||||
|
|
||||||
n_pts_per_ray_fine_training: int = 64
|
n_pts_per_ray_fine_training: int = 64
|
||||||
n_pts_per_ray_fine_evaluation: int = 64
|
n_pts_per_ray_fine_evaluation: int = 64
|
||||||
stratified_sampling_coarse_training: bool = True
|
stratified_sampling_coarse_training: bool = True
|
||||||
stratified_sampling_coarse_evaluation: bool = False
|
stratified_sampling_coarse_evaluation: bool = False
|
||||||
append_coarse_samples_to_fine: bool = True
|
append_coarse_samples_to_fine: bool = True
|
||||||
bg_color: Tuple[float, ...] = (0.0,)
|
|
||||||
density_noise_std_train: float = 0.0
|
density_noise_std_train: float = 0.0
|
||||||
capping_function: str = "exponential" # exponential | cap1
|
return_weights: bool = False
|
||||||
weight_function: str = "product" # product | minimum
|
|
||||||
background_opacity: float = 1e10
|
|
||||||
blend_output: bool = False
|
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -97,20 +102,12 @@ class MultiPassEmissionAbsorptionRenderer(BaseRenderer, torch.nn.Module):
|
|||||||
add_input_samples=self.append_coarse_samples_to_fine,
|
add_input_samples=self.append_coarse_samples_to_fine,
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
run_auto_creation(self)
|
||||||
self._raymarcher = GenericRaymarcher(
|
|
||||||
1,
|
|
||||||
self.bg_color,
|
|
||||||
capping_function=self.capping_function,
|
|
||||||
weight_function=self.weight_function,
|
|
||||||
background_opacity=self.background_opacity,
|
|
||||||
blend_output=self.blend_output,
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
ray_bundle,
|
ray_bundle: RayBundle,
|
||||||
implicit_functions=[],
|
implicit_functions: List[ImplicitFunctionWrapper] = [],
|
||||||
evaluation_mode: EvaluationMode = EvaluationMode.EVALUATION,
|
evaluation_mode: EvaluationMode = EvaluationMode.EVALUATION,
|
||||||
**kwargs
|
**kwargs
|
||||||
) -> RendererOutput:
|
) -> RendererOutput:
|
||||||
@ -149,14 +146,16 @@ class MultiPassEmissionAbsorptionRenderer(BaseRenderer, torch.nn.Module):
|
|||||||
else 0.0
|
else 0.0
|
||||||
)
|
)
|
||||||
|
|
||||||
features, depth, mask, weights, aux = self._raymarcher(
|
output = self.raymarcher(
|
||||||
*implicit_functions[0](ray_bundle),
|
*implicit_functions[0](ray_bundle),
|
||||||
ray_lengths=ray_bundle.lengths,
|
ray_lengths=ray_bundle.lengths,
|
||||||
density_noise_std=density_noise_std,
|
density_noise_std=density_noise_std,
|
||||||
)
|
)
|
||||||
output = RendererOutput(
|
output.prev_stage = prev_stage
|
||||||
features=features, depths=depth, masks=mask, aux=aux, prev_stage=prev_stage
|
|
||||||
)
|
weights = output.weights
|
||||||
|
if not self.return_weights:
|
||||||
|
output.weights = None
|
||||||
|
|
||||||
# we may need to make a recursive call
|
# we may need to make a recursive call
|
||||||
if len(implicit_functions) > 1:
|
if len(implicit_functions) > 1:
|
||||||
|
@ -4,51 +4,99 @@
|
|||||||
# This source code is licensed under the BSD-style license found in the
|
# This source code is licensed under the BSD-style license found in the
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Any, Callable, Dict, Tuple, Union
|
from typing import Any, Callable, Dict, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from pytorch3d.implicitron.models.renderer.base import RendererOutput
|
||||||
|
from pytorch3d.implicitron.tools.config import ReplaceableBase, registry
|
||||||
from pytorch3d.renderer.implicit.raymarching import _check_raymarcher_inputs
|
from pytorch3d.renderer.implicit.raymarching import _check_raymarcher_inputs
|
||||||
|
|
||||||
|
|
||||||
_TTensor = torch.Tensor
|
_TTensor = torch.Tensor
|
||||||
|
|
||||||
|
|
||||||
class GenericRaymarcher(torch.nn.Module):
|
class RaymarcherBase(ReplaceableBase):
|
||||||
|
"""
|
||||||
|
Defines a base class for raymarchers. Specifically, a raymarcher is responsible
|
||||||
|
for taking a set of features and density descriptors along rendering rays
|
||||||
|
and marching along them in order to generate a feature render.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
rays_densities: torch.Tensor,
|
||||||
|
rays_features: torch.Tensor,
|
||||||
|
aux: Dict[str, Any],
|
||||||
|
) -> RendererOutput:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
rays_densities: Per-ray density values represented with a tensor
|
||||||
|
of shape `(..., n_points_per_ray, 1)`.
|
||||||
|
rays_features: Per-ray feature values represented with a tensor
|
||||||
|
of shape `(..., n_points_per_ray, feature_dim)`.
|
||||||
|
aux: a dictionary with extra information.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
|
class AccumulativeRaymarcherBase(RaymarcherBase, torch.nn.Module):
|
||||||
"""
|
"""
|
||||||
This generalizes the `pytorch3d.renderer.EmissionAbsorptionRaymarcher`
|
This generalizes the `pytorch3d.renderer.EmissionAbsorptionRaymarcher`
|
||||||
and NeuralVolumes' Accumulative ray marcher. It additionally returns
|
and NeuralVolumes' cumsum ray marcher. It additionally returns
|
||||||
the rendering weights that can be used in the NVS pipeline to carry out
|
the rendering weights that can be used in the NVS pipeline to carry out
|
||||||
the importance ray-sampling in the refining pass.
|
the importance ray-sampling in the refining pass.
|
||||||
Different from `EmissionAbsorptionRaymarcher`, it takes raw
|
Different from `pytorch3d.renderer.EmissionAbsorptionRaymarcher`, it takes raw
|
||||||
(non-exponentiated) densities.
|
(non-exponentiated) densities.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
bg_color: background_color. Must be of shape (1,) or (feature_dim,)
|
surface_thickness: The thickness of the raymarched surface.
|
||||||
|
bg_color: The background color. A tuple of either 1 element or of D elements,
|
||||||
|
where D matches the feature dimensionality; it is broadcast when necessary.
|
||||||
|
background_opacity: The raw opacity value (i.e. before exponentiation)
|
||||||
|
of the background.
|
||||||
|
density_relu: If `True`, passes the input density through ReLU before
|
||||||
|
raymarching.
|
||||||
|
blend_output: If `True`, alpha-blends the output renders with the
|
||||||
|
background color using the rendered opacity mask.
|
||||||
|
|
||||||
|
capping_function: The capping function of the raymarcher.
|
||||||
|
Options:
|
||||||
|
- "exponential" (`cap_fn(x) = 1 - exp(-x)`)
|
||||||
|
- "cap1" (`cap_fn(x) = min(x, 1)`)
|
||||||
|
Set to "exponential" for the standard Emission Absorption raymarching.
|
||||||
|
weight_function: The weighting function of the raymarcher.
|
||||||
|
Options:
|
||||||
|
- "product" (`weight_fn(w, x) = w * x`)
|
||||||
|
- "minimum" (`weight_fn(w, x) = min(w, x)`)
|
||||||
|
Set to "product" for the standard Emission Absorption raymarching.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
surface_thickness: int = 1
|
||||||
self,
|
bg_color: Tuple[float, ...] = (0.0,)
|
||||||
surface_thickness: int = 1,
|
background_opacity: float = 0.0
|
||||||
bg_color: Union[Tuple[float, ...], _TTensor] = (0.0,),
|
density_relu: bool = True
|
||||||
capping_function: str = "exponential", # exponential | cap1
|
blend_output: bool = False
|
||||||
weight_function: str = "product", # product | minimum
|
|
||||||
background_opacity: float = 0.0,
|
@property
|
||||||
density_relu: bool = True,
|
def capping_function_type(self) -> str:
|
||||||
blend_output: bool = True,
|
raise NotImplementedError()
|
||||||
):
|
|
||||||
|
@property
|
||||||
|
def weight_function_type(self) -> str:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
surface_thickness: Denotes the overlap between the absorption
|
surface_thickness: Denotes the overlap between the absorption
|
||||||
function and the density function.
|
function and the density function.
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.surface_thickness = surface_thickness
|
|
||||||
self.density_relu = density_relu
|
|
||||||
self.background_opacity = background_opacity
|
|
||||||
self.blend_output = blend_output
|
|
||||||
if not isinstance(bg_color, torch.Tensor):
|
|
||||||
bg_color = torch.tensor(bg_color)
|
|
||||||
|
|
||||||
|
bg_color = torch.tensor(self.bg_color)
|
||||||
if bg_color.ndim != 1:
|
if bg_color.ndim != 1:
|
||||||
raise ValueError(f"bg_color (shape {bg_color.shape}) should be a 1D tensor")
|
raise ValueError(f"bg_color (shape {bg_color.shape}) should be a 1D tensor")
|
||||||
|
|
||||||
@ -57,12 +105,12 @@ class GenericRaymarcher(torch.nn.Module):
|
|||||||
self._capping_function: Callable[[_TTensor], _TTensor] = {
|
self._capping_function: Callable[[_TTensor], _TTensor] = {
|
||||||
"exponential": lambda x: 1.0 - torch.exp(-x),
|
"exponential": lambda x: 1.0 - torch.exp(-x),
|
||||||
"cap1": lambda x: x.clamp(max=1.0),
|
"cap1": lambda x: x.clamp(max=1.0),
|
||||||
}[capping_function]
|
}[self.capping_function_type]
|
||||||
|
|
||||||
self._weight_function: Callable[[_TTensor, _TTensor], _TTensor] = {
|
self._weight_function: Callable[[_TTensor, _TTensor], _TTensor] = {
|
||||||
"product": lambda curr, acc: curr * acc,
|
"product": lambda curr, acc: curr * acc,
|
||||||
"minimum": lambda curr, acc: torch.minimum(curr, acc),
|
"minimum": lambda curr, acc: torch.minimum(curr, acc),
|
||||||
}[weight_function]
|
}[self.weight_function_type]
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -71,7 +119,8 @@ class GenericRaymarcher(torch.nn.Module):
|
|||||||
aux: Dict[str, Any],
|
aux: Dict[str, Any],
|
||||||
ray_lengths: torch.Tensor,
|
ray_lengths: torch.Tensor,
|
||||||
density_noise_std: float = 0.0,
|
density_noise_std: float = 0.0,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Dict[str, Any]]:
|
**kwargs,
|
||||||
|
) -> RendererOutput:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
rays_densities: Per-ray density values represented with a tensor
|
rays_densities: Per-ray density values represented with a tensor
|
||||||
@ -87,7 +136,7 @@ class GenericRaymarcher(torch.nn.Module):
|
|||||||
features: A tensor of shape `(..., feature_dim)` containing
|
features: A tensor of shape `(..., feature_dim)` containing
|
||||||
the rendered features for each ray.
|
the rendered features for each ray.
|
||||||
depth: A tensor of shape `(..., 1)` containing estimated depth.
|
depth: A tensor of shape `(..., 1)` containing estimated depth.
|
||||||
opacities: A tensor of shape `(..., 1)` containing rendered opacsities.
|
opacities: A tensor of shape `(..., 1)` containing rendered opacities.
|
||||||
weights: A tensor of shape `(..., n_points_per_ray)` containing
|
weights: A tensor of shape `(..., n_points_per_ray)` containing
|
||||||
the ray-specific non-negative opacity weights. In general, they
|
the ray-specific non-negative opacity weights. In general, they
|
||||||
don't sum to 1 but do not overcome it, i.e.
|
don't sum to 1 but do not overcome it, i.e.
|
||||||
@ -113,16 +162,15 @@ class GenericRaymarcher(torch.nn.Module):
|
|||||||
rays_densities = rays_densities[..., 0]
|
rays_densities = rays_densities[..., 0]
|
||||||
|
|
||||||
if density_noise_std > 0.0:
|
if density_noise_std > 0.0:
|
||||||
rays_densities = (
|
noise: _TTensor = torch.randn_like(rays_densities).mul(density_noise_std)
|
||||||
rays_densities + torch.randn_like(rays_densities) * density_noise_std
|
rays_densities = rays_densities + noise
|
||||||
)
|
|
||||||
if self.density_relu:
|
if self.density_relu:
|
||||||
rays_densities = torch.relu(rays_densities)
|
rays_densities = torch.relu(rays_densities)
|
||||||
|
|
||||||
weighted_densities = deltas * rays_densities
|
weighted_densities = deltas * rays_densities
|
||||||
capped_densities = self._capping_function(weighted_densities)
|
capped_densities = self._capping_function(weighted_densities) # pyre-ignore: 29
|
||||||
|
|
||||||
rays_opacities = self._capping_function(
|
rays_opacities = self._capping_function( # pyre-ignore: 29
|
||||||
torch.cumsum(weighted_densities, dim=-1)
|
torch.cumsum(weighted_densities, dim=-1)
|
||||||
)
|
)
|
||||||
opacities = rays_opacities[..., -1:]
|
opacities = rays_opacities[..., -1:]
|
||||||
@ -131,7 +179,9 @@ class GenericRaymarcher(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
absorption_shifted[..., : self.surface_thickness] = 1.0
|
absorption_shifted[..., : self.surface_thickness] = 1.0
|
||||||
|
|
||||||
weights = self._weight_function(capped_densities, absorption_shifted)
|
weights = self._weight_function( # pyre-ignore: 29
|
||||||
|
capped_densities, absorption_shifted
|
||||||
|
)
|
||||||
features = (weights[..., None] * rays_features).sum(dim=-2)
|
features = (weights[..., None] * rays_features).sum(dim=-2)
|
||||||
depth = (weights * ray_lengths)[..., None].sum(dim=-2)
|
depth = (weights * ray_lengths)[..., None].sum(dim=-2)
|
||||||
|
|
||||||
@ -140,4 +190,42 @@ class GenericRaymarcher(torch.nn.Module):
|
|||||||
raise ValueError("Wrong number of background color channels.")
|
raise ValueError("Wrong number of background color channels.")
|
||||||
features = alpha * features + (1 - opacities) * self._bg_color
|
features = alpha * features + (1 - opacities) * self._bg_color
|
||||||
|
|
||||||
return features, depth, opacities, weights, aux
|
return RendererOutput(
|
||||||
|
features=features,
|
||||||
|
depths=depth,
|
||||||
|
masks=opacities,
|
||||||
|
weights=weights,
|
||||||
|
aux=aux,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@registry.register
|
||||||
|
class EmissionAbsorptionRaymarcher(AccumulativeRaymarcherBase):
|
||||||
|
"""
|
||||||
|
Implements the EmissionAbsorption raymarcher.
|
||||||
|
"""
|
||||||
|
|
||||||
|
background_opacity: float = 1e10
|
||||||
|
|
||||||
|
@property
|
||||||
|
def capping_function_type(self) -> str:
|
||||||
|
return "exponential"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def weight_function_type(self) -> str:
|
||||||
|
return "product"
|
||||||
|
|
||||||
|
|
||||||
|
@registry.register
|
||||||
|
class CumsumRaymarcher(AccumulativeRaymarcherBase):
|
||||||
|
"""
|
||||||
|
Implements the NeuralVolumes' cumulative-sum raymarcher.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def capping_function_type(self) -> str:
|
||||||
|
return "cap1"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def weight_function_type(self) -> str:
|
||||||
|
return "minimum"
|
||||||
|
Loading…
x
Reference in New Issue
Block a user