Optional[Configurable] in config

Summary: A new type of auto-expanded member of a Configurable: something of type Optional[X] where X is a Configurable. This works like X but its construction is controlled by a boolean membername_enabled.

Reviewed By: davnov134

Differential Revision: D35368269

fbshipit-source-id: 7e0c8a3e8c4930b0aa942fa1b325ce65336ebd5f
This commit is contained in:
Jeremy Reizenstein
2022-04-06 05:56:14 -07:00
committed by Facebook GitHub Bot
parent e10a90140d
commit 722646863c
2 changed files with 95 additions and 23 deletions

View File

@@ -106,8 +106,10 @@ class TestConfig(unittest.TestCase):
gt(Optional[Fruit]), (Fruit, _ProcessType.OPTIONAL_REPLACEABLE)
)
self.assertEqual(gt(MainTest), (MainTest, _ProcessType.CONFIGURABLE))
self.assertEqual(
gt(Optional[MainTest]), (MainTest, _ProcessType.OPTIONAL_CONFIGURABLE)
)
self.assertIsNone(gt(Optional[int]))
self.assertIsNone(gt(Optional[MainTest]))
self.assertIsNone(gt(Tuple[Fruit]))
self.assertIsNone(gt(Tuple[Fruit, Animal]))
self.assertIsNone(gt(Optional[List[int]]))
@@ -427,15 +429,24 @@ class TestConfig(unittest.TestCase):
run_auto_creation(self)
class C(Configurable):
b: B
b1: B
b2: Optional[B]
b3: Optional[B]
b2_enabled: bool = True
def __post_init__(self):
run_auto_creation(self)
c_args = get_default_args(C)
c = C(**c_args)
self.assertIsInstance(c.b.a, A)
self.assertEqual(c.b.a.n, 9)
self.assertIsInstance(c.b1.a, A)
self.assertEqual(c.b1.a.n, 9)
self.assertFalse(hasattr(c, "b1_enabled"))
self.assertIsInstance(c.b2.a, A)
self.assertEqual(c.b2.a.n, 9)
self.assertTrue(c.b2_enabled)
self.assertIsNone(c.b3)
self.assertFalse(c.b3_enabled)
def test_doc(self):
# The case in the docstring.
@@ -522,7 +533,7 @@ class TestConfig(unittest.TestCase):
# Like torch.nn.Module, this class contains annotations
# but is not designed to be dataclass'd.
# This test ensures that such classes, when inherited fron,
# are not accidentally expand_args_fields.
# are not accidentally affected by expand_args_fields.
a: int = 9
b: int
@@ -654,6 +665,15 @@ class TestConfig(unittest.TestCase):
self.assertEqual(sorted(instance_data.keys()), expected_keys)
self.assertEqual(instance_data, expected)
def test_remove_unused_components_optional(self):
class MainTestWrapper(Configurable):
mt: Optional[MainTest]
args = get_default_args(MainTestWrapper)
self.assertEqual(list(args.keys()), ["mt_args", "mt_enabled"])
remove_unused_components(args)
self.assertEqual(OmegaConf.to_yaml(args), "mt_enabled: false\n")
@dataclass(eq=False)
class MockDataclass: