From 052ca871bd380812e26397bdc0340012c12efda1 Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Wed, 30 Apr 2025 16:18:00 +0800 Subject: [PATCH] [data] optimize qwen3 loss computation (#7923) --- README.md | 6 +- README_zh.md | 6 +- src/llamafactory/chat/hf_engine.py | 6 +- src/llamafactory/chat/sglang_engine.py | 6 +- src/llamafactory/chat/vllm_engine.py | 6 +- src/llamafactory/data/template.py | 130 +++++++++++++++---- src/llamafactory/extras/constants.py | 22 +++- src/llamafactory/webui/chatter.py | 2 + src/llamafactory/webui/components/chatbot.py | 3 + src/llamafactory/webui/locales.py | 17 +++ tests/data/test_template.py | 40 ++++++ 11 files changed, 205 insertions(+), 39 deletions(-) diff --git a/README.md b/README.md index 710b46cd..2e2ddfbb 100644 --- a/README.md +++ b/README.md @@ -5,7 +5,7 @@ [![GitHub contributors](https://img.shields.io/github/contributors/hiyouga/LLaMA-Factory?color=orange)](https://github.com/hiyouga/LLaMA-Factory/graphs/contributors) [![GitHub workflow](https://github.com/hiyouga/LLaMA-Factory/actions/workflows/tests.yml/badge.svg)](https://github.com/hiyouga/LLaMA-Factory/actions/workflows/tests.yml) [![PyPI](https://img.shields.io/pypi/v/llamafactory)](https://pypi.org/project/llamafactory/) -[![Citation](https://img.shields.io/badge/citation-429-green)](https://scholar.google.com/scholar?cites=12620864006390196564) +[![Citation](https://img.shields.io/badge/citation-447-green)](https://scholar.google.com/scholar?cites=12620864006390196564) [![GitHub pull request](https://img.shields.io/badge/PRs-welcome-blue)](https://github.com/hiyouga/LLaMA-Factory/pulls) [![Twitter](https://img.shields.io/twitter/follow/llamafactory_ai)](https://twitter.com/llamafactory_ai) @@ -246,11 +246,11 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/ | [Command R](https://huggingface.co/CohereForAI) | 35B/104B | cohere | | [DeepSeek (Code/MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek | | [DeepSeek 2.5/3](https://huggingface.co/deepseek-ai) | 236B/671B | deepseek3 | -| [DeepSeek R1 (Distill)](https://huggingface.co/deepseek-ai) | 1.5B/7B/8B/14B/32B/70B/671B | deepseek3 | +| [DeepSeek R1 (Distill)](https://huggingface.co/deepseek-ai) | 1.5B/7B/8B/14B/32B/70B/671B | deepseekr1 | | [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon | | [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma | | [Gemma 3](https://huggingface.co/google) | 1B/4B/12B/27B | gemma3/gemma (1B) | -| [GLM-4/GLM-4-0414/GLM-Z1](https://huggingface.co/THUDM) | 9B/32B | glm4 | +| [GLM-4/GLM-4-0414/GLM-Z1](https://huggingface.co/THUDM) | 9B/32B | glm4/glmz1 | | [GPT-2](https://huggingface.co/openai-community) | 0.1B/0.4B/0.8B/1.5B | - | | [Granite 3.0-3.3](https://huggingface.co/ibm-granite) | 1B/2B/3B/8B | granite3 | | [Hunyuan](https://huggingface.co/tencent/) | 7B | hunyuan | diff --git a/README_zh.md b/README_zh.md index bba64696..a3914f7c 100644 --- a/README_zh.md +++ b/README_zh.md @@ -5,7 +5,7 @@ [![GitHub contributors](https://img.shields.io/github/contributors/hiyouga/LLaMA-Factory?color=orange)](https://github.com/hiyouga/LLaMA-Factory/graphs/contributors) [![GitHub workflow](https://github.com/hiyouga/LLaMA-Factory/actions/workflows/tests.yml/badge.svg)](https://github.com/hiyouga/LLaMA-Factory/actions/workflows/tests.yml) [![PyPI](https://img.shields.io/pypi/v/llamafactory)](https://pypi.org/project/llamafactory/) -[![Citation](https://img.shields.io/badge/citation-429-green)](https://scholar.google.com/scholar?cites=12620864006390196564) +[![Citation](https://img.shields.io/badge/citation-447-green)](https://scholar.google.com/scholar?cites=12620864006390196564) [![GitHub pull request](https://img.shields.io/badge/PRs-welcome-blue)](https://github.com/hiyouga/LLaMA-Factory/pulls) [![Twitter](https://img.shields.io/twitter/follow/llamafactory_ai)](https://twitter.com/llamafactory_ai) @@ -249,11 +249,11 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc | [Command R](https://huggingface.co/CohereForAI) | 35B/104B | cohere | | [DeepSeek (Code/MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek | | [DeepSeek 2.5/3](https://huggingface.co/deepseek-ai) | 236B/671B | deepseek3 | -| [DeepSeek R1 (Distill)](https://huggingface.co/deepseek-ai) | 1.5B/7B/8B/14B/32B/70B/671B | deepseek3 | +| [DeepSeek R1 (Distill)](https://huggingface.co/deepseek-ai) | 1.5B/7B/8B/14B/32B/70B/671B | deepseekr1 | | [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon | | [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma | | [Gemma 3](https://huggingface.co/google) | 1B/4B/12B/27B | gemma3/gemma (1B) | -| [GLM-4/GLM-4-0414/GLM-Z1](https://huggingface.co/THUDM) | 9B/32B | glm4 | +| [GLM-4/GLM-4-0414/GLM-Z1](https://huggingface.co/THUDM) | 9B/32B | glm4/glmz1 | | [GPT-2](https://huggingface.co/openai-community) | 0.1B/0.4B/0.8B/1.5B | - | | [Granite 3.0-3.3](https://huggingface.co/ibm-granite) | 1B/2B/3B/8B | granite3 | | [Hunyuan](https://huggingface.co/tencent/) | 7B | hunyuan | diff --git a/src/llamafactory/chat/hf_engine.py b/src/llamafactory/chat/hf_engine.py index 20a3c190..5f335d8b 100644 --- a/src/llamafactory/chat/hf_engine.py +++ b/src/llamafactory/chat/hf_engine.py @@ -103,9 +103,11 @@ class HuggingfaceEngine(BaseEngine): messages = template.mm_plugin.process_messages( messages, mm_input_dict["images"], mm_input_dict["videos"], mm_input_dict["audios"], processor ) - paired_messages = messages + [{"role": "assistant", "content": ""}] + # add thought words to avoid skipping thinking + paired_messages = messages + [{"role": "assistant", "content": template.add_thought("")}] system = system or generating_args["default_system"] - prompt_ids, _ = template.encode_oneturn(tokenizer, paired_messages, system, tools) + enable_thinking = input_kwargs.pop("enable_thinking", True) + prompt_ids, _ = template.encode_oneturn(tokenizer, paired_messages, system, tools, enable_thinking) prompt_ids, _ = template.mm_plugin.process_token_ids( prompt_ids, None, diff --git a/src/llamafactory/chat/sglang_engine.py b/src/llamafactory/chat/sglang_engine.py index 3fc3aeb5..44414a05 100644 --- a/src/llamafactory/chat/sglang_engine.py +++ b/src/llamafactory/chat/sglang_engine.py @@ -146,9 +146,11 @@ class SGLangEngine(BaseEngine): messages = self.template.mm_plugin.process_messages( messages, images or [], videos or [], audios or [], self.processor ) - paired_messages = messages + [{"role": "assistant", "content": ""}] + # add thought words to avoid skipping thinking + paired_messages = messages + [{"role": "assistant", "content": self.template.add_thought("")}] system = system or self.generating_args["default_system"] - prompt_ids, _ = self.template.encode_oneturn(self.tokenizer, paired_messages, system, tools) + enable_thinking = input_kwargs.pop("enable_thinking", True) + prompt_ids, _ = self.template.encode_oneturn(self.tokenizer, paired_messages, system, tools, enable_thinking) prompt_length = len(prompt_ids) temperature: Optional[float] = input_kwargs.pop("temperature", None) diff --git a/src/llamafactory/chat/vllm_engine.py b/src/llamafactory/chat/vllm_engine.py index 1100fc8a..274d12e3 100644 --- a/src/llamafactory/chat/vllm_engine.py +++ b/src/llamafactory/chat/vllm_engine.py @@ -123,9 +123,11 @@ class VllmEngine(BaseEngine): messages = self.template.mm_plugin.process_messages( messages, images or [], videos or [], audios or [], self.processor ) - paired_messages = messages + [{"role": "assistant", "content": ""}] + # add thought words to avoid skipping thinking + paired_messages = messages + [{"role": "assistant", "content": self.template.add_thought("")}] system = system or self.generating_args["default_system"] - prompt_ids, _ = self.template.encode_oneturn(self.tokenizer, paired_messages, system, tools) + enable_thinking = input_kwargs.pop("enable_thinking", True) + prompt_ids, _ = self.template.encode_oneturn(self.tokenizer, paired_messages, system, tools, enable_thinking) prompt_length = len(prompt_ids) temperature: Optional[float] = input_kwargs.pop("temperature", None) diff --git a/src/llamafactory/data/template.py b/src/llamafactory/data/template.py index 3b39a768..5aa1f1be 100644 --- a/src/llamafactory/data/template.py +++ b/src/llamafactory/data/template.py @@ -13,6 +13,7 @@ # limitations under the License. import re +from copy import deepcopy from dataclasses import dataclass from typing import TYPE_CHECKING, Optional, Union @@ -59,9 +60,10 @@ class Template: messages: list[dict[str, str]], system: Optional[str] = None, tools: Optional[str] = None, + enable_thinking: bool = True, ) -> tuple[list[int], list[int]]: r"""Return a single pair of token ids representing prompt and response respectively.""" - encoded_messages = self._encode(tokenizer, messages, system, tools, remove_thought=True) + encoded_messages = self._encode(tokenizer, messages, system, tools) prompt_ids = [] for encoded_ids in encoded_messages[:-1]: prompt_ids += encoded_ids @@ -77,7 +79,7 @@ class Template: tools: Optional[str] = None, ) -> list[tuple[list[int], list[int]]]: r"""Return multiple pairs of token ids representing prompts and responses respectively.""" - encoded_messages = self._encode(tokenizer, messages, system, tools, remove_thought=False) + encoded_messages = self._encode(tokenizer, messages, system, tools) return [(encoded_messages[i], encoded_messages[i + 1]) for i in range(0, len(encoded_messages), 2)] def extract_tool(self, content: str) -> Union[str, list["FunctionCall"]]: @@ -92,6 +94,19 @@ class Template: return list(stop_token_ids) + def add_thought(self, content: str) -> str: + r"""Add empty thought to assistant message.""" + return f"{self.thought_words[0]}\n\n{self.thought_words[1]}\n\n" + content + + def remove_thought(self, content: str) -> str: + r"""Remove thought from assistant message.""" + pattern = re.compile(f"{re.escape(self.thought_words[0])}(.*?){re.escape(self.thought_words[1])}", re.DOTALL) + return re.sub(pattern, "", content).lstrip("\n") + + def get_thought_word_ids(self, tokenizer: "PreTrainedTokenizer") -> list[int]: + r"""Get the token ids of thought words.""" + return tokenizer.encode(f"{self.thought_words[0]}\n\n{self.thought_words[1]}\n\n", add_special_tokens=False) + def _convert_elements_to_ids(self, tokenizer: "PreTrainedTokenizer", elements: "SLOTS") -> list[int]: r"""Convert elements to token ids.""" token_ids = [] @@ -111,18 +126,12 @@ class Template: return token_ids - def _remove_thought(self, content: str) -> str: - r"""Remove thought from assistant message.""" - pattern = re.compile(f"{re.escape(self.thought_words[0])}(.*?){re.escape(self.thought_words[1])}", re.DOTALL) - return re.sub(pattern, "", content).lstrip("\n") - def _encode( self, tokenizer: "PreTrainedTokenizer", messages: list[dict[str, str]], system: Optional[str], tools: Optional[str], - remove_thought: bool, ) -> list[list[int]]: r"""Encode formatted inputs to pairs of token ids. @@ -140,18 +149,14 @@ class Template: tool_text = self.format_tools.apply(content=tools)[0] if tools else "" elements += self.format_system.apply(content=(system + tool_text)) - content = message["content"] - if remove_thought and message["role"] == Role.ASSISTANT and (i != len(messages) - 1): - content = self._remove_thought(content) - if message["role"] == Role.USER: - elements += self.format_user.apply(content=content, idx=str(i // 2)) + elements += self.format_user.apply(content=message["content"], idx=str(i // 2)) elif message["role"] == Role.ASSISTANT: - elements += self.format_assistant.apply(content=content) + elements += self.format_assistant.apply(content=message["content"]) elif message["role"] == Role.OBSERVATION: - elements += self.format_observation.apply(content=content) + elements += self.format_observation.apply(content=message["content"]) elif message["role"] == Role.FUNCTION: - elements += self.format_function.apply(content=content) + elements += self.format_function.apply(content=message["content"]) else: raise NotImplementedError("Unexpected role: {}".format(message["role"])) @@ -331,7 +336,6 @@ class Llama2Template(Template): messages: list[dict[str, str]], system: str, tools: str, - remove_thought: bool, ) -> list[list[int]]: system = system or self.default_system encoded_messages = [] @@ -345,18 +349,14 @@ class Llama2Template(Template): tool_text = self.format_tools.apply(content=tools)[0] if tools else "" system_text = self.format_system.apply(content=(system + tool_text))[0] - content = message["content"] - if remove_thought and message["role"] == Role.ASSISTANT and (i != len(messages) - 1): - content = self._remove_thought(content) - if message["role"] == Role.USER: - elements += self.format_user.apply(content=system_text + content) + elements += self.format_user.apply(content=system_text + message["content"]) elif message["role"] == Role.ASSISTANT: - elements += self.format_assistant.apply(content=content) + elements += self.format_assistant.apply(content=message["content"]) elif message["role"] == Role.OBSERVATION: - elements += self.format_observation.apply(content=content) + elements += self.format_observation.apply(content=message["content"]) elif message["role"] == Role.FUNCTION: - elements += self.format_function.apply(content=content) + elements += self.format_function.apply(content=message["content"]) else: raise NotImplementedError("Unexpected role: {}".format(message["role"])) @@ -395,6 +395,60 @@ class Llama2Template(Template): return jinja_template +@dataclass +class ReasoningTemplate(Template): + r"""A template that add thought to assistant message.""" + + @override + def encode_oneturn( + self, + tokenizer: "PreTrainedTokenizer", + messages: list[dict[str, str]], + system: Optional[str] = None, + tools: Optional[str] = None, + enable_thinking: bool = True, + ) -> tuple[list[int], list[int]]: + messages = deepcopy(messages) + for i in range(len(messages)): + if messages[i]["role"] == Role.ASSISTANT and (i != len(messages) - 1): + messages[i]["content"] = self.remove_thought(messages[i]["content"]) + + encoded_messages = self._encode(tokenizer, messages, system, tools) + prompt_ids = [] + for encoded_ids in encoded_messages[:-1]: + prompt_ids += encoded_ids + + if not enable_thinking or ( + messages[-1]["role"] == Role.ASSISTANT + and self.thought_words[0] not in messages[-1]["content"] + and self.thought_words[1] not in messages[-1]["content"] + ): + prompt_ids += self.get_thought_word_ids(tokenizer) + + response_ids = encoded_messages[-1] + return prompt_ids, response_ids + + @override + def encode_multiturn( + self, + tokenizer: "PreTrainedTokenizer", + messages: list[dict[str, str]], + system: Optional[str] = None, + tools: Optional[str] = None, + ) -> list[tuple[list[int], list[int]]]: + messages = deepcopy(messages) + encoded_messages = self._encode(tokenizer, messages, system, tools) + for i in range(len(messages) - 1): + if ( + messages[i + 1]["role"] == Role.ASSISTANT + and self.thought_words[0] not in messages[i + 1]["content"] + and self.thought_words[1] not in messages[i + 1]["content"] + ): + encoded_messages[i] += self.get_thought_word_ids(tokenizer) + + return [(encoded_messages[i], encoded_messages[i + 1]) for i in range(0, len(encoded_messages), 2)] + + TEMPLATES: dict[str, "Template"] = {} @@ -778,6 +832,15 @@ register_template( ) +# copied from deepseek3 template +register_template( + name="deepseekr1", + format_user=StringFormatter(slots=["<|User|>{{content}}<|Assistant|>"]), + format_prefix=EmptyFormatter(slots=[{"bos_token"}]), + template_class=ReasoningTemplate, +) + + register_template( name="deepseekcoder", format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n### Response:"]), @@ -878,6 +941,22 @@ register_template( ) +# copied from glm4 template +register_template( + name="glmz1", + format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]), + format_assistant=StringFormatter(slots=["\n{{content}}"]), + format_system=StringFormatter(slots=["<|system|>\n{{content}}"]), + format_function=FunctionFormatter(slots=["{{content}}"], tool_format="glm4"), + format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>"]), + format_tools=ToolFormatter(tool_format="glm4"), + format_prefix=EmptyFormatter(slots=["[gMASK]"]), + stop_words=["<|user|>", "<|observation|>"], + efficient_eos=True, + template_class=ReasoningTemplate, +) + + register_template( name="granite3", format_user=StringFormatter( @@ -1458,6 +1537,7 @@ register_template( format_tools=ToolFormatter(tool_format="qwen"), stop_words=["<|im_end|>"], replace_eos=True, + template_class=ReasoningTemplate, ) diff --git a/src/llamafactory/extras/constants.py b/src/llamafactory/extras/constants.py index be4a417d..4e4d0760 100644 --- a/src/llamafactory/extras/constants.py +++ b/src/llamafactory/extras/constants.py @@ -533,6 +533,17 @@ register_model_group( DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V3", DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V3", }, + "DeepSeek-V3-671B-0324-Chat": { + DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V3-0324", + DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V3-0324", + }, + }, + template="deepseek3", +) + + +register_model_group( + models={ "DeepSeek-R1-1.5B-Distill": { DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", @@ -566,7 +577,7 @@ register_model_group( DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-R1", }, }, - template="deepseek3", + template="deepseekr1", ) @@ -737,6 +748,13 @@ register_model_group( DownloadSource.DEFAULT: "THUDM/GLM-4-32B-0414", DownloadSource.MODELSCOPE: "ZhipuAI/GLM-4-32B-0414", }, + }, + template="glm4", +) + + +register_model_group( + models={ "GLM-Z1-9B-0414-Chat": { DownloadSource.DEFAULT: "THUDM/GLM-Z1-9B-0414", DownloadSource.MODELSCOPE: "ZhipuAI/GLM-Z1-9B-0414", @@ -746,7 +764,7 @@ register_model_group( DownloadSource.MODELSCOPE: "ZhipuAI/GLM-Z1-32B-0414", }, }, - template="glm4", + template="glmz1", ) diff --git a/src/llamafactory/webui/chatter.py b/src/llamafactory/webui/chatter.py index a2242bb3..9f5733b3 100644 --- a/src/llamafactory/webui/chatter.py +++ b/src/llamafactory/webui/chatter.py @@ -191,6 +191,7 @@ class WebChatModel(ChatModel): temperature: float, skip_special_tokens: bool, escape_html: bool, + enable_thinking: bool, ) -> Generator[tuple[list[dict[str, str]], list[dict[str, str]]], None, None]: r"""Generate output text in stream. @@ -210,6 +211,7 @@ class WebChatModel(ChatModel): top_p=top_p, temperature=temperature, skip_special_tokens=skip_special_tokens, + enable_thinking=enable_thinking, ): response += new_text if tools: diff --git a/src/llamafactory/webui/components/chatbot.py b/src/llamafactory/webui/components/chatbot.py index 52217e16..53ac3e4e 100644 --- a/src/llamafactory/webui/components/chatbot.py +++ b/src/llamafactory/webui/components/chatbot.py @@ -79,6 +79,7 @@ def create_chat_box( temperature = gr.Slider(minimum=0.01, maximum=1.5, value=0.95, step=0.01) skip_special_tokens = gr.Checkbox(value=True) escape_html = gr.Checkbox(value=True) + enable_thinking = gr.Checkbox(value=True) clear_btn = gr.Button() tools.input(check_json_schema, inputs=[tools, engine.manager.get_elem_by_id("top.lang")]) @@ -103,6 +104,7 @@ def create_chat_box( temperature, skip_special_tokens, escape_html, + enable_thinking, ], [chatbot, messages], ) @@ -127,6 +129,7 @@ def create_chat_box( temperature=temperature, skip_special_tokens=skip_special_tokens, escape_html=escape_html, + enable_thinking=enable_thinking, clear_btn=clear_btn, ), ) diff --git a/src/llamafactory/webui/locales.py b/src/llamafactory/webui/locales.py index ad8ebeb1..a1ac2c51 100644 --- a/src/llamafactory/webui/locales.py +++ b/src/llamafactory/webui/locales.py @@ -2468,6 +2468,23 @@ LOCALES = { "label": "HTML タグをエスケープ", }, }, + "enable_thinking": { + "en": { + "label": "Enable thinking", + }, + "ru": { + "label": "Включить мышление", + }, + "zh": { + "label": "启用思考", + }, + "ko": { + "label": "사고를 활성화하다", + }, + "ja": { + "label": "思考を可能にする", + }, + }, "clear_btn": { "en": { "value": "Clear history", diff --git a/tests/data/test_template.py b/tests/data/test_template.py index 74f95df8..e74e8b45 100644 --- a/tests/data/test_template.py +++ b/tests/data/test_template.py @@ -125,6 +125,37 @@ def test_encode_multiturn(use_fast: bool): ) +@pytest.mark.parametrize("use_fast", [True, False]) +def test_reasoning_encode_oneturn(use_fast: bool): + tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B", use_fast=use_fast) + template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="qwen3")) + prompt_ids, answer_ids = template.encode_oneturn(tokenizer, MESSAGES) + prompt_str = ( + "<|im_start|>user\nHow are you<|im_end|>\n" + "<|im_start|>assistant\nI am fine!<|im_end|>\n" + "<|im_start|>user\n你好<|im_end|>\n" + "<|im_start|>assistant\n\n\n\n\n" + ) + answer_str = "很高兴认识你!<|im_end|>\n" + _check_tokenization(tokenizer, (prompt_ids, answer_ids), (prompt_str, answer_str)) + + +@pytest.mark.parametrize("use_fast", [True, False]) +def test_reasoning_encode_multiturn(use_fast: bool): + tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B", use_fast=use_fast) + template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="qwen3")) + encoded_pairs = template.encode_multiturn(tokenizer, MESSAGES) + prompt_str_1 = "<|im_start|>user\nHow are you<|im_end|>\n<|im_start|>assistant\n\n\n\n\n" + answer_str_1 = "I am fine!<|im_end|>\n" + prompt_str_2 = "<|im_start|>user\n你好<|im_end|>\n<|im_start|>assistant\n\n\n\n\n" + answer_str_2 = "很高兴认识你!<|im_end|>\n" + _check_tokenization( + tokenizer, + (encoded_pairs[0][0], encoded_pairs[0][1], encoded_pairs[1][0], encoded_pairs[1][1]), + (prompt_str_1, answer_str_1, prompt_str_2, answer_str_2), + ) + + @pytest.mark.parametrize("use_fast", [True, False]) def test_jinja_template(use_fast: bool): tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3, use_fast=use_fast) @@ -227,6 +258,15 @@ def test_qwen2_5_template(use_fast: bool): @pytest.mark.parametrize("use_fast", [True, False]) def test_qwen3_template(use_fast: bool): + prompt_str = ( + "<|im_start|>user\nHow are you<|im_end|>\n" + "<|im_start|>assistant\nI am fine!<|im_end|>\n" + "<|im_start|>user\n你好<|im_end|>\n" + "<|im_start|>assistant\n\n\n\n\n" + ) + answer_str = "很高兴认识你!<|im_end|>\n" + _check_template("Qwen/Qwen3-8B", "qwen3", prompt_str, answer_str, use_fast) + prompt_str = ( "<|im_start|>user\nHow are you<|im_end|>\n" "<|im_start|>assistant\nI am fine!<|im_end|>\n"