mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-01 11:12:50 +08:00
[model] add qwen3 (#7885)
This commit is contained in:
parent
369474451d
commit
ae392e054c
@ -87,7 +87,7 @@ Choose your path:
|
||||
|
||||
| Support Date | Model Name |
|
||||
| ------------ | ------------------------------------------------------------ |
|
||||
| Day 0 | Qwen2.5 / Qwen2.5-VL / Gemma 3 / InternLM 3 / MiniCPM-o-2.6 |
|
||||
| Day 0 | Qwen3 / Qwen2.5-VL / Gemma 3 / InternLM 3 / MiniCPM-o-2.6 |
|
||||
| Day 1 | Llama 3 / GLM-4 / Mistral Small / PaliGemma2 / Llama 4 |
|
||||
|
||||
## Benchmark
|
||||
@ -107,6 +107,8 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
|
||||
|
||||
## Changelog
|
||||
|
||||
[25/04/28] We supported fine-tuning the **[Qwen3](https://qwenlm.github.io/blog/qwen3/)** model family.
|
||||
|
||||
[25/04/21] We supported the **[Muon](https://github.com/KellerJordan/Muon)** optimizer. See [examples](examples/README.md) for usage. Thank [@tianshijing](https://github.com/tianshijing)'s PR.
|
||||
|
||||
[25/04/16] We supported fine-tuning the **[InternVL3](https://huggingface.co/OpenGVLab/InternVL3-8B)** model. See [PR #7258](https://github.com/hiyouga/LLaMA-Factory/pull/7258) to get started.
|
||||
@ -115,10 +117,10 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
|
||||
|
||||
[25/04/06] We supported fine-tuning the **[Llama 4](https://ai.meta.com/blog/llama-4-multimodal-intelligence/)** model. See [PR #7611](https://github.com/hiyouga/LLaMA-Factory/pull/7611) to get started.
|
||||
|
||||
[25/03/31] We supported fine-tuning the **[Qwen2.5 Omni](https://qwenlm.github.io/blog/qwen2.5-omni/)** model. See [PR #7537](https://github.com/hiyouga/LLaMA-Factory/pull/7537) to get started.
|
||||
|
||||
<details><summary>Full Changelog</summary>
|
||||
|
||||
[25/03/31] We supported fine-tuning the **[Qwen2.5 Omni](https://qwenlm.github.io/blog/qwen2.5-omni/)** model. See [PR #7537](https://github.com/hiyouga/LLaMA-Factory/pull/7537) to get started.
|
||||
|
||||
[25/03/15] We supported **[SGLang](https://github.com/sgl-project/sglang)** as inference backend. Try `infer_backend: sglang` to accelerate inference.
|
||||
|
||||
[25/03/12] We supported fine-tuning the **[Gemma 3](https://huggingface.co/blog/gemma3)** model.
|
||||
@ -274,6 +276,7 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
|
||||
| [Phi-4](https://huggingface.co/microsoft) | 14B | phi4 |
|
||||
| [Pixtral](https://huggingface.co/mistralai) | 12B | pixtral |
|
||||
| [Qwen (1-2.5) (Code/Math/MoE/QwQ)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen |
|
||||
| [Qwen3 (MoE)](https://huggingface.co/Qwen) | 0.6B/1.7B/4B/8B/14B/32B/235B | qwen3 |
|
||||
| [Qwen2-Audio](https://huggingface.co/Qwen) | 7B | qwen2_audio |
|
||||
| [Qwen2.5-Omni](https://huggingface.co/Qwen)\*\* | 7B | qwen2_omni |
|
||||
| [Qwen2-VL/Qwen2.5-VL/QVQ](https://huggingface.co/Qwen) | 2B/3B/7B/32B/72B | qwen2_vl |
|
||||
|
@ -90,7 +90,7 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
|
||||
|
||||
| 适配时间 | 模型名称 |
|
||||
| ------------ | ------------------------------------------------------------ |
|
||||
| Day 0 | Qwen2.5 / Qwen2.5-VL / Gemma 3 / InternLM 3 / MiniCPM-o-2.6 |
|
||||
| Day 0 | Qwen3 / Qwen2.5-VL / Gemma 3 / InternLM 3 / MiniCPM-o-2.6 |
|
||||
| Day 1 | Llama 3 / GLM-4 / Mistral Small / PaliGemma2 / Llama 4 |
|
||||
|
||||
## 性能指标
|
||||
@ -110,6 +110,8 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
|
||||
|
||||
## 更新日志
|
||||
|
||||
[25/04/28] 我们支持了 **[Qwen3](https://qwenlm.github.io/blog/qwen3/)** 系列模型的微调。
|
||||
|
||||
[25/04/21] 我们支持了 **[Muon](https://github.com/KellerJordan/Muon)** 优化器。详细用法请参照 [examples](examples/README_zh.md)。感谢 [@tianshijing](https://github.com/tianshijing) 的 PR。
|
||||
|
||||
[25/04/16] 我们支持了 **[InternVL3](https://huggingface.co/OpenGVLab/InternVL3-8B)** 模型的微调。查看 [PR #7258](https://github.com/hiyouga/LLaMA-Factory/pull/7258) 以使用。
|
||||
@ -118,10 +120,10 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
|
||||
|
||||
[25/04/06] 我们支持了 **[Llama 4](https://ai.meta.com/blog/llama-4-multimodal-intelligence/)** 模型的微调。查看 [PR #7611](https://github.com/hiyouga/LLaMA-Factory/pull/7611) 以使用。
|
||||
|
||||
[25/03/31] 我们支持了 **[Qwen2.5 Omni](https://qwenlm.github.io/blog/qwen2.5-omni/)** 模型的微调。查看 [PR #7537](https://github.com/hiyouga/LLaMA-Factory/pull/7537) 以使用。
|
||||
|
||||
<details><summary>展开日志</summary>
|
||||
|
||||
[25/03/31] 我们支持了 **[Qwen2.5 Omni](https://qwenlm.github.io/blog/qwen2.5-omni/)** 模型的微调。查看 [PR #7537](https://github.com/hiyouga/LLaMA-Factory/pull/7537) 以使用。
|
||||
|
||||
[25/03/15] 我们支持了 **[SGLang](https://github.com/sgl-project/sglang)** 推理后端,请使用 `infer_backend: sglang` 启用。
|
||||
|
||||
[25/03/12] 我们支持了 **[Gemma 3](https://huggingface.co/blog/gemma3)** 模型的微调。
|
||||
@ -277,6 +279,7 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
|
||||
| [Phi-4](https://huggingface.co/microsoft) | 14B | phi4 |
|
||||
| [Pixtral](https://huggingface.co/mistralai) | 12B | pixtral |
|
||||
| [Qwen (1-2.5) (Code/Math/MoE/QwQ)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen |
|
||||
| [Qwen3 (MoE)](https://huggingface.co/Qwen) | 0.6B/1.7B/4B/8B/14B/32B/235B | qwen3 |
|
||||
| [Qwen2-Audio](https://huggingface.co/Qwen) | 7B | qwen2_audio |
|
||||
| [Qwen2.5-Omni](https://huggingface.co/Qwen)\*\* | 7B | qwen2_omni |
|
||||
| [Qwen2-VL/Qwen2.5-VL/QVQ](https://huggingface.co/Qwen) | 2B/3B/7B/32B/72B | qwen2_vl |
|
||||
|
@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
|
||||
@ -60,7 +61,7 @@ class Template:
|
||||
tools: Optional[str] = None,
|
||||
) -> tuple[list[int], list[int]]:
|
||||
r"""Return a single pair of token ids representing prompt and response respectively."""
|
||||
encoded_messages = self._encode(tokenizer, messages, system, tools)
|
||||
encoded_messages = self._encode(tokenizer, messages, system, tools, remove_thought=True)
|
||||
prompt_ids = []
|
||||
for encoded_ids in encoded_messages[:-1]:
|
||||
prompt_ids += encoded_ids
|
||||
@ -76,7 +77,7 @@ class Template:
|
||||
tools: Optional[str] = None,
|
||||
) -> list[tuple[list[int], list[int]]]:
|
||||
r"""Return multiple pairs of token ids representing prompts and responses respectively."""
|
||||
encoded_messages = self._encode(tokenizer, messages, system, tools)
|
||||
encoded_messages = self._encode(tokenizer, messages, system, tools, remove_thought=False)
|
||||
return [(encoded_messages[i], encoded_messages[i + 1]) for i in range(0, len(encoded_messages), 2)]
|
||||
|
||||
def extract_tool(self, content: str) -> Union[str, list["FunctionCall"]]:
|
||||
@ -110,12 +111,18 @@ class Template:
|
||||
|
||||
return token_ids
|
||||
|
||||
def _remove_thought(self, content: str) -> str:
|
||||
r"""Remove thought from assistant message."""
|
||||
pattern = re.compile(f"{re.escape(self.thought_words[0])}(.*?){re.escape(self.thought_words[1])}", re.DOTALL)
|
||||
return re.sub(pattern, "", content).lstrip("\n")
|
||||
|
||||
def _encode(
|
||||
self,
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
messages: list[dict[str, str]],
|
||||
system: Optional[str],
|
||||
tools: Optional[str],
|
||||
remove_thought: bool,
|
||||
) -> list[list[int]]:
|
||||
r"""Encode formatted inputs to pairs of token ids.
|
||||
|
||||
@ -133,14 +140,18 @@ class Template:
|
||||
tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
|
||||
elements += self.format_system.apply(content=(system + tool_text))
|
||||
|
||||
if message["role"] == Role.USER.value:
|
||||
elements += self.format_user.apply(content=message["content"], idx=str(i // 2))
|
||||
elif message["role"] == Role.ASSISTANT.value:
|
||||
elements += self.format_assistant.apply(content=message["content"])
|
||||
elif message["role"] == Role.OBSERVATION.value:
|
||||
elements += self.format_observation.apply(content=message["content"])
|
||||
elif message["role"] == Role.FUNCTION.value:
|
||||
elements += self.format_function.apply(content=message["content"])
|
||||
content = message["content"]
|
||||
if remove_thought and message["role"] == Role.ASSISTANT and (i != len(messages) - 1):
|
||||
content = self._remove_thought(content)
|
||||
|
||||
if message["role"] == Role.USER:
|
||||
elements += self.format_user.apply(content=content, idx=str(i // 2))
|
||||
elif message["role"] == Role.ASSISTANT:
|
||||
elements += self.format_assistant.apply(content=content)
|
||||
elif message["role"] == Role.OBSERVATION:
|
||||
elements += self.format_observation.apply(content=content)
|
||||
elif message["role"] == Role.FUNCTION:
|
||||
elements += self.format_function.apply(content=content)
|
||||
else:
|
||||
raise NotImplementedError("Unexpected role: {}".format(message["role"]))
|
||||
|
||||
@ -317,6 +328,7 @@ class Llama2Template(Template):
|
||||
messages: list[dict[str, str]],
|
||||
system: str,
|
||||
tools: str,
|
||||
remove_thought: bool,
|
||||
) -> list[list[int]]:
|
||||
system = system or self.default_system
|
||||
encoded_messages = []
|
||||
@ -330,14 +342,18 @@ class Llama2Template(Template):
|
||||
tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
|
||||
system_text = self.format_system.apply(content=(system + tool_text))[0]
|
||||
|
||||
if message["role"] == Role.USER.value:
|
||||
elements += self.format_user.apply(content=system_text + message["content"])
|
||||
elif message["role"] == Role.ASSISTANT.value:
|
||||
elements += self.format_assistant.apply(content=message["content"])
|
||||
elif message["role"] == Role.OBSERVATION.value:
|
||||
elements += self.format_observation.apply(content=message["content"])
|
||||
elif message["role"] == Role.FUNCTION.value:
|
||||
elements += self.format_function.apply(content=message["content"])
|
||||
content = message["content"]
|
||||
if remove_thought and message["role"] == Role.ASSISTANT and (i != len(messages) - 1):
|
||||
content = self._remove_thought(content)
|
||||
|
||||
if message["role"] == Role.USER:
|
||||
elements += self.format_user.apply(content=system_text + content)
|
||||
elif message["role"] == Role.ASSISTANT:
|
||||
elements += self.format_assistant.apply(content=content)
|
||||
elif message["role"] == Role.OBSERVATION:
|
||||
elements += self.format_observation.apply(content=content)
|
||||
elif message["role"] == Role.FUNCTION:
|
||||
elements += self.format_function.apply(content=content)
|
||||
else:
|
||||
raise NotImplementedError("Unexpected role: {}".format(message["role"]))
|
||||
|
||||
@ -476,6 +492,7 @@ def parse_template(tokenizer: "PreTrainedTokenizer") -> "Template":
|
||||
messages = [{"role": "user", "content": "{{content}}"}, {"role": "assistant", "content": "{{content}}"}]
|
||||
assistant_slot = tokenizer.apply_chat_template(messages, add_generation_prompt=False, tokenize=False)
|
||||
assistant_slot = assistant_slot[len(prefix) + len(user_slot) :]
|
||||
assistant_slot = assistant_slot.replace("<think>", "").replace("</think>", "").lstrip("\n") # remove thought tags
|
||||
|
||||
if len(user_slot) > len(user_slot_empty_system):
|
||||
default_system = find_diff(user_slot_empty_system, user_slot)
|
||||
@ -1411,6 +1428,21 @@ register_template(
|
||||
)
|
||||
|
||||
|
||||
# copied from qwen template
|
||||
register_template(
|
||||
name="qwen3",
|
||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
|
||||
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
||||
format_function=FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen"),
|
||||
format_observation=StringFormatter(
|
||||
slots=["<|im_start|>user\n<tool_response>\n{{content}}\n</tool_response><|im_end|>\n<|im_start|>assistant\n"]
|
||||
),
|
||||
format_tools=ToolFormatter(tool_format="qwen"),
|
||||
stop_words=["<|im_end|>"],
|
||||
)
|
||||
|
||||
|
||||
# copied from chatml template
|
||||
register_template(
|
||||
name="qwen2_audio",
|
||||
|
@ -2403,6 +2403,69 @@ register_model_group(
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Qwen3-0.6B-Base": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen3-0.6B-Base",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen3-0.6B-Base",
|
||||
},
|
||||
"Qwen3-1.7B-Base": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen3-1.7B-Base",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen3-1.7B-Base",
|
||||
},
|
||||
"Qwen3-4B-Base": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen3-4B-Base",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen3-4B-Base",
|
||||
},
|
||||
"Qwen3-8B-Base": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen3-8B-Base",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen3-8B-Base",
|
||||
},
|
||||
"Qwen3-14B-Base": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen3-14B-Base",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen3-14B-Base",
|
||||
},
|
||||
"Qwen3-30B-A3B-Base": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen3-30B-A3B-Base",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen3-30B-A3B-Base",
|
||||
},
|
||||
"Qwen3-0.6B-Instruct": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen3-0.6B",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen3-0.6B",
|
||||
},
|
||||
"Qwen3-1.7B-Instruct": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen3-1.7B",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen3-1.7B",
|
||||
},
|
||||
"Qwen3-4B-Instruct": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen3-4B",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen3-4B",
|
||||
},
|
||||
"Qwen3-8B-Instruct": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen3-8B",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen3-8B",
|
||||
},
|
||||
"Qwen3-14B-Instruct": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen3-14B",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen3-14B",
|
||||
},
|
||||
"Qwen3-32B-Instruct": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen3-32B",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen3-32B",
|
||||
},
|
||||
"Qwen3-30B-A3B-Instruct": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen3-30B-A3B",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen3-30B-A3B",
|
||||
},
|
||||
"Qwen3-235B-A22B-Instruct": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen3-235B-A22B",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen3-235B-A22B",
|
||||
},
|
||||
},
|
||||
template="qwen3",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Qwen2-Audio-7B": {
|
||||
|
@ -56,11 +56,11 @@ def can_quantize_to(quantization_method: str) -> "gr.Dropdown":
|
||||
Inputs: top.quantization_method
|
||||
Outputs: top.quantization_bit
|
||||
"""
|
||||
if quantization_method == QuantizationMethod.BNB.value:
|
||||
if quantization_method == QuantizationMethod.BNB:
|
||||
available_bits = ["none", "8", "4"]
|
||||
elif quantization_method == QuantizationMethod.HQQ.value:
|
||||
elif quantization_method == QuantizationMethod.HQQ:
|
||||
available_bits = ["none", "8", "6", "5", "4", "3", "2", "1"]
|
||||
elif quantization_method == QuantizationMethod.EETQ.value:
|
||||
elif quantization_method == QuantizationMethod.EETQ:
|
||||
available_bits = ["none", "8"]
|
||||
|
||||
return gr.Dropdown(choices=available_bits)
|
||||
|
@ -39,6 +39,13 @@ MESSAGES = [
|
||||
{"role": "assistant", "content": "很高兴认识你!"},
|
||||
]
|
||||
|
||||
MESSAGES_WITH_THOUGHT = [
|
||||
{"role": "user", "content": "How are you"},
|
||||
{"role": "assistant", "content": "<think>\nModel thought here\n</think>\n\nI am fine!"},
|
||||
{"role": "user", "content": "你好"},
|
||||
{"role": "assistant", "content": "<think>\n模型思考内容\n</think>\n\n很高兴认识你!"},
|
||||
]
|
||||
|
||||
|
||||
def _check_tokenization(
|
||||
tokenizer: "PreTrainedTokenizer", batch_input_ids: list[list[int]], batch_text: list[str]
|
||||
@ -53,7 +60,14 @@ def _check_tokenization(
|
||||
assert tokenizer.decode(input_ids) == text
|
||||
|
||||
|
||||
def _check_template(model_id: str, template_name: str, prompt_str: str, answer_str: str, use_fast: bool) -> None:
|
||||
def _check_template(
|
||||
model_id: str,
|
||||
template_name: str,
|
||||
prompt_str: str,
|
||||
answer_str: str,
|
||||
use_fast: bool,
|
||||
messages: list[dict[str, str]] = MESSAGES,
|
||||
) -> None:
|
||||
r"""Check template.
|
||||
|
||||
Args:
|
||||
@ -62,13 +76,14 @@ def _check_template(model_id: str, template_name: str, prompt_str: str, answer_s
|
||||
prompt_str: the string corresponding to the prompt part.
|
||||
answer_str: the string corresponding to the answer part.
|
||||
use_fast: whether to use fast tokenizer.
|
||||
messages: the list of messages.
|
||||
|
||||
"""
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=use_fast, token=HF_TOKEN)
|
||||
content_str = tokenizer.apply_chat_template(MESSAGES, tokenize=False)
|
||||
content_ids = tokenizer.apply_chat_template(MESSAGES, tokenize=True)
|
||||
content_str = tokenizer.apply_chat_template(messages, tokenize=False)
|
||||
content_ids = tokenizer.apply_chat_template(messages, tokenize=True)
|
||||
template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template=template_name))
|
||||
prompt_ids, answer_ids = template.encode_oneturn(tokenizer, MESSAGES)
|
||||
prompt_ids, answer_ids = template.encode_oneturn(tokenizer, messages)
|
||||
assert content_str == prompt_str + answer_str
|
||||
assert content_ids == prompt_ids + answer_ids
|
||||
_check_tokenization(tokenizer, (prompt_ids, answer_ids), (prompt_str, answer_str))
|
||||
@ -198,7 +213,7 @@ def test_phi4_template(use_fast: bool):
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_fast", [True, False])
|
||||
def test_qwen_template(use_fast: bool):
|
||||
def test_qwen2_5_template(use_fast: bool):
|
||||
prompt_str = (
|
||||
"<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n"
|
||||
"<|im_start|>user\nHow are you<|im_end|>\n"
|
||||
@ -210,6 +225,18 @@ def test_qwen_template(use_fast: bool):
|
||||
_check_template("Qwen/Qwen2.5-7B-Instruct", "qwen", prompt_str, answer_str, use_fast)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_fast", [True, False])
|
||||
def test_qwen3_template(use_fast: bool):
|
||||
prompt_str = (
|
||||
"<|im_start|>user\nHow are you<|im_end|>\n"
|
||||
"<|im_start|>assistant\nI am fine!<|im_end|>\n"
|
||||
"<|im_start|>user\n你好<|im_end|>\n"
|
||||
"<|im_start|>assistant\n"
|
||||
)
|
||||
answer_str = "<think>\n模型思考内容\n</think>\n\n很高兴认识你!<|im_end|>\n"
|
||||
_check_template("Qwen/Qwen3-8B", "qwen3", prompt_str, answer_str, use_fast, messages=MESSAGES_WITH_THOUGHT)
|
||||
|
||||
|
||||
def test_parse_llama3_template():
|
||||
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3, token=HF_TOKEN)
|
||||
template = parse_template(tokenizer)
|
||||
@ -231,3 +258,13 @@ def test_parse_qwen_template():
|
||||
assert template.format_system.slots == ["<|im_start|>system\n{{content}}<|im_end|>\n"]
|
||||
assert template.format_prefix.slots == []
|
||||
assert template.default_system == "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."
|
||||
|
||||
|
||||
def test_parse_qwen3_template():
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B", token=HF_TOKEN)
|
||||
template = parse_template(tokenizer)
|
||||
assert template.format_user.slots == ["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]
|
||||
assert template.format_assistant.slots == ["{{content}}<|im_end|>\n"]
|
||||
assert template.format_system.slots == ["<|im_start|>system\n{{content}}<|im_end|>\n"]
|
||||
assert template.format_prefix.slots == []
|
||||
assert template.default_system == ""
|
||||
|
@ -1,2 +1,2 @@
|
||||
# change if test fails or cache is outdated
|
||||
0.9.3.105
|
||||
0.9.3.106
|
||||
|
Loading…
x
Reference in New Issue
Block a user