simplify image_feature_extractor control

Summary: If no view pooling, don't disable image_feature_extractor. Make image_feature_extractor default to absent.

Reviewed By: davnov134

Differential Revision: D36547815

fbshipit-source-id: e51718e1bcbf65b8b365a6e894d4324f136635e9
This commit is contained in:
Jeremy Reizenstein 2022-05-20 08:32:19 -07:00 committed by Facebook GitHub Bot
parent 9fe15da3cd
commit 2d1c6d5d93
2 changed files with 2 additions and 13 deletions

View File

@ -221,7 +221,7 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
# ---- image feature extractor settings
# (This is only created if view_pooler is enabled)
image_feature_extractor: Optional[FeatureExtractorBase]
image_feature_extractor_class_type: Optional[str] = "ResNetFeatureExtractor"
image_feature_extractor_class_type: Optional[str] = None
# ---- view pooler settings
view_pooler_enabled: bool = False
view_pooler: Optional[ViewPooler]
@ -276,8 +276,6 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
raise ValueError(
"image_feature_extractor must be present for view pooling."
)
else:
self.image_feature_extractor_class_type = None
run_auto_creation(self)
self._implicit_functions = self._construct_implicit_functions()
@ -618,16 +616,6 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
**renderer_args
)
def create_view_pooler(self):
"""
Custom creation function called by run_auto_creation checking
that image_feature_extractor is enabled when view_pooler is enabled.
"""
if self.view_pooler_enabled:
self.view_pooler = ViewPooler(**self.view_pooler_args)
else:
self.view_pooler = None
def create_implicit_function(self) -> None:
"""
No-op called by run_auto_creation so that self.implicit_function

View File

@ -69,6 +69,7 @@ class TestGenericModel(unittest.TestCase):
args.view_pooler_args.feature_aggregator_class_type = (
"AngleWeightedIdentityFeatureAggregator"
)
args.image_feature_extractor_class_type = "ResNetFeatureExtractor"
args.implicit_function_class_type = "IdrFeatureField"
idr_args = args.implicit_function_IdrFeatureField_args
idr_args.n_harmonic_functions_xyz = 1729