mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
get_default_args for callables respects non-class type annotations and Optionals
Summary: as subj Reviewed By: davnov134 Differential Revision: D35194863 fbshipit-source-id: c8e8f234083d4f0f93dca8d93e090ca0e1e1972d
This commit is contained in:
parent
b602edccc4
commit
a54ad2b912
@ -474,13 +474,15 @@ def get_default_args(C, *, _do_not_process: Tuple[type, ...] = ()) -> DictConfig
|
||||
+ f" Argument '{pname}' does not have a type annotation."
|
||||
)
|
||||
|
||||
_, annotation = _resolve_optional(defval.annotation)
|
||||
|
||||
if isinstance(default, set): # force OmegaConf to convert it to ListConfig
|
||||
default = tuple(default)
|
||||
|
||||
if isinstance(default, (list, dict)):
|
||||
# OmegaConf will convert to [Dict|List]Config, so it is safe to reuse the value
|
||||
field_ = dataclasses.field(default_factory=lambda default=default: default)
|
||||
elif not _is_immutable_type(defval.annotation, default):
|
||||
elif not _is_immutable_type(annotation, default):
|
||||
continue
|
||||
else:
|
||||
# we can use a simple default argument for dataclass.field
|
||||
@ -509,7 +511,22 @@ def _is_immutable_type(type_: Type, val: Any) -> bool:
|
||||
if isinstance(val, PRIMITIVE_TYPES):
|
||||
return True
|
||||
|
||||
return type_ in PRIMITIVE_TYPES or issubclass(type_, Enum)
|
||||
return type_ in PRIMITIVE_TYPES or (
|
||||
inspect.isclass(type_) and issubclass(type_, Enum)
|
||||
)
|
||||
|
||||
|
||||
# copied from OmegaConf
|
||||
def _resolve_optional(type_: Any) -> Tuple[bool, Any]:
|
||||
"""Check whether `type_` is equivalent to `typing.Optional[T]` for some T."""
|
||||
if get_origin(type_) is Union:
|
||||
args = get_args(type_)
|
||||
if len(args) == 2 and args[1] == type(None): # noqa E721
|
||||
return True, args[0]
|
||||
if type_ is Any:
|
||||
return True, Any
|
||||
|
||||
return False, type_
|
||||
|
||||
|
||||
def _is_actually_dataclass(some_class) -> bool:
|
||||
|
@ -629,6 +629,8 @@ class TestConfig(unittest.TestCase):
|
||||
class MockDataclass:
|
||||
field_no_default: int
|
||||
field_primitive_type: int = 42
|
||||
field_optional_none: Optional[int] = None
|
||||
field_optional_with_value: Optional[int] = 42
|
||||
field_list_type: List[int] = field(default_factory=lambda: [])
|
||||
|
||||
|
||||
@ -645,12 +647,16 @@ class MockClassWithInit: # noqa: B903
|
||||
field_no_nothing,
|
||||
field_no_default: int,
|
||||
field_primitive_type: int = 42,
|
||||
field_optional_none: Optional[int] = None,
|
||||
field_optional_with_value: Optional[int] = 42,
|
||||
field_list_type: List[int] = [], # noqa: B006
|
||||
field_reference_type: RefObject = REF_OBJECT,
|
||||
):
|
||||
self.field_no_nothing = field_no_nothing
|
||||
self.field_no_default = field_no_default
|
||||
self.field_primitive_type = field_primitive_type
|
||||
self.field_optional_none = field_optional_none
|
||||
self.field_optional_with_value = field_optional_with_value
|
||||
self.field_list_type = field_list_type
|
||||
self.field_reference_type = field_reference_type
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user