mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-15 11:20:35 +08:00
add prompt template class
This commit is contained in:
@@ -21,11 +21,10 @@ import datetime
|
||||
from fastapi import FastAPI, Request
|
||||
|
||||
from utils import (
|
||||
Template,
|
||||
load_pretrained,
|
||||
prepare_infer_args,
|
||||
get_logits_processor,
|
||||
prompt_template_alpaca,
|
||||
prompt_template_ziya
|
||||
get_logits_processor
|
||||
)
|
||||
|
||||
|
||||
@@ -43,7 +42,7 @@ app = FastAPI()
|
||||
|
||||
@app.post("/")
|
||||
async def create_item(request: Request):
|
||||
global model, tokenizer, format_example
|
||||
global model, tokenizer, prompt_template
|
||||
|
||||
# Parse the request JSON
|
||||
json_post_raw = await request.json()
|
||||
@@ -53,7 +52,7 @@ async def create_item(request: Request):
|
||||
history = json_post_list.get("history")
|
||||
|
||||
# Tokenize the input prompt
|
||||
input_ids = tokenizer([format_example(prompt, history)], return_tensors="pt")["input_ids"]
|
||||
input_ids = tokenizer([prompt_template.get_prompt(prompt, history)], return_tensors="pt")["input_ids"]
|
||||
input_ids = input_ids.to(model.device)
|
||||
|
||||
# Generation arguments
|
||||
@@ -98,6 +97,6 @@ async def create_item(request: Request):
|
||||
if __name__ == "__main__":
|
||||
model_args, data_args, finetuning_args = prepare_infer_args()
|
||||
model, tokenizer = load_pretrained(model_args, finetuning_args)
|
||||
format_example = prompt_template_alpaca if data_args.prompt_template == "alpaca" else prompt_template_ziya
|
||||
prompt_template = Template(data_args.prompt_template)
|
||||
|
||||
uvicorn.run(app, host='0.0.0.0', port=8000, workers=1)
|
||||
|
||||
Reference in New Issue
Block a user