Merge branch 'main' into minicpmv

Former-commit-id: fc045d7dd871985d621430b5662cba882188a59c
This commit is contained in:
Zhangchi Feng 2025-01-10 20:12:07 +08:00 committed by GitHub
commit f51ac40f0a
43 changed files with 647 additions and 357 deletions

View File

@ -12,6 +12,7 @@ FORCE_CHECK_IMPORTS=
LLAMAFACTORY_VERBOSITY= LLAMAFACTORY_VERBOSITY=
USE_MODELSCOPE_HUB= USE_MODELSCOPE_HUB=
USE_OPENMIND_HUB= USE_OPENMIND_HUB=
USE_RAY=
RECORD_VRAM= RECORD_VRAM=
# torchrun # torchrun
FORCE_TORCHRUN= FORCE_TORCHRUN=

63
.github/ISSUE_TEMPLATE/1-bug-report.yml vendored Normal file
View File

@ -0,0 +1,63 @@
name: "\U0001F41B Bug / help"
description: Create a report to help us improve the LLaMA Factory
labels: ["bug", "pending"]
body:
- type: markdown
attributes:
value: |
Issues included in **[FAQs](https://github.com/hiyouga/LLaMA-Factory/issues/4614)** or those with **insufficient** information may be closed without a response.
已经包含在 **[常见问题](https://github.com/hiyouga/LLaMA-Factory/issues/4614)** 内或提供信息**不完整**的 issues 可能不会被回复。
- type: markdown
attributes:
value: |
Please do not create issues that are not related to framework bugs under this category, use **[Discussions](https://github.com/hiyouga/LLaMA-Factory/discussions/categories/q-a)** instead.
请勿在此分类下创建和框架 bug 无关的 issues请使用 **[讨论区](https://github.com/hiyouga/LLaMA-Factory/discussions/categories/q-a)**。
- type: checkboxes
id: reminder
attributes:
label: Reminder
description: |
Please ensure you have read the above rules carefully and searched the existing issues (including FAQs).
请确保您已经认真阅读了上述规则并且搜索过现有的 issues包括常见问题
options:
- label: I have read the above rules and searched the existing issues.
required: true
- type: textarea
id: system-info
validations:
required: true
attributes:
label: System Info
description: |
Please share your system info with us. You can run the command **llamafactory-cli env** and copy-paste its output below.
请提供您的系统信息。您可以在命令行运行 **llamafactory-cli env** 并将其输出复制到该文本框中。
placeholder: llamafactory version, platform, python version, ...
- type: textarea
id: reproduction
validations:
required: true
attributes:
label: Reproduction
description: |
Please provide entry arguments, error messages and stack traces that reproduces the problem.
请提供入口参数,错误日志以及异常堆栈以便于我们复现问题。
Remember to wrap your log messages with \`\`\`.
请务必使用 Markdown 标签 \`\`\` 来包裹您的日志信息。
value: |
```text
Put your message here.
```
- type: textarea
id: others
validations:
required: false
attributes:
label: Others

View File

@ -0,0 +1,41 @@
name: "\U0001F680 Feature request"
description: Submit a request for a new feature
labels: ["enhancement", "pending"]
body:
- type: markdown
attributes:
value: |
Please do not create issues that are not related to new features under this category.
请勿在此分类下创建和新特性无关的 issues。
- type: checkboxes
id: reminder
attributes:
label: Reminder
description: |
Please ensure you have read the above rules carefully and searched the existing issues.
请确保您已经认真阅读了上述规则并且搜索过现有的 issues。
options:
- label: I have read the above rules and searched the existing issues.
required: true
- type: textarea
id: description
validations:
required: true
attributes:
label: Description
description: |
A clear and concise description of the feature proposal.
请详细描述您希望加入的新功能特性。
- type: textarea
id: contribution
validations:
required: false
attributes:
label: Pull Request
description: |
Have you already created the relevant PR and submitted the code?
您是否已经创建了相关 PR 并提交了代码?

View File

@ -1,66 +0,0 @@
name: "\U0001F41B Bug / Help"
description: Create a report to help us improve the LLaMA Factory
body:
- type: markdown
attributes:
value: |
Issues included in **FAQs** or those with **insufficient** information may be closed without a response.
包含在**常见问题**内或提供信息**不完整**的 issues 可能不会被回复。
- type: checkboxes
id: reminder
attributes:
label: Reminder
description: |
Please ensure you have read the README carefully and searched the existing issues (including FAQs).
请确保您已经认真阅读了 README 并且搜索过现有的 issues包括常见问题
options:
- label: I have read the README and searched the existing issues.
required: true
- type: textarea
id: system-info
validations:
required: true
attributes:
label: System Info
description: |
Please share your system info with us. You can run the command **llamafactory-cli env** and copy-paste its output below.
请提供您的系统信息。您可以在命令行运行 **llamafactory-cli env** 并将其输出复制到该文本框中。
placeholder: llamafactory version, platform, python version, ...
- type: textarea
id: reproduction
validations:
required: true
attributes:
label: Reproduction
description: |
Please provide code snippets, error messages and stack traces that reproduces the problem.
请提供运行参数,错误信息以及异常堆栈以便于我们复现该问题。
Remember to use Markdown tags to correctly format your code.
请合理使用 Markdown 标签来格式化您的文本。
placeholder: |
```bash
llamafactory-cli train ...
```
- type: textarea
id: expected-behavior
validations:
required: false
attributes:
label: Expected behavior
description: |
Please provide a clear and concise description of what you would expect to happen.
请提供您原本的目的,即这段代码的期望行为。
- type: textarea
id: others
validations:
required: false
attributes:
label: Others

1
.github/ISSUE_TEMPLATE/config.yml vendored Normal file
View File

@ -0,0 +1 @@
blank_issues_enabled: false

View File

@ -18,13 +18,15 @@ jobs:
ISSUE_URL: ${{ github.event.issue.html_url }} ISSUE_URL: ${{ github.event.issue.html_url }}
ISSUE_TITLE: ${{ github.event.issue.title }} ISSUE_TITLE: ${{ github.event.issue.title }}
run: | run: |
LABEL=pending LABEL=""
NPU_KEYWORDS=(npu huawei ascend 华为 昇腾) NPU_KEYWORDS=(npu huawei ascend 华为 昇腾)
ISSUE_TITLE_LOWER=$(echo $ISSUE_TITLE | tr '[:upper:]' '[:lower:]') ISSUE_TITLE_LOWER=$(echo $ISSUE_TITLE | tr '[:upper:]' '[:lower:]')
for KEYWORD in ${NPU_KEYWORDS[@]}; do for KEYWORD in ${NPU_KEYWORDS[@]}; do
if [[ $ISSUE_TITLE_LOWER == *$KEYWORD* ]] && [[ $ISSUE_TITLE_LOWER != *input* ]]; then if [[ $ISSUE_TITLE_LOWER == *$KEYWORD* ]] && [[ $ISSUE_TITLE_LOWER != *input* ]]; then
LABEL=pending,npu LABEL="npu"
break break
fi fi
done done
gh issue edit $ISSUE_URL --add-label $LABEL if [ -n "$LABEL" ]; then
gh issue edit $ISSUE_URL --add-label $LABEL
fi

View File

@ -88,14 +88,16 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
## Changelog ## Changelog
[25/01/10] We supported fine-tuning the **[Phi-4](https://huggingface.co/microsoft/phi-4)** model.
[24/12/21] We supported using **[SwanLab](https://github.com/SwanHubX/SwanLab)** for experiment tracking and visualization. See [this section](#use-swanlab-logger) for details. [24/12/21] We supported using **[SwanLab](https://github.com/SwanHubX/SwanLab)** for experiment tracking and visualization. See [this section](#use-swanlab-logger) for details.
[24/11/27] We supported fine-tuning the **[Skywork-o1](https://huggingface.co/Skywork/Skywork-o1-Open-Llama-3.1-8B)** model and the **[OpenO1](https://huggingface.co/datasets/O1-OPEN/OpenO1-SFT)** dataset. [24/11/27] We supported fine-tuning the **[Skywork-o1](https://huggingface.co/Skywork/Skywork-o1-Open-Llama-3.1-8B)** model and the **[OpenO1](https://huggingface.co/datasets/O1-OPEN/OpenO1-SFT)** dataset.
[24/10/09] We supported downloading pre-trained models and datasets from the **[Modelers Hub](https://modelers.cn/models)**. See [this tutorial](#download-from-modelers-hub) for usage.
<details><summary>Full Changelog</summary> <details><summary>Full Changelog</summary>
[24/10/09] We supported downloading pre-trained models and datasets from the **[Modelers Hub](https://modelers.cn/models)**. See [this tutorial](#download-from-modelers-hub) for usage.
[24/09/19] We supported fine-tuning the **[Qwen2.5](https://qwenlm.github.io/blog/qwen2.5/)** models. [24/09/19] We supported fine-tuning the **[Qwen2.5](https://qwenlm.github.io/blog/qwen2.5/)** models.
[24/08/30] We supported fine-tuning the **[Qwen2-VL](https://qwenlm.github.io/blog/qwen2-vl/)** models. Thank [@simonJJJ](https://github.com/simonJJJ)'s PR. [24/08/30] We supported fine-tuning the **[Qwen2-VL](https://qwenlm.github.io/blog/qwen2-vl/)** models. Thank [@simonJJJ](https://github.com/simonJJJ)'s PR.
@ -211,8 +213,9 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
| [OLMo](https://huggingface.co/allenai) | 1B/7B | - | | [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
| [PaliGemma/PaliGemma2](https://huggingface.co/google) | 3B/10B/28B | paligemma | | [PaliGemma/PaliGemma2](https://huggingface.co/google) | 3B/10B/28B | paligemma |
| [Phi-1.5/Phi-2](https://huggingface.co/microsoft) | 1.3B/2.7B | - | | [Phi-1.5/Phi-2](https://huggingface.co/microsoft) | 1.3B/2.7B | - |
| [Phi-3](https://huggingface.co/microsoft) | 4B/14B | phi | | [Phi-3/Phi-3.5](https://huggingface.co/microsoft) | 4B/14B | phi |
| [Phi-3-small](https://huggingface.co/microsoft) | 7B | phi_small | | [Phi-3-small](https://huggingface.co/microsoft) | 7B | phi_small |
| [Phi-4](https://huggingface.co/microsoft) | 14B | phi4 |
| [Pixtral](https://huggingface.co/mistralai) | 12B | pixtral | | [Pixtral](https://huggingface.co/mistralai) | 12B | pixtral |
| [Qwen/QwQ (1-2.5) (Code/Math/MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen | | [Qwen/QwQ (1-2.5) (Code/Math/MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen |
| [Qwen2-VL/QVQ](https://huggingface.co/Qwen) | 2B/7B/72B | qwen2_vl | | [Qwen2-VL/QVQ](https://huggingface.co/Qwen) | 2B/7B/72B | qwen2_vl |
@ -762,7 +765,7 @@ If you have a project that should be incorporated, please contact via email or c
This repository is licensed under the [Apache-2.0 License](LICENSE). This repository is licensed under the [Apache-2.0 License](LICENSE).
Please follow the model licenses to use the corresponding model weights: [Baichuan 2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM-4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [GPT-2](https://github.com/openai/gpt-2/blob/master/LICENSE) / [Granite](LICENSE) / [Index](https://huggingface.co/IndexTeam/Index-1.9B/blob/main/LICENSE) / [InternLM2](https://github.com/InternLM/InternLM#license) / [Llama](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [Llama 2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [Llama 3](https://llama.meta.com/llama3/license/) / [MiniCPM](https://github.com/OpenBMB/MiniCPM/blob/main/MiniCPM%20Model%20License.md) / [Mistral/Mixtral/Pixtral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/Phi-2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [Skywork](https://huggingface.co/Skywork/Skywork-13B-base/blob/main/Skywork%20Community%20License.pdf) / [StarCoder 2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [TeleChat2](https://huggingface.co/Tele-AI/telechat-7B/blob/main/TeleChat%E6%A8%A1%E5%9E%8B%E7%A4%BE%E5%8C%BA%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan 2](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan) Please follow the model licenses to use the corresponding model weights: [Baichuan 2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM-4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [GPT-2](https://github.com/openai/gpt-2/blob/master/LICENSE) / [Granite](LICENSE) / [Index](https://huggingface.co/IndexTeam/Index-1.9B/blob/main/LICENSE) / [InternLM2](https://github.com/InternLM/InternLM#license) / [Llama](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [Llama 2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [Llama 3](https://llama.meta.com/llama3/license/) / [MiniCPM](https://github.com/OpenBMB/MiniCPM/blob/main/MiniCPM%20Model%20License.md) / [Mistral/Mixtral/Pixtral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/Phi-2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3/Phi-4](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [Skywork](https://huggingface.co/Skywork/Skywork-13B-base/blob/main/Skywork%20Community%20License.pdf) / [StarCoder 2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [TeleChat2](https://huggingface.co/Tele-AI/telechat-7B/blob/main/TeleChat%E6%A8%A1%E5%9E%8B%E7%A4%BE%E5%8C%BA%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan 2](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
## Citation ## Citation

View File

@ -89,14 +89,16 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272
## 更新日志 ## 更新日志
[25/01/10] 我们支持了 **[Phi-4](https://huggingface.co/microsoft/phi-4)** 模型的微调。
[24/12/21] 我们支持了使用 **[SwanLab](https://github.com/SwanHubX/SwanLab)** 跟踪与可视化实验。详细用法请参考 [此部分](#使用-swanlab-面板)。 [24/12/21] 我们支持了使用 **[SwanLab](https://github.com/SwanHubX/SwanLab)** 跟踪与可视化实验。详细用法请参考 [此部分](#使用-swanlab-面板)。
[24/11/27] 我们支持了 **[Skywork-o1](https://huggingface.co/Skywork/Skywork-o1-Open-Llama-3.1-8B)** 模型的微调和 **[OpenO1](https://huggingface.co/datasets/O1-OPEN/OpenO1-SFT)** 数据集。 [24/11/27] 我们支持了 **[Skywork-o1](https://huggingface.co/Skywork/Skywork-o1-Open-Llama-3.1-8B)** 模型的微调和 **[OpenO1](https://huggingface.co/datasets/O1-OPEN/OpenO1-SFT)** 数据集。
[24/10/09] 我们支持了从 **[魔乐社区](https://modelers.cn/models)** 下载预训练模型和数据集。详细用法请参照 [此教程](#从魔乐社区下载)。
<details><summary>展开日志</summary> <details><summary>展开日志</summary>
[24/10/09] 我们支持了从 **[魔乐社区](https://modelers.cn/models)** 下载预训练模型和数据集。详细用法请参照 [此教程](#从魔乐社区下载)。
[24/09/19] 我们支持了 **[Qwen2.5](https://qwenlm.github.io/blog/qwen2.5/)** 模型的微调。 [24/09/19] 我们支持了 **[Qwen2.5](https://qwenlm.github.io/blog/qwen2.5/)** 模型的微调。
[24/08/30] 我们支持了 **[Qwen2-VL](https://qwenlm.github.io/blog/qwen2-vl/)** 模型的微调。感谢 [@simonJJJ](https://github.com/simonJJJ) 的 PR。 [24/08/30] 我们支持了 **[Qwen2-VL](https://qwenlm.github.io/blog/qwen2-vl/)** 模型的微调。感谢 [@simonJJJ](https://github.com/simonJJJ) 的 PR。
@ -212,8 +214,9 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272
| [OLMo](https://huggingface.co/allenai) | 1B/7B | - | | [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
| [PaliGemma/PaliGemma2](https://huggingface.co/google) | 3B/10B/28B | paligemma | | [PaliGemma/PaliGemma2](https://huggingface.co/google) | 3B/10B/28B | paligemma |
| [Phi-1.5/Phi-2](https://huggingface.co/microsoft) | 1.3B/2.7B | - | | [Phi-1.5/Phi-2](https://huggingface.co/microsoft) | 1.3B/2.7B | - |
| [Phi-3](https://huggingface.co/microsoft) | 4B/14B | phi | | [Phi-3/Phi-3.5](https://huggingface.co/microsoft) | 4B/14B | phi |
| [Phi-3-small](https://huggingface.co/microsoft) | 7B | phi_small | | [Phi-3-small](https://huggingface.co/microsoft) | 7B | phi_small |
| [Phi-4](https://huggingface.co/microsoft) | 14B | phi4 |
| [Pixtral](https://huggingface.co/mistralai) | 12B | pixtral | | [Pixtral](https://huggingface.co/mistralai) | 12B | pixtral |
| [Qwen/QwQ (1-2.5) (Code/Math/MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen | | [Qwen/QwQ (1-2.5) (Code/Math/MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen |
| [Qwen2-VL/QVQ](https://huggingface.co/Qwen) | 2B/7B/72B | qwen2_vl | | [Qwen2-VL/QVQ](https://huggingface.co/Qwen) | 2B/7B/72B | qwen2_vl |
@ -763,7 +766,7 @@ swanlab_run_name: test_run # 可选
本仓库的代码依照 [Apache-2.0](LICENSE) 协议开源。 本仓库的代码依照 [Apache-2.0](LICENSE) 协议开源。
使用模型权重时,请遵循对应的模型协议:[Baichuan 2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM-4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [GPT-2](https://github.com/openai/gpt-2/blob/master/LICENSE) / [Granite](LICENSE) / [Index](https://huggingface.co/IndexTeam/Index-1.9B/blob/main/LICENSE) / [InternLM2](https://github.com/InternLM/InternLM#license) / [Llama](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [Llama 2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [Llama 3](https://llama.meta.com/llama3/license/) / [MiniCPM](https://github.com/OpenBMB/MiniCPM/blob/main/MiniCPM%20Model%20License.md) / [Mistral/Mixtral/Pixtral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/Phi-2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [Skywork](https://huggingface.co/Skywork/Skywork-13B-base/blob/main/Skywork%20Community%20License.pdf) / [StarCoder 2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [TeleChat2](https://huggingface.co/Tele-AI/telechat-7B/blob/main/TeleChat%E6%A8%A1%E5%9E%8B%E7%A4%BE%E5%8C%BA%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan 2](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan) 使用模型权重时,请遵循对应的模型协议:[Baichuan 2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM-4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [GPT-2](https://github.com/openai/gpt-2/blob/master/LICENSE) / [Granite](LICENSE) / [Index](https://huggingface.co/IndexTeam/Index-1.9B/blob/main/LICENSE) / [InternLM2](https://github.com/InternLM/InternLM#license) / [Llama](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [Llama 2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [Llama 3](https://llama.meta.com/llama3/license/) / [MiniCPM](https://github.com/OpenBMB/MiniCPM/blob/main/MiniCPM%20Model%20License.md) / [Mistral/Mixtral/Pixtral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/Phi-2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3/Phi-4](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [Skywork](https://huggingface.co/Skywork/Skywork-13B-base/blob/main/Skywork%20Community%20License.pdf) / [StarCoder 2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [TeleChat2](https://huggingface.co/Tele-AI/telechat-7B/blob/main/TeleChat%E6%A8%A1%E5%9E%8B%E7%A4%BE%E5%8C%BA%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan 2](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
## 引用 ## 引用

Binary file not shown.

Before

Width:  |  Height:  |  Size: 163 KiB

After

Width:  |  Height:  |  Size: 164 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 168 KiB

After

Width:  |  Height:  |  Size: 167 KiB

View File

@ -95,6 +95,12 @@ FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500
FORCE_TORCHRUN=1 llamafactory-cli train examples/train_lora/llama3_lora_sft_ds3.yaml FORCE_TORCHRUN=1 llamafactory-cli train examples/train_lora/llama3_lora_sft_ds3.yaml
``` ```
#### Supervised Fine-Tuning with Ray on 4 GPUs
```bash
USE_RAY=1 llamafactory-cli train examples/train_full/llama3_lora_sft_ray.yaml
```
### QLoRA Fine-Tuning ### QLoRA Fine-Tuning
#### Supervised Fine-Tuning with 4/8-bit Bitsandbytes/HQQ/EETQ Quantization (Recommended) #### Supervised Fine-Tuning with 4/8-bit Bitsandbytes/HQQ/EETQ Quantization (Recommended)

View File

@ -95,6 +95,12 @@ FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500
FORCE_TORCHRUN=1 llamafactory-cli train examples/train_lora/llama3_lora_sft_ds3.yaml FORCE_TORCHRUN=1 llamafactory-cli train examples/train_lora/llama3_lora_sft_ds3.yaml
``` ```
#### 使用 Ray 在 4 张 GPU 上微调
```bash
USE_RAY=1 llamafactory-cli train examples/train_full/llama3_lora_sft_ray.yaml
```
### QLoRA 微调 ### QLoRA 微调
#### 基于 4/8 比特 Bitsandbytes/HQQ/EETQ 量化进行指令监督微调(推荐) #### 基于 4/8 比特 Bitsandbytes/HQQ/EETQ 量化进行指令监督微调(推荐)

View File

@ -0,0 +1,48 @@
### model
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct # or use local absolute path
trust_remote_code: true
### method
stage: sft
do_train: true
finetuning_type: lora
lora_target: all
### dataset
dataset: identity,alpaca_en_demo
dataset_dir: REMOTE:llamafactory/demo_data # or use local absolute path
template: llama3
cutoff_len: 2048
max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16
### output
output_dir: tmp_dir
logging_steps: 10
save_steps: 500
plot_loss: true
overwrite_output_dir: true
### train
per_device_train_batch_size: 1
gradient_accumulation_steps: 8
learning_rate: 1.0e-4
num_train_epochs: 3.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
bf16: true
ddp_timeout: 180000000
### eval
val_size: 0.1
per_device_eval_batch_size: 1
eval_strategy: steps
eval_steps: 500
### ray
ray_run_name: llama3_8b_sft_lora
ray_num_workers: 4 # number of GPUs to use
resources_per_worker:
GPU: 1
placement_strategy: PACK

View File

@ -63,7 +63,7 @@ class HuggingfaceEngine(BaseEngine):
try: try:
asyncio.get_event_loop() asyncio.get_event_loop()
except RuntimeError: except RuntimeError:
logger.warning_once("There is no current event loop, creating a new one.") logger.warning_rank0_once("There is no current event loop, creating a new one.")
loop = asyncio.new_event_loop() loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop) asyncio.set_event_loop(loop)

View File

@ -24,7 +24,7 @@ from .chat.chat_model import run_chat
from .eval.evaluator import run_eval from .eval.evaluator import run_eval
from .extras import logging from .extras import logging
from .extras.env import VERSION, print_env from .extras.env import VERSION, print_env
from .extras.misc import get_device_count from .extras.misc import get_device_count, use_ray
from .train.tuner import export_model, run_exp from .train.tuner import export_model, run_exp
from .webui.interface import run_web_demo, run_web_ui from .webui.interface import run_web_demo, run_web_ui
@ -87,7 +87,7 @@ def main():
export_model() export_model()
elif command == Command.TRAIN: elif command == Command.TRAIN:
force_torchrun = os.getenv("FORCE_TORCHRUN", "0").lower() in ["true", "1"] force_torchrun = os.getenv("FORCE_TORCHRUN", "0").lower() in ["true", "1"]
if force_torchrun or get_device_count() > 1: if force_torchrun or (get_device_count() > 1 and not use_ray()):
master_addr = os.getenv("MASTER_ADDR", "127.0.0.1") master_addr = os.getenv("MASTER_ADDR", "127.0.0.1")
master_port = os.getenv("MASTER_PORT", str(random.randint(20001, 29999))) master_port = os.getenv("MASTER_PORT", str(random.randint(20001, 29999)))
logger.info_rank0(f"Initializing distributed tasks at: {master_addr}:{master_port}") logger.info_rank0(f"Initializing distributed tasks at: {master_addr}:{master_port}")

View File

@ -56,12 +56,12 @@ def merge_dataset(
return all_datasets[0] return all_datasets[0]
elif data_args.mix_strategy == "concat": elif data_args.mix_strategy == "concat":
if data_args.streaming: if data_args.streaming:
logger.warning_once("The samples between different datasets will not be mixed in streaming mode.") logger.warning_rank0_once("The samples between different datasets will not be mixed in streaming mode.")
return concatenate_datasets(all_datasets) return concatenate_datasets(all_datasets)
elif data_args.mix_strategy.startswith("interleave"): elif data_args.mix_strategy.startswith("interleave"):
if not data_args.streaming: if not data_args.streaming:
logger.warning_once("We recommend using `mix_strategy=concat` in non-streaming mode.") logger.warning_rank0_once("We recommend using `mix_strategy=concat` in non-streaming mode.")
return interleave_datasets( return interleave_datasets(
datasets=all_datasets, datasets=all_datasets,

View File

@ -18,11 +18,10 @@ from typing import TYPE_CHECKING, Dict, Literal, Optional, Sequence, Union
import numpy as np import numpy as np
from datasets import DatasetDict, load_dataset, load_from_disk from datasets import DatasetDict, load_dataset, load_from_disk
from transformers.utils.versions import require_version
from ..extras import logging from ..extras import logging
from ..extras.constants import FILEEXT2TYPE from ..extras.constants import FILEEXT2TYPE
from ..extras.misc import has_tokenized_data from ..extras.misc import check_version, has_tokenized_data
from .aligner import align_dataset from .aligner import align_dataset
from .data_utils import merge_dataset, split_dataset from .data_utils import merge_dataset, split_dataset
from .parser import get_dataset_list from .parser import get_dataset_list
@ -84,7 +83,7 @@ def _load_single_dataset(
raise NotImplementedError(f"Unknown load type: {dataset_attr.load_from}.") raise NotImplementedError(f"Unknown load type: {dataset_attr.load_from}.")
if dataset_attr.load_from == "ms_hub": if dataset_attr.load_from == "ms_hub":
require_version("modelscope>=1.11.0", "To fix: pip install modelscope>=1.11.0") check_version("modelscope>=1.11.0", mandatory=True)
from modelscope import MsDataset # type: ignore from modelscope import MsDataset # type: ignore
from modelscope.utils.config_ds import MS_DATASETS_CACHE # type: ignore from modelscope.utils.config_ds import MS_DATASETS_CACHE # type: ignore
@ -103,7 +102,7 @@ def _load_single_dataset(
dataset = dataset.to_hf_dataset() dataset = dataset.to_hf_dataset()
elif dataset_attr.load_from == "om_hub": elif dataset_attr.load_from == "om_hub":
require_version("openmind>=0.8.0", "To fix: pip install openmind>=0.8.0") check_version("openmind>=0.8.0", mandatory=True)
from openmind import OmDataset # type: ignore from openmind import OmDataset # type: ignore
from openmind.utils.hub import OM_DATASETS_CACHE # type: ignore from openmind.utils.hub import OM_DATASETS_CACHE # type: ignore

View File

@ -75,10 +75,14 @@ class BasePlugin:
Validates if this model accepts the input modalities. Validates if this model accepts the input modalities.
""" """
if len(images) != 0 and self.image_token is None: if len(images) != 0 and self.image_token is None:
raise ValueError("This model does not support image input.") raise ValueError(
"This model does not support image input. Please check whether the correct `template` is used."
)
if len(videos) != 0 and self.video_token is None: if len(videos) != 0 and self.video_token is None:
raise ValueError("This model does not support video input.") raise ValueError(
"This model does not support video input. Please check whether the correct `template` is used."
)
def _preprocess_image(self, image: "ImageObject", **kwargs) -> "ImageObject": def _preprocess_image(self, image: "ImageObject", **kwargs) -> "ImageObject":
r""" r"""

View File

@ -15,10 +15,10 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union
from transformers.utils.versions import require_version
from typing_extensions import override from typing_extensions import override
from ..extras import logging from ..extras import logging
from ..extras.misc import check_version
from .data_utils import Role from .data_utils import Role
from .formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter from .formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter
from .mm_plugin import get_mm_plugin from .mm_plugin import get_mm_plugin
@ -44,7 +44,6 @@ class Template:
format_function: "Formatter" format_function: "Formatter"
format_observation: "Formatter" format_observation: "Formatter"
format_tools: "Formatter" format_tools: "Formatter"
format_separator: "Formatter"
format_prefix: "Formatter" format_prefix: "Formatter"
default_system: str default_system: str
stop_words: List[str] stop_words: List[str]
@ -113,9 +112,6 @@ class Template:
tool_text = self.format_tools.apply(content=tools)[0] if tools else "" tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
elements += self.format_system.apply(content=(system + tool_text)) elements += self.format_system.apply(content=(system + tool_text))
if i > 0 and i % 2 == 0:
elements += self.format_separator.apply()
if message["role"] == Role.USER.value: if message["role"] == Role.USER.value:
elements += self.format_user.apply(content=message["content"], idx=str(i // 2)) elements += self.format_user.apply(content=message["content"], idx=str(i // 2))
elif message["role"] == Role.ASSISTANT.value: elif message["role"] == Role.ASSISTANT.value:
@ -180,9 +176,6 @@ class Llama2Template(Template):
tool_text = self.format_tools.apply(content=tools)[0] if tools else "" tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
system_text = self.format_system.apply(content=(system + tool_text))[0] system_text = self.format_system.apply(content=(system + tool_text))[0]
if i > 0 and i % 2 == 0:
elements += self.format_separator.apply()
if message["role"] == Role.USER.value: if message["role"] == Role.USER.value:
elements += self.format_user.apply(content=system_text + message["content"]) elements += self.format_user.apply(content=system_text + message["content"])
elif message["role"] == Role.ASSISTANT.value: elif message["role"] == Role.ASSISTANT.value:
@ -210,7 +203,6 @@ def _register_template(
format_function: Optional["Formatter"] = None, format_function: Optional["Formatter"] = None,
format_observation: Optional["Formatter"] = None, format_observation: Optional["Formatter"] = None,
format_tools: Optional["Formatter"] = None, format_tools: Optional["Formatter"] = None,
format_separator: Optional["Formatter"] = None,
format_prefix: Optional["Formatter"] = None, format_prefix: Optional["Formatter"] = None,
default_system: str = "", default_system: str = "",
stop_words: Sequence[str] = [], stop_words: Sequence[str] = [],
@ -224,34 +216,28 @@ def _register_template(
To add the following chat template: To add the following chat template:
``` ```
[HUMAN]: <s><user>user prompt here
user prompt here <model>model response here</s>
[AI]: <user>user prompt here
model response here <model>model response here</s>
[HUMAN]:
user prompt here
[AI]:
model response here
``` ```
The corresponding code should be: The corresponding code should be:
``` ```
_register_template( _register_template(
name="custom", name="custom",
format_user=StringFormatter(slots=["[HUMAN]:\n{{content}}\n[AI]:\n"]), format_user=StringFormatter(slots=["<user>{{content}}\n<model>"]),
format_separator=EmptyFormatter(slots=["\n\n"]), format_assistant=StringFormatter(slots=["{{content}}</s>\n"]),
efficient_eos=True, format_prefix=EmptyFormatter("<s>"),
) )
``` ```
""" """
template_class = Llama2Template if any(k in name for k in ("llama2", "mistral")) else Template template_class = Llama2Template if any(k in name for k in ("llama2", "mistral", "pixtral")) else Template
default_slots = ["{{content}}"] if efficient_eos else ["{{content}}", {"eos_token"}] default_slots = ["{{content}}"] if efficient_eos else ["{{content}}", {"eos_token"}]
default_user_formatter = StringFormatter(slots=["{{content}}"]) default_user_formatter = StringFormatter(slots=["{{content}}"])
default_assistant_formatter = StringFormatter(slots=default_slots) default_assistant_formatter = StringFormatter(slots=default_slots)
default_function_formatter = FunctionFormatter(slots=default_slots, tool_format="default") default_function_formatter = FunctionFormatter(slots=default_slots, tool_format="default")
default_tool_formatter = ToolFormatter(tool_format="default") default_tool_formatter = ToolFormatter(tool_format="default")
default_separator_formatter = EmptyFormatter()
default_prefix_formatter = EmptyFormatter() default_prefix_formatter = EmptyFormatter()
TEMPLATES[name] = template_class( TEMPLATES[name] = template_class(
format_user=format_user or default_user_formatter, format_user=format_user or default_user_formatter,
@ -260,7 +246,6 @@ def _register_template(
format_function=format_function or default_function_formatter, format_function=format_function or default_function_formatter,
format_observation=format_observation or format_user or default_user_formatter, format_observation=format_observation or format_user or default_user_formatter,
format_tools=format_tools or default_tool_formatter, format_tools=format_tools or default_tool_formatter,
format_separator=format_separator or default_separator_formatter,
format_prefix=format_prefix or default_prefix_formatter, format_prefix=format_prefix or default_prefix_formatter,
default_system=default_system, default_system=default_system,
stop_words=stop_words, stop_words=stop_words,
@ -344,9 +329,7 @@ def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer")
jinja_template += "{{ " + user_message + " }}" jinja_template += "{{ " + user_message + " }}"
jinja_template += "{% elif message['role'] == 'assistant' %}" jinja_template += "{% elif message['role'] == 'assistant' %}"
assistant_message = _convert_slots_to_jinja( assistant_message = _convert_slots_to_jinja(template.format_assistant.apply(), tokenizer)
template.format_assistant.apply() + template.format_separator.apply(), tokenizer
)
jinja_template += "{{ " + assistant_message + " }}" jinja_template += "{{ " + assistant_message + " }}"
jinja_template += "{% endif %}" jinja_template += "{% endif %}"
jinja_template += "{% endfor %}" jinja_template += "{% endfor %}"
@ -365,7 +348,7 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args:
raise ValueError(f"Template {data_args.template} does not exist.") raise ValueError(f"Template {data_args.template} does not exist.")
if template.mm_plugin.__class__.__name__ != "BasePlugin": if template.mm_plugin.__class__.__name__ != "BasePlugin":
require_version("transformers>=4.45.0", "To fix: pip install transformers>=4.45.0") check_version("transformers>=4.45.0")
if data_args.train_on_prompt and template.efficient_eos: if data_args.train_on_prompt and template.efficient_eos:
raise ValueError("Current template does not support `train_on_prompt`.") raise ValueError("Current template does not support `train_on_prompt`.")
@ -411,7 +394,7 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args:
_register_template( _register_template(
name="alpaca", name="alpaca",
format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n\n### Response:\n"]), format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n\n### Response:\n"]),
format_separator=EmptyFormatter(slots=["\n\n"]), format_assistant=StringFormatter(slots=["{{content}}", {"eos_token"}, "\n\n"]),
default_system=( default_system=(
"Below is an instruction that describes a task. " "Below is an instruction that describes a task. "
"Write a response that appropriately completes the request.\n\n" "Write a response that appropriately completes the request.\n\n"
@ -423,13 +406,13 @@ _register_template(
_register_template( _register_template(
name="aquila", name="aquila",
format_user=StringFormatter(slots=["Human: {{content}}###Assistant:"]), format_user=StringFormatter(slots=["Human: {{content}}###Assistant:"]),
format_separator=EmptyFormatter(slots=["###"]), format_assistant=StringFormatter(slots=["{{content}}###"]),
format_system=StringFormatter(slots=["System: {{content}}###"]),
default_system=( default_system=(
"A chat between a curious human and an artificial intelligence assistant. " "A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the human's questions." "The assistant gives helpful, detailed, and polite answers to the human's questions."
), ),
stop_words=["</s>"], stop_words=["</s>"],
efficient_eos=True,
) )
@ -459,7 +442,7 @@ _register_template(
_register_template( _register_template(
name="belle", name="belle",
format_user=StringFormatter(slots=["Human: {{content}}\n\nBelle: "]), format_user=StringFormatter(slots=["Human: {{content}}\n\nBelle: "]),
format_separator=EmptyFormatter(slots=["\n\n"]), format_assistant=StringFormatter(slots=["{{content}}", {"eos_token"}, "\n\n"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]), format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
) )
@ -481,7 +464,6 @@ _register_template(
_register_template( _register_template(
name="chatglm2", name="chatglm2",
format_user=StringFormatter(slots=["[Round {{idx}}]\n\n问:{{content}}\n\n答:"]), format_user=StringFormatter(slots=["[Round {{idx}}]\n\n问:{{content}}\n\n答:"]),
format_separator=EmptyFormatter(slots=["\n\n"]),
format_prefix=EmptyFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}]), format_prefix=EmptyFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}]),
efficient_eos=True, efficient_eos=True,
) )
@ -506,9 +488,9 @@ _register_template(
_register_template( _register_template(
name="chatml", name="chatml",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
stop_words=["<|im_end|>", "<|im_start|>"], stop_words=["<|im_end|>", "<|im_start|>"],
replace_eos=True, replace_eos=True,
replace_jinja_template=True, replace_jinja_template=True,
@ -519,9 +501,9 @@ _register_template(
_register_template( _register_template(
name="chatml_de", name="chatml_de",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
default_system="Du bist ein freundlicher und hilfsbereiter KI-Assistent.", default_system="Du bist ein freundlicher und hilfsbereiter KI-Assistent.",
stop_words=["<|im_end|>", "<|im_start|>"], stop_words=["<|im_end|>", "<|im_start|>"],
replace_eos=True, replace_eos=True,
@ -574,9 +556,11 @@ _register_template(
) )
# copied from chatml template
_register_template( _register_template(
name="cpm3", name="cpm3",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]), format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<|im_end|>"], stop_words=["<|im_end|>"],
@ -603,9 +587,9 @@ _register_template(
_register_template( _register_template(
name="dbrx", name="dbrx",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
default_system=( default_system=(
"You are DBRX, created by Databricks. You were last updated in December 2023. " "You are DBRX, created by Databricks. You were last updated in December 2023. "
"You answer questions based on information available up to that point.\n" "You answer questions based on information available up to that point.\n"
@ -622,7 +606,6 @@ _register_template(
"ABOUT YOURSELF UNLESS THE INFORMATION IS DIRECTLY PERTINENT TO THE USER'S QUERY." "ABOUT YOURSELF UNLESS THE INFORMATION IS DIRECTLY PERTINENT TO THE USER'S QUERY."
), ),
stop_words=["<|im_end|>"], stop_words=["<|im_end|>"],
replace_eos=True,
) )
@ -644,8 +627,7 @@ _register_template(
_register_template( _register_template(
name="deepseekcoder", name="deepseekcoder",
format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n### Response:"]), format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n### Response:"]),
format_assistant=StringFormatter(slots=["\n{{content}}\n<|EOT|>"]), format_assistant=StringFormatter(slots=["\n{{content}}\n<|EOT|>\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]), format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
default_system=( default_system=(
"You are an AI programming assistant, utilizing the DeepSeek Coder model, " "You are an AI programming assistant, utilizing the DeepSeek Coder model, "
@ -659,8 +641,8 @@ _register_template(
_register_template( _register_template(
name="default", name="default",
format_user=StringFormatter(slots=["Human: {{content}}\nAssistant:"]), format_user=StringFormatter(slots=["Human: {{content}}\nAssistant:"]),
format_system=StringFormatter(slots=["{{content}}\n"]), format_assistant=StringFormatter(slots=["{{content}}", {"eos_token"}, "\n"]),
format_separator=EmptyFormatter(slots=["\n"]), format_system=StringFormatter(slots=["System: {{content}}\n"]),
) )
@ -673,22 +655,22 @@ _register_template(
_register_template( _register_template(
name="exaone", name="exaone",
format_user=StringFormatter(slots=["[|user|]{{content}}\n[|assistant|]"]), format_user=StringFormatter(slots=["[|user|]{{content}}\n[|assistant|]"]),
format_assistant=StringFormatter(slots=["{{content}}", {"eos_token"}, "\n"]),
format_system=StringFormatter(slots=["[|system|]{{content}}[|endofturn|]\n"]), format_system=StringFormatter(slots=["[|system|]{{content}}[|endofturn|]\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
) )
_register_template( _register_template(
name="falcon", name="falcon",
format_user=StringFormatter(slots=["User: {{content}}\nFalcon:"]), format_user=StringFormatter(slots=["User: {{content}}\nFalcon:"]),
format_separator=EmptyFormatter(slots=["\n"]), format_assistant=StringFormatter(slots=["{{content}}\n"]),
efficient_eos=True, efficient_eos=True,
) )
_register_template( _register_template(
name="fewshot", name="fewshot",
format_separator=EmptyFormatter(slots=["\n\n"]), format_assistant=StringFormatter(slots=["{{content}}\n\n"]),
efficient_eos=True, efficient_eos=True,
) )
@ -696,12 +678,11 @@ _register_template(
_register_template( _register_template(
name="gemma", name="gemma",
format_user=StringFormatter(slots=["<start_of_turn>user\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]), format_user=StringFormatter(slots=["<start_of_turn>user\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]),
format_assistant=StringFormatter(slots=["{{content}}<end_of_turn>\n"]),
format_observation=StringFormatter( format_observation=StringFormatter(
slots=["<start_of_turn>tool\n{{content}}<end_of_turn>\n<start_of_turn>model\n"] slots=["<start_of_turn>tool\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]
), ),
format_separator=EmptyFormatter(slots=["<end_of_turn>\n"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]), format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
efficient_eos=True,
) )
@ -726,8 +707,8 @@ _register_template(
"<|start_of_role|>user<|end_of_role|>{{content}}<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>" "<|start_of_role|>user<|end_of_role|>{{content}}<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>"
] ]
), ),
format_assistant=StringFormatter(slots=["{{content}}<|end_of_text|>\n"]),
format_system=StringFormatter(slots=["<|start_of_role|>system<|end_of_role|>{{content}}<|end_of_text|>\n"]), format_system=StringFormatter(slots=["<|start_of_role|>system<|end_of_role|>{{content}}<|end_of_text|>\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
) )
@ -742,22 +723,20 @@ _register_template(
_register_template( _register_template(
name="intern", name="intern",
format_user=StringFormatter(slots=["<|User|>:{{content}}\n<|Bot|>:"]), format_user=StringFormatter(slots=["<|User|>:{{content}}\n<|Bot|>:"]),
format_assistant=StringFormatter(slots=["{{content}}<eoa>\n"]),
format_system=StringFormatter(slots=["<|System|>:{{content}}\n"]), format_system=StringFormatter(slots=["<|System|>:{{content}}\n"]),
format_separator=EmptyFormatter(slots=["<eoa>\n"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]), format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<eoa>"], stop_words=["<eoa>"],
efficient_eos=True, # internlm tokenizer cannot set eos_token_id
) )
_register_template( _register_template(
name="intern2", name="intern2",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_separator=EmptyFormatter(slots=["<|im_end|>\n"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]), format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<|im_end|>"], stop_words=["<|im_end|>"],
efficient_eos=True, # internlm2 tokenizer cannot set eos_token_id
) )
@ -888,6 +867,7 @@ _register_template(
name="llava_next_mistral", name="llava_next_mistral",
format_user=StringFormatter(slots=["[INST] {{content}}[/INST]"]), format_user=StringFormatter(slots=["[INST] {{content}}[/INST]"]),
format_assistant=StringFormatter(slots=[" {{content}}", {"eos_token"}]), format_assistant=StringFormatter(slots=[" {{content}}", {"eos_token"}]),
format_system=StringFormatter(slots=["{{content}}\n\n"]),
format_function=FunctionFormatter(slots=["[TOOL_CALLS] ", "{{content}}", {"eos_token"}], tool_format="mistral"), format_function=FunctionFormatter(slots=["[TOOL_CALLS] ", "{{content}}", {"eos_token"}], tool_format="mistral"),
format_observation=StringFormatter(slots=["""[TOOL_RESULTS] {"content": {{content}}}[/TOOL_RESULTS]"""]), format_observation=StringFormatter(slots=["""[TOOL_RESULTS] {"content": {{content}}}[/TOOL_RESULTS]"""]),
format_tools=ToolFormatter(tool_format="mistral"), format_tools=ToolFormatter(tool_format="mistral"),
@ -900,16 +880,15 @@ _register_template(
_register_template( _register_template(
name="llava_next_qwen", name="llava_next_qwen",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_function=FunctionFormatter(slots=["{{content}}", "<|im_end|>"], tool_format="qwen"), format_function=FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen"),
format_observation=StringFormatter( format_observation=StringFormatter(
slots=["<|im_start|>user\n<tool_response>\n{{content}}\n</tool_response><|im_end|>\n<|im_start|>assistant\n"] slots=["<|im_start|>user\n<tool_response>\n{{content}}\n</tool_response><|im_end|>\n<|im_start|>assistant\n"]
), ),
format_tools=ToolFormatter(tool_format="qwen"), format_tools=ToolFormatter(tool_format="qwen"),
format_separator=EmptyFormatter(slots=["\n"]),
default_system="You are a helpful assistant.", default_system="You are a helpful assistant.",
stop_words=["<|im_end|>"], stop_words=["<|im_end|>"],
replace_eos=True,
mm_plugin=get_mm_plugin(name="llava_next", image_token="<image>"), mm_plugin=get_mm_plugin(name="llava_next", image_token="<image>"),
) )
@ -918,10 +897,9 @@ _register_template(
_register_template( _register_template(
name="llava_next_yi", name="llava_next_yi",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
stop_words=["<|im_end|>"], stop_words=["<|im_end|>"],
replace_eos=True,
mm_plugin=get_mm_plugin(name="llava_next", image_token="<image>"), mm_plugin=get_mm_plugin(name="llava_next", image_token="<image>"),
) )
@ -943,6 +921,7 @@ _register_template(
name="llava_next_video_mistral", name="llava_next_video_mistral",
format_user=StringFormatter(slots=["[INST] {{content}}[/INST]"]), format_user=StringFormatter(slots=["[INST] {{content}}[/INST]"]),
format_assistant=StringFormatter(slots=[" {{content}}", {"eos_token"}]), format_assistant=StringFormatter(slots=[" {{content}}", {"eos_token"}]),
format_system=StringFormatter(slots=["{{content}}\n\n"]),
format_function=FunctionFormatter(slots=["[TOOL_CALLS] ", "{{content}}", {"eos_token"}], tool_format="mistral"), format_function=FunctionFormatter(slots=["[TOOL_CALLS] ", "{{content}}", {"eos_token"}], tool_format="mistral"),
format_observation=StringFormatter(slots=["""[TOOL_RESULTS] {"content": {{content}}}[/TOOL_RESULTS]"""]), format_observation=StringFormatter(slots=["""[TOOL_RESULTS] {"content": {{content}}}[/TOOL_RESULTS]"""]),
format_tools=ToolFormatter(tool_format="mistral"), format_tools=ToolFormatter(tool_format="mistral"),
@ -955,10 +934,9 @@ _register_template(
_register_template( _register_template(
name="llava_next_video_yi", name="llava_next_video_yi",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
stop_words=["<|im_end|>"], stop_words=["<|im_end|>"],
replace_eos=True,
mm_plugin=get_mm_plugin(name="llava_next_video", image_token="<image>", video_token="<video>"), mm_plugin=get_mm_plugin(name="llava_next_video", image_token="<image>", video_token="<video>"),
) )
@ -967,16 +945,15 @@ _register_template(
_register_template( _register_template(
name="marco", name="marco",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
default_system=( default_system=(
"你是一个经过良好训练的AI助手你的名字是Marco-o1.由阿里国际数字商业集团的AI Business创造.\n## 重要!!!!!\n" "你是一个经过良好训练的AI助手你的名字是Marco-o1.由阿里国际数字商业集团的AI Business创造.\n## 重要!!!!!\n"
"当你回答问题时,你的思考应该在<Thought>内完成,<Output>内输出你的结果。\n" "当你回答问题时,你的思考应该在<Thought>内完成,<Output>内输出你的结果。\n"
"<Thought>应该尽可能是英文但是有2个特例一个是对原文中的引用另一个是是数学应该使用markdown格式<Output>内的输出需要遵循用户输入的语言。\n" "<Thought>应该尽可能是英文但是有2个特例一个是对原文中的引用另一个是是数学应该使用markdown格式<Output>内的输出需要遵循用户输入的语言。\n"
), ),
stop_words=["<|im_end|>"], stop_words=["<|im_end|>"],
replace_eos=True,
) )
@ -984,6 +961,7 @@ _register_template(
name="mistral", name="mistral",
format_user=StringFormatter(slots=["[INST] {{content}}[/INST]"]), format_user=StringFormatter(slots=["[INST] {{content}}[/INST]"]),
format_assistant=StringFormatter(slots=[" {{content}}", {"eos_token"}]), format_assistant=StringFormatter(slots=[" {{content}}", {"eos_token"}]),
format_system=StringFormatter(slots=["{{content}}\n\n"]),
format_function=FunctionFormatter(slots=["[TOOL_CALLS] ", "{{content}}", {"eos_token"}], tool_format="mistral"), format_function=FunctionFormatter(slots=["[TOOL_CALLS] ", "{{content}}", {"eos_token"}], tool_format="mistral"),
format_observation=StringFormatter(slots=["""[TOOL_RESULTS] {"content": {{content}}}[/TOOL_RESULTS]"""]), format_observation=StringFormatter(slots=["""[TOOL_RESULTS] {"content": {{content}}}[/TOOL_RESULTS]"""]),
format_tools=ToolFormatter(tool_format="mistral"), format_tools=ToolFormatter(tool_format="mistral"),
@ -1017,7 +995,6 @@ _register_template(
), ),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]), format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<|eot_id|>"], stop_words=["<|eot_id|>"],
replace_eos=True,
) )
@ -1025,9 +1002,9 @@ _register_template(
_register_template( _register_template(
name="opencoder", name="opencoder",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
default_system="You are OpenCoder, created by OpenCoder Team.", default_system="You are OpenCoder, created by OpenCoder Team.",
stop_words=["<|im_end|>"], stop_words=["<|im_end|>"],
) )
@ -1044,12 +1021,11 @@ _register_template(
_register_template( _register_template(
name="paligemma", name="paligemma",
format_user=StringFormatter(slots=["<start_of_turn>user\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]), format_user=StringFormatter(slots=["<start_of_turn>user\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]),
format_assistant=StringFormatter(slots=["{{content}}<end_of_turn>\n"]),
format_observation=StringFormatter( format_observation=StringFormatter(
slots=["<start_of_turn>tool\n{{content}}<end_of_turn>\n<start_of_turn>model\n"] slots=["<start_of_turn>tool\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]
), ),
format_separator=EmptyFormatter(slots=["<end_of_turn>\n"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]), format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
efficient_eos=True,
mm_plugin=get_mm_plugin(name="paligemma", image_token="<image>"), mm_plugin=get_mm_plugin(name="paligemma", image_token="<image>"),
) )
@ -1057,28 +1033,37 @@ _register_template(
_register_template( _register_template(
name="phi", name="phi",
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>\n"]), format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|end|>\n"]),
format_system=StringFormatter(slots=["<|system|>\n{{content}}<|end|>\n"]), format_system=StringFormatter(slots=["<|system|>\n{{content}}<|end|>\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<|end|>"], stop_words=["<|end|>"],
replace_eos=True,
) )
_register_template( _register_template(
name="phi_small", name="phi_small",
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>\n"]), format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|end|>\n"]),
format_system=StringFormatter(slots=["<|system|>\n{{content}}<|end|>\n"]), format_system=StringFormatter(slots=["<|system|>\n{{content}}<|end|>\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
format_prefix=EmptyFormatter(slots=[{"<|endoftext|>"}]), format_prefix=EmptyFormatter(slots=[{"<|endoftext|>"}]),
stop_words=["<|end|>"], stop_words=["<|end|>"],
replace_eos=True, )
_register_template(
name="phi4",
format_user=StringFormatter(
slots=["<|im_start|>user<|im_sep|>{{content}}<|im_end|><|im_start|>assistant<|im_sep|>"]
),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>"]),
format_system=StringFormatter(slots=["<|im_start|>system<|im_sep|>{{content}}<|im_end|>"]),
stop_words=["<|im_end|>"],
) )
_register_template( _register_template(
name="pixtral", name="pixtral",
format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]), format_user=StringFormatter(slots=["[INST]{{content}}[/INST]"]),
format_system=StringFormatter(slots=["{{content}}\n\n"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]), format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
mm_plugin=get_mm_plugin(name="pixtral", image_token="[IMG]"), mm_plugin=get_mm_plugin(name="pixtral", image_token="[IMG]"),
) )
@ -1088,13 +1073,13 @@ _register_template(
_register_template( _register_template(
name="qwen", name="qwen",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_function=FunctionFormatter(slots=["{{content}}", "<|im_end|>"], tool_format="qwen"), format_function=FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen"),
format_observation=StringFormatter( format_observation=StringFormatter(
slots=["<|im_start|>user\n<tool_response>\n{{content}}\n</tool_response><|im_end|>\n<|im_start|>assistant\n"] slots=["<|im_start|>user\n<tool_response>\n{{content}}\n</tool_response><|im_end|>\n<|im_start|>assistant\n"]
), ),
format_tools=ToolFormatter(tool_format="qwen"), format_tools=ToolFormatter(tool_format="qwen"),
format_separator=EmptyFormatter(slots=["\n"]),
default_system="You are a helpful assistant.", default_system="You are a helpful assistant.",
stop_words=["<|im_end|>"], stop_words=["<|im_end|>"],
) )
@ -1104,13 +1089,13 @@ _register_template(
_register_template( _register_template(
name="qwen2_vl", name="qwen2_vl",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_function=FunctionFormatter(slots=["{{content}}", "<|im_end|>"], tool_format="qwen"), format_function=FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen"),
format_observation=StringFormatter( format_observation=StringFormatter(
slots=["<|im_start|>user\n<tool_response>\n{{content}}\n</tool_response><|im_end|>\n<|im_start|>assistant\n"] slots=["<|im_start|>user\n<tool_response>\n{{content}}\n</tool_response><|im_end|>\n<|im_start|>assistant\n"]
), ),
format_tools=ToolFormatter(tool_format="qwen"), format_tools=ToolFormatter(tool_format="qwen"),
format_separator=EmptyFormatter(slots=["\n"]),
default_system="You are a helpful assistant.", default_system="You are a helpful assistant.",
stop_words=["<|im_end|>"], stop_words=["<|im_end|>"],
mm_plugin=get_mm_plugin(name="qwen2_vl", image_token="<|image_pad|>", video_token="<|video_pad|>"), mm_plugin=get_mm_plugin(name="qwen2_vl", image_token="<|image_pad|>", video_token="<|video_pad|>"),
@ -1120,8 +1105,8 @@ _register_template(
_register_template( _register_template(
name="sailor", name="sailor",
format_user=StringFormatter(slots=["<|im_start|>question\n{{content}}<|im_end|>\n<|im_start|>answer\n"]), format_user=StringFormatter(slots=["<|im_start|>question\n{{content}}<|im_end|>\n<|im_start|>answer\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
default_system=( default_system=(
"You are an AI assistant named Sailor created by Sea AI Lab. " "You are an AI assistant named Sailor created by Sea AI Lab. "
"Your answer should be friendly, unbiased, faithful, informative and detailed." "Your answer should be friendly, unbiased, faithful, informative and detailed."
@ -1173,10 +1158,9 @@ _register_template(
_register_template( _register_template(
name="starchat", name="starchat",
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>"]), format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>"]),
format_assistant=StringFormatter(slots=["{{content}}<|end|>\n"]),
format_system=StringFormatter(slots=["<|system|>\n{{content}}<|end|>\n"]), format_system=StringFormatter(slots=["<|system|>\n{{content}}<|end|>\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
stop_words=["<|end|>"], stop_words=["<|end|>"],
replace_eos=True,
) )
@ -1239,8 +1223,8 @@ _register_template(
_register_template( _register_template(
name="yayi", name="yayi",
format_user=StringFormatter(slots=[{"token": "<|Human|>"}, ":\n{{content}}\n\n", {"token": "<|YaYi|>"}, ":"]), format_user=StringFormatter(slots=[{"token": "<|Human|>"}, ":\n{{content}}\n\n", {"token": "<|YaYi|>"}, ":"]),
format_assistant=StringFormatter(slots=["{{content}}\n\n"]),
format_system=StringFormatter(slots=[{"token": "<|System|>"}, ":\n{{content}}\n\n"]), format_system=StringFormatter(slots=[{"token": "<|System|>"}, ":\n{{content}}\n\n"]),
format_separator=EmptyFormatter(slots=["\n\n"]),
default_system=( default_system=(
"You are a helpful, respectful and honest assistant named YaYi " "You are a helpful, respectful and honest assistant named YaYi "
"developed by Beijing Wenge Technology Co.,Ltd. " "developed by Beijing Wenge Technology Co.,Ltd. "
@ -1260,17 +1244,16 @@ _register_template(
_register_template( _register_template(
name="yi", name="yi",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
stop_words=["<|im_end|>"], stop_words=["<|im_end|>"],
replace_eos=True,
) )
_register_template( _register_template(
name="yi_vl", name="yi_vl",
format_user=StringFormatter(slots=["### Human: {{content}}\n### Assistant:"]), format_user=StringFormatter(slots=["### Human: {{content}}\n### Assistant:"]),
format_separator=EmptyFormatter(slots=["\n"]), format_assistant=StringFormatter(slots=["{{content}}\n"]),
default_system=( default_system=(
"This is a chat between an inquisitive human and an AI assistant. " "This is a chat between an inquisitive human and an AI assistant. "
"Assume the role of the AI assistant. Read all the images carefully, " "Assume the role of the AI assistant. Read all the images carefully, "
@ -1287,9 +1270,8 @@ _register_template(
_register_template( _register_template(
name="yuan", name="yuan",
format_user=StringFormatter(slots=["{{content}}", {"token": "<sep>"}]), format_user=StringFormatter(slots=["{{content}}", {"token": "<sep>"}]),
format_separator=EmptyFormatter(slots=["\n"]), format_assistant=StringFormatter(slots=["{{content}}<eod>\n"]),
stop_words=["<eod>"], stop_words=["<eod>"],
replace_eos=True,
) )
@ -1304,5 +1286,5 @@ _register_template(
_register_template( _register_template(
name="ziya", name="ziya",
format_user=StringFormatter(slots=["<human>:{{content}}\n<bot>:"]), format_user=StringFormatter(slots=["<human>:{{content}}\n<bot>:"]),
format_separator=EmptyFormatter(slots=["\n"]), format_assistant=StringFormatter(slots=["{{content}}\n"]),
) )

View File

@ -1424,6 +1424,14 @@ register_model_group(
DownloadSource.DEFAULT: "microsoft/Phi-3-medium-128k-instruct", DownloadSource.DEFAULT: "microsoft/Phi-3-medium-128k-instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-medium-128k-instruct", DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-medium-128k-instruct",
}, },
"Phi-3.5-4B-instruct": {
DownloadSource.DEFAULT: "microsoft/Phi-3.5-mini-instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Phi-3.5-mini-instruct",
},
"Phi-3.5-MoE-42B-A6.6B-instruct": {
DownloadSource.DEFAULT: "microsoft/Phi-3.5-MoE-instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Phi-3.5-MoE-instruct",
},
}, },
template="phi", template="phi",
) )
@ -1444,6 +1452,17 @@ register_model_group(
) )
register_model_group(
models={
"Phi-4-14B-Instruct": {
DownloadSource.DEFAULT: "microsoft/phi-4",
DownloadSource.MODELSCOPE: "LLM-Research/phi-4",
},
},
template="phi4",
)
register_model_group( register_model_group(
models={ models={
"Pixtral-12B-Instruct": { "Pixtral-12B-Instruct": {

View File

@ -68,7 +68,7 @@ class LoggerHandler(logging.Handler):
class _Logger(logging.Logger): class _Logger(logging.Logger):
r""" r"""
A logger that supports info_rank0 and warning_once. A logger that supports rank0 logging.
""" """
def info_rank0(self, *args, **kwargs) -> None: def info_rank0(self, *args, **kwargs) -> None:
@ -77,7 +77,7 @@ class _Logger(logging.Logger):
def warning_rank0(self, *args, **kwargs) -> None: def warning_rank0(self, *args, **kwargs) -> None:
self.warning(*args, **kwargs) self.warning(*args, **kwargs)
def warning_once(self, *args, **kwargs) -> None: def warning_rank0_once(self, *args, **kwargs) -> None:
self.warning(*args, **kwargs) self.warning(*args, **kwargs)
@ -163,11 +163,11 @@ def warning_rank0(self: "logging.Logger", *args, **kwargs) -> None:
@lru_cache(None) @lru_cache(None)
def warning_once(self: "logging.Logger", *args, **kwargs) -> None: def warning_rank0_once(self: "logging.Logger", *args, **kwargs) -> None:
if int(os.getenv("LOCAL_RANK", "0")) == 0: if int(os.getenv("LOCAL_RANK", "0")) == 0:
self.warning(*args, **kwargs) self.warning(*args, **kwargs)
logging.Logger.info_rank0 = info_rank0 logging.Logger.info_rank0 = info_rank0
logging.Logger.warning_rank0 = warning_rank0 logging.Logger.warning_rank0 = warning_rank0
logging.Logger.warning_once = warning_once logging.Logger.warning_rank0_once = warning_rank0_once

View File

@ -73,19 +73,31 @@ class AverageMeter:
self.avg = self.sum / self.count self.avg = self.sum / self.count
def check_version(requirement: str, mandatory: bool = False) -> None:
r"""
Optionally checks the package version.
"""
if os.getenv("DISABLE_VERSION_CHECK", "0").lower() in ["true", "1"] and not mandatory:
logger.warning_rank0_once("Version checking has been disabled, may lead to unexpected behaviors.")
return
if mandatory:
hint = f"To fix: run `pip install {requirement}`."
else:
hint = f"To fix: run `pip install {requirement}` or set `DISABLE_VERSION_CHECK=1` to skip this check."
require_version(requirement, hint)
def check_dependencies() -> None: def check_dependencies() -> None:
r""" r"""
Checks the version of the required packages. Checks the version of the required packages.
""" """
if os.getenv("DISABLE_VERSION_CHECK", "0").lower() in ["true", "1"]: check_version("transformers>=4.41.2,<=4.46.1")
logger.warning_once("Version checking has been disabled, may lead to unexpected behaviors.") check_version("datasets>=2.16.0,<=3.1.0")
return check_version("accelerate>=0.34.0,<=1.0.1")
check_version("peft>=0.11.1,<=0.12.0")
require_version("transformers>=4.41.2", "To fix: pip install transformers>=4.41.2") check_version("trl>=0.8.6,<=0.9.6")
require_version("datasets>=2.16.0,<=3.1.0", "To fix: pip install datasets>=2.16.0,<=3.1.0")
require_version("accelerate>=0.34.0,<=1.0.1", "To fix: pip install accelerate>=0.34.0,<=1.0.1")
require_version("peft>=0.11.1,<=0.12.0", "To fix: pip install peft>=0.11.1,<=0.12.0")
require_version("trl>=0.8.6,<=0.9.6", "To fix: pip install trl>=0.8.6,<=0.9.6")
def calculate_tps(dataset: Sequence[Dict[str, Any]], metrics: Dict[str, float], stage: Literal["sft", "rm"]) -> float: def calculate_tps(dataset: Sequence[Dict[str, Any]], metrics: Dict[str, float], stage: Literal["sft", "rm"]) -> float:
@ -229,7 +241,7 @@ def skip_check_imports() -> None:
r""" r"""
Avoids flash attention import error in custom model files. Avoids flash attention import error in custom model files.
""" """
if os.environ.get("FORCE_CHECK_IMPORTS", "0").lower() not in ["true", "1"]: if os.getenv("FORCE_CHECK_IMPORTS", "0").lower() not in ["true", "1"]:
transformers.dynamic_module_utils.check_imports = get_relative_imports transformers.dynamic_module_utils.check_imports = get_relative_imports
@ -253,7 +265,7 @@ def try_download_model_from_other_hub(model_args: "ModelArguments") -> str:
return model_args.model_name_or_path return model_args.model_name_or_path
if use_modelscope(): if use_modelscope():
require_version("modelscope>=1.11.0", "To fix: pip install modelscope>=1.11.0") check_version("modelscope>=1.11.0", mandatory=True)
from modelscope import snapshot_download # type: ignore from modelscope import snapshot_download # type: ignore
revision = "master" if model_args.model_revision == "main" else model_args.model_revision revision = "master" if model_args.model_revision == "main" else model_args.model_revision
@ -264,7 +276,7 @@ def try_download_model_from_other_hub(model_args: "ModelArguments") -> str:
) )
if use_openmind(): if use_openmind():
require_version("openmind>=0.8.0", "To fix: pip install openmind>=0.8.0") check_version("openmind>=0.8.0", mandatory=True)
from openmind.utils.hub import snapshot_download # type: ignore from openmind.utils.hub import snapshot_download # type: ignore
return snapshot_download( return snapshot_download(
@ -275,8 +287,12 @@ def try_download_model_from_other_hub(model_args: "ModelArguments") -> str:
def use_modelscope() -> bool: def use_modelscope() -> bool:
return os.environ.get("USE_MODELSCOPE_HUB", "0").lower() in ["true", "1"] return os.getenv("USE_MODELSCOPE_HUB", "0").lower() in ["true", "1"]
def use_openmind() -> bool: def use_openmind() -> bool:
return os.environ.get("USE_OPENMIND_HUB", "0").lower() in ["true", "1"] return os.getenv("USE_OPENMIND_HUB", "0").lower() in ["true", "1"]
def use_ray() -> bool:
return os.getenv("USE_RAY", "0").lower() in ["true", "1"]

View File

@ -62,6 +62,10 @@ def is_pillow_available():
return _is_package_available("PIL") return _is_package_available("PIL")
def is_ray_available():
return _is_package_available("ray")
def is_requests_available(): def is_requests_available():
return _is_package_available("requests") return _is_package_available("requests")

View File

@ -17,7 +17,8 @@ from .evaluation_args import EvaluationArguments
from .finetuning_args import FinetuningArguments from .finetuning_args import FinetuningArguments
from .generating_args import GeneratingArguments from .generating_args import GeneratingArguments
from .model_args import ModelArguments from .model_args import ModelArguments
from .parser import get_eval_args, get_infer_args, get_train_args from .parser import get_eval_args, get_infer_args, get_ray_args, get_train_args, read_args
from .training_args import RayArguments, TrainingArguments
__all__ = [ __all__ = [
@ -26,7 +27,11 @@ __all__ = [
"FinetuningArguments", "FinetuningArguments",
"GeneratingArguments", "GeneratingArguments",
"ModelArguments", "ModelArguments",
"RayArguments",
"TrainingArguments",
"get_eval_args", "get_eval_args",
"get_infer_args", "get_infer_args",
"get_ray_args",
"get_train_args", "get_train_args",
"read_args",
] ]

View File

@ -15,56 +15,67 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import json
import os import os
import sys import sys
from typing import Any, Dict, Optional, Tuple from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union
import torch import torch
import transformers import transformers
from transformers import HfArgumentParser, Seq2SeqTrainingArguments import yaml
from transformers import HfArgumentParser
from transformers.integrations import is_deepspeed_zero3_enabled from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.trainer_utils import get_last_checkpoint from transformers.trainer_utils import get_last_checkpoint
from transformers.training_args import ParallelMode from transformers.training_args import ParallelMode
from transformers.utils import is_torch_bf16_gpu_available, is_torch_npu_available from transformers.utils import is_torch_bf16_gpu_available, is_torch_npu_available
from transformers.utils.versions import require_version
from ..extras import logging from ..extras import logging
from ..extras.constants import CHECKPOINT_NAMES from ..extras.constants import CHECKPOINT_NAMES
from ..extras.misc import check_dependencies, get_current_device from ..extras.misc import check_dependencies, check_version, get_current_device
from .data_args import DataArguments from .data_args import DataArguments
from .evaluation_args import EvaluationArguments from .evaluation_args import EvaluationArguments
from .finetuning_args import FinetuningArguments from .finetuning_args import FinetuningArguments
from .generating_args import GeneratingArguments from .generating_args import GeneratingArguments
from .model_args import ModelArguments from .model_args import ModelArguments
from .training_args import RayArguments, TrainingArguments
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
check_dependencies() check_dependencies()
_TRAIN_ARGS = [ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments] _TRAIN_ARGS = [ModelArguments, DataArguments, TrainingArguments, FinetuningArguments, GeneratingArguments]
_TRAIN_CLS = Tuple[ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments] _TRAIN_CLS = Tuple[ModelArguments, DataArguments, TrainingArguments, FinetuningArguments, GeneratingArguments]
_INFER_ARGS = [ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments] _INFER_ARGS = [ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
_INFER_CLS = Tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments] _INFER_CLS = Tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
_EVAL_ARGS = [ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments] _EVAL_ARGS = [ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments]
_EVAL_CLS = Tuple[ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments] _EVAL_CLS = Tuple[ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments]
def _parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None) -> Tuple[Any]: def read_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> Union[Dict[str, Any], List[str]]:
if args is not None: if args is not None:
return parser.parse_dict(args) return args
if len(sys.argv) == 2 and (sys.argv[1].endswith(".yaml") or sys.argv[1].endswith(".yml")): if len(sys.argv) == 2 and (sys.argv[1].endswith(".yaml") or sys.argv[1].endswith(".yml")):
return parser.parse_yaml_file(os.path.abspath(sys.argv[1])) return yaml.safe_load(Path(sys.argv[1]).absolute().read_text())
elif len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
return json.loads(Path(sys.argv[1]).absolute().read_text())
else:
return sys.argv[1:]
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
return parser.parse_json_file(os.path.abspath(sys.argv[1]))
(*parsed_args, unknown_args) = parser.parse_args_into_dataclasses(return_remaining_strings=True) def _parse_args(
parser: "HfArgumentParser", args: Optional[Union[Dict[str, Any], List[str]]] = None, allow_extra_keys: bool = False
) -> Tuple[Any]:
args = read_args(args)
if isinstance(args, dict):
return parser.parse_dict(args, allow_extra_keys=allow_extra_keys)
if unknown_args: (*parsed_args, unknown_args) = parser.parse_args_into_dataclasses(args=args, return_remaining_strings=True)
if unknown_args and not allow_extra_keys:
print(parser.format_help()) print(parser.format_help())
print(f"Got unknown args, potentially deprecated arguments: {unknown_args}") print(f"Got unknown args, potentially deprecated arguments: {unknown_args}")
raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {unknown_args}") raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {unknown_args}")
@ -110,58 +121,61 @@ def _verify_model_args(
def _check_extra_dependencies( def _check_extra_dependencies(
model_args: "ModelArguments", model_args: "ModelArguments",
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
training_args: Optional["Seq2SeqTrainingArguments"] = None, training_args: Optional["TrainingArguments"] = None,
) -> None: ) -> None:
if os.getenv("DISABLE_VERSION_CHECK", "0").lower() in ["true", "1"]:
logger.warning_once("Version checking has been disabled, may lead to unexpected behaviors.")
return
if model_args.use_unsloth: if model_args.use_unsloth:
require_version("unsloth", "Please install unsloth: https://github.com/unslothai/unsloth") check_version("unsloth", mandatory=True)
if model_args.enable_liger_kernel: if model_args.enable_liger_kernel:
require_version("liger-kernel", "To fix: pip install liger-kernel") check_version("liger-kernel", mandatory=True)
if model_args.mixture_of_depths is not None: if model_args.mixture_of_depths is not None:
require_version("mixture-of-depth>=1.1.6", "To fix: pip install mixture-of-depth>=1.1.6") check_version("mixture-of-depth>=1.1.6", mandatory=True)
if model_args.infer_backend == "vllm": if model_args.infer_backend == "vllm":
require_version("vllm>=0.4.3,<0.6.7", "To fix: pip install vllm>=0.4.3,<0.6.7") check_version("vllm>=0.4.3,<0.6.7")
check_version("vllm", mandatory=True)
if finetuning_args.use_galore: if finetuning_args.use_galore:
require_version("galore_torch", "To fix: pip install galore_torch") check_version("galore_torch", mandatory=True)
if finetuning_args.use_badam: if finetuning_args.use_badam:
require_version("badam>=1.2.1", "To fix: pip install badam>=1.2.1") check_version("badam>=1.2.1", mandatory=True)
if finetuning_args.use_adam_mini: if finetuning_args.use_adam_mini:
require_version("adam-mini", "To fix: pip install adam-mini") check_version("adam-mini", mandatory=True)
if finetuning_args.plot_loss: if finetuning_args.plot_loss:
require_version("matplotlib", "To fix: pip install matplotlib") check_version("matplotlib", mandatory=True)
if training_args is not None and training_args.predict_with_generate: if training_args is not None and training_args.predict_with_generate:
require_version("jieba", "To fix: pip install jieba") check_version("jieba", mandatory=True)
require_version("nltk", "To fix: pip install nltk") check_version("nltk", mandatory=True)
require_version("rouge_chinese", "To fix: pip install rouge-chinese") check_version("rouge_chinese", mandatory=True)
def _parse_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: def _parse_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _TRAIN_CLS:
parser = HfArgumentParser(_TRAIN_ARGS) parser = HfArgumentParser(_TRAIN_ARGS)
return _parse_args(parser, args) return _parse_args(parser, args)
def _parse_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS: def _parse_infer_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _INFER_CLS:
parser = HfArgumentParser(_INFER_ARGS) parser = HfArgumentParser(_INFER_ARGS)
return _parse_args(parser, args) return _parse_args(parser, args)
def _parse_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS: def _parse_eval_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _EVAL_CLS:
parser = HfArgumentParser(_EVAL_ARGS) parser = HfArgumentParser(_EVAL_ARGS)
return _parse_args(parser, args) return _parse_args(parser, args)
def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: def get_ray_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> RayArguments:
parser = HfArgumentParser(RayArguments)
(ray_args,) = _parse_args(parser, args, allow_extra_keys=True)
return ray_args
def get_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _TRAIN_CLS:
model_args, data_args, training_args, finetuning_args, generating_args = _parse_train_args(args) model_args, data_args, training_args, finetuning_args, generating_args = _parse_train_args(args)
# Setup logging # Setup logging
@ -371,7 +385,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
return model_args, data_args, training_args, finetuning_args, generating_args return model_args, data_args, training_args, finetuning_args, generating_args
def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS: def get_infer_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _INFER_CLS:
model_args, data_args, finetuning_args, generating_args = _parse_infer_args(args) model_args, data_args, finetuning_args, generating_args = _parse_infer_args(args)
_set_transformers_logging() _set_transformers_logging()
@ -404,7 +418,7 @@ def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
return model_args, data_args, finetuning_args, generating_args return model_args, data_args, finetuning_args, generating_args
def get_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS: def get_eval_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _EVAL_CLS:
model_args, data_args, eval_args, finetuning_args = _parse_eval_args(args) model_args, data_args, eval_args, finetuning_args = _parse_eval_args(args)
_set_transformers_logging() _set_transformers_logging()

View File

@ -0,0 +1,48 @@
import json
from dataclasses import dataclass, field
from typing import Literal, Optional, Union
from transformers import Seq2SeqTrainingArguments
from transformers.training_args import _convert_str_dict
from ..extras.misc import use_ray
@dataclass
class RayArguments:
r"""
Arguments pertaining to the Ray training.
"""
ray_run_name: Optional[str] = field(
default=None,
metadata={"help": "The training results will be saved at `saves/ray_run_name`."},
)
ray_num_workers: int = field(
default=1,
metadata={"help": "The number of workers for Ray training. Default is 1 worker."},
)
resources_per_worker: Union[dict, str] = field(
default_factory=lambda: {"GPU": 1},
metadata={"help": "The resources per worker for Ray training. Default is to use 1 GPU per worker."},
)
placement_strategy: Literal["SPREAD", "PACK", "STRICT_SPREAD", "STRICT_PACK"] = field(
default="PACK",
metadata={"help": "The placement strategy for Ray training. Default is PACK."},
)
def __post_init__(self):
self.use_ray = use_ray()
if isinstance(self.resources_per_worker, str) and self.resources_per_worker.startswith("{"):
self.resources_per_worker = _convert_str_dict(json.loads(self.resources_per_worker))
@dataclass
class TrainingArguments(RayArguments, Seq2SeqTrainingArguments):
r"""
Arguments pertaining to the trainer.
"""
def __post_init__(self):
Seq2SeqTrainingArguments.__post_init__(self)
RayArguments.__post_init__(self)

View File

@ -15,9 +15,9 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from transformers.utils import is_flash_attn_2_available, is_torch_sdpa_available from transformers.utils import is_flash_attn_2_available, is_torch_sdpa_available
from transformers.utils.versions import require_version
from ...extras import logging from ...extras import logging
from ...extras.misc import check_version
if TYPE_CHECKING: if TYPE_CHECKING:
@ -35,8 +35,8 @@ def configure_attn_implementation(
if getattr(config, "model_type", None) == "gemma2" and is_trainable: if getattr(config, "model_type", None) == "gemma2" and is_trainable:
if model_args.flash_attn == "auto" or model_args.flash_attn == "fa2": if model_args.flash_attn == "auto" or model_args.flash_attn == "fa2":
if is_flash_attn_2_available(): if is_flash_attn_2_available():
require_version("transformers>=4.42.4", "To fix: pip install transformers>=4.42.4") check_version("transformers>=4.42.4")
require_version("flash_attn>=2.6.3", "To fix: pip install flash_attn>=2.6.3") check_version("flash_attn>=2.6.3")
if model_args.flash_attn != "fa2": if model_args.flash_attn != "fa2":
logger.warning_rank0("Gemma-2 should use flash attention 2, change `flash_attn` to fa2.") logger.warning_rank0("Gemma-2 should use flash attention 2, change `flash_attn` to fa2.")
model_args.flash_attn = "fa2" model_args.flash_attn = "fa2"

View File

@ -122,7 +122,7 @@ def _gradient_checkpointing_enable(
if "value" in inspect.signature(self._set_gradient_checkpointing).parameters: # old GC format if "value" in inspect.signature(self._set_gradient_checkpointing).parameters: # old GC format
self.apply(partial(self._set_gradient_checkpointing, value=True)) self.apply(partial(self._set_gradient_checkpointing, value=True))
self.enable_input_require_grads() self.enable_input_require_grads()
logger.warning_once("You are using the old GC format, some features (e.g. BAdam) will be invalid.") logger.warning_rank0_once("You are using the old GC format, some features (e.g. BAdam) will be invalid.")
else: # have already enabled input require gradients else: # have already enabled input require gradients
self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=gradient_checkpointing_func) self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=gradient_checkpointing_func)

View File

@ -31,10 +31,10 @@ from transformers.models.llama.modeling_llama import (
apply_rotary_pos_emb, apply_rotary_pos_emb,
repeat_kv, repeat_kv,
) )
from transformers.utils.versions import require_version
from ...extras import logging from ...extras import logging
from ...extras.constants import SUPPORTED_CLASS_FOR_S2ATTN from ...extras.constants import SUPPORTED_CLASS_FOR_S2ATTN
from ...extras.misc import check_version
from ...extras.packages import is_transformers_version_greater_than from ...extras.packages import is_transformers_version_greater_than
@ -353,7 +353,7 @@ def llama_sdpa_attention_forward(
def _apply_llama_patch() -> None: def _apply_llama_patch() -> None:
require_version("transformers>=4.41.2,<=4.46.1", "To fix: pip install transformers>=4.41.2,<=4.46.1") check_version("transformers>=4.41.2,<=4.46.1")
LlamaAttention.forward = llama_attention_forward LlamaAttention.forward = llama_attention_forward
LlamaFlashAttention2.forward = llama_flash_attention_2_forward LlamaFlashAttention2.forward = llama_flash_attention_2_forward
LlamaSdpaAttention.forward = llama_sdpa_attention_forward LlamaSdpaAttention.forward = llama_sdpa_attention_forward

View File

@ -16,7 +16,8 @@ from typing import TYPE_CHECKING, Sequence
import torch import torch
from transformers.integrations import is_deepspeed_zero3_enabled from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.utils.versions import require_version
from ...extras.misc import check_version
if TYPE_CHECKING: if TYPE_CHECKING:
@ -26,7 +27,7 @@ if TYPE_CHECKING:
def _set_z3_leaf_modules(model: "PreTrainedModel", leaf_modules: Sequence["torch.nn.Module"]) -> None: def _set_z3_leaf_modules(model: "PreTrainedModel", leaf_modules: Sequence["torch.nn.Module"]) -> None:
require_version("deepspeed>=0.13.0", "To fix: pip install deepspeed>=0.13.0") check_version("deepspeed>=0.13.0")
from deepspeed.utils import set_z3_leaf_modules # type: ignore from deepspeed.utils import set_z3_leaf_modules # type: ignore
set_z3_leaf_modules(model, leaf_modules) set_z3_leaf_modules(model, leaf_modules)

View File

@ -41,9 +41,9 @@ from typing import TYPE_CHECKING, Tuple
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from transformers.utils.versions import require_version
from ...extras import logging from ...extras import logging
from ...extras.misc import check_version
from ...extras.packages import is_transformers_version_greater_than from ...extras.packages import is_transformers_version_greater_than
@ -118,6 +118,6 @@ def configure_packing(model_args: "ModelArguments", is_trainable: bool) -> None:
if not is_trainable or not model_args.block_diag_attn: if not is_trainable or not model_args.block_diag_attn:
return return
require_version("transformers>=4.43.0,<=4.46.1", "To fix: pip install transformers>=4.43.0,<=4.46.1") check_version("transformers>=4.43.0,<=4.46.1")
transformers.modeling_flash_attention_utils._get_unpad_data = get_unpad_data transformers.modeling_flash_attention_utils._get_unpad_data = get_unpad_data
logger.info_rank0("Using block diagonal attention for sequence packing without cross-attention.") logger.info_rank0("Using block diagonal attention for sequence packing without cross-attention.")

View File

@ -26,11 +26,10 @@ from datasets import load_dataset
from transformers import BitsAndBytesConfig, EetqConfig, GPTQConfig, HqqConfig from transformers import BitsAndBytesConfig, EetqConfig, GPTQConfig, HqqConfig
from transformers.integrations import is_deepspeed_zero3_enabled from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.modeling_utils import is_fsdp_enabled from transformers.modeling_utils import is_fsdp_enabled
from transformers.utils.versions import require_version
from ...extras import logging from ...extras import logging
from ...extras.constants import FILEEXT2TYPE from ...extras.constants import FILEEXT2TYPE
from ...extras.misc import get_current_device from ...extras.misc import check_version, get_current_device
if TYPE_CHECKING: if TYPE_CHECKING:
@ -118,15 +117,15 @@ def configure_quantization(
quant_method = quantization_config.get("quant_method", "") quant_method = quantization_config.get("quant_method", "")
if quant_method == QuantizationMethod.GPTQ: if quant_method == QuantizationMethod.GPTQ:
require_version("auto_gptq>=0.5.0", "To fix: pip install auto_gptq>=0.5.0") check_version("auto_gptq>=0.5.0", mandatory=True)
quantization_config.pop("disable_exllama", None) # remove deprecated args quantization_config.pop("disable_exllama", None) # remove deprecated args
quantization_config["use_exllama"] = False # disable exllama quantization_config["use_exllama"] = False # disable exllama
if quant_method == QuantizationMethod.AWQ: if quant_method == QuantizationMethod.AWQ:
require_version("autoawq", "To fix: pip install autoawq") check_version("autoawq", mandatory=True)
if quant_method == QuantizationMethod.AQLM: if quant_method == QuantizationMethod.AQLM:
require_version("aqlm>=1.1.0", "To fix: pip install aqlm[gpu]>=1.1.0") check_version("aqlm>=1.1.0", mandatory=True)
quantization_config["bits"] = 2 quantization_config["bits"] = 2
quant_bits = quantization_config.get("bits", "?") quant_bits = quantization_config.get("bits", "?")
@ -136,8 +135,8 @@ def configure_quantization(
if model_args.export_quantization_bit not in [8, 4, 3, 2]: if model_args.export_quantization_bit not in [8, 4, 3, 2]:
raise ValueError("AutoGPTQ only accepts 2/3/4/8-bit quantization.") raise ValueError("AutoGPTQ only accepts 2/3/4/8-bit quantization.")
require_version("optimum>=1.17.0", "To fix: pip install optimum>=1.17.0") check_version("optimum>=1.17.0", mandatory=True)
require_version("auto_gptq>=0.5.0", "To fix: pip install auto_gptq>=0.5.0") check_version("auto_gptq>=0.5.0", mandatory=True)
from accelerate.utils import get_max_memory from accelerate.utils import get_max_memory
if getattr(config, "model_type", None) == "chatglm": if getattr(config, "model_type", None) == "chatglm":
@ -154,10 +153,10 @@ def configure_quantization(
elif model_args.quantization_bit is not None: # on-the-fly elif model_args.quantization_bit is not None: # on-the-fly
if model_args.quantization_method == QuantizationMethod.BITS_AND_BYTES.value: if model_args.quantization_method == QuantizationMethod.BITS_AND_BYTES.value:
if model_args.quantization_bit == 8: if model_args.quantization_bit == 8:
require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0") check_version("bitsandbytes>=0.37.0", mandatory=True)
init_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True) init_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
elif model_args.quantization_bit == 4: elif model_args.quantization_bit == 4:
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0") check_version("bitsandbytes>=0.39.0", mandatory=True)
init_kwargs["quantization_config"] = BitsAndBytesConfig( init_kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_4bit=True, load_in_4bit=True,
bnb_4bit_compute_dtype=model_args.compute_dtype, bnb_4bit_compute_dtype=model_args.compute_dtype,
@ -175,7 +174,7 @@ def configure_quantization(
if model_args.quantization_bit != 4: if model_args.quantization_bit != 4:
raise ValueError("Only 4-bit quantized model can use fsdp+qlora or auto device map.") raise ValueError("Only 4-bit quantized model can use fsdp+qlora or auto device map.")
require_version("bitsandbytes>=0.43.0", "To fix: pip install bitsandbytes>=0.43.0") check_version("bitsandbytes>=0.43.0", mandatory=True)
else: else:
init_kwargs["device_map"] = {"": get_current_device()} # change auto device map for inference init_kwargs["device_map"] = {"": get_current_device()} # change auto device map for inference
@ -187,7 +186,7 @@ def configure_quantization(
if is_deepspeed_zero3_enabled() or is_fsdp_enabled(): if is_deepspeed_zero3_enabled() or is_fsdp_enabled():
raise ValueError("HQQ quantization is incompatible with DeepSpeed ZeRO-3 or FSDP.") raise ValueError("HQQ quantization is incompatible with DeepSpeed ZeRO-3 or FSDP.")
require_version("hqq", "To fix: pip install hqq") check_version("hqq", mandatory=True)
init_kwargs["quantization_config"] = HqqConfig( init_kwargs["quantization_config"] = HqqConfig(
nbits=model_args.quantization_bit, quant_zero=False, quant_scale=False, axis=0 nbits=model_args.quantization_bit, quant_zero=False, quant_scale=False, axis=0
) # use ATEN kernel (axis=0) for performance ) # use ATEN kernel (axis=0) for performance
@ -199,6 +198,6 @@ def configure_quantization(
if is_deepspeed_zero3_enabled() or is_fsdp_enabled(): if is_deepspeed_zero3_enabled() or is_fsdp_enabled():
raise ValueError("EETQ quantization is incompatible with DeepSpeed ZeRO-3 or FSDP.") raise ValueError("EETQ quantization is incompatible with DeepSpeed ZeRO-3 or FSDP.")
require_version("eetq", "To fix: pip install eetq") check_version("eetq", mandatory=True)
init_kwargs["quantization_config"] = EetqConfig() init_kwargs["quantization_config"] = EetqConfig()
logger.info_rank0(f"Quantizing model to {model_args.quantization_bit} bit with EETQ.") logger.info_rank0(f"Quantizing model to {model_args.quantization_bit} bit with EETQ.")

View File

@ -35,7 +35,7 @@ from typing_extensions import override
from ..extras import logging from ..extras import logging
from ..extras.constants import TRAINER_LOG, V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME from ..extras.constants import TRAINER_LOG, V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
from ..extras.misc import get_peak_memory from ..extras.misc import get_peak_memory, use_ray
if is_safetensors_available(): if is_safetensors_available():
@ -194,7 +194,7 @@ class LogCallback(TrainerCallback):
self.do_train = False self.do_train = False
# Web UI # Web UI
self.webui_mode = os.environ.get("LLAMABOARD_ENABLED", "0").lower() in ["true", "1"] self.webui_mode = os.environ.get("LLAMABOARD_ENABLED", "0").lower() in ["true", "1"]
if self.webui_mode: if self.webui_mode and not use_ray():
signal.signal(signal.SIGABRT, self._set_abort) signal.signal(signal.SIGABRT, self._set_abort)
self.logger_handler = logging.LoggerHandler(os.environ.get("LLAMABOARD_WORKDIR")) self.logger_handler = logging.LoggerHandler(os.environ.get("LLAMABOARD_WORKDIR"))
logging.add_handler(self.logger_handler) logging.add_handler(self.logger_handler)
@ -239,7 +239,7 @@ class LogCallback(TrainerCallback):
and os.path.exists(os.path.join(args.output_dir, TRAINER_LOG)) and os.path.exists(os.path.join(args.output_dir, TRAINER_LOG))
and args.overwrite_output_dir and args.overwrite_output_dir
): ):
logger.warning_once("Previous trainer log in this folder will be deleted.") logger.warning_rank0_once("Previous trainer log in this folder will be deleted.")
os.remove(os.path.join(args.output_dir, TRAINER_LOG)) os.remove(os.path.join(args.output_dir, TRAINER_LOG))
@override @override
@ -383,7 +383,7 @@ class ReporterCallback(TrainerCallback):
) )
if self.finetuning_args.use_swanlab: if self.finetuning_args.use_swanlab:
import swanlab import swanlab # type: ignore
swanlab.config.update( swanlab.config.update(
{ {

View File

@ -31,7 +31,7 @@ from typing_extensions import override
from ...extras.constants import IGNORE_INDEX from ...extras.constants import IGNORE_INDEX
from ...extras.packages import is_transformers_version_equal_to_4_46, is_transformers_version_greater_than from ...extras.packages import is_transformers_version_equal_to_4_46, is_transformers_version_greater_than
from ..callbacks import SaveProcessorCallback from ..callbacks import SaveProcessorCallback
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps, nested_detach
if TYPE_CHECKING: if TYPE_CHECKING:
@ -193,7 +193,7 @@ class CustomDPOTrainer(DPOTrainer):
Otherwise the average log probabilities. Otherwise the average log probabilities.
""" """
if self.finetuning_args.use_ref_model: if self.finetuning_args.use_ref_model:
batch = {k: v.detach().clone() for k, v in batch.items()} # avoid error batch = nested_detach(batch, clone=True) # avoid error
all_logits: "torch.Tensor" = model(**batch, return_dict=True, use_cache=False).logits.to(torch.float32) all_logits: "torch.Tensor" = model(**batch, return_dict=True, use_cache=False).logits.to(torch.float32)
all_logps, valid_length = get_batch_logps(logits=all_logits, labels=batch["labels"]) all_logps, valid_length = get_batch_logps(logits=all_logits, labels=batch["labels"])

View File

@ -30,7 +30,7 @@ from typing_extensions import override
from ...extras.constants import IGNORE_INDEX from ...extras.constants import IGNORE_INDEX
from ...extras.packages import is_transformers_version_equal_to_4_46, is_transformers_version_greater_than from ...extras.packages import is_transformers_version_equal_to_4_46, is_transformers_version_greater_than
from ..callbacks import SaveProcessorCallback from ..callbacks import SaveProcessorCallback
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps, nested_detach
if TYPE_CHECKING: if TYPE_CHECKING:
@ -142,7 +142,7 @@ class CustomKTOTrainer(KTOTrainer):
r""" r"""
Runs forward pass and computes the log probabilities. Runs forward pass and computes the log probabilities.
""" """
batch = {k: v.detach().clone() for k, v in batch.items()} # avoid error batch = nested_detach(batch, clone=True) # avoid error
model_inputs = { model_inputs = {
"input_ids": batch[f"{prefix}input_ids"], "input_ids": batch[f"{prefix}input_ids"],
"attention_mask": batch[f"{prefix}attention_mask"], "attention_mask": batch[f"{prefix}attention_mask"],

View File

@ -122,7 +122,7 @@ def run_sft(
# Predict # Predict
if training_args.do_predict: if training_args.do_predict:
logger.warning_once("Batch generation can be very slow. Consider using `scripts/vllm_infer.py` instead.") logger.warning_rank0_once("Batch generation can be very slow. Consider using `scripts/vllm_infer.py` instead.")
predict_results = trainer.predict(dataset_module["eval_dataset"], metric_key_prefix="predict", **gen_kwargs) predict_results = trainer.predict(dataset_module["eval_dataset"], metric_key_prefix="predict", **gen_kwargs)
trainer.log_metrics("predict", predict_results.metrics) trainer.log_metrics("predict", predict_results.metrics)
trainer.save_metrics("predict", predict_results.metrics) trainer.save_metrics("predict", predict_results.metrics)

View File

@ -17,7 +17,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union from collections.abc import Mapping
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
import torch import torch
from transformers import Trainer from transformers import Trainer
@ -30,20 +32,25 @@ from typing_extensions import override
from ..extras import logging from ..extras import logging
from ..extras.constants import IGNORE_INDEX from ..extras.constants import IGNORE_INDEX
from ..extras.packages import is_galore_available from ..extras.packages import is_galore_available, is_ray_available
from ..hparams import FinetuningArguments, ModelArguments from ..hparams import FinetuningArguments, ModelArguments
from ..model import find_all_linear_modules, load_model, load_tokenizer, load_valuehead_params from ..model import find_all_linear_modules, load_model, load_tokenizer, load_valuehead_params
if is_galore_available(): if is_galore_available():
from galore_torch import GaLoreAdafactor, GaLoreAdamW, GaLoreAdamW8bit from galore_torch import GaLoreAdafactor, GaLoreAdamW, GaLoreAdamW8bit # type: ignore
if is_ray_available():
from ray.train import RunConfig, ScalingConfig
from ray.train.torch import TorchTrainer
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import PreTrainedModel, Seq2SeqTrainingArguments, TrainerCallback from transformers import PreTrainedModel, TrainerCallback
from trl import AutoModelForCausalLMWithValueHead from trl import AutoModelForCausalLMWithValueHead
from ..hparams import DataArguments from ..hparams import DataArguments, RayArguments, TrainingArguments
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@ -74,7 +81,7 @@ def create_modelcard_and_push(
trainer: "Trainer", trainer: "Trainer",
model_args: "ModelArguments", model_args: "ModelArguments",
data_args: "DataArguments", data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments", training_args: "TrainingArguments",
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
) -> None: ) -> None:
kwargs = { kwargs = {
@ -187,7 +194,7 @@ def _get_decay_parameter_names(model: "PreTrainedModel") -> List[str]:
def _create_galore_optimizer( def _create_galore_optimizer(
model: "PreTrainedModel", model: "PreTrainedModel",
training_args: "Seq2SeqTrainingArguments", training_args: "TrainingArguments",
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
) -> "torch.optim.Optimizer": ) -> "torch.optim.Optimizer":
if len(finetuning_args.galore_target) == 1 and finetuning_args.galore_target[0] == "all": if len(finetuning_args.galore_target) == 1 and finetuning_args.galore_target[0] == "all":
@ -271,7 +278,7 @@ def _create_galore_optimizer(
def _create_loraplus_optimizer( def _create_loraplus_optimizer(
model: "PreTrainedModel", model: "PreTrainedModel",
training_args: "Seq2SeqTrainingArguments", training_args: "TrainingArguments",
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
) -> "torch.optim.Optimizer": ) -> "torch.optim.Optimizer":
default_lr = training_args.learning_rate default_lr = training_args.learning_rate
@ -311,7 +318,7 @@ def _create_loraplus_optimizer(
def _create_badam_optimizer( def _create_badam_optimizer(
model: "PreTrainedModel", model: "PreTrainedModel",
training_args: "Seq2SeqTrainingArguments", training_args: "TrainingArguments",
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
) -> "torch.optim.Optimizer": ) -> "torch.optim.Optimizer":
decay_params, nodecay_params = [], [] decay_params, nodecay_params = [], []
@ -330,7 +337,7 @@ def _create_badam_optimizer(
] ]
if finetuning_args.badam_mode == "layer": if finetuning_args.badam_mode == "layer":
from badam import BlockOptimizer from badam import BlockOptimizer # type: ignore
base_optimizer = optim_class(param_groups, **optim_kwargs) base_optimizer = optim_class(param_groups, **optim_kwargs)
optimizer = BlockOptimizer( optimizer = BlockOptimizer(
@ -350,7 +357,7 @@ def _create_badam_optimizer(
) )
elif finetuning_args.badam_mode == "ratio": elif finetuning_args.badam_mode == "ratio":
from badam import BlockOptimizerRatio from badam import BlockOptimizerRatio # type: ignore
assert finetuning_args.badam_update_ratio > 1e-6 assert finetuning_args.badam_update_ratio > 1e-6
optimizer = BlockOptimizerRatio( optimizer = BlockOptimizerRatio(
@ -372,9 +379,9 @@ def _create_badam_optimizer(
def _create_adam_mini_optimizer( def _create_adam_mini_optimizer(
model: "PreTrainedModel", model: "PreTrainedModel",
training_args: "Seq2SeqTrainingArguments", training_args: "TrainingArguments",
) -> "torch.optim.Optimizer": ) -> "torch.optim.Optimizer":
from adam_mini import Adam_mini from adam_mini import Adam_mini # type: ignore
hidden_size = getattr(model.config, "hidden_size", None) hidden_size = getattr(model.config, "hidden_size", None)
num_q_head = getattr(model.config, "num_attention_heads", None) num_q_head = getattr(model.config, "num_attention_heads", None)
@ -397,7 +404,7 @@ def _create_adam_mini_optimizer(
def create_custom_optimizer( def create_custom_optimizer(
model: "PreTrainedModel", model: "PreTrainedModel",
training_args: "Seq2SeqTrainingArguments", training_args: "TrainingArguments",
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
) -> Optional["torch.optim.Optimizer"]: ) -> Optional["torch.optim.Optimizer"]:
if finetuning_args.use_galore: if finetuning_args.use_galore:
@ -414,7 +421,7 @@ def create_custom_optimizer(
def create_custom_scheduler( def create_custom_scheduler(
training_args: "Seq2SeqTrainingArguments", training_args: "TrainingArguments",
num_training_steps: int, num_training_steps: int,
optimizer: Optional["torch.optim.Optimizer"] = None, optimizer: Optional["torch.optim.Optimizer"] = None,
) -> None: ) -> None:
@ -459,12 +466,33 @@ def get_batch_logps(
return (per_token_logps * loss_mask).sum(-1), loss_mask.sum(-1) return (per_token_logps * loss_mask).sum(-1), loss_mask.sum(-1)
def nested_detach(
tensors: Union["torch.Tensor", List["torch.Tensor"], Tuple["torch.Tensor"], Dict[str, "torch.Tensor"]],
clone: bool = False,
):
r"""
Detach `tensors` (even if it's a nested list/tuple/dict of tensors).
"""
if isinstance(tensors, (list, tuple)):
return type(tensors)(nested_detach(t, clone=clone) for t in tensors)
elif isinstance(tensors, Mapping):
return type(tensors)({k: nested_detach(t, clone=clone) for k, t in tensors.items()})
if isinstance(tensors, torch.Tensor):
if clone:
return tensors.detach().clone()
else:
return tensors.detach()
else:
return tensors
def get_swanlab_callback(finetuning_args: "FinetuningArguments") -> "TrainerCallback": def get_swanlab_callback(finetuning_args: "FinetuningArguments") -> "TrainerCallback":
r""" r"""
Gets the callback for logging to SwanLab. Gets the callback for logging to SwanLab.
""" """
import swanlab import swanlab # type: ignore
from swanlab.integration.transformers import SwanLabCallback from swanlab.integration.transformers import SwanLabCallback # type: ignore
if finetuning_args.swanlab_api_key is not None: if finetuning_args.swanlab_api_key is not None:
swanlab.login(api_key=finetuning_args.swanlab_api_key) swanlab.login(api_key=finetuning_args.swanlab_api_key)
@ -477,3 +505,28 @@ def get_swanlab_callback(finetuning_args: "FinetuningArguments") -> "TrainerCall
config={"Framework": "🦙LlamaFactory"}, config={"Framework": "🦙LlamaFactory"},
) )
return swanlab_callback return swanlab_callback
def get_ray_trainer(
training_function: Callable,
train_loop_config: Dict[str, Any],
ray_args: "RayArguments",
) -> "TorchTrainer":
if not ray_args.use_ray:
raise ValueError("Ray was not enabled. Please set `USE_RAY=1` to enable ray.")
trainer = TorchTrainer(
training_function,
train_loop_config=train_loop_config,
scaling_config=ScalingConfig(
num_workers=ray_args.ray_num_workers,
resources_per_worker=ray_args.resources_per_worker,
placement_strategy=ray_args.placement_strategy,
use_gpu=True,
),
run_config=RunConfig(
name=ray_args.ray_run_name,
storage_path=Path("./saves").absolute().as_posix(),
),
)
return trainer

View File

@ -22,7 +22,8 @@ from transformers import PreTrainedModel
from ..data import get_template_and_fix_tokenizer from ..data import get_template_and_fix_tokenizer
from ..extras import logging from ..extras import logging
from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
from ..hparams import get_infer_args, get_train_args from ..extras.packages import is_ray_available
from ..hparams import get_infer_args, get_ray_args, get_train_args, read_args
from ..model import load_model, load_tokenizer from ..model import load_model, load_tokenizer
from .callbacks import LogCallback, PissaConvertCallback, ReporterCallback from .callbacks import LogCallback, PissaConvertCallback, ReporterCallback
from .dpo import run_dpo from .dpo import run_dpo
@ -31,7 +32,11 @@ from .ppo import run_ppo
from .pt import run_pt from .pt import run_pt
from .rm import run_rm from .rm import run_rm
from .sft import run_sft from .sft import run_sft
from .trainer_utils import get_swanlab_callback from .trainer_utils import get_ray_trainer, get_swanlab_callback
if is_ray_available():
from ray.train.huggingface.transformers import RayTrainReportCallback
if TYPE_CHECKING: if TYPE_CHECKING:
@ -41,10 +46,12 @@ if TYPE_CHECKING:
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: List["TrainerCallback"] = []) -> None: def _training_function(config: Dict[str, Any]) -> None:
callbacks.append(LogCallback()) args = config.get("args")
callbacks: List[Any] = config.get("callbacks")
model_args, data_args, training_args, finetuning_args, generating_args = get_train_args(args) model_args, data_args, training_args, finetuning_args, generating_args = get_train_args(args)
callbacks.append(LogCallback())
if finetuning_args.pissa_convert: if finetuning_args.pissa_convert:
callbacks.append(PissaConvertCallback()) callbacks.append(PissaConvertCallback())
@ -69,6 +76,22 @@ def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: List["TrainerCallb
raise ValueError(f"Unknown task: {finetuning_args.stage}.") raise ValueError(f"Unknown task: {finetuning_args.stage}.")
def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: Optional[List["TrainerCallback"]] = None) -> None:
args = read_args(args)
ray_args = get_ray_args(args)
callbacks = callbacks or []
if ray_args.use_ray:
callbacks.append(RayTrainReportCallback())
trainer = get_ray_trainer(
training_function=_training_function,
train_loop_config={"args": args, "callbacks": callbacks},
ray_args=ray_args,
)
trainer.fit()
else:
_training_function(config={"args": args, "callbacks": callbacks})
def export_model(args: Optional[Dict[str, Any]] = None) -> None: def export_model(args: Optional[Dict[str, Any]] = None) -> None:
model_args, data_args, finetuning_args, _ = get_infer_args(args) model_args, data_args, finetuning_args, _ = get_infer_args(args)

View File

@ -21,7 +21,7 @@ from typing import TYPE_CHECKING, Any, Dict, Generator, Optional
from transformers.trainer import TRAINING_ARGS_NAME from transformers.trainer import TRAINING_ARGS_NAME
from ..extras.constants import LLAMABOARD_CONFIG, PEFT_METHODS, TRAINING_STAGES from ..extras.constants import LLAMABOARD_CONFIG, PEFT_METHODS, TRAINING_STAGES
from ..extras.misc import is_gpu_or_npu_available, torch_gc from ..extras.misc import is_gpu_or_npu_available, torch_gc, use_ray
from ..extras.packages import is_gradio_available, is_transformers_version_equal_to_4_46 from ..extras.packages import is_gradio_available, is_transformers_version_equal_to_4_46
from .common import DEFAULT_CACHE_DIR, DEFAULT_CONFIG_DIR, QUANTIZATION_BITS, get_save_dir, load_config from .common import DEFAULT_CACHE_DIR, DEFAULT_CONFIG_DIR, QUANTIZATION_BITS, get_save_dir, load_config
from .locales import ALERTS, LOCALES from .locales import ALERTS, LOCALES
@ -394,12 +394,12 @@ class Runner:
continue continue
if self.do_train: if self.do_train:
if os.path.exists(os.path.join(output_path, TRAINING_ARGS_NAME)): if os.path.exists(os.path.join(output_path, TRAINING_ARGS_NAME)) or use_ray():
finish_info = ALERTS["info_finished"][lang] finish_info = ALERTS["info_finished"][lang]
else: else:
finish_info = ALERTS["err_failed"][lang] finish_info = ALERTS["err_failed"][lang]
else: else:
if os.path.exists(os.path.join(output_path, "all_results.json")): if os.path.exists(os.path.join(output_path, "all_results.json")) or use_ray():
finish_info = get_eval_results(os.path.join(output_path, "all_results.json")) finish_info = get_eval_results(os.path.join(output_path, "all_results.json"))
else: else:
finish_info = ALERTS["err_failed"][lang] finish_info = ALERTS["err_failed"][lang]

View File

@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
import os import os
from typing import TYPE_CHECKING, List, Sequence from typing import TYPE_CHECKING, Sequence
import pytest import pytest
from transformers import AutoTokenizer from transformers import AutoTokenizer
@ -42,39 +42,36 @@ MESSAGES = [
def _check_tokenization( def _check_tokenization(
tokenizer: "PreTrainedTokenizer", batch_input_ids: Sequence[Sequence[int]], batch_text: Sequence[str] tokenizer: "PreTrainedTokenizer", batch_input_ids: Sequence[Sequence[int]], batch_text: Sequence[str]
) -> None: ) -> None:
r"""
Checks token ids and texts.
encode(text) == token_ids
decode(token_ids) == text
"""
for input_ids, text in zip(batch_input_ids, batch_text): for input_ids, text in zip(batch_input_ids, batch_text):
assert input_ids == tokenizer.encode(text, add_special_tokens=False) assert tokenizer.encode(text, add_special_tokens=False) == input_ids
assert tokenizer.decode(input_ids) == text assert tokenizer.decode(input_ids) == text
def _check_single_template( def _check_template(model_id: str, template_name: str, prompt_str: str, answer_str: str, use_fast: bool) -> None:
model_id: str, template_name: str, prompt_str: str, answer_str: str, extra_str: str, use_fast: bool r"""
) -> List[str]: Checks template.
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=use_fast, token=HF_TOKEN)
content_str = tokenizer.apply_chat_template(MESSAGES, tokenize=False)
content_ids = tokenizer.apply_chat_template(MESSAGES, tokenize=True)
template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template=template_name))
prompt_ids, answer_ids = template.encode_oneturn(tokenizer, MESSAGES)
assert content_str == prompt_str + answer_str + extra_str
assert content_ids == prompt_ids + answer_ids + tokenizer.encode(extra_str, add_special_tokens=False)
_check_tokenization(tokenizer, (prompt_ids, answer_ids), (prompt_str, answer_str))
return content_ids
def _check_template(model_id: str, template_name: str, prompt_str: str, answer_str: str, extra_str: str = "") -> None:
"""
Checks template for both the slow tokenizer and the fast tokenizer.
Args: Args:
model_id: the model id on hugging face hub. model_id: the model id on hugging face hub.
template_name: the template name. template_name: the template name.
prompt_str: the string corresponding to the prompt part. prompt_str: the string corresponding to the prompt part.
answer_str: the string corresponding to the answer part. answer_str: the string corresponding to the answer part.
extra_str: the extra string in the jinja template of the original tokenizer. use_fast: whether to use fast tokenizer.
""" """
slow_ids = _check_single_template(model_id, template_name, prompt_str, answer_str, extra_str, use_fast=False) tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=use_fast, token=HF_TOKEN)
fast_ids = _check_single_template(model_id, template_name, prompt_str, answer_str, extra_str, use_fast=True) content_str = tokenizer.apply_chat_template(MESSAGES, tokenize=False)
assert slow_ids == fast_ids content_ids = tokenizer.apply_chat_template(MESSAGES, tokenize=True)
template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template=template_name))
prompt_ids, answer_ids = template.encode_oneturn(tokenizer, MESSAGES)
assert content_str == prompt_str + answer_str
assert content_ids == prompt_ids + answer_ids
_check_tokenization(tokenizer, (prompt_ids, answer_ids), (prompt_str, answer_str))
@pytest.mark.parametrize("use_fast", [True, False]) @pytest.mark.parametrize("use_fast", [True, False])
@ -125,19 +122,21 @@ def test_jinja_template(use_fast: bool):
@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.") @pytest.mark.skipif(not HF_TOKEN, reason="Gated model.")
def test_gemma_template(): @pytest.mark.parametrize("use_fast", [True, False])
def test_gemma_template(use_fast: bool):
prompt_str = ( prompt_str = (
"<bos><start_of_turn>user\nHow are you<end_of_turn>\n" "<bos><start_of_turn>user\nHow are you<end_of_turn>\n"
"<start_of_turn>model\nI am fine!<end_of_turn>\n" "<start_of_turn>model\nI am fine!<end_of_turn>\n"
"<start_of_turn>user\n你好<end_of_turn>\n" "<start_of_turn>user\n你好<end_of_turn>\n"
"<start_of_turn>model\n" "<start_of_turn>model\n"
) )
answer_str = "很高兴认识你!" answer_str = "很高兴认识你!<end_of_turn>\n"
_check_template("google/gemma-2-9b-it", "gemma", prompt_str, answer_str, extra_str="<end_of_turn>\n") _check_template("google/gemma-2-9b-it", "gemma", prompt_str, answer_str, use_fast)
@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.") @pytest.mark.skipif(not HF_TOKEN, reason="Gated model.")
def test_llama3_template(): @pytest.mark.parametrize("use_fast", [True, False])
def test_llama3_template(use_fast: bool):
prompt_str = ( prompt_str = (
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nHow are you<|eot_id|>" "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nHow are you<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>\n\nI am fine!<|eot_id|>" "<|start_header_id|>assistant<|end_header_id|>\n\nI am fine!<|eot_id|>"
@ -145,10 +144,25 @@ def test_llama3_template():
"<|start_header_id|>assistant<|end_header_id|>\n\n" "<|start_header_id|>assistant<|end_header_id|>\n\n"
) )
answer_str = "很高兴认识你!<|eot_id|>" answer_str = "很高兴认识你!<|eot_id|>"
_check_template("meta-llama/Meta-Llama-3-8B-Instruct", "llama3", prompt_str, answer_str) _check_template("meta-llama/Meta-Llama-3-8B-Instruct", "llama3", prompt_str, answer_str, use_fast)
def test_qwen_template(): @pytest.mark.parametrize(
"use_fast", [True, pytest.param(False, marks=pytest.mark.xfail(reason="Phi-4 slow tokenizer is broken."))]
)
def test_phi4_template(use_fast: bool):
prompt_str = (
"<|im_start|>user<|im_sep|>How are you<|im_end|>"
"<|im_start|>assistant<|im_sep|>I am fine!<|im_end|>"
"<|im_start|>user<|im_sep|>你好<|im_end|>"
"<|im_start|>assistant<|im_sep|>"
)
answer_str = "很高兴认识你!<|im_end|>"
_check_template("microsoft/phi-4", "phi4", prompt_str, answer_str, use_fast)
@pytest.mark.parametrize("use_fast", [True, False])
def test_qwen_template(use_fast: bool):
prompt_str = ( prompt_str = (
"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n" "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
"<|im_start|>user\nHow are you<|im_end|>\n" "<|im_start|>user\nHow are you<|im_end|>\n"
@ -156,17 +170,18 @@ def test_qwen_template():
"<|im_start|>user\n你好<|im_end|>\n" "<|im_start|>user\n你好<|im_end|>\n"
"<|im_start|>assistant\n" "<|im_start|>assistant\n"
) )
answer_str = "很高兴认识你!<|im_end|>" answer_str = "很高兴认识你!<|im_end|>\n"
_check_template("Qwen/Qwen2-7B-Instruct", "qwen", prompt_str, answer_str, extra_str="\n") _check_template("Qwen/Qwen2-7B-Instruct", "qwen", prompt_str, answer_str, use_fast)
@pytest.mark.xfail(reason="The fast tokenizer of Yi model is corrupted.") @pytest.mark.parametrize("use_fast", [True, False])
def test_yi_template(): @pytest.mark.xfail(reason="Yi tokenizer is broken.")
def test_yi_template(use_fast: bool):
prompt_str = ( prompt_str = (
"<|im_start|>user\nHow are you<|im_end|>\n" "<|im_start|>user\nHow are you<|im_end|>\n"
"<|im_start|>assistant\nI am fine!<|im_end|>\n" "<|im_start|>assistant\nI am fine!<|im_end|>\n"
"<|im_start|>user\n你好<|im_end|>\n" "<|im_start|>user\n你好<|im_end|>\n"
"<|im_start|>assistant\n" "<|im_start|>assistant\n"
) )
answer_str = "很高兴认识你!<|im_end|>" answer_str = "很高兴认识你!<|im_end|>\n"
_check_template("01-ai/Yi-1.5-6B-Chat", "yi", prompt_str, answer_str) _check_template("01-ai/Yi-1.5-6B-Chat", "yi", prompt_str, answer_str, use_fast)