mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 18:02:14 +08:00 
			
		
		
		
	Raysampler as pluggable
Summary: This converts raysamplers to ReplaceableBase so that users can hack their own raysampling impls. Context: Andrea tried to implement TensoRF within implicitron but could not due to the need to implement his own raysampler. Reviewed By: shapovalov Differential Revision: D36016318 fbshipit-source-id: ef746f3365282bdfa9c15f7b371090a5aae7f8da
This commit is contained in:
		
							parent
							
								
									e85fa03c5a
								
							
						
					
					
						commit
						e767c4b548
					
				@ -49,10 +49,8 @@ generic_model_args:
 | 
			
		||||
    append_xyz:
 | 
			
		||||
    - 5
 | 
			
		||||
    latent_dim: 0
 | 
			
		||||
  raysampler_args:
 | 
			
		||||
  raysampler_AdaptiveRaySampler_args:
 | 
			
		||||
    n_rays_per_image_sampled_from_mask: 1024
 | 
			
		||||
    min_depth: 0.0
 | 
			
		||||
    max_depth: 0.0
 | 
			
		||||
    scene_extent: 8.0
 | 
			
		||||
    n_pts_per_ray_training: 64
 | 
			
		||||
    n_pts_per_ray_evaluation: 64
 | 
			
		||||
 | 
			
		||||
@ -54,7 +54,7 @@ generic_model_args:
 | 
			
		||||
      n_harmonic_functions_dir: 4
 | 
			
		||||
      pooled_feature_dim: 0
 | 
			
		||||
      weight_norm: true
 | 
			
		||||
  raysampler_args:
 | 
			
		||||
  raysampler_AdaptiveRaySampler_args:
 | 
			
		||||
    n_rays_per_image_sampled_from_mask: 1024
 | 
			
		||||
    n_pts_per_ray_training: 0
 | 
			
		||||
    n_pts_per_ray_evaluation: 0
 | 
			
		||||
 | 
			
		||||
@ -6,5 +6,5 @@ clip_grad: 1.0
 | 
			
		||||
generic_model_args:
 | 
			
		||||
  chunk_size_grid: 16000
 | 
			
		||||
  view_pooler_enabled: true
 | 
			
		||||
  raysampler_args:
 | 
			
		||||
  raysampler_AdaptiveRaySampler_args:
 | 
			
		||||
    n_rays_per_image_sampled_from_mask: 850
 | 
			
		||||
 | 
			
		||||
@ -4,7 +4,7 @@ defaults:
 | 
			
		||||
- _self_
 | 
			
		||||
generic_model_args:
 | 
			
		||||
  chunk_size_grid: 16000
 | 
			
		||||
  raysampler_args:
 | 
			
		||||
  raysampler_AdaptiveRaySampler_args:
 | 
			
		||||
    n_rays_per_image_sampled_from_mask: 800
 | 
			
		||||
    n_pts_per_ray_training: 32
 | 
			
		||||
    n_pts_per_ray_evaluation: 32
 | 
			
		||||
 | 
			
		||||
@ -1,17 +1,6 @@
 | 
			
		||||
defaults:
 | 
			
		||||
- repro_multiseq_base.yaml
 | 
			
		||||
- repro_feat_extractor_transformer.yaml
 | 
			
		||||
- repro_multiseq_nerformer.yaml
 | 
			
		||||
- _self_
 | 
			
		||||
generic_model_args:
 | 
			
		||||
  chunk_size_grid: 16000
 | 
			
		||||
  raysampler_args:
 | 
			
		||||
    n_rays_per_image_sampled_from_mask: 800
 | 
			
		||||
    n_pts_per_ray_training: 32
 | 
			
		||||
    n_pts_per_ray_evaluation: 32
 | 
			
		||||
  renderer_MultiPassEmissionAbsorptionRenderer_args:
 | 
			
		||||
    n_pts_per_ray_fine_training: 16
 | 
			
		||||
    n_pts_per_ray_fine_evaluation: 16
 | 
			
		||||
  implicit_function_class_type: NeRFormerImplicitFunction
 | 
			
		||||
  view_pooler_enabled: true
 | 
			
		||||
  view_pooler_args:
 | 
			
		||||
    feature_aggregator_class_type: AngleWeightedIdentityFeatureAggregator
 | 
			
		||||
 | 
			
		||||
@ -16,11 +16,11 @@ generic_model_args:
 | 
			
		||||
  sequence_autodecoder_args:
 | 
			
		||||
    encoding_dim: 256
 | 
			
		||||
    n_instances: 20000
 | 
			
		||||
  raysampler_args:
 | 
			
		||||
  raysampler_class_type: NearFarRaySampler
 | 
			
		||||
  raysampler_NearFarRaySampler_args:
 | 
			
		||||
    n_rays_per_image_sampled_from_mask: 2048
 | 
			
		||||
    min_depth: 0.05
 | 
			
		||||
    max_depth: 0.05
 | 
			
		||||
    scene_extent: 0.0
 | 
			
		||||
    n_pts_per_ray_training: 1
 | 
			
		||||
    n_pts_per_ray_evaluation: 1
 | 
			
		||||
    stratified_point_sampling_training: false
 | 
			
		||||
 | 
			
		||||
@ -13,11 +13,11 @@ generic_model_args:
 | 
			
		||||
    loss_prev_stage_mask_bce: 0.0
 | 
			
		||||
    loss_autodecoder_norm: 0.0
 | 
			
		||||
    depth_neg_penalty: 10000.0
 | 
			
		||||
  raysampler_args:
 | 
			
		||||
  raysampler_class_type: NearFarRaySampler
 | 
			
		||||
  raysampler_NearFarRaySampler_args:
 | 
			
		||||
    n_rays_per_image_sampled_from_mask: 2048
 | 
			
		||||
    min_depth: 0.05
 | 
			
		||||
    max_depth: 0.05
 | 
			
		||||
    scene_extent: 0.0
 | 
			
		||||
    n_pts_per_ray_training: 1
 | 
			
		||||
    n_pts_per_ray_evaluation: 1
 | 
			
		||||
    stratified_point_sampling_training: false
 | 
			
		||||
 | 
			
		||||
@ -49,7 +49,7 @@ generic_model_args:
 | 
			
		||||
      n_harmonic_functions_dir: 4
 | 
			
		||||
      pooled_feature_dim: 0
 | 
			
		||||
      weight_norm: true
 | 
			
		||||
  raysampler_args:
 | 
			
		||||
  raysampler_AdaptiveRaySampler_args:
 | 
			
		||||
    n_rays_per_image_sampled_from_mask: 1024
 | 
			
		||||
    n_pts_per_ray_training: 0
 | 
			
		||||
    n_pts_per_ray_evaluation: 0
 | 
			
		||||
 | 
			
		||||
@ -1,4 +1,3 @@
 | 
			
		||||
defaults:
 | 
			
		||||
- repro_singleseq_base
 | 
			
		||||
- _self_
 | 
			
		||||
exp_dir: ./data/nerf_single_apple/
 | 
			
		||||
 | 
			
		||||
@ -5,5 +5,5 @@ defaults:
 | 
			
		||||
generic_model_args:
 | 
			
		||||
  chunk_size_grid: 16000
 | 
			
		||||
  view_pooler_enabled: true
 | 
			
		||||
  raysampler_args:
 | 
			
		||||
  raysampler_AdaptiveRaySampler_args:
 | 
			
		||||
    n_rays_per_image_sampled_from_mask: 850
 | 
			
		||||
 | 
			
		||||
@ -6,7 +6,7 @@ generic_model_args:
 | 
			
		||||
  chunk_size_grid: 16000
 | 
			
		||||
  view_pooler_enabled: true
 | 
			
		||||
  implicit_function_class_type: NeRFormerImplicitFunction
 | 
			
		||||
  raysampler_args:
 | 
			
		||||
  raysampler_AdaptiveRaySampler_args:
 | 
			
		||||
    n_rays_per_image_sampled_from_mask: 800
 | 
			
		||||
    n_pts_per_ray_training: 32
 | 
			
		||||
    n_pts_per_ray_evaluation: 32
 | 
			
		||||
 | 
			
		||||
@ -12,11 +12,11 @@ generic_model_args:
 | 
			
		||||
    loss_prev_stage_mask_bce: 0.0
 | 
			
		||||
    loss_autodecoder_norm: 0.0
 | 
			
		||||
    depth_neg_penalty: 10000.0
 | 
			
		||||
  raysampler_args:
 | 
			
		||||
  raysampler_class_type: NearFarRaySampler
 | 
			
		||||
  raysampler_NearFarRaySampler_args:
 | 
			
		||||
    n_rays_per_image_sampled_from_mask: 2048
 | 
			
		||||
    min_depth: 0.05
 | 
			
		||||
    max_depth: 0.05
 | 
			
		||||
    scene_extent: 0.0
 | 
			
		||||
    n_pts_per_ray_training: 1
 | 
			
		||||
    n_pts_per_ray_evaluation: 1
 | 
			
		||||
    stratified_point_sampling_training: false
 | 
			
		||||
 | 
			
		||||
@ -13,11 +13,11 @@ generic_model_args:
 | 
			
		||||
    loss_prev_stage_mask_bce: 0.0
 | 
			
		||||
    loss_autodecoder_norm: 0.0
 | 
			
		||||
    depth_neg_penalty: 10000.0
 | 
			
		||||
  raysampler_args:
 | 
			
		||||
  raysampler_class_type: NearFarRaySampler
 | 
			
		||||
  raysampler_NearFarRaySampler_args:
 | 
			
		||||
    n_rays_per_image_sampled_from_mask: 2048
 | 
			
		||||
    min_depth: 0.05
 | 
			
		||||
    max_depth: 0.05
 | 
			
		||||
    scene_extent: 0.0
 | 
			
		||||
    n_pts_per_ray_training: 1
 | 
			
		||||
    n_pts_per_ray_evaluation: 1
 | 
			
		||||
    stratified_point_sampling_training: false
 | 
			
		||||
 | 
			
		||||
@ -19,7 +19,7 @@ from pytorch3d.implicitron.tools.config import (
 | 
			
		||||
    run_auto_creation,
 | 
			
		||||
)
 | 
			
		||||
from pytorch3d.implicitron.tools.rasterize_mc import rasterize_mc_samples
 | 
			
		||||
from pytorch3d.implicitron.tools.utils import cat_dataclass
 | 
			
		||||
from pytorch3d.implicitron.tools.utils import cat_dataclass, setattr_if_hasattr
 | 
			
		||||
from pytorch3d.renderer import RayBundle, utils as rend_utils
 | 
			
		||||
from pytorch3d.renderer.cameras import CamerasBase
 | 
			
		||||
from visdom import Visdom
 | 
			
		||||
@ -46,7 +46,7 @@ from .renderer.base import (
 | 
			
		||||
)
 | 
			
		||||
from .renderer.lstm_renderer import LSTMRenderer  # noqa
 | 
			
		||||
from .renderer.multipass_ea import MultiPassEmissionAbsorptionRenderer  # noqa
 | 
			
		||||
from .renderer.ray_sampler import RaySampler
 | 
			
		||||
from .renderer.ray_sampler import RaySamplerBase
 | 
			
		||||
from .renderer.sdf_renderer import SignedDistanceFunctionRenderer  # noqa
 | 
			
		||||
from .resnet_feature_extractor import ResNetFeatureExtractor
 | 
			
		||||
from .view_pooler.view_pooler import ViewPooler
 | 
			
		||||
@ -160,6 +160,8 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module):  # pyre-ignore: 13
 | 
			
		||||
            the scene such as multiple objects or morphing objects. It is up to the implicit
 | 
			
		||||
            function definition how to use it, but the most typical way is to broadcast and
 | 
			
		||||
            concatenate to the other inputs for the implicit function.
 | 
			
		||||
        raysampler_class_type: The name of the raysampler class which is available
 | 
			
		||||
            in the global registry.
 | 
			
		||||
        raysampler: An instance of RaySampler which is used to emit
 | 
			
		||||
            rays from the target view(s).
 | 
			
		||||
        renderer_class_type: The name of the renderer class which is available in the global
 | 
			
		||||
@ -204,7 +206,8 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module):  # pyre-ignore: 13
 | 
			
		||||
    sequence_autodecoder: Autodecoder
 | 
			
		||||
 | 
			
		||||
    # ---- raysampler
 | 
			
		||||
    raysampler: RaySampler
 | 
			
		||||
    raysampler_class_type: str = "AdaptiveRaySampler"
 | 
			
		||||
    raysampler: RaySamplerBase
 | 
			
		||||
 | 
			
		||||
    # ---- renderer configs
 | 
			
		||||
    renderer_class_type: str = "MultiPassEmissionAbsorptionRenderer"
 | 
			
		||||
@ -262,11 +265,6 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module):  # pyre-ignore: 13
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self.view_metrics = ViewMetrics()
 | 
			
		||||
 | 
			
		||||
        self._check_and_preprocess_renderer_configs()
 | 
			
		||||
        self.raysampler_args["sampling_mode_training"] = self.sampling_mode_training
 | 
			
		||||
        self.raysampler_args["sampling_mode_evaluation"] = self.sampling_mode_evaluation
 | 
			
		||||
        self.raysampler_args["image_width"] = self.render_image_width
 | 
			
		||||
        self.raysampler_args["image_height"] = self.render_image_height
 | 
			
		||||
        run_auto_creation(self)
 | 
			
		||||
 | 
			
		||||
        self._implicit_functions = self._construct_implicit_functions()
 | 
			
		||||
@ -339,7 +337,7 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module):  # pyre-ignore: 13
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # (1) Sample rendering rays with the ray sampler.
 | 
			
		||||
        ray_bundle: RayBundle = self.raysampler(
 | 
			
		||||
        ray_bundle: RayBundle = self.raysampler(  # pyre-fixme[29]
 | 
			
		||||
            target_cameras,
 | 
			
		||||
            evaluation_mode,
 | 
			
		||||
            mask=mask_crop[:n_targets]
 | 
			
		||||
@ -565,19 +563,52 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module):  # pyre-ignore: 13
 | 
			
		||||
            else 0
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def _check_and_preprocess_renderer_configs(self):
 | 
			
		||||
    def create_raysampler(self):
 | 
			
		||||
        raysampler_args = getattr(
 | 
			
		||||
            self, "raysampler_" + self.raysampler_class_type + "_args"
 | 
			
		||||
        )
 | 
			
		||||
        setattr_if_hasattr(
 | 
			
		||||
            raysampler_args, "sampling_mode_training", self.sampling_mode_training
 | 
			
		||||
        )
 | 
			
		||||
        setattr_if_hasattr(
 | 
			
		||||
            raysampler_args, "sampling_mode_evaluation", self.sampling_mode_evaluation
 | 
			
		||||
        )
 | 
			
		||||
        setattr_if_hasattr(raysampler_args, "image_width", self.render_image_width)
 | 
			
		||||
        setattr_if_hasattr(raysampler_args, "image_height", self.render_image_height)
 | 
			
		||||
        self.raysampler = registry.get(RaySamplerBase, self.raysampler_class_type)(
 | 
			
		||||
            **raysampler_args
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def create_renderer(self):
 | 
			
		||||
        raysampler_args = getattr(
 | 
			
		||||
            self, "raysampler_" + self.raysampler_class_type + "_args"
 | 
			
		||||
        )
 | 
			
		||||
        self.renderer_MultiPassEmissionAbsorptionRenderer_args[
 | 
			
		||||
            "stratified_sampling_coarse_training"
 | 
			
		||||
        ] = self.raysampler_args["stratified_point_sampling_training"]
 | 
			
		||||
        ] = raysampler_args["stratified_point_sampling_training"]
 | 
			
		||||
        self.renderer_MultiPassEmissionAbsorptionRenderer_args[
 | 
			
		||||
            "stratified_sampling_coarse_evaluation"
 | 
			
		||||
        ] = self.raysampler_args["stratified_point_sampling_evaluation"]
 | 
			
		||||
        ] = raysampler_args["stratified_point_sampling_evaluation"]
 | 
			
		||||
        self.renderer_SignedDistanceFunctionRenderer_args[
 | 
			
		||||
            "render_features_dimensions"
 | 
			
		||||
        ] = self.render_features_dimensions
 | 
			
		||||
        self.renderer_SignedDistanceFunctionRenderer_args.ray_tracer_args[
 | 
			
		||||
            "object_bounding_sphere"
 | 
			
		||||
        ] = self.raysampler_args["scene_extent"]
 | 
			
		||||
 | 
			
		||||
        if self.renderer_class_type == "SignedDistanceFunctionRenderer":
 | 
			
		||||
            if "scene_extent" not in raysampler_args:
 | 
			
		||||
                raise ValueError(
 | 
			
		||||
                    "SignedDistanceFunctionRenderer requires"
 | 
			
		||||
                    + " a raysampler that defines the 'scene_extent' field"
 | 
			
		||||
                    + " (this field is supported by, e.g., the adaptive raysampler - "
 | 
			
		||||
                    + " self.raysampler_class_type='AdaptiveRaySampler')."
 | 
			
		||||
                )
 | 
			
		||||
            self.renderer_SignedDistanceFunctionRenderer_args.ray_tracer_args[
 | 
			
		||||
                "object_bounding_sphere"
 | 
			
		||||
            ] = self.raysampler_AdaptiveRaySampler_args["scene_extent"]
 | 
			
		||||
 | 
			
		||||
        renderer_args = getattr(self, "renderer_" + self.renderer_class_type + "_args")
 | 
			
		||||
        self.renderer = registry.get(BaseRenderer, self.renderer_class_type)(
 | 
			
		||||
            **renderer_args
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def create_view_pooler(self):
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
@ -9,16 +9,48 @@ from typing import Optional, Tuple
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from pytorch3d.implicitron.tools import camera_utils
 | 
			
		||||
from pytorch3d.implicitron.tools.config import Configurable
 | 
			
		||||
from pytorch3d.implicitron.tools.config import ReplaceableBase, registry
 | 
			
		||||
from pytorch3d.renderer import NDCMultinomialRaysampler, RayBundle
 | 
			
		||||
from pytorch3d.renderer.cameras import CamerasBase
 | 
			
		||||
 | 
			
		||||
from .base import EvaluationMode, RenderSamplingMode
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class RaySampler(Configurable, torch.nn.Module):
 | 
			
		||||
class RaySamplerBase(ReplaceableBase):
 | 
			
		||||
    """
 | 
			
		||||
    Base class for ray samplers.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
    def forward(
 | 
			
		||||
        self,
 | 
			
		||||
        cameras: CamerasBase,
 | 
			
		||||
        evaluation_mode: EvaluationMode,
 | 
			
		||||
        mask: Optional[torch.Tensor] = None,
 | 
			
		||||
    ) -> RayBundle:
 | 
			
		||||
        """
 | 
			
		||||
        Args:
 | 
			
		||||
            cameras: A batch of `batch_size` cameras from which the rays are emitted.
 | 
			
		||||
            evaluation_mode: one of `EvaluationMode.TRAINING` or
 | 
			
		||||
                `EvaluationMode.EVALUATION` which determines the sampling mode
 | 
			
		||||
                that is used.
 | 
			
		||||
            mask: Active for the `RenderSamplingMode.MASK_SAMPLE` sampling mode.
 | 
			
		||||
                Defines a non-negative mask of shape
 | 
			
		||||
                `(batch_size, image_height, image_width)` where each per-pixel
 | 
			
		||||
                value is proportional to the probability of sampling the
 | 
			
		||||
                corresponding pixel's ray.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            ray_bundle: A `RayBundle` object containing the parametrizations of the
 | 
			
		||||
                sampled rendering rays.
 | 
			
		||||
        """
 | 
			
		||||
        raise NotImplementedError()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class AbstractMaskRaySampler(RaySamplerBase, torch.nn.Module):
 | 
			
		||||
    """
 | 
			
		||||
    Samples a fixed number of points along rays which are in turn sampled for
 | 
			
		||||
    each camera in a batch.
 | 
			
		||||
 | 
			
		||||
@ -29,46 +61,19 @@ class RaySampler(Configurable, torch.nn.Module):
 | 
			
		||||
    for training and evaluation by setting `self.sampling_mode_training`
 | 
			
		||||
    and `self.sampling_mode_training` accordingly.
 | 
			
		||||
 | 
			
		||||
    The class allows two modes of sampling points along the rays:
 | 
			
		||||
        1) Sampling between fixed near and far z-planes:
 | 
			
		||||
            Active when `self.scene_extent <= 0`, samples points along each ray
 | 
			
		||||
            with approximately uniform spacing of z-coordinates between
 | 
			
		||||
            the minimum depth `self.min_depth` and the maximum depth `self.max_depth`.
 | 
			
		||||
            This sampling is useful for rendering scenes where the camera is
 | 
			
		||||
            in a constant distance from the focal point of the scene.
 | 
			
		||||
        2) Adaptive near/far plane estimation around the world scene center:
 | 
			
		||||
            Active when `self.scene_extent > 0`. Samples points on each
 | 
			
		||||
            ray between near and far planes whose depths are determined based on
 | 
			
		||||
            the distance from the camera center to a predefined scene center.
 | 
			
		||||
            More specifically,
 | 
			
		||||
            `min_depth = max(
 | 
			
		||||
                (self.scene_center-camera_center).norm() - self.scene_extent, eps
 | 
			
		||||
            )` and
 | 
			
		||||
            `max_depth = (self.scene_center-camera_center).norm() + self.scene_extent`.
 | 
			
		||||
            This sampling is ideal for object-centric scenes whose contents are
 | 
			
		||||
            centered around a known `self.scene_center` and fit into a bounding sphere
 | 
			
		||||
            with a radius of `self.scene_extent`.
 | 
			
		||||
 | 
			
		||||
        Similar to the sampling mode, the sampling parameters can be set separately
 | 
			
		||||
        for training and evaluation.
 | 
			
		||||
    The class allows to adjust the sampling points along rays by overwriting the
 | 
			
		||||
    `AbstractMaskRaySampler._get_min_max_depth_bounds` function which returns
 | 
			
		||||
    the near/far planes (`min_depth`/`max_depth`) `NDCMultinomialRaysampler`.
 | 
			
		||||
 | 
			
		||||
    Settings:
 | 
			
		||||
        image_width: The horizontal size of the image grid.
 | 
			
		||||
        image_height: The vertical size of the image grid.
 | 
			
		||||
        scene_center: The xyz coordinates of the center of the scene used
 | 
			
		||||
            along with `scene_extent` to compute the min and max depth planes
 | 
			
		||||
            for sampling ray-points.
 | 
			
		||||
        scene_extent: The radius of the scene bounding sphere centered at `scene_center`.
 | 
			
		||||
            If `scene_extent <= 0`, the raysampler samples points between
 | 
			
		||||
            `self.min_depth` and `self.max_depth` depths instead.
 | 
			
		||||
        sampling_mode_training: The ray sampling mode for training. This should be a str
 | 
			
		||||
            option from the RenderSamplingMode Enum
 | 
			
		||||
        sampling_mode_evaluation: Same as above but for evaluation.
 | 
			
		||||
        n_pts_per_ray_training: The number of points sampled along each ray during training.
 | 
			
		||||
        n_pts_per_ray_evaluation: The number of points sampled along each ray during evaluation.
 | 
			
		||||
        n_rays_per_image_sampled_from_mask: The amount of rays to be sampled from the image grid
 | 
			
		||||
        min_depth: The minimum depth of a ray-point. Active when `self.scene_extent > 0`.
 | 
			
		||||
        max_depth: The maximum depth of a ray-point. Active when `self.scene_extent > 0`.
 | 
			
		||||
        stratified_point_sampling_training: if set, performs stratified random sampling
 | 
			
		||||
            along the ray; otherwise takes ray points at deterministic offsets.
 | 
			
		||||
        stratified_point_sampling_evaluation: Same as above but for evaluation.
 | 
			
		||||
@ -77,24 +82,17 @@ class RaySampler(Configurable, torch.nn.Module):
 | 
			
		||||
 | 
			
		||||
    image_width: int = 400
 | 
			
		||||
    image_height: int = 400
 | 
			
		||||
    scene_center: Tuple[float, float, float] = field(
 | 
			
		||||
        default_factory=lambda: (0.0, 0.0, 0.0)
 | 
			
		||||
    )
 | 
			
		||||
    scene_extent: float = 0.0
 | 
			
		||||
    sampling_mode_training: str = "mask_sample"
 | 
			
		||||
    sampling_mode_evaluation: str = "full_grid"
 | 
			
		||||
    n_pts_per_ray_training: int = 64
 | 
			
		||||
    n_pts_per_ray_evaluation: int = 64
 | 
			
		||||
    n_rays_per_image_sampled_from_mask: int = 1024
 | 
			
		||||
    min_depth: float = 0.1
 | 
			
		||||
    max_depth: float = 8.0
 | 
			
		||||
    # stratified sampling vs taking points at deterministic offsets
 | 
			
		||||
    stratified_point_sampling_training: bool = True
 | 
			
		||||
    stratified_point_sampling_evaluation: bool = False
 | 
			
		||||
 | 
			
		||||
    def __post_init__(self):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self.scene_center = torch.FloatTensor(self.scene_center)
 | 
			
		||||
 | 
			
		||||
        self._sampling_mode = {
 | 
			
		||||
            EvaluationMode.TRAINING: RenderSamplingMode(self.sampling_mode_training),
 | 
			
		||||
@ -108,8 +106,8 @@ class RaySampler(Configurable, torch.nn.Module):
 | 
			
		||||
                image_width=self.image_width,
 | 
			
		||||
                image_height=self.image_height,
 | 
			
		||||
                n_pts_per_ray=self.n_pts_per_ray_training,
 | 
			
		||||
                min_depth=self.min_depth,
 | 
			
		||||
                max_depth=self.max_depth,
 | 
			
		||||
                min_depth=0.0,
 | 
			
		||||
                max_depth=0.0,
 | 
			
		||||
                n_rays_per_image=self.n_rays_per_image_sampled_from_mask
 | 
			
		||||
                if self._sampling_mode[EvaluationMode.TRAINING]
 | 
			
		||||
                == RenderSamplingMode.MASK_SAMPLE
 | 
			
		||||
@ -121,8 +119,8 @@ class RaySampler(Configurable, torch.nn.Module):
 | 
			
		||||
                image_width=self.image_width,
 | 
			
		||||
                image_height=self.image_height,
 | 
			
		||||
                n_pts_per_ray=self.n_pts_per_ray_evaluation,
 | 
			
		||||
                min_depth=self.min_depth,
 | 
			
		||||
                max_depth=self.max_depth,
 | 
			
		||||
                min_depth=0.0,
 | 
			
		||||
                max_depth=0.0,
 | 
			
		||||
                n_rays_per_image=self.n_rays_per_image_sampled_from_mask
 | 
			
		||||
                if self._sampling_mode[EvaluationMode.EVALUATION]
 | 
			
		||||
                == RenderSamplingMode.MASK_SAMPLE
 | 
			
		||||
@ -132,6 +130,9 @@ class RaySampler(Configurable, torch.nn.Module):
 | 
			
		||||
            ),
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
    def _get_min_max_depth_bounds(self, cameras: CamerasBase) -> Tuple[float, float]:
 | 
			
		||||
        raise NotImplementedError()
 | 
			
		||||
 | 
			
		||||
    def forward(
 | 
			
		||||
        self,
 | 
			
		||||
        cameras: CamerasBase,
 | 
			
		||||
@ -169,12 +170,7 @@ class RaySampler(Configurable, torch.nn.Module):
 | 
			
		||||
                mode="nearest",
 | 
			
		||||
            )[:, 0]
 | 
			
		||||
 | 
			
		||||
        if self.scene_extent > 0.0:
 | 
			
		||||
            # Override the min/max depth set in initialization based on the
 | 
			
		||||
            # input cameras.
 | 
			
		||||
            min_depth, max_depth = camera_utils.get_min_max_depth_bounds(
 | 
			
		||||
                cameras, self.scene_center, self.scene_extent
 | 
			
		||||
            )
 | 
			
		||||
        min_depth, max_depth = self._get_min_max_depth_bounds(cameras)
 | 
			
		||||
 | 
			
		||||
        # pyre-fixme[29]:
 | 
			
		||||
        #  `Union[BoundMethod[typing.Callable(torch.Tensor.__getitem__)[[Named(self,
 | 
			
		||||
@ -183,8 +179,75 @@ class RaySampler(Configurable, torch.nn.Module):
 | 
			
		||||
        ray_bundle = self._raysamplers[evaluation_mode](
 | 
			
		||||
            cameras=cameras,
 | 
			
		||||
            mask=sample_mask,
 | 
			
		||||
            min_depth=float(min_depth[0]) if self.scene_extent > 0.0 else None,
 | 
			
		||||
            max_depth=float(max_depth[0]) if self.scene_extent > 0.0 else None,
 | 
			
		||||
            min_depth=min_depth,
 | 
			
		||||
            max_depth=max_depth,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        return ray_bundle
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@registry.register
 | 
			
		||||
class AdaptiveRaySampler(AbstractMaskRaySampler):
 | 
			
		||||
    """
 | 
			
		||||
    Adaptively samples points on each ray between near and far planes whose
 | 
			
		||||
    depths are determined based on the distance from the camera center
 | 
			
		||||
    to a predefined scene center.
 | 
			
		||||
 | 
			
		||||
    More specifically,
 | 
			
		||||
    `min_depth = max(
 | 
			
		||||
        (self.scene_center-camera_center).norm() - self.scene_extent, eps
 | 
			
		||||
    )` and
 | 
			
		||||
    `max_depth = (self.scene_center-camera_center).norm() + self.scene_extent`.
 | 
			
		||||
 | 
			
		||||
    This sampling is ideal for object-centric scenes whose contents are
 | 
			
		||||
    centered around a known `self.scene_center` and fit into a bounding sphere
 | 
			
		||||
    with a radius of `self.scene_extent`.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        scene_center: The xyz coordinates of the center of the scene used
 | 
			
		||||
            along with `scene_extent` to compute the min and max depth planes
 | 
			
		||||
            for sampling ray-points.
 | 
			
		||||
        scene_extent: The radius of the scene bounding box centered at `scene_center`.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    scene_extent: float = 8.0
 | 
			
		||||
    scene_center: Tuple[float, float, float] = (0.0, 0.0, 0.0)
 | 
			
		||||
 | 
			
		||||
    def __post_init__(self):
 | 
			
		||||
        super().__post_init__()
 | 
			
		||||
        if self.scene_extent <= 0.0:
 | 
			
		||||
            raise ValueError("Adaptive raysampler requires self.scene_extent > 0.")
 | 
			
		||||
        self._scene_center = torch.FloatTensor(self.scene_center)
 | 
			
		||||
 | 
			
		||||
    def _get_min_max_depth_bounds(self, cameras: CamerasBase) -> Tuple[float, float]:
 | 
			
		||||
        """
 | 
			
		||||
        Returns the adaptivelly calculated near/far planes.
 | 
			
		||||
        """
 | 
			
		||||
        min_depth, max_depth = camera_utils.get_min_max_depth_bounds(
 | 
			
		||||
            cameras, self._scene_center, self.scene_extent
 | 
			
		||||
        )
 | 
			
		||||
        return float(min_depth[0]), float(max_depth[0])
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@registry.register
 | 
			
		||||
class NearFarRaySampler(AbstractMaskRaySampler):
 | 
			
		||||
    """
 | 
			
		||||
    Samples a fixed number of points between fixed near and far z-planes.
 | 
			
		||||
    Specifically, samples points along each ray with approximately uniform spacing
 | 
			
		||||
    of z-coordinates between the minimum depth `self.min_depth` and the maximum depth
 | 
			
		||||
    `self.max_depth`. This sampling is useful for rendering scenes where the camera is
 | 
			
		||||
    in a constant distance from the focal point of the scene.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        min_depth: The minimum depth of a ray-point.
 | 
			
		||||
        max_depth: The maximum depth of a ray-point.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    min_depth: float = 0.1
 | 
			
		||||
    max_depth: float = 8.0
 | 
			
		||||
 | 
			
		||||
    def _get_min_max_depth_bounds(self, cameras: CamerasBase) -> Tuple[float, float]:
 | 
			
		||||
        """
 | 
			
		||||
        Returns the stored near/far planes.
 | 
			
		||||
        """
 | 
			
		||||
        return self.min_depth, self.max_depth
 | 
			
		||||
 | 
			
		||||
@ -157,6 +157,15 @@ def cat_dataclass(batch, tensor_collator: Callable):
 | 
			
		||||
    return type(elem)(**collated)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def setattr_if_hasattr(obj, name, value):
 | 
			
		||||
    """
 | 
			
		||||
    Same as setattr(obj, name, value), but does nothing in case `name` is
 | 
			
		||||
    not an attribe of `obj`.
 | 
			
		||||
    """
 | 
			
		||||
    if hasattr(obj, name):
 | 
			
		||||
        setattr(obj, name, value)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Timer:
 | 
			
		||||
    """
 | 
			
		||||
    A simple class for timing execution.
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user