Optional[Configurable] in config

Summary: A new type of auto-expanded member of a Configurable: something of type Optional[X] where X is a Configurable. This works like X but its construction is controlled by a boolean membername_enabled.

Reviewed By: davnov134

Differential Revision: D35368269

fbshipit-source-id: 7e0c8a3e8c4930b0aa942fa1b325ce65336ebd5f
This commit is contained in:
Jeremy Reizenstein 2022-04-06 05:56:14 -07:00 committed by Facebook GitHub Bot
parent e10a90140d
commit 722646863c
2 changed files with 95 additions and 23 deletions

View File

@ -60,7 +60,7 @@ in dataclass style.
def __post_init__(self): def __post_init__(self):
run_auto_creation(self) run_auto_creation(self)
It can be used like Then it can be used like
b_args = get_default_args(B) b_args = get_default_args(B)
b = B(**b_args) b = B(**b_args)
@ -82,8 +82,8 @@ something like the following. (The modification itself is done by the function
self.a = A(**self.a_args) self.a = A(**self.a_args)
2. Pluggability. Instead of a dataclass-style member being given a concrete class, 2. Pluggability. Instead of a dataclass-style member being given a concrete class,
you can give a base class and the implementation is looked up by name in the global it can be given a base class and the implementation will be looked up by name in the
`registry` in this module. E.g. global `registry` in this module. E.g.
class A(ReplaceableBase): class A(ReplaceableBase):
k: int = 1 k: int = 1
@ -126,14 +126,14 @@ will expand to
default_factory=lambda: DictConfig({"k": 1, "m": 3} default_factory=lambda: DictConfig({"k": 1, "m": 3}
) )
a_A2_args: DictConfig = dataclasses.field( a_A2_args: DictConfig = dataclasses.field(
default_factory=lambda: DictConfig({"k": 1, "m": 3} default_factory=lambda: DictConfig({"k": 1, "n": 2}
) )
b_class_type: Optional[str] = "A2" b_class_type: Optional[str] = "A2"
b_A1_args: DictConfig = dataclasses.field( b_A1_args: DictConfig = dataclasses.field(
default_factory=lambda: DictConfig({"k": 1, "m": 3} default_factory=lambda: DictConfig({"k": 1, "m": 3}
) )
b_A2_args: DictConfig = dataclasses.field( b_A2_args: DictConfig = dataclasses.field(
default_factory=lambda: DictConfig({"k": 1, "m": 3} default_factory=lambda: DictConfig({"k": 1, "n": 2}
) )
def __post_init__(self): def __post_init__(self):
@ -155,12 +155,14 @@ will expand to
3. Aside from these classes, the members of these classes should be things 3. Aside from these classes, the members of these classes should be things
which DictConfig is happy with: e.g. (bool, int, str, None, float) and what which DictConfig is happy with: e.g. (bool, int, str, None, float) and what
can be built from them with DictConfigs and lists of them. can be built from them with `DictConfig`s and lists of them.
In addition, you can call get_default_args on a function or class to get In addition, you can call `get_default_args` on a function or class to get
the DictConfig of its defaulted arguments, assuming those are all things the `DictConfig` of its defaulted arguments, assuming those are all things
which DictConfig is happy with. If you want to use such a thing as a member which `DictConfig` is happy with, so long as you add a call to
of another configured class, `get_default_args_field` is a helper. `enable_get_default_args` after its definition. If you want to use such a
thing as the default for a member of another configured class,
`get_default_args_field` is a helper.
""" """
@ -172,6 +174,7 @@ _unprocessed_warning: str = (
TYPE_SUFFIX: str = "_class_type" TYPE_SUFFIX: str = "_class_type"
ARGS_SUFFIX: str = "_args" ARGS_SUFFIX: str = "_args"
ENABLED_SUFFIX: str = "_enabled"
class ReplaceableBase: class ReplaceableBase:
@ -351,7 +354,8 @@ class _ProcessType(Enum):
CONFIGURABLE = 1 CONFIGURABLE = 1
REPLACEABLE = 2 REPLACEABLE = 2
OPTIONAL_REPLACEABLE = 3 OPTIONAL_CONFIGURABLE = 3
OPTIONAL_REPLACEABLE = 4
def _default_create( def _default_create(
@ -379,6 +383,15 @@ def _default_create(
args = getattr(self, name + ARGS_SUFFIX) args = getattr(self, name + ARGS_SUFFIX)
setattr(self, name, type_(**args)) setattr(self, name, type_(**args))
def inner_optional(self):
expand_args_fields(type_)
enabled = getattr(self, name + ENABLED_SUFFIX)
if enabled:
args = getattr(self, name + ARGS_SUFFIX)
setattr(self, name, type_(**args))
else:
setattr(self, name, None)
def inner_pluggable(self): def inner_pluggable(self):
type_name = getattr(self, name + TYPE_SUFFIX) type_name = getattr(self, name + TYPE_SUFFIX)
if type_name is None: if type_name is None:
@ -398,6 +411,8 @@ def _default_create(
args = getattr(self, f"{name}_{type_name}{ARGS_SUFFIX}") args = getattr(self, f"{name}_{type_name}{ARGS_SUFFIX}")
setattr(self, name, chosen_class(**args)) setattr(self, name, chosen_class(**args))
if process_type == _ProcessType.OPTIONAL_CONFIGURABLE:
return inner_optional
return inner if process_type == _ProcessType.CONFIGURABLE else inner_pluggable return inner if process_type == _ProcessType.CONFIGURABLE else inner_pluggable
@ -608,8 +623,8 @@ def expand_args_fields(
with with
x_Y_args : DictConfig = dataclasses.field(default_factory=lambda: DictConfig()) x_Y_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(Y))
x_Z_args : DictConfig = dataclasses.field(default_factory=lambda: DictConfig()) x_Z_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(Z))
def create_x(self): def create_x(self):
self.x = registry.get(X, self.x_class_type)( self.x = registry.get(X, self.x_class_type)(
**self.getattr(f"x_{self.x_class_type}_args) **self.getattr(f"x_{self.x_class_type}_args)
@ -629,8 +644,8 @@ def expand_args_fields(
with with
x_Y_args : DictConfig = dataclasses.field(default_factory=lambda: DictConfig()) x_Y_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(Y))
x_Z_args : DictConfig = dataclasses.field(default_factory=lambda: DictConfig()) x_Z_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(Z))
def create_x(self): def create_x(self):
if self.x_class_type is None: if self.x_class_type is None:
self.x = None self.x = None
@ -653,10 +668,30 @@ def expand_args_fields(
will be replaced with will be replaced with
x_args : DictConfig = dataclasses.field(default_factory=lambda: DictConfig()) x_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(X))
def create_x(self): def create_x(self):
self.x = X(self.x_args) self.x = X(self.x_args)
Similarly, replace,
x: Optional[X]
and optionally
def create_x(self):...
x_enabled: bool = ...
with
x_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(X))
x_enabled: bool = False
def create_x(self):
if self.x_enabled:
self.x = X(self.x_args)
else:
self.x = None
Also adds the following class members, unannotated so that dataclass Also adds the following class members, unannotated so that dataclass
ignores them. ignores them.
- _creation_functions: Tuple[str] of all the create_ functions, - _creation_functions: Tuple[str] of all the create_ functions,
@ -779,6 +814,9 @@ def _get_type_to_process(type_) -> Optional[Tuple[Type, _ProcessType]]:
): ):
return underlying, _ProcessType.OPTIONAL_REPLACEABLE return underlying, _ProcessType.OPTIONAL_REPLACEABLE
if isinstance(underlying, type) and issubclass(underlying, Configurable):
return underlying, _ProcessType.OPTIONAL_CONFIGURABLE
if not isinstance(type_, type): if not isinstance(type_, type):
# e.g. any other Union or Tuple # e.g. any other Union or Tuple
return return
@ -817,7 +855,7 @@ def _process_member(
# there are non-defaulted standard class members. # there are non-defaulted standard class members.
del some_class.__annotations__[name] del some_class.__annotations__[name]
if process_type != _ProcessType.CONFIGURABLE: if process_type in (_ProcessType.REPLACEABLE, _ProcessType.OPTIONAL_REPLACEABLE):
type_name = name + TYPE_SUFFIX type_name = name + TYPE_SUFFIX
if type_name not in some_class.__annotations__: if type_name not in some_class.__annotations__:
if process_type == _ProcessType.OPTIONAL_REPLACEABLE: if process_type == _ProcessType.OPTIONAL_REPLACEABLE:
@ -866,6 +904,11 @@ def _process_member(
_do_not_process=_do_not_process + (some_class,), _do_not_process=_do_not_process + (some_class,),
), ),
) )
if process_type == _ProcessType.OPTIONAL_CONFIGURABLE:
enabled_name = name + ENABLED_SUFFIX
if enabled_name not in some_class.__annotations__:
some_class.__annotations__[enabled_name] = bool
setattr(some_class, enabled_name, False)
creation_function_name = f"create_{name}" creation_function_name = f"create_{name}"
if not hasattr(some_class, creation_function_name): if not hasattr(some_class, creation_function_name):
@ -884,7 +927,8 @@ def remove_unused_components(dict_: DictConfig) -> None:
pluggable parts which are not in use. pluggable parts which are not in use.
For example, if renderer_class_type is SignedDistanceFunctionRenderer, For example, if renderer_class_type is SignedDistanceFunctionRenderer,
the renderer_MultiPassEmissionAbsorptionRenderer_args will be the renderer_MultiPassEmissionAbsorptionRenderer_args will be
removed. removed. Also, if chocolate_enabled is False, then chocolate_args will
be removed.
Args: Args:
dict_: (MODIFIED IN PLACE) a DictConfig instance dict_: (MODIFIED IN PLACE) a DictConfig instance
@ -904,6 +948,14 @@ def remove_unused_components(dict_: DictConfig) -> None:
if key.startswith(replaceable + "_") and key != expect: if key.startswith(replaceable + "_") and key != expect:
del dict_[key] del dict_[key]
suffix_length = len(ENABLED_SUFFIX)
enableables = [key[:-suffix_length] for key in keys if key.endswith(ENABLED_SUFFIX)]
for enableable in enableables:
enabled = dict_[enableable + ENABLED_SUFFIX]
if not enabled:
with open_dict(dict_):
dict_.pop(enableable + ARGS_SUFFIX, None)
for key in dict_: for key in dict_:
if isinstance(dict_.get(key), DictConfig): if isinstance(dict_.get(key), DictConfig):
remove_unused_components(dict_[key]) remove_unused_components(dict_[key])

View File

@ -106,8 +106,10 @@ class TestConfig(unittest.TestCase):
gt(Optional[Fruit]), (Fruit, _ProcessType.OPTIONAL_REPLACEABLE) gt(Optional[Fruit]), (Fruit, _ProcessType.OPTIONAL_REPLACEABLE)
) )
self.assertEqual(gt(MainTest), (MainTest, _ProcessType.CONFIGURABLE)) self.assertEqual(gt(MainTest), (MainTest, _ProcessType.CONFIGURABLE))
self.assertEqual(
gt(Optional[MainTest]), (MainTest, _ProcessType.OPTIONAL_CONFIGURABLE)
)
self.assertIsNone(gt(Optional[int])) self.assertIsNone(gt(Optional[int]))
self.assertIsNone(gt(Optional[MainTest]))
self.assertIsNone(gt(Tuple[Fruit])) self.assertIsNone(gt(Tuple[Fruit]))
self.assertIsNone(gt(Tuple[Fruit, Animal])) self.assertIsNone(gt(Tuple[Fruit, Animal]))
self.assertIsNone(gt(Optional[List[int]])) self.assertIsNone(gt(Optional[List[int]]))
@ -427,15 +429,24 @@ class TestConfig(unittest.TestCase):
run_auto_creation(self) run_auto_creation(self)
class C(Configurable): class C(Configurable):
b: B b1: B
b2: Optional[B]
b3: Optional[B]
b2_enabled: bool = True
def __post_init__(self): def __post_init__(self):
run_auto_creation(self) run_auto_creation(self)
c_args = get_default_args(C) c_args = get_default_args(C)
c = C(**c_args) c = C(**c_args)
self.assertIsInstance(c.b.a, A) self.assertIsInstance(c.b1.a, A)
self.assertEqual(c.b.a.n, 9) self.assertEqual(c.b1.a.n, 9)
self.assertFalse(hasattr(c, "b1_enabled"))
self.assertIsInstance(c.b2.a, A)
self.assertEqual(c.b2.a.n, 9)
self.assertTrue(c.b2_enabled)
self.assertIsNone(c.b3)
self.assertFalse(c.b3_enabled)
def test_doc(self): def test_doc(self):
# The case in the docstring. # The case in the docstring.
@ -522,7 +533,7 @@ class TestConfig(unittest.TestCase):
# Like torch.nn.Module, this class contains annotations # Like torch.nn.Module, this class contains annotations
# but is not designed to be dataclass'd. # but is not designed to be dataclass'd.
# This test ensures that such classes, when inherited fron, # This test ensures that such classes, when inherited fron,
# are not accidentally expand_args_fields. # are not accidentally affected by expand_args_fields.
a: int = 9 a: int = 9
b: int b: int
@ -654,6 +665,15 @@ class TestConfig(unittest.TestCase):
self.assertEqual(sorted(instance_data.keys()), expected_keys) self.assertEqual(sorted(instance_data.keys()), expected_keys)
self.assertEqual(instance_data, expected) self.assertEqual(instance_data, expected)
def test_remove_unused_components_optional(self):
class MainTestWrapper(Configurable):
mt: Optional[MainTest]
args = get_default_args(MainTestWrapper)
self.assertEqual(list(args.keys()), ["mt_args", "mt_enabled"])
remove_unused_components(args)
self.assertEqual(OmegaConf.to_yaml(args), "mt_enabled: false\n")
@dataclass(eq=False) @dataclass(eq=False)
class MockDataclass: class MockDataclass: