mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 11:52:50 +08:00
small fixes to config
Summary: - indicate location of OmegaConf.structured failures - split the data gathering from enable_get_default_args to ease experimenting with it. - comment fixes. - nicer error when a_class_type has weird type. Reviewed By: kjchalup Differential Revision: D39434447 fbshipit-source-id: b80c7941547ca450e848038ef5be95b7ebbe8f3e
This commit is contained in:
parent
cb7bd33e7f
commit
da7fe2854e
@ -287,6 +287,10 @@ class _Registry:
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Cannot look up {base_class_wanted}. Cannot tell what it is."
|
f"Cannot look up {base_class_wanted}. Cannot tell what it is."
|
||||||
)
|
)
|
||||||
|
if not isinstance(name, str):
|
||||||
|
raise ValueError(
|
||||||
|
f"Cannot look up a {type(name)} in the registry. Got {name}."
|
||||||
|
)
|
||||||
result = self._mapping[base_class].get(name)
|
result = self._mapping[base_class].get(name)
|
||||||
if result is None:
|
if result is None:
|
||||||
raise ValueError(f"{name} has not been registered.")
|
raise ValueError(f"{name} has not been registered.")
|
||||||
@ -446,6 +450,11 @@ def _default_create_impl(
|
|||||||
setattr(self, name, None)
|
setattr(self, name, None)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
if not isinstance(type_name, str):
|
||||||
|
raise ValueError(
|
||||||
|
f"A {type(type_name)} was received as the type of {name}."
|
||||||
|
+ f" Perhaps this is from {name}{TYPE_SUFFIX}?"
|
||||||
|
)
|
||||||
chosen_class = registry.get(type_, type_name)
|
chosen_class = registry.get(type_, type_name)
|
||||||
if self._known_implementations.get(type_name, chosen_class) is not chosen_class:
|
if self._known_implementations.get(type_name, chosen_class) is not chosen_class:
|
||||||
# If this warning is raised, it means that a new definition of
|
# If this warning is raised, it means that a new definition of
|
||||||
@ -514,7 +523,10 @@ def get_default_args(C, *, _do_not_process: Tuple[type, ...] = ()) -> DictConfig
|
|||||||
# because in practice get_default_args_field is used for
|
# because in practice get_default_args_field is used for
|
||||||
# separate types than the outer type.
|
# separate types than the outer type.
|
||||||
|
|
||||||
out: DictConfig = OmegaConf.structured(C)
|
try:
|
||||||
|
out: DictConfig = OmegaConf.structured(C)
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"OmegaConf.structured({C}) failed") from e
|
||||||
exclude = getattr(C, "_processed_members", ())
|
exclude = getattr(C, "_processed_members", ())
|
||||||
with open_dict(out):
|
with open_dict(out):
|
||||||
for field in exclude:
|
for field in exclude:
|
||||||
@ -534,7 +546,11 @@ def get_default_args(C, *, _do_not_process: Tuple[type, ...] = ()) -> DictConfig
|
|||||||
f"Cannot get args for {C}. Was enable_get_default_args forgotten?"
|
f"Cannot get args for {C}. Was enable_get_default_args forgotten?"
|
||||||
)
|
)
|
||||||
|
|
||||||
return OmegaConf.structured(dataclass)
|
try:
|
||||||
|
out: DictConfig = OmegaConf.structured(dataclass)
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"OmegaConf.structured failed for {dataclass_name}") from e
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
def _dataclass_name_for_function(C: Any) -> str:
|
def _dataclass_name_for_function(C: Any) -> str:
|
||||||
@ -546,6 +562,53 @@ def _dataclass_name_for_function(C: Any) -> str:
|
|||||||
return name
|
return name
|
||||||
|
|
||||||
|
|
||||||
|
def _field_annotations_for_default_args(
|
||||||
|
C: Any,
|
||||||
|
) -> List[Tuple[str, Any, dataclasses.Field]]:
|
||||||
|
"""
|
||||||
|
If C is a function or a plain class with an __init__ function,
|
||||||
|
return the fields which `enable_get_default_args(C)` will need
|
||||||
|
to make a dataclass with.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
C: a function, or a class with an __init__ function. Must
|
||||||
|
have types for all its defaulted args.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
a list of fields for a dataclass.
|
||||||
|
"""
|
||||||
|
|
||||||
|
field_annotations = []
|
||||||
|
for pname, defval in _params_iter(C):
|
||||||
|
default = defval.default
|
||||||
|
if default == inspect.Parameter.empty:
|
||||||
|
# we do not have a default value for the parameter
|
||||||
|
continue
|
||||||
|
|
||||||
|
if defval.annotation == inspect._empty:
|
||||||
|
raise ValueError(
|
||||||
|
"All arguments of the input to enable_get_default_args have to"
|
||||||
|
f" be typed. Argument '{pname}' does not have a type annotation."
|
||||||
|
)
|
||||||
|
|
||||||
|
_, annotation = _resolve_optional(defval.annotation)
|
||||||
|
|
||||||
|
if isinstance(default, set): # force OmegaConf to convert it to ListConfig
|
||||||
|
default = tuple(default)
|
||||||
|
|
||||||
|
if isinstance(default, (list, dict)):
|
||||||
|
# OmegaConf will convert to [Dict|List]Config, so it is safe to reuse the value
|
||||||
|
field_ = dataclasses.field(default_factory=lambda default=default: default)
|
||||||
|
elif not _is_immutable_type(annotation, default):
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
# we can use a simple default argument for dataclass.field
|
||||||
|
field_ = dataclasses.field(default=default)
|
||||||
|
field_annotations.append((pname, defval.annotation, field_))
|
||||||
|
|
||||||
|
return field_annotations
|
||||||
|
|
||||||
|
|
||||||
def enable_get_default_args(C: Any, *, overwrite: bool = True) -> None:
|
def enable_get_default_args(C: Any, *, overwrite: bool = True) -> None:
|
||||||
"""
|
"""
|
||||||
If C is a function or a plain class with an __init__ function,
|
If C is a function or a plain class with an __init__ function,
|
||||||
@ -563,33 +626,7 @@ def enable_get_default_args(C: Any, *, overwrite: bool = True) -> None:
|
|||||||
if not inspect.isfunction(C) and not inspect.isclass(C):
|
if not inspect.isfunction(C) and not inspect.isclass(C):
|
||||||
raise ValueError(f"Unexpected {C}")
|
raise ValueError(f"Unexpected {C}")
|
||||||
|
|
||||||
field_annotations = []
|
field_annotations = _field_annotations_for_default_args(C)
|
||||||
for pname, defval in _params_iter(C):
|
|
||||||
default = defval.default
|
|
||||||
if default == inspect.Parameter.empty:
|
|
||||||
# we do not have a default value for the parameter
|
|
||||||
continue
|
|
||||||
|
|
||||||
if defval.annotation == inspect._empty:
|
|
||||||
raise ValueError(
|
|
||||||
"All arguments of the input callable have to be typed."
|
|
||||||
+ f" Argument '{pname}' does not have a type annotation."
|
|
||||||
)
|
|
||||||
|
|
||||||
_, annotation = _resolve_optional(defval.annotation)
|
|
||||||
|
|
||||||
if isinstance(default, set): # force OmegaConf to convert it to ListConfig
|
|
||||||
default = tuple(default)
|
|
||||||
|
|
||||||
if isinstance(default, (list, dict)):
|
|
||||||
# OmegaConf will convert to [Dict|List]Config, so it is safe to reuse the value
|
|
||||||
field_ = dataclasses.field(default_factory=lambda default=default: default)
|
|
||||||
elif not _is_immutable_type(annotation, default):
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
# we can use a simple default argument for dataclass.field
|
|
||||||
field_ = dataclasses.field(default=default)
|
|
||||||
field_annotations.append((pname, defval.annotation, field_))
|
|
||||||
|
|
||||||
name = _dataclass_name_for_function(C)
|
name = _dataclass_name_for_function(C)
|
||||||
module = sys.modules[C.__module__]
|
module = sys.modules[C.__module__]
|
||||||
@ -767,7 +804,7 @@ def expand_args_fields(
|
|||||||
|
|
||||||
Also adds the following class members, unannotated so that dataclass
|
Also adds the following class members, unannotated so that dataclass
|
||||||
ignores them.
|
ignores them.
|
||||||
- _creation_functions: Tuple[str] of all the create_ functions,
|
- _creation_functions: Tuple[str, ...] of all the create_ functions,
|
||||||
including those from base classes (not the create_x_impl ones).
|
including those from base classes (not the create_x_impl ones).
|
||||||
- _known_implementations: Dict[str, Type] containing the classes which
|
- _known_implementations: Dict[str, Type] containing the classes which
|
||||||
have been found from the registry.
|
have been found from the registry.
|
||||||
@ -945,7 +982,7 @@ def _get_type_to_process(type_) -> Optional[Tuple[Type, _ProcessType]]:
|
|||||||
return underlying, _ProcessType.OPTIONAL_CONFIGURABLE
|
return underlying, _ProcessType.OPTIONAL_CONFIGURABLE
|
||||||
|
|
||||||
if not isinstance(type_, type):
|
if not isinstance(type_, type):
|
||||||
# e.g. any other Union or Tuple
|
# e.g. any other Union or Tuple. Or ClassVar.
|
||||||
return
|
return
|
||||||
|
|
||||||
if issubclass(type_, ReplaceableBase) and ReplaceableBase in type_.__bases__:
|
if issubclass(type_, ReplaceableBase) and ReplaceableBase in type_.__bases__:
|
||||||
|
@ -168,7 +168,7 @@ class TestConfig(unittest.TestCase):
|
|||||||
self.assertIn(Banana, all_fruit)
|
self.assertIn(Banana, all_fruit)
|
||||||
self.assertIn(Pear, all_fruit)
|
self.assertIn(Pear, all_fruit)
|
||||||
self.assertIn(LargePear, all_fruit)
|
self.assertIn(LargePear, all_fruit)
|
||||||
self.assertEqual(set(registry.get_all(Pear)), {LargePear})
|
self.assertEqual(registry.get_all(Pear), [LargePear])
|
||||||
|
|
||||||
@registry.register
|
@registry.register
|
||||||
class Apple(Fruit):
|
class Apple(Fruit):
|
||||||
@ -178,7 +178,7 @@ class TestConfig(unittest.TestCase):
|
|||||||
class CrabApple(Apple):
|
class CrabApple(Apple):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
self.assertEqual(set(registry.get_all(Apple)), {CrabApple})
|
self.assertEqual(registry.get_all(Apple), [CrabApple])
|
||||||
|
|
||||||
self.assertIs(registry.get(Fruit, "CrabApple"), CrabApple)
|
self.assertIs(registry.get(Fruit, "CrabApple"), CrabApple)
|
||||||
|
|
||||||
@ -601,6 +601,7 @@ class TestConfig(unittest.TestCase):
|
|||||||
|
|
||||||
for C_ in [C, C_fn, C_cl]:
|
for C_ in [C, C_fn, C_cl]:
|
||||||
base = get_default_args(C_)
|
base = get_default_args(C_)
|
||||||
|
self.assertEqual(OmegaConf.to_yaml(base), "a: B1\n")
|
||||||
self.assertEqual(base.a, A.B1)
|
self.assertEqual(base.a, A.B1)
|
||||||
replaced = OmegaConf.merge(base, {"a": "B2"})
|
replaced = OmegaConf.merge(base, {"a": "B2"})
|
||||||
self.assertEqual(replaced.a, A.B2)
|
self.assertEqual(replaced.a, A.B2)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user