From a54ad2b912891a9b22aa858c2daf93918e4bf5ed Mon Sep 17 00:00:00 2001 From: Roman Shapovalov Date: Tue, 29 Mar 2022 11:36:11 -0700 Subject: [PATCH] get_default_args for callables respects non-class type annotations and Optionals Summary: as subj Reviewed By: davnov134 Differential Revision: D35194863 fbshipit-source-id: c8e8f234083d4f0f93dca8d93e090ca0e1e1972d --- pytorch3d/implicitron/tools/config.py | 21 +++++++++++++++++++-- tests/implicitron/test_config.py | 6 ++++++ 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/pytorch3d/implicitron/tools/config.py b/pytorch3d/implicitron/tools/config.py index c75b6871..79ae2eb7 100644 --- a/pytorch3d/implicitron/tools/config.py +++ b/pytorch3d/implicitron/tools/config.py @@ -474,13 +474,15 @@ def get_default_args(C, *, _do_not_process: Tuple[type, ...] = ()) -> DictConfig + f" Argument '{pname}' does not have a type annotation." ) + _, annotation = _resolve_optional(defval.annotation) + if isinstance(default, set): # force OmegaConf to convert it to ListConfig default = tuple(default) if isinstance(default, (list, dict)): # OmegaConf will convert to [Dict|List]Config, so it is safe to reuse the value field_ = dataclasses.field(default_factory=lambda default=default: default) - elif not _is_immutable_type(defval.annotation, default): + elif not _is_immutable_type(annotation, default): continue else: # we can use a simple default argument for dataclass.field @@ -509,7 +511,22 @@ def _is_immutable_type(type_: Type, val: Any) -> bool: if isinstance(val, PRIMITIVE_TYPES): return True - return type_ in PRIMITIVE_TYPES or issubclass(type_, Enum) + return type_ in PRIMITIVE_TYPES or ( + inspect.isclass(type_) and issubclass(type_, Enum) + ) + + +# copied from OmegaConf +def _resolve_optional(type_: Any) -> Tuple[bool, Any]: + """Check whether `type_` is equivalent to `typing.Optional[T]` for some T.""" + if get_origin(type_) is Union: + args = get_args(type_) + if len(args) == 2 and args[1] == type(None): # noqa E721 + return True, args[0] + if type_ is Any: + return True, Any + + return False, type_ def _is_actually_dataclass(some_class) -> bool: diff --git a/tests/implicitron/test_config.py b/tests/implicitron/test_config.py index 8fe5aafd..9f0cc77f 100644 --- a/tests/implicitron/test_config.py +++ b/tests/implicitron/test_config.py @@ -629,6 +629,8 @@ class TestConfig(unittest.TestCase): class MockDataclass: field_no_default: int field_primitive_type: int = 42 + field_optional_none: Optional[int] = None + field_optional_with_value: Optional[int] = 42 field_list_type: List[int] = field(default_factory=lambda: []) @@ -645,12 +647,16 @@ class MockClassWithInit: # noqa: B903 field_no_nothing, field_no_default: int, field_primitive_type: int = 42, + field_optional_none: Optional[int] = None, + field_optional_with_value: Optional[int] = 42, field_list_type: List[int] = [], # noqa: B006 field_reference_type: RefObject = REF_OBJECT, ): self.field_no_nothing = field_no_nothing self.field_no_default = field_no_default self.field_primitive_type = field_primitive_type + self.field_optional_none = field_optional_none + self.field_optional_with_value = field_optional_with_value self.field_list_type = field_list_type self.field_reference_type = field_reference_type