Return a typed structured config from default_args for callables

Summary:
Before the fix, running get_default_args(C: Callable) returns an unstructured DictConfig which causes Enums to be handled incorrectly. This is a fix.

WIP update: Currently tests still fail whenever a function signature contains an untyped argument: This needs to be somehow fixed.

Reviewed By: bottler

Differential Revision: D34932124

fbshipit-source-id: ecdc45c738633cfea5caa7480ba4f790ece931e8
This commit is contained in:
Roman Shapovalov
2022-03-25 07:08:01 -07:00
committed by Facebook GitHub Bot
parent 8ac5e8f083
commit 645a47d054
4 changed files with 133 additions and 50 deletions

View File

@@ -2,6 +2,9 @@
# Adapted from RenderingNetwork from IDR
# https://github.com/lioryariv/idr/
# Copyright (c) 2020 Lior Yariv
from typing import List, Tuple
import torch
from pytorch3d.renderer.implicit import HarmonicEmbedding, RayBundle
from torch import nn
@@ -10,38 +13,38 @@ from torch import nn
class RayNormalColoringNetwork(torch.nn.Module):
def __init__(
self,
feature_vector_size=3,
mode="idr",
d_in=9,
d_out=3,
dims=(512, 512, 512, 512),
weight_norm=True,
n_harmonic_functions_dir=0,
pooled_feature_dim=0,
):
feature_vector_size: int = 3,
mode: str = "idr",
d_in: int = 9,
d_out: int = 3,
dims: Tuple[int, ...] = (512, 512, 512, 512),
weight_norm: bool = True,
n_harmonic_functions_dir: int = 0,
pooled_feature_dim: int = 0,
) -> None:
super().__init__()
self.mode = mode
self.output_dimensions = d_out
dims = [d_in + feature_vector_size] + list(dims) + [d_out]
dims_full: List[int] = [d_in + feature_vector_size] + list(dims) + [d_out]
self.embedview_fn = None
if n_harmonic_functions_dir > 0:
self.embedview_fn = HarmonicEmbedding(
n_harmonic_functions_dir, append_input=True
)
dims[0] += self.embedview_fn.get_output_dim() - 3
dims_full[0] += self.embedview_fn.get_output_dim() - 3
if pooled_feature_dim > 0:
print("Pooled features in rendering network.")
dims[0] += pooled_feature_dim
dims_full[0] += pooled_feature_dim
self.num_layers = len(dims)
self.num_layers = len(dims_full)
layers = []
for layer_idx in range(self.num_layers - 1):
out_dim = dims[layer_idx + 1]
lin = nn.Linear(dims[layer_idx], out_dim)
out_dim = dims_full[layer_idx + 1]
lin = nn.Linear(dims_full[layer_idx], out_dim)
if weight_norm:
lin = nn.utils.weight_norm(lin)

View File

@@ -4,12 +4,12 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import copy
import dataclasses
import inspect
import itertools
import warnings
from collections import Counter, defaultdict
from enum import Enum
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, TypeVar, cast
from omegaconf import DictConfig, OmegaConf, open_dict
@@ -409,13 +409,13 @@ def get_default_args(C, *, _do_not_process: Tuple[type, ...] = ()) -> DictConfig
# straight away if C has already been processed.
expand_args_fields(C, _do_not_process=_do_not_process)
kwargs = {}
if dataclasses.is_dataclass(C):
# Note that if get_default_args_field is used somewhere in C,
# this call is recursive. No special care is needed,
# because in practice get_default_args_field is used for
# separate types than the outer type.
out = OmegaConf.structured(C)
out: DictConfig = OmegaConf.structured(C)
exclude = getattr(C, "_processed_members", ())
with open_dict(out):
for field in exclude:
@@ -425,16 +425,56 @@ def get_default_args(C, *, _do_not_process: Tuple[type, ...] = ()) -> DictConfig
if _is_configurable_class(C):
raise ValueError(f"Failed to process {C}")
# returns dict of keyword args of a callable C
sig = inspect.signature(C)
for pname, defval in dict(sig.parameters).items():
if defval.default == inspect.Parameter.empty:
# print('skipping %s' % pname)
# regular class or function
field_annotations = []
for pname, defval in _params_iter(C):
default = defval.default
if default == inspect.Parameter.empty:
# we do not have a default value for the parameter
continue
if defval.annotation == inspect._empty:
raise ValueError(
"All arguments of the input callable have to be typed."
+ f" Argument '{pname}' does not have a type annotation."
)
if isinstance(default, set): # force OmegaConf to convert it to ListConfig
default = tuple(default)
if isinstance(default, (list, dict)):
# OmegaConf will convert to [Dict|List]Config, so it is safe to reuse the value
field_ = dataclasses.field(default_factory=lambda default=default: default)
elif not _is_immutable_type(defval.annotation, default):
continue
else:
kwargs[pname] = copy.deepcopy(defval.default)
# we can use a simple default argument for dataclass.field
field_ = dataclasses.field(default=default)
field_annotations.append((pname, defval.annotation, field_))
return DictConfig(kwargs)
# make a temp dataclass and generate a structured config from it.
return OmegaConf.structured(
dataclasses.make_dataclass(f"__{C.__name__}_default_args__", field_annotations)
)
def _params_iter(C):
"""Returns dict of keyword args of a class or function C."""
if inspect.isclass(C):
return itertools.islice( # exclude `self`
inspect.signature(C.__init__).parameters.items(), 1, None
)
return inspect.signature(C).parameters.items()
def _is_immutable_type(type_: Type, val: Any) -> bool:
PRIMITIVE_TYPES = (int, float, bool, str, bytes, tuple)
# sometimes type can be too relaxed (e.g. Any), so we also check values
if isinstance(val, PRIMITIVE_TYPES):
return True
return type_ in PRIMITIVE_TYPES or issubclass(type_, Enum)
def _is_actually_dataclass(some_class) -> bool: