From 278df4308dbea047ab27a704103c2f9e4ea32c45 Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Mon, 21 Apr 2025 23:30:30 +0800 Subject: [PATCH] [parser] support omegaconf (#7793) --- README.md | 2 +- README_zh.md | 2 +- examples/README.md | 13 ++++++- examples/README_zh.md | 13 ++++++- examples/inference/llama3.yaml | 2 +- examples/inference/llama3_full_sft.yaml | 2 +- examples/inference/llama3_lora_sft.yaml | 2 +- examples/inference/llama3_sglang.yaml | 4 -- examples/inference/llama3_vllm.yaml | 5 --- examples/inference/llava1_5.yaml | 4 -- examples/inference/qwen2_vl.yaml | 4 +- examples/merge_lora/llama3_full_sft.yaml | 2 +- examples/merge_lora/llama3_gptq.yaml | 4 +- examples/merge_lora/llama3_lora_sft.yaml | 2 +- examples/merge_lora/qwen2vl_lora_sft.yaml | 4 +- examples/train_lora/llava1_5_lora_sft.yaml | 45 ---------------------- examples/train_lora/qwen2vl_lora_dpo.yaml | 2 +- examples/train_lora/qwen2vl_lora_sft.yaml | 2 +- requirements.txt | 1 + src/llamafactory/hparams/parser.py | 23 ++++++++--- src/llamafactory/train/dpo/workflow.py | 3 -- src/llamafactory/train/kto/workflow.py | 3 -- src/llamafactory/train/rm/workflow.py | 3 -- src/llamafactory/train/sft/workflow.py | 5 --- tests/train/test_sft_trainer.py | 4 ++ 25 files changed, 62 insertions(+), 94 deletions(-) delete mode 100644 examples/inference/llama3_sglang.yaml delete mode 100644 examples/inference/llama3_vllm.yaml delete mode 100644 examples/inference/llava1_5.yaml delete mode 100644 examples/train_lora/llava1_5_lora_sft.yaml diff --git a/README.md b/README.md index 77c8838d..e2abe336 100644 --- a/README.md +++ b/README.md @@ -726,7 +726,7 @@ docker exec -it llamafactory bash ### Deploy with OpenAI-style API and vLLM ```bash -API_PORT=8000 llamafactory-cli api examples/inference/llama3_vllm.yaml +API_PORT=8000 llamafactory-cli api examples/inference/llama3.yaml infer_backend=vllm vllm_enforce_eager=true ``` > [!TIP] diff --git a/README_zh.md b/README_zh.md index f1db3910..d052d1f6 100644 --- a/README_zh.md +++ b/README_zh.md @@ -730,7 +730,7 @@ docker exec -it llamafactory bash ### 利用 vLLM 部署 OpenAI API ```bash -API_PORT=8000 llamafactory-cli api examples/inference/llama3_vllm.yaml +API_PORT=8000 llamafactory-cli api examples/inference/llama3.yaml infer_backend=vllm vllm_enforce_eager=true ``` > [!TIP] diff --git a/examples/README.md b/examples/README.md index 457ec87f..f421d47d 100644 --- a/examples/README.md +++ b/examples/README.md @@ -15,6 +15,18 @@ Use `CUDA_VISIBLE_DEVICES` (GPU) or `ASCEND_RT_VISIBLE_DEVICES` (NPU) to choose By default, LLaMA-Factory uses all visible computing devices. +Basic usage: + +```bash +llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml +``` + +Advanced usage: + +```bash +CUDA_VISIBLE_DEVICES=0,1 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml learning_rate=1e-5 logging_steps=1 +``` + ## Examples ### LoRA Fine-Tuning @@ -34,7 +46,6 @@ llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml #### Multimodal Supervised Fine-Tuning ```bash -llamafactory-cli train examples/train_lora/llava1_5_lora_sft.yaml llamafactory-cli train examples/train_lora/qwen2vl_lora_sft.yaml ``` diff --git a/examples/README_zh.md b/examples/README_zh.md index 4899e279..9ff0e994 100644 --- a/examples/README_zh.md +++ b/examples/README_zh.md @@ -15,6 +15,18 @@ LLaMA-Factory 默认使用所有可见的计算设备。 +基础用法: + +```bash +llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml +``` + +高级用法: + +```bash +CUDA_VISIBLE_DEVICES=0,1 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml learning_rate=1e-5 logging_steps=1 +``` + ## 示例 ### LoRA 微调 @@ -34,7 +46,6 @@ llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml #### 多模态指令监督微调 ```bash -llamafactory-cli train examples/train_lora/llava1_5_lora_sft.yaml llamafactory-cli train examples/train_lora/qwen2vl_lora_sft.yaml ``` diff --git a/examples/inference/llama3.yaml b/examples/inference/llama3.yaml index 2851e9a3..5d5381c8 100644 --- a/examples/inference/llama3.yaml +++ b/examples/inference/llama3.yaml @@ -1,4 +1,4 @@ model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct template: llama3 -infer_backend: huggingface # choices: [huggingface, vllm] +infer_backend: huggingface # choices: [huggingface, vllm, sglang] trust_remote_code: true diff --git a/examples/inference/llama3_full_sft.yaml b/examples/inference/llama3_full_sft.yaml index d4555ca8..5d8acabe 100644 --- a/examples/inference/llama3_full_sft.yaml +++ b/examples/inference/llama3_full_sft.yaml @@ -1,4 +1,4 @@ model_name_or_path: saves/llama3-8b/full/sft template: llama3 -infer_backend: huggingface # choices: [huggingface, vllm] +infer_backend: huggingface # choices: [huggingface, vllm, sglang] trust_remote_code: true diff --git a/examples/inference/llama3_lora_sft.yaml b/examples/inference/llama3_lora_sft.yaml index 7796c526..0f5e9f84 100644 --- a/examples/inference/llama3_lora_sft.yaml +++ b/examples/inference/llama3_lora_sft.yaml @@ -1,5 +1,5 @@ model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct adapter_name_or_path: saves/llama3-8b/lora/sft template: llama3 -infer_backend: huggingface # choices: [huggingface, vllm] +infer_backend: huggingface # choices: [huggingface, vllm, sglang] trust_remote_code: true diff --git a/examples/inference/llama3_sglang.yaml b/examples/inference/llama3_sglang.yaml deleted file mode 100644 index 82418981..00000000 --- a/examples/inference/llama3_sglang.yaml +++ /dev/null @@ -1,4 +0,0 @@ -model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct -template: llama3 -infer_backend: sglang -trust_remote_code: true diff --git a/examples/inference/llama3_vllm.yaml b/examples/inference/llama3_vllm.yaml deleted file mode 100644 index 4379956c..00000000 --- a/examples/inference/llama3_vllm.yaml +++ /dev/null @@ -1,5 +0,0 @@ -model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct -template: llama3 -infer_backend: vllm -vllm_enforce_eager: true -trust_remote_code: true diff --git a/examples/inference/llava1_5.yaml b/examples/inference/llava1_5.yaml deleted file mode 100644 index 2e934ddc..00000000 --- a/examples/inference/llava1_5.yaml +++ /dev/null @@ -1,4 +0,0 @@ -model_name_or_path: llava-hf/llava-1.5-7b-hf -template: llava -infer_backend: huggingface # choices: [huggingface, vllm] -trust_remote_code: true diff --git a/examples/inference/qwen2_vl.yaml b/examples/inference/qwen2_vl.yaml index b5eabc66..d8f88dc2 100644 --- a/examples/inference/qwen2_vl.yaml +++ b/examples/inference/qwen2_vl.yaml @@ -1,4 +1,4 @@ -model_name_or_path: Qwen/Qwen2-VL-7B-Instruct +model_name_or_path: Qwen/Qwen2.5-VL-7B-Instruct template: qwen2_vl -infer_backend: huggingface # choices: [huggingface, vllm] +infer_backend: huggingface # choices: [huggingface, vllm, sglang] trust_remote_code: true diff --git a/examples/merge_lora/llama3_full_sft.yaml b/examples/merge_lora/llama3_full_sft.yaml index 4e329fad..dd695372 100644 --- a/examples/merge_lora/llama3_full_sft.yaml +++ b/examples/merge_lora/llama3_full_sft.yaml @@ -6,5 +6,5 @@ trust_remote_code: true ### export export_dir: output/llama3_full_sft export_size: 5 -export_device: cpu +export_device: cpu # choices: [cpu, auto] export_legacy_format: false diff --git a/examples/merge_lora/llama3_gptq.yaml b/examples/merge_lora/llama3_gptq.yaml index 3a2d9095..2a3d2fd6 100644 --- a/examples/merge_lora/llama3_gptq.yaml +++ b/examples/merge_lora/llama3_gptq.yaml @@ -6,7 +6,7 @@ trust_remote_code: true ### export export_dir: output/llama3_gptq export_quantization_bit: 4 -export_quantization_dataset: data/c4_demo.json +export_quantization_dataset: data/c4_demo.jsonl export_size: 5 -export_device: cpu +export_device: cpu # choices: [cpu, auto] export_legacy_format: false diff --git a/examples/merge_lora/llama3_lora_sft.yaml b/examples/merge_lora/llama3_lora_sft.yaml index 97bb457b..2b011d8d 100644 --- a/examples/merge_lora/llama3_lora_sft.yaml +++ b/examples/merge_lora/llama3_lora_sft.yaml @@ -9,5 +9,5 @@ trust_remote_code: true ### export export_dir: output/llama3_lora_sft export_size: 5 -export_device: cpu +export_device: cpu # choices: [cpu, auto] export_legacy_format: false diff --git a/examples/merge_lora/qwen2vl_lora_sft.yaml b/examples/merge_lora/qwen2vl_lora_sft.yaml index 103dbcd8..9b157b3c 100644 --- a/examples/merge_lora/qwen2vl_lora_sft.yaml +++ b/examples/merge_lora/qwen2vl_lora_sft.yaml @@ -1,7 +1,7 @@ ### Note: DO NOT use quantized model or quantization_bit when merging lora adapters ### model -model_name_or_path: Qwen/Qwen2-VL-7B-Instruct +model_name_or_path: Qwen/Qwen2.5-VL-7B-Instruct adapter_name_or_path: saves/qwen2_vl-7b/lora/sft template: qwen2_vl trust_remote_code: true @@ -9,5 +9,5 @@ trust_remote_code: true ### export export_dir: output/qwen2_vl_lora_sft export_size: 5 -export_device: cpu +export_device: cpu # choices: [cpu, auto] export_legacy_format: false diff --git a/examples/train_lora/llava1_5_lora_sft.yaml b/examples/train_lora/llava1_5_lora_sft.yaml deleted file mode 100644 index 63cdcaea..00000000 --- a/examples/train_lora/llava1_5_lora_sft.yaml +++ /dev/null @@ -1,45 +0,0 @@ -### model -model_name_or_path: llava-hf/llava-1.5-7b-hf -trust_remote_code: true - -### method -stage: sft -do_train: true -finetuning_type: lora -lora_rank: 8 -lora_target: all - -### dataset -dataset: mllm_demo -template: llava -cutoff_len: 2048 -max_samples: 1000 -overwrite_cache: true -preprocessing_num_workers: 16 -dataloader_num_workers: 4 - -### output -output_dir: saves/llava1_5-7b/lora/sft -logging_steps: 10 -save_steps: 500 -plot_loss: true -overwrite_output_dir: true -save_only_model: false -report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow] - -### train -per_device_train_batch_size: 1 -gradient_accumulation_steps: 8 -learning_rate: 1.0e-4 -num_train_epochs: 3.0 -lr_scheduler_type: cosine -warmup_ratio: 0.1 -bf16: true -ddp_timeout: 180000000 -resume_from_checkpoint: null - -### eval -# val_size: 0.1 -# per_device_eval_batch_size: 1 -# eval_strategy: steps -# eval_steps: 500 diff --git a/examples/train_lora/qwen2vl_lora_dpo.yaml b/examples/train_lora/qwen2vl_lora_dpo.yaml index 3c990b42..44ec2d8a 100644 --- a/examples/train_lora/qwen2vl_lora_dpo.yaml +++ b/examples/train_lora/qwen2vl_lora_dpo.yaml @@ -1,5 +1,5 @@ ### model -model_name_or_path: Qwen/Qwen2-VL-7B-Instruct +model_name_or_path: Qwen/Qwen2.5-VL-7B-Instruct image_max_pixels: 262144 video_max_pixels: 16384 trust_remote_code: true diff --git a/examples/train_lora/qwen2vl_lora_sft.yaml b/examples/train_lora/qwen2vl_lora_sft.yaml index 54ff9842..5951546c 100644 --- a/examples/train_lora/qwen2vl_lora_sft.yaml +++ b/examples/train_lora/qwen2vl_lora_sft.yaml @@ -1,5 +1,5 @@ ### model -model_name_or_path: Qwen/Qwen2-VL-7B-Instruct +model_name_or_path: Qwen/Qwen2.5-VL-7B-Instruct image_max_pixels: 262144 video_max_pixels: 16384 trust_remote_code: true diff --git a/requirements.txt b/requirements.txt index c818bb28..c56f9ac5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -15,6 +15,7 @@ fastapi sse-starlette matplotlib>=3.7.0 fire +omegaconf packaging pyyaml numpy<2.0.0 diff --git a/src/llamafactory/hparams/parser.py b/src/llamafactory/hparams/parser.py index bc200f60..cfe71498 100644 --- a/src/llamafactory/hparams/parser.py +++ b/src/llamafactory/hparams/parser.py @@ -24,6 +24,7 @@ from typing import Any, Optional, Union import torch import transformers import yaml +from omegaconf import OmegaConf from transformers import HfArgumentParser from transformers.integrations import is_deepspeed_zero3_enabled from transformers.trainer_utils import get_last_checkpoint @@ -59,10 +60,14 @@ def read_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> Union[ if args is not None: return args - if len(sys.argv) == 2 and (sys.argv[1].endswith(".yaml") or sys.argv[1].endswith(".yml")): - return yaml.safe_load(Path(sys.argv[1]).absolute().read_text()) - elif len(sys.argv) == 2 and sys.argv[1].endswith(".json"): - return json.loads(Path(sys.argv[1]).absolute().read_text()) + if sys.argv[1].endswith(".yaml") or sys.argv[1].endswith(".yml"): + override_config = OmegaConf.from_cli(sys.argv[2:]) + dict_config = yaml.safe_load(Path(sys.argv[1]).absolute().read_text()) + return OmegaConf.to_container(OmegaConf.merge(dict_config, override_config)) + elif sys.argv[1].endswith(".json"): + override_config = OmegaConf.from_cli(sys.argv[2:]) + dict_config = json.loads(Path(sys.argv[1]).absolute().read_text()) + return OmegaConf.to_container(OmegaConf.merge(dict_config, override_config)) else: return sys.argv[1:] @@ -330,12 +335,20 @@ def get_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _ logger.warning_rank0("Specify `ref_model` for computing rewards at evaluation.") # Post-process training arguments + training_args.generation_max_length = training_args.generation_max_length or data_args.cutoff_len + training_args.generation_num_beams = data_args.eval_num_beams or training_args.generation_num_beams + training_args.remove_unused_columns = False # important for multimodal dataset + + if finetuning_args.finetuning_type == "lora": + # https://github.com/huggingface/transformers/blob/v4.50.0/src/transformers/trainer.py#L782 + training_args.label_names = training_args.label_names or ["labels"] + if ( training_args.parallel_mode == ParallelMode.DISTRIBUTED and training_args.ddp_find_unused_parameters is None and finetuning_args.finetuning_type == "lora" ): - logger.warning_rank0("`ddp_find_unused_parameters` needs to be set as False for LoRA in DDP training.") + logger.info_rank0("Set `ddp_find_unused_parameters` to False in DDP training since LoRA is enabled.") training_args.ddp_find_unused_parameters = False if finetuning_args.stage in ["rm", "ppo"] and finetuning_args.finetuning_type in ["full", "freeze"]: diff --git a/src/llamafactory/train/dpo/workflow.py b/src/llamafactory/train/dpo/workflow.py index 97262ad5..c0a107d2 100644 --- a/src/llamafactory/train/dpo/workflow.py +++ b/src/llamafactory/train/dpo/workflow.py @@ -63,9 +63,6 @@ def run_dpo( else: ref_model = None - # Update arguments - training_args.remove_unused_columns = False # important for multimodal and pairwise dataset - # Initialize our Trainer trainer = CustomDPOTrainer( model=model, diff --git a/src/llamafactory/train/kto/workflow.py b/src/llamafactory/train/kto/workflow.py index 7b16d1d0..df0794e3 100644 --- a/src/llamafactory/train/kto/workflow.py +++ b/src/llamafactory/train/kto/workflow.py @@ -59,9 +59,6 @@ def run_kto( else: ref_model = create_ref_model(model_args, finetuning_args) - # Update arguments - training_args.remove_unused_columns = False # important for multimodal and pairwise dataset - # Initialize our Trainer trainer = CustomKTOTrainer( model=model, diff --git a/src/llamafactory/train/rm/workflow.py b/src/llamafactory/train/rm/workflow.py index 18d562e8..89b2c95c 100644 --- a/src/llamafactory/train/rm/workflow.py +++ b/src/llamafactory/train/rm/workflow.py @@ -48,9 +48,6 @@ def run_rm( template=template, model=model, pad_to_multiple_of=8, **tokenizer_module ) - # Update arguments - training_args.remove_unused_columns = False # important for multimodal and pairwise dataset - # Initialize our Trainer trainer = PairwiseTrainer( model=model, diff --git a/src/llamafactory/train/sft/workflow.py b/src/llamafactory/train/sft/workflow.py index 9e6c549e..aab1cb13 100644 --- a/src/llamafactory/train/sft/workflow.py +++ b/src/llamafactory/train/sft/workflow.py @@ -65,11 +65,6 @@ def run_sft( **tokenizer_module, ) - # Override the decoding parameters of Seq2SeqTrainer - training_args.generation_max_length = training_args.generation_max_length or data_args.cutoff_len - training_args.generation_num_beams = data_args.eval_num_beams or training_args.generation_num_beams - training_args.remove_unused_columns = False # important for multimodal dataset - # Metric utils metric_module = {} if training_args.predict_with_generate: diff --git a/tests/train/test_sft_trainer.py b/tests/train/test_sft_trainer.py index 66a33af8..9f6ebe41 100644 --- a/tests/train/test_sft_trainer.py +++ b/tests/train/test_sft_trainer.py @@ -50,6 +50,10 @@ class DataCollatorWithVerbose(DataCollatorWithPadding): verbose_list: list[dict[str, Any]] = field(default_factory=list) def __call__(self, features: list[dict[str, Any]]) -> dict[str, Any]: + features = [ + {k: v for k, v in feature.items() if k in ["input_ids", "attention_mask", "labels"]} + for feature in features + ] self.verbose_list.extend(features) batch = super().__call__(features) return {k: v[:, :1] for k, v in batch.items()} # truncate input length