mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-03 04:12:48 +08:00
Return a typed structured config from default_args for callables
Summary: Before the fix, running get_default_args(C: Callable) returns an unstructured DictConfig which causes Enums to be handled incorrectly. This is a fix. WIP update: Currently tests still fail whenever a function signature contains an untyped argument: This needs to be somehow fixed. Reviewed By: bottler Differential Revision: D34932124 fbshipit-source-id: ecdc45c738633cfea5caa7480ba4f790ece931e8
This commit is contained in:
parent
8ac5e8f083
commit
645a47d054
@ -213,7 +213,7 @@ def init_optimizer(
|
|||||||
model: GenericModel,
|
model: GenericModel,
|
||||||
optimizer_state: Optional[Dict[str, Any]],
|
optimizer_state: Optional[Dict[str, Any]],
|
||||||
last_epoch: int,
|
last_epoch: int,
|
||||||
breed: bool = "adam",
|
breed: str = "adam",
|
||||||
weight_decay: float = 0.0,
|
weight_decay: float = 0.0,
|
||||||
lr_policy: str = "multistep",
|
lr_policy: str = "multistep",
|
||||||
lr: float = 0.0005,
|
lr: float = 0.0005,
|
||||||
|
@ -2,6 +2,9 @@
|
|||||||
# Adapted from RenderingNetwork from IDR
|
# Adapted from RenderingNetwork from IDR
|
||||||
# https://github.com/lioryariv/idr/
|
# https://github.com/lioryariv/idr/
|
||||||
# Copyright (c) 2020 Lior Yariv
|
# Copyright (c) 2020 Lior Yariv
|
||||||
|
|
||||||
|
from typing import List, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from pytorch3d.renderer.implicit import HarmonicEmbedding, RayBundle
|
from pytorch3d.renderer.implicit import HarmonicEmbedding, RayBundle
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@ -10,38 +13,38 @@ from torch import nn
|
|||||||
class RayNormalColoringNetwork(torch.nn.Module):
|
class RayNormalColoringNetwork(torch.nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
feature_vector_size=3,
|
feature_vector_size: int = 3,
|
||||||
mode="idr",
|
mode: str = "idr",
|
||||||
d_in=9,
|
d_in: int = 9,
|
||||||
d_out=3,
|
d_out: int = 3,
|
||||||
dims=(512, 512, 512, 512),
|
dims: Tuple[int, ...] = (512, 512, 512, 512),
|
||||||
weight_norm=True,
|
weight_norm: bool = True,
|
||||||
n_harmonic_functions_dir=0,
|
n_harmonic_functions_dir: int = 0,
|
||||||
pooled_feature_dim=0,
|
pooled_feature_dim: int = 0,
|
||||||
):
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.mode = mode
|
self.mode = mode
|
||||||
self.output_dimensions = d_out
|
self.output_dimensions = d_out
|
||||||
dims = [d_in + feature_vector_size] + list(dims) + [d_out]
|
dims_full: List[int] = [d_in + feature_vector_size] + list(dims) + [d_out]
|
||||||
|
|
||||||
self.embedview_fn = None
|
self.embedview_fn = None
|
||||||
if n_harmonic_functions_dir > 0:
|
if n_harmonic_functions_dir > 0:
|
||||||
self.embedview_fn = HarmonicEmbedding(
|
self.embedview_fn = HarmonicEmbedding(
|
||||||
n_harmonic_functions_dir, append_input=True
|
n_harmonic_functions_dir, append_input=True
|
||||||
)
|
)
|
||||||
dims[0] += self.embedview_fn.get_output_dim() - 3
|
dims_full[0] += self.embedview_fn.get_output_dim() - 3
|
||||||
|
|
||||||
if pooled_feature_dim > 0:
|
if pooled_feature_dim > 0:
|
||||||
print("Pooled features in rendering network.")
|
print("Pooled features in rendering network.")
|
||||||
dims[0] += pooled_feature_dim
|
dims_full[0] += pooled_feature_dim
|
||||||
|
|
||||||
self.num_layers = len(dims)
|
self.num_layers = len(dims_full)
|
||||||
|
|
||||||
layers = []
|
layers = []
|
||||||
for layer_idx in range(self.num_layers - 1):
|
for layer_idx in range(self.num_layers - 1):
|
||||||
out_dim = dims[layer_idx + 1]
|
out_dim = dims_full[layer_idx + 1]
|
||||||
lin = nn.Linear(dims[layer_idx], out_dim)
|
lin = nn.Linear(dims_full[layer_idx], out_dim)
|
||||||
|
|
||||||
if weight_norm:
|
if weight_norm:
|
||||||
lin = nn.utils.weight_norm(lin)
|
lin = nn.utils.weight_norm(lin)
|
||||||
|
@ -4,12 +4,12 @@
|
|||||||
# This source code is licensed under the BSD-style license found in the
|
# This source code is licensed under the BSD-style license found in the
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
|
||||||
import copy
|
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import inspect
|
import inspect
|
||||||
|
import itertools
|
||||||
import warnings
|
import warnings
|
||||||
from collections import Counter, defaultdict
|
from collections import Counter, defaultdict
|
||||||
|
from enum import Enum
|
||||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, TypeVar, cast
|
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, TypeVar, cast
|
||||||
|
|
||||||
from omegaconf import DictConfig, OmegaConf, open_dict
|
from omegaconf import DictConfig, OmegaConf, open_dict
|
||||||
@ -409,13 +409,13 @@ def get_default_args(C, *, _do_not_process: Tuple[type, ...] = ()) -> DictConfig
|
|||||||
# straight away if C has already been processed.
|
# straight away if C has already been processed.
|
||||||
expand_args_fields(C, _do_not_process=_do_not_process)
|
expand_args_fields(C, _do_not_process=_do_not_process)
|
||||||
|
|
||||||
kwargs = {}
|
|
||||||
if dataclasses.is_dataclass(C):
|
if dataclasses.is_dataclass(C):
|
||||||
# Note that if get_default_args_field is used somewhere in C,
|
# Note that if get_default_args_field is used somewhere in C,
|
||||||
# this call is recursive. No special care is needed,
|
# this call is recursive. No special care is needed,
|
||||||
# because in practice get_default_args_field is used for
|
# because in practice get_default_args_field is used for
|
||||||
# separate types than the outer type.
|
# separate types than the outer type.
|
||||||
out = OmegaConf.structured(C)
|
|
||||||
|
out: DictConfig = OmegaConf.structured(C)
|
||||||
exclude = getattr(C, "_processed_members", ())
|
exclude = getattr(C, "_processed_members", ())
|
||||||
with open_dict(out):
|
with open_dict(out):
|
||||||
for field in exclude:
|
for field in exclude:
|
||||||
@ -425,16 +425,56 @@ def get_default_args(C, *, _do_not_process: Tuple[type, ...] = ()) -> DictConfig
|
|||||||
if _is_configurable_class(C):
|
if _is_configurable_class(C):
|
||||||
raise ValueError(f"Failed to process {C}")
|
raise ValueError(f"Failed to process {C}")
|
||||||
|
|
||||||
# returns dict of keyword args of a callable C
|
# regular class or function
|
||||||
sig = inspect.signature(C)
|
field_annotations = []
|
||||||
for pname, defval in dict(sig.parameters).items():
|
for pname, defval in _params_iter(C):
|
||||||
if defval.default == inspect.Parameter.empty:
|
default = defval.default
|
||||||
# print('skipping %s' % pname)
|
if default == inspect.Parameter.empty:
|
||||||
|
# we do not have a default value for the parameter
|
||||||
|
continue
|
||||||
|
|
||||||
|
if defval.annotation == inspect._empty:
|
||||||
|
raise ValueError(
|
||||||
|
"All arguments of the input callable have to be typed."
|
||||||
|
+ f" Argument '{pname}' does not have a type annotation."
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(default, set): # force OmegaConf to convert it to ListConfig
|
||||||
|
default = tuple(default)
|
||||||
|
|
||||||
|
if isinstance(default, (list, dict)):
|
||||||
|
# OmegaConf will convert to [Dict|List]Config, so it is safe to reuse the value
|
||||||
|
field_ = dataclasses.field(default_factory=lambda default=default: default)
|
||||||
|
elif not _is_immutable_type(defval.annotation, default):
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
kwargs[pname] = copy.deepcopy(defval.default)
|
# we can use a simple default argument for dataclass.field
|
||||||
|
field_ = dataclasses.field(default=default)
|
||||||
|
field_annotations.append((pname, defval.annotation, field_))
|
||||||
|
|
||||||
return DictConfig(kwargs)
|
# make a temp dataclass and generate a structured config from it.
|
||||||
|
return OmegaConf.structured(
|
||||||
|
dataclasses.make_dataclass(f"__{C.__name__}_default_args__", field_annotations)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _params_iter(C):
|
||||||
|
"""Returns dict of keyword args of a class or function C."""
|
||||||
|
if inspect.isclass(C):
|
||||||
|
return itertools.islice( # exclude `self`
|
||||||
|
inspect.signature(C.__init__).parameters.items(), 1, None
|
||||||
|
)
|
||||||
|
|
||||||
|
return inspect.signature(C).parameters.items()
|
||||||
|
|
||||||
|
|
||||||
|
def _is_immutable_type(type_: Type, val: Any) -> bool:
|
||||||
|
PRIMITIVE_TYPES = (int, float, bool, str, bytes, tuple)
|
||||||
|
# sometimes type can be too relaxed (e.g. Any), so we also check values
|
||||||
|
if isinstance(val, PRIMITIVE_TYPES):
|
||||||
|
return True
|
||||||
|
|
||||||
|
return type_ in PRIMITIVE_TYPES or issubclass(type_, Enum)
|
||||||
|
|
||||||
|
|
||||||
def _is_actually_dataclass(some_class) -> bool:
|
def _is_actually_dataclass(some_class) -> bool:
|
||||||
|
@ -8,7 +8,7 @@ import textwrap
|
|||||||
import unittest
|
import unittest
|
||||||
from dataclasses import dataclass, field, is_dataclass
|
from dataclasses import dataclass, field, is_dataclass
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import List, Optional, Tuple
|
from typing import Any, List, Optional, Set, Tuple
|
||||||
|
|
||||||
from omegaconf import DictConfig, ListConfig, OmegaConf, ValidationError
|
from omegaconf import DictConfig, ListConfig, OmegaConf, ValidationError
|
||||||
from pytorch3d.implicitron.tools.config import (
|
from pytorch3d.implicitron.tools.config import (
|
||||||
@ -216,7 +216,7 @@ class TestConfig(unittest.TestCase):
|
|||||||
# tolerated. But it would be nice to be able to
|
# tolerated. But it would be nice to be able to
|
||||||
# configure them.
|
# configure them.
|
||||||
class Foo:
|
class Foo:
|
||||||
def __init__(self, a=1, b=2):
|
def __init__(self, a: Any = 1, b: Any = 2):
|
||||||
self.a, self.b = a, b
|
self.a, self.b = a, b
|
||||||
|
|
||||||
@dataclass()
|
@dataclass()
|
||||||
@ -238,7 +238,6 @@ class TestConfig(unittest.TestCase):
|
|||||||
container_args = get_default_args(Container)
|
container_args = get_default_args(Container)
|
||||||
container = Container(**container_args)
|
container = Container(**container_args)
|
||||||
self.assertIsInstance(container.fruit, Orange)
|
self.assertIsInstance(container.fruit, Orange)
|
||||||
# self.assertIsInstance(container.bar, Bar)
|
|
||||||
|
|
||||||
container_defaulted = Container()
|
container_defaulted = Container()
|
||||||
container_defaulted.fruit_Pear_args.n_pips += 4
|
container_defaulted.fruit_Pear_args.n_pips += 4
|
||||||
@ -432,8 +431,13 @@ class TestConfig(unittest.TestCase):
|
|||||||
tuple_field: tuple = (3, True, "j")
|
tuple_field: tuple = (3, True, "j")
|
||||||
|
|
||||||
class SimpleClass:
|
class SimpleClass:
|
||||||
def __init__(self, tuple_member_=(3, 4)):
|
def __init__(
|
||||||
|
self,
|
||||||
|
tuple_member_: Tuple[int, int] = (3, 4),
|
||||||
|
set_member_: Set[int] = {2}, # noqa
|
||||||
|
):
|
||||||
self.tuple_member = tuple_member_
|
self.tuple_member = tuple_member_
|
||||||
|
self.set_member = set_member_
|
||||||
|
|
||||||
def get_tuple(self):
|
def get_tuple(self):
|
||||||
return self.tuple_member
|
return self.tuple_member
|
||||||
@ -459,6 +463,9 @@ class TestConfig(unittest.TestCase):
|
|||||||
# OmegaConf converts tuples to ListConfigs (which act like lists).
|
# OmegaConf converts tuples to ListConfigs (which act like lists).
|
||||||
self.assertEqual(simple.get_tuple(), [3, 4])
|
self.assertEqual(simple.get_tuple(), [3, 4])
|
||||||
self.assertTrue(isinstance(simple.get_tuple(), ListConfig))
|
self.assertTrue(isinstance(simple.get_tuple(), ListConfig))
|
||||||
|
# get_default_args converts sets to ListConfigs (which act like lists).
|
||||||
|
self.assertEqual(simple.set_member, [2])
|
||||||
|
self.assertTrue(isinstance(simple.set_member, ListConfig))
|
||||||
self.assertEqual(c.a_tuple, [4.0, 3.0])
|
self.assertEqual(c.a_tuple, [4.0, 3.0])
|
||||||
self.assertTrue(isinstance(c.a_tuple, ListConfig))
|
self.assertTrue(isinstance(c.a_tuple, ListConfig))
|
||||||
self.assertEqual(mydata.tuple_field, (3, True, "j"))
|
self.assertEqual(mydata.tuple_field, (3, True, "j"))
|
||||||
@ -514,10 +521,21 @@ class TestConfig(unittest.TestCase):
|
|||||||
B1 = "b1"
|
B1 = "b1"
|
||||||
B2 = "b2"
|
B2 = "b2"
|
||||||
|
|
||||||
|
# Test for a Configurable class, a function, and a regular class.
|
||||||
class C(Configurable):
|
class C(Configurable):
|
||||||
a: A = A.B1
|
a: A = A.B1
|
||||||
|
|
||||||
base = get_default_args(C)
|
# Also test for a calllable with enum arguments.
|
||||||
|
def C_fn(a: A = A.B1):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class C_cl:
|
||||||
|
def __init__(self, a: A = A.B1) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
for C_ in [C, C_fn, C_cl]:
|
||||||
|
base = get_default_args(C_)
|
||||||
|
self.assertEqual(base.a, A.B1)
|
||||||
replaced = OmegaConf.merge(base, {"a": "B2"})
|
replaced = OmegaConf.merge(base, {"a": "B2"})
|
||||||
self.assertEqual(replaced.a, A.B2)
|
self.assertEqual(replaced.a, A.B2)
|
||||||
with self.assertRaises(ValidationError):
|
with self.assertRaises(ValidationError):
|
||||||
@ -577,34 +595,56 @@ class TestConfig(unittest.TestCase):
|
|||||||
class MockDataclass:
|
class MockDataclass:
|
||||||
field_no_default: int
|
field_no_default: int
|
||||||
field_primitive_type: int = 42
|
field_primitive_type: int = 42
|
||||||
field_reference_type: List[int] = field(default_factory=lambda: [])
|
field_list_type: List[int] = field(default_factory=lambda: [])
|
||||||
|
|
||||||
|
|
||||||
|
class RefObject:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
REF_OBJECT = RefObject()
|
||||||
|
|
||||||
|
|
||||||
class MockClassWithInit: # noqa: B903
|
class MockClassWithInit: # noqa: B903
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
field_no_nothing,
|
||||||
field_no_default: int,
|
field_no_default: int,
|
||||||
field_primitive_type: int = 42,
|
field_primitive_type: int = 42,
|
||||||
field_reference_type: List[int] = [], # noqa: B006
|
field_list_type: List[int] = [], # noqa: B006
|
||||||
|
field_reference_type: RefObject = REF_OBJECT,
|
||||||
):
|
):
|
||||||
|
self.field_no_nothing = field_no_nothing
|
||||||
self.field_no_default = field_no_default
|
self.field_no_default = field_no_default
|
||||||
self.field_primitive_type = field_primitive_type
|
self.field_primitive_type = field_primitive_type
|
||||||
|
self.field_list_type = field_list_type
|
||||||
self.field_reference_type = field_reference_type
|
self.field_reference_type = field_reference_type
|
||||||
|
|
||||||
|
|
||||||
class TestRawClasses(unittest.TestCase):
|
class TestRawClasses(unittest.TestCase):
|
||||||
|
def setUp(self) -> None:
|
||||||
|
self._instances = {
|
||||||
|
MockDataclass: MockDataclass(field_no_default=0),
|
||||||
|
MockClassWithInit: MockClassWithInit(
|
||||||
|
field_no_nothing="tratata", field_no_default=0
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
def test_get_default_args(self):
|
def test_get_default_args(self):
|
||||||
for cls in [MockDataclass, MockClassWithInit]:
|
for cls in [MockDataclass, MockClassWithInit]:
|
||||||
dataclass_defaults = get_default_args(cls)
|
dataclass_defaults = get_default_args(cls)
|
||||||
inst = cls(field_no_default=0)
|
# DictConfig fields with missing values are `not in`
|
||||||
|
self.assertNotIn("field_no_default", dataclass_defaults)
|
||||||
|
self.assertNotIn("field_no_nothing", dataclass_defaults)
|
||||||
|
self.assertNotIn("field_reference_type", dataclass_defaults)
|
||||||
|
if cls == MockDataclass: # we don't remove undefaulted from dataclasses
|
||||||
dataclass_defaults.field_no_default = 0
|
dataclass_defaults.field_no_default = 0
|
||||||
for name, val in dataclass_defaults.items():
|
for name, val in dataclass_defaults.items():
|
||||||
self.assertTrue(hasattr(inst, name))
|
self.assertTrue(hasattr(self._instances[cls], name))
|
||||||
self.assertEqual(val, getattr(inst, name))
|
self.assertEqual(val, getattr(self._instances[cls], name))
|
||||||
|
|
||||||
def test_get_default_args_readonly(self):
|
def test_get_default_args_readonly(self):
|
||||||
for cls in [MockDataclass, MockClassWithInit]:
|
for cls in [MockDataclass, MockClassWithInit]:
|
||||||
dataclass_defaults = get_default_args(cls)
|
dataclass_defaults = get_default_args(cls)
|
||||||
dataclass_defaults["field_reference_type"].append(13)
|
dataclass_defaults["field_list_type"].append(13)
|
||||||
inst = cls(field_no_default=0)
|
self.assertEqual(self._instances[cls].field_list_type, [])
|
||||||
self.assertEqual(inst.field_reference_type, [])
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user