diff --git a/pytorch3d/implicitron/dataset/json_index_dataset_map_provider.py b/pytorch3d/implicitron/dataset/json_index_dataset_map_provider.py index d4f628c3..35ee7567 100644 --- a/pytorch3d/implicitron/dataset/json_index_dataset_map_provider.py +++ b/pytorch3d/implicitron/dataset/json_index_dataset_map_provider.py @@ -113,9 +113,8 @@ class JsonIndexDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13] Called by get_default_args(JsonIndexDatasetMapProvider) to not expose certain fields of each dataset class. """ - with open_dict(args): - for key in _NEED_CONTROL: - del args[key] + for key in _NEED_CONTROL: + del args[key] def create_dataset(self): """ diff --git a/pytorch3d/implicitron/dataset/json_index_dataset_map_provider_v2.py b/pytorch3d/implicitron/dataset/json_index_dataset_map_provider_v2.py index 29027ec6..a5257218 100644 --- a/pytorch3d/implicitron/dataset/json_index_dataset_map_provider_v2.py +++ b/pytorch3d/implicitron/dataset/json_index_dataset_map_provider_v2.py @@ -300,9 +300,8 @@ class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase): # pyre-ignore [13] Called by get_default_args(JsonIndexDatasetMapProviderV2) to not expose certain fields of each dataset class. """ - with open_dict(args): - for key in _NEED_CONTROL: - del args[key] + for key in _NEED_CONTROL: + del args[key] def create_dataset(self): # The dataset object is created inside `self.get_dataset_map` diff --git a/pytorch3d/implicitron/tools/config.py b/pytorch3d/implicitron/tools/config.py index 1605f8b4..9fb7f6c7 100644 --- a/pytorch3d/implicitron/tools/config.py +++ b/pytorch3d/implicitron/tools/config.py @@ -881,7 +881,8 @@ def get_default_args_field( def create(): args = get_default_args(C, _do_not_process=_do_not_process) if _hook is not None: - _hook(args) + with open_dict(args): + _hook(args) return args return dataclasses.field(default_factory=create) @@ -915,7 +916,8 @@ def _get_default_args_field_from_registry( C = registry.get(base_class_wanted=base_class_wanted, name=name) args = get_default_args(C, _do_not_process=_do_not_process) if _hook is not None: - _hook(args) + with open_dict(args): + _hook(args) return args return dataclasses.field(default_factory=create) diff --git a/tests/implicitron/test_config.py b/tests/implicitron/test_config.py index 590b9dea..56e7ecab 100644 --- a/tests/implicitron/test_config.py +++ b/tests/implicitron/test_config.py @@ -691,12 +691,17 @@ class TestConfig(unittest.TestCase): fruit2_class_type: str = "Pear" a: A a2: A + a3: A @classmethod def a_tweak_args(cls, type, args): assert type == A args.n = 993 + @classmethod + def a3_tweak_args(cls, type, args): + del args["n"] + @classmethod def fruit_tweak_args(cls, type, args): assert issubclass(type, Fruit) @@ -707,6 +712,7 @@ class TestConfig(unittest.TestCase): args = get_default_args(Wrapper) self.assertEqual(args.a_args.n, 993) self.assertEqual(args.a2_args.n, 9) + self.assertEqual(args.a3_args, {}) self.assertEqual(args.fruit_Pear_args.n_pips, 19) self.assertEqual(args.fruit2_Pear_args.n_pips, 13)