mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
open_dict for tweaking
Summary: Made the config system call open_dict when it calls the tweak function. Reviewed By: shapovalov Differential Revision: D38315334 fbshipit-source-id: 5924a92d8d0bf399bbf3788247f81fc990e265e7
This commit is contained in:
parent
c3f8dad55c
commit
5f069dbb7e
@ -113,9 +113,8 @@ class JsonIndexDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13]
|
||||
Called by get_default_args(JsonIndexDatasetMapProvider) to
|
||||
not expose certain fields of each dataset class.
|
||||
"""
|
||||
with open_dict(args):
|
||||
for key in _NEED_CONTROL:
|
||||
del args[key]
|
||||
for key in _NEED_CONTROL:
|
||||
del args[key]
|
||||
|
||||
def create_dataset(self):
|
||||
"""
|
||||
|
@ -300,9 +300,8 @@ class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase): # pyre-ignore [13]
|
||||
Called by get_default_args(JsonIndexDatasetMapProviderV2) to
|
||||
not expose certain fields of each dataset class.
|
||||
"""
|
||||
with open_dict(args):
|
||||
for key in _NEED_CONTROL:
|
||||
del args[key]
|
||||
for key in _NEED_CONTROL:
|
||||
del args[key]
|
||||
|
||||
def create_dataset(self):
|
||||
# The dataset object is created inside `self.get_dataset_map`
|
||||
|
@ -881,7 +881,8 @@ def get_default_args_field(
|
||||
def create():
|
||||
args = get_default_args(C, _do_not_process=_do_not_process)
|
||||
if _hook is not None:
|
||||
_hook(args)
|
||||
with open_dict(args):
|
||||
_hook(args)
|
||||
return args
|
||||
|
||||
return dataclasses.field(default_factory=create)
|
||||
@ -915,7 +916,8 @@ def _get_default_args_field_from_registry(
|
||||
C = registry.get(base_class_wanted=base_class_wanted, name=name)
|
||||
args = get_default_args(C, _do_not_process=_do_not_process)
|
||||
if _hook is not None:
|
||||
_hook(args)
|
||||
with open_dict(args):
|
||||
_hook(args)
|
||||
return args
|
||||
|
||||
return dataclasses.field(default_factory=create)
|
||||
|
@ -691,12 +691,17 @@ class TestConfig(unittest.TestCase):
|
||||
fruit2_class_type: str = "Pear"
|
||||
a: A
|
||||
a2: A
|
||||
a3: A
|
||||
|
||||
@classmethod
|
||||
def a_tweak_args(cls, type, args):
|
||||
assert type == A
|
||||
args.n = 993
|
||||
|
||||
@classmethod
|
||||
def a3_tweak_args(cls, type, args):
|
||||
del args["n"]
|
||||
|
||||
@classmethod
|
||||
def fruit_tweak_args(cls, type, args):
|
||||
assert issubclass(type, Fruit)
|
||||
@ -707,6 +712,7 @@ class TestConfig(unittest.TestCase):
|
||||
args = get_default_args(Wrapper)
|
||||
self.assertEqual(args.a_args.n, 993)
|
||||
self.assertEqual(args.a2_args.n, 9)
|
||||
self.assertEqual(args.a3_args, {})
|
||||
self.assertEqual(args.fruit_Pear_args.n_pips, 19)
|
||||
self.assertEqual(args.fruit2_Pear_args.n_pips, 13)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user