mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 20:02:49 +08:00
get_default_args for callables respects non-class type annotations and Optionals
Summary: as subj Reviewed By: davnov134 Differential Revision: D35194863 fbshipit-source-id: c8e8f234083d4f0f93dca8d93e090ca0e1e1972d
This commit is contained in:
parent
b602edccc4
commit
a54ad2b912
@ -474,13 +474,15 @@ def get_default_args(C, *, _do_not_process: Tuple[type, ...] = ()) -> DictConfig
|
|||||||
+ f" Argument '{pname}' does not have a type annotation."
|
+ f" Argument '{pname}' does not have a type annotation."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
_, annotation = _resolve_optional(defval.annotation)
|
||||||
|
|
||||||
if isinstance(default, set): # force OmegaConf to convert it to ListConfig
|
if isinstance(default, set): # force OmegaConf to convert it to ListConfig
|
||||||
default = tuple(default)
|
default = tuple(default)
|
||||||
|
|
||||||
if isinstance(default, (list, dict)):
|
if isinstance(default, (list, dict)):
|
||||||
# OmegaConf will convert to [Dict|List]Config, so it is safe to reuse the value
|
# OmegaConf will convert to [Dict|List]Config, so it is safe to reuse the value
|
||||||
field_ = dataclasses.field(default_factory=lambda default=default: default)
|
field_ = dataclasses.field(default_factory=lambda default=default: default)
|
||||||
elif not _is_immutable_type(defval.annotation, default):
|
elif not _is_immutable_type(annotation, default):
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
# we can use a simple default argument for dataclass.field
|
# we can use a simple default argument for dataclass.field
|
||||||
@ -509,7 +511,22 @@ def _is_immutable_type(type_: Type, val: Any) -> bool:
|
|||||||
if isinstance(val, PRIMITIVE_TYPES):
|
if isinstance(val, PRIMITIVE_TYPES):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
return type_ in PRIMITIVE_TYPES or issubclass(type_, Enum)
|
return type_ in PRIMITIVE_TYPES or (
|
||||||
|
inspect.isclass(type_) and issubclass(type_, Enum)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# copied from OmegaConf
|
||||||
|
def _resolve_optional(type_: Any) -> Tuple[bool, Any]:
|
||||||
|
"""Check whether `type_` is equivalent to `typing.Optional[T]` for some T."""
|
||||||
|
if get_origin(type_) is Union:
|
||||||
|
args = get_args(type_)
|
||||||
|
if len(args) == 2 and args[1] == type(None): # noqa E721
|
||||||
|
return True, args[0]
|
||||||
|
if type_ is Any:
|
||||||
|
return True, Any
|
||||||
|
|
||||||
|
return False, type_
|
||||||
|
|
||||||
|
|
||||||
def _is_actually_dataclass(some_class) -> bool:
|
def _is_actually_dataclass(some_class) -> bool:
|
||||||
|
@ -629,6 +629,8 @@ class TestConfig(unittest.TestCase):
|
|||||||
class MockDataclass:
|
class MockDataclass:
|
||||||
field_no_default: int
|
field_no_default: int
|
||||||
field_primitive_type: int = 42
|
field_primitive_type: int = 42
|
||||||
|
field_optional_none: Optional[int] = None
|
||||||
|
field_optional_with_value: Optional[int] = 42
|
||||||
field_list_type: List[int] = field(default_factory=lambda: [])
|
field_list_type: List[int] = field(default_factory=lambda: [])
|
||||||
|
|
||||||
|
|
||||||
@ -645,12 +647,16 @@ class MockClassWithInit: # noqa: B903
|
|||||||
field_no_nothing,
|
field_no_nothing,
|
||||||
field_no_default: int,
|
field_no_default: int,
|
||||||
field_primitive_type: int = 42,
|
field_primitive_type: int = 42,
|
||||||
|
field_optional_none: Optional[int] = None,
|
||||||
|
field_optional_with_value: Optional[int] = 42,
|
||||||
field_list_type: List[int] = [], # noqa: B006
|
field_list_type: List[int] = [], # noqa: B006
|
||||||
field_reference_type: RefObject = REF_OBJECT,
|
field_reference_type: RefObject = REF_OBJECT,
|
||||||
):
|
):
|
||||||
self.field_no_nothing = field_no_nothing
|
self.field_no_nothing = field_no_nothing
|
||||||
self.field_no_default = field_no_default
|
self.field_no_default = field_no_default
|
||||||
self.field_primitive_type = field_primitive_type
|
self.field_primitive_type = field_primitive_type
|
||||||
|
self.field_optional_none = field_optional_none
|
||||||
|
self.field_optional_with_value = field_optional_with_value
|
||||||
self.field_list_type = field_list_type
|
self.field_list_type = field_list_type
|
||||||
self.field_reference_type = field_reference_type
|
self.field_reference_type = field_reference_type
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user