Return a typed structured config from default_args for callables

Summary:
Before the fix, running get_default_args(C: Callable) returns an unstructured DictConfig which causes Enums to be handled incorrectly. This is a fix.

WIP update: Currently tests still fail whenever a function signature contains an untyped argument: This needs to be somehow fixed.

Reviewed By: bottler

Differential Revision: D34932124

fbshipit-source-id: ecdc45c738633cfea5caa7480ba4f790ece931e8
This commit is contained in:
Roman Shapovalov 2022-03-25 07:08:01 -07:00 committed by Facebook GitHub Bot
parent 8ac5e8f083
commit 645a47d054
4 changed files with 133 additions and 50 deletions

View File

@ -213,7 +213,7 @@ def init_optimizer(
model: GenericModel,
optimizer_state: Optional[Dict[str, Any]],
last_epoch: int,
breed: bool = "adam",
breed: str = "adam",
weight_decay: float = 0.0,
lr_policy: str = "multistep",
lr: float = 0.0005,

View File

@ -2,6 +2,9 @@
# Adapted from RenderingNetwork from IDR
# https://github.com/lioryariv/idr/
# Copyright (c) 2020 Lior Yariv
from typing import List, Tuple
import torch
from pytorch3d.renderer.implicit import HarmonicEmbedding, RayBundle
from torch import nn
@ -10,38 +13,38 @@ from torch import nn
class RayNormalColoringNetwork(torch.nn.Module):
def __init__(
self,
feature_vector_size=3,
mode="idr",
d_in=9,
d_out=3,
dims=(512, 512, 512, 512),
weight_norm=True,
n_harmonic_functions_dir=0,
pooled_feature_dim=0,
):
feature_vector_size: int = 3,
mode: str = "idr",
d_in: int = 9,
d_out: int = 3,
dims: Tuple[int, ...] = (512, 512, 512, 512),
weight_norm: bool = True,
n_harmonic_functions_dir: int = 0,
pooled_feature_dim: int = 0,
) -> None:
super().__init__()
self.mode = mode
self.output_dimensions = d_out
dims = [d_in + feature_vector_size] + list(dims) + [d_out]
dims_full: List[int] = [d_in + feature_vector_size] + list(dims) + [d_out]
self.embedview_fn = None
if n_harmonic_functions_dir > 0:
self.embedview_fn = HarmonicEmbedding(
n_harmonic_functions_dir, append_input=True
)
dims[0] += self.embedview_fn.get_output_dim() - 3
dims_full[0] += self.embedview_fn.get_output_dim() - 3
if pooled_feature_dim > 0:
print("Pooled features in rendering network.")
dims[0] += pooled_feature_dim
dims_full[0] += pooled_feature_dim
self.num_layers = len(dims)
self.num_layers = len(dims_full)
layers = []
for layer_idx in range(self.num_layers - 1):
out_dim = dims[layer_idx + 1]
lin = nn.Linear(dims[layer_idx], out_dim)
out_dim = dims_full[layer_idx + 1]
lin = nn.Linear(dims_full[layer_idx], out_dim)
if weight_norm:
lin = nn.utils.weight_norm(lin)

View File

@ -4,12 +4,12 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import copy
import dataclasses
import inspect
import itertools
import warnings
from collections import Counter, defaultdict
from enum import Enum
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, TypeVar, cast
from omegaconf import DictConfig, OmegaConf, open_dict
@ -409,13 +409,13 @@ def get_default_args(C, *, _do_not_process: Tuple[type, ...] = ()) -> DictConfig
# straight away if C has already been processed.
expand_args_fields(C, _do_not_process=_do_not_process)
kwargs = {}
if dataclasses.is_dataclass(C):
# Note that if get_default_args_field is used somewhere in C,
# this call is recursive. No special care is needed,
# because in practice get_default_args_field is used for
# separate types than the outer type.
out = OmegaConf.structured(C)
out: DictConfig = OmegaConf.structured(C)
exclude = getattr(C, "_processed_members", ())
with open_dict(out):
for field in exclude:
@ -425,16 +425,56 @@ def get_default_args(C, *, _do_not_process: Tuple[type, ...] = ()) -> DictConfig
if _is_configurable_class(C):
raise ValueError(f"Failed to process {C}")
# returns dict of keyword args of a callable C
sig = inspect.signature(C)
for pname, defval in dict(sig.parameters).items():
if defval.default == inspect.Parameter.empty:
# print('skipping %s' % pname)
# regular class or function
field_annotations = []
for pname, defval in _params_iter(C):
default = defval.default
if default == inspect.Parameter.empty:
# we do not have a default value for the parameter
continue
if defval.annotation == inspect._empty:
raise ValueError(
"All arguments of the input callable have to be typed."
+ f" Argument '{pname}' does not have a type annotation."
)
if isinstance(default, set): # force OmegaConf to convert it to ListConfig
default = tuple(default)
if isinstance(default, (list, dict)):
# OmegaConf will convert to [Dict|List]Config, so it is safe to reuse the value
field_ = dataclasses.field(default_factory=lambda default=default: default)
elif not _is_immutable_type(defval.annotation, default):
continue
else:
kwargs[pname] = copy.deepcopy(defval.default)
# we can use a simple default argument for dataclass.field
field_ = dataclasses.field(default=default)
field_annotations.append((pname, defval.annotation, field_))
return DictConfig(kwargs)
# make a temp dataclass and generate a structured config from it.
return OmegaConf.structured(
dataclasses.make_dataclass(f"__{C.__name__}_default_args__", field_annotations)
)
def _params_iter(C):
"""Returns dict of keyword args of a class or function C."""
if inspect.isclass(C):
return itertools.islice( # exclude `self`
inspect.signature(C.__init__).parameters.items(), 1, None
)
return inspect.signature(C).parameters.items()
def _is_immutable_type(type_: Type, val: Any) -> bool:
PRIMITIVE_TYPES = (int, float, bool, str, bytes, tuple)
# sometimes type can be too relaxed (e.g. Any), so we also check values
if isinstance(val, PRIMITIVE_TYPES):
return True
return type_ in PRIMITIVE_TYPES or issubclass(type_, Enum)
def _is_actually_dataclass(some_class) -> bool:

View File

@ -8,7 +8,7 @@ import textwrap
import unittest
from dataclasses import dataclass, field, is_dataclass
from enum import Enum
from typing import List, Optional, Tuple
from typing import Any, List, Optional, Set, Tuple
from omegaconf import DictConfig, ListConfig, OmegaConf, ValidationError
from pytorch3d.implicitron.tools.config import (
@ -216,7 +216,7 @@ class TestConfig(unittest.TestCase):
# tolerated. But it would be nice to be able to
# configure them.
class Foo:
def __init__(self, a=1, b=2):
def __init__(self, a: Any = 1, b: Any = 2):
self.a, self.b = a, b
@dataclass()
@ -238,7 +238,6 @@ class TestConfig(unittest.TestCase):
container_args = get_default_args(Container)
container = Container(**container_args)
self.assertIsInstance(container.fruit, Orange)
# self.assertIsInstance(container.bar, Bar)
container_defaulted = Container()
container_defaulted.fruit_Pear_args.n_pips += 4
@ -432,8 +431,13 @@ class TestConfig(unittest.TestCase):
tuple_field: tuple = (3, True, "j")
class SimpleClass:
def __init__(self, tuple_member_=(3, 4)):
def __init__(
self,
tuple_member_: Tuple[int, int] = (3, 4),
set_member_: Set[int] = {2}, # noqa
):
self.tuple_member = tuple_member_
self.set_member = set_member_
def get_tuple(self):
return self.tuple_member
@ -459,6 +463,9 @@ class TestConfig(unittest.TestCase):
# OmegaConf converts tuples to ListConfigs (which act like lists).
self.assertEqual(simple.get_tuple(), [3, 4])
self.assertTrue(isinstance(simple.get_tuple(), ListConfig))
# get_default_args converts sets to ListConfigs (which act like lists).
self.assertEqual(simple.set_member, [2])
self.assertTrue(isinstance(simple.set_member, ListConfig))
self.assertEqual(c.a_tuple, [4.0, 3.0])
self.assertTrue(isinstance(c.a_tuple, ListConfig))
self.assertEqual(mydata.tuple_field, (3, True, "j"))
@ -514,20 +521,31 @@ class TestConfig(unittest.TestCase):
B1 = "b1"
B2 = "b2"
# Test for a Configurable class, a function, and a regular class.
class C(Configurable):
a: A = A.B1
base = get_default_args(C)
replaced = OmegaConf.merge(base, {"a": "B2"})
self.assertEqual(replaced.a, A.B2)
with self.assertRaises(ValidationError):
# You can't use a value which is not one of the
# choices, even if it is the str representation
# of one of the choices.
OmegaConf.merge(base, {"a": "b2"})
# Also test for a calllable with enum arguments.
def C_fn(a: A = A.B1):
pass
remerged = OmegaConf.merge(base, OmegaConf.create(OmegaConf.to_yaml(base)))
self.assertEqual(remerged.a, A.B1)
class C_cl:
def __init__(self, a: A = A.B1) -> None:
pass
for C_ in [C, C_fn, C_cl]:
base = get_default_args(C_)
self.assertEqual(base.a, A.B1)
replaced = OmegaConf.merge(base, {"a": "B2"})
self.assertEqual(replaced.a, A.B2)
with self.assertRaises(ValidationError):
# You can't use a value which is not one of the
# choices, even if it is the str representation
# of one of the choices.
OmegaConf.merge(base, {"a": "b2"})
remerged = OmegaConf.merge(base, OmegaConf.create(OmegaConf.to_yaml(base)))
self.assertEqual(remerged.a, A.B1)
def test_remove_unused_components(self):
struct = get_default_args(MainTest)
@ -577,34 +595,56 @@ class TestConfig(unittest.TestCase):
class MockDataclass:
field_no_default: int
field_primitive_type: int = 42
field_reference_type: List[int] = field(default_factory=lambda: [])
field_list_type: List[int] = field(default_factory=lambda: [])
class RefObject:
pass
REF_OBJECT = RefObject()
class MockClassWithInit: # noqa: B903
def __init__(
self,
field_no_nothing,
field_no_default: int,
field_primitive_type: int = 42,
field_reference_type: List[int] = [], # noqa: B006
field_list_type: List[int] = [], # noqa: B006
field_reference_type: RefObject = REF_OBJECT,
):
self.field_no_nothing = field_no_nothing
self.field_no_default = field_no_default
self.field_primitive_type = field_primitive_type
self.field_list_type = field_list_type
self.field_reference_type = field_reference_type
class TestRawClasses(unittest.TestCase):
def setUp(self) -> None:
self._instances = {
MockDataclass: MockDataclass(field_no_default=0),
MockClassWithInit: MockClassWithInit(
field_no_nothing="tratata", field_no_default=0
),
}
def test_get_default_args(self):
for cls in [MockDataclass, MockClassWithInit]:
dataclass_defaults = get_default_args(cls)
inst = cls(field_no_default=0)
dataclass_defaults.field_no_default = 0
# DictConfig fields with missing values are `not in`
self.assertNotIn("field_no_default", dataclass_defaults)
self.assertNotIn("field_no_nothing", dataclass_defaults)
self.assertNotIn("field_reference_type", dataclass_defaults)
if cls == MockDataclass: # we don't remove undefaulted from dataclasses
dataclass_defaults.field_no_default = 0
for name, val in dataclass_defaults.items():
self.assertTrue(hasattr(inst, name))
self.assertEqual(val, getattr(inst, name))
self.assertTrue(hasattr(self._instances[cls], name))
self.assertEqual(val, getattr(self._instances[cls], name))
def test_get_default_args_readonly(self):
for cls in [MockDataclass, MockClassWithInit]:
dataclass_defaults = get_default_args(cls)
dataclass_defaults["field_reference_type"].append(13)
inst = cls(field_no_default=0)
self.assertEqual(inst.field_reference_type, [])
dataclass_defaults["field_list_type"].append(13)
self.assertEqual(self._instances[cls].field_list_type, [])