mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-01 11:12:50 +08:00
Compare commits
50 Commits
Author | SHA1 | Date | |
---|---|---|---|
|
4ba7de0434 | ||
|
ea8a2d60d0 | ||
|
ae0ef374a3 | ||
|
edd112f35c | ||
|
7218d4aa96 | ||
|
4380b7b35e | ||
|
3307ff1d4a | ||
|
2aadc90c2d | ||
|
2353e16e20 | ||
|
6812f5e1f5 | ||
|
2077875622 | ||
|
678b7d69d2 | ||
|
f00742b078 | ||
|
fdb70c04e0 | ||
|
95ed6c45cd | ||
|
cf1087d409 | ||
|
766884fa5c | ||
|
6a8d88826e | ||
|
043103e1c9 | ||
|
5817583630 | ||
|
62bd2c8047 | ||
|
1b549e3199 | ||
|
c6290db118 | ||
|
d30cbcdfa5 | ||
|
62c6943699 | ||
|
8e7727f4ee | ||
|
e117e3c2b7 | ||
|
dcd75e7063 | ||
|
4465e4347e | ||
|
c5a08291f4 | ||
|
544b7dc2ed | ||
|
ac6c93df1f | ||
|
0b188ca00c | ||
|
0a004904bd | ||
|
bb7bf51554 | ||
|
7242caf0ff | ||
|
ed57b7ba2a | ||
|
b10333dafb | ||
|
6b46c8b689 | ||
|
be27eae175 | ||
|
31b0787e12 | ||
|
fffa43be86 | ||
|
8ed085e403 | ||
|
1221533542 | ||
|
8a3bddc7fa | ||
|
3a119ed5a2 | ||
|
0d7d0ea972 | ||
|
0e1fea71d2 | ||
|
ec04d7b89c | ||
|
cabc9207be |
7
.github/ISSUE_TEMPLATE/config.yml
vendored
7
.github/ISSUE_TEMPLATE/config.yml
vendored
@ -1 +1,8 @@
|
||||
blank_issues_enabled: false
|
||||
contact_links:
|
||||
- name: 📚 FAQs | 常见问题
|
||||
url: https://github.com/hiyouga/LLaMA-Factory/issues/4614
|
||||
about: Reading in advance is recommended | 建议提前阅读
|
||||
- name: Discussions | 讨论区
|
||||
url: https://github.com/hiyouga/LLaMA-Factory/discussions
|
||||
about: Please ask fine-tuning questions here | 请在这里讨论训练问题
|
||||
|
62
.github/workflows/docker.yml
vendored
62
.github/workflows/docker.yml
vendored
@ -21,10 +21,17 @@ on:
|
||||
|
||||
jobs:
|
||||
build:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
device:
|
||||
- "cuda"
|
||||
- "npu"
|
||||
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}
|
||||
group: ${{ github.workflow }}-${{ github.ref }}-${{ matrix.device }}
|
||||
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
|
||||
|
||||
environment:
|
||||
@ -33,27 +40,44 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Free up disk space
|
||||
run: |
|
||||
df -h
|
||||
sudo rm -rf /usr/share/dotnet
|
||||
sudo rm -rf /opt/ghc
|
||||
sudo rm -rf /opt/hostedtoolcache
|
||||
df -h
|
||||
uses: jlumbroso/free-disk-space@54081f138730dfa15788a46383842cd2f914a1be # v1.3.1
|
||||
with:
|
||||
tool-cache: true
|
||||
docker-images: false
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.9"
|
||||
|
||||
- name: Get llamafactory version
|
||||
id: version
|
||||
run: |
|
||||
echo "tag=$(python setup.py --version | sed 's/\.dev0//')" >> "$GITHUB_OUTPUT"
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
if: github.event_name != 'pull_request'
|
||||
if: ${{ github.event_name != 'pull_request' }}
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ vars.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
|
||||
- name: Build and push Docker image
|
||||
- name: Login to Quay
|
||||
if: ${{ github.event_name != 'pull_request' && matrix.device == 'npu' }}
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: quay.io
|
||||
username: ${{ vars.QUAY_ASCEND_USERNAME }}
|
||||
password: ${{ secrets.QUAY_ASCEND_TOKEN }}
|
||||
|
||||
- name: Build and push Docker image (CUDA)
|
||||
if: ${{ matrix.device == 'cuda' }}
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
context: .
|
||||
@ -61,6 +85,24 @@ jobs:
|
||||
build-args: |
|
||||
EXTRAS=metrics,deepspeed,liger-kernel
|
||||
push: ${{ github.event_name != 'pull_request' }}
|
||||
tags: docker.io/hiyouga/llamafactory:latest
|
||||
tags: |
|
||||
docker.io/hiyouga/llamafactory:latest
|
||||
docker.io/hiyouga/llamafactory:${{ steps.version.outputs.tag }}
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
||||
|
||||
- name: Build and push Docker image (NPU)
|
||||
if: ${{ matrix.device == 'npu' }}
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
context: .
|
||||
platforms: linux/amd64,linux/arm64
|
||||
file: ./docker/docker-npu/Dockerfile
|
||||
push: ${{ github.event_name != 'pull_request' }}
|
||||
tags: |
|
||||
docker.io/hiyouga/llamafactory:latest-npu-a2
|
||||
docker.io/hiyouga/llamafactory:${{ steps.version.outputs.tag }}-npu-a2
|
||||
quay.io/ascend/llamafactory:latest-npu-a2
|
||||
quay.io/ascend/llamafactory:${{ steps.version.outputs.tag }}-npu-a2
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
||||
|
2
.github/workflows/label_issue.yml
vendored
2
.github/workflows/label_issue.yml
vendored
@ -19,7 +19,7 @@ jobs:
|
||||
ISSUE_TITLE: ${{ github.event.issue.title }}
|
||||
run: |
|
||||
LABEL=""
|
||||
NPU_KEYWORDS=(npu huawei ascend 华为 昇腾)
|
||||
NPU_KEYWORDS=(npu huawei ascend 华为 昇腾 910)
|
||||
ISSUE_TITLE_LOWER=$(echo $ISSUE_TITLE | tr '[:upper:]' '[:lower:]')
|
||||
for KEYWORD in ${NPU_KEYWORDS[@]}; do
|
||||
if [[ $ISSUE_TITLE_LOWER == *$KEYWORD* ]] && [[ $ISSUE_TITLE_LOWER != *input* ]]; then
|
||||
|
7
.github/workflows/tests.yml
vendored
7
.github/workflows/tests.yml
vendored
@ -6,14 +6,14 @@ on:
|
||||
branches:
|
||||
- "main"
|
||||
paths:
|
||||
- "**.py"
|
||||
- "**/*.py"
|
||||
- "requirements.txt"
|
||||
- ".github/workflows/*.yml"
|
||||
pull_request:
|
||||
branches:
|
||||
- "main"
|
||||
paths:
|
||||
- "**.py"
|
||||
- "**/*.py"
|
||||
- "requirements.txt"
|
||||
- ".github/workflows/*.yml"
|
||||
|
||||
@ -34,9 +34,6 @@ jobs:
|
||||
transformers:
|
||||
- null
|
||||
include: # test backward compatibility
|
||||
- python: "3.9"
|
||||
os: "ubuntu-latest"
|
||||
transformers: "4.45.0"
|
||||
- python: "3.9"
|
||||
os: "ubuntu-latest"
|
||||
transformers: "4.49.0"
|
||||
|
37
README.md
37
README.md
@ -5,7 +5,7 @@
|
||||
[](https://github.com/hiyouga/LLaMA-Factory/graphs/contributors)
|
||||
[](https://github.com/hiyouga/LLaMA-Factory/actions/workflows/tests.yml)
|
||||
[](https://pypi.org/project/llamafactory/)
|
||||
[](https://scholar.google.com/scholar?cites=12620864006390196564)
|
||||
[](https://scholar.google.com/scholar?cites=12620864006390196564)
|
||||
[](https://hub.docker.com/r/hiyouga/llamafactory/tags)
|
||||
|
||||
[](https://twitter.com/llamafactory_ai)
|
||||
@ -51,7 +51,8 @@ https://github.com/user-attachments/assets/3991a3a8-4276-4d30-9cab-4cb0c4b9b99e
|
||||
|
||||
Choose your path:
|
||||
|
||||
- **Documentation**: https://llamafactory.readthedocs.io/en/latest/
|
||||
- **Documentation (WIP)**: https://llamafactory.readthedocs.io/en/latest/
|
||||
- **Documentation (AMD GPU)**: https://rocm.docs.amd.com/projects/ai-developer-hub/en/latest/notebooks/fine_tune/llama_factory_llama3.html
|
||||
- **Colab (free)**: https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing
|
||||
- **Local machine**: Please refer to [usage](#getting-started)
|
||||
- **PAI-DSW (free trial)**: https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory
|
||||
@ -98,28 +99,32 @@ Choose your path:
|
||||
|
||||
### Day-N Support for Fine-Tuning Cutting-Edge Models
|
||||
|
||||
| Support Date | Model Name |
|
||||
| ------------ | ------------------------------------------------------------ |
|
||||
| Day 0 | Qwen3 / Qwen2.5-VL / Gemma 3 / InternLM 3 / MiniCPM-o-2.6 |
|
||||
| Day 1 | Llama 3 / GLM-4 / Mistral Small / PaliGemma2 / Llama 4 |
|
||||
| Support Date | Model Name |
|
||||
| ------------ | -------------------------------------------------------------------- |
|
||||
| Day 0 | Qwen3 / Qwen2.5-VL / Gemma 3 / GLM-4.1V / InternLM 3 / MiniCPM-o-2.6 |
|
||||
| Day 1 | Llama 3 / GLM-4 / Mistral Small / PaliGemma2 / Llama 4 |
|
||||
|
||||
## Blogs
|
||||
|
||||
- [Fine-tune Qwen2.5-VL for Autonomous Driving using LLaMA-Factory](https://docs.alayanew.com/docs/documents/useGuide/LLaMAFactory/mutiple/?utm_source=LLaMA-Factory) (Chinese)
|
||||
- [Fine-tune Llama3.1-70B for Medical Diagnosis using LLaMA-Factory](https://docs.alayanew.com/docs/documents/bestPractice/bigModel/llama70B/) (Chinese)
|
||||
- [A One-Stop Code-Free Model Reinforcement Learning and Deployment Platform based on LLaMA-Factory and EasyR1](https://aws.amazon.com/cn/blogs/china/building-llm-model-hub-based-on-llamafactory-and-easyr1/) (Chinese)
|
||||
- [How Apoidea Group enhances visual information extraction from banking documents with multimodal models using LLaMA-Factory on Amazon SageMaker HyperPod](https://aws.amazon.com/cn/blogs/machine-learning/how-apoidea-group-enhances-visual-information-extraction-from-banking-documents-with-multimodal-models-using-llama-factory-on-amazon-sagemaker-hyperpod/) (English)
|
||||
- [Easy Dataset × LLaMA Factory: Enabling LLMs to Efficiently Learn Domain Knowledge](https://buaa-act.feishu.cn/wiki/GVzlwYcRFiR8OLkHbL6cQpYin7g) (English)
|
||||
|
||||
<details><summary>All Blogs</summary>
|
||||
|
||||
- [Fine-tune Qwen2.5-VL for Autonomous Driving using LLaMA-Factory](https://docs.alayanew.com/docs/documents/useGuide/LLaMAFactory/mutiple/?utm_source=LLaMA-Factory) (Chinese)
|
||||
- [LLaMA Factory: Fine-tuning the DeepSeek-R1-Distill-Qwen-7B Model for News Classifier](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory_deepseek_r1_distill_7b) (Chinese)
|
||||
- [A One-Stop Code-Free Model Fine-Tuning \& Deployment Platform based on SageMaker and LLaMA-Factory](https://aws.amazon.com/cn/blogs/china/a-one-stop-code-free-model-fine-tuning-deployment-platform-based-on-sagemaker-and-llama-factory/) (Chinese)
|
||||
- [LLaMA Factory Multi-Modal Fine-Tuning Practice: Fine-Tuning Qwen2-VL for Personal Tourist Guide](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory_qwen2vl) (Chinese)
|
||||
- [LLaMA Factory: Fine-tuning the LLaMA3 Model for Role-Playing](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory) (Chinese)
|
||||
- [LLaMA Factory: Fine-tuning Llama3 for Role-Playing](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory) (Chinese)
|
||||
|
||||
</details>
|
||||
|
||||
## Changelog
|
||||
|
||||
[25/07/02] We supported fine-tuning the **[GLM-4.1V-9B-Thinking](https://github.com/THUDM/GLM-4.1V-Thinking)** model. Please install transformers from **main** branch to use.
|
||||
|
||||
[25/04/28] We supported fine-tuning the **[Qwen3](https://qwenlm.github.io/blog/qwen3/)** model family.
|
||||
|
||||
[25/04/21] We supported the **[Muon](https://github.com/KellerJordan/Muon)** optimizer. See [examples](examples/README.md) for usage. Thank [@tianshijing](https://github.com/tianshijing)'s PR.
|
||||
@ -261,11 +266,15 @@ Choose your path:
|
||||
| [DeepSeek 2.5/3](https://huggingface.co/deepseek-ai) | 236B/671B | deepseek3 |
|
||||
| [DeepSeek R1 (Distill)](https://huggingface.co/deepseek-ai) | 1.5B/7B/8B/14B/32B/70B/671B | deepseekr1 |
|
||||
| [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon |
|
||||
| [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma |
|
||||
| [Gemma 3](https://huggingface.co/google) | 1B/4B/12B/27B | gemma3/gemma (1B) |
|
||||
| [GLM-4/GLM-4-0414/GLM-Z1](https://huggingface.co/THUDM) | 9B/32B | glm4/glmz1 |
|
||||
| [Falcon-H1](https://huggingface.co/tiiuae) | 0.5B/1.5B/3B/7B/34B | falcon_h1 |
|
||||
| [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma/gemma2 |
|
||||
| [Gemma 3/Gemma 3n](https://huggingface.co/google) | 1B/4B/6B/8B/12B/27B | gemma3/gemma3n |
|
||||
| [GLM-4/GLM-4-0414/GLM-Z1](https://huggingface.co/zai-org) | 9B/32B | glm4/glmz1 |
|
||||
| [GLM-4.1V](https://huggingface.co/zai-org)* | 9B | glm4v |
|
||||
| [GLM-4.5](https://huggingface.co/zai-org)* | 106B/355B | glm4_moe |
|
||||
| [GPT-2](https://huggingface.co/openai-community) | 0.1B/0.4B/0.8B/1.5B | - |
|
||||
| [Granite 3.0-3.3](https://huggingface.co/ibm-granite) | 1B/2B/3B/8B | granite3 |
|
||||
| [Granite 4](https://huggingface.co/ibm-granite) | 7B | granite4 |
|
||||
| [Hunyuan](https://huggingface.co/tencent/) | 7B | hunyuan |
|
||||
| [Index](https://huggingface.co/IndexTeam) | 1.9B | index |
|
||||
| [InternLM 2-3](https://huggingface.co/internlm) | 7B/8B/20B | intern2 |
|
||||
@ -443,7 +452,7 @@ huggingface-cli login
|
||||
| python | 3.9 | 3.10 |
|
||||
| torch | 2.0.0 | 2.6.0 |
|
||||
| torchvision | 0.15.0 | 0.21.0 |
|
||||
| transformers | 4.45.0 | 4.50.0 |
|
||||
| transformers | 4.49.0 | 4.50.0 |
|
||||
| datasets | 2.16.0 | 3.2.0 |
|
||||
| accelerate | 0.34.0 | 1.2.1 |
|
||||
| peft | 0.14.0 | 0.15.1 |
|
||||
@ -485,7 +494,7 @@ cd LLaMA-Factory
|
||||
pip install -e ".[torch,metrics]" --no-build-isolation
|
||||
```
|
||||
|
||||
Extra dependencies available: torch, torch-npu, metrics, deepspeed, liger-kernel, bitsandbytes, hqq, eetq, gptq, aqlm, vllm, sglang, galore, apollo, badam, adam-mini, qwen, minicpm_v, modelscope, openmind, swanlab, dev
|
||||
Extra dependencies available: torch, torch-npu, metrics, deepspeed, liger-kernel, bitsandbytes, hqq, eetq, gptq, aqlm, vllm, sglang, galore, apollo, badam, adam-mini, qwen, minicpm_v, openmind, swanlab, dev
|
||||
|
||||
#### Install from Docker Image
|
||||
|
||||
@ -620,7 +629,7 @@ Please refer to [data/README.md](data/README.md) for checking the details about
|
||||
> [!NOTE]
|
||||
> Please update `data/dataset_info.json` to use your custom dataset.
|
||||
|
||||
You can also use **[Easy Dataset](https://github.com/ConardLi/easy-dataset)** or **[GraphGen](https://github.com/open-sciencelab/GraphGen)** to create synthetic data for fine-tuning.
|
||||
You can also use **[Easy Dataset](https://github.com/ConardLi/easy-dataset)**, **[DataFlow](https://github.com/OpenDCAI/DataFlow)** and **[GraphGen](https://github.com/open-sciencelab/GraphGen)** to create synthetic data for fine-tuning.
|
||||
|
||||
### Quickstart
|
||||
|
||||
|
35
README_zh.md
35
README_zh.md
@ -5,7 +5,7 @@
|
||||
[](https://github.com/hiyouga/LLaMA-Factory/graphs/contributors)
|
||||
[](https://github.com/hiyouga/LLaMA-Factory/actions/workflows/tests.yml)
|
||||
[](https://pypi.org/project/llamafactory/)
|
||||
[](https://scholar.google.com/scholar?cites=12620864006390196564)
|
||||
[](https://scholar.google.com/scholar?cites=12620864006390196564)
|
||||
[](https://hub.docker.com/r/hiyouga/llamafactory/tags)
|
||||
|
||||
[](https://twitter.com/llamafactory_ai)
|
||||
@ -52,6 +52,7 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
|
||||
选择你的打开方式:
|
||||
|
||||
- **入门教程**:https://zhuanlan.zhihu.com/p/695287607
|
||||
- **微调视频教程**:https://www.bilibili.com/video/BV1djgRzxEts/
|
||||
- **框架文档**:https://llamafactory.readthedocs.io/zh-cn/latest/
|
||||
- **框架文档(昇腾 NPU)**:https://ascend.github.io/docs/sources/llamafactory/
|
||||
- **Colab(免费)**:https://colab.research.google.com/drive/1d5KQtbemerlSDSxZIfAaWXhKr30QypiK?usp=sharing
|
||||
@ -100,28 +101,32 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
|
||||
|
||||
### 最新模型的 Day-N 微调适配
|
||||
|
||||
| 适配时间 | 模型名称 |
|
||||
| ------------ | ------------------------------------------------------------ |
|
||||
| Day 0 | Qwen3 / Qwen2.5-VL / Gemma 3 / InternLM 3 / MiniCPM-o-2.6 |
|
||||
| Day 1 | Llama 3 / GLM-4 / Mistral Small / PaliGemma2 / Llama 4 |
|
||||
| 适配时间 | 模型名称 |
|
||||
| ------------ | -------------------------------------------------------------------- |
|
||||
| Day 0 | Qwen3 / Qwen2.5-VL / Gemma 3 / GLM-4.1V / InternLM 3 / MiniCPM-o-2.6 |
|
||||
| Day 1 | Llama 3 / GLM-4 / Mistral Small / PaliGemma2 / Llama 4 |
|
||||
|
||||
## 官方博客
|
||||
|
||||
- [使用 LLaMA-Factory 微调 Qwen2.5-VL 实现自动驾驶场景微调](https://docs.alayanew.com/docs/documents/useGuide/LLaMAFactory/mutiple/?utm_source=LLaMA-Factory)(中文)
|
||||
- [使用 LLaMA-Factory 微调 Llama3.1-70B 医学诊断模型](https://docs.alayanew.com/docs/documents/bestPractice/bigModel/llama70B/)(中文)
|
||||
- [基于 LLaMA-Factory 和 EasyR1 打造一站式无代码大模型强化学习和部署平台 LLM Model Hub](https://aws.amazon.com/cn/blogs/china/building-llm-model-hub-based-on-llamafactory-and-easyr1/)(中文)
|
||||
- [通过亚马逊 SageMaker HyperPod 上的 LLaMA-Factory 增强多模态模型银行文档的视觉信息提取](https://aws.amazon.com/cn/blogs/machine-learning/how-apoidea-group-enhances-visual-information-extraction-from-banking-documents-with-multimodal-models-using-llama-factory-on-amazon-sagemaker-hyperpod/)(英文)
|
||||
- [Easy Dataset × LLaMA Factory: 让大模型高效学习领域知识](https://buaa-act.feishu.cn/wiki/KY9xwTGs1iqHrRkjXBwcZP9WnL9)(中文)
|
||||
|
||||
<details><summary>全部博客</summary>
|
||||
|
||||
- [使用 LLaMA-Factory 微调 Qwen2.5-VL 实现自动驾驶场景微调](https://docs.alayanew.com/docs/documents/useGuide/LLaMAFactory/mutiple/?utm_source=LLaMA-Factory)(中文)
|
||||
- [LLaMA Factory:微调 DeepSeek-R1-Distill-Qwen-7B 模型实现新闻标题分类器](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory_deepseek_r1_distill_7b)(中文)
|
||||
- [基于 Amazon SageMaker 和 LLaMA-Factory 打造一站式无代码模型微调部署平台 Model Hub](https://aws.amazon.com/cn/blogs/china/a-one-stop-code-free-model-fine-tuning-deployment-platform-based-on-sagemaker-and-llama-factory/)(中文)
|
||||
- [LLaMA Factory 多模态微调实践:微调 Qwen2-VL 构建文旅大模型](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory_qwen2vl)(中文)
|
||||
- [LLaMA Factory:微调LLaMA3模型实现角色扮演](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory)(中文)
|
||||
- [LLaMA Factory:微调 Llama3 模型实现角色扮演](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory)(中文)
|
||||
|
||||
</details>
|
||||
|
||||
## 更新日志
|
||||
|
||||
[25/07/02] 我们支持了 **[GLM-4.1V-9B-Thinking](https://github.com/THUDM/GLM-4.1V-Thinking)** 模型的微调。请安装 transformers 的 main 分支版本以使用。
|
||||
|
||||
[25/04/28] 我们支持了 **[Qwen3](https://qwenlm.github.io/blog/qwen3/)** 系列模型的微调。
|
||||
|
||||
[25/04/21] 我们支持了 **[Muon](https://github.com/KellerJordan/Muon)** 优化器。详细用法请参照 [examples](examples/README_zh.md)。感谢 [@tianshijing](https://github.com/tianshijing) 的 PR。
|
||||
@ -263,11 +268,15 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
|
||||
| [DeepSeek 2.5/3](https://huggingface.co/deepseek-ai) | 236B/671B | deepseek3 |
|
||||
| [DeepSeek R1 (Distill)](https://huggingface.co/deepseek-ai) | 1.5B/7B/8B/14B/32B/70B/671B | deepseekr1 |
|
||||
| [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon |
|
||||
| [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma |
|
||||
| [Gemma 3](https://huggingface.co/google) | 1B/4B/12B/27B | gemma3/gemma (1B) |
|
||||
| [GLM-4/GLM-4-0414/GLM-Z1](https://huggingface.co/THUDM) | 9B/32B | glm4/glmz1 |
|
||||
| [Falcon-H1](https://huggingface.co/tiiuae) | 0.5B/1.5B/3B/7B/34B | falcon_h1 |
|
||||
| [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma/gemma2 |
|
||||
| [Gemma 3/Gemma 3n](https://huggingface.co/google) | 1B/4B/6B/8B/12B/27B | gemma3/gemma3n |
|
||||
| [GLM-4/GLM-4-0414/GLM-Z1](https://huggingface.co/zai-org) | 9B/32B | glm4/glmz1 |
|
||||
| [GLM-4.1V](https://huggingface.co/zai-org)* | 9B | glm4v |
|
||||
| [GLM-4.5](https://huggingface.co/zai-org)* | 106B/355B | glm4_moe |
|
||||
| [GPT-2](https://huggingface.co/openai-community) | 0.1B/0.4B/0.8B/1.5B | - |
|
||||
| [Granite 3.0-3.3](https://huggingface.co/ibm-granite) | 1B/2B/3B/8B | granite3 |
|
||||
| [Granite 4](https://huggingface.co/ibm-granite) | 7B | granite4 |
|
||||
| [Hunyuan](https://huggingface.co/tencent/) | 7B | hunyuan |
|
||||
| [Index](https://huggingface.co/IndexTeam) | 1.9B | index |
|
||||
| [InternLM 2-3](https://huggingface.co/internlm) | 7B/8B/20B | intern2 |
|
||||
@ -445,7 +454,7 @@ huggingface-cli login
|
||||
| python | 3.9 | 3.10 |
|
||||
| torch | 2.0.0 | 2.6.0 |
|
||||
| torchvision | 0.15.0 | 0.21.0 |
|
||||
| transformers | 4.45.0 | 4.50.0 |
|
||||
| transformers | 4.49.0 | 4.50.0 |
|
||||
| datasets | 2.16.0 | 3.2.0 |
|
||||
| accelerate | 0.34.0 | 1.2.1 |
|
||||
| peft | 0.14.0 | 0.15.1 |
|
||||
@ -487,7 +496,7 @@ cd LLaMA-Factory
|
||||
pip install -e ".[torch,metrics]" --no-build-isolation
|
||||
```
|
||||
|
||||
可选的额外依赖项:torch、torch-npu、metrics、deepspeed、liger-kernel、bitsandbytes、hqq、eetq、gptq、aqlm、vllm、sglang、galore、apollo、badam、adam-mini、qwen、minicpm_v、modelscope、openmind、swanlab、dev
|
||||
可选的额外依赖项:torch、torch-npu、metrics、deepspeed、liger-kernel、bitsandbytes、hqq、eetq、gptq、aqlm、vllm、sglang、galore、apollo、badam、adam-mini、qwen、minicpm_v、openmind、swanlab、dev
|
||||
|
||||
#### 从镜像安装
|
||||
|
||||
@ -622,7 +631,7 @@ pip install .
|
||||
> [!NOTE]
|
||||
> 使用自定义数据集时,请更新 `data/dataset_info.json` 文件。
|
||||
|
||||
您也可以使用 **[Easy Dataset](https://github.com/ConardLi/easy-dataset)** 或 **[GraphGen](https://github.com/open-sciencelab/GraphGen)** 构建用于微调的合成数据。
|
||||
您也可以使用 **[Easy Dataset](https://github.com/ConardLi/easy-dataset)**、**[DataFlow](https://github.com/OpenDCAI/DataFlow)** 和 **[GraphGen](https://github.com/open-sciencelab/GraphGen)** 构建用于微调的合成数据。
|
||||
|
||||
### 快速开始
|
||||
|
||||
|
Binary file not shown.
Before Width: | Height: | Size: 168 KiB After Width: | Height: | Size: 166 KiB |
Binary file not shown.
Before Width: | Height: | Size: 168 KiB After Width: | Height: | Size: 171 KiB |
@ -173,7 +173,7 @@ An additional column `audios` is required. Please refer to the [sharegpt](#share
|
||||
|
||||
Compared to the alpaca format, the sharegpt format allows the datasets have **more roles**, such as human, gpt, observation and function. They are presented in a list of objects in the `conversations` column.
|
||||
|
||||
Note that the human and observation should appear in odd positions, while gpt and function should appear in even positions.
|
||||
Note that the human and observation should appear in odd positions, while gpt and function should appear in even positions. The gpt and function will be learned by the model.
|
||||
|
||||
```json
|
||||
[
|
||||
|
@ -172,7 +172,7 @@ KTO 数据集需要提供额外的 `kto_tag` 列。详情请参阅 [sharegpt](#s
|
||||
|
||||
相比 alpaca 格式的数据集,sharegpt 格式支持**更多的角色种类**,例如 human、gpt、observation、function 等等。它们构成一个对象列表呈现在 `conversations` 列中。
|
||||
|
||||
注意其中 human 和 observation 必须出现在奇数位置,gpt 和 function 必须出现在偶数位置。
|
||||
注意其中 human 和 observation 必须出现在奇数位置,gpt 和 function 必须出现在偶数位置。默认所有的 gpt 和 function 会被用于学习。
|
||||
|
||||
```json
|
||||
[
|
||||
|
@ -4154,11 +4154,6 @@
|
||||
"input": "The beauty of friendship",
|
||||
"output": "In life we all embark on journeys grand,\nAnd wander through the trials of our days;\nAt times we'll come across a barren land,\nWhere hardships test and challenge in new ways.\n\nBut through the toil, we're never all alone,\nFor friendship shines as brilliant as the sun;\nWith bonds that strengthen, deepen and have grown,\nA beauty that cannot be rivaled, none.\n\nA friend who's there with hands to lift us high,\nThrough laughter and the joy, through tears and pain,\nA beacon when the road ahead's unclear,\nA treasure more than riches one could gain.\n\nOh friendship, pure and ever shining bright,\nA beauty that brings everlasting light."
|
||||
},
|
||||
{
|
||||
"instruction": "The news and other sources have been claiming that Democracy is losing ground around the world. Give an example of a recent event to back this claim up.",
|
||||
"input": "",
|
||||
"output": "One recent example that supports this claim is the events in Hong Kong. In 2019 and 2020, the Chinese government imposed a national security law to crack down on pro-democracy protests in Hong Kong, effectively curtailing the freedoms and autonomy that the region once enjoyed. This move by Beijing to assert greater control over Hong Kong has led to concerns by human rights advocates and international leaders that democracy is being undermined in the territory. There is fear that Hong Kong’s outspoken media, independent judiciary, and civil liberties look increasingly vulnerable as Beijing tightens its grip on the city. Additionally, there have been arrests of multiple leading pro-Democracy figures such as Joshua Wong, Agnes Chow and Jimmy Lai in addition to many others who have become political prisoners."
|
||||
},
|
||||
{
|
||||
"instruction": "Redact the following sentence to remove all personal information: \"John Smith from Birmingham works in a factory.\"",
|
||||
"input": "",
|
||||
|
111
docker/docker-cuda/README.md
Normal file
111
docker/docker-cuda/README.md
Normal file
@ -0,0 +1,111 @@
|
||||
# Docker Setup for NVIDIA GPUs
|
||||
|
||||
This directory contains Docker configuration files for running LLaMA Factory with NVIDIA GPU support.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
### Linux-specific Requirements
|
||||
|
||||
Before running the Docker container with GPU support, you need to install the following packages:
|
||||
|
||||
1. **Docker**: The container runtime
|
||||
```bash
|
||||
# Ubuntu/Debian
|
||||
sudo apt-get update
|
||||
sudo apt-get install docker.io
|
||||
|
||||
# Or install Docker Engine from the official repository:
|
||||
# https://docs.docker.com/engine/install/
|
||||
```
|
||||
|
||||
2. **Docker Compose** (if using the docker-compose method):
|
||||
```bash
|
||||
# Ubuntu/Debian
|
||||
sudo apt-get install docker-compose
|
||||
|
||||
# Or install the latest version:
|
||||
# https://docs.docker.com/compose/install/
|
||||
```
|
||||
|
||||
3. **NVIDIA Container Toolkit** (required for GPU support):
|
||||
```bash
|
||||
# Add the NVIDIA GPG key and repository
|
||||
distribution=$(. /etc/os-release;echo $ID$VERSION_ID)
|
||||
curl -s -L https://nvidia.github.io/nvidia-docker/gpgkey | sudo apt-key add -
|
||||
curl -s -L https://nvidia.github.io/nvidia-docker/$distribution/nvidia-docker.list | sudo tee /etc/apt/sources.list.d/nvidia-docker.list
|
||||
|
||||
# Install nvidia-container-toolkit
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y nvidia-container-toolkit
|
||||
|
||||
# Restart Docker to apply changes
|
||||
sudo systemctl restart docker
|
||||
```
|
||||
|
||||
**Note**: Without `nvidia-container-toolkit`, the Docker container will not be able to access your NVIDIA GPU.
|
||||
|
||||
### Verify GPU Access
|
||||
|
||||
After installation, verify that Docker can access your GPU:
|
||||
|
||||
```bash
|
||||
sudo docker run --rm --gpus all nvidia/cuda:12.4.0-base-ubuntu22.04 nvidia-smi
|
||||
```
|
||||
|
||||
If successful, you should see your GPU information displayed.
|
||||
|
||||
## Usage
|
||||
|
||||
### Using Docker Compose (Recommended)
|
||||
|
||||
```bash
|
||||
cd docker/docker-cuda/
|
||||
docker compose up -d
|
||||
docker compose exec llamafactory bash
|
||||
```
|
||||
|
||||
### Using Docker Run
|
||||
|
||||
```bash
|
||||
# Build the image
|
||||
docker build -f ./docker/docker-cuda/Dockerfile \
|
||||
--build-arg PIP_INDEX=https://pypi.org/simple \
|
||||
--build-arg EXTRAS=metrics \
|
||||
-t llamafactory:latest .
|
||||
|
||||
# Run the container
|
||||
docker run -dit --ipc=host --gpus=all \
|
||||
-p 7860:7860 \
|
||||
-p 8000:8000 \
|
||||
--name llamafactory \
|
||||
llamafactory:latest
|
||||
|
||||
# Enter the container
|
||||
docker exec -it llamafactory bash
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### GPU Not Detected
|
||||
|
||||
If your GPU is not detected inside the container:
|
||||
|
||||
1. Ensure `nvidia-container-toolkit` is installed
|
||||
2. Check that the Docker daemon has been restarted after installation
|
||||
3. Verify your NVIDIA drivers are properly installed: `nvidia-smi`
|
||||
4. Check Docker GPU support: `docker run --rm --gpus all ubuntu nvidia-smi`
|
||||
|
||||
### Permission Denied
|
||||
|
||||
If you get permission errors, ensure your user is in the docker group:
|
||||
|
||||
```bash
|
||||
sudo usermod -aG docker $USER
|
||||
# Log out and back in for changes to take effect
|
||||
```
|
||||
|
||||
## Additional Notes
|
||||
|
||||
- The default image is built on Ubuntu 22.04 (x86_64), CUDA 12.4, Python 3.11, PyTorch 2.6.0, and Flash-attn 2.7.4
|
||||
- For different CUDA versions, you may need to adjust the base image in the Dockerfile
|
||||
- Make sure your NVIDIA driver version is compatible with the CUDA version used in the Docker image
|
@ -1,11 +1,12 @@
|
||||
# https://hub.docker.com/r/ascendai/cann/tags
|
||||
ARG BASE_IMAGE=ascendai/cann:8.0.0-910b-ubuntu22.04-py3.11
|
||||
ARG BASE_IMAGE=ascendai/cann:8.1.rc1-910b-ubuntu22.04-py3.11
|
||||
FROM ${BASE_IMAGE}
|
||||
|
||||
# Installation arguments
|
||||
ARG PIP_INDEX=https://pypi.org/simple
|
||||
ARG EXTRAS=torch-npu,metrics
|
||||
ARG HTTP_PROXY=""
|
||||
ARG PYTORCH_INDEX=https://download.pytorch.org/whl/cpu
|
||||
|
||||
# Define environments
|
||||
ENV MAX_JOBS=16
|
||||
@ -28,6 +29,10 @@ RUN pip config set global.index-url "${PIP_INDEX}" && \
|
||||
pip config set global.extra-index-url "${PIP_INDEX}" && \
|
||||
pip install --no-cache-dir --upgrade pip packaging wheel setuptools
|
||||
|
||||
# Install torch-npu
|
||||
RUN pip uninstall -y torch torchvision torchaudio && \
|
||||
pip install --no-cache-dir "torch-npu==2.5.1" "torchvision==0.20.1" --index-url "${PYTORCH_INDEX}"
|
||||
|
||||
# Install the requirements
|
||||
COPY requirements.txt /app
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
@ -1,27 +1,36 @@
|
||||
transformers>=4.45.0,<=4.52.4,!=4.46.*,!=4.47.*,!=4.48.0,!=4.52.0; sys_platform != 'darwin'
|
||||
transformers>=4.45.0,<=4.51.3,!=4.46.*,!=4.47.*,!=4.48.0,!=4.52.0; sys_platform == 'darwin'
|
||||
# core deps
|
||||
transformers>=4.49.0,<=4.52.4,!=4.52.0; sys_platform != 'darwin'
|
||||
transformers>=4.49.0,<=4.51.3,!=4.52.0; sys_platform == 'darwin'
|
||||
datasets>=2.16.0,<=3.6.0
|
||||
accelerate>=0.34.0,<=1.7.0
|
||||
accelerate>=1.3.0,<=1.7.0
|
||||
peft>=0.14.0,<=0.15.2
|
||||
trl>=0.8.6,<=0.9.6
|
||||
tokenizers>=0.19.0,<=0.21.1
|
||||
# gui
|
||||
gradio>=4.38.0,<=5.31.0
|
||||
scipy
|
||||
matplotlib>=3.7.0
|
||||
tyro<0.9.0
|
||||
# ops
|
||||
einops
|
||||
numpy<2.0.0
|
||||
pandas>=2.0.0
|
||||
scipy
|
||||
# model and tokenizer
|
||||
sentencepiece
|
||||
tiktoken
|
||||
protobuf
|
||||
uvicorn
|
||||
fastapi
|
||||
sse-starlette
|
||||
matplotlib>=3.7.0
|
||||
modelscope>=1.14.0
|
||||
hf-transfer
|
||||
# python
|
||||
fire
|
||||
omegaconf
|
||||
packaging
|
||||
protobuf
|
||||
pyyaml
|
||||
numpy<2.0.0
|
||||
pydantic<=2.10.6
|
||||
pandas>=2.0.0
|
||||
# api
|
||||
uvicorn
|
||||
fastapi
|
||||
sse-starlette
|
||||
# media
|
||||
av
|
||||
librosa
|
||||
tyro<0.9.0
|
||||
|
5
setup.py
5
setup.py
@ -43,7 +43,7 @@ def get_console_scripts() -> list[str]:
|
||||
|
||||
extra_require = {
|
||||
"torch": ["torch>=2.0.0", "torchvision>=0.15.0"],
|
||||
"torch-npu": ["torch==2.4.0", "torch-npu==2.4.0.post2", "decorator"],
|
||||
"torch-npu": ["torch-npu==2.5.1", "torchvision==0.20.1", "decorator"],
|
||||
"metrics": ["nltk", "jieba", "rouge-chinese"],
|
||||
"deepspeed": ["deepspeed>=0.10.0,<=0.16.9"],
|
||||
"liger-kernel": ["liger-kernel>=0.5.5"],
|
||||
@ -52,7 +52,7 @@ extra_require = {
|
||||
"eetq": ["eetq"],
|
||||
"gptq": ["optimum>=1.24.0", "gptqmodel>=2.0.0"],
|
||||
"aqlm": ["aqlm[gpu]>=1.1.0"],
|
||||
"vllm": ["vllm>=0.4.3,<=0.8.6"],
|
||||
"vllm": ["vllm>=0.4.3,<=0.10.0"],
|
||||
"sglang": ["sglang[srt]>=0.4.5", "transformers==4.51.1"],
|
||||
"galore": ["galore-torch"],
|
||||
"apollo": ["apollo-torch"],
|
||||
@ -68,7 +68,6 @@ extra_require = {
|
||||
"referencing",
|
||||
"jsonschema_specifications",
|
||||
],
|
||||
"modelscope": ["modelscope"],
|
||||
"openmind": ["openmind"],
|
||||
"swanlab": ["swanlab"],
|
||||
"dev": ["pre-commit", "ruff", "pytest", "build"],
|
||||
|
@ -132,7 +132,7 @@ def _process_request(
|
||||
if re.match(r"^data:video\/(mp4|mkv|avi|mov);base64,(.+)$", video_url): # base64 video
|
||||
video_stream = io.BytesIO(base64.b64decode(video_url.split(",", maxsplit=1)[1]))
|
||||
elif os.path.isfile(video_url): # local file
|
||||
video_stream = open(video_url, "rb")
|
||||
video_stream = video_url
|
||||
else: # web uri
|
||||
video_stream = requests.get(video_url, stream=True).raw
|
||||
|
||||
@ -143,7 +143,7 @@ def _process_request(
|
||||
if re.match(r"^data:audio\/(mpeg|mp3|wav|ogg);base64,(.+)$", audio_url): # base64 audio
|
||||
audio_stream = io.BytesIO(base64.b64decode(audio_url.split(",", maxsplit=1)[1]))
|
||||
elif os.path.isfile(audio_url): # local file
|
||||
audio_stream = open(audio_url, "rb")
|
||||
audio_stream = audio_url
|
||||
else: # web uri
|
||||
audio_stream = requests.get(audio_url, stream=True).raw
|
||||
|
||||
|
@ -210,10 +210,11 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
|
||||
if (
|
||||
self.model is not None
|
||||
and getattr(self.model.config, "model_type", None) in ["qwen2_vl", "qwen2_5_vl", "qwen2_5_omni_thinker"]
|
||||
and getattr(self.model.config, "model_type", None)
|
||||
in ["glm4v", "Keye", "qwen2_vl", "qwen2_5_vl", "qwen2_5_omni_thinker"]
|
||||
and ("position_ids" not in features or features["position_ids"].dim() != 3)
|
||||
):
|
||||
raise ValueError("Qwen2-VL/Qwen2.5-Omni model requires 3D position ids for mrope.")
|
||||
raise ValueError(f"{self.model.config.model_type} requires 3D position ids for mrope.")
|
||||
|
||||
if "cross_attention_mask" in mm_inputs: # for mllama inputs when pad_to_multiple_of is enabled
|
||||
cross_attention_mask = mm_inputs.pop("cross_attention_mask")
|
||||
|
@ -91,7 +91,7 @@ def _load_single_dataset(
|
||||
raise NotImplementedError(f"Unknown load type: {dataset_attr.load_from}.")
|
||||
|
||||
if dataset_attr.load_from == "ms_hub":
|
||||
check_version("modelscope>=1.11.0", mandatory=True)
|
||||
check_version("modelscope>=1.14.0", mandatory=True)
|
||||
from modelscope import MsDataset # type: ignore
|
||||
from modelscope.utils.config_ds import MS_DATASETS_CACHE # type: ignore
|
||||
|
||||
|
@ -27,6 +27,10 @@ from typing import TYPE_CHECKING, BinaryIO, Literal, Optional, TypedDict, Union
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers.image_utils import get_image_size, is_valid_image, to_numpy_array
|
||||
from transformers.models.mllama.processing_mllama import (
|
||||
convert_sparse_cross_attention_mask_to_dense,
|
||||
get_cross_attention_token_mask,
|
||||
)
|
||||
from typing_extensions import override
|
||||
|
||||
from ..extras.constants import AUDIO_PLACEHOLDER, IGNORE_INDEX, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
|
||||
@ -51,17 +55,10 @@ if is_pyav_available():
|
||||
import av
|
||||
|
||||
|
||||
if is_transformers_version_greater_than("4.45.0"):
|
||||
from transformers.models.mllama.processing_mllama import (
|
||||
convert_sparse_cross_attention_mask_to_dense,
|
||||
get_cross_attention_token_mask,
|
||||
)
|
||||
|
||||
|
||||
if is_transformers_version_greater_than("4.52.0"):
|
||||
from transformers.image_utils import make_flat_list_of_images
|
||||
from transformers.video_utils import make_batched_videos
|
||||
elif is_transformers_version_greater_than("4.49.0"):
|
||||
else:
|
||||
from transformers.image_utils import make_batched_videos, make_flat_list_of_images
|
||||
|
||||
|
||||
@ -298,11 +295,8 @@ class MMPluginMixin:
|
||||
r"""Regularizes audios to avoid error. Including reading and resampling."""
|
||||
results, sampling_rates = [], []
|
||||
for audio in audios:
|
||||
if isinstance(audio, (str, BinaryIO)):
|
||||
audio, sampling_rate = librosa.load(audio, sr=sampling_rate)
|
||||
|
||||
if not isinstance(audio, np.ndarray):
|
||||
raise ValueError(f"Expect input is a list of audios, but got {type(audio)}.")
|
||||
audio, sampling_rate = librosa.load(audio, sr=sampling_rate)
|
||||
|
||||
results.append(audio)
|
||||
sampling_rates.append(sampling_rate)
|
||||
@ -391,7 +385,7 @@ class MMPluginMixin:
|
||||
return_tensors="pt",
|
||||
)
|
||||
)
|
||||
mm_inputs["feature_attention_mask"] = mm_inputs.pop("attention_mask") # prevent conflicts
|
||||
mm_inputs["feature_attention_mask"] = mm_inputs.pop("attention_mask", None) # prevent conflicts
|
||||
|
||||
return mm_inputs
|
||||
|
||||
@ -512,6 +506,39 @@ class Gemma3Plugin(BasePlugin):
|
||||
return mm_inputs
|
||||
|
||||
|
||||
class Gemma3nPlugin(Gemma3Plugin):
|
||||
@override
|
||||
def process_messages(
|
||||
self,
|
||||
messages: list[dict[str, str]],
|
||||
images: list["ImageInput"],
|
||||
videos: list["VideoInput"],
|
||||
audios: list["AudioInput"],
|
||||
processor: Optional["MMProcessor"],
|
||||
) -> list[dict[str, str]]:
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
self._validate_messages(messages, images, videos, audios)
|
||||
messages = deepcopy(messages)
|
||||
boi_token: str = getattr(processor, "boi_token")
|
||||
boa_token: str = getattr(processor, "boa_token")
|
||||
full_image_sequence: str = getattr(processor, "full_image_sequence")
|
||||
full_audio_sequence: str = getattr(processor, "full_audio_sequence")
|
||||
image_str = full_image_sequence if self.expand_mm_tokens else boi_token
|
||||
audio_str = full_audio_sequence if self.expand_mm_tokens else boa_token
|
||||
|
||||
for message in messages:
|
||||
content = message["content"]
|
||||
while IMAGE_PLACEHOLDER in content:
|
||||
content = content.replace(IMAGE_PLACEHOLDER, image_str, 1)
|
||||
|
||||
while AUDIO_PLACEHOLDER in content:
|
||||
content = content.replace(AUDIO_PLACEHOLDER, audio_str, 1)
|
||||
|
||||
message["content"] = content
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
@dataclass
|
||||
class InternVLPlugin(BasePlugin):
|
||||
@override
|
||||
@ -1501,6 +1528,133 @@ class Qwen2VLPlugin(BasePlugin):
|
||||
return messages
|
||||
|
||||
|
||||
@dataclass
|
||||
class GLM4VPlugin(Qwen2VLPlugin):
|
||||
@override
|
||||
def _get_mm_inputs(
|
||||
self,
|
||||
images: list["ImageInput"],
|
||||
videos: list["VideoInput"],
|
||||
audios: list["AudioInput"],
|
||||
processor: "MMProcessor",
|
||||
) -> dict[str, "torch.Tensor"]:
|
||||
image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
|
||||
video_processor: BaseImageProcessor = getattr(processor, "video_processor", None)
|
||||
mm_inputs = {}
|
||||
if len(images) != 0:
|
||||
images = self._regularize_images(
|
||||
images,
|
||||
image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768),
|
||||
image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32),
|
||||
)["images"]
|
||||
mm_inputs.update(image_processor(images, return_tensors="pt"))
|
||||
|
||||
if len(videos) != 0:
|
||||
video_data = self._regularize_videos(
|
||||
videos,
|
||||
image_max_pixels=getattr(processor, "video_max_pixels", 256 * 256),
|
||||
image_min_pixels=getattr(processor, "video_min_pixels", 16 * 16),
|
||||
video_fps=getattr(processor, "video_fps", 2.0),
|
||||
video_maxlen=getattr(processor, "video_maxlen", 128),
|
||||
)
|
||||
# prepare video metadata
|
||||
video_metadata = [
|
||||
{"fps": 2, "duration": len(video), "total_frames": len(video)} for video in video_data["videos"]
|
||||
]
|
||||
mm_inputs.update(video_processor(images=None, videos=video_data["videos"], video_metadata=video_metadata))
|
||||
|
||||
return mm_inputs
|
||||
|
||||
@override
|
||||
def process_messages(
|
||||
self,
|
||||
messages: list[dict[str, str]],
|
||||
images: list["ImageInput"],
|
||||
videos: list["VideoInput"],
|
||||
audios: list["AudioInput"],
|
||||
processor: Optional["MMProcessor"],
|
||||
) -> list[dict[str, str]]:
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
self._validate_messages(messages, images, videos, audios)
|
||||
num_image_tokens, num_video_tokens = 0, 0
|
||||
messages = deepcopy(messages)
|
||||
image_processor: BaseImageProcessor = getattr(processor, "image_processor")
|
||||
|
||||
merge_length: int = getattr(image_processor, "merge_size") ** 2
|
||||
if self.expand_mm_tokens:
|
||||
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
||||
image_grid_thw = mm_inputs.get("image_grid_thw", [])
|
||||
video_grid_thw = mm_inputs.get("video_grid_thw", [])
|
||||
num_frames = video_grid_thw[0][0] if len(video_grid_thw) > 0 else 0 # hard code for now
|
||||
timestamps = mm_inputs.get("timestamps", [])
|
||||
|
||||
if hasattr(timestamps, "tolist"):
|
||||
timestamps = timestamps.tolist()
|
||||
|
||||
if not timestamps:
|
||||
timestamps_list = []
|
||||
elif isinstance(timestamps[0], list):
|
||||
timestamps_list = timestamps[0]
|
||||
else:
|
||||
timestamps_list = timestamps
|
||||
|
||||
unique_timestamps = timestamps_list.copy()
|
||||
selected_timestamps = unique_timestamps[:num_frames]
|
||||
while len(selected_timestamps) < num_frames:
|
||||
selected_timestamps.append(selected_timestamps[-1] if selected_timestamps else 0)
|
||||
|
||||
else:
|
||||
image_grid_thw = [None] * len(images)
|
||||
video_grid_thw = [None] * len(videos)
|
||||
num_frames = 0
|
||||
selected_timestamps = [0]
|
||||
|
||||
for message in messages:
|
||||
content = message["content"]
|
||||
while IMAGE_PLACEHOLDER in content:
|
||||
image_seqlen = image_grid_thw[num_image_tokens].prod() // merge_length if self.expand_mm_tokens else 1
|
||||
content = content.replace(
|
||||
IMAGE_PLACEHOLDER, f"<|begin_of_image|>{self.image_token * image_seqlen}<|end_of_image|>", 1
|
||||
)
|
||||
num_image_tokens += 1
|
||||
|
||||
while VIDEO_PLACEHOLDER in content:
|
||||
video_structure = ""
|
||||
for frame_index in range(num_frames):
|
||||
video_seqlen = (
|
||||
video_grid_thw[num_video_tokens][1:].prod() // merge_length if self.expand_mm_tokens else 1
|
||||
)
|
||||
timestamp_sec = selected_timestamps[frame_index]
|
||||
frame_structure = (
|
||||
f"<|begin_of_image|>{self.image_token * video_seqlen}<|end_of_image|>{timestamp_sec}"
|
||||
)
|
||||
video_structure += frame_structure
|
||||
|
||||
content = content.replace(VIDEO_PLACEHOLDER, f"<|begin_of_video|>{video_structure}<|end_of_video|>", 1)
|
||||
num_video_tokens += 1
|
||||
|
||||
message["content"] = content
|
||||
|
||||
return messages
|
||||
|
||||
@override
|
||||
def get_mm_inputs(
|
||||
self,
|
||||
images: list["ImageInput"],
|
||||
videos: list["VideoInput"],
|
||||
audios: list["AudioInput"],
|
||||
imglens: list[int],
|
||||
vidlens: list[int],
|
||||
audlens: list[int],
|
||||
batch_ids: list[list[int]],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> dict[str, Union[list[int], "torch.Tensor"]]:
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
||||
mm_inputs.pop("timestamps", None)
|
||||
return mm_inputs
|
||||
|
||||
|
||||
class Qwen2OmniPlugin(Qwen2VLPlugin):
|
||||
@override
|
||||
def _get_mm_inputs(
|
||||
@ -1718,6 +1872,8 @@ class VideoLlavaPlugin(BasePlugin):
|
||||
PLUGINS = {
|
||||
"base": BasePlugin,
|
||||
"gemma3": Gemma3Plugin,
|
||||
"glm4v": GLM4VPlugin,
|
||||
"gemma3n": Gemma3nPlugin,
|
||||
"intern_vl": InternVLPlugin,
|
||||
"kimi_vl": KimiVLPlugin,
|
||||
"llama4": Llama4Plugin,
|
||||
|
@ -916,6 +916,18 @@ register_template(
|
||||
)
|
||||
|
||||
|
||||
# copied from chatml template
|
||||
register_template(
|
||||
name="falcon_h1",
|
||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
|
||||
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
||||
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||
stop_words=["<|im_end|>", "<|end_of_text|>"],
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="fewshot",
|
||||
format_assistant=StringFormatter(slots=["{{content}}\n\n"]),
|
||||
@ -939,6 +951,22 @@ register_template(
|
||||
)
|
||||
|
||||
|
||||
# copied from gemma template
|
||||
register_template(
|
||||
name="gemma2",
|
||||
format_user=StringFormatter(slots=["<start_of_turn>user\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}<end_of_turn>\n"]),
|
||||
format_system=StringFormatter(slots=["{{content}}\n\n"]),
|
||||
format_observation=StringFormatter(
|
||||
slots=["<start_of_turn>tool\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]
|
||||
),
|
||||
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||
stop_words=["<eos>", "<end_of_turn>"],
|
||||
efficient_eos=True,
|
||||
template_class=Llama2Template,
|
||||
)
|
||||
|
||||
|
||||
# copied from gemma template
|
||||
register_template(
|
||||
name="gemma3",
|
||||
@ -956,6 +984,22 @@ register_template(
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="gemma3n",
|
||||
format_user=StringFormatter(slots=["<start_of_turn>user\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}<end_of_turn>\n"]),
|
||||
format_system=StringFormatter(slots=["{{content}}\n\n"]),
|
||||
format_observation=StringFormatter(
|
||||
slots=["<start_of_turn>tool\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]
|
||||
),
|
||||
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||
stop_words=["<end_of_turn>"],
|
||||
replace_eos=True,
|
||||
mm_plugin=get_mm_plugin("gemma3n", image_token="<image_soft_token>", audio_token="<audio_soft_token>"),
|
||||
template_class=Llama2Template,
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="glm4",
|
||||
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]),
|
||||
@ -970,6 +1014,38 @@ register_template(
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="glm4_moe",
|
||||
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]),
|
||||
format_assistant=StringFormatter(slots=["\n{{content}}"]),
|
||||
format_system=StringFormatter(slots=["<|system|>\n{{content}}"]),
|
||||
format_function=FunctionFormatter(slots=["{{content}}"], tool_format="glm4_moe"),
|
||||
format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>"]),
|
||||
format_tools=ToolFormatter(tool_format="glm4_moe"),
|
||||
format_prefix=EmptyFormatter(slots=["[gMASK]<sop>"]),
|
||||
stop_words=["<|user|>", "<|observation|>"],
|
||||
efficient_eos=True,
|
||||
template_class=ReasoningTemplate,
|
||||
)
|
||||
|
||||
|
||||
# copied from glm4 template
|
||||
register_template(
|
||||
name="glm4v",
|
||||
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]),
|
||||
format_assistant=StringFormatter(slots=["\n{{content}}"]),
|
||||
format_system=StringFormatter(slots=["<|system|>\n{{content}}"]),
|
||||
format_function=FunctionFormatter(slots=["{{content}}"], tool_format="glm4"),
|
||||
format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>"]),
|
||||
format_tools=ToolFormatter(tool_format="glm4"),
|
||||
format_prefix=EmptyFormatter(slots=["[gMASK]<sop>"]),
|
||||
stop_words=["<|user|>", "<|observation|>", "</answer>"],
|
||||
efficient_eos=True,
|
||||
mm_plugin=get_mm_plugin(name="glm4v", image_token="<|image|>", video_token="<|video|>"),
|
||||
template_class=ReasoningTemplate,
|
||||
)
|
||||
|
||||
|
||||
# copied from glm4 template
|
||||
register_template(
|
||||
name="glmz1",
|
||||
@ -1010,6 +1086,25 @@ register_template(
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="granite4",
|
||||
format_user=StringFormatter(
|
||||
slots=[
|
||||
"<|start_of_role|>user<|end_of_role|>{{content}}<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>"
|
||||
]
|
||||
),
|
||||
format_assistant=StringFormatter(slots=["{{content}}<|end_of_text|>\n"]),
|
||||
format_system=StringFormatter(slots=["<|start_of_role|>system<|end_of_role|>{{content}}<|end_of_text|>\n"]),
|
||||
format_function=FunctionFormatter(slots=["{{content}}<|end_of_text|>\n"], tool_format="default"),
|
||||
format_observation=StringFormatter(
|
||||
slots=["<|start_of_role|>tool<|end_of_role|>{{content}}<|end_of_text|>\n<|start_of_role|>assistant\n"]
|
||||
),
|
||||
format_tools=ToolFormatter(tool_format="default"),
|
||||
stop_words=["<|end_of_text|>"],
|
||||
default_system=("You are Granite, developed by IBM. You are a helpful AI assistant."),
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="index",
|
||||
format_user=StringFormatter(slots=["reserved_0{{content}}reserved_1"]),
|
||||
@ -1076,6 +1171,24 @@ register_template(
|
||||
)
|
||||
|
||||
|
||||
# copied from qwen template
|
||||
register_template(
|
||||
name="keye_vl",
|
||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
|
||||
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
||||
format_function=FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen"),
|
||||
format_observation=StringFormatter(
|
||||
slots=["<|im_start|>user\n<tool_response>\n{{content}}\n</tool_response><|im_end|>\n<|im_start|>assistant\n"]
|
||||
),
|
||||
format_tools=ToolFormatter(tool_format="qwen"),
|
||||
stop_words=["<|im_end|>"],
|
||||
replace_eos=True,
|
||||
mm_plugin=get_mm_plugin(name="qwen2_vl", image_token="<|image_pad|>", video_token="<|video_pad|>"),
|
||||
template_class=ReasoningTemplate,
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="kimi_vl",
|
||||
format_user=StringFormatter(
|
||||
|
@ -38,8 +38,20 @@ DEFAULT_TOOL_PROMPT = (
|
||||
)
|
||||
|
||||
GLM4_TOOL_PROMPT = (
|
||||
"你是一个名为 ChatGLM 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,"
|
||||
"你的任务是针对用户的问题和要求提供适当的答复和支持。# 可用工具{tool_text}"
|
||||
"你是一个名为 ChatGLM 的人工智能助手。你是基于智谱 AI 公司训练的语言模型 GLM-4 模型开发的,"
|
||||
"你的任务是针对用户的问题和要求提供适当的答复和支持。\n\n# 可用工具{tool_text}"
|
||||
)
|
||||
|
||||
GLM4_MOE_TOOL_PROMPT = (
|
||||
"\n\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\n"
|
||||
"You are provided with function signatures within <tools></tools> XML tags:\n<tools>{tool_text}"
|
||||
"\n</tools>\n\nFor each function call, output the function name and arguments within the following XML format:"
|
||||
"\n<tool_call>{{function-name}}"
|
||||
"\n<arg_key>{{arg-key-1}}</arg_key>"
|
||||
"\n<arg_value>{{arg-value-1}}</arg_value>"
|
||||
"\n<arg_key>{{arg-key-2}}</arg_key>"
|
||||
"\n<arg_value>{{arg-value-2}}</arg_value>"
|
||||
"\n...\n</tool_call>\n"
|
||||
)
|
||||
|
||||
LLAMA3_TOOL_PROMPT = (
|
||||
@ -303,12 +315,45 @@ class QwenToolUtils(ToolUtils):
|
||||
return results
|
||||
|
||||
|
||||
class GLM4MOEToolUtils(QwenToolUtils):
|
||||
r"""GLM-4-MOE tool using template."""
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def tool_formatter(tools: list[dict[str, Any]]) -> str:
|
||||
tool_text = ""
|
||||
for tool in tools:
|
||||
wrapped_tool = tool if tool.get("type") == "function" else {"type": "function", "function": tool}
|
||||
tool_text += "\n" + json.dumps(wrapped_tool, ensure_ascii=False)
|
||||
|
||||
return GLM4_MOE_TOOL_PROMPT.format(tool_text=tool_text)
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def function_formatter(functions: list["FunctionCall"]) -> str:
|
||||
function_json = [
|
||||
{"func_name": name, "func_key_values": json.loads(arguments)} for name, arguments in functions
|
||||
]
|
||||
function_texts = []
|
||||
for func in function_json:
|
||||
prompt = "\n<tool_call>" + func["func_name"]
|
||||
for key, value in func["func_key_values"].items():
|
||||
prompt += "\n<arg_key>" + key + "</arg_key>"
|
||||
if not isinstance(value, str):
|
||||
value = json.dumps(value, ensure_ascii=False)
|
||||
prompt += "\n<arg_value>" + value + "</arg_value>"
|
||||
function_texts.append(prompt)
|
||||
|
||||
return "\n".join(function_texts)
|
||||
|
||||
|
||||
TOOLS = {
|
||||
"default": DefaultToolUtils(),
|
||||
"glm4": GLM4ToolUtils(),
|
||||
"llama3": Llama3ToolUtils(),
|
||||
"mistral": MistralToolUtils(),
|
||||
"qwen": QwenToolUtils(),
|
||||
"glm4_moe": GLM4MOEToolUtils(),
|
||||
}
|
||||
|
||||
|
||||
|
@ -143,7 +143,7 @@ def register_model_group(
|
||||
for name, path in models.items():
|
||||
SUPPORTED_MODELS[name] = path
|
||||
if template is not None and (
|
||||
any(suffix in name for suffix in ("-Chat", "-Distill", "-Instruct")) or multimodal
|
||||
any(suffix in name for suffix in ("-Chat", "-Distill", "-Instruct", "-Thinking")) or multimodal
|
||||
):
|
||||
DEFAULT_TEMPLATE[name] = template
|
||||
|
||||
@ -276,7 +276,7 @@ register_model_group(
|
||||
register_model_group(
|
||||
models={
|
||||
"ChatGLM2-6B-Chat": {
|
||||
DownloadSource.DEFAULT: "THUDM/chatglm2-6b",
|
||||
DownloadSource.DEFAULT: "zai-org/chatglm2-6b",
|
||||
DownloadSource.MODELSCOPE: "ZhipuAI/chatglm2-6b",
|
||||
}
|
||||
},
|
||||
@ -287,11 +287,11 @@ register_model_group(
|
||||
register_model_group(
|
||||
models={
|
||||
"ChatGLM3-6B-Base": {
|
||||
DownloadSource.DEFAULT: "THUDM/chatglm3-6b-base",
|
||||
DownloadSource.DEFAULT: "zai-org/chatglm3-6b-base",
|
||||
DownloadSource.MODELSCOPE: "ZhipuAI/chatglm3-6b-base",
|
||||
},
|
||||
"ChatGLM3-6B-Chat": {
|
||||
DownloadSource.DEFAULT: "THUDM/chatglm3-6b",
|
||||
DownloadSource.DEFAULT: "zai-org/chatglm3-6b",
|
||||
DownloadSource.MODELSCOPE: "ZhipuAI/chatglm3-6b",
|
||||
},
|
||||
},
|
||||
@ -333,7 +333,7 @@ register_model_group(
|
||||
register_model_group(
|
||||
models={
|
||||
"CodeGeeX4-9B-Chat": {
|
||||
DownloadSource.DEFAULT: "THUDM/codegeex4-all-9b",
|
||||
DownloadSource.DEFAULT: "zai-org/codegeex4-all-9b",
|
||||
DownloadSource.MODELSCOPE: "ZhipuAI/codegeex4-all-9b",
|
||||
},
|
||||
},
|
||||
@ -589,6 +589,17 @@ register_model_group(
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Devstral-Small-2507-Instruct": {
|
||||
DownloadSource.DEFAULT: "mistralai/Devstral-Small-2507",
|
||||
DownloadSource.MODELSCOPE: "mistralai/Devstral-Small-2507",
|
||||
},
|
||||
},
|
||||
template="mistral_small",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"EXAONE-3.0-7.8B-Instruct": {
|
||||
@ -633,6 +644,60 @@ register_model_group(
|
||||
template="falcon",
|
||||
)
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Falcon-H1-0.5B-Base": {
|
||||
DownloadSource.DEFAULT: "tiiuae/Falcon-H1-0.5B-Base",
|
||||
DownloadSource.MODELSCOPE: "tiiuae/Falcon-H1-0.5B-Base",
|
||||
},
|
||||
"Falcon-H1-1.5B-Base": {
|
||||
DownloadSource.DEFAULT: "tiiuae/Falcon-H1-1.5B-Base",
|
||||
DownloadSource.MODELSCOPE: "tiiuae/Falcon-H1-1.5B-Base",
|
||||
},
|
||||
"Falcon-H1-1.5B-Deep-Base": {
|
||||
DownloadSource.DEFAULT: "tiuae/Falcon-H1-1.5B-Deep-Base",
|
||||
DownloadSource.MODELSCOPE: "tiiuae/Falcon-H1-1.5B-Deep-Base",
|
||||
},
|
||||
"Falcon-H1-3B-Base": {
|
||||
DownloadSource.DEFAULT: "tiiuae/Falcon-H1-3B-Base",
|
||||
DownloadSource.MODELSCOPE: "tiiuae/Falcon-H1-3B-Base",
|
||||
},
|
||||
"Falcon-H1-7B-Base": {
|
||||
DownloadSource.DEFAULT: "tiiuae/Falcon-H1-7B-Base",
|
||||
DownloadSource.MODELSCOPE: "tiiuae/Falcon-H1-7B-Base",
|
||||
},
|
||||
"Falcon-H1-34B-Base": {
|
||||
DownloadSource.DEFAULT: "tiiuae/Falcon-H1-34B-Base",
|
||||
DownloadSource.MODELSCOPE: "tiiuae/Falcon-H1-34B-Base",
|
||||
},
|
||||
"Falcon-H1-0.5B-Instruct": {
|
||||
DownloadSource.DEFAULT: "tiiuae/Falcon-H1-0.5B-Instruct",
|
||||
DownloadSource.MODELSCOPE: "tiiuae/Falcon-H1-0.5B-Instruct",
|
||||
},
|
||||
"Falcon-H1-1.5B-Instruct": {
|
||||
DownloadSource.DEFAULT: "tiiuae/Falcon-H1-1.5B-Instruct",
|
||||
DownloadSource.MODELSCOPE: "tiiuae/Falcon-H1-1.5B-Instruct",
|
||||
},
|
||||
"Falcon-H1-1.5B-Deep-Instruct": {
|
||||
DownloadSource.DEFAULT: "tiiuae/Falcon-H1-1.5B-Deep-Instruct",
|
||||
DownloadSource.MODELSCOPE: "tiiuae/Falcon-H1-1.5B-Deep-Instruct",
|
||||
},
|
||||
"Falcon-H1-3B-Instruct": {
|
||||
DownloadSource.DEFAULT: "tiiuae/Falcon-H1-3B-Instruct",
|
||||
DownloadSource.MODELSCOPE: "tiiuae/Falcon-H1-3B-Instruct",
|
||||
},
|
||||
"Falcon-H1-7B-Instruct": {
|
||||
DownloadSource.DEFAULT: "tiiuae/Falcon-H1-7B-Instruct",
|
||||
DownloadSource.MODELSCOPE: "tiiuae/Falcon-H1-7B-Instruct",
|
||||
},
|
||||
"Falcon-H1-34B-Instruct": {
|
||||
DownloadSource.DEFAULT: "tiiuae/Falcon-H1-34B-Instruct",
|
||||
DownloadSource.MODELSCOPE: "tiiuae/Falcon-H1-34B-Instruct",
|
||||
},
|
||||
},
|
||||
template="falcon_h1",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
@ -658,6 +723,13 @@ register_model_group(
|
||||
"Gemma-1.1-7B-Instruct": {
|
||||
DownloadSource.DEFAULT: "google/gemma-1.1-7b-it",
|
||||
},
|
||||
},
|
||||
template="gemma",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Gemma-2-2B": {
|
||||
DownloadSource.DEFAULT: "google/gemma-2-2b",
|
||||
DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-2b",
|
||||
@ -697,7 +769,7 @@ register_model_group(
|
||||
DownloadSource.MODELSCOPE: "google/medgemma-27b-text-it",
|
||||
},
|
||||
},
|
||||
template="gemma",
|
||||
template="gemma2",
|
||||
)
|
||||
|
||||
|
||||
@ -741,31 +813,55 @@ register_model_group(
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Gemma-3n-E2B": {
|
||||
DownloadSource.DEFAULT: "google/gemma-3n-E2B",
|
||||
DownloadSource.MODELSCOPE: "LLM-Research/gemma-3n-E2B",
|
||||
},
|
||||
"Gemma-3n-E4B": {
|
||||
DownloadSource.DEFAULT: "google/gemma-3n-E4B",
|
||||
DownloadSource.MODELSCOPE: "LLM-Research/gemma-3n-E4B",
|
||||
},
|
||||
"Gemma-3n-E2B-Instruct": {
|
||||
DownloadSource.DEFAULT: "google/gemma-3n-E2B-it",
|
||||
DownloadSource.MODELSCOPE: "LLM-Research/gemma-3n-E2B-it",
|
||||
},
|
||||
"Gemma-3n-E4B-Instruct": {
|
||||
DownloadSource.DEFAULT: "google/gemma-3n-E4B-it",
|
||||
DownloadSource.MODELSCOPE: "LLM-Research/gemma-3n-E4B-it",
|
||||
},
|
||||
},
|
||||
template="gemma3n",
|
||||
multimodal=True,
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"GLM-4-9B": {
|
||||
DownloadSource.DEFAULT: "THUDM/glm-4-9b",
|
||||
DownloadSource.DEFAULT: "zai-org/glm-4-9b",
|
||||
DownloadSource.MODELSCOPE: "ZhipuAI/glm-4-9b",
|
||||
},
|
||||
"GLM-4-9B-Chat": {
|
||||
DownloadSource.DEFAULT: "THUDM/glm-4-9b-chat",
|
||||
DownloadSource.DEFAULT: "zai-org/glm-4-9b-chat",
|
||||
DownloadSource.MODELSCOPE: "ZhipuAI/glm-4-9b-chat",
|
||||
DownloadSource.OPENMIND: "LlamaFactory/glm-4-9b-chat",
|
||||
},
|
||||
"GLM-4-9B-1M-Chat": {
|
||||
DownloadSource.DEFAULT: "THUDM/glm-4-9b-chat-1m",
|
||||
DownloadSource.DEFAULT: "zai-org/glm-4-9b-chat-1m",
|
||||
DownloadSource.MODELSCOPE: "ZhipuAI/glm-4-9b-chat-1m",
|
||||
},
|
||||
"GLM-4-0414-9B-Chat": {
|
||||
DownloadSource.DEFAULT: "THUDM/GLM-4-9B-0414",
|
||||
DownloadSource.DEFAULT: "zai-org/GLM-4-9B-0414",
|
||||
DownloadSource.MODELSCOPE: "ZhipuAI/GLM-4-9B-0414",
|
||||
},
|
||||
"GLM-4-0414-32B-Base": {
|
||||
DownloadSource.DEFAULT: "THUDM/GLM-4-32B-Base-0414",
|
||||
DownloadSource.DEFAULT: "zai-org/GLM-4-32B-Base-0414",
|
||||
DownloadSource.MODELSCOPE: "ZhipuAI/GLM-4-32B-Base-0414",
|
||||
},
|
||||
"GLM-4-0414-32B-Chat": {
|
||||
DownloadSource.DEFAULT: "THUDM/GLM-4-32B-0414",
|
||||
DownloadSource.DEFAULT: "zai-org/GLM-4-32B-0414",
|
||||
DownloadSource.MODELSCOPE: "ZhipuAI/GLM-4-32B-0414",
|
||||
},
|
||||
},
|
||||
@ -773,14 +869,53 @@ register_model_group(
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"GLM-4.1V-9B-Base": {
|
||||
DownloadSource.DEFAULT: "zai-org/GLM-4.1V-9B-Base",
|
||||
DownloadSource.MODELSCOPE: "ZhipuAI/GLM-4.1V-9B-Base",
|
||||
},
|
||||
"GLM-4.1V-9B-Thinking": {
|
||||
DownloadSource.DEFAULT: "zai-org/GLM-4.1V-9B-Thinking",
|
||||
DownloadSource.MODELSCOPE: "ZhipuAI/GLM-4.1V-9B-Thinking",
|
||||
},
|
||||
},
|
||||
template="glm4v",
|
||||
multimodal=True,
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"GLM-4.5-Air-Base": {
|
||||
DownloadSource.DEFAULT: "zai-org/GLM-4.5-Air-Base",
|
||||
DownloadSource.MODELSCOPE: "ZhipuAI/GLM-4.5-Air-Base",
|
||||
},
|
||||
"GLM-4.5-Base": {
|
||||
DownloadSource.DEFAULT: "zai-org/GLM-4.5-Base",
|
||||
DownloadSource.MODELSCOPE: "ZhipuAI/GLM-4.5-Base",
|
||||
},
|
||||
"GLM-4.5-Air-Chat": {
|
||||
DownloadSource.DEFAULT: "zai-org/GLM-4.5-Air",
|
||||
DownloadSource.MODELSCOPE: "ZhipuAI/GLM-4.5-Air",
|
||||
},
|
||||
"GLM-4.5-Chat": {
|
||||
DownloadSource.DEFAULT: "zai-org/GLM-4.5",
|
||||
DownloadSource.MODELSCOPE: "ZhipuAI/GLM-4.5",
|
||||
},
|
||||
},
|
||||
template="glm4_moe",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"GLM-Z1-0414-9B-Chat": {
|
||||
DownloadSource.DEFAULT: "THUDM/GLM-Z1-9B-0414",
|
||||
DownloadSource.DEFAULT: "zai-org/GLM-Z1-9B-0414",
|
||||
DownloadSource.MODELSCOPE: "ZhipuAI/GLM-Z1-9B-0414",
|
||||
},
|
||||
"GLM-Z1-0414-32B-Chat": {
|
||||
DownloadSource.DEFAULT: "THUDM/GLM-Z1-32B-0414",
|
||||
DownloadSource.DEFAULT: "zai-org/GLM-Z1-32B-0414",
|
||||
DownloadSource.MODELSCOPE: "ZhipuAI/GLM-Z1-32B-0414",
|
||||
},
|
||||
},
|
||||
@ -917,6 +1052,17 @@ register_model_group(
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Granite-4.0-tiny-preview": {
|
||||
DownloadSource.DEFAULT: "ibm-granite/granite-4.0-tiny-preview",
|
||||
DownloadSource.MODELSCOPE: "ibm-granite/granite-4.0-tiny-preview",
|
||||
},
|
||||
},
|
||||
template="granite4",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Hunyuan-7B-Instruct": {
|
||||
@ -1089,6 +1235,29 @@ register_model_group(
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Keye-VL-8B-Chat": {
|
||||
DownloadSource.DEFAULT: "Kwai-Keye/Keye-VL-8B-Preview",
|
||||
DownloadSource.MODELSCOPE: "Kwai-Keye/Keye-VL-8B-Preview",
|
||||
},
|
||||
},
|
||||
template="keye_vl",
|
||||
multimodal=True,
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Kimi-Dev-72B-Instruct": {
|
||||
DownloadSource.DEFAULT: "moonshotai/Kimi-Dev-72B",
|
||||
DownloadSource.MODELSCOPE: "moonshotai/Kimi-Dev-72B",
|
||||
},
|
||||
},
|
||||
template="qwen",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Kimi-VL-A3B-Instruct": {
|
||||
@ -1099,6 +1268,10 @@ register_model_group(
|
||||
DownloadSource.DEFAULT: "moonshotai/Kimi-VL-A3B-Thinking",
|
||||
DownloadSource.MODELSCOPE: "moonshotai/Kimi-VL-A3B-Thinking",
|
||||
},
|
||||
"Kimi-VL-A3B-Thinking-2506": {
|
||||
DownloadSource.DEFAULT: "moonshotai/Kimi-VL-A3B-Thinking-2506",
|
||||
DownloadSource.MODELSCOPE: "moonshotai/Kimi-VL-A3B-Thinking-2506",
|
||||
},
|
||||
},
|
||||
template="kimi_vl",
|
||||
multimodal=True,
|
||||
@ -1617,6 +1790,10 @@ register_model_group(
|
||||
DownloadSource.DEFAULT: "mistralai/Mistral-Small-3.1-24B-Instruct-2503",
|
||||
DownloadSource.MODELSCOPE: "mistralai/Mistral-Small-3.1-24B-Instruct-2503",
|
||||
},
|
||||
"Mistral-Small-3.2-24B-Instruct": {
|
||||
DownloadSource.DEFAULT: "mistralai/Mistral-Small-3.2-24B-Instruct-2506",
|
||||
DownloadSource.MODELSCOPE: "mistralai/Mistral-Small-3.2-24B-Instruct-2506",
|
||||
},
|
||||
},
|
||||
template="mistral_small",
|
||||
multimodal=True,
|
||||
@ -2538,67 +2715,83 @@ register_model_group(
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen3-30B-A3B-Base",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen3-30B-A3B-Base",
|
||||
},
|
||||
"Qwen3-0.6B-Instruct": {
|
||||
"Qwen3-0.6B-Thinking": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen3-0.6B",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen3-0.6B",
|
||||
},
|
||||
"Qwen3-1.7B-Instruct": {
|
||||
"Qwen3-1.7B-Thinking": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen3-1.7B",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen3-1.7B",
|
||||
},
|
||||
"Qwen3-4B-Instruct": {
|
||||
"Qwen3-4B-Thinking": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen3-4B",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen3-4B",
|
||||
},
|
||||
"Qwen3-8B-Instruct": {
|
||||
"Qwen3-8B-Thinking": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen3-8B",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen3-8B",
|
||||
},
|
||||
"Qwen3-14B-Instruct": {
|
||||
"Qwen3-14B-Thinking": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen3-14B",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen3-14B",
|
||||
},
|
||||
"Qwen3-32B-Instruct": {
|
||||
"Qwen3-32B-Thinking": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen3-32B",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen3-32B",
|
||||
},
|
||||
"Qwen3-30B-A3B-Instruct": {
|
||||
"Qwen3-30B-A3B-Thinking": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen3-30B-A3B",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen3-30B-A3B",
|
||||
},
|
||||
"Qwen3-235B-A22B-Instruct": {
|
||||
"Qwen3-30B-A3B-Instruct-2507": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen3-30B-A3B-Instruct-2507",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen3-30B-A3B-Instruct-2507",
|
||||
},
|
||||
"Qwen3-30B-A3B-Thinking-2507": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen3-30B-A3B-Thinking-2507",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen3-30B-A3B-Thinking-2507",
|
||||
},
|
||||
"Qwen3-235B-A22B-Thinking": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen3-235B-A22B",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen3-235B-A22B",
|
||||
},
|
||||
"Qwen3-0.6B-Instruct-GPTQ-Int8": {
|
||||
"Qwen3-235B-A22B-Instruct-2507": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen3-235B-A22B-Instruct-2507",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen3-235B-A22B-Instruct-2507",
|
||||
},
|
||||
"Qwen3-235B-A22B-Thinking-2507": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen3-235B-A22B-Thinking-2507",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen3-235B-A22B-Thinking-2507",
|
||||
},
|
||||
"Qwen3-0.6B-Thinking-GPTQ-Int8": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen3-0.6B-GPTQ-Int8",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen3-0.6B-GPTQ-Int8",
|
||||
},
|
||||
"Qwen3-1.7B-Instruct-GPTQ-Int8": {
|
||||
"Qwen3-1.7B-Thinking-GPTQ-Int8": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen3-1.7B-GPTQ-Int8",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen3-1.7B-GPTQ-Int8",
|
||||
},
|
||||
"Qwen3-4B-Instruct-AWQ": {
|
||||
"Qwen3-4B-Thinking-AWQ": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen3-4B-AWQ",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen3-4B-AWQ",
|
||||
},
|
||||
"Qwen3-8B-Instruct-AWQ": {
|
||||
"Qwen3-8B-Thinking-AWQ": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen3-8B-AWQ",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen3-8B-AWQ",
|
||||
},
|
||||
"Qwen3-14B-Instruct-AWQ": {
|
||||
"Qwen3-14B-Thinking-AWQ": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen3-14B-AWQ",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen3-14B-AWQ",
|
||||
},
|
||||
"Qwen3-32B-Instruct-AWQ": {
|
||||
"Qwen3-32B-Thinking-AWQ": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen3-32B-AWQ",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen3-32B-AWQ",
|
||||
},
|
||||
"Qwen3-30B-A3B-Instruct-GPTQ-Int4": {
|
||||
"Qwen3-30B-A3B-Thinking-GPTQ-Int4": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen3-30B-A3B-GPTQ-Int4",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen3-30B-A3B-GPTQ-Int4",
|
||||
},
|
||||
"Qwen3-235B-A22B-Instruct-GPTQ-Int4": {
|
||||
"Qwen3-235B-A22B-Thinking-GPTQ-Int4": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen3-235B-A22B-GPTQ-Int4",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen3-235B-A22B-GPTQ-Int4",
|
||||
},
|
||||
|
@ -27,7 +27,7 @@ import trl
|
||||
from transformers.utils import is_torch_cuda_available, is_torch_npu_available
|
||||
|
||||
|
||||
VERSION = "0.9.3"
|
||||
VERSION = "0.9.4.dev0"
|
||||
|
||||
|
||||
def print_env() -> None:
|
||||
|
@ -50,7 +50,7 @@ class LoggerHandler(logging.Handler):
|
||||
|
||||
def _write_log(self, log_entry: str) -> None:
|
||||
with open(self.running_log, "a", encoding="utf-8") as f:
|
||||
f.write(log_entry + "\n\n")
|
||||
f.write(log_entry + "\n")
|
||||
|
||||
def emit(self, record) -> None:
|
||||
if record.name == "httpx":
|
||||
|
@ -23,6 +23,7 @@ from typing import TYPE_CHECKING, Any, Literal, Union
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import transformers.dynamic_module_utils
|
||||
from huggingface_hub.utils import WeakFileLock
|
||||
from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList
|
||||
from transformers.dynamic_module_utils import get_relative_imports
|
||||
from transformers.utils import (
|
||||
@ -35,7 +36,6 @@ from transformers.utils import (
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
from . import logging
|
||||
from .packages import is_transformers_version_greater_than
|
||||
|
||||
|
||||
_is_fp16_available = is_torch_npu_available() or is_torch_cuda_available()
|
||||
@ -94,15 +94,11 @@ def check_version(requirement: str, mandatory: bool = False) -> None:
|
||||
|
||||
def check_dependencies() -> None:
|
||||
r"""Check the version of the required packages."""
|
||||
check_version(
|
||||
"transformers>=4.45.0,<=4.52.4,!=4.46.0,!=4.46.1,!=4.46.2,!=4.46.3,!=4.47.0,!=4.47.1,!=4.48.0,!=4.52.0"
|
||||
)
|
||||
check_version("transformers>=4.49.0,<=4.52.4,!=4.52.0")
|
||||
check_version("datasets>=2.16.0,<=3.6.0")
|
||||
check_version("accelerate>=0.34.0,<=1.7.0")
|
||||
check_version("accelerate>=1.3.0,<=1.7.0")
|
||||
check_version("peft>=0.14.0,<=0.15.2")
|
||||
check_version("trl>=0.8.6,<=0.9.6")
|
||||
if is_transformers_version_greater_than("4.46.0") and not is_transformers_version_greater_than("4.48.1"):
|
||||
logger.warning_rank0_once("There are known bugs in transformers v4.46.0-v4.48.0, please use other versions.")
|
||||
|
||||
|
||||
def calculate_tps(dataset: list[dict[str, Any]], metrics: dict[str, float], stage: Literal["sft", "rm"]) -> float:
|
||||
@ -182,8 +178,22 @@ def get_logits_processor() -> "LogitsProcessorList":
|
||||
return logits_processor
|
||||
|
||||
|
||||
def get_current_memory() -> tuple[int, int]:
|
||||
r"""Get the available and total memory for the current device (in Bytes)."""
|
||||
if is_torch_xpu_available():
|
||||
return torch.xpu.mem_get_info()
|
||||
elif is_torch_npu_available():
|
||||
return torch.npu.mem_get_info()
|
||||
elif is_torch_mps_available():
|
||||
return torch.mps.current_allocated_memory(), torch.mps.recommended_max_memory()
|
||||
elif is_torch_cuda_available():
|
||||
return torch.cuda.mem_get_info()
|
||||
else:
|
||||
return 0, -1
|
||||
|
||||
|
||||
def get_peak_memory() -> tuple[int, int]:
|
||||
r"""Get the peak memory usage for the current device (in Bytes)."""
|
||||
r"""Get the peak memory usage (allocated, reserved) for the current device (in Bytes)."""
|
||||
if is_torch_xpu_available():
|
||||
return torch.xpu.max_memory_allocated(), torch.xpu.max_memory_reserved()
|
||||
elif is_torch_npu_available():
|
||||
@ -193,7 +203,7 @@ def get_peak_memory() -> tuple[int, int]:
|
||||
elif is_torch_cuda_available():
|
||||
return torch.cuda.max_memory_allocated(), torch.cuda.max_memory_reserved()
|
||||
else:
|
||||
return 0, 0
|
||||
return 0, -1
|
||||
|
||||
|
||||
def has_tokenized_data(path: "os.PathLike") -> bool:
|
||||
@ -259,25 +269,36 @@ def try_download_model_from_other_hub(model_args: "ModelArguments") -> str:
|
||||
return model_args.model_name_or_path
|
||||
|
||||
if use_modelscope():
|
||||
check_version("modelscope>=1.11.0", mandatory=True)
|
||||
check_version("modelscope>=1.14.0", mandatory=True)
|
||||
from modelscope import snapshot_download # type: ignore
|
||||
from modelscope.hub.api import HubApi # type: ignore
|
||||
|
||||
if model_args.ms_hub_token:
|
||||
api = HubApi()
|
||||
api.login(model_args.ms_hub_token)
|
||||
|
||||
revision = "master" if model_args.model_revision == "main" else model_args.model_revision
|
||||
return snapshot_download(
|
||||
model_args.model_name_or_path,
|
||||
revision=revision,
|
||||
cache_dir=model_args.cache_dir,
|
||||
)
|
||||
with WeakFileLock(os.path.abspath(os.path.expanduser("~/.cache/llamafactory/modelscope.lock"))):
|
||||
model_path = snapshot_download(
|
||||
model_args.model_name_or_path,
|
||||
revision=revision,
|
||||
cache_dir=model_args.cache_dir,
|
||||
)
|
||||
|
||||
return model_path
|
||||
|
||||
if use_openmind():
|
||||
check_version("openmind>=0.8.0", mandatory=True)
|
||||
from openmind.utils.hub import snapshot_download # type: ignore
|
||||
|
||||
return snapshot_download(
|
||||
model_args.model_name_or_path,
|
||||
revision=model_args.model_revision,
|
||||
cache_dir=model_args.cache_dir,
|
||||
)
|
||||
with WeakFileLock(os.path.abspath(os.path.expanduser("~/.cache/llamafactory/openmind.lock"))):
|
||||
model_path = snapshot_download(
|
||||
model_args.model_name_or_path,
|
||||
revision=model_args.model_revision,
|
||||
cache_dir=model_args.cache_dir,
|
||||
)
|
||||
|
||||
return model_path
|
||||
|
||||
|
||||
def use_modelscope() -> bool:
|
||||
@ -305,5 +326,5 @@ def fix_proxy(ipv6_enabled: bool = False) -> None:
|
||||
r"""Fix proxy settings for gradio ui."""
|
||||
os.environ["no_proxy"] = "localhost,127.0.0.1,0.0.0.0"
|
||||
if ipv6_enabled:
|
||||
for name in ("http_proxy", "https_proxy", "HTTP_PROXY", "HTTPS_PROXY"):
|
||||
os.environ.pop(name, None)
|
||||
os.environ.pop("http_proxy", None)
|
||||
os.environ.pop("HTTP_PROXY", None)
|
||||
|
@ -15,7 +15,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
@ -23,7 +22,6 @@ from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
import yaml
|
||||
from omegaconf import OmegaConf
|
||||
from transformers import HfArgumentParser
|
||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||
@ -62,11 +60,11 @@ def read_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> Union[
|
||||
|
||||
if sys.argv[1].endswith(".yaml") or sys.argv[1].endswith(".yml"):
|
||||
override_config = OmegaConf.from_cli(sys.argv[2:])
|
||||
dict_config = yaml.safe_load(Path(sys.argv[1]).absolute().read_text())
|
||||
dict_config = OmegaConf.load(Path(sys.argv[1]).absolute())
|
||||
return OmegaConf.to_container(OmegaConf.merge(dict_config, override_config))
|
||||
elif sys.argv[1].endswith(".json"):
|
||||
override_config = OmegaConf.from_cli(sys.argv[2:])
|
||||
dict_config = json.loads(Path(sys.argv[1]).absolute().read_text())
|
||||
dict_config = OmegaConf.load(Path(sys.argv[1]).absolute())
|
||||
return OmegaConf.to_container(OmegaConf.merge(dict_config, override_config))
|
||||
else:
|
||||
return sys.argv[1:]
|
||||
@ -148,7 +146,7 @@ def _check_extra_dependencies(
|
||||
check_version("mixture-of-depth>=1.1.6", mandatory=True)
|
||||
|
||||
if model_args.infer_backend == EngineName.VLLM:
|
||||
check_version("vllm>=0.4.3,<=0.8.6")
|
||||
check_version("vllm>=0.4.3,<=0.10.0")
|
||||
check_version("vllm", mandatory=True)
|
||||
elif model_args.infer_backend == EngineName.SGLANG:
|
||||
check_version("sglang>=0.4.5")
|
||||
@ -166,6 +164,9 @@ def _check_extra_dependencies(
|
||||
if finetuning_args.use_adam_mini:
|
||||
check_version("adam-mini", mandatory=True)
|
||||
|
||||
if finetuning_args.use_swanlab:
|
||||
check_version("swanlab", mandatory=True)
|
||||
|
||||
if finetuning_args.plot_loss:
|
||||
check_version("matplotlib", mandatory=True)
|
||||
|
||||
@ -348,6 +349,9 @@ def get_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _
|
||||
# https://github.com/huggingface/transformers/blob/v4.50.0/src/transformers/trainer.py#L782
|
||||
training_args.label_names = training_args.label_names or ["labels"]
|
||||
|
||||
if "swanlab" in training_args.report_to and finetuning_args.use_swanlab:
|
||||
training_args.report_to.remove("swanlab")
|
||||
|
||||
if (
|
||||
training_args.parallel_mode == ParallelMode.DISTRIBUTED
|
||||
and training_args.ddp_find_unused_parameters is None
|
||||
|
@ -188,7 +188,7 @@ def _setup_lora_tuning(
|
||||
|
||||
if adapter_to_resume is not None: # resume lora training
|
||||
if model_args.use_unsloth:
|
||||
model = load_unsloth_peft_model(config, model_args, is_trainable=is_trainable)
|
||||
model = load_unsloth_peft_model(config, model_args, finetuning_args, is_trainable=is_trainable)
|
||||
else:
|
||||
model = PeftModel.from_pretrained(model, adapter_to_resume, is_trainable=is_trainable, **init_kwargs)
|
||||
|
||||
|
@ -19,6 +19,7 @@ import torch
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModelForCausalLM,
|
||||
AutoModelForImageTextToText,
|
||||
AutoModelForSeq2SeqLM,
|
||||
AutoModelForTextToWaveform,
|
||||
AutoModelForVision2Seq,
|
||||
@ -29,7 +30,6 @@ from trl import AutoModelForCausalLMWithValueHead
|
||||
|
||||
from ..extras import logging
|
||||
from ..extras.misc import count_parameters, skip_check_imports, try_download_model_from_other_hub
|
||||
from ..extras.packages import is_transformers_version_greater_than
|
||||
from .adapter import init_adapter
|
||||
from .model_utils.liger_kernel import apply_liger_kernel
|
||||
from .model_utils.misc import register_autoclass
|
||||
@ -39,10 +39,6 @@ from .model_utils.valuehead import load_valuehead_params
|
||||
from .patcher import patch_config, patch_model, patch_processor, patch_tokenizer, patch_valuehead_model
|
||||
|
||||
|
||||
if is_transformers_version_greater_than("4.46.0"):
|
||||
from transformers import AutoModelForImageTextToText
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
|
||||
|
||||
@ -111,9 +107,8 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
|
||||
**init_kwargs,
|
||||
)
|
||||
except Exception as e:
|
||||
raise OSError("Failed to load processor.") from e
|
||||
|
||||
patch_processor(processor, tokenizer, model_args)
|
||||
logger.info_rank0(f"Failed to load processor: {e}.")
|
||||
processor = None
|
||||
|
||||
# Avoid load tokenizer, see:
|
||||
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/auto/processing_auto.py#L324
|
||||
@ -121,6 +116,9 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
|
||||
logger.debug("The loaded processor is not an instance of Processor. Dropping it.")
|
||||
processor = None
|
||||
|
||||
if processor is not None:
|
||||
patch_processor(processor, tokenizer, model_args)
|
||||
|
||||
return {"tokenizer": tokenizer, "processor": processor}
|
||||
|
||||
|
||||
@ -160,10 +158,7 @@ def load_model(
|
||||
else:
|
||||
if type(config) in AutoModelForVision2Seq._model_mapping.keys(): # image-text
|
||||
load_class = AutoModelForVision2Seq
|
||||
elif (
|
||||
is_transformers_version_greater_than("4.46.0")
|
||||
and type(config) in AutoModelForImageTextToText._model_mapping.keys()
|
||||
): # image-text
|
||||
elif type(config) in AutoModelForImageTextToText._model_mapping.keys(): # image-text
|
||||
load_class = AutoModelForImageTextToText
|
||||
elif type(config) in AutoModelForSeq2SeqLM._model_mapping.keys(): # audio-text
|
||||
load_class = AutoModelForSeq2SeqLM
|
||||
|
@ -57,6 +57,11 @@ def add_z3_leaf_module(model: "PreTrainedModel") -> None:
|
||||
|
||||
_set_z3_leaf_modules(model, [GraniteMoeMoE])
|
||||
|
||||
if model_type == "glm4_moe":
|
||||
from transformers.models.glm4_moe.modeling_glm4_moe import Glm4MoeMoE
|
||||
|
||||
_set_z3_leaf_modules(model, [Glm4MoeMoE])
|
||||
|
||||
if model_type == "jamba":
|
||||
from transformers.models.jamba.modeling_jamba import JambaSparseMoeBlock
|
||||
|
||||
|
@ -80,12 +80,15 @@ def get_unsloth_peft_model(
|
||||
|
||||
|
||||
def load_unsloth_peft_model(
|
||||
config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool
|
||||
config: "PretrainedConfig",
|
||||
model_args: "ModelArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
is_trainable: bool,
|
||||
) -> "PreTrainedModel":
|
||||
r"""Load peft model with unsloth. Used in both training and inference."""
|
||||
from unsloth import FastLanguageModel # type: ignore
|
||||
|
||||
unsloth_kwargs = _get_unsloth_kwargs(config, model_args.adapter_name_or_path[0], model_args)
|
||||
unsloth_kwargs = _get_unsloth_kwargs(config, model_args.adapter_name_or_path[0], model_args, finetuning_args)
|
||||
try:
|
||||
if not is_trainable:
|
||||
unsloth_kwargs["use_gradient_checkpointing"] = False
|
||||
|
@ -49,7 +49,7 @@ def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") ->
|
||||
|
||||
try:
|
||||
vhead_file = cached_file(filename=V_HEAD_WEIGHTS_NAME, **kwargs)
|
||||
return torch.load(vhead_file, map_location="cpu")
|
||||
return torch.load(vhead_file, map_location="cpu", weights_only=True)
|
||||
except Exception as err:
|
||||
err_text = str(err)
|
||||
|
||||
|
@ -204,11 +204,37 @@ _register_composite_model(
|
||||
)
|
||||
|
||||
|
||||
_register_composite_model(
|
||||
model_type="gemma3n",
|
||||
vision_model_keys=["vision_tower", "audio_tower"],
|
||||
lora_conflict_keys=["timm_model", "subsample_conv_projection"],
|
||||
)
|
||||
|
||||
|
||||
# copied from qwen2vl
|
||||
_register_composite_model(
|
||||
model_type="glm4v",
|
||||
projector_key="visual.merger",
|
||||
vision_model_keys=["visual.patch_embed", "visual.blocks"],
|
||||
language_model_keys=["language_model", "lm_head"],
|
||||
lora_conflict_keys=["patch_embed"],
|
||||
)
|
||||
|
||||
|
||||
_register_composite_model(
|
||||
model_type="internvl",
|
||||
)
|
||||
|
||||
|
||||
_register_composite_model(
|
||||
model_type="Keye",
|
||||
projector_key="mlp_AR",
|
||||
vision_model_keys=["visual.vision_model.patch_embedding", "visual.vision_model.encoder"],
|
||||
language_model_keys=["model", "lm_head"],
|
||||
lora_conflict_keys=["patch_embedding"],
|
||||
)
|
||||
|
||||
|
||||
_register_composite_model(
|
||||
model_type="llama4",
|
||||
vision_model_keys=["vision_model"],
|
||||
|
@ -178,6 +178,9 @@ def patch_model(
|
||||
resize_embedding_layer(model, tokenizer)
|
||||
|
||||
if is_trainable:
|
||||
if getattr(model.config, "model_type", None) == "gemma3n":
|
||||
setattr(model_args, "disable_gradient_checkpointing", True)
|
||||
|
||||
prepare_model_for_training(model, model_args)
|
||||
autocast_projector_dtype(model, model_args)
|
||||
add_z3_leaf_module(model)
|
||||
|
@ -76,7 +76,7 @@ def fix_valuehead_checkpoint(
|
||||
state_dict: dict[str, torch.Tensor] = {key: f.get_tensor(key) for key in f.keys()}
|
||||
else:
|
||||
path_to_checkpoint = os.path.join(output_dir, WEIGHTS_NAME)
|
||||
state_dict: dict[str, torch.Tensor] = torch.load(path_to_checkpoint, map_location="cpu")
|
||||
state_dict: dict[str, torch.Tensor] = torch.load(path_to_checkpoint, map_location="cpu", weights_only=True)
|
||||
|
||||
os.remove(path_to_checkpoint)
|
||||
decoder_state_dict, v_head_state_dict = {}, {}
|
||||
|
@ -77,14 +77,19 @@ def load_config() -> dict[str, Union[str, dict[str, Any]]]:
|
||||
with open(_get_config_path(), encoding="utf-8") as f:
|
||||
return safe_load(f)
|
||||
except Exception:
|
||||
return {"lang": None, "last_model": None, "path_dict": {}, "cache_dir": None}
|
||||
return {"lang": None, "hub_name": None, "last_model": None, "path_dict": {}, "cache_dir": None}
|
||||
|
||||
|
||||
def save_config(lang: str, model_name: Optional[str] = None, model_path: Optional[str] = None) -> None:
|
||||
def save_config(
|
||||
lang: str, hub_name: Optional[str] = None, model_name: Optional[str] = None, model_path: Optional[str] = None
|
||||
) -> None:
|
||||
r"""Save user config."""
|
||||
os.makedirs(DEFAULT_CACHE_DIR, exist_ok=True)
|
||||
user_config = load_config()
|
||||
user_config["lang"] = lang or user_config["lang"]
|
||||
if hub_name:
|
||||
user_config["hub_name"] = hub_name
|
||||
|
||||
if model_name:
|
||||
user_config["last_model"] = model_name
|
||||
|
||||
@ -247,7 +252,7 @@ def create_ds_config() -> None:
|
||||
"stage": 2,
|
||||
"allgather_partitions": True,
|
||||
"allgather_bucket_size": 5e8,
|
||||
"overlap_comm": True,
|
||||
"overlap_comm": False,
|
||||
"reduce_scatter": True,
|
||||
"reduce_bucket_size": 5e8,
|
||||
"contiguous_gradients": True,
|
||||
@ -262,7 +267,7 @@ def create_ds_config() -> None:
|
||||
|
||||
ds_config["zero_optimization"] = {
|
||||
"stage": 3,
|
||||
"overlap_comm": True,
|
||||
"overlap_comm": False,
|
||||
"contiguous_gradients": True,
|
||||
"sub_group_size": 1e9,
|
||||
"reduce_bucket_size": "auto",
|
||||
|
@ -15,6 +15,7 @@
|
||||
from .chatbot import create_chat_box
|
||||
from .eval import create_eval_tab
|
||||
from .export import create_export_tab
|
||||
from .footer import create_footer
|
||||
from .infer import create_infer_tab
|
||||
from .top import create_top
|
||||
from .train import create_train_tab
|
||||
@ -24,6 +25,7 @@ __all__ = [
|
||||
"create_chat_box",
|
||||
"create_eval_tab",
|
||||
"create_export_tab",
|
||||
"create_footer",
|
||||
"create_infer_tab",
|
||||
"create_top",
|
||||
"create_train_tab",
|
||||
|
@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import json
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
@ -50,7 +51,14 @@ def create_chat_box(
|
||||
) -> tuple["Component", "Component", dict[str, "Component"]]:
|
||||
lang = engine.manager.get_elem_by_id("top.lang")
|
||||
with gr.Column(visible=visible) as chat_box:
|
||||
chatbot = gr.Chatbot(type="messages", show_copy_button=True)
|
||||
kwargs = {}
|
||||
if "show_copy_button" in inspect.signature(gr.Chatbot.__init__).parameters:
|
||||
kwargs["show_copy_button"] = True
|
||||
|
||||
if "resizable" in inspect.signature(gr.Chatbot.__init__).parameters:
|
||||
kwargs["resizable"] = True
|
||||
|
||||
chatbot = gr.Chatbot(type="messages", **kwargs)
|
||||
messages = gr.State([])
|
||||
with gr.Row():
|
||||
with gr.Column(scale=4):
|
||||
|
45
src/llamafactory/webui/components/footer.py
Normal file
45
src/llamafactory/webui/components/footer.py
Normal file
@ -0,0 +1,45 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...extras.misc import get_current_memory
|
||||
from ...extras.packages import is_gradio_available
|
||||
|
||||
|
||||
if is_gradio_available():
|
||||
import gradio as gr
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from gradio.components import Component
|
||||
|
||||
|
||||
def get_device_memory() -> "gr.Slider":
|
||||
free, total = get_current_memory()
|
||||
if total != -1:
|
||||
used = round((total - free) / (1024**3), 2)
|
||||
total = round(total / (1024**3), 2)
|
||||
return gr.Slider(minimum=0, maximum=total, value=used, step=0.01, visible=True)
|
||||
else:
|
||||
return gr.Slider(visible=False)
|
||||
|
||||
|
||||
def create_footer() -> dict[str, "Component"]:
|
||||
with gr.Row():
|
||||
device_memory = gr.Slider(visible=False, interactive=False)
|
||||
timer = gr.Timer(value=5)
|
||||
|
||||
timer.tick(get_device_memory, outputs=[device_memory], queue=False)
|
||||
return dict(device_memory=device_memory)
|
@ -16,9 +16,10 @@ from typing import TYPE_CHECKING
|
||||
|
||||
from ...data import TEMPLATES
|
||||
from ...extras.constants import METHODS, SUPPORTED_MODELS
|
||||
from ...extras.misc import use_modelscope, use_openmind
|
||||
from ...extras.packages import is_gradio_available
|
||||
from ..common import save_config
|
||||
from ..control import can_quantize, can_quantize_to, check_template, get_model_info, list_checkpoints
|
||||
from ..control import can_quantize, can_quantize_to, check_template, get_model_info, list_checkpoints, switch_hub
|
||||
|
||||
|
||||
if is_gradio_available():
|
||||
@ -33,8 +34,10 @@ def create_top() -> dict[str, "Component"]:
|
||||
with gr.Row():
|
||||
lang = gr.Dropdown(choices=["en", "ru", "zh", "ko", "ja"], value=None, scale=1)
|
||||
available_models = list(SUPPORTED_MODELS.keys()) + ["Custom"]
|
||||
model_name = gr.Dropdown(choices=available_models, value=None, scale=3)
|
||||
model_path = gr.Textbox(scale=3)
|
||||
model_name = gr.Dropdown(choices=available_models, value=None, scale=2)
|
||||
model_path = gr.Textbox(scale=2)
|
||||
default_hub = "modelscope" if use_modelscope() else "openmind" if use_openmind() else "huggingface"
|
||||
hub_name = gr.Dropdown(choices=["huggingface", "modelscope", "openmind"], value=default_hub, scale=2)
|
||||
|
||||
with gr.Row():
|
||||
finetuning_type = gr.Dropdown(choices=METHODS, value="lora", scale=1)
|
||||
@ -50,18 +53,25 @@ def create_top() -> dict[str, "Component"]:
|
||||
model_name.change(get_model_info, [model_name], [model_path, template], queue=False).then(
|
||||
list_checkpoints, [model_name, finetuning_type], [checkpoint_path], queue=False
|
||||
).then(check_template, [lang, template])
|
||||
model_name.input(save_config, inputs=[lang, model_name], queue=False)
|
||||
model_path.input(save_config, inputs=[lang, model_name, model_path], queue=False)
|
||||
model_name.input(save_config, inputs=[lang, hub_name, model_name], queue=False)
|
||||
model_path.input(save_config, inputs=[lang, hub_name, model_name, model_path], queue=False)
|
||||
finetuning_type.change(can_quantize, [finetuning_type], [quantization_bit], queue=False).then(
|
||||
list_checkpoints, [model_name, finetuning_type], [checkpoint_path], queue=False
|
||||
)
|
||||
checkpoint_path.focus(list_checkpoints, [model_name, finetuning_type], [checkpoint_path], queue=False)
|
||||
quantization_method.change(can_quantize_to, [quantization_method], [quantization_bit], queue=False)
|
||||
hub_name.change(switch_hub, inputs=[hub_name], queue=False).then(
|
||||
get_model_info, [model_name], [model_path, template], queue=False
|
||||
).then(list_checkpoints, [model_name, finetuning_type], [checkpoint_path], queue=False).then(
|
||||
check_template, [lang, template]
|
||||
)
|
||||
hub_name.input(save_config, inputs=[lang, hub_name], queue=False)
|
||||
|
||||
return dict(
|
||||
lang=lang,
|
||||
model_name=model_name,
|
||||
model_path=model_path,
|
||||
hub_name=hub_name,
|
||||
finetuning_type=finetuning_type,
|
||||
checkpoint_path=checkpoint_path,
|
||||
quantization_bit=quantization_bit,
|
||||
|
@ -38,6 +38,15 @@ if is_gradio_available():
|
||||
import gradio as gr
|
||||
|
||||
|
||||
def switch_hub(hub_name: str) -> None:
|
||||
r"""Switch model hub.
|
||||
|
||||
Inputs: top.hub_name
|
||||
"""
|
||||
os.environ["USE_MODELSCOPE_HUB"] = "1" if hub_name == "modelscope" else "0"
|
||||
os.environ["USE_OPENMIND_HUB"] = "1" if hub_name == "openmind" else "0"
|
||||
|
||||
|
||||
def can_quantize(finetuning_type: str) -> "gr.Dropdown":
|
||||
r"""Judge if the quantization is available in this finetuning type.
|
||||
|
||||
@ -112,7 +121,7 @@ def get_trainer_info(lang: str, output_path: os.PathLike, do_train: bool) -> tup
|
||||
running_log_path = os.path.join(output_path, RUNNING_LOG)
|
||||
if os.path.isfile(running_log_path):
|
||||
with open(running_log_path, encoding="utf-8") as f:
|
||||
running_log = f.read()[-20000:] # avoid lengthy log
|
||||
running_log = "```\n" + f.read()[-20000:] + "\n```\n" # avoid lengthy log
|
||||
|
||||
trainer_log_path = os.path.join(output_path, TRAINER_LOG)
|
||||
if os.path.isfile(trainer_log_path):
|
||||
|
@ -49,11 +49,13 @@ class Engine:
|
||||
def resume(self):
|
||||
r"""Get the initial value of gradio components and restores training status if necessary."""
|
||||
user_config = load_config() if not self.demo_mode else {} # do not use config in demo mode
|
||||
lang = user_config.get("lang", None) or "en"
|
||||
lang = user_config.get("lang") or "en"
|
||||
init_dict = {"top.lang": {"value": lang}, "infer.chat_box": {"visible": self.chatter.loaded}}
|
||||
|
||||
if not self.pure_chat:
|
||||
current_time = get_time()
|
||||
hub_name = user_config.get("hub_name") or "huggingface"
|
||||
init_dict["top.hub_name"] = {"value": hub_name}
|
||||
init_dict["train.current_time"] = {"value": current_time}
|
||||
init_dict["train.output_dir"] = {"value": f"train_{current_time}"}
|
||||
init_dict["train.config_path"] = {"value": f"{current_time}.yaml"}
|
||||
|
@ -22,6 +22,7 @@ from .components import (
|
||||
create_chat_box,
|
||||
create_eval_tab,
|
||||
create_export_tab,
|
||||
create_footer,
|
||||
create_infer_tab,
|
||||
create_top,
|
||||
create_train_tab,
|
||||
@ -38,15 +39,13 @@ def create_ui(demo_mode: bool = False) -> "gr.Blocks":
|
||||
engine = Engine(demo_mode=demo_mode, pure_chat=False)
|
||||
hostname = os.getenv("HOSTNAME", os.getenv("COMPUTERNAME", platform.node())).split(".")[0]
|
||||
|
||||
with gr.Blocks(title=f"LLaMA Board ({hostname})", css=CSS) as demo:
|
||||
with gr.Blocks(title=f"LLaMA Factory ({hostname})", css=CSS) as demo:
|
||||
title = gr.HTML()
|
||||
subtitle = gr.HTML()
|
||||
if demo_mode:
|
||||
gr.HTML("<h1><center>LLaMA Board: A One-stop Web UI for Getting Started with LLaMA Factory</center></h1>")
|
||||
gr.HTML(
|
||||
'<h3><center>Visit <a href="https://github.com/hiyouga/LLaMA-Factory" target="_blank">'
|
||||
"LLaMA Factory</a> for details.</center></h3>"
|
||||
)
|
||||
gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
|
||||
|
||||
engine.manager.add_elems("head", {"title": title, "subtitle": subtitle})
|
||||
engine.manager.add_elems("top", create_top())
|
||||
lang: gr.Dropdown = engine.manager.get_elem_by_id("top.lang")
|
||||
|
||||
@ -63,6 +62,7 @@ def create_ui(demo_mode: bool = False) -> "gr.Blocks":
|
||||
with gr.Tab("Export"):
|
||||
engine.manager.add_elems("export", create_export_tab(engine))
|
||||
|
||||
engine.manager.add_elems("footer", create_footer())
|
||||
demo.load(engine.resume, outputs=engine.manager.get_elem_list(), concurrency_limit=None)
|
||||
lang.change(engine.change_lang, [lang], engine.manager.get_elem_list(), queue=False)
|
||||
lang.input(save_config, inputs=[lang], queue=False)
|
||||
|
@ -13,6 +13,55 @@
|
||||
# limitations under the License.
|
||||
|
||||
LOCALES = {
|
||||
"title": {
|
||||
"en": {
|
||||
"value": "<h1><center>🦙🏭LLaMA Factory: Unified Efficient Fine-Tuning of 100+ LLMs</center></h1>",
|
||||
},
|
||||
"ru": {
|
||||
"value": "<h1><center>🦙🏭LLaMA Factory: Унифицированная эффективная тонкая настройка 100+ LLMs</center></h1>",
|
||||
},
|
||||
"zh": {
|
||||
"value": "<h1><center>🦙🏭LLaMA Factory: 一站式大模型高效微调平台</center></h1>",
|
||||
},
|
||||
"ko": {
|
||||
"value": "<h1><center>🦙🏭LLaMA Factory: 100+ LLMs를 위한 통합 효율적인 튜닝</center></h1>",
|
||||
},
|
||||
"ja": {
|
||||
"value": "<h1><center>🦙🏭LLaMA Factory: 100+ LLMs の統合効率的なチューニング</center></h1>",
|
||||
},
|
||||
},
|
||||
"subtitle": {
|
||||
"en": {
|
||||
"value": (
|
||||
"<h3><center>Visit <a href='https://github.com/hiyouga/LLaMA-Factory' target='_blank'>"
|
||||
"GitHub Page</a></center></h3>"
|
||||
),
|
||||
},
|
||||
"ru": {
|
||||
"value": (
|
||||
"<h3><center>Посетить <a href='https://github.com/hiyouga/LLaMA-Factory' target='_blank'>"
|
||||
"страницу GitHub</a></center></h3>"
|
||||
),
|
||||
},
|
||||
"zh": {
|
||||
"value": (
|
||||
"<h3><center>访问 <a href='https://github.com/hiyouga/LLaMA-Factory' target='_blank'>"
|
||||
"GitHub 主页</a></center></h3>"
|
||||
),
|
||||
},
|
||||
"ko": {
|
||||
"value": (
|
||||
"<h3><center><a href='https://github.com/hiyouga/LLaMA-Factory' target='_blank'>"
|
||||
"GitHub 페이지</a>를 방문하세요.</center></h3>"
|
||||
),
|
||||
},
|
||||
"ja": {
|
||||
"value": (
|
||||
"<h3><center><a href='https://github.com/hiyouga/LLaMA-Factory' target='_blank'>"
|
||||
"GitHub ページ</a>にアクセスする</center></h3>"
|
||||
),
|
||||
},
|
||||
},
|
||||
"lang": {
|
||||
"en": {
|
||||
"label": "Language",
|
||||
@ -74,6 +123,28 @@ LOCALES = {
|
||||
"info": "事前学習済みモデルへのパス、または Hugging Face のモデル識別子。",
|
||||
},
|
||||
},
|
||||
"hub_name": {
|
||||
"en": {
|
||||
"label": "Hub name",
|
||||
"info": "Choose the model download source.",
|
||||
},
|
||||
"ru": {
|
||||
"label": "Имя хаба",
|
||||
"info": "Выберите источник загрузки модели.",
|
||||
},
|
||||
"zh": {
|
||||
"label": "模型下载源",
|
||||
"info": "选择模型下载源。(网络受限环境推荐使用 ModelScope)",
|
||||
},
|
||||
"ko": {
|
||||
"label": "모델 다운로드 소스",
|
||||
"info": "모델 다운로드 소스를 선택하세요.",
|
||||
},
|
||||
"ja": {
|
||||
"label": "モデルダウンロードソース",
|
||||
"info": "モデルをダウンロードするためのソースを選択してください。",
|
||||
},
|
||||
},
|
||||
"finetuning_type": {
|
||||
"en": {
|
||||
"label": "Finetuning method",
|
||||
@ -2849,6 +2920,28 @@ LOCALES = {
|
||||
"value": "エクスポート",
|
||||
},
|
||||
},
|
||||
"device_memory": {
|
||||
"en": {
|
||||
"label": "Device memory",
|
||||
"info": "Current memory usage of the device (GB).",
|
||||
},
|
||||
"ru": {
|
||||
"label": "Память устройства",
|
||||
"info": "Текущая память на устройстве (GB).",
|
||||
},
|
||||
"zh": {
|
||||
"label": "设备显存",
|
||||
"info": "当前设备的显存(GB)。",
|
||||
},
|
||||
"ko": {
|
||||
"label": "디바이스 메모리",
|
||||
"info": "지금 사용 중인 기기 메모리 (GB).",
|
||||
},
|
||||
"ja": {
|
||||
"label": "デバイスメモリ",
|
||||
"info": "現在のデバイスのメモリ(GB)。",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
|
@ -16,14 +16,13 @@ import json
|
||||
import os
|
||||
from collections.abc import Generator
|
||||
from copy import deepcopy
|
||||
from subprocess import Popen, TimeoutExpired
|
||||
from subprocess import PIPE, Popen, TimeoutExpired
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from transformers.trainer import TRAINING_ARGS_NAME
|
||||
from transformers.utils import is_torch_npu_available
|
||||
|
||||
from ..extras.constants import LLAMABOARD_CONFIG, MULTIMODAL_SUPPORTED_MODELS, PEFT_METHODS, TRAINING_STAGES
|
||||
from ..extras.misc import is_accelerator_available, torch_gc, use_ray
|
||||
from ..extras.misc import is_accelerator_available, torch_gc
|
||||
from ..extras.packages import is_gradio_available
|
||||
from .common import (
|
||||
DEFAULT_CACHE_DIR,
|
||||
@ -114,7 +113,7 @@ class Runner:
|
||||
|
||||
return ""
|
||||
|
||||
def _finalize(self, lang: str, finish_info: str) -> str:
|
||||
def _finalize(self, lang: str, finish_info: str) -> None:
|
||||
r"""Clean the cached memory and resets the runner."""
|
||||
finish_info = ALERTS["info_aborted"][lang] if self.aborted else finish_info
|
||||
gr.Info(finish_info)
|
||||
@ -123,7 +122,6 @@ class Runner:
|
||||
self.running = False
|
||||
self.running_data = None
|
||||
torch_gc()
|
||||
return finish_info
|
||||
|
||||
def _parse_train_args(self, data: dict["Component", Any]) -> dict[str, Any]:
|
||||
r"""Build and validate the training arguments."""
|
||||
@ -314,11 +312,13 @@ class Runner:
|
||||
max_samples=int(get("eval.max_samples")),
|
||||
per_device_eval_batch_size=get("eval.batch_size"),
|
||||
predict_with_generate=True,
|
||||
report_to="none",
|
||||
max_new_tokens=get("eval.max_new_tokens"),
|
||||
top_p=get("eval.top_p"),
|
||||
temperature=get("eval.temperature"),
|
||||
output_dir=get_save_dir(model_name, finetuning_type, get("eval.output_dir")),
|
||||
trust_remote_code=True,
|
||||
ddp_timeout=180000000,
|
||||
)
|
||||
|
||||
if get("eval.predict"):
|
||||
@ -375,7 +375,7 @@ class Runner:
|
||||
env["FORCE_TORCHRUN"] = "1"
|
||||
|
||||
# NOTE: DO NOT USE shell=True to avoid security risk
|
||||
self.trainer = Popen(["llamafactory-cli", "train", save_cmd(args)], env=env)
|
||||
self.trainer = Popen(["llamafactory-cli", "train", save_cmd(args)], env=env, stderr=PIPE, text=True)
|
||||
yield from self.monitor()
|
||||
|
||||
def _build_config_dict(self, data: dict["Component", Any]) -> dict[str, Any]:
|
||||
@ -417,7 +417,8 @@ class Runner:
|
||||
swanlab_link = self.manager.get_elem_by_id("train.swanlab_link") if self.do_train else None
|
||||
|
||||
running_log = ""
|
||||
while self.trainer is not None:
|
||||
return_code = -1
|
||||
while return_code == -1:
|
||||
if self.aborted:
|
||||
yield {
|
||||
output_box: ALERTS["info_aborting"][lang],
|
||||
@ -436,27 +437,26 @@ class Runner:
|
||||
return_dict[swanlab_link] = running_info["swanlab_link"]
|
||||
|
||||
yield return_dict
|
||||
|
||||
try:
|
||||
self.trainer.wait(2)
|
||||
self.trainer = None
|
||||
stderr = self.trainer.communicate(timeout=2)[1]
|
||||
return_code = self.trainer.returncode
|
||||
except TimeoutExpired:
|
||||
continue
|
||||
|
||||
if self.do_train:
|
||||
if os.path.exists(os.path.join(output_path, TRAINING_ARGS_NAME)) or use_ray():
|
||||
finish_info = ALERTS["info_finished"][lang]
|
||||
if return_code == 0 or self.aborted:
|
||||
finish_info = ALERTS["info_finished"][lang]
|
||||
if self.do_train:
|
||||
finish_log = ALERTS["info_finished"][lang] + "\n\n" + running_log
|
||||
else:
|
||||
finish_info = ALERTS["err_failed"][lang]
|
||||
finish_log = load_eval_results(os.path.join(output_path, "all_results.json")) + "\n\n" + running_log
|
||||
else:
|
||||
if os.path.exists(os.path.join(output_path, "all_results.json")) or use_ray():
|
||||
finish_info = load_eval_results(os.path.join(output_path, "all_results.json"))
|
||||
else:
|
||||
finish_info = ALERTS["err_failed"][lang]
|
||||
print(stderr)
|
||||
finish_info = ALERTS["err_failed"][lang]
|
||||
finish_log = ALERTS["err_failed"][lang] + f" Exit code: {return_code}\n\n```\n{stderr}\n```\n"
|
||||
|
||||
return_dict = {
|
||||
output_box: self._finalize(lang, finish_info) + "\n\n" + running_log,
|
||||
progress_bar: gr.Slider(visible=False),
|
||||
}
|
||||
self._finalize(lang, finish_info)
|
||||
return_dict = {output_box: finish_log, progress_bar: gr.Slider(visible=False)}
|
||||
yield return_dict
|
||||
|
||||
def save_args(self, data):
|
||||
|
@ -110,8 +110,8 @@ def test_glm4_function_formatter():
|
||||
def test_glm4_tool_formatter():
|
||||
formatter = ToolFormatter(tool_format="glm4")
|
||||
assert formatter.apply(content=json.dumps(TOOLS)) == [
|
||||
"你是一个名为 ChatGLM 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,"
|
||||
"你的任务是针对用户的问题和要求提供适当的答复和支持。# 可用工具\n\n"
|
||||
"你是一个名为 ChatGLM 的人工智能助手。你是基于智谱 AI 公司训练的语言模型 GLM-4 模型开发的,"
|
||||
"你的任务是针对用户的问题和要求提供适当的答复和支持。\n\n# 可用工具\n\n"
|
||||
f"## test_tool\n\n{json.dumps(TOOLS[0], indent=4, ensure_ascii=False)}\n"
|
||||
"在调用上述函数时,请使用 Json 格式表示调用的参数。"
|
||||
]
|
||||
|
@ -238,7 +238,6 @@ def test_llama4_plugin():
|
||||
_check_plugin(**check_inputs)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_transformers_version_greater_than("4.47.0"), reason="Requires transformers>=4.47.0")
|
||||
def test_llava_plugin():
|
||||
image_seqlen = 576
|
||||
tokenizer_module = _load_tokenizer_module(model_name_or_path="llava-hf/llava-1.5-7b-hf")
|
||||
|
@ -226,6 +226,19 @@ def test_gemma_template(use_fast: bool):
|
||||
_check_template("google/gemma-3-4b-it", "gemma", prompt_str, answer_str, use_fast)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.")
|
||||
@pytest.mark.parametrize("use_fast", [True, False])
|
||||
def test_gemma2_template(use_fast: bool):
|
||||
prompt_str = (
|
||||
f"<bos><start_of_turn>user\n{MESSAGES[0]['content']}<end_of_turn>\n"
|
||||
f"<start_of_turn>model\n{MESSAGES[1]['content']}<end_of_turn>\n"
|
||||
f"<start_of_turn>user\n{MESSAGES[2]['content']}<end_of_turn>\n"
|
||||
"<start_of_turn>model\n"
|
||||
)
|
||||
answer_str = f"{MESSAGES[3]['content']}<end_of_turn>\n"
|
||||
_check_template("google/gemma-2-2b-it", "gemma2", prompt_str, answer_str, use_fast)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.")
|
||||
@pytest.mark.parametrize("use_fast", [True, False])
|
||||
def test_llama3_template(use_fast: bool):
|
||||
|
@ -1,2 +1,2 @@
|
||||
# change if test fails or cache is outdated
|
||||
0.9.3.108
|
||||
0.9.4.100
|
||||
|
Loading…
x
Reference in New Issue
Block a user