fix to get_default_args(instance)

Summary:
Small config system fix. Allows get_default_args to work on an instance which has been created with a dict (instead of a DictConfig) as an args field. E.g.

```
gm = GenericModel(
        raysampler_AdaptiveRaySampler_args={"scene_extent": 4.0}
    )
    OmegaConf.structured(gm1)
```

Reviewed By: shapovalov

Differential Revision: D40341047

fbshipit-source-id: 587d0e8262e271df442a80858949a48e5d6db3df
This commit is contained in:
Jeremy Reizenstein
2022-10-13 06:05:07 -07:00
committed by Facebook GitHub Bot
parent 76cddd90be
commit 4d9215b3b4
2 changed files with 41 additions and 8 deletions

View File

@@ -687,6 +687,34 @@ class TestConfig(unittest.TestCase):
remove_unused_components(args)
self.assertEqual(OmegaConf.to_yaml(args), "mt_enabled: false\n")
def test_get_instance_args(self):
mt1, mt2 = [
MainTest(
n_ids=0,
n_reps=909,
the_fruit_class_type="Pear",
the_second_fruit_class_type="Pear",
the_fruit_Pear_args=DictConfig({}),
the_second_fruit_Pear_args={},
)
for _ in range(2)
]
# Two equivalent ways to get the DictConfig back out of an instance.
cfg1 = OmegaConf.structured(mt1)
cfg2 = get_default_args(mt2)
self.assertEqual(cfg1, cfg2)
self.assertEqual(len(cfg1.the_second_fruit_Pear_args), 0)
self.assertEqual(len(mt2.the_second_fruit_Pear_args), 0)
from_cfg = MainTest(**cfg2)
self.assertEqual(len(from_cfg.the_second_fruit_Pear_args), 0)
# If you want the complete args, merge with the defaults.
merged_args = OmegaConf.merge(get_default_args(MainTest), cfg2)
from_merged = MainTest(**merged_args)
self.assertEqual(len(from_merged.the_second_fruit_Pear_args), 1)
self.assertEqual(from_merged.n_reps, 909)
def test_tweak_hook(self):
class A(Configurable):
n: int = 9