mirror of
				https://github.com/hiyouga/LLaMA-Factory.git
				synced 2025-11-04 18:02:19 +08:00 
			
		
		
		
	[data] optimize qwen3 loss computation (#7923)
This commit is contained in:
		
							parent
							
								
									73198a6645
								
							
						
					
					
						commit
						052ca871bd
					
				@ -5,7 +5,7 @@
 | 
			
		||||
[](https://github.com/hiyouga/LLaMA-Factory/graphs/contributors)
 | 
			
		||||
[](https://github.com/hiyouga/LLaMA-Factory/actions/workflows/tests.yml)
 | 
			
		||||
[](https://pypi.org/project/llamafactory/)
 | 
			
		||||
[](https://scholar.google.com/scholar?cites=12620864006390196564)
 | 
			
		||||
[](https://scholar.google.com/scholar?cites=12620864006390196564)
 | 
			
		||||
[](https://github.com/hiyouga/LLaMA-Factory/pulls)
 | 
			
		||||
 | 
			
		||||
[](https://twitter.com/llamafactory_ai)
 | 
			
		||||
@ -246,11 +246,11 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
 | 
			
		||||
| [Command R](https://huggingface.co/CohereForAI)                   | 35B/104B                         | cohere              |
 | 
			
		||||
| [DeepSeek (Code/MoE)](https://huggingface.co/deepseek-ai)         | 7B/16B/67B/236B                  | deepseek            |
 | 
			
		||||
| [DeepSeek 2.5/3](https://huggingface.co/deepseek-ai)              | 236B/671B                        | deepseek3           |
 | 
			
		||||
| [DeepSeek R1 (Distill)](https://huggingface.co/deepseek-ai)       | 1.5B/7B/8B/14B/32B/70B/671B      | deepseek3           |
 | 
			
		||||
| [DeepSeek R1 (Distill)](https://huggingface.co/deepseek-ai)       | 1.5B/7B/8B/14B/32B/70B/671B      | deepseekr1          |
 | 
			
		||||
| [Falcon](https://huggingface.co/tiiuae)                           | 7B/11B/40B/180B                  | falcon              |
 | 
			
		||||
| [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google)          | 2B/7B/9B/27B                     | gemma               |
 | 
			
		||||
| [Gemma 3](https://huggingface.co/google)                          | 1B/4B/12B/27B                    | gemma3/gemma (1B)   |
 | 
			
		||||
| [GLM-4/GLM-4-0414/GLM-Z1](https://huggingface.co/THUDM)           | 9B/32B                           | glm4                |
 | 
			
		||||
| [GLM-4/GLM-4-0414/GLM-Z1](https://huggingface.co/THUDM)           | 9B/32B                           | glm4/glmz1          |
 | 
			
		||||
| [GPT-2](https://huggingface.co/openai-community)                  | 0.1B/0.4B/0.8B/1.5B              | -                   |
 | 
			
		||||
| [Granite 3.0-3.3](https://huggingface.co/ibm-granite)             | 1B/2B/3B/8B                      | granite3            |
 | 
			
		||||
| [Hunyuan](https://huggingface.co/tencent/)                        | 7B                               | hunyuan             |
 | 
			
		||||
 | 
			
		||||
@ -5,7 +5,7 @@
 | 
			
		||||
[](https://github.com/hiyouga/LLaMA-Factory/graphs/contributors)
 | 
			
		||||
[](https://github.com/hiyouga/LLaMA-Factory/actions/workflows/tests.yml)
 | 
			
		||||
[](https://pypi.org/project/llamafactory/)
 | 
			
		||||
[](https://scholar.google.com/scholar?cites=12620864006390196564)
 | 
			
		||||
[](https://scholar.google.com/scholar?cites=12620864006390196564)
 | 
			
		||||
[](https://github.com/hiyouga/LLaMA-Factory/pulls)
 | 
			
		||||
 | 
			
		||||
[](https://twitter.com/llamafactory_ai)
 | 
			
		||||
@ -249,11 +249,11 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
 | 
			
		||||
| [Command R](https://huggingface.co/CohereForAI)                   | 35B/104B                         | cohere              |
 | 
			
		||||
| [DeepSeek (Code/MoE)](https://huggingface.co/deepseek-ai)         | 7B/16B/67B/236B                  | deepseek            |
 | 
			
		||||
| [DeepSeek 2.5/3](https://huggingface.co/deepseek-ai)              | 236B/671B                        | deepseek3           |
 | 
			
		||||
| [DeepSeek R1 (Distill)](https://huggingface.co/deepseek-ai)       | 1.5B/7B/8B/14B/32B/70B/671B      | deepseek3           |
 | 
			
		||||
| [DeepSeek R1 (Distill)](https://huggingface.co/deepseek-ai)       | 1.5B/7B/8B/14B/32B/70B/671B      | deepseekr1          |
 | 
			
		||||
| [Falcon](https://huggingface.co/tiiuae)                           | 7B/11B/40B/180B                  | falcon              |
 | 
			
		||||
| [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google)          | 2B/7B/9B/27B                     | gemma               |
 | 
			
		||||
| [Gemma 3](https://huggingface.co/google)                          | 1B/4B/12B/27B                    | gemma3/gemma (1B)   |
 | 
			
		||||
| [GLM-4/GLM-4-0414/GLM-Z1](https://huggingface.co/THUDM)           | 9B/32B                           | glm4                |
 | 
			
		||||
| [GLM-4/GLM-4-0414/GLM-Z1](https://huggingface.co/THUDM)           | 9B/32B                           | glm4/glmz1          |
 | 
			
		||||
| [GPT-2](https://huggingface.co/openai-community)                  | 0.1B/0.4B/0.8B/1.5B              | -                   |
 | 
			
		||||
| [Granite 3.0-3.3](https://huggingface.co/ibm-granite)             | 1B/2B/3B/8B                      | granite3            |
 | 
			
		||||
| [Hunyuan](https://huggingface.co/tencent/)                        | 7B                               | hunyuan             |
 | 
			
		||||
 | 
			
		||||
@ -103,9 +103,11 @@ class HuggingfaceEngine(BaseEngine):
 | 
			
		||||
        messages = template.mm_plugin.process_messages(
 | 
			
		||||
            messages, mm_input_dict["images"], mm_input_dict["videos"], mm_input_dict["audios"], processor
 | 
			
		||||
        )
 | 
			
		||||
        paired_messages = messages + [{"role": "assistant", "content": ""}]
 | 
			
		||||
        # add thought words to avoid skipping thinking
 | 
			
		||||
        paired_messages = messages + [{"role": "assistant", "content": template.add_thought("")}]
 | 
			
		||||
        system = system or generating_args["default_system"]
 | 
			
		||||
        prompt_ids, _ = template.encode_oneturn(tokenizer, paired_messages, system, tools)
 | 
			
		||||
        enable_thinking = input_kwargs.pop("enable_thinking", True)
 | 
			
		||||
        prompt_ids, _ = template.encode_oneturn(tokenizer, paired_messages, system, tools, enable_thinking)
 | 
			
		||||
        prompt_ids, _ = template.mm_plugin.process_token_ids(
 | 
			
		||||
            prompt_ids,
 | 
			
		||||
            None,
 | 
			
		||||
 | 
			
		||||
@ -146,9 +146,11 @@ class SGLangEngine(BaseEngine):
 | 
			
		||||
        messages = self.template.mm_plugin.process_messages(
 | 
			
		||||
            messages, images or [], videos or [], audios or [], self.processor
 | 
			
		||||
        )
 | 
			
		||||
        paired_messages = messages + [{"role": "assistant", "content": ""}]
 | 
			
		||||
        # add thought words to avoid skipping thinking
 | 
			
		||||
        paired_messages = messages + [{"role": "assistant", "content": self.template.add_thought("")}]
 | 
			
		||||
        system = system or self.generating_args["default_system"]
 | 
			
		||||
        prompt_ids, _ = self.template.encode_oneturn(self.tokenizer, paired_messages, system, tools)
 | 
			
		||||
        enable_thinking = input_kwargs.pop("enable_thinking", True)
 | 
			
		||||
        prompt_ids, _ = self.template.encode_oneturn(self.tokenizer, paired_messages, system, tools, enable_thinking)
 | 
			
		||||
        prompt_length = len(prompt_ids)
 | 
			
		||||
 | 
			
		||||
        temperature: Optional[float] = input_kwargs.pop("temperature", None)
 | 
			
		||||
 | 
			
		||||
@ -123,9 +123,11 @@ class VllmEngine(BaseEngine):
 | 
			
		||||
        messages = self.template.mm_plugin.process_messages(
 | 
			
		||||
            messages, images or [], videos or [], audios or [], self.processor
 | 
			
		||||
        )
 | 
			
		||||
        paired_messages = messages + [{"role": "assistant", "content": ""}]
 | 
			
		||||
        # add thought words to avoid skipping thinking
 | 
			
		||||
        paired_messages = messages + [{"role": "assistant", "content": self.template.add_thought("")}]
 | 
			
		||||
        system = system or self.generating_args["default_system"]
 | 
			
		||||
        prompt_ids, _ = self.template.encode_oneturn(self.tokenizer, paired_messages, system, tools)
 | 
			
		||||
        enable_thinking = input_kwargs.pop("enable_thinking", True)
 | 
			
		||||
        prompt_ids, _ = self.template.encode_oneturn(self.tokenizer, paired_messages, system, tools, enable_thinking)
 | 
			
		||||
        prompt_length = len(prompt_ids)
 | 
			
		||||
 | 
			
		||||
        temperature: Optional[float] = input_kwargs.pop("temperature", None)
 | 
			
		||||
 | 
			
		||||
@ -13,6 +13,7 @@
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
import re
 | 
			
		||||
from copy import deepcopy
 | 
			
		||||
from dataclasses import dataclass
 | 
			
		||||
from typing import TYPE_CHECKING, Optional, Union
 | 
			
		||||
 | 
			
		||||
@ -59,9 +60,10 @@ class Template:
 | 
			
		||||
        messages: list[dict[str, str]],
 | 
			
		||||
        system: Optional[str] = None,
 | 
			
		||||
        tools: Optional[str] = None,
 | 
			
		||||
        enable_thinking: bool = True,
 | 
			
		||||
    ) -> tuple[list[int], list[int]]:
 | 
			
		||||
        r"""Return a single pair of token ids representing prompt and response respectively."""
 | 
			
		||||
        encoded_messages = self._encode(tokenizer, messages, system, tools, remove_thought=True)
 | 
			
		||||
        encoded_messages = self._encode(tokenizer, messages, system, tools)
 | 
			
		||||
        prompt_ids = []
 | 
			
		||||
        for encoded_ids in encoded_messages[:-1]:
 | 
			
		||||
            prompt_ids += encoded_ids
 | 
			
		||||
@ -77,7 +79,7 @@ class Template:
 | 
			
		||||
        tools: Optional[str] = None,
 | 
			
		||||
    ) -> list[tuple[list[int], list[int]]]:
 | 
			
		||||
        r"""Return multiple pairs of token ids representing prompts and responses respectively."""
 | 
			
		||||
        encoded_messages = self._encode(tokenizer, messages, system, tools, remove_thought=False)
 | 
			
		||||
        encoded_messages = self._encode(tokenizer, messages, system, tools)
 | 
			
		||||
        return [(encoded_messages[i], encoded_messages[i + 1]) for i in range(0, len(encoded_messages), 2)]
 | 
			
		||||
 | 
			
		||||
    def extract_tool(self, content: str) -> Union[str, list["FunctionCall"]]:
 | 
			
		||||
@ -92,6 +94,19 @@ class Template:
 | 
			
		||||
 | 
			
		||||
        return list(stop_token_ids)
 | 
			
		||||
 | 
			
		||||
    def add_thought(self, content: str) -> str:
 | 
			
		||||
        r"""Add empty thought to assistant message."""
 | 
			
		||||
        return f"{self.thought_words[0]}\n\n{self.thought_words[1]}\n\n" + content
 | 
			
		||||
 | 
			
		||||
    def remove_thought(self, content: str) -> str:
 | 
			
		||||
        r"""Remove thought from assistant message."""
 | 
			
		||||
        pattern = re.compile(f"{re.escape(self.thought_words[0])}(.*?){re.escape(self.thought_words[1])}", re.DOTALL)
 | 
			
		||||
        return re.sub(pattern, "", content).lstrip("\n")
 | 
			
		||||
 | 
			
		||||
    def get_thought_word_ids(self, tokenizer: "PreTrainedTokenizer") -> list[int]:
 | 
			
		||||
        r"""Get the token ids of thought words."""
 | 
			
		||||
        return tokenizer.encode(f"{self.thought_words[0]}\n\n{self.thought_words[1]}\n\n", add_special_tokens=False)
 | 
			
		||||
 | 
			
		||||
    def _convert_elements_to_ids(self, tokenizer: "PreTrainedTokenizer", elements: "SLOTS") -> list[int]:
 | 
			
		||||
        r"""Convert elements to token ids."""
 | 
			
		||||
        token_ids = []
 | 
			
		||||
@ -111,18 +126,12 @@ class Template:
 | 
			
		||||
 | 
			
		||||
        return token_ids
 | 
			
		||||
 | 
			
		||||
    def _remove_thought(self, content: str) -> str:
 | 
			
		||||
        r"""Remove thought from assistant message."""
 | 
			
		||||
        pattern = re.compile(f"{re.escape(self.thought_words[0])}(.*?){re.escape(self.thought_words[1])}", re.DOTALL)
 | 
			
		||||
        return re.sub(pattern, "", content).lstrip("\n")
 | 
			
		||||
 | 
			
		||||
    def _encode(
 | 
			
		||||
        self,
 | 
			
		||||
        tokenizer: "PreTrainedTokenizer",
 | 
			
		||||
        messages: list[dict[str, str]],
 | 
			
		||||
        system: Optional[str],
 | 
			
		||||
        tools: Optional[str],
 | 
			
		||||
        remove_thought: bool,
 | 
			
		||||
    ) -> list[list[int]]:
 | 
			
		||||
        r"""Encode formatted inputs to pairs of token ids.
 | 
			
		||||
 | 
			
		||||
@ -140,18 +149,14 @@ class Template:
 | 
			
		||||
                    tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
 | 
			
		||||
                    elements += self.format_system.apply(content=(system + tool_text))
 | 
			
		||||
 | 
			
		||||
            content = message["content"]
 | 
			
		||||
            if remove_thought and message["role"] == Role.ASSISTANT and (i != len(messages) - 1):
 | 
			
		||||
                content = self._remove_thought(content)
 | 
			
		||||
 | 
			
		||||
            if message["role"] == Role.USER:
 | 
			
		||||
                elements += self.format_user.apply(content=content, idx=str(i // 2))
 | 
			
		||||
                elements += self.format_user.apply(content=message["content"], idx=str(i // 2))
 | 
			
		||||
            elif message["role"] == Role.ASSISTANT:
 | 
			
		||||
                elements += self.format_assistant.apply(content=content)
 | 
			
		||||
                elements += self.format_assistant.apply(content=message["content"])
 | 
			
		||||
            elif message["role"] == Role.OBSERVATION:
 | 
			
		||||
                elements += self.format_observation.apply(content=content)
 | 
			
		||||
                elements += self.format_observation.apply(content=message["content"])
 | 
			
		||||
            elif message["role"] == Role.FUNCTION:
 | 
			
		||||
                elements += self.format_function.apply(content=content)
 | 
			
		||||
                elements += self.format_function.apply(content=message["content"])
 | 
			
		||||
            else:
 | 
			
		||||
                raise NotImplementedError("Unexpected role: {}".format(message["role"]))
 | 
			
		||||
 | 
			
		||||
@ -331,7 +336,6 @@ class Llama2Template(Template):
 | 
			
		||||
        messages: list[dict[str, str]],
 | 
			
		||||
        system: str,
 | 
			
		||||
        tools: str,
 | 
			
		||||
        remove_thought: bool,
 | 
			
		||||
    ) -> list[list[int]]:
 | 
			
		||||
        system = system or self.default_system
 | 
			
		||||
        encoded_messages = []
 | 
			
		||||
@ -345,18 +349,14 @@ class Llama2Template(Template):
 | 
			
		||||
                    tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
 | 
			
		||||
                    system_text = self.format_system.apply(content=(system + tool_text))[0]
 | 
			
		||||
 | 
			
		||||
            content = message["content"]
 | 
			
		||||
            if remove_thought and message["role"] == Role.ASSISTANT and (i != len(messages) - 1):
 | 
			
		||||
                content = self._remove_thought(content)
 | 
			
		||||
 | 
			
		||||
            if message["role"] == Role.USER:
 | 
			
		||||
                elements += self.format_user.apply(content=system_text + content)
 | 
			
		||||
                elements += self.format_user.apply(content=system_text + message["content"])
 | 
			
		||||
            elif message["role"] == Role.ASSISTANT:
 | 
			
		||||
                elements += self.format_assistant.apply(content=content)
 | 
			
		||||
                elements += self.format_assistant.apply(content=message["content"])
 | 
			
		||||
            elif message["role"] == Role.OBSERVATION:
 | 
			
		||||
                elements += self.format_observation.apply(content=content)
 | 
			
		||||
                elements += self.format_observation.apply(content=message["content"])
 | 
			
		||||
            elif message["role"] == Role.FUNCTION:
 | 
			
		||||
                elements += self.format_function.apply(content=content)
 | 
			
		||||
                elements += self.format_function.apply(content=message["content"])
 | 
			
		||||
            else:
 | 
			
		||||
                raise NotImplementedError("Unexpected role: {}".format(message["role"]))
 | 
			
		||||
 | 
			
		||||
@ -395,6 +395,60 @@ class Llama2Template(Template):
 | 
			
		||||
        return jinja_template
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class ReasoningTemplate(Template):
 | 
			
		||||
    r"""A template that add thought to assistant message."""
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    def encode_oneturn(
 | 
			
		||||
        self,
 | 
			
		||||
        tokenizer: "PreTrainedTokenizer",
 | 
			
		||||
        messages: list[dict[str, str]],
 | 
			
		||||
        system: Optional[str] = None,
 | 
			
		||||
        tools: Optional[str] = None,
 | 
			
		||||
        enable_thinking: bool = True,
 | 
			
		||||
    ) -> tuple[list[int], list[int]]:
 | 
			
		||||
        messages = deepcopy(messages)
 | 
			
		||||
        for i in range(len(messages)):
 | 
			
		||||
            if messages[i]["role"] == Role.ASSISTANT and (i != len(messages) - 1):
 | 
			
		||||
                messages[i]["content"] = self.remove_thought(messages[i]["content"])
 | 
			
		||||
 | 
			
		||||
        encoded_messages = self._encode(tokenizer, messages, system, tools)
 | 
			
		||||
        prompt_ids = []
 | 
			
		||||
        for encoded_ids in encoded_messages[:-1]:
 | 
			
		||||
            prompt_ids += encoded_ids
 | 
			
		||||
 | 
			
		||||
        if not enable_thinking or (
 | 
			
		||||
            messages[-1]["role"] == Role.ASSISTANT
 | 
			
		||||
            and self.thought_words[0] not in messages[-1]["content"]
 | 
			
		||||
            and self.thought_words[1] not in messages[-1]["content"]
 | 
			
		||||
        ):
 | 
			
		||||
            prompt_ids += self.get_thought_word_ids(tokenizer)
 | 
			
		||||
 | 
			
		||||
        response_ids = encoded_messages[-1]
 | 
			
		||||
        return prompt_ids, response_ids
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    def encode_multiturn(
 | 
			
		||||
        self,
 | 
			
		||||
        tokenizer: "PreTrainedTokenizer",
 | 
			
		||||
        messages: list[dict[str, str]],
 | 
			
		||||
        system: Optional[str] = None,
 | 
			
		||||
        tools: Optional[str] = None,
 | 
			
		||||
    ) -> list[tuple[list[int], list[int]]]:
 | 
			
		||||
        messages = deepcopy(messages)
 | 
			
		||||
        encoded_messages = self._encode(tokenizer, messages, system, tools)
 | 
			
		||||
        for i in range(len(messages) - 1):
 | 
			
		||||
            if (
 | 
			
		||||
                messages[i + 1]["role"] == Role.ASSISTANT
 | 
			
		||||
                and self.thought_words[0] not in messages[i + 1]["content"]
 | 
			
		||||
                and self.thought_words[1] not in messages[i + 1]["content"]
 | 
			
		||||
            ):
 | 
			
		||||
                encoded_messages[i] += self.get_thought_word_ids(tokenizer)
 | 
			
		||||
 | 
			
		||||
        return [(encoded_messages[i], encoded_messages[i + 1]) for i in range(0, len(encoded_messages), 2)]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
TEMPLATES: dict[str, "Template"] = {}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -778,6 +832,15 @@ register_template(
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# copied from deepseek3 template
 | 
			
		||||
register_template(
 | 
			
		||||
    name="deepseekr1",
 | 
			
		||||
    format_user=StringFormatter(slots=["<|User|>{{content}}<|Assistant|>"]),
 | 
			
		||||
    format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
 | 
			
		||||
    template_class=ReasoningTemplate,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
register_template(
 | 
			
		||||
    name="deepseekcoder",
 | 
			
		||||
    format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n### Response:"]),
 | 
			
		||||
@ -878,6 +941,22 @@ register_template(
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# copied from glm4 template
 | 
			
		||||
register_template(
 | 
			
		||||
    name="glmz1",
 | 
			
		||||
    format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]),
 | 
			
		||||
    format_assistant=StringFormatter(slots=["\n{{content}}"]),
 | 
			
		||||
    format_system=StringFormatter(slots=["<|system|>\n{{content}}"]),
 | 
			
		||||
    format_function=FunctionFormatter(slots=["{{content}}"], tool_format="glm4"),
 | 
			
		||||
    format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>"]),
 | 
			
		||||
    format_tools=ToolFormatter(tool_format="glm4"),
 | 
			
		||||
    format_prefix=EmptyFormatter(slots=["[gMASK]<sop>"]),
 | 
			
		||||
    stop_words=["<|user|>", "<|observation|>"],
 | 
			
		||||
    efficient_eos=True,
 | 
			
		||||
    template_class=ReasoningTemplate,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
register_template(
 | 
			
		||||
    name="granite3",
 | 
			
		||||
    format_user=StringFormatter(
 | 
			
		||||
@ -1458,6 +1537,7 @@ register_template(
 | 
			
		||||
    format_tools=ToolFormatter(tool_format="qwen"),
 | 
			
		||||
    stop_words=["<|im_end|>"],
 | 
			
		||||
    replace_eos=True,
 | 
			
		||||
    template_class=ReasoningTemplate,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -533,6 +533,17 @@ register_model_group(
 | 
			
		||||
            DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V3",
 | 
			
		||||
            DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V3",
 | 
			
		||||
        },
 | 
			
		||||
        "DeepSeek-V3-671B-0324-Chat": {
 | 
			
		||||
            DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V3-0324",
 | 
			
		||||
            DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V3-0324",
 | 
			
		||||
        },
 | 
			
		||||
    },
 | 
			
		||||
    template="deepseek3",
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
register_model_group(
 | 
			
		||||
    models={
 | 
			
		||||
        "DeepSeek-R1-1.5B-Distill": {
 | 
			
		||||
            DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
 | 
			
		||||
            DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
 | 
			
		||||
@ -566,7 +577,7 @@ register_model_group(
 | 
			
		||||
            DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-R1",
 | 
			
		||||
        },
 | 
			
		||||
    },
 | 
			
		||||
    template="deepseek3",
 | 
			
		||||
    template="deepseekr1",
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -737,6 +748,13 @@ register_model_group(
 | 
			
		||||
            DownloadSource.DEFAULT: "THUDM/GLM-4-32B-0414",
 | 
			
		||||
            DownloadSource.MODELSCOPE: "ZhipuAI/GLM-4-32B-0414",
 | 
			
		||||
        },
 | 
			
		||||
    },
 | 
			
		||||
    template="glm4",
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
register_model_group(
 | 
			
		||||
    models={
 | 
			
		||||
        "GLM-Z1-9B-0414-Chat": {
 | 
			
		||||
            DownloadSource.DEFAULT: "THUDM/GLM-Z1-9B-0414",
 | 
			
		||||
            DownloadSource.MODELSCOPE: "ZhipuAI/GLM-Z1-9B-0414",
 | 
			
		||||
@ -746,7 +764,7 @@ register_model_group(
 | 
			
		||||
            DownloadSource.MODELSCOPE: "ZhipuAI/GLM-Z1-32B-0414",
 | 
			
		||||
        },
 | 
			
		||||
    },
 | 
			
		||||
    template="glm4",
 | 
			
		||||
    template="glmz1",
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -191,6 +191,7 @@ class WebChatModel(ChatModel):
 | 
			
		||||
        temperature: float,
 | 
			
		||||
        skip_special_tokens: bool,
 | 
			
		||||
        escape_html: bool,
 | 
			
		||||
        enable_thinking: bool,
 | 
			
		||||
    ) -> Generator[tuple[list[dict[str, str]], list[dict[str, str]]], None, None]:
 | 
			
		||||
        r"""Generate output text in stream.
 | 
			
		||||
 | 
			
		||||
@ -210,6 +211,7 @@ class WebChatModel(ChatModel):
 | 
			
		||||
            top_p=top_p,
 | 
			
		||||
            temperature=temperature,
 | 
			
		||||
            skip_special_tokens=skip_special_tokens,
 | 
			
		||||
            enable_thinking=enable_thinking,
 | 
			
		||||
        ):
 | 
			
		||||
            response += new_text
 | 
			
		||||
            if tools:
 | 
			
		||||
 | 
			
		||||
@ -79,6 +79,7 @@ def create_chat_box(
 | 
			
		||||
                temperature = gr.Slider(minimum=0.01, maximum=1.5, value=0.95, step=0.01)
 | 
			
		||||
                skip_special_tokens = gr.Checkbox(value=True)
 | 
			
		||||
                escape_html = gr.Checkbox(value=True)
 | 
			
		||||
                enable_thinking = gr.Checkbox(value=True)
 | 
			
		||||
                clear_btn = gr.Button()
 | 
			
		||||
 | 
			
		||||
    tools.input(check_json_schema, inputs=[tools, engine.manager.get_elem_by_id("top.lang")])
 | 
			
		||||
@ -103,6 +104,7 @@ def create_chat_box(
 | 
			
		||||
            temperature,
 | 
			
		||||
            skip_special_tokens,
 | 
			
		||||
            escape_html,
 | 
			
		||||
            enable_thinking,
 | 
			
		||||
        ],
 | 
			
		||||
        [chatbot, messages],
 | 
			
		||||
    )
 | 
			
		||||
@ -127,6 +129,7 @@ def create_chat_box(
 | 
			
		||||
            temperature=temperature,
 | 
			
		||||
            skip_special_tokens=skip_special_tokens,
 | 
			
		||||
            escape_html=escape_html,
 | 
			
		||||
            enable_thinking=enable_thinking,
 | 
			
		||||
            clear_btn=clear_btn,
 | 
			
		||||
        ),
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
@ -2468,6 +2468,23 @@ LOCALES = {
 | 
			
		||||
            "label": "HTML タグをエスケープ",
 | 
			
		||||
        },
 | 
			
		||||
    },
 | 
			
		||||
    "enable_thinking": {
 | 
			
		||||
        "en": {
 | 
			
		||||
            "label": "Enable thinking",
 | 
			
		||||
        },
 | 
			
		||||
        "ru": {
 | 
			
		||||
            "label": "Включить мышление",
 | 
			
		||||
        },
 | 
			
		||||
        "zh": {
 | 
			
		||||
            "label": "启用思考",
 | 
			
		||||
        },
 | 
			
		||||
        "ko": {
 | 
			
		||||
            "label": "사고를 활성화하다",
 | 
			
		||||
        },
 | 
			
		||||
        "ja": {
 | 
			
		||||
            "label": "思考を可能にする",
 | 
			
		||||
        },
 | 
			
		||||
    },
 | 
			
		||||
    "clear_btn": {
 | 
			
		||||
        "en": {
 | 
			
		||||
            "value": "Clear history",
 | 
			
		||||
 | 
			
		||||
@ -125,6 +125,37 @@ def test_encode_multiturn(use_fast: bool):
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.parametrize("use_fast", [True, False])
 | 
			
		||||
def test_reasoning_encode_oneturn(use_fast: bool):
 | 
			
		||||
    tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B", use_fast=use_fast)
 | 
			
		||||
    template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="qwen3"))
 | 
			
		||||
    prompt_ids, answer_ids = template.encode_oneturn(tokenizer, MESSAGES)
 | 
			
		||||
    prompt_str = (
 | 
			
		||||
        "<|im_start|>user\nHow are you<|im_end|>\n"
 | 
			
		||||
        "<|im_start|>assistant\nI am fine!<|im_end|>\n"
 | 
			
		||||
        "<|im_start|>user\n你好<|im_end|>\n"
 | 
			
		||||
        "<|im_start|>assistant\n<think>\n\n</think>\n\n"
 | 
			
		||||
    )
 | 
			
		||||
    answer_str = "很高兴认识你!<|im_end|>\n"
 | 
			
		||||
    _check_tokenization(tokenizer, (prompt_ids, answer_ids), (prompt_str, answer_str))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.parametrize("use_fast", [True, False])
 | 
			
		||||
def test_reasoning_encode_multiturn(use_fast: bool):
 | 
			
		||||
    tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B", use_fast=use_fast)
 | 
			
		||||
    template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="qwen3"))
 | 
			
		||||
    encoded_pairs = template.encode_multiturn(tokenizer, MESSAGES)
 | 
			
		||||
    prompt_str_1 = "<|im_start|>user\nHow are you<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
 | 
			
		||||
    answer_str_1 = "I am fine!<|im_end|>\n"
 | 
			
		||||
    prompt_str_2 = "<|im_start|>user\n你好<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
 | 
			
		||||
    answer_str_2 = "很高兴认识你!<|im_end|>\n"
 | 
			
		||||
    _check_tokenization(
 | 
			
		||||
        tokenizer,
 | 
			
		||||
        (encoded_pairs[0][0], encoded_pairs[0][1], encoded_pairs[1][0], encoded_pairs[1][1]),
 | 
			
		||||
        (prompt_str_1, answer_str_1, prompt_str_2, answer_str_2),
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.parametrize("use_fast", [True, False])
 | 
			
		||||
def test_jinja_template(use_fast: bool):
 | 
			
		||||
    tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3, use_fast=use_fast)
 | 
			
		||||
@ -227,6 +258,15 @@ def test_qwen2_5_template(use_fast: bool):
 | 
			
		||||
 | 
			
		||||
@pytest.mark.parametrize("use_fast", [True, False])
 | 
			
		||||
def test_qwen3_template(use_fast: bool):
 | 
			
		||||
    prompt_str = (
 | 
			
		||||
        "<|im_start|>user\nHow are you<|im_end|>\n"
 | 
			
		||||
        "<|im_start|>assistant\nI am fine!<|im_end|>\n"
 | 
			
		||||
        "<|im_start|>user\n你好<|im_end|>\n"
 | 
			
		||||
        "<|im_start|>assistant\n<think>\n\n</think>\n\n"
 | 
			
		||||
    )
 | 
			
		||||
    answer_str = "很高兴认识你!<|im_end|>\n"
 | 
			
		||||
    _check_template("Qwen/Qwen3-8B", "qwen3", prompt_str, answer_str, use_fast)
 | 
			
		||||
 | 
			
		||||
    prompt_str = (
 | 
			
		||||
        "<|im_start|>user\nHow are you<|im_end|>\n"
 | 
			
		||||
        "<|im_start|>assistant\nI am fine!<|im_end|>\n"
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user