mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
store original declared types in Configurable
Summary: Aid reflection by adding the original declared types of replaced members of a configurable as values in _processed_members. Reviewed By: davnov134 Differential Revision: D35358422 fbshipit-source-id: 80ef3266144c51c1c2105f349e0dd3464e230429
This commit is contained in:
parent
199309fcf7
commit
3b8a33e9c5
@ -11,7 +11,7 @@ import itertools
|
||||
import warnings
|
||||
from collections import Counter, defaultdict
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, TypeVar, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union
|
||||
|
||||
from omegaconf import DictConfig, OmegaConf, open_dict
|
||||
from pytorch3d.common.datatypes import get_args, get_origin
|
||||
@ -635,7 +635,9 @@ def expand_args_fields(
|
||||
- _known_implementations: Dict[str, Type] containing the classes which
|
||||
have been found from the registry.
|
||||
(used only to raise a warning if it one has been overwritten)
|
||||
- _processed_members: a Set[str] of all the members which have been transformed.
|
||||
- _processed_members: a Dict[str, Any] of all the members which have been
|
||||
transformed, with values giving the types they were declared to have.
|
||||
(E.g. {"x": X} or {"x": Optional[X]} in the cases above.)
|
||||
|
||||
Args:
|
||||
some_class: the class to be processed
|
||||
@ -660,7 +662,7 @@ def expand_args_fields(
|
||||
# unused.
|
||||
known_implementations: Dict[str, Type] = {}
|
||||
# Names of members which have been processed.
|
||||
processed_members: Set[str] = set()
|
||||
processed_members: Dict[str, Any] = {}
|
||||
|
||||
# For all bases except ReplaceableBase and Configurable and object,
|
||||
# we need to process them before our own processing. This is
|
||||
@ -691,6 +693,7 @@ def expand_args_fields(
|
||||
to_process.append((name, underlying_type, process_type))
|
||||
|
||||
for name, underlying_type, process_type in to_process:
|
||||
processed_members[name] = some_class.__annotations__[name]
|
||||
_process_member(
|
||||
name=name,
|
||||
type_=underlying_type,
|
||||
@ -700,7 +703,6 @@ def expand_args_fields(
|
||||
_do_not_process=_do_not_process,
|
||||
known_implementations=known_implementations,
|
||||
)
|
||||
processed_members.add(name)
|
||||
|
||||
for key, count in Counter(creation_functions).items():
|
||||
if count > 1:
|
||||
|
@ -255,6 +255,8 @@ class TestConfig(unittest.TestCase):
|
||||
container_args = get_default_args(Container)
|
||||
container = Container(**container_args)
|
||||
self.assertIsInstance(container.fruit, Orange)
|
||||
self.assertEqual(Container._processed_members, {"fruit": Fruit})
|
||||
self.assertEqual(container._processed_members, {"fruit": Fruit})
|
||||
|
||||
container_defaulted = Container()
|
||||
container_defaulted.fruit_Pear_args.n_pips += 4
|
||||
|
Loading…
x
Reference in New Issue
Block a user