Return a typed structured config from default_args for callables

Summary:
Before the fix, running get_default_args(C: Callable) returns an unstructured DictConfig which causes Enums to be handled incorrectly. This is a fix.

WIP update: Currently tests still fail whenever a function signature contains an untyped argument: This needs to be somehow fixed.

Reviewed By: bottler

Differential Revision: D34932124

fbshipit-source-id: ecdc45c738633cfea5caa7480ba4f790ece931e8
This commit is contained in:
Roman Shapovalov
2022-03-25 07:08:01 -07:00
committed by Facebook GitHub Bot
parent 8ac5e8f083
commit 645a47d054
4 changed files with 133 additions and 50 deletions

View File

@@ -8,7 +8,7 @@ import textwrap
import unittest
from dataclasses import dataclass, field, is_dataclass
from enum import Enum
from typing import List, Optional, Tuple
from typing import Any, List, Optional, Set, Tuple
from omegaconf import DictConfig, ListConfig, OmegaConf, ValidationError
from pytorch3d.implicitron.tools.config import (
@@ -216,7 +216,7 @@ class TestConfig(unittest.TestCase):
# tolerated. But it would be nice to be able to
# configure them.
class Foo:
def __init__(self, a=1, b=2):
def __init__(self, a: Any = 1, b: Any = 2):
self.a, self.b = a, b
@dataclass()
@@ -238,7 +238,6 @@ class TestConfig(unittest.TestCase):
container_args = get_default_args(Container)
container = Container(**container_args)
self.assertIsInstance(container.fruit, Orange)
# self.assertIsInstance(container.bar, Bar)
container_defaulted = Container()
container_defaulted.fruit_Pear_args.n_pips += 4
@@ -432,8 +431,13 @@ class TestConfig(unittest.TestCase):
tuple_field: tuple = (3, True, "j")
class SimpleClass:
def __init__(self, tuple_member_=(3, 4)):
def __init__(
self,
tuple_member_: Tuple[int, int] = (3, 4),
set_member_: Set[int] = {2}, # noqa
):
self.tuple_member = tuple_member_
self.set_member = set_member_
def get_tuple(self):
return self.tuple_member
@@ -459,6 +463,9 @@ class TestConfig(unittest.TestCase):
# OmegaConf converts tuples to ListConfigs (which act like lists).
self.assertEqual(simple.get_tuple(), [3, 4])
self.assertTrue(isinstance(simple.get_tuple(), ListConfig))
# get_default_args converts sets to ListConfigs (which act like lists).
self.assertEqual(simple.set_member, [2])
self.assertTrue(isinstance(simple.set_member, ListConfig))
self.assertEqual(c.a_tuple, [4.0, 3.0])
self.assertTrue(isinstance(c.a_tuple, ListConfig))
self.assertEqual(mydata.tuple_field, (3, True, "j"))
@@ -514,20 +521,31 @@ class TestConfig(unittest.TestCase):
B1 = "b1"
B2 = "b2"
# Test for a Configurable class, a function, and a regular class.
class C(Configurable):
a: A = A.B1
base = get_default_args(C)
replaced = OmegaConf.merge(base, {"a": "B2"})
self.assertEqual(replaced.a, A.B2)
with self.assertRaises(ValidationError):
# You can't use a value which is not one of the
# choices, even if it is the str representation
# of one of the choices.
OmegaConf.merge(base, {"a": "b2"})
# Also test for a calllable with enum arguments.
def C_fn(a: A = A.B1):
pass
remerged = OmegaConf.merge(base, OmegaConf.create(OmegaConf.to_yaml(base)))
self.assertEqual(remerged.a, A.B1)
class C_cl:
def __init__(self, a: A = A.B1) -> None:
pass
for C_ in [C, C_fn, C_cl]:
base = get_default_args(C_)
self.assertEqual(base.a, A.B1)
replaced = OmegaConf.merge(base, {"a": "B2"})
self.assertEqual(replaced.a, A.B2)
with self.assertRaises(ValidationError):
# You can't use a value which is not one of the
# choices, even if it is the str representation
# of one of the choices.
OmegaConf.merge(base, {"a": "b2"})
remerged = OmegaConf.merge(base, OmegaConf.create(OmegaConf.to_yaml(base)))
self.assertEqual(remerged.a, A.B1)
def test_remove_unused_components(self):
struct = get_default_args(MainTest)
@@ -577,34 +595,56 @@ class TestConfig(unittest.TestCase):
class MockDataclass:
field_no_default: int
field_primitive_type: int = 42
field_reference_type: List[int] = field(default_factory=lambda: [])
field_list_type: List[int] = field(default_factory=lambda: [])
class RefObject:
pass
REF_OBJECT = RefObject()
class MockClassWithInit: # noqa: B903
def __init__(
self,
field_no_nothing,
field_no_default: int,
field_primitive_type: int = 42,
field_reference_type: List[int] = [], # noqa: B006
field_list_type: List[int] = [], # noqa: B006
field_reference_type: RefObject = REF_OBJECT,
):
self.field_no_nothing = field_no_nothing
self.field_no_default = field_no_default
self.field_primitive_type = field_primitive_type
self.field_list_type = field_list_type
self.field_reference_type = field_reference_type
class TestRawClasses(unittest.TestCase):
def setUp(self) -> None:
self._instances = {
MockDataclass: MockDataclass(field_no_default=0),
MockClassWithInit: MockClassWithInit(
field_no_nothing="tratata", field_no_default=0
),
}
def test_get_default_args(self):
for cls in [MockDataclass, MockClassWithInit]:
dataclass_defaults = get_default_args(cls)
inst = cls(field_no_default=0)
dataclass_defaults.field_no_default = 0
# DictConfig fields with missing values are `not in`
self.assertNotIn("field_no_default", dataclass_defaults)
self.assertNotIn("field_no_nothing", dataclass_defaults)
self.assertNotIn("field_reference_type", dataclass_defaults)
if cls == MockDataclass: # we don't remove undefaulted from dataclasses
dataclass_defaults.field_no_default = 0
for name, val in dataclass_defaults.items():
self.assertTrue(hasattr(inst, name))
self.assertEqual(val, getattr(inst, name))
self.assertTrue(hasattr(self._instances[cls], name))
self.assertEqual(val, getattr(self._instances[cls], name))
def test_get_default_args_readonly(self):
for cls in [MockDataclass, MockClassWithInit]:
dataclass_defaults = get_default_args(cls)
dataclass_defaults["field_reference_type"].append(13)
inst = cls(field_no_default=0)
self.assertEqual(inst.field_reference_type, [])
dataclass_defaults["field_list_type"].append(13)
self.assertEqual(self._instances[cls].field_list_type, [])