pytorch3d/tests/implicitron/test_config.py
Tim Hatch 34bbb3ad32 apply import merging for fbcode/vision/fair (2 of 2)
Summary:
Applies new import merging and sorting from µsort v1.0.

When merging imports, µsort will make a best-effort to move associated
comments to match merged elements, but there are known limitations due to
the diynamic nature of Python and developer tooling. These changes should
not produce any dangerous runtime changes, but may require touch-ups to
satisfy linters and other tooling.

Note that µsort uses case-insensitive, lexicographical sorting, which
results in a different ordering compared to isort. This provides a more
consistent sorting order, matching the case-insensitive order used when
sorting import statements by module name, and ensures that "frog", "FROG",
and "Frog" always sort next to each other.

For details on µsort's sorting and merging semantics, see the user guide:
https://usort.readthedocs.io/en/stable/guide.html#sorting

Reviewed By: bottler

Differential Revision: D35553814

fbshipit-source-id: be49bdb6a4c25264ff8d4db3a601f18736d17be1
2022-04-13 06:51:33 -07:00

744 lines
24 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import pickle
import textwrap
import unittest
from dataclasses import dataclass, field, is_dataclass
from enum import Enum
from typing import Any, List, Optional, Set, Tuple
from omegaconf import DictConfig, ListConfig, OmegaConf, ValidationError
from pytorch3d.implicitron.tools.config import (
_get_type_to_process,
_is_actually_dataclass,
_ProcessType,
_Registry,
Configurable,
enable_get_default_args,
expand_args_fields,
get_default_args,
get_default_args_field,
registry,
remove_unused_components,
ReplaceableBase,
run_auto_creation,
)
@dataclass
class Animal(ReplaceableBase):
pass
class Fruit(ReplaceableBase):
pass
@registry.register
class Banana(Fruit):
pips: int
spots: int
bananame: str
@registry.register
class Pear(Fruit):
n_pips: int = 13
class Pineapple(Fruit):
pass
@registry.register
class Orange(Fruit):
pass
@registry.register
class Kiwi(Fruit):
pass
@registry.register
class LargePear(Pear):
pass
class MainTest(Configurable):
the_fruit: Fruit
n_ids: int
n_reps: int = 8
the_second_fruit: Fruit
def create_the_second_fruit(self):
expand_args_fields(Pineapple)
self.the_second_fruit = Pineapple()
def __post_init__(self):
run_auto_creation(self)
class TestConfig(unittest.TestCase):
def test_is_actually_dataclass(self):
@dataclass
class A:
pass
self.assertTrue(_is_actually_dataclass(A))
self.assertTrue(is_dataclass(A))
class B(A):
a: int
self.assertFalse(_is_actually_dataclass(B))
self.assertTrue(is_dataclass(B))
def test_get_type_to_process(self):
gt = _get_type_to_process
self.assertIsNone(gt(int))
self.assertEqual(gt(Fruit), (Fruit, _ProcessType.REPLACEABLE))
self.assertEqual(
gt(Optional[Fruit]), (Fruit, _ProcessType.OPTIONAL_REPLACEABLE)
)
self.assertEqual(gt(MainTest), (MainTest, _ProcessType.CONFIGURABLE))
self.assertEqual(
gt(Optional[MainTest]), (MainTest, _ProcessType.OPTIONAL_CONFIGURABLE)
)
self.assertIsNone(gt(Optional[int]))
self.assertIsNone(gt(Tuple[Fruit]))
self.assertIsNone(gt(Tuple[Fruit, Animal]))
self.assertIsNone(gt(Optional[List[int]]))
def test_simple_replacement(self):
struct = get_default_args(MainTest)
struct.n_ids = 9780
struct.the_fruit_Pear_args.n_pips = 3
struct.the_fruit_class_type = "Pear"
struct.the_second_fruit_class_type = "Pear"
main = MainTest(**struct)
self.assertIsInstance(main.the_fruit, Pear)
self.assertEqual(main.n_reps, 8)
self.assertEqual(main.n_ids, 9780)
self.assertEqual(main.the_fruit.n_pips, 3)
self.assertIsInstance(main.the_second_fruit, Pineapple)
struct2 = get_default_args(MainTest)
self.assertEqual(struct2.the_fruit_Pear_args.n_pips, 13)
self.assertEqual(
MainTest._creation_functions,
("create_the_fruit", "create_the_second_fruit"),
)
def test_detect_bases(self):
# testing the _base_class_from_class function
self.assertIsNone(_Registry._base_class_from_class(ReplaceableBase))
self.assertIsNone(_Registry._base_class_from_class(MainTest))
self.assertIs(_Registry._base_class_from_class(Fruit), Fruit)
self.assertIs(_Registry._base_class_from_class(Pear), Fruit)
class PricklyPear(Pear):
pass
self.assertIs(_Registry._base_class_from_class(PricklyPear), Fruit)
def test_registry_entries(self):
self.assertIs(registry.get(Fruit, "Banana"), Banana)
with self.assertRaisesRegex(ValueError, "Banana has not been registered."):
registry.get(Animal, "Banana")
with self.assertRaisesRegex(ValueError, "PricklyPear has not been registered."):
registry.get(Fruit, "PricklyPear")
self.assertIs(registry.get(Pear, "Pear"), Pear)
self.assertIs(registry.get(Pear, "LargePear"), LargePear)
with self.assertRaisesRegex(ValueError, "Banana resolves to"):
registry.get(Pear, "Banana")
all_fruit = set(registry.get_all(Fruit))
self.assertIn(Banana, all_fruit)
self.assertIn(Pear, all_fruit)
self.assertIn(LargePear, all_fruit)
self.assertEqual(set(registry.get_all(Pear)), {LargePear})
@registry.register
class Apple(Fruit):
pass
@registry.register
class CrabApple(Apple):
pass
self.assertEqual(set(registry.get_all(Apple)), {CrabApple})
self.assertIs(registry.get(Fruit, "CrabApple"), CrabApple)
with self.assertRaisesRegex(ValueError, "Cannot tell what it is."):
@registry.register
class NotAFruit:
pass
def test_recursion(self):
class Shape(ReplaceableBase):
pass
@registry.register
class Triangle(Shape):
a: float = 5.0
@registry.register
class Square(Shape):
a: float = 3.0
@registry.register
class LargeShape(Shape):
inner: Shape
def __post_init__(self):
run_auto_creation(self)
class ShapeContainer(Configurable):
shape: Shape
container = ShapeContainer(**get_default_args(ShapeContainer))
# This is because ShapeContainer is missing __post_init__
with self.assertRaises(AttributeError):
container.shape
class ShapeContainer2(Configurable):
x: Shape
x_class_type: str = "LargeShape"
def __post_init__(self):
self.x_LargeShape_args.inner_class_type = "Triangle"
run_auto_creation(self)
container2_args = get_default_args(ShapeContainer2)
container2_args.x_LargeShape_args.inner_Triangle_args.a += 10
self.assertIn("inner_Square_args", container2_args.x_LargeShape_args)
# We do not perform expansion that would result in an infinite recursion,
# so this member is not present.
self.assertNotIn("inner_LargeShape_args", container2_args.x_LargeShape_args)
container2_args.x_LargeShape_args.inner_Square_args.a += 100
container2 = ShapeContainer2(**container2_args)
self.assertIsInstance(container2.x, LargeShape)
self.assertIsInstance(container2.x.inner, Triangle)
self.assertEqual(container2.x.inner.a, 15.0)
def test_simpleclass_member(self):
# Members which are not dataclasses are
# tolerated. But it would be nice to be able to
# configure them.
class Foo:
def __init__(self, a: Any = 1, b: Any = 2):
self.a, self.b = a, b
enable_get_default_args(Foo)
@dataclass()
class Bar:
aa: int = 9
bb: int = 9
class Container(Configurable):
bar: Bar = Bar()
# TODO make this work?
# foo: Foo = Foo()
fruit: Fruit
fruit_class_type: str = "Orange"
def __post_init__(self):
run_auto_creation(self)
self.assertEqual(get_default_args(Foo), {"a": 1, "b": 2})
container_args = get_default_args(Container)
container = Container(**container_args)
self.assertIsInstance(container.fruit, Orange)
self.assertEqual(Container._processed_members, {"fruit": Fruit})
self.assertEqual(container._processed_members, {"fruit": Fruit})
container_defaulted = Container()
container_defaulted.fruit_Pear_args.n_pips += 4
container_args2 = get_default_args(Container)
container = Container(**container_args2)
self.assertEqual(container.fruit_Pear_args.n_pips, 13)
def test_inheritance(self):
# Also exercises optional replaceables
class FruitBowl(ReplaceableBase):
main_fruit: Fruit
main_fruit_class_type: str = "Orange"
def __post_init__(self):
raise ValueError("This doesn't get called")
class LargeFruitBowl(FruitBowl):
extra_fruit: Optional[Fruit]
extra_fruit_class_type: str = "Kiwi"
no_fruit: Optional[Fruit]
no_fruit_class_type: Optional[str] = None
def __post_init__(self):
run_auto_creation(self)
large_args = get_default_args(LargeFruitBowl)
self.assertNotIn("extra_fruit", large_args)
self.assertNotIn("main_fruit", large_args)
large = LargeFruitBowl(**large_args)
self.assertIsInstance(large.main_fruit, Orange)
self.assertIsInstance(large.extra_fruit, Kiwi)
self.assertIsNone(large.no_fruit)
self.assertIn("no_fruit_Kiwi_args", large_args)
remove_unused_components(large_args)
large2 = LargeFruitBowl(**large_args)
self.assertIsInstance(large2.main_fruit, Orange)
self.assertIsInstance(large2.extra_fruit, Kiwi)
self.assertIsNone(large2.no_fruit)
needed_args = [
"extra_fruit_Kiwi_args",
"extra_fruit_class_type",
"main_fruit_Orange_args",
"main_fruit_class_type",
"no_fruit_class_type",
]
self.assertEqual(sorted(large_args.keys()), needed_args)
def test_inheritance2(self):
# This is a case where a class could contain an instance
# of a subclass, which is ignored.
class Parent(ReplaceableBase):
pass
class Main(Configurable):
parent: Parent
# Note - no __post__init__
@registry.register
class Derived(Parent, Main):
pass
args = get_default_args(Main)
# Derived has been ignored in processing Main.
self.assertCountEqual(args.keys(), ["parent_class_type"])
main = Main(**args)
with self.assertRaisesRegex(ValueError, "UNDEFAULTED has not been registered."):
run_auto_creation(main)
main.parent_class_type = "Derived"
# Illustrates that a dict works fine instead of a DictConfig.
main.parent_Derived_args = {}
with self.assertRaises(AttributeError):
main.parent
run_auto_creation(main)
self.assertIsInstance(main.parent, Derived)
def test_redefine(self):
class FruitBowl(ReplaceableBase):
main_fruit: Fruit
main_fruit_class_type: str = "Grape"
def __post_init__(self):
run_auto_creation(self)
@registry.register
@dataclass
class Grape(Fruit):
large: bool = False
def get_color(self):
return "red"
def __post_init__(self):
raise ValueError("This doesn't get called")
bowl_args = get_default_args(FruitBowl)
@registry.register
@dataclass
class Grape(Fruit): # noqa: F811
large: bool = True
def get_color(self):
return "green"
with self.assertWarnsRegex(
UserWarning, "New implementation of Grape is being chosen."
):
bowl = FruitBowl(**bowl_args)
self.assertIsInstance(bowl.main_fruit, Grape)
# Redefining the same class won't help with defaults because encoded in args
self.assertEqual(bowl.main_fruit.large, False)
# But the override worked.
self.assertEqual(bowl.main_fruit.get_color(), "green")
# 2. Try redefining without the dataclass modifier
# This relies on the fact that default creation processes the class.
# (otherwise incomprehensible messages)
@registry.register
class Grape(Fruit): # noqa: F811
large: bool = True
with self.assertWarnsRegex(
UserWarning, "New implementation of Grape is being chosen."
):
bowl = FruitBowl(**bowl_args)
# 3. Adding a new class doesn't get picked up, because the first
# get_default_args call has frozen FruitBowl. This is intrinsic to
# the way dataclass and expand_args_fields work in-place but
# expand_args_fields is not pure - it depends on the registry.
@registry.register
class Fig(Fruit):
pass
bowl_args2 = get_default_args(FruitBowl)
self.assertIn("main_fruit_Grape_args", bowl_args2)
self.assertNotIn("main_fruit_Fig_args", bowl_args2)
# TODO Is it possible to make this work?
# bowl_args2["main_fruit_Fig_args"] = get_default_args(Fig)
# bowl_args2.main_fruit_class_type = "Fig"
# bowl2 = FruitBowl(**bowl_args2) <= unexpected argument
# Note that it is possible to use Fig if you can set
# bowl2.main_fruit_Fig_args explicitly (not in bowl_args2)
# before run_auto_creation happens. See test_inheritance2
# for an example.
def test_no_replacement(self):
# Test of Configurables without ReplaceableBase
class A(Configurable):
n: int = 9
class B(Configurable):
a: A
def __post_init__(self):
run_auto_creation(self)
class C(Configurable):
b1: B
b2: Optional[B]
b3: Optional[B]
b2_enabled: bool = True
def __post_init__(self):
run_auto_creation(self)
c_args = get_default_args(C)
c = C(**c_args)
self.assertIsInstance(c.b1.a, A)
self.assertEqual(c.b1.a.n, 9)
self.assertFalse(hasattr(c, "b1_enabled"))
self.assertIsInstance(c.b2.a, A)
self.assertEqual(c.b2.a.n, 9)
self.assertTrue(c.b2_enabled)
self.assertIsNone(c.b3)
self.assertFalse(c.b3_enabled)
def test_doc(self):
# The case in the docstring.
class A(ReplaceableBase):
k: int = 1
@registry.register
class A1(A):
m: int = 3
@registry.register
class A2(A):
n: str = "2"
class B(Configurable):
a: A
a_class_type: str = "A2"
def __post_init__(self):
run_auto_creation(self)
b_args = get_default_args(B)
self.assertNotIn("a", b_args)
b = B(**b_args)
self.assertEqual(b.a.n, "2")
def test_raw_types(self):
@dataclass
class MyDataclass:
int_field: int = 0
none_field: Optional[int] = None
float_field: float = 9.3
bool_field: bool = True
tuple_field: tuple = (3, True, "j")
class SimpleClass:
def __init__(
self,
tuple_member_: Tuple[int, int] = (3, 4),
set_member_: Set[int] = {2}, # noqa
):
self.tuple_member = tuple_member_
self.set_member = set_member_
def get_tuple(self):
return self.tuple_member
enable_get_default_args(SimpleClass)
def f(*, a: int = 3, b: str = "kj"):
self.assertEqual(a, 3)
self.assertEqual(b, "kj")
enable_get_default_args(f)
class C(Configurable):
simple: DictConfig = get_default_args_field(SimpleClass)
# simple2: SimpleClass2 = SimpleClass2()
mydata: DictConfig = get_default_args_field(MyDataclass)
a_tuple: Tuple[float] = (4.0, 3.0)
f_args: DictConfig = get_default_args_field(f)
args = get_default_args(C)
c = C(**args)
self.assertCountEqual(args.keys(), ["simple", "mydata", "a_tuple", "f_args"])
mydata = MyDataclass(**c.mydata)
simple = SimpleClass(**c.simple)
# OmegaConf converts tuples to ListConfigs (which act like lists).
self.assertEqual(simple.get_tuple(), [3, 4])
self.assertTrue(isinstance(simple.get_tuple(), ListConfig))
# get_default_args converts sets to ListConfigs (which act like lists).
self.assertEqual(simple.set_member, [2])
self.assertTrue(isinstance(simple.set_member, ListConfig))
self.assertEqual(c.a_tuple, [4.0, 3.0])
self.assertTrue(isinstance(c.a_tuple, ListConfig))
self.assertEqual(mydata.tuple_field, (3, True, "j"))
self.assertTrue(isinstance(mydata.tuple_field, ListConfig))
f(**c.f_args)
def test_irrelevant_bases(self):
class NotADataclass:
# Like torch.nn.Module, this class contains annotations
# but is not designed to be dataclass'd.
# This test ensures that such classes, when inherited fron,
# are not accidentally affected by expand_args_fields.
a: int = 9
b: int
class LeftConfigured(Configurable, NotADataclass):
left: int = 1
class RightConfigured(NotADataclass, Configurable):
right: int = 2
class Outer(Configurable):
left: LeftConfigured
right: RightConfigured
def __post_init__(self):
run_auto_creation(self)
outer = Outer(**get_default_args(Outer))
self.assertEqual(outer.left.left, 1)
self.assertEqual(outer.right.right, 2)
with self.assertRaisesRegex(TypeError, "non-default argument"):
dataclass(NotADataclass)
def test_unprocessed(self):
# behavior of Configurable classes which need processing in __new__,
class Unprocessed(Configurable):
a: int = 9
class UnprocessedReplaceable(ReplaceableBase):
a: int = 1
with self.assertWarnsRegex(UserWarning, "must be processed"):
Unprocessed()
with self.assertWarnsRegex(UserWarning, "must be processed"):
UnprocessedReplaceable()
def test_enum(self):
# Test that enum values are kept, i.e. that OmegaConf's runtime checks
# are in use.
class A(Enum):
B1 = "b1"
B2 = "b2"
# Test for a Configurable class, a function, and a regular class.
class C(Configurable):
a: A = A.B1
# Also test for a calllable with enum arguments.
def C_fn(a: A = A.B1):
pass
enable_get_default_args(C_fn)
class C_cl:
def __init__(self, a: A = A.B1) -> None:
pass
enable_get_default_args(C_cl)
for C_ in [C, C_fn, C_cl]:
base = get_default_args(C_)
self.assertEqual(base.a, A.B1)
replaced = OmegaConf.merge(base, {"a": "B2"})
self.assertEqual(replaced.a, A.B2)
with self.assertRaises(ValidationError):
# You can't use a value which is not one of the
# choices, even if it is the str representation
# of one of the choices.
OmegaConf.merge(base, {"a": "b2"})
remerged = OmegaConf.merge(base, OmegaConf.create(OmegaConf.to_yaml(base)))
self.assertEqual(remerged.a, A.B1)
def test_pickle(self):
def func(a: int = 1, b: str = "3"):
pass
enable_get_default_args(func)
args = get_default_args(func)
args2 = pickle.loads(pickle.dumps(args))
self.assertEqual(args2.a, 1)
self.assertEqual(args2.b, "3")
args_regenerated = get_default_args(func)
pickle.dumps(args_regenerated)
pickle.dumps(args)
def test_remove_unused_components(self):
struct = get_default_args(MainTest)
struct.n_ids = 32
struct.the_fruit_class_type = "Pear"
struct.the_second_fruit_class_type = "Banana"
remove_unused_components(struct)
expected_keys = [
"n_ids",
"n_reps",
"the_fruit_Pear_args",
"the_fruit_class_type",
"the_second_fruit_Banana_args",
"the_second_fruit_class_type",
]
expected_yaml = textwrap.dedent(
"""\
n_ids: 32
n_reps: 8
the_fruit_class_type: Pear
the_fruit_Pear_args:
n_pips: 13
the_second_fruit_class_type: Banana
the_second_fruit_Banana_args:
pips: ???
spots: ???
bananame: ???
"""
)
self.assertEqual(sorted(struct.keys()), expected_keys)
# Check that struct is what we expect
expected = OmegaConf.create(expected_yaml)
self.assertEqual(struct, expected)
# Check that we get what we expect when writing to yaml.
self.assertEqual(OmegaConf.to_yaml(struct, sort_keys=False), expected_yaml)
main = MainTest(**struct)
instance_data = OmegaConf.structured(main)
remove_unused_components(instance_data)
self.assertEqual(sorted(instance_data.keys()), expected_keys)
self.assertEqual(instance_data, expected)
def test_remove_unused_components_optional(self):
class MainTestWrapper(Configurable):
mt: Optional[MainTest]
args = get_default_args(MainTestWrapper)
self.assertEqual(list(args.keys()), ["mt_args", "mt_enabled"])
remove_unused_components(args)
self.assertEqual(OmegaConf.to_yaml(args), "mt_enabled: false\n")
@dataclass(eq=False)
class MockDataclass:
field_no_default: int
field_primitive_type: int = 42
field_optional_none: Optional[int] = None
field_optional_with_value: Optional[int] = 42
field_list_type: List[int] = field(default_factory=lambda: [])
class RefObject:
pass
REF_OBJECT = RefObject()
class MockClassWithInit: # noqa: B903
def __init__(
self,
field_no_nothing,
field_no_default: int,
field_primitive_type: int = 42,
field_optional_none: Optional[int] = None,
field_optional_with_value: Optional[int] = 42,
field_list_type: List[int] = [], # noqa: B006
field_reference_type: RefObject = REF_OBJECT,
):
self.field_no_nothing = field_no_nothing
self.field_no_default = field_no_default
self.field_primitive_type = field_primitive_type
self.field_optional_none = field_optional_none
self.field_optional_with_value = field_optional_with_value
self.field_list_type = field_list_type
self.field_reference_type = field_reference_type
enable_get_default_args(MockClassWithInit)
class TestRawClasses(unittest.TestCase):
def setUp(self) -> None:
self._instances = {
MockDataclass: MockDataclass(field_no_default=0),
MockClassWithInit: MockClassWithInit(
field_no_nothing="tratata", field_no_default=0
),
}
def test_get_default_args(self):
for cls in [MockDataclass, MockClassWithInit]:
dataclass_defaults = get_default_args(cls)
# DictConfig fields with missing values are `not in`
self.assertNotIn("field_no_default", dataclass_defaults)
self.assertNotIn("field_no_nothing", dataclass_defaults)
self.assertNotIn("field_reference_type", dataclass_defaults)
if cls == MockDataclass: # we don't remove undefaulted from dataclasses
dataclass_defaults.field_no_default = 0
for name, val in dataclass_defaults.items():
self.assertTrue(hasattr(self._instances[cls], name))
self.assertEqual(val, getattr(self._instances[cls], name))
def test_get_default_args_readonly(self):
for cls in [MockDataclass, MockClassWithInit]:
dataclass_defaults = get_default_args(cls)
dataclass_defaults["field_list_type"].append(13)
self.assertEqual(self._instances[cls].field_list_type, [])