mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-22 15:20:34 +08:00
implicitron v0 (#1133)
Co-authored-by: Jeremy Francis Reizenstein <bottler@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
0e377c6850
commit
cdd2142dd5
610
tests/implicitron/test_config.py
Normal file
610
tests/implicitron/test_config.py
Normal file
@@ -0,0 +1,610 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import textwrap
|
||||
import unittest
|
||||
from dataclasses import dataclass, field, is_dataclass
|
||||
from enum import Enum
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from omegaconf import DictConfig, ListConfig, OmegaConf, ValidationError
|
||||
from pytorch3d.implicitron.tools.config import (
|
||||
Configurable,
|
||||
ReplaceableBase,
|
||||
_is_actually_dataclass,
|
||||
_Registry,
|
||||
expand_args_fields,
|
||||
get_default_args,
|
||||
get_default_args_field,
|
||||
registry,
|
||||
remove_unused_components,
|
||||
run_auto_creation,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Animal(ReplaceableBase):
|
||||
pass
|
||||
|
||||
|
||||
class Fruit(ReplaceableBase):
|
||||
pass
|
||||
|
||||
|
||||
@registry.register
|
||||
class Banana(Fruit):
|
||||
pips: int
|
||||
spots: int
|
||||
bananame: str
|
||||
|
||||
|
||||
@registry.register
|
||||
class Pear(Fruit):
|
||||
n_pips: int = 13
|
||||
|
||||
|
||||
class Pineapple(Fruit):
|
||||
pass
|
||||
|
||||
|
||||
@registry.register
|
||||
class Orange(Fruit):
|
||||
pass
|
||||
|
||||
|
||||
@registry.register
|
||||
class Kiwi(Fruit):
|
||||
pass
|
||||
|
||||
|
||||
@registry.register
|
||||
class LargePear(Pear):
|
||||
pass
|
||||
|
||||
|
||||
class MainTest(Configurable):
|
||||
the_fruit: Fruit
|
||||
n_ids: int
|
||||
n_reps: int = 8
|
||||
the_second_fruit: Fruit
|
||||
|
||||
def create_the_second_fruit(self):
|
||||
expand_args_fields(Pineapple)
|
||||
self.the_second_fruit = Pineapple()
|
||||
|
||||
def __post_init__(self):
|
||||
run_auto_creation(self)
|
||||
|
||||
|
||||
class TestConfig(unittest.TestCase):
|
||||
def test_is_actually_dataclass(self):
|
||||
@dataclass
|
||||
class A:
|
||||
pass
|
||||
|
||||
self.assertTrue(_is_actually_dataclass(A))
|
||||
self.assertTrue(is_dataclass(A))
|
||||
|
||||
class B(A):
|
||||
a: int
|
||||
|
||||
self.assertFalse(_is_actually_dataclass(B))
|
||||
self.assertTrue(is_dataclass(B))
|
||||
|
||||
def test_simple_replacement(self):
|
||||
struct = get_default_args(MainTest)
|
||||
struct.n_ids = 9780
|
||||
struct.the_fruit_Pear_args.n_pips = 3
|
||||
struct.the_fruit_class_type = "Pear"
|
||||
struct.the_second_fruit_class_type = "Pear"
|
||||
|
||||
main = MainTest(**struct)
|
||||
self.assertIsInstance(main.the_fruit, Pear)
|
||||
self.assertEqual(main.n_reps, 8)
|
||||
self.assertEqual(main.n_ids, 9780)
|
||||
self.assertEqual(main.the_fruit.n_pips, 3)
|
||||
self.assertIsInstance(main.the_second_fruit, Pineapple)
|
||||
|
||||
struct2 = get_default_args(MainTest)
|
||||
self.assertEqual(struct2.the_fruit_Pear_args.n_pips, 13)
|
||||
|
||||
self.assertEqual(
|
||||
MainTest._creation_functions,
|
||||
("create_the_fruit", "create_the_second_fruit"),
|
||||
)
|
||||
|
||||
def test_detect_bases(self):
|
||||
# testing the _base_class_from_class function
|
||||
self.assertIsNone(_Registry._base_class_from_class(ReplaceableBase))
|
||||
self.assertIsNone(_Registry._base_class_from_class(MainTest))
|
||||
self.assertIs(_Registry._base_class_from_class(Fruit), Fruit)
|
||||
self.assertIs(_Registry._base_class_from_class(Pear), Fruit)
|
||||
|
||||
class PricklyPear(Pear):
|
||||
pass
|
||||
|
||||
self.assertIs(_Registry._base_class_from_class(PricklyPear), Fruit)
|
||||
|
||||
def test_registry_entries(self):
|
||||
self.assertIs(registry.get(Fruit, "Banana"), Banana)
|
||||
with self.assertRaisesRegex(ValueError, "Banana has not been registered."):
|
||||
registry.get(Animal, "Banana")
|
||||
with self.assertRaisesRegex(ValueError, "PricklyPear has not been registered."):
|
||||
registry.get(Fruit, "PricklyPear")
|
||||
|
||||
self.assertIs(registry.get(Pear, "Pear"), Pear)
|
||||
self.assertIs(registry.get(Pear, "LargePear"), LargePear)
|
||||
with self.assertRaisesRegex(ValueError, "Banana resolves to"):
|
||||
registry.get(Pear, "Banana")
|
||||
|
||||
all_fruit = set(registry.get_all(Fruit))
|
||||
self.assertIn(Banana, all_fruit)
|
||||
self.assertIn(Pear, all_fruit)
|
||||
self.assertIn(LargePear, all_fruit)
|
||||
self.assertEqual(set(registry.get_all(Pear)), {LargePear})
|
||||
|
||||
@registry.register
|
||||
class Apple(Fruit):
|
||||
pass
|
||||
|
||||
@registry.register
|
||||
class CrabApple(Apple):
|
||||
pass
|
||||
|
||||
self.assertEqual(set(registry.get_all(Apple)), {CrabApple})
|
||||
|
||||
self.assertIs(registry.get(Fruit, "CrabApple"), CrabApple)
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "Cannot tell what it is."):
|
||||
|
||||
@registry.register
|
||||
class NotAFruit:
|
||||
pass
|
||||
|
||||
def test_recursion(self):
|
||||
class Shape(ReplaceableBase):
|
||||
pass
|
||||
|
||||
@registry.register
|
||||
class Triangle(Shape):
|
||||
a: float = 5.0
|
||||
|
||||
@registry.register
|
||||
class Square(Shape):
|
||||
a: float = 3.0
|
||||
|
||||
@registry.register
|
||||
class LargeShape(Shape):
|
||||
inner: Shape
|
||||
|
||||
def __post_init__(self):
|
||||
run_auto_creation(self)
|
||||
|
||||
class ShapeContainer(Configurable):
|
||||
shape: Shape
|
||||
|
||||
container = ShapeContainer(**get_default_args(ShapeContainer))
|
||||
# This is because ShapeContainer is missing __post_init__
|
||||
with self.assertRaises(AttributeError):
|
||||
container.shape
|
||||
|
||||
class ShapeContainer2(Configurable):
|
||||
x: Shape
|
||||
x_class_type: str = "LargeShape"
|
||||
|
||||
def __post_init__(self):
|
||||
self.x_LargeShape_args.inner_class_type = "Triangle"
|
||||
run_auto_creation(self)
|
||||
|
||||
container2_args = get_default_args(ShapeContainer2)
|
||||
container2_args.x_LargeShape_args.inner_Triangle_args.a += 10
|
||||
self.assertIn("inner_Square_args", container2_args.x_LargeShape_args)
|
||||
# We do not perform expansion that would result in an infinite recursion,
|
||||
# so this member is not present.
|
||||
self.assertNotIn("inner_LargeShape_args", container2_args.x_LargeShape_args)
|
||||
container2_args.x_LargeShape_args.inner_Square_args.a += 100
|
||||
container2 = ShapeContainer2(**container2_args)
|
||||
self.assertIsInstance(container2.x, LargeShape)
|
||||
self.assertIsInstance(container2.x.inner, Triangle)
|
||||
self.assertEqual(container2.x.inner.a, 15.0)
|
||||
|
||||
def test_simpleclass_member(self):
|
||||
# Members which are not dataclasses are
|
||||
# tolerated. But it would be nice to be able to
|
||||
# configure them.
|
||||
class Foo:
|
||||
def __init__(self, a=1, b=2):
|
||||
self.a, self.b = a, b
|
||||
|
||||
@dataclass()
|
||||
class Bar:
|
||||
aa: int = 9
|
||||
bb: int = 9
|
||||
|
||||
class Container(Configurable):
|
||||
bar: Bar = Bar()
|
||||
# TODO make this work?
|
||||
# foo: Foo = Foo()
|
||||
fruit: Fruit
|
||||
fruit_class_type: str = "Orange"
|
||||
|
||||
def __post_init__(self):
|
||||
run_auto_creation(self)
|
||||
|
||||
self.assertEqual(get_default_args(Foo), {"a": 1, "b": 2})
|
||||
container_args = get_default_args(Container)
|
||||
container = Container(**container_args)
|
||||
self.assertIsInstance(container.fruit, Orange)
|
||||
# self.assertIsInstance(container.bar, Bar)
|
||||
|
||||
container_defaulted = Container()
|
||||
container_defaulted.fruit_Pear_args.n_pips += 4
|
||||
|
||||
container_args2 = get_default_args(Container)
|
||||
container = Container(**container_args2)
|
||||
self.assertEqual(container.fruit_Pear_args.n_pips, 13)
|
||||
|
||||
def test_inheritance(self):
|
||||
class FruitBowl(ReplaceableBase):
|
||||
main_fruit: Fruit
|
||||
main_fruit_class_type: str = "Orange"
|
||||
|
||||
def __post_init__(self):
|
||||
raise ValueError("This doesn't get called")
|
||||
|
||||
class LargeFruitBowl(FruitBowl):
|
||||
extra_fruit: Fruit
|
||||
extra_fruit_class_type: str = "Kiwi"
|
||||
|
||||
def __post_init__(self):
|
||||
run_auto_creation(self)
|
||||
|
||||
large_args = get_default_args(LargeFruitBowl)
|
||||
self.assertNotIn("extra_fruit", large_args)
|
||||
self.assertNotIn("main_fruit", large_args)
|
||||
large = LargeFruitBowl(**large_args)
|
||||
self.assertIsInstance(large.main_fruit, Orange)
|
||||
self.assertIsInstance(large.extra_fruit, Kiwi)
|
||||
|
||||
def test_inheritance2(self):
|
||||
# This is a case where a class could contain an instance
|
||||
# of a subclass, which is ignored.
|
||||
class Parent(ReplaceableBase):
|
||||
pass
|
||||
|
||||
class Main(Configurable):
|
||||
parent: Parent
|
||||
# Note - no __post__init__
|
||||
|
||||
@registry.register
|
||||
class Derived(Parent, Main):
|
||||
pass
|
||||
|
||||
args = get_default_args(Main)
|
||||
# Derived has been ignored in processing Main.
|
||||
self.assertCountEqual(args.keys(), ["parent_class_type"])
|
||||
|
||||
main = Main(**args)
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "UNDEFAULTED has not been registered."):
|
||||
run_auto_creation(main)
|
||||
|
||||
main.parent_class_type = "Derived"
|
||||
# Illustrates that a dict works fine instead of a DictConfig.
|
||||
main.parent_Derived_args = {}
|
||||
with self.assertRaises(AttributeError):
|
||||
main.parent
|
||||
run_auto_creation(main)
|
||||
self.assertIsInstance(main.parent, Derived)
|
||||
|
||||
def test_redefine(self):
|
||||
class FruitBowl(ReplaceableBase):
|
||||
main_fruit: Fruit
|
||||
main_fruit_class_type: str = "Grape"
|
||||
|
||||
def __post_init__(self):
|
||||
run_auto_creation(self)
|
||||
|
||||
@registry.register
|
||||
@dataclass
|
||||
class Grape(Fruit):
|
||||
large: bool = False
|
||||
|
||||
def get_color(self):
|
||||
return "red"
|
||||
|
||||
def __post_init__(self):
|
||||
raise ValueError("This doesn't get called")
|
||||
|
||||
bowl_args = get_default_args(FruitBowl)
|
||||
|
||||
@registry.register
|
||||
@dataclass
|
||||
class Grape(Fruit): # noqa: F811
|
||||
large: bool = True
|
||||
|
||||
def get_color(self):
|
||||
return "green"
|
||||
|
||||
with self.assertWarnsRegex(
|
||||
UserWarning, "New implementation of Grape is being chosen."
|
||||
):
|
||||
bowl = FruitBowl(**bowl_args)
|
||||
self.assertIsInstance(bowl.main_fruit, Grape)
|
||||
|
||||
# Redefining the same class won't help with defaults because encoded in args
|
||||
self.assertEqual(bowl.main_fruit.large, False)
|
||||
|
||||
# But the override worked.
|
||||
self.assertEqual(bowl.main_fruit.get_color(), "green")
|
||||
|
||||
# 2. Try redefining without the dataclass modifier
|
||||
# This relies on the fact that default creation processes the class.
|
||||
# (otherwise incomprehensible messages)
|
||||
@registry.register
|
||||
class Grape(Fruit): # noqa: F811
|
||||
large: bool = True
|
||||
|
||||
with self.assertWarnsRegex(
|
||||
UserWarning, "New implementation of Grape is being chosen."
|
||||
):
|
||||
bowl = FruitBowl(**bowl_args)
|
||||
|
||||
# 3. Adding a new class doesn't get picked up, because the first
|
||||
# get_default_args call has frozen FruitBowl. This is intrinsic to
|
||||
# the way dataclass and expand_args_fields work in-place but
|
||||
# expand_args_fields is not pure - it depends on the registry.
|
||||
@registry.register
|
||||
class Fig(Fruit):
|
||||
pass
|
||||
|
||||
bowl_args2 = get_default_args(FruitBowl)
|
||||
self.assertIn("main_fruit_Grape_args", bowl_args2)
|
||||
self.assertNotIn("main_fruit_Fig_args", bowl_args2)
|
||||
|
||||
# TODO Is it possible to make this work?
|
||||
# bowl_args2["main_fruit_Fig_args"] = get_default_args(Fig)
|
||||
# bowl_args2.main_fruit_class_type = "Fig"
|
||||
# bowl2 = FruitBowl(**bowl_args2) <= unexpected argument
|
||||
|
||||
# Note that it is possible to use Fig if you can set
|
||||
# bowl2.main_fruit_Fig_args explicitly (not in bowl_args2)
|
||||
# before run_auto_creation happens. See test_inheritance2
|
||||
# for an example.
|
||||
|
||||
def test_no_replacement(self):
|
||||
# Test of Configurables without ReplaceableBase
|
||||
class A(Configurable):
|
||||
n: int = 9
|
||||
|
||||
class B(Configurable):
|
||||
a: A
|
||||
|
||||
def __post_init__(self):
|
||||
run_auto_creation(self)
|
||||
|
||||
class C(Configurable):
|
||||
b: B
|
||||
|
||||
def __post_init__(self):
|
||||
run_auto_creation(self)
|
||||
|
||||
c_args = get_default_args(C)
|
||||
c = C(**c_args)
|
||||
self.assertIsInstance(c.b.a, A)
|
||||
self.assertEqual(c.b.a.n, 9)
|
||||
|
||||
def test_doc(self):
|
||||
# The case in the docstring.
|
||||
class A(ReplaceableBase):
|
||||
k: int = 1
|
||||
|
||||
@registry.register
|
||||
class A1(A):
|
||||
m: int = 3
|
||||
|
||||
@registry.register
|
||||
class A2(A):
|
||||
n: str = "2"
|
||||
|
||||
class B(Configurable):
|
||||
a: A
|
||||
a_class_type: str = "A2"
|
||||
|
||||
def __post_init__(self):
|
||||
run_auto_creation(self)
|
||||
|
||||
b_args = get_default_args(B)
|
||||
self.assertNotIn("a", b_args)
|
||||
b = B(**b_args)
|
||||
self.assertEqual(b.a.n, "2")
|
||||
|
||||
def test_raw_types(self):
|
||||
@dataclass
|
||||
class MyDataclass:
|
||||
int_field: int = 0
|
||||
none_field: Optional[int] = None
|
||||
float_field: float = 9.3
|
||||
bool_field: bool = True
|
||||
tuple_field: tuple = (3, True, "j")
|
||||
|
||||
class SimpleClass:
|
||||
def __init__(self, tuple_member_=(3, 4)):
|
||||
self.tuple_member = tuple_member_
|
||||
|
||||
def get_tuple(self):
|
||||
return self.tuple_member
|
||||
|
||||
def f(*, a: int = 3, b: str = "kj"):
|
||||
self.assertEqual(a, 3)
|
||||
self.assertEqual(b, "kj")
|
||||
|
||||
class C(Configurable):
|
||||
simple: DictConfig = get_default_args_field(SimpleClass)
|
||||
# simple2: SimpleClass2 = SimpleClass2()
|
||||
mydata: DictConfig = get_default_args_field(MyDataclass)
|
||||
a_tuple: Tuple[float] = (4.0, 3.0)
|
||||
f_args: DictConfig = get_default_args_field(f)
|
||||
|
||||
args = get_default_args(C)
|
||||
c = C(**args)
|
||||
self.assertCountEqual(args.keys(), ["simple", "mydata", "a_tuple", "f_args"])
|
||||
|
||||
mydata = MyDataclass(**c.mydata)
|
||||
simple = SimpleClass(**c.simple)
|
||||
|
||||
# OmegaConf converts tuples to ListConfigs (which act like lists).
|
||||
self.assertEqual(simple.get_tuple(), [3, 4])
|
||||
self.assertTrue(isinstance(simple.get_tuple(), ListConfig))
|
||||
self.assertEqual(c.a_tuple, [4.0, 3.0])
|
||||
self.assertTrue(isinstance(c.a_tuple, ListConfig))
|
||||
self.assertEqual(mydata.tuple_field, (3, True, "j"))
|
||||
self.assertTrue(isinstance(mydata.tuple_field, ListConfig))
|
||||
f(**c.f_args)
|
||||
|
||||
def test_irrelevant_bases(self):
|
||||
class NotADataclass:
|
||||
# Like torch.nn.Module, this class contains annotations
|
||||
# but is not designed to be dataclass'd.
|
||||
# This test ensures that such classes, when inherited fron,
|
||||
# are not accidentally expand_args_fields.
|
||||
a: int = 9
|
||||
b: int
|
||||
|
||||
class LeftConfigured(Configurable, NotADataclass):
|
||||
left: int = 1
|
||||
|
||||
class RightConfigured(NotADataclass, Configurable):
|
||||
right: int = 2
|
||||
|
||||
class Outer(Configurable):
|
||||
left: LeftConfigured
|
||||
right: RightConfigured
|
||||
|
||||
def __post_init__(self):
|
||||
run_auto_creation(self)
|
||||
|
||||
outer = Outer(**get_default_args(Outer))
|
||||
self.assertEqual(outer.left.left, 1)
|
||||
self.assertEqual(outer.right.right, 2)
|
||||
with self.assertRaisesRegex(TypeError, "non-default argument"):
|
||||
dataclass(NotADataclass)
|
||||
|
||||
def test_unprocessed(self):
|
||||
# behavior of Configurable classes which need processing in __new__,
|
||||
class Unprocessed(Configurable):
|
||||
a: int = 9
|
||||
|
||||
class UnprocessedReplaceable(ReplaceableBase):
|
||||
a: int = 1
|
||||
|
||||
with self.assertWarnsRegex(UserWarning, "must be processed"):
|
||||
Unprocessed()
|
||||
with self.assertWarnsRegex(UserWarning, "must be processed"):
|
||||
UnprocessedReplaceable()
|
||||
|
||||
def test_enum(self):
|
||||
# Test that enum values are kept, i.e. that OmegaConf's runtime checks
|
||||
# are in use.
|
||||
|
||||
class A(Enum):
|
||||
B1 = "b1"
|
||||
B2 = "b2"
|
||||
|
||||
class C(Configurable):
|
||||
a: A = A.B1
|
||||
|
||||
base = get_default_args(C)
|
||||
replaced = OmegaConf.merge(base, {"a": "B2"})
|
||||
self.assertEqual(replaced.a, A.B2)
|
||||
with self.assertRaises(ValidationError):
|
||||
# You can't use a value which is not one of the
|
||||
# choices, even if it is the str representation
|
||||
# of one of the choices.
|
||||
OmegaConf.merge(base, {"a": "b2"})
|
||||
|
||||
remerged = OmegaConf.merge(base, OmegaConf.create(OmegaConf.to_yaml(base)))
|
||||
self.assertEqual(remerged.a, A.B1)
|
||||
|
||||
def test_remove_unused_components(self):
|
||||
struct = get_default_args(MainTest)
|
||||
struct.n_ids = 32
|
||||
struct.the_fruit_class_type = "Pear"
|
||||
struct.the_second_fruit_class_type = "Banana"
|
||||
remove_unused_components(struct)
|
||||
expected_keys = [
|
||||
"n_ids",
|
||||
"n_reps",
|
||||
"the_fruit_Pear_args",
|
||||
"the_fruit_class_type",
|
||||
"the_second_fruit_Banana_args",
|
||||
"the_second_fruit_class_type",
|
||||
]
|
||||
expected_yaml = textwrap.dedent(
|
||||
"""\
|
||||
n_ids: 32
|
||||
n_reps: 8
|
||||
the_fruit_class_type: Pear
|
||||
the_fruit_Pear_args:
|
||||
n_pips: 13
|
||||
the_second_fruit_class_type: Banana
|
||||
the_second_fruit_Banana_args:
|
||||
pips: ???
|
||||
spots: ???
|
||||
bananame: ???
|
||||
"""
|
||||
)
|
||||
self.assertEqual(sorted(struct.keys()), expected_keys)
|
||||
|
||||
# Check that struct is what we expect
|
||||
expected = OmegaConf.create(expected_yaml)
|
||||
self.assertEqual(struct, expected)
|
||||
|
||||
# Check that we get what we expect when writing to yaml.
|
||||
self.assertEqual(OmegaConf.to_yaml(struct, sort_keys=False), expected_yaml)
|
||||
|
||||
main = MainTest(**struct)
|
||||
instance_data = OmegaConf.structured(main)
|
||||
remove_unused_components(instance_data)
|
||||
self.assertEqual(sorted(instance_data.keys()), expected_keys)
|
||||
self.assertEqual(instance_data, expected)
|
||||
|
||||
|
||||
@dataclass(eq=False)
|
||||
class MockDataclass:
|
||||
field_no_default: int
|
||||
field_primitive_type: int = 42
|
||||
field_reference_type: List[int] = field(default_factory=lambda: [])
|
||||
|
||||
|
||||
class MockClassWithInit: # noqa: B903
|
||||
def __init__(
|
||||
self,
|
||||
field_no_default: int,
|
||||
field_primitive_type: int = 42,
|
||||
field_reference_type: List[int] = [], # noqa: B006
|
||||
):
|
||||
self.field_no_default = field_no_default
|
||||
self.field_primitive_type = field_primitive_type
|
||||
self.field_reference_type = field_reference_type
|
||||
|
||||
|
||||
class TestRawClasses(unittest.TestCase):
|
||||
def test_get_default_args(self):
|
||||
for cls in [MockDataclass, MockClassWithInit]:
|
||||
dataclass_defaults = get_default_args(cls)
|
||||
inst = cls(field_no_default=0)
|
||||
dataclass_defaults.field_no_default = 0
|
||||
for name, val in dataclass_defaults.items():
|
||||
self.assertTrue(hasattr(inst, name))
|
||||
self.assertEqual(val, getattr(inst, name))
|
||||
|
||||
def test_get_default_args_readonly(self):
|
||||
for cls in [MockDataclass, MockClassWithInit]:
|
||||
dataclass_defaults = get_default_args(cls)
|
||||
dataclass_defaults["field_reference_type"].append(13)
|
||||
inst = cls(field_no_default=0)
|
||||
self.assertEqual(inst.field_reference_type, [])
|
||||
Reference in New Issue
Block a user