mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 11:52:50 +08:00
create_x_impl
Summary: Make create_x delegate to create_x_impl so that users can rely on create_x_impl in their overrides of create_x. Reviewed By: shapovalov, davnov134 Differential Revision: D35929810 fbshipit-source-id: 80595894ee93346b881729995775876b016fc08e
This commit is contained in:
parent
3b2300641a
commit
899a3192b6
@ -175,6 +175,8 @@ _unprocessed_warning: str = (
|
|||||||
TYPE_SUFFIX: str = "_class_type"
|
TYPE_SUFFIX: str = "_class_type"
|
||||||
ARGS_SUFFIX: str = "_args"
|
ARGS_SUFFIX: str = "_args"
|
||||||
ENABLED_SUFFIX: str = "_enabled"
|
ENABLED_SUFFIX: str = "_enabled"
|
||||||
|
CREATE_PREFIX: str = "create_"
|
||||||
|
IMPL_SUFFIX: str = "_impl"
|
||||||
|
|
||||||
|
|
||||||
class ReplaceableBase:
|
class ReplaceableBase:
|
||||||
@ -375,25 +377,68 @@ def _default_create(
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Function taking one argument, the object whose member should be
|
Function taking one argument, the object whose member should be
|
||||||
initialized.
|
initialized, i.e. self.
|
||||||
"""
|
"""
|
||||||
|
impl_name = f"{CREATE_PREFIX}{name}{IMPL_SUFFIX}"
|
||||||
|
|
||||||
def inner(self):
|
def inner(self):
|
||||||
expand_args_fields(type_)
|
expand_args_fields(type_)
|
||||||
|
impl = getattr(self, impl_name)
|
||||||
args = getattr(self, name + ARGS_SUFFIX)
|
args = getattr(self, name + ARGS_SUFFIX)
|
||||||
setattr(self, name, type_(**args))
|
impl(True, args)
|
||||||
|
|
||||||
def inner_optional(self):
|
def inner_optional(self):
|
||||||
expand_args_fields(type_)
|
expand_args_fields(type_)
|
||||||
|
impl = getattr(self, impl_name)
|
||||||
enabled = getattr(self, name + ENABLED_SUFFIX)
|
enabled = getattr(self, name + ENABLED_SUFFIX)
|
||||||
|
args = getattr(self, name + ARGS_SUFFIX)
|
||||||
|
impl(enabled, args)
|
||||||
|
|
||||||
|
def inner_pluggable(self):
|
||||||
|
type_name = getattr(self, name + TYPE_SUFFIX)
|
||||||
|
impl = getattr(self, impl_name)
|
||||||
|
if type_name is None:
|
||||||
|
args = None
|
||||||
|
else:
|
||||||
|
args = getattr(self, f"{name}_{type_name}{ARGS_SUFFIX}", None)
|
||||||
|
impl(type_name, args)
|
||||||
|
|
||||||
|
if process_type == _ProcessType.OPTIONAL_CONFIGURABLE:
|
||||||
|
return inner_optional
|
||||||
|
return inner if process_type == _ProcessType.CONFIGURABLE else inner_pluggable
|
||||||
|
|
||||||
|
|
||||||
|
def _default_create_impl(
|
||||||
|
name: str, type_: Type, process_type: _ProcessType
|
||||||
|
) -> Callable[[Any, Any, DictConfig], None]:
|
||||||
|
"""
|
||||||
|
Return the default internal function for initialising a member. This is a function
|
||||||
|
which could be called in the create_ function to initialise the member.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: name of the member
|
||||||
|
type_: type of the member (with any Optional removed)
|
||||||
|
process_type: Shows whether member's declared type inherits ReplaceableBase,
|
||||||
|
in which case the actual type to be created is decided at
|
||||||
|
runtime.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Function taking
|
||||||
|
- self, the object whose member should be initialized.
|
||||||
|
- option for what to do. This is
|
||||||
|
- for pluggables, the type to initialise or None to do nothing
|
||||||
|
- for non pluggables, a bool indicating whether to initialise.
|
||||||
|
- the args for initializing the member.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def create_configurable(self, enabled, args):
|
||||||
if enabled:
|
if enabled:
|
||||||
args = getattr(self, name + ARGS_SUFFIX)
|
expand_args_fields(type_)
|
||||||
setattr(self, name, type_(**args))
|
setattr(self, name, type_(**args))
|
||||||
else:
|
else:
|
||||||
setattr(self, name, None)
|
setattr(self, name, None)
|
||||||
|
|
||||||
def inner_pluggable(self):
|
def create_pluggable(self, type_name, args):
|
||||||
type_name = getattr(self, name + TYPE_SUFFIX)
|
|
||||||
if type_name is None:
|
if type_name is None:
|
||||||
setattr(self, name, None)
|
setattr(self, name, None)
|
||||||
return
|
return
|
||||||
@ -408,12 +453,11 @@ def _default_create(
|
|||||||
# were made in the redefinition will not be reflected here.
|
# were made in the redefinition will not be reflected here.
|
||||||
warnings.warn(f"New implementation of {type_name} is being chosen.")
|
warnings.warn(f"New implementation of {type_name} is being chosen.")
|
||||||
expand_args_fields(chosen_class)
|
expand_args_fields(chosen_class)
|
||||||
args = getattr(self, f"{name}_{type_name}{ARGS_SUFFIX}")
|
|
||||||
setattr(self, name, chosen_class(**args))
|
setattr(self, name, chosen_class(**args))
|
||||||
|
|
||||||
if process_type == _ProcessType.OPTIONAL_CONFIGURABLE:
|
if process_type in (_ProcessType.CONFIGURABLE, _ProcessType.OPTIONAL_CONFIGURABLE):
|
||||||
return inner_optional
|
return create_configurable
|
||||||
return inner if process_type == _ProcessType.CONFIGURABLE else inner_pluggable
|
return create_pluggable
|
||||||
|
|
||||||
|
|
||||||
def run_auto_creation(self: Any) -> None:
|
def run_auto_creation(self: Any) -> None:
|
||||||
@ -628,11 +672,12 @@ def expand_args_fields(
|
|||||||
x_Y_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(Y))
|
x_Y_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(Y))
|
||||||
x_Z_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(Z))
|
x_Z_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(Z))
|
||||||
def create_x(self):
|
def create_x(self):
|
||||||
x_type = registry.get(X, self.x_class_type)
|
args = self.getattr(f"x_{self.x_class_type}_args")
|
||||||
|
self.create_x_impl(self.x_class_type, args)
|
||||||
|
def create_x_impl(self, x_type, args):
|
||||||
|
x_type = registry.get(X, x_type)
|
||||||
expand_args_fields(x_type)
|
expand_args_fields(x_type)
|
||||||
self.x = x_type(
|
self.x = x_type(**args)
|
||||||
**self.getattr(f"x_{self.x_class_type}_args)
|
|
||||||
)
|
|
||||||
x_class_type: str = "UNDEFAULTED"
|
x_class_type: str = "UNDEFAULTED"
|
||||||
|
|
||||||
without adding the optional attributes if they are already there.
|
without adding the optional attributes if they are already there.
|
||||||
@ -652,14 +697,19 @@ def expand_args_fields(
|
|||||||
x_Z_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(Z))
|
x_Z_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(Z))
|
||||||
def create_x(self):
|
def create_x(self):
|
||||||
if self.x_class_type is None:
|
if self.x_class_type is None:
|
||||||
|
args = None
|
||||||
|
else:
|
||||||
|
args = self.getattr(f"x_{self.x_class_type}_args", None)
|
||||||
|
self.create_x_impl(self.x_class_type, args)
|
||||||
|
def create_x_impl(self, x_class_type, args):
|
||||||
|
if x_class_type is None:
|
||||||
self.x = None
|
self.x = None
|
||||||
return
|
return
|
||||||
|
|
||||||
x_type = registry.get(X, self.x_class_type)
|
x_type = registry.get(X, x_class_type)
|
||||||
expand_args_fields(x_type)
|
expand_args_fields(x_type)
|
||||||
self.x = x_type(
|
assert args is not None
|
||||||
**self.getattr(f"x_{self.x_class_type}_args)
|
self.x = x_type(**args)
|
||||||
)
|
|
||||||
x_class_type: Optional[str] = "UNDEFAULTED"
|
x_class_type: Optional[str] = "UNDEFAULTED"
|
||||||
|
|
||||||
without adding the optional attributes if they are already there.
|
without adding the optional attributes if they are already there.
|
||||||
@ -676,8 +726,14 @@ def expand_args_fields(
|
|||||||
|
|
||||||
x_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(X))
|
x_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(X))
|
||||||
def create_x(self):
|
def create_x(self):
|
||||||
expand_args_fields(X)
|
self.create_x_impl(True, self.x_args)
|
||||||
self.x = X(self.x_args)
|
|
||||||
|
def create_x_impl(self, enabled, args):
|
||||||
|
if enabled:
|
||||||
|
expand_args_fields(X)
|
||||||
|
self.x = X(**args)
|
||||||
|
else:
|
||||||
|
self.x = None
|
||||||
|
|
||||||
Similarly, replace,
|
Similarly, replace,
|
||||||
|
|
||||||
@ -693,9 +749,12 @@ def expand_args_fields(
|
|||||||
x_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(X))
|
x_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(X))
|
||||||
x_enabled: bool = False
|
x_enabled: bool = False
|
||||||
def create_x(self):
|
def create_x(self):
|
||||||
if self.x_enabled:
|
self.create_x_impl(self.x_enabled, self.x_args)
|
||||||
|
|
||||||
|
def create_x_impl(self, enabled, args):
|
||||||
|
if enabled:
|
||||||
expand_args_fields(X)
|
expand_args_fields(X)
|
||||||
self.x = X(self.x_args)
|
self.x = X(**args)
|
||||||
else:
|
else:
|
||||||
self.x = None
|
self.x = None
|
||||||
|
|
||||||
@ -703,7 +762,7 @@ def expand_args_fields(
|
|||||||
Also adds the following class members, unannotated so that dataclass
|
Also adds the following class members, unannotated so that dataclass
|
||||||
ignores them.
|
ignores them.
|
||||||
- _creation_functions: Tuple[str] of all the create_ functions,
|
- _creation_functions: Tuple[str] of all the create_ functions,
|
||||||
including those from base classes.
|
including those from base classes (not the create_x_impl ones).
|
||||||
- _known_implementations: Dict[str, Type] containing the classes which
|
- _known_implementations: Dict[str, Type] containing the classes which
|
||||||
have been found from the registry.
|
have been found from the registry.
|
||||||
(used only to raise a warning if it one has been overwritten)
|
(used only to raise a warning if it one has been overwritten)
|
||||||
@ -918,7 +977,7 @@ def _process_member(
|
|||||||
some_class.__annotations__[enabled_name] = bool
|
some_class.__annotations__[enabled_name] = bool
|
||||||
setattr(some_class, enabled_name, False)
|
setattr(some_class, enabled_name, False)
|
||||||
|
|
||||||
creation_function_name = f"create_{name}"
|
creation_function_name = f"{CREATE_PREFIX}{name}"
|
||||||
if not hasattr(some_class, creation_function_name):
|
if not hasattr(some_class, creation_function_name):
|
||||||
setattr(
|
setattr(
|
||||||
some_class,
|
some_class,
|
||||||
@ -927,6 +986,14 @@ def _process_member(
|
|||||||
)
|
)
|
||||||
creation_functions.append(creation_function_name)
|
creation_functions.append(creation_function_name)
|
||||||
|
|
||||||
|
creation_function_impl_name = f"{CREATE_PREFIX}{name}{IMPL_SUFFIX}"
|
||||||
|
if not hasattr(some_class, creation_function_impl_name):
|
||||||
|
setattr(
|
||||||
|
some_class,
|
||||||
|
creation_function_impl_name,
|
||||||
|
_default_create_impl(name, type_, process_type),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def remove_unused_components(dict_: DictConfig) -> None:
|
def remove_unused_components(dict_: DictConfig) -> None:
|
||||||
"""
|
"""
|
||||||
|
@ -69,6 +69,10 @@ class LargePear(Pear):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class BoringConfigurable(Configurable):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class MainTest(Configurable):
|
class MainTest(Configurable):
|
||||||
the_fruit: Fruit
|
the_fruit: Fruit
|
||||||
n_ids: int
|
n_ids: int
|
||||||
@ -674,6 +678,57 @@ class TestConfig(unittest.TestCase):
|
|||||||
remove_unused_components(args)
|
remove_unused_components(args)
|
||||||
self.assertEqual(OmegaConf.to_yaml(args), "mt_enabled: false\n")
|
self.assertEqual(OmegaConf.to_yaml(args), "mt_enabled: false\n")
|
||||||
|
|
||||||
|
def test_impls(self):
|
||||||
|
# Check that create_x actually uses create_x_impl to do its work
|
||||||
|
# by using all the member types, both with a faked impl function
|
||||||
|
# and without.
|
||||||
|
# members with _0 are optional and absent, those with _o are
|
||||||
|
# optional and present.
|
||||||
|
control_args = []
|
||||||
|
|
||||||
|
def fake_impl(self, control, args):
|
||||||
|
control_args.append(control)
|
||||||
|
|
||||||
|
for fake in [False, True]:
|
||||||
|
|
||||||
|
class MyClass(Configurable):
|
||||||
|
fruit: Fruit
|
||||||
|
fruit_class_type: str = "Orange"
|
||||||
|
fruit_o: Optional[Fruit]
|
||||||
|
fruit_o_class_type: str = "Orange"
|
||||||
|
fruit_0: Optional[Fruit]
|
||||||
|
fruit_0_class_type: Optional[str] = None
|
||||||
|
boring: BoringConfigurable
|
||||||
|
boring_o: Optional[BoringConfigurable]
|
||||||
|
boring_o_enabled: bool = True
|
||||||
|
boring_0: Optional[BoringConfigurable]
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
run_auto_creation(self)
|
||||||
|
|
||||||
|
if fake:
|
||||||
|
MyClass.create_fruit_impl = fake_impl
|
||||||
|
MyClass.create_fruit_o_impl = fake_impl
|
||||||
|
MyClass.create_boring_impl = fake_impl
|
||||||
|
MyClass.create_boring_o_impl = fake_impl
|
||||||
|
|
||||||
|
expand_args_fields(MyClass)
|
||||||
|
instance = MyClass()
|
||||||
|
for name in ["fruit", "fruit_o", "boring", "boring_o"]:
|
||||||
|
self.assertEqual(
|
||||||
|
hasattr(instance, name), not fake, msg=f"{name} {fake}"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertIsNone(instance.fruit_0)
|
||||||
|
self.assertIsNone(instance.boring_0)
|
||||||
|
if not fake:
|
||||||
|
self.assertIsInstance(instance.fruit, Orange)
|
||||||
|
self.assertIsInstance(instance.fruit_o, Orange)
|
||||||
|
self.assertIsInstance(instance.boring, BoringConfigurable)
|
||||||
|
self.assertIsInstance(instance.boring_o, BoringConfigurable)
|
||||||
|
|
||||||
|
self.assertEqual(control_args, ["Orange", "Orange", True, True])
|
||||||
|
|
||||||
|
|
||||||
@dataclass(eq=False)
|
@dataclass(eq=False)
|
||||||
class MockDataclass:
|
class MockDataclass:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user