mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-09-18 11:02:49 +08:00
[data] bailing template v2 & openai data converter (#9112)
This commit is contained in:
parent
59f2bf1ea3
commit
aff6923fd1
@ -11,7 +11,7 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
import json
|
||||||
import os
|
import os
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
@ -227,9 +227,151 @@ class SharegptDatasetConverter(DatasetConverter):
|
|||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class OpenAIDatasetConverter(DatasetConverter):
|
||||||
|
def __call__(self, example: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
tag_mapping = {
|
||||||
|
self.dataset_attr.user_tag: Role.USER.value,
|
||||||
|
self.dataset_attr.assistant_tag: Role.ASSISTANT.value,
|
||||||
|
self.dataset_attr.observation_tag: Role.OBSERVATION.value,
|
||||||
|
self.dataset_attr.function_tag: Role.FUNCTION.value,
|
||||||
|
self.dataset_attr.system_tag: Role.SYSTEM.value,
|
||||||
|
}
|
||||||
|
|
||||||
|
messages = example[self.dataset_attr.messages]
|
||||||
|
if (
|
||||||
|
self.dataset_attr.system_tag
|
||||||
|
and len(messages) != 0
|
||||||
|
and messages[0][self.dataset_attr.role_tag] == self.dataset_attr.system_tag
|
||||||
|
):
|
||||||
|
system = messages[0][self.dataset_attr.content_tag]
|
||||||
|
messages = messages[1:]
|
||||||
|
else:
|
||||||
|
system = example.get(self.dataset_attr.system, "") if self.dataset_attr.system else ""
|
||||||
|
|
||||||
|
aligned_messages = []
|
||||||
|
tool_responses = []
|
||||||
|
broken_data = False
|
||||||
|
for turn_idx, message in enumerate(messages):
|
||||||
|
role = message[self.dataset_attr.role_tag]
|
||||||
|
content = message[self.dataset_attr.content_tag]
|
||||||
|
|
||||||
|
if role in [self.dataset_attr.assistant_tag, self.dataset_attr.function_tag]:
|
||||||
|
if "tool_calls" in message and len(message['tool_calls']) > 0:
|
||||||
|
tool_calls_list = [tool['function'] for tool in message['tool_calls']]
|
||||||
|
content = json.dumps(tool_calls_list, ensure_ascii=False)
|
||||||
|
role = self.dataset_attr.function_tag
|
||||||
|
|
||||||
|
if role == self.dataset_attr.observation_tag:
|
||||||
|
tool_responses.append(content)
|
||||||
|
continue
|
||||||
|
elif len(tool_responses) > 0:
|
||||||
|
_content = "\n</tool_response>\n<tool_response>\n".join(tool_responses)
|
||||||
|
aligned_messages.append(
|
||||||
|
{
|
||||||
|
"role": Role.OBSERVATION.value,
|
||||||
|
"content": _content,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
tool_responses = []
|
||||||
|
|
||||||
|
aligned_messages.append(
|
||||||
|
{
|
||||||
|
"role": tag_mapping[role],
|
||||||
|
"content": content,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
odd_tags = (Role.USER.value, Role.OBSERVATION.value)
|
||||||
|
even_tags = (Role.ASSISTANT.value, Role.FUNCTION.value)
|
||||||
|
accept_tags = (odd_tags, even_tags)
|
||||||
|
for turn_idx, message in enumerate(aligned_messages):
|
||||||
|
if message["role"] not in accept_tags[turn_idx % 2]:
|
||||||
|
logger.warning_rank0(f"Invalid role tag in {messages}.")
|
||||||
|
broken_data = True
|
||||||
|
break
|
||||||
|
|
||||||
|
if (not self.dataset_attr.ranking and len(aligned_messages) % 2 != 0) or (
|
||||||
|
self.dataset_attr.ranking and len(aligned_messages) % 2 == 0
|
||||||
|
):
|
||||||
|
logger.warning_rank0(f"Invalid message count in {messages}.")
|
||||||
|
broken_data = True
|
||||||
|
|
||||||
|
if broken_data:
|
||||||
|
logger.warning_rank0("Skipping this abnormal example.")
|
||||||
|
prompt, response = [], []
|
||||||
|
elif self.dataset_attr.kto_tag and isinstance(example[self.dataset_attr.kto_tag], bool): # kto example
|
||||||
|
prompt = aligned_messages[:-1]
|
||||||
|
response = aligned_messages[-1:]
|
||||||
|
if example[self.dataset_attr.kto_tag]:
|
||||||
|
response = response + [{"role": Role.ASSISTANT.value, "content": ""}]
|
||||||
|
else:
|
||||||
|
response = [{"role": Role.ASSISTANT.value, "content": ""}] + response
|
||||||
|
elif (
|
||||||
|
self.dataset_attr.ranking
|
||||||
|
and isinstance(example[self.dataset_attr.chosen], dict)
|
||||||
|
and isinstance(example[self.dataset_attr.rejected], dict)
|
||||||
|
): # pairwise example
|
||||||
|
chosen = example[self.dataset_attr.chosen]
|
||||||
|
rejected = example[self.dataset_attr.rejected]
|
||||||
|
if (
|
||||||
|
chosen[self.dataset_attr.role_tag] not in accept_tags[-1]
|
||||||
|
or rejected[self.dataset_attr.role_tag] not in accept_tags[-1]
|
||||||
|
):
|
||||||
|
logger.warning_rank0(f"Invalid role tag in {[chosen, rejected]}.")
|
||||||
|
broken_data = True
|
||||||
|
|
||||||
|
prompt = aligned_messages
|
||||||
|
response = [
|
||||||
|
{
|
||||||
|
"role": tag_mapping[chosen[self.dataset_attr.role_tag]],
|
||||||
|
"content": chosen[self.dataset_attr.content_tag],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": tag_mapping[rejected[self.dataset_attr.role_tag]],
|
||||||
|
"content": rejected[self.dataset_attr.content_tag],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
else: # normal example
|
||||||
|
prompt = aligned_messages[:-1]
|
||||||
|
response = aligned_messages[-1:]
|
||||||
|
|
||||||
|
tools = example.get(self.dataset_attr.tools, "") if self.dataset_attr.tools else ""
|
||||||
|
if isinstance(tools, dict) or isinstance(tools, list):
|
||||||
|
tools = json.dumps(tools, ensure_ascii=False)
|
||||||
|
|
||||||
|
short_system_prompt = 'detailed thinking off'
|
||||||
|
if not system:
|
||||||
|
if not tools:
|
||||||
|
system = short_system_prompt
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
if not tools:
|
||||||
|
if 'detailed thinking on' in system or 'detailed thinking off' in system:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
system += '\n' + short_system_prompt
|
||||||
|
else:
|
||||||
|
system += '\n'
|
||||||
|
|
||||||
|
|
||||||
|
output = {
|
||||||
|
"_prompt": prompt,
|
||||||
|
"_response": response,
|
||||||
|
"_system": system,
|
||||||
|
"_tools": tools,
|
||||||
|
"_images": self._find_medias(example[self.dataset_attr.images]) if self.dataset_attr.images else None,
|
||||||
|
"_videos": self._find_medias(example[self.dataset_attr.videos]) if self.dataset_attr.videos else None,
|
||||||
|
"_audios": self._find_medias(example[self.dataset_attr.audios]) if self.dataset_attr.audios else None,
|
||||||
|
}
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
DATASET_CONVERTERS = {
|
DATASET_CONVERTERS = {
|
||||||
"alpaca": AlpacaDatasetConverter,
|
"alpaca": AlpacaDatasetConverter,
|
||||||
"sharegpt": SharegptDatasetConverter,
|
"sharegpt": SharegptDatasetConverter,
|
||||||
|
"openai": OpenAIDatasetConverter,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -679,6 +679,21 @@ register_template(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_template(
|
||||||
|
name="bailing_v2",
|
||||||
|
format_user=StringFormatter(slots=["<role>HUMAN</role>{{content}}<|role_end|><role>ASSISTANT</role>"]),
|
||||||
|
format_system=StringFormatter(slots=["<role>SYSTEM</role>{{content}}<|role_end|>"]),
|
||||||
|
format_assistant=StringFormatter(slots=["{{content}}<|role_end|>"]),
|
||||||
|
format_observation=StringFormatter(
|
||||||
|
slots=["<role>OBSERVATION</role>\n<tool_response>\n{{content}}\n</tool_response><|role_end|><role>ASSISTANT</role>"]
|
||||||
|
),
|
||||||
|
format_function=FunctionFormatter(slots=["{{content}}<|role_end|>"], tool_format="ling"),
|
||||||
|
format_tools=ToolFormatter(tool_format="ling"),
|
||||||
|
stop_words=["<|endoftext|>"],
|
||||||
|
efficient_eos=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
register_template(
|
register_template(
|
||||||
name="belle",
|
name="belle",
|
||||||
format_user=StringFormatter(slots=["Human: {{content}}\n\nBelle: "]),
|
format_user=StringFormatter(slots=["Human: {{content}}\n\nBelle: "]),
|
||||||
|
@ -79,6 +79,15 @@ SEED_TOOL_PROMPT = (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
LING_TOOL_PROMPT = (
|
||||||
|
"# Tools\n\nYou may call one or more functions to assist with the user query.\n\n"
|
||||||
|
"You are provided with function signatures within <tools></tools> XML tags:\n<tools>{tool_text}"
|
||||||
|
"\n</tools>\n\nFor each function call, return a json object with function name and arguments within "
|
||||||
|
"""<tool_call></tool_call> XML tags:\n<tool_call>\n{{"name": <function-name>, """
|
||||||
|
""""arguments": <args-json-object>}}\n</tool_call>"""
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ToolUtils(ABC):
|
class ToolUtils(ABC):
|
||||||
"""Base class for tool utilities."""
|
"""Base class for tool utilities."""
|
||||||
@ -406,6 +415,20 @@ class SeedToolUtils(ToolUtils):
|
|||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
class LingToolUtils(QwenToolUtils):
|
||||||
|
r"""Ling v2 tool using template."""
|
||||||
|
|
||||||
|
@override
|
||||||
|
@staticmethod
|
||||||
|
def tool_formatter(tools: list[dict[str, Any]]) -> str:
|
||||||
|
tool_text = ""
|
||||||
|
for tool in tools:
|
||||||
|
wrapped_tool = tool if tool.get("type") == "function" else {"type": "function", "function": tool}
|
||||||
|
tool_text += "\n" + json.dumps(wrapped_tool, ensure_ascii=False)
|
||||||
|
|
||||||
|
return LING_TOOL_PROMPT.format(tool_text=tool_text) + "\n" + "detailed thinking off"
|
||||||
|
|
||||||
|
|
||||||
TOOLS = {
|
TOOLS = {
|
||||||
"default": DefaultToolUtils(),
|
"default": DefaultToolUtils(),
|
||||||
"glm4": GLM4ToolUtils(),
|
"glm4": GLM4ToolUtils(),
|
||||||
@ -414,6 +437,7 @@ TOOLS = {
|
|||||||
"qwen": QwenToolUtils(),
|
"qwen": QwenToolUtils(),
|
||||||
"glm4_moe": GLM4MOEToolUtils(),
|
"glm4_moe": GLM4MOEToolUtils(),
|
||||||
"seed_oss": SeedToolUtils(),
|
"seed_oss": SeedToolUtils(),
|
||||||
|
"ling": LingToolUtils(),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user