mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-19 05:40:34 +08:00
hooks and allow registering base class
Summary: Allow a class to modify its subparts in get_default_args by defining the special function provide_config_hook. Reviewed By: davnov134 Differential Revision: D36671081 fbshipit-source-id: 3e5b73880cb846c494a209c4479835f6352f45cf
This commit is contained in:
committed by
Facebook GitHub Bot
parent
5cd70067e2
commit
8bc0a04e86
@@ -678,6 +678,36 @@ class TestConfig(unittest.TestCase):
|
||||
remove_unused_components(args)
|
||||
self.assertEqual(OmegaConf.to_yaml(args), "mt_enabled: false\n")
|
||||
|
||||
def test_tweak_hook(self):
|
||||
class A(Configurable):
|
||||
n: int = 9
|
||||
|
||||
class Wrapper(Configurable):
|
||||
fruit: Fruit
|
||||
fruit_class_type: str = "Pear"
|
||||
fruit2: Fruit
|
||||
fruit2_class_type: str = "Pear"
|
||||
a: A
|
||||
a2: A
|
||||
|
||||
@classmethod
|
||||
def a_tweak_args(cls, type, args):
|
||||
assert type == A
|
||||
args.n = 993
|
||||
|
||||
@classmethod
|
||||
def fruit_tweak_args(cls, type, args):
|
||||
assert issubclass(type, Fruit)
|
||||
if type == Pear:
|
||||
assert args.n_pips == 13
|
||||
args.n_pips = 19
|
||||
|
||||
args = get_default_args(Wrapper)
|
||||
self.assertEqual(args.a_args.n, 993)
|
||||
self.assertEqual(args.a2_args.n, 9)
|
||||
self.assertEqual(args.fruit_Pear_args.n_pips, 19)
|
||||
self.assertEqual(args.fruit2_Pear_args.n_pips, 13)
|
||||
|
||||
def test_impls(self):
|
||||
# Check that create_x actually uses create_x_impl to do its work
|
||||
# by using all the member types, both with a faked impl function
|
||||
|
||||
Reference in New Issue
Block a user