mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 20:02:49 +08:00
Optional[Configurable] in config
Summary: A new type of auto-expanded member of a Configurable: something of type Optional[X] where X is a Configurable. This works like X but its construction is controlled by a boolean membername_enabled. Reviewed By: davnov134 Differential Revision: D35368269 fbshipit-source-id: 7e0c8a3e8c4930b0aa942fa1b325ce65336ebd5f
This commit is contained in:
parent
e10a90140d
commit
722646863c
@ -60,7 +60,7 @@ in dataclass style.
|
|||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
run_auto_creation(self)
|
run_auto_creation(self)
|
||||||
|
|
||||||
It can be used like
|
Then it can be used like
|
||||||
|
|
||||||
b_args = get_default_args(B)
|
b_args = get_default_args(B)
|
||||||
b = B(**b_args)
|
b = B(**b_args)
|
||||||
@ -82,8 +82,8 @@ something like the following. (The modification itself is done by the function
|
|||||||
self.a = A(**self.a_args)
|
self.a = A(**self.a_args)
|
||||||
|
|
||||||
2. Pluggability. Instead of a dataclass-style member being given a concrete class,
|
2. Pluggability. Instead of a dataclass-style member being given a concrete class,
|
||||||
you can give a base class and the implementation is looked up by name in the global
|
it can be given a base class and the implementation will be looked up by name in the
|
||||||
`registry` in this module. E.g.
|
global `registry` in this module. E.g.
|
||||||
|
|
||||||
class A(ReplaceableBase):
|
class A(ReplaceableBase):
|
||||||
k: int = 1
|
k: int = 1
|
||||||
@ -126,14 +126,14 @@ will expand to
|
|||||||
default_factory=lambda: DictConfig({"k": 1, "m": 3}
|
default_factory=lambda: DictConfig({"k": 1, "m": 3}
|
||||||
)
|
)
|
||||||
a_A2_args: DictConfig = dataclasses.field(
|
a_A2_args: DictConfig = dataclasses.field(
|
||||||
default_factory=lambda: DictConfig({"k": 1, "m": 3}
|
default_factory=lambda: DictConfig({"k": 1, "n": 2}
|
||||||
)
|
)
|
||||||
b_class_type: Optional[str] = "A2"
|
b_class_type: Optional[str] = "A2"
|
||||||
b_A1_args: DictConfig = dataclasses.field(
|
b_A1_args: DictConfig = dataclasses.field(
|
||||||
default_factory=lambda: DictConfig({"k": 1, "m": 3}
|
default_factory=lambda: DictConfig({"k": 1, "m": 3}
|
||||||
)
|
)
|
||||||
b_A2_args: DictConfig = dataclasses.field(
|
b_A2_args: DictConfig = dataclasses.field(
|
||||||
default_factory=lambda: DictConfig({"k": 1, "m": 3}
|
default_factory=lambda: DictConfig({"k": 1, "n": 2}
|
||||||
)
|
)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
@ -155,12 +155,14 @@ will expand to
|
|||||||
|
|
||||||
3. Aside from these classes, the members of these classes should be things
|
3. Aside from these classes, the members of these classes should be things
|
||||||
which DictConfig is happy with: e.g. (bool, int, str, None, float) and what
|
which DictConfig is happy with: e.g. (bool, int, str, None, float) and what
|
||||||
can be built from them with DictConfigs and lists of them.
|
can be built from them with `DictConfig`s and lists of them.
|
||||||
|
|
||||||
In addition, you can call get_default_args on a function or class to get
|
In addition, you can call `get_default_args` on a function or class to get
|
||||||
the DictConfig of its defaulted arguments, assuming those are all things
|
the `DictConfig` of its defaulted arguments, assuming those are all things
|
||||||
which DictConfig is happy with. If you want to use such a thing as a member
|
which `DictConfig` is happy with, so long as you add a call to
|
||||||
of another configured class, `get_default_args_field` is a helper.
|
`enable_get_default_args` after its definition. If you want to use such a
|
||||||
|
thing as the default for a member of another configured class,
|
||||||
|
`get_default_args_field` is a helper.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@ -172,6 +174,7 @@ _unprocessed_warning: str = (
|
|||||||
|
|
||||||
TYPE_SUFFIX: str = "_class_type"
|
TYPE_SUFFIX: str = "_class_type"
|
||||||
ARGS_SUFFIX: str = "_args"
|
ARGS_SUFFIX: str = "_args"
|
||||||
|
ENABLED_SUFFIX: str = "_enabled"
|
||||||
|
|
||||||
|
|
||||||
class ReplaceableBase:
|
class ReplaceableBase:
|
||||||
@ -351,7 +354,8 @@ class _ProcessType(Enum):
|
|||||||
|
|
||||||
CONFIGURABLE = 1
|
CONFIGURABLE = 1
|
||||||
REPLACEABLE = 2
|
REPLACEABLE = 2
|
||||||
OPTIONAL_REPLACEABLE = 3
|
OPTIONAL_CONFIGURABLE = 3
|
||||||
|
OPTIONAL_REPLACEABLE = 4
|
||||||
|
|
||||||
|
|
||||||
def _default_create(
|
def _default_create(
|
||||||
@ -379,6 +383,15 @@ def _default_create(
|
|||||||
args = getattr(self, name + ARGS_SUFFIX)
|
args = getattr(self, name + ARGS_SUFFIX)
|
||||||
setattr(self, name, type_(**args))
|
setattr(self, name, type_(**args))
|
||||||
|
|
||||||
|
def inner_optional(self):
|
||||||
|
expand_args_fields(type_)
|
||||||
|
enabled = getattr(self, name + ENABLED_SUFFIX)
|
||||||
|
if enabled:
|
||||||
|
args = getattr(self, name + ARGS_SUFFIX)
|
||||||
|
setattr(self, name, type_(**args))
|
||||||
|
else:
|
||||||
|
setattr(self, name, None)
|
||||||
|
|
||||||
def inner_pluggable(self):
|
def inner_pluggable(self):
|
||||||
type_name = getattr(self, name + TYPE_SUFFIX)
|
type_name = getattr(self, name + TYPE_SUFFIX)
|
||||||
if type_name is None:
|
if type_name is None:
|
||||||
@ -398,6 +411,8 @@ def _default_create(
|
|||||||
args = getattr(self, f"{name}_{type_name}{ARGS_SUFFIX}")
|
args = getattr(self, f"{name}_{type_name}{ARGS_SUFFIX}")
|
||||||
setattr(self, name, chosen_class(**args))
|
setattr(self, name, chosen_class(**args))
|
||||||
|
|
||||||
|
if process_type == _ProcessType.OPTIONAL_CONFIGURABLE:
|
||||||
|
return inner_optional
|
||||||
return inner if process_type == _ProcessType.CONFIGURABLE else inner_pluggable
|
return inner if process_type == _ProcessType.CONFIGURABLE else inner_pluggable
|
||||||
|
|
||||||
|
|
||||||
@ -608,8 +623,8 @@ def expand_args_fields(
|
|||||||
|
|
||||||
with
|
with
|
||||||
|
|
||||||
x_Y_args : DictConfig = dataclasses.field(default_factory=lambda: DictConfig())
|
x_Y_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(Y))
|
||||||
x_Z_args : DictConfig = dataclasses.field(default_factory=lambda: DictConfig())
|
x_Z_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(Z))
|
||||||
def create_x(self):
|
def create_x(self):
|
||||||
self.x = registry.get(X, self.x_class_type)(
|
self.x = registry.get(X, self.x_class_type)(
|
||||||
**self.getattr(f"x_{self.x_class_type}_args)
|
**self.getattr(f"x_{self.x_class_type}_args)
|
||||||
@ -629,8 +644,8 @@ def expand_args_fields(
|
|||||||
|
|
||||||
with
|
with
|
||||||
|
|
||||||
x_Y_args : DictConfig = dataclasses.field(default_factory=lambda: DictConfig())
|
x_Y_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(Y))
|
||||||
x_Z_args : DictConfig = dataclasses.field(default_factory=lambda: DictConfig())
|
x_Z_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(Z))
|
||||||
def create_x(self):
|
def create_x(self):
|
||||||
if self.x_class_type is None:
|
if self.x_class_type is None:
|
||||||
self.x = None
|
self.x = None
|
||||||
@ -653,10 +668,30 @@ def expand_args_fields(
|
|||||||
|
|
||||||
will be replaced with
|
will be replaced with
|
||||||
|
|
||||||
x_args : DictConfig = dataclasses.field(default_factory=lambda: DictConfig())
|
x_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(X))
|
||||||
def create_x(self):
|
def create_x(self):
|
||||||
self.x = X(self.x_args)
|
self.x = X(self.x_args)
|
||||||
|
|
||||||
|
Similarly, replace,
|
||||||
|
|
||||||
|
x: Optional[X]
|
||||||
|
|
||||||
|
and optionally
|
||||||
|
|
||||||
|
def create_x(self):...
|
||||||
|
x_enabled: bool = ...
|
||||||
|
|
||||||
|
with
|
||||||
|
|
||||||
|
x_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(X))
|
||||||
|
x_enabled: bool = False
|
||||||
|
def create_x(self):
|
||||||
|
if self.x_enabled:
|
||||||
|
self.x = X(self.x_args)
|
||||||
|
else:
|
||||||
|
self.x = None
|
||||||
|
|
||||||
|
|
||||||
Also adds the following class members, unannotated so that dataclass
|
Also adds the following class members, unannotated so that dataclass
|
||||||
ignores them.
|
ignores them.
|
||||||
- _creation_functions: Tuple[str] of all the create_ functions,
|
- _creation_functions: Tuple[str] of all the create_ functions,
|
||||||
@ -779,6 +814,9 @@ def _get_type_to_process(type_) -> Optional[Tuple[Type, _ProcessType]]:
|
|||||||
):
|
):
|
||||||
return underlying, _ProcessType.OPTIONAL_REPLACEABLE
|
return underlying, _ProcessType.OPTIONAL_REPLACEABLE
|
||||||
|
|
||||||
|
if isinstance(underlying, type) and issubclass(underlying, Configurable):
|
||||||
|
return underlying, _ProcessType.OPTIONAL_CONFIGURABLE
|
||||||
|
|
||||||
if not isinstance(type_, type):
|
if not isinstance(type_, type):
|
||||||
# e.g. any other Union or Tuple
|
# e.g. any other Union or Tuple
|
||||||
return
|
return
|
||||||
@ -817,7 +855,7 @@ def _process_member(
|
|||||||
# there are non-defaulted standard class members.
|
# there are non-defaulted standard class members.
|
||||||
del some_class.__annotations__[name]
|
del some_class.__annotations__[name]
|
||||||
|
|
||||||
if process_type != _ProcessType.CONFIGURABLE:
|
if process_type in (_ProcessType.REPLACEABLE, _ProcessType.OPTIONAL_REPLACEABLE):
|
||||||
type_name = name + TYPE_SUFFIX
|
type_name = name + TYPE_SUFFIX
|
||||||
if type_name not in some_class.__annotations__:
|
if type_name not in some_class.__annotations__:
|
||||||
if process_type == _ProcessType.OPTIONAL_REPLACEABLE:
|
if process_type == _ProcessType.OPTIONAL_REPLACEABLE:
|
||||||
@ -866,6 +904,11 @@ def _process_member(
|
|||||||
_do_not_process=_do_not_process + (some_class,),
|
_do_not_process=_do_not_process + (some_class,),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
if process_type == _ProcessType.OPTIONAL_CONFIGURABLE:
|
||||||
|
enabled_name = name + ENABLED_SUFFIX
|
||||||
|
if enabled_name not in some_class.__annotations__:
|
||||||
|
some_class.__annotations__[enabled_name] = bool
|
||||||
|
setattr(some_class, enabled_name, False)
|
||||||
|
|
||||||
creation_function_name = f"create_{name}"
|
creation_function_name = f"create_{name}"
|
||||||
if not hasattr(some_class, creation_function_name):
|
if not hasattr(some_class, creation_function_name):
|
||||||
@ -884,7 +927,8 @@ def remove_unused_components(dict_: DictConfig) -> None:
|
|||||||
pluggable parts which are not in use.
|
pluggable parts which are not in use.
|
||||||
For example, if renderer_class_type is SignedDistanceFunctionRenderer,
|
For example, if renderer_class_type is SignedDistanceFunctionRenderer,
|
||||||
the renderer_MultiPassEmissionAbsorptionRenderer_args will be
|
the renderer_MultiPassEmissionAbsorptionRenderer_args will be
|
||||||
removed.
|
removed. Also, if chocolate_enabled is False, then chocolate_args will
|
||||||
|
be removed.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dict_: (MODIFIED IN PLACE) a DictConfig instance
|
dict_: (MODIFIED IN PLACE) a DictConfig instance
|
||||||
@ -904,6 +948,14 @@ def remove_unused_components(dict_: DictConfig) -> None:
|
|||||||
if key.startswith(replaceable + "_") and key != expect:
|
if key.startswith(replaceable + "_") and key != expect:
|
||||||
del dict_[key]
|
del dict_[key]
|
||||||
|
|
||||||
|
suffix_length = len(ENABLED_SUFFIX)
|
||||||
|
enableables = [key[:-suffix_length] for key in keys if key.endswith(ENABLED_SUFFIX)]
|
||||||
|
for enableable in enableables:
|
||||||
|
enabled = dict_[enableable + ENABLED_SUFFIX]
|
||||||
|
if not enabled:
|
||||||
|
with open_dict(dict_):
|
||||||
|
dict_.pop(enableable + ARGS_SUFFIX, None)
|
||||||
|
|
||||||
for key in dict_:
|
for key in dict_:
|
||||||
if isinstance(dict_.get(key), DictConfig):
|
if isinstance(dict_.get(key), DictConfig):
|
||||||
remove_unused_components(dict_[key])
|
remove_unused_components(dict_[key])
|
||||||
|
@ -106,8 +106,10 @@ class TestConfig(unittest.TestCase):
|
|||||||
gt(Optional[Fruit]), (Fruit, _ProcessType.OPTIONAL_REPLACEABLE)
|
gt(Optional[Fruit]), (Fruit, _ProcessType.OPTIONAL_REPLACEABLE)
|
||||||
)
|
)
|
||||||
self.assertEqual(gt(MainTest), (MainTest, _ProcessType.CONFIGURABLE))
|
self.assertEqual(gt(MainTest), (MainTest, _ProcessType.CONFIGURABLE))
|
||||||
|
self.assertEqual(
|
||||||
|
gt(Optional[MainTest]), (MainTest, _ProcessType.OPTIONAL_CONFIGURABLE)
|
||||||
|
)
|
||||||
self.assertIsNone(gt(Optional[int]))
|
self.assertIsNone(gt(Optional[int]))
|
||||||
self.assertIsNone(gt(Optional[MainTest]))
|
|
||||||
self.assertIsNone(gt(Tuple[Fruit]))
|
self.assertIsNone(gt(Tuple[Fruit]))
|
||||||
self.assertIsNone(gt(Tuple[Fruit, Animal]))
|
self.assertIsNone(gt(Tuple[Fruit, Animal]))
|
||||||
self.assertIsNone(gt(Optional[List[int]]))
|
self.assertIsNone(gt(Optional[List[int]]))
|
||||||
@ -427,15 +429,24 @@ class TestConfig(unittest.TestCase):
|
|||||||
run_auto_creation(self)
|
run_auto_creation(self)
|
||||||
|
|
||||||
class C(Configurable):
|
class C(Configurable):
|
||||||
b: B
|
b1: B
|
||||||
|
b2: Optional[B]
|
||||||
|
b3: Optional[B]
|
||||||
|
b2_enabled: bool = True
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
run_auto_creation(self)
|
run_auto_creation(self)
|
||||||
|
|
||||||
c_args = get_default_args(C)
|
c_args = get_default_args(C)
|
||||||
c = C(**c_args)
|
c = C(**c_args)
|
||||||
self.assertIsInstance(c.b.a, A)
|
self.assertIsInstance(c.b1.a, A)
|
||||||
self.assertEqual(c.b.a.n, 9)
|
self.assertEqual(c.b1.a.n, 9)
|
||||||
|
self.assertFalse(hasattr(c, "b1_enabled"))
|
||||||
|
self.assertIsInstance(c.b2.a, A)
|
||||||
|
self.assertEqual(c.b2.a.n, 9)
|
||||||
|
self.assertTrue(c.b2_enabled)
|
||||||
|
self.assertIsNone(c.b3)
|
||||||
|
self.assertFalse(c.b3_enabled)
|
||||||
|
|
||||||
def test_doc(self):
|
def test_doc(self):
|
||||||
# The case in the docstring.
|
# The case in the docstring.
|
||||||
@ -522,7 +533,7 @@ class TestConfig(unittest.TestCase):
|
|||||||
# Like torch.nn.Module, this class contains annotations
|
# Like torch.nn.Module, this class contains annotations
|
||||||
# but is not designed to be dataclass'd.
|
# but is not designed to be dataclass'd.
|
||||||
# This test ensures that such classes, when inherited fron,
|
# This test ensures that such classes, when inherited fron,
|
||||||
# are not accidentally expand_args_fields.
|
# are not accidentally affected by expand_args_fields.
|
||||||
a: int = 9
|
a: int = 9
|
||||||
b: int
|
b: int
|
||||||
|
|
||||||
@ -654,6 +665,15 @@ class TestConfig(unittest.TestCase):
|
|||||||
self.assertEqual(sorted(instance_data.keys()), expected_keys)
|
self.assertEqual(sorted(instance_data.keys()), expected_keys)
|
||||||
self.assertEqual(instance_data, expected)
|
self.assertEqual(instance_data, expected)
|
||||||
|
|
||||||
|
def test_remove_unused_components_optional(self):
|
||||||
|
class MainTestWrapper(Configurable):
|
||||||
|
mt: Optional[MainTest]
|
||||||
|
|
||||||
|
args = get_default_args(MainTestWrapper)
|
||||||
|
self.assertEqual(list(args.keys()), ["mt_args", "mt_enabled"])
|
||||||
|
remove_unused_components(args)
|
||||||
|
self.assertEqual(OmegaConf.to_yaml(args), "mt_enabled: false\n")
|
||||||
|
|
||||||
|
|
||||||
@dataclass(eq=False)
|
@dataclass(eq=False)
|
||||||
class MockDataclass:
|
class MockDataclass:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user