mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-02 19:52:50 +08:00
fix #2282 and update tool prompt
Former-commit-id: b2fb0eca56ee835438cc20e83f36ac3f6eb95c83
This commit is contained in:
parent
9898712a24
commit
509e35ffc8
@ -1 +1 @@
|
|||||||
fc9a6a3458caca2af8dafc6181773fe10c6d8657
|
34c723573fbc2d7601f6d9c882ccf5aa4f9bcc4b
|
@ -15,7 +15,7 @@ def convert_alpaca(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr")
|
|||||||
outputs = {"prompt": [], "response": [], "system": [], "tools": []}
|
outputs = {"prompt": [], "response": [], "system": [], "tools": []}
|
||||||
for i in range(len(examples[dataset_attr.prompt])):
|
for i in range(len(examples[dataset_attr.prompt])):
|
||||||
prompt = []
|
prompt = []
|
||||||
if dataset_attr.history:
|
if dataset_attr.history and isinstance(examples[dataset_attr.history][i], list):
|
||||||
for old_prompt, old_response in examples[dataset_attr.history][i]:
|
for old_prompt, old_response in examples[dataset_attr.history][i]:
|
||||||
prompt.append({"role": Role.USER, "content": old_prompt})
|
prompt.append({"role": Role.USER, "content": old_prompt})
|
||||||
prompt.append({"role": Role.ASSISTANT, "content": old_response})
|
prompt.append({"role": Role.ASSISTANT, "content": old_response})
|
||||||
@ -25,13 +25,10 @@ def convert_alpaca(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr")
|
|||||||
instruction += "\n" + examples[dataset_attr.query][i]
|
instruction += "\n" + examples[dataset_attr.query][i]
|
||||||
prompt.append({"role": Role.USER, "content": instruction})
|
prompt.append({"role": Role.USER, "content": instruction})
|
||||||
|
|
||||||
if dataset_attr.response:
|
if dataset_attr.response and isinstance(examples[dataset_attr.response][i], list):
|
||||||
if isinstance(examples[dataset_attr.response][i], list):
|
response = [{"role": Role.ASSISTANT, "content": content} for content in examples[dataset_attr.response][i]]
|
||||||
response = [
|
elif dataset_attr.response and isinstance(examples[dataset_attr.response][i], str):
|
||||||
{"role": Role.ASSISTANT, "content": content} for content in examples[dataset_attr.response][i]
|
response = [{"role": Role.ASSISTANT, "content": examples[dataset_attr.response][i]}]
|
||||||
]
|
|
||||||
else:
|
|
||||||
response = [{"role": Role.ASSISTANT, "content": examples[dataset_attr.response][i]}]
|
|
||||||
else:
|
else:
|
||||||
response = []
|
response = []
|
||||||
|
|
||||||
|
@ -15,10 +15,10 @@ JSON_FORMAT_PROMPT = (
|
|||||||
|
|
||||||
TOOL_SYSTEM_PROMPT = (
|
TOOL_SYSTEM_PROMPT = (
|
||||||
"You have access to the following tools:\n{tool_text}"
|
"You have access to the following tools:\n{tool_text}"
|
||||||
"Use the following format to answer the question:\n"
|
"Use the following format if using a tool:\n"
|
||||||
"```\n"
|
"```\n"
|
||||||
"Action: the action to take, should be one of [{tool_names}] if using a tool.\n"
|
"Action: tool name (one of [{tool_names}]).\n"
|
||||||
"Action Input: the input to the action{format_prompt}.\n"
|
"Action Input: the input to the tool{format_prompt}.\n"
|
||||||
"```\n"
|
"```\n"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -95,12 +95,15 @@ class StringFormatter(Formatter):
|
|||||||
for slot in self.slots:
|
for slot in self.slots:
|
||||||
if isinstance(slot, str):
|
if isinstance(slot, str):
|
||||||
for name, value in kwargs.items():
|
for name, value in kwargs.items():
|
||||||
|
if not isinstance(value, str):
|
||||||
|
raise RuntimeError("Expected a string, got {}".format(value))
|
||||||
|
|
||||||
slot = slot.replace("{{" + name + "}}", value, 1)
|
slot = slot.replace("{{" + name + "}}", value, 1)
|
||||||
elements.append(slot)
|
elements.append(slot)
|
||||||
elif isinstance(slot, (dict, set)):
|
elif isinstance(slot, (dict, set)):
|
||||||
elements.append(slot)
|
elements.append(slot)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Input must be string, set[str] or dict[str, str], got {}".format(type(slot)))
|
raise RuntimeError("Input must be string, set[str] or dict[str, str], got {}".format(type(slot)))
|
||||||
|
|
||||||
return elements
|
return elements
|
||||||
|
|
||||||
@ -124,7 +127,7 @@ class FunctionFormatter(Formatter):
|
|||||||
elif isinstance(slot, (dict, set)):
|
elif isinstance(slot, (dict, set)):
|
||||||
elements.append(slot)
|
elements.append(slot)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Input must be string, set[str] or dict[str, str], got {}".format(type(slot)))
|
raise RuntimeError("Input must be string, set[str] or dict[str, str], got {}".format(type(slot)))
|
||||||
|
|
||||||
return elements
|
return elements
|
||||||
|
|
||||||
|
@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Tuple
|
|||||||
|
|
||||||
from ..extras.constants import IGNORE_INDEX
|
from ..extras.constants import IGNORE_INDEX
|
||||||
from ..extras.logging import get_logger
|
from ..extras.logging import get_logger
|
||||||
|
from .utils import Role
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -51,7 +52,7 @@ def preprocess_supervised_dataset(
|
|||||||
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
|
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
|
||||||
|
|
||||||
for i in range(len(examples["prompt"])):
|
for i in range(len(examples["prompt"])):
|
||||||
if len(examples["prompt"][i]) == 0 or len(examples["response"][i]) != 1:
|
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
messages = examples["prompt"][i] + examples["response"][i]
|
messages = examples["prompt"][i] + examples["response"][i]
|
||||||
@ -93,7 +94,7 @@ def preprocess_packed_supervised_dataset(
|
|||||||
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
|
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
|
||||||
input_ids, labels = [], []
|
input_ids, labels = [], []
|
||||||
for i in range(len(examples["prompt"])):
|
for i in range(len(examples["prompt"])):
|
||||||
if len(examples["prompt"][i]) == 0 or len(examples["response"][i]) != 1:
|
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
messages = examples["prompt"][i] + examples["response"][i]
|
messages = examples["prompt"][i] + examples["response"][i]
|
||||||
@ -137,10 +138,14 @@ def preprocess_unsupervised_dataset(
|
|||||||
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
|
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
|
||||||
|
|
||||||
for i in range(len(examples["prompt"])):
|
for i in range(len(examples["prompt"])):
|
||||||
if len(examples["prompt"][i]) == 0 or len(examples["response"][i]) != 1:
|
if len(examples["prompt"][i]) % 2 != 1:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
messages = examples["prompt"][i] + examples["response"][i]
|
if len(examples["response"][i]) == 1:
|
||||||
|
messages = examples["prompt"][i] + examples["response"][i]
|
||||||
|
else:
|
||||||
|
messages = examples["prompt"][i] + [{"role": Role.ASSISTANT, "content": ""}]
|
||||||
|
|
||||||
input_ids, labels = template.encode_oneturn(
|
input_ids, labels = template.encode_oneturn(
|
||||||
tokenizer, messages, examples["system"][i], examples["tools"][i], data_args.cutoff_len
|
tokenizer, messages, examples["system"][i], examples["tools"][i], data_args.cutoff_len
|
||||||
)
|
)
|
||||||
@ -164,7 +169,7 @@ def preprocess_pairwise_dataset(
|
|||||||
# build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>`
|
# build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>`
|
||||||
model_inputs = {"prompt_ids": [], "chosen_ids": [], "rejected_ids": []}
|
model_inputs = {"prompt_ids": [], "chosen_ids": [], "rejected_ids": []}
|
||||||
for i in range(len(examples["prompt"])):
|
for i in range(len(examples["prompt"])):
|
||||||
if len(examples["prompt"][i]) == 0 or len(examples["response"][i]) < 2:
|
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) < 2:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
chosen_messages = examples["prompt"][i] + [examples["response"][i][0]]
|
chosen_messages = examples["prompt"][i] + [examples["response"][i][0]]
|
||||||
|
@ -60,7 +60,7 @@ def _verify_model_args(model_args: "ModelArguments", finetuning_args: "Finetunin
|
|||||||
if finetuning_args.finetuning_type != "lora":
|
if finetuning_args.finetuning_type != "lora":
|
||||||
raise ValueError("Quantization is only compatible with the LoRA method.")
|
raise ValueError("Quantization is only compatible with the LoRA method.")
|
||||||
|
|
||||||
if finetuning_args.create_new_adapter:
|
if model_args.adapter_name_or_path is not None and finetuning_args.create_new_adapter:
|
||||||
raise ValueError("Cannot create new adapter upon a quantized model.")
|
raise ValueError("Cannot create new adapter upon a quantized model.")
|
||||||
|
|
||||||
if model_args.adapter_name_or_path is not None and len(model_args.adapter_name_or_path) != 1:
|
if model_args.adapter_name_or_path is not None and len(model_args.adapter_name_or_path) != 1:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user