diff --git a/README.md b/README.md index 1eec5b96..739a13af 100644 --- a/README.md +++ b/README.md @@ -88,14 +88,16 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/ ## Changelog +[25/01/10] We supported fine-tuning the **[Phi-4](https://huggingface.co/microsoft/phi-4)** model. + [24/12/21] We supported using **[SwanLab](https://github.com/SwanHubX/SwanLab)** for experiment tracking and visualization. See [this section](#use-swanlab-logger) for details. [24/11/27] We supported fine-tuning the **[Skywork-o1](https://huggingface.co/Skywork/Skywork-o1-Open-Llama-3.1-8B)** model and the **[OpenO1](https://huggingface.co/datasets/O1-OPEN/OpenO1-SFT)** dataset. -[24/10/09] We supported downloading pre-trained models and datasets from the **[Modelers Hub](https://modelers.cn/models)**. See [this tutorial](#download-from-modelers-hub) for usage. -
Full Changelog +[24/10/09] We supported downloading pre-trained models and datasets from the **[Modelers Hub](https://modelers.cn/models)**. See [this tutorial](#download-from-modelers-hub) for usage. + [24/09/19] We supported fine-tuning the **[Qwen2.5](https://qwenlm.github.io/blog/qwen2.5/)** models. [24/08/30] We supported fine-tuning the **[Qwen2-VL](https://qwenlm.github.io/blog/qwen2-vl/)** models. Thank [@simonJJJ](https://github.com/simonJJJ)'s PR. @@ -211,8 +213,9 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/ | [OLMo](https://huggingface.co/allenai) | 1B/7B | - | | [PaliGemma/PaliGemma2](https://huggingface.co/google) | 3B/10B/28B | paligemma | | [Phi-1.5/Phi-2](https://huggingface.co/microsoft) | 1.3B/2.7B | - | -| [Phi-3](https://huggingface.co/microsoft) | 4B/14B | phi | +| [Phi-3/Phi-3.5](https://huggingface.co/microsoft) | 4B/14B | phi | | [Phi-3-small](https://huggingface.co/microsoft) | 7B | phi_small | +| [Phi-4](https://huggingface.co/microsoft) | 14B | phi4 | | [Pixtral](https://huggingface.co/mistralai) | 12B | pixtral | | [Qwen/QwQ (1-2.5) (Code/Math/MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen | | [Qwen2-VL/QVQ](https://huggingface.co/Qwen) | 2B/7B/72B | qwen2_vl | @@ -762,7 +765,7 @@ If you have a project that should be incorporated, please contact via email or c This repository is licensed under the [Apache-2.0 License](LICENSE). -Please follow the model licenses to use the corresponding model weights: [Baichuan 2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM-4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [GPT-2](https://github.com/openai/gpt-2/blob/master/LICENSE) / [Granite](LICENSE) / [Index](https://huggingface.co/IndexTeam/Index-1.9B/blob/main/LICENSE) / [InternLM2](https://github.com/InternLM/InternLM#license) / [Llama](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [Llama 2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [Llama 3](https://llama.meta.com/llama3/license/) / [MiniCPM](https://github.com/OpenBMB/MiniCPM/blob/main/MiniCPM%20Model%20License.md) / [Mistral/Mixtral/Pixtral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/Phi-2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [Skywork](https://huggingface.co/Skywork/Skywork-13B-base/blob/main/Skywork%20Community%20License.pdf) / [StarCoder 2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [TeleChat2](https://huggingface.co/Tele-AI/telechat-7B/blob/main/TeleChat%E6%A8%A1%E5%9E%8B%E7%A4%BE%E5%8C%BA%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan 2](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan) +Please follow the model licenses to use the corresponding model weights: [Baichuan 2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM-4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [GPT-2](https://github.com/openai/gpt-2/blob/master/LICENSE) / [Granite](LICENSE) / [Index](https://huggingface.co/IndexTeam/Index-1.9B/blob/main/LICENSE) / [InternLM2](https://github.com/InternLM/InternLM#license) / [Llama](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [Llama 2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [Llama 3](https://llama.meta.com/llama3/license/) / [MiniCPM](https://github.com/OpenBMB/MiniCPM/blob/main/MiniCPM%20Model%20License.md) / [Mistral/Mixtral/Pixtral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/Phi-2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3/Phi-4](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [Skywork](https://huggingface.co/Skywork/Skywork-13B-base/blob/main/Skywork%20Community%20License.pdf) / [StarCoder 2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [TeleChat2](https://huggingface.co/Tele-AI/telechat-7B/blob/main/TeleChat%E6%A8%A1%E5%9E%8B%E7%A4%BE%E5%8C%BA%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan 2](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan) ## Citation diff --git a/README_zh.md b/README_zh.md index 352a483d..e21560d2 100644 --- a/README_zh.md +++ b/README_zh.md @@ -89,14 +89,16 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272 ## 更新日志 +[25/01/10] 我们支持了 **[Phi-4](https://huggingface.co/microsoft/phi-4)** 模型的微调。 + [24/12/21] 我们支持了使用 **[SwanLab](https://github.com/SwanHubX/SwanLab)** 跟踪与可视化实验。详细用法请参考 [此部分](#使用-swanlab-面板)。 [24/11/27] 我们支持了 **[Skywork-o1](https://huggingface.co/Skywork/Skywork-o1-Open-Llama-3.1-8B)** 模型的微调和 **[OpenO1](https://huggingface.co/datasets/O1-OPEN/OpenO1-SFT)** 数据集。 -[24/10/09] 我们支持了从 **[魔乐社区](https://modelers.cn/models)** 下载预训练模型和数据集。详细用法请参照 [此教程](#从魔乐社区下载)。 -
展开日志 +[24/10/09] 我们支持了从 **[魔乐社区](https://modelers.cn/models)** 下载预训练模型和数据集。详细用法请参照 [此教程](#从魔乐社区下载)。 + [24/09/19] 我们支持了 **[Qwen2.5](https://qwenlm.github.io/blog/qwen2.5/)** 模型的微调。 [24/08/30] 我们支持了 **[Qwen2-VL](https://qwenlm.github.io/blog/qwen2-vl/)** 模型的微调。感谢 [@simonJJJ](https://github.com/simonJJJ) 的 PR。 @@ -212,8 +214,9 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272 | [OLMo](https://huggingface.co/allenai) | 1B/7B | - | | [PaliGemma/PaliGemma2](https://huggingface.co/google) | 3B/10B/28B | paligemma | | [Phi-1.5/Phi-2](https://huggingface.co/microsoft) | 1.3B/2.7B | - | -| [Phi-3](https://huggingface.co/microsoft) | 4B/14B | phi | +| [Phi-3/Phi-3.5](https://huggingface.co/microsoft) | 4B/14B | phi | | [Phi-3-small](https://huggingface.co/microsoft) | 7B | phi_small | +| [Phi-4](https://huggingface.co/microsoft) | 14B | phi4 | | [Pixtral](https://huggingface.co/mistralai) | 12B | pixtral | | [Qwen/QwQ (1-2.5) (Code/Math/MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen | | [Qwen2-VL/QVQ](https://huggingface.co/Qwen) | 2B/7B/72B | qwen2_vl | @@ -763,7 +766,7 @@ swanlab_run_name: test_run # 可选 本仓库的代码依照 [Apache-2.0](LICENSE) 协议开源。 -使用模型权重时,请遵循对应的模型协议:[Baichuan 2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM-4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [GPT-2](https://github.com/openai/gpt-2/blob/master/LICENSE) / [Granite](LICENSE) / [Index](https://huggingface.co/IndexTeam/Index-1.9B/blob/main/LICENSE) / [InternLM2](https://github.com/InternLM/InternLM#license) / [Llama](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [Llama 2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [Llama 3](https://llama.meta.com/llama3/license/) / [MiniCPM](https://github.com/OpenBMB/MiniCPM/blob/main/MiniCPM%20Model%20License.md) / [Mistral/Mixtral/Pixtral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/Phi-2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [Skywork](https://huggingface.co/Skywork/Skywork-13B-base/blob/main/Skywork%20Community%20License.pdf) / [StarCoder 2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [TeleChat2](https://huggingface.co/Tele-AI/telechat-7B/blob/main/TeleChat%E6%A8%A1%E5%9E%8B%E7%A4%BE%E5%8C%BA%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan 2](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan) +使用模型权重时,请遵循对应的模型协议:[Baichuan 2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM-4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [GPT-2](https://github.com/openai/gpt-2/blob/master/LICENSE) / [Granite](LICENSE) / [Index](https://huggingface.co/IndexTeam/Index-1.9B/blob/main/LICENSE) / [InternLM2](https://github.com/InternLM/InternLM#license) / [Llama](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [Llama 2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [Llama 3](https://llama.meta.com/llama3/license/) / [MiniCPM](https://github.com/OpenBMB/MiniCPM/blob/main/MiniCPM%20Model%20License.md) / [Mistral/Mixtral/Pixtral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/Phi-2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3/Phi-4](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [Skywork](https://huggingface.co/Skywork/Skywork-13B-base/blob/main/Skywork%20Community%20License.pdf) / [StarCoder 2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [TeleChat2](https://huggingface.co/Tele-AI/telechat-7B/blob/main/TeleChat%E6%A8%A1%E5%9E%8B%E7%A4%BE%E5%8C%BA%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan 2](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan) ## 引用 diff --git a/src/llamafactory/data/template.py b/src/llamafactory/data/template.py index ebe31553..490b571d 100644 --- a/src/llamafactory/data/template.py +++ b/src/llamafactory/data/template.py @@ -44,7 +44,6 @@ class Template: format_function: "Formatter" format_observation: "Formatter" format_tools: "Formatter" - format_separator: "Formatter" format_prefix: "Formatter" default_system: str stop_words: List[str] @@ -113,9 +112,6 @@ class Template: tool_text = self.format_tools.apply(content=tools)[0] if tools else "" elements += self.format_system.apply(content=(system + tool_text)) - if i > 0 and i % 2 == 0: - elements += self.format_separator.apply() - if message["role"] == Role.USER.value: elements += self.format_user.apply(content=message["content"], idx=str(i // 2)) elif message["role"] == Role.ASSISTANT.value: @@ -180,9 +176,6 @@ class Llama2Template(Template): tool_text = self.format_tools.apply(content=tools)[0] if tools else "" system_text = self.format_system.apply(content=(system + tool_text))[0] - if i > 0 and i % 2 == 0: - elements += self.format_separator.apply() - if message["role"] == Role.USER.value: elements += self.format_user.apply(content=system_text + message["content"]) elif message["role"] == Role.ASSISTANT.value: @@ -210,7 +203,6 @@ def _register_template( format_function: Optional["Formatter"] = None, format_observation: Optional["Formatter"] = None, format_tools: Optional["Formatter"] = None, - format_separator: Optional["Formatter"] = None, format_prefix: Optional["Formatter"] = None, default_system: str = "", stop_words: Sequence[str] = [], @@ -224,34 +216,28 @@ def _register_template( To add the following chat template: ``` - [HUMAN]: - user prompt here - [AI]: - model response here - - [HUMAN]: - user prompt here - [AI]: - model response here + user prompt here + model response here + user prompt here + model response here ``` The corresponding code should be: ``` _register_template( name="custom", - format_user=StringFormatter(slots=["[HUMAN]:\n{{content}}\n[AI]:\n"]), - format_separator=EmptyFormatter(slots=["\n\n"]), - efficient_eos=True, + format_user=StringFormatter(slots=["{{content}}\n"]), + format_assistant=StringFormatter(slots=["{{content}}\n"]), + format_prefix=EmptyFormatter(""), ) ``` """ - template_class = Llama2Template if any(k in name for k in ("llama2", "mistral")) else Template + template_class = Llama2Template if any(k in name for k in ("llama2", "mistral", "pixtral")) else Template default_slots = ["{{content}}"] if efficient_eos else ["{{content}}", {"eos_token"}] default_user_formatter = StringFormatter(slots=["{{content}}"]) default_assistant_formatter = StringFormatter(slots=default_slots) default_function_formatter = FunctionFormatter(slots=default_slots, tool_format="default") default_tool_formatter = ToolFormatter(tool_format="default") - default_separator_formatter = EmptyFormatter() default_prefix_formatter = EmptyFormatter() TEMPLATES[name] = template_class( format_user=format_user or default_user_formatter, @@ -260,7 +246,6 @@ def _register_template( format_function=format_function or default_function_formatter, format_observation=format_observation or format_user or default_user_formatter, format_tools=format_tools or default_tool_formatter, - format_separator=format_separator or default_separator_formatter, format_prefix=format_prefix or default_prefix_formatter, default_system=default_system, stop_words=stop_words, @@ -344,9 +329,7 @@ def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer") jinja_template += "{{ " + user_message + " }}" jinja_template += "{% elif message['role'] == 'assistant' %}" - assistant_message = _convert_slots_to_jinja( - template.format_assistant.apply() + template.format_separator.apply(), tokenizer - ) + assistant_message = _convert_slots_to_jinja(template.format_assistant.apply(), tokenizer) jinja_template += "{{ " + assistant_message + " }}" jinja_template += "{% endif %}" jinja_template += "{% endfor %}" @@ -411,7 +394,7 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args: _register_template( name="alpaca", format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n\n### Response:\n"]), - format_separator=EmptyFormatter(slots=["\n\n"]), + format_assistant=StringFormatter(slots=["{{content}}", {"eos_token"}, "\n\n"]), default_system=( "Below is an instruction that describes a task. " "Write a response that appropriately completes the request.\n\n" @@ -423,13 +406,13 @@ _register_template( _register_template( name="aquila", format_user=StringFormatter(slots=["Human: {{content}}###Assistant:"]), - format_separator=EmptyFormatter(slots=["###"]), + format_assistant=StringFormatter(slots=["{{content}}###"]), + format_system=StringFormatter(slots=["System: {{content}}###"]), default_system=( "A chat between a curious human and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the human's questions." ), stop_words=[""], - efficient_eos=True, ) @@ -459,7 +442,7 @@ _register_template( _register_template( name="belle", format_user=StringFormatter(slots=["Human: {{content}}\n\nBelle: "]), - format_separator=EmptyFormatter(slots=["\n\n"]), + format_assistant=StringFormatter(slots=["{{content}}", {"eos_token"}, "\n\n"]), format_prefix=EmptyFormatter(slots=[{"bos_token"}]), ) @@ -481,7 +464,6 @@ _register_template( _register_template( name="chatglm2", format_user=StringFormatter(slots=["[Round {{idx}}]\n\n问:{{content}}\n\n答:"]), - format_separator=EmptyFormatter(slots=["\n\n"]), format_prefix=EmptyFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}]), efficient_eos=True, ) @@ -506,9 +488,9 @@ _register_template( _register_template( name="chatml", format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), + format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]), format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), - format_separator=EmptyFormatter(slots=["\n"]), stop_words=["<|im_end|>", "<|im_start|>"], replace_eos=True, replace_jinja_template=True, @@ -519,9 +501,9 @@ _register_template( _register_template( name="chatml_de", format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), + format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]), format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), - format_separator=EmptyFormatter(slots=["\n"]), default_system="Du bist ein freundlicher und hilfsbereiter KI-Assistent.", stop_words=["<|im_end|>", "<|im_start|>"], replace_eos=True, @@ -574,9 +556,11 @@ _register_template( ) +# copied from chatml template _register_template( name="cpm3", format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), + format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]), format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), format_prefix=EmptyFormatter(slots=[{"bos_token"}]), stop_words=["<|im_end|>"], @@ -587,9 +571,9 @@ _register_template( _register_template( name="dbrx", format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), + format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]), format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), - format_separator=EmptyFormatter(slots=["\n"]), default_system=( "You are DBRX, created by Databricks. You were last updated in December 2023. " "You answer questions based on information available up to that point.\n" @@ -606,7 +590,6 @@ _register_template( "ABOUT YOURSELF UNLESS THE INFORMATION IS DIRECTLY PERTINENT TO THE USER'S QUERY." ), stop_words=["<|im_end|>"], - replace_eos=True, ) @@ -628,8 +611,7 @@ _register_template( _register_template( name="deepseekcoder", format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n### Response:"]), - format_assistant=StringFormatter(slots=["\n{{content}}\n<|EOT|>"]), - format_separator=EmptyFormatter(slots=["\n"]), + format_assistant=StringFormatter(slots=["\n{{content}}\n<|EOT|>\n"]), format_prefix=EmptyFormatter(slots=[{"bos_token"}]), default_system=( "You are an AI programming assistant, utilizing the DeepSeek Coder model, " @@ -643,8 +625,8 @@ _register_template( _register_template( name="default", format_user=StringFormatter(slots=["Human: {{content}}\nAssistant:"]), - format_system=StringFormatter(slots=["{{content}}\n"]), - format_separator=EmptyFormatter(slots=["\n"]), + format_assistant=StringFormatter(slots=["{{content}}", {"eos_token"}, "\n"]), + format_system=StringFormatter(slots=["System: {{content}}\n"]), ) @@ -657,22 +639,22 @@ _register_template( _register_template( name="exaone", format_user=StringFormatter(slots=["[|user|]{{content}}\n[|assistant|]"]), + format_assistant=StringFormatter(slots=["{{content}}", {"eos_token"}, "\n"]), format_system=StringFormatter(slots=["[|system|]{{content}}[|endofturn|]\n"]), - format_separator=EmptyFormatter(slots=["\n"]), ) _register_template( name="falcon", format_user=StringFormatter(slots=["User: {{content}}\nFalcon:"]), - format_separator=EmptyFormatter(slots=["\n"]), + format_assistant=StringFormatter(slots=["{{content}}\n"]), efficient_eos=True, ) _register_template( name="fewshot", - format_separator=EmptyFormatter(slots=["\n\n"]), + format_assistant=StringFormatter(slots=["{{content}}\n\n"]), efficient_eos=True, ) @@ -680,12 +662,11 @@ _register_template( _register_template( name="gemma", format_user=StringFormatter(slots=["user\n{{content}}\nmodel\n"]), + format_assistant=StringFormatter(slots=["{{content}}\n"]), format_observation=StringFormatter( slots=["tool\n{{content}}\nmodel\n"] ), - format_separator=EmptyFormatter(slots=["\n"]), format_prefix=EmptyFormatter(slots=[{"bos_token"}]), - efficient_eos=True, ) @@ -710,8 +691,8 @@ _register_template( "<|start_of_role|>user<|end_of_role|>{{content}}<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>" ] ), + format_assistant=StringFormatter(slots=["{{content}}<|end_of_text|>\n"]), format_system=StringFormatter(slots=["<|start_of_role|>system<|end_of_role|>{{content}}<|end_of_text|>\n"]), - format_separator=EmptyFormatter(slots=["\n"]), ) @@ -726,22 +707,20 @@ _register_template( _register_template( name="intern", format_user=StringFormatter(slots=["<|User|>:{{content}}\n<|Bot|>:"]), + format_assistant=StringFormatter(slots=["{{content}}\n"]), format_system=StringFormatter(slots=["<|System|>:{{content}}\n"]), - format_separator=EmptyFormatter(slots=["\n"]), format_prefix=EmptyFormatter(slots=[{"bos_token"}]), stop_words=[""], - efficient_eos=True, # internlm tokenizer cannot set eos_token_id ) _register_template( name="intern2", format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), + format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]), format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), - format_separator=EmptyFormatter(slots=["<|im_end|>\n"]), format_prefix=EmptyFormatter(slots=[{"bos_token"}]), stop_words=["<|im_end|>"], - efficient_eos=True, # internlm2 tokenizer cannot set eos_token_id ) @@ -872,6 +851,7 @@ _register_template( name="llava_next_mistral", format_user=StringFormatter(slots=["[INST] {{content}}[/INST]"]), format_assistant=StringFormatter(slots=[" {{content}}", {"eos_token"}]), + format_system=StringFormatter(slots=["{{content}}\n\n"]), format_function=FunctionFormatter(slots=["[TOOL_CALLS] ", "{{content}}", {"eos_token"}], tool_format="mistral"), format_observation=StringFormatter(slots=["""[TOOL_RESULTS] {"content": {{content}}}[/TOOL_RESULTS]"""]), format_tools=ToolFormatter(tool_format="mistral"), @@ -884,16 +864,15 @@ _register_template( _register_template( name="llava_next_qwen", format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), + format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]), format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), - format_function=FunctionFormatter(slots=["{{content}}", "<|im_end|>"], tool_format="qwen"), + format_function=FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen"), format_observation=StringFormatter( slots=["<|im_start|>user\n\n{{content}}\n<|im_end|>\n<|im_start|>assistant\n"] ), format_tools=ToolFormatter(tool_format="qwen"), - format_separator=EmptyFormatter(slots=["\n"]), default_system="You are a helpful assistant.", stop_words=["<|im_end|>"], - replace_eos=True, mm_plugin=get_mm_plugin(name="llava_next", image_token=""), ) @@ -902,10 +881,9 @@ _register_template( _register_template( name="llava_next_yi", format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), + format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]), format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), - format_separator=EmptyFormatter(slots=["\n"]), stop_words=["<|im_end|>"], - replace_eos=True, mm_plugin=get_mm_plugin(name="llava_next", image_token=""), ) @@ -927,6 +905,7 @@ _register_template( name="llava_next_video_mistral", format_user=StringFormatter(slots=["[INST] {{content}}[/INST]"]), format_assistant=StringFormatter(slots=[" {{content}}", {"eos_token"}]), + format_system=StringFormatter(slots=["{{content}}\n\n"]), format_function=FunctionFormatter(slots=["[TOOL_CALLS] ", "{{content}}", {"eos_token"}], tool_format="mistral"), format_observation=StringFormatter(slots=["""[TOOL_RESULTS] {"content": {{content}}}[/TOOL_RESULTS]"""]), format_tools=ToolFormatter(tool_format="mistral"), @@ -939,10 +918,9 @@ _register_template( _register_template( name="llava_next_video_yi", format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), + format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]), format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), - format_separator=EmptyFormatter(slots=["\n"]), stop_words=["<|im_end|>"], - replace_eos=True, mm_plugin=get_mm_plugin(name="llava_next_video", image_token="", video_token="