mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 18:02:14 +08:00 
			
		
		
		
	fix to get_default_args(instance)
Summary:
Small config system fix. Allows get_default_args to work on an instance which has been created with a dict (instead of a DictConfig) as an args field. E.g.
```
gm = GenericModel(
        raysampler_AdaptiveRaySampler_args={"scene_extent": 4.0}
    )
    OmegaConf.structured(gm1)
```
Reviewed By: shapovalov
Differential Revision: D40341047
fbshipit-source-id: 587d0e8262e271df442a80858949a48e5d6db3df
			
			
This commit is contained in:
		
							parent
							
								
									76cddd90be
								
							
						
					
					
						commit
						4d9215b3b4
					
				@ -709,8 +709,8 @@ def expand_args_fields(
 | 
			
		||||
 | 
			
		||||
    with
 | 
			
		||||
 | 
			
		||||
        x_Y_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(Y))
 | 
			
		||||
        x_Z_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(Z))
 | 
			
		||||
        x_Y_args: dict = dataclasses.field(default_factory=lambda: get_default_args(Y))
 | 
			
		||||
        x_Z_args: dict = dataclasses.field(default_factory=lambda: get_default_args(Z))
 | 
			
		||||
        def create_x(self):
 | 
			
		||||
            args = self.getattr(f"x_{self.x_class_type}_args")
 | 
			
		||||
            self.create_x_impl(self.x_class_type, args)
 | 
			
		||||
@ -733,8 +733,8 @@ def expand_args_fields(
 | 
			
		||||
 | 
			
		||||
    with
 | 
			
		||||
 | 
			
		||||
        x_Y_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(Y))
 | 
			
		||||
        x_Z_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(Z))
 | 
			
		||||
        x_Y_args: dict = dataclasses.field(default_factory=lambda: get_default_args(Y))
 | 
			
		||||
        x_Z_args: dict = dataclasses.field(default_factory=lambda: get_default_args(Z))
 | 
			
		||||
        def create_x(self):
 | 
			
		||||
            if self.x_class_type is None:
 | 
			
		||||
                args = None
 | 
			
		||||
@ -764,7 +764,7 @@ def expand_args_fields(
 | 
			
		||||
 | 
			
		||||
    will be replaced with
 | 
			
		||||
 | 
			
		||||
        x_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(X))
 | 
			
		||||
        x_args: dict = dataclasses.field(default_factory=lambda: get_default_args(X))
 | 
			
		||||
        def create_x(self):
 | 
			
		||||
            self.create_x_impl(True, self.x_args)
 | 
			
		||||
 | 
			
		||||
@ -786,7 +786,7 @@ def expand_args_fields(
 | 
			
		||||
 | 
			
		||||
    with
 | 
			
		||||
 | 
			
		||||
        x_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(X))
 | 
			
		||||
        x_args: dict = dataclasses.field(default_factory=lambda: get_default_args(X))
 | 
			
		||||
        x_enabled: bool = False
 | 
			
		||||
        def create_x(self):
 | 
			
		||||
            self.create_x_impl(self.x_enabled, self.x_args)
 | 
			
		||||
@ -818,6 +818,11 @@ def expand_args_fields(
 | 
			
		||||
    then the default_factory of x_args will also have a call to x_tweak_args(X, x_args) and
 | 
			
		||||
    the default_factory of x_Y_args will also have a call to x_tweak_args(Y, x_Y_args).
 | 
			
		||||
 | 
			
		||||
    Note that although the *_args members are intended to have type DictConfig, they
 | 
			
		||||
    are actually internally annotated as dicts. OmegaConf is happy to see a DictConfig
 | 
			
		||||
    in place of a dict, but not vice-versa. Allowing dict lets a class user specify
 | 
			
		||||
    x_args as an explicit dict without getting an incomprehensible error.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        some_class: the class to be processed
 | 
			
		||||
        _do_not_process: Internal use for get_default_args: Because get_default_args calls
 | 
			
		||||
@ -1040,7 +1045,7 @@ def _process_member(
 | 
			
		||||
                raise ValueError(
 | 
			
		||||
                    f"Cannot generate {args_name} because it is already present."
 | 
			
		||||
                )
 | 
			
		||||
            some_class.__annotations__[args_name] = DictConfig
 | 
			
		||||
            some_class.__annotations__[args_name] = dict
 | 
			
		||||
            if hook is not None:
 | 
			
		||||
                hook_closed = partial(hook, derived_type)
 | 
			
		||||
            else:
 | 
			
		||||
@ -1064,7 +1069,7 @@ def _process_member(
 | 
			
		||||
        if issubclass(type_, some_class) or type_ in _do_not_process:
 | 
			
		||||
            raise ValueError(f"Cannot process {type_} inside {some_class}")
 | 
			
		||||
 | 
			
		||||
        some_class.__annotations__[args_name] = DictConfig
 | 
			
		||||
        some_class.__annotations__[args_name] = dict
 | 
			
		||||
        if hook is not None:
 | 
			
		||||
            hook_closed = partial(hook, type_)
 | 
			
		||||
        else:
 | 
			
		||||
 | 
			
		||||
@ -687,6 +687,34 @@ class TestConfig(unittest.TestCase):
 | 
			
		||||
        remove_unused_components(args)
 | 
			
		||||
        self.assertEqual(OmegaConf.to_yaml(args), "mt_enabled: false\n")
 | 
			
		||||
 | 
			
		||||
    def test_get_instance_args(self):
 | 
			
		||||
        mt1, mt2 = [
 | 
			
		||||
            MainTest(
 | 
			
		||||
                n_ids=0,
 | 
			
		||||
                n_reps=909,
 | 
			
		||||
                the_fruit_class_type="Pear",
 | 
			
		||||
                the_second_fruit_class_type="Pear",
 | 
			
		||||
                the_fruit_Pear_args=DictConfig({}),
 | 
			
		||||
                the_second_fruit_Pear_args={},
 | 
			
		||||
            )
 | 
			
		||||
            for _ in range(2)
 | 
			
		||||
        ]
 | 
			
		||||
        # Two equivalent ways to get the DictConfig back out of an instance.
 | 
			
		||||
        cfg1 = OmegaConf.structured(mt1)
 | 
			
		||||
        cfg2 = get_default_args(mt2)
 | 
			
		||||
        self.assertEqual(cfg1, cfg2)
 | 
			
		||||
        self.assertEqual(len(cfg1.the_second_fruit_Pear_args), 0)
 | 
			
		||||
        self.assertEqual(len(mt2.the_second_fruit_Pear_args), 0)
 | 
			
		||||
 | 
			
		||||
        from_cfg = MainTest(**cfg2)
 | 
			
		||||
        self.assertEqual(len(from_cfg.the_second_fruit_Pear_args), 0)
 | 
			
		||||
 | 
			
		||||
        # If you want the complete args, merge with the defaults.
 | 
			
		||||
        merged_args = OmegaConf.merge(get_default_args(MainTest), cfg2)
 | 
			
		||||
        from_merged = MainTest(**merged_args)
 | 
			
		||||
        self.assertEqual(len(from_merged.the_second_fruit_Pear_args), 1)
 | 
			
		||||
        self.assertEqual(from_merged.n_reps, 909)
 | 
			
		||||
 | 
			
		||||
    def test_tweak_hook(self):
 | 
			
		||||
        class A(Configurable):
 | 
			
		||||
            n: int = 9
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user