mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 18:02:14 +08:00 
			
		
		
		
	create_x_impl
Summary: Make create_x delegate to create_x_impl so that users can rely on create_x_impl in their overrides of create_x. Reviewed By: shapovalov, davnov134 Differential Revision: D35929810 fbshipit-source-id: 80595894ee93346b881729995775876b016fc08e
This commit is contained in:
		
							parent
							
								
									3b2300641a
								
							
						
					
					
						commit
						899a3192b6
					
				@ -175,6 +175,8 @@ _unprocessed_warning: str = (
 | 
			
		||||
TYPE_SUFFIX: str = "_class_type"
 | 
			
		||||
ARGS_SUFFIX: str = "_args"
 | 
			
		||||
ENABLED_SUFFIX: str = "_enabled"
 | 
			
		||||
CREATE_PREFIX: str = "create_"
 | 
			
		||||
IMPL_SUFFIX: str = "_impl"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ReplaceableBase:
 | 
			
		||||
@ -375,25 +377,68 @@ def _default_create(
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
        Function taking one argument, the object whose member should be
 | 
			
		||||
            initialized.
 | 
			
		||||
            initialized, i.e. self.
 | 
			
		||||
    """
 | 
			
		||||
    impl_name = f"{CREATE_PREFIX}{name}{IMPL_SUFFIX}"
 | 
			
		||||
 | 
			
		||||
    def inner(self):
 | 
			
		||||
        expand_args_fields(type_)
 | 
			
		||||
        impl = getattr(self, impl_name)
 | 
			
		||||
        args = getattr(self, name + ARGS_SUFFIX)
 | 
			
		||||
        setattr(self, name, type_(**args))
 | 
			
		||||
        impl(True, args)
 | 
			
		||||
 | 
			
		||||
    def inner_optional(self):
 | 
			
		||||
        expand_args_fields(type_)
 | 
			
		||||
        impl = getattr(self, impl_name)
 | 
			
		||||
        enabled = getattr(self, name + ENABLED_SUFFIX)
 | 
			
		||||
        args = getattr(self, name + ARGS_SUFFIX)
 | 
			
		||||
        impl(enabled, args)
 | 
			
		||||
 | 
			
		||||
    def inner_pluggable(self):
 | 
			
		||||
        type_name = getattr(self, name + TYPE_SUFFIX)
 | 
			
		||||
        impl = getattr(self, impl_name)
 | 
			
		||||
        if type_name is None:
 | 
			
		||||
            args = None
 | 
			
		||||
        else:
 | 
			
		||||
            args = getattr(self, f"{name}_{type_name}{ARGS_SUFFIX}", None)
 | 
			
		||||
        impl(type_name, args)
 | 
			
		||||
 | 
			
		||||
    if process_type == _ProcessType.OPTIONAL_CONFIGURABLE:
 | 
			
		||||
        return inner_optional
 | 
			
		||||
    return inner if process_type == _ProcessType.CONFIGURABLE else inner_pluggable
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _default_create_impl(
 | 
			
		||||
    name: str, type_: Type, process_type: _ProcessType
 | 
			
		||||
) -> Callable[[Any, Any, DictConfig], None]:
 | 
			
		||||
    """
 | 
			
		||||
    Return the default internal function for initialising a member. This is a function
 | 
			
		||||
    which could be called in the create_ function to initialise the member.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        name: name of the member
 | 
			
		||||
        type_: type of the member (with any Optional removed)
 | 
			
		||||
        process_type: Shows whether member's declared type inherits ReplaceableBase,
 | 
			
		||||
                    in which case the actual type to be created is decided at
 | 
			
		||||
                    runtime.
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
        Function taking
 | 
			
		||||
            - self, the object whose member should be initialized.
 | 
			
		||||
            - option for what to do. This is
 | 
			
		||||
                - for pluggables, the type to initialise or None to do nothing
 | 
			
		||||
                - for non pluggables, a bool indicating whether to initialise.
 | 
			
		||||
            - the args for initializing the member.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def create_configurable(self, enabled, args):
 | 
			
		||||
        if enabled:
 | 
			
		||||
            args = getattr(self, name + ARGS_SUFFIX)
 | 
			
		||||
            expand_args_fields(type_)
 | 
			
		||||
            setattr(self, name, type_(**args))
 | 
			
		||||
        else:
 | 
			
		||||
            setattr(self, name, None)
 | 
			
		||||
 | 
			
		||||
    def inner_pluggable(self):
 | 
			
		||||
        type_name = getattr(self, name + TYPE_SUFFIX)
 | 
			
		||||
    def create_pluggable(self, type_name, args):
 | 
			
		||||
        if type_name is None:
 | 
			
		||||
            setattr(self, name, None)
 | 
			
		||||
            return
 | 
			
		||||
@ -408,12 +453,11 @@ def _default_create(
 | 
			
		||||
            # were made in the redefinition will not be reflected here.
 | 
			
		||||
            warnings.warn(f"New implementation of {type_name} is being chosen.")
 | 
			
		||||
        expand_args_fields(chosen_class)
 | 
			
		||||
        args = getattr(self, f"{name}_{type_name}{ARGS_SUFFIX}")
 | 
			
		||||
        setattr(self, name, chosen_class(**args))
 | 
			
		||||
 | 
			
		||||
    if process_type == _ProcessType.OPTIONAL_CONFIGURABLE:
 | 
			
		||||
        return inner_optional
 | 
			
		||||
    return inner if process_type == _ProcessType.CONFIGURABLE else inner_pluggable
 | 
			
		||||
    if process_type in (_ProcessType.CONFIGURABLE, _ProcessType.OPTIONAL_CONFIGURABLE):
 | 
			
		||||
        return create_configurable
 | 
			
		||||
    return create_pluggable
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def run_auto_creation(self: Any) -> None:
 | 
			
		||||
@ -628,11 +672,12 @@ def expand_args_fields(
 | 
			
		||||
        x_Y_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(Y))
 | 
			
		||||
        x_Z_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(Z))
 | 
			
		||||
        def create_x(self):
 | 
			
		||||
            x_type = registry.get(X, self.x_class_type)
 | 
			
		||||
            args = self.getattr(f"x_{self.x_class_type}_args")
 | 
			
		||||
            self.create_x_impl(self.x_class_type, args)
 | 
			
		||||
        def create_x_impl(self, x_type, args):
 | 
			
		||||
            x_type = registry.get(X, x_type)
 | 
			
		||||
            expand_args_fields(x_type)
 | 
			
		||||
            self.x = x_type(
 | 
			
		||||
                **self.getattr(f"x_{self.x_class_type}_args)
 | 
			
		||||
            )
 | 
			
		||||
            self.x = x_type(**args)
 | 
			
		||||
        x_class_type: str = "UNDEFAULTED"
 | 
			
		||||
 | 
			
		||||
    without adding the optional attributes if they are already there.
 | 
			
		||||
@ -652,14 +697,19 @@ def expand_args_fields(
 | 
			
		||||
        x_Z_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(Z))
 | 
			
		||||
        def create_x(self):
 | 
			
		||||
            if self.x_class_type is None:
 | 
			
		||||
                args = None
 | 
			
		||||
            else:
 | 
			
		||||
                args = self.getattr(f"x_{self.x_class_type}_args", None)
 | 
			
		||||
            self.create_x_impl(self.x_class_type, args)
 | 
			
		||||
        def create_x_impl(self, x_class_type, args):
 | 
			
		||||
            if x_class_type is None:
 | 
			
		||||
                self.x = None
 | 
			
		||||
                return
 | 
			
		||||
 | 
			
		||||
            x_type = registry.get(X, self.x_class_type)
 | 
			
		||||
            x_type = registry.get(X, x_class_type)
 | 
			
		||||
            expand_args_fields(x_type)
 | 
			
		||||
            self.x = x_type(
 | 
			
		||||
                **self.getattr(f"x_{self.x_class_type}_args)
 | 
			
		||||
            )
 | 
			
		||||
            assert args is not None
 | 
			
		||||
            self.x = x_type(**args)
 | 
			
		||||
        x_class_type: Optional[str] = "UNDEFAULTED"
 | 
			
		||||
 | 
			
		||||
    without adding the optional attributes if they are already there.
 | 
			
		||||
@ -676,8 +726,14 @@ def expand_args_fields(
 | 
			
		||||
 | 
			
		||||
        x_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(X))
 | 
			
		||||
        def create_x(self):
 | 
			
		||||
            expand_args_fields(X)
 | 
			
		||||
            self.x = X(self.x_args)
 | 
			
		||||
            self.create_x_impl(True, self.x_args)
 | 
			
		||||
 | 
			
		||||
        def create_x_impl(self, enabled, args):
 | 
			
		||||
            if enabled:
 | 
			
		||||
                expand_args_fields(X)
 | 
			
		||||
                self.x = X(**args)
 | 
			
		||||
            else:
 | 
			
		||||
                self.x = None
 | 
			
		||||
 | 
			
		||||
    Similarly, replace,
 | 
			
		||||
 | 
			
		||||
@ -693,9 +749,12 @@ def expand_args_fields(
 | 
			
		||||
        x_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(X))
 | 
			
		||||
        x_enabled: bool = False
 | 
			
		||||
        def create_x(self):
 | 
			
		||||
            if self.x_enabled:
 | 
			
		||||
            self.create_x_impl(self.x_enabled, self.x_args)
 | 
			
		||||
 | 
			
		||||
        def create_x_impl(self, enabled, args):
 | 
			
		||||
            if enabled:
 | 
			
		||||
                expand_args_fields(X)
 | 
			
		||||
                self.x = X(self.x_args)
 | 
			
		||||
                self.x = X(**args)
 | 
			
		||||
            else:
 | 
			
		||||
                self.x = None
 | 
			
		||||
 | 
			
		||||
@ -703,7 +762,7 @@ def expand_args_fields(
 | 
			
		||||
    Also adds the following class members, unannotated so that dataclass
 | 
			
		||||
    ignores them.
 | 
			
		||||
        - _creation_functions: Tuple[str] of all the create_ functions,
 | 
			
		||||
            including those from base classes.
 | 
			
		||||
            including those from base classes (not the create_x_impl ones).
 | 
			
		||||
        - _known_implementations: Dict[str, Type] containing the classes which
 | 
			
		||||
            have been found from the registry.
 | 
			
		||||
            (used only to raise a warning if it one has been overwritten)
 | 
			
		||||
@ -918,7 +977,7 @@ def _process_member(
 | 
			
		||||
                some_class.__annotations__[enabled_name] = bool
 | 
			
		||||
                setattr(some_class, enabled_name, False)
 | 
			
		||||
 | 
			
		||||
    creation_function_name = f"create_{name}"
 | 
			
		||||
    creation_function_name = f"{CREATE_PREFIX}{name}"
 | 
			
		||||
    if not hasattr(some_class, creation_function_name):
 | 
			
		||||
        setattr(
 | 
			
		||||
            some_class,
 | 
			
		||||
@ -927,6 +986,14 @@ def _process_member(
 | 
			
		||||
        )
 | 
			
		||||
    creation_functions.append(creation_function_name)
 | 
			
		||||
 | 
			
		||||
    creation_function_impl_name = f"{CREATE_PREFIX}{name}{IMPL_SUFFIX}"
 | 
			
		||||
    if not hasattr(some_class, creation_function_impl_name):
 | 
			
		||||
        setattr(
 | 
			
		||||
            some_class,
 | 
			
		||||
            creation_function_impl_name,
 | 
			
		||||
            _default_create_impl(name, type_, process_type),
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def remove_unused_components(dict_: DictConfig) -> None:
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
@ -69,6 +69,10 @@ class LargePear(Pear):
 | 
			
		||||
    pass
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class BoringConfigurable(Configurable):
 | 
			
		||||
    pass
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class MainTest(Configurable):
 | 
			
		||||
    the_fruit: Fruit
 | 
			
		||||
    n_ids: int
 | 
			
		||||
@ -674,6 +678,57 @@ class TestConfig(unittest.TestCase):
 | 
			
		||||
        remove_unused_components(args)
 | 
			
		||||
        self.assertEqual(OmegaConf.to_yaml(args), "mt_enabled: false\n")
 | 
			
		||||
 | 
			
		||||
    def test_impls(self):
 | 
			
		||||
        # Check that create_x actually uses create_x_impl to do its work
 | 
			
		||||
        # by using all the member types, both with a faked impl function
 | 
			
		||||
        # and without.
 | 
			
		||||
        # members with _0 are optional and absent, those with _o are
 | 
			
		||||
        # optional and present.
 | 
			
		||||
        control_args = []
 | 
			
		||||
 | 
			
		||||
        def fake_impl(self, control, args):
 | 
			
		||||
            control_args.append(control)
 | 
			
		||||
 | 
			
		||||
        for fake in [False, True]:
 | 
			
		||||
 | 
			
		||||
            class MyClass(Configurable):
 | 
			
		||||
                fruit: Fruit
 | 
			
		||||
                fruit_class_type: str = "Orange"
 | 
			
		||||
                fruit_o: Optional[Fruit]
 | 
			
		||||
                fruit_o_class_type: str = "Orange"
 | 
			
		||||
                fruit_0: Optional[Fruit]
 | 
			
		||||
                fruit_0_class_type: Optional[str] = None
 | 
			
		||||
                boring: BoringConfigurable
 | 
			
		||||
                boring_o: Optional[BoringConfigurable]
 | 
			
		||||
                boring_o_enabled: bool = True
 | 
			
		||||
                boring_0: Optional[BoringConfigurable]
 | 
			
		||||
 | 
			
		||||
                def __post_init__(self):
 | 
			
		||||
                    run_auto_creation(self)
 | 
			
		||||
 | 
			
		||||
            if fake:
 | 
			
		||||
                MyClass.create_fruit_impl = fake_impl
 | 
			
		||||
                MyClass.create_fruit_o_impl = fake_impl
 | 
			
		||||
                MyClass.create_boring_impl = fake_impl
 | 
			
		||||
                MyClass.create_boring_o_impl = fake_impl
 | 
			
		||||
 | 
			
		||||
            expand_args_fields(MyClass)
 | 
			
		||||
            instance = MyClass()
 | 
			
		||||
            for name in ["fruit", "fruit_o", "boring", "boring_o"]:
 | 
			
		||||
                self.assertEqual(
 | 
			
		||||
                    hasattr(instance, name), not fake, msg=f"{name} {fake}"
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
            self.assertIsNone(instance.fruit_0)
 | 
			
		||||
            self.assertIsNone(instance.boring_0)
 | 
			
		||||
            if not fake:
 | 
			
		||||
                self.assertIsInstance(instance.fruit, Orange)
 | 
			
		||||
                self.assertIsInstance(instance.fruit_o, Orange)
 | 
			
		||||
                self.assertIsInstance(instance.boring, BoringConfigurable)
 | 
			
		||||
                self.assertIsInstance(instance.boring_o, BoringConfigurable)
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(control_args, ["Orange", "Orange", True, True])
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass(eq=False)
 | 
			
		||||
class MockDataclass:
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user