open_dict for tweaking

Summary: Made the config system call open_dict when it calls the tweak function.

Reviewed By: shapovalov

Differential Revision: D38315334

fbshipit-source-id: 5924a92d8d0bf399bbf3788247f81fc990e265e7
This commit is contained in:
Darijan Gudelj 2022-08-03 05:33:46 -07:00 committed by Facebook GitHub Bot
parent c3f8dad55c
commit 5f069dbb7e
4 changed files with 14 additions and 8 deletions

View File

@ -113,9 +113,8 @@ class JsonIndexDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13]
Called by get_default_args(JsonIndexDatasetMapProvider) to
not expose certain fields of each dataset class.
"""
with open_dict(args):
for key in _NEED_CONTROL:
del args[key]
for key in _NEED_CONTROL:
del args[key]
def create_dataset(self):
"""

View File

@ -300,9 +300,8 @@ class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase): # pyre-ignore [13]
Called by get_default_args(JsonIndexDatasetMapProviderV2) to
not expose certain fields of each dataset class.
"""
with open_dict(args):
for key in _NEED_CONTROL:
del args[key]
for key in _NEED_CONTROL:
del args[key]
def create_dataset(self):
# The dataset object is created inside `self.get_dataset_map`

View File

@ -881,7 +881,8 @@ def get_default_args_field(
def create():
args = get_default_args(C, _do_not_process=_do_not_process)
if _hook is not None:
_hook(args)
with open_dict(args):
_hook(args)
return args
return dataclasses.field(default_factory=create)
@ -915,7 +916,8 @@ def _get_default_args_field_from_registry(
C = registry.get(base_class_wanted=base_class_wanted, name=name)
args = get_default_args(C, _do_not_process=_do_not_process)
if _hook is not None:
_hook(args)
with open_dict(args):
_hook(args)
return args
return dataclasses.field(default_factory=create)

View File

@ -691,12 +691,17 @@ class TestConfig(unittest.TestCase):
fruit2_class_type: str = "Pear"
a: A
a2: A
a3: A
@classmethod
def a_tweak_args(cls, type, args):
assert type == A
args.n = 993
@classmethod
def a3_tweak_args(cls, type, args):
del args["n"]
@classmethod
def fruit_tweak_args(cls, type, args):
assert issubclass(type, Fruit)
@ -707,6 +712,7 @@ class TestConfig(unittest.TestCase):
args = get_default_args(Wrapper)
self.assertEqual(args.a_args.n, 993)
self.assertEqual(args.a2_args.n, 9)
self.assertEqual(args.a3_args, {})
self.assertEqual(args.fruit_Pear_args.n_pips, 19)
self.assertEqual(args.fruit2_Pear_args.n_pips, 13)