mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 18:02:14 +08:00 
			
		
		
		
	make x_enabled compulsory
Summary: Optional[some_configurable] won't autogenerate the enabled flag Reviewed By: shapovalov Differential Revision: D41522104 fbshipit-source-id: 555ff6b343faf6f18aad2f92fbb7c341f5e991c6
This commit is contained in:
		
							parent
							
								
									1706eb8216
								
							
						
					
					
						commit
						60ab1cdb72
					
				@ -783,16 +783,16 @@ def expand_args_fields(
 | 
			
		||||
    Similarly, replace,
 | 
			
		||||
 | 
			
		||||
        x: Optional[X]
 | 
			
		||||
        x_enabled: bool = ...
 | 
			
		||||
 | 
			
		||||
    and optionally
 | 
			
		||||
 | 
			
		||||
        def create_x(self):...
 | 
			
		||||
        x_enabled: bool = ...
 | 
			
		||||
 | 
			
		||||
    with
 | 
			
		||||
 | 
			
		||||
        x_args: dict = dataclasses.field(default_factory=lambda: get_default_args(X))
 | 
			
		||||
        x_enabled: bool = False
 | 
			
		||||
        x_enabled: bool = ...
 | 
			
		||||
        def create_x(self):
 | 
			
		||||
            self.create_x_impl(self.x_enabled, self.x_args)
 | 
			
		||||
 | 
			
		||||
@ -1091,8 +1091,10 @@ def _process_member(
 | 
			
		||||
        if process_type == _ProcessType.OPTIONAL_CONFIGURABLE:
 | 
			
		||||
            enabled_name = name + ENABLED_SUFFIX
 | 
			
		||||
            if enabled_name not in some_class.__annotations__:
 | 
			
		||||
                some_class.__annotations__[enabled_name] = bool
 | 
			
		||||
                setattr(some_class, enabled_name, False)
 | 
			
		||||
                raise ValueError(
 | 
			
		||||
                    f"{name} is an Optional[{type_.__name__}] member "
 | 
			
		||||
                    f"but there is no corresponding member {enabled_name}."
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
    creation_function_name = f"{CREATE_PREFIX}{name}"
 | 
			
		||||
    if not hasattr(some_class, creation_function_name):
 | 
			
		||||
 | 
			
		||||
@ -446,6 +446,7 @@ class TestConfig(unittest.TestCase):
 | 
			
		||||
            b2: Optional[B]
 | 
			
		||||
            b3: Optional[B]
 | 
			
		||||
            b2_enabled: bool = True
 | 
			
		||||
            b3_enabled: bool = False
 | 
			
		||||
 | 
			
		||||
            def __post_init__(self):
 | 
			
		||||
                run_auto_creation(self)
 | 
			
		||||
@ -681,9 +682,10 @@ class TestConfig(unittest.TestCase):
 | 
			
		||||
    def test_remove_unused_components_optional(self):
 | 
			
		||||
        class MainTestWrapper(Configurable):
 | 
			
		||||
            mt: Optional[MainTest]
 | 
			
		||||
            mt_enabled: bool = False
 | 
			
		||||
 | 
			
		||||
        args = get_default_args(MainTestWrapper)
 | 
			
		||||
        self.assertEqual(list(args.keys()), ["mt_args", "mt_enabled"])
 | 
			
		||||
        self.assertEqual(list(args.keys()), ["mt_enabled", "mt_args"])
 | 
			
		||||
        remove_unused_components(args)
 | 
			
		||||
        self.assertEqual(OmegaConf.to_yaml(args), "mt_enabled: false\n")
 | 
			
		||||
 | 
			
		||||
@ -775,6 +777,7 @@ class TestConfig(unittest.TestCase):
 | 
			
		||||
                boring_o: Optional[BoringConfigurable]
 | 
			
		||||
                boring_o_enabled: bool = True
 | 
			
		||||
                boring_0: Optional[BoringConfigurable]
 | 
			
		||||
                boring_0_enabled: bool = False
 | 
			
		||||
 | 
			
		||||
                def __post_init__(self):
 | 
			
		||||
                    run_auto_creation(self)
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user