From 645a47d054e25fad657a0724aa5497f622a00ea4 Mon Sep 17 00:00:00 2001 From: Roman Shapovalov Date: Fri, 25 Mar 2022 07:08:01 -0700 Subject: [PATCH] Return a typed structured config from default_args for callables Summary: Before the fix, running get_default_args(C: Callable) returns an unstructured DictConfig which causes Enums to be handled incorrectly. This is a fix. WIP update: Currently tests still fail whenever a function signature contains an untyped argument: This needs to be somehow fixed. Reviewed By: bottler Differential Revision: D34932124 fbshipit-source-id: ecdc45c738633cfea5caa7480ba4f790ece931e8 --- projects/implicitron_trainer/experiment.py | 2 +- .../implicitron/models/renderer/rgb_net.py | 33 +++---- pytorch3d/implicitron/tools/config.py | 62 ++++++++++--- tests/implicitron/test_config.py | 86 ++++++++++++++----- 4 files changed, 133 insertions(+), 50 deletions(-) diff --git a/projects/implicitron_trainer/experiment.py b/projects/implicitron_trainer/experiment.py index 64090153..ae7a00b1 100755 --- a/projects/implicitron_trainer/experiment.py +++ b/projects/implicitron_trainer/experiment.py @@ -213,7 +213,7 @@ def init_optimizer( model: GenericModel, optimizer_state: Optional[Dict[str, Any]], last_epoch: int, - breed: bool = "adam", + breed: str = "adam", weight_decay: float = 0.0, lr_policy: str = "multistep", lr: float = 0.0005, diff --git a/pytorch3d/implicitron/models/renderer/rgb_net.py b/pytorch3d/implicitron/models/renderer/rgb_net.py index 0e444a43..7e2fbd69 100644 --- a/pytorch3d/implicitron/models/renderer/rgb_net.py +++ b/pytorch3d/implicitron/models/renderer/rgb_net.py @@ -2,6 +2,9 @@ # Adapted from RenderingNetwork from IDR # https://github.com/lioryariv/idr/ # Copyright (c) 2020 Lior Yariv + +from typing import List, Tuple + import torch from pytorch3d.renderer.implicit import HarmonicEmbedding, RayBundle from torch import nn @@ -10,38 +13,38 @@ from torch import nn class RayNormalColoringNetwork(torch.nn.Module): def __init__( self, - feature_vector_size=3, - mode="idr", - d_in=9, - d_out=3, - dims=(512, 512, 512, 512), - weight_norm=True, - n_harmonic_functions_dir=0, - pooled_feature_dim=0, - ): + feature_vector_size: int = 3, + mode: str = "idr", + d_in: int = 9, + d_out: int = 3, + dims: Tuple[int, ...] = (512, 512, 512, 512), + weight_norm: bool = True, + n_harmonic_functions_dir: int = 0, + pooled_feature_dim: int = 0, + ) -> None: super().__init__() self.mode = mode self.output_dimensions = d_out - dims = [d_in + feature_vector_size] + list(dims) + [d_out] + dims_full: List[int] = [d_in + feature_vector_size] + list(dims) + [d_out] self.embedview_fn = None if n_harmonic_functions_dir > 0: self.embedview_fn = HarmonicEmbedding( n_harmonic_functions_dir, append_input=True ) - dims[0] += self.embedview_fn.get_output_dim() - 3 + dims_full[0] += self.embedview_fn.get_output_dim() - 3 if pooled_feature_dim > 0: print("Pooled features in rendering network.") - dims[0] += pooled_feature_dim + dims_full[0] += pooled_feature_dim - self.num_layers = len(dims) + self.num_layers = len(dims_full) layers = [] for layer_idx in range(self.num_layers - 1): - out_dim = dims[layer_idx + 1] - lin = nn.Linear(dims[layer_idx], out_dim) + out_dim = dims_full[layer_idx + 1] + lin = nn.Linear(dims_full[layer_idx], out_dim) if weight_norm: lin = nn.utils.weight_norm(lin) diff --git a/pytorch3d/implicitron/tools/config.py b/pytorch3d/implicitron/tools/config.py index edbec7d3..c867caee 100644 --- a/pytorch3d/implicitron/tools/config.py +++ b/pytorch3d/implicitron/tools/config.py @@ -4,12 +4,12 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. - -import copy import dataclasses import inspect +import itertools import warnings from collections import Counter, defaultdict +from enum import Enum from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, TypeVar, cast from omegaconf import DictConfig, OmegaConf, open_dict @@ -409,13 +409,13 @@ def get_default_args(C, *, _do_not_process: Tuple[type, ...] = ()) -> DictConfig # straight away if C has already been processed. expand_args_fields(C, _do_not_process=_do_not_process) - kwargs = {} if dataclasses.is_dataclass(C): # Note that if get_default_args_field is used somewhere in C, # this call is recursive. No special care is needed, # because in practice get_default_args_field is used for # separate types than the outer type. - out = OmegaConf.structured(C) + + out: DictConfig = OmegaConf.structured(C) exclude = getattr(C, "_processed_members", ()) with open_dict(out): for field in exclude: @@ -425,16 +425,56 @@ def get_default_args(C, *, _do_not_process: Tuple[type, ...] = ()) -> DictConfig if _is_configurable_class(C): raise ValueError(f"Failed to process {C}") - # returns dict of keyword args of a callable C - sig = inspect.signature(C) - for pname, defval in dict(sig.parameters).items(): - if defval.default == inspect.Parameter.empty: - # print('skipping %s' % pname) + # regular class or function + field_annotations = [] + for pname, defval in _params_iter(C): + default = defval.default + if default == inspect.Parameter.empty: + # we do not have a default value for the parameter + continue + + if defval.annotation == inspect._empty: + raise ValueError( + "All arguments of the input callable have to be typed." + + f" Argument '{pname}' does not have a type annotation." + ) + + if isinstance(default, set): # force OmegaConf to convert it to ListConfig + default = tuple(default) + + if isinstance(default, (list, dict)): + # OmegaConf will convert to [Dict|List]Config, so it is safe to reuse the value + field_ = dataclasses.field(default_factory=lambda default=default: default) + elif not _is_immutable_type(defval.annotation, default): continue else: - kwargs[pname] = copy.deepcopy(defval.default) + # we can use a simple default argument for dataclass.field + field_ = dataclasses.field(default=default) + field_annotations.append((pname, defval.annotation, field_)) - return DictConfig(kwargs) + # make a temp dataclass and generate a structured config from it. + return OmegaConf.structured( + dataclasses.make_dataclass(f"__{C.__name__}_default_args__", field_annotations) + ) + + +def _params_iter(C): + """Returns dict of keyword args of a class or function C.""" + if inspect.isclass(C): + return itertools.islice( # exclude `self` + inspect.signature(C.__init__).parameters.items(), 1, None + ) + + return inspect.signature(C).parameters.items() + + +def _is_immutable_type(type_: Type, val: Any) -> bool: + PRIMITIVE_TYPES = (int, float, bool, str, bytes, tuple) + # sometimes type can be too relaxed (e.g. Any), so we also check values + if isinstance(val, PRIMITIVE_TYPES): + return True + + return type_ in PRIMITIVE_TYPES or issubclass(type_, Enum) def _is_actually_dataclass(some_class) -> bool: diff --git a/tests/implicitron/test_config.py b/tests/implicitron/test_config.py index a1dd7fd0..dcc98848 100644 --- a/tests/implicitron/test_config.py +++ b/tests/implicitron/test_config.py @@ -8,7 +8,7 @@ import textwrap import unittest from dataclasses import dataclass, field, is_dataclass from enum import Enum -from typing import List, Optional, Tuple +from typing import Any, List, Optional, Set, Tuple from omegaconf import DictConfig, ListConfig, OmegaConf, ValidationError from pytorch3d.implicitron.tools.config import ( @@ -216,7 +216,7 @@ class TestConfig(unittest.TestCase): # tolerated. But it would be nice to be able to # configure them. class Foo: - def __init__(self, a=1, b=2): + def __init__(self, a: Any = 1, b: Any = 2): self.a, self.b = a, b @dataclass() @@ -238,7 +238,6 @@ class TestConfig(unittest.TestCase): container_args = get_default_args(Container) container = Container(**container_args) self.assertIsInstance(container.fruit, Orange) - # self.assertIsInstance(container.bar, Bar) container_defaulted = Container() container_defaulted.fruit_Pear_args.n_pips += 4 @@ -432,8 +431,13 @@ class TestConfig(unittest.TestCase): tuple_field: tuple = (3, True, "j") class SimpleClass: - def __init__(self, tuple_member_=(3, 4)): + def __init__( + self, + tuple_member_: Tuple[int, int] = (3, 4), + set_member_: Set[int] = {2}, # noqa + ): self.tuple_member = tuple_member_ + self.set_member = set_member_ def get_tuple(self): return self.tuple_member @@ -459,6 +463,9 @@ class TestConfig(unittest.TestCase): # OmegaConf converts tuples to ListConfigs (which act like lists). self.assertEqual(simple.get_tuple(), [3, 4]) self.assertTrue(isinstance(simple.get_tuple(), ListConfig)) + # get_default_args converts sets to ListConfigs (which act like lists). + self.assertEqual(simple.set_member, [2]) + self.assertTrue(isinstance(simple.set_member, ListConfig)) self.assertEqual(c.a_tuple, [4.0, 3.0]) self.assertTrue(isinstance(c.a_tuple, ListConfig)) self.assertEqual(mydata.tuple_field, (3, True, "j")) @@ -514,20 +521,31 @@ class TestConfig(unittest.TestCase): B1 = "b1" B2 = "b2" + # Test for a Configurable class, a function, and a regular class. class C(Configurable): a: A = A.B1 - base = get_default_args(C) - replaced = OmegaConf.merge(base, {"a": "B2"}) - self.assertEqual(replaced.a, A.B2) - with self.assertRaises(ValidationError): - # You can't use a value which is not one of the - # choices, even if it is the str representation - # of one of the choices. - OmegaConf.merge(base, {"a": "b2"}) + # Also test for a calllable with enum arguments. + def C_fn(a: A = A.B1): + pass - remerged = OmegaConf.merge(base, OmegaConf.create(OmegaConf.to_yaml(base))) - self.assertEqual(remerged.a, A.B1) + class C_cl: + def __init__(self, a: A = A.B1) -> None: + pass + + for C_ in [C, C_fn, C_cl]: + base = get_default_args(C_) + self.assertEqual(base.a, A.B1) + replaced = OmegaConf.merge(base, {"a": "B2"}) + self.assertEqual(replaced.a, A.B2) + with self.assertRaises(ValidationError): + # You can't use a value which is not one of the + # choices, even if it is the str representation + # of one of the choices. + OmegaConf.merge(base, {"a": "b2"}) + + remerged = OmegaConf.merge(base, OmegaConf.create(OmegaConf.to_yaml(base))) + self.assertEqual(remerged.a, A.B1) def test_remove_unused_components(self): struct = get_default_args(MainTest) @@ -577,34 +595,56 @@ class TestConfig(unittest.TestCase): class MockDataclass: field_no_default: int field_primitive_type: int = 42 - field_reference_type: List[int] = field(default_factory=lambda: []) + field_list_type: List[int] = field(default_factory=lambda: []) + + +class RefObject: + pass + + +REF_OBJECT = RefObject() class MockClassWithInit: # noqa: B903 def __init__( self, + field_no_nothing, field_no_default: int, field_primitive_type: int = 42, - field_reference_type: List[int] = [], # noqa: B006 + field_list_type: List[int] = [], # noqa: B006 + field_reference_type: RefObject = REF_OBJECT, ): + self.field_no_nothing = field_no_nothing self.field_no_default = field_no_default self.field_primitive_type = field_primitive_type + self.field_list_type = field_list_type self.field_reference_type = field_reference_type class TestRawClasses(unittest.TestCase): + def setUp(self) -> None: + self._instances = { + MockDataclass: MockDataclass(field_no_default=0), + MockClassWithInit: MockClassWithInit( + field_no_nothing="tratata", field_no_default=0 + ), + } + def test_get_default_args(self): for cls in [MockDataclass, MockClassWithInit]: dataclass_defaults = get_default_args(cls) - inst = cls(field_no_default=0) - dataclass_defaults.field_no_default = 0 + # DictConfig fields with missing values are `not in` + self.assertNotIn("field_no_default", dataclass_defaults) + self.assertNotIn("field_no_nothing", dataclass_defaults) + self.assertNotIn("field_reference_type", dataclass_defaults) + if cls == MockDataclass: # we don't remove undefaulted from dataclasses + dataclass_defaults.field_no_default = 0 for name, val in dataclass_defaults.items(): - self.assertTrue(hasattr(inst, name)) - self.assertEqual(val, getattr(inst, name)) + self.assertTrue(hasattr(self._instances[cls], name)) + self.assertEqual(val, getattr(self._instances[cls], name)) def test_get_default_args_readonly(self): for cls in [MockDataclass, MockClassWithInit]: dataclass_defaults = get_default_args(cls) - dataclass_defaults["field_reference_type"].append(13) - inst = cls(field_no_default=0) - self.assertEqual(inst.field_reference_type, []) + dataclass_defaults["field_list_type"].append(13) + self.assertEqual(self._instances[cls].field_list_type, [])