From 317d0855d29485815dc60a6264b5bc4facfc39c7 Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Tue, 11 Mar 2025 01:15:35 +0800 Subject: [PATCH] [infer] fix vllm args (#7235) Former-commit-id: ef7af457fc44b1e8cad0c78717848617f98364f0 --- README.md | 21 +++++++++++---------- README_zh.md | 21 +++++++++++---------- scripts/vllm_infer.py | 14 +++++++++----- src/llamafactory/chat/vllm_engine.py | 2 +- 4 files changed, 32 insertions(+), 26 deletions(-) diff --git a/README.md b/README.md index a97d4077..93006e8b 100644 --- a/README.md +++ b/README.md @@ -5,7 +5,7 @@ [![GitHub contributors](https://img.shields.io/github/contributors/hiyouga/LLaMA-Factory?color=orange)](https://github.com/hiyouga/LLaMA-Factory/graphs/contributors) [![GitHub workflow](https://github.com/hiyouga/LLaMA-Factory/actions/workflows/tests.yml/badge.svg)](https://github.com/hiyouga/LLaMA-Factory/actions/workflows/tests.yml) [![PyPI](https://img.shields.io/pypi/v/llamafactory)](https://pypi.org/project/llamafactory/) -[![Citation](https://img.shields.io/badge/citation-341-green)](https://scholar.google.com/scholar?cites=12620864006390196564) +[![Citation](https://img.shields.io/badge/citation-349-green)](https://scholar.google.com/scholar?cites=12620864006390196564) [![GitHub pull request](https://img.shields.io/badge/PRs-welcome-blue)](https://github.com/hiyouga/LLaMA-Factory/pulls) [![Twitter](https://img.shields.io/twitter/follow/llamafactory_ai)](https://twitter.com/llamafactory_ai) @@ -412,15 +412,14 @@ huggingface-cli login \* *estimated* -| Method | Bits | 7B | 13B | 30B | 70B | 110B | 8x7B | 8x22B | -| ------------------------ | ---- | ----- | ----- | ----- | ------ | ------ | ----- | ------ | -| Full | 32 | 120GB | 240GB | 600GB | 1200GB | 2000GB | 900GB | 2400GB | -| Full | 16 | 60GB | 120GB | 300GB | 600GB | 900GB | 400GB | 1200GB | -| Freeze | 16 | 20GB | 40GB | 80GB | 200GB | 360GB | 160GB | 400GB | -| LoRA/GaLore/APOLLO/BAdam | 16 | 16GB | 32GB | 64GB | 160GB | 240GB | 120GB | 320GB | -| QLoRA | 8 | 10GB | 20GB | 40GB | 80GB | 140GB | 60GB | 160GB | -| QLoRA | 4 | 6GB | 12GB | 24GB | 48GB | 72GB | 30GB | 96GB | -| QLoRA | 2 | 4GB | 8GB | 16GB | 24GB | 48GB | 18GB | 48GB | +| Method | Bits | 7B | 14B | 30B | 70B | `x`B | +| ------------------------------- | ---- | ----- | ----- | ----- | ------ | ------- | +| Full (`bf16` or `fp16`) | 32 | 120GB | 240GB | 600GB | 1200GB | `18x`GB | +| Full (`pure_bf16`) | 16 | 60GB | 120GB | 300GB | 600GB | `8x`GB | +| Freeze/LoRA/GaLore/APOLLO/BAdam | 16 | 16GB | 32GB | 64GB | 160GB | `2x`GB | +| QLoRA | 8 | 10GB | 20GB | 40GB | 80GB | `x`GB | +| QLoRA | 4 | 6GB | 12GB | 24GB | 48GB | `x/2`GB | +| QLoRA | 2 | 4GB | 8GB | 16GB | 24GB | `x/4`GB | ## Getting Started @@ -560,6 +559,8 @@ See [examples/README.md](examples/README.md) for advanced usage (including distr > [!TIP] > Use `llamafactory-cli help` to show help information. +> +> Read [FAQs](https://github.com/hiyouga/LLaMA-Factory/issues/4614) first if you encounter any problems. ### Fine-Tuning with LLaMA Board GUI (powered by [Gradio](https://github.com/gradio-app/gradio)) diff --git a/README_zh.md b/README_zh.md index e85071f6..92349517 100644 --- a/README_zh.md +++ b/README_zh.md @@ -5,7 +5,7 @@ [![GitHub contributors](https://img.shields.io/github/contributors/hiyouga/LLaMA-Factory?color=orange)](https://github.com/hiyouga/LLaMA-Factory/graphs/contributors) [![GitHub workflow](https://github.com/hiyouga/LLaMA-Factory/actions/workflows/tests.yml/badge.svg)](https://github.com/hiyouga/LLaMA-Factory/actions/workflows/tests.yml) [![PyPI](https://img.shields.io/pypi/v/llamafactory)](https://pypi.org/project/llamafactory/) -[![Citation](https://img.shields.io/badge/citation-341-green)](https://scholar.google.com/scholar?cites=12620864006390196564) +[![Citation](https://img.shields.io/badge/citation-349-green)](https://scholar.google.com/scholar?cites=12620864006390196564) [![GitHub pull request](https://img.shields.io/badge/PRs-welcome-blue)](https://github.com/hiyouga/LLaMA-Factory/pulls) [![Twitter](https://img.shields.io/twitter/follow/llamafactory_ai)](https://twitter.com/llamafactory_ai) @@ -414,15 +414,14 @@ huggingface-cli login \* *估算值* -| 方法 | 精度 | 7B | 13B | 30B | 70B | 110B | 8x7B | 8x22B | -| ------------------------ | ---- | ----- | ----- | ----- | ------ | ------ | ----- | ------ | -| Full | 32 | 120GB | 240GB | 600GB | 1200GB | 2000GB | 900GB | 2400GB | -| Full | 16 | 60GB | 120GB | 300GB | 600GB | 900GB | 400GB | 1200GB | -| Freeze | 16 | 20GB | 40GB | 80GB | 200GB | 360GB | 160GB | 400GB | -| LoRA/GaLore/APOLLO/BAdam | 16 | 16GB | 32GB | 64GB | 160GB | 240GB | 120GB | 320GB | -| QLoRA | 8 | 10GB | 20GB | 40GB | 80GB | 140GB | 60GB | 160GB | -| QLoRA | 4 | 6GB | 12GB | 24GB | 48GB | 72GB | 30GB | 96GB | -| QLoRA | 2 | 4GB | 8GB | 16GB | 24GB | 48GB | 18GB | 48GB | +| 方法 | 精度 | 7B | 14B | 30B | 70B | `x`B | +| ------------------------------- | ---- | ----- | ----- | ----- | ------ | ------- | +| Full (`bf16` or `fp16`) | 32 | 120GB | 240GB | 600GB | 1200GB | `18x`GB | +| Full (`pure_bf16`) | 16 | 60GB | 120GB | 300GB | 600GB | `8x`GB | +| Freeze/LoRA/GaLore/APOLLO/BAdam | 16 | 16GB | 32GB | 64GB | 160GB | `2x`GB | +| QLoRA | 8 | 10GB | 20GB | 40GB | 80GB | `x`GB | +| QLoRA | 4 | 6GB | 12GB | 24GB | 48GB | `x/2`GB | +| QLoRA | 2 | 4GB | 8GB | 16GB | 24GB | `x/4`GB | ## 如何使用 @@ -563,6 +562,8 @@ llamafactory-cli export examples/merge_lora/llama3_lora_sft.yaml > [!TIP] > 使用 `llamafactory-cli help` 显示帮助信息。 +> +> 遇到报错请先看[常见问题](https://github.com/hiyouga/LLaMA-Factory/issues/4614)。 ### LLaMA Board 可视化微调(由 [Gradio](https://github.com/gradio-app/gradio) 驱动) diff --git a/scripts/vllm_infer.py b/scripts/vllm_infer.py index 6f47f173..02d20ee5 100644 --- a/scripts/vllm_infer.py +++ b/scripts/vllm_infer.py @@ -38,7 +38,7 @@ def vllm_infer( dataset_dir: str = "data", template: str = "default", cutoff_len: int = 2048, - max_samples: int = None, + max_samples: Optional[int] = None, vllm_config: str = "{}", save_name: str = "generated_predictions.jsonl", temperature: float = 0.95, @@ -46,6 +46,7 @@ def vllm_infer( top_k: int = 50, max_new_tokens: int = 1024, repetition_penalty: float = 1.0, + skip_special_tokens: bool = True, seed: Optional[int] = None, pipeline_parallel_size: int = 1, image_max_pixels: int = 768 * 768, @@ -97,19 +98,21 @@ def vllm_infer( multi_modal_data = None inputs.append({"prompt_token_ids": sample["input_ids"], "multi_modal_data": multi_modal_data}) - prompts.append(tokenizer.decode(sample["input_ids"], skip_special_tokens=False)) + prompts.append(tokenizer.decode(sample["input_ids"], skip_special_tokens=skip_special_tokens)) labels.append( - tokenizer.decode(list(filter(lambda x: x != IGNORE_INDEX, sample["labels"])), skip_special_tokens=False) + tokenizer.decode( + list(filter(lambda x: x != IGNORE_INDEX, sample["labels"])), skip_special_tokens=skip_special_tokens + ) ) sampling_params = SamplingParams( repetition_penalty=generating_args.repetition_penalty or 1.0, # repetition_penalty must > 0 temperature=generating_args.temperature, top_p=generating_args.top_p or 1.0, # top_p must > 0 - top_k=generating_args.top_k, + top_k=generating_args.top_k or -1, # top_k must > 0 stop_token_ids=template_obj.get_stop_token_ids(tokenizer), max_tokens=generating_args.max_new_tokens, - skip_special_tokens=False, + skip_special_tokens=skip_special_tokens, seed=seed, ) if model_args.adapter_name_or_path is not None: @@ -121,6 +124,7 @@ def vllm_infer( "model": model_args.model_name_or_path, "trust_remote_code": True, "dtype": model_args.infer_dtype, + "max_model_len": cutoff_len + max_new_tokens, "tensor_parallel_size": (get_device_count() // pipeline_parallel_size) or 1, "pipeline_parallel_size": pipeline_parallel_size, "disable_log_stats": True, diff --git a/src/llamafactory/chat/vllm_engine.py b/src/llamafactory/chat/vllm_engine.py index d5041261..0acfc370 100644 --- a/src/llamafactory/chat/vllm_engine.py +++ b/src/llamafactory/chat/vllm_engine.py @@ -170,7 +170,7 @@ class VllmEngine(BaseEngine): or 1.0, # repetition_penalty must > 0 temperature=temperature if temperature is not None else self.generating_args["temperature"], top_p=(top_p if top_p is not None else self.generating_args["top_p"]) or 1.0, # top_p must > 0 - top_k=top_k if top_k is not None else self.generating_args["top_k"], + top_k=(top_k if top_k is not None else self.generating_args["top_k"]) or -1, # top_k must > 0 stop=stop, stop_token_ids=self.template.get_stop_token_ids(self.tokenizer), max_tokens=max_tokens,