mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Optional[Configurable] in config
Summary: A new type of auto-expanded member of a Configurable: something of type Optional[X] where X is a Configurable. This works like X but its construction is controlled by a boolean membername_enabled. Reviewed By: davnov134 Differential Revision: D35368269 fbshipit-source-id: 7e0c8a3e8c4930b0aa942fa1b325ce65336ebd5f
This commit is contained in:
parent
e10a90140d
commit
722646863c
@ -60,7 +60,7 @@ in dataclass style.
|
||||
def __post_init__(self):
|
||||
run_auto_creation(self)
|
||||
|
||||
It can be used like
|
||||
Then it can be used like
|
||||
|
||||
b_args = get_default_args(B)
|
||||
b = B(**b_args)
|
||||
@ -82,8 +82,8 @@ something like the following. (The modification itself is done by the function
|
||||
self.a = A(**self.a_args)
|
||||
|
||||
2. Pluggability. Instead of a dataclass-style member being given a concrete class,
|
||||
you can give a base class and the implementation is looked up by name in the global
|
||||
`registry` in this module. E.g.
|
||||
it can be given a base class and the implementation will be looked up by name in the
|
||||
global `registry` in this module. E.g.
|
||||
|
||||
class A(ReplaceableBase):
|
||||
k: int = 1
|
||||
@ -126,14 +126,14 @@ will expand to
|
||||
default_factory=lambda: DictConfig({"k": 1, "m": 3}
|
||||
)
|
||||
a_A2_args: DictConfig = dataclasses.field(
|
||||
default_factory=lambda: DictConfig({"k": 1, "m": 3}
|
||||
default_factory=lambda: DictConfig({"k": 1, "n": 2}
|
||||
)
|
||||
b_class_type: Optional[str] = "A2"
|
||||
b_A1_args: DictConfig = dataclasses.field(
|
||||
default_factory=lambda: DictConfig({"k": 1, "m": 3}
|
||||
)
|
||||
b_A2_args: DictConfig = dataclasses.field(
|
||||
default_factory=lambda: DictConfig({"k": 1, "m": 3}
|
||||
default_factory=lambda: DictConfig({"k": 1, "n": 2}
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
@ -155,12 +155,14 @@ will expand to
|
||||
|
||||
3. Aside from these classes, the members of these classes should be things
|
||||
which DictConfig is happy with: e.g. (bool, int, str, None, float) and what
|
||||
can be built from them with DictConfigs and lists of them.
|
||||
can be built from them with `DictConfig`s and lists of them.
|
||||
|
||||
In addition, you can call get_default_args on a function or class to get
|
||||
the DictConfig of its defaulted arguments, assuming those are all things
|
||||
which DictConfig is happy with. If you want to use such a thing as a member
|
||||
of another configured class, `get_default_args_field` is a helper.
|
||||
In addition, you can call `get_default_args` on a function or class to get
|
||||
the `DictConfig` of its defaulted arguments, assuming those are all things
|
||||
which `DictConfig` is happy with, so long as you add a call to
|
||||
`enable_get_default_args` after its definition. If you want to use such a
|
||||
thing as the default for a member of another configured class,
|
||||
`get_default_args_field` is a helper.
|
||||
"""
|
||||
|
||||
|
||||
@ -172,6 +174,7 @@ _unprocessed_warning: str = (
|
||||
|
||||
TYPE_SUFFIX: str = "_class_type"
|
||||
ARGS_SUFFIX: str = "_args"
|
||||
ENABLED_SUFFIX: str = "_enabled"
|
||||
|
||||
|
||||
class ReplaceableBase:
|
||||
@ -351,7 +354,8 @@ class _ProcessType(Enum):
|
||||
|
||||
CONFIGURABLE = 1
|
||||
REPLACEABLE = 2
|
||||
OPTIONAL_REPLACEABLE = 3
|
||||
OPTIONAL_CONFIGURABLE = 3
|
||||
OPTIONAL_REPLACEABLE = 4
|
||||
|
||||
|
||||
def _default_create(
|
||||
@ -379,6 +383,15 @@ def _default_create(
|
||||
args = getattr(self, name + ARGS_SUFFIX)
|
||||
setattr(self, name, type_(**args))
|
||||
|
||||
def inner_optional(self):
|
||||
expand_args_fields(type_)
|
||||
enabled = getattr(self, name + ENABLED_SUFFIX)
|
||||
if enabled:
|
||||
args = getattr(self, name + ARGS_SUFFIX)
|
||||
setattr(self, name, type_(**args))
|
||||
else:
|
||||
setattr(self, name, None)
|
||||
|
||||
def inner_pluggable(self):
|
||||
type_name = getattr(self, name + TYPE_SUFFIX)
|
||||
if type_name is None:
|
||||
@ -398,6 +411,8 @@ def _default_create(
|
||||
args = getattr(self, f"{name}_{type_name}{ARGS_SUFFIX}")
|
||||
setattr(self, name, chosen_class(**args))
|
||||
|
||||
if process_type == _ProcessType.OPTIONAL_CONFIGURABLE:
|
||||
return inner_optional
|
||||
return inner if process_type == _ProcessType.CONFIGURABLE else inner_pluggable
|
||||
|
||||
|
||||
@ -608,8 +623,8 @@ def expand_args_fields(
|
||||
|
||||
with
|
||||
|
||||
x_Y_args : DictConfig = dataclasses.field(default_factory=lambda: DictConfig())
|
||||
x_Z_args : DictConfig = dataclasses.field(default_factory=lambda: DictConfig())
|
||||
x_Y_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(Y))
|
||||
x_Z_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(Z))
|
||||
def create_x(self):
|
||||
self.x = registry.get(X, self.x_class_type)(
|
||||
**self.getattr(f"x_{self.x_class_type}_args)
|
||||
@ -629,8 +644,8 @@ def expand_args_fields(
|
||||
|
||||
with
|
||||
|
||||
x_Y_args : DictConfig = dataclasses.field(default_factory=lambda: DictConfig())
|
||||
x_Z_args : DictConfig = dataclasses.field(default_factory=lambda: DictConfig())
|
||||
x_Y_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(Y))
|
||||
x_Z_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(Z))
|
||||
def create_x(self):
|
||||
if self.x_class_type is None:
|
||||
self.x = None
|
||||
@ -653,10 +668,30 @@ def expand_args_fields(
|
||||
|
||||
will be replaced with
|
||||
|
||||
x_args : DictConfig = dataclasses.field(default_factory=lambda: DictConfig())
|
||||
x_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(X))
|
||||
def create_x(self):
|
||||
self.x = X(self.x_args)
|
||||
|
||||
Similarly, replace,
|
||||
|
||||
x: Optional[X]
|
||||
|
||||
and optionally
|
||||
|
||||
def create_x(self):...
|
||||
x_enabled: bool = ...
|
||||
|
||||
with
|
||||
|
||||
x_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(X))
|
||||
x_enabled: bool = False
|
||||
def create_x(self):
|
||||
if self.x_enabled:
|
||||
self.x = X(self.x_args)
|
||||
else:
|
||||
self.x = None
|
||||
|
||||
|
||||
Also adds the following class members, unannotated so that dataclass
|
||||
ignores them.
|
||||
- _creation_functions: Tuple[str] of all the create_ functions,
|
||||
@ -779,6 +814,9 @@ def _get_type_to_process(type_) -> Optional[Tuple[Type, _ProcessType]]:
|
||||
):
|
||||
return underlying, _ProcessType.OPTIONAL_REPLACEABLE
|
||||
|
||||
if isinstance(underlying, type) and issubclass(underlying, Configurable):
|
||||
return underlying, _ProcessType.OPTIONAL_CONFIGURABLE
|
||||
|
||||
if not isinstance(type_, type):
|
||||
# e.g. any other Union or Tuple
|
||||
return
|
||||
@ -817,7 +855,7 @@ def _process_member(
|
||||
# there are non-defaulted standard class members.
|
||||
del some_class.__annotations__[name]
|
||||
|
||||
if process_type != _ProcessType.CONFIGURABLE:
|
||||
if process_type in (_ProcessType.REPLACEABLE, _ProcessType.OPTIONAL_REPLACEABLE):
|
||||
type_name = name + TYPE_SUFFIX
|
||||
if type_name not in some_class.__annotations__:
|
||||
if process_type == _ProcessType.OPTIONAL_REPLACEABLE:
|
||||
@ -866,6 +904,11 @@ def _process_member(
|
||||
_do_not_process=_do_not_process + (some_class,),
|
||||
),
|
||||
)
|
||||
if process_type == _ProcessType.OPTIONAL_CONFIGURABLE:
|
||||
enabled_name = name + ENABLED_SUFFIX
|
||||
if enabled_name not in some_class.__annotations__:
|
||||
some_class.__annotations__[enabled_name] = bool
|
||||
setattr(some_class, enabled_name, False)
|
||||
|
||||
creation_function_name = f"create_{name}"
|
||||
if not hasattr(some_class, creation_function_name):
|
||||
@ -884,7 +927,8 @@ def remove_unused_components(dict_: DictConfig) -> None:
|
||||
pluggable parts which are not in use.
|
||||
For example, if renderer_class_type is SignedDistanceFunctionRenderer,
|
||||
the renderer_MultiPassEmissionAbsorptionRenderer_args will be
|
||||
removed.
|
||||
removed. Also, if chocolate_enabled is False, then chocolate_args will
|
||||
be removed.
|
||||
|
||||
Args:
|
||||
dict_: (MODIFIED IN PLACE) a DictConfig instance
|
||||
@ -904,6 +948,14 @@ def remove_unused_components(dict_: DictConfig) -> None:
|
||||
if key.startswith(replaceable + "_") and key != expect:
|
||||
del dict_[key]
|
||||
|
||||
suffix_length = len(ENABLED_SUFFIX)
|
||||
enableables = [key[:-suffix_length] for key in keys if key.endswith(ENABLED_SUFFIX)]
|
||||
for enableable in enableables:
|
||||
enabled = dict_[enableable + ENABLED_SUFFIX]
|
||||
if not enabled:
|
||||
with open_dict(dict_):
|
||||
dict_.pop(enableable + ARGS_SUFFIX, None)
|
||||
|
||||
for key in dict_:
|
||||
if isinstance(dict_.get(key), DictConfig):
|
||||
remove_unused_components(dict_[key])
|
||||
|
@ -106,8 +106,10 @@ class TestConfig(unittest.TestCase):
|
||||
gt(Optional[Fruit]), (Fruit, _ProcessType.OPTIONAL_REPLACEABLE)
|
||||
)
|
||||
self.assertEqual(gt(MainTest), (MainTest, _ProcessType.CONFIGURABLE))
|
||||
self.assertEqual(
|
||||
gt(Optional[MainTest]), (MainTest, _ProcessType.OPTIONAL_CONFIGURABLE)
|
||||
)
|
||||
self.assertIsNone(gt(Optional[int]))
|
||||
self.assertIsNone(gt(Optional[MainTest]))
|
||||
self.assertIsNone(gt(Tuple[Fruit]))
|
||||
self.assertIsNone(gt(Tuple[Fruit, Animal]))
|
||||
self.assertIsNone(gt(Optional[List[int]]))
|
||||
@ -427,15 +429,24 @@ class TestConfig(unittest.TestCase):
|
||||
run_auto_creation(self)
|
||||
|
||||
class C(Configurable):
|
||||
b: B
|
||||
b1: B
|
||||
b2: Optional[B]
|
||||
b3: Optional[B]
|
||||
b2_enabled: bool = True
|
||||
|
||||
def __post_init__(self):
|
||||
run_auto_creation(self)
|
||||
|
||||
c_args = get_default_args(C)
|
||||
c = C(**c_args)
|
||||
self.assertIsInstance(c.b.a, A)
|
||||
self.assertEqual(c.b.a.n, 9)
|
||||
self.assertIsInstance(c.b1.a, A)
|
||||
self.assertEqual(c.b1.a.n, 9)
|
||||
self.assertFalse(hasattr(c, "b1_enabled"))
|
||||
self.assertIsInstance(c.b2.a, A)
|
||||
self.assertEqual(c.b2.a.n, 9)
|
||||
self.assertTrue(c.b2_enabled)
|
||||
self.assertIsNone(c.b3)
|
||||
self.assertFalse(c.b3_enabled)
|
||||
|
||||
def test_doc(self):
|
||||
# The case in the docstring.
|
||||
@ -522,7 +533,7 @@ class TestConfig(unittest.TestCase):
|
||||
# Like torch.nn.Module, this class contains annotations
|
||||
# but is not designed to be dataclass'd.
|
||||
# This test ensures that such classes, when inherited fron,
|
||||
# are not accidentally expand_args_fields.
|
||||
# are not accidentally affected by expand_args_fields.
|
||||
a: int = 9
|
||||
b: int
|
||||
|
||||
@ -654,6 +665,15 @@ class TestConfig(unittest.TestCase):
|
||||
self.assertEqual(sorted(instance_data.keys()), expected_keys)
|
||||
self.assertEqual(instance_data, expected)
|
||||
|
||||
def test_remove_unused_components_optional(self):
|
||||
class MainTestWrapper(Configurable):
|
||||
mt: Optional[MainTest]
|
||||
|
||||
args = get_default_args(MainTestWrapper)
|
||||
self.assertEqual(list(args.keys()), ["mt_args", "mt_enabled"])
|
||||
remove_unused_components(args)
|
||||
self.assertEqual(OmegaConf.to_yaml(args), "mt_enabled: false\n")
|
||||
|
||||
|
||||
@dataclass(eq=False)
|
||||
class MockDataclass:
|
||||
|
Loading…
x
Reference in New Issue
Block a user