fix to get_default_args(instance)

Summary:
Small config system fix. Allows get_default_args to work on an instance which has been created with a dict (instead of a DictConfig) as an args field. E.g.

```
gm = GenericModel(
        raysampler_AdaptiveRaySampler_args={"scene_extent": 4.0}
    )
    OmegaConf.structured(gm1)
```

Reviewed By: shapovalov

Differential Revision: D40341047

fbshipit-source-id: 587d0e8262e271df442a80858949a48e5d6db3df
This commit is contained in:
Jeremy Reizenstein
2022-10-13 06:05:07 -07:00
committed by Facebook GitHub Bot
parent 76cddd90be
commit 4d9215b3b4
2 changed files with 41 additions and 8 deletions

View File

@@ -709,8 +709,8 @@ def expand_args_fields(
with
x_Y_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(Y))
x_Z_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(Z))
x_Y_args: dict = dataclasses.field(default_factory=lambda: get_default_args(Y))
x_Z_args: dict = dataclasses.field(default_factory=lambda: get_default_args(Z))
def create_x(self):
args = self.getattr(f"x_{self.x_class_type}_args")
self.create_x_impl(self.x_class_type, args)
@@ -733,8 +733,8 @@ def expand_args_fields(
with
x_Y_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(Y))
x_Z_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(Z))
x_Y_args: dict = dataclasses.field(default_factory=lambda: get_default_args(Y))
x_Z_args: dict = dataclasses.field(default_factory=lambda: get_default_args(Z))
def create_x(self):
if self.x_class_type is None:
args = None
@@ -764,7 +764,7 @@ def expand_args_fields(
will be replaced with
x_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(X))
x_args: dict = dataclasses.field(default_factory=lambda: get_default_args(X))
def create_x(self):
self.create_x_impl(True, self.x_args)
@@ -786,7 +786,7 @@ def expand_args_fields(
with
x_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(X))
x_args: dict = dataclasses.field(default_factory=lambda: get_default_args(X))
x_enabled: bool = False
def create_x(self):
self.create_x_impl(self.x_enabled, self.x_args)
@@ -818,6 +818,11 @@ def expand_args_fields(
then the default_factory of x_args will also have a call to x_tweak_args(X, x_args) and
the default_factory of x_Y_args will also have a call to x_tweak_args(Y, x_Y_args).
Note that although the *_args members are intended to have type DictConfig, they
are actually internally annotated as dicts. OmegaConf is happy to see a DictConfig
in place of a dict, but not vice-versa. Allowing dict lets a class user specify
x_args as an explicit dict without getting an incomprehensible error.
Args:
some_class: the class to be processed
_do_not_process: Internal use for get_default_args: Because get_default_args calls
@@ -1040,7 +1045,7 @@ def _process_member(
raise ValueError(
f"Cannot generate {args_name} because it is already present."
)
some_class.__annotations__[args_name] = DictConfig
some_class.__annotations__[args_name] = dict
if hook is not None:
hook_closed = partial(hook, derived_type)
else:
@@ -1064,7 +1069,7 @@ def _process_member(
if issubclass(type_, some_class) or type_ in _do_not_process:
raise ValueError(f"Cannot process {type_} inside {some_class}")
some_class.__annotations__[args_name] = DictConfig
some_class.__annotations__[args_name] = dict
if hook is not None:
hook_closed = partial(hook, type_)
else: