diff --git a/pytorch3d/implicitron/tools/config.py b/pytorch3d/implicitron/tools/config.py index 79cda30a..1605f8b4 100644 --- a/pytorch3d/implicitron/tools/config.py +++ b/pytorch3d/implicitron/tools/config.py @@ -887,6 +887,40 @@ def get_default_args_field( return dataclasses.field(default_factory=create) +def _get_default_args_field_from_registry( + *, + base_class_wanted: Type[_X], + name: str, + _do_not_process: Tuple[type, ...] = (), + _hook: Optional[Callable[[DictConfig], None]] = None, +): + """ + Get a dataclass field which defaults to + get_default_args(registry.get(base_class_wanted, name)). + + This is used internally in place of get_default_args_field in + order that default values are updated if a class is redefined. + + Args: + base_class_wanted: As for registry.get. + name: As for registry.get. + _do_not_process: As for get_default_args + _hook: Function called on the result before returning. + + Returns: + function to return new DictConfig object + """ + + def create(): + C = registry.get(base_class_wanted=base_class_wanted, name=name) + args = get_default_args(C, _do_not_process=_do_not_process) + if _hook is not None: + _hook(args) + return args + + return dataclasses.field(default_factory=create) + + def _get_type_to_process(type_) -> Optional[Tuple[Type, _ProcessType]]: """ If a member is annotated as `type_`, and that should expanded in @@ -978,8 +1012,9 @@ def _process_member( setattr( some_class, args_name, - get_default_args_field( - derived_type, + _get_default_args_field_from_registry( + base_class_wanted=type_, + name=derived_type.__name__, _do_not_process=_do_not_process + (some_class,), _hook=hook_closed, ), diff --git a/tests/implicitron/test_config.py b/tests/implicitron/test_config.py index 503f8ab5..590b9dea 100644 --- a/tests/implicitron/test_config.py +++ b/tests/implicitron/test_config.py @@ -378,14 +378,20 @@ class TestConfig(unittest.TestCase): with self.assertWarnsRegex( UserWarning, "New implementation of Grape is being chosen." ): - bowl = FruitBowl(**bowl_args) - self.assertIsInstance(bowl.main_fruit, Grape) + defaulted_bowl = FruitBowl() + self.assertIsInstance(defaulted_bowl.main_fruit, Grape) + self.assertEqual(defaulted_bowl.main_fruit.large, True) + self.assertEqual(defaulted_bowl.main_fruit.get_color(), "green") + with self.assertWarnsRegex( + UserWarning, "New implementation of Grape is being chosen." + ): + args_bowl = FruitBowl(**bowl_args) + self.assertIsInstance(args_bowl.main_fruit, Grape) # Redefining the same class won't help with defaults because encoded in args - self.assertEqual(bowl.main_fruit.large, False) - + self.assertEqual(args_bowl.main_fruit.large, False) # But the override worked. - self.assertEqual(bowl.main_fruit.get_color(), "green") + self.assertEqual(args_bowl.main_fruit.get_color(), "green") # 2. Try redefining without the dataclass modifier # This relies on the fact that default creation processes the class. @@ -397,7 +403,7 @@ class TestConfig(unittest.TestCase): with self.assertWarnsRegex( UserWarning, "New implementation of Grape is being chosen." ): - bowl = FruitBowl(**bowl_args) + FruitBowl(**bowl_args) # 3. Adding a new class doesn't get picked up, because the first # get_default_args call has frozen FruitBowl. This is intrinsic to