From 46e82efb4e865e24a1351f20a16cf74695cfed5d Mon Sep 17 00:00:00 2001 From: Jeremy Reizenstein Date: Wed, 3 Aug 2022 12:37:31 -0700 Subject: [PATCH] clean IF args Summary: continued - avoid duplicate inputs Reviewed By: davnov134 Differential Revision: D38248827 fbshipit-source-id: 91ed398e304496a936f66e7a70ab3d189eeb5c70 --- .../configs/repro_base.yaml | 1 - .../implicitron_trainer/tests/experiment.yaml | 9 --- pytorch3d/implicitron/models/generic_model.py | 57 +++++++++++-------- .../scene_representation_networks.py | 28 +++++++++ tests/implicitron/data/overrides.yaml | 2 - tests/implicitron/test_srn.py | 5 +- 6 files changed, 65 insertions(+), 37 deletions(-) diff --git a/projects/implicitron_trainer/configs/repro_base.yaml b/projects/implicitron_trainer/configs/repro_base.yaml index 112d269f..9d6af260 100644 --- a/projects/implicitron_trainer/configs/repro_base.yaml +++ b/projects/implicitron_trainer/configs/repro_base.yaml @@ -43,7 +43,6 @@ model_factory_ImplicitronModelFactory_args: n_layers_xyz: 8 append_xyz: - 5 - latent_dim: 0 raysampler_AdaptiveRaySampler_args: n_rays_per_image_sampled_from_mask: 1024 scene_extent: 8.0 diff --git a/projects/implicitron_trainer/tests/experiment.yaml b/projects/implicitron_trainer/tests/experiment.yaml index d5447cfc..f4bf12e9 100644 --- a/projects/implicitron_trainer/tests/experiment.yaml +++ b/projects/implicitron_trainer/tests/experiment.yaml @@ -313,7 +313,6 @@ model_factory_ImplicitronModelFactory_args: - AVG - STD implicit_function_IdrFeatureField_args: - feature_vector_size: 3 d_in: 3 d_out: 1 dims: @@ -331,15 +330,12 @@ model_factory_ImplicitronModelFactory_args: weight_norm: true n_harmonic_functions_xyz: 0 pooled_feature_dim: 0 - encoding_dim: 0 implicit_function_NeRFormerImplicitFunction_args: n_harmonic_functions_xyz: 10 n_harmonic_functions_dir: 4 n_hidden_neurons_dir: 128 - latent_dim: 0 input_xyz: true xyz_ray_dir_in_camera_coords: false - color_dim: 3 transformer_dim_down_factor: 2.0 n_hidden_neurons_xyz: 80 n_layers_xyz: 2 @@ -349,10 +345,8 @@ model_factory_ImplicitronModelFactory_args: n_harmonic_functions_xyz: 10 n_harmonic_functions_dir: 4 n_hidden_neurons_dir: 128 - latent_dim: 0 input_xyz: true xyz_ray_dir_in_camera_coords: false - color_dim: 3 transformer_dim_down_factor: 1.0 n_hidden_neurons_xyz: 256 n_layers_xyz: 8 @@ -367,8 +361,6 @@ model_factory_ImplicitronModelFactory_args: n_layers_hypernet: 1 in_features: 3 out_features: 256 - latent_dim_hypernet: 0 - latent_dim: 0 xyz_in_camera_coords: false pixel_generator_args: n_harmonic_functions: 4 @@ -385,7 +377,6 @@ model_factory_ImplicitronModelFactory_args: n_layers: 2 in_features: 3 out_features: 256 - latent_dim: 0 xyz_in_camera_coords: false raymarch_function: null pixel_generator_args: diff --git a/pytorch3d/implicitron/models/generic_model.py b/pytorch3d/implicitron/models/generic_model.py index 2c9531f1..d4506245 100644 --- a/pytorch3d/implicitron/models/generic_model.py +++ b/pytorch3d/implicitron/models/generic_model.py @@ -677,6 +677,18 @@ class GenericModel(ImplicitronModelBase): # pyre-ignore: 13 """ pass + @classmethod + def implicit_function_tweak_args(cls, type, args: DictConfig) -> None: + """ + We don't expose certain implicit_function fields because we want to set + them based on other inputs. + """ + args.pop("feature_vector_size", None) + args.pop("encoding_dim", None) + args.pop("latent_dim", None) + args.pop("latent_dim_hypernet", None) + args.pop("color_dim", None) + def _construct_implicit_functions(self): """ After run_auto_creation has been called, the arguments @@ -686,32 +698,31 @@ class GenericModel(ImplicitronModelBase): # pyre-ignore: 13 implicit function method. Then the required implicit function(s) are initialized. """ - # nerf preprocessing - nerf_args = self.implicit_function_NeuralRadianceFieldImplicitFunction_args - nerformer_args = self.implicit_function_NeRFormerImplicitFunction_args - nerf_args["latent_dim"] = nerformer_args["latent_dim"] = ( - self._get_viewpooled_feature_dim() + self._get_global_encoder_encoding_dim() - ) - nerf_args["color_dim"] = nerformer_args[ - "color_dim" - ] = self.render_features_dimensions + extra_args = {} + if self.implicit_function_class_type in ( + "NeuralRadianceFieldImplicitFunction", + "NeRFormerImplicitFunction", + ): + extra_args["latent_dim"] = ( + self._get_viewpooled_feature_dim() + + self._get_global_encoder_encoding_dim() + ) + extra_args["color_dim"] = self.render_features_dimensions - # idr preprocessing - idr = self.implicit_function_IdrFeatureField_args - idr["feature_vector_size"] = self.render_features_dimensions - idr["encoding_dim"] = self._get_global_encoder_encoding_dim() + if self.implicit_function_class_type == "IdrFeatureField": + extra_args["feature_vector_size"] = self.render_features_dimensions + extra_args["encoding_dim"] = self._get_global_encoder_encoding_dim() - # srn preprocessing - srn = self.implicit_function_SRNImplicitFunction_args - srn.raymarch_function_args.latent_dim = ( - self._get_viewpooled_feature_dim() + self._get_global_encoder_encoding_dim() - ) + if self.implicit_function_class_type == "SRNImplicitFunction": + extra_args["latent_dim"] = ( + self._get_viewpooled_feature_dim() + + self._get_global_encoder_encoding_dim() + ) # srn_hypernet preprocessing - srn_hypernet = self.implicit_function_SRNHyperNetImplicitFunction_args - srn_hypernet_args = srn_hypernet.hypernet_args - srn_hypernet_args.latent_dim_hypernet = self._get_global_encoder_encoding_dim() - srn_hypernet_args.latent_dim = self._get_viewpooled_feature_dim() + if self.implicit_function_class_type == "SRNHyperNetImplicitFunction": + extra_args["latent_dim"] = self._get_viewpooled_feature_dim() + extra_args["latent_dim_hypernet"] = self._get_global_encoder_encoding_dim() # check that for srn, srn_hypernet, idr we have self.num_passes=1 implicit_function_type = registry.get( @@ -734,7 +745,7 @@ class GenericModel(ImplicitronModelBase): # pyre-ignore: 13 if config is None: raise ValueError(f"{config_name} not present") implicit_functions_list = [ - ImplicitFunctionWrapper(implicit_function_type(**config)) + ImplicitFunctionWrapper(implicit_function_type(**config, **extra_args)) for _ in range(self.num_passes) ] return torch.nn.ModuleList(implicit_functions_list) diff --git a/pytorch3d/implicitron/models/implicit_function/scene_representation_networks.py b/pytorch3d/implicitron/models/implicit_function/scene_representation_networks.py index faf6fdfe..b7f10a95 100644 --- a/pytorch3d/implicitron/models/implicit_function/scene_representation_networks.py +++ b/pytorch3d/implicitron/models/implicit_function/scene_representation_networks.py @@ -4,6 +4,7 @@ from typing import Any, cast, Optional, Tuple import torch +from omegaconf import DictConfig from pytorch3d.common.linear_with_repeat import LinearWithRepeat from pytorch3d.implicitron.third_party import hyperlayers, pytorch_prototyping from pytorch3d.implicitron.tools.config import Configurable, registry, run_auto_creation @@ -327,6 +328,7 @@ class SRNRaymarchHyperNet(Configurable, torch.nn.Module): @registry.register # pyre-fixme[13]: Uninitialized attribute class SRNImplicitFunction(ImplicitFunctionBase, torch.nn.Module): + latent_dim: int = 0 raymarch_function: SRNRaymarchFunction pixel_generator: SRNPixelGenerator @@ -334,6 +336,17 @@ class SRNImplicitFunction(ImplicitFunctionBase, torch.nn.Module): super().__init__() run_auto_creation(self) + def create_raymarch_function(self) -> None: + self.raymarch_function = SRNRaymarchFunction( + latent_dim=self.latent_dim, + # pyre-ignore[32] + **self.raymarch_function_args, + ) + + @classmethod + def raymarch_function_tweak_args(cls, type, args: DictConfig) -> None: + args.pop("latent_dim", None) + def forward( self, ray_bundle: RayBundle, @@ -371,6 +384,8 @@ class SRNHyperNetImplicitFunction(ImplicitFunctionBase, torch.nn.Module): the cache. """ + latent_dim_hypernet: int = 0 + latent_dim: int = 0 hypernet: SRNRaymarchHyperNet pixel_generator: SRNPixelGenerator @@ -378,6 +393,19 @@ class SRNHyperNetImplicitFunction(ImplicitFunctionBase, torch.nn.Module): super().__init__() run_auto_creation(self) + def create_hypernet(self) -> None: + self.hypernet = SRNRaymarchHyperNet( + latent_dim=self.latent_dim, + latent_dim_hypernet=self.latent_dim_hypernet, + # pyre-ignore[32] + **self.hypernet_args, + ) + + @classmethod + def hypernet_tweak_args(cls, type, args: DictConfig) -> None: + args.pop("latent_dim", None) + args.pop("latent_dim_hypernet", None) + def forward( self, ray_bundle: RayBundle, diff --git a/tests/implicitron/data/overrides.yaml b/tests/implicitron/data/overrides.yaml index 958fc649..7bbd5df0 100644 --- a/tests/implicitron/data/overrides.yaml +++ b/tests/implicitron/data/overrides.yaml @@ -103,7 +103,6 @@ view_pooler_args: weight_by_ray_angle_gamma: 1.0 min_ray_angle_weight: 0.1 implicit_function_IdrFeatureField_args: - feature_vector_size: 3 d_in: 3 d_out: 1 dims: @@ -121,6 +120,5 @@ implicit_function_IdrFeatureField_args: weight_norm: true n_harmonic_functions_xyz: 1729 pooled_feature_dim: 0 - encoding_dim: 0 view_metrics_ViewMetrics_args: {} regularization_metrics_RegularizationMetrics_args: {} diff --git a/tests/implicitron/test_srn.py b/tests/implicitron/test_srn.py index 48b71bf1..a50c341c 100644 --- a/tests/implicitron/test_srn.py +++ b/tests/implicitron/test_srn.py @@ -55,9 +55,10 @@ class TestSRN(TestCaseMixin, unittest.TestCase): def test_srn_hypernet_implicit_function(self): # TODO investigate: If latent_dim_hypernet=0, why does this crash and dump core? latent_dim_hypernet = 39 - hypernet_args = {"latent_dim_hypernet": latent_dim_hypernet} device = torch.device("cuda:0") - implicit_function = SRNHyperNetImplicitFunction(hypernet_args=hypernet_args) + implicit_function = SRNHyperNetImplicitFunction( + latent_dim_hypernet=latent_dim_hypernet + ) implicit_function.to(device) global_code = torch.rand(_BATCH_SIZE, latent_dim_hypernet, device=device) bundle = self._get_bundle(device=device)