mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Raysampler as pluggable
Summary: This converts raysamplers to ReplaceableBase so that users can hack their own raysampling impls. Context: Andrea tried to implement TensoRF within implicitron but could not due to the need to implement his own raysampler. Reviewed By: shapovalov Differential Revision: D36016318 fbshipit-source-id: ef746f3365282bdfa9c15f7b371090a5aae7f8da
This commit is contained in:
parent
e85fa03c5a
commit
e767c4b548
@ -49,10 +49,8 @@ generic_model_args:
|
|||||||
append_xyz:
|
append_xyz:
|
||||||
- 5
|
- 5
|
||||||
latent_dim: 0
|
latent_dim: 0
|
||||||
raysampler_args:
|
raysampler_AdaptiveRaySampler_args:
|
||||||
n_rays_per_image_sampled_from_mask: 1024
|
n_rays_per_image_sampled_from_mask: 1024
|
||||||
min_depth: 0.0
|
|
||||||
max_depth: 0.0
|
|
||||||
scene_extent: 8.0
|
scene_extent: 8.0
|
||||||
n_pts_per_ray_training: 64
|
n_pts_per_ray_training: 64
|
||||||
n_pts_per_ray_evaluation: 64
|
n_pts_per_ray_evaluation: 64
|
||||||
|
@ -54,7 +54,7 @@ generic_model_args:
|
|||||||
n_harmonic_functions_dir: 4
|
n_harmonic_functions_dir: 4
|
||||||
pooled_feature_dim: 0
|
pooled_feature_dim: 0
|
||||||
weight_norm: true
|
weight_norm: true
|
||||||
raysampler_args:
|
raysampler_AdaptiveRaySampler_args:
|
||||||
n_rays_per_image_sampled_from_mask: 1024
|
n_rays_per_image_sampled_from_mask: 1024
|
||||||
n_pts_per_ray_training: 0
|
n_pts_per_ray_training: 0
|
||||||
n_pts_per_ray_evaluation: 0
|
n_pts_per_ray_evaluation: 0
|
||||||
|
@ -6,5 +6,5 @@ clip_grad: 1.0
|
|||||||
generic_model_args:
|
generic_model_args:
|
||||||
chunk_size_grid: 16000
|
chunk_size_grid: 16000
|
||||||
view_pooler_enabled: true
|
view_pooler_enabled: true
|
||||||
raysampler_args:
|
raysampler_AdaptiveRaySampler_args:
|
||||||
n_rays_per_image_sampled_from_mask: 850
|
n_rays_per_image_sampled_from_mask: 850
|
||||||
|
@ -4,7 +4,7 @@ defaults:
|
|||||||
- _self_
|
- _self_
|
||||||
generic_model_args:
|
generic_model_args:
|
||||||
chunk_size_grid: 16000
|
chunk_size_grid: 16000
|
||||||
raysampler_args:
|
raysampler_AdaptiveRaySampler_args:
|
||||||
n_rays_per_image_sampled_from_mask: 800
|
n_rays_per_image_sampled_from_mask: 800
|
||||||
n_pts_per_ray_training: 32
|
n_pts_per_ray_training: 32
|
||||||
n_pts_per_ray_evaluation: 32
|
n_pts_per_ray_evaluation: 32
|
||||||
|
@ -1,17 +1,6 @@
|
|||||||
defaults:
|
defaults:
|
||||||
- repro_multiseq_base.yaml
|
- repro_multiseq_nerformer.yaml
|
||||||
- repro_feat_extractor_transformer.yaml
|
|
||||||
- _self_
|
- _self_
|
||||||
generic_model_args:
|
generic_model_args:
|
||||||
chunk_size_grid: 16000
|
|
||||||
raysampler_args:
|
|
||||||
n_rays_per_image_sampled_from_mask: 800
|
|
||||||
n_pts_per_ray_training: 32
|
|
||||||
n_pts_per_ray_evaluation: 32
|
|
||||||
renderer_MultiPassEmissionAbsorptionRenderer_args:
|
|
||||||
n_pts_per_ray_fine_training: 16
|
|
||||||
n_pts_per_ray_fine_evaluation: 16
|
|
||||||
implicit_function_class_type: NeRFormerImplicitFunction
|
|
||||||
view_pooler_enabled: true
|
|
||||||
view_pooler_args:
|
view_pooler_args:
|
||||||
feature_aggregator_class_type: AngleWeightedIdentityFeatureAggregator
|
feature_aggregator_class_type: AngleWeightedIdentityFeatureAggregator
|
||||||
|
@ -16,11 +16,11 @@ generic_model_args:
|
|||||||
sequence_autodecoder_args:
|
sequence_autodecoder_args:
|
||||||
encoding_dim: 256
|
encoding_dim: 256
|
||||||
n_instances: 20000
|
n_instances: 20000
|
||||||
raysampler_args:
|
raysampler_class_type: NearFarRaySampler
|
||||||
|
raysampler_NearFarRaySampler_args:
|
||||||
n_rays_per_image_sampled_from_mask: 2048
|
n_rays_per_image_sampled_from_mask: 2048
|
||||||
min_depth: 0.05
|
min_depth: 0.05
|
||||||
max_depth: 0.05
|
max_depth: 0.05
|
||||||
scene_extent: 0.0
|
|
||||||
n_pts_per_ray_training: 1
|
n_pts_per_ray_training: 1
|
||||||
n_pts_per_ray_evaluation: 1
|
n_pts_per_ray_evaluation: 1
|
||||||
stratified_point_sampling_training: false
|
stratified_point_sampling_training: false
|
||||||
|
@ -13,11 +13,11 @@ generic_model_args:
|
|||||||
loss_prev_stage_mask_bce: 0.0
|
loss_prev_stage_mask_bce: 0.0
|
||||||
loss_autodecoder_norm: 0.0
|
loss_autodecoder_norm: 0.0
|
||||||
depth_neg_penalty: 10000.0
|
depth_neg_penalty: 10000.0
|
||||||
raysampler_args:
|
raysampler_class_type: NearFarRaySampler
|
||||||
|
raysampler_NearFarRaySampler_args:
|
||||||
n_rays_per_image_sampled_from_mask: 2048
|
n_rays_per_image_sampled_from_mask: 2048
|
||||||
min_depth: 0.05
|
min_depth: 0.05
|
||||||
max_depth: 0.05
|
max_depth: 0.05
|
||||||
scene_extent: 0.0
|
|
||||||
n_pts_per_ray_training: 1
|
n_pts_per_ray_training: 1
|
||||||
n_pts_per_ray_evaluation: 1
|
n_pts_per_ray_evaluation: 1
|
||||||
stratified_point_sampling_training: false
|
stratified_point_sampling_training: false
|
||||||
|
@ -49,7 +49,7 @@ generic_model_args:
|
|||||||
n_harmonic_functions_dir: 4
|
n_harmonic_functions_dir: 4
|
||||||
pooled_feature_dim: 0
|
pooled_feature_dim: 0
|
||||||
weight_norm: true
|
weight_norm: true
|
||||||
raysampler_args:
|
raysampler_AdaptiveRaySampler_args:
|
||||||
n_rays_per_image_sampled_from_mask: 1024
|
n_rays_per_image_sampled_from_mask: 1024
|
||||||
n_pts_per_ray_training: 0
|
n_pts_per_ray_training: 0
|
||||||
n_pts_per_ray_evaluation: 0
|
n_pts_per_ray_evaluation: 0
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
defaults:
|
defaults:
|
||||||
- repro_singleseq_base
|
- repro_singleseq_base
|
||||||
- _self_
|
- _self_
|
||||||
exp_dir: ./data/nerf_single_apple/
|
|
||||||
|
@ -5,5 +5,5 @@ defaults:
|
|||||||
generic_model_args:
|
generic_model_args:
|
||||||
chunk_size_grid: 16000
|
chunk_size_grid: 16000
|
||||||
view_pooler_enabled: true
|
view_pooler_enabled: true
|
||||||
raysampler_args:
|
raysampler_AdaptiveRaySampler_args:
|
||||||
n_rays_per_image_sampled_from_mask: 850
|
n_rays_per_image_sampled_from_mask: 850
|
||||||
|
@ -6,7 +6,7 @@ generic_model_args:
|
|||||||
chunk_size_grid: 16000
|
chunk_size_grid: 16000
|
||||||
view_pooler_enabled: true
|
view_pooler_enabled: true
|
||||||
implicit_function_class_type: NeRFormerImplicitFunction
|
implicit_function_class_type: NeRFormerImplicitFunction
|
||||||
raysampler_args:
|
raysampler_AdaptiveRaySampler_args:
|
||||||
n_rays_per_image_sampled_from_mask: 800
|
n_rays_per_image_sampled_from_mask: 800
|
||||||
n_pts_per_ray_training: 32
|
n_pts_per_ray_training: 32
|
||||||
n_pts_per_ray_evaluation: 32
|
n_pts_per_ray_evaluation: 32
|
||||||
|
@ -12,11 +12,11 @@ generic_model_args:
|
|||||||
loss_prev_stage_mask_bce: 0.0
|
loss_prev_stage_mask_bce: 0.0
|
||||||
loss_autodecoder_norm: 0.0
|
loss_autodecoder_norm: 0.0
|
||||||
depth_neg_penalty: 10000.0
|
depth_neg_penalty: 10000.0
|
||||||
raysampler_args:
|
raysampler_class_type: NearFarRaySampler
|
||||||
|
raysampler_NearFarRaySampler_args:
|
||||||
n_rays_per_image_sampled_from_mask: 2048
|
n_rays_per_image_sampled_from_mask: 2048
|
||||||
min_depth: 0.05
|
min_depth: 0.05
|
||||||
max_depth: 0.05
|
max_depth: 0.05
|
||||||
scene_extent: 0.0
|
|
||||||
n_pts_per_ray_training: 1
|
n_pts_per_ray_training: 1
|
||||||
n_pts_per_ray_evaluation: 1
|
n_pts_per_ray_evaluation: 1
|
||||||
stratified_point_sampling_training: false
|
stratified_point_sampling_training: false
|
||||||
|
@ -13,11 +13,11 @@ generic_model_args:
|
|||||||
loss_prev_stage_mask_bce: 0.0
|
loss_prev_stage_mask_bce: 0.0
|
||||||
loss_autodecoder_norm: 0.0
|
loss_autodecoder_norm: 0.0
|
||||||
depth_neg_penalty: 10000.0
|
depth_neg_penalty: 10000.0
|
||||||
raysampler_args:
|
raysampler_class_type: NearFarRaySampler
|
||||||
|
raysampler_NearFarRaySampler_args:
|
||||||
n_rays_per_image_sampled_from_mask: 2048
|
n_rays_per_image_sampled_from_mask: 2048
|
||||||
min_depth: 0.05
|
min_depth: 0.05
|
||||||
max_depth: 0.05
|
max_depth: 0.05
|
||||||
scene_extent: 0.0
|
|
||||||
n_pts_per_ray_training: 1
|
n_pts_per_ray_training: 1
|
||||||
n_pts_per_ray_evaluation: 1
|
n_pts_per_ray_evaluation: 1
|
||||||
stratified_point_sampling_training: false
|
stratified_point_sampling_training: false
|
||||||
|
@ -19,7 +19,7 @@ from pytorch3d.implicitron.tools.config import (
|
|||||||
run_auto_creation,
|
run_auto_creation,
|
||||||
)
|
)
|
||||||
from pytorch3d.implicitron.tools.rasterize_mc import rasterize_mc_samples
|
from pytorch3d.implicitron.tools.rasterize_mc import rasterize_mc_samples
|
||||||
from pytorch3d.implicitron.tools.utils import cat_dataclass
|
from pytorch3d.implicitron.tools.utils import cat_dataclass, setattr_if_hasattr
|
||||||
from pytorch3d.renderer import RayBundle, utils as rend_utils
|
from pytorch3d.renderer import RayBundle, utils as rend_utils
|
||||||
from pytorch3d.renderer.cameras import CamerasBase
|
from pytorch3d.renderer.cameras import CamerasBase
|
||||||
from visdom import Visdom
|
from visdom import Visdom
|
||||||
@ -46,7 +46,7 @@ from .renderer.base import (
|
|||||||
)
|
)
|
||||||
from .renderer.lstm_renderer import LSTMRenderer # noqa
|
from .renderer.lstm_renderer import LSTMRenderer # noqa
|
||||||
from .renderer.multipass_ea import MultiPassEmissionAbsorptionRenderer # noqa
|
from .renderer.multipass_ea import MultiPassEmissionAbsorptionRenderer # noqa
|
||||||
from .renderer.ray_sampler import RaySampler
|
from .renderer.ray_sampler import RaySamplerBase
|
||||||
from .renderer.sdf_renderer import SignedDistanceFunctionRenderer # noqa
|
from .renderer.sdf_renderer import SignedDistanceFunctionRenderer # noqa
|
||||||
from .resnet_feature_extractor import ResNetFeatureExtractor
|
from .resnet_feature_extractor import ResNetFeatureExtractor
|
||||||
from .view_pooler.view_pooler import ViewPooler
|
from .view_pooler.view_pooler import ViewPooler
|
||||||
@ -160,6 +160,8 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
|
|||||||
the scene such as multiple objects or morphing objects. It is up to the implicit
|
the scene such as multiple objects or morphing objects. It is up to the implicit
|
||||||
function definition how to use it, but the most typical way is to broadcast and
|
function definition how to use it, but the most typical way is to broadcast and
|
||||||
concatenate to the other inputs for the implicit function.
|
concatenate to the other inputs for the implicit function.
|
||||||
|
raysampler_class_type: The name of the raysampler class which is available
|
||||||
|
in the global registry.
|
||||||
raysampler: An instance of RaySampler which is used to emit
|
raysampler: An instance of RaySampler which is used to emit
|
||||||
rays from the target view(s).
|
rays from the target view(s).
|
||||||
renderer_class_type: The name of the renderer class which is available in the global
|
renderer_class_type: The name of the renderer class which is available in the global
|
||||||
@ -204,7 +206,8 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
|
|||||||
sequence_autodecoder: Autodecoder
|
sequence_autodecoder: Autodecoder
|
||||||
|
|
||||||
# ---- raysampler
|
# ---- raysampler
|
||||||
raysampler: RaySampler
|
raysampler_class_type: str = "AdaptiveRaySampler"
|
||||||
|
raysampler: RaySamplerBase
|
||||||
|
|
||||||
# ---- renderer configs
|
# ---- renderer configs
|
||||||
renderer_class_type: str = "MultiPassEmissionAbsorptionRenderer"
|
renderer_class_type: str = "MultiPassEmissionAbsorptionRenderer"
|
||||||
@ -262,11 +265,6 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.view_metrics = ViewMetrics()
|
self.view_metrics = ViewMetrics()
|
||||||
|
|
||||||
self._check_and_preprocess_renderer_configs()
|
|
||||||
self.raysampler_args["sampling_mode_training"] = self.sampling_mode_training
|
|
||||||
self.raysampler_args["sampling_mode_evaluation"] = self.sampling_mode_evaluation
|
|
||||||
self.raysampler_args["image_width"] = self.render_image_width
|
|
||||||
self.raysampler_args["image_height"] = self.render_image_height
|
|
||||||
run_auto_creation(self)
|
run_auto_creation(self)
|
||||||
|
|
||||||
self._implicit_functions = self._construct_implicit_functions()
|
self._implicit_functions = self._construct_implicit_functions()
|
||||||
@ -339,7 +337,7 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
|
|||||||
)
|
)
|
||||||
|
|
||||||
# (1) Sample rendering rays with the ray sampler.
|
# (1) Sample rendering rays with the ray sampler.
|
||||||
ray_bundle: RayBundle = self.raysampler(
|
ray_bundle: RayBundle = self.raysampler( # pyre-fixme[29]
|
||||||
target_cameras,
|
target_cameras,
|
||||||
evaluation_mode,
|
evaluation_mode,
|
||||||
mask=mask_crop[:n_targets]
|
mask=mask_crop[:n_targets]
|
||||||
@ -565,19 +563,52 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
|
|||||||
else 0
|
else 0
|
||||||
)
|
)
|
||||||
|
|
||||||
def _check_and_preprocess_renderer_configs(self):
|
def create_raysampler(self):
|
||||||
|
raysampler_args = getattr(
|
||||||
|
self, "raysampler_" + self.raysampler_class_type + "_args"
|
||||||
|
)
|
||||||
|
setattr_if_hasattr(
|
||||||
|
raysampler_args, "sampling_mode_training", self.sampling_mode_training
|
||||||
|
)
|
||||||
|
setattr_if_hasattr(
|
||||||
|
raysampler_args, "sampling_mode_evaluation", self.sampling_mode_evaluation
|
||||||
|
)
|
||||||
|
setattr_if_hasattr(raysampler_args, "image_width", self.render_image_width)
|
||||||
|
setattr_if_hasattr(raysampler_args, "image_height", self.render_image_height)
|
||||||
|
self.raysampler = registry.get(RaySamplerBase, self.raysampler_class_type)(
|
||||||
|
**raysampler_args
|
||||||
|
)
|
||||||
|
|
||||||
|
def create_renderer(self):
|
||||||
|
raysampler_args = getattr(
|
||||||
|
self, "raysampler_" + self.raysampler_class_type + "_args"
|
||||||
|
)
|
||||||
self.renderer_MultiPassEmissionAbsorptionRenderer_args[
|
self.renderer_MultiPassEmissionAbsorptionRenderer_args[
|
||||||
"stratified_sampling_coarse_training"
|
"stratified_sampling_coarse_training"
|
||||||
] = self.raysampler_args["stratified_point_sampling_training"]
|
] = raysampler_args["stratified_point_sampling_training"]
|
||||||
self.renderer_MultiPassEmissionAbsorptionRenderer_args[
|
self.renderer_MultiPassEmissionAbsorptionRenderer_args[
|
||||||
"stratified_sampling_coarse_evaluation"
|
"stratified_sampling_coarse_evaluation"
|
||||||
] = self.raysampler_args["stratified_point_sampling_evaluation"]
|
] = raysampler_args["stratified_point_sampling_evaluation"]
|
||||||
self.renderer_SignedDistanceFunctionRenderer_args[
|
self.renderer_SignedDistanceFunctionRenderer_args[
|
||||||
"render_features_dimensions"
|
"render_features_dimensions"
|
||||||
] = self.render_features_dimensions
|
] = self.render_features_dimensions
|
||||||
self.renderer_SignedDistanceFunctionRenderer_args.ray_tracer_args[
|
|
||||||
"object_bounding_sphere"
|
if self.renderer_class_type == "SignedDistanceFunctionRenderer":
|
||||||
] = self.raysampler_args["scene_extent"]
|
if "scene_extent" not in raysampler_args:
|
||||||
|
raise ValueError(
|
||||||
|
"SignedDistanceFunctionRenderer requires"
|
||||||
|
+ " a raysampler that defines the 'scene_extent' field"
|
||||||
|
+ " (this field is supported by, e.g., the adaptive raysampler - "
|
||||||
|
+ " self.raysampler_class_type='AdaptiveRaySampler')."
|
||||||
|
)
|
||||||
|
self.renderer_SignedDistanceFunctionRenderer_args.ray_tracer_args[
|
||||||
|
"object_bounding_sphere"
|
||||||
|
] = self.raysampler_AdaptiveRaySampler_args["scene_extent"]
|
||||||
|
|
||||||
|
renderer_args = getattr(self, "renderer_" + self.renderer_class_type + "_args")
|
||||||
|
self.renderer = registry.get(BaseRenderer, self.renderer_class_type)(
|
||||||
|
**renderer_args
|
||||||
|
)
|
||||||
|
|
||||||
def create_view_pooler(self):
|
def create_view_pooler(self):
|
||||||
"""
|
"""
|
||||||
|
@ -9,16 +9,48 @@ from typing import Optional, Tuple
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from pytorch3d.implicitron.tools import camera_utils
|
from pytorch3d.implicitron.tools import camera_utils
|
||||||
from pytorch3d.implicitron.tools.config import Configurable
|
from pytorch3d.implicitron.tools.config import ReplaceableBase, registry
|
||||||
from pytorch3d.renderer import NDCMultinomialRaysampler, RayBundle
|
from pytorch3d.renderer import NDCMultinomialRaysampler, RayBundle
|
||||||
from pytorch3d.renderer.cameras import CamerasBase
|
from pytorch3d.renderer.cameras import CamerasBase
|
||||||
|
|
||||||
from .base import EvaluationMode, RenderSamplingMode
|
from .base import EvaluationMode, RenderSamplingMode
|
||||||
|
|
||||||
|
|
||||||
class RaySampler(Configurable, torch.nn.Module):
|
class RaySamplerBase(ReplaceableBase):
|
||||||
|
"""
|
||||||
|
Base class for ray samplers.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
cameras: CamerasBase,
|
||||||
|
evaluation_mode: EvaluationMode,
|
||||||
|
mask: Optional[torch.Tensor] = None,
|
||||||
|
) -> RayBundle:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
cameras: A batch of `batch_size` cameras from which the rays are emitted.
|
||||||
|
evaluation_mode: one of `EvaluationMode.TRAINING` or
|
||||||
|
`EvaluationMode.EVALUATION` which determines the sampling mode
|
||||||
|
that is used.
|
||||||
|
mask: Active for the `RenderSamplingMode.MASK_SAMPLE` sampling mode.
|
||||||
|
Defines a non-negative mask of shape
|
||||||
|
`(batch_size, image_height, image_width)` where each per-pixel
|
||||||
|
value is proportional to the probability of sampling the
|
||||||
|
corresponding pixel's ray.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ray_bundle: A `RayBundle` object containing the parametrizations of the
|
||||||
|
sampled rendering rays.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
|
class AbstractMaskRaySampler(RaySamplerBase, torch.nn.Module):
|
||||||
|
"""
|
||||||
Samples a fixed number of points along rays which are in turn sampled for
|
Samples a fixed number of points along rays which are in turn sampled for
|
||||||
each camera in a batch.
|
each camera in a batch.
|
||||||
|
|
||||||
@ -29,46 +61,19 @@ class RaySampler(Configurable, torch.nn.Module):
|
|||||||
for training and evaluation by setting `self.sampling_mode_training`
|
for training and evaluation by setting `self.sampling_mode_training`
|
||||||
and `self.sampling_mode_training` accordingly.
|
and `self.sampling_mode_training` accordingly.
|
||||||
|
|
||||||
The class allows two modes of sampling points along the rays:
|
The class allows to adjust the sampling points along rays by overwriting the
|
||||||
1) Sampling between fixed near and far z-planes:
|
`AbstractMaskRaySampler._get_min_max_depth_bounds` function which returns
|
||||||
Active when `self.scene_extent <= 0`, samples points along each ray
|
the near/far planes (`min_depth`/`max_depth`) `NDCMultinomialRaysampler`.
|
||||||
with approximately uniform spacing of z-coordinates between
|
|
||||||
the minimum depth `self.min_depth` and the maximum depth `self.max_depth`.
|
|
||||||
This sampling is useful for rendering scenes where the camera is
|
|
||||||
in a constant distance from the focal point of the scene.
|
|
||||||
2) Adaptive near/far plane estimation around the world scene center:
|
|
||||||
Active when `self.scene_extent > 0`. Samples points on each
|
|
||||||
ray between near and far planes whose depths are determined based on
|
|
||||||
the distance from the camera center to a predefined scene center.
|
|
||||||
More specifically,
|
|
||||||
`min_depth = max(
|
|
||||||
(self.scene_center-camera_center).norm() - self.scene_extent, eps
|
|
||||||
)` and
|
|
||||||
`max_depth = (self.scene_center-camera_center).norm() + self.scene_extent`.
|
|
||||||
This sampling is ideal for object-centric scenes whose contents are
|
|
||||||
centered around a known `self.scene_center` and fit into a bounding sphere
|
|
||||||
with a radius of `self.scene_extent`.
|
|
||||||
|
|
||||||
Similar to the sampling mode, the sampling parameters can be set separately
|
|
||||||
for training and evaluation.
|
|
||||||
|
|
||||||
Settings:
|
Settings:
|
||||||
image_width: The horizontal size of the image grid.
|
image_width: The horizontal size of the image grid.
|
||||||
image_height: The vertical size of the image grid.
|
image_height: The vertical size of the image grid.
|
||||||
scene_center: The xyz coordinates of the center of the scene used
|
|
||||||
along with `scene_extent` to compute the min and max depth planes
|
|
||||||
for sampling ray-points.
|
|
||||||
scene_extent: The radius of the scene bounding sphere centered at `scene_center`.
|
|
||||||
If `scene_extent <= 0`, the raysampler samples points between
|
|
||||||
`self.min_depth` and `self.max_depth` depths instead.
|
|
||||||
sampling_mode_training: The ray sampling mode for training. This should be a str
|
sampling_mode_training: The ray sampling mode for training. This should be a str
|
||||||
option from the RenderSamplingMode Enum
|
option from the RenderSamplingMode Enum
|
||||||
sampling_mode_evaluation: Same as above but for evaluation.
|
sampling_mode_evaluation: Same as above but for evaluation.
|
||||||
n_pts_per_ray_training: The number of points sampled along each ray during training.
|
n_pts_per_ray_training: The number of points sampled along each ray during training.
|
||||||
n_pts_per_ray_evaluation: The number of points sampled along each ray during evaluation.
|
n_pts_per_ray_evaluation: The number of points sampled along each ray during evaluation.
|
||||||
n_rays_per_image_sampled_from_mask: The amount of rays to be sampled from the image grid
|
n_rays_per_image_sampled_from_mask: The amount of rays to be sampled from the image grid
|
||||||
min_depth: The minimum depth of a ray-point. Active when `self.scene_extent > 0`.
|
|
||||||
max_depth: The maximum depth of a ray-point. Active when `self.scene_extent > 0`.
|
|
||||||
stratified_point_sampling_training: if set, performs stratified random sampling
|
stratified_point_sampling_training: if set, performs stratified random sampling
|
||||||
along the ray; otherwise takes ray points at deterministic offsets.
|
along the ray; otherwise takes ray points at deterministic offsets.
|
||||||
stratified_point_sampling_evaluation: Same as above but for evaluation.
|
stratified_point_sampling_evaluation: Same as above but for evaluation.
|
||||||
@ -77,24 +82,17 @@ class RaySampler(Configurable, torch.nn.Module):
|
|||||||
|
|
||||||
image_width: int = 400
|
image_width: int = 400
|
||||||
image_height: int = 400
|
image_height: int = 400
|
||||||
scene_center: Tuple[float, float, float] = field(
|
|
||||||
default_factory=lambda: (0.0, 0.0, 0.0)
|
|
||||||
)
|
|
||||||
scene_extent: float = 0.0
|
|
||||||
sampling_mode_training: str = "mask_sample"
|
sampling_mode_training: str = "mask_sample"
|
||||||
sampling_mode_evaluation: str = "full_grid"
|
sampling_mode_evaluation: str = "full_grid"
|
||||||
n_pts_per_ray_training: int = 64
|
n_pts_per_ray_training: int = 64
|
||||||
n_pts_per_ray_evaluation: int = 64
|
n_pts_per_ray_evaluation: int = 64
|
||||||
n_rays_per_image_sampled_from_mask: int = 1024
|
n_rays_per_image_sampled_from_mask: int = 1024
|
||||||
min_depth: float = 0.1
|
|
||||||
max_depth: float = 8.0
|
|
||||||
# stratified sampling vs taking points at deterministic offsets
|
# stratified sampling vs taking points at deterministic offsets
|
||||||
stratified_point_sampling_training: bool = True
|
stratified_point_sampling_training: bool = True
|
||||||
stratified_point_sampling_evaluation: bool = False
|
stratified_point_sampling_evaluation: bool = False
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.scene_center = torch.FloatTensor(self.scene_center)
|
|
||||||
|
|
||||||
self._sampling_mode = {
|
self._sampling_mode = {
|
||||||
EvaluationMode.TRAINING: RenderSamplingMode(self.sampling_mode_training),
|
EvaluationMode.TRAINING: RenderSamplingMode(self.sampling_mode_training),
|
||||||
@ -108,8 +106,8 @@ class RaySampler(Configurable, torch.nn.Module):
|
|||||||
image_width=self.image_width,
|
image_width=self.image_width,
|
||||||
image_height=self.image_height,
|
image_height=self.image_height,
|
||||||
n_pts_per_ray=self.n_pts_per_ray_training,
|
n_pts_per_ray=self.n_pts_per_ray_training,
|
||||||
min_depth=self.min_depth,
|
min_depth=0.0,
|
||||||
max_depth=self.max_depth,
|
max_depth=0.0,
|
||||||
n_rays_per_image=self.n_rays_per_image_sampled_from_mask
|
n_rays_per_image=self.n_rays_per_image_sampled_from_mask
|
||||||
if self._sampling_mode[EvaluationMode.TRAINING]
|
if self._sampling_mode[EvaluationMode.TRAINING]
|
||||||
== RenderSamplingMode.MASK_SAMPLE
|
== RenderSamplingMode.MASK_SAMPLE
|
||||||
@ -121,8 +119,8 @@ class RaySampler(Configurable, torch.nn.Module):
|
|||||||
image_width=self.image_width,
|
image_width=self.image_width,
|
||||||
image_height=self.image_height,
|
image_height=self.image_height,
|
||||||
n_pts_per_ray=self.n_pts_per_ray_evaluation,
|
n_pts_per_ray=self.n_pts_per_ray_evaluation,
|
||||||
min_depth=self.min_depth,
|
min_depth=0.0,
|
||||||
max_depth=self.max_depth,
|
max_depth=0.0,
|
||||||
n_rays_per_image=self.n_rays_per_image_sampled_from_mask
|
n_rays_per_image=self.n_rays_per_image_sampled_from_mask
|
||||||
if self._sampling_mode[EvaluationMode.EVALUATION]
|
if self._sampling_mode[EvaluationMode.EVALUATION]
|
||||||
== RenderSamplingMode.MASK_SAMPLE
|
== RenderSamplingMode.MASK_SAMPLE
|
||||||
@ -132,6 +130,9 @@ class RaySampler(Configurable, torch.nn.Module):
|
|||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def _get_min_max_depth_bounds(self, cameras: CamerasBase) -> Tuple[float, float]:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
cameras: CamerasBase,
|
cameras: CamerasBase,
|
||||||
@ -169,12 +170,7 @@ class RaySampler(Configurable, torch.nn.Module):
|
|||||||
mode="nearest",
|
mode="nearest",
|
||||||
)[:, 0]
|
)[:, 0]
|
||||||
|
|
||||||
if self.scene_extent > 0.0:
|
min_depth, max_depth = self._get_min_max_depth_bounds(cameras)
|
||||||
# Override the min/max depth set in initialization based on the
|
|
||||||
# input cameras.
|
|
||||||
min_depth, max_depth = camera_utils.get_min_max_depth_bounds(
|
|
||||||
cameras, self.scene_center, self.scene_extent
|
|
||||||
)
|
|
||||||
|
|
||||||
# pyre-fixme[29]:
|
# pyre-fixme[29]:
|
||||||
# `Union[BoundMethod[typing.Callable(torch.Tensor.__getitem__)[[Named(self,
|
# `Union[BoundMethod[typing.Callable(torch.Tensor.__getitem__)[[Named(self,
|
||||||
@ -183,8 +179,75 @@ class RaySampler(Configurable, torch.nn.Module):
|
|||||||
ray_bundle = self._raysamplers[evaluation_mode](
|
ray_bundle = self._raysamplers[evaluation_mode](
|
||||||
cameras=cameras,
|
cameras=cameras,
|
||||||
mask=sample_mask,
|
mask=sample_mask,
|
||||||
min_depth=float(min_depth[0]) if self.scene_extent > 0.0 else None,
|
min_depth=min_depth,
|
||||||
max_depth=float(max_depth[0]) if self.scene_extent > 0.0 else None,
|
max_depth=max_depth,
|
||||||
)
|
)
|
||||||
|
|
||||||
return ray_bundle
|
return ray_bundle
|
||||||
|
|
||||||
|
|
||||||
|
@registry.register
|
||||||
|
class AdaptiveRaySampler(AbstractMaskRaySampler):
|
||||||
|
"""
|
||||||
|
Adaptively samples points on each ray between near and far planes whose
|
||||||
|
depths are determined based on the distance from the camera center
|
||||||
|
to a predefined scene center.
|
||||||
|
|
||||||
|
More specifically,
|
||||||
|
`min_depth = max(
|
||||||
|
(self.scene_center-camera_center).norm() - self.scene_extent, eps
|
||||||
|
)` and
|
||||||
|
`max_depth = (self.scene_center-camera_center).norm() + self.scene_extent`.
|
||||||
|
|
||||||
|
This sampling is ideal for object-centric scenes whose contents are
|
||||||
|
centered around a known `self.scene_center` and fit into a bounding sphere
|
||||||
|
with a radius of `self.scene_extent`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scene_center: The xyz coordinates of the center of the scene used
|
||||||
|
along with `scene_extent` to compute the min and max depth planes
|
||||||
|
for sampling ray-points.
|
||||||
|
scene_extent: The radius of the scene bounding box centered at `scene_center`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
scene_extent: float = 8.0
|
||||||
|
scene_center: Tuple[float, float, float] = (0.0, 0.0, 0.0)
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
super().__post_init__()
|
||||||
|
if self.scene_extent <= 0.0:
|
||||||
|
raise ValueError("Adaptive raysampler requires self.scene_extent > 0.")
|
||||||
|
self._scene_center = torch.FloatTensor(self.scene_center)
|
||||||
|
|
||||||
|
def _get_min_max_depth_bounds(self, cameras: CamerasBase) -> Tuple[float, float]:
|
||||||
|
"""
|
||||||
|
Returns the adaptivelly calculated near/far planes.
|
||||||
|
"""
|
||||||
|
min_depth, max_depth = camera_utils.get_min_max_depth_bounds(
|
||||||
|
cameras, self._scene_center, self.scene_extent
|
||||||
|
)
|
||||||
|
return float(min_depth[0]), float(max_depth[0])
|
||||||
|
|
||||||
|
|
||||||
|
@registry.register
|
||||||
|
class NearFarRaySampler(AbstractMaskRaySampler):
|
||||||
|
"""
|
||||||
|
Samples a fixed number of points between fixed near and far z-planes.
|
||||||
|
Specifically, samples points along each ray with approximately uniform spacing
|
||||||
|
of z-coordinates between the minimum depth `self.min_depth` and the maximum depth
|
||||||
|
`self.max_depth`. This sampling is useful for rendering scenes where the camera is
|
||||||
|
in a constant distance from the focal point of the scene.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
min_depth: The minimum depth of a ray-point.
|
||||||
|
max_depth: The maximum depth of a ray-point.
|
||||||
|
"""
|
||||||
|
|
||||||
|
min_depth: float = 0.1
|
||||||
|
max_depth: float = 8.0
|
||||||
|
|
||||||
|
def _get_min_max_depth_bounds(self, cameras: CamerasBase) -> Tuple[float, float]:
|
||||||
|
"""
|
||||||
|
Returns the stored near/far planes.
|
||||||
|
"""
|
||||||
|
return self.min_depth, self.max_depth
|
||||||
|
@ -157,6 +157,15 @@ def cat_dataclass(batch, tensor_collator: Callable):
|
|||||||
return type(elem)(**collated)
|
return type(elem)(**collated)
|
||||||
|
|
||||||
|
|
||||||
|
def setattr_if_hasattr(obj, name, value):
|
||||||
|
"""
|
||||||
|
Same as setattr(obj, name, value), but does nothing in case `name` is
|
||||||
|
not an attribe of `obj`.
|
||||||
|
"""
|
||||||
|
if hasattr(obj, name):
|
||||||
|
setattr(obj, name, value)
|
||||||
|
|
||||||
|
|
||||||
class Timer:
|
class Timer:
|
||||||
"""
|
"""
|
||||||
A simple class for timing execution.
|
A simple class for timing execution.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user