diff --git a/pytorch3d/implicitron/tools/config.py b/pytorch3d/implicitron/tools/config.py index cf9970b6..eaf91151 100644 --- a/pytorch3d/implicitron/tools/config.py +++ b/pytorch3d/implicitron/tools/config.py @@ -611,6 +611,9 @@ def _params_iter(C): def _is_immutable_type(type_: Type, val: Any) -> bool: + if val is None: + return True + PRIMITIVE_TYPES = (int, float, bool, str, bytes, tuple) # sometimes type can be too relaxed (e.g. Any), so we also check values if isinstance(val, PRIMITIVE_TYPES): diff --git a/tests/implicitron/test_config.py b/tests/implicitron/test_config.py index 2511566f..2f591edd 100644 --- a/tests/implicitron/test_config.py +++ b/tests/implicitron/test_config.py @@ -9,7 +9,7 @@ import textwrap import unittest from dataclasses import dataclass, field, is_dataclass from enum import Enum -from typing import Any, List, Optional, Set, Tuple +from typing import Any, Dict, List, Optional, Set, Tuple from omegaconf import DictConfig, ListConfig, OmegaConf, ValidationError from pytorch3d.implicitron.tools.config import ( @@ -735,6 +735,7 @@ class MockDataclass: field_no_default: int field_primitive_type: int = 42 field_optional_none: Optional[int] = None + field_optional_dict_none: Optional[Dict] = None field_optional_with_value: Optional[int] = 42 field_list_type: List[int] = field(default_factory=lambda: []) @@ -753,6 +754,7 @@ class MockClassWithInit: # noqa: B903 field_no_default: int, field_primitive_type: int = 42, field_optional_none: Optional[int] = None, + field_optional_dict_none: Optional[Dict] = None, field_optional_with_value: Optional[int] = 42, field_list_type: List[int] = [], # noqa: B006 field_reference_type: RefObject = REF_OBJECT, @@ -761,6 +763,7 @@ class MockClassWithInit: # noqa: B903 self.field_no_default = field_no_default self.field_primitive_type = field_primitive_type self.field_optional_none = field_optional_none + self.field_optional_dict_none = field_optional_dict_none self.field_optional_with_value = field_optional_with_value self.field_list_type = field_list_type self.field_reference_type = field_reference_type @@ -785,8 +788,18 @@ class TestRawClasses(unittest.TestCase): self.assertNotIn("field_no_default", dataclass_defaults) self.assertNotIn("field_no_nothing", dataclass_defaults) self.assertNotIn("field_reference_type", dataclass_defaults) + expected_defaults = [ + "field_primitive_type", + "field_optional_none", + "field_optional_dict_none", + "field_optional_with_value", + "field_list_type", + ] + if cls == MockDataclass: # we don't remove undefaulted from dataclasses dataclass_defaults.field_no_default = 0 + expected_defaults.insert(0, "field_no_default") + self.assertEqual(list(dataclass_defaults), expected_defaults) for name, val in dataclass_defaults.items(): self.assertTrue(hasattr(self._instances[cls], name)) self.assertEqual(val, getattr(self._instances[cls], name))