mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-01 11:12:50 +08:00
[data] optimize qwen3 loss computation (#7923)
This commit is contained in:
parent
a8430f4244
commit
d8295cd601
@ -5,7 +5,7 @@
|
||||
[](https://github.com/hiyouga/LLaMA-Factory/graphs/contributors)
|
||||
[](https://github.com/hiyouga/LLaMA-Factory/actions/workflows/tests.yml)
|
||||
[](https://pypi.org/project/llamafactory/)
|
||||
[](https://scholar.google.com/scholar?cites=12620864006390196564)
|
||||
[](https://scholar.google.com/scholar?cites=12620864006390196564)
|
||||
[](https://github.com/hiyouga/LLaMA-Factory/pulls)
|
||||
|
||||
[](https://twitter.com/llamafactory_ai)
|
||||
@ -246,11 +246,11 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
|
||||
| [Command R](https://huggingface.co/CohereForAI) | 35B/104B | cohere |
|
||||
| [DeepSeek (Code/MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek |
|
||||
| [DeepSeek 2.5/3](https://huggingface.co/deepseek-ai) | 236B/671B | deepseek3 |
|
||||
| [DeepSeek R1 (Distill)](https://huggingface.co/deepseek-ai) | 1.5B/7B/8B/14B/32B/70B/671B | deepseek3 |
|
||||
| [DeepSeek R1 (Distill)](https://huggingface.co/deepseek-ai) | 1.5B/7B/8B/14B/32B/70B/671B | deepseekr1 |
|
||||
| [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon |
|
||||
| [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma |
|
||||
| [Gemma 3](https://huggingface.co/google) | 1B/4B/12B/27B | gemma3/gemma (1B) |
|
||||
| [GLM-4/GLM-4-0414/GLM-Z1](https://huggingface.co/THUDM) | 9B/32B | glm4 |
|
||||
| [GLM-4/GLM-4-0414/GLM-Z1](https://huggingface.co/THUDM) | 9B/32B | glm4/glmz1 |
|
||||
| [GPT-2](https://huggingface.co/openai-community) | 0.1B/0.4B/0.8B/1.5B | - |
|
||||
| [Granite 3.0-3.3](https://huggingface.co/ibm-granite) | 1B/2B/3B/8B | granite3 |
|
||||
| [Hunyuan](https://huggingface.co/tencent/) | 7B | hunyuan |
|
||||
|
@ -5,7 +5,7 @@
|
||||
[](https://github.com/hiyouga/LLaMA-Factory/graphs/contributors)
|
||||
[](https://github.com/hiyouga/LLaMA-Factory/actions/workflows/tests.yml)
|
||||
[](https://pypi.org/project/llamafactory/)
|
||||
[](https://scholar.google.com/scholar?cites=12620864006390196564)
|
||||
[](https://scholar.google.com/scholar?cites=12620864006390196564)
|
||||
[](https://github.com/hiyouga/LLaMA-Factory/pulls)
|
||||
|
||||
[](https://twitter.com/llamafactory_ai)
|
||||
@ -249,11 +249,11 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
|
||||
| [Command R](https://huggingface.co/CohereForAI) | 35B/104B | cohere |
|
||||
| [DeepSeek (Code/MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek |
|
||||
| [DeepSeek 2.5/3](https://huggingface.co/deepseek-ai) | 236B/671B | deepseek3 |
|
||||
| [DeepSeek R1 (Distill)](https://huggingface.co/deepseek-ai) | 1.5B/7B/8B/14B/32B/70B/671B | deepseek3 |
|
||||
| [DeepSeek R1 (Distill)](https://huggingface.co/deepseek-ai) | 1.5B/7B/8B/14B/32B/70B/671B | deepseekr1 |
|
||||
| [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon |
|
||||
| [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma |
|
||||
| [Gemma 3](https://huggingface.co/google) | 1B/4B/12B/27B | gemma3/gemma (1B) |
|
||||
| [GLM-4/GLM-4-0414/GLM-Z1](https://huggingface.co/THUDM) | 9B/32B | glm4 |
|
||||
| [GLM-4/GLM-4-0414/GLM-Z1](https://huggingface.co/THUDM) | 9B/32B | glm4/glmz1 |
|
||||
| [GPT-2](https://huggingface.co/openai-community) | 0.1B/0.4B/0.8B/1.5B | - |
|
||||
| [Granite 3.0-3.3](https://huggingface.co/ibm-granite) | 1B/2B/3B/8B | granite3 |
|
||||
| [Hunyuan](https://huggingface.co/tencent/) | 7B | hunyuan |
|
||||
|
@ -103,9 +103,11 @@ class HuggingfaceEngine(BaseEngine):
|
||||
messages = template.mm_plugin.process_messages(
|
||||
messages, mm_input_dict["images"], mm_input_dict["videos"], mm_input_dict["audios"], processor
|
||||
)
|
||||
paired_messages = messages + [{"role": "assistant", "content": ""}]
|
||||
# add thought words to avoid skipping thinking
|
||||
paired_messages = messages + [{"role": "assistant", "content": template.add_thought("")}]
|
||||
system = system or generating_args["default_system"]
|
||||
prompt_ids, _ = template.encode_oneturn(tokenizer, paired_messages, system, tools)
|
||||
enable_thinking = input_kwargs.pop("enable_thinking", True)
|
||||
prompt_ids, _ = template.encode_oneturn(tokenizer, paired_messages, system, tools, enable_thinking)
|
||||
prompt_ids, _ = template.mm_plugin.process_token_ids(
|
||||
prompt_ids,
|
||||
None,
|
||||
|
@ -146,9 +146,11 @@ class SGLangEngine(BaseEngine):
|
||||
messages = self.template.mm_plugin.process_messages(
|
||||
messages, images or [], videos or [], audios or [], self.processor
|
||||
)
|
||||
paired_messages = messages + [{"role": "assistant", "content": ""}]
|
||||
# add thought words to avoid skipping thinking
|
||||
paired_messages = messages + [{"role": "assistant", "content": self.template.add_thought("")}]
|
||||
system = system or self.generating_args["default_system"]
|
||||
prompt_ids, _ = self.template.encode_oneturn(self.tokenizer, paired_messages, system, tools)
|
||||
enable_thinking = input_kwargs.pop("enable_thinking", True)
|
||||
prompt_ids, _ = self.template.encode_oneturn(self.tokenizer, paired_messages, system, tools, enable_thinking)
|
||||
prompt_length = len(prompt_ids)
|
||||
|
||||
temperature: Optional[float] = input_kwargs.pop("temperature", None)
|
||||
|
@ -123,9 +123,11 @@ class VllmEngine(BaseEngine):
|
||||
messages = self.template.mm_plugin.process_messages(
|
||||
messages, images or [], videos or [], audios or [], self.processor
|
||||
)
|
||||
paired_messages = messages + [{"role": "assistant", "content": ""}]
|
||||
# add thought words to avoid skipping thinking
|
||||
paired_messages = messages + [{"role": "assistant", "content": self.template.add_thought("")}]
|
||||
system = system or self.generating_args["default_system"]
|
||||
prompt_ids, _ = self.template.encode_oneturn(self.tokenizer, paired_messages, system, tools)
|
||||
enable_thinking = input_kwargs.pop("enable_thinking", True)
|
||||
prompt_ids, _ = self.template.encode_oneturn(self.tokenizer, paired_messages, system, tools, enable_thinking)
|
||||
prompt_length = len(prompt_ids)
|
||||
|
||||
temperature: Optional[float] = input_kwargs.pop("temperature", None)
|
||||
|
@ -13,6 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import re
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
|
||||
@ -59,9 +60,10 @@ class Template:
|
||||
messages: list[dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
enable_thinking: bool = True,
|
||||
) -> tuple[list[int], list[int]]:
|
||||
r"""Return a single pair of token ids representing prompt and response respectively."""
|
||||
encoded_messages = self._encode(tokenizer, messages, system, tools, remove_thought=True)
|
||||
encoded_messages = self._encode(tokenizer, messages, system, tools)
|
||||
prompt_ids = []
|
||||
for encoded_ids in encoded_messages[:-1]:
|
||||
prompt_ids += encoded_ids
|
||||
@ -77,7 +79,7 @@ class Template:
|
||||
tools: Optional[str] = None,
|
||||
) -> list[tuple[list[int], list[int]]]:
|
||||
r"""Return multiple pairs of token ids representing prompts and responses respectively."""
|
||||
encoded_messages = self._encode(tokenizer, messages, system, tools, remove_thought=False)
|
||||
encoded_messages = self._encode(tokenizer, messages, system, tools)
|
||||
return [(encoded_messages[i], encoded_messages[i + 1]) for i in range(0, len(encoded_messages), 2)]
|
||||
|
||||
def extract_tool(self, content: str) -> Union[str, list["FunctionCall"]]:
|
||||
@ -92,6 +94,19 @@ class Template:
|
||||
|
||||
return list(stop_token_ids)
|
||||
|
||||
def add_thought(self, content: str) -> str:
|
||||
r"""Add empty thought to assistant message."""
|
||||
return f"{self.thought_words[0]}\n\n{self.thought_words[1]}\n\n" + content
|
||||
|
||||
def remove_thought(self, content: str) -> str:
|
||||
r"""Remove thought from assistant message."""
|
||||
pattern = re.compile(f"{re.escape(self.thought_words[0])}(.*?){re.escape(self.thought_words[1])}", re.DOTALL)
|
||||
return re.sub(pattern, "", content).lstrip("\n")
|
||||
|
||||
def get_thought_word_ids(self, tokenizer: "PreTrainedTokenizer") -> list[int]:
|
||||
r"""Get the token ids of thought words."""
|
||||
return tokenizer.encode(f"{self.thought_words[0]}\n\n{self.thought_words[1]}\n\n", add_special_tokens=False)
|
||||
|
||||
def _convert_elements_to_ids(self, tokenizer: "PreTrainedTokenizer", elements: "SLOTS") -> list[int]:
|
||||
r"""Convert elements to token ids."""
|
||||
token_ids = []
|
||||
@ -111,18 +126,12 @@ class Template:
|
||||
|
||||
return token_ids
|
||||
|
||||
def _remove_thought(self, content: str) -> str:
|
||||
r"""Remove thought from assistant message."""
|
||||
pattern = re.compile(f"{re.escape(self.thought_words[0])}(.*?){re.escape(self.thought_words[1])}", re.DOTALL)
|
||||
return re.sub(pattern, "", content).lstrip("\n")
|
||||
|
||||
def _encode(
|
||||
self,
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
messages: list[dict[str, str]],
|
||||
system: Optional[str],
|
||||
tools: Optional[str],
|
||||
remove_thought: bool,
|
||||
) -> list[list[int]]:
|
||||
r"""Encode formatted inputs to pairs of token ids.
|
||||
|
||||
@ -140,18 +149,14 @@ class Template:
|
||||
tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
|
||||
elements += self.format_system.apply(content=(system + tool_text))
|
||||
|
||||
content = message["content"]
|
||||
if remove_thought and message["role"] == Role.ASSISTANT and (i != len(messages) - 1):
|
||||
content = self._remove_thought(content)
|
||||
|
||||
if message["role"] == Role.USER:
|
||||
elements += self.format_user.apply(content=content, idx=str(i // 2))
|
||||
elements += self.format_user.apply(content=message["content"], idx=str(i // 2))
|
||||
elif message["role"] == Role.ASSISTANT:
|
||||
elements += self.format_assistant.apply(content=content)
|
||||
elements += self.format_assistant.apply(content=message["content"])
|
||||
elif message["role"] == Role.OBSERVATION:
|
||||
elements += self.format_observation.apply(content=content)
|
||||
elements += self.format_observation.apply(content=message["content"])
|
||||
elif message["role"] == Role.FUNCTION:
|
||||
elements += self.format_function.apply(content=content)
|
||||
elements += self.format_function.apply(content=message["content"])
|
||||
else:
|
||||
raise NotImplementedError("Unexpected role: {}".format(message["role"]))
|
||||
|
||||
@ -331,7 +336,6 @@ class Llama2Template(Template):
|
||||
messages: list[dict[str, str]],
|
||||
system: str,
|
||||
tools: str,
|
||||
remove_thought: bool,
|
||||
) -> list[list[int]]:
|
||||
system = system or self.default_system
|
||||
encoded_messages = []
|
||||
@ -345,18 +349,14 @@ class Llama2Template(Template):
|
||||
tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
|
||||
system_text = self.format_system.apply(content=(system + tool_text))[0]
|
||||
|
||||
content = message["content"]
|
||||
if remove_thought and message["role"] == Role.ASSISTANT and (i != len(messages) - 1):
|
||||
content = self._remove_thought(content)
|
||||
|
||||
if message["role"] == Role.USER:
|
||||
elements += self.format_user.apply(content=system_text + content)
|
||||
elements += self.format_user.apply(content=system_text + message["content"])
|
||||
elif message["role"] == Role.ASSISTANT:
|
||||
elements += self.format_assistant.apply(content=content)
|
||||
elements += self.format_assistant.apply(content=message["content"])
|
||||
elif message["role"] == Role.OBSERVATION:
|
||||
elements += self.format_observation.apply(content=content)
|
||||
elements += self.format_observation.apply(content=message["content"])
|
||||
elif message["role"] == Role.FUNCTION:
|
||||
elements += self.format_function.apply(content=content)
|
||||
elements += self.format_function.apply(content=message["content"])
|
||||
else:
|
||||
raise NotImplementedError("Unexpected role: {}".format(message["role"]))
|
||||
|
||||
@ -395,6 +395,60 @@ class Llama2Template(Template):
|
||||
return jinja_template
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReasoningTemplate(Template):
|
||||
r"""A template that add thought to assistant message."""
|
||||
|
||||
@override
|
||||
def encode_oneturn(
|
||||
self,
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
messages: list[dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
enable_thinking: bool = True,
|
||||
) -> tuple[list[int], list[int]]:
|
||||
messages = deepcopy(messages)
|
||||
for i in range(len(messages)):
|
||||
if messages[i]["role"] == Role.ASSISTANT and (i != len(messages) - 1):
|
||||
messages[i]["content"] = self.remove_thought(messages[i]["content"])
|
||||
|
||||
encoded_messages = self._encode(tokenizer, messages, system, tools)
|
||||
prompt_ids = []
|
||||
for encoded_ids in encoded_messages[:-1]:
|
||||
prompt_ids += encoded_ids
|
||||
|
||||
if not enable_thinking or (
|
||||
messages[-1]["role"] == Role.ASSISTANT
|
||||
and self.thought_words[0] not in messages[-1]["content"]
|
||||
and self.thought_words[1] not in messages[-1]["content"]
|
||||
):
|
||||
prompt_ids += self.get_thought_word_ids(tokenizer)
|
||||
|
||||
response_ids = encoded_messages[-1]
|
||||
return prompt_ids, response_ids
|
||||
|
||||
@override
|
||||
def encode_multiturn(
|
||||
self,
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
messages: list[dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
) -> list[tuple[list[int], list[int]]]:
|
||||
messages = deepcopy(messages)
|
||||
encoded_messages = self._encode(tokenizer, messages, system, tools)
|
||||
for i in range(len(messages) - 1):
|
||||
if (
|
||||
messages[i + 1]["role"] == Role.ASSISTANT
|
||||
and self.thought_words[0] not in messages[i + 1]["content"]
|
||||
and self.thought_words[1] not in messages[i + 1]["content"]
|
||||
):
|
||||
encoded_messages[i] += self.get_thought_word_ids(tokenizer)
|
||||
|
||||
return [(encoded_messages[i], encoded_messages[i + 1]) for i in range(0, len(encoded_messages), 2)]
|
||||
|
||||
|
||||
TEMPLATES: dict[str, "Template"] = {}
|
||||
|
||||
|
||||
@ -778,6 +832,15 @@ register_template(
|
||||
)
|
||||
|
||||
|
||||
# copied from deepseek3 template
|
||||
register_template(
|
||||
name="deepseekr1",
|
||||
format_user=StringFormatter(slots=["<|User|>{{content}}<|Assistant|>"]),
|
||||
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||
template_class=ReasoningTemplate,
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="deepseekcoder",
|
||||
format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n### Response:"]),
|
||||
@ -878,6 +941,22 @@ register_template(
|
||||
)
|
||||
|
||||
|
||||
# copied from glm4 template
|
||||
register_template(
|
||||
name="glmz1",
|
||||
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]),
|
||||
format_assistant=StringFormatter(slots=["\n{{content}}"]),
|
||||
format_system=StringFormatter(slots=["<|system|>\n{{content}}"]),
|
||||
format_function=FunctionFormatter(slots=["{{content}}"], tool_format="glm4"),
|
||||
format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>"]),
|
||||
format_tools=ToolFormatter(tool_format="glm4"),
|
||||
format_prefix=EmptyFormatter(slots=["[gMASK]<sop>"]),
|
||||
stop_words=["<|user|>", "<|observation|>"],
|
||||
efficient_eos=True,
|
||||
template_class=ReasoningTemplate,
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="granite3",
|
||||
format_user=StringFormatter(
|
||||
@ -1458,6 +1537,7 @@ register_template(
|
||||
format_tools=ToolFormatter(tool_format="qwen"),
|
||||
stop_words=["<|im_end|>"],
|
||||
replace_eos=True,
|
||||
template_class=ReasoningTemplate,
|
||||
)
|
||||
|
||||
|
||||
|
@ -533,6 +533,17 @@ register_model_group(
|
||||
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V3",
|
||||
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V3",
|
||||
},
|
||||
"DeepSeek-V3-671B-0324-Chat": {
|
||||
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V3-0324",
|
||||
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V3-0324",
|
||||
},
|
||||
},
|
||||
template="deepseek3",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"DeepSeek-R1-1.5B-Distill": {
|
||||
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
|
||||
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
|
||||
@ -566,7 +577,7 @@ register_model_group(
|
||||
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-R1",
|
||||
},
|
||||
},
|
||||
template="deepseek3",
|
||||
template="deepseekr1",
|
||||
)
|
||||
|
||||
|
||||
@ -737,6 +748,13 @@ register_model_group(
|
||||
DownloadSource.DEFAULT: "THUDM/GLM-4-32B-0414",
|
||||
DownloadSource.MODELSCOPE: "ZhipuAI/GLM-4-32B-0414",
|
||||
},
|
||||
},
|
||||
template="glm4",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"GLM-Z1-9B-0414-Chat": {
|
||||
DownloadSource.DEFAULT: "THUDM/GLM-Z1-9B-0414",
|
||||
DownloadSource.MODELSCOPE: "ZhipuAI/GLM-Z1-9B-0414",
|
||||
@ -746,7 +764,7 @@ register_model_group(
|
||||
DownloadSource.MODELSCOPE: "ZhipuAI/GLM-Z1-32B-0414",
|
||||
},
|
||||
},
|
||||
template="glm4",
|
||||
template="glmz1",
|
||||
)
|
||||
|
||||
|
||||
|
@ -191,6 +191,7 @@ class WebChatModel(ChatModel):
|
||||
temperature: float,
|
||||
skip_special_tokens: bool,
|
||||
escape_html: bool,
|
||||
enable_thinking: bool,
|
||||
) -> Generator[tuple[list[dict[str, str]], list[dict[str, str]]], None, None]:
|
||||
r"""Generate output text in stream.
|
||||
|
||||
@ -210,6 +211,7 @@ class WebChatModel(ChatModel):
|
||||
top_p=top_p,
|
||||
temperature=temperature,
|
||||
skip_special_tokens=skip_special_tokens,
|
||||
enable_thinking=enable_thinking,
|
||||
):
|
||||
response += new_text
|
||||
if tools:
|
||||
|
@ -79,6 +79,7 @@ def create_chat_box(
|
||||
temperature = gr.Slider(minimum=0.01, maximum=1.5, value=0.95, step=0.01)
|
||||
skip_special_tokens = gr.Checkbox(value=True)
|
||||
escape_html = gr.Checkbox(value=True)
|
||||
enable_thinking = gr.Checkbox(value=True)
|
||||
clear_btn = gr.Button()
|
||||
|
||||
tools.input(check_json_schema, inputs=[tools, engine.manager.get_elem_by_id("top.lang")])
|
||||
@ -103,6 +104,7 @@ def create_chat_box(
|
||||
temperature,
|
||||
skip_special_tokens,
|
||||
escape_html,
|
||||
enable_thinking,
|
||||
],
|
||||
[chatbot, messages],
|
||||
)
|
||||
@ -127,6 +129,7 @@ def create_chat_box(
|
||||
temperature=temperature,
|
||||
skip_special_tokens=skip_special_tokens,
|
||||
escape_html=escape_html,
|
||||
enable_thinking=enable_thinking,
|
||||
clear_btn=clear_btn,
|
||||
),
|
||||
)
|
||||
|
@ -2468,6 +2468,23 @@ LOCALES = {
|
||||
"label": "HTML タグをエスケープ",
|
||||
},
|
||||
},
|
||||
"enable_thinking": {
|
||||
"en": {
|
||||
"label": "Enable thinking",
|
||||
},
|
||||
"ru": {
|
||||
"label": "Включить мышление",
|
||||
},
|
||||
"zh": {
|
||||
"label": "启用思考",
|
||||
},
|
||||
"ko": {
|
||||
"label": "사고를 활성화하다",
|
||||
},
|
||||
"ja": {
|
||||
"label": "思考を可能にする",
|
||||
},
|
||||
},
|
||||
"clear_btn": {
|
||||
"en": {
|
||||
"value": "Clear history",
|
||||
|
@ -125,6 +125,37 @@ def test_encode_multiturn(use_fast: bool):
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_fast", [True, False])
|
||||
def test_reasoning_encode_oneturn(use_fast: bool):
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B", use_fast=use_fast)
|
||||
template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="qwen3"))
|
||||
prompt_ids, answer_ids = template.encode_oneturn(tokenizer, MESSAGES)
|
||||
prompt_str = (
|
||||
"<|im_start|>user\nHow are you<|im_end|>\n"
|
||||
"<|im_start|>assistant\nI am fine!<|im_end|>\n"
|
||||
"<|im_start|>user\n你好<|im_end|>\n"
|
||||
"<|im_start|>assistant\n<think>\n\n</think>\n\n"
|
||||
)
|
||||
answer_str = "很高兴认识你!<|im_end|>\n"
|
||||
_check_tokenization(tokenizer, (prompt_ids, answer_ids), (prompt_str, answer_str))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_fast", [True, False])
|
||||
def test_reasoning_encode_multiturn(use_fast: bool):
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B", use_fast=use_fast)
|
||||
template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="qwen3"))
|
||||
encoded_pairs = template.encode_multiturn(tokenizer, MESSAGES)
|
||||
prompt_str_1 = "<|im_start|>user\nHow are you<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
|
||||
answer_str_1 = "I am fine!<|im_end|>\n"
|
||||
prompt_str_2 = "<|im_start|>user\n你好<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
|
||||
answer_str_2 = "很高兴认识你!<|im_end|>\n"
|
||||
_check_tokenization(
|
||||
tokenizer,
|
||||
(encoded_pairs[0][0], encoded_pairs[0][1], encoded_pairs[1][0], encoded_pairs[1][1]),
|
||||
(prompt_str_1, answer_str_1, prompt_str_2, answer_str_2),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_fast", [True, False])
|
||||
def test_jinja_template(use_fast: bool):
|
||||
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3, use_fast=use_fast)
|
||||
@ -227,6 +258,15 @@ def test_qwen2_5_template(use_fast: bool):
|
||||
|
||||
@pytest.mark.parametrize("use_fast", [True, False])
|
||||
def test_qwen3_template(use_fast: bool):
|
||||
prompt_str = (
|
||||
"<|im_start|>user\nHow are you<|im_end|>\n"
|
||||
"<|im_start|>assistant\nI am fine!<|im_end|>\n"
|
||||
"<|im_start|>user\n你好<|im_end|>\n"
|
||||
"<|im_start|>assistant\n<think>\n\n</think>\n\n"
|
||||
)
|
||||
answer_str = "很高兴认识你!<|im_end|>\n"
|
||||
_check_template("Qwen/Qwen3-8B", "qwen3", prompt_str, answer_str, use_fast)
|
||||
|
||||
prompt_str = (
|
||||
"<|im_start|>user\nHow are you<|im_end|>\n"
|
||||
"<|im_start|>assistant\nI am fine!<|im_end|>\n"
|
||||
|
Loading…
x
Reference in New Issue
Block a user