mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-16 11:50:35 +08:00
fix checkpoint loading
This commit is contained in:
@@ -21,8 +21,14 @@ def main():
|
||||
model = model.cuda()
|
||||
model.eval()
|
||||
|
||||
def format_example(query):
|
||||
prompt = "Below is an instruction that describes a task. "
|
||||
prompt += "Write a response that appropriately completes the request.\n"
|
||||
prompt += "Instruction:\nHuman: {}\nAssistant: ".format(query)
|
||||
return prompt
|
||||
|
||||
def predict(query, history: list):
|
||||
inputs = tokenizer([query], return_tensors="pt")
|
||||
inputs = tokenizer([format_example(query)], return_tensors="pt")
|
||||
inputs = inputs.to(model.device)
|
||||
gen_kwargs = {
|
||||
"do_sample": True,
|
||||
|
||||
Reference in New Issue
Block a user