Optional[Configurable] in config

Summary: A new type of auto-expanded member of a Configurable: something of type Optional[X] where X is a Configurable. This works like X but its construction is controlled by a boolean membername_enabled.

Reviewed By: davnov134

Differential Revision: D35368269

fbshipit-source-id: 7e0c8a3e8c4930b0aa942fa1b325ce65336ebd5f
This commit is contained in:
Jeremy Reizenstein 2022-04-06 05:56:14 -07:00 committed by Facebook GitHub Bot
parent e10a90140d
commit 722646863c
2 changed files with 95 additions and 23 deletions

View File

@ -60,7 +60,7 @@ in dataclass style.
def __post_init__(self):
run_auto_creation(self)
It can be used like
Then it can be used like
b_args = get_default_args(B)
b = B(**b_args)
@ -82,8 +82,8 @@ something like the following. (The modification itself is done by the function
self.a = A(**self.a_args)
2. Pluggability. Instead of a dataclass-style member being given a concrete class,
you can give a base class and the implementation is looked up by name in the global
`registry` in this module. E.g.
it can be given a base class and the implementation will be looked up by name in the
global `registry` in this module. E.g.
class A(ReplaceableBase):
k: int = 1
@ -126,14 +126,14 @@ will expand to
default_factory=lambda: DictConfig({"k": 1, "m": 3}
)
a_A2_args: DictConfig = dataclasses.field(
default_factory=lambda: DictConfig({"k": 1, "m": 3}
default_factory=lambda: DictConfig({"k": 1, "n": 2}
)
b_class_type: Optional[str] = "A2"
b_A1_args: DictConfig = dataclasses.field(
default_factory=lambda: DictConfig({"k": 1, "m": 3}
)
b_A2_args: DictConfig = dataclasses.field(
default_factory=lambda: DictConfig({"k": 1, "m": 3}
default_factory=lambda: DictConfig({"k": 1, "n": 2}
)
def __post_init__(self):
@ -155,12 +155,14 @@ will expand to
3. Aside from these classes, the members of these classes should be things
which DictConfig is happy with: e.g. (bool, int, str, None, float) and what
can be built from them with DictConfigs and lists of them.
can be built from them with `DictConfig`s and lists of them.
In addition, you can call get_default_args on a function or class to get
the DictConfig of its defaulted arguments, assuming those are all things
which DictConfig is happy with. If you want to use such a thing as a member
of another configured class, `get_default_args_field` is a helper.
In addition, you can call `get_default_args` on a function or class to get
the `DictConfig` of its defaulted arguments, assuming those are all things
which `DictConfig` is happy with, so long as you add a call to
`enable_get_default_args` after its definition. If you want to use such a
thing as the default for a member of another configured class,
`get_default_args_field` is a helper.
"""
@ -172,6 +174,7 @@ _unprocessed_warning: str = (
TYPE_SUFFIX: str = "_class_type"
ARGS_SUFFIX: str = "_args"
ENABLED_SUFFIX: str = "_enabled"
class ReplaceableBase:
@ -351,7 +354,8 @@ class _ProcessType(Enum):
CONFIGURABLE = 1
REPLACEABLE = 2
OPTIONAL_REPLACEABLE = 3
OPTIONAL_CONFIGURABLE = 3
OPTIONAL_REPLACEABLE = 4
def _default_create(
@ -379,6 +383,15 @@ def _default_create(
args = getattr(self, name + ARGS_SUFFIX)
setattr(self, name, type_(**args))
def inner_optional(self):
expand_args_fields(type_)
enabled = getattr(self, name + ENABLED_SUFFIX)
if enabled:
args = getattr(self, name + ARGS_SUFFIX)
setattr(self, name, type_(**args))
else:
setattr(self, name, None)
def inner_pluggable(self):
type_name = getattr(self, name + TYPE_SUFFIX)
if type_name is None:
@ -398,6 +411,8 @@ def _default_create(
args = getattr(self, f"{name}_{type_name}{ARGS_SUFFIX}")
setattr(self, name, chosen_class(**args))
if process_type == _ProcessType.OPTIONAL_CONFIGURABLE:
return inner_optional
return inner if process_type == _ProcessType.CONFIGURABLE else inner_pluggable
@ -608,8 +623,8 @@ def expand_args_fields(
with
x_Y_args : DictConfig = dataclasses.field(default_factory=lambda: DictConfig())
x_Z_args : DictConfig = dataclasses.field(default_factory=lambda: DictConfig())
x_Y_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(Y))
x_Z_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(Z))
def create_x(self):
self.x = registry.get(X, self.x_class_type)(
**self.getattr(f"x_{self.x_class_type}_args)
@ -629,8 +644,8 @@ def expand_args_fields(
with
x_Y_args : DictConfig = dataclasses.field(default_factory=lambda: DictConfig())
x_Z_args : DictConfig = dataclasses.field(default_factory=lambda: DictConfig())
x_Y_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(Y))
x_Z_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(Z))
def create_x(self):
if self.x_class_type is None:
self.x = None
@ -653,10 +668,30 @@ def expand_args_fields(
will be replaced with
x_args : DictConfig = dataclasses.field(default_factory=lambda: DictConfig())
x_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(X))
def create_x(self):
self.x = X(self.x_args)
Similarly, replace,
x: Optional[X]
and optionally
def create_x(self):...
x_enabled: bool = ...
with
x_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(X))
x_enabled: bool = False
def create_x(self):
if self.x_enabled:
self.x = X(self.x_args)
else:
self.x = None
Also adds the following class members, unannotated so that dataclass
ignores them.
- _creation_functions: Tuple[str] of all the create_ functions,
@ -779,6 +814,9 @@ def _get_type_to_process(type_) -> Optional[Tuple[Type, _ProcessType]]:
):
return underlying, _ProcessType.OPTIONAL_REPLACEABLE
if isinstance(underlying, type) and issubclass(underlying, Configurable):
return underlying, _ProcessType.OPTIONAL_CONFIGURABLE
if not isinstance(type_, type):
# e.g. any other Union or Tuple
return
@ -817,7 +855,7 @@ def _process_member(
# there are non-defaulted standard class members.
del some_class.__annotations__[name]
if process_type != _ProcessType.CONFIGURABLE:
if process_type in (_ProcessType.REPLACEABLE, _ProcessType.OPTIONAL_REPLACEABLE):
type_name = name + TYPE_SUFFIX
if type_name not in some_class.__annotations__:
if process_type == _ProcessType.OPTIONAL_REPLACEABLE:
@ -866,6 +904,11 @@ def _process_member(
_do_not_process=_do_not_process + (some_class,),
),
)
if process_type == _ProcessType.OPTIONAL_CONFIGURABLE:
enabled_name = name + ENABLED_SUFFIX
if enabled_name not in some_class.__annotations__:
some_class.__annotations__[enabled_name] = bool
setattr(some_class, enabled_name, False)
creation_function_name = f"create_{name}"
if not hasattr(some_class, creation_function_name):
@ -884,7 +927,8 @@ def remove_unused_components(dict_: DictConfig) -> None:
pluggable parts which are not in use.
For example, if renderer_class_type is SignedDistanceFunctionRenderer,
the renderer_MultiPassEmissionAbsorptionRenderer_args will be
removed.
removed. Also, if chocolate_enabled is False, then chocolate_args will
be removed.
Args:
dict_: (MODIFIED IN PLACE) a DictConfig instance
@ -904,6 +948,14 @@ def remove_unused_components(dict_: DictConfig) -> None:
if key.startswith(replaceable + "_") and key != expect:
del dict_[key]
suffix_length = len(ENABLED_SUFFIX)
enableables = [key[:-suffix_length] for key in keys if key.endswith(ENABLED_SUFFIX)]
for enableable in enableables:
enabled = dict_[enableable + ENABLED_SUFFIX]
if not enabled:
with open_dict(dict_):
dict_.pop(enableable + ARGS_SUFFIX, None)
for key in dict_:
if isinstance(dict_.get(key), DictConfig):
remove_unused_components(dict_[key])

View File

@ -106,8 +106,10 @@ class TestConfig(unittest.TestCase):
gt(Optional[Fruit]), (Fruit, _ProcessType.OPTIONAL_REPLACEABLE)
)
self.assertEqual(gt(MainTest), (MainTest, _ProcessType.CONFIGURABLE))
self.assertEqual(
gt(Optional[MainTest]), (MainTest, _ProcessType.OPTIONAL_CONFIGURABLE)
)
self.assertIsNone(gt(Optional[int]))
self.assertIsNone(gt(Optional[MainTest]))
self.assertIsNone(gt(Tuple[Fruit]))
self.assertIsNone(gt(Tuple[Fruit, Animal]))
self.assertIsNone(gt(Optional[List[int]]))
@ -427,15 +429,24 @@ class TestConfig(unittest.TestCase):
run_auto_creation(self)
class C(Configurable):
b: B
b1: B
b2: Optional[B]
b3: Optional[B]
b2_enabled: bool = True
def __post_init__(self):
run_auto_creation(self)
c_args = get_default_args(C)
c = C(**c_args)
self.assertIsInstance(c.b.a, A)
self.assertEqual(c.b.a.n, 9)
self.assertIsInstance(c.b1.a, A)
self.assertEqual(c.b1.a.n, 9)
self.assertFalse(hasattr(c, "b1_enabled"))
self.assertIsInstance(c.b2.a, A)
self.assertEqual(c.b2.a.n, 9)
self.assertTrue(c.b2_enabled)
self.assertIsNone(c.b3)
self.assertFalse(c.b3_enabled)
def test_doc(self):
# The case in the docstring.
@ -522,7 +533,7 @@ class TestConfig(unittest.TestCase):
# Like torch.nn.Module, this class contains annotations
# but is not designed to be dataclass'd.
# This test ensures that such classes, when inherited fron,
# are not accidentally expand_args_fields.
# are not accidentally affected by expand_args_fields.
a: int = 9
b: int
@ -654,6 +665,15 @@ class TestConfig(unittest.TestCase):
self.assertEqual(sorted(instance_data.keys()), expected_keys)
self.assertEqual(instance_data, expected)
def test_remove_unused_components_optional(self):
class MainTestWrapper(Configurable):
mt: Optional[MainTest]
args = get_default_args(MainTestWrapper)
self.assertEqual(list(args.keys()), ["mt_args", "mt_enabled"])
remove_unused_components(args)
self.assertEqual(OmegaConf.to_yaml(args), "mt_enabled: false\n")
@dataclass(eq=False)
class MockDataclass: