get_default_args for callables respects non-class type annotations and Optionals

Summary: as subj

Reviewed By: davnov134

Differential Revision: D35194863

fbshipit-source-id: c8e8f234083d4f0f93dca8d93e090ca0e1e1972d
This commit is contained in:
Roman Shapovalov 2022-03-29 11:36:11 -07:00 committed by Facebook GitHub Bot
parent b602edccc4
commit a54ad2b912
2 changed files with 25 additions and 2 deletions

View File

@ -474,13 +474,15 @@ def get_default_args(C, *, _do_not_process: Tuple[type, ...] = ()) -> DictConfig
+ f" Argument '{pname}' does not have a type annotation."
)
_, annotation = _resolve_optional(defval.annotation)
if isinstance(default, set): # force OmegaConf to convert it to ListConfig
default = tuple(default)
if isinstance(default, (list, dict)):
# OmegaConf will convert to [Dict|List]Config, so it is safe to reuse the value
field_ = dataclasses.field(default_factory=lambda default=default: default)
elif not _is_immutable_type(defval.annotation, default):
elif not _is_immutable_type(annotation, default):
continue
else:
# we can use a simple default argument for dataclass.field
@ -509,7 +511,22 @@ def _is_immutable_type(type_: Type, val: Any) -> bool:
if isinstance(val, PRIMITIVE_TYPES):
return True
return type_ in PRIMITIVE_TYPES or issubclass(type_, Enum)
return type_ in PRIMITIVE_TYPES or (
inspect.isclass(type_) and issubclass(type_, Enum)
)
# copied from OmegaConf
def _resolve_optional(type_: Any) -> Tuple[bool, Any]:
"""Check whether `type_` is equivalent to `typing.Optional[T]` for some T."""
if get_origin(type_) is Union:
args = get_args(type_)
if len(args) == 2 and args[1] == type(None): # noqa E721
return True, args[0]
if type_ is Any:
return True, Any
return False, type_
def _is_actually_dataclass(some_class) -> bool:

View File

@ -629,6 +629,8 @@ class TestConfig(unittest.TestCase):
class MockDataclass:
field_no_default: int
field_primitive_type: int = 42
field_optional_none: Optional[int] = None
field_optional_with_value: Optional[int] = 42
field_list_type: List[int] = field(default_factory=lambda: [])
@ -645,12 +647,16 @@ class MockClassWithInit: # noqa: B903
field_no_nothing,
field_no_default: int,
field_primitive_type: int = 42,
field_optional_none: Optional[int] = None,
field_optional_with_value: Optional[int] = 42,
field_list_type: List[int] = [], # noqa: B006
field_reference_type: RefObject = REF_OBJECT,
):
self.field_no_nothing = field_no_nothing
self.field_no_default = field_no_default
self.field_primitive_type = field_primitive_type
self.field_optional_none = field_optional_none
self.field_optional_with_value = field_optional_with_value
self.field_list_type = field_list_type
self.field_reference_type = field_reference_type