diff --git a/src/llamafactory/hparams/parser.py b/src/llamafactory/hparams/parser.py index 91cee729..54becc5b 100644 --- a/src/llamafactory/hparams/parser.py +++ b/src/llamafactory/hparams/parser.py @@ -62,11 +62,11 @@ def read_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> Union[ if sys.argv[1].endswith(".yaml") or sys.argv[1].endswith(".yml"): override_config = OmegaConf.from_cli(sys.argv[2:]) - dict_config = yaml.safe_load(Path(sys.argv[1]).absolute().read_text()) + dict_config = OmegaConf.load(Path(sys.argv[1]).absolute()) return OmegaConf.to_container(OmegaConf.merge(dict_config, override_config)) elif sys.argv[1].endswith(".json"): override_config = OmegaConf.from_cli(sys.argv[2:]) - dict_config = json.loads(Path(sys.argv[1]).absolute().read_text()) + dict_config = OmegaConf.load(Path(sys.argv[1]).absolute()) return OmegaConf.to_container(OmegaConf.merge(dict_config, override_config)) else: return sys.argv[1:]