add ziya prompt template

This commit is contained in:
hiyouga
2023-06-03 19:05:51 +08:00
parent 771f454ff1
commit de09ee1315
6 changed files with 79 additions and 24 deletions

View File

@@ -4,14 +4,16 @@
import torch
from utils import ModelArguments, FinetuningArguments, load_pretrained, get_logits_processor
from transformers import HfArgumentParser
from utils import (
load_pretrained,
prepare_infer_args,
get_logits_processor
)
def main():
parser = HfArgumentParser((ModelArguments, FinetuningArguments))
model_args, finetuning_args = parser.parse_args_into_dataclasses()
model_args, data_args, finetuning_args = prepare_infer_args()
model_name = "BLOOM" if "bloom" in model_args.model_name_or_path else "LLaMA"
model, tokenizer = load_pretrained(model_args, finetuning_args)
@@ -24,14 +26,26 @@ def main():
model.eval()
def format_example(query):
def format_example_alpaca(query, history):
prompt = "Below is an instruction that describes a task. "
prompt += "Write a response that appropriately completes the request.\n"
prompt += "Instruction:\nHuman: {}\nAssistant: ".format(query)
prompt += "Instruction:\n"
for old_query, response in history:
prompt += "Human: {}\nAssistant: {}\n".format(old_query, response)
prompt += "Human: {}\nAssistant:".format(query)
return prompt
def format_example_ziya(query, history):
prompt = ""
for old_query, response in history:
prompt += "<human>: {}\n<bot>: {}\n".format(old_query, response)
prompt += "<human>: {}\n<bot>:".format(query)
return prompt
format_example = format_example_alpaca if data_args.prompt_template == "alpaca" else format_example_ziya
def predict(query, history: list):
input_ids = tokenizer([format_example(query)], return_tensors="pt")["input_ids"]
input_ids = tokenizer([format_example(query, history)], return_tensors="pt")["input_ids"]
input_ids = input_ids.to(model.device)
gen_kwargs = {
"do_sample": True,
@@ -65,6 +79,7 @@ def main():
if query.strip() == "clear":
history = []
print("History has been removed.")
continue
response, history = predict(query, history)