mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 20:02:49 +08:00
open_dict for tweaking
Summary: Made the config system call open_dict when it calls the tweak function. Reviewed By: shapovalov Differential Revision: D38315334 fbshipit-source-id: 5924a92d8d0bf399bbf3788247f81fc990e265e7
This commit is contained in:
parent
c3f8dad55c
commit
5f069dbb7e
@ -113,7 +113,6 @@ class JsonIndexDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13]
|
|||||||
Called by get_default_args(JsonIndexDatasetMapProvider) to
|
Called by get_default_args(JsonIndexDatasetMapProvider) to
|
||||||
not expose certain fields of each dataset class.
|
not expose certain fields of each dataset class.
|
||||||
"""
|
"""
|
||||||
with open_dict(args):
|
|
||||||
for key in _NEED_CONTROL:
|
for key in _NEED_CONTROL:
|
||||||
del args[key]
|
del args[key]
|
||||||
|
|
||||||
|
@ -300,7 +300,6 @@ class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase): # pyre-ignore [13]
|
|||||||
Called by get_default_args(JsonIndexDatasetMapProviderV2) to
|
Called by get_default_args(JsonIndexDatasetMapProviderV2) to
|
||||||
not expose certain fields of each dataset class.
|
not expose certain fields of each dataset class.
|
||||||
"""
|
"""
|
||||||
with open_dict(args):
|
|
||||||
for key in _NEED_CONTROL:
|
for key in _NEED_CONTROL:
|
||||||
del args[key]
|
del args[key]
|
||||||
|
|
||||||
|
@ -881,6 +881,7 @@ def get_default_args_field(
|
|||||||
def create():
|
def create():
|
||||||
args = get_default_args(C, _do_not_process=_do_not_process)
|
args = get_default_args(C, _do_not_process=_do_not_process)
|
||||||
if _hook is not None:
|
if _hook is not None:
|
||||||
|
with open_dict(args):
|
||||||
_hook(args)
|
_hook(args)
|
||||||
return args
|
return args
|
||||||
|
|
||||||
@ -915,6 +916,7 @@ def _get_default_args_field_from_registry(
|
|||||||
C = registry.get(base_class_wanted=base_class_wanted, name=name)
|
C = registry.get(base_class_wanted=base_class_wanted, name=name)
|
||||||
args = get_default_args(C, _do_not_process=_do_not_process)
|
args = get_default_args(C, _do_not_process=_do_not_process)
|
||||||
if _hook is not None:
|
if _hook is not None:
|
||||||
|
with open_dict(args):
|
||||||
_hook(args)
|
_hook(args)
|
||||||
return args
|
return args
|
||||||
|
|
||||||
|
@ -691,12 +691,17 @@ class TestConfig(unittest.TestCase):
|
|||||||
fruit2_class_type: str = "Pear"
|
fruit2_class_type: str = "Pear"
|
||||||
a: A
|
a: A
|
||||||
a2: A
|
a2: A
|
||||||
|
a3: A
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def a_tweak_args(cls, type, args):
|
def a_tweak_args(cls, type, args):
|
||||||
assert type == A
|
assert type == A
|
||||||
args.n = 993
|
args.n = 993
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def a3_tweak_args(cls, type, args):
|
||||||
|
del args["n"]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def fruit_tweak_args(cls, type, args):
|
def fruit_tweak_args(cls, type, args):
|
||||||
assert issubclass(type, Fruit)
|
assert issubclass(type, Fruit)
|
||||||
@ -707,6 +712,7 @@ class TestConfig(unittest.TestCase):
|
|||||||
args = get_default_args(Wrapper)
|
args = get_default_args(Wrapper)
|
||||||
self.assertEqual(args.a_args.n, 993)
|
self.assertEqual(args.a_args.n, 993)
|
||||||
self.assertEqual(args.a2_args.n, 9)
|
self.assertEqual(args.a2_args.n, 9)
|
||||||
|
self.assertEqual(args.a3_args, {})
|
||||||
self.assertEqual(args.fruit_Pear_args.n_pips, 19)
|
self.assertEqual(args.fruit_Pear_args.n_pips, 19)
|
||||||
self.assertEqual(args.fruit2_Pear_args.n_pips, 13)
|
self.assertEqual(args.fruit2_Pear_args.n_pips, 13)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user