From 98f23c6584587a368267dade9b0d4bb544e6968c Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Tue, 29 Apr 2025 09:34:05 +0800 Subject: [PATCH] [model] add qwen3 (#7885) --- README.md | 9 ++-- README_zh.md | 9 ++-- src/llamafactory/data/template.py | 68 ++++++++++++++++++++-------- src/llamafactory/extras/constants.py | 63 ++++++++++++++++++++++++++ src/llamafactory/webui/control.py | 6 +-- tests/data/test_template.py | 47 +++++++++++++++++-- tests/version.txt | 2 +- 7 files changed, 171 insertions(+), 33 deletions(-) diff --git a/README.md b/README.md index bf20c055..06e9335a 100644 --- a/README.md +++ b/README.md @@ -87,7 +87,7 @@ Choose your path: | Support Date | Model Name | | ------------ | ------------------------------------------------------------ | -| Day 0 | Qwen2.5 / Qwen2.5-VL / Gemma 3 / InternLM 3 / MiniCPM-o-2.6 | +| Day 0 | Qwen3 / Qwen2.5-VL / Gemma 3 / InternLM 3 / MiniCPM-o-2.6 | | Day 1 | Llama 3 / GLM-4 / Mistral Small / PaliGemma2 / Llama 4 | ## Benchmark @@ -107,6 +107,8 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/ ## Changelog +[25/04/28] We supported fine-tuning the **[Qwen3](https://qwenlm.github.io/blog/qwen3/)** model family. + [25/04/21] We supported the **[Muon](https://github.com/KellerJordan/Muon)** optimizer. See [examples](examples/README.md) for usage. Thank [@tianshijing](https://github.com/tianshijing)'s PR. [25/04/16] We supported fine-tuning the **[InternVL3](https://huggingface.co/OpenGVLab/InternVL3-8B)** model. See [PR #7258](https://github.com/hiyouga/LLaMA-Factory/pull/7258) to get started. @@ -115,10 +117,10 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/ [25/04/06] We supported fine-tuning the **[Llama 4](https://ai.meta.com/blog/llama-4-multimodal-intelligence/)** model. See [PR #7611](https://github.com/hiyouga/LLaMA-Factory/pull/7611) to get started. -[25/03/31] We supported fine-tuning the **[Qwen2.5 Omni](https://qwenlm.github.io/blog/qwen2.5-omni/)** model. See [PR #7537](https://github.com/hiyouga/LLaMA-Factory/pull/7537) to get started. -
Full Changelog +[25/03/31] We supported fine-tuning the **[Qwen2.5 Omni](https://qwenlm.github.io/blog/qwen2.5-omni/)** model. See [PR #7537](https://github.com/hiyouga/LLaMA-Factory/pull/7537) to get started. + [25/03/15] We supported **[SGLang](https://github.com/sgl-project/sglang)** as inference backend. Try `infer_backend: sglang` to accelerate inference. [25/03/12] We supported fine-tuning the **[Gemma 3](https://huggingface.co/blog/gemma3)** model. @@ -274,6 +276,7 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/ | [Phi-4](https://huggingface.co/microsoft) | 14B | phi4 | | [Pixtral](https://huggingface.co/mistralai) | 12B | pixtral | | [Qwen (1-2.5) (Code/Math/MoE/QwQ)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen | +| [Qwen3 (MoE)](https://huggingface.co/Qwen) | 0.6B/1.7B/4B/8B/14B/32B/235B | qwen3 | | [Qwen2-Audio](https://huggingface.co/Qwen) | 7B | qwen2_audio | | [Qwen2.5-Omni](https://huggingface.co/Qwen)\*\* | 7B | qwen2_omni | | [Qwen2-VL/Qwen2.5-VL/QVQ](https://huggingface.co/Qwen) | 2B/3B/7B/32B/72B | qwen2_vl | diff --git a/README_zh.md b/README_zh.md index a8f0492f..e9faa49b 100644 --- a/README_zh.md +++ b/README_zh.md @@ -90,7 +90,7 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc | 适配时间 | 模型名称 | | ------------ | ------------------------------------------------------------ | -| Day 0 | Qwen2.5 / Qwen2.5-VL / Gemma 3 / InternLM 3 / MiniCPM-o-2.6 | +| Day 0 | Qwen3 / Qwen2.5-VL / Gemma 3 / InternLM 3 / MiniCPM-o-2.6 | | Day 1 | Llama 3 / GLM-4 / Mistral Small / PaliGemma2 / Llama 4 | ## 性能指标 @@ -110,6 +110,8 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc ## 更新日志 +[25/04/28] 我们支持了 **[Qwen3](https://qwenlm.github.io/blog/qwen3/)** 系列模型的微调。 + [25/04/21] 我们支持了 **[Muon](https://github.com/KellerJordan/Muon)** 优化器。详细用法请参照 [examples](examples/README_zh.md)。感谢 [@tianshijing](https://github.com/tianshijing) 的 PR。 [25/04/16] 我们支持了 **[InternVL3](https://huggingface.co/OpenGVLab/InternVL3-8B)** 模型的微调。查看 [PR #7258](https://github.com/hiyouga/LLaMA-Factory/pull/7258) 以使用。 @@ -118,10 +120,10 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc [25/04/06] 我们支持了 **[Llama 4](https://ai.meta.com/blog/llama-4-multimodal-intelligence/)** 模型的微调。查看 [PR #7611](https://github.com/hiyouga/LLaMA-Factory/pull/7611) 以使用。 -[25/03/31] 我们支持了 **[Qwen2.5 Omni](https://qwenlm.github.io/blog/qwen2.5-omni/)** 模型的微调。查看 [PR #7537](https://github.com/hiyouga/LLaMA-Factory/pull/7537) 以使用。 -
展开日志 +[25/03/31] 我们支持了 **[Qwen2.5 Omni](https://qwenlm.github.io/blog/qwen2.5-omni/)** 模型的微调。查看 [PR #7537](https://github.com/hiyouga/LLaMA-Factory/pull/7537) 以使用。 + [25/03/15] 我们支持了 **[SGLang](https://github.com/sgl-project/sglang)** 推理后端,请使用 `infer_backend: sglang` 启用。 [25/03/12] 我们支持了 **[Gemma 3](https://huggingface.co/blog/gemma3)** 模型的微调。 @@ -277,6 +279,7 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc | [Phi-4](https://huggingface.co/microsoft) | 14B | phi4 | | [Pixtral](https://huggingface.co/mistralai) | 12B | pixtral | | [Qwen (1-2.5) (Code/Math/MoE/QwQ)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen | +| [Qwen3 (MoE)](https://huggingface.co/Qwen) | 0.6B/1.7B/4B/8B/14B/32B/235B | qwen3 | | [Qwen2-Audio](https://huggingface.co/Qwen) | 7B | qwen2_audio | | [Qwen2.5-Omni](https://huggingface.co/Qwen)\*\* | 7B | qwen2_omni | | [Qwen2-VL/Qwen2.5-VL/QVQ](https://huggingface.co/Qwen) | 2B/3B/7B/32B/72B | qwen2_vl | diff --git a/src/llamafactory/data/template.py b/src/llamafactory/data/template.py index 553e0285..1be20c5a 100644 --- a/src/llamafactory/data/template.py +++ b/src/llamafactory/data/template.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import re from dataclasses import dataclass from typing import TYPE_CHECKING, Optional, Union @@ -60,7 +61,7 @@ class Template: tools: Optional[str] = None, ) -> tuple[list[int], list[int]]: r"""Return a single pair of token ids representing prompt and response respectively.""" - encoded_messages = self._encode(tokenizer, messages, system, tools) + encoded_messages = self._encode(tokenizer, messages, system, tools, remove_thought=True) prompt_ids = [] for encoded_ids in encoded_messages[:-1]: prompt_ids += encoded_ids @@ -76,7 +77,7 @@ class Template: tools: Optional[str] = None, ) -> list[tuple[list[int], list[int]]]: r"""Return multiple pairs of token ids representing prompts and responses respectively.""" - encoded_messages = self._encode(tokenizer, messages, system, tools) + encoded_messages = self._encode(tokenizer, messages, system, tools, remove_thought=False) return [(encoded_messages[i], encoded_messages[i + 1]) for i in range(0, len(encoded_messages), 2)] def extract_tool(self, content: str) -> Union[str, list["FunctionCall"]]: @@ -110,12 +111,18 @@ class Template: return token_ids + def _remove_thought(self, content: str) -> str: + r"""Remove thought from assistant message.""" + pattern = re.compile(f"{re.escape(self.thought_words[0])}(.*?){re.escape(self.thought_words[1])}", re.DOTALL) + return re.sub(pattern, "", content).lstrip("\n") + def _encode( self, tokenizer: "PreTrainedTokenizer", messages: list[dict[str, str]], system: Optional[str], tools: Optional[str], + remove_thought: bool, ) -> list[list[int]]: r"""Encode formatted inputs to pairs of token ids. @@ -133,14 +140,18 @@ class Template: tool_text = self.format_tools.apply(content=tools)[0] if tools else "" elements += self.format_system.apply(content=(system + tool_text)) - if message["role"] == Role.USER.value: - elements += self.format_user.apply(content=message["content"], idx=str(i // 2)) - elif message["role"] == Role.ASSISTANT.value: - elements += self.format_assistant.apply(content=message["content"]) - elif message["role"] == Role.OBSERVATION.value: - elements += self.format_observation.apply(content=message["content"]) - elif message["role"] == Role.FUNCTION.value: - elements += self.format_function.apply(content=message["content"]) + content = message["content"] + if remove_thought and message["role"] == Role.ASSISTANT and (i != len(messages) - 1): + content = self._remove_thought(content) + + if message["role"] == Role.USER: + elements += self.format_user.apply(content=content, idx=str(i // 2)) + elif message["role"] == Role.ASSISTANT: + elements += self.format_assistant.apply(content=content) + elif message["role"] == Role.OBSERVATION: + elements += self.format_observation.apply(content=content) + elif message["role"] == Role.FUNCTION: + elements += self.format_function.apply(content=content) else: raise NotImplementedError("Unexpected role: {}".format(message["role"])) @@ -317,6 +328,7 @@ class Llama2Template(Template): messages: list[dict[str, str]], system: str, tools: str, + remove_thought: bool, ) -> list[list[int]]: system = system or self.default_system encoded_messages = [] @@ -330,14 +342,18 @@ class Llama2Template(Template): tool_text = self.format_tools.apply(content=tools)[0] if tools else "" system_text = self.format_system.apply(content=(system + tool_text))[0] - if message["role"] == Role.USER.value: - elements += self.format_user.apply(content=system_text + message["content"]) - elif message["role"] == Role.ASSISTANT.value: - elements += self.format_assistant.apply(content=message["content"]) - elif message["role"] == Role.OBSERVATION.value: - elements += self.format_observation.apply(content=message["content"]) - elif message["role"] == Role.FUNCTION.value: - elements += self.format_function.apply(content=message["content"]) + content = message["content"] + if remove_thought and message["role"] == Role.ASSISTANT and (i != len(messages) - 1): + content = self._remove_thought(content) + + if message["role"] == Role.USER: + elements += self.format_user.apply(content=system_text + content) + elif message["role"] == Role.ASSISTANT: + elements += self.format_assistant.apply(content=content) + elif message["role"] == Role.OBSERVATION: + elements += self.format_observation.apply(content=content) + elif message["role"] == Role.FUNCTION: + elements += self.format_function.apply(content=content) else: raise NotImplementedError("Unexpected role: {}".format(message["role"])) @@ -476,6 +492,7 @@ def parse_template(tokenizer: "PreTrainedTokenizer") -> "Template": messages = [{"role": "user", "content": "{{content}}"}, {"role": "assistant", "content": "{{content}}"}] assistant_slot = tokenizer.apply_chat_template(messages, add_generation_prompt=False, tokenize=False) assistant_slot = assistant_slot[len(prefix) + len(user_slot) :] + assistant_slot = assistant_slot.replace("", "").replace("", "").lstrip("\n") # remove thought tags if len(user_slot) > len(user_slot_empty_system): default_system = find_diff(user_slot_empty_system, user_slot) @@ -1411,6 +1428,21 @@ register_template( ) +# copied from qwen template +register_template( + name="qwen3", + format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), + format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]), + format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), + format_function=FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen"), + format_observation=StringFormatter( + slots=["<|im_start|>user\n\n{{content}}\n<|im_end|>\n<|im_start|>assistant\n"] + ), + format_tools=ToolFormatter(tool_format="qwen"), + stop_words=["<|im_end|>"], +) + + # copied from chatml template register_template( name="qwen2_audio", diff --git a/src/llamafactory/extras/constants.py b/src/llamafactory/extras/constants.py index a26ed6c2..be4a417d 100644 --- a/src/llamafactory/extras/constants.py +++ b/src/llamafactory/extras/constants.py @@ -2403,6 +2403,69 @@ register_model_group( ) +register_model_group( + models={ + "Qwen3-0.6B-Base": { + DownloadSource.DEFAULT: "Qwen/Qwen3-0.6B-Base", + DownloadSource.MODELSCOPE: "Qwen/Qwen3-0.6B-Base", + }, + "Qwen3-1.7B-Base": { + DownloadSource.DEFAULT: "Qwen/Qwen3-1.7B-Base", + DownloadSource.MODELSCOPE: "Qwen/Qwen3-1.7B-Base", + }, + "Qwen3-4B-Base": { + DownloadSource.DEFAULT: "Qwen/Qwen3-4B-Base", + DownloadSource.MODELSCOPE: "Qwen/Qwen3-4B-Base", + }, + "Qwen3-8B-Base": { + DownloadSource.DEFAULT: "Qwen/Qwen3-8B-Base", + DownloadSource.MODELSCOPE: "Qwen/Qwen3-8B-Base", + }, + "Qwen3-14B-Base": { + DownloadSource.DEFAULT: "Qwen/Qwen3-14B-Base", + DownloadSource.MODELSCOPE: "Qwen/Qwen3-14B-Base", + }, + "Qwen3-30B-A3B-Base": { + DownloadSource.DEFAULT: "Qwen/Qwen3-30B-A3B-Base", + DownloadSource.MODELSCOPE: "Qwen/Qwen3-30B-A3B-Base", + }, + "Qwen3-0.6B-Instruct": { + DownloadSource.DEFAULT: "Qwen/Qwen3-0.6B", + DownloadSource.MODELSCOPE: "Qwen/Qwen3-0.6B", + }, + "Qwen3-1.7B-Instruct": { + DownloadSource.DEFAULT: "Qwen/Qwen3-1.7B", + DownloadSource.MODELSCOPE: "Qwen/Qwen3-1.7B", + }, + "Qwen3-4B-Instruct": { + DownloadSource.DEFAULT: "Qwen/Qwen3-4B", + DownloadSource.MODELSCOPE: "Qwen/Qwen3-4B", + }, + "Qwen3-8B-Instruct": { + DownloadSource.DEFAULT: "Qwen/Qwen3-8B", + DownloadSource.MODELSCOPE: "Qwen/Qwen3-8B", + }, + "Qwen3-14B-Instruct": { + DownloadSource.DEFAULT: "Qwen/Qwen3-14B", + DownloadSource.MODELSCOPE: "Qwen/Qwen3-14B", + }, + "Qwen3-32B-Instruct": { + DownloadSource.DEFAULT: "Qwen/Qwen3-32B", + DownloadSource.MODELSCOPE: "Qwen/Qwen3-32B", + }, + "Qwen3-30B-A3B-Instruct": { + DownloadSource.DEFAULT: "Qwen/Qwen3-30B-A3B", + DownloadSource.MODELSCOPE: "Qwen/Qwen3-30B-A3B", + }, + "Qwen3-235B-A22B-Instruct": { + DownloadSource.DEFAULT: "Qwen/Qwen3-235B-A22B", + DownloadSource.MODELSCOPE: "Qwen/Qwen3-235B-A22B", + }, + }, + template="qwen3", +) + + register_model_group( models={ "Qwen2-Audio-7B": { diff --git a/src/llamafactory/webui/control.py b/src/llamafactory/webui/control.py index 9939456f..c45073dc 100644 --- a/src/llamafactory/webui/control.py +++ b/src/llamafactory/webui/control.py @@ -56,11 +56,11 @@ def can_quantize_to(quantization_method: str) -> "gr.Dropdown": Inputs: top.quantization_method Outputs: top.quantization_bit """ - if quantization_method == QuantizationMethod.BNB.value: + if quantization_method == QuantizationMethod.BNB: available_bits = ["none", "8", "4"] - elif quantization_method == QuantizationMethod.HQQ.value: + elif quantization_method == QuantizationMethod.HQQ: available_bits = ["none", "8", "6", "5", "4", "3", "2", "1"] - elif quantization_method == QuantizationMethod.EETQ.value: + elif quantization_method == QuantizationMethod.EETQ: available_bits = ["none", "8"] return gr.Dropdown(choices=available_bits) diff --git a/tests/data/test_template.py b/tests/data/test_template.py index 9ec25392..74f95df8 100644 --- a/tests/data/test_template.py +++ b/tests/data/test_template.py @@ -39,6 +39,13 @@ MESSAGES = [ {"role": "assistant", "content": "很高兴认识你!"}, ] +MESSAGES_WITH_THOUGHT = [ + {"role": "user", "content": "How are you"}, + {"role": "assistant", "content": "\nModel thought here\n\n\nI am fine!"}, + {"role": "user", "content": "你好"}, + {"role": "assistant", "content": "\n模型思考内容\n\n\n很高兴认识你!"}, +] + def _check_tokenization( tokenizer: "PreTrainedTokenizer", batch_input_ids: list[list[int]], batch_text: list[str] @@ -53,7 +60,14 @@ def _check_tokenization( assert tokenizer.decode(input_ids) == text -def _check_template(model_id: str, template_name: str, prompt_str: str, answer_str: str, use_fast: bool) -> None: +def _check_template( + model_id: str, + template_name: str, + prompt_str: str, + answer_str: str, + use_fast: bool, + messages: list[dict[str, str]] = MESSAGES, +) -> None: r"""Check template. Args: @@ -62,13 +76,14 @@ def _check_template(model_id: str, template_name: str, prompt_str: str, answer_s prompt_str: the string corresponding to the prompt part. answer_str: the string corresponding to the answer part. use_fast: whether to use fast tokenizer. + messages: the list of messages. """ tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=use_fast, token=HF_TOKEN) - content_str = tokenizer.apply_chat_template(MESSAGES, tokenize=False) - content_ids = tokenizer.apply_chat_template(MESSAGES, tokenize=True) + content_str = tokenizer.apply_chat_template(messages, tokenize=False) + content_ids = tokenizer.apply_chat_template(messages, tokenize=True) template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template=template_name)) - prompt_ids, answer_ids = template.encode_oneturn(tokenizer, MESSAGES) + prompt_ids, answer_ids = template.encode_oneturn(tokenizer, messages) assert content_str == prompt_str + answer_str assert content_ids == prompt_ids + answer_ids _check_tokenization(tokenizer, (prompt_ids, answer_ids), (prompt_str, answer_str)) @@ -198,7 +213,7 @@ def test_phi4_template(use_fast: bool): @pytest.mark.parametrize("use_fast", [True, False]) -def test_qwen_template(use_fast: bool): +def test_qwen2_5_template(use_fast: bool): prompt_str = ( "<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n" "<|im_start|>user\nHow are you<|im_end|>\n" @@ -210,6 +225,18 @@ def test_qwen_template(use_fast: bool): _check_template("Qwen/Qwen2.5-7B-Instruct", "qwen", prompt_str, answer_str, use_fast) +@pytest.mark.parametrize("use_fast", [True, False]) +def test_qwen3_template(use_fast: bool): + prompt_str = ( + "<|im_start|>user\nHow are you<|im_end|>\n" + "<|im_start|>assistant\nI am fine!<|im_end|>\n" + "<|im_start|>user\n你好<|im_end|>\n" + "<|im_start|>assistant\n" + ) + answer_str = "\n模型思考内容\n\n\n很高兴认识你!<|im_end|>\n" + _check_template("Qwen/Qwen3-8B", "qwen3", prompt_str, answer_str, use_fast, messages=MESSAGES_WITH_THOUGHT) + + def test_parse_llama3_template(): tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3, token=HF_TOKEN) template = parse_template(tokenizer) @@ -231,3 +258,13 @@ def test_parse_qwen_template(): assert template.format_system.slots == ["<|im_start|>system\n{{content}}<|im_end|>\n"] assert template.format_prefix.slots == [] assert template.default_system == "You are Qwen, created by Alibaba Cloud. You are a helpful assistant." + + +def test_parse_qwen3_template(): + tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B", token=HF_TOKEN) + template = parse_template(tokenizer) + assert template.format_user.slots == ["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"] + assert template.format_assistant.slots == ["{{content}}<|im_end|>\n"] + assert template.format_system.slots == ["<|im_start|>system\n{{content}}<|im_end|>\n"] + assert template.format_prefix.slots == [] + assert template.default_system == "" diff --git a/tests/version.txt b/tests/version.txt index f8affd0f..bbb71ad5 100644 --- a/tests/version.txt +++ b/tests/version.txt @@ -1,2 +1,2 @@ # change if test fails or cache is outdated -0.9.3.105 +0.9.3.106