diff --git a/examples/extras/eaft/qwen25_05b_eaft_full.yaml b/examples/extras/eaft/qwen25_05b_eaft_full.yaml index 904858f73..72c0b0184 100644 --- a/examples/extras/eaft/qwen25_05b_eaft_full.yaml +++ b/examples/extras/eaft/qwen25_05b_eaft_full.yaml @@ -36,5 +36,3 @@ lr_scheduler_type: cosine warmup_ratio: 0.1 bf16: true ddp_timeout: 180000000 - - diff --git a/src/llamafactory/hparams/parser.py b/src/llamafactory/hparams/parser.py index b1ffbb706..6e9541b84 100644 --- a/src/llamafactory/hparams/parser.py +++ b/src/llamafactory/hparams/parser.py @@ -15,6 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json import os import sys from pathlib import Path @@ -70,13 +71,13 @@ def read_args(args: dict[str, Any] | list[str] | None = None) -> dict[str, Any] if args is not None: return args - if sys.argv[1].endswith(".yaml") or sys.argv[1].endswith(".yml"): + if len(sys.argv) > 1 and (sys.argv[1].endswith(".yaml") or sys.argv[1].endswith(".yml")): override_config = OmegaConf.from_cli(sys.argv[2:]) dict_config = OmegaConf.load(Path(sys.argv[1]).absolute()) return OmegaConf.to_container(OmegaConf.merge(dict_config, override_config)) - elif sys.argv[1].endswith(".json"): + elif len(sys.argv) > 1 and sys.argv[1].endswith(".json"): override_config = OmegaConf.from_cli(sys.argv[2:]) - dict_config = OmegaConf.load(Path(sys.argv[1]).absolute()) + dict_config = OmegaConf.create(json.load(Path(sys.argv[1]).absolute())) return OmegaConf.to_container(OmegaConf.merge(dict_config, override_config)) else: return sys.argv[1:] diff --git a/src/llamafactory/v1/config/arg_parser.py b/src/llamafactory/v1/config/arg_parser.py index adec3e4bb..aee30efaf 100644 --- a/src/llamafactory/v1/config/arg_parser.py +++ b/src/llamafactory/v1/config/arg_parser.py @@ -30,21 +30,6 @@ from .training_args import TrainingArguments InputArgument = dict[str, Any] | list[str] | None -def validate_args( - data_args: DataArguments, - model_args: ModelArguments, - training_args: TrainingArguments, - sample_args: SampleArguments, -): - """Validate arguments.""" - if ( - model_args.quant_config is not None - and training_args.dist_config is not None - and training_args.dist_config.name == "deepspeed" - ): - raise ValueError("Quantization is not supported with deepspeed backend.") - - def get_args(args: InputArgument = None) -> tuple[DataArguments, ModelArguments, TrainingArguments, SampleArguments]: """Parse arguments from command line or config file.""" parser = HfArgumentParser([DataArguments, ModelArguments, TrainingArguments, SampleArguments]) @@ -71,8 +56,6 @@ def get_args(args: InputArgument = None) -> tuple[DataArguments, ModelArguments, print(f"Got unknown args, potentially deprecated arguments: {unknown_args}") raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {unknown_args}") - validate_args(*parsed_args) - return tuple(parsed_args)