mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-02-27 00:05:58 +08:00
Compare commits
11 Commits
eceec8ab69
...
v0.9.4
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
95ac3f2373 | ||
|
|
000526908a | ||
|
|
c8d7e85b3e | ||
|
|
16735b9e35 | ||
|
|
4e1d69579a | ||
|
|
1857fbdd6b | ||
|
|
bb1ba31005 | ||
|
|
e97d0474fb | ||
|
|
3f0c3dc84d | ||
|
|
c107cc22d0 | ||
|
|
7ef1fba34a |
2
.github/workflows/docker.yml
vendored
2
.github/workflows/docker.yml
vendored
@@ -72,7 +72,7 @@ jobs:
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
|
||||
- name: Login to Quay
|
||||
if: ${{ github.event_name != 'pull_request' && matrix.device == 'npu'}}
|
||||
if: ${{ github.event_name != 'pull_request' && startsWith(matrix.device, 'npu') }}
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: quay.io
|
||||
|
||||
10
.github/workflows/tests.yml
vendored
10
.github/workflows/tests.yml
vendored
@@ -27,23 +27,23 @@ jobs:
|
||||
python:
|
||||
- "3.11"
|
||||
- "3.12"
|
||||
# - "3.13" # enable after trl is upgraded
|
||||
- "3.13"
|
||||
os:
|
||||
- "ubuntu-latest"
|
||||
- "windows-latest"
|
||||
- "macos-latest"
|
||||
transformers:
|
||||
- null
|
||||
- ""
|
||||
include: # test backward compatibility
|
||||
- python: "3.11"
|
||||
os: "ubuntu-latest"
|
||||
transformers: "4.49.0"
|
||||
- python: "3.11"
|
||||
os: "ubuntu-latest"
|
||||
transformers: "4.51.0"
|
||||
- python: "3.11"
|
||||
os: "ubuntu-latest"
|
||||
transformers: "4.53.0"
|
||||
- python: "3.11"
|
||||
os: "ubuntu-latest"
|
||||
transformers: "4.55.0"
|
||||
|
||||
runs-on: ${{ matrix.os }}
|
||||
|
||||
|
||||
88
.github/workflows/tests_cuda.yml
vendored
Normal file
88
.github/workflows/tests_cuda.yml
vendored
Normal file
@@ -0,0 +1,88 @@
|
||||
name: tests_cuda
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
push:
|
||||
branches:
|
||||
- "main"
|
||||
paths:
|
||||
- "**/*.py"
|
||||
- "pyproject.toml"
|
||||
- "Makefile"
|
||||
- ".github/workflows/*.yml"
|
||||
pull_request:
|
||||
branches:
|
||||
- "main"
|
||||
paths:
|
||||
- "**/*.py"
|
||||
- "pyproject.toml"
|
||||
- "Makefile"
|
||||
- ".github/workflows/*.yml"
|
||||
|
||||
jobs:
|
||||
tests:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python:
|
||||
- "3.11"
|
||||
os:
|
||||
- "linux-x86_64-gpu-2"
|
||||
|
||||
runs-on: ${{ matrix.os }}
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}-${{ matrix.os }}-${{ matrix.python }}
|
||||
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
|
||||
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v7
|
||||
with:
|
||||
python-version: ${{ matrix.python }}
|
||||
github-token: ${{ github.token }}
|
||||
enable-cache: false
|
||||
|
||||
- name: Check GPU Status
|
||||
run: nvidia-smi
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
uv venv
|
||||
uv pip install -e ".[dev]"
|
||||
|
||||
- name: Cache HuggingFace models
|
||||
id: hf-hub-cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ${{ runner.temp }}/huggingface
|
||||
key: hf-cache-${{ runner.os }}-${{ hashFiles('tests/version.txt') }}
|
||||
|
||||
- name: Check quality
|
||||
run: |
|
||||
make style && make quality
|
||||
env:
|
||||
UV_NO_SYNC: 1
|
||||
|
||||
- name: Check license
|
||||
run: |
|
||||
make license
|
||||
env:
|
||||
UV_NO_SYNC: 1
|
||||
|
||||
- name: Check build
|
||||
run: |
|
||||
make build
|
||||
env:
|
||||
UV_NO_SYNC: 1
|
||||
|
||||
- name: Test with pytest
|
||||
run: |
|
||||
make test
|
||||
env:
|
||||
UV_NO_SYNC: 1
|
||||
HF_HOME: ${{ runner.temp }}/huggingface
|
||||
HF_HUB_OFFLINE: "${{ steps.hf-hub-cache.outputs.cache-hit == 'true' && '1' || '0' }}"
|
||||
7
.github/workflows/tests_npu.yml
vendored
7
.github/workflows/tests_npu.yml
vendored
@@ -49,8 +49,11 @@ jobs:
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install uv
|
||||
run: |
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
uses: astral-sh/setup-uv@v7
|
||||
with:
|
||||
python-version: ${{ matrix.python }}
|
||||
github-token: ${{ github.token }}
|
||||
enable-cache: false
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
|
||||
16
README.md
16
README.md
@@ -309,6 +309,7 @@ Read technical notes:
|
||||
| [MiMo](https://huggingface.co/XiaomiMiMo) | 7B/309B | mimo/mimo_v2 |
|
||||
| [MiniCPM 1-4.1](https://huggingface.co/openbmb) | 0.5B/1B/2B/4B/8B | cpm/cpm3/cpm4 |
|
||||
| [MiniCPM-o-2.6/MiniCPM-V-2.6](https://huggingface.co/openbmb) | 8B | minicpm_o/minicpm_v |
|
||||
| [MiniMax-M1/MiniMax-M2](https://huggingface.co/MiniMaxAI/models) | 229B/456B | minimax1/minimax2 |
|
||||
| [Ministral 3](https://huggingface.co/mistralai) | 3B/8B/14B | ministral3 |
|
||||
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
|
||||
| [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
|
||||
@@ -433,6 +434,7 @@ You also can add a custom chat template to [template.py](src/llamafactory/data/t
|
||||
- [Chinese-DeepSeek-R1-Distill (zh)](https://huggingface.co/datasets/Congliu/Chinese-DeepSeek-R1-Distill-data-110k-SFT)
|
||||
- [LLaVA mixed (en&zh)](https://huggingface.co/datasets/BUAADreamer/llava-en-zh-300k)
|
||||
- [Pokemon-gpt4o-captions (en&zh)](https://huggingface.co/datasets/jugg1024/pokemon-gpt4o-captions)
|
||||
- [DLR-Web (en)](https://huggingface.co/datasets/Attention1115/DLR-Web)
|
||||
- [Open Assistant (de)](https://huggingface.co/datasets/mayflowergmbh/oasst_de)
|
||||
- [Dolly 15k (de)](https://huggingface.co/datasets/mayflowergmbh/dolly-15k_de)
|
||||
- [Alpaca GPT4 (de)](https://huggingface.co/datasets/mayflowergmbh/alpaca-gpt4_de)
|
||||
@@ -514,7 +516,7 @@ huggingface-cli login
|
||||
```bash
|
||||
git clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git
|
||||
cd LLaMA-Factory
|
||||
pip install -e ".[metrics]" --no-build-isolation
|
||||
pip install -e ".[metrics]"
|
||||
```
|
||||
|
||||
Optional dependencies available: `metrics`, `deepspeed`. Install with: `pip install -e ".[metrics,deepspeed]"`
|
||||
@@ -637,7 +639,7 @@ cd transformers
|
||||
pip install .
|
||||
```
|
||||
|
||||
3. Set `double_quantization: false` in the configuration. You can refer to the [example](examples/train_qlora/llama3_lora_sft_bnb_npu.yaml).
|
||||
3. Set `double_quantization: false` in the configuration. You can refer to the [example](examples/train_qlora/qwen3_lora_sft_bnb_npu.yaml).
|
||||
|
||||
</details>
|
||||
|
||||
@@ -652,12 +654,12 @@ You can also use **[Easy Dataset](https://github.com/ConardLi/easy-dataset)**, *
|
||||
|
||||
### Quickstart
|
||||
|
||||
Use the following 3 commands to run LoRA **fine-tuning**, **inference** and **merging** of the Llama3-8B-Instruct model, respectively.
|
||||
Use the following 3 commands to run LoRA **fine-tuning**, **inference** and **merging** of the Qwen3-4B-Instruct model, respectively.
|
||||
|
||||
```bash
|
||||
llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
|
||||
llamafactory-cli chat examples/inference/llama3_lora_sft.yaml
|
||||
llamafactory-cli export examples/merge_lora/llama3_lora_sft.yaml
|
||||
llamafactory-cli train examples/train_lora/qwen3_lora_sft.yaml
|
||||
llamafactory-cli chat examples/inference/qwen3_lora_sft.yaml
|
||||
llamafactory-cli export examples/merge_lora/qwen3_lora_sft.yaml
|
||||
```
|
||||
|
||||
See [examples/README.md](examples/README.md) for advanced usage (including distributed training).
|
||||
@@ -780,7 +782,7 @@ When building the Docker image, use `-v ./hf_cache:/root/.cache/huggingface` arg
|
||||
### Deploy with OpenAI-style API and vLLM
|
||||
|
||||
```bash
|
||||
API_PORT=8000 llamafactory-cli api examples/inference/llama3.yaml infer_backend=vllm vllm_enforce_eager=true
|
||||
API_PORT=8000 llamafactory-cli api examples/inference/qwen3.yaml infer_backend=vllm vllm_enforce_eager=true
|
||||
```
|
||||
|
||||
> [!TIP]
|
||||
|
||||
16
README_zh.md
16
README_zh.md
@@ -311,6 +311,7 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
|
||||
| [MiMo](https://huggingface.co/XiaomiMiMo) | 7B/309B | mimo/mimo_v2 |
|
||||
| [MiniCPM 1-4.1](https://huggingface.co/openbmb) | 0.5B/1B/2B/4B/8B | cpm/cpm3/cpm4 |
|
||||
| [MiniCPM-o-2.6/MiniCPM-V-2.6](https://huggingface.co/openbmb) | 8B | minicpm_o/minicpm_v |
|
||||
| [MiniMax-M1/MiniMax-M2](https://huggingface.co/MiniMaxAI/models) | 229B/456B | minimax1/minimax2 |
|
||||
| [Ministral 3](https://huggingface.co/mistralai) | 3B/8B/14B | ministral3 |
|
||||
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
|
||||
| [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
|
||||
@@ -435,6 +436,7 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
|
||||
- [Chinese-DeepSeek-R1-Distill (zh)](https://huggingface.co/datasets/Congliu/Chinese-DeepSeek-R1-Distill-data-110k-SFT)
|
||||
- [LLaVA mixed (en&zh)](https://huggingface.co/datasets/BUAADreamer/llava-en-zh-300k)
|
||||
- [Pokemon-gpt4o-captions (en&zh)](https://huggingface.co/datasets/jugg1024/pokemon-gpt4o-captions)
|
||||
- [DLR-Web (en)](https://huggingface.co/datasets/Attention1115/DLR-Web)
|
||||
- [Open Assistant (de)](https://huggingface.co/datasets/mayflowergmbh/oasst_de)
|
||||
- [Dolly 15k (de)](https://huggingface.co/datasets/mayflowergmbh/dolly-15k_de)
|
||||
- [Alpaca GPT4 (de)](https://huggingface.co/datasets/mayflowergmbh/alpaca-gpt4_de)
|
||||
@@ -516,7 +518,7 @@ huggingface-cli login
|
||||
```bash
|
||||
git clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git
|
||||
cd LLaMA-Factory
|
||||
pip install -e ".[metrics]" --no-build-isolation
|
||||
pip install -e ".[metrics]"
|
||||
```
|
||||
|
||||
可选的额外依赖项:`metrics`、`deepspeed`。使用 `pip install -e ".[metrics,deepspeed]"` 安装。
|
||||
@@ -639,7 +641,7 @@ cd transformers
|
||||
pip install .
|
||||
```
|
||||
|
||||
3. 在训练参数中设置 `double_quantization: false`,可参考[示例](examples/train_qlora/llama3_lora_sft_bnb_npu.yaml)。
|
||||
3. 在训练参数中设置 `double_quantization: false`,可参考[示例](examples/train_qlora/qwen3_lora_sft_bnb_npu.yaml)。
|
||||
|
||||
</details>
|
||||
|
||||
@@ -654,12 +656,12 @@ pip install .
|
||||
|
||||
### 快速开始
|
||||
|
||||
下面三行命令分别对 Llama3-8B-Instruct 模型进行 LoRA **微调**、**推理**和**合并**。
|
||||
下面三行命令分别对 Qwen3-4B-Instruct 模型进行 LoRA **微调**、**推理**和**合并**。
|
||||
|
||||
```bash
|
||||
llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
|
||||
llamafactory-cli chat examples/inference/llama3_lora_sft.yaml
|
||||
llamafactory-cli export examples/merge_lora/llama3_lora_sft.yaml
|
||||
llamafactory-cli train examples/train_lora/qwen3_lora_sft.yaml
|
||||
llamafactory-cli chat examples/inference/qwen3_lora_sft.yaml
|
||||
llamafactory-cli export examples/merge_lora/qwen3_lora_sft.yaml
|
||||
```
|
||||
|
||||
高级用法请参考 [examples/README_zh.md](examples/README_zh.md)(包括多 GPU 微调)。
|
||||
@@ -785,7 +787,7 @@ docker exec -it llamafactory bash
|
||||
### 利用 vLLM 部署 OpenAI API
|
||||
|
||||
```bash
|
||||
API_PORT=8000 llamafactory-cli api examples/inference/llama3.yaml infer_backend=vllm vllm_enforce_eager=true
|
||||
API_PORT=8000 llamafactory-cli api examples/inference/qwen3.yaml infer_backend=vllm vllm_enforce_eager=true
|
||||
```
|
||||
|
||||
> [!TIP]
|
||||
|
||||
@@ -471,6 +471,14 @@
|
||||
"ultrachat_de": {
|
||||
"hf_hub_url": "mayflowergmbh/ultra-chat_de"
|
||||
},
|
||||
"dlr_web": {
|
||||
"hf_hub_url": "Attention1115/DLR-Web",
|
||||
"split": "full",
|
||||
"columns": {
|
||||
"prompt": "question",
|
||||
"response": "response"
|
||||
}
|
||||
},
|
||||
"dpo_en_demo": {
|
||||
"file_name": "dpo_en_demo.json",
|
||||
"ranking": true,
|
||||
|
||||
@@ -18,19 +18,19 @@ By default, LLaMA-Factory uses all visible computing devices.
|
||||
Basic usage:
|
||||
|
||||
```bash
|
||||
llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
|
||||
llamafactory-cli train examples/train_lora/qwen3_lora_sft.yaml
|
||||
```
|
||||
|
||||
Advanced usage:
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0,1 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml \
|
||||
CUDA_VISIBLE_DEVICES=0,1 llamafactory-cli train examples/train_lora/qwen3_lora_sft.yaml \
|
||||
learning_rate=1e-5 \
|
||||
logging_steps=1
|
||||
```
|
||||
|
||||
```bash
|
||||
bash examples/train_lora/llama3_lora_sft.sh
|
||||
bash examples/train_lora/qwen3_lora_sft.sh
|
||||
```
|
||||
|
||||
## Examples
|
||||
@@ -40,49 +40,43 @@ bash examples/train_lora/llama3_lora_sft.sh
|
||||
#### (Continuous) Pre-Training
|
||||
|
||||
```bash
|
||||
llamafactory-cli train examples/train_lora/llama3_lora_pretrain.yaml
|
||||
llamafactory-cli train examples/train_lora/qwen3_lora_pretrain.yaml
|
||||
```
|
||||
|
||||
#### Supervised Fine-Tuning
|
||||
|
||||
```bash
|
||||
llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
|
||||
llamafactory-cli train examples/train_lora/qwen3_lora_sft.yaml
|
||||
```
|
||||
|
||||
#### Multimodal Supervised Fine-Tuning
|
||||
|
||||
```bash
|
||||
llamafactory-cli train examples/train_lora/qwen2_5vl_lora_sft.yaml
|
||||
llamafactory-cli train examples/train_lora/qwen3vl_lora_sft.yaml
|
||||
```
|
||||
|
||||
#### DPO/ORPO/SimPO Training
|
||||
|
||||
```bash
|
||||
llamafactory-cli train examples/train_lora/llama3_lora_dpo.yaml
|
||||
llamafactory-cli train examples/train_lora/qwen3_lora_dpo.yaml
|
||||
```
|
||||
|
||||
#### Multimodal DPO/ORPO/SimPO Training
|
||||
|
||||
```bash
|
||||
llamafactory-cli train examples/train_lora/qwen2_5vl_lora_dpo.yaml
|
||||
llamafactory-cli train examples/train_lora/qwen3vl_lora_dpo.yaml
|
||||
```
|
||||
|
||||
#### Reward Modeling
|
||||
|
||||
```bash
|
||||
llamafactory-cli train examples/train_lora/llama3_lora_reward.yaml
|
||||
```
|
||||
|
||||
#### PPO Training
|
||||
|
||||
```bash
|
||||
llamafactory-cli train examples/train_lora/llama3_lora_ppo.yaml
|
||||
llamafactory-cli train examples/train_lora/qwen3_lora_reward.yaml
|
||||
```
|
||||
|
||||
#### KTO Training
|
||||
|
||||
```bash
|
||||
llamafactory-cli train examples/train_lora/llama3_lora_kto.yaml
|
||||
llamafactory-cli train examples/train_lora/qwen3_lora_kto.yaml
|
||||
```
|
||||
|
||||
#### Preprocess Dataset
|
||||
@@ -90,32 +84,26 @@ llamafactory-cli train examples/train_lora/llama3_lora_kto.yaml
|
||||
It is useful for large dataset, use `tokenized_path` in config to load the preprocessed dataset.
|
||||
|
||||
```bash
|
||||
llamafactory-cli train examples/train_lora/llama3_preprocess.yaml
|
||||
```
|
||||
|
||||
#### Evaluating on MMLU/CMMLU/C-Eval Benchmarks
|
||||
|
||||
```bash
|
||||
llamafactory-cli eval examples/train_lora/llama3_lora_eval.yaml
|
||||
llamafactory-cli train examples/train_lora/qwen3_preprocess.yaml
|
||||
```
|
||||
|
||||
#### Supervised Fine-Tuning on Multiple Nodes
|
||||
|
||||
```bash
|
||||
FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
|
||||
FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
|
||||
FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/qwen3_lora_sft.yaml
|
||||
FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/qwen3_lora_sft.yaml
|
||||
```
|
||||
|
||||
#### Supervised Fine-Tuning with DeepSpeed ZeRO-3 (Weight Sharding)
|
||||
|
||||
```bash
|
||||
FORCE_TORCHRUN=1 llamafactory-cli train examples/train_lora/llama3_lora_sft_ds3.yaml
|
||||
FORCE_TORCHRUN=1 llamafactory-cli train examples/train_lora/qwen3_lora_sft_ds3.yaml
|
||||
```
|
||||
|
||||
#### Supervised Fine-Tuning with Ray on 4 GPUs
|
||||
|
||||
```bash
|
||||
USE_RAY=1 llamafactory-cli train examples/train_lora/llama3_lora_sft_ray.yaml
|
||||
USE_RAY=1 llamafactory-cli train examples/train_lora/qwen3_lora_sft_ray.yaml
|
||||
```
|
||||
|
||||
### QLoRA Fine-Tuning
|
||||
@@ -123,13 +111,13 @@ USE_RAY=1 llamafactory-cli train examples/train_lora/llama3_lora_sft_ray.yaml
|
||||
#### Supervised Fine-Tuning with 4/8-bit Bitsandbytes/HQQ/EETQ Quantization (Recommended)
|
||||
|
||||
```bash
|
||||
llamafactory-cli train examples/train_qlora/llama3_lora_sft_otfq.yaml
|
||||
llamafactory-cli train examples/train_qlora/qwen3_lora_sft_otfq.yaml
|
||||
```
|
||||
|
||||
#### Supervised Fine-Tuning with 4-bit Bitsandbytes Quantization on Ascend NPU
|
||||
|
||||
```bash
|
||||
llamafactory-cli train examples/train_qlora/llama3_lora_sft_bnb_npu.yaml
|
||||
llamafactory-cli train examples/train_qlora/qwen3_lora_sft_bnb_npu.yaml
|
||||
```
|
||||
|
||||
#### Supervised Fine-Tuning with 4/8-bit GPTQ Quantization
|
||||
@@ -155,14 +143,14 @@ llamafactory-cli train examples/train_qlora/llama3_lora_sft_aqlm.yaml
|
||||
#### Supervised Fine-Tuning on Single Node
|
||||
|
||||
```bash
|
||||
FORCE_TORCHRUN=1 llamafactory-cli train examples/train_full/llama3_full_sft.yaml
|
||||
FORCE_TORCHRUN=1 llamafactory-cli train examples/train_full/qwen3_full_sft.yaml
|
||||
```
|
||||
|
||||
#### Supervised Fine-Tuning on Multiple Nodes
|
||||
|
||||
```bash
|
||||
FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft.yaml
|
||||
FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft.yaml
|
||||
FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/qwen3_full_sft.yaml
|
||||
FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/qwen3_full_sft.yaml
|
||||
```
|
||||
|
||||
### Elastic and Fault-Tolerant Supervised Fine-Tuning on Multiple Nodes
|
||||
@@ -170,13 +158,13 @@ FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500
|
||||
To launch an elastic job with `MAX_RESTARTS` failures retries, run the following on at least `MIN_NNODES` nodes and at most `MAX_NNODES` nodes. `RDZV_ID` should be set as a unique job id (shared by all nodes participating in the job). See also [torchrun](https://docs.pytorch.org/docs/stable/elastic/run.html).
|
||||
|
||||
```bash
|
||||
FORCE_TORCHRUN=1 MIN_NNODES=1 MAX_NNODES=3 MAX_RESTARTS=3 RDZV_ID=llamafactory MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft.yaml
|
||||
FORCE_TORCHRUN=1 MIN_NNODES=1 MAX_NNODES=3 MAX_RESTARTS=3 RDZV_ID=llamafactory MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/qwen3_full_sft.yaml
|
||||
```
|
||||
|
||||
#### Multimodal Supervised Fine-Tuning
|
||||
|
||||
```bash
|
||||
FORCE_TORCHRUN=1 llamafactory-cli train examples/train_full/qwen2_5vl_full_sft.yaml
|
||||
FORCE_TORCHRUN=1 llamafactory-cli train examples/train_full/qwen3vl_full_sft.yaml
|
||||
```
|
||||
|
||||
### Merging LoRA Adapters and Quantization
|
||||
@@ -186,19 +174,19 @@ FORCE_TORCHRUN=1 llamafactory-cli train examples/train_full/qwen2_5vl_full_sft.y
|
||||
Note: DO NOT use quantized model or `quantization_bit` when merging LoRA adapters.
|
||||
|
||||
```bash
|
||||
llamafactory-cli export examples/merge_lora/llama3_lora_sft.yaml
|
||||
llamafactory-cli export examples/merge_lora/qwen3_lora_sft.yaml
|
||||
```
|
||||
|
||||
#### Quantizing Model using AutoGPTQ
|
||||
|
||||
```bash
|
||||
llamafactory-cli export examples/merge_lora/llama3_gptq.yaml
|
||||
llamafactory-cli export examples/merge_lora/qwen3_gptq.yaml
|
||||
```
|
||||
|
||||
### Save Ollama modelfile
|
||||
|
||||
```bash
|
||||
llamafactory-cli export examples/merge_lora/llama3_full_sft.yaml
|
||||
llamafactory-cli export examples/merge_lora/qwen3_full_sft.yaml
|
||||
```
|
||||
|
||||
### Inferring LoRA Fine-Tuned Models
|
||||
@@ -206,26 +194,26 @@ llamafactory-cli export examples/merge_lora/llama3_full_sft.yaml
|
||||
#### Evaluation using vLLM's Multi-GPU Inference
|
||||
|
||||
```
|
||||
python scripts/vllm_infer.py --model_name_or_path meta-llama/Meta-Llama-3-8B-Instruct --template llama3 --dataset alpaca_en_demo
|
||||
python scripts/vllm_infer.py --model_name_or_path Qwen/Qwen3-4B-Instruct-2507 --template qwen3_nothink --dataset alpaca_en_demo
|
||||
python scripts/eval_bleu_rouge.py generated_predictions.jsonl
|
||||
```
|
||||
|
||||
#### Use CLI ChatBox
|
||||
|
||||
```bash
|
||||
llamafactory-cli chat examples/inference/llama3_lora_sft.yaml
|
||||
llamafactory-cli chat examples/inference/qwen3_lora_sft.yaml
|
||||
```
|
||||
|
||||
#### Use Web UI ChatBox
|
||||
|
||||
```bash
|
||||
llamafactory-cli webchat examples/inference/llama3_lora_sft.yaml
|
||||
llamafactory-cli webchat examples/inference/qwen3_lora_sft.yaml
|
||||
```
|
||||
|
||||
#### Launch OpenAI-style API
|
||||
|
||||
```bash
|
||||
llamafactory-cli api examples/inference/llama3_lora_sft.yaml
|
||||
llamafactory-cli api examples/inference/qwen3_lora_sft.yaml
|
||||
```
|
||||
|
||||
### Extras
|
||||
|
||||
@@ -18,19 +18,19 @@ LLaMA-Factory 默认使用所有可见的计算设备。
|
||||
基础用法:
|
||||
|
||||
```bash
|
||||
llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
|
||||
llamafactory-cli train examples/train_lora/qwen3_lora_sft.yaml
|
||||
```
|
||||
|
||||
高级用法:
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0,1 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml \
|
||||
CUDA_VISIBLE_DEVICES=0,1 llamafactory-cli train examples/train_lora/qwen3_lora_sft.yaml \
|
||||
learning_rate=1e-5 \
|
||||
logging_steps=1
|
||||
```
|
||||
|
||||
```bash
|
||||
bash examples/train_lora/llama3_lora_sft.sh
|
||||
bash examples/train_lora/qwen3_lora_sft.sh
|
||||
```
|
||||
|
||||
## 示例
|
||||
@@ -40,49 +40,43 @@ bash examples/train_lora/llama3_lora_sft.sh
|
||||
#### (增量)预训练
|
||||
|
||||
```bash
|
||||
llamafactory-cli train examples/train_lora/llama3_lora_pretrain.yaml
|
||||
llamafactory-cli train examples/train_lora/qwen3_lora_pretrain.yaml
|
||||
```
|
||||
|
||||
#### 指令监督微调
|
||||
|
||||
```bash
|
||||
llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
|
||||
llamafactory-cli train examples/train_lora/qwen3_lora_sft.yaml
|
||||
```
|
||||
|
||||
#### 多模态指令监督微调
|
||||
|
||||
```bash
|
||||
llamafactory-cli train examples/train_lora/qwen2_5vl_lora_sft.yaml
|
||||
llamafactory-cli train examples/train_lora/qwen3vl_lora_sft.yaml
|
||||
```
|
||||
|
||||
#### DPO/ORPO/SimPO 训练
|
||||
|
||||
```bash
|
||||
llamafactory-cli train examples/train_lora/llama3_lora_dpo.yaml
|
||||
llamafactory-cli train examples/train_lora/qwen3_lora_dpo.yaml
|
||||
```
|
||||
|
||||
#### 多模态 DPO/ORPO/SimPO 训练
|
||||
|
||||
```bash
|
||||
llamafactory-cli train examples/train_lora/qwen2_5vl_lora_dpo.yaml
|
||||
llamafactory-cli train examples/train_lora/qwen3vl_lora_dpo.yaml
|
||||
```
|
||||
|
||||
#### 奖励模型训练
|
||||
|
||||
```bash
|
||||
llamafactory-cli train examples/train_lora/llama3_lora_reward.yaml
|
||||
```
|
||||
|
||||
#### PPO 训练
|
||||
|
||||
```bash
|
||||
llamafactory-cli train examples/train_lora/llama3_lora_ppo.yaml
|
||||
llamafactory-cli train examples/train_lora/qwen3_lora_reward.yaml
|
||||
```
|
||||
|
||||
#### KTO 训练
|
||||
|
||||
```bash
|
||||
llamafactory-cli train examples/train_lora/llama3_lora_kto.yaml
|
||||
llamafactory-cli train examples/train_lora/qwen3_lora_kto.yaml
|
||||
```
|
||||
|
||||
#### 预处理数据集
|
||||
@@ -90,20 +84,14 @@ llamafactory-cli train examples/train_lora/llama3_lora_kto.yaml
|
||||
对于大数据集有帮助,在配置中使用 `tokenized_path` 以加载预处理后的数据集。
|
||||
|
||||
```bash
|
||||
llamafactory-cli train examples/train_lora/llama3_preprocess.yaml
|
||||
```
|
||||
|
||||
#### 在 MMLU/CMMLU/C-Eval 上评估
|
||||
|
||||
```bash
|
||||
llamafactory-cli eval examples/train_lora/llama3_lora_eval.yaml
|
||||
llamafactory-cli train examples/train_lora/qwen3_preprocess.yaml
|
||||
```
|
||||
|
||||
#### 多机指令监督微调
|
||||
|
||||
```bash
|
||||
FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
|
||||
FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
|
||||
FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/qwen3_lora_sft.yaml
|
||||
FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/qwen3_lora_sft.yaml
|
||||
```
|
||||
|
||||
### 支持弹性和容错的多机指令监督微调
|
||||
@@ -111,19 +99,19 @@ FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500
|
||||
要启动一个支持弹性节点和容错的多机指令微调,在每个节点上执行以下命令。弹性节点数量范围为 `MIN_NNODES:MAX_NNODES`,每个节点最多允许因为错误重启 `MAX_RESTARTS` 次。`RDZV_ID` 应设置为一个唯一的作业 ID(由参与该作业的所有节点共享)。更多新可以参考官方文档 [torchrun](https://docs.pytorch.org/docs/stable/elastic/run.html)。
|
||||
|
||||
```bash
|
||||
FORCE_TORCHRUN=1 MIN_NNODES=1 MAX_NNODES=3 MAX_RESTARTS=3 RDZV_ID=llamafactory MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft.yaml
|
||||
FORCE_TORCHRUN=1 MIN_NNODES=1 MAX_NNODES=3 MAX_RESTARTS=3 RDZV_ID=llamafactory MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/qwen3_full_sft.yaml
|
||||
```
|
||||
|
||||
#### 使用 DeepSpeed ZeRO-3 平均分配显存
|
||||
|
||||
```bash
|
||||
FORCE_TORCHRUN=1 llamafactory-cli train examples/train_lora/llama3_lora_sft_ds3.yaml
|
||||
FORCE_TORCHRUN=1 llamafactory-cli train examples/train_lora/qwen3_lora_sft_ds3.yaml
|
||||
```
|
||||
|
||||
#### 使用 Ray 在 4 张 GPU 上微调
|
||||
|
||||
```bash
|
||||
USE_RAY=1 llamafactory-cli train examples/train_lora/llama3_lora_sft_ray.yaml
|
||||
USE_RAY=1 llamafactory-cli train examples/train_lora/qwen3_lora_sft_ray.yaml
|
||||
```
|
||||
|
||||
### QLoRA 微调
|
||||
@@ -131,13 +119,13 @@ USE_RAY=1 llamafactory-cli train examples/train_lora/llama3_lora_sft_ray.yaml
|
||||
#### 基于 4/8 比特 Bitsandbytes/HQQ/EETQ 量化进行指令监督微调(推荐)
|
||||
|
||||
```bash
|
||||
llamafactory-cli train examples/train_qlora/llama3_lora_sft_otfq.yaml
|
||||
llamafactory-cli train examples/train_qlora/qwen3_lora_sft_otfq.yaml
|
||||
```
|
||||
|
||||
#### 在 NPU 上基于 4 比特 Bitsandbytes 量化进行指令监督微调
|
||||
|
||||
```bash
|
||||
llamafactory-cli train examples/train_qlora/llama3_lora_sft_bnb_npu.yaml
|
||||
llamafactory-cli train examples/train_qlora/qwen3_lora_sft_bnb_npu.yaml
|
||||
```
|
||||
|
||||
#### 基于 4/8 比特 GPTQ 量化进行指令监督微调
|
||||
@@ -163,20 +151,20 @@ llamafactory-cli train examples/train_qlora/llama3_lora_sft_aqlm.yaml
|
||||
#### 在单机上进行指令监督微调
|
||||
|
||||
```bash
|
||||
FORCE_TORCHRUN=1 llamafactory-cli train examples/train_full/llama3_full_sft.yaml
|
||||
FORCE_TORCHRUN=1 llamafactory-cli train examples/train_full/qwen3_full_sft.yaml
|
||||
```
|
||||
|
||||
#### 在多机上进行指令监督微调
|
||||
|
||||
```bash
|
||||
FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft.yaml
|
||||
FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft.yaml
|
||||
FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/qwen3_full_sft.yaml
|
||||
FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/qwen3_full_sft.yaml
|
||||
```
|
||||
|
||||
#### 多模态指令监督微调
|
||||
|
||||
```bash
|
||||
FORCE_TORCHRUN=1 llamafactory-cli train examples/train_full/qwen2_5vl_full_sft.yaml
|
||||
FORCE_TORCHRUN=1 llamafactory-cli train examples/train_full/qwen3vl_full_sft.yaml
|
||||
```
|
||||
|
||||
### 合并 LoRA 适配器与模型量化
|
||||
@@ -186,19 +174,19 @@ FORCE_TORCHRUN=1 llamafactory-cli train examples/train_full/qwen2_5vl_full_sft.y
|
||||
注:请勿使用量化后的模型或 `quantization_bit` 参数来合并 LoRA 适配器。
|
||||
|
||||
```bash
|
||||
llamafactory-cli export examples/merge_lora/llama3_lora_sft.yaml
|
||||
llamafactory-cli export examples/merge_lora/qwen3_lora_sft.yaml
|
||||
```
|
||||
|
||||
#### 使用 AutoGPTQ 量化模型
|
||||
|
||||
```bash
|
||||
llamafactory-cli export examples/merge_lora/llama3_gptq.yaml
|
||||
llamafactory-cli export examples/merge_lora/qwen3_gptq.yaml
|
||||
```
|
||||
|
||||
### 保存 Ollama 配置文件
|
||||
|
||||
```bash
|
||||
llamafactory-cli export examples/merge_lora/llama3_full_sft.yaml
|
||||
llamafactory-cli export examples/merge_lora/qwen3_full_sft.yaml
|
||||
```
|
||||
|
||||
### 推理 LoRA 模型
|
||||
@@ -206,26 +194,26 @@ llamafactory-cli export examples/merge_lora/llama3_full_sft.yaml
|
||||
#### 使用 vLLM 多卡推理评估
|
||||
|
||||
```
|
||||
python scripts/vllm_infer.py --model_name_or_path meta-llama/Meta-Llama-3-8B-Instruct --template llama3 --dataset alpaca_en_demo
|
||||
python scripts/vllm_infer.py --model_name_or_path Qwen/Qwen3-4B-Instruct-2507 --template qwen3_nothink --dataset alpaca_en_demo
|
||||
python scripts/eval_bleu_rouge.py generated_predictions.jsonl
|
||||
```
|
||||
|
||||
#### 使用命令行对话框
|
||||
|
||||
```bash
|
||||
llamafactory-cli chat examples/inference/llama3_lora_sft.yaml
|
||||
llamafactory-cli chat examples/inference/qwen3_lora_sft.yaml
|
||||
```
|
||||
|
||||
#### 使用浏览器对话框
|
||||
|
||||
```bash
|
||||
llamafactory-cli webchat examples/inference/llama3_lora_sft.yaml
|
||||
llamafactory-cli webchat examples/inference/qwen3_lora_sft.yaml
|
||||
```
|
||||
|
||||
#### 启动 OpenAI 风格 API
|
||||
|
||||
```bash
|
||||
llamafactory-cli api examples/inference/llama3_lora_sft.yaml
|
||||
llamafactory-cli api examples/inference/qwen3_lora_sft.yaml
|
||||
```
|
||||
|
||||
### 杂项
|
||||
|
||||
@@ -1,5 +0,0 @@
|
||||
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
||||
adapter_name_or_path: saves/llama3-8b/lora/sft
|
||||
template: llama3
|
||||
infer_backend: huggingface # choices: [huggingface, vllm, sglang, ktransformers]
|
||||
trust_remote_code: true
|
||||
@@ -1,4 +1,4 @@
|
||||
model_name_or_path: Qwen/Qwen2.5-VL-7B-Instruct
|
||||
template: qwen2_vl
|
||||
model_name_or_path: Qwen/Qwen3-4B-Instruct-2507
|
||||
template: qwen3_nothink
|
||||
infer_backend: huggingface # choices: [huggingface, vllm, sglang, ktransformers]
|
||||
trust_remote_code: true
|
||||
@@ -1,4 +1,4 @@
|
||||
model_name_or_path: saves/llama3-8b/full/sft
|
||||
template: llama3
|
||||
model_name_or_path: saves/qwen3-4b/full/sft
|
||||
template: qwen3_nothink
|
||||
infer_backend: huggingface # choices: [huggingface, vllm, sglang, ktransformers]
|
||||
trust_remote_code: true
|
||||
5
examples/inference/qwen3_lora_sft.yaml
Normal file
5
examples/inference/qwen3_lora_sft.yaml
Normal file
@@ -0,0 +1,5 @@
|
||||
model_name_or_path: Qwen/Qwen3-4B-Instruct-2507
|
||||
adapter_name_or_path: saves/qwen3-4b/lora/sft
|
||||
template: qwen3_nothink
|
||||
infer_backend: huggingface # choices: [huggingface, vllm, sglang, ktransformers]
|
||||
trust_remote_code: true
|
||||
@@ -1,4 +1,4 @@
|
||||
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
||||
template: llama3
|
||||
model_name_or_path: Qwen/Qwen3-VL-4B-Instruct
|
||||
template: qwen3_vl_nothink
|
||||
infer_backend: huggingface # choices: [huggingface, vllm, sglang, ktransformers]
|
||||
trust_remote_code: true
|
||||
@@ -1,10 +1,10 @@
|
||||
### model
|
||||
model_name_or_path: saves/llama3-8b/full/sft
|
||||
template: llama3
|
||||
model_name_or_path: saves/qwen3-4b/full/sft
|
||||
template: qwen3_nothink
|
||||
trust_remote_code: true
|
||||
|
||||
### export
|
||||
export_dir: output/llama3_full_sft
|
||||
export_dir: saves/qwen3_sft_merged
|
||||
export_size: 5
|
||||
export_device: cpu # choices: [cpu, auto]
|
||||
export_legacy_format: false
|
||||
@@ -1,10 +1,10 @@
|
||||
### model
|
||||
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
||||
template: llama3
|
||||
model_name_or_path: Qwen/Qwen3-4B-Instruct-2507
|
||||
template: qwen3_nothink
|
||||
trust_remote_code: true
|
||||
|
||||
### export
|
||||
export_dir: output/llama3_gptq
|
||||
export_dir: saves/qwen3_gptq
|
||||
export_quantization_bit: 4
|
||||
export_quantization_dataset: data/c4_demo.jsonl
|
||||
export_size: 5
|
||||
@@ -1,13 +1,13 @@
|
||||
### Note: DO NOT use quantized model or quantization_bit when merging lora adapters
|
||||
|
||||
### model
|
||||
model_name_or_path: Qwen/Qwen2.5-VL-7B-Instruct
|
||||
adapter_name_or_path: saves/qwen2_5vl-7b/lora/sft
|
||||
template: qwen2_vl
|
||||
model_name_or_path: Qwen/Qwen3-4B-Instruct-2507
|
||||
adapter_name_or_path: saves/qwen3-4b/lora/sft
|
||||
template: qwen3_nothink
|
||||
trust_remote_code: true
|
||||
|
||||
### export
|
||||
export_dir: output/qwen2_5vl_lora_sft
|
||||
export_dir: saves/qwen3_sft_merged
|
||||
export_size: 5
|
||||
export_device: cpu # choices: [cpu, auto]
|
||||
export_legacy_format: false
|
||||
@@ -1,13 +1,13 @@
|
||||
### Note: DO NOT use quantized model or quantization_bit when merging lora adapters
|
||||
|
||||
### model
|
||||
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
||||
adapter_name_or_path: saves/llama3-8b/lora/sft
|
||||
template: llama3
|
||||
model_name_or_path: Qwen/Qwen3-VL-4B-Instruct
|
||||
adapter_name_or_path: saves/qwen3-vl-4b/lora/sft
|
||||
template: qwen3_vl_nothink
|
||||
trust_remote_code: true
|
||||
|
||||
### export
|
||||
export_dir: output/llama3_lora_sft
|
||||
export_dir: saves/qwen3_vl_sft_merged
|
||||
export_size: 5
|
||||
export_device: cpu # choices: [cpu, auto]
|
||||
export_legacy_format: false
|
||||
@@ -1,5 +1,5 @@
|
||||
### model
|
||||
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
||||
model_name_or_path: Qwen/Qwen3-4B-Instruct-2507
|
||||
trust_remote_code: true
|
||||
|
||||
### method
|
||||
@@ -10,15 +10,14 @@ deepspeed: examples/deepspeed/ds_z3_config.json # choices: [ds_z0_config.json,
|
||||
|
||||
### dataset
|
||||
dataset: identity,alpaca_en_demo
|
||||
template: llama3
|
||||
template: qwen3_nothink
|
||||
cutoff_len: 2048
|
||||
max_samples: 1000
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
dataloader_num_workers: 4
|
||||
|
||||
### output
|
||||
output_dir: saves/llama3-8b/full/sft
|
||||
output_dir: saves/qwen3-4b/full/sft
|
||||
logging_steps: 10
|
||||
save_steps: 500
|
||||
plot_loss: true
|
||||
@@ -1,46 +0,0 @@
|
||||
### model
|
||||
model_name_or_path: Qwen/Qwen3-32B
|
||||
trust_remote_code: true
|
||||
use_v1_kernels: true
|
||||
|
||||
### method
|
||||
stage: sft
|
||||
do_train: true
|
||||
finetuning_type: full
|
||||
deepspeed: examples/deepspeed/ds_z2_autotp_config.json
|
||||
|
||||
### dataset
|
||||
dataset: identity,alpaca_en_demo
|
||||
template: qwen3
|
||||
cutoff_len: 2048
|
||||
max_samples: 1000
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
dataloader_num_workers: 4
|
||||
|
||||
### output
|
||||
output_dir: saves/qwen3-32b/full/sft_autotp
|
||||
logging_steps: 1
|
||||
save_steps: 500
|
||||
plot_loss: true
|
||||
overwrite_output_dir: true
|
||||
save_only_model: false
|
||||
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
|
||||
|
||||
### train
|
||||
per_device_train_batch_size: 4
|
||||
gradient_accumulation_steps: 1
|
||||
learning_rate: 1.0e-4
|
||||
num_train_epochs: 3.0
|
||||
lr_scheduler_type: cosine
|
||||
warmup_ratio: 0.1
|
||||
bf16: true
|
||||
ddp_timeout: 180000000
|
||||
resume_from_checkpoint: null
|
||||
|
||||
### eval
|
||||
# eval_dataset: alpaca_en_demo
|
||||
# val_size: 0.1
|
||||
# per_device_eval_batch_size: 1
|
||||
# eval_strategy: steps
|
||||
# eval_steps: 500
|
||||
@@ -1,5 +1,5 @@
|
||||
### model
|
||||
model_name_or_path: Qwen/Qwen2.5-VL-7B-Instruct
|
||||
model_name_or_path: Qwen/Qwen3-VL-4B-Instruct
|
||||
image_max_pixels: 262144
|
||||
video_max_pixels: 16384
|
||||
trust_remote_code: true
|
||||
@@ -15,15 +15,14 @@ deepspeed: examples/deepspeed/ds_z3_config.json
|
||||
|
||||
### dataset
|
||||
dataset: mllm_demo,identity,alpaca_en_demo
|
||||
template: qwen2_vl
|
||||
template: qwen3_vl_nothink
|
||||
cutoff_len: 2048
|
||||
max_samples: 1000
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
dataloader_num_workers: 4
|
||||
|
||||
### output
|
||||
output_dir: saves/qwen2_5vl-7b/full/sft
|
||||
output_dir: saves/qwen3-vl-4b/full/sft
|
||||
logging_steps: 10
|
||||
save_steps: 500
|
||||
plot_loss: true
|
||||
@@ -1,19 +0,0 @@
|
||||
### model
|
||||
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
||||
adapter_name_or_path: saves/llama3-8b/lora/sft
|
||||
trust_remote_code: true
|
||||
|
||||
### method
|
||||
finetuning_type: lora
|
||||
|
||||
### dataset
|
||||
task: mmlu_test # choices: [mmlu_test, ceval_validation, cmmlu_test]
|
||||
template: fewshot
|
||||
lang: en
|
||||
n_shot: 5
|
||||
|
||||
### output
|
||||
save_dir: saves/llama3-8b/lora/eval
|
||||
|
||||
### eval
|
||||
batch_size: 4
|
||||
@@ -1,43 +0,0 @@
|
||||
### model
|
||||
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
||||
reward_model: saves/llama3-8b/lora/reward
|
||||
trust_remote_code: true
|
||||
|
||||
### method
|
||||
stage: ppo
|
||||
do_train: true
|
||||
finetuning_type: lora
|
||||
lora_rank: 8
|
||||
lora_target: all
|
||||
|
||||
### dataset
|
||||
dataset: identity,alpaca_en_demo
|
||||
template: llama3
|
||||
cutoff_len: 2048
|
||||
max_samples: 1000
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
dataloader_num_workers: 4
|
||||
|
||||
### output
|
||||
output_dir: saves/llama3-8b/lora/ppo
|
||||
logging_steps: 10
|
||||
save_steps: 500
|
||||
plot_loss: true
|
||||
overwrite_output_dir: true
|
||||
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
|
||||
|
||||
### train
|
||||
per_device_train_batch_size: 1
|
||||
gradient_accumulation_steps: 8
|
||||
learning_rate: 1.0e-5
|
||||
num_train_epochs: 3.0
|
||||
lr_scheduler_type: cosine
|
||||
warmup_ratio: 0.1
|
||||
bf16: true
|
||||
ddp_timeout: 180000000
|
||||
|
||||
### generate
|
||||
max_new_tokens: 512
|
||||
top_k: 0
|
||||
top_p: 0.9
|
||||
@@ -1,46 +0,0 @@
|
||||
### model
|
||||
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
||||
trust_remote_code: true
|
||||
|
||||
### method
|
||||
stage: sft
|
||||
do_train: true
|
||||
finetuning_type: lora
|
||||
lora_rank: 8
|
||||
lora_target: all
|
||||
|
||||
### dataset
|
||||
dataset: identity,alpaca_en_demo
|
||||
template: llama3
|
||||
cutoff_len: 2048
|
||||
max_samples: 1000
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
dataloader_num_workers: 4
|
||||
|
||||
### output
|
||||
output_dir: saves/llama3-8b/lora/sft
|
||||
logging_steps: 10
|
||||
save_steps: 500
|
||||
plot_loss: true
|
||||
overwrite_output_dir: true
|
||||
save_only_model: false
|
||||
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
|
||||
|
||||
### train
|
||||
per_device_train_batch_size: 1
|
||||
gradient_accumulation_steps: 8
|
||||
learning_rate: 1.0e-4
|
||||
num_train_epochs: 3.0
|
||||
lr_scheduler_type: cosine
|
||||
warmup_ratio: 0.1
|
||||
bf16: true
|
||||
ddp_timeout: 180000000
|
||||
resume_from_checkpoint: null
|
||||
|
||||
### eval
|
||||
# eval_dataset: alpaca_en_demo
|
||||
# val_size: 0.1
|
||||
# per_device_eval_batch_size: 1
|
||||
# eval_strategy: steps
|
||||
# eval_steps: 500
|
||||
@@ -1,49 +0,0 @@
|
||||
# pip install git+https://github.com/hiyouga/transformers.git@llama4_train
|
||||
|
||||
### model
|
||||
model_name_or_path: meta-llama/Llama-4-Scout-17B-16E-Instruct
|
||||
trust_remote_code: true
|
||||
|
||||
### method
|
||||
stage: sft
|
||||
do_train: true
|
||||
finetuning_type: lora
|
||||
lora_rank: 8
|
||||
lora_target: all
|
||||
deepspeed: examples/deepspeed/ds_z3_config.json # choices: [ds_z0_config.json, ds_z2_config.json, ds_z3_config.json]
|
||||
|
||||
### dataset
|
||||
dataset: mllm_demo,identity,alpaca_en_demo
|
||||
template: llama4
|
||||
cutoff_len: 2048
|
||||
max_samples: 1000
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
dataloader_num_workers: 4
|
||||
|
||||
### output
|
||||
output_dir: saves/llama4-8b/lora/sft
|
||||
logging_steps: 10
|
||||
save_steps: 500
|
||||
plot_loss: true
|
||||
overwrite_output_dir: true
|
||||
save_only_model: false
|
||||
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
|
||||
|
||||
### train
|
||||
per_device_train_batch_size: 1
|
||||
gradient_accumulation_steps: 2
|
||||
learning_rate: 1.0e-4
|
||||
num_train_epochs: 3.0
|
||||
lr_scheduler_type: cosine
|
||||
warmup_ratio: 0.1
|
||||
bf16: true
|
||||
ddp_timeout: 180000000
|
||||
resume_from_checkpoint: null
|
||||
|
||||
### eval
|
||||
# eval_dataset: alpaca_en_demo
|
||||
# val_size: 0.1
|
||||
# per_device_eval_batch_size: 1
|
||||
# eval_strategy: steps
|
||||
# eval_steps: 500
|
||||
@@ -1,5 +1,5 @@
|
||||
### model
|
||||
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
||||
model_name_or_path: Qwen/Qwen3-4B-Instruct-2507
|
||||
trust_remote_code: true
|
||||
|
||||
### method
|
||||
@@ -13,15 +13,14 @@ pref_loss: sigmoid # choices: [sigmoid (dpo), orpo, simpo]
|
||||
|
||||
### dataset
|
||||
dataset: dpo_en_demo
|
||||
template: llama3
|
||||
template: qwen3_nothink
|
||||
cutoff_len: 2048
|
||||
max_samples: 1000
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
dataloader_num_workers: 4
|
||||
|
||||
### output
|
||||
output_dir: saves/llama3-8b/lora/dpo
|
||||
output_dir: saves/qwen3-4b/lora/dpo
|
||||
logging_steps: 10
|
||||
save_steps: 500
|
||||
plot_loss: true
|
||||
@@ -1,5 +1,5 @@
|
||||
### model
|
||||
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
||||
model_name_or_path: Qwen/Qwen3-4B-Instruct-2507
|
||||
trust_remote_code: true
|
||||
|
||||
### method
|
||||
@@ -12,15 +12,14 @@ pref_beta: 0.1
|
||||
|
||||
### dataset
|
||||
dataset: kto_en_demo
|
||||
template: llama3
|
||||
template: qwen3_nothink
|
||||
cutoff_len: 2048
|
||||
max_samples: 1000
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
dataloader_num_workers: 4
|
||||
|
||||
### output
|
||||
output_dir: saves/llama3-8b/lora/kto
|
||||
output_dir: saves/qwen3-4b/lora/kto
|
||||
logging_steps: 10
|
||||
save_steps: 500
|
||||
plot_loss: true
|
||||
@@ -1,5 +1,5 @@
|
||||
### model
|
||||
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
||||
model_name_or_path: Qwen/Qwen3-4B-Instruct-2507
|
||||
trust_remote_code: true
|
||||
|
||||
### method
|
||||
@@ -13,12 +13,11 @@ lora_target: all
|
||||
dataset: c4_demo
|
||||
cutoff_len: 2048
|
||||
max_samples: 1000
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
dataloader_num_workers: 4
|
||||
|
||||
### output
|
||||
output_dir: saves/llama3-8b/lora/pretrain
|
||||
output_dir: saves/qwen3-4b/lora/pretrain
|
||||
logging_steps: 10
|
||||
save_steps: 500
|
||||
plot_loss: true
|
||||
@@ -1,5 +1,5 @@
|
||||
### model
|
||||
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
||||
model_name_or_path: Qwen/Qwen3-4B-Instruct-2507
|
||||
trust_remote_code: true
|
||||
|
||||
### method
|
||||
@@ -11,15 +11,14 @@ lora_target: all
|
||||
|
||||
### dataset
|
||||
dataset: dpo_en_demo
|
||||
template: llama3
|
||||
template: qwen3_nothink
|
||||
cutoff_len: 2048
|
||||
max_samples: 1000
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
dataloader_num_workers: 4
|
||||
|
||||
### output
|
||||
output_dir: saves/llama3-8b/lora/reward
|
||||
output_dir: saves/qwen3-4b/lora/reward
|
||||
logging_steps: 10
|
||||
save_steps: 500
|
||||
plot_loss: true
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
set -x
|
||||
|
||||
MODEL_PATH=meta-llama/Meta-Llama-3-8B-Instruct
|
||||
MODEL_PATH=Qwen/Qwen3-4B-Instruct-2507
|
||||
|
||||
llamafactory-cli train \
|
||||
--model_name_or_path ${MODEL_PATH} \
|
||||
@@ -13,13 +13,12 @@ llamafactory-cli train \
|
||||
--lora_rank 8 \
|
||||
--lora_target all \
|
||||
--dataset identity,alpaca_en_demo \
|
||||
--template llama3 \
|
||||
--template qwen3_nothink \
|
||||
--cutoff_len 2048 \
|
||||
--max_samples 1000 \
|
||||
--overwrite_cache \
|
||||
--preprocessing_num_workers 16 \
|
||||
--dataloader_num_workers 4 \
|
||||
--output_dir saves/llama3-8b/lora/sft \
|
||||
--output_dir saves/qwen3-4b/lora/sft \
|
||||
--logging_steps 10 \
|
||||
--save_steps 500 \
|
||||
--plot_loss \
|
||||
@@ -1,5 +1,5 @@
|
||||
### model
|
||||
model_name_or_path: openai/gpt-oss-20b
|
||||
model_name_or_path: Qwen/Qwen3-4B-Instruct-2507
|
||||
trust_remote_code: true
|
||||
|
||||
### method
|
||||
@@ -11,15 +11,14 @@ lora_target: all
|
||||
|
||||
### dataset
|
||||
dataset: identity,alpaca_en_demo
|
||||
template: gpt
|
||||
template: qwen3_nothink
|
||||
cutoff_len: 2048
|
||||
max_samples: 1000
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
dataloader_num_workers: 4
|
||||
|
||||
### output
|
||||
output_dir: saves/gpt-20b/lora/sft
|
||||
output_dir: saves/qwen3-4b/lora/sft
|
||||
logging_steps: 10
|
||||
save_steps: 500
|
||||
plot_loss: true
|
||||
@@ -1,5 +1,5 @@
|
||||
### model
|
||||
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
||||
model_name_or_path: Qwen/Qwen3-4B-Instruct-2507
|
||||
trust_remote_code: true
|
||||
|
||||
### method
|
||||
@@ -12,15 +12,14 @@ deepspeed: examples/deepspeed/ds_z3_config.json # choices: [ds_z0_config.json,
|
||||
|
||||
### dataset
|
||||
dataset: identity,alpaca_en_demo
|
||||
template: llama3
|
||||
template: qwen3_nothink
|
||||
cutoff_len: 2048
|
||||
max_samples: 1000
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
dataloader_num_workers: 4
|
||||
|
||||
### output
|
||||
output_dir: saves/llama3-8b/lora/sft
|
||||
output_dir: saves/qwen3-4b/lora/sft
|
||||
logging_steps: 10
|
||||
save_steps: 500
|
||||
plot_loss: true
|
||||
@@ -1,5 +1,5 @@
|
||||
### model
|
||||
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct # or use local absolute path
|
||||
model_name_or_path: Qwen/Qwen3-4B-Instruct-2507 # or use local absolute path
|
||||
trust_remote_code: true
|
||||
|
||||
### method
|
||||
@@ -12,10 +12,9 @@ lora_target: all
|
||||
### dataset
|
||||
dataset: identity,alpaca_en_demo
|
||||
dataset_dir: REMOTE:llamafactory/demo_data # or use local absolute path
|
||||
template: llama3
|
||||
template: qwen3_nothink
|
||||
cutoff_len: 2048
|
||||
max_samples: 1000
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
dataloader_num_workers: 4
|
||||
|
||||
@@ -29,7 +28,7 @@ save_only_model: false
|
||||
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
|
||||
|
||||
### ray
|
||||
ray_run_name: llama3_8b_sft_lora
|
||||
ray_run_name: qwen3_4b_sft_lora
|
||||
ray_storage_path: ./saves
|
||||
ray_num_workers: 4 # Number of GPUs to use.
|
||||
placement_strategy: PACK
|
||||
@@ -1,5 +1,5 @@
|
||||
### model
|
||||
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
||||
model_name_or_path: Qwen/Qwen3-4B-Instruct-2507
|
||||
trust_remote_code: true
|
||||
|
||||
### method
|
||||
@@ -11,13 +11,12 @@ lora_target: all
|
||||
|
||||
### dataset
|
||||
dataset: identity,alpaca_en_demo
|
||||
template: llama3
|
||||
template: qwen3_nothink
|
||||
cutoff_len: 2048
|
||||
max_samples: 1000
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
tokenized_path: saves/llama3-8b/dataset/sft
|
||||
tokenized_path: saves/qwen3-4b/dataset/sft
|
||||
|
||||
### output
|
||||
output_dir: saves/llama3-8b/lora/sft
|
||||
### output (not used)
|
||||
output_dir: saves/qwen3-4b/lora/sft
|
||||
overwrite_output_dir: true
|
||||
@@ -1,5 +1,5 @@
|
||||
### model
|
||||
model_name_or_path: Qwen/Qwen2.5-VL-7B-Instruct
|
||||
model_name_or_path: Qwen/Qwen3-VL-4B-Instruct
|
||||
image_max_pixels: 262144
|
||||
video_max_pixels: 16384
|
||||
trust_remote_code: true
|
||||
@@ -15,15 +15,14 @@ pref_loss: sigmoid # choices: [sigmoid (dpo), orpo, simpo]
|
||||
|
||||
### dataset
|
||||
dataset: rlhf_v
|
||||
template: qwen2_vl
|
||||
template: qwen3_vl_nothink
|
||||
cutoff_len: 2048
|
||||
max_samples: 1000
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
dataloader_num_workers: 4
|
||||
|
||||
### output
|
||||
output_dir: saves/qwen2_5vl-7b/lora/dpo
|
||||
output_dir: saves/qwen3-vl-4b/lora/dpo
|
||||
logging_steps: 10
|
||||
save_steps: 500
|
||||
plot_loss: true
|
||||
@@ -1,5 +1,5 @@
|
||||
### model
|
||||
model_name_or_path: Qwen/Qwen2.5-VL-7B-Instruct
|
||||
model_name_or_path: Qwen/Qwen3-VL-4B-Instruct
|
||||
image_max_pixels: 262144
|
||||
video_max_pixels: 16384
|
||||
trust_remote_code: true
|
||||
@@ -13,15 +13,14 @@ lora_target: all
|
||||
|
||||
### dataset
|
||||
dataset: mllm_demo,identity,alpaca_en_demo # video: mllm_video_demo
|
||||
template: qwen2_vl
|
||||
template: qwen3_vl_nothink
|
||||
cutoff_len: 2048
|
||||
max_samples: 1000
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
dataloader_num_workers: 4
|
||||
|
||||
### output
|
||||
output_dir: saves/qwen2_5vl-7b/lora/sft
|
||||
output_dir: saves/qwen3-vl-4b/lora/sft
|
||||
logging_steps: 10
|
||||
save_steps: 500
|
||||
plot_loss: true
|
||||
@@ -14,7 +14,6 @@ dataset: identity,alpaca_en_demo
|
||||
template: llama3
|
||||
cutoff_len: 2048
|
||||
max_samples: 1000
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
dataloader_num_workers: 4
|
||||
|
||||
|
||||
@@ -14,7 +14,6 @@ dataset: identity,alpaca_en_demo
|
||||
template: llama3
|
||||
cutoff_len: 2048
|
||||
max_samples: 1000
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
dataloader_num_workers: 4
|
||||
|
||||
|
||||
@@ -14,7 +14,6 @@ dataset: identity,alpaca_en_demo
|
||||
template: llama3
|
||||
cutoff_len: 2048
|
||||
max_samples: 1000
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
dataloader_num_workers: 4
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
### model
|
||||
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
||||
model_name_or_path: Qwen/Qwen3-4B-Instruct-2507
|
||||
quantization_bit: 4
|
||||
quantization_method: bnb
|
||||
double_quantization: false
|
||||
@@ -14,15 +14,14 @@ lora_target: all
|
||||
|
||||
### dataset
|
||||
dataset: identity,alpaca_en_demo
|
||||
template: llama3
|
||||
template: qwen3_nothink
|
||||
cutoff_len: 2048
|
||||
max_samples: 1000
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
dataloader_num_workers: 4
|
||||
|
||||
### output
|
||||
output_dir: saves/llama3-8b/lora/sft
|
||||
output_dir: saves/qwen3-4b/lora/sft
|
||||
logging_steps: 10
|
||||
save_steps: 500
|
||||
plot_loss: true
|
||||
@@ -1,5 +1,5 @@
|
||||
### model
|
||||
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
||||
model_name_or_path: Qwen/Qwen3-4B-Instruct-2507
|
||||
quantization_bit: 4 # choices: [8 (bnb/hqq/eetq), 4 (bnb/hqq), 3 (hqq), 2 (hqq)]
|
||||
quantization_method: bnb # choices: [bnb, hqq, eetq]
|
||||
trust_remote_code: true
|
||||
@@ -13,15 +13,14 @@ lora_target: all
|
||||
|
||||
### dataset
|
||||
dataset: identity,alpaca_en_demo
|
||||
template: llama3
|
||||
template: qwen3_nothink
|
||||
cutoff_len: 2048
|
||||
max_samples: 1000
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
dataloader_num_workers: 4
|
||||
|
||||
### output
|
||||
output_dir: saves/llama3-8b/lora/sft
|
||||
output_dir: saves/qwen3-4b/lora/sft
|
||||
logging_steps: 10
|
||||
save_steps: 500
|
||||
plot_loss: true
|
||||
@@ -41,15 +41,14 @@ dependencies = [
|
||||
"torch>=2.4.0",
|
||||
"torchvision>=0.19.0",
|
||||
"torchaudio>=2.4.0",
|
||||
"transformers>=4.49.0,<=4.56.2,!=4.52.0; python_version < '3.10'",
|
||||
"transformers>=4.49.0,<=4.57.1,!=4.52.0,!=4.57.0; python_version >= '3.10'",
|
||||
"transformers>=4.51.0,<=4.57.1,!=4.52.0,!=4.57.0",
|
||||
"datasets>=2.16.0,<=4.0.0",
|
||||
"accelerate>=1.3.0,<=1.11.0",
|
||||
"peft>=0.14.0,<=0.17.1",
|
||||
"trl>=0.8.6,<=0.9.6",
|
||||
"trl>=0.18.0,<=0.24.0",
|
||||
"torchdata>=0.10.0,<=0.11.0",
|
||||
# gui
|
||||
"gradio>=4.38.0,<=6.2.0",
|
||||
"gradio>=4.38.0,<=5.50.0",
|
||||
"matplotlib>=3.7.0",
|
||||
"tyro<0.9.0",
|
||||
# ops
|
||||
|
||||
@@ -14,9 +14,12 @@
|
||||
|
||||
import gc
|
||||
import json
|
||||
import time
|
||||
|
||||
import av
|
||||
import fire
|
||||
from datasets import load_dataset
|
||||
from eval_bleu_rouge import compute_metrics
|
||||
from tqdm import tqdm
|
||||
from transformers import Seq2SeqTrainingArguments
|
||||
|
||||
@@ -51,6 +54,7 @@ def vllm_infer(
|
||||
max_samples: int | None = None,
|
||||
vllm_config: str = "{}",
|
||||
save_name: str = "generated_predictions.jsonl",
|
||||
matrix_save_name: str = None,
|
||||
temperature: float = 0.95,
|
||||
top_p: float = 0.7,
|
||||
top_k: int = 50,
|
||||
@@ -117,6 +121,7 @@ def vllm_infer(
|
||||
if isinstance(model_args.vllm_config, dict):
|
||||
engine_args.update(model_args.vllm_config)
|
||||
|
||||
model_preparation_start_time = time.time()
|
||||
llm = LLM(**engine_args)
|
||||
|
||||
# load datasets
|
||||
@@ -142,6 +147,7 @@ def vllm_infer(
|
||||
all_prompts, all_preds, all_labels = [], [], []
|
||||
need_video_kwargs = _need_video_kwargs(template)
|
||||
|
||||
model_predict_start_time = time.time()
|
||||
# Add batch process to avoid the issue of too many files opened
|
||||
for i in tqdm(range(0, len(train_dataset), batch_size), desc="Processing batched inference"):
|
||||
vllm_inputs, prompts, labels = [], [], []
|
||||
@@ -218,6 +224,7 @@ def vllm_infer(
|
||||
all_labels.extend(labels)
|
||||
gc.collect()
|
||||
|
||||
model_predict_end_time = time.time()
|
||||
# Write all results at once outside the loop
|
||||
with open(save_name, "w", encoding="utf-8") as f:
|
||||
for text, pred, label in zip(all_prompts, all_preds, all_labels):
|
||||
@@ -227,6 +234,49 @@ def vllm_infer(
|
||||
print(f"{len(all_prompts)} total generated results have been saved at {save_name}.")
|
||||
print("*" * 70)
|
||||
|
||||
# Write all matrix results when matrix_save_name is not None,
|
||||
# The result matrix is referencing src.llamafactory.train.sft.workflow.run_sft # 127~132
|
||||
# trainer.save_metrics("predict", predict_results.metrics)
|
||||
#
|
||||
# {
|
||||
# "predict_bleu-4": 4.349975,
|
||||
# "predict_model_preparation_time": 0.0128,
|
||||
# "predict_rouge-1": 21.873359375,
|
||||
# "predict_rouge-2": 4.144340625,
|
||||
# "predict_rouge-l": 10.83949375,
|
||||
# "predict_runtime": 131.664,
|
||||
# "predict_samples_per_second": 0.076,
|
||||
# "predict_steps_per_second": 0.008
|
||||
# }
|
||||
#
|
||||
if matrix_save_name is not None:
|
||||
predict_time = model_predict_end_time - model_predict_start_time
|
||||
preparation_time = model_predict_start_time - model_preparation_start_time
|
||||
|
||||
start_time = time.time()
|
||||
dataset = load_dataset("json", data_files=save_name, split="train")
|
||||
dataset = dataset.map(compute_metrics, num_proc=8, remove_columns=dataset.column_names)
|
||||
score_dict = dataset.to_dict()
|
||||
|
||||
average_score = {}
|
||||
for task, scores in sorted(score_dict.items(), key=lambda x: x[0]):
|
||||
score = sum(scores) / len(scores) if scores else 0.0
|
||||
print(f"predict_{task}: {score:.4f}")
|
||||
average_score["predict_" + task] = score
|
||||
|
||||
average_score["predict_model_preparation_time"] = preparation_time
|
||||
average_score["predict_runtime"] = predict_time
|
||||
num_steps = len(range(0, len(train_dataset), batch_size))
|
||||
average_score["predict_samples_per_second"] = len(dataset) / predict_time if predict_time > 0 else 0.0
|
||||
average_score["predict_steps_per_second"] = num_steps / predict_time if predict_time > 0 else 0.0
|
||||
|
||||
with open(matrix_save_name, "w", encoding="utf-8") as f:
|
||||
json.dump(average_score, f, indent=4)
|
||||
|
||||
print("*" * 70)
|
||||
print(f"\nDone in {time.time() - start_time:.3f}s.\nScore file saved to {matrix_save_name}.")
|
||||
print("*" * 70)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(vllm_infer)
|
||||
|
||||
@@ -1673,6 +1673,43 @@ register_template(
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="minimax1",
|
||||
format_user=StringFormatter(
|
||||
slots=[
|
||||
"<beginning_of_sentence>user name=user\n{{content}}<end_of_sentence>\n<beginning_of_sentence>ai name=assistant\n"
|
||||
]
|
||||
),
|
||||
format_assistant=StringFormatter(slots=["{{content}}<end_of_sentence>\n"]),
|
||||
format_system=StringFormatter(
|
||||
slots=["<beginning_of_sentence>system ai_setting=assistant\n{{content}}<end_of_sentence>\n"]
|
||||
),
|
||||
format_function=FunctionFormatter(slots=["{{content}}<end_of_sentence>\n"], tool_format="minimax1"),
|
||||
format_observation=StringFormatter(
|
||||
slots=[
|
||||
"<beginning_of_sentence>tool name=tools\n{{content}}<end_of_sentence>\n<beginning_of_sentence>ai name=assistant\n"
|
||||
]
|
||||
),
|
||||
format_tools=ToolFormatter(tool_format="minimax1"),
|
||||
default_system="You are a helpful assistant.",
|
||||
stop_words=["<end_of_sentence>"],
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="minimax2",
|
||||
format_user=StringFormatter(slots=["]~b]user\n{{content}}[e~[\n]~b]ai\n"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}[e~[\n"]),
|
||||
format_system=StringFormatter(slots=["]~!b[]~b]system\n{{content}}[e~[\n"]),
|
||||
format_function=FunctionFormatter(slots=["{{content}}[e~[\n"], tool_format="minimax2"),
|
||||
format_observation=StringFormatter(slots=["]~b]tool\n<response>{{content}}</response>[e~[\n]~b]ai\n"]),
|
||||
format_tools=ToolFormatter(tool_format="minimax2"),
|
||||
default_system="You are a helpful assistant. Your name is MiniMax-M2.1 and is built by MiniMax.",
|
||||
stop_words=["[e~["],
|
||||
template_class=ReasoningTemplate,
|
||||
)
|
||||
|
||||
|
||||
# mistral tokenizer v3 tekken
|
||||
register_template(
|
||||
name="ministral",
|
||||
|
||||
@@ -61,6 +61,21 @@ LLAMA3_TOOL_PROMPT = (
|
||||
"Do not use variables.\n\n{tool_text}"
|
||||
)
|
||||
|
||||
MINIMAX_M1_TOOL_PROMPT = (
|
||||
"You are provided with these tools:\n<tools>\n{tool_text}</tools>\n\n"
|
||||
"If you need to call tools, please respond with <tool_calls></tool_calls> XML tags, and provide tool-name and "
|
||||
"json-object of arguments, following the format below:\n<tool_calls>\n"
|
||||
"""{{"name": <tool-name-1>, "arguments": <args-json-object-1>}}\n...\n</tool_calls>"""
|
||||
)
|
||||
|
||||
MINIMAX_M2_TOOL_PROMPT = (
|
||||
"\n\n# Tools\n\nYou may call one or more tools to assist with the user query.\n"
|
||||
"Here are the tools available in JSONSchema format:\n\n<tools>\n{tool_text}</tools>\n\n"
|
||||
"When making tool calls, use XML format to invoke tools and pass parameters:\n"
|
||||
"""\n<minimax:tool_call>\n<invoke name="tool-name-1">\n<parameter name="param-key-1">param-value-1</parameter>\n"""
|
||||
"""<parameter name="param-key-2">param-value-2</parameter>\n...\n</invoke>\n</minimax:tool_call>"""
|
||||
)
|
||||
|
||||
QWEN_TOOL_PROMPT = (
|
||||
"\n\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\n"
|
||||
"You are provided with function signatures within <tools></tools> XML tags:\n<tools>{tool_text}"
|
||||
@@ -253,6 +268,109 @@ class Llama3ToolUtils(ToolUtils):
|
||||
return content
|
||||
|
||||
|
||||
class MiniMaxM1ToolUtils(ToolUtils):
|
||||
r"""MiniMax-M1 tool using template."""
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def tool_formatter(tools: list[dict[str, Any]]) -> str:
|
||||
tool_text = ""
|
||||
for tool in tools:
|
||||
tool = tool.get("function", "") if tool.get("type") == "function" else tool
|
||||
tool_text += json.dumps(tool, ensure_ascii=False) + "\n"
|
||||
|
||||
return MINIMAX_M1_TOOL_PROMPT.format(tool_text=tool_text)
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def function_formatter(functions: list["FunctionCall"]) -> str:
|
||||
function_texts = []
|
||||
for func in functions:
|
||||
name, arguments = func.name, json.loads(func.arguments)
|
||||
function_texts.append(json.dumps({"name": name, "arguments": arguments}, ensure_ascii=False))
|
||||
|
||||
return "<tool_calls>\n" + "\n".join(function_texts) + "\n</tool_calls>"
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]:
|
||||
regex = re.compile(r"<tool_calls>\s*(.+?)\s*</tool_calls>", re.DOTALL)
|
||||
tool_match = re.search(regex, content)
|
||||
if not tool_match:
|
||||
return content
|
||||
|
||||
tool_calls_content = tool_match.group(1)
|
||||
results = []
|
||||
for line in tool_calls_content.split("\n"):
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
try:
|
||||
tool_call = json.loads(line)
|
||||
results.append(FunctionCall(tool_call["name"], json.dumps(tool_call["arguments"], ensure_ascii=False)))
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
return results
|
||||
|
||||
|
||||
class MiniMaxM2ToolUtils(ToolUtils):
|
||||
r"""MiniMax-M2 tool using template."""
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def tool_formatter(tools: list[dict[str, Any]]) -> str:
|
||||
tool_text = ""
|
||||
for tool in tools:
|
||||
tool = tool.get("function", "") if tool.get("type") == "function" else tool
|
||||
tool_text += "<tool>" + json.dumps(tool, ensure_ascii=False) + "</tool>\n"
|
||||
|
||||
return MINIMAX_M2_TOOL_PROMPT.format(tool_text=tool_text)
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def function_formatter(functions: list["FunctionCall"]) -> str:
|
||||
function_texts = []
|
||||
for func in functions:
|
||||
name, arguments = func.name, json.loads(func.arguments)
|
||||
prompt = f'<invoke name="{name}">'
|
||||
for key, value in arguments.items():
|
||||
prompt += f'\n<parameter name="{key}">'
|
||||
if not isinstance(value, str):
|
||||
value = json.dumps(value, ensure_ascii=False)
|
||||
prompt += value + "</parameter>"
|
||||
prompt += "\n</invoke>"
|
||||
function_texts.append(prompt)
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]:
|
||||
regex = re.compile(r"<minimax:tool_call>\s*(.+?)\s*</minimax:tool_call>", re.DOTALL)
|
||||
tool_match = re.search(regex, content)
|
||||
if not tool_match:
|
||||
return content
|
||||
|
||||
tool_calls_content = tool_match.group(1)
|
||||
invoke_regex = re.compile(r"<invoke name=\"(.*?)\">(.*?)</invoke>", re.DOTALL)
|
||||
results = []
|
||||
|
||||
for func_name, params_block in re.findall(invoke_regex, tool_calls_content):
|
||||
args_dict = {}
|
||||
param_pattern = re.compile(r"<parameter name=\"(.*?)\">(.*?)</parameter>", re.DOTALL)
|
||||
for key, raw_value in re.findall(param_pattern, params_block):
|
||||
value = raw_value.strip()
|
||||
try:
|
||||
parsed_value = json.loads(value)
|
||||
except json.JSONDecodeError:
|
||||
parsed_value = raw_value
|
||||
args_dict[key] = parsed_value
|
||||
|
||||
results.append(FunctionCall(func_name.strip(), json.dumps(args_dict, ensure_ascii=False)))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
class MistralToolUtils(ToolUtils):
|
||||
r"""Mistral v0.3 tool using template."""
|
||||
|
||||
@@ -432,6 +550,8 @@ TOOLS = {
|
||||
"default": DefaultToolUtils(),
|
||||
"glm4": GLM4ToolUtils(),
|
||||
"llama3": Llama3ToolUtils(),
|
||||
"minimax1": MiniMaxM1ToolUtils(),
|
||||
"minimax2": MiniMaxM2ToolUtils(),
|
||||
"mistral": MistralToolUtils(),
|
||||
"qwen": QwenToolUtils(),
|
||||
"glm4_moe": GLM4MOEToolUtils(),
|
||||
|
||||
@@ -63,6 +63,7 @@ MCA_SUPPORTED_MODELS = {
|
||||
"qwen2",
|
||||
"qwen2_vl",
|
||||
"qwen2_5_vl",
|
||||
"qwen3_vl",
|
||||
"qwen3",
|
||||
"qwen3_moe",
|
||||
"qwen3_next",
|
||||
@@ -1071,6 +1072,40 @@ register_model_group(
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"MiniMax-Text-01-Instruct": {
|
||||
DownloadSource.DEFAULT: "MiniMaxAI/MiniMax-Text-01-hf",
|
||||
DownloadSource.MODELSCOPE: "MiniMaxAI/MiniMax-Text-01",
|
||||
},
|
||||
"MiniMax-M1-40k-Thinking": {
|
||||
DownloadSource.DEFAULT: "MiniMaxAI/MiniMax-M1-40k-hf",
|
||||
DownloadSource.MODELSCOPE: "MiniMaxAI/MiniMax-M1-40k-hf",
|
||||
},
|
||||
"MiniMax-M1-80k-Thinking": {
|
||||
DownloadSource.DEFAULT: "MiniMaxAI/MiniMax-M1-80k-hf",
|
||||
DownloadSource.MODELSCOPE: "MiniMaxAI/MiniMax-M1-80k-hf",
|
||||
},
|
||||
},
|
||||
template="minimax1",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"MiniMax-M2-Thinking": {
|
||||
DownloadSource.DEFAULT: "MiniMaxAI/MiniMax-M2",
|
||||
DownloadSource.MODELSCOPE: "MiniMaxAI/MiniMax-M2",
|
||||
},
|
||||
"MiniMax-M2.1-Thinking": {
|
||||
DownloadSource.DEFAULT: "MiniMaxAI/MiniMax-M2.1",
|
||||
DownloadSource.MODELSCOPE: "MiniMaxAI/MiniMax-M2.1",
|
||||
},
|
||||
},
|
||||
template="minimax2",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Granite-3.0-1B-A400M-Base": {
|
||||
|
||||
@@ -19,7 +19,7 @@
|
||||
from collections import OrderedDict
|
||||
|
||||
|
||||
VERSION = "0.9.4.dev0"
|
||||
VERSION = "0.9.4"
|
||||
|
||||
|
||||
def print_env() -> None:
|
||||
|
||||
@@ -94,11 +94,11 @@ def check_version(requirement: str, mandatory: bool = False) -> None:
|
||||
|
||||
def check_dependencies() -> None:
|
||||
r"""Check the version of the required packages."""
|
||||
check_version("transformers>=4.49.0,<=4.57.1")
|
||||
check_version("transformers>=4.51.0,<=4.57.1")
|
||||
check_version("datasets>=2.16.0,<=4.0.0")
|
||||
check_version("accelerate>=1.3.0,<=1.11.0")
|
||||
check_version("peft>=0.14.0,<=0.17.1")
|
||||
check_version("trl>=0.8.6,<=0.9.6")
|
||||
check_version("trl>=0.18.0,<=0.24.0")
|
||||
|
||||
|
||||
def calculate_tps(dataset: list[dict[str, Any]], metrics: dict[str, float], stage: Literal["sft", "rm"]) -> float:
|
||||
|
||||
@@ -173,7 +173,7 @@ class BaseModelArguments:
|
||||
default=True,
|
||||
metadata={"help": "Whether or not to use KV cache in generation."},
|
||||
)
|
||||
use_v1_kernels: bool = field(
|
||||
use_v1_kernels: bool | None = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to use high-performance kernels in training."},
|
||||
)
|
||||
|
||||
@@ -216,9 +216,9 @@ def load_model(
|
||||
"You are try to using future feature about kernels, please note that this feature "
|
||||
"is not supported for all models. If get any error, please disable this feature, or report the issue."
|
||||
)
|
||||
from ..v1.plugins.model_plugins.kernels.registry import apply_available_kernels
|
||||
from ..v1.plugins.model_plugins.kernels.interface import apply_default_kernels
|
||||
|
||||
model = apply_available_kernels(model)
|
||||
model = apply_default_kernels(model=model, include_kernels=model_args.use_v1_kernels)
|
||||
|
||||
trainable_params, all_param = count_parameters(model)
|
||||
if is_trainable:
|
||||
|
||||
@@ -26,6 +26,7 @@ import torch.nn.functional as F
|
||||
from transformers import Trainer
|
||||
from trl import DPOTrainer
|
||||
from trl.trainer import disable_dropout_in_model
|
||||
from trl.trainer.utils import prepare_deepspeed
|
||||
from typing_extensions import override
|
||||
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
@@ -95,7 +96,7 @@ class CustomDPOTrainer(DPOTrainer):
|
||||
if not (
|
||||
getattr(ref_model, "is_loaded_in_8bit", False) or getattr(ref_model, "is_loaded_in_4bit", False)
|
||||
): # quantized models are already set on the correct device
|
||||
self.ref_model = self._prepare_deepspeed(self.ref_model)
|
||||
self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
|
||||
else:
|
||||
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
|
||||
self.ref_model.eval()
|
||||
@@ -210,7 +211,7 @@ class CustomDPOTrainer(DPOTrainer):
|
||||
@override
|
||||
def concatenated_forward(
|
||||
self, model: "PreTrainedModel", batch: dict[str, "torch.Tensor"], is_ref_model: bool = False
|
||||
) -> tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
|
||||
) -> dict[str, "torch.Tensor"]:
|
||||
r"""Compute the sum log probabilities of the labels under given logits if loss_type is not IPO, ORPO or SimPO.
|
||||
|
||||
Otherwise the average log probabilities.
|
||||
@@ -230,11 +231,18 @@ class CustomDPOTrainer(DPOTrainer):
|
||||
chosen_logps, rejected_logps = all_logps.split(batch_size, dim=0)
|
||||
chosen_logits, rejected_logits = all_logits.split(batch_size, dim=0)
|
||||
chosen_length, _ = valid_length.split(batch_size, dim=0)
|
||||
|
||||
if self.loss_type in ["ipo", "orpo", "simpo"]:
|
||||
return chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_logps
|
||||
chosen_logps_avg = chosen_logps
|
||||
else:
|
||||
return chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_logps / chosen_length
|
||||
chosen_logps_avg = chosen_logps / chosen_length
|
||||
|
||||
return {
|
||||
"chosen_logps": chosen_logps,
|
||||
"rejected_logps": rejected_logps,
|
||||
"chosen_logits": chosen_logits,
|
||||
"rejected_logits": rejected_logits,
|
||||
"chosen_logps_avg": chosen_logps_avg,
|
||||
}
|
||||
|
||||
@override
|
||||
def compute_reference_log_probs(
|
||||
@@ -252,9 +260,9 @@ class CustomDPOTrainer(DPOTrainer):
|
||||
ref_context = nullcontext()
|
||||
|
||||
with torch.no_grad(), ref_context:
|
||||
reference_chosen_logps, reference_rejected_logps, *_ = self.concatenated_forward(
|
||||
ref_model, batch, is_ref_model=True
|
||||
)
|
||||
ref_output = self.concatenated_forward(ref_model, batch, is_ref_model=True)
|
||||
reference_chosen_logps = ref_output["chosen_logps"]
|
||||
reference_rejected_logps = ref_output["rejected_logps"]
|
||||
|
||||
return reference_chosen_logps, reference_rejected_logps
|
||||
|
||||
@@ -267,13 +275,13 @@ class CustomDPOTrainer(DPOTrainer):
|
||||
) -> tuple["torch.Tensor", dict[str, "torch.Tensor"]]:
|
||||
r"""Compute the DPO loss and other metrics for the given batch of inputs for train or test."""
|
||||
metrics = {}
|
||||
(
|
||||
policy_chosen_logps,
|
||||
policy_rejected_logps,
|
||||
policy_chosen_logits,
|
||||
policy_rejected_logits,
|
||||
policy_chosen_logps_avg,
|
||||
) = self.concatenated_forward(model, batch)
|
||||
|
||||
model_output = self.concatenated_forward(model, batch)
|
||||
policy_chosen_logps = model_output["chosen_logps"]
|
||||
policy_rejected_logps = model_output["rejected_logps"]
|
||||
policy_chosen_logits = model_output["chosen_logits"]
|
||||
policy_rejected_logits = model_output["rejected_logits"]
|
||||
policy_chosen_logps_avg = model_output["chosen_logps_avg"]
|
||||
|
||||
reference_chosen_logps, reference_rejected_logps = self.compute_reference_log_probs(model, batch)
|
||||
losses, chosen_rewards, rejected_rewards = self.compute_preference_loss(
|
||||
|
||||
@@ -25,6 +25,7 @@ import torch
|
||||
from transformers import Trainer
|
||||
from trl import KTOTrainer
|
||||
from trl.trainer import disable_dropout_in_model
|
||||
from trl.trainer.utils import prepare_deepspeed
|
||||
from typing_extensions import override
|
||||
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
@@ -77,6 +78,13 @@ class CustomKTOTrainer(KTOTrainer):
|
||||
self.desirable_weight = finetuning_args.kto_chosen_weight
|
||||
self.undesirable_weight = finetuning_args.kto_rejected_weight
|
||||
self.ftx_gamma = finetuning_args.pref_ftx
|
||||
# trl
|
||||
# Not all losses require a KL calculation
|
||||
self.calculate_KL = True
|
||||
if hasattr(self, "loss_type") and self.loss_type in ["apo_zero_unpaired"]:
|
||||
self.calculate_KL = False
|
||||
else:
|
||||
self.loss_type = "kto"
|
||||
|
||||
Trainer.__init__(self, model=model, **kwargs)
|
||||
self.model_accepts_loss_kwargs = False # overwrite trainer's default behavior
|
||||
@@ -90,7 +98,7 @@ class CustomKTOTrainer(KTOTrainer):
|
||||
if not (
|
||||
getattr(ref_model, "is_loaded_in_8bit", False) or getattr(ref_model, "is_loaded_in_4bit", False)
|
||||
): # quantized models are already set on the correct device
|
||||
self.ref_model = self._prepare_deepspeed(self.ref_model)
|
||||
self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
|
||||
else:
|
||||
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
|
||||
self.ref_model.eval()
|
||||
|
||||
@@ -11,14 +11,13 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""MCA (mcore_adapter) workflows for PT/SFT/DPO stages, aligned with LLaMA-Factory's workflow style."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
from collections.abc import Sequence
|
||||
from copy import deepcopy
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from transformers import DataCollatorForSeq2Seq
|
||||
|
||||
from ...data import (
|
||||
SFTDataCollatorWith4DAttentionMask,
|
||||
@@ -44,11 +43,11 @@ from mcore_adapter.models import AutoConfig, AutoModel
|
||||
from mcore_adapter.trainer import DPOTrainer as McaDPOTrainer
|
||||
from mcore_adapter.trainer import McaTrainer
|
||||
from mcore_adapter.trainer.dpo_config import DPOConfig
|
||||
from mcore_adapter.training_args import Seq2SeqTrainingArguments as McaSeq2SeqTrainingArguments
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import DataCollatorForSeq2Seq, TrainerCallback
|
||||
from mcore_adapter.training_args import Seq2SeqTrainingArguments as McaSeq2SeqTrainingArguments
|
||||
from transformers import TrainerCallback
|
||||
|
||||
from ...hparams import DataArguments, FinetuningArguments, ModelArguments
|
||||
|
||||
@@ -76,7 +75,7 @@ def _data_collator_wrapper(data_collator: Any):
|
||||
return wrapper
|
||||
|
||||
|
||||
def _check_model_support(model_args: ModelArguments):
|
||||
def _check_model_support(model_args: "ModelArguments"):
|
||||
from transformers import AutoConfig as HfAutoConfig
|
||||
|
||||
config = HfAutoConfig.from_pretrained(
|
||||
@@ -87,11 +86,11 @@ def _check_model_support(model_args: ModelArguments):
|
||||
|
||||
|
||||
def run_pt(
|
||||
model_args: ModelArguments,
|
||||
data_args: DataArguments,
|
||||
training_args: McaSeq2SeqTrainingArguments,
|
||||
finetuning_args: FinetuningArguments,
|
||||
callbacks: list[TrainerCallback] | None = None,
|
||||
model_args: "ModelArguments",
|
||||
data_args: "DataArguments",
|
||||
training_args: "McaSeq2SeqTrainingArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
callbacks: Optional[list["TrainerCallback"]] = None,
|
||||
):
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
tokenizer = tokenizer_module["tokenizer"]
|
||||
@@ -104,10 +103,7 @@ def run_pt(
|
||||
|
||||
_check_model_support(model_args)
|
||||
model = AutoModel.from_pretrained(model_args.model_name_or_path, training_args)
|
||||
|
||||
from transformers import DataCollatorForSeq2Seq
|
||||
|
||||
data_collator: DataCollatorForSeq2Seq = DataCollatorForSeq2Seq(
|
||||
data_collator = DataCollatorForSeq2Seq(
|
||||
tokenizer=tokenizer,
|
||||
pad_to_multiple_of=8,
|
||||
label_pad_token_id=IGNORE_INDEX,
|
||||
@@ -142,11 +138,11 @@ def run_pt(
|
||||
|
||||
|
||||
def run_sft(
|
||||
model_args: ModelArguments,
|
||||
data_args: DataArguments,
|
||||
training_args: McaSeq2SeqTrainingArguments,
|
||||
finetuning_args: FinetuningArguments,
|
||||
callbacks: list[TrainerCallback] | None = None,
|
||||
model_args: "ModelArguments",
|
||||
data_args: "DataArguments",
|
||||
training_args: "McaSeq2SeqTrainingArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
callbacks: Optional[list["TrainerCallback"]] = None,
|
||||
):
|
||||
# align packing flags
|
||||
# TODO: FIX SequencePacking
|
||||
@@ -166,7 +162,7 @@ def run_sft(
|
||||
model = AutoModel.from_pretrained(model_args.model_name_or_path, training_args)
|
||||
|
||||
# optional freezing for qwen2_vl, qwen2_5_vl
|
||||
if getattr(model.config, "hf_model_type", None) in ["qwen2_vl", "qwen2_5_vl"]:
|
||||
if getattr(model.config, "hf_model_type", None) in ["qwen2_vl", "qwen2_5_vl", "qwen3_vl"]:
|
||||
params_to_freeze = []
|
||||
if finetuning_args.freeze_vision_tower:
|
||||
params_to_freeze.extend(["vision_model.blocks", "vision_model.patch_embed"])
|
||||
@@ -220,11 +216,11 @@ def run_sft(
|
||||
|
||||
|
||||
def run_dpo(
|
||||
model_args: ModelArguments,
|
||||
data_args: DataArguments,
|
||||
training_args: McaSeq2SeqTrainingArguments,
|
||||
finetuning_args: FinetuningArguments,
|
||||
callbacks: list[TrainerCallback] | None = None,
|
||||
model_args: "ModelArguments",
|
||||
data_args: "DataArguments",
|
||||
training_args: "McaSeq2SeqTrainingArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
callbacks: Optional[list["TrainerCallback"]] = None,
|
||||
):
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
tokenizer = tokenizer_module["tokenizer"]
|
||||
|
||||
@@ -33,12 +33,12 @@ from transformers.trainer_pt_utils import remove_dummy_checkpoint
|
||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
||||
from transformers.utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME
|
||||
from trl import PPOConfig, PPOTrainer
|
||||
from trl.core import PPODecorators, logprobs_from_logits
|
||||
from trl import __version__ as trl_version
|
||||
from trl.models.utils import unwrap_model_for_generation
|
||||
from typing_extensions import override
|
||||
|
||||
from ...extras import logging
|
||||
from ...extras.misc import AverageMeter, count_parameters, get_current_device, get_logits_processor
|
||||
from ...extras.misc import AverageMeter, count_parameters, get_current_device, get_logits_processor, torch_gc
|
||||
from ..callbacks import FixValueHeadModelCallback, SaveProcessorCallback
|
||||
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
|
||||
from .ppo_utils import dump_layernorm, get_rewards_from_server, replace_model, restore_layernorm
|
||||
@@ -83,6 +83,19 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
if eval_dataset is not None:
|
||||
raise NotImplementedError("PPOTrainer does not support eval dataset yet.")
|
||||
|
||||
# Check if TRL version is compatible (0.8.6 <= version <= 0.9.6)
|
||||
try:
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
require_version(
|
||||
"trl>=0.8.6,<=0.9.6",
|
||||
"Incompatible TRL version detected. LLaMA-Factory ppo requires TRL version >=0.8.6,<=0.9.6. "
|
||||
f"Found version {trl_version}. Please install the correct version with: `pip install trl>=0.8.6,<=0.9.6`\n"
|
||||
"To fix: run `DISABLE_VERSION_CHECK=1 llamafactory-cli train example_ppo.yaml`\n",
|
||||
)
|
||||
except ImportError as e:
|
||||
raise e
|
||||
|
||||
backward_batch_size = training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps
|
||||
ppo_config = PPOConfig(
|
||||
model_name=model_args.model_name_or_path,
|
||||
@@ -406,7 +419,6 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
return rewards.float().detach() # use fp32 type
|
||||
|
||||
@override
|
||||
@PPODecorators.empty_device_cache()
|
||||
def batched_forward_pass(
|
||||
self,
|
||||
model: "AutoModelForCausalLMWithValueHead",
|
||||
@@ -420,6 +432,9 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
|
||||
Subclass and override to inject custom behavior.
|
||||
"""
|
||||
from trl.core import logprobs_from_logits
|
||||
|
||||
torch_gc()
|
||||
bs = len(queries)
|
||||
fbs = self.config.mini_batch_size
|
||||
all_logprobs = []
|
||||
|
||||
@@ -108,7 +108,7 @@ def create_modelcard_and_push(
|
||||
elif training_args.push_to_hub:
|
||||
trainer.push_to_hub(**kwargs)
|
||||
else:
|
||||
trainer.create_model_card(license="other", **kwargs) # prevent from connecting to hub
|
||||
Trainer.create_model_card(trainer, license="other", **kwargs) # prevent from connecting to hub
|
||||
|
||||
|
||||
def create_ref_model(
|
||||
|
||||
@@ -112,6 +112,13 @@ class ModelLoader:
|
||||
|
||||
model = PeftPlugin(self.args.peft_config.name)(model, self.args.peft_config, self.is_train)
|
||||
|
||||
if self.args.kernel_config is not None:
|
||||
from ..plugins.model_plugins.kernels.interface import KernelPlugin
|
||||
|
||||
model = KernelPlugin(self.args.kernel_config.name)(
|
||||
model=model, include_kernels=self.args.kernel_config.get("include_kernels")
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
|
||||
87
src/llamafactory/v1/plugins/model_plugins/kernels/base.py
Normal file
87
src/llamafactory/v1/plugins/model_plugins/kernels/base.py
Normal file
@@ -0,0 +1,87 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""The definition of base kernel class.
|
||||
|
||||
Init Phase:
|
||||
1. Define base kernel class.
|
||||
2. Define abstract methods.
|
||||
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
from ....accelerator.helper import DeviceType, get_current_accelerator
|
||||
from ....utils.types import HFModel
|
||||
|
||||
|
||||
class BaseKernel(ABC):
|
||||
r"""Base class for all kernel implementations.
|
||||
|
||||
Subclasses must implement the abstract methods and define the required class attributes.
|
||||
"""
|
||||
|
||||
_kernel_id: Any = "" # kernel ID, any hashable value to identify a kernel implementation
|
||||
_device: DeviceType = DeviceType.CPU # "cuda", "npu", "cpu", etc.
|
||||
|
||||
@classmethod
|
||||
def get_kernel_id(cls) -> str:
|
||||
r"""Returns the unique identifier for the kernel."""
|
||||
return cls._kernel_id
|
||||
|
||||
@classmethod
|
||||
def get_device(cls) -> str:
|
||||
r"""Returns the device type associated with the kernel (e.g., "cuda", "npu", "cpu")."""
|
||||
return cls._device
|
||||
|
||||
@classmethod
|
||||
def check_deps(cls) -> bool:
|
||||
r"""Checks if the required dependencies for the kernel are available.
|
||||
|
||||
Returns:
|
||||
bool: ``True`` if dependencies are met, ``False`` otherwise.
|
||||
|
||||
.. note::
|
||||
In explicit mode, if a user specifies an implementation but this check fails,
|
||||
it should raise an error instead of silently switching.
|
||||
Kernels can override this method to implement custom dependency checks.
|
||||
"""
|
||||
if cls._device != get_current_accelerator().type:
|
||||
return False
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def apply(cls, **kwargs) -> HFModel:
|
||||
r"""Applies the kernel optimization to the model.
|
||||
|
||||
Args:
|
||||
**kwargs: Arbitrary keyword arguments, usually containing the model instance and the kernel configuration.
|
||||
|
||||
Returns:
|
||||
HFModel: The model with the kernel applied.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the kernel dependencies are not met.
|
||||
NotImplementedError: If the method is not implemented by the subclass.
|
||||
|
||||
Example:
|
||||
>>> from llamafactory.v1.plugins.model_plugins.kernels.interface import apply_kernel
|
||||
>>> model = HFModel(config=config)
|
||||
>>> model = apply_kernel(model=model, kernel_id="npu_fused_moe")
|
||||
"""
|
||||
if not cls.check_deps():
|
||||
raise RuntimeError(f"{cls.__name__} is not available but {cls.__name__} kernel was called.")
|
||||
raise NotImplementedError
|
||||
@@ -1,23 +0,0 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class KernelType(str, Enum):
|
||||
RMSNORM = "rmsnorm"
|
||||
SWIGLU = "swiglu"
|
||||
FLASH_ATTENTION = "flash_attention"
|
||||
ROPE = "rope"
|
||||
MOE = "moe"
|
||||
132
src/llamafactory/v1/plugins/model_plugins/kernels/interface.py
Normal file
132
src/llamafactory/v1/plugins/model_plugins/kernels/interface.py
Normal file
@@ -0,0 +1,132 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""The definition of kernel interface.
|
||||
|
||||
Init Phase:
|
||||
1. Scan all kernels.
|
||||
2. Register default kernels.
|
||||
3. Define kernel plugin.
|
||||
|
||||
"""
|
||||
|
||||
import importlib
|
||||
from pathlib import Path
|
||||
|
||||
from ....utils.logging import get_logger
|
||||
from ....utils.plugin import BasePlugin
|
||||
from .registry import Registry
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def scan_all_kernels():
|
||||
r"""Scan all kernels in the ``ops`` directory.
|
||||
|
||||
Scans the ``ops`` directory for all ``.py`` files and attempts to import them.
|
||||
Importing triggers the :func:`~registry.register_kernel` decorator, which automatically registers the kernels.
|
||||
|
||||
Returns:
|
||||
dict[str, type[BaseKernel]]: A dictionary of registered kernels.
|
||||
|
||||
.. note::
|
||||
This function assumes that the ``ops`` directory is located in the same directory as this file.
|
||||
It recursively searches for ``.py`` files and constructs the module path for import.
|
||||
"""
|
||||
ops_path = Path(__file__).parent / "ops"
|
||||
|
||||
if not ops_path.exists():
|
||||
return
|
||||
|
||||
base_package = __package__
|
||||
|
||||
for file_path in ops_path.rglob("*.py"):
|
||||
if file_path.name == "__init__.py":
|
||||
continue
|
||||
|
||||
# calculate the relative path:
|
||||
# file_path = .../kernels_v2/ops/mlp/npu_swiglu.py
|
||||
# rel_path = ops/mlp/npu_swiglu.py
|
||||
rel_path = file_path.relative_to(Path(__file__).parent)
|
||||
|
||||
# build module path:
|
||||
module_name = ".".join(rel_path.parts)[:-3]
|
||||
full_module_name = f"{base_package}.{module_name}"
|
||||
|
||||
try:
|
||||
importlib.import_module(full_module_name)
|
||||
except Exception as e:
|
||||
logger.warning(f"[Kernel Registry] Failed to import {full_module_name} when loading kernels: {e}")
|
||||
|
||||
return Registry.get_registered_kernels()
|
||||
|
||||
|
||||
default_kernels = scan_all_kernels()
|
||||
|
||||
|
||||
def get_default_kernels():
|
||||
r"""Get a list of default registered kernel IDs.
|
||||
|
||||
Returns:
|
||||
list[str]: List of kernel IDs.
|
||||
"""
|
||||
return list(default_kernels.keys())
|
||||
|
||||
|
||||
def apply_kernel(kernel_id: str, **kwargs):
|
||||
r"""Applies a specific kernel to the model.
|
||||
|
||||
Args:
|
||||
kernel_id (str): The ID of the kernel to apply.
|
||||
**kwargs: Keyword arguments passed to the kernel application function.
|
||||
Typically includes the model instance.
|
||||
|
||||
Returns:
|
||||
HFModel: The model with applied kernel.
|
||||
"""
|
||||
kernel = default_kernels.get(kernel_id)
|
||||
if kernel is None:
|
||||
raise ValueError(f"Kernel {kernel_id} not found")
|
||||
kernel.apply(**kwargs)
|
||||
|
||||
|
||||
class KernelPlugin(BasePlugin):
|
||||
r"""Plugin for managing kernel optimizations."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@KernelPlugin("auto").register
|
||||
def apply_default_kernels(**kwargs):
|
||||
r"""Applies all default registered kernels to the model.
|
||||
|
||||
Args:
|
||||
**kwargs: Keyword arguments passed to the kernel application function.
|
||||
Typically includes the model instance and the include_kernels configuration.
|
||||
|
||||
Returns:
|
||||
HFModel: The model with applied kernels.
|
||||
"""
|
||||
if not kwargs.get("include_kernels"): # None/False/empty string
|
||||
return kwargs.get("model")
|
||||
elif kwargs.get("include_kernels") == "auto" or kwargs.get("include_kernels") is True: # True/auto
|
||||
use_kernels = default_kernels.keys()
|
||||
else:
|
||||
use_kernels = kwargs.get("include_kernels").split(",") # "kernel_id1,kernel_id2,kernel_id3"
|
||||
for kernel in use_kernels:
|
||||
if kernel not in default_kernels:
|
||||
raise ValueError(f"Kernel {kernel} not found")
|
||||
apply_kernel(kernel, **kwargs)
|
||||
return kwargs.get("model")
|
||||
@@ -12,22 +12,49 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""The definition of NPU fused MoE kernels.
|
||||
|
||||
Init Phase:
|
||||
1. Define GMM functions.
|
||||
2. Define NPU fused MoE functions.
|
||||
3. Register NPU fused MoE kernel.
|
||||
|
||||
"""
|
||||
|
||||
import types
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch_npu
|
||||
|
||||
from .....accelerator.helper import DeviceType, is_torch_npu_available
|
||||
from .....utils.packages import is_transformers_version_greater_than
|
||||
from .....utils.types import HFModel
|
||||
from ..constants import KernelType
|
||||
from ..registry import MetaMoEKernel
|
||||
|
||||
try:
|
||||
import torch_npu
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
from ......accelerator.helper import DeviceType
|
||||
from ......utils.packages import is_transformers_version_greater_than
|
||||
from ......utils.types import HFModel
|
||||
from ...base import BaseKernel
|
||||
from ...registry import register_kernel
|
||||
|
||||
|
||||
class GmmFunction(torch.autograd.Function):
|
||||
r"""Custom autograd function for NPU Grouped Matrix Multiplication (GMM)."""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, x, weight, group_list):
|
||||
r"""Performs the forward pass of Grouped Matrix Multiplication.
|
||||
|
||||
Args:
|
||||
ctx: Context object to save tensors for backward pass.
|
||||
x (Tensor): Input tensor.
|
||||
weight (Tensor): Weight tensor.
|
||||
group_list (list): List of group sizes.
|
||||
|
||||
Returns:
|
||||
Tensor: The result of the grouped matrix multiplication.
|
||||
"""
|
||||
ctx.save_for_backward(x, weight)
|
||||
ctx.group_list = group_list
|
||||
|
||||
@@ -38,6 +65,15 @@ class GmmFunction(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
r"""Performs the backward pass of Grouped Matrix Multiplication.
|
||||
|
||||
Args:
|
||||
ctx: Context object containing saved tensors.
|
||||
grad_output (Tensor): Gradient with respect to the output.
|
||||
|
||||
Returns:
|
||||
tuple: Gradients with respect to input, weight, and None for group_list.
|
||||
"""
|
||||
input_tensor, weight = ctx.saved_tensors
|
||||
group_list = ctx.group_list
|
||||
|
||||
@@ -58,8 +94,20 @@ class GmmFunction(torch.autograd.Function):
|
||||
|
||||
|
||||
class HybridGmmFunction(torch.autograd.Function):
|
||||
r"""Custom autograd function for Hybrid Grouped Matrix Multiplication on NPU."""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, num_experts, *args):
|
||||
r"""Performs the forward pass of Hybrid GMM.
|
||||
|
||||
Args:
|
||||
ctx: Context object to save tensors.
|
||||
num_experts (int): Number of experts.
|
||||
*args: Variable length argument list containing inputs and weights.
|
||||
|
||||
Returns:
|
||||
tuple: The outputs of the grouped matrix multiplication.
|
||||
"""
|
||||
x_list = list(args[:num_experts])
|
||||
weight_list = list(args[num_experts:])
|
||||
|
||||
@@ -76,6 +124,15 @@ class HybridGmmFunction(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, *grad_outputs):
|
||||
r"""Performs the backward pass of Hybrid GMM.
|
||||
|
||||
Args:
|
||||
ctx: Context object containing saved tensors.
|
||||
*grad_outputs: Gradients with respect to the outputs.
|
||||
|
||||
Returns:
|
||||
tuple: Gradients with respect to inputs and weights.
|
||||
"""
|
||||
saved_tensors = ctx.saved_tensors
|
||||
num_experts = ctx.num_experts
|
||||
split_sizes = ctx.split_sizes
|
||||
@@ -119,10 +176,23 @@ class HybridGmmFunction(torch.autograd.Function):
|
||||
|
||||
|
||||
class NpuMoeFused:
|
||||
r"""Container for NPU fused MoE forward functions."""
|
||||
|
||||
@staticmethod
|
||||
def npu_moe_experts_forward(
|
||||
self, hidden_states: torch.Tensor, routing_weights: torch.Tensor, router_indices: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
r"""Forward pass for MoE experts using NPU fused operations.
|
||||
|
||||
Args:
|
||||
self: The MoE layer instance.
|
||||
hidden_states (Tensor): Input hidden states.
|
||||
routing_weights (Tensor): Routing weights.
|
||||
router_indices (Tensor): Router indices.
|
||||
|
||||
Returns:
|
||||
Tensor: Output tensor after expert computation.
|
||||
"""
|
||||
batch_size = hidden_states.shape[0]
|
||||
hidden_states = hidden_states.reshape(-1, self.hidden_size)
|
||||
permuted_hidden_states, row_ids_map = torch_npu.npu_moe_token_permute(
|
||||
@@ -138,6 +208,15 @@ class NpuMoeFused:
|
||||
|
||||
@staticmethod
|
||||
def npu_moe_sparse_block_forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
r"""Forward pass for sparse MoE block using NPU optimization.
|
||||
|
||||
Args:
|
||||
self: The MoE sparse block instance.
|
||||
hidden_states (Tensor): Input hidden states.
|
||||
|
||||
Returns:
|
||||
Tensor: The routed output.
|
||||
"""
|
||||
batch_size = hidden_states.shape[0]
|
||||
hidden_states = hidden_states.reshape(-1, self.hidden_size)
|
||||
router_logits = self.gate(hidden_states)
|
||||
@@ -151,8 +230,19 @@ class NpuMoeFused:
|
||||
|
||||
|
||||
class Qwen3NpuMoeFused:
|
||||
r"""Container for Qwen3 NPU fused MoE forward functions."""
|
||||
|
||||
@staticmethod
|
||||
def qwen3moe_sparse_moe_block_forward(self, hidden_states: torch.Tensor):
|
||||
r"""Forward pass for Qwen3 sparse MoE block using NPU fused operations.
|
||||
|
||||
Args:
|
||||
self: The Qwen3 MoE block instance.
|
||||
hidden_states (Tensor): Input hidden states.
|
||||
|
||||
Returns:
|
||||
tuple: A tuple containing the next states and router logits.
|
||||
"""
|
||||
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||
|
||||
@@ -206,14 +296,33 @@ if not is_transformers_version_greater_than("5.0.0"):
|
||||
}
|
||||
|
||||
|
||||
class NpuMoEFusedMoEKernel(MetaMoEKernel):
|
||||
type = KernelType.MOE
|
||||
device = DeviceType.NPU
|
||||
@register_kernel
|
||||
class NpuFusedMoEKernel(BaseKernel):
|
||||
r"""NPU Fused MoE Kernel implementation."""
|
||||
|
||||
_kernel_id = "npu_fused_moe"
|
||||
_device = DeviceType.NPU
|
||||
|
||||
@classmethod
|
||||
def apply(cls, model, **kwargs) -> HFModel:
|
||||
if not is_torch_npu_available():
|
||||
return model
|
||||
def apply(cls, **kwargs) -> HFModel:
|
||||
r"""Applies the NPU fused MoE kernel to the model.
|
||||
|
||||
Args:
|
||||
**kwargs: Keyword arguments containing the model.
|
||||
|
||||
Returns:
|
||||
HFModel: The model with patched MoE forward functions.
|
||||
|
||||
Raises:
|
||||
ValueError: If the model is not provided.
|
||||
RuntimeError: If dependencies are not met.
|
||||
"""
|
||||
model = kwargs.get("model", None)
|
||||
if model is None:
|
||||
raise ValueError(f"HFModel instance is required for {cls.__name__}.")
|
||||
|
||||
if not cls.check_deps():
|
||||
raise RuntimeError("torch_npu is not available but NpuMoEFusedMoEKernel was called.")
|
||||
|
||||
archs = getattr(model.config, "architectures", [])
|
||||
target_moe_mapping = None
|
||||
@@ -12,36 +12,71 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""The definition of NPU fused SwiGLU kernels.
|
||||
|
||||
Init Phase:
|
||||
1. Define SwiGLU forward functions.
|
||||
2. Register NPU fused SwiGLU kernel.
|
||||
|
||||
"""
|
||||
|
||||
import re
|
||||
import types
|
||||
|
||||
import torch
|
||||
|
||||
from .....accelerator.helper import DeviceType, is_torch_npu_available
|
||||
from .....utils.types import HFModel
|
||||
from ..constants import KernelType
|
||||
from ..registry import MetaSwiGluKernel
|
||||
from ......accelerator.helper import DeviceType
|
||||
from ......utils.types import HFModel
|
||||
from ...base import BaseKernel
|
||||
from ...registry import register_kernel
|
||||
|
||||
|
||||
def _npu_swiglu_forward(self, hidden_state):
|
||||
try:
|
||||
import torch_npu
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
def npu_swiglu_forward(self, hidden_state):
|
||||
r"""SwiGLU forward pass for NPU.
|
||||
|
||||
Args:
|
||||
self: The MLP layer instance.
|
||||
hidden_state (Tensor): Input hidden state.
|
||||
|
||||
Returns:
|
||||
Tensor: Output of SwiGLU.
|
||||
"""
|
||||
return self.down_proj(
|
||||
torch_npu.npu_swiglu(torch.cat((self.gate_proj(hidden_state), self.up_proj(hidden_state)), dim=-1), dim=-1)
|
||||
)
|
||||
|
||||
|
||||
def _npu_swiglu_glm4_forward(self, hidden_states):
|
||||
import torch_npu
|
||||
r"""SwiGLU forward pass for GLM4 on NPU.
|
||||
|
||||
Args:
|
||||
self: The GLM4 MLP layer instance.
|
||||
hidden_states (Tensor): Input hidden states.
|
||||
|
||||
Returns:
|
||||
Tensor: Output of SwiGLU.
|
||||
"""
|
||||
up_states = self.gate_up_proj(hidden_states)
|
||||
gate, up_states = up_states.chunk(2, dim=-1)
|
||||
return self.down_proj(torch_npu.npu_swiglu(torch.cat((gate, up_states), dim=-1), dim=-1))
|
||||
|
||||
|
||||
def _npu_swiglu_gemma3ntext_forward(self, hidden_states):
|
||||
import torch_npu
|
||||
r"""SwiGLU forward pass for Gemma3nText on NPU.
|
||||
|
||||
Args:
|
||||
self: The Gemma3nText MLP layer instance.
|
||||
hidden_states (Tensor): Input hidden states.
|
||||
|
||||
Returns:
|
||||
Tensor: Output of SwiGLU.
|
||||
"""
|
||||
gate_proj = self.gate_proj(hidden_states)
|
||||
if self.activation_sparsity > 0.0:
|
||||
gate_proj = self._gaussian_topk(gate_proj)
|
||||
@@ -51,12 +86,11 @@ def _npu_swiglu_gemma3ntext_forward(self, hidden_states):
|
||||
return down_proj
|
||||
|
||||
|
||||
class NpuSwiGluKernel(MetaSwiGluKernel):
|
||||
type = KernelType.SWIGLU
|
||||
device = DeviceType.NPU
|
||||
kernel = _npu_swiglu_forward
|
||||
@register_kernel
|
||||
class NpuSwiGluKernel(BaseKernel):
|
||||
r"""NPU Kernel for fused SwiGLU activation."""
|
||||
|
||||
# Don't apply the kernel to the following modules
|
||||
# just support apply to the following module layers
|
||||
expect_modules = frozenset(
|
||||
{
|
||||
"Qwen3VLMoeTextMLP",
|
||||
@@ -87,10 +121,29 @@ class NpuSwiGluKernel(MetaSwiGluKernel):
|
||||
}
|
||||
)
|
||||
|
||||
_kernel_id = "npu_fused_swiglu"
|
||||
_device = DeviceType.NPU
|
||||
|
||||
@classmethod
|
||||
def apply(cls, model, **kwargs) -> "HFModel":
|
||||
if not is_torch_npu_available():
|
||||
return model
|
||||
def apply(cls, **kwargs) -> "HFModel":
|
||||
r"""Applies the NPU fused SwiGLU kernel to the model.
|
||||
|
||||
Args:
|
||||
**kwargs: Keyword arguments containing the model.
|
||||
|
||||
Returns:
|
||||
HFModel: The model with patched SwiGLU forward functions.
|
||||
|
||||
Raises:
|
||||
ValueError: If the model is not provided.
|
||||
RuntimeError: If dependencies are not met.
|
||||
"""
|
||||
model = kwargs.get("model", None)
|
||||
if model is None:
|
||||
raise ValueError(f"HFModel instance is required for {cls.__name__}.")
|
||||
|
||||
if not cls.check_deps():
|
||||
raise RuntimeError("torch_npu is not available but NpuSwiGluKernel was called.")
|
||||
|
||||
# Mapping of specific mlp modules to their corresponding kernel implementations
|
||||
kernel_mapping = {
|
||||
@@ -109,7 +162,7 @@ class NpuSwiGluKernel(MetaSwiGluKernel):
|
||||
):
|
||||
# Bind function as an instance method to preserve `self` semantics
|
||||
# and replace the original forward
|
||||
kernel_func = kernel_mapping.get(module.__class__.__name__, _npu_swiglu_forward)
|
||||
kernel_func = kernel_mapping.get(module.__class__.__name__, npu_swiglu_forward)
|
||||
module.forward = types.MethodType(kernel_func, module)
|
||||
|
||||
return model
|
||||
@@ -11,40 +11,49 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""The definition of NPU fused RMSNorm kernels.
|
||||
|
||||
Init Phase:
|
||||
1. Define RMSNorm forward function.
|
||||
2. Register NPU fused RMSNorm kernel.
|
||||
|
||||
"""
|
||||
|
||||
import re
|
||||
import types
|
||||
|
||||
from .....accelerator.helper import DeviceType, is_torch_npu_available
|
||||
from .....utils.types import HFModel
|
||||
from ..constants import KernelType
|
||||
from ..registry import MetaRMSNormKernel
|
||||
from ......accelerator.helper import DeviceType
|
||||
from ......utils.types import HFModel
|
||||
from ...base import BaseKernel
|
||||
from ...registry import register_kernel
|
||||
|
||||
|
||||
def _npu_rms_forward(self, hidden_states):
|
||||
"""NPU forward implementation for RMSNorm.
|
||||
def npu_rms_norm_forward(self, hidden_states):
|
||||
r"""NPU forward implementation for RMSNorm.
|
||||
|
||||
Args:
|
||||
self: RMSNorm module instance with `weight` and `variance_epsilon`.
|
||||
hidden_states: Input hidden states tensor, same shape as the baseline.
|
||||
hidden_states (Tensor): Input hidden states tensor, same shape as the baseline.
|
||||
|
||||
Returns:
|
||||
Normalized tensor consistent with the baseline RMSNorm behavior.
|
||||
Tensor: Normalized tensor consistent with the baseline RMSNorm behavior.
|
||||
"""
|
||||
import torch_npu
|
||||
|
||||
return torch_npu.npu_rms_norm(hidden_states, self.weight, epsilon=self.variance_epsilon)[0]
|
||||
|
||||
|
||||
class NpuRMSNormKernel(MetaRMSNormKernel):
|
||||
"""NPU kernel wrapper for RMSNorm that applies the replacement within a model."""
|
||||
@register_kernel
|
||||
class NpuRMSNormKernel(BaseKernel):
|
||||
r"""NPU kernel wrapper for RMSNorm that applies the replacement within a model."""
|
||||
|
||||
type = KernelType.RMSNORM
|
||||
device = DeviceType.NPU
|
||||
kernel = _npu_rms_forward
|
||||
_kernel_id = "npu_fused_rmsnorm"
|
||||
_device = DeviceType.NPU
|
||||
|
||||
@classmethod
|
||||
def apply(cls, model, **kwargs) -> HFModel:
|
||||
"""Iterate the model and apply NPU-optimized forward to matched RMSNorm modules.
|
||||
def apply(cls, **kwargs) -> "HFModel":
|
||||
r"""Iterate the model and apply NPU-optimized forward to matched RMSNorm modules.
|
||||
|
||||
Key points:
|
||||
- Match modules whose class name contains "RMSNorm" (case-insensitive).
|
||||
@@ -52,10 +61,23 @@ class NpuRMSNormKernel(MetaRMSNormKernel):
|
||||
replace the original `forward`.
|
||||
- Do not modify weights, hyperparameters, or module structure to ensure
|
||||
numerical behavior and interface consistency.
|
||||
"""
|
||||
if not is_torch_npu_available():
|
||||
return model
|
||||
|
||||
Args:
|
||||
**kwargs: Keyword arguments containing the model.
|
||||
|
||||
Returns:
|
||||
HFModel: The model with NPU fused RMSNorm.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If torch_npu is not available.
|
||||
ValueError: If the model is not provided.
|
||||
"""
|
||||
model = kwargs.get("model")
|
||||
if model is None:
|
||||
raise ValueError(f"HFModel instance is required for {cls.__name__}.")
|
||||
|
||||
if not cls.check_deps():
|
||||
raise RuntimeError(f"torch_npu is not available but {cls.__name__} was called.")
|
||||
rms_norm_pattern = re.compile("RMSNorm", re.IGNORECASE)
|
||||
|
||||
for name, module in model.named_modules():
|
||||
@@ -63,6 +85,6 @@ class NpuRMSNormKernel(MetaRMSNormKernel):
|
||||
if re.search(rms_norm_pattern, module.__class__.__name__):
|
||||
# Bind function as an instance method to preserve `self` semantics
|
||||
# and replace the original forward
|
||||
module.forward = types.MethodType(cls.kernel, module)
|
||||
module.forward = types.MethodType(npu_rms_norm_forward, module)
|
||||
|
||||
return model
|
||||
@@ -0,0 +1,146 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""The definition of NPU fused RoPE kernels.
|
||||
|
||||
Init Phase:
|
||||
1. Define RoPE forward functions.
|
||||
2. Register NPU fused RoPE kernel.
|
||||
|
||||
"""
|
||||
|
||||
import sys
|
||||
|
||||
import torch
|
||||
|
||||
from ......accelerator.helper import DeviceType
|
||||
from ......utils.logging import get_logger
|
||||
from ......utils.types import HFModel
|
||||
from ...base import BaseKernel
|
||||
from ...registry import register_kernel
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
try:
|
||||
import torch_npu
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
def _apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
||||
r"""Applies Rotary Position Embedding to the query and key tensors using NPU optimization.
|
||||
|
||||
Args:
|
||||
q (Tensor): Query tensor.
|
||||
k (Tensor): Key tensor.
|
||||
cos (Tensor): Cosine part of embedding.
|
||||
sin (Tensor): Sine part of embedding.
|
||||
position_ids (Tensor, optional): Position IDs. Default: ``None``.
|
||||
unsqueeze_dim (int): Dimension to unsqueeze cos and sin. Default: 1.
|
||||
|
||||
Returns:
|
||||
tuple: (q_embed, k_embed) The embedded query and key tensors.
|
||||
"""
|
||||
cos = cos.unsqueeze(unsqueeze_dim)
|
||||
sin = sin.unsqueeze(unsqueeze_dim)
|
||||
q_embed = torch_npu.npu_rotary_mul(q, cos, sin)
|
||||
k_embed = torch_npu.npu_rotary_mul(k, cos, sin)
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
def _apply_multimodal_rotary_pos_emb_qwen25_vl(q, k, cos, sin, mrope_section, unsqueeze_dim=1):
|
||||
r"""Applies Rotary Position Embedding with multimodal sections (Qwen2-VL) on NPU.
|
||||
|
||||
Args:
|
||||
q (Tensor): Query tensor.
|
||||
k (Tensor): Key tensor.
|
||||
cos (Tensor): Cosine part of embedding.
|
||||
sin (Tensor): Sine part of embedding.
|
||||
mrope_section (Tensor): Multimodal RoPE section.
|
||||
unsqueeze_dim (int): Dimension to unsqueeze cos and sin. Default: 1.
|
||||
|
||||
Returns:
|
||||
tuple: (q_embed, k_embed) The embedded query and key tensors.
|
||||
"""
|
||||
mrope_section = mrope_section * 2
|
||||
cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
|
||||
unsqueeze_dim
|
||||
)
|
||||
sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
|
||||
unsqueeze_dim
|
||||
)
|
||||
|
||||
q_embed = torch_npu.npu_rotary_mul(q, cos, sin)
|
||||
k_embed = torch_npu.npu_rotary_mul(k, cos, sin)
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
@register_kernel
|
||||
class NpuRoPEKernel(BaseKernel):
|
||||
r"""NPU Kernel for Rotary Position Embedding."""
|
||||
|
||||
_kernel_id = "npu_fused_rope"
|
||||
_device = DeviceType.NPU
|
||||
|
||||
@classmethod
|
||||
def apply(cls, **kwargs) -> "HFModel":
|
||||
r"""Apply RoPE acceleration by monkey-patching `apply_rotary_pos_emb`.
|
||||
|
||||
This function iterates through the model's modules to find attention layers,
|
||||
identifies the module where they are defined, and replaces the original
|
||||
`apply_rotary_pos_emb` function in that module's namespace with the
|
||||
NPU-accelerated version from this file.
|
||||
|
||||
Args:
|
||||
**kwargs: Keyword arguments containing the model.
|
||||
|
||||
Returns:
|
||||
HFModel: The model with patched RoPE functions.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If dependencies are not met.
|
||||
ValueError: If the model is not provided.
|
||||
"""
|
||||
if not cls.check_deps():
|
||||
raise RuntimeError(f"torch_npu is not available but {cls.__name__} was called.")
|
||||
model = kwargs.get("model", None)
|
||||
if model is None:
|
||||
raise ValueError(f"HFModel instance is required for {cls.__name__}.")
|
||||
_modules = set()
|
||||
for module in model.modules():
|
||||
if "Attention" in module.__class__.__name__:
|
||||
module_name = module.__class__.__module__
|
||||
if module_name in _modules:
|
||||
continue
|
||||
try:
|
||||
target_module = sys.modules[module_name]
|
||||
if hasattr(target_module, "apply_rotary_pos_emb"):
|
||||
if getattr(target_module, "apply_rotary_pos_emb") is not _apply_rotary_pos_emb:
|
||||
setattr(target_module, "apply_rotary_pos_emb", _apply_rotary_pos_emb)
|
||||
_modules.add(module_name)
|
||||
if hasattr(target_module, "apply_multimodal_rotary_pos_emb"):
|
||||
if (
|
||||
getattr(target_module, "apply_multimodal_rotary_pos_emb")
|
||||
is not _apply_multimodal_rotary_pos_emb_qwen25_vl
|
||||
):
|
||||
setattr(
|
||||
target_module,
|
||||
"apply_multimodal_rotary_pos_emb",
|
||||
_apply_multimodal_rotary_pos_emb_qwen25_vl,
|
||||
)
|
||||
_modules.add(module_name)
|
||||
except Exception as e:
|
||||
logger.warning_rank0_once(f"Failed to apply RoPE kernel to module {module_name}: {e}")
|
||||
return model
|
||||
@@ -12,247 +12,86 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from abc import ABC, ABCMeta, abstractmethod
|
||||
from collections.abc import Callable
|
||||
from typing import Any, Optional
|
||||
"""The definition of kernel registry.
|
||||
|
||||
from ....accelerator.helper import DeviceType, get_current_accelerator
|
||||
from ....utils.types import HFModel
|
||||
from .constants import KernelType
|
||||
Init Phase:
|
||||
1. Define kernel registry.
|
||||
2. Register kernels.
|
||||
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from ....accelerator.helper import get_current_accelerator
|
||||
from .base import BaseKernel
|
||||
|
||||
|
||||
class KernelRegistry:
|
||||
_instance: Optional["KernelRegistry"] = None
|
||||
_initialized: bool = False
|
||||
|
||||
def __new__(cls, *args: Any, **kwargs: Any) -> "KernelRegistry":
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
|
||||
return cls._instance
|
||||
|
||||
def __init__(self) -> None:
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
self._registry: dict[KernelType, dict[DeviceType, Callable[..., Any]]] = {}
|
||||
self._initialized = True
|
||||
|
||||
def register(
|
||||
self, kernel_type: KernelType, device_type: DeviceType, kernel_impl: Callable[..., Any] | None
|
||||
) -> None:
|
||||
"""Register a kernel implementation.
|
||||
|
||||
Args:
|
||||
kernel_type: the type of the kernel (e.g., KernelType.FLASH_ATTENTION).
|
||||
device_type: the device type the kernel is adapted to (e.g., DeviceType.CUDA).
|
||||
kernel_impl: the actual kernel function or class.
|
||||
"""
|
||||
if kernel_type not in self._registry:
|
||||
self._registry[kernel_type] = {}
|
||||
|
||||
if device_type in self._registry[kernel_type]:
|
||||
print(f"Warning: Overwriting kernel for {kernel_type.name} on {device_type.name}.")
|
||||
|
||||
self._registry[kernel_type][device_type] = kernel_impl
|
||||
print(f"Registered kernel {kernel_type.name} for device {device_type.name}.")
|
||||
|
||||
def get_kernel(self, kernel_type: KernelType, device_type: DeviceType) -> Callable[..., Any] | None:
|
||||
return self._registry.get(kernel_type, {}).get(device_type)
|
||||
__all__ = ["Registry", "register_kernel"]
|
||||
|
||||
|
||||
KERNEL_REGISTRY = KernelRegistry()
|
||||
class Registry:
|
||||
r"""Registry for managing kernel implementations.
|
||||
|
||||
|
||||
class AutoRegisterKernelMeta(ABCMeta):
|
||||
"""Metaclass that automatically registers kernel classes upon creation.
|
||||
|
||||
This metaclass checks if a newly created class has both `type` and `device`
|
||||
attributes defined. If so, it automatically registers the kernel in the
|
||||
global KERNEL_REGISTRY, eliminating the need for manual registration.
|
||||
|
||||
To disable auto-registration for a specific class, set `auto_register = False`.
|
||||
Storage structure: ``{ "kernel_id": Class }``
|
||||
"""
|
||||
|
||||
def __new__(mcs, name, bases, namespace, **kwargs):
|
||||
cls = super().__new__(mcs, name, bases, namespace, **kwargs)
|
||||
|
||||
# Check if auto-registration is disabled
|
||||
auto_register = namespace.get("auto_register", True)
|
||||
|
||||
# Only auto-register if the class has both type and device attributes defined
|
||||
# and they are not None (skip base classes like MetaKernel itself)
|
||||
# and auto_register is True
|
||||
kernel_type = namespace.get("type")
|
||||
device_type = namespace.get("device")
|
||||
|
||||
if auto_register and kernel_type is not None and device_type is not None:
|
||||
# Auto-register this kernel
|
||||
KERNEL_REGISTRY.register(kernel_type, device_type, cls)
|
||||
|
||||
return cls
|
||||
|
||||
|
||||
class MetaKernel(ABC, metaclass=AutoRegisterKernelMeta):
|
||||
"""Base class for all kernel implementations.
|
||||
|
||||
Subclasses are automatically registered when they define both `type` and `device`
|
||||
attributes. To disable auto-registration, set `auto_register = False`.
|
||||
|
||||
Attributes:
|
||||
type: The kernel type (e.g., KernelType.RMSNORM). Must be set in subclasses.
|
||||
device: The device type (e.g., DeviceType.NPU). Must be set in subclasses.
|
||||
kernel: The actual kernel function or implementation.
|
||||
auto_register: Set to False to disable automatic registration (default: True).
|
||||
"""
|
||||
|
||||
type: KernelType | None = None
|
||||
device: DeviceType | None = None
|
||||
kernel: Callable | None = None
|
||||
_kernels: dict[str, type[BaseKernel]] = {}
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def apply(cls, model: HFModel, **kwargs) -> HFModel:
|
||||
"""Apply the kernel to the model.
|
||||
def register(cls, kernel_cls: type[BaseKernel]):
|
||||
r"""Decorator to register a kernel class.
|
||||
|
||||
This method should check if the kernel can be applied (e.g., dependencies
|
||||
are installed, target modules exist) and perform the kernel replacement.
|
||||
The class must inherit from :class:`BaseKernel` and specify ``_kernel_id`` and ``_device`` attributes.
|
||||
|
||||
Args:
|
||||
model: The HuggingFace model to optimize.
|
||||
**kwargs: Additional arguments for kernel application.
|
||||
kernel_cls (type[BaseKernel]): The kernel class to register.
|
||||
|
||||
Returns:
|
||||
The optimized model (may be the same object with modifications).
|
||||
type[BaseKernel]: The registered kernel class.
|
||||
|
||||
Raises:
|
||||
TypeError: If the class does not inherit from :class:`BaseKernel`.
|
||||
ValueError: If the kernel ID is missing or already registered.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
if not issubclass(kernel_cls, BaseKernel):
|
||||
raise TypeError(f"Class {kernel_cls} must inherit from BaseKernel")
|
||||
kernel_id = kernel_cls.get_kernel_id()
|
||||
device = kernel_cls.get_device()
|
||||
|
||||
# The device type of the current accelerator does not match the device type required by the kernel, skip registration
|
||||
if device != get_current_accelerator().type:
|
||||
return
|
||||
|
||||
if not kernel_id:
|
||||
raise ValueError(f"Kernel ID (_kernel_id) is needed for {kernel_cls} to register")
|
||||
|
||||
if kernel_id in cls._kernels:
|
||||
raise ValueError(f"{kernel_id} already registered! The registered kernel is {cls._kernels[kernel_id]}")
|
||||
|
||||
cls._kernels[kernel_id] = kernel_cls
|
||||
return kernel_cls
|
||||
|
||||
class MetaFlashAttentionKernel(MetaKernel):
|
||||
@classmethod
|
||||
def apply(cls, model: HFModel, **kwargs) -> HFModel:
|
||||
raise NotImplementedError
|
||||
def get(cls, kernel_id: str) -> Optional[type[BaseKernel]]:
|
||||
r"""Retrieves a registered kernel implementation by its ID.
|
||||
|
||||
Args:
|
||||
kernel_id (str): The ID of the kernel to retrieve.
|
||||
|
||||
Returns:
|
||||
Optional[type[BaseKernel]]: The kernel class if found, else ``None``.
|
||||
"""
|
||||
return cls._kernels.get(kernel_id)
|
||||
|
||||
class MetaRMSNormKernel(MetaKernel):
|
||||
@classmethod
|
||||
def apply(cls, model: HFModel, **kwargs) -> HFModel:
|
||||
raise NotImplementedError
|
||||
def get_registered_kernels(cls) -> dict[str, type[BaseKernel]]:
|
||||
r"""Returns a dictionary of all registered kernels.
|
||||
|
||||
Returns:
|
||||
dict[str, type[BaseKernel]]: Dictionary mapping kernel IDs to kernel classes.
|
||||
"""
|
||||
return cls._kernels
|
||||
|
||||
|
||||
class MetaSwiGluKernel(MetaKernel):
|
||||
@classmethod
|
||||
def apply(cls, model: HFModel, **kwargs) -> HFModel:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class MetaRoPEKernel(MetaKernel):
|
||||
@classmethod
|
||||
def apply(cls, model: HFModel, **kwargs) -> HFModel:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class MetaMoEKernel(MetaKernel):
|
||||
@classmethod
|
||||
def apply(cls, model: HFModel, **kwargs) -> HFModel:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def _ensure_kernels_loaded() -> None:
|
||||
"""Ensure all kernel implementations are imported and registered.
|
||||
|
||||
This function dynamically imports all kernel implementation modules to trigger
|
||||
their auto-registration. Python's module system ensures each module is only
|
||||
executed once (cached in sys.modules), so repeated calls are safe and fast.
|
||||
"""
|
||||
# List of kernel module paths to import
|
||||
kernel_modules = [
|
||||
"rms_norm.npu_rms_norm",
|
||||
"rope.npu_rope",
|
||||
"mlp.npu_swiglu",
|
||||
"mlp.npu_fused_moe",
|
||||
# Add new kernel modules here as they are created
|
||||
]
|
||||
|
||||
# Import each module to trigger kernel registration
|
||||
# Python's import system caches modules, so this is fast on subsequent calls
|
||||
for module_name in kernel_modules:
|
||||
try:
|
||||
__import__(f"{__package__}.{module_name}", fromlist=["*"])
|
||||
except ImportError:
|
||||
# Silently ignore import errors (e.g., missing dependencies like torch_npu)
|
||||
pass
|
||||
|
||||
|
||||
def discover_kernels(model: HFModel = None) -> list[type[MetaKernel]]:
|
||||
"""Discover and return all kernel classes registered for the current device.
|
||||
|
||||
This function inspects the runtime environment (device type) and returns
|
||||
all MetaKernel classes registered for that device. Each kernel's `apply()`
|
||||
method is responsible for checking if it can actually be applied (e.g.,
|
||||
required dependencies are installed, target modules exist in the model).
|
||||
|
||||
The function automatically discovers all kernels registered in KERNEL_REGISTRY
|
||||
without requiring manual enumeration. On first call, it dynamically imports
|
||||
all kernel implementation modules to trigger their auto-registration.
|
||||
|
||||
Args:
|
||||
model: The HuggingFace model to apply kernels to.
|
||||
TODO: implement the kernel route detection logic by model structure.
|
||||
|
||||
Returns:
|
||||
A list of MetaKernel classes available for the current device.
|
||||
"""
|
||||
# Ensure all kernel modules are imported to trigger registration
|
||||
_ensure_kernels_loaded()
|
||||
|
||||
discovered_kernels: list[type[MetaKernel]] = []
|
||||
|
||||
# Detect current device type
|
||||
accelerator = get_current_accelerator()
|
||||
try:
|
||||
device_type = DeviceType(accelerator.type)
|
||||
except ValueError:
|
||||
# Unknown device type, return empty list
|
||||
return discovered_kernels
|
||||
|
||||
# Skip CPU as it typically doesn't have optimized kernels
|
||||
if device_type == DeviceType.CPU:
|
||||
return discovered_kernels
|
||||
|
||||
# Iterate through registry and collect all kernels for current device
|
||||
for devices in KERNEL_REGISTRY._registry.values():
|
||||
kernel_cls = devices.get(device_type)
|
||||
if kernel_cls is not None:
|
||||
discovered_kernels.append(kernel_cls)
|
||||
|
||||
return discovered_kernels
|
||||
|
||||
|
||||
def apply_kernel(model: HFModel, kernel: type[MetaKernel] | Any, /, **kwargs) -> "HFModel":
|
||||
"""Call the MetaKernel's `apply` to perform the replacement.
|
||||
|
||||
Corresponding replacement logic is maintained inside each kernel; the only
|
||||
requirement is that `apply` returns the replaced model.
|
||||
|
||||
Example:
|
||||
from transformers import AutoModelForCausalLM
|
||||
from .rms_norm.npu_rms_norm import NpuRMSNormKernel
|
||||
model = AutoModelForCausalLM.from_pretrained("qwen/qwen2.5-0.5B")
|
||||
model = apply_kernel(model, NpuRMSNormKernel)
|
||||
"""
|
||||
if not issubclass(kernel, MetaKernel):
|
||||
raise ValueError(f"{kernel} must be a MetaKernel instance.")
|
||||
|
||||
if kernel.device != get_current_accelerator().type:
|
||||
raise ValueError(f"{kernel} must be applied to {kernel.device} device, got {get_current_accelerator().type}.")
|
||||
|
||||
return kernel.apply(model, **kwargs)
|
||||
|
||||
|
||||
def apply_available_kernels(model: HFModel, **kwargs) -> "HFModel":
|
||||
"""Apply all available kernels to the model."""
|
||||
for kernel in discover_kernels(model):
|
||||
model = apply_kernel(model, kernel, **kwargs)
|
||||
|
||||
return model
|
||||
# export decorator alias
|
||||
register_kernel = Registry.register
|
||||
|
||||
@@ -1,122 +0,0 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import sys
|
||||
|
||||
import torch
|
||||
|
||||
from .....accelerator.helper import DeviceType, is_torch_npu_available
|
||||
from .....utils.types import HFModel
|
||||
from ..constants import KernelType
|
||||
from ..registry import MetaRoPEKernel
|
||||
|
||||
|
||||
def _apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
||||
"""Applies Rotary Position Embedding to the query and key tensors."""
|
||||
import torch_npu
|
||||
|
||||
cos = cos.unsqueeze(unsqueeze_dim)
|
||||
sin = sin.unsqueeze(unsqueeze_dim)
|
||||
q_embed = torch_npu.npu_rotary_mul(q, cos, sin)
|
||||
k_embed = torch_npu.npu_rotary_mul(k, cos, sin)
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
def _apply_multimodal_rotary_pos_emb_qwen25_vl(q, k, cos, sin, mrope_section, unsqueeze_dim=1):
|
||||
"""Applies Rotary Position Embedding with multimodal sections (Qwen2-VL)."""
|
||||
import torch_npu
|
||||
|
||||
mrope_section = mrope_section * 2
|
||||
cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
|
||||
unsqueeze_dim
|
||||
)
|
||||
sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
|
||||
unsqueeze_dim
|
||||
)
|
||||
|
||||
q_embed = torch_npu.npu_rotary_mul(q, cos, sin)
|
||||
k_embed = torch_npu.npu_rotary_mul(k, cos, sin)
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
class NpuRoPEKernel(MetaRoPEKernel):
|
||||
type = KernelType.ROPE
|
||||
device = DeviceType.NPU
|
||||
kernel = _apply_rotary_pos_emb
|
||||
|
||||
@classmethod
|
||||
def apply(cls, model, **kwargs) -> "HFModel":
|
||||
"""Apply RoPE acceleration by monkey-patching `apply_rotary_pos_emb`.
|
||||
|
||||
This function iterates through the model's modules to find attention layers,
|
||||
identifies the module where they are defined, and replaces the original
|
||||
`apply_rotary_pos_emb` function in that module's namespace with the
|
||||
NPU-accelerated version from this file.
|
||||
"""
|
||||
if not is_torch_npu_available():
|
||||
return model
|
||||
|
||||
_modules = set()
|
||||
for module in model.modules():
|
||||
if "Attention" in module.__class__.__name__:
|
||||
module_name = module.__class__.__module__
|
||||
if module_name in _modules:
|
||||
continue
|
||||
try:
|
||||
target_module = sys.modules[module_name]
|
||||
if hasattr(target_module, "apply_rotary_pos_emb"):
|
||||
if getattr(target_module, "apply_rotary_pos_emb") is not cls.kernel:
|
||||
setattr(target_module, "apply_rotary_pos_emb", cls.kernel)
|
||||
_modules.add(module_name)
|
||||
except Exception:
|
||||
pass
|
||||
return model
|
||||
|
||||
|
||||
class NpuQwen2VLRoPEKernel(MetaRoPEKernel):
|
||||
"""Qwen2-VL specific RoPE kernel - not auto-registered.
|
||||
|
||||
This kernel is for specific models (Qwen2-VL) and should be manually
|
||||
applied when needed rather than auto-discovered.
|
||||
"""
|
||||
|
||||
type = KernelType.ROPE
|
||||
device = DeviceType.NPU
|
||||
kernel = _apply_multimodal_rotary_pos_emb_qwen25_vl
|
||||
auto_register = False # Disable auto-registration for model-specific kernel
|
||||
|
||||
@classmethod
|
||||
def apply(cls, model, **kwargs) -> "HFModel":
|
||||
"""Apply RoPE acceleration by monkey-patching `apply_rotary_pos_emb`.
|
||||
|
||||
This function iterates through the model's modules to find attention layers,
|
||||
identifies the module where they are defined, and replaces the original
|
||||
`apply_rotary_pos_emb` function in that module's namespace with the
|
||||
NPU-accelerated version from this file.
|
||||
"""
|
||||
_modules = set()
|
||||
for module in model.modules():
|
||||
if "Attention" in module.__class__.__name__:
|
||||
module_name = module.__class__.__module__
|
||||
if module_name in _modules:
|
||||
continue
|
||||
try:
|
||||
target_module = sys.modules[module_name]
|
||||
if hasattr(target_module, "apply_multimodal_rotary_pos_emb"):
|
||||
if getattr(target_module, "apply_multimodal_rotary_pos_emb") is not cls.kernel:
|
||||
setattr(target_module, "apply_multimodal_rotary_pos_emb", cls.kernel)
|
||||
_modules.add(module_name)
|
||||
except Exception:
|
||||
pass
|
||||
return model
|
||||
@@ -18,8 +18,11 @@ Contains shared fixtures, pytest configuration, and custom markers.
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from pytest import Config, FixtureRequest, Item, MonkeyPatch
|
||||
|
||||
from llamafactory.extras.misc import get_current_device, get_device_count, is_env_enabled
|
||||
@@ -70,7 +73,7 @@ def _handle_slow_tests(items: list[Item]):
|
||||
item.add_marker(skip_slow)
|
||||
|
||||
|
||||
def _get_visible_devices_env() -> str | None:
|
||||
def _get_visible_devices_env() -> Optional[str]:
|
||||
"""Return device visibility env var name."""
|
||||
if CURRENT_DEVICE == "cuda":
|
||||
return "CUDA_VISIBLE_DEVICES"
|
||||
@@ -118,6 +121,14 @@ def pytest_collection_modifyitems(config: Config, items: list[Item]):
|
||||
_handle_device_visibility(items)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _cleanup_distributed_state():
|
||||
"""Cleanup distributed state after each test."""
|
||||
yield
|
||||
if dist.is_initialized():
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _manage_distributed_env(request: FixtureRequest, monkeypatch: MonkeyPatch) -> None:
|
||||
"""Set environment variables for distributed tests if specific devices are requested."""
|
||||
@@ -145,6 +156,10 @@ def _manage_distributed_env(request: FixtureRequest, monkeypatch: MonkeyPatch) -
|
||||
monkeypatch.setenv(env_key, visible_devices[0] if visible_devices else "0")
|
||||
else:
|
||||
monkeypatch.setenv(env_key, "0")
|
||||
if CURRENT_DEVICE == "cuda":
|
||||
monkeypatch.setattr(torch.cuda, "device_count", lambda: 1)
|
||||
elif CURRENT_DEVICE == "npu":
|
||||
monkeypatch.setattr(torch.npu, "device_count", lambda: 1)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
71
tests_v1/config/test_args_parser.py
Normal file
71
tests_v1/config/test_args_parser.py
Normal file
@@ -0,0 +1,71 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import pathlib
|
||||
import sys
|
||||
from unittest.mock import patch
|
||||
|
||||
from llamafactory.v1.config.arg_parser import get_args
|
||||
|
||||
|
||||
def test_get_args_from_yaml(tmp_path: pathlib.Path):
|
||||
config_yaml = """
|
||||
### model
|
||||
model: "llamafactory/tiny-random-qwen2.5"
|
||||
trust_remote_code: true
|
||||
use_fast_processor: true
|
||||
model_class: "llm"
|
||||
kernel_config:
|
||||
name: "auto"
|
||||
include_kernels: "auto" # choice: null/true/false/auto/kernel_id1,kernel_id2,kernel_id3, default is null
|
||||
peft_config:
|
||||
name: "lora"
|
||||
lora_rank: 0.8
|
||||
quant_config: null
|
||||
|
||||
### data
|
||||
dataset: "llamafactory/tiny-supervised-dataset"
|
||||
cutoff_len: 2048
|
||||
|
||||
### training
|
||||
output_dir: "outputs/test_run"
|
||||
micro_batch_size: 1
|
||||
global_batch_size: 1
|
||||
learning_rate: 1.0e-4
|
||||
bf16: false
|
||||
dist_config: null
|
||||
|
||||
### sample
|
||||
sample_backend: "hf"
|
||||
max_new_tokens: 128
|
||||
"""
|
||||
|
||||
config_file = tmp_path / "config.yaml"
|
||||
config_file.write_text(config_yaml, encoding="utf-8")
|
||||
|
||||
test_argv = ["test_args_parser.py", str(config_file)]
|
||||
|
||||
with patch.object(sys, "argv", test_argv):
|
||||
data_args, model_args, training_args, sample_args = get_args()
|
||||
assert training_args.output_dir == "outputs/test_run"
|
||||
assert training_args.micro_batch_size == 1
|
||||
assert training_args.global_batch_size == 1
|
||||
assert training_args.learning_rate == 1.0e-4
|
||||
assert training_args.bf16 is False
|
||||
assert training_args.dist_config is None
|
||||
assert model_args.model == "llamafactory/tiny-random-qwen2.5"
|
||||
assert model_args.kernel_config.name == "auto"
|
||||
assert model_args.kernel_config.get("include_kernels") == "auto"
|
||||
assert model_args.peft_config.name == "lora"
|
||||
assert model_args.peft_config.get("lora_rank") == 0.8
|
||||
@@ -18,8 +18,10 @@ Contains shared fixtures, pytest configuration, and custom markers.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from pytest import Config, FixtureRequest, Item, MonkeyPatch
|
||||
|
||||
from llamafactory.v1.accelerator.helper import get_current_accelerator, get_device_count
|
||||
@@ -139,9 +141,21 @@ def _manage_distributed_env(request: FixtureRequest, monkeypatch: MonkeyPatch) -
|
||||
devices_str = ",".join(str(i) for i in range(required))
|
||||
|
||||
monkeypatch.setenv(env_key, devices_str)
|
||||
|
||||
# add project root dir to path for mp run
|
||||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
if project_root not in sys.path:
|
||||
sys.path.insert(0, project_root)
|
||||
|
||||
os.environ["PYTHONPATH"] = project_root + os.pathsep + os.environ.get("PYTHONPATH", "")
|
||||
|
||||
else: # non-distributed test
|
||||
if old_value:
|
||||
visible_devices = [v for v in old_value.split(",") if v != ""]
|
||||
monkeypatch.setenv(env_key, visible_devices[0] if visible_devices else "0")
|
||||
else:
|
||||
monkeypatch.setenv(env_key, "0")
|
||||
if CURRENT_DEVICE == "cuda":
|
||||
monkeypatch.setattr(torch.cuda, "device_count", lambda: 1)
|
||||
elif CURRENT_DEVICE == "npu":
|
||||
monkeypatch.setattr(torch.npu, "device_count", lambda: 1)
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
|
||||
import torch
|
||||
|
||||
from llamafactory.v1.config.model_args import ModelArguments
|
||||
from llamafactory.v1.config.model_args import ModelArguments, PluginConfig
|
||||
from llamafactory.v1.core.model_loader import ModelLoader
|
||||
|
||||
|
||||
@@ -29,5 +29,23 @@ def test_tiny_qwen():
|
||||
assert model_loader.model.dtype == torch.bfloat16
|
||||
|
||||
|
||||
def test_tiny_qwen_with_kernel_plugin():
|
||||
from transformers import Qwen2ForCausalLM
|
||||
|
||||
from llamafactory.v1.plugins.model_plugins.kernels.ops.rms_norm.npu_rms_norm import npu_rms_norm_forward
|
||||
|
||||
model_args = ModelArguments(
|
||||
model="llamafactory/tiny-random-qwen2.5", kernel_config=PluginConfig(name="auto", include_kernels="auto")
|
||||
)
|
||||
model_loader = ModelLoader(model_args)
|
||||
# test enable apply kernel plugin
|
||||
if hasattr(torch, "npu"):
|
||||
assert model_loader.model.model.layers[0].input_layernorm.forward.__code__ == npu_rms_norm_forward.__code__
|
||||
else:
|
||||
assert model_loader.model.model.layers[0].input_layernorm.forward.__code__ != npu_rms_norm_forward.__code__
|
||||
assert isinstance(model_loader.model, Qwen2ForCausalLM)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_tiny_qwen()
|
||||
test_tiny_qwen_with_kernel_plugin()
|
||||
|
||||
@@ -12,16 +12,13 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import sys
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
from llamafactory.v1.accelerator.helper import get_current_accelerator
|
||||
from llamafactory.v1.plugins.model_plugins.kernels.mlp import npu_swiglu
|
||||
from llamafactory.v1.plugins.model_plugins.kernels.registry import apply_available_kernels, apply_kernel
|
||||
from llamafactory.v1.plugins.model_plugins.kernels.rms_norm import npu_rms_norm
|
||||
from llamafactory.v1.plugins.model_plugins.kernels.rope import npu_rope
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
@@ -29,24 +26,29 @@ def clear_accelerator_cache():
|
||||
get_current_accelerator.cache_clear()
|
||||
|
||||
|
||||
def reload_kernels():
|
||||
"""Helper to reload kernel modules to respect mocked accelerator."""
|
||||
# Unload kernel interface and registry
|
||||
keys_to_remove = [k for k in sys.modules if k.startswith("llamafactory.v1.plugins.model_plugins.kernels")]
|
||||
for k in keys_to_remove:
|
||||
del sys.modules[k]
|
||||
|
||||
|
||||
@patch("torch.accelerator.current_accelerator")
|
||||
def test_apply_kernel(mock_get_accelerator: MagicMock):
|
||||
mock_device = MagicMock()
|
||||
setattr(mock_device, "type", "npu")
|
||||
mock_get_accelerator.return_value = mock_device
|
||||
# Force reload of kernels with mocked accelerator
|
||||
reload_kernels()
|
||||
from llamafactory.v1.plugins.model_plugins.kernels.interface import apply_default_kernels
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("llamafactory/tiny-random-qwen2.5")
|
||||
|
||||
original_rmsnorm_forward = model.model.layers[0].input_layernorm.forward
|
||||
original_swiglu_forward = model.model.layers[0].mlp.forward
|
||||
|
||||
apply_kernel(model, npu_rope.NpuRoPEKernel)
|
||||
|
||||
model = apply_kernel(model, npu_rms_norm.NpuRMSNormKernel)
|
||||
assert model.model.layers[0].input_layernorm is not original_rmsnorm_forward
|
||||
|
||||
model = apply_kernel(model, npu_swiglu.NpuSwiGluKernel)
|
||||
assert model.model.layers[0].mlp.forward is not original_swiglu_forward
|
||||
model = apply_default_kernels(model=model, include_kernels="npu_fused_rmsnorm")
|
||||
assert model.model.layers[0].input_layernorm.forward.__func__ is not original_rmsnorm_forward.__func__
|
||||
assert model.model.layers[0].mlp.forward.__func__ is original_swiglu_forward.__func__
|
||||
|
||||
|
||||
@patch("torch.accelerator.current_accelerator")
|
||||
@@ -56,12 +58,15 @@ def test_apply_all_kernels(mock_get_accelerator: MagicMock):
|
||||
setattr(mock_device, "type", "npu")
|
||||
mock_get_accelerator.return_value = mock_device
|
||||
|
||||
# Force reload of kernels with mocked accelerator
|
||||
reload_kernels()
|
||||
from llamafactory.v1.plugins.model_plugins.kernels.interface import apply_default_kernels
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("llamafactory/tiny-random-qwen2.5")
|
||||
|
||||
original_rmsnorm_forward = model.model.layers[0].input_layernorm.forward
|
||||
original_swiglu_forward = model.model.layers[0].mlp.forward
|
||||
|
||||
model = apply_available_kernels(model)
|
||||
|
||||
assert model.model.layers[0].input_layernorm is not original_rmsnorm_forward
|
||||
assert model.model.layers[0].mlp.forward is not original_swiglu_forward
|
||||
model = apply_default_kernels(model=model, include_kernels=True)
|
||||
assert model.model.layers[0].input_layernorm.forward.__func__ is not original_rmsnorm_forward.__func__
|
||||
assert model.model.layers[0].mlp.forward.__func__ is not original_swiglu_forward.__func__
|
||||
|
||||
Reference in New Issue
Block a user