make x_enabled compulsory

Summary: Optional[some_configurable] won't autogenerate the enabled flag

Reviewed By: shapovalov

Differential Revision: D41522104

fbshipit-source-id: 555ff6b343faf6f18aad2f92fbb7c341f5e991c6
This commit is contained in:
Jeremy Reizenstein 2022-11-24 09:38:02 -08:00 committed by Facebook GitHub Bot
parent 1706eb8216
commit 60ab1cdb72
2 changed files with 10 additions and 5 deletions

View File

@ -783,16 +783,16 @@ def expand_args_fields(
Similarly, replace, Similarly, replace,
x: Optional[X] x: Optional[X]
x_enabled: bool = ...
and optionally and optionally
def create_x(self):... def create_x(self):...
x_enabled: bool = ...
with with
x_args: dict = dataclasses.field(default_factory=lambda: get_default_args(X)) x_args: dict = dataclasses.field(default_factory=lambda: get_default_args(X))
x_enabled: bool = False x_enabled: bool = ...
def create_x(self): def create_x(self):
self.create_x_impl(self.x_enabled, self.x_args) self.create_x_impl(self.x_enabled, self.x_args)
@ -1091,8 +1091,10 @@ def _process_member(
if process_type == _ProcessType.OPTIONAL_CONFIGURABLE: if process_type == _ProcessType.OPTIONAL_CONFIGURABLE:
enabled_name = name + ENABLED_SUFFIX enabled_name = name + ENABLED_SUFFIX
if enabled_name not in some_class.__annotations__: if enabled_name not in some_class.__annotations__:
some_class.__annotations__[enabled_name] = bool raise ValueError(
setattr(some_class, enabled_name, False) f"{name} is an Optional[{type_.__name__}] member "
f"but there is no corresponding member {enabled_name}."
)
creation_function_name = f"{CREATE_PREFIX}{name}" creation_function_name = f"{CREATE_PREFIX}{name}"
if not hasattr(some_class, creation_function_name): if not hasattr(some_class, creation_function_name):

View File

@ -446,6 +446,7 @@ class TestConfig(unittest.TestCase):
b2: Optional[B] b2: Optional[B]
b3: Optional[B] b3: Optional[B]
b2_enabled: bool = True b2_enabled: bool = True
b3_enabled: bool = False
def __post_init__(self): def __post_init__(self):
run_auto_creation(self) run_auto_creation(self)
@ -681,9 +682,10 @@ class TestConfig(unittest.TestCase):
def test_remove_unused_components_optional(self): def test_remove_unused_components_optional(self):
class MainTestWrapper(Configurable): class MainTestWrapper(Configurable):
mt: Optional[MainTest] mt: Optional[MainTest]
mt_enabled: bool = False
args = get_default_args(MainTestWrapper) args = get_default_args(MainTestWrapper)
self.assertEqual(list(args.keys()), ["mt_args", "mt_enabled"]) self.assertEqual(list(args.keys()), ["mt_enabled", "mt_args"])
remove_unused_components(args) remove_unused_components(args)
self.assertEqual(OmegaConf.to_yaml(args), "mt_enabled: false\n") self.assertEqual(OmegaConf.to_yaml(args), "mt_enabled: false\n")
@ -775,6 +777,7 @@ class TestConfig(unittest.TestCase):
boring_o: Optional[BoringConfigurable] boring_o: Optional[BoringConfigurable]
boring_o_enabled: bool = True boring_o_enabled: bool = True
boring_0: Optional[BoringConfigurable] boring_0: Optional[BoringConfigurable]
boring_0_enabled: bool = False
def __post_init__(self): def __post_init__(self):
run_auto_creation(self) run_auto_creation(self)