mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 20:02:49 +08:00
store original declared types in Configurable
Summary: Aid reflection by adding the original declared types of replaced members of a configurable as values in _processed_members. Reviewed By: davnov134 Differential Revision: D35358422 fbshipit-source-id: 80ef3266144c51c1c2105f349e0dd3464e230429
This commit is contained in:
parent
199309fcf7
commit
3b8a33e9c5
@ -11,7 +11,7 @@ import itertools
|
|||||||
import warnings
|
import warnings
|
||||||
from collections import Counter, defaultdict
|
from collections import Counter, defaultdict
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, TypeVar, Union
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union
|
||||||
|
|
||||||
from omegaconf import DictConfig, OmegaConf, open_dict
|
from omegaconf import DictConfig, OmegaConf, open_dict
|
||||||
from pytorch3d.common.datatypes import get_args, get_origin
|
from pytorch3d.common.datatypes import get_args, get_origin
|
||||||
@ -635,7 +635,9 @@ def expand_args_fields(
|
|||||||
- _known_implementations: Dict[str, Type] containing the classes which
|
- _known_implementations: Dict[str, Type] containing the classes which
|
||||||
have been found from the registry.
|
have been found from the registry.
|
||||||
(used only to raise a warning if it one has been overwritten)
|
(used only to raise a warning if it one has been overwritten)
|
||||||
- _processed_members: a Set[str] of all the members which have been transformed.
|
- _processed_members: a Dict[str, Any] of all the members which have been
|
||||||
|
transformed, with values giving the types they were declared to have.
|
||||||
|
(E.g. {"x": X} or {"x": Optional[X]} in the cases above.)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
some_class: the class to be processed
|
some_class: the class to be processed
|
||||||
@ -660,7 +662,7 @@ def expand_args_fields(
|
|||||||
# unused.
|
# unused.
|
||||||
known_implementations: Dict[str, Type] = {}
|
known_implementations: Dict[str, Type] = {}
|
||||||
# Names of members which have been processed.
|
# Names of members which have been processed.
|
||||||
processed_members: Set[str] = set()
|
processed_members: Dict[str, Any] = {}
|
||||||
|
|
||||||
# For all bases except ReplaceableBase and Configurable and object,
|
# For all bases except ReplaceableBase and Configurable and object,
|
||||||
# we need to process them before our own processing. This is
|
# we need to process them before our own processing. This is
|
||||||
@ -691,6 +693,7 @@ def expand_args_fields(
|
|||||||
to_process.append((name, underlying_type, process_type))
|
to_process.append((name, underlying_type, process_type))
|
||||||
|
|
||||||
for name, underlying_type, process_type in to_process:
|
for name, underlying_type, process_type in to_process:
|
||||||
|
processed_members[name] = some_class.__annotations__[name]
|
||||||
_process_member(
|
_process_member(
|
||||||
name=name,
|
name=name,
|
||||||
type_=underlying_type,
|
type_=underlying_type,
|
||||||
@ -700,7 +703,6 @@ def expand_args_fields(
|
|||||||
_do_not_process=_do_not_process,
|
_do_not_process=_do_not_process,
|
||||||
known_implementations=known_implementations,
|
known_implementations=known_implementations,
|
||||||
)
|
)
|
||||||
processed_members.add(name)
|
|
||||||
|
|
||||||
for key, count in Counter(creation_functions).items():
|
for key, count in Counter(creation_functions).items():
|
||||||
if count > 1:
|
if count > 1:
|
||||||
|
@ -255,6 +255,8 @@ class TestConfig(unittest.TestCase):
|
|||||||
container_args = get_default_args(Container)
|
container_args = get_default_args(Container)
|
||||||
container = Container(**container_args)
|
container = Container(**container_args)
|
||||||
self.assertIsInstance(container.fruit, Orange)
|
self.assertIsInstance(container.fruit, Orange)
|
||||||
|
self.assertEqual(Container._processed_members, {"fruit": Fruit})
|
||||||
|
self.assertEqual(container._processed_members, {"fruit": Fruit})
|
||||||
|
|
||||||
container_defaulted = Container()
|
container_defaulted = Container()
|
||||||
container_defaulted.fruit_Pear_args.n_pips += 4
|
container_defaulted.fruit_Pear_args.n_pips += 4
|
||||||
|
Loading…
x
Reference in New Issue
Block a user