mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-14 11:26:24 +08:00
Optional ReplaceableBase
Summary: Allow things like `renderer:Optional[BaseRenderer]` in configurables. Reviewed By: davnov134 Differential Revision: D35118339 fbshipit-source-id: 1219321b2817ed4b26fe924c6d6f73887095c985
This commit is contained in:
committed by
Facebook GitHub Bot
parent
e332f9ffa4
commit
21262e38c7
@@ -14,7 +14,9 @@ from omegaconf import DictConfig, ListConfig, OmegaConf, ValidationError
|
||||
from pytorch3d.implicitron.tools.config import (
|
||||
Configurable,
|
||||
ReplaceableBase,
|
||||
_get_type_to_process,
|
||||
_is_actually_dataclass,
|
||||
_ProcessType,
|
||||
_Registry,
|
||||
expand_args_fields,
|
||||
get_default_args,
|
||||
@@ -94,6 +96,19 @@ class TestConfig(unittest.TestCase):
|
||||
self.assertFalse(_is_actually_dataclass(B))
|
||||
self.assertTrue(is_dataclass(B))
|
||||
|
||||
def test_get_type_to_process(self):
|
||||
gt = _get_type_to_process
|
||||
self.assertIsNone(gt(int))
|
||||
self.assertEqual(gt(Fruit), (Fruit, _ProcessType.REPLACEABLE))
|
||||
self.assertEqual(
|
||||
gt(Optional[Fruit]), (Fruit, _ProcessType.OPTIONAL_REPLACEABLE)
|
||||
)
|
||||
self.assertEqual(gt(MainTest), (MainTest, _ProcessType.CONFIGURABLE))
|
||||
self.assertIsNone(gt(Optional[int]))
|
||||
self.assertIsNone(gt(Optional[MainTest]))
|
||||
self.assertIsNone(gt(Tuple[Fruit]))
|
||||
self.assertIsNone(gt(Tuple[Fruit, Animal]))
|
||||
|
||||
def test_simple_replacement(self):
|
||||
struct = get_default_args(MainTest)
|
||||
struct.n_ids = 9780
|
||||
@@ -247,6 +262,7 @@ class TestConfig(unittest.TestCase):
|
||||
self.assertEqual(container.fruit_Pear_args.n_pips, 13)
|
||||
|
||||
def test_inheritance(self):
|
||||
# Also exercises optional replaceables
|
||||
class FruitBowl(ReplaceableBase):
|
||||
main_fruit: Fruit
|
||||
main_fruit_class_type: str = "Orange"
|
||||
@@ -255,8 +271,10 @@ class TestConfig(unittest.TestCase):
|
||||
raise ValueError("This doesn't get called")
|
||||
|
||||
class LargeFruitBowl(FruitBowl):
|
||||
extra_fruit: Fruit
|
||||
extra_fruit: Optional[Fruit]
|
||||
extra_fruit_class_type: str = "Kiwi"
|
||||
no_fruit: Optional[Fruit]
|
||||
no_fruit_class_type: Optional[str] = None
|
||||
|
||||
def __post_init__(self):
|
||||
run_auto_creation(self)
|
||||
@@ -267,6 +285,22 @@ class TestConfig(unittest.TestCase):
|
||||
large = LargeFruitBowl(**large_args)
|
||||
self.assertIsInstance(large.main_fruit, Orange)
|
||||
self.assertIsInstance(large.extra_fruit, Kiwi)
|
||||
self.assertIsNone(large.no_fruit)
|
||||
self.assertIn("no_fruit_Kiwi_args", large_args)
|
||||
|
||||
remove_unused_components(large_args)
|
||||
large2 = LargeFruitBowl(**large_args)
|
||||
self.assertIsInstance(large2.main_fruit, Orange)
|
||||
self.assertIsInstance(large2.extra_fruit, Kiwi)
|
||||
self.assertIsNone(large2.no_fruit)
|
||||
needed_args = [
|
||||
"extra_fruit_Kiwi_args",
|
||||
"extra_fruit_class_type",
|
||||
"main_fruit_Orange_args",
|
||||
"main_fruit_class_type",
|
||||
"no_fruit_class_type",
|
||||
]
|
||||
self.assertEqual(sorted(large_args.keys()), needed_args)
|
||||
|
||||
def test_inheritance2(self):
|
||||
# This is a case where a class could contain an instance
|
||||
|
||||
Reference in New Issue
Block a user