allow Optional[Dict]=None in config

Summary: Fix recently observed case where enable_get_default_args was missing things declared as Optional[something mutable]=None.

Reviewed By: davnov134

Differential Revision: D36440492

fbshipit-source-id: 192ec07564c325b3b24ccc49b003788f67c63a3d
This commit is contained in:
Jeremy Reizenstein 2022-05-17 05:06:18 -07:00 committed by Facebook GitHub Bot
parent ea5df60d72
commit f36b11fe49
2 changed files with 17 additions and 1 deletions

View File

@ -611,6 +611,9 @@ def _params_iter(C):
def _is_immutable_type(type_: Type, val: Any) -> bool: def _is_immutable_type(type_: Type, val: Any) -> bool:
if val is None:
return True
PRIMITIVE_TYPES = (int, float, bool, str, bytes, tuple) PRIMITIVE_TYPES = (int, float, bool, str, bytes, tuple)
# sometimes type can be too relaxed (e.g. Any), so we also check values # sometimes type can be too relaxed (e.g. Any), so we also check values
if isinstance(val, PRIMITIVE_TYPES): if isinstance(val, PRIMITIVE_TYPES):

View File

@ -9,7 +9,7 @@ import textwrap
import unittest import unittest
from dataclasses import dataclass, field, is_dataclass from dataclasses import dataclass, field, is_dataclass
from enum import Enum from enum import Enum
from typing import Any, List, Optional, Set, Tuple from typing import Any, Dict, List, Optional, Set, Tuple
from omegaconf import DictConfig, ListConfig, OmegaConf, ValidationError from omegaconf import DictConfig, ListConfig, OmegaConf, ValidationError
from pytorch3d.implicitron.tools.config import ( from pytorch3d.implicitron.tools.config import (
@ -735,6 +735,7 @@ class MockDataclass:
field_no_default: int field_no_default: int
field_primitive_type: int = 42 field_primitive_type: int = 42
field_optional_none: Optional[int] = None field_optional_none: Optional[int] = None
field_optional_dict_none: Optional[Dict] = None
field_optional_with_value: Optional[int] = 42 field_optional_with_value: Optional[int] = 42
field_list_type: List[int] = field(default_factory=lambda: []) field_list_type: List[int] = field(default_factory=lambda: [])
@ -753,6 +754,7 @@ class MockClassWithInit: # noqa: B903
field_no_default: int, field_no_default: int,
field_primitive_type: int = 42, field_primitive_type: int = 42,
field_optional_none: Optional[int] = None, field_optional_none: Optional[int] = None,
field_optional_dict_none: Optional[Dict] = None,
field_optional_with_value: Optional[int] = 42, field_optional_with_value: Optional[int] = 42,
field_list_type: List[int] = [], # noqa: B006 field_list_type: List[int] = [], # noqa: B006
field_reference_type: RefObject = REF_OBJECT, field_reference_type: RefObject = REF_OBJECT,
@ -761,6 +763,7 @@ class MockClassWithInit: # noqa: B903
self.field_no_default = field_no_default self.field_no_default = field_no_default
self.field_primitive_type = field_primitive_type self.field_primitive_type = field_primitive_type
self.field_optional_none = field_optional_none self.field_optional_none = field_optional_none
self.field_optional_dict_none = field_optional_dict_none
self.field_optional_with_value = field_optional_with_value self.field_optional_with_value = field_optional_with_value
self.field_list_type = field_list_type self.field_list_type = field_list_type
self.field_reference_type = field_reference_type self.field_reference_type = field_reference_type
@ -785,8 +788,18 @@ class TestRawClasses(unittest.TestCase):
self.assertNotIn("field_no_default", dataclass_defaults) self.assertNotIn("field_no_default", dataclass_defaults)
self.assertNotIn("field_no_nothing", dataclass_defaults) self.assertNotIn("field_no_nothing", dataclass_defaults)
self.assertNotIn("field_reference_type", dataclass_defaults) self.assertNotIn("field_reference_type", dataclass_defaults)
expected_defaults = [
"field_primitive_type",
"field_optional_none",
"field_optional_dict_none",
"field_optional_with_value",
"field_list_type",
]
if cls == MockDataclass: # we don't remove undefaulted from dataclasses if cls == MockDataclass: # we don't remove undefaulted from dataclasses
dataclass_defaults.field_no_default = 0 dataclass_defaults.field_no_default = 0
expected_defaults.insert(0, "field_no_default")
self.assertEqual(list(dataclass_defaults), expected_defaults)
for name, val in dataclass_defaults.items(): for name, val in dataclass_defaults.items():
self.assertTrue(hasattr(self._instances[cls], name)) self.assertTrue(hasattr(self._instances[cls], name))
self.assertEqual(val, getattr(self._instances[cls], name)) self.assertEqual(val, getattr(self._instances[cls], name))