diff --git a/pytorch3d/implicitron/tools/config.py b/pytorch3d/implicitron/tools/config.py index 9fb7f6c7..50193662 100644 --- a/pytorch3d/implicitron/tools/config.py +++ b/pytorch3d/implicitron/tools/config.py @@ -287,6 +287,10 @@ class _Registry: raise ValueError( f"Cannot look up {base_class_wanted}. Cannot tell what it is." ) + if not isinstance(name, str): + raise ValueError( + f"Cannot look up a {type(name)} in the registry. Got {name}." + ) result = self._mapping[base_class].get(name) if result is None: raise ValueError(f"{name} has not been registered.") @@ -446,6 +450,11 @@ def _default_create_impl( setattr(self, name, None) return + if not isinstance(type_name, str): + raise ValueError( + f"A {type(type_name)} was received as the type of {name}." + + f" Perhaps this is from {name}{TYPE_SUFFIX}?" + ) chosen_class = registry.get(type_, type_name) if self._known_implementations.get(type_name, chosen_class) is not chosen_class: # If this warning is raised, it means that a new definition of @@ -514,7 +523,10 @@ def get_default_args(C, *, _do_not_process: Tuple[type, ...] = ()) -> DictConfig # because in practice get_default_args_field is used for # separate types than the outer type. - out: DictConfig = OmegaConf.structured(C) + try: + out: DictConfig = OmegaConf.structured(C) + except Exception as e: + raise ValueError(f"OmegaConf.structured({C}) failed") from e exclude = getattr(C, "_processed_members", ()) with open_dict(out): for field in exclude: @@ -534,7 +546,11 @@ def get_default_args(C, *, _do_not_process: Tuple[type, ...] = ()) -> DictConfig f"Cannot get args for {C}. Was enable_get_default_args forgotten?" ) - return OmegaConf.structured(dataclass) + try: + out: DictConfig = OmegaConf.structured(dataclass) + except Exception as e: + raise ValueError(f"OmegaConf.structured failed for {dataclass_name}") from e + return out def _dataclass_name_for_function(C: Any) -> str: @@ -546,6 +562,53 @@ def _dataclass_name_for_function(C: Any) -> str: return name +def _field_annotations_for_default_args( + C: Any, +) -> List[Tuple[str, Any, dataclasses.Field]]: + """ + If C is a function or a plain class with an __init__ function, + return the fields which `enable_get_default_args(C)` will need + to make a dataclass with. + + Args: + C: a function, or a class with an __init__ function. Must + have types for all its defaulted args. + + Returns: + a list of fields for a dataclass. + """ + + field_annotations = [] + for pname, defval in _params_iter(C): + default = defval.default + if default == inspect.Parameter.empty: + # we do not have a default value for the parameter + continue + + if defval.annotation == inspect._empty: + raise ValueError( + "All arguments of the input to enable_get_default_args have to" + f" be typed. Argument '{pname}' does not have a type annotation." + ) + + _, annotation = _resolve_optional(defval.annotation) + + if isinstance(default, set): # force OmegaConf to convert it to ListConfig + default = tuple(default) + + if isinstance(default, (list, dict)): + # OmegaConf will convert to [Dict|List]Config, so it is safe to reuse the value + field_ = dataclasses.field(default_factory=lambda default=default: default) + elif not _is_immutable_type(annotation, default): + continue + else: + # we can use a simple default argument for dataclass.field + field_ = dataclasses.field(default=default) + field_annotations.append((pname, defval.annotation, field_)) + + return field_annotations + + def enable_get_default_args(C: Any, *, overwrite: bool = True) -> None: """ If C is a function or a plain class with an __init__ function, @@ -563,33 +626,7 @@ def enable_get_default_args(C: Any, *, overwrite: bool = True) -> None: if not inspect.isfunction(C) and not inspect.isclass(C): raise ValueError(f"Unexpected {C}") - field_annotations = [] - for pname, defval in _params_iter(C): - default = defval.default - if default == inspect.Parameter.empty: - # we do not have a default value for the parameter - continue - - if defval.annotation == inspect._empty: - raise ValueError( - "All arguments of the input callable have to be typed." - + f" Argument '{pname}' does not have a type annotation." - ) - - _, annotation = _resolve_optional(defval.annotation) - - if isinstance(default, set): # force OmegaConf to convert it to ListConfig - default = tuple(default) - - if isinstance(default, (list, dict)): - # OmegaConf will convert to [Dict|List]Config, so it is safe to reuse the value - field_ = dataclasses.field(default_factory=lambda default=default: default) - elif not _is_immutable_type(annotation, default): - continue - else: - # we can use a simple default argument for dataclass.field - field_ = dataclasses.field(default=default) - field_annotations.append((pname, defval.annotation, field_)) + field_annotations = _field_annotations_for_default_args(C) name = _dataclass_name_for_function(C) module = sys.modules[C.__module__] @@ -767,7 +804,7 @@ def expand_args_fields( Also adds the following class members, unannotated so that dataclass ignores them. - - _creation_functions: Tuple[str] of all the create_ functions, + - _creation_functions: Tuple[str, ...] of all the create_ functions, including those from base classes (not the create_x_impl ones). - _known_implementations: Dict[str, Type] containing the classes which have been found from the registry. @@ -945,7 +982,7 @@ def _get_type_to_process(type_) -> Optional[Tuple[Type, _ProcessType]]: return underlying, _ProcessType.OPTIONAL_CONFIGURABLE if not isinstance(type_, type): - # e.g. any other Union or Tuple + # e.g. any other Union or Tuple. Or ClassVar. return if issubclass(type_, ReplaceableBase) and ReplaceableBase in type_.__bases__: diff --git a/tests/implicitron/test_config.py b/tests/implicitron/test_config.py index 56e7ecab..374677cf 100644 --- a/tests/implicitron/test_config.py +++ b/tests/implicitron/test_config.py @@ -168,7 +168,7 @@ class TestConfig(unittest.TestCase): self.assertIn(Banana, all_fruit) self.assertIn(Pear, all_fruit) self.assertIn(LargePear, all_fruit) - self.assertEqual(set(registry.get_all(Pear)), {LargePear}) + self.assertEqual(registry.get_all(Pear), [LargePear]) @registry.register class Apple(Fruit): @@ -178,7 +178,7 @@ class TestConfig(unittest.TestCase): class CrabApple(Apple): pass - self.assertEqual(set(registry.get_all(Apple)), {CrabApple}) + self.assertEqual(registry.get_all(Apple), [CrabApple]) self.assertIs(registry.get(Fruit, "CrabApple"), CrabApple) @@ -601,6 +601,7 @@ class TestConfig(unittest.TestCase): for C_ in [C, C_fn, C_cl]: base = get_default_args(C_) + self.assertEqual(OmegaConf.to_yaml(base), "a: B1\n") self.assertEqual(base.a, A.B1) replaced = OmegaConf.merge(base, {"a": "B2"}) self.assertEqual(replaced.a, A.B2)