mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 20:02:49 +08:00
redefinition -> defaults kept in config
Summary: This is an internal change in the config systen. It allows redefining a pluggable implementation with new default values. This is useful in notebooks / interactive use. For example, this now works. class A(ReplaceableBase): pass registry.register class B(A): i: int = 4 class C(Configurable): a: A a_class_type: str = "B" def __post_init__(self): run_auto_creation(self) expand_args_fields(C) registry.register class B(A): i: int = 5 c = C() assert c.a.i == 5 Reviewed By: shapovalov Differential Revision: D38219371 fbshipit-source-id: 72911a9bd3426d3359cf8802cc016fc7f6d7713b
This commit is contained in:
parent
cb49550486
commit
6b481595f0
@ -887,6 +887,40 @@ def get_default_args_field(
|
|||||||
return dataclasses.field(default_factory=create)
|
return dataclasses.field(default_factory=create)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_default_args_field_from_registry(
|
||||||
|
*,
|
||||||
|
base_class_wanted: Type[_X],
|
||||||
|
name: str,
|
||||||
|
_do_not_process: Tuple[type, ...] = (),
|
||||||
|
_hook: Optional[Callable[[DictConfig], None]] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Get a dataclass field which defaults to
|
||||||
|
get_default_args(registry.get(base_class_wanted, name)).
|
||||||
|
|
||||||
|
This is used internally in place of get_default_args_field in
|
||||||
|
order that default values are updated if a class is redefined.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
base_class_wanted: As for registry.get.
|
||||||
|
name: As for registry.get.
|
||||||
|
_do_not_process: As for get_default_args
|
||||||
|
_hook: Function called on the result before returning.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
function to return new DictConfig object
|
||||||
|
"""
|
||||||
|
|
||||||
|
def create():
|
||||||
|
C = registry.get(base_class_wanted=base_class_wanted, name=name)
|
||||||
|
args = get_default_args(C, _do_not_process=_do_not_process)
|
||||||
|
if _hook is not None:
|
||||||
|
_hook(args)
|
||||||
|
return args
|
||||||
|
|
||||||
|
return dataclasses.field(default_factory=create)
|
||||||
|
|
||||||
|
|
||||||
def _get_type_to_process(type_) -> Optional[Tuple[Type, _ProcessType]]:
|
def _get_type_to_process(type_) -> Optional[Tuple[Type, _ProcessType]]:
|
||||||
"""
|
"""
|
||||||
If a member is annotated as `type_`, and that should expanded in
|
If a member is annotated as `type_`, and that should expanded in
|
||||||
@ -978,8 +1012,9 @@ def _process_member(
|
|||||||
setattr(
|
setattr(
|
||||||
some_class,
|
some_class,
|
||||||
args_name,
|
args_name,
|
||||||
get_default_args_field(
|
_get_default_args_field_from_registry(
|
||||||
derived_type,
|
base_class_wanted=type_,
|
||||||
|
name=derived_type.__name__,
|
||||||
_do_not_process=_do_not_process + (some_class,),
|
_do_not_process=_do_not_process + (some_class,),
|
||||||
_hook=hook_closed,
|
_hook=hook_closed,
|
||||||
),
|
),
|
||||||
|
@ -378,14 +378,20 @@ class TestConfig(unittest.TestCase):
|
|||||||
with self.assertWarnsRegex(
|
with self.assertWarnsRegex(
|
||||||
UserWarning, "New implementation of Grape is being chosen."
|
UserWarning, "New implementation of Grape is being chosen."
|
||||||
):
|
):
|
||||||
bowl = FruitBowl(**bowl_args)
|
defaulted_bowl = FruitBowl()
|
||||||
self.assertIsInstance(bowl.main_fruit, Grape)
|
self.assertIsInstance(defaulted_bowl.main_fruit, Grape)
|
||||||
|
self.assertEqual(defaulted_bowl.main_fruit.large, True)
|
||||||
|
self.assertEqual(defaulted_bowl.main_fruit.get_color(), "green")
|
||||||
|
|
||||||
|
with self.assertWarnsRegex(
|
||||||
|
UserWarning, "New implementation of Grape is being chosen."
|
||||||
|
):
|
||||||
|
args_bowl = FruitBowl(**bowl_args)
|
||||||
|
self.assertIsInstance(args_bowl.main_fruit, Grape)
|
||||||
# Redefining the same class won't help with defaults because encoded in args
|
# Redefining the same class won't help with defaults because encoded in args
|
||||||
self.assertEqual(bowl.main_fruit.large, False)
|
self.assertEqual(args_bowl.main_fruit.large, False)
|
||||||
|
|
||||||
# But the override worked.
|
# But the override worked.
|
||||||
self.assertEqual(bowl.main_fruit.get_color(), "green")
|
self.assertEqual(args_bowl.main_fruit.get_color(), "green")
|
||||||
|
|
||||||
# 2. Try redefining without the dataclass modifier
|
# 2. Try redefining without the dataclass modifier
|
||||||
# This relies on the fact that default creation processes the class.
|
# This relies on the fact that default creation processes the class.
|
||||||
@ -397,7 +403,7 @@ class TestConfig(unittest.TestCase):
|
|||||||
with self.assertWarnsRegex(
|
with self.assertWarnsRegex(
|
||||||
UserWarning, "New implementation of Grape is being chosen."
|
UserWarning, "New implementation of Grape is being chosen."
|
||||||
):
|
):
|
||||||
bowl = FruitBowl(**bowl_args)
|
FruitBowl(**bowl_args)
|
||||||
|
|
||||||
# 3. Adding a new class doesn't get picked up, because the first
|
# 3. Adding a new class doesn't get picked up, because the first
|
||||||
# get_default_args call has frozen FruitBowl. This is intrinsic to
|
# get_default_args call has frozen FruitBowl. This is intrinsic to
|
||||||
|
Loading…
x
Reference in New Issue
Block a user