mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-02 03:32:50 +08:00
[infer] fix vllm args (#7235)
Former-commit-id: ef7af457fc44b1e8cad0c78717848617f98364f0
This commit is contained in:
parent
0a43bc1960
commit
317d0855d2
21
README.md
21
README.md
@ -5,7 +5,7 @@
|
|||||||
[](https://github.com/hiyouga/LLaMA-Factory/graphs/contributors)
|
[](https://github.com/hiyouga/LLaMA-Factory/graphs/contributors)
|
||||||
[](https://github.com/hiyouga/LLaMA-Factory/actions/workflows/tests.yml)
|
[](https://github.com/hiyouga/LLaMA-Factory/actions/workflows/tests.yml)
|
||||||
[](https://pypi.org/project/llamafactory/)
|
[](https://pypi.org/project/llamafactory/)
|
||||||
[](https://scholar.google.com/scholar?cites=12620864006390196564)
|
[](https://scholar.google.com/scholar?cites=12620864006390196564)
|
||||||
[](https://github.com/hiyouga/LLaMA-Factory/pulls)
|
[](https://github.com/hiyouga/LLaMA-Factory/pulls)
|
||||||
|
|
||||||
[](https://twitter.com/llamafactory_ai)
|
[](https://twitter.com/llamafactory_ai)
|
||||||
@ -412,15 +412,14 @@ huggingface-cli login
|
|||||||
|
|
||||||
\* *estimated*
|
\* *estimated*
|
||||||
|
|
||||||
| Method | Bits | 7B | 13B | 30B | 70B | 110B | 8x7B | 8x22B |
|
| Method | Bits | 7B | 14B | 30B | 70B | `x`B |
|
||||||
| ------------------------ | ---- | ----- | ----- | ----- | ------ | ------ | ----- | ------ |
|
| ------------------------------- | ---- | ----- | ----- | ----- | ------ | ------- |
|
||||||
| Full | 32 | 120GB | 240GB | 600GB | 1200GB | 2000GB | 900GB | 2400GB |
|
| Full (`bf16` or `fp16`) | 32 | 120GB | 240GB | 600GB | 1200GB | `18x`GB |
|
||||||
| Full | 16 | 60GB | 120GB | 300GB | 600GB | 900GB | 400GB | 1200GB |
|
| Full (`pure_bf16`) | 16 | 60GB | 120GB | 300GB | 600GB | `8x`GB |
|
||||||
| Freeze | 16 | 20GB | 40GB | 80GB | 200GB | 360GB | 160GB | 400GB |
|
| Freeze/LoRA/GaLore/APOLLO/BAdam | 16 | 16GB | 32GB | 64GB | 160GB | `2x`GB |
|
||||||
| LoRA/GaLore/APOLLO/BAdam | 16 | 16GB | 32GB | 64GB | 160GB | 240GB | 120GB | 320GB |
|
| QLoRA | 8 | 10GB | 20GB | 40GB | 80GB | `x`GB |
|
||||||
| QLoRA | 8 | 10GB | 20GB | 40GB | 80GB | 140GB | 60GB | 160GB |
|
| QLoRA | 4 | 6GB | 12GB | 24GB | 48GB | `x/2`GB |
|
||||||
| QLoRA | 4 | 6GB | 12GB | 24GB | 48GB | 72GB | 30GB | 96GB |
|
| QLoRA | 2 | 4GB | 8GB | 16GB | 24GB | `x/4`GB |
|
||||||
| QLoRA | 2 | 4GB | 8GB | 16GB | 24GB | 48GB | 18GB | 48GB |
|
|
||||||
|
|
||||||
## Getting Started
|
## Getting Started
|
||||||
|
|
||||||
@ -560,6 +559,8 @@ See [examples/README.md](examples/README.md) for advanced usage (including distr
|
|||||||
|
|
||||||
> [!TIP]
|
> [!TIP]
|
||||||
> Use `llamafactory-cli help` to show help information.
|
> Use `llamafactory-cli help` to show help information.
|
||||||
|
>
|
||||||
|
> Read [FAQs](https://github.com/hiyouga/LLaMA-Factory/issues/4614) first if you encounter any problems.
|
||||||
|
|
||||||
### Fine-Tuning with LLaMA Board GUI (powered by [Gradio](https://github.com/gradio-app/gradio))
|
### Fine-Tuning with LLaMA Board GUI (powered by [Gradio](https://github.com/gradio-app/gradio))
|
||||||
|
|
||||||
|
21
README_zh.md
21
README_zh.md
@ -5,7 +5,7 @@
|
|||||||
[](https://github.com/hiyouga/LLaMA-Factory/graphs/contributors)
|
[](https://github.com/hiyouga/LLaMA-Factory/graphs/contributors)
|
||||||
[](https://github.com/hiyouga/LLaMA-Factory/actions/workflows/tests.yml)
|
[](https://github.com/hiyouga/LLaMA-Factory/actions/workflows/tests.yml)
|
||||||
[](https://pypi.org/project/llamafactory/)
|
[](https://pypi.org/project/llamafactory/)
|
||||||
[](https://scholar.google.com/scholar?cites=12620864006390196564)
|
[](https://scholar.google.com/scholar?cites=12620864006390196564)
|
||||||
[](https://github.com/hiyouga/LLaMA-Factory/pulls)
|
[](https://github.com/hiyouga/LLaMA-Factory/pulls)
|
||||||
|
|
||||||
[](https://twitter.com/llamafactory_ai)
|
[](https://twitter.com/llamafactory_ai)
|
||||||
@ -414,15 +414,14 @@ huggingface-cli login
|
|||||||
|
|
||||||
\* *估算值*
|
\* *估算值*
|
||||||
|
|
||||||
| 方法 | 精度 | 7B | 13B | 30B | 70B | 110B | 8x7B | 8x22B |
|
| 方法 | 精度 | 7B | 14B | 30B | 70B | `x`B |
|
||||||
| ------------------------ | ---- | ----- | ----- | ----- | ------ | ------ | ----- | ------ |
|
| ------------------------------- | ---- | ----- | ----- | ----- | ------ | ------- |
|
||||||
| Full | 32 | 120GB | 240GB | 600GB | 1200GB | 2000GB | 900GB | 2400GB |
|
| Full (`bf16` or `fp16`) | 32 | 120GB | 240GB | 600GB | 1200GB | `18x`GB |
|
||||||
| Full | 16 | 60GB | 120GB | 300GB | 600GB | 900GB | 400GB | 1200GB |
|
| Full (`pure_bf16`) | 16 | 60GB | 120GB | 300GB | 600GB | `8x`GB |
|
||||||
| Freeze | 16 | 20GB | 40GB | 80GB | 200GB | 360GB | 160GB | 400GB |
|
| Freeze/LoRA/GaLore/APOLLO/BAdam | 16 | 16GB | 32GB | 64GB | 160GB | `2x`GB |
|
||||||
| LoRA/GaLore/APOLLO/BAdam | 16 | 16GB | 32GB | 64GB | 160GB | 240GB | 120GB | 320GB |
|
| QLoRA | 8 | 10GB | 20GB | 40GB | 80GB | `x`GB |
|
||||||
| QLoRA | 8 | 10GB | 20GB | 40GB | 80GB | 140GB | 60GB | 160GB |
|
| QLoRA | 4 | 6GB | 12GB | 24GB | 48GB | `x/2`GB |
|
||||||
| QLoRA | 4 | 6GB | 12GB | 24GB | 48GB | 72GB | 30GB | 96GB |
|
| QLoRA | 2 | 4GB | 8GB | 16GB | 24GB | `x/4`GB |
|
||||||
| QLoRA | 2 | 4GB | 8GB | 16GB | 24GB | 48GB | 18GB | 48GB |
|
|
||||||
|
|
||||||
## 如何使用
|
## 如何使用
|
||||||
|
|
||||||
@ -563,6 +562,8 @@ llamafactory-cli export examples/merge_lora/llama3_lora_sft.yaml
|
|||||||
|
|
||||||
> [!TIP]
|
> [!TIP]
|
||||||
> 使用 `llamafactory-cli help` 显示帮助信息。
|
> 使用 `llamafactory-cli help` 显示帮助信息。
|
||||||
|
>
|
||||||
|
> 遇到报错请先看[常见问题](https://github.com/hiyouga/LLaMA-Factory/issues/4614)。
|
||||||
|
|
||||||
### LLaMA Board 可视化微调(由 [Gradio](https://github.com/gradio-app/gradio) 驱动)
|
### LLaMA Board 可视化微调(由 [Gradio](https://github.com/gradio-app/gradio) 驱动)
|
||||||
|
|
||||||
|
@ -38,7 +38,7 @@ def vllm_infer(
|
|||||||
dataset_dir: str = "data",
|
dataset_dir: str = "data",
|
||||||
template: str = "default",
|
template: str = "default",
|
||||||
cutoff_len: int = 2048,
|
cutoff_len: int = 2048,
|
||||||
max_samples: int = None,
|
max_samples: Optional[int] = None,
|
||||||
vllm_config: str = "{}",
|
vllm_config: str = "{}",
|
||||||
save_name: str = "generated_predictions.jsonl",
|
save_name: str = "generated_predictions.jsonl",
|
||||||
temperature: float = 0.95,
|
temperature: float = 0.95,
|
||||||
@ -46,6 +46,7 @@ def vllm_infer(
|
|||||||
top_k: int = 50,
|
top_k: int = 50,
|
||||||
max_new_tokens: int = 1024,
|
max_new_tokens: int = 1024,
|
||||||
repetition_penalty: float = 1.0,
|
repetition_penalty: float = 1.0,
|
||||||
|
skip_special_tokens: bool = True,
|
||||||
seed: Optional[int] = None,
|
seed: Optional[int] = None,
|
||||||
pipeline_parallel_size: int = 1,
|
pipeline_parallel_size: int = 1,
|
||||||
image_max_pixels: int = 768 * 768,
|
image_max_pixels: int = 768 * 768,
|
||||||
@ -97,19 +98,21 @@ def vllm_infer(
|
|||||||
multi_modal_data = None
|
multi_modal_data = None
|
||||||
|
|
||||||
inputs.append({"prompt_token_ids": sample["input_ids"], "multi_modal_data": multi_modal_data})
|
inputs.append({"prompt_token_ids": sample["input_ids"], "multi_modal_data": multi_modal_data})
|
||||||
prompts.append(tokenizer.decode(sample["input_ids"], skip_special_tokens=False))
|
prompts.append(tokenizer.decode(sample["input_ids"], skip_special_tokens=skip_special_tokens))
|
||||||
labels.append(
|
labels.append(
|
||||||
tokenizer.decode(list(filter(lambda x: x != IGNORE_INDEX, sample["labels"])), skip_special_tokens=False)
|
tokenizer.decode(
|
||||||
|
list(filter(lambda x: x != IGNORE_INDEX, sample["labels"])), skip_special_tokens=skip_special_tokens
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
sampling_params = SamplingParams(
|
sampling_params = SamplingParams(
|
||||||
repetition_penalty=generating_args.repetition_penalty or 1.0, # repetition_penalty must > 0
|
repetition_penalty=generating_args.repetition_penalty or 1.0, # repetition_penalty must > 0
|
||||||
temperature=generating_args.temperature,
|
temperature=generating_args.temperature,
|
||||||
top_p=generating_args.top_p or 1.0, # top_p must > 0
|
top_p=generating_args.top_p or 1.0, # top_p must > 0
|
||||||
top_k=generating_args.top_k,
|
top_k=generating_args.top_k or -1, # top_k must > 0
|
||||||
stop_token_ids=template_obj.get_stop_token_ids(tokenizer),
|
stop_token_ids=template_obj.get_stop_token_ids(tokenizer),
|
||||||
max_tokens=generating_args.max_new_tokens,
|
max_tokens=generating_args.max_new_tokens,
|
||||||
skip_special_tokens=False,
|
skip_special_tokens=skip_special_tokens,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
)
|
)
|
||||||
if model_args.adapter_name_or_path is not None:
|
if model_args.adapter_name_or_path is not None:
|
||||||
@ -121,6 +124,7 @@ def vllm_infer(
|
|||||||
"model": model_args.model_name_or_path,
|
"model": model_args.model_name_or_path,
|
||||||
"trust_remote_code": True,
|
"trust_remote_code": True,
|
||||||
"dtype": model_args.infer_dtype,
|
"dtype": model_args.infer_dtype,
|
||||||
|
"max_model_len": cutoff_len + max_new_tokens,
|
||||||
"tensor_parallel_size": (get_device_count() // pipeline_parallel_size) or 1,
|
"tensor_parallel_size": (get_device_count() // pipeline_parallel_size) or 1,
|
||||||
"pipeline_parallel_size": pipeline_parallel_size,
|
"pipeline_parallel_size": pipeline_parallel_size,
|
||||||
"disable_log_stats": True,
|
"disable_log_stats": True,
|
||||||
|
@ -170,7 +170,7 @@ class VllmEngine(BaseEngine):
|
|||||||
or 1.0, # repetition_penalty must > 0
|
or 1.0, # repetition_penalty must > 0
|
||||||
temperature=temperature if temperature is not None else self.generating_args["temperature"],
|
temperature=temperature if temperature is not None else self.generating_args["temperature"],
|
||||||
top_p=(top_p if top_p is not None else self.generating_args["top_p"]) or 1.0, # top_p must > 0
|
top_p=(top_p if top_p is not None else self.generating_args["top_p"]) or 1.0, # top_p must > 0
|
||||||
top_k=top_k if top_k is not None else self.generating_args["top_k"],
|
top_k=(top_k if top_k is not None else self.generating_args["top_k"]) or -1, # top_k must > 0
|
||||||
stop=stop,
|
stop=stop,
|
||||||
stop_token_ids=self.template.get_stop_token_ids(self.tokenizer),
|
stop_token_ids=self.template.get_stop_token_ids(self.tokenizer),
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user