82 Commits

Author SHA1 Message Date
Yaowei Zheng
7ef1fba34a [version] fix gradio (#9685) 2025-12-28 05:00:51 +08:00
Copilot
eceec8ab69 [deps] goodbye python 3.9 (#9677)
Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: hiyouga <16256802+hiyouga@users.noreply.github.com>
Co-authored-by: hiyouga <hiyouga@buaa.edu.cn>
2025-12-27 02:50:44 +08:00
Yaowei Zheng
b44f651e09 [ci] fix docker (#9678) 2025-12-27 02:43:46 +08:00
Yaowei Zheng
55590f5ece [misc] fix ci with uv (#9676) 2025-12-27 01:39:13 +08:00
Copilot
a1b1931b4a [breaking] migrate from setuptools to uv (#9673)
Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: hiyouga <16256802+hiyouga@users.noreply.github.com>
2025-12-26 22:47:23 +08:00
Xunpeng Xiao
3c17f2722c [model] Update ernie_vl to adapt new version (#9665) 2025-12-26 19:57:49 +08:00
Copilot
a882e2d5fc [assets] Add GitHub Copilot instructions for repository (#9675)
Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: hiyouga <16256802+hiyouga@users.noreply.github.com>
2025-12-26 17:32:48 +08:00
Yaowei Zheng
a754604c11 [misc] fix accelerator (#9661)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2025-12-25 02:11:04 +08:00
Xunpeng Xiao
6a2eafbae3 [feat] Models trained and inferred with Mxfp4 are dequantized by default (#9652)
Co-authored-by: Yaowei Zheng <hiyouga@buaa.edu.cn>
2025-12-24 00:26:40 +08:00
Yaowei Zheng
84485406b7 [ci] disable pip cache for ci (#9654) 2025-12-23 18:37:40 +08:00
Kingsley
1c8a42d2f8 [v1&WIP] dataloader init (#9645) 2025-12-23 16:29:47 +08:00
thulyubh22
7901b2f32e [model] efficient tuning for gpt-oss (#9354)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2025-12-23 16:28:38 +08:00
Yaowei Zheng
1f1f5a7d1b [ci] remove docker cache (#9640) 2025-12-22 01:03:10 +08:00
Yaowei Zheng
6ef9854713 [misc] fix cache & pin transformers to 4.57.1 (#9638) 2025-12-22 00:20:55 +08:00
Hertz
4923f52a28 [model] support MiMo-V2-Flash model (#9637) 2025-12-21 14:38:18 +08:00
Yaowei Zheng
0894b4f37e [misc] lint (#9636) 2025-12-20 16:19:39 +08:00
ZIYI ZENG
b0d49e137f [misc] Support split eval_dataset when explict set "predict_with_generate" (#9604)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2025-12-20 01:46:00 +08:00
Xunpeng Xiao
ddd7dcc722 [data] Fix the video frame sampling issue #9620 (#9634) 2025-12-19 18:36:31 +08:00
浮梦
5204cd2bca [misc] add version check for moe (#9633) 2025-12-19 14:57:37 +08:00
Xunpeng Xiao
8c74dca76a [feat] Models trained and inferred with FP8 are dequantized by default (#9627) 2025-12-18 22:54:35 +08:00
xvxuopop
e8deda53a1 [example] add Qwen3 series examples (#9624)
Co-authored-by: UsernameFull <tohowtodoit@gmail.com>
2025-12-18 21:27:00 +08:00
mrhaoxx
a769fb94b9 [feat] support ktransformers for dpo (#9621)
Co-authored-by: poryfly <porykid@gmail.com>
2025-12-18 21:26:25 +08:00
mrhaoxx
964569751f [kt] refactor ktransformers integration (#9632) 2025-12-18 21:26:04 +08:00
Hertz
9fd4b094d4 [model] support VibeThinker models (#9616) 2025-12-16 21:50:46 +08:00
浮梦
18c21bce5a [test] add allreduce test on npu (#9619)
Co-authored-by: frozenleaves <frozen@Mac.local>
2025-12-16 21:33:30 +08:00
sunyi0505
a0179772ab [example] add deepspeed autotp config and example (#9602) 2025-12-15 15:15:26 +08:00
Yaowei Zheng
aeda079014 [v1] model loader (#9613) 2025-12-14 11:50:52 +08:00
Xunpeng Xiao
fdd24276ed [feat] support new function call value (#9610)
Co-authored-by: Yaowei Zheng <hiyouga@buaa.edu.cn>
2025-12-14 00:20:33 +08:00
Yaowei Zheng
110d21713e [v1] add dp & mp mesh (#9611) 2025-12-13 01:44:28 +08:00
Yaowei Zheng
203069e11c [v1] add accelerator (#9607) 2025-12-12 19:22:06 +08:00
tangefly
4fd94141a4 [model] Add Ministral3 (#9582)
Co-authored-by: kingsley <kingsleydodonow@gmail.com>
2025-12-10 15:57:24 +08:00
Kingsley
22d6ac29d5 [model] Rename GLMV template (#9595) 2025-12-10 13:27:47 +08:00
DoubleWheat
cff4483392 [config] Fix RoPE scaling patch for resuming from a scaled model (#9588) 2025-12-09 20:37:37 +08:00
Yaowei Zheng
5d56817e2b [misc] lint (#9593)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2025-12-09 18:00:35 +08:00
Yaowei Zheng
1bbb461f76 [assets] update readme (#9587) 2025-12-09 12:22:54 +08:00
Hertz
c1f5f8fff6 [model] support GLM4.6v (#9586) 2025-12-09 11:06:42 +08:00
Yaowei Zheng
5744f1ea94 [v1] add models & accelerator (#9579) 2025-12-08 02:30:25 +08:00
tangefly
739954910a [deps] Update for Transformers v5 (#9569) 2025-12-08 01:13:32 +08:00
xvxuopop
109162dc56 [fix] fix the issue when using fsdp2 with gradient checkpointing. (#9541)
Co-authored-by: jin-yongxu <jinyongxu@h-partners.com>
2025-12-06 16:04:51 +08:00
jiaqiw09
165f3f073a [examples] add fsdp config for mutiple nodes (#9575)
Co-authored-by: Yaowei Zheng <hiyouga@buaa.edu.cn>
2025-12-05 23:22:48 +08:00
jiaqiw09
efb13b7483 [V1] Refactor ascend MoE kernel patch logic & Support Qwen3-MoE (#9557) 2025-12-02 00:22:03 +08:00
Username_Full
e43a972b25 [test] add npu test yaml and add ascend a3 docker file (#9547)
Co-authored-by: jiaqiw09 <jiaqiw960714@gmail.com>
2025-11-30 09:37:08 +08:00
Kingsley
22be45c78c [misc] fix omni thinker load (#9552)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2025-11-30 09:36:36 +08:00
浮梦
d1f585f80a [test] update test cmd (#9544)
Co-authored-by: frozenleaves <frozen@Mac.local>
Co-authored-by: Yaowei Zheng <hiyouga@buaa.edu.cn>
2025-11-27 17:59:42 +08:00
xvxuopop
955396e8a5 [example] correct the parameter errors in the examples file. (#9543) 2025-11-27 17:38:38 +08:00
xvxuopop
231756a5bf [chat] fix the error when the vLLM version is greater than 0.10.0 (#9539)
Co-authored-by: Yaowei Zheng <hiyouga@buaa.edu.cn>
2025-11-27 02:14:53 +08:00
xvxuopop
2c4fb3c97e [v1] Support fused moe kernel for qwen3vlmoe model. (#9532) 2025-11-27 02:13:33 +08:00
浮梦
2b6f16f261 [model] temporarily support npu fused options on v0, powered by v1 kernels (#9520)
Co-authored-by: frozenleaves <frozen@Mac.local>
2025-11-27 02:08:36 +08:00
浮梦
f17efde693 [v1] support automatic discovery of registered kernels. (#9509)
Co-authored-by: frozenleaves <frozen@Mac.local>
2025-11-27 01:47:22 +08:00
Hertz
591fc9ed02 [model] support ERNIE-4.5-VL Models (#9521) 2025-11-24 16:48:06 +08:00
Peilin Li
3140c242f0 [assets] add README with KT+llamafactory (#9514) 2025-11-19 16:50:45 +08:00
Peilin Li
887c562d60 [example] Add KTransformers Qwen3MoE example (#9511)
Co-authored-by: unknown <xiongchenhui@hisense.ad>
Co-authored-by: Kingsley <kingsleydodonow@gmail.com>
2025-11-19 00:53:28 +08:00
Edge-Seven
9779b1f361 [misc] fix typos in some files (#9505)
Co-authored-by: khanhkhanhlele <namkhanh20xx@gmail.com>
2025-11-18 20:36:01 +08:00
Yinlei Sun
45f0437a14 [v1] Add support for ShareGPT format. (#9486) 2025-11-18 13:44:08 +08:00
浮梦
d4e120423d [data] fix qwen3omni moe model (#9501)
Co-authored-by: frozenleaves <frozen@Mac.local>
2025-11-18 13:43:22 +08:00
Pory
10a446e373 [model] ktransformers qwen3 support (#9485)
Co-authored-by: unknown <xiongchenhui@hisense.ad>
2025-11-13 20:09:44 +08:00
jiaqiw09
0aa4a051af [test] support slow skip and device skip in Uts (#9484) 2025-11-13 20:08:22 +08:00
Yaowei Zheng
8173a88a26 [assets] update readme (#9477) 2025-11-12 16:15:41 +08:00
Kingsley
fef86fa7fe [data] fix qwen3omni audio length calculation (#9467) 2025-11-12 10:37:15 +08:00
taohongsheng
5afa851f71 [misc] Modify pip install command for huggingface_hub (#9463) 2025-11-10 23:04:00 +08:00
MyungHa Kwon
a711bce664 [data] add openai format (#9449) 2025-11-06 20:10:20 +08:00
魅影
bd24350cbf [v1] add pair data converter (#9360)
Co-authored-by: frozenleaves <frozen@Mac.local>
2025-11-06 14:05:58 +08:00
Peilin Li
bd30c0003b [train] fix denominator of ga in ksft loss (#9409) 2025-11-05 20:53:23 +08:00
魅影
8edd2622ce [docker] update npu dockerfile (#9407)
Co-authored-by: frozenleaves <frozen@Mac.local>
2025-11-05 18:28:32 +08:00
Yaowei Zheng
eaf963f67f [model] update kt code (#9406) 2025-11-05 15:27:22 +08:00
Kingsley
56f45e826f [train] fix MPO re-weight (#9405) 2025-11-04 21:10:41 +08:00
魅影
14abb75126 [model] enable using FA in npu (#9397)
Co-authored-by: frozenleaves <frozen@Mac.local>
2025-11-04 19:32:30 +08:00
한송민
5a9939050e [model] add deepstack_merger_list to Qwen3-VL vision_model_keys (#9399) 2025-11-04 19:27:34 +08:00
Peilin Li
934b3084ee [train] KTransformers SFT as backend engine for LLaMA-Factory (#9400)
Co-authored-by: jimmy128 <jimmy128@noreply.gitcode.com>
Co-authored-by: Yaowei Zheng <hiyouga@buaa.edu.cn>
2025-11-04 15:54:12 +08:00
Yaowei Zheng
3ae15da9c0 [misc] lint code (#9395) 2025-11-03 22:08:59 +08:00
魅影
215580c77d [data] fix mm pluigin for qwen omni video training (#9388)
Co-authored-by: frozenleaves <frozen@Mac.local>
2025-11-03 11:44:27 +08:00
魅影
767b344fb4 [model] remove npu sdpa patch (#9368)
Co-authored-by: frozenleaves <frozen@Mac.local>
2025-10-30 16:26:35 +08:00
Kingsley
3057db15c3 [readme] upd mcore readme (#9352) 2025-10-27 21:23:31 +08:00
Kingsley
13170577b2 [feat] support megatron-LM training by mcore_adapter (#9237)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Yaowei Zheng <hiyouga@buaa.edu.cn>
2025-10-26 16:21:30 +08:00
Xiaosu Zhu
129e918106 [data] Fix Qwen3VL plugin (#9297)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Yaowei Zheng <hiyouga@buaa.edu.cn>
Co-authored-by: kingsley <kingsleydodonow@gmail.com>
2025-10-26 16:07:04 +08:00
Yaowei Zheng
9c0d033a15 [model] add qwen3vl 2b & 32b (#9343) 2025-10-24 13:22:36 +08:00
Yaowei Zheng
2a822178de [deps] fix yanked packages (#9333) 2025-10-22 20:54:51 +08:00
Kingsley
b842457ef4 [ci] revert mac os ci setup (#9316) 2025-10-21 18:26:12 +08:00
魅影
2c6aded5d4 [v1] kernel plugin (#9274)
Co-authored-by: frozenleaves <frozen@Mac.local>
2025-10-18 18:02:14 +08:00
Yaowei Zheng
d9d67ba62d [misc] fix import error (#9299) 2025-10-17 17:46:27 +08:00
Yaowei Zheng
a442fa90ad [misc] fix import error (#9296) 2025-10-17 10:54:30 +08:00
wyfdgg
8c341cbaae [model] support hunyuan-mt model (#9284)
Co-authored-by: wyfdgg <liwenkun0812@163.com>
Co-authored-by: Yaowei Zheng <hiyouga@buaa.edu.cn>
2025-10-17 10:33:09 +08:00
209 changed files with 8679 additions and 1150 deletions

View File

@@ -15,6 +15,7 @@ LLAMAFACTORY_VERBOSITY=
USE_MODELSCOPE_HUB= USE_MODELSCOPE_HUB=
USE_OPENMIND_HUB= USE_OPENMIND_HUB=
USE_RAY= USE_RAY=
USE_KT=
RECORD_VRAM= RECORD_VRAM=
OPTIM_TORCH= OPTIM_TORCH=
NPU_JIT_COMPILE= NPU_JIT_COMPILE=
@@ -35,6 +36,8 @@ GRADIO_SERVER_NAME=
GRADIO_SERVER_PORT= GRADIO_SERVER_PORT=
GRADIO_ROOT_PATH= GRADIO_ROOT_PATH=
GRADIO_IPV6= GRADIO_IPV6=
# backend
USE_MCA=
# setup # setup
ENABLE_SHORT_CONSOLE= ENABLE_SHORT_CONSOLE=
# reserved (do not use) # reserved (do not use)

180
.github/copilot-instructions.md vendored Normal file
View File

@@ -0,0 +1,180 @@
# GitHub Copilot Instructions for LLaMA Factory
## Project Overview
LLaMA Factory is an efficient fine-tuning framework for 100+ large language models (LLMs). It provides:
- Support for various models: LLaMA, LLaVA, Mistral, Qwen, DeepSeek, Yi, Gemma, ChatGLM, Phi, etc.
- Multiple training methods: pre-training, supervised fine-tuning, reward modeling, PPO, DPO, KTO, ORPO
- Scalable resources: 16-bit full-tuning, freeze-tuning, LoRA and QLoRA variants
- Advanced algorithms: GaLore, BAdam, APOLLO, Adam-mini, Muon, OFT, DoRA, etc.
- Web UI (LLaMA Board) and CLI interfaces
### Architecture Versions
LLaMA Factory has two parallel architectures that can be switched via the `USE_V1` environment variable:
**v0 (default)** - File hierarchy:
- `api`, `webui``chat`, `eval`, `train``data`, `model``hparams``extras`
**v1** - File hierarchy:
- `trainers``core``accelerator`, `plugins`, `config``utils`
Set `USE_V1=1` to enable v1 architecture.
## Code Structure
### v0 Architecture (Default)
- `src/llamafactory/` - Main package directory
- `api/` - OpenAI-style API implementation
- `chat/` - Chat interface implementation
- `cli.py` - Command-line interface
- `data/` - Data processing and dataset handling
- `eval/` - Model evaluation utilities
- `extras/` - Additional utilities and helpers
- `hparams/` - Hyperparameter definitions
- `model/` - Model loading, patching, and utilities
- `train/` - Training pipeline implementation
- `webui/` - Gradio-based web interface
- `src/train.py` - Training entry script (delegates to `llamafactory.train.tuner`)
- `src/webui.py` - Web UI entry script (delegates to `llamafactory.webui.interface`)
- `src/api.py` - API server entry script (delegates to `llamafactory.api.app`)
- `tests/` - Test suite
- `examples/` - Example configurations for various training scenarios
- `data/` - Dataset definitions and examples
### v1 Architecture (USE_V1=1)
- `src/llamafactory/v1/` - Version 1 package directory
- `trainers/` - Training implementations
- `core/` - Core training utilities
- `accelerator/` - Acceleration and distributed training
- `plugins/` - Pluggable components (model, data, sampler, trainer)
- `config/` - Configuration management
- `utils/` - Utility functions
## Development Practices
### Code Style
- Follow the [Google Python Style Guide](https://google.github.io/styleguide/pyguide.html)
- Use ruff for linting and formatting
- Line length: 119 characters
- Indentation: 4 spaces
- Quote style: double quotes
- Use Google-style docstrings for documentation
### Import Organization
- Known first-party: `llamafactory`
- Known third-party: `accelerate`, `datasets`, `gradio`, `numpy`, `peft`, `torch`, `transformers`, `trl`
- Use 2 blank lines after imports
### Quality Checks
Before committing code, run:
```bash
make style # Auto-fix style issues
make quality # Check code quality
make test # Run test suite
```
Or use the combined command:
```bash
make commit # Run pre-commit hooks
```
### Testing
- Use pytest for testing
- Tests are located in `tests/` and `tests_v1/` directories
- Run tests with: `make test` (which runs `WANDB_DISABLED=true pytest -vv --import-mode=importlib tests/ tests_v1/`)
- Disable wandb during testing to avoid external dependencies
- **Note**: Training configurations require GPU machines, so training is typically not tested end-to-end. Use `make test` to validate file-level functionality.
### Building
Build the package with:
```bash
pip3 install build && python3 -m build
```
### License
- All source files must include the Apache 2.0 license header
- Check license headers with: `make license`
## Common Patterns
### Configuration Files
- Training configurations are typically YAML or JSON files in `examples/` directory
- Hyperparameters are defined using dataclasses in `src/llamafactory/hparams/`
### Model Support
- New model support is added through model patches in `src/llamafactory/model/`
- Visual models use the visual utilities in `src/llamafactory/model/model_utils/visual.py`
- Quantization support is in `src/llamafactory/model/model_utils/quantization.py`
### Data Processing
- Dataset definitions are in `data/dataset_info.json`
- Data templates and processors are in `src/llamafactory/data/`
### Training
- Training pipelines are in `src/llamafactory/train/`
- Support for different training methods: SFT, DPO, PPO, RM, PT, KTO, ORPO
## Key Dependencies
- Python >= 3.9.0
- PyTorch and transformers for model handling
- datasets for data processing
- peft for parameter-efficient fine-tuning
- accelerate for distributed training
- gradio for web UI
- trl for reinforcement learning
- Optional: vllm/sglang for inference, flash-attention-2, unsloth, liger-kernel
## Entry Points
- **CLI Training**: `llamafactory-cli train --config examples/train_lora/llama3_lora_sft.yaml`
- **Web UI**: `llamafactory-cli webui` or `python src/webui.py`
- **API Server**: `llamafactory-cli api` or `python src/api.py`
- **Chat Interface**: `llamafactory-cli chat --model_name_or_path MODEL_PATH`
## Environment Setup
For development:
```bash
pip install -e ".[dev]"
```
## Important Notes
- The project supports multiple backends: default PyTorch, vLLM, SGLang
- Megatron-core training is supported via mcore_adapter
- SwanLab and W&B are supported for experiment tracking
- Docker support is available with pre-built images
- Day-0/Day-1 support for latest cutting-edge models
- Multi-modal support for vision and audio understanding tasks
## Contribution Guidelines
1. Fork the repository
2. Create a development branch
3. Set up development environment with `pip install -e ".[dev]"`
4. Make changes following the style guide
5. Run quality checks: `make style && make quality`
6. Run tests: `make test`
7. Submit a pull request
## Common Commands
- `make style` - Format code
- `make quality` - Run linters
- `make test` - Run tests
- `make commit` - Install and run pre-commit hooks
- `make license` - Check license headers

View File

@@ -7,7 +7,7 @@ on:
- "main" - "main"
paths: paths:
- "**/*.py" - "**/*.py"
- "requirements.txt" - "pyproject.toml"
- "docker/**" - "docker/**"
- ".github/workflows/*.yml" - ".github/workflows/*.yml"
pull_request: pull_request:
@@ -15,7 +15,7 @@ on:
- "main" - "main"
paths: paths:
- "**/*.py" - "**/*.py"
- "requirements.txt" - "pyproject.toml"
- "docker/**" - "docker/**"
- ".github/workflows/*.yml" - ".github/workflows/*.yml"
release: release:
@@ -27,9 +27,10 @@ jobs:
strategy: strategy:
fail-fast: false fail-fast: false
matrix: matrix:
device: include:
- "cuda" - device: "cuda"
- "npu" - device: "npu-a2"
- device: "npu-a3"
runs-on: ubuntu-latest runs-on: ubuntu-latest
@@ -51,16 +52,11 @@ jobs:
- name: Checkout - name: Checkout
uses: actions/checkout@v4 uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.10"
- name: Get llamafactory version - name: Get llamafactory version
id: version id: version
run: | run: |
if [ "${{ github.event_name }}" = "release" ]; then if [ "${{ github.event_name }}" = "release" ]; then
echo "tag=$(python setup.py --version)" >> "$GITHUB_OUTPUT" echo "tag=$(grep -oP 'VERSION = "\K[^"]+' src/llamafactory/extras/env.py)" >> "$GITHUB_OUTPUT"
else else
echo "tag=latest" >> "$GITHUB_OUTPUT" echo "tag=latest" >> "$GITHUB_OUTPUT"
fi fi
@@ -76,7 +72,7 @@ jobs:
password: ${{ secrets.DOCKERHUB_TOKEN }} password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Login to Quay - name: Login to Quay
if: ${{ github.event_name != 'pull_request' && matrix.device == 'npu' }} if: ${{ github.event_name != 'pull_request' && matrix.device == 'npu'}}
uses: docker/login-action@v3 uses: docker/login-action@v3
with: with:
registry: quay.io registry: quay.io
@@ -89,16 +85,12 @@ jobs:
with: with:
context: . context: .
file: ./docker/docker-cuda/Dockerfile file: ./docker/docker-cuda/Dockerfile
build-args: |
EXTRAS=metrics,deepspeed,liger-kernel
push: ${{ github.event_name != 'pull_request' }} push: ${{ github.event_name != 'pull_request' }}
tags: | tags: |
docker.io/hiyouga/llamafactory:${{ steps.version.outputs.tag }} docker.io/hiyouga/llamafactory:${{ steps.version.outputs.tag }}
cache-from: type=gha
cache-to: type=gha,mode=max
- name: Build and push Docker image (NPU) - name: Build and push Docker image (NPU-A2)
if: ${{ matrix.device == 'npu' }} if: ${{ matrix.device == 'npu-a2' }}
uses: docker/build-push-action@v6 uses: docker/build-push-action@v6
with: with:
context: . context: .
@@ -108,5 +100,17 @@ jobs:
tags: | tags: |
docker.io/hiyouga/llamafactory:${{ steps.version.outputs.tag }}-npu-a2 docker.io/hiyouga/llamafactory:${{ steps.version.outputs.tag }}-npu-a2
quay.io/ascend/llamafactory:${{ steps.version.outputs.tag }}-npu-a2 quay.io/ascend/llamafactory:${{ steps.version.outputs.tag }}-npu-a2
cache-from: type=gha
cache-to: type=gha,mode=max - name: Build and push Docker image (NPU-A3)
if: ${{ matrix.device == 'npu-a3' }}
uses: docker/build-push-action@v6
with:
context: .
platforms: linux/amd64,linux/arm64
file: ./docker/docker-npu/Dockerfile
build-args: |
BASE_IMAGE=quay.io/ascend/cann:8.3.rc2-a3-ubuntu22.04-py3.11
push: ${{ github.event_name != 'pull_request' }}
tags: |
docker.io/hiyouga/llamafactory:${{ steps.version.outputs.tag }}-npu-a3
quay.io/ascend/llamafactory:${{ steps.version.outputs.tag }}-npu-a3

View File

@@ -23,10 +23,11 @@ jobs:
- name: Checkout - name: Checkout
uses: actions/checkout@v4 uses: actions/checkout@v4
- name: Set up Python - name: Install uv
uses: actions/setup-python@v5 uses: astral-sh/setup-uv@v7
with: with:
python-version: "3.9" python-version: "3.11"
github-token: ${{ github.token }}
- name: Build package - name: Build package
run: | run: |

View File

@@ -7,14 +7,16 @@ on:
- "main" - "main"
paths: paths:
- "**/*.py" - "**/*.py"
- "requirements.txt" - "pyproject.toml"
- "Makefile"
- ".github/workflows/*.yml" - ".github/workflows/*.yml"
pull_request: pull_request:
branches: branches:
- "main" - "main"
paths: paths:
- "**/*.py" - "**/*.py"
- "requirements.txt" - "pyproject.toml"
- "Makefile"
- ".github/workflows/*.yml" - ".github/workflows/*.yml"
jobs: jobs:
@@ -23,10 +25,9 @@ jobs:
fail-fast: false fail-fast: false
matrix: matrix:
python: python:
- "3.9"
- "3.10"
- "3.11" - "3.11"
- "3.12" - "3.12"
# - "3.13" # enable after trl is upgraded
os: os:
- "ubuntu-latest" - "ubuntu-latest"
- "windows-latest" - "windows-latest"
@@ -34,18 +35,15 @@ jobs:
transformers: transformers:
- null - null
include: # test backward compatibility include: # test backward compatibility
- python: "3.9" - python: "3.11"
os: "ubuntu-latest" os: "ubuntu-latest"
transformers: "4.49.0" transformers: "4.49.0"
- python: "3.9" - python: "3.11"
os: "ubuntu-latest" os: "ubuntu-latest"
transformers: "4.51.0" transformers: "4.51.0"
- python: "3.9" - python: "3.11"
os: "ubuntu-latest" os: "ubuntu-latest"
transformers: "4.53.0" transformers: "4.53.0"
exclude: # exclude python 3.9 on macos
- python: "3.9"
os: "macos-latest"
runs-on: ${{ matrix.os }} runs-on: ${{ matrix.os }}
@@ -61,28 +59,23 @@ jobs:
- name: Checkout - name: Checkout
uses: actions/checkout@v4 uses: actions/checkout@v4
- name: Set up Python - name: Install uv
uses: actions/setup-python@v5 uses: astral-sh/setup-uv@v7
with: with:
python-version: ${{ matrix.python }} python-version: ${{ matrix.python }}
cache: "pip" github-token: ${{ github.token }}
cache-dependency-path: "**/requirements*.txt" enable-cache: false
- name: Install dependencies - name: Install dependencies
run: | run: |
python -m pip install --upgrade pip uv venv
python -m pip install ".[torch,dev]" uv pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
uv pip install -e ".[dev]"
- name: Install transformers - name: Install transformers
if: ${{ matrix.transformers }} if: ${{ matrix.transformers }}
run: | run: |
python -m pip install "transformers==${{ matrix.transformers }}" uv pip install "transformers==${{ matrix.transformers }}"
- name: Update accelerate to avoid mac os ci errors (before accelerate 1.11.0)
if: ${{ matrix.os == 'macos-latest' }}
run: |
python -m pip uninstall -y accelerate
python -m pip install "git+https://github.com/huggingface/accelerate.git"
- name: Cache files - name: Cache files
id: hf-hub-cache id: hf-hub-cache
@@ -94,18 +87,25 @@ jobs:
- name: Check quality - name: Check quality
run: | run: |
make style && make quality make style && make quality
env:
UV_NO_SYNC: 1
- name: Check license - name: Check license
run: | run: |
make license make license
env:
UV_NO_SYNC: 1
- name: Check build - name: Check build
run: | run: |
make build make build
env:
UV_NO_SYNC: 1
- name: Test with pytest - name: Test with pytest
run: | run: |
make test make test
env: env:
UV_NO_SYNC: 1
HF_HOME: ${{ runner.temp }}/huggingface HF_HOME: ${{ runner.temp }}/huggingface
HF_HUB_OFFLINE: "${{ steps.hf-hub-cache.outputs.cache-hit == 'true' && '1' || '0' }}" HF_HUB_OFFLINE: "${{ steps.hf-hub-cache.outputs.cache-hit == 'true' && '1' || '0' }}"

99
.github/workflows/tests_npu.yml vendored Normal file
View File

@@ -0,0 +1,99 @@
name: tests_npu
on:
workflow_dispatch:
push:
branches:
- "main"
paths:
- "**/*.py"
- "pyproject.toml"
- "Makefile"
- ".github/workflows/*.yml"
pull_request:
branches:
- "main"
paths:
- "**/*.py"
- "pyproject.toml"
- "Makefile"
- ".github/workflows/*.yml"
jobs:
tests:
strategy:
fail-fast: false
matrix:
python:
- "3.11"
os:
- "linux-aarch64-a2-4"
pytorch_npu:
- "2.7.1"
runs-on: ${{ matrix.os }}
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ matrix.os }}-${{ matrix.python }}
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
container:
image: ascendai/cann:8.3.rc2-910b-ubuntu22.04-py3.11
env:
HF_ENDPOINT: https://hf-mirror.com
HF_TOKEN: ${{ secrets.HF_TOKEN }}
OS_NAME: ${{ matrix.os }}
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Install uv
run: |
curl -LsSf https://astral.sh/uv/install.sh | sh
- name: Install dependencies
run: |
uv venv
uv pip install torch-npu==${{matrix.pytorch_npu}}
uv pip install -e ".[dev]"
- name: Install node
run: |
apt-get update || true
apt-get install -y curl
curl -fsSL https://deb.nodesource.com/setup_20.x | bash -
apt-get install -y nodejs
- name: Cache files
id: hf-hub-cache
uses: actions/cache@v4
with:
path: ${{ runner.temp }}/huggingface
key: huggingface-${{ matrix.os }}-${{ matrix.python }}-${{ hashFiles('tests/version.txt') }}
- name: Check quality
run: |
make style && make quality
env:
UV_NO_SYNC: 1
- name: Check license
run: |
make license
env:
UV_NO_SYNC: 1
- name: Check build
run: |
make build
env:
UV_NO_SYNC: 1
- name: Test with pytest
run: |
make test
env:
UV_NO_SYNC: 1
HF_HOME: /root/.cache/huggingface
HF_HUB_OFFLINE: "${{ steps.hf-hub-cache.outputs.cache-hit == 'true' && '1' || '0' }}"

5
.gitignore vendored
View File

@@ -85,7 +85,7 @@ ipython_config.py
# pyenv # pyenv
# For a library or package, you might want to ignore these files since the code is # For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in: # intended to run in multiple environments; otherwise, check them in:
# .python-version .python-version
# pipenv # pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
@@ -165,6 +165,9 @@ cython_debug/
# uv # uv
uv.lock uv.lock
# macOS
.DS_Store
# custom .gitignore # custom .gitignore
hf_cache/ hf_cache/
ms_cache/ ms_cache/

View File

@@ -1 +1 @@
include LICENSE requirements.txt include LICENSE

View File

@@ -1,24 +1,28 @@
.PHONY: build commit license quality style test .PHONY: build commit license quality style test
check_dirs := scripts src tests tests_v1 setup.py check_dirs := scripts src tests tests_v1
RUN := $(shell command -v uv >/dev/null 2>&1 && echo "uv run" || echo "")
BUILD := $(shell command -v uv >/dev/null 2>&1 && echo "uv build" || echo "python -m build")
TOOL := $(shell command -v uv >/dev/null 2>&1 && echo "uvx" || echo "")
build: build:
pip3 install build && python3 -m build $(BUILD)
commit: commit:
pre-commit install $(TOOL) pre-commit install
pre-commit run --all-files $(TOOL) pre-commit run --all-files
license: license:
python3 tests/check_license.py $(check_dirs) $(RUN) python3 tests/check_license.py $(check_dirs)
quality: quality:
ruff check $(check_dirs) $(TOOL) ruff check $(check_dirs)
ruff format --check $(check_dirs) $(TOOL) ruff format --check $(check_dirs)
style: style:
ruff check $(check_dirs) --fix $(TOOL) ruff check $(check_dirs) --fix
ruff format $(check_dirs) $(TOOL) ruff format $(check_dirs)
test: test:
CUDA_VISIBLE_DEVICES= WANDB_DISABLED=true pytest -vv tests/ WANDB_DISABLED=true $(RUN) pytest -vv --import-mode=importlib tests/ tests_v1/

View File

@@ -5,11 +5,13 @@
[![GitHub contributors](https://img.shields.io/github/contributors/hiyouga/LLaMA-Factory?color=orange)](https://github.com/hiyouga/LLaMA-Factory/graphs/contributors) [![GitHub contributors](https://img.shields.io/github/contributors/hiyouga/LLaMA-Factory?color=orange)](https://github.com/hiyouga/LLaMA-Factory/graphs/contributors)
[![GitHub workflow](https://github.com/hiyouga/LLaMA-Factory/actions/workflows/tests.yml/badge.svg)](https://github.com/hiyouga/LLaMA-Factory/actions/workflows/tests.yml) [![GitHub workflow](https://github.com/hiyouga/LLaMA-Factory/actions/workflows/tests.yml/badge.svg)](https://github.com/hiyouga/LLaMA-Factory/actions/workflows/tests.yml)
[![PyPI](https://img.shields.io/pypi/v/llamafactory)](https://pypi.org/project/llamafactory/) [![PyPI](https://img.shields.io/pypi/v/llamafactory)](https://pypi.org/project/llamafactory/)
[![Citation](https://img.shields.io/badge/citation-840-green)](https://scholar.google.com/scholar?cites=12620864006390196564) [![Citation](https://img.shields.io/badge/citation-1000+-green)](https://scholar.google.com/scholar?cites=12620864006390196564)
[![Docker Pulls](https://img.shields.io/docker/pulls/hiyouga/llamafactory)](https://hub.docker.com/r/hiyouga/llamafactory/tags) [![Docker Pulls](https://img.shields.io/docker/pulls/hiyouga/llamafactory)](https://hub.docker.com/r/hiyouga/llamafactory/tags)
[![Twitter](https://img.shields.io/twitter/follow/llamafactory_ai)](https://twitter.com/llamafactory_ai) [![Twitter](https://img.shields.io/twitter/follow/llamafactory_ai)](https://twitter.com/llamafactory_ai)
[![Discord](assets/thirdparty/discord.svg)](https://discord.gg/rKfvV9r9FK) [![Discord](assets/thirdparty/discord.svg)](https://discord.gg/rKfvV9r9FK)
[![WeChat](https://img.shields.io/badge/WeChat-User%20Group-blue?logo=wechat)](https://github.com/hiyouga/llamafactory-community)
[![Blog](https://img.shields.io/badge/Hugo-Official%20Blog-blue?logo=hugo)](https://blog.llamafactory.net/en/)
[![Open in Colab](assets/thirdparty/colab.svg)](https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing) [![Open in Colab](assets/thirdparty/colab.svg)](https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing)
[![Open in DSW](assets/thirdparty/dsw.svg)](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory) [![Open in DSW](assets/thirdparty/dsw.svg)](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory)
@@ -44,16 +46,20 @@
https://github.com/user-attachments/assets/3991a3a8-4276-4d30-9cab-4cb0c4b9b99e https://github.com/user-attachments/assets/3991a3a8-4276-4d30-9cab-4cb0c4b9b99e
Choose your path: Start local training:
- Please refer to [usage](#getting-started)
Start cloud training:
- **Colab (free)**: https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing
- **PAI-DSW (free trial)**: https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory
- **LLaMA Factory Online**: https://www.llamafactory.com.cn/?utm_source=LLaMA-Factory
- **Alaya NeW (cloud GPU deal)**: https://docs.alayanew.com/docs/documents/useGuide/LLaMAFactory/mutiple/?utm_source=LLaMA-Factory
Read technical notes:
- **Documentation (WIP)**: https://llamafactory.readthedocs.io/en/latest/ - **Documentation (WIP)**: https://llamafactory.readthedocs.io/en/latest/
- **Documentation (AMD GPU)**: https://rocm.docs.amd.com/projects/ai-developer-hub/en/latest/notebooks/fine_tune/llama_factory_llama3.html - **Documentation (AMD GPU)**: https://rocm.docs.amd.com/projects/ai-developer-hub/en/latest/notebooks/fine_tune/llama_factory_llama3.html
- **Colab (free)**: https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing - **Official Blog**: https://blog.llamafactory.net/en/
- **Local machine**: Please refer to [usage](#getting-started)
- **PAI-DSW (free trial)**: https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory
- **Alaya NeW (cloud GPU deal)**: https://docs.alayanew.com/docs/documents/useGuide/LLaMAFactory/mutiple/?utm_source=LLaMA-Factory
- **Official Course**: https://www.lab4ai.cn/course/detail?id=7c13e60f6137474eb40f6fd3983c0f46&utm_source=LLaMA-Factory - **Official Course**: https://www.lab4ai.cn/course/detail?id=7c13e60f6137474eb40f6fd3983c0f46&utm_source=LLaMA-Factory
- **LLaMA Factory Online**: https://www.llamafactory.com.cn/?utm_source=LLaMA-Factory
> [!NOTE] > [!NOTE]
> Except for the above links, all other websites are unauthorized third-party websites. Please carefully use them. > Except for the above links, all other websites are unauthorized third-party websites. Please carefully use them.
@@ -90,7 +96,7 @@ Choose your path:
- **Integrated methods**: (Continuous) pre-training, (multimodal) supervised fine-tuning, reward modeling, PPO, DPO, KTO, ORPO, etc. - **Integrated methods**: (Continuous) pre-training, (multimodal) supervised fine-tuning, reward modeling, PPO, DPO, KTO, ORPO, etc.
- **Scalable resources**: 16-bit full-tuning, freeze-tuning, LoRA and 2/3/4/5/6/8-bit QLoRA via AQLM/AWQ/GPTQ/LLM.int8/HQQ/EETQ. - **Scalable resources**: 16-bit full-tuning, freeze-tuning, LoRA and 2/3/4/5/6/8-bit QLoRA via AQLM/AWQ/GPTQ/LLM.int8/HQQ/EETQ.
- **Advanced algorithms**: [GaLore](https://github.com/jiaweizzhao/GaLore), [BAdam](https://github.com/Ledzy/BAdam), [APOLLO](https://github.com/zhuhanqing/APOLLO), [Adam-mini](https://github.com/zyushun/Adam-mini), [Muon](https://github.com/KellerJordan/Muon), [OFT](https://github.com/huggingface/peft/tree/main/src/peft/tuners/oft), DoRA, LongLoRA, LLaMA Pro, Mixture-of-Depths, LoRA+, LoftQ and PiSSA. - **Advanced algorithms**: [GaLore](https://github.com/jiaweizzhao/GaLore), [BAdam](https://github.com/Ledzy/BAdam), [APOLLO](https://github.com/zhuhanqing/APOLLO), [Adam-mini](https://github.com/zyushun/Adam-mini), [Muon](https://github.com/KellerJordan/Muon), [OFT](https://github.com/huggingface/peft/tree/main/src/peft/tuners/oft), DoRA, LongLoRA, LLaMA Pro, Mixture-of-Depths, LoRA+, LoftQ and PiSSA.
- **Practical tricks**: [FlashAttention-2](https://github.com/Dao-AILab/flash-attention), [Unsloth](https://github.com/unslothai/unsloth), [Liger Kernel](https://github.com/linkedin/Liger-Kernel), RoPE scaling, NEFTune and rsLoRA. - **Practical tricks**: [FlashAttention-2](https://github.com/Dao-AILab/flash-attention), [Unsloth](https://github.com/unslothai/unsloth), [Liger Kernel](https://github.com/linkedin/Liger-Kernel), [KTransformers](https://github.com/kvcache-ai/ktransformers/), RoPE scaling, NEFTune and rsLoRA.
- **Wide tasks**: Multi-turn dialogue, tool using, image understanding, visual grounding, video recognition, audio understanding, etc. - **Wide tasks**: Multi-turn dialogue, tool using, image understanding, visual grounding, video recognition, audio understanding, etc.
- **Experiment monitors**: LlamaBoard, TensorBoard, Wandb, MLflow, [SwanLab](https://github.com/SwanHubX/SwanLab), etc. - **Experiment monitors**: LlamaBoard, TensorBoard, Wandb, MLflow, [SwanLab](https://github.com/SwanHubX/SwanLab), etc.
- **Faster inference**: OpenAI-style API, Gradio UI and CLI with [vLLM worker](https://github.com/vllm-project/vllm) or [SGLang worker](https://github.com/sgl-project/sglang). - **Faster inference**: OpenAI-style API, Gradio UI and CLI with [vLLM worker](https://github.com/vllm-project/vllm) or [SGLang worker](https://github.com/sgl-project/sglang).
@@ -104,6 +110,12 @@ Choose your path:
## Blogs ## Blogs
> [!TIP]
> Now we have a dedicated blog for LLaMA Factory!
>
> Website: https://blog.llamafactory.net/en/
- 💡 [KTransformers Fine-Tuning × LLaMA Factory: Fine-tuning 1000 Billion models with 2 4090-GPU + CPU](https://blog.llamafactory.net/en/posts/ktransformers/) (English)
- 💡 [Easy Dataset × LLaMA Factory: Enabling LLMs to Efficiently Learn Domain Knowledge](https://buaa-act.feishu.cn/wiki/GVzlwYcRFiR8OLkHbL6cQpYin7g) (English) - 💡 [Easy Dataset × LLaMA Factory: Enabling LLMs to Efficiently Learn Domain Knowledge](https://buaa-act.feishu.cn/wiki/GVzlwYcRFiR8OLkHbL6cQpYin7g) (English)
- [Fine-tune a mental health LLM using LLaMA-Factory](https://www.lab4ai.cn/project/detail?id=25cce32ec131497b9e06a93336a0817f&type=project&utm_source=LLaMA-Factory) (Chinese) - [Fine-tune a mental health LLM using LLaMA-Factory](https://www.lab4ai.cn/project/detail?id=25cce32ec131497b9e06a93336a0817f&type=project&utm_source=LLaMA-Factory) (Chinese)
- [Fine-tune GPT-OSS for Role-Playing using LLaMA-Factory](https://docs.llamafactory.com.cn/docs/documents/best-practice/gptroleplay/?utm_source=LLaMA-Factory) (Chinese) - [Fine-tune GPT-OSS for Role-Playing using LLaMA-Factory](https://docs.llamafactory.com.cn/docs/documents/best-practice/gptroleplay/?utm_source=LLaMA-Factory) (Chinese)
@@ -123,6 +135,8 @@ Choose your path:
## Changelog ## Changelog
[25/10/26] We support Megatron-core training backend with [**mcore_adapter**](https://github.com/alibaba/ROLL/tree/main/mcore_adapter). See [PR #9237](https://github.com/hiyouga/LLaMA-Factory/pull/9237) to get started.
[25/08/22] We supported **[OFT](https://arxiv.org/abs/2306.07280)** and **[OFTv2](https://arxiv.org/abs/2506.19847)**. See [examples](examples/README.md) for usage. [25/08/22] We supported **[OFT](https://arxiv.org/abs/2306.07280)** and **[OFTv2](https://arxiv.org/abs/2506.19847)**. See [examples](examples/README.md) for usage.
[25/08/20] We supported fine-tuning the **[Intern-S1-mini](https://huggingface.co/internlm/Intern-S1-mini)** models. See [PR #8976](https://github.com/hiyouga/LLaMA-Factory/pull/8976) to get started. [25/08/20] We supported fine-tuning the **[Intern-S1-mini](https://huggingface.co/internlm/Intern-S1-mini)** models. See [PR #8976](https://github.com/hiyouga/LLaMA-Factory/pull/8976) to get started.
@@ -264,27 +278,21 @@ Choose your path:
| Model | Model size | Template | | Model | Model size | Template |
| ----------------------------------------------------------------- | -------------------------------- | -------------------- | | ----------------------------------------------------------------- | -------------------------------- | -------------------- |
| [Baichuan 2](https://huggingface.co/baichuan-inc) | 7B/13B | baichuan2 |
| [BLOOM/BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - | | [BLOOM/BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
| [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 |
| [Command R](https://huggingface.co/CohereForAI) | 35B/104B | cohere | | [Command R](https://huggingface.co/CohereForAI) | 35B/104B | cohere |
| [DeepSeek (Code/MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek | | [DeepSeek (LLM/Code/MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek |
| [DeepSeek 2.5/3](https://huggingface.co/deepseek-ai) | 236B/671B | deepseek3 | | [DeepSeek 3-3.2](https://huggingface.co/deepseek-ai) | 236B/671B | deepseek3 |
| [DeepSeek R1 (Distill)](https://huggingface.co/deepseek-ai) | 1.5B/7B/8B/14B/32B/70B/671B | deepseekr1 | | [DeepSeek R1 (Distill)](https://huggingface.co/deepseek-ai) | 1.5B/7B/8B/14B/32B/70B/671B | deepseekr1 |
| [ERNIE-4.5](https://huggingface.co/baidu) | 0.3B/21B/300B | ernie/ernie_nothink | | [ERNIE-4.5](https://huggingface.co/baidu) | 0.3B/21B/300B | ernie/ernie_nothink |
| [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon | | [Falcon/Falcon H1](https://huggingface.co/tiiuae) | 0.5B/1.5B/3B/7B/11B/34B/40B/180B | falcon/falcon_h1 |
| [Falcon-H1](https://huggingface.co/tiiuae) | 0.5B/1.5B/3B/7B/34B | falcon_h1 |
| [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma/gemma2 | | [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma/gemma2 |
| [Gemma 3/Gemma 3n](https://huggingface.co/google) | 270M/1B/4B/6B/8B/12B/27B | gemma3/gemma3n | | [Gemma 3/Gemma 3n](https://huggingface.co/google) | 270M/1B/4B/6B/8B/12B/27B | gemma3/gemma3n |
| [GLM-4/GLM-4-0414/GLM-Z1](https://huggingface.co/zai-org) | 9B/32B | glm4/glmz1 | | [GLM-4/GLM-4-0414/GLM-Z1](https://huggingface.co/zai-org) | 9B/32B | glm4/glmz1 |
| [GLM-4.1V](https://huggingface.co/zai-org) | 9B | glm4v | | [GLM-4.5/GLM-4.5(6)V](https://huggingface.co/zai-org) | 9B/106B/355B | glm4_moe/glm4_5v |
| [GLM-4.5/GLM-4.5V](https://huggingface.co/zai-org) | 106B/355B | glm4_moe/glm4v_moe |
| [GPT-2](https://huggingface.co/openai-community) | 0.1B/0.4B/0.8B/1.5B | - | | [GPT-2](https://huggingface.co/openai-community) | 0.1B/0.4B/0.8B/1.5B | - |
| [GPT-OSS](https://huggingface.co/openai) | 20B/120B | gpt | | [GPT-OSS](https://huggingface.co/openai) | 20B/120B | gpt_oss |
| [Granite 3.0-3.3](https://huggingface.co/ibm-granite) | 1B/2B/3B/8B | granite3 | | [Granite 3-4](https://huggingface.co/ibm-granite) | 1B/2B/3B/7B/8B | granite3/granite4 |
| [Granite 4](https://huggingface.co/ibm-granite) | 7B | granite4 | | [Hunyuan (MT)](https://huggingface.co/tencent/) | 7B | hunyuan |
| [Hunyuan](https://huggingface.co/tencent/) | 7B | hunyuan |
| [Index](https://huggingface.co/IndexTeam) | 1.9B | index |
| [InternLM 2-3](https://huggingface.co/internlm) | 7B/8B/20B | intern2 | | [InternLM 2-3](https://huggingface.co/internlm) | 7B/8B/20B | intern2 |
| [InternVL 2.5-3.5](https://huggingface.co/OpenGVLab) | 1B/2B/4B/8B/14B/30B/38B/78B/241B | intern_vl | | [InternVL 2.5-3.5](https://huggingface.co/OpenGVLab) | 1B/2B/4B/8B/14B/30B/38B/78B/241B | intern_vl |
| [InternLM/Intern-S1-mini](https://huggingface.co/internlm/) | 8B | intern_s1 | | [InternLM/Intern-S1-mini](https://huggingface.co/internlm/) | 8B | intern_s1 |
@@ -298,15 +306,13 @@ Choose your path:
| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | llava | | [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | llava |
| [LLaVA-NeXT](https://huggingface.co/llava-hf) | 7B/8B/13B/34B/72B/110B | llava_next | | [LLaVA-NeXT](https://huggingface.co/llava-hf) | 7B/8B/13B/34B/72B/110B | llava_next |
| [LLaVA-NeXT-Video](https://huggingface.co/llava-hf) | 7B/34B | llava_next_video | | [LLaVA-NeXT-Video](https://huggingface.co/llava-hf) | 7B/34B | llava_next_video |
| [MiMo](https://huggingface.co/XiaomiMiMo) | 7B | mimo | | [MiMo](https://huggingface.co/XiaomiMiMo) | 7B/309B | mimo/mimo_v2 |
| [MiniCPM 1-4.1](https://huggingface.co/openbmb) | 0.5B/1B/2B/4B/8B | cpm/cpm3/cpm4 | | [MiniCPM 1-4.1](https://huggingface.co/openbmb) | 0.5B/1B/2B/4B/8B | cpm/cpm3/cpm4 |
| [MiniCPM-o-2.6/MiniCPM-V-2.6](https://huggingface.co/openbmb) | 8B | minicpm_o/minicpm_v | | [MiniCPM-o-2.6/MiniCPM-V-2.6](https://huggingface.co/openbmb) | 8B | minicpm_o/minicpm_v |
| [Ministral/Mistral-Nemo](https://huggingface.co/mistralai) | 8B/12B | ministral | | [Ministral 3](https://huggingface.co/mistralai) | 3B/8B/14B | ministral3 |
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral | | [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
| [Mistral Small](https://huggingface.co/mistralai) | 24B | mistral_small |
| [OLMo](https://huggingface.co/allenai) | 1B/7B | - | | [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
| [PaliGemma/PaliGemma2](https://huggingface.co/google) | 3B/10B/28B | paligemma | | [PaliGemma/PaliGemma2](https://huggingface.co/google) | 3B/10B/28B | paligemma |
| [Phi-1.5/Phi-2](https://huggingface.co/microsoft) | 1.3B/2.7B | - |
| [Phi-3/Phi-3.5](https://huggingface.co/microsoft) | 4B/14B | phi | | [Phi-3/Phi-3.5](https://huggingface.co/microsoft) | 4B/14B | phi |
| [Phi-3-small](https://huggingface.co/microsoft) | 7B | phi_small | | [Phi-3-small](https://huggingface.co/microsoft) | 7B | phi_small |
| [Phi-4](https://huggingface.co/microsoft) | 14B | phi4 | | [Phi-4](https://huggingface.co/microsoft) | 14B | phi4 |
@@ -317,19 +323,18 @@ Choose your path:
| [Qwen2.5-Omni](https://huggingface.co/Qwen) | 3B/7B | qwen2_omni | | [Qwen2.5-Omni](https://huggingface.co/Qwen) | 3B/7B | qwen2_omni |
| [Qwen3-Omni](https://huggingface.co/Qwen) | 30B | qwen3_omni | | [Qwen3-Omni](https://huggingface.co/Qwen) | 30B | qwen3_omni |
| [Qwen2-VL/Qwen2.5-VL/QVQ](https://huggingface.co/Qwen) | 2B/3B/7B/32B/72B | qwen2_vl | | [Qwen2-VL/Qwen2.5-VL/QVQ](https://huggingface.co/Qwen) | 2B/3B/7B/32B/72B | qwen2_vl |
| [Qwen3-VL](https://huggingface.co/Qwen) | 235B | qwen3_vl | | [Qwen3-VL](https://huggingface.co/Qwen) | 2B/4B/8B/30B/32B/235B | qwen3_vl |
| [Seed (OSS/Coder)](https://huggingface.co/ByteDance-Seed) | 8B/36B | seed_oss/seed_coder | | [Seed (OSS/Coder)](https://huggingface.co/ByteDance-Seed) | 8B/36B | seed_oss/seed_coder |
| [Skywork o1](https://huggingface.co/Skywork) | 8B | skywork_o1 |
| [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - | | [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
| [TeleChat2](https://huggingface.co/Tele-AI) | 3B/7B/35B/115B | telechat2 | | [VibeThinker-1.5B](https://huggingface.co/WeiboAI) | 1.5B | qwen3 |
| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | xverse |
| [Yi/Yi-1.5 (Code)](https://huggingface.co/01-ai) | 1.5B/6B/9B/34B | yi | | [Yi/Yi-1.5 (Code)](https://huggingface.co/01-ai) | 1.5B/6B/9B/34B | yi |
| [Yi-VL](https://huggingface.co/01-ai) | 6B/34B | yi_vl |
| [Yuan 2](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan | | [Yuan 2](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan |
> [!NOTE] > [!NOTE]
> For the "base" models, the `template` argument can be chosen from `default`, `alpaca`, `vicuna` etc. But make sure to use the **corresponding template** for the "instruct/chat" models. > For the "base" models, the `template` argument can be chosen from `default`, `alpaca`, `vicuna` etc. But make sure to use the **corresponding template** for the "instruct/chat" models.
> >
> If the model has both reasoning and non-reasoning versions, please use the `_nothink` suffix to distinguish between them. For example, `qwen3` and `qwen3_nothink`.
>
> Remember to use the **SAME** template in training and inference. > Remember to use the **SAME** template in training and inference.
> >
> \*: You should install the `transformers` from main branch and use `DISABLE_VERSION_CHECK=1` to skip version check. > \*: You should install the `transformers` from main branch and use `DISABLE_VERSION_CHECK=1` to skip version check.
@@ -459,7 +464,7 @@ You also can add a custom chat template to [template.py](src/llamafactory/data/t
Some datasets require confirmation before using them, so we recommend logging in with your Hugging Face account using these commands. Some datasets require confirmation before using them, so we recommend logging in with your Hugging Face account using these commands.
```bash ```bash
pip install --upgrade huggingface_hub pip install "huggingface_hub<1.0.0"
huggingface-cli login huggingface-cli login
``` ```
@@ -509,10 +514,12 @@ huggingface-cli login
```bash ```bash
git clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git git clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git
cd LLaMA-Factory cd LLaMA-Factory
pip install -e ".[torch,metrics]" --no-build-isolation pip install -e ".[metrics]" --no-build-isolation
``` ```
Extra dependencies available: torch, torch-npu, metrics, deepspeed, liger-kernel, bitsandbytes, hqq, eetq, gptq, aqlm, vllm, sglang, galore, apollo, badam, adam-mini, qwen, minicpm_v, openmind, swanlab, dev Optional dependencies available: `metrics`, `deepspeed`. Install with: `pip install -e ".[metrics,deepspeed]"`
Additional dependencies for specific features are available in `examples/requirements/`.
#### Install from Docker Image #### Install from Docker Image
@@ -531,13 +538,7 @@ Please refer to [build docker](#build-docker) to build the image yourself.
Create an isolated Python environment with [uv](https://github.com/astral-sh/uv): Create an isolated Python environment with [uv](https://github.com/astral-sh/uv):
```bash ```bash
uv sync --extra torch --extra metrics --prerelease=allow uv run llamafactory-cli webui
```
Run LLaMA-Factory in the isolated environment:
```bash
uv run --prerelease=allow llamafactory-cli train examples/train_lora/llama3_lora_pretrain.yaml
``` ```
</details> </details>
@@ -574,7 +575,7 @@ To enable FlashAttention-2 on the Windows platform, please use the script from [
<details><summary>For Ascend NPU users</summary> <details><summary>For Ascend NPU users</summary>
To install LLaMA Factory on Ascend NPU devices, please upgrade Python to version 3.10 or higher and specify extra dependencies: `pip install -e ".[torch-npu,metrics]"`. Additionally, you need to install the **[Ascend CANN Toolkit and Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**. Please follow the [installation tutorial](https://www.hiascend.com/document/detail/en/CANNCommunityEdition/600alphaX/softwareinstall/instg/atlasdeploy_03_0031.html) or use the following commands: To install LLaMA Factory on Ascend NPU devices, please upgrade Python to version 3.10 or higher: `pip install -e . torch-npu==2.7.1`. Additionally, you need to install the **[Ascend CANN Toolkit and Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**. Please follow the [installation tutorial](https://www.hiascend.com/document/detail/en/CANNCommunityEdition/600alphaX/softwareinstall/instg/atlasdeploy_03_0031.html) or use the following commands:
```bash ```bash
# replace the url according to your CANN version and devices # replace the url according to your CANN version and devices
@@ -593,8 +594,8 @@ source /usr/local/Ascend/ascend-toolkit/set_env.sh
| Requirement | Minimum | Recommend | | Requirement | Minimum | Recommend |
| ------------ | ------- | -------------- | | ------------ | ------- | -------------- |
| CANN | 8.0.RC1 | 8.0.0.alpha002 | | CANN | 8.0.RC1 | 8.0.0.alpha002 |
| torch | 2.1.0 | 2.4.0 | | torch | 2.1.0 | 2.7.1 |
| torch-npu | 2.1.0 | 2.4.0.post2 | | torch-npu | 2.1.0 | 2.7.1 |
| deepspeed | 0.13.2 | 0.13.2 | | deepspeed | 0.13.2 | 0.13.2 |
| vllm-ascend | - | 0.7.3 | | vllm-ascend | - | 0.7.3 |
@@ -709,7 +710,6 @@ For CUDA users:
```bash ```bash
docker build -f ./docker/docker-cuda/Dockerfile \ docker build -f ./docker/docker-cuda/Dockerfile \
--build-arg PIP_INDEX=https://pypi.org/simple \ --build-arg PIP_INDEX=https://pypi.org/simple \
--build-arg EXTRAS=metrics \
-t llamafactory:latest . -t llamafactory:latest .
docker run -dit --ipc=host --gpus=all \ docker run -dit --ipc=host --gpus=all \
@@ -726,7 +726,6 @@ For Ascend NPU users:
```bash ```bash
docker build -f ./docker/docker-npu/Dockerfile \ docker build -f ./docker/docker-npu/Dockerfile \
--build-arg PIP_INDEX=https://pypi.org/simple \ --build-arg PIP_INDEX=https://pypi.org/simple \
--build-arg EXTRAS=torch-npu,metrics \
-t llamafactory:latest . -t llamafactory:latest .
docker run -dit --ipc=host \ docker run -dit --ipc=host \
@@ -751,7 +750,6 @@ For AMD ROCm users:
```bash ```bash
docker build -f ./docker/docker-rocm/Dockerfile \ docker build -f ./docker/docker-rocm/Dockerfile \
--build-arg PIP_INDEX=https://pypi.org/simple \ --build-arg PIP_INDEX=https://pypi.org/simple \
--build-arg EXTRAS=metrics \
-t llamafactory:latest . -t llamafactory:latest .
docker run -dit --ipc=host \ docker run -dit --ipc=host \

View File

@@ -5,11 +5,13 @@
[![GitHub contributors](https://img.shields.io/github/contributors/hiyouga/LLaMA-Factory?color=orange)](https://github.com/hiyouga/LLaMA-Factory/graphs/contributors) [![GitHub contributors](https://img.shields.io/github/contributors/hiyouga/LLaMA-Factory?color=orange)](https://github.com/hiyouga/LLaMA-Factory/graphs/contributors)
[![GitHub workflow](https://github.com/hiyouga/LLaMA-Factory/actions/workflows/tests.yml/badge.svg)](https://github.com/hiyouga/LLaMA-Factory/actions/workflows/tests.yml) [![GitHub workflow](https://github.com/hiyouga/LLaMA-Factory/actions/workflows/tests.yml/badge.svg)](https://github.com/hiyouga/LLaMA-Factory/actions/workflows/tests.yml)
[![PyPI](https://img.shields.io/pypi/v/llamafactory)](https://pypi.org/project/llamafactory/) [![PyPI](https://img.shields.io/pypi/v/llamafactory)](https://pypi.org/project/llamafactory/)
[![Citation](https://img.shields.io/badge/citation-840-green)](https://scholar.google.com/scholar?cites=12620864006390196564) [![Citation](https://img.shields.io/badge/citation-1000+-green)](https://scholar.google.com/scholar?cites=12620864006390196564)
[![Docker Pulls](https://img.shields.io/docker/pulls/hiyouga/llamafactory)](https://hub.docker.com/r/hiyouga/llamafactory/tags) [![Docker Pulls](https://img.shields.io/docker/pulls/hiyouga/llamafactory)](https://hub.docker.com/r/hiyouga/llamafactory/tags)
[![Twitter](https://img.shields.io/twitter/follow/llamafactory_ai)](https://twitter.com/llamafactory_ai) [![Twitter](https://img.shields.io/twitter/follow/llamafactory_ai)](https://twitter.com/llamafactory_ai)
[![Discord](assets/thirdparty/discord.svg)](https://discord.gg/rKfvV9r9FK) [![Discord](assets/thirdparty/discord.svg)](https://discord.gg/rKfvV9r9FK)
[![WeChat](https://img.shields.io/badge/WeChat-User%20Group-blue?logo=wechat)](https://github.com/hiyouga/llamafactory-community)
[![Blog](https://img.shields.io/badge/Hugo-Official%20Blog-blue?logo=hugo)](https://blog.llamafactory.net/)
[![Open in Colab](assets/thirdparty/colab.svg)](https://colab.research.google.com/drive/1d5KQtbemerlSDSxZIfAaWXhKr30QypiK?usp=sharing) [![Open in Colab](assets/thirdparty/colab.svg)](https://colab.research.google.com/drive/1d5KQtbemerlSDSxZIfAaWXhKr30QypiK?usp=sharing)
[![Open in DSW](assets/thirdparty/dsw.svg)](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory) [![Open in DSW](assets/thirdparty/dsw.svg)](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory)
@@ -44,18 +46,22 @@
https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
选择你的打开方式 开始本地训练
- 请见[如何使用](#如何使用)
开始云端训练:
- **Colab免费**https://colab.research.google.com/drive/1d5KQtbemerlSDSxZIfAaWXhKr30QypiK?usp=sharing
- **PAI-DSW免费试用**https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory
- **LLaMA Factory Online在线微调**https://www.llamafactory.com.cn/?utm_source=LLaMA-Factory
- **九章智算云(算力优惠活动)**https://docs.alayanew.com/docs/documents/useGuide/LLaMAFactory/mutiple/?utm_source=LLaMA-Factory
阅读技术文档:
- **入门教程**https://zhuanlan.zhihu.com/p/695287607 - **入门教程**https://zhuanlan.zhihu.com/p/695287607
- **微调视频教程**https://www.bilibili.com/video/BV1djgRzxEts/ - **微调视频教程**https://www.bilibili.com/video/BV1djgRzxEts/
- **框架文档**https://llamafactory.readthedocs.io/zh-cn/latest/ - **框架文档**https://llamafactory.readthedocs.io/zh-cn/latest/
- **框架文档(昇腾 NPU**https://ascend.github.io/docs/sources/llamafactory/ - **框架文档(昇腾 NPU**https://ascend.github.io/docs/sources/llamafactory/
- **Colab免费**https://colab.research.google.com/drive/1d5KQtbemerlSDSxZIfAaWXhKr30QypiK?usp=sharing - **官方博客**https://blog.llamafactory.net/
- **本地机器**:请见[如何使用](#如何使用)
- **PAI-DSW免费试用**https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory
- **九章智算云(算力优惠活动)**https://docs.alayanew.com/docs/documents/useGuide/LLaMAFactory/mutiple/?utm_source=LLaMA-Factory
- **官方课程**https://www.lab4ai.cn/course/detail?id=7c13e60f6137474eb40f6fd3983c0f46&utm_source=LLaMA-Factory - **官方课程**https://www.lab4ai.cn/course/detail?id=7c13e60f6137474eb40f6fd3983c0f46&utm_source=LLaMA-Factory
- **LLaMA Factory Online在线微调**https://www.llamafactory.com.cn/?utm_source=LLaMA-Factory
> [!NOTE] > [!NOTE]
> 除上述链接以外的其他网站均为未经许可的第三方网站,请小心甄别。 > 除上述链接以外的其他网站均为未经许可的第三方网站,请小心甄别。
@@ -92,7 +98,7 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
- **集成方法**增量预训练、多模态指令监督微调、奖励模型训练、PPO 训练、DPO 训练、KTO 训练、ORPO 训练等等。 - **集成方法**增量预训练、多模态指令监督微调、奖励模型训练、PPO 训练、DPO 训练、KTO 训练、ORPO 训练等等。
- **多种精度**16 比特全参数微调、冻结微调、LoRA 微调和基于 AQLM/AWQ/GPTQ/LLM.int8/HQQ/EETQ 的 2/3/4/5/6/8 比特 QLoRA 微调。 - **多种精度**16 比特全参数微调、冻结微调、LoRA 微调和基于 AQLM/AWQ/GPTQ/LLM.int8/HQQ/EETQ 的 2/3/4/5/6/8 比特 QLoRA 微调。
- **先进算法**[GaLore](https://github.com/jiaweizzhao/GaLore)、[BAdam](https://github.com/Ledzy/BAdam)、[APOLLO](https://github.com/zhuhanqing/APOLLO)、[Adam-mini](https://github.com/zyushun/Adam-mini)、[Muon](https://github.com/KellerJordan/Muon)、[OFT](https://github.com/huggingface/peft/tree/main/src/peft/tuners/oft)、DoRA、LongLoRA、LLaMA Pro、Mixture-of-Depths、LoRA+、LoftQ 和 PiSSA。 - **先进算法**[GaLore](https://github.com/jiaweizzhao/GaLore)、[BAdam](https://github.com/Ledzy/BAdam)、[APOLLO](https://github.com/zhuhanqing/APOLLO)、[Adam-mini](https://github.com/zyushun/Adam-mini)、[Muon](https://github.com/KellerJordan/Muon)、[OFT](https://github.com/huggingface/peft/tree/main/src/peft/tuners/oft)、DoRA、LongLoRA、LLaMA Pro、Mixture-of-Depths、LoRA+、LoftQ 和 PiSSA。
- **实用技巧**[FlashAttention-2](https://github.com/Dao-AILab/flash-attention)、[Unsloth](https://github.com/unslothai/unsloth)、[Liger Kernel](https://github.com/linkedin/Liger-Kernel)、RoPE scaling、NEFTune 和 rsLoRA。 - **实用技巧**[FlashAttention-2](https://github.com/Dao-AILab/flash-attention)、[Unsloth](https://github.com/unslothai/unsloth)、[Liger Kernel](https://github.com/linkedin/Liger-Kernel)、[KTransformers](https://github.com/kvcache-ai/ktransformers/)、RoPE scaling、NEFTune 和 rsLoRA。
- **广泛任务**:多轮对话、工具调用、图像理解、视觉定位、视频识别和语音理解等等。 - **广泛任务**:多轮对话、工具调用、图像理解、视觉定位、视频识别和语音理解等等。
- **实验监控**LlamaBoard、TensorBoard、Wandb、MLflow、[SwanLab](https://github.com/SwanHubX/SwanLab) 等等。 - **实验监控**LlamaBoard、TensorBoard、Wandb、MLflow、[SwanLab](https://github.com/SwanHubX/SwanLab) 等等。
- **极速推理**:基于 [vLLM](https://github.com/vllm-project/vllm) 或 [SGLang](https://github.com/sgl-project/sglang) 的 OpenAI 风格 API、浏览器界面和命令行接口。 - **极速推理**:基于 [vLLM](https://github.com/vllm-project/vllm) 或 [SGLang](https://github.com/sgl-project/sglang) 的 OpenAI 风格 API、浏览器界面和命令行接口。
@@ -106,6 +112,12 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
## 官方博客 ## 官方博客
> [!TIP]
> 我们现在拥有了 LLaMA Factory 的专属博客!
>
> 网站地址https://blog.llamafactory.net/
- 💡 [KTransformers Fine-Tuning × LLaMA Factory: 用2张4090级的GPU+CPU 微调 1000B规模的超大模型](https://swcil84qspu.feishu.cn/wiki/Z1sSwb2poijybxkyPEkcDG6enVc) (中文)
- 💡 [Easy Dataset × LLaMA Factory: 让大模型高效学习领域知识](https://buaa-act.feishu.cn/wiki/KY9xwTGs1iqHrRkjXBwcZP9WnL9)(中文) - 💡 [Easy Dataset × LLaMA Factory: 让大模型高效学习领域知识](https://buaa-act.feishu.cn/wiki/KY9xwTGs1iqHrRkjXBwcZP9WnL9)(中文)
- [使用 LLaMA-Factory 微调心理健康大模型](https://www.lab4ai.cn/project/detail?id=25cce32ec131497b9e06a93336a0817f&type=project&utm_source=LLaMA-Factory)(中文) - [使用 LLaMA-Factory 微调心理健康大模型](https://www.lab4ai.cn/project/detail?id=25cce32ec131497b9e06a93336a0817f&type=project&utm_source=LLaMA-Factory)(中文)
- [使用 LLaMA-Factory 构建 GPT-OSS 角色扮演模型](https://docs.llamafactory.com.cn/docs/documents/best-practice/gptroleplay/?utm_source=LLaMA-Factory)(中文) - [使用 LLaMA-Factory 构建 GPT-OSS 角色扮演模型](https://docs.llamafactory.com.cn/docs/documents/best-practice/gptroleplay/?utm_source=LLaMA-Factory)(中文)
@@ -125,6 +137,8 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
## 更新日志 ## 更新日志
[25/10/26] 我们支持了Megatron-core作为训练后端和适配了[**mcore_adapter**](https://github.com/alibaba/ROLL/tree/main/mcore_adapter)。查看[PR #9237](https://github.com/hiyouga/LLaMA-Factory/pull/9237)以使用。
[25/08/22] 我们支持了 **[OFT](https://arxiv.org/abs/2306.07280)** 和 **[OFTv2](https://arxiv.org/abs/2506.19847)** 模型的微调。查看 [examples](examples/README.md) 以使用。 [25/08/22] 我们支持了 **[OFT](https://arxiv.org/abs/2306.07280)** 和 **[OFTv2](https://arxiv.org/abs/2506.19847)** 模型的微调。查看 [examples](examples/README.md) 以使用。
[25/08/20] 我们支持了 **[Intern-S1-mini](https://huggingface.co/internlm/Intern-S1-mini)** 模型的微调。查看 [PR #8976](https://github.com/hiyouga/LLaMA-Factory/pull/8976) 以使用。 [25/08/20] 我们支持了 **[Intern-S1-mini](https://huggingface.co/internlm/Intern-S1-mini)** 模型的微调。查看 [PR #8976](https://github.com/hiyouga/LLaMA-Factory/pull/8976) 以使用。
@@ -266,27 +280,21 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
| 模型名 | 参数量 | Template | | 模型名 | 参数量 | Template |
| ----------------------------------------------------------------- | -------------------------------- | -------------------- | | ----------------------------------------------------------------- | -------------------------------- | -------------------- |
| [Baichuan 2](https://huggingface.co/baichuan-inc) | 7B/13B | baichuan2 |
| [BLOOM/BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - | | [BLOOM/BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
| [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 |
| [Command R](https://huggingface.co/CohereForAI) | 35B/104B | cohere | | [Command R](https://huggingface.co/CohereForAI) | 35B/104B | cohere |
| [DeepSeek (Code/MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek | | [DeepSeek (LLM/Code/MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek |
| [DeepSeek 2.5/3](https://huggingface.co/deepseek-ai) | 236B/671B | deepseek3 | | [DeepSeek 3-3.2](https://huggingface.co/deepseek-ai) | 236B/671B | deepseek3 |
| [DeepSeek R1 (Distill)](https://huggingface.co/deepseek-ai) | 1.5B/7B/8B/14B/32B/70B/671B | deepseekr1 | | [DeepSeek R1 (Distill)](https://huggingface.co/deepseek-ai) | 1.5B/7B/8B/14B/32B/70B/671B | deepseekr1 |
| [ERNIE-4.5](https://huggingface.co/baidu) | 0.3B/21B/300B | ernie/ernie_nothink | | [ERNIE-4.5](https://huggingface.co/baidu) | 0.3B/21B/300B | ernie/ernie_nothink |
| [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon | | [Falcon/Falcon H1](https://huggingface.co/tiiuae) | 0.5B/1.5B/3B/7B/11B/34B/40B/180B | falcon/falcon_h1 |
| [Falcon-H1](https://huggingface.co/tiiuae) | 0.5B/1.5B/3B/7B/34B | falcon_h1 |
| [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma/gemma2 | | [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma/gemma2 |
| [Gemma 3/Gemma 3n](https://huggingface.co/google) | 270M/1B/4B/6B/8B/12B/27B | gemma3/gemma3n | | [Gemma 3/Gemma 3n](https://huggingface.co/google) | 270M/1B/4B/6B/8B/12B/27B | gemma3/gemma3n |
| [GLM-4/GLM-4-0414/GLM-Z1](https://huggingface.co/zai-org) | 9B/32B | glm4/glmz1 | | [GLM-4/GLM-4-0414/GLM-Z1](https://huggingface.co/zai-org) | 9B/32B | glm4/glmz1 |
| [GLM-4.1V](https://huggingface.co/zai-org) | 9B | glm4v | | [GLM-4.5/GLM-4.5(6)V](https://huggingface.co/zai-org) | 9B/106B/355B | glm4_moe/glm4_5v |
| [GLM-4.5/GLM-4.5V](https://huggingface.co/zai-org) | 106B/355B | glm4_moe/glm4v_moe |
| [GPT-2](https://huggingface.co/openai-community) | 0.1B/0.4B/0.8B/1.5B | - | | [GPT-2](https://huggingface.co/openai-community) | 0.1B/0.4B/0.8B/1.5B | - |
| [GPT-OSS](https://huggingface.co/openai) | 20B/120B | gpt | | [GPT-OSS](https://huggingface.co/openai) | 20B/120B | gpt_oss |
| [Granite 3.0-3.3](https://huggingface.co/ibm-granite) | 1B/2B/3B/8B | granite3 | | [Granite 3-4](https://huggingface.co/ibm-granite) | 1B/2B/3B/7B/8B | granite3/granite4 |
| [Granite 4](https://huggingface.co/ibm-granite) | 7B | granite4 | | [Hunyuan (MT)](https://huggingface.co/tencent/) | 7B | hunyuan |
| [Hunyuan](https://huggingface.co/tencent/) | 7B | hunyuan |
| [Index](https://huggingface.co/IndexTeam) | 1.9B | index |
| [InternLM 2-3](https://huggingface.co/internlm) | 7B/8B/20B | intern2 | | [InternLM 2-3](https://huggingface.co/internlm) | 7B/8B/20B | intern2 |
| [InternVL 2.5-3.5](https://huggingface.co/OpenGVLab) | 1B/2B/4B/8B/14B/30B/38B/78B/241B | intern_vl | | [InternVL 2.5-3.5](https://huggingface.co/OpenGVLab) | 1B/2B/4B/8B/14B/30B/38B/78B/241B | intern_vl |
| [InternLM/Intern-S1-mini](https://huggingface.co/internlm/) | 8B | intern_s1 | | [InternLM/Intern-S1-mini](https://huggingface.co/internlm/) | 8B | intern_s1 |
@@ -300,15 +308,13 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | llava | | [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | llava |
| [LLaVA-NeXT](https://huggingface.co/llava-hf) | 7B/8B/13B/34B/72B/110B | llava_next | | [LLaVA-NeXT](https://huggingface.co/llava-hf) | 7B/8B/13B/34B/72B/110B | llava_next |
| [LLaVA-NeXT-Video](https://huggingface.co/llava-hf) | 7B/34B | llava_next_video | | [LLaVA-NeXT-Video](https://huggingface.co/llava-hf) | 7B/34B | llava_next_video |
| [MiMo](https://huggingface.co/XiaomiMiMo) | 7B | mimo | | [MiMo](https://huggingface.co/XiaomiMiMo) | 7B/309B | mimo/mimo_v2 |
| [MiniCPM 1-4.1](https://huggingface.co/openbmb) | 0.5B/1B/2B/4B/8B | cpm/cpm3/cpm4 | | [MiniCPM 1-4.1](https://huggingface.co/openbmb) | 0.5B/1B/2B/4B/8B | cpm/cpm3/cpm4 |
| [MiniCPM-o-2.6/MiniCPM-V-2.6](https://huggingface.co/openbmb) | 8B | minicpm_o/minicpm_v | | [MiniCPM-o-2.6/MiniCPM-V-2.6](https://huggingface.co/openbmb) | 8B | minicpm_o/minicpm_v |
| [Ministral/Mistral-Nemo](https://huggingface.co/mistralai) | 8B/12B | ministral | | [Ministral 3](https://huggingface.co/mistralai) | 3B/8B/14B | ministral3 |
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral | | [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
| [Mistral Small](https://huggingface.co/mistralai) | 24B | mistral_small |
| [OLMo](https://huggingface.co/allenai) | 1B/7B | - | | [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
| [PaliGemma/PaliGemma2](https://huggingface.co/google) | 3B/10B/28B | paligemma | | [PaliGemma/PaliGemma2](https://huggingface.co/google) | 3B/10B/28B | paligemma |
| [Phi-1.5/Phi-2](https://huggingface.co/microsoft) | 1.3B/2.7B | - |
| [Phi-3/Phi-3.5](https://huggingface.co/microsoft) | 4B/14B | phi | | [Phi-3/Phi-3.5](https://huggingface.co/microsoft) | 4B/14B | phi |
| [Phi-3-small](https://huggingface.co/microsoft) | 7B | phi_small | | [Phi-3-small](https://huggingface.co/microsoft) | 7B | phi_small |
| [Phi-4](https://huggingface.co/microsoft) | 14B | phi4 | | [Phi-4](https://huggingface.co/microsoft) | 14B | phi4 |
@@ -319,19 +325,18 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
| [Qwen2.5-Omni](https://huggingface.co/Qwen) | 3B/7B | qwen2_omni | | [Qwen2.5-Omni](https://huggingface.co/Qwen) | 3B/7B | qwen2_omni |
| [Qwen3-Omni](https://huggingface.co/Qwen) | 30B | qwen3_omni | | [Qwen3-Omni](https://huggingface.co/Qwen) | 30B | qwen3_omni |
| [Qwen2-VL/Qwen2.5-VL/QVQ](https://huggingface.co/Qwen) | 2B/3B/7B/32B/72B | qwen2_vl | | [Qwen2-VL/Qwen2.5-VL/QVQ](https://huggingface.co/Qwen) | 2B/3B/7B/32B/72B | qwen2_vl |
| [Qwen3-VL](https://huggingface.co/Qwen) | 235B | qwen3_vl | | [Qwen3-VL](https://huggingface.co/Qwen) | 2B/4B/8B/30B/32B/235B | qwen3_vl |
| [Seed (OSS/Coder)](https://huggingface.co/ByteDance-Seed) | 8B/36B | seed_oss/seed_coder | | [Seed (OSS/Coder)](https://huggingface.co/ByteDance-Seed) | 8B/36B | seed_oss/seed_coder |
| [Skywork o1](https://huggingface.co/Skywork) | 8B | skywork_o1 |
| [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - | | [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
| [TeleChat2](https://huggingface.co/Tele-AI) | 3B/7B/35B/115B | telechat2 | | [VibeThinker-1.5B](https://huggingface.co/WeiboAI) | 1.5B | qwen3 |
| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | xverse |
| [Yi/Yi-1.5 (Code)](https://huggingface.co/01-ai) | 1.5B/6B/9B/34B | yi | | [Yi/Yi-1.5 (Code)](https://huggingface.co/01-ai) | 1.5B/6B/9B/34B | yi |
| [Yi-VL](https://huggingface.co/01-ai) | 6B/34B | yi_vl |
| [Yuan 2](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan | | [Yuan 2](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan |
> [!NOTE] > [!NOTE]
> 对于所有“基座”Base模型`template` 参数可以是 `default`, `alpaca`, `vicuna` 等任意值。但“对话”Instruct/Chat模型请务必使用**对应的模板**。 > 对于所有“基座”Base模型`template` 参数可以是 `default`, `alpaca`, `vicuna` 等任意值。但“对话”Instruct/Chat模型请务必使用**对应的模板**。
> >
> 如果模型有推理 / 非推理两个版本,请使用 `_nothink` 后缀来区分不同的模板。例如 `qwen3` 和 `qwen3_nothink`。
>
> 请务必在训练和推理时采用**完全一致**的模板。 > 请务必在训练和推理时采用**完全一致**的模板。
> >
> \*:您需要从 main 分支安装 `transformers` 并使用 `DISABLE_VERSION_CHECK=1` 来跳过版本检查。 > \*:您需要从 main 分支安装 `transformers` 并使用 `DISABLE_VERSION_CHECK=1` 来跳过版本检查。
@@ -511,10 +516,12 @@ huggingface-cli login
```bash ```bash
git clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git git clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git
cd LLaMA-Factory cd LLaMA-Factory
pip install -e ".[torch,metrics]" --no-build-isolation pip install -e ".[metrics]" --no-build-isolation
``` ```
可选的额外依赖项:torch、torch-npu、metricsdeepspeed、liger-kernel、bitsandbytes、hqq、eetq、gptq、aqlm、vllm、sglang、galore、apollo、badam、adam-mini、qwen、minicpm_v、openmind、swanlab、dev 可选的额外依赖项:`metrics``deepspeed`。使用 `pip install -e ".[metrics,deepspeed]"` 安装。
其他可选依赖项请参考 `examples/requirements/` 目录下的文件。
#### 从镜像安装 #### 从镜像安装
@@ -533,13 +540,7 @@ docker run -it --rm --gpus=all --ipc=host hiyouga/llamafactory:latest
使用 [uv](https://github.com/astral-sh/uv) 创建隔离的 Python 环境: 使用 [uv](https://github.com/astral-sh/uv) 创建隔离的 Python 环境:
```bash ```bash
uv sync --extra torch --extra metrics --prerelease=allow uv run llamafactory-cli webui
```
在环境中运行 LLaMA-Factory
```bash
uv run --prerelease=allow llamafactory-cli train examples/train_lora/llama3_lora_pretrain.yaml
``` ```
</details> </details>
@@ -576,7 +577,7 @@ pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/downl
<details><summary>昇腾 NPU 用户指南</summary> <details><summary>昇腾 NPU 用户指南</summary>
在昇腾 NPU 设备上安装 LLaMA Factory 时,请升级 Python 到 3.10 及以上,并需要指定额外依赖项,使用 `pip install -e ".[torch-npu,metrics]"` 命令安装。此外,还需要安装 **[Ascend CANN Toolkit 与 Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**,安装方法请参考[安装教程](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/80RC2alpha002/quickstart/quickstart/quickstart_18_0004.html)或使用以下命令: 在昇腾 NPU 设备上安装 LLaMA Factory 时,请升级 Python 到 3.10 及以上,并需要指定额外依赖项,使用 `pip install -e . torch-npu==2.7.1` 命令安装。此外,还需要安装 **[Ascend CANN Toolkit 与 Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**,安装方法请参考[安装教程](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/80RC2alpha002/quickstart/quickstart/quickstart_18_0004.html)或使用以下命令:
```bash ```bash
# 请替换 URL 为 CANN 版本和设备型号对应的 URL # 请替换 URL 为 CANN 版本和设备型号对应的 URL
@@ -595,8 +596,8 @@ source /usr/local/Ascend/ascend-toolkit/set_env.sh
| 依赖项 | 至少 | 推荐 | | 依赖项 | 至少 | 推荐 |
| ------------ | ------- | -------------- | | ------------ | ------- | -------------- |
| CANN | 8.0.RC1 | 8.0.0.alpha002 | | CANN | 8.0.RC1 | 8.0.0.alpha002 |
| torch | 2.1.0 | 2.4.0 | | torch | 2.1.0 | 2.7.1 |
| torch-npu | 2.1.0 | 2.4.0.post2 | | torch-npu | 2.1.0 | 2.7.1 |
| deepspeed | 0.13.2 | 0.13.2 | | deepspeed | 0.13.2 | 0.13.2 |
| vllm-ascend | - | 0.7.3 | | vllm-ascend | - | 0.7.3 |

10
data/v1_dpo_demo.jsonl Normal file

File diff suppressed because one or more lines are too long

4
data/v1_dpo_demo.yaml Normal file
View File

@@ -0,0 +1,4 @@
dpo_zh_demo:
path: HuggingFaceH4/orca_dpo_pairs
split: train_prefs
converter: pair

View File

@@ -1,8 +1,9 @@
identity: identity:
file_name: identity.json path: data/identity.json
source: local
converter: alpaca converter: alpaca
alpaca_en_demo: alpaca_en_demo:
file_name: alpaca_en_demo.json path: data/alpaca_en_demo.json
dataset_dir: ~/data source: local
converter: alpaca converter: alpaca
num_samples: 500 size: 500

View File

@@ -4,7 +4,6 @@ FROM ${BASE_IMAGE}
# Installation arguments # Installation arguments
ARG PIP_INDEX=https://pypi.org/simple ARG PIP_INDEX=https://pypi.org/simple
ARG EXTRAS=metrics
ARG INSTALL_FLASHATTN=false ARG INSTALL_FLASHATTN=false
ARG HTTP_PROXY="" ARG HTTP_PROXY=""
@@ -27,17 +26,13 @@ WORKDIR /app
# Change pip source # Change pip source
RUN pip config set global.index-url "${PIP_INDEX}" && \ RUN pip config set global.index-url "${PIP_INDEX}" && \
pip config set global.extra-index-url "${PIP_INDEX}" && \ pip config set global.extra-index-url "${PIP_INDEX}" && \
pip install --no-cache-dir --upgrade pip packaging wheel setuptools pip install --no-cache-dir --upgrade pip packaging wheel setuptools editables "hatchling>=1.18.0"
# Install the requirements # Copy the application into the image
COPY requirements.txt /app
RUN pip install --no-cache-dir -r requirements.txt
# Copy the rest of the application into the image
COPY . /app COPY . /app
# Install LLaMA Factory # Install LLaMA Factory
RUN pip install --no-cache-dir -e ".[${EXTRAS}]" --no-build-isolation RUN pip install --no-cache-dir --no-build-isolation -e ".[metrics,deepspeed]"
# Rebuild flash attention # Rebuild flash attention
RUN if [ "${INSTALL_FLASHATTN}" == "true" ]; then \ RUN if [ "${INSTALL_FLASHATTN}" == "true" ]; then \

View File

@@ -0,0 +1,77 @@
# NVIDIA official image (ubuntu-22.04 + cuda-12.4 + python-3.10)
# https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-24-08.html
FROM nvcr.io/nvidia/pytorch:24.05-py3
ENV DEBIAN_FRONTEND=noninteractive
ENV PIP_ROOT_USER_ACTION=ignore
ENV PYPI_MIRROR=https://mirrors.aliyun.com/pypi/simple/
ENV PYPI_TRUSTED_HOST=mirrors.aliyun.com
ENV APT_MIRROR=https://mirrors.tuna.tsinghua.edu.cn/ubuntu/
RUN pip install --upgrade pip setuptools wheel "hatchling>=1.18.0" editables --trusted-host ${PYPI_TRUSTED_HOST} --index-url ${PYPI_MIRROR}
RUN pip uninstall -y torch torchvision torch-tensorrt \
flash_attn transformer-engine \
cudf dask-cuda cugraph cugraph-service-server cuml raft-dask cugraph-dgl cugraph-pyg dask-cudf
RUN pip install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 --index-url https://download.pytorch.org/whl/cu124
RUN pip uninstall -y opencv opencv-python opencv-python-headless && \
rm -rf /usr/local/lib/python3.10/dist-packages/cv2/ && \
pip install opencv-python-headless==4.11.0.86 --trusted-host ${PYPI_TRUSTED_HOST} --index-url ${PYPI_MIRROR}
RUN pip install "numpy==1.26.4" "optree>=0.13.0" "spacy==3.7.5" "weasel==0.4.1" \
transformer-engine[pytorch]==2.2.0 megatron-core==0.13.0 deepspeed==0.16.4 \
--trusted-host ${PYPI_TRUSTED_HOST} --index-url ${PYPI_MIRROR}
RUN pip install https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.2.post1/flash_attn-2.7.2.post1+cu12torch2.6cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
# RUN pip install vllm==0.8.4 \
# --trusted-host ${PYPI_TRUSTED_HOST} --index-url ${PYPI_MIRROR}
WORKDIR /build
ARG apex_url=git+https://github.com/NVIDIA/apex.git@25.04
RUN pip uninstall -y apex && \
MAX_JOBS=32 NINJA_FLAGS="-j32" NVCC_APPEND_FLAGS="--threads 32" \
pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation \
--config-settings "--build-option=--cpp_ext --cuda_ext --parallel 32" ${apex_url}
RUN rm -rf /build
WORKDIR /workspace
RUN cp /etc/apt/sources.list /etc/apt/sources.list.bak && \
{ \
echo "deb ${APT_MIRROR} jammy main restricted universe multiverse"; \
echo "deb ${APT_MIRROR} jammy-security main restricted universe multiverse"; \
echo "deb ${APT_MIRROR} jammy-updates main restricted universe multiverse"; \
echo "deb ${APT_MIRROR} jammy-backports main restricted universe multiverse"; \
} > /etc/apt/sources.list
RUN apt-get update && apt-get install -y zip
RUN apt-get install -y openjdk-21-jdk
ENV JAVA_HOME /usr/lib/jvm/java-21-openjdk-amd64
# pip install LLaMA-Factory
WORKDIR /app
# Copy the application into the image
COPY . /app
# Install LLaMA Factory
RUN pip install --no-cache-dir -e ".[metrics]" --no-build-isolation
RUN pip install "git+https://github.com/alibaba/roll.git#subdirectory=mcore_adapter"
# Expose port 7860 for LLaMA Board
ENV GRADIO_SERVER_PORT=7860
EXPOSE 7860
# Expose port 8000 for API service
ENV API_PORT=8000
EXPOSE 8000
# unset proxy
ENV http_proxy=
ENV https_proxy=

View File

@@ -5,7 +5,6 @@ services:
context: ../.. context: ../..
args: args:
PIP_INDEX: https://pypi.org/simple PIP_INDEX: https://pypi.org/simple
EXTRAS: metrics
container_name: llamafactory container_name: llamafactory
ports: ports:
- "7860:7860" - "7860:7860"

View File

@@ -1,10 +1,10 @@
# https://hub.docker.com/r/ascendai/cann/tags # https://hub.docker.com/r/ascendai/cann/tags
ARG BASE_IMAGE=ascendai/cann:8.1.rc1-910b-ubuntu22.04-py3.11
ARG BASE_IMAGE=quay.io/ascend/cann:8.3.rc2-910b-ubuntu22.04-py3.11
FROM ${BASE_IMAGE} FROM ${BASE_IMAGE}
# Installation arguments # Installation arguments
ARG PIP_INDEX=https://pypi.org/simple ARG PIP_INDEX=https://pypi.org/simple
ARG EXTRAS=torch-npu,metrics
ARG HTTP_PROXY="" ARG HTTP_PROXY=""
ARG PYTORCH_INDEX=https://download.pytorch.org/whl/cpu ARG PYTORCH_INDEX=https://download.pytorch.org/whl/cpu
@@ -27,21 +27,15 @@ WORKDIR /app
# Change pip source # Change pip source
RUN pip config set global.index-url "${PIP_INDEX}" && \ RUN pip config set global.index-url "${PIP_INDEX}" && \
pip config set global.extra-index-url "${PIP_INDEX}" && \ pip config set global.extra-index-url "${PIP_INDEX}" && \
pip install --no-cache-dir --upgrade pip packaging wheel setuptools pip install --no-cache-dir --upgrade pip packaging wheel setuptools editables "hatchling>=1.18.0"
# Copy the application into the image
COPY . /app
# Install torch-npu # Install torch-npu
RUN pip uninstall -y torch torchvision torchaudio && \ RUN pip uninstall -y torch torchvision torchaudio && \
pip install --no-cache-dir "torch-npu==2.5.1" "torchvision==0.20.1" --index-url "${PYTORCH_INDEX}" pip install --no-cache-dir "torch==2.7.1" "torch-npu==2.7.1" "torchvision==0.22.1" "torchaudio==2.7.1" --index-url "${PYTORCH_INDEX}" && \
pip install --no-cache-dir -e ".[metrics]" --no-build-isolation
# Install the requirements
COPY requirements.txt /app
RUN pip install --no-cache-dir -r requirements.txt
# Copy the rest of the application into the image
COPY . /app
# Install LLaMA Factory
RUN pip install --no-cache-dir -e ".[${EXTRAS}]" --no-build-isolation
# Set up volumes # Set up volumes
# VOLUME [ "/root/.cache/huggingface", "/app/shared_data", "/app/output" ] # VOLUME [ "/root/.cache/huggingface", "/app/shared_data", "/app/output" ]

View File

@@ -1,12 +1,12 @@
services: services:
llamafactory: llamafactory-a2:
build: build:
dockerfile: ./docker/docker-npu/Dockerfile dockerfile: ./docker/docker-npu/Dockerfile
context: ../.. context: ../..
args: args:
PIP_INDEX: https://pypi.org/simple PIP_INDEX: https://pypi.org/simple
EXTRAS: torch-npu,metrics container_name: llamafactory-a2
container_name: llamafactory image: llamafactory:npu-a2
volumes: volumes:
- /usr/local/dcmi:/usr/local/dcmi - /usr/local/dcmi:/usr/local/dcmi
- /usr/local/bin/npu-smi:/usr/local/bin/npu-smi - /usr/local/bin/npu-smi:/usr/local/bin/npu-smi
@@ -26,3 +26,33 @@ services:
- /dev/devmm_svm - /dev/devmm_svm
- /dev/hisi_hdc - /dev/hisi_hdc
restart: unless-stopped restart: unless-stopped
llamafactory-a3:
profiles: ["a3"]
build:
dockerfile: ./docker/docker-npu/Dockerfile
context: ../..
args:
BASE_IMAGE: quay.io/ascend/cann:8.3.rc2-a3-ubuntu22.04-py3.11
PIP_INDEX: https://pypi.org/simple
container_name: llamafactory-a3
image: llamafactory:npu-a3
volumes:
- /usr/local/dcmi:/usr/local/dcmi
- /usr/local/bin/npu-smi:/usr/local/bin/npu-smi
- /usr/local/Ascend/driver:/usr/local/Ascend/driver
- /etc/ascend_install.info:/etc/ascend_install.info
ports:
- "7861:7860"
- "8001:8000"
ipc: host
tty: true
# shm_size: "16gb" # ipc: host is set
stdin_open: true
command: bash
devices:
- /dev/davinci0
- /dev/davinci_manager
- /dev/devmm_svm
- /dev/hisi_hdc
restart: unless-stopped

View File

@@ -4,7 +4,6 @@ FROM ${BASE_IMAGE}
# Installation arguments # Installation arguments
ARG PIP_INDEX=https://pypi.org/simple ARG PIP_INDEX=https://pypi.org/simple
ARG EXTRAS=metrics
ARG INSTALL_FLASHATTN=false ARG INSTALL_FLASHATTN=false
ARG HTTP_PROXY="" ARG HTTP_PROXY=""
ARG PYTORCH_INDEX=https://download.pytorch.org/whl/rocm6.3 ARG PYTORCH_INDEX=https://download.pytorch.org/whl/rocm6.3
@@ -28,21 +27,14 @@ WORKDIR /app
# Change pip source # Change pip source
RUN pip config set global.index-url "${PIP_INDEX}" && \ RUN pip config set global.index-url "${PIP_INDEX}" && \
pip config set global.extra-index-url "${PIP_INDEX}" && \ pip config set global.extra-index-url "${PIP_INDEX}" && \
pip install --no-cache-dir --upgrade pip packaging wheel setuptools pip install --no-cache-dir --upgrade pip packaging wheel setuptools editables "hatchling>=1.18.0"
# Reinstall pytorch rocm # Copy the application into the image
RUN pip uninstall -y torch torchvision torchaudio && \
pip install --no-cache-dir --pre torch torchvision torchaudio --index-url "${PYTORCH_INDEX}"
# Install the requirements
COPY requirements.txt /app
RUN pip install --no-cache-dir -r requirements.txt
# Copy the rest of the application into the image
COPY . /app COPY . /app
# Install LLaMA Factory # Reinstall pytorch rocm and install LLaMA Factory
RUN pip install --no-cache-dir -e ".[${EXTRAS}]" --no-build-isolation RUN pip uninstall -y torch torchvision torchaudio && \
pip install --no-cache-dir --no-build-isolation -e --pre ".[metrics,deepspeed]" --index-url "${PYTORCH_INDEX}"
# Rebuild flash attention # Rebuild flash attention
RUN if [ "${INSTALL_FLASHATTN}" == "true" ]; then \ RUN if [ "${INSTALL_FLASHATTN}" == "true" ]; then \

View File

@@ -5,7 +5,6 @@ services:
context: ../.. context: ../..
args: args:
PIP_INDEX: https://pypi.org/simple PIP_INDEX: https://pypi.org/simple
EXTRAS: metrics
container_name: llamafactory container_name: llamafactory
ports: ports:
- "7860:7860" - "7860:7860"

View File

@@ -0,0 +1,22 @@
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: FSDP
downcast_bf16: 'no'
fsdp_config:
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_cpu_ram_efficient_loading: true
fsdp_offload_params: false
fsdp_reshard_after_forward: true
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_version: 2
machine_rank: 0
main_training_function: main
mixed_precision: bf16 # or fp16
num_machines: 1 # the number of nodes
num_processes: 2 # the number of GPUs in all nodes
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

View File

@@ -0,0 +1,34 @@
# If you want to run this example on multiple nodes, you need to set the following parameters:
# - num_machines: the number of nodes
# - num_processes: the number of GPUs in all nodes, num_machines * num_processes_per_machine
# - main_process_ip: the IP address of the main process, please keep it the same across all nodes
# - main_process_port: the port of all nodes, please keep it the same across all nodes
# - machine_rank: the rank of the current machine, starting from 0, and it should be 0 for main_process_ip
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: FSDP
downcast_bf16: 'no'
fsdp_config:
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_backward_prefetch: BACKWARD_PRE
fsdp_forward_prefetch: false
fsdp_cpu_ram_efficient_loading: true
fsdp_offload_params: false
fsdp_sharding_strategy: FULL_SHARD
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_sync_module_states: true
fsdp_use_orig_params: true
machine_rank: 0
main_training_function: main
mixed_precision: bf16 # or fp16
main_process_ip: 192.168.0.1
main_process_port: 29500
num_machines: 2 # the number of nodes
num_processes: 16 # the number of GPUs in all nodes, num_machines * num_processes_per_machine
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

View File

@@ -0,0 +1,45 @@
# Start FSDP2 fine-tuning
# accelerate launch \
# --config_file examples/accelerate/fsdp2_config.yaml \
# src/train.py examples/ascend/qwen3_full_sft_fsdp2.yaml
# Change `num_processes` in fsdp2_config.yaml to 16 in A3
### model
model_name_or_path: Qwen/Qwen3-8B
trust_remote_code: true
use_v1_kernels: true
flash_attn: fa2
### method
stage: sft
do_train: true
finetuning_type: full
### dataset
dataset: alpaca_en_demo
template: qwen3
cutoff_len: 2048
max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16
dataloader_num_workers: 4
### output
output_dir: saves/Qwen3-8B/full/sft
logging_steps: 1
save_steps: 500
max_steps: 500
plot_loss: true
overwrite_output_dir: true
save_only_model: false
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
### train
per_device_train_batch_size: 8
gradient_accumulation_steps: 1
learning_rate: 1.0e-5
lr_scheduler_type: cosine
warmup_ratio: 0.1
bf16: true
ddp_timeout: 1800
resume_from_checkpoint: null

View File

@@ -0,0 +1,46 @@
# Start FSDP fine-tuning
# accelerate launch \
# --config_file examples/accelerate/fsdp_config.yaml \
# src/train.py examples/ascend/qwen3moe_full_sft_fsdp.yaml
# Change `num_processes` in fsdp_config.yaml to 16 in A3
### model
model_name_or_path: Qwen/Qwen3-30B-A3B-Instruct-2507
trust_remote_code: true
use_v1_kernels: true
flash_attn: fa2
### method
stage: sft
do_train: true
finetuning_type: full
disable_gradient_checkpointing: false
### dataset
dataset: alpaca_zh
template: qwen3
cutoff_len: 1024
overwrite_cache: true
preprocessing_num_workers: 16
dataloader_num_workers: 4
### output
output_dir: saves/Qwen3-30B-A3B-Instruct-2507/full/sft
logging_steps: 1
save_steps: 500
max_steps: 500
plot_loss: true
overwrite_output_dir: true
save_only_model: true
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
### train
per_device_train_batch_size: 4
gradient_accumulation_steps: 1
learning_rate: 1.0e-4
lr_scheduler_type: cosine
warmup_ratio: 0.1
bf16: true
ddp_timeout: 180000000
resume_from_checkpoint: null
seed: 1234

View File

@@ -0,0 +1,48 @@
# Start FSDP2 fine-tuning
# accelerate launch \
# --config_file examples/accelerate/fsdp2_config.yaml \
# src/train.py examples/ascend/qwen3vlmoe_full_sft_fsdp2.yaml
# Change `num_processes` in fsdp2_config.yaml to 16 in A3
### model
model_name_or_path: Qwen/Qwen3-VL-30B-A3B-Instruct
image_max_pixels: 262144
video_max_pixels: 16384
trust_remote_code: true
use_v1_kernels: true
flash_attn: fa2
### method
stage: sft
do_train: true
finetuning_type: full
disable_gradient_checkpointing: false
### dataset
dataset: llava_1k_en, llava_1k_zh
template: qwen3_vl
cutoff_len: 1024
overwrite_cache: true
preprocessing_num_workers: 16
dataloader_num_workers: 4
### output
output_dir: saves/Qwen3-VL-30B-A3B-Instruct/full/sft
logging_steps: 1
save_steps: 500
max_steps: 500
plot_loss: true
overwrite_output_dir: true
save_only_model: true
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
### train
per_device_train_batch_size: 2
gradient_accumulation_steps: 1
learning_rate: 1.0e-4
lr_scheduler_type: cosine
warmup_ratio: 0.1
bf16: true
ddp_timeout: 180000000
resume_from_checkpoint: null
seed: 1234

View File

@@ -0,0 +1,42 @@
### model
model_name_or_path: Qwen/Qwen3-VL-30B-A3B-Instruct
image_max_pixels: 262144
video_max_pixels: 16384
trust_remote_code: true
use_v1_kernels: true # replaced kernels: [NpuRMSNormKernel, NpuRoPEKernel, NpuQwen3VLMoEFusedMoEKernel]
### method
stage: sft
do_train: true
finetuning_type: lora
lora_rank: 8
lora_target: all
disable_gradient_checkpointing: false
flash_attn: disabled
### dataset
dataset: alpaca_zh_demo, alpaca_en_demo
template: qwen3_vl
cutoff_len: 1024
overwrite_cache: true
preprocessing_num_workers: 16
dataloader_num_workers: 4
### output
output_dir: saves/qwen3vlmoe/lora/sft
logging_steps: 1
plot_loss: true
overwrite_output_dir: true
save_only_model: true
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
### train
per_device_train_batch_size: 8
gradient_accumulation_steps: 1
learning_rate: 1.0e-4
lr_scheduler_type: cosine
warmup_ratio: 0.1
bf16: true
ddp_timeout: 180000000
resume_from_checkpoint: null
seed: 1234

View File

@@ -0,0 +1,32 @@
{
"_comment": "suooprted model list: https://www.deepspeed.ai/tutorials/automatic-tensor-parallelism/#supported-models",
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"zero_allow_untested_optimizer": true,
"fp16": {
"enabled": "auto",
"loss_scale": 0,
"loss_scale_window": 1000,
"initial_scale_power": 16,
"hysteresis": 2,
"min_loss_scale": 1
},
"bf16": {
"enabled": "auto"
},
"zero_optimization": {
"stage": 2,
"allgather_partitions": true,
"allgather_bucket_size": 5e8,
"overlap_comm": false,
"reduce_scatter": true,
"reduce_bucket_size": 5e8,
"contiguous_gradients": true,
"round_robin_gradients": true
},
"tensor_parallel": {
"autotp_size": 2
}
}

View File

@@ -0,0 +1,10 @@
model_name_or_path: deepseek-ai/DeepSeek-V2-Lite
adapter_name_or_path: saves/Kllama_deepseekV2
template: deepseek
infer_backend: ktransformers # choices: [huggingface, vllm, sglang, ktransformers]
trust_remote_code: true
use_kt: true # use KTransformers as LoRA sft backend to inference
kt_optimize_rule: examples/kt_optimize_rules/DeepSeek-V2-Lite-Chat-sft-amx.yaml
cpu_infer: 32
chunk_size: 8192

View File

@@ -0,0 +1,9 @@
model_name_or_path: opensourcerelease/DeepSeek-V3-bf16
template: deepseek
infer_backend: ktransformers # choices: [huggingface, vllm, sglang, ktransformers]
trust_remote_code: true
use_kt: true # use KTransformers as LoRA sft backend to inference
kt_optimize_rule: examples/kt_optimize_rules/DeepSeek-V3-Chat-sft-amx-multi-gpu.yaml
cpu_infer: 32
chunk_size: 8192

View File

@@ -0,0 +1,10 @@
model_name_or_path: opensourcerelease/DeepSeek-V3-bf16
adapter_name_or_path: saves/Kllama_deepseekV3
template: deepseek
infer_backend: ktransformers # choices: [huggingface, vllm, sglang, ktransformers]
trust_remote_code: true
use_kt: true # use KTransformers as LoRA sft backend to inference
kt_optimize_rule: examples/kt_optimize_rules/DeepSeek-V3-Chat-sft-amx-multi-gpu.yaml
cpu_infer: 32
chunk_size: 8192

View File

@@ -1,4 +1,4 @@
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
template: llama3 template: llama3
infer_backend: huggingface # choices: [huggingface, vllm, sglang] infer_backend: huggingface # choices: [huggingface, vllm, sglang, ktransformers]
trust_remote_code: true trust_remote_code: true

View File

@@ -1,4 +1,4 @@
model_name_or_path: saves/llama3-8b/full/sft model_name_or_path: saves/llama3-8b/full/sft
template: llama3 template: llama3
infer_backend: huggingface # choices: [huggingface, vllm, sglang] infer_backend: huggingface # choices: [huggingface, vllm, sglang, ktransformers]
trust_remote_code: true trust_remote_code: true

View File

@@ -1,5 +1,5 @@
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
adapter_name_or_path: saves/llama3-8b/lora/sft adapter_name_or_path: saves/llama3-8b/lora/sft
template: llama3 template: llama3
infer_backend: huggingface # choices: [huggingface, vllm, sglang] infer_backend: huggingface # choices: [huggingface, vllm, sglang, ktransformers]
trust_remote_code: true trust_remote_code: true

View File

@@ -1,4 +1,4 @@
model_name_or_path: Qwen/Qwen2.5-VL-7B-Instruct model_name_or_path: Qwen/Qwen2.5-VL-7B-Instruct
template: qwen2_vl template: qwen2_vl
infer_backend: huggingface # choices: [huggingface, vllm, sglang] infer_backend: huggingface # choices: [huggingface, vllm, sglang, ktransformers]
trust_remote_code: true trust_remote_code: true

View File

@@ -0,0 +1,10 @@
model_name_or_path: Qwen/Qwen3-235B-A22B-Instruct-2507
adapter_name_or_path: saves/Kllama_Qwen3MoE_235bA22b
template: qwen3_nothink
infer_backend: ktransformers # choices: [huggingface, vllm, sglang, ktransformers]
trust_remote_code: true
use_kt: true # use KTransformers as LoRA sft backend to inference
kt_optimize_rule: examples/kt_optimize_rules/Qwen3Moe-sft-amx.yaml
cpu_infer: 32
chunk_size: 8192

View File

@@ -0,0 +1,69 @@
- match:
class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding
replace:
class: ktransformers.operators.RoPE.YarnRotaryEmbedding
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match:
name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
generate_op: "KLinearTorch"
prefill_op: "KLinearTorch"
- match:
name: "^lm_head"
class: torch.nn.Linear
replace:
class: ktransformers.operators.linear.KTransformersLinear
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
generate_op: "KLinearTorch"
prefill_op: "KLinearTorch"
- match:
name: "^model\\.layers\\..*\\.mlp$"
class: ktransformers.models.modeling_deepseek.DeepseekV2MoE
replace:
class: ktransformers.operators.experts.KDeepseekV2MoE # mlp module with custom forward function
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match:
name: "^model\\.layers\\..*\\.mlp\\.experts$"
replace:
class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
kwargs:
prefill_device: "cuda"
prefill_op: "KExpertsTorch"
generate_device: "cpu"
generate_op: "KSFTExpertsCPU"
out_device: "cuda"
backend: "AMXInt8" # or "AMXBF16" or "llamafile" (default)
recursive: False # don't recursively inject submodules of this module
- match:
name: "^model\\.layers\\..*\\.self_attn$"
replace:
class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match:
name: "^model$"
replace:
class: "ktransformers.operators.models.KDeepseekV2Model"
kwargs:
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
- match:
name: "^model.embed_tokens"
replace:
class: "default"
kwargs:
generate_device: "cpu"
prefill_device: "cpu"

View File

@@ -0,0 +1,68 @@
- match:
class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding
replace:
class: ktransformers.operators.RoPE.YarnRotaryEmbedding
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match:
name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch"
- match:
name: "^lm_head"
class: torch.nn.Linear
replace:
class: ktransformers.operators.linear.KTransformersLinear
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch"
- match:
name: "^model\\.layers\\..*\\.mlp$"
class: ktransformers.models.modeling_deepseek.DeepseekV2MoE
replace:
class: ktransformers.operators.experts.KDeepseekV2MoE # mlp module with custom forward function
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match:
name: "^model\\.layers\\..*\\.mlp\\.experts$"
replace:
class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
kwargs:
prefill_device: "cuda"
prefill_op: "KExpertsTorch"
generate_device: "cpu"
generate_op: "KExpertsCPU"
out_device: "cuda"
recursive: False # don't recursively inject submodules of this module
- match:
name: "^model\\.layers\\..*\\.self_attn$"
replace:
class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match:
name: "^model$"
replace:
class: "ktransformers.operators.models.KDeepseekV2Model"
kwargs:
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
- match:
name: "^model.embed_tokens"
replace:
class: "default"
kwargs:
generate_device: "cpu"
prefill_device: "cpu"

View File

@@ -0,0 +1,139 @@
- match:
name: "^model.embed_tokens"
replace:
class: "default"
kwargs:
generate_device: "cpu"
prefill_device: "cpu"
- match:
name: "^model\\.layers\\.(0|[1-9])\\."
class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding
replace:
class: ktransformers.operators.RoPE.YarnRotaryEmbedding
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
- match:
name: "^model\\.layers\\.([12][0-9])\\."
class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding
replace:
class: ktransformers.operators.RoPE.YarnRotaryEmbedding
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
- match:
name: "^model\\.layers\\.(0|[1-9])\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
generate_op: "KLinearTorch"
prefill_op: "KLinearTorch"
- match:
name: "^model\\.layers\\.([12][0-9])\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
generate_op: "KLinearTorch"
prefill_op: "KLinearTorch"
- match:
name: "^model\\.layers\\.(0|[1-9])\\.mlp$"
class: ktransformers.models.modeling_deepseek.DeepseekV2MoE
replace:
class: ktransformers.operators.experts.KDeepseekV2MoE # mlp module with custom forward function
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
- match:
name: "^model\\.layers\\.([12][0-9])\\.mlp$"
class: ktransformers.models.modeling_deepseek.DeepseekV2MoE
replace:
class: ktransformers.operators.experts.KDeepseekV2MoE # mlp module with custom forward function
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
- match:
name: "^model\\.layers\\.(0|[1-9])\\.mlp\\.experts$"
replace:
class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
kwargs:
prefill_device: "cuda:0"
prefill_op: "KExpertsTorch"
generate_device: "cpu"
generate_op: "KSFTExpertsCPU"
out_device: "cuda:0"
backend: "AMXInt8" # or "AMXBF16" or "llamafile" (default)
recursive: False # don't recursively inject submodules of this module
- match:
name: "^model\\.layers\\.([12][0-9])\\.mlp\\.experts$"
replace:
class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
kwargs:
prefill_device: "cuda:1"
prefill_op: "KExpertsTorch"
generate_device: "cpu"
generate_op: "KSFTExpertsCPU"
out_device: "cuda:1"
backend: "AMXInt8" # or "AMXBF16" or "llamafile" (default)
recursive: False # don't recursively inject submodules of this module
- match:
name: "^model\\.layers\\.(0|[1-9])\\.self_attn$"
replace:
class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
- match:
name: "^model\\.layers\\.([12][0-9])\\.self_attn$"
replace:
class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
- match:
name: "^model$"
replace:
class: "ktransformers.operators.models.KDeepseekV2Model"
kwargs:
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
transfer_map:
10: "cuda:1"
- match:
name: "^model\\.layers\\.(0|[1-9])\\."
replace:
class: "default"
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
- match:
name: "^lm_head"
class: torch.nn.Linear
replace:
class: ktransformers.operators.linear.KTransformersLinear
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
generate_op: "KLinearTorch"
prefill_op: "KLinearTorch"
- match:
name: "(^model\\.layers\\.([12][0-9])\\.)|(model.norm)"
replace:
class: "default"
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"

View File

@@ -0,0 +1,69 @@
- match:
class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding
replace:
class: ktransformers.operators.RoPE.YarnRotaryEmbedding
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match:
name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
generate_op: "KLinearTorch"
prefill_op: "KLinearTorch"
- match:
name: "^lm_head"
class: torch.nn.Linear
replace:
class: ktransformers.operators.linear.KTransformersLinear
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
generate_op: "KLinearTorch"
prefill_op: "KLinearTorch"
- match:
name: "^model\\.layers\\..*\\.mlp$"
class: ktransformers.models.modeling_deepseek.DeepseekV2MoE
replace:
class: ktransformers.operators.experts.KDeepseekV2MoE # mlp module with custom forward function
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match:
name: "^model\\.layers\\..*\\.mlp\\.experts$"
replace:
class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
kwargs:
prefill_device: "cpu"
prefill_op: "KExpertsTorch"
generate_device: "cpu"
generate_op: "KSFTExpertsCPU"
out_device: "cuda"
backend: "AMXInt8" # or "AMXBF16" or "llamafile" (default)
recursive: False # don't recursively inject submodules of this module
- match:
name: "^model\\.layers\\..*\\.self_attn$"
replace:
class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match:
name: "^model$"
replace:
class: "ktransformers.operators.models.KDeepseekV2Model"
kwargs:
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
- match:
name: "^model.embed_tokens"
replace:
class: "default"
kwargs:
generate_device: "cpu"
prefill_device: "cpu"

View File

@@ -0,0 +1,68 @@
- match:
class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding
replace:
class: ktransformers.operators.RoPE.YarnRotaryEmbedding
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match:
name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
generate_op: "KLinearTorch"
prefill_op: "KLinearTorch"
- match:
name: "^lm_head"
class: torch.nn.Linear
replace:
class: ktransformers.operators.linear.KTransformersLinear
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
generate_op: "KLinearTorch"
prefill_op: "KLinearTorch"
- match:
name: "^model\\.layers\\..*\\.mlp$"
class: ktransformers.models.modeling_deepseek.DeepseekV2MoE
replace:
class: ktransformers.operators.experts.KDeepseekV2MoE # mlp module with custom forward function
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match:
name: "^model\\.layers\\..*\\.mlp\\.experts$"
replace:
class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
kwargs:
prefill_device: "cpu"
prefill_op: "KExpertsTorch"
generate_device: "cpu"
generate_op: "KSFTExpertsCPU"
out_device: "cuda"
recursive: False # don't recursively inject submodules of this module
- match:
name: "^model\\.layers\\..*\\.self_attn$"
replace:
class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match:
name: "^model$"
replace:
class: "ktransformers.operators.models.KDeepseekV2Model"
kwargs:
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
- match:
name: "^model.embed_tokens"
replace:
class: "default"
kwargs:
generate_device: "cpu"
prefill_device: "cpu"

View File

@@ -0,0 +1,68 @@
- match:
class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding
replace:
class: ktransformers.operators.RoPE.YarnRotaryEmbedding
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match:
name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch"
- match:
name: "^lm_head"
class: torch.nn.Linear
replace:
class: ktransformers.operators.linear.KTransformersLinear
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch"
- match:
name: "^model\\.layers\\..*\\.mlp$"
class: ktransformers.models.modeling_deepseek.DeepseekV2MoE
replace:
class: ktransformers.operators.experts.KDeepseekV2MoE # mlp module with custom forward function
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match:
name: "^model\\.layers\\..*\\.mlp\\.experts$"
replace:
class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
kwargs:
prefill_device: "cuda"
prefill_op: "KExpertsTorch"
generate_device: "cpu"
generate_op: "KExpertsCPU"
out_device: "cuda"
recursive: False # don't recursively inject submodules of this module
- match:
name: "^model\\.layers\\..*\\.self_attn$"
replace:
class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match:
name: "^model$"
replace:
class: "ktransformers.operators.models.KDeepseekV2Model"
kwargs:
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
- match:
name: "^model.embed_tokens"
replace:
class: "default"
kwargs:
generate_device: "cpu"
prefill_device: "cpu"

View File

@@ -0,0 +1,77 @@
- match:
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
replace:
class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match:
name: "^lm_head$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch"
- match:
name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch"
- match:
name: "^model\\.layers\\..*\\.mlp$"
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
replace:
class: ktransformers.operators.experts.KDeepseekV3MoE # mlp module with custom forward function
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match:
class: ktransformers.models.modeling_deepseek_v3.MoEGate
replace:
class: ktransformers.operators.gate.KMoEGate
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
- match:
name: "^model\\.layers\\..*\\.mlp\\.experts$"
replace:
class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
kwargs:
prefill_device: "cuda"
prefill_op: "KExpertsTorch"
generate_device: "cpu"
generate_op: "KExpertsCPU"
out_device: "cuda"
backend: "AMXInt8" # or "AMXBF16" or "llamafile" (default)
recursive: False # don't recursively inject submodules of this module
- match:
name: "^model\\.layers\\..*\\.self_attn$"
replace:
class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
absorb_for_prefill: False # change this to True to enable long context(prefill may slower).
- match:
name: "^model$"
replace:
class: "ktransformers.operators.models.KDeepseekV2Model"
kwargs:
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
- match:
name: "^model.embed_tokens"
replace:
class: "default"
kwargs:
generate_device: "cpu"
prefill_device: "cpu"

View File

@@ -0,0 +1,392 @@
- match:
name: "^model.embed_tokens"
replace:
class: "default"
kwargs:
generate_device: "cpu"
prefill_device: "cpu"
# === Rotary Embedding Replacement ===
# GPU 0: layers 014
- match:
name: "^model\\.layers\\.([0-9]|1[0-4])\\."
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
replace:
class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
# GPU 1: layers 1529
- match:
name: "^model\\.layers\\.(1[5-9]|2[0-9])\\."
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
replace:
class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
# GPU 2: layers 3044
- match:
name: "^model\\.layers\\.(3[0-9]|4[0-4])\\."
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
replace:
class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3
kwargs:
generate_device: "cuda:2"
prefill_device: "cuda:2"
# GPU 3: layers 4560
- match:
name: "^model\\.layers\\.(4[5-9]|5[0-9]|60)\\."
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
replace:
class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3
kwargs:
generate_device: "cuda:3"
prefill_device: "cuda:3"
# === Linear Layers Replacement (excluding self_attn.kv_b_proj) ===
# GPU 0: layers 014
- match:
name: "^model\\.layers\\.([0-9]|1[0-4])\\.(?!self_attn\\.kv_b_proj).*$"
class: torch.nn.Linear
replace:
class: ktransformers.operators.linear.KTransformersLinear
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
generate_op: "KLinearTorch"
prefill_op: "KLinearTorch"
# GPU 1: layers 1529
- match:
name: "^model\\.layers\\.(1[5-9]|2[0-9])\\.(?!self_attn\\.kv_b_proj).*$"
class: torch.nn.Linear
replace:
class: ktransformers.operators.linear.KTransformersLinear
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
generate_op: "KLinearTorch"
prefill_op: "KLinearTorch"
# GPU 2: layers 3044
- match:
name: "^model\\.layers\\.(3[0-9]|4[0-4])\\.(?!self_attn\\.kv_b_proj).*$"
class: torch.nn.Linear
replace:
class: ktransformers.operators.linear.KTransformersLinear
kwargs:
generate_device: "cuda:2"
prefill_device: "cuda:2"
generate_op: "KLinearTorch"
prefill_op: "KLinearTorch"
# GPU 3: layers 4560
- match:
name: "^model\\.layers\\.(4[5-9]|5[0-9]|60)\\.(?!self_attn\\.kv_b_proj).*$"
class: torch.nn.Linear
replace:
class: ktransformers.operators.linear.KTransformersLinear
kwargs:
generate_device: "cuda:3"
prefill_device: "cuda:3"
generate_op: "KLinearTorch"
prefill_op: "KLinearTorch"
# === MLP (MoE) Replacement ===
# GPU 0: layers 014
- match:
name: "^model\\.layers\\.([0-9]|1[0-4])\\.mlp$"
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
replace:
class: ktransformers.operators.experts.KDeepseekV3MoE
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
# GPU 1: layers 1529
- match:
name: "^model\\.layers\\.(1[5-9]|2[0-9])\\.mlp$"
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
replace:
class: ktransformers.operators.experts.KDeepseekV3MoE
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
# GPU 2: layers 3044
- match:
name: "^model\\.layers\\.(3[0-9]|4[0-4])\\.mlp$"
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
replace:
class: ktransformers.operators.experts.KDeepseekV3MoE
kwargs:
generate_device: "cuda:2"
prefill_device: "cuda:2"
# GPU 3: layers 4560
- match:
name: "^model\\.layers\\.(4[5-9]|5[0-9]|60)\\.mlp$"
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
replace:
class: ktransformers.operators.experts.KDeepseekV3MoE
kwargs:
generate_device: "cuda:3"
prefill_device: "cuda:3"
# === MLP Gate Replacement ===
# GPU 0: layers 014
- match:
name: "^model\\.layers\\.([0-9]|1[0-4])\\.mlp\\.gate$"
class: ktransformers.models.modeling_deepseek_v3.MoEGate
replace:
class: ktransformers.operators.gate.KMoEGate
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
# GPU 1: layers 1529
- match:
name: "^model\\.layers\\.(1[5-9]|2[0-9])\\.mlp\\.gate$"
class: ktransformers.models.modeling_deepseek_v3.MoEGate
replace:
class: ktransformers.operators.gate.KMoEGate
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
# GPU 2: layers 3044
- match:
name: "^model\\.layers\\.(3[0-9]|4[0-4])\\.mlp\\.gate$"
class: ktransformers.models.modeling_deepseek_v3.MoEGate
replace:
class: ktransformers.operators.gate.KMoEGate
kwargs:
generate_device: "cuda:2"
prefill_device: "cuda:2"
# GPU 3: layers 4560
- match:
name: "^model\\.layers\\.(4[5-9]|5[0-9]|60)\\.mlp\\.gate$"
class: ktransformers.models.modeling_deepseek_v3.MoEGate
replace:
class: ktransformers.operators.gate.KMoEGate
kwargs:
generate_device: "cuda:3"
prefill_device: "cuda:3"
# === MLP Experts Replacement ===
# replace with marlin expert. Open and modify layer-num as needed.
# Each layer of malin experts takes about 6GB of GPU memory.
# !!!Do remember 'close' cuda graph if you are using marlin expert.!!!
# !!!KExpertsTorch is untested, we don't have enough VRAM.!!!
# GPU 0: layers 34
# - match:
# name: "^model\\.layers\\.([3-4])\\.mlp\\.experts$"
# replace:
# class: ktransformers.operators.experts.KTransformersExperts
# kwargs:
# generate_device: "cuda:0"
# generate_op: "KExpertsMarlin"
# recursive: False
# # GPU 1: layers 1517
# - match:
# name: "^model\\.layers\\.(1[5-7])\\.mlp\\.experts$"
# replace:
# class: ktransformers.operators.experts.KTransformersExperts
# kwargs:
# generate_device: "cuda:1"
# generate_op: "KExpertsMarlin"
# recursive: False
# # GPU 2: layers 3032
# - match:
# name: "^model\\.layers\\.(3[0-2])\\.mlp\\.experts$"
# replace:
# class: ktransformers.operators.experts.KTransformersExperts
# kwargs:
# generate_device: "cuda:2"
# generate_op: "KExpertsMarlin"
# recursive: False
# # GPU 3: layers 4546
# - match:
# name: "^model\\.layers\\.(4[5-6])\\.mlp\\.experts$"
# replace:
# class: ktransformers.operators.experts.KTransformersExperts
# kwargs:
# generate_device: "cuda:3"
# generate_op: "KExpertsMarlin"
# recursive: False
# === MLP Experts Replacement ===
# GPU 0: layers 014
- match:
name: "^model\\.layers\\.([0-9]|1[0-4])\\.mlp\\.experts$"
replace:
class: ktransformers.operators.experts.KTransformersExperts
kwargs:
prefill_device: "cuda:0"
prefill_op: "KExpertsTorch"
generate_device: "cpu"
generate_op: "KSFTExpertsCPU"
out_device: "cuda:0"
backend: "AMXInt8" # or "AMXBF16" or "llamafile" (default)
recursive: False
# GPU 1: layers 1529
- match:
name: "^model\\.layers\\.(1[5-9]|2[0-9])\\.mlp\\.experts$"
replace:
class: ktransformers.operators.experts.KTransformersExperts
kwargs:
prefill_device: "cuda:1"
prefill_op: "KExpertsTorch"
generate_device: "cpu"
generate_op: "KSFTExpertsCPU"
out_device: "cuda:1"
backend: "AMXInt8" # or "AMXBF16" or "llamafile" (default)
recursive: False
# GPU 2: layers 3044
- match:
name: "^model\\.layers\\.(3[0-9]|4[0-4])\\.mlp\\.experts$"
replace:
class: ktransformers.operators.experts.KTransformersExperts
kwargs:
prefill_device: "cuda:2"
prefill_op: "KExpertsTorch"
generate_device: "cpu"
generate_op: "KSFTExpertsCPU"
out_device: "cuda:2"
backend: "AMXInt8" # or "AMXBF16" or "llamafile" (default)
recursive: False
# GPU 3: layers 4560
- match:
name: "^model\\.layers\\.(4[5-9]|5[0-9]|60)\\.mlp\\.experts$"
replace:
class: ktransformers.operators.experts.KTransformersExperts
kwargs:
prefill_device: "cuda:3"
prefill_op: "KExpertsTorch"
generate_device: "cpu"
generate_op: "KSFTExpertsCPU"
out_device: "cuda:3"
backend: "AMXInt8" # or "AMXBF16" or "llamafile" (default)
recursive: False
# === Self-Attention Replacement ===
# GPU 0: layers 014
- match:
name: "^model\\.layers\\.([0-9]|1[0-4])\\.self_attn$"
replace:
class: ktransformers.operators.attention.KDeepseekV2Attention
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
absorb_for_prefill: False
# GPU 1: layers 1529
- match:
name: "^model\\.layers\\.(1[5-9]|2[0-9])\\.self_attn$"
replace:
class: ktransformers.operators.attention.KDeepseekV2Attention
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
absorb_for_prefill: False
# GPU 2: layers 3044
- match:
name: "^model\\.layers\\.(3[0-9]|4[0-4])\\.self_attn$"
replace:
class: ktransformers.operators.attention.KDeepseekV2Attention
kwargs:
generate_device: "cuda:2"
prefill_device: "cuda:2"
absorb_for_prefill: False
# GPU 3: layers 4560
- match:
name: "^model\\.layers\\.(4[5-9]|5[0-9]|60)\\.self_attn$"
replace:
class: ktransformers.operators.attention.KDeepseekV2Attention
kwargs:
generate_device: "cuda:3"
prefill_device: "cuda:3"
absorb_for_prefill: False
# === Overall Model Replacement with Transfer Map ===
- match:
name: "^model$"
replace:
class: "ktransformers.operators.models.KDeepseekV2Model"
kwargs:
per_layer_prefill_intput_threshold: 0 # 0 means close layerwise prefill
transfer_map:
15: "cuda:1" # Layers 15+ on GPU 1
30: "cuda:2" # Layers 30+ on GPU 2
45: "cuda:3" # Layers 45+ on GPU 3
# === Default Catch-All for Other Modules ===
# GPU 0: layers 014
- match:
name: "^model\\.layers\\.([0-9]|1[0-4])\\."
replace:
class: "default"
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
# GPU 1: layers 1529
- match:
name: "^model\\.layers\\.(1[5-9]|2[0-9])\\."
replace:
class: "default"
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
# GPU 2: layers 3044
- match:
name: "^model\\.layers\\.(3[0-9]|4[0-4])\\."
replace:
class: "default"
kwargs:
generate_device: "cuda:2"
prefill_device: "cuda:2"
- match:
name: "^lm_head"
class: torch.nn.Linear
replace:
class: ktransformers.operators.linear.KTransformersLinear
kwargs:
generate_device: "cuda:3"
prefill_device: "cuda:3"
generate_op: "KLinearTorch"
prefill_op: "KLinearTorch"
# For final modules (model.norm), ensure they are on GPU 3 (as in your original config)
- match:
name: "(^model\\.layers\\.(4[5-9]|5[0-9]|60)\\.)|(^model\\.norm)"
replace:
class: "default"
kwargs:
generate_device: "cuda:3"
prefill_device: "cuda:3"

View File

@@ -0,0 +1,156 @@
- match:
name: "^model.embed_tokens"
replace:
class: "default"
kwargs:
generate_device: "cpu"
prefill_device: "cpu"
- match:
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\."
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
replace:
class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
- match:
name: "^model\\.layers\\.([3456][0-9])\\."
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
replace:
class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
- match:
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.(?!self_attn\\.kv_b_proj).*$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
generate_op: "KLinearTorch"
prefill_op: "KLinearTorch"
- match:
name: "^model\\.layers\\.([3456][0-9])\\.(?!self_attn\\.kv_b_proj).*$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
generate_op: "KLinearTorch"
prefill_op: "KLinearTorch"
- match:
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.mlp$"
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
replace:
class: ktransformers.operators.experts.KDeepseekV3MoE # mlp module with custom forward function
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
- match:
name: "^model\\.layers\\.([3456][0-9])\\.mlp$"
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
replace:
class: ktransformers.operators.experts.KDeepseekV3MoE # mlp module with custom forward function
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
- match:
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.mlp\\.gate$"
class: ktransformers.models.modeling_deepseek_v3.MoEGate
replace:
class: ktransformers.operators.gate.KMoEGate
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
- match:
name: "^model\\.layers\\.([3456][0-9])\\.mlp\\.gate$"
class: ktransformers.models.modeling_deepseek_v3.MoEGate
replace:
class: ktransformers.operators.gate.KMoEGate # mlp module with custom forward function
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
- match:
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.mlp\\.experts$"
replace:
class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
kwargs:
prefill_device: "cuda:0"
prefill_op: "KExpertsTorch"
generate_device: "cpu"
generate_op: "KSFTExpertsCPU"
out_device: "cuda:0"
backend: "AMXInt8" # or "AMXBF16" or "llamafile" (default)
recursive: False # don't recursively inject submodules of this module
- match:
name: "^model\\.layers\\.([3456][0-9])\\.mlp\\.experts$"
replace:
class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
kwargs:
prefill_device: "cuda:1"
prefill_op: "KExpertsTorch"
generate_device: "cpu"
generate_op: "KSFTExpertsCPU"
out_device: "cuda:1"
backend: "AMXInt8" # or "AMXBF16" or "llamafile" (default)
recursive: False # don't recursively inject submodules of this module
- match:
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.self_attn$"
replace:
class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
- match:
name: "^model\\.layers\\.([3456][0-9])\\.self_attn$"
replace:
class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
- match:
name: "^model$"
replace:
class: "ktransformers.operators.models.KDeepseekV2Model"
kwargs:
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
transfer_map:
30: "cuda:1"
- match:
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\."
replace:
class: "default"
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
- match:
name: "^lm_head"
class: torch.nn.Linear
replace:
class: ktransformers.operators.linear.KTransformersLinear
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
generate_op: "KLinearTorch"
prefill_op: "KLinearTorch"
- match:
name: "(^model\\.layers\\.([3456][0-9])\\.)|(model.norm)"
replace:
class: "default"
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"

View File

@@ -0,0 +1,77 @@
- match:
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
replace:
class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match:
name: "^lm_head$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
generate_op: "KLinearTorch"
prefill_op: "KLinearTorch"
- match:
name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
generate_op: "KLinearTorch"
prefill_op: "KLinearTorch"
- match:
name: "^model\\.layers\\..*\\.mlp$"
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
replace:
class: ktransformers.operators.experts.KDeepseekV3MoE # mlp module with custom forward function
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match:
class: ktransformers.models.modeling_deepseek_v3.MoEGate
replace:
class: ktransformers.operators.gate.KMoEGate
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
- match:
name: "^model\\.layers\\..*\\.mlp\\.experts$"
replace:
class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
kwargs:
prefill_device: "cuda"
prefill_op: "KExpertsTorch"
generate_device: "cpu"
generate_op: "KSFTExpertsCPU"
out_device: "cuda"
backend: "AMXInt8" # or "AMXBF16" or "llamafile" (default)
recursive: False # don't recursively inject submodules of this module
- match:
name: "^model\\.layers\\..*\\.self_attn$"
replace:
class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
absorb_for_prefill: False # change this to True to enable long context(prefill may slower).
- match:
name: "^model$"
replace:
class: "ktransformers.operators.models.KDeepseekV2Model"
kwargs:
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
- match:
name: "^model.embed_tokens"
replace:
class: "default"
kwargs:
generate_device: "cpu"
prefill_device: "cpu"

View File

@@ -0,0 +1,80 @@
- match:
class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeRotaryEmbedding
replace:
class: ktransformers.operators.RoPE.RotaryEmbedding
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match:
name: "^lm_head$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
generate_op: "KLinearTorch"
prefill_op: "KLinearTorch"
# - match:
# name: "^model\\.layers\\..*$" # regular expression
# class: torch.nn.Linear # only match modules matching name and class simultaneously
# replace:
# class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
# kwargs:
# generate_device: "cuda"
# prefill_device: "cuda"
# generate_op: "KLinearTorch"
# prefill_op: "KLinearTorch"
- match:
name: "^model\\.layers\\.(?!.*mlp\\.shared_expert_gate).*$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
generate_op: "KLinearTorch"
prefill_op: "KLinearTorch"
- match:
name: "^model\\.layers\\..*\\.mlp$"
replace:
class: ktransformers.operators.experts.KQwen3MoeSparseMoeBlock # mlp module with custom forward function
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match:
name: "^model\\.layers\\..*\\.mlp\\.experts$"
replace:
class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
kwargs:
prefill_device: "cuda"
prefill_op: "KExpertsTorch"
generate_device: "cpu"
generate_op: "KSFTExpertsCPU"
out_device: "cuda"
backend: "AMXInt8" # or "AMXBF16" or "AMXInt8"
recursive: False # don't recursively inject submodules of this module
- match:
name: "^model\\.layers\\..*\\.self_attn$"
replace:
class: ktransformers.operators.attention.KQwen3MoeAttention # optimized MLA implementation
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match:
name: "^model.embed_tokens"
replace:
class: "default"
kwargs:
generate_device: "cpu"
prefill_device: "cpu"
- match:
name: "^model$"
replace:
class: "ktransformers.operators.models.KQwen3MoeModel"
kwargs:
per_layer_prefill_intput_threshold: 0

View File

@@ -0,0 +1,29 @@
model_name_or_path: Qwen/Qwen2-VL-7B-Instruct
image_max_pixels: 262144
video_max_pixels: 16384
do_train: true
stage: sft
finetuning_type: full # only support full for now
dataset: llava_1k_en
preprocessing_num_workers: 8
cutoff_len: 4096
template: qwen2_vl
output_dir: saves/mca/qwen2_vl_full
per_device_train_batch_size: 1
gradient_accumulation_steps: 2
num_train_epochs: 2
learning_rate: 2e-5
logging_steps: 1
save_steps: 100
lr_scheduler_type: cosine
bf16: true
# mcore speed up
tensor_model_parallel_size: 4
sequence_parallel: true
pipeline_model_parallel_size: 2
bias_activation_fusion: true
apply_rope_fusion: true
use_distributed_optimizer: true

View File

@@ -0,0 +1,35 @@
model_name_or_path: Qwen/Qwen3-30B-A3B-Instruct-2507
# GPU memory: 8 * 78GB
do_train: true
stage: sft
finetuning_type: full # only support full for now
dataset: alpaca_en_demo
preprocessing_num_workers: 8
cutoff_len: 4096
template: qwen3_nothink
# global batchsize = (8 // 2 // 4) * 8 = 8
output_dir: saves/mca/qwen3_moe_full
per_device_train_batch_size: 1
gradient_accumulation_steps: 8
num_train_epochs: 2
learning_rate: 3e-6
logging_steps: 1
save_steps: 100
lr_scheduler_type: constant
bf16: true
# mcore speed up
tensor_model_parallel_size: 1
sequence_parallel: false
pipeline_model_parallel_size: 4
bias_activation_fusion: true
apply_rope_fusion: true
use_distributed_optimizer: true
overlap_param_gather: true
overlap_grad_reduce: true
moe_grouped_gemm: true
moe_token_dispatcher_type: alltoall
expert_model_parallel_size: 2
recompute_granularity: full

View File

@@ -0,0 +1 @@
adam-mini

View File

@@ -0,0 +1 @@
apollo-torch

View File

@@ -0,0 +1 @@
aqlm[gpu]>=1.1.0

View File

@@ -0,0 +1 @@
badam>=1.2.1

View File

@@ -0,0 +1 @@
bitsandbytes>=0.39.0

View File

@@ -0,0 +1 @@
eetq

View File

@@ -0,0 +1,2 @@
transformer_engine[pytorch]>=2.0.0
accelerate>=1.10.0

View File

@@ -0,0 +1,2 @@
torchao>=0.8.0
accelerate>=1.10.0

View File

@@ -0,0 +1 @@
galore-torch

View File

@@ -0,0 +1,2 @@
optimum>=1.24.0
gptqmodel>=2.0.0

View File

@@ -0,0 +1 @@
hqq

View File

@@ -0,0 +1 @@
liger-kernel>=0.5.5

View File

@@ -0,0 +1,8 @@
soundfile
torchvision
torchaudio
vector_quantize_pytorch
vocos
msgpack
referencing
jsonschema_specifications

View File

@@ -0,0 +1 @@
openmind

View File

@@ -0,0 +1,2 @@
sglang[srt]>=0.4.5
transformers==4.51.1

View File

@@ -0,0 +1 @@
swanlab

View File

@@ -0,0 +1 @@
vllm>=0.4.3,<=0.11.0

View File

@@ -0,0 +1,46 @@
### model
model_name_or_path: Qwen/Qwen3-32B
trust_remote_code: true
use_v1_kernels: true
### method
stage: sft
do_train: true
finetuning_type: full
deepspeed: examples/deepspeed/ds_z2_autotp_config.json
### dataset
dataset: identity,alpaca_en_demo
template: qwen3
cutoff_len: 2048
max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16
dataloader_num_workers: 4
### output
output_dir: saves/qwen3-32b/full/sft_autotp
logging_steps: 1
save_steps: 500
plot_loss: true
overwrite_output_dir: true
save_only_model: false
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
### train
per_device_train_batch_size: 4
gradient_accumulation_steps: 1
learning_rate: 1.0e-4
num_train_epochs: 3.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
bf16: true
ddp_timeout: 180000000
resume_from_checkpoint: null
### eval
# eval_dataset: alpaca_en_demo
# val_size: 0.1
# per_device_eval_batch_size: 1
# eval_strategy: steps
# eval_steps: 500

View File

@@ -0,0 +1,52 @@
### model
model_name_or_path: deepseek-ai/DeepSeek-V2-Lite
trust_remote_code: true
### method
stage: sft
do_train: true
finetuning_type: lora
lora_rank: 8
lora_target: all
### dataset
dataset: identity
template: deepseek
cutoff_len: 2048
max_samples: 100000
overwrite_cache: true
preprocessing_num_workers: 16
dataloader_num_workers: 4
### output
output_dir: saves/Kllama_deepseekV2
logging_steps: 10
save_steps: 500
plot_loss: true
overwrite_output_dir: true
save_only_model: false
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
### train
per_device_train_batch_size: 1
gradient_accumulation_steps: 8
learning_rate: 1.0e-4
num_train_epochs: 3.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
bf16: true
ddp_timeout: 180000000
resume_from_checkpoint: null
### ktransformers
use_kt: true # use KTransformers as LoRA sft backend
kt_optimize_rule: examples/kt_optimize_rules/DeepSeek-V2-Lite-Chat-sft-amx.yaml
cpu_infer: 32
chunk_size: 8192
### eval
# eval_dataset: alpaca_en_demo
# val_size: 0.1
# per_device_eval_batch_size: 1
# eval_strategy: steps
# eval_steps: 500

View File

@@ -0,0 +1,52 @@
### model
model_name_or_path: opensourcerelease/DeepSeek-V3-bf16
trust_remote_code: true
### method
stage: sft
do_train: true
finetuning_type: lora
lora_rank: 8
lora_target: all
### dataset
dataset: identity
template: deepseek
cutoff_len: 2048
max_samples: 100000
overwrite_cache: true
preprocessing_num_workers: 16
dataloader_num_workers: 4
### output
output_dir: saves/Kllama_deepseekV3
logging_steps: 10
save_steps: 500
plot_loss: true
overwrite_output_dir: true
save_only_model: false
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
### train
per_device_train_batch_size: 1
gradient_accumulation_steps: 8
learning_rate: 1.0e-4
num_train_epochs: 3.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
bf16: true
ddp_timeout: 180000000
resume_from_checkpoint: null
### ktransformers
use_kt: true # use KTransformers as LoRA sft backend
kt_optimize_rule: examples/kt_optimize_rules/DeepSeek-V3-Chat-sft-amx-multi-gpu.yaml
cpu_infer: 32
chunk_size: 8192
### eval
# eval_dataset: alpaca_en_demo
# val_size: 0.1
# per_device_eval_batch_size: 1
# eval_strategy: steps
# eval_steps: 500

View File

@@ -0,0 +1,52 @@
### model
model_name_or_path: Qwen/Qwen3-235B-A22B-Instruct-2507
trust_remote_code: true
### method
stage: sft
do_train: true
finetuning_type: lora
lora_rank: 8
lora_target: all
### dataset
dataset: identity, alpaca_en_demo
template: qwen3_nothink
cutoff_len: 2048
max_samples: 100000
overwrite_cache: true
preprocessing_num_workers: 16
dataloader_num_workers: 4
### output
output_dir: saves/Kllama_Qwen3MoE_235bA22b
logging_steps: 10
save_steps: 200
plot_loss: true
overwrite_output_dir: true
save_only_model: false
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
### train
per_device_train_batch_size: 1
gradient_accumulation_steps: 8
learning_rate: 1.0e-4
num_train_epochs: 3
lr_scheduler_type: cosine
warmup_ratio: 0.1
bf16: true
ddp_timeout: 180000000
resume_from_checkpoint: null
### ktransformers
use_kt: true # use KTransformers as LoRA sft backend
kt_optimize_rule: examples/kt_optimize_rules/Qwen3Moe-sft-amx.yaml
cpu_infer: 32
chunk_size: 8192
### eval
# eval_dataset: alpaca_en_demo
# val_size: 0.1
# per_device_eval_batch_size: 1
# eval_strategy: steps
# eval_steps: 500

View File

@@ -1,42 +1,123 @@
[build-system] [build-system]
requires = ["setuptools>=61.0"] requires = ["hatchling"]
build-backend = "setuptools.build_meta" build-backend = "hatchling.build"
[project] [project]
name = "llamafactory" name = "llamafactory"
requires-python = ">=3.9.0" dynamic = ["version"]
dynamic = [ description = "Unified Efficient Fine-Tuning of 100+ LLMs"
"version", readme = "README.md"
"dependencies", license = "Apache-2.0"
"optional-dependencies", requires-python = ">=3.11.0"
"scripts", authors = [
"authors", { name = "hiyouga", email = "hiyouga@buaa.edu.cn" }
"description", ]
"readme", keywords = [
"license", "AI",
"keywords", "LLM",
"classifiers" "GPT",
"ChatGPT",
"Llama",
"Transformer",
"DeepSeek",
"Pytorch"
]
classifiers = [
"Development Status :: 4 - Beta",
"Intended Audience :: Developers",
"Intended Audience :: Education",
"Intended Audience :: Science/Research",
"License :: OSI Approved :: Apache Software License",
"Operating System :: OS Independent",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
"Topic :: Scientific/Engineering :: Artificial Intelligence"
]
dependencies = [
# core deps
"torch>=2.4.0",
"torchvision>=0.19.0",
"torchaudio>=2.4.0",
"transformers>=4.49.0,<=4.56.2,!=4.52.0; python_version < '3.10'",
"transformers>=4.49.0,<=4.57.1,!=4.52.0,!=4.57.0; python_version >= '3.10'",
"datasets>=2.16.0,<=4.0.0",
"accelerate>=1.3.0,<=1.11.0",
"peft>=0.14.0,<=0.17.1",
"trl>=0.8.6,<=0.9.6",
"torchdata>=0.10.0,<=0.11.0",
# gui
"gradio>=4.38.0,<=5.50.0",
"matplotlib>=3.7.0",
"tyro<0.9.0",
# ops
"einops",
"numpy",
"pandas",
"scipy",
# model and tokenizer
"sentencepiece",
"tiktoken",
"modelscope",
"hf-transfer",
"safetensors",
# python
"av",
"fire",
"omegaconf",
"packaging",
"protobuf",
"pyyaml",
"pydantic",
# api
"uvicorn",
"fastapi",
"sse-starlette"
] ]
[project.optional-dependencies]
dev = ["pre-commit", "ruff", "pytest", "build"]
metrics = ["nltk", "jieba", "rouge-chinese"]
deepspeed = ["deepspeed>=0.10.0,<=0.16.9"]
[project.scripts]
llamafactory-cli = "llamafactory.cli:main"
lmf = "llamafactory.cli:main"
[project.urls]
Homepage = "https://github.com/hiyouga/LLaMA-Factory"
Repository = "https://github.com/hiyouga/LLaMA-Factory"
[tool.hatch.build.targets.wheel]
packages = ["src/llamafactory"]
[tool.hatch.version]
path = "src/llamafactory/extras/env.py"
pattern = "VERSION = \"(?P<version>[^\"]+)\""
[tool.ruff] [tool.ruff]
target-version = "py39" target-version = "py311"
line-length = 119 line-length = 119
indent-width = 4 indent-width = 4
[tool.ruff.lint] [tool.ruff.lint]
ignore = [ ignore = [
"C408", # collection "C408", # collection
"C901", # complex "C901", # complex
"E501", # line too long "E501", # line too long
"E731", # lambda function "E731", # lambda function
"E741", # ambiguous var name "E741", # ambiguous var name
"D100", # no doc public module "UP007", # no upgrade union
"D101", # no doc public class "UP045", # no upgrade optional
"D102", # no doc public method "D100", # no doc public module
"D103", # no doc public function "D101", # no doc public class
"D104", # no doc public package "D102", # no doc public method
"D105", # no doc magic method "D103", # no doc public function
"D107", # no doc __init__ "D104", # no doc public package
"D105", # no doc magic method
"D107", # no doc __init__
] ]
extend-select = [ extend-select = [
"C", # complexity "C", # complexity
@@ -73,23 +154,3 @@ indent-style = "space"
docstring-code-format = true docstring-code-format = true
skip-magic-trailing-comma = false skip-magic-trailing-comma = false
line-ending = "auto" line-ending = "auto"
[tool.uv]
conflicts = [
[
{ extra = "torch-npu" },
{ extra = "aqlm" },
],
[
{ extra = "torch-npu" },
{ extra = "vllm" },
],
[
{ extra = "torch-npu" },
{ extra = "sglang" },
],
[
{ extra = "vllm" },
{ extra = "sglang" },
],
]

View File

@@ -1,36 +0,0 @@
# core deps
transformers>=4.49.0,<=4.56.2,!=4.52.0; python_version < '3.10'
transformers>=4.49.0,<=4.57.1,!=4.52.0; python_version >= '3.10'
datasets>=2.16.0,<=4.0.0
accelerate>=1.3.0,<=1.11.0
peft>=0.14.0,<=0.17.1
trl>=0.8.6,<=0.9.6
# gui
gradio>=4.38.0,<=5.45.0
matplotlib>=3.7.0
tyro<0.9.0
# ops
einops
numpy<2.0.0
pandas>=2.0.0
scipy
# model and tokenizer
sentencepiece
tiktoken
modelscope>=1.14.0
hf-transfer
safetensors<=0.5.3
# python
fire
omegaconf
packaging
protobuf
pyyaml
pydantic<=2.10.6
# api
uvicorn
fastapi
sse-starlette
# media
av
librosa

124
scripts/megatron_merge.py Normal file
View File

@@ -0,0 +1,124 @@
# Copyright 2025 the ROLL team and the LlamaFactory team.
#
# This code is modified from the ROLL library.
# https://github.com/alibaba/ROLL/blob/main/mcore_adapter/tools/convert.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import fire
import torch
from mcore_adapter.models.converter.post_converter import convert_checkpoint_to_hf, convert_checkpoint_to_mca
from mcore_adapter.training_args import DistributingParallelArguments
from mcore_adapter.utils import get_logger
from transformers import AutoConfig
logger = get_logger(__name__)
def convert_mca_to_hf(
checkpoint_path: str,
output_path: str = "./output",
bf16: bool = False,
fp16: bool = False,
convert_model_max_length: int | None = None,
):
"""Convert megatron checkpoint to HuggingFace format.
Args:
checkpoint_path: Path to the checkpoint to convert
output_path: Path to save the converted checkpoint
bf16: Use bfloat16 precision
fp16: Use float16 precision
convert_model_max_length: Change the model_max_length in hf config.json
"""
if bf16 and fp16:
raise ValueError("bf16 and fp16 cannot be both True.")
torch_dtype = None
if bf16:
torch_dtype = torch.bfloat16
elif fp16:
torch_dtype = torch.float16
convert_checkpoint_to_hf(checkpoint_path, output_path, torch_dtype=torch_dtype)
if convert_model_max_length is not None:
config = AutoConfig.from_pretrained(output_path, trust_remote_code=True)
config.model_max_length = convert_model_max_length
config.save_pretrained(output_path)
def convert(
checkpoint_path: str,
output_path: str = "./output",
bf16: bool = False,
fp16: bool = False,
convert_model_max_length: int | None = None,
tensor_model_parallel_size: int = 1,
pipeline_model_parallel_size: int = 1,
expert_model_parallel_size: int = 1,
virtual_pipeline_model_parallel_size: int | None = None,
):
"""Convert checkpoint between MCA and HuggingFace formats.
Args:
checkpoint_path: Path to the checkpoint to convert
output_path: Path to save the converted checkpoint
bf16: Use bfloat16 precision
fp16: Use float16 precision
convert_model_max_length: Change the model_max_length in hf config.json
tensor_model_parallel_size: Tensor model parallel size
pipeline_model_parallel_size: Pipeline model parallel size
expert_model_parallel_size: Expert model parallel size
virtual_pipeline_model_parallel_size: Virtual pipeline model parallel size
"""
if bf16 and fp16:
raise ValueError("bf16 and fp16 cannot be both True.")
mca_config_path = os.path.join(checkpoint_path, "mca_config.json")
from_mca = os.path.exists(mca_config_path)
if not from_mca:
dist_args = DistributingParallelArguments(
tensor_model_parallel_size=tensor_model_parallel_size,
pipeline_model_parallel_size=pipeline_model_parallel_size,
expert_model_parallel_size=expert_model_parallel_size,
virtual_pipeline_model_parallel_size=virtual_pipeline_model_parallel_size,
)
convert_checkpoint_to_mca(
checkpoint_path,
output_path,
dist_args,
bf16=bf16,
fp16=fp16,
)
else:
convert_mca_to_hf(
checkpoint_path=checkpoint_path,
output_path=output_path,
bf16=bf16,
fp16=fp16,
convert_model_max_length=convert_model_max_length,
)
def main():
fire.Fire(convert)
if __name__ == "__main__":
main()

View File

@@ -14,7 +14,7 @@
import json import json
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Literal, Optional from typing import Any, Literal
import fire import fire
import torch import torch
@@ -61,7 +61,7 @@ def calculate_ppl(
dataset_dir: str = "data", dataset_dir: str = "data",
template: str = "default", template: str = "default",
cutoff_len: int = 2048, cutoff_len: int = 2048,
max_samples: Optional[int] = None, max_samples: int | None = None,
train_on_prompt: bool = False, train_on_prompt: bool = False,
): ):
r"""Calculate the ppl on the dataset of the pre-trained models. r"""Calculate the ppl on the dataset of the pre-trained models.

View File

@@ -14,8 +14,8 @@
import gc import gc
import json import json
from typing import Optional
import av
import fire import fire
from tqdm import tqdm from tqdm import tqdm
from transformers import Seq2SeqTrainingArguments from transformers import Seq2SeqTrainingArguments
@@ -33,6 +33,14 @@ if is_vllm_available():
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
def _need_video_kwargs(template):
NEEDED_TEMPLATE = ["qwen3_vl", "glm4v"]
if any(t in template for t in NEEDED_TEMPLATE):
return True
return False
def vllm_infer( def vllm_infer(
model_name_or_path: str, model_name_or_path: str,
adapter_name_or_path: str = None, adapter_name_or_path: str = None,
@@ -40,7 +48,7 @@ def vllm_infer(
dataset_dir: str = "data", dataset_dir: str = "data",
template: str = "default", template: str = "default",
cutoff_len: int = 2048, cutoff_len: int = 2048,
max_samples: Optional[int] = None, max_samples: int | None = None,
vllm_config: str = "{}", vllm_config: str = "{}",
save_name: str = "generated_predictions.jsonl", save_name: str = "generated_predictions.jsonl",
temperature: float = 0.95, temperature: float = 0.95,
@@ -49,9 +57,9 @@ def vllm_infer(
max_new_tokens: int = 1024, max_new_tokens: int = 1024,
repetition_penalty: float = 1.0, repetition_penalty: float = 1.0,
skip_special_tokens: bool = True, skip_special_tokens: bool = True,
default_system: Optional[str] = None, default_system: str | None = None,
enable_thinking: bool = True, enable_thinking: bool = True,
seed: Optional[int] = None, seed: int | None = None,
pipeline_parallel_size: int = 1, pipeline_parallel_size: int = 1,
image_max_pixels: int = 768 * 768, image_max_pixels: int = 768 * 768,
image_min_pixels: int = 32 * 32, image_min_pixels: int = 32 * 32,
@@ -132,6 +140,7 @@ def vllm_infer(
# Store all results in these lists # Store all results in these lists
all_prompts, all_preds, all_labels = [], [], [] all_prompts, all_preds, all_labels = [], [], []
need_video_kwargs = _need_video_kwargs(template)
# Add batch process to avoid the issue of too many files opened # Add batch process to avoid the issue of too many files opened
for i in tqdm(range(0, len(train_dataset), batch_size), desc="Processing batched inference"): for i in tqdm(range(0, len(train_dataset), batch_size), desc="Processing batched inference"):
@@ -147,6 +156,7 @@ def vllm_infer(
)["images"] )["images"]
} }
elif batch["videos"][j] is not None: elif batch["videos"][j] is not None:
video_metadata, video_metadata_kwargs = None, None
video = batch["videos"][j] video = batch["videos"][j]
multi_modal_data = { multi_modal_data = {
"video": template_obj.mm_plugin._regularize_videos( "video": template_obj.mm_plugin._regularize_videos(
@@ -157,6 +167,25 @@ def vllm_infer(
video_maxlen=video_maxlen, video_maxlen=video_maxlen,
)["videos"] )["videos"]
} }
if need_video_kwargs:
container = av.open(video[0], "r")
video_stream = next(stream for stream in container.streams if stream.type == "video")
sampling_indices = template_obj.mm_plugin._get_video_sample_indices(
video_stream, video_fps, video_maxlen
)
total_frames = video_stream.frames
video_metadata_kwargs = {
"fps": getattr(tokenizer_module["processor"], "video_fps", 24.0),
"do_sample_frames": False,
"total_num_frames": total_frames,
}
video_metadata = dict(
fps=video_fps,
frames_indices=sampling_indices,
total_num_frames=total_frames,
video_backend="opencv",
)
multi_modal_data["video"] = (multi_modal_data["video"], video_metadata)
elif batch["audios"][j] is not None: elif batch["audios"][j] is not None:
audio = batch["audios"][j] audio = batch["audios"][j]
audio_data = template_obj.mm_plugin._regularize_audios( audio_data = template_obj.mm_plugin._regularize_audios(
@@ -167,7 +196,11 @@ def vllm_infer(
else: else:
multi_modal_data = None multi_modal_data = None
vllm_inputs.append({"prompt_token_ids": batch["input_ids"][j], "multi_modal_data": multi_modal_data}) vllm_input_data = {"prompt_token_ids": batch["input_ids"][j], "multi_modal_data": multi_modal_data}
if "video_metadata_kwargs" in locals() and video_metadata_kwargs is not None:
vllm_input_data["mm_processor_kwargs"] = video_metadata_kwargs
vllm_inputs.append(vllm_input_data)
prompts.append(tokenizer.decode(batch["input_ids"][j], skip_special_tokens=skip_special_tokens)) prompts.append(tokenizer.decode(batch["input_ids"][j], skip_special_tokens=skip_special_tokens))
labels.append( labels.append(
tokenizer.decode( tokenizer.decode(

116
setup.py
View File

@@ -1,116 +0,0 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import re
from setuptools import find_packages, setup
def get_version() -> str:
with open(os.path.join("src", "llamafactory", "extras", "env.py"), encoding="utf-8") as f:
file_content = f.read()
pattern = r"{}\W*=\W*\"([^\"]+)\"".format("VERSION")
(version,) = re.findall(pattern, file_content)
return version
def get_requires() -> list[str]:
with open("requirements.txt", encoding="utf-8") as f:
file_content = f.read()
lines = [line.strip() for line in file_content.strip().split("\n") if not line.startswith("#")]
return lines
def get_console_scripts() -> list[str]:
console_scripts = ["llamafactory-cli = llamafactory.cli:main"]
if os.getenv("ENABLE_SHORT_CONSOLE", "1").lower() in ["true", "y", "1"]:
console_scripts.append("lmf = llamafactory.cli:main")
return console_scripts
extra_require = {
"torch": ["torch>=2.0.0", "torchvision>=0.15.0"],
"torch-npu": ["torch-npu==2.5.1", "torchvision==0.20.1", "decorator"],
"metrics": ["nltk", "jieba", "rouge-chinese"],
"deepspeed": ["deepspeed>=0.10.0,<=0.16.9"],
"liger-kernel": ["liger-kernel>=0.5.5"],
"bitsandbytes": ["bitsandbytes>=0.39.0"],
"hqq": ["hqq"],
"eetq": ["eetq"],
"gptq": ["optimum>=1.24.0", "gptqmodel>=2.0.0"],
"aqlm": ["aqlm[gpu]>=1.1.0"],
"vllm": ["vllm>=0.4.3,<=0.11.0"],
"sglang": ["sglang[srt]>=0.4.5", "transformers==4.51.1"],
"galore": ["galore-torch"],
"apollo": ["apollo-torch"],
"badam": ["badam>=1.2.1"],
"adam-mini": ["adam-mini"],
"minicpm_v": [
"soundfile",
"torchvision",
"torchaudio",
"vector_quantize_pytorch",
"vocos",
"msgpack",
"referencing",
"jsonschema_specifications",
],
"openmind": ["openmind"],
"swanlab": ["swanlab"],
"fp8": ["torchao>=0.8.0", "accelerate>=1.10.0"],
"fp8-te": ["transformer_engine[pytorch]>=2.0.0", "accelerate>=1.10.0"],
"fp8-all": ["torchao>=0.8.0", "transformer_engine[pytorch]>=2.0.0", "accelerate>=1.10.0"],
"dev": ["pre-commit", "ruff", "pytest", "build"],
}
def main():
setup(
name="llamafactory",
version=get_version(),
author="hiyouga",
author_email="hiyouga@buaa.edu.cn",
description="Unified Efficient Fine-Tuning of 100+ LLMs",
long_description=open("README.md", encoding="utf-8").read(),
long_description_content_type="text/markdown",
keywords=["AI", "LLM", "GPT", "ChatGPT", "Llama", "Transformer", "DeepSeek", "Pytorch"],
license="Apache 2.0 License",
url="https://github.com/hiyouga/LLaMA-Factory",
package_dir={"": "src"},
packages=find_packages("src"),
python_requires=">=3.9.0",
install_requires=get_requires(),
extras_require=extra_require,
entry_points={"console_scripts": get_console_scripts()},
classifiers=[
"Development Status :: 4 - Beta",
"Intended Audience :: Developers",
"Intended Audience :: Education",
"Intended Audience :: Science/Research",
"License :: OSI Approved :: Apache Software License",
"Operating System :: OS Independent",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
],
)
if __name__ == "__main__":
main()

View File

@@ -16,7 +16,7 @@ import asyncio
import os import os
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from functools import partial from functools import partial
from typing import Annotated, Optional from typing import Annotated
from ..chat import ChatModel from ..chat import ChatModel
from ..extras.constants import EngineName from ..extras.constants import EngineName
@@ -79,7 +79,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
api_key = os.getenv("API_KEY") api_key = os.getenv("API_KEY")
security = HTTPBearer(auto_error=False) security = HTTPBearer(auto_error=False)
async def verify_api_key(auth: Annotated[Optional[HTTPAuthorizationCredentials], Depends(security)]): async def verify_api_key(auth: Annotated[HTTPAuthorizationCredentials | None, Depends(security)]):
if api_key and (auth is None or auth.credentials != api_key): if api_key and (auth is None or auth.credentials != api_key):
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key.") raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key.")

View File

@@ -14,10 +14,9 @@
import time import time
from enum import Enum, unique from enum import Enum, unique
from typing import Any, Optional, Union from typing import Any, Literal
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing_extensions import Literal
@unique @unique
@@ -61,7 +60,7 @@ class FunctionDefinition(BaseModel):
class FunctionAvailable(BaseModel): class FunctionAvailable(BaseModel):
type: Literal["function", "code_interpreter"] = "function" type: Literal["function", "code_interpreter"] = "function"
function: Optional[FunctionDefinition] = None function: FunctionDefinition | None = None
class FunctionCall(BaseModel): class FunctionCall(BaseModel):
@@ -77,35 +76,35 @@ class URL(BaseModel):
class MultimodalInputItem(BaseModel): class MultimodalInputItem(BaseModel):
type: Literal["text", "image_url", "video_url", "audio_url"] type: Literal["text", "image_url", "video_url", "audio_url"]
text: Optional[str] = None text: str | None = None
image_url: Optional[URL] = None image_url: URL | None = None
video_url: Optional[URL] = None video_url: URL | None = None
audio_url: Optional[URL] = None audio_url: URL | None = None
class ChatMessage(BaseModel): class ChatMessage(BaseModel):
role: Role role: Role
content: Optional[Union[str, list[MultimodalInputItem]]] = None content: str | list[MultimodalInputItem] | None = None
tool_calls: Optional[list[FunctionCall]] = None tool_calls: list[FunctionCall] | None = None
class ChatCompletionMessage(BaseModel): class ChatCompletionMessage(BaseModel):
role: Optional[Role] = None role: Role | None = None
content: Optional[str] = None content: str | None = None
tool_calls: Optional[list[FunctionCall]] = None tool_calls: list[FunctionCall] | None = None
class ChatCompletionRequest(BaseModel): class ChatCompletionRequest(BaseModel):
model: str model: str
messages: list[ChatMessage] messages: list[ChatMessage]
tools: Optional[list[FunctionAvailable]] = None tools: list[FunctionAvailable] | None = None
do_sample: Optional[bool] = None do_sample: bool | None = None
temperature: Optional[float] = None temperature: float | None = None
top_p: Optional[float] = None top_p: float | None = None
n: int = 1 n: int = 1
presence_penalty: Optional[float] = None presence_penalty: float | None = None
max_tokens: Optional[int] = None max_tokens: int | None = None
stop: Optional[Union[str, list[str]]] = None stop: str | list[str] | None = None
stream: bool = False stream: bool = False
@@ -118,7 +117,7 @@ class ChatCompletionResponseChoice(BaseModel):
class ChatCompletionStreamResponseChoice(BaseModel): class ChatCompletionStreamResponseChoice(BaseModel):
index: int index: int
delta: ChatCompletionMessage delta: ChatCompletionMessage
finish_reason: Optional[Finish] = None finish_reason: Finish | None = None
class ChatCompletionResponseUsage(BaseModel): class ChatCompletionResponseUsage(BaseModel):
@@ -147,7 +146,7 @@ class ChatCompletionStreamResponse(BaseModel):
class ScoreEvaluationRequest(BaseModel): class ScoreEvaluationRequest(BaseModel):
model: str model: str
messages: list[str] messages: list[str]
max_length: Optional[int] = None max_length: int | None = None
class ScoreEvaluationResponse(BaseModel): class ScoreEvaluationResponse(BaseModel):

View File

@@ -71,6 +71,16 @@ class ChatModel:
"SGLang not install, you may need to run `pip install sglang[all]`\n" "SGLang not install, you may need to run `pip install sglang[all]`\n"
"or try to use HuggingFace backend: --infer_backend huggingface" "or try to use HuggingFace backend: --infer_backend huggingface"
) from e ) from e
elif model_args.infer_backend == EngineName.KT:
try:
from .kt_engine import KTransformersEngine
self.engine: BaseEngine = KTransformersEngine(model_args, data_args, finetuning_args, generating_args)
except ImportError as e:
raise ImportError(
"KTransformers not install, you may need to run `pip install ktransformers`\n"
"or try to use HuggingFace backend: --infer_backend huggingface"
) from e
else: else:
raise NotImplementedError(f"Unknown backend: {model_args.infer_backend}") raise NotImplementedError(f"Unknown backend: {model_args.infer_backend}")

View File

@@ -14,9 +14,9 @@
import asyncio import asyncio
import os import os
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator, Callable
from threading import Thread from threading import Thread
from typing import TYPE_CHECKING, Any, Callable, Optional, Union from typing import TYPE_CHECKING, Any, Optional, Union
import torch import torch
from transformers import GenerationConfig, TextIteratorStreamer from transformers import GenerationConfig, TextIteratorStreamer

View File

@@ -0,0 +1,284 @@
# Copyright 2025 the KVCache.AI team, Approaching AI, and the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import os
import platform
from collections.abc import AsyncGenerator
from threading import Thread
from typing import TYPE_CHECKING, Any, Optional
import torch
from typing_extensions import override
from ..data import get_template_and_fix_tokenizer
from ..extras import logging
from ..extras.constants import EngineName
from ..model import load_model, load_tokenizer
from .base_engine import BaseEngine, Response
if TYPE_CHECKING:
from transformers import PreTrainedTokenizer
from trl import PreTrainedModelWrapper
from ..data.mm_plugin import AudioInput, ImageInput, VideoInput
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled
from ktransformers.server.config.config import Config
from ktransformers.util.utils import (
get_compute_capability,
prefill_and_generate_capture,
)
from ktransformers.util.vendors import GPUVendor, device_manager
logger = logging.get_logger(__name__)
class KTransformersEngine(BaseEngine):
def __init__(
self,
model_args: "ModelArguments",
data_args: "DataArguments",
finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments",
) -> None:
self.name = EngineName.KT
self.can_generate = finetuning_args.stage == "sft"
tok_mod = load_tokenizer(model_args)
self.tokenizer = tok_mod["tokenizer"]
self.tokenizer.padding_side = "left" if self.can_generate else "right"
self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args)
self.model = load_model(
self.tokenizer, model_args, finetuning_args, is_trainable=False, add_valuehead=(not self.can_generate)
)
self.generating_args = generating_args.to_dict()
self.max_new_tokens = model_args.kt_maxlen
self.use_cuda_graph = model_args.kt_use_cuda_graph
self.mode = model_args.kt_mode
self.force_think = model_args.kt_force_think
self.chunk_size = model_args.chunk_size
try:
asyncio.get_event_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
self.semaphore = asyncio.Semaphore(int(os.getenv("MAX_CONCURRENT", "1")))
@staticmethod
@torch.inference_mode()
def _get_scores(
model: "PreTrainedModelWrapper",
tokenizer: "PreTrainedTokenizer",
batch_input: list[str],
input_kwargs: Optional[dict[str, Any]] = {},
) -> list[float]:
max_length: Optional[int] = input_kwargs.pop("max_length", None)
device = getattr(model.pretrained_model, "device", "cuda")
inputs = tokenizer(
batch_input,
padding=True,
truncation=True,
max_length=max_length or getattr(model.config, "max_position_embeddings", 1024),
return_tensors="pt",
add_special_tokens=False,
).to(device)
values: torch.Tensor = model(**inputs, return_dict=True, use_cache=False)[-1]
scores = values.gather(dim=-1, index=(inputs["attention_mask"].sum(dim=-1, keepdim=True) - 1))
return scores
async def _generate(
self,
messages: list[dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
**input_kwargs,
) -> AsyncGenerator[str, None]:
paired = messages + [{"role": "assistant", "content": ""}]
prompt_ids, _ = self.template.encode_oneturn(self.tokenizer, paired, system, tools)
prompt_len = len(prompt_ids)
max_length: Optional[int] = input_kwargs.pop("max_length", None)
max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None)
if "max_new_tokens" in self.generating_args:
max_tokens = int(self.generating_args["max_new_tokens"])
elif "max_length" in self.generating_args:
gl = int(self.generating_args["max_length"])
max_tokens = gl - prompt_len if gl > prompt_len else 1
else:
max_tokens = self.max_new_tokens or 256
if max_length is not None:
max_tokens = max(max_length - prompt_len, 1)
if max_new_tokens is not None:
max_tokens = int(max_new_tokens)
max_tokens = max(1, int(max_tokens))
if self.mode == "long_context":
max_len_cfg = Config().long_context_config["max_seq_len"]
need = prompt_len + max_tokens
assert max_len_cfg > need, f"please set max_seq_len > {need} in ~/.ktransformers/config.yaml"
device = next(self.model.parameters()).device
input_tensor = torch.tensor([prompt_ids], dtype=torch.long, device=device)
if self.force_think:
think = torch.tensor(
[self.tokenizer.encode("<think>\n", add_special_tokens=False)], dtype=torch.long, device=device
)
input_tensor = torch.cat([input_tensor, think], dim=1)
use_flashinfer = (
platform.system() != "Windows"
and getattr(self.model.config, "architectures", [""])[0]
in {"DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"}
and flashinfer_enabled
and get_compute_capability() >= 8
and device_manager.gpu_vendor == GPUVendor.NVIDIA
)
def make_gen():
if use_flashinfer:
return prefill_and_generate_capture(
self.model,
self.tokenizer,
input_tensor,
max_tokens,
self.use_cuda_graph,
mode=self.mode,
force_think=self.force_think,
chunk_size=self.chunk_size,
use_flashinfer_mla=True,
num_heads=self.model.config.num_attention_heads,
head_dim_ckv=getattr(self.model.config, "kv_lora_rank", 0),
head_dim_kpe=getattr(self.model.config, "qk_rope_head_dim", 0),
q_head_dim=getattr(self.model.config, "qk_rope_head_dim", 0)
+ getattr(self.model.config, "qk_nope_head_dim", 0),
echo_stream=False,
)
else:
return prefill_and_generate_capture(
self.model,
self.tokenizer,
input_tensor,
max_tokens,
self.use_cuda_graph,
mode=self.mode,
force_think=self.force_think,
chunk_size=self.chunk_size,
echo_stream=False,
)
loop = asyncio.get_running_loop()
q: asyncio.Queue[Optional[str]] = asyncio.Queue()
def producer():
try:
gen = make_gen()
if hasattr(gen, "__aiter__"):
async def drain_async():
async for t in gen:
loop.call_soon_threadsafe(q.put_nowait, t if isinstance(t, str) else str(t))
asyncio.run(drain_async())
elif hasattr(gen, "__iter__"):
for t in gen:
loop.call_soon_threadsafe(q.put_nowait, t if isinstance(t, str) else str(t))
else:
loop.call_soon_threadsafe(q.put_nowait, gen if isinstance(gen, str) else str(gen))
finally:
loop.call_soon_threadsafe(q.put_nowait, None)
Thread(target=producer, daemon=True).start()
while True:
item = await q.get()
if item is None:
break
yield item
@override
async def chat(
self,
messages: list[dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
images: Optional[list["ImageInput"]] = None,
videos: Optional[list["VideoInput"]] = None,
audios: Optional[list["AudioInput"]] = None,
**input_kwargs,
) -> list["Response"]:
if not self.can_generate:
raise ValueError("The current model does not support `chat`.")
async with self.semaphore:
produced = ""
final_text = ""
async for t in self._generate(messages, system, tools, **input_kwargs):
delta = t
produced = produced + delta
if delta:
final_text += delta
prompt_ids, _ = self.template.encode_oneturn(
self.tokenizer, messages + [{"role": "assistant", "content": ""}], system, tools
)
return [
Response(
response_text=final_text,
response_length=len(self.tokenizer.encode(final_text, add_special_tokens=False)),
prompt_length=len(prompt_ids),
finish_reason="stop",
)
]
@override
async def stream_chat(
self,
messages: list[dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
images: Optional[list["ImageInput"]] = None,
videos: Optional[list["VideoInput"]] = None,
audios: Optional[list["AudioInput"]] = None,
**input_kwargs,
) -> AsyncGenerator[str, None]:
if not self.can_generate:
raise ValueError("The current model does not support `stream_chat`.")
async with self.semaphore:
produced = ""
async for t in self._generate(messages, system, tools, **input_kwargs):
delta = t[len(produced) :] if t.startswith(produced) else t
produced = t
if delta:
yield delta
@override
async def get_scores(
self,
batch_input: list[str],
**input_kwargs,
) -> list[float]:
if self.can_generate:
raise ValueError("Cannot get scores using an auto-regressive model.")
args = (self.model, self.tokenizer, batch_input, input_kwargs)
async with self.semaphore:
return await asyncio.to_thread(self._get_scores, *args)

View File

@@ -16,6 +16,7 @@ import uuid
from collections.abc import AsyncGenerator, AsyncIterator from collections.abc import AsyncGenerator, AsyncIterator
from typing import TYPE_CHECKING, Any, Optional, Union from typing import TYPE_CHECKING, Any, Optional, Union
from packaging import version
from typing_extensions import override from typing_extensions import override
from ..data import get_template_and_fix_tokenizer from ..data import get_template_and_fix_tokenizer
@@ -77,11 +78,18 @@ class VllmEngine(BaseEngine):
"tensor_parallel_size": get_device_count() or 1, "tensor_parallel_size": get_device_count() or 1,
"gpu_memory_utilization": model_args.vllm_gpu_util, "gpu_memory_utilization": model_args.vllm_gpu_util,
"disable_log_stats": True, "disable_log_stats": True,
"disable_log_requests": True,
"enforce_eager": model_args.vllm_enforce_eager, "enforce_eager": model_args.vllm_enforce_eager,
"enable_lora": model_args.adapter_name_or_path is not None, "enable_lora": model_args.adapter_name_or_path is not None,
"max_lora_rank": model_args.vllm_max_lora_rank, "max_lora_rank": model_args.vllm_max_lora_rank,
} }
import vllm
if version.parse(vllm.__version__) <= version.parse("0.10.0"):
engine_args["disable_log_requests"] = True
else:
engine_args["enable_log_requests"] = False
if self.template.mm_plugin.__class__.__name__ != "BasePlugin": if self.template.mm_plugin.__class__.__name__ != "BasePlugin":
engine_args["limit_mm_per_prompt"] = {"image": 4, "video": 2, "audio": 2} engine_args["limit_mm_per_prompt"] = {"image": 4, "video": 2, "audio": 2}

View File

@@ -15,7 +15,7 @@ import json
import os import os
from abc import abstractmethod from abc import abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Optional, Union from typing import TYPE_CHECKING, Any, Union
from ..extras import logging from ..extras import logging
from .data_utils import Role from .data_utils import Role
@@ -40,7 +40,7 @@ class DatasetConverter:
dataset_attr: "DatasetAttr" dataset_attr: "DatasetAttr"
data_args: "DataArguments" data_args: "DataArguments"
def _find_medias(self, medias: Union["MediaType", list["MediaType"], None]) -> Optional[list["MediaType"]]: def _find_medias(self, medias: Union["MediaType", list["MediaType"], None]) -> list["MediaType"] | None:
r"""Optionally concatenate media path to media dir when loading from local disk.""" r"""Optionally concatenate media path to media dir when loading from local disk."""
if medias is None: if medias is None:
return None return None

View File

@@ -81,41 +81,48 @@ def split_dataset(
eval_dataset: Optional[Union["Dataset", "IterableDataset", dict[str, "Dataset"]]], eval_dataset: Optional[Union["Dataset", "IterableDataset", dict[str, "Dataset"]]],
data_args: "DataArguments", data_args: "DataArguments",
seed: int, seed: int,
) -> "DatasetDict": ) -> tuple[dict, dict]:
r"""Split the dataset and returns a dataset dict containing train set and validation set. r"""Split the dataset and returns two dicts containing train set and validation set.
Support both map dataset and iterable dataset. Support both map dataset and iterable dataset.
Returns:
train_dict: Dictionary containing training data with key "train"
eval_dict: Dictionary containing evaluation data with keys "validation" or "validation_{name}"
""" """
if eval_dataset is not None and data_args.val_size > 1e-6: if eval_dataset is not None and data_args.val_size > 1e-6:
raise ValueError("Cannot specify `val_size` if `eval_dataset` is not None.") raise ValueError("Cannot specify `val_size` if `eval_dataset` is not None.")
dataset_dict = {} # the train and eval better to in dict dtype and separately return for cpode clearly and good handle outside
train_dict, eval_dict = {}, {}
if dataset is not None: if dataset is not None:
if data_args.streaming: if data_args.streaming:
dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=seed) dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=seed)
if data_args.val_size > 1e-6: if data_args.val_size > 1e-6:
if data_args.streaming: if data_args.streaming:
dataset_dict["validation"] = dataset.take(int(data_args.val_size)) eval_dict["validation"] = dataset.take(int(data_args.val_size))
dataset_dict["train"] = dataset.skip(int(data_args.val_size)) train_dict["train"] = dataset.skip(int(data_args.val_size))
else: else:
val_size = int(data_args.val_size) if data_args.val_size > 1 else data_args.val_size val_size = int(data_args.val_size) if data_args.val_size > 1 else data_args.val_size
dataset_dict = dataset.train_test_split(test_size=val_size, seed=seed) split_result = dataset.train_test_split(test_size=val_size, seed=seed)
dataset = dataset.train_test_split(test_size=val_size, seed=seed) train_dict["train"] = split_result["train"]
dataset_dict = {"train": dataset["train"], "validation": dataset["test"]} eval_dict["validation"] = split_result["test"]
else: else:
dataset_dict["train"] = dataset train_dict["train"] = dataset
if eval_dataset is not None: if eval_dataset is not None:
if isinstance(eval_dataset, dict): if isinstance(eval_dataset, dict):
dataset_dict.update({f"validation_{name}": data for name, data in eval_dataset.items()}) for name, data in eval_dataset.items():
eval_dict[f"validation_{name}"] = data
else: else:
if data_args.streaming: if data_args.streaming:
eval_dataset = eval_dataset.shuffle(buffer_size=data_args.buffer_size, seed=seed) eval_dataset = eval_dataset.shuffle(buffer_size=data_args.buffer_size, seed=seed)
dataset_dict["validation"] = eval_dataset eval_dict["validation"] = eval_dataset
return DatasetDict(dataset_dict) return train_dict, eval_dict
def get_dataset_module(dataset: Union["Dataset", "DatasetDict"]) -> "DatasetModule": def get_dataset_module(dataset: Union["Dataset", "DatasetDict"]) -> "DatasetModule":

View File

@@ -16,7 +16,6 @@ import json
import re import re
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Optional, Union
from typing_extensions import override from typing_extensions import override
@@ -27,14 +26,14 @@ from .tool_utils import FunctionCall, get_tool_utils
@dataclass @dataclass
class Formatter(ABC): class Formatter(ABC):
slots: SLOTS = field(default_factory=list) slots: SLOTS = field(default_factory=list)
tool_format: Optional[str] = None tool_format: str | None = None
@abstractmethod @abstractmethod
def apply(self, **kwargs) -> SLOTS: def apply(self, **kwargs) -> SLOTS:
r"""Forms a list of slots according to the inputs to encode.""" r"""Forms a list of slots according to the inputs to encode."""
... ...
def extract(self, content: str) -> Union[str, list["FunctionCall"]]: def extract(self, content: str) -> str | list["FunctionCall"]:
r"""Extract a list of tuples from the response message if using tools. r"""Extract a list of tuples from the response message if using tools.
Each tuple consists of function name and function arguments. Each tuple consists of function name and function arguments.
@@ -97,31 +96,46 @@ class FunctionFormatter(StringFormatter):
@override @override
def apply(self, **kwargs) -> SLOTS: def apply(self, **kwargs) -> SLOTS:
content: str = kwargs.pop("content") content: str = kwargs.pop("content")
thought_words, thought = kwargs.pop("thought_words", None), None thought_words = kwargs.pop("thought_words", None)
if thought_words and len(thought_words) == 2: tool_call_words = kwargs.pop("tool_call_words", None)
regex = re.compile(rf"{re.escape(thought_words[0])}(.*?){re.escape(thought_words[1])}", re.DOTALL)
thought = re.search(regex, content)
if thought: def _parse_functions(json_content: str) -> list["FunctionCall"]:
content = content.replace(thought.group(0), "") try:
tool_calls = json.loads(json_content)
if not isinstance(tool_calls, list): # parallel function call
tool_calls = [tool_calls]
functions: list[FunctionCall] = [] return [FunctionCall(tc["name"], json.dumps(tc["arguments"], ensure_ascii=False)) for tc in tool_calls]
try: except json.JSONDecodeError:
tool_calls = json.loads(content) raise RuntimeError(f"Invalid JSON format in function message: {str([content])}.")
if not isinstance(tool_calls, list): # parallel function call
tool_calls = [tool_calls]
for tool_call in tool_calls: tool_call_match = None
functions.append( if tool_call_words and len(tool_call_words) == 2:
FunctionCall(tool_call["name"], json.dumps(tool_call["arguments"], ensure_ascii=False)) tool_call_regex = re.compile(
) rf"{re.escape(tool_call_words[0])}(.*?){re.escape(tool_call_words[1])}", re.DOTALL
)
tool_call_match = re.search(tool_call_regex, content)
except json.JSONDecodeError: if tool_call_match is None:
raise RuntimeError(f"Invalid JSON format in function message: {str([content])}.") # flat string thought_match = None
if thought_words and len(thought_words) == 2:
regex = re.compile(rf"{re.escape(thought_words[0])}(.*?){re.escape(thought_words[1])}", re.DOTALL)
thought_match = re.search(regex, content)
function_str = self.tool_utils.function_formatter(functions) if thought_match:
if thought: json_part = content.replace(thought_match.group(0), "")
function_str = thought.group(0) + function_str else:
json_part = content
functions = _parse_functions(json_part)
function_str = self.tool_utils.function_formatter(functions)
if thought_match:
function_str = thought_match.group(0) + function_str
else:
thought_content = content.replace(tool_call_match.group(0), "")
functions = _parse_functions(tool_call_match.group(1))
function_str = self.tool_utils.function_formatter(functions)
function_str = thought_content + function_str
return super().apply(content=function_str) return super().apply(content=function_str)
@@ -141,5 +155,5 @@ class ToolFormatter(Formatter):
raise RuntimeError(f"Invalid JSON format in tool description: {str([content])}.") # flat string raise RuntimeError(f"Invalid JSON format in tool description: {str([content])}.") # flat string
@override @override
def extract(self, content: str) -> Union[str, list["FunctionCall"]]: def extract(self, content: str) -> str | list["FunctionCall"]:
return self.tool_utils.tool_extractor(content) return self.tool_utils.tool_extractor(content)

View File

@@ -16,7 +16,7 @@ import os
from typing import TYPE_CHECKING, Literal, Optional, Union from typing import TYPE_CHECKING, Literal, Optional, Union
import numpy as np import numpy as np
from datasets import Dataset, load_dataset, load_from_disk from datasets import Dataset, DatasetDict, load_dataset, load_from_disk
from ..extras import logging from ..extras import logging
from ..extras.constants import FILEEXT2TYPE from ..extras.constants import FILEEXT2TYPE
@@ -137,7 +137,6 @@ def _load_single_dataset(
cache_dir=model_args.cache_dir, cache_dir=model_args.cache_dir,
token=model_args.hf_hub_token, token=model_args.hf_hub_token,
num_proc=data_args.preprocessing_num_workers, num_proc=data_args.preprocessing_num_workers,
trust_remote_code=model_args.trust_remote_code,
streaming=data_args.streaming and dataset_attr.load_from != "file", streaming=data_args.streaming and dataset_attr.load_from != "file",
) )
if data_args.streaming and dataset_attr.load_from == "file": if data_args.streaming and dataset_attr.load_from == "file":
@@ -163,13 +162,13 @@ def _load_single_dataset(
def _get_merged_dataset( def _get_merged_dataset(
dataset_names: Optional[list[str]], dataset_names: list[str] | None,
model_args: "ModelArguments", model_args: "ModelArguments",
data_args: "DataArguments", data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments", training_args: "Seq2SeqTrainingArguments",
stage: Literal["pt", "sft", "rm", "ppo", "kto"], stage: Literal["pt", "sft", "rm", "ppo", "kto"],
return_dict: bool = False, return_dict: bool = False,
) -> Optional[Union["Dataset", "IterableDataset", dict[str, "Dataset"]]]: ) -> Union["Dataset", "IterableDataset", dict[str, "Dataset"]] | None:
r"""Return the merged datasets in the standard format.""" r"""Return the merged datasets in the standard format."""
if dataset_names is None: if dataset_names is None:
return None return None
@@ -228,7 +227,7 @@ def _get_dataset_processor(
def _get_preprocessed_dataset( def _get_preprocessed_dataset(
dataset: Optional[Union["Dataset", "IterableDataset"]], dataset: Union["Dataset", "IterableDataset"] | None,
data_args: "DataArguments", data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments", training_args: "Seq2SeqTrainingArguments",
stage: Literal["pt", "sft", "rm", "ppo", "kto"], stage: Literal["pt", "sft", "rm", "ppo", "kto"],
@@ -236,7 +235,7 @@ def _get_preprocessed_dataset(
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"] = None, processor: Optional["ProcessorMixin"] = None,
is_eval: bool = False, is_eval: bool = False,
) -> Optional[Union["Dataset", "IterableDataset"]]: ) -> Union["Dataset", "IterableDataset"] | None:
r"""Preprocesses the dataset, including format checking and tokenization.""" r"""Preprocesses the dataset, including format checking and tokenization."""
if dataset is None: if dataset is None:
return None return None
@@ -312,20 +311,22 @@ def get_dataset(
) )
with training_args.main_process_first(desc="pre-process dataset", local=(not data_args.data_shared_file_system)): with training_args.main_process_first(desc="pre-process dataset", local=(not data_args.data_shared_file_system)):
dataset = _get_preprocessed_dataset( # move front to make sure eval_dataset(if contain or split) can preprocessed appropriately
dataset, data_args, training_args, stage, template, tokenizer, processor, is_eval=False train_dict, eval_dict = split_dataset(dataset, eval_dataset, data_args, seed=training_args.seed)
)
if isinstance(eval_dataset, dict): if "train" in train_dict:
for eval_name, eval_data in eval_dataset.items(): train_dict["train"] = _get_preprocessed_dataset(
eval_dataset[eval_name] = _get_preprocessed_dataset( train_dict["train"], data_args, training_args, stage, template, tokenizer, processor, is_eval=False
eval_data, data_args, training_args, stage, template, tokenizer, processor, is_eval=True
)
else:
eval_dataset = _get_preprocessed_dataset(
eval_dataset, data_args, training_args, stage, template, tokenizer, processor, is_eval=True
) )
dataset_dict = split_dataset(dataset, eval_dataset, data_args, seed=training_args.seed) for key in eval_dict:
eval_dict[key] = _get_preprocessed_dataset(
eval_dict[key], data_args, training_args, stage, template, tokenizer, processor, is_eval=True
)
# Combine train and eval dictionaries
dataset_dict = DatasetDict({**train_dict, **eval_dict})
if data_args.tokenized_path is not None: # save tokenized dataset to disk if data_args.tokenized_path is not None: # save tokenized dataset to disk
if training_args.should_save: if training_args.should_save:
dataset_dict.save_to_disk(data_args.tokenized_path) dataset_dict.save_to_disk(data_args.tokenized_path)

View File

@@ -22,10 +22,11 @@ import re
from copy import deepcopy from copy import deepcopy
from dataclasses import dataclass from dataclasses import dataclass
from io import BytesIO from io import BytesIO
from typing import TYPE_CHECKING, BinaryIO, Literal, Optional, TypedDict, Union from typing import TYPE_CHECKING, BinaryIO, Literal, NotRequired, Optional, TypedDict, Union
import numpy as np import numpy as np
import torch import torch
import torchaudio
from transformers.image_utils import get_image_size, is_valid_image, to_numpy_array from transformers.image_utils import get_image_size, is_valid_image, to_numpy_array
from transformers.models.mllama.processing_mllama import ( from transformers.models.mllama.processing_mllama import (
convert_sparse_cross_attention_mask_to_dense, convert_sparse_cross_attention_mask_to_dense,
@@ -34,16 +35,7 @@ from transformers.models.mllama.processing_mllama import (
from typing_extensions import override from typing_extensions import override
from ..extras.constants import AUDIO_PLACEHOLDER, IGNORE_INDEX, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER from ..extras.constants import AUDIO_PLACEHOLDER, IGNORE_INDEX, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
from ..extras.packages import ( from ..extras.packages import is_pillow_available, is_pyav_available, is_transformers_version_greater_than
is_librosa_available,
is_pillow_available,
is_pyav_available,
is_transformers_version_greater_than,
)
if is_librosa_available():
import librosa
if is_pillow_available(): if is_pillow_available():
@@ -68,15 +60,28 @@ if TYPE_CHECKING:
from transformers import PreTrainedTokenizer, ProcessorMixin from transformers import PreTrainedTokenizer, ProcessorMixin
from transformers.feature_extraction_sequence_utils import SequenceFeatureExtractor from transformers.feature_extraction_sequence_utils import SequenceFeatureExtractor
from transformers.image_processing_utils import BaseImageProcessor from transformers.image_processing_utils import BaseImageProcessor
from transformers.video_processing_utils import BaseVideoProcessor
class EncodedImage(TypedDict): class EncodedImage(TypedDict):
path: Optional[str] path: str | None
bytes: Optional[bytes] bytes: bytes | None
ImageInput = Union[str, bytes, EncodedImage, BinaryIO, ImageObject] ImageInput = Union[str, bytes, EncodedImage, BinaryIO, ImageObject]
VideoInput = Union[str, BinaryIO, list[list[ImageInput]]] VideoInput = Union[str, BinaryIO, list[list[ImageInput]]]
AudioInput = Union[str, BinaryIO, NDArray] AudioInput = Union[str, BinaryIO, NDArray]
class RegularizedImageOutput(TypedDict):
images: list[ImageObject]
class RegularizedVideoOutput(TypedDict):
videos: list[list[ImageObject]]
durations: list[float]
fps_per_video: NotRequired[list[float]]
class RegularizedAudioOutput(TypedDict):
audios: list[NDArray]
sampling_rates: list[float]
class MMProcessor(ProcessorMixin): class MMProcessor(ProcessorMixin):
patch_size: int patch_size: int
image_seq_length: int image_seq_length: int
@@ -139,9 +144,9 @@ def _check_video_is_nested_images(video: "VideoInput") -> bool:
@dataclass @dataclass
class MMPluginMixin: class MMPluginMixin:
image_token: Optional[str] image_token: str | None
video_token: Optional[str] video_token: str | None
audio_token: Optional[str] audio_token: str | None
expand_mm_tokens: bool = True expand_mm_tokens: bool = True
def _validate_input( def _validate_input(
@@ -244,7 +249,7 @@ class MMPluginMixin:
sample_frames = min(total_frames, video_maxlen, sample_frames) sample_frames = min(total_frames, video_maxlen, sample_frames)
return np.linspace(0, total_frames - 1, sample_frames).astype(np.int32) return np.linspace(0, total_frames - 1, sample_frames).astype(np.int32)
def _regularize_images(self, images: list["ImageInput"], **kwargs) -> dict[str, list["ImageObject"]]: def _regularize_images(self, images: list["ImageInput"], **kwargs) -> "RegularizedImageOutput":
r"""Regularize images to avoid error. Including reading and pre-processing.""" r"""Regularize images to avoid error. Including reading and pre-processing."""
results = [] results = []
for image in images: for image in images:
@@ -265,9 +270,10 @@ class MMPluginMixin:
return {"images": results} return {"images": results}
def _regularize_videos(self, videos: list["VideoInput"], **kwargs) -> dict[str, list[list["ImageObject"]]]: def _regularize_videos(self, videos: list["VideoInput"], **kwargs) -> "RegularizedVideoOutput":
r"""Regularizes videos to avoid error. Including reading, resizing and converting.""" r"""Regularizes videos to avoid error. Including reading, resizing and converting."""
results = [] results = []
durations = []
for video in videos: for video in videos:
frames: list[ImageObject] = [] frames: list[ImageObject] = []
if _check_video_is_nested_images(video): if _check_video_is_nested_images(video):
@@ -275,6 +281,7 @@ class MMPluginMixin:
if not is_valid_image(frame) and not isinstance(frame, dict) and not os.path.exists(frame): if not is_valid_image(frame) and not isinstance(frame, dict) and not os.path.exists(frame):
raise ValueError("Invalid image found in video frames.") raise ValueError("Invalid image found in video frames.")
frames = video frames = video
durations.append(len(frames) / kwargs.get("video_fps", 2.0))
else: else:
container = av.open(video, "r") container = av.open(video, "r")
video_stream = next(stream for stream in container.streams if stream.type == "video") video_stream = next(stream for stream in container.streams if stream.type == "video")
@@ -284,19 +291,31 @@ class MMPluginMixin:
if frame_idx in sample_indices: if frame_idx in sample_indices:
frames.append(frame.to_image()) frames.append(frame.to_image())
if video_stream.duration is None:
durations.append(len(frames) / kwargs.get("video_fps", 2.0))
else:
durations.append(float(video_stream.duration * video_stream.time_base))
frames = self._regularize_images(frames, **kwargs)["images"] frames = self._regularize_images(frames, **kwargs)["images"]
results.append(frames) results.append(frames)
return {"videos": results} return {"videos": results, "durations": durations}
def _regularize_audios( def _regularize_audios(
self, audios: list["AudioInput"], sampling_rate: float, **kwargs self, audios: list["AudioInput"], sampling_rate: float, **kwargs
) -> dict[str, Union[list["NDArray"], list[float]]]: ) -> "RegularizedAudioOutput":
r"""Regularizes audios to avoid error. Including reading and resampling.""" r"""Regularizes audios to avoid error. Including reading and resampling."""
results, sampling_rates = [], [] results, sampling_rates = [], []
for audio in audios: for audio in audios:
if not isinstance(audio, np.ndarray): if not isinstance(audio, np.ndarray):
audio, sampling_rate = librosa.load(audio, sr=sampling_rate) audio, sr = torchaudio.load(audio)
if audio.shape[0] > 1:
audio = audio.mean(dim=0, keepdim=True)
if sr != sampling_rate:
audio = torchaudio.functional.resample(audio, sr, sampling_rate)
audio = audio.squeeze(0).numpy()
results.append(audio) results.append(audio)
sampling_rates.append(sampling_rate) sampling_rates.append(sampling_rate)
@@ -309,7 +328,7 @@ class MMPluginMixin:
videos: list["VideoInput"], videos: list["VideoInput"],
audios: list["AudioInput"], audios: list["AudioInput"],
processor: "MMProcessor", processor: "MMProcessor",
imglens: Optional[list[int]] = None, imglens: list[int] | None = None,
) -> dict[str, "torch.Tensor"]: ) -> dict[str, "torch.Tensor"]:
r"""Process visual inputs. r"""Process visual inputs.
@@ -407,13 +426,13 @@ class BasePlugin(MMPluginMixin):
def process_token_ids( def process_token_ids(
self, self,
input_ids: list[int], input_ids: list[int],
labels: Optional[list[int]], labels: list[int] | None,
images: list["ImageInput"], images: list["ImageInput"],
videos: list["VideoInput"], videos: list["VideoInput"],
audios: list["AudioInput"], audios: list["AudioInput"],
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
processor: Optional["MMProcessor"], processor: Optional["MMProcessor"],
) -> tuple[list[int], Optional[list[int]]]: ) -> tuple[list[int], list[int] | None]:
r"""Pre-process token ids after tokenization for VLMs.""" r"""Pre-process token ids after tokenization for VLMs."""
self._validate_input(processor, images, videos, audios) self._validate_input(processor, images, videos, audios)
return input_ids, labels return input_ids, labels
@@ -446,6 +465,57 @@ class BasePlugin(MMPluginMixin):
return self._get_mm_inputs(images, videos, audios, processor) return self._get_mm_inputs(images, videos, audios, processor)
@dataclass
class ErnieVLPlugin(BasePlugin):
@override
def process_messages(
self,
messages: list[dict[str, str]],
images: list["ImageInput"],
videos: list["VideoInput"],
audios: list["AudioInput"],
processor: Optional["MMProcessor"],
) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios)
self._validate_messages(messages, images, videos, audios)
messages = deepcopy(messages)
image_processor: BaseImageProcessor = getattr(processor, "image_processor")
merge_length: int = getattr(image_processor, "merge_size") ** 2
if self.expand_mm_tokens:
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
image_grid_thw = mm_inputs.get("image_grid_thw", [])
video_grid_thw = mm_inputs.get("video_grid_thw", [])
else:
image_grid_thw = [None] * len(images)
video_grid_thw = [None] * len(videos)
image_idx, video_idx = 0, 0
for message in messages:
content = message["content"]
image_token = self.image_token or "<|IMAGE_PLACEHOLDER|>"
video_token = self.video_token or "<|VIDEO_PLACEHOLDER|>"
while IMAGE_PLACEHOLDER in content:
image_seqlen = image_grid_thw[image_idx].prod() // merge_length if self.expand_mm_tokens else 1
content = content.replace(
IMAGE_PLACEHOLDER,
f"Picture {image_idx + 1}:<|IMAGE_START|>{image_token * image_seqlen}<|IMAGE_END|>",
1,
)
image_idx += 1
while VIDEO_PLACEHOLDER in content:
video_seqlen = video_grid_thw[video_idx].prod() // merge_length if self.expand_mm_tokens else 1
content = content.replace(
VIDEO_PLACEHOLDER,
f"Video {video_idx + 1}:<|VIDEO_START|>{video_token * video_seqlen}<|VIDEO_END|>",
1,
)
video_idx += 1
message["content"] = content
return messages
@dataclass @dataclass
class Gemma3Plugin(BasePlugin): class Gemma3Plugin(BasePlugin):
@override @override
@@ -1235,13 +1305,13 @@ class PaliGemmaPlugin(BasePlugin):
def process_token_ids( def process_token_ids(
self, self,
input_ids: list[int], input_ids: list[int],
labels: Optional[list[int]], labels: list[int] | None,
images: list["ImageInput"], images: list["ImageInput"],
videos: list["VideoInput"], videos: list["VideoInput"],
audios: list["AudioInput"], audios: list["AudioInput"],
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
processor: Optional["MMProcessor"], processor: Optional["MMProcessor"],
) -> tuple[list[int], Optional[list[int]]]: ) -> tuple[list[int], list[int] | None]:
self._validate_input(processor, images, videos, audios) self._validate_input(processor, images, videos, audios)
num_images = len(images) num_images = len(images)
image_seqlen = processor.image_seq_length if self.expand_mm_tokens else 0 # skip mm token image_seqlen = processor.image_seq_length if self.expand_mm_tokens else 0 # skip mm token
@@ -1418,10 +1488,8 @@ class Qwen2VLPlugin(BasePlugin):
return image return image
@override @override
def _regularize_videos( def _regularize_videos(self, videos: list["VideoInput"], **kwargs) -> "RegularizedVideoOutput":
self, videos: list["VideoInput"], **kwargs results, fps_per_video, durations = [], [], []
) -> dict[str, Union[list[list["ImageObject"]], list[float]]]:
results, fps_per_video = [], []
for video in videos: for video in videos:
frames: list[ImageObject] = [] frames: list[ImageObject] = []
if _check_video_is_nested_images(video): if _check_video_is_nested_images(video):
@@ -1431,6 +1499,7 @@ class Qwen2VLPlugin(BasePlugin):
frames = video frames = video
fps_per_video.append(kwargs.get("video_fps", 2.0)) fps_per_video.append(kwargs.get("video_fps", 2.0))
durations.append(len(frames) / kwargs.get("video_fps", 2.0))
else: else:
container = av.open(video, "r") container = av.open(video, "r")
video_stream = next(stream for stream in container.streams if stream.type == "video") video_stream = next(stream for stream in container.streams if stream.type == "video")
@@ -1442,8 +1511,10 @@ class Qwen2VLPlugin(BasePlugin):
if video_stream.duration is None: if video_stream.duration is None:
fps_per_video.append(kwargs.get("video_fps", 2.0)) fps_per_video.append(kwargs.get("video_fps", 2.0))
durations.append(len(frames) / kwargs.get("video_fps", 2.0))
else: else:
fps_per_video.append(len(sample_indices) / float(video_stream.duration * video_stream.time_base)) fps_per_video.append(len(sample_indices) / float(video_stream.duration * video_stream.time_base))
durations.append(float(video_stream.duration * video_stream.time_base))
if len(frames) % 2 != 0: if len(frames) % 2 != 0:
frames.append(frames[-1]) frames.append(frames[-1])
@@ -1451,7 +1522,7 @@ class Qwen2VLPlugin(BasePlugin):
frames = self._regularize_images(frames, **kwargs)["images"] frames = self._regularize_images(frames, **kwargs)["images"]
results.append(frames) results.append(frames)
return {"videos": results, "fps_per_video": fps_per_video} return {"videos": results, "fps_per_video": fps_per_video, "durations": durations}
@override @override
def _get_mm_inputs( def _get_mm_inputs(
@@ -1462,6 +1533,7 @@ class Qwen2VLPlugin(BasePlugin):
processor: "MMProcessor", processor: "MMProcessor",
) -> dict[str, "torch.Tensor"]: ) -> dict[str, "torch.Tensor"]:
image_processor: BaseImageProcessor = getattr(processor, "image_processor", None) image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
video_processor: BaseVideoProcessor = getattr(processor, "video_processor", None)
mm_inputs = {} mm_inputs = {}
if len(images) != 0: if len(images) != 0:
images = self._regularize_images( images = self._regularize_images(
@@ -1479,7 +1551,7 @@ class Qwen2VLPlugin(BasePlugin):
video_fps=getattr(processor, "video_fps", 2.0), video_fps=getattr(processor, "video_fps", 2.0),
video_maxlen=getattr(processor, "video_maxlen", 128), video_maxlen=getattr(processor, "video_maxlen", 128),
) )
mm_inputs.update(image_processor(images=None, videos=video_data["videos"], return_tensors="pt")) mm_inputs.update(video_processor(videos=video_data["videos"], return_tensors="pt"))
temporal_patch_size: int = getattr(image_processor, "temporal_patch_size", 2) temporal_patch_size: int = getattr(image_processor, "temporal_patch_size", 2)
if "second_per_grid_ts" in processor.model_input_names: if "second_per_grid_ts" in processor.model_input_names:
mm_inputs["second_per_grid_ts"] = [temporal_patch_size / fps for fps in video_data["fps_per_video"]] mm_inputs["second_per_grid_ts"] = [temporal_patch_size / fps for fps in video_data["fps_per_video"]]
@@ -1565,11 +1637,16 @@ class Qwen3VLPlugin(Qwen2VLPlugin):
video_maxlen=getattr(processor, "video_maxlen", 128), video_maxlen=getattr(processor, "video_maxlen", 128),
) )
video_metadata = [ video_metadata = [
{"fps": getattr(processor, "video_fps", 24.0), "duration": len(video), "total_num_frames": len(video)} {"fps": getattr(processor, "video_fps", 24.0), "duration": duration, "total_num_frames": len(video)}
for video in videos["videos"] for video, duration in zip(videos["videos"], videos["durations"])
] ]
mm_inputs.update( mm_inputs.update(
video_processor(videos=videos["videos"], video_metadata=video_metadata, return_metadata=True) video_processor(
videos=videos["videos"],
video_metadata=video_metadata,
fps=getattr(processor, "video_fps", 2.0),
return_metadata=True,
)
) )
temporal_patch_size: int = getattr(image_processor, "temporal_patch_size", 2) temporal_patch_size: int = getattr(image_processor, "temporal_patch_size", 2)
if "second_per_grid_ts" in processor.model_input_names: if "second_per_grid_ts" in processor.model_input_names:
@@ -1622,27 +1699,27 @@ class Qwen3VLPlugin(Qwen2VLPlugin):
num_image_tokens += 1 num_image_tokens += 1
while VIDEO_PLACEHOLDER in content: while VIDEO_PLACEHOLDER in content:
metadata = video_metadata[idx] if self.expand_mm_tokens:
timestamps = processor._calculate_timestamps( metadata = video_metadata[idx]
metadata.frames_indices, timestamps = processor._calculate_timestamps(
metadata.fps, metadata.frames_indices,
video_processor.merge_size, metadata.fps,
) video_processor.merge_size,
video_structure = ""
for frame_index in range(num_frames):
video_seqlen = (
video_grid_thw[num_video_tokens][1:].prod() // video_merge_length
if self.expand_mm_tokens
else 1
) )
timestamp_sec = timestamps[frame_index] video_structure = ""
frame_structure = ( for frame_index in range(num_frames):
f"<{timestamp_sec:.1f} seconds>" video_seqlen = (
f"{self.vision_bos_token}{self.video_token * video_seqlen}{self.vision_eos_token}" video_grid_thw[num_video_tokens][1:].prod() // video_merge_length
) if self.expand_mm_tokens
video_structure += frame_structure else 1
)
if not self.expand_mm_tokens: timestamp_sec = timestamps[frame_index]
frame_structure = (
f"<{timestamp_sec:.1f} seconds>"
f"{self.vision_bos_token}{self.video_token * video_seqlen}{self.vision_eos_token}"
)
video_structure += frame_structure
else:
video_structure = f"{self.vision_bos_token}{self.video_token}{self.vision_eos_token}" video_structure = f"{self.vision_bos_token}{self.video_token}{self.vision_eos_token}"
content = content.replace(VIDEO_PLACEHOLDER, video_structure, 1) content = content.replace(VIDEO_PLACEHOLDER, video_structure, 1)
@@ -1684,7 +1761,8 @@ class GLM4VPlugin(Qwen2VLPlugin):
) )
# prepare video metadata # prepare video metadata
video_metadata = [ video_metadata = [
{"fps": 2, "duration": len(video), "total_frames": len(video)} for video in video_data["videos"] {"fps": 2, "duration": duration, "total_frames": len(video)}
for video, duration in zip(video_data["videos"], video_data["durations"])
] ]
mm_inputs.update(video_processor(images=None, videos=video_data["videos"], video_metadata=video_metadata)) mm_inputs.update(video_processor(images=None, videos=video_data["videos"], video_metadata=video_metadata))
@@ -1797,6 +1875,7 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
processor: "MMProcessor", processor: "MMProcessor",
) -> dict[str, "torch.Tensor"]: ) -> dict[str, "torch.Tensor"]:
image_processor: BaseImageProcessor = getattr(processor, "image_processor", None) image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
video_processor: BaseVideoProcessor = getattr(processor, "video_processor", None)
feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None) feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None)
mm_inputs = {} mm_inputs = {}
if len(images) != 0: if len(images) != 0:
@@ -1815,7 +1894,7 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
video_fps=getattr(processor, "video_fps", 2.0), video_fps=getattr(processor, "video_fps", 2.0),
video_maxlen=getattr(processor, "video_maxlen", 128), video_maxlen=getattr(processor, "video_maxlen", 128),
) )
mm_inputs.update(image_processor(images=None, videos=video_dict["videos"], return_tensors="pt")) mm_inputs.update(video_processor(videos=video_dict["videos"], return_tensors="pt"))
temporal_patch_size: int = getattr(image_processor, "temporal_patch_size", 2) temporal_patch_size: int = getattr(image_processor, "temporal_patch_size", 2)
mm_inputs["video_second_per_grid"] = torch.tensor( mm_inputs["video_second_per_grid"] = torch.tensor(
[temporal_patch_size / fps for fps in video_dict["fps_per_video"]] [temporal_patch_size / fps for fps in video_dict["fps_per_video"]]
@@ -1861,8 +1940,14 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
image_grid_thw = mm_inputs.get("image_grid_thw", []) image_grid_thw = mm_inputs.get("image_grid_thw", [])
video_grid_thw = mm_inputs.get("video_grid_thw", []) video_grid_thw = mm_inputs.get("video_grid_thw", [])
if "feature_attention_mask" in mm_inputs: if "feature_attention_mask" in mm_inputs:
input_lengths = (mm_inputs["feature_attention_mask"].sum(-1).numpy() - 1) // 2 + 1 if processor.__class__.__name__ == "Qwen3OmniMoeProcessor": # for qwen3omni
audio_lengths = (input_lengths - 2) // 2 + 1 input_lengths = mm_inputs["feature_attention_mask"].sum(-1)
input_lengths_leave = input_lengths % 100
feature_lengths = (input_lengths_leave - 1) // 2 + 1
audio_lengths = ((feature_lengths - 1) // 2 + 1 - 1) // 2 + 1 + (input_lengths // 100) * 13
else:
input_lengths = (mm_inputs["feature_attention_mask"].sum(-1).numpy() - 1) // 2 + 1
audio_lengths = (input_lengths - 2) // 2 + 1
else: else:
mm_inputs = {} mm_inputs = {}
image_grid_thw = [None] * len(images) image_grid_thw = [None] * len(images)
@@ -2009,6 +2094,7 @@ class VideoLlavaPlugin(BasePlugin):
PLUGINS = { PLUGINS = {
"base": BasePlugin, "base": BasePlugin,
"ernie_vl": ErnieVLPlugin,
"gemma3": Gemma3Plugin, "gemma3": Gemma3Plugin,
"glm4v": GLM4VPlugin, "glm4v": GLM4VPlugin,
"gemma3n": Gemma3nPlugin, "gemma3n": Gemma3nPlugin,
@@ -2040,9 +2126,9 @@ def register_mm_plugin(name: str, plugin_class: type["BasePlugin"]) -> None:
def get_mm_plugin( def get_mm_plugin(
name: str, name: str,
image_token: Optional[str] = None, image_token: str | None = None,
video_token: Optional[str] = None, video_token: str | None = None,
audio_token: Optional[str] = None, audio_token: str | None = None,
**kwargs, **kwargs,
) -> "BasePlugin": ) -> "BasePlugin":
r"""Get plugin for multimodal inputs.""" r"""Get plugin for multimodal inputs."""

View File

@@ -15,7 +15,7 @@
import json import json
import os import os
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Literal, Optional, Union from typing import Any, Literal
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
@@ -30,43 +30,43 @@ class DatasetAttr:
# basic configs # basic configs
load_from: Literal["hf_hub", "ms_hub", "om_hub", "script", "file"] load_from: Literal["hf_hub", "ms_hub", "om_hub", "script", "file"]
dataset_name: str dataset_name: str
formatting: Literal["alpaca", "sharegpt"] = "alpaca" formatting: Literal["alpaca", "sharegpt", "openai"] = "alpaca"
ranking: bool = False ranking: bool = False
# extra configs # extra configs
subset: Optional[str] = None subset: str | None = None
split: str = "train" split: str = "train"
folder: Optional[str] = None folder: str | None = None
num_samples: Optional[int] = None num_samples: int | None = None
# common columns # common columns
system: Optional[str] = None system: str | None = None
tools: Optional[str] = None tools: str | None = None
images: Optional[str] = None images: str | None = None
videos: Optional[str] = None videos: str | None = None
audios: Optional[str] = None audios: str | None = None
# dpo columns # dpo columns
chosen: Optional[str] = None chosen: str | None = None
rejected: Optional[str] = None rejected: str | None = None
kto_tag: Optional[str] = None kto_tag: str | None = None
# alpaca columns # alpaca columns
prompt: Optional[str] = "instruction" prompt: str | None = "instruction"
query: Optional[str] = "input" query: str | None = "input"
response: Optional[str] = "output" response: str | None = "output"
history: Optional[str] = None history: str | None = None
# sharegpt columns # sharegpt columns
messages: Optional[str] = "conversations" messages: str | None = "conversations"
# sharegpt tags # sharegpt tags
role_tag: Optional[str] = "from" role_tag: str | None = "from"
content_tag: Optional[str] = "value" content_tag: str | None = "value"
user_tag: Optional[str] = "human" user_tag: str | None = "human"
assistant_tag: Optional[str] = "gpt" assistant_tag: str | None = "gpt"
observation_tag: Optional[str] = "observation" observation_tag: str | None = "observation"
function_tag: Optional[str] = "function_call" function_tag: str | None = "function_call"
system_tag: Optional[str] = "system" system_tag: str | None = "system"
def __repr__(self) -> str: def __repr__(self) -> str:
return self.dataset_name return self.dataset_name
def set_attr(self, key: str, obj: dict[str, Any], default: Optional[Any] = None) -> None: def set_attr(self, key: str, obj: dict[str, Any], default: Any | None = None) -> None:
setattr(self, key, obj.get(key, default)) setattr(self, key, obj.get(key, default))
def join(self, attr: dict[str, Any]) -> None: def join(self, attr: dict[str, Any]) -> None:
@@ -90,7 +90,7 @@ class DatasetAttr:
self.set_attr(tag, attr["tags"]) self.set_attr(tag, attr["tags"])
def get_dataset_list(dataset_names: Optional[list[str]], dataset_dir: Union[str, dict]) -> list["DatasetAttr"]: def get_dataset_list(dataset_names: list[str] | None, dataset_dir: str | dict) -> list["DatasetAttr"]:
r"""Get the attributes of the datasets.""" r"""Get the attributes of the datasets."""
if dataset_names is None: if dataset_names is None:
dataset_names = [] dataset_names = []

View File

@@ -49,6 +49,7 @@ class Template:
default_system: str default_system: str
stop_words: list[str] stop_words: list[str]
thought_words: tuple[str, str] thought_words: tuple[str, str]
tool_call_words: tuple[str, str]
efficient_eos: bool efficient_eos: bool
replace_eos: bool replace_eos: bool
replace_jinja_template: bool replace_jinja_template: bool
@@ -156,7 +157,9 @@ class Template:
elif message["role"] == Role.OBSERVATION: elif message["role"] == Role.OBSERVATION:
elements += self.format_observation.apply(content=message["content"]) elements += self.format_observation.apply(content=message["content"])
elif message["role"] == Role.FUNCTION: elif message["role"] == Role.FUNCTION:
elements += self.format_function.apply(content=message["content"], thought_words=self.thought_words) elements += self.format_function.apply(
content=message["content"], thought_words=self.thought_words, tool_call_words=self.tool_call_words
)
else: else:
raise NotImplementedError("Unexpected role: {}".format(message["role"])) raise NotImplementedError("Unexpected role: {}".format(message["role"]))
@@ -199,9 +202,12 @@ class Template:
logger.info_rank0(f"Add pad token: {tokenizer.pad_token}") logger.info_rank0(f"Add pad token: {tokenizer.pad_token}")
if stop_words: if stop_words:
num_added_tokens = tokenizer.add_special_tokens( try:
dict(additional_special_tokens=stop_words), replace_additional_special_tokens=False num_added_tokens = tokenizer.add_special_tokens(
) dict(additional_special_tokens=stop_words), replace_additional_special_tokens=False
)
except TypeError:
num_added_tokens = tokenizer.add_special_tokens(dict(additional_special_tokens=stop_words))
logger.info_rank0("Add {} to stop words.".format(",".join(stop_words))) logger.info_rank0("Add {} to stop words.".format(",".join(stop_words)))
if num_added_tokens > 0: if num_added_tokens > 0:
logger.warning_rank0("New tokens have been added, make sure `resize_vocab` is True.") logger.warning_rank0("New tokens have been added, make sure `resize_vocab` is True.")
@@ -468,6 +474,7 @@ def register_template(
default_system: str = "", default_system: str = "",
stop_words: Optional[list[str]] = None, stop_words: Optional[list[str]] = None,
thought_words: Optional[tuple[str, str]] = None, thought_words: Optional[tuple[str, str]] = None,
tool_call_words: Optional[tuple[str, str]] = None,
efficient_eos: bool = False, efficient_eos: bool = False,
replace_eos: bool = False, replace_eos: bool = False,
replace_jinja_template: bool = False, replace_jinja_template: bool = False,
@@ -519,6 +526,7 @@ def register_template(
default_system=default_system, default_system=default_system,
stop_words=stop_words or [], stop_words=stop_words or [],
thought_words=thought_words or ("<think>\n", "\n</think>\n\n"), thought_words=thought_words or ("<think>\n", "\n</think>\n\n"),
tool_call_words=tool_call_words or ("<tool_call>", "</tool_call>"),
efficient_eos=efficient_eos, efficient_eos=efficient_eos,
replace_eos=replace_eos, replace_eos=replace_eos,
replace_jinja_template=replace_jinja_template, replace_jinja_template=replace_jinja_template,
@@ -580,6 +588,7 @@ def parse_template(tokenizer: "PreTrainedTokenizer") -> "Template":
default_system=default_system, default_system=default_system,
stop_words=[], stop_words=[],
thought_words=("<think>\n", "\n</think>\n\n"), thought_words=("<think>\n", "\n</think>\n\n"),
tool_call_words=("<tool_call>", "</tool_call>"),
efficient_eos=False, efficient_eos=False,
replace_eos=False, replace_eos=False,
replace_jinja_template=False, replace_jinja_template=False,
@@ -616,7 +625,14 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args:
logger.info_rank0(f"Using default system message: {data_args.default_system}.") logger.info_rank0(f"Using default system message: {data_args.default_system}.")
template.default_system = data_args.default_system template.default_system = data_args.default_system
template.enable_thinking = data_args.enable_thinking if isinstance(template, ReasoningTemplate):
logger.warning_rank0(
"You are using reasoning template, "
"please add `_nothink` suffix if the model is not a reasoning model. "
"e.g., qwen3_vl_nothink"
)
template.enable_thinking = data_args.enable_thinking
template.fix_special_tokens(tokenizer) template.fix_special_tokens(tokenizer)
template.fix_jinja_template(tokenizer) template.fix_jinja_template(tokenizer)
return template return template
@@ -956,6 +972,19 @@ register_template(
) )
register_template(
name="ernie_vl",
format_user=StringFormatter(slots=["User: {{content}}"]),
format_assistant=StringFormatter(slots=["\nAssistant: {{content}}<|end_of_sentence|>"]),
format_system=StringFormatter(slots=["{{content}}\n"]),
stop_words=["<|end_of_sentence|>"],
replace_eos=True,
replace_jinja_template=True,
template_class=ReasoningTemplate,
mm_plugin=get_mm_plugin(name="ernie_vl", image_token="<|IMAGE_PLACEHOLDER|>", video_token="<|VIDEO_PLACEHOLDER|>"),
)
register_template( register_template(
name="exaone", name="exaone",
format_user=StringFormatter(slots=["[|user|]{{content}}\n[|assistant|]"]), format_user=StringFormatter(slots=["[|user|]{{content}}\n[|assistant|]"]),
@@ -1105,7 +1134,7 @@ register_template(
# copied from glm4 template # copied from glm4 template
register_template( register_template(
name="glm4v_moe", name="glm4_5v",
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]), format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]),
format_assistant=StringFormatter(slots=["\n{{content}}"]), format_assistant=StringFormatter(slots=["\n{{content}}"]),
format_system=StringFormatter(slots=["<|system|>\n{{content}}"]), format_system=StringFormatter(slots=["<|system|>\n{{content}}"]),
@@ -1137,7 +1166,7 @@ register_template(
register_template( register_template(
name="gpt", name="gpt_oss",
format_user=StringFormatter(slots=["<|start|>user<|message|>{{content}}<|end|><|start|>assistant"]), format_user=StringFormatter(slots=["<|start|>user<|message|>{{content}}<|end|><|start|>assistant"]),
format_assistant=StringFormatter(slots=["{{content}}<|end|>"]), format_assistant=StringFormatter(slots=["{{content}}<|end|>"]),
format_system=StringFormatter(slots=["<|start|>system<|message|>{{content}}<|end|>"]), format_system=StringFormatter(slots=["<|start|>system<|message|>{{content}}<|end|>"]),
@@ -1201,10 +1230,10 @@ register_template(
register_template( register_template(
name="hunyuan", name="hunyuan",
format_user=StringFormatter(slots=["<|bos|>user\n{{content}}<|eos|>\n<|bos|>assistant\n"]), format_user=StringFormatter(slots=["{{content}}<|extra_0|>"]),
format_assistant=StringFormatter(slots=["{{content}}<|eos|>\n"]), format_assistant=StringFormatter(slots=["{{content}}<|eos|>"]),
format_system=StringFormatter(slots=["<|bos|>system\n{{content}}<|eos|>\n"]), format_system=StringFormatter(slots=["{{content}}<|extra_4|>"]),
format_prefix=EmptyFormatter(slots=["<|bos|>"]), format_prefix=EmptyFormatter(slots=["<|startoftext|>"]),
stop_words=["<|eos|>"], stop_words=["<|eos|>"],
) )
@@ -1581,6 +1610,26 @@ register_template(
template_class=ReasoningTemplate, template_class=ReasoningTemplate,
) )
# copied from qwen template
register_template(
name="mimo_v2",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_function=FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen"),
format_observation=StringFormatter(
slots=["<|im_start|>user\n<tool_response>\n{{content}}\n</tool_response><|im_end|>\n<|im_start|>assistant\n"]
),
format_tools=ToolFormatter(tool_format="qwen"),
default_system="You are MiMo, a helpful AI assistant engineered by Xiaomi.",
stop_words=["<|im_end|>"],
replace_eos=True,
thought_words=("<think>", "</think>"),
template_class=ReasoningTemplate,
)
# copied from qwen2vl # copied from qwen2vl
register_template( register_template(
name="mimo_vl", name="mimo_vl",
@@ -1664,6 +1713,19 @@ register_template(
) )
register_template(
name="ministral3",
format_user=StringFormatter(slots=["[INST]{{content}}[/INST]"]),
format_system=StringFormatter(slots=["{{content}}\n\n"]),
format_function=FunctionFormatter(slots=["[TOOL_CALLS]{{content}}", {"eos_token"}], tool_format="mistral"),
format_observation=StringFormatter(slots=["""[TOOL_RESULTS]{"content": {{content}}}[/TOOL_RESULTS]"""]),
format_tools=ToolFormatter(tool_format="mistral"),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
template_class=Llama2Template,
mm_plugin=get_mm_plugin(name="pixtral", image_token="[IMG]"),
)
register_template( register_template(
name="olmo", name="olmo",
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>\n"]), format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>\n"]),

View File

@@ -15,7 +15,6 @@
import os import os
from collections import OrderedDict, defaultdict from collections import OrderedDict, defaultdict
from enum import Enum, unique from enum import Enum, unique
from typing import Optional
from peft.utils import SAFETENSORS_WEIGHTS_NAME as SAFE_ADAPTER_WEIGHTS_NAME from peft.utils import SAFETENSORS_WEIGHTS_NAME as SAFE_ADAPTER_WEIGHTS_NAME
from peft.utils import WEIGHTS_NAME as ADAPTER_WEIGHTS_NAME from peft.utils import WEIGHTS_NAME as ADAPTER_WEIGHTS_NAME
@@ -56,6 +55,19 @@ LAYERNORM_NAMES = {"norm", "ln"}
LLAMABOARD_CONFIG = "llamaboard_config.yaml" LLAMABOARD_CONFIG = "llamaboard_config.yaml"
MCA_SUPPORTED_MODELS = {
"deepseek_v3",
"llama",
"mistral",
"mixtral",
"qwen2",
"qwen2_vl",
"qwen2_5_vl",
"qwen3",
"qwen3_moe",
"qwen3_next",
}
METHODS = ["full", "freeze", "lora", "oft"] METHODS = ["full", "freeze", "lora", "oft"]
MOD_SUPPORTED_MODELS = {"bloom", "falcon", "gemma", "llama", "mistral", "mixtral", "phi", "starcoder2"} MOD_SUPPORTED_MODELS = {"bloom", "falcon", "gemma", "llama", "mistral", "mixtral", "phi", "starcoder2"}
@@ -101,12 +113,14 @@ class AttentionFunction(str, Enum):
DISABLED = "disabled" DISABLED = "disabled"
SDPA = "sdpa" SDPA = "sdpa"
FA2 = "fa2" FA2 = "fa2"
FA3 = "fa3"
class EngineName(str, Enum): class EngineName(str, Enum):
HF = "huggingface" HF = "huggingface"
VLLM = "vllm" VLLM = "vllm"
SGLANG = "sglang" SGLANG = "sglang"
KT = "ktransformers"
class DownloadSource(str, Enum): class DownloadSource(str, Enum):
@@ -127,6 +141,7 @@ class QuantizationMethod(str, Enum):
EETQ = "eetq" EETQ = "eetq"
HQQ = "hqq" HQQ = "hqq"
MXFP4 = "mxfp4" MXFP4 = "mxfp4"
FP8 = "fp8"
class RopeScaling(str, Enum): class RopeScaling(str, Enum):
@@ -138,7 +153,7 @@ class RopeScaling(str, Enum):
def register_model_group( def register_model_group(
models: dict[str, dict[DownloadSource, str]], models: dict[str, dict[DownloadSource, str]],
template: Optional[str] = None, template: str | None = None,
multimodal: bool = False, multimodal: bool = False,
) -> None: ) -> None:
for name, path in models.items(): for name, path in models.items():
@@ -643,6 +658,26 @@ register_model_group(
) )
register_model_group(
models={
"ERNIE-4.5-VL-28B-A3B-PT": {
DownloadSource.DEFAULT: "baidu/ERNIE-4.5-VL-28B-A3B-PT",
DownloadSource.MODELSCOPE: "PaddlePaddle/ERNIE-4.5-VL-28B-A3B-PT",
},
"ERNIE-4.5-VL-28B-A3B-Thinking": {
DownloadSource.DEFAULT: "baidu/ERNIE-4.5-VL-28B-A3B-Thinking",
DownloadSource.MODELSCOPE: "PaddlePaddle/ERNIE-4.5-VL-28B-A3B-Thinking",
},
"ERNIE-4.5-VL-424B-A47B-Base-PT": {
DownloadSource.DEFAULT: "baidu/ERNIE-4.5-VL-424B-A47B-PT",
DownloadSource.MODELSCOPE: "PaddlePaddle/ERNIE-4.5-VL-424B-A47B-PT",
},
},
template="ernie_vl",
multimodal=True,
)
register_model_group( register_model_group(
models={ models={
"EXAONE-3.0-7.8B-Instruct": { "EXAONE-3.0-7.8B-Instruct": {
@@ -969,9 +1004,17 @@ register_model_group(
"GLM-4.5V-Air-Thinking": { "GLM-4.5V-Air-Thinking": {
DownloadSource.DEFAULT: "zai-org/GLM-4.5V", DownloadSource.DEFAULT: "zai-org/GLM-4.5V",
DownloadSource.MODELSCOPE: "ZhipuAI/GLM-4.5V", DownloadSource.MODELSCOPE: "ZhipuAI/GLM-4.5V",
} },
"GLM-4.6V": {
DownloadSource.DEFAULT: "zai-org/GLM-4.6V",
DownloadSource.MODELSCOPE: "ZhipuAI/GLM-4.6V",
},
"GLM-4.6V-Flash": {
DownloadSource.DEFAULT: "zai-org/GLM-4.6V-Flash",
DownloadSource.MODELSCOPE: "ZhipuAI/GLM-4.6V-Flash",
},
}, },
template="glm4v_moe", template="glm4_5v",
multimodal=True, multimodal=True,
) )
@@ -1024,7 +1067,7 @@ register_model_group(
DownloadSource.MODELSCOPE: "openai/gpt-oss-120b", DownloadSource.MODELSCOPE: "openai/gpt-oss-120b",
}, },
}, },
template="gpt", template="gpt_oss",
) )
@@ -1152,6 +1195,10 @@ register_model_group(
DownloadSource.DEFAULT: "tencent/Hunyuan-7B-Instruct", DownloadSource.DEFAULT: "tencent/Hunyuan-7B-Instruct",
DownloadSource.MODELSCOPE: "AI-ModelScope/Hunyuan-7B-Instruct", DownloadSource.MODELSCOPE: "AI-ModelScope/Hunyuan-7B-Instruct",
}, },
"Hunyuan-MT-7B-Instruct": {
DownloadSource.DEFAULT: "tencent/Hunyuan-MT-7B",
DownloadSource.MODELSCOPE: "Tencent-Hunyuan/Hunyuan-MT-7B",
},
}, },
template="hunyuan", template="hunyuan",
) )
@@ -1756,6 +1803,21 @@ register_model_group(
) )
register_model_group(
models={
"MiMo-V2-Flash-Base": {
DownloadSource.DEFAULT: "XiaomiMiMo/MiMo-V2-Flash-Base",
DownloadSource.MODELSCOPE: "XiaomiMiMo/MiMo-V2-Flash-Base",
},
"MiMo-V2-Flash": {
DownloadSource.DEFAULT: "XiaomiMiMo/MiMo-V2-Flash",
DownloadSource.MODELSCOPE: "XiaomiMiMo/MiMo-V2-Flash",
},
},
template="mimo_v2",
)
register_model_group( register_model_group(
models={ models={
"MiMo-7B-VL-RL": { "MiMo-7B-VL-RL": {
@@ -1780,7 +1842,7 @@ register_model_group(
}, },
"MiMo-VL-7B-SFT-2508": { "MiMo-VL-7B-SFT-2508": {
DownloadSource.DEFAULT: "XiaomiMiMo/MiMo-VL-7B-SFT-2508", DownloadSource.DEFAULT: "XiaomiMiMo/MiMo-VL-7B-SFT-2508",
DownloadSource.DEFAULT: "XiaomiMiMo/MiMo-VL-7B-SFT-2508", DownloadSource.MODELSCOPE: "XiaomiMiMo/MiMo-VL-7B-SFT-2508",
}, },
}, },
template="qwen2_vl", template="qwen2_vl",
@@ -1931,6 +1993,37 @@ register_model_group(
template="mistral", template="mistral",
) )
register_model_group(
models={
"Ministral-3-3B-Base-2512": {
DownloadSource.DEFAULT: "mistralai/Ministral-3-3B-Base-2512",
DownloadSource.MODELSCOPE: "mistralai/Ministral-3-3B-Base-2512",
},
"Ministral-3-8B-Base-2512": {
DownloadSource.DEFAULT: "mistralai/Ministral-3-8B-Base-2512",
DownloadSource.MODELSCOPE: "mistralai/Ministral-3-8B-Base-2512",
},
"Ministral-3-14B-Base-2512": {
DownloadSource.DEFAULT: "mistralai/Ministral-3-14B-Base-2512",
DownloadSource.MODELSCOPE: "mistralai/Ministral-3-14B-Base-2512",
},
"Ministral-3-3B-Instruct-2512": {
DownloadSource.DEFAULT: "mistralai/Ministral-3-3B-Instruct-2512",
DownloadSource.MODELSCOPE: "mistralai/Ministral-3-3B-Instruct-2512",
},
"Ministral-3-8B-Instruct-2512": {
DownloadSource.DEFAULT: "mistralai/Ministral-3-8B-Instruct-2512",
DownloadSource.MODELSCOPE: "mistralai/Ministral-3-8B-Instruct-2512",
},
"Ministral-3-14B-Instruct-2512": {
DownloadSource.DEFAULT: "mistralai/Ministral-3-14B-Instruct-2512",
DownloadSource.MODELSCOPE: "mistralai/Ministral-3-14B-Instruct-2512",
},
},
template="ministral3",
multimodal=True,
)
register_model_group( register_model_group(
models={ models={
@@ -3193,6 +3286,10 @@ register_model_group(
register_model_group( register_model_group(
models={ models={
"Qwen3-VL-2B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen3-VL-2B-Instruct",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-VL-2B-Instruct",
},
"Qwen3-VL-4B-Instruct": { "Qwen3-VL-4B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen3-VL-4B-Instruct", DownloadSource.DEFAULT: "Qwen/Qwen3-VL-4B-Instruct",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-VL-4B-Instruct", DownloadSource.MODELSCOPE: "Qwen/Qwen3-VL-4B-Instruct",
@@ -3201,6 +3298,10 @@ register_model_group(
DownloadSource.DEFAULT: "Qwen/Qwen3-VL-8B-Instruct", DownloadSource.DEFAULT: "Qwen/Qwen3-VL-8B-Instruct",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-VL-8B-Instruct", DownloadSource.MODELSCOPE: "Qwen/Qwen3-VL-8B-Instruct",
}, },
"Qwen3-VL-32B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen3-VL-32B-Instruct",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-VL-32B-Instruct",
},
"Qwen3-VL-30B-A3B-Instruct": { "Qwen3-VL-30B-A3B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen3-VL-30B-A3B-Instruct", DownloadSource.DEFAULT: "Qwen/Qwen3-VL-30B-A3B-Instruct",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-VL-30B-A3B-Instruct", DownloadSource.MODELSCOPE: "Qwen/Qwen3-VL-30B-A3B-Instruct",
@@ -3217,6 +3318,10 @@ register_model_group(
register_model_group( register_model_group(
models={ models={
"Qwen3-VL-2B-Thinking": {
DownloadSource.DEFAULT: "Qwen/Qwen3-VL-2B-Thinking",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-VL-2B-Thinking",
},
"Qwen3-VL-4B-Thinking": { "Qwen3-VL-4B-Thinking": {
DownloadSource.DEFAULT: "Qwen/Qwen3-VL-4B-Thinking", DownloadSource.DEFAULT: "Qwen/Qwen3-VL-4B-Thinking",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-VL-4B-Thinking", DownloadSource.MODELSCOPE: "Qwen/Qwen3-VL-4B-Thinking",
@@ -3225,6 +3330,10 @@ register_model_group(
DownloadSource.DEFAULT: "Qwen/Qwen3-VL-8B-Thinking", DownloadSource.DEFAULT: "Qwen/Qwen3-VL-8B-Thinking",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-VL-8B-Thinking", DownloadSource.MODELSCOPE: "Qwen/Qwen3-VL-8B-Thinking",
}, },
"Qwen3-VL-32B-Thinking": {
DownloadSource.DEFAULT: "Qwen/Qwen3-VL-32B-Thinking",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-VL-32B-Thinking",
},
"Qwen3-VL-30B-A3B-Thinking": { "Qwen3-VL-30B-A3B-Thinking": {
DownloadSource.DEFAULT: "Qwen/Qwen3-VL-30B-A3B-Thinking", DownloadSource.DEFAULT: "Qwen/Qwen3-VL-30B-A3B-Thinking",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-VL-30B-A3B-Thinking", DownloadSource.MODELSCOPE: "Qwen/Qwen3-VL-30B-A3B-Thinking",
@@ -3438,6 +3547,17 @@ register_model_group(
) )
register_model_group(
models={
"VibeThinker-1.5B": {
DownloadSource.DEFAULT: "WeiboAI/VibeThinker-1.5B",
DownloadSource.MODELSCOPE: "WeiboAI/VibeThinker-1.5B",
},
},
template="qwen3",
)
register_model_group( register_model_group(
models={ models={
"Vicuna-v1.5-7B-Chat": { "Vicuna-v1.5-7B-Chat": {

View File

@@ -117,7 +117,7 @@ def _configure_library_root_logger() -> None:
library_root_logger.propagate = False library_root_logger.propagate = False
def get_logger(name: Optional[str] = None) -> "_Logger": def get_logger(name: str | None = None) -> "_Logger":
r"""Return a logger with the specified name. It it not supposed to be accessed externally.""" r"""Return a logger with the specified name. It it not supposed to be accessed externally."""
if name is None: if name is None:
name = _get_library_name() name = _get_library_name()

View File

@@ -313,6 +313,10 @@ def use_ray() -> bool:
return is_env_enabled("USE_RAY") return is_env_enabled("USE_RAY")
def use_kt() -> bool:
return is_env_enabled("USE_KT")
def find_available_port() -> int: def find_available_port() -> int:
r"""Find an available port on the local machine.""" r"""Find an available port on the local machine."""
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
@@ -328,3 +332,7 @@ def fix_proxy(ipv6_enabled: bool = False) -> None:
if ipv6_enabled: if ipv6_enabled:
os.environ.pop("http_proxy", None) os.environ.pop("http_proxy", None)
os.environ.pop("HTTP_PROXY", None) os.environ.pop("HTTP_PROXY", None)
os.environ.pop("https_proxy", None)
os.environ.pop("HTTPS_PROXY", None)
os.environ.pop("all_proxy", None)
os.environ.pop("ALL_PROXY", None)

View File

@@ -70,6 +70,10 @@ def is_matplotlib_available():
return _is_package_available("matplotlib") return _is_package_available("matplotlib")
def is_mcore_adapter_available():
return _is_package_available("mcore_adapter")
def is_pillow_available(): def is_pillow_available():
return _is_package_available("PIL") return _is_package_available("PIL")
@@ -78,6 +82,10 @@ def is_ray_available():
return _is_package_available("ray") return _is_package_available("ray")
def is_kt_available():
return _is_package_available("ktransformers")
def is_requests_available(): def is_requests_available():
return _is_package_available("requests") return _is_package_available("requests")
@@ -86,6 +94,14 @@ def is_rouge_available():
return _is_package_available("rouge_chinese") return _is_package_available("rouge_chinese")
def is_safetensors_available():
return _is_package_available("safetensors")
def is_sglang_available():
return _is_package_available("sglang")
def is_starlette_available(): def is_starlette_available():
return _is_package_available("sse_starlette") return _is_package_available("sse_starlette")
@@ -95,13 +111,14 @@ def is_transformers_version_greater_than(content: str):
return _get_package_version("transformers") >= version.parse(content) return _get_package_version("transformers") >= version.parse(content)
@lru_cache
def is_torch_version_greater_than(content: str):
return _get_package_version("torch") >= version.parse(content)
def is_uvicorn_available(): def is_uvicorn_available():
return _is_package_available("uvicorn") return _is_package_available("uvicorn")
def is_vllm_available(): def is_vllm_available():
return _is_package_available("vllm") return _is_package_available("vllm")
def is_sglang_available():
return _is_package_available("sglang")

View File

@@ -16,22 +16,22 @@
# limitations under the License. # limitations under the License.
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
from typing import Any, Literal, Optional from typing import Any, Literal
@dataclass @dataclass
class DataArguments: class DataArguments:
r"""Arguments pertaining to what data we are going to input our model for training and evaluation.""" r"""Arguments pertaining to what data we are going to input our model for training and evaluation."""
template: Optional[str] = field( template: str | None = field(
default=None, default=None,
metadata={"help": "Which template to use for constructing prompts in training and inference."}, metadata={"help": "Which template to use for constructing prompts in training and inference."},
) )
dataset: Optional[str] = field( dataset: str | None = field(
default=None, default=None,
metadata={"help": "The name of dataset(s) to use for training. Use commas to separate multiple datasets."}, metadata={"help": "The name of dataset(s) to use for training. Use commas to separate multiple datasets."},
) )
eval_dataset: Optional[str] = field( eval_dataset: str | None = field(
default=None, default=None,
metadata={"help": "The name of dataset(s) to use for evaluation. Use commas to separate multiple datasets."}, metadata={"help": "The name of dataset(s) to use for evaluation. Use commas to separate multiple datasets."},
) )
@@ -39,7 +39,7 @@ class DataArguments:
default="data", default="data",
metadata={"help": "Path to the folder containing the datasets."}, metadata={"help": "Path to the folder containing the datasets."},
) )
media_dir: Optional[str] = field( media_dir: str | None = field(
default=None, default=None,
metadata={"help": "Path to the folder containing the images, videos or audios. Defaults to `dataset_dir`."}, metadata={"help": "Path to the folder containing the images, videos or audios. Defaults to `dataset_dir`."},
) )
@@ -67,7 +67,7 @@ class DataArguments:
default="concat", default="concat",
metadata={"help": "Strategy to use in dataset mixing (concat/interleave) (undersampling/oversampling)."}, metadata={"help": "Strategy to use in dataset mixing (concat/interleave) (undersampling/oversampling)."},
) )
interleave_probs: Optional[str] = field( interleave_probs: str | None = field(
default=None, default=None,
metadata={"help": "Probabilities to sample data from datasets. Use commas to separate multiple datasets."}, metadata={"help": "Probabilities to sample data from datasets. Use commas to separate multiple datasets."},
) )
@@ -79,15 +79,15 @@ class DataArguments:
default=1000, default=1000,
metadata={"help": "The number of examples in one group in pre-processing."}, metadata={"help": "The number of examples in one group in pre-processing."},
) )
preprocessing_num_workers: Optional[int] = field( preprocessing_num_workers: int | None = field(
default=None, default=None,
metadata={"help": "The number of processes to use for the pre-processing."}, metadata={"help": "The number of processes to use for the pre-processing."},
) )
max_samples: Optional[int] = field( max_samples: int | None = field(
default=None, default=None,
metadata={"help": "For debugging purposes, truncate the number of examples for each dataset."}, metadata={"help": "For debugging purposes, truncate the number of examples for each dataset."},
) )
eval_num_beams: Optional[int] = field( eval_num_beams: int | None = field(
default=None, default=None,
metadata={"help": "Number of beams to use for evaluation. This argument will be passed to `model.generate`"}, metadata={"help": "Number of beams to use for evaluation. This argument will be passed to `model.generate`"},
) )
@@ -103,7 +103,7 @@ class DataArguments:
default=False, default=False,
metadata={"help": "Whether or not to evaluate on each dataset separately."}, metadata={"help": "Whether or not to evaluate on each dataset separately."},
) )
packing: Optional[bool] = field( packing: bool | None = field(
default=None, default=None,
metadata={"help": "Enable sequences packing in training. Will automatically enable in pre-training."}, metadata={"help": "Enable sequences packing in training. Will automatically enable in pre-training."},
) )
@@ -111,19 +111,19 @@ class DataArguments:
default=False, default=False,
metadata={"help": "Enable sequence packing without cross-attention."}, metadata={"help": "Enable sequence packing without cross-attention."},
) )
tool_format: Optional[str] = field( tool_format: str | None = field(
default=None, default=None,
metadata={"help": "Tool format to use for constructing function calling examples."}, metadata={"help": "Tool format to use for constructing function calling examples."},
) )
default_system: Optional[str] = field( default_system: str | None = field(
default=None, default=None,
metadata={"help": "Override the default system message in the template."}, metadata={"help": "Override the default system message in the template."},
) )
enable_thinking: Optional[bool] = field( enable_thinking: bool | None = field(
default=True, default=True,
metadata={"help": "Whether or not to enable thinking mode for reasoning models."}, metadata={"help": "Whether or not to enable thinking mode for reasoning models."},
) )
tokenized_path: Optional[str] = field( tokenized_path: str | None = field(
default=None, default=None,
metadata={ metadata={
"help": ( "help": (

View File

@@ -14,7 +14,7 @@
import os import os
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Literal, Optional from typing import Literal
from datasets import DownloadMode from datasets import DownloadMode
@@ -46,7 +46,7 @@ class EvaluationArguments:
default=5, default=5,
metadata={"help": "Number of examplars for few-shot learning."}, metadata={"help": "Number of examplars for few-shot learning."},
) )
save_dir: Optional[str] = field( save_dir: str | None = field(
default=None, default=None,
metadata={"help": "Path to save the evaluation results."}, metadata={"help": "Path to save the evaluation results."},
) )

View File

@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
from typing import Any, Literal, Optional from typing import Any, Literal
@dataclass @dataclass
@@ -40,7 +40,7 @@ class FreezeArguments:
) )
}, },
) )
freeze_extra_modules: Optional[str] = field( freeze_extra_modules: str | None = field(
default=None, default=None,
metadata={ metadata={
"help": ( "help": (
@@ -56,7 +56,7 @@ class FreezeArguments:
class LoraArguments: class LoraArguments:
r"""Arguments pertaining to the LoRA training.""" r"""Arguments pertaining to the LoRA training."""
additional_target: Optional[str] = field( additional_target: str | None = field(
default=None, default=None,
metadata={ metadata={
"help": ( "help": (
@@ -66,7 +66,7 @@ class LoraArguments:
) )
}, },
) )
lora_alpha: Optional[int] = field( lora_alpha: int | None = field(
default=None, default=None,
metadata={"help": "The scale factor for LoRA fine-tuning (default: lora_rank * 2)."}, metadata={"help": "The scale factor for LoRA fine-tuning (default: lora_rank * 2)."},
) )
@@ -88,7 +88,7 @@ class LoraArguments:
) )
}, },
) )
loraplus_lr_ratio: Optional[float] = field( loraplus_lr_ratio: float | None = field(
default=None, default=None,
metadata={"help": "LoRA plus learning rate ratio (lr_B / lr_A)."}, metadata={"help": "LoRA plus learning rate ratio (lr_B / lr_A)."},
) )
@@ -126,7 +126,7 @@ class LoraArguments:
class OFTArguments: class OFTArguments:
r"""Arguments pertaining to the OFT training.""" r"""Arguments pertaining to the OFT training."""
additional_target: Optional[str] = field( additional_target: str | None = field(
default=None, default=None,
metadata={ metadata={
"help": ( "help": (
@@ -220,27 +220,27 @@ class RLHFArguments:
default=False, default=False,
metadata={"help": "Whiten the rewards before compute advantages in PPO training."}, metadata={"help": "Whiten the rewards before compute advantages in PPO training."},
) )
ref_model: Optional[str] = field( ref_model: str | None = field(
default=None, default=None,
metadata={"help": "Path to the reference model used for the PPO or DPO training."}, metadata={"help": "Path to the reference model used for the PPO or DPO training."},
) )
ref_model_adapters: Optional[str] = field( ref_model_adapters: str | None = field(
default=None, default=None,
metadata={"help": "Path to the adapters of the reference model."}, metadata={"help": "Path to the adapters of the reference model."},
) )
ref_model_quantization_bit: Optional[int] = field( ref_model_quantization_bit: int | None = field(
default=None, default=None,
metadata={"help": "The number of bits to quantize the reference model."}, metadata={"help": "The number of bits to quantize the reference model."},
) )
reward_model: Optional[str] = field( reward_model: str | None = field(
default=None, default=None,
metadata={"help": "Path to the reward model used for the PPO training."}, metadata={"help": "Path to the reward model used for the PPO training."},
) )
reward_model_adapters: Optional[str] = field( reward_model_adapters: str | None = field(
default=None, default=None,
metadata={"help": "Path to the adapters of the reward model."}, metadata={"help": "Path to the adapters of the reward model."},
) )
reward_model_quantization_bit: Optional[int] = field( reward_model_quantization_bit: int | None = field(
default=None, default=None,
metadata={"help": "The number of bits to quantize the reward model."}, metadata={"help": "The number of bits to quantize the reward model."},
) )
@@ -248,7 +248,7 @@ class RLHFArguments:
default="lora", default="lora",
metadata={"help": "The type of the reward model in PPO training. Lora model only supports lora training."}, metadata={"help": "The type of the reward model in PPO training. Lora model only supports lora training."},
) )
ld_alpha: Optional[float] = field( ld_alpha: float | None = field(
default=None, default=None,
metadata={ metadata={
"help": ( "help": (
@@ -361,15 +361,15 @@ class BAdamArgument:
default="layer", default="layer",
metadata={"help": "Whether to use layer-wise or ratio-wise BAdam optimizer."}, metadata={"help": "Whether to use layer-wise or ratio-wise BAdam optimizer."},
) )
badam_start_block: Optional[int] = field( badam_start_block: int | None = field(
default=None, default=None,
metadata={"help": "The starting block index for layer-wise BAdam."}, metadata={"help": "The starting block index for layer-wise BAdam."},
) )
badam_switch_mode: Optional[Literal["ascending", "descending", "random", "fixed"]] = field( badam_switch_mode: Literal["ascending", "descending", "random", "fixed"] | None = field(
default="ascending", default="ascending",
metadata={"help": "the strategy of picking block to update for layer-wise BAdam."}, metadata={"help": "the strategy of picking block to update for layer-wise BAdam."},
) )
badam_switch_interval: Optional[int] = field( badam_switch_interval: int | None = field(
default=50, default=50,
metadata={ metadata={
"help": "Number of steps to update the block for layer-wise BAdam. Use -1 to disable the block update." "help": "Number of steps to update the block for layer-wise BAdam. Use -1 to disable the block update."
@@ -406,15 +406,15 @@ class SwanLabArguments:
default=False, default=False,
metadata={"help": "Whether or not to use the SwanLab (an experiment tracking and visualization tool)."}, metadata={"help": "Whether or not to use the SwanLab (an experiment tracking and visualization tool)."},
) )
swanlab_project: Optional[str] = field( swanlab_project: str | None = field(
default="llamafactory", default="llamafactory",
metadata={"help": "The project name in SwanLab."}, metadata={"help": "The project name in SwanLab."},
) )
swanlab_workspace: Optional[str] = field( swanlab_workspace: str | None = field(
default=None, default=None,
metadata={"help": "The workspace name in SwanLab."}, metadata={"help": "The workspace name in SwanLab."},
) )
swanlab_run_name: Optional[str] = field( swanlab_run_name: str | None = field(
default=None, default=None,
metadata={"help": "The experiment name in SwanLab."}, metadata={"help": "The experiment name in SwanLab."},
) )
@@ -422,19 +422,19 @@ class SwanLabArguments:
default="cloud", default="cloud",
metadata={"help": "The mode of SwanLab."}, metadata={"help": "The mode of SwanLab."},
) )
swanlab_api_key: Optional[str] = field( swanlab_api_key: str | None = field(
default=None, default=None,
metadata={"help": "The API key for SwanLab."}, metadata={"help": "The API key for SwanLab."},
) )
swanlab_logdir: Optional[str] = field( swanlab_logdir: str | None = field(
default=None, default=None,
metadata={"help": "The log directory for SwanLab."}, metadata={"help": "The log directory for SwanLab."},
) )
swanlab_lark_webhook_url: Optional[str] = field( swanlab_lark_webhook_url: str | None = field(
default=None, default=None,
metadata={"help": "The Lark(飞书) webhook URL for SwanLab."}, metadata={"help": "The Lark(飞书) webhook URL for SwanLab."},
) )
swanlab_lark_secret: Optional[str] = field( swanlab_lark_secret: str | None = field(
default=None, default=None,
metadata={"help": "The Lark(飞书) secret for SwanLab."}, metadata={"help": "The Lark(飞书) secret for SwanLab."},
) )
@@ -461,7 +461,7 @@ class FinetuningArguments(
default="sft", default="sft",
metadata={"help": "Which stage will be performed in training."}, metadata={"help": "Which stage will be performed in training."},
) )
finetuning_type: Literal["lora", "freeze", "full"] = field( finetuning_type: Literal["lora", "oft", "freeze", "full"] = field(
default="lora", default="lora",
metadata={"help": "Which fine-tuning method to use."}, metadata={"help": "Which fine-tuning method to use."},
) )
@@ -473,6 +473,15 @@ class FinetuningArguments(
default=False, default=False,
metadata={"help": "Whether or not to use the Adam-mini optimizer."}, metadata={"help": "Whether or not to use the Adam-mini optimizer."},
) )
use_mca: bool = field(
default=False,
metadata={
"help": (
"Whether or not to use MCA (Megatron Core Adapter) training. "
"Controlled by USE_MCA environment variable."
)
},
)
use_muon: bool = field( use_muon: bool = field(
default=False, default=False,
metadata={"help": "Whether or not to use the Muon optimizer."}, metadata={"help": "Whether or not to use the Muon optimizer."},
@@ -501,7 +510,7 @@ class FinetuningArguments(
default=False, default=False,
metadata={"help": "Whether or not to disable the shuffling of the training set."}, metadata={"help": "Whether or not to disable the shuffling of the training set."},
) )
early_stopping_steps: Optional[int] = field( early_stopping_steps: int | None = field(
default=None, default=None,
metadata={"help": "Number of steps to stop training if the `metric_for_best_model` does not improve."}, metadata={"help": "Number of steps to stop training if the `metric_for_best_model` does not improve."},
) )
@@ -521,11 +530,11 @@ class FinetuningArguments(
return arg return arg
self.freeze_trainable_modules: list[str] = split_arg(self.freeze_trainable_modules) self.freeze_trainable_modules: list[str] = split_arg(self.freeze_trainable_modules)
self.freeze_extra_modules: Optional[list[str]] = split_arg(self.freeze_extra_modules) self.freeze_extra_modules: list[str] | None = split_arg(self.freeze_extra_modules)
self.lora_alpha: int = self.lora_alpha or self.lora_rank * 2 self.lora_alpha: int = self.lora_alpha or self.lora_rank * 2
self.lora_target: list[str] = split_arg(self.lora_target) self.lora_target: list[str] = split_arg(self.lora_target)
self.oft_target: list[str] = split_arg(self.oft_target) self.oft_target: list[str] = split_arg(self.oft_target)
self.additional_target: Optional[list[str]] = split_arg(self.additional_target) self.additional_target: list[str] | None = split_arg(self.additional_target)
self.galore_target: list[str] = split_arg(self.galore_target) self.galore_target: list[str] = split_arg(self.galore_target)
self.apollo_target: list[str] = split_arg(self.apollo_target) self.apollo_target: list[str] = split_arg(self.apollo_target)
self.use_ref_model = self.stage == "dpo" and self.pref_loss not in ["orpo", "simpo"] self.use_ref_model = self.stage == "dpo" and self.pref_loss not in ["orpo", "simpo"]

View File

@@ -1,4 +1,4 @@
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team. # Copyright 2025 HuggingFace Inc., the KVCache.AI team, Approaching AI, and the LlamaFactory team.
# #
# This code is inspired by the HuggingFace's transformers library. # This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py # https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
@@ -17,29 +17,30 @@
import json import json
from dataclasses import asdict, dataclass, field, fields from dataclasses import asdict, dataclass, field, fields
from typing import Any, Literal, Optional, Union from typing import Any, Literal, Self
import torch import torch
from transformers.training_args import _convert_str_dict
from typing_extensions import Self
from omegaconf import OmegaConf from omegaconf import OmegaConf
from transformers.training_args import _convert_str_dict
from ..extras.constants import AttentionFunction, EngineName, QuantizationMethod, RopeScaling from ..extras.constants import AttentionFunction, EngineName, QuantizationMethod, RopeScaling
from ..extras.logging import get_logger from ..extras.logging import get_logger
logger = get_logger(__name__) logger = get_logger(__name__)
@dataclass @dataclass
class BaseModelArguments: class BaseModelArguments:
r"""Arguments pertaining to the model.""" r"""Arguments pertaining to the model."""
model_name_or_path: Optional[str] = field( model_name_or_path: str | None = field(
default=None, default=None,
metadata={ metadata={
"help": "Path to the model weight or identifier from huggingface.co/models or modelscope.cn/models." "help": "Path to the model weight or identifier from huggingface.co/models or modelscope.cn/models."
}, },
) )
adapter_name_or_path: Optional[str] = field( adapter_name_or_path: str | None = field(
default=None, default=None,
metadata={ metadata={
"help": ( "help": (
@@ -48,11 +49,11 @@ class BaseModelArguments:
) )
}, },
) )
adapter_folder: Optional[str] = field( adapter_folder: str | None = field(
default=None, default=None,
metadata={"help": "The folder containing the adapter weights to load."}, metadata={"help": "The folder containing the adapter weights to load."},
) )
cache_dir: Optional[str] = field( cache_dir: str | None = field(
default=None, default=None,
metadata={"help": "Where to store the pre-trained models downloaded from huggingface.co or modelscope.cn."}, metadata={"help": "Where to store the pre-trained models downloaded from huggingface.co or modelscope.cn."},
) )
@@ -68,17 +69,17 @@ class BaseModelArguments:
default=False, default=False,
metadata={"help": "Whether or not the special tokens should be split during the tokenization process."}, metadata={"help": "Whether or not the special tokens should be split during the tokenization process."},
) )
add_tokens: Optional[str] = field( add_tokens: str | None = field(
default=None, default=None,
metadata={ metadata={
"help": "Non-special tokens to be added into the tokenizer. Use commas to separate multiple tokens." "help": "Non-special tokens to be added into the tokenizer. Use commas to separate multiple tokens."
}, },
) )
add_special_tokens: Optional[str] = field( add_special_tokens: str | None = field(
default=None, default=None,
metadata={"help": "Special tokens to be added into the tokenizer. Use commas to separate multiple tokens."}, metadata={"help": "Special tokens to be added into the tokenizer. Use commas to separate multiple tokens."},
) )
new_special_tokens_config: Optional[str] = field( new_special_tokens_config: str | None = field(
default=None, default=None,
metadata={ metadata={
"help": ( "help": (
@@ -108,7 +109,7 @@ class BaseModelArguments:
default=True, default=True,
metadata={"help": "Whether or not to use memory-efficient model loading."}, metadata={"help": "Whether or not to use memory-efficient model loading."},
) )
rope_scaling: Optional[RopeScaling] = field( rope_scaling: RopeScaling | None = field(
default=None, default=None,
metadata={"help": "Which scaling strategy should be adopted for the RoPE embeddings."}, metadata={"help": "Which scaling strategy should be adopted for the RoPE embeddings."},
) )
@@ -120,7 +121,7 @@ class BaseModelArguments:
default=False, default=False,
metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."}, metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."},
) )
mixture_of_depths: Optional[Literal["convert", "load"]] = field( mixture_of_depths: Literal["convert", "load"] | None = field(
default=None, default=None,
metadata={"help": "Convert the model to mixture-of-depths (MoD) or load the MoD model."}, metadata={"help": "Convert the model to mixture-of-depths (MoD) or load the MoD model."},
) )
@@ -136,7 +137,7 @@ class BaseModelArguments:
default=False, default=False,
metadata={"help": "Whether or not to enable liger kernel for faster training."}, metadata={"help": "Whether or not to enable liger kernel for faster training."},
) )
moe_aux_loss_coef: Optional[float] = field( moe_aux_loss_coef: float | None = field(
default=None, default=None,
metadata={"help": "Coefficient of the auxiliary router loss in mixture-of-experts model."}, metadata={"help": "Coefficient of the auxiliary router loss in mixture-of-experts model."},
) )
@@ -168,23 +169,27 @@ class BaseModelArguments:
default="offload", default="offload",
metadata={"help": "Path to offload model weights."}, metadata={"help": "Path to offload model weights."},
) )
use_cache: bool = field( use_kv_cache: bool = field(
default=True, default=True,
metadata={"help": "Whether or not to use KV cache in generation."}, metadata={"help": "Whether or not to use KV cache in generation."},
) )
use_v1_kernels: bool = field(
default=False,
metadata={"help": "Whether or not to use high-performance kernels in training."},
)
infer_dtype: Literal["auto", "float16", "bfloat16", "float32"] = field( infer_dtype: Literal["auto", "float16", "bfloat16", "float32"] = field(
default="auto", default="auto",
metadata={"help": "Data type for model weights and activations at inference."}, metadata={"help": "Data type for model weights and activations at inference."},
) )
hf_hub_token: Optional[str] = field( hf_hub_token: str | None = field(
default=None, default=None,
metadata={"help": "Auth token to log in with Hugging Face Hub."}, metadata={"help": "Auth token to log in with Hugging Face Hub."},
) )
ms_hub_token: Optional[str] = field( ms_hub_token: str | None = field(
default=None, default=None,
metadata={"help": "Auth token to log in with ModelScope Hub."}, metadata={"help": "Auth token to log in with ModelScope Hub."},
) )
om_hub_token: Optional[str] = field( om_hub_token: str | None = field(
default=None, default=None,
metadata={"help": "Auth token to log in with Modelers Hub."}, metadata={"help": "Auth token to log in with Modelers Hub."},
) )
@@ -277,7 +282,7 @@ class QuantizationArguments:
default=QuantizationMethod.BNB, default=QuantizationMethod.BNB,
metadata={"help": "Quantization method to use for on-the-fly quantization."}, metadata={"help": "Quantization method to use for on-the-fly quantization."},
) )
quantization_bit: Optional[int] = field( quantization_bit: int | None = field(
default=None, default=None,
metadata={"help": "The number of bits to quantize the model using on-the-fly quantization."}, metadata={"help": "The number of bits to quantize the model using on-the-fly quantization."},
) )
@@ -289,7 +294,7 @@ class QuantizationArguments:
default=True, default=True,
metadata={"help": "Whether or not to use double quantization in bitsandbytes int4 training."}, metadata={"help": "Whether or not to use double quantization in bitsandbytes int4 training."},
) )
quantization_device_map: Optional[Literal["auto"]] = field( quantization_device_map: Literal["auto"] | None = field(
default=None, default=None,
metadata={"help": "Device map used to infer the 4-bit quantized model, needs bitsandbytes>=0.43.0."}, metadata={"help": "Device map used to infer the 4-bit quantized model, needs bitsandbytes>=0.43.0."},
) )
@@ -369,7 +374,7 @@ class ProcessorArguments:
class ExportArguments: class ExportArguments:
r"""Arguments pertaining to the model export.""" r"""Arguments pertaining to the model export."""
export_dir: Optional[str] = field( export_dir: str | None = field(
default=None, default=None,
metadata={"help": "Path to the directory to save the exported model."}, metadata={"help": "Path to the directory to save the exported model."},
) )
@@ -381,11 +386,11 @@ class ExportArguments:
default="cpu", default="cpu",
metadata={"help": "The device used in model export, use `auto` to accelerate exporting."}, metadata={"help": "The device used in model export, use `auto` to accelerate exporting."},
) )
export_quantization_bit: Optional[int] = field( export_quantization_bit: int | None = field(
default=None, default=None,
metadata={"help": "The number of bits to quantize the exported model."}, metadata={"help": "The number of bits to quantize the exported model."},
) )
export_quantization_dataset: Optional[str] = field( export_quantization_dataset: str | None = field(
default=None, default=None,
metadata={"help": "Path to the dataset or dataset name to use in quantizing the exported model."}, metadata={"help": "Path to the dataset or dataset name to use in quantizing the exported model."},
) )
@@ -401,7 +406,7 @@ class ExportArguments:
default=False, default=False,
metadata={"help": "Whether or not to save the `.bin` files instead of `.safetensors`."}, metadata={"help": "Whether or not to save the `.bin` files instead of `.safetensors`."},
) )
export_hub_model_id: Optional[str] = field( export_hub_model_id: str | None = field(
default=None, default=None,
metadata={"help": "The name of the repository if push the model to the Hugging Face hub."}, metadata={"help": "The name of the repository if push the model to the Hugging Face hub."},
) )
@@ -431,7 +436,7 @@ class VllmArguments:
default=32, default=32,
metadata={"help": "Maximum rank of all LoRAs in the vLLM engine."}, metadata={"help": "Maximum rank of all LoRAs in the vLLM engine."},
) )
vllm_config: Optional[Union[dict, str]] = field( vllm_config: dict | str | None = field(
default=None, default=None,
metadata={"help": "Config to initialize the vllm engine. Please use JSON strings."}, metadata={"help": "Config to initialize the vllm engine. Please use JSON strings."},
) )
@@ -457,7 +462,7 @@ class SGLangArguments:
default=-1, default=-1,
metadata={"help": "Tensor parallel size for the SGLang engine."}, metadata={"help": "Tensor parallel size for the SGLang engine."},
) )
sglang_config: Optional[Union[dict, str]] = field( sglang_config: dict | str | None = field(
default=None, default=None,
metadata={"help": "Config to initialize the SGLang engine. Please use JSON strings."}, metadata={"help": "Config to initialize the SGLang engine. Please use JSON strings."},
) )
@@ -473,26 +478,77 @@ class SGLangArguments:
self.sglang_config = _convert_str_dict(json.loads(self.sglang_config)) self.sglang_config = _convert_str_dict(json.loads(self.sglang_config))
@dataclass
class KTransformersArguments:
r"""Arguments pertaining to the KT training."""
use_kt: bool = field(
default=False,
metadata={"help": "Whether To Use KTransformers Optimizations For LoRA Training."},
)
kt_optimize_rule: str | None = field(
default=None,
metadata={
"help": "Path To The KTransformers Optimize Rule; See https://github.com/kvcache-ai/ktransformers/."
},
)
cpu_infer: int | None = field(
default=32,
metadata={"help": "Number Of CPU Cores Used For Computation."},
)
chunk_size: int | None = field(
default=8192,
metadata={"help": "Chunk Size Used For CPU Compute In KTransformers."},
)
mode: str | None = field(
default="normal",
metadata={"help": "Normal Or Long_Context For Llama Models."},
)
kt_maxlen: int = field(
default=4096,
metadata={"help": "Maximum Sequence (Prompt + Response) Length Of The KT Engine."},
)
kt_use_cuda_graph: bool = field(
default=True,
metadata={"help": "Whether To Use CUDA Graphs For The KT Engine."},
)
kt_mode: str = field(
default="normal",
metadata={"help": "Normal Or Long_Context Mode For The KT Engine."},
)
kt_force_think: bool = field(
default=False,
metadata={"help": "Force-Think Toggle For The KT Engine."},
)
@dataclass @dataclass
class ModelArguments( class ModelArguments(
SGLangArguments, VllmArguments, ExportArguments, ProcessorArguments, QuantizationArguments, BaseModelArguments SGLangArguments,
VllmArguments,
KTransformersArguments,
ExportArguments,
ProcessorArguments,
QuantizationArguments,
BaseModelArguments,
): ):
r"""Arguments pertaining to which model/config/tokenizer we are going to fine-tune or infer. r"""Arguments pertaining to which model/config/tokenizer we are going to fine-tune or infer.
The class on the most right will be displayed first. The class on the most right will be displayed first.
""" """
compute_dtype: Optional[torch.dtype] = field( compute_dtype: torch.dtype | None = field(
default=None, default=None,
init=False, init=False,
metadata={"help": "Torch data type for computing model outputs, derived from `fp/bf16`. Do not specify it."}, metadata={"help": "Torch data type for computing model outputs, derived from `fp/bf16`. Do not specify it."},
) )
device_map: Optional[Union[str, dict[str, Any]]] = field( device_map: str | dict[str, Any] | None = field(
default=None, default=None,
init=False, init=False,
metadata={"help": "Device map for model placement, derived from training stage. Do not specify it."}, metadata={"help": "Device map for model placement, derived from training stage. Do not specify it."},
) )
model_max_length: Optional[int] = field( model_max_length: int | None = field(
default=None, default=None,
init=False, init=False,
metadata={"help": "The maximum input length for model, derived from `cutoff_len`. Do not specify it."}, metadata={"help": "The maximum input length for model, derived from `cutoff_len`. Do not specify it."},

View File

@@ -18,7 +18,7 @@
import os import os
import sys import sys
from pathlib import Path from pathlib import Path
from typing import Any, Optional, Union from typing import Any, Optional
import torch import torch
import transformers import transformers
@@ -32,7 +32,7 @@ from transformers.utils import is_torch_bf16_gpu_available, is_torch_npu_availab
from ..extras import logging from ..extras import logging
from ..extras.constants import CHECKPOINT_NAMES, EngineName from ..extras.constants import CHECKPOINT_NAMES, EngineName
from ..extras.misc import check_dependencies, check_version, get_current_device, is_env_enabled from ..extras.misc import check_dependencies, check_version, get_current_device, is_env_enabled
from ..extras.packages import is_transformers_version_greater_than from ..extras.packages import is_mcore_adapter_available, is_transformers_version_greater_than
from .data_args import DataArguments from .data_args import DataArguments
from .evaluation_args import EvaluationArguments from .evaluation_args import EvaluationArguments
from .finetuning_args import FinetuningArguments from .finetuning_args import FinetuningArguments
@@ -53,8 +53,19 @@ _INFER_CLS = tuple[ModelArguments, DataArguments, FinetuningArguments, Generatin
_EVAL_ARGS = [ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments] _EVAL_ARGS = [ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments]
_EVAL_CLS = tuple[ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments] _EVAL_CLS = tuple[ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments]
if is_mcore_adapter_available() and is_env_enabled("USE_MCA"):
from mcore_adapter import TrainingArguments as McaTrainingArguments
def read_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> Union[dict[str, Any], list[str]]: _TRAIN_MCA_ARGS = [ModelArguments, DataArguments, McaTrainingArguments, FinetuningArguments, GeneratingArguments]
_TRAIN_MCA_CLS = tuple[
ModelArguments, DataArguments, McaTrainingArguments, FinetuningArguments, GeneratingArguments
]
else:
_TRAIN_MCA_ARGS = []
_TRAIN_MCA_CLS = tuple()
def read_args(args: dict[str, Any] | list[str] | None = None) -> dict[str, Any] | list[str]:
r"""Get arguments from the command line or a config file.""" r"""Get arguments from the command line or a config file."""
if args is not None: if args is not None:
return args return args
@@ -72,7 +83,7 @@ def read_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> Union[
def _parse_args( def _parse_args(
parser: "HfArgumentParser", args: Optional[Union[dict[str, Any], list[str]]] = None, allow_extra_keys: bool = False parser: "HfArgumentParser", args: dict[str, Any] | list[str] | None = None, allow_extra_keys: bool = False
) -> tuple[Any]: ) -> tuple[Any]:
args = read_args(args) args = read_args(args)
if isinstance(args, dict): if isinstance(args, dict):
@@ -145,6 +156,9 @@ def _check_extra_dependencies(
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
training_args: Optional["TrainingArguments"] = None, training_args: Optional["TrainingArguments"] = None,
) -> None: ) -> None:
if model_args.use_kt:
check_version("ktransformers", mandatory=True)
if model_args.use_unsloth: if model_args.use_unsloth:
check_version("unsloth", mandatory=True) check_version("unsloth", mandatory=True)
@@ -191,32 +205,57 @@ def _check_extra_dependencies(
check_version("rouge_chinese", mandatory=True) check_version("rouge_chinese", mandatory=True)
def _parse_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _TRAIN_CLS: def _parse_train_args(args: dict[str, Any] | list[str] | None = None) -> _TRAIN_CLS:
parser = HfArgumentParser(_TRAIN_ARGS) parser = HfArgumentParser(_TRAIN_ARGS)
allow_extra_keys = is_env_enabled("ALLOW_EXTRA_ARGS") allow_extra_keys = is_env_enabled("ALLOW_EXTRA_ARGS")
return _parse_args(parser, args, allow_extra_keys=allow_extra_keys) return _parse_args(parser, args, allow_extra_keys=allow_extra_keys)
def _parse_infer_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _INFER_CLS: def _parse_train_mca_args(args: dict[str, Any] | list[str] | None = None) -> _TRAIN_MCA_CLS:
parser = HfArgumentParser(_TRAIN_MCA_ARGS)
allow_extra_keys = is_env_enabled("ALLOW_EXTRA_ARGS")
model_args, data_args, training_args, finetuning_args, generating_args = _parse_args(
parser, args, allow_extra_keys=allow_extra_keys
)
_configure_mca_training_args(training_args, data_args, finetuning_args)
return model_args, data_args, training_args, finetuning_args, generating_args
def _configure_mca_training_args(training_args, data_args, finetuning_args) -> None:
"""Patch training args to avoid args checking errors and sync MCA settings."""
training_args.predict_with_generate = False
training_args.generation_max_length = data_args.cutoff_len
training_args.generation_num_beams = 1
training_args.use_mca = True
finetuning_args.use_mca = True
def _parse_infer_args(args: dict[str, Any] | list[str] | None = None) -> _INFER_CLS:
parser = HfArgumentParser(_INFER_ARGS) parser = HfArgumentParser(_INFER_ARGS)
allow_extra_keys = is_env_enabled("ALLOW_EXTRA_ARGS") allow_extra_keys = is_env_enabled("ALLOW_EXTRA_ARGS")
return _parse_args(parser, args, allow_extra_keys=allow_extra_keys) return _parse_args(parser, args, allow_extra_keys=allow_extra_keys)
def _parse_eval_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _EVAL_CLS: def _parse_eval_args(args: dict[str, Any] | list[str] | None = None) -> _EVAL_CLS:
parser = HfArgumentParser(_EVAL_ARGS) parser = HfArgumentParser(_EVAL_ARGS)
allow_extra_keys = is_env_enabled("ALLOW_EXTRA_ARGS") allow_extra_keys = is_env_enabled("ALLOW_EXTRA_ARGS")
return _parse_args(parser, args, allow_extra_keys=allow_extra_keys) return _parse_args(parser, args, allow_extra_keys=allow_extra_keys)
def get_ray_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> RayArguments: def get_ray_args(args: dict[str, Any] | list[str] | None = None) -> RayArguments:
parser = HfArgumentParser(RayArguments) parser = HfArgumentParser(RayArguments)
(ray_args,) = _parse_args(parser, args, allow_extra_keys=True) (ray_args,) = _parse_args(parser, args, allow_extra_keys=True)
return ray_args return ray_args
def get_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _TRAIN_CLS: def get_train_args(args: dict[str, Any] | list[str] | None = None) -> _TRAIN_CLS:
model_args, data_args, training_args, finetuning_args, generating_args = _parse_train_args(args) if is_env_enabled("USE_MCA"):
model_args, data_args, training_args, finetuning_args, generating_args = _parse_train_mca_args(args)
else:
model_args, data_args, training_args, finetuning_args, generating_args = _parse_train_args(args)
finetuning_args.use_mca = False
# Setup logging # Setup logging
if training_args.should_log: if training_args.should_log:
@@ -246,13 +285,16 @@ def get_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _
if model_args.shift_attn: if model_args.shift_attn:
raise ValueError("PPO training is incompatible with S^2-Attn.") raise ValueError("PPO training is incompatible with S^2-Attn.")
if finetuning_args.reward_model_type == "lora" and model_args.use_kt:
raise ValueError("KTransformers does not support lora reward model.")
if finetuning_args.reward_model_type == "lora" and model_args.use_unsloth: if finetuning_args.reward_model_type == "lora" and model_args.use_unsloth:
raise ValueError("Unsloth does not support lora reward model.") raise ValueError("Unsloth does not support lora reward model.")
if training_args.report_to and training_args.report_to[0] not in ["wandb", "tensorboard"]: if training_args.report_to and training_args.report_to[0] not in ["wandb", "tensorboard"]:
raise ValueError("PPO only accepts wandb or tensorboard logger.") raise ValueError("PPO only accepts wandb or tensorboard logger.")
if training_args.parallel_mode == ParallelMode.NOT_DISTRIBUTED: if not model_args.use_kt and training_args.parallel_mode == ParallelMode.NOT_DISTRIBUTED:
raise ValueError("Please launch distributed training with `llamafactory-cli` or `torchrun`.") raise ValueError("Please launch distributed training with `llamafactory-cli` or `torchrun`.")
if training_args.deepspeed and training_args.parallel_mode != ParallelMode.DISTRIBUTED: if training_args.deepspeed and training_args.parallel_mode != ParallelMode.DISTRIBUTED:
@@ -264,18 +306,15 @@ def get_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _
if training_args.do_train and data_args.dataset is None: if training_args.do_train and data_args.dataset is None:
raise ValueError("Please specify dataset for training.") raise ValueError("Please specify dataset for training.")
if (training_args.do_eval or training_args.do_predict) and ( if (training_args.do_eval or training_args.do_predict or training_args.predict_with_generate) and (
data_args.eval_dataset is None and data_args.val_size < 1e-6 data_args.eval_dataset is None and data_args.val_size < 1e-6
): ):
raise ValueError("Please specify dataset for evaluation.") raise ValueError("Please make sure eval_dataset be provided or val_size >1e-6")
if training_args.predict_with_generate: if training_args.predict_with_generate:
if is_deepspeed_zero3_enabled(): if is_deepspeed_zero3_enabled():
raise ValueError("`predict_with_generate` is incompatible with DeepSpeed ZeRO-3.") raise ValueError("`predict_with_generate` is incompatible with DeepSpeed ZeRO-3.")
if data_args.eval_dataset is None:
raise ValueError("Cannot use `predict_with_generate` if `eval_dataset` is None.")
if finetuning_args.compute_accuracy: if finetuning_args.compute_accuracy:
raise ValueError("Cannot use `predict_with_generate` and `compute_accuracy` together.") raise ValueError("Cannot use `predict_with_generate` and `compute_accuracy` together.")
@@ -314,6 +353,9 @@ def get_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _
if model_args.use_unsloth and is_deepspeed_zero3_enabled(): if model_args.use_unsloth and is_deepspeed_zero3_enabled():
raise ValueError("Unsloth is incompatible with DeepSpeed ZeRO-3.") raise ValueError("Unsloth is incompatible with DeepSpeed ZeRO-3.")
if model_args.use_kt and is_deepspeed_zero3_enabled():
raise ValueError("KTransformers is incompatible with DeepSpeed ZeRO-3.")
if data_args.neat_packing and is_transformers_version_greater_than("4.53.0"): if data_args.neat_packing and is_transformers_version_greater_than("4.53.0"):
raise ValueError("Neat packing is incompatible with transformers>=4.53.0.") raise ValueError("Neat packing is incompatible with transformers>=4.53.0.")
@@ -431,7 +473,7 @@ def get_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _
return model_args, data_args, training_args, finetuning_args, generating_args return model_args, data_args, training_args, finetuning_args, generating_args
def get_infer_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _INFER_CLS: def get_infer_args(args: dict[str, Any] | list[str] | None = None) -> _INFER_CLS:
model_args, data_args, finetuning_args, generating_args = _parse_infer_args(args) model_args, data_args, finetuning_args, generating_args = _parse_infer_args(args)
# Setup logging # Setup logging
@@ -466,7 +508,7 @@ def get_infer_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _
return model_args, data_args, finetuning_args, generating_args return model_args, data_args, finetuning_args, generating_args
def get_eval_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _EVAL_CLS: def get_eval_args(args: dict[str, Any] | list[str] | None = None) -> _EVAL_CLS:
model_args, data_args, eval_args, finetuning_args = _parse_eval_args(args) model_args, data_args, eval_args, finetuning_args = _parse_eval_args(args)
# Setup logging # Setup logging

View File

@@ -14,19 +14,33 @@
import json import json
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Literal, Optional, Union from typing import Literal
from transformers import Seq2SeqTrainingArguments from transformers import Seq2SeqTrainingArguments
from transformers.training_args import _convert_str_dict from transformers.training_args import _convert_str_dict
from ..extras.misc import use_ray from ..extras.misc import is_env_enabled, use_ray
from ..extras.packages import is_mcore_adapter_available
if is_env_enabled("USE_MCA"):
if not is_mcore_adapter_available():
raise ImportError(
"mcore_adapter is required when USE_MCA=1. Please install `mcore_adapter` and its dependencies."
)
from mcore_adapter import Seq2SeqTrainingArguments as McaSeq2SeqTrainingArguments
BaseTrainingArguments = McaSeq2SeqTrainingArguments
else:
BaseTrainingArguments = Seq2SeqTrainingArguments
@dataclass @dataclass
class RayArguments: class RayArguments:
r"""Arguments pertaining to the Ray training.""" r"""Arguments pertaining to the Ray training."""
ray_run_name: Optional[str] = field( ray_run_name: str | None = field(
default=None, default=None,
metadata={"help": "The training results will be saved at `<ray_storage_path>/ray_run_name`."}, metadata={"help": "The training results will be saved at `<ray_storage_path>/ray_run_name`."},
) )
@@ -34,7 +48,7 @@ class RayArguments:
default="./saves", default="./saves",
metadata={"help": "The storage path to save training results to"}, metadata={"help": "The storage path to save training results to"},
) )
ray_storage_filesystem: Optional[Literal["s3", "gs", "gcs"]] = field( ray_storage_filesystem: Literal["s3", "gs", "gcs"] | None = field(
default=None, default=None,
metadata={"help": "The storage filesystem to use. If None specified, local filesystem will be used."}, metadata={"help": "The storage filesystem to use. If None specified, local filesystem will be used."},
) )
@@ -42,7 +56,7 @@ class RayArguments:
default=1, default=1,
metadata={"help": "The number of workers for Ray training. Default is 1 worker."}, metadata={"help": "The number of workers for Ray training. Default is 1 worker."},
) )
resources_per_worker: Union[dict, str] = field( resources_per_worker: dict | str = field(
default_factory=lambda: {"GPU": 1}, default_factory=lambda: {"GPU": 1},
metadata={"help": "The resources per worker for Ray training. Default is to use 1 GPU per worker."}, metadata={"help": "The resources per worker for Ray training. Default is to use 1 GPU per worker."},
) )
@@ -50,7 +64,7 @@ class RayArguments:
default="PACK", default="PACK",
metadata={"help": "The placement strategy for Ray training. Default is PACK."}, metadata={"help": "The placement strategy for Ray training. Default is PACK."},
) )
ray_init_kwargs: Optional[Union[dict, str]] = field( ray_init_kwargs: dict | str | None = field(
default=None, default=None,
metadata={"help": "The arguments to pass to ray.init for Ray training. Default is None."}, metadata={"help": "The arguments to pass to ray.init for Ray training. Default is None."},
) )
@@ -78,9 +92,14 @@ class RayArguments:
@dataclass @dataclass
class TrainingArguments(RayArguments, Seq2SeqTrainingArguments): class TrainingArguments(RayArguments, BaseTrainingArguments):
r"""Arguments pertaining to the trainer.""" r"""Arguments pertaining to the trainer."""
overwrite_output_dir: bool = field(
default=False,
metadata={"help": "deprecated"},
)
def __post_init__(self): def __post_init__(self):
Seq2SeqTrainingArguments.__post_init__(self)
RayArguments.__post_init__(self) RayArguments.__post_init__(self)
BaseTrainingArguments.__post_init__(self)

View File

@@ -38,7 +38,7 @@ USAGE = (
def launch(): def launch():
from .extras import logging from .extras import logging
from .extras.env import VERSION, print_env from .extras.env import VERSION, print_env
from .extras.misc import find_available_port, get_device_count, is_env_enabled, use_ray from .extras.misc import find_available_port, get_device_count, is_env_enabled, use_kt, use_ray
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
WELCOME = ( WELCOME = (
@@ -54,7 +54,12 @@ def launch():
) )
command = sys.argv.pop(1) if len(sys.argv) > 1 else "help" command = sys.argv.pop(1) if len(sys.argv) > 1 else "help"
if command == "train" and (is_env_enabled("FORCE_TORCHRUN") or (get_device_count() > 1 and not use_ray())): if is_env_enabled("USE_MCA"): # force use torchrun
os.environ["FORCE_TORCHRUN"] = "1"
if command == "train" and (
is_env_enabled("FORCE_TORCHRUN") or (get_device_count() > 1 and not use_ray() and not use_kt())
):
# launch distributed training # launch distributed training
nnodes = os.getenv("NNODES", "1") nnodes = os.getenv("NNODES", "1")
node_rank = os.getenv("NODE_RANK", "0") node_rank = os.getenv("NODE_RANK", "0")

Some files were not shown because too many files have changed in this diff Show More