mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
create_x_impl
Summary: Make create_x delegate to create_x_impl so that users can rely on create_x_impl in their overrides of create_x. Reviewed By: shapovalov, davnov134 Differential Revision: D35929810 fbshipit-source-id: 80595894ee93346b881729995775876b016fc08e
This commit is contained in:
parent
3b2300641a
commit
899a3192b6
@ -175,6 +175,8 @@ _unprocessed_warning: str = (
|
||||
TYPE_SUFFIX: str = "_class_type"
|
||||
ARGS_SUFFIX: str = "_args"
|
||||
ENABLED_SUFFIX: str = "_enabled"
|
||||
CREATE_PREFIX: str = "create_"
|
||||
IMPL_SUFFIX: str = "_impl"
|
||||
|
||||
|
||||
class ReplaceableBase:
|
||||
@ -375,25 +377,68 @@ def _default_create(
|
||||
|
||||
Returns:
|
||||
Function taking one argument, the object whose member should be
|
||||
initialized.
|
||||
initialized, i.e. self.
|
||||
"""
|
||||
impl_name = f"{CREATE_PREFIX}{name}{IMPL_SUFFIX}"
|
||||
|
||||
def inner(self):
|
||||
expand_args_fields(type_)
|
||||
impl = getattr(self, impl_name)
|
||||
args = getattr(self, name + ARGS_SUFFIX)
|
||||
setattr(self, name, type_(**args))
|
||||
impl(True, args)
|
||||
|
||||
def inner_optional(self):
|
||||
expand_args_fields(type_)
|
||||
impl = getattr(self, impl_name)
|
||||
enabled = getattr(self, name + ENABLED_SUFFIX)
|
||||
args = getattr(self, name + ARGS_SUFFIX)
|
||||
impl(enabled, args)
|
||||
|
||||
def inner_pluggable(self):
|
||||
type_name = getattr(self, name + TYPE_SUFFIX)
|
||||
impl = getattr(self, impl_name)
|
||||
if type_name is None:
|
||||
args = None
|
||||
else:
|
||||
args = getattr(self, f"{name}_{type_name}{ARGS_SUFFIX}", None)
|
||||
impl(type_name, args)
|
||||
|
||||
if process_type == _ProcessType.OPTIONAL_CONFIGURABLE:
|
||||
return inner_optional
|
||||
return inner if process_type == _ProcessType.CONFIGURABLE else inner_pluggable
|
||||
|
||||
|
||||
def _default_create_impl(
|
||||
name: str, type_: Type, process_type: _ProcessType
|
||||
) -> Callable[[Any, Any, DictConfig], None]:
|
||||
"""
|
||||
Return the default internal function for initialising a member. This is a function
|
||||
which could be called in the create_ function to initialise the member.
|
||||
|
||||
Args:
|
||||
name: name of the member
|
||||
type_: type of the member (with any Optional removed)
|
||||
process_type: Shows whether member's declared type inherits ReplaceableBase,
|
||||
in which case the actual type to be created is decided at
|
||||
runtime.
|
||||
|
||||
Returns:
|
||||
Function taking
|
||||
- self, the object whose member should be initialized.
|
||||
- option for what to do. This is
|
||||
- for pluggables, the type to initialise or None to do nothing
|
||||
- for non pluggables, a bool indicating whether to initialise.
|
||||
- the args for initializing the member.
|
||||
"""
|
||||
|
||||
def create_configurable(self, enabled, args):
|
||||
if enabled:
|
||||
args = getattr(self, name + ARGS_SUFFIX)
|
||||
expand_args_fields(type_)
|
||||
setattr(self, name, type_(**args))
|
||||
else:
|
||||
setattr(self, name, None)
|
||||
|
||||
def inner_pluggable(self):
|
||||
type_name = getattr(self, name + TYPE_SUFFIX)
|
||||
def create_pluggable(self, type_name, args):
|
||||
if type_name is None:
|
||||
setattr(self, name, None)
|
||||
return
|
||||
@ -408,12 +453,11 @@ def _default_create(
|
||||
# were made in the redefinition will not be reflected here.
|
||||
warnings.warn(f"New implementation of {type_name} is being chosen.")
|
||||
expand_args_fields(chosen_class)
|
||||
args = getattr(self, f"{name}_{type_name}{ARGS_SUFFIX}")
|
||||
setattr(self, name, chosen_class(**args))
|
||||
|
||||
if process_type == _ProcessType.OPTIONAL_CONFIGURABLE:
|
||||
return inner_optional
|
||||
return inner if process_type == _ProcessType.CONFIGURABLE else inner_pluggable
|
||||
if process_type in (_ProcessType.CONFIGURABLE, _ProcessType.OPTIONAL_CONFIGURABLE):
|
||||
return create_configurable
|
||||
return create_pluggable
|
||||
|
||||
|
||||
def run_auto_creation(self: Any) -> None:
|
||||
@ -628,11 +672,12 @@ def expand_args_fields(
|
||||
x_Y_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(Y))
|
||||
x_Z_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(Z))
|
||||
def create_x(self):
|
||||
x_type = registry.get(X, self.x_class_type)
|
||||
args = self.getattr(f"x_{self.x_class_type}_args")
|
||||
self.create_x_impl(self.x_class_type, args)
|
||||
def create_x_impl(self, x_type, args):
|
||||
x_type = registry.get(X, x_type)
|
||||
expand_args_fields(x_type)
|
||||
self.x = x_type(
|
||||
**self.getattr(f"x_{self.x_class_type}_args)
|
||||
)
|
||||
self.x = x_type(**args)
|
||||
x_class_type: str = "UNDEFAULTED"
|
||||
|
||||
without adding the optional attributes if they are already there.
|
||||
@ -652,14 +697,19 @@ def expand_args_fields(
|
||||
x_Z_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(Z))
|
||||
def create_x(self):
|
||||
if self.x_class_type is None:
|
||||
args = None
|
||||
else:
|
||||
args = self.getattr(f"x_{self.x_class_type}_args", None)
|
||||
self.create_x_impl(self.x_class_type, args)
|
||||
def create_x_impl(self, x_class_type, args):
|
||||
if x_class_type is None:
|
||||
self.x = None
|
||||
return
|
||||
|
||||
x_type = registry.get(X, self.x_class_type)
|
||||
x_type = registry.get(X, x_class_type)
|
||||
expand_args_fields(x_type)
|
||||
self.x = x_type(
|
||||
**self.getattr(f"x_{self.x_class_type}_args)
|
||||
)
|
||||
assert args is not None
|
||||
self.x = x_type(**args)
|
||||
x_class_type: Optional[str] = "UNDEFAULTED"
|
||||
|
||||
without adding the optional attributes if they are already there.
|
||||
@ -676,8 +726,14 @@ def expand_args_fields(
|
||||
|
||||
x_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(X))
|
||||
def create_x(self):
|
||||
expand_args_fields(X)
|
||||
self.x = X(self.x_args)
|
||||
self.create_x_impl(True, self.x_args)
|
||||
|
||||
def create_x_impl(self, enabled, args):
|
||||
if enabled:
|
||||
expand_args_fields(X)
|
||||
self.x = X(**args)
|
||||
else:
|
||||
self.x = None
|
||||
|
||||
Similarly, replace,
|
||||
|
||||
@ -693,9 +749,12 @@ def expand_args_fields(
|
||||
x_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(X))
|
||||
x_enabled: bool = False
|
||||
def create_x(self):
|
||||
if self.x_enabled:
|
||||
self.create_x_impl(self.x_enabled, self.x_args)
|
||||
|
||||
def create_x_impl(self, enabled, args):
|
||||
if enabled:
|
||||
expand_args_fields(X)
|
||||
self.x = X(self.x_args)
|
||||
self.x = X(**args)
|
||||
else:
|
||||
self.x = None
|
||||
|
||||
@ -703,7 +762,7 @@ def expand_args_fields(
|
||||
Also adds the following class members, unannotated so that dataclass
|
||||
ignores them.
|
||||
- _creation_functions: Tuple[str] of all the create_ functions,
|
||||
including those from base classes.
|
||||
including those from base classes (not the create_x_impl ones).
|
||||
- _known_implementations: Dict[str, Type] containing the classes which
|
||||
have been found from the registry.
|
||||
(used only to raise a warning if it one has been overwritten)
|
||||
@ -918,7 +977,7 @@ def _process_member(
|
||||
some_class.__annotations__[enabled_name] = bool
|
||||
setattr(some_class, enabled_name, False)
|
||||
|
||||
creation_function_name = f"create_{name}"
|
||||
creation_function_name = f"{CREATE_PREFIX}{name}"
|
||||
if not hasattr(some_class, creation_function_name):
|
||||
setattr(
|
||||
some_class,
|
||||
@ -927,6 +986,14 @@ def _process_member(
|
||||
)
|
||||
creation_functions.append(creation_function_name)
|
||||
|
||||
creation_function_impl_name = f"{CREATE_PREFIX}{name}{IMPL_SUFFIX}"
|
||||
if not hasattr(some_class, creation_function_impl_name):
|
||||
setattr(
|
||||
some_class,
|
||||
creation_function_impl_name,
|
||||
_default_create_impl(name, type_, process_type),
|
||||
)
|
||||
|
||||
|
||||
def remove_unused_components(dict_: DictConfig) -> None:
|
||||
"""
|
||||
|
@ -69,6 +69,10 @@ class LargePear(Pear):
|
||||
pass
|
||||
|
||||
|
||||
class BoringConfigurable(Configurable):
|
||||
pass
|
||||
|
||||
|
||||
class MainTest(Configurable):
|
||||
the_fruit: Fruit
|
||||
n_ids: int
|
||||
@ -674,6 +678,57 @@ class TestConfig(unittest.TestCase):
|
||||
remove_unused_components(args)
|
||||
self.assertEqual(OmegaConf.to_yaml(args), "mt_enabled: false\n")
|
||||
|
||||
def test_impls(self):
|
||||
# Check that create_x actually uses create_x_impl to do its work
|
||||
# by using all the member types, both with a faked impl function
|
||||
# and without.
|
||||
# members with _0 are optional and absent, those with _o are
|
||||
# optional and present.
|
||||
control_args = []
|
||||
|
||||
def fake_impl(self, control, args):
|
||||
control_args.append(control)
|
||||
|
||||
for fake in [False, True]:
|
||||
|
||||
class MyClass(Configurable):
|
||||
fruit: Fruit
|
||||
fruit_class_type: str = "Orange"
|
||||
fruit_o: Optional[Fruit]
|
||||
fruit_o_class_type: str = "Orange"
|
||||
fruit_0: Optional[Fruit]
|
||||
fruit_0_class_type: Optional[str] = None
|
||||
boring: BoringConfigurable
|
||||
boring_o: Optional[BoringConfigurable]
|
||||
boring_o_enabled: bool = True
|
||||
boring_0: Optional[BoringConfigurable]
|
||||
|
||||
def __post_init__(self):
|
||||
run_auto_creation(self)
|
||||
|
||||
if fake:
|
||||
MyClass.create_fruit_impl = fake_impl
|
||||
MyClass.create_fruit_o_impl = fake_impl
|
||||
MyClass.create_boring_impl = fake_impl
|
||||
MyClass.create_boring_o_impl = fake_impl
|
||||
|
||||
expand_args_fields(MyClass)
|
||||
instance = MyClass()
|
||||
for name in ["fruit", "fruit_o", "boring", "boring_o"]:
|
||||
self.assertEqual(
|
||||
hasattr(instance, name), not fake, msg=f"{name} {fake}"
|
||||
)
|
||||
|
||||
self.assertIsNone(instance.fruit_0)
|
||||
self.assertIsNone(instance.boring_0)
|
||||
if not fake:
|
||||
self.assertIsInstance(instance.fruit, Orange)
|
||||
self.assertIsInstance(instance.fruit_o, Orange)
|
||||
self.assertIsInstance(instance.boring, BoringConfigurable)
|
||||
self.assertIsInstance(instance.boring_o, BoringConfigurable)
|
||||
|
||||
self.assertEqual(control_args, ["Orange", "Orange", True, True])
|
||||
|
||||
|
||||
@dataclass(eq=False)
|
||||
class MockDataclass:
|
||||
|
Loading…
x
Reference in New Issue
Block a user