mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 20:02:49 +08:00
simplify image_feature_extractor control
Summary: If no view pooling, don't disable image_feature_extractor. Make image_feature_extractor default to absent. Reviewed By: davnov134 Differential Revision: D36547815 fbshipit-source-id: e51718e1bcbf65b8b365a6e894d4324f136635e9
This commit is contained in:
parent
9fe15da3cd
commit
2d1c6d5d93
@ -221,7 +221,7 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
|
|||||||
# ---- image feature extractor settings
|
# ---- image feature extractor settings
|
||||||
# (This is only created if view_pooler is enabled)
|
# (This is only created if view_pooler is enabled)
|
||||||
image_feature_extractor: Optional[FeatureExtractorBase]
|
image_feature_extractor: Optional[FeatureExtractorBase]
|
||||||
image_feature_extractor_class_type: Optional[str] = "ResNetFeatureExtractor"
|
image_feature_extractor_class_type: Optional[str] = None
|
||||||
# ---- view pooler settings
|
# ---- view pooler settings
|
||||||
view_pooler_enabled: bool = False
|
view_pooler_enabled: bool = False
|
||||||
view_pooler: Optional[ViewPooler]
|
view_pooler: Optional[ViewPooler]
|
||||||
@ -276,8 +276,6 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
"image_feature_extractor must be present for view pooling."
|
"image_feature_extractor must be present for view pooling."
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
self.image_feature_extractor_class_type = None
|
|
||||||
run_auto_creation(self)
|
run_auto_creation(self)
|
||||||
|
|
||||||
self._implicit_functions = self._construct_implicit_functions()
|
self._implicit_functions = self._construct_implicit_functions()
|
||||||
@ -618,16 +616,6 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
|
|||||||
**renderer_args
|
**renderer_args
|
||||||
)
|
)
|
||||||
|
|
||||||
def create_view_pooler(self):
|
|
||||||
"""
|
|
||||||
Custom creation function called by run_auto_creation checking
|
|
||||||
that image_feature_extractor is enabled when view_pooler is enabled.
|
|
||||||
"""
|
|
||||||
if self.view_pooler_enabled:
|
|
||||||
self.view_pooler = ViewPooler(**self.view_pooler_args)
|
|
||||||
else:
|
|
||||||
self.view_pooler = None
|
|
||||||
|
|
||||||
def create_implicit_function(self) -> None:
|
def create_implicit_function(self) -> None:
|
||||||
"""
|
"""
|
||||||
No-op called by run_auto_creation so that self.implicit_function
|
No-op called by run_auto_creation so that self.implicit_function
|
||||||
|
@ -69,6 +69,7 @@ class TestGenericModel(unittest.TestCase):
|
|||||||
args.view_pooler_args.feature_aggregator_class_type = (
|
args.view_pooler_args.feature_aggregator_class_type = (
|
||||||
"AngleWeightedIdentityFeatureAggregator"
|
"AngleWeightedIdentityFeatureAggregator"
|
||||||
)
|
)
|
||||||
|
args.image_feature_extractor_class_type = "ResNetFeatureExtractor"
|
||||||
args.implicit_function_class_type = "IdrFeatureField"
|
args.implicit_function_class_type = "IdrFeatureField"
|
||||||
idr_args = args.implicit_function_IdrFeatureField_args
|
idr_args = args.implicit_function_IdrFeatureField_args
|
||||||
idr_args.n_harmonic_functions_xyz = 1729
|
idr_args.n_harmonic_functions_xyz = 1729
|
||||||
|
Loading…
x
Reference in New Issue
Block a user