Raysampler as pluggable

Summary:
This converts raysamplers to ReplaceableBase so that users can hack their own raysampling impls.

Context: Andrea tried to implement TensoRF within implicitron but could not due to the need to implement his own raysampler.

Reviewed By: shapovalov

Differential Revision: D36016318

fbshipit-source-id: ef746f3365282bdfa9c15f7b371090a5aae7f8da
This commit is contained in:
David Novotny 2022-05-12 15:39:35 -07:00 committed by Facebook GitHub Bot
parent e85fa03c5a
commit e767c4b548
16 changed files with 185 additions and 96 deletions

View File

@ -49,10 +49,8 @@ generic_model_args:
append_xyz: append_xyz:
- 5 - 5
latent_dim: 0 latent_dim: 0
raysampler_args: raysampler_AdaptiveRaySampler_args:
n_rays_per_image_sampled_from_mask: 1024 n_rays_per_image_sampled_from_mask: 1024
min_depth: 0.0
max_depth: 0.0
scene_extent: 8.0 scene_extent: 8.0
n_pts_per_ray_training: 64 n_pts_per_ray_training: 64
n_pts_per_ray_evaluation: 64 n_pts_per_ray_evaluation: 64

View File

@ -54,7 +54,7 @@ generic_model_args:
n_harmonic_functions_dir: 4 n_harmonic_functions_dir: 4
pooled_feature_dim: 0 pooled_feature_dim: 0
weight_norm: true weight_norm: true
raysampler_args: raysampler_AdaptiveRaySampler_args:
n_rays_per_image_sampled_from_mask: 1024 n_rays_per_image_sampled_from_mask: 1024
n_pts_per_ray_training: 0 n_pts_per_ray_training: 0
n_pts_per_ray_evaluation: 0 n_pts_per_ray_evaluation: 0

View File

@ -6,5 +6,5 @@ clip_grad: 1.0
generic_model_args: generic_model_args:
chunk_size_grid: 16000 chunk_size_grid: 16000
view_pooler_enabled: true view_pooler_enabled: true
raysampler_args: raysampler_AdaptiveRaySampler_args:
n_rays_per_image_sampled_from_mask: 850 n_rays_per_image_sampled_from_mask: 850

View File

@ -4,7 +4,7 @@ defaults:
- _self_ - _self_
generic_model_args: generic_model_args:
chunk_size_grid: 16000 chunk_size_grid: 16000
raysampler_args: raysampler_AdaptiveRaySampler_args:
n_rays_per_image_sampled_from_mask: 800 n_rays_per_image_sampled_from_mask: 800
n_pts_per_ray_training: 32 n_pts_per_ray_training: 32
n_pts_per_ray_evaluation: 32 n_pts_per_ray_evaluation: 32

View File

@ -1,17 +1,6 @@
defaults: defaults:
- repro_multiseq_base.yaml - repro_multiseq_nerformer.yaml
- repro_feat_extractor_transformer.yaml
- _self_ - _self_
generic_model_args: generic_model_args:
chunk_size_grid: 16000
raysampler_args:
n_rays_per_image_sampled_from_mask: 800
n_pts_per_ray_training: 32
n_pts_per_ray_evaluation: 32
renderer_MultiPassEmissionAbsorptionRenderer_args:
n_pts_per_ray_fine_training: 16
n_pts_per_ray_fine_evaluation: 16
implicit_function_class_type: NeRFormerImplicitFunction
view_pooler_enabled: true
view_pooler_args: view_pooler_args:
feature_aggregator_class_type: AngleWeightedIdentityFeatureAggregator feature_aggregator_class_type: AngleWeightedIdentityFeatureAggregator

View File

@ -16,11 +16,11 @@ generic_model_args:
sequence_autodecoder_args: sequence_autodecoder_args:
encoding_dim: 256 encoding_dim: 256
n_instances: 20000 n_instances: 20000
raysampler_args: raysampler_class_type: NearFarRaySampler
raysampler_NearFarRaySampler_args:
n_rays_per_image_sampled_from_mask: 2048 n_rays_per_image_sampled_from_mask: 2048
min_depth: 0.05 min_depth: 0.05
max_depth: 0.05 max_depth: 0.05
scene_extent: 0.0
n_pts_per_ray_training: 1 n_pts_per_ray_training: 1
n_pts_per_ray_evaluation: 1 n_pts_per_ray_evaluation: 1
stratified_point_sampling_training: false stratified_point_sampling_training: false

View File

@ -13,11 +13,11 @@ generic_model_args:
loss_prev_stage_mask_bce: 0.0 loss_prev_stage_mask_bce: 0.0
loss_autodecoder_norm: 0.0 loss_autodecoder_norm: 0.0
depth_neg_penalty: 10000.0 depth_neg_penalty: 10000.0
raysampler_args: raysampler_class_type: NearFarRaySampler
raysampler_NearFarRaySampler_args:
n_rays_per_image_sampled_from_mask: 2048 n_rays_per_image_sampled_from_mask: 2048
min_depth: 0.05 min_depth: 0.05
max_depth: 0.05 max_depth: 0.05
scene_extent: 0.0
n_pts_per_ray_training: 1 n_pts_per_ray_training: 1
n_pts_per_ray_evaluation: 1 n_pts_per_ray_evaluation: 1
stratified_point_sampling_training: false stratified_point_sampling_training: false

View File

@ -49,7 +49,7 @@ generic_model_args:
n_harmonic_functions_dir: 4 n_harmonic_functions_dir: 4
pooled_feature_dim: 0 pooled_feature_dim: 0
weight_norm: true weight_norm: true
raysampler_args: raysampler_AdaptiveRaySampler_args:
n_rays_per_image_sampled_from_mask: 1024 n_rays_per_image_sampled_from_mask: 1024
n_pts_per_ray_training: 0 n_pts_per_ray_training: 0
n_pts_per_ray_evaluation: 0 n_pts_per_ray_evaluation: 0

View File

@ -1,4 +1,3 @@
defaults: defaults:
- repro_singleseq_base - repro_singleseq_base
- _self_ - _self_
exp_dir: ./data/nerf_single_apple/

View File

@ -5,5 +5,5 @@ defaults:
generic_model_args: generic_model_args:
chunk_size_grid: 16000 chunk_size_grid: 16000
view_pooler_enabled: true view_pooler_enabled: true
raysampler_args: raysampler_AdaptiveRaySampler_args:
n_rays_per_image_sampled_from_mask: 850 n_rays_per_image_sampled_from_mask: 850

View File

@ -6,7 +6,7 @@ generic_model_args:
chunk_size_grid: 16000 chunk_size_grid: 16000
view_pooler_enabled: true view_pooler_enabled: true
implicit_function_class_type: NeRFormerImplicitFunction implicit_function_class_type: NeRFormerImplicitFunction
raysampler_args: raysampler_AdaptiveRaySampler_args:
n_rays_per_image_sampled_from_mask: 800 n_rays_per_image_sampled_from_mask: 800
n_pts_per_ray_training: 32 n_pts_per_ray_training: 32
n_pts_per_ray_evaluation: 32 n_pts_per_ray_evaluation: 32

View File

@ -12,11 +12,11 @@ generic_model_args:
loss_prev_stage_mask_bce: 0.0 loss_prev_stage_mask_bce: 0.0
loss_autodecoder_norm: 0.0 loss_autodecoder_norm: 0.0
depth_neg_penalty: 10000.0 depth_neg_penalty: 10000.0
raysampler_args: raysampler_class_type: NearFarRaySampler
raysampler_NearFarRaySampler_args:
n_rays_per_image_sampled_from_mask: 2048 n_rays_per_image_sampled_from_mask: 2048
min_depth: 0.05 min_depth: 0.05
max_depth: 0.05 max_depth: 0.05
scene_extent: 0.0
n_pts_per_ray_training: 1 n_pts_per_ray_training: 1
n_pts_per_ray_evaluation: 1 n_pts_per_ray_evaluation: 1
stratified_point_sampling_training: false stratified_point_sampling_training: false

View File

@ -13,11 +13,11 @@ generic_model_args:
loss_prev_stage_mask_bce: 0.0 loss_prev_stage_mask_bce: 0.0
loss_autodecoder_norm: 0.0 loss_autodecoder_norm: 0.0
depth_neg_penalty: 10000.0 depth_neg_penalty: 10000.0
raysampler_args: raysampler_class_type: NearFarRaySampler
raysampler_NearFarRaySampler_args:
n_rays_per_image_sampled_from_mask: 2048 n_rays_per_image_sampled_from_mask: 2048
min_depth: 0.05 min_depth: 0.05
max_depth: 0.05 max_depth: 0.05
scene_extent: 0.0
n_pts_per_ray_training: 1 n_pts_per_ray_training: 1
n_pts_per_ray_evaluation: 1 n_pts_per_ray_evaluation: 1
stratified_point_sampling_training: false stratified_point_sampling_training: false

View File

@ -19,7 +19,7 @@ from pytorch3d.implicitron.tools.config import (
run_auto_creation, run_auto_creation,
) )
from pytorch3d.implicitron.tools.rasterize_mc import rasterize_mc_samples from pytorch3d.implicitron.tools.rasterize_mc import rasterize_mc_samples
from pytorch3d.implicitron.tools.utils import cat_dataclass from pytorch3d.implicitron.tools.utils import cat_dataclass, setattr_if_hasattr
from pytorch3d.renderer import RayBundle, utils as rend_utils from pytorch3d.renderer import RayBundle, utils as rend_utils
from pytorch3d.renderer.cameras import CamerasBase from pytorch3d.renderer.cameras import CamerasBase
from visdom import Visdom from visdom import Visdom
@ -46,7 +46,7 @@ from .renderer.base import (
) )
from .renderer.lstm_renderer import LSTMRenderer # noqa from .renderer.lstm_renderer import LSTMRenderer # noqa
from .renderer.multipass_ea import MultiPassEmissionAbsorptionRenderer # noqa from .renderer.multipass_ea import MultiPassEmissionAbsorptionRenderer # noqa
from .renderer.ray_sampler import RaySampler from .renderer.ray_sampler import RaySamplerBase
from .renderer.sdf_renderer import SignedDistanceFunctionRenderer # noqa from .renderer.sdf_renderer import SignedDistanceFunctionRenderer # noqa
from .resnet_feature_extractor import ResNetFeatureExtractor from .resnet_feature_extractor import ResNetFeatureExtractor
from .view_pooler.view_pooler import ViewPooler from .view_pooler.view_pooler import ViewPooler
@ -160,6 +160,8 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
the scene such as multiple objects or morphing objects. It is up to the implicit the scene such as multiple objects or morphing objects. It is up to the implicit
function definition how to use it, but the most typical way is to broadcast and function definition how to use it, but the most typical way is to broadcast and
concatenate to the other inputs for the implicit function. concatenate to the other inputs for the implicit function.
raysampler_class_type: The name of the raysampler class which is available
in the global registry.
raysampler: An instance of RaySampler which is used to emit raysampler: An instance of RaySampler which is used to emit
rays from the target view(s). rays from the target view(s).
renderer_class_type: The name of the renderer class which is available in the global renderer_class_type: The name of the renderer class which is available in the global
@ -204,7 +206,8 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
sequence_autodecoder: Autodecoder sequence_autodecoder: Autodecoder
# ---- raysampler # ---- raysampler
raysampler: RaySampler raysampler_class_type: str = "AdaptiveRaySampler"
raysampler: RaySamplerBase
# ---- renderer configs # ---- renderer configs
renderer_class_type: str = "MultiPassEmissionAbsorptionRenderer" renderer_class_type: str = "MultiPassEmissionAbsorptionRenderer"
@ -262,11 +265,6 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
super().__init__() super().__init__()
self.view_metrics = ViewMetrics() self.view_metrics = ViewMetrics()
self._check_and_preprocess_renderer_configs()
self.raysampler_args["sampling_mode_training"] = self.sampling_mode_training
self.raysampler_args["sampling_mode_evaluation"] = self.sampling_mode_evaluation
self.raysampler_args["image_width"] = self.render_image_width
self.raysampler_args["image_height"] = self.render_image_height
run_auto_creation(self) run_auto_creation(self)
self._implicit_functions = self._construct_implicit_functions() self._implicit_functions = self._construct_implicit_functions()
@ -339,7 +337,7 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
) )
# (1) Sample rendering rays with the ray sampler. # (1) Sample rendering rays with the ray sampler.
ray_bundle: RayBundle = self.raysampler( ray_bundle: RayBundle = self.raysampler( # pyre-fixme[29]
target_cameras, target_cameras,
evaluation_mode, evaluation_mode,
mask=mask_crop[:n_targets] mask=mask_crop[:n_targets]
@ -565,19 +563,52 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
else 0 else 0
) )
def _check_and_preprocess_renderer_configs(self): def create_raysampler(self):
raysampler_args = getattr(
self, "raysampler_" + self.raysampler_class_type + "_args"
)
setattr_if_hasattr(
raysampler_args, "sampling_mode_training", self.sampling_mode_training
)
setattr_if_hasattr(
raysampler_args, "sampling_mode_evaluation", self.sampling_mode_evaluation
)
setattr_if_hasattr(raysampler_args, "image_width", self.render_image_width)
setattr_if_hasattr(raysampler_args, "image_height", self.render_image_height)
self.raysampler = registry.get(RaySamplerBase, self.raysampler_class_type)(
**raysampler_args
)
def create_renderer(self):
raysampler_args = getattr(
self, "raysampler_" + self.raysampler_class_type + "_args"
)
self.renderer_MultiPassEmissionAbsorptionRenderer_args[ self.renderer_MultiPassEmissionAbsorptionRenderer_args[
"stratified_sampling_coarse_training" "stratified_sampling_coarse_training"
] = self.raysampler_args["stratified_point_sampling_training"] ] = raysampler_args["stratified_point_sampling_training"]
self.renderer_MultiPassEmissionAbsorptionRenderer_args[ self.renderer_MultiPassEmissionAbsorptionRenderer_args[
"stratified_sampling_coarse_evaluation" "stratified_sampling_coarse_evaluation"
] = self.raysampler_args["stratified_point_sampling_evaluation"] ] = raysampler_args["stratified_point_sampling_evaluation"]
self.renderer_SignedDistanceFunctionRenderer_args[ self.renderer_SignedDistanceFunctionRenderer_args[
"render_features_dimensions" "render_features_dimensions"
] = self.render_features_dimensions ] = self.render_features_dimensions
self.renderer_SignedDistanceFunctionRenderer_args.ray_tracer_args[
"object_bounding_sphere" if self.renderer_class_type == "SignedDistanceFunctionRenderer":
] = self.raysampler_args["scene_extent"] if "scene_extent" not in raysampler_args:
raise ValueError(
"SignedDistanceFunctionRenderer requires"
+ " a raysampler that defines the 'scene_extent' field"
+ " (this field is supported by, e.g., the adaptive raysampler - "
+ " self.raysampler_class_type='AdaptiveRaySampler')."
)
self.renderer_SignedDistanceFunctionRenderer_args.ray_tracer_args[
"object_bounding_sphere"
] = self.raysampler_AdaptiveRaySampler_args["scene_extent"]
renderer_args = getattr(self, "renderer_" + self.renderer_class_type + "_args")
self.renderer = registry.get(BaseRenderer, self.renderer_class_type)(
**renderer_args
)
def create_view_pooler(self): def create_view_pooler(self):
""" """

View File

@ -9,16 +9,48 @@ from typing import Optional, Tuple
import torch import torch
from pytorch3d.implicitron.tools import camera_utils from pytorch3d.implicitron.tools import camera_utils
from pytorch3d.implicitron.tools.config import Configurable from pytorch3d.implicitron.tools.config import ReplaceableBase, registry
from pytorch3d.renderer import NDCMultinomialRaysampler, RayBundle from pytorch3d.renderer import NDCMultinomialRaysampler, RayBundle
from pytorch3d.renderer.cameras import CamerasBase from pytorch3d.renderer.cameras import CamerasBase
from .base import EvaluationMode, RenderSamplingMode from .base import EvaluationMode, RenderSamplingMode
class RaySampler(Configurable, torch.nn.Module): class RaySamplerBase(ReplaceableBase):
"""
Base class for ray samplers.
""" """
def __init__(self):
super().__init__()
def forward(
self,
cameras: CamerasBase,
evaluation_mode: EvaluationMode,
mask: Optional[torch.Tensor] = None,
) -> RayBundle:
"""
Args:
cameras: A batch of `batch_size` cameras from which the rays are emitted.
evaluation_mode: one of `EvaluationMode.TRAINING` or
`EvaluationMode.EVALUATION` which determines the sampling mode
that is used.
mask: Active for the `RenderSamplingMode.MASK_SAMPLE` sampling mode.
Defines a non-negative mask of shape
`(batch_size, image_height, image_width)` where each per-pixel
value is proportional to the probability of sampling the
corresponding pixel's ray.
Returns:
ray_bundle: A `RayBundle` object containing the parametrizations of the
sampled rendering rays.
"""
raise NotImplementedError()
class AbstractMaskRaySampler(RaySamplerBase, torch.nn.Module):
"""
Samples a fixed number of points along rays which are in turn sampled for Samples a fixed number of points along rays which are in turn sampled for
each camera in a batch. each camera in a batch.
@ -29,46 +61,19 @@ class RaySampler(Configurable, torch.nn.Module):
for training and evaluation by setting `self.sampling_mode_training` for training and evaluation by setting `self.sampling_mode_training`
and `self.sampling_mode_training` accordingly. and `self.sampling_mode_training` accordingly.
The class allows two modes of sampling points along the rays: The class allows to adjust the sampling points along rays by overwriting the
1) Sampling between fixed near and far z-planes: `AbstractMaskRaySampler._get_min_max_depth_bounds` function which returns
Active when `self.scene_extent <= 0`, samples points along each ray the near/far planes (`min_depth`/`max_depth`) `NDCMultinomialRaysampler`.
with approximately uniform spacing of z-coordinates between
the minimum depth `self.min_depth` and the maximum depth `self.max_depth`.
This sampling is useful for rendering scenes where the camera is
in a constant distance from the focal point of the scene.
2) Adaptive near/far plane estimation around the world scene center:
Active when `self.scene_extent > 0`. Samples points on each
ray between near and far planes whose depths are determined based on
the distance from the camera center to a predefined scene center.
More specifically,
`min_depth = max(
(self.scene_center-camera_center).norm() - self.scene_extent, eps
)` and
`max_depth = (self.scene_center-camera_center).norm() + self.scene_extent`.
This sampling is ideal for object-centric scenes whose contents are
centered around a known `self.scene_center` and fit into a bounding sphere
with a radius of `self.scene_extent`.
Similar to the sampling mode, the sampling parameters can be set separately
for training and evaluation.
Settings: Settings:
image_width: The horizontal size of the image grid. image_width: The horizontal size of the image grid.
image_height: The vertical size of the image grid. image_height: The vertical size of the image grid.
scene_center: The xyz coordinates of the center of the scene used
along with `scene_extent` to compute the min and max depth planes
for sampling ray-points.
scene_extent: The radius of the scene bounding sphere centered at `scene_center`.
If `scene_extent <= 0`, the raysampler samples points between
`self.min_depth` and `self.max_depth` depths instead.
sampling_mode_training: The ray sampling mode for training. This should be a str sampling_mode_training: The ray sampling mode for training. This should be a str
option from the RenderSamplingMode Enum option from the RenderSamplingMode Enum
sampling_mode_evaluation: Same as above but for evaluation. sampling_mode_evaluation: Same as above but for evaluation.
n_pts_per_ray_training: The number of points sampled along each ray during training. n_pts_per_ray_training: The number of points sampled along each ray during training.
n_pts_per_ray_evaluation: The number of points sampled along each ray during evaluation. n_pts_per_ray_evaluation: The number of points sampled along each ray during evaluation.
n_rays_per_image_sampled_from_mask: The amount of rays to be sampled from the image grid n_rays_per_image_sampled_from_mask: The amount of rays to be sampled from the image grid
min_depth: The minimum depth of a ray-point. Active when `self.scene_extent > 0`.
max_depth: The maximum depth of a ray-point. Active when `self.scene_extent > 0`.
stratified_point_sampling_training: if set, performs stratified random sampling stratified_point_sampling_training: if set, performs stratified random sampling
along the ray; otherwise takes ray points at deterministic offsets. along the ray; otherwise takes ray points at deterministic offsets.
stratified_point_sampling_evaluation: Same as above but for evaluation. stratified_point_sampling_evaluation: Same as above but for evaluation.
@ -77,24 +82,17 @@ class RaySampler(Configurable, torch.nn.Module):
image_width: int = 400 image_width: int = 400
image_height: int = 400 image_height: int = 400
scene_center: Tuple[float, float, float] = field(
default_factory=lambda: (0.0, 0.0, 0.0)
)
scene_extent: float = 0.0
sampling_mode_training: str = "mask_sample" sampling_mode_training: str = "mask_sample"
sampling_mode_evaluation: str = "full_grid" sampling_mode_evaluation: str = "full_grid"
n_pts_per_ray_training: int = 64 n_pts_per_ray_training: int = 64
n_pts_per_ray_evaluation: int = 64 n_pts_per_ray_evaluation: int = 64
n_rays_per_image_sampled_from_mask: int = 1024 n_rays_per_image_sampled_from_mask: int = 1024
min_depth: float = 0.1
max_depth: float = 8.0
# stratified sampling vs taking points at deterministic offsets # stratified sampling vs taking points at deterministic offsets
stratified_point_sampling_training: bool = True stratified_point_sampling_training: bool = True
stratified_point_sampling_evaluation: bool = False stratified_point_sampling_evaluation: bool = False
def __post_init__(self): def __post_init__(self):
super().__init__() super().__init__()
self.scene_center = torch.FloatTensor(self.scene_center)
self._sampling_mode = { self._sampling_mode = {
EvaluationMode.TRAINING: RenderSamplingMode(self.sampling_mode_training), EvaluationMode.TRAINING: RenderSamplingMode(self.sampling_mode_training),
@ -108,8 +106,8 @@ class RaySampler(Configurable, torch.nn.Module):
image_width=self.image_width, image_width=self.image_width,
image_height=self.image_height, image_height=self.image_height,
n_pts_per_ray=self.n_pts_per_ray_training, n_pts_per_ray=self.n_pts_per_ray_training,
min_depth=self.min_depth, min_depth=0.0,
max_depth=self.max_depth, max_depth=0.0,
n_rays_per_image=self.n_rays_per_image_sampled_from_mask n_rays_per_image=self.n_rays_per_image_sampled_from_mask
if self._sampling_mode[EvaluationMode.TRAINING] if self._sampling_mode[EvaluationMode.TRAINING]
== RenderSamplingMode.MASK_SAMPLE == RenderSamplingMode.MASK_SAMPLE
@ -121,8 +119,8 @@ class RaySampler(Configurable, torch.nn.Module):
image_width=self.image_width, image_width=self.image_width,
image_height=self.image_height, image_height=self.image_height,
n_pts_per_ray=self.n_pts_per_ray_evaluation, n_pts_per_ray=self.n_pts_per_ray_evaluation,
min_depth=self.min_depth, min_depth=0.0,
max_depth=self.max_depth, max_depth=0.0,
n_rays_per_image=self.n_rays_per_image_sampled_from_mask n_rays_per_image=self.n_rays_per_image_sampled_from_mask
if self._sampling_mode[EvaluationMode.EVALUATION] if self._sampling_mode[EvaluationMode.EVALUATION]
== RenderSamplingMode.MASK_SAMPLE == RenderSamplingMode.MASK_SAMPLE
@ -132,6 +130,9 @@ class RaySampler(Configurable, torch.nn.Module):
), ),
} }
def _get_min_max_depth_bounds(self, cameras: CamerasBase) -> Tuple[float, float]:
raise NotImplementedError()
def forward( def forward(
self, self,
cameras: CamerasBase, cameras: CamerasBase,
@ -169,12 +170,7 @@ class RaySampler(Configurable, torch.nn.Module):
mode="nearest", mode="nearest",
)[:, 0] )[:, 0]
if self.scene_extent > 0.0: min_depth, max_depth = self._get_min_max_depth_bounds(cameras)
# Override the min/max depth set in initialization based on the
# input cameras.
min_depth, max_depth = camera_utils.get_min_max_depth_bounds(
cameras, self.scene_center, self.scene_extent
)
# pyre-fixme[29]: # pyre-fixme[29]:
# `Union[BoundMethod[typing.Callable(torch.Tensor.__getitem__)[[Named(self, # `Union[BoundMethod[typing.Callable(torch.Tensor.__getitem__)[[Named(self,
@ -183,8 +179,75 @@ class RaySampler(Configurable, torch.nn.Module):
ray_bundle = self._raysamplers[evaluation_mode]( ray_bundle = self._raysamplers[evaluation_mode](
cameras=cameras, cameras=cameras,
mask=sample_mask, mask=sample_mask,
min_depth=float(min_depth[0]) if self.scene_extent > 0.0 else None, min_depth=min_depth,
max_depth=float(max_depth[0]) if self.scene_extent > 0.0 else None, max_depth=max_depth,
) )
return ray_bundle return ray_bundle
@registry.register
class AdaptiveRaySampler(AbstractMaskRaySampler):
"""
Adaptively samples points on each ray between near and far planes whose
depths are determined based on the distance from the camera center
to a predefined scene center.
More specifically,
`min_depth = max(
(self.scene_center-camera_center).norm() - self.scene_extent, eps
)` and
`max_depth = (self.scene_center-camera_center).norm() + self.scene_extent`.
This sampling is ideal for object-centric scenes whose contents are
centered around a known `self.scene_center` and fit into a bounding sphere
with a radius of `self.scene_extent`.
Args:
scene_center: The xyz coordinates of the center of the scene used
along with `scene_extent` to compute the min and max depth planes
for sampling ray-points.
scene_extent: The radius of the scene bounding box centered at `scene_center`.
"""
scene_extent: float = 8.0
scene_center: Tuple[float, float, float] = (0.0, 0.0, 0.0)
def __post_init__(self):
super().__post_init__()
if self.scene_extent <= 0.0:
raise ValueError("Adaptive raysampler requires self.scene_extent > 0.")
self._scene_center = torch.FloatTensor(self.scene_center)
def _get_min_max_depth_bounds(self, cameras: CamerasBase) -> Tuple[float, float]:
"""
Returns the adaptivelly calculated near/far planes.
"""
min_depth, max_depth = camera_utils.get_min_max_depth_bounds(
cameras, self._scene_center, self.scene_extent
)
return float(min_depth[0]), float(max_depth[0])
@registry.register
class NearFarRaySampler(AbstractMaskRaySampler):
"""
Samples a fixed number of points between fixed near and far z-planes.
Specifically, samples points along each ray with approximately uniform spacing
of z-coordinates between the minimum depth `self.min_depth` and the maximum depth
`self.max_depth`. This sampling is useful for rendering scenes where the camera is
in a constant distance from the focal point of the scene.
Args:
min_depth: The minimum depth of a ray-point.
max_depth: The maximum depth of a ray-point.
"""
min_depth: float = 0.1
max_depth: float = 8.0
def _get_min_max_depth_bounds(self, cameras: CamerasBase) -> Tuple[float, float]:
"""
Returns the stored near/far planes.
"""
return self.min_depth, self.max_depth

View File

@ -157,6 +157,15 @@ def cat_dataclass(batch, tensor_collator: Callable):
return type(elem)(**collated) return type(elem)(**collated)
def setattr_if_hasattr(obj, name, value):
"""
Same as setattr(obj, name, value), but does nothing in case `name` is
not an attribe of `obj`.
"""
if hasattr(obj, name):
setattr(obj, name, value)
class Timer: class Timer:
""" """
A simple class for timing execution. A simple class for timing execution.