mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-14 11:26:24 +08:00
Return a typed structured config from default_args for callables
Summary: Before the fix, running get_default_args(C: Callable) returns an unstructured DictConfig which causes Enums to be handled incorrectly. This is a fix. WIP update: Currently tests still fail whenever a function signature contains an untyped argument: This needs to be somehow fixed. Reviewed By: bottler Differential Revision: D34932124 fbshipit-source-id: ecdc45c738633cfea5caa7480ba4f790ece931e8
This commit is contained in:
committed by
Facebook GitHub Bot
parent
8ac5e8f083
commit
645a47d054
@@ -8,7 +8,7 @@ import textwrap
|
||||
import unittest
|
||||
from dataclasses import dataclass, field, is_dataclass
|
||||
from enum import Enum
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import Any, List, Optional, Set, Tuple
|
||||
|
||||
from omegaconf import DictConfig, ListConfig, OmegaConf, ValidationError
|
||||
from pytorch3d.implicitron.tools.config import (
|
||||
@@ -216,7 +216,7 @@ class TestConfig(unittest.TestCase):
|
||||
# tolerated. But it would be nice to be able to
|
||||
# configure them.
|
||||
class Foo:
|
||||
def __init__(self, a=1, b=2):
|
||||
def __init__(self, a: Any = 1, b: Any = 2):
|
||||
self.a, self.b = a, b
|
||||
|
||||
@dataclass()
|
||||
@@ -238,7 +238,6 @@ class TestConfig(unittest.TestCase):
|
||||
container_args = get_default_args(Container)
|
||||
container = Container(**container_args)
|
||||
self.assertIsInstance(container.fruit, Orange)
|
||||
# self.assertIsInstance(container.bar, Bar)
|
||||
|
||||
container_defaulted = Container()
|
||||
container_defaulted.fruit_Pear_args.n_pips += 4
|
||||
@@ -432,8 +431,13 @@ class TestConfig(unittest.TestCase):
|
||||
tuple_field: tuple = (3, True, "j")
|
||||
|
||||
class SimpleClass:
|
||||
def __init__(self, tuple_member_=(3, 4)):
|
||||
def __init__(
|
||||
self,
|
||||
tuple_member_: Tuple[int, int] = (3, 4),
|
||||
set_member_: Set[int] = {2}, # noqa
|
||||
):
|
||||
self.tuple_member = tuple_member_
|
||||
self.set_member = set_member_
|
||||
|
||||
def get_tuple(self):
|
||||
return self.tuple_member
|
||||
@@ -459,6 +463,9 @@ class TestConfig(unittest.TestCase):
|
||||
# OmegaConf converts tuples to ListConfigs (which act like lists).
|
||||
self.assertEqual(simple.get_tuple(), [3, 4])
|
||||
self.assertTrue(isinstance(simple.get_tuple(), ListConfig))
|
||||
# get_default_args converts sets to ListConfigs (which act like lists).
|
||||
self.assertEqual(simple.set_member, [2])
|
||||
self.assertTrue(isinstance(simple.set_member, ListConfig))
|
||||
self.assertEqual(c.a_tuple, [4.0, 3.0])
|
||||
self.assertTrue(isinstance(c.a_tuple, ListConfig))
|
||||
self.assertEqual(mydata.tuple_field, (3, True, "j"))
|
||||
@@ -514,20 +521,31 @@ class TestConfig(unittest.TestCase):
|
||||
B1 = "b1"
|
||||
B2 = "b2"
|
||||
|
||||
# Test for a Configurable class, a function, and a regular class.
|
||||
class C(Configurable):
|
||||
a: A = A.B1
|
||||
|
||||
base = get_default_args(C)
|
||||
replaced = OmegaConf.merge(base, {"a": "B2"})
|
||||
self.assertEqual(replaced.a, A.B2)
|
||||
with self.assertRaises(ValidationError):
|
||||
# You can't use a value which is not one of the
|
||||
# choices, even if it is the str representation
|
||||
# of one of the choices.
|
||||
OmegaConf.merge(base, {"a": "b2"})
|
||||
# Also test for a calllable with enum arguments.
|
||||
def C_fn(a: A = A.B1):
|
||||
pass
|
||||
|
||||
remerged = OmegaConf.merge(base, OmegaConf.create(OmegaConf.to_yaml(base)))
|
||||
self.assertEqual(remerged.a, A.B1)
|
||||
class C_cl:
|
||||
def __init__(self, a: A = A.B1) -> None:
|
||||
pass
|
||||
|
||||
for C_ in [C, C_fn, C_cl]:
|
||||
base = get_default_args(C_)
|
||||
self.assertEqual(base.a, A.B1)
|
||||
replaced = OmegaConf.merge(base, {"a": "B2"})
|
||||
self.assertEqual(replaced.a, A.B2)
|
||||
with self.assertRaises(ValidationError):
|
||||
# You can't use a value which is not one of the
|
||||
# choices, even if it is the str representation
|
||||
# of one of the choices.
|
||||
OmegaConf.merge(base, {"a": "b2"})
|
||||
|
||||
remerged = OmegaConf.merge(base, OmegaConf.create(OmegaConf.to_yaml(base)))
|
||||
self.assertEqual(remerged.a, A.B1)
|
||||
|
||||
def test_remove_unused_components(self):
|
||||
struct = get_default_args(MainTest)
|
||||
@@ -577,34 +595,56 @@ class TestConfig(unittest.TestCase):
|
||||
class MockDataclass:
|
||||
field_no_default: int
|
||||
field_primitive_type: int = 42
|
||||
field_reference_type: List[int] = field(default_factory=lambda: [])
|
||||
field_list_type: List[int] = field(default_factory=lambda: [])
|
||||
|
||||
|
||||
class RefObject:
|
||||
pass
|
||||
|
||||
|
||||
REF_OBJECT = RefObject()
|
||||
|
||||
|
||||
class MockClassWithInit: # noqa: B903
|
||||
def __init__(
|
||||
self,
|
||||
field_no_nothing,
|
||||
field_no_default: int,
|
||||
field_primitive_type: int = 42,
|
||||
field_reference_type: List[int] = [], # noqa: B006
|
||||
field_list_type: List[int] = [], # noqa: B006
|
||||
field_reference_type: RefObject = REF_OBJECT,
|
||||
):
|
||||
self.field_no_nothing = field_no_nothing
|
||||
self.field_no_default = field_no_default
|
||||
self.field_primitive_type = field_primitive_type
|
||||
self.field_list_type = field_list_type
|
||||
self.field_reference_type = field_reference_type
|
||||
|
||||
|
||||
class TestRawClasses(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self._instances = {
|
||||
MockDataclass: MockDataclass(field_no_default=0),
|
||||
MockClassWithInit: MockClassWithInit(
|
||||
field_no_nothing="tratata", field_no_default=0
|
||||
),
|
||||
}
|
||||
|
||||
def test_get_default_args(self):
|
||||
for cls in [MockDataclass, MockClassWithInit]:
|
||||
dataclass_defaults = get_default_args(cls)
|
||||
inst = cls(field_no_default=0)
|
||||
dataclass_defaults.field_no_default = 0
|
||||
# DictConfig fields with missing values are `not in`
|
||||
self.assertNotIn("field_no_default", dataclass_defaults)
|
||||
self.assertNotIn("field_no_nothing", dataclass_defaults)
|
||||
self.assertNotIn("field_reference_type", dataclass_defaults)
|
||||
if cls == MockDataclass: # we don't remove undefaulted from dataclasses
|
||||
dataclass_defaults.field_no_default = 0
|
||||
for name, val in dataclass_defaults.items():
|
||||
self.assertTrue(hasattr(inst, name))
|
||||
self.assertEqual(val, getattr(inst, name))
|
||||
self.assertTrue(hasattr(self._instances[cls], name))
|
||||
self.assertEqual(val, getattr(self._instances[cls], name))
|
||||
|
||||
def test_get_default_args_readonly(self):
|
||||
for cls in [MockDataclass, MockClassWithInit]:
|
||||
dataclass_defaults = get_default_args(cls)
|
||||
dataclass_defaults["field_reference_type"].append(13)
|
||||
inst = cls(field_no_default=0)
|
||||
self.assertEqual(inst.field_reference_type, [])
|
||||
dataclass_defaults["field_list_type"].append(13)
|
||||
self.assertEqual(self._instances[cls].field_list_type, [])
|
||||
|
||||
Reference in New Issue
Block a user