hooks and allow registering base class

Summary: Allow a class to modify its subparts in get_default_args by defining the special function provide_config_hook.

Reviewed By: davnov134

Differential Revision: D36671081

fbshipit-source-id: 3e5b73880cb846c494a209c4479835f6352f45cf
This commit is contained in:
Jeremy Reizenstein
2022-06-10 12:22:46 -07:00
committed by Facebook GitHub Bot
parent 5cd70067e2
commit 8bc0a04e86
2 changed files with 68 additions and 9 deletions

View File

@@ -678,6 +678,36 @@ class TestConfig(unittest.TestCase):
remove_unused_components(args)
self.assertEqual(OmegaConf.to_yaml(args), "mt_enabled: false\n")
def test_tweak_hook(self):
class A(Configurable):
n: int = 9
class Wrapper(Configurable):
fruit: Fruit
fruit_class_type: str = "Pear"
fruit2: Fruit
fruit2_class_type: str = "Pear"
a: A
a2: A
@classmethod
def a_tweak_args(cls, type, args):
assert type == A
args.n = 993
@classmethod
def fruit_tweak_args(cls, type, args):
assert issubclass(type, Fruit)
if type == Pear:
assert args.n_pips == 13
args.n_pips = 19
args = get_default_args(Wrapper)
self.assertEqual(args.a_args.n, 993)
self.assertEqual(args.a2_args.n, 9)
self.assertEqual(args.fruit_Pear_args.n_pips, 19)
self.assertEqual(args.fruit2_Pear_args.n_pips, 13)
def test_impls(self):
# Check that create_x actually uses create_x_impl to do its work
# by using all the member types, both with a faked impl function