mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 18:02:14 +08:00 
			
		
		
		
	hooks and allow registering base class
Summary: Allow a class to modify its subparts in get_default_args by defining the special function provide_config_hook. Reviewed By: davnov134 Differential Revision: D36671081 fbshipit-source-id: 3e5b73880cb846c494a209c4479835f6352f45cf
This commit is contained in:
		
							parent
							
								
									5cd70067e2
								
							
						
					
					
						commit
						8bc0a04e86
					
				@ -11,6 +11,7 @@ import sys
 | 
			
		||||
import warnings
 | 
			
		||||
from collections import Counter, defaultdict
 | 
			
		||||
from enum import Enum
 | 
			
		||||
from functools import partial
 | 
			
		||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union
 | 
			
		||||
 | 
			
		||||
from omegaconf import DictConfig, OmegaConf, open_dict
 | 
			
		||||
@ -177,6 +178,7 @@ ARGS_SUFFIX: str = "_args"
 | 
			
		||||
ENABLED_SUFFIX: str = "_enabled"
 | 
			
		||||
CREATE_PREFIX: str = "create_"
 | 
			
		||||
IMPL_SUFFIX: str = "_impl"
 | 
			
		||||
TWEAK_SUFFIX: str = "_tweak_args"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ReplaceableBase:
 | 
			
		||||
@ -261,13 +263,9 @@ class _Registry:
 | 
			
		||||
                raise ValueError(
 | 
			
		||||
                    f"Cannot register {some_class}. Cannot tell what it is."
 | 
			
		||||
                )
 | 
			
		||||
        if some_class is base_class:
 | 
			
		||||
            raise ValueError(f"Attempted to register the base class {some_class}")
 | 
			
		||||
        self._mapping[base_class][name] = some_class
 | 
			
		||||
 | 
			
		||||
    def get(
 | 
			
		||||
        self, base_class_wanted: Type[ReplaceableBase], name: str
 | 
			
		||||
    ) -> Type[ReplaceableBase]:
 | 
			
		||||
    def get(self, base_class_wanted: Type[_X], name: str) -> Type[_X]:
 | 
			
		||||
        """
 | 
			
		||||
        Retrieve a class from the registry by name
 | 
			
		||||
 | 
			
		||||
@ -295,6 +293,7 @@ class _Registry:
 | 
			
		||||
            raise ValueError(
 | 
			
		||||
                f"{name} resolves to {result} which does not subclass {base_class_wanted}"
 | 
			
		||||
            )
 | 
			
		||||
        # pyre-ignore[7]
 | 
			
		||||
        return result
 | 
			
		||||
 | 
			
		||||
    def get_all(
 | 
			
		||||
@ -773,6 +772,14 @@ def expand_args_fields(
 | 
			
		||||
            transformed, with values giving the types they were declared to have.
 | 
			
		||||
            (E.g. {"x": X} or {"x": Optional[X]} in the cases above.)
 | 
			
		||||
 | 
			
		||||
    In addition, if the class has a member function
 | 
			
		||||
 | 
			
		||||
        @classmethod
 | 
			
		||||
        def x_tweak_args(cls, member_type: Type, args: DictConfig) -> None
 | 
			
		||||
 | 
			
		||||
    then the default_factory of x_args will also have a call to x_tweak_args(X, x_args) and
 | 
			
		||||
    the default_factory of x_Y_args will also have a call to x_tweak_args(Y, x_Y_args).
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        some_class: the class to be processed
 | 
			
		||||
        _do_not_process: Internal use for get_default_args: Because get_default_args calls
 | 
			
		||||
@ -849,19 +856,29 @@ def expand_args_fields(
 | 
			
		||||
    return some_class
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_default_args_field(C, *, _do_not_process: Tuple[type, ...] = ()):
 | 
			
		||||
def get_default_args_field(
 | 
			
		||||
    C,
 | 
			
		||||
    *,
 | 
			
		||||
    _do_not_process: Tuple[type, ...] = (),
 | 
			
		||||
    _hook: Optional[Callable[[DictConfig], None]] = None,
 | 
			
		||||
):
 | 
			
		||||
    """
 | 
			
		||||
    Get a dataclass field which defaults to get_default_args(...)
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        As for get_default_args.
 | 
			
		||||
        C: As for get_default_args.
 | 
			
		||||
        _do_not_process: As for get_default_args
 | 
			
		||||
        _hook: Function called on the result before returning.
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
        function to return new DictConfig object
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def create():
 | 
			
		||||
        return get_default_args(C, _do_not_process=_do_not_process)
 | 
			
		||||
        args = get_default_args(C, _do_not_process=_do_not_process)
 | 
			
		||||
        if _hook is not None:
 | 
			
		||||
            _hook(args)
 | 
			
		||||
        return args
 | 
			
		||||
 | 
			
		||||
    return dataclasses.field(default_factory=create)
 | 
			
		||||
 | 
			
		||||
@ -924,6 +941,7 @@ def _process_member(
 | 
			
		||||
    # sure they go at the end of __annotations__ in case
 | 
			
		||||
    # there are non-defaulted standard class members.
 | 
			
		||||
    del some_class.__annotations__[name]
 | 
			
		||||
    hook = getattr(some_class, name + TWEAK_SUFFIX, None)
 | 
			
		||||
 | 
			
		||||
    if process_type in (_ProcessType.REPLACEABLE, _ProcessType.OPTIONAL_REPLACEABLE):
 | 
			
		||||
        type_name = name + TYPE_SUFFIX
 | 
			
		||||
@ -949,11 +967,17 @@ def _process_member(
 | 
			
		||||
                    f"Cannot generate {args_name} because it is already present."
 | 
			
		||||
                )
 | 
			
		||||
            some_class.__annotations__[args_name] = DictConfig
 | 
			
		||||
            if hook is not None:
 | 
			
		||||
                hook_closed = partial(hook, derived_type)
 | 
			
		||||
            else:
 | 
			
		||||
                hook_closed = None
 | 
			
		||||
            setattr(
 | 
			
		||||
                some_class,
 | 
			
		||||
                args_name,
 | 
			
		||||
                get_default_args_field(
 | 
			
		||||
                    derived_type, _do_not_process=_do_not_process + (some_class,)
 | 
			
		||||
                    derived_type,
 | 
			
		||||
                    _do_not_process=_do_not_process + (some_class,),
 | 
			
		||||
                    _hook=hook_closed,
 | 
			
		||||
                ),
 | 
			
		||||
            )
 | 
			
		||||
    else:
 | 
			
		||||
@ -966,12 +990,17 @@ def _process_member(
 | 
			
		||||
            raise ValueError(f"Cannot process {type_} inside {some_class}")
 | 
			
		||||
 | 
			
		||||
        some_class.__annotations__[args_name] = DictConfig
 | 
			
		||||
        if hook is not None:
 | 
			
		||||
            hook_closed = partial(hook, type_)
 | 
			
		||||
        else:
 | 
			
		||||
            hook_closed = None
 | 
			
		||||
        setattr(
 | 
			
		||||
            some_class,
 | 
			
		||||
            args_name,
 | 
			
		||||
            get_default_args_field(
 | 
			
		||||
                type_,
 | 
			
		||||
                _do_not_process=_do_not_process + (some_class,),
 | 
			
		||||
                _hook=hook_closed,
 | 
			
		||||
            ),
 | 
			
		||||
        )
 | 
			
		||||
        if process_type == _ProcessType.OPTIONAL_CONFIGURABLE:
 | 
			
		||||
 | 
			
		||||
@ -678,6 +678,36 @@ class TestConfig(unittest.TestCase):
 | 
			
		||||
        remove_unused_components(args)
 | 
			
		||||
        self.assertEqual(OmegaConf.to_yaml(args), "mt_enabled: false\n")
 | 
			
		||||
 | 
			
		||||
    def test_tweak_hook(self):
 | 
			
		||||
        class A(Configurable):
 | 
			
		||||
            n: int = 9
 | 
			
		||||
 | 
			
		||||
        class Wrapper(Configurable):
 | 
			
		||||
            fruit: Fruit
 | 
			
		||||
            fruit_class_type: str = "Pear"
 | 
			
		||||
            fruit2: Fruit
 | 
			
		||||
            fruit2_class_type: str = "Pear"
 | 
			
		||||
            a: A
 | 
			
		||||
            a2: A
 | 
			
		||||
 | 
			
		||||
            @classmethod
 | 
			
		||||
            def a_tweak_args(cls, type, args):
 | 
			
		||||
                assert type == A
 | 
			
		||||
                args.n = 993
 | 
			
		||||
 | 
			
		||||
            @classmethod
 | 
			
		||||
            def fruit_tweak_args(cls, type, args):
 | 
			
		||||
                assert issubclass(type, Fruit)
 | 
			
		||||
                if type == Pear:
 | 
			
		||||
                    assert args.n_pips == 13
 | 
			
		||||
                    args.n_pips = 19
 | 
			
		||||
 | 
			
		||||
        args = get_default_args(Wrapper)
 | 
			
		||||
        self.assertEqual(args.a_args.n, 993)
 | 
			
		||||
        self.assertEqual(args.a2_args.n, 9)
 | 
			
		||||
        self.assertEqual(args.fruit_Pear_args.n_pips, 19)
 | 
			
		||||
        self.assertEqual(args.fruit2_Pear_args.n_pips, 13)
 | 
			
		||||
 | 
			
		||||
    def test_impls(self):
 | 
			
		||||
        # Check that create_x actually uses create_x_impl to do its work
 | 
			
		||||
        # by using all the member types, both with a faked impl function
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user