diff --git a/projects/implicitron_trainer/configs/repro_multiseq_idr_ad.yaml b/projects/implicitron_trainer/configs/repro_multiseq_idr_ad.yaml index 56684f6f..f6bb1fe4 100644 --- a/projects/implicitron_trainer/configs/repro_multiseq_idr_ad.yaml +++ b/projects/implicitron_trainer/configs/repro_multiseq_idr_ad.yaml @@ -43,7 +43,6 @@ model_factory_ImplicitronModelFactory_args: line_step_iters: 3 n_secant_steps: 8 n_steps: 100 - object_bounding_sphere: 8.0 sdf_threshold: 5.0e-05 ray_normal_coloring_network_args: d_in: 9 diff --git a/projects/implicitron_trainer/configs/repro_singleseq_idr.yaml b/projects/implicitron_trainer/configs/repro_singleseq_idr.yaml index c936d092..7224b9d5 100644 --- a/projects/implicitron_trainer/configs/repro_singleseq_idr.yaml +++ b/projects/implicitron_trainer/configs/repro_singleseq_idr.yaml @@ -36,7 +36,6 @@ model_factory_ImplicitronModelFactory_args: line_step_iters: 3 n_secant_steps: 8 n_steps: 100 - object_bounding_sphere: 8.0 sdf_threshold: 5.0e-05 ray_normal_coloring_network_args: d_in: 9 diff --git a/projects/implicitron_trainer/tests/experiment.yaml b/projects/implicitron_trainer/tests/experiment.yaml index 90f73a39..d5447cfc 100644 --- a/projects/implicitron_trainer/tests/experiment.yaml +++ b/projects/implicitron_trainer/tests/experiment.yaml @@ -241,15 +241,6 @@ model_factory_ImplicitronModelFactory_args: density_relu: true blend_output: false renderer_SignedDistanceFunctionRenderer_args: - render_features_dimensions: 3 - ray_tracer_args: - object_bounding_sphere: 1.0 - sdf_threshold: 5.0e-05 - line_search_step: 0.5 - line_step_iters: 1 - sphere_tracing_iters: 10 - n_steps: 100 - n_secant_steps: 8 ray_normal_coloring_network_args: feature_vector_size: 3 mode: idr @@ -266,6 +257,13 @@ model_factory_ImplicitronModelFactory_args: bg_color: - 0.0 soft_mask_alpha: 50.0 + ray_tracer_args: + sdf_threshold: 5.0e-05 + line_search_step: 0.5 + line_step_iters: 1 + sphere_tracing_iters: 10 + n_steps: 100 + n_secant_steps: 8 image_feature_extractor_ResNetFeatureExtractor_args: name: resnet34 pretrained: true diff --git a/pytorch3d/implicitron/models/generic_model.py b/pytorch3d/implicitron/models/generic_model.py index afd4616f..2c9531f1 100644 --- a/pytorch3d/implicitron/models/generic_model.py +++ b/pytorch3d/implicitron/models/generic_model.py @@ -641,35 +641,32 @@ class GenericModel(ImplicitronModelBase): # pyre-ignore: 13 **raysampler_args, **extra_args ) + @classmethod + def renderer_tweak_args(cls, type, args: DictConfig) -> None: + """ + We don't expose certain fields of the renderer because we want to set + them based on other inputs. + """ + args.pop("render_features_dimensions", None) + args.pop("object_bounding_sphere", None) + def create_renderer(self): - raysampler_args = getattr( - self, "raysampler_" + self.raysampler_class_type + "_args" - ) - self.renderer_MultiPassEmissionAbsorptionRenderer_args[ - "stratified_sampling_coarse_training" - ] = raysampler_args["stratified_point_sampling_training"] - self.renderer_MultiPassEmissionAbsorptionRenderer_args[ - "stratified_sampling_coarse_evaluation" - ] = raysampler_args["stratified_point_sampling_evaluation"] - self.renderer_SignedDistanceFunctionRenderer_args[ - "render_features_dimensions" - ] = self.render_features_dimensions + extra_args = {} if self.renderer_class_type == "SignedDistanceFunctionRenderer": - if "scene_extent" not in raysampler_args: + extra_args["render_features_dimensions"] = self.render_features_dimensions + if not hasattr(self.raysampler, "scene_extent"): raise ValueError( "SignedDistanceFunctionRenderer requires" + " a raysampler that defines the 'scene_extent' field" + " (this field is supported by, e.g., the adaptive raysampler - " + " self.raysampler_class_type='AdaptiveRaySampler')." ) - self.renderer_SignedDistanceFunctionRenderer_args.ray_tracer_args[ - "object_bounding_sphere" - ] = self.raysampler_AdaptiveRaySampler_args["scene_extent"] + extra_args["object_bounding_sphere"] = self.raysampler.scene_extent renderer_args = getattr(self, "renderer_" + self.renderer_class_type + "_args") self.renderer = registry.get(BaseRenderer, self.renderer_class_type)( - **renderer_args + **renderer_args, **extra_args ) def create_implicit_function(self) -> None: diff --git a/pytorch3d/implicitron/models/renderer/multipass_ea.py b/pytorch3d/implicitron/models/renderer/multipass_ea.py index 76718848..89bceae1 100644 --- a/pytorch3d/implicitron/models/renderer/multipass_ea.py +++ b/pytorch3d/implicitron/models/renderer/multipass_ea.py @@ -53,10 +53,12 @@ class MultiPassEmissionAbsorptionRenderer( # pyre-ignore: 13 fine rendering pass during training. n_pts_per_ray_fine_evaluation: The number of points sampled per ray for the fine rendering pass during evaluation. - stratified_sampling_coarse_training: Enable/disable stratified sampling during - training. - stratified_sampling_coarse_evaluation: Enable/disable stratified sampling during - evaluation. + stratified_sampling_coarse_training: Enable/disable stratified sampling in the + refiner during training. Only matters if there are multiple implicit + functions (i.e. in GenericModel if num_passes>1). + stratified_sampling_coarse_evaluation: Enable/disable stratified sampling in + the refiner during evaluation. Only matters if there are multiple implicit + functions (i.e. in GenericModel if num_passes>1). append_coarse_samples_to_fine: Add the fine ray points to the coarse points after sampling. density_noise_std_train: Standard deviation of the noise added to the diff --git a/pytorch3d/implicitron/models/renderer/sdf_renderer.py b/pytorch3d/implicitron/models/renderer/sdf_renderer.py index edda575f..aa85693e 100644 --- a/pytorch3d/implicitron/models/renderer/sdf_renderer.py +++ b/pytorch3d/implicitron/models/renderer/sdf_renderer.py @@ -8,7 +8,11 @@ from typing import List, Optional, Tuple import torch from omegaconf import DictConfig -from pytorch3d.implicitron.tools.config import get_default_args_field, registry +from pytorch3d.implicitron.tools.config import ( + get_default_args_field, + registry, + run_auto_creation, +) from pytorch3d.implicitron.tools.utils import evaluating from pytorch3d.renderer import RayBundle @@ -18,9 +22,10 @@ from .rgb_net import RayNormalColoringNetwork @registry.register -class SignedDistanceFunctionRenderer(BaseRenderer, torch.nn.Module): +class SignedDistanceFunctionRenderer(BaseRenderer, torch.nn.Module): # pyre-ignore[13] render_features_dimensions: int = 3 - ray_tracer_args: DictConfig = get_default_args_field(RayTracing) + object_bounding_sphere: float = 1.0 + ray_tracer: RayTracing ray_normal_coloring_network_args: DictConfig = get_default_args_field( RayNormalColoringNetwork ) @@ -37,8 +42,7 @@ class SignedDistanceFunctionRenderer(BaseRenderer, torch.nn.Module): f"Background color should have {render_features_dimensions} entries." ) - self.ray_tracer = RayTracing(**self.ray_tracer_args) - self.object_bounding_sphere = self.ray_tracer_args.get("object_bounding_sphere") + run_auto_creation(self) self.ray_normal_coloring_network_args[ "feature_vector_size" @@ -49,6 +53,17 @@ class SignedDistanceFunctionRenderer(BaseRenderer, torch.nn.Module): self.register_buffer("_bg_color", torch.tensor(self.bg_color), persistent=False) + @classmethod + def ray_tracer_tweak_args(cls, type, args: DictConfig) -> None: + del args["object_bounding_sphere"] + + def create_ray_tracer(self) -> None: + self.ray_tracer = RayTracing( + # pyre-ignore[32] + **self.ray_tracer_args, + object_bounding_sphere=self.object_bounding_sphere, + ) + def requires_object_mask(self) -> bool: return True @@ -97,7 +112,6 @@ class SignedDistanceFunctionRenderer(BaseRenderer, torch.nn.Module): object_mask = object_mask.reshape(batch_size, -1) with torch.no_grad(), evaluating(implicit_function): - # pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function. points, network_object_mask, dists = self.ray_tracer( sdf=lambda x: implicit_function(x)[ :, 0 @@ -128,7 +142,6 @@ class SignedDistanceFunctionRenderer(BaseRenderer, torch.nn.Module): N = surface_points.shape[0] # Sample points for the eikonal loss - # pyre-fixme[9] eik_bounding_box: float = self.object_bounding_sphere n_eik_points = batch_size * num_pixels // 2 eikonal_points = torch.empty(