Optional ReplaceableBase

Summary: Allow things like `renderer:Optional[BaseRenderer]` in configurables.

Reviewed By: davnov134

Differential Revision: D35118339

fbshipit-source-id: 1219321b2817ed4b26fe924c6d6f73887095c985
This commit is contained in:
Jeremy Reizenstein
2022-03-29 08:43:46 -07:00
committed by Facebook GitHub Bot
parent e332f9ffa4
commit 21262e38c7
4 changed files with 171 additions and 46 deletions

View File

@@ -14,7 +14,9 @@ from omegaconf import DictConfig, ListConfig, OmegaConf, ValidationError
from pytorch3d.implicitron.tools.config import (
Configurable,
ReplaceableBase,
_get_type_to_process,
_is_actually_dataclass,
_ProcessType,
_Registry,
expand_args_fields,
get_default_args,
@@ -94,6 +96,19 @@ class TestConfig(unittest.TestCase):
self.assertFalse(_is_actually_dataclass(B))
self.assertTrue(is_dataclass(B))
def test_get_type_to_process(self):
gt = _get_type_to_process
self.assertIsNone(gt(int))
self.assertEqual(gt(Fruit), (Fruit, _ProcessType.REPLACEABLE))
self.assertEqual(
gt(Optional[Fruit]), (Fruit, _ProcessType.OPTIONAL_REPLACEABLE)
)
self.assertEqual(gt(MainTest), (MainTest, _ProcessType.CONFIGURABLE))
self.assertIsNone(gt(Optional[int]))
self.assertIsNone(gt(Optional[MainTest]))
self.assertIsNone(gt(Tuple[Fruit]))
self.assertIsNone(gt(Tuple[Fruit, Animal]))
def test_simple_replacement(self):
struct = get_default_args(MainTest)
struct.n_ids = 9780
@@ -247,6 +262,7 @@ class TestConfig(unittest.TestCase):
self.assertEqual(container.fruit_Pear_args.n_pips, 13)
def test_inheritance(self):
# Also exercises optional replaceables
class FruitBowl(ReplaceableBase):
main_fruit: Fruit
main_fruit_class_type: str = "Orange"
@@ -255,8 +271,10 @@ class TestConfig(unittest.TestCase):
raise ValueError("This doesn't get called")
class LargeFruitBowl(FruitBowl):
extra_fruit: Fruit
extra_fruit: Optional[Fruit]
extra_fruit_class_type: str = "Kiwi"
no_fruit: Optional[Fruit]
no_fruit_class_type: Optional[str] = None
def __post_init__(self):
run_auto_creation(self)
@@ -267,6 +285,22 @@ class TestConfig(unittest.TestCase):
large = LargeFruitBowl(**large_args)
self.assertIsInstance(large.main_fruit, Orange)
self.assertIsInstance(large.extra_fruit, Kiwi)
self.assertIsNone(large.no_fruit)
self.assertIn("no_fruit_Kiwi_args", large_args)
remove_unused_components(large_args)
large2 = LargeFruitBowl(**large_args)
self.assertIsInstance(large2.main_fruit, Orange)
self.assertIsInstance(large2.extra_fruit, Kiwi)
self.assertIsNone(large2.no_fruit)
needed_args = [
"extra_fruit_Kiwi_args",
"extra_fruit_class_type",
"main_fruit_Orange_args",
"main_fruit_class_type",
"no_fruit_class_type",
]
self.assertEqual(sorted(large_args.keys()), needed_args)
def test_inheritance2(self):
# This is a case where a class could contain an instance