11 Commits

Author SHA1 Message Date
Yaowei Zheng
95ac3f2373 [release] Bye 2025 (#9702) 2025-12-31 22:22:40 +08:00
Username_Full
000526908a [core deps] upgrade TRL to be between 0.18 and 0.24 (#9617)
Co-authored-by: Yaowei Zheng <hiyouga@buaa.edu.cn>
2025-12-31 20:54:27 +08:00
fivehaitao
c8d7e85b3e [fix] Fix prediction metrics in scripts/vllm_infer.py to match Transformers (#9701)
Co-authored-by: xuht6 <xuht6@asiainfo.com>
2025-12-31 18:30:00 +08:00
浮梦
16735b9e35 [v1] Refactor kernel plugin (#9669)
Co-authored-by: frozenleaves <frozen@Mac.local>
2025-12-31 18:26:48 +08:00
Weize Liu
4e1d69579a [data] add DLR-Web dataset for supervised fine-tuning (#9696) 2025-12-30 20:50:38 +08:00
浮梦
1857fbdd6b [ci] add cuda workflow (#9682)
Co-authored-by: frozenleaves <frozen@Mac.local>
Co-authored-by: Yaowei Zheng <hiyouga@buaa.edu.cn>
2025-12-29 20:03:00 +08:00
Kingsley
bb1ba31005 [misc] lint mca code (#9692) 2025-12-29 11:44:38 +08:00
Copilot
e97d0474fb [ci] Fix NPU device condition in docker workflow (#9688)
Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: hiyouga <16256802+hiyouga@users.noreply.github.com>
2025-12-28 20:04:59 +08:00
Yaowei Zheng
3f0c3dc84d [assets] fix installation (#9687)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2025-12-28 19:29:28 +08:00
Hertz
c107cc22d0 [model] support MiniMax-M1&M2 series (#9680)
Co-authored-by: Yaowei Zheng <hiyouga@buaa.edu.cn>
2025-12-28 19:02:05 +08:00
Yaowei Zheng
7ef1fba34a [version] fix gradio (#9685) 2025-12-28 05:00:51 +08:00
92 changed files with 1383 additions and 884 deletions

View File

@@ -72,7 +72,7 @@ jobs:
password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Login to Quay
if: ${{ github.event_name != 'pull_request' && matrix.device == 'npu'}}
if: ${{ github.event_name != 'pull_request' && startsWith(matrix.device, 'npu') }}
uses: docker/login-action@v3
with:
registry: quay.io

View File

@@ -27,23 +27,23 @@ jobs:
python:
- "3.11"
- "3.12"
# - "3.13" # enable after trl is upgraded
- "3.13"
os:
- "ubuntu-latest"
- "windows-latest"
- "macos-latest"
transformers:
- null
- ""
include: # test backward compatibility
- python: "3.11"
os: "ubuntu-latest"
transformers: "4.49.0"
- python: "3.11"
os: "ubuntu-latest"
transformers: "4.51.0"
- python: "3.11"
os: "ubuntu-latest"
transformers: "4.53.0"
- python: "3.11"
os: "ubuntu-latest"
transformers: "4.55.0"
runs-on: ${{ matrix.os }}

88
.github/workflows/tests_cuda.yml vendored Normal file
View File

@@ -0,0 +1,88 @@
name: tests_cuda
on:
workflow_dispatch:
push:
branches:
- "main"
paths:
- "**/*.py"
- "pyproject.toml"
- "Makefile"
- ".github/workflows/*.yml"
pull_request:
branches:
- "main"
paths:
- "**/*.py"
- "pyproject.toml"
- "Makefile"
- ".github/workflows/*.yml"
jobs:
tests:
strategy:
fail-fast: false
matrix:
python:
- "3.11"
os:
- "linux-x86_64-gpu-2"
runs-on: ${{ matrix.os }}
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ matrix.os }}-${{ matrix.python }}
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Install uv
uses: astral-sh/setup-uv@v7
with:
python-version: ${{ matrix.python }}
github-token: ${{ github.token }}
enable-cache: false
- name: Check GPU Status
run: nvidia-smi
- name: Install dependencies
run: |
uv venv
uv pip install -e ".[dev]"
- name: Cache HuggingFace models
id: hf-hub-cache
uses: actions/cache@v4
with:
path: ${{ runner.temp }}/huggingface
key: hf-cache-${{ runner.os }}-${{ hashFiles('tests/version.txt') }}
- name: Check quality
run: |
make style && make quality
env:
UV_NO_SYNC: 1
- name: Check license
run: |
make license
env:
UV_NO_SYNC: 1
- name: Check build
run: |
make build
env:
UV_NO_SYNC: 1
- name: Test with pytest
run: |
make test
env:
UV_NO_SYNC: 1
HF_HOME: ${{ runner.temp }}/huggingface
HF_HUB_OFFLINE: "${{ steps.hf-hub-cache.outputs.cache-hit == 'true' && '1' || '0' }}"

View File

@@ -49,8 +49,11 @@ jobs:
uses: actions/checkout@v4
- name: Install uv
run: |
curl -LsSf https://astral.sh/uv/install.sh | sh
uses: astral-sh/setup-uv@v7
with:
python-version: ${{ matrix.python }}
github-token: ${{ github.token }}
enable-cache: false
- name: Install dependencies
run: |

View File

@@ -309,6 +309,7 @@ Read technical notes:
| [MiMo](https://huggingface.co/XiaomiMiMo) | 7B/309B | mimo/mimo_v2 |
| [MiniCPM 1-4.1](https://huggingface.co/openbmb) | 0.5B/1B/2B/4B/8B | cpm/cpm3/cpm4 |
| [MiniCPM-o-2.6/MiniCPM-V-2.6](https://huggingface.co/openbmb) | 8B | minicpm_o/minicpm_v |
| [MiniMax-M1/MiniMax-M2](https://huggingface.co/MiniMaxAI/models) | 229B/456B | minimax1/minimax2 |
| [Ministral 3](https://huggingface.co/mistralai) | 3B/8B/14B | ministral3 |
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
| [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
@@ -433,6 +434,7 @@ You also can add a custom chat template to [template.py](src/llamafactory/data/t
- [Chinese-DeepSeek-R1-Distill (zh)](https://huggingface.co/datasets/Congliu/Chinese-DeepSeek-R1-Distill-data-110k-SFT)
- [LLaVA mixed (en&zh)](https://huggingface.co/datasets/BUAADreamer/llava-en-zh-300k)
- [Pokemon-gpt4o-captions (en&zh)](https://huggingface.co/datasets/jugg1024/pokemon-gpt4o-captions)
- [DLR-Web (en)](https://huggingface.co/datasets/Attention1115/DLR-Web)
- [Open Assistant (de)](https://huggingface.co/datasets/mayflowergmbh/oasst_de)
- [Dolly 15k (de)](https://huggingface.co/datasets/mayflowergmbh/dolly-15k_de)
- [Alpaca GPT4 (de)](https://huggingface.co/datasets/mayflowergmbh/alpaca-gpt4_de)
@@ -514,7 +516,7 @@ huggingface-cli login
```bash
git clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git
cd LLaMA-Factory
pip install -e ".[metrics]" --no-build-isolation
pip install -e ".[metrics]"
```
Optional dependencies available: `metrics`, `deepspeed`. Install with: `pip install -e ".[metrics,deepspeed]"`
@@ -637,7 +639,7 @@ cd transformers
pip install .
```
3. Set `double_quantization: false` in the configuration. You can refer to the [example](examples/train_qlora/llama3_lora_sft_bnb_npu.yaml).
3. Set `double_quantization: false` in the configuration. You can refer to the [example](examples/train_qlora/qwen3_lora_sft_bnb_npu.yaml).
</details>
@@ -652,12 +654,12 @@ You can also use **[Easy Dataset](https://github.com/ConardLi/easy-dataset)**, *
### Quickstart
Use the following 3 commands to run LoRA **fine-tuning**, **inference** and **merging** of the Llama3-8B-Instruct model, respectively.
Use the following 3 commands to run LoRA **fine-tuning**, **inference** and **merging** of the Qwen3-4B-Instruct model, respectively.
```bash
llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
llamafactory-cli chat examples/inference/llama3_lora_sft.yaml
llamafactory-cli export examples/merge_lora/llama3_lora_sft.yaml
llamafactory-cli train examples/train_lora/qwen3_lora_sft.yaml
llamafactory-cli chat examples/inference/qwen3_lora_sft.yaml
llamafactory-cli export examples/merge_lora/qwen3_lora_sft.yaml
```
See [examples/README.md](examples/README.md) for advanced usage (including distributed training).
@@ -780,7 +782,7 @@ When building the Docker image, use `-v ./hf_cache:/root/.cache/huggingface` arg
### Deploy with OpenAI-style API and vLLM
```bash
API_PORT=8000 llamafactory-cli api examples/inference/llama3.yaml infer_backend=vllm vllm_enforce_eager=true
API_PORT=8000 llamafactory-cli api examples/inference/qwen3.yaml infer_backend=vllm vllm_enforce_eager=true
```
> [!TIP]

View File

@@ -311,6 +311,7 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
| [MiMo](https://huggingface.co/XiaomiMiMo) | 7B/309B | mimo/mimo_v2 |
| [MiniCPM 1-4.1](https://huggingface.co/openbmb) | 0.5B/1B/2B/4B/8B | cpm/cpm3/cpm4 |
| [MiniCPM-o-2.6/MiniCPM-V-2.6](https://huggingface.co/openbmb) | 8B | minicpm_o/minicpm_v |
| [MiniMax-M1/MiniMax-M2](https://huggingface.co/MiniMaxAI/models) | 229B/456B | minimax1/minimax2 |
| [Ministral 3](https://huggingface.co/mistralai) | 3B/8B/14B | ministral3 |
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
| [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
@@ -435,6 +436,7 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
- [Chinese-DeepSeek-R1-Distill (zh)](https://huggingface.co/datasets/Congliu/Chinese-DeepSeek-R1-Distill-data-110k-SFT)
- [LLaVA mixed (en&zh)](https://huggingface.co/datasets/BUAADreamer/llava-en-zh-300k)
- [Pokemon-gpt4o-captions (en&zh)](https://huggingface.co/datasets/jugg1024/pokemon-gpt4o-captions)
- [DLR-Web (en)](https://huggingface.co/datasets/Attention1115/DLR-Web)
- [Open Assistant (de)](https://huggingface.co/datasets/mayflowergmbh/oasst_de)
- [Dolly 15k (de)](https://huggingface.co/datasets/mayflowergmbh/dolly-15k_de)
- [Alpaca GPT4 (de)](https://huggingface.co/datasets/mayflowergmbh/alpaca-gpt4_de)
@@ -516,7 +518,7 @@ huggingface-cli login
```bash
git clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git
cd LLaMA-Factory
pip install -e ".[metrics]" --no-build-isolation
pip install -e ".[metrics]"
```
可选的额外依赖项:`metrics``deepspeed`。使用 `pip install -e ".[metrics,deepspeed]"` 安装。
@@ -639,7 +641,7 @@ cd transformers
pip install .
```
3. 在训练参数中设置 `double_quantization: false`,可参考[示例](examples/train_qlora/llama3_lora_sft_bnb_npu.yaml)。
3. 在训练参数中设置 `double_quantization: false`,可参考[示例](examples/train_qlora/qwen3_lora_sft_bnb_npu.yaml)。
</details>
@@ -654,12 +656,12 @@ pip install .
### 快速开始
下面三行命令分别对 Llama3-8B-Instruct 模型进行 LoRA **微调**、**推理**和**合并**。
下面三行命令分别对 Qwen3-4B-Instruct 模型进行 LoRA **微调**、**推理**和**合并**。
```bash
llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
llamafactory-cli chat examples/inference/llama3_lora_sft.yaml
llamafactory-cli export examples/merge_lora/llama3_lora_sft.yaml
llamafactory-cli train examples/train_lora/qwen3_lora_sft.yaml
llamafactory-cli chat examples/inference/qwen3_lora_sft.yaml
llamafactory-cli export examples/merge_lora/qwen3_lora_sft.yaml
```
高级用法请参考 [examples/README_zh.md](examples/README_zh.md)(包括多 GPU 微调)。
@@ -785,7 +787,7 @@ docker exec -it llamafactory bash
### 利用 vLLM 部署 OpenAI API
```bash
API_PORT=8000 llamafactory-cli api examples/inference/llama3.yaml infer_backend=vllm vllm_enforce_eager=true
API_PORT=8000 llamafactory-cli api examples/inference/qwen3.yaml infer_backend=vllm vllm_enforce_eager=true
```
> [!TIP]

View File

@@ -471,6 +471,14 @@
"ultrachat_de": {
"hf_hub_url": "mayflowergmbh/ultra-chat_de"
},
"dlr_web": {
"hf_hub_url": "Attention1115/DLR-Web",
"split": "full",
"columns": {
"prompt": "question",
"response": "response"
}
},
"dpo_en_demo": {
"file_name": "dpo_en_demo.json",
"ranking": true,

View File

@@ -18,19 +18,19 @@ By default, LLaMA-Factory uses all visible computing devices.
Basic usage:
```bash
llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
llamafactory-cli train examples/train_lora/qwen3_lora_sft.yaml
```
Advanced usage:
```bash
CUDA_VISIBLE_DEVICES=0,1 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml \
CUDA_VISIBLE_DEVICES=0,1 llamafactory-cli train examples/train_lora/qwen3_lora_sft.yaml \
learning_rate=1e-5 \
logging_steps=1
```
```bash
bash examples/train_lora/llama3_lora_sft.sh
bash examples/train_lora/qwen3_lora_sft.sh
```
## Examples
@@ -40,49 +40,43 @@ bash examples/train_lora/llama3_lora_sft.sh
#### (Continuous) Pre-Training
```bash
llamafactory-cli train examples/train_lora/llama3_lora_pretrain.yaml
llamafactory-cli train examples/train_lora/qwen3_lora_pretrain.yaml
```
#### Supervised Fine-Tuning
```bash
llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
llamafactory-cli train examples/train_lora/qwen3_lora_sft.yaml
```
#### Multimodal Supervised Fine-Tuning
```bash
llamafactory-cli train examples/train_lora/qwen2_5vl_lora_sft.yaml
llamafactory-cli train examples/train_lora/qwen3vl_lora_sft.yaml
```
#### DPO/ORPO/SimPO Training
```bash
llamafactory-cli train examples/train_lora/llama3_lora_dpo.yaml
llamafactory-cli train examples/train_lora/qwen3_lora_dpo.yaml
```
#### Multimodal DPO/ORPO/SimPO Training
```bash
llamafactory-cli train examples/train_lora/qwen2_5vl_lora_dpo.yaml
llamafactory-cli train examples/train_lora/qwen3vl_lora_dpo.yaml
```
#### Reward Modeling
```bash
llamafactory-cli train examples/train_lora/llama3_lora_reward.yaml
```
#### PPO Training
```bash
llamafactory-cli train examples/train_lora/llama3_lora_ppo.yaml
llamafactory-cli train examples/train_lora/qwen3_lora_reward.yaml
```
#### KTO Training
```bash
llamafactory-cli train examples/train_lora/llama3_lora_kto.yaml
llamafactory-cli train examples/train_lora/qwen3_lora_kto.yaml
```
#### Preprocess Dataset
@@ -90,32 +84,26 @@ llamafactory-cli train examples/train_lora/llama3_lora_kto.yaml
It is useful for large dataset, use `tokenized_path` in config to load the preprocessed dataset.
```bash
llamafactory-cli train examples/train_lora/llama3_preprocess.yaml
```
#### Evaluating on MMLU/CMMLU/C-Eval Benchmarks
```bash
llamafactory-cli eval examples/train_lora/llama3_lora_eval.yaml
llamafactory-cli train examples/train_lora/qwen3_preprocess.yaml
```
#### Supervised Fine-Tuning on Multiple Nodes
```bash
FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/qwen3_lora_sft.yaml
FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/qwen3_lora_sft.yaml
```
#### Supervised Fine-Tuning with DeepSpeed ZeRO-3 (Weight Sharding)
```bash
FORCE_TORCHRUN=1 llamafactory-cli train examples/train_lora/llama3_lora_sft_ds3.yaml
FORCE_TORCHRUN=1 llamafactory-cli train examples/train_lora/qwen3_lora_sft_ds3.yaml
```
#### Supervised Fine-Tuning with Ray on 4 GPUs
```bash
USE_RAY=1 llamafactory-cli train examples/train_lora/llama3_lora_sft_ray.yaml
USE_RAY=1 llamafactory-cli train examples/train_lora/qwen3_lora_sft_ray.yaml
```
### QLoRA Fine-Tuning
@@ -123,13 +111,13 @@ USE_RAY=1 llamafactory-cli train examples/train_lora/llama3_lora_sft_ray.yaml
#### Supervised Fine-Tuning with 4/8-bit Bitsandbytes/HQQ/EETQ Quantization (Recommended)
```bash
llamafactory-cli train examples/train_qlora/llama3_lora_sft_otfq.yaml
llamafactory-cli train examples/train_qlora/qwen3_lora_sft_otfq.yaml
```
#### Supervised Fine-Tuning with 4-bit Bitsandbytes Quantization on Ascend NPU
```bash
llamafactory-cli train examples/train_qlora/llama3_lora_sft_bnb_npu.yaml
llamafactory-cli train examples/train_qlora/qwen3_lora_sft_bnb_npu.yaml
```
#### Supervised Fine-Tuning with 4/8-bit GPTQ Quantization
@@ -155,14 +143,14 @@ llamafactory-cli train examples/train_qlora/llama3_lora_sft_aqlm.yaml
#### Supervised Fine-Tuning on Single Node
```bash
FORCE_TORCHRUN=1 llamafactory-cli train examples/train_full/llama3_full_sft.yaml
FORCE_TORCHRUN=1 llamafactory-cli train examples/train_full/qwen3_full_sft.yaml
```
#### Supervised Fine-Tuning on Multiple Nodes
```bash
FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft.yaml
FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft.yaml
FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/qwen3_full_sft.yaml
FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/qwen3_full_sft.yaml
```
### Elastic and Fault-Tolerant Supervised Fine-Tuning on Multiple Nodes
@@ -170,13 +158,13 @@ FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500
To launch an elastic job with `MAX_RESTARTS` failures retries, run the following on at least `MIN_NNODES` nodes and at most `MAX_NNODES` nodes. `RDZV_ID` should be set as a unique job id (shared by all nodes participating in the job). See also [torchrun](https://docs.pytorch.org/docs/stable/elastic/run.html).
```bash
FORCE_TORCHRUN=1 MIN_NNODES=1 MAX_NNODES=3 MAX_RESTARTS=3 RDZV_ID=llamafactory MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft.yaml
FORCE_TORCHRUN=1 MIN_NNODES=1 MAX_NNODES=3 MAX_RESTARTS=3 RDZV_ID=llamafactory MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/qwen3_full_sft.yaml
```
#### Multimodal Supervised Fine-Tuning
```bash
FORCE_TORCHRUN=1 llamafactory-cli train examples/train_full/qwen2_5vl_full_sft.yaml
FORCE_TORCHRUN=1 llamafactory-cli train examples/train_full/qwen3vl_full_sft.yaml
```
### Merging LoRA Adapters and Quantization
@@ -186,19 +174,19 @@ FORCE_TORCHRUN=1 llamafactory-cli train examples/train_full/qwen2_5vl_full_sft.y
Note: DO NOT use quantized model or `quantization_bit` when merging LoRA adapters.
```bash
llamafactory-cli export examples/merge_lora/llama3_lora_sft.yaml
llamafactory-cli export examples/merge_lora/qwen3_lora_sft.yaml
```
#### Quantizing Model using AutoGPTQ
```bash
llamafactory-cli export examples/merge_lora/llama3_gptq.yaml
llamafactory-cli export examples/merge_lora/qwen3_gptq.yaml
```
### Save Ollama modelfile
```bash
llamafactory-cli export examples/merge_lora/llama3_full_sft.yaml
llamafactory-cli export examples/merge_lora/qwen3_full_sft.yaml
```
### Inferring LoRA Fine-Tuned Models
@@ -206,26 +194,26 @@ llamafactory-cli export examples/merge_lora/llama3_full_sft.yaml
#### Evaluation using vLLM's Multi-GPU Inference
```
python scripts/vllm_infer.py --model_name_or_path meta-llama/Meta-Llama-3-8B-Instruct --template llama3 --dataset alpaca_en_demo
python scripts/vllm_infer.py --model_name_or_path Qwen/Qwen3-4B-Instruct-2507 --template qwen3_nothink --dataset alpaca_en_demo
python scripts/eval_bleu_rouge.py generated_predictions.jsonl
```
#### Use CLI ChatBox
```bash
llamafactory-cli chat examples/inference/llama3_lora_sft.yaml
llamafactory-cli chat examples/inference/qwen3_lora_sft.yaml
```
#### Use Web UI ChatBox
```bash
llamafactory-cli webchat examples/inference/llama3_lora_sft.yaml
llamafactory-cli webchat examples/inference/qwen3_lora_sft.yaml
```
#### Launch OpenAI-style API
```bash
llamafactory-cli api examples/inference/llama3_lora_sft.yaml
llamafactory-cli api examples/inference/qwen3_lora_sft.yaml
```
### Extras

View File

@@ -18,19 +18,19 @@ LLaMA-Factory 默认使用所有可见的计算设备。
基础用法:
```bash
llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
llamafactory-cli train examples/train_lora/qwen3_lora_sft.yaml
```
高级用法:
```bash
CUDA_VISIBLE_DEVICES=0,1 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml \
CUDA_VISIBLE_DEVICES=0,1 llamafactory-cli train examples/train_lora/qwen3_lora_sft.yaml \
learning_rate=1e-5 \
logging_steps=1
```
```bash
bash examples/train_lora/llama3_lora_sft.sh
bash examples/train_lora/qwen3_lora_sft.sh
```
## 示例
@@ -40,49 +40,43 @@ bash examples/train_lora/llama3_lora_sft.sh
#### (增量)预训练
```bash
llamafactory-cli train examples/train_lora/llama3_lora_pretrain.yaml
llamafactory-cli train examples/train_lora/qwen3_lora_pretrain.yaml
```
#### 指令监督微调
```bash
llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
llamafactory-cli train examples/train_lora/qwen3_lora_sft.yaml
```
#### 多模态指令监督微调
```bash
llamafactory-cli train examples/train_lora/qwen2_5vl_lora_sft.yaml
llamafactory-cli train examples/train_lora/qwen3vl_lora_sft.yaml
```
#### DPO/ORPO/SimPO 训练
```bash
llamafactory-cli train examples/train_lora/llama3_lora_dpo.yaml
llamafactory-cli train examples/train_lora/qwen3_lora_dpo.yaml
```
#### 多模态 DPO/ORPO/SimPO 训练
```bash
llamafactory-cli train examples/train_lora/qwen2_5vl_lora_dpo.yaml
llamafactory-cli train examples/train_lora/qwen3vl_lora_dpo.yaml
```
#### 奖励模型训练
```bash
llamafactory-cli train examples/train_lora/llama3_lora_reward.yaml
```
#### PPO 训练
```bash
llamafactory-cli train examples/train_lora/llama3_lora_ppo.yaml
llamafactory-cli train examples/train_lora/qwen3_lora_reward.yaml
```
#### KTO 训练
```bash
llamafactory-cli train examples/train_lora/llama3_lora_kto.yaml
llamafactory-cli train examples/train_lora/qwen3_lora_kto.yaml
```
#### 预处理数据集
@@ -90,20 +84,14 @@ llamafactory-cli train examples/train_lora/llama3_lora_kto.yaml
对于大数据集有帮助,在配置中使用 `tokenized_path` 以加载预处理后的数据集。
```bash
llamafactory-cli train examples/train_lora/llama3_preprocess.yaml
```
#### 在 MMLU/CMMLU/C-Eval 上评估
```bash
llamafactory-cli eval examples/train_lora/llama3_lora_eval.yaml
llamafactory-cli train examples/train_lora/qwen3_preprocess.yaml
```
#### 多机指令监督微调
```bash
FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/qwen3_lora_sft.yaml
FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/qwen3_lora_sft.yaml
```
### 支持弹性和容错的多机指令监督微调
@@ -111,19 +99,19 @@ FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500
要启动一个支持弹性节点和容错的多机指令微调,在每个节点上执行以下命令。弹性节点数量范围为 `MIN_NNODES:MAX_NNODES`,每个节点最多允许因为错误重启 `MAX_RESTARTS` 次。`RDZV_ID` 应设置为一个唯一的作业 ID由参与该作业的所有节点共享。更多新可以参考官方文档 [torchrun](https://docs.pytorch.org/docs/stable/elastic/run.html)。
```bash
FORCE_TORCHRUN=1 MIN_NNODES=1 MAX_NNODES=3 MAX_RESTARTS=3 RDZV_ID=llamafactory MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft.yaml
FORCE_TORCHRUN=1 MIN_NNODES=1 MAX_NNODES=3 MAX_RESTARTS=3 RDZV_ID=llamafactory MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/qwen3_full_sft.yaml
```
#### 使用 DeepSpeed ZeRO-3 平均分配显存
```bash
FORCE_TORCHRUN=1 llamafactory-cli train examples/train_lora/llama3_lora_sft_ds3.yaml
FORCE_TORCHRUN=1 llamafactory-cli train examples/train_lora/qwen3_lora_sft_ds3.yaml
```
#### 使用 Ray 在 4 张 GPU 上微调
```bash
USE_RAY=1 llamafactory-cli train examples/train_lora/llama3_lora_sft_ray.yaml
USE_RAY=1 llamafactory-cli train examples/train_lora/qwen3_lora_sft_ray.yaml
```
### QLoRA 微调
@@ -131,13 +119,13 @@ USE_RAY=1 llamafactory-cli train examples/train_lora/llama3_lora_sft_ray.yaml
#### 基于 4/8 比特 Bitsandbytes/HQQ/EETQ 量化进行指令监督微调(推荐)
```bash
llamafactory-cli train examples/train_qlora/llama3_lora_sft_otfq.yaml
llamafactory-cli train examples/train_qlora/qwen3_lora_sft_otfq.yaml
```
#### 在 NPU 上基于 4 比特 Bitsandbytes 量化进行指令监督微调
```bash
llamafactory-cli train examples/train_qlora/llama3_lora_sft_bnb_npu.yaml
llamafactory-cli train examples/train_qlora/qwen3_lora_sft_bnb_npu.yaml
```
#### 基于 4/8 比特 GPTQ 量化进行指令监督微调
@@ -163,20 +151,20 @@ llamafactory-cli train examples/train_qlora/llama3_lora_sft_aqlm.yaml
#### 在单机上进行指令监督微调
```bash
FORCE_TORCHRUN=1 llamafactory-cli train examples/train_full/llama3_full_sft.yaml
FORCE_TORCHRUN=1 llamafactory-cli train examples/train_full/qwen3_full_sft.yaml
```
#### 在多机上进行指令监督微调
```bash
FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft.yaml
FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft.yaml
FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/qwen3_full_sft.yaml
FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/qwen3_full_sft.yaml
```
#### 多模态指令监督微调
```bash
FORCE_TORCHRUN=1 llamafactory-cli train examples/train_full/qwen2_5vl_full_sft.yaml
FORCE_TORCHRUN=1 llamafactory-cli train examples/train_full/qwen3vl_full_sft.yaml
```
### 合并 LoRA 适配器与模型量化
@@ -186,19 +174,19 @@ FORCE_TORCHRUN=1 llamafactory-cli train examples/train_full/qwen2_5vl_full_sft.y
注:请勿使用量化后的模型或 `quantization_bit` 参数来合并 LoRA 适配器。
```bash
llamafactory-cli export examples/merge_lora/llama3_lora_sft.yaml
llamafactory-cli export examples/merge_lora/qwen3_lora_sft.yaml
```
#### 使用 AutoGPTQ 量化模型
```bash
llamafactory-cli export examples/merge_lora/llama3_gptq.yaml
llamafactory-cli export examples/merge_lora/qwen3_gptq.yaml
```
### 保存 Ollama 配置文件
```bash
llamafactory-cli export examples/merge_lora/llama3_full_sft.yaml
llamafactory-cli export examples/merge_lora/qwen3_full_sft.yaml
```
### 推理 LoRA 模型
@@ -206,26 +194,26 @@ llamafactory-cli export examples/merge_lora/llama3_full_sft.yaml
#### 使用 vLLM 多卡推理评估
```
python scripts/vllm_infer.py --model_name_or_path meta-llama/Meta-Llama-3-8B-Instruct --template llama3 --dataset alpaca_en_demo
python scripts/vllm_infer.py --model_name_or_path Qwen/Qwen3-4B-Instruct-2507 --template qwen3_nothink --dataset alpaca_en_demo
python scripts/eval_bleu_rouge.py generated_predictions.jsonl
```
#### 使用命令行对话框
```bash
llamafactory-cli chat examples/inference/llama3_lora_sft.yaml
llamafactory-cli chat examples/inference/qwen3_lora_sft.yaml
```
#### 使用浏览器对话框
```bash
llamafactory-cli webchat examples/inference/llama3_lora_sft.yaml
llamafactory-cli webchat examples/inference/qwen3_lora_sft.yaml
```
#### 启动 OpenAI 风格 API
```bash
llamafactory-cli api examples/inference/llama3_lora_sft.yaml
llamafactory-cli api examples/inference/qwen3_lora_sft.yaml
```
### 杂项

View File

@@ -1,5 +0,0 @@
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
adapter_name_or_path: saves/llama3-8b/lora/sft
template: llama3
infer_backend: huggingface # choices: [huggingface, vllm, sglang, ktransformers]
trust_remote_code: true

View File

@@ -1,4 +1,4 @@
model_name_or_path: Qwen/Qwen2.5-VL-7B-Instruct
template: qwen2_vl
model_name_or_path: Qwen/Qwen3-4B-Instruct-2507
template: qwen3_nothink
infer_backend: huggingface # choices: [huggingface, vllm, sglang, ktransformers]
trust_remote_code: true

View File

@@ -1,4 +1,4 @@
model_name_or_path: saves/llama3-8b/full/sft
template: llama3
model_name_or_path: saves/qwen3-4b/full/sft
template: qwen3_nothink
infer_backend: huggingface # choices: [huggingface, vllm, sglang, ktransformers]
trust_remote_code: true

View File

@@ -0,0 +1,5 @@
model_name_or_path: Qwen/Qwen3-4B-Instruct-2507
adapter_name_or_path: saves/qwen3-4b/lora/sft
template: qwen3_nothink
infer_backend: huggingface # choices: [huggingface, vllm, sglang, ktransformers]
trust_remote_code: true

View File

@@ -1,4 +1,4 @@
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
template: llama3
model_name_or_path: Qwen/Qwen3-VL-4B-Instruct
template: qwen3_vl_nothink
infer_backend: huggingface # choices: [huggingface, vllm, sglang, ktransformers]
trust_remote_code: true

View File

@@ -1,10 +1,10 @@
### model
model_name_or_path: saves/llama3-8b/full/sft
template: llama3
model_name_or_path: saves/qwen3-4b/full/sft
template: qwen3_nothink
trust_remote_code: true
### export
export_dir: output/llama3_full_sft
export_dir: saves/qwen3_sft_merged
export_size: 5
export_device: cpu # choices: [cpu, auto]
export_legacy_format: false

View File

@@ -1,10 +1,10 @@
### model
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
template: llama3
model_name_or_path: Qwen/Qwen3-4B-Instruct-2507
template: qwen3_nothink
trust_remote_code: true
### export
export_dir: output/llama3_gptq
export_dir: saves/qwen3_gptq
export_quantization_bit: 4
export_quantization_dataset: data/c4_demo.jsonl
export_size: 5

View File

@@ -1,13 +1,13 @@
### Note: DO NOT use quantized model or quantization_bit when merging lora adapters
### model
model_name_or_path: Qwen/Qwen2.5-VL-7B-Instruct
adapter_name_or_path: saves/qwen2_5vl-7b/lora/sft
template: qwen2_vl
model_name_or_path: Qwen/Qwen3-4B-Instruct-2507
adapter_name_or_path: saves/qwen3-4b/lora/sft
template: qwen3_nothink
trust_remote_code: true
### export
export_dir: output/qwen2_5vl_lora_sft
export_dir: saves/qwen3_sft_merged
export_size: 5
export_device: cpu # choices: [cpu, auto]
export_legacy_format: false

View File

@@ -1,13 +1,13 @@
### Note: DO NOT use quantized model or quantization_bit when merging lora adapters
### model
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
adapter_name_or_path: saves/llama3-8b/lora/sft
template: llama3
model_name_or_path: Qwen/Qwen3-VL-4B-Instruct
adapter_name_or_path: saves/qwen3-vl-4b/lora/sft
template: qwen3_vl_nothink
trust_remote_code: true
### export
export_dir: output/llama3_lora_sft
export_dir: saves/qwen3_vl_sft_merged
export_size: 5
export_device: cpu # choices: [cpu, auto]
export_legacy_format: false

View File

@@ -1,5 +1,5 @@
### model
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
model_name_or_path: Qwen/Qwen3-4B-Instruct-2507
trust_remote_code: true
### method
@@ -10,15 +10,14 @@ deepspeed: examples/deepspeed/ds_z3_config.json # choices: [ds_z0_config.json,
### dataset
dataset: identity,alpaca_en_demo
template: llama3
template: qwen3_nothink
cutoff_len: 2048
max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16
dataloader_num_workers: 4
### output
output_dir: saves/llama3-8b/full/sft
output_dir: saves/qwen3-4b/full/sft
logging_steps: 10
save_steps: 500
plot_loss: true

View File

@@ -1,46 +0,0 @@
### model
model_name_or_path: Qwen/Qwen3-32B
trust_remote_code: true
use_v1_kernels: true
### method
stage: sft
do_train: true
finetuning_type: full
deepspeed: examples/deepspeed/ds_z2_autotp_config.json
### dataset
dataset: identity,alpaca_en_demo
template: qwen3
cutoff_len: 2048
max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16
dataloader_num_workers: 4
### output
output_dir: saves/qwen3-32b/full/sft_autotp
logging_steps: 1
save_steps: 500
plot_loss: true
overwrite_output_dir: true
save_only_model: false
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
### train
per_device_train_batch_size: 4
gradient_accumulation_steps: 1
learning_rate: 1.0e-4
num_train_epochs: 3.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
bf16: true
ddp_timeout: 180000000
resume_from_checkpoint: null
### eval
# eval_dataset: alpaca_en_demo
# val_size: 0.1
# per_device_eval_batch_size: 1
# eval_strategy: steps
# eval_steps: 500

View File

@@ -1,5 +1,5 @@
### model
model_name_or_path: Qwen/Qwen2.5-VL-7B-Instruct
model_name_or_path: Qwen/Qwen3-VL-4B-Instruct
image_max_pixels: 262144
video_max_pixels: 16384
trust_remote_code: true
@@ -15,15 +15,14 @@ deepspeed: examples/deepspeed/ds_z3_config.json
### dataset
dataset: mllm_demo,identity,alpaca_en_demo
template: qwen2_vl
template: qwen3_vl_nothink
cutoff_len: 2048
max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16
dataloader_num_workers: 4
### output
output_dir: saves/qwen2_5vl-7b/full/sft
output_dir: saves/qwen3-vl-4b/full/sft
logging_steps: 10
save_steps: 500
plot_loss: true

View File

@@ -1,19 +0,0 @@
### model
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
adapter_name_or_path: saves/llama3-8b/lora/sft
trust_remote_code: true
### method
finetuning_type: lora
### dataset
task: mmlu_test # choices: [mmlu_test, ceval_validation, cmmlu_test]
template: fewshot
lang: en
n_shot: 5
### output
save_dir: saves/llama3-8b/lora/eval
### eval
batch_size: 4

View File

@@ -1,43 +0,0 @@
### model
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
reward_model: saves/llama3-8b/lora/reward
trust_remote_code: true
### method
stage: ppo
do_train: true
finetuning_type: lora
lora_rank: 8
lora_target: all
### dataset
dataset: identity,alpaca_en_demo
template: llama3
cutoff_len: 2048
max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16
dataloader_num_workers: 4
### output
output_dir: saves/llama3-8b/lora/ppo
logging_steps: 10
save_steps: 500
plot_loss: true
overwrite_output_dir: true
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
### train
per_device_train_batch_size: 1
gradient_accumulation_steps: 8
learning_rate: 1.0e-5
num_train_epochs: 3.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
bf16: true
ddp_timeout: 180000000
### generate
max_new_tokens: 512
top_k: 0
top_p: 0.9

View File

@@ -1,46 +0,0 @@
### model
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
trust_remote_code: true
### method
stage: sft
do_train: true
finetuning_type: lora
lora_rank: 8
lora_target: all
### dataset
dataset: identity,alpaca_en_demo
template: llama3
cutoff_len: 2048
max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16
dataloader_num_workers: 4
### output
output_dir: saves/llama3-8b/lora/sft
logging_steps: 10
save_steps: 500
plot_loss: true
overwrite_output_dir: true
save_only_model: false
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
### train
per_device_train_batch_size: 1
gradient_accumulation_steps: 8
learning_rate: 1.0e-4
num_train_epochs: 3.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
bf16: true
ddp_timeout: 180000000
resume_from_checkpoint: null
### eval
# eval_dataset: alpaca_en_demo
# val_size: 0.1
# per_device_eval_batch_size: 1
# eval_strategy: steps
# eval_steps: 500

View File

@@ -1,49 +0,0 @@
# pip install git+https://github.com/hiyouga/transformers.git@llama4_train
### model
model_name_or_path: meta-llama/Llama-4-Scout-17B-16E-Instruct
trust_remote_code: true
### method
stage: sft
do_train: true
finetuning_type: lora
lora_rank: 8
lora_target: all
deepspeed: examples/deepspeed/ds_z3_config.json # choices: [ds_z0_config.json, ds_z2_config.json, ds_z3_config.json]
### dataset
dataset: mllm_demo,identity,alpaca_en_demo
template: llama4
cutoff_len: 2048
max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16
dataloader_num_workers: 4
### output
output_dir: saves/llama4-8b/lora/sft
logging_steps: 10
save_steps: 500
plot_loss: true
overwrite_output_dir: true
save_only_model: false
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
### train
per_device_train_batch_size: 1
gradient_accumulation_steps: 2
learning_rate: 1.0e-4
num_train_epochs: 3.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
bf16: true
ddp_timeout: 180000000
resume_from_checkpoint: null
### eval
# eval_dataset: alpaca_en_demo
# val_size: 0.1
# per_device_eval_batch_size: 1
# eval_strategy: steps
# eval_steps: 500

View File

@@ -1,5 +1,5 @@
### model
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
model_name_or_path: Qwen/Qwen3-4B-Instruct-2507
trust_remote_code: true
### method
@@ -13,15 +13,14 @@ pref_loss: sigmoid # choices: [sigmoid (dpo), orpo, simpo]
### dataset
dataset: dpo_en_demo
template: llama3
template: qwen3_nothink
cutoff_len: 2048
max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16
dataloader_num_workers: 4
### output
output_dir: saves/llama3-8b/lora/dpo
output_dir: saves/qwen3-4b/lora/dpo
logging_steps: 10
save_steps: 500
plot_loss: true

View File

@@ -1,5 +1,5 @@
### model
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
model_name_or_path: Qwen/Qwen3-4B-Instruct-2507
trust_remote_code: true
### method
@@ -12,15 +12,14 @@ pref_beta: 0.1
### dataset
dataset: kto_en_demo
template: llama3
template: qwen3_nothink
cutoff_len: 2048
max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16
dataloader_num_workers: 4
### output
output_dir: saves/llama3-8b/lora/kto
output_dir: saves/qwen3-4b/lora/kto
logging_steps: 10
save_steps: 500
plot_loss: true

View File

@@ -1,5 +1,5 @@
### model
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
model_name_or_path: Qwen/Qwen3-4B-Instruct-2507
trust_remote_code: true
### method
@@ -13,12 +13,11 @@ lora_target: all
dataset: c4_demo
cutoff_len: 2048
max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16
dataloader_num_workers: 4
### output
output_dir: saves/llama3-8b/lora/pretrain
output_dir: saves/qwen3-4b/lora/pretrain
logging_steps: 10
save_steps: 500
plot_loss: true

View File

@@ -1,5 +1,5 @@
### model
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
model_name_or_path: Qwen/Qwen3-4B-Instruct-2507
trust_remote_code: true
### method
@@ -11,15 +11,14 @@ lora_target: all
### dataset
dataset: dpo_en_demo
template: llama3
template: qwen3_nothink
cutoff_len: 2048
max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16
dataloader_num_workers: 4
### output
output_dir: saves/llama3-8b/lora/reward
output_dir: saves/qwen3-4b/lora/reward
logging_steps: 10
save_steps: 500
plot_loss: true

View File

@@ -2,7 +2,7 @@
set -x
MODEL_PATH=meta-llama/Meta-Llama-3-8B-Instruct
MODEL_PATH=Qwen/Qwen3-4B-Instruct-2507
llamafactory-cli train \
--model_name_or_path ${MODEL_PATH} \
@@ -13,13 +13,12 @@ llamafactory-cli train \
--lora_rank 8 \
--lora_target all \
--dataset identity,alpaca_en_demo \
--template llama3 \
--template qwen3_nothink \
--cutoff_len 2048 \
--max_samples 1000 \
--overwrite_cache \
--preprocessing_num_workers 16 \
--dataloader_num_workers 4 \
--output_dir saves/llama3-8b/lora/sft \
--output_dir saves/qwen3-4b/lora/sft \
--logging_steps 10 \
--save_steps 500 \
--plot_loss \

View File

@@ -1,5 +1,5 @@
### model
model_name_or_path: openai/gpt-oss-20b
model_name_or_path: Qwen/Qwen3-4B-Instruct-2507
trust_remote_code: true
### method
@@ -11,15 +11,14 @@ lora_target: all
### dataset
dataset: identity,alpaca_en_demo
template: gpt
template: qwen3_nothink
cutoff_len: 2048
max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16
dataloader_num_workers: 4
### output
output_dir: saves/gpt-20b/lora/sft
output_dir: saves/qwen3-4b/lora/sft
logging_steps: 10
save_steps: 500
plot_loss: true

View File

@@ -1,5 +1,5 @@
### model
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
model_name_or_path: Qwen/Qwen3-4B-Instruct-2507
trust_remote_code: true
### method
@@ -12,15 +12,14 @@ deepspeed: examples/deepspeed/ds_z3_config.json # choices: [ds_z0_config.json,
### dataset
dataset: identity,alpaca_en_demo
template: llama3
template: qwen3_nothink
cutoff_len: 2048
max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16
dataloader_num_workers: 4
### output
output_dir: saves/llama3-8b/lora/sft
output_dir: saves/qwen3-4b/lora/sft
logging_steps: 10
save_steps: 500
plot_loss: true

View File

@@ -1,5 +1,5 @@
### model
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct # or use local absolute path
model_name_or_path: Qwen/Qwen3-4B-Instruct-2507 # or use local absolute path
trust_remote_code: true
### method
@@ -12,10 +12,9 @@ lora_target: all
### dataset
dataset: identity,alpaca_en_demo
dataset_dir: REMOTE:llamafactory/demo_data # or use local absolute path
template: llama3
template: qwen3_nothink
cutoff_len: 2048
max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16
dataloader_num_workers: 4
@@ -29,7 +28,7 @@ save_only_model: false
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
### ray
ray_run_name: llama3_8b_sft_lora
ray_run_name: qwen3_4b_sft_lora
ray_storage_path: ./saves
ray_num_workers: 4 # Number of GPUs to use.
placement_strategy: PACK

View File

@@ -1,5 +1,5 @@
### model
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
model_name_or_path: Qwen/Qwen3-4B-Instruct-2507
trust_remote_code: true
### method
@@ -11,13 +11,12 @@ lora_target: all
### dataset
dataset: identity,alpaca_en_demo
template: llama3
template: qwen3_nothink
cutoff_len: 2048
max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16
tokenized_path: saves/llama3-8b/dataset/sft
tokenized_path: saves/qwen3-4b/dataset/sft
### output
output_dir: saves/llama3-8b/lora/sft
### output (not used)
output_dir: saves/qwen3-4b/lora/sft
overwrite_output_dir: true

View File

@@ -1,5 +1,5 @@
### model
model_name_or_path: Qwen/Qwen2.5-VL-7B-Instruct
model_name_or_path: Qwen/Qwen3-VL-4B-Instruct
image_max_pixels: 262144
video_max_pixels: 16384
trust_remote_code: true
@@ -15,15 +15,14 @@ pref_loss: sigmoid # choices: [sigmoid (dpo), orpo, simpo]
### dataset
dataset: rlhf_v
template: qwen2_vl
template: qwen3_vl_nothink
cutoff_len: 2048
max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16
dataloader_num_workers: 4
### output
output_dir: saves/qwen2_5vl-7b/lora/dpo
output_dir: saves/qwen3-vl-4b/lora/dpo
logging_steps: 10
save_steps: 500
plot_loss: true

View File

@@ -1,5 +1,5 @@
### model
model_name_or_path: Qwen/Qwen2.5-VL-7B-Instruct
model_name_or_path: Qwen/Qwen3-VL-4B-Instruct
image_max_pixels: 262144
video_max_pixels: 16384
trust_remote_code: true
@@ -13,15 +13,14 @@ lora_target: all
### dataset
dataset: mllm_demo,identity,alpaca_en_demo # video: mllm_video_demo
template: qwen2_vl
template: qwen3_vl_nothink
cutoff_len: 2048
max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16
dataloader_num_workers: 4
### output
output_dir: saves/qwen2_5vl-7b/lora/sft
output_dir: saves/qwen3-vl-4b/lora/sft
logging_steps: 10
save_steps: 500
plot_loss: true

View File

@@ -14,7 +14,6 @@ dataset: identity,alpaca_en_demo
template: llama3
cutoff_len: 2048
max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16
dataloader_num_workers: 4

View File

@@ -14,7 +14,6 @@ dataset: identity,alpaca_en_demo
template: llama3
cutoff_len: 2048
max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16
dataloader_num_workers: 4

View File

@@ -14,7 +14,6 @@ dataset: identity,alpaca_en_demo
template: llama3
cutoff_len: 2048
max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16
dataloader_num_workers: 4

View File

@@ -1,5 +1,5 @@
### model
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
model_name_or_path: Qwen/Qwen3-4B-Instruct-2507
quantization_bit: 4
quantization_method: bnb
double_quantization: false
@@ -14,15 +14,14 @@ lora_target: all
### dataset
dataset: identity,alpaca_en_demo
template: llama3
template: qwen3_nothink
cutoff_len: 2048
max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16
dataloader_num_workers: 4
### output
output_dir: saves/llama3-8b/lora/sft
output_dir: saves/qwen3-4b/lora/sft
logging_steps: 10
save_steps: 500
plot_loss: true

View File

@@ -1,5 +1,5 @@
### model
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
model_name_or_path: Qwen/Qwen3-4B-Instruct-2507
quantization_bit: 4 # choices: [8 (bnb/hqq/eetq), 4 (bnb/hqq), 3 (hqq), 2 (hqq)]
quantization_method: bnb # choices: [bnb, hqq, eetq]
trust_remote_code: true
@@ -13,15 +13,14 @@ lora_target: all
### dataset
dataset: identity,alpaca_en_demo
template: llama3
template: qwen3_nothink
cutoff_len: 2048
max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16
dataloader_num_workers: 4
### output
output_dir: saves/llama3-8b/lora/sft
output_dir: saves/qwen3-4b/lora/sft
logging_steps: 10
save_steps: 500
plot_loss: true

View File

@@ -41,15 +41,14 @@ dependencies = [
"torch>=2.4.0",
"torchvision>=0.19.0",
"torchaudio>=2.4.0",
"transformers>=4.49.0,<=4.56.2,!=4.52.0; python_version < '3.10'",
"transformers>=4.49.0,<=4.57.1,!=4.52.0,!=4.57.0; python_version >= '3.10'",
"transformers>=4.51.0,<=4.57.1,!=4.52.0,!=4.57.0",
"datasets>=2.16.0,<=4.0.0",
"accelerate>=1.3.0,<=1.11.0",
"peft>=0.14.0,<=0.17.1",
"trl>=0.8.6,<=0.9.6",
"trl>=0.18.0,<=0.24.0",
"torchdata>=0.10.0,<=0.11.0",
# gui
"gradio>=4.38.0,<=6.2.0",
"gradio>=4.38.0,<=5.50.0",
"matplotlib>=3.7.0",
"tyro<0.9.0",
# ops

View File

@@ -14,9 +14,12 @@
import gc
import json
import time
import av
import fire
from datasets import load_dataset
from eval_bleu_rouge import compute_metrics
from tqdm import tqdm
from transformers import Seq2SeqTrainingArguments
@@ -51,6 +54,7 @@ def vllm_infer(
max_samples: int | None = None,
vllm_config: str = "{}",
save_name: str = "generated_predictions.jsonl",
matrix_save_name: str = None,
temperature: float = 0.95,
top_p: float = 0.7,
top_k: int = 50,
@@ -117,6 +121,7 @@ def vllm_infer(
if isinstance(model_args.vllm_config, dict):
engine_args.update(model_args.vllm_config)
model_preparation_start_time = time.time()
llm = LLM(**engine_args)
# load datasets
@@ -142,6 +147,7 @@ def vllm_infer(
all_prompts, all_preds, all_labels = [], [], []
need_video_kwargs = _need_video_kwargs(template)
model_predict_start_time = time.time()
# Add batch process to avoid the issue of too many files opened
for i in tqdm(range(0, len(train_dataset), batch_size), desc="Processing batched inference"):
vllm_inputs, prompts, labels = [], [], []
@@ -218,6 +224,7 @@ def vllm_infer(
all_labels.extend(labels)
gc.collect()
model_predict_end_time = time.time()
# Write all results at once outside the loop
with open(save_name, "w", encoding="utf-8") as f:
for text, pred, label in zip(all_prompts, all_preds, all_labels):
@@ -227,6 +234,49 @@ def vllm_infer(
print(f"{len(all_prompts)} total generated results have been saved at {save_name}.")
print("*" * 70)
# Write all matrix results when matrix_save_name is not None,
# The result matrix is referencing src.llamafactory.train.sft.workflow.run_sft # 127~132
# trainer.save_metrics("predict", predict_results.metrics)
#
# {
# "predict_bleu-4": 4.349975,
# "predict_model_preparation_time": 0.0128,
# "predict_rouge-1": 21.873359375,
# "predict_rouge-2": 4.144340625,
# "predict_rouge-l": 10.83949375,
# "predict_runtime": 131.664,
# "predict_samples_per_second": 0.076,
# "predict_steps_per_second": 0.008
# }
#
if matrix_save_name is not None:
predict_time = model_predict_end_time - model_predict_start_time
preparation_time = model_predict_start_time - model_preparation_start_time
start_time = time.time()
dataset = load_dataset("json", data_files=save_name, split="train")
dataset = dataset.map(compute_metrics, num_proc=8, remove_columns=dataset.column_names)
score_dict = dataset.to_dict()
average_score = {}
for task, scores in sorted(score_dict.items(), key=lambda x: x[0]):
score = sum(scores) / len(scores) if scores else 0.0
print(f"predict_{task}: {score:.4f}")
average_score["predict_" + task] = score
average_score["predict_model_preparation_time"] = preparation_time
average_score["predict_runtime"] = predict_time
num_steps = len(range(0, len(train_dataset), batch_size))
average_score["predict_samples_per_second"] = len(dataset) / predict_time if predict_time > 0 else 0.0
average_score["predict_steps_per_second"] = num_steps / predict_time if predict_time > 0 else 0.0
with open(matrix_save_name, "w", encoding="utf-8") as f:
json.dump(average_score, f, indent=4)
print("*" * 70)
print(f"\nDone in {time.time() - start_time:.3f}s.\nScore file saved to {matrix_save_name}.")
print("*" * 70)
if __name__ == "__main__":
fire.Fire(vllm_infer)

View File

@@ -1673,6 +1673,43 @@ register_template(
)
register_template(
name="minimax1",
format_user=StringFormatter(
slots=[
"<beginning_of_sentence>user name=user\n{{content}}<end_of_sentence>\n<beginning_of_sentence>ai name=assistant\n"
]
),
format_assistant=StringFormatter(slots=["{{content}}<end_of_sentence>\n"]),
format_system=StringFormatter(
slots=["<beginning_of_sentence>system ai_setting=assistant\n{{content}}<end_of_sentence>\n"]
),
format_function=FunctionFormatter(slots=["{{content}}<end_of_sentence>\n"], tool_format="minimax1"),
format_observation=StringFormatter(
slots=[
"<beginning_of_sentence>tool name=tools\n{{content}}<end_of_sentence>\n<beginning_of_sentence>ai name=assistant\n"
]
),
format_tools=ToolFormatter(tool_format="minimax1"),
default_system="You are a helpful assistant.",
stop_words=["<end_of_sentence>"],
)
register_template(
name="minimax2",
format_user=StringFormatter(slots=["]~b]user\n{{content}}[e~[\n]~b]ai\n"]),
format_assistant=StringFormatter(slots=["{{content}}[e~[\n"]),
format_system=StringFormatter(slots=["]~!b[]~b]system\n{{content}}[e~[\n"]),
format_function=FunctionFormatter(slots=["{{content}}[e~[\n"], tool_format="minimax2"),
format_observation=StringFormatter(slots=["]~b]tool\n<response>{{content}}</response>[e~[\n]~b]ai\n"]),
format_tools=ToolFormatter(tool_format="minimax2"),
default_system="You are a helpful assistant. Your name is MiniMax-M2.1 and is built by MiniMax.",
stop_words=["[e~["],
template_class=ReasoningTemplate,
)
# mistral tokenizer v3 tekken
register_template(
name="ministral",

View File

@@ -61,6 +61,21 @@ LLAMA3_TOOL_PROMPT = (
"Do not use variables.\n\n{tool_text}"
)
MINIMAX_M1_TOOL_PROMPT = (
"You are provided with these tools:\n<tools>\n{tool_text}</tools>\n\n"
"If you need to call tools, please respond with <tool_calls></tool_calls> XML tags, and provide tool-name and "
"json-object of arguments, following the format below:\n<tool_calls>\n"
"""{{"name": <tool-name-1>, "arguments": <args-json-object-1>}}\n...\n</tool_calls>"""
)
MINIMAX_M2_TOOL_PROMPT = (
"\n\n# Tools\n\nYou may call one or more tools to assist with the user query.\n"
"Here are the tools available in JSONSchema format:\n\n<tools>\n{tool_text}</tools>\n\n"
"When making tool calls, use XML format to invoke tools and pass parameters:\n"
"""\n<minimax:tool_call>\n<invoke name="tool-name-1">\n<parameter name="param-key-1">param-value-1</parameter>\n"""
"""<parameter name="param-key-2">param-value-2</parameter>\n...\n</invoke>\n</minimax:tool_call>"""
)
QWEN_TOOL_PROMPT = (
"\n\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\n"
"You are provided with function signatures within <tools></tools> XML tags:\n<tools>{tool_text}"
@@ -253,6 +268,109 @@ class Llama3ToolUtils(ToolUtils):
return content
class MiniMaxM1ToolUtils(ToolUtils):
r"""MiniMax-M1 tool using template."""
@override
@staticmethod
def tool_formatter(tools: list[dict[str, Any]]) -> str:
tool_text = ""
for tool in tools:
tool = tool.get("function", "") if tool.get("type") == "function" else tool
tool_text += json.dumps(tool, ensure_ascii=False) + "\n"
return MINIMAX_M1_TOOL_PROMPT.format(tool_text=tool_text)
@override
@staticmethod
def function_formatter(functions: list["FunctionCall"]) -> str:
function_texts = []
for func in functions:
name, arguments = func.name, json.loads(func.arguments)
function_texts.append(json.dumps({"name": name, "arguments": arguments}, ensure_ascii=False))
return "<tool_calls>\n" + "\n".join(function_texts) + "\n</tool_calls>"
@override
@staticmethod
def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]:
regex = re.compile(r"<tool_calls>\s*(.+?)\s*</tool_calls>", re.DOTALL)
tool_match = re.search(regex, content)
if not tool_match:
return content
tool_calls_content = tool_match.group(1)
results = []
for line in tool_calls_content.split("\n"):
line = line.strip()
if not line:
continue
try:
tool_call = json.loads(line)
results.append(FunctionCall(tool_call["name"], json.dumps(tool_call["arguments"], ensure_ascii=False)))
except json.JSONDecodeError:
continue
return results
class MiniMaxM2ToolUtils(ToolUtils):
r"""MiniMax-M2 tool using template."""
@override
@staticmethod
def tool_formatter(tools: list[dict[str, Any]]) -> str:
tool_text = ""
for tool in tools:
tool = tool.get("function", "") if tool.get("type") == "function" else tool
tool_text += "<tool>" + json.dumps(tool, ensure_ascii=False) + "</tool>\n"
return MINIMAX_M2_TOOL_PROMPT.format(tool_text=tool_text)
@override
@staticmethod
def function_formatter(functions: list["FunctionCall"]) -> str:
function_texts = []
for func in functions:
name, arguments = func.name, json.loads(func.arguments)
prompt = f'<invoke name="{name}">'
for key, value in arguments.items():
prompt += f'\n<parameter name="{key}">'
if not isinstance(value, str):
value = json.dumps(value, ensure_ascii=False)
prompt += value + "</parameter>"
prompt += "\n</invoke>"
function_texts.append(prompt)
@override
@staticmethod
def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]:
regex = re.compile(r"<minimax:tool_call>\s*(.+?)\s*</minimax:tool_call>", re.DOTALL)
tool_match = re.search(regex, content)
if not tool_match:
return content
tool_calls_content = tool_match.group(1)
invoke_regex = re.compile(r"<invoke name=\"(.*?)\">(.*?)</invoke>", re.DOTALL)
results = []
for func_name, params_block in re.findall(invoke_regex, tool_calls_content):
args_dict = {}
param_pattern = re.compile(r"<parameter name=\"(.*?)\">(.*?)</parameter>", re.DOTALL)
for key, raw_value in re.findall(param_pattern, params_block):
value = raw_value.strip()
try:
parsed_value = json.loads(value)
except json.JSONDecodeError:
parsed_value = raw_value
args_dict[key] = parsed_value
results.append(FunctionCall(func_name.strip(), json.dumps(args_dict, ensure_ascii=False)))
return results
class MistralToolUtils(ToolUtils):
r"""Mistral v0.3 tool using template."""
@@ -432,6 +550,8 @@ TOOLS = {
"default": DefaultToolUtils(),
"glm4": GLM4ToolUtils(),
"llama3": Llama3ToolUtils(),
"minimax1": MiniMaxM1ToolUtils(),
"minimax2": MiniMaxM2ToolUtils(),
"mistral": MistralToolUtils(),
"qwen": QwenToolUtils(),
"glm4_moe": GLM4MOEToolUtils(),

View File

@@ -63,6 +63,7 @@ MCA_SUPPORTED_MODELS = {
"qwen2",
"qwen2_vl",
"qwen2_5_vl",
"qwen3_vl",
"qwen3",
"qwen3_moe",
"qwen3_next",
@@ -1071,6 +1072,40 @@ register_model_group(
)
register_model_group(
models={
"MiniMax-Text-01-Instruct": {
DownloadSource.DEFAULT: "MiniMaxAI/MiniMax-Text-01-hf",
DownloadSource.MODELSCOPE: "MiniMaxAI/MiniMax-Text-01",
},
"MiniMax-M1-40k-Thinking": {
DownloadSource.DEFAULT: "MiniMaxAI/MiniMax-M1-40k-hf",
DownloadSource.MODELSCOPE: "MiniMaxAI/MiniMax-M1-40k-hf",
},
"MiniMax-M1-80k-Thinking": {
DownloadSource.DEFAULT: "MiniMaxAI/MiniMax-M1-80k-hf",
DownloadSource.MODELSCOPE: "MiniMaxAI/MiniMax-M1-80k-hf",
},
},
template="minimax1",
)
register_model_group(
models={
"MiniMax-M2-Thinking": {
DownloadSource.DEFAULT: "MiniMaxAI/MiniMax-M2",
DownloadSource.MODELSCOPE: "MiniMaxAI/MiniMax-M2",
},
"MiniMax-M2.1-Thinking": {
DownloadSource.DEFAULT: "MiniMaxAI/MiniMax-M2.1",
DownloadSource.MODELSCOPE: "MiniMaxAI/MiniMax-M2.1",
},
},
template="minimax2",
)
register_model_group(
models={
"Granite-3.0-1B-A400M-Base": {

View File

@@ -19,7 +19,7 @@
from collections import OrderedDict
VERSION = "0.9.4.dev0"
VERSION = "0.9.4"
def print_env() -> None:

View File

@@ -94,11 +94,11 @@ def check_version(requirement: str, mandatory: bool = False) -> None:
def check_dependencies() -> None:
r"""Check the version of the required packages."""
check_version("transformers>=4.49.0,<=4.57.1")
check_version("transformers>=4.51.0,<=4.57.1")
check_version("datasets>=2.16.0,<=4.0.0")
check_version("accelerate>=1.3.0,<=1.11.0")
check_version("peft>=0.14.0,<=0.17.1")
check_version("trl>=0.8.6,<=0.9.6")
check_version("trl>=0.18.0,<=0.24.0")
def calculate_tps(dataset: list[dict[str, Any]], metrics: dict[str, float], stage: Literal["sft", "rm"]) -> float:

View File

@@ -173,7 +173,7 @@ class BaseModelArguments:
default=True,
metadata={"help": "Whether or not to use KV cache in generation."},
)
use_v1_kernels: bool = field(
use_v1_kernels: bool | None = field(
default=False,
metadata={"help": "Whether or not to use high-performance kernels in training."},
)

View File

@@ -216,9 +216,9 @@ def load_model(
"You are try to using future feature about kernels, please note that this feature "
"is not supported for all models. If get any error, please disable this feature, or report the issue."
)
from ..v1.plugins.model_plugins.kernels.registry import apply_available_kernels
from ..v1.plugins.model_plugins.kernels.interface import apply_default_kernels
model = apply_available_kernels(model)
model = apply_default_kernels(model=model, include_kernels=model_args.use_v1_kernels)
trainable_params, all_param = count_parameters(model)
if is_trainable:

View File

@@ -26,6 +26,7 @@ import torch.nn.functional as F
from transformers import Trainer
from trl import DPOTrainer
from trl.trainer import disable_dropout_in_model
from trl.trainer.utils import prepare_deepspeed
from typing_extensions import override
from ...extras.constants import IGNORE_INDEX
@@ -95,7 +96,7 @@ class CustomDPOTrainer(DPOTrainer):
if not (
getattr(ref_model, "is_loaded_in_8bit", False) or getattr(ref_model, "is_loaded_in_4bit", False)
): # quantized models are already set on the correct device
self.ref_model = self._prepare_deepspeed(self.ref_model)
self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
else:
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
self.ref_model.eval()
@@ -210,7 +211,7 @@ class CustomDPOTrainer(DPOTrainer):
@override
def concatenated_forward(
self, model: "PreTrainedModel", batch: dict[str, "torch.Tensor"], is_ref_model: bool = False
) -> tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
) -> dict[str, "torch.Tensor"]:
r"""Compute the sum log probabilities of the labels under given logits if loss_type is not IPO, ORPO or SimPO.
Otherwise the average log probabilities.
@@ -230,11 +231,18 @@ class CustomDPOTrainer(DPOTrainer):
chosen_logps, rejected_logps = all_logps.split(batch_size, dim=0)
chosen_logits, rejected_logits = all_logits.split(batch_size, dim=0)
chosen_length, _ = valid_length.split(batch_size, dim=0)
if self.loss_type in ["ipo", "orpo", "simpo"]:
return chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_logps
chosen_logps_avg = chosen_logps
else:
return chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_logps / chosen_length
chosen_logps_avg = chosen_logps / chosen_length
return {
"chosen_logps": chosen_logps,
"rejected_logps": rejected_logps,
"chosen_logits": chosen_logits,
"rejected_logits": rejected_logits,
"chosen_logps_avg": chosen_logps_avg,
}
@override
def compute_reference_log_probs(
@@ -252,9 +260,9 @@ class CustomDPOTrainer(DPOTrainer):
ref_context = nullcontext()
with torch.no_grad(), ref_context:
reference_chosen_logps, reference_rejected_logps, *_ = self.concatenated_forward(
ref_model, batch, is_ref_model=True
)
ref_output = self.concatenated_forward(ref_model, batch, is_ref_model=True)
reference_chosen_logps = ref_output["chosen_logps"]
reference_rejected_logps = ref_output["rejected_logps"]
return reference_chosen_logps, reference_rejected_logps
@@ -267,13 +275,13 @@ class CustomDPOTrainer(DPOTrainer):
) -> tuple["torch.Tensor", dict[str, "torch.Tensor"]]:
r"""Compute the DPO loss and other metrics for the given batch of inputs for train or test."""
metrics = {}
(
policy_chosen_logps,
policy_rejected_logps,
policy_chosen_logits,
policy_rejected_logits,
policy_chosen_logps_avg,
) = self.concatenated_forward(model, batch)
model_output = self.concatenated_forward(model, batch)
policy_chosen_logps = model_output["chosen_logps"]
policy_rejected_logps = model_output["rejected_logps"]
policy_chosen_logits = model_output["chosen_logits"]
policy_rejected_logits = model_output["rejected_logits"]
policy_chosen_logps_avg = model_output["chosen_logps_avg"]
reference_chosen_logps, reference_rejected_logps = self.compute_reference_log_probs(model, batch)
losses, chosen_rewards, rejected_rewards = self.compute_preference_loss(

View File

@@ -25,6 +25,7 @@ import torch
from transformers import Trainer
from trl import KTOTrainer
from trl.trainer import disable_dropout_in_model
from trl.trainer.utils import prepare_deepspeed
from typing_extensions import override
from ...extras.constants import IGNORE_INDEX
@@ -77,6 +78,13 @@ class CustomKTOTrainer(KTOTrainer):
self.desirable_weight = finetuning_args.kto_chosen_weight
self.undesirable_weight = finetuning_args.kto_rejected_weight
self.ftx_gamma = finetuning_args.pref_ftx
# trl
# Not all losses require a KL calculation
self.calculate_KL = True
if hasattr(self, "loss_type") and self.loss_type in ["apo_zero_unpaired"]:
self.calculate_KL = False
else:
self.loss_type = "kto"
Trainer.__init__(self, model=model, **kwargs)
self.model_accepts_loss_kwargs = False # overwrite trainer's default behavior
@@ -90,7 +98,7 @@ class CustomKTOTrainer(KTOTrainer):
if not (
getattr(ref_model, "is_loaded_in_8bit", False) or getattr(ref_model, "is_loaded_in_4bit", False)
): # quantized models are already set on the correct device
self.ref_model = self._prepare_deepspeed(self.ref_model)
self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
else:
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
self.ref_model.eval()

View File

@@ -11,14 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""MCA (mcore_adapter) workflows for PT/SFT/DPO stages, aligned with LLaMA-Factory's workflow style."""
from __future__ import annotations
import functools
from collections.abc import Sequence
from copy import deepcopy
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Optional
from transformers import DataCollatorForSeq2Seq
from ...data import (
SFTDataCollatorWith4DAttentionMask,
@@ -44,11 +43,11 @@ from mcore_adapter.models import AutoConfig, AutoModel
from mcore_adapter.trainer import DPOTrainer as McaDPOTrainer
from mcore_adapter.trainer import McaTrainer
from mcore_adapter.trainer.dpo_config import DPOConfig
from mcore_adapter.training_args import Seq2SeqTrainingArguments as McaSeq2SeqTrainingArguments
if TYPE_CHECKING:
from transformers import DataCollatorForSeq2Seq, TrainerCallback
from mcore_adapter.training_args import Seq2SeqTrainingArguments as McaSeq2SeqTrainingArguments
from transformers import TrainerCallback
from ...hparams import DataArguments, FinetuningArguments, ModelArguments
@@ -76,7 +75,7 @@ def _data_collator_wrapper(data_collator: Any):
return wrapper
def _check_model_support(model_args: ModelArguments):
def _check_model_support(model_args: "ModelArguments"):
from transformers import AutoConfig as HfAutoConfig
config = HfAutoConfig.from_pretrained(
@@ -87,11 +86,11 @@ def _check_model_support(model_args: ModelArguments):
def run_pt(
model_args: ModelArguments,
data_args: DataArguments,
training_args: McaSeq2SeqTrainingArguments,
finetuning_args: FinetuningArguments,
callbacks: list[TrainerCallback] | None = None,
model_args: "ModelArguments",
data_args: "DataArguments",
training_args: "McaSeq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
callbacks: Optional[list["TrainerCallback"]] = None,
):
tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"]
@@ -104,10 +103,7 @@ def run_pt(
_check_model_support(model_args)
model = AutoModel.from_pretrained(model_args.model_name_or_path, training_args)
from transformers import DataCollatorForSeq2Seq
data_collator: DataCollatorForSeq2Seq = DataCollatorForSeq2Seq(
data_collator = DataCollatorForSeq2Seq(
tokenizer=tokenizer,
pad_to_multiple_of=8,
label_pad_token_id=IGNORE_INDEX,
@@ -142,11 +138,11 @@ def run_pt(
def run_sft(
model_args: ModelArguments,
data_args: DataArguments,
training_args: McaSeq2SeqTrainingArguments,
finetuning_args: FinetuningArguments,
callbacks: list[TrainerCallback] | None = None,
model_args: "ModelArguments",
data_args: "DataArguments",
training_args: "McaSeq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
callbacks: Optional[list["TrainerCallback"]] = None,
):
# align packing flags
# TODO: FIX SequencePacking
@@ -166,7 +162,7 @@ def run_sft(
model = AutoModel.from_pretrained(model_args.model_name_or_path, training_args)
# optional freezing for qwen2_vl, qwen2_5_vl
if getattr(model.config, "hf_model_type", None) in ["qwen2_vl", "qwen2_5_vl"]:
if getattr(model.config, "hf_model_type", None) in ["qwen2_vl", "qwen2_5_vl", "qwen3_vl"]:
params_to_freeze = []
if finetuning_args.freeze_vision_tower:
params_to_freeze.extend(["vision_model.blocks", "vision_model.patch_embed"])
@@ -220,11 +216,11 @@ def run_sft(
def run_dpo(
model_args: ModelArguments,
data_args: DataArguments,
training_args: McaSeq2SeqTrainingArguments,
finetuning_args: FinetuningArguments,
callbacks: list[TrainerCallback] | None = None,
model_args: "ModelArguments",
data_args: "DataArguments",
training_args: "McaSeq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
callbacks: Optional[list["TrainerCallback"]] = None,
):
tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"]

View File

@@ -33,12 +33,12 @@ from transformers.trainer_pt_utils import remove_dummy_checkpoint
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
from transformers.utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME
from trl import PPOConfig, PPOTrainer
from trl.core import PPODecorators, logprobs_from_logits
from trl import __version__ as trl_version
from trl.models.utils import unwrap_model_for_generation
from typing_extensions import override
from ...extras import logging
from ...extras.misc import AverageMeter, count_parameters, get_current_device, get_logits_processor
from ...extras.misc import AverageMeter, count_parameters, get_current_device, get_logits_processor, torch_gc
from ..callbacks import FixValueHeadModelCallback, SaveProcessorCallback
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
from .ppo_utils import dump_layernorm, get_rewards_from_server, replace_model, restore_layernorm
@@ -83,6 +83,19 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
if eval_dataset is not None:
raise NotImplementedError("PPOTrainer does not support eval dataset yet.")
# Check if TRL version is compatible (0.8.6 <= version <= 0.9.6)
try:
from transformers.utils.versions import require_version
require_version(
"trl>=0.8.6,<=0.9.6",
"Incompatible TRL version detected. LLaMA-Factory ppo requires TRL version >=0.8.6,<=0.9.6. "
f"Found version {trl_version}. Please install the correct version with: `pip install trl>=0.8.6,<=0.9.6`\n"
"To fix: run `DISABLE_VERSION_CHECK=1 llamafactory-cli train example_ppo.yaml`\n",
)
except ImportError as e:
raise e
backward_batch_size = training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps
ppo_config = PPOConfig(
model_name=model_args.model_name_or_path,
@@ -406,7 +419,6 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
return rewards.float().detach() # use fp32 type
@override
@PPODecorators.empty_device_cache()
def batched_forward_pass(
self,
model: "AutoModelForCausalLMWithValueHead",
@@ -420,6 +432,9 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
Subclass and override to inject custom behavior.
"""
from trl.core import logprobs_from_logits
torch_gc()
bs = len(queries)
fbs = self.config.mini_batch_size
all_logprobs = []

View File

@@ -108,7 +108,7 @@ def create_modelcard_and_push(
elif training_args.push_to_hub:
trainer.push_to_hub(**kwargs)
else:
trainer.create_model_card(license="other", **kwargs) # prevent from connecting to hub
Trainer.create_model_card(trainer, license="other", **kwargs) # prevent from connecting to hub
def create_ref_model(

View File

@@ -112,6 +112,13 @@ class ModelLoader:
model = PeftPlugin(self.args.peft_config.name)(model, self.args.peft_config, self.is_train)
if self.args.kernel_config is not None:
from ..plugins.model_plugins.kernels.interface import KernelPlugin
model = KernelPlugin(self.args.kernel_config.name)(
model=model, include_kernels=self.args.kernel_config.get("include_kernels")
)
return model

View File

@@ -0,0 +1,87 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The definition of base kernel class.
Init Phase:
1. Define base kernel class.
2. Define abstract methods.
"""
from abc import ABC, abstractmethod
from typing import Any
from ....accelerator.helper import DeviceType, get_current_accelerator
from ....utils.types import HFModel
class BaseKernel(ABC):
r"""Base class for all kernel implementations.
Subclasses must implement the abstract methods and define the required class attributes.
"""
_kernel_id: Any = "" # kernel ID, any hashable value to identify a kernel implementation
_device: DeviceType = DeviceType.CPU # "cuda", "npu", "cpu", etc.
@classmethod
def get_kernel_id(cls) -> str:
r"""Returns the unique identifier for the kernel."""
return cls._kernel_id
@classmethod
def get_device(cls) -> str:
r"""Returns the device type associated with the kernel (e.g., "cuda", "npu", "cpu")."""
return cls._device
@classmethod
def check_deps(cls) -> bool:
r"""Checks if the required dependencies for the kernel are available.
Returns:
bool: ``True`` if dependencies are met, ``False`` otherwise.
.. note::
In explicit mode, if a user specifies an implementation but this check fails,
it should raise an error instead of silently switching.
Kernels can override this method to implement custom dependency checks.
"""
if cls._device != get_current_accelerator().type:
return False
return True
@classmethod
@abstractmethod
def apply(cls, **kwargs) -> HFModel:
r"""Applies the kernel optimization to the model.
Args:
**kwargs: Arbitrary keyword arguments, usually containing the model instance and the kernel configuration.
Returns:
HFModel: The model with the kernel applied.
Raises:
RuntimeError: If the kernel dependencies are not met.
NotImplementedError: If the method is not implemented by the subclass.
Example:
>>> from llamafactory.v1.plugins.model_plugins.kernels.interface import apply_kernel
>>> model = HFModel(config=config)
>>> model = apply_kernel(model=model, kernel_id="npu_fused_moe")
"""
if not cls.check_deps():
raise RuntimeError(f"{cls.__name__} is not available but {cls.__name__} kernel was called.")
raise NotImplementedError

View File

@@ -1,23 +0,0 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from enum import Enum
class KernelType(str, Enum):
RMSNORM = "rmsnorm"
SWIGLU = "swiglu"
FLASH_ATTENTION = "flash_attention"
ROPE = "rope"
MOE = "moe"

View File

@@ -0,0 +1,132 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The definition of kernel interface.
Init Phase:
1. Scan all kernels.
2. Register default kernels.
3. Define kernel plugin.
"""
import importlib
from pathlib import Path
from ....utils.logging import get_logger
from ....utils.plugin import BasePlugin
from .registry import Registry
logger = get_logger(__name__)
def scan_all_kernels():
r"""Scan all kernels in the ``ops`` directory.
Scans the ``ops`` directory for all ``.py`` files and attempts to import them.
Importing triggers the :func:`~registry.register_kernel` decorator, which automatically registers the kernels.
Returns:
dict[str, type[BaseKernel]]: A dictionary of registered kernels.
.. note::
This function assumes that the ``ops`` directory is located in the same directory as this file.
It recursively searches for ``.py`` files and constructs the module path for import.
"""
ops_path = Path(__file__).parent / "ops"
if not ops_path.exists():
return
base_package = __package__
for file_path in ops_path.rglob("*.py"):
if file_path.name == "__init__.py":
continue
# calculate the relative path:
# file_path = .../kernels_v2/ops/mlp/npu_swiglu.py
# rel_path = ops/mlp/npu_swiglu.py
rel_path = file_path.relative_to(Path(__file__).parent)
# build module path:
module_name = ".".join(rel_path.parts)[:-3]
full_module_name = f"{base_package}.{module_name}"
try:
importlib.import_module(full_module_name)
except Exception as e:
logger.warning(f"[Kernel Registry] Failed to import {full_module_name} when loading kernels: {e}")
return Registry.get_registered_kernels()
default_kernels = scan_all_kernels()
def get_default_kernels():
r"""Get a list of default registered kernel IDs.
Returns:
list[str]: List of kernel IDs.
"""
return list(default_kernels.keys())
def apply_kernel(kernel_id: str, **kwargs):
r"""Applies a specific kernel to the model.
Args:
kernel_id (str): The ID of the kernel to apply.
**kwargs: Keyword arguments passed to the kernel application function.
Typically includes the model instance.
Returns:
HFModel: The model with applied kernel.
"""
kernel = default_kernels.get(kernel_id)
if kernel is None:
raise ValueError(f"Kernel {kernel_id} not found")
kernel.apply(**kwargs)
class KernelPlugin(BasePlugin):
r"""Plugin for managing kernel optimizations."""
pass
@KernelPlugin("auto").register
def apply_default_kernels(**kwargs):
r"""Applies all default registered kernels to the model.
Args:
**kwargs: Keyword arguments passed to the kernel application function.
Typically includes the model instance and the include_kernels configuration.
Returns:
HFModel: The model with applied kernels.
"""
if not kwargs.get("include_kernels"): # None/False/empty string
return kwargs.get("model")
elif kwargs.get("include_kernels") == "auto" or kwargs.get("include_kernels") is True: # True/auto
use_kernels = default_kernels.keys()
else:
use_kernels = kwargs.get("include_kernels").split(",") # "kernel_id1,kernel_id2,kernel_id3"
for kernel in use_kernels:
if kernel not in default_kernels:
raise ValueError(f"Kernel {kernel} not found")
apply_kernel(kernel, **kwargs)
return kwargs.get("model")

View File

@@ -12,22 +12,49 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""The definition of NPU fused MoE kernels.
Init Phase:
1. Define GMM functions.
2. Define NPU fused MoE functions.
3. Register NPU fused MoE kernel.
"""
import types
import torch
import torch.nn.functional as F
import torch_npu
from .....accelerator.helper import DeviceType, is_torch_npu_available
from .....utils.packages import is_transformers_version_greater_than
from .....utils.types import HFModel
from ..constants import KernelType
from ..registry import MetaMoEKernel
try:
import torch_npu
except ImportError:
pass
from ......accelerator.helper import DeviceType
from ......utils.packages import is_transformers_version_greater_than
from ......utils.types import HFModel
from ...base import BaseKernel
from ...registry import register_kernel
class GmmFunction(torch.autograd.Function):
r"""Custom autograd function for NPU Grouped Matrix Multiplication (GMM)."""
@staticmethod
def forward(ctx, x, weight, group_list):
r"""Performs the forward pass of Grouped Matrix Multiplication.
Args:
ctx: Context object to save tensors for backward pass.
x (Tensor): Input tensor.
weight (Tensor): Weight tensor.
group_list (list): List of group sizes.
Returns:
Tensor: The result of the grouped matrix multiplication.
"""
ctx.save_for_backward(x, weight)
ctx.group_list = group_list
@@ -38,6 +65,15 @@ class GmmFunction(torch.autograd.Function):
@staticmethod
def backward(ctx, grad_output):
r"""Performs the backward pass of Grouped Matrix Multiplication.
Args:
ctx: Context object containing saved tensors.
grad_output (Tensor): Gradient with respect to the output.
Returns:
tuple: Gradients with respect to input, weight, and None for group_list.
"""
input_tensor, weight = ctx.saved_tensors
group_list = ctx.group_list
@@ -58,8 +94,20 @@ class GmmFunction(torch.autograd.Function):
class HybridGmmFunction(torch.autograd.Function):
r"""Custom autograd function for Hybrid Grouped Matrix Multiplication on NPU."""
@staticmethod
def forward(ctx, num_experts, *args):
r"""Performs the forward pass of Hybrid GMM.
Args:
ctx: Context object to save tensors.
num_experts (int): Number of experts.
*args: Variable length argument list containing inputs and weights.
Returns:
tuple: The outputs of the grouped matrix multiplication.
"""
x_list = list(args[:num_experts])
weight_list = list(args[num_experts:])
@@ -76,6 +124,15 @@ class HybridGmmFunction(torch.autograd.Function):
@staticmethod
def backward(ctx, *grad_outputs):
r"""Performs the backward pass of Hybrid GMM.
Args:
ctx: Context object containing saved tensors.
*grad_outputs: Gradients with respect to the outputs.
Returns:
tuple: Gradients with respect to inputs and weights.
"""
saved_tensors = ctx.saved_tensors
num_experts = ctx.num_experts
split_sizes = ctx.split_sizes
@@ -119,10 +176,23 @@ class HybridGmmFunction(torch.autograd.Function):
class NpuMoeFused:
r"""Container for NPU fused MoE forward functions."""
@staticmethod
def npu_moe_experts_forward(
self, hidden_states: torch.Tensor, routing_weights: torch.Tensor, router_indices: torch.Tensor
) -> torch.Tensor:
r"""Forward pass for MoE experts using NPU fused operations.
Args:
self: The MoE layer instance.
hidden_states (Tensor): Input hidden states.
routing_weights (Tensor): Routing weights.
router_indices (Tensor): Router indices.
Returns:
Tensor: Output tensor after expert computation.
"""
batch_size = hidden_states.shape[0]
hidden_states = hidden_states.reshape(-1, self.hidden_size)
permuted_hidden_states, row_ids_map = torch_npu.npu_moe_token_permute(
@@ -138,6 +208,15 @@ class NpuMoeFused:
@staticmethod
def npu_moe_sparse_block_forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
r"""Forward pass for sparse MoE block using NPU optimization.
Args:
self: The MoE sparse block instance.
hidden_states (Tensor): Input hidden states.
Returns:
Tensor: The routed output.
"""
batch_size = hidden_states.shape[0]
hidden_states = hidden_states.reshape(-1, self.hidden_size)
router_logits = self.gate(hidden_states)
@@ -151,8 +230,19 @@ class NpuMoeFused:
class Qwen3NpuMoeFused:
r"""Container for Qwen3 NPU fused MoE forward functions."""
@staticmethod
def qwen3moe_sparse_moe_block_forward(self, hidden_states: torch.Tensor):
r"""Forward pass for Qwen3 sparse MoE block using NPU fused operations.
Args:
self: The Qwen3 MoE block instance.
hidden_states (Tensor): Input hidden states.
Returns:
tuple: A tuple containing the next states and router logits.
"""
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
@@ -206,14 +296,33 @@ if not is_transformers_version_greater_than("5.0.0"):
}
class NpuMoEFusedMoEKernel(MetaMoEKernel):
type = KernelType.MOE
device = DeviceType.NPU
@register_kernel
class NpuFusedMoEKernel(BaseKernel):
r"""NPU Fused MoE Kernel implementation."""
_kernel_id = "npu_fused_moe"
_device = DeviceType.NPU
@classmethod
def apply(cls, model, **kwargs) -> HFModel:
if not is_torch_npu_available():
return model
def apply(cls, **kwargs) -> HFModel:
r"""Applies the NPU fused MoE kernel to the model.
Args:
**kwargs: Keyword arguments containing the model.
Returns:
HFModel: The model with patched MoE forward functions.
Raises:
ValueError: If the model is not provided.
RuntimeError: If dependencies are not met.
"""
model = kwargs.get("model", None)
if model is None:
raise ValueError(f"HFModel instance is required for {cls.__name__}.")
if not cls.check_deps():
raise RuntimeError("torch_npu is not available but NpuMoEFusedMoEKernel was called.")
archs = getattr(model.config, "architectures", [])
target_moe_mapping = None

View File

@@ -12,36 +12,71 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""The definition of NPU fused SwiGLU kernels.
Init Phase:
1. Define SwiGLU forward functions.
2. Register NPU fused SwiGLU kernel.
"""
import re
import types
import torch
from .....accelerator.helper import DeviceType, is_torch_npu_available
from .....utils.types import HFModel
from ..constants import KernelType
from ..registry import MetaSwiGluKernel
from ......accelerator.helper import DeviceType
from ......utils.types import HFModel
from ...base import BaseKernel
from ...registry import register_kernel
def _npu_swiglu_forward(self, hidden_state):
try:
import torch_npu
except ImportError:
pass
def npu_swiglu_forward(self, hidden_state):
r"""SwiGLU forward pass for NPU.
Args:
self: The MLP layer instance.
hidden_state (Tensor): Input hidden state.
Returns:
Tensor: Output of SwiGLU.
"""
return self.down_proj(
torch_npu.npu_swiglu(torch.cat((self.gate_proj(hidden_state), self.up_proj(hidden_state)), dim=-1), dim=-1)
)
def _npu_swiglu_glm4_forward(self, hidden_states):
import torch_npu
r"""SwiGLU forward pass for GLM4 on NPU.
Args:
self: The GLM4 MLP layer instance.
hidden_states (Tensor): Input hidden states.
Returns:
Tensor: Output of SwiGLU.
"""
up_states = self.gate_up_proj(hidden_states)
gate, up_states = up_states.chunk(2, dim=-1)
return self.down_proj(torch_npu.npu_swiglu(torch.cat((gate, up_states), dim=-1), dim=-1))
def _npu_swiglu_gemma3ntext_forward(self, hidden_states):
import torch_npu
r"""SwiGLU forward pass for Gemma3nText on NPU.
Args:
self: The Gemma3nText MLP layer instance.
hidden_states (Tensor): Input hidden states.
Returns:
Tensor: Output of SwiGLU.
"""
gate_proj = self.gate_proj(hidden_states)
if self.activation_sparsity > 0.0:
gate_proj = self._gaussian_topk(gate_proj)
@@ -51,12 +86,11 @@ def _npu_swiglu_gemma3ntext_forward(self, hidden_states):
return down_proj
class NpuSwiGluKernel(MetaSwiGluKernel):
type = KernelType.SWIGLU
device = DeviceType.NPU
kernel = _npu_swiglu_forward
@register_kernel
class NpuSwiGluKernel(BaseKernel):
r"""NPU Kernel for fused SwiGLU activation."""
# Don't apply the kernel to the following modules
# just support apply to the following module layers
expect_modules = frozenset(
{
"Qwen3VLMoeTextMLP",
@@ -87,10 +121,29 @@ class NpuSwiGluKernel(MetaSwiGluKernel):
}
)
_kernel_id = "npu_fused_swiglu"
_device = DeviceType.NPU
@classmethod
def apply(cls, model, **kwargs) -> "HFModel":
if not is_torch_npu_available():
return model
def apply(cls, **kwargs) -> "HFModel":
r"""Applies the NPU fused SwiGLU kernel to the model.
Args:
**kwargs: Keyword arguments containing the model.
Returns:
HFModel: The model with patched SwiGLU forward functions.
Raises:
ValueError: If the model is not provided.
RuntimeError: If dependencies are not met.
"""
model = kwargs.get("model", None)
if model is None:
raise ValueError(f"HFModel instance is required for {cls.__name__}.")
if not cls.check_deps():
raise RuntimeError("torch_npu is not available but NpuSwiGluKernel was called.")
# Mapping of specific mlp modules to their corresponding kernel implementations
kernel_mapping = {
@@ -109,7 +162,7 @@ class NpuSwiGluKernel(MetaSwiGluKernel):
):
# Bind function as an instance method to preserve `self` semantics
# and replace the original forward
kernel_func = kernel_mapping.get(module.__class__.__name__, _npu_swiglu_forward)
kernel_func = kernel_mapping.get(module.__class__.__name__, npu_swiglu_forward)
module.forward = types.MethodType(kernel_func, module)
return model

View File

@@ -11,40 +11,49 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The definition of NPU fused RMSNorm kernels.
Init Phase:
1. Define RMSNorm forward function.
2. Register NPU fused RMSNorm kernel.
"""
import re
import types
from .....accelerator.helper import DeviceType, is_torch_npu_available
from .....utils.types import HFModel
from ..constants import KernelType
from ..registry import MetaRMSNormKernel
from ......accelerator.helper import DeviceType
from ......utils.types import HFModel
from ...base import BaseKernel
from ...registry import register_kernel
def _npu_rms_forward(self, hidden_states):
"""NPU forward implementation for RMSNorm.
def npu_rms_norm_forward(self, hidden_states):
r"""NPU forward implementation for RMSNorm.
Args:
self: RMSNorm module instance with `weight` and `variance_epsilon`.
hidden_states: Input hidden states tensor, same shape as the baseline.
hidden_states (Tensor): Input hidden states tensor, same shape as the baseline.
Returns:
Normalized tensor consistent with the baseline RMSNorm behavior.
Tensor: Normalized tensor consistent with the baseline RMSNorm behavior.
"""
import torch_npu
return torch_npu.npu_rms_norm(hidden_states, self.weight, epsilon=self.variance_epsilon)[0]
class NpuRMSNormKernel(MetaRMSNormKernel):
"""NPU kernel wrapper for RMSNorm that applies the replacement within a model."""
@register_kernel
class NpuRMSNormKernel(BaseKernel):
r"""NPU kernel wrapper for RMSNorm that applies the replacement within a model."""
type = KernelType.RMSNORM
device = DeviceType.NPU
kernel = _npu_rms_forward
_kernel_id = "npu_fused_rmsnorm"
_device = DeviceType.NPU
@classmethod
def apply(cls, model, **kwargs) -> HFModel:
"""Iterate the model and apply NPU-optimized forward to matched RMSNorm modules.
def apply(cls, **kwargs) -> "HFModel":
r"""Iterate the model and apply NPU-optimized forward to matched RMSNorm modules.
Key points:
- Match modules whose class name contains "RMSNorm" (case-insensitive).
@@ -52,10 +61,23 @@ class NpuRMSNormKernel(MetaRMSNormKernel):
replace the original `forward`.
- Do not modify weights, hyperparameters, or module structure to ensure
numerical behavior and interface consistency.
"""
if not is_torch_npu_available():
return model
Args:
**kwargs: Keyword arguments containing the model.
Returns:
HFModel: The model with NPU fused RMSNorm.
Raises:
RuntimeError: If torch_npu is not available.
ValueError: If the model is not provided.
"""
model = kwargs.get("model")
if model is None:
raise ValueError(f"HFModel instance is required for {cls.__name__}.")
if not cls.check_deps():
raise RuntimeError(f"torch_npu is not available but {cls.__name__} was called.")
rms_norm_pattern = re.compile("RMSNorm", re.IGNORECASE)
for name, module in model.named_modules():
@@ -63,6 +85,6 @@ class NpuRMSNormKernel(MetaRMSNormKernel):
if re.search(rms_norm_pattern, module.__class__.__name__):
# Bind function as an instance method to preserve `self` semantics
# and replace the original forward
module.forward = types.MethodType(cls.kernel, module)
module.forward = types.MethodType(npu_rms_norm_forward, module)
return model

View File

@@ -0,0 +1,146 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The definition of NPU fused RoPE kernels.
Init Phase:
1. Define RoPE forward functions.
2. Register NPU fused RoPE kernel.
"""
import sys
import torch
from ......accelerator.helper import DeviceType
from ......utils.logging import get_logger
from ......utils.types import HFModel
from ...base import BaseKernel
from ...registry import register_kernel
logger = get_logger(__name__)
try:
import torch_npu
except ImportError:
pass
def _apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
r"""Applies Rotary Position Embedding to the query and key tensors using NPU optimization.
Args:
q (Tensor): Query tensor.
k (Tensor): Key tensor.
cos (Tensor): Cosine part of embedding.
sin (Tensor): Sine part of embedding.
position_ids (Tensor, optional): Position IDs. Default: ``None``.
unsqueeze_dim (int): Dimension to unsqueeze cos and sin. Default: 1.
Returns:
tuple: (q_embed, k_embed) The embedded query and key tensors.
"""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = torch_npu.npu_rotary_mul(q, cos, sin)
k_embed = torch_npu.npu_rotary_mul(k, cos, sin)
return q_embed, k_embed
def _apply_multimodal_rotary_pos_emb_qwen25_vl(q, k, cos, sin, mrope_section, unsqueeze_dim=1):
r"""Applies Rotary Position Embedding with multimodal sections (Qwen2-VL) on NPU.
Args:
q (Tensor): Query tensor.
k (Tensor): Key tensor.
cos (Tensor): Cosine part of embedding.
sin (Tensor): Sine part of embedding.
mrope_section (Tensor): Multimodal RoPE section.
unsqueeze_dim (int): Dimension to unsqueeze cos and sin. Default: 1.
Returns:
tuple: (q_embed, k_embed) The embedded query and key tensors.
"""
mrope_section = mrope_section * 2
cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
unsqueeze_dim
)
sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
unsqueeze_dim
)
q_embed = torch_npu.npu_rotary_mul(q, cos, sin)
k_embed = torch_npu.npu_rotary_mul(k, cos, sin)
return q_embed, k_embed
@register_kernel
class NpuRoPEKernel(BaseKernel):
r"""NPU Kernel for Rotary Position Embedding."""
_kernel_id = "npu_fused_rope"
_device = DeviceType.NPU
@classmethod
def apply(cls, **kwargs) -> "HFModel":
r"""Apply RoPE acceleration by monkey-patching `apply_rotary_pos_emb`.
This function iterates through the model's modules to find attention layers,
identifies the module where they are defined, and replaces the original
`apply_rotary_pos_emb` function in that module's namespace with the
NPU-accelerated version from this file.
Args:
**kwargs: Keyword arguments containing the model.
Returns:
HFModel: The model with patched RoPE functions.
Raises:
RuntimeError: If dependencies are not met.
ValueError: If the model is not provided.
"""
if not cls.check_deps():
raise RuntimeError(f"torch_npu is not available but {cls.__name__} was called.")
model = kwargs.get("model", None)
if model is None:
raise ValueError(f"HFModel instance is required for {cls.__name__}.")
_modules = set()
for module in model.modules():
if "Attention" in module.__class__.__name__:
module_name = module.__class__.__module__
if module_name in _modules:
continue
try:
target_module = sys.modules[module_name]
if hasattr(target_module, "apply_rotary_pos_emb"):
if getattr(target_module, "apply_rotary_pos_emb") is not _apply_rotary_pos_emb:
setattr(target_module, "apply_rotary_pos_emb", _apply_rotary_pos_emb)
_modules.add(module_name)
if hasattr(target_module, "apply_multimodal_rotary_pos_emb"):
if (
getattr(target_module, "apply_multimodal_rotary_pos_emb")
is not _apply_multimodal_rotary_pos_emb_qwen25_vl
):
setattr(
target_module,
"apply_multimodal_rotary_pos_emb",
_apply_multimodal_rotary_pos_emb_qwen25_vl,
)
_modules.add(module_name)
except Exception as e:
logger.warning_rank0_once(f"Failed to apply RoPE kernel to module {module_name}: {e}")
return model

View File

@@ -12,247 +12,86 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC, ABCMeta, abstractmethod
from collections.abc import Callable
from typing import Any, Optional
"""The definition of kernel registry.
from ....accelerator.helper import DeviceType, get_current_accelerator
from ....utils.types import HFModel
from .constants import KernelType
Init Phase:
1. Define kernel registry.
2. Register kernels.
"""
from typing import Optional
from ....accelerator.helper import get_current_accelerator
from .base import BaseKernel
class KernelRegistry:
_instance: Optional["KernelRegistry"] = None
_initialized: bool = False
def __new__(cls, *args: Any, **kwargs: Any) -> "KernelRegistry":
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self) -> None:
if self._initialized:
return
self._registry: dict[KernelType, dict[DeviceType, Callable[..., Any]]] = {}
self._initialized = True
def register(
self, kernel_type: KernelType, device_type: DeviceType, kernel_impl: Callable[..., Any] | None
) -> None:
"""Register a kernel implementation.
Args:
kernel_type: the type of the kernel (e.g., KernelType.FLASH_ATTENTION).
device_type: the device type the kernel is adapted to (e.g., DeviceType.CUDA).
kernel_impl: the actual kernel function or class.
"""
if kernel_type not in self._registry:
self._registry[kernel_type] = {}
if device_type in self._registry[kernel_type]:
print(f"Warning: Overwriting kernel for {kernel_type.name} on {device_type.name}.")
self._registry[kernel_type][device_type] = kernel_impl
print(f"Registered kernel {kernel_type.name} for device {device_type.name}.")
def get_kernel(self, kernel_type: KernelType, device_type: DeviceType) -> Callable[..., Any] | None:
return self._registry.get(kernel_type, {}).get(device_type)
__all__ = ["Registry", "register_kernel"]
KERNEL_REGISTRY = KernelRegistry()
class Registry:
r"""Registry for managing kernel implementations.
class AutoRegisterKernelMeta(ABCMeta):
"""Metaclass that automatically registers kernel classes upon creation.
This metaclass checks if a newly created class has both `type` and `device`
attributes defined. If so, it automatically registers the kernel in the
global KERNEL_REGISTRY, eliminating the need for manual registration.
To disable auto-registration for a specific class, set `auto_register = False`.
Storage structure: ``{ "kernel_id": Class }``
"""
def __new__(mcs, name, bases, namespace, **kwargs):
cls = super().__new__(mcs, name, bases, namespace, **kwargs)
# Check if auto-registration is disabled
auto_register = namespace.get("auto_register", True)
# Only auto-register if the class has both type and device attributes defined
# and they are not None (skip base classes like MetaKernel itself)
# and auto_register is True
kernel_type = namespace.get("type")
device_type = namespace.get("device")
if auto_register and kernel_type is not None and device_type is not None:
# Auto-register this kernel
KERNEL_REGISTRY.register(kernel_type, device_type, cls)
return cls
class MetaKernel(ABC, metaclass=AutoRegisterKernelMeta):
"""Base class for all kernel implementations.
Subclasses are automatically registered when they define both `type` and `device`
attributes. To disable auto-registration, set `auto_register = False`.
Attributes:
type: The kernel type (e.g., KernelType.RMSNORM). Must be set in subclasses.
device: The device type (e.g., DeviceType.NPU). Must be set in subclasses.
kernel: The actual kernel function or implementation.
auto_register: Set to False to disable automatic registration (default: True).
"""
type: KernelType | None = None
device: DeviceType | None = None
kernel: Callable | None = None
_kernels: dict[str, type[BaseKernel]] = {}
@classmethod
@abstractmethod
def apply(cls, model: HFModel, **kwargs) -> HFModel:
"""Apply the kernel to the model.
def register(cls, kernel_cls: type[BaseKernel]):
r"""Decorator to register a kernel class.
This method should check if the kernel can be applied (e.g., dependencies
are installed, target modules exist) and perform the kernel replacement.
The class must inherit from :class:`BaseKernel` and specify ``_kernel_id`` and ``_device`` attributes.
Args:
model: The HuggingFace model to optimize.
**kwargs: Additional arguments for kernel application.
kernel_cls (type[BaseKernel]): The kernel class to register.
Returns:
The optimized model (may be the same object with modifications).
type[BaseKernel]: The registered kernel class.
Raises:
TypeError: If the class does not inherit from :class:`BaseKernel`.
ValueError: If the kernel ID is missing or already registered.
"""
raise NotImplementedError
if not issubclass(kernel_cls, BaseKernel):
raise TypeError(f"Class {kernel_cls} must inherit from BaseKernel")
kernel_id = kernel_cls.get_kernel_id()
device = kernel_cls.get_device()
# The device type of the current accelerator does not match the device type required by the kernel, skip registration
if device != get_current_accelerator().type:
return
if not kernel_id:
raise ValueError(f"Kernel ID (_kernel_id) is needed for {kernel_cls} to register")
if kernel_id in cls._kernels:
raise ValueError(f"{kernel_id} already registered! The registered kernel is {cls._kernels[kernel_id]}")
cls._kernels[kernel_id] = kernel_cls
return kernel_cls
class MetaFlashAttentionKernel(MetaKernel):
@classmethod
def apply(cls, model: HFModel, **kwargs) -> HFModel:
raise NotImplementedError
def get(cls, kernel_id: str) -> Optional[type[BaseKernel]]:
r"""Retrieves a registered kernel implementation by its ID.
Args:
kernel_id (str): The ID of the kernel to retrieve.
Returns:
Optional[type[BaseKernel]]: The kernel class if found, else ``None``.
"""
return cls._kernels.get(kernel_id)
class MetaRMSNormKernel(MetaKernel):
@classmethod
def apply(cls, model: HFModel, **kwargs) -> HFModel:
raise NotImplementedError
def get_registered_kernels(cls) -> dict[str, type[BaseKernel]]:
r"""Returns a dictionary of all registered kernels.
Returns:
dict[str, type[BaseKernel]]: Dictionary mapping kernel IDs to kernel classes.
"""
return cls._kernels
class MetaSwiGluKernel(MetaKernel):
@classmethod
def apply(cls, model: HFModel, **kwargs) -> HFModel:
raise NotImplementedError
class MetaRoPEKernel(MetaKernel):
@classmethod
def apply(cls, model: HFModel, **kwargs) -> HFModel:
raise NotImplementedError
class MetaMoEKernel(MetaKernel):
@classmethod
def apply(cls, model: HFModel, **kwargs) -> HFModel:
raise NotImplementedError
def _ensure_kernels_loaded() -> None:
"""Ensure all kernel implementations are imported and registered.
This function dynamically imports all kernel implementation modules to trigger
their auto-registration. Python's module system ensures each module is only
executed once (cached in sys.modules), so repeated calls are safe and fast.
"""
# List of kernel module paths to import
kernel_modules = [
"rms_norm.npu_rms_norm",
"rope.npu_rope",
"mlp.npu_swiglu",
"mlp.npu_fused_moe",
# Add new kernel modules here as they are created
]
# Import each module to trigger kernel registration
# Python's import system caches modules, so this is fast on subsequent calls
for module_name in kernel_modules:
try:
__import__(f"{__package__}.{module_name}", fromlist=["*"])
except ImportError:
# Silently ignore import errors (e.g., missing dependencies like torch_npu)
pass
def discover_kernels(model: HFModel = None) -> list[type[MetaKernel]]:
"""Discover and return all kernel classes registered for the current device.
This function inspects the runtime environment (device type) and returns
all MetaKernel classes registered for that device. Each kernel's `apply()`
method is responsible for checking if it can actually be applied (e.g.,
required dependencies are installed, target modules exist in the model).
The function automatically discovers all kernels registered in KERNEL_REGISTRY
without requiring manual enumeration. On first call, it dynamically imports
all kernel implementation modules to trigger their auto-registration.
Args:
model: The HuggingFace model to apply kernels to.
TODO: implement the kernel route detection logic by model structure.
Returns:
A list of MetaKernel classes available for the current device.
"""
# Ensure all kernel modules are imported to trigger registration
_ensure_kernels_loaded()
discovered_kernels: list[type[MetaKernel]] = []
# Detect current device type
accelerator = get_current_accelerator()
try:
device_type = DeviceType(accelerator.type)
except ValueError:
# Unknown device type, return empty list
return discovered_kernels
# Skip CPU as it typically doesn't have optimized kernels
if device_type == DeviceType.CPU:
return discovered_kernels
# Iterate through registry and collect all kernels for current device
for devices in KERNEL_REGISTRY._registry.values():
kernel_cls = devices.get(device_type)
if kernel_cls is not None:
discovered_kernels.append(kernel_cls)
return discovered_kernels
def apply_kernel(model: HFModel, kernel: type[MetaKernel] | Any, /, **kwargs) -> "HFModel":
"""Call the MetaKernel's `apply` to perform the replacement.
Corresponding replacement logic is maintained inside each kernel; the only
requirement is that `apply` returns the replaced model.
Example:
from transformers import AutoModelForCausalLM
from .rms_norm.npu_rms_norm import NpuRMSNormKernel
model = AutoModelForCausalLM.from_pretrained("qwen/qwen2.5-0.5B")
model = apply_kernel(model, NpuRMSNormKernel)
"""
if not issubclass(kernel, MetaKernel):
raise ValueError(f"{kernel} must be a MetaKernel instance.")
if kernel.device != get_current_accelerator().type:
raise ValueError(f"{kernel} must be applied to {kernel.device} device, got {get_current_accelerator().type}.")
return kernel.apply(model, **kwargs)
def apply_available_kernels(model: HFModel, **kwargs) -> "HFModel":
"""Apply all available kernels to the model."""
for kernel in discover_kernels(model):
model = apply_kernel(model, kernel, **kwargs)
return model
# export decorator alias
register_kernel = Registry.register

View File

@@ -1,122 +0,0 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import torch
from .....accelerator.helper import DeviceType, is_torch_npu_available
from .....utils.types import HFModel
from ..constants import KernelType
from ..registry import MetaRoPEKernel
def _apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors."""
import torch_npu
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = torch_npu.npu_rotary_mul(q, cos, sin)
k_embed = torch_npu.npu_rotary_mul(k, cos, sin)
return q_embed, k_embed
def _apply_multimodal_rotary_pos_emb_qwen25_vl(q, k, cos, sin, mrope_section, unsqueeze_dim=1):
"""Applies Rotary Position Embedding with multimodal sections (Qwen2-VL)."""
import torch_npu
mrope_section = mrope_section * 2
cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
unsqueeze_dim
)
sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
unsqueeze_dim
)
q_embed = torch_npu.npu_rotary_mul(q, cos, sin)
k_embed = torch_npu.npu_rotary_mul(k, cos, sin)
return q_embed, k_embed
class NpuRoPEKernel(MetaRoPEKernel):
type = KernelType.ROPE
device = DeviceType.NPU
kernel = _apply_rotary_pos_emb
@classmethod
def apply(cls, model, **kwargs) -> "HFModel":
"""Apply RoPE acceleration by monkey-patching `apply_rotary_pos_emb`.
This function iterates through the model's modules to find attention layers,
identifies the module where they are defined, and replaces the original
`apply_rotary_pos_emb` function in that module's namespace with the
NPU-accelerated version from this file.
"""
if not is_torch_npu_available():
return model
_modules = set()
for module in model.modules():
if "Attention" in module.__class__.__name__:
module_name = module.__class__.__module__
if module_name in _modules:
continue
try:
target_module = sys.modules[module_name]
if hasattr(target_module, "apply_rotary_pos_emb"):
if getattr(target_module, "apply_rotary_pos_emb") is not cls.kernel:
setattr(target_module, "apply_rotary_pos_emb", cls.kernel)
_modules.add(module_name)
except Exception:
pass
return model
class NpuQwen2VLRoPEKernel(MetaRoPEKernel):
"""Qwen2-VL specific RoPE kernel - not auto-registered.
This kernel is for specific models (Qwen2-VL) and should be manually
applied when needed rather than auto-discovered.
"""
type = KernelType.ROPE
device = DeviceType.NPU
kernel = _apply_multimodal_rotary_pos_emb_qwen25_vl
auto_register = False # Disable auto-registration for model-specific kernel
@classmethod
def apply(cls, model, **kwargs) -> "HFModel":
"""Apply RoPE acceleration by monkey-patching `apply_rotary_pos_emb`.
This function iterates through the model's modules to find attention layers,
identifies the module where they are defined, and replaces the original
`apply_rotary_pos_emb` function in that module's namespace with the
NPU-accelerated version from this file.
"""
_modules = set()
for module in model.modules():
if "Attention" in module.__class__.__name__:
module_name = module.__class__.__module__
if module_name in _modules:
continue
try:
target_module = sys.modules[module_name]
if hasattr(target_module, "apply_multimodal_rotary_pos_emb"):
if getattr(target_module, "apply_multimodal_rotary_pos_emb") is not cls.kernel:
setattr(target_module, "apply_multimodal_rotary_pos_emb", cls.kernel)
_modules.add(module_name)
except Exception:
pass
return model

View File

@@ -18,8 +18,11 @@ Contains shared fixtures, pytest configuration, and custom markers.
"""
import os
from typing import Optional
import pytest
import torch
import torch.distributed as dist
from pytest import Config, FixtureRequest, Item, MonkeyPatch
from llamafactory.extras.misc import get_current_device, get_device_count, is_env_enabled
@@ -70,7 +73,7 @@ def _handle_slow_tests(items: list[Item]):
item.add_marker(skip_slow)
def _get_visible_devices_env() -> str | None:
def _get_visible_devices_env() -> Optional[str]:
"""Return device visibility env var name."""
if CURRENT_DEVICE == "cuda":
return "CUDA_VISIBLE_DEVICES"
@@ -118,6 +121,14 @@ def pytest_collection_modifyitems(config: Config, items: list[Item]):
_handle_device_visibility(items)
@pytest.fixture(autouse=True)
def _cleanup_distributed_state():
"""Cleanup distributed state after each test."""
yield
if dist.is_initialized():
dist.destroy_process_group()
@pytest.fixture(autouse=True)
def _manage_distributed_env(request: FixtureRequest, monkeypatch: MonkeyPatch) -> None:
"""Set environment variables for distributed tests if specific devices are requested."""
@@ -145,6 +156,10 @@ def _manage_distributed_env(request: FixtureRequest, monkeypatch: MonkeyPatch) -
monkeypatch.setenv(env_key, visible_devices[0] if visible_devices else "0")
else:
monkeypatch.setenv(env_key, "0")
if CURRENT_DEVICE == "cuda":
monkeypatch.setattr(torch.cuda, "device_count", lambda: 1)
elif CURRENT_DEVICE == "npu":
monkeypatch.setattr(torch.npu, "device_count", lambda: 1)
@pytest.fixture

View File

@@ -0,0 +1,71 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pathlib
import sys
from unittest.mock import patch
from llamafactory.v1.config.arg_parser import get_args
def test_get_args_from_yaml(tmp_path: pathlib.Path):
config_yaml = """
### model
model: "llamafactory/tiny-random-qwen2.5"
trust_remote_code: true
use_fast_processor: true
model_class: "llm"
kernel_config:
name: "auto"
include_kernels: "auto" # choice: null/true/false/auto/kernel_id1,kernel_id2,kernel_id3, default is null
peft_config:
name: "lora"
lora_rank: 0.8
quant_config: null
### data
dataset: "llamafactory/tiny-supervised-dataset"
cutoff_len: 2048
### training
output_dir: "outputs/test_run"
micro_batch_size: 1
global_batch_size: 1
learning_rate: 1.0e-4
bf16: false
dist_config: null
### sample
sample_backend: "hf"
max_new_tokens: 128
"""
config_file = tmp_path / "config.yaml"
config_file.write_text(config_yaml, encoding="utf-8")
test_argv = ["test_args_parser.py", str(config_file)]
with patch.object(sys, "argv", test_argv):
data_args, model_args, training_args, sample_args = get_args()
assert training_args.output_dir == "outputs/test_run"
assert training_args.micro_batch_size == 1
assert training_args.global_batch_size == 1
assert training_args.learning_rate == 1.0e-4
assert training_args.bf16 is False
assert training_args.dist_config is None
assert model_args.model == "llamafactory/tiny-random-qwen2.5"
assert model_args.kernel_config.name == "auto"
assert model_args.kernel_config.get("include_kernels") == "auto"
assert model_args.peft_config.name == "lora"
assert model_args.peft_config.get("lora_rank") == 0.8

View File

@@ -18,8 +18,10 @@ Contains shared fixtures, pytest configuration, and custom markers.
"""
import os
import sys
import pytest
import torch
from pytest import Config, FixtureRequest, Item, MonkeyPatch
from llamafactory.v1.accelerator.helper import get_current_accelerator, get_device_count
@@ -139,9 +141,21 @@ def _manage_distributed_env(request: FixtureRequest, monkeypatch: MonkeyPatch) -
devices_str = ",".join(str(i) for i in range(required))
monkeypatch.setenv(env_key, devices_str)
# add project root dir to path for mp run
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
if project_root not in sys.path:
sys.path.insert(0, project_root)
os.environ["PYTHONPATH"] = project_root + os.pathsep + os.environ.get("PYTHONPATH", "")
else: # non-distributed test
if old_value:
visible_devices = [v for v in old_value.split(",") if v != ""]
monkeypatch.setenv(env_key, visible_devices[0] if visible_devices else "0")
else:
monkeypatch.setenv(env_key, "0")
if CURRENT_DEVICE == "cuda":
monkeypatch.setattr(torch.cuda, "device_count", lambda: 1)
elif CURRENT_DEVICE == "npu":
monkeypatch.setattr(torch.npu, "device_count", lambda: 1)

View File

@@ -14,7 +14,7 @@
import torch
from llamafactory.v1.config.model_args import ModelArguments
from llamafactory.v1.config.model_args import ModelArguments, PluginConfig
from llamafactory.v1.core.model_loader import ModelLoader
@@ -29,5 +29,23 @@ def test_tiny_qwen():
assert model_loader.model.dtype == torch.bfloat16
def test_tiny_qwen_with_kernel_plugin():
from transformers import Qwen2ForCausalLM
from llamafactory.v1.plugins.model_plugins.kernels.ops.rms_norm.npu_rms_norm import npu_rms_norm_forward
model_args = ModelArguments(
model="llamafactory/tiny-random-qwen2.5", kernel_config=PluginConfig(name="auto", include_kernels="auto")
)
model_loader = ModelLoader(model_args)
# test enable apply kernel plugin
if hasattr(torch, "npu"):
assert model_loader.model.model.layers[0].input_layernorm.forward.__code__ == npu_rms_norm_forward.__code__
else:
assert model_loader.model.model.layers[0].input_layernorm.forward.__code__ != npu_rms_norm_forward.__code__
assert isinstance(model_loader.model, Qwen2ForCausalLM)
if __name__ == "__main__":
test_tiny_qwen()
test_tiny_qwen_with_kernel_plugin()

View File

@@ -12,16 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
from unittest.mock import MagicMock, patch
import pytest
from transformers import AutoModelForCausalLM
from llamafactory.v1.accelerator.helper import get_current_accelerator
from llamafactory.v1.plugins.model_plugins.kernels.mlp import npu_swiglu
from llamafactory.v1.plugins.model_plugins.kernels.registry import apply_available_kernels, apply_kernel
from llamafactory.v1.plugins.model_plugins.kernels.rms_norm import npu_rms_norm
from llamafactory.v1.plugins.model_plugins.kernels.rope import npu_rope
@pytest.fixture(autouse=True)
@@ -29,24 +26,29 @@ def clear_accelerator_cache():
get_current_accelerator.cache_clear()
def reload_kernels():
"""Helper to reload kernel modules to respect mocked accelerator."""
# Unload kernel interface and registry
keys_to_remove = [k for k in sys.modules if k.startswith("llamafactory.v1.plugins.model_plugins.kernels")]
for k in keys_to_remove:
del sys.modules[k]
@patch("torch.accelerator.current_accelerator")
def test_apply_kernel(mock_get_accelerator: MagicMock):
mock_device = MagicMock()
setattr(mock_device, "type", "npu")
mock_get_accelerator.return_value = mock_device
# Force reload of kernels with mocked accelerator
reload_kernels()
from llamafactory.v1.plugins.model_plugins.kernels.interface import apply_default_kernels
model = AutoModelForCausalLM.from_pretrained("llamafactory/tiny-random-qwen2.5")
original_rmsnorm_forward = model.model.layers[0].input_layernorm.forward
original_swiglu_forward = model.model.layers[0].mlp.forward
apply_kernel(model, npu_rope.NpuRoPEKernel)
model = apply_kernel(model, npu_rms_norm.NpuRMSNormKernel)
assert model.model.layers[0].input_layernorm is not original_rmsnorm_forward
model = apply_kernel(model, npu_swiglu.NpuSwiGluKernel)
assert model.model.layers[0].mlp.forward is not original_swiglu_forward
model = apply_default_kernels(model=model, include_kernels="npu_fused_rmsnorm")
assert model.model.layers[0].input_layernorm.forward.__func__ is not original_rmsnorm_forward.__func__
assert model.model.layers[0].mlp.forward.__func__ is original_swiglu_forward.__func__
@patch("torch.accelerator.current_accelerator")
@@ -56,12 +58,15 @@ def test_apply_all_kernels(mock_get_accelerator: MagicMock):
setattr(mock_device, "type", "npu")
mock_get_accelerator.return_value = mock_device
# Force reload of kernels with mocked accelerator
reload_kernels()
from llamafactory.v1.plugins.model_plugins.kernels.interface import apply_default_kernels
model = AutoModelForCausalLM.from_pretrained("llamafactory/tiny-random-qwen2.5")
original_rmsnorm_forward = model.model.layers[0].input_layernorm.forward
original_swiglu_forward = model.model.layers[0].mlp.forward
model = apply_available_kernels(model)
assert model.model.layers[0].input_layernorm is not original_rmsnorm_forward
assert model.model.layers[0].mlp.forward is not original_swiglu_forward
model = apply_default_kernels(model=model, include_kernels=True)
assert model.model.layers[0].input_layernorm.forward.__func__ is not original_rmsnorm_forward.__func__
assert model.model.layers[0].mlp.forward.__func__ is not original_swiglu_forward.__func__