mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Return a typed structured config from default_args for callables
Summary: Before the fix, running get_default_args(C: Callable) returns an unstructured DictConfig which causes Enums to be handled incorrectly. This is a fix. WIP update: Currently tests still fail whenever a function signature contains an untyped argument: This needs to be somehow fixed. Reviewed By: bottler Differential Revision: D34932124 fbshipit-source-id: ecdc45c738633cfea5caa7480ba4f790ece931e8
This commit is contained in:
parent
8ac5e8f083
commit
645a47d054
@ -213,7 +213,7 @@ def init_optimizer(
|
||||
model: GenericModel,
|
||||
optimizer_state: Optional[Dict[str, Any]],
|
||||
last_epoch: int,
|
||||
breed: bool = "adam",
|
||||
breed: str = "adam",
|
||||
weight_decay: float = 0.0,
|
||||
lr_policy: str = "multistep",
|
||||
lr: float = 0.0005,
|
||||
|
@ -2,6 +2,9 @@
|
||||
# Adapted from RenderingNetwork from IDR
|
||||
# https://github.com/lioryariv/idr/
|
||||
# Copyright (c) 2020 Lior Yariv
|
||||
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
from pytorch3d.renderer.implicit import HarmonicEmbedding, RayBundle
|
||||
from torch import nn
|
||||
@ -10,38 +13,38 @@ from torch import nn
|
||||
class RayNormalColoringNetwork(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
feature_vector_size=3,
|
||||
mode="idr",
|
||||
d_in=9,
|
||||
d_out=3,
|
||||
dims=(512, 512, 512, 512),
|
||||
weight_norm=True,
|
||||
n_harmonic_functions_dir=0,
|
||||
pooled_feature_dim=0,
|
||||
):
|
||||
feature_vector_size: int = 3,
|
||||
mode: str = "idr",
|
||||
d_in: int = 9,
|
||||
d_out: int = 3,
|
||||
dims: Tuple[int, ...] = (512, 512, 512, 512),
|
||||
weight_norm: bool = True,
|
||||
n_harmonic_functions_dir: int = 0,
|
||||
pooled_feature_dim: int = 0,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.mode = mode
|
||||
self.output_dimensions = d_out
|
||||
dims = [d_in + feature_vector_size] + list(dims) + [d_out]
|
||||
dims_full: List[int] = [d_in + feature_vector_size] + list(dims) + [d_out]
|
||||
|
||||
self.embedview_fn = None
|
||||
if n_harmonic_functions_dir > 0:
|
||||
self.embedview_fn = HarmonicEmbedding(
|
||||
n_harmonic_functions_dir, append_input=True
|
||||
)
|
||||
dims[0] += self.embedview_fn.get_output_dim() - 3
|
||||
dims_full[0] += self.embedview_fn.get_output_dim() - 3
|
||||
|
||||
if pooled_feature_dim > 0:
|
||||
print("Pooled features in rendering network.")
|
||||
dims[0] += pooled_feature_dim
|
||||
dims_full[0] += pooled_feature_dim
|
||||
|
||||
self.num_layers = len(dims)
|
||||
self.num_layers = len(dims_full)
|
||||
|
||||
layers = []
|
||||
for layer_idx in range(self.num_layers - 1):
|
||||
out_dim = dims[layer_idx + 1]
|
||||
lin = nn.Linear(dims[layer_idx], out_dim)
|
||||
out_dim = dims_full[layer_idx + 1]
|
||||
lin = nn.Linear(dims_full[layer_idx], out_dim)
|
||||
|
||||
if weight_norm:
|
||||
lin = nn.utils.weight_norm(lin)
|
||||
|
@ -4,12 +4,12 @@
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import copy
|
||||
import dataclasses
|
||||
import inspect
|
||||
import itertools
|
||||
import warnings
|
||||
from collections import Counter, defaultdict
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, TypeVar, cast
|
||||
|
||||
from omegaconf import DictConfig, OmegaConf, open_dict
|
||||
@ -409,13 +409,13 @@ def get_default_args(C, *, _do_not_process: Tuple[type, ...] = ()) -> DictConfig
|
||||
# straight away if C has already been processed.
|
||||
expand_args_fields(C, _do_not_process=_do_not_process)
|
||||
|
||||
kwargs = {}
|
||||
if dataclasses.is_dataclass(C):
|
||||
# Note that if get_default_args_field is used somewhere in C,
|
||||
# this call is recursive. No special care is needed,
|
||||
# because in practice get_default_args_field is used for
|
||||
# separate types than the outer type.
|
||||
out = OmegaConf.structured(C)
|
||||
|
||||
out: DictConfig = OmegaConf.structured(C)
|
||||
exclude = getattr(C, "_processed_members", ())
|
||||
with open_dict(out):
|
||||
for field in exclude:
|
||||
@ -425,16 +425,56 @@ def get_default_args(C, *, _do_not_process: Tuple[type, ...] = ()) -> DictConfig
|
||||
if _is_configurable_class(C):
|
||||
raise ValueError(f"Failed to process {C}")
|
||||
|
||||
# returns dict of keyword args of a callable C
|
||||
sig = inspect.signature(C)
|
||||
for pname, defval in dict(sig.parameters).items():
|
||||
if defval.default == inspect.Parameter.empty:
|
||||
# print('skipping %s' % pname)
|
||||
# regular class or function
|
||||
field_annotations = []
|
||||
for pname, defval in _params_iter(C):
|
||||
default = defval.default
|
||||
if default == inspect.Parameter.empty:
|
||||
# we do not have a default value for the parameter
|
||||
continue
|
||||
|
||||
if defval.annotation == inspect._empty:
|
||||
raise ValueError(
|
||||
"All arguments of the input callable have to be typed."
|
||||
+ f" Argument '{pname}' does not have a type annotation."
|
||||
)
|
||||
|
||||
if isinstance(default, set): # force OmegaConf to convert it to ListConfig
|
||||
default = tuple(default)
|
||||
|
||||
if isinstance(default, (list, dict)):
|
||||
# OmegaConf will convert to [Dict|List]Config, so it is safe to reuse the value
|
||||
field_ = dataclasses.field(default_factory=lambda default=default: default)
|
||||
elif not _is_immutable_type(defval.annotation, default):
|
||||
continue
|
||||
else:
|
||||
kwargs[pname] = copy.deepcopy(defval.default)
|
||||
# we can use a simple default argument for dataclass.field
|
||||
field_ = dataclasses.field(default=default)
|
||||
field_annotations.append((pname, defval.annotation, field_))
|
||||
|
||||
return DictConfig(kwargs)
|
||||
# make a temp dataclass and generate a structured config from it.
|
||||
return OmegaConf.structured(
|
||||
dataclasses.make_dataclass(f"__{C.__name__}_default_args__", field_annotations)
|
||||
)
|
||||
|
||||
|
||||
def _params_iter(C):
|
||||
"""Returns dict of keyword args of a class or function C."""
|
||||
if inspect.isclass(C):
|
||||
return itertools.islice( # exclude `self`
|
||||
inspect.signature(C.__init__).parameters.items(), 1, None
|
||||
)
|
||||
|
||||
return inspect.signature(C).parameters.items()
|
||||
|
||||
|
||||
def _is_immutable_type(type_: Type, val: Any) -> bool:
|
||||
PRIMITIVE_TYPES = (int, float, bool, str, bytes, tuple)
|
||||
# sometimes type can be too relaxed (e.g. Any), so we also check values
|
||||
if isinstance(val, PRIMITIVE_TYPES):
|
||||
return True
|
||||
|
||||
return type_ in PRIMITIVE_TYPES or issubclass(type_, Enum)
|
||||
|
||||
|
||||
def _is_actually_dataclass(some_class) -> bool:
|
||||
|
@ -8,7 +8,7 @@ import textwrap
|
||||
import unittest
|
||||
from dataclasses import dataclass, field, is_dataclass
|
||||
from enum import Enum
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import Any, List, Optional, Set, Tuple
|
||||
|
||||
from omegaconf import DictConfig, ListConfig, OmegaConf, ValidationError
|
||||
from pytorch3d.implicitron.tools.config import (
|
||||
@ -216,7 +216,7 @@ class TestConfig(unittest.TestCase):
|
||||
# tolerated. But it would be nice to be able to
|
||||
# configure them.
|
||||
class Foo:
|
||||
def __init__(self, a=1, b=2):
|
||||
def __init__(self, a: Any = 1, b: Any = 2):
|
||||
self.a, self.b = a, b
|
||||
|
||||
@dataclass()
|
||||
@ -238,7 +238,6 @@ class TestConfig(unittest.TestCase):
|
||||
container_args = get_default_args(Container)
|
||||
container = Container(**container_args)
|
||||
self.assertIsInstance(container.fruit, Orange)
|
||||
# self.assertIsInstance(container.bar, Bar)
|
||||
|
||||
container_defaulted = Container()
|
||||
container_defaulted.fruit_Pear_args.n_pips += 4
|
||||
@ -432,8 +431,13 @@ class TestConfig(unittest.TestCase):
|
||||
tuple_field: tuple = (3, True, "j")
|
||||
|
||||
class SimpleClass:
|
||||
def __init__(self, tuple_member_=(3, 4)):
|
||||
def __init__(
|
||||
self,
|
||||
tuple_member_: Tuple[int, int] = (3, 4),
|
||||
set_member_: Set[int] = {2}, # noqa
|
||||
):
|
||||
self.tuple_member = tuple_member_
|
||||
self.set_member = set_member_
|
||||
|
||||
def get_tuple(self):
|
||||
return self.tuple_member
|
||||
@ -459,6 +463,9 @@ class TestConfig(unittest.TestCase):
|
||||
# OmegaConf converts tuples to ListConfigs (which act like lists).
|
||||
self.assertEqual(simple.get_tuple(), [3, 4])
|
||||
self.assertTrue(isinstance(simple.get_tuple(), ListConfig))
|
||||
# get_default_args converts sets to ListConfigs (which act like lists).
|
||||
self.assertEqual(simple.set_member, [2])
|
||||
self.assertTrue(isinstance(simple.set_member, ListConfig))
|
||||
self.assertEqual(c.a_tuple, [4.0, 3.0])
|
||||
self.assertTrue(isinstance(c.a_tuple, ListConfig))
|
||||
self.assertEqual(mydata.tuple_field, (3, True, "j"))
|
||||
@ -514,20 +521,31 @@ class TestConfig(unittest.TestCase):
|
||||
B1 = "b1"
|
||||
B2 = "b2"
|
||||
|
||||
# Test for a Configurable class, a function, and a regular class.
|
||||
class C(Configurable):
|
||||
a: A = A.B1
|
||||
|
||||
base = get_default_args(C)
|
||||
replaced = OmegaConf.merge(base, {"a": "B2"})
|
||||
self.assertEqual(replaced.a, A.B2)
|
||||
with self.assertRaises(ValidationError):
|
||||
# You can't use a value which is not one of the
|
||||
# choices, even if it is the str representation
|
||||
# of one of the choices.
|
||||
OmegaConf.merge(base, {"a": "b2"})
|
||||
# Also test for a calllable with enum arguments.
|
||||
def C_fn(a: A = A.B1):
|
||||
pass
|
||||
|
||||
remerged = OmegaConf.merge(base, OmegaConf.create(OmegaConf.to_yaml(base)))
|
||||
self.assertEqual(remerged.a, A.B1)
|
||||
class C_cl:
|
||||
def __init__(self, a: A = A.B1) -> None:
|
||||
pass
|
||||
|
||||
for C_ in [C, C_fn, C_cl]:
|
||||
base = get_default_args(C_)
|
||||
self.assertEqual(base.a, A.B1)
|
||||
replaced = OmegaConf.merge(base, {"a": "B2"})
|
||||
self.assertEqual(replaced.a, A.B2)
|
||||
with self.assertRaises(ValidationError):
|
||||
# You can't use a value which is not one of the
|
||||
# choices, even if it is the str representation
|
||||
# of one of the choices.
|
||||
OmegaConf.merge(base, {"a": "b2"})
|
||||
|
||||
remerged = OmegaConf.merge(base, OmegaConf.create(OmegaConf.to_yaml(base)))
|
||||
self.assertEqual(remerged.a, A.B1)
|
||||
|
||||
def test_remove_unused_components(self):
|
||||
struct = get_default_args(MainTest)
|
||||
@ -577,34 +595,56 @@ class TestConfig(unittest.TestCase):
|
||||
class MockDataclass:
|
||||
field_no_default: int
|
||||
field_primitive_type: int = 42
|
||||
field_reference_type: List[int] = field(default_factory=lambda: [])
|
||||
field_list_type: List[int] = field(default_factory=lambda: [])
|
||||
|
||||
|
||||
class RefObject:
|
||||
pass
|
||||
|
||||
|
||||
REF_OBJECT = RefObject()
|
||||
|
||||
|
||||
class MockClassWithInit: # noqa: B903
|
||||
def __init__(
|
||||
self,
|
||||
field_no_nothing,
|
||||
field_no_default: int,
|
||||
field_primitive_type: int = 42,
|
||||
field_reference_type: List[int] = [], # noqa: B006
|
||||
field_list_type: List[int] = [], # noqa: B006
|
||||
field_reference_type: RefObject = REF_OBJECT,
|
||||
):
|
||||
self.field_no_nothing = field_no_nothing
|
||||
self.field_no_default = field_no_default
|
||||
self.field_primitive_type = field_primitive_type
|
||||
self.field_list_type = field_list_type
|
||||
self.field_reference_type = field_reference_type
|
||||
|
||||
|
||||
class TestRawClasses(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self._instances = {
|
||||
MockDataclass: MockDataclass(field_no_default=0),
|
||||
MockClassWithInit: MockClassWithInit(
|
||||
field_no_nothing="tratata", field_no_default=0
|
||||
),
|
||||
}
|
||||
|
||||
def test_get_default_args(self):
|
||||
for cls in [MockDataclass, MockClassWithInit]:
|
||||
dataclass_defaults = get_default_args(cls)
|
||||
inst = cls(field_no_default=0)
|
||||
dataclass_defaults.field_no_default = 0
|
||||
# DictConfig fields with missing values are `not in`
|
||||
self.assertNotIn("field_no_default", dataclass_defaults)
|
||||
self.assertNotIn("field_no_nothing", dataclass_defaults)
|
||||
self.assertNotIn("field_reference_type", dataclass_defaults)
|
||||
if cls == MockDataclass: # we don't remove undefaulted from dataclasses
|
||||
dataclass_defaults.field_no_default = 0
|
||||
for name, val in dataclass_defaults.items():
|
||||
self.assertTrue(hasattr(inst, name))
|
||||
self.assertEqual(val, getattr(inst, name))
|
||||
self.assertTrue(hasattr(self._instances[cls], name))
|
||||
self.assertEqual(val, getattr(self._instances[cls], name))
|
||||
|
||||
def test_get_default_args_readonly(self):
|
||||
for cls in [MockDataclass, MockClassWithInit]:
|
||||
dataclass_defaults = get_default_args(cls)
|
||||
dataclass_defaults["field_reference_type"].append(13)
|
||||
inst = cls(field_no_default=0)
|
||||
self.assertEqual(inst.field_reference_type, [])
|
||||
dataclass_defaults["field_list_type"].append(13)
|
||||
self.assertEqual(self._instances[cls].field_list_type, [])
|
||||
|
Loading…
x
Reference in New Issue
Block a user