Return a typed structured config from default_args for callables

Summary:
Before the fix, running get_default_args(C: Callable) returns an unstructured DictConfig which causes Enums to be handled incorrectly. This is a fix.

WIP update: Currently tests still fail whenever a function signature contains an untyped argument: This needs to be somehow fixed.

Reviewed By: bottler

Differential Revision: D34932124

fbshipit-source-id: ecdc45c738633cfea5caa7480ba4f790ece931e8
This commit is contained in:
Roman Shapovalov 2022-03-25 07:08:01 -07:00 committed by Facebook GitHub Bot
parent 8ac5e8f083
commit 645a47d054
4 changed files with 133 additions and 50 deletions

View File

@ -213,7 +213,7 @@ def init_optimizer(
model: GenericModel, model: GenericModel,
optimizer_state: Optional[Dict[str, Any]], optimizer_state: Optional[Dict[str, Any]],
last_epoch: int, last_epoch: int,
breed: bool = "adam", breed: str = "adam",
weight_decay: float = 0.0, weight_decay: float = 0.0,
lr_policy: str = "multistep", lr_policy: str = "multistep",
lr: float = 0.0005, lr: float = 0.0005,

View File

@ -2,6 +2,9 @@
# Adapted from RenderingNetwork from IDR # Adapted from RenderingNetwork from IDR
# https://github.com/lioryariv/idr/ # https://github.com/lioryariv/idr/
# Copyright (c) 2020 Lior Yariv # Copyright (c) 2020 Lior Yariv
from typing import List, Tuple
import torch import torch
from pytorch3d.renderer.implicit import HarmonicEmbedding, RayBundle from pytorch3d.renderer.implicit import HarmonicEmbedding, RayBundle
from torch import nn from torch import nn
@ -10,38 +13,38 @@ from torch import nn
class RayNormalColoringNetwork(torch.nn.Module): class RayNormalColoringNetwork(torch.nn.Module):
def __init__( def __init__(
self, self,
feature_vector_size=3, feature_vector_size: int = 3,
mode="idr", mode: str = "idr",
d_in=9, d_in: int = 9,
d_out=3, d_out: int = 3,
dims=(512, 512, 512, 512), dims: Tuple[int, ...] = (512, 512, 512, 512),
weight_norm=True, weight_norm: bool = True,
n_harmonic_functions_dir=0, n_harmonic_functions_dir: int = 0,
pooled_feature_dim=0, pooled_feature_dim: int = 0,
): ) -> None:
super().__init__() super().__init__()
self.mode = mode self.mode = mode
self.output_dimensions = d_out self.output_dimensions = d_out
dims = [d_in + feature_vector_size] + list(dims) + [d_out] dims_full: List[int] = [d_in + feature_vector_size] + list(dims) + [d_out]
self.embedview_fn = None self.embedview_fn = None
if n_harmonic_functions_dir > 0: if n_harmonic_functions_dir > 0:
self.embedview_fn = HarmonicEmbedding( self.embedview_fn = HarmonicEmbedding(
n_harmonic_functions_dir, append_input=True n_harmonic_functions_dir, append_input=True
) )
dims[0] += self.embedview_fn.get_output_dim() - 3 dims_full[0] += self.embedview_fn.get_output_dim() - 3
if pooled_feature_dim > 0: if pooled_feature_dim > 0:
print("Pooled features in rendering network.") print("Pooled features in rendering network.")
dims[0] += pooled_feature_dim dims_full[0] += pooled_feature_dim
self.num_layers = len(dims) self.num_layers = len(dims_full)
layers = [] layers = []
for layer_idx in range(self.num_layers - 1): for layer_idx in range(self.num_layers - 1):
out_dim = dims[layer_idx + 1] out_dim = dims_full[layer_idx + 1]
lin = nn.Linear(dims[layer_idx], out_dim) lin = nn.Linear(dims_full[layer_idx], out_dim)
if weight_norm: if weight_norm:
lin = nn.utils.weight_norm(lin) lin = nn.utils.weight_norm(lin)

View File

@ -4,12 +4,12 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import copy
import dataclasses import dataclasses
import inspect import inspect
import itertools
import warnings import warnings
from collections import Counter, defaultdict from collections import Counter, defaultdict
from enum import Enum
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, TypeVar, cast from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, TypeVar, cast
from omegaconf import DictConfig, OmegaConf, open_dict from omegaconf import DictConfig, OmegaConf, open_dict
@ -409,13 +409,13 @@ def get_default_args(C, *, _do_not_process: Tuple[type, ...] = ()) -> DictConfig
# straight away if C has already been processed. # straight away if C has already been processed.
expand_args_fields(C, _do_not_process=_do_not_process) expand_args_fields(C, _do_not_process=_do_not_process)
kwargs = {}
if dataclasses.is_dataclass(C): if dataclasses.is_dataclass(C):
# Note that if get_default_args_field is used somewhere in C, # Note that if get_default_args_field is used somewhere in C,
# this call is recursive. No special care is needed, # this call is recursive. No special care is needed,
# because in practice get_default_args_field is used for # because in practice get_default_args_field is used for
# separate types than the outer type. # separate types than the outer type.
out = OmegaConf.structured(C)
out: DictConfig = OmegaConf.structured(C)
exclude = getattr(C, "_processed_members", ()) exclude = getattr(C, "_processed_members", ())
with open_dict(out): with open_dict(out):
for field in exclude: for field in exclude:
@ -425,16 +425,56 @@ def get_default_args(C, *, _do_not_process: Tuple[type, ...] = ()) -> DictConfig
if _is_configurable_class(C): if _is_configurable_class(C):
raise ValueError(f"Failed to process {C}") raise ValueError(f"Failed to process {C}")
# returns dict of keyword args of a callable C # regular class or function
sig = inspect.signature(C) field_annotations = []
for pname, defval in dict(sig.parameters).items(): for pname, defval in _params_iter(C):
if defval.default == inspect.Parameter.empty: default = defval.default
# print('skipping %s' % pname) if default == inspect.Parameter.empty:
# we do not have a default value for the parameter
continue
if defval.annotation == inspect._empty:
raise ValueError(
"All arguments of the input callable have to be typed."
+ f" Argument '{pname}' does not have a type annotation."
)
if isinstance(default, set): # force OmegaConf to convert it to ListConfig
default = tuple(default)
if isinstance(default, (list, dict)):
# OmegaConf will convert to [Dict|List]Config, so it is safe to reuse the value
field_ = dataclasses.field(default_factory=lambda default=default: default)
elif not _is_immutable_type(defval.annotation, default):
continue continue
else: else:
kwargs[pname] = copy.deepcopy(defval.default) # we can use a simple default argument for dataclass.field
field_ = dataclasses.field(default=default)
field_annotations.append((pname, defval.annotation, field_))
return DictConfig(kwargs) # make a temp dataclass and generate a structured config from it.
return OmegaConf.structured(
dataclasses.make_dataclass(f"__{C.__name__}_default_args__", field_annotations)
)
def _params_iter(C):
"""Returns dict of keyword args of a class or function C."""
if inspect.isclass(C):
return itertools.islice( # exclude `self`
inspect.signature(C.__init__).parameters.items(), 1, None
)
return inspect.signature(C).parameters.items()
def _is_immutable_type(type_: Type, val: Any) -> bool:
PRIMITIVE_TYPES = (int, float, bool, str, bytes, tuple)
# sometimes type can be too relaxed (e.g. Any), so we also check values
if isinstance(val, PRIMITIVE_TYPES):
return True
return type_ in PRIMITIVE_TYPES or issubclass(type_, Enum)
def _is_actually_dataclass(some_class) -> bool: def _is_actually_dataclass(some_class) -> bool:

View File

@ -8,7 +8,7 @@ import textwrap
import unittest import unittest
from dataclasses import dataclass, field, is_dataclass from dataclasses import dataclass, field, is_dataclass
from enum import Enum from enum import Enum
from typing import List, Optional, Tuple from typing import Any, List, Optional, Set, Tuple
from omegaconf import DictConfig, ListConfig, OmegaConf, ValidationError from omegaconf import DictConfig, ListConfig, OmegaConf, ValidationError
from pytorch3d.implicitron.tools.config import ( from pytorch3d.implicitron.tools.config import (
@ -216,7 +216,7 @@ class TestConfig(unittest.TestCase):
# tolerated. But it would be nice to be able to # tolerated. But it would be nice to be able to
# configure them. # configure them.
class Foo: class Foo:
def __init__(self, a=1, b=2): def __init__(self, a: Any = 1, b: Any = 2):
self.a, self.b = a, b self.a, self.b = a, b
@dataclass() @dataclass()
@ -238,7 +238,6 @@ class TestConfig(unittest.TestCase):
container_args = get_default_args(Container) container_args = get_default_args(Container)
container = Container(**container_args) container = Container(**container_args)
self.assertIsInstance(container.fruit, Orange) self.assertIsInstance(container.fruit, Orange)
# self.assertIsInstance(container.bar, Bar)
container_defaulted = Container() container_defaulted = Container()
container_defaulted.fruit_Pear_args.n_pips += 4 container_defaulted.fruit_Pear_args.n_pips += 4
@ -432,8 +431,13 @@ class TestConfig(unittest.TestCase):
tuple_field: tuple = (3, True, "j") tuple_field: tuple = (3, True, "j")
class SimpleClass: class SimpleClass:
def __init__(self, tuple_member_=(3, 4)): def __init__(
self,
tuple_member_: Tuple[int, int] = (3, 4),
set_member_: Set[int] = {2}, # noqa
):
self.tuple_member = tuple_member_ self.tuple_member = tuple_member_
self.set_member = set_member_
def get_tuple(self): def get_tuple(self):
return self.tuple_member return self.tuple_member
@ -459,6 +463,9 @@ class TestConfig(unittest.TestCase):
# OmegaConf converts tuples to ListConfigs (which act like lists). # OmegaConf converts tuples to ListConfigs (which act like lists).
self.assertEqual(simple.get_tuple(), [3, 4]) self.assertEqual(simple.get_tuple(), [3, 4])
self.assertTrue(isinstance(simple.get_tuple(), ListConfig)) self.assertTrue(isinstance(simple.get_tuple(), ListConfig))
# get_default_args converts sets to ListConfigs (which act like lists).
self.assertEqual(simple.set_member, [2])
self.assertTrue(isinstance(simple.set_member, ListConfig))
self.assertEqual(c.a_tuple, [4.0, 3.0]) self.assertEqual(c.a_tuple, [4.0, 3.0])
self.assertTrue(isinstance(c.a_tuple, ListConfig)) self.assertTrue(isinstance(c.a_tuple, ListConfig))
self.assertEqual(mydata.tuple_field, (3, True, "j")) self.assertEqual(mydata.tuple_field, (3, True, "j"))
@ -514,10 +521,21 @@ class TestConfig(unittest.TestCase):
B1 = "b1" B1 = "b1"
B2 = "b2" B2 = "b2"
# Test for a Configurable class, a function, and a regular class.
class C(Configurable): class C(Configurable):
a: A = A.B1 a: A = A.B1
base = get_default_args(C) # Also test for a calllable with enum arguments.
def C_fn(a: A = A.B1):
pass
class C_cl:
def __init__(self, a: A = A.B1) -> None:
pass
for C_ in [C, C_fn, C_cl]:
base = get_default_args(C_)
self.assertEqual(base.a, A.B1)
replaced = OmegaConf.merge(base, {"a": "B2"}) replaced = OmegaConf.merge(base, {"a": "B2"})
self.assertEqual(replaced.a, A.B2) self.assertEqual(replaced.a, A.B2)
with self.assertRaises(ValidationError): with self.assertRaises(ValidationError):
@ -577,34 +595,56 @@ class TestConfig(unittest.TestCase):
class MockDataclass: class MockDataclass:
field_no_default: int field_no_default: int
field_primitive_type: int = 42 field_primitive_type: int = 42
field_reference_type: List[int] = field(default_factory=lambda: []) field_list_type: List[int] = field(default_factory=lambda: [])
class RefObject:
pass
REF_OBJECT = RefObject()
class MockClassWithInit: # noqa: B903 class MockClassWithInit: # noqa: B903
def __init__( def __init__(
self, self,
field_no_nothing,
field_no_default: int, field_no_default: int,
field_primitive_type: int = 42, field_primitive_type: int = 42,
field_reference_type: List[int] = [], # noqa: B006 field_list_type: List[int] = [], # noqa: B006
field_reference_type: RefObject = REF_OBJECT,
): ):
self.field_no_nothing = field_no_nothing
self.field_no_default = field_no_default self.field_no_default = field_no_default
self.field_primitive_type = field_primitive_type self.field_primitive_type = field_primitive_type
self.field_list_type = field_list_type
self.field_reference_type = field_reference_type self.field_reference_type = field_reference_type
class TestRawClasses(unittest.TestCase): class TestRawClasses(unittest.TestCase):
def setUp(self) -> None:
self._instances = {
MockDataclass: MockDataclass(field_no_default=0),
MockClassWithInit: MockClassWithInit(
field_no_nothing="tratata", field_no_default=0
),
}
def test_get_default_args(self): def test_get_default_args(self):
for cls in [MockDataclass, MockClassWithInit]: for cls in [MockDataclass, MockClassWithInit]:
dataclass_defaults = get_default_args(cls) dataclass_defaults = get_default_args(cls)
inst = cls(field_no_default=0) # DictConfig fields with missing values are `not in`
self.assertNotIn("field_no_default", dataclass_defaults)
self.assertNotIn("field_no_nothing", dataclass_defaults)
self.assertNotIn("field_reference_type", dataclass_defaults)
if cls == MockDataclass: # we don't remove undefaulted from dataclasses
dataclass_defaults.field_no_default = 0 dataclass_defaults.field_no_default = 0
for name, val in dataclass_defaults.items(): for name, val in dataclass_defaults.items():
self.assertTrue(hasattr(inst, name)) self.assertTrue(hasattr(self._instances[cls], name))
self.assertEqual(val, getattr(inst, name)) self.assertEqual(val, getattr(self._instances[cls], name))
def test_get_default_args_readonly(self): def test_get_default_args_readonly(self):
for cls in [MockDataclass, MockClassWithInit]: for cls in [MockDataclass, MockClassWithInit]:
dataclass_defaults = get_default_args(cls) dataclass_defaults = get_default_args(cls)
dataclass_defaults["field_reference_type"].append(13) dataclass_defaults["field_list_type"].append(13)
inst = cls(field_no_default=0) self.assertEqual(self._instances[cls].field_list_type, [])
self.assertEqual(inst.field_reference_type, [])