From 8bc0a04e863b87d58c60efe71c79acab80c93818 Mon Sep 17 00:00:00 2001 From: Jeremy Reizenstein Date: Fri, 10 Jun 2022 12:22:46 -0700 Subject: [PATCH] hooks and allow registering base class Summary: Allow a class to modify its subparts in get_default_args by defining the special function provide_config_hook. Reviewed By: davnov134 Differential Revision: D36671081 fbshipit-source-id: 3e5b73880cb846c494a209c4479835f6352f45cf --- pytorch3d/implicitron/tools/config.py | 47 ++++++++++++++++++++++----- tests/implicitron/test_config.py | 30 +++++++++++++++++ 2 files changed, 68 insertions(+), 9 deletions(-) diff --git a/pytorch3d/implicitron/tools/config.py b/pytorch3d/implicitron/tools/config.py index eaf91151..11e344e0 100644 --- a/pytorch3d/implicitron/tools/config.py +++ b/pytorch3d/implicitron/tools/config.py @@ -11,6 +11,7 @@ import sys import warnings from collections import Counter, defaultdict from enum import Enum +from functools import partial from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union from omegaconf import DictConfig, OmegaConf, open_dict @@ -177,6 +178,7 @@ ARGS_SUFFIX: str = "_args" ENABLED_SUFFIX: str = "_enabled" CREATE_PREFIX: str = "create_" IMPL_SUFFIX: str = "_impl" +TWEAK_SUFFIX: str = "_tweak_args" class ReplaceableBase: @@ -261,13 +263,9 @@ class _Registry: raise ValueError( f"Cannot register {some_class}. Cannot tell what it is." ) - if some_class is base_class: - raise ValueError(f"Attempted to register the base class {some_class}") self._mapping[base_class][name] = some_class - def get( - self, base_class_wanted: Type[ReplaceableBase], name: str - ) -> Type[ReplaceableBase]: + def get(self, base_class_wanted: Type[_X], name: str) -> Type[_X]: """ Retrieve a class from the registry by name @@ -295,6 +293,7 @@ class _Registry: raise ValueError( f"{name} resolves to {result} which does not subclass {base_class_wanted}" ) + # pyre-ignore[7] return result def get_all( @@ -773,6 +772,14 @@ def expand_args_fields( transformed, with values giving the types they were declared to have. (E.g. {"x": X} or {"x": Optional[X]} in the cases above.) + In addition, if the class has a member function + + @classmethod + def x_tweak_args(cls, member_type: Type, args: DictConfig) -> None + + then the default_factory of x_args will also have a call to x_tweak_args(X, x_args) and + the default_factory of x_Y_args will also have a call to x_tweak_args(Y, x_Y_args). + Args: some_class: the class to be processed _do_not_process: Internal use for get_default_args: Because get_default_args calls @@ -849,19 +856,29 @@ def expand_args_fields( return some_class -def get_default_args_field(C, *, _do_not_process: Tuple[type, ...] = ()): +def get_default_args_field( + C, + *, + _do_not_process: Tuple[type, ...] = (), + _hook: Optional[Callable[[DictConfig], None]] = None, +): """ Get a dataclass field which defaults to get_default_args(...) Args: - As for get_default_args. + C: As for get_default_args. + _do_not_process: As for get_default_args + _hook: Function called on the result before returning. Returns: function to return new DictConfig object """ def create(): - return get_default_args(C, _do_not_process=_do_not_process) + args = get_default_args(C, _do_not_process=_do_not_process) + if _hook is not None: + _hook(args) + return args return dataclasses.field(default_factory=create) @@ -924,6 +941,7 @@ def _process_member( # sure they go at the end of __annotations__ in case # there are non-defaulted standard class members. del some_class.__annotations__[name] + hook = getattr(some_class, name + TWEAK_SUFFIX, None) if process_type in (_ProcessType.REPLACEABLE, _ProcessType.OPTIONAL_REPLACEABLE): type_name = name + TYPE_SUFFIX @@ -949,11 +967,17 @@ def _process_member( f"Cannot generate {args_name} because it is already present." ) some_class.__annotations__[args_name] = DictConfig + if hook is not None: + hook_closed = partial(hook, derived_type) + else: + hook_closed = None setattr( some_class, args_name, get_default_args_field( - derived_type, _do_not_process=_do_not_process + (some_class,) + derived_type, + _do_not_process=_do_not_process + (some_class,), + _hook=hook_closed, ), ) else: @@ -966,12 +990,17 @@ def _process_member( raise ValueError(f"Cannot process {type_} inside {some_class}") some_class.__annotations__[args_name] = DictConfig + if hook is not None: + hook_closed = partial(hook, type_) + else: + hook_closed = None setattr( some_class, args_name, get_default_args_field( type_, _do_not_process=_do_not_process + (some_class,), + _hook=hook_closed, ), ) if process_type == _ProcessType.OPTIONAL_CONFIGURABLE: diff --git a/tests/implicitron/test_config.py b/tests/implicitron/test_config.py index 2f591edd..626d0a0f 100644 --- a/tests/implicitron/test_config.py +++ b/tests/implicitron/test_config.py @@ -678,6 +678,36 @@ class TestConfig(unittest.TestCase): remove_unused_components(args) self.assertEqual(OmegaConf.to_yaml(args), "mt_enabled: false\n") + def test_tweak_hook(self): + class A(Configurable): + n: int = 9 + + class Wrapper(Configurable): + fruit: Fruit + fruit_class_type: str = "Pear" + fruit2: Fruit + fruit2_class_type: str = "Pear" + a: A + a2: A + + @classmethod + def a_tweak_args(cls, type, args): + assert type == A + args.n = 993 + + @classmethod + def fruit_tweak_args(cls, type, args): + assert issubclass(type, Fruit) + if type == Pear: + assert args.n_pips == 13 + args.n_pips = 19 + + args = get_default_args(Wrapper) + self.assertEqual(args.a_args.n, 993) + self.assertEqual(args.a2_args.n, 9) + self.assertEqual(args.fruit_Pear_args.n_pips, 19) + self.assertEqual(args.fruit2_Pear_args.n_pips, 13) + def test_impls(self): # Check that create_x actually uses create_x_impl to do its work # by using all the member types, both with a faked impl function