fix checkpoint loading

This commit is contained in:
hiyouga
2023-05-29 17:43:16 +08:00
parent ce71cc8b6d
commit c0e5df92d6
4 changed files with 56 additions and 23 deletions

View File

@@ -9,8 +9,10 @@ import gradio as gr
from utils import ModelArguments, auto_configure_device_map, load_pretrained
from transformers import HfArgumentParser
from transformers.utils.versions import require_version
require_version("gradio==3.27.0", "To fix: pip install gradio==3.27.0") # higher version may cause problems
parser = HfArgumentParser(ModelArguments)
model_args, = parser.parse_args_into_dataclasses()
model, tokenizer = load_pretrained(model_args)
@@ -71,10 +73,17 @@ def parse_text(text): # copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT
return text
def format_example(query):
prompt = "Below is an instruction that describes a task. "
prompt += "Write a response that appropriately completes the request.\n"
prompt += "Instruction:\nHuman: {}\nAssistant: ".format(query)
return prompt
def predict(input, chatbot, max_length, top_p, temperature, history):
chatbot.append((parse_text(input), ""))
inputs = tokenizer([input], return_tensors="pt")
inputs = tokenizer([format_example(input)], return_tensors="pt")
inputs = inputs.to(model.device)
gen_kwargs = {
"do_sample": True,