mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-23 06:12:50 +08:00
Merge branch 'main' into minicpmv
Former-commit-id: fc045d7dd871985d621430b5662cba882188a59c
This commit is contained in:
commit
f51ac40f0a
@ -12,6 +12,7 @@ FORCE_CHECK_IMPORTS=
|
||||
LLAMAFACTORY_VERBOSITY=
|
||||
USE_MODELSCOPE_HUB=
|
||||
USE_OPENMIND_HUB=
|
||||
USE_RAY=
|
||||
RECORD_VRAM=
|
||||
# torchrun
|
||||
FORCE_TORCHRUN=
|
||||
|
63
.github/ISSUE_TEMPLATE/1-bug-report.yml
vendored
Normal file
63
.github/ISSUE_TEMPLATE/1-bug-report.yml
vendored
Normal file
@ -0,0 +1,63 @@
|
||||
name: "\U0001F41B Bug / help"
|
||||
description: Create a report to help us improve the LLaMA Factory
|
||||
labels: ["bug", "pending"]
|
||||
body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
Issues included in **[FAQs](https://github.com/hiyouga/LLaMA-Factory/issues/4614)** or those with **insufficient** information may be closed without a response.
|
||||
已经包含在 **[常见问题](https://github.com/hiyouga/LLaMA-Factory/issues/4614)** 内或提供信息**不完整**的 issues 可能不会被回复。
|
||||
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
Please do not create issues that are not related to framework bugs under this category, use **[Discussions](https://github.com/hiyouga/LLaMA-Factory/discussions/categories/q-a)** instead.
|
||||
请勿在此分类下创建和框架 bug 无关的 issues,请使用 **[讨论区](https://github.com/hiyouga/LLaMA-Factory/discussions/categories/q-a)**。
|
||||
|
||||
- type: checkboxes
|
||||
id: reminder
|
||||
attributes:
|
||||
label: Reminder
|
||||
description: |
|
||||
Please ensure you have read the above rules carefully and searched the existing issues (including FAQs).
|
||||
请确保您已经认真阅读了上述规则并且搜索过现有的 issues(包括常见问题)。
|
||||
|
||||
options:
|
||||
- label: I have read the above rules and searched the existing issues.
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: system-info
|
||||
validations:
|
||||
required: true
|
||||
attributes:
|
||||
label: System Info
|
||||
description: |
|
||||
Please share your system info with us. You can run the command **llamafactory-cli env** and copy-paste its output below.
|
||||
请提供您的系统信息。您可以在命令行运行 **llamafactory-cli env** 并将其输出复制到该文本框中。
|
||||
|
||||
placeholder: llamafactory version, platform, python version, ...
|
||||
|
||||
- type: textarea
|
||||
id: reproduction
|
||||
validations:
|
||||
required: true
|
||||
attributes:
|
||||
label: Reproduction
|
||||
description: |
|
||||
Please provide entry arguments, error messages and stack traces that reproduces the problem.
|
||||
请提供入口参数,错误日志以及异常堆栈以便于我们复现问题。
|
||||
Remember to wrap your log messages with \`\`\`.
|
||||
请务必使用 Markdown 标签 \`\`\` 来包裹您的日志信息。
|
||||
|
||||
value: |
|
||||
```text
|
||||
Put your message here.
|
||||
```
|
||||
|
||||
- type: textarea
|
||||
id: others
|
||||
validations:
|
||||
required: false
|
||||
attributes:
|
||||
label: Others
|
41
.github/ISSUE_TEMPLATE/2-feature-request.yml
vendored
Normal file
41
.github/ISSUE_TEMPLATE/2-feature-request.yml
vendored
Normal file
@ -0,0 +1,41 @@
|
||||
name: "\U0001F680 Feature request"
|
||||
description: Submit a request for a new feature
|
||||
labels: ["enhancement", "pending"]
|
||||
body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
Please do not create issues that are not related to new features under this category.
|
||||
请勿在此分类下创建和新特性无关的 issues。
|
||||
|
||||
- type: checkboxes
|
||||
id: reminder
|
||||
attributes:
|
||||
label: Reminder
|
||||
description: |
|
||||
Please ensure you have read the above rules carefully and searched the existing issues.
|
||||
请确保您已经认真阅读了上述规则并且搜索过现有的 issues。
|
||||
|
||||
options:
|
||||
- label: I have read the above rules and searched the existing issues.
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: description
|
||||
validations:
|
||||
required: true
|
||||
attributes:
|
||||
label: Description
|
||||
description: |
|
||||
A clear and concise description of the feature proposal.
|
||||
请详细描述您希望加入的新功能特性。
|
||||
|
||||
- type: textarea
|
||||
id: contribution
|
||||
validations:
|
||||
required: false
|
||||
attributes:
|
||||
label: Pull Request
|
||||
description: |
|
||||
Have you already created the relevant PR and submitted the code?
|
||||
您是否已经创建了相关 PR 并提交了代码?
|
66
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
66
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
@ -1,66 +0,0 @@
|
||||
name: "\U0001F41B Bug / Help"
|
||||
description: Create a report to help us improve the LLaMA Factory
|
||||
body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
Issues included in **FAQs** or those with **insufficient** information may be closed without a response.
|
||||
包含在**常见问题**内或提供信息**不完整**的 issues 可能不会被回复。
|
||||
|
||||
- type: checkboxes
|
||||
id: reminder
|
||||
attributes:
|
||||
label: Reminder
|
||||
description: |
|
||||
Please ensure you have read the README carefully and searched the existing issues (including FAQs).
|
||||
请确保您已经认真阅读了 README 并且搜索过现有的 issues(包括常见问题)。
|
||||
|
||||
options:
|
||||
- label: I have read the README and searched the existing issues.
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: system-info
|
||||
validations:
|
||||
required: true
|
||||
attributes:
|
||||
label: System Info
|
||||
description: |
|
||||
Please share your system info with us. You can run the command **llamafactory-cli env** and copy-paste its output below.
|
||||
请提供您的系统信息。您可以在命令行运行 **llamafactory-cli env** 并将其输出复制到该文本框中。
|
||||
|
||||
placeholder: llamafactory version, platform, python version, ...
|
||||
|
||||
- type: textarea
|
||||
id: reproduction
|
||||
validations:
|
||||
required: true
|
||||
attributes:
|
||||
label: Reproduction
|
||||
description: |
|
||||
Please provide code snippets, error messages and stack traces that reproduces the problem.
|
||||
请提供运行参数,错误信息以及异常堆栈以便于我们复现该问题。
|
||||
Remember to use Markdown tags to correctly format your code.
|
||||
请合理使用 Markdown 标签来格式化您的文本。
|
||||
|
||||
placeholder: |
|
||||
```bash
|
||||
llamafactory-cli train ...
|
||||
```
|
||||
|
||||
- type: textarea
|
||||
id: expected-behavior
|
||||
validations:
|
||||
required: false
|
||||
attributes:
|
||||
label: Expected behavior
|
||||
description: |
|
||||
Please provide a clear and concise description of what you would expect to happen.
|
||||
请提供您原本的目的,即这段代码的期望行为。
|
||||
|
||||
- type: textarea
|
||||
id: others
|
||||
validations:
|
||||
required: false
|
||||
attributes:
|
||||
label: Others
|
1
.github/ISSUE_TEMPLATE/config.yml
vendored
Normal file
1
.github/ISSUE_TEMPLATE/config.yml
vendored
Normal file
@ -0,0 +1 @@
|
||||
blank_issues_enabled: false
|
8
.github/workflows/label_issue.yml
vendored
8
.github/workflows/label_issue.yml
vendored
@ -18,13 +18,15 @@ jobs:
|
||||
ISSUE_URL: ${{ github.event.issue.html_url }}
|
||||
ISSUE_TITLE: ${{ github.event.issue.title }}
|
||||
run: |
|
||||
LABEL=pending
|
||||
LABEL=""
|
||||
NPU_KEYWORDS=(npu huawei ascend 华为 昇腾)
|
||||
ISSUE_TITLE_LOWER=$(echo $ISSUE_TITLE | tr '[:upper:]' '[:lower:]')
|
||||
for KEYWORD in ${NPU_KEYWORDS[@]}; do
|
||||
if [[ $ISSUE_TITLE_LOWER == *$KEYWORD* ]] && [[ $ISSUE_TITLE_LOWER != *input* ]]; then
|
||||
LABEL=pending,npu
|
||||
LABEL="npu"
|
||||
break
|
||||
fi
|
||||
done
|
||||
gh issue edit $ISSUE_URL --add-label $LABEL
|
||||
if [ -n "$LABEL" ]; then
|
||||
gh issue edit $ISSUE_URL --add-label $LABEL
|
||||
fi
|
||||
|
11
README.md
11
README.md
@ -88,14 +88,16 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
|
||||
|
||||
## Changelog
|
||||
|
||||
[25/01/10] We supported fine-tuning the **[Phi-4](https://huggingface.co/microsoft/phi-4)** model.
|
||||
|
||||
[24/12/21] We supported using **[SwanLab](https://github.com/SwanHubX/SwanLab)** for experiment tracking and visualization. See [this section](#use-swanlab-logger) for details.
|
||||
|
||||
[24/11/27] We supported fine-tuning the **[Skywork-o1](https://huggingface.co/Skywork/Skywork-o1-Open-Llama-3.1-8B)** model and the **[OpenO1](https://huggingface.co/datasets/O1-OPEN/OpenO1-SFT)** dataset.
|
||||
|
||||
[24/10/09] We supported downloading pre-trained models and datasets from the **[Modelers Hub](https://modelers.cn/models)**. See [this tutorial](#download-from-modelers-hub) for usage.
|
||||
|
||||
<details><summary>Full Changelog</summary>
|
||||
|
||||
[24/10/09] We supported downloading pre-trained models and datasets from the **[Modelers Hub](https://modelers.cn/models)**. See [this tutorial](#download-from-modelers-hub) for usage.
|
||||
|
||||
[24/09/19] We supported fine-tuning the **[Qwen2.5](https://qwenlm.github.io/blog/qwen2.5/)** models.
|
||||
|
||||
[24/08/30] We supported fine-tuning the **[Qwen2-VL](https://qwenlm.github.io/blog/qwen2-vl/)** models. Thank [@simonJJJ](https://github.com/simonJJJ)'s PR.
|
||||
@ -211,8 +213,9 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
|
||||
| [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
|
||||
| [PaliGemma/PaliGemma2](https://huggingface.co/google) | 3B/10B/28B | paligemma |
|
||||
| [Phi-1.5/Phi-2](https://huggingface.co/microsoft) | 1.3B/2.7B | - |
|
||||
| [Phi-3](https://huggingface.co/microsoft) | 4B/14B | phi |
|
||||
| [Phi-3/Phi-3.5](https://huggingface.co/microsoft) | 4B/14B | phi |
|
||||
| [Phi-3-small](https://huggingface.co/microsoft) | 7B | phi_small |
|
||||
| [Phi-4](https://huggingface.co/microsoft) | 14B | phi4 |
|
||||
| [Pixtral](https://huggingface.co/mistralai) | 12B | pixtral |
|
||||
| [Qwen/QwQ (1-2.5) (Code/Math/MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen |
|
||||
| [Qwen2-VL/QVQ](https://huggingface.co/Qwen) | 2B/7B/72B | qwen2_vl |
|
||||
@ -762,7 +765,7 @@ If you have a project that should be incorporated, please contact via email or c
|
||||
|
||||
This repository is licensed under the [Apache-2.0 License](LICENSE).
|
||||
|
||||
Please follow the model licenses to use the corresponding model weights: [Baichuan 2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM-4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [GPT-2](https://github.com/openai/gpt-2/blob/master/LICENSE) / [Granite](LICENSE) / [Index](https://huggingface.co/IndexTeam/Index-1.9B/blob/main/LICENSE) / [InternLM2](https://github.com/InternLM/InternLM#license) / [Llama](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [Llama 2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [Llama 3](https://llama.meta.com/llama3/license/) / [MiniCPM](https://github.com/OpenBMB/MiniCPM/blob/main/MiniCPM%20Model%20License.md) / [Mistral/Mixtral/Pixtral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/Phi-2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [Skywork](https://huggingface.co/Skywork/Skywork-13B-base/blob/main/Skywork%20Community%20License.pdf) / [StarCoder 2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [TeleChat2](https://huggingface.co/Tele-AI/telechat-7B/blob/main/TeleChat%E6%A8%A1%E5%9E%8B%E7%A4%BE%E5%8C%BA%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan 2](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
|
||||
Please follow the model licenses to use the corresponding model weights: [Baichuan 2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM-4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [GPT-2](https://github.com/openai/gpt-2/blob/master/LICENSE) / [Granite](LICENSE) / [Index](https://huggingface.co/IndexTeam/Index-1.9B/blob/main/LICENSE) / [InternLM2](https://github.com/InternLM/InternLM#license) / [Llama](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [Llama 2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [Llama 3](https://llama.meta.com/llama3/license/) / [MiniCPM](https://github.com/OpenBMB/MiniCPM/blob/main/MiniCPM%20Model%20License.md) / [Mistral/Mixtral/Pixtral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/Phi-2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3/Phi-4](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [Skywork](https://huggingface.co/Skywork/Skywork-13B-base/blob/main/Skywork%20Community%20License.pdf) / [StarCoder 2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [TeleChat2](https://huggingface.co/Tele-AI/telechat-7B/blob/main/TeleChat%E6%A8%A1%E5%9E%8B%E7%A4%BE%E5%8C%BA%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan 2](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
|
||||
|
||||
## Citation
|
||||
|
||||
|
11
README_zh.md
11
README_zh.md
@ -89,14 +89,16 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272
|
||||
|
||||
## 更新日志
|
||||
|
||||
[25/01/10] 我们支持了 **[Phi-4](https://huggingface.co/microsoft/phi-4)** 模型的微调。
|
||||
|
||||
[24/12/21] 我们支持了使用 **[SwanLab](https://github.com/SwanHubX/SwanLab)** 跟踪与可视化实验。详细用法请参考 [此部分](#使用-swanlab-面板)。
|
||||
|
||||
[24/11/27] 我们支持了 **[Skywork-o1](https://huggingface.co/Skywork/Skywork-o1-Open-Llama-3.1-8B)** 模型的微调和 **[OpenO1](https://huggingface.co/datasets/O1-OPEN/OpenO1-SFT)** 数据集。
|
||||
|
||||
[24/10/09] 我们支持了从 **[魔乐社区](https://modelers.cn/models)** 下载预训练模型和数据集。详细用法请参照 [此教程](#从魔乐社区下载)。
|
||||
|
||||
<details><summary>展开日志</summary>
|
||||
|
||||
[24/10/09] 我们支持了从 **[魔乐社区](https://modelers.cn/models)** 下载预训练模型和数据集。详细用法请参照 [此教程](#从魔乐社区下载)。
|
||||
|
||||
[24/09/19] 我们支持了 **[Qwen2.5](https://qwenlm.github.io/blog/qwen2.5/)** 模型的微调。
|
||||
|
||||
[24/08/30] 我们支持了 **[Qwen2-VL](https://qwenlm.github.io/blog/qwen2-vl/)** 模型的微调。感谢 [@simonJJJ](https://github.com/simonJJJ) 的 PR。
|
||||
@ -212,8 +214,9 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272
|
||||
| [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
|
||||
| [PaliGemma/PaliGemma2](https://huggingface.co/google) | 3B/10B/28B | paligemma |
|
||||
| [Phi-1.5/Phi-2](https://huggingface.co/microsoft) | 1.3B/2.7B | - |
|
||||
| [Phi-3](https://huggingface.co/microsoft) | 4B/14B | phi |
|
||||
| [Phi-3/Phi-3.5](https://huggingface.co/microsoft) | 4B/14B | phi |
|
||||
| [Phi-3-small](https://huggingface.co/microsoft) | 7B | phi_small |
|
||||
| [Phi-4](https://huggingface.co/microsoft) | 14B | phi4 |
|
||||
| [Pixtral](https://huggingface.co/mistralai) | 12B | pixtral |
|
||||
| [Qwen/QwQ (1-2.5) (Code/Math/MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen |
|
||||
| [Qwen2-VL/QVQ](https://huggingface.co/Qwen) | 2B/7B/72B | qwen2_vl |
|
||||
@ -763,7 +766,7 @@ swanlab_run_name: test_run # 可选
|
||||
|
||||
本仓库的代码依照 [Apache-2.0](LICENSE) 协议开源。
|
||||
|
||||
使用模型权重时,请遵循对应的模型协议:[Baichuan 2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM-4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [GPT-2](https://github.com/openai/gpt-2/blob/master/LICENSE) / [Granite](LICENSE) / [Index](https://huggingface.co/IndexTeam/Index-1.9B/blob/main/LICENSE) / [InternLM2](https://github.com/InternLM/InternLM#license) / [Llama](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [Llama 2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [Llama 3](https://llama.meta.com/llama3/license/) / [MiniCPM](https://github.com/OpenBMB/MiniCPM/blob/main/MiniCPM%20Model%20License.md) / [Mistral/Mixtral/Pixtral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/Phi-2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [Skywork](https://huggingface.co/Skywork/Skywork-13B-base/blob/main/Skywork%20Community%20License.pdf) / [StarCoder 2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [TeleChat2](https://huggingface.co/Tele-AI/telechat-7B/blob/main/TeleChat%E6%A8%A1%E5%9E%8B%E7%A4%BE%E5%8C%BA%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan 2](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
|
||||
使用模型权重时,请遵循对应的模型协议:[Baichuan 2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM-4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [GPT-2](https://github.com/openai/gpt-2/blob/master/LICENSE) / [Granite](LICENSE) / [Index](https://huggingface.co/IndexTeam/Index-1.9B/blob/main/LICENSE) / [InternLM2](https://github.com/InternLM/InternLM#license) / [Llama](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [Llama 2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [Llama 3](https://llama.meta.com/llama3/license/) / [MiniCPM](https://github.com/OpenBMB/MiniCPM/blob/main/MiniCPM%20Model%20License.md) / [Mistral/Mixtral/Pixtral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/Phi-2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3/Phi-4](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [Skywork](https://huggingface.co/Skywork/Skywork-13B-base/blob/main/Skywork%20Community%20License.pdf) / [StarCoder 2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [TeleChat2](https://huggingface.co/Tele-AI/telechat-7B/blob/main/TeleChat%E6%A8%A1%E5%9E%8B%E7%A4%BE%E5%8C%BA%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan 2](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
|
||||
|
||||
## 引用
|
||||
|
||||
|
Binary file not shown.
Before Width: | Height: | Size: 163 KiB After Width: | Height: | Size: 164 KiB |
Binary file not shown.
Before Width: | Height: | Size: 168 KiB After Width: | Height: | Size: 167 KiB |
@ -23,10 +23,10 @@ ARG HTTP_PROXY=
|
||||
WORKDIR /app
|
||||
|
||||
# Set http proxy
|
||||
RUN if [ -n "$HTTP_PROXY" ]; then \
|
||||
echo "Configuring proxy..."; \
|
||||
export http_proxy=$HTTP_PROXY; \
|
||||
export https_proxy=$HTTP_PROXY; \
|
||||
RUN if [ -n "$HTTP_PROXY" ]; then \
|
||||
echo "Configuring proxy..."; \
|
||||
export http_proxy=$HTTP_PROXY; \
|
||||
export https_proxy=$HTTP_PROXY; \
|
||||
fi
|
||||
|
||||
# Install the requirements
|
||||
@ -34,10 +34,10 @@ COPY requirements.txt /app
|
||||
RUN pip config set global.index-url "$PIP_INDEX" && \
|
||||
pip config set global.extra-index-url "$PIP_INDEX" && \
|
||||
python -m pip install --upgrade pip && \
|
||||
if [ -n "$HTTP_PROXY" ]; then \
|
||||
python -m pip install --proxy=$HTTP_PROXY -r requirements.txt; \
|
||||
else \
|
||||
python -m pip install -r requirements.txt; \
|
||||
if [ -n "$HTTP_PROXY" ]; then \
|
||||
python -m pip install --proxy=$HTTP_PROXY -r requirements.txt; \
|
||||
else \
|
||||
python -m pip install -r requirements.txt; \
|
||||
fi
|
||||
|
||||
# Copy the rest of the application into the image
|
||||
@ -63,10 +63,10 @@ RUN EXTRA_PACKAGES="metrics"; \
|
||||
if [ "$INSTALL_EETQ" == "true" ]; then \
|
||||
EXTRA_PACKAGES="${EXTRA_PACKAGES},eetq"; \
|
||||
fi; \
|
||||
if [ -n "$HTTP_PROXY" ]; then \
|
||||
pip install --proxy=$HTTP_PROXY -e ".[$EXTRA_PACKAGES]"; \
|
||||
else \
|
||||
pip install -e ".[$EXTRA_PACKAGES]"; \
|
||||
if [ -n "$HTTP_PROXY" ]; then \
|
||||
pip install --proxy=$HTTP_PROXY -e ".[$EXTRA_PACKAGES]"; \
|
||||
else \
|
||||
pip install -e ".[$EXTRA_PACKAGES]"; \
|
||||
fi
|
||||
|
||||
# Rebuild flash attention
|
||||
@ -76,8 +76,8 @@ RUN pip uninstall -y transformer-engine flash-attn && \
|
||||
if [ -n "$HTTP_PROXY" ]; then \
|
||||
pip install --proxy=$HTTP_PROXY ninja && \
|
||||
pip install --proxy=$HTTP_PROXY --no-cache-dir flash-attn --no-build-isolation; \
|
||||
else \
|
||||
pip install ninja && \
|
||||
else \
|
||||
pip install ninja && \
|
||||
pip install --no-cache-dir flash-attn --no-build-isolation; \
|
||||
fi; \
|
||||
fi
|
||||
|
@ -18,10 +18,10 @@ ARG HTTP_PROXY=
|
||||
WORKDIR /app
|
||||
|
||||
# Set http proxy
|
||||
RUN if [ -n "$HTTP_PROXY" ]; then \
|
||||
echo "Configuring proxy..."; \
|
||||
export http_proxy=$HTTP_PROXY; \
|
||||
export https_proxy=$HTTP_PROXY; \
|
||||
RUN if [ -n "$HTTP_PROXY" ]; then \
|
||||
echo "Configuring proxy..."; \
|
||||
export http_proxy=$HTTP_PROXY; \
|
||||
export https_proxy=$HTTP_PROXY; \
|
||||
fi
|
||||
|
||||
# Install the requirements
|
||||
@ -29,10 +29,10 @@ COPY requirements.txt /app
|
||||
RUN pip config set global.index-url "$PIP_INDEX" && \
|
||||
pip config set global.extra-index-url "$TORCH_INDEX" && \
|
||||
python -m pip install --upgrade pip && \
|
||||
if [ -n "$HTTP_PROXY" ]; then \
|
||||
python -m pip install --proxy=$HTTP_PROXY -r requirements.txt; \
|
||||
else \
|
||||
python -m pip install -r requirements.txt; \
|
||||
if [ -n "$HTTP_PROXY" ]; then \
|
||||
python -m pip install --proxy=$HTTP_PROXY -r requirements.txt; \
|
||||
else \
|
||||
python -m pip install -r requirements.txt; \
|
||||
fi
|
||||
|
||||
# Copy the rest of the application into the image
|
||||
@ -43,10 +43,10 @@ RUN EXTRA_PACKAGES="torch-npu,metrics"; \
|
||||
if [ "$INSTALL_DEEPSPEED" == "true" ]; then \
|
||||
EXTRA_PACKAGES="${EXTRA_PACKAGES},deepspeed"; \
|
||||
fi; \
|
||||
if [ -n "$HTTP_PROXY" ]; then \
|
||||
pip install --proxy=$HTTP_PROXY -e ".[$EXTRA_PACKAGES]"; \
|
||||
else \
|
||||
pip install -e ".[$EXTRA_PACKAGES]"; \
|
||||
if [ -n "$HTTP_PROXY" ]; then \
|
||||
pip install --proxy=$HTTP_PROXY -e ".[$EXTRA_PACKAGES]"; \
|
||||
else \
|
||||
pip install -e ".[$EXTRA_PACKAGES]"; \
|
||||
fi
|
||||
|
||||
# Unset http proxy
|
||||
|
@ -19,10 +19,10 @@ ARG HTTP_PROXY=
|
||||
WORKDIR /app
|
||||
|
||||
# Set http proxy
|
||||
RUN if [ -n "$HTTP_PROXY" ]; then \
|
||||
echo "Configuring proxy..."; \
|
||||
export http_proxy=$HTTP_PROXY; \
|
||||
export https_proxy=$HTTP_PROXY; \
|
||||
RUN if [ -n "$HTTP_PROXY" ]; then \
|
||||
echo "Configuring proxy..."; \
|
||||
export http_proxy=$HTTP_PROXY; \
|
||||
export https_proxy=$HTTP_PROXY; \
|
||||
fi
|
||||
|
||||
# Install the requirements
|
||||
@ -30,10 +30,10 @@ COPY requirements.txt /app
|
||||
RUN pip config set global.index-url "$PIP_INDEX" && \
|
||||
pip config set global.extra-index-url "$PIP_INDEX" && \
|
||||
python -m pip install --upgrade pip && \
|
||||
if [ -n "$HTTP_PROXY" ]; then \
|
||||
python -m pip install --proxy=$HTTP_PROXY -r requirements.txt; \
|
||||
else \
|
||||
python -m pip install -r requirements.txt; \
|
||||
if [ -n "$HTTP_PROXY" ]; then \
|
||||
python -m pip install --proxy=$HTTP_PROXY -r requirements.txt; \
|
||||
else \
|
||||
python -m pip install -r requirements.txt; \
|
||||
fi
|
||||
|
||||
# Copy the rest of the application into the image
|
||||
@ -56,10 +56,10 @@ RUN EXTRA_PACKAGES="metrics"; \
|
||||
if [ "$INSTALL_HQQ" == "true" ]; then \
|
||||
EXTRA_PACKAGES="${EXTRA_PACKAGES},hqq"; \
|
||||
fi; \
|
||||
if [ -n "$HTTP_PROXY" ]; then \
|
||||
pip install --proxy=$HTTP_PROXY -e ".[$EXTRA_PACKAGES]"; \
|
||||
else \
|
||||
pip install -e ".[$EXTRA_PACKAGES]"; \
|
||||
if [ -n "$HTTP_PROXY" ]; then \
|
||||
pip install --proxy=$HTTP_PROXY -e ".[$EXTRA_PACKAGES]"; \
|
||||
else \
|
||||
pip install -e ".[$EXTRA_PACKAGES]"; \
|
||||
fi
|
||||
|
||||
# Rebuild flash attention
|
||||
@ -69,8 +69,8 @@ RUN pip uninstall -y transformer-engine flash-attn && \
|
||||
if [ -n "$HTTP_PROXY" ]; then \
|
||||
pip install --proxy=$HTTP_PROXY ninja && \
|
||||
pip install --proxy=$HTTP_PROXY --no-cache-dir flash-attn --no-build-isolation; \
|
||||
else \
|
||||
pip install ninja && \
|
||||
else \
|
||||
pip install ninja && \
|
||||
pip install --no-cache-dir flash-attn --no-build-isolation; \
|
||||
fi; \
|
||||
fi
|
||||
|
@ -95,6 +95,12 @@ FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500
|
||||
FORCE_TORCHRUN=1 llamafactory-cli train examples/train_lora/llama3_lora_sft_ds3.yaml
|
||||
```
|
||||
|
||||
#### Supervised Fine-Tuning with Ray on 4 GPUs
|
||||
|
||||
```bash
|
||||
USE_RAY=1 llamafactory-cli train examples/train_full/llama3_lora_sft_ray.yaml
|
||||
```
|
||||
|
||||
### QLoRA Fine-Tuning
|
||||
|
||||
#### Supervised Fine-Tuning with 4/8-bit Bitsandbytes/HQQ/EETQ Quantization (Recommended)
|
||||
|
@ -95,6 +95,12 @@ FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500
|
||||
FORCE_TORCHRUN=1 llamafactory-cli train examples/train_lora/llama3_lora_sft_ds3.yaml
|
||||
```
|
||||
|
||||
#### 使用 Ray 在 4 张 GPU 上微调
|
||||
|
||||
```bash
|
||||
USE_RAY=1 llamafactory-cli train examples/train_full/llama3_lora_sft_ray.yaml
|
||||
```
|
||||
|
||||
### QLoRA 微调
|
||||
|
||||
#### 基于 4/8 比特 Bitsandbytes/HQQ/EETQ 量化进行指令监督微调(推荐)
|
||||
|
48
examples/train_lora/llama3_lora_sft_ray.yaml
Normal file
48
examples/train_lora/llama3_lora_sft_ray.yaml
Normal file
@ -0,0 +1,48 @@
|
||||
### model
|
||||
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct # or use local absolute path
|
||||
trust_remote_code: true
|
||||
|
||||
### method
|
||||
stage: sft
|
||||
do_train: true
|
||||
finetuning_type: lora
|
||||
lora_target: all
|
||||
|
||||
### dataset
|
||||
dataset: identity,alpaca_en_demo
|
||||
dataset_dir: REMOTE:llamafactory/demo_data # or use local absolute path
|
||||
template: llama3
|
||||
cutoff_len: 2048
|
||||
max_samples: 1000
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
|
||||
### output
|
||||
output_dir: tmp_dir
|
||||
logging_steps: 10
|
||||
save_steps: 500
|
||||
plot_loss: true
|
||||
overwrite_output_dir: true
|
||||
|
||||
### train
|
||||
per_device_train_batch_size: 1
|
||||
gradient_accumulation_steps: 8
|
||||
learning_rate: 1.0e-4
|
||||
num_train_epochs: 3.0
|
||||
lr_scheduler_type: cosine
|
||||
warmup_ratio: 0.1
|
||||
bf16: true
|
||||
ddp_timeout: 180000000
|
||||
|
||||
### eval
|
||||
val_size: 0.1
|
||||
per_device_eval_batch_size: 1
|
||||
eval_strategy: steps
|
||||
eval_steps: 500
|
||||
|
||||
### ray
|
||||
ray_run_name: llama3_8b_sft_lora
|
||||
ray_num_workers: 4 # number of GPUs to use
|
||||
resources_per_worker:
|
||||
GPU: 1
|
||||
placement_strategy: PACK
|
@ -63,7 +63,7 @@ class HuggingfaceEngine(BaseEngine):
|
||||
try:
|
||||
asyncio.get_event_loop()
|
||||
except RuntimeError:
|
||||
logger.warning_once("There is no current event loop, creating a new one.")
|
||||
logger.warning_rank0_once("There is no current event loop, creating a new one.")
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
|
@ -24,7 +24,7 @@ from .chat.chat_model import run_chat
|
||||
from .eval.evaluator import run_eval
|
||||
from .extras import logging
|
||||
from .extras.env import VERSION, print_env
|
||||
from .extras.misc import get_device_count
|
||||
from .extras.misc import get_device_count, use_ray
|
||||
from .train.tuner import export_model, run_exp
|
||||
from .webui.interface import run_web_demo, run_web_ui
|
||||
|
||||
@ -87,7 +87,7 @@ def main():
|
||||
export_model()
|
||||
elif command == Command.TRAIN:
|
||||
force_torchrun = os.getenv("FORCE_TORCHRUN", "0").lower() in ["true", "1"]
|
||||
if force_torchrun or get_device_count() > 1:
|
||||
if force_torchrun or (get_device_count() > 1 and not use_ray()):
|
||||
master_addr = os.getenv("MASTER_ADDR", "127.0.0.1")
|
||||
master_port = os.getenv("MASTER_PORT", str(random.randint(20001, 29999)))
|
||||
logger.info_rank0(f"Initializing distributed tasks at: {master_addr}:{master_port}")
|
||||
|
@ -56,12 +56,12 @@ def merge_dataset(
|
||||
return all_datasets[0]
|
||||
elif data_args.mix_strategy == "concat":
|
||||
if data_args.streaming:
|
||||
logger.warning_once("The samples between different datasets will not be mixed in streaming mode.")
|
||||
logger.warning_rank0_once("The samples between different datasets will not be mixed in streaming mode.")
|
||||
|
||||
return concatenate_datasets(all_datasets)
|
||||
elif data_args.mix_strategy.startswith("interleave"):
|
||||
if not data_args.streaming:
|
||||
logger.warning_once("We recommend using `mix_strategy=concat` in non-streaming mode.")
|
||||
logger.warning_rank0_once("We recommend using `mix_strategy=concat` in non-streaming mode.")
|
||||
|
||||
return interleave_datasets(
|
||||
datasets=all_datasets,
|
||||
|
@ -18,11 +18,10 @@ from typing import TYPE_CHECKING, Dict, Literal, Optional, Sequence, Union
|
||||
|
||||
import numpy as np
|
||||
from datasets import DatasetDict, load_dataset, load_from_disk
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
from ..extras import logging
|
||||
from ..extras.constants import FILEEXT2TYPE
|
||||
from ..extras.misc import has_tokenized_data
|
||||
from ..extras.misc import check_version, has_tokenized_data
|
||||
from .aligner import align_dataset
|
||||
from .data_utils import merge_dataset, split_dataset
|
||||
from .parser import get_dataset_list
|
||||
@ -84,7 +83,7 @@ def _load_single_dataset(
|
||||
raise NotImplementedError(f"Unknown load type: {dataset_attr.load_from}.")
|
||||
|
||||
if dataset_attr.load_from == "ms_hub":
|
||||
require_version("modelscope>=1.11.0", "To fix: pip install modelscope>=1.11.0")
|
||||
check_version("modelscope>=1.11.0", mandatory=True)
|
||||
from modelscope import MsDataset # type: ignore
|
||||
from modelscope.utils.config_ds import MS_DATASETS_CACHE # type: ignore
|
||||
|
||||
@ -103,7 +102,7 @@ def _load_single_dataset(
|
||||
dataset = dataset.to_hf_dataset()
|
||||
|
||||
elif dataset_attr.load_from == "om_hub":
|
||||
require_version("openmind>=0.8.0", "To fix: pip install openmind>=0.8.0")
|
||||
check_version("openmind>=0.8.0", mandatory=True)
|
||||
from openmind import OmDataset # type: ignore
|
||||
from openmind.utils.hub import OM_DATASETS_CACHE # type: ignore
|
||||
|
||||
|
@ -75,10 +75,14 @@ class BasePlugin:
|
||||
Validates if this model accepts the input modalities.
|
||||
"""
|
||||
if len(images) != 0 and self.image_token is None:
|
||||
raise ValueError("This model does not support image input.")
|
||||
raise ValueError(
|
||||
"This model does not support image input. Please check whether the correct `template` is used."
|
||||
)
|
||||
|
||||
if len(videos) != 0 and self.video_token is None:
|
||||
raise ValueError("This model does not support video input.")
|
||||
raise ValueError(
|
||||
"This model does not support video input. Please check whether the correct `template` is used."
|
||||
)
|
||||
|
||||
def _preprocess_image(self, image: "ImageObject", **kwargs) -> "ImageObject":
|
||||
r"""
|
||||
|
@ -15,10 +15,10 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
from transformers.utils.versions import require_version
|
||||
from typing_extensions import override
|
||||
|
||||
from ..extras import logging
|
||||
from ..extras.misc import check_version
|
||||
from .data_utils import Role
|
||||
from .formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter
|
||||
from .mm_plugin import get_mm_plugin
|
||||
@ -44,7 +44,6 @@ class Template:
|
||||
format_function: "Formatter"
|
||||
format_observation: "Formatter"
|
||||
format_tools: "Formatter"
|
||||
format_separator: "Formatter"
|
||||
format_prefix: "Formatter"
|
||||
default_system: str
|
||||
stop_words: List[str]
|
||||
@ -113,9 +112,6 @@ class Template:
|
||||
tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
|
||||
elements += self.format_system.apply(content=(system + tool_text))
|
||||
|
||||
if i > 0 and i % 2 == 0:
|
||||
elements += self.format_separator.apply()
|
||||
|
||||
if message["role"] == Role.USER.value:
|
||||
elements += self.format_user.apply(content=message["content"], idx=str(i // 2))
|
||||
elif message["role"] == Role.ASSISTANT.value:
|
||||
@ -180,9 +176,6 @@ class Llama2Template(Template):
|
||||
tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
|
||||
system_text = self.format_system.apply(content=(system + tool_text))[0]
|
||||
|
||||
if i > 0 and i % 2 == 0:
|
||||
elements += self.format_separator.apply()
|
||||
|
||||
if message["role"] == Role.USER.value:
|
||||
elements += self.format_user.apply(content=system_text + message["content"])
|
||||
elif message["role"] == Role.ASSISTANT.value:
|
||||
@ -210,7 +203,6 @@ def _register_template(
|
||||
format_function: Optional["Formatter"] = None,
|
||||
format_observation: Optional["Formatter"] = None,
|
||||
format_tools: Optional["Formatter"] = None,
|
||||
format_separator: Optional["Formatter"] = None,
|
||||
format_prefix: Optional["Formatter"] = None,
|
||||
default_system: str = "",
|
||||
stop_words: Sequence[str] = [],
|
||||
@ -224,34 +216,28 @@ def _register_template(
|
||||
|
||||
To add the following chat template:
|
||||
```
|
||||
[HUMAN]:
|
||||
user prompt here
|
||||
[AI]:
|
||||
model response here
|
||||
|
||||
[HUMAN]:
|
||||
user prompt here
|
||||
[AI]:
|
||||
model response here
|
||||
<s><user>user prompt here
|
||||
<model>model response here</s>
|
||||
<user>user prompt here
|
||||
<model>model response here</s>
|
||||
```
|
||||
|
||||
The corresponding code should be:
|
||||
```
|
||||
_register_template(
|
||||
name="custom",
|
||||
format_user=StringFormatter(slots=["[HUMAN]:\n{{content}}\n[AI]:\n"]),
|
||||
format_separator=EmptyFormatter(slots=["\n\n"]),
|
||||
efficient_eos=True,
|
||||
format_user=StringFormatter(slots=["<user>{{content}}\n<model>"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}</s>\n"]),
|
||||
format_prefix=EmptyFormatter("<s>"),
|
||||
)
|
||||
```
|
||||
"""
|
||||
template_class = Llama2Template if any(k in name for k in ("llama2", "mistral")) else Template
|
||||
template_class = Llama2Template if any(k in name for k in ("llama2", "mistral", "pixtral")) else Template
|
||||
default_slots = ["{{content}}"] if efficient_eos else ["{{content}}", {"eos_token"}]
|
||||
default_user_formatter = StringFormatter(slots=["{{content}}"])
|
||||
default_assistant_formatter = StringFormatter(slots=default_slots)
|
||||
default_function_formatter = FunctionFormatter(slots=default_slots, tool_format="default")
|
||||
default_tool_formatter = ToolFormatter(tool_format="default")
|
||||
default_separator_formatter = EmptyFormatter()
|
||||
default_prefix_formatter = EmptyFormatter()
|
||||
TEMPLATES[name] = template_class(
|
||||
format_user=format_user or default_user_formatter,
|
||||
@ -260,7 +246,6 @@ def _register_template(
|
||||
format_function=format_function or default_function_formatter,
|
||||
format_observation=format_observation or format_user or default_user_formatter,
|
||||
format_tools=format_tools or default_tool_formatter,
|
||||
format_separator=format_separator or default_separator_formatter,
|
||||
format_prefix=format_prefix or default_prefix_formatter,
|
||||
default_system=default_system,
|
||||
stop_words=stop_words,
|
||||
@ -344,9 +329,7 @@ def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer")
|
||||
jinja_template += "{{ " + user_message + " }}"
|
||||
|
||||
jinja_template += "{% elif message['role'] == 'assistant' %}"
|
||||
assistant_message = _convert_slots_to_jinja(
|
||||
template.format_assistant.apply() + template.format_separator.apply(), tokenizer
|
||||
)
|
||||
assistant_message = _convert_slots_to_jinja(template.format_assistant.apply(), tokenizer)
|
||||
jinja_template += "{{ " + assistant_message + " }}"
|
||||
jinja_template += "{% endif %}"
|
||||
jinja_template += "{% endfor %}"
|
||||
@ -365,7 +348,7 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args:
|
||||
raise ValueError(f"Template {data_args.template} does not exist.")
|
||||
|
||||
if template.mm_plugin.__class__.__name__ != "BasePlugin":
|
||||
require_version("transformers>=4.45.0", "To fix: pip install transformers>=4.45.0")
|
||||
check_version("transformers>=4.45.0")
|
||||
|
||||
if data_args.train_on_prompt and template.efficient_eos:
|
||||
raise ValueError("Current template does not support `train_on_prompt`.")
|
||||
@ -411,7 +394,7 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args:
|
||||
_register_template(
|
||||
name="alpaca",
|
||||
format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n\n### Response:\n"]),
|
||||
format_separator=EmptyFormatter(slots=["\n\n"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}", {"eos_token"}, "\n\n"]),
|
||||
default_system=(
|
||||
"Below is an instruction that describes a task. "
|
||||
"Write a response that appropriately completes the request.\n\n"
|
||||
@ -423,13 +406,13 @@ _register_template(
|
||||
_register_template(
|
||||
name="aquila",
|
||||
format_user=StringFormatter(slots=["Human: {{content}}###Assistant:"]),
|
||||
format_separator=EmptyFormatter(slots=["###"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}###"]),
|
||||
format_system=StringFormatter(slots=["System: {{content}}###"]),
|
||||
default_system=(
|
||||
"A chat between a curious human and an artificial intelligence assistant. "
|
||||
"The assistant gives helpful, detailed, and polite answers to the human's questions."
|
||||
),
|
||||
stop_words=["</s>"],
|
||||
efficient_eos=True,
|
||||
)
|
||||
|
||||
|
||||
@ -459,7 +442,7 @@ _register_template(
|
||||
_register_template(
|
||||
name="belle",
|
||||
format_user=StringFormatter(slots=["Human: {{content}}\n\nBelle: "]),
|
||||
format_separator=EmptyFormatter(slots=["\n\n"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}", {"eos_token"}, "\n\n"]),
|
||||
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||
)
|
||||
|
||||
@ -481,7 +464,6 @@ _register_template(
|
||||
_register_template(
|
||||
name="chatglm2",
|
||||
format_user=StringFormatter(slots=["[Round {{idx}}]\n\n问:{{content}}\n\n答:"]),
|
||||
format_separator=EmptyFormatter(slots=["\n\n"]),
|
||||
format_prefix=EmptyFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}]),
|
||||
efficient_eos=True,
|
||||
)
|
||||
@ -506,9 +488,9 @@ _register_template(
|
||||
_register_template(
|
||||
name="chatml",
|
||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
|
||||
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
||||
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
format_separator=EmptyFormatter(slots=["\n"]),
|
||||
stop_words=["<|im_end|>", "<|im_start|>"],
|
||||
replace_eos=True,
|
||||
replace_jinja_template=True,
|
||||
@ -519,9 +501,9 @@ _register_template(
|
||||
_register_template(
|
||||
name="chatml_de",
|
||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
|
||||
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
||||
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
format_separator=EmptyFormatter(slots=["\n"]),
|
||||
default_system="Du bist ein freundlicher und hilfsbereiter KI-Assistent.",
|
||||
stop_words=["<|im_end|>", "<|im_start|>"],
|
||||
replace_eos=True,
|
||||
@ -574,9 +556,11 @@ _register_template(
|
||||
)
|
||||
|
||||
|
||||
# copied from chatml template
|
||||
_register_template(
|
||||
name="cpm3",
|
||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
|
||||
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
||||
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||
stop_words=["<|im_end|>"],
|
||||
@ -603,9 +587,9 @@ _register_template(
|
||||
_register_template(
|
||||
name="dbrx",
|
||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
|
||||
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
||||
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
format_separator=EmptyFormatter(slots=["\n"]),
|
||||
default_system=(
|
||||
"You are DBRX, created by Databricks. You were last updated in December 2023. "
|
||||
"You answer questions based on information available up to that point.\n"
|
||||
@ -622,7 +606,6 @@ _register_template(
|
||||
"ABOUT YOURSELF UNLESS THE INFORMATION IS DIRECTLY PERTINENT TO THE USER'S QUERY."
|
||||
),
|
||||
stop_words=["<|im_end|>"],
|
||||
replace_eos=True,
|
||||
)
|
||||
|
||||
|
||||
@ -644,8 +627,7 @@ _register_template(
|
||||
_register_template(
|
||||
name="deepseekcoder",
|
||||
format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n### Response:"]),
|
||||
format_assistant=StringFormatter(slots=["\n{{content}}\n<|EOT|>"]),
|
||||
format_separator=EmptyFormatter(slots=["\n"]),
|
||||
format_assistant=StringFormatter(slots=["\n{{content}}\n<|EOT|>\n"]),
|
||||
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||
default_system=(
|
||||
"You are an AI programming assistant, utilizing the DeepSeek Coder model, "
|
||||
@ -659,8 +641,8 @@ _register_template(
|
||||
_register_template(
|
||||
name="default",
|
||||
format_user=StringFormatter(slots=["Human: {{content}}\nAssistant:"]),
|
||||
format_system=StringFormatter(slots=["{{content}}\n"]),
|
||||
format_separator=EmptyFormatter(slots=["\n"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}", {"eos_token"}, "\n"]),
|
||||
format_system=StringFormatter(slots=["System: {{content}}\n"]),
|
||||
)
|
||||
|
||||
|
||||
@ -673,22 +655,22 @@ _register_template(
|
||||
_register_template(
|
||||
name="exaone",
|
||||
format_user=StringFormatter(slots=["[|user|]{{content}}\n[|assistant|]"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}", {"eos_token"}, "\n"]),
|
||||
format_system=StringFormatter(slots=["[|system|]{{content}}[|endofturn|]\n"]),
|
||||
format_separator=EmptyFormatter(slots=["\n"]),
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="falcon",
|
||||
format_user=StringFormatter(slots=["User: {{content}}\nFalcon:"]),
|
||||
format_separator=EmptyFormatter(slots=["\n"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}\n"]),
|
||||
efficient_eos=True,
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="fewshot",
|
||||
format_separator=EmptyFormatter(slots=["\n\n"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}\n\n"]),
|
||||
efficient_eos=True,
|
||||
)
|
||||
|
||||
@ -696,12 +678,11 @@ _register_template(
|
||||
_register_template(
|
||||
name="gemma",
|
||||
format_user=StringFormatter(slots=["<start_of_turn>user\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}<end_of_turn>\n"]),
|
||||
format_observation=StringFormatter(
|
||||
slots=["<start_of_turn>tool\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]
|
||||
),
|
||||
format_separator=EmptyFormatter(slots=["<end_of_turn>\n"]),
|
||||
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||
efficient_eos=True,
|
||||
)
|
||||
|
||||
|
||||
@ -726,8 +707,8 @@ _register_template(
|
||||
"<|start_of_role|>user<|end_of_role|>{{content}}<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>"
|
||||
]
|
||||
),
|
||||
format_assistant=StringFormatter(slots=["{{content}}<|end_of_text|>\n"]),
|
||||
format_system=StringFormatter(slots=["<|start_of_role|>system<|end_of_role|>{{content}}<|end_of_text|>\n"]),
|
||||
format_separator=EmptyFormatter(slots=["\n"]),
|
||||
)
|
||||
|
||||
|
||||
@ -742,22 +723,20 @@ _register_template(
|
||||
_register_template(
|
||||
name="intern",
|
||||
format_user=StringFormatter(slots=["<|User|>:{{content}}\n<|Bot|>:"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}<eoa>\n"]),
|
||||
format_system=StringFormatter(slots=["<|System|>:{{content}}\n"]),
|
||||
format_separator=EmptyFormatter(slots=["<eoa>\n"]),
|
||||
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||
stop_words=["<eoa>"],
|
||||
efficient_eos=True, # internlm tokenizer cannot set eos_token_id
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="intern2",
|
||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
|
||||
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
||||
format_separator=EmptyFormatter(slots=["<|im_end|>\n"]),
|
||||
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||
stop_words=["<|im_end|>"],
|
||||
efficient_eos=True, # internlm2 tokenizer cannot set eos_token_id
|
||||
)
|
||||
|
||||
|
||||
@ -888,6 +867,7 @@ _register_template(
|
||||
name="llava_next_mistral",
|
||||
format_user=StringFormatter(slots=["[INST] {{content}}[/INST]"]),
|
||||
format_assistant=StringFormatter(slots=[" {{content}}", {"eos_token"}]),
|
||||
format_system=StringFormatter(slots=["{{content}}\n\n"]),
|
||||
format_function=FunctionFormatter(slots=["[TOOL_CALLS] ", "{{content}}", {"eos_token"}], tool_format="mistral"),
|
||||
format_observation=StringFormatter(slots=["""[TOOL_RESULTS] {"content": {{content}}}[/TOOL_RESULTS]"""]),
|
||||
format_tools=ToolFormatter(tool_format="mistral"),
|
||||
@ -900,16 +880,15 @@ _register_template(
|
||||
_register_template(
|
||||
name="llava_next_qwen",
|
||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
|
||||
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
||||
format_function=FunctionFormatter(slots=["{{content}}", "<|im_end|>"], tool_format="qwen"),
|
||||
format_function=FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen"),
|
||||
format_observation=StringFormatter(
|
||||
slots=["<|im_start|>user\n<tool_response>\n{{content}}\n</tool_response><|im_end|>\n<|im_start|>assistant\n"]
|
||||
),
|
||||
format_tools=ToolFormatter(tool_format="qwen"),
|
||||
format_separator=EmptyFormatter(slots=["\n"]),
|
||||
default_system="You are a helpful assistant.",
|
||||
stop_words=["<|im_end|>"],
|
||||
replace_eos=True,
|
||||
mm_plugin=get_mm_plugin(name="llava_next", image_token="<image>"),
|
||||
)
|
||||
|
||||
@ -918,10 +897,9 @@ _register_template(
|
||||
_register_template(
|
||||
name="llava_next_yi",
|
||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
|
||||
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
||||
format_separator=EmptyFormatter(slots=["\n"]),
|
||||
stop_words=["<|im_end|>"],
|
||||
replace_eos=True,
|
||||
mm_plugin=get_mm_plugin(name="llava_next", image_token="<image>"),
|
||||
)
|
||||
|
||||
@ -943,6 +921,7 @@ _register_template(
|
||||
name="llava_next_video_mistral",
|
||||
format_user=StringFormatter(slots=["[INST] {{content}}[/INST]"]),
|
||||
format_assistant=StringFormatter(slots=[" {{content}}", {"eos_token"}]),
|
||||
format_system=StringFormatter(slots=["{{content}}\n\n"]),
|
||||
format_function=FunctionFormatter(slots=["[TOOL_CALLS] ", "{{content}}", {"eos_token"}], tool_format="mistral"),
|
||||
format_observation=StringFormatter(slots=["""[TOOL_RESULTS] {"content": {{content}}}[/TOOL_RESULTS]"""]),
|
||||
format_tools=ToolFormatter(tool_format="mistral"),
|
||||
@ -955,10 +934,9 @@ _register_template(
|
||||
_register_template(
|
||||
name="llava_next_video_yi",
|
||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
|
||||
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
||||
format_separator=EmptyFormatter(slots=["\n"]),
|
||||
stop_words=["<|im_end|>"],
|
||||
replace_eos=True,
|
||||
mm_plugin=get_mm_plugin(name="llava_next_video", image_token="<image>", video_token="<video>"),
|
||||
)
|
||||
|
||||
@ -967,16 +945,15 @@ _register_template(
|
||||
_register_template(
|
||||
name="marco",
|
||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
|
||||
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
||||
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
format_separator=EmptyFormatter(slots=["\n"]),
|
||||
default_system=(
|
||||
"你是一个经过良好训练的AI助手,你的名字是Marco-o1.由阿里国际数字商业集团的AI Business创造.\n## 重要!!!!!\n"
|
||||
"当你回答问题时,你的思考应该在<Thought>内完成,<Output>内输出你的结果。\n"
|
||||
"<Thought>应该尽可能是英文,但是有2个特例,一个是对原文中的引用,另一个是是数学应该使用markdown格式,<Output>内的输出需要遵循用户输入的语言。\n"
|
||||
),
|
||||
stop_words=["<|im_end|>"],
|
||||
replace_eos=True,
|
||||
)
|
||||
|
||||
|
||||
@ -984,6 +961,7 @@ _register_template(
|
||||
name="mistral",
|
||||
format_user=StringFormatter(slots=["[INST] {{content}}[/INST]"]),
|
||||
format_assistant=StringFormatter(slots=[" {{content}}", {"eos_token"}]),
|
||||
format_system=StringFormatter(slots=["{{content}}\n\n"]),
|
||||
format_function=FunctionFormatter(slots=["[TOOL_CALLS] ", "{{content}}", {"eos_token"}], tool_format="mistral"),
|
||||
format_observation=StringFormatter(slots=["""[TOOL_RESULTS] {"content": {{content}}}[/TOOL_RESULTS]"""]),
|
||||
format_tools=ToolFormatter(tool_format="mistral"),
|
||||
@ -1017,7 +995,6 @@ _register_template(
|
||||
),
|
||||
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||
stop_words=["<|eot_id|>"],
|
||||
replace_eos=True,
|
||||
)
|
||||
|
||||
|
||||
@ -1025,9 +1002,9 @@ _register_template(
|
||||
_register_template(
|
||||
name="opencoder",
|
||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
|
||||
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
||||
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
format_separator=EmptyFormatter(slots=["\n"]),
|
||||
default_system="You are OpenCoder, created by OpenCoder Team.",
|
||||
stop_words=["<|im_end|>"],
|
||||
)
|
||||
@ -1044,12 +1021,11 @@ _register_template(
|
||||
_register_template(
|
||||
name="paligemma",
|
||||
format_user=StringFormatter(slots=["<start_of_turn>user\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}<end_of_turn>\n"]),
|
||||
format_observation=StringFormatter(
|
||||
slots=["<start_of_turn>tool\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]
|
||||
),
|
||||
format_separator=EmptyFormatter(slots=["<end_of_turn>\n"]),
|
||||
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||
efficient_eos=True,
|
||||
mm_plugin=get_mm_plugin(name="paligemma", image_token="<image>"),
|
||||
)
|
||||
|
||||
@ -1057,28 +1033,37 @@ _register_template(
|
||||
_register_template(
|
||||
name="phi",
|
||||
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>\n"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}<|end|>\n"]),
|
||||
format_system=StringFormatter(slots=["<|system|>\n{{content}}<|end|>\n"]),
|
||||
format_separator=EmptyFormatter(slots=["\n"]),
|
||||
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||
stop_words=["<|end|>"],
|
||||
replace_eos=True,
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="phi_small",
|
||||
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>\n"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}<|end|>\n"]),
|
||||
format_system=StringFormatter(slots=["<|system|>\n{{content}}<|end|>\n"]),
|
||||
format_separator=EmptyFormatter(slots=["\n"]),
|
||||
format_prefix=EmptyFormatter(slots=[{"<|endoftext|>"}]),
|
||||
stop_words=["<|end|>"],
|
||||
replace_eos=True,
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="phi4",
|
||||
format_user=StringFormatter(
|
||||
slots=["<|im_start|>user<|im_sep|>{{content}}<|im_end|><|im_start|>assistant<|im_sep|>"]
|
||||
),
|
||||
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>"]),
|
||||
format_system=StringFormatter(slots=["<|im_start|>system<|im_sep|>{{content}}<|im_end|>"]),
|
||||
stop_words=["<|im_end|>"],
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="pixtral",
|
||||
format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]),
|
||||
format_user=StringFormatter(slots=["[INST]{{content}}[/INST]"]),
|
||||
format_system=StringFormatter(slots=["{{content}}\n\n"]),
|
||||
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||
mm_plugin=get_mm_plugin(name="pixtral", image_token="[IMG]"),
|
||||
)
|
||||
@ -1088,13 +1073,13 @@ _register_template(
|
||||
_register_template(
|
||||
name="qwen",
|
||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
|
||||
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
||||
format_function=FunctionFormatter(slots=["{{content}}", "<|im_end|>"], tool_format="qwen"),
|
||||
format_function=FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen"),
|
||||
format_observation=StringFormatter(
|
||||
slots=["<|im_start|>user\n<tool_response>\n{{content}}\n</tool_response><|im_end|>\n<|im_start|>assistant\n"]
|
||||
),
|
||||
format_tools=ToolFormatter(tool_format="qwen"),
|
||||
format_separator=EmptyFormatter(slots=["\n"]),
|
||||
default_system="You are a helpful assistant.",
|
||||
stop_words=["<|im_end|>"],
|
||||
)
|
||||
@ -1104,13 +1089,13 @@ _register_template(
|
||||
_register_template(
|
||||
name="qwen2_vl",
|
||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
|
||||
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
||||
format_function=FunctionFormatter(slots=["{{content}}", "<|im_end|>"], tool_format="qwen"),
|
||||
format_function=FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen"),
|
||||
format_observation=StringFormatter(
|
||||
slots=["<|im_start|>user\n<tool_response>\n{{content}}\n</tool_response><|im_end|>\n<|im_start|>assistant\n"]
|
||||
),
|
||||
format_tools=ToolFormatter(tool_format="qwen"),
|
||||
format_separator=EmptyFormatter(slots=["\n"]),
|
||||
default_system="You are a helpful assistant.",
|
||||
stop_words=["<|im_end|>"],
|
||||
mm_plugin=get_mm_plugin(name="qwen2_vl", image_token="<|image_pad|>", video_token="<|video_pad|>"),
|
||||
@ -1120,8 +1105,8 @@ _register_template(
|
||||
_register_template(
|
||||
name="sailor",
|
||||
format_user=StringFormatter(slots=["<|im_start|>question\n{{content}}<|im_end|>\n<|im_start|>answer\n"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
|
||||
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
||||
format_separator=EmptyFormatter(slots=["\n"]),
|
||||
default_system=(
|
||||
"You are an AI assistant named Sailor created by Sea AI Lab. "
|
||||
"Your answer should be friendly, unbiased, faithful, informative and detailed."
|
||||
@ -1173,10 +1158,9 @@ _register_template(
|
||||
_register_template(
|
||||
name="starchat",
|
||||
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}<|end|>\n"]),
|
||||
format_system=StringFormatter(slots=["<|system|>\n{{content}}<|end|>\n"]),
|
||||
format_separator=EmptyFormatter(slots=["\n"]),
|
||||
stop_words=["<|end|>"],
|
||||
replace_eos=True,
|
||||
)
|
||||
|
||||
|
||||
@ -1239,8 +1223,8 @@ _register_template(
|
||||
_register_template(
|
||||
name="yayi",
|
||||
format_user=StringFormatter(slots=[{"token": "<|Human|>"}, ":\n{{content}}\n\n", {"token": "<|YaYi|>"}, ":"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}\n\n"]),
|
||||
format_system=StringFormatter(slots=[{"token": "<|System|>"}, ":\n{{content}}\n\n"]),
|
||||
format_separator=EmptyFormatter(slots=["\n\n"]),
|
||||
default_system=(
|
||||
"You are a helpful, respectful and honest assistant named YaYi "
|
||||
"developed by Beijing Wenge Technology Co.,Ltd. "
|
||||
@ -1260,17 +1244,16 @@ _register_template(
|
||||
_register_template(
|
||||
name="yi",
|
||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
|
||||
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
||||
format_separator=EmptyFormatter(slots=["\n"]),
|
||||
stop_words=["<|im_end|>"],
|
||||
replace_eos=True,
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="yi_vl",
|
||||
format_user=StringFormatter(slots=["### Human: {{content}}\n### Assistant:"]),
|
||||
format_separator=EmptyFormatter(slots=["\n"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}\n"]),
|
||||
default_system=(
|
||||
"This is a chat between an inquisitive human and an AI assistant. "
|
||||
"Assume the role of the AI assistant. Read all the images carefully, "
|
||||
@ -1287,9 +1270,8 @@ _register_template(
|
||||
_register_template(
|
||||
name="yuan",
|
||||
format_user=StringFormatter(slots=["{{content}}", {"token": "<sep>"}]),
|
||||
format_separator=EmptyFormatter(slots=["\n"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}<eod>\n"]),
|
||||
stop_words=["<eod>"],
|
||||
replace_eos=True,
|
||||
)
|
||||
|
||||
|
||||
@ -1304,5 +1286,5 @@ _register_template(
|
||||
_register_template(
|
||||
name="ziya",
|
||||
format_user=StringFormatter(slots=["<human>:{{content}}\n<bot>:"]),
|
||||
format_separator=EmptyFormatter(slots=["\n"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}\n"]),
|
||||
)
|
||||
|
@ -1424,6 +1424,14 @@ register_model_group(
|
||||
DownloadSource.DEFAULT: "microsoft/Phi-3-medium-128k-instruct",
|
||||
DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-medium-128k-instruct",
|
||||
},
|
||||
"Phi-3.5-4B-instruct": {
|
||||
DownloadSource.DEFAULT: "microsoft/Phi-3.5-mini-instruct",
|
||||
DownloadSource.MODELSCOPE: "LLM-Research/Phi-3.5-mini-instruct",
|
||||
},
|
||||
"Phi-3.5-MoE-42B-A6.6B-instruct": {
|
||||
DownloadSource.DEFAULT: "microsoft/Phi-3.5-MoE-instruct",
|
||||
DownloadSource.MODELSCOPE: "LLM-Research/Phi-3.5-MoE-instruct",
|
||||
},
|
||||
},
|
||||
template="phi",
|
||||
)
|
||||
@ -1444,6 +1452,17 @@ register_model_group(
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Phi-4-14B-Instruct": {
|
||||
DownloadSource.DEFAULT: "microsoft/phi-4",
|
||||
DownloadSource.MODELSCOPE: "LLM-Research/phi-4",
|
||||
},
|
||||
},
|
||||
template="phi4",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Pixtral-12B-Instruct": {
|
||||
|
@ -68,7 +68,7 @@ class LoggerHandler(logging.Handler):
|
||||
|
||||
class _Logger(logging.Logger):
|
||||
r"""
|
||||
A logger that supports info_rank0 and warning_once.
|
||||
A logger that supports rank0 logging.
|
||||
"""
|
||||
|
||||
def info_rank0(self, *args, **kwargs) -> None:
|
||||
@ -77,7 +77,7 @@ class _Logger(logging.Logger):
|
||||
def warning_rank0(self, *args, **kwargs) -> None:
|
||||
self.warning(*args, **kwargs)
|
||||
|
||||
def warning_once(self, *args, **kwargs) -> None:
|
||||
def warning_rank0_once(self, *args, **kwargs) -> None:
|
||||
self.warning(*args, **kwargs)
|
||||
|
||||
|
||||
@ -163,11 +163,11 @@ def warning_rank0(self: "logging.Logger", *args, **kwargs) -> None:
|
||||
|
||||
|
||||
@lru_cache(None)
|
||||
def warning_once(self: "logging.Logger", *args, **kwargs) -> None:
|
||||
def warning_rank0_once(self: "logging.Logger", *args, **kwargs) -> None:
|
||||
if int(os.getenv("LOCAL_RANK", "0")) == 0:
|
||||
self.warning(*args, **kwargs)
|
||||
|
||||
|
||||
logging.Logger.info_rank0 = info_rank0
|
||||
logging.Logger.warning_rank0 = warning_rank0
|
||||
logging.Logger.warning_once = warning_once
|
||||
logging.Logger.warning_rank0_once = warning_rank0_once
|
||||
|
@ -73,19 +73,31 @@ class AverageMeter:
|
||||
self.avg = self.sum / self.count
|
||||
|
||||
|
||||
def check_version(requirement: str, mandatory: bool = False) -> None:
|
||||
r"""
|
||||
Optionally checks the package version.
|
||||
"""
|
||||
if os.getenv("DISABLE_VERSION_CHECK", "0").lower() in ["true", "1"] and not mandatory:
|
||||
logger.warning_rank0_once("Version checking has been disabled, may lead to unexpected behaviors.")
|
||||
return
|
||||
|
||||
if mandatory:
|
||||
hint = f"To fix: run `pip install {requirement}`."
|
||||
else:
|
||||
hint = f"To fix: run `pip install {requirement}` or set `DISABLE_VERSION_CHECK=1` to skip this check."
|
||||
|
||||
require_version(requirement, hint)
|
||||
|
||||
|
||||
def check_dependencies() -> None:
|
||||
r"""
|
||||
Checks the version of the required packages.
|
||||
"""
|
||||
if os.getenv("DISABLE_VERSION_CHECK", "0").lower() in ["true", "1"]:
|
||||
logger.warning_once("Version checking has been disabled, may lead to unexpected behaviors.")
|
||||
return
|
||||
|
||||
require_version("transformers>=4.41.2", "To fix: pip install transformers>=4.41.2")
|
||||
require_version("datasets>=2.16.0,<=3.1.0", "To fix: pip install datasets>=2.16.0,<=3.1.0")
|
||||
require_version("accelerate>=0.34.0,<=1.0.1", "To fix: pip install accelerate>=0.34.0,<=1.0.1")
|
||||
require_version("peft>=0.11.1,<=0.12.0", "To fix: pip install peft>=0.11.1,<=0.12.0")
|
||||
require_version("trl>=0.8.6,<=0.9.6", "To fix: pip install trl>=0.8.6,<=0.9.6")
|
||||
check_version("transformers>=4.41.2,<=4.46.1")
|
||||
check_version("datasets>=2.16.0,<=3.1.0")
|
||||
check_version("accelerate>=0.34.0,<=1.0.1")
|
||||
check_version("peft>=0.11.1,<=0.12.0")
|
||||
check_version("trl>=0.8.6,<=0.9.6")
|
||||
|
||||
|
||||
def calculate_tps(dataset: Sequence[Dict[str, Any]], metrics: Dict[str, float], stage: Literal["sft", "rm"]) -> float:
|
||||
@ -229,7 +241,7 @@ def skip_check_imports() -> None:
|
||||
r"""
|
||||
Avoids flash attention import error in custom model files.
|
||||
"""
|
||||
if os.environ.get("FORCE_CHECK_IMPORTS", "0").lower() not in ["true", "1"]:
|
||||
if os.getenv("FORCE_CHECK_IMPORTS", "0").lower() not in ["true", "1"]:
|
||||
transformers.dynamic_module_utils.check_imports = get_relative_imports
|
||||
|
||||
|
||||
@ -253,7 +265,7 @@ def try_download_model_from_other_hub(model_args: "ModelArguments") -> str:
|
||||
return model_args.model_name_or_path
|
||||
|
||||
if use_modelscope():
|
||||
require_version("modelscope>=1.11.0", "To fix: pip install modelscope>=1.11.0")
|
||||
check_version("modelscope>=1.11.0", mandatory=True)
|
||||
from modelscope import snapshot_download # type: ignore
|
||||
|
||||
revision = "master" if model_args.model_revision == "main" else model_args.model_revision
|
||||
@ -264,7 +276,7 @@ def try_download_model_from_other_hub(model_args: "ModelArguments") -> str:
|
||||
)
|
||||
|
||||
if use_openmind():
|
||||
require_version("openmind>=0.8.0", "To fix: pip install openmind>=0.8.0")
|
||||
check_version("openmind>=0.8.0", mandatory=True)
|
||||
from openmind.utils.hub import snapshot_download # type: ignore
|
||||
|
||||
return snapshot_download(
|
||||
@ -275,8 +287,12 @@ def try_download_model_from_other_hub(model_args: "ModelArguments") -> str:
|
||||
|
||||
|
||||
def use_modelscope() -> bool:
|
||||
return os.environ.get("USE_MODELSCOPE_HUB", "0").lower() in ["true", "1"]
|
||||
return os.getenv("USE_MODELSCOPE_HUB", "0").lower() in ["true", "1"]
|
||||
|
||||
|
||||
def use_openmind() -> bool:
|
||||
return os.environ.get("USE_OPENMIND_HUB", "0").lower() in ["true", "1"]
|
||||
return os.getenv("USE_OPENMIND_HUB", "0").lower() in ["true", "1"]
|
||||
|
||||
|
||||
def use_ray() -> bool:
|
||||
return os.getenv("USE_RAY", "0").lower() in ["true", "1"]
|
||||
|
@ -62,6 +62,10 @@ def is_pillow_available():
|
||||
return _is_package_available("PIL")
|
||||
|
||||
|
||||
def is_ray_available():
|
||||
return _is_package_available("ray")
|
||||
|
||||
|
||||
def is_requests_available():
|
||||
return _is_package_available("requests")
|
||||
|
||||
|
@ -17,7 +17,8 @@ from .evaluation_args import EvaluationArguments
|
||||
from .finetuning_args import FinetuningArguments
|
||||
from .generating_args import GeneratingArguments
|
||||
from .model_args import ModelArguments
|
||||
from .parser import get_eval_args, get_infer_args, get_train_args
|
||||
from .parser import get_eval_args, get_infer_args, get_ray_args, get_train_args, read_args
|
||||
from .training_args import RayArguments, TrainingArguments
|
||||
|
||||
|
||||
__all__ = [
|
||||
@ -26,7 +27,11 @@ __all__ = [
|
||||
"FinetuningArguments",
|
||||
"GeneratingArguments",
|
||||
"ModelArguments",
|
||||
"RayArguments",
|
||||
"TrainingArguments",
|
||||
"get_eval_args",
|
||||
"get_infer_args",
|
||||
"get_ray_args",
|
||||
"get_train_args",
|
||||
"read_args",
|
||||
]
|
||||
|
@ -15,56 +15,67 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
from transformers import HfArgumentParser, Seq2SeqTrainingArguments
|
||||
import yaml
|
||||
from transformers import HfArgumentParser
|
||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||
from transformers.trainer_utils import get_last_checkpoint
|
||||
from transformers.training_args import ParallelMode
|
||||
from transformers.utils import is_torch_bf16_gpu_available, is_torch_npu_available
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
from ..extras import logging
|
||||
from ..extras.constants import CHECKPOINT_NAMES
|
||||
from ..extras.misc import check_dependencies, get_current_device
|
||||
from ..extras.misc import check_dependencies, check_version, get_current_device
|
||||
from .data_args import DataArguments
|
||||
from .evaluation_args import EvaluationArguments
|
||||
from .finetuning_args import FinetuningArguments
|
||||
from .generating_args import GeneratingArguments
|
||||
from .model_args import ModelArguments
|
||||
from .training_args import RayArguments, TrainingArguments
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
check_dependencies()
|
||||
|
||||
|
||||
_TRAIN_ARGS = [ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments]
|
||||
_TRAIN_CLS = Tuple[ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments]
|
||||
_TRAIN_ARGS = [ModelArguments, DataArguments, TrainingArguments, FinetuningArguments, GeneratingArguments]
|
||||
_TRAIN_CLS = Tuple[ModelArguments, DataArguments, TrainingArguments, FinetuningArguments, GeneratingArguments]
|
||||
_INFER_ARGS = [ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
|
||||
_INFER_CLS = Tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
|
||||
_EVAL_ARGS = [ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments]
|
||||
_EVAL_CLS = Tuple[ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments]
|
||||
|
||||
|
||||
def _parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None) -> Tuple[Any]:
|
||||
def read_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> Union[Dict[str, Any], List[str]]:
|
||||
if args is not None:
|
||||
return parser.parse_dict(args)
|
||||
return args
|
||||
|
||||
if len(sys.argv) == 2 and (sys.argv[1].endswith(".yaml") or sys.argv[1].endswith(".yml")):
|
||||
return parser.parse_yaml_file(os.path.abspath(sys.argv[1]))
|
||||
return yaml.safe_load(Path(sys.argv[1]).absolute().read_text())
|
||||
elif len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
||||
return json.loads(Path(sys.argv[1]).absolute().read_text())
|
||||
else:
|
||||
return sys.argv[1:]
|
||||
|
||||
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
||||
return parser.parse_json_file(os.path.abspath(sys.argv[1]))
|
||||
|
||||
(*parsed_args, unknown_args) = parser.parse_args_into_dataclasses(return_remaining_strings=True)
|
||||
def _parse_args(
|
||||
parser: "HfArgumentParser", args: Optional[Union[Dict[str, Any], List[str]]] = None, allow_extra_keys: bool = False
|
||||
) -> Tuple[Any]:
|
||||
args = read_args(args)
|
||||
if isinstance(args, dict):
|
||||
return parser.parse_dict(args, allow_extra_keys=allow_extra_keys)
|
||||
|
||||
if unknown_args:
|
||||
(*parsed_args, unknown_args) = parser.parse_args_into_dataclasses(args=args, return_remaining_strings=True)
|
||||
|
||||
if unknown_args and not allow_extra_keys:
|
||||
print(parser.format_help())
|
||||
print(f"Got unknown args, potentially deprecated arguments: {unknown_args}")
|
||||
raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {unknown_args}")
|
||||
@ -110,58 +121,61 @@ def _verify_model_args(
|
||||
def _check_extra_dependencies(
|
||||
model_args: "ModelArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
training_args: Optional["Seq2SeqTrainingArguments"] = None,
|
||||
training_args: Optional["TrainingArguments"] = None,
|
||||
) -> None:
|
||||
if os.getenv("DISABLE_VERSION_CHECK", "0").lower() in ["true", "1"]:
|
||||
logger.warning_once("Version checking has been disabled, may lead to unexpected behaviors.")
|
||||
return
|
||||
|
||||
if model_args.use_unsloth:
|
||||
require_version("unsloth", "Please install unsloth: https://github.com/unslothai/unsloth")
|
||||
check_version("unsloth", mandatory=True)
|
||||
|
||||
if model_args.enable_liger_kernel:
|
||||
require_version("liger-kernel", "To fix: pip install liger-kernel")
|
||||
check_version("liger-kernel", mandatory=True)
|
||||
|
||||
if model_args.mixture_of_depths is not None:
|
||||
require_version("mixture-of-depth>=1.1.6", "To fix: pip install mixture-of-depth>=1.1.6")
|
||||
check_version("mixture-of-depth>=1.1.6", mandatory=True)
|
||||
|
||||
if model_args.infer_backend == "vllm":
|
||||
require_version("vllm>=0.4.3,<0.6.7", "To fix: pip install vllm>=0.4.3,<0.6.7")
|
||||
check_version("vllm>=0.4.3,<0.6.7")
|
||||
check_version("vllm", mandatory=True)
|
||||
|
||||
if finetuning_args.use_galore:
|
||||
require_version("galore_torch", "To fix: pip install galore_torch")
|
||||
check_version("galore_torch", mandatory=True)
|
||||
|
||||
if finetuning_args.use_badam:
|
||||
require_version("badam>=1.2.1", "To fix: pip install badam>=1.2.1")
|
||||
check_version("badam>=1.2.1", mandatory=True)
|
||||
|
||||
if finetuning_args.use_adam_mini:
|
||||
require_version("adam-mini", "To fix: pip install adam-mini")
|
||||
check_version("adam-mini", mandatory=True)
|
||||
|
||||
if finetuning_args.plot_loss:
|
||||
require_version("matplotlib", "To fix: pip install matplotlib")
|
||||
check_version("matplotlib", mandatory=True)
|
||||
|
||||
if training_args is not None and training_args.predict_with_generate:
|
||||
require_version("jieba", "To fix: pip install jieba")
|
||||
require_version("nltk", "To fix: pip install nltk")
|
||||
require_version("rouge_chinese", "To fix: pip install rouge-chinese")
|
||||
check_version("jieba", mandatory=True)
|
||||
check_version("nltk", mandatory=True)
|
||||
check_version("rouge_chinese", mandatory=True)
|
||||
|
||||
|
||||
def _parse_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||
def _parse_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _TRAIN_CLS:
|
||||
parser = HfArgumentParser(_TRAIN_ARGS)
|
||||
return _parse_args(parser, args)
|
||||
|
||||
|
||||
def _parse_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
|
||||
def _parse_infer_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _INFER_CLS:
|
||||
parser = HfArgumentParser(_INFER_ARGS)
|
||||
return _parse_args(parser, args)
|
||||
|
||||
|
||||
def _parse_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS:
|
||||
def _parse_eval_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _EVAL_CLS:
|
||||
parser = HfArgumentParser(_EVAL_ARGS)
|
||||
return _parse_args(parser, args)
|
||||
|
||||
|
||||
def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||
def get_ray_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> RayArguments:
|
||||
parser = HfArgumentParser(RayArguments)
|
||||
(ray_args,) = _parse_args(parser, args, allow_extra_keys=True)
|
||||
return ray_args
|
||||
|
||||
|
||||
def get_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _TRAIN_CLS:
|
||||
model_args, data_args, training_args, finetuning_args, generating_args = _parse_train_args(args)
|
||||
|
||||
# Setup logging
|
||||
@ -371,7 +385,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||
return model_args, data_args, training_args, finetuning_args, generating_args
|
||||
|
||||
|
||||
def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
|
||||
def get_infer_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _INFER_CLS:
|
||||
model_args, data_args, finetuning_args, generating_args = _parse_infer_args(args)
|
||||
|
||||
_set_transformers_logging()
|
||||
@ -404,7 +418,7 @@ def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
|
||||
return model_args, data_args, finetuning_args, generating_args
|
||||
|
||||
|
||||
def get_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS:
|
||||
def get_eval_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _EVAL_CLS:
|
||||
model_args, data_args, eval_args, finetuning_args = _parse_eval_args(args)
|
||||
|
||||
_set_transformers_logging()
|
||||
|
48
src/llamafactory/hparams/training_args.py
Normal file
48
src/llamafactory/hparams/training_args.py
Normal file
@ -0,0 +1,48 @@
|
||||
import json
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Literal, Optional, Union
|
||||
|
||||
from transformers import Seq2SeqTrainingArguments
|
||||
from transformers.training_args import _convert_str_dict
|
||||
|
||||
from ..extras.misc import use_ray
|
||||
|
||||
|
||||
@dataclass
|
||||
class RayArguments:
|
||||
r"""
|
||||
Arguments pertaining to the Ray training.
|
||||
"""
|
||||
|
||||
ray_run_name: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "The training results will be saved at `saves/ray_run_name`."},
|
||||
)
|
||||
ray_num_workers: int = field(
|
||||
default=1,
|
||||
metadata={"help": "The number of workers for Ray training. Default is 1 worker."},
|
||||
)
|
||||
resources_per_worker: Union[dict, str] = field(
|
||||
default_factory=lambda: {"GPU": 1},
|
||||
metadata={"help": "The resources per worker for Ray training. Default is to use 1 GPU per worker."},
|
||||
)
|
||||
placement_strategy: Literal["SPREAD", "PACK", "STRICT_SPREAD", "STRICT_PACK"] = field(
|
||||
default="PACK",
|
||||
metadata={"help": "The placement strategy for Ray training. Default is PACK."},
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
self.use_ray = use_ray()
|
||||
if isinstance(self.resources_per_worker, str) and self.resources_per_worker.startswith("{"):
|
||||
self.resources_per_worker = _convert_str_dict(json.loads(self.resources_per_worker))
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingArguments(RayArguments, Seq2SeqTrainingArguments):
|
||||
r"""
|
||||
Arguments pertaining to the trainer.
|
||||
"""
|
||||
|
||||
def __post_init__(self):
|
||||
Seq2SeqTrainingArguments.__post_init__(self)
|
||||
RayArguments.__post_init__(self)
|
@ -15,9 +15,9 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from transformers.utils import is_flash_attn_2_available, is_torch_sdpa_available
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
from ...extras import logging
|
||||
from ...extras.misc import check_version
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -35,8 +35,8 @@ def configure_attn_implementation(
|
||||
if getattr(config, "model_type", None) == "gemma2" and is_trainable:
|
||||
if model_args.flash_attn == "auto" or model_args.flash_attn == "fa2":
|
||||
if is_flash_attn_2_available():
|
||||
require_version("transformers>=4.42.4", "To fix: pip install transformers>=4.42.4")
|
||||
require_version("flash_attn>=2.6.3", "To fix: pip install flash_attn>=2.6.3")
|
||||
check_version("transformers>=4.42.4")
|
||||
check_version("flash_attn>=2.6.3")
|
||||
if model_args.flash_attn != "fa2":
|
||||
logger.warning_rank0("Gemma-2 should use flash attention 2, change `flash_attn` to fa2.")
|
||||
model_args.flash_attn = "fa2"
|
||||
|
@ -122,7 +122,7 @@ def _gradient_checkpointing_enable(
|
||||
if "value" in inspect.signature(self._set_gradient_checkpointing).parameters: # old GC format
|
||||
self.apply(partial(self._set_gradient_checkpointing, value=True))
|
||||
self.enable_input_require_grads()
|
||||
logger.warning_once("You are using the old GC format, some features (e.g. BAdam) will be invalid.")
|
||||
logger.warning_rank0_once("You are using the old GC format, some features (e.g. BAdam) will be invalid.")
|
||||
else: # have already enabled input require gradients
|
||||
self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=gradient_checkpointing_func)
|
||||
|
||||
|
@ -31,10 +31,10 @@ from transformers.models.llama.modeling_llama import (
|
||||
apply_rotary_pos_emb,
|
||||
repeat_kv,
|
||||
)
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
from ...extras import logging
|
||||
from ...extras.constants import SUPPORTED_CLASS_FOR_S2ATTN
|
||||
from ...extras.misc import check_version
|
||||
from ...extras.packages import is_transformers_version_greater_than
|
||||
|
||||
|
||||
@ -353,7 +353,7 @@ def llama_sdpa_attention_forward(
|
||||
|
||||
|
||||
def _apply_llama_patch() -> None:
|
||||
require_version("transformers>=4.41.2,<=4.46.1", "To fix: pip install transformers>=4.41.2,<=4.46.1")
|
||||
check_version("transformers>=4.41.2,<=4.46.1")
|
||||
LlamaAttention.forward = llama_attention_forward
|
||||
LlamaFlashAttention2.forward = llama_flash_attention_2_forward
|
||||
LlamaSdpaAttention.forward = llama_sdpa_attention_forward
|
||||
|
@ -16,7 +16,8 @@ from typing import TYPE_CHECKING, Sequence
|
||||
|
||||
import torch
|
||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
from ...extras.misc import check_version
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -26,7 +27,7 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
def _set_z3_leaf_modules(model: "PreTrainedModel", leaf_modules: Sequence["torch.nn.Module"]) -> None:
|
||||
require_version("deepspeed>=0.13.0", "To fix: pip install deepspeed>=0.13.0")
|
||||
check_version("deepspeed>=0.13.0")
|
||||
from deepspeed.utils import set_z3_leaf_modules # type: ignore
|
||||
|
||||
set_z3_leaf_modules(model, leaf_modules)
|
||||
|
@ -41,9 +41,9 @@ from typing import TYPE_CHECKING, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
from ...extras import logging
|
||||
from ...extras.misc import check_version
|
||||
from ...extras.packages import is_transformers_version_greater_than
|
||||
|
||||
|
||||
@ -118,6 +118,6 @@ def configure_packing(model_args: "ModelArguments", is_trainable: bool) -> None:
|
||||
if not is_trainable or not model_args.block_diag_attn:
|
||||
return
|
||||
|
||||
require_version("transformers>=4.43.0,<=4.46.1", "To fix: pip install transformers>=4.43.0,<=4.46.1")
|
||||
check_version("transformers>=4.43.0,<=4.46.1")
|
||||
transformers.modeling_flash_attention_utils._get_unpad_data = get_unpad_data
|
||||
logger.info_rank0("Using block diagonal attention for sequence packing without cross-attention.")
|
||||
|
@ -26,11 +26,10 @@ from datasets import load_dataset
|
||||
from transformers import BitsAndBytesConfig, EetqConfig, GPTQConfig, HqqConfig
|
||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||
from transformers.modeling_utils import is_fsdp_enabled
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
from ...extras import logging
|
||||
from ...extras.constants import FILEEXT2TYPE
|
||||
from ...extras.misc import get_current_device
|
||||
from ...extras.misc import check_version, get_current_device
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -118,15 +117,15 @@ def configure_quantization(
|
||||
quant_method = quantization_config.get("quant_method", "")
|
||||
|
||||
if quant_method == QuantizationMethod.GPTQ:
|
||||
require_version("auto_gptq>=0.5.0", "To fix: pip install auto_gptq>=0.5.0")
|
||||
check_version("auto_gptq>=0.5.0", mandatory=True)
|
||||
quantization_config.pop("disable_exllama", None) # remove deprecated args
|
||||
quantization_config["use_exllama"] = False # disable exllama
|
||||
|
||||
if quant_method == QuantizationMethod.AWQ:
|
||||
require_version("autoawq", "To fix: pip install autoawq")
|
||||
check_version("autoawq", mandatory=True)
|
||||
|
||||
if quant_method == QuantizationMethod.AQLM:
|
||||
require_version("aqlm>=1.1.0", "To fix: pip install aqlm[gpu]>=1.1.0")
|
||||
check_version("aqlm>=1.1.0", mandatory=True)
|
||||
quantization_config["bits"] = 2
|
||||
|
||||
quant_bits = quantization_config.get("bits", "?")
|
||||
@ -136,8 +135,8 @@ def configure_quantization(
|
||||
if model_args.export_quantization_bit not in [8, 4, 3, 2]:
|
||||
raise ValueError("AutoGPTQ only accepts 2/3/4/8-bit quantization.")
|
||||
|
||||
require_version("optimum>=1.17.0", "To fix: pip install optimum>=1.17.0")
|
||||
require_version("auto_gptq>=0.5.0", "To fix: pip install auto_gptq>=0.5.0")
|
||||
check_version("optimum>=1.17.0", mandatory=True)
|
||||
check_version("auto_gptq>=0.5.0", mandatory=True)
|
||||
from accelerate.utils import get_max_memory
|
||||
|
||||
if getattr(config, "model_type", None) == "chatglm":
|
||||
@ -154,10 +153,10 @@ def configure_quantization(
|
||||
elif model_args.quantization_bit is not None: # on-the-fly
|
||||
if model_args.quantization_method == QuantizationMethod.BITS_AND_BYTES.value:
|
||||
if model_args.quantization_bit == 8:
|
||||
require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
|
||||
check_version("bitsandbytes>=0.37.0", mandatory=True)
|
||||
init_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
|
||||
elif model_args.quantization_bit == 4:
|
||||
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
|
||||
check_version("bitsandbytes>=0.39.0", mandatory=True)
|
||||
init_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_compute_dtype=model_args.compute_dtype,
|
||||
@ -175,7 +174,7 @@ def configure_quantization(
|
||||
if model_args.quantization_bit != 4:
|
||||
raise ValueError("Only 4-bit quantized model can use fsdp+qlora or auto device map.")
|
||||
|
||||
require_version("bitsandbytes>=0.43.0", "To fix: pip install bitsandbytes>=0.43.0")
|
||||
check_version("bitsandbytes>=0.43.0", mandatory=True)
|
||||
else:
|
||||
init_kwargs["device_map"] = {"": get_current_device()} # change auto device map for inference
|
||||
|
||||
@ -187,7 +186,7 @@ def configure_quantization(
|
||||
if is_deepspeed_zero3_enabled() or is_fsdp_enabled():
|
||||
raise ValueError("HQQ quantization is incompatible with DeepSpeed ZeRO-3 or FSDP.")
|
||||
|
||||
require_version("hqq", "To fix: pip install hqq")
|
||||
check_version("hqq", mandatory=True)
|
||||
init_kwargs["quantization_config"] = HqqConfig(
|
||||
nbits=model_args.quantization_bit, quant_zero=False, quant_scale=False, axis=0
|
||||
) # use ATEN kernel (axis=0) for performance
|
||||
@ -199,6 +198,6 @@ def configure_quantization(
|
||||
if is_deepspeed_zero3_enabled() or is_fsdp_enabled():
|
||||
raise ValueError("EETQ quantization is incompatible with DeepSpeed ZeRO-3 or FSDP.")
|
||||
|
||||
require_version("eetq", "To fix: pip install eetq")
|
||||
check_version("eetq", mandatory=True)
|
||||
init_kwargs["quantization_config"] = EetqConfig()
|
||||
logger.info_rank0(f"Quantizing model to {model_args.quantization_bit} bit with EETQ.")
|
||||
|
@ -35,7 +35,7 @@ from typing_extensions import override
|
||||
|
||||
from ..extras import logging
|
||||
from ..extras.constants import TRAINER_LOG, V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
|
||||
from ..extras.misc import get_peak_memory
|
||||
from ..extras.misc import get_peak_memory, use_ray
|
||||
|
||||
|
||||
if is_safetensors_available():
|
||||
@ -194,7 +194,7 @@ class LogCallback(TrainerCallback):
|
||||
self.do_train = False
|
||||
# Web UI
|
||||
self.webui_mode = os.environ.get("LLAMABOARD_ENABLED", "0").lower() in ["true", "1"]
|
||||
if self.webui_mode:
|
||||
if self.webui_mode and not use_ray():
|
||||
signal.signal(signal.SIGABRT, self._set_abort)
|
||||
self.logger_handler = logging.LoggerHandler(os.environ.get("LLAMABOARD_WORKDIR"))
|
||||
logging.add_handler(self.logger_handler)
|
||||
@ -239,7 +239,7 @@ class LogCallback(TrainerCallback):
|
||||
and os.path.exists(os.path.join(args.output_dir, TRAINER_LOG))
|
||||
and args.overwrite_output_dir
|
||||
):
|
||||
logger.warning_once("Previous trainer log in this folder will be deleted.")
|
||||
logger.warning_rank0_once("Previous trainer log in this folder will be deleted.")
|
||||
os.remove(os.path.join(args.output_dir, TRAINER_LOG))
|
||||
|
||||
@override
|
||||
@ -383,7 +383,7 @@ class ReporterCallback(TrainerCallback):
|
||||
)
|
||||
|
||||
if self.finetuning_args.use_swanlab:
|
||||
import swanlab
|
||||
import swanlab # type: ignore
|
||||
|
||||
swanlab.config.update(
|
||||
{
|
||||
|
@ -31,7 +31,7 @@ from typing_extensions import override
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
from ...extras.packages import is_transformers_version_equal_to_4_46, is_transformers_version_greater_than
|
||||
from ..callbacks import SaveProcessorCallback
|
||||
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps
|
||||
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps, nested_detach
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -193,7 +193,7 @@ class CustomDPOTrainer(DPOTrainer):
|
||||
Otherwise the average log probabilities.
|
||||
"""
|
||||
if self.finetuning_args.use_ref_model:
|
||||
batch = {k: v.detach().clone() for k, v in batch.items()} # avoid error
|
||||
batch = nested_detach(batch, clone=True) # avoid error
|
||||
|
||||
all_logits: "torch.Tensor" = model(**batch, return_dict=True, use_cache=False).logits.to(torch.float32)
|
||||
all_logps, valid_length = get_batch_logps(logits=all_logits, labels=batch["labels"])
|
||||
|
@ -30,7 +30,7 @@ from typing_extensions import override
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
from ...extras.packages import is_transformers_version_equal_to_4_46, is_transformers_version_greater_than
|
||||
from ..callbacks import SaveProcessorCallback
|
||||
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps
|
||||
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps, nested_detach
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -142,7 +142,7 @@ class CustomKTOTrainer(KTOTrainer):
|
||||
r"""
|
||||
Runs forward pass and computes the log probabilities.
|
||||
"""
|
||||
batch = {k: v.detach().clone() for k, v in batch.items()} # avoid error
|
||||
batch = nested_detach(batch, clone=True) # avoid error
|
||||
model_inputs = {
|
||||
"input_ids": batch[f"{prefix}input_ids"],
|
||||
"attention_mask": batch[f"{prefix}attention_mask"],
|
||||
|
@ -122,7 +122,7 @@ def run_sft(
|
||||
|
||||
# Predict
|
||||
if training_args.do_predict:
|
||||
logger.warning_once("Batch generation can be very slow. Consider using `scripts/vllm_infer.py` instead.")
|
||||
logger.warning_rank0_once("Batch generation can be very slow. Consider using `scripts/vllm_infer.py` instead.")
|
||||
predict_results = trainer.predict(dataset_module["eval_dataset"], metric_key_prefix="predict", **gen_kwargs)
|
||||
trainer.log_metrics("predict", predict_results.metrics)
|
||||
trainer.save_metrics("predict", predict_results.metrics)
|
||||
|
@ -17,7 +17,9 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union
|
||||
from collections.abc import Mapping
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from transformers import Trainer
|
||||
@ -30,20 +32,25 @@ from typing_extensions import override
|
||||
|
||||
from ..extras import logging
|
||||
from ..extras.constants import IGNORE_INDEX
|
||||
from ..extras.packages import is_galore_available
|
||||
from ..extras.packages import is_galore_available, is_ray_available
|
||||
from ..hparams import FinetuningArguments, ModelArguments
|
||||
from ..model import find_all_linear_modules, load_model, load_tokenizer, load_valuehead_params
|
||||
|
||||
|
||||
if is_galore_available():
|
||||
from galore_torch import GaLoreAdafactor, GaLoreAdamW, GaLoreAdamW8bit
|
||||
from galore_torch import GaLoreAdafactor, GaLoreAdamW, GaLoreAdamW8bit # type: ignore
|
||||
|
||||
|
||||
if is_ray_available():
|
||||
from ray.train import RunConfig, ScalingConfig
|
||||
from ray.train.torch import TorchTrainer
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PreTrainedModel, Seq2SeqTrainingArguments, TrainerCallback
|
||||
from transformers import PreTrainedModel, TrainerCallback
|
||||
from trl import AutoModelForCausalLMWithValueHead
|
||||
|
||||
from ..hparams import DataArguments
|
||||
from ..hparams import DataArguments, RayArguments, TrainingArguments
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@ -74,7 +81,7 @@ def create_modelcard_and_push(
|
||||
trainer: "Trainer",
|
||||
model_args: "ModelArguments",
|
||||
data_args: "DataArguments",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
training_args: "TrainingArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
) -> None:
|
||||
kwargs = {
|
||||
@ -187,7 +194,7 @@ def _get_decay_parameter_names(model: "PreTrainedModel") -> List[str]:
|
||||
|
||||
def _create_galore_optimizer(
|
||||
model: "PreTrainedModel",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
training_args: "TrainingArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
) -> "torch.optim.Optimizer":
|
||||
if len(finetuning_args.galore_target) == 1 and finetuning_args.galore_target[0] == "all":
|
||||
@ -271,7 +278,7 @@ def _create_galore_optimizer(
|
||||
|
||||
def _create_loraplus_optimizer(
|
||||
model: "PreTrainedModel",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
training_args: "TrainingArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
) -> "torch.optim.Optimizer":
|
||||
default_lr = training_args.learning_rate
|
||||
@ -311,7 +318,7 @@ def _create_loraplus_optimizer(
|
||||
|
||||
def _create_badam_optimizer(
|
||||
model: "PreTrainedModel",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
training_args: "TrainingArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
) -> "torch.optim.Optimizer":
|
||||
decay_params, nodecay_params = [], []
|
||||
@ -330,7 +337,7 @@ def _create_badam_optimizer(
|
||||
]
|
||||
|
||||
if finetuning_args.badam_mode == "layer":
|
||||
from badam import BlockOptimizer
|
||||
from badam import BlockOptimizer # type: ignore
|
||||
|
||||
base_optimizer = optim_class(param_groups, **optim_kwargs)
|
||||
optimizer = BlockOptimizer(
|
||||
@ -350,7 +357,7 @@ def _create_badam_optimizer(
|
||||
)
|
||||
|
||||
elif finetuning_args.badam_mode == "ratio":
|
||||
from badam import BlockOptimizerRatio
|
||||
from badam import BlockOptimizerRatio # type: ignore
|
||||
|
||||
assert finetuning_args.badam_update_ratio > 1e-6
|
||||
optimizer = BlockOptimizerRatio(
|
||||
@ -372,9 +379,9 @@ def _create_badam_optimizer(
|
||||
|
||||
def _create_adam_mini_optimizer(
|
||||
model: "PreTrainedModel",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
training_args: "TrainingArguments",
|
||||
) -> "torch.optim.Optimizer":
|
||||
from adam_mini import Adam_mini
|
||||
from adam_mini import Adam_mini # type: ignore
|
||||
|
||||
hidden_size = getattr(model.config, "hidden_size", None)
|
||||
num_q_head = getattr(model.config, "num_attention_heads", None)
|
||||
@ -397,7 +404,7 @@ def _create_adam_mini_optimizer(
|
||||
|
||||
def create_custom_optimizer(
|
||||
model: "PreTrainedModel",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
training_args: "TrainingArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
) -> Optional["torch.optim.Optimizer"]:
|
||||
if finetuning_args.use_galore:
|
||||
@ -414,7 +421,7 @@ def create_custom_optimizer(
|
||||
|
||||
|
||||
def create_custom_scheduler(
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
training_args: "TrainingArguments",
|
||||
num_training_steps: int,
|
||||
optimizer: Optional["torch.optim.Optimizer"] = None,
|
||||
) -> None:
|
||||
@ -459,12 +466,33 @@ def get_batch_logps(
|
||||
return (per_token_logps * loss_mask).sum(-1), loss_mask.sum(-1)
|
||||
|
||||
|
||||
def nested_detach(
|
||||
tensors: Union["torch.Tensor", List["torch.Tensor"], Tuple["torch.Tensor"], Dict[str, "torch.Tensor"]],
|
||||
clone: bool = False,
|
||||
):
|
||||
r"""
|
||||
Detach `tensors` (even if it's a nested list/tuple/dict of tensors).
|
||||
"""
|
||||
if isinstance(tensors, (list, tuple)):
|
||||
return type(tensors)(nested_detach(t, clone=clone) for t in tensors)
|
||||
elif isinstance(tensors, Mapping):
|
||||
return type(tensors)({k: nested_detach(t, clone=clone) for k, t in tensors.items()})
|
||||
|
||||
if isinstance(tensors, torch.Tensor):
|
||||
if clone:
|
||||
return tensors.detach().clone()
|
||||
else:
|
||||
return tensors.detach()
|
||||
else:
|
||||
return tensors
|
||||
|
||||
|
||||
def get_swanlab_callback(finetuning_args: "FinetuningArguments") -> "TrainerCallback":
|
||||
r"""
|
||||
Gets the callback for logging to SwanLab.
|
||||
"""
|
||||
import swanlab
|
||||
from swanlab.integration.transformers import SwanLabCallback
|
||||
import swanlab # type: ignore
|
||||
from swanlab.integration.transformers import SwanLabCallback # type: ignore
|
||||
|
||||
if finetuning_args.swanlab_api_key is not None:
|
||||
swanlab.login(api_key=finetuning_args.swanlab_api_key)
|
||||
@ -477,3 +505,28 @@ def get_swanlab_callback(finetuning_args: "FinetuningArguments") -> "TrainerCall
|
||||
config={"Framework": "🦙LlamaFactory"},
|
||||
)
|
||||
return swanlab_callback
|
||||
|
||||
|
||||
def get_ray_trainer(
|
||||
training_function: Callable,
|
||||
train_loop_config: Dict[str, Any],
|
||||
ray_args: "RayArguments",
|
||||
) -> "TorchTrainer":
|
||||
if not ray_args.use_ray:
|
||||
raise ValueError("Ray was not enabled. Please set `USE_RAY=1` to enable ray.")
|
||||
|
||||
trainer = TorchTrainer(
|
||||
training_function,
|
||||
train_loop_config=train_loop_config,
|
||||
scaling_config=ScalingConfig(
|
||||
num_workers=ray_args.ray_num_workers,
|
||||
resources_per_worker=ray_args.resources_per_worker,
|
||||
placement_strategy=ray_args.placement_strategy,
|
||||
use_gpu=True,
|
||||
),
|
||||
run_config=RunConfig(
|
||||
name=ray_args.ray_run_name,
|
||||
storage_path=Path("./saves").absolute().as_posix(),
|
||||
),
|
||||
)
|
||||
return trainer
|
||||
|
@ -22,7 +22,8 @@ from transformers import PreTrainedModel
|
||||
from ..data import get_template_and_fix_tokenizer
|
||||
from ..extras import logging
|
||||
from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
|
||||
from ..hparams import get_infer_args, get_train_args
|
||||
from ..extras.packages import is_ray_available
|
||||
from ..hparams import get_infer_args, get_ray_args, get_train_args, read_args
|
||||
from ..model import load_model, load_tokenizer
|
||||
from .callbacks import LogCallback, PissaConvertCallback, ReporterCallback
|
||||
from .dpo import run_dpo
|
||||
@ -31,7 +32,11 @@ from .ppo import run_ppo
|
||||
from .pt import run_pt
|
||||
from .rm import run_rm
|
||||
from .sft import run_sft
|
||||
from .trainer_utils import get_swanlab_callback
|
||||
from .trainer_utils import get_ray_trainer, get_swanlab_callback
|
||||
|
||||
|
||||
if is_ray_available():
|
||||
from ray.train.huggingface.transformers import RayTrainReportCallback
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -41,10 +46,12 @@ if TYPE_CHECKING:
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: List["TrainerCallback"] = []) -> None:
|
||||
callbacks.append(LogCallback())
|
||||
def _training_function(config: Dict[str, Any]) -> None:
|
||||
args = config.get("args")
|
||||
callbacks: List[Any] = config.get("callbacks")
|
||||
model_args, data_args, training_args, finetuning_args, generating_args = get_train_args(args)
|
||||
|
||||
callbacks.append(LogCallback())
|
||||
if finetuning_args.pissa_convert:
|
||||
callbacks.append(PissaConvertCallback())
|
||||
|
||||
@ -69,6 +76,22 @@ def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: List["TrainerCallb
|
||||
raise ValueError(f"Unknown task: {finetuning_args.stage}.")
|
||||
|
||||
|
||||
def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: Optional[List["TrainerCallback"]] = None) -> None:
|
||||
args = read_args(args)
|
||||
ray_args = get_ray_args(args)
|
||||
callbacks = callbacks or []
|
||||
if ray_args.use_ray:
|
||||
callbacks.append(RayTrainReportCallback())
|
||||
trainer = get_ray_trainer(
|
||||
training_function=_training_function,
|
||||
train_loop_config={"args": args, "callbacks": callbacks},
|
||||
ray_args=ray_args,
|
||||
)
|
||||
trainer.fit()
|
||||
else:
|
||||
_training_function(config={"args": args, "callbacks": callbacks})
|
||||
|
||||
|
||||
def export_model(args: Optional[Dict[str, Any]] = None) -> None:
|
||||
model_args, data_args, finetuning_args, _ = get_infer_args(args)
|
||||
|
||||
|
@ -21,7 +21,7 @@ from typing import TYPE_CHECKING, Any, Dict, Generator, Optional
|
||||
from transformers.trainer import TRAINING_ARGS_NAME
|
||||
|
||||
from ..extras.constants import LLAMABOARD_CONFIG, PEFT_METHODS, TRAINING_STAGES
|
||||
from ..extras.misc import is_gpu_or_npu_available, torch_gc
|
||||
from ..extras.misc import is_gpu_or_npu_available, torch_gc, use_ray
|
||||
from ..extras.packages import is_gradio_available, is_transformers_version_equal_to_4_46
|
||||
from .common import DEFAULT_CACHE_DIR, DEFAULT_CONFIG_DIR, QUANTIZATION_BITS, get_save_dir, load_config
|
||||
from .locales import ALERTS, LOCALES
|
||||
@ -394,12 +394,12 @@ class Runner:
|
||||
continue
|
||||
|
||||
if self.do_train:
|
||||
if os.path.exists(os.path.join(output_path, TRAINING_ARGS_NAME)):
|
||||
if os.path.exists(os.path.join(output_path, TRAINING_ARGS_NAME)) or use_ray():
|
||||
finish_info = ALERTS["info_finished"][lang]
|
||||
else:
|
||||
finish_info = ALERTS["err_failed"][lang]
|
||||
else:
|
||||
if os.path.exists(os.path.join(output_path, "all_results.json")):
|
||||
if os.path.exists(os.path.join(output_path, "all_results.json")) or use_ray():
|
||||
finish_info = get_eval_results(os.path.join(output_path, "all_results.json"))
|
||||
else:
|
||||
finish_info = ALERTS["err_failed"][lang]
|
||||
|
@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from typing import TYPE_CHECKING, List, Sequence
|
||||
from typing import TYPE_CHECKING, Sequence
|
||||
|
||||
import pytest
|
||||
from transformers import AutoTokenizer
|
||||
@ -42,39 +42,36 @@ MESSAGES = [
|
||||
def _check_tokenization(
|
||||
tokenizer: "PreTrainedTokenizer", batch_input_ids: Sequence[Sequence[int]], batch_text: Sequence[str]
|
||||
) -> None:
|
||||
r"""
|
||||
Checks token ids and texts.
|
||||
|
||||
encode(text) == token_ids
|
||||
decode(token_ids) == text
|
||||
"""
|
||||
for input_ids, text in zip(batch_input_ids, batch_text):
|
||||
assert input_ids == tokenizer.encode(text, add_special_tokens=False)
|
||||
assert tokenizer.encode(text, add_special_tokens=False) == input_ids
|
||||
assert tokenizer.decode(input_ids) == text
|
||||
|
||||
|
||||
def _check_single_template(
|
||||
model_id: str, template_name: str, prompt_str: str, answer_str: str, extra_str: str, use_fast: bool
|
||||
) -> List[str]:
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=use_fast, token=HF_TOKEN)
|
||||
content_str = tokenizer.apply_chat_template(MESSAGES, tokenize=False)
|
||||
content_ids = tokenizer.apply_chat_template(MESSAGES, tokenize=True)
|
||||
template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template=template_name))
|
||||
prompt_ids, answer_ids = template.encode_oneturn(tokenizer, MESSAGES)
|
||||
assert content_str == prompt_str + answer_str + extra_str
|
||||
assert content_ids == prompt_ids + answer_ids + tokenizer.encode(extra_str, add_special_tokens=False)
|
||||
_check_tokenization(tokenizer, (prompt_ids, answer_ids), (prompt_str, answer_str))
|
||||
return content_ids
|
||||
|
||||
|
||||
def _check_template(model_id: str, template_name: str, prompt_str: str, answer_str: str, extra_str: str = "") -> None:
|
||||
"""
|
||||
Checks template for both the slow tokenizer and the fast tokenizer.
|
||||
def _check_template(model_id: str, template_name: str, prompt_str: str, answer_str: str, use_fast: bool) -> None:
|
||||
r"""
|
||||
Checks template.
|
||||
|
||||
Args:
|
||||
model_id: the model id on hugging face hub.
|
||||
template_name: the template name.
|
||||
prompt_str: the string corresponding to the prompt part.
|
||||
answer_str: the string corresponding to the answer part.
|
||||
extra_str: the extra string in the jinja template of the original tokenizer.
|
||||
use_fast: whether to use fast tokenizer.
|
||||
"""
|
||||
slow_ids = _check_single_template(model_id, template_name, prompt_str, answer_str, extra_str, use_fast=False)
|
||||
fast_ids = _check_single_template(model_id, template_name, prompt_str, answer_str, extra_str, use_fast=True)
|
||||
assert slow_ids == fast_ids
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=use_fast, token=HF_TOKEN)
|
||||
content_str = tokenizer.apply_chat_template(MESSAGES, tokenize=False)
|
||||
content_ids = tokenizer.apply_chat_template(MESSAGES, tokenize=True)
|
||||
template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template=template_name))
|
||||
prompt_ids, answer_ids = template.encode_oneturn(tokenizer, MESSAGES)
|
||||
assert content_str == prompt_str + answer_str
|
||||
assert content_ids == prompt_ids + answer_ids
|
||||
_check_tokenization(tokenizer, (prompt_ids, answer_ids), (prompt_str, answer_str))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_fast", [True, False])
|
||||
@ -125,19 +122,21 @@ def test_jinja_template(use_fast: bool):
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.")
|
||||
def test_gemma_template():
|
||||
@pytest.mark.parametrize("use_fast", [True, False])
|
||||
def test_gemma_template(use_fast: bool):
|
||||
prompt_str = (
|
||||
"<bos><start_of_turn>user\nHow are you<end_of_turn>\n"
|
||||
"<start_of_turn>model\nI am fine!<end_of_turn>\n"
|
||||
"<start_of_turn>user\n你好<end_of_turn>\n"
|
||||
"<start_of_turn>model\n"
|
||||
)
|
||||
answer_str = "很高兴认识你!"
|
||||
_check_template("google/gemma-2-9b-it", "gemma", prompt_str, answer_str, extra_str="<end_of_turn>\n")
|
||||
answer_str = "很高兴认识你!<end_of_turn>\n"
|
||||
_check_template("google/gemma-2-9b-it", "gemma", prompt_str, answer_str, use_fast)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.")
|
||||
def test_llama3_template():
|
||||
@pytest.mark.parametrize("use_fast", [True, False])
|
||||
def test_llama3_template(use_fast: bool):
|
||||
prompt_str = (
|
||||
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nHow are you<|eot_id|>"
|
||||
"<|start_header_id|>assistant<|end_header_id|>\n\nI am fine!<|eot_id|>"
|
||||
@ -145,10 +144,25 @@ def test_llama3_template():
|
||||
"<|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
)
|
||||
answer_str = "很高兴认识你!<|eot_id|>"
|
||||
_check_template("meta-llama/Meta-Llama-3-8B-Instruct", "llama3", prompt_str, answer_str)
|
||||
_check_template("meta-llama/Meta-Llama-3-8B-Instruct", "llama3", prompt_str, answer_str, use_fast)
|
||||
|
||||
|
||||
def test_qwen_template():
|
||||
@pytest.mark.parametrize(
|
||||
"use_fast", [True, pytest.param(False, marks=pytest.mark.xfail(reason="Phi-4 slow tokenizer is broken."))]
|
||||
)
|
||||
def test_phi4_template(use_fast: bool):
|
||||
prompt_str = (
|
||||
"<|im_start|>user<|im_sep|>How are you<|im_end|>"
|
||||
"<|im_start|>assistant<|im_sep|>I am fine!<|im_end|>"
|
||||
"<|im_start|>user<|im_sep|>你好<|im_end|>"
|
||||
"<|im_start|>assistant<|im_sep|>"
|
||||
)
|
||||
answer_str = "很高兴认识你!<|im_end|>"
|
||||
_check_template("microsoft/phi-4", "phi4", prompt_str, answer_str, use_fast)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_fast", [True, False])
|
||||
def test_qwen_template(use_fast: bool):
|
||||
prompt_str = (
|
||||
"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
|
||||
"<|im_start|>user\nHow are you<|im_end|>\n"
|
||||
@ -156,17 +170,18 @@ def test_qwen_template():
|
||||
"<|im_start|>user\n你好<|im_end|>\n"
|
||||
"<|im_start|>assistant\n"
|
||||
)
|
||||
answer_str = "很高兴认识你!<|im_end|>"
|
||||
_check_template("Qwen/Qwen2-7B-Instruct", "qwen", prompt_str, answer_str, extra_str="\n")
|
||||
answer_str = "很高兴认识你!<|im_end|>\n"
|
||||
_check_template("Qwen/Qwen2-7B-Instruct", "qwen", prompt_str, answer_str, use_fast)
|
||||
|
||||
|
||||
@pytest.mark.xfail(reason="The fast tokenizer of Yi model is corrupted.")
|
||||
def test_yi_template():
|
||||
@pytest.mark.parametrize("use_fast", [True, False])
|
||||
@pytest.mark.xfail(reason="Yi tokenizer is broken.")
|
||||
def test_yi_template(use_fast: bool):
|
||||
prompt_str = (
|
||||
"<|im_start|>user\nHow are you<|im_end|>\n"
|
||||
"<|im_start|>assistant\nI am fine!<|im_end|>\n"
|
||||
"<|im_start|>user\n你好<|im_end|>\n"
|
||||
"<|im_start|>assistant\n"
|
||||
)
|
||||
answer_str = "很高兴认识你!<|im_end|>"
|
||||
_check_template("01-ai/Yi-1.5-6B-Chat", "yi", prompt_str, answer_str)
|
||||
answer_str = "很高兴认识你!<|im_end|>\n"
|
||||
_check_template("01-ai/Yi-1.5-6B-Chat", "yi", prompt_str, answer_str, use_fast)
|
||||
|
Loading…
x
Reference in New Issue
Block a user