From 235cdcaceed577411172eb17c0d37cb542a3e35b Mon Sep 17 00:00:00 2001 From: hiyouga Date: Wed, 4 Dec 2024 13:50:00 +0000 Subject: [PATCH] support batch infer in vllm Former-commit-id: 1324d158f954d777f1fbf09f46149c372704b388 --- .gitignore | 1 + README.md | 3 +- README_zh.md | 2 +- examples/README.md | 30 +-- examples/README_zh.md | 30 +-- .../nlg_eval}/llama3_lora_predict.yaml | 3 + examples/inference/llama3.yaml | 1 + examples/inference/llama3_lora_sft.yaml | 1 + examples/inference/llava1_5.yaml | 1 + examples/inference/qwen2_vl.yaml | 1 + examples/train_full/llama3_full_predict.yaml | 23 -- scripts/{ => api_example}/test_image.py | 0 scripts/{ => api_example}/test_toolcall.py | 0 scripts/async_call_api.py | 223 ------------------ .../{ => convert_ckpt}/llamafy_baichuan2.py | 0 scripts/{ => convert_ckpt}/llamafy_qwen.py | 0 scripts/{ => stat_utils}/cal_flops.py | 0 scripts/{ => stat_utils}/cal_lr.py | 0 scripts/{ => stat_utils}/cal_mfu.py | 0 scripts/{ => stat_utils}/cal_ppl.py | 0 scripts/{ => stat_utils}/length_cdf.py | 0 scripts/vllm_infer.py | 161 ++++++------- setup.py | 2 +- src/llamafactory/extras/misc.py | 25 +- src/llamafactory/hparams/parser.py | 2 +- src/llamafactory/train/dpo/workflow.py | 15 +- src/llamafactory/train/sft/trainer.py | 11 +- src/llamafactory/train/sft/workflow.py | 18 +- src/llamafactory/webui/interface.py | 2 +- 29 files changed, 148 insertions(+), 407 deletions(-) rename examples/{train_lora => extras/nlg_eval}/llama3_lora_predict.yaml (80%) delete mode 100644 examples/train_full/llama3_full_predict.yaml rename scripts/{ => api_example}/test_image.py (100%) rename scripts/{ => api_example}/test_toolcall.py (100%) delete mode 100644 scripts/async_call_api.py rename scripts/{ => convert_ckpt}/llamafy_baichuan2.py (100%) rename scripts/{ => convert_ckpt}/llamafy_qwen.py (100%) rename scripts/{ => stat_utils}/cal_flops.py (100%) rename scripts/{ => stat_utils}/cal_lr.py (100%) rename scripts/{ => stat_utils}/cal_mfu.py (100%) rename scripts/{ => stat_utils}/cal_ppl.py (100%) rename scripts/{ => stat_utils}/length_cdf.py (100%) diff --git a/.gitignore b/.gitignore index 630760ed..88c36ca2 100644 --- a/.gitignore +++ b/.gitignore @@ -171,3 +171,4 @@ config/ saves/ output/ wandb/ +generated_predictions.jsonl diff --git a/README.md b/README.md index 91a55e2d..d0a94dc6 100644 --- a/README.md +++ b/README.md @@ -594,7 +594,7 @@ API_PORT=8000 llamafactory-cli api examples/inference/llama3_vllm.yaml > [!TIP] > Visit [this page](https://platform.openai.com/docs/api-reference/chat/create) for API document. > -> Examples: [Image understanding](scripts/test_image.py) | [Function calling](scripts/test_toolcall.py) +> Examples: [Image understanding](scripts/api_example/test_image.py) | [Function calling](scripts/api_example/test_toolcall.py) ### Download from ModelScope Hub @@ -727,7 +727,6 @@ If you have a project that should be incorporated, please contact via email or c 1. **[LazyLLM](https://github.com/LazyAGI/LazyLLM)**: An easy and lazy way for building multi-agent LLMs applications and supports model fine-tuning via LLaMA Factory. 1. **[RAG-Retrieval](https://github.com/NLPJCL/RAG-Retrieval)**: A full pipeline for RAG retrieval model fine-tuning, inference, and distillation. [[blog]](https://zhuanlan.zhihu.com/p/987727357) - ## License diff --git a/README_zh.md b/README_zh.md index 88ee7d9e..7e5b914b 100644 --- a/README_zh.md +++ b/README_zh.md @@ -594,7 +594,7 @@ API_PORT=8000 llamafactory-cli api examples/inference/llama3_vllm.yaml > [!TIP] > API 文档请查阅[这里](https://platform.openai.com/docs/api-reference/chat/create)。 > -> 示例:[图像理解](scripts/test_image.py) | [工具调用](scripts/test_toolcall.py) +> 示例:[图像理解](scripts/api_example/test_image.py) | [工具调用](scripts/api_example/test_toolcall.py) ### 从魔搭社区下载 diff --git a/examples/README.md b/examples/README.md index 3a98a088..9413ef2a 100644 --- a/examples/README.md +++ b/examples/README.md @@ -13,6 +13,8 @@ Make sure to execute these commands in the `LLaMA-Factory` directory. Use `CUDA_VISIBLE_DEVICES` (GPU) or `ASCEND_RT_VISIBLE_DEVICES` (NPU) to choose computing devices. +By default, LLaMA-Factory uses all visible computing devices. + ## Examples ### LoRA Fine-Tuning @@ -80,12 +82,6 @@ llamafactory-cli train examples/train_lora/llama3_preprocess.yaml llamafactory-cli eval examples/train_lora/llama3_lora_eval.yaml ``` -#### Batch Predicting and Computing BLEU and ROUGE Scores - -```bash -llamafactory-cli train examples/train_lora/llama3_lora_predict.yaml -``` - #### Supervised Fine-Tuning on Multiple Nodes ```bash @@ -146,12 +142,6 @@ FORCE_TORCHRUN=1 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llama FORCE_TORCHRUN=1 llamafactory-cli train examples/train_full/qwen2vl_full_sft.yaml ``` -#### Batch Predicting and Computing BLEU and ROUGE Scores - -```bash -llamafactory-cli train examples/train_full/llama3_full_predict.yaml -``` - ### Merging LoRA Adapters and Quantization #### Merge LoRA Adapters @@ -170,13 +160,19 @@ llamafactory-cli export examples/merge_lora/llama3_gptq.yaml ### Inferring LoRA Fine-Tuned Models -#### Use CLI +#### Batch Generation using vLLM Tensor Parallel + +``` +python scripts/vllm_infer.py --model_name_or_path path_to_merged_model --dataset alpaca_en_demo +``` + +#### Use CLI ChatBox ```bash llamafactory-cli chat examples/inference/llama3_lora_sft.yaml ``` -#### Use Web UI +#### Use Web UI ChatBox ```bash llamafactory-cli webchat examples/inference/llama3_lora_sft.yaml @@ -238,3 +234,9 @@ llamafactory-cli train examples/extras/llama_pro/llama3_freeze_sft.yaml ```bash bash examples/extras/fsdp_qlora/train.sh ``` + +#### Computing BLEU and ROUGE Scores + +```bash +llamafactory-cli train examples/extras/nlg_eval/llama3_lora_predict.yaml +``` diff --git a/examples/README_zh.md b/examples/README_zh.md index 45e96bcf..9aa4ca90 100644 --- a/examples/README_zh.md +++ b/examples/README_zh.md @@ -13,6 +13,8 @@ 使用 `CUDA_VISIBLE_DEVICES`(GPU)或 `ASCEND_RT_VISIBLE_DEVICES`(NPU)选择计算设备。 +LLaMA-Factory 默认使用所有可见的计算设备。 + ## 示例 ### LoRA 微调 @@ -80,12 +82,6 @@ llamafactory-cli train examples/train_lora/llama3_preprocess.yaml llamafactory-cli eval examples/train_lora/llama3_lora_eval.yaml ``` -#### 批量预测并计算 BLEU 和 ROUGE 分数 - -```bash -llamafactory-cli train examples/train_lora/llama3_lora_predict.yaml -``` - #### 多机指令监督微调 ```bash @@ -146,12 +142,6 @@ FORCE_TORCHRUN=1 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llama FORCE_TORCHRUN=1 llamafactory-cli train examples/train_full/qwen2vl_full_sft.yaml ``` -#### 批量预测并计算 BLEU 和 ROUGE 分数 - -```bash -llamafactory-cli train examples/train_full/llama3_full_predict.yaml -``` - ### 合并 LoRA 适配器与模型量化 #### 合并 LoRA 适配器 @@ -170,13 +160,19 @@ llamafactory-cli export examples/merge_lora/llama3_gptq.yaml ### 推理 LoRA 模型 -#### 使用命令行接口 +#### 使用 vLLM+TP 批量推理 + +``` +python scripts/vllm_infer.py --model_name_or_path path_to_merged_model --dataset alpaca_en_demo +``` + +#### 使用命令行对话框 ```bash llamafactory-cli chat examples/inference/llama3_lora_sft.yaml ``` -#### 使用浏览器界面 +#### 使用浏览器对话框 ```bash llamafactory-cli webchat examples/inference/llama3_lora_sft.yaml @@ -238,3 +234,9 @@ llamafactory-cli train examples/extras/llama_pro/llama3_freeze_sft.yaml ```bash bash examples/extras/fsdp_qlora/train.sh ``` + +#### 计算 BLEU 和 ROUGE 分数 + +```bash +llamafactory-cli train examples/extras/nlg_eval/llama3_lora_predict.yaml +``` diff --git a/examples/train_lora/llama3_lora_predict.yaml b/examples/extras/nlg_eval/llama3_lora_predict.yaml similarity index 80% rename from examples/train_lora/llama3_lora_predict.yaml rename to examples/extras/nlg_eval/llama3_lora_predict.yaml index f7119a8a..3cdca843 100644 --- a/examples/train_lora/llama3_lora_predict.yaml +++ b/examples/extras/nlg_eval/llama3_lora_predict.yaml @@ -1,3 +1,6 @@ +# The batch generation can be SLOW using this config. +# For faster inference, we recommend to use `scripts/vllm_infer.py`. + ### model model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct adapter_name_or_path: saves/llama3-8b/lora/sft diff --git a/examples/inference/llama3.yaml b/examples/inference/llama3.yaml index ffc5be82..a2a8c6fd 100644 --- a/examples/inference/llama3.yaml +++ b/examples/inference/llama3.yaml @@ -1,2 +1,3 @@ model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct template: llama3 +infer_backend: huggingface # choices: [huggingface, vllm] diff --git a/examples/inference/llama3_lora_sft.yaml b/examples/inference/llama3_lora_sft.yaml index 262f4445..ec5d8732 100644 --- a/examples/inference/llama3_lora_sft.yaml +++ b/examples/inference/llama3_lora_sft.yaml @@ -2,3 +2,4 @@ model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct adapter_name_or_path: saves/llama3-8b/lora/sft template: llama3 finetuning_type: lora +infer_backend: huggingface # choices: [huggingface, vllm] diff --git a/examples/inference/llava1_5.yaml b/examples/inference/llava1_5.yaml index 68f3b8ff..163d6e1e 100644 --- a/examples/inference/llava1_5.yaml +++ b/examples/inference/llava1_5.yaml @@ -1,2 +1,3 @@ model_name_or_path: llava-hf/llava-1.5-7b-hf template: llava +infer_backend: huggingface # choices: [huggingface, vllm] diff --git a/examples/inference/qwen2_vl.yaml b/examples/inference/qwen2_vl.yaml index ed1cef6c..f6dc5539 100644 --- a/examples/inference/qwen2_vl.yaml +++ b/examples/inference/qwen2_vl.yaml @@ -1,2 +1,3 @@ model_name_or_path: Qwen/Qwen2-VL-7B-Instruct template: qwen2_vl +infer_backend: huggingface # choices: [huggingface, vllm] diff --git a/examples/train_full/llama3_full_predict.yaml b/examples/train_full/llama3_full_predict.yaml deleted file mode 100644 index dcac4925..00000000 --- a/examples/train_full/llama3_full_predict.yaml +++ /dev/null @@ -1,23 +0,0 @@ -### model -model_name_or_path: saves/llama3-8b/full/sft - -### method -stage: sft -do_predict: true -finetuning_type: full - -### dataset -eval_dataset: identity,alpaca_en_demo -template: llama3 -cutoff_len: 2048 -max_samples: 50 -overwrite_cache: true -preprocessing_num_workers: 16 - -### output -output_dir: saves/llama3-8b/full/predict -overwrite_output_dir: true - -### eval -per_device_eval_batch_size: 1 -predict_with_generate: true diff --git a/scripts/test_image.py b/scripts/api_example/test_image.py similarity index 100% rename from scripts/test_image.py rename to scripts/api_example/test_image.py diff --git a/scripts/test_toolcall.py b/scripts/api_example/test_toolcall.py similarity index 100% rename from scripts/test_toolcall.py rename to scripts/api_example/test_toolcall.py diff --git a/scripts/async_call_api.py b/scripts/async_call_api.py deleted file mode 100644 index 3dcf4b87..00000000 --- a/scripts/async_call_api.py +++ /dev/null @@ -1,223 +0,0 @@ -# Copyright 2024 the LlamaFactory team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# pip install langchain langchain_openai - -import os -import sys -import json -import asyncio - - -import fire -from tqdm import tqdm -from dataclasses import dataclass -from aiolimiter import AsyncLimiter -from typing import List -import pandas as pd -from langchain_openai import ChatOpenAI -from dotenv import load_dotenv - -from llamafactory.hparams import get_train_args -from llamafactory.extras.constants import IGNORE_INDEX -from llamafactory.data.loader import _get_merged_dataset - -load_dotenv() - - -class AsyncLLM: - def __init__( - self, - model: str = "gpt-3.5-turbo", - base_url: str = "http://localhost:{}/v1/".format( - os.environ.get("API_PORT", 8000) - ), - api_key: str = "{}".format(os.environ.get("API_KEY", "0")), - num_per_second: int = 6, - **kwargs, - ): - self.model = model - self.base_url = base_url - self.api_key = api_key - self.num_per_second = num_per_second - - # 创建限速器,每秒最多发出 5 个请求 - self.limiter = AsyncLimiter(self.num_per_second, 1) - - self.llm = ChatOpenAI( - model=self.model, base_url=self.base_url, api_key=self.api_key, **kwargs - ) - - async def __call__(self, text): - # 限速 - async with self.limiter: - return await self.llm.ainvoke([text]) - - -llm = AsyncLLM( - base_url="http://localhost:{}/v1/".format(os.environ.get("API_PORT", 8000)), - api_key="{}".format(os.environ.get("API_KEY", "0")), - num_per_second=10, -) -llms = [llm] - - -@dataclass -class AsyncAPICall: - uid: str = "0" - - @staticmethod - async def _run_task_with_progress(task, pbar): - result = await task - pbar.update(1) - return result - - @staticmethod - def async_run( - llms: List[AsyncLLM], - data: List[str], - keyword: str = "", - output_dir: str = "output", - chunk_size=500, - ) -> List[str]: - - async def infer_chunk(llms: List[AsyncLLM], data: List): - """ - 逐块进行推理,为避免处理庞大数据时,程序崩溃导致已推理数据丢失 - """ - results = [llms[i % len(llms)](text) for i, text in enumerate(data)] - - with tqdm(total=len(results)) as pbar: - results = await asyncio.gather( - *[ - AsyncAPICall._run_task_with_progress(task, pbar) - for task in results - ] - ) - return results - - idx = 0 - all_df = [] - file_exist_skip = False - user_confirm = False - - while idx < len(data): - file_path = os.path.join(output_dir, "tmp", f"{idx}.csv.temp") - - if os.path.exists(file_path): - if not user_confirm: - while True: - user_response = input( - f"Find {file_path} file already exists. Do you want to skip them forever?\ny or Y to skip, n or N to rerun to overwrite: " - ) - if user_response.lower() == "y": - user_confirm = True - file_exist_skip = True - break - elif user_response.lower() == "n": - user_confirm = True - file_exist_skip = False - break - - if file_exist_skip: - tmp_df = pd.read_csv(file_path) - all_df.append(tmp_df) - idx += chunk_size - continue - - tmp_data = data[idx : idx + chunk_size] - loop = asyncio.get_event_loop() - tmp_result = loop.run_until_complete(infer_chunk(llms=llms, data=tmp_data)) - tmp_result = [item.content for item in tmp_result] - - tmp_df = pd.DataFrame({"infer": tmp_result}) - - if not os.path.exists(p := os.path.dirname(file_path)): - os.makedirs(p, exist_ok=True) - - tmp_df.to_csv(file_path, index=False) - all_df.append(tmp_df) - idx += chunk_size - - all_df = pd.concat(all_df) - return all_df["infer"] - - -def async_api_infer( - model_name_or_path: str = "", - eval_dataset: str = "", - template: str = "", - dataset_dir: str = "data", - do_predict: bool = True, - predict_with_generate: bool = True, - max_samples: int = None, - output_dir: str = "output", - chunk_size=50, -): - - if len(sys.argv) == 1: - model_args, data_args, training_args, finetuning_args, generating_args = ( - get_train_args( - dict( - model_name_or_path=model_name_or_path, - dataset_dir=dataset_dir, - eval_dataset=eval_dataset, - template=template, - output_dir=output_dir, - do_predict=True, - predict_with_generate=True, - max_samples=max_samples, - ) - ) - ) - else: - model_args, data_args, training_args, finetuning_args, generating_args = ( - get_train_args() - ) - - dataset = _get_merged_dataset( - data_args.eval_dataset, model_args, data_args, training_args, "sft" - ) - - labels = [item[0]["content"] for item in dataset["_response"]] - prompts = [item[0]["content"] for item in dataset["_prompt"]] - - infers = AsyncAPICall.async_run( - llms, - prompts, - chunk_size=chunk_size, - output_dir=training_args.output_dir, - ) - - if not os.path.exists(training_args.output_dir): - os.makedirs(training_args.output_dir, exist_ok=True) - - output_prediction_file = os.path.join( - training_args.output_dir, "generated_predictions.jsonl" - ) - - with open(output_prediction_file, "w", encoding="utf-8") as writer: - res: List[str] = [] - for text, pred, label in zip(prompts, infers, labels): - res.append( - json.dumps( - {"prompt": text, "predict": pred, "label": label}, - ensure_ascii=False, - ) - ) - writer.write("\n".join(res)) - - -if __name__ == "__main__": - fire.Fire(async_api_infer) diff --git a/scripts/llamafy_baichuan2.py b/scripts/convert_ckpt/llamafy_baichuan2.py similarity index 100% rename from scripts/llamafy_baichuan2.py rename to scripts/convert_ckpt/llamafy_baichuan2.py diff --git a/scripts/llamafy_qwen.py b/scripts/convert_ckpt/llamafy_qwen.py similarity index 100% rename from scripts/llamafy_qwen.py rename to scripts/convert_ckpt/llamafy_qwen.py diff --git a/scripts/cal_flops.py b/scripts/stat_utils/cal_flops.py similarity index 100% rename from scripts/cal_flops.py rename to scripts/stat_utils/cal_flops.py diff --git a/scripts/cal_lr.py b/scripts/stat_utils/cal_lr.py similarity index 100% rename from scripts/cal_lr.py rename to scripts/stat_utils/cal_lr.py diff --git a/scripts/cal_mfu.py b/scripts/stat_utils/cal_mfu.py similarity index 100% rename from scripts/cal_mfu.py rename to scripts/stat_utils/cal_mfu.py diff --git a/scripts/cal_ppl.py b/scripts/stat_utils/cal_ppl.py similarity index 100% rename from scripts/cal_ppl.py rename to scripts/stat_utils/cal_ppl.py diff --git a/scripts/length_cdf.py b/scripts/stat_utils/length_cdf.py similarity index 100% rename from scripts/length_cdf.py rename to scripts/stat_utils/length_cdf.py diff --git a/scripts/vllm_infer.py b/scripts/vllm_infer.py index 0d498959..1eae6842 100644 --- a/scripts/vllm_infer.py +++ b/scripts/vllm_infer.py @@ -12,127 +12,106 @@ # See the License for the specific language governing permissions and # limitations under the License. - import json -import os -import sys -from typing import List import fire +from transformers import Seq2SeqTrainingArguments from vllm import LLM, SamplingParams from vllm.lora.request import LoRARequest from llamafactory.data import get_dataset, get_template_and_fix_tokenizer from llamafactory.extras.constants import IGNORE_INDEX -from llamafactory.hparams import get_train_args +from llamafactory.extras.misc import get_device_count +from llamafactory.hparams import get_infer_args from llamafactory.model import load_tokenizer -max_tokens = 2048 - - def vllm_infer( - model_name_or_path: str = None, + model_name_or_path: str, adapter_name_or_path: str = None, + dataset: str = "alpaca_en_demo", dataset_dir: str = "data", - eval_dataset: str = None, template: str = "default", - max_sample: int = None, - preprocessing_num_workers: int = 16, - predict_with_generate: bool = True, - do_predict: bool = True, - temperature: float = 0.7, + cutoff_len: int = 2048, + max_samples: int = None, + vllm_config: str = "{}", + save_name: str = "generated_predictions.jsonl", + temperature: float = 0.95, top_p: float = 0.7, - top_k: float = 50, - output_dir: str = "output", + top_k: int = 50, + max_new_tokens: int = 1024, + repetition_penalty: float = 1.0, ): - - if len(sys.argv) == 1: - model_args, data_args, training_args, finetuning_args, generating_args = ( - get_train_args( - dict( - model_name_or_path=model_name_or_path, - adapter_name_or_path=adapter_name_or_path, - dataset_dir=dataset_dir, - eval_dataset=eval_dataset, - template=template, - max_sample=max_sample, - preprocessing_num_workers=preprocessing_num_workers, - predict_with_generate=predict_with_generate, - do_predict=do_predict, - temperature=temperature, - top_p=top_p, - top_k=top_k, - output_dir=output_dir, - ) - ) - ) - else: - model_args, data_args, training_args, finetuning_args, generating_args = ( - get_train_args() + r""" + Performs batch generation using vLLM engine, which supports tensor parallelism. + Usage: python vllm_infer.py --model_name_or_path meta-llama/Llama-2-7b-hf --template llama --dataset alpaca_en_demo + """ + model_args, data_args, _, generating_args = get_infer_args( + dict( + model_name_or_path=model_name_or_path, + adapter_name_or_path=adapter_name_or_path, + dataset=dataset, + dataset_dir=dataset_dir, + template=template, + cutoff_len=cutoff_len, + max_samples=max_samples, + vllm_config=vllm_config, + temperature=temperature, + top_p=top_p, + top_k=top_k, + max_new_tokens=max_new_tokens, + repetition_penalty=repetition_penalty, ) + ) - tokenizer = load_tokenizer(model_args)["tokenizer"] + training_args = Seq2SeqTrainingArguments(output_dir="dummy_dir", predict_with_generate=True) + tokenizer_module = load_tokenizer(model_args) + tokenizer = tokenizer_module["tokenizer"] template = get_template_and_fix_tokenizer(tokenizer, data_args) + dataset = get_dataset(template, model_args, data_args, training_args, "ppo", **tokenizer_module)["train_dataset"] - eval_dataset = get_dataset( - template, model_args, data_args, training_args, finetuning_args.stage, tokenizer - )["eval_dataset"] - - prompts = [item["input_ids"] for item in eval_dataset] - prompts = tokenizer.batch_decode(prompts, skip_special_tokens=False) - - labels = [ - list(filter(lambda x: x != IGNORE_INDEX, item["labels"])) - for item in eval_dataset - ] - labels = tokenizer.batch_decode(labels, skip_special_tokens=True) + inputs, prompts, labels = [], [], [] + for sample in dataset: + inputs.append({"prompt_token_ids": sample["input_ids"]}) + prompts.append(tokenizer.decode(sample["input_ids"], skip_special_tokens=False)) + labels.append( + tokenizer.decode(list(filter(lambda x: x != IGNORE_INDEX, sample["labels"])), skip_special_tokens=False) + ) sampling_params = SamplingParams( + repetition_penalty=generating_args.repetition_penalty or 1.0, # repetition_penalty must > 0 temperature=generating_args.temperature, + top_p=generating_args.top_p or 1.0, # top_p must > 0 top_k=generating_args.top_k, - top_p=generating_args.top_p, - max_tokens=max_tokens, + stop_token_ids=[tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids, + max_tokens=generating_args.max_new_tokens, + skip_special_tokens=False, ) - - if model_args.adapter_name_or_path: - if isinstance(model_args.adapter_name_or_path, list): - lora_path = model_args.adapter_name_or_path[0] - else: - lora_path = model_args.adapter_name_or_path - - lora_requests = LoRARequest("lora_adapter_0", 0, lora_path=lora_path) - enable_lora = True + if model_args.adapter_name_or_path is not None: + lora_request = LoRARequest("default", 1, model_args.adapter_name_or_path[0]) else: - lora_requests = None - enable_lora = False + lora_request = None - llm = LLM( - model=model_args.model_name_or_path, - trust_remote_code=True, - tokenizer=model_args.model_name_or_path, - enable_lora=enable_lora, - ) + engine_args = { + "model": model_args.model_name_or_path, + "trust_remote_code": True, + "dtype": model_args.infer_dtype, + "tensor_parallel_size": get_device_count() or 1, + "disable_log_stats": True, + "enable_lora": model_args.adapter_name_or_path is not None, + } + if isinstance(model_args.vllm_config, dict): + engine_args.update(model_args.vllm_config) - outputs = llm.generate(prompts, sampling_params, lora_request=lora_requests) + results = LLM(**engine_args).generate(inputs, sampling_params, lora_request=lora_request) + preds = [result.outputs[0].text for result in results] + with open(save_name, "w", encoding="utf-8") as f: + for text, pred, label in zip(prompts, preds, labels): + f.write(json.dumps({"prompt": text, "predict": pred, "label": label}, ensure_ascii=False) + "\n") - if not os.path.exists(training_args.output_dir): - os.makedirs(training_args.output_dir, exist_ok=True) - - output_prediction_file = os.path.join( - training_args.output_dir, "generated_predictions.jsonl" - ) - - with open(output_prediction_file, "w", encoding="utf-8") as writer: - res: List[str] = [] - for text, pred, label in zip(prompts, outputs, labels): - res.append( - json.dumps( - {"prompt": text, "predict": pred.outputs[0].text, "label": label}, - ensure_ascii=False, - ) - ) - writer.write("\n".join(res)) + print("*" * 70) + print(f"{len(prompts)} generated results have been saved at {save_name}.") + print("*" * 70) if __name__ == "__main__": diff --git a/setup.py b/setup.py index 74b6f73b..862e9b94 100644 --- a/setup.py +++ b/setup.py @@ -54,7 +54,7 @@ extra_require = { "gptq": ["optimum>=1.17.0", "auto-gptq>=0.5.0"], "awq": ["autoawq"], "aqlm": ["aqlm[gpu]>=1.1.0"], - "vllm": ["vllm>=0.4.3,<0.6.4"], + "vllm": ["vllm>=0.4.3,<0.6.5"], "galore": ["galore-torch"], "badam": ["badam>=1.2.1"], "adam-mini": ["adam-mini"], diff --git a/src/llamafactory/extras/misc.py b/src/llamafactory/extras/misc.py index f46c0f88..ae0da5ee 100644 --- a/src/llamafactory/extras/misc.py +++ b/src/llamafactory/extras/misc.py @@ -17,7 +17,7 @@ import gc import os -from typing import TYPE_CHECKING, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, Literal, Sequence, Tuple, Union import torch import torch.distributed as dist @@ -87,6 +87,21 @@ def check_dependencies() -> None: require_version("trl>=0.8.6,<=0.9.6", "To fix: pip install trl>=0.8.6,<=0.9.6") +def calculate_tps(dataset: Sequence[Dict[str, Any]], metrics: Dict[str, float], stage: Literal["sft", "rm"]) -> float: + r""" + Calculates effective tokens per second. + """ + effective_token_num = 0 + for data in dataset: + if stage == "sft": + effective_token_num += len(data["input_ids"]) + elif stage == "rm": + effective_token_num += len(data["chosen_input_ids"]) + len(data["rejected_input_ids"]) + + result = effective_token_num * metrics["epoch"] / metrics["train_runtime"] + return result / dist.get_world_size() if dist.is_initialized() else result + + def count_parameters(model: "torch.nn.Module") -> Tuple[int, int]: r""" Returns the number of trainable parameters and number of all parameters in the model. @@ -264,11 +279,3 @@ def use_modelscope() -> bool: def use_openmind() -> bool: return os.environ.get("USE_OPENMIND_HUB", "0").lower() in ["true", "1"] - - -def cal_effective_tokens(effective_token_num, epoch, train_runtime) -> int: - r""" - calculate effective tokens. - """ - result = effective_token_num * epoch / train_runtime - return result / dist.get_world_size() if dist.is_initialized() else result diff --git a/src/llamafactory/hparams/parser.py b/src/llamafactory/hparams/parser.py index 5bc16dac..5b542c2e 100644 --- a/src/llamafactory/hparams/parser.py +++ b/src/llamafactory/hparams/parser.py @@ -122,7 +122,7 @@ def _check_extra_dependencies( require_version("mixture-of-depth>=1.1.6", "To fix: pip install mixture-of-depth>=1.1.6") if model_args.infer_backend == "vllm": - require_version("vllm>=0.4.3,<0.6.4", "To fix: pip install vllm>=0.4.3,<0.6.4") + require_version("vllm>=0.4.3,<0.6.5", "To fix: pip install vllm>=0.4.3,<0.6.5") if finetuning_args.use_galore: require_version("galore_torch", "To fix: pip install galore_torch") diff --git a/src/llamafactory/train/dpo/workflow.py b/src/llamafactory/train/dpo/workflow.py index 8c3e7401..e3d6e660 100644 --- a/src/llamafactory/train/dpo/workflow.py +++ b/src/llamafactory/train/dpo/workflow.py @@ -19,7 +19,7 @@ from typing import TYPE_CHECKING, List, Optional from ...data import PairwiseDataCollatorWithPadding, get_dataset, get_template_and_fix_tokenizer from ...extras.constants import IGNORE_INDEX -from ...extras.misc import cal_effective_tokens +from ...extras.misc import calculate_tps from ...extras.ploting import plot_loss from ...hparams import ModelArguments from ...model import load_model, load_tokenizer @@ -65,12 +65,6 @@ def run_dpo( # Update arguments training_args.remove_unused_columns = False # important for multimodal and pairwise dataset - effective_token_num = 0.0 - if finetuning_args.include_effective_tokens_per_second: - for data in dataset_module["train_dataset"]: - effective_token_num += len(data["chosen_input_ids"]) - effective_token_num += len(data["rejected_input_ids"]) - # Initialize our Trainer trainer = CustomDPOTrainer( model=model, @@ -86,13 +80,12 @@ def run_dpo( # Training if training_args.do_train: train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint) - + trainer.save_model() if finetuning_args.include_effective_tokens_per_second: - train_result.metrics["effective_tokens_per_sec"] = cal_effective_tokens( - effective_token_num, train_result.metrics["epoch"], train_result.metrics["train_runtime"] + train_result.metrics["effective_tokens_per_sec"] = calculate_tps( + dataset_module["train_dataset"], train_result.metrics, stage="rm" ) - trainer.save_model() trainer.log_metrics("train", train_result.metrics) trainer.save_metrics("train", train_result.metrics) trainer.save_state() diff --git a/src/llamafactory/train/sft/trainer.py b/src/llamafactory/train/sft/trainer.py index f49fbd27..85ce6e8a 100644 --- a/src/llamafactory/train/sft/trainer.py +++ b/src/llamafactory/train/sft/trainer.py @@ -161,12 +161,9 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer): preds[i] = np.concatenate((preds[i][pad_len[0] :], preds[i][: pad_len[0]]), axis=-1) decoded_inputs = self.tokenizer.batch_decode(dataset["input_ids"], skip_special_tokens=True) - decoded_labels = self.tokenizer.batch_decode(labels, skip_special_tokens=True) decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True) + decoded_labels = self.tokenizer.batch_decode(labels, skip_special_tokens=True) - with open(output_prediction_file, "w", encoding="utf-8") as writer: - res: List[str] = [] - for text, label, pred in zip(decoded_inputs, decoded_labels, decoded_preds): - res.append(json.dumps({"prompt": text, "label": label, "predict": pred}, ensure_ascii=False)) - - writer.write("\n".join(res)) + with open(output_prediction_file, "w", encoding="utf-8") as f: + for text, pred, label in zip(decoded_inputs, decoded_preds, decoded_labels): + f.write(json.dumps({"prompt": text, "predict": pred, "label": label}, ensure_ascii=False) + "\n") diff --git a/src/llamafactory/train/sft/workflow.py b/src/llamafactory/train/sft/workflow.py index d8dafc5f..bc7ccb50 100644 --- a/src/llamafactory/train/sft/workflow.py +++ b/src/llamafactory/train/sft/workflow.py @@ -19,7 +19,8 @@ from typing import TYPE_CHECKING, List, Optional from ...data import SFTDataCollatorWith4DAttentionMask, get_dataset, get_template_and_fix_tokenizer from ...extras.constants import IGNORE_INDEX -from ...extras.misc import cal_effective_tokens, get_logits_processor +from ...extras.logging import get_logger +from ...extras.misc import calculate_tps, get_logits_processor from ...extras.ploting import plot_loss from ...model import load_model, load_tokenizer from ..trainer_utils import create_modelcard_and_push @@ -33,6 +34,9 @@ if TYPE_CHECKING: from ...hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments +logger = get_logger(__name__) + + def run_sft( model_args: "ModelArguments", data_args: "DataArguments", @@ -65,11 +69,6 @@ def run_sft( training_args.generation_num_beams = data_args.eval_num_beams or training_args.generation_num_beams training_args.remove_unused_columns = False # important for multimodal dataset - effective_token_num = 0.0 - if finetuning_args.include_effective_tokens_per_second: - for data in dataset_module["train_dataset"]: - effective_token_num += len(data["input_ids"]) - # Metric utils metric_module = {} if training_args.predict_with_generate: @@ -99,12 +98,12 @@ def run_sft( # Training if training_args.do_train: train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint) + trainer.save_model() if finetuning_args.include_effective_tokens_per_second: - train_result.metrics["effective_tokens_per_sec"] = cal_effective_tokens( - effective_token_num, train_result.metrics["epoch"], train_result.metrics["train_runtime"] + train_result.metrics["effective_tokens_per_sec"] = calculate_tps( + dataset_module["train_dataset"], train_result.metrics, stage="sft" ) - trainer.save_model() trainer.log_metrics("train", train_result.metrics) trainer.save_metrics("train", train_result.metrics) trainer.save_state() @@ -124,6 +123,7 @@ def run_sft( # Predict if training_args.do_predict: + logger.warning_once("Batch generation can be very slow. Consider using `scripts/vllm_infer.py` instead.") predict_results = trainer.predict(dataset_module["eval_dataset"], metric_key_prefix="predict", **gen_kwargs) if training_args.predict_with_generate: # predict_loss will be wrong if predict_with_generate is enabled predict_results.metrics.pop("predict_loss", None) diff --git a/src/llamafactory/webui/interface.py b/src/llamafactory/webui/interface.py index a4f2c1fa..8a45ce5f 100644 --- a/src/llamafactory/webui/interface.py +++ b/src/llamafactory/webui/interface.py @@ -35,7 +35,7 @@ if is_gradio_available(): def create_ui(demo_mode: bool = False) -> "gr.Blocks": engine = Engine(demo_mode=demo_mode, pure_chat=False) - hostname = os.getenv("HOSTNAME", os.getenv("COMPUTERNAME", platform.node())).split('.')[0] + hostname = os.getenv("HOSTNAME", os.getenv("COMPUTERNAME", platform.node())).split(".")[0] with gr.Blocks(title=f"LLaMA Board ({hostname})", css=CSS) as demo: if demo_mode: