diff --git a/pytorch3d/implicitron/tools/config.py b/pytorch3d/implicitron/tools/config.py index 3c2c4d88..0883ca76 100644 --- a/pytorch3d/implicitron/tools/config.py +++ b/pytorch3d/implicitron/tools/config.py @@ -11,7 +11,7 @@ import itertools import warnings from collections import Counter, defaultdict from enum import Enum -from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, TypeVar, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union from omegaconf import DictConfig, OmegaConf, open_dict from pytorch3d.common.datatypes import get_args, get_origin @@ -635,7 +635,9 @@ def expand_args_fields( - _known_implementations: Dict[str, Type] containing the classes which have been found from the registry. (used only to raise a warning if it one has been overwritten) - - _processed_members: a Set[str] of all the members which have been transformed. + - _processed_members: a Dict[str, Any] of all the members which have been + transformed, with values giving the types they were declared to have. + (E.g. {"x": X} or {"x": Optional[X]} in the cases above.) Args: some_class: the class to be processed @@ -660,7 +662,7 @@ def expand_args_fields( # unused. known_implementations: Dict[str, Type] = {} # Names of members which have been processed. - processed_members: Set[str] = set() + processed_members: Dict[str, Any] = {} # For all bases except ReplaceableBase and Configurable and object, # we need to process them before our own processing. This is @@ -691,6 +693,7 @@ def expand_args_fields( to_process.append((name, underlying_type, process_type)) for name, underlying_type, process_type in to_process: + processed_members[name] = some_class.__annotations__[name] _process_member( name=name, type_=underlying_type, @@ -700,7 +703,6 @@ def expand_args_fields( _do_not_process=_do_not_process, known_implementations=known_implementations, ) - processed_members.add(name) for key, count in Counter(creation_functions).items(): if count > 1: diff --git a/tests/implicitron/test_config.py b/tests/implicitron/test_config.py index 6f238412..54c9d7f6 100644 --- a/tests/implicitron/test_config.py +++ b/tests/implicitron/test_config.py @@ -255,6 +255,8 @@ class TestConfig(unittest.TestCase): container_args = get_default_args(Container) container = Container(**container_args) self.assertIsInstance(container.fruit, Orange) + self.assertEqual(Container._processed_members, {"fruit": Fruit}) + self.assertEqual(container._processed_members, {"fruit": Fruit}) container_defaulted = Container() container_defaulted.fruit_Pear_args.n_pips += 4