56 Commits

Author SHA1 Message Date
Yaowei Zheng
7af909522a [version] release v0.9.5 (#10532) 2026-05-30 23:57:09 +08:00
xvxuopop
e016d2480e [fix] Fix NPU FusedMoE and RMSNorm (#10512) 2026-05-30 21:42:54 +08:00
jiaqiw09
7d719182c9 [model] fix non-packing batch (bsz>1) for Qwen3.5 with flash attention (#10529) 2026-05-30 21:41:41 +08:00
jiaqiw09
01398eb18d [v1] fix padding free with sp (#10513) 2026-05-26 23:49:21 +08:00
cxy
8e68764b65 [v1] Implement dynamic padding-free stretrgy for batching (#10507)
Co-authored-by: cxy-thinkbook <xuanyuchen@seu.edu.cn>
2026-05-25 20:40:21 +08:00
Copilot
16ff5a23cb [fix] use getattr for profiler attrs to support MCA TrainingArguments (#10506)
Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: hiyouga <16256802+hiyouga@users.noreply.github.com>
2026-05-21 17:26:29 +08:00
jiaqiw09
bdcb92d035 [v1] Add FlashAttention selection and implement normal / padding-free / dynamic batching (#10469) 2026-05-21 17:14:19 +08:00
sunyi0505
7e20db5735 [v1] support liger_kernel (#10493) 2026-05-21 11:44:56 +08:00
浮梦
2322bf1cc2 [v1] add cuda fused moe kernel, implementing with triton (#10481) 2026-05-20 20:49:42 +08:00
浮梦
368c48968f [callback] add torch profiler callback (#10463) 2026-05-20 20:47:52 +08:00
浮梦
8b5ea65770 [v1] support reward training stage (#10431)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2026-05-20 20:46:52 +08:00
Dennis Huang
40e786d016 [data] add missing return statement in MiniCPM V Plugin (#10500)
Co-authored-by: Cursor <cursoragent@cursor.com>
2026-05-20 01:50:00 +08:00
xvxuopop
6b9df75ab9 [docker] update npu docker (#10479) 2026-05-13 20:56:43 +08:00
马境远
ca50f22c38 [fix] Fix MiniCPM-V-4.6 image preprocessing behavior (#10478) 2026-05-12 11:35:23 +08:00
马境远
53e77a9bfa [model] support MiniCPM-V-4.6 (#10472) 2026-05-08 18:14:34 +08:00
浮梦
55bd4944b6 [fix] fix qwen3_6 template doc (#10470) 2026-05-08 11:47:02 +08:00
Tai An
7e09152275 fix(data/converter): handle None tool_calls in OpenAI-style messages (#10455) 2026-05-07 17:44:41 +08:00
simulikeit
1e503a982d [assets] correct typo in examples/README_zh.md (#10462) 2026-05-07 00:42:01 +08:00
luca-888
8752280dd7 [data] Optimize QwenVL video dataset preprocessing (#10404)
Co-authored-by: Kingsley <kingsleydodonow@gmail.com>
2026-05-03 18:36:56 +08:00
Kingsley
468723c5d9 [packing] fix GDN crash when meeting dummy image (#10453) 2026-05-01 12:10:13 +08:00
Peilin Li
887ee2b121 [refactor] Add KTransformers AMX MoE SFT support via Accelerate (#10430)
Co-authored-by: mrhaoxx <mr.haoxx@gmail.com>
Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-05-01 01:47:58 +08:00
Kingsley
6b08b948c9 [misc] bump transformers version upperbound (#10446) 2026-05-01 01:30:11 +08:00
Hertz
f7f3bfcbd7 [model] support Hy3-Preview (#10432) 2026-04-29 23:21:13 +08:00
Kingsley
3475198d1e [fa2] fix IMA when train qwen3_5 (#10448) 2026-04-29 20:20:55 +08:00
sunyi0505
50945ef850 [v1] fix device_mesh and sp for fsdp2 (#10429) 2026-04-28 11:20:11 +08:00
Octopus
2f0bef207a [export] handle NotImplementedError in export_model for transformers>=5.0 (fixes #10410) (#10438)
Co-authored-by: octo-patch <octo-patch@github.com>
2026-04-27 23:36:23 +08:00
curnane-lab
2092abc217 [npu] add Qwen3.5 support with Partial RoPE and Hybrid Attention (#10421)
Co-authored-by: Curnane <mingliangfu@users.noreply.github.com>
2026-04-27 23:36:07 +08:00
Kingsley
99464b3d03 [misc] code lint (#10439)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2026-04-27 14:07:31 +08:00
jiaqiw09
9a0cfdccfa [v1] fix init on meta in transformers v5 (#10414) 2026-04-27 00:37:09 +08:00
Kingsley
c8890c32db [data] support discard history cot for multiturn (#10435) 2026-04-27 00:32:44 +08:00
Kingsley
79c8332e4c [train] add qwen35 patch for neat_packing (#10436) 2026-04-27 00:31:49 +08:00
jiaqiw09
e0bc3c1971 [v1] fix epoch and steps (#10422) 2026-04-23 17:29:06 +08:00
浮梦
ecca167eb4 [model] support qwen3.6 models (#10415)
Co-authored-by: frozenleaves <frozen@Mac.local>
2026-04-22 19:44:01 +08:00
jiaqiw09
28a6ea1cdc [v1] add deepspeed zero3 trigger for low memory usage weight loading (#10300) 2026-04-21 14:09:52 +08:00
sunyi0505
f5d739b132 [v1] fix device mesh and clip_grad_norm for ulysses cp (#10366) 2026-04-21 10:54:54 +08:00
浮梦
c4bbac49b2 [v1] support resume training from checkpoint (#10280)
Co-authored-by: frozenleaves <frozen@Mac.local>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2026-04-20 20:28:08 +08:00
Cocoon-Break
c5aecaf31d [data] fix SeedToolUtils.tool_extractor returns content when no tool calls found (#10408)
Signed-off-by: Cocoon-Break <54054995+kuishou68@users.noreply.github.com>
2026-04-20 12:22:55 +08:00
Kingsley
436d26bc28 fix: projector lookup for gemma4 modules (#10382)
Co-authored-by: yiluoAK_47 <yiluoAK_47@163.com>
2026-04-12 08:32:14 +08:00
Kingsley
c109c061e5 [model] set mm_projectors for omni models (#10378) 2026-04-10 18:12:57 +08:00
Kingsley
fa09c01c36 fix: gemma4 mm_token_type_ids padding (#10359) 2026-04-06 13:14:45 +08:00
Kingsley
eae6f0b541 [model] gemma4 (#10346) 2026-04-05 12:10:28 +08:00
Kingsley
acac63ef35 [data] fix qwen3vl timestamp (#10338) 2026-04-01 22:40:12 +08:00
浮梦
e5e8546493 [misc] fix moe (#10334)
Co-authored-by: frozenleaves <frozen@Mac.local>
2026-03-31 23:04:45 +08:00
Cui-yshoho
97433c53b6 [feat] support LlamaFactory SFT training by HyperParallel FSDP2 backend (#10289) 2026-03-30 10:47:20 +08:00
sunyi0505
b5afabe3d2 [v1] support ulysses cp for fsdp2 (#10262) 2026-03-27 16:22:48 +08:00
jiaqiw09
df2e6edb7e [v1] add init on rank0 for fsdp2 (#10264) 2026-03-27 14:54:03 +08:00
Goalina
d02fcd3588 [ci] add nginx cache config for Ascend NPU CI environment (#10323) 2026-03-27 10:04:16 +08:00
jiaqiw09
c340aa2a33 [v1] add callbacks (#10255) 2026-03-26 19:59:57 +08:00
Hertz
1e536733c6 [data] fix mimo-v2 tool call (#10315) 2026-03-26 17:37:22 +08:00
Yutong Wu
97d479fa92 [model] support Qwen3.5 liger kernel (#10313) 2026-03-24 18:25:33 +08:00
Kingsley
ffbff33af3 chore: mca workflow compatible with qwen-vl series (#10303) 2026-03-22 02:28:52 +08:00
Kingsley
833f6027b1 [fix] fit neat_packing & mrope model packing (#10283)
Co-authored-by: Yaowei Zheng <hiyouga@buaa.edu.cn>
2026-03-20 16:50:11 +08:00
robertglools
d91d8af89e [data] add SGSC zero-hallucination B2B dataset (NOO-Protocol) (#10284)
Co-authored-by: GloolsGuan <GloolsGuan@gmail.com>
2026-03-20 15:49:03 +08:00
xxddccaa
e67ab9e2f2 fix:MiniCPMVPlugin IndexError in process_messages when training with video (#10276)
Co-authored-by: xxddccaa <xxddccaa@users.noreply.github.com>
2026-03-18 19:18:06 +08:00
LincolnBurrows2017
2c4f121817 [fix] handle empty content list in system message (#10291)
Co-authored-by: AI Assistant <assistant@example.com>
2026-03-18 12:05:49 +08:00
xvxuopop
487f8b8191 [v1] add qwen3 templates and fix rendering plugin. (#10212)
Co-authored-by: Yaowei Zheng <hiyouga@buaa.edu.cn>
2026-03-18 11:30:50 +08:00
151 changed files with 7952 additions and 2620 deletions

105
.ai/CLAUDE.md Normal file
View File

@@ -0,0 +1,105 @@
# CLAUDE.md
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
## Commands
```bash
# Code style (auto-fix)
make style
# Code quality check (no modifications)
make quality
# Run all tests
make test
# Run a single test file
WANDB_DISABLED=true pytest -vv --import-mode=importlib tests/path/to/test_file.py
# Run tests matching a pattern
WANDB_DISABLED=true pytest -vv --import-mode=importlib tests/ -k "test_name"
# License header check
make license
# Build package
make build
```
The project uses `uv` as the preferred package manager. Commands automatically use `uv run` / `uvx` if `uv` is available.
## Architecture
LlamaFactory has two parallel architectures controlled by the `USE_V1` environment variable:
- **v0 (default):** `api, webui > chat, eval, train > data, model > hparams > extras`
- **v1 (experimental, `USE_V1=1`):** `trainers > core > accelerator, plugins, config > utils`
Most active development happens in v0. The v1 architecture lives in `src/llamafactory/v1/`.
### Entry Points
CLI entry point is `llamafactory-cli` / `lmf``src/llamafactory/cli.py:main()`, which dispatches to `launcher.py` based on `USE_V1`.
Available subcommands: `train`, `chat`, `api`, `export`, `webchat`, `webui`, `env`, `version`, `help`.
### Training Flow (v0)
```
run_exp() [tuner.py]
→ read_args() → parse YAML/JSON config
→ get_train_args() → produces typed argument dataclasses
→ routes to: run_sft / run_dpo / run_ppo / run_rm / run_pt / run_kto
→ optional: export_model()
```
Training is invoked with a YAML config: `llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml`
### Configuration System
All training parameters are YAML/JSON config files. Argument parsing in `src/llamafactory/hparams/parser.py` produces four typed dataclasses:
- `ModelArguments` — model/tokenizer selection, quantization
- `DataArguments` — datasets, templates, preprocessing
- `FinetuningArguments` — LoRA rank/target, training method (sft/dpo/ppo/rm/pt/kto)
- `TrainingArguments` — extends HuggingFace's `TrainingArguments`
### Key Modules
| Module | Purpose |
|--------|---------|
| `src/llamafactory/model/loader.py` | Loads model + tokenizer; applies quantization, LoRA, patches |
| `src/llamafactory/model/patcher.py` | Model-specific compatibility patches |
| `src/llamafactory/data/template.py` | Prompt templates; `TEMPLATES` dict maps model family → format |
| `src/llamafactory/data/mm_plugin.py` | Multi-modal (image/video/audio) data handling |
| `src/llamafactory/data/processor/` | Per-stage data processors (supervised, pairwise, pretrain, etc.) |
| `src/llamafactory/train/sft/` | SFT trainer; other stages follow same structure |
| `src/llamafactory/chat/` | Inference engines: `hf_engine`, `vllm_engine`, `sglang_engine`, `kt_engine` |
| `src/llamafactory/extras/constants.py` | Enums and constants used across the project |
### Adding Support for a New Model
1. Add a prompt template to `src/llamafactory/data/template.py` in the `TEMPLATES` dict
2. Add any necessary model patches in `src/llamafactory/model/patcher.py`
3. Add multi-modal support in `src/llamafactory/data/mm_plugin.py` if needed
### Distributed Training
Multi-GPU automatically uses `torchrun`. Additional backends:
- **Ray:** Optional Ray cluster support
- **HyperParallel FSDP2:** `src/llamafactory/train/hyper_parallel/`
- **Megatron-core:** `src/llamafactory/train/mca/`
### Testing
- `tests/` — v0 tests; `tests_v1/` — v1 tests
- Most training tests require GPU hardware
- pytest markers: `@pytest.mark.slow`, `@pytest.mark.runs_on(['cuda'])`
- Always set `WANDB_DISABLED=true` when running tests
### Code Style
- Ruff for linting and formatting (line length 119, Google-style docstrings)
- Python 3.11+ syntax
- Double quotes for strings
- All new files must include Apache 2.0 license header (checked by `make license`)

View File

@@ -109,7 +109,7 @@ jobs:
platforms: linux/amd64,linux/arm64
file: ./docker/docker-npu/Dockerfile
build-args: |
BASE_IMAGE=quay.io/ascend/cann:8.3.rc2-a3-ubuntu22.04-py3.11
BASE_IMAGE=quay.io/ascend/cann:9.0.0-a3-ubuntu22.04-py3.11
push: ${{ github.event_name != 'pull_request' }}
tags: |
docker.io/hiyouga/llamafactory:${{ steps.version.outputs.tag }}-npu-a3

View File

@@ -35,15 +35,12 @@ jobs:
transformers:
- ""
include: # test backward compatibility
- python: "3.11"
os: "ubuntu-latest"
transformers: "4.51.0"
- python: "3.11"
os: "ubuntu-latest"
transformers: "4.53.0"
- python: "3.11"
os: "ubuntu-latest"
transformers: "4.55.0"
- python: "3.11"
os: "ubuntu-latest"
transformers: "4.57.1"
runs-on: ${{ matrix.os }}

View File

@@ -38,7 +38,7 @@ jobs:
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
container:
image: ascendai/cann:8.3.rc2-910b-ubuntu22.04-py3.11
image: ascendai/cann:9.0.0-910b-ubuntu22.04-py3.11
env:
HF_ENDPOINT: https://hf-mirror.com
HF_TOKEN: ${{ secrets.HF_TOKEN }}
@@ -49,6 +49,12 @@ jobs:
- name: Checkout
uses: actions/checkout@v6
- name: Set nginx-cache for Ascend CI
run: |
sed -Ei 's@(ports|archive).ubuntu.com@cache-service.nginx-pypi-cache.svc.cluster.local:8081@g' /etc/apt/sources.list
pip config set global.index-url http://cache-service.nginx-pypi-cache.svc.cluster.local/pypi/simple
pip config set global.trusted-host cache-service.nginx-pypi-cache.svc.cluster.local
- name: Install uv
uses: astral-sh/setup-uv@v7
with:
@@ -59,8 +65,8 @@ jobs:
- name: Install dependencies
run: |
uv venv
uv pip install -r requirements/npu.txt
uv pip install -e .
uv pip install -r requirements/npu.txt
uv pip install -r requirements/dev.txt
- name: Install node
@@ -83,5 +89,7 @@ jobs:
make build
- name: Test with pytest
shell: bash
run: |
source /usr/local/Ascend/ascend-toolkit/set_env.sh
make test

1
CLAUDE.md Symbolic link
View File

@@ -0,0 +1 @@
.ai/CLAUDE.md

View File

@@ -15,8 +15,6 @@
[![Open in Colab](assets/thirdparty/colab.svg)](https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing)
[![Open in DSW](assets/thirdparty/dsw.svg)](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory)
[![Open in Lab4ai](assets/thirdparty/lab4ai.svg)](https://www.lab4ai.cn/course/detail?id=7c13e60f6137474eb40f6fd3983c0f46&utm_source=LLaMA-Factory)
[![Open in Online](assets/thirdparty/online.svg)](https://www.llamafactory.com.cn/?utm_source=LLaMA-Factory)
[![Open in Spaces](https://img.shields.io/badge/🤗-Open%20in%20Spaces-blue)](https://huggingface.co/spaces/hiyouga/LLaMA-Board)
[![Open in Studios](https://img.shields.io/badge/ModelScope-Open%20in%20Studios-blue)](https://modelscope.cn/studios/hiyouga/LLaMA-Board)
[![Open in Novita](https://img.shields.io/badge/Novita-Deploy%20Template-blue)](https://novita.ai/templates-library/105981?sharer=88115474-394e-4bda-968e-b88e123d0c47)
@@ -38,7 +36,7 @@
</div>
👋 Join our [WeChat](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/main.jpg), [NPU](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/npu.jpg), [Lab4AI](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/lab4ai.jpg), [LLaMA Factory Online](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/online.jpg) user group.
👋 Join our [WeChat](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/main.jpg) and [NPU](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/npu.jpg) user groups.
\[ English | [中文](README_zh.md) \]
@@ -52,14 +50,11 @@ Start local training:
Start cloud training:
- **Colab (free)**: https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing
- **PAI-DSW (free trial)**: https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory
- **LLaMA Factory Online**: https://www.llamafactory.com.cn/?utm_source=LLaMA-Factory
- **Alaya NeW (cloud GPU deal)**: https://docs.alayanew.com/docs/documents/useGuide/LLaMAFactory/mutiple/?utm_source=LLaMA-Factory
Read technical notes:
- **Documentation (WIP)**: https://llamafactory.readthedocs.io/en/latest/
- **Documentation (AMD GPU)**: https://rocm.docs.amd.com/projects/ai-developer-hub/en/latest/notebooks/fine_tune/llama_factory_llama3.html
- **Official Blog**: https://blog.llamafactory.net/en/
- **Official Course**: https://www.lab4ai.cn/course/detail?id=7c13e60f6137474eb40f6fd3983c0f46&utm_source=LLaMA-Factory
> [!NOTE]
> Except for the above links, all other websites are unauthorized third-party websites. Please carefully use them.
@@ -78,7 +73,6 @@ Read technical notes:
- [Data Preparation](#data-preparation)
- [Quickstart](#quickstart)
- [Fine-Tuning with LLaMA Board GUI](#fine-tuning-with-llama-board-gui-powered-by-gradio)
- [LLaMA Factory Online](#llama-factory-online)
- [Build Docker](#build-docker)
- [Deploy with OpenAI-style API and vLLM](#deploy-with-openai-style-api-and-vllm)
- [Download from ModelScope Hub](#download-from-modelscope-hub)
@@ -117,15 +111,11 @@ Read technical notes:
- 💡 [KTransformers Fine-Tuning × LLaMA Factory: Fine-tuning 1000 Billion models with 2 4090-GPU + CPU](https://blog.llamafactory.net/en/posts/ktransformers/) (English)
- 💡 [Easy Dataset × LLaMA Factory: Enabling LLMs to Efficiently Learn Domain Knowledge](https://buaa-act.feishu.cn/wiki/GVzlwYcRFiR8OLkHbL6cQpYin7g) (English)
- [Fine-tune a mental health LLM using LLaMA-Factory](https://www.lab4ai.cn/project/detail?id=25cce32ec131497b9e06a93336a0817f&type=project&utm_source=LLaMA-Factory) (Chinese)
- [Fine-tune GPT-OSS for Role-Playing using LLaMA-Factory](https://docs.llamafactory.com.cn/docs/documents/best-practice/gptroleplay/?utm_source=LLaMA-Factory) (Chinese)
- [A One-Stop Code-Free Model Reinforcement Learning and Deployment Platform based on LLaMA-Factory and EasyR1](https://aws.amazon.com/cn/blogs/china/building-llm-model-hub-based-on-llamafactory-and-easyr1/) (Chinese)
- [How Apoidea Group enhances visual information extraction from banking documents with multimodal models using LLaMA-Factory on Amazon SageMaker HyperPod](https://aws.amazon.com/cn/blogs/machine-learning/how-apoidea-group-enhances-visual-information-extraction-from-banking-documents-with-multimodal-models-using-llama-factory-on-amazon-sagemaker-hyperpod/) (English)
<details><summary>All Blogs</summary>
- [Fine-tune Llama3.1-70B for Medical Diagnosis using LLaMA-Factory](https://docs.alayanew.com/docs/documents/bestPractice/bigModel/llama70B/?utm_source=LLaMA-Factory) (Chinese)
- [Fine-tune Qwen2.5-VL for Autonomous Driving using LLaMA-Factory](https://docs.alayanew.com/docs/documents/useGuide/LLaMAFactory/mutiple/?utm_source=LLaMA-Factory) (Chinese)
- [LLaMA Factory: Fine-tuning the DeepSeek-R1-Distill-Qwen-7B Model for News Classifier](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory_deepseek_r1_distill_7b) (Chinese)
- [A One-Stop Code-Free Model Fine-Tuning \& Deployment Platform based on SageMaker and LLaMA-Factory](https://aws.amazon.com/cn/blogs/china/a-one-stop-code-free-model-fine-tuning-deployment-platform-based-on-sagemaker-and-llama-factory/) (Chinese)
- [LLaMA Factory Multi-Modal Fine-Tuning Practice: Fine-Tuning Qwen2-VL for Personal Tourist Guide](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory_qwen2vl) (Chinese)
@@ -319,7 +309,8 @@ Read technical notes:
| [Pixtral](https://huggingface.co/mistralai) | 12B | pixtral |
| [Qwen2 (Code/Math/MoE/QwQ)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen |
| [Qwen3 (MoE/Instruct/Thinking/Next)](https://huggingface.co/Qwen) | 0.6B/1.7B/4B/8B/14B/32B/80B/235B | qwen3/qwen3_nothink |
| [Qwen3.5](https://huggingface.co/Qwen) | 0.8B/2B/4B/9B/27B/35B/122B/397B | qwen3_5 |
| [Qwen3.5](https://huggingface.co/Qwen) | 0.8B/2B/4B/9B/27B/35B/122B/397B | qwen3_5/qwen3_5_nothink |
| [Qwen3.6](https://huggingface.co/Qwen) | 27B/35B | qwen3_6 |
| [Qwen2-Audio](https://huggingface.co/Qwen) | 7B | qwen2_audio |
| [Qwen2.5-Omni](https://huggingface.co/Qwen) | 3B/7B | qwen2_omni |
| [Qwen3-Omni](https://huggingface.co/Qwen) | 30B | qwen3_omni |
@@ -660,10 +651,6 @@ See [examples/README.md](examples/README.md) for advanced usage (including distr
llamafactory-cli webui
```
### LLaMA Factory Online
Read our [documentation](https://docs.llamafactory.com.cn/docs/documents/quickstart/getstarted/?utm_source=LLaMA-Factory).
### Build Docker
For CUDA users:

View File

@@ -15,8 +15,6 @@
[![Open in Colab](assets/thirdparty/colab.svg)](https://colab.research.google.com/drive/1d5KQtbemerlSDSxZIfAaWXhKr30QypiK?usp=sharing)
[![Open in DSW](assets/thirdparty/dsw.svg)](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory)
[![Open in Lab4ai](assets/thirdparty/lab4ai.svg)](https://www.lab4ai.cn/course/detail?id=7c13e60f6137474eb40f6fd3983c0f46&utm_source=LLaMA-Factory)
[![Open in Online](assets/thirdparty/online.svg)](https://www.llamafactory.com.cn/?utm_source=LLaMA-Factory)
[![Open in Spaces](https://img.shields.io/badge/🤗-Open%20in%20Spaces-blue)](https://huggingface.co/spaces/hiyouga/LLaMA-Board)
[![Open in Studios](https://img.shields.io/badge/ModelScope-Open%20in%20Studios-blue)](https://modelscope.cn/studios/hiyouga/LLaMA-Board)
[![Open in Novita](https://img.shields.io/badge/Novita-Deploy%20Template-blue)](https://novita.ai/templates-library/105981?sharer=88115474-394e-4bda-968e-b88e123d0c47)
@@ -38,7 +36,7 @@
</div>
👋 加入我们的[微信群](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/main.jpg)[NPU 用户群](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/npu.jpg)、[大模型实验室群](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/lab4ai.jpg) 或 [LLaMA Factory Online 用户群](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/online.png)
👋 加入我们的[微信群](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/main.jpg)[NPU 用户群](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/npu.jpg)。
\[ [English](README.md) | 中文 \]
@@ -52,8 +50,6 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
开始云端训练:
- **Colab免费**https://colab.research.google.com/drive/1d5KQtbemerlSDSxZIfAaWXhKr30QypiK?usp=sharing
- **PAI-DSW免费试用**https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory
- **LLaMA Factory Online在线微调**https://www.llamafactory.com.cn/?utm_source=LLaMA-Factory
- **九章智算云(算力优惠活动)**https://docs.alayanew.com/docs/documents/useGuide/LLaMAFactory/mutiple/?utm_source=LLaMA-Factory
阅读技术文档:
- **入门教程**https://zhuanlan.zhihu.com/p/695287607
@@ -61,7 +57,6 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
- **框架文档**https://llamafactory.readthedocs.io/zh-cn/latest/
- **框架文档(昇腾 NPU**https://ascend.github.io/docs/sources/llamafactory/
- **官方博客**https://blog.llamafactory.net/
- **官方课程**https://www.lab4ai.cn/course/detail?id=7c13e60f6137474eb40f6fd3983c0f46&utm_source=LLaMA-Factory
> [!NOTE]
> 除上述链接以外的其他网站均为未经许可的第三方网站,请小心甄别。
@@ -80,7 +75,6 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
- [数据准备](#数据准备)
- [快速开始](#快速开始)
- [LLaMA Board 可视化微调](#llama-board-可视化微调由-gradio-驱动)
- [LLaMA Factory Online 在线微调](#llama-factory-online-在线微调)
- [构建 Docker](#构建-docker)
- [利用 vLLM 部署 OpenAI API](#利用-vllm-部署-openai-api)
- [从魔搭社区下载](#从魔搭社区下载)
@@ -119,15 +113,11 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
- 💡 [KTransformers Fine-Tuning × LLaMA Factory: 用2张4090级的GPU+CPU 微调 1000B规模的超大模型](https://swcil84qspu.feishu.cn/wiki/Z1sSwb2poijybxkyPEkcDG6enVc) (中文)
- 💡 [Easy Dataset × LLaMA Factory: 让大模型高效学习领域知识](https://buaa-act.feishu.cn/wiki/KY9xwTGs1iqHrRkjXBwcZP9WnL9)(中文)
- [使用 LLaMA-Factory 微调心理健康大模型](https://www.lab4ai.cn/project/detail?id=25cce32ec131497b9e06a93336a0817f&type=project&utm_source=LLaMA-Factory)(中文)
- [使用 LLaMA-Factory 构建 GPT-OSS 角色扮演模型](https://docs.llamafactory.com.cn/docs/documents/best-practice/gptroleplay/?utm_source=LLaMA-Factory)(中文)
- [基于 LLaMA-Factory 和 EasyR1 打造一站式无代码大模型强化学习和部署平台 LLM Model Hub](https://aws.amazon.com/cn/blogs/china/building-llm-model-hub-based-on-llamafactory-and-easyr1/)(中文)
- [通过亚马逊 SageMaker HyperPod 上的 LLaMA-Factory 增强多模态模型银行文档的视觉信息提取](https://aws.amazon.com/cn/blogs/machine-learning/how-apoidea-group-enhances-visual-information-extraction-from-banking-documents-with-multimodal-models-using-llama-factory-on-amazon-sagemaker-hyperpod/)(英文)
<details><summary>全部博客</summary>
- [使用 LLaMA-Factory 微调 Llama3.1-70B 医学诊断模型](https://docs.alayanew.com/docs/documents/bestPractice/bigModel/llama70B/?utm_source=LLaMA-Factory)(中文)
- [使用 LLaMA-Factory 微调 Qwen2.5-VL 实现自动驾驶场景微调](https://docs.alayanew.com/docs/documents/useGuide/LLaMAFactory/mutiple/?utm_source=LLaMA-Factory)(中文)
- [LLaMA Factory微调 DeepSeek-R1-Distill-Qwen-7B 模型实现新闻标题分类器](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory_deepseek_r1_distill_7b)(中文)
- [基于 Amazon SageMaker 和 LLaMA-Factory 打造一站式无代码模型微调部署平台 Model Hub](https://aws.amazon.com/cn/blogs/china/a-one-stop-code-free-model-fine-tuning-deployment-platform-based-on-sagemaker-and-llama-factory/)(中文)
- [LLaMA Factory 多模态微调实践:微调 Qwen2-VL 构建文旅大模型](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory_qwen2vl)(中文)
@@ -321,7 +311,8 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
| [Pixtral](https://huggingface.co/mistralai) | 12B | pixtral |
| [Qwen2 (Code/Math/MoE/QwQ)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen |
| [Qwen3 (MoE/Instruct/Thinking/Next)](https://huggingface.co/Qwen) | 0.6B/1.7B/4B/8B/14B/32B/80B/235B | qwen3/qwen3_nothink |
| [Qwen3.5](https://huggingface.co/Qwen) | 0.8B/2B/4B/9B/27B/35B/122B/397B | qwen3_5 |
| [Qwen3.5](https://huggingface.co/Qwen) | 0.8B/2B/4B/9B/27B/35B/122B/397B | qwen3_5/qwen3_5_nothink |
| [Qwen3.6](https://huggingface.co/Qwen) | 27B/35B | qwen3_6 |
| [Qwen2-Audio](https://huggingface.co/Qwen) | 7B | qwen2_audio |
| [Qwen2.5-Omni](https://huggingface.co/Qwen) | 3B/7B | qwen2_omni |
| [Qwen3-Omni](https://huggingface.co/Qwen) | 30B | qwen3_omni |
@@ -661,10 +652,6 @@ llamafactory-cli export examples/merge_lora/qwen3_lora_sft.yaml
llamafactory-cli webui
```
### LLaMA Factory Online 在线微调
详情阅读该[文档](https://docs.llamafactory.com.cn/docs/documents/quickstart/getstarted/?utm_source=LLaMA-Factory)。
### 构建 Docker
CUDA 用户:

View File

@@ -236,6 +236,13 @@
"ms_hub_url": "AI-ModelScope/sharegpt_gpt4",
"formatting": "sharegpt"
},
"sgsc_b2b_entities": {
"hf_hub_url": "Nooxus-AI/NOO-Verified-Global-Entities",
"formatting": "sharegpt",
"columns": {
"messages": "messages"
}
},
"ultrachat_200k": {
"hf_hub_url": "HuggingFaceH4/ultrachat_200k",
"ms_hub_url": "AI-ModelScope/ultrachat_200k",

View File

@@ -1,6 +1,6 @@
# https://hub.docker.com/r/ascendai/cann/tags
ARG BASE_IMAGE=quay.io/ascend/cann:8.5.1-910b-ubuntu22.04-py3.11
ARG BASE_IMAGE=quay.io/ascend/cann:9.0.0-910b-ubuntu22.04-py3.11
FROM ${BASE_IMAGE}
# Installation arguments

View File

@@ -33,7 +33,7 @@ services:
dockerfile: ./docker/docker-npu/Dockerfile
context: ../..
args:
BASE_IMAGE: quay.io/ascend/cann:8.5.1-a3-ubuntu22.04-py3.11
BASE_IMAGE: quay.io/ascend/cann:9.0.0-a3-ubuntu22.04-py3.11
PIP_INDEX: https://pypi.org/simple
container_name: llamafactory-a3
image: llamafactory:npu-a3

View File

@@ -96,7 +96,7 @@ FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500
### 支持弹性和容错的多机指令监督微调
要启动一个支持弹性节点和容错的多机指令微调,在每个节点上执行以下命令。弹性节点数量范围为 `MIN_NNODES:MAX_NNODES`,每个节点最多允许因为错误重启 `MAX_RESTARTS` 次。`RDZV_ID` 应设置为一个唯一的作业 ID由参与该作业的所有节点共享。更多可以参考官方文档 [torchrun](https://docs.pytorch.org/docs/stable/elastic/run.html)。
要启动一个支持弹性节点和容错的多机指令微调,在每个节点上执行以下命令。弹性节点数量范围为 `MIN_NNODES:MAX_NNODES`,每个节点最多允许因为错误重启 `MAX_RESTARTS` 次。`RDZV_ID` 应设置为一个唯一的作业 ID由参与该作业的所有节点共享。更多细节可以参考官方文档 [torchrun](https://docs.pytorch.org/docs/stable/elastic/run.html)。
```bash
FORCE_TORCHRUN=1 MIN_NNODES=1 MAX_NNODES=3 MAX_RESTARTS=3 RDZV_ID=llamafactory MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/qwen3_full_sft.yaml

View File

@@ -0,0 +1,20 @@
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: FSDP
downcast_bf16: 'no'
fsdp_config:
fsdp_version: 2
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_transformer_layer_cls_to_wrap: Qwen3_5DecoderLayer,Qwen3_5VisionBlock
fsdp_cpu_ram_efficient_loading: true
fsdp_offload_params: false
fsdp_reshard_after_forward: true
fsdp_state_dict_type: FULL_STATE_DICT
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 8 # Change to match your NPU count (e.g., 8 for A2, 16 for A3)
rdzv_backend: static
same_network: true
use_cpu: false

View File

@@ -0,0 +1,47 @@
# Start FSDP2 full fine-tuning on Ascend NPU
# Usage:
# accelerate launch \
# --config_file examples/accelerate/fsdp2_config_qwen35.yaml \
# src/train.py examples/ascend/qwen3_5_full_sft_fsdp2.yaml
#
# Note: Change `num_processes` in fsdp2_config_qwen35.yaml to match your NPU count
### model
model_name_or_path: Qwen/Qwen3.5-4B
trust_remote_code: true
use_v1_kernels: true
flash_attn: fa2
### method
stage: sft
do_train: true
finetuning_type: full
### dataset
dataset: alpaca_en_demo
template: qwen3_5_nothink
cutoff_len: 2048
max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16
dataloader_num_workers: 4
### output
output_dir: saves/Qwen3.5-4B/full/sft
logging_steps: 1
save_steps: 500
max_steps: 500
plot_loss: true
overwrite_output_dir: true
save_only_model: false
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
### train
per_device_train_batch_size: 8
gradient_accumulation_steps: 1
learning_rate: 1.0e-5
lr_scheduler_type: cosine
warmup_ratio: 0.1
bf16: true
ddp_timeout: 1800
resume_from_checkpoint: null

View File

@@ -0,0 +1,25 @@
compute_environment: LOCAL_MACHINE
distributed_type: FSDP
fsdp_config:
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_cpu_ram_efficient_loading: true
fsdp_offload_params: false
fsdp_reshard_after_forward: true
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_version: 2
mixed_precision: bf16
num_machines: 1
num_processes: 4 # Adjust based on your GPU count; 4 is suitable for 4 GPUs
rdzv_backend: static
same_network: true
use_cpu: false
kt_config:
enabled: true
kt_backend: AMXBF16 # Use with original BF16 expert weights.
kt_num_threads: 96
kt_tp_enabled: true
kt_threadpool_count: 2
kt_max_cache_depth: 2
kt_share_backward_bb: true
lora_rank: 8

View File

@@ -0,0 +1,25 @@
compute_environment: LOCAL_MACHINE
distributed_type: FSDP
fsdp_config:
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_cpu_ram_efficient_loading: true
fsdp_offload_params: false
fsdp_reshard_after_forward: true
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_version: 2
mixed_precision: bf16
num_machines: 1
num_processes: 4 # Adjust based on your GPU count; 4 is suitable for 4 GPUs
rdzv_backend: static
same_network: true
use_cpu: false
kt_config:
enabled: true
kt_backend: AMXINT4 # Use with online-converted INT4 expert weights
kt_num_threads: 96
kt_tp_enabled: true
kt_threadpool_count: 2
kt_max_cache_depth: 2
kt_share_backward_bb: true
lora_rank: 8

View File

@@ -0,0 +1,25 @@
compute_environment: LOCAL_MACHINE
distributed_type: FSDP
fsdp_config:
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_cpu_ram_efficient_loading: true
fsdp_offload_params: false
fsdp_reshard_after_forward: true
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_version: 2
mixed_precision: bf16
num_machines: 1
num_processes: 4 # Adjust based on your GPU count; 4 is suitable for 4 GPUs
rdzv_backend: static
same_network: true
use_cpu: false
kt_config:
enabled: true
kt_backend: AMXINT8 # Use with online-converted INT8 expert weights
kt_num_threads: 96
kt_tp_enabled: true
kt_threadpool_count: 2
kt_max_cache_depth: 2
kt_share_backward_bb: true
lora_rank: 8

View File

@@ -0,0 +1,25 @@
compute_environment: LOCAL_MACHINE
distributed_type: FSDP
fsdp_config:
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_cpu_ram_efficient_loading: true
fsdp_offload_params: false
fsdp_reshard_after_forward: true
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_version: 2
mixed_precision: bf16
num_machines: 1
num_processes: 1 # Adjust based on your GPU count; 1 is suitable for 1 GPU
rdzv_backend: static
same_network: true
use_cpu: false
kt_config:
enabled: true
kt_backend: AMXINT8 # Use with online-converted INT8 expert weights
kt_num_threads: 96
kt_tp_enabled: true
kt_threadpool_count: 2
kt_max_cache_depth: 2
kt_share_backward_bb: true
lora_rank: 8

View File

@@ -0,0 +1,25 @@
compute_environment: LOCAL_MACHINE
distributed_type: FSDP
fsdp_config:
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_cpu_ram_efficient_loading: true
fsdp_offload_params: false
fsdp_reshard_after_forward: true
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_version: 2
mixed_precision: bf16
num_machines: 1
num_processes: 8 # Adjust based on your GPU count; 8 is suitable for 8 GPUs
rdzv_backend: static
same_network: true
use_cpu: false
kt_config:
enabled: true
kt_backend: AMXINT8 # Use with online-converted INT8 expert weights
kt_num_threads: 96
kt_tp_enabled: true
kt_threadpool_count: 2
kt_max_cache_depth: 2
kt_share_backward_bb: true
lora_rank: 8

View File

@@ -1,10 +0,0 @@
model_name_or_path: deepseek-ai/DeepSeek-V2-Lite
adapter_name_or_path: saves/Kllama_deepseekV2
template: deepseek
infer_backend: ktransformers # choices: [huggingface, vllm, sglang, ktransformers]
trust_remote_code: true
use_kt: true # use KTransformers as LoRA sft backend to inference
kt_optimize_rule: examples/ktransformers/kt_optimize_rules/DeepSeek-V2-Lite-Chat-sft-amx.yaml
cpu_infer: 32
chunk_size: 8192

View File

@@ -1,9 +0,0 @@
model_name_or_path: opensourcerelease/DeepSeek-V3-bf16
template: deepseek3
infer_backend: ktransformers # choices: [huggingface, vllm, sglang, ktransformers]
trust_remote_code: true
use_kt: true # use KTransformers as LoRA sft backend to inference
kt_optimize_rule: examples/ktransformers/kt_optimize_rules/DeepSeek-V3-Chat-sft-amx-multi-gpu.yaml
cpu_infer: 32
chunk_size: 8192

View File

@@ -1,10 +0,0 @@
model_name_or_path: opensourcerelease/DeepSeek-V3-bf16
adapter_name_or_path: saves/Kllama_deepseekV3
template: deepseek3
infer_backend: ktransformers # choices: [huggingface, vllm, sglang, ktransformers]
trust_remote_code: true
use_kt: true # use KTransformers as LoRA sft backend to inference
kt_optimize_rule: examples/ktransformers/kt_optimize_rules/DeepSeek-V3-Chat-sft-amx-multi-gpu.yaml
cpu_infer: 32
chunk_size: 8192

View File

@@ -1,10 +0,0 @@
model_name_or_path: Qwen/Qwen3-235B-A22B-Instruct-2507
adapter_name_or_path: saves/Kllama_Qwen3MoE_235bA22b
template: qwen3_nothink
infer_backend: ktransformers # choices: [huggingface, vllm, sglang, ktransformers]
trust_remote_code: true
use_kt: true # use KTransformers as LoRA sft backend to inference
kt_optimize_rule: examples/ktransformers/kt_optimize_rules/Qwen3Moe-sft-amx.yaml
cpu_infer: 32
chunk_size: 8192

View File

@@ -1,69 +0,0 @@
- match:
class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding
replace:
class: ktransformers.operators.RoPE.YarnRotaryEmbedding
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match:
name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
generate_op: "KLinearTorch"
prefill_op: "KLinearTorch"
- match:
name: "^lm_head"
class: torch.nn.Linear
replace:
class: ktransformers.operators.linear.KTransformersLinear
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
generate_op: "KLinearTorch"
prefill_op: "KLinearTorch"
- match:
name: "^model\\.layers\\..*\\.mlp$"
class: ktransformers.models.modeling_deepseek.DeepseekV2MoE
replace:
class: ktransformers.operators.experts.KDeepseekV2MoE # mlp module with custom forward function
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match:
name: "^model\\.layers\\..*\\.mlp\\.experts$"
replace:
class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
kwargs:
prefill_device: "cuda"
prefill_op: "KExpertsTorch"
generate_device: "cpu"
generate_op: "KSFTExpertsCPU"
out_device: "cuda"
backend: "AMXInt8" # or "AMXBF16" or "llamafile" (default)
recursive: False # don't recursively inject submodules of this module
- match:
name: "^model\\.layers\\..*\\.self_attn$"
replace:
class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match:
name: "^model$"
replace:
class: "ktransformers.operators.models.KDeepseekV2Model"
kwargs:
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
- match:
name: "^model.embed_tokens"
replace:
class: "default"
kwargs:
generate_device: "cpu"
prefill_device: "cpu"

View File

@@ -1,68 +0,0 @@
- match:
class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding
replace:
class: ktransformers.operators.RoPE.YarnRotaryEmbedding
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match:
name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch"
- match:
name: "^lm_head"
class: torch.nn.Linear
replace:
class: ktransformers.operators.linear.KTransformersLinear
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch"
- match:
name: "^model\\.layers\\..*\\.mlp$"
class: ktransformers.models.modeling_deepseek.DeepseekV2MoE
replace:
class: ktransformers.operators.experts.KDeepseekV2MoE # mlp module with custom forward function
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match:
name: "^model\\.layers\\..*\\.mlp\\.experts$"
replace:
class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
kwargs:
prefill_device: "cuda"
prefill_op: "KExpertsTorch"
generate_device: "cpu"
generate_op: "KExpertsCPU"
out_device: "cuda"
recursive: False # don't recursively inject submodules of this module
- match:
name: "^model\\.layers\\..*\\.self_attn$"
replace:
class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match:
name: "^model$"
replace:
class: "ktransformers.operators.models.KDeepseekV2Model"
kwargs:
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
- match:
name: "^model.embed_tokens"
replace:
class: "default"
kwargs:
generate_device: "cpu"
prefill_device: "cpu"

View File

@@ -1,139 +0,0 @@
- match:
name: "^model.embed_tokens"
replace:
class: "default"
kwargs:
generate_device: "cpu"
prefill_device: "cpu"
- match:
name: "^model\\.layers\\.(0|[1-9])\\."
class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding
replace:
class: ktransformers.operators.RoPE.YarnRotaryEmbedding
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
- match:
name: "^model\\.layers\\.([12][0-9])\\."
class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding
replace:
class: ktransformers.operators.RoPE.YarnRotaryEmbedding
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
- match:
name: "^model\\.layers\\.(0|[1-9])\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
generate_op: "KLinearTorch"
prefill_op: "KLinearTorch"
- match:
name: "^model\\.layers\\.([12][0-9])\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
generate_op: "KLinearTorch"
prefill_op: "KLinearTorch"
- match:
name: "^model\\.layers\\.(0|[1-9])\\.mlp$"
class: ktransformers.models.modeling_deepseek.DeepseekV2MoE
replace:
class: ktransformers.operators.experts.KDeepseekV2MoE # mlp module with custom forward function
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
- match:
name: "^model\\.layers\\.([12][0-9])\\.mlp$"
class: ktransformers.models.modeling_deepseek.DeepseekV2MoE
replace:
class: ktransformers.operators.experts.KDeepseekV2MoE # mlp module with custom forward function
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
- match:
name: "^model\\.layers\\.(0|[1-9])\\.mlp\\.experts$"
replace:
class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
kwargs:
prefill_device: "cuda:0"
prefill_op: "KExpertsTorch"
generate_device: "cpu"
generate_op: "KSFTExpertsCPU"
out_device: "cuda:0"
backend: "AMXInt8" # or "AMXBF16" or "llamafile" (default)
recursive: False # don't recursively inject submodules of this module
- match:
name: "^model\\.layers\\.([12][0-9])\\.mlp\\.experts$"
replace:
class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
kwargs:
prefill_device: "cuda:1"
prefill_op: "KExpertsTorch"
generate_device: "cpu"
generate_op: "KSFTExpertsCPU"
out_device: "cuda:1"
backend: "AMXInt8" # or "AMXBF16" or "llamafile" (default)
recursive: False # don't recursively inject submodules of this module
- match:
name: "^model\\.layers\\.(0|[1-9])\\.self_attn$"
replace:
class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
- match:
name: "^model\\.layers\\.([12][0-9])\\.self_attn$"
replace:
class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
- match:
name: "^model$"
replace:
class: "ktransformers.operators.models.KDeepseekV2Model"
kwargs:
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
transfer_map:
10: "cuda:1"
- match:
name: "^model\\.layers\\.(0|[1-9])\\."
replace:
class: "default"
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
- match:
name: "^lm_head"
class: torch.nn.Linear
replace:
class: ktransformers.operators.linear.KTransformersLinear
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
generate_op: "KLinearTorch"
prefill_op: "KLinearTorch"
- match:
name: "(^model\\.layers\\.([12][0-9])\\.)|(model.norm)"
replace:
class: "default"
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"

View File

@@ -1,69 +0,0 @@
- match:
class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding
replace:
class: ktransformers.operators.RoPE.YarnRotaryEmbedding
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match:
name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
generate_op: "KLinearTorch"
prefill_op: "KLinearTorch"
- match:
name: "^lm_head"
class: torch.nn.Linear
replace:
class: ktransformers.operators.linear.KTransformersLinear
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
generate_op: "KLinearTorch"
prefill_op: "KLinearTorch"
- match:
name: "^model\\.layers\\..*\\.mlp$"
class: ktransformers.models.modeling_deepseek.DeepseekV2MoE
replace:
class: ktransformers.operators.experts.KDeepseekV2MoE # mlp module with custom forward function
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match:
name: "^model\\.layers\\..*\\.mlp\\.experts$"
replace:
class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
kwargs:
prefill_device: "cpu"
prefill_op: "KExpertsTorch"
generate_device: "cpu"
generate_op: "KSFTExpertsCPU"
out_device: "cuda"
backend: "AMXInt8" # or "AMXBF16" or "llamafile" (default)
recursive: False # don't recursively inject submodules of this module
- match:
name: "^model\\.layers\\..*\\.self_attn$"
replace:
class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match:
name: "^model$"
replace:
class: "ktransformers.operators.models.KDeepseekV2Model"
kwargs:
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
- match:
name: "^model.embed_tokens"
replace:
class: "default"
kwargs:
generate_device: "cpu"
prefill_device: "cpu"

View File

@@ -1,68 +0,0 @@
- match:
class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding
replace:
class: ktransformers.operators.RoPE.YarnRotaryEmbedding
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match:
name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
generate_op: "KLinearTorch"
prefill_op: "KLinearTorch"
- match:
name: "^lm_head"
class: torch.nn.Linear
replace:
class: ktransformers.operators.linear.KTransformersLinear
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
generate_op: "KLinearTorch"
prefill_op: "KLinearTorch"
- match:
name: "^model\\.layers\\..*\\.mlp$"
class: ktransformers.models.modeling_deepseek.DeepseekV2MoE
replace:
class: ktransformers.operators.experts.KDeepseekV2MoE # mlp module with custom forward function
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match:
name: "^model\\.layers\\..*\\.mlp\\.experts$"
replace:
class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
kwargs:
prefill_device: "cpu"
prefill_op: "KExpertsTorch"
generate_device: "cpu"
generate_op: "KSFTExpertsCPU"
out_device: "cuda"
recursive: False # don't recursively inject submodules of this module
- match:
name: "^model\\.layers\\..*\\.self_attn$"
replace:
class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match:
name: "^model$"
replace:
class: "ktransformers.operators.models.KDeepseekV2Model"
kwargs:
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
- match:
name: "^model.embed_tokens"
replace:
class: "default"
kwargs:
generate_device: "cpu"
prefill_device: "cpu"

View File

@@ -1,68 +0,0 @@
- match:
class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding
replace:
class: ktransformers.operators.RoPE.YarnRotaryEmbedding
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match:
name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch"
- match:
name: "^lm_head"
class: torch.nn.Linear
replace:
class: ktransformers.operators.linear.KTransformersLinear
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch"
- match:
name: "^model\\.layers\\..*\\.mlp$"
class: ktransformers.models.modeling_deepseek.DeepseekV2MoE
replace:
class: ktransformers.operators.experts.KDeepseekV2MoE # mlp module with custom forward function
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match:
name: "^model\\.layers\\..*\\.mlp\\.experts$"
replace:
class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
kwargs:
prefill_device: "cuda"
prefill_op: "KExpertsTorch"
generate_device: "cpu"
generate_op: "KExpertsCPU"
out_device: "cuda"
recursive: False # don't recursively inject submodules of this module
- match:
name: "^model\\.layers\\..*\\.self_attn$"
replace:
class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match:
name: "^model$"
replace:
class: "ktransformers.operators.models.KDeepseekV2Model"
kwargs:
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
- match:
name: "^model.embed_tokens"
replace:
class: "default"
kwargs:
generate_device: "cpu"
prefill_device: "cpu"

View File

@@ -1,77 +0,0 @@
- match:
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
replace:
class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match:
name: "^lm_head$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch"
- match:
name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch"
- match:
name: "^model\\.layers\\..*\\.mlp$"
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
replace:
class: ktransformers.operators.experts.KDeepseekV3MoE # mlp module with custom forward function
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match:
class: ktransformers.models.modeling_deepseek_v3.MoEGate
replace:
class: ktransformers.operators.gate.KMoEGate
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
- match:
name: "^model\\.layers\\..*\\.mlp\\.experts$"
replace:
class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
kwargs:
prefill_device: "cuda"
prefill_op: "KExpertsTorch"
generate_device: "cpu"
generate_op: "KExpertsCPU"
out_device: "cuda"
backend: "AMXInt8" # or "AMXBF16" or "llamafile" (default)
recursive: False # don't recursively inject submodules of this module
- match:
name: "^model\\.layers\\..*\\.self_attn$"
replace:
class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
absorb_for_prefill: False # change this to True to enable long context(prefill may slower).
- match:
name: "^model$"
replace:
class: "ktransformers.operators.models.KDeepseekV2Model"
kwargs:
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
- match:
name: "^model.embed_tokens"
replace:
class: "default"
kwargs:
generate_device: "cpu"
prefill_device: "cpu"

View File

@@ -1,392 +0,0 @@
- match:
name: "^model.embed_tokens"
replace:
class: "default"
kwargs:
generate_device: "cpu"
prefill_device: "cpu"
# === Rotary Embedding Replacement ===
# GPU 0: layers 014
- match:
name: "^model\\.layers\\.([0-9]|1[0-4])\\."
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
replace:
class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
# GPU 1: layers 1529
- match:
name: "^model\\.layers\\.(1[5-9]|2[0-9])\\."
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
replace:
class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
# GPU 2: layers 3044
- match:
name: "^model\\.layers\\.(3[0-9]|4[0-4])\\."
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
replace:
class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3
kwargs:
generate_device: "cuda:2"
prefill_device: "cuda:2"
# GPU 3: layers 4560
- match:
name: "^model\\.layers\\.(4[5-9]|5[0-9]|60)\\."
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
replace:
class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3
kwargs:
generate_device: "cuda:3"
prefill_device: "cuda:3"
# === Linear Layers Replacement (excluding self_attn.kv_b_proj) ===
# GPU 0: layers 014
- match:
name: "^model\\.layers\\.([0-9]|1[0-4])\\.(?!self_attn\\.kv_b_proj).*$"
class: torch.nn.Linear
replace:
class: ktransformers.operators.linear.KTransformersLinear
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
generate_op: "KLinearTorch"
prefill_op: "KLinearTorch"
# GPU 1: layers 1529
- match:
name: "^model\\.layers\\.(1[5-9]|2[0-9])\\.(?!self_attn\\.kv_b_proj).*$"
class: torch.nn.Linear
replace:
class: ktransformers.operators.linear.KTransformersLinear
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
generate_op: "KLinearTorch"
prefill_op: "KLinearTorch"
# GPU 2: layers 3044
- match:
name: "^model\\.layers\\.(3[0-9]|4[0-4])\\.(?!self_attn\\.kv_b_proj).*$"
class: torch.nn.Linear
replace:
class: ktransformers.operators.linear.KTransformersLinear
kwargs:
generate_device: "cuda:2"
prefill_device: "cuda:2"
generate_op: "KLinearTorch"
prefill_op: "KLinearTorch"
# GPU 3: layers 4560
- match:
name: "^model\\.layers\\.(4[5-9]|5[0-9]|60)\\.(?!self_attn\\.kv_b_proj).*$"
class: torch.nn.Linear
replace:
class: ktransformers.operators.linear.KTransformersLinear
kwargs:
generate_device: "cuda:3"
prefill_device: "cuda:3"
generate_op: "KLinearTorch"
prefill_op: "KLinearTorch"
# === MLP (MoE) Replacement ===
# GPU 0: layers 014
- match:
name: "^model\\.layers\\.([0-9]|1[0-4])\\.mlp$"
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
replace:
class: ktransformers.operators.experts.KDeepseekV3MoE
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
# GPU 1: layers 1529
- match:
name: "^model\\.layers\\.(1[5-9]|2[0-9])\\.mlp$"
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
replace:
class: ktransformers.operators.experts.KDeepseekV3MoE
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
# GPU 2: layers 3044
- match:
name: "^model\\.layers\\.(3[0-9]|4[0-4])\\.mlp$"
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
replace:
class: ktransformers.operators.experts.KDeepseekV3MoE
kwargs:
generate_device: "cuda:2"
prefill_device: "cuda:2"
# GPU 3: layers 4560
- match:
name: "^model\\.layers\\.(4[5-9]|5[0-9]|60)\\.mlp$"
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
replace:
class: ktransformers.operators.experts.KDeepseekV3MoE
kwargs:
generate_device: "cuda:3"
prefill_device: "cuda:3"
# === MLP Gate Replacement ===
# GPU 0: layers 014
- match:
name: "^model\\.layers\\.([0-9]|1[0-4])\\.mlp\\.gate$"
class: ktransformers.models.modeling_deepseek_v3.MoEGate
replace:
class: ktransformers.operators.gate.KMoEGate
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
# GPU 1: layers 1529
- match:
name: "^model\\.layers\\.(1[5-9]|2[0-9])\\.mlp\\.gate$"
class: ktransformers.models.modeling_deepseek_v3.MoEGate
replace:
class: ktransformers.operators.gate.KMoEGate
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
# GPU 2: layers 3044
- match:
name: "^model\\.layers\\.(3[0-9]|4[0-4])\\.mlp\\.gate$"
class: ktransformers.models.modeling_deepseek_v3.MoEGate
replace:
class: ktransformers.operators.gate.KMoEGate
kwargs:
generate_device: "cuda:2"
prefill_device: "cuda:2"
# GPU 3: layers 4560
- match:
name: "^model\\.layers\\.(4[5-9]|5[0-9]|60)\\.mlp\\.gate$"
class: ktransformers.models.modeling_deepseek_v3.MoEGate
replace:
class: ktransformers.operators.gate.KMoEGate
kwargs:
generate_device: "cuda:3"
prefill_device: "cuda:3"
# === MLP Experts Replacement ===
# replace with marlin expert. Open and modify layer-num as needed.
# Each layer of malin experts takes about 6GB of GPU memory.
# !!!Do remember 'close' cuda graph if you are using marlin expert.!!!
# !!!KExpertsTorch is untested, we don't have enough VRAM.!!!
# GPU 0: layers 34
# - match:
# name: "^model\\.layers\\.([3-4])\\.mlp\\.experts$"
# replace:
# class: ktransformers.operators.experts.KTransformersExperts
# kwargs:
# generate_device: "cuda:0"
# generate_op: "KExpertsMarlin"
# recursive: False
# # GPU 1: layers 1517
# - match:
# name: "^model\\.layers\\.(1[5-7])\\.mlp\\.experts$"
# replace:
# class: ktransformers.operators.experts.KTransformersExperts
# kwargs:
# generate_device: "cuda:1"
# generate_op: "KExpertsMarlin"
# recursive: False
# # GPU 2: layers 3032
# - match:
# name: "^model\\.layers\\.(3[0-2])\\.mlp\\.experts$"
# replace:
# class: ktransformers.operators.experts.KTransformersExperts
# kwargs:
# generate_device: "cuda:2"
# generate_op: "KExpertsMarlin"
# recursive: False
# # GPU 3: layers 4546
# - match:
# name: "^model\\.layers\\.(4[5-6])\\.mlp\\.experts$"
# replace:
# class: ktransformers.operators.experts.KTransformersExperts
# kwargs:
# generate_device: "cuda:3"
# generate_op: "KExpertsMarlin"
# recursive: False
# === MLP Experts Replacement ===
# GPU 0: layers 014
- match:
name: "^model\\.layers\\.([0-9]|1[0-4])\\.mlp\\.experts$"
replace:
class: ktransformers.operators.experts.KTransformersExperts
kwargs:
prefill_device: "cuda:0"
prefill_op: "KExpertsTorch"
generate_device: "cpu"
generate_op: "KSFTExpertsCPU"
out_device: "cuda:0"
backend: "AMXInt8" # or "AMXBF16" or "llamafile" (default)
recursive: False
# GPU 1: layers 1529
- match:
name: "^model\\.layers\\.(1[5-9]|2[0-9])\\.mlp\\.experts$"
replace:
class: ktransformers.operators.experts.KTransformersExperts
kwargs:
prefill_device: "cuda:1"
prefill_op: "KExpertsTorch"
generate_device: "cpu"
generate_op: "KSFTExpertsCPU"
out_device: "cuda:1"
backend: "AMXInt8" # or "AMXBF16" or "llamafile" (default)
recursive: False
# GPU 2: layers 3044
- match:
name: "^model\\.layers\\.(3[0-9]|4[0-4])\\.mlp\\.experts$"
replace:
class: ktransformers.operators.experts.KTransformersExperts
kwargs:
prefill_device: "cuda:2"
prefill_op: "KExpertsTorch"
generate_device: "cpu"
generate_op: "KSFTExpertsCPU"
out_device: "cuda:2"
backend: "AMXInt8" # or "AMXBF16" or "llamafile" (default)
recursive: False
# GPU 3: layers 4560
- match:
name: "^model\\.layers\\.(4[5-9]|5[0-9]|60)\\.mlp\\.experts$"
replace:
class: ktransformers.operators.experts.KTransformersExperts
kwargs:
prefill_device: "cuda:3"
prefill_op: "KExpertsTorch"
generate_device: "cpu"
generate_op: "KSFTExpertsCPU"
out_device: "cuda:3"
backend: "AMXInt8" # or "AMXBF16" or "llamafile" (default)
recursive: False
# === Self-Attention Replacement ===
# GPU 0: layers 014
- match:
name: "^model\\.layers\\.([0-9]|1[0-4])\\.self_attn$"
replace:
class: ktransformers.operators.attention.KDeepseekV2Attention
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
absorb_for_prefill: False
# GPU 1: layers 1529
- match:
name: "^model\\.layers\\.(1[5-9]|2[0-9])\\.self_attn$"
replace:
class: ktransformers.operators.attention.KDeepseekV2Attention
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
absorb_for_prefill: False
# GPU 2: layers 3044
- match:
name: "^model\\.layers\\.(3[0-9]|4[0-4])\\.self_attn$"
replace:
class: ktransformers.operators.attention.KDeepseekV2Attention
kwargs:
generate_device: "cuda:2"
prefill_device: "cuda:2"
absorb_for_prefill: False
# GPU 3: layers 4560
- match:
name: "^model\\.layers\\.(4[5-9]|5[0-9]|60)\\.self_attn$"
replace:
class: ktransformers.operators.attention.KDeepseekV2Attention
kwargs:
generate_device: "cuda:3"
prefill_device: "cuda:3"
absorb_for_prefill: False
# === Overall Model Replacement with Transfer Map ===
- match:
name: "^model$"
replace:
class: "ktransformers.operators.models.KDeepseekV2Model"
kwargs:
per_layer_prefill_intput_threshold: 0 # 0 means close layerwise prefill
transfer_map:
15: "cuda:1" # Layers 15+ on GPU 1
30: "cuda:2" # Layers 30+ on GPU 2
45: "cuda:3" # Layers 45+ on GPU 3
# === Default Catch-All for Other Modules ===
# GPU 0: layers 014
- match:
name: "^model\\.layers\\.([0-9]|1[0-4])\\."
replace:
class: "default"
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
# GPU 1: layers 1529
- match:
name: "^model\\.layers\\.(1[5-9]|2[0-9])\\."
replace:
class: "default"
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
# GPU 2: layers 3044
- match:
name: "^model\\.layers\\.(3[0-9]|4[0-4])\\."
replace:
class: "default"
kwargs:
generate_device: "cuda:2"
prefill_device: "cuda:2"
- match:
name: "^lm_head"
class: torch.nn.Linear
replace:
class: ktransformers.operators.linear.KTransformersLinear
kwargs:
generate_device: "cuda:3"
prefill_device: "cuda:3"
generate_op: "KLinearTorch"
prefill_op: "KLinearTorch"
# For final modules (model.norm), ensure they are on GPU 3 (as in your original config)
- match:
name: "(^model\\.layers\\.(4[5-9]|5[0-9]|60)\\.)|(^model\\.norm)"
replace:
class: "default"
kwargs:
generate_device: "cuda:3"
prefill_device: "cuda:3"

View File

@@ -1,156 +0,0 @@
- match:
name: "^model.embed_tokens"
replace:
class: "default"
kwargs:
generate_device: "cpu"
prefill_device: "cpu"
- match:
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\."
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
replace:
class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
- match:
name: "^model\\.layers\\.([3456][0-9])\\."
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
replace:
class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
- match:
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.(?!self_attn\\.kv_b_proj).*$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
generate_op: "KLinearTorch"
prefill_op: "KLinearTorch"
- match:
name: "^model\\.layers\\.([3456][0-9])\\.(?!self_attn\\.kv_b_proj).*$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
generate_op: "KLinearTorch"
prefill_op: "KLinearTorch"
- match:
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.mlp$"
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
replace:
class: ktransformers.operators.experts.KDeepseekV3MoE # mlp module with custom forward function
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
- match:
name: "^model\\.layers\\.([3456][0-9])\\.mlp$"
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
replace:
class: ktransformers.operators.experts.KDeepseekV3MoE # mlp module with custom forward function
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
- match:
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.mlp\\.gate$"
class: ktransformers.models.modeling_deepseek_v3.MoEGate
replace:
class: ktransformers.operators.gate.KMoEGate
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
- match:
name: "^model\\.layers\\.([3456][0-9])\\.mlp\\.gate$"
class: ktransformers.models.modeling_deepseek_v3.MoEGate
replace:
class: ktransformers.operators.gate.KMoEGate # mlp module with custom forward function
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
- match:
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.mlp\\.experts$"
replace:
class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
kwargs:
prefill_device: "cuda:0"
prefill_op: "KExpertsTorch"
generate_device: "cpu"
generate_op: "KSFTExpertsCPU"
out_device: "cuda:0"
backend: "AMXInt8" # or "AMXBF16" or "llamafile" (default)
recursive: False # don't recursively inject submodules of this module
- match:
name: "^model\\.layers\\.([3456][0-9])\\.mlp\\.experts$"
replace:
class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
kwargs:
prefill_device: "cuda:1"
prefill_op: "KExpertsTorch"
generate_device: "cpu"
generate_op: "KSFTExpertsCPU"
out_device: "cuda:1"
backend: "AMXInt8" # or "AMXBF16" or "llamafile" (default)
recursive: False # don't recursively inject submodules of this module
- match:
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.self_attn$"
replace:
class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
- match:
name: "^model\\.layers\\.([3456][0-9])\\.self_attn$"
replace:
class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
- match:
name: "^model$"
replace:
class: "ktransformers.operators.models.KDeepseekV2Model"
kwargs:
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
transfer_map:
30: "cuda:1"
- match:
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\."
replace:
class: "default"
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
- match:
name: "^lm_head"
class: torch.nn.Linear
replace:
class: ktransformers.operators.linear.KTransformersLinear
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
generate_op: "KLinearTorch"
prefill_op: "KLinearTorch"
- match:
name: "(^model\\.layers\\.([3456][0-9])\\.)|(model.norm)"
replace:
class: "default"
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"

View File

@@ -1,77 +0,0 @@
- match:
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
replace:
class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match:
name: "^lm_head$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
generate_op: "KLinearTorch"
prefill_op: "KLinearTorch"
- match:
name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
generate_op: "KLinearTorch"
prefill_op: "KLinearTorch"
- match:
name: "^model\\.layers\\..*\\.mlp$"
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
replace:
class: ktransformers.operators.experts.KDeepseekV3MoE # mlp module with custom forward function
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match:
class: ktransformers.models.modeling_deepseek_v3.MoEGate
replace:
class: ktransformers.operators.gate.KMoEGate
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
- match:
name: "^model\\.layers\\..*\\.mlp\\.experts$"
replace:
class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
kwargs:
prefill_device: "cuda"
prefill_op: "KExpertsTorch"
generate_device: "cpu"
generate_op: "KSFTExpertsCPU"
out_device: "cuda"
backend: "AMXInt8" # or "AMXBF16" or "llamafile" (default)
recursive: False # don't recursively inject submodules of this module
- match:
name: "^model\\.layers\\..*\\.self_attn$"
replace:
class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
absorb_for_prefill: False # change this to True to enable long context(prefill may slower).
- match:
name: "^model$"
replace:
class: "ktransformers.operators.models.KDeepseekV2Model"
kwargs:
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
- match:
name: "^model.embed_tokens"
replace:
class: "default"
kwargs:
generate_device: "cpu"
prefill_device: "cpu"

View File

@@ -1,80 +0,0 @@
- match:
class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeRotaryEmbedding
replace:
class: ktransformers.operators.RoPE.RotaryEmbedding
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match:
name: "^lm_head$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
generate_op: "KLinearTorch"
prefill_op: "KLinearTorch"
# - match:
# name: "^model\\.layers\\..*$" # regular expression
# class: torch.nn.Linear # only match modules matching name and class simultaneously
# replace:
# class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
# kwargs:
# generate_device: "cuda"
# prefill_device: "cuda"
# generate_op: "KLinearTorch"
# prefill_op: "KLinearTorch"
- match:
name: "^model\\.layers\\.(?!.*mlp\\.shared_expert_gate).*$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
generate_op: "KLinearTorch"
prefill_op: "KLinearTorch"
- match:
name: "^model\\.layers\\..*\\.mlp$"
replace:
class: ktransformers.operators.experts.KQwen3MoeSparseMoeBlock # mlp module with custom forward function
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match:
name: "^model\\.layers\\..*\\.mlp\\.experts$"
replace:
class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
kwargs:
prefill_device: "cuda"
prefill_op: "KExpertsTorch"
generate_device: "cpu"
generate_op: "KSFTExpertsCPU"
out_device: "cuda"
backend: "AMXInt8" # or "AMXBF16" or "AMXInt8"
recursive: False # don't recursively inject submodules of this module
- match:
name: "^model\\.layers\\..*\\.self_attn$"
replace:
class: ktransformers.operators.attention.KQwen3MoeAttention # optimized MLA implementation
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match:
name: "^model.embed_tokens"
replace:
class: "default"
kwargs:
generate_device: "cpu"
prefill_device: "cpu"
- match:
name: "^model$"
replace:
class: "ktransformers.operators.models.KQwen3MoeModel"
kwargs:
per_layer_prefill_intput_threshold: 0

View File

@@ -19,7 +19,7 @@ preprocessing_num_workers: 16
dataloader_num_workers: 4
### output
output_dir: saves/Kllama_deepseekV2
output_dir: saves/KT_FT_deepseekV2
logging_steps: 10
save_steps: 500
plot_loss: true
@@ -39,14 +39,7 @@ ddp_timeout: 180000000
resume_from_checkpoint: null
### ktransformers
use_kt: true # use KTransformers as LoRA sft backend
kt_optimize_rule: examples/ktransformers/kt_optimize_rules/DeepSeek-V2-Lite-Chat-sft-amx.yaml
cpu_infer: 32
chunk_size: 8192
### eval
# eval_dataset: alpaca_en_demo
# val_size: 0.1
# per_device_eval_batch_size: 1
# eval_strategy: steps
# eval_steps: 500
use_kt: true
# Pair with fsdp2_kt_bf16.yaml for original BF16 checkpoints.
# For pre-converted expert weights, uncomment kt_weight_path and use fsdp2_kt_int8.yaml or fsdp2_kt_int4.yaml.
# kt_weight_path: /path/to/DeepSeek-V2-Lite-AMXINT8

View File

@@ -1,5 +1,5 @@
### model
model_name_or_path: opensourcerelease/DeepSeek-V3-bf16
model_name_or_path: deepseek-ai/DeepSeek-V3-0324-BF16 # need to convert to BF16 checkpoint first
trust_remote_code: true
### method
@@ -19,7 +19,7 @@ preprocessing_num_workers: 16
dataloader_num_workers: 4
### output
output_dir: saves/Kllama_deepseekV3
output_dir: saves/KT_FT_deepseekV3
logging_steps: 10
save_steps: 500
plot_loss: true
@@ -39,14 +39,7 @@ ddp_timeout: 180000000
resume_from_checkpoint: null
### ktransformers
use_kt: true # use KTransformers as LoRA sft backend
kt_optimize_rule: examples/ktransformers/kt_optimize_rules/DeepSeek-V3-Chat-sft-amx-multi-gpu.yaml
cpu_infer: 32
chunk_size: 8192
### eval
# eval_dataset: alpaca_en_demo
# val_size: 0.1
# per_device_eval_batch_size: 1
# eval_strategy: steps
# eval_steps: 500
use_kt: true
# Pair with fsdp2_kt_bf16.yaml for original BF16 checkpoints.
# For pre-converted expert weights, uncomment kt_weight_path and use fsdp2_kt_int8.yaml or fsdp2_kt_int4.yaml.
# kt_weight_path: /path/to/DeepSeek-V3-AMXINT8

View File

@@ -0,0 +1,46 @@
### model
model_name_or_path: Qwen/Qwen3.5-397B-A17B
trust_remote_code: true
### method
stage: sft
do_train: true
finetuning_type: lora
lora_rank: 8
lora_target: all
### dataset
dataset: identity, alpaca_en_demo
template: qwen3_5
cutoff_len: 2048
max_samples: 100000
overwrite_cache: true
preprocessing_num_workers: 16
dataloader_num_workers: 4
### output
output_dir: saves/KT_FT_qwen35Moe
logging_steps: 10
save_steps: 500
plot_loss: true
overwrite_output_dir: true
save_only_model: false
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
### train
per_device_train_batch_size: 1
gradient_accumulation_steps: 8
learning_rate: 1.0e-4
num_train_epochs: 3.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
bf16: true
ddp_timeout: 180000000
resume_from_checkpoint: null
### ktransformers
use_kt: true
# For original BF16 checkpoints, start with examples/ktransformers/accelerate/fsdp2_kt_bf16.yaml.
# For pre-converted expert weights, uncomment kt_weight_path and use fsdp2_kt_int8.yaml or fsdp2_kt_int4.yaml.
# Pair the 397B path with fsdp2_kt_int8.yaml, tune cutoff_len to prepared weights and GPU memory.
# kt_weight_path: /path/to/Qwen3.5-MoE-AMXINT8

View File

@@ -11,7 +11,7 @@ lora_target: all
### dataset
dataset: identity, alpaca_en_demo
template: qwen3_nothink
template: qwen3
cutoff_len: 2048
max_samples: 100000
overwrite_cache: true
@@ -19,9 +19,9 @@ preprocessing_num_workers: 16
dataloader_num_workers: 4
### output
output_dir: saves/Kllama_Qwen3MoE_235bA22b
output_dir: saves/KT_FT_qwen3Moe
logging_steps: 10
save_steps: 200
save_steps: 500
plot_loss: true
overwrite_output_dir: true
save_only_model: false
@@ -31,7 +31,7 @@ report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
per_device_train_batch_size: 1
gradient_accumulation_steps: 8
learning_rate: 1.0e-4
num_train_epochs: 3
num_train_epochs: 3.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
bf16: true
@@ -39,14 +39,7 @@ ddp_timeout: 180000000
resume_from_checkpoint: null
### ktransformers
use_kt: true # use KTransformers as LoRA sft backend
kt_optimize_rule: examples/ktransformers/kt_optimize_rules/Qwen3Moe-sft-amx.yaml
cpu_infer: 32
chunk_size: 8192
### eval
# eval_dataset: alpaca_en_demo
# val_size: 0.1
# per_device_eval_batch_size: 1
# eval_strategy: steps
# eval_steps: 500
use_kt: true
# Pair with examples/ktransformers/accelerate/fsdp2_kt_bf16.yaml for original BF16 checkpoints.
# For pre-converted expert weights, uncomment kt_weight_path and use fsdp2_kt_int8.yaml or fsdp2_kt_int4.yaml.
# kt_weight_path: /path/to/Qwen3-235B-A22B-Instruct-2507-AMXINT8

View File

@@ -0,0 +1,31 @@
model: Qwen/Qwen3-0.6B
model_class: llm
template: qwen3_nothink
kernel_config:
name: auto
include_kernels: auto # choice: null/true/false/auto/kernel_id1,kernel_id2,kernel_id3, default is null
quant_config: null
dist_config:
name: fsdp2
dcp_path: null # /mnt/f/pretrain_models/Qwen3-0.6B-dcp
### data
train_dataset: data/v1_sft_demo.yaml
### training
output_dir: outputs/test_fsdp2
micro_batch_size: 2
batching_strategy: normal
cutoff_len: 2048
learning_rate: 1.0e-4
max_steps: 10
### sample
sample_backend: hf
max_new_tokens: 128

View File

@@ -0,0 +1,30 @@
model: Qwen/Qwen3-0.6B
model_class: llm
template: qwen3_nothink
kernel_config:
name: auto
include_kernels: auto # choice: null/true/false/auto/kernel_id1,kernel_id2,kernel_id3, default is null
quant_config: null
dist_config:
name: fsdp2
dcp_path: null # /mnt/f/pretrain_models/Qwen3-0.6B-dcp
### data
train_dataset: data/v1_sft_demo.yaml
### training
output_dir: outputs/test_fsdp2
micro_batch_size: 2
batching_strategy: dynamic_batching
cutoff_len: 2048
learning_rate: 1.0e-4
max_steps: 10
### sample
sample_backend: hf
max_new_tokens: 128

View File

@@ -0,0 +1,30 @@
model: Qwen/Qwen3-0.6B
model_class: llm
template: qwen3_nothink
kernel_config:
name: auto
include_kernels: auto # choice: null/true/false/auto/kernel_id1,kernel_id2,kernel_id3, default is null
quant_config: null
dist_config:
name: fsdp2
dcp_path: null # /mnt/f/pretrain_models/Qwen3-0.6B-dcp
### data
train_dataset: data/v1_sft_demo.yaml
### training
output_dir: outputs/test_fsdp2
micro_batch_size: 4
batching_strategy: dynamic_padding_free
flash_attn: flash_attention2
cutoff_len: 2048
learning_rate: 1.0e-4
max_steps: 10
### sample
sample_backend: hf
max_new_tokens: 128

View File

@@ -0,0 +1,30 @@
model: Qwen/Qwen3-0.6B
model_class: llm
template: qwen3_nothink
kernel_config:
name: auto
include_kernels: auto # choice: null/true/false/auto/kernel_id1,kernel_id2,kernel_id3, default is null
quant_config: null
dist_config:
name: fsdp2
dcp_path: null # /mnt/f/pretrain_models/Qwen3-0.6B-dcp
### data
train_dataset: data/v1_sft_demo.yaml
### training
output_dir: outputs/test_fsdp2
micro_batch_size: 4
batching_strategy: padding_free
flash_attn: flash_attention2
cutoff_len: 2048
learning_rate: 1.0e-4
max_steps: 10
### sample
sample_backend: hf
max_new_tokens: 128

View File

@@ -1,5 +1,4 @@
model: Qwen/Qwen3-4B
trust_remote_code: true
model_class: llm
template: qwen3_nothink
@@ -30,7 +29,6 @@ micro_batch_size: 1
global_batch_size: 4
cutoff_len: 2048
learning_rate: 2.0e-5
bf16: false
max_steps: 10
### sample

View File

@@ -1,5 +1,4 @@
model: Qwen/Qwen3-0.6B
model_class: llm
template: qwen3_nothink
@@ -20,5 +19,4 @@ output_dir: outputs/Qwen3-0.6B-deepspeed
micro_batch_size: 1
cutoff_len: 2048
learning_rate: 1.0e-4
bf16: true
max_steps: 10

View File

@@ -1,5 +1,4 @@
model: Qwen/Qwen3-0.6B
trust_remote_code: true
model_class: llm
template: qwen3_nothink
@@ -22,7 +21,6 @@ output_dir: outputs/test_fsdp2
micro_batch_size: 1
cutoff_len: 2048
learning_rate: 1.0e-4
bf16: false
max_steps: 10
### sample

View File

@@ -0,0 +1,28 @@
model: Qwen/Qwen3-0.6B
model_class: llm
template: qwen3_nothink
kernel_config:
name: liger_kernel
include_kernels: auto # choice: null/true/false/auto/kernel_id1,kernel_id2,kernel_id3, default is null
quant_config: null
dist_config:
name: fsdp2
dcp_path: null # /mnt/f/pretrain_models/Qwen3-0.6B-dcp
### data
train_dataset: data/v1_sft_demo.yaml
### training
output_dir: outputs/test_fsdp2
micro_batch_size: 1
cutoff_len: 2048
learning_rate: 1.0e-4
max_steps: 10
### sample
sample_backend: hf
max_new_tokens: 128

View File

@@ -0,0 +1,23 @@
model: Qwen/Qwen3-0.6B
trust_remote_code: true
model_class: llm
template: qwen3_nothink
# FSDP Config
dist_config:
name: fsdp2
dcp_path: null
cp_mode: ulysses
cp_size: 2
### data
train_dataset: data/v1_sft_demo.yaml
### training
output_dir: outputs/test_ulysses_cp
micro_batch_size: 1
cutoff_len: 2048
learning_rate: 1.0e-4
bf16: false
max_steps: 10

View File

@@ -1,5 +1,4 @@
model: Qwen/Qwen3-4B
trust_remote_code: true
model_class: llm
template: qwen3_nothink
@@ -28,10 +27,8 @@ train_dataset: data/v1_sft_demo.yaml
### training
output_dir: ./outputs/test_lora
micro_batch_size: 1
global_batch_size: 4
cutoff_len: 2048
learning_rate: 1.0e-4
bf16: true
max_steps: 10
### sample

View File

@@ -0,0 +1,40 @@
model: Qwen/Qwen3-4B
model_class: llm
template: qwen3_nothink
# PEFT Configuration
peft_config:
name: lora
r: 16
lora_alpha: 32
lora_dropout: 0.05
target_modules: all
# Kernel Config
kernel_config:
name: auto
include_kernels: auto
# FSDP Config
dist_config:
name: fsdp2
dcp_path: null
init_config:
name: init_on_rank0
### data
train_dataset: data/v1_sft_demo.yaml
### training
output_dir: ./outputs/test_lora
micro_batch_size: 1
cutoff_len: 2048
learning_rate: 1.0e-4
bf16: true
max_steps: 10
### sample
sample_backend: hf
max_new_tokens: 128

View File

@@ -1,5 +1,4 @@
model: Qwen/Qwen3-0.6B
trust_remote_code: true
model_class: llm
template: qwen3_nothink
@@ -35,7 +34,6 @@ output_dir: outputs/test_quantization
micro_batch_size: 1
cutoff_len: 2048
learning_rate: 1.0e-4
bf16: false
max_steps: 10
### sample

View File

@@ -40,7 +40,7 @@ dependencies = [
"torch>=2.4.0",
"torchvision>=0.19.0",
"torchaudio>=2.4.0",
"transformers>=4.51.0,<=5.2.0,!=4.52.0,!=4.57.0",
"transformers>=4.55.0,<=5.6.0,!=4.52.0,!=4.57.0",
"datasets>=2.16.0,<=4.0.0",
"accelerate>=1.3.0,<=1.11.0",
"peft>=0.18.0,<=0.18.1",

View File

@@ -0,0 +1 @@
ktransformers[sft]

View File

@@ -1,4 +1,5 @@
torch==2.7.1
torch-npu==2.7.1.post2
torch-npu==2.7.1.post4
torchvision==0.22.1
torchaudio==2.7.1
decorator

76
scripts/dcp2hf.py Normal file
View File

@@ -0,0 +1,76 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert a DCP checkpoint to HuggingFace model format.
Usage:
python scripts/dcp2hf.py convert --dcp_path=/path/to/dcp --hf_path=/path/to/hf --config_path=/path/to/config
Arguments:
dcp_path: Path to the DCP checkpoint directory.
hf_path: Output path (directory) for HuggingFace model.
config_path: Path to the HuggingFace model directory containing config.json.
"""
import fire
import torch
import torch.distributed.checkpoint as dcp
import transformers
from transformers import AutoConfig
def convert(dcp_path: str, hf_path: str, config_path: str) -> None:
"""Convert DCP model weights to HF.
Note: this script is used to convert a DCP checkpoint to HuggingFace model format,
it will just convert the DCP checkpoint to a HuggingFace model format, for the tokenizer,
you may need to copy from the original model.
Args:
dcp_path: DCP checkpoint directory.
hf_path: Output path (directory) for HuggingFace model.
config_path: Path to the HuggingFace model directory containing config.json.
"""
if not dcp_path or not hf_path or not config_path:
raise ValueError("All 'dcp_path', 'hf_path', and 'config_path' are required.")
print(f"Loading config from {config_path}...")
config = AutoConfig.from_pretrained(config_path)
architectures = getattr(config, "architectures", [])
if architectures:
model_cls = getattr(transformers, architectures[0], transformers.AutoModelForCausalLM)
else:
model_cls = transformers.AutoModelForCausalLM
print("Initializing model on CPU...")
model = model_cls(config).to(torch.bfloat16)
print(f"Loading DCP from {dcp_path}...")
state_dict = model.state_dict()
dcp.load(state_dict, checkpoint_id=dcp_path)
model.load_state_dict(state_dict)
print(f"Saving to HF format at {hf_path}...")
model.save_pretrained(hf_path)
config.save_pretrained(hf_path)
print("Done!")
def help() -> None:
"""Show help message."""
print(__doc__)
if __name__ == "__main__":
fire.Fire({"convert": convert, "help": help, "--convert": convert})

View File

@@ -25,7 +25,8 @@ Arguments:
import fire
import torch
import torch.distributed.checkpoint as dcp
from transformers import AutoModelForCausalLM
import transformers
from transformers import AutoConfig
def convert(hf_path: str, dcp_path: str) -> None:
@@ -39,7 +40,14 @@ def convert(hf_path: str, dcp_path: str) -> None:
raise ValueError("Both 'hf_path' and 'dcp_path' are required.")
print(f"Loading HF model from {hf_path}...")
model = AutoModelForCausalLM.from_pretrained(hf_path, device_map="cpu", torch_dtype=torch.bfloat16)
config = AutoConfig.from_pretrained(hf_path)
architectures = getattr(config, "architectures", [])
if architectures:
model_cls = getattr(transformers, architectures[0], transformers.AutoModelForCausalLM)
else:
model_cls = transformers.AutoModelForCausalLM
model = model_cls.from_pretrained(hf_path, device_map="cpu", torch_dtype=torch.bfloat16)
print(f"Saving to DCP format at {dcp_path}...")
dcp.save(model.state_dict(), checkpoint_id=dcp_path)

View File

@@ -88,7 +88,10 @@ def _process_request(
if request.messages[0].role == Role.SYSTEM:
content = request.messages.pop(0).content
system = content[0].text if isinstance(content, list) else content
if isinstance(content, list):
system = content[0].text if content else ""
else:
system = content
else:
system = None

View File

@@ -71,16 +71,6 @@ class ChatModel:
"SGLang not install, you may need to run `pip install sglang[all]`\n"
"or try to use HuggingFace backend: --infer_backend huggingface"
) from e
elif model_args.infer_backend == EngineName.KT:
try:
from .kt_engine import KTransformersEngine
self.engine: BaseEngine = KTransformersEngine(model_args, data_args, finetuning_args, generating_args)
except ImportError as e:
raise ImportError(
"KTransformers not install, you may need to run `pip install ktransformers`\n"
"or try to use HuggingFace backend: --infer_backend huggingface"
) from e
else:
raise NotImplementedError(f"Unknown backend: {model_args.infer_backend}")

View File

@@ -1,284 +0,0 @@
# Copyright 2025 the KVCache.AI team, Approaching AI, and the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import os
import platform
from collections.abc import AsyncGenerator
from threading import Thread
from typing import TYPE_CHECKING, Any, Optional
import torch
from typing_extensions import override
from ..data import get_template_and_fix_tokenizer
from ..extras import logging
from ..extras.constants import EngineName
from ..model import load_model, load_tokenizer
from .base_engine import BaseEngine, Response
if TYPE_CHECKING:
from transformers import PreTrainedTokenizer
from trl import PreTrainedModelWrapper
from ..data.mm_plugin import AudioInput, ImageInput, VideoInput
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled
from ktransformers.server.config.config import Config
from ktransformers.util.utils import (
get_compute_capability,
prefill_and_generate_capture,
)
from ktransformers.util.vendors import GPUVendor, device_manager
logger = logging.get_logger(__name__)
class KTransformersEngine(BaseEngine):
def __init__(
self,
model_args: "ModelArguments",
data_args: "DataArguments",
finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments",
) -> None:
self.name = EngineName.KT
self.can_generate = finetuning_args.stage == "sft"
tok_mod = load_tokenizer(model_args)
self.tokenizer = tok_mod["tokenizer"]
self.tokenizer.padding_side = "left" if self.can_generate else "right"
self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args)
self.model = load_model(
self.tokenizer, model_args, finetuning_args, is_trainable=False, add_valuehead=(not self.can_generate)
)
self.generating_args = generating_args.to_dict()
self.max_new_tokens = model_args.kt_maxlen
self.use_cuda_graph = model_args.kt_use_cuda_graph
self.mode = model_args.kt_mode
self.force_think = model_args.kt_force_think
self.chunk_size = model_args.chunk_size
try:
asyncio.get_event_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
self.semaphore = asyncio.Semaphore(int(os.getenv("MAX_CONCURRENT", "1")))
@staticmethod
@torch.inference_mode()
def _get_scores(
model: "PreTrainedModelWrapper",
tokenizer: "PreTrainedTokenizer",
batch_input: list[str],
input_kwargs: Optional[dict[str, Any]] = {},
) -> list[float]:
max_length: Optional[int] = input_kwargs.pop("max_length", None)
device = getattr(model.pretrained_model, "device", "cuda")
inputs = tokenizer(
batch_input,
padding=True,
truncation=True,
max_length=max_length or getattr(model.config, "max_position_embeddings", 1024),
return_tensors="pt",
add_special_tokens=False,
).to(device)
values: torch.Tensor = model(**inputs, return_dict=True, use_cache=False)[-1]
scores = values.gather(dim=-1, index=(inputs["attention_mask"].sum(dim=-1, keepdim=True) - 1))
return scores
async def _generate(
self,
messages: list[dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
**input_kwargs,
) -> AsyncGenerator[str, None]:
paired = messages + [{"role": "assistant", "content": ""}]
prompt_ids, _ = self.template.encode_oneturn(self.tokenizer, paired, system, tools)
prompt_len = len(prompt_ids)
max_length: Optional[int] = input_kwargs.pop("max_length", None)
max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None)
if "max_new_tokens" in self.generating_args:
max_tokens = int(self.generating_args["max_new_tokens"])
elif "max_length" in self.generating_args:
gl = int(self.generating_args["max_length"])
max_tokens = gl - prompt_len if gl > prompt_len else 1
else:
max_tokens = self.max_new_tokens or 256
if max_length is not None:
max_tokens = max(max_length - prompt_len, 1)
if max_new_tokens is not None:
max_tokens = int(max_new_tokens)
max_tokens = max(1, int(max_tokens))
if self.mode == "long_context":
max_len_cfg = Config().long_context_config["max_seq_len"]
need = prompt_len + max_tokens
assert max_len_cfg > need, f"please set max_seq_len > {need} in ~/.ktransformers/config.yaml"
device = next(self.model.parameters()).device
input_tensor = torch.tensor([prompt_ids], dtype=torch.long, device=device)
if self.force_think:
think = torch.tensor(
[self.tokenizer.encode("<think>\n", add_special_tokens=False)], dtype=torch.long, device=device
)
input_tensor = torch.cat([input_tensor, think], dim=1)
use_flashinfer = (
platform.system() != "Windows"
and getattr(self.model.config, "architectures", [""])[0]
in {"DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"}
and flashinfer_enabled
and get_compute_capability() >= 8
and device_manager.gpu_vendor == GPUVendor.NVIDIA
)
def make_gen():
if use_flashinfer:
return prefill_and_generate_capture(
self.model,
self.tokenizer,
input_tensor,
max_tokens,
self.use_cuda_graph,
mode=self.mode,
force_think=self.force_think,
chunk_size=self.chunk_size,
use_flashinfer_mla=True,
num_heads=self.model.config.num_attention_heads,
head_dim_ckv=getattr(self.model.config, "kv_lora_rank", 0),
head_dim_kpe=getattr(self.model.config, "qk_rope_head_dim", 0),
q_head_dim=getattr(self.model.config, "qk_rope_head_dim", 0)
+ getattr(self.model.config, "qk_nope_head_dim", 0),
echo_stream=False,
)
else:
return prefill_and_generate_capture(
self.model,
self.tokenizer,
input_tensor,
max_tokens,
self.use_cuda_graph,
mode=self.mode,
force_think=self.force_think,
chunk_size=self.chunk_size,
echo_stream=False,
)
loop = asyncio.get_running_loop()
q: asyncio.Queue[Optional[str]] = asyncio.Queue()
def producer():
try:
gen = make_gen()
if hasattr(gen, "__aiter__"):
async def drain_async():
async for t in gen:
loop.call_soon_threadsafe(q.put_nowait, t if isinstance(t, str) else str(t))
asyncio.run(drain_async())
elif hasattr(gen, "__iter__"):
for t in gen:
loop.call_soon_threadsafe(q.put_nowait, t if isinstance(t, str) else str(t))
else:
loop.call_soon_threadsafe(q.put_nowait, gen if isinstance(gen, str) else str(gen))
finally:
loop.call_soon_threadsafe(q.put_nowait, None)
Thread(target=producer, daemon=True).start()
while True:
item = await q.get()
if item is None:
break
yield item
@override
async def chat(
self,
messages: list[dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
images: Optional[list["ImageInput"]] = None,
videos: Optional[list["VideoInput"]] = None,
audios: Optional[list["AudioInput"]] = None,
**input_kwargs,
) -> list["Response"]:
if not self.can_generate:
raise ValueError("The current model does not support `chat`.")
async with self.semaphore:
produced = ""
final_text = ""
async for t in self._generate(messages, system, tools, **input_kwargs):
delta = t
produced = produced + delta
if delta:
final_text += delta
prompt_ids, _ = self.template.encode_oneturn(
self.tokenizer, messages + [{"role": "assistant", "content": ""}], system, tools
)
return [
Response(
response_text=final_text,
response_length=len(self.tokenizer.encode(final_text, add_special_tokens=False)),
prompt_length=len(prompt_ids),
finish_reason="stop",
)
]
@override
async def stream_chat(
self,
messages: list[dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
images: Optional[list["ImageInput"]] = None,
videos: Optional[list["VideoInput"]] = None,
audios: Optional[list["AudioInput"]] = None,
**input_kwargs,
) -> AsyncGenerator[str, None]:
if not self.can_generate:
raise ValueError("The current model does not support `stream_chat`.")
async with self.semaphore:
produced = ""
async for t in self._generate(messages, system, tools, **input_kwargs):
delta = t[len(produced) :] if t.startswith(produced) else t
produced = t
if delta:
yield delta
@override
async def get_scores(
self,
batch_input: list[str],
**input_kwargs,
) -> list[float]:
if self.can_generate:
raise ValueError("Cannot get scores using an auto-regressive model.")
args = (self.model, self.tokenizer, batch_input, input_kwargs)
async with self.semaphore:
return await asyncio.to_thread(self._get_scores, *args)

View File

@@ -15,6 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import inspect
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Literal, Optional
@@ -25,7 +26,7 @@ import torch.nn.functional as F
from peft import PeftModel
from transformers import DataCollatorForSeq2Seq
from ..extras.constants import AUDIO_PLACEHOLDER, IGNORE_INDEX, IMAGE_PLACEHOLDER
from ..extras.constants import AUDIO_PLACEHOLDER, IGNORE_INDEX, IMAGE_PLACEHOLDER, MROPE_MODELS
from ..extras.packages import is_pillow_available
@@ -39,6 +40,56 @@ if TYPE_CHECKING:
from .template import Template
def _slice_mm_inputs_for_sample(
mm_inputs: dict[str, Any],
batch_imglens: list[int],
batch_vidlens: list[int],
batch_idx: int,
images_per_subseq: Optional[list[int]] = None,
videos_per_subseq: Optional[list[int]] = None,
subseq_idx: Optional[int] = None,
) -> dict[str, Any]:
r"""Slice mm_inputs for one batch sample, optionally for a single sub-sequence when packing.
image_grid_thw / video_grid_thw have shape [num_items, 3]. Indices for sample batch_idx
are batch_imglens[batch_idx] images and batch_vidlens[batch_idx] videos. When subseq_idx
is given, further restrict to that sub-seq's counts via packed_*_counts.
has_dummy_image=True means only batch[0] will be concated with fake image and no multimodal data.
"""
image_start_idx = sum(batch_imglens[:batch_idx])
image_end_idx = sum(batch_imglens[: batch_idx + 1])
video_start_idx = sum(batch_vidlens[:batch_idx])
video_end_idx = sum(batch_vidlens[: batch_idx + 1])
if subseq_idx is not None and images_per_subseq is not None:
image_start_idx += sum(images_per_subseq[:subseq_idx])
image_end_idx = image_start_idx + images_per_subseq[subseq_idx]
if subseq_idx is not None and videos_per_subseq is not None:
video_start_idx += sum(videos_per_subseq[:subseq_idx])
video_end_idx = video_start_idx + videos_per_subseq[subseq_idx]
sliced_mm_inputs: dict[str, Any] = {}
key_to_slice_meta = {
"image_grid_thw": (image_start_idx, image_end_idx, True),
"video_grid_thw": (video_start_idx, video_end_idx, True),
"second_per_grid_ts": (video_start_idx, video_end_idx, False), # qwen2.5vl
"video_second_per_grid": (video_start_idx, video_end_idx, False), # qwen omni
}
for key, (start_idx, end_idx, assign_none_when_empty) in key_to_slice_meta.items():
if key not in mm_inputs:
continue
mm_value = mm_inputs[key]
if mm_value is not None and end_idx > start_idx:
sliced_mm_inputs[key] = mm_value[start_idx:end_idx]
elif assign_none_when_empty:
sliced_mm_inputs[key] = None
return sliced_mm_inputs
def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype: "torch.dtype") -> "torch.Tensor":
r"""Expand 2d attention mask to 4d attention mask.
@@ -106,9 +157,174 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
else:
self.get_rope_func = None
def _compute_rope_position_ids(self, features: dict[str, "torch.Tensor"], mm_inputs: dict[str, Any]) -> None:
r"""Compute position_ids and rope_deltas via get_rope_func for VLMs."""
rope_index_kwargs = {
"input_ids": features["input_ids"],
"image_grid_thw": mm_inputs.get("image_grid_thw"),
"video_grid_thw": mm_inputs.get("video_grid_thw"),
"attention_mask": (features["attention_mask"] >= 1).float(),
}
if features["attention_mask"].sum() == 0: # for pad tokens
seq_len = features["input_ids"].shape[-1]
features["position_ids"] = (
torch.arange(seq_len).view(1, 1, seq_len).expand(3, *features["input_ids"].shape).contiguous()
)
features["rope_deltas"] = torch.zeros(features["input_ids"].shape[0])
return
if "mm_token_type_ids" in inspect.signature(self.get_rope_func).parameters:
image_token_id = getattr(self.model.config, "image_token_id", None)
video_token_id = getattr(self.model.config, "video_token_id", None)
if image_token_id is not None or video_token_id is not None:
mm_token_type_ids = torch.zeros_like(features["input_ids"])
if image_token_id is not None:
mm_token_type_ids[features["input_ids"] == image_token_id] = 1
if video_token_id is not None:
mm_token_type_ids[features["input_ids"] == video_token_id] = 2
rope_index_kwargs["mm_token_type_ids"] = mm_token_type_ids
if "second_per_grid_ts" in mm_inputs: # for qwen2vl
rope_index_kwargs["second_per_grid_ts"] = mm_inputs.get("second_per_grid_ts")
elif "video_second_per_grid" in mm_inputs: # for qwen2.5 omni
rope_index_kwargs["second_per_grids"] = mm_inputs.get("video_second_per_grid")
if getattr(self.model.config, "model_type", None) in ["qwen2_5_omni_thinker", "qwen3_omni_moe_thinker"]:
rope_index_kwargs["use_audio_in_video"] = getattr(self.processor, "use_audio_in_video", False)
feature_attention_mask = mm_inputs.get("feature_attention_mask", None)
if feature_attention_mask is not None: # FIXME: need to get video image lengths
audio_feature_lengths = torch.sum(feature_attention_mask, dim=1)
rope_index_kwargs["audio_seqlens"] = audio_feature_lengths # prepare for input
features["position_ids"], rope_deltas = self.get_rope_func(**rope_index_kwargs)
features["rope_deltas"] = rope_deltas - (1 - rope_index_kwargs["attention_mask"]).sum(dim=-1).unsqueeze(-1)
else: # for qwen vl
features["position_ids"], features["rope_deltas"] = self.get_rope_func(**rope_index_kwargs)
def _compute_rope_position_ids_with_packing(
self,
features: dict[str, "torch.Tensor"],
mm_inputs: dict[str, Any],
packing_params_list: list[dict[str, Any] | None],
batch_imglens: list[int],
batch_vidlens: list[int],
batch_audlens: list[int],
has_dummy_image: bool,
) -> None:
r"""Compute position_ids and rope_deltas per sample (or per sub-sequence when packed), then merge and validate."""
bsz = features["input_ids"].size(0)
seq_len = features["input_ids"].size(1)
all_position_ids: list[torch.Tensor] = []
all_rope_deltas: list[torch.Tensor] = []
if has_dummy_image:
# for [0, seq_len] = [0, unpadded_length + right_padding_length + fake_input_ids_len + collator_padding_length]
# FIXME: maybe right_padding_length is large, with improper max_cutoff_len
unpadded_length = int(features["attention_mask"][0].bool().sum().item())
right_padding_length = int((packing_params_list[0] or {}).get("right_padding_length") or 0)
fake_input_padding_length = max(0, seq_len - unpadded_length - right_padding_length)
# avoid continual cuseqlens breaking varlen attention @kuangdd
# https://github.com/hiyouga/LlamaFactory/issues/10452
dummy_image_right_padding_mrope = (
torch.arange(fake_input_padding_length)
.view(1, 1, fake_input_padding_length)
.expand(3, bsz, fake_input_padding_length)
)
dummy_image_right_padding_attention_mask = torch.zeros((bsz, fake_input_padding_length))
assert self.tokenizer.padding_side == "right", "padding_side should be right when fake image is injected"
dummy_mm_inputs = copy.deepcopy(mm_inputs)
for sample_idx in range(bsz):
sample_packing = (packing_params_list[sample_idx] or {}) if sample_idx < len(packing_params_list) else {}
sequence_boundaries = sample_packing.get("sequence_boundaries")
num_sub_seqs = (
(len(sequence_boundaries) - 1) if sequence_boundaries and len(sequence_boundaries) > 1 else 1
)
image_subseq_ids = sample_packing.get("image_subseq_ids") or []
video_subseq_ids = sample_packing.get("video_subseq_ids") or []
images_per_subseq = (
[image_subseq_ids.count(i) for i in range(num_sub_seqs)]
if image_subseq_ids and num_sub_seqs > 1
else None
)
videos_per_subseq = (
[video_subseq_ids.count(i) for i in range(num_sub_seqs)]
if video_subseq_ids and num_sub_seqs > 1
else None
)
if has_dummy_image:
mm_inputs = {}
if num_sub_seqs <= 1:
sample_features = {
"input_ids": features["input_ids"],
"attention_mask": features["attention_mask"][sample_idx : sample_idx + 1],
}
mm_inputs_for_sample = _slice_mm_inputs_for_sample(
mm_inputs, batch_imglens, batch_vidlens, sample_idx=sample_idx
)
self._compute_rope_position_ids(sample_features, mm_inputs_for_sample)
all_position_ids.append(sample_features["position_ids"])
all_rope_deltas.append(sample_features["rope_deltas"])
else:
# when we do packing, don't need rope_deltas when training.
sample_position_ids: list[torch.Tensor] = []
for subseq_idx in range(num_sub_seqs):
subseq_start = sequence_boundaries[subseq_idx]
subseq_end = sequence_boundaries[subseq_idx + 1]
subseq_features = {
"input_ids": features["input_ids"][sample_idx : sample_idx + 1, subseq_start:subseq_end],
"attention_mask": features["attention_mask"][
sample_idx : sample_idx + 1, subseq_start:subseq_end
],
}
mm_inputs_for_subseq = _slice_mm_inputs_for_sample(
mm_inputs,
batch_imglens,
batch_vidlens,
sample_idx,
images_per_subseq,
videos_per_subseq,
subseq_idx,
)
self._compute_rope_position_ids(subseq_features, mm_inputs_for_subseq)
sample_position_ids.append(subseq_features["position_ids"])
all_position_ids.append(torch.cat(sample_position_ids, dim=-1))
batch_dim_for_position_ids = 1 if all_position_ids[0].dim() == 3 else 0
features["position_ids"] = torch.cat(all_position_ids, dim=batch_dim_for_position_ids)
if has_dummy_image:
mm_inputs = dummy_mm_inputs
expected_position_ids_shape = (
(bsz, seq_len)
if all_position_ids[0].dim() == 2
else (
all_position_ids[0].size(0),
bsz,
seq_len,
)
)
# Check if position_ids shape matches expected shape.
# for further usage, we should padding to the right when some padding token on the right.
if has_dummy_image:
features["position_ids"] = torch.cat([features["position_ids"], dummy_image_right_padding_mrope], dim=-1)
features["attention_mask"] = torch.cat(
[features["attention_mask"], dummy_image_right_padding_attention_mask], dim=-1
)
if features["position_ids"].shape != expected_position_ids_shape:
raise ValueError(
"Merged position_ids shape mismatch: "
f"got {features['position_ids'].shape}, expected {expected_position_ids_shape}."
)
def __call__(self, features: list[dict[str, Any]]) -> dict[str, "torch.Tensor"]:
batch_images, batch_videos, batch_audios = [], [], []
batch_imglens, batch_vidlens, batch_audlens, batch_input_ids = [], [], [], []
packing_params_list: list[dict[str, Any] | None] = []
for feature in features:
images = feature.pop("images", None) or []
videos = feature.pop("videos", None) or []
@@ -120,8 +336,10 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
batch_vidlens.append(len(videos))
batch_audlens.append(len(audios))
batch_input_ids.append(feature["input_ids"])
packing_params_list.append(feature.pop("packing_params", None))
fake_input_ids = []
has_dummy_image = False
if (
self.template.mm_plugin.image_token is not None and sum(batch_imglens) == 0 and sum(batch_vidlens) == 0
): # avoid process hanging in zero3/fsdp case
@@ -137,6 +355,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
fake_input_ids.extend(_fake_input_ids)
batch_images = fake_images
batch_imglens[0] = 1
has_dummy_image = True
if (
self.template.mm_plugin.audio_token is not None and sum(batch_audlens) == 0
@@ -181,59 +400,63 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
for i, feature in enumerate(features):
feature["token_type_ids"] = token_type_ids[i]
if "mm_token_type_ids" in mm_inputs: # need tensor-like for gemma4
mm_token_type_ids = mm_inputs.pop("mm_token_type_ids")
max_len = max(len(ids) for ids in mm_token_type_ids)
padded = []
for ids in mm_token_type_ids:
pad_len = max_len - len(ids)
if self.tokenizer.padding_side == "right":
padded.append(ids + [0] * pad_len)
else:
padded.append([0] * pad_len + ids)
mm_inputs["mm_token_type_ids"] = torch.tensor(padded, dtype=torch.long)
features: dict[str, torch.Tensor] = super().__call__(features)
bsz, seq_len = features["input_ids"].shape[:2]
model_type = getattr(self.model.config, "model_type", None) if self.model is not None else None
is_omni = model_type in [
"qwen2_5_omni_thinker",
"qwen3_omni_moe_thinker",
]
if self.get_rope_func is not None:
rope_index_kwargs = {
"input_ids": features["input_ids"],
"image_grid_thw": mm_inputs.get("image_grid_thw"),
"video_grid_thw": mm_inputs.get("video_grid_thw"),
"attention_mask": (features["attention_mask"] >= 1).float(),
}
if "mm_token_type_ids" in inspect.signature(self.get_rope_func).parameters:
image_token_id = getattr(self.model.config, "image_token_id", None)
video_token_id = getattr(self.model.config, "video_token_id", None)
if image_token_id is not None or video_token_id is not None:
mm_token_type_ids = torch.zeros_like(features["input_ids"])
if image_token_id is not None:
mm_token_type_ids[features["input_ids"] == image_token_id] = 1
if video_token_id is not None:
mm_token_type_ids[features["input_ids"] == video_token_id] = 2
rope_index_kwargs["mm_token_type_ids"] = mm_token_type_ids
if "second_per_grid_ts" in mm_inputs: # for qwen2vl
rope_index_kwargs["second_per_grid_ts"] = mm_inputs.get("second_per_grid_ts")
elif "video_second_per_grid" in mm_inputs: # for qwen2.5 omni
rope_index_kwargs["second_per_grids"] = mm_inputs.get("video_second_per_grid")
# for mmrope situation, we should calculate position_ids and rope_deltas per sample.
# When neat_packing is on, each sample has packing_params; None means no packing for that sample.
boundaries_list = [p.get("sequence_boundaries") if p is not None else None for p in packing_params_list]
has_packing = any(b is not None and len(b) > 2 for b in boundaries_list)
if has_dummy_image and has_packing:
# FIXME: too tricky, need to be refactored @kuangdd
features["has_dummy_image"] = True
if getattr(self.model.config, "model_type", None) in ["qwen2_5_omni_thinker", "qwen3_omni_moe_thinker"]:
rope_index_kwargs["use_audio_in_video"] = getattr(self.processor, "use_audio_in_video", False)
feature_attention_mask = mm_inputs.get("feature_attention_mask", None)
if feature_attention_mask is not None: # FIXME: need to get video image lengths
audio_feature_lengths = torch.sum(feature_attention_mask, dim=1)
rope_index_kwargs["audio_seqlens"] = audio_feature_lengths # prepare for input
# When fake image/audio was injected, sequence_boundaries no longer match the tensor; use non-packing path.
if not has_packing:
self._compute_rope_position_ids(features, mm_inputs)
else:
if is_omni: # TODO: support omni models for packed sequences @kuangdd
raise RuntimeError("Omni models are not supported for packed sequences for now.")
features["position_ids"], rope_deltas = self.get_rope_func(**rope_index_kwargs)
features["rope_deltas"] = rope_deltas - (1 - rope_index_kwargs["attention_mask"]).sum(
dim=-1
).unsqueeze(-1)
else: # for qwen vl
features["position_ids"], features["rope_deltas"] = self.get_rope_func(**rope_index_kwargs)
self._compute_rope_position_ids_with_packing(
features,
mm_inputs,
packing_params_list,
batch_imglens,
batch_vidlens,
batch_audlens,
has_dummy_image,
)
# For transformers compatibility, after https://github.com/huggingface/transformers/issues/39400
if features["position_ids"].dim() == 3:
features["position_ids"] = torch.cat(
[features["position_ids"][0].unsqueeze(0), features["position_ids"]], dim=0
)
if (
self.model is not None
and getattr(self.model.config, "model_type", None)
in [
"glm4v",
"glm_ocr",
"Keye",
"qwen2_vl",
"qwen2_5_vl",
"qwen2_5_omni_thinker",
"qwen3_omni_moe_thinker",
"qwen3_5",
"qwen3_vl",
"qwen3_vl_moe",
]
and getattr(self.model.config, "model_type", None) in MROPE_MODELS
and ("position_ids" not in features or features["position_ids"].dim() != 3)
):
raise ValueError(f"{self.model.config.model_type} requires 3D position ids for mrope.")
@@ -261,12 +484,53 @@ class SFTDataCollatorWith4DAttentionMask(MultiModalDataCollatorForSeq2Seq):
block_diag_attn: bool = False
attn_implementation: Literal["eager", "sdpa", "flash_attention_2"] = "eager"
compute_dtype: "torch.dtype" = torch.float32
neat_packing: bool = False
def __post_init__(self):
super().__post_init__()
if self.neat_packing and self.attn_implementation == "flash_attention_2":
if self.model is not None and getattr(self.model.config, "model_type", None) in ["gemma4", "gpt_oss"]:
raise ValueError("Neat packing is not supported for gemma4, gpt_oss models for now.")
@staticmethod
def _unpad_packed_features(features: dict[str, Any]) -> None:
r"""Trim padded positions for packed FA2 batches."""
attention_mask = features.get("attention_mask")
if not torch.is_tensor(attention_mask) or attention_mask.dim() != 2 or attention_mask.size(0) != 1:
return
seq_len = attention_mask.size(1)
non_padding_indices = torch.nonzero(attention_mask[0] != 0, as_tuple=False).flatten()
if non_padding_indices.numel() == seq_len:
return
keys_on_seq_dim_1 = {"input_ids", "labels", "attention_mask", "token_type_ids"}
for key, value in list(features.items()):
if not torch.is_tensor(value):
continue
if key == "position_ids" and value.size(-1) == seq_len:
features[key] = value.index_select(-1, non_padding_indices)
elif (
key == "cross_attention_mask" and value.dim() >= 2 and value.size(0) == 1 and value.size(1) == seq_len
):
features[key] = value.index_select(1, non_padding_indices)
elif key in keys_on_seq_dim_1 and value.dim() == 2 and value.size(0) == 1 and value.size(1) == seq_len:
features[key] = value.index_select(1, non_padding_indices)
def __call__(self, features: list[dict[str, Any]]) -> dict[str, "torch.Tensor"]:
features = super().__call__(features)
has_dummy_image = features.pop("has_dummy_image", False)
if self.block_diag_attn and self.attn_implementation != "flash_attention_2":
features["attention_mask"] = prepare_4d_attention_mask(features["attention_mask"], self.compute_dtype)
if self.neat_packing and self.attn_implementation == "flash_attention_2": # FIXME compatibility fa3/fa4
assert features["input_ids"].shape[0] == 1, "bsz should be 1 for neat packing"
if not has_dummy_image:
self._unpad_packed_features(features)
features["attention_mask"] = None # let transformers handle causal packed mask.
for key, value in features.items(): # cast data dtype for paligemma
if torch.is_tensor(value) and torch.is_floating_point(value):
features[key] = value.to(self.compute_dtype)

View File

@@ -257,8 +257,8 @@ class OpenAIDatasetConverter(DatasetConverter):
content = message[self.dataset_attr.content_tag]
if role in [self.dataset_attr.assistant_tag, self.dataset_attr.function_tag]:
if "tool_calls" in message and len(message["tool_calls"]) > 0:
tool_calls_list = [tool["function"] for tool in message["tool_calls"]]
if tool_calls := message.get("tool_calls"):
tool_calls_list = [tool["function"] for tool in tool_calls]
content = json.dumps(tool_calls_list, ensure_ascii=False)
role = self.dataset_attr.function_tag

View File

@@ -22,16 +22,18 @@ import re
from copy import deepcopy
from dataclasses import dataclass
from io import BytesIO
from typing import TYPE_CHECKING, BinaryIO, Literal, NotRequired, Optional, TypedDict, Union
from types import SimpleNamespace
from typing import TYPE_CHECKING, Any, BinaryIO, Literal, NotRequired, Optional, TypedDict, Union
import numpy as np
import torch
import torchaudio
from transformers.image_utils import get_image_size, is_valid_image, to_numpy_array
from transformers.image_utils import get_image_size, is_valid_image, make_flat_list_of_images, to_numpy_array
from transformers.models.mllama.processing_mllama import (
convert_sparse_cross_attention_mask_to_dense,
get_cross_attention_token_mask,
)
from transformers.video_utils import make_batched_videos
from typing_extensions import override
from ..extras.constants import AUDIO_PLACEHOLDER, IGNORE_INDEX, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
@@ -47,13 +49,6 @@ if is_pyav_available():
import av
if is_transformers_version_greater_than("4.52.0"):
from transformers.image_utils import make_flat_list_of_images
from transformers.video_utils import make_batched_videos
else:
from transformers.image_utils import make_batched_videos, make_flat_list_of_images
if TYPE_CHECKING:
from av.stream import Stream
from numpy.typing import NDArray
@@ -251,6 +246,14 @@ class MMPluginMixin:
sample_frames = min(total_frames, video_maxlen, sample_frames)
return np.linspace(0, total_frames - 1, sample_frames).astype(np.int32)
def _get_video_token_metadata(
self,
videos: list["VideoInput"],
processor: "MMProcessor",
) -> Optional[dict[str, Any]]:
r"""Build metadata used to expand video tokens without decoding frames."""
return None
def _regularize_images(self, images: list["ImageInput"], **kwargs) -> "RegularizedImageOutput":
r"""Regularize images to avoid error. Including reading and pre-processing."""
results = []
@@ -613,6 +616,217 @@ class Gemma3nPlugin(Gemma3Plugin):
return messages
@dataclass
class Gemma4Plugin(BasePlugin):
r"""Plugin for the Gemma4 multimodal model."""
@override
def _regularize_videos(self, videos: list["VideoInput"], **kwargs) -> "RegularizedVideoOutput":
r"""Regularize videos, also tracking per-video FPS and frame indices for timestamp generation."""
results, fps_per_video, durations, frames_indices = [], [], [], []
for video in videos:
frames: list[ImageObject] = []
if _check_video_is_nested_images(video):
frames = video
fps_per_video.append(kwargs.get("video_fps", 2.0))
durations.append(len(frames) / kwargs.get("video_fps", 2.0))
frames_indices.append(list(range(len(frames))))
else:
container = av.open(video, "r")
video_stream = next(stream for stream in container.streams if stream.type == "video")
sample_indices = self._get_video_sample_indices(video_stream, **kwargs)
original_fps = float(video_stream.average_rate)
# for correctly calculate timestamps
frames_indices.append([idx / original_fps * kwargs.get("video_fps", 2.0) for idx in sample_indices])
container.seek(0)
for frame_idx, frame in enumerate(container.decode(video_stream)):
if frame_idx in sample_indices:
frames.append(frame.to_image())
if video_stream.duration is None:
durations.append(len(frames) / kwargs.get("video_fps", 2.0))
else:
durations.append(float(video_stream.duration * video_stream.time_base))
frames = self._regularize_images(frames, **kwargs)["images"]
results.append(frames)
return {
"videos": results,
"fps_per_video": fps_per_video,
"durations": durations,
"frames_indices": frames_indices,
}
@override
def _get_mm_inputs(
self,
images: list["ImageInput"],
videos: list["VideoInput"],
audios: list["AudioInput"],
processor: "MMProcessor",
) -> dict[str, Union[list[int], "torch.Tensor"]]:
image_processor = getattr(processor, "image_processor", None)
video_processor = getattr(processor, "video_processor", None)
feature_extractor = getattr(processor, "feature_extractor", None)
mm_inputs = {}
if len(images) != 0:
regularized = self._regularize_images(
images,
image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768),
image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32),
)["images"]
mm_inputs.update(image_processor(regularized, return_tensors="pt"))
if len(videos) != 0:
video_data = self._regularize_videos(
videos,
image_max_pixels=getattr(processor, "video_max_pixels", 256 * 256),
image_min_pixels=getattr(processor, "video_min_pixels", 16 * 16),
video_fps=getattr(processor, "video_fps", 2.0),
video_maxlen=getattr(processor, "video_maxlen", 128),
)
video_metadata = [
{
"fps": getattr(processor, "video_fps", 2.0),
"duration": duration,
"total_num_frames": len(video),
"frames_indices": sample_indices,
}
for video, duration, sample_indices in zip(
video_data["videos"], video_data["durations"], video_data["frames_indices"]
)
]
mm_inputs.update(
video_processor(
videos=video_data["videos"],
video_metadata=video_metadata,
return_tensors="pt",
return_metadata=True,
do_sample_frames=False,
)
)
if len(audios) != 0: # only for gemma4n
audios = self._regularize_audios(
audios,
sampling_rate=getattr(processor, "audio_sampling_rate", 16000),
)["audios"]
mm_inputs.update(
feature_extractor(
audios,
padding="max_length",
return_tensors="pt",
)
)
return mm_inputs
@override
def process_messages(
self,
messages: list[dict[str, str]],
images: list["ImageInput"],
videos: list["VideoInput"],
audios: list["AudioInput"],
processor: Optional["MMProcessor"],
) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios)
self._validate_messages(messages, images, videos, audios)
messages = deepcopy(messages)
boi_token: str = getattr(processor, "boi_token")
eoi_token: str = getattr(processor, "eoi_token")
boa_token: str = getattr(processor, "boa_token")
eoa_token: str = getattr(processor, "eoa_token")
image_token: str = getattr(processor, "image_token")
video_token: str = getattr(processor, "video_token")
audio_token: str = getattr(processor, "audio_token")
if self.expand_mm_tokens:
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
num_image_soft_tokens: list[int] = list(
mm_inputs.get("num_soft_tokens_per_image", [getattr(processor, "image_seq_length", 256)] * len(images))
)
num_video_soft_tokens: list[int] = list(mm_inputs.get("num_soft_tokens_per_video", [1] * len(videos)))
video_metadata = mm_inputs.get("video_metadata", [])
else:
num_image_soft_tokens = [1] * len(images)
num_video_soft_tokens = [1] * len(videos)
video_metadata = [None] * len(videos)
audio_iter = iter(audios)
image_iter = iter(num_image_soft_tokens)
video_iter = iter(zip(num_video_soft_tokens, video_metadata))
for message in messages:
content = message["content"]
while IMAGE_PLACEHOLDER in content:
n = next(image_iter)
content = content.replace(IMAGE_PLACEHOLDER, f"{boi_token}{image_token * n}{eoi_token}", 1)
while VIDEO_PLACEHOLDER in content:
num_soft_tokens_per_frame, metadata = next(video_iter)
if self.expand_mm_tokens:
timestamp_strs = [f"{int(t // 60):02d}:{int(t % 60):02d}" for t in metadata.timestamps]
frame_strs = [
f"{ts} {boi_token}{video_token * num_soft_tokens_per_frame}{eoi_token}"
for ts in timestamp_strs
]
video_str = " ".join(frame_strs)
else:
video_str = f"{boi_token}{video_token * num_soft_tokens_per_frame}{eoi_token}"
content = content.replace(VIDEO_PLACEHOLDER, video_str, 1)
while AUDIO_PLACEHOLDER in content:
current_audio = next(audio_iter)
if self.expand_mm_tokens:
num_audio_tokens = processor._compute_audio_num_tokens(
current_audio, processor.feature_extractor.sampling_rate
)
audio_str = f"{boa_token}{audio_token * num_audio_tokens}{eoa_token}"
else:
audio_str = f"{boa_token}{audio_token}{eoa_token}"
content = content.replace(AUDIO_PLACEHOLDER, audio_str, 1)
message["content"] = content
return messages
@override
def get_mm_inputs(
self,
images: list["ImageInput"],
videos: list["VideoInput"],
audios: list["AudioInput"],
imglens: list[int],
vidlens: list[int],
audlens: list[int],
batch_ids: list[list[int]],
processor: Optional["MMProcessor"],
) -> dict[str, Union[list[int], "torch.Tensor"]]:
self._validate_input(processor, images, videos, audios)
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
# Pop metadata keys that must not be passed to the model.
for key in (
"num_soft_tokens_per_image",
"num_soft_tokens_per_video",
"video_metadata",
"_gemma4_fps_per_video",
"_gemma4_frames_indices",
"_gemma4_num_audio_soft_tokens",
):
mm_inputs.pop(key, None)
mm_inputs["mm_token_type_ids"] = processor.create_mm_token_type_ids(batch_ids)
return mm_inputs
@dataclass
class InternVLPlugin(BasePlugin):
@override
@@ -1058,7 +1272,9 @@ class MiniCPMVPlugin(BasePlugin):
chunk_input=True,
sampling_rate=getattr(processor, "audio_sampling_rate", 16000),
)
audio_feature_lens = [torch.tensor(audio_feature_len) for audio_feature_len in audio_feature_lens]
audio_feature_lens = [
x.clone().detach() if isinstance(x, torch.Tensor) else torch.tensor(x) for x in audio_feature_lens
]
mm_inputs.update({"audio_features": audio_features, "audio_feature_lens": audio_feature_lens})
if kwargs.get("ret_phs", False):
mm_inputs.update({"audio_phs": audio_phs})
@@ -1098,7 +1314,7 @@ class MiniCPMVPlugin(BasePlugin):
num_image_tokens += 1
while VIDEO_PLACEHOLDER in content:
video_seqlen = len(mm_inputs["pixel_values"][num_video_tokens]) if self.expand_mm_tokens else 1
video_seqlen = len(mm_inputs["image_sizes"][num_video_tokens]) if self.expand_mm_tokens else 1
content = content.replace(VIDEO_PLACEHOLDER, "{{image}}" * video_seqlen, 1)
num_video_tokens += 1
@@ -1225,6 +1441,225 @@ class MiniCPMVPlugin(BasePlugin):
return mm_inputs
@dataclass
class MiniCPMV4_6Plugin(BasePlugin):
"""Plugin for MiniCPM-V-4.6 with new transformers (NaViT vision + get_placeholder_mask API)."""
def _get_mm_inputs(
self,
images: list["ImageInput"],
videos: list["VideoInput"],
audios: list["AudioInput"],
processor: "MMProcessor",
**kwargs,
) -> dict[str, "torch.Tensor"]:
image_processor = getattr(processor, "image_processor")
video_processor = getattr(processor, "video_processor", None)
mm_inputs = {}
if len(images) != 0:
# The image_processor ignores downsample_mode; target_sizes are always based on patch_size.
# downsample_mode only affects the token divisor in _build_v4_6_placeholder and model forward.
mm_inputs.update(image_processor(images, return_tensors="pt"))
if len(videos) != 0:
if video_processor is not None:
video_inputs = video_processor(videos, return_tensors="pt")
mm_inputs["pixel_values_videos"] = video_inputs["pixel_values_videos"]
mm_inputs["target_sizes_videos"] = video_inputs["target_sizes_videos"]
else:
video_inputs = image_processor(videos, return_tensors="pt")
mm_inputs["pixel_values_videos"] = video_inputs["pixel_values"]
mm_inputs["target_sizes_videos"] = video_inputs["target_sizes"]
if len(audios) != 0:
audio_features, audio_feature_lens, audio_phs = processor.audio_feature_extract(
[audios],
chunk_input=True,
sampling_rate=getattr(processor, "audio_sampling_rate", 16000),
)
audio_feature_lens = [
x.clone().detach() if isinstance(x, torch.Tensor) else torch.tensor(x) for x in audio_feature_lens
]
mm_inputs.update({"audio_features": audio_features, "audio_feature_lens": audio_feature_lens})
if kwargs.get("ret_phs", False):
mm_inputs.update({"audio_phs": audio_phs})
return mm_inputs
def _build_v4_6_placeholder(
self,
image_inputs: dict[str, Any],
image_idx: int,
use_image_id: bool,
processor: "MMProcessor",
) -> str:
"""Build image placeholder for MiniCPM-V-4.6 using NaViT token count computation."""
grids = image_inputs.get("grids", [[0, 0]])
num_patches_per_image = image_inputs.get("num_patches_per_image", [1])
target_sizes = image_inputs.get("target_sizes")
downsample_mode = os.getenv("DOWNSAMPLE_MODE")
if downsample_mode is None:
image_processor = getattr(processor, "image_processor")
downsample_mode = getattr(image_processor, "downsample_mode", "16x")
token_divisor = 4 if downsample_mode == "4x" else 16
flat_index = 0
for idx in range(image_idx):
flat_index += num_patches_per_image[idx]
n_patches = num_patches_per_image[image_idx]
img_target_sizes = target_sizes[flat_index : flat_index + n_patches]
num_tokens_per_patch = img_target_sizes.prod(-1) // token_divisor
num_rows, num_cols = grids[image_idx]
image_start = getattr(processor, "image_start_token", "<image>")
image_end = getattr(processor, "image_end_token", "</image>")
slice_start = getattr(processor, "slice_start_token", "<slice>")
slice_end = getattr(processor, "slice_end_token", "</slice>")
image_id_start = getattr(processor, "image_id_start_token", "<image_id>")
image_id_end = getattr(processor, "image_id_end_token", "</image_id>")
image_token = (
getattr(processor, "image_token", None)
or getattr(getattr(processor, "tokenizer", None), "image_token", None)
or "<image>"
)
image_placeholder = image_start + "<|ph|>" * int(num_tokens_per_patch[0]) + image_end
if use_image_id:
image_placeholder = f"{image_id_start}{image_idx}{image_id_end}" + image_placeholder
slice_mode = getattr(processor, "slice_mode", True)
if slice_mode and num_rows > 0 and num_cols > 0:
per_slice_tokens = int(num_tokens_per_patch[1]) if len(num_tokens_per_patch) > 1 else 0
slice_placeholder = slice_start + "<|ph|>" * per_slice_tokens + slice_end
slices = [slice_placeholder * num_cols for _ in range(num_rows)]
image_placeholder += "\n".join(slices)
return image_placeholder.replace("<|ph|>", image_token)
@override
def process_messages(
self,
messages: list[dict[str, str]],
images: list["ImageInput"],
videos: list["VideoInput"],
audios: list["AudioInput"],
processor: Optional["MMProcessor"],
) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios)
self._validate_messages(messages, images, videos, audios)
num_image_tokens, num_video_tokens, num_audio_tokens = 0, 0, 0
messages = deepcopy(messages)
mm_inputs, audio_inputs = {}, {}
if len(images) != 0 and len(videos) != 0:
raise ValueError("MiniCPM-V model does not support input images and videos at the same time.")
use_image_id = getattr(processor, "default_use_image_id", True)
if len(videos) != 0:
use_image_id = False
mm_inputs = self._get_mm_inputs([], videos, [], processor)
for i, message in enumerate(messages):
content = message["content"]
while IMAGE_PLACEHOLDER in content:
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1)
num_image_tokens += 1
while VIDEO_PLACEHOLDER in content:
num_frames = 1
if "num_frames_per_video" in mm_inputs:
num_frames = sum(mm_inputs["num_frames_per_video"])
content = content.replace(VIDEO_PLACEHOLDER, "{{image}}" * num_frames, 1)
num_video_tokens += 1
while AUDIO_PLACEHOLDER in content:
content = content.replace(AUDIO_PLACEHOLDER, "{{audio}}", 1)
num_audio_tokens += 1
message["content"] = content.replace("{{image}}", "(<image>./</image>)").replace(
"{{audio}}", "(<audio>./</audio>)"
)
if len(images):
mm_inputs = self._get_mm_inputs(images, [], [], processor)
if len(audios):
audio_inputs = self._get_mm_inputs([], [], audios, processor, ret_phs=True)
if self.expand_mm_tokens and mm_inputs:
pattern = "(<image>./</image>)"
idx = 0
for index, message in enumerate(messages):
text = message["content"]
image_tags = re.findall(pattern, text)
text_chunks = text.split(pattern)
final_text = ""
for i in range(len(image_tags)):
image_placeholder = self._build_v4_6_placeholder(mm_inputs, idx, use_image_id, processor)
final_text = final_text + text_chunks[i] + image_placeholder
idx += 1
final_text += text_chunks[-1]
messages[index]["content"] = final_text
if self.expand_mm_tokens and audio_inputs:
pattern = "(<audio>./</audio>)"
idx = 0
for index, message in enumerate(messages):
text = message["content"]
audio_tags = re.findall(pattern, text)
text_chunks = text.split(pattern)
final_text = ""
for i in range(len(audio_tags)):
audio_placeholder = audio_inputs["audio_phs"][0][idx]
final_text = final_text + text_chunks[i] + audio_placeholder
idx += 1
final_text += text_chunks[-1]
messages[index]["content"] = final_text
return messages
@override
def get_mm_inputs(
self,
images: list["ImageInput"],
videos: list["VideoInput"],
audios: list["AudioInput"],
imglens: list[int],
vidlens: list[int],
audlens: list[int],
batch_ids: list[list[int]],
processor: Optional["MMProcessor"],
) -> dict[str, Union[list[int], "torch.Tensor"]]:
self._validate_input(processor, images, videos, audios)
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
# v4.6 does NOT use image_bound — the model finds image tokens via get_placeholder_mask
# Ensure target_sizes key name matches the model's expected input
if "target_sizes" not in mm_inputs and "tgt_sizes" in mm_inputs:
mm_inputs["target_sizes"] = mm_inputs.pop("tgt_sizes")
if "target_sizes" not in mm_inputs:
mm_inputs["target_sizes"] = torch.empty(0, 2, dtype=torch.int32)
if "pixel_values" not in mm_inputs:
mm_inputs["pixel_values"] = torch.empty(1, 3, 14, 0)
# Pass downsample_mode to model forward so it matches the placeholder divisor
_ds = os.getenv("DOWNSAMPLE_MODE")
if _ds is None:
_ds = getattr(getattr(processor, "image_processor", None), "downsample_mode", "16x")
mm_inputs["downsample_mode"] = _ds
if len(audios) > 0:
audio_inputs = self._get_mm_inputs([], [], audios, processor)
mm_inputs.update(audio_inputs)
return mm_inputs
@dataclass
class MllamaPlugin(BasePlugin):
@override
@@ -1493,10 +1928,11 @@ class Qwen2VLPlugin(BasePlugin):
@override
def _regularize_videos(self, videos: list["VideoInput"], **kwargs) -> "RegularizedVideoOutput":
results, fps_per_video, durations = [], [], []
results, fps_per_video, durations, frames_indices = [], [], [], []
for video in videos:
frames: list[ImageObject] = []
if _check_video_is_nested_images(video):
# we assume already sample frames from videos
for frame in video:
if not is_valid_image(frame) and not isinstance(frame, dict) and not os.path.exists(frame):
raise ValueError("Invalid image found in video frames.")
@@ -1504,10 +1940,16 @@ class Qwen2VLPlugin(BasePlugin):
frames = video
fps_per_video.append(kwargs.get("video_fps", 2.0))
durations.append(len(frames) / kwargs.get("video_fps", 2.0))
frames_indices.append(list(range(len(frames))))
else:
container = av.open(video, "r")
video_stream = next(stream for stream in container.streams if stream.type == "video")
sample_indices = self._get_video_sample_indices(video_stream, **kwargs)
original_fps = float(video_stream.average_rate)
# for qwen3vl video timestamp calculation
frames_indices.append(
[idx / original_fps * kwargs.get("video_fps", 2.0) for idx in sample_indices]
) # hack usage when do_sample_frames=False
container.seek(0)
for frame_idx, frame in enumerate(container.decode(video_stream)):
if frame_idx in sample_indices:
@@ -1526,7 +1968,205 @@ class Qwen2VLPlugin(BasePlugin):
frames = self._regularize_images(frames, **kwargs)["images"]
results.append(frames)
return {"videos": results, "fps_per_video": fps_per_video, "durations": durations}
return {
"videos": results,
"fps_per_video": fps_per_video,
"durations": durations,
"frames_indices": frames_indices,
}
def _get_qwen_video_size_after_regularization(
self, width: int, height: int, image_max_pixels: int, image_min_pixels: int
) -> tuple[int, int]:
r"""Compute the frame size produced by Qwen-VL image regularization."""
if (width * height) > image_max_pixels:
resize_factor = math.sqrt(image_max_pixels / (width * height))
width, height = int(width * resize_factor), int(height * resize_factor)
if (width * height) < image_min_pixels:
resize_factor = math.sqrt(image_min_pixels / (width * height))
width, height = int(width * resize_factor), int(height * resize_factor)
if min(width, height) < 28:
width, height = max(width, 28), max(height, 28)
if width / height > 200:
width, height = height * 180, height
if height / width > 200:
width, height = width, width * 180
return width, height
def _get_qwen_video_stream_metadata(
self,
video: "VideoInput",
video_fps: float,
video_maxlen: int,
) -> Optional[dict[str, Any]]:
if not is_pyav_available() or not isinstance(video, (str, os.PathLike)):
return None
try:
container = av.open(video, "r")
except (av.FFmpegError, OSError):
return None
try:
video_stream = next((stream for stream in container.streams if stream.type == "video"), None)
if video_stream is None:
return None
if video_stream.duration is None or video_stream.average_rate is None:
return None
average_fps = float(video_stream.average_rate)
if average_fps <= 0:
return None
sample_indices = self._get_video_sample_indices(
video_stream, video_fps=video_fps, video_maxlen=video_maxlen
)
return {
"width": video_stream.width,
"height": video_stream.height,
"average_fps": average_fps,
"sample_indices": sample_indices,
}
finally:
container.close()
def _get_qwen_video_resize(
self,
num_frames: int,
height: int,
width: int,
patch_size: int,
temporal_patch_size: int,
merge_size: int,
min_pixels: int,
max_pixels: int,
) -> tuple[int, int]:
from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize
return smart_resize(
height=height,
width=width,
factor=patch_size * merge_size,
min_pixels=min_pixels,
max_pixels=max_pixels,
)
def _get_qwen_video_grid_metadata(
self,
videos: list["VideoInput"],
processor: "MMProcessor",
) -> Optional[dict[str, Any]]:
if len(videos) == 0:
return {"video_grid_thw": torch.empty((0, 3), dtype=torch.long), "frames_indices": [], "fps": 2.0}
image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
video_processor: BaseVideoProcessor = getattr(processor, "video_processor", None) or image_processor
if image_processor is None or video_processor is None:
return None
patch_size = getattr(video_processor, "patch_size", None)
temporal_patch_size = getattr(video_processor, "temporal_patch_size", None)
merge_size = getattr(video_processor, "merge_size", None)
size = getattr(video_processor, "size", None)
if patch_size is None or temporal_patch_size is None or merge_size is None or size is None:
return None
if isinstance(size, dict):
min_pixels = size.get("shortest_edge")
max_pixels = size.get("longest_edge")
else:
min_pixels = getattr(size, "shortest_edge", None)
max_pixels = getattr(size, "longest_edge", None)
if min_pixels is None or max_pixels is None:
return None
video_fps = getattr(processor, "video_fps", 2.0)
video_maxlen = getattr(processor, "video_maxlen", 128)
image_max_pixels = getattr(processor, "video_max_pixels", 256 * 256)
image_min_pixels = getattr(processor, "video_min_pixels", 16 * 16)
video_grid_thw = []
frames_indices = []
for video in videos:
metadata = self._get_qwen_video_stream_metadata(video, video_fps, video_maxlen)
if metadata is None:
return None
width, height = self._get_qwen_video_size_after_regularization(
metadata["width"], metadata["height"], image_max_pixels, image_min_pixels
)
num_frames = len(metadata["sample_indices"])
if num_frames % 2 != 0:
num_frames += 1
resized_size = self._get_qwen_video_resize(
num_frames,
height,
width,
patch_size,
temporal_patch_size,
merge_size,
min_pixels,
max_pixels,
)
resized_height, resized_width = resized_size
video_grid_thw.append(
[
math.ceil(num_frames / temporal_patch_size),
resized_height // patch_size,
resized_width // patch_size,
]
)
frames_indices.append([idx / metadata["average_fps"] * video_fps for idx in metadata["sample_indices"]])
return {
"video_grid_thw": torch.tensor(video_grid_thw, dtype=torch.long),
"frames_indices": frames_indices,
"fps": video_fps,
}
@override
def _get_video_token_metadata(
self,
videos: list["VideoInput"],
processor: "MMProcessor",
) -> Optional[dict[str, Any]]:
video_metadata = self._get_qwen_video_grid_metadata(videos, processor)
if video_metadata is None:
return None
return {"video_grid_thw": video_metadata["video_grid_thw"]}
def _get_mm_token_metadata(
self,
images: list["ImageInput"],
videos: list["VideoInput"],
audios: list["AudioInput"],
processor: "MMProcessor",
) -> Optional[dict[str, Any]]:
if len(audios) != 0:
return None
mm_inputs = {}
if len(images) != 0:
mm_inputs.update(self._get_mm_inputs(images, [], [], processor))
if len(videos) != 0:
video_inputs = self._get_video_token_metadata(videos, processor)
if video_inputs is None:
return None
mm_inputs.update(video_inputs)
return mm_inputs
@override
def _get_mm_inputs(
@@ -1579,7 +2219,10 @@ class Qwen2VLPlugin(BasePlugin):
merge_length: int = getattr(image_processor, "merge_size") ** 2
if self.expand_mm_tokens:
mm_inputs = self._get_mm_token_metadata(images, videos, audios, processor)
if mm_inputs is None:
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
image_grid_thw = mm_inputs.get("image_grid_thw", [])
video_grid_thw = mm_inputs.get("video_grid_thw", [])
else:
@@ -1613,6 +2256,51 @@ class Qwen2VLPlugin(BasePlugin):
@dataclass
class Qwen3VLPlugin(Qwen2VLPlugin):
@override
def _get_qwen_video_resize(
self,
num_frames: int,
height: int,
width: int,
patch_size: int,
temporal_patch_size: int,
merge_size: int,
min_pixels: int,
max_pixels: int,
) -> tuple[int, int]:
from transformers.models.qwen3_vl.video_processing_qwen3_vl import smart_resize
return smart_resize(
num_frames=num_frames,
height=height,
width=width,
temporal_factor=temporal_patch_size,
factor=patch_size * merge_size,
min_pixels=min_pixels,
max_pixels=max_pixels,
)
@override
def _get_video_token_metadata(
self,
videos: list["VideoInput"],
processor: "MMProcessor",
) -> Optional[dict[str, Any]]:
video_metadata = self._get_qwen_video_grid_metadata(videos, processor)
if video_metadata is None:
return None
return {
"video_grid_thw": video_metadata["video_grid_thw"],
"video_metadata": [
SimpleNamespace(
frames_indices=frames_indices,
fps=video_metadata["fps"],
)
for frames_indices in video_metadata["frames_indices"]
],
}
@override
def _get_mm_inputs(
self,
@@ -1641,8 +2329,15 @@ class Qwen3VLPlugin(Qwen2VLPlugin):
video_maxlen=getattr(processor, "video_maxlen", 128),
)
video_metadata = [
{"fps": getattr(processor, "video_fps", 24.0), "duration": duration, "total_num_frames": len(video)}
for video, duration in zip(videos["videos"], videos["durations"])
{
"fps": getattr(processor, "video_fps", 2.0),
"duration": duration,
"total_num_frames": len(video),
"frames_indices": sample_indices,
}
for video, duration, sample_indices in zip(
videos["videos"], videos["durations"], videos["frames_indices"]
)
]
mm_inputs.update(
video_processor(
@@ -1650,6 +2345,7 @@ class Qwen3VLPlugin(Qwen2VLPlugin):
video_metadata=video_metadata,
fps=getattr(processor, "video_fps", 2.0),
return_metadata=True,
do_sample_frames=False, # avoid changing frames_indices
)
)
temporal_patch_size: int = getattr(image_processor, "temporal_patch_size", 2)
@@ -1677,11 +2373,14 @@ class Qwen3VLPlugin(Qwen2VLPlugin):
image_merge_length: int = getattr(image_processor, "merge_size") ** 2
video_merge_length: int = getattr(video_processor, "merge_size") ** 2
if self.expand_mm_tokens:
mm_inputs = self._get_mm_token_metadata(images, videos, audios, processor)
if mm_inputs is None:
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
image_grid_thw = mm_inputs.get("image_grid_thw", [])
video_grid_thw = mm_inputs.get("video_grid_thw", [])
num_frames = video_grid_thw[0][0] if len(video_grid_thw) > 0 else 0 # hard code for now
video_metadata = mm_inputs.get("video_metadata", {})
video_metadata = mm_inputs.get("video_metadata", [])
else:
image_grid_thw = [None] * len(images)
@@ -2204,8 +2903,9 @@ PLUGINS = {
"base": BasePlugin,
"ernie_vl": ErnieVLPlugin,
"gemma3": Gemma3Plugin,
"glm4v": GLM4VPlugin,
"gemma3n": Gemma3nPlugin,
"gemma4": Gemma4Plugin,
"glm4v": GLM4VPlugin,
"intern_vl": InternVLPlugin,
"kimi_vl": KimiVLPlugin,
"llama4": Llama4Plugin,
@@ -2214,6 +2914,7 @@ PLUGINS = {
"llava_next_video": LlavaNextVideoPlugin,
"lfm2_vl": LFMVLPlugin,
"minicpm_v": MiniCPMVPlugin,
"minicpm_v_4_6": MiniCPMV4_6Plugin,
"mllama": MllamaPlugin,
"paligemma": PaliGemmaPlugin,
"pixtral": PixtralPlugin,

View File

@@ -13,7 +13,7 @@
# limitations under the License.
from collections import defaultdict
from dataclasses import dataclass
from dataclasses import asdict, dataclass
from typing import TYPE_CHECKING, Any, Optional
from ...extras import logging
@@ -27,6 +27,25 @@ if TYPE_CHECKING:
logger = logging.get_logger(__name__)
MAX_SU_SEQ_IDX = 2**32 # maximum sub-sequence index
@dataclass
class PackingParams:
r"""Metadata for a packed sequence: sub-sequence boundaries and multimodal data indices.
- sequence_boundaries: cumulative token positions, e.g. [0, 100, 250, 512] means 3 sub-seqs
with token ranges [0,100), [100,250), [250,512). Length = num_sub_seqs + 1.
- image_subseq_ids / video_subseq_ids / audio_subseq_ids: for each mm item, the 0-based
sub-sequence index it belongs to. Length = total number of that mm type in the packed sample.
"""
sequence_boundaries: list[int]
image_subseq_ids: list[int]
video_subseq_ids: list[int]
audio_subseq_ids: list[int]
right_padding_length: int
@dataclass
class SupervisedDatasetProcessor(DatasetProcessor):
@@ -44,7 +63,8 @@ class SupervisedDatasetProcessor(DatasetProcessor):
input_ids, labels = self.template.mm_plugin.process_token_ids(
[], [], images, videos, audios, self.tokenizer, self.processor
)
encoded_pairs = self.template.encode_multiturn(self.tokenizer, messages, system, tools)
discarding_history_cot = self.data_args.mask_history and not self.template.preserve_thinking
encoded_pairs = self.template.encode_multiturn(self.tokenizer, messages, system, tools, discarding_history_cot)
total_length = len(input_ids) + (1 if self.template.efficient_eos else 0)
if self.data_args.mask_history:
encoded_pairs = encoded_pairs[::-1] # high priority for last turns
@@ -162,10 +182,17 @@ class PackedSupervisedDatasetProcessor(SupervisedDatasetProcessor):
valid_num += 1
model_inputs = defaultdict(list)
requires_packing_params = self.data_args.neat_packing
knapsacks = greedy_knapsack(lengths, self.data_args.cutoff_len)
for knapsack in knapsacks:
packed_input_ids, packed_attention_masks, packed_position_ids, packed_labels = [], [], [], []
packed_images, packed_videos, packed_audios = [], [], []
if requires_packing_params:
sequence_boundaries = [0]
image_subseq_ids: list[int] = []
video_subseq_ids: list[int] = []
audio_subseq_ids: list[int] = []
for i, length in enumerate(knapsack):
index = length2indexes[length].pop()
packed_input_ids += batch_input_ids[index]
@@ -174,6 +201,15 @@ class PackedSupervisedDatasetProcessor(SupervisedDatasetProcessor):
packed_images += batch_images[index]
packed_videos += batch_videos[index]
packed_audios += batch_audios[index]
if requires_packing_params:
n_img = len(batch_images[index])
n_vid = len(batch_videos[index])
n_aud = len(batch_audios[index])
sequence_boundaries.append(sequence_boundaries[-1] + len(batch_input_ids[index]))
image_subseq_ids.extend([i] * n_img)
video_subseq_ids.extend([i] * n_vid)
audio_subseq_ids.extend([i] * n_aud)
if self.data_args.neat_packing:
packed_attention_masks += [i + 1] * len(batch_input_ids[index]) # start from 1
else:
@@ -189,10 +225,23 @@ class PackedSupervisedDatasetProcessor(SupervisedDatasetProcessor):
else:
packed_attention_masks += [1] * pad_length # more efficient flash_attn
if requires_packing_params:
sequence_boundaries.append(sequence_boundaries[-1] + pad_length)
if len(packed_input_ids) != self.data_args.cutoff_len + 1:
raise ValueError("The length of packed example should be identical to the cutoff length.")
model_inputs["input_ids"].append(packed_input_ids)
if requires_packing_params:
packing_params = PackingParams(
sequence_boundaries=sequence_boundaries,
image_subseq_ids=image_subseq_ids or [MAX_SU_SEQ_IDX], # avoid dataset concat error
video_subseq_ids=video_subseq_ids or [MAX_SU_SEQ_IDX],
audio_subseq_ids=audio_subseq_ids or [MAX_SU_SEQ_IDX],
right_padding_length=pad_length,
)
model_inputs["packing_params"].append(asdict(packing_params))
model_inputs["attention_mask"].append(packed_attention_masks)
model_inputs["position_ids"].append(packed_position_ids)
model_inputs["labels"].append(packed_labels)

View File

@@ -54,6 +54,7 @@ class Template:
replace_eos: bool
replace_jinja_template: bool
enable_thinking: Optional[bool]
preserve_thinking: bool
mm_plugin: "BasePlugin"
def encode_oneturn(
@@ -78,6 +79,7 @@ class Template:
messages: list[dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
discarding_history_cot: bool = False, # only effect reasoning template
) -> list[tuple[list[int], list[int]]]:
r"""Return multiple pairs of token ids representing prompts and responses respectively."""
encoded_messages = self._encode(tokenizer, messages, system, tools)
@@ -414,6 +416,7 @@ class ReasoningTemplate(Template):
tools: Optional[str] = None,
) -> tuple[list[int], list[int]]:
messages = deepcopy(messages)
if not self.preserve_thinking:
for i in range(1, len(messages) - 2, 2):
messages[i]["content"] = self.remove_thought(messages[i]["content"])
@@ -439,14 +442,24 @@ class ReasoningTemplate(Template):
messages: list[dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
discarding_history_cot: bool = False,
) -> list[tuple[list[int], list[int]]]:
messages = deepcopy(messages)
if self.enable_thinking is False: # remove all cot
for i in range(1, len(messages), 2):
messages[i]["content"] = self.remove_thought(messages[i]["content"])
if discarding_history_cot:
for i in range(1, len(messages) - 2, 2): # preserve the last cot
messages[i]["content"] = self.remove_thought(messages[i]["content"])
encoded_messages = self._encode(tokenizer, messages, system, tools)
for i in range(0, len(messages), 2):
if discarding_history_cot:
turn_indices = [len(messages) - 2]
else:
turn_indices = range(0, len(messages), 2)
for i in turn_indices:
if (
self.thought_words[0].strip() not in messages[i + 1]["content"]
and self.thought_words[1].strip() not in messages[i + 1]["content"]
@@ -491,6 +504,7 @@ def register_template(
replace_eos: bool = False,
replace_jinja_template: bool = False,
enable_thinking: Optional[bool] = True,
preserve_thinking: bool = False,
mm_plugin: "BasePlugin" = get_mm_plugin(name="base"),
template_class: type["Template"] = Template,
) -> None:
@@ -543,6 +557,7 @@ def register_template(
replace_eos=replace_eos,
replace_jinja_template=replace_jinja_template,
enable_thinking=enable_thinking,
preserve_thinking=preserve_thinking,
mm_plugin=mm_plugin,
)
@@ -605,6 +620,7 @@ def parse_template(tokenizer: "PreTrainedTokenizer") -> "Template":
replace_eos=False,
replace_jinja_template=False,
enable_thinking=True,
preserve_thinking=False,
mm_plugin=get_mm_plugin(name="base"),
)
@@ -644,6 +660,7 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args:
"e.g., qwen3_vl_nothink"
)
template.enable_thinking = data_args.enable_thinking
template.preserve_thinking = data_args.preserve_thinking
template.fix_special_tokens(tokenizer)
template.fix_jinja_template(tokenizer)
@@ -816,6 +833,19 @@ register_template(
)
register_template(
name="hy3",
format_user=StringFormatter(slots=["<hy_User>{{content}}<hy_Assistant>"]),
format_assistant=StringFormatter(slots=["{{content}}<hy_eos>"]),
format_system=StringFormatter(slots=["{{content}}"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<hy_eos>"],
replace_eos=True,
thought_words=("<think>", "</think>"),
template_class=ReasoningTemplate,
)
register_template(
name="deepseekcoder",
format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n### Response:"]),
@@ -997,6 +1027,57 @@ register_template(
)
register_template(
name="gemma4",
format_user=StringFormatter(slots=["<|turn>user\n{{content}}<turn|>\n<|turn>model\n"]),
format_assistant=StringFormatter(slots=["{{content}}<turn|>\n"]),
format_system=StringFormatter(
slots=["<|turn>system\n<|think|>{{content}}<turn|>\n"]
), # default thought singal contained
format_observation=StringFormatter(
slots=["<|turn>tool\n{{content}}<turn|>\n<|turn>model\n"]
), # seem not consistent with the chattemplate
format_tools=ToolFormatter(tool_format="gemma4"),
format_function=FunctionFormatter(slots=["<|tool>{{content}}<tool|>"], tool_format="gemma4"),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<turn|>"],
default_system="You are a helpful assistant.", # important for thinking
thought_words=("<|channel>thought\n", "<channel|>"),
replace_eos=True,
mm_plugin=get_mm_plugin(
"gemma4",
image_token="<|image|>",
video_token="<|video|>",
),
template_class=ReasoningTemplate,
)
register_template(
name="gemma4n",
format_user=StringFormatter(slots=["<|turn>user\n{{content}}<turn|>\n<|turn>model\n"]),
format_assistant=StringFormatter(slots=["{{content}}<turn|>\n"]),
format_system=StringFormatter(
slots=["<|turn>system\n<|think|>{{content}}<turn|>\n"]
), # default thought singal contained
format_observation=StringFormatter(slots=["<|turn>tool\n{{content}}<turn|>\n<|turn>model\n"]),
format_tools=ToolFormatter(tool_format="gemma4"),
format_function=FunctionFormatter(slots=["<|tool>{{content}}<tool|>"], tool_format="gemma4"),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<turn|>"],
default_system="You are a helpful assistant.", # important for thinking
thought_words=("<|channel>thought\n", "<channel|>"),
replace_eos=True,
mm_plugin=get_mm_plugin(
"gemma4",
image_token="<|image|>",
video_token="<|video|>",
audio_token="<|audio|>",
),
template_class=ReasoningTemplate,
)
register_template(
name="glm4",
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]),
@@ -1623,6 +1704,17 @@ register_template(
)
register_template(
name="minicpm_v_4_6",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
stop_words=["<|im_end|>"],
default_system="You are a helpful assistant.",
mm_plugin=get_mm_plugin(name="minicpm_v_4_6", image_token="<image>", video_token="<video>"),
)
# copied from minicpm_v template
register_template(
name="minicpm_o",
@@ -2062,6 +2154,24 @@ register_template(
)
# copied from qwen3_5 template
register_template(
name="qwen3_6",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_function=FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen3_5"),
format_observation=StringFormatter(
slots=["<|im_start|>user\n<tool_response>\n{{content}}\n</tool_response><|im_end|>\n<|im_start|>assistant\n"]
),
format_tools=ToolFormatter(tool_format="qwen3_5"),
stop_words=["<|im_end|>"],
replace_eos=True,
mm_plugin=get_mm_plugin(name="qwen3_vl", image_token="<|image_pad|>", video_token="<|video_pad|>"),
template_class=ReasoningTemplate,
)
register_template(
name="sailor",
format_user=StringFormatter(slots=["<|im_start|>question\n{{content}}<|im_end|>\n<|im_start|>answer\n"]),

View File

@@ -210,6 +210,166 @@ class DefaultToolUtils(ToolUtils):
return results
class Gemma4ToolUtils(ToolUtils):
r"""Gemma-4 tool using template."""
@override
@staticmethod
def tool_formatter(tools: list[dict[str, Any]]) -> str:
def _format_parameters(properties: dict[str, Any]) -> str:
parts: list[str] = []
for name, schema in properties.items():
item_parts: list[str] = []
if schema.get("description"):
item_parts.append(f'description:<|"|>{schema["description"]}<|"|>')
if schema.get("type"):
item_parts.append(f'type:<|"|>{str(schema["type"]).upper()}<|"|>')
parts.append(f"{name}:{{{','.join(item_parts)}}}")
return ",".join(parts)
declarations: list[str] = []
for tool in tools:
function_data = tool.get("function", tool) if tool.get("type") == "function" else tool
declaration = (
f"declaration:{function_data['name']}"
+ "{"
+ f'description:<|"|>{function_data.get("description", "")}<|"|>'
)
params = function_data.get("parameters")
if params:
param_parts: list[str] = []
if params.get("properties"):
param_parts.append(f"properties:{{{_format_parameters(params['properties'])}}}")
if params.get("required"):
required_text = ",".join(f'<|"|>{item}<|"|>' for item in params["required"])
param_parts.append(f"required:[{required_text}]")
if params.get("type"):
param_parts.append(f'type:<|"|>{str(params["type"]).upper()}<|"|>')
declaration += f",parameters:{{{','.join(param_parts)}}}"
response_declaration = function_data.get("response")
if response_declaration:
response_parts: list[str] = []
if response_declaration.get("description"):
response_parts.append(f'description:<|"|>{response_declaration["description"]}<|"|>')
response_type = str(response_declaration.get("type", "")).upper()
if response_type == "OBJECT":
response_parts.append(f'type:<|"|>{response_type}<|"|>')
declaration += f",response:{{{','.join(response_parts)}}}"
declarations.append(declaration + "}")
return "\n".join(declarations)
@override
@staticmethod
def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]:
regex = re.compile(r"<\|tool_call\>call:([^{\s]+)\{(.*?)\}<tool_call\|>", re.DOTALL)
matches = re.findall(regex, content)
if not matches:
return content
def _parse_arguments(arg_text: str) -> Any:
text = arg_text.strip()
if not text:
return {}
# `function_formatter` writes dict arguments as `k:v,...` inside `{...}`.
# The extractor captures only the inner text, so re-wrap it to parse as JSON object.
object_like_text = "{" + text + "}"
# Convert Gemma string markers (<|"|>value<|"|>) to valid JSON strings.
normalized = re.sub(
r"<\|\"\|\>(.*?)<\|\"\|\>",
lambda m: json.dumps(m.group(1), ensure_ascii=False),
object_like_text,
flags=re.DOTALL,
)
# Quote unquoted object keys so the payload can be parsed by json.loads.
normalized = re.sub(r"(^|[{\s,])([A-Za-z_][A-Za-z0-9_]*)(\s*:)", r'\1"\2"\3', normalized)
try:
return json.loads(normalized)
except json.JSONDecodeError:
pass
try:
return json.loads(text)
except json.JSONDecodeError:
return text
results: list[FunctionCall] = []
for name, arg_block in matches:
parsed_arguments = _parse_arguments(arg_block)
if isinstance(parsed_arguments, str):
arguments = parsed_arguments
else:
arguments = json.dumps(parsed_arguments, ensure_ascii=False)
results.append(FunctionCall(name.strip(), arguments))
return results
@override
@staticmethod
def function_formatter(functions: list["FunctionCall"]) -> str:
def _format_argument(argument: Any, escape_keys: bool = True) -> str:
if isinstance(argument, str):
return f'<|"|>{argument}<|"|>'
if isinstance(argument, bool):
return "true" if argument else "false"
if isinstance(argument, dict):
items: list[str] = []
for key in sorted(argument.keys()):
formatted_key = f'<|"|>{key}<|"|>' if escape_keys else str(key)
formatted_value = _format_argument(argument[key], escape_keys=escape_keys)
items.append(f"{formatted_key}:{formatted_value}")
return "{" + ",".join(items) + "}"
if isinstance(argument, (list, tuple)):
return "[" + ",".join(_format_argument(item, escape_keys=escape_keys) for item in argument) + "]"
if argument is None:
return "null"
return str(argument)
function_texts: list[str] = []
for function in functions:
name = function.name
raw_arguments = function.arguments
try:
parsed_arguments = json.loads(raw_arguments)
except (TypeError, json.JSONDecodeError):
parsed_arguments = raw_arguments
call_text = f"<|tool_call>call:{name}" + "{"
if isinstance(parsed_arguments, dict):
args_text = []
for key in sorted(parsed_arguments.keys()):
value_text = _format_argument(parsed_arguments[key], escape_keys=False)
args_text.append(f"{key}:{value_text}")
call_text += ",".join(args_text)
elif isinstance(parsed_arguments, str):
call_text += parsed_arguments
else:
call_text += _format_argument(parsed_arguments, escape_keys=False)
call_text += "}<tool_call|>"
function_texts.append(call_text)
return "".join(function_texts)
class GLM4ToolUtils(ToolUtils):
r"""GLM-4 tool using template."""
@@ -361,6 +521,8 @@ class MiniMaxM2ToolUtils(ToolUtils):
prompt += "\n</invoke>"
function_texts.append(prompt)
return "\n".join(function_texts)
@override
@staticmethod
def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]:
@@ -598,7 +760,7 @@ class SeedToolUtils(ToolUtils):
results.append(FunctionCall(func_name.strip(), json.dumps(args_dict, ensure_ascii=False)))
return results
return results if results else content
class LingToolUtils(QwenToolUtils):
@@ -721,6 +883,7 @@ class LFM2ToolUtils(ToolUtils):
TOOLS = {
"default": DefaultToolUtils(),
"gemma4": Gemma4ToolUtils(),
"glm4": GLM4ToolUtils(),
"llama3": Llama3ToolUtils(),
"lfm2": LFM2ToolUtils(),

View File

@@ -77,6 +77,20 @@ METHODS = ["full", "freeze", "lora", "oft"]
MOD_SUPPORTED_MODELS = {"bloom", "falcon", "gemma", "llama", "mistral", "mixtral", "phi", "starcoder2"}
MROPE_MODELS = {
"glm4v",
"glm_ocr",
"Keye",
"qwen2_vl",
"qwen2_5_vl",
"qwen2_5_omni_thinker",
"qwen3_omni_moe_thinker",
"qwen3_vl",
"qwen3_vl_moe",
"qwen3_5",
"qwen3_5_moe",
}
MULTIMODAL_SUPPORTED_MODELS = set()
PEFT_METHODS = {"lora", "oft"}
@@ -125,7 +139,6 @@ class EngineName(StrEnum):
HF = "huggingface"
VLLM = "vllm"
SGLANG = "sglang"
KT = "ktransformers"
class DownloadSource(StrEnum):
@@ -851,6 +864,34 @@ register_model_group(
)
register_model_group(
models={
"Gemma-4-26B-A4B-Thinking": {
DownloadSource.DEFAULT: "google/gemma-4-26B-A4B-it",
},
"Gemma-4-31B-Thinking": {
DownloadSource.DEFAULT: "google/gemma-4-31B-it",
},
},
template="gemma4",
multimodal=True,
)
register_model_group(
models={
"Gemma-4-E2B-Thinking": {
DownloadSource.DEFAULT: "google/gemma-4-E2B-it",
},
"Gemma-4-E4B-Thinking": {
DownloadSource.DEFAULT: "google/gemma-4-E4B-it",
},
},
template="gemma4n",
multimodal=True,
)
register_model_group(
models={
"GLM-4-9B": {
@@ -1215,6 +1256,17 @@ register_model_group(
)
register_model_group(
models={
"Hy3-Preview": {
DownloadSource.DEFAULT: "tencent/Hy3-preview",
DownloadSource.MODELSCOPE: "tencent/Hy3-preview",
},
},
template="hy3",
)
register_model_group(
models={
"Index-1.9B-Base": {
@@ -1896,6 +1948,18 @@ register_model_group(
)
register_model_group(
models={
"MiniCPM-V-4.6": {
DownloadSource.DEFAULT: "openbmb/MiniCPM-V-4_6",
DownloadSource.MODELSCOPE: "OpenBMB/MiniCPM-V-4_6",
},
},
template="minicpm_v_4_6",
multimodal=True,
)
register_model_group(
models={
"Ministral-8B-Instruct-2410": {

View File

@@ -19,7 +19,7 @@
from collections import OrderedDict
VERSION = "0.9.5.dev0"
VERSION = "0.9.5"
def print_env() -> None:

View File

@@ -94,10 +94,10 @@ def check_version(requirement: str, mandatory: bool = False) -> None:
def check_dependencies() -> None:
r"""Check the version of the required packages."""
check_version("transformers>=4.51.0,<=5.2.0")
check_version("transformers>=4.55.0,<=5.6.0")
check_version("datasets>=2.16.0,<=4.0.0")
check_version("accelerate>=1.3.0,<=1.11.0")
check_version("peft>=0.18.0,<=0.18.1")
check_version("accelerate>=1.3.0,<=1.15.0")
check_version("peft>=0.18.0,<=0.20.0")
check_version("trl>=0.18.0,<=0.24.0")

View File

@@ -20,6 +20,7 @@ import importlib.util
from functools import lru_cache
from typing import TYPE_CHECKING
import transformers.utils.import_utils as import_utils
from packaging import version
@@ -70,6 +71,10 @@ def is_matplotlib_available():
return _is_package_available("matplotlib")
def is_hyper_parallel_available():
return _is_package_available("hyper_parallel")
def is_mcore_adapter_available():
return _is_package_available("mcore_adapter")
@@ -83,7 +88,7 @@ def is_ray_available():
def is_kt_available():
return _is_package_available("ktransformers")
return _is_package_available("kt_kernel")
def is_requests_available():
@@ -122,3 +127,26 @@ def is_uvicorn_available():
def is_vllm_available():
return _is_package_available("vllm")
_orig_is_package_available = import_utils._is_package_available
class PackageAvailability(tuple):
__slots__ = ()
def __new__(cls, available: bool, pkg_version: str = "N/A"):
return super().__new__(cls, (bool(available), pkg_version))
def __bool__(self) -> bool:
return self[0]
def _patched_is_package_available(pkg_name: str, return_version: bool = False):
available, version = _orig_is_package_available(pkg_name, return_version=return_version)
return PackageAvailability(available, version)
if is_transformers_version_greater_than("5.3.0"):
import_utils._is_package_available = _patched_is_package_available

View File

@@ -125,6 +125,10 @@ class DataArguments:
default=True,
metadata={"help": "Whether or not to enable thinking mode for reasoning models."},
)
preserve_thinking: bool = field(
default=False,
metadata={"help": "Whether or not to preserve thinking content in historical turns for reasoning models."},
)
tokenized_path: str | None = field(
default=None,
metadata={

View File

@@ -482,6 +482,24 @@ class FinetuningArguments(
)
},
)
use_hyper_parallel: bool = field(
default=False,
metadata={
"help": (
"Whether or not to use HyperParallel distributed training backend (FSDP/TP). "
"Only supported for the 'sft' stage with full fine-tuning."
)
},
)
hyper_parallel_args: str | None = field(
default=None,
metadata={
"help": (
"Path to a JSON file containing HyperParallel strategy arguments "
"(e.g., tp_size, param_dtype). Used when use_hyper_parallel=True."
)
},
)
use_muon: bool = field(
default=False,
metadata={"help": "Whether or not to use the Muon optimizer."},

View File

@@ -16,6 +16,7 @@
# limitations under the License.
import json
import os
from dataclasses import asdict, dataclass, field, fields
from typing import Any, Literal, Self
@@ -460,47 +461,81 @@ class SGLangArguments:
@dataclass
class KTransformersArguments:
r"""Arguments pertaining to the KT training."""
r"""Arguments pertaining to KTransformers AMX MoE SFT training.
These fields are normalized into the transformers/accelerate KT config before training starts.
"""
use_kt: bool = field(
default=False,
metadata={"help": "Whether To Use KTransformers Optimizations For LoRA Training."},
metadata={"help": "Whether to use KTransformers AMX MoE backend for SFT training."},
)
kt_optimize_rule: str | None = field(
kt_weight_path: str | None = field(
default=None,
metadata={
"help": "Path To The KTransformers Optimize Rule; See https://github.com/kvcache-ai/ktransformers/."
},
metadata={"help": "Path to pre-quantized INT8 expert weights (.kt files)."},
)
cpu_infer: int | None = field(
default=32,
metadata={"help": "Number Of CPU Cores Used For Computation."},
kt_expert_checkpoint_path: str | None = field(
default=None,
metadata={"help": "Path to expert checkpoint (safetensors) for online conversion."},
)
chunk_size: int | None = field(
default=8192,
metadata={"help": "Chunk Size Used For CPU Compute In KTransformers."},
kt_use_lora_experts: bool | None = field(
default=None,
metadata={"help": "Whether to use GPU-side LoRA Experts."},
)
mode: str | None = field(
default="normal",
metadata={"help": "Normal Or Long_Context For Llama Models."},
kt_lora_expert_num: int | None = field(
default=None,
metadata={"help": "Number of GPU-side LoRA Experts."},
)
kt_lora_expert_intermediate_size: int | None = field(
default=None,
metadata={"help": "Intermediate size for GPU-side LoRA Experts."},
)
kt_maxlen: int = field(
default=4096,
metadata={"help": "Maximum Sequence (Prompt + Response) Length Of The KT Engine."},
)
kt_use_cuda_graph: bool = field(
default=True,
metadata={"help": "Whether To Use CUDA Graphs For The KT Engine."},
)
kt_mode: str = field(
default="normal",
metadata={"help": "Normal Or Long_Context Mode For The KT Engine."},
)
kt_force_think: bool = field(
default=False,
metadata={"help": "Force-Think Toggle For The KT Engine."},
def get_kt_config_dict(self, finetuning_args: Any, model_max_length: int | None) -> dict[str, Any]:
r"""Build KT config values from LLaMA-Factory model and LoRA arguments."""
kt_config = {
"kt_lora_rank": getattr(finetuning_args, "lora_rank", None),
"kt_lora_alpha": getattr(finetuning_args, "lora_alpha", None),
"kt_weight_path": self.kt_weight_path,
"kt_expert_checkpoint_path": self.kt_expert_checkpoint_path,
"kt_model_max_length": model_max_length,
"kt_use_lora_experts": self.kt_use_lora_experts,
"kt_lora_expert_num": self.kt_lora_expert_num,
"kt_lora_expert_intermediate_size": self.kt_lora_expert_intermediate_size,
}
return {key: value for key, value in kt_config.items() if value is not None}
def apply_kt_config(self, finetuning_args: Any, training_args: Any, model_max_length: int | None) -> None:
r"""Apply LLaMA-Factory KT args to transformers/accelerate KT integration points."""
if not self.use_kt:
return
kt_config = self.get_kt_config_dict(finetuning_args, model_max_length)
env_mapping = {
"kt_weight_path": "ACCELERATE_KT_WEIGHT_PATH",
"kt_expert_checkpoint_path": "ACCELERATE_KT_EXPERT_CHECKPOINT_PATH",
"kt_model_max_length": "ACCELERATE_KT_MODEL_MAX_LENGTH",
"kt_lora_rank": "ACCELERATE_KT_LORA_RANK",
"kt_lora_alpha": "ACCELERATE_KT_LORA_ALPHA",
"kt_use_lora_experts": "ACCELERATE_KT_USE_LORA_EXPERTS",
"kt_lora_expert_num": "ACCELERATE_KT_LORA_EXPERT_NUM",
"kt_lora_expert_intermediate_size": "ACCELERATE_KT_LORA_EXPERT_INTERMEDIATE_SIZE",
}
for key, env_key in env_mapping.items():
value = kt_config.get(key)
if value is not None:
os.environ[env_key] = str(value)
hf_kt = getattr(training_args, "hf_kt_config", None)
if hf_kt is None or not hasattr(hf_kt, "_kt_config") or not isinstance(hf_kt._kt_config, dict):
return
hf_kt._kt_config.update(kt_config)
gc_enabled = getattr(training_args, "gradient_checkpointing", False) or not getattr(
self, "disable_gradient_checkpointing", True
)
if gc_enabled:
hf_kt._kt_config.setdefault("kt_share_cache_pool", True)
@dataclass

View File

@@ -33,7 +33,7 @@ from transformers.utils import is_torch_bf16_gpu_available, is_torch_npu_availab
from ..extras import logging
from ..extras.constants import CHECKPOINT_NAMES, EngineName
from ..extras.misc import check_dependencies, check_version, get_current_device, is_env_enabled
from ..extras.packages import is_mcore_adapter_available, is_transformers_version_greater_than
from ..extras.packages import is_mcore_adapter_available
from .data_args import DataArguments
from .evaluation_args import EvaluationArguments
from .finetuning_args import FinetuningArguments
@@ -47,7 +47,13 @@ logger = logging.get_logger(__name__)
check_dependencies()
_TRAIN_ARGS = [ModelArguments, DataArguments, TrainingArguments, FinetuningArguments, GeneratingArguments]
_TRAIN_ARGS = [
ModelArguments,
DataArguments,
TrainingArguments,
FinetuningArguments,
GeneratingArguments,
]
_TRAIN_CLS = tuple[ModelArguments, DataArguments, TrainingArguments, FinetuningArguments, GeneratingArguments]
_INFER_ARGS = [ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
_INFER_CLS = tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
@@ -57,9 +63,19 @@ _EVAL_CLS = tuple[ModelArguments, DataArguments, EvaluationArguments, Finetuning
if is_mcore_adapter_available() and is_env_enabled("USE_MCA"):
from mcore_adapter import TrainingArguments as McaTrainingArguments
_TRAIN_MCA_ARGS = [ModelArguments, DataArguments, McaTrainingArguments, FinetuningArguments, GeneratingArguments]
_TRAIN_MCA_ARGS = [
ModelArguments,
DataArguments,
McaTrainingArguments,
FinetuningArguments,
GeneratingArguments,
]
_TRAIN_MCA_CLS = tuple[
ModelArguments, DataArguments, McaTrainingArguments, FinetuningArguments, GeneratingArguments
ModelArguments,
DataArguments,
McaTrainingArguments,
FinetuningArguments,
GeneratingArguments,
]
else:
_TRAIN_MCA_ARGS = []
@@ -192,7 +208,9 @@ def _check_extra_dependencies(
training_args: Optional["TrainingArguments"] = None,
) -> None:
if model_args.use_kt:
check_version("ktransformers", mandatory=True)
check_version("kt-kernel", mandatory=True)
check_version("transformers-kt", mandatory=True)
check_version("accelerate-kt", mandatory=True)
if model_args.use_unsloth:
check_version("unsloth", mandatory=True)
@@ -394,9 +412,6 @@ def get_train_args(args: dict[str, Any] | list[str] | None = None) -> _TRAIN_CLS
if model_args.use_kt and is_deepspeed_zero3_enabled():
raise ValueError("KTransformers is incompatible with DeepSpeed ZeRO-3.")
if data_args.neat_packing and is_transformers_version_greater_than("4.53.0"):
raise ValueError("Neat packing is incompatible with transformers>=4.53.0.")
_set_env_vars()
_verify_model_args(model_args, data_args, finetuning_args)
_check_extra_dependencies(model_args, finetuning_args, training_args)
@@ -513,6 +528,9 @@ def get_train_args(args: dict[str, Any] | list[str] | None = None) -> _TRAIN_CLS
)
transformers.set_seed(training_args.seed)
if model_args.use_kt:
model_args.apply_kt_config(finetuning_args, training_args, model_args.model_max_length)
return model_args, data_args, training_args, finetuning_args, generating_args

View File

@@ -14,6 +14,7 @@
import json
from dataclasses import dataclass, field
from typing import Optional
from transformers import Seq2SeqTrainingArguments
from transformers.training_args import _convert_str_dict
@@ -63,6 +64,58 @@ class RayArguments:
self.ray_init_kwargs = _convert_str_dict(json.loads(self.ray_init_kwargs))
@dataclass
class ProfilerArguments:
r"""Arguments for torch profiler configuration."""
enable_torch_profiler: bool = field(
default=False,
metadata={"help": "Whether to enable torch profiler for collecting performance traces."},
)
profiler_output_dir: Optional[str] = field(
default=None,
metadata={"help": "Directory to write profiler traces. Defaults to <output_dir>/profiler if not set."},
)
profiler_wait_steps: int = field(
default=1,
metadata={"help": "Number of steps to skip at the start of each profiling cycle."},
)
profiler_warmup_steps: int = field(
default=1,
metadata={"help": "Number of profiler warm-up steps per cycle."},
)
profiler_active_steps: int = field(
default=1,
metadata={"help": "Number of steps to actively record per cycle."},
)
profiler_repeat: int = field(
default=1,
metadata={"help": "Number of profiling cycles. Set to 0 for continuous profiling."},
)
profiler_record_shapes: bool = field(
default=True,
metadata={"help": "Whether to record tensor shapes during profiling."},
)
profiler_profile_memory: bool = field(
default=True,
metadata={"help": "Whether to profile memory usage."},
)
profiler_with_stack: bool = field(
default=True,
metadata={"help": "Whether to record stack traces during profiling."},
)
profile_modules: Optional[str] = field(
default=None,
metadata={
"help": (
"Comma-separated list of module name patterns to profile with CUDA events. "
"Supports fnmatch wildcards (e.g. 'model.layers.0.self_attn,model.layers.*.mlp'). "
"Reports per-module forward/backward timing statistics at each logging step."
)
},
)
@dataclass
class Fp8Arguments:
r"""Arguments pertaining to the FP8 training."""
@@ -87,7 +140,7 @@ class Fp8Arguments:
@dataclass
class TrainingArguments(Fp8Arguments, RayArguments, BaseTrainingArguments):
class TrainingArguments(ProfilerArguments, Fp8Arguments, RayArguments, BaseTrainingArguments):
r"""Arguments pertaining to the trainer."""
overwrite_output_dir: bool = field(

View File

@@ -20,8 +20,6 @@ from peft import LoraConfig, LoraModel, OFTConfig, PeftModel, TaskType, get_peft
from transformers.integrations import is_deepspeed_zero3_enabled
from ..extras import logging
from ..extras.constants import EngineName
from .model_utils.ktransformers import get_kt_peft_model, load_kt_peft_model
from .model_utils.misc import find_all_linear_modules, find_expanded_modules
from .model_utils.quantization import QuantizationMethod
from .model_utils.unsloth import get_unsloth_peft_model, load_unsloth_peft_model
@@ -125,7 +123,7 @@ def _setup_freeze_tuning(
model_type = getattr(model.config, "model_type", None)
if not finetuning_args.freeze_multi_modal_projector and model_type in COMPOSITE_MODELS:
trainable_layers.append(COMPOSITE_MODELS[model_type].projector_key)
trainable_layers.extend(COMPOSITE_MODELS[model_type].projector_keys)
forbidden_modules = get_forbidden_modules(model.config, finetuning_args)
for name, param in model.named_parameters():
@@ -188,12 +186,6 @@ def _setup_lora_tuning(
"token": model_args.hf_hub_token,
}
if model_args.use_kt:
if model_args.infer_backend != EngineName.KT:
raise ValueError(
"We should use ktransformers as backend to infer the adapter fine-tuned by ktransformers."
)
for adapter in adapter_to_merge:
model: LoraModel = PeftModel.from_pretrained(model, adapter, **init_kwargs)
model = model.merge_and_unload()
@@ -202,9 +194,7 @@ def _setup_lora_tuning(
logger.info_rank0(f"Merged {len(adapter_to_merge)} adapter(s).")
if adapter_to_resume is not None: # resume lora training
if model_args.use_kt:
model = load_kt_peft_model(model_args, model)
elif model_args.use_unsloth:
if model_args.use_unsloth:
model = load_unsloth_peft_model(config, model_args, finetuning_args, is_trainable=is_trainable)
else:
model = PeftModel.from_pretrained(model, adapter_to_resume, is_trainable=is_trainable, **init_kwargs)
@@ -217,16 +207,6 @@ def _setup_lora_tuning(
else:
target_modules = finetuning_args.lora_target
if model_args.use_kt:
new_list = []
for m in target_modules:
if m in ("down_proj", "up_proj", "gate_proj"):
new_list.extend([f"mlp.{m}", f"shared_experts.{m}"])
elif m not in ("generate_linear", "orig_module", "prefill_linear"):
new_list.append(m)
target_modules[:] = new_list
if finetuning_args.use_llama_pro:
target_modules = find_expanded_modules(model, target_modules, finetuning_args.freeze_trainable_layers)
@@ -270,19 +250,11 @@ def _setup_lora_tuning(
}
if model_args.use_kt:
if finetuning_args.finetuning_type == "oft":
raise ValueError("KTransformers is currently not supported for OFT.")
if finetuning_args.finetuning_type == "lora":
peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
inference_mode=False,
**peft_kwargs,
)
else:
raise ValueError("KTransformers is currently only supported for LoRA.")
if finetuning_args.finetuning_type != "lora":
raise ValueError("KTransformers only supports LoRA finetuning.")
model = get_kt_peft_model(model, peft_config)
print(f"KT_model:{model}")
peft_config = LoraConfig(task_type=TaskType.CAUSAL_LM, inference_mode=False, **peft_kwargs)
model = get_peft_model(model, peft_config)
elif model_args.use_unsloth:
if finetuning_args.finetuning_type == "oft":
raise ValueError("Unsloth is currently not supported for OFT.")

View File

@@ -31,7 +31,6 @@ from ..extras import logging
from ..extras.misc import count_parameters, skip_check_imports, try_download_model_from_other_hub
from ..extras.packages import is_torch_version_greater_than
from .adapter import init_adapter
from .model_utils.ktransformers import load_kt_pretrained_model
from .model_utils.liger_kernel import apply_liger_kernel
from .model_utils.misc import register_autoclass
from .model_utils.mod import convert_pretrained_model_to_mod, load_mod_pretrained_model
@@ -144,12 +143,7 @@ def load_model(
model = None
lazy_load = False
if model_args.use_kt:
from ktransformers.sft.monkey_patch_torch_module import install_patch
install_patch()
model = load_kt_pretrained_model(config, model_args)
elif model_args.use_unsloth:
if model_args.use_unsloth:
if model_args.adapter_name_or_path is not None:
lazy_load = True
elif is_trainable:

View File

@@ -1,154 +0,0 @@
# Copyright 2025 the KVCache.AI team, Approaching AI, and the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib.util as _u
from typing import TYPE_CHECKING, Any
import torch
from ...extras import logging
from ...extras.misc import get_current_device
if TYPE_CHECKING:
from ...hparams import FinetuningArguments, ModelArguments
from transformers import AutoConfig, AutoModelForCausalLM, PretrainedConfig, PreTrainedModel
KT_AVAILABLE = _u.find_spec("ktransformers") is not None
if KT_AVAILABLE:
from ktransformers.models.modeling_deepseek import DeepseekV2ForCausalLM
from ktransformers.models.modeling_deepseek_v3 import DeepseekV3ForCausalLM
from ktransformers.models.modeling_llama import LlamaForCausalLM
from ktransformers.models.modeling_mixtral import MixtralForCausalLM
from ktransformers.models.modeling_qwen2_moe import Qwen2MoeForCausalLM
from ktransformers.models.modeling_qwen3_moe import Qwen3MoeForCausalLM
from ktransformers.optimize.optimize import optimize_and_load_gguf
from ktransformers.server.config.config import Config
from ktransformers.sft.lora import inject_lora_layer
from ktransformers.util.custom_loader import GGUFLoader, SafeTensorLoader
from ktransformers.util.globals import GLOBAL_CONFIG
from ktransformers.util.utils import load_weights
logger = logging.get_logger(__name__)
def _get_kt_kwargs(
config: "PretrainedConfig",
model_name_or_path: str,
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments",
) -> dict[str, Any]:
return {
"model_name": model_name_or_path,
"max_seq_length": model_args.model_max_length or 4096,
"dtype": model_args.compute_dtype,
"load_in_4bit": model_args.quantization_bit == 4,
"token": model_args.hf_hub_token,
"full_finetuning": finetuning_args.finetuning_type == "full",
"device_map": {"": get_current_device()},
"rope_scaling": getattr(config, "rope_scaling", None),
"fix_tokenizer": False,
"trust_remote_code": model_args.trust_remote_code,
"use_gradient_checkpointing": "ktransformers",
}
def load_kt_pretrained_model(config: "PretrainedConfig", model_args: "ModelArguments") -> "PreTrainedModel":
r"""Optionally load pretrained model with KTransformers. Used in training."""
custom_models = {
"DeepseekV2ForCausalLM": DeepseekV2ForCausalLM,
"DeepseekV3ForCausalLM": DeepseekV3ForCausalLM,
"Qwen2MoeForCausalLM": Qwen2MoeForCausalLM,
"Qwen3MoeForCausalLM": Qwen3MoeForCausalLM,
"LlamaForCausalLM": LlamaForCausalLM,
"MixtralForCausalLM": MixtralForCausalLM,
}
Config().cpu_infer = model_args.cpu_infer
Config().chunk_size = model_args.chunk_size
config = AutoConfig.from_pretrained(model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code)
if model_args.mode == "long_context":
assert config.architectures[0] == "LlamaForCausalLM", "only LlamaForCausalLM support long_context mode"
torch.set_default_dtype(torch.float16)
else:
torch.set_default_dtype(config.torch_dtype)
with torch.device("meta"):
if config.architectures[0] in custom_models:
print("using custom modeling_xxx.py.")
if "Qwen2Moe" in config.architectures[0]: # Qwen2Moe must use flash_attention_2 to avoid overflow.
config._attn_implementation = "flash_attention_2"
if "Llama" in config.architectures[0]:
config._attn_implementation = "eager"
if "Mixtral" in config.architectures[0]:
config._attn_implementation = "flash_attention_2"
model = custom_models[config.architectures[0]](config)
else:
attn_implementation = "flash_attention_2"
model = AutoModelForCausalLM.from_config(
config, trust_remote_code=True, attn_implementation=attn_implementation
)
optimize_config_path = model_args.kt_optimize_rule
gguf_path = model_args.model_name_or_path
assert optimize_config_path is not None, "optimize_config_path must be provided (path to YAML rules file)."
assert gguf_path is not None, "gguf_path must be provided (path to a folder or .gguf file)."
GLOBAL_CONFIG._config["mod"] = "infer"
optimize_and_load_gguf(model, optimize_config_path, gguf_path, config)
return model
def get_kt_peft_model(model: "PreTrainedModel", peft_kwargs: dict[str, Any]) -> "PreTrainedModel":
r"""Get the peft model for the pretrained model with KTransformers. Used in training."""
from ktransformers.sft.peft_utils.mapping import get_peft_model
return get_peft_model(model, peft_kwargs)
def load_kt_peft_model(model_args: "ModelArguments", model: "PreTrainedModel") -> "PreTrainedModel":
r"""Load peft model with KTransformers. Used in both training and inference."""
load_adapter_name_or_path = model_args.adapter_name_or_path[0]
if load_adapter_name_or_path.endswith(".gguf"):
inject_lora_layer(model, load_adapter_name_or_path)
adapter_gguf_loader = GGUFLoader(load_adapter_name_or_path)
load_weights(model, adapter_gguf_loader, adapter_gguf=True)
model.train()
else:
inject_lora_layer(model, load_adapter_name_or_path)
adapter_loader = SafeTensorLoader(load_adapter_name_or_path)
device = next(model.parameters()).device
for key in adapter_loader.tensor_file_map.keys():
try:
tensor = adapter_loader.load_tensor(key, device=device)
model_key = key.replace("base_model.model.", "")
model_key = model_key.replace(".weight", ".default.weight")
model_key = model_key.replace(".default.default.weight", ".default.weight")
param = model.get_parameter(model_key)
param.data.copy_(tensor.data)
print(f"Loaded adapter weight: {key} -> {model_key}")
except AttributeError:
print(f"Skipping {key}: not a model parameter")
except KeyError:
print(f"Key not found in model: {model_key} (original: {key})")
return model

View File

@@ -45,7 +45,7 @@ def apply_liger_kernel(
from liger_kernel.transformers import apply_liger_kernel_to_gemma3 as apply_liger_kernel
elif model_type == "gemma3_text":
from liger_kernel.transformers import apply_liger_kernel_to_gemma3_text as apply_liger_kernel
elif model_type == "glm4":
elif model_type in ["glm", "glm4"]: # for glm4-9b, glm4-32B respectively
from liger_kernel.transformers import apply_liger_kernel_to_glm4 as apply_liger_kernel
elif model_type == "glm4v":
from liger_kernel.transformers import apply_liger_kernel_to_glm4v as apply_liger_kernel
@@ -79,6 +79,8 @@ def apply_liger_kernel(
from liger_kernel.transformers import apply_liger_kernel_to_qwen3_moe as apply_liger_kernel
elif model_type == "qwen3_next":
from liger_kernel.transformers import apply_liger_kernel_to_qwen3_next as apply_liger_kernel
elif model_type == "qwen3_5":
from liger_kernel.transformers import apply_liger_kernel_to_qwen3_5 as apply_liger_kernel
elif model_type == "gpt_oss":
try:
from liger_kernel.transformers import apply_liger_kernel_to_gpt_oss as apply_liger_kernel

View File

@@ -35,7 +35,7 @@ def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool)
forbidden_modules.add("output")
if model_type in COMPOSITE_MODELS:
forbidden_modules.add(COMPOSITE_MODELS[model_type].projector_key)
forbidden_modules.update(COMPOSITE_MODELS[model_type].projector_keys)
if freeze_vision_tower and model_type in COMPOSITE_MODELS:
forbidden_modules.update(COMPOSITE_MODELS[model_type].vision_model_keys)

View File

@@ -62,6 +62,10 @@ def add_z3_leaf_module(model: "PreTrainedModel") -> None:
# deepseek v3 and kimi vl use custom code
_set_z3_leaf_modules(model, ["DeepseekV3MoE"])
if model_type == "hy_v3":
# hy3 uses custom code
_set_z3_leaf_modules(model, ["HYV3MoE"])
if model_type == "ernie4_5_moe":
from transformers.models.ernie4_5_moe.modeling_ernie4_5_moe import Ernie4_5_MoeSparseMoeBlock
@@ -147,6 +151,11 @@ def add_z3_leaf_module(model: "PreTrainedModel") -> None:
_set_z3_leaf_modules(model, [Qwen3NextSparseMoeBlock])
if model_type == "qwen3_5_moe":
from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import Qwen3_5MoeSparseMoeBlock
_set_z3_leaf_modules(model, [Qwen3_5MoeSparseMoeBlock])
def configure_moe(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
if not is_trainable or not model_args.moe_aux_loss_coef:

View File

@@ -37,7 +37,6 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
from typing import TYPE_CHECKING
import torch
import torch.nn.functional as F
@@ -45,10 +44,6 @@ import torch.nn.functional as F
from ...extras import logging
if TYPE_CHECKING:
from ...hparams import ModelArguments
logger = logging.get_logger(__name__)
@@ -105,13 +100,3 @@ def get_unpad_data(attention_mask: "torch.Tensor") -> tuple["torch.Tensor", "tor
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
return indices, cu_seqlens, max_seqlen_in_batch
def configure_packing(model_args: "ModelArguments", is_trainable: bool) -> None:
if not is_trainable or not model_args.block_diag_attn:
return
import transformers.modeling_flash_attention_utils
transformers.modeling_flash_attention_utils._get_unpad_data = get_unpad_data
logger.info_rank0("Using block diagonal attention for sequence packing without cross-attention.")

View File

@@ -24,7 +24,6 @@ import transformers.models
from transformers.activations import ACT2FN
from ...extras import logging
from ...extras.packages import is_transformers_version_greater_than
if TYPE_CHECKING:
@@ -40,16 +39,27 @@ transformers_logger = transformers.utils.logging.get_logger(__name__)
@dataclass
class CompositeModel:
model_type: str
projector_key: str
projector_keys: list[str]
vision_model_keys: list[str]
language_model_keys: list[str]
lora_conflict_keys: list[str]
def get_projector(self, module: "torch.nn.Module") -> "torch.nn.Module":
for key in self.projector_key.split("."):
module = getattr(module, key)
def get_projectors(self, module: "torch.nn.Module") -> list["torch.nn.Module"]:
mm_projectors: list[torch.nn.Module] = []
for projector_key in self.projector_keys:
project_module = module
for key in projector_key.split("."):
project_module = getattr(project_module, key, None)
if project_module is None: # i,e gemma4 bigger one, there is no embed_audio
logger.warning_rank0(
f"Projector key {projector_key} not found in module {module.__class__.__name__}."
)
break
return module
if project_module is not None:
mm_projectors.append(project_module)
return mm_projectors
COMPOSITE_MODELS: dict[str, "CompositeModel"] = {}
@@ -57,7 +67,7 @@ COMPOSITE_MODELS: dict[str, "CompositeModel"] = {}
def _register_composite_model(
model_type: str,
projector_key: Optional[str] = None,
projector_keys: list[str] | None = None,
vision_model_keys: Optional[list[str]] = None,
language_model_keys: Optional[list[str]] = None,
lora_conflict_keys: Optional[list[str]] = None,
@@ -66,7 +76,7 @@ def _register_composite_model(
Args:
model_type: model type
projector_key: multi_modal_projector
projector_keys: multi_modal_projector
vision_model_keys: vision_tower
language_model_keys: language_model
lora_conflict_keys: None
@@ -74,7 +84,7 @@ def _register_composite_model(
"""
COMPOSITE_MODELS[model_type] = CompositeModel(
model_type=model_type,
projector_key=projector_key or "multi_modal_projector",
projector_keys=projector_keys or ["multi_modal_projector"],
vision_model_keys=vision_model_keys or ["vision_tower"],
language_model_keys=language_model_keys or ["language_model", "lm_head"],
lora_conflict_keys=lora_conflict_keys or [],
@@ -137,11 +147,15 @@ def autocast_projector_dtype(model: "PreTrainedModel", model_args: "ModelArgumen
if getattr(model, "quantization_method", None):
model_type = getattr(model.config, "model_type", None)
if model_type in COMPOSITE_MODELS:
mm_projector = COMPOSITE_MODELS[model_type].get_projector(model)
mm_projectors = COMPOSITE_MODELS[model_type].get_projectors(model)
else:
return
logger.info_rank0(f"Casting multimodal projector outputs in {model_args.compute_dtype}.")
logger.info_rank0(
f"Casting multimodal projector outputs in {model_args.compute_dtype}: "
f"{COMPOSITE_MODELS[model_type].projector_keys}."
)
for mm_projector in mm_projectors:
mm_projector.register_forward_hook(_mm_projector_forward_post_hook)
@@ -167,9 +181,9 @@ def get_forbidden_modules(config: "PretrainedConfig", finetuning_args: "Finetuni
forbidden_modules.update(vision_model_keys)
if finetuning_args.freeze_multi_modal_projector:
projector_key = COMPOSITE_MODELS[model_type].projector_key
logger.info_rank0(f"Set multi model projector not trainable: {projector_key}.")
forbidden_modules.add(projector_key)
projector_keys = COMPOSITE_MODELS[model_type].projector_keys
logger.info_rank0(f"Set multi model projector not trainable: {projector_keys}.")
forbidden_modules.update(projector_keys)
if finetuning_args.freeze_language_model:
language_model_keys = COMPOSITE_MODELS[model_type].language_model_keys
@@ -201,7 +215,7 @@ def patch_target_modules(
_register_composite_model(
model_type="dots_ocr",
projector_key="vision_tower.merger",
projector_keys=["vision_tower.merger"],
vision_model_keys=["vision_tower"],
language_model_keys=["model", "lm_head"],
lora_conflict_keys=["merger"],
@@ -220,10 +234,18 @@ _register_composite_model(
)
_register_composite_model(
model_type="gemma4",
projector_keys=["model.embed_vision", "model.embed_audio"],
vision_model_keys=["vision_tower", "audio_tower"],
lora_conflict_keys=["per_layer_projection_norm"],
)
# copied from qwen2vl
_register_composite_model(
model_type="glm4v",
projector_key="visual.merger",
projector_keys=["visual.merger"],
vision_model_keys=["visual.patch_embed", "visual.blocks"],
language_model_keys=["language_model", "lm_head"],
lora_conflict_keys=["patch_embed"],
@@ -232,7 +254,7 @@ _register_composite_model(
_register_composite_model(
model_type="glm4v_moe",
projector_key="visual.merger",
projector_keys=["visual.merger"],
vision_model_keys=["visual.patch_embed", "visual.blocks"],
language_model_keys=["language_model", "lm_head"],
lora_conflict_keys=["patch_embed"],
@@ -241,7 +263,7 @@ _register_composite_model(
_register_composite_model(
model_type="glm_ocr",
projector_key="visual.merger",
projector_keys=["visual.merger"],
vision_model_keys=["visual.patch_embed", "visual.blocks"],
language_model_keys=["language_model", "lm_head"],
lora_conflict_keys=["patch_embed"],
@@ -258,7 +280,7 @@ _register_composite_model(
_register_composite_model(
model_type="Keye",
projector_key="mlp_AR",
projector_keys=["mlp_AR"],
vision_model_keys=["visual.vision_model.patch_embedding", "visual.vision_model.encoder"],
language_model_keys=["model", "lm_head"],
lora_conflict_keys=["patch_embedding"],
@@ -293,15 +315,23 @@ _register_composite_model(
_register_composite_model(
model_type="minicpmv",
projector_key="resampler",
projector_keys=["resampler"],
vision_model_keys=["vpm"],
language_model_keys=["llm"],
)
_register_composite_model(
model_type="minicpmv4_6",
projector_keys=["model.merger"],
vision_model_keys=["model.vision_tower"],
language_model_keys=["model.language_model", "lm_head"],
)
_register_composite_model(
model_type="minicpmo",
projector_key="resampler",
projector_keys=["resampler"],
vision_model_keys=["vpm", "apm", "audio_avg_pooler", "audio_projection_layer", "tts"],
language_model_keys=["llm"],
lora_conflict_keys=["audio_projection_layer"],
@@ -310,7 +340,7 @@ _register_composite_model(
_register_composite_model(
model_type="mistral3",
projector_key="model.multi_modal_projector",
projector_keys=["model.multi_modal_projector"],
)
@@ -333,7 +363,7 @@ _register_composite_model(
_register_composite_model(
model_type="qwen2_5_omni_thinker",
projector_key="visual.merger",
projector_keys=["visual.merger", "audio_tower.proj"],
vision_model_keys=["visual.patch_embed", "visual.blocks", "audio_tower"],
language_model_keys=["model", "lm_head"],
lora_conflict_keys=["patch_embed"],
@@ -342,29 +372,25 @@ _register_composite_model(
_register_composite_model(
model_type="qwen2_vl",
projector_key="visual.merger",
projector_keys=["visual.merger"],
vision_model_keys=["visual.patch_embed", "visual.blocks"],
language_model_keys=["language_model", "lm_head"]
if is_transformers_version_greater_than("4.52.0")
else ["model", "lm_head"],
language_model_keys=["language_model", "lm_head"],
lora_conflict_keys=["patch_embed"],
)
_register_composite_model(
model_type="qwen2_5_vl",
projector_key="visual.merger",
projector_keys=["visual.merger"],
vision_model_keys=["visual.patch_embed", "visual.blocks"],
language_model_keys=["language_model", "lm_head"]
if is_transformers_version_greater_than("4.52.0")
else ["model", "lm_head"],
language_model_keys=["language_model", "lm_head"],
lora_conflict_keys=["patch_embed"],
)
_register_composite_model(
model_type="qwen3_vl",
projector_key="visual.merger",
projector_keys=["visual.merger"],
vision_model_keys=["visual.pos_embed", "visual.patch_embed", "visual.blocks", "visual.deepstack_merger_list"],
language_model_keys=["language_model", "lm_head"],
lora_conflict_keys=["patch_embed"],
@@ -373,7 +399,7 @@ _register_composite_model(
_register_composite_model(
model_type="qwen3_vl_moe",
projector_key="visual.merger",
projector_keys=["visual.merger"],
vision_model_keys=["visual.pos_embed", "visual.patch_embed", "visual.blocks", "visual.deepstack_merger_list"],
language_model_keys=["language_model", "lm_head"],
lora_conflict_keys=["patch_embed"],
@@ -382,7 +408,7 @@ _register_composite_model(
_register_composite_model(
model_type="qwen3_omni_moe_thinker",
projector_key="visual.merger",
projector_keys=["visual.merger", "audio_tower.proj"],
vision_model_keys=[
"visual.pos_embed",
"visual.patch_embed",
@@ -397,7 +423,7 @@ _register_composite_model(
_register_composite_model(
model_type="qwen3_5",
projector_key="model.visual.merger",
projector_keys=["model.visual.merger"],
vision_model_keys=["visual.pos_embed", "visual.patch_embed", "visual.blocks"],
language_model_keys=["language_model", "lm_head"],
lora_conflict_keys=["patch_embed"],
@@ -406,7 +432,7 @@ _register_composite_model(
_register_composite_model(
model_type="qwen3_5_moe",
projector_key="model.visual.merger",
projector_keys=["model.visual.merger"],
vision_model_keys=["visual.pos_embed", "visual.patch_embed", "visual.blocks"],
language_model_keys=["language_model", "lm_head"],
lora_conflict_keys=["patch_embed"],

View File

@@ -30,7 +30,6 @@ from .model_utils.embedding import resize_embedding_layer
from .model_utils.kv_cache import configure_kv_cache
from .model_utils.longlora import configure_longlora
from .model_utils.moe import add_z3_leaf_module, configure_moe
from .model_utils.packing import configure_packing
from .model_utils.quantization import configure_quantization
from .model_utils.rope import configure_rope
from .model_utils.valuehead import prepare_valuehead_model
@@ -61,6 +60,195 @@ def patch_qwen3_omni_moe_thinker_text_sparse_moe_block():
modeling_qwen3_omni_moe.Qwen3OmniMoeThinkerTextSparseMoeBlock = Qwen3OmniMoeThinkerTextSparseMoeBlock
def _check_fla_dependencies() -> None:
"""Check that the FLA dependencies required for varlen GDN forwarding are available.
Requires ``flash-linear-attention >= 0.4.1`` (which exposes the varlen
``causal_conv1d`` under ``fla.modules.convolution`` and the
``chunk_gated_delta_rule`` / ``fused_recurrent_gated_delta_rule`` kernels
under ``fla.ops.gated_delta_rule``). Raises ``ImportError`` with an
actionable message otherwise.
"""
try:
from fla.modules.convolution import causal_conv1d # noqa: F401
from fla.ops.gated_delta_rule import ( # noqa: F401
chunk_gated_delta_rule,
fused_recurrent_gated_delta_rule,
)
except ImportError as exc:
raise ImportError(
"Qwen3.5 packing-seq forwarding requires `flash-linear-attention>=0.4.1` "
"(provides `fla.modules.convolution.causal_conv1d` and "
"`fla.ops.gated_delta_rule.{chunk,fused_recurrent}_gated_delta_rule`). "
"Please install/upgrade it."
) from exc
def patch_qwen3_5_forward(model: "PreTrainedModel") -> None:
"""Patch the forward method of Qwen3_5ForConditionalGeneration to support cu_seqlens input only patch when do training.
Refer to: https://github.com/axolotl-ai-cloud/axolotl/blob/main/src/axolotl/monkeypatch/models/qwen3_5/modeling.py.
"""
if is_transformers_version_greater_than("5.2.0"):
from transformers.models.qwen3_5.modeling_qwen3_5 import apply_mask_to_padding_states
from torch.nn import functional as F
from transformers.modeling_flash_attention_utils import prepare_fa_kwargs_from_position_ids
_check_fla_dependencies()
from fla.modules.convolution import causal_conv1d as fla_causal_conv1d
from fla.ops.gated_delta_rule import chunk_gated_delta_rule
def _patched_decoder_forward(
self,
hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
past_key_values=None,
cache_position: torch.LongTensor | None = None,
**kwargs,
) -> torch.FloatTensor:
"""Decoder layer forward that passes position_ids through to linear attention."""
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
if self.layer_type == "linear_attention":
hidden_states = self.linear_attn(
hidden_states=hidden_states,
cache_params=past_key_values,
cache_position=cache_position,
attention_mask=attention_mask,
position_ids=position_ids, # passing position_ids to linear attention
)
elif self.layer_type == "full_attention":
hidden_states, _ = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids[None, 0], # keep [1, B, L]
past_key_values=past_key_values,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
if isinstance(hidden_states, tuple): # MoE returns (hidden_states, router_logits)
hidden_states, _ = hidden_states
hidden_states = residual + hidden_states
return hidden_states
# gdn forward (training only, cache_params is always None)
def _patch_gdn_forward(
self,
hidden_states: torch.Tensor,
cache_params=None,
cache_position: torch.LongTensor | None = None,
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
):
# @kuangdd fix: here attention_mask is None
hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask)
batch_size, seq_len, _ = hidden_states.shape
# Qwen3.5 VL passes 3-D MRoPE position_ids ([axes, B, T]); collapse to [B, T].
if position_ids is not None and position_ids.ndim == 3:
position_ids = position_ids[0]
# cu_seqlens for the FLA varlen path is only needed when batch_size == 1:
# packing / neat-packing: always folded into a single sequence (bsz == 1) -> varlen
# non-packing, bsz == 1: single segment, equivalent to a standard single sequence
# non-packing, bsz > 1: not packed, use cu_seqlens=None and standard batched kernels
if position_ids is not None and batch_size == 1:
cu_seqlens = prepare_fa_kwargs_from_position_ids(position_ids)[0][0]
else:
cu_seqlens = None
# FLA varlen kernels expect [B, T, D] layout, not [B, D, T] like the
# standard causal-conv1d path that the upstream forward uses.
mixed_qkv = self.in_proj_qkv(hidden_states)
z = self.in_proj_z(hidden_states)
z = z.reshape(batch_size, seq_len, -1, self.head_v_dim)
b = self.in_proj_b(hidden_states)
a = self.in_proj_a(hidden_states)
# FLA's causal_conv1d returns (out, final_state); we don't use the state here.
mixed_qkv, _ = fla_causal_conv1d(
x=mixed_qkv,
weight=self.conv1d.weight.squeeze(1),
bias=self.conv1d.bias,
activation=self.activation,
cu_seqlens=cu_seqlens,
)
query, key, value = torch.split(
mixed_qkv,
[
self.key_dim,
self.key_dim,
self.value_dim,
],
dim=-1,
)
query = query.reshape(batch_size, seq_len, -1, self.head_k_dim)
key = key.reshape(batch_size, seq_len, -1, self.head_k_dim)
value = value.reshape(batch_size, seq_len, -1, self.head_v_dim)
beta = b.sigmoid()
# If the model is loaded in fp16, without the .float() here, A might be -inf
g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias)
if self.num_v_heads // self.num_k_heads > 1:
query = query.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2)
key = key.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2)
core_attn_out, _ = chunk_gated_delta_rule(
query,
key,
value,
g=g,
beta=beta,
initial_state=None,
output_final_state=False,
use_qk_l2norm_in_kernel=True,
**({"cu_seqlens": cu_seqlens} if cu_seqlens is not None else {}),
)
core_attn_out = core_attn_out.reshape(-1, self.head_v_dim)
z = z.reshape(-1, self.head_v_dim)
core_attn_out = self.norm(core_attn_out, z)
core_attn_out = core_attn_out.reshape(batch_size, seq_len, -1)
output = self.out_proj(core_attn_out)
return output
if model.config.architectures[0] == "Qwen3_5ForConditionalGeneration":
from transformers.models.qwen3_5.modeling_qwen3_5 import Qwen3_5DecoderLayer, Qwen3_5GatedDeltaNet
Qwen3_5DecoderLayer.forward = _patched_decoder_forward
Qwen3_5GatedDeltaNet.forward = _patch_gdn_forward
elif model.config.architectures[0] == "Qwen3_5MoeForConditionalGeneration":
from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import (
Qwen3_5MoeDecoderLayer,
Qwen3_5MoeGatedDeltaNet,
)
Qwen3_5MoeDecoderLayer.forward = _patched_decoder_forward
Qwen3_5MoeGatedDeltaNet.forward = _patch_gdn_forward
logger.info_rank0("Patched Qwen3.5 decoder forward to support cu_seqlens input only patch when do training.")
def patch_youtu_vl_model(model: "PreTrainedModel") -> None:
original_forward = model.forward
@@ -142,7 +330,6 @@ def patch_config(
configure_quantization(config, tokenizer, model_args, is_trainable, init_kwargs)
configure_moe(config, model_args, is_trainable)
configure_visual_model(config)
configure_packing(model_args, is_trainable)
configure_kv_cache(config, model_args, is_trainable)
if getattr(config, "model_type", None) == "qwen":
@@ -234,6 +421,9 @@ def patch_model(
autocast_projector_dtype(model, model_args)
add_z3_leaf_module(model)
if getattr(model.config, "model_type", None) in ["qwen3_5", "qwen3_5_moe"] and model_args.flash_attn == "fa2":
patch_qwen3_5_forward(model)
if not model_args.use_unsloth:
print_attn_implementation(model.config)

View File

@@ -12,11 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import fnmatch
import json
import os
import signal
import sys
import time
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor
from datetime import timedelta
from typing import TYPE_CHECKING, Any, Optional
@@ -31,7 +33,7 @@ from typing_extensions import override
from ..extras import logging
from ..extras.constants import TRAINER_LOG, V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
from ..extras.misc import get_peak_memory, is_env_enabled, use_ray
from ..extras.misc import get_peak_memory, is_env_enabled, is_torch_cuda_available, is_torch_npu_available, use_ray
from ..extras.packages import is_safetensors_available
@@ -338,6 +340,96 @@ class LogCallback(TrainerCallback):
self.thread_pool.submit(self._write_log, args.output_dir, logs)
class TorchProfilerCallback(TrainerCallback):
r"""A callback for collecting torch.profiler traces during training.
Activated by setting ``enable_torch_profiler: true`` in the YAML config.
Configuration fields (in YAML):
profiler_output_dir where to write traces (default: <output_dir>/profiler)
profiler_wait_steps steps to skip at start of each cycle (default: 1)
profiler_warmup_steps profiler warm-up steps per cycle (default: 1)
profiler_active_steps steps to record per cycle (default: 1)
profiler_repeat number of cycles; 0 = forever (default: 1)
profiler_record_shapes record tensor shapes (default: true)
profiler_profile_memory profile memory usage (default: true)
profiler_with_stack record stack traces (default: true)
Trace files (one per rank, Chrome / TensorBoard JSON format) are written to
``<profiler_output_dir>/rank_<N>/``.
"""
def __init__(self, training_args: "TrainingArguments") -> None:
self.profiler = None
self.profiler_args = training_args
@staticmethod
def _get_rank() -> int:
import torch.distributed as dist
if dist.is_available() and dist.is_initialized():
return dist.get_rank()
return 0
@override
def on_train_begin(
self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs
) -> None:
if self.profiler is not None:
self.profiler.stop()
self.profiler = None
pa = self.profiler_args
output_dir = pa.profiler_output_dir or os.path.join(args.output_dir, "profiler")
rank = self._get_rank()
trace_dir = os.path.join(output_dir, f"rank_{rank}")
os.makedirs(trace_dir, exist_ok=True)
activities = [torch.profiler.ProfilerActivity.CPU]
try:
if is_torch_cuda_available():
activities.append(torch.profiler.ProfilerActivity.CUDA)
if is_torch_npu_available():
activities.append(torch.profiler.ProfilerActivity.NPU)
except Exception:
pass
self.profiler = torch.profiler.profile(
activities=activities,
schedule=torch.profiler.schedule(
wait=pa.profiler_wait_steps,
warmup=pa.profiler_warmup_steps,
active=pa.profiler_active_steps,
repeat=pa.profiler_repeat,
),
on_trace_ready=torch.profiler.tensorboard_trace_handler(trace_dir),
record_shapes=pa.profiler_record_shapes,
profile_memory=pa.profiler_profile_memory,
with_stack=pa.profiler_with_stack,
)
self.profiler.start()
logger.info_rank0(
f"TorchProfiler started — schedule: wait={pa.profiler_wait_steps}, warmup={pa.profiler_warmup_steps}, "
f"active={pa.profiler_active_steps}, repeat={pa.profiler_repeat}. Traces → {output_dir}"
)
@override
def on_step_end(
self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs
) -> None:
if self.profiler is not None:
self.profiler.step()
@override
def on_train_end(
self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs
) -> None:
if self.profiler is not None:
self.profiler.stop()
self.profiler = None
logger.info_rank0("TorchProfiler stopped.")
class ReporterCallback(TrainerCallback):
r"""A callback for reporting training status to external logger."""
@@ -394,3 +486,143 @@ class ReporterCallback(TrainerCallback):
"generating_args": self.generating_args.to_dict(),
}
)
class ModuleProfilerCallback(TrainerCallback):
r"""Profile forward/backward time of specified modules using accelerator events.
Hooks are registered on modules matching the user-provided name patterns.
Timing statistics are logged at each trainer logging step.
Usage in YAML config:
profile_modules: "*.layers.0.self_attn,*.layers.0.mlp"
Supports fnmatch wildcards:
profile_modules: "*.layers.*.self_attn,*.layers.*.mlp.experts"
"""
@staticmethod
def _get_accelerator():
"""Detect available accelerator and return (event_factory, synchronize_fn)."""
if is_torch_cuda_available():
return torch.cuda.Event, torch.cuda.synchronize
if is_torch_npu_available():
return torch.npu.Event, torch.npu.synchronize
return None, None
def __init__(self, profile_modules: str) -> None:
self.patterns = [p.strip() for p in profile_modules.split(",") if p.strip()]
self._create_event, self._synchronize = self._get_accelerator()
self._handles: list[Any] = []
self._forward_times: dict[str, list[float]] = defaultdict(list)
self._backward_times: dict[str, list[float]] = defaultdict(list)
self._pending_forward: dict[str, tuple] = {}
self._pending_backward: dict[str, tuple] = {}
@property
def enabled(self) -> bool:
return self._create_event is not None
def _match(self, name: str) -> bool:
return any(fnmatch.fnmatch(name, pat) for pat in self.patterns)
def _make_forward_pre_hook(self, name: str):
def hook(module, input):
start = self._create_event(enable_timing=True)
end = self._create_event(enable_timing=True)
start.record()
self._pending_forward[name] = (start, end)
return hook
def _make_forward_hook(self, name: str):
def hook(module, input, output):
pair = self._pending_forward.get(name)
if pair is not None:
pair[1].record()
return hook
def _make_backward_pre_hook(self, name: str):
def hook(module, grad_output):
start = self._create_event(enable_timing=True)
end = self._create_event(enable_timing=True)
start.record()
self._pending_backward[name] = (start, end)
return hook
def _make_backward_hook(self, name: str):
def hook(module, grad_input, grad_output):
pair = self._pending_backward.get(name)
if pair is not None:
pair[1].record()
return hook
@override
def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
if not self.enabled:
logger.warning_rank0("ModuleProfiler: no supported accelerator (CUDA/NPU) found, profiling disabled.")
return
model = kwargs.get("model")
if model is None:
return
matched = []
for name, module in model.named_modules():
if not name or not self._match(name):
continue
self._handles.append(module.register_forward_pre_hook(self._make_forward_pre_hook(name)))
self._handles.append(module.register_forward_hook(self._make_forward_hook(name)))
self._handles.append(module.register_full_backward_pre_hook(self._make_backward_pre_hook(name)))
self._handles.append(module.register_full_backward_hook(self._make_backward_hook(name)))
matched.append(name)
if matched:
logger.info_rank0(
f"ModuleProfiler: registered hooks on {len(matched)} modules: {matched[:5]}"
+ (f" ... (+{len(matched) - 5} more)" if len(matched) > 5 else "")
)
else:
logger.warning_rank0(f"ModuleProfiler: no modules matched patterns {self.patterns}")
@override
def on_step_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
if not self.enabled:
return
self._synchronize()
for name, (start, end) in self._pending_forward.items():
self._forward_times[name].append(start.elapsed_time(end))
self._pending_forward.clear()
for name, (start, end) in self._pending_backward.items():
self._backward_times[name].append(start.elapsed_time(end))
self._pending_backward.clear()
@override
def on_log(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
if not self._forward_times and not self._backward_times:
return
lines = ["[ModuleProfiler] Timing (ms):"]
all_names = sorted(set(list(self._forward_times.keys()) + list(self._backward_times.keys())))
for name in all_names:
fwd = self._forward_times.get(name, [])
bwd = self._backward_times.get(name, [])
fwd_mean = sum(fwd) / len(fwd) if fwd else 0.0
bwd_mean = sum(bwd) / len(bwd) if bwd else 0.0
lines.append(f" {name}: fwd={fwd_mean:.3f}, bwd={bwd_mean:.3f}, total={fwd_mean + bwd_mean:.3f}")
logger.info_rank0("\n".join(lines))
self._forward_times.clear()
self._backward_times.clear()
@override
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
for handle in self._handles:
handle.remove()
self._handles.clear()

View File

@@ -1,62 +0,0 @@
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's TRL library.
# https://github.com/huggingface/trl/blob/v0.8.0/trl/trainer/dpo_trainer.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
import torch
from ktransformers.sft.lora import KTrainer # type: ignore
from typing_extensions import override
from ..trainer_utils import get_batch_logps, nested_detach
from .trainer import CustomDPOTrainer
if TYPE_CHECKING:
from transformers import PreTrainedModel
class KDPOTrainer(KTrainer, CustomDPOTrainer):
@override
def concatenated_forward(
self, model: "PreTrainedModel", batch: dict[str, "torch.Tensor"], is_ref_model: bool = False
) -> tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
r"""Compute the sum log probabilities of the labels under given logits if loss_type is not IPO, ORPO or SimPO.
Otherwise the average log probabilities.
"""
if self.finetuning_args.use_ref_model:
batch = nested_detach(batch, clone=True) # avoid error
labels = batch.pop("labels") # dpo do not need compute loss in forward
all_logits: torch.Tensor = model(**batch, return_dict=True, use_cache=False).logits.to(torch.float32)
all_logits = all_logits.to("cpu")
labels = labels.to(all_logits.device)
all_logps, valid_length = get_batch_logps(
logits=all_logits, labels=labels, ld_alpha=(self.ld_alpha if not is_ref_model else None)
)
if self.loss_type in ["ipo", "orpo", "simpo"]:
all_logps = all_logps / valid_length
batch_size = batch["input_ids"].size(0) // 2
chosen_logps, rejected_logps = all_logps.split(batch_size, dim=0)
chosen_logits, rejected_logits = all_logits.split(batch_size, dim=0)
chosen_length, _ = valid_length.split(batch_size, dim=0)
if self.loss_type in ["ipo", "orpo", "simpo"]:
return chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_logps
else:
return chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_logps / chosen_length

View File

@@ -123,10 +123,10 @@ class CustomDPOTrainer(DPOTrainer):
self.running = RunningMoments(self.accelerator)
@override
def create_optimizer(self) -> "torch.optim.Optimizer":
def create_optimizer(self, *args, **kwargs) -> "torch.optim.Optimizer":
if self.optimizer is None:
self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args)
return super().create_optimizer()
return super().create_optimizer(*args, **kwargs)
@override
def create_scheduler(

View File

@@ -62,14 +62,6 @@ def run_dpo(
else:
ref_model = None
if model_args.use_kt:
from ktransformers.util.globals import GLOBAL_CONFIG # type: ignore
from .ktrainer import KDPOTrainer as CustomDPOTrainer
GLOBAL_CONFIG._config["mod"] = "sft"
else:
from .trainer import CustomDPOTrainer
# Initialize our Trainer

View File

@@ -0,0 +1,18 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .workflow import run_sft
__all__ = ["run_sft"]

View File

@@ -0,0 +1,181 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Optional
from ...data import SFTDataCollatorWith4DAttentionMask, get_dataset, get_template_and_fix_tokenizer
from ...extras.constants import IGNORE_INDEX
from ...extras.logging import get_logger
from ...extras.misc import calculate_tps
from ...extras.packages import is_hyper_parallel_available, is_transformers_version_greater_than
from ...extras.ploting import plot_loss
from ...model import load_model, load_tokenizer
from ..callbacks import SaveProcessorCallback
from ..sft.metric import ComputeAccuracy, ComputeSimilarity, eval_logit_processor
from ..trainer_utils import asft_loss_func, create_modelcard_and_push, create_ref_model, dft_loss_func, eaft_loss_func
if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, TrainerCallback
from ...hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
logger = get_logger(__name__)
def run_sft(
model_args: "ModelArguments",
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments",
callbacks: Optional[list["TrainerCallback"]] = None,
):
if not is_hyper_parallel_available():
raise ImportError("hyper_parallel is not installed. Please install it with `pip install hyper_parallel`.")
from hyper_parallel.integration.llamafactory import ( # pylint: disable=C0415
HyperParallelArguments,
HyperParallelTrainer,
)
tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"]
template = get_template_and_fix_tokenizer(tokenizer, data_args)
dataset_module = get_dataset(template, model_args, data_args, training_args, stage="sft", **tokenizer_module)
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
ref_model = None
if finetuning_args.use_asft_loss:
ref_model = create_ref_model(model_args, finetuning_args)
data_collator = SFTDataCollatorWith4DAttentionMask(
template=template,
model=model if not training_args.predict_with_generate else None,
pad_to_multiple_of=8 if training_args.do_train else None,
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id,
block_diag_attn=model_args.block_diag_attn,
attn_implementation=getattr(model.config, "_attn_implementation", None),
compute_dtype=model_args.compute_dtype,
**tokenizer_module,
)
# Metric utils
metric_module = {}
if training_args.predict_with_generate:
metric_module["compute_metrics"] = ComputeSimilarity(tokenizer=tokenizer)
elif finetuning_args.compute_accuracy:
metric_module["compute_metrics"] = ComputeAccuracy()
metric_module["preprocess_logits_for_metrics"] = eval_logit_processor
# Keyword arguments for `model.generate`
gen_kwargs = generating_args.to_dict(obey_generation_config=True)
if is_transformers_version_greater_than("4.58.0"):
extra_ids = getattr(tokenizer, "additional_special_tokens_ids", None)
if not isinstance(extra_ids, list):
extra_special_tokens = getattr(tokenizer, "_extra_special_tokens", [])
string_tokens = [str(t) for t in extra_special_tokens]
extra_ids = tokenizer.convert_tokens_to_ids(string_tokens)
all_eos_ids = [tokenizer.eos_token_id] + [i for i in extra_ids if i != -1]
gen_kwargs["eos_token_id"] = list(dict.fromkeys(all_eos_ids))
else:
gen_kwargs["eos_token_id"] = [tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids
gen_kwargs["pad_token_id"] = tokenizer.pad_token_id
hp_args = HyperParallelArguments.from_finetuning_args(finetuning_args)
callbacks = list(callbacks or [])
processor = tokenizer_module.get("processor")
if processor is not None:
callbacks.append(SaveProcessorCallback(processor))
compute_loss_func = None
if finetuning_args.use_dft_loss:
compute_loss_func = dft_loss_func
elif finetuning_args.use_eaft_loss:
compute_loss_func = lambda outputs, labels, num_items_in_batch=None: eaft_loss_func( # noqa: E731
outputs, labels, num_items_in_batch, finetuning_args.eaft_alpha
)
elif finetuning_args.use_asft_loss:
from functools import partial
compute_loss_func = partial(asft_loss_func, asft_alpha=finetuning_args.asft_alpha)
trainer = HyperParallelTrainer(
hp_args=hp_args,
model=model,
args=training_args,
finetuning_args=finetuning_args,
data_collator=data_collator,
callbacks=callbacks,
gen_kwargs=gen_kwargs,
ref_model=ref_model,
compute_loss_func=compute_loss_func,
**dataset_module,
**tokenizer_module,
**metric_module,
)
if finetuning_args.use_badam:
from types import MethodType
from badam import BAdamCallback, clip_grad_norm_old_version # type: ignore[import]
trainer.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, trainer.accelerator)
trainer.add_callback(BAdamCallback)
# Training
if training_args.do_train:
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
trainer.save_model()
if finetuning_args.include_effective_tokens_per_second:
train_result.metrics["effective_tokens_per_sec"] = calculate_tps(
dataset_module["train_dataset"], train_result.metrics, stage="sft"
)
trainer.log_metrics("train", train_result.metrics)
trainer.save_metrics("train", train_result.metrics)
trainer.save_state()
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
keys = ["loss"]
if isinstance(dataset_module.get("eval_dataset"), dict):
keys += sum(
[[f"eval_{key}_loss", f"eval_{key}_accuracy"] for key in dataset_module["eval_dataset"].keys()],
[],
)
else:
keys += ["eval_loss", "eval_accuracy"]
plot_loss(training_args.output_dir, keys=keys)
if training_args.predict_with_generate:
tokenizer.padding_side = "left"
# Evaluation
if training_args.do_eval:
metrics = trainer.evaluate(metric_key_prefix="eval", **gen_kwargs)
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
# Predict
if training_args.do_predict:
logger.warning_rank0_once("Batch generation can be very slow. Consider using `scripts/vllm_infer.py` instead.")
predict_results = trainer.predict(dataset_module["eval_dataset"], metric_key_prefix="predict", **gen_kwargs)
trainer.log_metrics("predict", predict_results.metrics)
trainer.save_metrics("predict", predict_results.metrics)
trainer.save_predictions(dataset_module["eval_dataset"], predict_results, generating_args.skip_special_tokens)
# Create model card
create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args)

View File

@@ -120,10 +120,10 @@ class CustomKTOTrainer(KTOTrainer):
self.add_callback(BAdamCallback)
@override
def create_optimizer(self) -> "torch.optim.Optimizer":
def create_optimizer(self, *args, **kwargs) -> "torch.optim.Optimizer":
if self.optimizer is None:
self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args)
return super().create_optimizer()
return super().create_optimizer(*args, **kwargs)
@override
def create_scheduler(

View File

@@ -12,4 +12,62 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# TODO override the original trainer
from typing import Any
import torch.nn.functional as F
from mcore_adapter.trainer import McaTrainer
from torch import Tensor
from transformers import PreTrainedTokenizerBase
from typing_extensions import override
from ...extras.constants import IGNORE_INDEX
class CustomMcaTrainer(McaTrainer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@override
def _pad_batched_inputs(self, inputs: dict[str, Tensor | Any], seq_length: int):
r"""Override to avoid padding error when handling 3d posids."""
padding_inputs = {
k: v.tolist() if v is not None and isinstance(v, Tensor) else v
for k, v in inputs.items()
if k in self._language_input_names
}
position_ids_3d = None
if isinstance(inputs.get("position_ids"), Tensor) and inputs["position_ids"].dim() == 3:
position_ids_3d = inputs["position_ids"]
padding_inputs.pop("position_ids", None)
if "labels" in padding_inputs:
padding_inputs["labels"] = [
labels + [IGNORE_INDEX] * (seq_length - len(labels)) for labels in padding_inputs["labels"]
]
tokenizer = (
self.processing_class
if isinstance(self.processing_class, PreTrainedTokenizerBase)
else getattr(self.processing_class, "tokenizer", self.processing_class)
)
padding_side = getattr(tokenizer, "padding_side", "right")
padding_inputs = tokenizer.pad(
padding_inputs,
padding="max_length",
max_length=seq_length,
return_tensors="pt",
).to(self.args.device)
inputs.update(padding_inputs)
if position_ids_3d is not None:
current_seq_len = position_ids_3d.size(-1)
if current_seq_len < seq_length:
pad_len = seq_length - current_seq_len
if padding_side == "left":
position_ids_3d = F.pad(position_ids_3d, (pad_len, 0), value=0)
else:
position_ids_3d = F.pad(position_ids_3d, (0, pad_len), value=0)
inputs["position_ids"] = position_ids_3d.to(self.args.device)
return inputs

View File

@@ -19,6 +19,7 @@ from collections.abc import Sequence
from copy import deepcopy
from typing import TYPE_CHECKING, Any, Optional
import torch
from transformers import DataCollatorForSeq2Seq
from ...data import (
@@ -43,9 +44,10 @@ if not is_mcore_adapter_available():
from mcore_adapter.models import AutoConfig, AutoModel
from mcore_adapter.trainer import DPOTrainer as McaDPOTrainer
from mcore_adapter.trainer import McaTrainer
from mcore_adapter.trainer.dpo_config import DPOConfig
from .trainer import CustomMcaTrainer
if TYPE_CHECKING:
from mcore_adapter.training_args import Seq2SeqTrainingArguments as McaSeq2SeqTrainingArguments
@@ -72,13 +74,25 @@ def _data_collator_wrapper(data_collator: Any):
for k in ["attention_mask", "position_ids"]:
if k in feature:
feature[k] = feature[k][:-1]
return data_collator(features)
# for qwen vl series model
tmp_features = data_collator(features)
tmp_features.pop("rope_deltas", None)
position_ids = tmp_features.get("position_ids", None)
if position_ids is not None and position_ids.dim() == 3:
if position_ids.shape[0] == 4:
position_ids = position_ids[1:]
tmp_features["position_ids"] = position_ids
return tmp_features
return wrapper
def _check_model_support(model_args: "ModelArguments"):
from transformers import AutoConfig as HfAutoConfig
if os.path.exists(os.path.join(model_args.model_name_or_path, "mca_config.json")): # load from mcore ckpt
mca_config = json.load(open(os.path.join(model_args.model_name_or_path, "mca_config.json")))
model_type = mca_config.get("hf_model_type", None)
@@ -97,17 +111,24 @@ def _check_model_support(model_args: "ModelArguments"):
def _freeze_model_parameters(model: Any, finetuning_args: "FinetuningArguments"):
"""Freeze model parameters for qwen_vl series models based on finetuning arguments."""
if getattr(model.config, "hf_model_type", None) not in ["qwen2_vl", "qwen2_5_vl", "qwen3_vl", "qwen3_vl_moe", "qwen3_5", "qwen3_5_moe"]:
if getattr(model.config, "hf_model_type", None) not in [
"qwen2_vl",
"qwen2_5_vl",
"qwen3_vl",
"qwen3_vl_moe",
"qwen3_5",
"qwen3_5_moe",
]:
return
params_to_freeze = []
if finetuning_args.freeze_vision_tower:
params_to_freeze.extend(["vision_model.blocks", "vision_model.patch_embed"])
if getattr(model.config, "hf_model_type", None) in ["qwen3_vl", "qwen3_vl_moe"]:
if getattr(model.config, "hf_model_type", None) in ["qwen3_vl", "qwen3_vl_moe", "qwen3_5", "qwen3_5_moe"]:
params_to_freeze.extend(["vision_model.pos_embed"])
if finetuning_args.freeze_multi_modal_projector:
params_to_freeze.extend(["multi_modal_projector"])
params_to_freeze.extend(["vision_model.merger"])
if finetuning_args.freeze_language_model:
params_to_freeze.extend(["embedding", "decoder", "output_layer"])
@@ -118,6 +139,27 @@ def _freeze_model_parameters(model: Any, finetuning_args: "FinetuningArguments")
p.requires_grad_(False)
def _build_meta_hf_model_for_collator(model_args: "ModelArguments") -> Any | None:
r"""Build a lightweight HF model on meta device for compatibility with collator."""
from transformers import AutoConfig as HfAutoConfig
from transformers import AutoModel as HfAutoModel
from transformers import AutoModelForImageTextToText
try:
config = HfAutoConfig.from_pretrained(
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
)
with torch.device("meta"):
try:
# Prefer multimodal auto class for VLMs (e.g. qwen2-vl), so get_rope_index is available.
return AutoModelForImageTextToText.from_config(config)
except Exception:
return HfAutoModel.from_config(config)
except Exception as exc:
logger.warning("Failed to build meta HF model for collator, fallback to no model. Error: %s", exc)
return None
def run_pt(
model_args: "ModelArguments",
data_args: "DataArguments",
@@ -143,7 +185,7 @@ def run_pt(
)
data_collator = _data_collator_wrapper(data_collator)
trainer = McaTrainer(
trainer = CustomMcaTrainer(
model=model,
args=training_args,
tokenizer=tokenizer,
@@ -193,6 +235,7 @@ def run_sft(
_check_model_support(model_args)
model = AutoModel.from_pretrained(model_args.model_name_or_path, training_args)
collator_model = _build_meta_hf_model_for_collator(model_args)
# optional freezing for qwen_vl series
_freeze_model_parameters(model, finetuning_args)
@@ -200,6 +243,7 @@ def run_sft(
pad_to_max = training_args.expert_model_parallel_size is not None and training_args.expert_model_parallel_size > 1
data_collator = SFTDataCollatorWith4DAttentionMask(
template=template,
model=collator_model,
padding="max_length" if pad_to_max else "longest",
max_length=data_args.cutoff_len if pad_to_max else None,
pad_to_multiple_of=64,
@@ -208,7 +252,7 @@ def run_sft(
)
data_collator = _data_collator_wrapper(data_collator)
trainer = McaTrainer(
trainer = CustomMcaTrainer(
model=model,
args=training_args,
tokenizer=tokenizer,
@@ -247,6 +291,7 @@ def run_dpo(
_check_model_support(model_args)
model = AutoModel.from_pretrained(model_args.model_name_or_path, training_args)
collator_model = _build_meta_hf_model_for_collator(model_args)
_freeze_model_parameters(model, finetuning_args)
@@ -270,6 +315,7 @@ def run_dpo(
)
data_collator = PairwiseDataCollatorWithPadding(
template=template,
model=collator_model,
pad_to_multiple_of=64,
padding="max_length" if pad_to_max else "longest",
max_length=data_args.cutoff_len if pad_to_max else None,

View File

@@ -69,10 +69,10 @@ class CustomTrainer(Trainer):
verify_fp8_status(self.accelerator, training_args)
@override
def create_optimizer(self) -> "torch.optim.Optimizer":
def create_optimizer(self, *args, **kwargs) -> "torch.optim.Optimizer":
if self.optimizer is None:
self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args)
return super().create_optimizer()
return super().create_optimizer(*args, **kwargs)
@override
def create_scheduler(

View File

@@ -65,10 +65,10 @@ class PairwiseTrainer(Trainer):
self.add_callback(BAdamCallback)
@override
def create_optimizer(self) -> "torch.optim.Optimizer":
def create_optimizer(self, *args, **kwargs) -> "torch.optim.Optimizer":
if self.optimizer is None:
self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args)
return super().create_optimizer()
return super().create_optimizer(*args, **kwargs)
@override
def create_scheduler(

View File

@@ -128,10 +128,10 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
verify_fp8_status(self.accelerator, training_args)
@override
def create_optimizer(self) -> "torch.optim.Optimizer":
def create_optimizer(self, *args, **kwargs) -> "torch.optim.Optimizer":
if self.optimizer is None:
self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args)
return super().create_optimizer()
return super().create_optimizer(*args, **kwargs)
@override
def create_scheduler(

View File

@@ -65,6 +65,7 @@ def run_sft(
pad_to_multiple_of=8 if training_args.do_train else None, # for shift short attention
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id,
block_diag_attn=model_args.block_diag_attn,
neat_packing=data_args.neat_packing,
attn_implementation=getattr(model.config, "_attn_implementation", None),
compute_dtype=model_args.compute_dtype,
**tokenizer_module,
@@ -102,25 +103,6 @@ def run_sft(
gen_kwargs["pad_token_id"] = tokenizer.pad_token_id
# Initialize our Trainer
if model_args.use_kt:
from ktransformers.sft.lora import KTrainer # type: ignore
from ktransformers.util.globals import GLOBAL_CONFIG # type: ignore
GLOBAL_CONFIG._config["mod"] = "sft"
trainer = KTrainer(
model=model,
args=training_args,
tokenizer=tokenizer_module,
data_collator=data_collator,
callbacks=callbacks,
**dataset_module,
**metric_module,
)
trainer.model_accepts_loss_kwargs = False
model.config.use_cache = False
else:
trainer = CustomSeq2SeqTrainer(
model=model,
args=training_args,

View File

@@ -50,9 +50,9 @@ if is_apollo_available():
if is_ray_available():
import ray
from ray.util.state import list_nodes
from ray.util.placement_group import PlacementGroup, placement_group
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
from ray.util.state import list_nodes
if TYPE_CHECKING:
@@ -103,7 +103,7 @@ def create_modelcard_and_push(
kwargs["tags"] = kwargs["tags"] + ["unsloth"]
if model_args.use_kt:
kwargs["tags"] = kwargs["tags"] + ["ktransformers"]
kwargs["tags"] = kwargs["tags"] + ["kt-kernel"]
if not training_args.do_train:
pass

View File

@@ -24,10 +24,21 @@ from ..data import get_template_and_fix_tokenizer
from ..extras import logging
from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
from ..extras.misc import find_available_port, get_device_name, get_torch_device, infer_optim_dtype
from ..extras.packages import is_mcore_adapter_available, is_ray_available, is_transformers_version_greater_than
from ..extras.packages import (
is_hyper_parallel_available,
is_mcore_adapter_available,
is_ray_available,
is_transformers_version_greater_than,
)
from ..hparams import RayArguments, get_infer_args, get_ray_args, get_train_args, read_args
from ..model import load_model, load_tokenizer
from .callbacks import LogCallback, PissaConvertCallback, ReporterCallback
from .callbacks import (
LogCallback,
ModuleProfilerCallback,
PissaConvertCallback,
ReporterCallback,
TorchProfilerCallback,
)
from .dpo import run_dpo
from .kto import run_kto
from .ppo import run_ppo
@@ -69,9 +80,22 @@ def _training_function(config: dict[str, Any]) -> None:
if finetuning_args.early_stopping_steps is not None:
callbacks.append(EarlyStoppingCallback(early_stopping_patience=finetuning_args.early_stopping_steps))
if getattr(training_args, "enable_torch_profiler", False):
callbacks.append(TorchProfilerCallback(training_args))
if getattr(training_args, "profile_modules", None):
callbacks.append(ModuleProfilerCallback(training_args.profile_modules))
callbacks.append(ReporterCallback(model_args, data_args, finetuning_args, generating_args)) # add to last
if finetuning_args.stage in ["pt", "sft", "dpo"] and finetuning_args.use_mca:
if finetuning_args.stage == "sft" and finetuning_args.use_hyper_parallel:
if not is_hyper_parallel_available():
raise ImportError("hyper_parallel is not installed. Please install it with `pip install hyper_parallel`.")
from .hyper_parallel import run_sft as run_sft_hp
run_sft_hp(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
elif finetuning_args.stage in ["pt", "sft", "dpo"] and finetuning_args.use_mca:
if not is_mcore_adapter_available():
raise ImportError("mcore_adapter is not installed. Please install it with `pip install mcore-adapter`.")
if finetuning_args.stage == "pt":
@@ -168,7 +192,15 @@ def export_model(args: Optional[dict[str, Any]] = None) -> None:
if not is_transformers_version_greater_than("5.0.0"):
save_kwargs["safe_serialization"] = not model_args.export_legacy_format
try:
model.save_pretrained(**save_kwargs)
except NotImplementedError as err:
raise RuntimeError(
"Failed to export model: weight conversion reversal is not supported for this model architecture "
"(NotImplementedError in transformers.core_model_loading.reverse_op). "
"This is a known issue with transformers>=5.0 for certain model types (e.g. Mistral/Ministral). "
"Workarounds: (1) use transformers<5.0, or (2) report the issue to the transformers repository."
) from err
if model_args.export_hub_model_id is not None:
# Prepare push arguments (safe_serialization removed in transformers v5.0.0)

View File

@@ -123,6 +123,8 @@ class DistributedInterface:
if self._initialized:
return
self.dist_config = config
helper.set_device_index()
self._is_distributed = helper.is_distributed()
self._rank = helper.get_rank()

View File

@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from ..utils.types import AttentionFunction
from .arg_parser import InputArgument, get_args
from .arg_utils import BatchingStrategy, ModelClass, SampleBackend
from .data_args import DataArguments
@@ -21,6 +22,7 @@ from .training_args import TrainingArguments
__all__ = [
"AttentionFunction",
"BatchingStrategy",
"DataArguments",
"InputArgument",

View File

@@ -57,15 +57,12 @@ def get_args(args: InputArgument = None) -> tuple[ModelArguments, DataArguments,
print(f"Got unknown args, potentially deprecated arguments: {unknown_args}")
raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {unknown_args}")
model_args, data_args, training_args, sample_args = parsed_args
# Seed as early as possible after argument parsing so all downstream
# components (dist init, dataloader, model init in run_* entrypoints) share the same RNG state.
for arg in parsed_args:
seed = getattr(arg, "seed", None)
if seed is not None:
set_seed(seed)
break
set_seed(training_args.seed, full_determinism=training_args.full_determinism)
return tuple(parsed_args)
return model_args, data_args, training_args, sample_args
if __name__ == "__main__":

View File

@@ -15,6 +15,7 @@
from dataclasses import dataclass, field
from ..utils.types import AttentionFunction
from .arg_utils import ModelClass, PluginConfig, get_plugin_config
@@ -32,6 +33,12 @@ class ModelArguments:
default=False,
metadata={"help": "Trust remote code from Hugging Face."},
)
flash_attn: AttentionFunction = field(
default=AttentionFunction.SDPA,
metadata={
"help": "Attention implementation to use: eager, sdpa, or flash_attention_2. SDPA is the default implementation for models."
},
)
model_class: ModelClass = field(
default=ModelClass.LLM,
metadata={"help": "Model class from Hugging Face."},
@@ -54,6 +61,12 @@ class ModelArguments:
)
def __post_init__(self) -> None:
supported_flash_attn = [item.value for item in AttentionFunction]
if self.flash_attn not in supported_flash_attn:
raise ValueError(
f"Unsupported `flash_attn`: {self.flash_attn}. Supported values are: {supported_flash_attn}."
)
self.init_config = get_plugin_config(self.init_config)
self.peft_config = get_plugin_config(self.peft_config)
self.kernel_config = get_plugin_config(self.kernel_config)

Some files were not shown because too many files have changed in this diff Show More