mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-19 05:40:34 +08:00
Optional[Configurable] in config
Summary: A new type of auto-expanded member of a Configurable: something of type Optional[X] where X is a Configurable. This works like X but its construction is controlled by a boolean membername_enabled. Reviewed By: davnov134 Differential Revision: D35368269 fbshipit-source-id: 7e0c8a3e8c4930b0aa942fa1b325ce65336ebd5f
This commit is contained in:
committed by
Facebook GitHub Bot
parent
e10a90140d
commit
722646863c
@@ -106,8 +106,10 @@ class TestConfig(unittest.TestCase):
|
||||
gt(Optional[Fruit]), (Fruit, _ProcessType.OPTIONAL_REPLACEABLE)
|
||||
)
|
||||
self.assertEqual(gt(MainTest), (MainTest, _ProcessType.CONFIGURABLE))
|
||||
self.assertEqual(
|
||||
gt(Optional[MainTest]), (MainTest, _ProcessType.OPTIONAL_CONFIGURABLE)
|
||||
)
|
||||
self.assertIsNone(gt(Optional[int]))
|
||||
self.assertIsNone(gt(Optional[MainTest]))
|
||||
self.assertIsNone(gt(Tuple[Fruit]))
|
||||
self.assertIsNone(gt(Tuple[Fruit, Animal]))
|
||||
self.assertIsNone(gt(Optional[List[int]]))
|
||||
@@ -427,15 +429,24 @@ class TestConfig(unittest.TestCase):
|
||||
run_auto_creation(self)
|
||||
|
||||
class C(Configurable):
|
||||
b: B
|
||||
b1: B
|
||||
b2: Optional[B]
|
||||
b3: Optional[B]
|
||||
b2_enabled: bool = True
|
||||
|
||||
def __post_init__(self):
|
||||
run_auto_creation(self)
|
||||
|
||||
c_args = get_default_args(C)
|
||||
c = C(**c_args)
|
||||
self.assertIsInstance(c.b.a, A)
|
||||
self.assertEqual(c.b.a.n, 9)
|
||||
self.assertIsInstance(c.b1.a, A)
|
||||
self.assertEqual(c.b1.a.n, 9)
|
||||
self.assertFalse(hasattr(c, "b1_enabled"))
|
||||
self.assertIsInstance(c.b2.a, A)
|
||||
self.assertEqual(c.b2.a.n, 9)
|
||||
self.assertTrue(c.b2_enabled)
|
||||
self.assertIsNone(c.b3)
|
||||
self.assertFalse(c.b3_enabled)
|
||||
|
||||
def test_doc(self):
|
||||
# The case in the docstring.
|
||||
@@ -522,7 +533,7 @@ class TestConfig(unittest.TestCase):
|
||||
# Like torch.nn.Module, this class contains annotations
|
||||
# but is not designed to be dataclass'd.
|
||||
# This test ensures that such classes, when inherited fron,
|
||||
# are not accidentally expand_args_fields.
|
||||
# are not accidentally affected by expand_args_fields.
|
||||
a: int = 9
|
||||
b: int
|
||||
|
||||
@@ -654,6 +665,15 @@ class TestConfig(unittest.TestCase):
|
||||
self.assertEqual(sorted(instance_data.keys()), expected_keys)
|
||||
self.assertEqual(instance_data, expected)
|
||||
|
||||
def test_remove_unused_components_optional(self):
|
||||
class MainTestWrapper(Configurable):
|
||||
mt: Optional[MainTest]
|
||||
|
||||
args = get_default_args(MainTestWrapper)
|
||||
self.assertEqual(list(args.keys()), ["mt_args", "mt_enabled"])
|
||||
remove_unused_components(args)
|
||||
self.assertEqual(OmegaConf.to_yaml(args), "mt_enabled: false\n")
|
||||
|
||||
|
||||
@dataclass(eq=False)
|
||||
class MockDataclass:
|
||||
|
||||
Reference in New Issue
Block a user