mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-01 11:12:50 +08:00
[parser] support omegaconf (#7793)
This commit is contained in:
parent
81768df04c
commit
278df4308d
@ -726,7 +726,7 @@ docker exec -it llamafactory bash
|
|||||||
### Deploy with OpenAI-style API and vLLM
|
### Deploy with OpenAI-style API and vLLM
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
API_PORT=8000 llamafactory-cli api examples/inference/llama3_vllm.yaml
|
API_PORT=8000 llamafactory-cli api examples/inference/llama3.yaml infer_backend=vllm vllm_enforce_eager=true
|
||||||
```
|
```
|
||||||
|
|
||||||
> [!TIP]
|
> [!TIP]
|
||||||
|
@ -730,7 +730,7 @@ docker exec -it llamafactory bash
|
|||||||
### 利用 vLLM 部署 OpenAI API
|
### 利用 vLLM 部署 OpenAI API
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
API_PORT=8000 llamafactory-cli api examples/inference/llama3_vllm.yaml
|
API_PORT=8000 llamafactory-cli api examples/inference/llama3.yaml infer_backend=vllm vllm_enforce_eager=true
|
||||||
```
|
```
|
||||||
|
|
||||||
> [!TIP]
|
> [!TIP]
|
||||||
|
@ -15,6 +15,18 @@ Use `CUDA_VISIBLE_DEVICES` (GPU) or `ASCEND_RT_VISIBLE_DEVICES` (NPU) to choose
|
|||||||
|
|
||||||
By default, LLaMA-Factory uses all visible computing devices.
|
By default, LLaMA-Factory uses all visible computing devices.
|
||||||
|
|
||||||
|
Basic usage:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
Advanced usage:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
CUDA_VISIBLE_DEVICES=0,1 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml learning_rate=1e-5 logging_steps=1
|
||||||
|
```
|
||||||
|
|
||||||
## Examples
|
## Examples
|
||||||
|
|
||||||
### LoRA Fine-Tuning
|
### LoRA Fine-Tuning
|
||||||
@ -34,7 +46,6 @@ llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
|
|||||||
#### Multimodal Supervised Fine-Tuning
|
#### Multimodal Supervised Fine-Tuning
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
llamafactory-cli train examples/train_lora/llava1_5_lora_sft.yaml
|
|
||||||
llamafactory-cli train examples/train_lora/qwen2vl_lora_sft.yaml
|
llamafactory-cli train examples/train_lora/qwen2vl_lora_sft.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -15,6 +15,18 @@
|
|||||||
|
|
||||||
LLaMA-Factory 默认使用所有可见的计算设备。
|
LLaMA-Factory 默认使用所有可见的计算设备。
|
||||||
|
|
||||||
|
基础用法:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
高级用法:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
CUDA_VISIBLE_DEVICES=0,1 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml learning_rate=1e-5 logging_steps=1
|
||||||
|
```
|
||||||
|
|
||||||
## 示例
|
## 示例
|
||||||
|
|
||||||
### LoRA 微调
|
### LoRA 微调
|
||||||
@ -34,7 +46,6 @@ llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
|
|||||||
#### 多模态指令监督微调
|
#### 多模态指令监督微调
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
llamafactory-cli train examples/train_lora/llava1_5_lora_sft.yaml
|
|
||||||
llamafactory-cli train examples/train_lora/qwen2vl_lora_sft.yaml
|
llamafactory-cli train examples/train_lora/qwen2vl_lora_sft.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
||||||
template: llama3
|
template: llama3
|
||||||
infer_backend: huggingface # choices: [huggingface, vllm]
|
infer_backend: huggingface # choices: [huggingface, vllm, sglang]
|
||||||
trust_remote_code: true
|
trust_remote_code: true
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
model_name_or_path: saves/llama3-8b/full/sft
|
model_name_or_path: saves/llama3-8b/full/sft
|
||||||
template: llama3
|
template: llama3
|
||||||
infer_backend: huggingface # choices: [huggingface, vllm]
|
infer_backend: huggingface # choices: [huggingface, vllm, sglang]
|
||||||
trust_remote_code: true
|
trust_remote_code: true
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
||||||
adapter_name_or_path: saves/llama3-8b/lora/sft
|
adapter_name_or_path: saves/llama3-8b/lora/sft
|
||||||
template: llama3
|
template: llama3
|
||||||
infer_backend: huggingface # choices: [huggingface, vllm]
|
infer_backend: huggingface # choices: [huggingface, vllm, sglang]
|
||||||
trust_remote_code: true
|
trust_remote_code: true
|
||||||
|
@ -1,4 +0,0 @@
|
|||||||
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
|
||||||
template: llama3
|
|
||||||
infer_backend: sglang
|
|
||||||
trust_remote_code: true
|
|
@ -1,5 +0,0 @@
|
|||||||
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
|
||||||
template: llama3
|
|
||||||
infer_backend: vllm
|
|
||||||
vllm_enforce_eager: true
|
|
||||||
trust_remote_code: true
|
|
@ -1,4 +0,0 @@
|
|||||||
model_name_or_path: llava-hf/llava-1.5-7b-hf
|
|
||||||
template: llava
|
|
||||||
infer_backend: huggingface # choices: [huggingface, vllm]
|
|
||||||
trust_remote_code: true
|
|
@ -1,4 +1,4 @@
|
|||||||
model_name_or_path: Qwen/Qwen2-VL-7B-Instruct
|
model_name_or_path: Qwen/Qwen2.5-VL-7B-Instruct
|
||||||
template: qwen2_vl
|
template: qwen2_vl
|
||||||
infer_backend: huggingface # choices: [huggingface, vllm]
|
infer_backend: huggingface # choices: [huggingface, vllm, sglang]
|
||||||
trust_remote_code: true
|
trust_remote_code: true
|
||||||
|
@ -6,5 +6,5 @@ trust_remote_code: true
|
|||||||
### export
|
### export
|
||||||
export_dir: output/llama3_full_sft
|
export_dir: output/llama3_full_sft
|
||||||
export_size: 5
|
export_size: 5
|
||||||
export_device: cpu
|
export_device: cpu # choices: [cpu, auto]
|
||||||
export_legacy_format: false
|
export_legacy_format: false
|
||||||
|
@ -6,7 +6,7 @@ trust_remote_code: true
|
|||||||
### export
|
### export
|
||||||
export_dir: output/llama3_gptq
|
export_dir: output/llama3_gptq
|
||||||
export_quantization_bit: 4
|
export_quantization_bit: 4
|
||||||
export_quantization_dataset: data/c4_demo.json
|
export_quantization_dataset: data/c4_demo.jsonl
|
||||||
export_size: 5
|
export_size: 5
|
||||||
export_device: cpu
|
export_device: cpu # choices: [cpu, auto]
|
||||||
export_legacy_format: false
|
export_legacy_format: false
|
||||||
|
@ -9,5 +9,5 @@ trust_remote_code: true
|
|||||||
### export
|
### export
|
||||||
export_dir: output/llama3_lora_sft
|
export_dir: output/llama3_lora_sft
|
||||||
export_size: 5
|
export_size: 5
|
||||||
export_device: cpu
|
export_device: cpu # choices: [cpu, auto]
|
||||||
export_legacy_format: false
|
export_legacy_format: false
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
### Note: DO NOT use quantized model or quantization_bit when merging lora adapters
|
### Note: DO NOT use quantized model or quantization_bit when merging lora adapters
|
||||||
|
|
||||||
### model
|
### model
|
||||||
model_name_or_path: Qwen/Qwen2-VL-7B-Instruct
|
model_name_or_path: Qwen/Qwen2.5-VL-7B-Instruct
|
||||||
adapter_name_or_path: saves/qwen2_vl-7b/lora/sft
|
adapter_name_or_path: saves/qwen2_vl-7b/lora/sft
|
||||||
template: qwen2_vl
|
template: qwen2_vl
|
||||||
trust_remote_code: true
|
trust_remote_code: true
|
||||||
@ -9,5 +9,5 @@ trust_remote_code: true
|
|||||||
### export
|
### export
|
||||||
export_dir: output/qwen2_vl_lora_sft
|
export_dir: output/qwen2_vl_lora_sft
|
||||||
export_size: 5
|
export_size: 5
|
||||||
export_device: cpu
|
export_device: cpu # choices: [cpu, auto]
|
||||||
export_legacy_format: false
|
export_legacy_format: false
|
||||||
|
@ -1,45 +0,0 @@
|
|||||||
### model
|
|
||||||
model_name_or_path: llava-hf/llava-1.5-7b-hf
|
|
||||||
trust_remote_code: true
|
|
||||||
|
|
||||||
### method
|
|
||||||
stage: sft
|
|
||||||
do_train: true
|
|
||||||
finetuning_type: lora
|
|
||||||
lora_rank: 8
|
|
||||||
lora_target: all
|
|
||||||
|
|
||||||
### dataset
|
|
||||||
dataset: mllm_demo
|
|
||||||
template: llava
|
|
||||||
cutoff_len: 2048
|
|
||||||
max_samples: 1000
|
|
||||||
overwrite_cache: true
|
|
||||||
preprocessing_num_workers: 16
|
|
||||||
dataloader_num_workers: 4
|
|
||||||
|
|
||||||
### output
|
|
||||||
output_dir: saves/llava1_5-7b/lora/sft
|
|
||||||
logging_steps: 10
|
|
||||||
save_steps: 500
|
|
||||||
plot_loss: true
|
|
||||||
overwrite_output_dir: true
|
|
||||||
save_only_model: false
|
|
||||||
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
|
|
||||||
|
|
||||||
### train
|
|
||||||
per_device_train_batch_size: 1
|
|
||||||
gradient_accumulation_steps: 8
|
|
||||||
learning_rate: 1.0e-4
|
|
||||||
num_train_epochs: 3.0
|
|
||||||
lr_scheduler_type: cosine
|
|
||||||
warmup_ratio: 0.1
|
|
||||||
bf16: true
|
|
||||||
ddp_timeout: 180000000
|
|
||||||
resume_from_checkpoint: null
|
|
||||||
|
|
||||||
### eval
|
|
||||||
# val_size: 0.1
|
|
||||||
# per_device_eval_batch_size: 1
|
|
||||||
# eval_strategy: steps
|
|
||||||
# eval_steps: 500
|
|
@ -1,5 +1,5 @@
|
|||||||
### model
|
### model
|
||||||
model_name_or_path: Qwen/Qwen2-VL-7B-Instruct
|
model_name_or_path: Qwen/Qwen2.5-VL-7B-Instruct
|
||||||
image_max_pixels: 262144
|
image_max_pixels: 262144
|
||||||
video_max_pixels: 16384
|
video_max_pixels: 16384
|
||||||
trust_remote_code: true
|
trust_remote_code: true
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
### model
|
### model
|
||||||
model_name_or_path: Qwen/Qwen2-VL-7B-Instruct
|
model_name_or_path: Qwen/Qwen2.5-VL-7B-Instruct
|
||||||
image_max_pixels: 262144
|
image_max_pixels: 262144
|
||||||
video_max_pixels: 16384
|
video_max_pixels: 16384
|
||||||
trust_remote_code: true
|
trust_remote_code: true
|
||||||
|
@ -15,6 +15,7 @@ fastapi
|
|||||||
sse-starlette
|
sse-starlette
|
||||||
matplotlib>=3.7.0
|
matplotlib>=3.7.0
|
||||||
fire
|
fire
|
||||||
|
omegaconf
|
||||||
packaging
|
packaging
|
||||||
pyyaml
|
pyyaml
|
||||||
numpy<2.0.0
|
numpy<2.0.0
|
||||||
|
@ -24,6 +24,7 @@ from typing import Any, Optional, Union
|
|||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
import yaml
|
import yaml
|
||||||
|
from omegaconf import OmegaConf
|
||||||
from transformers import HfArgumentParser
|
from transformers import HfArgumentParser
|
||||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||||
from transformers.trainer_utils import get_last_checkpoint
|
from transformers.trainer_utils import get_last_checkpoint
|
||||||
@ -59,10 +60,14 @@ def read_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> Union[
|
|||||||
if args is not None:
|
if args is not None:
|
||||||
return args
|
return args
|
||||||
|
|
||||||
if len(sys.argv) == 2 and (sys.argv[1].endswith(".yaml") or sys.argv[1].endswith(".yml")):
|
if sys.argv[1].endswith(".yaml") or sys.argv[1].endswith(".yml"):
|
||||||
return yaml.safe_load(Path(sys.argv[1]).absolute().read_text())
|
override_config = OmegaConf.from_cli(sys.argv[2:])
|
||||||
elif len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
dict_config = yaml.safe_load(Path(sys.argv[1]).absolute().read_text())
|
||||||
return json.loads(Path(sys.argv[1]).absolute().read_text())
|
return OmegaConf.to_container(OmegaConf.merge(dict_config, override_config))
|
||||||
|
elif sys.argv[1].endswith(".json"):
|
||||||
|
override_config = OmegaConf.from_cli(sys.argv[2:])
|
||||||
|
dict_config = json.loads(Path(sys.argv[1]).absolute().read_text())
|
||||||
|
return OmegaConf.to_container(OmegaConf.merge(dict_config, override_config))
|
||||||
else:
|
else:
|
||||||
return sys.argv[1:]
|
return sys.argv[1:]
|
||||||
|
|
||||||
@ -330,12 +335,20 @@ def get_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _
|
|||||||
logger.warning_rank0("Specify `ref_model` for computing rewards at evaluation.")
|
logger.warning_rank0("Specify `ref_model` for computing rewards at evaluation.")
|
||||||
|
|
||||||
# Post-process training arguments
|
# Post-process training arguments
|
||||||
|
training_args.generation_max_length = training_args.generation_max_length or data_args.cutoff_len
|
||||||
|
training_args.generation_num_beams = data_args.eval_num_beams or training_args.generation_num_beams
|
||||||
|
training_args.remove_unused_columns = False # important for multimodal dataset
|
||||||
|
|
||||||
|
if finetuning_args.finetuning_type == "lora":
|
||||||
|
# https://github.com/huggingface/transformers/blob/v4.50.0/src/transformers/trainer.py#L782
|
||||||
|
training_args.label_names = training_args.label_names or ["labels"]
|
||||||
|
|
||||||
if (
|
if (
|
||||||
training_args.parallel_mode == ParallelMode.DISTRIBUTED
|
training_args.parallel_mode == ParallelMode.DISTRIBUTED
|
||||||
and training_args.ddp_find_unused_parameters is None
|
and training_args.ddp_find_unused_parameters is None
|
||||||
and finetuning_args.finetuning_type == "lora"
|
and finetuning_args.finetuning_type == "lora"
|
||||||
):
|
):
|
||||||
logger.warning_rank0("`ddp_find_unused_parameters` needs to be set as False for LoRA in DDP training.")
|
logger.info_rank0("Set `ddp_find_unused_parameters` to False in DDP training since LoRA is enabled.")
|
||||||
training_args.ddp_find_unused_parameters = False
|
training_args.ddp_find_unused_parameters = False
|
||||||
|
|
||||||
if finetuning_args.stage in ["rm", "ppo"] and finetuning_args.finetuning_type in ["full", "freeze"]:
|
if finetuning_args.stage in ["rm", "ppo"] and finetuning_args.finetuning_type in ["full", "freeze"]:
|
||||||
|
@ -63,9 +63,6 @@ def run_dpo(
|
|||||||
else:
|
else:
|
||||||
ref_model = None
|
ref_model = None
|
||||||
|
|
||||||
# Update arguments
|
|
||||||
training_args.remove_unused_columns = False # important for multimodal and pairwise dataset
|
|
||||||
|
|
||||||
# Initialize our Trainer
|
# Initialize our Trainer
|
||||||
trainer = CustomDPOTrainer(
|
trainer = CustomDPOTrainer(
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -59,9 +59,6 @@ def run_kto(
|
|||||||
else:
|
else:
|
||||||
ref_model = create_ref_model(model_args, finetuning_args)
|
ref_model = create_ref_model(model_args, finetuning_args)
|
||||||
|
|
||||||
# Update arguments
|
|
||||||
training_args.remove_unused_columns = False # important for multimodal and pairwise dataset
|
|
||||||
|
|
||||||
# Initialize our Trainer
|
# Initialize our Trainer
|
||||||
trainer = CustomKTOTrainer(
|
trainer = CustomKTOTrainer(
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -48,9 +48,6 @@ def run_rm(
|
|||||||
template=template, model=model, pad_to_multiple_of=8, **tokenizer_module
|
template=template, model=model, pad_to_multiple_of=8, **tokenizer_module
|
||||||
)
|
)
|
||||||
|
|
||||||
# Update arguments
|
|
||||||
training_args.remove_unused_columns = False # important for multimodal and pairwise dataset
|
|
||||||
|
|
||||||
# Initialize our Trainer
|
# Initialize our Trainer
|
||||||
trainer = PairwiseTrainer(
|
trainer = PairwiseTrainer(
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -65,11 +65,6 @@ def run_sft(
|
|||||||
**tokenizer_module,
|
**tokenizer_module,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Override the decoding parameters of Seq2SeqTrainer
|
|
||||||
training_args.generation_max_length = training_args.generation_max_length or data_args.cutoff_len
|
|
||||||
training_args.generation_num_beams = data_args.eval_num_beams or training_args.generation_num_beams
|
|
||||||
training_args.remove_unused_columns = False # important for multimodal dataset
|
|
||||||
|
|
||||||
# Metric utils
|
# Metric utils
|
||||||
metric_module = {}
|
metric_module = {}
|
||||||
if training_args.predict_with_generate:
|
if training_args.predict_with_generate:
|
||||||
|
@ -50,6 +50,10 @@ class DataCollatorWithVerbose(DataCollatorWithPadding):
|
|||||||
verbose_list: list[dict[str, Any]] = field(default_factory=list)
|
verbose_list: list[dict[str, Any]] = field(default_factory=list)
|
||||||
|
|
||||||
def __call__(self, features: list[dict[str, Any]]) -> dict[str, Any]:
|
def __call__(self, features: list[dict[str, Any]]) -> dict[str, Any]:
|
||||||
|
features = [
|
||||||
|
{k: v for k, v in feature.items() if k in ["input_ids", "attention_mask", "labels"]}
|
||||||
|
for feature in features
|
||||||
|
]
|
||||||
self.verbose_list.extend(features)
|
self.verbose_list.extend(features)
|
||||||
batch = super().__call__(features)
|
batch = super().__call__(features)
|
||||||
return {k: v[:, :1] for k, v in batch.items()} # truncate input length
|
return {k: v[:, :1] for k, v in batch.items()} # truncate input length
|
||||||
|
Loading…
x
Reference in New Issue
Block a user