add prompt template class

This commit is contained in:
hiyouga
2023-06-07 11:55:25 +08:00
parent 5d021d4ad5
commit 909af8f496
8 changed files with 67 additions and 40 deletions

View File

@@ -21,11 +21,10 @@ import datetime
from fastapi import FastAPI, Request
from utils import (
Template,
load_pretrained,
prepare_infer_args,
get_logits_processor,
prompt_template_alpaca,
prompt_template_ziya
get_logits_processor
)
@@ -43,7 +42,7 @@ app = FastAPI()
@app.post("/")
async def create_item(request: Request):
global model, tokenizer, format_example
global model, tokenizer, prompt_template
# Parse the request JSON
json_post_raw = await request.json()
@@ -53,7 +52,7 @@ async def create_item(request: Request):
history = json_post_list.get("history")
# Tokenize the input prompt
input_ids = tokenizer([format_example(prompt, history)], return_tensors="pt")["input_ids"]
input_ids = tokenizer([prompt_template.get_prompt(prompt, history)], return_tensors="pt")["input_ids"]
input_ids = input_ids.to(model.device)
# Generation arguments
@@ -98,6 +97,6 @@ async def create_item(request: Request):
if __name__ == "__main__":
model_args, data_args, finetuning_args = prepare_infer_args()
model, tokenizer = load_pretrained(model_args, finetuning_args)
format_example = prompt_template_alpaca if data_args.prompt_template == "alpaca" else prompt_template_ziya
prompt_template = Template(data_args.prompt_template)
uvicorn.run(app, host='0.0.0.0', port=8000, workers=1)