mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
hooks and allow registering base class
Summary: Allow a class to modify its subparts in get_default_args by defining the special function provide_config_hook. Reviewed By: davnov134 Differential Revision: D36671081 fbshipit-source-id: 3e5b73880cb846c494a209c4479835f6352f45cf
This commit is contained in:
parent
5cd70067e2
commit
8bc0a04e86
@ -11,6 +11,7 @@ import sys
|
||||
import warnings
|
||||
from collections import Counter, defaultdict
|
||||
from enum import Enum
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union
|
||||
|
||||
from omegaconf import DictConfig, OmegaConf, open_dict
|
||||
@ -177,6 +178,7 @@ ARGS_SUFFIX: str = "_args"
|
||||
ENABLED_SUFFIX: str = "_enabled"
|
||||
CREATE_PREFIX: str = "create_"
|
||||
IMPL_SUFFIX: str = "_impl"
|
||||
TWEAK_SUFFIX: str = "_tweak_args"
|
||||
|
||||
|
||||
class ReplaceableBase:
|
||||
@ -261,13 +263,9 @@ class _Registry:
|
||||
raise ValueError(
|
||||
f"Cannot register {some_class}. Cannot tell what it is."
|
||||
)
|
||||
if some_class is base_class:
|
||||
raise ValueError(f"Attempted to register the base class {some_class}")
|
||||
self._mapping[base_class][name] = some_class
|
||||
|
||||
def get(
|
||||
self, base_class_wanted: Type[ReplaceableBase], name: str
|
||||
) -> Type[ReplaceableBase]:
|
||||
def get(self, base_class_wanted: Type[_X], name: str) -> Type[_X]:
|
||||
"""
|
||||
Retrieve a class from the registry by name
|
||||
|
||||
@ -295,6 +293,7 @@ class _Registry:
|
||||
raise ValueError(
|
||||
f"{name} resolves to {result} which does not subclass {base_class_wanted}"
|
||||
)
|
||||
# pyre-ignore[7]
|
||||
return result
|
||||
|
||||
def get_all(
|
||||
@ -773,6 +772,14 @@ def expand_args_fields(
|
||||
transformed, with values giving the types they were declared to have.
|
||||
(E.g. {"x": X} or {"x": Optional[X]} in the cases above.)
|
||||
|
||||
In addition, if the class has a member function
|
||||
|
||||
@classmethod
|
||||
def x_tweak_args(cls, member_type: Type, args: DictConfig) -> None
|
||||
|
||||
then the default_factory of x_args will also have a call to x_tweak_args(X, x_args) and
|
||||
the default_factory of x_Y_args will also have a call to x_tweak_args(Y, x_Y_args).
|
||||
|
||||
Args:
|
||||
some_class: the class to be processed
|
||||
_do_not_process: Internal use for get_default_args: Because get_default_args calls
|
||||
@ -849,19 +856,29 @@ def expand_args_fields(
|
||||
return some_class
|
||||
|
||||
|
||||
def get_default_args_field(C, *, _do_not_process: Tuple[type, ...] = ()):
|
||||
def get_default_args_field(
|
||||
C,
|
||||
*,
|
||||
_do_not_process: Tuple[type, ...] = (),
|
||||
_hook: Optional[Callable[[DictConfig], None]] = None,
|
||||
):
|
||||
"""
|
||||
Get a dataclass field which defaults to get_default_args(...)
|
||||
|
||||
Args:
|
||||
As for get_default_args.
|
||||
C: As for get_default_args.
|
||||
_do_not_process: As for get_default_args
|
||||
_hook: Function called on the result before returning.
|
||||
|
||||
Returns:
|
||||
function to return new DictConfig object
|
||||
"""
|
||||
|
||||
def create():
|
||||
return get_default_args(C, _do_not_process=_do_not_process)
|
||||
args = get_default_args(C, _do_not_process=_do_not_process)
|
||||
if _hook is not None:
|
||||
_hook(args)
|
||||
return args
|
||||
|
||||
return dataclasses.field(default_factory=create)
|
||||
|
||||
@ -924,6 +941,7 @@ def _process_member(
|
||||
# sure they go at the end of __annotations__ in case
|
||||
# there are non-defaulted standard class members.
|
||||
del some_class.__annotations__[name]
|
||||
hook = getattr(some_class, name + TWEAK_SUFFIX, None)
|
||||
|
||||
if process_type in (_ProcessType.REPLACEABLE, _ProcessType.OPTIONAL_REPLACEABLE):
|
||||
type_name = name + TYPE_SUFFIX
|
||||
@ -949,11 +967,17 @@ def _process_member(
|
||||
f"Cannot generate {args_name} because it is already present."
|
||||
)
|
||||
some_class.__annotations__[args_name] = DictConfig
|
||||
if hook is not None:
|
||||
hook_closed = partial(hook, derived_type)
|
||||
else:
|
||||
hook_closed = None
|
||||
setattr(
|
||||
some_class,
|
||||
args_name,
|
||||
get_default_args_field(
|
||||
derived_type, _do_not_process=_do_not_process + (some_class,)
|
||||
derived_type,
|
||||
_do_not_process=_do_not_process + (some_class,),
|
||||
_hook=hook_closed,
|
||||
),
|
||||
)
|
||||
else:
|
||||
@ -966,12 +990,17 @@ def _process_member(
|
||||
raise ValueError(f"Cannot process {type_} inside {some_class}")
|
||||
|
||||
some_class.__annotations__[args_name] = DictConfig
|
||||
if hook is not None:
|
||||
hook_closed = partial(hook, type_)
|
||||
else:
|
||||
hook_closed = None
|
||||
setattr(
|
||||
some_class,
|
||||
args_name,
|
||||
get_default_args_field(
|
||||
type_,
|
||||
_do_not_process=_do_not_process + (some_class,),
|
||||
_hook=hook_closed,
|
||||
),
|
||||
)
|
||||
if process_type == _ProcessType.OPTIONAL_CONFIGURABLE:
|
||||
|
@ -678,6 +678,36 @@ class TestConfig(unittest.TestCase):
|
||||
remove_unused_components(args)
|
||||
self.assertEqual(OmegaConf.to_yaml(args), "mt_enabled: false\n")
|
||||
|
||||
def test_tweak_hook(self):
|
||||
class A(Configurable):
|
||||
n: int = 9
|
||||
|
||||
class Wrapper(Configurable):
|
||||
fruit: Fruit
|
||||
fruit_class_type: str = "Pear"
|
||||
fruit2: Fruit
|
||||
fruit2_class_type: str = "Pear"
|
||||
a: A
|
||||
a2: A
|
||||
|
||||
@classmethod
|
||||
def a_tweak_args(cls, type, args):
|
||||
assert type == A
|
||||
args.n = 993
|
||||
|
||||
@classmethod
|
||||
def fruit_tweak_args(cls, type, args):
|
||||
assert issubclass(type, Fruit)
|
||||
if type == Pear:
|
||||
assert args.n_pips == 13
|
||||
args.n_pips = 19
|
||||
|
||||
args = get_default_args(Wrapper)
|
||||
self.assertEqual(args.a_args.n, 993)
|
||||
self.assertEqual(args.a2_args.n, 9)
|
||||
self.assertEqual(args.fruit_Pear_args.n_pips, 19)
|
||||
self.assertEqual(args.fruit2_Pear_args.n_pips, 13)
|
||||
|
||||
def test_impls(self):
|
||||
# Check that create_x actually uses create_x_impl to do its work
|
||||
# by using all the member types, both with a faked impl function
|
||||
|
Loading…
x
Reference in New Issue
Block a user