store original declared types in Configurable

Summary: Aid reflection by adding the original declared types of replaced members of a configurable as values in _processed_members.

Reviewed By: davnov134

Differential Revision: D35358422

fbshipit-source-id: 80ef3266144c51c1c2105f349e0dd3464e230429
This commit is contained in:
Jeremy Reizenstein 2022-04-04 07:19:56 -07:00 committed by Facebook GitHub Bot
parent 199309fcf7
commit 3b8a33e9c5
2 changed files with 8 additions and 4 deletions

View File

@ -11,7 +11,7 @@ import itertools
import warnings
from collections import Counter, defaultdict
from enum import Enum
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, TypeVar, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union
from omegaconf import DictConfig, OmegaConf, open_dict
from pytorch3d.common.datatypes import get_args, get_origin
@ -635,7 +635,9 @@ def expand_args_fields(
- _known_implementations: Dict[str, Type] containing the classes which
have been found from the registry.
(used only to raise a warning if it one has been overwritten)
- _processed_members: a Set[str] of all the members which have been transformed.
- _processed_members: a Dict[str, Any] of all the members which have been
transformed, with values giving the types they were declared to have.
(E.g. {"x": X} or {"x": Optional[X]} in the cases above.)
Args:
some_class: the class to be processed
@ -660,7 +662,7 @@ def expand_args_fields(
# unused.
known_implementations: Dict[str, Type] = {}
# Names of members which have been processed.
processed_members: Set[str] = set()
processed_members: Dict[str, Any] = {}
# For all bases except ReplaceableBase and Configurable and object,
# we need to process them before our own processing. This is
@ -691,6 +693,7 @@ def expand_args_fields(
to_process.append((name, underlying_type, process_type))
for name, underlying_type, process_type in to_process:
processed_members[name] = some_class.__annotations__[name]
_process_member(
name=name,
type_=underlying_type,
@ -700,7 +703,6 @@ def expand_args_fields(
_do_not_process=_do_not_process,
known_implementations=known_implementations,
)
processed_members.add(name)
for key, count in Counter(creation_functions).items():
if count > 1:

View File

@ -255,6 +255,8 @@ class TestConfig(unittest.TestCase):
container_args = get_default_args(Container)
container = Container(**container_args)
self.assertIsInstance(container.fruit, Orange)
self.assertEqual(Container._processed_members, {"fruit": Fruit})
self.assertEqual(container._processed_members, {"fruit": Fruit})
container_defaulted = Container()
container_defaulted.fruit_Pear_args.n_pips += 4