diff --git a/projects/implicitron_trainer/configs/repro_base.yaml b/projects/implicitron_trainer/configs/repro_base.yaml index a65401f7..fbfbba49 100644 --- a/projects/implicitron_trainer/configs/repro_base.yaml +++ b/projects/implicitron_trainer/configs/repro_base.yaml @@ -49,10 +49,8 @@ generic_model_args: append_xyz: - 5 latent_dim: 0 - raysampler_args: + raysampler_AdaptiveRaySampler_args: n_rays_per_image_sampled_from_mask: 1024 - min_depth: 0.0 - max_depth: 0.0 scene_extent: 8.0 n_pts_per_ray_training: 64 n_pts_per_ray_evaluation: 64 diff --git a/projects/implicitron_trainer/configs/repro_multiseq_idr_ad.yaml b/projects/implicitron_trainer/configs/repro_multiseq_idr_ad.yaml index 3be35876..d0181eec 100644 --- a/projects/implicitron_trainer/configs/repro_multiseq_idr_ad.yaml +++ b/projects/implicitron_trainer/configs/repro_multiseq_idr_ad.yaml @@ -54,7 +54,7 @@ generic_model_args: n_harmonic_functions_dir: 4 pooled_feature_dim: 0 weight_norm: true - raysampler_args: + raysampler_AdaptiveRaySampler_args: n_rays_per_image_sampled_from_mask: 1024 n_pts_per_ray_training: 0 n_pts_per_ray_evaluation: 0 diff --git a/projects/implicitron_trainer/configs/repro_multiseq_nerf_wce.yaml b/projects/implicitron_trainer/configs/repro_multiseq_nerf_wce.yaml index 2fa03575..00140db6 100644 --- a/projects/implicitron_trainer/configs/repro_multiseq_nerf_wce.yaml +++ b/projects/implicitron_trainer/configs/repro_multiseq_nerf_wce.yaml @@ -6,5 +6,5 @@ clip_grad: 1.0 generic_model_args: chunk_size_grid: 16000 view_pooler_enabled: true - raysampler_args: + raysampler_AdaptiveRaySampler_args: n_rays_per_image_sampled_from_mask: 850 diff --git a/projects/implicitron_trainer/configs/repro_multiseq_nerformer.yaml b/projects/implicitron_trainer/configs/repro_multiseq_nerformer.yaml index 213e1124..c4a20f6e 100644 --- a/projects/implicitron_trainer/configs/repro_multiseq_nerformer.yaml +++ b/projects/implicitron_trainer/configs/repro_multiseq_nerformer.yaml @@ -4,7 +4,7 @@ defaults: - _self_ generic_model_args: chunk_size_grid: 16000 - raysampler_args: + raysampler_AdaptiveRaySampler_args: n_rays_per_image_sampled_from_mask: 800 n_pts_per_ray_training: 32 n_pts_per_ray_evaluation: 32 diff --git a/projects/implicitron_trainer/configs/repro_multiseq_nerformer_angle_w.yaml b/projects/implicitron_trainer/configs/repro_multiseq_nerformer_angle_w.yaml index 982c5eaa..61f6ebb4 100644 --- a/projects/implicitron_trainer/configs/repro_multiseq_nerformer_angle_w.yaml +++ b/projects/implicitron_trainer/configs/repro_multiseq_nerformer_angle_w.yaml @@ -1,17 +1,6 @@ defaults: -- repro_multiseq_base.yaml -- repro_feat_extractor_transformer.yaml +- repro_multiseq_nerformer.yaml - _self_ generic_model_args: - chunk_size_grid: 16000 - raysampler_args: - n_rays_per_image_sampled_from_mask: 800 - n_pts_per_ray_training: 32 - n_pts_per_ray_evaluation: 32 - renderer_MultiPassEmissionAbsorptionRenderer_args: - n_pts_per_ray_fine_training: 16 - n_pts_per_ray_fine_evaluation: 16 - implicit_function_class_type: NeRFormerImplicitFunction - view_pooler_enabled: true view_pooler_args: feature_aggregator_class_type: AngleWeightedIdentityFeatureAggregator diff --git a/projects/implicitron_trainer/configs/repro_multiseq_srn_ad_hypernet.yaml b/projects/implicitron_trainer/configs/repro_multiseq_srn_ad_hypernet.yaml index e719f646..8c8ef0d7 100644 --- a/projects/implicitron_trainer/configs/repro_multiseq_srn_ad_hypernet.yaml +++ b/projects/implicitron_trainer/configs/repro_multiseq_srn_ad_hypernet.yaml @@ -16,11 +16,11 @@ generic_model_args: sequence_autodecoder_args: encoding_dim: 256 n_instances: 20000 - raysampler_args: + raysampler_class_type: NearFarRaySampler + raysampler_NearFarRaySampler_args: n_rays_per_image_sampled_from_mask: 2048 min_depth: 0.05 max_depth: 0.05 - scene_extent: 0.0 n_pts_per_ray_training: 1 n_pts_per_ray_evaluation: 1 stratified_point_sampling_training: false diff --git a/projects/implicitron_trainer/configs/repro_multiseq_srn_wce.yaml b/projects/implicitron_trainer/configs/repro_multiseq_srn_wce.yaml index 02425a6b..d340c18a 100644 --- a/projects/implicitron_trainer/configs/repro_multiseq_srn_wce.yaml +++ b/projects/implicitron_trainer/configs/repro_multiseq_srn_wce.yaml @@ -13,11 +13,11 @@ generic_model_args: loss_prev_stage_mask_bce: 0.0 loss_autodecoder_norm: 0.0 depth_neg_penalty: 10000.0 - raysampler_args: + raysampler_class_type: NearFarRaySampler + raysampler_NearFarRaySampler_args: n_rays_per_image_sampled_from_mask: 2048 min_depth: 0.05 max_depth: 0.05 - scene_extent: 0.0 n_pts_per_ray_training: 1 n_pts_per_ray_evaluation: 1 stratified_point_sampling_training: false diff --git a/projects/implicitron_trainer/configs/repro_singleseq_idr.yaml b/projects/implicitron_trainer/configs/repro_singleseq_idr.yaml index 1baa2523..bb587056 100644 --- a/projects/implicitron_trainer/configs/repro_singleseq_idr.yaml +++ b/projects/implicitron_trainer/configs/repro_singleseq_idr.yaml @@ -49,7 +49,7 @@ generic_model_args: n_harmonic_functions_dir: 4 pooled_feature_dim: 0 weight_norm: true - raysampler_args: + raysampler_AdaptiveRaySampler_args: n_rays_per_image_sampled_from_mask: 1024 n_pts_per_ray_training: 0 n_pts_per_ray_evaluation: 0 diff --git a/projects/implicitron_trainer/configs/repro_singleseq_nerf.yaml b/projects/implicitron_trainer/configs/repro_singleseq_nerf.yaml index d6d45585..fd85af5e 100644 --- a/projects/implicitron_trainer/configs/repro_singleseq_nerf.yaml +++ b/projects/implicitron_trainer/configs/repro_singleseq_nerf.yaml @@ -1,4 +1,3 @@ defaults: - repro_singleseq_base - _self_ -exp_dir: ./data/nerf_single_apple/ diff --git a/projects/implicitron_trainer/configs/repro_singleseq_nerf_wce.yaml b/projects/implicitron_trainer/configs/repro_singleseq_nerf_wce.yaml index 31ee863e..5a587dbe 100644 --- a/projects/implicitron_trainer/configs/repro_singleseq_nerf_wce.yaml +++ b/projects/implicitron_trainer/configs/repro_singleseq_nerf_wce.yaml @@ -5,5 +5,5 @@ defaults: generic_model_args: chunk_size_grid: 16000 view_pooler_enabled: true - raysampler_args: + raysampler_AdaptiveRaySampler_args: n_rays_per_image_sampled_from_mask: 850 diff --git a/projects/implicitron_trainer/configs/repro_singleseq_nerformer.yaml b/projects/implicitron_trainer/configs/repro_singleseq_nerformer.yaml index 2fad73c6..37b08dfa 100644 --- a/projects/implicitron_trainer/configs/repro_singleseq_nerformer.yaml +++ b/projects/implicitron_trainer/configs/repro_singleseq_nerformer.yaml @@ -6,7 +6,7 @@ generic_model_args: chunk_size_grid: 16000 view_pooler_enabled: true implicit_function_class_type: NeRFormerImplicitFunction - raysampler_args: + raysampler_AdaptiveRaySampler_args: n_rays_per_image_sampled_from_mask: 800 n_pts_per_ray_training: 32 n_pts_per_ray_evaluation: 32 diff --git a/projects/implicitron_trainer/configs/repro_singleseq_srn.yaml b/projects/implicitron_trainer/configs/repro_singleseq_srn.yaml index e6a84489..6500f56a 100644 --- a/projects/implicitron_trainer/configs/repro_singleseq_srn.yaml +++ b/projects/implicitron_trainer/configs/repro_singleseq_srn.yaml @@ -12,11 +12,11 @@ generic_model_args: loss_prev_stage_mask_bce: 0.0 loss_autodecoder_norm: 0.0 depth_neg_penalty: 10000.0 - raysampler_args: + raysampler_class_type: NearFarRaySampler + raysampler_NearFarRaySampler_args: n_rays_per_image_sampled_from_mask: 2048 min_depth: 0.05 max_depth: 0.05 - scene_extent: 0.0 n_pts_per_ray_training: 1 n_pts_per_ray_evaluation: 1 stratified_point_sampling_training: false diff --git a/projects/implicitron_trainer/configs/repro_singleseq_srn_wce.yaml b/projects/implicitron_trainer/configs/repro_singleseq_srn_wce.yaml index c66b21c2..3da29f06 100644 --- a/projects/implicitron_trainer/configs/repro_singleseq_srn_wce.yaml +++ b/projects/implicitron_trainer/configs/repro_singleseq_srn_wce.yaml @@ -13,11 +13,11 @@ generic_model_args: loss_prev_stage_mask_bce: 0.0 loss_autodecoder_norm: 0.0 depth_neg_penalty: 10000.0 - raysampler_args: + raysampler_class_type: NearFarRaySampler + raysampler_NearFarRaySampler_args: n_rays_per_image_sampled_from_mask: 2048 min_depth: 0.05 max_depth: 0.05 - scene_extent: 0.0 n_pts_per_ray_training: 1 n_pts_per_ray_evaluation: 1 stratified_point_sampling_training: false diff --git a/pytorch3d/implicitron/models/generic_model.py b/pytorch3d/implicitron/models/generic_model.py index 80483644..6311c6d5 100644 --- a/pytorch3d/implicitron/models/generic_model.py +++ b/pytorch3d/implicitron/models/generic_model.py @@ -19,7 +19,7 @@ from pytorch3d.implicitron.tools.config import ( run_auto_creation, ) from pytorch3d.implicitron.tools.rasterize_mc import rasterize_mc_samples -from pytorch3d.implicitron.tools.utils import cat_dataclass +from pytorch3d.implicitron.tools.utils import cat_dataclass, setattr_if_hasattr from pytorch3d.renderer import RayBundle, utils as rend_utils from pytorch3d.renderer.cameras import CamerasBase from visdom import Visdom @@ -46,7 +46,7 @@ from .renderer.base import ( ) from .renderer.lstm_renderer import LSTMRenderer # noqa from .renderer.multipass_ea import MultiPassEmissionAbsorptionRenderer # noqa -from .renderer.ray_sampler import RaySampler +from .renderer.ray_sampler import RaySamplerBase from .renderer.sdf_renderer import SignedDistanceFunctionRenderer # noqa from .resnet_feature_extractor import ResNetFeatureExtractor from .view_pooler.view_pooler import ViewPooler @@ -160,6 +160,8 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13 the scene such as multiple objects or morphing objects. It is up to the implicit function definition how to use it, but the most typical way is to broadcast and concatenate to the other inputs for the implicit function. + raysampler_class_type: The name of the raysampler class which is available + in the global registry. raysampler: An instance of RaySampler which is used to emit rays from the target view(s). renderer_class_type: The name of the renderer class which is available in the global @@ -204,7 +206,8 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13 sequence_autodecoder: Autodecoder # ---- raysampler - raysampler: RaySampler + raysampler_class_type: str = "AdaptiveRaySampler" + raysampler: RaySamplerBase # ---- renderer configs renderer_class_type: str = "MultiPassEmissionAbsorptionRenderer" @@ -262,11 +265,6 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13 super().__init__() self.view_metrics = ViewMetrics() - self._check_and_preprocess_renderer_configs() - self.raysampler_args["sampling_mode_training"] = self.sampling_mode_training - self.raysampler_args["sampling_mode_evaluation"] = self.sampling_mode_evaluation - self.raysampler_args["image_width"] = self.render_image_width - self.raysampler_args["image_height"] = self.render_image_height run_auto_creation(self) self._implicit_functions = self._construct_implicit_functions() @@ -339,7 +337,7 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13 ) # (1) Sample rendering rays with the ray sampler. - ray_bundle: RayBundle = self.raysampler( + ray_bundle: RayBundle = self.raysampler( # pyre-fixme[29] target_cameras, evaluation_mode, mask=mask_crop[:n_targets] @@ -565,19 +563,52 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13 else 0 ) - def _check_and_preprocess_renderer_configs(self): + def create_raysampler(self): + raysampler_args = getattr( + self, "raysampler_" + self.raysampler_class_type + "_args" + ) + setattr_if_hasattr( + raysampler_args, "sampling_mode_training", self.sampling_mode_training + ) + setattr_if_hasattr( + raysampler_args, "sampling_mode_evaluation", self.sampling_mode_evaluation + ) + setattr_if_hasattr(raysampler_args, "image_width", self.render_image_width) + setattr_if_hasattr(raysampler_args, "image_height", self.render_image_height) + self.raysampler = registry.get(RaySamplerBase, self.raysampler_class_type)( + **raysampler_args + ) + + def create_renderer(self): + raysampler_args = getattr( + self, "raysampler_" + self.raysampler_class_type + "_args" + ) self.renderer_MultiPassEmissionAbsorptionRenderer_args[ "stratified_sampling_coarse_training" - ] = self.raysampler_args["stratified_point_sampling_training"] + ] = raysampler_args["stratified_point_sampling_training"] self.renderer_MultiPassEmissionAbsorptionRenderer_args[ "stratified_sampling_coarse_evaluation" - ] = self.raysampler_args["stratified_point_sampling_evaluation"] + ] = raysampler_args["stratified_point_sampling_evaluation"] self.renderer_SignedDistanceFunctionRenderer_args[ "render_features_dimensions" ] = self.render_features_dimensions - self.renderer_SignedDistanceFunctionRenderer_args.ray_tracer_args[ - "object_bounding_sphere" - ] = self.raysampler_args["scene_extent"] + + if self.renderer_class_type == "SignedDistanceFunctionRenderer": + if "scene_extent" not in raysampler_args: + raise ValueError( + "SignedDistanceFunctionRenderer requires" + + " a raysampler that defines the 'scene_extent' field" + + " (this field is supported by, e.g., the adaptive raysampler - " + + " self.raysampler_class_type='AdaptiveRaySampler')." + ) + self.renderer_SignedDistanceFunctionRenderer_args.ray_tracer_args[ + "object_bounding_sphere" + ] = self.raysampler_AdaptiveRaySampler_args["scene_extent"] + + renderer_args = getattr(self, "renderer_" + self.renderer_class_type + "_args") + self.renderer = registry.get(BaseRenderer, self.renderer_class_type)( + **renderer_args + ) def create_view_pooler(self): """ diff --git a/pytorch3d/implicitron/models/renderer/ray_sampler.py b/pytorch3d/implicitron/models/renderer/ray_sampler.py index 4f5cee69..ef381a2f 100644 --- a/pytorch3d/implicitron/models/renderer/ray_sampler.py +++ b/pytorch3d/implicitron/models/renderer/ray_sampler.py @@ -9,16 +9,48 @@ from typing import Optional, Tuple import torch from pytorch3d.implicitron.tools import camera_utils -from pytorch3d.implicitron.tools.config import Configurable +from pytorch3d.implicitron.tools.config import ReplaceableBase, registry from pytorch3d.renderer import NDCMultinomialRaysampler, RayBundle from pytorch3d.renderer.cameras import CamerasBase from .base import EvaluationMode, RenderSamplingMode -class RaySampler(Configurable, torch.nn.Module): +class RaySamplerBase(ReplaceableBase): + """ + Base class for ray samplers. """ + def __init__(self): + super().__init__() + + def forward( + self, + cameras: CamerasBase, + evaluation_mode: EvaluationMode, + mask: Optional[torch.Tensor] = None, + ) -> RayBundle: + """ + Args: + cameras: A batch of `batch_size` cameras from which the rays are emitted. + evaluation_mode: one of `EvaluationMode.TRAINING` or + `EvaluationMode.EVALUATION` which determines the sampling mode + that is used. + mask: Active for the `RenderSamplingMode.MASK_SAMPLE` sampling mode. + Defines a non-negative mask of shape + `(batch_size, image_height, image_width)` where each per-pixel + value is proportional to the probability of sampling the + corresponding pixel's ray. + + Returns: + ray_bundle: A `RayBundle` object containing the parametrizations of the + sampled rendering rays. + """ + raise NotImplementedError() + + +class AbstractMaskRaySampler(RaySamplerBase, torch.nn.Module): + """ Samples a fixed number of points along rays which are in turn sampled for each camera in a batch. @@ -29,46 +61,19 @@ class RaySampler(Configurable, torch.nn.Module): for training and evaluation by setting `self.sampling_mode_training` and `self.sampling_mode_training` accordingly. - The class allows two modes of sampling points along the rays: - 1) Sampling between fixed near and far z-planes: - Active when `self.scene_extent <= 0`, samples points along each ray - with approximately uniform spacing of z-coordinates between - the minimum depth `self.min_depth` and the maximum depth `self.max_depth`. - This sampling is useful for rendering scenes where the camera is - in a constant distance from the focal point of the scene. - 2) Adaptive near/far plane estimation around the world scene center: - Active when `self.scene_extent > 0`. Samples points on each - ray between near and far planes whose depths are determined based on - the distance from the camera center to a predefined scene center. - More specifically, - `min_depth = max( - (self.scene_center-camera_center).norm() - self.scene_extent, eps - )` and - `max_depth = (self.scene_center-camera_center).norm() + self.scene_extent`. - This sampling is ideal for object-centric scenes whose contents are - centered around a known `self.scene_center` and fit into a bounding sphere - with a radius of `self.scene_extent`. - - Similar to the sampling mode, the sampling parameters can be set separately - for training and evaluation. + The class allows to adjust the sampling points along rays by overwriting the + `AbstractMaskRaySampler._get_min_max_depth_bounds` function which returns + the near/far planes (`min_depth`/`max_depth`) `NDCMultinomialRaysampler`. Settings: image_width: The horizontal size of the image grid. image_height: The vertical size of the image grid. - scene_center: The xyz coordinates of the center of the scene used - along with `scene_extent` to compute the min and max depth planes - for sampling ray-points. - scene_extent: The radius of the scene bounding sphere centered at `scene_center`. - If `scene_extent <= 0`, the raysampler samples points between - `self.min_depth` and `self.max_depth` depths instead. sampling_mode_training: The ray sampling mode for training. This should be a str option from the RenderSamplingMode Enum sampling_mode_evaluation: Same as above but for evaluation. n_pts_per_ray_training: The number of points sampled along each ray during training. n_pts_per_ray_evaluation: The number of points sampled along each ray during evaluation. n_rays_per_image_sampled_from_mask: The amount of rays to be sampled from the image grid - min_depth: The minimum depth of a ray-point. Active when `self.scene_extent > 0`. - max_depth: The maximum depth of a ray-point. Active when `self.scene_extent > 0`. stratified_point_sampling_training: if set, performs stratified random sampling along the ray; otherwise takes ray points at deterministic offsets. stratified_point_sampling_evaluation: Same as above but for evaluation. @@ -77,24 +82,17 @@ class RaySampler(Configurable, torch.nn.Module): image_width: int = 400 image_height: int = 400 - scene_center: Tuple[float, float, float] = field( - default_factory=lambda: (0.0, 0.0, 0.0) - ) - scene_extent: float = 0.0 sampling_mode_training: str = "mask_sample" sampling_mode_evaluation: str = "full_grid" n_pts_per_ray_training: int = 64 n_pts_per_ray_evaluation: int = 64 n_rays_per_image_sampled_from_mask: int = 1024 - min_depth: float = 0.1 - max_depth: float = 8.0 # stratified sampling vs taking points at deterministic offsets stratified_point_sampling_training: bool = True stratified_point_sampling_evaluation: bool = False def __post_init__(self): super().__init__() - self.scene_center = torch.FloatTensor(self.scene_center) self._sampling_mode = { EvaluationMode.TRAINING: RenderSamplingMode(self.sampling_mode_training), @@ -108,8 +106,8 @@ class RaySampler(Configurable, torch.nn.Module): image_width=self.image_width, image_height=self.image_height, n_pts_per_ray=self.n_pts_per_ray_training, - min_depth=self.min_depth, - max_depth=self.max_depth, + min_depth=0.0, + max_depth=0.0, n_rays_per_image=self.n_rays_per_image_sampled_from_mask if self._sampling_mode[EvaluationMode.TRAINING] == RenderSamplingMode.MASK_SAMPLE @@ -121,8 +119,8 @@ class RaySampler(Configurable, torch.nn.Module): image_width=self.image_width, image_height=self.image_height, n_pts_per_ray=self.n_pts_per_ray_evaluation, - min_depth=self.min_depth, - max_depth=self.max_depth, + min_depth=0.0, + max_depth=0.0, n_rays_per_image=self.n_rays_per_image_sampled_from_mask if self._sampling_mode[EvaluationMode.EVALUATION] == RenderSamplingMode.MASK_SAMPLE @@ -132,6 +130,9 @@ class RaySampler(Configurable, torch.nn.Module): ), } + def _get_min_max_depth_bounds(self, cameras: CamerasBase) -> Tuple[float, float]: + raise NotImplementedError() + def forward( self, cameras: CamerasBase, @@ -169,12 +170,7 @@ class RaySampler(Configurable, torch.nn.Module): mode="nearest", )[:, 0] - if self.scene_extent > 0.0: - # Override the min/max depth set in initialization based on the - # input cameras. - min_depth, max_depth = camera_utils.get_min_max_depth_bounds( - cameras, self.scene_center, self.scene_extent - ) + min_depth, max_depth = self._get_min_max_depth_bounds(cameras) # pyre-fixme[29]: # `Union[BoundMethod[typing.Callable(torch.Tensor.__getitem__)[[Named(self, @@ -183,8 +179,75 @@ class RaySampler(Configurable, torch.nn.Module): ray_bundle = self._raysamplers[evaluation_mode]( cameras=cameras, mask=sample_mask, - min_depth=float(min_depth[0]) if self.scene_extent > 0.0 else None, - max_depth=float(max_depth[0]) if self.scene_extent > 0.0 else None, + min_depth=min_depth, + max_depth=max_depth, ) return ray_bundle + + +@registry.register +class AdaptiveRaySampler(AbstractMaskRaySampler): + """ + Adaptively samples points on each ray between near and far planes whose + depths are determined based on the distance from the camera center + to a predefined scene center. + + More specifically, + `min_depth = max( + (self.scene_center-camera_center).norm() - self.scene_extent, eps + )` and + `max_depth = (self.scene_center-camera_center).norm() + self.scene_extent`. + + This sampling is ideal for object-centric scenes whose contents are + centered around a known `self.scene_center` and fit into a bounding sphere + with a radius of `self.scene_extent`. + + Args: + scene_center: The xyz coordinates of the center of the scene used + along with `scene_extent` to compute the min and max depth planes + for sampling ray-points. + scene_extent: The radius of the scene bounding box centered at `scene_center`. + """ + + scene_extent: float = 8.0 + scene_center: Tuple[float, float, float] = (0.0, 0.0, 0.0) + + def __post_init__(self): + super().__post_init__() + if self.scene_extent <= 0.0: + raise ValueError("Adaptive raysampler requires self.scene_extent > 0.") + self._scene_center = torch.FloatTensor(self.scene_center) + + def _get_min_max_depth_bounds(self, cameras: CamerasBase) -> Tuple[float, float]: + """ + Returns the adaptivelly calculated near/far planes. + """ + min_depth, max_depth = camera_utils.get_min_max_depth_bounds( + cameras, self._scene_center, self.scene_extent + ) + return float(min_depth[0]), float(max_depth[0]) + + +@registry.register +class NearFarRaySampler(AbstractMaskRaySampler): + """ + Samples a fixed number of points between fixed near and far z-planes. + Specifically, samples points along each ray with approximately uniform spacing + of z-coordinates between the minimum depth `self.min_depth` and the maximum depth + `self.max_depth`. This sampling is useful for rendering scenes where the camera is + in a constant distance from the focal point of the scene. + + Args: + min_depth: The minimum depth of a ray-point. + max_depth: The maximum depth of a ray-point. + """ + + min_depth: float = 0.1 + max_depth: float = 8.0 + + def _get_min_max_depth_bounds(self, cameras: CamerasBase) -> Tuple[float, float]: + """ + Returns the stored near/far planes. + """ + return self.min_depth, self.max_depth diff --git a/pytorch3d/implicitron/tools/utils.py b/pytorch3d/implicitron/tools/utils.py index 5e70c1c5..4a00dee2 100644 --- a/pytorch3d/implicitron/tools/utils.py +++ b/pytorch3d/implicitron/tools/utils.py @@ -157,6 +157,15 @@ def cat_dataclass(batch, tensor_collator: Callable): return type(elem)(**collated) +def setattr_if_hasattr(obj, name, value): + """ + Same as setattr(obj, name, value), but does nothing in case `name` is + not an attribe of `obj`. + """ + if hasattr(obj, name): + setattr(obj, name, value) + + class Timer: """ A simple class for timing execution.