LLaMA-Factory/src/cli_demo.py
hiyouga 8875f565ad add prompt template class
Former-commit-id: 3d7e3a38d00aa5d9664824093043951af8c3f707
2023-06-07 11:55:25 +08:00

74 lines
2.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# coding=utf-8
# Implements stream chat in command line for fine-tuned models.
# Usage: python cli_demo.py --model_name_or_path path_to_model --checkpoint_dir path_to_checkpoint
from utils import (
Template,
load_pretrained,
prepare_infer_args,
get_logits_processor
)
from threading import Thread
from transformers import TextIteratorStreamer
def main():
model_args, data_args, finetuning_args = prepare_infer_args()
model_name = "BLOOM" if "bloom" in model_args.model_name_or_path else "LLaMA"
model, tokenizer = load_pretrained(model_args, finetuning_args)
prompt_template = Template(data_args.prompt_template)
streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
def predict_and_print(query, history: list):
input_ids = tokenizer([prompt_template.get_prompt(query, history)], return_tensors="pt")["input_ids"]
input_ids = input_ids.to(model.device)
gen_kwargs = {
"input_ids": input_ids,
"do_sample": True,
"top_p": 0.7,
"temperature": 0.95,
"num_beams": 1,
"max_new_tokens": 512,
"repetition_penalty": 1.0,
"logits_processor": get_logits_processor(),
"streamer": streamer
}
thread = Thread(target=model.generate, kwargs=gen_kwargs)
thread.start()
response = ""
print("{}: ".format(model_name), end="")
for new_text in streamer:
print(new_text, end="", flush=True)
response += new_text
print()
history = history + [(query, response)]
return history
history = []
print("欢迎使用 {} 模型输入内容即可对话clear清空对话历史stop终止程序".format(model_name))
while True:
try:
query = input("\nInput: ")
except UnicodeDecodeError:
print("Detected decoding error at the inputs, please set the terminal encoding to utf-8.")
continue
except Exception:
raise
if query.strip() == "stop":
break
if query.strip() == "clear":
history = []
print("History has been removed.")
continue
history = predict_and_print(query, history)
if __name__ == "__main__":
main()