11 Commits

Author SHA1 Message Date
Yaowei Zheng
95ac3f2373 [release] Bye 2025 (#9702) 2025-12-31 22:22:40 +08:00
Username_Full
000526908a [core deps] upgrade TRL to be between 0.18 and 0.24 (#9617)
Co-authored-by: Yaowei Zheng <hiyouga@buaa.edu.cn>
2025-12-31 20:54:27 +08:00
fivehaitao
c8d7e85b3e [fix] Fix prediction metrics in scripts/vllm_infer.py to match Transformers (#9701)
Co-authored-by: xuht6 <xuht6@asiainfo.com>
2025-12-31 18:30:00 +08:00
浮梦
16735b9e35 [v1] Refactor kernel plugin (#9669)
Co-authored-by: frozenleaves <frozen@Mac.local>
2025-12-31 18:26:48 +08:00
Weize Liu
4e1d69579a [data] add DLR-Web dataset for supervised fine-tuning (#9696) 2025-12-30 20:50:38 +08:00
浮梦
1857fbdd6b [ci] add cuda workflow (#9682)
Co-authored-by: frozenleaves <frozen@Mac.local>
Co-authored-by: Yaowei Zheng <hiyouga@buaa.edu.cn>
2025-12-29 20:03:00 +08:00
Kingsley
bb1ba31005 [misc] lint mca code (#9692) 2025-12-29 11:44:38 +08:00
Copilot
e97d0474fb [ci] Fix NPU device condition in docker workflow (#9688)
Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: hiyouga <16256802+hiyouga@users.noreply.github.com>
2025-12-28 20:04:59 +08:00
Yaowei Zheng
3f0c3dc84d [assets] fix installation (#9687)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2025-12-28 19:29:28 +08:00
Hertz
c107cc22d0 [model] support MiniMax-M1&M2 series (#9680)
Co-authored-by: Yaowei Zheng <hiyouga@buaa.edu.cn>
2025-12-28 19:02:05 +08:00
Yaowei Zheng
7ef1fba34a [version] fix gradio (#9685) 2025-12-28 05:00:51 +08:00
92 changed files with 1383 additions and 884 deletions

View File

@@ -72,7 +72,7 @@ jobs:
password: ${{ secrets.DOCKERHUB_TOKEN }} password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Login to Quay - name: Login to Quay
if: ${{ github.event_name != 'pull_request' && matrix.device == 'npu'}} if: ${{ github.event_name != 'pull_request' && startsWith(matrix.device, 'npu') }}
uses: docker/login-action@v3 uses: docker/login-action@v3
with: with:
registry: quay.io registry: quay.io

View File

@@ -27,23 +27,23 @@ jobs:
python: python:
- "3.11" - "3.11"
- "3.12" - "3.12"
# - "3.13" # enable after trl is upgraded - "3.13"
os: os:
- "ubuntu-latest" - "ubuntu-latest"
- "windows-latest" - "windows-latest"
- "macos-latest" - "macos-latest"
transformers: transformers:
- null - ""
include: # test backward compatibility include: # test backward compatibility
- python: "3.11"
os: "ubuntu-latest"
transformers: "4.49.0"
- python: "3.11" - python: "3.11"
os: "ubuntu-latest" os: "ubuntu-latest"
transformers: "4.51.0" transformers: "4.51.0"
- python: "3.11" - python: "3.11"
os: "ubuntu-latest" os: "ubuntu-latest"
transformers: "4.53.0" transformers: "4.53.0"
- python: "3.11"
os: "ubuntu-latest"
transformers: "4.55.0"
runs-on: ${{ matrix.os }} runs-on: ${{ matrix.os }}

88
.github/workflows/tests_cuda.yml vendored Normal file
View File

@@ -0,0 +1,88 @@
name: tests_cuda
on:
workflow_dispatch:
push:
branches:
- "main"
paths:
- "**/*.py"
- "pyproject.toml"
- "Makefile"
- ".github/workflows/*.yml"
pull_request:
branches:
- "main"
paths:
- "**/*.py"
- "pyproject.toml"
- "Makefile"
- ".github/workflows/*.yml"
jobs:
tests:
strategy:
fail-fast: false
matrix:
python:
- "3.11"
os:
- "linux-x86_64-gpu-2"
runs-on: ${{ matrix.os }}
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ matrix.os }}-${{ matrix.python }}
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Install uv
uses: astral-sh/setup-uv@v7
with:
python-version: ${{ matrix.python }}
github-token: ${{ github.token }}
enable-cache: false
- name: Check GPU Status
run: nvidia-smi
- name: Install dependencies
run: |
uv venv
uv pip install -e ".[dev]"
- name: Cache HuggingFace models
id: hf-hub-cache
uses: actions/cache@v4
with:
path: ${{ runner.temp }}/huggingface
key: hf-cache-${{ runner.os }}-${{ hashFiles('tests/version.txt') }}
- name: Check quality
run: |
make style && make quality
env:
UV_NO_SYNC: 1
- name: Check license
run: |
make license
env:
UV_NO_SYNC: 1
- name: Check build
run: |
make build
env:
UV_NO_SYNC: 1
- name: Test with pytest
run: |
make test
env:
UV_NO_SYNC: 1
HF_HOME: ${{ runner.temp }}/huggingface
HF_HUB_OFFLINE: "${{ steps.hf-hub-cache.outputs.cache-hit == 'true' && '1' || '0' }}"

View File

@@ -49,8 +49,11 @@ jobs:
uses: actions/checkout@v4 uses: actions/checkout@v4
- name: Install uv - name: Install uv
run: | uses: astral-sh/setup-uv@v7
curl -LsSf https://astral.sh/uv/install.sh | sh with:
python-version: ${{ matrix.python }}
github-token: ${{ github.token }}
enable-cache: false
- name: Install dependencies - name: Install dependencies
run: | run: |

View File

@@ -309,6 +309,7 @@ Read technical notes:
| [MiMo](https://huggingface.co/XiaomiMiMo) | 7B/309B | mimo/mimo_v2 | | [MiMo](https://huggingface.co/XiaomiMiMo) | 7B/309B | mimo/mimo_v2 |
| [MiniCPM 1-4.1](https://huggingface.co/openbmb) | 0.5B/1B/2B/4B/8B | cpm/cpm3/cpm4 | | [MiniCPM 1-4.1](https://huggingface.co/openbmb) | 0.5B/1B/2B/4B/8B | cpm/cpm3/cpm4 |
| [MiniCPM-o-2.6/MiniCPM-V-2.6](https://huggingface.co/openbmb) | 8B | minicpm_o/minicpm_v | | [MiniCPM-o-2.6/MiniCPM-V-2.6](https://huggingface.co/openbmb) | 8B | minicpm_o/minicpm_v |
| [MiniMax-M1/MiniMax-M2](https://huggingface.co/MiniMaxAI/models) | 229B/456B | minimax1/minimax2 |
| [Ministral 3](https://huggingface.co/mistralai) | 3B/8B/14B | ministral3 | | [Ministral 3](https://huggingface.co/mistralai) | 3B/8B/14B | ministral3 |
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral | | [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
| [OLMo](https://huggingface.co/allenai) | 1B/7B | - | | [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
@@ -433,6 +434,7 @@ You also can add a custom chat template to [template.py](src/llamafactory/data/t
- [Chinese-DeepSeek-R1-Distill (zh)](https://huggingface.co/datasets/Congliu/Chinese-DeepSeek-R1-Distill-data-110k-SFT) - [Chinese-DeepSeek-R1-Distill (zh)](https://huggingface.co/datasets/Congliu/Chinese-DeepSeek-R1-Distill-data-110k-SFT)
- [LLaVA mixed (en&zh)](https://huggingface.co/datasets/BUAADreamer/llava-en-zh-300k) - [LLaVA mixed (en&zh)](https://huggingface.co/datasets/BUAADreamer/llava-en-zh-300k)
- [Pokemon-gpt4o-captions (en&zh)](https://huggingface.co/datasets/jugg1024/pokemon-gpt4o-captions) - [Pokemon-gpt4o-captions (en&zh)](https://huggingface.co/datasets/jugg1024/pokemon-gpt4o-captions)
- [DLR-Web (en)](https://huggingface.co/datasets/Attention1115/DLR-Web)
- [Open Assistant (de)](https://huggingface.co/datasets/mayflowergmbh/oasst_de) - [Open Assistant (de)](https://huggingface.co/datasets/mayflowergmbh/oasst_de)
- [Dolly 15k (de)](https://huggingface.co/datasets/mayflowergmbh/dolly-15k_de) - [Dolly 15k (de)](https://huggingface.co/datasets/mayflowergmbh/dolly-15k_de)
- [Alpaca GPT4 (de)](https://huggingface.co/datasets/mayflowergmbh/alpaca-gpt4_de) - [Alpaca GPT4 (de)](https://huggingface.co/datasets/mayflowergmbh/alpaca-gpt4_de)
@@ -514,7 +516,7 @@ huggingface-cli login
```bash ```bash
git clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git git clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git
cd LLaMA-Factory cd LLaMA-Factory
pip install -e ".[metrics]" --no-build-isolation pip install -e ".[metrics]"
``` ```
Optional dependencies available: `metrics`, `deepspeed`. Install with: `pip install -e ".[metrics,deepspeed]"` Optional dependencies available: `metrics`, `deepspeed`. Install with: `pip install -e ".[metrics,deepspeed]"`
@@ -637,7 +639,7 @@ cd transformers
pip install . pip install .
``` ```
3. Set `double_quantization: false` in the configuration. You can refer to the [example](examples/train_qlora/llama3_lora_sft_bnb_npu.yaml). 3. Set `double_quantization: false` in the configuration. You can refer to the [example](examples/train_qlora/qwen3_lora_sft_bnb_npu.yaml).
</details> </details>
@@ -652,12 +654,12 @@ You can also use **[Easy Dataset](https://github.com/ConardLi/easy-dataset)**, *
### Quickstart ### Quickstart
Use the following 3 commands to run LoRA **fine-tuning**, **inference** and **merging** of the Llama3-8B-Instruct model, respectively. Use the following 3 commands to run LoRA **fine-tuning**, **inference** and **merging** of the Qwen3-4B-Instruct model, respectively.
```bash ```bash
llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml llamafactory-cli train examples/train_lora/qwen3_lora_sft.yaml
llamafactory-cli chat examples/inference/llama3_lora_sft.yaml llamafactory-cli chat examples/inference/qwen3_lora_sft.yaml
llamafactory-cli export examples/merge_lora/llama3_lora_sft.yaml llamafactory-cli export examples/merge_lora/qwen3_lora_sft.yaml
``` ```
See [examples/README.md](examples/README.md) for advanced usage (including distributed training). See [examples/README.md](examples/README.md) for advanced usage (including distributed training).
@@ -780,7 +782,7 @@ When building the Docker image, use `-v ./hf_cache:/root/.cache/huggingface` arg
### Deploy with OpenAI-style API and vLLM ### Deploy with OpenAI-style API and vLLM
```bash ```bash
API_PORT=8000 llamafactory-cli api examples/inference/llama3.yaml infer_backend=vllm vllm_enforce_eager=true API_PORT=8000 llamafactory-cli api examples/inference/qwen3.yaml infer_backend=vllm vllm_enforce_eager=true
``` ```
> [!TIP] > [!TIP]

View File

@@ -311,6 +311,7 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
| [MiMo](https://huggingface.co/XiaomiMiMo) | 7B/309B | mimo/mimo_v2 | | [MiMo](https://huggingface.co/XiaomiMiMo) | 7B/309B | mimo/mimo_v2 |
| [MiniCPM 1-4.1](https://huggingface.co/openbmb) | 0.5B/1B/2B/4B/8B | cpm/cpm3/cpm4 | | [MiniCPM 1-4.1](https://huggingface.co/openbmb) | 0.5B/1B/2B/4B/8B | cpm/cpm3/cpm4 |
| [MiniCPM-o-2.6/MiniCPM-V-2.6](https://huggingface.co/openbmb) | 8B | minicpm_o/minicpm_v | | [MiniCPM-o-2.6/MiniCPM-V-2.6](https://huggingface.co/openbmb) | 8B | minicpm_o/minicpm_v |
| [MiniMax-M1/MiniMax-M2](https://huggingface.co/MiniMaxAI/models) | 229B/456B | minimax1/minimax2 |
| [Ministral 3](https://huggingface.co/mistralai) | 3B/8B/14B | ministral3 | | [Ministral 3](https://huggingface.co/mistralai) | 3B/8B/14B | ministral3 |
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral | | [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
| [OLMo](https://huggingface.co/allenai) | 1B/7B | - | | [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
@@ -435,6 +436,7 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
- [Chinese-DeepSeek-R1-Distill (zh)](https://huggingface.co/datasets/Congliu/Chinese-DeepSeek-R1-Distill-data-110k-SFT) - [Chinese-DeepSeek-R1-Distill (zh)](https://huggingface.co/datasets/Congliu/Chinese-DeepSeek-R1-Distill-data-110k-SFT)
- [LLaVA mixed (en&zh)](https://huggingface.co/datasets/BUAADreamer/llava-en-zh-300k) - [LLaVA mixed (en&zh)](https://huggingface.co/datasets/BUAADreamer/llava-en-zh-300k)
- [Pokemon-gpt4o-captions (en&zh)](https://huggingface.co/datasets/jugg1024/pokemon-gpt4o-captions) - [Pokemon-gpt4o-captions (en&zh)](https://huggingface.co/datasets/jugg1024/pokemon-gpt4o-captions)
- [DLR-Web (en)](https://huggingface.co/datasets/Attention1115/DLR-Web)
- [Open Assistant (de)](https://huggingface.co/datasets/mayflowergmbh/oasst_de) - [Open Assistant (de)](https://huggingface.co/datasets/mayflowergmbh/oasst_de)
- [Dolly 15k (de)](https://huggingface.co/datasets/mayflowergmbh/dolly-15k_de) - [Dolly 15k (de)](https://huggingface.co/datasets/mayflowergmbh/dolly-15k_de)
- [Alpaca GPT4 (de)](https://huggingface.co/datasets/mayflowergmbh/alpaca-gpt4_de) - [Alpaca GPT4 (de)](https://huggingface.co/datasets/mayflowergmbh/alpaca-gpt4_de)
@@ -516,7 +518,7 @@ huggingface-cli login
```bash ```bash
git clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git git clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git
cd LLaMA-Factory cd LLaMA-Factory
pip install -e ".[metrics]" --no-build-isolation pip install -e ".[metrics]"
``` ```
可选的额外依赖项:`metrics``deepspeed`。使用 `pip install -e ".[metrics,deepspeed]"` 安装。 可选的额外依赖项:`metrics``deepspeed`。使用 `pip install -e ".[metrics,deepspeed]"` 安装。
@@ -639,7 +641,7 @@ cd transformers
pip install . pip install .
``` ```
3. 在训练参数中设置 `double_quantization: false`,可参考[示例](examples/train_qlora/llama3_lora_sft_bnb_npu.yaml)。 3. 在训练参数中设置 `double_quantization: false`,可参考[示例](examples/train_qlora/qwen3_lora_sft_bnb_npu.yaml)。
</details> </details>
@@ -654,12 +656,12 @@ pip install .
### 快速开始 ### 快速开始
下面三行命令分别对 Llama3-8B-Instruct 模型进行 LoRA **微调**、**推理**和**合并**。 下面三行命令分别对 Qwen3-4B-Instruct 模型进行 LoRA **微调**、**推理**和**合并**。
```bash ```bash
llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml llamafactory-cli train examples/train_lora/qwen3_lora_sft.yaml
llamafactory-cli chat examples/inference/llama3_lora_sft.yaml llamafactory-cli chat examples/inference/qwen3_lora_sft.yaml
llamafactory-cli export examples/merge_lora/llama3_lora_sft.yaml llamafactory-cli export examples/merge_lora/qwen3_lora_sft.yaml
``` ```
高级用法请参考 [examples/README_zh.md](examples/README_zh.md)(包括多 GPU 微调)。 高级用法请参考 [examples/README_zh.md](examples/README_zh.md)(包括多 GPU 微调)。
@@ -785,7 +787,7 @@ docker exec -it llamafactory bash
### 利用 vLLM 部署 OpenAI API ### 利用 vLLM 部署 OpenAI API
```bash ```bash
API_PORT=8000 llamafactory-cli api examples/inference/llama3.yaml infer_backend=vllm vllm_enforce_eager=true API_PORT=8000 llamafactory-cli api examples/inference/qwen3.yaml infer_backend=vllm vllm_enforce_eager=true
``` ```
> [!TIP] > [!TIP]

View File

@@ -471,6 +471,14 @@
"ultrachat_de": { "ultrachat_de": {
"hf_hub_url": "mayflowergmbh/ultra-chat_de" "hf_hub_url": "mayflowergmbh/ultra-chat_de"
}, },
"dlr_web": {
"hf_hub_url": "Attention1115/DLR-Web",
"split": "full",
"columns": {
"prompt": "question",
"response": "response"
}
},
"dpo_en_demo": { "dpo_en_demo": {
"file_name": "dpo_en_demo.json", "file_name": "dpo_en_demo.json",
"ranking": true, "ranking": true,

View File

@@ -18,19 +18,19 @@ By default, LLaMA-Factory uses all visible computing devices.
Basic usage: Basic usage:
```bash ```bash
llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml llamafactory-cli train examples/train_lora/qwen3_lora_sft.yaml
``` ```
Advanced usage: Advanced usage:
```bash ```bash
CUDA_VISIBLE_DEVICES=0,1 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml \ CUDA_VISIBLE_DEVICES=0,1 llamafactory-cli train examples/train_lora/qwen3_lora_sft.yaml \
learning_rate=1e-5 \ learning_rate=1e-5 \
logging_steps=1 logging_steps=1
``` ```
```bash ```bash
bash examples/train_lora/llama3_lora_sft.sh bash examples/train_lora/qwen3_lora_sft.sh
``` ```
## Examples ## Examples
@@ -40,49 +40,43 @@ bash examples/train_lora/llama3_lora_sft.sh
#### (Continuous) Pre-Training #### (Continuous) Pre-Training
```bash ```bash
llamafactory-cli train examples/train_lora/llama3_lora_pretrain.yaml llamafactory-cli train examples/train_lora/qwen3_lora_pretrain.yaml
``` ```
#### Supervised Fine-Tuning #### Supervised Fine-Tuning
```bash ```bash
llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml llamafactory-cli train examples/train_lora/qwen3_lora_sft.yaml
``` ```
#### Multimodal Supervised Fine-Tuning #### Multimodal Supervised Fine-Tuning
```bash ```bash
llamafactory-cli train examples/train_lora/qwen2_5vl_lora_sft.yaml llamafactory-cli train examples/train_lora/qwen3vl_lora_sft.yaml
``` ```
#### DPO/ORPO/SimPO Training #### DPO/ORPO/SimPO Training
```bash ```bash
llamafactory-cli train examples/train_lora/llama3_lora_dpo.yaml llamafactory-cli train examples/train_lora/qwen3_lora_dpo.yaml
``` ```
#### Multimodal DPO/ORPO/SimPO Training #### Multimodal DPO/ORPO/SimPO Training
```bash ```bash
llamafactory-cli train examples/train_lora/qwen2_5vl_lora_dpo.yaml llamafactory-cli train examples/train_lora/qwen3vl_lora_dpo.yaml
``` ```
#### Reward Modeling #### Reward Modeling
```bash ```bash
llamafactory-cli train examples/train_lora/llama3_lora_reward.yaml llamafactory-cli train examples/train_lora/qwen3_lora_reward.yaml
```
#### PPO Training
```bash
llamafactory-cli train examples/train_lora/llama3_lora_ppo.yaml
``` ```
#### KTO Training #### KTO Training
```bash ```bash
llamafactory-cli train examples/train_lora/llama3_lora_kto.yaml llamafactory-cli train examples/train_lora/qwen3_lora_kto.yaml
``` ```
#### Preprocess Dataset #### Preprocess Dataset
@@ -90,32 +84,26 @@ llamafactory-cli train examples/train_lora/llama3_lora_kto.yaml
It is useful for large dataset, use `tokenized_path` in config to load the preprocessed dataset. It is useful for large dataset, use `tokenized_path` in config to load the preprocessed dataset.
```bash ```bash
llamafactory-cli train examples/train_lora/llama3_preprocess.yaml llamafactory-cli train examples/train_lora/qwen3_preprocess.yaml
```
#### Evaluating on MMLU/CMMLU/C-Eval Benchmarks
```bash
llamafactory-cli eval examples/train_lora/llama3_lora_eval.yaml
``` ```
#### Supervised Fine-Tuning on Multiple Nodes #### Supervised Fine-Tuning on Multiple Nodes
```bash ```bash
FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/qwen3_lora_sft.yaml
FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/qwen3_lora_sft.yaml
``` ```
#### Supervised Fine-Tuning with DeepSpeed ZeRO-3 (Weight Sharding) #### Supervised Fine-Tuning with DeepSpeed ZeRO-3 (Weight Sharding)
```bash ```bash
FORCE_TORCHRUN=1 llamafactory-cli train examples/train_lora/llama3_lora_sft_ds3.yaml FORCE_TORCHRUN=1 llamafactory-cli train examples/train_lora/qwen3_lora_sft_ds3.yaml
``` ```
#### Supervised Fine-Tuning with Ray on 4 GPUs #### Supervised Fine-Tuning with Ray on 4 GPUs
```bash ```bash
USE_RAY=1 llamafactory-cli train examples/train_lora/llama3_lora_sft_ray.yaml USE_RAY=1 llamafactory-cli train examples/train_lora/qwen3_lora_sft_ray.yaml
``` ```
### QLoRA Fine-Tuning ### QLoRA Fine-Tuning
@@ -123,13 +111,13 @@ USE_RAY=1 llamafactory-cli train examples/train_lora/llama3_lora_sft_ray.yaml
#### Supervised Fine-Tuning with 4/8-bit Bitsandbytes/HQQ/EETQ Quantization (Recommended) #### Supervised Fine-Tuning with 4/8-bit Bitsandbytes/HQQ/EETQ Quantization (Recommended)
```bash ```bash
llamafactory-cli train examples/train_qlora/llama3_lora_sft_otfq.yaml llamafactory-cli train examples/train_qlora/qwen3_lora_sft_otfq.yaml
``` ```
#### Supervised Fine-Tuning with 4-bit Bitsandbytes Quantization on Ascend NPU #### Supervised Fine-Tuning with 4-bit Bitsandbytes Quantization on Ascend NPU
```bash ```bash
llamafactory-cli train examples/train_qlora/llama3_lora_sft_bnb_npu.yaml llamafactory-cli train examples/train_qlora/qwen3_lora_sft_bnb_npu.yaml
``` ```
#### Supervised Fine-Tuning with 4/8-bit GPTQ Quantization #### Supervised Fine-Tuning with 4/8-bit GPTQ Quantization
@@ -155,14 +143,14 @@ llamafactory-cli train examples/train_qlora/llama3_lora_sft_aqlm.yaml
#### Supervised Fine-Tuning on Single Node #### Supervised Fine-Tuning on Single Node
```bash ```bash
FORCE_TORCHRUN=1 llamafactory-cli train examples/train_full/llama3_full_sft.yaml FORCE_TORCHRUN=1 llamafactory-cli train examples/train_full/qwen3_full_sft.yaml
``` ```
#### Supervised Fine-Tuning on Multiple Nodes #### Supervised Fine-Tuning on Multiple Nodes
```bash ```bash
FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft.yaml FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/qwen3_full_sft.yaml
FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft.yaml FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/qwen3_full_sft.yaml
``` ```
### Elastic and Fault-Tolerant Supervised Fine-Tuning on Multiple Nodes ### Elastic and Fault-Tolerant Supervised Fine-Tuning on Multiple Nodes
@@ -170,13 +158,13 @@ FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500
To launch an elastic job with `MAX_RESTARTS` failures retries, run the following on at least `MIN_NNODES` nodes and at most `MAX_NNODES` nodes. `RDZV_ID` should be set as a unique job id (shared by all nodes participating in the job). See also [torchrun](https://docs.pytorch.org/docs/stable/elastic/run.html). To launch an elastic job with `MAX_RESTARTS` failures retries, run the following on at least `MIN_NNODES` nodes and at most `MAX_NNODES` nodes. `RDZV_ID` should be set as a unique job id (shared by all nodes participating in the job). See also [torchrun](https://docs.pytorch.org/docs/stable/elastic/run.html).
```bash ```bash
FORCE_TORCHRUN=1 MIN_NNODES=1 MAX_NNODES=3 MAX_RESTARTS=3 RDZV_ID=llamafactory MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft.yaml FORCE_TORCHRUN=1 MIN_NNODES=1 MAX_NNODES=3 MAX_RESTARTS=3 RDZV_ID=llamafactory MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/qwen3_full_sft.yaml
``` ```
#### Multimodal Supervised Fine-Tuning #### Multimodal Supervised Fine-Tuning
```bash ```bash
FORCE_TORCHRUN=1 llamafactory-cli train examples/train_full/qwen2_5vl_full_sft.yaml FORCE_TORCHRUN=1 llamafactory-cli train examples/train_full/qwen3vl_full_sft.yaml
``` ```
### Merging LoRA Adapters and Quantization ### Merging LoRA Adapters and Quantization
@@ -186,19 +174,19 @@ FORCE_TORCHRUN=1 llamafactory-cli train examples/train_full/qwen2_5vl_full_sft.y
Note: DO NOT use quantized model or `quantization_bit` when merging LoRA adapters. Note: DO NOT use quantized model or `quantization_bit` when merging LoRA adapters.
```bash ```bash
llamafactory-cli export examples/merge_lora/llama3_lora_sft.yaml llamafactory-cli export examples/merge_lora/qwen3_lora_sft.yaml
``` ```
#### Quantizing Model using AutoGPTQ #### Quantizing Model using AutoGPTQ
```bash ```bash
llamafactory-cli export examples/merge_lora/llama3_gptq.yaml llamafactory-cli export examples/merge_lora/qwen3_gptq.yaml
``` ```
### Save Ollama modelfile ### Save Ollama modelfile
```bash ```bash
llamafactory-cli export examples/merge_lora/llama3_full_sft.yaml llamafactory-cli export examples/merge_lora/qwen3_full_sft.yaml
``` ```
### Inferring LoRA Fine-Tuned Models ### Inferring LoRA Fine-Tuned Models
@@ -206,26 +194,26 @@ llamafactory-cli export examples/merge_lora/llama3_full_sft.yaml
#### Evaluation using vLLM's Multi-GPU Inference #### Evaluation using vLLM's Multi-GPU Inference
``` ```
python scripts/vllm_infer.py --model_name_or_path meta-llama/Meta-Llama-3-8B-Instruct --template llama3 --dataset alpaca_en_demo python scripts/vllm_infer.py --model_name_or_path Qwen/Qwen3-4B-Instruct-2507 --template qwen3_nothink --dataset alpaca_en_demo
python scripts/eval_bleu_rouge.py generated_predictions.jsonl python scripts/eval_bleu_rouge.py generated_predictions.jsonl
``` ```
#### Use CLI ChatBox #### Use CLI ChatBox
```bash ```bash
llamafactory-cli chat examples/inference/llama3_lora_sft.yaml llamafactory-cli chat examples/inference/qwen3_lora_sft.yaml
``` ```
#### Use Web UI ChatBox #### Use Web UI ChatBox
```bash ```bash
llamafactory-cli webchat examples/inference/llama3_lora_sft.yaml llamafactory-cli webchat examples/inference/qwen3_lora_sft.yaml
``` ```
#### Launch OpenAI-style API #### Launch OpenAI-style API
```bash ```bash
llamafactory-cli api examples/inference/llama3_lora_sft.yaml llamafactory-cli api examples/inference/qwen3_lora_sft.yaml
``` ```
### Extras ### Extras

View File

@@ -18,19 +18,19 @@ LLaMA-Factory 默认使用所有可见的计算设备。
基础用法: 基础用法:
```bash ```bash
llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml llamafactory-cli train examples/train_lora/qwen3_lora_sft.yaml
``` ```
高级用法: 高级用法:
```bash ```bash
CUDA_VISIBLE_DEVICES=0,1 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml \ CUDA_VISIBLE_DEVICES=0,1 llamafactory-cli train examples/train_lora/qwen3_lora_sft.yaml \
learning_rate=1e-5 \ learning_rate=1e-5 \
logging_steps=1 logging_steps=1
``` ```
```bash ```bash
bash examples/train_lora/llama3_lora_sft.sh bash examples/train_lora/qwen3_lora_sft.sh
``` ```
## 示例 ## 示例
@@ -40,49 +40,43 @@ bash examples/train_lora/llama3_lora_sft.sh
#### (增量)预训练 #### (增量)预训练
```bash ```bash
llamafactory-cli train examples/train_lora/llama3_lora_pretrain.yaml llamafactory-cli train examples/train_lora/qwen3_lora_pretrain.yaml
``` ```
#### 指令监督微调 #### 指令监督微调
```bash ```bash
llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml llamafactory-cli train examples/train_lora/qwen3_lora_sft.yaml
``` ```
#### 多模态指令监督微调 #### 多模态指令监督微调
```bash ```bash
llamafactory-cli train examples/train_lora/qwen2_5vl_lora_sft.yaml llamafactory-cli train examples/train_lora/qwen3vl_lora_sft.yaml
``` ```
#### DPO/ORPO/SimPO 训练 #### DPO/ORPO/SimPO 训练
```bash ```bash
llamafactory-cli train examples/train_lora/llama3_lora_dpo.yaml llamafactory-cli train examples/train_lora/qwen3_lora_dpo.yaml
``` ```
#### 多模态 DPO/ORPO/SimPO 训练 #### 多模态 DPO/ORPO/SimPO 训练
```bash ```bash
llamafactory-cli train examples/train_lora/qwen2_5vl_lora_dpo.yaml llamafactory-cli train examples/train_lora/qwen3vl_lora_dpo.yaml
``` ```
#### 奖励模型训练 #### 奖励模型训练
```bash ```bash
llamafactory-cli train examples/train_lora/llama3_lora_reward.yaml llamafactory-cli train examples/train_lora/qwen3_lora_reward.yaml
```
#### PPO 训练
```bash
llamafactory-cli train examples/train_lora/llama3_lora_ppo.yaml
``` ```
#### KTO 训练 #### KTO 训练
```bash ```bash
llamafactory-cli train examples/train_lora/llama3_lora_kto.yaml llamafactory-cli train examples/train_lora/qwen3_lora_kto.yaml
``` ```
#### 预处理数据集 #### 预处理数据集
@@ -90,20 +84,14 @@ llamafactory-cli train examples/train_lora/llama3_lora_kto.yaml
对于大数据集有帮助,在配置中使用 `tokenized_path` 以加载预处理后的数据集。 对于大数据集有帮助,在配置中使用 `tokenized_path` 以加载预处理后的数据集。
```bash ```bash
llamafactory-cli train examples/train_lora/llama3_preprocess.yaml llamafactory-cli train examples/train_lora/qwen3_preprocess.yaml
```
#### 在 MMLU/CMMLU/C-Eval 上评估
```bash
llamafactory-cli eval examples/train_lora/llama3_lora_eval.yaml
``` ```
#### 多机指令监督微调 #### 多机指令监督微调
```bash ```bash
FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/qwen3_lora_sft.yaml
FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/qwen3_lora_sft.yaml
``` ```
### 支持弹性和容错的多机指令监督微调 ### 支持弹性和容错的多机指令监督微调
@@ -111,19 +99,19 @@ FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500
要启动一个支持弹性节点和容错的多机指令微调,在每个节点上执行以下命令。弹性节点数量范围为 `MIN_NNODES:MAX_NNODES`,每个节点最多允许因为错误重启 `MAX_RESTARTS` 次。`RDZV_ID` 应设置为一个唯一的作业 ID由参与该作业的所有节点共享。更多新可以参考官方文档 [torchrun](https://docs.pytorch.org/docs/stable/elastic/run.html)。 要启动一个支持弹性节点和容错的多机指令微调,在每个节点上执行以下命令。弹性节点数量范围为 `MIN_NNODES:MAX_NNODES`,每个节点最多允许因为错误重启 `MAX_RESTARTS` 次。`RDZV_ID` 应设置为一个唯一的作业 ID由参与该作业的所有节点共享。更多新可以参考官方文档 [torchrun](https://docs.pytorch.org/docs/stable/elastic/run.html)。
```bash ```bash
FORCE_TORCHRUN=1 MIN_NNODES=1 MAX_NNODES=3 MAX_RESTARTS=3 RDZV_ID=llamafactory MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft.yaml FORCE_TORCHRUN=1 MIN_NNODES=1 MAX_NNODES=3 MAX_RESTARTS=3 RDZV_ID=llamafactory MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/qwen3_full_sft.yaml
``` ```
#### 使用 DeepSpeed ZeRO-3 平均分配显存 #### 使用 DeepSpeed ZeRO-3 平均分配显存
```bash ```bash
FORCE_TORCHRUN=1 llamafactory-cli train examples/train_lora/llama3_lora_sft_ds3.yaml FORCE_TORCHRUN=1 llamafactory-cli train examples/train_lora/qwen3_lora_sft_ds3.yaml
``` ```
#### 使用 Ray 在 4 张 GPU 上微调 #### 使用 Ray 在 4 张 GPU 上微调
```bash ```bash
USE_RAY=1 llamafactory-cli train examples/train_lora/llama3_lora_sft_ray.yaml USE_RAY=1 llamafactory-cli train examples/train_lora/qwen3_lora_sft_ray.yaml
``` ```
### QLoRA 微调 ### QLoRA 微调
@@ -131,13 +119,13 @@ USE_RAY=1 llamafactory-cli train examples/train_lora/llama3_lora_sft_ray.yaml
#### 基于 4/8 比特 Bitsandbytes/HQQ/EETQ 量化进行指令监督微调(推荐) #### 基于 4/8 比特 Bitsandbytes/HQQ/EETQ 量化进行指令监督微调(推荐)
```bash ```bash
llamafactory-cli train examples/train_qlora/llama3_lora_sft_otfq.yaml llamafactory-cli train examples/train_qlora/qwen3_lora_sft_otfq.yaml
``` ```
#### 在 NPU 上基于 4 比特 Bitsandbytes 量化进行指令监督微调 #### 在 NPU 上基于 4 比特 Bitsandbytes 量化进行指令监督微调
```bash ```bash
llamafactory-cli train examples/train_qlora/llama3_lora_sft_bnb_npu.yaml llamafactory-cli train examples/train_qlora/qwen3_lora_sft_bnb_npu.yaml
``` ```
#### 基于 4/8 比特 GPTQ 量化进行指令监督微调 #### 基于 4/8 比特 GPTQ 量化进行指令监督微调
@@ -163,20 +151,20 @@ llamafactory-cli train examples/train_qlora/llama3_lora_sft_aqlm.yaml
#### 在单机上进行指令监督微调 #### 在单机上进行指令监督微调
```bash ```bash
FORCE_TORCHRUN=1 llamafactory-cli train examples/train_full/llama3_full_sft.yaml FORCE_TORCHRUN=1 llamafactory-cli train examples/train_full/qwen3_full_sft.yaml
``` ```
#### 在多机上进行指令监督微调 #### 在多机上进行指令监督微调
```bash ```bash
FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft.yaml FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/qwen3_full_sft.yaml
FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft.yaml FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/qwen3_full_sft.yaml
``` ```
#### 多模态指令监督微调 #### 多模态指令监督微调
```bash ```bash
FORCE_TORCHRUN=1 llamafactory-cli train examples/train_full/qwen2_5vl_full_sft.yaml FORCE_TORCHRUN=1 llamafactory-cli train examples/train_full/qwen3vl_full_sft.yaml
``` ```
### 合并 LoRA 适配器与模型量化 ### 合并 LoRA 适配器与模型量化
@@ -186,19 +174,19 @@ FORCE_TORCHRUN=1 llamafactory-cli train examples/train_full/qwen2_5vl_full_sft.y
注:请勿使用量化后的模型或 `quantization_bit` 参数来合并 LoRA 适配器。 注:请勿使用量化后的模型或 `quantization_bit` 参数来合并 LoRA 适配器。
```bash ```bash
llamafactory-cli export examples/merge_lora/llama3_lora_sft.yaml llamafactory-cli export examples/merge_lora/qwen3_lora_sft.yaml
``` ```
#### 使用 AutoGPTQ 量化模型 #### 使用 AutoGPTQ 量化模型
```bash ```bash
llamafactory-cli export examples/merge_lora/llama3_gptq.yaml llamafactory-cli export examples/merge_lora/qwen3_gptq.yaml
``` ```
### 保存 Ollama 配置文件 ### 保存 Ollama 配置文件
```bash ```bash
llamafactory-cli export examples/merge_lora/llama3_full_sft.yaml llamafactory-cli export examples/merge_lora/qwen3_full_sft.yaml
``` ```
### 推理 LoRA 模型 ### 推理 LoRA 模型
@@ -206,26 +194,26 @@ llamafactory-cli export examples/merge_lora/llama3_full_sft.yaml
#### 使用 vLLM 多卡推理评估 #### 使用 vLLM 多卡推理评估
``` ```
python scripts/vllm_infer.py --model_name_or_path meta-llama/Meta-Llama-3-8B-Instruct --template llama3 --dataset alpaca_en_demo python scripts/vllm_infer.py --model_name_or_path Qwen/Qwen3-4B-Instruct-2507 --template qwen3_nothink --dataset alpaca_en_demo
python scripts/eval_bleu_rouge.py generated_predictions.jsonl python scripts/eval_bleu_rouge.py generated_predictions.jsonl
``` ```
#### 使用命令行对话框 #### 使用命令行对话框
```bash ```bash
llamafactory-cli chat examples/inference/llama3_lora_sft.yaml llamafactory-cli chat examples/inference/qwen3_lora_sft.yaml
``` ```
#### 使用浏览器对话框 #### 使用浏览器对话框
```bash ```bash
llamafactory-cli webchat examples/inference/llama3_lora_sft.yaml llamafactory-cli webchat examples/inference/qwen3_lora_sft.yaml
``` ```
#### 启动 OpenAI 风格 API #### 启动 OpenAI 风格 API
```bash ```bash
llamafactory-cli api examples/inference/llama3_lora_sft.yaml llamafactory-cli api examples/inference/qwen3_lora_sft.yaml
``` ```
### 杂项 ### 杂项

View File

@@ -1,5 +0,0 @@
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
adapter_name_or_path: saves/llama3-8b/lora/sft
template: llama3
infer_backend: huggingface # choices: [huggingface, vllm, sglang, ktransformers]
trust_remote_code: true

View File

@@ -1,4 +1,4 @@
model_name_or_path: Qwen/Qwen2.5-VL-7B-Instruct model_name_or_path: Qwen/Qwen3-4B-Instruct-2507
template: qwen2_vl template: qwen3_nothink
infer_backend: huggingface # choices: [huggingface, vllm, sglang, ktransformers] infer_backend: huggingface # choices: [huggingface, vllm, sglang, ktransformers]
trust_remote_code: true trust_remote_code: true

View File

@@ -1,4 +1,4 @@
model_name_or_path: saves/llama3-8b/full/sft model_name_or_path: saves/qwen3-4b/full/sft
template: llama3 template: qwen3_nothink
infer_backend: huggingface # choices: [huggingface, vllm, sglang, ktransformers] infer_backend: huggingface # choices: [huggingface, vllm, sglang, ktransformers]
trust_remote_code: true trust_remote_code: true

View File

@@ -0,0 +1,5 @@
model_name_or_path: Qwen/Qwen3-4B-Instruct-2507
adapter_name_or_path: saves/qwen3-4b/lora/sft
template: qwen3_nothink
infer_backend: huggingface # choices: [huggingface, vllm, sglang, ktransformers]
trust_remote_code: true

View File

@@ -1,4 +1,4 @@
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct model_name_or_path: Qwen/Qwen3-VL-4B-Instruct
template: llama3 template: qwen3_vl_nothink
infer_backend: huggingface # choices: [huggingface, vllm, sglang, ktransformers] infer_backend: huggingface # choices: [huggingface, vllm, sglang, ktransformers]
trust_remote_code: true trust_remote_code: true

View File

@@ -1,10 +1,10 @@
### model ### model
model_name_or_path: saves/llama3-8b/full/sft model_name_or_path: saves/qwen3-4b/full/sft
template: llama3 template: qwen3_nothink
trust_remote_code: true trust_remote_code: true
### export ### export
export_dir: output/llama3_full_sft export_dir: saves/qwen3_sft_merged
export_size: 5 export_size: 5
export_device: cpu # choices: [cpu, auto] export_device: cpu # choices: [cpu, auto]
export_legacy_format: false export_legacy_format: false

View File

@@ -1,10 +1,10 @@
### model ### model
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct model_name_or_path: Qwen/Qwen3-4B-Instruct-2507
template: llama3 template: qwen3_nothink
trust_remote_code: true trust_remote_code: true
### export ### export
export_dir: output/llama3_gptq export_dir: saves/qwen3_gptq
export_quantization_bit: 4 export_quantization_bit: 4
export_quantization_dataset: data/c4_demo.jsonl export_quantization_dataset: data/c4_demo.jsonl
export_size: 5 export_size: 5

View File

@@ -1,13 +1,13 @@
### Note: DO NOT use quantized model or quantization_bit when merging lora adapters ### Note: DO NOT use quantized model or quantization_bit when merging lora adapters
### model ### model
model_name_or_path: Qwen/Qwen2.5-VL-7B-Instruct model_name_or_path: Qwen/Qwen3-4B-Instruct-2507
adapter_name_or_path: saves/qwen2_5vl-7b/lora/sft adapter_name_or_path: saves/qwen3-4b/lora/sft
template: qwen2_vl template: qwen3_nothink
trust_remote_code: true trust_remote_code: true
### export ### export
export_dir: output/qwen2_5vl_lora_sft export_dir: saves/qwen3_sft_merged
export_size: 5 export_size: 5
export_device: cpu # choices: [cpu, auto] export_device: cpu # choices: [cpu, auto]
export_legacy_format: false export_legacy_format: false

View File

@@ -1,13 +1,13 @@
### Note: DO NOT use quantized model or quantization_bit when merging lora adapters ### Note: DO NOT use quantized model or quantization_bit when merging lora adapters
### model ### model
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct model_name_or_path: Qwen/Qwen3-VL-4B-Instruct
adapter_name_or_path: saves/llama3-8b/lora/sft adapter_name_or_path: saves/qwen3-vl-4b/lora/sft
template: llama3 template: qwen3_vl_nothink
trust_remote_code: true trust_remote_code: true
### export ### export
export_dir: output/llama3_lora_sft export_dir: saves/qwen3_vl_sft_merged
export_size: 5 export_size: 5
export_device: cpu # choices: [cpu, auto] export_device: cpu # choices: [cpu, auto]
export_legacy_format: false export_legacy_format: false

View File

@@ -1,5 +1,5 @@
### model ### model
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct model_name_or_path: Qwen/Qwen3-4B-Instruct-2507
trust_remote_code: true trust_remote_code: true
### method ### method
@@ -10,15 +10,14 @@ deepspeed: examples/deepspeed/ds_z3_config.json # choices: [ds_z0_config.json,
### dataset ### dataset
dataset: identity,alpaca_en_demo dataset: identity,alpaca_en_demo
template: llama3 template: qwen3_nothink
cutoff_len: 2048 cutoff_len: 2048
max_samples: 1000 max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16 preprocessing_num_workers: 16
dataloader_num_workers: 4 dataloader_num_workers: 4
### output ### output
output_dir: saves/llama3-8b/full/sft output_dir: saves/qwen3-4b/full/sft
logging_steps: 10 logging_steps: 10
save_steps: 500 save_steps: 500
plot_loss: true plot_loss: true

View File

@@ -1,46 +0,0 @@
### model
model_name_or_path: Qwen/Qwen3-32B
trust_remote_code: true
use_v1_kernels: true
### method
stage: sft
do_train: true
finetuning_type: full
deepspeed: examples/deepspeed/ds_z2_autotp_config.json
### dataset
dataset: identity,alpaca_en_demo
template: qwen3
cutoff_len: 2048
max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16
dataloader_num_workers: 4
### output
output_dir: saves/qwen3-32b/full/sft_autotp
logging_steps: 1
save_steps: 500
plot_loss: true
overwrite_output_dir: true
save_only_model: false
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
### train
per_device_train_batch_size: 4
gradient_accumulation_steps: 1
learning_rate: 1.0e-4
num_train_epochs: 3.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
bf16: true
ddp_timeout: 180000000
resume_from_checkpoint: null
### eval
# eval_dataset: alpaca_en_demo
# val_size: 0.1
# per_device_eval_batch_size: 1
# eval_strategy: steps
# eval_steps: 500

View File

@@ -1,5 +1,5 @@
### model ### model
model_name_or_path: Qwen/Qwen2.5-VL-7B-Instruct model_name_or_path: Qwen/Qwen3-VL-4B-Instruct
image_max_pixels: 262144 image_max_pixels: 262144
video_max_pixels: 16384 video_max_pixels: 16384
trust_remote_code: true trust_remote_code: true
@@ -15,15 +15,14 @@ deepspeed: examples/deepspeed/ds_z3_config.json
### dataset ### dataset
dataset: mllm_demo,identity,alpaca_en_demo dataset: mllm_demo,identity,alpaca_en_demo
template: qwen2_vl template: qwen3_vl_nothink
cutoff_len: 2048 cutoff_len: 2048
max_samples: 1000 max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16 preprocessing_num_workers: 16
dataloader_num_workers: 4 dataloader_num_workers: 4
### output ### output
output_dir: saves/qwen2_5vl-7b/full/sft output_dir: saves/qwen3-vl-4b/full/sft
logging_steps: 10 logging_steps: 10
save_steps: 500 save_steps: 500
plot_loss: true plot_loss: true

View File

@@ -1,19 +0,0 @@
### model
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
adapter_name_or_path: saves/llama3-8b/lora/sft
trust_remote_code: true
### method
finetuning_type: lora
### dataset
task: mmlu_test # choices: [mmlu_test, ceval_validation, cmmlu_test]
template: fewshot
lang: en
n_shot: 5
### output
save_dir: saves/llama3-8b/lora/eval
### eval
batch_size: 4

View File

@@ -1,43 +0,0 @@
### model
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
reward_model: saves/llama3-8b/lora/reward
trust_remote_code: true
### method
stage: ppo
do_train: true
finetuning_type: lora
lora_rank: 8
lora_target: all
### dataset
dataset: identity,alpaca_en_demo
template: llama3
cutoff_len: 2048
max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16
dataloader_num_workers: 4
### output
output_dir: saves/llama3-8b/lora/ppo
logging_steps: 10
save_steps: 500
plot_loss: true
overwrite_output_dir: true
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
### train
per_device_train_batch_size: 1
gradient_accumulation_steps: 8
learning_rate: 1.0e-5
num_train_epochs: 3.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
bf16: true
ddp_timeout: 180000000
### generate
max_new_tokens: 512
top_k: 0
top_p: 0.9

View File

@@ -1,46 +0,0 @@
### model
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
trust_remote_code: true
### method
stage: sft
do_train: true
finetuning_type: lora
lora_rank: 8
lora_target: all
### dataset
dataset: identity,alpaca_en_demo
template: llama3
cutoff_len: 2048
max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16
dataloader_num_workers: 4
### output
output_dir: saves/llama3-8b/lora/sft
logging_steps: 10
save_steps: 500
plot_loss: true
overwrite_output_dir: true
save_only_model: false
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
### train
per_device_train_batch_size: 1
gradient_accumulation_steps: 8
learning_rate: 1.0e-4
num_train_epochs: 3.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
bf16: true
ddp_timeout: 180000000
resume_from_checkpoint: null
### eval
# eval_dataset: alpaca_en_demo
# val_size: 0.1
# per_device_eval_batch_size: 1
# eval_strategy: steps
# eval_steps: 500

View File

@@ -1,49 +0,0 @@
# pip install git+https://github.com/hiyouga/transformers.git@llama4_train
### model
model_name_or_path: meta-llama/Llama-4-Scout-17B-16E-Instruct
trust_remote_code: true
### method
stage: sft
do_train: true
finetuning_type: lora
lora_rank: 8
lora_target: all
deepspeed: examples/deepspeed/ds_z3_config.json # choices: [ds_z0_config.json, ds_z2_config.json, ds_z3_config.json]
### dataset
dataset: mllm_demo,identity,alpaca_en_demo
template: llama4
cutoff_len: 2048
max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16
dataloader_num_workers: 4
### output
output_dir: saves/llama4-8b/lora/sft
logging_steps: 10
save_steps: 500
plot_loss: true
overwrite_output_dir: true
save_only_model: false
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
### train
per_device_train_batch_size: 1
gradient_accumulation_steps: 2
learning_rate: 1.0e-4
num_train_epochs: 3.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
bf16: true
ddp_timeout: 180000000
resume_from_checkpoint: null
### eval
# eval_dataset: alpaca_en_demo
# val_size: 0.1
# per_device_eval_batch_size: 1
# eval_strategy: steps
# eval_steps: 500

View File

@@ -1,5 +1,5 @@
### model ### model
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct model_name_or_path: Qwen/Qwen3-4B-Instruct-2507
trust_remote_code: true trust_remote_code: true
### method ### method
@@ -13,15 +13,14 @@ pref_loss: sigmoid # choices: [sigmoid (dpo), orpo, simpo]
### dataset ### dataset
dataset: dpo_en_demo dataset: dpo_en_demo
template: llama3 template: qwen3_nothink
cutoff_len: 2048 cutoff_len: 2048
max_samples: 1000 max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16 preprocessing_num_workers: 16
dataloader_num_workers: 4 dataloader_num_workers: 4
### output ### output
output_dir: saves/llama3-8b/lora/dpo output_dir: saves/qwen3-4b/lora/dpo
logging_steps: 10 logging_steps: 10
save_steps: 500 save_steps: 500
plot_loss: true plot_loss: true

View File

@@ -1,5 +1,5 @@
### model ### model
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct model_name_or_path: Qwen/Qwen3-4B-Instruct-2507
trust_remote_code: true trust_remote_code: true
### method ### method
@@ -12,15 +12,14 @@ pref_beta: 0.1
### dataset ### dataset
dataset: kto_en_demo dataset: kto_en_demo
template: llama3 template: qwen3_nothink
cutoff_len: 2048 cutoff_len: 2048
max_samples: 1000 max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16 preprocessing_num_workers: 16
dataloader_num_workers: 4 dataloader_num_workers: 4
### output ### output
output_dir: saves/llama3-8b/lora/kto output_dir: saves/qwen3-4b/lora/kto
logging_steps: 10 logging_steps: 10
save_steps: 500 save_steps: 500
plot_loss: true plot_loss: true

View File

@@ -1,5 +1,5 @@
### model ### model
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct model_name_or_path: Qwen/Qwen3-4B-Instruct-2507
trust_remote_code: true trust_remote_code: true
### method ### method
@@ -13,12 +13,11 @@ lora_target: all
dataset: c4_demo dataset: c4_demo
cutoff_len: 2048 cutoff_len: 2048
max_samples: 1000 max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16 preprocessing_num_workers: 16
dataloader_num_workers: 4 dataloader_num_workers: 4
### output ### output
output_dir: saves/llama3-8b/lora/pretrain output_dir: saves/qwen3-4b/lora/pretrain
logging_steps: 10 logging_steps: 10
save_steps: 500 save_steps: 500
plot_loss: true plot_loss: true

View File

@@ -1,5 +1,5 @@
### model ### model
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct model_name_or_path: Qwen/Qwen3-4B-Instruct-2507
trust_remote_code: true trust_remote_code: true
### method ### method
@@ -11,15 +11,14 @@ lora_target: all
### dataset ### dataset
dataset: dpo_en_demo dataset: dpo_en_demo
template: llama3 template: qwen3_nothink
cutoff_len: 2048 cutoff_len: 2048
max_samples: 1000 max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16 preprocessing_num_workers: 16
dataloader_num_workers: 4 dataloader_num_workers: 4
### output ### output
output_dir: saves/llama3-8b/lora/reward output_dir: saves/qwen3-4b/lora/reward
logging_steps: 10 logging_steps: 10
save_steps: 500 save_steps: 500
plot_loss: true plot_loss: true

View File

@@ -2,7 +2,7 @@
set -x set -x
MODEL_PATH=meta-llama/Meta-Llama-3-8B-Instruct MODEL_PATH=Qwen/Qwen3-4B-Instruct-2507
llamafactory-cli train \ llamafactory-cli train \
--model_name_or_path ${MODEL_PATH} \ --model_name_or_path ${MODEL_PATH} \
@@ -13,13 +13,12 @@ llamafactory-cli train \
--lora_rank 8 \ --lora_rank 8 \
--lora_target all \ --lora_target all \
--dataset identity,alpaca_en_demo \ --dataset identity,alpaca_en_demo \
--template llama3 \ --template qwen3_nothink \
--cutoff_len 2048 \ --cutoff_len 2048 \
--max_samples 1000 \ --max_samples 1000 \
--overwrite_cache \
--preprocessing_num_workers 16 \ --preprocessing_num_workers 16 \
--dataloader_num_workers 4 \ --dataloader_num_workers 4 \
--output_dir saves/llama3-8b/lora/sft \ --output_dir saves/qwen3-4b/lora/sft \
--logging_steps 10 \ --logging_steps 10 \
--save_steps 500 \ --save_steps 500 \
--plot_loss \ --plot_loss \

View File

@@ -1,5 +1,5 @@
### model ### model
model_name_or_path: openai/gpt-oss-20b model_name_or_path: Qwen/Qwen3-4B-Instruct-2507
trust_remote_code: true trust_remote_code: true
### method ### method
@@ -11,15 +11,14 @@ lora_target: all
### dataset ### dataset
dataset: identity,alpaca_en_demo dataset: identity,alpaca_en_demo
template: gpt template: qwen3_nothink
cutoff_len: 2048 cutoff_len: 2048
max_samples: 1000 max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16 preprocessing_num_workers: 16
dataloader_num_workers: 4 dataloader_num_workers: 4
### output ### output
output_dir: saves/gpt-20b/lora/sft output_dir: saves/qwen3-4b/lora/sft
logging_steps: 10 logging_steps: 10
save_steps: 500 save_steps: 500
plot_loss: true plot_loss: true

View File

@@ -1,5 +1,5 @@
### model ### model
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct model_name_or_path: Qwen/Qwen3-4B-Instruct-2507
trust_remote_code: true trust_remote_code: true
### method ### method
@@ -12,15 +12,14 @@ deepspeed: examples/deepspeed/ds_z3_config.json # choices: [ds_z0_config.json,
### dataset ### dataset
dataset: identity,alpaca_en_demo dataset: identity,alpaca_en_demo
template: llama3 template: qwen3_nothink
cutoff_len: 2048 cutoff_len: 2048
max_samples: 1000 max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16 preprocessing_num_workers: 16
dataloader_num_workers: 4 dataloader_num_workers: 4
### output ### output
output_dir: saves/llama3-8b/lora/sft output_dir: saves/qwen3-4b/lora/sft
logging_steps: 10 logging_steps: 10
save_steps: 500 save_steps: 500
plot_loss: true plot_loss: true

View File

@@ -1,5 +1,5 @@
### model ### model
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct # or use local absolute path model_name_or_path: Qwen/Qwen3-4B-Instruct-2507 # or use local absolute path
trust_remote_code: true trust_remote_code: true
### method ### method
@@ -12,10 +12,9 @@ lora_target: all
### dataset ### dataset
dataset: identity,alpaca_en_demo dataset: identity,alpaca_en_demo
dataset_dir: REMOTE:llamafactory/demo_data # or use local absolute path dataset_dir: REMOTE:llamafactory/demo_data # or use local absolute path
template: llama3 template: qwen3_nothink
cutoff_len: 2048 cutoff_len: 2048
max_samples: 1000 max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16 preprocessing_num_workers: 16
dataloader_num_workers: 4 dataloader_num_workers: 4
@@ -29,7 +28,7 @@ save_only_model: false
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow] report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
### ray ### ray
ray_run_name: llama3_8b_sft_lora ray_run_name: qwen3_4b_sft_lora
ray_storage_path: ./saves ray_storage_path: ./saves
ray_num_workers: 4 # Number of GPUs to use. ray_num_workers: 4 # Number of GPUs to use.
placement_strategy: PACK placement_strategy: PACK

View File

@@ -1,5 +1,5 @@
### model ### model
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct model_name_or_path: Qwen/Qwen3-4B-Instruct-2507
trust_remote_code: true trust_remote_code: true
### method ### method
@@ -11,13 +11,12 @@ lora_target: all
### dataset ### dataset
dataset: identity,alpaca_en_demo dataset: identity,alpaca_en_demo
template: llama3 template: qwen3_nothink
cutoff_len: 2048 cutoff_len: 2048
max_samples: 1000 max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16 preprocessing_num_workers: 16
tokenized_path: saves/llama3-8b/dataset/sft tokenized_path: saves/qwen3-4b/dataset/sft
### output ### output (not used)
output_dir: saves/llama3-8b/lora/sft output_dir: saves/qwen3-4b/lora/sft
overwrite_output_dir: true overwrite_output_dir: true

View File

@@ -1,5 +1,5 @@
### model ### model
model_name_or_path: Qwen/Qwen2.5-VL-7B-Instruct model_name_or_path: Qwen/Qwen3-VL-4B-Instruct
image_max_pixels: 262144 image_max_pixels: 262144
video_max_pixels: 16384 video_max_pixels: 16384
trust_remote_code: true trust_remote_code: true
@@ -15,15 +15,14 @@ pref_loss: sigmoid # choices: [sigmoid (dpo), orpo, simpo]
### dataset ### dataset
dataset: rlhf_v dataset: rlhf_v
template: qwen2_vl template: qwen3_vl_nothink
cutoff_len: 2048 cutoff_len: 2048
max_samples: 1000 max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16 preprocessing_num_workers: 16
dataloader_num_workers: 4 dataloader_num_workers: 4
### output ### output
output_dir: saves/qwen2_5vl-7b/lora/dpo output_dir: saves/qwen3-vl-4b/lora/dpo
logging_steps: 10 logging_steps: 10
save_steps: 500 save_steps: 500
plot_loss: true plot_loss: true

View File

@@ -1,5 +1,5 @@
### model ### model
model_name_or_path: Qwen/Qwen2.5-VL-7B-Instruct model_name_or_path: Qwen/Qwen3-VL-4B-Instruct
image_max_pixels: 262144 image_max_pixels: 262144
video_max_pixels: 16384 video_max_pixels: 16384
trust_remote_code: true trust_remote_code: true
@@ -13,15 +13,14 @@ lora_target: all
### dataset ### dataset
dataset: mllm_demo,identity,alpaca_en_demo # video: mllm_video_demo dataset: mllm_demo,identity,alpaca_en_demo # video: mllm_video_demo
template: qwen2_vl template: qwen3_vl_nothink
cutoff_len: 2048 cutoff_len: 2048
max_samples: 1000 max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16 preprocessing_num_workers: 16
dataloader_num_workers: 4 dataloader_num_workers: 4
### output ### output
output_dir: saves/qwen2_5vl-7b/lora/sft output_dir: saves/qwen3-vl-4b/lora/sft
logging_steps: 10 logging_steps: 10
save_steps: 500 save_steps: 500
plot_loss: true plot_loss: true

View File

@@ -14,7 +14,6 @@ dataset: identity,alpaca_en_demo
template: llama3 template: llama3
cutoff_len: 2048 cutoff_len: 2048
max_samples: 1000 max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16 preprocessing_num_workers: 16
dataloader_num_workers: 4 dataloader_num_workers: 4

View File

@@ -14,7 +14,6 @@ dataset: identity,alpaca_en_demo
template: llama3 template: llama3
cutoff_len: 2048 cutoff_len: 2048
max_samples: 1000 max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16 preprocessing_num_workers: 16
dataloader_num_workers: 4 dataloader_num_workers: 4

View File

@@ -14,7 +14,6 @@ dataset: identity,alpaca_en_demo
template: llama3 template: llama3
cutoff_len: 2048 cutoff_len: 2048
max_samples: 1000 max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16 preprocessing_num_workers: 16
dataloader_num_workers: 4 dataloader_num_workers: 4

View File

@@ -1,5 +1,5 @@
### model ### model
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct model_name_or_path: Qwen/Qwen3-4B-Instruct-2507
quantization_bit: 4 quantization_bit: 4
quantization_method: bnb quantization_method: bnb
double_quantization: false double_quantization: false
@@ -14,15 +14,14 @@ lora_target: all
### dataset ### dataset
dataset: identity,alpaca_en_demo dataset: identity,alpaca_en_demo
template: llama3 template: qwen3_nothink
cutoff_len: 2048 cutoff_len: 2048
max_samples: 1000 max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16 preprocessing_num_workers: 16
dataloader_num_workers: 4 dataloader_num_workers: 4
### output ### output
output_dir: saves/llama3-8b/lora/sft output_dir: saves/qwen3-4b/lora/sft
logging_steps: 10 logging_steps: 10
save_steps: 500 save_steps: 500
plot_loss: true plot_loss: true

View File

@@ -1,5 +1,5 @@
### model ### model
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct model_name_or_path: Qwen/Qwen3-4B-Instruct-2507
quantization_bit: 4 # choices: [8 (bnb/hqq/eetq), 4 (bnb/hqq), 3 (hqq), 2 (hqq)] quantization_bit: 4 # choices: [8 (bnb/hqq/eetq), 4 (bnb/hqq), 3 (hqq), 2 (hqq)]
quantization_method: bnb # choices: [bnb, hqq, eetq] quantization_method: bnb # choices: [bnb, hqq, eetq]
trust_remote_code: true trust_remote_code: true
@@ -13,15 +13,14 @@ lora_target: all
### dataset ### dataset
dataset: identity,alpaca_en_demo dataset: identity,alpaca_en_demo
template: llama3 template: qwen3_nothink
cutoff_len: 2048 cutoff_len: 2048
max_samples: 1000 max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16 preprocessing_num_workers: 16
dataloader_num_workers: 4 dataloader_num_workers: 4
### output ### output
output_dir: saves/llama3-8b/lora/sft output_dir: saves/qwen3-4b/lora/sft
logging_steps: 10 logging_steps: 10
save_steps: 500 save_steps: 500
plot_loss: true plot_loss: true

View File

@@ -41,15 +41,14 @@ dependencies = [
"torch>=2.4.0", "torch>=2.4.0",
"torchvision>=0.19.0", "torchvision>=0.19.0",
"torchaudio>=2.4.0", "torchaudio>=2.4.0",
"transformers>=4.49.0,<=4.56.2,!=4.52.0; python_version < '3.10'", "transformers>=4.51.0,<=4.57.1,!=4.52.0,!=4.57.0",
"transformers>=4.49.0,<=4.57.1,!=4.52.0,!=4.57.0; python_version >= '3.10'",
"datasets>=2.16.0,<=4.0.0", "datasets>=2.16.0,<=4.0.0",
"accelerate>=1.3.0,<=1.11.0", "accelerate>=1.3.0,<=1.11.0",
"peft>=0.14.0,<=0.17.1", "peft>=0.14.0,<=0.17.1",
"trl>=0.8.6,<=0.9.6", "trl>=0.18.0,<=0.24.0",
"torchdata>=0.10.0,<=0.11.0", "torchdata>=0.10.0,<=0.11.0",
# gui # gui
"gradio>=4.38.0,<=6.2.0", "gradio>=4.38.0,<=5.50.0",
"matplotlib>=3.7.0", "matplotlib>=3.7.0",
"tyro<0.9.0", "tyro<0.9.0",
# ops # ops

View File

@@ -14,9 +14,12 @@
import gc import gc
import json import json
import time
import av import av
import fire import fire
from datasets import load_dataset
from eval_bleu_rouge import compute_metrics
from tqdm import tqdm from tqdm import tqdm
from transformers import Seq2SeqTrainingArguments from transformers import Seq2SeqTrainingArguments
@@ -51,6 +54,7 @@ def vllm_infer(
max_samples: int | None = None, max_samples: int | None = None,
vllm_config: str = "{}", vllm_config: str = "{}",
save_name: str = "generated_predictions.jsonl", save_name: str = "generated_predictions.jsonl",
matrix_save_name: str = None,
temperature: float = 0.95, temperature: float = 0.95,
top_p: float = 0.7, top_p: float = 0.7,
top_k: int = 50, top_k: int = 50,
@@ -117,6 +121,7 @@ def vllm_infer(
if isinstance(model_args.vllm_config, dict): if isinstance(model_args.vllm_config, dict):
engine_args.update(model_args.vllm_config) engine_args.update(model_args.vllm_config)
model_preparation_start_time = time.time()
llm = LLM(**engine_args) llm = LLM(**engine_args)
# load datasets # load datasets
@@ -142,6 +147,7 @@ def vllm_infer(
all_prompts, all_preds, all_labels = [], [], [] all_prompts, all_preds, all_labels = [], [], []
need_video_kwargs = _need_video_kwargs(template) need_video_kwargs = _need_video_kwargs(template)
model_predict_start_time = time.time()
# Add batch process to avoid the issue of too many files opened # Add batch process to avoid the issue of too many files opened
for i in tqdm(range(0, len(train_dataset), batch_size), desc="Processing batched inference"): for i in tqdm(range(0, len(train_dataset), batch_size), desc="Processing batched inference"):
vllm_inputs, prompts, labels = [], [], [] vllm_inputs, prompts, labels = [], [], []
@@ -218,6 +224,7 @@ def vllm_infer(
all_labels.extend(labels) all_labels.extend(labels)
gc.collect() gc.collect()
model_predict_end_time = time.time()
# Write all results at once outside the loop # Write all results at once outside the loop
with open(save_name, "w", encoding="utf-8") as f: with open(save_name, "w", encoding="utf-8") as f:
for text, pred, label in zip(all_prompts, all_preds, all_labels): for text, pred, label in zip(all_prompts, all_preds, all_labels):
@@ -227,6 +234,49 @@ def vllm_infer(
print(f"{len(all_prompts)} total generated results have been saved at {save_name}.") print(f"{len(all_prompts)} total generated results have been saved at {save_name}.")
print("*" * 70) print("*" * 70)
# Write all matrix results when matrix_save_name is not None,
# The result matrix is referencing src.llamafactory.train.sft.workflow.run_sft # 127~132
# trainer.save_metrics("predict", predict_results.metrics)
#
# {
# "predict_bleu-4": 4.349975,
# "predict_model_preparation_time": 0.0128,
# "predict_rouge-1": 21.873359375,
# "predict_rouge-2": 4.144340625,
# "predict_rouge-l": 10.83949375,
# "predict_runtime": 131.664,
# "predict_samples_per_second": 0.076,
# "predict_steps_per_second": 0.008
# }
#
if matrix_save_name is not None:
predict_time = model_predict_end_time - model_predict_start_time
preparation_time = model_predict_start_time - model_preparation_start_time
start_time = time.time()
dataset = load_dataset("json", data_files=save_name, split="train")
dataset = dataset.map(compute_metrics, num_proc=8, remove_columns=dataset.column_names)
score_dict = dataset.to_dict()
average_score = {}
for task, scores in sorted(score_dict.items(), key=lambda x: x[0]):
score = sum(scores) / len(scores) if scores else 0.0
print(f"predict_{task}: {score:.4f}")
average_score["predict_" + task] = score
average_score["predict_model_preparation_time"] = preparation_time
average_score["predict_runtime"] = predict_time
num_steps = len(range(0, len(train_dataset), batch_size))
average_score["predict_samples_per_second"] = len(dataset) / predict_time if predict_time > 0 else 0.0
average_score["predict_steps_per_second"] = num_steps / predict_time if predict_time > 0 else 0.0
with open(matrix_save_name, "w", encoding="utf-8") as f:
json.dump(average_score, f, indent=4)
print("*" * 70)
print(f"\nDone in {time.time() - start_time:.3f}s.\nScore file saved to {matrix_save_name}.")
print("*" * 70)
if __name__ == "__main__": if __name__ == "__main__":
fire.Fire(vllm_infer) fire.Fire(vllm_infer)

View File

@@ -1673,6 +1673,43 @@ register_template(
) )
register_template(
name="minimax1",
format_user=StringFormatter(
slots=[
"<beginning_of_sentence>user name=user\n{{content}}<end_of_sentence>\n<beginning_of_sentence>ai name=assistant\n"
]
),
format_assistant=StringFormatter(slots=["{{content}}<end_of_sentence>\n"]),
format_system=StringFormatter(
slots=["<beginning_of_sentence>system ai_setting=assistant\n{{content}}<end_of_sentence>\n"]
),
format_function=FunctionFormatter(slots=["{{content}}<end_of_sentence>\n"], tool_format="minimax1"),
format_observation=StringFormatter(
slots=[
"<beginning_of_sentence>tool name=tools\n{{content}}<end_of_sentence>\n<beginning_of_sentence>ai name=assistant\n"
]
),
format_tools=ToolFormatter(tool_format="minimax1"),
default_system="You are a helpful assistant.",
stop_words=["<end_of_sentence>"],
)
register_template(
name="minimax2",
format_user=StringFormatter(slots=["]~b]user\n{{content}}[e~[\n]~b]ai\n"]),
format_assistant=StringFormatter(slots=["{{content}}[e~[\n"]),
format_system=StringFormatter(slots=["]~!b[]~b]system\n{{content}}[e~[\n"]),
format_function=FunctionFormatter(slots=["{{content}}[e~[\n"], tool_format="minimax2"),
format_observation=StringFormatter(slots=["]~b]tool\n<response>{{content}}</response>[e~[\n]~b]ai\n"]),
format_tools=ToolFormatter(tool_format="minimax2"),
default_system="You are a helpful assistant. Your name is MiniMax-M2.1 and is built by MiniMax.",
stop_words=["[e~["],
template_class=ReasoningTemplate,
)
# mistral tokenizer v3 tekken # mistral tokenizer v3 tekken
register_template( register_template(
name="ministral", name="ministral",

View File

@@ -61,6 +61,21 @@ LLAMA3_TOOL_PROMPT = (
"Do not use variables.\n\n{tool_text}" "Do not use variables.\n\n{tool_text}"
) )
MINIMAX_M1_TOOL_PROMPT = (
"You are provided with these tools:\n<tools>\n{tool_text}</tools>\n\n"
"If you need to call tools, please respond with <tool_calls></tool_calls> XML tags, and provide tool-name and "
"json-object of arguments, following the format below:\n<tool_calls>\n"
"""{{"name": <tool-name-1>, "arguments": <args-json-object-1>}}\n...\n</tool_calls>"""
)
MINIMAX_M2_TOOL_PROMPT = (
"\n\n# Tools\n\nYou may call one or more tools to assist with the user query.\n"
"Here are the tools available in JSONSchema format:\n\n<tools>\n{tool_text}</tools>\n\n"
"When making tool calls, use XML format to invoke tools and pass parameters:\n"
"""\n<minimax:tool_call>\n<invoke name="tool-name-1">\n<parameter name="param-key-1">param-value-1</parameter>\n"""
"""<parameter name="param-key-2">param-value-2</parameter>\n...\n</invoke>\n</minimax:tool_call>"""
)
QWEN_TOOL_PROMPT = ( QWEN_TOOL_PROMPT = (
"\n\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\n" "\n\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\n"
"You are provided with function signatures within <tools></tools> XML tags:\n<tools>{tool_text}" "You are provided with function signatures within <tools></tools> XML tags:\n<tools>{tool_text}"
@@ -253,6 +268,109 @@ class Llama3ToolUtils(ToolUtils):
return content return content
class MiniMaxM1ToolUtils(ToolUtils):
r"""MiniMax-M1 tool using template."""
@override
@staticmethod
def tool_formatter(tools: list[dict[str, Any]]) -> str:
tool_text = ""
for tool in tools:
tool = tool.get("function", "") if tool.get("type") == "function" else tool
tool_text += json.dumps(tool, ensure_ascii=False) + "\n"
return MINIMAX_M1_TOOL_PROMPT.format(tool_text=tool_text)
@override
@staticmethod
def function_formatter(functions: list["FunctionCall"]) -> str:
function_texts = []
for func in functions:
name, arguments = func.name, json.loads(func.arguments)
function_texts.append(json.dumps({"name": name, "arguments": arguments}, ensure_ascii=False))
return "<tool_calls>\n" + "\n".join(function_texts) + "\n</tool_calls>"
@override
@staticmethod
def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]:
regex = re.compile(r"<tool_calls>\s*(.+?)\s*</tool_calls>", re.DOTALL)
tool_match = re.search(regex, content)
if not tool_match:
return content
tool_calls_content = tool_match.group(1)
results = []
for line in tool_calls_content.split("\n"):
line = line.strip()
if not line:
continue
try:
tool_call = json.loads(line)
results.append(FunctionCall(tool_call["name"], json.dumps(tool_call["arguments"], ensure_ascii=False)))
except json.JSONDecodeError:
continue
return results
class MiniMaxM2ToolUtils(ToolUtils):
r"""MiniMax-M2 tool using template."""
@override
@staticmethod
def tool_formatter(tools: list[dict[str, Any]]) -> str:
tool_text = ""
for tool in tools:
tool = tool.get("function", "") if tool.get("type") == "function" else tool
tool_text += "<tool>" + json.dumps(tool, ensure_ascii=False) + "</tool>\n"
return MINIMAX_M2_TOOL_PROMPT.format(tool_text=tool_text)
@override
@staticmethod
def function_formatter(functions: list["FunctionCall"]) -> str:
function_texts = []
for func in functions:
name, arguments = func.name, json.loads(func.arguments)
prompt = f'<invoke name="{name}">'
for key, value in arguments.items():
prompt += f'\n<parameter name="{key}">'
if not isinstance(value, str):
value = json.dumps(value, ensure_ascii=False)
prompt += value + "</parameter>"
prompt += "\n</invoke>"
function_texts.append(prompt)
@override
@staticmethod
def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]:
regex = re.compile(r"<minimax:tool_call>\s*(.+?)\s*</minimax:tool_call>", re.DOTALL)
tool_match = re.search(regex, content)
if not tool_match:
return content
tool_calls_content = tool_match.group(1)
invoke_regex = re.compile(r"<invoke name=\"(.*?)\">(.*?)</invoke>", re.DOTALL)
results = []
for func_name, params_block in re.findall(invoke_regex, tool_calls_content):
args_dict = {}
param_pattern = re.compile(r"<parameter name=\"(.*?)\">(.*?)</parameter>", re.DOTALL)
for key, raw_value in re.findall(param_pattern, params_block):
value = raw_value.strip()
try:
parsed_value = json.loads(value)
except json.JSONDecodeError:
parsed_value = raw_value
args_dict[key] = parsed_value
results.append(FunctionCall(func_name.strip(), json.dumps(args_dict, ensure_ascii=False)))
return results
class MistralToolUtils(ToolUtils): class MistralToolUtils(ToolUtils):
r"""Mistral v0.3 tool using template.""" r"""Mistral v0.3 tool using template."""
@@ -432,6 +550,8 @@ TOOLS = {
"default": DefaultToolUtils(), "default": DefaultToolUtils(),
"glm4": GLM4ToolUtils(), "glm4": GLM4ToolUtils(),
"llama3": Llama3ToolUtils(), "llama3": Llama3ToolUtils(),
"minimax1": MiniMaxM1ToolUtils(),
"minimax2": MiniMaxM2ToolUtils(),
"mistral": MistralToolUtils(), "mistral": MistralToolUtils(),
"qwen": QwenToolUtils(), "qwen": QwenToolUtils(),
"glm4_moe": GLM4MOEToolUtils(), "glm4_moe": GLM4MOEToolUtils(),

View File

@@ -63,6 +63,7 @@ MCA_SUPPORTED_MODELS = {
"qwen2", "qwen2",
"qwen2_vl", "qwen2_vl",
"qwen2_5_vl", "qwen2_5_vl",
"qwen3_vl",
"qwen3", "qwen3",
"qwen3_moe", "qwen3_moe",
"qwen3_next", "qwen3_next",
@@ -1071,6 +1072,40 @@ register_model_group(
) )
register_model_group(
models={
"MiniMax-Text-01-Instruct": {
DownloadSource.DEFAULT: "MiniMaxAI/MiniMax-Text-01-hf",
DownloadSource.MODELSCOPE: "MiniMaxAI/MiniMax-Text-01",
},
"MiniMax-M1-40k-Thinking": {
DownloadSource.DEFAULT: "MiniMaxAI/MiniMax-M1-40k-hf",
DownloadSource.MODELSCOPE: "MiniMaxAI/MiniMax-M1-40k-hf",
},
"MiniMax-M1-80k-Thinking": {
DownloadSource.DEFAULT: "MiniMaxAI/MiniMax-M1-80k-hf",
DownloadSource.MODELSCOPE: "MiniMaxAI/MiniMax-M1-80k-hf",
},
},
template="minimax1",
)
register_model_group(
models={
"MiniMax-M2-Thinking": {
DownloadSource.DEFAULT: "MiniMaxAI/MiniMax-M2",
DownloadSource.MODELSCOPE: "MiniMaxAI/MiniMax-M2",
},
"MiniMax-M2.1-Thinking": {
DownloadSource.DEFAULT: "MiniMaxAI/MiniMax-M2.1",
DownloadSource.MODELSCOPE: "MiniMaxAI/MiniMax-M2.1",
},
},
template="minimax2",
)
register_model_group( register_model_group(
models={ models={
"Granite-3.0-1B-A400M-Base": { "Granite-3.0-1B-A400M-Base": {

View File

@@ -19,7 +19,7 @@
from collections import OrderedDict from collections import OrderedDict
VERSION = "0.9.4.dev0" VERSION = "0.9.4"
def print_env() -> None: def print_env() -> None:

View File

@@ -94,11 +94,11 @@ def check_version(requirement: str, mandatory: bool = False) -> None:
def check_dependencies() -> None: def check_dependencies() -> None:
r"""Check the version of the required packages.""" r"""Check the version of the required packages."""
check_version("transformers>=4.49.0,<=4.57.1") check_version("transformers>=4.51.0,<=4.57.1")
check_version("datasets>=2.16.0,<=4.0.0") check_version("datasets>=2.16.0,<=4.0.0")
check_version("accelerate>=1.3.0,<=1.11.0") check_version("accelerate>=1.3.0,<=1.11.0")
check_version("peft>=0.14.0,<=0.17.1") check_version("peft>=0.14.0,<=0.17.1")
check_version("trl>=0.8.6,<=0.9.6") check_version("trl>=0.18.0,<=0.24.0")
def calculate_tps(dataset: list[dict[str, Any]], metrics: dict[str, float], stage: Literal["sft", "rm"]) -> float: def calculate_tps(dataset: list[dict[str, Any]], metrics: dict[str, float], stage: Literal["sft", "rm"]) -> float:

View File

@@ -173,7 +173,7 @@ class BaseModelArguments:
default=True, default=True,
metadata={"help": "Whether or not to use KV cache in generation."}, metadata={"help": "Whether or not to use KV cache in generation."},
) )
use_v1_kernels: bool = field( use_v1_kernels: bool | None = field(
default=False, default=False,
metadata={"help": "Whether or not to use high-performance kernels in training."}, metadata={"help": "Whether or not to use high-performance kernels in training."},
) )

View File

@@ -216,9 +216,9 @@ def load_model(
"You are try to using future feature about kernels, please note that this feature " "You are try to using future feature about kernels, please note that this feature "
"is not supported for all models. If get any error, please disable this feature, or report the issue." "is not supported for all models. If get any error, please disable this feature, or report the issue."
) )
from ..v1.plugins.model_plugins.kernels.registry import apply_available_kernels from ..v1.plugins.model_plugins.kernels.interface import apply_default_kernels
model = apply_available_kernels(model) model = apply_default_kernels(model=model, include_kernels=model_args.use_v1_kernels)
trainable_params, all_param = count_parameters(model) trainable_params, all_param = count_parameters(model)
if is_trainable: if is_trainable:

View File

@@ -26,6 +26,7 @@ import torch.nn.functional as F
from transformers import Trainer from transformers import Trainer
from trl import DPOTrainer from trl import DPOTrainer
from trl.trainer import disable_dropout_in_model from trl.trainer import disable_dropout_in_model
from trl.trainer.utils import prepare_deepspeed
from typing_extensions import override from typing_extensions import override
from ...extras.constants import IGNORE_INDEX from ...extras.constants import IGNORE_INDEX
@@ -95,7 +96,7 @@ class CustomDPOTrainer(DPOTrainer):
if not ( if not (
getattr(ref_model, "is_loaded_in_8bit", False) or getattr(ref_model, "is_loaded_in_4bit", False) getattr(ref_model, "is_loaded_in_8bit", False) or getattr(ref_model, "is_loaded_in_4bit", False)
): # quantized models are already set on the correct device ): # quantized models are already set on the correct device
self.ref_model = self._prepare_deepspeed(self.ref_model) self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
else: else:
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True) self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
self.ref_model.eval() self.ref_model.eval()
@@ -210,7 +211,7 @@ class CustomDPOTrainer(DPOTrainer):
@override @override
def concatenated_forward( def concatenated_forward(
self, model: "PreTrainedModel", batch: dict[str, "torch.Tensor"], is_ref_model: bool = False self, model: "PreTrainedModel", batch: dict[str, "torch.Tensor"], is_ref_model: bool = False
) -> tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]: ) -> dict[str, "torch.Tensor"]:
r"""Compute the sum log probabilities of the labels under given logits if loss_type is not IPO, ORPO or SimPO. r"""Compute the sum log probabilities of the labels under given logits if loss_type is not IPO, ORPO or SimPO.
Otherwise the average log probabilities. Otherwise the average log probabilities.
@@ -230,11 +231,18 @@ class CustomDPOTrainer(DPOTrainer):
chosen_logps, rejected_logps = all_logps.split(batch_size, dim=0) chosen_logps, rejected_logps = all_logps.split(batch_size, dim=0)
chosen_logits, rejected_logits = all_logits.split(batch_size, dim=0) chosen_logits, rejected_logits = all_logits.split(batch_size, dim=0)
chosen_length, _ = valid_length.split(batch_size, dim=0) chosen_length, _ = valid_length.split(batch_size, dim=0)
if self.loss_type in ["ipo", "orpo", "simpo"]: if self.loss_type in ["ipo", "orpo", "simpo"]:
return chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_logps chosen_logps_avg = chosen_logps
else: else:
return chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_logps / chosen_length chosen_logps_avg = chosen_logps / chosen_length
return {
"chosen_logps": chosen_logps,
"rejected_logps": rejected_logps,
"chosen_logits": chosen_logits,
"rejected_logits": rejected_logits,
"chosen_logps_avg": chosen_logps_avg,
}
@override @override
def compute_reference_log_probs( def compute_reference_log_probs(
@@ -252,9 +260,9 @@ class CustomDPOTrainer(DPOTrainer):
ref_context = nullcontext() ref_context = nullcontext()
with torch.no_grad(), ref_context: with torch.no_grad(), ref_context:
reference_chosen_logps, reference_rejected_logps, *_ = self.concatenated_forward( ref_output = self.concatenated_forward(ref_model, batch, is_ref_model=True)
ref_model, batch, is_ref_model=True reference_chosen_logps = ref_output["chosen_logps"]
) reference_rejected_logps = ref_output["rejected_logps"]
return reference_chosen_logps, reference_rejected_logps return reference_chosen_logps, reference_rejected_logps
@@ -267,13 +275,13 @@ class CustomDPOTrainer(DPOTrainer):
) -> tuple["torch.Tensor", dict[str, "torch.Tensor"]]: ) -> tuple["torch.Tensor", dict[str, "torch.Tensor"]]:
r"""Compute the DPO loss and other metrics for the given batch of inputs for train or test.""" r"""Compute the DPO loss and other metrics for the given batch of inputs for train or test."""
metrics = {} metrics = {}
(
policy_chosen_logps, model_output = self.concatenated_forward(model, batch)
policy_rejected_logps, policy_chosen_logps = model_output["chosen_logps"]
policy_chosen_logits, policy_rejected_logps = model_output["rejected_logps"]
policy_rejected_logits, policy_chosen_logits = model_output["chosen_logits"]
policy_chosen_logps_avg, policy_rejected_logits = model_output["rejected_logits"]
) = self.concatenated_forward(model, batch) policy_chosen_logps_avg = model_output["chosen_logps_avg"]
reference_chosen_logps, reference_rejected_logps = self.compute_reference_log_probs(model, batch) reference_chosen_logps, reference_rejected_logps = self.compute_reference_log_probs(model, batch)
losses, chosen_rewards, rejected_rewards = self.compute_preference_loss( losses, chosen_rewards, rejected_rewards = self.compute_preference_loss(

View File

@@ -25,6 +25,7 @@ import torch
from transformers import Trainer from transformers import Trainer
from trl import KTOTrainer from trl import KTOTrainer
from trl.trainer import disable_dropout_in_model from trl.trainer import disable_dropout_in_model
from trl.trainer.utils import prepare_deepspeed
from typing_extensions import override from typing_extensions import override
from ...extras.constants import IGNORE_INDEX from ...extras.constants import IGNORE_INDEX
@@ -77,6 +78,13 @@ class CustomKTOTrainer(KTOTrainer):
self.desirable_weight = finetuning_args.kto_chosen_weight self.desirable_weight = finetuning_args.kto_chosen_weight
self.undesirable_weight = finetuning_args.kto_rejected_weight self.undesirable_weight = finetuning_args.kto_rejected_weight
self.ftx_gamma = finetuning_args.pref_ftx self.ftx_gamma = finetuning_args.pref_ftx
# trl
# Not all losses require a KL calculation
self.calculate_KL = True
if hasattr(self, "loss_type") and self.loss_type in ["apo_zero_unpaired"]:
self.calculate_KL = False
else:
self.loss_type = "kto"
Trainer.__init__(self, model=model, **kwargs) Trainer.__init__(self, model=model, **kwargs)
self.model_accepts_loss_kwargs = False # overwrite trainer's default behavior self.model_accepts_loss_kwargs = False # overwrite trainer's default behavior
@@ -90,7 +98,7 @@ class CustomKTOTrainer(KTOTrainer):
if not ( if not (
getattr(ref_model, "is_loaded_in_8bit", False) or getattr(ref_model, "is_loaded_in_4bit", False) getattr(ref_model, "is_loaded_in_8bit", False) or getattr(ref_model, "is_loaded_in_4bit", False)
): # quantized models are already set on the correct device ): # quantized models are already set on the correct device
self.ref_model = self._prepare_deepspeed(self.ref_model) self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
else: else:
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True) self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
self.ref_model.eval() self.ref_model.eval()

View File

@@ -11,14 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""MCA (mcore_adapter) workflows for PT/SFT/DPO stages, aligned with LLaMA-Factory's workflow style."""
from __future__ import annotations
import functools import functools
from collections.abc import Sequence from collections.abc import Sequence
from copy import deepcopy from copy import deepcopy
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any, Optional
from transformers import DataCollatorForSeq2Seq
from ...data import ( from ...data import (
SFTDataCollatorWith4DAttentionMask, SFTDataCollatorWith4DAttentionMask,
@@ -44,11 +43,11 @@ from mcore_adapter.models import AutoConfig, AutoModel
from mcore_adapter.trainer import DPOTrainer as McaDPOTrainer from mcore_adapter.trainer import DPOTrainer as McaDPOTrainer
from mcore_adapter.trainer import McaTrainer from mcore_adapter.trainer import McaTrainer
from mcore_adapter.trainer.dpo_config import DPOConfig from mcore_adapter.trainer.dpo_config import DPOConfig
from mcore_adapter.training_args import Seq2SeqTrainingArguments as McaSeq2SeqTrainingArguments
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import DataCollatorForSeq2Seq, TrainerCallback from mcore_adapter.training_args import Seq2SeqTrainingArguments as McaSeq2SeqTrainingArguments
from transformers import TrainerCallback
from ...hparams import DataArguments, FinetuningArguments, ModelArguments from ...hparams import DataArguments, FinetuningArguments, ModelArguments
@@ -76,7 +75,7 @@ def _data_collator_wrapper(data_collator: Any):
return wrapper return wrapper
def _check_model_support(model_args: ModelArguments): def _check_model_support(model_args: "ModelArguments"):
from transformers import AutoConfig as HfAutoConfig from transformers import AutoConfig as HfAutoConfig
config = HfAutoConfig.from_pretrained( config = HfAutoConfig.from_pretrained(
@@ -87,11 +86,11 @@ def _check_model_support(model_args: ModelArguments):
def run_pt( def run_pt(
model_args: ModelArguments, model_args: "ModelArguments",
data_args: DataArguments, data_args: "DataArguments",
training_args: McaSeq2SeqTrainingArguments, training_args: "McaSeq2SeqTrainingArguments",
finetuning_args: FinetuningArguments, finetuning_args: "FinetuningArguments",
callbacks: list[TrainerCallback] | None = None, callbacks: Optional[list["TrainerCallback"]] = None,
): ):
tokenizer_module = load_tokenizer(model_args) tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"] tokenizer = tokenizer_module["tokenizer"]
@@ -104,10 +103,7 @@ def run_pt(
_check_model_support(model_args) _check_model_support(model_args)
model = AutoModel.from_pretrained(model_args.model_name_or_path, training_args) model = AutoModel.from_pretrained(model_args.model_name_or_path, training_args)
data_collator = DataCollatorForSeq2Seq(
from transformers import DataCollatorForSeq2Seq
data_collator: DataCollatorForSeq2Seq = DataCollatorForSeq2Seq(
tokenizer=tokenizer, tokenizer=tokenizer,
pad_to_multiple_of=8, pad_to_multiple_of=8,
label_pad_token_id=IGNORE_INDEX, label_pad_token_id=IGNORE_INDEX,
@@ -142,11 +138,11 @@ def run_pt(
def run_sft( def run_sft(
model_args: ModelArguments, model_args: "ModelArguments",
data_args: DataArguments, data_args: "DataArguments",
training_args: McaSeq2SeqTrainingArguments, training_args: "McaSeq2SeqTrainingArguments",
finetuning_args: FinetuningArguments, finetuning_args: "FinetuningArguments",
callbacks: list[TrainerCallback] | None = None, callbacks: Optional[list["TrainerCallback"]] = None,
): ):
# align packing flags # align packing flags
# TODO: FIX SequencePacking # TODO: FIX SequencePacking
@@ -166,7 +162,7 @@ def run_sft(
model = AutoModel.from_pretrained(model_args.model_name_or_path, training_args) model = AutoModel.from_pretrained(model_args.model_name_or_path, training_args)
# optional freezing for qwen2_vl, qwen2_5_vl # optional freezing for qwen2_vl, qwen2_5_vl
if getattr(model.config, "hf_model_type", None) in ["qwen2_vl", "qwen2_5_vl"]: if getattr(model.config, "hf_model_type", None) in ["qwen2_vl", "qwen2_5_vl", "qwen3_vl"]:
params_to_freeze = [] params_to_freeze = []
if finetuning_args.freeze_vision_tower: if finetuning_args.freeze_vision_tower:
params_to_freeze.extend(["vision_model.blocks", "vision_model.patch_embed"]) params_to_freeze.extend(["vision_model.blocks", "vision_model.patch_embed"])
@@ -220,11 +216,11 @@ def run_sft(
def run_dpo( def run_dpo(
model_args: ModelArguments, model_args: "ModelArguments",
data_args: DataArguments, data_args: "DataArguments",
training_args: McaSeq2SeqTrainingArguments, training_args: "McaSeq2SeqTrainingArguments",
finetuning_args: FinetuningArguments, finetuning_args: "FinetuningArguments",
callbacks: list[TrainerCallback] | None = None, callbacks: Optional[list["TrainerCallback"]] = None,
): ):
tokenizer_module = load_tokenizer(model_args) tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"] tokenizer = tokenizer_module["tokenizer"]

View File

@@ -33,12 +33,12 @@ from transformers.trainer_pt_utils import remove_dummy_checkpoint
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
from transformers.utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME from transformers.utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME
from trl import PPOConfig, PPOTrainer from trl import PPOConfig, PPOTrainer
from trl.core import PPODecorators, logprobs_from_logits from trl import __version__ as trl_version
from trl.models.utils import unwrap_model_for_generation from trl.models.utils import unwrap_model_for_generation
from typing_extensions import override from typing_extensions import override
from ...extras import logging from ...extras import logging
from ...extras.misc import AverageMeter, count_parameters, get_current_device, get_logits_processor from ...extras.misc import AverageMeter, count_parameters, get_current_device, get_logits_processor, torch_gc
from ..callbacks import FixValueHeadModelCallback, SaveProcessorCallback from ..callbacks import FixValueHeadModelCallback, SaveProcessorCallback
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
from .ppo_utils import dump_layernorm, get_rewards_from_server, replace_model, restore_layernorm from .ppo_utils import dump_layernorm, get_rewards_from_server, replace_model, restore_layernorm
@@ -83,6 +83,19 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
if eval_dataset is not None: if eval_dataset is not None:
raise NotImplementedError("PPOTrainer does not support eval dataset yet.") raise NotImplementedError("PPOTrainer does not support eval dataset yet.")
# Check if TRL version is compatible (0.8.6 <= version <= 0.9.6)
try:
from transformers.utils.versions import require_version
require_version(
"trl>=0.8.6,<=0.9.6",
"Incompatible TRL version detected. LLaMA-Factory ppo requires TRL version >=0.8.6,<=0.9.6. "
f"Found version {trl_version}. Please install the correct version with: `pip install trl>=0.8.6,<=0.9.6`\n"
"To fix: run `DISABLE_VERSION_CHECK=1 llamafactory-cli train example_ppo.yaml`\n",
)
except ImportError as e:
raise e
backward_batch_size = training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps backward_batch_size = training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps
ppo_config = PPOConfig( ppo_config = PPOConfig(
model_name=model_args.model_name_or_path, model_name=model_args.model_name_or_path,
@@ -406,7 +419,6 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
return rewards.float().detach() # use fp32 type return rewards.float().detach() # use fp32 type
@override @override
@PPODecorators.empty_device_cache()
def batched_forward_pass( def batched_forward_pass(
self, self,
model: "AutoModelForCausalLMWithValueHead", model: "AutoModelForCausalLMWithValueHead",
@@ -420,6 +432,9 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
Subclass and override to inject custom behavior. Subclass and override to inject custom behavior.
""" """
from trl.core import logprobs_from_logits
torch_gc()
bs = len(queries) bs = len(queries)
fbs = self.config.mini_batch_size fbs = self.config.mini_batch_size
all_logprobs = [] all_logprobs = []

View File

@@ -108,7 +108,7 @@ def create_modelcard_and_push(
elif training_args.push_to_hub: elif training_args.push_to_hub:
trainer.push_to_hub(**kwargs) trainer.push_to_hub(**kwargs)
else: else:
trainer.create_model_card(license="other", **kwargs) # prevent from connecting to hub Trainer.create_model_card(trainer, license="other", **kwargs) # prevent from connecting to hub
def create_ref_model( def create_ref_model(

View File

@@ -112,6 +112,13 @@ class ModelLoader:
model = PeftPlugin(self.args.peft_config.name)(model, self.args.peft_config, self.is_train) model = PeftPlugin(self.args.peft_config.name)(model, self.args.peft_config, self.is_train)
if self.args.kernel_config is not None:
from ..plugins.model_plugins.kernels.interface import KernelPlugin
model = KernelPlugin(self.args.kernel_config.name)(
model=model, include_kernels=self.args.kernel_config.get("include_kernels")
)
return model return model

View File

@@ -0,0 +1,87 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The definition of base kernel class.
Init Phase:
1. Define base kernel class.
2. Define abstract methods.
"""
from abc import ABC, abstractmethod
from typing import Any
from ....accelerator.helper import DeviceType, get_current_accelerator
from ....utils.types import HFModel
class BaseKernel(ABC):
r"""Base class for all kernel implementations.
Subclasses must implement the abstract methods and define the required class attributes.
"""
_kernel_id: Any = "" # kernel ID, any hashable value to identify a kernel implementation
_device: DeviceType = DeviceType.CPU # "cuda", "npu", "cpu", etc.
@classmethod
def get_kernel_id(cls) -> str:
r"""Returns the unique identifier for the kernel."""
return cls._kernel_id
@classmethod
def get_device(cls) -> str:
r"""Returns the device type associated with the kernel (e.g., "cuda", "npu", "cpu")."""
return cls._device
@classmethod
def check_deps(cls) -> bool:
r"""Checks if the required dependencies for the kernel are available.
Returns:
bool: ``True`` if dependencies are met, ``False`` otherwise.
.. note::
In explicit mode, if a user specifies an implementation but this check fails,
it should raise an error instead of silently switching.
Kernels can override this method to implement custom dependency checks.
"""
if cls._device != get_current_accelerator().type:
return False
return True
@classmethod
@abstractmethod
def apply(cls, **kwargs) -> HFModel:
r"""Applies the kernel optimization to the model.
Args:
**kwargs: Arbitrary keyword arguments, usually containing the model instance and the kernel configuration.
Returns:
HFModel: The model with the kernel applied.
Raises:
RuntimeError: If the kernel dependencies are not met.
NotImplementedError: If the method is not implemented by the subclass.
Example:
>>> from llamafactory.v1.plugins.model_plugins.kernels.interface import apply_kernel
>>> model = HFModel(config=config)
>>> model = apply_kernel(model=model, kernel_id="npu_fused_moe")
"""
if not cls.check_deps():
raise RuntimeError(f"{cls.__name__} is not available but {cls.__name__} kernel was called.")
raise NotImplementedError

View File

@@ -1,23 +0,0 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from enum import Enum
class KernelType(str, Enum):
RMSNORM = "rmsnorm"
SWIGLU = "swiglu"
FLASH_ATTENTION = "flash_attention"
ROPE = "rope"
MOE = "moe"

View File

@@ -0,0 +1,132 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The definition of kernel interface.
Init Phase:
1. Scan all kernels.
2. Register default kernels.
3. Define kernel plugin.
"""
import importlib
from pathlib import Path
from ....utils.logging import get_logger
from ....utils.plugin import BasePlugin
from .registry import Registry
logger = get_logger(__name__)
def scan_all_kernels():
r"""Scan all kernels in the ``ops`` directory.
Scans the ``ops`` directory for all ``.py`` files and attempts to import them.
Importing triggers the :func:`~registry.register_kernel` decorator, which automatically registers the kernels.
Returns:
dict[str, type[BaseKernel]]: A dictionary of registered kernels.
.. note::
This function assumes that the ``ops`` directory is located in the same directory as this file.
It recursively searches for ``.py`` files and constructs the module path for import.
"""
ops_path = Path(__file__).parent / "ops"
if not ops_path.exists():
return
base_package = __package__
for file_path in ops_path.rglob("*.py"):
if file_path.name == "__init__.py":
continue
# calculate the relative path:
# file_path = .../kernels_v2/ops/mlp/npu_swiglu.py
# rel_path = ops/mlp/npu_swiglu.py
rel_path = file_path.relative_to(Path(__file__).parent)
# build module path:
module_name = ".".join(rel_path.parts)[:-3]
full_module_name = f"{base_package}.{module_name}"
try:
importlib.import_module(full_module_name)
except Exception as e:
logger.warning(f"[Kernel Registry] Failed to import {full_module_name} when loading kernels: {e}")
return Registry.get_registered_kernels()
default_kernels = scan_all_kernels()
def get_default_kernels():
r"""Get a list of default registered kernel IDs.
Returns:
list[str]: List of kernel IDs.
"""
return list(default_kernels.keys())
def apply_kernel(kernel_id: str, **kwargs):
r"""Applies a specific kernel to the model.
Args:
kernel_id (str): The ID of the kernel to apply.
**kwargs: Keyword arguments passed to the kernel application function.
Typically includes the model instance.
Returns:
HFModel: The model with applied kernel.
"""
kernel = default_kernels.get(kernel_id)
if kernel is None:
raise ValueError(f"Kernel {kernel_id} not found")
kernel.apply(**kwargs)
class KernelPlugin(BasePlugin):
r"""Plugin for managing kernel optimizations."""
pass
@KernelPlugin("auto").register
def apply_default_kernels(**kwargs):
r"""Applies all default registered kernels to the model.
Args:
**kwargs: Keyword arguments passed to the kernel application function.
Typically includes the model instance and the include_kernels configuration.
Returns:
HFModel: The model with applied kernels.
"""
if not kwargs.get("include_kernels"): # None/False/empty string
return kwargs.get("model")
elif kwargs.get("include_kernels") == "auto" or kwargs.get("include_kernels") is True: # True/auto
use_kernels = default_kernels.keys()
else:
use_kernels = kwargs.get("include_kernels").split(",") # "kernel_id1,kernel_id2,kernel_id3"
for kernel in use_kernels:
if kernel not in default_kernels:
raise ValueError(f"Kernel {kernel} not found")
apply_kernel(kernel, **kwargs)
return kwargs.get("model")

View File

@@ -12,22 +12,49 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""The definition of NPU fused MoE kernels.
Init Phase:
1. Define GMM functions.
2. Define NPU fused MoE functions.
3. Register NPU fused MoE kernel.
"""
import types import types
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import torch_npu
from .....accelerator.helper import DeviceType, is_torch_npu_available
from .....utils.packages import is_transformers_version_greater_than try:
from .....utils.types import HFModel import torch_npu
from ..constants import KernelType except ImportError:
from ..registry import MetaMoEKernel pass
from ......accelerator.helper import DeviceType
from ......utils.packages import is_transformers_version_greater_than
from ......utils.types import HFModel
from ...base import BaseKernel
from ...registry import register_kernel
class GmmFunction(torch.autograd.Function): class GmmFunction(torch.autograd.Function):
r"""Custom autograd function for NPU Grouped Matrix Multiplication (GMM)."""
@staticmethod @staticmethod
def forward(ctx, x, weight, group_list): def forward(ctx, x, weight, group_list):
r"""Performs the forward pass of Grouped Matrix Multiplication.
Args:
ctx: Context object to save tensors for backward pass.
x (Tensor): Input tensor.
weight (Tensor): Weight tensor.
group_list (list): List of group sizes.
Returns:
Tensor: The result of the grouped matrix multiplication.
"""
ctx.save_for_backward(x, weight) ctx.save_for_backward(x, weight)
ctx.group_list = group_list ctx.group_list = group_list
@@ -38,6 +65,15 @@ class GmmFunction(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, grad_output): def backward(ctx, grad_output):
r"""Performs the backward pass of Grouped Matrix Multiplication.
Args:
ctx: Context object containing saved tensors.
grad_output (Tensor): Gradient with respect to the output.
Returns:
tuple: Gradients with respect to input, weight, and None for group_list.
"""
input_tensor, weight = ctx.saved_tensors input_tensor, weight = ctx.saved_tensors
group_list = ctx.group_list group_list = ctx.group_list
@@ -58,8 +94,20 @@ class GmmFunction(torch.autograd.Function):
class HybridGmmFunction(torch.autograd.Function): class HybridGmmFunction(torch.autograd.Function):
r"""Custom autograd function for Hybrid Grouped Matrix Multiplication on NPU."""
@staticmethod @staticmethod
def forward(ctx, num_experts, *args): def forward(ctx, num_experts, *args):
r"""Performs the forward pass of Hybrid GMM.
Args:
ctx: Context object to save tensors.
num_experts (int): Number of experts.
*args: Variable length argument list containing inputs and weights.
Returns:
tuple: The outputs of the grouped matrix multiplication.
"""
x_list = list(args[:num_experts]) x_list = list(args[:num_experts])
weight_list = list(args[num_experts:]) weight_list = list(args[num_experts:])
@@ -76,6 +124,15 @@ class HybridGmmFunction(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, *grad_outputs): def backward(ctx, *grad_outputs):
r"""Performs the backward pass of Hybrid GMM.
Args:
ctx: Context object containing saved tensors.
*grad_outputs: Gradients with respect to the outputs.
Returns:
tuple: Gradients with respect to inputs and weights.
"""
saved_tensors = ctx.saved_tensors saved_tensors = ctx.saved_tensors
num_experts = ctx.num_experts num_experts = ctx.num_experts
split_sizes = ctx.split_sizes split_sizes = ctx.split_sizes
@@ -119,10 +176,23 @@ class HybridGmmFunction(torch.autograd.Function):
class NpuMoeFused: class NpuMoeFused:
r"""Container for NPU fused MoE forward functions."""
@staticmethod @staticmethod
def npu_moe_experts_forward( def npu_moe_experts_forward(
self, hidden_states: torch.Tensor, routing_weights: torch.Tensor, router_indices: torch.Tensor self, hidden_states: torch.Tensor, routing_weights: torch.Tensor, router_indices: torch.Tensor
) -> torch.Tensor: ) -> torch.Tensor:
r"""Forward pass for MoE experts using NPU fused operations.
Args:
self: The MoE layer instance.
hidden_states (Tensor): Input hidden states.
routing_weights (Tensor): Routing weights.
router_indices (Tensor): Router indices.
Returns:
Tensor: Output tensor after expert computation.
"""
batch_size = hidden_states.shape[0] batch_size = hidden_states.shape[0]
hidden_states = hidden_states.reshape(-1, self.hidden_size) hidden_states = hidden_states.reshape(-1, self.hidden_size)
permuted_hidden_states, row_ids_map = torch_npu.npu_moe_token_permute( permuted_hidden_states, row_ids_map = torch_npu.npu_moe_token_permute(
@@ -138,6 +208,15 @@ class NpuMoeFused:
@staticmethod @staticmethod
def npu_moe_sparse_block_forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def npu_moe_sparse_block_forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
r"""Forward pass for sparse MoE block using NPU optimization.
Args:
self: The MoE sparse block instance.
hidden_states (Tensor): Input hidden states.
Returns:
Tensor: The routed output.
"""
batch_size = hidden_states.shape[0] batch_size = hidden_states.shape[0]
hidden_states = hidden_states.reshape(-1, self.hidden_size) hidden_states = hidden_states.reshape(-1, self.hidden_size)
router_logits = self.gate(hidden_states) router_logits = self.gate(hidden_states)
@@ -151,8 +230,19 @@ class NpuMoeFused:
class Qwen3NpuMoeFused: class Qwen3NpuMoeFused:
r"""Container for Qwen3 NPU fused MoE forward functions."""
@staticmethod @staticmethod
def qwen3moe_sparse_moe_block_forward(self, hidden_states: torch.Tensor): def qwen3moe_sparse_moe_block_forward(self, hidden_states: torch.Tensor):
r"""Forward pass for Qwen3 sparse MoE block using NPU fused operations.
Args:
self: The Qwen3 MoE block instance.
hidden_states (Tensor): Input hidden states.
Returns:
tuple: A tuple containing the next states and router logits.
"""
batch_size, sequence_length, hidden_dim = hidden_states.shape batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim) hidden_states = hidden_states.view(-1, hidden_dim)
@@ -206,14 +296,33 @@ if not is_transformers_version_greater_than("5.0.0"):
} }
class NpuMoEFusedMoEKernel(MetaMoEKernel): @register_kernel
type = KernelType.MOE class NpuFusedMoEKernel(BaseKernel):
device = DeviceType.NPU r"""NPU Fused MoE Kernel implementation."""
_kernel_id = "npu_fused_moe"
_device = DeviceType.NPU
@classmethod @classmethod
def apply(cls, model, **kwargs) -> HFModel: def apply(cls, **kwargs) -> HFModel:
if not is_torch_npu_available(): r"""Applies the NPU fused MoE kernel to the model.
return model
Args:
**kwargs: Keyword arguments containing the model.
Returns:
HFModel: The model with patched MoE forward functions.
Raises:
ValueError: If the model is not provided.
RuntimeError: If dependencies are not met.
"""
model = kwargs.get("model", None)
if model is None:
raise ValueError(f"HFModel instance is required for {cls.__name__}.")
if not cls.check_deps():
raise RuntimeError("torch_npu is not available but NpuMoEFusedMoEKernel was called.")
archs = getattr(model.config, "architectures", []) archs = getattr(model.config, "architectures", [])
target_moe_mapping = None target_moe_mapping = None

View File

@@ -12,36 +12,71 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""The definition of NPU fused SwiGLU kernels.
Init Phase:
1. Define SwiGLU forward functions.
2. Register NPU fused SwiGLU kernel.
"""
import re import re
import types import types
import torch import torch
from .....accelerator.helper import DeviceType, is_torch_npu_available from ......accelerator.helper import DeviceType
from .....utils.types import HFModel from ......utils.types import HFModel
from ..constants import KernelType from ...base import BaseKernel
from ..registry import MetaSwiGluKernel from ...registry import register_kernel
def _npu_swiglu_forward(self, hidden_state): try:
import torch_npu import torch_npu
except ImportError:
pass
def npu_swiglu_forward(self, hidden_state):
r"""SwiGLU forward pass for NPU.
Args:
self: The MLP layer instance.
hidden_state (Tensor): Input hidden state.
Returns:
Tensor: Output of SwiGLU.
"""
return self.down_proj( return self.down_proj(
torch_npu.npu_swiglu(torch.cat((self.gate_proj(hidden_state), self.up_proj(hidden_state)), dim=-1), dim=-1) torch_npu.npu_swiglu(torch.cat((self.gate_proj(hidden_state), self.up_proj(hidden_state)), dim=-1), dim=-1)
) )
def _npu_swiglu_glm4_forward(self, hidden_states): def _npu_swiglu_glm4_forward(self, hidden_states):
import torch_npu r"""SwiGLU forward pass for GLM4 on NPU.
Args:
self: The GLM4 MLP layer instance.
hidden_states (Tensor): Input hidden states.
Returns:
Tensor: Output of SwiGLU.
"""
up_states = self.gate_up_proj(hidden_states) up_states = self.gate_up_proj(hidden_states)
gate, up_states = up_states.chunk(2, dim=-1) gate, up_states = up_states.chunk(2, dim=-1)
return self.down_proj(torch_npu.npu_swiglu(torch.cat((gate, up_states), dim=-1), dim=-1)) return self.down_proj(torch_npu.npu_swiglu(torch.cat((gate, up_states), dim=-1), dim=-1))
def _npu_swiglu_gemma3ntext_forward(self, hidden_states): def _npu_swiglu_gemma3ntext_forward(self, hidden_states):
import torch_npu r"""SwiGLU forward pass for Gemma3nText on NPU.
Args:
self: The Gemma3nText MLP layer instance.
hidden_states (Tensor): Input hidden states.
Returns:
Tensor: Output of SwiGLU.
"""
gate_proj = self.gate_proj(hidden_states) gate_proj = self.gate_proj(hidden_states)
if self.activation_sparsity > 0.0: if self.activation_sparsity > 0.0:
gate_proj = self._gaussian_topk(gate_proj) gate_proj = self._gaussian_topk(gate_proj)
@@ -51,12 +86,11 @@ def _npu_swiglu_gemma3ntext_forward(self, hidden_states):
return down_proj return down_proj
class NpuSwiGluKernel(MetaSwiGluKernel): @register_kernel
type = KernelType.SWIGLU class NpuSwiGluKernel(BaseKernel):
device = DeviceType.NPU r"""NPU Kernel for fused SwiGLU activation."""
kernel = _npu_swiglu_forward
# Don't apply the kernel to the following modules # just support apply to the following module layers
expect_modules = frozenset( expect_modules = frozenset(
{ {
"Qwen3VLMoeTextMLP", "Qwen3VLMoeTextMLP",
@@ -87,10 +121,29 @@ class NpuSwiGluKernel(MetaSwiGluKernel):
} }
) )
_kernel_id = "npu_fused_swiglu"
_device = DeviceType.NPU
@classmethod @classmethod
def apply(cls, model, **kwargs) -> "HFModel": def apply(cls, **kwargs) -> "HFModel":
if not is_torch_npu_available(): r"""Applies the NPU fused SwiGLU kernel to the model.
return model
Args:
**kwargs: Keyword arguments containing the model.
Returns:
HFModel: The model with patched SwiGLU forward functions.
Raises:
ValueError: If the model is not provided.
RuntimeError: If dependencies are not met.
"""
model = kwargs.get("model", None)
if model is None:
raise ValueError(f"HFModel instance is required for {cls.__name__}.")
if not cls.check_deps():
raise RuntimeError("torch_npu is not available but NpuSwiGluKernel was called.")
# Mapping of specific mlp modules to their corresponding kernel implementations # Mapping of specific mlp modules to their corresponding kernel implementations
kernel_mapping = { kernel_mapping = {
@@ -109,7 +162,7 @@ class NpuSwiGluKernel(MetaSwiGluKernel):
): ):
# Bind function as an instance method to preserve `self` semantics # Bind function as an instance method to preserve `self` semantics
# and replace the original forward # and replace the original forward
kernel_func = kernel_mapping.get(module.__class__.__name__, _npu_swiglu_forward) kernel_func = kernel_mapping.get(module.__class__.__name__, npu_swiglu_forward)
module.forward = types.MethodType(kernel_func, module) module.forward = types.MethodType(kernel_func, module)
return model return model

View File

@@ -11,40 +11,49 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""The definition of NPU fused RMSNorm kernels.
Init Phase:
1. Define RMSNorm forward function.
2. Register NPU fused RMSNorm kernel.
"""
import re import re
import types import types
from .....accelerator.helper import DeviceType, is_torch_npu_available from ......accelerator.helper import DeviceType
from .....utils.types import HFModel from ......utils.types import HFModel
from ..constants import KernelType from ...base import BaseKernel
from ..registry import MetaRMSNormKernel from ...registry import register_kernel
def _npu_rms_forward(self, hidden_states): def npu_rms_norm_forward(self, hidden_states):
"""NPU forward implementation for RMSNorm. r"""NPU forward implementation for RMSNorm.
Args: Args:
self: RMSNorm module instance with `weight` and `variance_epsilon`. self: RMSNorm module instance with `weight` and `variance_epsilon`.
hidden_states: Input hidden states tensor, same shape as the baseline. hidden_states (Tensor): Input hidden states tensor, same shape as the baseline.
Returns: Returns:
Normalized tensor consistent with the baseline RMSNorm behavior. Tensor: Normalized tensor consistent with the baseline RMSNorm behavior.
""" """
import torch_npu import torch_npu
return torch_npu.npu_rms_norm(hidden_states, self.weight, epsilon=self.variance_epsilon)[0] return torch_npu.npu_rms_norm(hidden_states, self.weight, epsilon=self.variance_epsilon)[0]
class NpuRMSNormKernel(MetaRMSNormKernel): @register_kernel
"""NPU kernel wrapper for RMSNorm that applies the replacement within a model.""" class NpuRMSNormKernel(BaseKernel):
r"""NPU kernel wrapper for RMSNorm that applies the replacement within a model."""
type = KernelType.RMSNORM _kernel_id = "npu_fused_rmsnorm"
device = DeviceType.NPU _device = DeviceType.NPU
kernel = _npu_rms_forward
@classmethod @classmethod
def apply(cls, model, **kwargs) -> HFModel: def apply(cls, **kwargs) -> "HFModel":
"""Iterate the model and apply NPU-optimized forward to matched RMSNorm modules. r"""Iterate the model and apply NPU-optimized forward to matched RMSNorm modules.
Key points: Key points:
- Match modules whose class name contains "RMSNorm" (case-insensitive). - Match modules whose class name contains "RMSNorm" (case-insensitive).
@@ -52,10 +61,23 @@ class NpuRMSNormKernel(MetaRMSNormKernel):
replace the original `forward`. replace the original `forward`.
- Do not modify weights, hyperparameters, or module structure to ensure - Do not modify weights, hyperparameters, or module structure to ensure
numerical behavior and interface consistency. numerical behavior and interface consistency.
"""
if not is_torch_npu_available():
return model
Args:
**kwargs: Keyword arguments containing the model.
Returns:
HFModel: The model with NPU fused RMSNorm.
Raises:
RuntimeError: If torch_npu is not available.
ValueError: If the model is not provided.
"""
model = kwargs.get("model")
if model is None:
raise ValueError(f"HFModel instance is required for {cls.__name__}.")
if not cls.check_deps():
raise RuntimeError(f"torch_npu is not available but {cls.__name__} was called.")
rms_norm_pattern = re.compile("RMSNorm", re.IGNORECASE) rms_norm_pattern = re.compile("RMSNorm", re.IGNORECASE)
for name, module in model.named_modules(): for name, module in model.named_modules():
@@ -63,6 +85,6 @@ class NpuRMSNormKernel(MetaRMSNormKernel):
if re.search(rms_norm_pattern, module.__class__.__name__): if re.search(rms_norm_pattern, module.__class__.__name__):
# Bind function as an instance method to preserve `self` semantics # Bind function as an instance method to preserve `self` semantics
# and replace the original forward # and replace the original forward
module.forward = types.MethodType(cls.kernel, module) module.forward = types.MethodType(npu_rms_norm_forward, module)
return model return model

View File

@@ -0,0 +1,146 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The definition of NPU fused RoPE kernels.
Init Phase:
1. Define RoPE forward functions.
2. Register NPU fused RoPE kernel.
"""
import sys
import torch
from ......accelerator.helper import DeviceType
from ......utils.logging import get_logger
from ......utils.types import HFModel
from ...base import BaseKernel
from ...registry import register_kernel
logger = get_logger(__name__)
try:
import torch_npu
except ImportError:
pass
def _apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
r"""Applies Rotary Position Embedding to the query and key tensors using NPU optimization.
Args:
q (Tensor): Query tensor.
k (Tensor): Key tensor.
cos (Tensor): Cosine part of embedding.
sin (Tensor): Sine part of embedding.
position_ids (Tensor, optional): Position IDs. Default: ``None``.
unsqueeze_dim (int): Dimension to unsqueeze cos and sin. Default: 1.
Returns:
tuple: (q_embed, k_embed) The embedded query and key tensors.
"""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = torch_npu.npu_rotary_mul(q, cos, sin)
k_embed = torch_npu.npu_rotary_mul(k, cos, sin)
return q_embed, k_embed
def _apply_multimodal_rotary_pos_emb_qwen25_vl(q, k, cos, sin, mrope_section, unsqueeze_dim=1):
r"""Applies Rotary Position Embedding with multimodal sections (Qwen2-VL) on NPU.
Args:
q (Tensor): Query tensor.
k (Tensor): Key tensor.
cos (Tensor): Cosine part of embedding.
sin (Tensor): Sine part of embedding.
mrope_section (Tensor): Multimodal RoPE section.
unsqueeze_dim (int): Dimension to unsqueeze cos and sin. Default: 1.
Returns:
tuple: (q_embed, k_embed) The embedded query and key tensors.
"""
mrope_section = mrope_section * 2
cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
unsqueeze_dim
)
sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
unsqueeze_dim
)
q_embed = torch_npu.npu_rotary_mul(q, cos, sin)
k_embed = torch_npu.npu_rotary_mul(k, cos, sin)
return q_embed, k_embed
@register_kernel
class NpuRoPEKernel(BaseKernel):
r"""NPU Kernel for Rotary Position Embedding."""
_kernel_id = "npu_fused_rope"
_device = DeviceType.NPU
@classmethod
def apply(cls, **kwargs) -> "HFModel":
r"""Apply RoPE acceleration by monkey-patching `apply_rotary_pos_emb`.
This function iterates through the model's modules to find attention layers,
identifies the module where they are defined, and replaces the original
`apply_rotary_pos_emb` function in that module's namespace with the
NPU-accelerated version from this file.
Args:
**kwargs: Keyword arguments containing the model.
Returns:
HFModel: The model with patched RoPE functions.
Raises:
RuntimeError: If dependencies are not met.
ValueError: If the model is not provided.
"""
if not cls.check_deps():
raise RuntimeError(f"torch_npu is not available but {cls.__name__} was called.")
model = kwargs.get("model", None)
if model is None:
raise ValueError(f"HFModel instance is required for {cls.__name__}.")
_modules = set()
for module in model.modules():
if "Attention" in module.__class__.__name__:
module_name = module.__class__.__module__
if module_name in _modules:
continue
try:
target_module = sys.modules[module_name]
if hasattr(target_module, "apply_rotary_pos_emb"):
if getattr(target_module, "apply_rotary_pos_emb") is not _apply_rotary_pos_emb:
setattr(target_module, "apply_rotary_pos_emb", _apply_rotary_pos_emb)
_modules.add(module_name)
if hasattr(target_module, "apply_multimodal_rotary_pos_emb"):
if (
getattr(target_module, "apply_multimodal_rotary_pos_emb")
is not _apply_multimodal_rotary_pos_emb_qwen25_vl
):
setattr(
target_module,
"apply_multimodal_rotary_pos_emb",
_apply_multimodal_rotary_pos_emb_qwen25_vl,
)
_modules.add(module_name)
except Exception as e:
logger.warning_rank0_once(f"Failed to apply RoPE kernel to module {module_name}: {e}")
return model

View File

@@ -12,247 +12,86 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from abc import ABC, ABCMeta, abstractmethod """The definition of kernel registry.
from collections.abc import Callable
from typing import Any, Optional
from ....accelerator.helper import DeviceType, get_current_accelerator Init Phase:
from ....utils.types import HFModel 1. Define kernel registry.
from .constants import KernelType 2. Register kernels.
"""
from typing import Optional
from ....accelerator.helper import get_current_accelerator
from .base import BaseKernel
class KernelRegistry: __all__ = ["Registry", "register_kernel"]
_instance: Optional["KernelRegistry"] = None
_initialized: bool = False
def __new__(cls, *args: Any, **kwargs: Any) -> "KernelRegistry":
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self) -> None:
if self._initialized:
return
self._registry: dict[KernelType, dict[DeviceType, Callable[..., Any]]] = {}
self._initialized = True
def register(
self, kernel_type: KernelType, device_type: DeviceType, kernel_impl: Callable[..., Any] | None
) -> None:
"""Register a kernel implementation.
Args:
kernel_type: the type of the kernel (e.g., KernelType.FLASH_ATTENTION).
device_type: the device type the kernel is adapted to (e.g., DeviceType.CUDA).
kernel_impl: the actual kernel function or class.
"""
if kernel_type not in self._registry:
self._registry[kernel_type] = {}
if device_type in self._registry[kernel_type]:
print(f"Warning: Overwriting kernel for {kernel_type.name} on {device_type.name}.")
self._registry[kernel_type][device_type] = kernel_impl
print(f"Registered kernel {kernel_type.name} for device {device_type.name}.")
def get_kernel(self, kernel_type: KernelType, device_type: DeviceType) -> Callable[..., Any] | None:
return self._registry.get(kernel_type, {}).get(device_type)
KERNEL_REGISTRY = KernelRegistry() class Registry:
r"""Registry for managing kernel implementations.
Storage structure: ``{ "kernel_id": Class }``
class AutoRegisterKernelMeta(ABCMeta):
"""Metaclass that automatically registers kernel classes upon creation.
This metaclass checks if a newly created class has both `type` and `device`
attributes defined. If so, it automatically registers the kernel in the
global KERNEL_REGISTRY, eliminating the need for manual registration.
To disable auto-registration for a specific class, set `auto_register = False`.
""" """
def __new__(mcs, name, bases, namespace, **kwargs): _kernels: dict[str, type[BaseKernel]] = {}
cls = super().__new__(mcs, name, bases, namespace, **kwargs)
# Check if auto-registration is disabled
auto_register = namespace.get("auto_register", True)
# Only auto-register if the class has both type and device attributes defined
# and they are not None (skip base classes like MetaKernel itself)
# and auto_register is True
kernel_type = namespace.get("type")
device_type = namespace.get("device")
if auto_register and kernel_type is not None and device_type is not None:
# Auto-register this kernel
KERNEL_REGISTRY.register(kernel_type, device_type, cls)
return cls
class MetaKernel(ABC, metaclass=AutoRegisterKernelMeta):
"""Base class for all kernel implementations.
Subclasses are automatically registered when they define both `type` and `device`
attributes. To disable auto-registration, set `auto_register = False`.
Attributes:
type: The kernel type (e.g., KernelType.RMSNORM). Must be set in subclasses.
device: The device type (e.g., DeviceType.NPU). Must be set in subclasses.
kernel: The actual kernel function or implementation.
auto_register: Set to False to disable automatic registration (default: True).
"""
type: KernelType | None = None
device: DeviceType | None = None
kernel: Callable | None = None
@classmethod @classmethod
@abstractmethod def register(cls, kernel_cls: type[BaseKernel]):
def apply(cls, model: HFModel, **kwargs) -> HFModel: r"""Decorator to register a kernel class.
"""Apply the kernel to the model.
This method should check if the kernel can be applied (e.g., dependencies The class must inherit from :class:`BaseKernel` and specify ``_kernel_id`` and ``_device`` attributes.
are installed, target modules exist) and perform the kernel replacement.
Args: Args:
model: The HuggingFace model to optimize. kernel_cls (type[BaseKernel]): The kernel class to register.
**kwargs: Additional arguments for kernel application.
Returns: Returns:
The optimized model (may be the same object with modifications). type[BaseKernel]: The registered kernel class.
Raises:
TypeError: If the class does not inherit from :class:`BaseKernel`.
ValueError: If the kernel ID is missing or already registered.
""" """
raise NotImplementedError if not issubclass(kernel_cls, BaseKernel):
raise TypeError(f"Class {kernel_cls} must inherit from BaseKernel")
kernel_id = kernel_cls.get_kernel_id()
device = kernel_cls.get_device()
# The device type of the current accelerator does not match the device type required by the kernel, skip registration
if device != get_current_accelerator().type:
return
if not kernel_id:
raise ValueError(f"Kernel ID (_kernel_id) is needed for {kernel_cls} to register")
if kernel_id in cls._kernels:
raise ValueError(f"{kernel_id} already registered! The registered kernel is {cls._kernels[kernel_id]}")
cls._kernels[kernel_id] = kernel_cls
return kernel_cls
class MetaFlashAttentionKernel(MetaKernel):
@classmethod @classmethod
def apply(cls, model: HFModel, **kwargs) -> HFModel: def get(cls, kernel_id: str) -> Optional[type[BaseKernel]]:
raise NotImplementedError r"""Retrieves a registered kernel implementation by its ID.
Args:
kernel_id (str): The ID of the kernel to retrieve.
Returns:
Optional[type[BaseKernel]]: The kernel class if found, else ``None``.
"""
return cls._kernels.get(kernel_id)
class MetaRMSNormKernel(MetaKernel):
@classmethod @classmethod
def apply(cls, model: HFModel, **kwargs) -> HFModel: def get_registered_kernels(cls) -> dict[str, type[BaseKernel]]:
raise NotImplementedError r"""Returns a dictionary of all registered kernels.
Returns:
dict[str, type[BaseKernel]]: Dictionary mapping kernel IDs to kernel classes.
"""
return cls._kernels
class MetaSwiGluKernel(MetaKernel): # export decorator alias
@classmethod register_kernel = Registry.register
def apply(cls, model: HFModel, **kwargs) -> HFModel:
raise NotImplementedError
class MetaRoPEKernel(MetaKernel):
@classmethod
def apply(cls, model: HFModel, **kwargs) -> HFModel:
raise NotImplementedError
class MetaMoEKernel(MetaKernel):
@classmethod
def apply(cls, model: HFModel, **kwargs) -> HFModel:
raise NotImplementedError
def _ensure_kernels_loaded() -> None:
"""Ensure all kernel implementations are imported and registered.
This function dynamically imports all kernel implementation modules to trigger
their auto-registration. Python's module system ensures each module is only
executed once (cached in sys.modules), so repeated calls are safe and fast.
"""
# List of kernel module paths to import
kernel_modules = [
"rms_norm.npu_rms_norm",
"rope.npu_rope",
"mlp.npu_swiglu",
"mlp.npu_fused_moe",
# Add new kernel modules here as they are created
]
# Import each module to trigger kernel registration
# Python's import system caches modules, so this is fast on subsequent calls
for module_name in kernel_modules:
try:
__import__(f"{__package__}.{module_name}", fromlist=["*"])
except ImportError:
# Silently ignore import errors (e.g., missing dependencies like torch_npu)
pass
def discover_kernels(model: HFModel = None) -> list[type[MetaKernel]]:
"""Discover and return all kernel classes registered for the current device.
This function inspects the runtime environment (device type) and returns
all MetaKernel classes registered for that device. Each kernel's `apply()`
method is responsible for checking if it can actually be applied (e.g.,
required dependencies are installed, target modules exist in the model).
The function automatically discovers all kernels registered in KERNEL_REGISTRY
without requiring manual enumeration. On first call, it dynamically imports
all kernel implementation modules to trigger their auto-registration.
Args:
model: The HuggingFace model to apply kernels to.
TODO: implement the kernel route detection logic by model structure.
Returns:
A list of MetaKernel classes available for the current device.
"""
# Ensure all kernel modules are imported to trigger registration
_ensure_kernels_loaded()
discovered_kernels: list[type[MetaKernel]] = []
# Detect current device type
accelerator = get_current_accelerator()
try:
device_type = DeviceType(accelerator.type)
except ValueError:
# Unknown device type, return empty list
return discovered_kernels
# Skip CPU as it typically doesn't have optimized kernels
if device_type == DeviceType.CPU:
return discovered_kernels
# Iterate through registry and collect all kernels for current device
for devices in KERNEL_REGISTRY._registry.values():
kernel_cls = devices.get(device_type)
if kernel_cls is not None:
discovered_kernels.append(kernel_cls)
return discovered_kernels
def apply_kernel(model: HFModel, kernel: type[MetaKernel] | Any, /, **kwargs) -> "HFModel":
"""Call the MetaKernel's `apply` to perform the replacement.
Corresponding replacement logic is maintained inside each kernel; the only
requirement is that `apply` returns the replaced model.
Example:
from transformers import AutoModelForCausalLM
from .rms_norm.npu_rms_norm import NpuRMSNormKernel
model = AutoModelForCausalLM.from_pretrained("qwen/qwen2.5-0.5B")
model = apply_kernel(model, NpuRMSNormKernel)
"""
if not issubclass(kernel, MetaKernel):
raise ValueError(f"{kernel} must be a MetaKernel instance.")
if kernel.device != get_current_accelerator().type:
raise ValueError(f"{kernel} must be applied to {kernel.device} device, got {get_current_accelerator().type}.")
return kernel.apply(model, **kwargs)
def apply_available_kernels(model: HFModel, **kwargs) -> "HFModel":
"""Apply all available kernels to the model."""
for kernel in discover_kernels(model):
model = apply_kernel(model, kernel, **kwargs)
return model

View File

@@ -1,122 +0,0 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import torch
from .....accelerator.helper import DeviceType, is_torch_npu_available
from .....utils.types import HFModel
from ..constants import KernelType
from ..registry import MetaRoPEKernel
def _apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors."""
import torch_npu
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = torch_npu.npu_rotary_mul(q, cos, sin)
k_embed = torch_npu.npu_rotary_mul(k, cos, sin)
return q_embed, k_embed
def _apply_multimodal_rotary_pos_emb_qwen25_vl(q, k, cos, sin, mrope_section, unsqueeze_dim=1):
"""Applies Rotary Position Embedding with multimodal sections (Qwen2-VL)."""
import torch_npu
mrope_section = mrope_section * 2
cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
unsqueeze_dim
)
sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
unsqueeze_dim
)
q_embed = torch_npu.npu_rotary_mul(q, cos, sin)
k_embed = torch_npu.npu_rotary_mul(k, cos, sin)
return q_embed, k_embed
class NpuRoPEKernel(MetaRoPEKernel):
type = KernelType.ROPE
device = DeviceType.NPU
kernel = _apply_rotary_pos_emb
@classmethod
def apply(cls, model, **kwargs) -> "HFModel":
"""Apply RoPE acceleration by monkey-patching `apply_rotary_pos_emb`.
This function iterates through the model's modules to find attention layers,
identifies the module where they are defined, and replaces the original
`apply_rotary_pos_emb` function in that module's namespace with the
NPU-accelerated version from this file.
"""
if not is_torch_npu_available():
return model
_modules = set()
for module in model.modules():
if "Attention" in module.__class__.__name__:
module_name = module.__class__.__module__
if module_name in _modules:
continue
try:
target_module = sys.modules[module_name]
if hasattr(target_module, "apply_rotary_pos_emb"):
if getattr(target_module, "apply_rotary_pos_emb") is not cls.kernel:
setattr(target_module, "apply_rotary_pos_emb", cls.kernel)
_modules.add(module_name)
except Exception:
pass
return model
class NpuQwen2VLRoPEKernel(MetaRoPEKernel):
"""Qwen2-VL specific RoPE kernel - not auto-registered.
This kernel is for specific models (Qwen2-VL) and should be manually
applied when needed rather than auto-discovered.
"""
type = KernelType.ROPE
device = DeviceType.NPU
kernel = _apply_multimodal_rotary_pos_emb_qwen25_vl
auto_register = False # Disable auto-registration for model-specific kernel
@classmethod
def apply(cls, model, **kwargs) -> "HFModel":
"""Apply RoPE acceleration by monkey-patching `apply_rotary_pos_emb`.
This function iterates through the model's modules to find attention layers,
identifies the module where they are defined, and replaces the original
`apply_rotary_pos_emb` function in that module's namespace with the
NPU-accelerated version from this file.
"""
_modules = set()
for module in model.modules():
if "Attention" in module.__class__.__name__:
module_name = module.__class__.__module__
if module_name in _modules:
continue
try:
target_module = sys.modules[module_name]
if hasattr(target_module, "apply_multimodal_rotary_pos_emb"):
if getattr(target_module, "apply_multimodal_rotary_pos_emb") is not cls.kernel:
setattr(target_module, "apply_multimodal_rotary_pos_emb", cls.kernel)
_modules.add(module_name)
except Exception:
pass
return model

View File

@@ -18,8 +18,11 @@ Contains shared fixtures, pytest configuration, and custom markers.
""" """
import os import os
from typing import Optional
import pytest import pytest
import torch
import torch.distributed as dist
from pytest import Config, FixtureRequest, Item, MonkeyPatch from pytest import Config, FixtureRequest, Item, MonkeyPatch
from llamafactory.extras.misc import get_current_device, get_device_count, is_env_enabled from llamafactory.extras.misc import get_current_device, get_device_count, is_env_enabled
@@ -70,7 +73,7 @@ def _handle_slow_tests(items: list[Item]):
item.add_marker(skip_slow) item.add_marker(skip_slow)
def _get_visible_devices_env() -> str | None: def _get_visible_devices_env() -> Optional[str]:
"""Return device visibility env var name.""" """Return device visibility env var name."""
if CURRENT_DEVICE == "cuda": if CURRENT_DEVICE == "cuda":
return "CUDA_VISIBLE_DEVICES" return "CUDA_VISIBLE_DEVICES"
@@ -118,6 +121,14 @@ def pytest_collection_modifyitems(config: Config, items: list[Item]):
_handle_device_visibility(items) _handle_device_visibility(items)
@pytest.fixture(autouse=True)
def _cleanup_distributed_state():
"""Cleanup distributed state after each test."""
yield
if dist.is_initialized():
dist.destroy_process_group()
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def _manage_distributed_env(request: FixtureRequest, monkeypatch: MonkeyPatch) -> None: def _manage_distributed_env(request: FixtureRequest, monkeypatch: MonkeyPatch) -> None:
"""Set environment variables for distributed tests if specific devices are requested.""" """Set environment variables for distributed tests if specific devices are requested."""
@@ -145,6 +156,10 @@ def _manage_distributed_env(request: FixtureRequest, monkeypatch: MonkeyPatch) -
monkeypatch.setenv(env_key, visible_devices[0] if visible_devices else "0") monkeypatch.setenv(env_key, visible_devices[0] if visible_devices else "0")
else: else:
monkeypatch.setenv(env_key, "0") monkeypatch.setenv(env_key, "0")
if CURRENT_DEVICE == "cuda":
monkeypatch.setattr(torch.cuda, "device_count", lambda: 1)
elif CURRENT_DEVICE == "npu":
monkeypatch.setattr(torch.npu, "device_count", lambda: 1)
@pytest.fixture @pytest.fixture

View File

@@ -0,0 +1,71 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pathlib
import sys
from unittest.mock import patch
from llamafactory.v1.config.arg_parser import get_args
def test_get_args_from_yaml(tmp_path: pathlib.Path):
config_yaml = """
### model
model: "llamafactory/tiny-random-qwen2.5"
trust_remote_code: true
use_fast_processor: true
model_class: "llm"
kernel_config:
name: "auto"
include_kernels: "auto" # choice: null/true/false/auto/kernel_id1,kernel_id2,kernel_id3, default is null
peft_config:
name: "lora"
lora_rank: 0.8
quant_config: null
### data
dataset: "llamafactory/tiny-supervised-dataset"
cutoff_len: 2048
### training
output_dir: "outputs/test_run"
micro_batch_size: 1
global_batch_size: 1
learning_rate: 1.0e-4
bf16: false
dist_config: null
### sample
sample_backend: "hf"
max_new_tokens: 128
"""
config_file = tmp_path / "config.yaml"
config_file.write_text(config_yaml, encoding="utf-8")
test_argv = ["test_args_parser.py", str(config_file)]
with patch.object(sys, "argv", test_argv):
data_args, model_args, training_args, sample_args = get_args()
assert training_args.output_dir == "outputs/test_run"
assert training_args.micro_batch_size == 1
assert training_args.global_batch_size == 1
assert training_args.learning_rate == 1.0e-4
assert training_args.bf16 is False
assert training_args.dist_config is None
assert model_args.model == "llamafactory/tiny-random-qwen2.5"
assert model_args.kernel_config.name == "auto"
assert model_args.kernel_config.get("include_kernels") == "auto"
assert model_args.peft_config.name == "lora"
assert model_args.peft_config.get("lora_rank") == 0.8

View File

@@ -18,8 +18,10 @@ Contains shared fixtures, pytest configuration, and custom markers.
""" """
import os import os
import sys
import pytest import pytest
import torch
from pytest import Config, FixtureRequest, Item, MonkeyPatch from pytest import Config, FixtureRequest, Item, MonkeyPatch
from llamafactory.v1.accelerator.helper import get_current_accelerator, get_device_count from llamafactory.v1.accelerator.helper import get_current_accelerator, get_device_count
@@ -139,9 +141,21 @@ def _manage_distributed_env(request: FixtureRequest, monkeypatch: MonkeyPatch) -
devices_str = ",".join(str(i) for i in range(required)) devices_str = ",".join(str(i) for i in range(required))
monkeypatch.setenv(env_key, devices_str) monkeypatch.setenv(env_key, devices_str)
# add project root dir to path for mp run
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
if project_root not in sys.path:
sys.path.insert(0, project_root)
os.environ["PYTHONPATH"] = project_root + os.pathsep + os.environ.get("PYTHONPATH", "")
else: # non-distributed test else: # non-distributed test
if old_value: if old_value:
visible_devices = [v for v in old_value.split(",") if v != ""] visible_devices = [v for v in old_value.split(",") if v != ""]
monkeypatch.setenv(env_key, visible_devices[0] if visible_devices else "0") monkeypatch.setenv(env_key, visible_devices[0] if visible_devices else "0")
else: else:
monkeypatch.setenv(env_key, "0") monkeypatch.setenv(env_key, "0")
if CURRENT_DEVICE == "cuda":
monkeypatch.setattr(torch.cuda, "device_count", lambda: 1)
elif CURRENT_DEVICE == "npu":
monkeypatch.setattr(torch.npu, "device_count", lambda: 1)

View File

@@ -14,7 +14,7 @@
import torch import torch
from llamafactory.v1.config.model_args import ModelArguments from llamafactory.v1.config.model_args import ModelArguments, PluginConfig
from llamafactory.v1.core.model_loader import ModelLoader from llamafactory.v1.core.model_loader import ModelLoader
@@ -29,5 +29,23 @@ def test_tiny_qwen():
assert model_loader.model.dtype == torch.bfloat16 assert model_loader.model.dtype == torch.bfloat16
def test_tiny_qwen_with_kernel_plugin():
from transformers import Qwen2ForCausalLM
from llamafactory.v1.plugins.model_plugins.kernels.ops.rms_norm.npu_rms_norm import npu_rms_norm_forward
model_args = ModelArguments(
model="llamafactory/tiny-random-qwen2.5", kernel_config=PluginConfig(name="auto", include_kernels="auto")
)
model_loader = ModelLoader(model_args)
# test enable apply kernel plugin
if hasattr(torch, "npu"):
assert model_loader.model.model.layers[0].input_layernorm.forward.__code__ == npu_rms_norm_forward.__code__
else:
assert model_loader.model.model.layers[0].input_layernorm.forward.__code__ != npu_rms_norm_forward.__code__
assert isinstance(model_loader.model, Qwen2ForCausalLM)
if __name__ == "__main__": if __name__ == "__main__":
test_tiny_qwen() test_tiny_qwen()
test_tiny_qwen_with_kernel_plugin()

View File

@@ -12,16 +12,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import sys
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
import pytest import pytest
from transformers import AutoModelForCausalLM from transformers import AutoModelForCausalLM
from llamafactory.v1.accelerator.helper import get_current_accelerator from llamafactory.v1.accelerator.helper import get_current_accelerator
from llamafactory.v1.plugins.model_plugins.kernels.mlp import npu_swiglu
from llamafactory.v1.plugins.model_plugins.kernels.registry import apply_available_kernels, apply_kernel
from llamafactory.v1.plugins.model_plugins.kernels.rms_norm import npu_rms_norm
from llamafactory.v1.plugins.model_plugins.kernels.rope import npu_rope
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
@@ -29,24 +26,29 @@ def clear_accelerator_cache():
get_current_accelerator.cache_clear() get_current_accelerator.cache_clear()
def reload_kernels():
"""Helper to reload kernel modules to respect mocked accelerator."""
# Unload kernel interface and registry
keys_to_remove = [k for k in sys.modules if k.startswith("llamafactory.v1.plugins.model_plugins.kernels")]
for k in keys_to_remove:
del sys.modules[k]
@patch("torch.accelerator.current_accelerator") @patch("torch.accelerator.current_accelerator")
def test_apply_kernel(mock_get_accelerator: MagicMock): def test_apply_kernel(mock_get_accelerator: MagicMock):
mock_device = MagicMock() mock_device = MagicMock()
setattr(mock_device, "type", "npu") setattr(mock_device, "type", "npu")
mock_get_accelerator.return_value = mock_device mock_get_accelerator.return_value = mock_device
# Force reload of kernels with mocked accelerator
reload_kernels()
from llamafactory.v1.plugins.model_plugins.kernels.interface import apply_default_kernels
model = AutoModelForCausalLM.from_pretrained("llamafactory/tiny-random-qwen2.5") model = AutoModelForCausalLM.from_pretrained("llamafactory/tiny-random-qwen2.5")
original_rmsnorm_forward = model.model.layers[0].input_layernorm.forward original_rmsnorm_forward = model.model.layers[0].input_layernorm.forward
original_swiglu_forward = model.model.layers[0].mlp.forward original_swiglu_forward = model.model.layers[0].mlp.forward
model = apply_default_kernels(model=model, include_kernels="npu_fused_rmsnorm")
apply_kernel(model, npu_rope.NpuRoPEKernel) assert model.model.layers[0].input_layernorm.forward.__func__ is not original_rmsnorm_forward.__func__
assert model.model.layers[0].mlp.forward.__func__ is original_swiglu_forward.__func__
model = apply_kernel(model, npu_rms_norm.NpuRMSNormKernel)
assert model.model.layers[0].input_layernorm is not original_rmsnorm_forward
model = apply_kernel(model, npu_swiglu.NpuSwiGluKernel)
assert model.model.layers[0].mlp.forward is not original_swiglu_forward
@patch("torch.accelerator.current_accelerator") @patch("torch.accelerator.current_accelerator")
@@ -56,12 +58,15 @@ def test_apply_all_kernels(mock_get_accelerator: MagicMock):
setattr(mock_device, "type", "npu") setattr(mock_device, "type", "npu")
mock_get_accelerator.return_value = mock_device mock_get_accelerator.return_value = mock_device
# Force reload of kernels with mocked accelerator
reload_kernels()
from llamafactory.v1.plugins.model_plugins.kernels.interface import apply_default_kernels
model = AutoModelForCausalLM.from_pretrained("llamafactory/tiny-random-qwen2.5") model = AutoModelForCausalLM.from_pretrained("llamafactory/tiny-random-qwen2.5")
original_rmsnorm_forward = model.model.layers[0].input_layernorm.forward original_rmsnorm_forward = model.model.layers[0].input_layernorm.forward
original_swiglu_forward = model.model.layers[0].mlp.forward original_swiglu_forward = model.model.layers[0].mlp.forward
model = apply_available_kernels(model) model = apply_default_kernels(model=model, include_kernels=True)
assert model.model.layers[0].input_layernorm.forward.__func__ is not original_rmsnorm_forward.__func__
assert model.model.layers[0].input_layernorm is not original_rmsnorm_forward assert model.model.layers[0].mlp.forward.__func__ is not original_swiglu_forward.__func__
assert model.model.layers[0].mlp.forward is not original_swiglu_forward