fix checkpoint loading

Former-commit-id: d31aa5c2c0bcb6a4ef4a62e21693548dd9acaae6
This commit is contained in:
hiyouga
2023-05-29 17:43:16 +08:00
parent 35d04a2c05
commit 304be6dc28
4 changed files with 56 additions and 23 deletions

View File

@@ -21,8 +21,14 @@ def main():
model = model.cuda()
model.eval()
def format_example(query):
prompt = "Below is an instruction that describes a task. "
prompt += "Write a response that appropriately completes the request.\n"
prompt += "Instruction:\nHuman: {}\nAssistant: ".format(query)
return prompt
def predict(query, history: list):
inputs = tokenizer([query], return_tensors="pt")
inputs = tokenizer([format_example(query)], return_tensors="pt")
inputs = inputs.to(model.device)
gen_kwargs = {
"do_sample": True,