mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Summary: Applies new import merging and sorting from µsort v1.0. When merging imports, µsort will make a best-effort to move associated comments to match merged elements, but there are known limitations due to the diynamic nature of Python and developer tooling. These changes should not produce any dangerous runtime changes, but may require touch-ups to satisfy linters and other tooling. Note that µsort uses case-insensitive, lexicographical sorting, which results in a different ordering compared to isort. This provides a more consistent sorting order, matching the case-insensitive order used when sorting import statements by module name, and ensures that "frog", "FROG", and "Frog" always sort next to each other. For details on µsort's sorting and merging semantics, see the user guide: https://usort.readthedocs.io/en/stable/guide.html#sorting Reviewed By: bottler Differential Revision: D35553814 fbshipit-source-id: be49bdb6a4c25264ff8d4db3a601f18736d17be1
744 lines
24 KiB
Python
744 lines
24 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the BSD-style license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import pickle
|
|
import textwrap
|
|
import unittest
|
|
from dataclasses import dataclass, field, is_dataclass
|
|
from enum import Enum
|
|
from typing import Any, List, Optional, Set, Tuple
|
|
|
|
from omegaconf import DictConfig, ListConfig, OmegaConf, ValidationError
|
|
from pytorch3d.implicitron.tools.config import (
|
|
_get_type_to_process,
|
|
_is_actually_dataclass,
|
|
_ProcessType,
|
|
_Registry,
|
|
Configurable,
|
|
enable_get_default_args,
|
|
expand_args_fields,
|
|
get_default_args,
|
|
get_default_args_field,
|
|
registry,
|
|
remove_unused_components,
|
|
ReplaceableBase,
|
|
run_auto_creation,
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class Animal(ReplaceableBase):
|
|
pass
|
|
|
|
|
|
class Fruit(ReplaceableBase):
|
|
pass
|
|
|
|
|
|
@registry.register
|
|
class Banana(Fruit):
|
|
pips: int
|
|
spots: int
|
|
bananame: str
|
|
|
|
|
|
@registry.register
|
|
class Pear(Fruit):
|
|
n_pips: int = 13
|
|
|
|
|
|
class Pineapple(Fruit):
|
|
pass
|
|
|
|
|
|
@registry.register
|
|
class Orange(Fruit):
|
|
pass
|
|
|
|
|
|
@registry.register
|
|
class Kiwi(Fruit):
|
|
pass
|
|
|
|
|
|
@registry.register
|
|
class LargePear(Pear):
|
|
pass
|
|
|
|
|
|
class MainTest(Configurable):
|
|
the_fruit: Fruit
|
|
n_ids: int
|
|
n_reps: int = 8
|
|
the_second_fruit: Fruit
|
|
|
|
def create_the_second_fruit(self):
|
|
expand_args_fields(Pineapple)
|
|
self.the_second_fruit = Pineapple()
|
|
|
|
def __post_init__(self):
|
|
run_auto_creation(self)
|
|
|
|
|
|
class TestConfig(unittest.TestCase):
|
|
def test_is_actually_dataclass(self):
|
|
@dataclass
|
|
class A:
|
|
pass
|
|
|
|
self.assertTrue(_is_actually_dataclass(A))
|
|
self.assertTrue(is_dataclass(A))
|
|
|
|
class B(A):
|
|
a: int
|
|
|
|
self.assertFalse(_is_actually_dataclass(B))
|
|
self.assertTrue(is_dataclass(B))
|
|
|
|
def test_get_type_to_process(self):
|
|
gt = _get_type_to_process
|
|
self.assertIsNone(gt(int))
|
|
self.assertEqual(gt(Fruit), (Fruit, _ProcessType.REPLACEABLE))
|
|
self.assertEqual(
|
|
gt(Optional[Fruit]), (Fruit, _ProcessType.OPTIONAL_REPLACEABLE)
|
|
)
|
|
self.assertEqual(gt(MainTest), (MainTest, _ProcessType.CONFIGURABLE))
|
|
self.assertEqual(
|
|
gt(Optional[MainTest]), (MainTest, _ProcessType.OPTIONAL_CONFIGURABLE)
|
|
)
|
|
self.assertIsNone(gt(Optional[int]))
|
|
self.assertIsNone(gt(Tuple[Fruit]))
|
|
self.assertIsNone(gt(Tuple[Fruit, Animal]))
|
|
self.assertIsNone(gt(Optional[List[int]]))
|
|
|
|
def test_simple_replacement(self):
|
|
struct = get_default_args(MainTest)
|
|
struct.n_ids = 9780
|
|
struct.the_fruit_Pear_args.n_pips = 3
|
|
struct.the_fruit_class_type = "Pear"
|
|
struct.the_second_fruit_class_type = "Pear"
|
|
|
|
main = MainTest(**struct)
|
|
self.assertIsInstance(main.the_fruit, Pear)
|
|
self.assertEqual(main.n_reps, 8)
|
|
self.assertEqual(main.n_ids, 9780)
|
|
self.assertEqual(main.the_fruit.n_pips, 3)
|
|
self.assertIsInstance(main.the_second_fruit, Pineapple)
|
|
|
|
struct2 = get_default_args(MainTest)
|
|
self.assertEqual(struct2.the_fruit_Pear_args.n_pips, 13)
|
|
|
|
self.assertEqual(
|
|
MainTest._creation_functions,
|
|
("create_the_fruit", "create_the_second_fruit"),
|
|
)
|
|
|
|
def test_detect_bases(self):
|
|
# testing the _base_class_from_class function
|
|
self.assertIsNone(_Registry._base_class_from_class(ReplaceableBase))
|
|
self.assertIsNone(_Registry._base_class_from_class(MainTest))
|
|
self.assertIs(_Registry._base_class_from_class(Fruit), Fruit)
|
|
self.assertIs(_Registry._base_class_from_class(Pear), Fruit)
|
|
|
|
class PricklyPear(Pear):
|
|
pass
|
|
|
|
self.assertIs(_Registry._base_class_from_class(PricklyPear), Fruit)
|
|
|
|
def test_registry_entries(self):
|
|
self.assertIs(registry.get(Fruit, "Banana"), Banana)
|
|
with self.assertRaisesRegex(ValueError, "Banana has not been registered."):
|
|
registry.get(Animal, "Banana")
|
|
with self.assertRaisesRegex(ValueError, "PricklyPear has not been registered."):
|
|
registry.get(Fruit, "PricklyPear")
|
|
|
|
self.assertIs(registry.get(Pear, "Pear"), Pear)
|
|
self.assertIs(registry.get(Pear, "LargePear"), LargePear)
|
|
with self.assertRaisesRegex(ValueError, "Banana resolves to"):
|
|
registry.get(Pear, "Banana")
|
|
|
|
all_fruit = set(registry.get_all(Fruit))
|
|
self.assertIn(Banana, all_fruit)
|
|
self.assertIn(Pear, all_fruit)
|
|
self.assertIn(LargePear, all_fruit)
|
|
self.assertEqual(set(registry.get_all(Pear)), {LargePear})
|
|
|
|
@registry.register
|
|
class Apple(Fruit):
|
|
pass
|
|
|
|
@registry.register
|
|
class CrabApple(Apple):
|
|
pass
|
|
|
|
self.assertEqual(set(registry.get_all(Apple)), {CrabApple})
|
|
|
|
self.assertIs(registry.get(Fruit, "CrabApple"), CrabApple)
|
|
|
|
with self.assertRaisesRegex(ValueError, "Cannot tell what it is."):
|
|
|
|
@registry.register
|
|
class NotAFruit:
|
|
pass
|
|
|
|
def test_recursion(self):
|
|
class Shape(ReplaceableBase):
|
|
pass
|
|
|
|
@registry.register
|
|
class Triangle(Shape):
|
|
a: float = 5.0
|
|
|
|
@registry.register
|
|
class Square(Shape):
|
|
a: float = 3.0
|
|
|
|
@registry.register
|
|
class LargeShape(Shape):
|
|
inner: Shape
|
|
|
|
def __post_init__(self):
|
|
run_auto_creation(self)
|
|
|
|
class ShapeContainer(Configurable):
|
|
shape: Shape
|
|
|
|
container = ShapeContainer(**get_default_args(ShapeContainer))
|
|
# This is because ShapeContainer is missing __post_init__
|
|
with self.assertRaises(AttributeError):
|
|
container.shape
|
|
|
|
class ShapeContainer2(Configurable):
|
|
x: Shape
|
|
x_class_type: str = "LargeShape"
|
|
|
|
def __post_init__(self):
|
|
self.x_LargeShape_args.inner_class_type = "Triangle"
|
|
run_auto_creation(self)
|
|
|
|
container2_args = get_default_args(ShapeContainer2)
|
|
container2_args.x_LargeShape_args.inner_Triangle_args.a += 10
|
|
self.assertIn("inner_Square_args", container2_args.x_LargeShape_args)
|
|
# We do not perform expansion that would result in an infinite recursion,
|
|
# so this member is not present.
|
|
self.assertNotIn("inner_LargeShape_args", container2_args.x_LargeShape_args)
|
|
container2_args.x_LargeShape_args.inner_Square_args.a += 100
|
|
container2 = ShapeContainer2(**container2_args)
|
|
self.assertIsInstance(container2.x, LargeShape)
|
|
self.assertIsInstance(container2.x.inner, Triangle)
|
|
self.assertEqual(container2.x.inner.a, 15.0)
|
|
|
|
def test_simpleclass_member(self):
|
|
# Members which are not dataclasses are
|
|
# tolerated. But it would be nice to be able to
|
|
# configure them.
|
|
class Foo:
|
|
def __init__(self, a: Any = 1, b: Any = 2):
|
|
self.a, self.b = a, b
|
|
|
|
enable_get_default_args(Foo)
|
|
|
|
@dataclass()
|
|
class Bar:
|
|
aa: int = 9
|
|
bb: int = 9
|
|
|
|
class Container(Configurable):
|
|
bar: Bar = Bar()
|
|
# TODO make this work?
|
|
# foo: Foo = Foo()
|
|
fruit: Fruit
|
|
fruit_class_type: str = "Orange"
|
|
|
|
def __post_init__(self):
|
|
run_auto_creation(self)
|
|
|
|
self.assertEqual(get_default_args(Foo), {"a": 1, "b": 2})
|
|
container_args = get_default_args(Container)
|
|
container = Container(**container_args)
|
|
self.assertIsInstance(container.fruit, Orange)
|
|
self.assertEqual(Container._processed_members, {"fruit": Fruit})
|
|
self.assertEqual(container._processed_members, {"fruit": Fruit})
|
|
|
|
container_defaulted = Container()
|
|
container_defaulted.fruit_Pear_args.n_pips += 4
|
|
|
|
container_args2 = get_default_args(Container)
|
|
container = Container(**container_args2)
|
|
self.assertEqual(container.fruit_Pear_args.n_pips, 13)
|
|
|
|
def test_inheritance(self):
|
|
# Also exercises optional replaceables
|
|
class FruitBowl(ReplaceableBase):
|
|
main_fruit: Fruit
|
|
main_fruit_class_type: str = "Orange"
|
|
|
|
def __post_init__(self):
|
|
raise ValueError("This doesn't get called")
|
|
|
|
class LargeFruitBowl(FruitBowl):
|
|
extra_fruit: Optional[Fruit]
|
|
extra_fruit_class_type: str = "Kiwi"
|
|
no_fruit: Optional[Fruit]
|
|
no_fruit_class_type: Optional[str] = None
|
|
|
|
def __post_init__(self):
|
|
run_auto_creation(self)
|
|
|
|
large_args = get_default_args(LargeFruitBowl)
|
|
self.assertNotIn("extra_fruit", large_args)
|
|
self.assertNotIn("main_fruit", large_args)
|
|
large = LargeFruitBowl(**large_args)
|
|
self.assertIsInstance(large.main_fruit, Orange)
|
|
self.assertIsInstance(large.extra_fruit, Kiwi)
|
|
self.assertIsNone(large.no_fruit)
|
|
self.assertIn("no_fruit_Kiwi_args", large_args)
|
|
|
|
remove_unused_components(large_args)
|
|
large2 = LargeFruitBowl(**large_args)
|
|
self.assertIsInstance(large2.main_fruit, Orange)
|
|
self.assertIsInstance(large2.extra_fruit, Kiwi)
|
|
self.assertIsNone(large2.no_fruit)
|
|
needed_args = [
|
|
"extra_fruit_Kiwi_args",
|
|
"extra_fruit_class_type",
|
|
"main_fruit_Orange_args",
|
|
"main_fruit_class_type",
|
|
"no_fruit_class_type",
|
|
]
|
|
self.assertEqual(sorted(large_args.keys()), needed_args)
|
|
|
|
def test_inheritance2(self):
|
|
# This is a case where a class could contain an instance
|
|
# of a subclass, which is ignored.
|
|
class Parent(ReplaceableBase):
|
|
pass
|
|
|
|
class Main(Configurable):
|
|
parent: Parent
|
|
# Note - no __post__init__
|
|
|
|
@registry.register
|
|
class Derived(Parent, Main):
|
|
pass
|
|
|
|
args = get_default_args(Main)
|
|
# Derived has been ignored in processing Main.
|
|
self.assertCountEqual(args.keys(), ["parent_class_type"])
|
|
|
|
main = Main(**args)
|
|
|
|
with self.assertRaisesRegex(ValueError, "UNDEFAULTED has not been registered."):
|
|
run_auto_creation(main)
|
|
|
|
main.parent_class_type = "Derived"
|
|
# Illustrates that a dict works fine instead of a DictConfig.
|
|
main.parent_Derived_args = {}
|
|
with self.assertRaises(AttributeError):
|
|
main.parent
|
|
run_auto_creation(main)
|
|
self.assertIsInstance(main.parent, Derived)
|
|
|
|
def test_redefine(self):
|
|
class FruitBowl(ReplaceableBase):
|
|
main_fruit: Fruit
|
|
main_fruit_class_type: str = "Grape"
|
|
|
|
def __post_init__(self):
|
|
run_auto_creation(self)
|
|
|
|
@registry.register
|
|
@dataclass
|
|
class Grape(Fruit):
|
|
large: bool = False
|
|
|
|
def get_color(self):
|
|
return "red"
|
|
|
|
def __post_init__(self):
|
|
raise ValueError("This doesn't get called")
|
|
|
|
bowl_args = get_default_args(FruitBowl)
|
|
|
|
@registry.register
|
|
@dataclass
|
|
class Grape(Fruit): # noqa: F811
|
|
large: bool = True
|
|
|
|
def get_color(self):
|
|
return "green"
|
|
|
|
with self.assertWarnsRegex(
|
|
UserWarning, "New implementation of Grape is being chosen."
|
|
):
|
|
bowl = FruitBowl(**bowl_args)
|
|
self.assertIsInstance(bowl.main_fruit, Grape)
|
|
|
|
# Redefining the same class won't help with defaults because encoded in args
|
|
self.assertEqual(bowl.main_fruit.large, False)
|
|
|
|
# But the override worked.
|
|
self.assertEqual(bowl.main_fruit.get_color(), "green")
|
|
|
|
# 2. Try redefining without the dataclass modifier
|
|
# This relies on the fact that default creation processes the class.
|
|
# (otherwise incomprehensible messages)
|
|
@registry.register
|
|
class Grape(Fruit): # noqa: F811
|
|
large: bool = True
|
|
|
|
with self.assertWarnsRegex(
|
|
UserWarning, "New implementation of Grape is being chosen."
|
|
):
|
|
bowl = FruitBowl(**bowl_args)
|
|
|
|
# 3. Adding a new class doesn't get picked up, because the first
|
|
# get_default_args call has frozen FruitBowl. This is intrinsic to
|
|
# the way dataclass and expand_args_fields work in-place but
|
|
# expand_args_fields is not pure - it depends on the registry.
|
|
@registry.register
|
|
class Fig(Fruit):
|
|
pass
|
|
|
|
bowl_args2 = get_default_args(FruitBowl)
|
|
self.assertIn("main_fruit_Grape_args", bowl_args2)
|
|
self.assertNotIn("main_fruit_Fig_args", bowl_args2)
|
|
|
|
# TODO Is it possible to make this work?
|
|
# bowl_args2["main_fruit_Fig_args"] = get_default_args(Fig)
|
|
# bowl_args2.main_fruit_class_type = "Fig"
|
|
# bowl2 = FruitBowl(**bowl_args2) <= unexpected argument
|
|
|
|
# Note that it is possible to use Fig if you can set
|
|
# bowl2.main_fruit_Fig_args explicitly (not in bowl_args2)
|
|
# before run_auto_creation happens. See test_inheritance2
|
|
# for an example.
|
|
|
|
def test_no_replacement(self):
|
|
# Test of Configurables without ReplaceableBase
|
|
class A(Configurable):
|
|
n: int = 9
|
|
|
|
class B(Configurable):
|
|
a: A
|
|
|
|
def __post_init__(self):
|
|
run_auto_creation(self)
|
|
|
|
class C(Configurable):
|
|
b1: B
|
|
b2: Optional[B]
|
|
b3: Optional[B]
|
|
b2_enabled: bool = True
|
|
|
|
def __post_init__(self):
|
|
run_auto_creation(self)
|
|
|
|
c_args = get_default_args(C)
|
|
c = C(**c_args)
|
|
self.assertIsInstance(c.b1.a, A)
|
|
self.assertEqual(c.b1.a.n, 9)
|
|
self.assertFalse(hasattr(c, "b1_enabled"))
|
|
self.assertIsInstance(c.b2.a, A)
|
|
self.assertEqual(c.b2.a.n, 9)
|
|
self.assertTrue(c.b2_enabled)
|
|
self.assertIsNone(c.b3)
|
|
self.assertFalse(c.b3_enabled)
|
|
|
|
def test_doc(self):
|
|
# The case in the docstring.
|
|
class A(ReplaceableBase):
|
|
k: int = 1
|
|
|
|
@registry.register
|
|
class A1(A):
|
|
m: int = 3
|
|
|
|
@registry.register
|
|
class A2(A):
|
|
n: str = "2"
|
|
|
|
class B(Configurable):
|
|
a: A
|
|
a_class_type: str = "A2"
|
|
|
|
def __post_init__(self):
|
|
run_auto_creation(self)
|
|
|
|
b_args = get_default_args(B)
|
|
self.assertNotIn("a", b_args)
|
|
b = B(**b_args)
|
|
self.assertEqual(b.a.n, "2")
|
|
|
|
def test_raw_types(self):
|
|
@dataclass
|
|
class MyDataclass:
|
|
int_field: int = 0
|
|
none_field: Optional[int] = None
|
|
float_field: float = 9.3
|
|
bool_field: bool = True
|
|
tuple_field: tuple = (3, True, "j")
|
|
|
|
class SimpleClass:
|
|
def __init__(
|
|
self,
|
|
tuple_member_: Tuple[int, int] = (3, 4),
|
|
set_member_: Set[int] = {2}, # noqa
|
|
):
|
|
self.tuple_member = tuple_member_
|
|
self.set_member = set_member_
|
|
|
|
def get_tuple(self):
|
|
return self.tuple_member
|
|
|
|
enable_get_default_args(SimpleClass)
|
|
|
|
def f(*, a: int = 3, b: str = "kj"):
|
|
self.assertEqual(a, 3)
|
|
self.assertEqual(b, "kj")
|
|
|
|
enable_get_default_args(f)
|
|
|
|
class C(Configurable):
|
|
simple: DictConfig = get_default_args_field(SimpleClass)
|
|
# simple2: SimpleClass2 = SimpleClass2()
|
|
mydata: DictConfig = get_default_args_field(MyDataclass)
|
|
a_tuple: Tuple[float] = (4.0, 3.0)
|
|
f_args: DictConfig = get_default_args_field(f)
|
|
|
|
args = get_default_args(C)
|
|
c = C(**args)
|
|
self.assertCountEqual(args.keys(), ["simple", "mydata", "a_tuple", "f_args"])
|
|
|
|
mydata = MyDataclass(**c.mydata)
|
|
simple = SimpleClass(**c.simple)
|
|
|
|
# OmegaConf converts tuples to ListConfigs (which act like lists).
|
|
self.assertEqual(simple.get_tuple(), [3, 4])
|
|
self.assertTrue(isinstance(simple.get_tuple(), ListConfig))
|
|
# get_default_args converts sets to ListConfigs (which act like lists).
|
|
self.assertEqual(simple.set_member, [2])
|
|
self.assertTrue(isinstance(simple.set_member, ListConfig))
|
|
self.assertEqual(c.a_tuple, [4.0, 3.0])
|
|
self.assertTrue(isinstance(c.a_tuple, ListConfig))
|
|
self.assertEqual(mydata.tuple_field, (3, True, "j"))
|
|
self.assertTrue(isinstance(mydata.tuple_field, ListConfig))
|
|
f(**c.f_args)
|
|
|
|
def test_irrelevant_bases(self):
|
|
class NotADataclass:
|
|
# Like torch.nn.Module, this class contains annotations
|
|
# but is not designed to be dataclass'd.
|
|
# This test ensures that such classes, when inherited fron,
|
|
# are not accidentally affected by expand_args_fields.
|
|
a: int = 9
|
|
b: int
|
|
|
|
class LeftConfigured(Configurable, NotADataclass):
|
|
left: int = 1
|
|
|
|
class RightConfigured(NotADataclass, Configurable):
|
|
right: int = 2
|
|
|
|
class Outer(Configurable):
|
|
left: LeftConfigured
|
|
right: RightConfigured
|
|
|
|
def __post_init__(self):
|
|
run_auto_creation(self)
|
|
|
|
outer = Outer(**get_default_args(Outer))
|
|
self.assertEqual(outer.left.left, 1)
|
|
self.assertEqual(outer.right.right, 2)
|
|
with self.assertRaisesRegex(TypeError, "non-default argument"):
|
|
dataclass(NotADataclass)
|
|
|
|
def test_unprocessed(self):
|
|
# behavior of Configurable classes which need processing in __new__,
|
|
class Unprocessed(Configurable):
|
|
a: int = 9
|
|
|
|
class UnprocessedReplaceable(ReplaceableBase):
|
|
a: int = 1
|
|
|
|
with self.assertWarnsRegex(UserWarning, "must be processed"):
|
|
Unprocessed()
|
|
with self.assertWarnsRegex(UserWarning, "must be processed"):
|
|
UnprocessedReplaceable()
|
|
|
|
def test_enum(self):
|
|
# Test that enum values are kept, i.e. that OmegaConf's runtime checks
|
|
# are in use.
|
|
|
|
class A(Enum):
|
|
B1 = "b1"
|
|
B2 = "b2"
|
|
|
|
# Test for a Configurable class, a function, and a regular class.
|
|
class C(Configurable):
|
|
a: A = A.B1
|
|
|
|
# Also test for a calllable with enum arguments.
|
|
def C_fn(a: A = A.B1):
|
|
pass
|
|
|
|
enable_get_default_args(C_fn)
|
|
|
|
class C_cl:
|
|
def __init__(self, a: A = A.B1) -> None:
|
|
pass
|
|
|
|
enable_get_default_args(C_cl)
|
|
|
|
for C_ in [C, C_fn, C_cl]:
|
|
base = get_default_args(C_)
|
|
self.assertEqual(base.a, A.B1)
|
|
replaced = OmegaConf.merge(base, {"a": "B2"})
|
|
self.assertEqual(replaced.a, A.B2)
|
|
with self.assertRaises(ValidationError):
|
|
# You can't use a value which is not one of the
|
|
# choices, even if it is the str representation
|
|
# of one of the choices.
|
|
OmegaConf.merge(base, {"a": "b2"})
|
|
|
|
remerged = OmegaConf.merge(base, OmegaConf.create(OmegaConf.to_yaml(base)))
|
|
self.assertEqual(remerged.a, A.B1)
|
|
|
|
def test_pickle(self):
|
|
def func(a: int = 1, b: str = "3"):
|
|
pass
|
|
|
|
enable_get_default_args(func)
|
|
|
|
args = get_default_args(func)
|
|
args2 = pickle.loads(pickle.dumps(args))
|
|
self.assertEqual(args2.a, 1)
|
|
self.assertEqual(args2.b, "3")
|
|
|
|
args_regenerated = get_default_args(func)
|
|
pickle.dumps(args_regenerated)
|
|
pickle.dumps(args)
|
|
|
|
def test_remove_unused_components(self):
|
|
struct = get_default_args(MainTest)
|
|
struct.n_ids = 32
|
|
struct.the_fruit_class_type = "Pear"
|
|
struct.the_second_fruit_class_type = "Banana"
|
|
remove_unused_components(struct)
|
|
expected_keys = [
|
|
"n_ids",
|
|
"n_reps",
|
|
"the_fruit_Pear_args",
|
|
"the_fruit_class_type",
|
|
"the_second_fruit_Banana_args",
|
|
"the_second_fruit_class_type",
|
|
]
|
|
expected_yaml = textwrap.dedent(
|
|
"""\
|
|
n_ids: 32
|
|
n_reps: 8
|
|
the_fruit_class_type: Pear
|
|
the_fruit_Pear_args:
|
|
n_pips: 13
|
|
the_second_fruit_class_type: Banana
|
|
the_second_fruit_Banana_args:
|
|
pips: ???
|
|
spots: ???
|
|
bananame: ???
|
|
"""
|
|
)
|
|
self.assertEqual(sorted(struct.keys()), expected_keys)
|
|
|
|
# Check that struct is what we expect
|
|
expected = OmegaConf.create(expected_yaml)
|
|
self.assertEqual(struct, expected)
|
|
|
|
# Check that we get what we expect when writing to yaml.
|
|
self.assertEqual(OmegaConf.to_yaml(struct, sort_keys=False), expected_yaml)
|
|
|
|
main = MainTest(**struct)
|
|
instance_data = OmegaConf.structured(main)
|
|
remove_unused_components(instance_data)
|
|
self.assertEqual(sorted(instance_data.keys()), expected_keys)
|
|
self.assertEqual(instance_data, expected)
|
|
|
|
def test_remove_unused_components_optional(self):
|
|
class MainTestWrapper(Configurable):
|
|
mt: Optional[MainTest]
|
|
|
|
args = get_default_args(MainTestWrapper)
|
|
self.assertEqual(list(args.keys()), ["mt_args", "mt_enabled"])
|
|
remove_unused_components(args)
|
|
self.assertEqual(OmegaConf.to_yaml(args), "mt_enabled: false\n")
|
|
|
|
|
|
@dataclass(eq=False)
|
|
class MockDataclass:
|
|
field_no_default: int
|
|
field_primitive_type: int = 42
|
|
field_optional_none: Optional[int] = None
|
|
field_optional_with_value: Optional[int] = 42
|
|
field_list_type: List[int] = field(default_factory=lambda: [])
|
|
|
|
|
|
class RefObject:
|
|
pass
|
|
|
|
|
|
REF_OBJECT = RefObject()
|
|
|
|
|
|
class MockClassWithInit: # noqa: B903
|
|
def __init__(
|
|
self,
|
|
field_no_nothing,
|
|
field_no_default: int,
|
|
field_primitive_type: int = 42,
|
|
field_optional_none: Optional[int] = None,
|
|
field_optional_with_value: Optional[int] = 42,
|
|
field_list_type: List[int] = [], # noqa: B006
|
|
field_reference_type: RefObject = REF_OBJECT,
|
|
):
|
|
self.field_no_nothing = field_no_nothing
|
|
self.field_no_default = field_no_default
|
|
self.field_primitive_type = field_primitive_type
|
|
self.field_optional_none = field_optional_none
|
|
self.field_optional_with_value = field_optional_with_value
|
|
self.field_list_type = field_list_type
|
|
self.field_reference_type = field_reference_type
|
|
|
|
|
|
enable_get_default_args(MockClassWithInit)
|
|
|
|
|
|
class TestRawClasses(unittest.TestCase):
|
|
def setUp(self) -> None:
|
|
self._instances = {
|
|
MockDataclass: MockDataclass(field_no_default=0),
|
|
MockClassWithInit: MockClassWithInit(
|
|
field_no_nothing="tratata", field_no_default=0
|
|
),
|
|
}
|
|
|
|
def test_get_default_args(self):
|
|
for cls in [MockDataclass, MockClassWithInit]:
|
|
dataclass_defaults = get_default_args(cls)
|
|
# DictConfig fields with missing values are `not in`
|
|
self.assertNotIn("field_no_default", dataclass_defaults)
|
|
self.assertNotIn("field_no_nothing", dataclass_defaults)
|
|
self.assertNotIn("field_reference_type", dataclass_defaults)
|
|
if cls == MockDataclass: # we don't remove undefaulted from dataclasses
|
|
dataclass_defaults.field_no_default = 0
|
|
for name, val in dataclass_defaults.items():
|
|
self.assertTrue(hasattr(self._instances[cls], name))
|
|
self.assertEqual(val, getattr(self._instances[cls], name))
|
|
|
|
def test_get_default_args_readonly(self):
|
|
for cls in [MockDataclass, MockClassWithInit]:
|
|
dataclass_defaults = get_default_args(cls)
|
|
dataclass_defaults["field_list_type"].append(13)
|
|
self.assertEqual(self._instances[cls].field_list_type, [])
|