[parser] support omegaconf (#7793)

This commit is contained in:
hoshi-hiyouga 2025-04-21 23:30:30 +08:00 committed by GitHub
parent 81768df04c
commit 278df4308d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
25 changed files with 62 additions and 94 deletions

View File

@ -726,7 +726,7 @@ docker exec -it llamafactory bash
### Deploy with OpenAI-style API and vLLM ### Deploy with OpenAI-style API and vLLM
```bash ```bash
API_PORT=8000 llamafactory-cli api examples/inference/llama3_vllm.yaml API_PORT=8000 llamafactory-cli api examples/inference/llama3.yaml infer_backend=vllm vllm_enforce_eager=true
``` ```
> [!TIP] > [!TIP]

View File

@ -730,7 +730,7 @@ docker exec -it llamafactory bash
### 利用 vLLM 部署 OpenAI API ### 利用 vLLM 部署 OpenAI API
```bash ```bash
API_PORT=8000 llamafactory-cli api examples/inference/llama3_vllm.yaml API_PORT=8000 llamafactory-cli api examples/inference/llama3.yaml infer_backend=vllm vllm_enforce_eager=true
``` ```
> [!TIP] > [!TIP]

View File

@ -15,6 +15,18 @@ Use `CUDA_VISIBLE_DEVICES` (GPU) or `ASCEND_RT_VISIBLE_DEVICES` (NPU) to choose
By default, LLaMA-Factory uses all visible computing devices. By default, LLaMA-Factory uses all visible computing devices.
Basic usage:
```bash
llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
```
Advanced usage:
```bash
CUDA_VISIBLE_DEVICES=0,1 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml learning_rate=1e-5 logging_steps=1
```
## Examples ## Examples
### LoRA Fine-Tuning ### LoRA Fine-Tuning
@ -34,7 +46,6 @@ llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
#### Multimodal Supervised Fine-Tuning #### Multimodal Supervised Fine-Tuning
```bash ```bash
llamafactory-cli train examples/train_lora/llava1_5_lora_sft.yaml
llamafactory-cli train examples/train_lora/qwen2vl_lora_sft.yaml llamafactory-cli train examples/train_lora/qwen2vl_lora_sft.yaml
``` ```

View File

@ -15,6 +15,18 @@
LLaMA-Factory 默认使用所有可见的计算设备。 LLaMA-Factory 默认使用所有可见的计算设备。
基础用法:
```bash
llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
```
高级用法:
```bash
CUDA_VISIBLE_DEVICES=0,1 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml learning_rate=1e-5 logging_steps=1
```
## 示例 ## 示例
### LoRA 微调 ### LoRA 微调
@ -34,7 +46,6 @@ llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
#### 多模态指令监督微调 #### 多模态指令监督微调
```bash ```bash
llamafactory-cli train examples/train_lora/llava1_5_lora_sft.yaml
llamafactory-cli train examples/train_lora/qwen2vl_lora_sft.yaml llamafactory-cli train examples/train_lora/qwen2vl_lora_sft.yaml
``` ```

View File

@ -1,4 +1,4 @@
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
template: llama3 template: llama3
infer_backend: huggingface # choices: [huggingface, vllm] infer_backend: huggingface # choices: [huggingface, vllm, sglang]
trust_remote_code: true trust_remote_code: true

View File

@ -1,4 +1,4 @@
model_name_or_path: saves/llama3-8b/full/sft model_name_or_path: saves/llama3-8b/full/sft
template: llama3 template: llama3
infer_backend: huggingface # choices: [huggingface, vllm] infer_backend: huggingface # choices: [huggingface, vllm, sglang]
trust_remote_code: true trust_remote_code: true

View File

@ -1,5 +1,5 @@
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
adapter_name_or_path: saves/llama3-8b/lora/sft adapter_name_or_path: saves/llama3-8b/lora/sft
template: llama3 template: llama3
infer_backend: huggingface # choices: [huggingface, vllm] infer_backend: huggingface # choices: [huggingface, vllm, sglang]
trust_remote_code: true trust_remote_code: true

View File

@ -1,4 +0,0 @@
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
template: llama3
infer_backend: sglang
trust_remote_code: true

View File

@ -1,5 +0,0 @@
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
template: llama3
infer_backend: vllm
vllm_enforce_eager: true
trust_remote_code: true

View File

@ -1,4 +0,0 @@
model_name_or_path: llava-hf/llava-1.5-7b-hf
template: llava
infer_backend: huggingface # choices: [huggingface, vllm]
trust_remote_code: true

View File

@ -1,4 +1,4 @@
model_name_or_path: Qwen/Qwen2-VL-7B-Instruct model_name_or_path: Qwen/Qwen2.5-VL-7B-Instruct
template: qwen2_vl template: qwen2_vl
infer_backend: huggingface # choices: [huggingface, vllm] infer_backend: huggingface # choices: [huggingface, vllm, sglang]
trust_remote_code: true trust_remote_code: true

View File

@ -6,5 +6,5 @@ trust_remote_code: true
### export ### export
export_dir: output/llama3_full_sft export_dir: output/llama3_full_sft
export_size: 5 export_size: 5
export_device: cpu export_device: cpu # choices: [cpu, auto]
export_legacy_format: false export_legacy_format: false

View File

@ -6,7 +6,7 @@ trust_remote_code: true
### export ### export
export_dir: output/llama3_gptq export_dir: output/llama3_gptq
export_quantization_bit: 4 export_quantization_bit: 4
export_quantization_dataset: data/c4_demo.json export_quantization_dataset: data/c4_demo.jsonl
export_size: 5 export_size: 5
export_device: cpu export_device: cpu # choices: [cpu, auto]
export_legacy_format: false export_legacy_format: false

View File

@ -9,5 +9,5 @@ trust_remote_code: true
### export ### export
export_dir: output/llama3_lora_sft export_dir: output/llama3_lora_sft
export_size: 5 export_size: 5
export_device: cpu export_device: cpu # choices: [cpu, auto]
export_legacy_format: false export_legacy_format: false

View File

@ -1,7 +1,7 @@
### Note: DO NOT use quantized model or quantization_bit when merging lora adapters ### Note: DO NOT use quantized model or quantization_bit when merging lora adapters
### model ### model
model_name_or_path: Qwen/Qwen2-VL-7B-Instruct model_name_or_path: Qwen/Qwen2.5-VL-7B-Instruct
adapter_name_or_path: saves/qwen2_vl-7b/lora/sft adapter_name_or_path: saves/qwen2_vl-7b/lora/sft
template: qwen2_vl template: qwen2_vl
trust_remote_code: true trust_remote_code: true
@ -9,5 +9,5 @@ trust_remote_code: true
### export ### export
export_dir: output/qwen2_vl_lora_sft export_dir: output/qwen2_vl_lora_sft
export_size: 5 export_size: 5
export_device: cpu export_device: cpu # choices: [cpu, auto]
export_legacy_format: false export_legacy_format: false

View File

@ -1,45 +0,0 @@
### model
model_name_or_path: llava-hf/llava-1.5-7b-hf
trust_remote_code: true
### method
stage: sft
do_train: true
finetuning_type: lora
lora_rank: 8
lora_target: all
### dataset
dataset: mllm_demo
template: llava
cutoff_len: 2048
max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16
dataloader_num_workers: 4
### output
output_dir: saves/llava1_5-7b/lora/sft
logging_steps: 10
save_steps: 500
plot_loss: true
overwrite_output_dir: true
save_only_model: false
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
### train
per_device_train_batch_size: 1
gradient_accumulation_steps: 8
learning_rate: 1.0e-4
num_train_epochs: 3.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
bf16: true
ddp_timeout: 180000000
resume_from_checkpoint: null
### eval
# val_size: 0.1
# per_device_eval_batch_size: 1
# eval_strategy: steps
# eval_steps: 500

View File

@ -1,5 +1,5 @@
### model ### model
model_name_or_path: Qwen/Qwen2-VL-7B-Instruct model_name_or_path: Qwen/Qwen2.5-VL-7B-Instruct
image_max_pixels: 262144 image_max_pixels: 262144
video_max_pixels: 16384 video_max_pixels: 16384
trust_remote_code: true trust_remote_code: true

View File

@ -1,5 +1,5 @@
### model ### model
model_name_or_path: Qwen/Qwen2-VL-7B-Instruct model_name_or_path: Qwen/Qwen2.5-VL-7B-Instruct
image_max_pixels: 262144 image_max_pixels: 262144
video_max_pixels: 16384 video_max_pixels: 16384
trust_remote_code: true trust_remote_code: true

View File

@ -15,6 +15,7 @@ fastapi
sse-starlette sse-starlette
matplotlib>=3.7.0 matplotlib>=3.7.0
fire fire
omegaconf
packaging packaging
pyyaml pyyaml
numpy<2.0.0 numpy<2.0.0

View File

@ -24,6 +24,7 @@ from typing import Any, Optional, Union
import torch import torch
import transformers import transformers
import yaml import yaml
from omegaconf import OmegaConf
from transformers import HfArgumentParser from transformers import HfArgumentParser
from transformers.integrations import is_deepspeed_zero3_enabled from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.trainer_utils import get_last_checkpoint from transformers.trainer_utils import get_last_checkpoint
@ -59,10 +60,14 @@ def read_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> Union[
if args is not None: if args is not None:
return args return args
if len(sys.argv) == 2 and (sys.argv[1].endswith(".yaml") or sys.argv[1].endswith(".yml")): if sys.argv[1].endswith(".yaml") or sys.argv[1].endswith(".yml"):
return yaml.safe_load(Path(sys.argv[1]).absolute().read_text()) override_config = OmegaConf.from_cli(sys.argv[2:])
elif len(sys.argv) == 2 and sys.argv[1].endswith(".json"): dict_config = yaml.safe_load(Path(sys.argv[1]).absolute().read_text())
return json.loads(Path(sys.argv[1]).absolute().read_text()) return OmegaConf.to_container(OmegaConf.merge(dict_config, override_config))
elif sys.argv[1].endswith(".json"):
override_config = OmegaConf.from_cli(sys.argv[2:])
dict_config = json.loads(Path(sys.argv[1]).absolute().read_text())
return OmegaConf.to_container(OmegaConf.merge(dict_config, override_config))
else: else:
return sys.argv[1:] return sys.argv[1:]
@ -330,12 +335,20 @@ def get_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _
logger.warning_rank0("Specify `ref_model` for computing rewards at evaluation.") logger.warning_rank0("Specify `ref_model` for computing rewards at evaluation.")
# Post-process training arguments # Post-process training arguments
training_args.generation_max_length = training_args.generation_max_length or data_args.cutoff_len
training_args.generation_num_beams = data_args.eval_num_beams or training_args.generation_num_beams
training_args.remove_unused_columns = False # important for multimodal dataset
if finetuning_args.finetuning_type == "lora":
# https://github.com/huggingface/transformers/blob/v4.50.0/src/transformers/trainer.py#L782
training_args.label_names = training_args.label_names or ["labels"]
if ( if (
training_args.parallel_mode == ParallelMode.DISTRIBUTED training_args.parallel_mode == ParallelMode.DISTRIBUTED
and training_args.ddp_find_unused_parameters is None and training_args.ddp_find_unused_parameters is None
and finetuning_args.finetuning_type == "lora" and finetuning_args.finetuning_type == "lora"
): ):
logger.warning_rank0("`ddp_find_unused_parameters` needs to be set as False for LoRA in DDP training.") logger.info_rank0("Set `ddp_find_unused_parameters` to False in DDP training since LoRA is enabled.")
training_args.ddp_find_unused_parameters = False training_args.ddp_find_unused_parameters = False
if finetuning_args.stage in ["rm", "ppo"] and finetuning_args.finetuning_type in ["full", "freeze"]: if finetuning_args.stage in ["rm", "ppo"] and finetuning_args.finetuning_type in ["full", "freeze"]:

View File

@ -63,9 +63,6 @@ def run_dpo(
else: else:
ref_model = None ref_model = None
# Update arguments
training_args.remove_unused_columns = False # important for multimodal and pairwise dataset
# Initialize our Trainer # Initialize our Trainer
trainer = CustomDPOTrainer( trainer = CustomDPOTrainer(
model=model, model=model,

View File

@ -59,9 +59,6 @@ def run_kto(
else: else:
ref_model = create_ref_model(model_args, finetuning_args) ref_model = create_ref_model(model_args, finetuning_args)
# Update arguments
training_args.remove_unused_columns = False # important for multimodal and pairwise dataset
# Initialize our Trainer # Initialize our Trainer
trainer = CustomKTOTrainer( trainer = CustomKTOTrainer(
model=model, model=model,

View File

@ -48,9 +48,6 @@ def run_rm(
template=template, model=model, pad_to_multiple_of=8, **tokenizer_module template=template, model=model, pad_to_multiple_of=8, **tokenizer_module
) )
# Update arguments
training_args.remove_unused_columns = False # important for multimodal and pairwise dataset
# Initialize our Trainer # Initialize our Trainer
trainer = PairwiseTrainer( trainer = PairwiseTrainer(
model=model, model=model,

View File

@ -65,11 +65,6 @@ def run_sft(
**tokenizer_module, **tokenizer_module,
) )
# Override the decoding parameters of Seq2SeqTrainer
training_args.generation_max_length = training_args.generation_max_length or data_args.cutoff_len
training_args.generation_num_beams = data_args.eval_num_beams or training_args.generation_num_beams
training_args.remove_unused_columns = False # important for multimodal dataset
# Metric utils # Metric utils
metric_module = {} metric_module = {}
if training_args.predict_with_generate: if training_args.predict_with_generate:

View File

@ -50,6 +50,10 @@ class DataCollatorWithVerbose(DataCollatorWithPadding):
verbose_list: list[dict[str, Any]] = field(default_factory=list) verbose_list: list[dict[str, Any]] = field(default_factory=list)
def __call__(self, features: list[dict[str, Any]]) -> dict[str, Any]: def __call__(self, features: list[dict[str, Any]]) -> dict[str, Any]:
features = [
{k: v for k, v in feature.items() if k in ["input_ids", "attention_mask", "labels"]}
for feature in features
]
self.verbose_list.extend(features) self.verbose_list.extend(features)
batch = super().__call__(features) batch = super().__call__(features)
return {k: v[:, :1] for k, v in batch.items()} # truncate input length return {k: v[:, :1] for k, v in batch.items()} # truncate input length