mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
fix to get_default_args(instance)
Summary: Small config system fix. Allows get_default_args to work on an instance which has been created with a dict (instead of a DictConfig) as an args field. E.g. ``` gm = GenericModel( raysampler_AdaptiveRaySampler_args={"scene_extent": 4.0} ) OmegaConf.structured(gm1) ``` Reviewed By: shapovalov Differential Revision: D40341047 fbshipit-source-id: 587d0e8262e271df442a80858949a48e5d6db3df
This commit is contained in:
parent
76cddd90be
commit
4d9215b3b4
@ -709,8 +709,8 @@ def expand_args_fields(
|
||||
|
||||
with
|
||||
|
||||
x_Y_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(Y))
|
||||
x_Z_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(Z))
|
||||
x_Y_args: dict = dataclasses.field(default_factory=lambda: get_default_args(Y))
|
||||
x_Z_args: dict = dataclasses.field(default_factory=lambda: get_default_args(Z))
|
||||
def create_x(self):
|
||||
args = self.getattr(f"x_{self.x_class_type}_args")
|
||||
self.create_x_impl(self.x_class_type, args)
|
||||
@ -733,8 +733,8 @@ def expand_args_fields(
|
||||
|
||||
with
|
||||
|
||||
x_Y_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(Y))
|
||||
x_Z_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(Z))
|
||||
x_Y_args: dict = dataclasses.field(default_factory=lambda: get_default_args(Y))
|
||||
x_Z_args: dict = dataclasses.field(default_factory=lambda: get_default_args(Z))
|
||||
def create_x(self):
|
||||
if self.x_class_type is None:
|
||||
args = None
|
||||
@ -764,7 +764,7 @@ def expand_args_fields(
|
||||
|
||||
will be replaced with
|
||||
|
||||
x_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(X))
|
||||
x_args: dict = dataclasses.field(default_factory=lambda: get_default_args(X))
|
||||
def create_x(self):
|
||||
self.create_x_impl(True, self.x_args)
|
||||
|
||||
@ -786,7 +786,7 @@ def expand_args_fields(
|
||||
|
||||
with
|
||||
|
||||
x_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(X))
|
||||
x_args: dict = dataclasses.field(default_factory=lambda: get_default_args(X))
|
||||
x_enabled: bool = False
|
||||
def create_x(self):
|
||||
self.create_x_impl(self.x_enabled, self.x_args)
|
||||
@ -818,6 +818,11 @@ def expand_args_fields(
|
||||
then the default_factory of x_args will also have a call to x_tweak_args(X, x_args) and
|
||||
the default_factory of x_Y_args will also have a call to x_tweak_args(Y, x_Y_args).
|
||||
|
||||
Note that although the *_args members are intended to have type DictConfig, they
|
||||
are actually internally annotated as dicts. OmegaConf is happy to see a DictConfig
|
||||
in place of a dict, but not vice-versa. Allowing dict lets a class user specify
|
||||
x_args as an explicit dict without getting an incomprehensible error.
|
||||
|
||||
Args:
|
||||
some_class: the class to be processed
|
||||
_do_not_process: Internal use for get_default_args: Because get_default_args calls
|
||||
@ -1040,7 +1045,7 @@ def _process_member(
|
||||
raise ValueError(
|
||||
f"Cannot generate {args_name} because it is already present."
|
||||
)
|
||||
some_class.__annotations__[args_name] = DictConfig
|
||||
some_class.__annotations__[args_name] = dict
|
||||
if hook is not None:
|
||||
hook_closed = partial(hook, derived_type)
|
||||
else:
|
||||
@ -1064,7 +1069,7 @@ def _process_member(
|
||||
if issubclass(type_, some_class) or type_ in _do_not_process:
|
||||
raise ValueError(f"Cannot process {type_} inside {some_class}")
|
||||
|
||||
some_class.__annotations__[args_name] = DictConfig
|
||||
some_class.__annotations__[args_name] = dict
|
||||
if hook is not None:
|
||||
hook_closed = partial(hook, type_)
|
||||
else:
|
||||
|
@ -687,6 +687,34 @@ class TestConfig(unittest.TestCase):
|
||||
remove_unused_components(args)
|
||||
self.assertEqual(OmegaConf.to_yaml(args), "mt_enabled: false\n")
|
||||
|
||||
def test_get_instance_args(self):
|
||||
mt1, mt2 = [
|
||||
MainTest(
|
||||
n_ids=0,
|
||||
n_reps=909,
|
||||
the_fruit_class_type="Pear",
|
||||
the_second_fruit_class_type="Pear",
|
||||
the_fruit_Pear_args=DictConfig({}),
|
||||
the_second_fruit_Pear_args={},
|
||||
)
|
||||
for _ in range(2)
|
||||
]
|
||||
# Two equivalent ways to get the DictConfig back out of an instance.
|
||||
cfg1 = OmegaConf.structured(mt1)
|
||||
cfg2 = get_default_args(mt2)
|
||||
self.assertEqual(cfg1, cfg2)
|
||||
self.assertEqual(len(cfg1.the_second_fruit_Pear_args), 0)
|
||||
self.assertEqual(len(mt2.the_second_fruit_Pear_args), 0)
|
||||
|
||||
from_cfg = MainTest(**cfg2)
|
||||
self.assertEqual(len(from_cfg.the_second_fruit_Pear_args), 0)
|
||||
|
||||
# If you want the complete args, merge with the defaults.
|
||||
merged_args = OmegaConf.merge(get_default_args(MainTest), cfg2)
|
||||
from_merged = MainTest(**merged_args)
|
||||
self.assertEqual(len(from_merged.the_second_fruit_Pear_args), 1)
|
||||
self.assertEqual(from_merged.n_reps, 909)
|
||||
|
||||
def test_tweak_hook(self):
|
||||
class A(Configurable):
|
||||
n: int = 9
|
||||
|
Loading…
x
Reference in New Issue
Block a user