mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-16 11:50:35 +08:00
create chat model
This commit is contained in:
@@ -2,46 +2,11 @@
|
||||
# Implements stream chat in command line for fine-tuned models.
|
||||
# Usage: python cli_demo.py --model_name_or_path path_to_model --checkpoint_dir path_to_checkpoint
|
||||
|
||||
from threading import Thread
|
||||
from transformers import TextIteratorStreamer
|
||||
|
||||
from llmtuner import Template, get_infer_args, load_model_and_tokenizer, get_logits_processor
|
||||
from llmtuner import ChatModel, get_infer_args
|
||||
|
||||
|
||||
def main():
|
||||
model_args, data_args, finetuning_args, generating_args = get_infer_args()
|
||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
|
||||
|
||||
prompt_template = Template(data_args.prompt_template)
|
||||
source_prefix = data_args.source_prefix if data_args.source_prefix else ""
|
||||
|
||||
def predict_and_print(query, history: list) -> list:
|
||||
input_ids = tokenizer([prompt_template.get_prompt(query, history, source_prefix)], return_tensors="pt")["input_ids"]
|
||||
input_ids = input_ids.to(model.device)
|
||||
|
||||
streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
|
||||
|
||||
gen_kwargs = generating_args.to_dict()
|
||||
gen_kwargs.update({
|
||||
"input_ids": input_ids,
|
||||
"logits_processor": get_logits_processor(),
|
||||
"streamer": streamer
|
||||
})
|
||||
|
||||
thread = Thread(target=model.generate, kwargs=gen_kwargs)
|
||||
thread.start()
|
||||
|
||||
print("Assistant: ", end="", flush=True)
|
||||
|
||||
response = ""
|
||||
for new_text in streamer:
|
||||
print(new_text, end="", flush=True)
|
||||
response += new_text
|
||||
print()
|
||||
|
||||
history = history + [(query, response)]
|
||||
return history
|
||||
|
||||
chat_model = ChatModel(*get_infer_args())
|
||||
history = []
|
||||
print("Welcome to the CLI application, use `clear` to remove the history, use `exit` to exit the application.")
|
||||
|
||||
@@ -62,7 +27,15 @@ def main():
|
||||
print("History has been removed.")
|
||||
continue
|
||||
|
||||
history = predict_and_print(query, history)
|
||||
print("Assistant: ", end="", flush=True)
|
||||
|
||||
response = ""
|
||||
for new_text in chat_model.stream_chat(query, history):
|
||||
print(new_text, end="", flush=True)
|
||||
response += new_text
|
||||
print()
|
||||
|
||||
history = history + [(query, response)]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user