mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-16 03:40:34 +08:00
add ziya prompt template
This commit is contained in:
@@ -4,14 +4,16 @@
|
||||
|
||||
|
||||
import torch
|
||||
from utils import ModelArguments, FinetuningArguments, load_pretrained, get_logits_processor
|
||||
from transformers import HfArgumentParser
|
||||
from utils import (
|
||||
load_pretrained,
|
||||
prepare_infer_args,
|
||||
get_logits_processor
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
parser = HfArgumentParser((ModelArguments, FinetuningArguments))
|
||||
model_args, finetuning_args = parser.parse_args_into_dataclasses()
|
||||
model_args, data_args, finetuning_args = prepare_infer_args()
|
||||
model_name = "BLOOM" if "bloom" in model_args.model_name_or_path else "LLaMA"
|
||||
model, tokenizer = load_pretrained(model_args, finetuning_args)
|
||||
|
||||
@@ -24,14 +26,26 @@ def main():
|
||||
|
||||
model.eval()
|
||||
|
||||
def format_example(query):
|
||||
def format_example_alpaca(query, history):
|
||||
prompt = "Below is an instruction that describes a task. "
|
||||
prompt += "Write a response that appropriately completes the request.\n"
|
||||
prompt += "Instruction:\nHuman: {}\nAssistant: ".format(query)
|
||||
prompt += "Instruction:\n"
|
||||
for old_query, response in history:
|
||||
prompt += "Human: {}\nAssistant: {}\n".format(old_query, response)
|
||||
prompt += "Human: {}\nAssistant:".format(query)
|
||||
return prompt
|
||||
|
||||
def format_example_ziya(query, history):
|
||||
prompt = ""
|
||||
for old_query, response in history:
|
||||
prompt += "<human>: {}\n<bot>: {}\n".format(old_query, response)
|
||||
prompt += "<human>: {}\n<bot>:".format(query)
|
||||
return prompt
|
||||
|
||||
format_example = format_example_alpaca if data_args.prompt_template == "alpaca" else format_example_ziya
|
||||
|
||||
def predict(query, history: list):
|
||||
input_ids = tokenizer([format_example(query)], return_tensors="pt")["input_ids"]
|
||||
input_ids = tokenizer([format_example(query, history)], return_tensors="pt")["input_ids"]
|
||||
input_ids = input_ids.to(model.device)
|
||||
gen_kwargs = {
|
||||
"do_sample": True,
|
||||
@@ -65,6 +79,7 @@ def main():
|
||||
|
||||
if query.strip() == "clear":
|
||||
history = []
|
||||
print("History has been removed.")
|
||||
continue
|
||||
|
||||
response, history = predict(query, history)
|
||||
|
||||
Reference in New Issue
Block a user