allow Optional[Dict]=None in config

Summary: Fix recently observed case where enable_get_default_args was missing things declared as Optional[something mutable]=None.

Reviewed By: davnov134

Differential Revision: D36440492

fbshipit-source-id: 192ec07564c325b3b24ccc49b003788f67c63a3d
This commit is contained in:
Jeremy Reizenstein 2022-05-17 05:06:18 -07:00 committed by Facebook GitHub Bot
parent ea5df60d72
commit f36b11fe49
2 changed files with 17 additions and 1 deletions

View File

@ -611,6 +611,9 @@ def _params_iter(C):
def _is_immutable_type(type_: Type, val: Any) -> bool:
if val is None:
return True
PRIMITIVE_TYPES = (int, float, bool, str, bytes, tuple)
# sometimes type can be too relaxed (e.g. Any), so we also check values
if isinstance(val, PRIMITIVE_TYPES):

View File

@ -9,7 +9,7 @@ import textwrap
import unittest
from dataclasses import dataclass, field, is_dataclass
from enum import Enum
from typing import Any, List, Optional, Set, Tuple
from typing import Any, Dict, List, Optional, Set, Tuple
from omegaconf import DictConfig, ListConfig, OmegaConf, ValidationError
from pytorch3d.implicitron.tools.config import (
@ -735,6 +735,7 @@ class MockDataclass:
field_no_default: int
field_primitive_type: int = 42
field_optional_none: Optional[int] = None
field_optional_dict_none: Optional[Dict] = None
field_optional_with_value: Optional[int] = 42
field_list_type: List[int] = field(default_factory=lambda: [])
@ -753,6 +754,7 @@ class MockClassWithInit: # noqa: B903
field_no_default: int,
field_primitive_type: int = 42,
field_optional_none: Optional[int] = None,
field_optional_dict_none: Optional[Dict] = None,
field_optional_with_value: Optional[int] = 42,
field_list_type: List[int] = [], # noqa: B006
field_reference_type: RefObject = REF_OBJECT,
@ -761,6 +763,7 @@ class MockClassWithInit: # noqa: B903
self.field_no_default = field_no_default
self.field_primitive_type = field_primitive_type
self.field_optional_none = field_optional_none
self.field_optional_dict_none = field_optional_dict_none
self.field_optional_with_value = field_optional_with_value
self.field_list_type = field_list_type
self.field_reference_type = field_reference_type
@ -785,8 +788,18 @@ class TestRawClasses(unittest.TestCase):
self.assertNotIn("field_no_default", dataclass_defaults)
self.assertNotIn("field_no_nothing", dataclass_defaults)
self.assertNotIn("field_reference_type", dataclass_defaults)
expected_defaults = [
"field_primitive_type",
"field_optional_none",
"field_optional_dict_none",
"field_optional_with_value",
"field_list_type",
]
if cls == MockDataclass: # we don't remove undefaulted from dataclasses
dataclass_defaults.field_no_default = 0
expected_defaults.insert(0, "field_no_default")
self.assertEqual(list(dataclass_defaults), expected_defaults)
for name, val in dataclass_defaults.items():
self.assertTrue(hasattr(self._instances[cls], name))
self.assertEqual(val, getattr(self._instances[cls], name))