mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-02-28 08:46:00 +08:00
Compare commits
22 Commits
v0.9.4
...
df4c45c9ae
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
df4c45c9ae | ||
|
|
af3b6f5418 | ||
|
|
5aacbe8434 | ||
|
|
5fb5d7ebd3 | ||
|
|
03a70ba8dd | ||
|
|
5cfd804b59 | ||
|
|
4c1eb922e2 | ||
|
|
958fb523a2 | ||
|
|
b4e051bea4 | ||
|
|
d43e1007e8 | ||
|
|
f89d9367e5 | ||
|
|
d22de0d4bf | ||
|
|
ea0b4e2466 | ||
|
|
e944dc442c | ||
|
|
68119e5522 | ||
|
|
f60a6e3d01 | ||
|
|
81b8a50aa5 | ||
|
|
8600530002 | ||
|
|
9ae62c6fc0 | ||
|
|
0087bc253b | ||
|
|
355d5c5e5a | ||
|
|
6fe6bd290b |
11
.github/workflows/tests.yml
vendored
11
.github/workflows/tests.yml
vendored
@@ -54,6 +54,7 @@ jobs:
|
||||
env:
|
||||
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
||||
OS_NAME: ${{ matrix.os }}
|
||||
UV_NO_SYNC: 1
|
||||
|
||||
steps:
|
||||
- name: Checkout
|
||||
@@ -70,7 +71,8 @@ jobs:
|
||||
run: |
|
||||
uv venv
|
||||
uv pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
|
||||
uv pip install -e ".[dev]"
|
||||
uv pip install -e .
|
||||
uv pip install -r requirements/dev.txt
|
||||
|
||||
- name: Install transformers
|
||||
if: ${{ matrix.transformers }}
|
||||
@@ -87,25 +89,18 @@ jobs:
|
||||
- name: Check quality
|
||||
run: |
|
||||
make style && make quality
|
||||
env:
|
||||
UV_NO_SYNC: 1
|
||||
|
||||
- name: Check license
|
||||
run: |
|
||||
make license
|
||||
env:
|
||||
UV_NO_SYNC: 1
|
||||
|
||||
- name: Check build
|
||||
run: |
|
||||
make build
|
||||
env:
|
||||
UV_NO_SYNC: 1
|
||||
|
||||
- name: Test with pytest
|
||||
run: |
|
||||
make test
|
||||
env:
|
||||
UV_NO_SYNC: 1
|
||||
HF_HOME: ${{ runner.temp }}/huggingface
|
||||
HF_HUB_OFFLINE: "${{ steps.hf-hub-cache.outputs.cache-hit == 'true' && '1' || '0' }}"
|
||||
|
||||
27
.github/workflows/tests_cuda.yml
vendored
27
.github/workflows/tests_cuda.yml
vendored
@@ -35,6 +35,13 @@ jobs:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}-${{ matrix.os }}-${{ matrix.python }}
|
||||
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
|
||||
|
||||
env:
|
||||
HF_HOME: "${{ github.workspace }}/../.runner_cache/huggingface"
|
||||
UV_CACHE_DIR: "${{ github.workspace }}/../.runner_cache/uv"
|
||||
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
||||
OS_NAME: ${{ matrix.os }}
|
||||
UV_NO_SYNC: 1
|
||||
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
@@ -52,37 +59,21 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
uv venv
|
||||
uv pip install -e ".[dev]"
|
||||
|
||||
- name: Cache HuggingFace models
|
||||
id: hf-hub-cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ${{ runner.temp }}/huggingface
|
||||
key: hf-cache-${{ runner.os }}-${{ hashFiles('tests/version.txt') }}
|
||||
uv pip install -e .
|
||||
uv pip install -r requirements/dev.txt
|
||||
|
||||
- name: Check quality
|
||||
run: |
|
||||
make style && make quality
|
||||
env:
|
||||
UV_NO_SYNC: 1
|
||||
|
||||
- name: Check license
|
||||
run: |
|
||||
make license
|
||||
env:
|
||||
UV_NO_SYNC: 1
|
||||
|
||||
- name: Check build
|
||||
run: |
|
||||
make build
|
||||
env:
|
||||
UV_NO_SYNC: 1
|
||||
|
||||
- name: Test with pytest
|
||||
run: |
|
||||
make test
|
||||
env:
|
||||
UV_NO_SYNC: 1
|
||||
HF_HOME: ${{ runner.temp }}/huggingface
|
||||
HF_HUB_OFFLINE: "${{ steps.hf-hub-cache.outputs.cache-hit == 'true' && '1' || '0' }}"
|
||||
|
||||
23
.github/workflows/tests_npu.yml
vendored
23
.github/workflows/tests_npu.yml
vendored
@@ -43,6 +43,7 @@ jobs:
|
||||
HF_ENDPOINT: https://hf-mirror.com
|
||||
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
||||
OS_NAME: ${{ matrix.os }}
|
||||
UV_NO_SYNC: 1
|
||||
|
||||
steps:
|
||||
- name: Checkout
|
||||
@@ -58,8 +59,9 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
uv venv
|
||||
uv pip install torch-npu==${{matrix.pytorch_npu}}
|
||||
uv pip install -e ".[dev]"
|
||||
uv pip install -r requirements/npu.txt
|
||||
uv pip install -e .
|
||||
uv pip install -r requirements/dev.txt
|
||||
|
||||
- name: Install node
|
||||
run: |
|
||||
@@ -68,35 +70,18 @@ jobs:
|
||||
curl -fsSL https://deb.nodesource.com/setup_20.x | bash -
|
||||
apt-get install -y nodejs
|
||||
|
||||
- name: Cache files
|
||||
id: hf-hub-cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ${{ runner.temp }}/huggingface
|
||||
key: huggingface-${{ matrix.os }}-${{ matrix.python }}-${{ hashFiles('tests/version.txt') }}
|
||||
|
||||
- name: Check quality
|
||||
run: |
|
||||
make style && make quality
|
||||
env:
|
||||
UV_NO_SYNC: 1
|
||||
|
||||
- name: Check license
|
||||
run: |
|
||||
make license
|
||||
env:
|
||||
UV_NO_SYNC: 1
|
||||
|
||||
- name: Check build
|
||||
run: |
|
||||
make build
|
||||
env:
|
||||
UV_NO_SYNC: 1
|
||||
|
||||
- name: Test with pytest
|
||||
run: |
|
||||
make test
|
||||
env:
|
||||
UV_NO_SYNC: 1
|
||||
HF_HOME: /root/.cache/huggingface
|
||||
HF_HUB_OFFLINE: "${{ steps.hf-hub-cache.outputs.cache-hit == 'true' && '1' || '0' }}"
|
||||
|
||||
61
README.md
61
README.md
@@ -92,7 +92,7 @@ Read technical notes:
|
||||
|
||||
## Features
|
||||
|
||||
- **Various models**: LLaMA, LLaVA, Mistral, Mixtral-MoE, Qwen, Qwen2-VL, DeepSeek, Yi, Gemma, ChatGLM, Phi, etc.
|
||||
- **Various models**: LLaMA, LLaVA, Mistral, Mixtral-MoE, Qwen3, Qwen3-VL, DeepSeek, Gemma, GLM, Phi, etc.
|
||||
- **Integrated methods**: (Continuous) pre-training, (multimodal) supervised fine-tuning, reward modeling, PPO, DPO, KTO, ORPO, etc.
|
||||
- **Scalable resources**: 16-bit full-tuning, freeze-tuning, LoRA and 2/3/4/5/6/8-bit QLoRA via AQLM/AWQ/GPTQ/LLM.int8/HQQ/EETQ.
|
||||
- **Advanced algorithms**: [GaLore](https://github.com/jiaweizzhao/GaLore), [BAdam](https://github.com/Ledzy/BAdam), [APOLLO](https://github.com/zhuhanqing/APOLLO), [Adam-mini](https://github.com/zyushun/Adam-mini), [Muon](https://github.com/KellerJordan/Muon), [OFT](https://github.com/huggingface/peft/tree/main/src/peft/tuners/oft), DoRA, LongLoRA, LLaMA Pro, Mixture-of-Depths, LoRA+, LoftQ and PiSSA.
|
||||
@@ -279,11 +279,10 @@ Read technical notes:
|
||||
| Model | Model size | Template |
|
||||
| ----------------------------------------------------------------- | -------------------------------- | -------------------- |
|
||||
| [BLOOM/BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
|
||||
| [Command R](https://huggingface.co/CohereForAI) | 35B/104B | cohere |
|
||||
| [DeepSeek (LLM/Code/MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek |
|
||||
| [DeepSeek 3-3.2](https://huggingface.co/deepseek-ai) | 236B/671B | deepseek3 |
|
||||
| [DeepSeek R1 (Distill)](https://huggingface.co/deepseek-ai) | 1.5B/7B/8B/14B/32B/70B/671B | deepseekr1 |
|
||||
| [ERNIE-4.5](https://huggingface.co/baidu) | 0.3B/21B/300B | ernie/ernie_nothink |
|
||||
| [ERNIE-4.5](https://huggingface.co/baidu) | 0.3B/21B/300B | ernie_nothink |
|
||||
| [Falcon/Falcon H1](https://huggingface.co/tiiuae) | 0.5B/1.5B/3B/7B/11B/34B/40B/180B | falcon/falcon_h1 |
|
||||
| [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma/gemma2 |
|
||||
| [Gemma 3/Gemma 3n](https://huggingface.co/google) | 270M/1B/4B/6B/8B/12B/27B | gemma3/gemma3n |
|
||||
@@ -295,9 +294,10 @@ Read technical notes:
|
||||
| [Hunyuan (MT)](https://huggingface.co/tencent/) | 7B | hunyuan |
|
||||
| [InternLM 2-3](https://huggingface.co/internlm) | 7B/8B/20B | intern2 |
|
||||
| [InternVL 2.5-3.5](https://huggingface.co/OpenGVLab) | 1B/2B/4B/8B/14B/30B/38B/78B/241B | intern_vl |
|
||||
| [InternLM/Intern-S1-mini](https://huggingface.co/internlm/) | 8B | intern_s1 |
|
||||
| [Intern-S1-mini](https://huggingface.co/internlm/) | 8B | intern_s1 |
|
||||
| [Kimi-VL](https://huggingface.co/moonshotai) | 16B | kimi_vl |
|
||||
| [Ling 2.0 (mini/flash)](https://huggingface.co/inclusionAI) | 16B/100B | bailing_v2 |
|
||||
| [LFM 2.5 (VL)](https://huggingface.co/LiquidAI) | 1.2B/1.6B | lfm2/lfm2_vl |
|
||||
| [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
|
||||
| [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
|
||||
| [Llama 3-3.3](https://huggingface.co/meta-llama) | 1B/3B/8B/70B | llama3 |
|
||||
@@ -307,18 +307,17 @@ Read technical notes:
|
||||
| [LLaVA-NeXT](https://huggingface.co/llava-hf) | 7B/8B/13B/34B/72B/110B | llava_next |
|
||||
| [LLaVA-NeXT-Video](https://huggingface.co/llava-hf) | 7B/34B | llava_next_video |
|
||||
| [MiMo](https://huggingface.co/XiaomiMiMo) | 7B/309B | mimo/mimo_v2 |
|
||||
| [MiniCPM 1-4.1](https://huggingface.co/openbmb) | 0.5B/1B/2B/4B/8B | cpm/cpm3/cpm4 |
|
||||
| [MiniCPM 4](https://huggingface.co/openbmb) | 0.5B/8B | cpm4 |
|
||||
| [MiniCPM-o-2.6/MiniCPM-V-2.6](https://huggingface.co/openbmb) | 8B | minicpm_o/minicpm_v |
|
||||
| [MiniMax-M1/MiniMax-M2](https://huggingface.co/MiniMaxAI/models) | 229B/456B | minimax1/minimax2 |
|
||||
| [Ministral 3](https://huggingface.co/mistralai) | 3B/8B/14B | ministral3 |
|
||||
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
|
||||
| [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
|
||||
| [PaliGemma/PaliGemma2](https://huggingface.co/google) | 3B/10B/28B | paligemma |
|
||||
| [Phi-3/Phi-3.5](https://huggingface.co/microsoft) | 4B/14B | phi |
|
||||
| [Phi-3-small](https://huggingface.co/microsoft) | 7B | phi_small |
|
||||
| [Phi-4](https://huggingface.co/microsoft) | 14B | phi4 |
|
||||
| [Phi-4-mini/Phi-4](https://huggingface.co/microsoft) | 3.8B/14B | phi4_mini/phi4 |
|
||||
| [Pixtral](https://huggingface.co/mistralai) | 12B | pixtral |
|
||||
| [Qwen (1-2.5) (Code/Math/MoE/QwQ)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen |
|
||||
| [Qwen2 (Code/Math/MoE/QwQ)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen |
|
||||
| [Qwen3 (MoE/Instruct/Thinking/Next)](https://huggingface.co/Qwen) | 0.6B/1.7B/4B/8B/14B/32B/80B/235B | qwen3/qwen3_nothink |
|
||||
| [Qwen2-Audio](https://huggingface.co/Qwen) | 7B | qwen2_audio |
|
||||
| [Qwen2.5-Omni](https://huggingface.co/Qwen) | 3B/7B | qwen2_omni |
|
||||
@@ -327,8 +326,6 @@ Read technical notes:
|
||||
| [Qwen3-VL](https://huggingface.co/Qwen) | 2B/4B/8B/30B/32B/235B | qwen3_vl |
|
||||
| [Seed (OSS/Coder)](https://huggingface.co/ByteDance-Seed) | 8B/36B | seed_oss/seed_coder |
|
||||
| [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
|
||||
| [VibeThinker-1.5B](https://huggingface.co/WeiboAI) | 1.5B | qwen3 |
|
||||
| [Yi/Yi-1.5 (Code)](https://huggingface.co/01-ai) | 1.5B/6B/9B/34B | yi |
|
||||
| [Yuan 2](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan |
|
||||
|
||||
> [!NOTE]
|
||||
@@ -514,12 +511,13 @@ huggingface-cli login
|
||||
#### Install from Source
|
||||
|
||||
```bash
|
||||
git clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git
|
||||
cd LLaMA-Factory
|
||||
pip install -e ".[metrics]"
|
||||
git clone --depth 1 https://github.com/hiyouga/LlamaFactory.git
|
||||
cd LlamaFactory
|
||||
pip install -e .
|
||||
pip install -r requirements/metrics.txt
|
||||
```
|
||||
|
||||
Optional dependencies available: `metrics`, `deepspeed`. Install with: `pip install -e ".[metrics,deepspeed]"`
|
||||
Optional dependencies available: `metrics`, `deepspeed`. Install with: `pip install -e . && pip install -r requirements/metrics.txt -r requirements/deepspeed.txt`
|
||||
|
||||
Additional dependencies for specific features are available in `examples/requirements/`.
|
||||
|
||||
@@ -577,36 +575,21 @@ To enable FlashAttention-2 on the Windows platform, please use the script from [
|
||||
|
||||
<details><summary>For Ascend NPU users</summary>
|
||||
|
||||
To install LLaMA Factory on Ascend NPU devices, please upgrade Python to version 3.10 or higher: `pip install -e . torch-npu==2.7.1`. Additionally, you need to install the **[Ascend CANN Toolkit and Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**. Please follow the [installation tutorial](https://www.hiascend.com/document/detail/en/CANNCommunityEdition/600alphaX/softwareinstall/instg/atlasdeploy_03_0031.html) or use the following commands:
|
||||
To install LLaMA Factory on Ascend NPU devices, please upgrade Python to version 3.10 or higher: `pip install -r requirements/npu.txt`. Additionally, you need to install the **Ascend CANN Toolkit and Kernels**. Please follow the [installation tutorial](https://llamafactory.readthedocs.io/en/latest/advanced/npu_installation.html).
|
||||
|
||||
|
||||
You can also download the pre-built Docker images:
|
||||
|
||||
```bash
|
||||
# replace the url according to your CANN version and devices
|
||||
# install CANN Toolkit
|
||||
wget https://ascend-repo.obs.cn-east-2.myhuaweicloud.com/Milan-ASL/Milan-ASL%20V100R001C20SPC702/Ascend-cann-toolkit_8.0.0.alpha002_linux-"$(uname -i)".run
|
||||
bash Ascend-cann-toolkit_8.0.0.alpha002_linux-"$(uname -i)".run --install
|
||||
# Docker Hub
|
||||
docker pull hiyouga/llamafactory:latest-npu-a2
|
||||
docker pull hiyouga/llamafactory:latest-npu-a3
|
||||
|
||||
# install CANN Kernels
|
||||
wget https://ascend-repo.obs.cn-east-2.myhuaweicloud.com/Milan-ASL/Milan-ASL%20V100R001C20SPC702/Ascend-cann-kernels-910b_8.0.0.alpha002_linux-"$(uname -i)".run
|
||||
bash Ascend-cann-kernels-910b_8.0.0.alpha002_linux-"$(uname -i)".run --install
|
||||
|
||||
# set env variables
|
||||
source /usr/local/Ascend/ascend-toolkit/set_env.sh
|
||||
# quay.io
|
||||
docker pull quay.io/ascend/llamafactory:latest-npu-a2
|
||||
docker pull quay.io/ascend/llamafactory:latest-npu-a3
|
||||
```
|
||||
|
||||
| Requirement | Minimum | Recommend |
|
||||
| ------------ | ------- | -------------- |
|
||||
| CANN | 8.0.RC1 | 8.0.0.alpha002 |
|
||||
| torch | 2.1.0 | 2.7.1 |
|
||||
| torch-npu | 2.1.0 | 2.7.1 |
|
||||
| deepspeed | 0.13.2 | 0.13.2 |
|
||||
| vllm-ascend | - | 0.7.3 |
|
||||
|
||||
Remember to use `ASCEND_RT_VISIBLE_DEVICES` instead of `CUDA_VISIBLE_DEVICES` to specify the device to use.
|
||||
|
||||
If you cannot infer model on NPU devices, try setting `do_sample: false` in the configurations.
|
||||
|
||||
Download the pre-built Docker images: [32GB](http://mirrors.cn-central-221.ovaijisuan.com/detail/130.html) | [64GB](http://mirrors.cn-central-221.ovaijisuan.com/detail/131.html)
|
||||
|
||||
#### Install BitsAndBytes
|
||||
|
||||
To use QLoRA based on bitsandbytes on Ascend NPU, please follow these 3 steps:
|
||||
|
||||
60
README_zh.md
60
README_zh.md
@@ -94,7 +94,7 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
|
||||
|
||||
## 项目特色
|
||||
|
||||
- **多种模型**:LLaMA、LLaVA、Mistral、Mixtral-MoE、Qwen、Qwen2-VL、DeepSeek、Yi、Gemma、ChatGLM、Phi 等等。
|
||||
- **多种模型**:LLaMA、LLaVA、Mistral、Mixtral-MoE、Qwen3、Qwen3-VL、DeepSeek、Gemma、GLM、Phi 等等。
|
||||
- **集成方法**:(增量)预训练、(多模态)指令监督微调、奖励模型训练、PPO 训练、DPO 训练、KTO 训练、ORPO 训练等等。
|
||||
- **多种精度**:16 比特全参数微调、冻结微调、LoRA 微调和基于 AQLM/AWQ/GPTQ/LLM.int8/HQQ/EETQ 的 2/3/4/5/6/8 比特 QLoRA 微调。
|
||||
- **先进算法**:[GaLore](https://github.com/jiaweizzhao/GaLore)、[BAdam](https://github.com/Ledzy/BAdam)、[APOLLO](https://github.com/zhuhanqing/APOLLO)、[Adam-mini](https://github.com/zyushun/Adam-mini)、[Muon](https://github.com/KellerJordan/Muon)、[OFT](https://github.com/huggingface/peft/tree/main/src/peft/tuners/oft)、DoRA、LongLoRA、LLaMA Pro、Mixture-of-Depths、LoRA+、LoftQ 和 PiSSA。
|
||||
@@ -281,11 +281,10 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
|
||||
| 模型名 | 参数量 | Template |
|
||||
| ----------------------------------------------------------------- | -------------------------------- | -------------------- |
|
||||
| [BLOOM/BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
|
||||
| [Command R](https://huggingface.co/CohereForAI) | 35B/104B | cohere |
|
||||
| [DeepSeek (LLM/Code/MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek |
|
||||
| [DeepSeek 3-3.2](https://huggingface.co/deepseek-ai) | 236B/671B | deepseek3 |
|
||||
| [DeepSeek R1 (Distill)](https://huggingface.co/deepseek-ai) | 1.5B/7B/8B/14B/32B/70B/671B | deepseekr1 |
|
||||
| [ERNIE-4.5](https://huggingface.co/baidu) | 0.3B/21B/300B | ernie/ernie_nothink |
|
||||
| [ERNIE-4.5](https://huggingface.co/baidu) | 0.3B/21B/300B | ernie_nothink |
|
||||
| [Falcon/Falcon H1](https://huggingface.co/tiiuae) | 0.5B/1.5B/3B/7B/11B/34B/40B/180B | falcon/falcon_h1 |
|
||||
| [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma/gemma2 |
|
||||
| [Gemma 3/Gemma 3n](https://huggingface.co/google) | 270M/1B/4B/6B/8B/12B/27B | gemma3/gemma3n |
|
||||
@@ -297,9 +296,10 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
|
||||
| [Hunyuan (MT)](https://huggingface.co/tencent/) | 7B | hunyuan |
|
||||
| [InternLM 2-3](https://huggingface.co/internlm) | 7B/8B/20B | intern2 |
|
||||
| [InternVL 2.5-3.5](https://huggingface.co/OpenGVLab) | 1B/2B/4B/8B/14B/30B/38B/78B/241B | intern_vl |
|
||||
| [InternLM/Intern-S1-mini](https://huggingface.co/internlm/) | 8B | intern_s1 |
|
||||
| [Intern-S1-mini](https://huggingface.co/internlm/) | 8B | intern_s1 |
|
||||
| [Kimi-VL](https://huggingface.co/moonshotai) | 16B | kimi_vl |
|
||||
| [Ling 2.0 (mini/flash)](https://huggingface.co/inclusionAI) | 16B/100B | bailing_v2 |
|
||||
| [LFM 2.5 (VL)](https://huggingface.co/LiquidAI) | 1.2B/1.6B | lfm2/lfm2_vl |
|
||||
| [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
|
||||
| [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
|
||||
| [Llama 3-3.3](https://huggingface.co/meta-llama) | 1B/3B/8B/70B | llama3 |
|
||||
@@ -309,18 +309,17 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
|
||||
| [LLaVA-NeXT](https://huggingface.co/llava-hf) | 7B/8B/13B/34B/72B/110B | llava_next |
|
||||
| [LLaVA-NeXT-Video](https://huggingface.co/llava-hf) | 7B/34B | llava_next_video |
|
||||
| [MiMo](https://huggingface.co/XiaomiMiMo) | 7B/309B | mimo/mimo_v2 |
|
||||
| [MiniCPM 1-4.1](https://huggingface.co/openbmb) | 0.5B/1B/2B/4B/8B | cpm/cpm3/cpm4 |
|
||||
| [MiniCPM 4](https://huggingface.co/openbmb) | 0.5B/8B | cpm4 |
|
||||
| [MiniCPM-o-2.6/MiniCPM-V-2.6](https://huggingface.co/openbmb) | 8B | minicpm_o/minicpm_v |
|
||||
| [MiniMax-M1/MiniMax-M2](https://huggingface.co/MiniMaxAI/models) | 229B/456B | minimax1/minimax2 |
|
||||
| [Ministral 3](https://huggingface.co/mistralai) | 3B/8B/14B | ministral3 |
|
||||
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
|
||||
| [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
|
||||
| [PaliGemma/PaliGemma2](https://huggingface.co/google) | 3B/10B/28B | paligemma |
|
||||
| [Phi-3/Phi-3.5](https://huggingface.co/microsoft) | 4B/14B | phi |
|
||||
| [Phi-3-small](https://huggingface.co/microsoft) | 7B | phi_small |
|
||||
| [Phi-4](https://huggingface.co/microsoft) | 14B | phi4 |
|
||||
| [Phi-4-mini/Phi-4](https://huggingface.co/microsoft) | 3.8B/14B | phi4_mini/phi4 |
|
||||
| [Pixtral](https://huggingface.co/mistralai) | 12B | pixtral |
|
||||
| [Qwen (1-2.5) (Code/Math/MoE/QwQ)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen |
|
||||
| [Qwen2 (Code/Math/MoE/QwQ)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen |
|
||||
| [Qwen3 (MoE/Instruct/Thinking/Next)](https://huggingface.co/Qwen) | 0.6B/1.7B/4B/8B/14B/32B/80B/235B | qwen3/qwen3_nothink |
|
||||
| [Qwen2-Audio](https://huggingface.co/Qwen) | 7B | qwen2_audio |
|
||||
| [Qwen2.5-Omni](https://huggingface.co/Qwen) | 3B/7B | qwen2_omni |
|
||||
@@ -329,8 +328,6 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
|
||||
| [Qwen3-VL](https://huggingface.co/Qwen) | 2B/4B/8B/30B/32B/235B | qwen3_vl |
|
||||
| [Seed (OSS/Coder)](https://huggingface.co/ByteDance-Seed) | 8B/36B | seed_oss/seed_coder |
|
||||
| [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
|
||||
| [VibeThinker-1.5B](https://huggingface.co/WeiboAI) | 1.5B | qwen3 |
|
||||
| [Yi/Yi-1.5 (Code)](https://huggingface.co/01-ai) | 1.5B/6B/9B/34B | yi |
|
||||
| [Yuan 2](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan |
|
||||
|
||||
> [!NOTE]
|
||||
@@ -516,12 +513,13 @@ huggingface-cli login
|
||||
#### 从源码安装
|
||||
|
||||
```bash
|
||||
git clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git
|
||||
cd LLaMA-Factory
|
||||
pip install -e ".[metrics]"
|
||||
git clone --depth 1 https://github.com/hiyouga/LlamaFactory.git
|
||||
cd LlamaFactory
|
||||
pip install -e .
|
||||
pip install -r requirements/metrics.txt
|
||||
```
|
||||
|
||||
可选的额外依赖项:`metrics`、`deepspeed`。使用 `pip install -e ".[metrics,deepspeed]"` 安装。
|
||||
可选的额外依赖项:`metrics`、`deepspeed`。使用 `pip install -e . && pip install -r requirements/metrics.txt -r requirements/deepspeed.txt` 安装。
|
||||
|
||||
其他可选依赖项请参考 `examples/requirements/` 目录下的文件。
|
||||
|
||||
@@ -579,36 +577,20 @@ pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/downl
|
||||
|
||||
<details><summary>昇腾 NPU 用户指南</summary>
|
||||
|
||||
在昇腾 NPU 设备上安装 LLaMA Factory 时,请升级 Python 到 3.10 及以上,并需要指定额外依赖项,使用 `pip install -e . torch-npu==2.7.1` 命令安装。此外,还需要安装 **[Ascend CANN Toolkit 与 Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**,安装方法请参考[安装教程](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/80RC2alpha002/quickstart/quickstart/quickstart_18_0004.html)或使用以下命令:
|
||||
在昇腾 NPU 设备上安装 LLaMA Factory 时,请升级 Python 到 3.10 及以上,并需要指定额外依赖项,使用 `pip install -r requirements/npu.txt` 命令安装。此外,还需要安装 **Ascend CANN Toolkit 与 Kernels**,安装方法请参考[安装教程](https://llamafactory.readthedocs.io/zh-cn/latest/advanced/npu_installation.html)。
|
||||
|
||||
您可以直接下载预安装的最新docker镜像:
|
||||
|
||||
```bash
|
||||
# 请替换 URL 为 CANN 版本和设备型号对应的 URL
|
||||
# 安装 CANN Toolkit
|
||||
wget https://ascend-repo.obs.cn-east-2.myhuaweicloud.com/Milan-ASL/Milan-ASL%20V100R001C17SPC701/Ascend-cann-toolkit_8.0.RC1.alpha001_linux-"$(uname -i)".run
|
||||
bash Ascend-cann-toolkit_8.0.RC1.alpha001_linux-"$(uname -i)".run --install
|
||||
# Docker Hub
|
||||
docker pull hiyouga/llamafactory:latest-npu-a2
|
||||
docker pull hiyouga/llamafactory:latest-npu-a3
|
||||
|
||||
# 安装 CANN Kernels
|
||||
wget https://ascend-repo.obs.cn-east-2.myhuaweicloud.com/Milan-ASL/Milan-ASL%20V100R001C17SPC701/Ascend-cann-kernels-910b_8.0.RC1.alpha001_linux.run
|
||||
bash Ascend-cann-kernels-910b_8.0.RC1.alpha001_linux.run --install
|
||||
|
||||
# 设置环境变量
|
||||
source /usr/local/Ascend/ascend-toolkit/set_env.sh
|
||||
# quay.io
|
||||
docker pull quay.io/ascend/llamafactory:latest-npu-a2
|
||||
docker pull quay.io/ascend/llamafactory:latest-npu-a3
|
||||
```
|
||||
|
||||
| 依赖项 | 至少 | 推荐 |
|
||||
| ------------ | ------- | -------------- |
|
||||
| CANN | 8.0.RC1 | 8.0.0.alpha002 |
|
||||
| torch | 2.1.0 | 2.7.1 |
|
||||
| torch-npu | 2.1.0 | 2.7.1 |
|
||||
| deepspeed | 0.13.2 | 0.13.2 |
|
||||
| vllm-ascend | - | 0.7.3 |
|
||||
|
||||
请使用 `ASCEND_RT_VISIBLE_DEVICES` 而非 `CUDA_VISIBLE_DEVICES` 来指定运算设备。
|
||||
|
||||
如果遇到无法正常推理的情况,请尝试设置 `do_sample: false`。
|
||||
|
||||
下载预构建 Docker 镜像:[32GB](http://mirrors.cn-central-221.ovaijisuan.com/detail/130.html) | [64GB](http://mirrors.cn-central-221.ovaijisuan.com/detail/131.html)
|
||||
|
||||
#### 安装 BitsAndBytes
|
||||
|
||||
如果要在 Ascend NPU 上进行基于 bitsandbytes 的 QLoRA 量化微调,请执行如下步骤:
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -32,7 +32,8 @@ RUN pip config set global.index-url "${PIP_INDEX}" && \
|
||||
COPY . /app
|
||||
|
||||
# Install LLaMA Factory
|
||||
RUN pip install --no-cache-dir --no-build-isolation -e ".[metrics,deepspeed]"
|
||||
RUN pip install --no-cache-dir --no-build-isolation -e . && \
|
||||
pip install --no-cache-dir --no-build-isolation -r requirements/metrics.txt -r requirements/deepspeed.txt
|
||||
|
||||
# Rebuild flash attention
|
||||
RUN if [ "${INSTALL_FLASHATTN}" == "true" ]; then \
|
||||
|
||||
@@ -60,7 +60,8 @@ WORKDIR /app
|
||||
COPY . /app
|
||||
|
||||
# Install LLaMA Factory
|
||||
RUN pip install --no-cache-dir -e ".[metrics]" --no-build-isolation
|
||||
RUN pip install --no-cache-dir -e . --no-build-isolation && \
|
||||
pip install --no-cache-dir -r requirements/metrics.txt --no-build-isolation
|
||||
|
||||
RUN pip install "git+https://github.com/alibaba/roll.git#subdirectory=mcore_adapter"
|
||||
|
||||
|
||||
@@ -35,7 +35,8 @@ COPY . /app
|
||||
# Install torch-npu
|
||||
RUN pip uninstall -y torch torchvision torchaudio && \
|
||||
pip install --no-cache-dir "torch==2.7.1" "torch-npu==2.7.1" "torchvision==0.22.1" "torchaudio==2.7.1" --index-url "${PYTORCH_INDEX}" && \
|
||||
pip install --no-cache-dir -e ".[metrics]" --no-build-isolation
|
||||
pip install --no-cache-dir -e . --no-build-isolation && \
|
||||
pip install --no-cache-dir -r requirements/metrics.txt --no-build-isolation
|
||||
|
||||
# Set up volumes
|
||||
# VOLUME [ "/root/.cache/huggingface", "/app/shared_data", "/app/output" ]
|
||||
|
||||
@@ -34,7 +34,8 @@ COPY . /app
|
||||
|
||||
# Reinstall pytorch rocm and install LLaMA Factory
|
||||
RUN pip uninstall -y torch torchvision torchaudio && \
|
||||
pip install --no-cache-dir --no-build-isolation -e --pre ".[metrics,deepspeed]" --index-url "${PYTORCH_INDEX}"
|
||||
pip install --no-cache-dir --no-build-isolation -e --pre . --index-url "${PYTORCH_INDEX}" && \
|
||||
pip install --no-cache-dir --no-build-isolation -r requirements/metrics.txt -r requirements/deepspeed.txt --index-url "${PYTORCH_INDEX}"
|
||||
|
||||
# Rebuild flash attention
|
||||
RUN if [ "${INSTALL_FLASHATTN}" == "true" ]; then \
|
||||
|
||||
38
examples/extras/eaft/qwen25_05b_eaft_full.yaml
Normal file
38
examples/extras/eaft/qwen25_05b_eaft_full.yaml
Normal file
@@ -0,0 +1,38 @@
|
||||
### model
|
||||
model_name_or_path: Qwen/Qwen2.5-0.5B-Instruct
|
||||
trust_remote_code: true
|
||||
|
||||
### method
|
||||
stage: sft
|
||||
do_train: true
|
||||
finetuning_type: full
|
||||
use_eaft_loss: true
|
||||
|
||||
### dataset
|
||||
dataset: identity,alpaca_en_demo
|
||||
template: qwen
|
||||
cutoff_len: 2048
|
||||
max_samples: 1000
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
dataloader_num_workers: 4
|
||||
|
||||
### output
|
||||
output_dir: qwen2.5-0_5b/full/sft_eaft
|
||||
logging_steps: 1
|
||||
save_steps: 500
|
||||
plot_loss: true
|
||||
overwrite_output_dir: true
|
||||
save_only_model: false
|
||||
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
|
||||
|
||||
|
||||
### train
|
||||
per_device_train_batch_size: 2
|
||||
gradient_accumulation_steps: 8
|
||||
learning_rate: 1.0e-5
|
||||
num_train_epochs: 3.0
|
||||
lr_scheduler_type: cosine
|
||||
warmup_ratio: 0.1
|
||||
bf16: true
|
||||
ddp_timeout: 180000000
|
||||
@@ -5,6 +5,6 @@ infer_backend: ktransformers # choices: [huggingface, vllm, sglang, ktransforme
|
||||
trust_remote_code: true
|
||||
|
||||
use_kt: true # use KTransformers as LoRA sft backend to inference
|
||||
kt_optimize_rule: examples/kt_optimize_rules/DeepSeek-V2-Lite-Chat-sft-amx.yaml
|
||||
kt_optimize_rule: examples/ktransformers/kt_optimize_rules/DeepSeek-V2-Lite-Chat-sft-amx.yaml
|
||||
cpu_infer: 32
|
||||
chunk_size: 8192
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
model_name_or_path: opensourcerelease/DeepSeek-V3-bf16
|
||||
template: deepseek
|
||||
template: deepseek3
|
||||
infer_backend: ktransformers # choices: [huggingface, vllm, sglang, ktransformers]
|
||||
trust_remote_code: true
|
||||
|
||||
use_kt: true # use KTransformers as LoRA sft backend to inference
|
||||
kt_optimize_rule: examples/kt_optimize_rules/DeepSeek-V3-Chat-sft-amx-multi-gpu.yaml
|
||||
kt_optimize_rule: examples/ktransformers/kt_optimize_rules/DeepSeek-V3-Chat-sft-amx-multi-gpu.yaml
|
||||
cpu_infer: 32
|
||||
chunk_size: 8192
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
model_name_or_path: opensourcerelease/DeepSeek-V3-bf16
|
||||
adapter_name_or_path: saves/Kllama_deepseekV3
|
||||
template: deepseek
|
||||
template: deepseek3
|
||||
infer_backend: ktransformers # choices: [huggingface, vllm, sglang, ktransformers]
|
||||
trust_remote_code: true
|
||||
|
||||
use_kt: true # use KTransformers as LoRA sft backend to inference
|
||||
kt_optimize_rule: examples/kt_optimize_rules/DeepSeek-V3-Chat-sft-amx-multi-gpu.yaml
|
||||
kt_optimize_rule: examples/ktransformers/kt_optimize_rules/DeepSeek-V3-Chat-sft-amx-multi-gpu.yaml
|
||||
cpu_infer: 32
|
||||
chunk_size: 8192
|
||||
|
||||
@@ -5,6 +5,6 @@ infer_backend: ktransformers # choices: [huggingface, vllm, sglang, ktransforme
|
||||
trust_remote_code: true
|
||||
|
||||
use_kt: true # use KTransformers as LoRA sft backend to inference
|
||||
kt_optimize_rule: examples/kt_optimize_rules/Qwen3Moe-sft-amx.yaml
|
||||
kt_optimize_rule: examples/ktransformers/kt_optimize_rules/Qwen3Moe-sft-amx.yaml
|
||||
cpu_infer: 32
|
||||
chunk_size: 8192
|
||||
|
||||
@@ -10,7 +10,7 @@ lora_rank: 8
|
||||
lora_target: all
|
||||
|
||||
### dataset
|
||||
dataset: identity
|
||||
dataset: identity, alpaca_en_demo
|
||||
template: deepseek
|
||||
cutoff_len: 2048
|
||||
max_samples: 100000
|
||||
@@ -40,7 +40,7 @@ resume_from_checkpoint: null
|
||||
|
||||
### ktransformers
|
||||
use_kt: true # use KTransformers as LoRA sft backend
|
||||
kt_optimize_rule: examples/kt_optimize_rules/DeepSeek-V2-Lite-Chat-sft-amx.yaml
|
||||
kt_optimize_rule: examples/ktransformers/kt_optimize_rules/DeepSeek-V2-Lite-Chat-sft-amx.yaml
|
||||
cpu_infer: 32
|
||||
chunk_size: 8192
|
||||
|
||||
|
||||
@@ -10,8 +10,8 @@ lora_rank: 8
|
||||
lora_target: all
|
||||
|
||||
### dataset
|
||||
dataset: identity
|
||||
template: deepseek
|
||||
dataset: identity, alpaca_en_demo
|
||||
template: deepseek3
|
||||
cutoff_len: 2048
|
||||
max_samples: 100000
|
||||
overwrite_cache: true
|
||||
@@ -40,7 +40,7 @@ resume_from_checkpoint: null
|
||||
|
||||
### ktransformers
|
||||
use_kt: true # use KTransformers as LoRA sft backend
|
||||
kt_optimize_rule: examples/kt_optimize_rules/DeepSeek-V3-Chat-sft-amx-multi-gpu.yaml
|
||||
kt_optimize_rule: examples/ktransformers/kt_optimize_rules/DeepSeek-V3-Chat-sft-amx-multi-gpu.yaml
|
||||
cpu_infer: 32
|
||||
chunk_size: 8192
|
||||
|
||||
|
||||
@@ -40,7 +40,7 @@ resume_from_checkpoint: null
|
||||
|
||||
### ktransformers
|
||||
use_kt: true # use KTransformers as LoRA sft backend
|
||||
kt_optimize_rule: examples/kt_optimize_rules/Qwen3Moe-sft-amx.yaml
|
||||
kt_optimize_rule: examples/ktransformers/kt_optimize_rules/Qwen3Moe-sft-amx.yaml
|
||||
cpu_infer: 32
|
||||
chunk_size: 8192
|
||||
|
||||
|
||||
@@ -30,7 +30,6 @@ classifiers = [
|
||||
"License :: OSI Approved :: Apache Software License",
|
||||
"Operating System :: OS Independent",
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
"Programming Language :: Python :: 3.11",
|
||||
"Programming Language :: Python :: 3.12",
|
||||
"Programming Language :: Python :: 3.13",
|
||||
@@ -73,14 +72,9 @@ dependencies = [
|
||||
# api
|
||||
"uvicorn",
|
||||
"fastapi",
|
||||
"sse-starlette"
|
||||
"sse-starlette",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
dev = ["pre-commit", "ruff", "pytest", "build"]
|
||||
metrics = ["nltk", "jieba", "rouge-chinese"]
|
||||
deepspeed = ["deepspeed>=0.10.0,<=0.16.9"]
|
||||
|
||||
[project.scripts]
|
||||
llamafactory-cli = "llamafactory.cli:main"
|
||||
lmf = "llamafactory.cli:main"
|
||||
|
||||
1
requirements/deepspeed.txt
Normal file
1
requirements/deepspeed.txt
Normal file
@@ -0,0 +1 @@
|
||||
deepspeed>=0.10.0,<=0.16.9
|
||||
4
requirements/dev.txt
Normal file
4
requirements/dev.txt
Normal file
@@ -0,0 +1,4 @@
|
||||
pre-commit
|
||||
ruff
|
||||
pytest
|
||||
build
|
||||
3
requirements/metrics.txt
Normal file
3
requirements/metrics.txt
Normal file
@@ -0,0 +1,3 @@
|
||||
nltk
|
||||
jieba
|
||||
rouge-chinese
|
||||
4
requirements/npu.txt
Normal file
4
requirements/npu.txt
Normal file
@@ -0,0 +1,4 @@
|
||||
torch==2.7.1
|
||||
torch-npu==2.7.1
|
||||
torchvision==0.22.1
|
||||
torchaudio==2.7.1
|
||||
@@ -28,7 +28,7 @@ try:
|
||||
jieba.setLogLevel(logging.CRITICAL)
|
||||
jieba.initialize()
|
||||
except ImportError:
|
||||
print("Please install llamafactory with `pip install -e .[metrics]`.")
|
||||
print("Please install llamafactory with `pip install -r requirements/metrics.txt`.")
|
||||
raise
|
||||
|
||||
|
||||
|
||||
@@ -2092,6 +2092,73 @@ class VideoLlavaPlugin(BasePlugin):
|
||||
return messages
|
||||
|
||||
|
||||
@dataclass
|
||||
class LFMVLPlugin(BasePlugin):
|
||||
r"""Plugin for LFM2.5-VL vision-language models.
|
||||
|
||||
LFM2.5-VL uses dynamic image token counts based on image resolution.
|
||||
The image processor returns spatial_shapes tensor with [height, width] grid dimensions.
|
||||
Token count per image = (spatial_h * spatial_w) / (downsample_factor^2)
|
||||
"""
|
||||
|
||||
@override
|
||||
def _get_mm_inputs(
|
||||
self,
|
||||
images: list["ImageInput"],
|
||||
videos: list["VideoInput"],
|
||||
audios: list["AudioInput"],
|
||||
processor: "MMProcessor",
|
||||
) -> dict[str, "torch.Tensor"]:
|
||||
image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
|
||||
mm_inputs = {}
|
||||
if len(images) != 0:
|
||||
images = self._regularize_images(
|
||||
images,
|
||||
image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768),
|
||||
image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32),
|
||||
)["images"]
|
||||
mm_inputs.update(image_processor(images, return_tensors="pt"))
|
||||
return mm_inputs
|
||||
|
||||
@override
|
||||
def process_messages(
|
||||
self,
|
||||
messages: list[dict[str, str]],
|
||||
images: list["ImageInput"],
|
||||
videos: list["VideoInput"],
|
||||
audios: list["AudioInput"],
|
||||
processor: Optional["MMProcessor"],
|
||||
) -> list[dict[str, str]]:
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
self._validate_messages(messages, images, videos, audios)
|
||||
num_image_tokens = 0
|
||||
messages = deepcopy(messages)
|
||||
image_processor: BaseImageProcessor = getattr(processor, "image_processor")
|
||||
downsample_factor: int = getattr(image_processor, "downsample_factor", 2)
|
||||
|
||||
if self.expand_mm_tokens and len(images) > 0:
|
||||
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
||||
spatial_shapes = mm_inputs.get("spatial_shapes", [])
|
||||
else:
|
||||
spatial_shapes = []
|
||||
|
||||
for message in messages:
|
||||
content = message["content"]
|
||||
while IMAGE_PLACEHOLDER in content:
|
||||
if self.expand_mm_tokens and len(spatial_shapes) > num_image_tokens:
|
||||
h, w = spatial_shapes[num_image_tokens].tolist()
|
||||
image_seqlen = (h * w) // (downsample_factor * downsample_factor)
|
||||
else:
|
||||
image_seqlen = 1
|
||||
|
||||
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
|
||||
num_image_tokens += 1
|
||||
|
||||
message["content"] = content.replace("{{image}}", self.image_token)
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
PLUGINS = {
|
||||
"base": BasePlugin,
|
||||
"ernie_vl": ErnieVLPlugin,
|
||||
@@ -2104,6 +2171,7 @@ PLUGINS = {
|
||||
"llava": LlavaPlugin,
|
||||
"llava_next": LlavaNextPlugin,
|
||||
"llava_next_video": LlavaNextVideoPlugin,
|
||||
"lfm2_vl": LFMVLPlugin,
|
||||
"minicpm_v": MiniCPMVPlugin,
|
||||
"mllama": MllamaPlugin,
|
||||
"paligemma": PaliGemmaPlugin,
|
||||
|
||||
@@ -649,42 +649,6 @@ register_template(
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="aquila",
|
||||
format_user=StringFormatter(slots=["Human: {{content}}###Assistant:"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}###"]),
|
||||
format_system=StringFormatter(slots=["System: {{content}}###"]),
|
||||
default_system=(
|
||||
"A chat between a curious human and an artificial intelligence assistant. "
|
||||
"The assistant gives helpful, detailed, and polite answers to the human's questions."
|
||||
),
|
||||
stop_words=["</s>"],
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="atom",
|
||||
format_user=StringFormatter(
|
||||
slots=[{"bos_token"}, "Human: {{content}}\n", {"eos_token"}, {"bos_token"}, "Assistant:"]
|
||||
),
|
||||
format_assistant=StringFormatter(slots=["{{content}}\n", {"eos_token"}]),
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="baichuan",
|
||||
format_user=StringFormatter(slots=[{"token": "<reserved_102>"}, "{{content}}", {"token": "<reserved_103>"}]),
|
||||
efficient_eos=True,
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="baichuan2",
|
||||
format_user=StringFormatter(slots=["<reserved_106>{{content}}<reserved_107>"]),
|
||||
efficient_eos=True,
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="bailing",
|
||||
format_user=StringFormatter(slots=["<role>HUMAN</role>{{content}}<role>ASSISTANT</role>"]),
|
||||
@@ -712,20 +676,6 @@ register_template(
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="belle",
|
||||
format_user=StringFormatter(slots=["Human: {{content}}\n\nBelle: "]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}", {"eos_token"}, "\n\n"]),
|
||||
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="bluelm",
|
||||
format_user=StringFormatter(slots=[{"token": "[|Human|]:"}, "{{content}}", {"token": "[|AI|]:"}]),
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="breeze",
|
||||
format_user=StringFormatter(slots=["[INST] {{content}} [/INST] "]),
|
||||
@@ -734,14 +684,6 @@ register_template(
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="chatglm2",
|
||||
format_user=StringFormatter(slots=["[Round {{idx}}]\n\n问:{{content}}\n\n答:"]),
|
||||
format_prefix=EmptyFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}]),
|
||||
efficient_eos=True,
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="chatglm3",
|
||||
format_user=StringFormatter(slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]),
|
||||
@@ -784,29 +726,6 @@ register_template(
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="codegeex2",
|
||||
format_prefix=EmptyFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}]),
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="codegeex4",
|
||||
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>\n"]),
|
||||
format_system=StringFormatter(slots=["<|system|>\n{{content}}"]),
|
||||
format_function=FunctionFormatter(slots=["{{content}}"], tool_format="glm4"),
|
||||
format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>\n"]),
|
||||
format_tools=ToolFormatter(tool_format="glm4"),
|
||||
format_prefix=EmptyFormatter(slots=["[gMASK]<sop>"]),
|
||||
default_system=(
|
||||
"你是一位智能编程助手,你叫CodeGeeX。你会为用户回答关于编程、代码、计算机方面的任何问题,"
|
||||
"并提供格式规范、可以执行、准确安全的代码,并在必要时提供详细的解释。"
|
||||
),
|
||||
stop_words=["<|user|>", "<|observation|>"],
|
||||
efficient_eos=True,
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="cohere",
|
||||
format_user=StringFormatter(
|
||||
@@ -822,25 +741,6 @@ register_template(
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="cpm",
|
||||
format_user=StringFormatter(slots=["<用户>{{content}}<AI>"]),
|
||||
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||
)
|
||||
|
||||
|
||||
# copied from chatml template
|
||||
register_template(
|
||||
name="cpm3",
|
||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
|
||||
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
||||
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||
stop_words=["<|im_end|>"],
|
||||
)
|
||||
|
||||
|
||||
# copied from chatml template
|
||||
register_template(
|
||||
name="cpm4",
|
||||
@@ -1238,23 +1138,6 @@ register_template(
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="intern",
|
||||
format_user=StringFormatter(slots=["<|User|>:{{content}}\n<|Bot|>:"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}<eoa>\n"]),
|
||||
format_system=StringFormatter(slots=["<|System|>:{{content}}\n"]),
|
||||
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||
default_system=(
|
||||
"You are an AI assistant whose name is InternLM (书生·浦语).\n"
|
||||
"- InternLM (书生·浦语) is a conversational language model that is developed by Shanghai AI Laboratory "
|
||||
"(上海人工智能实验室). It is designed to be helpful, honest, and harmless.\n"
|
||||
"- InternLM (书生·浦语) can understand and communicate fluently in the language "
|
||||
"chosen by the user such as English and 中文."
|
||||
),
|
||||
stop_words=["<eoa>"],
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="intern2",
|
||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
@@ -1330,6 +1213,47 @@ register_template(
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="lfm2",
|
||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
|
||||
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
||||
format_function=FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="lfm2"),
|
||||
format_observation=StringFormatter(
|
||||
slots=[
|
||||
"<|im_start|>tool\n<|tool_response_start|>{{content}}<|tool_response_end|><|im_end|>\n"
|
||||
"<|im_start|>assistant\n"
|
||||
]
|
||||
),
|
||||
format_tools=ToolFormatter(tool_format="lfm2"),
|
||||
default_system="You are a helpful AI assistant.",
|
||||
stop_words=["<|im_end|>"],
|
||||
tool_call_words=("<|tool_call_start|>", "<|tool_call_end|>"),
|
||||
replace_eos=True,
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="lfm2_vl",
|
||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
|
||||
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
||||
format_function=FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="lfm2"),
|
||||
format_observation=StringFormatter(
|
||||
slots=[
|
||||
"<|im_start|>tool\n<|tool_response_start|>{{content}}<|tool_response_end|><|im_end|>\n"
|
||||
"<|im_start|>assistant\n"
|
||||
]
|
||||
),
|
||||
format_tools=ToolFormatter(tool_format="lfm2"),
|
||||
default_system="You are a helpful multimodal assistant by Liquid AI.",
|
||||
stop_words=["<|im_end|>"],
|
||||
tool_call_words=("<|tool_call_start|>", "<|tool_call_end|>"),
|
||||
replace_eos=True,
|
||||
mm_plugin=get_mm_plugin(name="lfm2_vl", image_token="<image>"),
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="llama2",
|
||||
format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]),
|
||||
@@ -1576,23 +1500,6 @@ register_template(
|
||||
)
|
||||
|
||||
|
||||
# copied from chatml template
|
||||
register_template(
|
||||
name="marco",
|
||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
|
||||
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
||||
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
default_system=(
|
||||
"你是一个经过良好训练的AI助手,你的名字是Marco-o1."
|
||||
"由阿里国际数字商业集团的AI Business创造.\n## 重要!!!!!\n"
|
||||
"当你回答问题时,你的思考应该在<Thought>内完成,<Output>内输出你的结果。\n"
|
||||
"<Thought>应该尽可能是英文,但是有2个特例,一个是对原文中的引用,另一个是是数学应该使用markdown格式,<Output>内的输出需要遵循用户输入的语言。\n"
|
||||
),
|
||||
stop_words=["<|im_end|>"],
|
||||
)
|
||||
|
||||
|
||||
# copied from qwen template
|
||||
register_template(
|
||||
name="mimo",
|
||||
@@ -1804,13 +1711,6 @@ register_template(
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="orion",
|
||||
format_user=StringFormatter(slots=["Human: {{content}}\n\nAssistant: ", {"eos_token"}]),
|
||||
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="paligemma",
|
||||
format_user=StringFormatter(slots=["{{content}}\n"]),
|
||||
@@ -1869,6 +1769,17 @@ register_template(
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="phi4_mini",
|
||||
format_user=StringFormatter(slots=["<|user|>{{content}}<|end|><|assistant|>"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}<|end|>"]),
|
||||
format_system=StringFormatter(slots=["<|system|>{{content}}<|end|>"]),
|
||||
format_tools=StringFormatter(slots=["<|tool|>{{content}}<|/tool|>"]),
|
||||
stop_words=["<|end|>"],
|
||||
replace_eos=True,
|
||||
)
|
||||
|
||||
|
||||
# copied from ministral template
|
||||
register_template(
|
||||
name="pixtral",
|
||||
@@ -2104,41 +2015,6 @@ register_template(
|
||||
)
|
||||
|
||||
|
||||
# copied from llama3 template
|
||||
register_template(
|
||||
name="skywork_o1",
|
||||
format_user=StringFormatter(
|
||||
slots=[
|
||||
(
|
||||
"<|start_header_id|>user<|end_header_id|>\n\n{{content}}<|eot_id|>"
|
||||
"<|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
)
|
||||
]
|
||||
),
|
||||
format_assistant=StringFormatter(slots=["{{content}}<|eot_id|>"]),
|
||||
format_system=StringFormatter(slots=["<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"]),
|
||||
format_function=FunctionFormatter(slots=["{{content}}<|eot_id|>"], tool_format="llama3"),
|
||||
format_observation=StringFormatter(
|
||||
slots=[
|
||||
(
|
||||
"<|start_header_id|>ipython<|end_header_id|>\n\n{{content}}<|eot_id|>"
|
||||
"<|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
)
|
||||
]
|
||||
),
|
||||
format_tools=ToolFormatter(tool_format="llama3"),
|
||||
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||
default_system=(
|
||||
"You are Skywork-o1, a thinking model developed by Skywork AI, specializing in solving complex problems "
|
||||
"involving mathematics, coding, and logical reasoning through deep thought. When faced with a user's request, "
|
||||
"you first engage in a lengthy and in-depth thinking process to explore possible solutions to the problem. "
|
||||
"After completing your thoughts, you then provide a detailed explanation of the solution process "
|
||||
"in your response."
|
||||
),
|
||||
stop_words=["<|eot_id|>", "<|eom_id|>"],
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="smollm",
|
||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
@@ -2175,13 +2051,6 @@ register_template(
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="telechat",
|
||||
format_user=StringFormatter(slots=["<_user>{{content}}<_bot>"]),
|
||||
format_system=StringFormatter(slots=["<_system>{{content}}<_end>"]),
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="telechat2",
|
||||
format_user=StringFormatter(slots=["<_user>{{content}}<_bot>"]),
|
||||
@@ -2225,32 +2094,6 @@ register_template(
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="xverse",
|
||||
format_user=StringFormatter(slots=["Human: {{content}}\n\nAssistant: "]),
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="yayi",
|
||||
format_user=StringFormatter(slots=[{"token": "<|Human|>"}, ":\n{{content}}\n\n", {"token": "<|YaYi|>"}, ":"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}\n\n"]),
|
||||
format_system=StringFormatter(slots=[{"token": "<|System|>"}, ":\n{{content}}\n\n"]),
|
||||
default_system=(
|
||||
"You are a helpful, respectful and honest assistant named YaYi "
|
||||
"developed by Beijing Wenge Technology Co.,Ltd. "
|
||||
"Always answer as helpfully as possible, while being safe. "
|
||||
"Your answers should not include any harmful, unethical, "
|
||||
"racist, sexist, toxic, dangerous, or illegal content. "
|
||||
"Please ensure that your responses are socially unbiased and positive in nature.\n\n"
|
||||
"If a question does not make any sense, or is not factually coherent, "
|
||||
"explain why instead of answering something not correct. "
|
||||
"If you don't know the answer to a question, please don't share false information."
|
||||
),
|
||||
stop_words=["<|End|>"],
|
||||
)
|
||||
|
||||
|
||||
# copied from chatml template
|
||||
register_template(
|
||||
name="yi",
|
||||
@@ -2278,6 +2121,21 @@ register_template(
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="youtu",
|
||||
format_user=StringFormatter(slots=["<|User|>{{content}}<|Assistant|>"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}<|end_of_text|>"]),
|
||||
format_system=StringFormatter(slots=["{{content}}"]),
|
||||
format_function=FunctionFormatter(slots=["{{content}}"], tool_format="default"),
|
||||
format_observation=StringFormatter(slots=["<tool_response>\n{{content}}\n</tool_response><|Assistant|>"]),
|
||||
format_tools=ToolFormatter(tool_format="default"),
|
||||
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||
stop_words=["<|end_of_text|>"],
|
||||
replace_eos=True,
|
||||
template_class=ReasoningTemplate,
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="yuan",
|
||||
format_user=StringFormatter(slots=["{{content}}", {"token": "<sep>"}]),
|
||||
@@ -2292,10 +2150,3 @@ register_template(
|
||||
format_system=StringFormatter(slots=["<|system|>\n{{content}}", {"eos_token"}]),
|
||||
default_system="You are Zephyr, a helpful assistant.",
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="ziya",
|
||||
format_user=StringFormatter(slots=["<human>:{{content}}\n<bot>:"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}\n"]),
|
||||
)
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import ast
|
||||
import json
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
@@ -101,6 +102,8 @@ LING_TOOL_PROMPT = (
|
||||
""""arguments": <args-json-object>}}\n</tool_call>"""
|
||||
)
|
||||
|
||||
LFM2_TOOL_PROMPT = "List of tools: <|tool_list_start|>{tool_text}<|tool_list_end|>"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolUtils(ABC):
|
||||
@@ -546,10 +549,115 @@ class LingToolUtils(QwenToolUtils):
|
||||
return LING_TOOL_PROMPT.format(tool_text=tool_text) + "\n" + "detailed thinking off"
|
||||
|
||||
|
||||
class LFM2ToolUtils(ToolUtils):
|
||||
r"""LFM2.5 tool using template with Pythonic function call syntax."""
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def tool_formatter(tools: list[dict[str, Any]]) -> str:
|
||||
tool_list = []
|
||||
for tool in tools:
|
||||
tool = tool.get("function", tool) if tool.get("type") == "function" else tool
|
||||
tool_list.append(tool)
|
||||
|
||||
return LFM2_TOOL_PROMPT.format(tool_text=json.dumps(tool_list, ensure_ascii=False))
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def function_formatter(functions: list["FunctionCall"]) -> str:
|
||||
calls = []
|
||||
for name, args_json in functions:
|
||||
args = json.loads(args_json)
|
||||
kwargs_parts = []
|
||||
for key, value in args.items():
|
||||
if isinstance(value, str):
|
||||
kwargs_parts.append(f'{key}="{value}"')
|
||||
else:
|
||||
kwargs_parts.append(f"{key}={json.dumps(value, ensure_ascii=False)}")
|
||||
|
||||
calls.append(f"{name}({', '.join(kwargs_parts)})")
|
||||
|
||||
return f"<|tool_call_start|>[{', '.join(calls)}]<|tool_call_end|>"
|
||||
|
||||
@staticmethod
|
||||
def _ast_to_value(node: ast.AST) -> Any:
|
||||
"""Convert an AST node to a Python value, handling JSON-style booleans/null."""
|
||||
# Handle JSON-style true/false/null as Name nodes
|
||||
if isinstance(node, ast.Name):
|
||||
if node.id == "true":
|
||||
return True
|
||||
elif node.id == "false":
|
||||
return False
|
||||
elif node.id == "null":
|
||||
return None
|
||||
else:
|
||||
raise ValueError(f"Unknown identifier: {node.id}")
|
||||
|
||||
# Use literal_eval for other cases (strings, numbers, lists, dicts)
|
||||
return ast.literal_eval(node)
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]:
|
||||
# Extract content between tool call markers
|
||||
start_marker = "<|tool_call_start|>"
|
||||
end_marker = "<|tool_call_end|>"
|
||||
|
||||
start_idx = content.find(start_marker)
|
||||
if start_idx == -1:
|
||||
return content
|
||||
|
||||
end_idx = content.find(end_marker, start_idx)
|
||||
if end_idx == -1:
|
||||
return content
|
||||
|
||||
tool_call_str = content[start_idx + len(start_marker) : end_idx].strip()
|
||||
|
||||
# Parse Pythonic function call syntax using AST
|
||||
try:
|
||||
tree = ast.parse(tool_call_str, mode="eval")
|
||||
except SyntaxError:
|
||||
return content
|
||||
|
||||
# Handle both single call and list of calls
|
||||
if isinstance(tree.body, ast.List):
|
||||
call_nodes = tree.body.elts
|
||||
elif isinstance(tree.body, ast.Call):
|
||||
call_nodes = [tree.body]
|
||||
else:
|
||||
return content
|
||||
|
||||
results = []
|
||||
for node in call_nodes:
|
||||
if not isinstance(node, ast.Call):
|
||||
return content
|
||||
|
||||
# Extract function name
|
||||
if isinstance(node.func, ast.Name):
|
||||
func_name = node.func.id
|
||||
else:
|
||||
return content
|
||||
|
||||
# Extract keyword arguments
|
||||
args_dict = {}
|
||||
for keyword in node.keywords:
|
||||
key = keyword.arg
|
||||
try:
|
||||
value = LFM2ToolUtils._ast_to_value(keyword.value)
|
||||
except (ValueError, SyntaxError):
|
||||
return content
|
||||
args_dict[key] = value
|
||||
|
||||
results.append(FunctionCall(func_name, json.dumps(args_dict, ensure_ascii=False)))
|
||||
|
||||
return results if results else content
|
||||
|
||||
|
||||
TOOLS = {
|
||||
"default": DefaultToolUtils(),
|
||||
"glm4": GLM4ToolUtils(),
|
||||
"llama3": Llama3ToolUtils(),
|
||||
"lfm2": LFM2ToolUtils(),
|
||||
"minimax1": MiniMaxM1ToolUtils(),
|
||||
"minimax2": MiniMaxM2ToolUtils(),
|
||||
"mistral": MistralToolUtils(),
|
||||
|
||||
@@ -181,51 +181,6 @@ register_model_group(
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Baichuan-7B-Base": {
|
||||
DownloadSource.DEFAULT: "baichuan-inc/Baichuan-7B",
|
||||
DownloadSource.MODELSCOPE: "baichuan-inc/baichuan-7B",
|
||||
},
|
||||
"Baichuan-13B-Base": {
|
||||
DownloadSource.DEFAULT: "baichuan-inc/Baichuan-13B-Base",
|
||||
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan-13B-Base",
|
||||
},
|
||||
"Baichuan-13B-Chat": {
|
||||
DownloadSource.DEFAULT: "baichuan-inc/Baichuan-13B-Chat",
|
||||
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan-13B-Chat",
|
||||
},
|
||||
},
|
||||
template="baichuan",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Baichuan2-7B-Base": {
|
||||
DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-7B-Base",
|
||||
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-7B-Base",
|
||||
},
|
||||
"Baichuan2-13B-Base": {
|
||||
DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-13B-Base",
|
||||
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-13B-Base",
|
||||
DownloadSource.OPENMIND: "Baichuan/Baichuan2_13b_base_pt",
|
||||
},
|
||||
"Baichuan2-7B-Chat": {
|
||||
DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-7B-Chat",
|
||||
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-7B-Chat",
|
||||
DownloadSource.OPENMIND: "Baichuan/Baichuan2_7b_chat_pt",
|
||||
},
|
||||
"Baichuan2-13B-Chat": {
|
||||
DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-13B-Chat",
|
||||
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-13B-Chat",
|
||||
DownloadSource.OPENMIND: "Baichuan/Baichuan2_13b_chat_pt",
|
||||
},
|
||||
},
|
||||
template="baichuan2",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"BLOOM-560M": {
|
||||
@@ -262,21 +217,6 @@ register_model_group(
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"BlueLM-7B-Base": {
|
||||
DownloadSource.DEFAULT: "vivo-ai/BlueLM-7B-Base",
|
||||
DownloadSource.MODELSCOPE: "vivo-ai/BlueLM-7B-Base",
|
||||
},
|
||||
"BlueLM-7B-Chat": {
|
||||
DownloadSource.DEFAULT: "vivo-ai/BlueLM-7B-Chat",
|
||||
DownloadSource.MODELSCOPE: "vivo-ai/BlueLM-7B-Chat",
|
||||
},
|
||||
},
|
||||
template="bluelm",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Breeze-7B": {
|
||||
@@ -290,17 +230,6 @@ register_model_group(
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"ChatGLM2-6B-Chat": {
|
||||
DownloadSource.DEFAULT: "zai-org/chatglm2-6b",
|
||||
DownloadSource.MODELSCOPE: "ZhipuAI/chatglm2-6b",
|
||||
}
|
||||
},
|
||||
template="chatglm2",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"ChatGLM3-6B-Base": {
|
||||
@@ -347,17 +276,6 @@ register_model_group(
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"CodeGeeX4-9B-Chat": {
|
||||
DownloadSource.DEFAULT: "zai-org/codegeex4-all-9b",
|
||||
DownloadSource.MODELSCOPE: "ZhipuAI/codegeex4-all-9b",
|
||||
},
|
||||
},
|
||||
template="codegeex4",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"CodeGemma-7B": {
|
||||
@@ -642,15 +560,15 @@ register_model_group(
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"ERNIE-4.5-0.3B-PT": {
|
||||
"ERNIE-4.5-0.3B-Instruct": {
|
||||
DownloadSource.DEFAULT: "baidu/ERNIE-4.5-0.3B-PT",
|
||||
DownloadSource.MODELSCOPE: "PaddlePaddle/ERNIE-4.5-0.3B-PT",
|
||||
},
|
||||
"ERNIE-4.5-21B-A3B-PT": {
|
||||
"ERNIE-4.5-21B-A3B-Instruct": {
|
||||
DownloadSource.DEFAULT: "baidu/ERNIE-4.5-21B-A3B-PT",
|
||||
DownloadSource.MODELSCOPE: "PaddlePaddle/ERNIE-4.5-21B-A3B-PT",
|
||||
},
|
||||
"ERNIE-4.5-300B-A47B-PT": {
|
||||
"ERNIE-4.5-300B-A47B-Instruct": {
|
||||
DownloadSource.DEFAULT: "baidu/ERNIE-4.5-300B-A47B-PT",
|
||||
DownloadSource.MODELSCOPE: "PaddlePaddle/ERNIE-4.5-300B-A47B-PT",
|
||||
},
|
||||
@@ -661,7 +579,7 @@ register_model_group(
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"ERNIE-4.5-VL-28B-A3B-PT": {
|
||||
"ERNIE-4.5-VL-28B-A3B-Instruct": {
|
||||
DownloadSource.DEFAULT: "baidu/ERNIE-4.5-VL-28B-A3B-PT",
|
||||
DownloadSource.MODELSCOPE: "PaddlePaddle/ERNIE-4.5-VL-28B-A3B-PT",
|
||||
},
|
||||
@@ -669,7 +587,7 @@ register_model_group(
|
||||
DownloadSource.DEFAULT: "baidu/ERNIE-4.5-VL-28B-A3B-Thinking",
|
||||
DownloadSource.MODELSCOPE: "PaddlePaddle/ERNIE-4.5-VL-28B-A3B-Thinking",
|
||||
},
|
||||
"ERNIE-4.5-VL-424B-A47B-Base-PT": {
|
||||
"ERNIE-4.5-VL-424B-A47B-Instruct": {
|
||||
DownloadSource.DEFAULT: "baidu/ERNIE-4.5-VL-424B-A47B-PT",
|
||||
DownloadSource.MODELSCOPE: "PaddlePaddle/ERNIE-4.5-VL-424B-A47B-PT",
|
||||
},
|
||||
@@ -1266,29 +1184,6 @@ register_model_group(
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"InternLM-7B": {
|
||||
DownloadSource.DEFAULT: "internlm/internlm-7b",
|
||||
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm-7b",
|
||||
},
|
||||
"InternLM-20B": {
|
||||
DownloadSource.DEFAULT: "internlm/internlm-20b",
|
||||
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm-20b",
|
||||
},
|
||||
"InternLM-7B-Chat": {
|
||||
DownloadSource.DEFAULT: "internlm/internlm-chat-7b",
|
||||
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm-chat-7b",
|
||||
},
|
||||
"InternLM-20B-Chat": {
|
||||
DownloadSource.DEFAULT: "internlm/internlm-chat-20b",
|
||||
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm-chat-20b",
|
||||
},
|
||||
},
|
||||
template="intern",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"InternLM2-7B": {
|
||||
@@ -1485,11 +1380,25 @@ register_model_group(
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"LingoWhale-8B": {
|
||||
DownloadSource.DEFAULT: "deeplang-ai/LingoWhale-8B",
|
||||
DownloadSource.MODELSCOPE: "DeepLang/LingoWhale-8B",
|
||||
}
|
||||
"LFM2.5-1.2B": {
|
||||
DownloadSource.DEFAULT: "LiquidAI/LFM2.5-1.2B-Base",
|
||||
},
|
||||
"LFM2.5-1.2B-Instruct": {
|
||||
DownloadSource.DEFAULT: "LiquidAI/LFM2.5-1.2B-Instruct",
|
||||
},
|
||||
},
|
||||
template="lfm2",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"LFM2.5-VL-1.6B": {
|
||||
DownloadSource.DEFAULT: "LiquidAI/LFM2.5-VL-1.6B",
|
||||
},
|
||||
},
|
||||
template="lfm2_vl",
|
||||
multimodal=True,
|
||||
)
|
||||
|
||||
|
||||
@@ -1804,17 +1713,6 @@ register_model_group(
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Marco-o1-Chat": {
|
||||
DownloadSource.DEFAULT: "AIDC-AI/Marco-o1",
|
||||
DownloadSource.MODELSCOPE: "AIDC-AI/Marco-o1",
|
||||
},
|
||||
},
|
||||
template="marco",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"MiMo-7B-Base": {
|
||||
@@ -1885,33 +1783,6 @@ register_model_group(
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"MiniCPM-2B-SFT-Chat": {
|
||||
DownloadSource.DEFAULT: "openbmb/MiniCPM-2B-sft-bf16",
|
||||
DownloadSource.MODELSCOPE: "OpenBMB/miniCPM-bf16",
|
||||
},
|
||||
"MiniCPM-2B-DPO-Chat": {
|
||||
DownloadSource.DEFAULT: "openbmb/MiniCPM-2B-dpo-bf16",
|
||||
DownloadSource.MODELSCOPE: "OpenBMB/MiniCPM-2B-dpo-bf16",
|
||||
},
|
||||
},
|
||||
template="cpm",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"MiniCPM3-4B-Chat": {
|
||||
DownloadSource.DEFAULT: "openbmb/MiniCPM3-4B",
|
||||
DownloadSource.MODELSCOPE: "OpenBMB/MiniCPM3-4B",
|
||||
DownloadSource.OPENMIND: "LlamaFactory/MiniCPM3-4B",
|
||||
},
|
||||
},
|
||||
template="cpm3",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"MiniCPM4-0.5B-Chat": {
|
||||
@@ -1949,26 +1820,10 @@ register_model_group(
|
||||
DownloadSource.DEFAULT: "openbmb/MiniCPM-V-2_6",
|
||||
DownloadSource.MODELSCOPE: "OpenBMB/MiniCPM-V-2_6",
|
||||
},
|
||||
},
|
||||
template="minicpm_v",
|
||||
multimodal=True,
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"MiniCPM-V-4": {
|
||||
DownloadSource.DEFAULT: "openbmb/MiniCPM-V-4",
|
||||
DownloadSource.MODELSCOPE: "OpenBMB/MiniCPM-V-4",
|
||||
},
|
||||
},
|
||||
template="minicpm_v",
|
||||
multimodal=True,
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"MiniCPM-V-4.5": {
|
||||
DownloadSource.DEFAULT: "openbmb/MiniCPM-V-4_5",
|
||||
DownloadSource.MODELSCOPE: "OpenBMB/MiniCPM-V-4_5",
|
||||
@@ -2226,33 +2081,6 @@ register_model_group(
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Orion-14B-Base": {
|
||||
DownloadSource.DEFAULT: "OrionStarAI/Orion-14B-Base",
|
||||
DownloadSource.MODELSCOPE: "OrionStarAI/Orion-14B-Base",
|
||||
},
|
||||
"Orion-14B-Chat": {
|
||||
DownloadSource.DEFAULT: "OrionStarAI/Orion-14B-Chat",
|
||||
DownloadSource.MODELSCOPE: "OrionStarAI/Orion-14B-Chat",
|
||||
},
|
||||
"Orion-14B-Long-Chat": {
|
||||
DownloadSource.DEFAULT: "OrionStarAI/Orion-14B-LongChat",
|
||||
DownloadSource.MODELSCOPE: "OrionStarAI/Orion-14B-LongChat",
|
||||
},
|
||||
"Orion-14B-RAG-Chat": {
|
||||
DownloadSource.DEFAULT: "OrionStarAI/Orion-14B-Chat-RAG",
|
||||
DownloadSource.MODELSCOPE: "OrionStarAI/Orion-14B-Chat-RAG",
|
||||
},
|
||||
"Orion-14B-Plugin-Chat": {
|
||||
DownloadSource.DEFAULT: "OrionStarAI/Orion-14B-Chat-Plugin",
|
||||
DownloadSource.MODELSCOPE: "OrionStarAI/Orion-14B-Chat-Plugin",
|
||||
},
|
||||
},
|
||||
template="orion",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"PaliGemma-3B-pt-224": {
|
||||
@@ -2349,20 +2177,6 @@ register_model_group(
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Phi-1.5-1.3B": {
|
||||
DownloadSource.DEFAULT: "microsoft/phi-1_5",
|
||||
DownloadSource.MODELSCOPE: "allspace/PHI_1-5",
|
||||
},
|
||||
"Phi-2-2.7B": {
|
||||
DownloadSource.DEFAULT: "microsoft/phi-2",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/phi-2",
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Phi-3-4B-4k-Instruct": {
|
||||
@@ -2419,6 +2233,15 @@ register_model_group(
|
||||
template="phi4",
|
||||
)
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Phi-4-3.8B-instruct": {
|
||||
DownloadSource.DEFAULT: "microsoft/Phi-4-mini-instruct",
|
||||
DownloadSource.MODELSCOPE: "LLM-Research/Phi-4-mini-instruct",
|
||||
},
|
||||
},
|
||||
template="phi4_mini",
|
||||
)
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
@@ -2432,228 +2255,6 @@ register_model_group(
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Qwen-1.8B": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen-1_8B",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen-1_8B",
|
||||
},
|
||||
"Qwen-7B": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen-7B",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen-7B",
|
||||
},
|
||||
"Qwen-14B": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen-14B",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen-14B",
|
||||
},
|
||||
"Qwen-72B": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen-72B",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen-72B",
|
||||
},
|
||||
"Qwen-1.8B-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen-1_8B-Chat",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen-1_8B-Chat",
|
||||
},
|
||||
"Qwen-7B-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen-7B-Chat",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen-7B-Chat",
|
||||
},
|
||||
"Qwen-14B-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen-14B-Chat",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen-14B-Chat",
|
||||
},
|
||||
"Qwen-72B-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen-72B-Chat",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen-72B-Chat",
|
||||
},
|
||||
"Qwen-1.8B-Chat-Int8": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen-1_8B-Chat-Int8",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen-1_8B-Chat-Int8",
|
||||
},
|
||||
"Qwen-1.8B-Chat-Int4": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen-1_8B-Chat-Int4",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen-1_8B-Chat-Int4",
|
||||
},
|
||||
"Qwen-7B-Chat-Int8": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen-7B-Chat-Int8",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen-7B-Chat-Int8",
|
||||
},
|
||||
"Qwen-7B-Chat-Int4": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen-7B-Chat-Int4",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen-7B-Chat-Int4",
|
||||
},
|
||||
"Qwen-14B-Chat-Int8": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen-14B-Chat-Int8",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen-14B-Chat-Int8",
|
||||
},
|
||||
"Qwen-14B-Chat-Int4": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen-14B-Chat-Int4",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen-14B-Chat-Int4",
|
||||
},
|
||||
"Qwen-72B-Chat-Int8": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen-72B-Chat-Int8",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen-72B-Chat-Int8",
|
||||
},
|
||||
"Qwen-72B-Chat-Int4": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen-72B-Chat-Int4",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen-72B-Chat-Int4",
|
||||
},
|
||||
},
|
||||
template="qwen",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Qwen1.5-0.5B": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-0.5B",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-0.5B",
|
||||
},
|
||||
"Qwen1.5-1.8B": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-1.8B",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-1.8B",
|
||||
},
|
||||
"Qwen1.5-4B": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-4B",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-4B",
|
||||
},
|
||||
"Qwen1.5-7B": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-7B",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-7B",
|
||||
},
|
||||
"Qwen1.5-14B": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-14B",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-14B",
|
||||
},
|
||||
"Qwen1.5-32B": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-32B",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-32B",
|
||||
},
|
||||
"Qwen1.5-72B": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-72B",
|
||||
},
|
||||
"Qwen1.5-110B": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-110B",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-110B",
|
||||
},
|
||||
"Qwen1.5-MoE-A2.7B": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-MoE-A2.7B",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-MoE-A2.7B",
|
||||
},
|
||||
"Qwen1.5-0.5B-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-0.5B-Chat",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-0.5B-Chat",
|
||||
},
|
||||
"Qwen1.5-1.8B-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-1.8B-Chat",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-1.8B-Chat",
|
||||
},
|
||||
"Qwen1.5-4B-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-4B-Chat",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-4B-Chat",
|
||||
},
|
||||
"Qwen1.5-7B-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-7B-Chat",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-7B-Chat",
|
||||
},
|
||||
"Qwen1.5-14B-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-14B-Chat",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-14B-Chat",
|
||||
},
|
||||
"Qwen1.5-32B-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-32B-Chat",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-32B-Chat",
|
||||
},
|
||||
"Qwen1.5-72B-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B-Chat",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-72B-Chat",
|
||||
},
|
||||
"Qwen1.5-110B-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-110B-Chat",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-110B-Chat",
|
||||
},
|
||||
"Qwen1.5-MoE-A2.7B-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-MoE-A2.7B-Chat",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-MoE-A2.7B-Chat",
|
||||
},
|
||||
"Qwen1.5-0.5B-Chat-GPTQ-Int8": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-0.5B-Chat-GPTQ-Int8",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-0.5B-Chat-GPTQ-Int8",
|
||||
},
|
||||
"Qwen1.5-0.5B-Chat-AWQ": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-0.5B-Chat-AWQ",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-0.5B-Chat-AWQ",
|
||||
},
|
||||
"Qwen1.5-1.8B-Chat-GPTQ-Int8": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-1.8B-Chat-GPTQ-Int8",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-1.8B-Chat-GPTQ-Int8",
|
||||
},
|
||||
"Qwen1.5-1.8B-Chat-AWQ": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-1.8B-Chat-AWQ",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-1.8B-Chat-AWQ",
|
||||
},
|
||||
"Qwen1.5-4B-Chat-GPTQ-Int8": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-4B-Chat-GPTQ-Int8",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-4B-Chat-GPTQ-Int8",
|
||||
},
|
||||
"Qwen1.5-4B-Chat-AWQ": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-4B-Chat-AWQ",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-4B-Chat-AWQ",
|
||||
},
|
||||
"Qwen1.5-7B-Chat-GPTQ-Int8": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-7B-Chat-GPTQ-Int8",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-7B-Chat-GPTQ-Int8",
|
||||
},
|
||||
"Qwen1.5-7B-Chat-AWQ": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-7B-Chat-AWQ",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-7B-Chat-AWQ",
|
||||
},
|
||||
"Qwen1.5-14B-Chat-GPTQ-Int8": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-14B-Chat-GPTQ-Int8",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-14B-Chat-GPTQ-Int8",
|
||||
},
|
||||
"Qwen1.5-14B-Chat-AWQ": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-14B-Chat-AWQ",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-14B-Chat-AWQ",
|
||||
},
|
||||
"Qwen1.5-32B-Chat-AWQ": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-32B-Chat-AWQ",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-32B-Chat-AWQ",
|
||||
},
|
||||
"Qwen1.5-72B-Chat-GPTQ-Int8": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B-Chat-GPTQ-Int8",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-72B-Chat-GPTQ-Int8",
|
||||
},
|
||||
"Qwen1.5-72B-Chat-AWQ": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B-Chat-AWQ",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-72B-Chat-AWQ",
|
||||
},
|
||||
"Qwen1.5-110B-Chat-AWQ": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-110B-Chat-AWQ",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-110B-Chat-AWQ",
|
||||
},
|
||||
"Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4",
|
||||
},
|
||||
"CodeQwen1.5-7B": {
|
||||
DownloadSource.DEFAULT: "Qwen/CodeQwen1.5-7B",
|
||||
DownloadSource.MODELSCOPE: "Qwen/CodeQwen1.5-7B",
|
||||
},
|
||||
"CodeQwen1.5-7B-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/CodeQwen1.5-7B-Chat",
|
||||
DownloadSource.MODELSCOPE: "Qwen/CodeQwen1.5-7B-Chat",
|
||||
},
|
||||
"CodeQwen1.5-7B-Chat-AWQ": {
|
||||
DownloadSource.DEFAULT: "Qwen/CodeQwen1.5-7B-Chat-AWQ",
|
||||
DownloadSource.MODELSCOPE: "Qwen/CodeQwen1.5-7B-Chat-AWQ",
|
||||
},
|
||||
},
|
||||
template="qwen",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Qwen2-0.5B": {
|
||||
@@ -3421,27 +3022,6 @@ register_model_group(
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Skywork-13B-Base": {
|
||||
DownloadSource.DEFAULT: "Skywork/Skywork-13B-base",
|
||||
DownloadSource.MODELSCOPE: "skywork/Skywork-13B-base",
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Skywork-o1-Open-Llama-3.1-8B": {
|
||||
DownloadSource.DEFAULT: "Skywork/Skywork-o1-Open-Llama-3.1-8B",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/Skywork-o1-Open-Llama-3.1-8B",
|
||||
}
|
||||
},
|
||||
template="skywork_o1",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"SmolLM-135M": {
|
||||
@@ -3536,30 +3116,6 @@ register_model_group(
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"TeleChat-1B-Chat": {
|
||||
DownloadSource.DEFAULT: "Tele-AI/TeleChat-1B",
|
||||
DownloadSource.MODELSCOPE: "TeleAI/TeleChat-1B",
|
||||
},
|
||||
"TeleChat-7B-Chat": {
|
||||
DownloadSource.DEFAULT: "Tele-AI/telechat-7B",
|
||||
DownloadSource.MODELSCOPE: "TeleAI/telechat-7B",
|
||||
DownloadSource.OPENMIND: "TeleAI/TeleChat-7B-pt",
|
||||
},
|
||||
"TeleChat-12B-Chat": {
|
||||
DownloadSource.DEFAULT: "Tele-AI/TeleChat-12B-v2",
|
||||
DownloadSource.MODELSCOPE: "TeleAI/TeleChat-12B-v2",
|
||||
DownloadSource.OPENMIND: "TeleAI/TeleChat-12B-pt",
|
||||
},
|
||||
"TeleChat-52B-Chat": {
|
||||
DownloadSource.DEFAULT: "Tele-AI/TeleChat-52B",
|
||||
},
|
||||
},
|
||||
template="telechat",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"TeleChat2-3B-Chat": {
|
||||
@@ -3674,80 +3230,6 @@ register_model_group(
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"XVERSE-7B": {
|
||||
DownloadSource.DEFAULT: "xverse/XVERSE-7B",
|
||||
DownloadSource.MODELSCOPE: "xverse/XVERSE-7B",
|
||||
},
|
||||
"XVERSE-13B": {
|
||||
DownloadSource.DEFAULT: "xverse/XVERSE-13B",
|
||||
DownloadSource.MODELSCOPE: "xverse/XVERSE-13B",
|
||||
},
|
||||
"XVERSE-65B": {
|
||||
DownloadSource.DEFAULT: "xverse/XVERSE-65B",
|
||||
DownloadSource.MODELSCOPE: "xverse/XVERSE-65B",
|
||||
},
|
||||
"XVERSE-65B-2": {
|
||||
DownloadSource.DEFAULT: "xverse/XVERSE-65B-2",
|
||||
DownloadSource.MODELSCOPE: "xverse/XVERSE-65B-2",
|
||||
},
|
||||
"XVERSE-7B-Chat": {
|
||||
DownloadSource.DEFAULT: "xverse/XVERSE-7B-Chat",
|
||||
DownloadSource.MODELSCOPE: "xverse/XVERSE-7B-Chat",
|
||||
},
|
||||
"XVERSE-13B-Chat": {
|
||||
DownloadSource.DEFAULT: "xverse/XVERSE-13B-Chat",
|
||||
DownloadSource.MODELSCOPE: "xverse/XVERSE-13B-Chat",
|
||||
},
|
||||
"XVERSE-65B-Chat": {
|
||||
DownloadSource.DEFAULT: "xverse/XVERSE-65B-Chat",
|
||||
DownloadSource.MODELSCOPE: "xverse/XVERSE-65B-Chat",
|
||||
},
|
||||
"XVERSE-MoE-A4.2B": {
|
||||
DownloadSource.DEFAULT: "xverse/XVERSE-MoE-A4.2B",
|
||||
DownloadSource.MODELSCOPE: "xverse/XVERSE-MoE-A4.2B",
|
||||
},
|
||||
"XVERSE-7B-Chat-GPTQ-Int8": {
|
||||
DownloadSource.DEFAULT: "xverse/XVERSE-7B-Chat-GPTQ-Int8",
|
||||
DownloadSource.MODELSCOPE: "xverse/XVERSE-7B-Chat-GPTQ-Int8",
|
||||
},
|
||||
"XVERSE-7B-Chat-GPTQ-Int4": {
|
||||
DownloadSource.DEFAULT: "xverse/XVERSE-7B-Chat-GPTQ-Int4",
|
||||
DownloadSource.MODELSCOPE: "xverse/XVERSE-7B-Chat-GPTQ-Int4",
|
||||
},
|
||||
"XVERSE-13B-Chat-GPTQ-Int8": {
|
||||
DownloadSource.DEFAULT: "xverse/XVERSE-13B-Chat-GPTQ-Int8",
|
||||
DownloadSource.MODELSCOPE: "xverse/XVERSE-13B-Chat-GPTQ-Int8",
|
||||
},
|
||||
"XVERSE-13B-Chat-GPTQ-Int4": {
|
||||
DownloadSource.DEFAULT: "xverse/XVERSE-13B-Chat-GPTQ-Int4",
|
||||
DownloadSource.MODELSCOPE: "xverse/XVERSE-13B-Chat-GPTQ-Int4",
|
||||
},
|
||||
"XVERSE-65B-Chat-GPTQ-Int4": {
|
||||
DownloadSource.DEFAULT: "xverse/XVERSE-65B-Chat-GPTQ-Int4",
|
||||
DownloadSource.MODELSCOPE: "xverse/XVERSE-65B-Chat-GPTQ-Int4",
|
||||
},
|
||||
},
|
||||
template="xverse",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Yayi-7B": {
|
||||
DownloadSource.DEFAULT: "wenge-research/yayi-7b-llama2",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/yayi-7b-llama2",
|
||||
},
|
||||
"Yayi-13B": {
|
||||
DownloadSource.DEFAULT: "wenge-research/yayi-13b-llama2",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/yayi-13b-llama2",
|
||||
},
|
||||
},
|
||||
template="yayi",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Yi-6B": {
|
||||
@@ -3846,6 +3328,21 @@ register_model_group(
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Youtu-LLM-2B-Instruct": {
|
||||
DownloadSource.DEFAULT: "tencent/Youtu-LLM-2B",
|
||||
DownloadSource.MODELSCOPE: "Tencent-YouTu-Research/Youtu-LLM-2B",
|
||||
},
|
||||
"Youtu-LLM-2B-Base": {
|
||||
DownloadSource.DEFAULT: "tencent/Youtu-LLM-2B-Base",
|
||||
DownloadSource.MODELSCOPE: "Tencent-YouTu-Research/Youtu-LLM-2B-Base",
|
||||
},
|
||||
},
|
||||
template="youtu",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Yuan2-2B-Chat": {
|
||||
|
||||
@@ -19,7 +19,7 @@
|
||||
from collections import OrderedDict
|
||||
|
||||
|
||||
VERSION = "0.9.4"
|
||||
VERSION = "0.9.5.dev0"
|
||||
|
||||
|
||||
def print_env() -> None:
|
||||
|
||||
@@ -490,6 +490,14 @@ class FinetuningArguments(
|
||||
default=False,
|
||||
metadata={"help": "Whether to use the DFT loss."},
|
||||
)
|
||||
use_eaft_loss: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether to use the EAFT loss."},
|
||||
)
|
||||
eaft_alpha: float = field(
|
||||
default=1.0,
|
||||
metadata={"help": "The alpha parameter for EAFT loss to control the power of adaptive weight."},
|
||||
)
|
||||
freeze_vision_tower: bool = field(
|
||||
default=True,
|
||||
metadata={"help": "Whether ot not to freeze the vision tower in MLLM training."},
|
||||
|
||||
@@ -298,23 +298,6 @@ class QuantizationArguments:
|
||||
default=None,
|
||||
metadata={"help": "Device map used to infer the 4-bit quantized model, needs bitsandbytes>=0.43.0."},
|
||||
)
|
||||
fp8: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Enable FP8 mixed precision training via HuggingFace Accelerate. "
|
||||
"Requires PyTorch 2.7+ and Hopper architecture GPUs."
|
||||
},
|
||||
)
|
||||
fp8_backend: str = field(
|
||||
default="auto",
|
||||
metadata={
|
||||
"help": "FP8 backend to use ('auto', 'torchao', 'te', 'msamp'). 'auto' selects best available backend."
|
||||
},
|
||||
)
|
||||
fp8_enable_fsdp_float8_all_gather: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Enable FP8 optimizations for FSDP2 all-gather operations."},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
@@ -70,13 +71,13 @@ def read_args(args: dict[str, Any] | list[str] | None = None) -> dict[str, Any]
|
||||
if args is not None:
|
||||
return args
|
||||
|
||||
if sys.argv[1].endswith(".yaml") or sys.argv[1].endswith(".yml"):
|
||||
if len(sys.argv) > 1 and (sys.argv[1].endswith(".yaml") or sys.argv[1].endswith(".yml")):
|
||||
override_config = OmegaConf.from_cli(sys.argv[2:])
|
||||
dict_config = OmegaConf.load(Path(sys.argv[1]).absolute())
|
||||
return OmegaConf.to_container(OmegaConf.merge(dict_config, override_config))
|
||||
elif sys.argv[1].endswith(".json"):
|
||||
elif len(sys.argv) > 1 and sys.argv[1].endswith(".json"):
|
||||
override_config = OmegaConf.from_cli(sys.argv[2:])
|
||||
dict_config = OmegaConf.load(Path(sys.argv[1]).absolute())
|
||||
dict_config = OmegaConf.create(json.load(Path(sys.argv[1]).absolute()))
|
||||
return OmegaConf.to_container(OmegaConf.merge(dict_config, override_config))
|
||||
else:
|
||||
return sys.argv[1:]
|
||||
@@ -142,14 +143,6 @@ def _verify_model_args(
|
||||
logger.warning_rank0("We should use slow tokenizer for the Yi models. Change `use_fast_tokenizer` to False.")
|
||||
model_args.use_fast_tokenizer = False
|
||||
|
||||
# Validate advanced training features
|
||||
if model_args.fp8 and model_args.quantization_bit is not None:
|
||||
raise ValueError("FP8 training is not compatible with quantization. Please disable one of them.")
|
||||
|
||||
if model_args.fp8_enable_fsdp_float8_all_gather and not model_args.fp8:
|
||||
logger.warning_rank0("fp8_enable_fsdp_float8_all_gather requires fp8=True. Setting fp8=True.")
|
||||
model_args.fp8 = True
|
||||
|
||||
|
||||
def _check_extra_dependencies(
|
||||
model_args: "ModelArguments",
|
||||
@@ -347,6 +340,9 @@ def get_train_args(args: dict[str, Any] | list[str] | None = None) -> _TRAIN_CLS
|
||||
if training_args.deepspeed is not None and (finetuning_args.use_galore or finetuning_args.use_apollo):
|
||||
raise ValueError("GaLore and APOLLO are incompatible with DeepSpeed yet.")
|
||||
|
||||
if training_args.fp8 and model_args.quantization_bit is not None:
|
||||
raise ValueError("FP8 training is not compatible with quantization. Please disable one of them.")
|
||||
|
||||
if model_args.infer_backend != EngineName.HF:
|
||||
raise ValueError("vLLM/SGLang backend is only available for API, CLI and Web.")
|
||||
|
||||
@@ -363,6 +359,10 @@ def get_train_args(args: dict[str, Any] | list[str] | None = None) -> _TRAIN_CLS
|
||||
_verify_model_args(model_args, data_args, finetuning_args)
|
||||
_check_extra_dependencies(model_args, finetuning_args, training_args)
|
||||
|
||||
if training_args.fp8_enable_fsdp_float8_all_gather and not training_args.fp8:
|
||||
logger.warning_rank0("fp8_enable_fsdp_float8_all_gather requires fp8=True. Setting fp8=True.")
|
||||
model_args.fp8 = True
|
||||
|
||||
if (
|
||||
training_args.do_train
|
||||
and finetuning_args.finetuning_type == "lora"
|
||||
|
||||
@@ -92,7 +92,30 @@ class RayArguments:
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingArguments(RayArguments, BaseTrainingArguments):
|
||||
class Fp8Arguments:
|
||||
r"""Arguments pertaining to the FP8 training."""
|
||||
|
||||
fp8: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Enable FP8 mixed precision training via HuggingFace Accelerate. "
|
||||
"Requires PyTorch 2.7+ and Hopper architecture GPUs."
|
||||
},
|
||||
)
|
||||
fp8_backend: str = field(
|
||||
default="auto",
|
||||
metadata={
|
||||
"help": "FP8 backend to use ('auto', 'torchao', 'te', 'msamp'). 'auto' selects best available backend."
|
||||
},
|
||||
)
|
||||
fp8_enable_fsdp_float8_all_gather: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Enable FP8 optimizations for FSDP2 all-gather operations."},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingArguments(Fp8Arguments, RayArguments, BaseTrainingArguments):
|
||||
r"""Arguments pertaining to the trainer."""
|
||||
|
||||
overwrite_output_dir: bool = field(
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, Optional, TypedDict
|
||||
|
||||
import torch
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModelForCausalLM,
|
||||
@@ -29,6 +30,7 @@ from trl import AutoModelForCausalLMWithValueHead
|
||||
|
||||
from ..extras import logging
|
||||
from ..extras.misc import count_parameters, skip_check_imports, try_download_model_from_other_hub
|
||||
from ..extras.packages import is_torch_version_greater_than
|
||||
from .adapter import init_adapter
|
||||
from .model_utils.ktransformers import load_kt_pretrained_model
|
||||
from .model_utils.liger_kernel import apply_liger_kernel
|
||||
@@ -203,6 +205,16 @@ def load_model(
|
||||
model.load_state_dict(vhead_params, strict=False)
|
||||
logger.info_rank0(f"Loaded valuehead from checkpoint: {vhead_path}")
|
||||
|
||||
# Conv3D is not recommended when using torch 2.9.x
|
||||
if is_torch_version_greater_than("2.9.0") and not is_torch_version_greater_than("2.10.0"):
|
||||
if any(isinstance(m, torch.nn.Conv3d) for m in model.modules()):
|
||||
raise ValueError(
|
||||
"Unsupported torch version detected: torch 2.9.x with Conv3D. "
|
||||
"This combination is known to cause severe performance regression. "
|
||||
"Please downgrade torch to <2.9 or remove Conv3D. "
|
||||
"See https://github.com/pytorch/pytorch/issues/166122"
|
||||
)
|
||||
|
||||
if not is_trainable:
|
||||
model.requires_grad_(False)
|
||||
model.eval()
|
||||
|
||||
@@ -138,18 +138,25 @@ def patch_config(
|
||||
if getattr(config, "model_type", None) == "kimi_vl" and is_trainable:
|
||||
setattr(config.text_config, "topk_method", "greedy")
|
||||
|
||||
if "InternVLChatModel" in getattr(config, "architectures", []):
|
||||
architectures = getattr(config, "architectures", None)
|
||||
if isinstance(architectures, list) and "InternVLChatModel" in architectures:
|
||||
raise ValueError(
|
||||
"Please download the internvl models in a Hugging Face–compatible format "
|
||||
"(for example, https://huggingface.co/OpenGVLab/InternVL3-8B-hf)."
|
||||
)
|
||||
|
||||
if "LlavaLlamaForCausalLM" in getattr(config, "architectures", []):
|
||||
if isinstance(architectures, list) and "LlavaLlamaForCausalLM" in architectures:
|
||||
raise ValueError("Please download llava models with hf-compatible format: https://huggingface.co/llava-hf")
|
||||
|
||||
if getattr(config, "model_type", None) == "internlm3" and not is_transformers_version_greater_than("4.47.1"):
|
||||
raise RuntimeError("InternLM3 model requires transformers>=4.47.1, please upgrade it.")
|
||||
|
||||
if getattr(config, "model_type", None) == "lfm2_vl" and not is_transformers_version_greater_than("4.58.0"):
|
||||
raise RuntimeError(
|
||||
"LFM2.5-VL model requires transformers>=4.58.0 or install from commit: "
|
||||
"pip install git+https://github.com/huggingface/transformers.git@3c2517727ce28a30f5044e01663ee204deb1cdbe"
|
||||
)
|
||||
|
||||
if getattr(config, "model_type", None) == "qwen3_omni_moe":
|
||||
patch_qwen3_omni_moe_thinker_text_sparse_moe_block()
|
||||
|
||||
|
||||
@@ -12,35 +12,45 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import types
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from ..extras import logging
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..hparams import ModelArguments
|
||||
from ..hparams import TrainingArguments
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def create_fp8_kwargs(model_args: "ModelArguments") -> list[Any]:
|
||||
def create_fp8_kwargs(training_args: "TrainingArguments") -> list[Any]:
|
||||
"""Create AORecipeKwargs for FP8 training with HuggingFace Accelerate.
|
||||
|
||||
Args:
|
||||
model_args: Model arguments containing FP8 configuration
|
||||
training_args: Training arguments containing FP8 configuration
|
||||
|
||||
Returns:
|
||||
List containing AORecipeKwargs if FP8 is enabled and supported, empty list otherwise
|
||||
"""
|
||||
if not model_args.fp8:
|
||||
if not training_args.fp8:
|
||||
return []
|
||||
|
||||
try:
|
||||
# Check if AORecipeKwargs is available (Accelerate 1.8.0+)
|
||||
from accelerate.utils import AORecipeKwargs
|
||||
backend = getattr(training_args, "fp8_backend", "auto")
|
||||
logger.info_rank0(f"Creating FP8 configuration with backend: {backend}")
|
||||
|
||||
backend = getattr(model_args, "fp8_backend", "auto")
|
||||
logger.info_rank0(f"Creating FP8 configuration with backend: {backend}")
|
||||
try:
|
||||
# Use Transformer Engine backend (optimal for Hopper GPUs)
|
||||
if backend == "te":
|
||||
from accelerate.utils import FP8RecipeKwargs
|
||||
|
||||
logger.info_rank0("Using Transformer Engine FP8 backend")
|
||||
return [FP8RecipeKwargs(backend="TE", fp8_format="HYBRID", amax_history_len=16, amax_compute_algo="max")]
|
||||
|
||||
# Use TorchAO backend (default)
|
||||
from accelerate.utils import AORecipeKwargs
|
||||
|
||||
# Create Float8LinearConfig if torchao backend is used
|
||||
config = None
|
||||
@@ -83,7 +93,10 @@ def create_fp8_kwargs(model_args: "ModelArguments") -> list[Any]:
|
||||
return True
|
||||
|
||||
# Map FSDP all-gather setting if available (this affects the underlying implementation)
|
||||
if hasattr(model_args, "fp8_enable_fsdp_float8_all_gather") and model_args.fp8_enable_fsdp_float8_all_gather:
|
||||
if (
|
||||
hasattr(training_args, "fp8_enable_fsdp_float8_all_gather")
|
||||
and training_args.fp8_enable_fsdp_float8_all_gather
|
||||
):
|
||||
logger.info_rank0("FSDP float8 all-gather optimization requested")
|
||||
|
||||
return [AORecipeKwargs(config=config, module_filter_func=module_filter_func)]
|
||||
@@ -92,19 +105,19 @@ def create_fp8_kwargs(model_args: "ModelArguments") -> list[Any]:
|
||||
return []
|
||||
|
||||
|
||||
def get_fp8_mixed_precision(model_args: "ModelArguments") -> Optional[str]:
|
||||
def get_fp8_mixed_precision(training_args: "TrainingArguments") -> Optional[str]:
|
||||
"""Get the mixed precision setting for Accelerate when using FP8.
|
||||
|
||||
Args:
|
||||
model_args: Model arguments containing FP8 configuration
|
||||
training_args: Training arguments containing FP8 configuration
|
||||
|
||||
Returns:
|
||||
"fp8" if FP8 is enabled, None otherwise
|
||||
"""
|
||||
return "fp8" if model_args.fp8 else None
|
||||
return "fp8" if training_args.fp8 else None
|
||||
|
||||
|
||||
def configure_fp8_environment(model_args: "ModelArguments") -> None:
|
||||
def configure_fp8_environment(training_args: "TrainingArguments") -> None:
|
||||
"""Configure FP8 environment for HuggingFace Accelerate.
|
||||
|
||||
FP8 training is handled entirely through HuggingFace Accelerate, regardless of whether
|
||||
@@ -112,11 +125,9 @@ def configure_fp8_environment(model_args: "ModelArguments") -> None:
|
||||
variables and validates the FP8 configuration.
|
||||
|
||||
Args:
|
||||
model_args: Model arguments containing FP8 configuration
|
||||
training_args: Training arguments containing FP8 configuration
|
||||
"""
|
||||
import os
|
||||
|
||||
if not model_args.fp8:
|
||||
if not training_args.fp8:
|
||||
return
|
||||
|
||||
# Set mixed precision to fp8 for HuggingFace Accelerate
|
||||
@@ -124,38 +135,38 @@ def configure_fp8_environment(model_args: "ModelArguments") -> None:
|
||||
logger.info_rank0("Set ACCELERATE_MIXED_PRECISION=fp8")
|
||||
|
||||
# Configure FP8 backend and options
|
||||
backend = getattr(model_args, "fp8_backend", "auto")
|
||||
backend = getattr(training_args, "fp8_backend", "auto")
|
||||
if backend != "auto":
|
||||
os.environ["FP8_BACKEND"] = backend
|
||||
logger.info_rank0(f"Set FP8_BACKEND={backend}")
|
||||
|
||||
# Create and validate FP8 recipe kwargs (for logging/debugging)
|
||||
fp8_kwargs = create_fp8_kwargs(model_args)
|
||||
fp8_kwargs = create_fp8_kwargs(training_args)
|
||||
logger.info_rank0(f"FP8 AORecipeKwargs created: {len(fp8_kwargs)} items")
|
||||
|
||||
# Enable FSDP float8 all-gather optimization if requested
|
||||
if hasattr(model_args, "fp8_enable_fsdp_float8_all_gather") and model_args.fp8_enable_fsdp_float8_all_gather:
|
||||
if hasattr(training_args, "fp8_enable_fsdp_float8_all_gather") and training_args.fp8_enable_fsdp_float8_all_gather:
|
||||
os.environ["FP8_ENABLE_FSDP_FLOAT8_ALL_GATHER"] = "true"
|
||||
logger.info_rank0("Set FP8_ENABLE_FSDP_FLOAT8_ALL_GATHER=true")
|
||||
|
||||
logger.info_rank0("FP8 environment configured - all FP8 training handled by HuggingFace Accelerate")
|
||||
|
||||
|
||||
def verify_fp8_status(accelerator, model_args: "ModelArguments") -> None:
|
||||
def verify_fp8_status(accelerator, training_args: "TrainingArguments") -> None:
|
||||
"""Verify that FP8 training is actually working after model preparation.
|
||||
|
||||
Args:
|
||||
accelerator: The HuggingFace Accelerator instance
|
||||
model_args: Model arguments containing FP8 configuration
|
||||
training_args: Training arguments containing FP8 configuration
|
||||
"""
|
||||
if not model_args.fp8:
|
||||
if not training_args.fp8:
|
||||
return
|
||||
|
||||
# Check Accelerate's FP8 status
|
||||
fp8_enabled = getattr(accelerator, "fp8_enabled", False)
|
||||
fp8_backend_type = getattr(accelerator, "fp8_backend", "UNKNOWN")
|
||||
|
||||
backend = getattr(model_args, "fp8_backend", "auto")
|
||||
backend = getattr(training_args, "fp8_backend", "auto")
|
||||
if backend == "torchao" or backend == "auto":
|
||||
logger.info_rank0(
|
||||
"FP8 training enabled with TorchAO backend. For optimal performance, "
|
||||
@@ -169,3 +180,50 @@ def verify_fp8_status(accelerator, model_args: "ModelArguments") -> None:
|
||||
|
||||
if not fp8_enabled:
|
||||
logger.info_rank0("WARNING: FP8 was requested but Accelerate shows fp8_enabled=False. FP8 may not be working.")
|
||||
|
||||
|
||||
def patch_accelerator_for_fp8() -> None:
|
||||
"""Patch Accelerator to inject FP8 recipe kwargs.
|
||||
|
||||
This is needed because HuggingFace Trainer doesn't pass kwargs_handlers to Accelerator.
|
||||
We monkey-patch Accelerator.__init__ to inject the FP8 recipe and force mixed_precision='fp8'.
|
||||
"""
|
||||
import transformer_engine.pytorch as te
|
||||
from accelerate import Accelerator
|
||||
|
||||
# Guard against multiple patches
|
||||
if getattr(Accelerator, "_te_fp8_patched", False):
|
||||
return
|
||||
|
||||
# Stub for Accelerate 1.12+ compatibility (te.fp8.check_mxfp8_support doesn't exist yet)
|
||||
if not hasattr(te, "fp8"):
|
||||
te.fp8 = types.ModuleType("fp8")
|
||||
te.fp8.check_mxfp8_support = lambda: (False, "MXFP8 not supported")
|
||||
|
||||
try:
|
||||
from accelerate.utils import TERecipeKwargs as FP8Recipe
|
||||
|
||||
use_te_recipe = True
|
||||
except ImportError:
|
||||
from accelerate.utils import FP8RecipeKwargs as FP8Recipe
|
||||
|
||||
use_te_recipe = False
|
||||
|
||||
original_init = Accelerator.__init__
|
||||
|
||||
def patched_init(self, *args, **kwargs):
|
||||
if "kwargs_handlers" not in kwargs or not kwargs["kwargs_handlers"]:
|
||||
if use_te_recipe:
|
||||
kwargs["kwargs_handlers"] = [
|
||||
FP8Recipe(fp8_format="HYBRID", amax_history_len=16, amax_compute_algo="max")
|
||||
]
|
||||
else:
|
||||
kwargs["kwargs_handlers"] = [
|
||||
FP8Recipe(backend="TE", fp8_format="HYBRID", amax_history_len=16, amax_compute_algo="max")
|
||||
]
|
||||
# Only force mixed_precision when we inject handlers
|
||||
kwargs["mixed_precision"] = "fp8"
|
||||
return original_init(self, *args, **kwargs)
|
||||
|
||||
Accelerator.__init__ = patched_init
|
||||
Accelerator._te_fp8_patched = True
|
||||
|
||||
@@ -19,16 +19,15 @@ import torch
|
||||
from transformers import Trainer
|
||||
from typing_extensions import override
|
||||
|
||||
from ...extras.packages import is_transformers_version_greater_than
|
||||
from ..callbacks import SaveProcessorCallback
|
||||
from ..fp8_utils import configure_fp8_environment, verify_fp8_status
|
||||
from ..fp8_utils import configure_fp8_environment, patch_accelerator_for_fp8, verify_fp8_status
|
||||
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import ProcessorMixin
|
||||
|
||||
from ...hparams import FinetuningArguments, ModelArguments
|
||||
from ...hparams import FinetuningArguments, ModelArguments, TrainingArguments
|
||||
|
||||
|
||||
class CustomTrainer(Trainer):
|
||||
@@ -41,11 +40,13 @@ class CustomTrainer(Trainer):
|
||||
model_args: Optional["ModelArguments"] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
kwargs["processing_class"] = kwargs.pop("tokenizer")
|
||||
# Configure FP8 environment if enabled
|
||||
if model_args is not None and model_args.fp8:
|
||||
configure_fp8_environment(model_args)
|
||||
if is_transformers_version_greater_than("4.46"):
|
||||
kwargs["processing_class"] = kwargs.pop("tokenizer")
|
||||
training_args: TrainingArguments = kwargs.get("args")
|
||||
if training_args.fp8:
|
||||
configure_fp8_environment(training_args)
|
||||
if getattr(training_args, "fp8_backend", "auto") == "te":
|
||||
patch_accelerator_for_fp8()
|
||||
|
||||
super().__init__(**kwargs)
|
||||
if processor is not None:
|
||||
@@ -64,9 +65,8 @@ class CustomTrainer(Trainer):
|
||||
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
|
||||
self.add_callback(BAdamCallback)
|
||||
|
||||
# Verify FP8 status after trainer initialization (accelerator should be available)
|
||||
if model_args is not None and model_args.fp8 and hasattr(self, "accelerator"):
|
||||
verify_fp8_status(self.accelerator, model_args)
|
||||
if training_args.fp8 and hasattr(self, "accelerator"): # verify FP8 status after trainer initialization
|
||||
verify_fp8_status(self.accelerator, training_args)
|
||||
|
||||
@override
|
||||
def create_optimizer(self) -> "torch.optim.Optimizer":
|
||||
|
||||
@@ -27,18 +27,17 @@ from typing_extensions import override
|
||||
|
||||
from ...extras import logging
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
from ...extras.packages import is_transformers_version_greater_than
|
||||
from ..callbacks import SaveProcessorCallback
|
||||
from ..fp8_utils import configure_fp8_environment, verify_fp8_status
|
||||
from ..fp8_utils import configure_fp8_environment, patch_accelerator_for_fp8, verify_fp8_status
|
||||
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch.utils.data import Dataset
|
||||
from transformers import PreTrainedTokenizer, ProcessorMixin
|
||||
from transformers import ProcessorMixin
|
||||
from transformers.trainer import PredictionOutput
|
||||
|
||||
from ...hparams import FinetuningArguments, ModelArguments
|
||||
from ...hparams import FinetuningArguments, ModelArguments, TrainingArguments
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@@ -55,13 +54,13 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
||||
gen_kwargs: Optional[dict[str, Any]] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
kwargs["processing_class"] = kwargs.pop("tokenizer")
|
||||
# Configure FP8 environment if enabled
|
||||
if model_args is not None and model_args.fp8:
|
||||
configure_fp8_environment(model_args)
|
||||
if is_transformers_version_greater_than("4.46"):
|
||||
kwargs["processing_class"] = kwargs.pop("tokenizer")
|
||||
else:
|
||||
self.processing_class: PreTrainedTokenizer = kwargs.get("tokenizer")
|
||||
training_args: TrainingArguments = kwargs.get("args")
|
||||
if training_args.fp8:
|
||||
configure_fp8_environment(training_args)
|
||||
if getattr(training_args, "fp8_backend", "auto") == "te":
|
||||
patch_accelerator_for_fp8()
|
||||
|
||||
super().__init__(**kwargs)
|
||||
if processor is not None:
|
||||
@@ -88,9 +87,15 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
||||
|
||||
self.compute_loss_func = dft_loss_func
|
||||
|
||||
# Verify FP8 status after trainer initialization (accelerator should be available)
|
||||
if model_args is not None and model_args.fp8 and hasattr(self, "accelerator"):
|
||||
verify_fp8_status(self.accelerator, model_args)
|
||||
elif finetuning_args.use_eaft_loss:
|
||||
from ..trainer_utils import eaft_loss_func
|
||||
|
||||
self.compute_loss_func = lambda outputs, labels, num_items_in_batch=None: eaft_loss_func(
|
||||
outputs, labels, num_items_in_batch, finetuning_args.eaft_alpha
|
||||
)
|
||||
|
||||
if training_args.fp8 and hasattr(self, "accelerator"): # verify FP8 status after trainer initialization
|
||||
verify_fp8_status(self.accelerator, training_args)
|
||||
|
||||
@override
|
||||
def create_optimizer(self) -> "torch.optim.Optimizer":
|
||||
|
||||
@@ -634,7 +634,9 @@ def get_batch_logps(
|
||||
return logps, valid_length
|
||||
|
||||
|
||||
def dft_loss_func(outputs, labels, num_items_in_batch=None):
|
||||
def dft_loss_func(
|
||||
outputs: "torch.Tensor", labels: "torch.Tensor", num_items_in_batch: Optional["torch.Tensor"] = None
|
||||
):
|
||||
logits = outputs.get("logits")
|
||||
if logits is None:
|
||||
return outputs.get("loss", torch.tensor(0.0))
|
||||
@@ -652,11 +654,11 @@ def dft_loss_func(outputs, labels, num_items_in_batch=None):
|
||||
|
||||
|
||||
def _dft_cross_entropy(
|
||||
source: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
num_items_in_batch: Optional[torch.Tensor] = None,
|
||||
source: "torch.Tensor",
|
||||
target: "torch.Tensor",
|
||||
num_items_in_batch: Optional["torch.Tensor"] = None,
|
||||
ignore_index: int = -100,
|
||||
) -> torch.Tensor:
|
||||
) -> "torch.Tensor":
|
||||
per_token_loss = torch.nn.functional.cross_entropy(source, target, ignore_index=ignore_index, reduction="none")
|
||||
valid_mask = target != ignore_index
|
||||
if not valid_mask.any():
|
||||
@@ -679,6 +681,67 @@ def _dft_cross_entropy(
|
||||
return loss
|
||||
|
||||
|
||||
def eaft_loss_func(
|
||||
outputs: "torch.Tensor",
|
||||
labels: "torch.Tensor",
|
||||
num_items_in_batch: Optional["torch.Tensor"] = None,
|
||||
alpha: float = 1.0,
|
||||
) -> "torch.Tensor":
|
||||
logits = outputs.get("logits")
|
||||
if logits is None:
|
||||
return outputs.get("loss", torch.tensor(0.0))
|
||||
|
||||
logits = logits.float()
|
||||
vocab_size = logits.size(-1)
|
||||
labels = torch.nn.functional.pad(labels, (0, 1), value=-100)
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
logits = logits.view(-1, vocab_size)
|
||||
shift_labels = shift_labels.view(-1)
|
||||
shift_labels = shift_labels.to(logits.device)
|
||||
|
||||
loss = _eaft_cross_entropy(logits, shift_labels, num_items_in_batch, alpha)
|
||||
return loss
|
||||
|
||||
|
||||
def _eaft_cross_entropy(
|
||||
source: "torch.Tensor",
|
||||
target: "torch.Tensor",
|
||||
num_items_in_batch: Optional["torch.Tensor"] = None,
|
||||
alpha: float = 1.0,
|
||||
ignore_index: int = -100,
|
||||
) -> "torch.Tensor":
|
||||
per_token_loss = torch.nn.functional.cross_entropy(source, target, ignore_index=ignore_index, reduction="none")
|
||||
valid_mask = target != ignore_index
|
||||
if not valid_mask.any():
|
||||
return torch.tensor(0.0, device=source.device, dtype=source.dtype)
|
||||
|
||||
valid_losses = per_token_loss[valid_mask]
|
||||
|
||||
with torch.no_grad():
|
||||
source_detached = source[valid_mask].detach()
|
||||
|
||||
topk_val, _ = torch.topk(source_detached, k=20, dim=-1)
|
||||
logsumexp_topk = torch.logsumexp(topk_val, dim=-1, keepdim=True)
|
||||
log_probs_topk = topk_val - logsumexp_topk
|
||||
probs_topk = torch.exp(log_probs_topk)
|
||||
entropy_approx = -(probs_topk * log_probs_topk).sum(dim=-1)
|
||||
|
||||
entropy_term = entropy_approx / 3.0
|
||||
adaptive_weight = torch.pow(entropy_term, alpha)
|
||||
|
||||
weighted_losses = valid_losses * adaptive_weight
|
||||
|
||||
if num_items_in_batch is not None:
|
||||
total_loss = weighted_losses.sum()
|
||||
if torch.is_tensor(num_items_in_batch):
|
||||
num_items_in_batch = num_items_in_batch.to(total_loss.device)
|
||||
loss = total_loss / num_items_in_batch
|
||||
else:
|
||||
loss = weighted_losses.mean()
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
def nested_detach(
|
||||
tensors: Union["torch.Tensor", list["torch.Tensor"], tuple["torch.Tensor"], dict[str, "torch.Tensor"]],
|
||||
clone: bool = False,
|
||||
|
||||
@@ -119,9 +119,19 @@ def synchronize() -> None:
|
||||
|
||||
|
||||
@requires_accelerator
|
||||
def set_device() -> None:
|
||||
"""Set current accelerator."""
|
||||
torch.accelerator.set_device_index(get_local_rank())
|
||||
def set_device_index() -> None:
|
||||
"""Set current accelerator index to local rank."""
|
||||
if get_current_accelerator().type != DeviceType.CPU:
|
||||
torch.accelerator.set_device_index(get_local_rank())
|
||||
|
||||
|
||||
@requires_accelerator
|
||||
def get_current_device() -> torch.device:
|
||||
"""Get current accelerator device."""
|
||||
if get_current_accelerator().type == DeviceType.CPU:
|
||||
return torch.device(DeviceType.CPU.value)
|
||||
else:
|
||||
return torch.device(type=get_current_accelerator().type, index=torch.accelerator.current_device_index())
|
||||
|
||||
|
||||
def is_torch_cuda_available():
|
||||
|
||||
@@ -34,10 +34,14 @@ from typing import Any, Optional
|
||||
from torch.distributed import barrier, destroy_process_group, init_process_group
|
||||
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
|
||||
|
||||
from ..utils.types import DistributedConfig, ProcessGroup, Tensor, TensorLike
|
||||
from ..utils import logging
|
||||
from ..utils.types import DistributedConfig, ProcessGroup, TensorLike
|
||||
from . import helper
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class Dim(str, Enum):
|
||||
"""Dimension names."""
|
||||
|
||||
@@ -119,12 +123,13 @@ class DistributedInterface:
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
helper.set_device_index()
|
||||
self._is_distributed = helper.is_distributed()
|
||||
self._rank = helper.get_rank()
|
||||
self._world_size = helper.get_world_size()
|
||||
self._local_rank = helper.get_local_rank()
|
||||
self._local_world_size = helper.get_local_world_size()
|
||||
self.current_accelerator = helper.get_current_accelerator()
|
||||
self.current_device = helper.get_current_device()
|
||||
self.device_count = helper.get_device_count()
|
||||
|
||||
if config is None:
|
||||
@@ -140,15 +145,14 @@ class DistributedInterface:
|
||||
timeout = config.get("timeout", 18000)
|
||||
|
||||
if self._is_distributed:
|
||||
helper.set_device()
|
||||
init_process_group(timeout=timedelta(seconds=timeout))
|
||||
self.model_device_mesh = init_device_mesh(
|
||||
device_type=self.current_accelerator.type,
|
||||
device_type=self.current_device.type,
|
||||
mesh_shape=self.strategy.model_mesh_shape,
|
||||
mesh_dim_names=self.strategy.model_mesh_dim_names,
|
||||
)
|
||||
self.data_device_mesh = init_device_mesh(
|
||||
device_type=self.current_accelerator.type,
|
||||
device_type=self.current_device.type,
|
||||
mesh_shape=self.strategy.data_mesh_shape,
|
||||
mesh_dim_names=self.strategy.data_mesh_dim_names,
|
||||
)
|
||||
@@ -157,11 +161,12 @@ class DistributedInterface:
|
||||
self.data_device_mesh = None
|
||||
|
||||
self._initialized = True
|
||||
logger.info_rank0(f"DistributedInterface initialized: {self}.")
|
||||
|
||||
def __str__(self) -> str:
|
||||
return (
|
||||
f"DistributedInterface(strategy={self.strategy}), is_distributed={self._is_distributed}, "
|
||||
f"current_accelerator={self.current_accelerator}, rank={self._rank}, world_size={self._world_size}, "
|
||||
f"current_device={self.current_device}, rank={self._rank}, world_size={self._world_size}, "
|
||||
f"model_device_mesh={self.model_device_mesh}, data_device_mesh={self.data_device_mesh}"
|
||||
)
|
||||
|
||||
@@ -209,7 +214,7 @@ class DistributedInterface:
|
||||
"""Get parallel local world size."""
|
||||
return self._local_world_size
|
||||
|
||||
def all_gather(self, data: Tensor, dim: Dim | None = Dim.DP) -> Tensor:
|
||||
def all_gather(self, data: TensorLike, dim: Dim | None = Dim.DP) -> TensorLike:
|
||||
"""Gather tensor across specified parallel group."""
|
||||
if self.model_device_mesh is not None:
|
||||
return helper.operate_tensorlike(helper.all_gather, data, group=self.get_group(dim))
|
||||
@@ -246,4 +251,7 @@ class DistributedInterface:
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(DistributedInterface(DistributedStrategy()))
|
||||
"""
|
||||
python -m llamafactory.v1.accelerator.interface
|
||||
"""
|
||||
print(DistributedInterface())
|
||||
|
||||
@@ -0,0 +1,32 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .arg_parser import InputArgument, get_args
|
||||
from .arg_utils import ModelClass, SampleBackend
|
||||
from .data_args import DataArguments
|
||||
from .model_args import ModelArguments
|
||||
from .sample_args import SampleArguments
|
||||
from .training_args import TrainingArguments
|
||||
|
||||
|
||||
__all__ = [
|
||||
"DataArguments",
|
||||
"InputArgument",
|
||||
"ModelArguments",
|
||||
"ModelClass",
|
||||
"SampleArguments",
|
||||
"SampleBackend",
|
||||
"TrainingArguments",
|
||||
"get_args",
|
||||
]
|
||||
|
||||
@@ -30,21 +30,6 @@ from .training_args import TrainingArguments
|
||||
InputArgument = dict[str, Any] | list[str] | None
|
||||
|
||||
|
||||
def validate_args(
|
||||
data_args: DataArguments,
|
||||
model_args: ModelArguments,
|
||||
training_args: TrainingArguments,
|
||||
sample_args: SampleArguments,
|
||||
):
|
||||
"""Validate arguments."""
|
||||
if (
|
||||
model_args.quant_config is not None
|
||||
and training_args.dist_config is not None
|
||||
and training_args.dist_config.name == "deepspeed"
|
||||
):
|
||||
raise ValueError("Quantization is not supported with deepspeed backend.")
|
||||
|
||||
|
||||
def get_args(args: InputArgument = None) -> tuple[DataArguments, ModelArguments, TrainingArguments, SampleArguments]:
|
||||
"""Parse arguments from command line or config file."""
|
||||
parser = HfArgumentParser([DataArguments, ModelArguments, TrainingArguments, SampleArguments])
|
||||
@@ -71,8 +56,6 @@ def get_args(args: InputArgument = None) -> tuple[DataArguments, ModelArguments,
|
||||
print(f"Got unknown args, potentially deprecated arguments: {unknown_args}")
|
||||
raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {unknown_args}")
|
||||
|
||||
validate_args(*parsed_args)
|
||||
|
||||
return tuple(parsed_args)
|
||||
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
|
||||
|
||||
import json
|
||||
from enum import Enum, unique
|
||||
from enum import StrEnum, unique
|
||||
|
||||
|
||||
class PluginConfig(dict):
|
||||
@@ -36,7 +36,7 @@ PluginArgument = PluginConfig | dict | str | None
|
||||
|
||||
|
||||
@unique
|
||||
class ModelClass(str, Enum):
|
||||
class ModelClass(StrEnum):
|
||||
"""Auto class for model config."""
|
||||
|
||||
LLM = "llm"
|
||||
@@ -45,7 +45,7 @@ class ModelClass(str, Enum):
|
||||
|
||||
|
||||
@unique
|
||||
class SampleBackend(str, Enum):
|
||||
class SampleBackend(StrEnum):
|
||||
HF = "hf"
|
||||
VLLM = "vllm"
|
||||
|
||||
|
||||
@@ -21,20 +21,25 @@ from .arg_utils import ModelClass, PluginConfig, get_plugin_config
|
||||
@dataclass
|
||||
class ModelArguments:
|
||||
model: str = field(
|
||||
default="Qwen/Qwen3-4B-Instruct-2507",
|
||||
metadata={"help": "Path to the model or model identifier from Hugging Face."},
|
||||
)
|
||||
template: str = field(
|
||||
default="qwen3_nothink",
|
||||
metadata={"help": "Template for the model."},
|
||||
)
|
||||
trust_remote_code: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Trust remote code from Hugging Face."},
|
||||
)
|
||||
use_fast_processor: bool = field(
|
||||
default=True,
|
||||
metadata={"help": "Use fast processor from Hugging Face."},
|
||||
)
|
||||
model_class: ModelClass = field(
|
||||
default=ModelClass.LLM,
|
||||
metadata={"help": "Model class from Hugging Face."},
|
||||
)
|
||||
init_config: PluginConfig | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Initialization configuration for the model."},
|
||||
)
|
||||
peft_config: PluginConfig | None = field(
|
||||
default=None,
|
||||
metadata={"help": "PEFT configuration for the model."},
|
||||
@@ -49,6 +54,7 @@ class ModelArguments:
|
||||
)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.init_config = get_plugin_config(self.init_config)
|
||||
self.peft_config = get_plugin_config(self.peft_config)
|
||||
self.kernel_config = get_plugin_config(self.kernel_config)
|
||||
self.quant_config = get_plugin_config(self.quant_config)
|
||||
|
||||
@@ -22,7 +22,7 @@ from .arg_utils import PluginConfig, get_plugin_config
|
||||
@dataclass
|
||||
class TrainingArguments:
|
||||
output_dir: str = field(
|
||||
default=os.path.join("outputs", str(uuid4())),
|
||||
default=os.path.join("outputs", str(uuid4().hex)),
|
||||
metadata={"help": "Path to the output directory."},
|
||||
)
|
||||
micro_batch_size: int = field(
|
||||
|
||||
181
src/llamafactory/v1/core/base_sampler.py
Normal file
181
src/llamafactory/v1/core/base_sampler.py
Normal file
@@ -0,0 +1,181 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import AsyncGenerator, Generator
|
||||
from threading import Thread
|
||||
|
||||
import torch
|
||||
from transformers import TextIteratorStreamer
|
||||
|
||||
from ..accelerator.interface import DistributedInterface
|
||||
from ..config import ModelArguments, SampleArguments, SampleBackend
|
||||
from ..utils.helper import get_tokenizer
|
||||
from ..utils.types import HFModel, Message, Sample, TorchDataset
|
||||
from .utils.rendering import Renderer
|
||||
|
||||
|
||||
class BaseEngine(ABC):
|
||||
@abstractmethod
|
||||
def __init__(
|
||||
self,
|
||||
args: SampleArguments,
|
||||
model_args: ModelArguments,
|
||||
model: HFModel,
|
||||
renderer: Renderer,
|
||||
) -> None:
|
||||
"""Initialize the engine.
|
||||
|
||||
Args:
|
||||
args: Sample arguments.
|
||||
model_args: Model arguments.
|
||||
model: Model.
|
||||
renderer: Renderer.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def generate(self, messages: list[Message], tools: str | None = None) -> AsyncGenerator[str, None]:
|
||||
"""Generate tokens asynchronously.
|
||||
|
||||
Args:
|
||||
messages: List of messages.
|
||||
tools: Tools string.
|
||||
|
||||
Yields:
|
||||
Generated tokens.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def batch_infer(self, dataset: TorchDataset) -> list[Sample]:
|
||||
"""Batch infer samples.
|
||||
|
||||
Args:
|
||||
dataset: Torch dataset.
|
||||
|
||||
Returns:
|
||||
List of samples.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class HuggingFaceEngine(BaseEngine):
|
||||
def __init__(
|
||||
self,
|
||||
args: SampleArguments,
|
||||
model_args: ModelArguments,
|
||||
model: HFModel,
|
||||
renderer: Renderer,
|
||||
) -> None:
|
||||
self.args = args
|
||||
self.model_args = model_args
|
||||
self.model = model
|
||||
self.renderer = renderer
|
||||
self.semaphore = asyncio.Semaphore(int(os.getenv("MAX_CONCURRENT", "1")))
|
||||
|
||||
@torch.inference_mode()
|
||||
def get_response(self, messages: list[Message], tools: str | None = None) -> Generator[str, None, None]:
|
||||
model_inputs = self.renderer.render_messages(messages, tools, is_generate=True)
|
||||
streamer = TextIteratorStreamer(
|
||||
tokenizer=get_tokenizer(self.renderer.processor),
|
||||
skip_prompt=True,
|
||||
skip_special_tokens=True, # TODO: configurable
|
||||
)
|
||||
device = DistributedInterface().current_device
|
||||
kwargs = {
|
||||
"input_ids": torch.tensor([model_inputs["input_ids"]]).to(device),
|
||||
"attention_mask": torch.tensor([model_inputs["attention_mask"]]).to(device),
|
||||
"max_new_tokens": self.args.max_new_tokens,
|
||||
"streamer": streamer,
|
||||
}
|
||||
thread = Thread(target=self.model.generate, kwargs=kwargs, daemon=True)
|
||||
thread.start()
|
||||
|
||||
def stream():
|
||||
try:
|
||||
return streamer.__next__()
|
||||
except StopIteration:
|
||||
raise StopAsyncIteration()
|
||||
|
||||
return stream
|
||||
|
||||
async def generate(self, messages: list[Message], tools: str | None = None) -> AsyncGenerator[str, None]:
|
||||
async with self.semaphore:
|
||||
response = self.get_response(messages, tools)
|
||||
while True:
|
||||
try:
|
||||
yield await asyncio.to_thread(response)
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
|
||||
async def batch_infer(self, dataset: TorchDataset) -> list[Sample]:
|
||||
"""Batch infer samples.
|
||||
|
||||
Args:
|
||||
dataset: Torch dataset.
|
||||
|
||||
Returns:
|
||||
List of samples.
|
||||
"""
|
||||
raise NotImplementedError("Batch infer is not implemented.")
|
||||
|
||||
|
||||
class BaseSampler:
|
||||
"""Base sampler.
|
||||
|
||||
Args:
|
||||
args: Sample arguments.
|
||||
model_args: Model arguments.
|
||||
model: Model.
|
||||
renderer: Renderer.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
args: SampleArguments,
|
||||
model_args: ModelArguments,
|
||||
model: HFModel,
|
||||
renderer: Renderer,
|
||||
) -> None:
|
||||
if args.sample_backend == SampleBackend.HF:
|
||||
self.engine = HuggingFaceEngine(args, model_args, model, renderer)
|
||||
else:
|
||||
raise ValueError(f"Unknown sample backend: {args.sample_backend}")
|
||||
|
||||
async def generate(self, messages: list[Message], tools: str | None = None) -> AsyncGenerator[str, None]:
|
||||
"""Generate tokens asynchronously.
|
||||
|
||||
Args:
|
||||
messages: List of messages.
|
||||
tools: Tools string.
|
||||
|
||||
Yields:
|
||||
Generated tokens.
|
||||
"""
|
||||
async for token in self.engine.generate(messages, tools):
|
||||
yield token
|
||||
|
||||
async def batch_infer(self, dataset: TorchDataset) -> list[Sample]:
|
||||
"""Batch infer samples.
|
||||
|
||||
Args:
|
||||
dataset: Torch dataset.
|
||||
|
||||
Returns:
|
||||
List of samples.
|
||||
"""
|
||||
return await self.engine.batch_infer(dataset)
|
||||
@@ -1,44 +0,0 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from ..config.sample_args import SampleArguments, SampleBackend
|
||||
from .model_loader import ModelLoader
|
||||
|
||||
|
||||
class BaseEngine(ABC):
|
||||
@abstractmethod
|
||||
def __init__(self, sample_args: SampleArguments, model_loader: ModelLoader) -> None: ...
|
||||
|
||||
@abstractmethod
|
||||
async def generate(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def batch_infer(self):
|
||||
pass
|
||||
|
||||
|
||||
class HuggingFaceEngine(BaseEngine):
|
||||
def __init__(self, model_loader: ModelLoader, sample_args: SampleArguments) -> None:
|
||||
self.args = sample_args
|
||||
|
||||
|
||||
class ChatSampler:
|
||||
def __init__(self, model_loader: ModelLoader, sample_args: SampleArguments) -> None:
|
||||
if sample_args.sample_backend == SampleBackend.HF:
|
||||
self.engine = HuggingFaceEngine(model_loader, sample_args)
|
||||
else:
|
||||
raise ValueError(f"Unknown sample backend: {sample_args.sample_backend}")
|
||||
@@ -14,15 +14,23 @@
|
||||
|
||||
"""The definition of data engine.
|
||||
|
||||
Init Data engine:
|
||||
How to use:
|
||||
data_engine = DataEngine(data_args)
|
||||
data_engine[i]: Get the sample via index.
|
||||
|
||||
Init workflow:
|
||||
1. Parse dataset info from arguments.
|
||||
2. Load datasets according to dataset info.
|
||||
3. Build data index (and reweight samples if necessary).
|
||||
|
||||
Get Data Sample:
|
||||
Get data sample:
|
||||
1. Get sample from data index.
|
||||
2. Convert sample to standard format.
|
||||
3. Return sample.
|
||||
|
||||
Note:
|
||||
1. The data engine is equivalent to the torch dataset.
|
||||
2. The data engine is agnostic to the model used.
|
||||
"""
|
||||
|
||||
import os
|
||||
@@ -98,10 +106,10 @@ class DataEngine(Dataset):
|
||||
|
||||
size = self.dataset_infos[dataset_name].get("size")
|
||||
weight = self.dataset_infos[dataset_name].get("weight")
|
||||
if size or weight: # data index plugin
|
||||
from ..plugins.data_plugins.loader import DataIndexPlugin
|
||||
if size or weight:
|
||||
from ..plugins.data_plugins.loader import adjust_data_index
|
||||
|
||||
data_index = DataIndexPlugin().adjust_data_index(data_index, size, weight)
|
||||
data_index = adjust_data_index(data_index, size, weight)
|
||||
|
||||
self.data_index.extend(data_index)
|
||||
|
||||
@@ -150,9 +158,9 @@ class DataEngine(Dataset):
|
||||
dataset_name, sample_index = self.data_index[index]
|
||||
return self._convert_data_sample(self.datasets[dataset_name][sample_index], dataset_name)
|
||||
else: # data selector plugin
|
||||
from ..plugins.data_plugins.loader import DataSelectorPlugin
|
||||
from ..plugins.data_plugins.loader import select_data_sample
|
||||
|
||||
selected_index = DataSelectorPlugin().select(self.data_index, index)
|
||||
selected_index = select_data_sample(self.data_index, index)
|
||||
if isinstance(selected_index, list):
|
||||
return [
|
||||
self._convert_data_sample(self.datasets[dataset_name][sample_index], dataset_name)
|
||||
|
||||
@@ -12,34 +12,44 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""The definition of model loader.
|
||||
"""The definition of model engine.
|
||||
|
||||
Init Phase:
|
||||
How to use:
|
||||
model_engine = ModelEngine(model_args, is_train=True)
|
||||
model_engine.processor: Get the tokenizer or multi-modal processor.
|
||||
model_engine.renderer: Get the renderer.
|
||||
model_engine.model_config: Get the model configuration.
|
||||
model_engine.model: Get the HF model.
|
||||
|
||||
Init workflow:
|
||||
1. Init processor.
|
||||
2. Init render.
|
||||
2. Init model config.
|
||||
3. Init model.
|
||||
4. Init adapter.
|
||||
|
||||
"""
|
||||
|
||||
import torch
|
||||
from accelerate import init_empty_weights
|
||||
from transformers import AutoConfig, AutoProcessor
|
||||
|
||||
from ..accelerator.helper import DeviceType
|
||||
from ..accelerator.interface import DistributedInterface
|
||||
from ..config.model_args import ModelArguments, ModelClass
|
||||
from ..utils import logging
|
||||
from ..utils.types import HFConfig, HFModel, Processor
|
||||
from .utils.rendering import Renderer
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class ModelLoader:
|
||||
"""Model loader.
|
||||
class ModelEngine:
|
||||
"""Model engine.
|
||||
|
||||
Args:
|
||||
model_args: Model arguments.
|
||||
is_trainable: Whether to train the model.
|
||||
is_train: Whether to train the model.
|
||||
"""
|
||||
|
||||
def __init__(self, model_args: ModelArguments, is_train: bool = False) -> None:
|
||||
@@ -49,17 +59,22 @@ class ModelLoader:
|
||||
"""Whether to train the model."""
|
||||
self.processor = self._init_processor()
|
||||
"""Tokenizer or multi-modal processor."""
|
||||
self.renderer = Renderer(self.args.template, self.processor)
|
||||
"""Renderer."""
|
||||
self.model_config = self._init_model_config()
|
||||
"""Model configuration."""
|
||||
self.model = self._init_model()
|
||||
"""HF model."""
|
||||
|
||||
def _init_processor(self) -> Processor:
|
||||
"""Init processor."""
|
||||
"""Init processor.
|
||||
|
||||
NOTE: Transformers v5 always use fast tokenizer.
|
||||
https://github.com/huggingface/transformers/blob/v5.0.0rc1/src/transformers/models/auto/tokenization_auto.py#L642
|
||||
"""
|
||||
return AutoProcessor.from_pretrained(
|
||||
self.args.model,
|
||||
trust_remote_code=self.args.trust_remote_code,
|
||||
use_fast=self.args.use_fast_processor,
|
||||
)
|
||||
|
||||
def _init_model_config(self) -> HFConfig:
|
||||
@@ -92,14 +107,24 @@ class ModelLoader:
|
||||
|
||||
AutoClass = AutoModel
|
||||
|
||||
# map the entire model to the current accelerator
|
||||
model = AutoClass.from_pretrained(
|
||||
self.args.model,
|
||||
config=self.model_config,
|
||||
dtype="auto",
|
||||
device_map=DistributedInterface().current_accelerator,
|
||||
trust_remote_code=self.args.trust_remote_code,
|
||||
)
|
||||
if self.args.init_config is not None:
|
||||
from ..plugins.model_plugins.initialization import InitPlugin
|
||||
|
||||
init_device = InitPlugin(self.args.init_config.name)()
|
||||
else:
|
||||
init_device = DistributedInterface().current_device
|
||||
|
||||
if init_device.type == DeviceType.META:
|
||||
with init_empty_weights():
|
||||
model = AutoClass.from_config(self.model_config)
|
||||
else:
|
||||
model = AutoClass.from_pretrained(
|
||||
self.args.model,
|
||||
config=self.model_config,
|
||||
dtype="auto",
|
||||
device_map=init_device,
|
||||
trust_remote_code=self.args.trust_remote_code,
|
||||
)
|
||||
|
||||
if self.args.peft_config is None:
|
||||
if self.is_train:
|
||||
@@ -124,12 +149,12 @@ class ModelLoader:
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""
|
||||
python -m llamafactory.v1.core.model_loader --model llamafactory/tiny-random-qwen2.5
|
||||
python -m llamafactory.v1.core.model_engine --model llamafactory/tiny-random-qwen2.5
|
||||
"""
|
||||
from ..config.arg_parser import get_args
|
||||
|
||||
_, model_args, *_ = get_args()
|
||||
model_loader = ModelLoader(model_args=model_args)
|
||||
print(model_loader.processor)
|
||||
print(model_loader.model_config)
|
||||
print(model_loader.model)
|
||||
model_engine = ModelEngine(model_args=model_args)
|
||||
print(model_engine.processor)
|
||||
print(model_engine.model_config)
|
||||
print(model_engine.model)
|
||||
99
src/llamafactory/v1/core/utils/rendering.py
Normal file
99
src/llamafactory/v1/core/utils/rendering.py
Normal file
@@ -0,0 +1,99 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from ...utils.constants import IGNORE_INDEX
|
||||
from ...utils.helper import get_tokenizer
|
||||
from ...utils.types import Message, ModelInput, Processor
|
||||
|
||||
|
||||
def render_chatml_messages(
|
||||
processor: Processor,
|
||||
messages: list[Message],
|
||||
tools: str | None = None,
|
||||
is_generate: bool = False,
|
||||
) -> ModelInput:
|
||||
"""Apply chatml template to messages and convert them to model input.
|
||||
|
||||
See https://huggingface.co/spaces/huggingfacejs/chat-template-playground?modelId=Qwen/Qwen2-7B-Instruct
|
||||
"""
|
||||
tokenizer = get_tokenizer(processor)
|
||||
input_ids, labels, loss_weights = [], [], []
|
||||
|
||||
for message in messages:
|
||||
temp_str = "<|im_start|>" + message["role"] + "\n"
|
||||
for content in message["content"]:
|
||||
if content["type"] == "text":
|
||||
temp_str += content["value"]
|
||||
else:
|
||||
raise ValueError(f"Unsupported content type: {content['type']}")
|
||||
|
||||
temp_str += "<|im_end|>\n"
|
||||
temp_weight = message.get("loss_weight", 1.0 if message["role"] == "assistant" else 0.0)
|
||||
temp_ids = tokenizer.encode(temp_str, add_special_tokens=False)
|
||||
input_ids.extend(temp_ids)
|
||||
loss_weights.extend([temp_weight] * len(temp_ids))
|
||||
if temp_weight > 1e-6:
|
||||
labels.extend(temp_ids)
|
||||
else:
|
||||
labels.extend([IGNORE_INDEX] * len(temp_ids))
|
||||
|
||||
if is_generate:
|
||||
temp_ids = tokenizer.encode("<|im_start|>assistant\n", add_special_tokens=False)
|
||||
input_ids.extend(temp_ids)
|
||||
loss_weights.extend([0.0] * len(temp_ids))
|
||||
labels.extend([IGNORE_INDEX] * len(temp_ids))
|
||||
|
||||
return ModelInput(
|
||||
input_ids=input_ids,
|
||||
attention_mask=[1] * len(input_ids),
|
||||
labels=labels,
|
||||
loss_weights=loss_weights,
|
||||
)
|
||||
|
||||
|
||||
def parse_chatml_message(generated_text: str) -> Message:
|
||||
"""Parse a message in ChatML format. Supports interleaved reasoning and tool calls.
|
||||
|
||||
Args:
|
||||
generated_text (str): The generated text in ChatML format.
|
||||
|
||||
Returns:
|
||||
Message: The parsed message.
|
||||
"""
|
||||
return Message(role="assistant", content=[{"type": "text", "value": generated_text}])
|
||||
|
||||
|
||||
class Renderer:
|
||||
def __init__(self, template: str, processor: Processor):
|
||||
self.template = template
|
||||
self.processor = processor
|
||||
|
||||
def render_messages(
|
||||
self, messages: list[Message], tools: str | None = None, is_generate: bool = False
|
||||
) -> ModelInput:
|
||||
if self.template == "chatml":
|
||||
return render_chatml_messages(self.processor, messages, tools, is_generate)
|
||||
else:
|
||||
from ...plugins.model_plugins.rendering import RenderingPlugin
|
||||
|
||||
return RenderingPlugin(self.template).render_messages(self.processor, messages, tools, is_generate)
|
||||
|
||||
def parse_message(self, generated_text: str) -> Message:
|
||||
if self.template == "chatml":
|
||||
return parse_chatml_message(generated_text)
|
||||
else:
|
||||
from ...plugins.model_plugins.rendering import RenderingPlugin
|
||||
|
||||
return RenderingPlugin(self.template).parse_message(generated_text)
|
||||
@@ -49,6 +49,11 @@ def launch():
|
||||
|
||||
run_sft()
|
||||
|
||||
elif command == "chat":
|
||||
from .samplers.cli_sampler import run_chat
|
||||
|
||||
run_chat()
|
||||
|
||||
elif command == "env":
|
||||
print_env()
|
||||
|
||||
|
||||
@@ -13,11 +13,12 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import json
|
||||
from typing import Any, Literal, NotRequired, TypedDict
|
||||
|
||||
from ...utils import logging
|
||||
from ...utils.plugin import BasePlugin
|
||||
from ...utils.types import DPOSample, Sample, SFTSample
|
||||
from ...utils.types import DPOSample, Sample, SFTSample, ToolCall
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@@ -61,7 +62,7 @@ class DataConverterPlugin(BasePlugin):
|
||||
return super().__call__(raw_sample)
|
||||
|
||||
|
||||
@DataConverterPlugin("alpaca").register
|
||||
@DataConverterPlugin("alpaca").register()
|
||||
def alpaca_converter(raw_sample: AlpacaSample) -> SFTSample:
|
||||
"""Convert Alpaca sample to SFT sample.
|
||||
|
||||
@@ -98,7 +99,7 @@ def alpaca_converter(raw_sample: AlpacaSample) -> SFTSample:
|
||||
return {"messages": messages}
|
||||
|
||||
|
||||
@DataConverterPlugin("sharegpt").register
|
||||
@DataConverterPlugin("sharegpt").register()
|
||||
def sharegpt_converter(raw_sample: SharegptSample) -> SFTSample:
|
||||
"""Convert ShareGPT sample to SFT sample.
|
||||
|
||||
@@ -118,17 +119,32 @@ def sharegpt_converter(raw_sample: SharegptSample) -> SFTSample:
|
||||
"function_call": "assistant",
|
||||
}
|
||||
messages = []
|
||||
tools = raw_sample.get("tools", "")
|
||||
tools = raw_sample.get("tools")
|
||||
if tools:
|
||||
try:
|
||||
tools: list[dict[str, Any]] = json.loads(tools)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning_rank0(f"Invalid tools format: {str(tools)}")
|
||||
tools = []
|
||||
|
||||
for message in raw_sample.get("conversations", []):
|
||||
tag = message["from"]
|
||||
if tag not in tag_mapping:
|
||||
logger.warning_rank0(f"Unsupported role tag {tag} in message: {message}")
|
||||
elif tag == "function_call":
|
||||
try:
|
||||
tool_calls: ToolCall | list[ToolCall] = json.loads(message["value"])
|
||||
except json.JSONDecodeError:
|
||||
logger.warning_rank0(f"Invalid tool call format: {str(message['value'])}")
|
||||
continue
|
||||
|
||||
if not isinstance(tool_calls, list):
|
||||
tool_calls = [tool_calls]
|
||||
|
||||
messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [{"type": "tool_calls", "value": message["value"]}],
|
||||
"content": [{"type": "tool_call", "value": json.dumps(tool_call)} for tool_call in tool_calls],
|
||||
"loss_weight": 1.0,
|
||||
}
|
||||
)
|
||||
@@ -142,15 +158,12 @@ def sharegpt_converter(raw_sample: SharegptSample) -> SFTSample:
|
||||
)
|
||||
|
||||
if tools:
|
||||
if messages and messages[0]["role"] == "system":
|
||||
messages[0]["content"].append({"type": "tools", "value": tools})
|
||||
else:
|
||||
messages.insert(0, {"role": "system", "content": [{"type": "tools", "value": tools}], "loss_weight": 0.0})
|
||||
|
||||
return {"messages": messages}
|
||||
return {"messages": messages, "tools": json.dumps(tools)}
|
||||
else:
|
||||
return {"messages": messages}
|
||||
|
||||
|
||||
@DataConverterPlugin("pair").register
|
||||
@DataConverterPlugin("pair").register()
|
||||
def pair_converter(raw_sample: PairSample) -> DPOSample:
|
||||
"""Convert Pair sample to DPO sample.
|
||||
|
||||
|
||||
@@ -49,7 +49,7 @@ def _get_builder_name(path: str) -> Literal["arrow", "csv", "json", "parquet", "
|
||||
raise ValueError(f"Unknown dataset filetype: {filetype}.")
|
||||
|
||||
|
||||
@DataLoaderPlugin("local").register
|
||||
@DataLoaderPlugin("local").register()
|
||||
def load_data_from_file(filepath: str, split: str, streaming: bool) -> HFDataset:
|
||||
if os.path.isdir(filepath):
|
||||
filetype = _get_builder_name(os.listdir(filepath)[0])
|
||||
@@ -66,49 +66,43 @@ def load_data_from_file(filepath: str, split: str, streaming: bool) -> HFDataset
|
||||
return dataset
|
||||
|
||||
|
||||
class DataIndexPlugin(BasePlugin):
|
||||
"""Plugin for adjusting dataset index."""
|
||||
def adjust_data_index(
|
||||
data_index: list[tuple[str, int]], size: int | None, weight: float | None
|
||||
) -> list[tuple[str, int]]:
|
||||
"""Adjust dataset index by size and weight.
|
||||
|
||||
def adjust_data_index(
|
||||
self, data_index: list[tuple[str, int]], size: int | None, weight: float | None
|
||||
) -> list[tuple[str, int]]:
|
||||
"""Adjust dataset index by size and weight.
|
||||
Args:
|
||||
data_index (list[tuple[str, int]]): List of (dataset_name, sample_index).
|
||||
size (Optional[int]): Desired dataset size.
|
||||
weight (Optional[float]): Desired dataset weight.
|
||||
|
||||
Args:
|
||||
data_index (list[tuple[str, int]]): List of (dataset_name, sample_index).
|
||||
size (Optional[int]): Desired dataset size.
|
||||
weight (Optional[float]): Desired dataset weight.
|
||||
Returns:
|
||||
list[tuple[str, int]]: Adjusted dataset index.
|
||||
"""
|
||||
if size is not None:
|
||||
data_index = random.choices(data_index, k=size)
|
||||
|
||||
Returns:
|
||||
list[tuple[str, int]]: Adjusted dataset index.
|
||||
"""
|
||||
if size is not None:
|
||||
data_index = random.choices(data_index, k=size)
|
||||
if weight is not None:
|
||||
data_index = random.choices(data_index, k=int(len(data_index) * weight))
|
||||
|
||||
if weight is not None:
|
||||
data_index = random.choices(data_index, k=int(len(data_index) * weight))
|
||||
|
||||
return data_index
|
||||
return data_index
|
||||
|
||||
|
||||
class DataSelectorPlugin(BasePlugin):
|
||||
"""Plugin for selecting dataset samples."""
|
||||
def select_data_sample(
|
||||
data_index: list[tuple[str, int]], index: slice | list[int] | Any
|
||||
) -> tuple[str, int] | list[tuple[str, int]]:
|
||||
"""Select dataset samples.
|
||||
|
||||
def select(
|
||||
self, data_index: list[tuple[str, int]], index: slice | list[int] | Any
|
||||
) -> tuple[str, int] | list[tuple[str, int]]:
|
||||
"""Select dataset samples.
|
||||
Args:
|
||||
data_index (list[tuple[str, int]]): List of (dataset_name, sample_index).
|
||||
index (Union[slice, list[int], Any]): Index of dataset samples.
|
||||
|
||||
Args:
|
||||
data_index (list[tuple[str, int]]): List of (dataset_name, sample_index).
|
||||
index (Union[slice, list[int], Any]): Index of dataset samples.
|
||||
|
||||
Returns:
|
||||
Union[tuple[str, int], list[tuple[str, int]]]: Selected dataset samples.
|
||||
"""
|
||||
if isinstance(index, slice):
|
||||
return [data_index[i] for i in range(*index.indices(len(data_index)))]
|
||||
elif isinstance(index, list):
|
||||
return [data_index[i] for i in index]
|
||||
else:
|
||||
raise ValueError(f"Invalid index type {type(index)}.")
|
||||
Returns:
|
||||
Union[tuple[str, int], list[tuple[str, int]]]: Selected dataset samples.
|
||||
"""
|
||||
if isinstance(index, slice):
|
||||
return [data_index[i] for i in range(*index.indices(len(data_index)))]
|
||||
elif isinstance(index, list):
|
||||
return [data_index[i] for i in index]
|
||||
else:
|
||||
raise ValueError(f"Invalid index type {type(index)}.")
|
||||
|
||||
@@ -1,133 +0,0 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class Template:
|
||||
user_template: str
|
||||
assistant_template: str
|
||||
system_template: str
|
||||
|
||||
def render_message(self, message: dict[str, str]) -> str:
|
||||
return self.user_template.format(**message)
|
||||
|
||||
|
||||
@dataclass
|
||||
class QwenTemplate:
|
||||
message_template: str = "<|im_start|>{role}\n{content}<|im_end|>\n" # FIXME if role: tool
|
||||
thinking_template: str = "<think>\n{content}\n</think>\n\n"
|
||||
|
||||
def _extract_content(self, content_data: str | list[dict[str, str]]) -> str:
|
||||
if isinstance(content_data, str):
|
||||
return content_data.strip()
|
||||
|
||||
if isinstance(content_data, list):
|
||||
parts = []
|
||||
for item in content_data:
|
||||
if item.get("type") == "text":
|
||||
parts.append(item.get("value", ""))
|
||||
elif item.get("type") == "image_url":
|
||||
pass
|
||||
return "\n".join(parts).strip()
|
||||
|
||||
return ""
|
||||
|
||||
def render_message(self, message: dict[str, str | list[dict[str, str]]]) -> str:
|
||||
role = message["role"]
|
||||
content = self._extract_content(message.get("content", ""))
|
||||
|
||||
if role == "assistant":
|
||||
reasoning_content = message.get("reasoning_content", "")
|
||||
if reasoning_content:
|
||||
reasoning_content = self.thinking_template.format(content=str(reasoning_content).strip())
|
||||
return self.message_template.format(role="assistant", content=reasoning_content + content)
|
||||
else:
|
||||
return self.message_template.format(role=role, content=content)
|
||||
|
||||
def encode_messages(self, tokenizer, messages: list[dict[str, str]], max_seq_len: int = 8192) -> any:
|
||||
"""Encode one message."""
|
||||
input_ids, attention_mask, labels = [], [], []
|
||||
for message in messages:
|
||||
content_str = self.render_message(message)
|
||||
content_ids = tokenizer.encode(content_str, add_special_tokens=False)
|
||||
input_ids += content_ids
|
||||
attention_mask += [1] * len(content_ids)
|
||||
|
||||
if hasattr(message, "loss_weight"):
|
||||
loss_weight = message["loss_weight"]
|
||||
else:
|
||||
loss_weight = 1 if message["role"] == "assistant" else 0
|
||||
if loss_weight == 1:
|
||||
labels += content_ids
|
||||
else:
|
||||
labels += [-100] * len(content_ids)
|
||||
model_inputs = {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}
|
||||
model_inputs.update({"position_ids": list(range(len(input_ids)))})
|
||||
model_inputs = {k: v[-max_seq_len:] for k, v in model_inputs.items()}
|
||||
return model_inputs
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
def to_qwen3_messages(template: QwenTemplate, messages: list[dict]):
|
||||
out = []
|
||||
for m in messages:
|
||||
role = m["role"]
|
||||
content = template._extract_content(m.get("content", ""))
|
||||
if role == "assistant":
|
||||
reasoning = (m.get("reasoning_content") or "").strip()
|
||||
if reasoning:
|
||||
content = template.thinking_template.format(content=reasoning) + content
|
||||
out.append({"role": role, "content": content})
|
||||
return out
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
tok = AutoTokenizer.from_pretrained(
|
||||
"Qwen/Qwen3-30B-A3B-Thinking-2507",
|
||||
trust_remote_code=True,
|
||||
)
|
||||
|
||||
test_messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": "1+1等于几?"}, {"type": "text", "text": "2+2等于几?"}],
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"reasoning_content": "这是一个简单的数学问题。1加1的结果是2。",
|
||||
"content": [{"type": "text", "text": "1+1=2"}, {"type": "text", "text": "2+2=4"}],
|
||||
},
|
||||
]
|
||||
|
||||
template = QwenTemplate()
|
||||
rendered_custom = "".join([template.render_message(m) for m in test_messages])
|
||||
|
||||
qwen3_messages = to_qwen3_messages(template, test_messages)
|
||||
rendered_hf = tok.apply_chat_template(qwen3_messages, tokenize=False, add_generation_prompt=False)
|
||||
|
||||
print("==== custom ====")
|
||||
print(rendered_custom)
|
||||
print("==== hf ====")
|
||||
print(rendered_hf)
|
||||
|
||||
assert rendered_custom.strip() == rendered_hf.strip(), "Rendered text mismatch"
|
||||
|
||||
ids_custom = tok.encode(rendered_custom, add_special_tokens=False)
|
||||
ids_hf = tok.apply_chat_template(qwen3_messages, tokenize=True, add_generation_prompt=False)
|
||||
assert ids_custom == ids_hf, f"Token ids mismatch: custom={len(ids_custom)} hf={len(ids_hf)}"
|
||||
@@ -0,0 +1,43 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
from ...accelerator.helper import DeviceType
|
||||
from ...accelerator.interface import DistributedInterface
|
||||
from ...utils.plugin import BasePlugin
|
||||
|
||||
|
||||
class InitPlugin(BasePlugin):
|
||||
def __call__(self) -> torch.device:
|
||||
return super().__call__()
|
||||
|
||||
|
||||
@InitPlugin("init_on_meta").register()
|
||||
def init_on_meta() -> torch.device:
|
||||
return torch.device(DeviceType.META.value)
|
||||
|
||||
|
||||
@InitPlugin("init_on_rank0").register()
|
||||
def init_on_rank0() -> torch.device:
|
||||
if DistributedInterface().get_rank() == 0:
|
||||
return torch.device(DeviceType.CPU.value)
|
||||
else:
|
||||
return torch.device(DeviceType.META.value)
|
||||
|
||||
|
||||
@InitPlugin("init_on_default").register()
|
||||
def init_on_default() -> torch.device:
|
||||
return DistributedInterface().current_device
|
||||
|
||||
@@ -38,17 +38,17 @@ class BaseKernel(ABC):
|
||||
|
||||
@classmethod
|
||||
def get_kernel_id(cls) -> str:
|
||||
r"""Returns the unique identifier for the kernel."""
|
||||
"""Returns the unique identifier for the kernel."""
|
||||
return cls._kernel_id
|
||||
|
||||
@classmethod
|
||||
def get_device(cls) -> str:
|
||||
r"""Returns the device type associated with the kernel (e.g., "cuda", "npu", "cpu")."""
|
||||
"""Returns the device type associated with the kernel (e.g., "cuda", "npu", "cpu")."""
|
||||
return cls._device
|
||||
|
||||
@classmethod
|
||||
def check_deps(cls) -> bool:
|
||||
r"""Checks if the required dependencies for the kernel are available.
|
||||
"""Checks if the required dependencies for the kernel are available.
|
||||
|
||||
Returns:
|
||||
bool: ``True`` if dependencies are met, ``False`` otherwise.
|
||||
@@ -65,7 +65,7 @@ class BaseKernel(ABC):
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def apply(cls, **kwargs) -> HFModel:
|
||||
r"""Applies the kernel optimization to the model.
|
||||
"""Applies the kernel optimization to the model.
|
||||
|
||||
Args:
|
||||
**kwargs: Arbitrary keyword arguments, usually containing the model instance and the kernel configuration.
|
||||
|
||||
@@ -33,7 +33,7 @@ logger = get_logger(__name__)
|
||||
|
||||
|
||||
def scan_all_kernels():
|
||||
r"""Scan all kernels in the ``ops`` directory.
|
||||
"""Scan all kernels in the ``ops`` directory.
|
||||
|
||||
Scans the ``ops`` directory for all ``.py`` files and attempts to import them.
|
||||
Importing triggers the :func:`~registry.register_kernel` decorator, which automatically registers the kernels.
|
||||
@@ -77,7 +77,7 @@ default_kernels = scan_all_kernels()
|
||||
|
||||
|
||||
def get_default_kernels():
|
||||
r"""Get a list of default registered kernel IDs.
|
||||
"""Get a list of default registered kernel IDs.
|
||||
|
||||
Returns:
|
||||
list[str]: List of kernel IDs.
|
||||
@@ -86,7 +86,7 @@ def get_default_kernels():
|
||||
|
||||
|
||||
def apply_kernel(kernel_id: str, **kwargs):
|
||||
r"""Applies a specific kernel to the model.
|
||||
"""Applies a specific kernel to the model.
|
||||
|
||||
Args:
|
||||
kernel_id (str): The ID of the kernel to apply.
|
||||
@@ -99,18 +99,19 @@ def apply_kernel(kernel_id: str, **kwargs):
|
||||
kernel = default_kernels.get(kernel_id)
|
||||
if kernel is None:
|
||||
raise ValueError(f"Kernel {kernel_id} not found")
|
||||
|
||||
kernel.apply(**kwargs)
|
||||
|
||||
|
||||
class KernelPlugin(BasePlugin):
|
||||
r"""Plugin for managing kernel optimizations."""
|
||||
"""Plugin for managing kernel optimizations."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@KernelPlugin("auto").register
|
||||
@KernelPlugin("auto").register()
|
||||
def apply_default_kernels(**kwargs):
|
||||
r"""Applies all default registered kernels to the model.
|
||||
"""Applies all default registered kernels to the model.
|
||||
|
||||
Args:
|
||||
**kwargs: Keyword arguments passed to the kernel application function.
|
||||
@@ -125,8 +126,11 @@ def apply_default_kernels(**kwargs):
|
||||
use_kernels = default_kernels.keys()
|
||||
else:
|
||||
use_kernels = kwargs.get("include_kernels").split(",") # "kernel_id1,kernel_id2,kernel_id3"
|
||||
|
||||
for kernel in use_kernels:
|
||||
if kernel not in default_kernels:
|
||||
raise ValueError(f"Kernel {kernel} not found")
|
||||
|
||||
apply_kernel(kernel, **kwargs)
|
||||
|
||||
return kwargs.get("model")
|
||||
|
||||
@@ -40,11 +40,11 @@ from ...registry import register_kernel
|
||||
|
||||
|
||||
class GmmFunction(torch.autograd.Function):
|
||||
r"""Custom autograd function for NPU Grouped Matrix Multiplication (GMM)."""
|
||||
"""Custom autograd function for NPU Grouped Matrix Multiplication (GMM)."""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, x, weight, group_list):
|
||||
r"""Performs the forward pass of Grouped Matrix Multiplication.
|
||||
"""Performs the forward pass of Grouped Matrix Multiplication.
|
||||
|
||||
Args:
|
||||
ctx: Context object to save tensors for backward pass.
|
||||
@@ -65,7 +65,7 @@ class GmmFunction(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
r"""Performs the backward pass of Grouped Matrix Multiplication.
|
||||
"""Performs the backward pass of Grouped Matrix Multiplication.
|
||||
|
||||
Args:
|
||||
ctx: Context object containing saved tensors.
|
||||
@@ -94,11 +94,11 @@ class GmmFunction(torch.autograd.Function):
|
||||
|
||||
|
||||
class HybridGmmFunction(torch.autograd.Function):
|
||||
r"""Custom autograd function for Hybrid Grouped Matrix Multiplication on NPU."""
|
||||
"""Custom autograd function for Hybrid Grouped Matrix Multiplication on NPU."""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, num_experts, *args):
|
||||
r"""Performs the forward pass of Hybrid GMM.
|
||||
"""Performs the forward pass of Hybrid GMM.
|
||||
|
||||
Args:
|
||||
ctx: Context object to save tensors.
|
||||
@@ -124,7 +124,7 @@ class HybridGmmFunction(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, *grad_outputs):
|
||||
r"""Performs the backward pass of Hybrid GMM.
|
||||
"""Performs the backward pass of Hybrid GMM.
|
||||
|
||||
Args:
|
||||
ctx: Context object containing saved tensors.
|
||||
@@ -176,13 +176,13 @@ class HybridGmmFunction(torch.autograd.Function):
|
||||
|
||||
|
||||
class NpuMoeFused:
|
||||
r"""Container for NPU fused MoE forward functions."""
|
||||
"""Container for NPU fused MoE forward functions."""
|
||||
|
||||
@staticmethod
|
||||
def npu_moe_experts_forward(
|
||||
self, hidden_states: torch.Tensor, routing_weights: torch.Tensor, router_indices: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
r"""Forward pass for MoE experts using NPU fused operations.
|
||||
"""Forward pass for MoE experts using NPU fused operations.
|
||||
|
||||
Args:
|
||||
self: The MoE layer instance.
|
||||
@@ -230,11 +230,11 @@ class NpuMoeFused:
|
||||
|
||||
|
||||
class Qwen3NpuMoeFused:
|
||||
r"""Container for Qwen3 NPU fused MoE forward functions."""
|
||||
"""Container for Qwen3 NPU fused MoE forward functions."""
|
||||
|
||||
@staticmethod
|
||||
def qwen3moe_sparse_moe_block_forward(self, hidden_states: torch.Tensor):
|
||||
r"""Forward pass for Qwen3 sparse MoE block using NPU fused operations.
|
||||
"""Forward pass for Qwen3 sparse MoE block using NPU fused operations.
|
||||
|
||||
Args:
|
||||
self: The Qwen3 MoE block instance.
|
||||
@@ -298,14 +298,14 @@ if not is_transformers_version_greater_than("5.0.0"):
|
||||
|
||||
@register_kernel
|
||||
class NpuFusedMoEKernel(BaseKernel):
|
||||
r"""NPU Fused MoE Kernel implementation."""
|
||||
"""NPU Fused MoE Kernel implementation."""
|
||||
|
||||
_kernel_id = "npu_fused_moe"
|
||||
_device = DeviceType.NPU
|
||||
|
||||
@classmethod
|
||||
def apply(cls, **kwargs) -> HFModel:
|
||||
r"""Applies the NPU fused MoE kernel to the model.
|
||||
"""Applies the NPU fused MoE kernel to the model.
|
||||
|
||||
Args:
|
||||
**kwargs: Keyword arguments containing the model.
|
||||
@@ -333,6 +333,7 @@ class NpuFusedMoEKernel(BaseKernel):
|
||||
|
||||
if target_moe_mapping is None:
|
||||
return model
|
||||
|
||||
for module in model.modules():
|
||||
class_name = module.__class__.__name__
|
||||
if class_name in target_moe_mapping:
|
||||
|
||||
@@ -38,7 +38,7 @@ except ImportError:
|
||||
|
||||
|
||||
def npu_swiglu_forward(self, hidden_state):
|
||||
r"""SwiGLU forward pass for NPU.
|
||||
"""SwiGLU forward pass for NPU.
|
||||
|
||||
Args:
|
||||
self: The MLP layer instance.
|
||||
@@ -53,7 +53,7 @@ def npu_swiglu_forward(self, hidden_state):
|
||||
|
||||
|
||||
def _npu_swiglu_glm4_forward(self, hidden_states):
|
||||
r"""SwiGLU forward pass for GLM4 on NPU.
|
||||
"""SwiGLU forward pass for GLM4 on NPU.
|
||||
|
||||
Args:
|
||||
self: The GLM4 MLP layer instance.
|
||||
@@ -68,7 +68,7 @@ def _npu_swiglu_glm4_forward(self, hidden_states):
|
||||
|
||||
|
||||
def _npu_swiglu_gemma3ntext_forward(self, hidden_states):
|
||||
r"""SwiGLU forward pass for Gemma3nText on NPU.
|
||||
"""SwiGLU forward pass for Gemma3nText on NPU.
|
||||
|
||||
Args:
|
||||
self: The Gemma3nText MLP layer instance.
|
||||
@@ -88,7 +88,7 @@ def _npu_swiglu_gemma3ntext_forward(self, hidden_states):
|
||||
|
||||
@register_kernel
|
||||
class NpuSwiGluKernel(BaseKernel):
|
||||
r"""NPU Kernel for fused SwiGLU activation."""
|
||||
"""NPU Kernel for fused SwiGLU activation."""
|
||||
|
||||
# just support apply to the following module layers
|
||||
expect_modules = frozenset(
|
||||
@@ -126,7 +126,7 @@ class NpuSwiGluKernel(BaseKernel):
|
||||
|
||||
@classmethod
|
||||
def apply(cls, **kwargs) -> "HFModel":
|
||||
r"""Applies the NPU fused SwiGLU kernel to the model.
|
||||
"""Applies the NPU fused SwiGLU kernel to the model.
|
||||
|
||||
Args:
|
||||
**kwargs: Keyword arguments containing the model.
|
||||
|
||||
@@ -30,7 +30,7 @@ from ...registry import register_kernel
|
||||
|
||||
|
||||
def npu_rms_norm_forward(self, hidden_states):
|
||||
r"""NPU forward implementation for RMSNorm.
|
||||
"""NPU forward implementation for RMSNorm.
|
||||
|
||||
Args:
|
||||
self: RMSNorm module instance with `weight` and `variance_epsilon`.
|
||||
@@ -46,14 +46,14 @@ def npu_rms_norm_forward(self, hidden_states):
|
||||
|
||||
@register_kernel
|
||||
class NpuRMSNormKernel(BaseKernel):
|
||||
r"""NPU kernel wrapper for RMSNorm that applies the replacement within a model."""
|
||||
"""NPU kernel wrapper for RMSNorm that applies the replacement within a model."""
|
||||
|
||||
_kernel_id = "npu_fused_rmsnorm"
|
||||
_device = DeviceType.NPU
|
||||
|
||||
@classmethod
|
||||
def apply(cls, **kwargs) -> "HFModel":
|
||||
r"""Iterate the model and apply NPU-optimized forward to matched RMSNorm modules.
|
||||
"""Iterate the model and apply NPU-optimized forward to matched RMSNorm modules.
|
||||
|
||||
Key points:
|
||||
- Match modules whose class name contains "RMSNorm" (case-insensitive).
|
||||
@@ -78,6 +78,7 @@ class NpuRMSNormKernel(BaseKernel):
|
||||
|
||||
if not cls.check_deps():
|
||||
raise RuntimeError(f"torch_npu is not available but {cls.__name__} was called.")
|
||||
|
||||
rms_norm_pattern = re.compile("RMSNorm", re.IGNORECASE)
|
||||
|
||||
for name, module in model.named_modules():
|
||||
|
||||
@@ -40,7 +40,7 @@ except ImportError:
|
||||
|
||||
|
||||
def _apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
||||
r"""Applies Rotary Position Embedding to the query and key tensors using NPU optimization.
|
||||
"""Applies Rotary Position Embedding to the query and key tensors using NPU optimization.
|
||||
|
||||
Args:
|
||||
q (Tensor): Query tensor.
|
||||
@@ -61,7 +61,7 @@ def _apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
||||
|
||||
|
||||
def _apply_multimodal_rotary_pos_emb_qwen25_vl(q, k, cos, sin, mrope_section, unsqueeze_dim=1):
|
||||
r"""Applies Rotary Position Embedding with multimodal sections (Qwen2-VL) on NPU.
|
||||
"""Applies Rotary Position Embedding with multimodal sections (Qwen2-VL) on NPU.
|
||||
|
||||
Args:
|
||||
q (Tensor): Query tensor.
|
||||
@@ -89,14 +89,14 @@ def _apply_multimodal_rotary_pos_emb_qwen25_vl(q, k, cos, sin, mrope_section, un
|
||||
|
||||
@register_kernel
|
||||
class NpuRoPEKernel(BaseKernel):
|
||||
r"""NPU Kernel for Rotary Position Embedding."""
|
||||
"""NPU Kernel for Rotary Position Embedding."""
|
||||
|
||||
_kernel_id = "npu_fused_rope"
|
||||
_device = DeviceType.NPU
|
||||
|
||||
@classmethod
|
||||
def apply(cls, **kwargs) -> "HFModel":
|
||||
r"""Apply RoPE acceleration by monkey-patching `apply_rotary_pos_emb`.
|
||||
"""Apply RoPE acceleration by monkey-patching `apply_rotary_pos_emb`.
|
||||
|
||||
This function iterates through the model's modules to find attention layers,
|
||||
identifies the module where they are defined, and replaces the original
|
||||
@@ -115,9 +115,11 @@ class NpuRoPEKernel(BaseKernel):
|
||||
"""
|
||||
if not cls.check_deps():
|
||||
raise RuntimeError(f"torch_npu is not available but {cls.__name__} was called.")
|
||||
|
||||
model = kwargs.get("model", None)
|
||||
if model is None:
|
||||
raise ValueError(f"HFModel instance is required for {cls.__name__}.")
|
||||
|
||||
_modules = set()
|
||||
for module in model.modules():
|
||||
if "Attention" in module.__class__.__name__:
|
||||
@@ -143,4 +145,5 @@ class NpuRoPEKernel(BaseKernel):
|
||||
_modules.add(module_name)
|
||||
except Exception as e:
|
||||
logger.warning_rank0_once(f"Failed to apply RoPE kernel to module {module_name}: {e}")
|
||||
|
||||
return model
|
||||
|
||||
@@ -30,7 +30,7 @@ __all__ = ["Registry", "register_kernel"]
|
||||
|
||||
|
||||
class Registry:
|
||||
r"""Registry for managing kernel implementations.
|
||||
"""Registry for managing kernel implementations.
|
||||
|
||||
Storage structure: ``{ "kernel_id": Class }``
|
||||
"""
|
||||
@@ -38,8 +38,8 @@ class Registry:
|
||||
_kernels: dict[str, type[BaseKernel]] = {}
|
||||
|
||||
@classmethod
|
||||
def register(cls, kernel_cls: type[BaseKernel]):
|
||||
r"""Decorator to register a kernel class.
|
||||
def register(cls, kernel_cls: type[BaseKernel]) -> type[BaseKernel] | None:
|
||||
"""Decorator to register a kernel class.
|
||||
|
||||
The class must inherit from :class:`BaseKernel` and specify ``_kernel_id`` and ``_device`` attributes.
|
||||
|
||||
@@ -47,7 +47,7 @@ class Registry:
|
||||
kernel_cls (type[BaseKernel]): The kernel class to register.
|
||||
|
||||
Returns:
|
||||
type[BaseKernel]: The registered kernel class.
|
||||
type[BaseKernel] | None: The registered kernel class if the device type matches the current accelerator
|
||||
|
||||
Raises:
|
||||
TypeError: If the class does not inherit from :class:`BaseKernel`.
|
||||
@@ -55,6 +55,7 @@ class Registry:
|
||||
"""
|
||||
if not issubclass(kernel_cls, BaseKernel):
|
||||
raise TypeError(f"Class {kernel_cls} must inherit from BaseKernel")
|
||||
|
||||
kernel_id = kernel_cls.get_kernel_id()
|
||||
device = kernel_cls.get_device()
|
||||
|
||||
@@ -73,7 +74,7 @@ class Registry:
|
||||
|
||||
@classmethod
|
||||
def get(cls, kernel_id: str) -> Optional[type[BaseKernel]]:
|
||||
r"""Retrieves a registered kernel implementation by its ID.
|
||||
"""Retrieves a registered kernel implementation by its ID.
|
||||
|
||||
Args:
|
||||
kernel_id (str): The ID of the kernel to retrieve.
|
||||
@@ -85,7 +86,7 @@ class Registry:
|
||||
|
||||
@classmethod
|
||||
def get_registered_kernels(cls) -> dict[str, type[BaseKernel]]:
|
||||
r"""Returns a dictionary of all registered kernels.
|
||||
"""Returns a dictionary of all registered kernels.
|
||||
|
||||
Returns:
|
||||
dict[str, type[BaseKernel]]: Dictionary mapping kernel IDs to kernel classes.
|
||||
|
||||
@@ -45,13 +45,13 @@ class PeftPlugin(BasePlugin):
|
||||
return super().__call__(model, config)
|
||||
|
||||
|
||||
@PeftPlugin("lora").register
|
||||
@PeftPlugin("lora").register()
|
||||
def get_lora_model(model: HFModel, config: LoraConfigDict, is_train: bool) -> PeftModel:
|
||||
peft_config = LoraConfig(**config)
|
||||
model = get_peft_model(model, peft_config)
|
||||
return model
|
||||
|
||||
|
||||
@PeftPlugin("freeze").register
|
||||
@PeftPlugin("freeze").register()
|
||||
def get_freeze_model(model: HFModel, config: FreezeConfigDict, is_train: bool) -> HFModel:
|
||||
raise NotImplementedError()
|
||||
|
||||
212
src/llamafactory/v1/plugins/model_plugins/rendering.py
Normal file
212
src/llamafactory/v1/plugins/model_plugins/rendering.py
Normal file
@@ -0,0 +1,212 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import json
|
||||
import re
|
||||
|
||||
from ...utils.constants import IGNORE_INDEX
|
||||
from ...utils.helper import get_tokenizer
|
||||
from ...utils.plugin import BasePlugin
|
||||
from ...utils.types import Message, ModelInput, Processor, ToolCall
|
||||
|
||||
|
||||
class RenderingPlugin(BasePlugin):
|
||||
pass
|
||||
|
||||
|
||||
def _update_model_input(
|
||||
processor: Processor,
|
||||
input_ids: list[int],
|
||||
labels: list[int],
|
||||
loss_weights: list[int],
|
||||
temp_str: str,
|
||||
temp_weight: float,
|
||||
) -> str:
|
||||
"""Update model input with temporary string."""
|
||||
if not temp_str:
|
||||
return ""
|
||||
|
||||
tokenizer = get_tokenizer(processor)
|
||||
temp_ids = tokenizer.encode(temp_str, add_special_tokens=False)
|
||||
input_ids.extend(temp_ids)
|
||||
loss_weights.extend([temp_weight] * len(temp_ids))
|
||||
if temp_weight > 1e-6:
|
||||
labels.extend(temp_ids)
|
||||
else:
|
||||
labels.extend([IGNORE_INDEX] * len(temp_ids))
|
||||
|
||||
return ""
|
||||
|
||||
|
||||
@RenderingPlugin("qwen3_nothink").register("render_messages")
|
||||
def render_qwen_messages(
|
||||
processor: Processor,
|
||||
messages: list[Message],
|
||||
tools: str | None = None,
|
||||
is_generate: bool = False,
|
||||
) -> ModelInput:
|
||||
input_ids, labels, loss_weights = [], [], []
|
||||
temp_str, temp_weight = "", 0.0
|
||||
if tools:
|
||||
temp_str += "<|im_start|>system\n"
|
||||
if messages[0]["role"] == "system":
|
||||
for content in messages[0]["content"]:
|
||||
if content["type"] == "text":
|
||||
temp_str += content["value"]
|
||||
else:
|
||||
raise ValueError(f"Unsupported content type: {content['type']}")
|
||||
|
||||
temp_str += "\n\n"
|
||||
temp_weight = messages[0].get("loss_weight", 0.0)
|
||||
|
||||
temp_str += (
|
||||
"# Tools\n\nYou may call one or more functions to assist with the user query.\n\n"
|
||||
"You are provided with function signatures within <tools></tools> XML tags:\n<tools>"
|
||||
)
|
||||
try:
|
||||
tools = json.loads(tools)
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError(f"Invalid tools format: {str(tools)}.")
|
||||
|
||||
if not isinstance(tools, list):
|
||||
tools = [tools]
|
||||
|
||||
for tool in tools:
|
||||
temp_str += "\n" + json.dumps(tool, ensure_ascii=False)
|
||||
|
||||
temp_str += (
|
||||
"\n</tools>\n\nFor each function call, return a json object with function name "
|
||||
'and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{"name": '
|
||||
'<function-name>, "arguments": <args-json-object>}\n</tool_call><|im_end|>\n'
|
||||
)
|
||||
elif messages[0]["role"] == "system":
|
||||
temp_str += "<|im_start|>system\n"
|
||||
for content in messages[0]["content"]:
|
||||
if content["type"] == "text":
|
||||
temp_str += content["value"]
|
||||
else:
|
||||
raise ValueError(f"Unsupported content type: {content['type']}")
|
||||
|
||||
temp_str += "<|im_end|>\n"
|
||||
temp_weight = messages[0].get("loss_weight", 0.0)
|
||||
|
||||
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
|
||||
|
||||
for turn_idx, message in enumerate(messages):
|
||||
if message["role"] == "user" or (message["role"] == "system" and turn_idx != 0):
|
||||
temp_str += "<|im_start|>" + message["role"] + "\n"
|
||||
for content in message["content"]:
|
||||
if content["type"] == "text":
|
||||
temp_str += content["value"]
|
||||
else:
|
||||
raise ValueError(f"Unsupported content type: {content['type']}")
|
||||
|
||||
temp_str += "<|im_end|>\n"
|
||||
temp_weight = message.get("loss_weight", 0.0)
|
||||
elif message["role"] == "assistant":
|
||||
temp_str += "<|im_start|>" + message["role"] + "\n"
|
||||
for val_idx, content in enumerate(message["content"]):
|
||||
if content["type"] == "text":
|
||||
temp_str += content["value"]
|
||||
elif content["type"] == "reasoning":
|
||||
temp_str += "<thinking>\n" + content["value"] + "\n</thinking>\n\n" # avoid using special tokens
|
||||
elif content["type"] == "tool_call":
|
||||
if val_idx != 0 and message["content"][val_idx - 1]["type"] in ["text", "tool_call"]:
|
||||
temp_str += "\n"
|
||||
|
||||
try:
|
||||
tool_call: ToolCall = json.loads(content["value"])
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError(f"Invalid tool call format: {content['value']}.")
|
||||
|
||||
temp_str += (
|
||||
'<tool_call>\n{"name": "'
|
||||
+ tool_call["name"]
|
||||
+ '", "arguments": '
|
||||
+ json.dumps(tool_call["arguments"], ensure_ascii=False)
|
||||
+ "}\n</tool_call>"
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported content type: {content['type']}")
|
||||
|
||||
temp_str += "<|im_end|>\n"
|
||||
temp_weight = message.get("loss_weight", 1.0)
|
||||
elif message["role"] == "tool":
|
||||
if turn_idx == 0 or messages[turn_idx - 1]["role"] != "tool":
|
||||
temp_str += "<|im_start|>user"
|
||||
|
||||
temp_str += "\n<tool_response>\n"
|
||||
for content in message["content"]:
|
||||
if content["type"] == "text":
|
||||
temp_str += content["value"]
|
||||
else:
|
||||
raise ValueError(f"Unsupported content type: {content['type']}")
|
||||
|
||||
temp_str += "\n</tool_response>"
|
||||
if turn_idx == len(messages) - 1 or messages[turn_idx + 1]["role"] != "tool":
|
||||
temp_str += "<|im_end|>\n"
|
||||
|
||||
temp_weight = message.get("loss_weight", 0.0)
|
||||
|
||||
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
|
||||
|
||||
if is_generate:
|
||||
temp_str += "<|im_start|>assistant\n"
|
||||
temp_weight = 0.0
|
||||
|
||||
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
|
||||
|
||||
attention_mask = [1] * len(input_ids)
|
||||
return ModelInput(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
labels=labels,
|
||||
loss_weights=loss_weights,
|
||||
)
|
||||
|
||||
|
||||
@RenderingPlugin("qwen3_nothink").register("parse_message")
|
||||
def parse_qwen_message(generated_text: str) -> Message:
|
||||
pattern = re.compile(r"<(thinking|tool_call)>\s*(.*?)\s*</\1>\s*", re.DOTALL)
|
||||
content = []
|
||||
last_end = 0
|
||||
for match in pattern.finditer(generated_text):
|
||||
start, end = match.span()
|
||||
if start > last_end:
|
||||
text = generated_text[last_end:start].strip()
|
||||
if text:
|
||||
content.append({"type": "text", "value": text})
|
||||
|
||||
tag_type = match.group(1)
|
||||
tag_value = match.group(2).strip()
|
||||
if tag_type == "thinking":
|
||||
content.append({"type": "reasoning", "value": tag_value.strip()})
|
||||
elif tag_type == "tool_call":
|
||||
try:
|
||||
json.loads(tag_value.strip())
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError(f"Invalid tool call format: {tag_value.strip()}.")
|
||||
|
||||
content.append({"type": "tool_call", "value": tag_value.strip()})
|
||||
|
||||
last_end = end
|
||||
|
||||
if last_end < len(generated_text):
|
||||
text = generated_text[last_end:].strip()
|
||||
if text:
|
||||
content.append({"type": "text", "value": text})
|
||||
|
||||
return Message(role="assistant", content=content)
|
||||
125
src/llamafactory/v1/samplers/cli_sampler.py
Normal file
125
src/llamafactory/v1/samplers/cli_sampler.py
Normal file
@@ -0,0 +1,125 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from collections.abc import Generator
|
||||
from threading import Thread
|
||||
|
||||
from ..config import InputArgument, ModelArguments, SampleArguments, SampleBackend, get_args
|
||||
from ..core.base_sampler import BaseSampler
|
||||
from ..core.data_engine import DataEngine
|
||||
from ..core.model_engine import ModelEngine
|
||||
from ..core.utils.rendering import Renderer
|
||||
from ..utils.types import HFModel, Message, Sample, TorchDataset
|
||||
|
||||
|
||||
class SyncSampler(BaseSampler):
|
||||
def __init__(
|
||||
self,
|
||||
args: SampleArguments,
|
||||
model_args: ModelArguments,
|
||||
model: HFModel,
|
||||
renderer: Renderer,
|
||||
) -> None:
|
||||
def _start_background_loop(loop: asyncio.AbstractEventLoop) -> None:
|
||||
asyncio.set_event_loop(loop)
|
||||
loop.run_forever()
|
||||
|
||||
super().__init__(args, model_args, model, renderer)
|
||||
self._loop = asyncio.new_event_loop()
|
||||
self._thread = Thread(target=_start_background_loop, args=(self._loop,), daemon=True)
|
||||
self._thread.start()
|
||||
|
||||
def generate(self, messages: list[Message], tools: str | None = None) -> Generator[str, None, None]:
|
||||
"""Generate tokens synchronously.
|
||||
|
||||
Args:
|
||||
messages: List of messages.
|
||||
tools: Tools string.
|
||||
|
||||
Yields:
|
||||
Generated tokens.
|
||||
"""
|
||||
generator = super().generate(messages, tools)
|
||||
while True:
|
||||
try:
|
||||
token = asyncio.run_coroutine_threadsafe(generator.__anext__(), self._loop).result()
|
||||
yield token
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
|
||||
def batch_infer(self, dataset: TorchDataset) -> list[Sample]:
|
||||
"""Batch infer samples synchronously.
|
||||
|
||||
Args:
|
||||
dataset: Torch dataset.
|
||||
|
||||
Returns:
|
||||
List of samples.
|
||||
"""
|
||||
return asyncio.run_coroutine_threadsafe(super().batch_infer(dataset), self._loop).result()
|
||||
|
||||
|
||||
def run_chat(args: InputArgument = None):
|
||||
data_args, model_args, _, sample_args = get_args(args)
|
||||
if sample_args.sample_backend != SampleBackend.HF:
|
||||
model_args.init_plugin = {"name": "init_on_meta"}
|
||||
|
||||
model_engine = ModelEngine(model_args)
|
||||
sampler = SyncSampler(sample_args, model_args, model_engine.model, model_engine.renderer)
|
||||
if data_args.dataset is not None:
|
||||
dataset = DataEngine(data_args)
|
||||
sampler.batch_infer(dataset)
|
||||
else:
|
||||
if os.name != "nt":
|
||||
try:
|
||||
import readline # noqa: F401
|
||||
except ImportError:
|
||||
print("Install `readline` for a better experience.")
|
||||
|
||||
messages = []
|
||||
print("Welcome to the CLI application, use `clear` to remove the history, use `exit` to exit the application.")
|
||||
|
||||
while True:
|
||||
try:
|
||||
query = input("\nUser: ")
|
||||
except UnicodeDecodeError:
|
||||
print("Detected decoding error at the inputs, please set the terminal encoding to utf-8.")
|
||||
continue
|
||||
except Exception:
|
||||
raise
|
||||
|
||||
if query.strip() == "exit":
|
||||
break
|
||||
|
||||
if query.strip() == "clear":
|
||||
messages = []
|
||||
print("History has been removed.")
|
||||
continue
|
||||
|
||||
messages.append({"role": "user", "content": [{"type": "text", "value": query}]})
|
||||
print("Assistant: ", end="", flush=True)
|
||||
|
||||
response = ""
|
||||
for new_text in sampler.generate(messages):
|
||||
print(new_text, end="", flush=True)
|
||||
response += new_text
|
||||
|
||||
print()
|
||||
messages.append(model_engine.renderer.parse_message(response))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_chat()
|
||||
@@ -17,7 +17,7 @@ from ..accelerator.interface import DistributedInterface
|
||||
from ..config.arg_parser import get_args
|
||||
from ..core.base_trainer import BaseTrainer
|
||||
from ..core.data_engine import DataEngine
|
||||
from ..core.model_loader import ModelLoader
|
||||
from ..core.model_engine import ModelEngine
|
||||
|
||||
|
||||
class SFTTrainer(BaseTrainer):
|
||||
@@ -28,11 +28,11 @@ def run_sft(user_args):
|
||||
model_args, data_args, training_args, _ = get_args(user_args)
|
||||
DistributedInterface(training_args.dist_config)
|
||||
data_engine = DataEngine(data_args)
|
||||
model_loader = ModelLoader(model_args)
|
||||
model_engine = ModelEngine(model_args)
|
||||
trainer = SFTTrainer(
|
||||
args=training_args,
|
||||
model=model_loader.model,
|
||||
processor=model_loader.processor,
|
||||
model=model_engine.model,
|
||||
processor=model_engine.processor,
|
||||
dataset=data_engine,
|
||||
)
|
||||
trainer.fit()
|
||||
|
||||
@@ -11,3 +11,5 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
IGNORE_INDEX = -100
|
||||
|
||||
29
src/llamafactory/v1/utils/helper.py
Normal file
29
src/llamafactory/v1/utils/helper.py
Normal file
@@ -0,0 +1,29 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
from .types import Processor
|
||||
|
||||
|
||||
def get_tokenizer(processor: Processor) -> PreTrainedTokenizer:
|
||||
"""Get tokenizer from processor.
|
||||
|
||||
Args:
|
||||
processor: Processor.
|
||||
|
||||
Returns:
|
||||
Tokenizer.
|
||||
"""
|
||||
return processor.tokenizer if hasattr(processor, "tokenizer") else processor
|
||||
@@ -54,7 +54,7 @@ def _get_default_logging_level() -> "logging._Level":
|
||||
|
||||
|
||||
def _get_library_name() -> str:
|
||||
return __name__.split(".")[0]
|
||||
return ".".join(__name__.split(".")[:2]) # llamafactory.v1
|
||||
|
||||
|
||||
def _get_library_root_logger() -> "_Logger":
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable
|
||||
|
||||
from . import logging
|
||||
@@ -27,7 +28,7 @@ class BasePlugin:
|
||||
A plugin is a callable object that can be registered and called by name.
|
||||
"""
|
||||
|
||||
_registry: dict[str, Callable] = {}
|
||||
_registry: dict[str, dict[str, Callable]] = defaultdict(dict)
|
||||
|
||||
def __init__(self, name: str | None = None):
|
||||
"""Initialize the plugin with a name.
|
||||
@@ -37,8 +38,7 @@ class BasePlugin:
|
||||
"""
|
||||
self.name = name
|
||||
|
||||
@property
|
||||
def register(self):
|
||||
def register(self, method_name: str = "__call__"):
|
||||
"""Decorator to register a function as a plugin.
|
||||
|
||||
Example usage:
|
||||
@@ -46,16 +46,21 @@ class BasePlugin:
|
||||
@PrintPlugin("hello").register()
|
||||
def print_hello():
|
||||
print("Hello world!")
|
||||
|
||||
|
||||
@PrintPlugin("hello").register("again")
|
||||
def print_hello_again():
|
||||
print("Hello world! Again.")
|
||||
```
|
||||
"""
|
||||
if self.name is None:
|
||||
raise ValueError("Plugin name is not specified.")
|
||||
raise ValueError("Plugin name should be specified.")
|
||||
|
||||
if self.name in self._registry:
|
||||
logger.warning_rank0_once(f"Plugin {self.name} is already registered.")
|
||||
if method_name in self._registry[self.name]:
|
||||
logger.warning_rank0_once(f"Method {method_name} of plugin {self.name} is already registered.")
|
||||
|
||||
def decorator(func: Callable) -> Callable:
|
||||
self._registry[self.name] = func
|
||||
self._registry[self.name][method_name] = func
|
||||
return func
|
||||
|
||||
return decorator
|
||||
@@ -68,10 +73,23 @@ class BasePlugin:
|
||||
PrintPlugin("hello")()
|
||||
```
|
||||
"""
|
||||
if self.name not in self._registry:
|
||||
raise ValueError(f"Plugin {self.name} is not registered.")
|
||||
if "__call__" not in self._registry[self.name]:
|
||||
raise ValueError(f"Method __call__ of plugin {self.name} is not registered.")
|
||||
|
||||
return self._registry[self.name](*args, **kwargs)
|
||||
return self._registry[self.name]["__call__"](*args, **kwargs)
|
||||
|
||||
def __getattr__(self, method_name: str):
|
||||
"""Get the registered function with the given name.
|
||||
|
||||
Example usage:
|
||||
```python
|
||||
PrintPlugin("hello").again()
|
||||
```
|
||||
"""
|
||||
if method_name not in self._registry[self.name]:
|
||||
raise ValueError(f"Method {method_name} of plugin {self.name} is not registered.")
|
||||
|
||||
return self._registry[self.name][method_name]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -82,8 +100,13 @@ if __name__ == "__main__":
|
||||
class PrintPlugin(BasePlugin):
|
||||
pass
|
||||
|
||||
@PrintPlugin("hello").register
|
||||
@PrintPlugin("hello").register()
|
||||
def print_hello():
|
||||
print("Hello world!")
|
||||
|
||||
@PrintPlugin("hello").register("again")
|
||||
def print_hello_again():
|
||||
print("Hello world! Again.")
|
||||
|
||||
PrintPlugin("hello")()
|
||||
PrintPlugin("hello").again()
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING, Literal, NotRequired, TypedDict, Union
|
||||
from typing import TYPE_CHECKING, Any, Literal, NotRequired, TypedDict, Union
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -84,27 +84,63 @@ class DistributedConfig(TypedDict, total=False):
|
||||
|
||||
|
||||
class Content(TypedDict):
|
||||
type: Literal["text", "reasoning", "tools", "tool_calls", "image_url"]
|
||||
type: Literal["text", "reasoning", "tool_call", "image_url"]
|
||||
"""Type of the content."""
|
||||
value: str
|
||||
"""Value of the content."""
|
||||
|
||||
|
||||
class Message(TypedDict):
|
||||
role: Literal["system", "user", "assistant", "tool"]
|
||||
"""Role of the message."""
|
||||
content: list[Content]
|
||||
loss_weight: float
|
||||
"""Content of the message."""
|
||||
loss_weight: NotRequired[float]
|
||||
"""Loss weight for this message, default to 1.0. Required in training."""
|
||||
|
||||
|
||||
class SFTSample(TypedDict):
|
||||
messages: list[Message]
|
||||
"""Messages in the sample."""
|
||||
tools: NotRequired[str]
|
||||
"""Tools for the sample in JSON string format."""
|
||||
extra_info: NotRequired[str]
|
||||
"""Extra information for the sample, e.g. kto_labels."""
|
||||
_dataset_name: NotRequired[str]
|
||||
"""Dataset name for the sample."""
|
||||
|
||||
|
||||
class DPOSample(TypedDict):
|
||||
chosen_messages: list[Message]
|
||||
"""Chosen messages in the sample."""
|
||||
rejected_messages: list[Message]
|
||||
"""Rejected messages in the sample."""
|
||||
tools: NotRequired[str]
|
||||
"""Tools for the sample in JSON string format."""
|
||||
extra_info: NotRequired[str]
|
||||
"""Extra information for the sample, e.g. kto_labels."""
|
||||
_dataset_name: NotRequired[str]
|
||||
"""Dataset name for the sample."""
|
||||
|
||||
|
||||
Sample = Union[SFTSample, DPOSample]
|
||||
|
||||
|
||||
class ToolCall(TypedDict):
|
||||
name: str
|
||||
"""Function name."""
|
||||
arguments: dict[str, Any]
|
||||
"""Function arguments."""
|
||||
|
||||
|
||||
class ModelInput(TypedDict, total=False):
|
||||
input_ids: list[int]
|
||||
"""Input ids for the model."""
|
||||
attention_mask: list[int]
|
||||
"""Attention mask for the model."""
|
||||
labels: list[int]
|
||||
"""Labels for the model."""
|
||||
loss_weights: list[float]
|
||||
"""Loss weight for each token, default to 1.0."""
|
||||
position_ids: NotRequired[list[int] | list[list[int]]]
|
||||
"""Position ids for the model (optional)."""
|
||||
|
||||
@@ -12,13 +12,13 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""LLaMA-Factory test configuration.
|
||||
"""LlamaFactory test configuration.
|
||||
|
||||
Contains shared fixtures, pytest configuration, and custom markers.
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Optional
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@@ -73,7 +73,7 @@ def _handle_slow_tests(items: list[Item]):
|
||||
item.add_marker(skip_slow)
|
||||
|
||||
|
||||
def _get_visible_devices_env() -> Optional[str]:
|
||||
def _get_visible_devices_env() -> str | None:
|
||||
"""Return device visibility env var name."""
|
||||
if CURRENT_DEVICE == "cuda":
|
||||
return "CUDA_VISIBLE_DEVICES"
|
||||
@@ -110,11 +110,10 @@ def _handle_device_visibility(items: list[Item]):
|
||||
def pytest_collection_modifyitems(config: Config, items: list[Item]):
|
||||
"""Modify test collection based on markers and environment."""
|
||||
# Handle version compatibility (from HEAD)
|
||||
if not is_transformers_version_greater_than("4.57.0"):
|
||||
skip_bc = pytest.mark.skip(reason="Skip backward compatibility tests")
|
||||
for item in items:
|
||||
if "tests_v1" in str(item.fspath):
|
||||
item.add_marker(skip_bc)
|
||||
skip_bc = pytest.mark.skip(reason="Skip backward compatibility tests")
|
||||
for item in items:
|
||||
if "tests_v1" in str(item.fspath) and not is_transformers_version_greater_than("4.57.0"):
|
||||
item.add_marker(skip_bc)
|
||||
|
||||
_handle_slow_tests(items)
|
||||
_handle_runs_on(items)
|
||||
@@ -150,12 +149,21 @@ def _manage_distributed_env(request: FixtureRequest, monkeypatch: MonkeyPatch) -
|
||||
devices_str = ",".join(str(i) for i in range(required))
|
||||
|
||||
monkeypatch.setenv(env_key, devices_str)
|
||||
|
||||
# add project root dir to path for mp run
|
||||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
if project_root not in sys.path:
|
||||
sys.path.insert(0, project_root)
|
||||
|
||||
os.environ["PYTHONPATH"] = project_root + os.pathsep + os.environ.get("PYTHONPATH", "")
|
||||
|
||||
else: # non-distributed test
|
||||
if old_value:
|
||||
visible_devices = [v for v in old_value.split(",") if v != ""]
|
||||
monkeypatch.setenv(env_key, visible_devices[0] if visible_devices else "0")
|
||||
else:
|
||||
monkeypatch.setenv(env_key, "0")
|
||||
|
||||
if CURRENT_DEVICE == "cuda":
|
||||
monkeypatch.setattr(torch.cuda, "device_count", lambda: 1)
|
||||
elif CURRENT_DEVICE == "npu":
|
||||
|
||||
@@ -292,3 +292,91 @@ def test_qwen_multi_tool_extractor():
|
||||
("test_tool", """{"foo": "bar", "size": 10}"""),
|
||||
("another_tool", """{"foo": "job", "size": 2}"""),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.runs_on(["cpu", "mps"])
|
||||
def test_lfm2_function_formatter():
|
||||
formatter = FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="lfm2")
|
||||
tool_calls = json.dumps(FUNCTION)
|
||||
assert formatter.apply(content=tool_calls) == [
|
||||
"""<|tool_call_start|>[tool_name(foo="bar", size=10)]<|tool_call_end|><|im_end|>\n"""
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.runs_on(["cpu", "mps"])
|
||||
def test_lfm2_multi_function_formatter():
|
||||
formatter = FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="lfm2")
|
||||
tool_calls = json.dumps([FUNCTION] * 2)
|
||||
assert formatter.apply(content=tool_calls) == [
|
||||
"""<|tool_call_start|>[tool_name(foo="bar", size=10), tool_name(foo="bar", size=10)]<|tool_call_end|>"""
|
||||
"<|im_end|>\n"
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.runs_on(["cpu", "mps"])
|
||||
def test_lfm2_tool_formatter():
|
||||
formatter = ToolFormatter(tool_format="lfm2")
|
||||
assert formatter.apply(content=json.dumps(TOOLS)) == [
|
||||
"List of tools: <|tool_list_start|>" + json.dumps(TOOLS, ensure_ascii=False) + "<|tool_list_end|>"
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.runs_on(["cpu", "mps"])
|
||||
def test_lfm2_tool_extractor():
|
||||
formatter = ToolFormatter(tool_format="lfm2")
|
||||
result = """<|tool_call_start|>[test_tool(foo="bar", size=10)]<|tool_call_end|>"""
|
||||
assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")]
|
||||
|
||||
|
||||
@pytest.mark.runs_on(["cpu", "mps"])
|
||||
def test_lfm2_multi_tool_extractor():
|
||||
formatter = ToolFormatter(tool_format="lfm2")
|
||||
result = """<|tool_call_start|>[test_tool(foo="bar", size=10), another_tool(foo="job", size=2)]<|tool_call_end|>"""
|
||||
assert formatter.extract(result) == [
|
||||
("test_tool", """{"foo": "bar", "size": 10}"""),
|
||||
("another_tool", """{"foo": "job", "size": 2}"""),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.runs_on(["cpu", "mps"])
|
||||
def test_lfm2_tool_extractor_with_nested_dict():
|
||||
formatter = ToolFormatter(tool_format="lfm2")
|
||||
result = """<|tool_call_start|>[search(query="test", options={"limit": 10, "offset": 0})]<|tool_call_end|>"""
|
||||
extracted = formatter.extract(result)
|
||||
assert len(extracted) == 1
|
||||
assert extracted[0][0] == "search"
|
||||
args = json.loads(extracted[0][1])
|
||||
assert args["query"] == "test"
|
||||
assert args["options"] == {"limit": 10, "offset": 0}
|
||||
|
||||
|
||||
@pytest.mark.runs_on(["cpu", "mps"])
|
||||
def test_lfm2_tool_extractor_with_list_arg():
|
||||
formatter = ToolFormatter(tool_format="lfm2")
|
||||
result = """<|tool_call_start|>[batch_process(items=[1, 2, 3], enabled=True)]<|tool_call_end|>"""
|
||||
extracted = formatter.extract(result)
|
||||
assert len(extracted) == 1
|
||||
assert extracted[0][0] == "batch_process"
|
||||
args = json.loads(extracted[0][1])
|
||||
assert args["items"] == [1, 2, 3]
|
||||
assert args["enabled"] is True
|
||||
|
||||
|
||||
@pytest.mark.runs_on(["cpu", "mps"])
|
||||
def test_lfm2_tool_extractor_no_match():
|
||||
formatter = ToolFormatter(tool_format="lfm2")
|
||||
result = "This is a regular response without tool calls."
|
||||
extracted = formatter.extract(result)
|
||||
assert extracted == result
|
||||
|
||||
|
||||
@pytest.mark.runs_on(["cpu", "mps"])
|
||||
def test_lfm2_tool_round_trip():
|
||||
formatter = FunctionFormatter(slots=["{{content}}"], tool_format="lfm2")
|
||||
tool_formatter = ToolFormatter(tool_format="lfm2")
|
||||
original = {"name": "my_func", "arguments": {"arg1": "hello", "arg2": 42, "arg3": True}}
|
||||
formatted = formatter.apply(content=json.dumps(original))
|
||||
extracted = tool_formatter.extract(formatted[0])
|
||||
assert len(extracted) == 1
|
||||
assert extracted[0][0] == original["name"]
|
||||
assert json.loads(extracted[0][1]) == original["arguments"]
|
||||
|
||||
@@ -419,3 +419,15 @@ def test_video_llava_plugin():
|
||||
]
|
||||
check_inputs["expected_mm_inputs"] = _get_mm_inputs(tokenizer_module["processor"])
|
||||
_check_plugin(**check_inputs)
|
||||
|
||||
|
||||
@pytest.mark.runs_on(["cpu", "mps"])
|
||||
def test_lfm2_vl_plugin():
|
||||
"""Test LFM2.5-VL plugin instantiation."""
|
||||
# Test plugin can be instantiated with correct tokens
|
||||
lfm2_vl_plugin = get_mm_plugin(name="lfm2_vl", image_token="<image>")
|
||||
assert lfm2_vl_plugin is not None
|
||||
assert lfm2_vl_plugin.image_token == "<image>"
|
||||
assert lfm2_vl_plugin.video_token is None
|
||||
assert lfm2_vl_plugin.audio_token is None
|
||||
assert lfm2_vl_plugin.__class__.__name__ == "LFMVLPlugin"
|
||||
|
||||
@@ -1,2 +1,2 @@
|
||||
# change if test fails or cache is outdated
|
||||
0.9.4.105
|
||||
0.9.5.103
|
||||
|
||||
@@ -24,7 +24,6 @@ def test_get_args_from_yaml(tmp_path: pathlib.Path):
|
||||
### model
|
||||
model: "llamafactory/tiny-random-qwen2.5"
|
||||
trust_remote_code: true
|
||||
use_fast_processor: true
|
||||
model_class: "llm"
|
||||
kernel_config:
|
||||
name: "auto"
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user