mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
simplify image_feature_extractor control
Summary: If no view pooling, don't disable image_feature_extractor. Make image_feature_extractor default to absent. Reviewed By: davnov134 Differential Revision: D36547815 fbshipit-source-id: e51718e1bcbf65b8b365a6e894d4324f136635e9
This commit is contained in:
parent
9fe15da3cd
commit
2d1c6d5d93
@ -221,7 +221,7 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
|
||||
# ---- image feature extractor settings
|
||||
# (This is only created if view_pooler is enabled)
|
||||
image_feature_extractor: Optional[FeatureExtractorBase]
|
||||
image_feature_extractor_class_type: Optional[str] = "ResNetFeatureExtractor"
|
||||
image_feature_extractor_class_type: Optional[str] = None
|
||||
# ---- view pooler settings
|
||||
view_pooler_enabled: bool = False
|
||||
view_pooler: Optional[ViewPooler]
|
||||
@ -276,8 +276,6 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
|
||||
raise ValueError(
|
||||
"image_feature_extractor must be present for view pooling."
|
||||
)
|
||||
else:
|
||||
self.image_feature_extractor_class_type = None
|
||||
run_auto_creation(self)
|
||||
|
||||
self._implicit_functions = self._construct_implicit_functions()
|
||||
@ -618,16 +616,6 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
|
||||
**renderer_args
|
||||
)
|
||||
|
||||
def create_view_pooler(self):
|
||||
"""
|
||||
Custom creation function called by run_auto_creation checking
|
||||
that image_feature_extractor is enabled when view_pooler is enabled.
|
||||
"""
|
||||
if self.view_pooler_enabled:
|
||||
self.view_pooler = ViewPooler(**self.view_pooler_args)
|
||||
else:
|
||||
self.view_pooler = None
|
||||
|
||||
def create_implicit_function(self) -> None:
|
||||
"""
|
||||
No-op called by run_auto_creation so that self.implicit_function
|
||||
|
@ -69,6 +69,7 @@ class TestGenericModel(unittest.TestCase):
|
||||
args.view_pooler_args.feature_aggregator_class_type = (
|
||||
"AngleWeightedIdentityFeatureAggregator"
|
||||
)
|
||||
args.image_feature_extractor_class_type = "ResNetFeatureExtractor"
|
||||
args.implicit_function_class_type = "IdrFeatureField"
|
||||
idr_args = args.implicit_function_IdrFeatureField_args
|
||||
idr_args.n_harmonic_functions_xyz = 1729
|
||||
|
Loading…
x
Reference in New Issue
Block a user