mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-06-17 12:48:55 +08:00
Compare commits
70 Commits
b5cb7cb0e6
...
v0.9.5
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7af909522a | ||
|
|
e016d2480e | ||
|
|
7d719182c9 | ||
|
|
01398eb18d | ||
|
|
8e68764b65 | ||
|
|
16ff5a23cb | ||
|
|
bdcb92d035 | ||
|
|
7e20db5735 | ||
|
|
2322bf1cc2 | ||
|
|
368c48968f | ||
|
|
8b5ea65770 | ||
|
|
40e786d016 | ||
|
|
6b9df75ab9 | ||
|
|
ca50f22c38 | ||
|
|
53e77a9bfa | ||
|
|
55bd4944b6 | ||
|
|
7e09152275 | ||
|
|
1e503a982d | ||
|
|
8752280dd7 | ||
|
|
468723c5d9 | ||
|
|
887ee2b121 | ||
|
|
6b08b948c9 | ||
|
|
f7f3bfcbd7 | ||
|
|
3475198d1e | ||
|
|
50945ef850 | ||
|
|
2f0bef207a | ||
|
|
2092abc217 | ||
|
|
99464b3d03 | ||
|
|
9a0cfdccfa | ||
|
|
c8890c32db | ||
|
|
79c8332e4c | ||
|
|
e0bc3c1971 | ||
|
|
ecca167eb4 | ||
|
|
28a6ea1cdc | ||
|
|
f5d739b132 | ||
|
|
c4bbac49b2 | ||
|
|
c5aecaf31d | ||
|
|
436d26bc28 | ||
|
|
c109c061e5 | ||
|
|
fa09c01c36 | ||
|
|
eae6f0b541 | ||
|
|
acac63ef35 | ||
|
|
e5e8546493 | ||
|
|
97433c53b6 | ||
|
|
b5afabe3d2 | ||
|
|
df2e6edb7e | ||
|
|
d02fcd3588 | ||
|
|
c340aa2a33 | ||
|
|
1e536733c6 | ||
|
|
97d479fa92 | ||
|
|
ffbff33af3 | ||
|
|
833f6027b1 | ||
|
|
d91d8af89e | ||
|
|
e67ab9e2f2 | ||
|
|
2c4f121817 | ||
|
|
487f8b8191 | ||
|
|
78cad1e332 | ||
|
|
70653026f5 | ||
|
|
246192abd2 | ||
|
|
0258dc14d0 | ||
|
|
3045adf0ba | ||
|
|
a3d44e3152 | ||
|
|
edeb953bc7 | ||
|
|
d045794387 | ||
|
|
9501c3308a | ||
|
|
0ee1c42c2b | ||
|
|
3061f48d55 | ||
|
|
2d9bd2aa14 | ||
|
|
c0245c43fc | ||
|
|
eb976d75a2 |
105
.ai/CLAUDE.md
Normal file
105
.ai/CLAUDE.md
Normal file
@@ -0,0 +1,105 @@
|
||||
# CLAUDE.md
|
||||
|
||||
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
|
||||
|
||||
## Commands
|
||||
|
||||
```bash
|
||||
# Code style (auto-fix)
|
||||
make style
|
||||
|
||||
# Code quality check (no modifications)
|
||||
make quality
|
||||
|
||||
# Run all tests
|
||||
make test
|
||||
|
||||
# Run a single test file
|
||||
WANDB_DISABLED=true pytest -vv --import-mode=importlib tests/path/to/test_file.py
|
||||
|
||||
# Run tests matching a pattern
|
||||
WANDB_DISABLED=true pytest -vv --import-mode=importlib tests/ -k "test_name"
|
||||
|
||||
# License header check
|
||||
make license
|
||||
|
||||
# Build package
|
||||
make build
|
||||
```
|
||||
|
||||
The project uses `uv` as the preferred package manager. Commands automatically use `uv run` / `uvx` if `uv` is available.
|
||||
|
||||
## Architecture
|
||||
|
||||
LlamaFactory has two parallel architectures controlled by the `USE_V1` environment variable:
|
||||
|
||||
- **v0 (default):** `api, webui > chat, eval, train > data, model > hparams > extras`
|
||||
- **v1 (experimental, `USE_V1=1`):** `trainers > core > accelerator, plugins, config > utils`
|
||||
|
||||
Most active development happens in v0. The v1 architecture lives in `src/llamafactory/v1/`.
|
||||
|
||||
### Entry Points
|
||||
|
||||
CLI entry point is `llamafactory-cli` / `lmf` → `src/llamafactory/cli.py:main()`, which dispatches to `launcher.py` based on `USE_V1`.
|
||||
|
||||
Available subcommands: `train`, `chat`, `api`, `export`, `webchat`, `webui`, `env`, `version`, `help`.
|
||||
|
||||
### Training Flow (v0)
|
||||
|
||||
```
|
||||
run_exp() [tuner.py]
|
||||
→ read_args() → parse YAML/JSON config
|
||||
→ get_train_args() → produces typed argument dataclasses
|
||||
→ routes to: run_sft / run_dpo / run_ppo / run_rm / run_pt / run_kto
|
||||
→ optional: export_model()
|
||||
```
|
||||
|
||||
Training is invoked with a YAML config: `llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml`
|
||||
|
||||
### Configuration System
|
||||
|
||||
All training parameters are YAML/JSON config files. Argument parsing in `src/llamafactory/hparams/parser.py` produces four typed dataclasses:
|
||||
- `ModelArguments` — model/tokenizer selection, quantization
|
||||
- `DataArguments` — datasets, templates, preprocessing
|
||||
- `FinetuningArguments` — LoRA rank/target, training method (sft/dpo/ppo/rm/pt/kto)
|
||||
- `TrainingArguments` — extends HuggingFace's `TrainingArguments`
|
||||
|
||||
### Key Modules
|
||||
|
||||
| Module | Purpose |
|
||||
|--------|---------|
|
||||
| `src/llamafactory/model/loader.py` | Loads model + tokenizer; applies quantization, LoRA, patches |
|
||||
| `src/llamafactory/model/patcher.py` | Model-specific compatibility patches |
|
||||
| `src/llamafactory/data/template.py` | Prompt templates; `TEMPLATES` dict maps model family → format |
|
||||
| `src/llamafactory/data/mm_plugin.py` | Multi-modal (image/video/audio) data handling |
|
||||
| `src/llamafactory/data/processor/` | Per-stage data processors (supervised, pairwise, pretrain, etc.) |
|
||||
| `src/llamafactory/train/sft/` | SFT trainer; other stages follow same structure |
|
||||
| `src/llamafactory/chat/` | Inference engines: `hf_engine`, `vllm_engine`, `sglang_engine`, `kt_engine` |
|
||||
| `src/llamafactory/extras/constants.py` | Enums and constants used across the project |
|
||||
|
||||
### Adding Support for a New Model
|
||||
|
||||
1. Add a prompt template to `src/llamafactory/data/template.py` in the `TEMPLATES` dict
|
||||
2. Add any necessary model patches in `src/llamafactory/model/patcher.py`
|
||||
3. Add multi-modal support in `src/llamafactory/data/mm_plugin.py` if needed
|
||||
|
||||
### Distributed Training
|
||||
|
||||
Multi-GPU automatically uses `torchrun`. Additional backends:
|
||||
- **Ray:** Optional Ray cluster support
|
||||
- **HyperParallel FSDP2:** `src/llamafactory/train/hyper_parallel/`
|
||||
- **Megatron-core:** `src/llamafactory/train/mca/`
|
||||
|
||||
### Testing
|
||||
|
||||
- `tests/` — v0 tests; `tests_v1/` — v1 tests
|
||||
- Most training tests require GPU hardware
|
||||
- pytest markers: `@pytest.mark.slow`, `@pytest.mark.runs_on(['cuda'])`
|
||||
- Always set `WANDB_DISABLED=true` when running tests
|
||||
|
||||
### Code Style
|
||||
|
||||
- Ruff for linting and formatting (line length 119, Google-style docstrings)
|
||||
- Python 3.11+ syntax
|
||||
- Double quotes for strings
|
||||
- All new files must include Apache 2.0 license header (checked by `make license`)
|
||||
2
.github/workflows/docker.yml
vendored
2
.github/workflows/docker.yml
vendored
@@ -109,7 +109,7 @@ jobs:
|
||||
platforms: linux/amd64,linux/arm64
|
||||
file: ./docker/docker-npu/Dockerfile
|
||||
build-args: |
|
||||
BASE_IMAGE=quay.io/ascend/cann:8.3.rc2-a3-ubuntu22.04-py3.11
|
||||
BASE_IMAGE=quay.io/ascend/cann:9.0.0-a3-ubuntu22.04-py3.11
|
||||
push: ${{ github.event_name != 'pull_request' }}
|
||||
tags: |
|
||||
docker.io/hiyouga/llamafactory:${{ steps.version.outputs.tag }}-npu-a3
|
||||
|
||||
9
.github/workflows/tests.yml
vendored
9
.github/workflows/tests.yml
vendored
@@ -35,15 +35,12 @@ jobs:
|
||||
transformers:
|
||||
- ""
|
||||
include: # test backward compatibility
|
||||
- python: "3.11"
|
||||
os: "ubuntu-latest"
|
||||
transformers: "4.51.0"
|
||||
- python: "3.11"
|
||||
os: "ubuntu-latest"
|
||||
transformers: "4.53.0"
|
||||
- python: "3.11"
|
||||
os: "ubuntu-latest"
|
||||
transformers: "4.55.0"
|
||||
- python: "3.11"
|
||||
os: "ubuntu-latest"
|
||||
transformers: "4.57.1"
|
||||
|
||||
runs-on: ${{ matrix.os }}
|
||||
|
||||
|
||||
12
.github/workflows/tests_npu.yml
vendored
12
.github/workflows/tests_npu.yml
vendored
@@ -38,7 +38,7 @@ jobs:
|
||||
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
|
||||
|
||||
container:
|
||||
image: ascendai/cann:8.3.rc2-910b-ubuntu22.04-py3.11
|
||||
image: ascendai/cann:9.0.0-910b-ubuntu22.04-py3.11
|
||||
env:
|
||||
HF_ENDPOINT: https://hf-mirror.com
|
||||
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
||||
@@ -49,6 +49,12 @@ jobs:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Set nginx-cache for Ascend CI
|
||||
run: |
|
||||
sed -Ei 's@(ports|archive).ubuntu.com@cache-service.nginx-pypi-cache.svc.cluster.local:8081@g' /etc/apt/sources.list
|
||||
pip config set global.index-url http://cache-service.nginx-pypi-cache.svc.cluster.local/pypi/simple
|
||||
pip config set global.trusted-host cache-service.nginx-pypi-cache.svc.cluster.local
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v7
|
||||
with:
|
||||
@@ -59,8 +65,8 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
uv venv
|
||||
uv pip install -r requirements/npu.txt
|
||||
uv pip install -e .
|
||||
uv pip install -r requirements/npu.txt
|
||||
uv pip install -r requirements/dev.txt
|
||||
|
||||
- name: Install node
|
||||
@@ -83,5 +89,7 @@ jobs:
|
||||
make build
|
||||
|
||||
- name: Test with pytest
|
||||
shell: bash
|
||||
run: |
|
||||
source /usr/local/Ascend/ascend-toolkit/set_env.sh
|
||||
make test
|
||||
|
||||
21
README.md
21
README.md
@@ -15,8 +15,6 @@
|
||||
|
||||
[](https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing)
|
||||
[](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory)
|
||||
[](https://www.lab4ai.cn/course/detail?id=7c13e60f6137474eb40f6fd3983c0f46&utm_source=LLaMA-Factory)
|
||||
[](https://www.llamafactory.com.cn/?utm_source=LLaMA-Factory)
|
||||
[](https://huggingface.co/spaces/hiyouga/LLaMA-Board)
|
||||
[](https://modelscope.cn/studios/hiyouga/LLaMA-Board)
|
||||
[](https://novita.ai/templates-library/105981?sharer=88115474-394e-4bda-968e-b88e123d0c47)
|
||||
@@ -38,7 +36,7 @@
|
||||
|
||||
</div>
|
||||
|
||||
👋 Join our [WeChat](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/main.jpg), [NPU](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/npu.jpg), [Lab4AI](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/lab4ai.jpg), [LLaMA Factory Online](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/online.jpg) user group.
|
||||
👋 Join our [WeChat](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/main.jpg) and [NPU](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/npu.jpg) user groups.
|
||||
|
||||
\[ English | [中文](README_zh.md) \]
|
||||
|
||||
@@ -52,14 +50,11 @@ Start local training:
|
||||
Start cloud training:
|
||||
- **Colab (free)**: https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing
|
||||
- **PAI-DSW (free trial)**: https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory
|
||||
- **LLaMA Factory Online**: https://www.llamafactory.com.cn/?utm_source=LLaMA-Factory
|
||||
- **Alaya NeW (cloud GPU deal)**: https://docs.alayanew.com/docs/documents/useGuide/LLaMAFactory/mutiple/?utm_source=LLaMA-Factory
|
||||
|
||||
Read technical notes:
|
||||
- **Documentation (WIP)**: https://llamafactory.readthedocs.io/en/latest/
|
||||
- **Documentation (AMD GPU)**: https://rocm.docs.amd.com/projects/ai-developer-hub/en/latest/notebooks/fine_tune/llama_factory_llama3.html
|
||||
- **Official Blog**: https://blog.llamafactory.net/en/
|
||||
- **Official Course**: https://www.lab4ai.cn/course/detail?id=7c13e60f6137474eb40f6fd3983c0f46&utm_source=LLaMA-Factory
|
||||
|
||||
> [!NOTE]
|
||||
> Except for the above links, all other websites are unauthorized third-party websites. Please carefully use them.
|
||||
@@ -78,7 +73,6 @@ Read technical notes:
|
||||
- [Data Preparation](#data-preparation)
|
||||
- [Quickstart](#quickstart)
|
||||
- [Fine-Tuning with LLaMA Board GUI](#fine-tuning-with-llama-board-gui-powered-by-gradio)
|
||||
- [LLaMA Factory Online](#llama-factory-online)
|
||||
- [Build Docker](#build-docker)
|
||||
- [Deploy with OpenAI-style API and vLLM](#deploy-with-openai-style-api-and-vllm)
|
||||
- [Download from ModelScope Hub](#download-from-modelscope-hub)
|
||||
@@ -117,15 +111,11 @@ Read technical notes:
|
||||
|
||||
- 💡 [KTransformers Fine-Tuning × LLaMA Factory: Fine-tuning 1000 Billion models with 2 4090-GPU + CPU](https://blog.llamafactory.net/en/posts/ktransformers/) (English)
|
||||
- 💡 [Easy Dataset × LLaMA Factory: Enabling LLMs to Efficiently Learn Domain Knowledge](https://buaa-act.feishu.cn/wiki/GVzlwYcRFiR8OLkHbL6cQpYin7g) (English)
|
||||
- [Fine-tune a mental health LLM using LLaMA-Factory](https://www.lab4ai.cn/project/detail?id=25cce32ec131497b9e06a93336a0817f&type=project&utm_source=LLaMA-Factory) (Chinese)
|
||||
- [Fine-tune GPT-OSS for Role-Playing using LLaMA-Factory](https://docs.llamafactory.com.cn/docs/documents/best-practice/gptroleplay/?utm_source=LLaMA-Factory) (Chinese)
|
||||
- [A One-Stop Code-Free Model Reinforcement Learning and Deployment Platform based on LLaMA-Factory and EasyR1](https://aws.amazon.com/cn/blogs/china/building-llm-model-hub-based-on-llamafactory-and-easyr1/) (Chinese)
|
||||
- [How Apoidea Group enhances visual information extraction from banking documents with multimodal models using LLaMA-Factory on Amazon SageMaker HyperPod](https://aws.amazon.com/cn/blogs/machine-learning/how-apoidea-group-enhances-visual-information-extraction-from-banking-documents-with-multimodal-models-using-llama-factory-on-amazon-sagemaker-hyperpod/) (English)
|
||||
|
||||
<details><summary>All Blogs</summary>
|
||||
|
||||
- [Fine-tune Llama3.1-70B for Medical Diagnosis using LLaMA-Factory](https://docs.alayanew.com/docs/documents/bestPractice/bigModel/llama70B/?utm_source=LLaMA-Factory) (Chinese)
|
||||
- [Fine-tune Qwen2.5-VL for Autonomous Driving using LLaMA-Factory](https://docs.alayanew.com/docs/documents/useGuide/LLaMAFactory/mutiple/?utm_source=LLaMA-Factory) (Chinese)
|
||||
- [LLaMA Factory: Fine-tuning the DeepSeek-R1-Distill-Qwen-7B Model for News Classifier](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory_deepseek_r1_distill_7b) (Chinese)
|
||||
- [A One-Stop Code-Free Model Fine-Tuning \& Deployment Platform based on SageMaker and LLaMA-Factory](https://aws.amazon.com/cn/blogs/china/a-one-stop-code-free-model-fine-tuning-deployment-platform-based-on-sagemaker-and-llama-factory/) (Chinese)
|
||||
- [LLaMA Factory Multi-Modal Fine-Tuning Practice: Fine-Tuning Qwen2-VL for Personal Tourist Guide](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory_qwen2vl) (Chinese)
|
||||
@@ -319,7 +309,8 @@ Read technical notes:
|
||||
| [Pixtral](https://huggingface.co/mistralai) | 12B | pixtral |
|
||||
| [Qwen2 (Code/Math/MoE/QwQ)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen |
|
||||
| [Qwen3 (MoE/Instruct/Thinking/Next)](https://huggingface.co/Qwen) | 0.6B/1.7B/4B/8B/14B/32B/80B/235B | qwen3/qwen3_nothink |
|
||||
| [Qwen3.5](https://huggingface.co/Qwen) | 27B/35B/122B/397B | qwen3_5 |
|
||||
| [Qwen3.5](https://huggingface.co/Qwen) | 0.8B/2B/4B/9B/27B/35B/122B/397B | qwen3_5/qwen3_5_nothink |
|
||||
| [Qwen3.6](https://huggingface.co/Qwen) | 27B/35B | qwen3_6 |
|
||||
| [Qwen2-Audio](https://huggingface.co/Qwen) | 7B | qwen2_audio |
|
||||
| [Qwen2.5-Omni](https://huggingface.co/Qwen) | 3B/7B | qwen2_omni |
|
||||
| [Qwen3-Omni](https://huggingface.co/Qwen) | 30B | qwen3_omni |
|
||||
@@ -473,7 +464,7 @@ huggingface-cli login
|
||||
|
||||
| Mandatory | Minimum | Recommend |
|
||||
| ------------ | ------- | --------- |
|
||||
| python | 3.9 | 3.10 |
|
||||
| python | 3.11 | >=3.11 |
|
||||
| torch | 2.0.0 | 2.6.0 |
|
||||
| torchvision | 0.15.0 | 0.21.0 |
|
||||
| transformers | 4.49.0 | 4.50.0 |
|
||||
@@ -660,10 +651,6 @@ See [examples/README.md](examples/README.md) for advanced usage (including distr
|
||||
llamafactory-cli webui
|
||||
```
|
||||
|
||||
### LLaMA Factory Online
|
||||
|
||||
Read our [documentation](https://docs.llamafactory.com.cn/docs/documents/quickstart/getstarted/?utm_source=LLaMA-Factory).
|
||||
|
||||
### Build Docker
|
||||
|
||||
For CUDA users:
|
||||
|
||||
21
README_zh.md
21
README_zh.md
@@ -15,8 +15,6 @@
|
||||
|
||||
[](https://colab.research.google.com/drive/1d5KQtbemerlSDSxZIfAaWXhKr30QypiK?usp=sharing)
|
||||
[](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory)
|
||||
[](https://www.lab4ai.cn/course/detail?id=7c13e60f6137474eb40f6fd3983c0f46&utm_source=LLaMA-Factory)
|
||||
[](https://www.llamafactory.com.cn/?utm_source=LLaMA-Factory)
|
||||
[](https://huggingface.co/spaces/hiyouga/LLaMA-Board)
|
||||
[](https://modelscope.cn/studios/hiyouga/LLaMA-Board)
|
||||
[](https://novita.ai/templates-library/105981?sharer=88115474-394e-4bda-968e-b88e123d0c47)
|
||||
@@ -38,7 +36,7 @@
|
||||
|
||||
</div>
|
||||
|
||||
👋 加入我们的[微信群](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/main.jpg)、[NPU 用户群](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/npu.jpg)、[大模型实验室群](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/lab4ai.jpg) 或 [LLaMA Factory Online 用户群](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/online.png)。
|
||||
👋 加入我们的[微信群](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/main.jpg)和 [NPU 用户群](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/npu.jpg)。
|
||||
|
||||
\[ [English](README.md) | 中文 \]
|
||||
|
||||
@@ -52,8 +50,6 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
|
||||
开始云端训练:
|
||||
- **Colab(免费)**:https://colab.research.google.com/drive/1d5KQtbemerlSDSxZIfAaWXhKr30QypiK?usp=sharing
|
||||
- **PAI-DSW(免费试用)**:https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory
|
||||
- **LLaMA Factory Online(在线微调)**:https://www.llamafactory.com.cn/?utm_source=LLaMA-Factory
|
||||
- **九章智算云(算力优惠活动)**:https://docs.alayanew.com/docs/documents/useGuide/LLaMAFactory/mutiple/?utm_source=LLaMA-Factory
|
||||
|
||||
阅读技术文档:
|
||||
- **入门教程**:https://zhuanlan.zhihu.com/p/695287607
|
||||
@@ -61,7 +57,6 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
|
||||
- **框架文档**:https://llamafactory.readthedocs.io/zh-cn/latest/
|
||||
- **框架文档(昇腾 NPU)**:https://ascend.github.io/docs/sources/llamafactory/
|
||||
- **官方博客**:https://blog.llamafactory.net/
|
||||
- **官方课程**:https://www.lab4ai.cn/course/detail?id=7c13e60f6137474eb40f6fd3983c0f46&utm_source=LLaMA-Factory
|
||||
|
||||
> [!NOTE]
|
||||
> 除上述链接以外的其他网站均为未经许可的第三方网站,请小心甄别。
|
||||
@@ -80,7 +75,6 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
|
||||
- [数据准备](#数据准备)
|
||||
- [快速开始](#快速开始)
|
||||
- [LLaMA Board 可视化微调](#llama-board-可视化微调由-gradio-驱动)
|
||||
- [LLaMA Factory Online 在线微调](#llama-factory-online-在线微调)
|
||||
- [构建 Docker](#构建-docker)
|
||||
- [利用 vLLM 部署 OpenAI API](#利用-vllm-部署-openai-api)
|
||||
- [从魔搭社区下载](#从魔搭社区下载)
|
||||
@@ -119,15 +113,11 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
|
||||
|
||||
- 💡 [KTransformers Fine-Tuning × LLaMA Factory: 用2张4090级的GPU+CPU 微调 1000B规模的超大模型](https://swcil84qspu.feishu.cn/wiki/Z1sSwb2poijybxkyPEkcDG6enVc) (中文)
|
||||
- 💡 [Easy Dataset × LLaMA Factory: 让大模型高效学习领域知识](https://buaa-act.feishu.cn/wiki/KY9xwTGs1iqHrRkjXBwcZP9WnL9)(中文)
|
||||
- [使用 LLaMA-Factory 微调心理健康大模型](https://www.lab4ai.cn/project/detail?id=25cce32ec131497b9e06a93336a0817f&type=project&utm_source=LLaMA-Factory)(中文)
|
||||
- [使用 LLaMA-Factory 构建 GPT-OSS 角色扮演模型](https://docs.llamafactory.com.cn/docs/documents/best-practice/gptroleplay/?utm_source=LLaMA-Factory)(中文)
|
||||
- [基于 LLaMA-Factory 和 EasyR1 打造一站式无代码大模型强化学习和部署平台 LLM Model Hub](https://aws.amazon.com/cn/blogs/china/building-llm-model-hub-based-on-llamafactory-and-easyr1/)(中文)
|
||||
- [通过亚马逊 SageMaker HyperPod 上的 LLaMA-Factory 增强多模态模型银行文档的视觉信息提取](https://aws.amazon.com/cn/blogs/machine-learning/how-apoidea-group-enhances-visual-information-extraction-from-banking-documents-with-multimodal-models-using-llama-factory-on-amazon-sagemaker-hyperpod/)(英文)
|
||||
|
||||
<details><summary>全部博客</summary>
|
||||
|
||||
- [使用 LLaMA-Factory 微调 Llama3.1-70B 医学诊断模型](https://docs.alayanew.com/docs/documents/bestPractice/bigModel/llama70B/?utm_source=LLaMA-Factory)(中文)
|
||||
- [使用 LLaMA-Factory 微调 Qwen2.5-VL 实现自动驾驶场景微调](https://docs.alayanew.com/docs/documents/useGuide/LLaMAFactory/mutiple/?utm_source=LLaMA-Factory)(中文)
|
||||
- [LLaMA Factory:微调 DeepSeek-R1-Distill-Qwen-7B 模型实现新闻标题分类器](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory_deepseek_r1_distill_7b)(中文)
|
||||
- [基于 Amazon SageMaker 和 LLaMA-Factory 打造一站式无代码模型微调部署平台 Model Hub](https://aws.amazon.com/cn/blogs/china/a-one-stop-code-free-model-fine-tuning-deployment-platform-based-on-sagemaker-and-llama-factory/)(中文)
|
||||
- [LLaMA Factory 多模态微调实践:微调 Qwen2-VL 构建文旅大模型](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory_qwen2vl)(中文)
|
||||
@@ -321,7 +311,8 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
|
||||
| [Pixtral](https://huggingface.co/mistralai) | 12B | pixtral |
|
||||
| [Qwen2 (Code/Math/MoE/QwQ)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen |
|
||||
| [Qwen3 (MoE/Instruct/Thinking/Next)](https://huggingface.co/Qwen) | 0.6B/1.7B/4B/8B/14B/32B/80B/235B | qwen3/qwen3_nothink |
|
||||
| [Qwen3.5](https://huggingface.co/Qwen) | 27B/35B/122B/397B | qwen3_5 |
|
||||
| [Qwen3.5](https://huggingface.co/Qwen) | 0.8B/2B/4B/9B/27B/35B/122B/397B | qwen3_5/qwen3_5_nothink |
|
||||
| [Qwen3.6](https://huggingface.co/Qwen) | 27B/35B | qwen3_6 |
|
||||
| [Qwen2-Audio](https://huggingface.co/Qwen) | 7B | qwen2_audio |
|
||||
| [Qwen2.5-Omni](https://huggingface.co/Qwen) | 3B/7B | qwen2_omni |
|
||||
| [Qwen3-Omni](https://huggingface.co/Qwen) | 30B | qwen3_omni |
|
||||
@@ -475,7 +466,7 @@ huggingface-cli login
|
||||
|
||||
| 必需项 | 至少 | 推荐 |
|
||||
| ------------ | ------- | --------- |
|
||||
| python | 3.9 | 3.10 |
|
||||
| python | 3.11 | >=3.11 |
|
||||
| torch | 2.0.0 | 2.6.0 |
|
||||
| torchvision | 0.15.0 | 0.21.0 |
|
||||
| transformers | 4.49.0 | 4.50.0 |
|
||||
@@ -661,10 +652,6 @@ llamafactory-cli export examples/merge_lora/qwen3_lora_sft.yaml
|
||||
llamafactory-cli webui
|
||||
```
|
||||
|
||||
### LLaMA Factory Online 在线微调
|
||||
|
||||
详情阅读该[文档](https://docs.llamafactory.com.cn/docs/documents/quickstart/getstarted/?utm_source=LLaMA-Factory)。
|
||||
|
||||
### 构建 Docker
|
||||
|
||||
CUDA 用户:
|
||||
|
||||
@@ -236,6 +236,13 @@
|
||||
"ms_hub_url": "AI-ModelScope/sharegpt_gpt4",
|
||||
"formatting": "sharegpt"
|
||||
},
|
||||
"sgsc_b2b_entities": {
|
||||
"hf_hub_url": "Nooxus-AI/NOO-Verified-Global-Entities",
|
||||
"formatting": "sharegpt",
|
||||
"columns": {
|
||||
"messages": "messages"
|
||||
}
|
||||
},
|
||||
"ultrachat_200k": {
|
||||
"hf_hub_url": "HuggingFaceH4/ultrachat_200k",
|
||||
"ms_hub_url": "AI-ModelScope/ultrachat_200k",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# https://hub.docker.com/r/ascendai/cann/tags
|
||||
|
||||
ARG BASE_IMAGE=quay.io/ascend/cann:8.3.rc2-910b-ubuntu22.04-py3.11
|
||||
ARG BASE_IMAGE=quay.io/ascend/cann:9.0.0-910b-ubuntu22.04-py3.11
|
||||
FROM ${BASE_IMAGE}
|
||||
|
||||
# Installation arguments
|
||||
@@ -33,9 +33,11 @@ RUN pip config set global.index-url "${PIP_INDEX}" && \
|
||||
COPY . /app
|
||||
|
||||
# Install torch-npu
|
||||
RUN pip uninstall -y torch torchvision torchaudio && \
|
||||
pip install --no-cache-dir "torch==2.7.1" "torch-npu==2.7.1" "torchvision==0.22.1" "torchaudio==2.7.1" --index-url "${PYTORCH_INDEX}" && \
|
||||
pip install --no-cache-dir -e . --no-build-isolation && \
|
||||
RUN source /usr/local/Ascend/ascend-toolkit/set_env.sh
|
||||
RUN pip uninstall -y torch torchvision torchaudio
|
||||
RUN pip install --no-cache-dir -r requirements/npu.txt --index-url "${PYTORCH_INDEX}"
|
||||
RUN pip install --no-cache-dir -r requirements/deepspeed.txt
|
||||
RUN pip install --no-cache-dir -e . --no-build-isolation && \
|
||||
pip install --no-cache-dir -r requirements/metrics.txt --no-build-isolation
|
||||
|
||||
# Set up volumes
|
||||
|
||||
@@ -33,7 +33,7 @@ services:
|
||||
dockerfile: ./docker/docker-npu/Dockerfile
|
||||
context: ../..
|
||||
args:
|
||||
BASE_IMAGE: quay.io/ascend/cann:8.3.rc2-a3-ubuntu22.04-py3.11
|
||||
BASE_IMAGE: quay.io/ascend/cann:9.0.0-a3-ubuntu22.04-py3.11
|
||||
PIP_INDEX: https://pypi.org/simple
|
||||
container_name: llamafactory-a3
|
||||
image: llamafactory:npu-a3
|
||||
|
||||
@@ -96,7 +96,7 @@ FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500
|
||||
|
||||
### 支持弹性和容错的多机指令监督微调
|
||||
|
||||
要启动一个支持弹性节点和容错的多机指令微调,在每个节点上执行以下命令。弹性节点数量范围为 `MIN_NNODES:MAX_NNODES`,每个节点最多允许因为错误重启 `MAX_RESTARTS` 次。`RDZV_ID` 应设置为一个唯一的作业 ID(由参与该作业的所有节点共享)。更多新可以参考官方文档 [torchrun](https://docs.pytorch.org/docs/stable/elastic/run.html)。
|
||||
要启动一个支持弹性节点和容错的多机指令微调,在每个节点上执行以下命令。弹性节点数量范围为 `MIN_NNODES:MAX_NNODES`,每个节点最多允许因为错误重启 `MAX_RESTARTS` 次。`RDZV_ID` 应设置为一个唯一的作业 ID(由参与该作业的所有节点共享)。更多细节可以参考官方文档 [torchrun](https://docs.pytorch.org/docs/stable/elastic/run.html)。
|
||||
|
||||
```bash
|
||||
FORCE_TORCHRUN=1 MIN_NNODES=1 MAX_NNODES=3 MAX_RESTARTS=3 RDZV_ID=llamafactory MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/qwen3_full_sft.yaml
|
||||
|
||||
20
examples/accelerate/fsdp2_config_qwen35.yaml
Normal file
20
examples/accelerate/fsdp2_config_qwen35.yaml
Normal file
@@ -0,0 +1,20 @@
|
||||
compute_environment: LOCAL_MACHINE
|
||||
debug: false
|
||||
distributed_type: FSDP
|
||||
downcast_bf16: 'no'
|
||||
fsdp_config:
|
||||
fsdp_version: 2
|
||||
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
fsdp_transformer_layer_cls_to_wrap: Qwen3_5DecoderLayer,Qwen3_5VisionBlock
|
||||
fsdp_cpu_ram_efficient_loading: true
|
||||
fsdp_offload_params: false
|
||||
fsdp_reshard_after_forward: true
|
||||
fsdp_state_dict_type: FULL_STATE_DICT
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
mixed_precision: bf16
|
||||
num_machines: 1
|
||||
num_processes: 8 # Change to match your NPU count (e.g., 8 for A2, 16 for A3)
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
use_cpu: false
|
||||
47
examples/ascend/qwen3_5_full_sft_fsdp2.yaml
Normal file
47
examples/ascend/qwen3_5_full_sft_fsdp2.yaml
Normal file
@@ -0,0 +1,47 @@
|
||||
# Start FSDP2 full fine-tuning on Ascend NPU
|
||||
# Usage:
|
||||
# accelerate launch \
|
||||
# --config_file examples/accelerate/fsdp2_config_qwen35.yaml \
|
||||
# src/train.py examples/ascend/qwen3_5_full_sft_fsdp2.yaml
|
||||
#
|
||||
# Note: Change `num_processes` in fsdp2_config_qwen35.yaml to match your NPU count
|
||||
|
||||
### model
|
||||
model_name_or_path: Qwen/Qwen3.5-4B
|
||||
trust_remote_code: true
|
||||
use_v1_kernels: true
|
||||
flash_attn: fa2
|
||||
|
||||
### method
|
||||
stage: sft
|
||||
do_train: true
|
||||
finetuning_type: full
|
||||
|
||||
### dataset
|
||||
dataset: alpaca_en_demo
|
||||
template: qwen3_5_nothink
|
||||
cutoff_len: 2048
|
||||
max_samples: 1000
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
dataloader_num_workers: 4
|
||||
|
||||
### output
|
||||
output_dir: saves/Qwen3.5-4B/full/sft
|
||||
logging_steps: 1
|
||||
save_steps: 500
|
||||
max_steps: 500
|
||||
plot_loss: true
|
||||
overwrite_output_dir: true
|
||||
save_only_model: false
|
||||
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
|
||||
|
||||
### train
|
||||
per_device_train_batch_size: 8
|
||||
gradient_accumulation_steps: 1
|
||||
learning_rate: 1.0e-5
|
||||
lr_scheduler_type: cosine
|
||||
warmup_ratio: 0.1
|
||||
bf16: true
|
||||
ddp_timeout: 1800
|
||||
resume_from_checkpoint: null
|
||||
25
examples/ktransformers/accelerate/fsdp2_kt_bf16.yaml
Normal file
25
examples/ktransformers/accelerate/fsdp2_kt_bf16.yaml
Normal file
@@ -0,0 +1,25 @@
|
||||
compute_environment: LOCAL_MACHINE
|
||||
distributed_type: FSDP
|
||||
fsdp_config:
|
||||
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
fsdp_cpu_ram_efficient_loading: true
|
||||
fsdp_offload_params: false
|
||||
fsdp_reshard_after_forward: true
|
||||
fsdp_state_dict_type: FULL_STATE_DICT
|
||||
fsdp_version: 2
|
||||
mixed_precision: bf16
|
||||
num_machines: 1
|
||||
num_processes: 4 # Adjust based on your GPU count; 4 is suitable for 4 GPUs
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
use_cpu: false
|
||||
|
||||
kt_config:
|
||||
enabled: true
|
||||
kt_backend: AMXBF16 # Use with original BF16 expert weights.
|
||||
kt_num_threads: 96
|
||||
kt_tp_enabled: true
|
||||
kt_threadpool_count: 2
|
||||
kt_max_cache_depth: 2
|
||||
kt_share_backward_bb: true
|
||||
lora_rank: 8
|
||||
25
examples/ktransformers/accelerate/fsdp2_kt_int4.yaml
Normal file
25
examples/ktransformers/accelerate/fsdp2_kt_int4.yaml
Normal file
@@ -0,0 +1,25 @@
|
||||
compute_environment: LOCAL_MACHINE
|
||||
distributed_type: FSDP
|
||||
fsdp_config:
|
||||
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
fsdp_cpu_ram_efficient_loading: true
|
||||
fsdp_offload_params: false
|
||||
fsdp_reshard_after_forward: true
|
||||
fsdp_state_dict_type: FULL_STATE_DICT
|
||||
fsdp_version: 2
|
||||
mixed_precision: bf16
|
||||
num_machines: 1
|
||||
num_processes: 4 # Adjust based on your GPU count; 4 is suitable for 4 GPUs
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
use_cpu: false
|
||||
|
||||
kt_config:
|
||||
enabled: true
|
||||
kt_backend: AMXINT4 # Use with online-converted INT4 expert weights
|
||||
kt_num_threads: 96
|
||||
kt_tp_enabled: true
|
||||
kt_threadpool_count: 2
|
||||
kt_max_cache_depth: 2
|
||||
kt_share_backward_bb: true
|
||||
lora_rank: 8
|
||||
25
examples/ktransformers/accelerate/fsdp2_kt_int8.yaml
Normal file
25
examples/ktransformers/accelerate/fsdp2_kt_int8.yaml
Normal file
@@ -0,0 +1,25 @@
|
||||
compute_environment: LOCAL_MACHINE
|
||||
distributed_type: FSDP
|
||||
fsdp_config:
|
||||
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
fsdp_cpu_ram_efficient_loading: true
|
||||
fsdp_offload_params: false
|
||||
fsdp_reshard_after_forward: true
|
||||
fsdp_state_dict_type: FULL_STATE_DICT
|
||||
fsdp_version: 2
|
||||
mixed_precision: bf16
|
||||
num_machines: 1
|
||||
num_processes: 4 # Adjust based on your GPU count; 4 is suitable for 4 GPUs
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
use_cpu: false
|
||||
|
||||
kt_config:
|
||||
enabled: true
|
||||
kt_backend: AMXINT8 # Use with online-converted INT8 expert weights
|
||||
kt_num_threads: 96
|
||||
kt_tp_enabled: true
|
||||
kt_threadpool_count: 2
|
||||
kt_max_cache_depth: 2
|
||||
kt_share_backward_bb: true
|
||||
lora_rank: 8
|
||||
25
examples/ktransformers/accelerate/fsdp2_kt_int8_1gpu.yaml
Normal file
25
examples/ktransformers/accelerate/fsdp2_kt_int8_1gpu.yaml
Normal file
@@ -0,0 +1,25 @@
|
||||
compute_environment: LOCAL_MACHINE
|
||||
distributed_type: FSDP
|
||||
fsdp_config:
|
||||
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
fsdp_cpu_ram_efficient_loading: true
|
||||
fsdp_offload_params: false
|
||||
fsdp_reshard_after_forward: true
|
||||
fsdp_state_dict_type: FULL_STATE_DICT
|
||||
fsdp_version: 2
|
||||
mixed_precision: bf16
|
||||
num_machines: 1
|
||||
num_processes: 1 # Adjust based on your GPU count; 1 is suitable for 1 GPU
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
use_cpu: false
|
||||
|
||||
kt_config:
|
||||
enabled: true
|
||||
kt_backend: AMXINT8 # Use with online-converted INT8 expert weights
|
||||
kt_num_threads: 96
|
||||
kt_tp_enabled: true
|
||||
kt_threadpool_count: 2
|
||||
kt_max_cache_depth: 2
|
||||
kt_share_backward_bb: true
|
||||
lora_rank: 8
|
||||
25
examples/ktransformers/accelerate/fsdp2_kt_int8_8gpu.yaml
Normal file
25
examples/ktransformers/accelerate/fsdp2_kt_int8_8gpu.yaml
Normal file
@@ -0,0 +1,25 @@
|
||||
compute_environment: LOCAL_MACHINE
|
||||
distributed_type: FSDP
|
||||
fsdp_config:
|
||||
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
fsdp_cpu_ram_efficient_loading: true
|
||||
fsdp_offload_params: false
|
||||
fsdp_reshard_after_forward: true
|
||||
fsdp_state_dict_type: FULL_STATE_DICT
|
||||
fsdp_version: 2
|
||||
mixed_precision: bf16
|
||||
num_machines: 1
|
||||
num_processes: 8 # Adjust based on your GPU count; 8 is suitable for 8 GPUs
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
use_cpu: false
|
||||
|
||||
kt_config:
|
||||
enabled: true
|
||||
kt_backend: AMXINT8 # Use with online-converted INT8 expert weights
|
||||
kt_num_threads: 96
|
||||
kt_tp_enabled: true
|
||||
kt_threadpool_count: 2
|
||||
kt_max_cache_depth: 2
|
||||
kt_share_backward_bb: true
|
||||
lora_rank: 8
|
||||
@@ -1,10 +0,0 @@
|
||||
model_name_or_path: deepseek-ai/DeepSeek-V2-Lite
|
||||
adapter_name_or_path: saves/Kllama_deepseekV2
|
||||
template: deepseek
|
||||
infer_backend: ktransformers # choices: [huggingface, vllm, sglang, ktransformers]
|
||||
trust_remote_code: true
|
||||
|
||||
use_kt: true # use KTransformers as LoRA sft backend to inference
|
||||
kt_optimize_rule: examples/ktransformers/kt_optimize_rules/DeepSeek-V2-Lite-Chat-sft-amx.yaml
|
||||
cpu_infer: 32
|
||||
chunk_size: 8192
|
||||
@@ -1,9 +0,0 @@
|
||||
model_name_or_path: opensourcerelease/DeepSeek-V3-bf16
|
||||
template: deepseek3
|
||||
infer_backend: ktransformers # choices: [huggingface, vllm, sglang, ktransformers]
|
||||
trust_remote_code: true
|
||||
|
||||
use_kt: true # use KTransformers as LoRA sft backend to inference
|
||||
kt_optimize_rule: examples/ktransformers/kt_optimize_rules/DeepSeek-V3-Chat-sft-amx-multi-gpu.yaml
|
||||
cpu_infer: 32
|
||||
chunk_size: 8192
|
||||
@@ -1,10 +0,0 @@
|
||||
model_name_or_path: opensourcerelease/DeepSeek-V3-bf16
|
||||
adapter_name_or_path: saves/Kllama_deepseekV3
|
||||
template: deepseek3
|
||||
infer_backend: ktransformers # choices: [huggingface, vllm, sglang, ktransformers]
|
||||
trust_remote_code: true
|
||||
|
||||
use_kt: true # use KTransformers as LoRA sft backend to inference
|
||||
kt_optimize_rule: examples/ktransformers/kt_optimize_rules/DeepSeek-V3-Chat-sft-amx-multi-gpu.yaml
|
||||
cpu_infer: 32
|
||||
chunk_size: 8192
|
||||
@@ -1,10 +0,0 @@
|
||||
model_name_or_path: Qwen/Qwen3-235B-A22B-Instruct-2507
|
||||
adapter_name_or_path: saves/Kllama_Qwen3MoE_235bA22b
|
||||
template: qwen3_nothink
|
||||
infer_backend: ktransformers # choices: [huggingface, vllm, sglang, ktransformers]
|
||||
trust_remote_code: true
|
||||
|
||||
use_kt: true # use KTransformers as LoRA sft backend to inference
|
||||
kt_optimize_rule: examples/ktransformers/kt_optimize_rules/Qwen3Moe-sft-amx.yaml
|
||||
cpu_infer: 32
|
||||
chunk_size: 8192
|
||||
@@ -1,69 +0,0 @@
|
||||
- match:
|
||||
class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding
|
||||
replace:
|
||||
class: ktransformers.operators.RoPE.YarnRotaryEmbedding
|
||||
kwargs:
|
||||
generate_device: "cuda"
|
||||
prefill_device: "cuda"
|
||||
- match:
|
||||
name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
|
||||
class: torch.nn.Linear # only match modules matching name and class simultaneously
|
||||
replace:
|
||||
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
|
||||
kwargs:
|
||||
generate_device: "cuda"
|
||||
prefill_device: "cuda"
|
||||
generate_op: "KLinearTorch"
|
||||
prefill_op: "KLinearTorch"
|
||||
|
||||
- match:
|
||||
name: "^lm_head"
|
||||
class: torch.nn.Linear
|
||||
replace:
|
||||
class: ktransformers.operators.linear.KTransformersLinear
|
||||
kwargs:
|
||||
generate_device: "cuda"
|
||||
prefill_device: "cuda"
|
||||
generate_op: "KLinearTorch"
|
||||
prefill_op: "KLinearTorch"
|
||||
|
||||
- match:
|
||||
name: "^model\\.layers\\..*\\.mlp$"
|
||||
class: ktransformers.models.modeling_deepseek.DeepseekV2MoE
|
||||
replace:
|
||||
class: ktransformers.operators.experts.KDeepseekV2MoE # mlp module with custom forward function
|
||||
kwargs:
|
||||
generate_device: "cuda"
|
||||
prefill_device: "cuda"
|
||||
- match:
|
||||
name: "^model\\.layers\\..*\\.mlp\\.experts$"
|
||||
replace:
|
||||
class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
|
||||
kwargs:
|
||||
prefill_device: "cuda"
|
||||
prefill_op: "KExpertsTorch"
|
||||
generate_device: "cpu"
|
||||
generate_op: "KSFTExpertsCPU"
|
||||
out_device: "cuda"
|
||||
backend: "AMXInt8" # or "AMXBF16" or "llamafile" (default)
|
||||
recursive: False # don't recursively inject submodules of this module
|
||||
- match:
|
||||
name: "^model\\.layers\\..*\\.self_attn$"
|
||||
replace:
|
||||
class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
|
||||
kwargs:
|
||||
generate_device: "cuda"
|
||||
prefill_device: "cuda"
|
||||
- match:
|
||||
name: "^model$"
|
||||
replace:
|
||||
class: "ktransformers.operators.models.KDeepseekV2Model"
|
||||
kwargs:
|
||||
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
|
||||
- match:
|
||||
name: "^model.embed_tokens"
|
||||
replace:
|
||||
class: "default"
|
||||
kwargs:
|
||||
generate_device: "cpu"
|
||||
prefill_device: "cpu"
|
||||
@@ -1,68 +0,0 @@
|
||||
- match:
|
||||
class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding
|
||||
replace:
|
||||
class: ktransformers.operators.RoPE.YarnRotaryEmbedding
|
||||
kwargs:
|
||||
generate_device: "cuda"
|
||||
prefill_device: "cuda"
|
||||
- match:
|
||||
name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
|
||||
class: torch.nn.Linear # only match modules matching name and class simultaneously
|
||||
replace:
|
||||
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
|
||||
kwargs:
|
||||
generate_device: "cuda"
|
||||
prefill_device: "cuda"
|
||||
generate_op: "KLinearMarlin"
|
||||
prefill_op: "KLinearTorch"
|
||||
|
||||
- match:
|
||||
name: "^lm_head"
|
||||
class: torch.nn.Linear
|
||||
replace:
|
||||
class: ktransformers.operators.linear.KTransformersLinear
|
||||
kwargs:
|
||||
generate_device: "cuda"
|
||||
prefill_device: "cuda"
|
||||
generate_op: "KLinearMarlin"
|
||||
prefill_op: "KLinearTorch"
|
||||
|
||||
- match:
|
||||
name: "^model\\.layers\\..*\\.mlp$"
|
||||
class: ktransformers.models.modeling_deepseek.DeepseekV2MoE
|
||||
replace:
|
||||
class: ktransformers.operators.experts.KDeepseekV2MoE # mlp module with custom forward function
|
||||
kwargs:
|
||||
generate_device: "cuda"
|
||||
prefill_device: "cuda"
|
||||
- match:
|
||||
name: "^model\\.layers\\..*\\.mlp\\.experts$"
|
||||
replace:
|
||||
class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
|
||||
kwargs:
|
||||
prefill_device: "cuda"
|
||||
prefill_op: "KExpertsTorch"
|
||||
generate_device: "cpu"
|
||||
generate_op: "KExpertsCPU"
|
||||
out_device: "cuda"
|
||||
recursive: False # don't recursively inject submodules of this module
|
||||
- match:
|
||||
name: "^model\\.layers\\..*\\.self_attn$"
|
||||
replace:
|
||||
class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
|
||||
kwargs:
|
||||
generate_device: "cuda"
|
||||
prefill_device: "cuda"
|
||||
- match:
|
||||
name: "^model$"
|
||||
replace:
|
||||
class: "ktransformers.operators.models.KDeepseekV2Model"
|
||||
kwargs:
|
||||
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
|
||||
- match:
|
||||
name: "^model.embed_tokens"
|
||||
replace:
|
||||
class: "default"
|
||||
kwargs:
|
||||
generate_device: "cpu"
|
||||
prefill_device: "cpu"
|
||||
@@ -1,139 +0,0 @@
|
||||
- match:
|
||||
name: "^model.embed_tokens"
|
||||
replace:
|
||||
class: "default"
|
||||
kwargs:
|
||||
generate_device: "cpu"
|
||||
prefill_device: "cpu"
|
||||
|
||||
- match:
|
||||
name: "^model\\.layers\\.(0|[1-9])\\."
|
||||
class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding
|
||||
replace:
|
||||
class: ktransformers.operators.RoPE.YarnRotaryEmbedding
|
||||
kwargs:
|
||||
generate_device: "cuda:0"
|
||||
prefill_device: "cuda:0"
|
||||
- match:
|
||||
name: "^model\\.layers\\.([12][0-9])\\."
|
||||
class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding
|
||||
replace:
|
||||
class: ktransformers.operators.RoPE.YarnRotaryEmbedding
|
||||
kwargs:
|
||||
generate_device: "cuda:1"
|
||||
prefill_device: "cuda:1"
|
||||
|
||||
- match:
|
||||
name: "^model\\.layers\\.(0|[1-9])\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
|
||||
class: torch.nn.Linear # only match modules matching name and class simultaneously
|
||||
replace:
|
||||
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
|
||||
kwargs:
|
||||
generate_device: "cuda:0"
|
||||
prefill_device: "cuda:0"
|
||||
generate_op: "KLinearTorch"
|
||||
prefill_op: "KLinearTorch"
|
||||
|
||||
- match:
|
||||
name: "^model\\.layers\\.([12][0-9])\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
|
||||
class: torch.nn.Linear # only match modules matching name and class simultaneously
|
||||
replace:
|
||||
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
|
||||
kwargs:
|
||||
generate_device: "cuda:1"
|
||||
prefill_device: "cuda:1"
|
||||
generate_op: "KLinearTorch"
|
||||
prefill_op: "KLinearTorch"
|
||||
|
||||
- match:
|
||||
name: "^model\\.layers\\.(0|[1-9])\\.mlp$"
|
||||
class: ktransformers.models.modeling_deepseek.DeepseekV2MoE
|
||||
replace:
|
||||
class: ktransformers.operators.experts.KDeepseekV2MoE # mlp module with custom forward function
|
||||
kwargs:
|
||||
generate_device: "cuda:0"
|
||||
prefill_device: "cuda:0"
|
||||
- match:
|
||||
name: "^model\\.layers\\.([12][0-9])\\.mlp$"
|
||||
class: ktransformers.models.modeling_deepseek.DeepseekV2MoE
|
||||
replace:
|
||||
class: ktransformers.operators.experts.KDeepseekV2MoE # mlp module with custom forward function
|
||||
kwargs:
|
||||
generate_device: "cuda:1"
|
||||
prefill_device: "cuda:1"
|
||||
|
||||
- match:
|
||||
name: "^model\\.layers\\.(0|[1-9])\\.mlp\\.experts$"
|
||||
replace:
|
||||
class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
|
||||
kwargs:
|
||||
prefill_device: "cuda:0"
|
||||
prefill_op: "KExpertsTorch"
|
||||
generate_device: "cpu"
|
||||
generate_op: "KSFTExpertsCPU"
|
||||
out_device: "cuda:0"
|
||||
backend: "AMXInt8" # or "AMXBF16" or "llamafile" (default)
|
||||
recursive: False # don't recursively inject submodules of this module
|
||||
|
||||
- match:
|
||||
name: "^model\\.layers\\.([12][0-9])\\.mlp\\.experts$"
|
||||
replace:
|
||||
class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
|
||||
kwargs:
|
||||
prefill_device: "cuda:1"
|
||||
prefill_op: "KExpertsTorch"
|
||||
generate_device: "cpu"
|
||||
generate_op: "KSFTExpertsCPU"
|
||||
out_device: "cuda:1"
|
||||
backend: "AMXInt8" # or "AMXBF16" or "llamafile" (default)
|
||||
recursive: False # don't recursively inject submodules of this module
|
||||
|
||||
- match:
|
||||
name: "^model\\.layers\\.(0|[1-9])\\.self_attn$"
|
||||
replace:
|
||||
class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
|
||||
kwargs:
|
||||
generate_device: "cuda:0"
|
||||
prefill_device: "cuda:0"
|
||||
- match:
|
||||
name: "^model\\.layers\\.([12][0-9])\\.self_attn$"
|
||||
replace:
|
||||
class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
|
||||
kwargs:
|
||||
generate_device: "cuda:1"
|
||||
prefill_device: "cuda:1"
|
||||
- match:
|
||||
name: "^model$"
|
||||
replace:
|
||||
class: "ktransformers.operators.models.KDeepseekV2Model"
|
||||
kwargs:
|
||||
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
|
||||
transfer_map:
|
||||
10: "cuda:1"
|
||||
|
||||
- match:
|
||||
name: "^model\\.layers\\.(0|[1-9])\\."
|
||||
replace:
|
||||
class: "default"
|
||||
kwargs:
|
||||
generate_device: "cuda:0"
|
||||
prefill_device: "cuda:0"
|
||||
|
||||
- match:
|
||||
name: "^lm_head"
|
||||
class: torch.nn.Linear
|
||||
replace:
|
||||
class: ktransformers.operators.linear.KTransformersLinear
|
||||
kwargs:
|
||||
generate_device: "cuda:1"
|
||||
prefill_device: "cuda:1"
|
||||
generate_op: "KLinearTorch"
|
||||
prefill_op: "KLinearTorch"
|
||||
|
||||
- match:
|
||||
name: "(^model\\.layers\\.([12][0-9])\\.)|(model.norm)"
|
||||
replace:
|
||||
class: "default"
|
||||
kwargs:
|
||||
generate_device: "cuda:1"
|
||||
prefill_device: "cuda:1"
|
||||
@@ -1,69 +0,0 @@
|
||||
- match:
|
||||
class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding
|
||||
replace:
|
||||
class: ktransformers.operators.RoPE.YarnRotaryEmbedding
|
||||
kwargs:
|
||||
generate_device: "cuda"
|
||||
prefill_device: "cuda"
|
||||
- match:
|
||||
name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
|
||||
class: torch.nn.Linear # only match modules matching name and class simultaneously
|
||||
replace:
|
||||
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
|
||||
kwargs:
|
||||
generate_device: "cuda"
|
||||
prefill_device: "cuda"
|
||||
generate_op: "KLinearTorch"
|
||||
prefill_op: "KLinearTorch"
|
||||
|
||||
- match:
|
||||
name: "^lm_head"
|
||||
class: torch.nn.Linear
|
||||
replace:
|
||||
class: ktransformers.operators.linear.KTransformersLinear
|
||||
kwargs:
|
||||
generate_device: "cuda"
|
||||
prefill_device: "cuda"
|
||||
generate_op: "KLinearTorch"
|
||||
prefill_op: "KLinearTorch"
|
||||
|
||||
- match:
|
||||
name: "^model\\.layers\\..*\\.mlp$"
|
||||
class: ktransformers.models.modeling_deepseek.DeepseekV2MoE
|
||||
replace:
|
||||
class: ktransformers.operators.experts.KDeepseekV2MoE # mlp module with custom forward function
|
||||
kwargs:
|
||||
generate_device: "cuda"
|
||||
prefill_device: "cuda"
|
||||
- match:
|
||||
name: "^model\\.layers\\..*\\.mlp\\.experts$"
|
||||
replace:
|
||||
class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
|
||||
kwargs:
|
||||
prefill_device: "cpu"
|
||||
prefill_op: "KExpertsTorch"
|
||||
generate_device: "cpu"
|
||||
generate_op: "KSFTExpertsCPU"
|
||||
out_device: "cuda"
|
||||
backend: "AMXInt8" # or "AMXBF16" or "llamafile" (default)
|
||||
recursive: False # don't recursively inject submodules of this module
|
||||
- match:
|
||||
name: "^model\\.layers\\..*\\.self_attn$"
|
||||
replace:
|
||||
class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
|
||||
kwargs:
|
||||
generate_device: "cuda"
|
||||
prefill_device: "cuda"
|
||||
- match:
|
||||
name: "^model$"
|
||||
replace:
|
||||
class: "ktransformers.operators.models.KDeepseekV2Model"
|
||||
kwargs:
|
||||
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
|
||||
- match:
|
||||
name: "^model.embed_tokens"
|
||||
replace:
|
||||
class: "default"
|
||||
kwargs:
|
||||
generate_device: "cpu"
|
||||
prefill_device: "cpu"
|
||||
@@ -1,68 +0,0 @@
|
||||
- match:
|
||||
class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding
|
||||
replace:
|
||||
class: ktransformers.operators.RoPE.YarnRotaryEmbedding
|
||||
kwargs:
|
||||
generate_device: "cuda"
|
||||
prefill_device: "cuda"
|
||||
- match:
|
||||
name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
|
||||
class: torch.nn.Linear # only match modules matching name and class simultaneously
|
||||
replace:
|
||||
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
|
||||
kwargs:
|
||||
generate_device: "cuda"
|
||||
prefill_device: "cuda"
|
||||
generate_op: "KLinearTorch"
|
||||
prefill_op: "KLinearTorch"
|
||||
|
||||
- match:
|
||||
name: "^lm_head"
|
||||
class: torch.nn.Linear
|
||||
replace:
|
||||
class: ktransformers.operators.linear.KTransformersLinear
|
||||
kwargs:
|
||||
generate_device: "cuda"
|
||||
prefill_device: "cuda"
|
||||
generate_op: "KLinearTorch"
|
||||
prefill_op: "KLinearTorch"
|
||||
|
||||
- match:
|
||||
name: "^model\\.layers\\..*\\.mlp$"
|
||||
class: ktransformers.models.modeling_deepseek.DeepseekV2MoE
|
||||
replace:
|
||||
class: ktransformers.operators.experts.KDeepseekV2MoE # mlp module with custom forward function
|
||||
kwargs:
|
||||
generate_device: "cuda"
|
||||
prefill_device: "cuda"
|
||||
- match:
|
||||
name: "^model\\.layers\\..*\\.mlp\\.experts$"
|
||||
replace:
|
||||
class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
|
||||
kwargs:
|
||||
prefill_device: "cpu"
|
||||
prefill_op: "KExpertsTorch"
|
||||
generate_device: "cpu"
|
||||
generate_op: "KSFTExpertsCPU"
|
||||
out_device: "cuda"
|
||||
recursive: False # don't recursively inject submodules of this module
|
||||
- match:
|
||||
name: "^model\\.layers\\..*\\.self_attn$"
|
||||
replace:
|
||||
class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
|
||||
kwargs:
|
||||
generate_device: "cuda"
|
||||
prefill_device: "cuda"
|
||||
- match:
|
||||
name: "^model$"
|
||||
replace:
|
||||
class: "ktransformers.operators.models.KDeepseekV2Model"
|
||||
kwargs:
|
||||
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
|
||||
- match:
|
||||
name: "^model.embed_tokens"
|
||||
replace:
|
||||
class: "default"
|
||||
kwargs:
|
||||
generate_device: "cpu"
|
||||
prefill_device: "cpu"
|
||||
@@ -1,68 +0,0 @@
|
||||
- match:
|
||||
class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding
|
||||
replace:
|
||||
class: ktransformers.operators.RoPE.YarnRotaryEmbedding
|
||||
kwargs:
|
||||
generate_device: "cuda"
|
||||
prefill_device: "cuda"
|
||||
- match:
|
||||
name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
|
||||
class: torch.nn.Linear # only match modules matching name and class simultaneously
|
||||
replace:
|
||||
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
|
||||
kwargs:
|
||||
generate_device: "cuda"
|
||||
prefill_device: "cuda"
|
||||
generate_op: "KLinearMarlin"
|
||||
prefill_op: "KLinearTorch"
|
||||
|
||||
- match:
|
||||
name: "^lm_head"
|
||||
class: torch.nn.Linear
|
||||
replace:
|
||||
class: ktransformers.operators.linear.KTransformersLinear
|
||||
kwargs:
|
||||
generate_device: "cuda"
|
||||
prefill_device: "cuda"
|
||||
generate_op: "KLinearMarlin"
|
||||
prefill_op: "KLinearTorch"
|
||||
|
||||
- match:
|
||||
name: "^model\\.layers\\..*\\.mlp$"
|
||||
class: ktransformers.models.modeling_deepseek.DeepseekV2MoE
|
||||
replace:
|
||||
class: ktransformers.operators.experts.KDeepseekV2MoE # mlp module with custom forward function
|
||||
kwargs:
|
||||
generate_device: "cuda"
|
||||
prefill_device: "cuda"
|
||||
- match:
|
||||
name: "^model\\.layers\\..*\\.mlp\\.experts$"
|
||||
replace:
|
||||
class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
|
||||
kwargs:
|
||||
prefill_device: "cuda"
|
||||
prefill_op: "KExpertsTorch"
|
||||
generate_device: "cpu"
|
||||
generate_op: "KExpertsCPU"
|
||||
out_device: "cuda"
|
||||
recursive: False # don't recursively inject submodules of this module
|
||||
- match:
|
||||
name: "^model\\.layers\\..*\\.self_attn$"
|
||||
replace:
|
||||
class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
|
||||
kwargs:
|
||||
generate_device: "cuda"
|
||||
prefill_device: "cuda"
|
||||
- match:
|
||||
name: "^model$"
|
||||
replace:
|
||||
class: "ktransformers.operators.models.KDeepseekV2Model"
|
||||
kwargs:
|
||||
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
|
||||
- match:
|
||||
name: "^model.embed_tokens"
|
||||
replace:
|
||||
class: "default"
|
||||
kwargs:
|
||||
generate_device: "cpu"
|
||||
prefill_device: "cpu"
|
||||
@@ -1,77 +0,0 @@
|
||||
- match:
|
||||
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
|
||||
replace:
|
||||
class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3
|
||||
kwargs:
|
||||
generate_device: "cuda"
|
||||
prefill_device: "cuda"
|
||||
|
||||
- match:
|
||||
name: "^lm_head$" # regular expression
|
||||
class: torch.nn.Linear # only match modules matching name and class simultaneously
|
||||
replace:
|
||||
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
|
||||
kwargs:
|
||||
generate_device: "cuda"
|
||||
prefill_device: "cuda"
|
||||
generate_op: "KLinearMarlin"
|
||||
prefill_op: "KLinearTorch"
|
||||
|
||||
- match:
|
||||
name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
|
||||
class: torch.nn.Linear # only match modules matching name and class simultaneously
|
||||
replace:
|
||||
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
|
||||
kwargs:
|
||||
generate_device: "cuda"
|
||||
prefill_device: "cuda"
|
||||
generate_op: "KLinearMarlin"
|
||||
prefill_op: "KLinearTorch"
|
||||
- match:
|
||||
name: "^model\\.layers\\..*\\.mlp$"
|
||||
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
|
||||
replace:
|
||||
class: ktransformers.operators.experts.KDeepseekV3MoE # mlp module with custom forward function
|
||||
kwargs:
|
||||
generate_device: "cuda"
|
||||
prefill_device: "cuda"
|
||||
- match:
|
||||
class: ktransformers.models.modeling_deepseek_v3.MoEGate
|
||||
replace:
|
||||
class: ktransformers.operators.gate.KMoEGate
|
||||
kwargs:
|
||||
generate_device: "cuda:0"
|
||||
prefill_device: "cuda:0"
|
||||
- match:
|
||||
name: "^model\\.layers\\..*\\.mlp\\.experts$"
|
||||
replace:
|
||||
class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
|
||||
kwargs:
|
||||
prefill_device: "cuda"
|
||||
prefill_op: "KExpertsTorch"
|
||||
generate_device: "cpu"
|
||||
generate_op: "KExpertsCPU"
|
||||
out_device: "cuda"
|
||||
backend: "AMXInt8" # or "AMXBF16" or "llamafile" (default)
|
||||
recursive: False # don't recursively inject submodules of this module
|
||||
- match:
|
||||
name: "^model\\.layers\\..*\\.self_attn$"
|
||||
replace:
|
||||
class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
|
||||
kwargs:
|
||||
generate_device: "cuda"
|
||||
prefill_device: "cuda"
|
||||
absorb_for_prefill: False # change this to True to enable long context(prefill may slower).
|
||||
- match:
|
||||
name: "^model$"
|
||||
replace:
|
||||
class: "ktransformers.operators.models.KDeepseekV2Model"
|
||||
kwargs:
|
||||
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
|
||||
- match:
|
||||
name: "^model.embed_tokens"
|
||||
replace:
|
||||
class: "default"
|
||||
kwargs:
|
||||
generate_device: "cpu"
|
||||
prefill_device: "cpu"
|
||||
@@ -1,392 +0,0 @@
|
||||
- match:
|
||||
name: "^model.embed_tokens"
|
||||
replace:
|
||||
class: "default"
|
||||
kwargs:
|
||||
generate_device: "cpu"
|
||||
prefill_device: "cpu"
|
||||
|
||||
# === Rotary Embedding Replacement ===
|
||||
|
||||
# GPU 0: layers 0–14
|
||||
- match:
|
||||
name: "^model\\.layers\\.([0-9]|1[0-4])\\."
|
||||
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
|
||||
replace:
|
||||
class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3
|
||||
kwargs:
|
||||
generate_device: "cuda:0"
|
||||
prefill_device: "cuda:0"
|
||||
|
||||
# GPU 1: layers 15–29
|
||||
- match:
|
||||
name: "^model\\.layers\\.(1[5-9]|2[0-9])\\."
|
||||
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
|
||||
replace:
|
||||
class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3
|
||||
kwargs:
|
||||
generate_device: "cuda:1"
|
||||
prefill_device: "cuda:1"
|
||||
|
||||
# GPU 2: layers 30–44
|
||||
- match:
|
||||
name: "^model\\.layers\\.(3[0-9]|4[0-4])\\."
|
||||
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
|
||||
replace:
|
||||
class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3
|
||||
kwargs:
|
||||
generate_device: "cuda:2"
|
||||
prefill_device: "cuda:2"
|
||||
|
||||
# GPU 3: layers 45–60
|
||||
- match:
|
||||
name: "^model\\.layers\\.(4[5-9]|5[0-9]|60)\\."
|
||||
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
|
||||
replace:
|
||||
class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3
|
||||
kwargs:
|
||||
generate_device: "cuda:3"
|
||||
prefill_device: "cuda:3"
|
||||
|
||||
# === Linear Layers Replacement (excluding self_attn.kv_b_proj) ===
|
||||
|
||||
# GPU 0: layers 0–14
|
||||
- match:
|
||||
name: "^model\\.layers\\.([0-9]|1[0-4])\\.(?!self_attn\\.kv_b_proj).*$"
|
||||
class: torch.nn.Linear
|
||||
replace:
|
||||
class: ktransformers.operators.linear.KTransformersLinear
|
||||
kwargs:
|
||||
generate_device: "cuda:0"
|
||||
prefill_device: "cuda:0"
|
||||
generate_op: "KLinearTorch"
|
||||
prefill_op: "KLinearTorch"
|
||||
|
||||
# GPU 1: layers 15–29
|
||||
- match:
|
||||
name: "^model\\.layers\\.(1[5-9]|2[0-9])\\.(?!self_attn\\.kv_b_proj).*$"
|
||||
class: torch.nn.Linear
|
||||
replace:
|
||||
class: ktransformers.operators.linear.KTransformersLinear
|
||||
kwargs:
|
||||
generate_device: "cuda:1"
|
||||
prefill_device: "cuda:1"
|
||||
generate_op: "KLinearTorch"
|
||||
prefill_op: "KLinearTorch"
|
||||
|
||||
# GPU 2: layers 30–44
|
||||
- match:
|
||||
name: "^model\\.layers\\.(3[0-9]|4[0-4])\\.(?!self_attn\\.kv_b_proj).*$"
|
||||
class: torch.nn.Linear
|
||||
replace:
|
||||
class: ktransformers.operators.linear.KTransformersLinear
|
||||
kwargs:
|
||||
generate_device: "cuda:2"
|
||||
prefill_device: "cuda:2"
|
||||
generate_op: "KLinearTorch"
|
||||
prefill_op: "KLinearTorch"
|
||||
|
||||
# GPU 3: layers 45–60
|
||||
- match:
|
||||
name: "^model\\.layers\\.(4[5-9]|5[0-9]|60)\\.(?!self_attn\\.kv_b_proj).*$"
|
||||
class: torch.nn.Linear
|
||||
replace:
|
||||
class: ktransformers.operators.linear.KTransformersLinear
|
||||
kwargs:
|
||||
generate_device: "cuda:3"
|
||||
prefill_device: "cuda:3"
|
||||
generate_op: "KLinearTorch"
|
||||
prefill_op: "KLinearTorch"
|
||||
|
||||
# === MLP (MoE) Replacement ===
|
||||
|
||||
# GPU 0: layers 0–14
|
||||
- match:
|
||||
name: "^model\\.layers\\.([0-9]|1[0-4])\\.mlp$"
|
||||
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
|
||||
replace:
|
||||
class: ktransformers.operators.experts.KDeepseekV3MoE
|
||||
kwargs:
|
||||
generate_device: "cuda:0"
|
||||
prefill_device: "cuda:0"
|
||||
|
||||
# GPU 1: layers 15–29
|
||||
- match:
|
||||
name: "^model\\.layers\\.(1[5-9]|2[0-9])\\.mlp$"
|
||||
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
|
||||
replace:
|
||||
class: ktransformers.operators.experts.KDeepseekV3MoE
|
||||
kwargs:
|
||||
generate_device: "cuda:1"
|
||||
prefill_device: "cuda:1"
|
||||
|
||||
# GPU 2: layers 30–44
|
||||
- match:
|
||||
name: "^model\\.layers\\.(3[0-9]|4[0-4])\\.mlp$"
|
||||
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
|
||||
replace:
|
||||
class: ktransformers.operators.experts.KDeepseekV3MoE
|
||||
kwargs:
|
||||
generate_device: "cuda:2"
|
||||
prefill_device: "cuda:2"
|
||||
|
||||
# GPU 3: layers 45–60
|
||||
- match:
|
||||
name: "^model\\.layers\\.(4[5-9]|5[0-9]|60)\\.mlp$"
|
||||
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
|
||||
replace:
|
||||
class: ktransformers.operators.experts.KDeepseekV3MoE
|
||||
kwargs:
|
||||
generate_device: "cuda:3"
|
||||
prefill_device: "cuda:3"
|
||||
|
||||
# === MLP Gate Replacement ===
|
||||
|
||||
# GPU 0: layers 0–14
|
||||
- match:
|
||||
name: "^model\\.layers\\.([0-9]|1[0-4])\\.mlp\\.gate$"
|
||||
class: ktransformers.models.modeling_deepseek_v3.MoEGate
|
||||
replace:
|
||||
class: ktransformers.operators.gate.KMoEGate
|
||||
kwargs:
|
||||
generate_device: "cuda:0"
|
||||
prefill_device: "cuda:0"
|
||||
|
||||
# GPU 1: layers 15–29
|
||||
- match:
|
||||
name: "^model\\.layers\\.(1[5-9]|2[0-9])\\.mlp\\.gate$"
|
||||
class: ktransformers.models.modeling_deepseek_v3.MoEGate
|
||||
replace:
|
||||
class: ktransformers.operators.gate.KMoEGate
|
||||
kwargs:
|
||||
generate_device: "cuda:1"
|
||||
prefill_device: "cuda:1"
|
||||
|
||||
# GPU 2: layers 30–44
|
||||
- match:
|
||||
name: "^model\\.layers\\.(3[0-9]|4[0-4])\\.mlp\\.gate$"
|
||||
class: ktransformers.models.modeling_deepseek_v3.MoEGate
|
||||
replace:
|
||||
class: ktransformers.operators.gate.KMoEGate
|
||||
kwargs:
|
||||
generate_device: "cuda:2"
|
||||
prefill_device: "cuda:2"
|
||||
|
||||
# GPU 3: layers 45–60
|
||||
- match:
|
||||
name: "^model\\.layers\\.(4[5-9]|5[0-9]|60)\\.mlp\\.gate$"
|
||||
class: ktransformers.models.modeling_deepseek_v3.MoEGate
|
||||
replace:
|
||||
class: ktransformers.operators.gate.KMoEGate
|
||||
kwargs:
|
||||
generate_device: "cuda:3"
|
||||
prefill_device: "cuda:3"
|
||||
|
||||
# === MLP Experts Replacement ===
|
||||
# replace with marlin expert. Open and modify layer-num as needed.
|
||||
# Each layer of malin experts takes about 6GB of GPU memory.
|
||||
# !!!Do remember 'close' cuda graph if you are using marlin expert.!!!
|
||||
# !!!KExpertsTorch is untested, we don't have enough VRAM.!!!
|
||||
|
||||
# GPU 0: layers 3–4
|
||||
# - match:
|
||||
# name: "^model\\.layers\\.([3-4])\\.mlp\\.experts$"
|
||||
# replace:
|
||||
# class: ktransformers.operators.experts.KTransformersExperts
|
||||
# kwargs:
|
||||
# generate_device: "cuda:0"
|
||||
# generate_op: "KExpertsMarlin"
|
||||
# recursive: False
|
||||
|
||||
# # GPU 1: layers 15–17
|
||||
# - match:
|
||||
# name: "^model\\.layers\\.(1[5-7])\\.mlp\\.experts$"
|
||||
# replace:
|
||||
# class: ktransformers.operators.experts.KTransformersExperts
|
||||
# kwargs:
|
||||
# generate_device: "cuda:1"
|
||||
# generate_op: "KExpertsMarlin"
|
||||
# recursive: False
|
||||
|
||||
# # GPU 2: layers 30–32
|
||||
# - match:
|
||||
# name: "^model\\.layers\\.(3[0-2])\\.mlp\\.experts$"
|
||||
# replace:
|
||||
# class: ktransformers.operators.experts.KTransformersExperts
|
||||
# kwargs:
|
||||
# generate_device: "cuda:2"
|
||||
# generate_op: "KExpertsMarlin"
|
||||
# recursive: False
|
||||
|
||||
# # GPU 3: layers 45–46
|
||||
# - match:
|
||||
# name: "^model\\.layers\\.(4[5-6])\\.mlp\\.experts$"
|
||||
# replace:
|
||||
# class: ktransformers.operators.experts.KTransformersExperts
|
||||
# kwargs:
|
||||
# generate_device: "cuda:3"
|
||||
# generate_op: "KExpertsMarlin"
|
||||
# recursive: False
|
||||
|
||||
|
||||
# === MLP Experts Replacement ===
|
||||
|
||||
# GPU 0: layers 0–14
|
||||
- match:
|
||||
name: "^model\\.layers\\.([0-9]|1[0-4])\\.mlp\\.experts$"
|
||||
replace:
|
||||
class: ktransformers.operators.experts.KTransformersExperts
|
||||
kwargs:
|
||||
prefill_device: "cuda:0"
|
||||
prefill_op: "KExpertsTorch"
|
||||
generate_device: "cpu"
|
||||
generate_op: "KSFTExpertsCPU"
|
||||
out_device: "cuda:0"
|
||||
backend: "AMXInt8" # or "AMXBF16" or "llamafile" (default)
|
||||
recursive: False
|
||||
|
||||
# GPU 1: layers 15–29
|
||||
- match:
|
||||
name: "^model\\.layers\\.(1[5-9]|2[0-9])\\.mlp\\.experts$"
|
||||
replace:
|
||||
class: ktransformers.operators.experts.KTransformersExperts
|
||||
kwargs:
|
||||
prefill_device: "cuda:1"
|
||||
prefill_op: "KExpertsTorch"
|
||||
generate_device: "cpu"
|
||||
generate_op: "KSFTExpertsCPU"
|
||||
out_device: "cuda:1"
|
||||
backend: "AMXInt8" # or "AMXBF16" or "llamafile" (default)
|
||||
recursive: False
|
||||
|
||||
# GPU 2: layers 30–44
|
||||
- match:
|
||||
name: "^model\\.layers\\.(3[0-9]|4[0-4])\\.mlp\\.experts$"
|
||||
replace:
|
||||
class: ktransformers.operators.experts.KTransformersExperts
|
||||
kwargs:
|
||||
prefill_device: "cuda:2"
|
||||
prefill_op: "KExpertsTorch"
|
||||
generate_device: "cpu"
|
||||
generate_op: "KSFTExpertsCPU"
|
||||
out_device: "cuda:2"
|
||||
backend: "AMXInt8" # or "AMXBF16" or "llamafile" (default)
|
||||
recursive: False
|
||||
|
||||
# GPU 3: layers 45–60
|
||||
- match:
|
||||
name: "^model\\.layers\\.(4[5-9]|5[0-9]|60)\\.mlp\\.experts$"
|
||||
replace:
|
||||
class: ktransformers.operators.experts.KTransformersExperts
|
||||
kwargs:
|
||||
prefill_device: "cuda:3"
|
||||
prefill_op: "KExpertsTorch"
|
||||
generate_device: "cpu"
|
||||
generate_op: "KSFTExpertsCPU"
|
||||
out_device: "cuda:3"
|
||||
backend: "AMXInt8" # or "AMXBF16" or "llamafile" (default)
|
||||
recursive: False
|
||||
|
||||
# === Self-Attention Replacement ===
|
||||
|
||||
# GPU 0: layers 0–14
|
||||
- match:
|
||||
name: "^model\\.layers\\.([0-9]|1[0-4])\\.self_attn$"
|
||||
replace:
|
||||
class: ktransformers.operators.attention.KDeepseekV2Attention
|
||||
kwargs:
|
||||
generate_device: "cuda:0"
|
||||
prefill_device: "cuda:0"
|
||||
absorb_for_prefill: False
|
||||
|
||||
# GPU 1: layers 15–29
|
||||
- match:
|
||||
name: "^model\\.layers\\.(1[5-9]|2[0-9])\\.self_attn$"
|
||||
replace:
|
||||
class: ktransformers.operators.attention.KDeepseekV2Attention
|
||||
kwargs:
|
||||
generate_device: "cuda:1"
|
||||
prefill_device: "cuda:1"
|
||||
absorb_for_prefill: False
|
||||
|
||||
# GPU 2: layers 30–44
|
||||
- match:
|
||||
name: "^model\\.layers\\.(3[0-9]|4[0-4])\\.self_attn$"
|
||||
replace:
|
||||
class: ktransformers.operators.attention.KDeepseekV2Attention
|
||||
kwargs:
|
||||
generate_device: "cuda:2"
|
||||
prefill_device: "cuda:2"
|
||||
absorb_for_prefill: False
|
||||
|
||||
# GPU 3: layers 45–60
|
||||
- match:
|
||||
name: "^model\\.layers\\.(4[5-9]|5[0-9]|60)\\.self_attn$"
|
||||
replace:
|
||||
class: ktransformers.operators.attention.KDeepseekV2Attention
|
||||
kwargs:
|
||||
generate_device: "cuda:3"
|
||||
prefill_device: "cuda:3"
|
||||
absorb_for_prefill: False
|
||||
|
||||
# === Overall Model Replacement with Transfer Map ===
|
||||
|
||||
- match:
|
||||
name: "^model$"
|
||||
replace:
|
||||
class: "ktransformers.operators.models.KDeepseekV2Model"
|
||||
kwargs:
|
||||
per_layer_prefill_intput_threshold: 0 # 0 means close layer‐wise prefill
|
||||
transfer_map:
|
||||
15: "cuda:1" # Layers 15+ on GPU 1
|
||||
30: "cuda:2" # Layers 30+ on GPU 2
|
||||
45: "cuda:3" # Layers 45+ on GPU 3
|
||||
|
||||
# === Default Catch-All for Other Modules ===
|
||||
|
||||
# GPU 0: layers 0–14
|
||||
- match:
|
||||
name: "^model\\.layers\\.([0-9]|1[0-4])\\."
|
||||
replace:
|
||||
class: "default"
|
||||
kwargs:
|
||||
generate_device: "cuda:0"
|
||||
prefill_device: "cuda:0"
|
||||
|
||||
# GPU 1: layers 15–29
|
||||
- match:
|
||||
name: "^model\\.layers\\.(1[5-9]|2[0-9])\\."
|
||||
replace:
|
||||
class: "default"
|
||||
kwargs:
|
||||
generate_device: "cuda:1"
|
||||
prefill_device: "cuda:1"
|
||||
|
||||
# GPU 2: layers 30–44
|
||||
- match:
|
||||
name: "^model\\.layers\\.(3[0-9]|4[0-4])\\."
|
||||
replace:
|
||||
class: "default"
|
||||
kwargs:
|
||||
generate_device: "cuda:2"
|
||||
prefill_device: "cuda:2"
|
||||
|
||||
- match:
|
||||
name: "^lm_head"
|
||||
class: torch.nn.Linear
|
||||
replace:
|
||||
class: ktransformers.operators.linear.KTransformersLinear
|
||||
kwargs:
|
||||
generate_device: "cuda:3"
|
||||
prefill_device: "cuda:3"
|
||||
generate_op: "KLinearTorch"
|
||||
prefill_op: "KLinearTorch"
|
||||
|
||||
# For final modules (model.norm), ensure they are on GPU 3 (as in your original config)
|
||||
- match:
|
||||
name: "(^model\\.layers\\.(4[5-9]|5[0-9]|60)\\.)|(^model\\.norm)"
|
||||
replace:
|
||||
class: "default"
|
||||
kwargs:
|
||||
generate_device: "cuda:3"
|
||||
prefill_device: "cuda:3"
|
||||
@@ -1,156 +0,0 @@
|
||||
- match:
|
||||
name: "^model.embed_tokens"
|
||||
replace:
|
||||
class: "default"
|
||||
kwargs:
|
||||
generate_device: "cpu"
|
||||
prefill_device: "cpu"
|
||||
|
||||
- match:
|
||||
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\."
|
||||
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
|
||||
replace:
|
||||
class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3
|
||||
kwargs:
|
||||
generate_device: "cuda:0"
|
||||
prefill_device: "cuda:0"
|
||||
- match:
|
||||
name: "^model\\.layers\\.([3456][0-9])\\."
|
||||
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
|
||||
replace:
|
||||
class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3
|
||||
kwargs:
|
||||
generate_device: "cuda:1"
|
||||
prefill_device: "cuda:1"
|
||||
|
||||
- match:
|
||||
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.(?!self_attn\\.kv_b_proj).*$" # regular expression
|
||||
class: torch.nn.Linear # only match modules matching name and class simultaneously
|
||||
replace:
|
||||
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
|
||||
kwargs:
|
||||
generate_device: "cuda:0"
|
||||
prefill_device: "cuda:0"
|
||||
generate_op: "KLinearTorch"
|
||||
prefill_op: "KLinearTorch"
|
||||
|
||||
- match:
|
||||
name: "^model\\.layers\\.([3456][0-9])\\.(?!self_attn\\.kv_b_proj).*$" # regular expression
|
||||
class: torch.nn.Linear # only match modules matching name and class simultaneously
|
||||
replace:
|
||||
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
|
||||
kwargs:
|
||||
generate_device: "cuda:1"
|
||||
prefill_device: "cuda:1"
|
||||
generate_op: "KLinearTorch"
|
||||
prefill_op: "KLinearTorch"
|
||||
|
||||
- match:
|
||||
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.mlp$"
|
||||
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
|
||||
replace:
|
||||
class: ktransformers.operators.experts.KDeepseekV3MoE # mlp module with custom forward function
|
||||
kwargs:
|
||||
generate_device: "cuda:0"
|
||||
prefill_device: "cuda:0"
|
||||
- match:
|
||||
name: "^model\\.layers\\.([3456][0-9])\\.mlp$"
|
||||
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
|
||||
replace:
|
||||
class: ktransformers.operators.experts.KDeepseekV3MoE # mlp module with custom forward function
|
||||
kwargs:
|
||||
generate_device: "cuda:1"
|
||||
prefill_device: "cuda:1"
|
||||
|
||||
- match:
|
||||
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.mlp\\.gate$"
|
||||
class: ktransformers.models.modeling_deepseek_v3.MoEGate
|
||||
replace:
|
||||
class: ktransformers.operators.gate.KMoEGate
|
||||
kwargs:
|
||||
generate_device: "cuda:0"
|
||||
prefill_device: "cuda:0"
|
||||
- match:
|
||||
name: "^model\\.layers\\.([3456][0-9])\\.mlp\\.gate$"
|
||||
class: ktransformers.models.modeling_deepseek_v3.MoEGate
|
||||
replace:
|
||||
class: ktransformers.operators.gate.KMoEGate # mlp module with custom forward function
|
||||
kwargs:
|
||||
generate_device: "cuda:1"
|
||||
prefill_device: "cuda:1"
|
||||
|
||||
- match:
|
||||
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.mlp\\.experts$"
|
||||
replace:
|
||||
class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
|
||||
kwargs:
|
||||
prefill_device: "cuda:0"
|
||||
prefill_op: "KExpertsTorch"
|
||||
generate_device: "cpu"
|
||||
generate_op: "KSFTExpertsCPU"
|
||||
out_device: "cuda:0"
|
||||
backend: "AMXInt8" # or "AMXBF16" or "llamafile" (default)
|
||||
recursive: False # don't recursively inject submodules of this module
|
||||
|
||||
- match:
|
||||
name: "^model\\.layers\\.([3456][0-9])\\.mlp\\.experts$"
|
||||
replace:
|
||||
class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
|
||||
kwargs:
|
||||
prefill_device: "cuda:1"
|
||||
prefill_op: "KExpertsTorch"
|
||||
generate_device: "cpu"
|
||||
generate_op: "KSFTExpertsCPU"
|
||||
out_device: "cuda:1"
|
||||
backend: "AMXInt8" # or "AMXBF16" or "llamafile" (default)
|
||||
recursive: False # don't recursively inject submodules of this module
|
||||
|
||||
- match:
|
||||
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.self_attn$"
|
||||
replace:
|
||||
class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
|
||||
kwargs:
|
||||
generate_device: "cuda:0"
|
||||
prefill_device: "cuda:0"
|
||||
- match:
|
||||
name: "^model\\.layers\\.([3456][0-9])\\.self_attn$"
|
||||
replace:
|
||||
class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
|
||||
kwargs:
|
||||
generate_device: "cuda:1"
|
||||
prefill_device: "cuda:1"
|
||||
- match:
|
||||
name: "^model$"
|
||||
replace:
|
||||
class: "ktransformers.operators.models.KDeepseekV2Model"
|
||||
kwargs:
|
||||
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
|
||||
transfer_map:
|
||||
30: "cuda:1"
|
||||
|
||||
- match:
|
||||
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\."
|
||||
replace:
|
||||
class: "default"
|
||||
kwargs:
|
||||
generate_device: "cuda:0"
|
||||
prefill_device: "cuda:0"
|
||||
|
||||
- match:
|
||||
name: "^lm_head"
|
||||
class: torch.nn.Linear
|
||||
replace:
|
||||
class: ktransformers.operators.linear.KTransformersLinear
|
||||
kwargs:
|
||||
generate_device: "cuda:1"
|
||||
prefill_device: "cuda:1"
|
||||
generate_op: "KLinearTorch"
|
||||
prefill_op: "KLinearTorch"
|
||||
|
||||
- match:
|
||||
name: "(^model\\.layers\\.([3456][0-9])\\.)|(model.norm)"
|
||||
replace:
|
||||
class: "default"
|
||||
kwargs:
|
||||
generate_device: "cuda:1"
|
||||
prefill_device: "cuda:1"
|
||||
@@ -1,77 +0,0 @@
|
||||
- match:
|
||||
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
|
||||
replace:
|
||||
class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3
|
||||
kwargs:
|
||||
generate_device: "cuda"
|
||||
prefill_device: "cuda"
|
||||
|
||||
- match:
|
||||
name: "^lm_head$" # regular expression
|
||||
class: torch.nn.Linear # only match modules matching name and class simultaneously
|
||||
replace:
|
||||
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
|
||||
kwargs:
|
||||
generate_device: "cuda"
|
||||
prefill_device: "cuda"
|
||||
generate_op: "KLinearTorch"
|
||||
prefill_op: "KLinearTorch"
|
||||
|
||||
- match:
|
||||
name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
|
||||
class: torch.nn.Linear # only match modules matching name and class simultaneously
|
||||
replace:
|
||||
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
|
||||
kwargs:
|
||||
generate_device: "cuda"
|
||||
prefill_device: "cuda"
|
||||
generate_op: "KLinearTorch"
|
||||
prefill_op: "KLinearTorch"
|
||||
- match:
|
||||
name: "^model\\.layers\\..*\\.mlp$"
|
||||
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
|
||||
replace:
|
||||
class: ktransformers.operators.experts.KDeepseekV3MoE # mlp module with custom forward function
|
||||
kwargs:
|
||||
generate_device: "cuda"
|
||||
prefill_device: "cuda"
|
||||
- match:
|
||||
class: ktransformers.models.modeling_deepseek_v3.MoEGate
|
||||
replace:
|
||||
class: ktransformers.operators.gate.KMoEGate
|
||||
kwargs:
|
||||
generate_device: "cuda:0"
|
||||
prefill_device: "cuda:0"
|
||||
- match:
|
||||
name: "^model\\.layers\\..*\\.mlp\\.experts$"
|
||||
replace:
|
||||
class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
|
||||
kwargs:
|
||||
prefill_device: "cuda"
|
||||
prefill_op: "KExpertsTorch"
|
||||
generate_device: "cpu"
|
||||
generate_op: "KSFTExpertsCPU"
|
||||
out_device: "cuda"
|
||||
backend: "AMXInt8" # or "AMXBF16" or "llamafile" (default)
|
||||
recursive: False # don't recursively inject submodules of this module
|
||||
- match:
|
||||
name: "^model\\.layers\\..*\\.self_attn$"
|
||||
replace:
|
||||
class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
|
||||
kwargs:
|
||||
generate_device: "cuda"
|
||||
prefill_device: "cuda"
|
||||
absorb_for_prefill: False # change this to True to enable long context(prefill may slower).
|
||||
- match:
|
||||
name: "^model$"
|
||||
replace:
|
||||
class: "ktransformers.operators.models.KDeepseekV2Model"
|
||||
kwargs:
|
||||
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
|
||||
- match:
|
||||
name: "^model.embed_tokens"
|
||||
replace:
|
||||
class: "default"
|
||||
kwargs:
|
||||
generate_device: "cpu"
|
||||
prefill_device: "cpu"
|
||||
@@ -1,80 +0,0 @@
|
||||
- match:
|
||||
class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeRotaryEmbedding
|
||||
replace:
|
||||
class: ktransformers.operators.RoPE.RotaryEmbedding
|
||||
kwargs:
|
||||
generate_device: "cuda"
|
||||
prefill_device: "cuda"
|
||||
|
||||
- match:
|
||||
name: "^lm_head$" # regular expression
|
||||
class: torch.nn.Linear # only match modules matching name and class simultaneously
|
||||
replace:
|
||||
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
|
||||
kwargs:
|
||||
generate_device: "cuda"
|
||||
prefill_device: "cuda"
|
||||
generate_op: "KLinearTorch"
|
||||
prefill_op: "KLinearTorch"
|
||||
|
||||
# - match:
|
||||
# name: "^model\\.layers\\..*$" # regular expression
|
||||
# class: torch.nn.Linear # only match modules matching name and class simultaneously
|
||||
# replace:
|
||||
# class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
|
||||
# kwargs:
|
||||
# generate_device: "cuda"
|
||||
# prefill_device: "cuda"
|
||||
# generate_op: "KLinearTorch"
|
||||
# prefill_op: "KLinearTorch"
|
||||
- match:
|
||||
name: "^model\\.layers\\.(?!.*mlp\\.shared_expert_gate).*$" # regular expression
|
||||
class: torch.nn.Linear # only match modules matching name and class simultaneously
|
||||
replace:
|
||||
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
|
||||
kwargs:
|
||||
generate_device: "cuda"
|
||||
prefill_device: "cuda"
|
||||
generate_op: "KLinearTorch"
|
||||
prefill_op: "KLinearTorch"
|
||||
- match:
|
||||
name: "^model\\.layers\\..*\\.mlp$"
|
||||
replace:
|
||||
class: ktransformers.operators.experts.KQwen3MoeSparseMoeBlock # mlp module with custom forward function
|
||||
kwargs:
|
||||
generate_device: "cuda"
|
||||
prefill_device: "cuda"
|
||||
|
||||
- match:
|
||||
name: "^model\\.layers\\..*\\.mlp\\.experts$"
|
||||
replace:
|
||||
class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
|
||||
kwargs:
|
||||
prefill_device: "cuda"
|
||||
prefill_op: "KExpertsTorch"
|
||||
generate_device: "cpu"
|
||||
generate_op: "KSFTExpertsCPU"
|
||||
out_device: "cuda"
|
||||
backend: "AMXInt8" # or "AMXBF16" or "AMXInt8"
|
||||
recursive: False # don't recursively inject submodules of this module
|
||||
- match:
|
||||
name: "^model\\.layers\\..*\\.self_attn$"
|
||||
replace:
|
||||
class: ktransformers.operators.attention.KQwen3MoeAttention # optimized MLA implementation
|
||||
kwargs:
|
||||
generate_device: "cuda"
|
||||
prefill_device: "cuda"
|
||||
- match:
|
||||
name: "^model.embed_tokens"
|
||||
replace:
|
||||
class: "default"
|
||||
kwargs:
|
||||
generate_device: "cpu"
|
||||
prefill_device: "cpu"
|
||||
|
||||
- match:
|
||||
name: "^model$"
|
||||
replace:
|
||||
class: "ktransformers.operators.models.KQwen3MoeModel"
|
||||
kwargs:
|
||||
per_layer_prefill_intput_threshold: 0
|
||||
@@ -19,7 +19,7 @@ preprocessing_num_workers: 16
|
||||
dataloader_num_workers: 4
|
||||
|
||||
### output
|
||||
output_dir: saves/Kllama_deepseekV2
|
||||
output_dir: saves/KT_FT_deepseekV2
|
||||
logging_steps: 10
|
||||
save_steps: 500
|
||||
plot_loss: true
|
||||
@@ -39,14 +39,7 @@ ddp_timeout: 180000000
|
||||
resume_from_checkpoint: null
|
||||
|
||||
### ktransformers
|
||||
use_kt: true # use KTransformers as LoRA sft backend
|
||||
kt_optimize_rule: examples/ktransformers/kt_optimize_rules/DeepSeek-V2-Lite-Chat-sft-amx.yaml
|
||||
cpu_infer: 32
|
||||
chunk_size: 8192
|
||||
|
||||
### eval
|
||||
# eval_dataset: alpaca_en_demo
|
||||
# val_size: 0.1
|
||||
# per_device_eval_batch_size: 1
|
||||
# eval_strategy: steps
|
||||
# eval_steps: 500
|
||||
use_kt: true
|
||||
# Pair with fsdp2_kt_bf16.yaml for original BF16 checkpoints.
|
||||
# For pre-converted expert weights, uncomment kt_weight_path and use fsdp2_kt_int8.yaml or fsdp2_kt_int4.yaml.
|
||||
# kt_weight_path: /path/to/DeepSeek-V2-Lite-AMXINT8
|
||||
@@ -1,5 +1,5 @@
|
||||
### model
|
||||
model_name_or_path: opensourcerelease/DeepSeek-V3-bf16
|
||||
model_name_or_path: deepseek-ai/DeepSeek-V3-0324-BF16 # need to convert to BF16 checkpoint first
|
||||
trust_remote_code: true
|
||||
|
||||
### method
|
||||
@@ -19,7 +19,7 @@ preprocessing_num_workers: 16
|
||||
dataloader_num_workers: 4
|
||||
|
||||
### output
|
||||
output_dir: saves/Kllama_deepseekV3
|
||||
output_dir: saves/KT_FT_deepseekV3
|
||||
logging_steps: 10
|
||||
save_steps: 500
|
||||
plot_loss: true
|
||||
@@ -39,14 +39,7 @@ ddp_timeout: 180000000
|
||||
resume_from_checkpoint: null
|
||||
|
||||
### ktransformers
|
||||
use_kt: true # use KTransformers as LoRA sft backend
|
||||
kt_optimize_rule: examples/ktransformers/kt_optimize_rules/DeepSeek-V3-Chat-sft-amx-multi-gpu.yaml
|
||||
cpu_infer: 32
|
||||
chunk_size: 8192
|
||||
|
||||
### eval
|
||||
# eval_dataset: alpaca_en_demo
|
||||
# val_size: 0.1
|
||||
# per_device_eval_batch_size: 1
|
||||
# eval_strategy: steps
|
||||
# eval_steps: 500
|
||||
use_kt: true
|
||||
# Pair with fsdp2_kt_bf16.yaml for original BF16 checkpoints.
|
||||
# For pre-converted expert weights, uncomment kt_weight_path and use fsdp2_kt_int8.yaml or fsdp2_kt_int4.yaml.
|
||||
# kt_weight_path: /path/to/DeepSeek-V3-AMXINT8
|
||||
@@ -0,0 +1,46 @@
|
||||
### model
|
||||
model_name_or_path: Qwen/Qwen3.5-397B-A17B
|
||||
trust_remote_code: true
|
||||
|
||||
### method
|
||||
stage: sft
|
||||
do_train: true
|
||||
finetuning_type: lora
|
||||
lora_rank: 8
|
||||
lora_target: all
|
||||
|
||||
### dataset
|
||||
dataset: identity, alpaca_en_demo
|
||||
template: qwen3_5
|
||||
cutoff_len: 2048
|
||||
max_samples: 100000
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
dataloader_num_workers: 4
|
||||
|
||||
### output
|
||||
output_dir: saves/KT_FT_qwen35Moe
|
||||
logging_steps: 10
|
||||
save_steps: 500
|
||||
plot_loss: true
|
||||
overwrite_output_dir: true
|
||||
save_only_model: false
|
||||
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
|
||||
|
||||
### train
|
||||
per_device_train_batch_size: 1
|
||||
gradient_accumulation_steps: 8
|
||||
learning_rate: 1.0e-4
|
||||
num_train_epochs: 3.0
|
||||
lr_scheduler_type: cosine
|
||||
warmup_ratio: 0.1
|
||||
bf16: true
|
||||
ddp_timeout: 180000000
|
||||
resume_from_checkpoint: null
|
||||
|
||||
### ktransformers
|
||||
use_kt: true
|
||||
# For original BF16 checkpoints, start with examples/ktransformers/accelerate/fsdp2_kt_bf16.yaml.
|
||||
# For pre-converted expert weights, uncomment kt_weight_path and use fsdp2_kt_int8.yaml or fsdp2_kt_int4.yaml.
|
||||
# Pair the 397B path with fsdp2_kt_int8.yaml, tune cutoff_len to prepared weights and GPU memory.
|
||||
# kt_weight_path: /path/to/Qwen3.5-MoE-AMXINT8
|
||||
@@ -11,7 +11,7 @@ lora_target: all
|
||||
|
||||
### dataset
|
||||
dataset: identity, alpaca_en_demo
|
||||
template: qwen3_nothink
|
||||
template: qwen3
|
||||
cutoff_len: 2048
|
||||
max_samples: 100000
|
||||
overwrite_cache: true
|
||||
@@ -19,9 +19,9 @@ preprocessing_num_workers: 16
|
||||
dataloader_num_workers: 4
|
||||
|
||||
### output
|
||||
output_dir: saves/Kllama_Qwen3MoE_235bA22b
|
||||
output_dir: saves/KT_FT_qwen3Moe
|
||||
logging_steps: 10
|
||||
save_steps: 200
|
||||
save_steps: 500
|
||||
plot_loss: true
|
||||
overwrite_output_dir: true
|
||||
save_only_model: false
|
||||
@@ -31,7 +31,7 @@ report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
|
||||
per_device_train_batch_size: 1
|
||||
gradient_accumulation_steps: 8
|
||||
learning_rate: 1.0e-4
|
||||
num_train_epochs: 3
|
||||
num_train_epochs: 3.0
|
||||
lr_scheduler_type: cosine
|
||||
warmup_ratio: 0.1
|
||||
bf16: true
|
||||
@@ -39,14 +39,7 @@ ddp_timeout: 180000000
|
||||
resume_from_checkpoint: null
|
||||
|
||||
### ktransformers
|
||||
use_kt: true # use KTransformers as LoRA sft backend
|
||||
kt_optimize_rule: examples/ktransformers/kt_optimize_rules/Qwen3Moe-sft-amx.yaml
|
||||
cpu_infer: 32
|
||||
chunk_size: 8192
|
||||
|
||||
### eval
|
||||
# eval_dataset: alpaca_en_demo
|
||||
# val_size: 0.1
|
||||
# per_device_eval_batch_size: 1
|
||||
# eval_strategy: steps
|
||||
# eval_steps: 500
|
||||
use_kt: true
|
||||
# Pair with examples/ktransformers/accelerate/fsdp2_kt_bf16.yaml for original BF16 checkpoints.
|
||||
# For pre-converted expert weights, uncomment kt_weight_path and use fsdp2_kt_int8.yaml or fsdp2_kt_int4.yaml.
|
||||
# kt_weight_path: /path/to/Qwen3-235B-A22B-Instruct-2507-AMXINT8
|
||||
|
||||
@@ -28,12 +28,7 @@ save_only_model: false
|
||||
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
|
||||
|
||||
### ray
|
||||
ray_run_name: qwen3_4b_sft_lora
|
||||
ray_storage_path: ./saves
|
||||
ray_num_workers: 4 # Number of GPUs to use.
|
||||
placement_strategy: PACK
|
||||
resources_per_worker:
|
||||
GPU: 1
|
||||
# ray_init_kwargs:
|
||||
# runtime_env:
|
||||
# env_vars:
|
||||
|
||||
@@ -0,0 +1,31 @@
|
||||
model: Qwen/Qwen3-0.6B
|
||||
model_class: llm
|
||||
|
||||
template: qwen3_nothink
|
||||
|
||||
|
||||
kernel_config:
|
||||
name: auto
|
||||
include_kernels: auto # choice: null/true/false/auto/kernel_id1,kernel_id2,kernel_id3, default is null
|
||||
|
||||
quant_config: null
|
||||
|
||||
dist_config:
|
||||
name: fsdp2
|
||||
dcp_path: null # /mnt/f/pretrain_models/Qwen3-0.6B-dcp
|
||||
|
||||
### data
|
||||
train_dataset: data/v1_sft_demo.yaml
|
||||
|
||||
### training
|
||||
output_dir: outputs/test_fsdp2
|
||||
micro_batch_size: 2
|
||||
batching_strategy: normal
|
||||
|
||||
cutoff_len: 2048
|
||||
learning_rate: 1.0e-4
|
||||
max_steps: 10
|
||||
|
||||
### sample
|
||||
sample_backend: hf
|
||||
max_new_tokens: 128
|
||||
@@ -0,0 +1,30 @@
|
||||
model: Qwen/Qwen3-0.6B
|
||||
model_class: llm
|
||||
|
||||
template: qwen3_nothink
|
||||
|
||||
kernel_config:
|
||||
name: auto
|
||||
include_kernels: auto # choice: null/true/false/auto/kernel_id1,kernel_id2,kernel_id3, default is null
|
||||
|
||||
quant_config: null
|
||||
|
||||
dist_config:
|
||||
name: fsdp2
|
||||
dcp_path: null # /mnt/f/pretrain_models/Qwen3-0.6B-dcp
|
||||
|
||||
### data
|
||||
train_dataset: data/v1_sft_demo.yaml
|
||||
|
||||
|
||||
### training
|
||||
output_dir: outputs/test_fsdp2
|
||||
micro_batch_size: 2
|
||||
batching_strategy: dynamic_batching
|
||||
cutoff_len: 2048
|
||||
learning_rate: 1.0e-4
|
||||
max_steps: 10
|
||||
|
||||
### sample
|
||||
sample_backend: hf
|
||||
max_new_tokens: 128
|
||||
@@ -0,0 +1,30 @@
|
||||
model: Qwen/Qwen3-0.6B
|
||||
model_class: llm
|
||||
|
||||
template: qwen3_nothink
|
||||
|
||||
kernel_config:
|
||||
name: auto
|
||||
include_kernels: auto # choice: null/true/false/auto/kernel_id1,kernel_id2,kernel_id3, default is null
|
||||
|
||||
quant_config: null
|
||||
|
||||
dist_config:
|
||||
name: fsdp2
|
||||
dcp_path: null # /mnt/f/pretrain_models/Qwen3-0.6B-dcp
|
||||
|
||||
### data
|
||||
train_dataset: data/v1_sft_demo.yaml
|
||||
|
||||
### training
|
||||
output_dir: outputs/test_fsdp2
|
||||
micro_batch_size: 4
|
||||
batching_strategy: dynamic_padding_free
|
||||
flash_attn: flash_attention2
|
||||
cutoff_len: 2048
|
||||
learning_rate: 1.0e-4
|
||||
max_steps: 10
|
||||
|
||||
### sample
|
||||
sample_backend: hf
|
||||
max_new_tokens: 128
|
||||
@@ -0,0 +1,30 @@
|
||||
model: Qwen/Qwen3-0.6B
|
||||
model_class: llm
|
||||
|
||||
template: qwen3_nothink
|
||||
|
||||
kernel_config:
|
||||
name: auto
|
||||
include_kernels: auto # choice: null/true/false/auto/kernel_id1,kernel_id2,kernel_id3, default is null
|
||||
|
||||
quant_config: null
|
||||
|
||||
dist_config:
|
||||
name: fsdp2
|
||||
dcp_path: null # /mnt/f/pretrain_models/Qwen3-0.6B-dcp
|
||||
|
||||
### data
|
||||
train_dataset: data/v1_sft_demo.yaml
|
||||
|
||||
### training
|
||||
output_dir: outputs/test_fsdp2
|
||||
micro_batch_size: 4
|
||||
batching_strategy: padding_free
|
||||
flash_attn: flash_attention2
|
||||
cutoff_len: 2048
|
||||
learning_rate: 1.0e-4
|
||||
max_steps: 10
|
||||
|
||||
### sample
|
||||
sample_backend: hf
|
||||
max_new_tokens: 128
|
||||
@@ -1,5 +1,4 @@
|
||||
model: Qwen/Qwen3-4B
|
||||
trust_remote_code: true
|
||||
model_class: llm
|
||||
|
||||
template: qwen3_nothink
|
||||
@@ -30,7 +29,6 @@ micro_batch_size: 1
|
||||
global_batch_size: 4
|
||||
cutoff_len: 2048
|
||||
learning_rate: 2.0e-5
|
||||
bf16: false
|
||||
max_steps: 10
|
||||
|
||||
### sample
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
model: Qwen/Qwen3-0.6B
|
||||
|
||||
model_class: llm
|
||||
|
||||
template: qwen3_nothink
|
||||
@@ -20,5 +19,4 @@ output_dir: outputs/Qwen3-0.6B-deepspeed
|
||||
micro_batch_size: 1
|
||||
cutoff_len: 2048
|
||||
learning_rate: 1.0e-4
|
||||
bf16: true
|
||||
max_steps: 10
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
model: Qwen/Qwen3-0.6B
|
||||
trust_remote_code: true
|
||||
model_class: llm
|
||||
|
||||
template: qwen3_nothink
|
||||
@@ -22,7 +21,6 @@ output_dir: outputs/test_fsdp2
|
||||
micro_batch_size: 1
|
||||
cutoff_len: 2048
|
||||
learning_rate: 1.0e-4
|
||||
bf16: false
|
||||
max_steps: 10
|
||||
|
||||
### sample
|
||||
|
||||
28
examples/v1/train_full/train_full_liger_kernel.yaml
Normal file
28
examples/v1/train_full/train_full_liger_kernel.yaml
Normal file
@@ -0,0 +1,28 @@
|
||||
model: Qwen/Qwen3-0.6B
|
||||
model_class: llm
|
||||
|
||||
template: qwen3_nothink
|
||||
|
||||
kernel_config:
|
||||
name: liger_kernel
|
||||
include_kernels: auto # choice: null/true/false/auto/kernel_id1,kernel_id2,kernel_id3, default is null
|
||||
|
||||
quant_config: null
|
||||
|
||||
dist_config:
|
||||
name: fsdp2
|
||||
dcp_path: null # /mnt/f/pretrain_models/Qwen3-0.6B-dcp
|
||||
|
||||
### data
|
||||
train_dataset: data/v1_sft_demo.yaml
|
||||
|
||||
### training
|
||||
output_dir: outputs/test_fsdp2
|
||||
micro_batch_size: 1
|
||||
cutoff_len: 2048
|
||||
learning_rate: 1.0e-4
|
||||
max_steps: 10
|
||||
|
||||
### sample
|
||||
sample_backend: hf
|
||||
max_new_tokens: 128
|
||||
23
examples/v1/train_full/train_full_ulysses_cp.yaml
Normal file
23
examples/v1/train_full/train_full_ulysses_cp.yaml
Normal file
@@ -0,0 +1,23 @@
|
||||
model: Qwen/Qwen3-0.6B
|
||||
trust_remote_code: true
|
||||
model_class: llm
|
||||
|
||||
template: qwen3_nothink
|
||||
|
||||
# FSDP Config
|
||||
dist_config:
|
||||
name: fsdp2
|
||||
dcp_path: null
|
||||
cp_mode: ulysses
|
||||
cp_size: 2
|
||||
|
||||
### data
|
||||
train_dataset: data/v1_sft_demo.yaml
|
||||
|
||||
### training
|
||||
output_dir: outputs/test_ulysses_cp
|
||||
micro_batch_size: 1
|
||||
cutoff_len: 2048
|
||||
learning_rate: 1.0e-4
|
||||
bf16: false
|
||||
max_steps: 10
|
||||
@@ -1,5 +1,4 @@
|
||||
model: Qwen/Qwen3-4B
|
||||
trust_remote_code: true
|
||||
model_class: llm
|
||||
|
||||
template: qwen3_nothink
|
||||
@@ -28,10 +27,8 @@ train_dataset: data/v1_sft_demo.yaml
|
||||
### training
|
||||
output_dir: ./outputs/test_lora
|
||||
micro_batch_size: 1
|
||||
global_batch_size: 4
|
||||
cutoff_len: 2048
|
||||
learning_rate: 1.0e-4
|
||||
bf16: true
|
||||
max_steps: 10
|
||||
|
||||
### sample
|
||||
|
||||
40
examples/v1/train_lora/train_lora_sft_rank0.yaml
Normal file
40
examples/v1/train_lora/train_lora_sft_rank0.yaml
Normal file
@@ -0,0 +1,40 @@
|
||||
model: Qwen/Qwen3-4B
|
||||
model_class: llm
|
||||
|
||||
template: qwen3_nothink
|
||||
|
||||
# PEFT Configuration
|
||||
peft_config:
|
||||
name: lora
|
||||
r: 16
|
||||
lora_alpha: 32
|
||||
lora_dropout: 0.05
|
||||
target_modules: all
|
||||
|
||||
# Kernel Config
|
||||
kernel_config:
|
||||
name: auto
|
||||
include_kernels: auto
|
||||
|
||||
# FSDP Config
|
||||
dist_config:
|
||||
name: fsdp2
|
||||
dcp_path: null
|
||||
|
||||
init_config:
|
||||
name: init_on_rank0
|
||||
|
||||
### data
|
||||
train_dataset: data/v1_sft_demo.yaml
|
||||
|
||||
### training
|
||||
output_dir: ./outputs/test_lora
|
||||
micro_batch_size: 1
|
||||
cutoff_len: 2048
|
||||
learning_rate: 1.0e-4
|
||||
bf16: true
|
||||
max_steps: 10
|
||||
|
||||
### sample
|
||||
sample_backend: hf
|
||||
max_new_tokens: 128
|
||||
@@ -1,5 +1,4 @@
|
||||
model: Qwen/Qwen3-0.6B
|
||||
trust_remote_code: true
|
||||
model_class: llm
|
||||
|
||||
template: qwen3_nothink
|
||||
@@ -35,7 +34,6 @@ output_dir: outputs/test_quantization
|
||||
micro_batch_size: 1
|
||||
cutoff_len: 2048
|
||||
learning_rate: 1.0e-4
|
||||
bf16: false
|
||||
max_steps: 10
|
||||
|
||||
### sample
|
||||
|
||||
@@ -40,7 +40,7 @@ dependencies = [
|
||||
"torch>=2.4.0",
|
||||
"torchvision>=0.19.0",
|
||||
"torchaudio>=2.4.0",
|
||||
"transformers>=4.51.0,<=5.2.0,!=4.52.0,!=4.57.0",
|
||||
"transformers>=4.55.0,<=5.6.0,!=4.52.0,!=4.57.0",
|
||||
"datasets>=2.16.0,<=4.0.0",
|
||||
"accelerate>=1.3.0,<=1.11.0",
|
||||
"peft>=0.18.0,<=0.18.1",
|
||||
|
||||
1
requirements/ktransformers.txt
Normal file
1
requirements/ktransformers.txt
Normal file
@@ -0,0 +1 @@
|
||||
ktransformers[sft]
|
||||
@@ -1,4 +1,5 @@
|
||||
torch==2.7.1
|
||||
torch-npu==2.7.1
|
||||
torch-npu==2.7.1.post4
|
||||
torchvision==0.22.1
|
||||
torchaudio==2.7.1
|
||||
decorator
|
||||
|
||||
76
scripts/dcp2hf.py
Normal file
76
scripts/dcp2hf.py
Normal file
@@ -0,0 +1,76 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Convert a DCP checkpoint to HuggingFace model format.
|
||||
|
||||
Usage:
|
||||
python scripts/dcp2hf.py convert --dcp_path=/path/to/dcp --hf_path=/path/to/hf --config_path=/path/to/config
|
||||
|
||||
Arguments:
|
||||
dcp_path: Path to the DCP checkpoint directory.
|
||||
hf_path: Output path (directory) for HuggingFace model.
|
||||
config_path: Path to the HuggingFace model directory containing config.json.
|
||||
"""
|
||||
|
||||
import fire
|
||||
import torch
|
||||
import torch.distributed.checkpoint as dcp
|
||||
import transformers
|
||||
from transformers import AutoConfig
|
||||
|
||||
|
||||
def convert(dcp_path: str, hf_path: str, config_path: str) -> None:
|
||||
"""Convert DCP model weights to HF.
|
||||
|
||||
Note: this script is used to convert a DCP checkpoint to HuggingFace model format,
|
||||
it will just convert the DCP checkpoint to a HuggingFace model format, for the tokenizer,
|
||||
you may need to copy from the original model.
|
||||
|
||||
Args:
|
||||
dcp_path: DCP checkpoint directory.
|
||||
hf_path: Output path (directory) for HuggingFace model.
|
||||
config_path: Path to the HuggingFace model directory containing config.json.
|
||||
"""
|
||||
if not dcp_path or not hf_path or not config_path:
|
||||
raise ValueError("All 'dcp_path', 'hf_path', and 'config_path' are required.")
|
||||
|
||||
print(f"Loading config from {config_path}...")
|
||||
config = AutoConfig.from_pretrained(config_path)
|
||||
architectures = getattr(config, "architectures", [])
|
||||
if architectures:
|
||||
model_cls = getattr(transformers, architectures[0], transformers.AutoModelForCausalLM)
|
||||
else:
|
||||
model_cls = transformers.AutoModelForCausalLM
|
||||
|
||||
print("Initializing model on CPU...")
|
||||
model = model_cls(config).to(torch.bfloat16)
|
||||
|
||||
print(f"Loading DCP from {dcp_path}...")
|
||||
state_dict = model.state_dict()
|
||||
dcp.load(state_dict, checkpoint_id=dcp_path)
|
||||
model.load_state_dict(state_dict)
|
||||
|
||||
print(f"Saving to HF format at {hf_path}...")
|
||||
model.save_pretrained(hf_path)
|
||||
config.save_pretrained(hf_path)
|
||||
print("Done!")
|
||||
|
||||
|
||||
def help() -> None:
|
||||
"""Show help message."""
|
||||
print(__doc__)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire({"convert": convert, "help": help, "--convert": convert})
|
||||
@@ -25,7 +25,8 @@ Arguments:
|
||||
import fire
|
||||
import torch
|
||||
import torch.distributed.checkpoint as dcp
|
||||
from transformers import AutoModelForCausalLM
|
||||
import transformers
|
||||
from transformers import AutoConfig
|
||||
|
||||
|
||||
def convert(hf_path: str, dcp_path: str) -> None:
|
||||
@@ -39,7 +40,14 @@ def convert(hf_path: str, dcp_path: str) -> None:
|
||||
raise ValueError("Both 'hf_path' and 'dcp_path' are required.")
|
||||
|
||||
print(f"Loading HF model from {hf_path}...")
|
||||
model = AutoModelForCausalLM.from_pretrained(hf_path, device_map="cpu", torch_dtype=torch.bfloat16)
|
||||
config = AutoConfig.from_pretrained(hf_path)
|
||||
architectures = getattr(config, "architectures", [])
|
||||
if architectures:
|
||||
model_cls = getattr(transformers, architectures[0], transformers.AutoModelForCausalLM)
|
||||
else:
|
||||
model_cls = transformers.AutoModelForCausalLM
|
||||
|
||||
model = model_cls.from_pretrained(hf_path, device_map="cpu", torch_dtype=torch.bfloat16)
|
||||
|
||||
print(f"Saving to DCP format at {dcp_path}...")
|
||||
dcp.save(model.state_dict(), checkpoint_id=dcp_path)
|
||||
|
||||
@@ -71,6 +71,7 @@ def convert(
|
||||
pipeline_model_parallel_size: int = 1,
|
||||
expert_model_parallel_size: int = 1,
|
||||
virtual_pipeline_model_parallel_size: int | None = None,
|
||||
moe_grouped_gemm: bool | None = None,
|
||||
):
|
||||
"""Convert checkpoint between MCA and HuggingFace formats.
|
||||
|
||||
@@ -84,6 +85,10 @@ def convert(
|
||||
pipeline_model_parallel_size: Pipeline model parallel size
|
||||
expert_model_parallel_size: Expert model parallel size
|
||||
virtual_pipeline_model_parallel_size: Virtual pipeline model parallel size
|
||||
moe_grouped_gemm: Use grouped gemm for MoE experts. When enabled, expert
|
||||
weights are stored in a flattened format (linear_fc1.weight0, weight1, ...)
|
||||
rather than per-expert format (local_experts.0.linear_fc1.weight, ...).
|
||||
Must match the format used when saving the checkpoint.
|
||||
"""
|
||||
if bf16 and fp16:
|
||||
raise ValueError("bf16 and fp16 cannot be both True.")
|
||||
@@ -97,8 +102,9 @@ def convert(
|
||||
pipeline_model_parallel_size=pipeline_model_parallel_size,
|
||||
expert_model_parallel_size=expert_model_parallel_size,
|
||||
virtual_pipeline_model_parallel_size=virtual_pipeline_model_parallel_size,
|
||||
moe_grouped_gemm=moe_grouped_gemm,
|
||||
transformer_impl="transformer_engine", # hard code here since we default using te for training
|
||||
)
|
||||
|
||||
convert_checkpoint_to_mca(
|
||||
checkpoint_path,
|
||||
output_path,
|
||||
|
||||
@@ -88,7 +88,10 @@ def _process_request(
|
||||
|
||||
if request.messages[0].role == Role.SYSTEM:
|
||||
content = request.messages.pop(0).content
|
||||
system = content[0].text if isinstance(content, list) else content
|
||||
if isinstance(content, list):
|
||||
system = content[0].text if content else ""
|
||||
else:
|
||||
system = content
|
||||
else:
|
||||
system = None
|
||||
|
||||
|
||||
@@ -71,16 +71,6 @@ class ChatModel:
|
||||
"SGLang not install, you may need to run `pip install sglang[all]`\n"
|
||||
"or try to use HuggingFace backend: --infer_backend huggingface"
|
||||
) from e
|
||||
elif model_args.infer_backend == EngineName.KT:
|
||||
try:
|
||||
from .kt_engine import KTransformersEngine
|
||||
|
||||
self.engine: BaseEngine = KTransformersEngine(model_args, data_args, finetuning_args, generating_args)
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"KTransformers not install, you may need to run `pip install ktransformers`\n"
|
||||
"or try to use HuggingFace backend: --infer_backend huggingface"
|
||||
) from e
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown backend: {model_args.infer_backend}")
|
||||
|
||||
|
||||
@@ -1,284 +0,0 @@
|
||||
# Copyright 2025 the KVCache.AI team, Approaching AI, and the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import platform
|
||||
from collections.abc import AsyncGenerator
|
||||
from threading import Thread
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
import torch
|
||||
from typing_extensions import override
|
||||
|
||||
from ..data import get_template_and_fix_tokenizer
|
||||
from ..extras import logging
|
||||
from ..extras.constants import EngineName
|
||||
from ..model import load_model, load_tokenizer
|
||||
from .base_engine import BaseEngine, Response
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PreTrainedTokenizer
|
||||
from trl import PreTrainedModelWrapper
|
||||
|
||||
from ..data.mm_plugin import AudioInput, ImageInput, VideoInput
|
||||
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
|
||||
|
||||
from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled
|
||||
from ktransformers.server.config.config import Config
|
||||
from ktransformers.util.utils import (
|
||||
get_compute_capability,
|
||||
prefill_and_generate_capture,
|
||||
)
|
||||
from ktransformers.util.vendors import GPUVendor, device_manager
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class KTransformersEngine(BaseEngine):
|
||||
def __init__(
|
||||
self,
|
||||
model_args: "ModelArguments",
|
||||
data_args: "DataArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
generating_args: "GeneratingArguments",
|
||||
) -> None:
|
||||
self.name = EngineName.KT
|
||||
self.can_generate = finetuning_args.stage == "sft"
|
||||
|
||||
tok_mod = load_tokenizer(model_args)
|
||||
self.tokenizer = tok_mod["tokenizer"]
|
||||
self.tokenizer.padding_side = "left" if self.can_generate else "right"
|
||||
self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args)
|
||||
|
||||
self.model = load_model(
|
||||
self.tokenizer, model_args, finetuning_args, is_trainable=False, add_valuehead=(not self.can_generate)
|
||||
)
|
||||
|
||||
self.generating_args = generating_args.to_dict()
|
||||
self.max_new_tokens = model_args.kt_maxlen
|
||||
self.use_cuda_graph = model_args.kt_use_cuda_graph
|
||||
self.mode = model_args.kt_mode
|
||||
self.force_think = model_args.kt_force_think
|
||||
self.chunk_size = model_args.chunk_size
|
||||
|
||||
try:
|
||||
asyncio.get_event_loop()
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
self.semaphore = asyncio.Semaphore(int(os.getenv("MAX_CONCURRENT", "1")))
|
||||
|
||||
@staticmethod
|
||||
@torch.inference_mode()
|
||||
def _get_scores(
|
||||
model: "PreTrainedModelWrapper",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
batch_input: list[str],
|
||||
input_kwargs: Optional[dict[str, Any]] = {},
|
||||
) -> list[float]:
|
||||
max_length: Optional[int] = input_kwargs.pop("max_length", None)
|
||||
device = getattr(model.pretrained_model, "device", "cuda")
|
||||
inputs = tokenizer(
|
||||
batch_input,
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=max_length or getattr(model.config, "max_position_embeddings", 1024),
|
||||
return_tensors="pt",
|
||||
add_special_tokens=False,
|
||||
).to(device)
|
||||
values: torch.Tensor = model(**inputs, return_dict=True, use_cache=False)[-1]
|
||||
scores = values.gather(dim=-1, index=(inputs["attention_mask"].sum(dim=-1, keepdim=True) - 1))
|
||||
return scores
|
||||
|
||||
async def _generate(
|
||||
self,
|
||||
messages: list[dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
**input_kwargs,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
paired = messages + [{"role": "assistant", "content": ""}]
|
||||
prompt_ids, _ = self.template.encode_oneturn(self.tokenizer, paired, system, tools)
|
||||
prompt_len = len(prompt_ids)
|
||||
|
||||
max_length: Optional[int] = input_kwargs.pop("max_length", None)
|
||||
max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None)
|
||||
|
||||
if "max_new_tokens" in self.generating_args:
|
||||
max_tokens = int(self.generating_args["max_new_tokens"])
|
||||
elif "max_length" in self.generating_args:
|
||||
gl = int(self.generating_args["max_length"])
|
||||
max_tokens = gl - prompt_len if gl > prompt_len else 1
|
||||
else:
|
||||
max_tokens = self.max_new_tokens or 256
|
||||
|
||||
if max_length is not None:
|
||||
max_tokens = max(max_length - prompt_len, 1)
|
||||
if max_new_tokens is not None:
|
||||
max_tokens = int(max_new_tokens)
|
||||
max_tokens = max(1, int(max_tokens))
|
||||
|
||||
if self.mode == "long_context":
|
||||
max_len_cfg = Config().long_context_config["max_seq_len"]
|
||||
need = prompt_len + max_tokens
|
||||
assert max_len_cfg > need, f"please set max_seq_len > {need} in ~/.ktransformers/config.yaml"
|
||||
|
||||
device = next(self.model.parameters()).device
|
||||
input_tensor = torch.tensor([prompt_ids], dtype=torch.long, device=device)
|
||||
if self.force_think:
|
||||
think = torch.tensor(
|
||||
[self.tokenizer.encode("<think>\n", add_special_tokens=False)], dtype=torch.long, device=device
|
||||
)
|
||||
input_tensor = torch.cat([input_tensor, think], dim=1)
|
||||
|
||||
use_flashinfer = (
|
||||
platform.system() != "Windows"
|
||||
and getattr(self.model.config, "architectures", [""])[0]
|
||||
in {"DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"}
|
||||
and flashinfer_enabled
|
||||
and get_compute_capability() >= 8
|
||||
and device_manager.gpu_vendor == GPUVendor.NVIDIA
|
||||
)
|
||||
|
||||
def make_gen():
|
||||
if use_flashinfer:
|
||||
return prefill_and_generate_capture(
|
||||
self.model,
|
||||
self.tokenizer,
|
||||
input_tensor,
|
||||
max_tokens,
|
||||
self.use_cuda_graph,
|
||||
mode=self.mode,
|
||||
force_think=self.force_think,
|
||||
chunk_size=self.chunk_size,
|
||||
use_flashinfer_mla=True,
|
||||
num_heads=self.model.config.num_attention_heads,
|
||||
head_dim_ckv=getattr(self.model.config, "kv_lora_rank", 0),
|
||||
head_dim_kpe=getattr(self.model.config, "qk_rope_head_dim", 0),
|
||||
q_head_dim=getattr(self.model.config, "qk_rope_head_dim", 0)
|
||||
+ getattr(self.model.config, "qk_nope_head_dim", 0),
|
||||
echo_stream=False,
|
||||
)
|
||||
else:
|
||||
return prefill_and_generate_capture(
|
||||
self.model,
|
||||
self.tokenizer,
|
||||
input_tensor,
|
||||
max_tokens,
|
||||
self.use_cuda_graph,
|
||||
mode=self.mode,
|
||||
force_think=self.force_think,
|
||||
chunk_size=self.chunk_size,
|
||||
echo_stream=False,
|
||||
)
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
q: asyncio.Queue[Optional[str]] = asyncio.Queue()
|
||||
|
||||
def producer():
|
||||
try:
|
||||
gen = make_gen()
|
||||
if hasattr(gen, "__aiter__"):
|
||||
|
||||
async def drain_async():
|
||||
async for t in gen:
|
||||
loop.call_soon_threadsafe(q.put_nowait, t if isinstance(t, str) else str(t))
|
||||
|
||||
asyncio.run(drain_async())
|
||||
elif hasattr(gen, "__iter__"):
|
||||
for t in gen:
|
||||
loop.call_soon_threadsafe(q.put_nowait, t if isinstance(t, str) else str(t))
|
||||
else:
|
||||
loop.call_soon_threadsafe(q.put_nowait, gen if isinstance(gen, str) else str(gen))
|
||||
finally:
|
||||
loop.call_soon_threadsafe(q.put_nowait, None)
|
||||
|
||||
Thread(target=producer, daemon=True).start()
|
||||
|
||||
while True:
|
||||
item = await q.get()
|
||||
if item is None:
|
||||
break
|
||||
yield item
|
||||
|
||||
@override
|
||||
async def chat(
|
||||
self,
|
||||
messages: list[dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
images: Optional[list["ImageInput"]] = None,
|
||||
videos: Optional[list["VideoInput"]] = None,
|
||||
audios: Optional[list["AudioInput"]] = None,
|
||||
**input_kwargs,
|
||||
) -> list["Response"]:
|
||||
if not self.can_generate:
|
||||
raise ValueError("The current model does not support `chat`.")
|
||||
async with self.semaphore:
|
||||
produced = ""
|
||||
final_text = ""
|
||||
async for t in self._generate(messages, system, tools, **input_kwargs):
|
||||
delta = t
|
||||
produced = produced + delta
|
||||
if delta:
|
||||
final_text += delta
|
||||
|
||||
prompt_ids, _ = self.template.encode_oneturn(
|
||||
self.tokenizer, messages + [{"role": "assistant", "content": ""}], system, tools
|
||||
)
|
||||
return [
|
||||
Response(
|
||||
response_text=final_text,
|
||||
response_length=len(self.tokenizer.encode(final_text, add_special_tokens=False)),
|
||||
prompt_length=len(prompt_ids),
|
||||
finish_reason="stop",
|
||||
)
|
||||
]
|
||||
|
||||
@override
|
||||
async def stream_chat(
|
||||
self,
|
||||
messages: list[dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
images: Optional[list["ImageInput"]] = None,
|
||||
videos: Optional[list["VideoInput"]] = None,
|
||||
audios: Optional[list["AudioInput"]] = None,
|
||||
**input_kwargs,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
if not self.can_generate:
|
||||
raise ValueError("The current model does not support `stream_chat`.")
|
||||
async with self.semaphore:
|
||||
produced = ""
|
||||
async for t in self._generate(messages, system, tools, **input_kwargs):
|
||||
delta = t[len(produced) :] if t.startswith(produced) else t
|
||||
produced = t
|
||||
if delta:
|
||||
yield delta
|
||||
|
||||
@override
|
||||
async def get_scores(
|
||||
self,
|
||||
batch_input: list[str],
|
||||
**input_kwargs,
|
||||
) -> list[float]:
|
||||
if self.can_generate:
|
||||
raise ValueError("Cannot get scores using an auto-regressive model.")
|
||||
args = (self.model, self.tokenizer, batch_input, input_kwargs)
|
||||
async with self.semaphore:
|
||||
return await asyncio.to_thread(self._get_scores, *args)
|
||||
@@ -15,6 +15,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
import inspect
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Literal, Optional
|
||||
@@ -25,7 +26,7 @@ import torch.nn.functional as F
|
||||
from peft import PeftModel
|
||||
from transformers import DataCollatorForSeq2Seq
|
||||
|
||||
from ..extras.constants import AUDIO_PLACEHOLDER, IGNORE_INDEX, IMAGE_PLACEHOLDER
|
||||
from ..extras.constants import AUDIO_PLACEHOLDER, IGNORE_INDEX, IMAGE_PLACEHOLDER, MROPE_MODELS
|
||||
from ..extras.packages import is_pillow_available
|
||||
|
||||
|
||||
@@ -39,6 +40,56 @@ if TYPE_CHECKING:
|
||||
from .template import Template
|
||||
|
||||
|
||||
def _slice_mm_inputs_for_sample(
|
||||
mm_inputs: dict[str, Any],
|
||||
batch_imglens: list[int],
|
||||
batch_vidlens: list[int],
|
||||
batch_idx: int,
|
||||
images_per_subseq: Optional[list[int]] = None,
|
||||
videos_per_subseq: Optional[list[int]] = None,
|
||||
subseq_idx: Optional[int] = None,
|
||||
) -> dict[str, Any]:
|
||||
r"""Slice mm_inputs for one batch sample, optionally for a single sub-sequence when packing.
|
||||
|
||||
image_grid_thw / video_grid_thw have shape [num_items, 3]. Indices for sample batch_idx
|
||||
are batch_imglens[batch_idx] images and batch_vidlens[batch_idx] videos. When subseq_idx
|
||||
is given, further restrict to that sub-seq's counts via packed_*_counts.
|
||||
has_dummy_image=True means only batch[0] will be concated with fake image and no multimodal data.
|
||||
"""
|
||||
image_start_idx = sum(batch_imglens[:batch_idx])
|
||||
image_end_idx = sum(batch_imglens[: batch_idx + 1])
|
||||
video_start_idx = sum(batch_vidlens[:batch_idx])
|
||||
video_end_idx = sum(batch_vidlens[: batch_idx + 1])
|
||||
|
||||
if subseq_idx is not None and images_per_subseq is not None:
|
||||
image_start_idx += sum(images_per_subseq[:subseq_idx])
|
||||
image_end_idx = image_start_idx + images_per_subseq[subseq_idx]
|
||||
|
||||
if subseq_idx is not None and videos_per_subseq is not None:
|
||||
video_start_idx += sum(videos_per_subseq[:subseq_idx])
|
||||
video_end_idx = video_start_idx + videos_per_subseq[subseq_idx]
|
||||
|
||||
sliced_mm_inputs: dict[str, Any] = {}
|
||||
key_to_slice_meta = {
|
||||
"image_grid_thw": (image_start_idx, image_end_idx, True),
|
||||
"video_grid_thw": (video_start_idx, video_end_idx, True),
|
||||
"second_per_grid_ts": (video_start_idx, video_end_idx, False), # qwen2.5vl
|
||||
"video_second_per_grid": (video_start_idx, video_end_idx, False), # qwen omni
|
||||
}
|
||||
|
||||
for key, (start_idx, end_idx, assign_none_when_empty) in key_to_slice_meta.items():
|
||||
if key not in mm_inputs:
|
||||
continue
|
||||
|
||||
mm_value = mm_inputs[key]
|
||||
if mm_value is not None and end_idx > start_idx:
|
||||
sliced_mm_inputs[key] = mm_value[start_idx:end_idx]
|
||||
elif assign_none_when_empty:
|
||||
sliced_mm_inputs[key] = None
|
||||
|
||||
return sliced_mm_inputs
|
||||
|
||||
|
||||
def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype: "torch.dtype") -> "torch.Tensor":
|
||||
r"""Expand 2d attention mask to 4d attention mask.
|
||||
|
||||
@@ -106,9 +157,174 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
else:
|
||||
self.get_rope_func = None
|
||||
|
||||
def _compute_rope_position_ids(self, features: dict[str, "torch.Tensor"], mm_inputs: dict[str, Any]) -> None:
|
||||
r"""Compute position_ids and rope_deltas via get_rope_func for VLMs."""
|
||||
rope_index_kwargs = {
|
||||
"input_ids": features["input_ids"],
|
||||
"image_grid_thw": mm_inputs.get("image_grid_thw"),
|
||||
"video_grid_thw": mm_inputs.get("video_grid_thw"),
|
||||
"attention_mask": (features["attention_mask"] >= 1).float(),
|
||||
}
|
||||
if features["attention_mask"].sum() == 0: # for pad tokens
|
||||
seq_len = features["input_ids"].shape[-1]
|
||||
features["position_ids"] = (
|
||||
torch.arange(seq_len).view(1, 1, seq_len).expand(3, *features["input_ids"].shape).contiguous()
|
||||
)
|
||||
features["rope_deltas"] = torch.zeros(features["input_ids"].shape[0])
|
||||
return
|
||||
|
||||
if "mm_token_type_ids" in inspect.signature(self.get_rope_func).parameters:
|
||||
image_token_id = getattr(self.model.config, "image_token_id", None)
|
||||
video_token_id = getattr(self.model.config, "video_token_id", None)
|
||||
if image_token_id is not None or video_token_id is not None:
|
||||
mm_token_type_ids = torch.zeros_like(features["input_ids"])
|
||||
if image_token_id is not None:
|
||||
mm_token_type_ids[features["input_ids"] == image_token_id] = 1
|
||||
if video_token_id is not None:
|
||||
mm_token_type_ids[features["input_ids"] == video_token_id] = 2
|
||||
rope_index_kwargs["mm_token_type_ids"] = mm_token_type_ids
|
||||
|
||||
if "second_per_grid_ts" in mm_inputs: # for qwen2vl
|
||||
rope_index_kwargs["second_per_grid_ts"] = mm_inputs.get("second_per_grid_ts")
|
||||
elif "video_second_per_grid" in mm_inputs: # for qwen2.5 omni
|
||||
rope_index_kwargs["second_per_grids"] = mm_inputs.get("video_second_per_grid")
|
||||
|
||||
if getattr(self.model.config, "model_type", None) in ["qwen2_5_omni_thinker", "qwen3_omni_moe_thinker"]:
|
||||
rope_index_kwargs["use_audio_in_video"] = getattr(self.processor, "use_audio_in_video", False)
|
||||
feature_attention_mask = mm_inputs.get("feature_attention_mask", None)
|
||||
if feature_attention_mask is not None: # FIXME: need to get video image lengths
|
||||
audio_feature_lengths = torch.sum(feature_attention_mask, dim=1)
|
||||
rope_index_kwargs["audio_seqlens"] = audio_feature_lengths # prepare for input
|
||||
|
||||
features["position_ids"], rope_deltas = self.get_rope_func(**rope_index_kwargs)
|
||||
features["rope_deltas"] = rope_deltas - (1 - rope_index_kwargs["attention_mask"]).sum(dim=-1).unsqueeze(-1)
|
||||
else: # for qwen vl
|
||||
features["position_ids"], features["rope_deltas"] = self.get_rope_func(**rope_index_kwargs)
|
||||
|
||||
def _compute_rope_position_ids_with_packing(
|
||||
self,
|
||||
features: dict[str, "torch.Tensor"],
|
||||
mm_inputs: dict[str, Any],
|
||||
packing_params_list: list[dict[str, Any] | None],
|
||||
batch_imglens: list[int],
|
||||
batch_vidlens: list[int],
|
||||
batch_audlens: list[int],
|
||||
has_dummy_image: bool,
|
||||
) -> None:
|
||||
r"""Compute position_ids and rope_deltas per sample (or per sub-sequence when packed), then merge and validate."""
|
||||
bsz = features["input_ids"].size(0)
|
||||
seq_len = features["input_ids"].size(1)
|
||||
all_position_ids: list[torch.Tensor] = []
|
||||
all_rope_deltas: list[torch.Tensor] = []
|
||||
|
||||
if has_dummy_image:
|
||||
# for [0, seq_len] = [0, unpadded_length + right_padding_length + fake_input_ids_len + collator_padding_length]
|
||||
# FIXME: maybe right_padding_length is large, with improper max_cutoff_len
|
||||
unpadded_length = int(features["attention_mask"][0].bool().sum().item())
|
||||
right_padding_length = int((packing_params_list[0] or {}).get("right_padding_length") or 0)
|
||||
fake_input_padding_length = max(0, seq_len - unpadded_length - right_padding_length)
|
||||
# avoid continual cuseqlens breaking varlen attention @kuangdd
|
||||
# https://github.com/hiyouga/LlamaFactory/issues/10452
|
||||
dummy_image_right_padding_mrope = (
|
||||
torch.arange(fake_input_padding_length)
|
||||
.view(1, 1, fake_input_padding_length)
|
||||
.expand(3, bsz, fake_input_padding_length)
|
||||
)
|
||||
dummy_image_right_padding_attention_mask = torch.zeros((bsz, fake_input_padding_length))
|
||||
assert self.tokenizer.padding_side == "right", "padding_side should be right when fake image is injected"
|
||||
dummy_mm_inputs = copy.deepcopy(mm_inputs)
|
||||
|
||||
for sample_idx in range(bsz):
|
||||
sample_packing = (packing_params_list[sample_idx] or {}) if sample_idx < len(packing_params_list) else {}
|
||||
sequence_boundaries = sample_packing.get("sequence_boundaries")
|
||||
num_sub_seqs = (
|
||||
(len(sequence_boundaries) - 1) if sequence_boundaries and len(sequence_boundaries) > 1 else 1
|
||||
)
|
||||
image_subseq_ids = sample_packing.get("image_subseq_ids") or []
|
||||
video_subseq_ids = sample_packing.get("video_subseq_ids") or []
|
||||
images_per_subseq = (
|
||||
[image_subseq_ids.count(i) for i in range(num_sub_seqs)]
|
||||
if image_subseq_ids and num_sub_seqs > 1
|
||||
else None
|
||||
)
|
||||
videos_per_subseq = (
|
||||
[video_subseq_ids.count(i) for i in range(num_sub_seqs)]
|
||||
if video_subseq_ids and num_sub_seqs > 1
|
||||
else None
|
||||
)
|
||||
if has_dummy_image:
|
||||
mm_inputs = {}
|
||||
|
||||
if num_sub_seqs <= 1:
|
||||
sample_features = {
|
||||
"input_ids": features["input_ids"],
|
||||
"attention_mask": features["attention_mask"][sample_idx : sample_idx + 1],
|
||||
}
|
||||
mm_inputs_for_sample = _slice_mm_inputs_for_sample(
|
||||
mm_inputs, batch_imglens, batch_vidlens, sample_idx=sample_idx
|
||||
)
|
||||
self._compute_rope_position_ids(sample_features, mm_inputs_for_sample)
|
||||
all_position_ids.append(sample_features["position_ids"])
|
||||
all_rope_deltas.append(sample_features["rope_deltas"])
|
||||
else:
|
||||
# when we do packing, don't need rope_deltas when training.
|
||||
sample_position_ids: list[torch.Tensor] = []
|
||||
for subseq_idx in range(num_sub_seqs):
|
||||
subseq_start = sequence_boundaries[subseq_idx]
|
||||
subseq_end = sequence_boundaries[subseq_idx + 1]
|
||||
subseq_features = {
|
||||
"input_ids": features["input_ids"][sample_idx : sample_idx + 1, subseq_start:subseq_end],
|
||||
"attention_mask": features["attention_mask"][
|
||||
sample_idx : sample_idx + 1, subseq_start:subseq_end
|
||||
],
|
||||
}
|
||||
mm_inputs_for_subseq = _slice_mm_inputs_for_sample(
|
||||
mm_inputs,
|
||||
batch_imglens,
|
||||
batch_vidlens,
|
||||
sample_idx,
|
||||
images_per_subseq,
|
||||
videos_per_subseq,
|
||||
subseq_idx,
|
||||
)
|
||||
self._compute_rope_position_ids(subseq_features, mm_inputs_for_subseq)
|
||||
sample_position_ids.append(subseq_features["position_ids"])
|
||||
|
||||
all_position_ids.append(torch.cat(sample_position_ids, dim=-1))
|
||||
|
||||
batch_dim_for_position_ids = 1 if all_position_ids[0].dim() == 3 else 0
|
||||
|
||||
features["position_ids"] = torch.cat(all_position_ids, dim=batch_dim_for_position_ids)
|
||||
if has_dummy_image:
|
||||
mm_inputs = dummy_mm_inputs
|
||||
|
||||
expected_position_ids_shape = (
|
||||
(bsz, seq_len)
|
||||
if all_position_ids[0].dim() == 2
|
||||
else (
|
||||
all_position_ids[0].size(0),
|
||||
bsz,
|
||||
seq_len,
|
||||
)
|
||||
)
|
||||
# Check if position_ids shape matches expected shape.
|
||||
# for further usage, we should padding to the right when some padding token on the right.
|
||||
if has_dummy_image:
|
||||
features["position_ids"] = torch.cat([features["position_ids"], dummy_image_right_padding_mrope], dim=-1)
|
||||
features["attention_mask"] = torch.cat(
|
||||
[features["attention_mask"], dummy_image_right_padding_attention_mask], dim=-1
|
||||
)
|
||||
|
||||
if features["position_ids"].shape != expected_position_ids_shape:
|
||||
raise ValueError(
|
||||
"Merged position_ids shape mismatch: "
|
||||
f"got {features['position_ids'].shape}, expected {expected_position_ids_shape}."
|
||||
)
|
||||
|
||||
def __call__(self, features: list[dict[str, Any]]) -> dict[str, "torch.Tensor"]:
|
||||
batch_images, batch_videos, batch_audios = [], [], []
|
||||
batch_imglens, batch_vidlens, batch_audlens, batch_input_ids = [], [], [], []
|
||||
packing_params_list: list[dict[str, Any] | None] = []
|
||||
for feature in features:
|
||||
images = feature.pop("images", None) or []
|
||||
videos = feature.pop("videos", None) or []
|
||||
@@ -120,8 +336,10 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
batch_vidlens.append(len(videos))
|
||||
batch_audlens.append(len(audios))
|
||||
batch_input_ids.append(feature["input_ids"])
|
||||
packing_params_list.append(feature.pop("packing_params", None))
|
||||
|
||||
fake_input_ids = []
|
||||
has_dummy_image = False
|
||||
if (
|
||||
self.template.mm_plugin.image_token is not None and sum(batch_imglens) == 0 and sum(batch_vidlens) == 0
|
||||
): # avoid process hanging in zero3/fsdp case
|
||||
@@ -137,6 +355,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
fake_input_ids.extend(_fake_input_ids)
|
||||
batch_images = fake_images
|
||||
batch_imglens[0] = 1
|
||||
has_dummy_image = True
|
||||
|
||||
if (
|
||||
self.template.mm_plugin.audio_token is not None and sum(batch_audlens) == 0
|
||||
@@ -181,59 +400,63 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
for i, feature in enumerate(features):
|
||||
feature["token_type_ids"] = token_type_ids[i]
|
||||
|
||||
if "mm_token_type_ids" in mm_inputs: # need tensor-like for gemma4
|
||||
mm_token_type_ids = mm_inputs.pop("mm_token_type_ids")
|
||||
max_len = max(len(ids) for ids in mm_token_type_ids)
|
||||
padded = []
|
||||
for ids in mm_token_type_ids:
|
||||
pad_len = max_len - len(ids)
|
||||
if self.tokenizer.padding_side == "right":
|
||||
padded.append(ids + [0] * pad_len)
|
||||
else:
|
||||
padded.append([0] * pad_len + ids)
|
||||
|
||||
mm_inputs["mm_token_type_ids"] = torch.tensor(padded, dtype=torch.long)
|
||||
|
||||
features: dict[str, torch.Tensor] = super().__call__(features)
|
||||
|
||||
bsz, seq_len = features["input_ids"].shape[:2]
|
||||
model_type = getattr(self.model.config, "model_type", None) if self.model is not None else None
|
||||
is_omni = model_type in [
|
||||
"qwen2_5_omni_thinker",
|
||||
"qwen3_omni_moe_thinker",
|
||||
]
|
||||
|
||||
if self.get_rope_func is not None:
|
||||
rope_index_kwargs = {
|
||||
"input_ids": features["input_ids"],
|
||||
"image_grid_thw": mm_inputs.get("image_grid_thw"),
|
||||
"video_grid_thw": mm_inputs.get("video_grid_thw"),
|
||||
"attention_mask": (features["attention_mask"] >= 1).float(),
|
||||
}
|
||||
if "mm_token_type_ids" in inspect.signature(self.get_rope_func).parameters:
|
||||
image_token_id = getattr(self.model.config, "image_token_id", None)
|
||||
video_token_id = getattr(self.model.config, "video_token_id", None)
|
||||
if image_token_id is not None or video_token_id is not None:
|
||||
mm_token_type_ids = torch.zeros_like(features["input_ids"])
|
||||
if image_token_id is not None:
|
||||
mm_token_type_ids[features["input_ids"] == image_token_id] = 1
|
||||
if video_token_id is not None:
|
||||
mm_token_type_ids[features["input_ids"] == video_token_id] = 2
|
||||
rope_index_kwargs["mm_token_type_ids"] = mm_token_type_ids
|
||||
if "second_per_grid_ts" in mm_inputs: # for qwen2vl
|
||||
rope_index_kwargs["second_per_grid_ts"] = mm_inputs.get("second_per_grid_ts")
|
||||
elif "video_second_per_grid" in mm_inputs: # for qwen2.5 omni
|
||||
rope_index_kwargs["second_per_grids"] = mm_inputs.get("video_second_per_grid")
|
||||
# for mmrope situation, we should calculate position_ids and rope_deltas per sample.
|
||||
# When neat_packing is on, each sample has packing_params; None means no packing for that sample.
|
||||
boundaries_list = [p.get("sequence_boundaries") if p is not None else None for p in packing_params_list]
|
||||
has_packing = any(b is not None and len(b) > 2 for b in boundaries_list)
|
||||
if has_dummy_image and has_packing:
|
||||
# FIXME: too tricky, need to be refactored @kuangdd
|
||||
features["has_dummy_image"] = True
|
||||
|
||||
if getattr(self.model.config, "model_type", None) in ["qwen2_5_omni_thinker", "qwen3_omni_moe_thinker"]:
|
||||
rope_index_kwargs["use_audio_in_video"] = getattr(self.processor, "use_audio_in_video", False)
|
||||
feature_attention_mask = mm_inputs.get("feature_attention_mask", None)
|
||||
if feature_attention_mask is not None: # FIXME: need to get video image lengths
|
||||
audio_feature_lengths = torch.sum(feature_attention_mask, dim=1)
|
||||
rope_index_kwargs["audio_seqlens"] = audio_feature_lengths # prepare for input
|
||||
# When fake image/audio was injected, sequence_boundaries no longer match the tensor; use non-packing path.
|
||||
if not has_packing:
|
||||
self._compute_rope_position_ids(features, mm_inputs)
|
||||
else:
|
||||
if is_omni: # TODO: support omni models for packed sequences @kuangdd
|
||||
raise RuntimeError("Omni models are not supported for packed sequences for now.")
|
||||
|
||||
features["position_ids"], rope_deltas = self.get_rope_func(**rope_index_kwargs)
|
||||
features["rope_deltas"] = rope_deltas - (1 - rope_index_kwargs["attention_mask"]).sum(
|
||||
dim=-1
|
||||
).unsqueeze(-1)
|
||||
else: # for qwen vl
|
||||
features["position_ids"], features["rope_deltas"] = self.get_rope_func(**rope_index_kwargs)
|
||||
self._compute_rope_position_ids_with_packing(
|
||||
features,
|
||||
mm_inputs,
|
||||
packing_params_list,
|
||||
batch_imglens,
|
||||
batch_vidlens,
|
||||
batch_audlens,
|
||||
has_dummy_image,
|
||||
)
|
||||
|
||||
# For transformers compatibility, after https://github.com/huggingface/transformers/issues/39400
|
||||
if features["position_ids"].dim() == 3:
|
||||
features["position_ids"] = torch.cat(
|
||||
[features["position_ids"][0].unsqueeze(0), features["position_ids"]], dim=0
|
||||
)
|
||||
|
||||
if (
|
||||
self.model is not None
|
||||
and getattr(self.model.config, "model_type", None)
|
||||
in [
|
||||
"glm4v",
|
||||
"glm_ocr",
|
||||
"Keye",
|
||||
"qwen2_vl",
|
||||
"qwen2_5_vl",
|
||||
"qwen2_5_omni_thinker",
|
||||
"qwen3_omni_moe_thinker",
|
||||
"qwen3_5",
|
||||
"qwen3_vl",
|
||||
"qwen3_vl_moe",
|
||||
]
|
||||
and getattr(self.model.config, "model_type", None) in MROPE_MODELS
|
||||
and ("position_ids" not in features or features["position_ids"].dim() != 3)
|
||||
):
|
||||
raise ValueError(f"{self.model.config.model_type} requires 3D position ids for mrope.")
|
||||
@@ -261,12 +484,53 @@ class SFTDataCollatorWith4DAttentionMask(MultiModalDataCollatorForSeq2Seq):
|
||||
block_diag_attn: bool = False
|
||||
attn_implementation: Literal["eager", "sdpa", "flash_attention_2"] = "eager"
|
||||
compute_dtype: "torch.dtype" = torch.float32
|
||||
neat_packing: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
if self.neat_packing and self.attn_implementation == "flash_attention_2":
|
||||
if self.model is not None and getattr(self.model.config, "model_type", None) in ["gemma4", "gpt_oss"]:
|
||||
raise ValueError("Neat packing is not supported for gemma4, gpt_oss models for now.")
|
||||
|
||||
@staticmethod
|
||||
def _unpad_packed_features(features: dict[str, Any]) -> None:
|
||||
r"""Trim padded positions for packed FA2 batches."""
|
||||
attention_mask = features.get("attention_mask")
|
||||
if not torch.is_tensor(attention_mask) or attention_mask.dim() != 2 or attention_mask.size(0) != 1:
|
||||
return
|
||||
|
||||
seq_len = attention_mask.size(1)
|
||||
non_padding_indices = torch.nonzero(attention_mask[0] != 0, as_tuple=False).flatten()
|
||||
if non_padding_indices.numel() == seq_len:
|
||||
return
|
||||
|
||||
keys_on_seq_dim_1 = {"input_ids", "labels", "attention_mask", "token_type_ids"}
|
||||
for key, value in list(features.items()):
|
||||
if not torch.is_tensor(value):
|
||||
continue
|
||||
|
||||
if key == "position_ids" and value.size(-1) == seq_len:
|
||||
features[key] = value.index_select(-1, non_padding_indices)
|
||||
elif (
|
||||
key == "cross_attention_mask" and value.dim() >= 2 and value.size(0) == 1 and value.size(1) == seq_len
|
||||
):
|
||||
features[key] = value.index_select(1, non_padding_indices)
|
||||
elif key in keys_on_seq_dim_1 and value.dim() == 2 and value.size(0) == 1 and value.size(1) == seq_len:
|
||||
features[key] = value.index_select(1, non_padding_indices)
|
||||
|
||||
def __call__(self, features: list[dict[str, Any]]) -> dict[str, "torch.Tensor"]:
|
||||
features = super().__call__(features)
|
||||
has_dummy_image = features.pop("has_dummy_image", False)
|
||||
if self.block_diag_attn and self.attn_implementation != "flash_attention_2":
|
||||
features["attention_mask"] = prepare_4d_attention_mask(features["attention_mask"], self.compute_dtype)
|
||||
|
||||
if self.neat_packing and self.attn_implementation == "flash_attention_2": # FIXME compatibility fa3/fa4
|
||||
assert features["input_ids"].shape[0] == 1, "bsz should be 1 for neat packing"
|
||||
if not has_dummy_image:
|
||||
self._unpad_packed_features(features)
|
||||
|
||||
features["attention_mask"] = None # let transformers handle causal packed mask.
|
||||
|
||||
for key, value in features.items(): # cast data dtype for paligemma
|
||||
if torch.is_tensor(value) and torch.is_floating_point(value):
|
||||
features[key] = value.to(self.compute_dtype)
|
||||
|
||||
@@ -257,8 +257,8 @@ class OpenAIDatasetConverter(DatasetConverter):
|
||||
content = message[self.dataset_attr.content_tag]
|
||||
|
||||
if role in [self.dataset_attr.assistant_tag, self.dataset_attr.function_tag]:
|
||||
if "tool_calls" in message and len(message["tool_calls"]) > 0:
|
||||
tool_calls_list = [tool["function"] for tool in message["tool_calls"]]
|
||||
if tool_calls := message.get("tool_calls"):
|
||||
tool_calls_list = [tool["function"] for tool in tool_calls]
|
||||
content = json.dumps(tool_calls_list, ensure_ascii=False)
|
||||
role = self.dataset_attr.function_tag
|
||||
|
||||
|
||||
@@ -196,7 +196,7 @@ def read_cloud_json(cloud_path: str) -> list[Any]:
|
||||
|
||||
# filter out non-JSON files
|
||||
files = [x["Key"] for x in fs.listdir(cloud_path)] if fs.isdir(cloud_path) else [cloud_path]
|
||||
files = filter(lambda file: file.endswith(".json") or file.endswith(".jsonl"), files)
|
||||
files = list(filter(lambda file: file.endswith(".json") or file.endswith(".jsonl"), files))
|
||||
if not files:
|
||||
raise ValueError(f"No JSON/JSONL files found in the specified path: {cloud_path}.")
|
||||
|
||||
|
||||
@@ -22,16 +22,18 @@ import re
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from io import BytesIO
|
||||
from typing import TYPE_CHECKING, BinaryIO, Literal, NotRequired, Optional, TypedDict, Union
|
||||
from types import SimpleNamespace
|
||||
from typing import TYPE_CHECKING, Any, BinaryIO, Literal, NotRequired, Optional, TypedDict, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchaudio
|
||||
from transformers.image_utils import get_image_size, is_valid_image, to_numpy_array
|
||||
from transformers.image_utils import get_image_size, is_valid_image, make_flat_list_of_images, to_numpy_array
|
||||
from transformers.models.mllama.processing_mllama import (
|
||||
convert_sparse_cross_attention_mask_to_dense,
|
||||
get_cross_attention_token_mask,
|
||||
)
|
||||
from transformers.video_utils import make_batched_videos
|
||||
from typing_extensions import override
|
||||
|
||||
from ..extras.constants import AUDIO_PLACEHOLDER, IGNORE_INDEX, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
|
||||
@@ -47,13 +49,6 @@ if is_pyav_available():
|
||||
import av
|
||||
|
||||
|
||||
if is_transformers_version_greater_than("4.52.0"):
|
||||
from transformers.image_utils import make_flat_list_of_images
|
||||
from transformers.video_utils import make_batched_videos
|
||||
else:
|
||||
from transformers.image_utils import make_batched_videos, make_flat_list_of_images
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from av.stream import Stream
|
||||
from numpy.typing import NDArray
|
||||
@@ -161,7 +156,9 @@ class MMPluginMixin:
|
||||
video_processor: BaseImageProcessor = getattr(
|
||||
processor, "video_processor", getattr(processor, "image_processor", None)
|
||||
)
|
||||
feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None)
|
||||
feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None) or getattr(
|
||||
processor, "audio_processor", None
|
||||
)
|
||||
if len(images) != 0 and self.image_token is None:
|
||||
raise ValueError(
|
||||
"This model does not support image input. Please check whether the correct `template` is used."
|
||||
@@ -249,6 +246,14 @@ class MMPluginMixin:
|
||||
sample_frames = min(total_frames, video_maxlen, sample_frames)
|
||||
return np.linspace(0, total_frames - 1, sample_frames).astype(np.int32)
|
||||
|
||||
def _get_video_token_metadata(
|
||||
self,
|
||||
videos: list["VideoInput"],
|
||||
processor: "MMProcessor",
|
||||
) -> Optional[dict[str, Any]]:
|
||||
r"""Build metadata used to expand video tokens without decoding frames."""
|
||||
return None
|
||||
|
||||
def _regularize_images(self, images: list["ImageInput"], **kwargs) -> "RegularizedImageOutput":
|
||||
r"""Regularize images to avoid error. Including reading and pre-processing."""
|
||||
results = []
|
||||
@@ -390,7 +395,9 @@ class MMPluginMixin:
|
||||
mm_inputs.update(video_processor(videos, return_tensors="pt"))
|
||||
|
||||
if len(audios) != 0:
|
||||
feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None)
|
||||
feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None) or getattr(
|
||||
processor, "audio_processor", None
|
||||
)
|
||||
audios = self._regularize_audios(
|
||||
audios,
|
||||
sampling_rate=getattr(processor, "audio_sampling_rate", 16000),
|
||||
@@ -609,6 +616,217 @@ class Gemma3nPlugin(Gemma3Plugin):
|
||||
return messages
|
||||
|
||||
|
||||
@dataclass
|
||||
class Gemma4Plugin(BasePlugin):
|
||||
r"""Plugin for the Gemma4 multimodal model."""
|
||||
|
||||
@override
|
||||
def _regularize_videos(self, videos: list["VideoInput"], **kwargs) -> "RegularizedVideoOutput":
|
||||
r"""Regularize videos, also tracking per-video FPS and frame indices for timestamp generation."""
|
||||
results, fps_per_video, durations, frames_indices = [], [], [], []
|
||||
for video in videos:
|
||||
frames: list[ImageObject] = []
|
||||
if _check_video_is_nested_images(video):
|
||||
frames = video
|
||||
fps_per_video.append(kwargs.get("video_fps", 2.0))
|
||||
durations.append(len(frames) / kwargs.get("video_fps", 2.0))
|
||||
frames_indices.append(list(range(len(frames))))
|
||||
else:
|
||||
container = av.open(video, "r")
|
||||
video_stream = next(stream for stream in container.streams if stream.type == "video")
|
||||
sample_indices = self._get_video_sample_indices(video_stream, **kwargs)
|
||||
original_fps = float(video_stream.average_rate)
|
||||
# for correctly calculate timestamps
|
||||
frames_indices.append([idx / original_fps * kwargs.get("video_fps", 2.0) for idx in sample_indices])
|
||||
container.seek(0)
|
||||
for frame_idx, frame in enumerate(container.decode(video_stream)):
|
||||
if frame_idx in sample_indices:
|
||||
frames.append(frame.to_image())
|
||||
|
||||
if video_stream.duration is None:
|
||||
durations.append(len(frames) / kwargs.get("video_fps", 2.0))
|
||||
else:
|
||||
durations.append(float(video_stream.duration * video_stream.time_base))
|
||||
|
||||
frames = self._regularize_images(frames, **kwargs)["images"]
|
||||
results.append(frames)
|
||||
|
||||
return {
|
||||
"videos": results,
|
||||
"fps_per_video": fps_per_video,
|
||||
"durations": durations,
|
||||
"frames_indices": frames_indices,
|
||||
}
|
||||
|
||||
@override
|
||||
def _get_mm_inputs(
|
||||
self,
|
||||
images: list["ImageInput"],
|
||||
videos: list["VideoInput"],
|
||||
audios: list["AudioInput"],
|
||||
processor: "MMProcessor",
|
||||
) -> dict[str, Union[list[int], "torch.Tensor"]]:
|
||||
image_processor = getattr(processor, "image_processor", None)
|
||||
video_processor = getattr(processor, "video_processor", None)
|
||||
feature_extractor = getattr(processor, "feature_extractor", None)
|
||||
mm_inputs = {}
|
||||
|
||||
if len(images) != 0:
|
||||
regularized = self._regularize_images(
|
||||
images,
|
||||
image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768),
|
||||
image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32),
|
||||
)["images"]
|
||||
mm_inputs.update(image_processor(regularized, return_tensors="pt"))
|
||||
|
||||
if len(videos) != 0:
|
||||
video_data = self._regularize_videos(
|
||||
videos,
|
||||
image_max_pixels=getattr(processor, "video_max_pixels", 256 * 256),
|
||||
image_min_pixels=getattr(processor, "video_min_pixels", 16 * 16),
|
||||
video_fps=getattr(processor, "video_fps", 2.0),
|
||||
video_maxlen=getattr(processor, "video_maxlen", 128),
|
||||
)
|
||||
video_metadata = [
|
||||
{
|
||||
"fps": getattr(processor, "video_fps", 2.0),
|
||||
"duration": duration,
|
||||
"total_num_frames": len(video),
|
||||
"frames_indices": sample_indices,
|
||||
}
|
||||
for video, duration, sample_indices in zip(
|
||||
video_data["videos"], video_data["durations"], video_data["frames_indices"]
|
||||
)
|
||||
]
|
||||
mm_inputs.update(
|
||||
video_processor(
|
||||
videos=video_data["videos"],
|
||||
video_metadata=video_metadata,
|
||||
return_tensors="pt",
|
||||
return_metadata=True,
|
||||
do_sample_frames=False,
|
||||
)
|
||||
)
|
||||
|
||||
if len(audios) != 0: # only for gemma4n
|
||||
audios = self._regularize_audios(
|
||||
audios,
|
||||
sampling_rate=getattr(processor, "audio_sampling_rate", 16000),
|
||||
)["audios"]
|
||||
|
||||
mm_inputs.update(
|
||||
feature_extractor(
|
||||
audios,
|
||||
padding="max_length",
|
||||
return_tensors="pt",
|
||||
)
|
||||
)
|
||||
|
||||
return mm_inputs
|
||||
|
||||
@override
|
||||
def process_messages(
|
||||
self,
|
||||
messages: list[dict[str, str]],
|
||||
images: list["ImageInput"],
|
||||
videos: list["VideoInput"],
|
||||
audios: list["AudioInput"],
|
||||
processor: Optional["MMProcessor"],
|
||||
) -> list[dict[str, str]]:
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
self._validate_messages(messages, images, videos, audios)
|
||||
messages = deepcopy(messages)
|
||||
|
||||
boi_token: str = getattr(processor, "boi_token")
|
||||
eoi_token: str = getattr(processor, "eoi_token")
|
||||
boa_token: str = getattr(processor, "boa_token")
|
||||
eoa_token: str = getattr(processor, "eoa_token")
|
||||
image_token: str = getattr(processor, "image_token")
|
||||
video_token: str = getattr(processor, "video_token")
|
||||
audio_token: str = getattr(processor, "audio_token")
|
||||
|
||||
if self.expand_mm_tokens:
|
||||
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
||||
num_image_soft_tokens: list[int] = list(
|
||||
mm_inputs.get("num_soft_tokens_per_image", [getattr(processor, "image_seq_length", 256)] * len(images))
|
||||
)
|
||||
num_video_soft_tokens: list[int] = list(mm_inputs.get("num_soft_tokens_per_video", [1] * len(videos)))
|
||||
video_metadata = mm_inputs.get("video_metadata", [])
|
||||
else:
|
||||
num_image_soft_tokens = [1] * len(images)
|
||||
num_video_soft_tokens = [1] * len(videos)
|
||||
video_metadata = [None] * len(videos)
|
||||
|
||||
audio_iter = iter(audios)
|
||||
image_iter = iter(num_image_soft_tokens)
|
||||
video_iter = iter(zip(num_video_soft_tokens, video_metadata))
|
||||
|
||||
for message in messages:
|
||||
content = message["content"]
|
||||
|
||||
while IMAGE_PLACEHOLDER in content:
|
||||
n = next(image_iter)
|
||||
content = content.replace(IMAGE_PLACEHOLDER, f"{boi_token}{image_token * n}{eoi_token}", 1)
|
||||
|
||||
while VIDEO_PLACEHOLDER in content:
|
||||
num_soft_tokens_per_frame, metadata = next(video_iter)
|
||||
if self.expand_mm_tokens:
|
||||
timestamp_strs = [f"{int(t // 60):02d}:{int(t % 60):02d}" for t in metadata.timestamps]
|
||||
frame_strs = [
|
||||
f"{ts} {boi_token}{video_token * num_soft_tokens_per_frame}{eoi_token}"
|
||||
for ts in timestamp_strs
|
||||
]
|
||||
video_str = " ".join(frame_strs)
|
||||
else:
|
||||
video_str = f"{boi_token}{video_token * num_soft_tokens_per_frame}{eoi_token}"
|
||||
content = content.replace(VIDEO_PLACEHOLDER, video_str, 1)
|
||||
|
||||
while AUDIO_PLACEHOLDER in content:
|
||||
current_audio = next(audio_iter)
|
||||
if self.expand_mm_tokens:
|
||||
num_audio_tokens = processor._compute_audio_num_tokens(
|
||||
current_audio, processor.feature_extractor.sampling_rate
|
||||
)
|
||||
audio_str = f"{boa_token}{audio_token * num_audio_tokens}{eoa_token}"
|
||||
else:
|
||||
audio_str = f"{boa_token}{audio_token}{eoa_token}"
|
||||
|
||||
content = content.replace(AUDIO_PLACEHOLDER, audio_str, 1)
|
||||
|
||||
message["content"] = content
|
||||
|
||||
return messages
|
||||
|
||||
@override
|
||||
def get_mm_inputs(
|
||||
self,
|
||||
images: list["ImageInput"],
|
||||
videos: list["VideoInput"],
|
||||
audios: list["AudioInput"],
|
||||
imglens: list[int],
|
||||
vidlens: list[int],
|
||||
audlens: list[int],
|
||||
batch_ids: list[list[int]],
|
||||
processor: Optional["MMProcessor"],
|
||||
) -> dict[str, Union[list[int], "torch.Tensor"]]:
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
||||
# Pop metadata keys that must not be passed to the model.
|
||||
for key in (
|
||||
"num_soft_tokens_per_image",
|
||||
"num_soft_tokens_per_video",
|
||||
"video_metadata",
|
||||
"_gemma4_fps_per_video",
|
||||
"_gemma4_frames_indices",
|
||||
"_gemma4_num_audio_soft_tokens",
|
||||
):
|
||||
mm_inputs.pop(key, None)
|
||||
|
||||
mm_inputs["mm_token_type_ids"] = processor.create_mm_token_type_ids(batch_ids)
|
||||
|
||||
return mm_inputs
|
||||
|
||||
|
||||
@dataclass
|
||||
class InternVLPlugin(BasePlugin):
|
||||
@override
|
||||
@@ -1054,7 +1272,9 @@ class MiniCPMVPlugin(BasePlugin):
|
||||
chunk_input=True,
|
||||
sampling_rate=getattr(processor, "audio_sampling_rate", 16000),
|
||||
)
|
||||
audio_feature_lens = [torch.tensor(audio_feature_len) for audio_feature_len in audio_feature_lens]
|
||||
audio_feature_lens = [
|
||||
x.clone().detach() if isinstance(x, torch.Tensor) else torch.tensor(x) for x in audio_feature_lens
|
||||
]
|
||||
mm_inputs.update({"audio_features": audio_features, "audio_feature_lens": audio_feature_lens})
|
||||
if kwargs.get("ret_phs", False):
|
||||
mm_inputs.update({"audio_phs": audio_phs})
|
||||
@@ -1094,7 +1314,7 @@ class MiniCPMVPlugin(BasePlugin):
|
||||
num_image_tokens += 1
|
||||
|
||||
while VIDEO_PLACEHOLDER in content:
|
||||
video_seqlen = len(mm_inputs["pixel_values"][num_video_tokens]) if self.expand_mm_tokens else 1
|
||||
video_seqlen = len(mm_inputs["image_sizes"][num_video_tokens]) if self.expand_mm_tokens else 1
|
||||
content = content.replace(VIDEO_PLACEHOLDER, "{{image}}" * video_seqlen, 1)
|
||||
num_video_tokens += 1
|
||||
|
||||
@@ -1221,6 +1441,225 @@ class MiniCPMVPlugin(BasePlugin):
|
||||
return mm_inputs
|
||||
|
||||
|
||||
@dataclass
|
||||
class MiniCPMV4_6Plugin(BasePlugin):
|
||||
"""Plugin for MiniCPM-V-4.6 with new transformers (NaViT vision + get_placeholder_mask API)."""
|
||||
|
||||
def _get_mm_inputs(
|
||||
self,
|
||||
images: list["ImageInput"],
|
||||
videos: list["VideoInput"],
|
||||
audios: list["AudioInput"],
|
||||
processor: "MMProcessor",
|
||||
**kwargs,
|
||||
) -> dict[str, "torch.Tensor"]:
|
||||
image_processor = getattr(processor, "image_processor")
|
||||
video_processor = getattr(processor, "video_processor", None)
|
||||
mm_inputs = {}
|
||||
|
||||
if len(images) != 0:
|
||||
# The image_processor ignores downsample_mode; target_sizes are always based on patch_size.
|
||||
# downsample_mode only affects the token divisor in _build_v4_6_placeholder and model forward.
|
||||
mm_inputs.update(image_processor(images, return_tensors="pt"))
|
||||
|
||||
if len(videos) != 0:
|
||||
if video_processor is not None:
|
||||
video_inputs = video_processor(videos, return_tensors="pt")
|
||||
mm_inputs["pixel_values_videos"] = video_inputs["pixel_values_videos"]
|
||||
mm_inputs["target_sizes_videos"] = video_inputs["target_sizes_videos"]
|
||||
else:
|
||||
video_inputs = image_processor(videos, return_tensors="pt")
|
||||
mm_inputs["pixel_values_videos"] = video_inputs["pixel_values"]
|
||||
mm_inputs["target_sizes_videos"] = video_inputs["target_sizes"]
|
||||
|
||||
if len(audios) != 0:
|
||||
audio_features, audio_feature_lens, audio_phs = processor.audio_feature_extract(
|
||||
[audios],
|
||||
chunk_input=True,
|
||||
sampling_rate=getattr(processor, "audio_sampling_rate", 16000),
|
||||
)
|
||||
audio_feature_lens = [
|
||||
x.clone().detach() if isinstance(x, torch.Tensor) else torch.tensor(x) for x in audio_feature_lens
|
||||
]
|
||||
mm_inputs.update({"audio_features": audio_features, "audio_feature_lens": audio_feature_lens})
|
||||
if kwargs.get("ret_phs", False):
|
||||
mm_inputs.update({"audio_phs": audio_phs})
|
||||
|
||||
return mm_inputs
|
||||
|
||||
def _build_v4_6_placeholder(
|
||||
self,
|
||||
image_inputs: dict[str, Any],
|
||||
image_idx: int,
|
||||
use_image_id: bool,
|
||||
processor: "MMProcessor",
|
||||
) -> str:
|
||||
"""Build image placeholder for MiniCPM-V-4.6 using NaViT token count computation."""
|
||||
grids = image_inputs.get("grids", [[0, 0]])
|
||||
num_patches_per_image = image_inputs.get("num_patches_per_image", [1])
|
||||
target_sizes = image_inputs.get("target_sizes")
|
||||
|
||||
downsample_mode = os.getenv("DOWNSAMPLE_MODE")
|
||||
if downsample_mode is None:
|
||||
image_processor = getattr(processor, "image_processor")
|
||||
downsample_mode = getattr(image_processor, "downsample_mode", "16x")
|
||||
token_divisor = 4 if downsample_mode == "4x" else 16
|
||||
|
||||
flat_index = 0
|
||||
for idx in range(image_idx):
|
||||
flat_index += num_patches_per_image[idx]
|
||||
n_patches = num_patches_per_image[image_idx]
|
||||
|
||||
img_target_sizes = target_sizes[flat_index : flat_index + n_patches]
|
||||
num_tokens_per_patch = img_target_sizes.prod(-1) // token_divisor
|
||||
num_rows, num_cols = grids[image_idx]
|
||||
|
||||
image_start = getattr(processor, "image_start_token", "<image>")
|
||||
image_end = getattr(processor, "image_end_token", "</image>")
|
||||
slice_start = getattr(processor, "slice_start_token", "<slice>")
|
||||
slice_end = getattr(processor, "slice_end_token", "</slice>")
|
||||
image_id_start = getattr(processor, "image_id_start_token", "<image_id>")
|
||||
image_id_end = getattr(processor, "image_id_end_token", "</image_id>")
|
||||
image_token = (
|
||||
getattr(processor, "image_token", None)
|
||||
or getattr(getattr(processor, "tokenizer", None), "image_token", None)
|
||||
or "<image>"
|
||||
)
|
||||
|
||||
image_placeholder = image_start + "<|ph|>" * int(num_tokens_per_patch[0]) + image_end
|
||||
if use_image_id:
|
||||
image_placeholder = f"{image_id_start}{image_idx}{image_id_end}" + image_placeholder
|
||||
|
||||
slice_mode = getattr(processor, "slice_mode", True)
|
||||
if slice_mode and num_rows > 0 and num_cols > 0:
|
||||
per_slice_tokens = int(num_tokens_per_patch[1]) if len(num_tokens_per_patch) > 1 else 0
|
||||
slice_placeholder = slice_start + "<|ph|>" * per_slice_tokens + slice_end
|
||||
slices = [slice_placeholder * num_cols for _ in range(num_rows)]
|
||||
image_placeholder += "\n".join(slices)
|
||||
|
||||
return image_placeholder.replace("<|ph|>", image_token)
|
||||
|
||||
@override
|
||||
def process_messages(
|
||||
self,
|
||||
messages: list[dict[str, str]],
|
||||
images: list["ImageInput"],
|
||||
videos: list["VideoInput"],
|
||||
audios: list["AudioInput"],
|
||||
processor: Optional["MMProcessor"],
|
||||
) -> list[dict[str, str]]:
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
self._validate_messages(messages, images, videos, audios)
|
||||
num_image_tokens, num_video_tokens, num_audio_tokens = 0, 0, 0
|
||||
messages = deepcopy(messages)
|
||||
mm_inputs, audio_inputs = {}, {}
|
||||
if len(images) != 0 and len(videos) != 0:
|
||||
raise ValueError("MiniCPM-V model does not support input images and videos at the same time.")
|
||||
|
||||
use_image_id = getattr(processor, "default_use_image_id", True)
|
||||
|
||||
if len(videos) != 0:
|
||||
use_image_id = False
|
||||
mm_inputs = self._get_mm_inputs([], videos, [], processor)
|
||||
|
||||
for i, message in enumerate(messages):
|
||||
content = message["content"]
|
||||
while IMAGE_PLACEHOLDER in content:
|
||||
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1)
|
||||
num_image_tokens += 1
|
||||
|
||||
while VIDEO_PLACEHOLDER in content:
|
||||
num_frames = 1
|
||||
if "num_frames_per_video" in mm_inputs:
|
||||
num_frames = sum(mm_inputs["num_frames_per_video"])
|
||||
content = content.replace(VIDEO_PLACEHOLDER, "{{image}}" * num_frames, 1)
|
||||
num_video_tokens += 1
|
||||
|
||||
while AUDIO_PLACEHOLDER in content:
|
||||
content = content.replace(AUDIO_PLACEHOLDER, "{{audio}}", 1)
|
||||
num_audio_tokens += 1
|
||||
|
||||
message["content"] = content.replace("{{image}}", "(<image>./</image>)").replace(
|
||||
"{{audio}}", "(<audio>./</audio>)"
|
||||
)
|
||||
|
||||
if len(images):
|
||||
mm_inputs = self._get_mm_inputs(images, [], [], processor)
|
||||
|
||||
if len(audios):
|
||||
audio_inputs = self._get_mm_inputs([], [], audios, processor, ret_phs=True)
|
||||
|
||||
if self.expand_mm_tokens and mm_inputs:
|
||||
pattern = "(<image>./</image>)"
|
||||
idx = 0
|
||||
for index, message in enumerate(messages):
|
||||
text = message["content"]
|
||||
image_tags = re.findall(pattern, text)
|
||||
text_chunks = text.split(pattern)
|
||||
final_text = ""
|
||||
for i in range(len(image_tags)):
|
||||
image_placeholder = self._build_v4_6_placeholder(mm_inputs, idx, use_image_id, processor)
|
||||
final_text = final_text + text_chunks[i] + image_placeholder
|
||||
idx += 1
|
||||
final_text += text_chunks[-1]
|
||||
messages[index]["content"] = final_text
|
||||
|
||||
if self.expand_mm_tokens and audio_inputs:
|
||||
pattern = "(<audio>./</audio>)"
|
||||
idx = 0
|
||||
for index, message in enumerate(messages):
|
||||
text = message["content"]
|
||||
audio_tags = re.findall(pattern, text)
|
||||
text_chunks = text.split(pattern)
|
||||
final_text = ""
|
||||
for i in range(len(audio_tags)):
|
||||
audio_placeholder = audio_inputs["audio_phs"][0][idx]
|
||||
final_text = final_text + text_chunks[i] + audio_placeholder
|
||||
idx += 1
|
||||
final_text += text_chunks[-1]
|
||||
messages[index]["content"] = final_text
|
||||
|
||||
return messages
|
||||
|
||||
@override
|
||||
def get_mm_inputs(
|
||||
self,
|
||||
images: list["ImageInput"],
|
||||
videos: list["VideoInput"],
|
||||
audios: list["AudioInput"],
|
||||
imglens: list[int],
|
||||
vidlens: list[int],
|
||||
audlens: list[int],
|
||||
batch_ids: list[list[int]],
|
||||
processor: Optional["MMProcessor"],
|
||||
) -> dict[str, Union[list[int], "torch.Tensor"]]:
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
||||
|
||||
# v4.6 does NOT use image_bound — the model finds image tokens via get_placeholder_mask
|
||||
# Ensure target_sizes key name matches the model's expected input
|
||||
if "target_sizes" not in mm_inputs and "tgt_sizes" in mm_inputs:
|
||||
mm_inputs["target_sizes"] = mm_inputs.pop("tgt_sizes")
|
||||
|
||||
if "target_sizes" not in mm_inputs:
|
||||
mm_inputs["target_sizes"] = torch.empty(0, 2, dtype=torch.int32)
|
||||
|
||||
if "pixel_values" not in mm_inputs:
|
||||
mm_inputs["pixel_values"] = torch.empty(1, 3, 14, 0)
|
||||
|
||||
# Pass downsample_mode to model forward so it matches the placeholder divisor
|
||||
_ds = os.getenv("DOWNSAMPLE_MODE")
|
||||
if _ds is None:
|
||||
_ds = getattr(getattr(processor, "image_processor", None), "downsample_mode", "16x")
|
||||
mm_inputs["downsample_mode"] = _ds
|
||||
|
||||
if len(audios) > 0:
|
||||
audio_inputs = self._get_mm_inputs([], [], audios, processor)
|
||||
mm_inputs.update(audio_inputs)
|
||||
|
||||
return mm_inputs
|
||||
|
||||
|
||||
@dataclass
|
||||
class MllamaPlugin(BasePlugin):
|
||||
@override
|
||||
@@ -1489,10 +1928,11 @@ class Qwen2VLPlugin(BasePlugin):
|
||||
|
||||
@override
|
||||
def _regularize_videos(self, videos: list["VideoInput"], **kwargs) -> "RegularizedVideoOutput":
|
||||
results, fps_per_video, durations = [], [], []
|
||||
results, fps_per_video, durations, frames_indices = [], [], [], []
|
||||
for video in videos:
|
||||
frames: list[ImageObject] = []
|
||||
if _check_video_is_nested_images(video):
|
||||
# we assume already sample frames from videos
|
||||
for frame in video:
|
||||
if not is_valid_image(frame) and not isinstance(frame, dict) and not os.path.exists(frame):
|
||||
raise ValueError("Invalid image found in video frames.")
|
||||
@@ -1500,10 +1940,16 @@ class Qwen2VLPlugin(BasePlugin):
|
||||
frames = video
|
||||
fps_per_video.append(kwargs.get("video_fps", 2.0))
|
||||
durations.append(len(frames) / kwargs.get("video_fps", 2.0))
|
||||
frames_indices.append(list(range(len(frames))))
|
||||
else:
|
||||
container = av.open(video, "r")
|
||||
video_stream = next(stream for stream in container.streams if stream.type == "video")
|
||||
sample_indices = self._get_video_sample_indices(video_stream, **kwargs)
|
||||
original_fps = float(video_stream.average_rate)
|
||||
# for qwen3vl video timestamp calculation
|
||||
frames_indices.append(
|
||||
[idx / original_fps * kwargs.get("video_fps", 2.0) for idx in sample_indices]
|
||||
) # hack usage when do_sample_frames=False
|
||||
container.seek(0)
|
||||
for frame_idx, frame in enumerate(container.decode(video_stream)):
|
||||
if frame_idx in sample_indices:
|
||||
@@ -1522,7 +1968,205 @@ class Qwen2VLPlugin(BasePlugin):
|
||||
frames = self._regularize_images(frames, **kwargs)["images"]
|
||||
results.append(frames)
|
||||
|
||||
return {"videos": results, "fps_per_video": fps_per_video, "durations": durations}
|
||||
return {
|
||||
"videos": results,
|
||||
"fps_per_video": fps_per_video,
|
||||
"durations": durations,
|
||||
"frames_indices": frames_indices,
|
||||
}
|
||||
|
||||
def _get_qwen_video_size_after_regularization(
|
||||
self, width: int, height: int, image_max_pixels: int, image_min_pixels: int
|
||||
) -> tuple[int, int]:
|
||||
r"""Compute the frame size produced by Qwen-VL image regularization."""
|
||||
if (width * height) > image_max_pixels:
|
||||
resize_factor = math.sqrt(image_max_pixels / (width * height))
|
||||
width, height = int(width * resize_factor), int(height * resize_factor)
|
||||
|
||||
if (width * height) < image_min_pixels:
|
||||
resize_factor = math.sqrt(image_min_pixels / (width * height))
|
||||
width, height = int(width * resize_factor), int(height * resize_factor)
|
||||
|
||||
if min(width, height) < 28:
|
||||
width, height = max(width, 28), max(height, 28)
|
||||
|
||||
if width / height > 200:
|
||||
width, height = height * 180, height
|
||||
|
||||
if height / width > 200:
|
||||
width, height = width, width * 180
|
||||
|
||||
return width, height
|
||||
|
||||
def _get_qwen_video_stream_metadata(
|
||||
self,
|
||||
video: "VideoInput",
|
||||
video_fps: float,
|
||||
video_maxlen: int,
|
||||
) -> Optional[dict[str, Any]]:
|
||||
if not is_pyav_available() or not isinstance(video, (str, os.PathLike)):
|
||||
return None
|
||||
|
||||
try:
|
||||
container = av.open(video, "r")
|
||||
except (av.FFmpegError, OSError):
|
||||
return None
|
||||
|
||||
try:
|
||||
video_stream = next((stream for stream in container.streams if stream.type == "video"), None)
|
||||
if video_stream is None:
|
||||
return None
|
||||
|
||||
if video_stream.duration is None or video_stream.average_rate is None:
|
||||
return None
|
||||
|
||||
average_fps = float(video_stream.average_rate)
|
||||
if average_fps <= 0:
|
||||
return None
|
||||
|
||||
sample_indices = self._get_video_sample_indices(
|
||||
video_stream, video_fps=video_fps, video_maxlen=video_maxlen
|
||||
)
|
||||
return {
|
||||
"width": video_stream.width,
|
||||
"height": video_stream.height,
|
||||
"average_fps": average_fps,
|
||||
"sample_indices": sample_indices,
|
||||
}
|
||||
finally:
|
||||
container.close()
|
||||
|
||||
def _get_qwen_video_resize(
|
||||
self,
|
||||
num_frames: int,
|
||||
height: int,
|
||||
width: int,
|
||||
patch_size: int,
|
||||
temporal_patch_size: int,
|
||||
merge_size: int,
|
||||
min_pixels: int,
|
||||
max_pixels: int,
|
||||
) -> tuple[int, int]:
|
||||
from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize
|
||||
|
||||
return smart_resize(
|
||||
height=height,
|
||||
width=width,
|
||||
factor=patch_size * merge_size,
|
||||
min_pixels=min_pixels,
|
||||
max_pixels=max_pixels,
|
||||
)
|
||||
|
||||
def _get_qwen_video_grid_metadata(
|
||||
self,
|
||||
videos: list["VideoInput"],
|
||||
processor: "MMProcessor",
|
||||
) -> Optional[dict[str, Any]]:
|
||||
if len(videos) == 0:
|
||||
return {"video_grid_thw": torch.empty((0, 3), dtype=torch.long), "frames_indices": [], "fps": 2.0}
|
||||
|
||||
image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
|
||||
video_processor: BaseVideoProcessor = getattr(processor, "video_processor", None) or image_processor
|
||||
if image_processor is None or video_processor is None:
|
||||
return None
|
||||
|
||||
patch_size = getattr(video_processor, "patch_size", None)
|
||||
temporal_patch_size = getattr(video_processor, "temporal_patch_size", None)
|
||||
merge_size = getattr(video_processor, "merge_size", None)
|
||||
size = getattr(video_processor, "size", None)
|
||||
if patch_size is None or temporal_patch_size is None or merge_size is None or size is None:
|
||||
return None
|
||||
|
||||
if isinstance(size, dict):
|
||||
min_pixels = size.get("shortest_edge")
|
||||
max_pixels = size.get("longest_edge")
|
||||
else:
|
||||
min_pixels = getattr(size, "shortest_edge", None)
|
||||
max_pixels = getattr(size, "longest_edge", None)
|
||||
|
||||
if min_pixels is None or max_pixels is None:
|
||||
return None
|
||||
|
||||
video_fps = getattr(processor, "video_fps", 2.0)
|
||||
video_maxlen = getattr(processor, "video_maxlen", 128)
|
||||
image_max_pixels = getattr(processor, "video_max_pixels", 256 * 256)
|
||||
image_min_pixels = getattr(processor, "video_min_pixels", 16 * 16)
|
||||
|
||||
video_grid_thw = []
|
||||
frames_indices = []
|
||||
for video in videos:
|
||||
metadata = self._get_qwen_video_stream_metadata(video, video_fps, video_maxlen)
|
||||
if metadata is None:
|
||||
return None
|
||||
|
||||
width, height = self._get_qwen_video_size_after_regularization(
|
||||
metadata["width"], metadata["height"], image_max_pixels, image_min_pixels
|
||||
)
|
||||
num_frames = len(metadata["sample_indices"])
|
||||
if num_frames % 2 != 0:
|
||||
num_frames += 1
|
||||
|
||||
resized_size = self._get_qwen_video_resize(
|
||||
num_frames,
|
||||
height,
|
||||
width,
|
||||
patch_size,
|
||||
temporal_patch_size,
|
||||
merge_size,
|
||||
min_pixels,
|
||||
max_pixels,
|
||||
)
|
||||
|
||||
resized_height, resized_width = resized_size
|
||||
video_grid_thw.append(
|
||||
[
|
||||
math.ceil(num_frames / temporal_patch_size),
|
||||
resized_height // patch_size,
|
||||
resized_width // patch_size,
|
||||
]
|
||||
)
|
||||
frames_indices.append([idx / metadata["average_fps"] * video_fps for idx in metadata["sample_indices"]])
|
||||
|
||||
return {
|
||||
"video_grid_thw": torch.tensor(video_grid_thw, dtype=torch.long),
|
||||
"frames_indices": frames_indices,
|
||||
"fps": video_fps,
|
||||
}
|
||||
|
||||
@override
|
||||
def _get_video_token_metadata(
|
||||
self,
|
||||
videos: list["VideoInput"],
|
||||
processor: "MMProcessor",
|
||||
) -> Optional[dict[str, Any]]:
|
||||
video_metadata = self._get_qwen_video_grid_metadata(videos, processor)
|
||||
if video_metadata is None:
|
||||
return None
|
||||
|
||||
return {"video_grid_thw": video_metadata["video_grid_thw"]}
|
||||
|
||||
def _get_mm_token_metadata(
|
||||
self,
|
||||
images: list["ImageInput"],
|
||||
videos: list["VideoInput"],
|
||||
audios: list["AudioInput"],
|
||||
processor: "MMProcessor",
|
||||
) -> Optional[dict[str, Any]]:
|
||||
if len(audios) != 0:
|
||||
return None
|
||||
|
||||
mm_inputs = {}
|
||||
if len(images) != 0:
|
||||
mm_inputs.update(self._get_mm_inputs(images, [], [], processor))
|
||||
|
||||
if len(videos) != 0:
|
||||
video_inputs = self._get_video_token_metadata(videos, processor)
|
||||
if video_inputs is None:
|
||||
return None
|
||||
|
||||
mm_inputs.update(video_inputs)
|
||||
|
||||
return mm_inputs
|
||||
|
||||
@override
|
||||
def _get_mm_inputs(
|
||||
@@ -1575,7 +2219,10 @@ class Qwen2VLPlugin(BasePlugin):
|
||||
|
||||
merge_length: int = getattr(image_processor, "merge_size") ** 2
|
||||
if self.expand_mm_tokens:
|
||||
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
||||
mm_inputs = self._get_mm_token_metadata(images, videos, audios, processor)
|
||||
if mm_inputs is None:
|
||||
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
||||
|
||||
image_grid_thw = mm_inputs.get("image_grid_thw", [])
|
||||
video_grid_thw = mm_inputs.get("video_grid_thw", [])
|
||||
else:
|
||||
@@ -1609,6 +2256,51 @@ class Qwen2VLPlugin(BasePlugin):
|
||||
|
||||
@dataclass
|
||||
class Qwen3VLPlugin(Qwen2VLPlugin):
|
||||
@override
|
||||
def _get_qwen_video_resize(
|
||||
self,
|
||||
num_frames: int,
|
||||
height: int,
|
||||
width: int,
|
||||
patch_size: int,
|
||||
temporal_patch_size: int,
|
||||
merge_size: int,
|
||||
min_pixels: int,
|
||||
max_pixels: int,
|
||||
) -> tuple[int, int]:
|
||||
from transformers.models.qwen3_vl.video_processing_qwen3_vl import smart_resize
|
||||
|
||||
return smart_resize(
|
||||
num_frames=num_frames,
|
||||
height=height,
|
||||
width=width,
|
||||
temporal_factor=temporal_patch_size,
|
||||
factor=patch_size * merge_size,
|
||||
min_pixels=min_pixels,
|
||||
max_pixels=max_pixels,
|
||||
)
|
||||
|
||||
@override
|
||||
def _get_video_token_metadata(
|
||||
self,
|
||||
videos: list["VideoInput"],
|
||||
processor: "MMProcessor",
|
||||
) -> Optional[dict[str, Any]]:
|
||||
video_metadata = self._get_qwen_video_grid_metadata(videos, processor)
|
||||
if video_metadata is None:
|
||||
return None
|
||||
|
||||
return {
|
||||
"video_grid_thw": video_metadata["video_grid_thw"],
|
||||
"video_metadata": [
|
||||
SimpleNamespace(
|
||||
frames_indices=frames_indices,
|
||||
fps=video_metadata["fps"],
|
||||
)
|
||||
for frames_indices in video_metadata["frames_indices"]
|
||||
],
|
||||
}
|
||||
|
||||
@override
|
||||
def _get_mm_inputs(
|
||||
self,
|
||||
@@ -1637,8 +2329,15 @@ class Qwen3VLPlugin(Qwen2VLPlugin):
|
||||
video_maxlen=getattr(processor, "video_maxlen", 128),
|
||||
)
|
||||
video_metadata = [
|
||||
{"fps": getattr(processor, "video_fps", 24.0), "duration": duration, "total_num_frames": len(video)}
|
||||
for video, duration in zip(videos["videos"], videos["durations"])
|
||||
{
|
||||
"fps": getattr(processor, "video_fps", 2.0),
|
||||
"duration": duration,
|
||||
"total_num_frames": len(video),
|
||||
"frames_indices": sample_indices,
|
||||
}
|
||||
for video, duration, sample_indices in zip(
|
||||
videos["videos"], videos["durations"], videos["frames_indices"]
|
||||
)
|
||||
]
|
||||
mm_inputs.update(
|
||||
video_processor(
|
||||
@@ -1646,6 +2345,7 @@ class Qwen3VLPlugin(Qwen2VLPlugin):
|
||||
video_metadata=video_metadata,
|
||||
fps=getattr(processor, "video_fps", 2.0),
|
||||
return_metadata=True,
|
||||
do_sample_frames=False, # avoid changing frames_indices
|
||||
)
|
||||
)
|
||||
temporal_patch_size: int = getattr(image_processor, "temporal_patch_size", 2)
|
||||
@@ -1673,11 +2373,14 @@ class Qwen3VLPlugin(Qwen2VLPlugin):
|
||||
image_merge_length: int = getattr(image_processor, "merge_size") ** 2
|
||||
video_merge_length: int = getattr(video_processor, "merge_size") ** 2
|
||||
if self.expand_mm_tokens:
|
||||
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
||||
mm_inputs = self._get_mm_token_metadata(images, videos, audios, processor)
|
||||
if mm_inputs is None:
|
||||
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
||||
|
||||
image_grid_thw = mm_inputs.get("image_grid_thw", [])
|
||||
video_grid_thw = mm_inputs.get("video_grid_thw", [])
|
||||
num_frames = video_grid_thw[0][0] if len(video_grid_thw) > 0 else 0 # hard code for now
|
||||
video_metadata = mm_inputs.get("video_metadata", {})
|
||||
video_metadata = mm_inputs.get("video_metadata", [])
|
||||
|
||||
else:
|
||||
image_grid_thw = [None] * len(images)
|
||||
@@ -1876,7 +2579,9 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
|
||||
) -> dict[str, "torch.Tensor"]:
|
||||
image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
|
||||
video_processor: BaseVideoProcessor = getattr(processor, "video_processor", None)
|
||||
feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None)
|
||||
feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None) or getattr(
|
||||
processor, "audio_processor", None
|
||||
)
|
||||
mm_inputs = {}
|
||||
if len(images) != 0:
|
||||
images = self._regularize_images(
|
||||
@@ -1981,6 +2686,7 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
|
||||
f"Each {VIDEO_PLACEHOLDER} must be followed by an {AUDIO_PLACEHOLDER} when using audio in video."
|
||||
)
|
||||
|
||||
position_id_per_seconds: int = getattr(processor, "position_id_per_seconds", 25)
|
||||
audio_t_index = torch.arange(audio_lengths[num_audio_tokens])
|
||||
video_t_index = (
|
||||
torch.arange(video_grid_thw[num_video_tokens][0])
|
||||
@@ -1992,9 +2698,9 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
|
||||
)
|
||||
.flatten()
|
||||
* mm_inputs["video_second_per_grid"][num_video_tokens]
|
||||
* 25 # FIXME hardcode of position_id_per_seconds=25
|
||||
* position_id_per_seconds
|
||||
).long()
|
||||
t_ntoken_per_chunk = 50 # FIXME hardcode: [25 * 2]
|
||||
t_ntoken_per_chunk = position_id_per_seconds * 2
|
||||
video_chunk_indices = processor.get_chunked_index(video_t_index, t_ntoken_per_chunk)
|
||||
audio_chunk_indices = processor.get_chunked_index(audio_t_index, t_ntoken_per_chunk)
|
||||
placeholder_string = ""
|
||||
@@ -2197,8 +2903,9 @@ PLUGINS = {
|
||||
"base": BasePlugin,
|
||||
"ernie_vl": ErnieVLPlugin,
|
||||
"gemma3": Gemma3Plugin,
|
||||
"glm4v": GLM4VPlugin,
|
||||
"gemma3n": Gemma3nPlugin,
|
||||
"gemma4": Gemma4Plugin,
|
||||
"glm4v": GLM4VPlugin,
|
||||
"intern_vl": InternVLPlugin,
|
||||
"kimi_vl": KimiVLPlugin,
|
||||
"llama4": Llama4Plugin,
|
||||
@@ -2207,6 +2914,7 @@ PLUGINS = {
|
||||
"llava_next_video": LlavaNextVideoPlugin,
|
||||
"lfm2_vl": LFMVLPlugin,
|
||||
"minicpm_v": MiniCPMVPlugin,
|
||||
"minicpm_v_4_6": MiniCPMV4_6Plugin,
|
||||
"mllama": MllamaPlugin,
|
||||
"paligemma": PaliGemmaPlugin,
|
||||
"pixtral": PixtralPlugin,
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from ...extras import logging
|
||||
@@ -27,6 +27,25 @@ if TYPE_CHECKING:
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
MAX_SU_SEQ_IDX = 2**32 # maximum sub-sequence index
|
||||
|
||||
|
||||
@dataclass
|
||||
class PackingParams:
|
||||
r"""Metadata for a packed sequence: sub-sequence boundaries and multimodal data indices.
|
||||
|
||||
- sequence_boundaries: cumulative token positions, e.g. [0, 100, 250, 512] means 3 sub-seqs
|
||||
with token ranges [0,100), [100,250), [250,512). Length = num_sub_seqs + 1.
|
||||
- image_subseq_ids / video_subseq_ids / audio_subseq_ids: for each mm item, the 0-based
|
||||
sub-sequence index it belongs to. Length = total number of that mm type in the packed sample.
|
||||
"""
|
||||
|
||||
sequence_boundaries: list[int]
|
||||
image_subseq_ids: list[int]
|
||||
video_subseq_ids: list[int]
|
||||
audio_subseq_ids: list[int]
|
||||
right_padding_length: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class SupervisedDatasetProcessor(DatasetProcessor):
|
||||
@@ -44,7 +63,8 @@ class SupervisedDatasetProcessor(DatasetProcessor):
|
||||
input_ids, labels = self.template.mm_plugin.process_token_ids(
|
||||
[], [], images, videos, audios, self.tokenizer, self.processor
|
||||
)
|
||||
encoded_pairs = self.template.encode_multiturn(self.tokenizer, messages, system, tools)
|
||||
discarding_history_cot = self.data_args.mask_history and not self.template.preserve_thinking
|
||||
encoded_pairs = self.template.encode_multiturn(self.tokenizer, messages, system, tools, discarding_history_cot)
|
||||
total_length = len(input_ids) + (1 if self.template.efficient_eos else 0)
|
||||
if self.data_args.mask_history:
|
||||
encoded_pairs = encoded_pairs[::-1] # high priority for last turns
|
||||
@@ -162,10 +182,17 @@ class PackedSupervisedDatasetProcessor(SupervisedDatasetProcessor):
|
||||
valid_num += 1
|
||||
|
||||
model_inputs = defaultdict(list)
|
||||
requires_packing_params = self.data_args.neat_packing
|
||||
knapsacks = greedy_knapsack(lengths, self.data_args.cutoff_len)
|
||||
for knapsack in knapsacks:
|
||||
packed_input_ids, packed_attention_masks, packed_position_ids, packed_labels = [], [], [], []
|
||||
packed_images, packed_videos, packed_audios = [], [], []
|
||||
if requires_packing_params:
|
||||
sequence_boundaries = [0]
|
||||
image_subseq_ids: list[int] = []
|
||||
video_subseq_ids: list[int] = []
|
||||
audio_subseq_ids: list[int] = []
|
||||
|
||||
for i, length in enumerate(knapsack):
|
||||
index = length2indexes[length].pop()
|
||||
packed_input_ids += batch_input_ids[index]
|
||||
@@ -174,6 +201,15 @@ class PackedSupervisedDatasetProcessor(SupervisedDatasetProcessor):
|
||||
packed_images += batch_images[index]
|
||||
packed_videos += batch_videos[index]
|
||||
packed_audios += batch_audios[index]
|
||||
if requires_packing_params:
|
||||
n_img = len(batch_images[index])
|
||||
n_vid = len(batch_videos[index])
|
||||
n_aud = len(batch_audios[index])
|
||||
sequence_boundaries.append(sequence_boundaries[-1] + len(batch_input_ids[index]))
|
||||
image_subseq_ids.extend([i] * n_img)
|
||||
video_subseq_ids.extend([i] * n_vid)
|
||||
audio_subseq_ids.extend([i] * n_aud)
|
||||
|
||||
if self.data_args.neat_packing:
|
||||
packed_attention_masks += [i + 1] * len(batch_input_ids[index]) # start from 1
|
||||
else:
|
||||
@@ -189,10 +225,23 @@ class PackedSupervisedDatasetProcessor(SupervisedDatasetProcessor):
|
||||
else:
|
||||
packed_attention_masks += [1] * pad_length # more efficient flash_attn
|
||||
|
||||
if requires_packing_params:
|
||||
sequence_boundaries.append(sequence_boundaries[-1] + pad_length)
|
||||
|
||||
if len(packed_input_ids) != self.data_args.cutoff_len + 1:
|
||||
raise ValueError("The length of packed example should be identical to the cutoff length.")
|
||||
|
||||
model_inputs["input_ids"].append(packed_input_ids)
|
||||
if requires_packing_params:
|
||||
packing_params = PackingParams(
|
||||
sequence_boundaries=sequence_boundaries,
|
||||
image_subseq_ids=image_subseq_ids or [MAX_SU_SEQ_IDX], # avoid dataset concat error
|
||||
video_subseq_ids=video_subseq_ids or [MAX_SU_SEQ_IDX],
|
||||
audio_subseq_ids=audio_subseq_ids or [MAX_SU_SEQ_IDX],
|
||||
right_padding_length=pad_length,
|
||||
)
|
||||
model_inputs["packing_params"].append(asdict(packing_params))
|
||||
|
||||
model_inputs["attention_mask"].append(packed_attention_masks)
|
||||
model_inputs["position_ids"].append(packed_position_ids)
|
||||
model_inputs["labels"].append(packed_labels)
|
||||
|
||||
@@ -54,6 +54,7 @@ class Template:
|
||||
replace_eos: bool
|
||||
replace_jinja_template: bool
|
||||
enable_thinking: Optional[bool]
|
||||
preserve_thinking: bool
|
||||
mm_plugin: "BasePlugin"
|
||||
|
||||
def encode_oneturn(
|
||||
@@ -78,6 +79,7 @@ class Template:
|
||||
messages: list[dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
discarding_history_cot: bool = False, # only effect reasoning template
|
||||
) -> list[tuple[list[int], list[int]]]:
|
||||
r"""Return multiple pairs of token ids representing prompts and responses respectively."""
|
||||
encoded_messages = self._encode(tokenizer, messages, system, tools)
|
||||
@@ -414,8 +416,9 @@ class ReasoningTemplate(Template):
|
||||
tools: Optional[str] = None,
|
||||
) -> tuple[list[int], list[int]]:
|
||||
messages = deepcopy(messages)
|
||||
for i in range(1, len(messages) - 2, 2):
|
||||
messages[i]["content"] = self.remove_thought(messages[i]["content"])
|
||||
if not self.preserve_thinking:
|
||||
for i in range(1, len(messages) - 2, 2):
|
||||
messages[i]["content"] = self.remove_thought(messages[i]["content"])
|
||||
|
||||
if self.enable_thinking is False: # remove all cot
|
||||
messages[-1]["content"] = self.remove_thought(messages[-1]["content"])
|
||||
@@ -439,14 +442,24 @@ class ReasoningTemplate(Template):
|
||||
messages: list[dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
discarding_history_cot: bool = False,
|
||||
) -> list[tuple[list[int], list[int]]]:
|
||||
messages = deepcopy(messages)
|
||||
if self.enable_thinking is False: # remove all cot
|
||||
for i in range(1, len(messages), 2):
|
||||
messages[i]["content"] = self.remove_thought(messages[i]["content"])
|
||||
|
||||
if discarding_history_cot:
|
||||
for i in range(1, len(messages) - 2, 2): # preserve the last cot
|
||||
messages[i]["content"] = self.remove_thought(messages[i]["content"])
|
||||
|
||||
encoded_messages = self._encode(tokenizer, messages, system, tools)
|
||||
for i in range(0, len(messages), 2):
|
||||
if discarding_history_cot:
|
||||
turn_indices = [len(messages) - 2]
|
||||
else:
|
||||
turn_indices = range(0, len(messages), 2)
|
||||
|
||||
for i in turn_indices:
|
||||
if (
|
||||
self.thought_words[0].strip() not in messages[i + 1]["content"]
|
||||
and self.thought_words[1].strip() not in messages[i + 1]["content"]
|
||||
@@ -491,6 +504,7 @@ def register_template(
|
||||
replace_eos: bool = False,
|
||||
replace_jinja_template: bool = False,
|
||||
enable_thinking: Optional[bool] = True,
|
||||
preserve_thinking: bool = False,
|
||||
mm_plugin: "BasePlugin" = get_mm_plugin(name="base"),
|
||||
template_class: type["Template"] = Template,
|
||||
) -> None:
|
||||
@@ -543,6 +557,7 @@ def register_template(
|
||||
replace_eos=replace_eos,
|
||||
replace_jinja_template=replace_jinja_template,
|
||||
enable_thinking=enable_thinking,
|
||||
preserve_thinking=preserve_thinking,
|
||||
mm_plugin=mm_plugin,
|
||||
)
|
||||
|
||||
@@ -605,6 +620,7 @@ def parse_template(tokenizer: "PreTrainedTokenizer") -> "Template":
|
||||
replace_eos=False,
|
||||
replace_jinja_template=False,
|
||||
enable_thinking=True,
|
||||
preserve_thinking=False,
|
||||
mm_plugin=get_mm_plugin(name="base"),
|
||||
)
|
||||
|
||||
@@ -644,6 +660,7 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args:
|
||||
"e.g., qwen3_vl_nothink"
|
||||
)
|
||||
template.enable_thinking = data_args.enable_thinking
|
||||
template.preserve_thinking = data_args.preserve_thinking
|
||||
|
||||
template.fix_special_tokens(tokenizer)
|
||||
template.fix_jinja_template(tokenizer)
|
||||
@@ -816,6 +833,19 @@ register_template(
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="hy3",
|
||||
format_user=StringFormatter(slots=["<|hy_User|>{{content}}<|hy_Assistant|>"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}<|hy_eos|>"]),
|
||||
format_system=StringFormatter(slots=["{{content}}"]),
|
||||
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||
stop_words=["<|hy_eos|>"],
|
||||
replace_eos=True,
|
||||
thought_words=("<think>", "</think>"),
|
||||
template_class=ReasoningTemplate,
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="deepseekcoder",
|
||||
format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n### Response:"]),
|
||||
@@ -997,6 +1027,57 @@ register_template(
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="gemma4",
|
||||
format_user=StringFormatter(slots=["<|turn>user\n{{content}}<turn|>\n<|turn>model\n"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}<turn|>\n"]),
|
||||
format_system=StringFormatter(
|
||||
slots=["<|turn>system\n<|think|>{{content}}<turn|>\n"]
|
||||
), # default thought singal contained
|
||||
format_observation=StringFormatter(
|
||||
slots=["<|turn>tool\n{{content}}<turn|>\n<|turn>model\n"]
|
||||
), # seem not consistent with the chattemplate
|
||||
format_tools=ToolFormatter(tool_format="gemma4"),
|
||||
format_function=FunctionFormatter(slots=["<|tool>{{content}}<tool|>"], tool_format="gemma4"),
|
||||
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||
stop_words=["<turn|>"],
|
||||
default_system="You are a helpful assistant.", # important for thinking
|
||||
thought_words=("<|channel>thought\n", "<channel|>"),
|
||||
replace_eos=True,
|
||||
mm_plugin=get_mm_plugin(
|
||||
"gemma4",
|
||||
image_token="<|image|>",
|
||||
video_token="<|video|>",
|
||||
),
|
||||
template_class=ReasoningTemplate,
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="gemma4n",
|
||||
format_user=StringFormatter(slots=["<|turn>user\n{{content}}<turn|>\n<|turn>model\n"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}<turn|>\n"]),
|
||||
format_system=StringFormatter(
|
||||
slots=["<|turn>system\n<|think|>{{content}}<turn|>\n"]
|
||||
), # default thought singal contained
|
||||
format_observation=StringFormatter(slots=["<|turn>tool\n{{content}}<turn|>\n<|turn>model\n"]),
|
||||
format_tools=ToolFormatter(tool_format="gemma4"),
|
||||
format_function=FunctionFormatter(slots=["<|tool>{{content}}<tool|>"], tool_format="gemma4"),
|
||||
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||
stop_words=["<turn|>"],
|
||||
default_system="You are a helpful assistant.", # important for thinking
|
||||
thought_words=("<|channel>thought\n", "<channel|>"),
|
||||
replace_eos=True,
|
||||
mm_plugin=get_mm_plugin(
|
||||
"gemma4",
|
||||
image_token="<|image|>",
|
||||
video_token="<|video|>",
|
||||
audio_token="<|audio|>",
|
||||
),
|
||||
template_class=ReasoningTemplate,
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="glm4",
|
||||
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]),
|
||||
@@ -1113,7 +1194,7 @@ register_template(
|
||||
register_template(
|
||||
name="gpt_oss",
|
||||
format_user=StringFormatter(slots=["<|start|>user<|message|>{{content}}<|end|><|start|>assistant"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}<|end|>"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}"]),
|
||||
format_system=StringFormatter(slots=["<|start|>system<|message|>{{content}}<|end|>"]),
|
||||
default_system="You are ChatGPT, a large language model trained by OpenAI.",
|
||||
thought_words=("<|channel|>analysis<|message|>", "<|end|><|start|>assistant<|channel|>final<|message|>"),
|
||||
@@ -1623,6 +1704,17 @@ register_template(
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="minicpm_v_4_6",
|
||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
|
||||
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
||||
stop_words=["<|im_end|>"],
|
||||
default_system="You are a helpful assistant.",
|
||||
mm_plugin=get_mm_plugin(name="minicpm_v_4_6", image_token="<image>", video_token="<video>"),
|
||||
)
|
||||
|
||||
|
||||
# copied from minicpm_v template
|
||||
register_template(
|
||||
name="minicpm_o",
|
||||
@@ -2062,6 +2154,24 @@ register_template(
|
||||
)
|
||||
|
||||
|
||||
# copied from qwen3_5 template
|
||||
register_template(
|
||||
name="qwen3_6",
|
||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
|
||||
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
||||
format_function=FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen3_5"),
|
||||
format_observation=StringFormatter(
|
||||
slots=["<|im_start|>user\n<tool_response>\n{{content}}\n</tool_response><|im_end|>\n<|im_start|>assistant\n"]
|
||||
),
|
||||
format_tools=ToolFormatter(tool_format="qwen3_5"),
|
||||
stop_words=["<|im_end|>"],
|
||||
replace_eos=True,
|
||||
mm_plugin=get_mm_plugin(name="qwen3_vl", image_token="<|image_pad|>", video_token="<|video_pad|>"),
|
||||
template_class=ReasoningTemplate,
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="sailor",
|
||||
format_user=StringFormatter(slots=["<|im_start|>question\n{{content}}<|im_end|>\n<|im_start|>answer\n"]),
|
||||
|
||||
@@ -210,6 +210,166 @@ class DefaultToolUtils(ToolUtils):
|
||||
return results
|
||||
|
||||
|
||||
class Gemma4ToolUtils(ToolUtils):
|
||||
r"""Gemma-4 tool using template."""
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def tool_formatter(tools: list[dict[str, Any]]) -> str:
|
||||
def _format_parameters(properties: dict[str, Any]) -> str:
|
||||
parts: list[str] = []
|
||||
for name, schema in properties.items():
|
||||
item_parts: list[str] = []
|
||||
if schema.get("description"):
|
||||
item_parts.append(f'description:<|"|>{schema["description"]}<|"|>')
|
||||
if schema.get("type"):
|
||||
item_parts.append(f'type:<|"|>{str(schema["type"]).upper()}<|"|>')
|
||||
parts.append(f"{name}:{{{','.join(item_parts)}}}")
|
||||
|
||||
return ",".join(parts)
|
||||
|
||||
declarations: list[str] = []
|
||||
for tool in tools:
|
||||
function_data = tool.get("function", tool) if tool.get("type") == "function" else tool
|
||||
declaration = (
|
||||
f"declaration:{function_data['name']}"
|
||||
+ "{"
|
||||
+ f'description:<|"|>{function_data.get("description", "")}<|"|>'
|
||||
)
|
||||
|
||||
params = function_data.get("parameters")
|
||||
if params:
|
||||
param_parts: list[str] = []
|
||||
if params.get("properties"):
|
||||
param_parts.append(f"properties:{{{_format_parameters(params['properties'])}}}")
|
||||
|
||||
if params.get("required"):
|
||||
required_text = ",".join(f'<|"|>{item}<|"|>' for item in params["required"])
|
||||
param_parts.append(f"required:[{required_text}]")
|
||||
|
||||
if params.get("type"):
|
||||
param_parts.append(f'type:<|"|>{str(params["type"]).upper()}<|"|>')
|
||||
|
||||
declaration += f",parameters:{{{','.join(param_parts)}}}"
|
||||
|
||||
response_declaration = function_data.get("response")
|
||||
if response_declaration:
|
||||
response_parts: list[str] = []
|
||||
if response_declaration.get("description"):
|
||||
response_parts.append(f'description:<|"|>{response_declaration["description"]}<|"|>')
|
||||
|
||||
response_type = str(response_declaration.get("type", "")).upper()
|
||||
|
||||
if response_type == "OBJECT":
|
||||
response_parts.append(f'type:<|"|>{response_type}<|"|>')
|
||||
|
||||
declaration += f",response:{{{','.join(response_parts)}}}"
|
||||
|
||||
declarations.append(declaration + "}")
|
||||
|
||||
return "\n".join(declarations)
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]:
|
||||
regex = re.compile(r"<\|tool_call\>call:([^{\s]+)\{(.*?)\}<tool_call\|>", re.DOTALL)
|
||||
matches = re.findall(regex, content)
|
||||
if not matches:
|
||||
return content
|
||||
|
||||
def _parse_arguments(arg_text: str) -> Any:
|
||||
text = arg_text.strip()
|
||||
if not text:
|
||||
return {}
|
||||
|
||||
# `function_formatter` writes dict arguments as `k:v,...` inside `{...}`.
|
||||
# The extractor captures only the inner text, so re-wrap it to parse as JSON object.
|
||||
object_like_text = "{" + text + "}"
|
||||
# Convert Gemma string markers (<|"|>value<|"|>) to valid JSON strings.
|
||||
normalized = re.sub(
|
||||
r"<\|\"\|\>(.*?)<\|\"\|\>",
|
||||
lambda m: json.dumps(m.group(1), ensure_ascii=False),
|
||||
object_like_text,
|
||||
flags=re.DOTALL,
|
||||
)
|
||||
# Quote unquoted object keys so the payload can be parsed by json.loads.
|
||||
normalized = re.sub(r"(^|[{\s,])([A-Za-z_][A-Za-z0-9_]*)(\s*:)", r'\1"\2"\3', normalized)
|
||||
try:
|
||||
return json.loads(normalized)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
try:
|
||||
return json.loads(text)
|
||||
except json.JSONDecodeError:
|
||||
return text
|
||||
|
||||
results: list[FunctionCall] = []
|
||||
for name, arg_block in matches:
|
||||
parsed_arguments = _parse_arguments(arg_block)
|
||||
if isinstance(parsed_arguments, str):
|
||||
arguments = parsed_arguments
|
||||
else:
|
||||
arguments = json.dumps(parsed_arguments, ensure_ascii=False)
|
||||
results.append(FunctionCall(name.strip(), arguments))
|
||||
|
||||
return results
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def function_formatter(functions: list["FunctionCall"]) -> str:
|
||||
def _format_argument(argument: Any, escape_keys: bool = True) -> str:
|
||||
if isinstance(argument, str):
|
||||
return f'<|"|>{argument}<|"|>'
|
||||
|
||||
if isinstance(argument, bool):
|
||||
return "true" if argument else "false"
|
||||
|
||||
if isinstance(argument, dict):
|
||||
items: list[str] = []
|
||||
for key in sorted(argument.keys()):
|
||||
formatted_key = f'<|"|>{key}<|"|>' if escape_keys else str(key)
|
||||
formatted_value = _format_argument(argument[key], escape_keys=escape_keys)
|
||||
items.append(f"{formatted_key}:{formatted_value}")
|
||||
return "{" + ",".join(items) + "}"
|
||||
|
||||
if isinstance(argument, (list, tuple)):
|
||||
return "[" + ",".join(_format_argument(item, escape_keys=escape_keys) for item in argument) + "]"
|
||||
|
||||
if argument is None:
|
||||
return "null"
|
||||
|
||||
return str(argument)
|
||||
|
||||
function_texts: list[str] = []
|
||||
for function in functions:
|
||||
name = function.name
|
||||
raw_arguments = function.arguments
|
||||
|
||||
try:
|
||||
parsed_arguments = json.loads(raw_arguments)
|
||||
except (TypeError, json.JSONDecodeError):
|
||||
parsed_arguments = raw_arguments
|
||||
|
||||
call_text = f"<|tool_call>call:{name}" + "{"
|
||||
if isinstance(parsed_arguments, dict):
|
||||
args_text = []
|
||||
for key in sorted(parsed_arguments.keys()):
|
||||
value_text = _format_argument(parsed_arguments[key], escape_keys=False)
|
||||
args_text.append(f"{key}:{value_text}")
|
||||
|
||||
call_text += ",".join(args_text)
|
||||
elif isinstance(parsed_arguments, str):
|
||||
call_text += parsed_arguments
|
||||
else:
|
||||
call_text += _format_argument(parsed_arguments, escape_keys=False)
|
||||
|
||||
call_text += "}<tool_call|>"
|
||||
function_texts.append(call_text)
|
||||
|
||||
return "".join(function_texts)
|
||||
|
||||
|
||||
class GLM4ToolUtils(ToolUtils):
|
||||
r"""GLM-4 tool using template."""
|
||||
|
||||
@@ -361,6 +521,8 @@ class MiniMaxM2ToolUtils(ToolUtils):
|
||||
prompt += "\n</invoke>"
|
||||
function_texts.append(prompt)
|
||||
|
||||
return "\n".join(function_texts)
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]:
|
||||
@@ -598,7 +760,7 @@ class SeedToolUtils(ToolUtils):
|
||||
|
||||
results.append(FunctionCall(func_name.strip(), json.dumps(args_dict, ensure_ascii=False)))
|
||||
|
||||
return results
|
||||
return results if results else content
|
||||
|
||||
|
||||
class LingToolUtils(QwenToolUtils):
|
||||
@@ -721,6 +883,7 @@ class LFM2ToolUtils(ToolUtils):
|
||||
|
||||
TOOLS = {
|
||||
"default": DefaultToolUtils(),
|
||||
"gemma4": Gemma4ToolUtils(),
|
||||
"glm4": GLM4ToolUtils(),
|
||||
"llama3": Llama3ToolUtils(),
|
||||
"lfm2": LFM2ToolUtils(),
|
||||
|
||||
@@ -69,12 +69,28 @@ MCA_SUPPORTED_MODELS = {
|
||||
"qwen3",
|
||||
"qwen3_moe",
|
||||
"qwen3_next",
|
||||
"qwen3_5",
|
||||
"qwen3_5_moe",
|
||||
}
|
||||
|
||||
METHODS = ["full", "freeze", "lora", "oft"]
|
||||
|
||||
MOD_SUPPORTED_MODELS = {"bloom", "falcon", "gemma", "llama", "mistral", "mixtral", "phi", "starcoder2"}
|
||||
|
||||
MROPE_MODELS = {
|
||||
"glm4v",
|
||||
"glm_ocr",
|
||||
"Keye",
|
||||
"qwen2_vl",
|
||||
"qwen2_5_vl",
|
||||
"qwen2_5_omni_thinker",
|
||||
"qwen3_omni_moe_thinker",
|
||||
"qwen3_vl",
|
||||
"qwen3_vl_moe",
|
||||
"qwen3_5",
|
||||
"qwen3_5_moe",
|
||||
}
|
||||
|
||||
MULTIMODAL_SUPPORTED_MODELS = set()
|
||||
|
||||
PEFT_METHODS = {"lora", "oft"}
|
||||
@@ -123,7 +139,6 @@ class EngineName(StrEnum):
|
||||
HF = "huggingface"
|
||||
VLLM = "vllm"
|
||||
SGLANG = "sglang"
|
||||
KT = "ktransformers"
|
||||
|
||||
|
||||
class DownloadSource(StrEnum):
|
||||
@@ -849,6 +864,34 @@ register_model_group(
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Gemma-4-26B-A4B-Thinking": {
|
||||
DownloadSource.DEFAULT: "google/gemma-4-26B-A4B-it",
|
||||
},
|
||||
"Gemma-4-31B-Thinking": {
|
||||
DownloadSource.DEFAULT: "google/gemma-4-31B-it",
|
||||
},
|
||||
},
|
||||
template="gemma4",
|
||||
multimodal=True,
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Gemma-4-E2B-Thinking": {
|
||||
DownloadSource.DEFAULT: "google/gemma-4-E2B-it",
|
||||
},
|
||||
"Gemma-4-E4B-Thinking": {
|
||||
DownloadSource.DEFAULT: "google/gemma-4-E4B-it",
|
||||
},
|
||||
},
|
||||
template="gemma4n",
|
||||
multimodal=True,
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"GLM-4-9B": {
|
||||
@@ -1213,6 +1256,17 @@ register_model_group(
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Hy3-Preview": {
|
||||
DownloadSource.DEFAULT: "tencent/Hy3-preview",
|
||||
DownloadSource.MODELSCOPE: "tencent/Hy3-preview",
|
||||
},
|
||||
},
|
||||
template="hy3",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Index-1.9B-Base": {
|
||||
@@ -1894,6 +1948,18 @@ register_model_group(
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"MiniCPM-V-4.6": {
|
||||
DownloadSource.DEFAULT: "openbmb/MiniCPM-V-4_6",
|
||||
DownloadSource.MODELSCOPE: "OpenBMB/MiniCPM-V-4_6",
|
||||
},
|
||||
},
|
||||
template="minicpm_v_4_6",
|
||||
multimodal=True,
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Ministral-8B-Instruct-2410": {
|
||||
@@ -2812,10 +2878,42 @@ register_model_group(
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Qwen3.5-0.8B-Base": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen3.5-0.8B-Base",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen3.5-0.8B-Base",
|
||||
},
|
||||
"Qwen3.5-2B-Base": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen3.5-2B-Base",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen3.5-2B-Base",
|
||||
},
|
||||
"Qwen3.5-4B-Base": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen3.5-4B-Base",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen3.5-4B-Base",
|
||||
},
|
||||
"Qwen3.5-9B-Base": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen3.5-9B-Base",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen3.5-9B-Base",
|
||||
},
|
||||
"Qwen3.5-35B-A3B-Base": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen3.5-35B-A3B-Base",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen3.5-35B-A3B-Base",
|
||||
},
|
||||
"Qwen3.5-0.8B-Thinking": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen3.5-0.8B",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen3.5-0.8B",
|
||||
},
|
||||
"Qwen3.5-2B-Thinking": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen3.5-2B",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen3.5-2B",
|
||||
},
|
||||
"Qwen3.5-4B-Thinking": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen3.5-4B",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen3.5-4B",
|
||||
},
|
||||
"Qwen3.5-9B-Thinking": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen3.5-9B",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen3.5-9B",
|
||||
},
|
||||
"Qwen3.5-27B-Thinking": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen3.5-27B",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen3.5-27B",
|
||||
|
||||
@@ -19,7 +19,7 @@
|
||||
from collections import OrderedDict
|
||||
|
||||
|
||||
VERSION = "0.9.5.dev0"
|
||||
VERSION = "0.9.5"
|
||||
|
||||
|
||||
def print_env() -> None:
|
||||
|
||||
@@ -94,10 +94,10 @@ def check_version(requirement: str, mandatory: bool = False) -> None:
|
||||
|
||||
def check_dependencies() -> None:
|
||||
r"""Check the version of the required packages."""
|
||||
check_version("transformers>=4.51.0,<=5.2.0")
|
||||
check_version("transformers>=4.55.0,<=5.6.0")
|
||||
check_version("datasets>=2.16.0,<=4.0.0")
|
||||
check_version("accelerate>=1.3.0,<=1.11.0")
|
||||
check_version("peft>=0.18.0,<=0.18.1")
|
||||
check_version("accelerate>=1.3.0,<=1.15.0")
|
||||
check_version("peft>=0.18.0,<=0.20.0")
|
||||
check_version("trl>=0.18.0,<=0.24.0")
|
||||
|
||||
|
||||
|
||||
@@ -20,6 +20,7 @@ import importlib.util
|
||||
from functools import lru_cache
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import transformers.utils.import_utils as import_utils
|
||||
from packaging import version
|
||||
|
||||
|
||||
@@ -70,6 +71,10 @@ def is_matplotlib_available():
|
||||
return _is_package_available("matplotlib")
|
||||
|
||||
|
||||
def is_hyper_parallel_available():
|
||||
return _is_package_available("hyper_parallel")
|
||||
|
||||
|
||||
def is_mcore_adapter_available():
|
||||
return _is_package_available("mcore_adapter")
|
||||
|
||||
@@ -83,7 +88,7 @@ def is_ray_available():
|
||||
|
||||
|
||||
def is_kt_available():
|
||||
return _is_package_available("ktransformers")
|
||||
return _is_package_available("kt_kernel")
|
||||
|
||||
|
||||
def is_requests_available():
|
||||
@@ -122,3 +127,26 @@ def is_uvicorn_available():
|
||||
|
||||
def is_vllm_available():
|
||||
return _is_package_available("vllm")
|
||||
|
||||
|
||||
_orig_is_package_available = import_utils._is_package_available
|
||||
|
||||
|
||||
class PackageAvailability(tuple):
|
||||
__slots__ = ()
|
||||
|
||||
def __new__(cls, available: bool, pkg_version: str = "N/A"):
|
||||
return super().__new__(cls, (bool(available), pkg_version))
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
return self[0]
|
||||
|
||||
|
||||
def _patched_is_package_available(pkg_name: str, return_version: bool = False):
|
||||
available, version = _orig_is_package_available(pkg_name, return_version=return_version)
|
||||
|
||||
return PackageAvailability(available, version)
|
||||
|
||||
|
||||
if is_transformers_version_greater_than("5.3.0"):
|
||||
import_utils._is_package_available = _patched_is_package_available
|
||||
|
||||
@@ -125,6 +125,10 @@ class DataArguments:
|
||||
default=True,
|
||||
metadata={"help": "Whether or not to enable thinking mode for reasoning models."},
|
||||
)
|
||||
preserve_thinking: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to preserve thinking content in historical turns for reasoning models."},
|
||||
)
|
||||
tokenized_path: str | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
|
||||
@@ -482,6 +482,24 @@ class FinetuningArguments(
|
||||
)
|
||||
},
|
||||
)
|
||||
use_hyper_parallel: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": (
|
||||
"Whether or not to use HyperParallel distributed training backend (FSDP/TP). "
|
||||
"Only supported for the 'sft' stage with full fine-tuning."
|
||||
)
|
||||
},
|
||||
)
|
||||
hyper_parallel_args: str | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": (
|
||||
"Path to a JSON file containing HyperParallel strategy arguments "
|
||||
"(e.g., tp_size, param_dtype). Used when use_hyper_parallel=True."
|
||||
)
|
||||
},
|
||||
)
|
||||
use_muon: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to use the Muon optimizer."},
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import os
|
||||
from dataclasses import asdict, dataclass, field, fields
|
||||
from typing import Any, Literal, Self
|
||||
|
||||
@@ -460,47 +461,81 @@ class SGLangArguments:
|
||||
|
||||
@dataclass
|
||||
class KTransformersArguments:
|
||||
r"""Arguments pertaining to the KT training."""
|
||||
r"""Arguments pertaining to KTransformers AMX MoE SFT training.
|
||||
|
||||
These fields are normalized into the transformers/accelerate KT config before training starts.
|
||||
"""
|
||||
|
||||
use_kt: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether To Use KTransformers Optimizations For LoRA Training."},
|
||||
metadata={"help": "Whether to use KTransformers AMX MoE backend for SFT training."},
|
||||
)
|
||||
kt_optimize_rule: str | None = field(
|
||||
kt_weight_path: str | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Path To The KTransformers Optimize Rule; See https://github.com/kvcache-ai/ktransformers/."
|
||||
},
|
||||
metadata={"help": "Path to pre-quantized INT8 expert weights (.kt files)."},
|
||||
)
|
||||
cpu_infer: int | None = field(
|
||||
default=32,
|
||||
metadata={"help": "Number Of CPU Cores Used For Computation."},
|
||||
kt_expert_checkpoint_path: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to expert checkpoint (safetensors) for online conversion."},
|
||||
)
|
||||
chunk_size: int | None = field(
|
||||
default=8192,
|
||||
metadata={"help": "Chunk Size Used For CPU Compute In KTransformers."},
|
||||
kt_use_lora_experts: bool | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Whether to use GPU-side LoRA Experts."},
|
||||
)
|
||||
mode: str | None = field(
|
||||
default="normal",
|
||||
metadata={"help": "Normal Or Long_Context For Llama Models."},
|
||||
kt_lora_expert_num: int | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Number of GPU-side LoRA Experts."},
|
||||
)
|
||||
kt_lora_expert_intermediate_size: int | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Intermediate size for GPU-side LoRA Experts."},
|
||||
)
|
||||
|
||||
kt_maxlen: int = field(
|
||||
default=4096,
|
||||
metadata={"help": "Maximum Sequence (Prompt + Response) Length Of The KT Engine."},
|
||||
)
|
||||
kt_use_cuda_graph: bool = field(
|
||||
default=True,
|
||||
metadata={"help": "Whether To Use CUDA Graphs For The KT Engine."},
|
||||
)
|
||||
kt_mode: str = field(
|
||||
default="normal",
|
||||
metadata={"help": "Normal Or Long_Context Mode For The KT Engine."},
|
||||
)
|
||||
kt_force_think: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Force-Think Toggle For The KT Engine."},
|
||||
)
|
||||
def get_kt_config_dict(self, finetuning_args: Any, model_max_length: int | None) -> dict[str, Any]:
|
||||
r"""Build KT config values from LLaMA-Factory model and LoRA arguments."""
|
||||
kt_config = {
|
||||
"kt_lora_rank": getattr(finetuning_args, "lora_rank", None),
|
||||
"kt_lora_alpha": getattr(finetuning_args, "lora_alpha", None),
|
||||
"kt_weight_path": self.kt_weight_path,
|
||||
"kt_expert_checkpoint_path": self.kt_expert_checkpoint_path,
|
||||
"kt_model_max_length": model_max_length,
|
||||
"kt_use_lora_experts": self.kt_use_lora_experts,
|
||||
"kt_lora_expert_num": self.kt_lora_expert_num,
|
||||
"kt_lora_expert_intermediate_size": self.kt_lora_expert_intermediate_size,
|
||||
}
|
||||
return {key: value for key, value in kt_config.items() if value is not None}
|
||||
|
||||
def apply_kt_config(self, finetuning_args: Any, training_args: Any, model_max_length: int | None) -> None:
|
||||
r"""Apply LLaMA-Factory KT args to transformers/accelerate KT integration points."""
|
||||
if not self.use_kt:
|
||||
return
|
||||
|
||||
kt_config = self.get_kt_config_dict(finetuning_args, model_max_length)
|
||||
env_mapping = {
|
||||
"kt_weight_path": "ACCELERATE_KT_WEIGHT_PATH",
|
||||
"kt_expert_checkpoint_path": "ACCELERATE_KT_EXPERT_CHECKPOINT_PATH",
|
||||
"kt_model_max_length": "ACCELERATE_KT_MODEL_MAX_LENGTH",
|
||||
"kt_lora_rank": "ACCELERATE_KT_LORA_RANK",
|
||||
"kt_lora_alpha": "ACCELERATE_KT_LORA_ALPHA",
|
||||
"kt_use_lora_experts": "ACCELERATE_KT_USE_LORA_EXPERTS",
|
||||
"kt_lora_expert_num": "ACCELERATE_KT_LORA_EXPERT_NUM",
|
||||
"kt_lora_expert_intermediate_size": "ACCELERATE_KT_LORA_EXPERT_INTERMEDIATE_SIZE",
|
||||
}
|
||||
for key, env_key in env_mapping.items():
|
||||
value = kt_config.get(key)
|
||||
if value is not None:
|
||||
os.environ[env_key] = str(value)
|
||||
|
||||
hf_kt = getattr(training_args, "hf_kt_config", None)
|
||||
if hf_kt is None or not hasattr(hf_kt, "_kt_config") or not isinstance(hf_kt._kt_config, dict):
|
||||
return
|
||||
|
||||
hf_kt._kt_config.update(kt_config)
|
||||
gc_enabled = getattr(training_args, "gradient_checkpointing", False) or not getattr(
|
||||
self, "disable_gradient_checkpointing", True
|
||||
)
|
||||
if gc_enabled:
|
||||
hf_kt._kt_config.setdefault("kt_share_cache_pool", True)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -33,7 +33,7 @@ from transformers.utils import is_torch_bf16_gpu_available, is_torch_npu_availab
|
||||
from ..extras import logging
|
||||
from ..extras.constants import CHECKPOINT_NAMES, EngineName
|
||||
from ..extras.misc import check_dependencies, check_version, get_current_device, is_env_enabled
|
||||
from ..extras.packages import is_mcore_adapter_available, is_transformers_version_greater_than
|
||||
from ..extras.packages import is_mcore_adapter_available
|
||||
from .data_args import DataArguments
|
||||
from .evaluation_args import EvaluationArguments
|
||||
from .finetuning_args import FinetuningArguments
|
||||
@@ -47,7 +47,13 @@ logger = logging.get_logger(__name__)
|
||||
check_dependencies()
|
||||
|
||||
|
||||
_TRAIN_ARGS = [ModelArguments, DataArguments, TrainingArguments, FinetuningArguments, GeneratingArguments]
|
||||
_TRAIN_ARGS = [
|
||||
ModelArguments,
|
||||
DataArguments,
|
||||
TrainingArguments,
|
||||
FinetuningArguments,
|
||||
GeneratingArguments,
|
||||
]
|
||||
_TRAIN_CLS = tuple[ModelArguments, DataArguments, TrainingArguments, FinetuningArguments, GeneratingArguments]
|
||||
_INFER_ARGS = [ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
|
||||
_INFER_CLS = tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
|
||||
@@ -57,9 +63,19 @@ _EVAL_CLS = tuple[ModelArguments, DataArguments, EvaluationArguments, Finetuning
|
||||
if is_mcore_adapter_available() and is_env_enabled("USE_MCA"):
|
||||
from mcore_adapter import TrainingArguments as McaTrainingArguments
|
||||
|
||||
_TRAIN_MCA_ARGS = [ModelArguments, DataArguments, McaTrainingArguments, FinetuningArguments, GeneratingArguments]
|
||||
_TRAIN_MCA_ARGS = [
|
||||
ModelArguments,
|
||||
DataArguments,
|
||||
McaTrainingArguments,
|
||||
FinetuningArguments,
|
||||
GeneratingArguments,
|
||||
]
|
||||
_TRAIN_MCA_CLS = tuple[
|
||||
ModelArguments, DataArguments, McaTrainingArguments, FinetuningArguments, GeneratingArguments
|
||||
ModelArguments,
|
||||
DataArguments,
|
||||
McaTrainingArguments,
|
||||
FinetuningArguments,
|
||||
GeneratingArguments,
|
||||
]
|
||||
else:
|
||||
_TRAIN_MCA_ARGS = []
|
||||
@@ -100,6 +116,52 @@ def _parse_args(
|
||||
return tuple(parsed_args)
|
||||
|
||||
|
||||
def _verify_trackio_args(training_args: "TrainingArguments") -> None:
|
||||
"""Validates Trackio-specific arguments.
|
||||
|
||||
Args:
|
||||
training_args: TrainingArguments instance (not a dictionary)
|
||||
"""
|
||||
report_to = training_args.report_to
|
||||
if not report_to:
|
||||
return
|
||||
|
||||
if isinstance(report_to, str):
|
||||
report_to = [report_to]
|
||||
|
||||
if "trackio" not in report_to:
|
||||
return
|
||||
|
||||
# --- Enforce project (required by Trackio) ---
|
||||
if not training_args.project:
|
||||
raise ValueError("`--project` must be specified when using Trackio.")
|
||||
|
||||
# --- Validate trackio_space_id format ---
|
||||
space_id = training_args.trackio_space_id
|
||||
if space_id:
|
||||
if space_id != "trackio" and "/" not in space_id:
|
||||
logger.warning(
|
||||
f"trackio_space_id '{space_id}' should typically be in format "
|
||||
"'org/space' for Hugging Face Spaces deployment."
|
||||
)
|
||||
|
||||
# --- Inform about default project usage ---
|
||||
if training_args.project == "huggingface":
|
||||
logger.info(
|
||||
"Using default project name 'huggingface'. "
|
||||
"Consider setting a custom project name with --project "
|
||||
"for better organization."
|
||||
)
|
||||
|
||||
# --- Validate hub repo privacy flag ---
|
||||
if training_args.hub_private_repo:
|
||||
logger.info("Repository will be created as private on Hugging Face Hub.")
|
||||
|
||||
# --- Recommend run_name for experiment clarity ---
|
||||
if not training_args.run_name:
|
||||
logger.warning("Consider setting --run_name for better experiment tracking clarity.")
|
||||
|
||||
|
||||
def _set_transformers_logging() -> None:
|
||||
if os.getenv("LLAMAFACTORY_VERBOSITY", "INFO") in ["DEBUG", "INFO"]:
|
||||
transformers.utils.logging.set_verbosity_info()
|
||||
@@ -146,7 +208,9 @@ def _check_extra_dependencies(
|
||||
training_args: Optional["TrainingArguments"] = None,
|
||||
) -> None:
|
||||
if model_args.use_kt:
|
||||
check_version("ktransformers", mandatory=True)
|
||||
check_version("kt-kernel", mandatory=True)
|
||||
check_version("transformers-kt", mandatory=True)
|
||||
check_version("accelerate-kt", mandatory=True)
|
||||
|
||||
if model_args.use_unsloth:
|
||||
check_version("unsloth", mandatory=True)
|
||||
@@ -278,8 +342,10 @@ def get_train_args(args: dict[str, Any] | list[str] | None = None) -> _TRAIN_CLS
|
||||
if finetuning_args.reward_model_type == "lora" and model_args.use_unsloth:
|
||||
raise ValueError("Unsloth does not support lora reward model.")
|
||||
|
||||
if training_args.report_to and training_args.report_to[0] not in ["wandb", "tensorboard"]:
|
||||
raise ValueError("PPO only accepts wandb or tensorboard logger.")
|
||||
if training_args.report_to and any(
|
||||
logger not in ("wandb", "tensorboard", "trackio", "none") for logger in training_args.report_to
|
||||
):
|
||||
raise ValueError("PPO only accepts wandb, tensorboard, or trackio logger.")
|
||||
|
||||
if not model_args.use_kt and training_args.parallel_mode == ParallelMode.NOT_DISTRIBUTED:
|
||||
raise ValueError("Please launch distributed training with `llamafactory-cli` or `torchrun`.")
|
||||
@@ -346,12 +412,10 @@ def get_train_args(args: dict[str, Any] | list[str] | None = None) -> _TRAIN_CLS
|
||||
if model_args.use_kt and is_deepspeed_zero3_enabled():
|
||||
raise ValueError("KTransformers is incompatible with DeepSpeed ZeRO-3.")
|
||||
|
||||
if data_args.neat_packing and is_transformers_version_greater_than("4.53.0"):
|
||||
raise ValueError("Neat packing is incompatible with transformers>=4.53.0.")
|
||||
|
||||
_set_env_vars()
|
||||
_verify_model_args(model_args, data_args, finetuning_args)
|
||||
_check_extra_dependencies(model_args, finetuning_args, training_args)
|
||||
_verify_trackio_args(training_args)
|
||||
|
||||
if not finetuning_args.use_mca and training_args.fp8_enable_fsdp_float8_all_gather and not training_args.fp8:
|
||||
logger.warning_rank0("fp8_enable_fsdp_float8_all_gather requires fp8=True. Setting fp8=True.")
|
||||
@@ -421,7 +485,7 @@ def get_train_args(args: dict[str, Any] | list[str] | None = None) -> _TRAIN_CLS
|
||||
training_args.resume_from_checkpoint is None
|
||||
and training_args.do_train
|
||||
and os.path.isdir(training_args.output_dir)
|
||||
and not training_args.overwrite_output_dir
|
||||
and not getattr(training_args, "overwrite_output_dir", False) # for mca training args and transformers >= 5.0
|
||||
and can_resume_from_checkpoint
|
||||
):
|
||||
last_checkpoint = get_last_checkpoint(training_args.output_dir)
|
||||
@@ -464,6 +528,9 @@ def get_train_args(args: dict[str, Any] | list[str] | None = None) -> _TRAIN_CLS
|
||||
)
|
||||
transformers.set_seed(training_args.seed)
|
||||
|
||||
if model_args.use_kt:
|
||||
model_args.apply_kt_config(finetuning_args, training_args, model_args.model_max_length)
|
||||
|
||||
return model_args, data_args, training_args, finetuning_args, generating_args
|
||||
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
from transformers import Seq2SeqTrainingArguments
|
||||
from transformers.training_args import _convert_str_dict
|
||||
@@ -63,6 +64,58 @@ class RayArguments:
|
||||
self.ray_init_kwargs = _convert_str_dict(json.loads(self.ray_init_kwargs))
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProfilerArguments:
|
||||
r"""Arguments for torch profiler configuration."""
|
||||
|
||||
enable_torch_profiler: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether to enable torch profiler for collecting performance traces."},
|
||||
)
|
||||
profiler_output_dir: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Directory to write profiler traces. Defaults to <output_dir>/profiler if not set."},
|
||||
)
|
||||
profiler_wait_steps: int = field(
|
||||
default=1,
|
||||
metadata={"help": "Number of steps to skip at the start of each profiling cycle."},
|
||||
)
|
||||
profiler_warmup_steps: int = field(
|
||||
default=1,
|
||||
metadata={"help": "Number of profiler warm-up steps per cycle."},
|
||||
)
|
||||
profiler_active_steps: int = field(
|
||||
default=1,
|
||||
metadata={"help": "Number of steps to actively record per cycle."},
|
||||
)
|
||||
profiler_repeat: int = field(
|
||||
default=1,
|
||||
metadata={"help": "Number of profiling cycles. Set to 0 for continuous profiling."},
|
||||
)
|
||||
profiler_record_shapes: bool = field(
|
||||
default=True,
|
||||
metadata={"help": "Whether to record tensor shapes during profiling."},
|
||||
)
|
||||
profiler_profile_memory: bool = field(
|
||||
default=True,
|
||||
metadata={"help": "Whether to profile memory usage."},
|
||||
)
|
||||
profiler_with_stack: bool = field(
|
||||
default=True,
|
||||
metadata={"help": "Whether to record stack traces during profiling."},
|
||||
)
|
||||
profile_modules: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": (
|
||||
"Comma-separated list of module name patterns to profile with CUDA events. "
|
||||
"Supports fnmatch wildcards (e.g. 'model.layers.0.self_attn,model.layers.*.mlp'). "
|
||||
"Reports per-module forward/backward timing statistics at each logging step."
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Fp8Arguments:
|
||||
r"""Arguments pertaining to the FP8 training."""
|
||||
@@ -87,7 +140,7 @@ class Fp8Arguments:
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingArguments(Fp8Arguments, RayArguments, BaseTrainingArguments):
|
||||
class TrainingArguments(ProfilerArguments, Fp8Arguments, RayArguments, BaseTrainingArguments):
|
||||
r"""Arguments pertaining to the trainer."""
|
||||
|
||||
overwrite_output_dir: bool = field(
|
||||
|
||||
@@ -20,8 +20,6 @@ from peft import LoraConfig, LoraModel, OFTConfig, PeftModel, TaskType, get_peft
|
||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||
|
||||
from ..extras import logging
|
||||
from ..extras.constants import EngineName
|
||||
from .model_utils.ktransformers import get_kt_peft_model, load_kt_peft_model
|
||||
from .model_utils.misc import find_all_linear_modules, find_expanded_modules
|
||||
from .model_utils.quantization import QuantizationMethod
|
||||
from .model_utils.unsloth import get_unsloth_peft_model, load_unsloth_peft_model
|
||||
@@ -125,7 +123,7 @@ def _setup_freeze_tuning(
|
||||
|
||||
model_type = getattr(model.config, "model_type", None)
|
||||
if not finetuning_args.freeze_multi_modal_projector and model_type in COMPOSITE_MODELS:
|
||||
trainable_layers.append(COMPOSITE_MODELS[model_type].projector_key)
|
||||
trainable_layers.extend(COMPOSITE_MODELS[model_type].projector_keys)
|
||||
|
||||
forbidden_modules = get_forbidden_modules(model.config, finetuning_args)
|
||||
for name, param in model.named_parameters():
|
||||
@@ -188,12 +186,6 @@ def _setup_lora_tuning(
|
||||
"token": model_args.hf_hub_token,
|
||||
}
|
||||
|
||||
if model_args.use_kt:
|
||||
if model_args.infer_backend != EngineName.KT:
|
||||
raise ValueError(
|
||||
"We should use ktransformers as backend to infer the adapter fine-tuned by ktransformers."
|
||||
)
|
||||
|
||||
for adapter in adapter_to_merge:
|
||||
model: LoraModel = PeftModel.from_pretrained(model, adapter, **init_kwargs)
|
||||
model = model.merge_and_unload()
|
||||
@@ -202,9 +194,7 @@ def _setup_lora_tuning(
|
||||
logger.info_rank0(f"Merged {len(adapter_to_merge)} adapter(s).")
|
||||
|
||||
if adapter_to_resume is not None: # resume lora training
|
||||
if model_args.use_kt:
|
||||
model = load_kt_peft_model(model_args, model)
|
||||
elif model_args.use_unsloth:
|
||||
if model_args.use_unsloth:
|
||||
model = load_unsloth_peft_model(config, model_args, finetuning_args, is_trainable=is_trainable)
|
||||
else:
|
||||
model = PeftModel.from_pretrained(model, adapter_to_resume, is_trainable=is_trainable, **init_kwargs)
|
||||
@@ -217,16 +207,6 @@ def _setup_lora_tuning(
|
||||
else:
|
||||
target_modules = finetuning_args.lora_target
|
||||
|
||||
if model_args.use_kt:
|
||||
new_list = []
|
||||
for m in target_modules:
|
||||
if m in ("down_proj", "up_proj", "gate_proj"):
|
||||
new_list.extend([f"mlp.{m}", f"shared_experts.{m}"])
|
||||
elif m not in ("generate_linear", "orig_module", "prefill_linear"):
|
||||
new_list.append(m)
|
||||
|
||||
target_modules[:] = new_list
|
||||
|
||||
if finetuning_args.use_llama_pro:
|
||||
target_modules = find_expanded_modules(model, target_modules, finetuning_args.freeze_trainable_layers)
|
||||
|
||||
@@ -270,19 +250,11 @@ def _setup_lora_tuning(
|
||||
}
|
||||
|
||||
if model_args.use_kt:
|
||||
if finetuning_args.finetuning_type == "oft":
|
||||
raise ValueError("KTransformers is currently not supported for OFT.")
|
||||
if finetuning_args.finetuning_type == "lora":
|
||||
peft_config = LoraConfig(
|
||||
task_type=TaskType.CAUSAL_LM,
|
||||
inference_mode=False,
|
||||
**peft_kwargs,
|
||||
)
|
||||
else:
|
||||
raise ValueError("KTransformers is currently only supported for LoRA.")
|
||||
if finetuning_args.finetuning_type != "lora":
|
||||
raise ValueError("KTransformers only supports LoRA finetuning.")
|
||||
|
||||
model = get_kt_peft_model(model, peft_config)
|
||||
print(f"KT_model:{model}")
|
||||
peft_config = LoraConfig(task_type=TaskType.CAUSAL_LM, inference_mode=False, **peft_kwargs)
|
||||
model = get_peft_model(model, peft_config)
|
||||
elif model_args.use_unsloth:
|
||||
if finetuning_args.finetuning_type == "oft":
|
||||
raise ValueError("Unsloth is currently not supported for OFT.")
|
||||
|
||||
@@ -31,7 +31,6 @@ from ..extras import logging
|
||||
from ..extras.misc import count_parameters, skip_check_imports, try_download_model_from_other_hub
|
||||
from ..extras.packages import is_torch_version_greater_than
|
||||
from .adapter import init_adapter
|
||||
from .model_utils.ktransformers import load_kt_pretrained_model
|
||||
from .model_utils.liger_kernel import apply_liger_kernel
|
||||
from .model_utils.misc import register_autoclass
|
||||
from .model_utils.mod import convert_pretrained_model_to_mod, load_mod_pretrained_model
|
||||
@@ -144,12 +143,7 @@ def load_model(
|
||||
|
||||
model = None
|
||||
lazy_load = False
|
||||
if model_args.use_kt:
|
||||
from ktransformers.sft.monkey_patch_torch_module import install_patch
|
||||
|
||||
install_patch()
|
||||
model = load_kt_pretrained_model(config, model_args)
|
||||
elif model_args.use_unsloth:
|
||||
if model_args.use_unsloth:
|
||||
if model_args.adapter_name_or_path is not None:
|
||||
lazy_load = True
|
||||
elif is_trainable:
|
||||
|
||||
@@ -1,154 +0,0 @@
|
||||
# Copyright 2025 the KVCache.AI team, Approaching AI, and the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import importlib.util as _u
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import torch
|
||||
|
||||
from ...extras import logging
|
||||
from ...extras.misc import get_current_device
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ...hparams import FinetuningArguments, ModelArguments
|
||||
|
||||
from transformers import AutoConfig, AutoModelForCausalLM, PretrainedConfig, PreTrainedModel
|
||||
|
||||
|
||||
KT_AVAILABLE = _u.find_spec("ktransformers") is not None
|
||||
if KT_AVAILABLE:
|
||||
from ktransformers.models.modeling_deepseek import DeepseekV2ForCausalLM
|
||||
from ktransformers.models.modeling_deepseek_v3 import DeepseekV3ForCausalLM
|
||||
from ktransformers.models.modeling_llama import LlamaForCausalLM
|
||||
from ktransformers.models.modeling_mixtral import MixtralForCausalLM
|
||||
from ktransformers.models.modeling_qwen2_moe import Qwen2MoeForCausalLM
|
||||
from ktransformers.models.modeling_qwen3_moe import Qwen3MoeForCausalLM
|
||||
from ktransformers.optimize.optimize import optimize_and_load_gguf
|
||||
from ktransformers.server.config.config import Config
|
||||
from ktransformers.sft.lora import inject_lora_layer
|
||||
from ktransformers.util.custom_loader import GGUFLoader, SafeTensorLoader
|
||||
from ktransformers.util.globals import GLOBAL_CONFIG
|
||||
from ktransformers.util.utils import load_weights
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def _get_kt_kwargs(
|
||||
config: "PretrainedConfig",
|
||||
model_name_or_path: str,
|
||||
model_args: "ModelArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
"model_name": model_name_or_path,
|
||||
"max_seq_length": model_args.model_max_length or 4096,
|
||||
"dtype": model_args.compute_dtype,
|
||||
"load_in_4bit": model_args.quantization_bit == 4,
|
||||
"token": model_args.hf_hub_token,
|
||||
"full_finetuning": finetuning_args.finetuning_type == "full",
|
||||
"device_map": {"": get_current_device()},
|
||||
"rope_scaling": getattr(config, "rope_scaling", None),
|
||||
"fix_tokenizer": False,
|
||||
"trust_remote_code": model_args.trust_remote_code,
|
||||
"use_gradient_checkpointing": "ktransformers",
|
||||
}
|
||||
|
||||
|
||||
def load_kt_pretrained_model(config: "PretrainedConfig", model_args: "ModelArguments") -> "PreTrainedModel":
|
||||
r"""Optionally load pretrained model with KTransformers. Used in training."""
|
||||
custom_models = {
|
||||
"DeepseekV2ForCausalLM": DeepseekV2ForCausalLM,
|
||||
"DeepseekV3ForCausalLM": DeepseekV3ForCausalLM,
|
||||
"Qwen2MoeForCausalLM": Qwen2MoeForCausalLM,
|
||||
"Qwen3MoeForCausalLM": Qwen3MoeForCausalLM,
|
||||
"LlamaForCausalLM": LlamaForCausalLM,
|
||||
"MixtralForCausalLM": MixtralForCausalLM,
|
||||
}
|
||||
Config().cpu_infer = model_args.cpu_infer
|
||||
Config().chunk_size = model_args.chunk_size
|
||||
config = AutoConfig.from_pretrained(model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code)
|
||||
|
||||
if model_args.mode == "long_context":
|
||||
assert config.architectures[0] == "LlamaForCausalLM", "only LlamaForCausalLM support long_context mode"
|
||||
torch.set_default_dtype(torch.float16)
|
||||
else:
|
||||
torch.set_default_dtype(config.torch_dtype)
|
||||
|
||||
with torch.device("meta"):
|
||||
if config.architectures[0] in custom_models:
|
||||
print("using custom modeling_xxx.py.")
|
||||
if "Qwen2Moe" in config.architectures[0]: # Qwen2Moe must use flash_attention_2 to avoid overflow.
|
||||
config._attn_implementation = "flash_attention_2"
|
||||
if "Llama" in config.architectures[0]:
|
||||
config._attn_implementation = "eager"
|
||||
if "Mixtral" in config.architectures[0]:
|
||||
config._attn_implementation = "flash_attention_2"
|
||||
model = custom_models[config.architectures[0]](config)
|
||||
else:
|
||||
attn_implementation = "flash_attention_2"
|
||||
model = AutoModelForCausalLM.from_config(
|
||||
config, trust_remote_code=True, attn_implementation=attn_implementation
|
||||
)
|
||||
|
||||
optimize_config_path = model_args.kt_optimize_rule
|
||||
gguf_path = model_args.model_name_or_path
|
||||
|
||||
assert optimize_config_path is not None, "optimize_config_path must be provided (path to YAML rules file)."
|
||||
assert gguf_path is not None, "gguf_path must be provided (path to a folder or .gguf file)."
|
||||
|
||||
GLOBAL_CONFIG._config["mod"] = "infer"
|
||||
optimize_and_load_gguf(model, optimize_config_path, gguf_path, config)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def get_kt_peft_model(model: "PreTrainedModel", peft_kwargs: dict[str, Any]) -> "PreTrainedModel":
|
||||
r"""Get the peft model for the pretrained model with KTransformers. Used in training."""
|
||||
from ktransformers.sft.peft_utils.mapping import get_peft_model
|
||||
|
||||
return get_peft_model(model, peft_kwargs)
|
||||
|
||||
|
||||
def load_kt_peft_model(model_args: "ModelArguments", model: "PreTrainedModel") -> "PreTrainedModel":
|
||||
r"""Load peft model with KTransformers. Used in both training and inference."""
|
||||
load_adapter_name_or_path = model_args.adapter_name_or_path[0]
|
||||
if load_adapter_name_or_path.endswith(".gguf"):
|
||||
inject_lora_layer(model, load_adapter_name_or_path)
|
||||
adapter_gguf_loader = GGUFLoader(load_adapter_name_or_path)
|
||||
load_weights(model, adapter_gguf_loader, adapter_gguf=True)
|
||||
model.train()
|
||||
else:
|
||||
inject_lora_layer(model, load_adapter_name_or_path)
|
||||
|
||||
adapter_loader = SafeTensorLoader(load_adapter_name_or_path)
|
||||
device = next(model.parameters()).device
|
||||
for key in adapter_loader.tensor_file_map.keys():
|
||||
try:
|
||||
tensor = adapter_loader.load_tensor(key, device=device)
|
||||
|
||||
model_key = key.replace("base_model.model.", "")
|
||||
model_key = model_key.replace(".weight", ".default.weight")
|
||||
model_key = model_key.replace(".default.default.weight", ".default.weight")
|
||||
|
||||
param = model.get_parameter(model_key)
|
||||
param.data.copy_(tensor.data)
|
||||
|
||||
print(f"Loaded adapter weight: {key} -> {model_key}")
|
||||
except AttributeError:
|
||||
print(f"Skipping {key}: not a model parameter")
|
||||
except KeyError:
|
||||
print(f"Key not found in model: {model_key} (original: {key})")
|
||||
|
||||
return model
|
||||
@@ -45,7 +45,7 @@ def apply_liger_kernel(
|
||||
from liger_kernel.transformers import apply_liger_kernel_to_gemma3 as apply_liger_kernel
|
||||
elif model_type == "gemma3_text":
|
||||
from liger_kernel.transformers import apply_liger_kernel_to_gemma3_text as apply_liger_kernel
|
||||
elif model_type == "glm4":
|
||||
elif model_type in ["glm", "glm4"]: # for glm4-9b, glm4-32B respectively
|
||||
from liger_kernel.transformers import apply_liger_kernel_to_glm4 as apply_liger_kernel
|
||||
elif model_type == "glm4v":
|
||||
from liger_kernel.transformers import apply_liger_kernel_to_glm4v as apply_liger_kernel
|
||||
@@ -79,6 +79,8 @@ def apply_liger_kernel(
|
||||
from liger_kernel.transformers import apply_liger_kernel_to_qwen3_moe as apply_liger_kernel
|
||||
elif model_type == "qwen3_next":
|
||||
from liger_kernel.transformers import apply_liger_kernel_to_qwen3_next as apply_liger_kernel
|
||||
elif model_type == "qwen3_5":
|
||||
from liger_kernel.transformers import apply_liger_kernel_to_qwen3_5 as apply_liger_kernel
|
||||
elif model_type == "gpt_oss":
|
||||
try:
|
||||
from liger_kernel.transformers import apply_liger_kernel_to_gpt_oss as apply_liger_kernel
|
||||
|
||||
@@ -35,7 +35,7 @@ def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool)
|
||||
forbidden_modules.add("output")
|
||||
|
||||
if model_type in COMPOSITE_MODELS:
|
||||
forbidden_modules.add(COMPOSITE_MODELS[model_type].projector_key)
|
||||
forbidden_modules.update(COMPOSITE_MODELS[model_type].projector_keys)
|
||||
|
||||
if freeze_vision_tower and model_type in COMPOSITE_MODELS:
|
||||
forbidden_modules.update(COMPOSITE_MODELS[model_type].vision_model_keys)
|
||||
|
||||
@@ -62,6 +62,10 @@ def add_z3_leaf_module(model: "PreTrainedModel") -> None:
|
||||
# deepseek v3 and kimi vl use custom code
|
||||
_set_z3_leaf_modules(model, ["DeepseekV3MoE"])
|
||||
|
||||
if model_type == "hy_v3":
|
||||
# hy3 uses custom code
|
||||
_set_z3_leaf_modules(model, ["HYV3MoE"])
|
||||
|
||||
if model_type == "ernie4_5_moe":
|
||||
from transformers.models.ernie4_5_moe.modeling_ernie4_5_moe import Ernie4_5_MoeSparseMoeBlock
|
||||
|
||||
@@ -147,6 +151,11 @@ def add_z3_leaf_module(model: "PreTrainedModel") -> None:
|
||||
|
||||
_set_z3_leaf_modules(model, [Qwen3NextSparseMoeBlock])
|
||||
|
||||
if model_type == "qwen3_5_moe":
|
||||
from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import Qwen3_5MoeSparseMoeBlock
|
||||
|
||||
_set_z3_leaf_modules(model, [Qwen3_5MoeSparseMoeBlock])
|
||||
|
||||
|
||||
def configure_moe(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
|
||||
if not is_trainable or not model_args.moe_aux_loss_coef:
|
||||
|
||||
@@ -37,7 +37,6 @@
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -45,10 +44,6 @@ import torch.nn.functional as F
|
||||
from ...extras import logging
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ...hparams import ModelArguments
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@@ -105,13 +100,3 @@ def get_unpad_data(attention_mask: "torch.Tensor") -> tuple["torch.Tensor", "tor
|
||||
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
||||
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
|
||||
return indices, cu_seqlens, max_seqlen_in_batch
|
||||
|
||||
|
||||
def configure_packing(model_args: "ModelArguments", is_trainable: bool) -> None:
|
||||
if not is_trainable or not model_args.block_diag_attn:
|
||||
return
|
||||
|
||||
import transformers.modeling_flash_attention_utils
|
||||
|
||||
transformers.modeling_flash_attention_utils._get_unpad_data = get_unpad_data
|
||||
logger.info_rank0("Using block diagonal attention for sequence packing without cross-attention.")
|
||||
|
||||
@@ -24,7 +24,6 @@ import transformers.models
|
||||
from transformers.activations import ACT2FN
|
||||
|
||||
from ...extras import logging
|
||||
from ...extras.packages import is_transformers_version_greater_than
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -40,16 +39,27 @@ transformers_logger = transformers.utils.logging.get_logger(__name__)
|
||||
@dataclass
|
||||
class CompositeModel:
|
||||
model_type: str
|
||||
projector_key: str
|
||||
projector_keys: list[str]
|
||||
vision_model_keys: list[str]
|
||||
language_model_keys: list[str]
|
||||
lora_conflict_keys: list[str]
|
||||
|
||||
def get_projector(self, module: "torch.nn.Module") -> "torch.nn.Module":
|
||||
for key in self.projector_key.split("."):
|
||||
module = getattr(module, key)
|
||||
def get_projectors(self, module: "torch.nn.Module") -> list["torch.nn.Module"]:
|
||||
mm_projectors: list[torch.nn.Module] = []
|
||||
for projector_key in self.projector_keys:
|
||||
project_module = module
|
||||
for key in projector_key.split("."):
|
||||
project_module = getattr(project_module, key, None)
|
||||
if project_module is None: # i,e gemma4 bigger one, there is no embed_audio
|
||||
logger.warning_rank0(
|
||||
f"Projector key {projector_key} not found in module {module.__class__.__name__}."
|
||||
)
|
||||
break
|
||||
|
||||
return module
|
||||
if project_module is not None:
|
||||
mm_projectors.append(project_module)
|
||||
|
||||
return mm_projectors
|
||||
|
||||
|
||||
COMPOSITE_MODELS: dict[str, "CompositeModel"] = {}
|
||||
@@ -57,7 +67,7 @@ COMPOSITE_MODELS: dict[str, "CompositeModel"] = {}
|
||||
|
||||
def _register_composite_model(
|
||||
model_type: str,
|
||||
projector_key: Optional[str] = None,
|
||||
projector_keys: list[str] | None = None,
|
||||
vision_model_keys: Optional[list[str]] = None,
|
||||
language_model_keys: Optional[list[str]] = None,
|
||||
lora_conflict_keys: Optional[list[str]] = None,
|
||||
@@ -66,7 +76,7 @@ def _register_composite_model(
|
||||
|
||||
Args:
|
||||
model_type: model type
|
||||
projector_key: multi_modal_projector
|
||||
projector_keys: multi_modal_projector
|
||||
vision_model_keys: vision_tower
|
||||
language_model_keys: language_model
|
||||
lora_conflict_keys: None
|
||||
@@ -74,7 +84,7 @@ def _register_composite_model(
|
||||
"""
|
||||
COMPOSITE_MODELS[model_type] = CompositeModel(
|
||||
model_type=model_type,
|
||||
projector_key=projector_key or "multi_modal_projector",
|
||||
projector_keys=projector_keys or ["multi_modal_projector"],
|
||||
vision_model_keys=vision_model_keys or ["vision_tower"],
|
||||
language_model_keys=language_model_keys or ["language_model", "lm_head"],
|
||||
lora_conflict_keys=lora_conflict_keys or [],
|
||||
@@ -137,12 +147,16 @@ def autocast_projector_dtype(model: "PreTrainedModel", model_args: "ModelArgumen
|
||||
if getattr(model, "quantization_method", None):
|
||||
model_type = getattr(model.config, "model_type", None)
|
||||
if model_type in COMPOSITE_MODELS:
|
||||
mm_projector = COMPOSITE_MODELS[model_type].get_projector(model)
|
||||
mm_projectors = COMPOSITE_MODELS[model_type].get_projectors(model)
|
||||
else:
|
||||
return
|
||||
|
||||
logger.info_rank0(f"Casting multimodal projector outputs in {model_args.compute_dtype}.")
|
||||
mm_projector.register_forward_hook(_mm_projector_forward_post_hook)
|
||||
logger.info_rank0(
|
||||
f"Casting multimodal projector outputs in {model_args.compute_dtype}: "
|
||||
f"{COMPOSITE_MODELS[model_type].projector_keys}."
|
||||
)
|
||||
for mm_projector in mm_projectors:
|
||||
mm_projector.register_forward_hook(_mm_projector_forward_post_hook)
|
||||
|
||||
|
||||
def configure_visual_model(config: "PretrainedConfig") -> None:
|
||||
@@ -167,9 +181,9 @@ def get_forbidden_modules(config: "PretrainedConfig", finetuning_args: "Finetuni
|
||||
forbidden_modules.update(vision_model_keys)
|
||||
|
||||
if finetuning_args.freeze_multi_modal_projector:
|
||||
projector_key = COMPOSITE_MODELS[model_type].projector_key
|
||||
logger.info_rank0(f"Set multi model projector not trainable: {projector_key}.")
|
||||
forbidden_modules.add(projector_key)
|
||||
projector_keys = COMPOSITE_MODELS[model_type].projector_keys
|
||||
logger.info_rank0(f"Set multi model projector not trainable: {projector_keys}.")
|
||||
forbidden_modules.update(projector_keys)
|
||||
|
||||
if finetuning_args.freeze_language_model:
|
||||
language_model_keys = COMPOSITE_MODELS[model_type].language_model_keys
|
||||
@@ -201,7 +215,7 @@ def patch_target_modules(
|
||||
|
||||
_register_composite_model(
|
||||
model_type="dots_ocr",
|
||||
projector_key="vision_tower.merger",
|
||||
projector_keys=["vision_tower.merger"],
|
||||
vision_model_keys=["vision_tower"],
|
||||
language_model_keys=["model", "lm_head"],
|
||||
lora_conflict_keys=["merger"],
|
||||
@@ -220,10 +234,18 @@ _register_composite_model(
|
||||
)
|
||||
|
||||
|
||||
_register_composite_model(
|
||||
model_type="gemma4",
|
||||
projector_keys=["model.embed_vision", "model.embed_audio"],
|
||||
vision_model_keys=["vision_tower", "audio_tower"],
|
||||
lora_conflict_keys=["per_layer_projection_norm"],
|
||||
)
|
||||
|
||||
|
||||
# copied from qwen2vl
|
||||
_register_composite_model(
|
||||
model_type="glm4v",
|
||||
projector_key="visual.merger",
|
||||
projector_keys=["visual.merger"],
|
||||
vision_model_keys=["visual.patch_embed", "visual.blocks"],
|
||||
language_model_keys=["language_model", "lm_head"],
|
||||
lora_conflict_keys=["patch_embed"],
|
||||
@@ -232,7 +254,7 @@ _register_composite_model(
|
||||
|
||||
_register_composite_model(
|
||||
model_type="glm4v_moe",
|
||||
projector_key="visual.merger",
|
||||
projector_keys=["visual.merger"],
|
||||
vision_model_keys=["visual.patch_embed", "visual.blocks"],
|
||||
language_model_keys=["language_model", "lm_head"],
|
||||
lora_conflict_keys=["patch_embed"],
|
||||
@@ -241,7 +263,7 @@ _register_composite_model(
|
||||
|
||||
_register_composite_model(
|
||||
model_type="glm_ocr",
|
||||
projector_key="visual.merger",
|
||||
projector_keys=["visual.merger"],
|
||||
vision_model_keys=["visual.patch_embed", "visual.blocks"],
|
||||
language_model_keys=["language_model", "lm_head"],
|
||||
lora_conflict_keys=["patch_embed"],
|
||||
@@ -258,7 +280,7 @@ _register_composite_model(
|
||||
|
||||
_register_composite_model(
|
||||
model_type="Keye",
|
||||
projector_key="mlp_AR",
|
||||
projector_keys=["mlp_AR"],
|
||||
vision_model_keys=["visual.vision_model.patch_embedding", "visual.vision_model.encoder"],
|
||||
language_model_keys=["model", "lm_head"],
|
||||
lora_conflict_keys=["patch_embedding"],
|
||||
@@ -293,15 +315,23 @@ _register_composite_model(
|
||||
|
||||
_register_composite_model(
|
||||
model_type="minicpmv",
|
||||
projector_key="resampler",
|
||||
projector_keys=["resampler"],
|
||||
vision_model_keys=["vpm"],
|
||||
language_model_keys=["llm"],
|
||||
)
|
||||
|
||||
|
||||
_register_composite_model(
|
||||
model_type="minicpmv4_6",
|
||||
projector_keys=["model.merger"],
|
||||
vision_model_keys=["model.vision_tower"],
|
||||
language_model_keys=["model.language_model", "lm_head"],
|
||||
)
|
||||
|
||||
|
||||
_register_composite_model(
|
||||
model_type="minicpmo",
|
||||
projector_key="resampler",
|
||||
projector_keys=["resampler"],
|
||||
vision_model_keys=["vpm", "apm", "audio_avg_pooler", "audio_projection_layer", "tts"],
|
||||
language_model_keys=["llm"],
|
||||
lora_conflict_keys=["audio_projection_layer"],
|
||||
@@ -310,7 +340,7 @@ _register_composite_model(
|
||||
|
||||
_register_composite_model(
|
||||
model_type="mistral3",
|
||||
projector_key="model.multi_modal_projector",
|
||||
projector_keys=["model.multi_modal_projector"],
|
||||
)
|
||||
|
||||
|
||||
@@ -333,7 +363,7 @@ _register_composite_model(
|
||||
|
||||
_register_composite_model(
|
||||
model_type="qwen2_5_omni_thinker",
|
||||
projector_key="visual.merger",
|
||||
projector_keys=["visual.merger", "audio_tower.proj"],
|
||||
vision_model_keys=["visual.patch_embed", "visual.blocks", "audio_tower"],
|
||||
language_model_keys=["model", "lm_head"],
|
||||
lora_conflict_keys=["patch_embed"],
|
||||
@@ -342,29 +372,25 @@ _register_composite_model(
|
||||
|
||||
_register_composite_model(
|
||||
model_type="qwen2_vl",
|
||||
projector_key="visual.merger",
|
||||
projector_keys=["visual.merger"],
|
||||
vision_model_keys=["visual.patch_embed", "visual.blocks"],
|
||||
language_model_keys=["language_model", "lm_head"]
|
||||
if is_transformers_version_greater_than("4.52.0")
|
||||
else ["model", "lm_head"],
|
||||
language_model_keys=["language_model", "lm_head"],
|
||||
lora_conflict_keys=["patch_embed"],
|
||||
)
|
||||
|
||||
|
||||
_register_composite_model(
|
||||
model_type="qwen2_5_vl",
|
||||
projector_key="visual.merger",
|
||||
projector_keys=["visual.merger"],
|
||||
vision_model_keys=["visual.patch_embed", "visual.blocks"],
|
||||
language_model_keys=["language_model", "lm_head"]
|
||||
if is_transformers_version_greater_than("4.52.0")
|
||||
else ["model", "lm_head"],
|
||||
language_model_keys=["language_model", "lm_head"],
|
||||
lora_conflict_keys=["patch_embed"],
|
||||
)
|
||||
|
||||
|
||||
_register_composite_model(
|
||||
model_type="qwen3_vl",
|
||||
projector_key="visual.merger",
|
||||
projector_keys=["visual.merger"],
|
||||
vision_model_keys=["visual.pos_embed", "visual.patch_embed", "visual.blocks", "visual.deepstack_merger_list"],
|
||||
language_model_keys=["language_model", "lm_head"],
|
||||
lora_conflict_keys=["patch_embed"],
|
||||
@@ -373,7 +399,7 @@ _register_composite_model(
|
||||
|
||||
_register_composite_model(
|
||||
model_type="qwen3_vl_moe",
|
||||
projector_key="visual.merger",
|
||||
projector_keys=["visual.merger"],
|
||||
vision_model_keys=["visual.pos_embed", "visual.patch_embed", "visual.blocks", "visual.deepstack_merger_list"],
|
||||
language_model_keys=["language_model", "lm_head"],
|
||||
lora_conflict_keys=["patch_embed"],
|
||||
@@ -382,7 +408,7 @@ _register_composite_model(
|
||||
|
||||
_register_composite_model(
|
||||
model_type="qwen3_omni_moe_thinker",
|
||||
projector_key="visual.merger",
|
||||
projector_keys=["visual.merger", "audio_tower.proj"],
|
||||
vision_model_keys=[
|
||||
"visual.pos_embed",
|
||||
"visual.patch_embed",
|
||||
@@ -390,14 +416,14 @@ _register_composite_model(
|
||||
"visual.deepstack_merger_list",
|
||||
"audio_tower",
|
||||
],
|
||||
language_model_keys=["model", "lm_head"],
|
||||
language_model_keys=["language_model", "lm_head"],
|
||||
lora_conflict_keys=["patch_embed"],
|
||||
)
|
||||
|
||||
|
||||
_register_composite_model(
|
||||
model_type="qwen3_5",
|
||||
projector_key="visual.merger",
|
||||
projector_keys=["model.visual.merger"],
|
||||
vision_model_keys=["visual.pos_embed", "visual.patch_embed", "visual.blocks"],
|
||||
language_model_keys=["language_model", "lm_head"],
|
||||
lora_conflict_keys=["patch_embed"],
|
||||
@@ -406,7 +432,7 @@ _register_composite_model(
|
||||
|
||||
_register_composite_model(
|
||||
model_type="qwen3_5_moe",
|
||||
projector_key="visual.merger",
|
||||
projector_keys=["model.visual.merger"],
|
||||
vision_model_keys=["visual.pos_embed", "visual.patch_embed", "visual.blocks"],
|
||||
language_model_keys=["language_model", "lm_head"],
|
||||
lora_conflict_keys=["patch_embed"],
|
||||
|
||||
@@ -30,7 +30,6 @@ from .model_utils.embedding import resize_embedding_layer
|
||||
from .model_utils.kv_cache import configure_kv_cache
|
||||
from .model_utils.longlora import configure_longlora
|
||||
from .model_utils.moe import add_z3_leaf_module, configure_moe
|
||||
from .model_utils.packing import configure_packing
|
||||
from .model_utils.quantization import configure_quantization
|
||||
from .model_utils.rope import configure_rope
|
||||
from .model_utils.valuehead import prepare_valuehead_model
|
||||
@@ -61,6 +60,195 @@ def patch_qwen3_omni_moe_thinker_text_sparse_moe_block():
|
||||
modeling_qwen3_omni_moe.Qwen3OmniMoeThinkerTextSparseMoeBlock = Qwen3OmniMoeThinkerTextSparseMoeBlock
|
||||
|
||||
|
||||
def _check_fla_dependencies() -> None:
|
||||
"""Check that the FLA dependencies required for varlen GDN forwarding are available.
|
||||
|
||||
Requires ``flash-linear-attention >= 0.4.1`` (which exposes the varlen
|
||||
``causal_conv1d`` under ``fla.modules.convolution`` and the
|
||||
``chunk_gated_delta_rule`` / ``fused_recurrent_gated_delta_rule`` kernels
|
||||
under ``fla.ops.gated_delta_rule``). Raises ``ImportError`` with an
|
||||
actionable message otherwise.
|
||||
"""
|
||||
try:
|
||||
from fla.modules.convolution import causal_conv1d # noqa: F401
|
||||
from fla.ops.gated_delta_rule import ( # noqa: F401
|
||||
chunk_gated_delta_rule,
|
||||
fused_recurrent_gated_delta_rule,
|
||||
)
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"Qwen3.5 packing-seq forwarding requires `flash-linear-attention>=0.4.1` "
|
||||
"(provides `fla.modules.convolution.causal_conv1d` and "
|
||||
"`fla.ops.gated_delta_rule.{chunk,fused_recurrent}_gated_delta_rule`). "
|
||||
"Please install/upgrade it."
|
||||
) from exc
|
||||
|
||||
|
||||
def patch_qwen3_5_forward(model: "PreTrainedModel") -> None:
|
||||
"""Patch the forward method of Qwen3_5ForConditionalGeneration to support cu_seqlens input only patch when do training.
|
||||
|
||||
Refer to: https://github.com/axolotl-ai-cloud/axolotl/blob/main/src/axolotl/monkeypatch/models/qwen3_5/modeling.py.
|
||||
"""
|
||||
if is_transformers_version_greater_than("5.2.0"):
|
||||
from transformers.models.qwen3_5.modeling_qwen3_5 import apply_mask_to_padding_states
|
||||
|
||||
from torch.nn import functional as F
|
||||
from transformers.modeling_flash_attention_utils import prepare_fa_kwargs_from_position_ids
|
||||
|
||||
_check_fla_dependencies()
|
||||
from fla.modules.convolution import causal_conv1d as fla_causal_conv1d
|
||||
from fla.ops.gated_delta_rule import chunk_gated_delta_rule
|
||||
|
||||
def _patched_decoder_forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
position_embeddings: tuple[torch.Tensor, torch.Tensor],
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
position_ids: torch.LongTensor | None = None,
|
||||
past_key_values=None,
|
||||
cache_position: torch.LongTensor | None = None,
|
||||
**kwargs,
|
||||
) -> torch.FloatTensor:
|
||||
"""Decoder layer forward that passes position_ids through to linear attention."""
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
|
||||
if self.layer_type == "linear_attention":
|
||||
hidden_states = self.linear_attn(
|
||||
hidden_states=hidden_states,
|
||||
cache_params=past_key_values,
|
||||
cache_position=cache_position,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids, # passing position_ids to linear attention
|
||||
)
|
||||
elif self.layer_type == "full_attention":
|
||||
hidden_states, _ = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids[None, 0], # keep [1, B, L]
|
||||
past_key_values=past_key_values,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
if isinstance(hidden_states, tuple): # MoE returns (hidden_states, router_logits)
|
||||
hidden_states, _ = hidden_states
|
||||
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
return hidden_states
|
||||
|
||||
# gdn forward (training only, cache_params is always None)
|
||||
def _patch_gdn_forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
cache_params=None,
|
||||
cache_position: torch.LongTensor | None = None,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
position_ids: torch.LongTensor | None = None,
|
||||
):
|
||||
# @kuangdd fix: here attention_mask is None
|
||||
hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask)
|
||||
|
||||
batch_size, seq_len, _ = hidden_states.shape
|
||||
|
||||
# Qwen3.5 VL passes 3-D MRoPE position_ids ([axes, B, T]); collapse to [B, T].
|
||||
if position_ids is not None and position_ids.ndim == 3:
|
||||
position_ids = position_ids[0]
|
||||
|
||||
# cu_seqlens for the FLA varlen path is only needed when batch_size == 1:
|
||||
# packing / neat-packing: always folded into a single sequence (bsz == 1) -> varlen
|
||||
# non-packing, bsz == 1: single segment, equivalent to a standard single sequence
|
||||
# non-packing, bsz > 1: not packed, use cu_seqlens=None and standard batched kernels
|
||||
if position_ids is not None and batch_size == 1:
|
||||
cu_seqlens = prepare_fa_kwargs_from_position_ids(position_ids)[0][0]
|
||||
else:
|
||||
cu_seqlens = None
|
||||
|
||||
# FLA varlen kernels expect [B, T, D] layout, not [B, D, T] like the
|
||||
# standard causal-conv1d path that the upstream forward uses.
|
||||
mixed_qkv = self.in_proj_qkv(hidden_states)
|
||||
|
||||
z = self.in_proj_z(hidden_states)
|
||||
z = z.reshape(batch_size, seq_len, -1, self.head_v_dim)
|
||||
|
||||
b = self.in_proj_b(hidden_states)
|
||||
a = self.in_proj_a(hidden_states)
|
||||
|
||||
# FLA's causal_conv1d returns (out, final_state); we don't use the state here.
|
||||
mixed_qkv, _ = fla_causal_conv1d(
|
||||
x=mixed_qkv,
|
||||
weight=self.conv1d.weight.squeeze(1),
|
||||
bias=self.conv1d.bias,
|
||||
activation=self.activation,
|
||||
cu_seqlens=cu_seqlens,
|
||||
)
|
||||
|
||||
query, key, value = torch.split(
|
||||
mixed_qkv,
|
||||
[
|
||||
self.key_dim,
|
||||
self.key_dim,
|
||||
self.value_dim,
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
query = query.reshape(batch_size, seq_len, -1, self.head_k_dim)
|
||||
key = key.reshape(batch_size, seq_len, -1, self.head_k_dim)
|
||||
value = value.reshape(batch_size, seq_len, -1, self.head_v_dim)
|
||||
|
||||
beta = b.sigmoid()
|
||||
# If the model is loaded in fp16, without the .float() here, A might be -inf
|
||||
g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias)
|
||||
if self.num_v_heads // self.num_k_heads > 1:
|
||||
query = query.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2)
|
||||
key = key.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2)
|
||||
|
||||
core_attn_out, _ = chunk_gated_delta_rule(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
g=g,
|
||||
beta=beta,
|
||||
initial_state=None,
|
||||
output_final_state=False,
|
||||
use_qk_l2norm_in_kernel=True,
|
||||
**({"cu_seqlens": cu_seqlens} if cu_seqlens is not None else {}),
|
||||
)
|
||||
|
||||
core_attn_out = core_attn_out.reshape(-1, self.head_v_dim)
|
||||
z = z.reshape(-1, self.head_v_dim)
|
||||
core_attn_out = self.norm(core_attn_out, z)
|
||||
core_attn_out = core_attn_out.reshape(batch_size, seq_len, -1)
|
||||
|
||||
output = self.out_proj(core_attn_out)
|
||||
|
||||
return output
|
||||
|
||||
if model.config.architectures[0] == "Qwen3_5ForConditionalGeneration":
|
||||
from transformers.models.qwen3_5.modeling_qwen3_5 import Qwen3_5DecoderLayer, Qwen3_5GatedDeltaNet
|
||||
|
||||
Qwen3_5DecoderLayer.forward = _patched_decoder_forward
|
||||
Qwen3_5GatedDeltaNet.forward = _patch_gdn_forward
|
||||
elif model.config.architectures[0] == "Qwen3_5MoeForConditionalGeneration":
|
||||
from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import (
|
||||
Qwen3_5MoeDecoderLayer,
|
||||
Qwen3_5MoeGatedDeltaNet,
|
||||
)
|
||||
|
||||
Qwen3_5MoeDecoderLayer.forward = _patched_decoder_forward
|
||||
Qwen3_5MoeGatedDeltaNet.forward = _patch_gdn_forward
|
||||
|
||||
logger.info_rank0("Patched Qwen3.5 decoder forward to support cu_seqlens input only patch when do training.")
|
||||
|
||||
|
||||
def patch_youtu_vl_model(model: "PreTrainedModel") -> None:
|
||||
original_forward = model.forward
|
||||
|
||||
@@ -142,7 +330,6 @@ def patch_config(
|
||||
configure_quantization(config, tokenizer, model_args, is_trainable, init_kwargs)
|
||||
configure_moe(config, model_args, is_trainable)
|
||||
configure_visual_model(config)
|
||||
configure_packing(model_args, is_trainable)
|
||||
configure_kv_cache(config, model_args, is_trainable)
|
||||
|
||||
if getattr(config, "model_type", None) == "qwen":
|
||||
@@ -234,6 +421,9 @@ def patch_model(
|
||||
autocast_projector_dtype(model, model_args)
|
||||
add_z3_leaf_module(model)
|
||||
|
||||
if getattr(model.config, "model_type", None) in ["qwen3_5", "qwen3_5_moe"] and model_args.flash_attn == "fa2":
|
||||
patch_qwen3_5_forward(model)
|
||||
|
||||
if not model_args.use_unsloth:
|
||||
print_attn_implementation(model.config)
|
||||
|
||||
|
||||
@@ -12,11 +12,13 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import fnmatch
|
||||
import json
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from datetime import timedelta
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
@@ -31,7 +33,7 @@ from typing_extensions import override
|
||||
|
||||
from ..extras import logging
|
||||
from ..extras.constants import TRAINER_LOG, V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
|
||||
from ..extras.misc import get_peak_memory, is_env_enabled, use_ray
|
||||
from ..extras.misc import get_peak_memory, is_env_enabled, is_torch_cuda_available, is_torch_npu_available, use_ray
|
||||
from ..extras.packages import is_safetensors_available
|
||||
|
||||
|
||||
@@ -228,7 +230,7 @@ class LogCallback(TrainerCallback):
|
||||
if (
|
||||
args.should_save
|
||||
and os.path.exists(os.path.join(args.output_dir, TRAINER_LOG))
|
||||
and args.overwrite_output_dir
|
||||
and getattr(args, "overwrite_output_dir", False)
|
||||
):
|
||||
logger.warning_rank0_once("Previous trainer log in this folder will be deleted.")
|
||||
os.remove(os.path.join(args.output_dir, TRAINER_LOG))
|
||||
@@ -338,6 +340,96 @@ class LogCallback(TrainerCallback):
|
||||
self.thread_pool.submit(self._write_log, args.output_dir, logs)
|
||||
|
||||
|
||||
class TorchProfilerCallback(TrainerCallback):
|
||||
r"""A callback for collecting torch.profiler traces during training.
|
||||
|
||||
Activated by setting ``enable_torch_profiler: true`` in the YAML config.
|
||||
|
||||
Configuration fields (in YAML):
|
||||
profiler_output_dir – where to write traces (default: <output_dir>/profiler)
|
||||
profiler_wait_steps – steps to skip at start of each cycle (default: 1)
|
||||
profiler_warmup_steps – profiler warm-up steps per cycle (default: 1)
|
||||
profiler_active_steps – steps to record per cycle (default: 1)
|
||||
profiler_repeat – number of cycles; 0 = forever (default: 1)
|
||||
profiler_record_shapes – record tensor shapes (default: true)
|
||||
profiler_profile_memory – profile memory usage (default: true)
|
||||
profiler_with_stack – record stack traces (default: true)
|
||||
|
||||
Trace files (one per rank, Chrome / TensorBoard JSON format) are written to
|
||||
``<profiler_output_dir>/rank_<N>/``.
|
||||
"""
|
||||
|
||||
def __init__(self, training_args: "TrainingArguments") -> None:
|
||||
self.profiler = None
|
||||
self.profiler_args = training_args
|
||||
|
||||
@staticmethod
|
||||
def _get_rank() -> int:
|
||||
import torch.distributed as dist
|
||||
|
||||
if dist.is_available() and dist.is_initialized():
|
||||
return dist.get_rank()
|
||||
return 0
|
||||
|
||||
@override
|
||||
def on_train_begin(
|
||||
self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs
|
||||
) -> None:
|
||||
if self.profiler is not None:
|
||||
self.profiler.stop()
|
||||
self.profiler = None
|
||||
|
||||
pa = self.profiler_args
|
||||
output_dir = pa.profiler_output_dir or os.path.join(args.output_dir, "profiler")
|
||||
rank = self._get_rank()
|
||||
trace_dir = os.path.join(output_dir, f"rank_{rank}")
|
||||
os.makedirs(trace_dir, exist_ok=True)
|
||||
|
||||
activities = [torch.profiler.ProfilerActivity.CPU]
|
||||
try:
|
||||
if is_torch_cuda_available():
|
||||
activities.append(torch.profiler.ProfilerActivity.CUDA)
|
||||
if is_torch_npu_available():
|
||||
activities.append(torch.profiler.ProfilerActivity.NPU)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
self.profiler = torch.profiler.profile(
|
||||
activities=activities,
|
||||
schedule=torch.profiler.schedule(
|
||||
wait=pa.profiler_wait_steps,
|
||||
warmup=pa.profiler_warmup_steps,
|
||||
active=pa.profiler_active_steps,
|
||||
repeat=pa.profiler_repeat,
|
||||
),
|
||||
on_trace_ready=torch.profiler.tensorboard_trace_handler(trace_dir),
|
||||
record_shapes=pa.profiler_record_shapes,
|
||||
profile_memory=pa.profiler_profile_memory,
|
||||
with_stack=pa.profiler_with_stack,
|
||||
)
|
||||
self.profiler.start()
|
||||
logger.info_rank0(
|
||||
f"TorchProfiler started — schedule: wait={pa.profiler_wait_steps}, warmup={pa.profiler_warmup_steps}, "
|
||||
f"active={pa.profiler_active_steps}, repeat={pa.profiler_repeat}. Traces → {output_dir}"
|
||||
)
|
||||
|
||||
@override
|
||||
def on_step_end(
|
||||
self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs
|
||||
) -> None:
|
||||
if self.profiler is not None:
|
||||
self.profiler.step()
|
||||
|
||||
@override
|
||||
def on_train_end(
|
||||
self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs
|
||||
) -> None:
|
||||
if self.profiler is not None:
|
||||
self.profiler.stop()
|
||||
self.profiler = None
|
||||
logger.info_rank0("TorchProfiler stopped.")
|
||||
|
||||
|
||||
class ReporterCallback(TrainerCallback):
|
||||
r"""A callback for reporting training status to external logger."""
|
||||
|
||||
@@ -371,6 +463,18 @@ class ReporterCallback(TrainerCallback):
|
||||
}
|
||||
)
|
||||
|
||||
if "trackio" in args.report_to:
|
||||
import trackio
|
||||
|
||||
trackio.config.update(
|
||||
{
|
||||
"model_args": self.model_args.to_dict(),
|
||||
"data_args": self.data_args.to_dict(),
|
||||
"finetuning_args": self.finetuning_args.to_dict(),
|
||||
"generating_args": self.generating_args.to_dict(),
|
||||
}
|
||||
)
|
||||
|
||||
if self.finetuning_args.use_swanlab:
|
||||
import swanlab # type: ignore
|
||||
|
||||
@@ -382,3 +486,143 @@ class ReporterCallback(TrainerCallback):
|
||||
"generating_args": self.generating_args.to_dict(),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class ModuleProfilerCallback(TrainerCallback):
|
||||
r"""Profile forward/backward time of specified modules using accelerator events.
|
||||
|
||||
Hooks are registered on modules matching the user-provided name patterns.
|
||||
Timing statistics are logged at each trainer logging step.
|
||||
|
||||
Usage in YAML config:
|
||||
profile_modules: "*.layers.0.self_attn,*.layers.0.mlp"
|
||||
|
||||
Supports fnmatch wildcards:
|
||||
profile_modules: "*.layers.*.self_attn,*.layers.*.mlp.experts"
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _get_accelerator():
|
||||
"""Detect available accelerator and return (event_factory, synchronize_fn)."""
|
||||
if is_torch_cuda_available():
|
||||
return torch.cuda.Event, torch.cuda.synchronize
|
||||
if is_torch_npu_available():
|
||||
return torch.npu.Event, torch.npu.synchronize
|
||||
return None, None
|
||||
|
||||
def __init__(self, profile_modules: str) -> None:
|
||||
self.patterns = [p.strip() for p in profile_modules.split(",") if p.strip()]
|
||||
self._create_event, self._synchronize = self._get_accelerator()
|
||||
self._handles: list[Any] = []
|
||||
self._forward_times: dict[str, list[float]] = defaultdict(list)
|
||||
self._backward_times: dict[str, list[float]] = defaultdict(list)
|
||||
self._pending_forward: dict[str, tuple] = {}
|
||||
self._pending_backward: dict[str, tuple] = {}
|
||||
|
||||
@property
|
||||
def enabled(self) -> bool:
|
||||
return self._create_event is not None
|
||||
|
||||
def _match(self, name: str) -> bool:
|
||||
return any(fnmatch.fnmatch(name, pat) for pat in self.patterns)
|
||||
|
||||
def _make_forward_pre_hook(self, name: str):
|
||||
def hook(module, input):
|
||||
start = self._create_event(enable_timing=True)
|
||||
end = self._create_event(enable_timing=True)
|
||||
start.record()
|
||||
self._pending_forward[name] = (start, end)
|
||||
|
||||
return hook
|
||||
|
||||
def _make_forward_hook(self, name: str):
|
||||
def hook(module, input, output):
|
||||
pair = self._pending_forward.get(name)
|
||||
if pair is not None:
|
||||
pair[1].record()
|
||||
|
||||
return hook
|
||||
|
||||
def _make_backward_pre_hook(self, name: str):
|
||||
def hook(module, grad_output):
|
||||
start = self._create_event(enable_timing=True)
|
||||
end = self._create_event(enable_timing=True)
|
||||
start.record()
|
||||
self._pending_backward[name] = (start, end)
|
||||
|
||||
return hook
|
||||
|
||||
def _make_backward_hook(self, name: str):
|
||||
def hook(module, grad_input, grad_output):
|
||||
pair = self._pending_backward.get(name)
|
||||
if pair is not None:
|
||||
pair[1].record()
|
||||
|
||||
return hook
|
||||
|
||||
@override
|
||||
def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
if not self.enabled:
|
||||
logger.warning_rank0("ModuleProfiler: no supported accelerator (CUDA/NPU) found, profiling disabled.")
|
||||
return
|
||||
|
||||
model = kwargs.get("model")
|
||||
if model is None:
|
||||
return
|
||||
|
||||
matched = []
|
||||
for name, module in model.named_modules():
|
||||
if not name or not self._match(name):
|
||||
continue
|
||||
self._handles.append(module.register_forward_pre_hook(self._make_forward_pre_hook(name)))
|
||||
self._handles.append(module.register_forward_hook(self._make_forward_hook(name)))
|
||||
self._handles.append(module.register_full_backward_pre_hook(self._make_backward_pre_hook(name)))
|
||||
self._handles.append(module.register_full_backward_hook(self._make_backward_hook(name)))
|
||||
matched.append(name)
|
||||
|
||||
if matched:
|
||||
logger.info_rank0(
|
||||
f"ModuleProfiler: registered hooks on {len(matched)} modules: {matched[:5]}"
|
||||
+ (f" ... (+{len(matched) - 5} more)" if len(matched) > 5 else "")
|
||||
)
|
||||
else:
|
||||
logger.warning_rank0(f"ModuleProfiler: no modules matched patterns {self.patterns}")
|
||||
|
||||
@override
|
||||
def on_step_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
if not self.enabled:
|
||||
return
|
||||
|
||||
self._synchronize()
|
||||
|
||||
for name, (start, end) in self._pending_forward.items():
|
||||
self._forward_times[name].append(start.elapsed_time(end))
|
||||
self._pending_forward.clear()
|
||||
|
||||
for name, (start, end) in self._pending_backward.items():
|
||||
self._backward_times[name].append(start.elapsed_time(end))
|
||||
self._pending_backward.clear()
|
||||
|
||||
@override
|
||||
def on_log(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
if not self._forward_times and not self._backward_times:
|
||||
return
|
||||
|
||||
lines = ["[ModuleProfiler] Timing (ms):"]
|
||||
all_names = sorted(set(list(self._forward_times.keys()) + list(self._backward_times.keys())))
|
||||
for name in all_names:
|
||||
fwd = self._forward_times.get(name, [])
|
||||
bwd = self._backward_times.get(name, [])
|
||||
fwd_mean = sum(fwd) / len(fwd) if fwd else 0.0
|
||||
bwd_mean = sum(bwd) / len(bwd) if bwd else 0.0
|
||||
lines.append(f" {name}: fwd={fwd_mean:.3f}, bwd={bwd_mean:.3f}, total={fwd_mean + bwd_mean:.3f}")
|
||||
|
||||
logger.info_rank0("\n".join(lines))
|
||||
self._forward_times.clear()
|
||||
self._backward_times.clear()
|
||||
|
||||
@override
|
||||
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
for handle in self._handles:
|
||||
handle.remove()
|
||||
self._handles.clear()
|
||||
|
||||
@@ -1,62 +0,0 @@
|
||||
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
|
||||
#
|
||||
# This code is inspired by the HuggingFace's TRL library.
|
||||
# https://github.com/huggingface/trl/blob/v0.8.0/trl/trainer/dpo_trainer.py
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
from ktransformers.sft.lora import KTrainer # type: ignore
|
||||
from typing_extensions import override
|
||||
|
||||
from ..trainer_utils import get_batch_logps, nested_detach
|
||||
from .trainer import CustomDPOTrainer
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PreTrainedModel
|
||||
|
||||
|
||||
class KDPOTrainer(KTrainer, CustomDPOTrainer):
|
||||
@override
|
||||
def concatenated_forward(
|
||||
self, model: "PreTrainedModel", batch: dict[str, "torch.Tensor"], is_ref_model: bool = False
|
||||
) -> tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
|
||||
r"""Compute the sum log probabilities of the labels under given logits if loss_type is not IPO, ORPO or SimPO.
|
||||
|
||||
Otherwise the average log probabilities.
|
||||
"""
|
||||
if self.finetuning_args.use_ref_model:
|
||||
batch = nested_detach(batch, clone=True) # avoid error
|
||||
|
||||
labels = batch.pop("labels") # dpo do not need compute loss in forward
|
||||
all_logits: torch.Tensor = model(**batch, return_dict=True, use_cache=False).logits.to(torch.float32)
|
||||
all_logits = all_logits.to("cpu")
|
||||
labels = labels.to(all_logits.device)
|
||||
all_logps, valid_length = get_batch_logps(
|
||||
logits=all_logits, labels=labels, ld_alpha=(self.ld_alpha if not is_ref_model else None)
|
||||
)
|
||||
if self.loss_type in ["ipo", "orpo", "simpo"]:
|
||||
all_logps = all_logps / valid_length
|
||||
|
||||
batch_size = batch["input_ids"].size(0) // 2
|
||||
chosen_logps, rejected_logps = all_logps.split(batch_size, dim=0)
|
||||
chosen_logits, rejected_logits = all_logits.split(batch_size, dim=0)
|
||||
chosen_length, _ = valid_length.split(batch_size, dim=0)
|
||||
|
||||
if self.loss_type in ["ipo", "orpo", "simpo"]:
|
||||
return chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_logps
|
||||
else:
|
||||
return chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_logps / chosen_length
|
||||
@@ -123,10 +123,10 @@ class CustomDPOTrainer(DPOTrainer):
|
||||
self.running = RunningMoments(self.accelerator)
|
||||
|
||||
@override
|
||||
def create_optimizer(self) -> "torch.optim.Optimizer":
|
||||
def create_optimizer(self, *args, **kwargs) -> "torch.optim.Optimizer":
|
||||
if self.optimizer is None:
|
||||
self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args)
|
||||
return super().create_optimizer()
|
||||
return super().create_optimizer(*args, **kwargs)
|
||||
|
||||
@override
|
||||
def create_scheduler(
|
||||
|
||||
@@ -62,15 +62,7 @@ def run_dpo(
|
||||
else:
|
||||
ref_model = None
|
||||
|
||||
if model_args.use_kt:
|
||||
from ktransformers.util.globals import GLOBAL_CONFIG # type: ignore
|
||||
|
||||
from .ktrainer import KDPOTrainer as CustomDPOTrainer
|
||||
|
||||
GLOBAL_CONFIG._config["mod"] = "sft"
|
||||
|
||||
else:
|
||||
from .trainer import CustomDPOTrainer
|
||||
from .trainer import CustomDPOTrainer
|
||||
|
||||
# Initialize our Trainer
|
||||
trainer = CustomDPOTrainer(
|
||||
|
||||
18
src/llamafactory/train/hyper_parallel/__init__.py
Normal file
18
src/llamafactory/train/hyper_parallel/__init__.py
Normal file
@@ -0,0 +1,18 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .workflow import run_sft
|
||||
|
||||
|
||||
__all__ = ["run_sft"]
|
||||
181
src/llamafactory/train/hyper_parallel/workflow.py
Normal file
181
src/llamafactory/train/hyper_parallel/workflow.py
Normal file
@@ -0,0 +1,181 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from ...data import SFTDataCollatorWith4DAttentionMask, get_dataset, get_template_and_fix_tokenizer
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
from ...extras.logging import get_logger
|
||||
from ...extras.misc import calculate_tps
|
||||
from ...extras.packages import is_hyper_parallel_available, is_transformers_version_greater_than
|
||||
from ...extras.ploting import plot_loss
|
||||
from ...model import load_model, load_tokenizer
|
||||
from ..callbacks import SaveProcessorCallback
|
||||
from ..sft.metric import ComputeAccuracy, ComputeSimilarity, eval_logit_processor
|
||||
from ..trainer_utils import asft_loss_func, create_modelcard_and_push, create_ref_model, dft_loss_func, eaft_loss_func
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
||||
|
||||
from ...hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def run_sft(
|
||||
model_args: "ModelArguments",
|
||||
data_args: "DataArguments",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
generating_args: "GeneratingArguments",
|
||||
callbacks: Optional[list["TrainerCallback"]] = None,
|
||||
):
|
||||
if not is_hyper_parallel_available():
|
||||
raise ImportError("hyper_parallel is not installed. Please install it with `pip install hyper_parallel`.")
|
||||
|
||||
from hyper_parallel.integration.llamafactory import ( # pylint: disable=C0415
|
||||
HyperParallelArguments,
|
||||
HyperParallelTrainer,
|
||||
)
|
||||
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
tokenizer = tokenizer_module["tokenizer"]
|
||||
template = get_template_and_fix_tokenizer(tokenizer, data_args)
|
||||
dataset_module = get_dataset(template, model_args, data_args, training_args, stage="sft", **tokenizer_module)
|
||||
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
|
||||
|
||||
ref_model = None
|
||||
if finetuning_args.use_asft_loss:
|
||||
ref_model = create_ref_model(model_args, finetuning_args)
|
||||
|
||||
data_collator = SFTDataCollatorWith4DAttentionMask(
|
||||
template=template,
|
||||
model=model if not training_args.predict_with_generate else None,
|
||||
pad_to_multiple_of=8 if training_args.do_train else None,
|
||||
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id,
|
||||
block_diag_attn=model_args.block_diag_attn,
|
||||
attn_implementation=getattr(model.config, "_attn_implementation", None),
|
||||
compute_dtype=model_args.compute_dtype,
|
||||
**tokenizer_module,
|
||||
)
|
||||
|
||||
# Metric utils
|
||||
metric_module = {}
|
||||
if training_args.predict_with_generate:
|
||||
metric_module["compute_metrics"] = ComputeSimilarity(tokenizer=tokenizer)
|
||||
elif finetuning_args.compute_accuracy:
|
||||
metric_module["compute_metrics"] = ComputeAccuracy()
|
||||
metric_module["preprocess_logits_for_metrics"] = eval_logit_processor
|
||||
|
||||
# Keyword arguments for `model.generate`
|
||||
gen_kwargs = generating_args.to_dict(obey_generation_config=True)
|
||||
if is_transformers_version_greater_than("4.58.0"):
|
||||
extra_ids = getattr(tokenizer, "additional_special_tokens_ids", None)
|
||||
if not isinstance(extra_ids, list):
|
||||
extra_special_tokens = getattr(tokenizer, "_extra_special_tokens", [])
|
||||
string_tokens = [str(t) for t in extra_special_tokens]
|
||||
extra_ids = tokenizer.convert_tokens_to_ids(string_tokens)
|
||||
all_eos_ids = [tokenizer.eos_token_id] + [i for i in extra_ids if i != -1]
|
||||
gen_kwargs["eos_token_id"] = list(dict.fromkeys(all_eos_ids))
|
||||
else:
|
||||
gen_kwargs["eos_token_id"] = [tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids
|
||||
gen_kwargs["pad_token_id"] = tokenizer.pad_token_id
|
||||
|
||||
hp_args = HyperParallelArguments.from_finetuning_args(finetuning_args)
|
||||
|
||||
callbacks = list(callbacks or [])
|
||||
processor = tokenizer_module.get("processor")
|
||||
if processor is not None:
|
||||
callbacks.append(SaveProcessorCallback(processor))
|
||||
|
||||
compute_loss_func = None
|
||||
if finetuning_args.use_dft_loss:
|
||||
compute_loss_func = dft_loss_func
|
||||
elif finetuning_args.use_eaft_loss:
|
||||
compute_loss_func = lambda outputs, labels, num_items_in_batch=None: eaft_loss_func( # noqa: E731
|
||||
outputs, labels, num_items_in_batch, finetuning_args.eaft_alpha
|
||||
)
|
||||
elif finetuning_args.use_asft_loss:
|
||||
from functools import partial
|
||||
|
||||
compute_loss_func = partial(asft_loss_func, asft_alpha=finetuning_args.asft_alpha)
|
||||
|
||||
trainer = HyperParallelTrainer(
|
||||
hp_args=hp_args,
|
||||
model=model,
|
||||
args=training_args,
|
||||
finetuning_args=finetuning_args,
|
||||
data_collator=data_collator,
|
||||
callbacks=callbacks,
|
||||
gen_kwargs=gen_kwargs,
|
||||
ref_model=ref_model,
|
||||
compute_loss_func=compute_loss_func,
|
||||
**dataset_module,
|
||||
**tokenizer_module,
|
||||
**metric_module,
|
||||
)
|
||||
|
||||
if finetuning_args.use_badam:
|
||||
from types import MethodType
|
||||
|
||||
from badam import BAdamCallback, clip_grad_norm_old_version # type: ignore[import]
|
||||
|
||||
trainer.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, trainer.accelerator)
|
||||
trainer.add_callback(BAdamCallback)
|
||||
|
||||
# Training
|
||||
if training_args.do_train:
|
||||
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
|
||||
trainer.save_model()
|
||||
if finetuning_args.include_effective_tokens_per_second:
|
||||
train_result.metrics["effective_tokens_per_sec"] = calculate_tps(
|
||||
dataset_module["train_dataset"], train_result.metrics, stage="sft"
|
||||
)
|
||||
|
||||
trainer.log_metrics("train", train_result.metrics)
|
||||
trainer.save_metrics("train", train_result.metrics)
|
||||
trainer.save_state()
|
||||
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
|
||||
keys = ["loss"]
|
||||
if isinstance(dataset_module.get("eval_dataset"), dict):
|
||||
keys += sum(
|
||||
[[f"eval_{key}_loss", f"eval_{key}_accuracy"] for key in dataset_module["eval_dataset"].keys()],
|
||||
[],
|
||||
)
|
||||
else:
|
||||
keys += ["eval_loss", "eval_accuracy"]
|
||||
|
||||
plot_loss(training_args.output_dir, keys=keys)
|
||||
|
||||
if training_args.predict_with_generate:
|
||||
tokenizer.padding_side = "left"
|
||||
|
||||
# Evaluation
|
||||
if training_args.do_eval:
|
||||
metrics = trainer.evaluate(metric_key_prefix="eval", **gen_kwargs)
|
||||
trainer.log_metrics("eval", metrics)
|
||||
trainer.save_metrics("eval", metrics)
|
||||
|
||||
# Predict
|
||||
if training_args.do_predict:
|
||||
logger.warning_rank0_once("Batch generation can be very slow. Consider using `scripts/vllm_infer.py` instead.")
|
||||
predict_results = trainer.predict(dataset_module["eval_dataset"], metric_key_prefix="predict", **gen_kwargs)
|
||||
trainer.log_metrics("predict", predict_results.metrics)
|
||||
trainer.save_metrics("predict", predict_results.metrics)
|
||||
trainer.save_predictions(dataset_module["eval_dataset"], predict_results, generating_args.skip_special_tokens)
|
||||
|
||||
# Create model card
|
||||
create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args)
|
||||
@@ -120,10 +120,10 @@ class CustomKTOTrainer(KTOTrainer):
|
||||
self.add_callback(BAdamCallback)
|
||||
|
||||
@override
|
||||
def create_optimizer(self) -> "torch.optim.Optimizer":
|
||||
def create_optimizer(self, *args, **kwargs) -> "torch.optim.Optimizer":
|
||||
if self.optimizer is None:
|
||||
self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args)
|
||||
return super().create_optimizer()
|
||||
return super().create_optimizer(*args, **kwargs)
|
||||
|
||||
@override
|
||||
def create_scheduler(
|
||||
|
||||
@@ -12,4 +12,62 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# TODO override the original trainer
|
||||
from typing import Any
|
||||
|
||||
import torch.nn.functional as F
|
||||
from mcore_adapter.trainer import McaTrainer
|
||||
from torch import Tensor
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
from typing_extensions import override
|
||||
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
|
||||
|
||||
class CustomMcaTrainer(McaTrainer):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
@override
|
||||
def _pad_batched_inputs(self, inputs: dict[str, Tensor | Any], seq_length: int):
|
||||
r"""Override to avoid padding error when handling 3d posids."""
|
||||
padding_inputs = {
|
||||
k: v.tolist() if v is not None and isinstance(v, Tensor) else v
|
||||
for k, v in inputs.items()
|
||||
if k in self._language_input_names
|
||||
}
|
||||
|
||||
position_ids_3d = None
|
||||
if isinstance(inputs.get("position_ids"), Tensor) and inputs["position_ids"].dim() == 3:
|
||||
position_ids_3d = inputs["position_ids"]
|
||||
padding_inputs.pop("position_ids", None)
|
||||
|
||||
if "labels" in padding_inputs:
|
||||
padding_inputs["labels"] = [
|
||||
labels + [IGNORE_INDEX] * (seq_length - len(labels)) for labels in padding_inputs["labels"]
|
||||
]
|
||||
tokenizer = (
|
||||
self.processing_class
|
||||
if isinstance(self.processing_class, PreTrainedTokenizerBase)
|
||||
else getattr(self.processing_class, "tokenizer", self.processing_class)
|
||||
)
|
||||
padding_side = getattr(tokenizer, "padding_side", "right")
|
||||
padding_inputs = tokenizer.pad(
|
||||
padding_inputs,
|
||||
padding="max_length",
|
||||
max_length=seq_length,
|
||||
return_tensors="pt",
|
||||
).to(self.args.device)
|
||||
inputs.update(padding_inputs)
|
||||
|
||||
if position_ids_3d is not None:
|
||||
current_seq_len = position_ids_3d.size(-1)
|
||||
if current_seq_len < seq_length:
|
||||
pad_len = seq_length - current_seq_len
|
||||
if padding_side == "left":
|
||||
position_ids_3d = F.pad(position_ids_3d, (pad_len, 0), value=0)
|
||||
else:
|
||||
position_ids_3d = F.pad(position_ids_3d, (0, pad_len), value=0)
|
||||
|
||||
inputs["position_ids"] = position_ids_3d.to(self.args.device)
|
||||
|
||||
return inputs
|
||||
|
||||
@@ -13,10 +13,13 @@
|
||||
# limitations under the License.
|
||||
|
||||
import functools
|
||||
import json
|
||||
import os
|
||||
from collections.abc import Sequence
|
||||
from copy import deepcopy
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
import torch
|
||||
from transformers import DataCollatorForSeq2Seq
|
||||
|
||||
from ...data import (
|
||||
@@ -41,9 +44,10 @@ if not is_mcore_adapter_available():
|
||||
|
||||
from mcore_adapter.models import AutoConfig, AutoModel
|
||||
from mcore_adapter.trainer import DPOTrainer as McaDPOTrainer
|
||||
from mcore_adapter.trainer import McaTrainer
|
||||
from mcore_adapter.trainer.dpo_config import DPOConfig
|
||||
|
||||
from .trainer import CustomMcaTrainer
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from mcore_adapter.training_args import Seq2SeqTrainingArguments as McaSeq2SeqTrainingArguments
|
||||
@@ -70,7 +74,18 @@ def _data_collator_wrapper(data_collator: Any):
|
||||
for k in ["attention_mask", "position_ids"]:
|
||||
if k in feature:
|
||||
feature[k] = feature[k][:-1]
|
||||
return data_collator(features)
|
||||
|
||||
# for qwen vl series model
|
||||
tmp_features = data_collator(features)
|
||||
tmp_features.pop("rope_deltas", None)
|
||||
position_ids = tmp_features.get("position_ids", None)
|
||||
|
||||
if position_ids is not None and position_ids.dim() == 3:
|
||||
if position_ids.shape[0] == 4:
|
||||
position_ids = position_ids[1:]
|
||||
tmp_features["position_ids"] = position_ids
|
||||
|
||||
return tmp_features
|
||||
|
||||
return wrapper
|
||||
|
||||
@@ -78,29 +93,42 @@ def _data_collator_wrapper(data_collator: Any):
|
||||
def _check_model_support(model_args: "ModelArguments"):
|
||||
from transformers import AutoConfig as HfAutoConfig
|
||||
|
||||
config = HfAutoConfig.from_pretrained(
|
||||
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
|
||||
)
|
||||
if config.model_type not in MCA_SUPPORTED_MODELS:
|
||||
if os.path.exists(os.path.join(model_args.model_name_or_path, "mca_config.json")): # load from mcore ckpt
|
||||
mca_config = json.load(open(os.path.join(model_args.model_name_or_path, "mca_config.json")))
|
||||
model_type = mca_config.get("hf_model_type", None)
|
||||
else:
|
||||
config = HfAutoConfig.from_pretrained(
|
||||
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
|
||||
)
|
||||
model_type = config.model_type
|
||||
|
||||
if model_type not in MCA_SUPPORTED_MODELS:
|
||||
raise ValueError(
|
||||
f"Model {config.model_type} is not supported by mcore_adapter."
|
||||
f"Model {model_type} is not supported by mcore_adapter."
|
||||
"You can try to upgrade mcore_adapter to the latest version for more supported models."
|
||||
)
|
||||
|
||||
|
||||
def _freeze_model_parameters(model: Any, finetuning_args: "FinetuningArguments"):
|
||||
"""Freeze model parameters for qwen_vl series models based on finetuning arguments."""
|
||||
if getattr(model.config, "hf_model_type", None) not in ["qwen2_vl", "qwen2_5_vl", "qwen3_vl", "qwen3_vl_moe"]:
|
||||
if getattr(model.config, "hf_model_type", None) not in [
|
||||
"qwen2_vl",
|
||||
"qwen2_5_vl",
|
||||
"qwen3_vl",
|
||||
"qwen3_vl_moe",
|
||||
"qwen3_5",
|
||||
"qwen3_5_moe",
|
||||
]:
|
||||
return
|
||||
|
||||
params_to_freeze = []
|
||||
if finetuning_args.freeze_vision_tower:
|
||||
params_to_freeze.extend(["vision_model.blocks", "vision_model.patch_embed"])
|
||||
if getattr(model.config, "hf_model_type", None) in ["qwen3_vl", "qwen3_vl_moe"]:
|
||||
if getattr(model.config, "hf_model_type", None) in ["qwen3_vl", "qwen3_vl_moe", "qwen3_5", "qwen3_5_moe"]:
|
||||
params_to_freeze.extend(["vision_model.pos_embed"])
|
||||
|
||||
if finetuning_args.freeze_multi_modal_projector:
|
||||
params_to_freeze.extend(["multi_modal_projector"])
|
||||
params_to_freeze.extend(["vision_model.merger"])
|
||||
|
||||
if finetuning_args.freeze_language_model:
|
||||
params_to_freeze.extend(["embedding", "decoder", "output_layer"])
|
||||
@@ -111,6 +139,27 @@ def _freeze_model_parameters(model: Any, finetuning_args: "FinetuningArguments")
|
||||
p.requires_grad_(False)
|
||||
|
||||
|
||||
def _build_meta_hf_model_for_collator(model_args: "ModelArguments") -> Any | None:
|
||||
r"""Build a lightweight HF model on meta device for compatibility with collator."""
|
||||
from transformers import AutoConfig as HfAutoConfig
|
||||
from transformers import AutoModel as HfAutoModel
|
||||
from transformers import AutoModelForImageTextToText
|
||||
|
||||
try:
|
||||
config = HfAutoConfig.from_pretrained(
|
||||
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
|
||||
)
|
||||
with torch.device("meta"):
|
||||
try:
|
||||
# Prefer multimodal auto class for VLMs (e.g. qwen2-vl), so get_rope_index is available.
|
||||
return AutoModelForImageTextToText.from_config(config)
|
||||
except Exception:
|
||||
return HfAutoModel.from_config(config)
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to build meta HF model for collator, fallback to no model. Error: %s", exc)
|
||||
return None
|
||||
|
||||
|
||||
def run_pt(
|
||||
model_args: "ModelArguments",
|
||||
data_args: "DataArguments",
|
||||
@@ -136,7 +185,7 @@ def run_pt(
|
||||
)
|
||||
data_collator = _data_collator_wrapper(data_collator)
|
||||
|
||||
trainer = McaTrainer(
|
||||
trainer = CustomMcaTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
tokenizer=tokenizer,
|
||||
@@ -186,6 +235,7 @@ def run_sft(
|
||||
|
||||
_check_model_support(model_args)
|
||||
model = AutoModel.from_pretrained(model_args.model_name_or_path, training_args)
|
||||
collator_model = _build_meta_hf_model_for_collator(model_args)
|
||||
|
||||
# optional freezing for qwen_vl series
|
||||
_freeze_model_parameters(model, finetuning_args)
|
||||
@@ -193,6 +243,7 @@ def run_sft(
|
||||
pad_to_max = training_args.expert_model_parallel_size is not None and training_args.expert_model_parallel_size > 1
|
||||
data_collator = SFTDataCollatorWith4DAttentionMask(
|
||||
template=template,
|
||||
model=collator_model,
|
||||
padding="max_length" if pad_to_max else "longest",
|
||||
max_length=data_args.cutoff_len if pad_to_max else None,
|
||||
pad_to_multiple_of=64,
|
||||
@@ -201,7 +252,7 @@ def run_sft(
|
||||
)
|
||||
data_collator = _data_collator_wrapper(data_collator)
|
||||
|
||||
trainer = McaTrainer(
|
||||
trainer = CustomMcaTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
tokenizer=tokenizer,
|
||||
@@ -240,6 +291,7 @@ def run_dpo(
|
||||
|
||||
_check_model_support(model_args)
|
||||
model = AutoModel.from_pretrained(model_args.model_name_or_path, training_args)
|
||||
collator_model = _build_meta_hf_model_for_collator(model_args)
|
||||
|
||||
_freeze_model_parameters(model, finetuning_args)
|
||||
|
||||
@@ -263,6 +315,7 @@ def run_dpo(
|
||||
)
|
||||
data_collator = PairwiseDataCollatorWithPadding(
|
||||
template=template,
|
||||
model=collator_model,
|
||||
pad_to_multiple_of=64,
|
||||
padding="max_length" if pad_to_max else "longest",
|
||||
max_length=data_args.cutoff_len if pad_to_max else None,
|
||||
|
||||
@@ -69,10 +69,10 @@ class CustomTrainer(Trainer):
|
||||
verify_fp8_status(self.accelerator, training_args)
|
||||
|
||||
@override
|
||||
def create_optimizer(self) -> "torch.optim.Optimizer":
|
||||
def create_optimizer(self, *args, **kwargs) -> "torch.optim.Optimizer":
|
||||
if self.optimizer is None:
|
||||
self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args)
|
||||
return super().create_optimizer()
|
||||
return super().create_optimizer(*args, **kwargs)
|
||||
|
||||
@override
|
||||
def create_scheduler(
|
||||
|
||||
@@ -65,10 +65,10 @@ class PairwiseTrainer(Trainer):
|
||||
self.add_callback(BAdamCallback)
|
||||
|
||||
@override
|
||||
def create_optimizer(self) -> "torch.optim.Optimizer":
|
||||
def create_optimizer(self, *args, **kwargs) -> "torch.optim.Optimizer":
|
||||
if self.optimizer is None:
|
||||
self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args)
|
||||
return super().create_optimizer()
|
||||
return super().create_optimizer(*args, **kwargs)
|
||||
|
||||
@override
|
||||
def create_scheduler(
|
||||
|
||||
@@ -128,10 +128,10 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
||||
verify_fp8_status(self.accelerator, training_args)
|
||||
|
||||
@override
|
||||
def create_optimizer(self) -> "torch.optim.Optimizer":
|
||||
def create_optimizer(self, *args, **kwargs) -> "torch.optim.Optimizer":
|
||||
if self.optimizer is None:
|
||||
self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args)
|
||||
return super().create_optimizer()
|
||||
return super().create_optimizer(*args, **kwargs)
|
||||
|
||||
@override
|
||||
def create_scheduler(
|
||||
@@ -215,7 +215,13 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
||||
if len(pad_len): # move pad token to last
|
||||
preds[i] = np.concatenate((preds[i][pad_len[0] :], preds[i][: pad_len[0]]), axis=-1)
|
||||
|
||||
decoded_inputs = self.processing_class.batch_decode(dataset["input_ids"], skip_special_tokens=False)
|
||||
input_ids_column = dataset["input_ids"]
|
||||
try:
|
||||
input_ids_list = input_ids_column.to_pylist()
|
||||
except AttributeError:
|
||||
input_ids_list = list(input_ids_column)
|
||||
|
||||
decoded_inputs = self.processing_class.batch_decode(input_ids_list, skip_special_tokens=False)
|
||||
decoded_preds = self.processing_class.batch_decode(preds, skip_special_tokens=skip_special_tokens)
|
||||
decoded_labels = self.processing_class.batch_decode(labels, skip_special_tokens=skip_special_tokens)
|
||||
|
||||
|
||||
@@ -65,6 +65,7 @@ def run_sft(
|
||||
pad_to_multiple_of=8 if training_args.do_train else None, # for shift short attention
|
||||
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id,
|
||||
block_diag_attn=model_args.block_diag_attn,
|
||||
neat_packing=data_args.neat_packing,
|
||||
attn_implementation=getattr(model.config, "_attn_implementation", None),
|
||||
compute_dtype=model_args.compute_dtype,
|
||||
**tokenizer_module,
|
||||
@@ -102,37 +103,18 @@ def run_sft(
|
||||
gen_kwargs["pad_token_id"] = tokenizer.pad_token_id
|
||||
|
||||
# Initialize our Trainer
|
||||
if model_args.use_kt:
|
||||
from ktransformers.sft.lora import KTrainer # type: ignore
|
||||
from ktransformers.util.globals import GLOBAL_CONFIG # type: ignore
|
||||
|
||||
GLOBAL_CONFIG._config["mod"] = "sft"
|
||||
|
||||
trainer = KTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
tokenizer=tokenizer_module,
|
||||
data_collator=data_collator,
|
||||
callbacks=callbacks,
|
||||
**dataset_module,
|
||||
**metric_module,
|
||||
)
|
||||
trainer.model_accepts_loss_kwargs = False
|
||||
model.config.use_cache = False
|
||||
|
||||
else:
|
||||
trainer = CustomSeq2SeqTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
finetuning_args=finetuning_args,
|
||||
data_collator=data_collator,
|
||||
callbacks=callbacks,
|
||||
gen_kwargs=gen_kwargs,
|
||||
ref_model=ref_model,
|
||||
**dataset_module,
|
||||
**tokenizer_module,
|
||||
**metric_module,
|
||||
)
|
||||
trainer = CustomSeq2SeqTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
finetuning_args=finetuning_args,
|
||||
data_collator=data_collator,
|
||||
callbacks=callbacks,
|
||||
gen_kwargs=gen_kwargs,
|
||||
ref_model=ref_model,
|
||||
**dataset_module,
|
||||
**tokenizer_module,
|
||||
**metric_module,
|
||||
)
|
||||
|
||||
# Training
|
||||
if training_args.do_train:
|
||||
|
||||
@@ -52,6 +52,7 @@ if is_ray_available():
|
||||
import ray
|
||||
from ray.util.placement_group import PlacementGroup, placement_group
|
||||
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
|
||||
from ray.util.state import list_nodes
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -102,7 +103,7 @@ def create_modelcard_and_push(
|
||||
kwargs["tags"] = kwargs["tags"] + ["unsloth"]
|
||||
|
||||
if model_args.use_kt:
|
||||
kwargs["tags"] = kwargs["tags"] + ["ktransformers"]
|
||||
kwargs["tags"] = kwargs["tags"] + ["kt-kernel"]
|
||||
|
||||
if not training_args.do_train:
|
||||
pass
|
||||
@@ -941,7 +942,7 @@ def get_ray_remote_config_for_worker(
|
||||
|
||||
def get_ray_head_node_ip() -> str:
|
||||
r"""Get the IP address of the Ray head node."""
|
||||
head_ip = next(node["NodeManagerAddress"] for node in ray.nodes() if node.get("IsHead", False))
|
||||
head_ip = next(node["node_ip"] for node in list_nodes() if node.get("is_head_node", False))
|
||||
return head_ip
|
||||
|
||||
|
||||
|
||||
@@ -24,10 +24,21 @@ from ..data import get_template_and_fix_tokenizer
|
||||
from ..extras import logging
|
||||
from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
|
||||
from ..extras.misc import find_available_port, get_device_name, get_torch_device, infer_optim_dtype
|
||||
from ..extras.packages import is_mcore_adapter_available, is_ray_available, is_transformers_version_greater_than
|
||||
from ..extras.packages import (
|
||||
is_hyper_parallel_available,
|
||||
is_mcore_adapter_available,
|
||||
is_ray_available,
|
||||
is_transformers_version_greater_than,
|
||||
)
|
||||
from ..hparams import RayArguments, get_infer_args, get_ray_args, get_train_args, read_args
|
||||
from ..model import load_model, load_tokenizer
|
||||
from .callbacks import LogCallback, PissaConvertCallback, ReporterCallback
|
||||
from .callbacks import (
|
||||
LogCallback,
|
||||
ModuleProfilerCallback,
|
||||
PissaConvertCallback,
|
||||
ReporterCallback,
|
||||
TorchProfilerCallback,
|
||||
)
|
||||
from .dpo import run_dpo
|
||||
from .kto import run_kto
|
||||
from .ppo import run_ppo
|
||||
@@ -69,9 +80,22 @@ def _training_function(config: dict[str, Any]) -> None:
|
||||
if finetuning_args.early_stopping_steps is not None:
|
||||
callbacks.append(EarlyStoppingCallback(early_stopping_patience=finetuning_args.early_stopping_steps))
|
||||
|
||||
if getattr(training_args, "enable_torch_profiler", False):
|
||||
callbacks.append(TorchProfilerCallback(training_args))
|
||||
|
||||
if getattr(training_args, "profile_modules", None):
|
||||
callbacks.append(ModuleProfilerCallback(training_args.profile_modules))
|
||||
|
||||
callbacks.append(ReporterCallback(model_args, data_args, finetuning_args, generating_args)) # add to last
|
||||
|
||||
if finetuning_args.stage in ["pt", "sft", "dpo"] and finetuning_args.use_mca:
|
||||
if finetuning_args.stage == "sft" and finetuning_args.use_hyper_parallel:
|
||||
if not is_hyper_parallel_available():
|
||||
raise ImportError("hyper_parallel is not installed. Please install it with `pip install hyper_parallel`.")
|
||||
from .hyper_parallel import run_sft as run_sft_hp
|
||||
|
||||
run_sft_hp(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
|
||||
|
||||
elif finetuning_args.stage in ["pt", "sft", "dpo"] and finetuning_args.use_mca:
|
||||
if not is_mcore_adapter_available():
|
||||
raise ImportError("mcore_adapter is not installed. Please install it with `pip install mcore-adapter`.")
|
||||
if finetuning_args.stage == "pt":
|
||||
@@ -168,7 +192,15 @@ def export_model(args: Optional[dict[str, Any]] = None) -> None:
|
||||
if not is_transformers_version_greater_than("5.0.0"):
|
||||
save_kwargs["safe_serialization"] = not model_args.export_legacy_format
|
||||
|
||||
model.save_pretrained(**save_kwargs)
|
||||
try:
|
||||
model.save_pretrained(**save_kwargs)
|
||||
except NotImplementedError as err:
|
||||
raise RuntimeError(
|
||||
"Failed to export model: weight conversion reversal is not supported for this model architecture "
|
||||
"(NotImplementedError in transformers.core_model_loading.reverse_op). "
|
||||
"This is a known issue with transformers>=5.0 for certain model types (e.g. Mistral/Ministral). "
|
||||
"Workarounds: (1) use transformers<5.0, or (2) report the issue to the transformers repository."
|
||||
) from err
|
||||
|
||||
if model_args.export_hub_model_id is not None:
|
||||
# Prepare push arguments (safe_serialization removed in transformers v5.0.0)
|
||||
|
||||
@@ -123,6 +123,8 @@ class DistributedInterface:
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
self.dist_config = config
|
||||
|
||||
helper.set_device_index()
|
||||
self._is_distributed = helper.is_distributed()
|
||||
self._rank = helper.get_rank()
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user