hooks and allow registering base class

Summary: Allow a class to modify its subparts in get_default_args by defining the special function provide_config_hook.

Reviewed By: davnov134

Differential Revision: D36671081

fbshipit-source-id: 3e5b73880cb846c494a209c4479835f6352f45cf
This commit is contained in:
Jeremy Reizenstein 2022-06-10 12:22:46 -07:00 committed by Facebook GitHub Bot
parent 5cd70067e2
commit 8bc0a04e86
2 changed files with 68 additions and 9 deletions

View File

@ -11,6 +11,7 @@ import sys
import warnings
from collections import Counter, defaultdict
from enum import Enum
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union
from omegaconf import DictConfig, OmegaConf, open_dict
@ -177,6 +178,7 @@ ARGS_SUFFIX: str = "_args"
ENABLED_SUFFIX: str = "_enabled"
CREATE_PREFIX: str = "create_"
IMPL_SUFFIX: str = "_impl"
TWEAK_SUFFIX: str = "_tweak_args"
class ReplaceableBase:
@ -261,13 +263,9 @@ class _Registry:
raise ValueError(
f"Cannot register {some_class}. Cannot tell what it is."
)
if some_class is base_class:
raise ValueError(f"Attempted to register the base class {some_class}")
self._mapping[base_class][name] = some_class
def get(
self, base_class_wanted: Type[ReplaceableBase], name: str
) -> Type[ReplaceableBase]:
def get(self, base_class_wanted: Type[_X], name: str) -> Type[_X]:
"""
Retrieve a class from the registry by name
@ -295,6 +293,7 @@ class _Registry:
raise ValueError(
f"{name} resolves to {result} which does not subclass {base_class_wanted}"
)
# pyre-ignore[7]
return result
def get_all(
@ -773,6 +772,14 @@ def expand_args_fields(
transformed, with values giving the types they were declared to have.
(E.g. {"x": X} or {"x": Optional[X]} in the cases above.)
In addition, if the class has a member function
@classmethod
def x_tweak_args(cls, member_type: Type, args: DictConfig) -> None
then the default_factory of x_args will also have a call to x_tweak_args(X, x_args) and
the default_factory of x_Y_args will also have a call to x_tweak_args(Y, x_Y_args).
Args:
some_class: the class to be processed
_do_not_process: Internal use for get_default_args: Because get_default_args calls
@ -849,19 +856,29 @@ def expand_args_fields(
return some_class
def get_default_args_field(C, *, _do_not_process: Tuple[type, ...] = ()):
def get_default_args_field(
C,
*,
_do_not_process: Tuple[type, ...] = (),
_hook: Optional[Callable[[DictConfig], None]] = None,
):
"""
Get a dataclass field which defaults to get_default_args(...)
Args:
As for get_default_args.
C: As for get_default_args.
_do_not_process: As for get_default_args
_hook: Function called on the result before returning.
Returns:
function to return new DictConfig object
"""
def create():
return get_default_args(C, _do_not_process=_do_not_process)
args = get_default_args(C, _do_not_process=_do_not_process)
if _hook is not None:
_hook(args)
return args
return dataclasses.field(default_factory=create)
@ -924,6 +941,7 @@ def _process_member(
# sure they go at the end of __annotations__ in case
# there are non-defaulted standard class members.
del some_class.__annotations__[name]
hook = getattr(some_class, name + TWEAK_SUFFIX, None)
if process_type in (_ProcessType.REPLACEABLE, _ProcessType.OPTIONAL_REPLACEABLE):
type_name = name + TYPE_SUFFIX
@ -949,11 +967,17 @@ def _process_member(
f"Cannot generate {args_name} because it is already present."
)
some_class.__annotations__[args_name] = DictConfig
if hook is not None:
hook_closed = partial(hook, derived_type)
else:
hook_closed = None
setattr(
some_class,
args_name,
get_default_args_field(
derived_type, _do_not_process=_do_not_process + (some_class,)
derived_type,
_do_not_process=_do_not_process + (some_class,),
_hook=hook_closed,
),
)
else:
@ -966,12 +990,17 @@ def _process_member(
raise ValueError(f"Cannot process {type_} inside {some_class}")
some_class.__annotations__[args_name] = DictConfig
if hook is not None:
hook_closed = partial(hook, type_)
else:
hook_closed = None
setattr(
some_class,
args_name,
get_default_args_field(
type_,
_do_not_process=_do_not_process + (some_class,),
_hook=hook_closed,
),
)
if process_type == _ProcessType.OPTIONAL_CONFIGURABLE:

View File

@ -678,6 +678,36 @@ class TestConfig(unittest.TestCase):
remove_unused_components(args)
self.assertEqual(OmegaConf.to_yaml(args), "mt_enabled: false\n")
def test_tweak_hook(self):
class A(Configurable):
n: int = 9
class Wrapper(Configurable):
fruit: Fruit
fruit_class_type: str = "Pear"
fruit2: Fruit
fruit2_class_type: str = "Pear"
a: A
a2: A
@classmethod
def a_tweak_args(cls, type, args):
assert type == A
args.n = 993
@classmethod
def fruit_tweak_args(cls, type, args):
assert issubclass(type, Fruit)
if type == Pear:
assert args.n_pips == 13
args.n_pips = 19
args = get_default_args(Wrapper)
self.assertEqual(args.a_args.n, 993)
self.assertEqual(args.a2_args.n, 9)
self.assertEqual(args.fruit_Pear_args.n_pips, 19)
self.assertEqual(args.fruit2_Pear_args.n_pips, 13)
def test_impls(self):
# Check that create_x actually uses create_x_impl to do its work
# by using all the member types, both with a faked impl function