mirror of
				https://github.com/hiyouga/LLaMA-Factory.git
				synced 2025-11-04 18:02:19 +08:00 
			
		
		
		
	[data] bailing template v2 & openai data converter (#9112)
This commit is contained in:
		
							parent
							
								
									db223e3975
								
							
						
					
					
						commit
						a22dab97fd
					
				@ -11,7 +11,7 @@
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
import json
 | 
			
		||||
import os
 | 
			
		||||
from abc import abstractmethod
 | 
			
		||||
from dataclasses import dataclass
 | 
			
		||||
@ -227,9 +227,151 @@ class SharegptDatasetConverter(DatasetConverter):
 | 
			
		||||
        return output
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class OpenAIDatasetConverter(DatasetConverter):
 | 
			
		||||
    def __call__(self, example: dict[str, Any]) -> dict[str, Any]:
 | 
			
		||||
        tag_mapping = {
 | 
			
		||||
            self.dataset_attr.user_tag: Role.USER.value,
 | 
			
		||||
            self.dataset_attr.assistant_tag: Role.ASSISTANT.value,
 | 
			
		||||
            self.dataset_attr.observation_tag: Role.OBSERVATION.value,
 | 
			
		||||
            self.dataset_attr.function_tag: Role.FUNCTION.value,
 | 
			
		||||
            self.dataset_attr.system_tag: Role.SYSTEM.value,
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        messages = example[self.dataset_attr.messages]
 | 
			
		||||
        if (
 | 
			
		||||
            self.dataset_attr.system_tag
 | 
			
		||||
            and len(messages) != 0
 | 
			
		||||
            and messages[0][self.dataset_attr.role_tag] == self.dataset_attr.system_tag
 | 
			
		||||
        ):
 | 
			
		||||
            system = messages[0][self.dataset_attr.content_tag]
 | 
			
		||||
            messages = messages[1:]
 | 
			
		||||
        else:
 | 
			
		||||
            system = example.get(self.dataset_attr.system, "") if self.dataset_attr.system else ""
 | 
			
		||||
 | 
			
		||||
        aligned_messages = []
 | 
			
		||||
        tool_responses = []
 | 
			
		||||
        broken_data = False
 | 
			
		||||
        for turn_idx, message in enumerate(messages):
 | 
			
		||||
            role = message[self.dataset_attr.role_tag]
 | 
			
		||||
            content = message[self.dataset_attr.content_tag]
 | 
			
		||||
 | 
			
		||||
            if role in [self.dataset_attr.assistant_tag, self.dataset_attr.function_tag]:
 | 
			
		||||
                if "tool_calls" in message and len(message['tool_calls']) > 0:
 | 
			
		||||
                    tool_calls_list = [tool['function'] for tool in message['tool_calls']]
 | 
			
		||||
                    content = json.dumps(tool_calls_list, ensure_ascii=False)
 | 
			
		||||
                    role = self.dataset_attr.function_tag
 | 
			
		||||
 | 
			
		||||
            if role == self.dataset_attr.observation_tag:
 | 
			
		||||
                tool_responses.append(content)
 | 
			
		||||
                continue
 | 
			
		||||
            elif len(tool_responses) > 0:
 | 
			
		||||
                _content = "\n</tool_response>\n<tool_response>\n".join(tool_responses)
 | 
			
		||||
                aligned_messages.append(
 | 
			
		||||
                    {
 | 
			
		||||
                        "role": Role.OBSERVATION.value,
 | 
			
		||||
                        "content": _content,
 | 
			
		||||
                    }
 | 
			
		||||
                )
 | 
			
		||||
                tool_responses = []
 | 
			
		||||
 | 
			
		||||
            aligned_messages.append(
 | 
			
		||||
                {
 | 
			
		||||
                    "role": tag_mapping[role],
 | 
			
		||||
                    "content": content,
 | 
			
		||||
                }
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        odd_tags = (Role.USER.value, Role.OBSERVATION.value)
 | 
			
		||||
        even_tags = (Role.ASSISTANT.value, Role.FUNCTION.value)
 | 
			
		||||
        accept_tags = (odd_tags, even_tags)
 | 
			
		||||
        for turn_idx, message in enumerate(aligned_messages):
 | 
			
		||||
            if message["role"] not in accept_tags[turn_idx % 2]:
 | 
			
		||||
                logger.warning_rank0(f"Invalid role tag in {messages}.")
 | 
			
		||||
                broken_data = True
 | 
			
		||||
                break
 | 
			
		||||
 | 
			
		||||
        if (not self.dataset_attr.ranking and len(aligned_messages) % 2 != 0) or (
 | 
			
		||||
            self.dataset_attr.ranking and len(aligned_messages) % 2 == 0
 | 
			
		||||
        ):
 | 
			
		||||
            logger.warning_rank0(f"Invalid message count in {messages}.")
 | 
			
		||||
            broken_data = True
 | 
			
		||||
 | 
			
		||||
        if broken_data:
 | 
			
		||||
            logger.warning_rank0("Skipping this abnormal example.")
 | 
			
		||||
            prompt, response = [], []
 | 
			
		||||
        elif self.dataset_attr.kto_tag and isinstance(example[self.dataset_attr.kto_tag], bool):  # kto example
 | 
			
		||||
            prompt = aligned_messages[:-1]
 | 
			
		||||
            response = aligned_messages[-1:]
 | 
			
		||||
            if example[self.dataset_attr.kto_tag]:
 | 
			
		||||
                response = response + [{"role": Role.ASSISTANT.value, "content": ""}]
 | 
			
		||||
            else:
 | 
			
		||||
                response = [{"role": Role.ASSISTANT.value, "content": ""}] + response
 | 
			
		||||
        elif (
 | 
			
		||||
            self.dataset_attr.ranking
 | 
			
		||||
            and isinstance(example[self.dataset_attr.chosen], dict)
 | 
			
		||||
            and isinstance(example[self.dataset_attr.rejected], dict)
 | 
			
		||||
        ):  # pairwise example
 | 
			
		||||
            chosen = example[self.dataset_attr.chosen]
 | 
			
		||||
            rejected = example[self.dataset_attr.rejected]
 | 
			
		||||
            if (
 | 
			
		||||
                chosen[self.dataset_attr.role_tag] not in accept_tags[-1]
 | 
			
		||||
                or rejected[self.dataset_attr.role_tag] not in accept_tags[-1]
 | 
			
		||||
            ):
 | 
			
		||||
                logger.warning_rank0(f"Invalid role tag in {[chosen, rejected]}.")
 | 
			
		||||
                broken_data = True
 | 
			
		||||
 | 
			
		||||
            prompt = aligned_messages
 | 
			
		||||
            response = [
 | 
			
		||||
                {
 | 
			
		||||
                    "role": tag_mapping[chosen[self.dataset_attr.role_tag]],
 | 
			
		||||
                    "content": chosen[self.dataset_attr.content_tag],
 | 
			
		||||
                },
 | 
			
		||||
                {
 | 
			
		||||
                    "role": tag_mapping[rejected[self.dataset_attr.role_tag]],
 | 
			
		||||
                    "content": rejected[self.dataset_attr.content_tag],
 | 
			
		||||
                },
 | 
			
		||||
            ]
 | 
			
		||||
        else:  # normal example
 | 
			
		||||
            prompt = aligned_messages[:-1]
 | 
			
		||||
            response = aligned_messages[-1:]
 | 
			
		||||
 | 
			
		||||
        tools = example.get(self.dataset_attr.tools, "") if self.dataset_attr.tools else ""
 | 
			
		||||
        if isinstance(tools, dict) or isinstance(tools, list):
 | 
			
		||||
            tools = json.dumps(tools, ensure_ascii=False)
 | 
			
		||||
 | 
			
		||||
        short_system_prompt = 'detailed thinking off'
 | 
			
		||||
        if not system:
 | 
			
		||||
            if not tools:
 | 
			
		||||
                system = short_system_prompt
 | 
			
		||||
            else:
 | 
			
		||||
                pass
 | 
			
		||||
        else:
 | 
			
		||||
            if not tools:
 | 
			
		||||
                if 'detailed thinking on' in system or 'detailed thinking off' in system:
 | 
			
		||||
                    pass
 | 
			
		||||
                else:
 | 
			
		||||
                    system += '\n' + short_system_prompt
 | 
			
		||||
            else:
 | 
			
		||||
                system += '\n'
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
        output = {
 | 
			
		||||
            "_prompt": prompt,
 | 
			
		||||
            "_response": response,
 | 
			
		||||
            "_system": system,
 | 
			
		||||
            "_tools": tools,
 | 
			
		||||
            "_images": self._find_medias(example[self.dataset_attr.images]) if self.dataset_attr.images else None,
 | 
			
		||||
            "_videos": self._find_medias(example[self.dataset_attr.videos]) if self.dataset_attr.videos else None,
 | 
			
		||||
            "_audios": self._find_medias(example[self.dataset_attr.audios]) if self.dataset_attr.audios else None,
 | 
			
		||||
        }
 | 
			
		||||
        return output
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
DATASET_CONVERTERS = {
 | 
			
		||||
    "alpaca": AlpacaDatasetConverter,
 | 
			
		||||
    "sharegpt": SharegptDatasetConverter,
 | 
			
		||||
    "openai": OpenAIDatasetConverter,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -679,6 +679,21 @@ register_template(
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
register_template(
 | 
			
		||||
    name="bailing_v2",
 | 
			
		||||
    format_user=StringFormatter(slots=["<role>HUMAN</role>{{content}}<|role_end|><role>ASSISTANT</role>"]),
 | 
			
		||||
    format_system=StringFormatter(slots=["<role>SYSTEM</role>{{content}}<|role_end|>"]),
 | 
			
		||||
    format_assistant=StringFormatter(slots=["{{content}}<|role_end|>"]),
 | 
			
		||||
    format_observation=StringFormatter(
 | 
			
		||||
        slots=["<role>OBSERVATION</role>\n<tool_response>\n{{content}}\n</tool_response><|role_end|><role>ASSISTANT</role>"]
 | 
			
		||||
    ),
 | 
			
		||||
    format_function=FunctionFormatter(slots=["{{content}}<|role_end|>"], tool_format="ling"),
 | 
			
		||||
    format_tools=ToolFormatter(tool_format="ling"),
 | 
			
		||||
    stop_words=["<|endoftext|>"],
 | 
			
		||||
    efficient_eos=True,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
register_template(
 | 
			
		||||
    name="belle",
 | 
			
		||||
    format_user=StringFormatter(slots=["Human: {{content}}\n\nBelle: "]),
 | 
			
		||||
 | 
			
		||||
@ -79,6 +79,15 @@ SEED_TOOL_PROMPT = (
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
LING_TOOL_PROMPT = (
 | 
			
		||||
    "# Tools\n\nYou may call one or more functions to assist with the user query.\n\n"
 | 
			
		||||
    "You are provided with function signatures within <tools></tools> XML tags:\n<tools>{tool_text}"
 | 
			
		||||
    "\n</tools>\n\nFor each function call, return a json object with function name and arguments within "
 | 
			
		||||
    """<tool_call></tool_call> XML tags:\n<tool_call>\n{{"name": <function-name>, """
 | 
			
		||||
    """"arguments": <args-json-object>}}\n</tool_call>"""
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class ToolUtils(ABC):
 | 
			
		||||
    """Base class for tool utilities."""
 | 
			
		||||
@ -406,6 +415,20 @@ class SeedToolUtils(ToolUtils):
 | 
			
		||||
        return results
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class LingToolUtils(QwenToolUtils):
 | 
			
		||||
    r"""Ling v2 tool using template."""
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def tool_formatter(tools: list[dict[str, Any]]) -> str:
 | 
			
		||||
        tool_text = ""
 | 
			
		||||
        for tool in tools:
 | 
			
		||||
            wrapped_tool = tool if tool.get("type") == "function" else {"type": "function", "function": tool}
 | 
			
		||||
            tool_text += "\n" + json.dumps(wrapped_tool, ensure_ascii=False)
 | 
			
		||||
 | 
			
		||||
        return LING_TOOL_PROMPT.format(tool_text=tool_text) + "\n" + "detailed thinking off"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
TOOLS = {
 | 
			
		||||
    "default": DefaultToolUtils(),
 | 
			
		||||
    "glm4": GLM4ToolUtils(),
 | 
			
		||||
@ -414,6 +437,7 @@ TOOLS = {
 | 
			
		||||
    "qwen": QwenToolUtils(),
 | 
			
		||||
    "glm4_moe": GLM4MOEToolUtils(),
 | 
			
		||||
    "seed_oss": SeedToolUtils(),
 | 
			
		||||
    "ling": LingToolUtils(),
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user