mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 20:02:49 +08:00
hooks and allow registering base class
Summary: Allow a class to modify its subparts in get_default_args by defining the special function provide_config_hook. Reviewed By: davnov134 Differential Revision: D36671081 fbshipit-source-id: 3e5b73880cb846c494a209c4479835f6352f45cf
This commit is contained in:
parent
5cd70067e2
commit
8bc0a04e86
@ -11,6 +11,7 @@ import sys
|
|||||||
import warnings
|
import warnings
|
||||||
from collections import Counter, defaultdict
|
from collections import Counter, defaultdict
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
from functools import partial
|
||||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union
|
||||||
|
|
||||||
from omegaconf import DictConfig, OmegaConf, open_dict
|
from omegaconf import DictConfig, OmegaConf, open_dict
|
||||||
@ -177,6 +178,7 @@ ARGS_SUFFIX: str = "_args"
|
|||||||
ENABLED_SUFFIX: str = "_enabled"
|
ENABLED_SUFFIX: str = "_enabled"
|
||||||
CREATE_PREFIX: str = "create_"
|
CREATE_PREFIX: str = "create_"
|
||||||
IMPL_SUFFIX: str = "_impl"
|
IMPL_SUFFIX: str = "_impl"
|
||||||
|
TWEAK_SUFFIX: str = "_tweak_args"
|
||||||
|
|
||||||
|
|
||||||
class ReplaceableBase:
|
class ReplaceableBase:
|
||||||
@ -261,13 +263,9 @@ class _Registry:
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Cannot register {some_class}. Cannot tell what it is."
|
f"Cannot register {some_class}. Cannot tell what it is."
|
||||||
)
|
)
|
||||||
if some_class is base_class:
|
|
||||||
raise ValueError(f"Attempted to register the base class {some_class}")
|
|
||||||
self._mapping[base_class][name] = some_class
|
self._mapping[base_class][name] = some_class
|
||||||
|
|
||||||
def get(
|
def get(self, base_class_wanted: Type[_X], name: str) -> Type[_X]:
|
||||||
self, base_class_wanted: Type[ReplaceableBase], name: str
|
|
||||||
) -> Type[ReplaceableBase]:
|
|
||||||
"""
|
"""
|
||||||
Retrieve a class from the registry by name
|
Retrieve a class from the registry by name
|
||||||
|
|
||||||
@ -295,6 +293,7 @@ class _Registry:
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"{name} resolves to {result} which does not subclass {base_class_wanted}"
|
f"{name} resolves to {result} which does not subclass {base_class_wanted}"
|
||||||
)
|
)
|
||||||
|
# pyre-ignore[7]
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def get_all(
|
def get_all(
|
||||||
@ -773,6 +772,14 @@ def expand_args_fields(
|
|||||||
transformed, with values giving the types they were declared to have.
|
transformed, with values giving the types they were declared to have.
|
||||||
(E.g. {"x": X} or {"x": Optional[X]} in the cases above.)
|
(E.g. {"x": X} or {"x": Optional[X]} in the cases above.)
|
||||||
|
|
||||||
|
In addition, if the class has a member function
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def x_tweak_args(cls, member_type: Type, args: DictConfig) -> None
|
||||||
|
|
||||||
|
then the default_factory of x_args will also have a call to x_tweak_args(X, x_args) and
|
||||||
|
the default_factory of x_Y_args will also have a call to x_tweak_args(Y, x_Y_args).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
some_class: the class to be processed
|
some_class: the class to be processed
|
||||||
_do_not_process: Internal use for get_default_args: Because get_default_args calls
|
_do_not_process: Internal use for get_default_args: Because get_default_args calls
|
||||||
@ -849,19 +856,29 @@ def expand_args_fields(
|
|||||||
return some_class
|
return some_class
|
||||||
|
|
||||||
|
|
||||||
def get_default_args_field(C, *, _do_not_process: Tuple[type, ...] = ()):
|
def get_default_args_field(
|
||||||
|
C,
|
||||||
|
*,
|
||||||
|
_do_not_process: Tuple[type, ...] = (),
|
||||||
|
_hook: Optional[Callable[[DictConfig], None]] = None,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Get a dataclass field which defaults to get_default_args(...)
|
Get a dataclass field which defaults to get_default_args(...)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
As for get_default_args.
|
C: As for get_default_args.
|
||||||
|
_do_not_process: As for get_default_args
|
||||||
|
_hook: Function called on the result before returning.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
function to return new DictConfig object
|
function to return new DictConfig object
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def create():
|
def create():
|
||||||
return get_default_args(C, _do_not_process=_do_not_process)
|
args = get_default_args(C, _do_not_process=_do_not_process)
|
||||||
|
if _hook is not None:
|
||||||
|
_hook(args)
|
||||||
|
return args
|
||||||
|
|
||||||
return dataclasses.field(default_factory=create)
|
return dataclasses.field(default_factory=create)
|
||||||
|
|
||||||
@ -924,6 +941,7 @@ def _process_member(
|
|||||||
# sure they go at the end of __annotations__ in case
|
# sure they go at the end of __annotations__ in case
|
||||||
# there are non-defaulted standard class members.
|
# there are non-defaulted standard class members.
|
||||||
del some_class.__annotations__[name]
|
del some_class.__annotations__[name]
|
||||||
|
hook = getattr(some_class, name + TWEAK_SUFFIX, None)
|
||||||
|
|
||||||
if process_type in (_ProcessType.REPLACEABLE, _ProcessType.OPTIONAL_REPLACEABLE):
|
if process_type in (_ProcessType.REPLACEABLE, _ProcessType.OPTIONAL_REPLACEABLE):
|
||||||
type_name = name + TYPE_SUFFIX
|
type_name = name + TYPE_SUFFIX
|
||||||
@ -949,11 +967,17 @@ def _process_member(
|
|||||||
f"Cannot generate {args_name} because it is already present."
|
f"Cannot generate {args_name} because it is already present."
|
||||||
)
|
)
|
||||||
some_class.__annotations__[args_name] = DictConfig
|
some_class.__annotations__[args_name] = DictConfig
|
||||||
|
if hook is not None:
|
||||||
|
hook_closed = partial(hook, derived_type)
|
||||||
|
else:
|
||||||
|
hook_closed = None
|
||||||
setattr(
|
setattr(
|
||||||
some_class,
|
some_class,
|
||||||
args_name,
|
args_name,
|
||||||
get_default_args_field(
|
get_default_args_field(
|
||||||
derived_type, _do_not_process=_do_not_process + (some_class,)
|
derived_type,
|
||||||
|
_do_not_process=_do_not_process + (some_class,),
|
||||||
|
_hook=hook_closed,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@ -966,12 +990,17 @@ def _process_member(
|
|||||||
raise ValueError(f"Cannot process {type_} inside {some_class}")
|
raise ValueError(f"Cannot process {type_} inside {some_class}")
|
||||||
|
|
||||||
some_class.__annotations__[args_name] = DictConfig
|
some_class.__annotations__[args_name] = DictConfig
|
||||||
|
if hook is not None:
|
||||||
|
hook_closed = partial(hook, type_)
|
||||||
|
else:
|
||||||
|
hook_closed = None
|
||||||
setattr(
|
setattr(
|
||||||
some_class,
|
some_class,
|
||||||
args_name,
|
args_name,
|
||||||
get_default_args_field(
|
get_default_args_field(
|
||||||
type_,
|
type_,
|
||||||
_do_not_process=_do_not_process + (some_class,),
|
_do_not_process=_do_not_process + (some_class,),
|
||||||
|
_hook=hook_closed,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
if process_type == _ProcessType.OPTIONAL_CONFIGURABLE:
|
if process_type == _ProcessType.OPTIONAL_CONFIGURABLE:
|
||||||
|
@ -678,6 +678,36 @@ class TestConfig(unittest.TestCase):
|
|||||||
remove_unused_components(args)
|
remove_unused_components(args)
|
||||||
self.assertEqual(OmegaConf.to_yaml(args), "mt_enabled: false\n")
|
self.assertEqual(OmegaConf.to_yaml(args), "mt_enabled: false\n")
|
||||||
|
|
||||||
|
def test_tweak_hook(self):
|
||||||
|
class A(Configurable):
|
||||||
|
n: int = 9
|
||||||
|
|
||||||
|
class Wrapper(Configurable):
|
||||||
|
fruit: Fruit
|
||||||
|
fruit_class_type: str = "Pear"
|
||||||
|
fruit2: Fruit
|
||||||
|
fruit2_class_type: str = "Pear"
|
||||||
|
a: A
|
||||||
|
a2: A
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def a_tweak_args(cls, type, args):
|
||||||
|
assert type == A
|
||||||
|
args.n = 993
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def fruit_tweak_args(cls, type, args):
|
||||||
|
assert issubclass(type, Fruit)
|
||||||
|
if type == Pear:
|
||||||
|
assert args.n_pips == 13
|
||||||
|
args.n_pips = 19
|
||||||
|
|
||||||
|
args = get_default_args(Wrapper)
|
||||||
|
self.assertEqual(args.a_args.n, 993)
|
||||||
|
self.assertEqual(args.a2_args.n, 9)
|
||||||
|
self.assertEqual(args.fruit_Pear_args.n_pips, 19)
|
||||||
|
self.assertEqual(args.fruit2_Pear_args.n_pips, 13)
|
||||||
|
|
||||||
def test_impls(self):
|
def test_impls(self):
|
||||||
# Check that create_x actually uses create_x_impl to do its work
|
# Check that create_x actually uses create_x_impl to do its work
|
||||||
# by using all the member types, both with a faked impl function
|
# by using all the member types, both with a faked impl function
|
||||||
|
Loading…
x
Reference in New Issue
Block a user