From 2d1c6d5d9382651bbf825f5f1677d576305d1f92 Mon Sep 17 00:00:00 2001 From: Jeremy Reizenstein Date: Fri, 20 May 2022 08:32:19 -0700 Subject: [PATCH] simplify image_feature_extractor control Summary: If no view pooling, don't disable image_feature_extractor. Make image_feature_extractor default to absent. Reviewed By: davnov134 Differential Revision: D36547815 fbshipit-source-id: e51718e1bcbf65b8b365a6e894d4324f136635e9 --- pytorch3d/implicitron/models/generic_model.py | 14 +------------- tests/implicitron/test_config_use.py | 1 + 2 files changed, 2 insertions(+), 13 deletions(-) diff --git a/pytorch3d/implicitron/models/generic_model.py b/pytorch3d/implicitron/models/generic_model.py index bd433e69..e798dfe2 100644 --- a/pytorch3d/implicitron/models/generic_model.py +++ b/pytorch3d/implicitron/models/generic_model.py @@ -221,7 +221,7 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13 # ---- image feature extractor settings # (This is only created if view_pooler is enabled) image_feature_extractor: Optional[FeatureExtractorBase] - image_feature_extractor_class_type: Optional[str] = "ResNetFeatureExtractor" + image_feature_extractor_class_type: Optional[str] = None # ---- view pooler settings view_pooler_enabled: bool = False view_pooler: Optional[ViewPooler] @@ -276,8 +276,6 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13 raise ValueError( "image_feature_extractor must be present for view pooling." ) - else: - self.image_feature_extractor_class_type = None run_auto_creation(self) self._implicit_functions = self._construct_implicit_functions() @@ -618,16 +616,6 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13 **renderer_args ) - def create_view_pooler(self): - """ - Custom creation function called by run_auto_creation checking - that image_feature_extractor is enabled when view_pooler is enabled. - """ - if self.view_pooler_enabled: - self.view_pooler = ViewPooler(**self.view_pooler_args) - else: - self.view_pooler = None - def create_implicit_function(self) -> None: """ No-op called by run_auto_creation so that self.implicit_function diff --git a/tests/implicitron/test_config_use.py b/tests/implicitron/test_config_use.py index 3820b93a..a5b87ba5 100644 --- a/tests/implicitron/test_config_use.py +++ b/tests/implicitron/test_config_use.py @@ -69,6 +69,7 @@ class TestGenericModel(unittest.TestCase): args.view_pooler_args.feature_aggregator_class_type = ( "AngleWeightedIdentityFeatureAggregator" ) + args.image_feature_extractor_class_type = "ResNetFeatureExtractor" args.implicit_function_class_type = "IdrFeatureField" idr_args = args.implicit_function_IdrFeatureField_args idr_args.n_harmonic_functions_xyz = 1729