create_x_impl

Summary: Make create_x delegate to create_x_impl so that users can rely on create_x_impl in their overrides of create_x.

Reviewed By: shapovalov, davnov134

Differential Revision: D35929810

fbshipit-source-id: 80595894ee93346b881729995775876b016fc08e
This commit is contained in:
Jeremy Reizenstein 2022-05-16 04:42:03 -07:00 committed by Facebook GitHub Bot
parent 3b2300641a
commit 899a3192b6
2 changed files with 145 additions and 23 deletions

View File

@ -175,6 +175,8 @@ _unprocessed_warning: str = (
TYPE_SUFFIX: str = "_class_type"
ARGS_SUFFIX: str = "_args"
ENABLED_SUFFIX: str = "_enabled"
CREATE_PREFIX: str = "create_"
IMPL_SUFFIX: str = "_impl"
class ReplaceableBase:
@ -375,25 +377,68 @@ def _default_create(
Returns:
Function taking one argument, the object whose member should be
initialized.
initialized, i.e. self.
"""
impl_name = f"{CREATE_PREFIX}{name}{IMPL_SUFFIX}"
def inner(self):
expand_args_fields(type_)
impl = getattr(self, impl_name)
args = getattr(self, name + ARGS_SUFFIX)
setattr(self, name, type_(**args))
impl(True, args)
def inner_optional(self):
expand_args_fields(type_)
impl = getattr(self, impl_name)
enabled = getattr(self, name + ENABLED_SUFFIX)
args = getattr(self, name + ARGS_SUFFIX)
impl(enabled, args)
def inner_pluggable(self):
type_name = getattr(self, name + TYPE_SUFFIX)
impl = getattr(self, impl_name)
if type_name is None:
args = None
else:
args = getattr(self, f"{name}_{type_name}{ARGS_SUFFIX}", None)
impl(type_name, args)
if process_type == _ProcessType.OPTIONAL_CONFIGURABLE:
return inner_optional
return inner if process_type == _ProcessType.CONFIGURABLE else inner_pluggable
def _default_create_impl(
name: str, type_: Type, process_type: _ProcessType
) -> Callable[[Any, Any, DictConfig], None]:
"""
Return the default internal function for initialising a member. This is a function
which could be called in the create_ function to initialise the member.
Args:
name: name of the member
type_: type of the member (with any Optional removed)
process_type: Shows whether member's declared type inherits ReplaceableBase,
in which case the actual type to be created is decided at
runtime.
Returns:
Function taking
- self, the object whose member should be initialized.
- option for what to do. This is
- for pluggables, the type to initialise or None to do nothing
- for non pluggables, a bool indicating whether to initialise.
- the args for initializing the member.
"""
def create_configurable(self, enabled, args):
if enabled:
args = getattr(self, name + ARGS_SUFFIX)
expand_args_fields(type_)
setattr(self, name, type_(**args))
else:
setattr(self, name, None)
def inner_pluggable(self):
type_name = getattr(self, name + TYPE_SUFFIX)
def create_pluggable(self, type_name, args):
if type_name is None:
setattr(self, name, None)
return
@ -408,12 +453,11 @@ def _default_create(
# were made in the redefinition will not be reflected here.
warnings.warn(f"New implementation of {type_name} is being chosen.")
expand_args_fields(chosen_class)
args = getattr(self, f"{name}_{type_name}{ARGS_SUFFIX}")
setattr(self, name, chosen_class(**args))
if process_type == _ProcessType.OPTIONAL_CONFIGURABLE:
return inner_optional
return inner if process_type == _ProcessType.CONFIGURABLE else inner_pluggable
if process_type in (_ProcessType.CONFIGURABLE, _ProcessType.OPTIONAL_CONFIGURABLE):
return create_configurable
return create_pluggable
def run_auto_creation(self: Any) -> None:
@ -628,11 +672,12 @@ def expand_args_fields(
x_Y_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(Y))
x_Z_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(Z))
def create_x(self):
x_type = registry.get(X, self.x_class_type)
args = self.getattr(f"x_{self.x_class_type}_args")
self.create_x_impl(self.x_class_type, args)
def create_x_impl(self, x_type, args):
x_type = registry.get(X, x_type)
expand_args_fields(x_type)
self.x = x_type(
**self.getattr(f"x_{self.x_class_type}_args)
)
self.x = x_type(**args)
x_class_type: str = "UNDEFAULTED"
without adding the optional attributes if they are already there.
@ -652,14 +697,19 @@ def expand_args_fields(
x_Z_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(Z))
def create_x(self):
if self.x_class_type is None:
args = None
else:
args = self.getattr(f"x_{self.x_class_type}_args", None)
self.create_x_impl(self.x_class_type, args)
def create_x_impl(self, x_class_type, args):
if x_class_type is None:
self.x = None
return
x_type = registry.get(X, self.x_class_type)
x_type = registry.get(X, x_class_type)
expand_args_fields(x_type)
self.x = x_type(
**self.getattr(f"x_{self.x_class_type}_args)
)
assert args is not None
self.x = x_type(**args)
x_class_type: Optional[str] = "UNDEFAULTED"
without adding the optional attributes if they are already there.
@ -676,8 +726,14 @@ def expand_args_fields(
x_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(X))
def create_x(self):
expand_args_fields(X)
self.x = X(self.x_args)
self.create_x_impl(True, self.x_args)
def create_x_impl(self, enabled, args):
if enabled:
expand_args_fields(X)
self.x = X(**args)
else:
self.x = None
Similarly, replace,
@ -693,9 +749,12 @@ def expand_args_fields(
x_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(X))
x_enabled: bool = False
def create_x(self):
if self.x_enabled:
self.create_x_impl(self.x_enabled, self.x_args)
def create_x_impl(self, enabled, args):
if enabled:
expand_args_fields(X)
self.x = X(self.x_args)
self.x = X(**args)
else:
self.x = None
@ -703,7 +762,7 @@ def expand_args_fields(
Also adds the following class members, unannotated so that dataclass
ignores them.
- _creation_functions: Tuple[str] of all the create_ functions,
including those from base classes.
including those from base classes (not the create_x_impl ones).
- _known_implementations: Dict[str, Type] containing the classes which
have been found from the registry.
(used only to raise a warning if it one has been overwritten)
@ -918,7 +977,7 @@ def _process_member(
some_class.__annotations__[enabled_name] = bool
setattr(some_class, enabled_name, False)
creation_function_name = f"create_{name}"
creation_function_name = f"{CREATE_PREFIX}{name}"
if not hasattr(some_class, creation_function_name):
setattr(
some_class,
@ -927,6 +986,14 @@ def _process_member(
)
creation_functions.append(creation_function_name)
creation_function_impl_name = f"{CREATE_PREFIX}{name}{IMPL_SUFFIX}"
if not hasattr(some_class, creation_function_impl_name):
setattr(
some_class,
creation_function_impl_name,
_default_create_impl(name, type_, process_type),
)
def remove_unused_components(dict_: DictConfig) -> None:
"""

View File

@ -69,6 +69,10 @@ class LargePear(Pear):
pass
class BoringConfigurable(Configurable):
pass
class MainTest(Configurable):
the_fruit: Fruit
n_ids: int
@ -674,6 +678,57 @@ class TestConfig(unittest.TestCase):
remove_unused_components(args)
self.assertEqual(OmegaConf.to_yaml(args), "mt_enabled: false\n")
def test_impls(self):
# Check that create_x actually uses create_x_impl to do its work
# by using all the member types, both with a faked impl function
# and without.
# members with _0 are optional and absent, those with _o are
# optional and present.
control_args = []
def fake_impl(self, control, args):
control_args.append(control)
for fake in [False, True]:
class MyClass(Configurable):
fruit: Fruit
fruit_class_type: str = "Orange"
fruit_o: Optional[Fruit]
fruit_o_class_type: str = "Orange"
fruit_0: Optional[Fruit]
fruit_0_class_type: Optional[str] = None
boring: BoringConfigurable
boring_o: Optional[BoringConfigurable]
boring_o_enabled: bool = True
boring_0: Optional[BoringConfigurable]
def __post_init__(self):
run_auto_creation(self)
if fake:
MyClass.create_fruit_impl = fake_impl
MyClass.create_fruit_o_impl = fake_impl
MyClass.create_boring_impl = fake_impl
MyClass.create_boring_o_impl = fake_impl
expand_args_fields(MyClass)
instance = MyClass()
for name in ["fruit", "fruit_o", "boring", "boring_o"]:
self.assertEqual(
hasattr(instance, name), not fake, msg=f"{name} {fake}"
)
self.assertIsNone(instance.fruit_0)
self.assertIsNone(instance.boring_0)
if not fake:
self.assertIsInstance(instance.fruit, Orange)
self.assertIsInstance(instance.fruit_o, Orange)
self.assertIsInstance(instance.boring, BoringConfigurable)
self.assertIsInstance(instance.boring_o, BoringConfigurable)
self.assertEqual(control_args, ["Orange", "Orange", True, True])
@dataclass(eq=False)
class MockDataclass: