mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
make x_enabled compulsory
Summary: Optional[some_configurable] won't autogenerate the enabled flag Reviewed By: shapovalov Differential Revision: D41522104 fbshipit-source-id: 555ff6b343faf6f18aad2f92fbb7c341f5e991c6
This commit is contained in:
parent
1706eb8216
commit
60ab1cdb72
@ -783,16 +783,16 @@ def expand_args_fields(
|
|||||||
Similarly, replace,
|
Similarly, replace,
|
||||||
|
|
||||||
x: Optional[X]
|
x: Optional[X]
|
||||||
|
x_enabled: bool = ...
|
||||||
|
|
||||||
and optionally
|
and optionally
|
||||||
|
|
||||||
def create_x(self):...
|
def create_x(self):...
|
||||||
x_enabled: bool = ...
|
|
||||||
|
|
||||||
with
|
with
|
||||||
|
|
||||||
x_args: dict = dataclasses.field(default_factory=lambda: get_default_args(X))
|
x_args: dict = dataclasses.field(default_factory=lambda: get_default_args(X))
|
||||||
x_enabled: bool = False
|
x_enabled: bool = ...
|
||||||
def create_x(self):
|
def create_x(self):
|
||||||
self.create_x_impl(self.x_enabled, self.x_args)
|
self.create_x_impl(self.x_enabled, self.x_args)
|
||||||
|
|
||||||
@ -1091,8 +1091,10 @@ def _process_member(
|
|||||||
if process_type == _ProcessType.OPTIONAL_CONFIGURABLE:
|
if process_type == _ProcessType.OPTIONAL_CONFIGURABLE:
|
||||||
enabled_name = name + ENABLED_SUFFIX
|
enabled_name = name + ENABLED_SUFFIX
|
||||||
if enabled_name not in some_class.__annotations__:
|
if enabled_name not in some_class.__annotations__:
|
||||||
some_class.__annotations__[enabled_name] = bool
|
raise ValueError(
|
||||||
setattr(some_class, enabled_name, False)
|
f"{name} is an Optional[{type_.__name__}] member "
|
||||||
|
f"but there is no corresponding member {enabled_name}."
|
||||||
|
)
|
||||||
|
|
||||||
creation_function_name = f"{CREATE_PREFIX}{name}"
|
creation_function_name = f"{CREATE_PREFIX}{name}"
|
||||||
if not hasattr(some_class, creation_function_name):
|
if not hasattr(some_class, creation_function_name):
|
||||||
|
@ -446,6 +446,7 @@ class TestConfig(unittest.TestCase):
|
|||||||
b2: Optional[B]
|
b2: Optional[B]
|
||||||
b3: Optional[B]
|
b3: Optional[B]
|
||||||
b2_enabled: bool = True
|
b2_enabled: bool = True
|
||||||
|
b3_enabled: bool = False
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
run_auto_creation(self)
|
run_auto_creation(self)
|
||||||
@ -681,9 +682,10 @@ class TestConfig(unittest.TestCase):
|
|||||||
def test_remove_unused_components_optional(self):
|
def test_remove_unused_components_optional(self):
|
||||||
class MainTestWrapper(Configurable):
|
class MainTestWrapper(Configurable):
|
||||||
mt: Optional[MainTest]
|
mt: Optional[MainTest]
|
||||||
|
mt_enabled: bool = False
|
||||||
|
|
||||||
args = get_default_args(MainTestWrapper)
|
args = get_default_args(MainTestWrapper)
|
||||||
self.assertEqual(list(args.keys()), ["mt_args", "mt_enabled"])
|
self.assertEqual(list(args.keys()), ["mt_enabled", "mt_args"])
|
||||||
remove_unused_components(args)
|
remove_unused_components(args)
|
||||||
self.assertEqual(OmegaConf.to_yaml(args), "mt_enabled: false\n")
|
self.assertEqual(OmegaConf.to_yaml(args), "mt_enabled: false\n")
|
||||||
|
|
||||||
@ -775,6 +777,7 @@ class TestConfig(unittest.TestCase):
|
|||||||
boring_o: Optional[BoringConfigurable]
|
boring_o: Optional[BoringConfigurable]
|
||||||
boring_o_enabled: bool = True
|
boring_o_enabled: bool = True
|
||||||
boring_0: Optional[BoringConfigurable]
|
boring_0: Optional[BoringConfigurable]
|
||||||
|
boring_0_enabled: bool = False
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
run_auto_creation(self)
|
run_auto_creation(self)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user