support batch infer in vllm

Former-commit-id: 1324d158f954d777f1fbf09f46149c372704b388
This commit is contained in:
hiyouga 2024-12-04 13:50:00 +00:00
parent b2c67a989a
commit 235cdcacee
29 changed files with 148 additions and 407 deletions

1
.gitignore vendored
View File

@ -171,3 +171,4 @@ config/
saves/
output/
wandb/
generated_predictions.jsonl

View File

@ -594,7 +594,7 @@ API_PORT=8000 llamafactory-cli api examples/inference/llama3_vllm.yaml
> [!TIP]
> Visit [this page](https://platform.openai.com/docs/api-reference/chat/create) for API document.
>
> Examples: [Image understanding](scripts/test_image.py) | [Function calling](scripts/test_toolcall.py)
> Examples: [Image understanding](scripts/api_example/test_image.py) | [Function calling](scripts/api_example/test_toolcall.py)
### Download from ModelScope Hub
@ -727,7 +727,6 @@ If you have a project that should be incorporated, please contact via email or c
1. **[LazyLLM](https://github.com/LazyAGI/LazyLLM)**: An easy and lazy way for building multi-agent LLMs applications and supports model fine-tuning via LLaMA Factory.
1. **[RAG-Retrieval](https://github.com/NLPJCL/RAG-Retrieval)**: A full pipeline for RAG retrieval model fine-tuning, inference, and distillation. [[blog]](https://zhuanlan.zhihu.com/p/987727357)
</details>
## License

View File

@ -594,7 +594,7 @@ API_PORT=8000 llamafactory-cli api examples/inference/llama3_vllm.yaml
> [!TIP]
> API 文档请查阅[这里](https://platform.openai.com/docs/api-reference/chat/create)。
>
> 示例:[图像理解](scripts/test_image.py) | [工具调用](scripts/test_toolcall.py)
> 示例:[图像理解](scripts/api_example/test_image.py) | [工具调用](scripts/api_example/test_toolcall.py)
### 从魔搭社区下载

View File

@ -13,6 +13,8 @@ Make sure to execute these commands in the `LLaMA-Factory` directory.
Use `CUDA_VISIBLE_DEVICES` (GPU) or `ASCEND_RT_VISIBLE_DEVICES` (NPU) to choose computing devices.
By default, LLaMA-Factory uses all visible computing devices.
## Examples
### LoRA Fine-Tuning
@ -80,12 +82,6 @@ llamafactory-cli train examples/train_lora/llama3_preprocess.yaml
llamafactory-cli eval examples/train_lora/llama3_lora_eval.yaml
```
#### Batch Predicting and Computing BLEU and ROUGE Scores
```bash
llamafactory-cli train examples/train_lora/llama3_lora_predict.yaml
```
#### Supervised Fine-Tuning on Multiple Nodes
```bash
@ -146,12 +142,6 @@ FORCE_TORCHRUN=1 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llama
FORCE_TORCHRUN=1 llamafactory-cli train examples/train_full/qwen2vl_full_sft.yaml
```
#### Batch Predicting and Computing BLEU and ROUGE Scores
```bash
llamafactory-cli train examples/train_full/llama3_full_predict.yaml
```
### Merging LoRA Adapters and Quantization
#### Merge LoRA Adapters
@ -170,13 +160,19 @@ llamafactory-cli export examples/merge_lora/llama3_gptq.yaml
### Inferring LoRA Fine-Tuned Models
#### Use CLI
#### Batch Generation using vLLM Tensor Parallel
```
python scripts/vllm_infer.py --model_name_or_path path_to_merged_model --dataset alpaca_en_demo
```
#### Use CLI ChatBox
```bash
llamafactory-cli chat examples/inference/llama3_lora_sft.yaml
```
#### Use Web UI
#### Use Web UI ChatBox
```bash
llamafactory-cli webchat examples/inference/llama3_lora_sft.yaml
@ -238,3 +234,9 @@ llamafactory-cli train examples/extras/llama_pro/llama3_freeze_sft.yaml
```bash
bash examples/extras/fsdp_qlora/train.sh
```
#### Computing BLEU and ROUGE Scores
```bash
llamafactory-cli train examples/extras/nlg_eval/llama3_lora_predict.yaml
```

View File

@ -13,6 +13,8 @@
使用 `CUDA_VISIBLE_DEVICES`GPU`ASCEND_RT_VISIBLE_DEVICES`NPU选择计算设备。
LLaMA-Factory 默认使用所有可见的计算设备。
## 示例
### LoRA 微调
@ -80,12 +82,6 @@ llamafactory-cli train examples/train_lora/llama3_preprocess.yaml
llamafactory-cli eval examples/train_lora/llama3_lora_eval.yaml
```
#### 批量预测并计算 BLEU 和 ROUGE 分数
```bash
llamafactory-cli train examples/train_lora/llama3_lora_predict.yaml
```
#### 多机指令监督微调
```bash
@ -146,12 +142,6 @@ FORCE_TORCHRUN=1 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llama
FORCE_TORCHRUN=1 llamafactory-cli train examples/train_full/qwen2vl_full_sft.yaml
```
#### 批量预测并计算 BLEU 和 ROUGE 分数
```bash
llamafactory-cli train examples/train_full/llama3_full_predict.yaml
```
### 合并 LoRA 适配器与模型量化
#### 合并 LoRA 适配器
@ -170,13 +160,19 @@ llamafactory-cli export examples/merge_lora/llama3_gptq.yaml
### 推理 LoRA 模型
#### 使用命令行接口
#### 使用 vLLM+TP 批量推理
```
python scripts/vllm_infer.py --model_name_or_path path_to_merged_model --dataset alpaca_en_demo
```
#### 使用命令行对话框
```bash
llamafactory-cli chat examples/inference/llama3_lora_sft.yaml
```
#### 使用浏览器界面
#### 使用浏览器对话框
```bash
llamafactory-cli webchat examples/inference/llama3_lora_sft.yaml
@ -238,3 +234,9 @@ llamafactory-cli train examples/extras/llama_pro/llama3_freeze_sft.yaml
```bash
bash examples/extras/fsdp_qlora/train.sh
```
#### 计算 BLEU 和 ROUGE 分数
```bash
llamafactory-cli train examples/extras/nlg_eval/llama3_lora_predict.yaml
```

View File

@ -1,3 +1,6 @@
# The batch generation can be SLOW using this config.
# For faster inference, we recommend to use `scripts/vllm_infer.py`.
### model
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
adapter_name_or_path: saves/llama3-8b/lora/sft

View File

@ -1,2 +1,3 @@
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
template: llama3
infer_backend: huggingface # choices: [huggingface, vllm]

View File

@ -2,3 +2,4 @@ model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
adapter_name_or_path: saves/llama3-8b/lora/sft
template: llama3
finetuning_type: lora
infer_backend: huggingface # choices: [huggingface, vllm]

View File

@ -1,2 +1,3 @@
model_name_or_path: llava-hf/llava-1.5-7b-hf
template: llava
infer_backend: huggingface # choices: [huggingface, vllm]

View File

@ -1,2 +1,3 @@
model_name_or_path: Qwen/Qwen2-VL-7B-Instruct
template: qwen2_vl
infer_backend: huggingface # choices: [huggingface, vllm]

View File

@ -1,23 +0,0 @@
### model
model_name_or_path: saves/llama3-8b/full/sft
### method
stage: sft
do_predict: true
finetuning_type: full
### dataset
eval_dataset: identity,alpaca_en_demo
template: llama3
cutoff_len: 2048
max_samples: 50
overwrite_cache: true
preprocessing_num_workers: 16
### output
output_dir: saves/llama3-8b/full/predict
overwrite_output_dir: true
### eval
per_device_eval_batch_size: 1
predict_with_generate: true

View File

@ -1,223 +0,0 @@
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pip install langchain langchain_openai
import os
import sys
import json
import asyncio
import fire
from tqdm import tqdm
from dataclasses import dataclass
from aiolimiter import AsyncLimiter
from typing import List
import pandas as pd
from langchain_openai import ChatOpenAI
from dotenv import load_dotenv
from llamafactory.hparams import get_train_args
from llamafactory.extras.constants import IGNORE_INDEX
from llamafactory.data.loader import _get_merged_dataset
load_dotenv()
class AsyncLLM:
def __init__(
self,
model: str = "gpt-3.5-turbo",
base_url: str = "http://localhost:{}/v1/".format(
os.environ.get("API_PORT", 8000)
),
api_key: str = "{}".format(os.environ.get("API_KEY", "0")),
num_per_second: int = 6,
**kwargs,
):
self.model = model
self.base_url = base_url
self.api_key = api_key
self.num_per_second = num_per_second
# 创建限速器,每秒最多发出 5 个请求
self.limiter = AsyncLimiter(self.num_per_second, 1)
self.llm = ChatOpenAI(
model=self.model, base_url=self.base_url, api_key=self.api_key, **kwargs
)
async def __call__(self, text):
# 限速
async with self.limiter:
return await self.llm.ainvoke([text])
llm = AsyncLLM(
base_url="http://localhost:{}/v1/".format(os.environ.get("API_PORT", 8000)),
api_key="{}".format(os.environ.get("API_KEY", "0")),
num_per_second=10,
)
llms = [llm]
@dataclass
class AsyncAPICall:
uid: str = "0"
@staticmethod
async def _run_task_with_progress(task, pbar):
result = await task
pbar.update(1)
return result
@staticmethod
def async_run(
llms: List[AsyncLLM],
data: List[str],
keyword: str = "",
output_dir: str = "output",
chunk_size=500,
) -> List[str]:
async def infer_chunk(llms: List[AsyncLLM], data: List):
"""
逐块进行推理为避免处理庞大数据时程序崩溃导致已推理数据丢失
"""
results = [llms[i % len(llms)](text) for i, text in enumerate(data)]
with tqdm(total=len(results)) as pbar:
results = await asyncio.gather(
*[
AsyncAPICall._run_task_with_progress(task, pbar)
for task in results
]
)
return results
idx = 0
all_df = []
file_exist_skip = False
user_confirm = False
while idx < len(data):
file_path = os.path.join(output_dir, "tmp", f"{idx}.csv.temp")
if os.path.exists(file_path):
if not user_confirm:
while True:
user_response = input(
f"Find {file_path} file already exists. Do you want to skip them forever?\ny or Y to skip, n or N to rerun to overwrite: "
)
if user_response.lower() == "y":
user_confirm = True
file_exist_skip = True
break
elif user_response.lower() == "n":
user_confirm = True
file_exist_skip = False
break
if file_exist_skip:
tmp_df = pd.read_csv(file_path)
all_df.append(tmp_df)
idx += chunk_size
continue
tmp_data = data[idx : idx + chunk_size]
loop = asyncio.get_event_loop()
tmp_result = loop.run_until_complete(infer_chunk(llms=llms, data=tmp_data))
tmp_result = [item.content for item in tmp_result]
tmp_df = pd.DataFrame({"infer": tmp_result})
if not os.path.exists(p := os.path.dirname(file_path)):
os.makedirs(p, exist_ok=True)
tmp_df.to_csv(file_path, index=False)
all_df.append(tmp_df)
idx += chunk_size
all_df = pd.concat(all_df)
return all_df["infer"]
def async_api_infer(
model_name_or_path: str = "",
eval_dataset: str = "",
template: str = "",
dataset_dir: str = "data",
do_predict: bool = True,
predict_with_generate: bool = True,
max_samples: int = None,
output_dir: str = "output",
chunk_size=50,
):
if len(sys.argv) == 1:
model_args, data_args, training_args, finetuning_args, generating_args = (
get_train_args(
dict(
model_name_or_path=model_name_or_path,
dataset_dir=dataset_dir,
eval_dataset=eval_dataset,
template=template,
output_dir=output_dir,
do_predict=True,
predict_with_generate=True,
max_samples=max_samples,
)
)
)
else:
model_args, data_args, training_args, finetuning_args, generating_args = (
get_train_args()
)
dataset = _get_merged_dataset(
data_args.eval_dataset, model_args, data_args, training_args, "sft"
)
labels = [item[0]["content"] for item in dataset["_response"]]
prompts = [item[0]["content"] for item in dataset["_prompt"]]
infers = AsyncAPICall.async_run(
llms,
prompts,
chunk_size=chunk_size,
output_dir=training_args.output_dir,
)
if not os.path.exists(training_args.output_dir):
os.makedirs(training_args.output_dir, exist_ok=True)
output_prediction_file = os.path.join(
training_args.output_dir, "generated_predictions.jsonl"
)
with open(output_prediction_file, "w", encoding="utf-8") as writer:
res: List[str] = []
for text, pred, label in zip(prompts, infers, labels):
res.append(
json.dumps(
{"prompt": text, "predict": pred, "label": label},
ensure_ascii=False,
)
)
writer.write("\n".join(res))
if __name__ == "__main__":
fire.Fire(async_api_infer)

View File

@ -12,127 +12,106 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
import sys
from typing import List
import fire
from transformers import Seq2SeqTrainingArguments
from vllm import LLM, SamplingParams
from vllm.lora.request import LoRARequest
from llamafactory.data import get_dataset, get_template_and_fix_tokenizer
from llamafactory.extras.constants import IGNORE_INDEX
from llamafactory.hparams import get_train_args
from llamafactory.extras.misc import get_device_count
from llamafactory.hparams import get_infer_args
from llamafactory.model import load_tokenizer
max_tokens = 2048
def vllm_infer(
model_name_or_path: str = None,
model_name_or_path: str,
adapter_name_or_path: str = None,
dataset: str = "alpaca_en_demo",
dataset_dir: str = "data",
eval_dataset: str = None,
template: str = "default",
max_sample: int = None,
preprocessing_num_workers: int = 16,
predict_with_generate: bool = True,
do_predict: bool = True,
temperature: float = 0.7,
cutoff_len: int = 2048,
max_samples: int = None,
vllm_config: str = "{}",
save_name: str = "generated_predictions.jsonl",
temperature: float = 0.95,
top_p: float = 0.7,
top_k: float = 50,
output_dir: str = "output",
top_k: int = 50,
max_new_tokens: int = 1024,
repetition_penalty: float = 1.0,
):
if len(sys.argv) == 1:
model_args, data_args, training_args, finetuning_args, generating_args = (
get_train_args(
dict(
model_name_or_path=model_name_or_path,
adapter_name_or_path=adapter_name_or_path,
dataset_dir=dataset_dir,
eval_dataset=eval_dataset,
template=template,
max_sample=max_sample,
preprocessing_num_workers=preprocessing_num_workers,
predict_with_generate=predict_with_generate,
do_predict=do_predict,
temperature=temperature,
top_p=top_p,
top_k=top_k,
output_dir=output_dir,
)
)
)
else:
model_args, data_args, training_args, finetuning_args, generating_args = (
get_train_args()
r"""
Performs batch generation using vLLM engine, which supports tensor parallelism.
Usage: python vllm_infer.py --model_name_or_path meta-llama/Llama-2-7b-hf --template llama --dataset alpaca_en_demo
"""
model_args, data_args, _, generating_args = get_infer_args(
dict(
model_name_or_path=model_name_or_path,
adapter_name_or_path=adapter_name_or_path,
dataset=dataset,
dataset_dir=dataset_dir,
template=template,
cutoff_len=cutoff_len,
max_samples=max_samples,
vllm_config=vllm_config,
temperature=temperature,
top_p=top_p,
top_k=top_k,
max_new_tokens=max_new_tokens,
repetition_penalty=repetition_penalty,
)
)
tokenizer = load_tokenizer(model_args)["tokenizer"]
training_args = Seq2SeqTrainingArguments(output_dir="dummy_dir", predict_with_generate=True)
tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"]
template = get_template_and_fix_tokenizer(tokenizer, data_args)
dataset = get_dataset(template, model_args, data_args, training_args, "ppo", **tokenizer_module)["train_dataset"]
eval_dataset = get_dataset(
template, model_args, data_args, training_args, finetuning_args.stage, tokenizer
)["eval_dataset"]
prompts = [item["input_ids"] for item in eval_dataset]
prompts = tokenizer.batch_decode(prompts, skip_special_tokens=False)
labels = [
list(filter(lambda x: x != IGNORE_INDEX, item["labels"]))
for item in eval_dataset
]
labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
inputs, prompts, labels = [], [], []
for sample in dataset:
inputs.append({"prompt_token_ids": sample["input_ids"]})
prompts.append(tokenizer.decode(sample["input_ids"], skip_special_tokens=False))
labels.append(
tokenizer.decode(list(filter(lambda x: x != IGNORE_INDEX, sample["labels"])), skip_special_tokens=False)
)
sampling_params = SamplingParams(
repetition_penalty=generating_args.repetition_penalty or 1.0, # repetition_penalty must > 0
temperature=generating_args.temperature,
top_p=generating_args.top_p or 1.0, # top_p must > 0
top_k=generating_args.top_k,
top_p=generating_args.top_p,
max_tokens=max_tokens,
stop_token_ids=[tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids,
max_tokens=generating_args.max_new_tokens,
skip_special_tokens=False,
)
if model_args.adapter_name_or_path:
if isinstance(model_args.adapter_name_or_path, list):
lora_path = model_args.adapter_name_or_path[0]
else:
lora_path = model_args.adapter_name_or_path
lora_requests = LoRARequest("lora_adapter_0", 0, lora_path=lora_path)
enable_lora = True
if model_args.adapter_name_or_path is not None:
lora_request = LoRARequest("default", 1, model_args.adapter_name_or_path[0])
else:
lora_requests = None
enable_lora = False
lora_request = None
llm = LLM(
model=model_args.model_name_or_path,
trust_remote_code=True,
tokenizer=model_args.model_name_or_path,
enable_lora=enable_lora,
)
engine_args = {
"model": model_args.model_name_or_path,
"trust_remote_code": True,
"dtype": model_args.infer_dtype,
"tensor_parallel_size": get_device_count() or 1,
"disable_log_stats": True,
"enable_lora": model_args.adapter_name_or_path is not None,
}
if isinstance(model_args.vllm_config, dict):
engine_args.update(model_args.vllm_config)
outputs = llm.generate(prompts, sampling_params, lora_request=lora_requests)
results = LLM(**engine_args).generate(inputs, sampling_params, lora_request=lora_request)
preds = [result.outputs[0].text for result in results]
with open(save_name, "w", encoding="utf-8") as f:
for text, pred, label in zip(prompts, preds, labels):
f.write(json.dumps({"prompt": text, "predict": pred, "label": label}, ensure_ascii=False) + "\n")
if not os.path.exists(training_args.output_dir):
os.makedirs(training_args.output_dir, exist_ok=True)
output_prediction_file = os.path.join(
training_args.output_dir, "generated_predictions.jsonl"
)
with open(output_prediction_file, "w", encoding="utf-8") as writer:
res: List[str] = []
for text, pred, label in zip(prompts, outputs, labels):
res.append(
json.dumps(
{"prompt": text, "predict": pred.outputs[0].text, "label": label},
ensure_ascii=False,
)
)
writer.write("\n".join(res))
print("*" * 70)
print(f"{len(prompts)} generated results have been saved at {save_name}.")
print("*" * 70)
if __name__ == "__main__":

View File

@ -54,7 +54,7 @@ extra_require = {
"gptq": ["optimum>=1.17.0", "auto-gptq>=0.5.0"],
"awq": ["autoawq"],
"aqlm": ["aqlm[gpu]>=1.1.0"],
"vllm": ["vllm>=0.4.3,<0.6.4"],
"vllm": ["vllm>=0.4.3,<0.6.5"],
"galore": ["galore-torch"],
"badam": ["badam>=1.2.1"],
"adam-mini": ["adam-mini"],

View File

@ -17,7 +17,7 @@
import gc
import os
from typing import TYPE_CHECKING, Tuple, Union
from typing import TYPE_CHECKING, Any, Dict, Literal, Sequence, Tuple, Union
import torch
import torch.distributed as dist
@ -87,6 +87,21 @@ def check_dependencies() -> None:
require_version("trl>=0.8.6,<=0.9.6", "To fix: pip install trl>=0.8.6,<=0.9.6")
def calculate_tps(dataset: Sequence[Dict[str, Any]], metrics: Dict[str, float], stage: Literal["sft", "rm"]) -> float:
r"""
Calculates effective tokens per second.
"""
effective_token_num = 0
for data in dataset:
if stage == "sft":
effective_token_num += len(data["input_ids"])
elif stage == "rm":
effective_token_num += len(data["chosen_input_ids"]) + len(data["rejected_input_ids"])
result = effective_token_num * metrics["epoch"] / metrics["train_runtime"]
return result / dist.get_world_size() if dist.is_initialized() else result
def count_parameters(model: "torch.nn.Module") -> Tuple[int, int]:
r"""
Returns the number of trainable parameters and number of all parameters in the model.
@ -264,11 +279,3 @@ def use_modelscope() -> bool:
def use_openmind() -> bool:
return os.environ.get("USE_OPENMIND_HUB", "0").lower() in ["true", "1"]
def cal_effective_tokens(effective_token_num, epoch, train_runtime) -> int:
r"""
calculate effective tokens.
"""
result = effective_token_num * epoch / train_runtime
return result / dist.get_world_size() if dist.is_initialized() else result

View File

@ -122,7 +122,7 @@ def _check_extra_dependencies(
require_version("mixture-of-depth>=1.1.6", "To fix: pip install mixture-of-depth>=1.1.6")
if model_args.infer_backend == "vllm":
require_version("vllm>=0.4.3,<0.6.4", "To fix: pip install vllm>=0.4.3,<0.6.4")
require_version("vllm>=0.4.3,<0.6.5", "To fix: pip install vllm>=0.4.3,<0.6.5")
if finetuning_args.use_galore:
require_version("galore_torch", "To fix: pip install galore_torch")

View File

@ -19,7 +19,7 @@ from typing import TYPE_CHECKING, List, Optional
from ...data import PairwiseDataCollatorWithPadding, get_dataset, get_template_and_fix_tokenizer
from ...extras.constants import IGNORE_INDEX
from ...extras.misc import cal_effective_tokens
from ...extras.misc import calculate_tps
from ...extras.ploting import plot_loss
from ...hparams import ModelArguments
from ...model import load_model, load_tokenizer
@ -65,12 +65,6 @@ def run_dpo(
# Update arguments
training_args.remove_unused_columns = False # important for multimodal and pairwise dataset
effective_token_num = 0.0
if finetuning_args.include_effective_tokens_per_second:
for data in dataset_module["train_dataset"]:
effective_token_num += len(data["chosen_input_ids"])
effective_token_num += len(data["rejected_input_ids"])
# Initialize our Trainer
trainer = CustomDPOTrainer(
model=model,
@ -86,13 +80,12 @@ def run_dpo(
# Training
if training_args.do_train:
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
trainer.save_model()
if finetuning_args.include_effective_tokens_per_second:
train_result.metrics["effective_tokens_per_sec"] = cal_effective_tokens(
effective_token_num, train_result.metrics["epoch"], train_result.metrics["train_runtime"]
train_result.metrics["effective_tokens_per_sec"] = calculate_tps(
dataset_module["train_dataset"], train_result.metrics, stage="rm"
)
trainer.save_model()
trainer.log_metrics("train", train_result.metrics)
trainer.save_metrics("train", train_result.metrics)
trainer.save_state()

View File

@ -161,12 +161,9 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
preds[i] = np.concatenate((preds[i][pad_len[0] :], preds[i][: pad_len[0]]), axis=-1)
decoded_inputs = self.tokenizer.batch_decode(dataset["input_ids"], skip_special_tokens=True)
decoded_labels = self.tokenizer.batch_decode(labels, skip_special_tokens=True)
decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True)
decoded_labels = self.tokenizer.batch_decode(labels, skip_special_tokens=True)
with open(output_prediction_file, "w", encoding="utf-8") as writer:
res: List[str] = []
for text, label, pred in zip(decoded_inputs, decoded_labels, decoded_preds):
res.append(json.dumps({"prompt": text, "label": label, "predict": pred}, ensure_ascii=False))
writer.write("\n".join(res))
with open(output_prediction_file, "w", encoding="utf-8") as f:
for text, pred, label in zip(decoded_inputs, decoded_preds, decoded_labels):
f.write(json.dumps({"prompt": text, "predict": pred, "label": label}, ensure_ascii=False) + "\n")

View File

@ -19,7 +19,8 @@ from typing import TYPE_CHECKING, List, Optional
from ...data import SFTDataCollatorWith4DAttentionMask, get_dataset, get_template_and_fix_tokenizer
from ...extras.constants import IGNORE_INDEX
from ...extras.misc import cal_effective_tokens, get_logits_processor
from ...extras.logging import get_logger
from ...extras.misc import calculate_tps, get_logits_processor
from ...extras.ploting import plot_loss
from ...model import load_model, load_tokenizer
from ..trainer_utils import create_modelcard_and_push
@ -33,6 +34,9 @@ if TYPE_CHECKING:
from ...hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
logger = get_logger(__name__)
def run_sft(
model_args: "ModelArguments",
data_args: "DataArguments",
@ -65,11 +69,6 @@ def run_sft(
training_args.generation_num_beams = data_args.eval_num_beams or training_args.generation_num_beams
training_args.remove_unused_columns = False # important for multimodal dataset
effective_token_num = 0.0
if finetuning_args.include_effective_tokens_per_second:
for data in dataset_module["train_dataset"]:
effective_token_num += len(data["input_ids"])
# Metric utils
metric_module = {}
if training_args.predict_with_generate:
@ -99,12 +98,12 @@ def run_sft(
# Training
if training_args.do_train:
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
trainer.save_model()
if finetuning_args.include_effective_tokens_per_second:
train_result.metrics["effective_tokens_per_sec"] = cal_effective_tokens(
effective_token_num, train_result.metrics["epoch"], train_result.metrics["train_runtime"]
train_result.metrics["effective_tokens_per_sec"] = calculate_tps(
dataset_module["train_dataset"], train_result.metrics, stage="sft"
)
trainer.save_model()
trainer.log_metrics("train", train_result.metrics)
trainer.save_metrics("train", train_result.metrics)
trainer.save_state()
@ -124,6 +123,7 @@ def run_sft(
# Predict
if training_args.do_predict:
logger.warning_once("Batch generation can be very slow. Consider using `scripts/vllm_infer.py` instead.")
predict_results = trainer.predict(dataset_module["eval_dataset"], metric_key_prefix="predict", **gen_kwargs)
if training_args.predict_with_generate: # predict_loss will be wrong if predict_with_generate is enabled
predict_results.metrics.pop("predict_loss", None)

View File

@ -35,7 +35,7 @@ if is_gradio_available():
def create_ui(demo_mode: bool = False) -> "gr.Blocks":
engine = Engine(demo_mode=demo_mode, pure_chat=False)
hostname = os.getenv("HOSTNAME", os.getenv("COMPUTERNAME", platform.node())).split('.')[0]
hostname = os.getenv("HOSTNAME", os.getenv("COMPUTERNAME", platform.node())).split(".")[0]
with gr.Blocks(title=f"LLaMA Board ({hostname})", css=CSS) as demo:
if demo_mode: