mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
allow Optional[Dict]=None in config
Summary: Fix recently observed case where enable_get_default_args was missing things declared as Optional[something mutable]=None. Reviewed By: davnov134 Differential Revision: D36440492 fbshipit-source-id: 192ec07564c325b3b24ccc49b003788f67c63a3d
This commit is contained in:
parent
ea5df60d72
commit
f36b11fe49
@ -611,6 +611,9 @@ def _params_iter(C):
|
||||
|
||||
|
||||
def _is_immutable_type(type_: Type, val: Any) -> bool:
|
||||
if val is None:
|
||||
return True
|
||||
|
||||
PRIMITIVE_TYPES = (int, float, bool, str, bytes, tuple)
|
||||
# sometimes type can be too relaxed (e.g. Any), so we also check values
|
||||
if isinstance(val, PRIMITIVE_TYPES):
|
||||
|
@ -9,7 +9,7 @@ import textwrap
|
||||
import unittest
|
||||
from dataclasses import dataclass, field, is_dataclass
|
||||
from enum import Enum
|
||||
from typing import Any, List, Optional, Set, Tuple
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple
|
||||
|
||||
from omegaconf import DictConfig, ListConfig, OmegaConf, ValidationError
|
||||
from pytorch3d.implicitron.tools.config import (
|
||||
@ -735,6 +735,7 @@ class MockDataclass:
|
||||
field_no_default: int
|
||||
field_primitive_type: int = 42
|
||||
field_optional_none: Optional[int] = None
|
||||
field_optional_dict_none: Optional[Dict] = None
|
||||
field_optional_with_value: Optional[int] = 42
|
||||
field_list_type: List[int] = field(default_factory=lambda: [])
|
||||
|
||||
@ -753,6 +754,7 @@ class MockClassWithInit: # noqa: B903
|
||||
field_no_default: int,
|
||||
field_primitive_type: int = 42,
|
||||
field_optional_none: Optional[int] = None,
|
||||
field_optional_dict_none: Optional[Dict] = None,
|
||||
field_optional_with_value: Optional[int] = 42,
|
||||
field_list_type: List[int] = [], # noqa: B006
|
||||
field_reference_type: RefObject = REF_OBJECT,
|
||||
@ -761,6 +763,7 @@ class MockClassWithInit: # noqa: B903
|
||||
self.field_no_default = field_no_default
|
||||
self.field_primitive_type = field_primitive_type
|
||||
self.field_optional_none = field_optional_none
|
||||
self.field_optional_dict_none = field_optional_dict_none
|
||||
self.field_optional_with_value = field_optional_with_value
|
||||
self.field_list_type = field_list_type
|
||||
self.field_reference_type = field_reference_type
|
||||
@ -785,8 +788,18 @@ class TestRawClasses(unittest.TestCase):
|
||||
self.assertNotIn("field_no_default", dataclass_defaults)
|
||||
self.assertNotIn("field_no_nothing", dataclass_defaults)
|
||||
self.assertNotIn("field_reference_type", dataclass_defaults)
|
||||
expected_defaults = [
|
||||
"field_primitive_type",
|
||||
"field_optional_none",
|
||||
"field_optional_dict_none",
|
||||
"field_optional_with_value",
|
||||
"field_list_type",
|
||||
]
|
||||
|
||||
if cls == MockDataclass: # we don't remove undefaulted from dataclasses
|
||||
dataclass_defaults.field_no_default = 0
|
||||
expected_defaults.insert(0, "field_no_default")
|
||||
self.assertEqual(list(dataclass_defaults), expected_defaults)
|
||||
for name, val in dataclass_defaults.items():
|
||||
self.assertTrue(hasattr(self._instances[cls], name))
|
||||
self.assertEqual(val, getattr(self._instances[cls], name))
|
||||
|
Loading…
x
Reference in New Issue
Block a user