From 722646863c660dcd1bd8f492eab9116bb7a141c1 Mon Sep 17 00:00:00 2001 From: Jeremy Reizenstein Date: Wed, 6 Apr 2022 05:56:14 -0700 Subject: [PATCH] Optional[Configurable] in config Summary: A new type of auto-expanded member of a Configurable: something of type Optional[X] where X is a Configurable. This works like X but its construction is controlled by a boolean membername_enabled. Reviewed By: davnov134 Differential Revision: D35368269 fbshipit-source-id: 7e0c8a3e8c4930b0aa942fa1b325ce65336ebd5f --- pytorch3d/implicitron/tools/config.py | 88 +++++++++++++++++++++------ tests/implicitron/test_config.py | 30 +++++++-- 2 files changed, 95 insertions(+), 23 deletions(-) diff --git a/pytorch3d/implicitron/tools/config.py b/pytorch3d/implicitron/tools/config.py index 476c4711..f104007f 100644 --- a/pytorch3d/implicitron/tools/config.py +++ b/pytorch3d/implicitron/tools/config.py @@ -60,7 +60,7 @@ in dataclass style. def __post_init__(self): run_auto_creation(self) -It can be used like +Then it can be used like b_args = get_default_args(B) b = B(**b_args) @@ -82,8 +82,8 @@ something like the following. (The modification itself is done by the function self.a = A(**self.a_args) 2. Pluggability. Instead of a dataclass-style member being given a concrete class, -you can give a base class and the implementation is looked up by name in the global -`registry` in this module. E.g. +it can be given a base class and the implementation will be looked up by name in the +global `registry` in this module. E.g. class A(ReplaceableBase): k: int = 1 @@ -126,14 +126,14 @@ will expand to default_factory=lambda: DictConfig({"k": 1, "m": 3} ) a_A2_args: DictConfig = dataclasses.field( - default_factory=lambda: DictConfig({"k": 1, "m": 3} + default_factory=lambda: DictConfig({"k": 1, "n": 2} ) b_class_type: Optional[str] = "A2" b_A1_args: DictConfig = dataclasses.field( default_factory=lambda: DictConfig({"k": 1, "m": 3} ) b_A2_args: DictConfig = dataclasses.field( - default_factory=lambda: DictConfig({"k": 1, "m": 3} + default_factory=lambda: DictConfig({"k": 1, "n": 2} ) def __post_init__(self): @@ -155,12 +155,14 @@ will expand to 3. Aside from these classes, the members of these classes should be things which DictConfig is happy with: e.g. (bool, int, str, None, float) and what -can be built from them with DictConfigs and lists of them. +can be built from them with `DictConfig`s and lists of them. -In addition, you can call get_default_args on a function or class to get -the DictConfig of its defaulted arguments, assuming those are all things -which DictConfig is happy with. If you want to use such a thing as a member -of another configured class, `get_default_args_field` is a helper. +In addition, you can call `get_default_args` on a function or class to get +the `DictConfig` of its defaulted arguments, assuming those are all things +which `DictConfig` is happy with, so long as you add a call to +`enable_get_default_args` after its definition. If you want to use such a +thing as the default for a member of another configured class, +`get_default_args_field` is a helper. """ @@ -172,6 +174,7 @@ _unprocessed_warning: str = ( TYPE_SUFFIX: str = "_class_type" ARGS_SUFFIX: str = "_args" +ENABLED_SUFFIX: str = "_enabled" class ReplaceableBase: @@ -351,7 +354,8 @@ class _ProcessType(Enum): CONFIGURABLE = 1 REPLACEABLE = 2 - OPTIONAL_REPLACEABLE = 3 + OPTIONAL_CONFIGURABLE = 3 + OPTIONAL_REPLACEABLE = 4 def _default_create( @@ -379,6 +383,15 @@ def _default_create( args = getattr(self, name + ARGS_SUFFIX) setattr(self, name, type_(**args)) + def inner_optional(self): + expand_args_fields(type_) + enabled = getattr(self, name + ENABLED_SUFFIX) + if enabled: + args = getattr(self, name + ARGS_SUFFIX) + setattr(self, name, type_(**args)) + else: + setattr(self, name, None) + def inner_pluggable(self): type_name = getattr(self, name + TYPE_SUFFIX) if type_name is None: @@ -398,6 +411,8 @@ def _default_create( args = getattr(self, f"{name}_{type_name}{ARGS_SUFFIX}") setattr(self, name, chosen_class(**args)) + if process_type == _ProcessType.OPTIONAL_CONFIGURABLE: + return inner_optional return inner if process_type == _ProcessType.CONFIGURABLE else inner_pluggable @@ -608,8 +623,8 @@ def expand_args_fields( with - x_Y_args : DictConfig = dataclasses.field(default_factory=lambda: DictConfig()) - x_Z_args : DictConfig = dataclasses.field(default_factory=lambda: DictConfig()) + x_Y_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(Y)) + x_Z_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(Z)) def create_x(self): self.x = registry.get(X, self.x_class_type)( **self.getattr(f"x_{self.x_class_type}_args) @@ -629,8 +644,8 @@ def expand_args_fields( with - x_Y_args : DictConfig = dataclasses.field(default_factory=lambda: DictConfig()) - x_Z_args : DictConfig = dataclasses.field(default_factory=lambda: DictConfig()) + x_Y_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(Y)) + x_Z_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(Z)) def create_x(self): if self.x_class_type is None: self.x = None @@ -653,10 +668,30 @@ def expand_args_fields( will be replaced with - x_args : DictConfig = dataclasses.field(default_factory=lambda: DictConfig()) + x_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(X)) def create_x(self): self.x = X(self.x_args) + Similarly, replace, + + x: Optional[X] + + and optionally + + def create_x(self):... + x_enabled: bool = ... + + with + + x_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(X)) + x_enabled: bool = False + def create_x(self): + if self.x_enabled: + self.x = X(self.x_args) + else: + self.x = None + + Also adds the following class members, unannotated so that dataclass ignores them. - _creation_functions: Tuple[str] of all the create_ functions, @@ -779,6 +814,9 @@ def _get_type_to_process(type_) -> Optional[Tuple[Type, _ProcessType]]: ): return underlying, _ProcessType.OPTIONAL_REPLACEABLE + if isinstance(underlying, type) and issubclass(underlying, Configurable): + return underlying, _ProcessType.OPTIONAL_CONFIGURABLE + if not isinstance(type_, type): # e.g. any other Union or Tuple return @@ -817,7 +855,7 @@ def _process_member( # there are non-defaulted standard class members. del some_class.__annotations__[name] - if process_type != _ProcessType.CONFIGURABLE: + if process_type in (_ProcessType.REPLACEABLE, _ProcessType.OPTIONAL_REPLACEABLE): type_name = name + TYPE_SUFFIX if type_name not in some_class.__annotations__: if process_type == _ProcessType.OPTIONAL_REPLACEABLE: @@ -866,6 +904,11 @@ def _process_member( _do_not_process=_do_not_process + (some_class,), ), ) + if process_type == _ProcessType.OPTIONAL_CONFIGURABLE: + enabled_name = name + ENABLED_SUFFIX + if enabled_name not in some_class.__annotations__: + some_class.__annotations__[enabled_name] = bool + setattr(some_class, enabled_name, False) creation_function_name = f"create_{name}" if not hasattr(some_class, creation_function_name): @@ -884,7 +927,8 @@ def remove_unused_components(dict_: DictConfig) -> None: pluggable parts which are not in use. For example, if renderer_class_type is SignedDistanceFunctionRenderer, the renderer_MultiPassEmissionAbsorptionRenderer_args will be - removed. + removed. Also, if chocolate_enabled is False, then chocolate_args will + be removed. Args: dict_: (MODIFIED IN PLACE) a DictConfig instance @@ -904,6 +948,14 @@ def remove_unused_components(dict_: DictConfig) -> None: if key.startswith(replaceable + "_") and key != expect: del dict_[key] + suffix_length = len(ENABLED_SUFFIX) + enableables = [key[:-suffix_length] for key in keys if key.endswith(ENABLED_SUFFIX)] + for enableable in enableables: + enabled = dict_[enableable + ENABLED_SUFFIX] + if not enabled: + with open_dict(dict_): + dict_.pop(enableable + ARGS_SUFFIX, None) + for key in dict_: if isinstance(dict_.get(key), DictConfig): remove_unused_components(dict_[key]) diff --git a/tests/implicitron/test_config.py b/tests/implicitron/test_config.py index e53fe0bd..2be4595f 100644 --- a/tests/implicitron/test_config.py +++ b/tests/implicitron/test_config.py @@ -106,8 +106,10 @@ class TestConfig(unittest.TestCase): gt(Optional[Fruit]), (Fruit, _ProcessType.OPTIONAL_REPLACEABLE) ) self.assertEqual(gt(MainTest), (MainTest, _ProcessType.CONFIGURABLE)) + self.assertEqual( + gt(Optional[MainTest]), (MainTest, _ProcessType.OPTIONAL_CONFIGURABLE) + ) self.assertIsNone(gt(Optional[int])) - self.assertIsNone(gt(Optional[MainTest])) self.assertIsNone(gt(Tuple[Fruit])) self.assertIsNone(gt(Tuple[Fruit, Animal])) self.assertIsNone(gt(Optional[List[int]])) @@ -427,15 +429,24 @@ class TestConfig(unittest.TestCase): run_auto_creation(self) class C(Configurable): - b: B + b1: B + b2: Optional[B] + b3: Optional[B] + b2_enabled: bool = True def __post_init__(self): run_auto_creation(self) c_args = get_default_args(C) c = C(**c_args) - self.assertIsInstance(c.b.a, A) - self.assertEqual(c.b.a.n, 9) + self.assertIsInstance(c.b1.a, A) + self.assertEqual(c.b1.a.n, 9) + self.assertFalse(hasattr(c, "b1_enabled")) + self.assertIsInstance(c.b2.a, A) + self.assertEqual(c.b2.a.n, 9) + self.assertTrue(c.b2_enabled) + self.assertIsNone(c.b3) + self.assertFalse(c.b3_enabled) def test_doc(self): # The case in the docstring. @@ -522,7 +533,7 @@ class TestConfig(unittest.TestCase): # Like torch.nn.Module, this class contains annotations # but is not designed to be dataclass'd. # This test ensures that such classes, when inherited fron, - # are not accidentally expand_args_fields. + # are not accidentally affected by expand_args_fields. a: int = 9 b: int @@ -654,6 +665,15 @@ class TestConfig(unittest.TestCase): self.assertEqual(sorted(instance_data.keys()), expected_keys) self.assertEqual(instance_data, expected) + def test_remove_unused_components_optional(self): + class MainTestWrapper(Configurable): + mt: Optional[MainTest] + + args = get_default_args(MainTestWrapper) + self.assertEqual(list(args.keys()), ["mt_args", "mt_enabled"]) + remove_unused_components(args) + self.assertEqual(OmegaConf.to_yaml(args), "mt_enabled: false\n") + @dataclass(eq=False) class MockDataclass: