From 899a3192b6d34f892c35764cda581fb9f7fffd9c Mon Sep 17 00:00:00 2001 From: Jeremy Reizenstein Date: Mon, 16 May 2022 04:42:03 -0700 Subject: [PATCH] create_x_impl Summary: Make create_x delegate to create_x_impl so that users can rely on create_x_impl in their overrides of create_x. Reviewed By: shapovalov, davnov134 Differential Revision: D35929810 fbshipit-source-id: 80595894ee93346b881729995775876b016fc08e --- pytorch3d/implicitron/tools/config.py | 113 ++++++++++++++++++++------ tests/implicitron/test_config.py | 55 +++++++++++++ 2 files changed, 145 insertions(+), 23 deletions(-) diff --git a/pytorch3d/implicitron/tools/config.py b/pytorch3d/implicitron/tools/config.py index 2e806d4d..cf9970b6 100644 --- a/pytorch3d/implicitron/tools/config.py +++ b/pytorch3d/implicitron/tools/config.py @@ -175,6 +175,8 @@ _unprocessed_warning: str = ( TYPE_SUFFIX: str = "_class_type" ARGS_SUFFIX: str = "_args" ENABLED_SUFFIX: str = "_enabled" +CREATE_PREFIX: str = "create_" +IMPL_SUFFIX: str = "_impl" class ReplaceableBase: @@ -375,25 +377,68 @@ def _default_create( Returns: Function taking one argument, the object whose member should be - initialized. + initialized, i.e. self. """ + impl_name = f"{CREATE_PREFIX}{name}{IMPL_SUFFIX}" def inner(self): expand_args_fields(type_) + impl = getattr(self, impl_name) args = getattr(self, name + ARGS_SUFFIX) - setattr(self, name, type_(**args)) + impl(True, args) def inner_optional(self): expand_args_fields(type_) + impl = getattr(self, impl_name) enabled = getattr(self, name + ENABLED_SUFFIX) + args = getattr(self, name + ARGS_SUFFIX) + impl(enabled, args) + + def inner_pluggable(self): + type_name = getattr(self, name + TYPE_SUFFIX) + impl = getattr(self, impl_name) + if type_name is None: + args = None + else: + args = getattr(self, f"{name}_{type_name}{ARGS_SUFFIX}", None) + impl(type_name, args) + + if process_type == _ProcessType.OPTIONAL_CONFIGURABLE: + return inner_optional + return inner if process_type == _ProcessType.CONFIGURABLE else inner_pluggable + + +def _default_create_impl( + name: str, type_: Type, process_type: _ProcessType +) -> Callable[[Any, Any, DictConfig], None]: + """ + Return the default internal function for initialising a member. This is a function + which could be called in the create_ function to initialise the member. + + Args: + name: name of the member + type_: type of the member (with any Optional removed) + process_type: Shows whether member's declared type inherits ReplaceableBase, + in which case the actual type to be created is decided at + runtime. + + Returns: + Function taking + - self, the object whose member should be initialized. + - option for what to do. This is + - for pluggables, the type to initialise or None to do nothing + - for non pluggables, a bool indicating whether to initialise. + - the args for initializing the member. + """ + + def create_configurable(self, enabled, args): if enabled: - args = getattr(self, name + ARGS_SUFFIX) + expand_args_fields(type_) setattr(self, name, type_(**args)) else: setattr(self, name, None) - def inner_pluggable(self): - type_name = getattr(self, name + TYPE_SUFFIX) + def create_pluggable(self, type_name, args): if type_name is None: setattr(self, name, None) return @@ -408,12 +453,11 @@ def _default_create( # were made in the redefinition will not be reflected here. warnings.warn(f"New implementation of {type_name} is being chosen.") expand_args_fields(chosen_class) - args = getattr(self, f"{name}_{type_name}{ARGS_SUFFIX}") setattr(self, name, chosen_class(**args)) - if process_type == _ProcessType.OPTIONAL_CONFIGURABLE: - return inner_optional - return inner if process_type == _ProcessType.CONFIGURABLE else inner_pluggable + if process_type in (_ProcessType.CONFIGURABLE, _ProcessType.OPTIONAL_CONFIGURABLE): + return create_configurable + return create_pluggable def run_auto_creation(self: Any) -> None: @@ -628,11 +672,12 @@ def expand_args_fields( x_Y_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(Y)) x_Z_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(Z)) def create_x(self): - x_type = registry.get(X, self.x_class_type) + args = self.getattr(f"x_{self.x_class_type}_args") + self.create_x_impl(self.x_class_type, args) + def create_x_impl(self, x_type, args): + x_type = registry.get(X, x_type) expand_args_fields(x_type) - self.x = x_type( - **self.getattr(f"x_{self.x_class_type}_args) - ) + self.x = x_type(**args) x_class_type: str = "UNDEFAULTED" without adding the optional attributes if they are already there. @@ -652,14 +697,19 @@ def expand_args_fields( x_Z_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(Z)) def create_x(self): if self.x_class_type is None: + args = None + else: + args = self.getattr(f"x_{self.x_class_type}_args", None) + self.create_x_impl(self.x_class_type, args) + def create_x_impl(self, x_class_type, args): + if x_class_type is None: self.x = None return - x_type = registry.get(X, self.x_class_type) + x_type = registry.get(X, x_class_type) expand_args_fields(x_type) - self.x = x_type( - **self.getattr(f"x_{self.x_class_type}_args) - ) + assert args is not None + self.x = x_type(**args) x_class_type: Optional[str] = "UNDEFAULTED" without adding the optional attributes if they are already there. @@ -676,8 +726,14 @@ def expand_args_fields( x_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(X)) def create_x(self): - expand_args_fields(X) - self.x = X(self.x_args) + self.create_x_impl(True, self.x_args) + + def create_x_impl(self, enabled, args): + if enabled: + expand_args_fields(X) + self.x = X(**args) + else: + self.x = None Similarly, replace, @@ -693,9 +749,12 @@ def expand_args_fields( x_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(X)) x_enabled: bool = False def create_x(self): - if self.x_enabled: + self.create_x_impl(self.x_enabled, self.x_args) + + def create_x_impl(self, enabled, args): + if enabled: expand_args_fields(X) - self.x = X(self.x_args) + self.x = X(**args) else: self.x = None @@ -703,7 +762,7 @@ def expand_args_fields( Also adds the following class members, unannotated so that dataclass ignores them. - _creation_functions: Tuple[str] of all the create_ functions, - including those from base classes. + including those from base classes (not the create_x_impl ones). - _known_implementations: Dict[str, Type] containing the classes which have been found from the registry. (used only to raise a warning if it one has been overwritten) @@ -918,7 +977,7 @@ def _process_member( some_class.__annotations__[enabled_name] = bool setattr(some_class, enabled_name, False) - creation_function_name = f"create_{name}" + creation_function_name = f"{CREATE_PREFIX}{name}" if not hasattr(some_class, creation_function_name): setattr( some_class, @@ -927,6 +986,14 @@ def _process_member( ) creation_functions.append(creation_function_name) + creation_function_impl_name = f"{CREATE_PREFIX}{name}{IMPL_SUFFIX}" + if not hasattr(some_class, creation_function_impl_name): + setattr( + some_class, + creation_function_impl_name, + _default_create_impl(name, type_, process_type), + ) + def remove_unused_components(dict_: DictConfig) -> None: """ diff --git a/tests/implicitron/test_config.py b/tests/implicitron/test_config.py index 7076c656..2511566f 100644 --- a/tests/implicitron/test_config.py +++ b/tests/implicitron/test_config.py @@ -69,6 +69,10 @@ class LargePear(Pear): pass +class BoringConfigurable(Configurable): + pass + + class MainTest(Configurable): the_fruit: Fruit n_ids: int @@ -674,6 +678,57 @@ class TestConfig(unittest.TestCase): remove_unused_components(args) self.assertEqual(OmegaConf.to_yaml(args), "mt_enabled: false\n") + def test_impls(self): + # Check that create_x actually uses create_x_impl to do its work + # by using all the member types, both with a faked impl function + # and without. + # members with _0 are optional and absent, those with _o are + # optional and present. + control_args = [] + + def fake_impl(self, control, args): + control_args.append(control) + + for fake in [False, True]: + + class MyClass(Configurable): + fruit: Fruit + fruit_class_type: str = "Orange" + fruit_o: Optional[Fruit] + fruit_o_class_type: str = "Orange" + fruit_0: Optional[Fruit] + fruit_0_class_type: Optional[str] = None + boring: BoringConfigurable + boring_o: Optional[BoringConfigurable] + boring_o_enabled: bool = True + boring_0: Optional[BoringConfigurable] + + def __post_init__(self): + run_auto_creation(self) + + if fake: + MyClass.create_fruit_impl = fake_impl + MyClass.create_fruit_o_impl = fake_impl + MyClass.create_boring_impl = fake_impl + MyClass.create_boring_o_impl = fake_impl + + expand_args_fields(MyClass) + instance = MyClass() + for name in ["fruit", "fruit_o", "boring", "boring_o"]: + self.assertEqual( + hasattr(instance, name), not fake, msg=f"{name} {fake}" + ) + + self.assertIsNone(instance.fruit_0) + self.assertIsNone(instance.boring_0) + if not fake: + self.assertIsInstance(instance.fruit, Orange) + self.assertIsInstance(instance.fruit_o, Orange) + self.assertIsInstance(instance.boring, BoringConfigurable) + self.assertIsInstance(instance.boring_o, BoringConfigurable) + + self.assertEqual(control_args, ["Orange", "Orange", True, True]) + @dataclass(eq=False) class MockDataclass: