From aff6923fd1fa41771f0b1103e0827913e9e88b09 Mon Sep 17 00:00:00 2001 From: wangshaofei Date: Sun, 14 Sep 2025 02:20:34 +0800 Subject: [PATCH] [data] bailing template v2 & openai data converter (#9112) --- src/llamafactory/data/converter.py | 144 +++++++++++++++++++++++++++- src/llamafactory/data/template.py | 15 +++ src/llamafactory/data/tool_utils.py | 24 +++++ 3 files changed, 182 insertions(+), 1 deletion(-) diff --git a/src/llamafactory/data/converter.py b/src/llamafactory/data/converter.py index 2a284f19..4981466a 100644 --- a/src/llamafactory/data/converter.py +++ b/src/llamafactory/data/converter.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import json import os from abc import abstractmethod from dataclasses import dataclass @@ -227,9 +227,151 @@ class SharegptDatasetConverter(DatasetConverter): return output +@dataclass +class OpenAIDatasetConverter(DatasetConverter): + def __call__(self, example: dict[str, Any]) -> dict[str, Any]: + tag_mapping = { + self.dataset_attr.user_tag: Role.USER.value, + self.dataset_attr.assistant_tag: Role.ASSISTANT.value, + self.dataset_attr.observation_tag: Role.OBSERVATION.value, + self.dataset_attr.function_tag: Role.FUNCTION.value, + self.dataset_attr.system_tag: Role.SYSTEM.value, + } + + messages = example[self.dataset_attr.messages] + if ( + self.dataset_attr.system_tag + and len(messages) != 0 + and messages[0][self.dataset_attr.role_tag] == self.dataset_attr.system_tag + ): + system = messages[0][self.dataset_attr.content_tag] + messages = messages[1:] + else: + system = example.get(self.dataset_attr.system, "") if self.dataset_attr.system else "" + + aligned_messages = [] + tool_responses = [] + broken_data = False + for turn_idx, message in enumerate(messages): + role = message[self.dataset_attr.role_tag] + content = message[self.dataset_attr.content_tag] + + if role in [self.dataset_attr.assistant_tag, self.dataset_attr.function_tag]: + if "tool_calls" in message and len(message['tool_calls']) > 0: + tool_calls_list = [tool['function'] for tool in message['tool_calls']] + content = json.dumps(tool_calls_list, ensure_ascii=False) + role = self.dataset_attr.function_tag + + if role == self.dataset_attr.observation_tag: + tool_responses.append(content) + continue + elif len(tool_responses) > 0: + _content = "\n\n\n".join(tool_responses) + aligned_messages.append( + { + "role": Role.OBSERVATION.value, + "content": _content, + } + ) + tool_responses = [] + + aligned_messages.append( + { + "role": tag_mapping[role], + "content": content, + } + ) + + odd_tags = (Role.USER.value, Role.OBSERVATION.value) + even_tags = (Role.ASSISTANT.value, Role.FUNCTION.value) + accept_tags = (odd_tags, even_tags) + for turn_idx, message in enumerate(aligned_messages): + if message["role"] not in accept_tags[turn_idx % 2]: + logger.warning_rank0(f"Invalid role tag in {messages}.") + broken_data = True + break + + if (not self.dataset_attr.ranking and len(aligned_messages) % 2 != 0) or ( + self.dataset_attr.ranking and len(aligned_messages) % 2 == 0 + ): + logger.warning_rank0(f"Invalid message count in {messages}.") + broken_data = True + + if broken_data: + logger.warning_rank0("Skipping this abnormal example.") + prompt, response = [], [] + elif self.dataset_attr.kto_tag and isinstance(example[self.dataset_attr.kto_tag], bool): # kto example + prompt = aligned_messages[:-1] + response = aligned_messages[-1:] + if example[self.dataset_attr.kto_tag]: + response = response + [{"role": Role.ASSISTANT.value, "content": ""}] + else: + response = [{"role": Role.ASSISTANT.value, "content": ""}] + response + elif ( + self.dataset_attr.ranking + and isinstance(example[self.dataset_attr.chosen], dict) + and isinstance(example[self.dataset_attr.rejected], dict) + ): # pairwise example + chosen = example[self.dataset_attr.chosen] + rejected = example[self.dataset_attr.rejected] + if ( + chosen[self.dataset_attr.role_tag] not in accept_tags[-1] + or rejected[self.dataset_attr.role_tag] not in accept_tags[-1] + ): + logger.warning_rank0(f"Invalid role tag in {[chosen, rejected]}.") + broken_data = True + + prompt = aligned_messages + response = [ + { + "role": tag_mapping[chosen[self.dataset_attr.role_tag]], + "content": chosen[self.dataset_attr.content_tag], + }, + { + "role": tag_mapping[rejected[self.dataset_attr.role_tag]], + "content": rejected[self.dataset_attr.content_tag], + }, + ] + else: # normal example + prompt = aligned_messages[:-1] + response = aligned_messages[-1:] + + tools = example.get(self.dataset_attr.tools, "") if self.dataset_attr.tools else "" + if isinstance(tools, dict) or isinstance(tools, list): + tools = json.dumps(tools, ensure_ascii=False) + + short_system_prompt = 'detailed thinking off' + if not system: + if not tools: + system = short_system_prompt + else: + pass + else: + if not tools: + if 'detailed thinking on' in system or 'detailed thinking off' in system: + pass + else: + system += '\n' + short_system_prompt + else: + system += '\n' + + + output = { + "_prompt": prompt, + "_response": response, + "_system": system, + "_tools": tools, + "_images": self._find_medias(example[self.dataset_attr.images]) if self.dataset_attr.images else None, + "_videos": self._find_medias(example[self.dataset_attr.videos]) if self.dataset_attr.videos else None, + "_audios": self._find_medias(example[self.dataset_attr.audios]) if self.dataset_attr.audios else None, + } + return output + + DATASET_CONVERTERS = { "alpaca": AlpacaDatasetConverter, "sharegpt": SharegptDatasetConverter, + "openai": OpenAIDatasetConverter, } diff --git a/src/llamafactory/data/template.py b/src/llamafactory/data/template.py index 60650086..b57d00a5 100644 --- a/src/llamafactory/data/template.py +++ b/src/llamafactory/data/template.py @@ -679,6 +679,21 @@ register_template( ) +register_template( + name="bailing_v2", + format_user=StringFormatter(slots=["HUMAN{{content}}<|role_end|>ASSISTANT"]), + format_system=StringFormatter(slots=["SYSTEM{{content}}<|role_end|>"]), + format_assistant=StringFormatter(slots=["{{content}}<|role_end|>"]), + format_observation=StringFormatter( + slots=["OBSERVATION\n\n{{content}}\n<|role_end|>ASSISTANT"] + ), + format_function=FunctionFormatter(slots=["{{content}}<|role_end|>"], tool_format="ling"), + format_tools=ToolFormatter(tool_format="ling"), + stop_words=["<|endoftext|>"], + efficient_eos=True, +) + + register_template( name="belle", format_user=StringFormatter(slots=["Human: {{content}}\n\nBelle: "]), diff --git a/src/llamafactory/data/tool_utils.py b/src/llamafactory/data/tool_utils.py index 29257551..3e059d87 100644 --- a/src/llamafactory/data/tool_utils.py +++ b/src/llamafactory/data/tool_utils.py @@ -79,6 +79,15 @@ SEED_TOOL_PROMPT = ( ) +LING_TOOL_PROMPT = ( + "# Tools\n\nYou may call one or more functions to assist with the user query.\n\n" + "You are provided with function signatures within XML tags:\n{tool_text}" + "\n\n\nFor each function call, return a json object with function name and arguments within " + """ XML tags:\n\n{{"name": , """ + """"arguments": }}\n""" +) + + @dataclass class ToolUtils(ABC): """Base class for tool utilities.""" @@ -406,6 +415,20 @@ class SeedToolUtils(ToolUtils): return results +class LingToolUtils(QwenToolUtils): + r"""Ling v2 tool using template.""" + + @override + @staticmethod + def tool_formatter(tools: list[dict[str, Any]]) -> str: + tool_text = "" + for tool in tools: + wrapped_tool = tool if tool.get("type") == "function" else {"type": "function", "function": tool} + tool_text += "\n" + json.dumps(wrapped_tool, ensure_ascii=False) + + return LING_TOOL_PROMPT.format(tool_text=tool_text) + "\n" + "detailed thinking off" + + TOOLS = { "default": DefaultToolUtils(), "glm4": GLM4ToolUtils(), @@ -414,6 +437,7 @@ TOOLS = { "qwen": QwenToolUtils(), "glm4_moe": GLM4MOEToolUtils(), "seed_oss": SeedToolUtils(), + "ling": LingToolUtils(), }