diff --git a/projects/implicitron_trainer/experiment.py b/projects/implicitron_trainer/experiment.py index 96de5f67..198eb77d 100755 --- a/projects/implicitron_trainer/experiment.py +++ b/projects/implicitron_trainer/experiment.py @@ -708,9 +708,8 @@ class ExperimentConfig(Configurable): expand_args_fields(ExperimentConfig) -if __name__ == "__main__": - cs = hydra.core.config_store.ConfigStore.instance() - cs.store(name="default_config", node=ExperimentConfig) +cs = hydra.core.config_store.ConfigStore.instance() +cs.store(name="default_config", node=ExperimentConfig) @hydra.main(config_path="./configs/", config_name="default_config") diff --git a/projects/implicitron_trainer/tests/test_experiment.py b/projects/implicitron_trainer/tests/test_experiment.py index beafe75b..ee41ce9f 100644 --- a/projects/implicitron_trainer/tests/test_experiment.py +++ b/projects/implicitron_trainer/tests/test_experiment.py @@ -10,6 +10,7 @@ from pathlib import Path import experiment import torch +from hydra import compose, initialize_config_dir from omegaconf import OmegaConf @@ -23,6 +24,7 @@ def interactive_testing_requested() -> bool: DATA_DIR = Path(__file__).resolve().parent +IMPLICITRON_CONFIGS_DIR = Path(__file__).resolve().parent.parent / "configs" DEBUG: bool = False # TODO: @@ -65,3 +67,20 @@ class TestExperiment(unittest.TestCase): if DEBUG: (DATA_DIR / "experiment.yaml").write_text(yaml) self.assertEqual(yaml, (DATA_DIR / "experiment.yaml").read_text()) + + def test_load_configs(self): + config_files = [] + + for pattern in ("repro_singleseq*.yaml", "repro_multiseq*.yaml"): + config_files.extend( + [ + f + for f in IMPLICITRON_CONFIGS_DIR.glob(pattern) + if not f.name.endswith("_base.yaml") + ] + ) + + for file in config_files: + with self.subTest(file.name): + with initialize_config_dir(config_dir=str(IMPLICITRON_CONFIGS_DIR)): + compose(file.name)