diff --git a/pytorch3d/implicitron/tools/config.py b/pytorch3d/implicitron/tools/config.py index c777ea74..44cdb2ed 100644 --- a/pytorch3d/implicitron/tools/config.py +++ b/pytorch3d/implicitron/tools/config.py @@ -709,8 +709,8 @@ def expand_args_fields( with - x_Y_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(Y)) - x_Z_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(Z)) + x_Y_args: dict = dataclasses.field(default_factory=lambda: get_default_args(Y)) + x_Z_args: dict = dataclasses.field(default_factory=lambda: get_default_args(Z)) def create_x(self): args = self.getattr(f"x_{self.x_class_type}_args") self.create_x_impl(self.x_class_type, args) @@ -733,8 +733,8 @@ def expand_args_fields( with - x_Y_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(Y)) - x_Z_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(Z)) + x_Y_args: dict = dataclasses.field(default_factory=lambda: get_default_args(Y)) + x_Z_args: dict = dataclasses.field(default_factory=lambda: get_default_args(Z)) def create_x(self): if self.x_class_type is None: args = None @@ -764,7 +764,7 @@ def expand_args_fields( will be replaced with - x_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(X)) + x_args: dict = dataclasses.field(default_factory=lambda: get_default_args(X)) def create_x(self): self.create_x_impl(True, self.x_args) @@ -786,7 +786,7 @@ def expand_args_fields( with - x_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(X)) + x_args: dict = dataclasses.field(default_factory=lambda: get_default_args(X)) x_enabled: bool = False def create_x(self): self.create_x_impl(self.x_enabled, self.x_args) @@ -818,6 +818,11 @@ def expand_args_fields( then the default_factory of x_args will also have a call to x_tweak_args(X, x_args) and the default_factory of x_Y_args will also have a call to x_tweak_args(Y, x_Y_args). + Note that although the *_args members are intended to have type DictConfig, they + are actually internally annotated as dicts. OmegaConf is happy to see a DictConfig + in place of a dict, but not vice-versa. Allowing dict lets a class user specify + x_args as an explicit dict without getting an incomprehensible error. + Args: some_class: the class to be processed _do_not_process: Internal use for get_default_args: Because get_default_args calls @@ -1040,7 +1045,7 @@ def _process_member( raise ValueError( f"Cannot generate {args_name} because it is already present." ) - some_class.__annotations__[args_name] = DictConfig + some_class.__annotations__[args_name] = dict if hook is not None: hook_closed = partial(hook, derived_type) else: @@ -1064,7 +1069,7 @@ def _process_member( if issubclass(type_, some_class) or type_ in _do_not_process: raise ValueError(f"Cannot process {type_} inside {some_class}") - some_class.__annotations__[args_name] = DictConfig + some_class.__annotations__[args_name] = dict if hook is not None: hook_closed = partial(hook, type_) else: diff --git a/tests/implicitron/test_config.py b/tests/implicitron/test_config.py index ed1e0696..e86d12aa 100644 --- a/tests/implicitron/test_config.py +++ b/tests/implicitron/test_config.py @@ -687,6 +687,34 @@ class TestConfig(unittest.TestCase): remove_unused_components(args) self.assertEqual(OmegaConf.to_yaml(args), "mt_enabled: false\n") + def test_get_instance_args(self): + mt1, mt2 = [ + MainTest( + n_ids=0, + n_reps=909, + the_fruit_class_type="Pear", + the_second_fruit_class_type="Pear", + the_fruit_Pear_args=DictConfig({}), + the_second_fruit_Pear_args={}, + ) + for _ in range(2) + ] + # Two equivalent ways to get the DictConfig back out of an instance. + cfg1 = OmegaConf.structured(mt1) + cfg2 = get_default_args(mt2) + self.assertEqual(cfg1, cfg2) + self.assertEqual(len(cfg1.the_second_fruit_Pear_args), 0) + self.assertEqual(len(mt2.the_second_fruit_Pear_args), 0) + + from_cfg = MainTest(**cfg2) + self.assertEqual(len(from_cfg.the_second_fruit_Pear_args), 0) + + # If you want the complete args, merge with the defaults. + merged_args = OmegaConf.merge(get_default_args(MainTest), cfg2) + from_merged = MainTest(**merged_args) + self.assertEqual(len(from_merged.the_second_fruit_Pear_args), 1) + self.assertEqual(from_merged.n_reps, 909) + def test_tweak_hook(self): class A(Configurable): n: int = 9