mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-01 03:02:51 +08:00
add async call api
Former-commit-id: 4c61368600a6648ec20753a41536ad3c7986703b
This commit is contained in:
parent
99265c7d2f
commit
d4bf81b36a
223
scripts/async_call_api.py
Normal file
223
scripts/async_call_api.py
Normal file
@ -0,0 +1,223 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# pip install langchain langchain_openai
|
||||
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import asyncio
|
||||
|
||||
|
||||
import fire
|
||||
from tqdm import tqdm
|
||||
from dataclasses import dataclass
|
||||
from aiolimiter import AsyncLimiter
|
||||
from typing import List
|
||||
import pandas as pd
|
||||
from langchain_openai import ChatOpenAI
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from llamafactory.hparams import get_train_args
|
||||
from llamafactory.extras.constants import IGNORE_INDEX
|
||||
from llamafactory.data.loader import _get_merged_dataset
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
class AsyncLLM:
|
||||
def __init__(
|
||||
self,
|
||||
model: str = "gpt-3.5-turbo",
|
||||
base_url: str = "http://localhost:{}/v1/".format(
|
||||
os.environ.get("API_PORT", 8000)
|
||||
),
|
||||
api_key: str = "{}".format(os.environ.get("API_KEY", "0")),
|
||||
num_per_second: int = 6,
|
||||
**kwargs,
|
||||
):
|
||||
self.model = model
|
||||
self.base_url = base_url
|
||||
self.api_key = api_key
|
||||
self.num_per_second = num_per_second
|
||||
|
||||
# 创建限速器,每秒最多发出 5 个请求
|
||||
self.limiter = AsyncLimiter(self.num_per_second, 1)
|
||||
|
||||
self.llm = ChatOpenAI(
|
||||
model=self.model, base_url=self.base_url, api_key=self.api_key, **kwargs
|
||||
)
|
||||
|
||||
async def __call__(self, text):
|
||||
# 限速
|
||||
async with self.limiter:
|
||||
return await self.llm.ainvoke([text])
|
||||
|
||||
|
||||
llm = AsyncLLM(
|
||||
base_url="http://localhost:{}/v1/".format(os.environ.get("API_PORT", 8000)),
|
||||
api_key="{}".format(os.environ.get("API_KEY", "0")),
|
||||
num_per_second=10,
|
||||
)
|
||||
llms = [llm]
|
||||
|
||||
|
||||
@dataclass
|
||||
class AsyncAPICall:
|
||||
uid: str = "0"
|
||||
|
||||
@staticmethod
|
||||
async def _run_task_with_progress(task, pbar):
|
||||
result = await task
|
||||
pbar.update(1)
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def async_run(
|
||||
llms: List[AsyncLLM],
|
||||
data: List[str],
|
||||
keyword: str = "",
|
||||
output_dir: str = "output",
|
||||
chunk_size=500,
|
||||
) -> List[str]:
|
||||
|
||||
async def infer_chunk(llms: List[AsyncLLM], data: List):
|
||||
"""
|
||||
逐块进行推理,为避免处理庞大数据时,程序崩溃导致已推理数据丢失
|
||||
"""
|
||||
results = [llms[i % len(llms)](text) for i, text in enumerate(data)]
|
||||
|
||||
with tqdm(total=len(results)) as pbar:
|
||||
results = await asyncio.gather(
|
||||
*[
|
||||
AsyncAPICall._run_task_with_progress(task, pbar)
|
||||
for task in results
|
||||
]
|
||||
)
|
||||
return results
|
||||
|
||||
idx = 0
|
||||
all_df = []
|
||||
file_exist_skip = False
|
||||
user_confirm = False
|
||||
|
||||
while idx < len(data):
|
||||
file_path = os.path.join(output_dir, "tmp", f"{idx}.csv.temp")
|
||||
|
||||
if os.path.exists(file_path):
|
||||
if not user_confirm:
|
||||
while True:
|
||||
user_response = input(
|
||||
f"Find {file_path} file already exists. Do you want to skip them forever?\ny or Y to skip, n or N to rerun to overwrite: "
|
||||
)
|
||||
if user_response.lower() == "y":
|
||||
user_confirm = True
|
||||
file_exist_skip = True
|
||||
break
|
||||
elif user_response.lower() == "n":
|
||||
user_confirm = True
|
||||
file_exist_skip = False
|
||||
break
|
||||
|
||||
if file_exist_skip:
|
||||
tmp_df = pd.read_csv(file_path)
|
||||
all_df.append(tmp_df)
|
||||
idx += chunk_size
|
||||
continue
|
||||
|
||||
tmp_data = data[idx : idx + chunk_size]
|
||||
loop = asyncio.get_event_loop()
|
||||
tmp_result = loop.run_until_complete(infer_chunk(llms=llms, data=tmp_data))
|
||||
tmp_result = [item.content for item in tmp_result]
|
||||
|
||||
tmp_df = pd.DataFrame({"infer": tmp_result})
|
||||
|
||||
if not os.path.exists(p := os.path.dirname(file_path)):
|
||||
os.makedirs(p, exist_ok=True)
|
||||
|
||||
tmp_df.to_csv(file_path, index=False)
|
||||
all_df.append(tmp_df)
|
||||
idx += chunk_size
|
||||
|
||||
all_df = pd.concat(all_df)
|
||||
return all_df["infer"]
|
||||
|
||||
|
||||
def async_api_infer(
|
||||
model_name_or_path: str = "",
|
||||
eval_dataset: str = "",
|
||||
template: str = "",
|
||||
dataset_dir: str = "data",
|
||||
do_predict: bool = True,
|
||||
predict_with_generate: bool = True,
|
||||
max_samples: int = None,
|
||||
output_dir: str = "output",
|
||||
chunk_size=50,
|
||||
):
|
||||
|
||||
if len(sys.argv) == 1:
|
||||
model_args, data_args, training_args, finetuning_args, generating_args = (
|
||||
get_train_args(
|
||||
dict(
|
||||
model_name_or_path=model_name_or_path,
|
||||
dataset_dir=dataset_dir,
|
||||
eval_dataset=eval_dataset,
|
||||
template=template,
|
||||
output_dir=output_dir,
|
||||
do_predict=True,
|
||||
predict_with_generate=True,
|
||||
max_samples=max_samples,
|
||||
)
|
||||
)
|
||||
)
|
||||
else:
|
||||
model_args, data_args, training_args, finetuning_args, generating_args = (
|
||||
get_train_args()
|
||||
)
|
||||
|
||||
dataset = _get_merged_dataset(
|
||||
data_args.eval_dataset, model_args, data_args, training_args, "sft"
|
||||
)
|
||||
|
||||
labels = [item[0]["content"] for item in dataset["_response"]]
|
||||
prompts = [item[0]["content"] for item in dataset["_prompt"]]
|
||||
|
||||
infers = AsyncAPICall.async_run(
|
||||
llms,
|
||||
prompts,
|
||||
chunk_size=chunk_size,
|
||||
output_dir=training_args.output_dir,
|
||||
)
|
||||
|
||||
if not os.path.exists(training_args.output_dir):
|
||||
os.makedirs(training_args.output_dir, exist_ok=True)
|
||||
|
||||
output_prediction_file = os.path.join(
|
||||
training_args.output_dir, "generated_predictions.jsonl"
|
||||
)
|
||||
|
||||
with open(output_prediction_file, "w", encoding="utf-8") as writer:
|
||||
res: List[str] = []
|
||||
for text, pred, label in zip(prompts, infers, labels):
|
||||
res.append(
|
||||
json.dumps(
|
||||
{"prompt": text, "predict": pred, "label": label},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
)
|
||||
writer.write("\n".join(res))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(async_api_infer)
|
Loading…
x
Reference in New Issue
Block a user