mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-28 09:40:34 +08:00
Compare commits
82 Commits
47a7dc1698
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7ef1fba34a | ||
|
|
eceec8ab69 | ||
|
|
b44f651e09 | ||
|
|
55590f5ece | ||
|
|
a1b1931b4a | ||
|
|
3c17f2722c | ||
|
|
a882e2d5fc | ||
|
|
a754604c11 | ||
|
|
6a2eafbae3 | ||
|
|
84485406b7 | ||
|
|
1c8a42d2f8 | ||
|
|
7901b2f32e | ||
|
|
1f1f5a7d1b | ||
|
|
6ef9854713 | ||
|
|
4923f52a28 | ||
|
|
0894b4f37e | ||
|
|
b0d49e137f | ||
|
|
ddd7dcc722 | ||
|
|
5204cd2bca | ||
|
|
8c74dca76a | ||
|
|
e8deda53a1 | ||
|
|
a769fb94b9 | ||
|
|
964569751f | ||
|
|
9fd4b094d4 | ||
|
|
18c21bce5a | ||
|
|
a0179772ab | ||
|
|
aeda079014 | ||
|
|
fdd24276ed | ||
|
|
110d21713e | ||
|
|
203069e11c | ||
|
|
4fd94141a4 | ||
|
|
22d6ac29d5 | ||
|
|
cff4483392 | ||
|
|
5d56817e2b | ||
|
|
1bbb461f76 | ||
|
|
c1f5f8fff6 | ||
|
|
5744f1ea94 | ||
|
|
739954910a | ||
|
|
109162dc56 | ||
|
|
165f3f073a | ||
|
|
efb13b7483 | ||
|
|
e43a972b25 | ||
|
|
22be45c78c | ||
|
|
d1f585f80a | ||
|
|
955396e8a5 | ||
|
|
231756a5bf | ||
|
|
2c4fb3c97e | ||
|
|
2b6f16f261 | ||
|
|
f17efde693 | ||
|
|
591fc9ed02 | ||
|
|
3140c242f0 | ||
|
|
887c562d60 | ||
|
|
9779b1f361 | ||
|
|
45f0437a14 | ||
|
|
d4e120423d | ||
|
|
10a446e373 | ||
|
|
0aa4a051af | ||
|
|
8173a88a26 | ||
|
|
fef86fa7fe | ||
|
|
5afa851f71 | ||
|
|
a711bce664 | ||
|
|
bd24350cbf | ||
|
|
bd30c0003b | ||
|
|
8edd2622ce | ||
|
|
eaf963f67f | ||
|
|
56f45e826f | ||
|
|
14abb75126 | ||
|
|
5a9939050e | ||
|
|
934b3084ee | ||
|
|
3ae15da9c0 | ||
|
|
215580c77d | ||
|
|
767b344fb4 | ||
|
|
3057db15c3 | ||
|
|
13170577b2 | ||
|
|
129e918106 | ||
|
|
9c0d033a15 | ||
|
|
2a822178de | ||
|
|
b842457ef4 | ||
|
|
2c6aded5d4 | ||
|
|
d9d67ba62d | ||
|
|
a442fa90ad | ||
|
|
8c341cbaae |
@@ -15,6 +15,7 @@ LLAMAFACTORY_VERBOSITY=
|
|||||||
USE_MODELSCOPE_HUB=
|
USE_MODELSCOPE_HUB=
|
||||||
USE_OPENMIND_HUB=
|
USE_OPENMIND_HUB=
|
||||||
USE_RAY=
|
USE_RAY=
|
||||||
|
USE_KT=
|
||||||
RECORD_VRAM=
|
RECORD_VRAM=
|
||||||
OPTIM_TORCH=
|
OPTIM_TORCH=
|
||||||
NPU_JIT_COMPILE=
|
NPU_JIT_COMPILE=
|
||||||
@@ -35,6 +36,8 @@ GRADIO_SERVER_NAME=
|
|||||||
GRADIO_SERVER_PORT=
|
GRADIO_SERVER_PORT=
|
||||||
GRADIO_ROOT_PATH=
|
GRADIO_ROOT_PATH=
|
||||||
GRADIO_IPV6=
|
GRADIO_IPV6=
|
||||||
|
# backend
|
||||||
|
USE_MCA=
|
||||||
# setup
|
# setup
|
||||||
ENABLE_SHORT_CONSOLE=
|
ENABLE_SHORT_CONSOLE=
|
||||||
# reserved (do not use)
|
# reserved (do not use)
|
||||||
|
|||||||
180
.github/copilot-instructions.md
vendored
Normal file
180
.github/copilot-instructions.md
vendored
Normal file
@@ -0,0 +1,180 @@
|
|||||||
|
# GitHub Copilot Instructions for LLaMA Factory
|
||||||
|
|
||||||
|
## Project Overview
|
||||||
|
|
||||||
|
LLaMA Factory is an efficient fine-tuning framework for 100+ large language models (LLMs). It provides:
|
||||||
|
- Support for various models: LLaMA, LLaVA, Mistral, Qwen, DeepSeek, Yi, Gemma, ChatGLM, Phi, etc.
|
||||||
|
- Multiple training methods: pre-training, supervised fine-tuning, reward modeling, PPO, DPO, KTO, ORPO
|
||||||
|
- Scalable resources: 16-bit full-tuning, freeze-tuning, LoRA and QLoRA variants
|
||||||
|
- Advanced algorithms: GaLore, BAdam, APOLLO, Adam-mini, Muon, OFT, DoRA, etc.
|
||||||
|
- Web UI (LLaMA Board) and CLI interfaces
|
||||||
|
|
||||||
|
### Architecture Versions
|
||||||
|
|
||||||
|
LLaMA Factory has two parallel architectures that can be switched via the `USE_V1` environment variable:
|
||||||
|
|
||||||
|
**v0 (default)** - File hierarchy:
|
||||||
|
- `api`, `webui` → `chat`, `eval`, `train` → `data`, `model` → `hparams` → `extras`
|
||||||
|
|
||||||
|
**v1** - File hierarchy:
|
||||||
|
- `trainers` → `core` → `accelerator`, `plugins`, `config` → `utils`
|
||||||
|
|
||||||
|
Set `USE_V1=1` to enable v1 architecture.
|
||||||
|
|
||||||
|
## Code Structure
|
||||||
|
|
||||||
|
### v0 Architecture (Default)
|
||||||
|
|
||||||
|
- `src/llamafactory/` - Main package directory
|
||||||
|
- `api/` - OpenAI-style API implementation
|
||||||
|
- `chat/` - Chat interface implementation
|
||||||
|
- `cli.py` - Command-line interface
|
||||||
|
- `data/` - Data processing and dataset handling
|
||||||
|
- `eval/` - Model evaluation utilities
|
||||||
|
- `extras/` - Additional utilities and helpers
|
||||||
|
- `hparams/` - Hyperparameter definitions
|
||||||
|
- `model/` - Model loading, patching, and utilities
|
||||||
|
- `train/` - Training pipeline implementation
|
||||||
|
- `webui/` - Gradio-based web interface
|
||||||
|
- `src/train.py` - Training entry script (delegates to `llamafactory.train.tuner`)
|
||||||
|
- `src/webui.py` - Web UI entry script (delegates to `llamafactory.webui.interface`)
|
||||||
|
- `src/api.py` - API server entry script (delegates to `llamafactory.api.app`)
|
||||||
|
- `tests/` - Test suite
|
||||||
|
- `examples/` - Example configurations for various training scenarios
|
||||||
|
- `data/` - Dataset definitions and examples
|
||||||
|
|
||||||
|
### v1 Architecture (USE_V1=1)
|
||||||
|
|
||||||
|
- `src/llamafactory/v1/` - Version 1 package directory
|
||||||
|
- `trainers/` - Training implementations
|
||||||
|
- `core/` - Core training utilities
|
||||||
|
- `accelerator/` - Acceleration and distributed training
|
||||||
|
- `plugins/` - Pluggable components (model, data, sampler, trainer)
|
||||||
|
- `config/` - Configuration management
|
||||||
|
- `utils/` - Utility functions
|
||||||
|
|
||||||
|
## Development Practices
|
||||||
|
|
||||||
|
### Code Style
|
||||||
|
|
||||||
|
- Follow the [Google Python Style Guide](https://google.github.io/styleguide/pyguide.html)
|
||||||
|
- Use ruff for linting and formatting
|
||||||
|
- Line length: 119 characters
|
||||||
|
- Indentation: 4 spaces
|
||||||
|
- Quote style: double quotes
|
||||||
|
- Use Google-style docstrings for documentation
|
||||||
|
|
||||||
|
### Import Organization
|
||||||
|
|
||||||
|
- Known first-party: `llamafactory`
|
||||||
|
- Known third-party: `accelerate`, `datasets`, `gradio`, `numpy`, `peft`, `torch`, `transformers`, `trl`
|
||||||
|
- Use 2 blank lines after imports
|
||||||
|
|
||||||
|
### Quality Checks
|
||||||
|
|
||||||
|
Before committing code, run:
|
||||||
|
```bash
|
||||||
|
make style # Auto-fix style issues
|
||||||
|
make quality # Check code quality
|
||||||
|
make test # Run test suite
|
||||||
|
```
|
||||||
|
|
||||||
|
Or use the combined command:
|
||||||
|
```bash
|
||||||
|
make commit # Run pre-commit hooks
|
||||||
|
```
|
||||||
|
|
||||||
|
### Testing
|
||||||
|
|
||||||
|
- Use pytest for testing
|
||||||
|
- Tests are located in `tests/` and `tests_v1/` directories
|
||||||
|
- Run tests with: `make test` (which runs `WANDB_DISABLED=true pytest -vv --import-mode=importlib tests/ tests_v1/`)
|
||||||
|
- Disable wandb during testing to avoid external dependencies
|
||||||
|
- **Note**: Training configurations require GPU machines, so training is typically not tested end-to-end. Use `make test` to validate file-level functionality.
|
||||||
|
|
||||||
|
### Building
|
||||||
|
|
||||||
|
Build the package with:
|
||||||
|
```bash
|
||||||
|
pip3 install build && python3 -m build
|
||||||
|
```
|
||||||
|
|
||||||
|
### License
|
||||||
|
|
||||||
|
- All source files must include the Apache 2.0 license header
|
||||||
|
- Check license headers with: `make license`
|
||||||
|
|
||||||
|
## Common Patterns
|
||||||
|
|
||||||
|
### Configuration Files
|
||||||
|
|
||||||
|
- Training configurations are typically YAML or JSON files in `examples/` directory
|
||||||
|
- Hyperparameters are defined using dataclasses in `src/llamafactory/hparams/`
|
||||||
|
|
||||||
|
### Model Support
|
||||||
|
|
||||||
|
- New model support is added through model patches in `src/llamafactory/model/`
|
||||||
|
- Visual models use the visual utilities in `src/llamafactory/model/model_utils/visual.py`
|
||||||
|
- Quantization support is in `src/llamafactory/model/model_utils/quantization.py`
|
||||||
|
|
||||||
|
### Data Processing
|
||||||
|
|
||||||
|
- Dataset definitions are in `data/dataset_info.json`
|
||||||
|
- Data templates and processors are in `src/llamafactory/data/`
|
||||||
|
|
||||||
|
### Training
|
||||||
|
|
||||||
|
- Training pipelines are in `src/llamafactory/train/`
|
||||||
|
- Support for different training methods: SFT, DPO, PPO, RM, PT, KTO, ORPO
|
||||||
|
|
||||||
|
## Key Dependencies
|
||||||
|
|
||||||
|
- Python >= 3.9.0
|
||||||
|
- PyTorch and transformers for model handling
|
||||||
|
- datasets for data processing
|
||||||
|
- peft for parameter-efficient fine-tuning
|
||||||
|
- accelerate for distributed training
|
||||||
|
- gradio for web UI
|
||||||
|
- trl for reinforcement learning
|
||||||
|
- Optional: vllm/sglang for inference, flash-attention-2, unsloth, liger-kernel
|
||||||
|
|
||||||
|
## Entry Points
|
||||||
|
|
||||||
|
- **CLI Training**: `llamafactory-cli train --config examples/train_lora/llama3_lora_sft.yaml`
|
||||||
|
- **Web UI**: `llamafactory-cli webui` or `python src/webui.py`
|
||||||
|
- **API Server**: `llamafactory-cli api` or `python src/api.py`
|
||||||
|
- **Chat Interface**: `llamafactory-cli chat --model_name_or_path MODEL_PATH`
|
||||||
|
|
||||||
|
## Environment Setup
|
||||||
|
|
||||||
|
For development:
|
||||||
|
```bash
|
||||||
|
pip install -e ".[dev]"
|
||||||
|
```
|
||||||
|
|
||||||
|
## Important Notes
|
||||||
|
|
||||||
|
- The project supports multiple backends: default PyTorch, vLLM, SGLang
|
||||||
|
- Megatron-core training is supported via mcore_adapter
|
||||||
|
- SwanLab and W&B are supported for experiment tracking
|
||||||
|
- Docker support is available with pre-built images
|
||||||
|
- Day-0/Day-1 support for latest cutting-edge models
|
||||||
|
- Multi-modal support for vision and audio understanding tasks
|
||||||
|
|
||||||
|
## Contribution Guidelines
|
||||||
|
|
||||||
|
1. Fork the repository
|
||||||
|
2. Create a development branch
|
||||||
|
3. Set up development environment with `pip install -e ".[dev]"`
|
||||||
|
4. Make changes following the style guide
|
||||||
|
5. Run quality checks: `make style && make quality`
|
||||||
|
6. Run tests: `make test`
|
||||||
|
7. Submit a pull request
|
||||||
|
|
||||||
|
## Common Commands
|
||||||
|
|
||||||
|
- `make style` - Format code
|
||||||
|
- `make quality` - Run linters
|
||||||
|
- `make test` - Run tests
|
||||||
|
- `make commit` - Install and run pre-commit hooks
|
||||||
|
- `make license` - Check license headers
|
||||||
44
.github/workflows/docker.yml
vendored
44
.github/workflows/docker.yml
vendored
@@ -7,7 +7,7 @@ on:
|
|||||||
- "main"
|
- "main"
|
||||||
paths:
|
paths:
|
||||||
- "**/*.py"
|
- "**/*.py"
|
||||||
- "requirements.txt"
|
- "pyproject.toml"
|
||||||
- "docker/**"
|
- "docker/**"
|
||||||
- ".github/workflows/*.yml"
|
- ".github/workflows/*.yml"
|
||||||
pull_request:
|
pull_request:
|
||||||
@@ -15,7 +15,7 @@ on:
|
|||||||
- "main"
|
- "main"
|
||||||
paths:
|
paths:
|
||||||
- "**/*.py"
|
- "**/*.py"
|
||||||
- "requirements.txt"
|
- "pyproject.toml"
|
||||||
- "docker/**"
|
- "docker/**"
|
||||||
- ".github/workflows/*.yml"
|
- ".github/workflows/*.yml"
|
||||||
release:
|
release:
|
||||||
@@ -27,9 +27,10 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
device:
|
include:
|
||||||
- "cuda"
|
- device: "cuda"
|
||||||
- "npu"
|
- device: "npu-a2"
|
||||||
|
- device: "npu-a3"
|
||||||
|
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
@@ -51,16 +52,11 @@ jobs:
|
|||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Set up Python
|
|
||||||
uses: actions/setup-python@v5
|
|
||||||
with:
|
|
||||||
python-version: "3.10"
|
|
||||||
|
|
||||||
- name: Get llamafactory version
|
- name: Get llamafactory version
|
||||||
id: version
|
id: version
|
||||||
run: |
|
run: |
|
||||||
if [ "${{ github.event_name }}" = "release" ]; then
|
if [ "${{ github.event_name }}" = "release" ]; then
|
||||||
echo "tag=$(python setup.py --version)" >> "$GITHUB_OUTPUT"
|
echo "tag=$(grep -oP 'VERSION = "\K[^"]+' src/llamafactory/extras/env.py)" >> "$GITHUB_OUTPUT"
|
||||||
else
|
else
|
||||||
echo "tag=latest" >> "$GITHUB_OUTPUT"
|
echo "tag=latest" >> "$GITHUB_OUTPUT"
|
||||||
fi
|
fi
|
||||||
@@ -76,7 +72,7 @@ jobs:
|
|||||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||||
|
|
||||||
- name: Login to Quay
|
- name: Login to Quay
|
||||||
if: ${{ github.event_name != 'pull_request' && matrix.device == 'npu' }}
|
if: ${{ github.event_name != 'pull_request' && matrix.device == 'npu'}}
|
||||||
uses: docker/login-action@v3
|
uses: docker/login-action@v3
|
||||||
with:
|
with:
|
||||||
registry: quay.io
|
registry: quay.io
|
||||||
@@ -89,16 +85,12 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
context: .
|
context: .
|
||||||
file: ./docker/docker-cuda/Dockerfile
|
file: ./docker/docker-cuda/Dockerfile
|
||||||
build-args: |
|
|
||||||
EXTRAS=metrics,deepspeed,liger-kernel
|
|
||||||
push: ${{ github.event_name != 'pull_request' }}
|
push: ${{ github.event_name != 'pull_request' }}
|
||||||
tags: |
|
tags: |
|
||||||
docker.io/hiyouga/llamafactory:${{ steps.version.outputs.tag }}
|
docker.io/hiyouga/llamafactory:${{ steps.version.outputs.tag }}
|
||||||
cache-from: type=gha
|
|
||||||
cache-to: type=gha,mode=max
|
|
||||||
|
|
||||||
- name: Build and push Docker image (NPU)
|
- name: Build and push Docker image (NPU-A2)
|
||||||
if: ${{ matrix.device == 'npu' }}
|
if: ${{ matrix.device == 'npu-a2' }}
|
||||||
uses: docker/build-push-action@v6
|
uses: docker/build-push-action@v6
|
||||||
with:
|
with:
|
||||||
context: .
|
context: .
|
||||||
@@ -108,5 +100,17 @@ jobs:
|
|||||||
tags: |
|
tags: |
|
||||||
docker.io/hiyouga/llamafactory:${{ steps.version.outputs.tag }}-npu-a2
|
docker.io/hiyouga/llamafactory:${{ steps.version.outputs.tag }}-npu-a2
|
||||||
quay.io/ascend/llamafactory:${{ steps.version.outputs.tag }}-npu-a2
|
quay.io/ascend/llamafactory:${{ steps.version.outputs.tag }}-npu-a2
|
||||||
cache-from: type=gha
|
|
||||||
cache-to: type=gha,mode=max
|
- name: Build and push Docker image (NPU-A3)
|
||||||
|
if: ${{ matrix.device == 'npu-a3' }}
|
||||||
|
uses: docker/build-push-action@v6
|
||||||
|
with:
|
||||||
|
context: .
|
||||||
|
platforms: linux/amd64,linux/arm64
|
||||||
|
file: ./docker/docker-npu/Dockerfile
|
||||||
|
build-args: |
|
||||||
|
BASE_IMAGE=quay.io/ascend/cann:8.3.rc2-a3-ubuntu22.04-py3.11
|
||||||
|
push: ${{ github.event_name != 'pull_request' }}
|
||||||
|
tags: |
|
||||||
|
docker.io/hiyouga/llamafactory:${{ steps.version.outputs.tag }}-npu-a3
|
||||||
|
quay.io/ascend/llamafactory:${{ steps.version.outputs.tag }}-npu-a3
|
||||||
|
|||||||
7
.github/workflows/publish.yml
vendored
7
.github/workflows/publish.yml
vendored
@@ -23,10 +23,11 @@ jobs:
|
|||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Set up Python
|
- name: Install uv
|
||||||
uses: actions/setup-python@v5
|
uses: astral-sh/setup-uv@v7
|
||||||
with:
|
with:
|
||||||
python-version: "3.9"
|
python-version: "3.11"
|
||||||
|
github-token: ${{ github.token }}
|
||||||
|
|
||||||
- name: Build package
|
- name: Build package
|
||||||
run: |
|
run: |
|
||||||
|
|||||||
46
.github/workflows/tests.yml
vendored
46
.github/workflows/tests.yml
vendored
@@ -7,14 +7,16 @@ on:
|
|||||||
- "main"
|
- "main"
|
||||||
paths:
|
paths:
|
||||||
- "**/*.py"
|
- "**/*.py"
|
||||||
- "requirements.txt"
|
- "pyproject.toml"
|
||||||
|
- "Makefile"
|
||||||
- ".github/workflows/*.yml"
|
- ".github/workflows/*.yml"
|
||||||
pull_request:
|
pull_request:
|
||||||
branches:
|
branches:
|
||||||
- "main"
|
- "main"
|
||||||
paths:
|
paths:
|
||||||
- "**/*.py"
|
- "**/*.py"
|
||||||
- "requirements.txt"
|
- "pyproject.toml"
|
||||||
|
- "Makefile"
|
||||||
- ".github/workflows/*.yml"
|
- ".github/workflows/*.yml"
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
@@ -23,10 +25,9 @@ jobs:
|
|||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
python:
|
python:
|
||||||
- "3.9"
|
|
||||||
- "3.10"
|
|
||||||
- "3.11"
|
- "3.11"
|
||||||
- "3.12"
|
- "3.12"
|
||||||
|
# - "3.13" # enable after trl is upgraded
|
||||||
os:
|
os:
|
||||||
- "ubuntu-latest"
|
- "ubuntu-latest"
|
||||||
- "windows-latest"
|
- "windows-latest"
|
||||||
@@ -34,18 +35,15 @@ jobs:
|
|||||||
transformers:
|
transformers:
|
||||||
- null
|
- null
|
||||||
include: # test backward compatibility
|
include: # test backward compatibility
|
||||||
- python: "3.9"
|
- python: "3.11"
|
||||||
os: "ubuntu-latest"
|
os: "ubuntu-latest"
|
||||||
transformers: "4.49.0"
|
transformers: "4.49.0"
|
||||||
- python: "3.9"
|
- python: "3.11"
|
||||||
os: "ubuntu-latest"
|
os: "ubuntu-latest"
|
||||||
transformers: "4.51.0"
|
transformers: "4.51.0"
|
||||||
- python: "3.9"
|
- python: "3.11"
|
||||||
os: "ubuntu-latest"
|
os: "ubuntu-latest"
|
||||||
transformers: "4.53.0"
|
transformers: "4.53.0"
|
||||||
exclude: # exclude python 3.9 on macos
|
|
||||||
- python: "3.9"
|
|
||||||
os: "macos-latest"
|
|
||||||
|
|
||||||
runs-on: ${{ matrix.os }}
|
runs-on: ${{ matrix.os }}
|
||||||
|
|
||||||
@@ -61,28 +59,23 @@ jobs:
|
|||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Set up Python
|
- name: Install uv
|
||||||
uses: actions/setup-python@v5
|
uses: astral-sh/setup-uv@v7
|
||||||
with:
|
with:
|
||||||
python-version: ${{ matrix.python }}
|
python-version: ${{ matrix.python }}
|
||||||
cache: "pip"
|
github-token: ${{ github.token }}
|
||||||
cache-dependency-path: "**/requirements*.txt"
|
enable-cache: false
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
python -m pip install --upgrade pip
|
uv venv
|
||||||
python -m pip install ".[torch,dev]"
|
uv pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
|
||||||
|
uv pip install -e ".[dev]"
|
||||||
|
|
||||||
- name: Install transformers
|
- name: Install transformers
|
||||||
if: ${{ matrix.transformers }}
|
if: ${{ matrix.transformers }}
|
||||||
run: |
|
run: |
|
||||||
python -m pip install "transformers==${{ matrix.transformers }}"
|
uv pip install "transformers==${{ matrix.transformers }}"
|
||||||
|
|
||||||
- name: Update accelerate to avoid mac os ci errors (before accelerate 1.11.0)
|
|
||||||
if: ${{ matrix.os == 'macos-latest' }}
|
|
||||||
run: |
|
|
||||||
python -m pip uninstall -y accelerate
|
|
||||||
python -m pip install "git+https://github.com/huggingface/accelerate.git"
|
|
||||||
|
|
||||||
- name: Cache files
|
- name: Cache files
|
||||||
id: hf-hub-cache
|
id: hf-hub-cache
|
||||||
@@ -94,18 +87,25 @@ jobs:
|
|||||||
- name: Check quality
|
- name: Check quality
|
||||||
run: |
|
run: |
|
||||||
make style && make quality
|
make style && make quality
|
||||||
|
env:
|
||||||
|
UV_NO_SYNC: 1
|
||||||
|
|
||||||
- name: Check license
|
- name: Check license
|
||||||
run: |
|
run: |
|
||||||
make license
|
make license
|
||||||
|
env:
|
||||||
|
UV_NO_SYNC: 1
|
||||||
|
|
||||||
- name: Check build
|
- name: Check build
|
||||||
run: |
|
run: |
|
||||||
make build
|
make build
|
||||||
|
env:
|
||||||
|
UV_NO_SYNC: 1
|
||||||
|
|
||||||
- name: Test with pytest
|
- name: Test with pytest
|
||||||
run: |
|
run: |
|
||||||
make test
|
make test
|
||||||
env:
|
env:
|
||||||
|
UV_NO_SYNC: 1
|
||||||
HF_HOME: ${{ runner.temp }}/huggingface
|
HF_HOME: ${{ runner.temp }}/huggingface
|
||||||
HF_HUB_OFFLINE: "${{ steps.hf-hub-cache.outputs.cache-hit == 'true' && '1' || '0' }}"
|
HF_HUB_OFFLINE: "${{ steps.hf-hub-cache.outputs.cache-hit == 'true' && '1' || '0' }}"
|
||||||
|
|||||||
99
.github/workflows/tests_npu.yml
vendored
Normal file
99
.github/workflows/tests_npu.yml
vendored
Normal file
@@ -0,0 +1,99 @@
|
|||||||
|
name: tests_npu
|
||||||
|
|
||||||
|
on:
|
||||||
|
workflow_dispatch:
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- "main"
|
||||||
|
paths:
|
||||||
|
- "**/*.py"
|
||||||
|
- "pyproject.toml"
|
||||||
|
- "Makefile"
|
||||||
|
- ".github/workflows/*.yml"
|
||||||
|
pull_request:
|
||||||
|
branches:
|
||||||
|
- "main"
|
||||||
|
paths:
|
||||||
|
- "**/*.py"
|
||||||
|
- "pyproject.toml"
|
||||||
|
- "Makefile"
|
||||||
|
- ".github/workflows/*.yml"
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
tests:
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
python:
|
||||||
|
- "3.11"
|
||||||
|
os:
|
||||||
|
- "linux-aarch64-a2-4"
|
||||||
|
pytorch_npu:
|
||||||
|
- "2.7.1"
|
||||||
|
|
||||||
|
runs-on: ${{ matrix.os }}
|
||||||
|
|
||||||
|
concurrency:
|
||||||
|
group: ${{ github.workflow }}-${{ github.ref }}-${{ matrix.os }}-${{ matrix.python }}
|
||||||
|
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
|
||||||
|
|
||||||
|
container:
|
||||||
|
image: ascendai/cann:8.3.rc2-910b-ubuntu22.04-py3.11
|
||||||
|
env:
|
||||||
|
HF_ENDPOINT: https://hf-mirror.com
|
||||||
|
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
||||||
|
OS_NAME: ${{ matrix.os }}
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Checkout
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Install uv
|
||||||
|
run: |
|
||||||
|
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
run: |
|
||||||
|
uv venv
|
||||||
|
uv pip install torch-npu==${{matrix.pytorch_npu}}
|
||||||
|
uv pip install -e ".[dev]"
|
||||||
|
|
||||||
|
- name: Install node
|
||||||
|
run: |
|
||||||
|
apt-get update || true
|
||||||
|
apt-get install -y curl
|
||||||
|
curl -fsSL https://deb.nodesource.com/setup_20.x | bash -
|
||||||
|
apt-get install -y nodejs
|
||||||
|
|
||||||
|
- name: Cache files
|
||||||
|
id: hf-hub-cache
|
||||||
|
uses: actions/cache@v4
|
||||||
|
with:
|
||||||
|
path: ${{ runner.temp }}/huggingface
|
||||||
|
key: huggingface-${{ matrix.os }}-${{ matrix.python }}-${{ hashFiles('tests/version.txt') }}
|
||||||
|
|
||||||
|
- name: Check quality
|
||||||
|
run: |
|
||||||
|
make style && make quality
|
||||||
|
env:
|
||||||
|
UV_NO_SYNC: 1
|
||||||
|
|
||||||
|
- name: Check license
|
||||||
|
run: |
|
||||||
|
make license
|
||||||
|
env:
|
||||||
|
UV_NO_SYNC: 1
|
||||||
|
|
||||||
|
- name: Check build
|
||||||
|
run: |
|
||||||
|
make build
|
||||||
|
env:
|
||||||
|
UV_NO_SYNC: 1
|
||||||
|
|
||||||
|
- name: Test with pytest
|
||||||
|
run: |
|
||||||
|
make test
|
||||||
|
env:
|
||||||
|
UV_NO_SYNC: 1
|
||||||
|
HF_HOME: /root/.cache/huggingface
|
||||||
|
HF_HUB_OFFLINE: "${{ steps.hf-hub-cache.outputs.cache-hit == 'true' && '1' || '0' }}"
|
||||||
5
.gitignore
vendored
5
.gitignore
vendored
@@ -85,7 +85,7 @@ ipython_config.py
|
|||||||
# pyenv
|
# pyenv
|
||||||
# For a library or package, you might want to ignore these files since the code is
|
# For a library or package, you might want to ignore these files since the code is
|
||||||
# intended to run in multiple environments; otherwise, check them in:
|
# intended to run in multiple environments; otherwise, check them in:
|
||||||
# .python-version
|
.python-version
|
||||||
|
|
||||||
# pipenv
|
# pipenv
|
||||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||||
@@ -165,6 +165,9 @@ cython_debug/
|
|||||||
# uv
|
# uv
|
||||||
uv.lock
|
uv.lock
|
||||||
|
|
||||||
|
# macOS
|
||||||
|
.DS_Store
|
||||||
|
|
||||||
# custom .gitignore
|
# custom .gitignore
|
||||||
hf_cache/
|
hf_cache/
|
||||||
ms_cache/
|
ms_cache/
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
include LICENSE requirements.txt
|
include LICENSE
|
||||||
|
|||||||
24
Makefile
24
Makefile
@@ -1,24 +1,28 @@
|
|||||||
.PHONY: build commit license quality style test
|
.PHONY: build commit license quality style test
|
||||||
|
|
||||||
check_dirs := scripts src tests tests_v1 setup.py
|
check_dirs := scripts src tests tests_v1
|
||||||
|
|
||||||
|
RUN := $(shell command -v uv >/dev/null 2>&1 && echo "uv run" || echo "")
|
||||||
|
BUILD := $(shell command -v uv >/dev/null 2>&1 && echo "uv build" || echo "python -m build")
|
||||||
|
TOOL := $(shell command -v uv >/dev/null 2>&1 && echo "uvx" || echo "")
|
||||||
|
|
||||||
build:
|
build:
|
||||||
pip3 install build && python3 -m build
|
$(BUILD)
|
||||||
|
|
||||||
commit:
|
commit:
|
||||||
pre-commit install
|
$(TOOL) pre-commit install
|
||||||
pre-commit run --all-files
|
$(TOOL) pre-commit run --all-files
|
||||||
|
|
||||||
license:
|
license:
|
||||||
python3 tests/check_license.py $(check_dirs)
|
$(RUN) python3 tests/check_license.py $(check_dirs)
|
||||||
|
|
||||||
quality:
|
quality:
|
||||||
ruff check $(check_dirs)
|
$(TOOL) ruff check $(check_dirs)
|
||||||
ruff format --check $(check_dirs)
|
$(TOOL) ruff format --check $(check_dirs)
|
||||||
|
|
||||||
style:
|
style:
|
||||||
ruff check $(check_dirs) --fix
|
$(TOOL) ruff check $(check_dirs) --fix
|
||||||
ruff format $(check_dirs)
|
$(TOOL) ruff format $(check_dirs)
|
||||||
|
|
||||||
test:
|
test:
|
||||||
CUDA_VISIBLE_DEVICES= WANDB_DISABLED=true pytest -vv tests/
|
WANDB_DISABLED=true $(RUN) pytest -vv --import-mode=importlib tests/ tests_v1/
|
||||||
|
|||||||
90
README.md
90
README.md
@@ -5,11 +5,13 @@
|
|||||||
[](https://github.com/hiyouga/LLaMA-Factory/graphs/contributors)
|
[](https://github.com/hiyouga/LLaMA-Factory/graphs/contributors)
|
||||||
[](https://github.com/hiyouga/LLaMA-Factory/actions/workflows/tests.yml)
|
[](https://github.com/hiyouga/LLaMA-Factory/actions/workflows/tests.yml)
|
||||||
[](https://pypi.org/project/llamafactory/)
|
[](https://pypi.org/project/llamafactory/)
|
||||||
[](https://scholar.google.com/scholar?cites=12620864006390196564)
|
[](https://scholar.google.com/scholar?cites=12620864006390196564)
|
||||||
[](https://hub.docker.com/r/hiyouga/llamafactory/tags)
|
[](https://hub.docker.com/r/hiyouga/llamafactory/tags)
|
||||||
|
|
||||||
[](https://twitter.com/llamafactory_ai)
|
[](https://twitter.com/llamafactory_ai)
|
||||||
[](https://discord.gg/rKfvV9r9FK)
|
[](https://discord.gg/rKfvV9r9FK)
|
||||||
|
[](https://github.com/hiyouga/llamafactory-community)
|
||||||
|
[](https://blog.llamafactory.net/en/)
|
||||||
|
|
||||||
[](https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing)
|
[](https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing)
|
||||||
[](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory)
|
[](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory)
|
||||||
@@ -44,16 +46,20 @@
|
|||||||
|
|
||||||
https://github.com/user-attachments/assets/3991a3a8-4276-4d30-9cab-4cb0c4b9b99e
|
https://github.com/user-attachments/assets/3991a3a8-4276-4d30-9cab-4cb0c4b9b99e
|
||||||
|
|
||||||
Choose your path:
|
Start local training:
|
||||||
|
- Please refer to [usage](#getting-started)
|
||||||
|
|
||||||
|
Start cloud training:
|
||||||
|
- **Colab (free)**: https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing
|
||||||
|
- **PAI-DSW (free trial)**: https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory
|
||||||
|
- **LLaMA Factory Online**: https://www.llamafactory.com.cn/?utm_source=LLaMA-Factory
|
||||||
|
- **Alaya NeW (cloud GPU deal)**: https://docs.alayanew.com/docs/documents/useGuide/LLaMAFactory/mutiple/?utm_source=LLaMA-Factory
|
||||||
|
|
||||||
|
Read technical notes:
|
||||||
- **Documentation (WIP)**: https://llamafactory.readthedocs.io/en/latest/
|
- **Documentation (WIP)**: https://llamafactory.readthedocs.io/en/latest/
|
||||||
- **Documentation (AMD GPU)**: https://rocm.docs.amd.com/projects/ai-developer-hub/en/latest/notebooks/fine_tune/llama_factory_llama3.html
|
- **Documentation (AMD GPU)**: https://rocm.docs.amd.com/projects/ai-developer-hub/en/latest/notebooks/fine_tune/llama_factory_llama3.html
|
||||||
- **Colab (free)**: https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing
|
- **Official Blog**: https://blog.llamafactory.net/en/
|
||||||
- **Local machine**: Please refer to [usage](#getting-started)
|
|
||||||
- **PAI-DSW (free trial)**: https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory
|
|
||||||
- **Alaya NeW (cloud GPU deal)**: https://docs.alayanew.com/docs/documents/useGuide/LLaMAFactory/mutiple/?utm_source=LLaMA-Factory
|
|
||||||
- **Official Course**: https://www.lab4ai.cn/course/detail?id=7c13e60f6137474eb40f6fd3983c0f46&utm_source=LLaMA-Factory
|
- **Official Course**: https://www.lab4ai.cn/course/detail?id=7c13e60f6137474eb40f6fd3983c0f46&utm_source=LLaMA-Factory
|
||||||
- **LLaMA Factory Online**: https://www.llamafactory.com.cn/?utm_source=LLaMA-Factory
|
|
||||||
|
|
||||||
> [!NOTE]
|
> [!NOTE]
|
||||||
> Except for the above links, all other websites are unauthorized third-party websites. Please carefully use them.
|
> Except for the above links, all other websites are unauthorized third-party websites. Please carefully use them.
|
||||||
@@ -90,7 +96,7 @@ Choose your path:
|
|||||||
- **Integrated methods**: (Continuous) pre-training, (multimodal) supervised fine-tuning, reward modeling, PPO, DPO, KTO, ORPO, etc.
|
- **Integrated methods**: (Continuous) pre-training, (multimodal) supervised fine-tuning, reward modeling, PPO, DPO, KTO, ORPO, etc.
|
||||||
- **Scalable resources**: 16-bit full-tuning, freeze-tuning, LoRA and 2/3/4/5/6/8-bit QLoRA via AQLM/AWQ/GPTQ/LLM.int8/HQQ/EETQ.
|
- **Scalable resources**: 16-bit full-tuning, freeze-tuning, LoRA and 2/3/4/5/6/8-bit QLoRA via AQLM/AWQ/GPTQ/LLM.int8/HQQ/EETQ.
|
||||||
- **Advanced algorithms**: [GaLore](https://github.com/jiaweizzhao/GaLore), [BAdam](https://github.com/Ledzy/BAdam), [APOLLO](https://github.com/zhuhanqing/APOLLO), [Adam-mini](https://github.com/zyushun/Adam-mini), [Muon](https://github.com/KellerJordan/Muon), [OFT](https://github.com/huggingface/peft/tree/main/src/peft/tuners/oft), DoRA, LongLoRA, LLaMA Pro, Mixture-of-Depths, LoRA+, LoftQ and PiSSA.
|
- **Advanced algorithms**: [GaLore](https://github.com/jiaweizzhao/GaLore), [BAdam](https://github.com/Ledzy/BAdam), [APOLLO](https://github.com/zhuhanqing/APOLLO), [Adam-mini](https://github.com/zyushun/Adam-mini), [Muon](https://github.com/KellerJordan/Muon), [OFT](https://github.com/huggingface/peft/tree/main/src/peft/tuners/oft), DoRA, LongLoRA, LLaMA Pro, Mixture-of-Depths, LoRA+, LoftQ and PiSSA.
|
||||||
- **Practical tricks**: [FlashAttention-2](https://github.com/Dao-AILab/flash-attention), [Unsloth](https://github.com/unslothai/unsloth), [Liger Kernel](https://github.com/linkedin/Liger-Kernel), RoPE scaling, NEFTune and rsLoRA.
|
- **Practical tricks**: [FlashAttention-2](https://github.com/Dao-AILab/flash-attention), [Unsloth](https://github.com/unslothai/unsloth), [Liger Kernel](https://github.com/linkedin/Liger-Kernel), [KTransformers](https://github.com/kvcache-ai/ktransformers/), RoPE scaling, NEFTune and rsLoRA.
|
||||||
- **Wide tasks**: Multi-turn dialogue, tool using, image understanding, visual grounding, video recognition, audio understanding, etc.
|
- **Wide tasks**: Multi-turn dialogue, tool using, image understanding, visual grounding, video recognition, audio understanding, etc.
|
||||||
- **Experiment monitors**: LlamaBoard, TensorBoard, Wandb, MLflow, [SwanLab](https://github.com/SwanHubX/SwanLab), etc.
|
- **Experiment monitors**: LlamaBoard, TensorBoard, Wandb, MLflow, [SwanLab](https://github.com/SwanHubX/SwanLab), etc.
|
||||||
- **Faster inference**: OpenAI-style API, Gradio UI and CLI with [vLLM worker](https://github.com/vllm-project/vllm) or [SGLang worker](https://github.com/sgl-project/sglang).
|
- **Faster inference**: OpenAI-style API, Gradio UI and CLI with [vLLM worker](https://github.com/vllm-project/vllm) or [SGLang worker](https://github.com/sgl-project/sglang).
|
||||||
@@ -104,6 +110,12 @@ Choose your path:
|
|||||||
|
|
||||||
## Blogs
|
## Blogs
|
||||||
|
|
||||||
|
> [!TIP]
|
||||||
|
> Now we have a dedicated blog for LLaMA Factory!
|
||||||
|
>
|
||||||
|
> Website: https://blog.llamafactory.net/en/
|
||||||
|
|
||||||
|
- 💡 [KTransformers Fine-Tuning × LLaMA Factory: Fine-tuning 1000 Billion models with 2 4090-GPU + CPU](https://blog.llamafactory.net/en/posts/ktransformers/) (English)
|
||||||
- 💡 [Easy Dataset × LLaMA Factory: Enabling LLMs to Efficiently Learn Domain Knowledge](https://buaa-act.feishu.cn/wiki/GVzlwYcRFiR8OLkHbL6cQpYin7g) (English)
|
- 💡 [Easy Dataset × LLaMA Factory: Enabling LLMs to Efficiently Learn Domain Knowledge](https://buaa-act.feishu.cn/wiki/GVzlwYcRFiR8OLkHbL6cQpYin7g) (English)
|
||||||
- [Fine-tune a mental health LLM using LLaMA-Factory](https://www.lab4ai.cn/project/detail?id=25cce32ec131497b9e06a93336a0817f&type=project&utm_source=LLaMA-Factory) (Chinese)
|
- [Fine-tune a mental health LLM using LLaMA-Factory](https://www.lab4ai.cn/project/detail?id=25cce32ec131497b9e06a93336a0817f&type=project&utm_source=LLaMA-Factory) (Chinese)
|
||||||
- [Fine-tune GPT-OSS for Role-Playing using LLaMA-Factory](https://docs.llamafactory.com.cn/docs/documents/best-practice/gptroleplay/?utm_source=LLaMA-Factory) (Chinese)
|
- [Fine-tune GPT-OSS for Role-Playing using LLaMA-Factory](https://docs.llamafactory.com.cn/docs/documents/best-practice/gptroleplay/?utm_source=LLaMA-Factory) (Chinese)
|
||||||
@@ -123,6 +135,8 @@ Choose your path:
|
|||||||
|
|
||||||
## Changelog
|
## Changelog
|
||||||
|
|
||||||
|
[25/10/26] We support Megatron-core training backend with [**mcore_adapter**](https://github.com/alibaba/ROLL/tree/main/mcore_adapter). See [PR #9237](https://github.com/hiyouga/LLaMA-Factory/pull/9237) to get started.
|
||||||
|
|
||||||
[25/08/22] We supported **[OFT](https://arxiv.org/abs/2306.07280)** and **[OFTv2](https://arxiv.org/abs/2506.19847)**. See [examples](examples/README.md) for usage.
|
[25/08/22] We supported **[OFT](https://arxiv.org/abs/2306.07280)** and **[OFTv2](https://arxiv.org/abs/2506.19847)**. See [examples](examples/README.md) for usage.
|
||||||
|
|
||||||
[25/08/20] We supported fine-tuning the **[Intern-S1-mini](https://huggingface.co/internlm/Intern-S1-mini)** models. See [PR #8976](https://github.com/hiyouga/LLaMA-Factory/pull/8976) to get started.
|
[25/08/20] We supported fine-tuning the **[Intern-S1-mini](https://huggingface.co/internlm/Intern-S1-mini)** models. See [PR #8976](https://github.com/hiyouga/LLaMA-Factory/pull/8976) to get started.
|
||||||
@@ -264,27 +278,21 @@ Choose your path:
|
|||||||
|
|
||||||
| Model | Model size | Template |
|
| Model | Model size | Template |
|
||||||
| ----------------------------------------------------------------- | -------------------------------- | -------------------- |
|
| ----------------------------------------------------------------- | -------------------------------- | -------------------- |
|
||||||
| [Baichuan 2](https://huggingface.co/baichuan-inc) | 7B/13B | baichuan2 |
|
|
||||||
| [BLOOM/BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
|
| [BLOOM/BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
|
||||||
| [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 |
|
|
||||||
| [Command R](https://huggingface.co/CohereForAI) | 35B/104B | cohere |
|
| [Command R](https://huggingface.co/CohereForAI) | 35B/104B | cohere |
|
||||||
| [DeepSeek (Code/MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek |
|
| [DeepSeek (LLM/Code/MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek |
|
||||||
| [DeepSeek 2.5/3](https://huggingface.co/deepseek-ai) | 236B/671B | deepseek3 |
|
| [DeepSeek 3-3.2](https://huggingface.co/deepseek-ai) | 236B/671B | deepseek3 |
|
||||||
| [DeepSeek R1 (Distill)](https://huggingface.co/deepseek-ai) | 1.5B/7B/8B/14B/32B/70B/671B | deepseekr1 |
|
| [DeepSeek R1 (Distill)](https://huggingface.co/deepseek-ai) | 1.5B/7B/8B/14B/32B/70B/671B | deepseekr1 |
|
||||||
| [ERNIE-4.5](https://huggingface.co/baidu) | 0.3B/21B/300B | ernie/ernie_nothink |
|
| [ERNIE-4.5](https://huggingface.co/baidu) | 0.3B/21B/300B | ernie/ernie_nothink |
|
||||||
| [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon |
|
| [Falcon/Falcon H1](https://huggingface.co/tiiuae) | 0.5B/1.5B/3B/7B/11B/34B/40B/180B | falcon/falcon_h1 |
|
||||||
| [Falcon-H1](https://huggingface.co/tiiuae) | 0.5B/1.5B/3B/7B/34B | falcon_h1 |
|
|
||||||
| [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma/gemma2 |
|
| [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma/gemma2 |
|
||||||
| [Gemma 3/Gemma 3n](https://huggingface.co/google) | 270M/1B/4B/6B/8B/12B/27B | gemma3/gemma3n |
|
| [Gemma 3/Gemma 3n](https://huggingface.co/google) | 270M/1B/4B/6B/8B/12B/27B | gemma3/gemma3n |
|
||||||
| [GLM-4/GLM-4-0414/GLM-Z1](https://huggingface.co/zai-org) | 9B/32B | glm4/glmz1 |
|
| [GLM-4/GLM-4-0414/GLM-Z1](https://huggingface.co/zai-org) | 9B/32B | glm4/glmz1 |
|
||||||
| [GLM-4.1V](https://huggingface.co/zai-org) | 9B | glm4v |
|
| [GLM-4.5/GLM-4.5(6)V](https://huggingface.co/zai-org) | 9B/106B/355B | glm4_moe/glm4_5v |
|
||||||
| [GLM-4.5/GLM-4.5V](https://huggingface.co/zai-org) | 106B/355B | glm4_moe/glm4v_moe |
|
|
||||||
| [GPT-2](https://huggingface.co/openai-community) | 0.1B/0.4B/0.8B/1.5B | - |
|
| [GPT-2](https://huggingface.co/openai-community) | 0.1B/0.4B/0.8B/1.5B | - |
|
||||||
| [GPT-OSS](https://huggingface.co/openai) | 20B/120B | gpt |
|
| [GPT-OSS](https://huggingface.co/openai) | 20B/120B | gpt_oss |
|
||||||
| [Granite 3.0-3.3](https://huggingface.co/ibm-granite) | 1B/2B/3B/8B | granite3 |
|
| [Granite 3-4](https://huggingface.co/ibm-granite) | 1B/2B/3B/7B/8B | granite3/granite4 |
|
||||||
| [Granite 4](https://huggingface.co/ibm-granite) | 7B | granite4 |
|
| [Hunyuan (MT)](https://huggingface.co/tencent/) | 7B | hunyuan |
|
||||||
| [Hunyuan](https://huggingface.co/tencent/) | 7B | hunyuan |
|
|
||||||
| [Index](https://huggingface.co/IndexTeam) | 1.9B | index |
|
|
||||||
| [InternLM 2-3](https://huggingface.co/internlm) | 7B/8B/20B | intern2 |
|
| [InternLM 2-3](https://huggingface.co/internlm) | 7B/8B/20B | intern2 |
|
||||||
| [InternVL 2.5-3.5](https://huggingface.co/OpenGVLab) | 1B/2B/4B/8B/14B/30B/38B/78B/241B | intern_vl |
|
| [InternVL 2.5-3.5](https://huggingface.co/OpenGVLab) | 1B/2B/4B/8B/14B/30B/38B/78B/241B | intern_vl |
|
||||||
| [InternLM/Intern-S1-mini](https://huggingface.co/internlm/) | 8B | intern_s1 |
|
| [InternLM/Intern-S1-mini](https://huggingface.co/internlm/) | 8B | intern_s1 |
|
||||||
@@ -298,15 +306,13 @@ Choose your path:
|
|||||||
| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | llava |
|
| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | llava |
|
||||||
| [LLaVA-NeXT](https://huggingface.co/llava-hf) | 7B/8B/13B/34B/72B/110B | llava_next |
|
| [LLaVA-NeXT](https://huggingface.co/llava-hf) | 7B/8B/13B/34B/72B/110B | llava_next |
|
||||||
| [LLaVA-NeXT-Video](https://huggingface.co/llava-hf) | 7B/34B | llava_next_video |
|
| [LLaVA-NeXT-Video](https://huggingface.co/llava-hf) | 7B/34B | llava_next_video |
|
||||||
| [MiMo](https://huggingface.co/XiaomiMiMo) | 7B | mimo |
|
| [MiMo](https://huggingface.co/XiaomiMiMo) | 7B/309B | mimo/mimo_v2 |
|
||||||
| [MiniCPM 1-4.1](https://huggingface.co/openbmb) | 0.5B/1B/2B/4B/8B | cpm/cpm3/cpm4 |
|
| [MiniCPM 1-4.1](https://huggingface.co/openbmb) | 0.5B/1B/2B/4B/8B | cpm/cpm3/cpm4 |
|
||||||
| [MiniCPM-o-2.6/MiniCPM-V-2.6](https://huggingface.co/openbmb) | 8B | minicpm_o/minicpm_v |
|
| [MiniCPM-o-2.6/MiniCPM-V-2.6](https://huggingface.co/openbmb) | 8B | minicpm_o/minicpm_v |
|
||||||
| [Ministral/Mistral-Nemo](https://huggingface.co/mistralai) | 8B/12B | ministral |
|
| [Ministral 3](https://huggingface.co/mistralai) | 3B/8B/14B | ministral3 |
|
||||||
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
|
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
|
||||||
| [Mistral Small](https://huggingface.co/mistralai) | 24B | mistral_small |
|
|
||||||
| [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
|
| [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
|
||||||
| [PaliGemma/PaliGemma2](https://huggingface.co/google) | 3B/10B/28B | paligemma |
|
| [PaliGemma/PaliGemma2](https://huggingface.co/google) | 3B/10B/28B | paligemma |
|
||||||
| [Phi-1.5/Phi-2](https://huggingface.co/microsoft) | 1.3B/2.7B | - |
|
|
||||||
| [Phi-3/Phi-3.5](https://huggingface.co/microsoft) | 4B/14B | phi |
|
| [Phi-3/Phi-3.5](https://huggingface.co/microsoft) | 4B/14B | phi |
|
||||||
| [Phi-3-small](https://huggingface.co/microsoft) | 7B | phi_small |
|
| [Phi-3-small](https://huggingface.co/microsoft) | 7B | phi_small |
|
||||||
| [Phi-4](https://huggingface.co/microsoft) | 14B | phi4 |
|
| [Phi-4](https://huggingface.co/microsoft) | 14B | phi4 |
|
||||||
@@ -317,19 +323,18 @@ Choose your path:
|
|||||||
| [Qwen2.5-Omni](https://huggingface.co/Qwen) | 3B/7B | qwen2_omni |
|
| [Qwen2.5-Omni](https://huggingface.co/Qwen) | 3B/7B | qwen2_omni |
|
||||||
| [Qwen3-Omni](https://huggingface.co/Qwen) | 30B | qwen3_omni |
|
| [Qwen3-Omni](https://huggingface.co/Qwen) | 30B | qwen3_omni |
|
||||||
| [Qwen2-VL/Qwen2.5-VL/QVQ](https://huggingface.co/Qwen) | 2B/3B/7B/32B/72B | qwen2_vl |
|
| [Qwen2-VL/Qwen2.5-VL/QVQ](https://huggingface.co/Qwen) | 2B/3B/7B/32B/72B | qwen2_vl |
|
||||||
| [Qwen3-VL](https://huggingface.co/Qwen) | 235B | qwen3_vl |
|
| [Qwen3-VL](https://huggingface.co/Qwen) | 2B/4B/8B/30B/32B/235B | qwen3_vl |
|
||||||
| [Seed (OSS/Coder)](https://huggingface.co/ByteDance-Seed) | 8B/36B | seed_oss/seed_coder |
|
| [Seed (OSS/Coder)](https://huggingface.co/ByteDance-Seed) | 8B/36B | seed_oss/seed_coder |
|
||||||
| [Skywork o1](https://huggingface.co/Skywork) | 8B | skywork_o1 |
|
|
||||||
| [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
|
| [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
|
||||||
| [TeleChat2](https://huggingface.co/Tele-AI) | 3B/7B/35B/115B | telechat2 |
|
| [VibeThinker-1.5B](https://huggingface.co/WeiboAI) | 1.5B | qwen3 |
|
||||||
| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | xverse |
|
|
||||||
| [Yi/Yi-1.5 (Code)](https://huggingface.co/01-ai) | 1.5B/6B/9B/34B | yi |
|
| [Yi/Yi-1.5 (Code)](https://huggingface.co/01-ai) | 1.5B/6B/9B/34B | yi |
|
||||||
| [Yi-VL](https://huggingface.co/01-ai) | 6B/34B | yi_vl |
|
|
||||||
| [Yuan 2](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan |
|
| [Yuan 2](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan |
|
||||||
|
|
||||||
> [!NOTE]
|
> [!NOTE]
|
||||||
> For the "base" models, the `template` argument can be chosen from `default`, `alpaca`, `vicuna` etc. But make sure to use the **corresponding template** for the "instruct/chat" models.
|
> For the "base" models, the `template` argument can be chosen from `default`, `alpaca`, `vicuna` etc. But make sure to use the **corresponding template** for the "instruct/chat" models.
|
||||||
>
|
>
|
||||||
|
> If the model has both reasoning and non-reasoning versions, please use the `_nothink` suffix to distinguish between them. For example, `qwen3` and `qwen3_nothink`.
|
||||||
|
>
|
||||||
> Remember to use the **SAME** template in training and inference.
|
> Remember to use the **SAME** template in training and inference.
|
||||||
>
|
>
|
||||||
> \*: You should install the `transformers` from main branch and use `DISABLE_VERSION_CHECK=1` to skip version check.
|
> \*: You should install the `transformers` from main branch and use `DISABLE_VERSION_CHECK=1` to skip version check.
|
||||||
@@ -459,7 +464,7 @@ You also can add a custom chat template to [template.py](src/llamafactory/data/t
|
|||||||
Some datasets require confirmation before using them, so we recommend logging in with your Hugging Face account using these commands.
|
Some datasets require confirmation before using them, so we recommend logging in with your Hugging Face account using these commands.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip install --upgrade huggingface_hub
|
pip install "huggingface_hub<1.0.0"
|
||||||
huggingface-cli login
|
huggingface-cli login
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -509,10 +514,12 @@ huggingface-cli login
|
|||||||
```bash
|
```bash
|
||||||
git clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git
|
git clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git
|
||||||
cd LLaMA-Factory
|
cd LLaMA-Factory
|
||||||
pip install -e ".[torch,metrics]" --no-build-isolation
|
pip install -e ".[metrics]" --no-build-isolation
|
||||||
```
|
```
|
||||||
|
|
||||||
Extra dependencies available: torch, torch-npu, metrics, deepspeed, liger-kernel, bitsandbytes, hqq, eetq, gptq, aqlm, vllm, sglang, galore, apollo, badam, adam-mini, qwen, minicpm_v, openmind, swanlab, dev
|
Optional dependencies available: `metrics`, `deepspeed`. Install with: `pip install -e ".[metrics,deepspeed]"`
|
||||||
|
|
||||||
|
Additional dependencies for specific features are available in `examples/requirements/`.
|
||||||
|
|
||||||
#### Install from Docker Image
|
#### Install from Docker Image
|
||||||
|
|
||||||
@@ -531,13 +538,7 @@ Please refer to [build docker](#build-docker) to build the image yourself.
|
|||||||
Create an isolated Python environment with [uv](https://github.com/astral-sh/uv):
|
Create an isolated Python environment with [uv](https://github.com/astral-sh/uv):
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
uv sync --extra torch --extra metrics --prerelease=allow
|
uv run llamafactory-cli webui
|
||||||
```
|
|
||||||
|
|
||||||
Run LLaMA-Factory in the isolated environment:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
uv run --prerelease=allow llamafactory-cli train examples/train_lora/llama3_lora_pretrain.yaml
|
|
||||||
```
|
```
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
@@ -574,7 +575,7 @@ To enable FlashAttention-2 on the Windows platform, please use the script from [
|
|||||||
|
|
||||||
<details><summary>For Ascend NPU users</summary>
|
<details><summary>For Ascend NPU users</summary>
|
||||||
|
|
||||||
To install LLaMA Factory on Ascend NPU devices, please upgrade Python to version 3.10 or higher and specify extra dependencies: `pip install -e ".[torch-npu,metrics]"`. Additionally, you need to install the **[Ascend CANN Toolkit and Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**. Please follow the [installation tutorial](https://www.hiascend.com/document/detail/en/CANNCommunityEdition/600alphaX/softwareinstall/instg/atlasdeploy_03_0031.html) or use the following commands:
|
To install LLaMA Factory on Ascend NPU devices, please upgrade Python to version 3.10 or higher: `pip install -e . torch-npu==2.7.1`. Additionally, you need to install the **[Ascend CANN Toolkit and Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**. Please follow the [installation tutorial](https://www.hiascend.com/document/detail/en/CANNCommunityEdition/600alphaX/softwareinstall/instg/atlasdeploy_03_0031.html) or use the following commands:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# replace the url according to your CANN version and devices
|
# replace the url according to your CANN version and devices
|
||||||
@@ -593,8 +594,8 @@ source /usr/local/Ascend/ascend-toolkit/set_env.sh
|
|||||||
| Requirement | Minimum | Recommend |
|
| Requirement | Minimum | Recommend |
|
||||||
| ------------ | ------- | -------------- |
|
| ------------ | ------- | -------------- |
|
||||||
| CANN | 8.0.RC1 | 8.0.0.alpha002 |
|
| CANN | 8.0.RC1 | 8.0.0.alpha002 |
|
||||||
| torch | 2.1.0 | 2.4.0 |
|
| torch | 2.1.0 | 2.7.1 |
|
||||||
| torch-npu | 2.1.0 | 2.4.0.post2 |
|
| torch-npu | 2.1.0 | 2.7.1 |
|
||||||
| deepspeed | 0.13.2 | 0.13.2 |
|
| deepspeed | 0.13.2 | 0.13.2 |
|
||||||
| vllm-ascend | - | 0.7.3 |
|
| vllm-ascend | - | 0.7.3 |
|
||||||
|
|
||||||
@@ -709,7 +710,6 @@ For CUDA users:
|
|||||||
```bash
|
```bash
|
||||||
docker build -f ./docker/docker-cuda/Dockerfile \
|
docker build -f ./docker/docker-cuda/Dockerfile \
|
||||||
--build-arg PIP_INDEX=https://pypi.org/simple \
|
--build-arg PIP_INDEX=https://pypi.org/simple \
|
||||||
--build-arg EXTRAS=metrics \
|
|
||||||
-t llamafactory:latest .
|
-t llamafactory:latest .
|
||||||
|
|
||||||
docker run -dit --ipc=host --gpus=all \
|
docker run -dit --ipc=host --gpus=all \
|
||||||
@@ -726,7 +726,6 @@ For Ascend NPU users:
|
|||||||
```bash
|
```bash
|
||||||
docker build -f ./docker/docker-npu/Dockerfile \
|
docker build -f ./docker/docker-npu/Dockerfile \
|
||||||
--build-arg PIP_INDEX=https://pypi.org/simple \
|
--build-arg PIP_INDEX=https://pypi.org/simple \
|
||||||
--build-arg EXTRAS=torch-npu,metrics \
|
|
||||||
-t llamafactory:latest .
|
-t llamafactory:latest .
|
||||||
|
|
||||||
docker run -dit --ipc=host \
|
docker run -dit --ipc=host \
|
||||||
@@ -751,7 +750,6 @@ For AMD ROCm users:
|
|||||||
```bash
|
```bash
|
||||||
docker build -f ./docker/docker-rocm/Dockerfile \
|
docker build -f ./docker/docker-rocm/Dockerfile \
|
||||||
--build-arg PIP_INDEX=https://pypi.org/simple \
|
--build-arg PIP_INDEX=https://pypi.org/simple \
|
||||||
--build-arg EXTRAS=metrics \
|
|
||||||
-t llamafactory:latest .
|
-t llamafactory:latest .
|
||||||
|
|
||||||
docker run -dit --ipc=host \
|
docker run -dit --ipc=host \
|
||||||
|
|||||||
85
README_zh.md
85
README_zh.md
@@ -5,11 +5,13 @@
|
|||||||
[](https://github.com/hiyouga/LLaMA-Factory/graphs/contributors)
|
[](https://github.com/hiyouga/LLaMA-Factory/graphs/contributors)
|
||||||
[](https://github.com/hiyouga/LLaMA-Factory/actions/workflows/tests.yml)
|
[](https://github.com/hiyouga/LLaMA-Factory/actions/workflows/tests.yml)
|
||||||
[](https://pypi.org/project/llamafactory/)
|
[](https://pypi.org/project/llamafactory/)
|
||||||
[](https://scholar.google.com/scholar?cites=12620864006390196564)
|
[](https://scholar.google.com/scholar?cites=12620864006390196564)
|
||||||
[](https://hub.docker.com/r/hiyouga/llamafactory/tags)
|
[](https://hub.docker.com/r/hiyouga/llamafactory/tags)
|
||||||
|
|
||||||
[](https://twitter.com/llamafactory_ai)
|
[](https://twitter.com/llamafactory_ai)
|
||||||
[](https://discord.gg/rKfvV9r9FK)
|
[](https://discord.gg/rKfvV9r9FK)
|
||||||
|
[](https://github.com/hiyouga/llamafactory-community)
|
||||||
|
[](https://blog.llamafactory.net/)
|
||||||
|
|
||||||
[](https://colab.research.google.com/drive/1d5KQtbemerlSDSxZIfAaWXhKr30QypiK?usp=sharing)
|
[](https://colab.research.google.com/drive/1d5KQtbemerlSDSxZIfAaWXhKr30QypiK?usp=sharing)
|
||||||
[](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory)
|
[](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory)
|
||||||
@@ -44,18 +46,22 @@
|
|||||||
|
|
||||||
https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
|
https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
|
||||||
|
|
||||||
选择你的打开方式:
|
开始本地训练:
|
||||||
|
- 请见[如何使用](#如何使用)
|
||||||
|
|
||||||
|
开始云端训练:
|
||||||
|
- **Colab(免费)**:https://colab.research.google.com/drive/1d5KQtbemerlSDSxZIfAaWXhKr30QypiK?usp=sharing
|
||||||
|
- **PAI-DSW(免费试用)**:https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory
|
||||||
|
- **LLaMA Factory Online(在线微调)**:https://www.llamafactory.com.cn/?utm_source=LLaMA-Factory
|
||||||
|
- **九章智算云(算力优惠活动)**:https://docs.alayanew.com/docs/documents/useGuide/LLaMAFactory/mutiple/?utm_source=LLaMA-Factory
|
||||||
|
|
||||||
|
阅读技术文档:
|
||||||
- **入门教程**:https://zhuanlan.zhihu.com/p/695287607
|
- **入门教程**:https://zhuanlan.zhihu.com/p/695287607
|
||||||
- **微调视频教程**:https://www.bilibili.com/video/BV1djgRzxEts/
|
- **微调视频教程**:https://www.bilibili.com/video/BV1djgRzxEts/
|
||||||
- **框架文档**:https://llamafactory.readthedocs.io/zh-cn/latest/
|
- **框架文档**:https://llamafactory.readthedocs.io/zh-cn/latest/
|
||||||
- **框架文档(昇腾 NPU)**:https://ascend.github.io/docs/sources/llamafactory/
|
- **框架文档(昇腾 NPU)**:https://ascend.github.io/docs/sources/llamafactory/
|
||||||
- **Colab(免费)**:https://colab.research.google.com/drive/1d5KQtbemerlSDSxZIfAaWXhKr30QypiK?usp=sharing
|
- **官方博客**:https://blog.llamafactory.net/
|
||||||
- **本地机器**:请见[如何使用](#如何使用)
|
|
||||||
- **PAI-DSW(免费试用)**:https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory
|
|
||||||
- **九章智算云(算力优惠活动)**:https://docs.alayanew.com/docs/documents/useGuide/LLaMAFactory/mutiple/?utm_source=LLaMA-Factory
|
|
||||||
- **官方课程**:https://www.lab4ai.cn/course/detail?id=7c13e60f6137474eb40f6fd3983c0f46&utm_source=LLaMA-Factory
|
- **官方课程**:https://www.lab4ai.cn/course/detail?id=7c13e60f6137474eb40f6fd3983c0f46&utm_source=LLaMA-Factory
|
||||||
- **LLaMA Factory Online(在线微调)**:https://www.llamafactory.com.cn/?utm_source=LLaMA-Factory
|
|
||||||
|
|
||||||
> [!NOTE]
|
> [!NOTE]
|
||||||
> 除上述链接以外的其他网站均为未经许可的第三方网站,请小心甄别。
|
> 除上述链接以外的其他网站均为未经许可的第三方网站,请小心甄别。
|
||||||
@@ -92,7 +98,7 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
|
|||||||
- **集成方法**:(增量)预训练、(多模态)指令监督微调、奖励模型训练、PPO 训练、DPO 训练、KTO 训练、ORPO 训练等等。
|
- **集成方法**:(增量)预训练、(多模态)指令监督微调、奖励模型训练、PPO 训练、DPO 训练、KTO 训练、ORPO 训练等等。
|
||||||
- **多种精度**:16 比特全参数微调、冻结微调、LoRA 微调和基于 AQLM/AWQ/GPTQ/LLM.int8/HQQ/EETQ 的 2/3/4/5/6/8 比特 QLoRA 微调。
|
- **多种精度**:16 比特全参数微调、冻结微调、LoRA 微调和基于 AQLM/AWQ/GPTQ/LLM.int8/HQQ/EETQ 的 2/3/4/5/6/8 比特 QLoRA 微调。
|
||||||
- **先进算法**:[GaLore](https://github.com/jiaweizzhao/GaLore)、[BAdam](https://github.com/Ledzy/BAdam)、[APOLLO](https://github.com/zhuhanqing/APOLLO)、[Adam-mini](https://github.com/zyushun/Adam-mini)、[Muon](https://github.com/KellerJordan/Muon)、[OFT](https://github.com/huggingface/peft/tree/main/src/peft/tuners/oft)、DoRA、LongLoRA、LLaMA Pro、Mixture-of-Depths、LoRA+、LoftQ 和 PiSSA。
|
- **先进算法**:[GaLore](https://github.com/jiaweizzhao/GaLore)、[BAdam](https://github.com/Ledzy/BAdam)、[APOLLO](https://github.com/zhuhanqing/APOLLO)、[Adam-mini](https://github.com/zyushun/Adam-mini)、[Muon](https://github.com/KellerJordan/Muon)、[OFT](https://github.com/huggingface/peft/tree/main/src/peft/tuners/oft)、DoRA、LongLoRA、LLaMA Pro、Mixture-of-Depths、LoRA+、LoftQ 和 PiSSA。
|
||||||
- **实用技巧**:[FlashAttention-2](https://github.com/Dao-AILab/flash-attention)、[Unsloth](https://github.com/unslothai/unsloth)、[Liger Kernel](https://github.com/linkedin/Liger-Kernel)、RoPE scaling、NEFTune 和 rsLoRA。
|
- **实用技巧**:[FlashAttention-2](https://github.com/Dao-AILab/flash-attention)、[Unsloth](https://github.com/unslothai/unsloth)、[Liger Kernel](https://github.com/linkedin/Liger-Kernel)、[KTransformers](https://github.com/kvcache-ai/ktransformers/)、RoPE scaling、NEFTune 和 rsLoRA。
|
||||||
- **广泛任务**:多轮对话、工具调用、图像理解、视觉定位、视频识别和语音理解等等。
|
- **广泛任务**:多轮对话、工具调用、图像理解、视觉定位、视频识别和语音理解等等。
|
||||||
- **实验监控**:LlamaBoard、TensorBoard、Wandb、MLflow、[SwanLab](https://github.com/SwanHubX/SwanLab) 等等。
|
- **实验监控**:LlamaBoard、TensorBoard、Wandb、MLflow、[SwanLab](https://github.com/SwanHubX/SwanLab) 等等。
|
||||||
- **极速推理**:基于 [vLLM](https://github.com/vllm-project/vllm) 或 [SGLang](https://github.com/sgl-project/sglang) 的 OpenAI 风格 API、浏览器界面和命令行接口。
|
- **极速推理**:基于 [vLLM](https://github.com/vllm-project/vllm) 或 [SGLang](https://github.com/sgl-project/sglang) 的 OpenAI 风格 API、浏览器界面和命令行接口。
|
||||||
@@ -106,6 +112,12 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
|
|||||||
|
|
||||||
## 官方博客
|
## 官方博客
|
||||||
|
|
||||||
|
> [!TIP]
|
||||||
|
> 我们现在拥有了 LLaMA Factory 的专属博客!
|
||||||
|
>
|
||||||
|
> 网站地址:https://blog.llamafactory.net/
|
||||||
|
|
||||||
|
- 💡 [KTransformers Fine-Tuning × LLaMA Factory: 用2张4090级的GPU+CPU 微调 1000B规模的超大模型](https://swcil84qspu.feishu.cn/wiki/Z1sSwb2poijybxkyPEkcDG6enVc) (中文)
|
||||||
- 💡 [Easy Dataset × LLaMA Factory: 让大模型高效学习领域知识](https://buaa-act.feishu.cn/wiki/KY9xwTGs1iqHrRkjXBwcZP9WnL9)(中文)
|
- 💡 [Easy Dataset × LLaMA Factory: 让大模型高效学习领域知识](https://buaa-act.feishu.cn/wiki/KY9xwTGs1iqHrRkjXBwcZP9WnL9)(中文)
|
||||||
- [使用 LLaMA-Factory 微调心理健康大模型](https://www.lab4ai.cn/project/detail?id=25cce32ec131497b9e06a93336a0817f&type=project&utm_source=LLaMA-Factory)(中文)
|
- [使用 LLaMA-Factory 微调心理健康大模型](https://www.lab4ai.cn/project/detail?id=25cce32ec131497b9e06a93336a0817f&type=project&utm_source=LLaMA-Factory)(中文)
|
||||||
- [使用 LLaMA-Factory 构建 GPT-OSS 角色扮演模型](https://docs.llamafactory.com.cn/docs/documents/best-practice/gptroleplay/?utm_source=LLaMA-Factory)(中文)
|
- [使用 LLaMA-Factory 构建 GPT-OSS 角色扮演模型](https://docs.llamafactory.com.cn/docs/documents/best-practice/gptroleplay/?utm_source=LLaMA-Factory)(中文)
|
||||||
@@ -125,6 +137,8 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
|
|||||||
|
|
||||||
## 更新日志
|
## 更新日志
|
||||||
|
|
||||||
|
[25/10/26] 我们支持了Megatron-core作为训练后端和适配了[**mcore_adapter**](https://github.com/alibaba/ROLL/tree/main/mcore_adapter)。查看[PR #9237](https://github.com/hiyouga/LLaMA-Factory/pull/9237)以使用。
|
||||||
|
|
||||||
[25/08/22] 我们支持了 **[OFT](https://arxiv.org/abs/2306.07280)** 和 **[OFTv2](https://arxiv.org/abs/2506.19847)** 模型的微调。查看 [examples](examples/README.md) 以使用。
|
[25/08/22] 我们支持了 **[OFT](https://arxiv.org/abs/2306.07280)** 和 **[OFTv2](https://arxiv.org/abs/2506.19847)** 模型的微调。查看 [examples](examples/README.md) 以使用。
|
||||||
|
|
||||||
[25/08/20] 我们支持了 **[Intern-S1-mini](https://huggingface.co/internlm/Intern-S1-mini)** 模型的微调。查看 [PR #8976](https://github.com/hiyouga/LLaMA-Factory/pull/8976) 以使用。
|
[25/08/20] 我们支持了 **[Intern-S1-mini](https://huggingface.co/internlm/Intern-S1-mini)** 模型的微调。查看 [PR #8976](https://github.com/hiyouga/LLaMA-Factory/pull/8976) 以使用。
|
||||||
@@ -266,27 +280,21 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
|
|||||||
|
|
||||||
| 模型名 | 参数量 | Template |
|
| 模型名 | 参数量 | Template |
|
||||||
| ----------------------------------------------------------------- | -------------------------------- | -------------------- |
|
| ----------------------------------------------------------------- | -------------------------------- | -------------------- |
|
||||||
| [Baichuan 2](https://huggingface.co/baichuan-inc) | 7B/13B | baichuan2 |
|
|
||||||
| [BLOOM/BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
|
| [BLOOM/BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
|
||||||
| [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 |
|
|
||||||
| [Command R](https://huggingface.co/CohereForAI) | 35B/104B | cohere |
|
| [Command R](https://huggingface.co/CohereForAI) | 35B/104B | cohere |
|
||||||
| [DeepSeek (Code/MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek |
|
| [DeepSeek (LLM/Code/MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek |
|
||||||
| [DeepSeek 2.5/3](https://huggingface.co/deepseek-ai) | 236B/671B | deepseek3 |
|
| [DeepSeek 3-3.2](https://huggingface.co/deepseek-ai) | 236B/671B | deepseek3 |
|
||||||
| [DeepSeek R1 (Distill)](https://huggingface.co/deepseek-ai) | 1.5B/7B/8B/14B/32B/70B/671B | deepseekr1 |
|
| [DeepSeek R1 (Distill)](https://huggingface.co/deepseek-ai) | 1.5B/7B/8B/14B/32B/70B/671B | deepseekr1 |
|
||||||
| [ERNIE-4.5](https://huggingface.co/baidu) | 0.3B/21B/300B | ernie/ernie_nothink |
|
| [ERNIE-4.5](https://huggingface.co/baidu) | 0.3B/21B/300B | ernie/ernie_nothink |
|
||||||
| [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon |
|
| [Falcon/Falcon H1](https://huggingface.co/tiiuae) | 0.5B/1.5B/3B/7B/11B/34B/40B/180B | falcon/falcon_h1 |
|
||||||
| [Falcon-H1](https://huggingface.co/tiiuae) | 0.5B/1.5B/3B/7B/34B | falcon_h1 |
|
|
||||||
| [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma/gemma2 |
|
| [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma/gemma2 |
|
||||||
| [Gemma 3/Gemma 3n](https://huggingface.co/google) | 270M/1B/4B/6B/8B/12B/27B | gemma3/gemma3n |
|
| [Gemma 3/Gemma 3n](https://huggingface.co/google) | 270M/1B/4B/6B/8B/12B/27B | gemma3/gemma3n |
|
||||||
| [GLM-4/GLM-4-0414/GLM-Z1](https://huggingface.co/zai-org) | 9B/32B | glm4/glmz1 |
|
| [GLM-4/GLM-4-0414/GLM-Z1](https://huggingface.co/zai-org) | 9B/32B | glm4/glmz1 |
|
||||||
| [GLM-4.1V](https://huggingface.co/zai-org) | 9B | glm4v |
|
| [GLM-4.5/GLM-4.5(6)V](https://huggingface.co/zai-org) | 9B/106B/355B | glm4_moe/glm4_5v |
|
||||||
| [GLM-4.5/GLM-4.5V](https://huggingface.co/zai-org) | 106B/355B | glm4_moe/glm4v_moe |
|
|
||||||
| [GPT-2](https://huggingface.co/openai-community) | 0.1B/0.4B/0.8B/1.5B | - |
|
| [GPT-2](https://huggingface.co/openai-community) | 0.1B/0.4B/0.8B/1.5B | - |
|
||||||
| [GPT-OSS](https://huggingface.co/openai) | 20B/120B | gpt |
|
| [GPT-OSS](https://huggingface.co/openai) | 20B/120B | gpt_oss |
|
||||||
| [Granite 3.0-3.3](https://huggingface.co/ibm-granite) | 1B/2B/3B/8B | granite3 |
|
| [Granite 3-4](https://huggingface.co/ibm-granite) | 1B/2B/3B/7B/8B | granite3/granite4 |
|
||||||
| [Granite 4](https://huggingface.co/ibm-granite) | 7B | granite4 |
|
| [Hunyuan (MT)](https://huggingface.co/tencent/) | 7B | hunyuan |
|
||||||
| [Hunyuan](https://huggingface.co/tencent/) | 7B | hunyuan |
|
|
||||||
| [Index](https://huggingface.co/IndexTeam) | 1.9B | index |
|
|
||||||
| [InternLM 2-3](https://huggingface.co/internlm) | 7B/8B/20B | intern2 |
|
| [InternLM 2-3](https://huggingface.co/internlm) | 7B/8B/20B | intern2 |
|
||||||
| [InternVL 2.5-3.5](https://huggingface.co/OpenGVLab) | 1B/2B/4B/8B/14B/30B/38B/78B/241B | intern_vl |
|
| [InternVL 2.5-3.5](https://huggingface.co/OpenGVLab) | 1B/2B/4B/8B/14B/30B/38B/78B/241B | intern_vl |
|
||||||
| [InternLM/Intern-S1-mini](https://huggingface.co/internlm/) | 8B | intern_s1 |
|
| [InternLM/Intern-S1-mini](https://huggingface.co/internlm/) | 8B | intern_s1 |
|
||||||
@@ -300,15 +308,13 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
|
|||||||
| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | llava |
|
| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | llava |
|
||||||
| [LLaVA-NeXT](https://huggingface.co/llava-hf) | 7B/8B/13B/34B/72B/110B | llava_next |
|
| [LLaVA-NeXT](https://huggingface.co/llava-hf) | 7B/8B/13B/34B/72B/110B | llava_next |
|
||||||
| [LLaVA-NeXT-Video](https://huggingface.co/llava-hf) | 7B/34B | llava_next_video |
|
| [LLaVA-NeXT-Video](https://huggingface.co/llava-hf) | 7B/34B | llava_next_video |
|
||||||
| [MiMo](https://huggingface.co/XiaomiMiMo) | 7B | mimo |
|
| [MiMo](https://huggingface.co/XiaomiMiMo) | 7B/309B | mimo/mimo_v2 |
|
||||||
| [MiniCPM 1-4.1](https://huggingface.co/openbmb) | 0.5B/1B/2B/4B/8B | cpm/cpm3/cpm4 |
|
| [MiniCPM 1-4.1](https://huggingface.co/openbmb) | 0.5B/1B/2B/4B/8B | cpm/cpm3/cpm4 |
|
||||||
| [MiniCPM-o-2.6/MiniCPM-V-2.6](https://huggingface.co/openbmb) | 8B | minicpm_o/minicpm_v |
|
| [MiniCPM-o-2.6/MiniCPM-V-2.6](https://huggingface.co/openbmb) | 8B | minicpm_o/minicpm_v |
|
||||||
| [Ministral/Mistral-Nemo](https://huggingface.co/mistralai) | 8B/12B | ministral |
|
| [Ministral 3](https://huggingface.co/mistralai) | 3B/8B/14B | ministral3 |
|
||||||
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
|
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
|
||||||
| [Mistral Small](https://huggingface.co/mistralai) | 24B | mistral_small |
|
|
||||||
| [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
|
| [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
|
||||||
| [PaliGemma/PaliGemma2](https://huggingface.co/google) | 3B/10B/28B | paligemma |
|
| [PaliGemma/PaliGemma2](https://huggingface.co/google) | 3B/10B/28B | paligemma |
|
||||||
| [Phi-1.5/Phi-2](https://huggingface.co/microsoft) | 1.3B/2.7B | - |
|
|
||||||
| [Phi-3/Phi-3.5](https://huggingface.co/microsoft) | 4B/14B | phi |
|
| [Phi-3/Phi-3.5](https://huggingface.co/microsoft) | 4B/14B | phi |
|
||||||
| [Phi-3-small](https://huggingface.co/microsoft) | 7B | phi_small |
|
| [Phi-3-small](https://huggingface.co/microsoft) | 7B | phi_small |
|
||||||
| [Phi-4](https://huggingface.co/microsoft) | 14B | phi4 |
|
| [Phi-4](https://huggingface.co/microsoft) | 14B | phi4 |
|
||||||
@@ -319,19 +325,18 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
|
|||||||
| [Qwen2.5-Omni](https://huggingface.co/Qwen) | 3B/7B | qwen2_omni |
|
| [Qwen2.5-Omni](https://huggingface.co/Qwen) | 3B/7B | qwen2_omni |
|
||||||
| [Qwen3-Omni](https://huggingface.co/Qwen) | 30B | qwen3_omni |
|
| [Qwen3-Omni](https://huggingface.co/Qwen) | 30B | qwen3_omni |
|
||||||
| [Qwen2-VL/Qwen2.5-VL/QVQ](https://huggingface.co/Qwen) | 2B/3B/7B/32B/72B | qwen2_vl |
|
| [Qwen2-VL/Qwen2.5-VL/QVQ](https://huggingface.co/Qwen) | 2B/3B/7B/32B/72B | qwen2_vl |
|
||||||
| [Qwen3-VL](https://huggingface.co/Qwen) | 235B | qwen3_vl |
|
| [Qwen3-VL](https://huggingface.co/Qwen) | 2B/4B/8B/30B/32B/235B | qwen3_vl |
|
||||||
| [Seed (OSS/Coder)](https://huggingface.co/ByteDance-Seed) | 8B/36B | seed_oss/seed_coder |
|
| [Seed (OSS/Coder)](https://huggingface.co/ByteDance-Seed) | 8B/36B | seed_oss/seed_coder |
|
||||||
| [Skywork o1](https://huggingface.co/Skywork) | 8B | skywork_o1 |
|
|
||||||
| [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
|
| [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
|
||||||
| [TeleChat2](https://huggingface.co/Tele-AI) | 3B/7B/35B/115B | telechat2 |
|
| [VibeThinker-1.5B](https://huggingface.co/WeiboAI) | 1.5B | qwen3 |
|
||||||
| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | xverse |
|
|
||||||
| [Yi/Yi-1.5 (Code)](https://huggingface.co/01-ai) | 1.5B/6B/9B/34B | yi |
|
| [Yi/Yi-1.5 (Code)](https://huggingface.co/01-ai) | 1.5B/6B/9B/34B | yi |
|
||||||
| [Yi-VL](https://huggingface.co/01-ai) | 6B/34B | yi_vl |
|
|
||||||
| [Yuan 2](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan |
|
| [Yuan 2](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan |
|
||||||
|
|
||||||
> [!NOTE]
|
> [!NOTE]
|
||||||
> 对于所有“基座”(Base)模型,`template` 参数可以是 `default`, `alpaca`, `vicuna` 等任意值。但“对话”(Instruct/Chat)模型请务必使用**对应的模板**。
|
> 对于所有“基座”(Base)模型,`template` 参数可以是 `default`, `alpaca`, `vicuna` 等任意值。但“对话”(Instruct/Chat)模型请务必使用**对应的模板**。
|
||||||
>
|
>
|
||||||
|
> 如果模型有推理 / 非推理两个版本,请使用 `_nothink` 后缀来区分不同的模板。例如 `qwen3` 和 `qwen3_nothink`。
|
||||||
|
>
|
||||||
> 请务必在训练和推理时采用**完全一致**的模板。
|
> 请务必在训练和推理时采用**完全一致**的模板。
|
||||||
>
|
>
|
||||||
> \*:您需要从 main 分支安装 `transformers` 并使用 `DISABLE_VERSION_CHECK=1` 来跳过版本检查。
|
> \*:您需要从 main 分支安装 `transformers` 并使用 `DISABLE_VERSION_CHECK=1` 来跳过版本检查。
|
||||||
@@ -511,10 +516,12 @@ huggingface-cli login
|
|||||||
```bash
|
```bash
|
||||||
git clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git
|
git clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git
|
||||||
cd LLaMA-Factory
|
cd LLaMA-Factory
|
||||||
pip install -e ".[torch,metrics]" --no-build-isolation
|
pip install -e ".[metrics]" --no-build-isolation
|
||||||
```
|
```
|
||||||
|
|
||||||
可选的额外依赖项:torch、torch-npu、metrics、deepspeed、liger-kernel、bitsandbytes、hqq、eetq、gptq、aqlm、vllm、sglang、galore、apollo、badam、adam-mini、qwen、minicpm_v、openmind、swanlab、dev
|
可选的额外依赖项:`metrics`、`deepspeed`。使用 `pip install -e ".[metrics,deepspeed]"` 安装。
|
||||||
|
|
||||||
|
其他可选依赖项请参考 `examples/requirements/` 目录下的文件。
|
||||||
|
|
||||||
#### 从镜像安装
|
#### 从镜像安装
|
||||||
|
|
||||||
@@ -533,13 +540,7 @@ docker run -it --rm --gpus=all --ipc=host hiyouga/llamafactory:latest
|
|||||||
使用 [uv](https://github.com/astral-sh/uv) 创建隔离的 Python 环境:
|
使用 [uv](https://github.com/astral-sh/uv) 创建隔离的 Python 环境:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
uv sync --extra torch --extra metrics --prerelease=allow
|
uv run llamafactory-cli webui
|
||||||
```
|
|
||||||
|
|
||||||
在环境中运行 LLaMA-Factory:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
uv run --prerelease=allow llamafactory-cli train examples/train_lora/llama3_lora_pretrain.yaml
|
|
||||||
```
|
```
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
@@ -576,7 +577,7 @@ pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/downl
|
|||||||
|
|
||||||
<details><summary>昇腾 NPU 用户指南</summary>
|
<details><summary>昇腾 NPU 用户指南</summary>
|
||||||
|
|
||||||
在昇腾 NPU 设备上安装 LLaMA Factory 时,请升级 Python 到 3.10 及以上,并需要指定额外依赖项,使用 `pip install -e ".[torch-npu,metrics]"` 命令安装。此外,还需要安装 **[Ascend CANN Toolkit 与 Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**,安装方法请参考[安装教程](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/80RC2alpha002/quickstart/quickstart/quickstart_18_0004.html)或使用以下命令:
|
在昇腾 NPU 设备上安装 LLaMA Factory 时,请升级 Python 到 3.10 及以上,并需要指定额外依赖项,使用 `pip install -e . torch-npu==2.7.1` 命令安装。此外,还需要安装 **[Ascend CANN Toolkit 与 Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**,安装方法请参考[安装教程](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/80RC2alpha002/quickstart/quickstart/quickstart_18_0004.html)或使用以下命令:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# 请替换 URL 为 CANN 版本和设备型号对应的 URL
|
# 请替换 URL 为 CANN 版本和设备型号对应的 URL
|
||||||
@@ -595,8 +596,8 @@ source /usr/local/Ascend/ascend-toolkit/set_env.sh
|
|||||||
| 依赖项 | 至少 | 推荐 |
|
| 依赖项 | 至少 | 推荐 |
|
||||||
| ------------ | ------- | -------------- |
|
| ------------ | ------- | -------------- |
|
||||||
| CANN | 8.0.RC1 | 8.0.0.alpha002 |
|
| CANN | 8.0.RC1 | 8.0.0.alpha002 |
|
||||||
| torch | 2.1.0 | 2.4.0 |
|
| torch | 2.1.0 | 2.7.1 |
|
||||||
| torch-npu | 2.1.0 | 2.4.0.post2 |
|
| torch-npu | 2.1.0 | 2.7.1 |
|
||||||
| deepspeed | 0.13.2 | 0.13.2 |
|
| deepspeed | 0.13.2 | 0.13.2 |
|
||||||
| vllm-ascend | - | 0.7.3 |
|
| vllm-ascend | - | 0.7.3 |
|
||||||
|
|
||||||
|
|||||||
10
data/v1_dpo_demo.jsonl
Normal file
10
data/v1_dpo_demo.jsonl
Normal file
File diff suppressed because one or more lines are too long
4
data/v1_dpo_demo.yaml
Normal file
4
data/v1_dpo_demo.yaml
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
dpo_zh_demo:
|
||||||
|
path: HuggingFaceH4/orca_dpo_pairs
|
||||||
|
split: train_prefs
|
||||||
|
converter: pair
|
||||||
@@ -1,8 +1,9 @@
|
|||||||
identity:
|
identity:
|
||||||
file_name: identity.json
|
path: data/identity.json
|
||||||
|
source: local
|
||||||
converter: alpaca
|
converter: alpaca
|
||||||
alpaca_en_demo:
|
alpaca_en_demo:
|
||||||
file_name: alpaca_en_demo.json
|
path: data/alpaca_en_demo.json
|
||||||
dataset_dir: ~/data
|
source: local
|
||||||
converter: alpaca
|
converter: alpaca
|
||||||
num_samples: 500
|
size: 500
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ FROM ${BASE_IMAGE}
|
|||||||
|
|
||||||
# Installation arguments
|
# Installation arguments
|
||||||
ARG PIP_INDEX=https://pypi.org/simple
|
ARG PIP_INDEX=https://pypi.org/simple
|
||||||
ARG EXTRAS=metrics
|
|
||||||
ARG INSTALL_FLASHATTN=false
|
ARG INSTALL_FLASHATTN=false
|
||||||
ARG HTTP_PROXY=""
|
ARG HTTP_PROXY=""
|
||||||
|
|
||||||
@@ -27,17 +26,13 @@ WORKDIR /app
|
|||||||
# Change pip source
|
# Change pip source
|
||||||
RUN pip config set global.index-url "${PIP_INDEX}" && \
|
RUN pip config set global.index-url "${PIP_INDEX}" && \
|
||||||
pip config set global.extra-index-url "${PIP_INDEX}" && \
|
pip config set global.extra-index-url "${PIP_INDEX}" && \
|
||||||
pip install --no-cache-dir --upgrade pip packaging wheel setuptools
|
pip install --no-cache-dir --upgrade pip packaging wheel setuptools editables "hatchling>=1.18.0"
|
||||||
|
|
||||||
# Install the requirements
|
# Copy the application into the image
|
||||||
COPY requirements.txt /app
|
|
||||||
RUN pip install --no-cache-dir -r requirements.txt
|
|
||||||
|
|
||||||
# Copy the rest of the application into the image
|
|
||||||
COPY . /app
|
COPY . /app
|
||||||
|
|
||||||
# Install LLaMA Factory
|
# Install LLaMA Factory
|
||||||
RUN pip install --no-cache-dir -e ".[${EXTRAS}]" --no-build-isolation
|
RUN pip install --no-cache-dir --no-build-isolation -e ".[metrics,deepspeed]"
|
||||||
|
|
||||||
# Rebuild flash attention
|
# Rebuild flash attention
|
||||||
RUN if [ "${INSTALL_FLASHATTN}" == "true" ]; then \
|
RUN if [ "${INSTALL_FLASHATTN}" == "true" ]; then \
|
||||||
|
|||||||
77
docker/docker-cuda/Dockerfile.megatron
Normal file
77
docker/docker-cuda/Dockerfile.megatron
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
# NVIDIA official image (ubuntu-22.04 + cuda-12.4 + python-3.10)
|
||||||
|
# https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-24-08.html
|
||||||
|
FROM nvcr.io/nvidia/pytorch:24.05-py3
|
||||||
|
|
||||||
|
ENV DEBIAN_FRONTEND=noninteractive
|
||||||
|
ENV PIP_ROOT_USER_ACTION=ignore
|
||||||
|
ENV PYPI_MIRROR=https://mirrors.aliyun.com/pypi/simple/
|
||||||
|
ENV PYPI_TRUSTED_HOST=mirrors.aliyun.com
|
||||||
|
ENV APT_MIRROR=https://mirrors.tuna.tsinghua.edu.cn/ubuntu/
|
||||||
|
|
||||||
|
RUN pip install --upgrade pip setuptools wheel "hatchling>=1.18.0" editables --trusted-host ${PYPI_TRUSTED_HOST} --index-url ${PYPI_MIRROR}
|
||||||
|
|
||||||
|
RUN pip uninstall -y torch torchvision torch-tensorrt \
|
||||||
|
flash_attn transformer-engine \
|
||||||
|
cudf dask-cuda cugraph cugraph-service-server cuml raft-dask cugraph-dgl cugraph-pyg dask-cudf
|
||||||
|
|
||||||
|
RUN pip install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 --index-url https://download.pytorch.org/whl/cu124
|
||||||
|
|
||||||
|
RUN pip uninstall -y opencv opencv-python opencv-python-headless && \
|
||||||
|
rm -rf /usr/local/lib/python3.10/dist-packages/cv2/ && \
|
||||||
|
pip install opencv-python-headless==4.11.0.86 --trusted-host ${PYPI_TRUSTED_HOST} --index-url ${PYPI_MIRROR}
|
||||||
|
|
||||||
|
RUN pip install "numpy==1.26.4" "optree>=0.13.0" "spacy==3.7.5" "weasel==0.4.1" \
|
||||||
|
transformer-engine[pytorch]==2.2.0 megatron-core==0.13.0 deepspeed==0.16.4 \
|
||||||
|
--trusted-host ${PYPI_TRUSTED_HOST} --index-url ${PYPI_MIRROR}
|
||||||
|
|
||||||
|
RUN pip install https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.2.post1/flash_attn-2.7.2.post1+cu12torch2.6cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
|
||||||
|
|
||||||
|
# RUN pip install vllm==0.8.4 \
|
||||||
|
# --trusted-host ${PYPI_TRUSTED_HOST} --index-url ${PYPI_MIRROR}
|
||||||
|
|
||||||
|
WORKDIR /build
|
||||||
|
|
||||||
|
ARG apex_url=git+https://github.com/NVIDIA/apex.git@25.04
|
||||||
|
RUN pip uninstall -y apex && \
|
||||||
|
MAX_JOBS=32 NINJA_FLAGS="-j32" NVCC_APPEND_FLAGS="--threads 32" \
|
||||||
|
pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation \
|
||||||
|
--config-settings "--build-option=--cpp_ext --cuda_ext --parallel 32" ${apex_url}
|
||||||
|
|
||||||
|
RUN rm -rf /build
|
||||||
|
WORKDIR /workspace
|
||||||
|
|
||||||
|
RUN cp /etc/apt/sources.list /etc/apt/sources.list.bak && \
|
||||||
|
{ \
|
||||||
|
echo "deb ${APT_MIRROR} jammy main restricted universe multiverse"; \
|
||||||
|
echo "deb ${APT_MIRROR} jammy-security main restricted universe multiverse"; \
|
||||||
|
echo "deb ${APT_MIRROR} jammy-updates main restricted universe multiverse"; \
|
||||||
|
echo "deb ${APT_MIRROR} jammy-backports main restricted universe multiverse"; \
|
||||||
|
} > /etc/apt/sources.list
|
||||||
|
|
||||||
|
RUN apt-get update && apt-get install -y zip
|
||||||
|
|
||||||
|
RUN apt-get install -y openjdk-21-jdk
|
||||||
|
ENV JAVA_HOME /usr/lib/jvm/java-21-openjdk-amd64
|
||||||
|
|
||||||
|
# pip install LLaMA-Factory
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
# Copy the application into the image
|
||||||
|
COPY . /app
|
||||||
|
|
||||||
|
# Install LLaMA Factory
|
||||||
|
RUN pip install --no-cache-dir -e ".[metrics]" --no-build-isolation
|
||||||
|
|
||||||
|
RUN pip install "git+https://github.com/alibaba/roll.git#subdirectory=mcore_adapter"
|
||||||
|
|
||||||
|
# Expose port 7860 for LLaMA Board
|
||||||
|
ENV GRADIO_SERVER_PORT=7860
|
||||||
|
EXPOSE 7860
|
||||||
|
|
||||||
|
# Expose port 8000 for API service
|
||||||
|
ENV API_PORT=8000
|
||||||
|
EXPOSE 8000
|
||||||
|
|
||||||
|
# unset proxy
|
||||||
|
ENV http_proxy=
|
||||||
|
ENV https_proxy=
|
||||||
@@ -5,7 +5,6 @@ services:
|
|||||||
context: ../..
|
context: ../..
|
||||||
args:
|
args:
|
||||||
PIP_INDEX: https://pypi.org/simple
|
PIP_INDEX: https://pypi.org/simple
|
||||||
EXTRAS: metrics
|
|
||||||
container_name: llamafactory
|
container_name: llamafactory
|
||||||
ports:
|
ports:
|
||||||
- "7860:7860"
|
- "7860:7860"
|
||||||
|
|||||||
@@ -1,10 +1,10 @@
|
|||||||
# https://hub.docker.com/r/ascendai/cann/tags
|
# https://hub.docker.com/r/ascendai/cann/tags
|
||||||
ARG BASE_IMAGE=ascendai/cann:8.1.rc1-910b-ubuntu22.04-py3.11
|
|
||||||
|
ARG BASE_IMAGE=quay.io/ascend/cann:8.3.rc2-910b-ubuntu22.04-py3.11
|
||||||
FROM ${BASE_IMAGE}
|
FROM ${BASE_IMAGE}
|
||||||
|
|
||||||
# Installation arguments
|
# Installation arguments
|
||||||
ARG PIP_INDEX=https://pypi.org/simple
|
ARG PIP_INDEX=https://pypi.org/simple
|
||||||
ARG EXTRAS=torch-npu,metrics
|
|
||||||
ARG HTTP_PROXY=""
|
ARG HTTP_PROXY=""
|
||||||
ARG PYTORCH_INDEX=https://download.pytorch.org/whl/cpu
|
ARG PYTORCH_INDEX=https://download.pytorch.org/whl/cpu
|
||||||
|
|
||||||
@@ -27,21 +27,15 @@ WORKDIR /app
|
|||||||
# Change pip source
|
# Change pip source
|
||||||
RUN pip config set global.index-url "${PIP_INDEX}" && \
|
RUN pip config set global.index-url "${PIP_INDEX}" && \
|
||||||
pip config set global.extra-index-url "${PIP_INDEX}" && \
|
pip config set global.extra-index-url "${PIP_INDEX}" && \
|
||||||
pip install --no-cache-dir --upgrade pip packaging wheel setuptools
|
pip install --no-cache-dir --upgrade pip packaging wheel setuptools editables "hatchling>=1.18.0"
|
||||||
|
|
||||||
|
# Copy the application into the image
|
||||||
|
COPY . /app
|
||||||
|
|
||||||
# Install torch-npu
|
# Install torch-npu
|
||||||
RUN pip uninstall -y torch torchvision torchaudio && \
|
RUN pip uninstall -y torch torchvision torchaudio && \
|
||||||
pip install --no-cache-dir "torch-npu==2.5.1" "torchvision==0.20.1" --index-url "${PYTORCH_INDEX}"
|
pip install --no-cache-dir "torch==2.7.1" "torch-npu==2.7.1" "torchvision==0.22.1" "torchaudio==2.7.1" --index-url "${PYTORCH_INDEX}" && \
|
||||||
|
pip install --no-cache-dir -e ".[metrics]" --no-build-isolation
|
||||||
# Install the requirements
|
|
||||||
COPY requirements.txt /app
|
|
||||||
RUN pip install --no-cache-dir -r requirements.txt
|
|
||||||
|
|
||||||
# Copy the rest of the application into the image
|
|
||||||
COPY . /app
|
|
||||||
|
|
||||||
# Install LLaMA Factory
|
|
||||||
RUN pip install --no-cache-dir -e ".[${EXTRAS}]" --no-build-isolation
|
|
||||||
|
|
||||||
# Set up volumes
|
# Set up volumes
|
||||||
# VOLUME [ "/root/.cache/huggingface", "/app/shared_data", "/app/output" ]
|
# VOLUME [ "/root/.cache/huggingface", "/app/shared_data", "/app/output" ]
|
||||||
|
|||||||
@@ -1,12 +1,12 @@
|
|||||||
services:
|
services:
|
||||||
llamafactory:
|
llamafactory-a2:
|
||||||
build:
|
build:
|
||||||
dockerfile: ./docker/docker-npu/Dockerfile
|
dockerfile: ./docker/docker-npu/Dockerfile
|
||||||
context: ../..
|
context: ../..
|
||||||
args:
|
args:
|
||||||
PIP_INDEX: https://pypi.org/simple
|
PIP_INDEX: https://pypi.org/simple
|
||||||
EXTRAS: torch-npu,metrics
|
container_name: llamafactory-a2
|
||||||
container_name: llamafactory
|
image: llamafactory:npu-a2
|
||||||
volumes:
|
volumes:
|
||||||
- /usr/local/dcmi:/usr/local/dcmi
|
- /usr/local/dcmi:/usr/local/dcmi
|
||||||
- /usr/local/bin/npu-smi:/usr/local/bin/npu-smi
|
- /usr/local/bin/npu-smi:/usr/local/bin/npu-smi
|
||||||
@@ -26,3 +26,33 @@ services:
|
|||||||
- /dev/devmm_svm
|
- /dev/devmm_svm
|
||||||
- /dev/hisi_hdc
|
- /dev/hisi_hdc
|
||||||
restart: unless-stopped
|
restart: unless-stopped
|
||||||
|
|
||||||
|
llamafactory-a3:
|
||||||
|
profiles: ["a3"]
|
||||||
|
build:
|
||||||
|
dockerfile: ./docker/docker-npu/Dockerfile
|
||||||
|
context: ../..
|
||||||
|
args:
|
||||||
|
BASE_IMAGE: quay.io/ascend/cann:8.3.rc2-a3-ubuntu22.04-py3.11
|
||||||
|
PIP_INDEX: https://pypi.org/simple
|
||||||
|
container_name: llamafactory-a3
|
||||||
|
image: llamafactory:npu-a3
|
||||||
|
volumes:
|
||||||
|
- /usr/local/dcmi:/usr/local/dcmi
|
||||||
|
- /usr/local/bin/npu-smi:/usr/local/bin/npu-smi
|
||||||
|
- /usr/local/Ascend/driver:/usr/local/Ascend/driver
|
||||||
|
- /etc/ascend_install.info:/etc/ascend_install.info
|
||||||
|
ports:
|
||||||
|
- "7861:7860"
|
||||||
|
- "8001:8000"
|
||||||
|
ipc: host
|
||||||
|
tty: true
|
||||||
|
# shm_size: "16gb" # ipc: host is set
|
||||||
|
stdin_open: true
|
||||||
|
command: bash
|
||||||
|
devices:
|
||||||
|
- /dev/davinci0
|
||||||
|
- /dev/davinci_manager
|
||||||
|
- /dev/devmm_svm
|
||||||
|
- /dev/hisi_hdc
|
||||||
|
restart: unless-stopped
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ FROM ${BASE_IMAGE}
|
|||||||
|
|
||||||
# Installation arguments
|
# Installation arguments
|
||||||
ARG PIP_INDEX=https://pypi.org/simple
|
ARG PIP_INDEX=https://pypi.org/simple
|
||||||
ARG EXTRAS=metrics
|
|
||||||
ARG INSTALL_FLASHATTN=false
|
ARG INSTALL_FLASHATTN=false
|
||||||
ARG HTTP_PROXY=""
|
ARG HTTP_PROXY=""
|
||||||
ARG PYTORCH_INDEX=https://download.pytorch.org/whl/rocm6.3
|
ARG PYTORCH_INDEX=https://download.pytorch.org/whl/rocm6.3
|
||||||
@@ -28,21 +27,14 @@ WORKDIR /app
|
|||||||
# Change pip source
|
# Change pip source
|
||||||
RUN pip config set global.index-url "${PIP_INDEX}" && \
|
RUN pip config set global.index-url "${PIP_INDEX}" && \
|
||||||
pip config set global.extra-index-url "${PIP_INDEX}" && \
|
pip config set global.extra-index-url "${PIP_INDEX}" && \
|
||||||
pip install --no-cache-dir --upgrade pip packaging wheel setuptools
|
pip install --no-cache-dir --upgrade pip packaging wheel setuptools editables "hatchling>=1.18.0"
|
||||||
|
|
||||||
# Reinstall pytorch rocm
|
# Copy the application into the image
|
||||||
RUN pip uninstall -y torch torchvision torchaudio && \
|
|
||||||
pip install --no-cache-dir --pre torch torchvision torchaudio --index-url "${PYTORCH_INDEX}"
|
|
||||||
|
|
||||||
# Install the requirements
|
|
||||||
COPY requirements.txt /app
|
|
||||||
RUN pip install --no-cache-dir -r requirements.txt
|
|
||||||
|
|
||||||
# Copy the rest of the application into the image
|
|
||||||
COPY . /app
|
COPY . /app
|
||||||
|
|
||||||
# Install LLaMA Factory
|
# Reinstall pytorch rocm and install LLaMA Factory
|
||||||
RUN pip install --no-cache-dir -e ".[${EXTRAS}]" --no-build-isolation
|
RUN pip uninstall -y torch torchvision torchaudio && \
|
||||||
|
pip install --no-cache-dir --no-build-isolation -e --pre ".[metrics,deepspeed]" --index-url "${PYTORCH_INDEX}"
|
||||||
|
|
||||||
# Rebuild flash attention
|
# Rebuild flash attention
|
||||||
RUN if [ "${INSTALL_FLASHATTN}" == "true" ]; then \
|
RUN if [ "${INSTALL_FLASHATTN}" == "true" ]; then \
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ services:
|
|||||||
context: ../..
|
context: ../..
|
||||||
args:
|
args:
|
||||||
PIP_INDEX: https://pypi.org/simple
|
PIP_INDEX: https://pypi.org/simple
|
||||||
EXTRAS: metrics
|
|
||||||
container_name: llamafactory
|
container_name: llamafactory
|
||||||
ports:
|
ports:
|
||||||
- "7860:7860"
|
- "7860:7860"
|
||||||
|
|||||||
22
examples/accelerate/fsdp2_config.yaml
Normal file
22
examples/accelerate/fsdp2_config.yaml
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
compute_environment: LOCAL_MACHINE
|
||||||
|
debug: false
|
||||||
|
distributed_type: FSDP
|
||||||
|
downcast_bf16: 'no'
|
||||||
|
fsdp_config:
|
||||||
|
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||||
|
fsdp_cpu_ram_efficient_loading: true
|
||||||
|
fsdp_offload_params: false
|
||||||
|
fsdp_reshard_after_forward: true
|
||||||
|
fsdp_state_dict_type: FULL_STATE_DICT
|
||||||
|
fsdp_version: 2
|
||||||
|
machine_rank: 0
|
||||||
|
main_training_function: main
|
||||||
|
mixed_precision: bf16 # or fp16
|
||||||
|
num_machines: 1 # the number of nodes
|
||||||
|
num_processes: 2 # the number of GPUs in all nodes
|
||||||
|
rdzv_backend: static
|
||||||
|
same_network: true
|
||||||
|
tpu_env: []
|
||||||
|
tpu_use_cluster: false
|
||||||
|
tpu_use_sudo: false
|
||||||
|
use_cpu: false
|
||||||
34
examples/accelerate/fsdp_config_multiple_nodes.yaml
Normal file
34
examples/accelerate/fsdp_config_multiple_nodes.yaml
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
# If you want to run this example on multiple nodes, you need to set the following parameters:
|
||||||
|
# - num_machines: the number of nodes
|
||||||
|
# - num_processes: the number of GPUs in all nodes, num_machines * num_processes_per_machine
|
||||||
|
# - main_process_ip: the IP address of the main process, please keep it the same across all nodes
|
||||||
|
# - main_process_port: the port of all nodes, please keep it the same across all nodes
|
||||||
|
# - machine_rank: the rank of the current machine, starting from 0, and it should be 0 for main_process_ip
|
||||||
|
|
||||||
|
compute_environment: LOCAL_MACHINE
|
||||||
|
debug: false
|
||||||
|
distributed_type: FSDP
|
||||||
|
downcast_bf16: 'no'
|
||||||
|
fsdp_config:
|
||||||
|
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||||
|
fsdp_backward_prefetch: BACKWARD_PRE
|
||||||
|
fsdp_forward_prefetch: false
|
||||||
|
fsdp_cpu_ram_efficient_loading: true
|
||||||
|
fsdp_offload_params: false
|
||||||
|
fsdp_sharding_strategy: FULL_SHARD
|
||||||
|
fsdp_state_dict_type: FULL_STATE_DICT
|
||||||
|
fsdp_sync_module_states: true
|
||||||
|
fsdp_use_orig_params: true
|
||||||
|
machine_rank: 0
|
||||||
|
main_training_function: main
|
||||||
|
mixed_precision: bf16 # or fp16
|
||||||
|
main_process_ip: 192.168.0.1
|
||||||
|
main_process_port: 29500
|
||||||
|
num_machines: 2 # the number of nodes
|
||||||
|
num_processes: 16 # the number of GPUs in all nodes, num_machines * num_processes_per_machine
|
||||||
|
rdzv_backend: static
|
||||||
|
same_network: true
|
||||||
|
tpu_env: []
|
||||||
|
tpu_use_cluster: false
|
||||||
|
tpu_use_sudo: false
|
||||||
|
use_cpu: false
|
||||||
45
examples/ascend/qwen3_full_sft_fsdp2.yaml
Normal file
45
examples/ascend/qwen3_full_sft_fsdp2.yaml
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
# Start FSDP2 fine-tuning
|
||||||
|
# accelerate launch \
|
||||||
|
# --config_file examples/accelerate/fsdp2_config.yaml \
|
||||||
|
# src/train.py examples/ascend/qwen3_full_sft_fsdp2.yaml
|
||||||
|
# Change `num_processes` in fsdp2_config.yaml to 16 in A3
|
||||||
|
|
||||||
|
### model
|
||||||
|
model_name_or_path: Qwen/Qwen3-8B
|
||||||
|
trust_remote_code: true
|
||||||
|
use_v1_kernels: true
|
||||||
|
flash_attn: fa2
|
||||||
|
|
||||||
|
### method
|
||||||
|
stage: sft
|
||||||
|
do_train: true
|
||||||
|
finetuning_type: full
|
||||||
|
|
||||||
|
### dataset
|
||||||
|
dataset: alpaca_en_demo
|
||||||
|
template: qwen3
|
||||||
|
cutoff_len: 2048
|
||||||
|
max_samples: 1000
|
||||||
|
overwrite_cache: true
|
||||||
|
preprocessing_num_workers: 16
|
||||||
|
dataloader_num_workers: 4
|
||||||
|
|
||||||
|
### output
|
||||||
|
output_dir: saves/Qwen3-8B/full/sft
|
||||||
|
logging_steps: 1
|
||||||
|
save_steps: 500
|
||||||
|
max_steps: 500
|
||||||
|
plot_loss: true
|
||||||
|
overwrite_output_dir: true
|
||||||
|
save_only_model: false
|
||||||
|
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
|
||||||
|
|
||||||
|
### train
|
||||||
|
per_device_train_batch_size: 8
|
||||||
|
gradient_accumulation_steps: 1
|
||||||
|
learning_rate: 1.0e-5
|
||||||
|
lr_scheduler_type: cosine
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
bf16: true
|
||||||
|
ddp_timeout: 1800
|
||||||
|
resume_from_checkpoint: null
|
||||||
46
examples/ascend/qwen3moe_full_sft_fsdp.yaml
Normal file
46
examples/ascend/qwen3moe_full_sft_fsdp.yaml
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
# Start FSDP fine-tuning
|
||||||
|
# accelerate launch \
|
||||||
|
# --config_file examples/accelerate/fsdp_config.yaml \
|
||||||
|
# src/train.py examples/ascend/qwen3moe_full_sft_fsdp.yaml
|
||||||
|
# Change `num_processes` in fsdp_config.yaml to 16 in A3
|
||||||
|
|
||||||
|
### model
|
||||||
|
model_name_or_path: Qwen/Qwen3-30B-A3B-Instruct-2507
|
||||||
|
trust_remote_code: true
|
||||||
|
use_v1_kernels: true
|
||||||
|
flash_attn: fa2
|
||||||
|
|
||||||
|
### method
|
||||||
|
stage: sft
|
||||||
|
do_train: true
|
||||||
|
finetuning_type: full
|
||||||
|
disable_gradient_checkpointing: false
|
||||||
|
|
||||||
|
### dataset
|
||||||
|
dataset: alpaca_zh
|
||||||
|
template: qwen3
|
||||||
|
cutoff_len: 1024
|
||||||
|
overwrite_cache: true
|
||||||
|
preprocessing_num_workers: 16
|
||||||
|
dataloader_num_workers: 4
|
||||||
|
|
||||||
|
### output
|
||||||
|
output_dir: saves/Qwen3-30B-A3B-Instruct-2507/full/sft
|
||||||
|
logging_steps: 1
|
||||||
|
save_steps: 500
|
||||||
|
max_steps: 500
|
||||||
|
plot_loss: true
|
||||||
|
overwrite_output_dir: true
|
||||||
|
save_only_model: true
|
||||||
|
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
|
||||||
|
|
||||||
|
### train
|
||||||
|
per_device_train_batch_size: 4
|
||||||
|
gradient_accumulation_steps: 1
|
||||||
|
learning_rate: 1.0e-4
|
||||||
|
lr_scheduler_type: cosine
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
bf16: true
|
||||||
|
ddp_timeout: 180000000
|
||||||
|
resume_from_checkpoint: null
|
||||||
|
seed: 1234
|
||||||
48
examples/ascend/qwen3vlmoe_full_sft_fsdp2.yaml
Normal file
48
examples/ascend/qwen3vlmoe_full_sft_fsdp2.yaml
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
# Start FSDP2 fine-tuning
|
||||||
|
# accelerate launch \
|
||||||
|
# --config_file examples/accelerate/fsdp2_config.yaml \
|
||||||
|
# src/train.py examples/ascend/qwen3vlmoe_full_sft_fsdp2.yaml
|
||||||
|
# Change `num_processes` in fsdp2_config.yaml to 16 in A3
|
||||||
|
|
||||||
|
### model
|
||||||
|
model_name_or_path: Qwen/Qwen3-VL-30B-A3B-Instruct
|
||||||
|
image_max_pixels: 262144
|
||||||
|
video_max_pixels: 16384
|
||||||
|
trust_remote_code: true
|
||||||
|
use_v1_kernels: true
|
||||||
|
flash_attn: fa2
|
||||||
|
|
||||||
|
### method
|
||||||
|
stage: sft
|
||||||
|
do_train: true
|
||||||
|
finetuning_type: full
|
||||||
|
disable_gradient_checkpointing: false
|
||||||
|
|
||||||
|
### dataset
|
||||||
|
dataset: llava_1k_en, llava_1k_zh
|
||||||
|
template: qwen3_vl
|
||||||
|
cutoff_len: 1024
|
||||||
|
overwrite_cache: true
|
||||||
|
preprocessing_num_workers: 16
|
||||||
|
dataloader_num_workers: 4
|
||||||
|
|
||||||
|
### output
|
||||||
|
output_dir: saves/Qwen3-VL-30B-A3B-Instruct/full/sft
|
||||||
|
logging_steps: 1
|
||||||
|
save_steps: 500
|
||||||
|
max_steps: 500
|
||||||
|
plot_loss: true
|
||||||
|
overwrite_output_dir: true
|
||||||
|
save_only_model: true
|
||||||
|
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
|
||||||
|
|
||||||
|
### train
|
||||||
|
per_device_train_batch_size: 2
|
||||||
|
gradient_accumulation_steps: 1
|
||||||
|
learning_rate: 1.0e-4
|
||||||
|
lr_scheduler_type: cosine
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
bf16: true
|
||||||
|
ddp_timeout: 180000000
|
||||||
|
resume_from_checkpoint: null
|
||||||
|
seed: 1234
|
||||||
42
examples/ascend/qwen3vlmoe_lora_sft_fsdp.yaml
Normal file
42
examples/ascend/qwen3vlmoe_lora_sft_fsdp.yaml
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
### model
|
||||||
|
model_name_or_path: Qwen/Qwen3-VL-30B-A3B-Instruct
|
||||||
|
image_max_pixels: 262144
|
||||||
|
video_max_pixels: 16384
|
||||||
|
trust_remote_code: true
|
||||||
|
use_v1_kernels: true # replaced kernels: [NpuRMSNormKernel, NpuRoPEKernel, NpuQwen3VLMoEFusedMoEKernel]
|
||||||
|
|
||||||
|
### method
|
||||||
|
stage: sft
|
||||||
|
do_train: true
|
||||||
|
finetuning_type: lora
|
||||||
|
lora_rank: 8
|
||||||
|
lora_target: all
|
||||||
|
disable_gradient_checkpointing: false
|
||||||
|
flash_attn: disabled
|
||||||
|
|
||||||
|
### dataset
|
||||||
|
dataset: alpaca_zh_demo, alpaca_en_demo
|
||||||
|
template: qwen3_vl
|
||||||
|
cutoff_len: 1024
|
||||||
|
overwrite_cache: true
|
||||||
|
preprocessing_num_workers: 16
|
||||||
|
dataloader_num_workers: 4
|
||||||
|
|
||||||
|
### output
|
||||||
|
output_dir: saves/qwen3vlmoe/lora/sft
|
||||||
|
logging_steps: 1
|
||||||
|
plot_loss: true
|
||||||
|
overwrite_output_dir: true
|
||||||
|
save_only_model: true
|
||||||
|
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
|
||||||
|
|
||||||
|
### train
|
||||||
|
per_device_train_batch_size: 8
|
||||||
|
gradient_accumulation_steps: 1
|
||||||
|
learning_rate: 1.0e-4
|
||||||
|
lr_scheduler_type: cosine
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
bf16: true
|
||||||
|
ddp_timeout: 180000000
|
||||||
|
resume_from_checkpoint: null
|
||||||
|
seed: 1234
|
||||||
32
examples/deepspeed/ds_z2_autotp_config.json
Normal file
32
examples/deepspeed/ds_z2_autotp_config.json
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
{
|
||||||
|
"_comment": "suooprted model list: https://www.deepspeed.ai/tutorials/automatic-tensor-parallelism/#supported-models",
|
||||||
|
"train_batch_size": "auto",
|
||||||
|
"train_micro_batch_size_per_gpu": "auto",
|
||||||
|
"gradient_accumulation_steps": "auto",
|
||||||
|
"gradient_clipping": "auto",
|
||||||
|
"zero_allow_untested_optimizer": true,
|
||||||
|
"fp16": {
|
||||||
|
"enabled": "auto",
|
||||||
|
"loss_scale": 0,
|
||||||
|
"loss_scale_window": 1000,
|
||||||
|
"initial_scale_power": 16,
|
||||||
|
"hysteresis": 2,
|
||||||
|
"min_loss_scale": 1
|
||||||
|
},
|
||||||
|
"bf16": {
|
||||||
|
"enabled": "auto"
|
||||||
|
},
|
||||||
|
"zero_optimization": {
|
||||||
|
"stage": 2,
|
||||||
|
"allgather_partitions": true,
|
||||||
|
"allgather_bucket_size": 5e8,
|
||||||
|
"overlap_comm": false,
|
||||||
|
"reduce_scatter": true,
|
||||||
|
"reduce_bucket_size": 5e8,
|
||||||
|
"contiguous_gradients": true,
|
||||||
|
"round_robin_gradients": true
|
||||||
|
},
|
||||||
|
"tensor_parallel": {
|
||||||
|
"autotp_size": 2
|
||||||
|
}
|
||||||
|
}
|
||||||
10
examples/inference/deepseek2_lora_sft_kt.yaml
Normal file
10
examples/inference/deepseek2_lora_sft_kt.yaml
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
model_name_or_path: deepseek-ai/DeepSeek-V2-Lite
|
||||||
|
adapter_name_or_path: saves/Kllama_deepseekV2
|
||||||
|
template: deepseek
|
||||||
|
infer_backend: ktransformers # choices: [huggingface, vllm, sglang, ktransformers]
|
||||||
|
trust_remote_code: true
|
||||||
|
|
||||||
|
use_kt: true # use KTransformers as LoRA sft backend to inference
|
||||||
|
kt_optimize_rule: examples/kt_optimize_rules/DeepSeek-V2-Lite-Chat-sft-amx.yaml
|
||||||
|
cpu_infer: 32
|
||||||
|
chunk_size: 8192
|
||||||
9
examples/inference/deepseek3_kt.yaml
Normal file
9
examples/inference/deepseek3_kt.yaml
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
model_name_or_path: opensourcerelease/DeepSeek-V3-bf16
|
||||||
|
template: deepseek
|
||||||
|
infer_backend: ktransformers # choices: [huggingface, vllm, sglang, ktransformers]
|
||||||
|
trust_remote_code: true
|
||||||
|
|
||||||
|
use_kt: true # use KTransformers as LoRA sft backend to inference
|
||||||
|
kt_optimize_rule: examples/kt_optimize_rules/DeepSeek-V3-Chat-sft-amx-multi-gpu.yaml
|
||||||
|
cpu_infer: 32
|
||||||
|
chunk_size: 8192
|
||||||
10
examples/inference/deepseek3_lora_sft_kt.yaml
Normal file
10
examples/inference/deepseek3_lora_sft_kt.yaml
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
model_name_or_path: opensourcerelease/DeepSeek-V3-bf16
|
||||||
|
adapter_name_or_path: saves/Kllama_deepseekV3
|
||||||
|
template: deepseek
|
||||||
|
infer_backend: ktransformers # choices: [huggingface, vllm, sglang, ktransformers]
|
||||||
|
trust_remote_code: true
|
||||||
|
|
||||||
|
use_kt: true # use KTransformers as LoRA sft backend to inference
|
||||||
|
kt_optimize_rule: examples/kt_optimize_rules/DeepSeek-V3-Chat-sft-amx-multi-gpu.yaml
|
||||||
|
cpu_infer: 32
|
||||||
|
chunk_size: 8192
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
||||||
template: llama3
|
template: llama3
|
||||||
infer_backend: huggingface # choices: [huggingface, vllm, sglang]
|
infer_backend: huggingface # choices: [huggingface, vllm, sglang, ktransformers]
|
||||||
trust_remote_code: true
|
trust_remote_code: true
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
model_name_or_path: saves/llama3-8b/full/sft
|
model_name_or_path: saves/llama3-8b/full/sft
|
||||||
template: llama3
|
template: llama3
|
||||||
infer_backend: huggingface # choices: [huggingface, vllm, sglang]
|
infer_backend: huggingface # choices: [huggingface, vllm, sglang, ktransformers]
|
||||||
trust_remote_code: true
|
trust_remote_code: true
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
||||||
adapter_name_or_path: saves/llama3-8b/lora/sft
|
adapter_name_or_path: saves/llama3-8b/lora/sft
|
||||||
template: llama3
|
template: llama3
|
||||||
infer_backend: huggingface # choices: [huggingface, vllm, sglang]
|
infer_backend: huggingface # choices: [huggingface, vllm, sglang, ktransformers]
|
||||||
trust_remote_code: true
|
trust_remote_code: true
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
model_name_or_path: Qwen/Qwen2.5-VL-7B-Instruct
|
model_name_or_path: Qwen/Qwen2.5-VL-7B-Instruct
|
||||||
template: qwen2_vl
|
template: qwen2_vl
|
||||||
infer_backend: huggingface # choices: [huggingface, vllm, sglang]
|
infer_backend: huggingface # choices: [huggingface, vllm, sglang, ktransformers]
|
||||||
trust_remote_code: true
|
trust_remote_code: true
|
||||||
|
|||||||
10
examples/inference/qwen3moe_lora_sft_kt.yaml
Normal file
10
examples/inference/qwen3moe_lora_sft_kt.yaml
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
model_name_or_path: Qwen/Qwen3-235B-A22B-Instruct-2507
|
||||||
|
adapter_name_or_path: saves/Kllama_Qwen3MoE_235bA22b
|
||||||
|
template: qwen3_nothink
|
||||||
|
infer_backend: ktransformers # choices: [huggingface, vllm, sglang, ktransformers]
|
||||||
|
trust_remote_code: true
|
||||||
|
|
||||||
|
use_kt: true # use KTransformers as LoRA sft backend to inference
|
||||||
|
kt_optimize_rule: examples/kt_optimize_rules/Qwen3Moe-sft-amx.yaml
|
||||||
|
cpu_infer: 32
|
||||||
|
chunk_size: 8192
|
||||||
69
examples/kt_optimize_rules/DeepSeek-V2-Chat-sft-amx.yaml
Normal file
69
examples/kt_optimize_rules/DeepSeek-V2-Chat-sft-amx.yaml
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
- match:
|
||||||
|
class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.RoPE.YarnRotaryEmbedding
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda"
|
||||||
|
prefill_device: "cuda"
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
|
||||||
|
class: torch.nn.Linear # only match modules matching name and class simultaneously
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda"
|
||||||
|
prefill_device: "cuda"
|
||||||
|
generate_op: "KLinearTorch"
|
||||||
|
prefill_op: "KLinearTorch"
|
||||||
|
|
||||||
|
- match:
|
||||||
|
name: "^lm_head"
|
||||||
|
class: torch.nn.Linear
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.linear.KTransformersLinear
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda"
|
||||||
|
prefill_device: "cuda"
|
||||||
|
generate_op: "KLinearTorch"
|
||||||
|
prefill_op: "KLinearTorch"
|
||||||
|
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\..*\\.mlp$"
|
||||||
|
class: ktransformers.models.modeling_deepseek.DeepseekV2MoE
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.experts.KDeepseekV2MoE # mlp module with custom forward function
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda"
|
||||||
|
prefill_device: "cuda"
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\..*\\.mlp\\.experts$"
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
|
||||||
|
kwargs:
|
||||||
|
prefill_device: "cuda"
|
||||||
|
prefill_op: "KExpertsTorch"
|
||||||
|
generate_device: "cpu"
|
||||||
|
generate_op: "KSFTExpertsCPU"
|
||||||
|
out_device: "cuda"
|
||||||
|
backend: "AMXInt8" # or "AMXBF16" or "llamafile" (default)
|
||||||
|
recursive: False # don't recursively inject submodules of this module
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\..*\\.self_attn$"
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda"
|
||||||
|
prefill_device: "cuda"
|
||||||
|
- match:
|
||||||
|
name: "^model$"
|
||||||
|
replace:
|
||||||
|
class: "ktransformers.operators.models.KDeepseekV2Model"
|
||||||
|
kwargs:
|
||||||
|
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
|
||||||
|
- match:
|
||||||
|
name: "^model.embed_tokens"
|
||||||
|
replace:
|
||||||
|
class: "default"
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cpu"
|
||||||
|
prefill_device: "cpu"
|
||||||
68
examples/kt_optimize_rules/DeepSeek-V2-Chat.yaml
Normal file
68
examples/kt_optimize_rules/DeepSeek-V2-Chat.yaml
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
- match:
|
||||||
|
class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.RoPE.YarnRotaryEmbedding
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda"
|
||||||
|
prefill_device: "cuda"
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
|
||||||
|
class: torch.nn.Linear # only match modules matching name and class simultaneously
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda"
|
||||||
|
prefill_device: "cuda"
|
||||||
|
generate_op: "KLinearMarlin"
|
||||||
|
prefill_op: "KLinearTorch"
|
||||||
|
|
||||||
|
- match:
|
||||||
|
name: "^lm_head"
|
||||||
|
class: torch.nn.Linear
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.linear.KTransformersLinear
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda"
|
||||||
|
prefill_device: "cuda"
|
||||||
|
generate_op: "KLinearMarlin"
|
||||||
|
prefill_op: "KLinearTorch"
|
||||||
|
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\..*\\.mlp$"
|
||||||
|
class: ktransformers.models.modeling_deepseek.DeepseekV2MoE
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.experts.KDeepseekV2MoE # mlp module with custom forward function
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda"
|
||||||
|
prefill_device: "cuda"
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\..*\\.mlp\\.experts$"
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
|
||||||
|
kwargs:
|
||||||
|
prefill_device: "cuda"
|
||||||
|
prefill_op: "KExpertsTorch"
|
||||||
|
generate_device: "cpu"
|
||||||
|
generate_op: "KExpertsCPU"
|
||||||
|
out_device: "cuda"
|
||||||
|
recursive: False # don't recursively inject submodules of this module
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\..*\\.self_attn$"
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda"
|
||||||
|
prefill_device: "cuda"
|
||||||
|
- match:
|
||||||
|
name: "^model$"
|
||||||
|
replace:
|
||||||
|
class: "ktransformers.operators.models.KDeepseekV2Model"
|
||||||
|
kwargs:
|
||||||
|
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
|
||||||
|
- match:
|
||||||
|
name: "^model.embed_tokens"
|
||||||
|
replace:
|
||||||
|
class: "default"
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cpu"
|
||||||
|
prefill_device: "cpu"
|
||||||
@@ -0,0 +1,139 @@
|
|||||||
|
- match:
|
||||||
|
name: "^model.embed_tokens"
|
||||||
|
replace:
|
||||||
|
class: "default"
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cpu"
|
||||||
|
prefill_device: "cpu"
|
||||||
|
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\.(0|[1-9])\\."
|
||||||
|
class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.RoPE.YarnRotaryEmbedding
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda:0"
|
||||||
|
prefill_device: "cuda:0"
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\.([12][0-9])\\."
|
||||||
|
class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.RoPE.YarnRotaryEmbedding
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda:1"
|
||||||
|
prefill_device: "cuda:1"
|
||||||
|
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\.(0|[1-9])\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
|
||||||
|
class: torch.nn.Linear # only match modules matching name and class simultaneously
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda:0"
|
||||||
|
prefill_device: "cuda:0"
|
||||||
|
generate_op: "KLinearTorch"
|
||||||
|
prefill_op: "KLinearTorch"
|
||||||
|
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\.([12][0-9])\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
|
||||||
|
class: torch.nn.Linear # only match modules matching name and class simultaneously
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda:1"
|
||||||
|
prefill_device: "cuda:1"
|
||||||
|
generate_op: "KLinearTorch"
|
||||||
|
prefill_op: "KLinearTorch"
|
||||||
|
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\.(0|[1-9])\\.mlp$"
|
||||||
|
class: ktransformers.models.modeling_deepseek.DeepseekV2MoE
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.experts.KDeepseekV2MoE # mlp module with custom forward function
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda:0"
|
||||||
|
prefill_device: "cuda:0"
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\.([12][0-9])\\.mlp$"
|
||||||
|
class: ktransformers.models.modeling_deepseek.DeepseekV2MoE
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.experts.KDeepseekV2MoE # mlp module with custom forward function
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda:1"
|
||||||
|
prefill_device: "cuda:1"
|
||||||
|
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\.(0|[1-9])\\.mlp\\.experts$"
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
|
||||||
|
kwargs:
|
||||||
|
prefill_device: "cuda:0"
|
||||||
|
prefill_op: "KExpertsTorch"
|
||||||
|
generate_device: "cpu"
|
||||||
|
generate_op: "KSFTExpertsCPU"
|
||||||
|
out_device: "cuda:0"
|
||||||
|
backend: "AMXInt8" # or "AMXBF16" or "llamafile" (default)
|
||||||
|
recursive: False # don't recursively inject submodules of this module
|
||||||
|
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\.([12][0-9])\\.mlp\\.experts$"
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
|
||||||
|
kwargs:
|
||||||
|
prefill_device: "cuda:1"
|
||||||
|
prefill_op: "KExpertsTorch"
|
||||||
|
generate_device: "cpu"
|
||||||
|
generate_op: "KSFTExpertsCPU"
|
||||||
|
out_device: "cuda:1"
|
||||||
|
backend: "AMXInt8" # or "AMXBF16" or "llamafile" (default)
|
||||||
|
recursive: False # don't recursively inject submodules of this module
|
||||||
|
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\.(0|[1-9])\\.self_attn$"
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda:0"
|
||||||
|
prefill_device: "cuda:0"
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\.([12][0-9])\\.self_attn$"
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda:1"
|
||||||
|
prefill_device: "cuda:1"
|
||||||
|
- match:
|
||||||
|
name: "^model$"
|
||||||
|
replace:
|
||||||
|
class: "ktransformers.operators.models.KDeepseekV2Model"
|
||||||
|
kwargs:
|
||||||
|
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
|
||||||
|
transfer_map:
|
||||||
|
10: "cuda:1"
|
||||||
|
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\.(0|[1-9])\\."
|
||||||
|
replace:
|
||||||
|
class: "default"
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda:0"
|
||||||
|
prefill_device: "cuda:0"
|
||||||
|
|
||||||
|
- match:
|
||||||
|
name: "^lm_head"
|
||||||
|
class: torch.nn.Linear
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.linear.KTransformersLinear
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda:1"
|
||||||
|
prefill_device: "cuda:1"
|
||||||
|
generate_op: "KLinearTorch"
|
||||||
|
prefill_op: "KLinearTorch"
|
||||||
|
|
||||||
|
- match:
|
||||||
|
name: "(^model\\.layers\\.([12][0-9])\\.)|(model.norm)"
|
||||||
|
replace:
|
||||||
|
class: "default"
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda:1"
|
||||||
|
prefill_device: "cuda:1"
|
||||||
@@ -0,0 +1,69 @@
|
|||||||
|
- match:
|
||||||
|
class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.RoPE.YarnRotaryEmbedding
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda"
|
||||||
|
prefill_device: "cuda"
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
|
||||||
|
class: torch.nn.Linear # only match modules matching name and class simultaneously
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda"
|
||||||
|
prefill_device: "cuda"
|
||||||
|
generate_op: "KLinearTorch"
|
||||||
|
prefill_op: "KLinearTorch"
|
||||||
|
|
||||||
|
- match:
|
||||||
|
name: "^lm_head"
|
||||||
|
class: torch.nn.Linear
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.linear.KTransformersLinear
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda"
|
||||||
|
prefill_device: "cuda"
|
||||||
|
generate_op: "KLinearTorch"
|
||||||
|
prefill_op: "KLinearTorch"
|
||||||
|
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\..*\\.mlp$"
|
||||||
|
class: ktransformers.models.modeling_deepseek.DeepseekV2MoE
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.experts.KDeepseekV2MoE # mlp module with custom forward function
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda"
|
||||||
|
prefill_device: "cuda"
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\..*\\.mlp\\.experts$"
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
|
||||||
|
kwargs:
|
||||||
|
prefill_device: "cpu"
|
||||||
|
prefill_op: "KExpertsTorch"
|
||||||
|
generate_device: "cpu"
|
||||||
|
generate_op: "KSFTExpertsCPU"
|
||||||
|
out_device: "cuda"
|
||||||
|
backend: "AMXInt8" # or "AMXBF16" or "llamafile" (default)
|
||||||
|
recursive: False # don't recursively inject submodules of this module
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\..*\\.self_attn$"
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda"
|
||||||
|
prefill_device: "cuda"
|
||||||
|
- match:
|
||||||
|
name: "^model$"
|
||||||
|
replace:
|
||||||
|
class: "ktransformers.operators.models.KDeepseekV2Model"
|
||||||
|
kwargs:
|
||||||
|
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
|
||||||
|
- match:
|
||||||
|
name: "^model.embed_tokens"
|
||||||
|
replace:
|
||||||
|
class: "default"
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cpu"
|
||||||
|
prefill_device: "cpu"
|
||||||
68
examples/kt_optimize_rules/DeepSeek-V2-Lite-Chat-sft.yaml
Normal file
68
examples/kt_optimize_rules/DeepSeek-V2-Lite-Chat-sft.yaml
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
- match:
|
||||||
|
class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.RoPE.YarnRotaryEmbedding
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda"
|
||||||
|
prefill_device: "cuda"
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
|
||||||
|
class: torch.nn.Linear # only match modules matching name and class simultaneously
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda"
|
||||||
|
prefill_device: "cuda"
|
||||||
|
generate_op: "KLinearTorch"
|
||||||
|
prefill_op: "KLinearTorch"
|
||||||
|
|
||||||
|
- match:
|
||||||
|
name: "^lm_head"
|
||||||
|
class: torch.nn.Linear
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.linear.KTransformersLinear
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda"
|
||||||
|
prefill_device: "cuda"
|
||||||
|
generate_op: "KLinearTorch"
|
||||||
|
prefill_op: "KLinearTorch"
|
||||||
|
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\..*\\.mlp$"
|
||||||
|
class: ktransformers.models.modeling_deepseek.DeepseekV2MoE
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.experts.KDeepseekV2MoE # mlp module with custom forward function
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda"
|
||||||
|
prefill_device: "cuda"
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\..*\\.mlp\\.experts$"
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
|
||||||
|
kwargs:
|
||||||
|
prefill_device: "cpu"
|
||||||
|
prefill_op: "KExpertsTorch"
|
||||||
|
generate_device: "cpu"
|
||||||
|
generate_op: "KSFTExpertsCPU"
|
||||||
|
out_device: "cuda"
|
||||||
|
recursive: False # don't recursively inject submodules of this module
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\..*\\.self_attn$"
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda"
|
||||||
|
prefill_device: "cuda"
|
||||||
|
- match:
|
||||||
|
name: "^model$"
|
||||||
|
replace:
|
||||||
|
class: "ktransformers.operators.models.KDeepseekV2Model"
|
||||||
|
kwargs:
|
||||||
|
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
|
||||||
|
- match:
|
||||||
|
name: "^model.embed_tokens"
|
||||||
|
replace:
|
||||||
|
class: "default"
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cpu"
|
||||||
|
prefill_device: "cpu"
|
||||||
68
examples/kt_optimize_rules/DeepSeek-V2-Lite-Chat.yaml
Normal file
68
examples/kt_optimize_rules/DeepSeek-V2-Lite-Chat.yaml
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
- match:
|
||||||
|
class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.RoPE.YarnRotaryEmbedding
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda"
|
||||||
|
prefill_device: "cuda"
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
|
||||||
|
class: torch.nn.Linear # only match modules matching name and class simultaneously
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda"
|
||||||
|
prefill_device: "cuda"
|
||||||
|
generate_op: "KLinearMarlin"
|
||||||
|
prefill_op: "KLinearTorch"
|
||||||
|
|
||||||
|
- match:
|
||||||
|
name: "^lm_head"
|
||||||
|
class: torch.nn.Linear
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.linear.KTransformersLinear
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda"
|
||||||
|
prefill_device: "cuda"
|
||||||
|
generate_op: "KLinearMarlin"
|
||||||
|
prefill_op: "KLinearTorch"
|
||||||
|
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\..*\\.mlp$"
|
||||||
|
class: ktransformers.models.modeling_deepseek.DeepseekV2MoE
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.experts.KDeepseekV2MoE # mlp module with custom forward function
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda"
|
||||||
|
prefill_device: "cuda"
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\..*\\.mlp\\.experts$"
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
|
||||||
|
kwargs:
|
||||||
|
prefill_device: "cuda"
|
||||||
|
prefill_op: "KExpertsTorch"
|
||||||
|
generate_device: "cpu"
|
||||||
|
generate_op: "KExpertsCPU"
|
||||||
|
out_device: "cuda"
|
||||||
|
recursive: False # don't recursively inject submodules of this module
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\..*\\.self_attn$"
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda"
|
||||||
|
prefill_device: "cuda"
|
||||||
|
- match:
|
||||||
|
name: "^model$"
|
||||||
|
replace:
|
||||||
|
class: "ktransformers.operators.models.KDeepseekV2Model"
|
||||||
|
kwargs:
|
||||||
|
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
|
||||||
|
- match:
|
||||||
|
name: "^model.embed_tokens"
|
||||||
|
replace:
|
||||||
|
class: "default"
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cpu"
|
||||||
|
prefill_device: "cpu"
|
||||||
77
examples/kt_optimize_rules/DeepSeek-V3-Chat-amx.yaml
Normal file
77
examples/kt_optimize_rules/DeepSeek-V3-Chat-amx.yaml
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
- match:
|
||||||
|
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda"
|
||||||
|
prefill_device: "cuda"
|
||||||
|
|
||||||
|
- match:
|
||||||
|
name: "^lm_head$" # regular expression
|
||||||
|
class: torch.nn.Linear # only match modules matching name and class simultaneously
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda"
|
||||||
|
prefill_device: "cuda"
|
||||||
|
generate_op: "KLinearMarlin"
|
||||||
|
prefill_op: "KLinearTorch"
|
||||||
|
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
|
||||||
|
class: torch.nn.Linear # only match modules matching name and class simultaneously
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda"
|
||||||
|
prefill_device: "cuda"
|
||||||
|
generate_op: "KLinearMarlin"
|
||||||
|
prefill_op: "KLinearTorch"
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\..*\\.mlp$"
|
||||||
|
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.experts.KDeepseekV3MoE # mlp module with custom forward function
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda"
|
||||||
|
prefill_device: "cuda"
|
||||||
|
- match:
|
||||||
|
class: ktransformers.models.modeling_deepseek_v3.MoEGate
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.gate.KMoEGate
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda:0"
|
||||||
|
prefill_device: "cuda:0"
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\..*\\.mlp\\.experts$"
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
|
||||||
|
kwargs:
|
||||||
|
prefill_device: "cuda"
|
||||||
|
prefill_op: "KExpertsTorch"
|
||||||
|
generate_device: "cpu"
|
||||||
|
generate_op: "KExpertsCPU"
|
||||||
|
out_device: "cuda"
|
||||||
|
backend: "AMXInt8" # or "AMXBF16" or "llamafile" (default)
|
||||||
|
recursive: False # don't recursively inject submodules of this module
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\..*\\.self_attn$"
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda"
|
||||||
|
prefill_device: "cuda"
|
||||||
|
absorb_for_prefill: False # change this to True to enable long context(prefill may slower).
|
||||||
|
- match:
|
||||||
|
name: "^model$"
|
||||||
|
replace:
|
||||||
|
class: "ktransformers.operators.models.KDeepseekV2Model"
|
||||||
|
kwargs:
|
||||||
|
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
|
||||||
|
- match:
|
||||||
|
name: "^model.embed_tokens"
|
||||||
|
replace:
|
||||||
|
class: "default"
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cpu"
|
||||||
|
prefill_device: "cpu"
|
||||||
@@ -0,0 +1,392 @@
|
|||||||
|
- match:
|
||||||
|
name: "^model.embed_tokens"
|
||||||
|
replace:
|
||||||
|
class: "default"
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cpu"
|
||||||
|
prefill_device: "cpu"
|
||||||
|
|
||||||
|
# === Rotary Embedding Replacement ===
|
||||||
|
|
||||||
|
# GPU 0: layers 0–14
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\.([0-9]|1[0-4])\\."
|
||||||
|
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda:0"
|
||||||
|
prefill_device: "cuda:0"
|
||||||
|
|
||||||
|
# GPU 1: layers 15–29
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\.(1[5-9]|2[0-9])\\."
|
||||||
|
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda:1"
|
||||||
|
prefill_device: "cuda:1"
|
||||||
|
|
||||||
|
# GPU 2: layers 30–44
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\.(3[0-9]|4[0-4])\\."
|
||||||
|
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda:2"
|
||||||
|
prefill_device: "cuda:2"
|
||||||
|
|
||||||
|
# GPU 3: layers 45–60
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\.(4[5-9]|5[0-9]|60)\\."
|
||||||
|
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda:3"
|
||||||
|
prefill_device: "cuda:3"
|
||||||
|
|
||||||
|
# === Linear Layers Replacement (excluding self_attn.kv_b_proj) ===
|
||||||
|
|
||||||
|
# GPU 0: layers 0–14
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\.([0-9]|1[0-4])\\.(?!self_attn\\.kv_b_proj).*$"
|
||||||
|
class: torch.nn.Linear
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.linear.KTransformersLinear
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda:0"
|
||||||
|
prefill_device: "cuda:0"
|
||||||
|
generate_op: "KLinearTorch"
|
||||||
|
prefill_op: "KLinearTorch"
|
||||||
|
|
||||||
|
# GPU 1: layers 15–29
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\.(1[5-9]|2[0-9])\\.(?!self_attn\\.kv_b_proj).*$"
|
||||||
|
class: torch.nn.Linear
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.linear.KTransformersLinear
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda:1"
|
||||||
|
prefill_device: "cuda:1"
|
||||||
|
generate_op: "KLinearTorch"
|
||||||
|
prefill_op: "KLinearTorch"
|
||||||
|
|
||||||
|
# GPU 2: layers 30–44
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\.(3[0-9]|4[0-4])\\.(?!self_attn\\.kv_b_proj).*$"
|
||||||
|
class: torch.nn.Linear
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.linear.KTransformersLinear
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda:2"
|
||||||
|
prefill_device: "cuda:2"
|
||||||
|
generate_op: "KLinearTorch"
|
||||||
|
prefill_op: "KLinearTorch"
|
||||||
|
|
||||||
|
# GPU 3: layers 45–60
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\.(4[5-9]|5[0-9]|60)\\.(?!self_attn\\.kv_b_proj).*$"
|
||||||
|
class: torch.nn.Linear
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.linear.KTransformersLinear
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda:3"
|
||||||
|
prefill_device: "cuda:3"
|
||||||
|
generate_op: "KLinearTorch"
|
||||||
|
prefill_op: "KLinearTorch"
|
||||||
|
|
||||||
|
# === MLP (MoE) Replacement ===
|
||||||
|
|
||||||
|
# GPU 0: layers 0–14
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\.([0-9]|1[0-4])\\.mlp$"
|
||||||
|
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.experts.KDeepseekV3MoE
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda:0"
|
||||||
|
prefill_device: "cuda:0"
|
||||||
|
|
||||||
|
# GPU 1: layers 15–29
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\.(1[5-9]|2[0-9])\\.mlp$"
|
||||||
|
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.experts.KDeepseekV3MoE
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda:1"
|
||||||
|
prefill_device: "cuda:1"
|
||||||
|
|
||||||
|
# GPU 2: layers 30–44
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\.(3[0-9]|4[0-4])\\.mlp$"
|
||||||
|
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.experts.KDeepseekV3MoE
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda:2"
|
||||||
|
prefill_device: "cuda:2"
|
||||||
|
|
||||||
|
# GPU 3: layers 45–60
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\.(4[5-9]|5[0-9]|60)\\.mlp$"
|
||||||
|
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.experts.KDeepseekV3MoE
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda:3"
|
||||||
|
prefill_device: "cuda:3"
|
||||||
|
|
||||||
|
# === MLP Gate Replacement ===
|
||||||
|
|
||||||
|
# GPU 0: layers 0–14
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\.([0-9]|1[0-4])\\.mlp\\.gate$"
|
||||||
|
class: ktransformers.models.modeling_deepseek_v3.MoEGate
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.gate.KMoEGate
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda:0"
|
||||||
|
prefill_device: "cuda:0"
|
||||||
|
|
||||||
|
# GPU 1: layers 15–29
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\.(1[5-9]|2[0-9])\\.mlp\\.gate$"
|
||||||
|
class: ktransformers.models.modeling_deepseek_v3.MoEGate
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.gate.KMoEGate
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda:1"
|
||||||
|
prefill_device: "cuda:1"
|
||||||
|
|
||||||
|
# GPU 2: layers 30–44
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\.(3[0-9]|4[0-4])\\.mlp\\.gate$"
|
||||||
|
class: ktransformers.models.modeling_deepseek_v3.MoEGate
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.gate.KMoEGate
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda:2"
|
||||||
|
prefill_device: "cuda:2"
|
||||||
|
|
||||||
|
# GPU 3: layers 45–60
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\.(4[5-9]|5[0-9]|60)\\.mlp\\.gate$"
|
||||||
|
class: ktransformers.models.modeling_deepseek_v3.MoEGate
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.gate.KMoEGate
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda:3"
|
||||||
|
prefill_device: "cuda:3"
|
||||||
|
|
||||||
|
# === MLP Experts Replacement ===
|
||||||
|
# replace with marlin expert. Open and modify layer-num as needed.
|
||||||
|
# Each layer of malin experts takes about 6GB of GPU memory.
|
||||||
|
# !!!Do remember 'close' cuda graph if you are using marlin expert.!!!
|
||||||
|
# !!!KExpertsTorch is untested, we don't have enough VRAM.!!!
|
||||||
|
|
||||||
|
# GPU 0: layers 3–4
|
||||||
|
# - match:
|
||||||
|
# name: "^model\\.layers\\.([3-4])\\.mlp\\.experts$"
|
||||||
|
# replace:
|
||||||
|
# class: ktransformers.operators.experts.KTransformersExperts
|
||||||
|
# kwargs:
|
||||||
|
# generate_device: "cuda:0"
|
||||||
|
# generate_op: "KExpertsMarlin"
|
||||||
|
# recursive: False
|
||||||
|
|
||||||
|
# # GPU 1: layers 15–17
|
||||||
|
# - match:
|
||||||
|
# name: "^model\\.layers\\.(1[5-7])\\.mlp\\.experts$"
|
||||||
|
# replace:
|
||||||
|
# class: ktransformers.operators.experts.KTransformersExperts
|
||||||
|
# kwargs:
|
||||||
|
# generate_device: "cuda:1"
|
||||||
|
# generate_op: "KExpertsMarlin"
|
||||||
|
# recursive: False
|
||||||
|
|
||||||
|
# # GPU 2: layers 30–32
|
||||||
|
# - match:
|
||||||
|
# name: "^model\\.layers\\.(3[0-2])\\.mlp\\.experts$"
|
||||||
|
# replace:
|
||||||
|
# class: ktransformers.operators.experts.KTransformersExperts
|
||||||
|
# kwargs:
|
||||||
|
# generate_device: "cuda:2"
|
||||||
|
# generate_op: "KExpertsMarlin"
|
||||||
|
# recursive: False
|
||||||
|
|
||||||
|
# # GPU 3: layers 45–46
|
||||||
|
# - match:
|
||||||
|
# name: "^model\\.layers\\.(4[5-6])\\.mlp\\.experts$"
|
||||||
|
# replace:
|
||||||
|
# class: ktransformers.operators.experts.KTransformersExperts
|
||||||
|
# kwargs:
|
||||||
|
# generate_device: "cuda:3"
|
||||||
|
# generate_op: "KExpertsMarlin"
|
||||||
|
# recursive: False
|
||||||
|
|
||||||
|
|
||||||
|
# === MLP Experts Replacement ===
|
||||||
|
|
||||||
|
# GPU 0: layers 0–14
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\.([0-9]|1[0-4])\\.mlp\\.experts$"
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.experts.KTransformersExperts
|
||||||
|
kwargs:
|
||||||
|
prefill_device: "cuda:0"
|
||||||
|
prefill_op: "KExpertsTorch"
|
||||||
|
generate_device: "cpu"
|
||||||
|
generate_op: "KSFTExpertsCPU"
|
||||||
|
out_device: "cuda:0"
|
||||||
|
backend: "AMXInt8" # or "AMXBF16" or "llamafile" (default)
|
||||||
|
recursive: False
|
||||||
|
|
||||||
|
# GPU 1: layers 15–29
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\.(1[5-9]|2[0-9])\\.mlp\\.experts$"
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.experts.KTransformersExperts
|
||||||
|
kwargs:
|
||||||
|
prefill_device: "cuda:1"
|
||||||
|
prefill_op: "KExpertsTorch"
|
||||||
|
generate_device: "cpu"
|
||||||
|
generate_op: "KSFTExpertsCPU"
|
||||||
|
out_device: "cuda:1"
|
||||||
|
backend: "AMXInt8" # or "AMXBF16" or "llamafile" (default)
|
||||||
|
recursive: False
|
||||||
|
|
||||||
|
# GPU 2: layers 30–44
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\.(3[0-9]|4[0-4])\\.mlp\\.experts$"
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.experts.KTransformersExperts
|
||||||
|
kwargs:
|
||||||
|
prefill_device: "cuda:2"
|
||||||
|
prefill_op: "KExpertsTorch"
|
||||||
|
generate_device: "cpu"
|
||||||
|
generate_op: "KSFTExpertsCPU"
|
||||||
|
out_device: "cuda:2"
|
||||||
|
backend: "AMXInt8" # or "AMXBF16" or "llamafile" (default)
|
||||||
|
recursive: False
|
||||||
|
|
||||||
|
# GPU 3: layers 45–60
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\.(4[5-9]|5[0-9]|60)\\.mlp\\.experts$"
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.experts.KTransformersExperts
|
||||||
|
kwargs:
|
||||||
|
prefill_device: "cuda:3"
|
||||||
|
prefill_op: "KExpertsTorch"
|
||||||
|
generate_device: "cpu"
|
||||||
|
generate_op: "KSFTExpertsCPU"
|
||||||
|
out_device: "cuda:3"
|
||||||
|
backend: "AMXInt8" # or "AMXBF16" or "llamafile" (default)
|
||||||
|
recursive: False
|
||||||
|
|
||||||
|
# === Self-Attention Replacement ===
|
||||||
|
|
||||||
|
# GPU 0: layers 0–14
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\.([0-9]|1[0-4])\\.self_attn$"
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.attention.KDeepseekV2Attention
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda:0"
|
||||||
|
prefill_device: "cuda:0"
|
||||||
|
absorb_for_prefill: False
|
||||||
|
|
||||||
|
# GPU 1: layers 15–29
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\.(1[5-9]|2[0-9])\\.self_attn$"
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.attention.KDeepseekV2Attention
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda:1"
|
||||||
|
prefill_device: "cuda:1"
|
||||||
|
absorb_for_prefill: False
|
||||||
|
|
||||||
|
# GPU 2: layers 30–44
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\.(3[0-9]|4[0-4])\\.self_attn$"
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.attention.KDeepseekV2Attention
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda:2"
|
||||||
|
prefill_device: "cuda:2"
|
||||||
|
absorb_for_prefill: False
|
||||||
|
|
||||||
|
# GPU 3: layers 45–60
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\.(4[5-9]|5[0-9]|60)\\.self_attn$"
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.attention.KDeepseekV2Attention
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda:3"
|
||||||
|
prefill_device: "cuda:3"
|
||||||
|
absorb_for_prefill: False
|
||||||
|
|
||||||
|
# === Overall Model Replacement with Transfer Map ===
|
||||||
|
|
||||||
|
- match:
|
||||||
|
name: "^model$"
|
||||||
|
replace:
|
||||||
|
class: "ktransformers.operators.models.KDeepseekV2Model"
|
||||||
|
kwargs:
|
||||||
|
per_layer_prefill_intput_threshold: 0 # 0 means close layer‐wise prefill
|
||||||
|
transfer_map:
|
||||||
|
15: "cuda:1" # Layers 15+ on GPU 1
|
||||||
|
30: "cuda:2" # Layers 30+ on GPU 2
|
||||||
|
45: "cuda:3" # Layers 45+ on GPU 3
|
||||||
|
|
||||||
|
# === Default Catch-All for Other Modules ===
|
||||||
|
|
||||||
|
# GPU 0: layers 0–14
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\.([0-9]|1[0-4])\\."
|
||||||
|
replace:
|
||||||
|
class: "default"
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda:0"
|
||||||
|
prefill_device: "cuda:0"
|
||||||
|
|
||||||
|
# GPU 1: layers 15–29
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\.(1[5-9]|2[0-9])\\."
|
||||||
|
replace:
|
||||||
|
class: "default"
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda:1"
|
||||||
|
prefill_device: "cuda:1"
|
||||||
|
|
||||||
|
# GPU 2: layers 30–44
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\.(3[0-9]|4[0-4])\\."
|
||||||
|
replace:
|
||||||
|
class: "default"
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda:2"
|
||||||
|
prefill_device: "cuda:2"
|
||||||
|
|
||||||
|
- match:
|
||||||
|
name: "^lm_head"
|
||||||
|
class: torch.nn.Linear
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.linear.KTransformersLinear
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda:3"
|
||||||
|
prefill_device: "cuda:3"
|
||||||
|
generate_op: "KLinearTorch"
|
||||||
|
prefill_op: "KLinearTorch"
|
||||||
|
|
||||||
|
# For final modules (model.norm), ensure they are on GPU 3 (as in your original config)
|
||||||
|
- match:
|
||||||
|
name: "(^model\\.layers\\.(4[5-9]|5[0-9]|60)\\.)|(^model\\.norm)"
|
||||||
|
replace:
|
||||||
|
class: "default"
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda:3"
|
||||||
|
prefill_device: "cuda:3"
|
||||||
@@ -0,0 +1,156 @@
|
|||||||
|
- match:
|
||||||
|
name: "^model.embed_tokens"
|
||||||
|
replace:
|
||||||
|
class: "default"
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cpu"
|
||||||
|
prefill_device: "cpu"
|
||||||
|
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\."
|
||||||
|
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda:0"
|
||||||
|
prefill_device: "cuda:0"
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\.([3456][0-9])\\."
|
||||||
|
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda:1"
|
||||||
|
prefill_device: "cuda:1"
|
||||||
|
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.(?!self_attn\\.kv_b_proj).*$" # regular expression
|
||||||
|
class: torch.nn.Linear # only match modules matching name and class simultaneously
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda:0"
|
||||||
|
prefill_device: "cuda:0"
|
||||||
|
generate_op: "KLinearTorch"
|
||||||
|
prefill_op: "KLinearTorch"
|
||||||
|
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\.([3456][0-9])\\.(?!self_attn\\.kv_b_proj).*$" # regular expression
|
||||||
|
class: torch.nn.Linear # only match modules matching name and class simultaneously
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda:1"
|
||||||
|
prefill_device: "cuda:1"
|
||||||
|
generate_op: "KLinearTorch"
|
||||||
|
prefill_op: "KLinearTorch"
|
||||||
|
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.mlp$"
|
||||||
|
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.experts.KDeepseekV3MoE # mlp module with custom forward function
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda:0"
|
||||||
|
prefill_device: "cuda:0"
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\.([3456][0-9])\\.mlp$"
|
||||||
|
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.experts.KDeepseekV3MoE # mlp module with custom forward function
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda:1"
|
||||||
|
prefill_device: "cuda:1"
|
||||||
|
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.mlp\\.gate$"
|
||||||
|
class: ktransformers.models.modeling_deepseek_v3.MoEGate
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.gate.KMoEGate
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda:0"
|
||||||
|
prefill_device: "cuda:0"
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\.([3456][0-9])\\.mlp\\.gate$"
|
||||||
|
class: ktransformers.models.modeling_deepseek_v3.MoEGate
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.gate.KMoEGate # mlp module with custom forward function
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda:1"
|
||||||
|
prefill_device: "cuda:1"
|
||||||
|
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.mlp\\.experts$"
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
|
||||||
|
kwargs:
|
||||||
|
prefill_device: "cuda:0"
|
||||||
|
prefill_op: "KExpertsTorch"
|
||||||
|
generate_device: "cpu"
|
||||||
|
generate_op: "KSFTExpertsCPU"
|
||||||
|
out_device: "cuda:0"
|
||||||
|
backend: "AMXInt8" # or "AMXBF16" or "llamafile" (default)
|
||||||
|
recursive: False # don't recursively inject submodules of this module
|
||||||
|
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\.([3456][0-9])\\.mlp\\.experts$"
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
|
||||||
|
kwargs:
|
||||||
|
prefill_device: "cuda:1"
|
||||||
|
prefill_op: "KExpertsTorch"
|
||||||
|
generate_device: "cpu"
|
||||||
|
generate_op: "KSFTExpertsCPU"
|
||||||
|
out_device: "cuda:1"
|
||||||
|
backend: "AMXInt8" # or "AMXBF16" or "llamafile" (default)
|
||||||
|
recursive: False # don't recursively inject submodules of this module
|
||||||
|
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.self_attn$"
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda:0"
|
||||||
|
prefill_device: "cuda:0"
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\.([3456][0-9])\\.self_attn$"
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda:1"
|
||||||
|
prefill_device: "cuda:1"
|
||||||
|
- match:
|
||||||
|
name: "^model$"
|
||||||
|
replace:
|
||||||
|
class: "ktransformers.operators.models.KDeepseekV2Model"
|
||||||
|
kwargs:
|
||||||
|
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
|
||||||
|
transfer_map:
|
||||||
|
30: "cuda:1"
|
||||||
|
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\."
|
||||||
|
replace:
|
||||||
|
class: "default"
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda:0"
|
||||||
|
prefill_device: "cuda:0"
|
||||||
|
|
||||||
|
- match:
|
||||||
|
name: "^lm_head"
|
||||||
|
class: torch.nn.Linear
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.linear.KTransformersLinear
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda:1"
|
||||||
|
prefill_device: "cuda:1"
|
||||||
|
generate_op: "KLinearTorch"
|
||||||
|
prefill_op: "KLinearTorch"
|
||||||
|
|
||||||
|
- match:
|
||||||
|
name: "(^model\\.layers\\.([3456][0-9])\\.)|(model.norm)"
|
||||||
|
replace:
|
||||||
|
class: "default"
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda:1"
|
||||||
|
prefill_device: "cuda:1"
|
||||||
77
examples/kt_optimize_rules/DeepSeek-V3-Chat-sft-amx.yaml
Normal file
77
examples/kt_optimize_rules/DeepSeek-V3-Chat-sft-amx.yaml
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
- match:
|
||||||
|
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda"
|
||||||
|
prefill_device: "cuda"
|
||||||
|
|
||||||
|
- match:
|
||||||
|
name: "^lm_head$" # regular expression
|
||||||
|
class: torch.nn.Linear # only match modules matching name and class simultaneously
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda"
|
||||||
|
prefill_device: "cuda"
|
||||||
|
generate_op: "KLinearTorch"
|
||||||
|
prefill_op: "KLinearTorch"
|
||||||
|
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
|
||||||
|
class: torch.nn.Linear # only match modules matching name and class simultaneously
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda"
|
||||||
|
prefill_device: "cuda"
|
||||||
|
generate_op: "KLinearTorch"
|
||||||
|
prefill_op: "KLinearTorch"
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\..*\\.mlp$"
|
||||||
|
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.experts.KDeepseekV3MoE # mlp module with custom forward function
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda"
|
||||||
|
prefill_device: "cuda"
|
||||||
|
- match:
|
||||||
|
class: ktransformers.models.modeling_deepseek_v3.MoEGate
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.gate.KMoEGate
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda:0"
|
||||||
|
prefill_device: "cuda:0"
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\..*\\.mlp\\.experts$"
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
|
||||||
|
kwargs:
|
||||||
|
prefill_device: "cuda"
|
||||||
|
prefill_op: "KExpertsTorch"
|
||||||
|
generate_device: "cpu"
|
||||||
|
generate_op: "KSFTExpertsCPU"
|
||||||
|
out_device: "cuda"
|
||||||
|
backend: "AMXInt8" # or "AMXBF16" or "llamafile" (default)
|
||||||
|
recursive: False # don't recursively inject submodules of this module
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\..*\\.self_attn$"
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda"
|
||||||
|
prefill_device: "cuda"
|
||||||
|
absorb_for_prefill: False # change this to True to enable long context(prefill may slower).
|
||||||
|
- match:
|
||||||
|
name: "^model$"
|
||||||
|
replace:
|
||||||
|
class: "ktransformers.operators.models.KDeepseekV2Model"
|
||||||
|
kwargs:
|
||||||
|
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
|
||||||
|
- match:
|
||||||
|
name: "^model.embed_tokens"
|
||||||
|
replace:
|
||||||
|
class: "default"
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cpu"
|
||||||
|
prefill_device: "cpu"
|
||||||
80
examples/kt_optimize_rules/Qwen3Moe-sft-amx.yaml
Normal file
80
examples/kt_optimize_rules/Qwen3Moe-sft-amx.yaml
Normal file
@@ -0,0 +1,80 @@
|
|||||||
|
- match:
|
||||||
|
class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeRotaryEmbedding
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.RoPE.RotaryEmbedding
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda"
|
||||||
|
prefill_device: "cuda"
|
||||||
|
|
||||||
|
- match:
|
||||||
|
name: "^lm_head$" # regular expression
|
||||||
|
class: torch.nn.Linear # only match modules matching name and class simultaneously
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda"
|
||||||
|
prefill_device: "cuda"
|
||||||
|
generate_op: "KLinearTorch"
|
||||||
|
prefill_op: "KLinearTorch"
|
||||||
|
|
||||||
|
# - match:
|
||||||
|
# name: "^model\\.layers\\..*$" # regular expression
|
||||||
|
# class: torch.nn.Linear # only match modules matching name and class simultaneously
|
||||||
|
# replace:
|
||||||
|
# class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
|
||||||
|
# kwargs:
|
||||||
|
# generate_device: "cuda"
|
||||||
|
# prefill_device: "cuda"
|
||||||
|
# generate_op: "KLinearTorch"
|
||||||
|
# prefill_op: "KLinearTorch"
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\.(?!.*mlp\\.shared_expert_gate).*$" # regular expression
|
||||||
|
class: torch.nn.Linear # only match modules matching name and class simultaneously
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda"
|
||||||
|
prefill_device: "cuda"
|
||||||
|
generate_op: "KLinearTorch"
|
||||||
|
prefill_op: "KLinearTorch"
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\..*\\.mlp$"
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.experts.KQwen3MoeSparseMoeBlock # mlp module with custom forward function
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda"
|
||||||
|
prefill_device: "cuda"
|
||||||
|
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\..*\\.mlp\\.experts$"
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
|
||||||
|
kwargs:
|
||||||
|
prefill_device: "cuda"
|
||||||
|
prefill_op: "KExpertsTorch"
|
||||||
|
generate_device: "cpu"
|
||||||
|
generate_op: "KSFTExpertsCPU"
|
||||||
|
out_device: "cuda"
|
||||||
|
backend: "AMXInt8" # or "AMXBF16" or "AMXInt8"
|
||||||
|
recursive: False # don't recursively inject submodules of this module
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\..*\\.self_attn$"
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.attention.KQwen3MoeAttention # optimized MLA implementation
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda"
|
||||||
|
prefill_device: "cuda"
|
||||||
|
- match:
|
||||||
|
name: "^model.embed_tokens"
|
||||||
|
replace:
|
||||||
|
class: "default"
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cpu"
|
||||||
|
prefill_device: "cpu"
|
||||||
|
|
||||||
|
- match:
|
||||||
|
name: "^model$"
|
||||||
|
replace:
|
||||||
|
class: "ktransformers.operators.models.KQwen3MoeModel"
|
||||||
|
kwargs:
|
||||||
|
per_layer_prefill_intput_threshold: 0
|
||||||
29
examples/megatron/qwen2_vl_full.yaml
Normal file
29
examples/megatron/qwen2_vl_full.yaml
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
model_name_or_path: Qwen/Qwen2-VL-7B-Instruct
|
||||||
|
image_max_pixels: 262144
|
||||||
|
video_max_pixels: 16384
|
||||||
|
|
||||||
|
do_train: true
|
||||||
|
stage: sft
|
||||||
|
finetuning_type: full # only support full for now
|
||||||
|
dataset: llava_1k_en
|
||||||
|
preprocessing_num_workers: 8
|
||||||
|
cutoff_len: 4096
|
||||||
|
template: qwen2_vl
|
||||||
|
|
||||||
|
output_dir: saves/mca/qwen2_vl_full
|
||||||
|
per_device_train_batch_size: 1
|
||||||
|
gradient_accumulation_steps: 2
|
||||||
|
num_train_epochs: 2
|
||||||
|
learning_rate: 2e-5
|
||||||
|
logging_steps: 1
|
||||||
|
save_steps: 100
|
||||||
|
lr_scheduler_type: cosine
|
||||||
|
bf16: true
|
||||||
|
|
||||||
|
# mcore speed up
|
||||||
|
tensor_model_parallel_size: 4
|
||||||
|
sequence_parallel: true
|
||||||
|
pipeline_model_parallel_size: 2
|
||||||
|
bias_activation_fusion: true
|
||||||
|
apply_rope_fusion: true
|
||||||
|
use_distributed_optimizer: true
|
||||||
35
examples/megatron/qwen3_moe_full.yaml
Normal file
35
examples/megatron/qwen3_moe_full.yaml
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
model_name_or_path: Qwen/Qwen3-30B-A3B-Instruct-2507
|
||||||
|
|
||||||
|
# GPU memory: 8 * 78GB
|
||||||
|
do_train: true
|
||||||
|
stage: sft
|
||||||
|
finetuning_type: full # only support full for now
|
||||||
|
dataset: alpaca_en_demo
|
||||||
|
preprocessing_num_workers: 8
|
||||||
|
cutoff_len: 4096
|
||||||
|
template: qwen3_nothink
|
||||||
|
|
||||||
|
# global batchsize = (8 // 2 // 4) * 8 = 8
|
||||||
|
output_dir: saves/mca/qwen3_moe_full
|
||||||
|
per_device_train_batch_size: 1
|
||||||
|
gradient_accumulation_steps: 8
|
||||||
|
num_train_epochs: 2
|
||||||
|
learning_rate: 3e-6
|
||||||
|
logging_steps: 1
|
||||||
|
save_steps: 100
|
||||||
|
lr_scheduler_type: constant
|
||||||
|
bf16: true
|
||||||
|
|
||||||
|
# mcore speed up
|
||||||
|
tensor_model_parallel_size: 1
|
||||||
|
sequence_parallel: false
|
||||||
|
pipeline_model_parallel_size: 4
|
||||||
|
bias_activation_fusion: true
|
||||||
|
apply_rope_fusion: true
|
||||||
|
use_distributed_optimizer: true
|
||||||
|
overlap_param_gather: true
|
||||||
|
overlap_grad_reduce: true
|
||||||
|
moe_grouped_gemm: true
|
||||||
|
moe_token_dispatcher_type: alltoall
|
||||||
|
expert_model_parallel_size: 2
|
||||||
|
recompute_granularity: full
|
||||||
1
examples/requirements/adam-mini.txt
Normal file
1
examples/requirements/adam-mini.txt
Normal file
@@ -0,0 +1 @@
|
|||||||
|
adam-mini
|
||||||
1
examples/requirements/apollo.txt
Normal file
1
examples/requirements/apollo.txt
Normal file
@@ -0,0 +1 @@
|
|||||||
|
apollo-torch
|
||||||
1
examples/requirements/aqlm.txt
Normal file
1
examples/requirements/aqlm.txt
Normal file
@@ -0,0 +1 @@
|
|||||||
|
aqlm[gpu]>=1.1.0
|
||||||
1
examples/requirements/badam.txt
Normal file
1
examples/requirements/badam.txt
Normal file
@@ -0,0 +1 @@
|
|||||||
|
badam>=1.2.1
|
||||||
1
examples/requirements/bitsandbytes.txt
Normal file
1
examples/requirements/bitsandbytes.txt
Normal file
@@ -0,0 +1 @@
|
|||||||
|
bitsandbytes>=0.39.0
|
||||||
1
examples/requirements/eetq.txt
Normal file
1
examples/requirements/eetq.txt
Normal file
@@ -0,0 +1 @@
|
|||||||
|
eetq
|
||||||
2
examples/requirements/fp8-te.txt
Normal file
2
examples/requirements/fp8-te.txt
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
transformer_engine[pytorch]>=2.0.0
|
||||||
|
accelerate>=1.10.0
|
||||||
2
examples/requirements/fp8.txt
Normal file
2
examples/requirements/fp8.txt
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
torchao>=0.8.0
|
||||||
|
accelerate>=1.10.0
|
||||||
1
examples/requirements/galore.txt
Normal file
1
examples/requirements/galore.txt
Normal file
@@ -0,0 +1 @@
|
|||||||
|
galore-torch
|
||||||
2
examples/requirements/gptq.txt
Normal file
2
examples/requirements/gptq.txt
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
optimum>=1.24.0
|
||||||
|
gptqmodel>=2.0.0
|
||||||
1
examples/requirements/hqq.txt
Normal file
1
examples/requirements/hqq.txt
Normal file
@@ -0,0 +1 @@
|
|||||||
|
hqq
|
||||||
1
examples/requirements/liger-kernel.txt
Normal file
1
examples/requirements/liger-kernel.txt
Normal file
@@ -0,0 +1 @@
|
|||||||
|
liger-kernel>=0.5.5
|
||||||
8
examples/requirements/minicpm-v.txt
Normal file
8
examples/requirements/minicpm-v.txt
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
soundfile
|
||||||
|
torchvision
|
||||||
|
torchaudio
|
||||||
|
vector_quantize_pytorch
|
||||||
|
vocos
|
||||||
|
msgpack
|
||||||
|
referencing
|
||||||
|
jsonschema_specifications
|
||||||
1
examples/requirements/openmind.txt
Normal file
1
examples/requirements/openmind.txt
Normal file
@@ -0,0 +1 @@
|
|||||||
|
openmind
|
||||||
2
examples/requirements/sglang.txt
Normal file
2
examples/requirements/sglang.txt
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
sglang[srt]>=0.4.5
|
||||||
|
transformers==4.51.1
|
||||||
1
examples/requirements/swanlab.txt
Normal file
1
examples/requirements/swanlab.txt
Normal file
@@ -0,0 +1 @@
|
|||||||
|
swanlab
|
||||||
1
examples/requirements/vllm.txt
Normal file
1
examples/requirements/vllm.txt
Normal file
@@ -0,0 +1 @@
|
|||||||
|
vllm>=0.4.3,<=0.11.0
|
||||||
46
examples/train_full/qwen3_full_sft_autotp.yaml
Normal file
46
examples/train_full/qwen3_full_sft_autotp.yaml
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
### model
|
||||||
|
model_name_or_path: Qwen/Qwen3-32B
|
||||||
|
trust_remote_code: true
|
||||||
|
use_v1_kernels: true
|
||||||
|
|
||||||
|
### method
|
||||||
|
stage: sft
|
||||||
|
do_train: true
|
||||||
|
finetuning_type: full
|
||||||
|
deepspeed: examples/deepspeed/ds_z2_autotp_config.json
|
||||||
|
|
||||||
|
### dataset
|
||||||
|
dataset: identity,alpaca_en_demo
|
||||||
|
template: qwen3
|
||||||
|
cutoff_len: 2048
|
||||||
|
max_samples: 1000
|
||||||
|
overwrite_cache: true
|
||||||
|
preprocessing_num_workers: 16
|
||||||
|
dataloader_num_workers: 4
|
||||||
|
|
||||||
|
### output
|
||||||
|
output_dir: saves/qwen3-32b/full/sft_autotp
|
||||||
|
logging_steps: 1
|
||||||
|
save_steps: 500
|
||||||
|
plot_loss: true
|
||||||
|
overwrite_output_dir: true
|
||||||
|
save_only_model: false
|
||||||
|
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
|
||||||
|
|
||||||
|
### train
|
||||||
|
per_device_train_batch_size: 4
|
||||||
|
gradient_accumulation_steps: 1
|
||||||
|
learning_rate: 1.0e-4
|
||||||
|
num_train_epochs: 3.0
|
||||||
|
lr_scheduler_type: cosine
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
bf16: true
|
||||||
|
ddp_timeout: 180000000
|
||||||
|
resume_from_checkpoint: null
|
||||||
|
|
||||||
|
### eval
|
||||||
|
# eval_dataset: alpaca_en_demo
|
||||||
|
# val_size: 0.1
|
||||||
|
# per_device_eval_batch_size: 1
|
||||||
|
# eval_strategy: steps
|
||||||
|
# eval_steps: 500
|
||||||
52
examples/train_lora/deepseek2_lora_sft_kt.yaml
Normal file
52
examples/train_lora/deepseek2_lora_sft_kt.yaml
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
### model
|
||||||
|
model_name_or_path: deepseek-ai/DeepSeek-V2-Lite
|
||||||
|
trust_remote_code: true
|
||||||
|
|
||||||
|
### method
|
||||||
|
stage: sft
|
||||||
|
do_train: true
|
||||||
|
finetuning_type: lora
|
||||||
|
lora_rank: 8
|
||||||
|
lora_target: all
|
||||||
|
|
||||||
|
### dataset
|
||||||
|
dataset: identity
|
||||||
|
template: deepseek
|
||||||
|
cutoff_len: 2048
|
||||||
|
max_samples: 100000
|
||||||
|
overwrite_cache: true
|
||||||
|
preprocessing_num_workers: 16
|
||||||
|
dataloader_num_workers: 4
|
||||||
|
|
||||||
|
### output
|
||||||
|
output_dir: saves/Kllama_deepseekV2
|
||||||
|
logging_steps: 10
|
||||||
|
save_steps: 500
|
||||||
|
plot_loss: true
|
||||||
|
overwrite_output_dir: true
|
||||||
|
save_only_model: false
|
||||||
|
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
|
||||||
|
|
||||||
|
### train
|
||||||
|
per_device_train_batch_size: 1
|
||||||
|
gradient_accumulation_steps: 8
|
||||||
|
learning_rate: 1.0e-4
|
||||||
|
num_train_epochs: 3.0
|
||||||
|
lr_scheduler_type: cosine
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
bf16: true
|
||||||
|
ddp_timeout: 180000000
|
||||||
|
resume_from_checkpoint: null
|
||||||
|
|
||||||
|
### ktransformers
|
||||||
|
use_kt: true # use KTransformers as LoRA sft backend
|
||||||
|
kt_optimize_rule: examples/kt_optimize_rules/DeepSeek-V2-Lite-Chat-sft-amx.yaml
|
||||||
|
cpu_infer: 32
|
||||||
|
chunk_size: 8192
|
||||||
|
|
||||||
|
### eval
|
||||||
|
# eval_dataset: alpaca_en_demo
|
||||||
|
# val_size: 0.1
|
||||||
|
# per_device_eval_batch_size: 1
|
||||||
|
# eval_strategy: steps
|
||||||
|
# eval_steps: 500
|
||||||
52
examples/train_lora/deepseek3_lora_sft_kt.yaml
Normal file
52
examples/train_lora/deepseek3_lora_sft_kt.yaml
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
### model
|
||||||
|
model_name_or_path: opensourcerelease/DeepSeek-V3-bf16
|
||||||
|
trust_remote_code: true
|
||||||
|
|
||||||
|
### method
|
||||||
|
stage: sft
|
||||||
|
do_train: true
|
||||||
|
finetuning_type: lora
|
||||||
|
lora_rank: 8
|
||||||
|
lora_target: all
|
||||||
|
|
||||||
|
### dataset
|
||||||
|
dataset: identity
|
||||||
|
template: deepseek
|
||||||
|
cutoff_len: 2048
|
||||||
|
max_samples: 100000
|
||||||
|
overwrite_cache: true
|
||||||
|
preprocessing_num_workers: 16
|
||||||
|
dataloader_num_workers: 4
|
||||||
|
|
||||||
|
### output
|
||||||
|
output_dir: saves/Kllama_deepseekV3
|
||||||
|
logging_steps: 10
|
||||||
|
save_steps: 500
|
||||||
|
plot_loss: true
|
||||||
|
overwrite_output_dir: true
|
||||||
|
save_only_model: false
|
||||||
|
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
|
||||||
|
|
||||||
|
### train
|
||||||
|
per_device_train_batch_size: 1
|
||||||
|
gradient_accumulation_steps: 8
|
||||||
|
learning_rate: 1.0e-4
|
||||||
|
num_train_epochs: 3.0
|
||||||
|
lr_scheduler_type: cosine
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
bf16: true
|
||||||
|
ddp_timeout: 180000000
|
||||||
|
resume_from_checkpoint: null
|
||||||
|
|
||||||
|
### ktransformers
|
||||||
|
use_kt: true # use KTransformers as LoRA sft backend
|
||||||
|
kt_optimize_rule: examples/kt_optimize_rules/DeepSeek-V3-Chat-sft-amx-multi-gpu.yaml
|
||||||
|
cpu_infer: 32
|
||||||
|
chunk_size: 8192
|
||||||
|
|
||||||
|
### eval
|
||||||
|
# eval_dataset: alpaca_en_demo
|
||||||
|
# val_size: 0.1
|
||||||
|
# per_device_eval_batch_size: 1
|
||||||
|
# eval_strategy: steps
|
||||||
|
# eval_steps: 500
|
||||||
52
examples/train_lora/qwen3moe_lora_sft_kt.yaml
Normal file
52
examples/train_lora/qwen3moe_lora_sft_kt.yaml
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
### model
|
||||||
|
model_name_or_path: Qwen/Qwen3-235B-A22B-Instruct-2507
|
||||||
|
trust_remote_code: true
|
||||||
|
|
||||||
|
### method
|
||||||
|
stage: sft
|
||||||
|
do_train: true
|
||||||
|
finetuning_type: lora
|
||||||
|
lora_rank: 8
|
||||||
|
lora_target: all
|
||||||
|
|
||||||
|
### dataset
|
||||||
|
dataset: identity, alpaca_en_demo
|
||||||
|
template: qwen3_nothink
|
||||||
|
cutoff_len: 2048
|
||||||
|
max_samples: 100000
|
||||||
|
overwrite_cache: true
|
||||||
|
preprocessing_num_workers: 16
|
||||||
|
dataloader_num_workers: 4
|
||||||
|
|
||||||
|
### output
|
||||||
|
output_dir: saves/Kllama_Qwen3MoE_235bA22b
|
||||||
|
logging_steps: 10
|
||||||
|
save_steps: 200
|
||||||
|
plot_loss: true
|
||||||
|
overwrite_output_dir: true
|
||||||
|
save_only_model: false
|
||||||
|
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
|
||||||
|
|
||||||
|
### train
|
||||||
|
per_device_train_batch_size: 1
|
||||||
|
gradient_accumulation_steps: 8
|
||||||
|
learning_rate: 1.0e-4
|
||||||
|
num_train_epochs: 3
|
||||||
|
lr_scheduler_type: cosine
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
bf16: true
|
||||||
|
ddp_timeout: 180000000
|
||||||
|
resume_from_checkpoint: null
|
||||||
|
|
||||||
|
### ktransformers
|
||||||
|
use_kt: true # use KTransformers as LoRA sft backend
|
||||||
|
kt_optimize_rule: examples/kt_optimize_rules/Qwen3Moe-sft-amx.yaml
|
||||||
|
cpu_infer: 32
|
||||||
|
chunk_size: 8192
|
||||||
|
|
||||||
|
### eval
|
||||||
|
# eval_dataset: alpaca_en_demo
|
||||||
|
# val_size: 0.1
|
||||||
|
# per_device_eval_batch_size: 1
|
||||||
|
# eval_strategy: steps
|
||||||
|
# eval_steps: 500
|
||||||
155
pyproject.toml
155
pyproject.toml
@@ -1,42 +1,123 @@
|
|||||||
[build-system]
|
[build-system]
|
||||||
requires = ["setuptools>=61.0"]
|
requires = ["hatchling"]
|
||||||
build-backend = "setuptools.build_meta"
|
build-backend = "hatchling.build"
|
||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "llamafactory"
|
name = "llamafactory"
|
||||||
requires-python = ">=3.9.0"
|
dynamic = ["version"]
|
||||||
dynamic = [
|
description = "Unified Efficient Fine-Tuning of 100+ LLMs"
|
||||||
"version",
|
readme = "README.md"
|
||||||
"dependencies",
|
license = "Apache-2.0"
|
||||||
"optional-dependencies",
|
requires-python = ">=3.11.0"
|
||||||
"scripts",
|
authors = [
|
||||||
"authors",
|
{ name = "hiyouga", email = "hiyouga@buaa.edu.cn" }
|
||||||
"description",
|
]
|
||||||
"readme",
|
keywords = [
|
||||||
"license",
|
"AI",
|
||||||
"keywords",
|
"LLM",
|
||||||
"classifiers"
|
"GPT",
|
||||||
|
"ChatGPT",
|
||||||
|
"Llama",
|
||||||
|
"Transformer",
|
||||||
|
"DeepSeek",
|
||||||
|
"Pytorch"
|
||||||
|
]
|
||||||
|
classifiers = [
|
||||||
|
"Development Status :: 4 - Beta",
|
||||||
|
"Intended Audience :: Developers",
|
||||||
|
"Intended Audience :: Education",
|
||||||
|
"Intended Audience :: Science/Research",
|
||||||
|
"License :: OSI Approved :: Apache Software License",
|
||||||
|
"Operating System :: OS Independent",
|
||||||
|
"Programming Language :: Python :: 3",
|
||||||
|
"Programming Language :: Python :: 3.10",
|
||||||
|
"Programming Language :: Python :: 3.11",
|
||||||
|
"Programming Language :: Python :: 3.12",
|
||||||
|
"Programming Language :: Python :: 3.13",
|
||||||
|
"Topic :: Scientific/Engineering :: Artificial Intelligence"
|
||||||
|
]
|
||||||
|
dependencies = [
|
||||||
|
# core deps
|
||||||
|
"torch>=2.4.0",
|
||||||
|
"torchvision>=0.19.0",
|
||||||
|
"torchaudio>=2.4.0",
|
||||||
|
"transformers>=4.49.0,<=4.56.2,!=4.52.0; python_version < '3.10'",
|
||||||
|
"transformers>=4.49.0,<=4.57.1,!=4.52.0,!=4.57.0; python_version >= '3.10'",
|
||||||
|
"datasets>=2.16.0,<=4.0.0",
|
||||||
|
"accelerate>=1.3.0,<=1.11.0",
|
||||||
|
"peft>=0.14.0,<=0.17.1",
|
||||||
|
"trl>=0.8.6,<=0.9.6",
|
||||||
|
"torchdata>=0.10.0,<=0.11.0",
|
||||||
|
# gui
|
||||||
|
"gradio>=4.38.0,<=5.50.0",
|
||||||
|
"matplotlib>=3.7.0",
|
||||||
|
"tyro<0.9.0",
|
||||||
|
# ops
|
||||||
|
"einops",
|
||||||
|
"numpy",
|
||||||
|
"pandas",
|
||||||
|
"scipy",
|
||||||
|
# model and tokenizer
|
||||||
|
"sentencepiece",
|
||||||
|
"tiktoken",
|
||||||
|
"modelscope",
|
||||||
|
"hf-transfer",
|
||||||
|
"safetensors",
|
||||||
|
# python
|
||||||
|
"av",
|
||||||
|
"fire",
|
||||||
|
"omegaconf",
|
||||||
|
"packaging",
|
||||||
|
"protobuf",
|
||||||
|
"pyyaml",
|
||||||
|
"pydantic",
|
||||||
|
# api
|
||||||
|
"uvicorn",
|
||||||
|
"fastapi",
|
||||||
|
"sse-starlette"
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[project.optional-dependencies]
|
||||||
|
dev = ["pre-commit", "ruff", "pytest", "build"]
|
||||||
|
metrics = ["nltk", "jieba", "rouge-chinese"]
|
||||||
|
deepspeed = ["deepspeed>=0.10.0,<=0.16.9"]
|
||||||
|
|
||||||
|
[project.scripts]
|
||||||
|
llamafactory-cli = "llamafactory.cli:main"
|
||||||
|
lmf = "llamafactory.cli:main"
|
||||||
|
|
||||||
|
[project.urls]
|
||||||
|
Homepage = "https://github.com/hiyouga/LLaMA-Factory"
|
||||||
|
Repository = "https://github.com/hiyouga/LLaMA-Factory"
|
||||||
|
|
||||||
|
[tool.hatch.build.targets.wheel]
|
||||||
|
packages = ["src/llamafactory"]
|
||||||
|
|
||||||
|
[tool.hatch.version]
|
||||||
|
path = "src/llamafactory/extras/env.py"
|
||||||
|
pattern = "VERSION = \"(?P<version>[^\"]+)\""
|
||||||
|
|
||||||
[tool.ruff]
|
[tool.ruff]
|
||||||
target-version = "py39"
|
target-version = "py311"
|
||||||
line-length = 119
|
line-length = 119
|
||||||
indent-width = 4
|
indent-width = 4
|
||||||
|
|
||||||
[tool.ruff.lint]
|
[tool.ruff.lint]
|
||||||
ignore = [
|
ignore = [
|
||||||
"C408", # collection
|
"C408", # collection
|
||||||
"C901", # complex
|
"C901", # complex
|
||||||
"E501", # line too long
|
"E501", # line too long
|
||||||
"E731", # lambda function
|
"E731", # lambda function
|
||||||
"E741", # ambiguous var name
|
"E741", # ambiguous var name
|
||||||
"D100", # no doc public module
|
"UP007", # no upgrade union
|
||||||
"D101", # no doc public class
|
"UP045", # no upgrade optional
|
||||||
"D102", # no doc public method
|
"D100", # no doc public module
|
||||||
"D103", # no doc public function
|
"D101", # no doc public class
|
||||||
"D104", # no doc public package
|
"D102", # no doc public method
|
||||||
"D105", # no doc magic method
|
"D103", # no doc public function
|
||||||
"D107", # no doc __init__
|
"D104", # no doc public package
|
||||||
|
"D105", # no doc magic method
|
||||||
|
"D107", # no doc __init__
|
||||||
]
|
]
|
||||||
extend-select = [
|
extend-select = [
|
||||||
"C", # complexity
|
"C", # complexity
|
||||||
@@ -73,23 +154,3 @@ indent-style = "space"
|
|||||||
docstring-code-format = true
|
docstring-code-format = true
|
||||||
skip-magic-trailing-comma = false
|
skip-magic-trailing-comma = false
|
||||||
line-ending = "auto"
|
line-ending = "auto"
|
||||||
|
|
||||||
[tool.uv]
|
|
||||||
conflicts = [
|
|
||||||
[
|
|
||||||
{ extra = "torch-npu" },
|
|
||||||
{ extra = "aqlm" },
|
|
||||||
],
|
|
||||||
[
|
|
||||||
{ extra = "torch-npu" },
|
|
||||||
{ extra = "vllm" },
|
|
||||||
],
|
|
||||||
[
|
|
||||||
{ extra = "torch-npu" },
|
|
||||||
{ extra = "sglang" },
|
|
||||||
],
|
|
||||||
[
|
|
||||||
{ extra = "vllm" },
|
|
||||||
{ extra = "sglang" },
|
|
||||||
],
|
|
||||||
]
|
|
||||||
|
|||||||
@@ -1,36 +0,0 @@
|
|||||||
# core deps
|
|
||||||
transformers>=4.49.0,<=4.56.2,!=4.52.0; python_version < '3.10'
|
|
||||||
transformers>=4.49.0,<=4.57.1,!=4.52.0; python_version >= '3.10'
|
|
||||||
datasets>=2.16.0,<=4.0.0
|
|
||||||
accelerate>=1.3.0,<=1.11.0
|
|
||||||
peft>=0.14.0,<=0.17.1
|
|
||||||
trl>=0.8.6,<=0.9.6
|
|
||||||
# gui
|
|
||||||
gradio>=4.38.0,<=5.45.0
|
|
||||||
matplotlib>=3.7.0
|
|
||||||
tyro<0.9.0
|
|
||||||
# ops
|
|
||||||
einops
|
|
||||||
numpy<2.0.0
|
|
||||||
pandas>=2.0.0
|
|
||||||
scipy
|
|
||||||
# model and tokenizer
|
|
||||||
sentencepiece
|
|
||||||
tiktoken
|
|
||||||
modelscope>=1.14.0
|
|
||||||
hf-transfer
|
|
||||||
safetensors<=0.5.3
|
|
||||||
# python
|
|
||||||
fire
|
|
||||||
omegaconf
|
|
||||||
packaging
|
|
||||||
protobuf
|
|
||||||
pyyaml
|
|
||||||
pydantic<=2.10.6
|
|
||||||
# api
|
|
||||||
uvicorn
|
|
||||||
fastapi
|
|
||||||
sse-starlette
|
|
||||||
# media
|
|
||||||
av
|
|
||||||
librosa
|
|
||||||
124
scripts/megatron_merge.py
Normal file
124
scripts/megatron_merge.py
Normal file
@@ -0,0 +1,124 @@
|
|||||||
|
# Copyright 2025 the ROLL team and the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# This code is modified from the ROLL library.
|
||||||
|
# https://github.com/alibaba/ROLL/blob/main/mcore_adapter/tools/convert.py
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
import fire
|
||||||
|
import torch
|
||||||
|
from mcore_adapter.models.converter.post_converter import convert_checkpoint_to_hf, convert_checkpoint_to_mca
|
||||||
|
from mcore_adapter.training_args import DistributingParallelArguments
|
||||||
|
from mcore_adapter.utils import get_logger
|
||||||
|
from transformers import AutoConfig
|
||||||
|
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_mca_to_hf(
|
||||||
|
checkpoint_path: str,
|
||||||
|
output_path: str = "./output",
|
||||||
|
bf16: bool = False,
|
||||||
|
fp16: bool = False,
|
||||||
|
convert_model_max_length: int | None = None,
|
||||||
|
):
|
||||||
|
"""Convert megatron checkpoint to HuggingFace format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
checkpoint_path: Path to the checkpoint to convert
|
||||||
|
output_path: Path to save the converted checkpoint
|
||||||
|
bf16: Use bfloat16 precision
|
||||||
|
fp16: Use float16 precision
|
||||||
|
convert_model_max_length: Change the model_max_length in hf config.json
|
||||||
|
"""
|
||||||
|
if bf16 and fp16:
|
||||||
|
raise ValueError("bf16 and fp16 cannot be both True.")
|
||||||
|
|
||||||
|
torch_dtype = None
|
||||||
|
if bf16:
|
||||||
|
torch_dtype = torch.bfloat16
|
||||||
|
elif fp16:
|
||||||
|
torch_dtype = torch.float16
|
||||||
|
|
||||||
|
convert_checkpoint_to_hf(checkpoint_path, output_path, torch_dtype=torch_dtype)
|
||||||
|
|
||||||
|
if convert_model_max_length is not None:
|
||||||
|
config = AutoConfig.from_pretrained(output_path, trust_remote_code=True)
|
||||||
|
config.model_max_length = convert_model_max_length
|
||||||
|
config.save_pretrained(output_path)
|
||||||
|
|
||||||
|
|
||||||
|
def convert(
|
||||||
|
checkpoint_path: str,
|
||||||
|
output_path: str = "./output",
|
||||||
|
bf16: bool = False,
|
||||||
|
fp16: bool = False,
|
||||||
|
convert_model_max_length: int | None = None,
|
||||||
|
tensor_model_parallel_size: int = 1,
|
||||||
|
pipeline_model_parallel_size: int = 1,
|
||||||
|
expert_model_parallel_size: int = 1,
|
||||||
|
virtual_pipeline_model_parallel_size: int | None = None,
|
||||||
|
):
|
||||||
|
"""Convert checkpoint between MCA and HuggingFace formats.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
checkpoint_path: Path to the checkpoint to convert
|
||||||
|
output_path: Path to save the converted checkpoint
|
||||||
|
bf16: Use bfloat16 precision
|
||||||
|
fp16: Use float16 precision
|
||||||
|
convert_model_max_length: Change the model_max_length in hf config.json
|
||||||
|
tensor_model_parallel_size: Tensor model parallel size
|
||||||
|
pipeline_model_parallel_size: Pipeline model parallel size
|
||||||
|
expert_model_parallel_size: Expert model parallel size
|
||||||
|
virtual_pipeline_model_parallel_size: Virtual pipeline model parallel size
|
||||||
|
"""
|
||||||
|
if bf16 and fp16:
|
||||||
|
raise ValueError("bf16 and fp16 cannot be both True.")
|
||||||
|
|
||||||
|
mca_config_path = os.path.join(checkpoint_path, "mca_config.json")
|
||||||
|
from_mca = os.path.exists(mca_config_path)
|
||||||
|
|
||||||
|
if not from_mca:
|
||||||
|
dist_args = DistributingParallelArguments(
|
||||||
|
tensor_model_parallel_size=tensor_model_parallel_size,
|
||||||
|
pipeline_model_parallel_size=pipeline_model_parallel_size,
|
||||||
|
expert_model_parallel_size=expert_model_parallel_size,
|
||||||
|
virtual_pipeline_model_parallel_size=virtual_pipeline_model_parallel_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
convert_checkpoint_to_mca(
|
||||||
|
checkpoint_path,
|
||||||
|
output_path,
|
||||||
|
dist_args,
|
||||||
|
bf16=bf16,
|
||||||
|
fp16=fp16,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
convert_mca_to_hf(
|
||||||
|
checkpoint_path=checkpoint_path,
|
||||||
|
output_path=output_path,
|
||||||
|
bf16=bf16,
|
||||||
|
fp16=fp16,
|
||||||
|
convert_model_max_length=convert_model_max_length,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
fire.Fire(convert)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -14,7 +14,7 @@
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Literal, Optional
|
from typing import Any, Literal
|
||||||
|
|
||||||
import fire
|
import fire
|
||||||
import torch
|
import torch
|
||||||
@@ -61,7 +61,7 @@ def calculate_ppl(
|
|||||||
dataset_dir: str = "data",
|
dataset_dir: str = "data",
|
||||||
template: str = "default",
|
template: str = "default",
|
||||||
cutoff_len: int = 2048,
|
cutoff_len: int = 2048,
|
||||||
max_samples: Optional[int] = None,
|
max_samples: int | None = None,
|
||||||
train_on_prompt: bool = False,
|
train_on_prompt: bool = False,
|
||||||
):
|
):
|
||||||
r"""Calculate the ppl on the dataset of the pre-trained models.
|
r"""Calculate the ppl on the dataset of the pre-trained models.
|
||||||
|
|||||||
@@ -14,8 +14,8 @@
|
|||||||
|
|
||||||
import gc
|
import gc
|
||||||
import json
|
import json
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
|
import av
|
||||||
import fire
|
import fire
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from transformers import Seq2SeqTrainingArguments
|
from transformers import Seq2SeqTrainingArguments
|
||||||
@@ -33,6 +33,14 @@ if is_vllm_available():
|
|||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
|
|
||||||
|
|
||||||
|
def _need_video_kwargs(template):
|
||||||
|
NEEDED_TEMPLATE = ["qwen3_vl", "glm4v"]
|
||||||
|
if any(t in template for t in NEEDED_TEMPLATE):
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def vllm_infer(
|
def vllm_infer(
|
||||||
model_name_or_path: str,
|
model_name_or_path: str,
|
||||||
adapter_name_or_path: str = None,
|
adapter_name_or_path: str = None,
|
||||||
@@ -40,7 +48,7 @@ def vllm_infer(
|
|||||||
dataset_dir: str = "data",
|
dataset_dir: str = "data",
|
||||||
template: str = "default",
|
template: str = "default",
|
||||||
cutoff_len: int = 2048,
|
cutoff_len: int = 2048,
|
||||||
max_samples: Optional[int] = None,
|
max_samples: int | None = None,
|
||||||
vllm_config: str = "{}",
|
vllm_config: str = "{}",
|
||||||
save_name: str = "generated_predictions.jsonl",
|
save_name: str = "generated_predictions.jsonl",
|
||||||
temperature: float = 0.95,
|
temperature: float = 0.95,
|
||||||
@@ -49,9 +57,9 @@ def vllm_infer(
|
|||||||
max_new_tokens: int = 1024,
|
max_new_tokens: int = 1024,
|
||||||
repetition_penalty: float = 1.0,
|
repetition_penalty: float = 1.0,
|
||||||
skip_special_tokens: bool = True,
|
skip_special_tokens: bool = True,
|
||||||
default_system: Optional[str] = None,
|
default_system: str | None = None,
|
||||||
enable_thinking: bool = True,
|
enable_thinking: bool = True,
|
||||||
seed: Optional[int] = None,
|
seed: int | None = None,
|
||||||
pipeline_parallel_size: int = 1,
|
pipeline_parallel_size: int = 1,
|
||||||
image_max_pixels: int = 768 * 768,
|
image_max_pixels: int = 768 * 768,
|
||||||
image_min_pixels: int = 32 * 32,
|
image_min_pixels: int = 32 * 32,
|
||||||
@@ -132,6 +140,7 @@ def vllm_infer(
|
|||||||
|
|
||||||
# Store all results in these lists
|
# Store all results in these lists
|
||||||
all_prompts, all_preds, all_labels = [], [], []
|
all_prompts, all_preds, all_labels = [], [], []
|
||||||
|
need_video_kwargs = _need_video_kwargs(template)
|
||||||
|
|
||||||
# Add batch process to avoid the issue of too many files opened
|
# Add batch process to avoid the issue of too many files opened
|
||||||
for i in tqdm(range(0, len(train_dataset), batch_size), desc="Processing batched inference"):
|
for i in tqdm(range(0, len(train_dataset), batch_size), desc="Processing batched inference"):
|
||||||
@@ -147,6 +156,7 @@ def vllm_infer(
|
|||||||
)["images"]
|
)["images"]
|
||||||
}
|
}
|
||||||
elif batch["videos"][j] is not None:
|
elif batch["videos"][j] is not None:
|
||||||
|
video_metadata, video_metadata_kwargs = None, None
|
||||||
video = batch["videos"][j]
|
video = batch["videos"][j]
|
||||||
multi_modal_data = {
|
multi_modal_data = {
|
||||||
"video": template_obj.mm_plugin._regularize_videos(
|
"video": template_obj.mm_plugin._regularize_videos(
|
||||||
@@ -157,6 +167,25 @@ def vllm_infer(
|
|||||||
video_maxlen=video_maxlen,
|
video_maxlen=video_maxlen,
|
||||||
)["videos"]
|
)["videos"]
|
||||||
}
|
}
|
||||||
|
if need_video_kwargs:
|
||||||
|
container = av.open(video[0], "r")
|
||||||
|
video_stream = next(stream for stream in container.streams if stream.type == "video")
|
||||||
|
sampling_indices = template_obj.mm_plugin._get_video_sample_indices(
|
||||||
|
video_stream, video_fps, video_maxlen
|
||||||
|
)
|
||||||
|
total_frames = video_stream.frames
|
||||||
|
video_metadata_kwargs = {
|
||||||
|
"fps": getattr(tokenizer_module["processor"], "video_fps", 24.0),
|
||||||
|
"do_sample_frames": False,
|
||||||
|
"total_num_frames": total_frames,
|
||||||
|
}
|
||||||
|
video_metadata = dict(
|
||||||
|
fps=video_fps,
|
||||||
|
frames_indices=sampling_indices,
|
||||||
|
total_num_frames=total_frames,
|
||||||
|
video_backend="opencv",
|
||||||
|
)
|
||||||
|
multi_modal_data["video"] = (multi_modal_data["video"], video_metadata)
|
||||||
elif batch["audios"][j] is not None:
|
elif batch["audios"][j] is not None:
|
||||||
audio = batch["audios"][j]
|
audio = batch["audios"][j]
|
||||||
audio_data = template_obj.mm_plugin._regularize_audios(
|
audio_data = template_obj.mm_plugin._regularize_audios(
|
||||||
@@ -167,7 +196,11 @@ def vllm_infer(
|
|||||||
else:
|
else:
|
||||||
multi_modal_data = None
|
multi_modal_data = None
|
||||||
|
|
||||||
vllm_inputs.append({"prompt_token_ids": batch["input_ids"][j], "multi_modal_data": multi_modal_data})
|
vllm_input_data = {"prompt_token_ids": batch["input_ids"][j], "multi_modal_data": multi_modal_data}
|
||||||
|
if "video_metadata_kwargs" in locals() and video_metadata_kwargs is not None:
|
||||||
|
vllm_input_data["mm_processor_kwargs"] = video_metadata_kwargs
|
||||||
|
|
||||||
|
vllm_inputs.append(vllm_input_data)
|
||||||
prompts.append(tokenizer.decode(batch["input_ids"][j], skip_special_tokens=skip_special_tokens))
|
prompts.append(tokenizer.decode(batch["input_ids"][j], skip_special_tokens=skip_special_tokens))
|
||||||
labels.append(
|
labels.append(
|
||||||
tokenizer.decode(
|
tokenizer.decode(
|
||||||
|
|||||||
116
setup.py
116
setup.py
@@ -1,116 +0,0 @@
|
|||||||
# Copyright 2025 the LlamaFactory team.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
import os
|
|
||||||
import re
|
|
||||||
|
|
||||||
from setuptools import find_packages, setup
|
|
||||||
|
|
||||||
|
|
||||||
def get_version() -> str:
|
|
||||||
with open(os.path.join("src", "llamafactory", "extras", "env.py"), encoding="utf-8") as f:
|
|
||||||
file_content = f.read()
|
|
||||||
pattern = r"{}\W*=\W*\"([^\"]+)\"".format("VERSION")
|
|
||||||
(version,) = re.findall(pattern, file_content)
|
|
||||||
return version
|
|
||||||
|
|
||||||
|
|
||||||
def get_requires() -> list[str]:
|
|
||||||
with open("requirements.txt", encoding="utf-8") as f:
|
|
||||||
file_content = f.read()
|
|
||||||
lines = [line.strip() for line in file_content.strip().split("\n") if not line.startswith("#")]
|
|
||||||
return lines
|
|
||||||
|
|
||||||
|
|
||||||
def get_console_scripts() -> list[str]:
|
|
||||||
console_scripts = ["llamafactory-cli = llamafactory.cli:main"]
|
|
||||||
if os.getenv("ENABLE_SHORT_CONSOLE", "1").lower() in ["true", "y", "1"]:
|
|
||||||
console_scripts.append("lmf = llamafactory.cli:main")
|
|
||||||
|
|
||||||
return console_scripts
|
|
||||||
|
|
||||||
|
|
||||||
extra_require = {
|
|
||||||
"torch": ["torch>=2.0.0", "torchvision>=0.15.0"],
|
|
||||||
"torch-npu": ["torch-npu==2.5.1", "torchvision==0.20.1", "decorator"],
|
|
||||||
"metrics": ["nltk", "jieba", "rouge-chinese"],
|
|
||||||
"deepspeed": ["deepspeed>=0.10.0,<=0.16.9"],
|
|
||||||
"liger-kernel": ["liger-kernel>=0.5.5"],
|
|
||||||
"bitsandbytes": ["bitsandbytes>=0.39.0"],
|
|
||||||
"hqq": ["hqq"],
|
|
||||||
"eetq": ["eetq"],
|
|
||||||
"gptq": ["optimum>=1.24.0", "gptqmodel>=2.0.0"],
|
|
||||||
"aqlm": ["aqlm[gpu]>=1.1.0"],
|
|
||||||
"vllm": ["vllm>=0.4.3,<=0.11.0"],
|
|
||||||
"sglang": ["sglang[srt]>=0.4.5", "transformers==4.51.1"],
|
|
||||||
"galore": ["galore-torch"],
|
|
||||||
"apollo": ["apollo-torch"],
|
|
||||||
"badam": ["badam>=1.2.1"],
|
|
||||||
"adam-mini": ["adam-mini"],
|
|
||||||
"minicpm_v": [
|
|
||||||
"soundfile",
|
|
||||||
"torchvision",
|
|
||||||
"torchaudio",
|
|
||||||
"vector_quantize_pytorch",
|
|
||||||
"vocos",
|
|
||||||
"msgpack",
|
|
||||||
"referencing",
|
|
||||||
"jsonschema_specifications",
|
|
||||||
],
|
|
||||||
"openmind": ["openmind"],
|
|
||||||
"swanlab": ["swanlab"],
|
|
||||||
"fp8": ["torchao>=0.8.0", "accelerate>=1.10.0"],
|
|
||||||
"fp8-te": ["transformer_engine[pytorch]>=2.0.0", "accelerate>=1.10.0"],
|
|
||||||
"fp8-all": ["torchao>=0.8.0", "transformer_engine[pytorch]>=2.0.0", "accelerate>=1.10.0"],
|
|
||||||
"dev": ["pre-commit", "ruff", "pytest", "build"],
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
setup(
|
|
||||||
name="llamafactory",
|
|
||||||
version=get_version(),
|
|
||||||
author="hiyouga",
|
|
||||||
author_email="hiyouga@buaa.edu.cn",
|
|
||||||
description="Unified Efficient Fine-Tuning of 100+ LLMs",
|
|
||||||
long_description=open("README.md", encoding="utf-8").read(),
|
|
||||||
long_description_content_type="text/markdown",
|
|
||||||
keywords=["AI", "LLM", "GPT", "ChatGPT", "Llama", "Transformer", "DeepSeek", "Pytorch"],
|
|
||||||
license="Apache 2.0 License",
|
|
||||||
url="https://github.com/hiyouga/LLaMA-Factory",
|
|
||||||
package_dir={"": "src"},
|
|
||||||
packages=find_packages("src"),
|
|
||||||
python_requires=">=3.9.0",
|
|
||||||
install_requires=get_requires(),
|
|
||||||
extras_require=extra_require,
|
|
||||||
entry_points={"console_scripts": get_console_scripts()},
|
|
||||||
classifiers=[
|
|
||||||
"Development Status :: 4 - Beta",
|
|
||||||
"Intended Audience :: Developers",
|
|
||||||
"Intended Audience :: Education",
|
|
||||||
"Intended Audience :: Science/Research",
|
|
||||||
"License :: OSI Approved :: Apache Software License",
|
|
||||||
"Operating System :: OS Independent",
|
|
||||||
"Programming Language :: Python :: 3",
|
|
||||||
"Programming Language :: Python :: 3.9",
|
|
||||||
"Programming Language :: Python :: 3.10",
|
|
||||||
"Programming Language :: Python :: 3.11",
|
|
||||||
"Programming Language :: Python :: 3.12",
|
|
||||||
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
@@ -16,7 +16,7 @@ import asyncio
|
|||||||
import os
|
import os
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Annotated, Optional
|
from typing import Annotated
|
||||||
|
|
||||||
from ..chat import ChatModel
|
from ..chat import ChatModel
|
||||||
from ..extras.constants import EngineName
|
from ..extras.constants import EngineName
|
||||||
@@ -79,7 +79,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
|
|||||||
api_key = os.getenv("API_KEY")
|
api_key = os.getenv("API_KEY")
|
||||||
security = HTTPBearer(auto_error=False)
|
security = HTTPBearer(auto_error=False)
|
||||||
|
|
||||||
async def verify_api_key(auth: Annotated[Optional[HTTPAuthorizationCredentials], Depends(security)]):
|
async def verify_api_key(auth: Annotated[HTTPAuthorizationCredentials | None, Depends(security)]):
|
||||||
if api_key and (auth is None or auth.credentials != api_key):
|
if api_key and (auth is None or auth.credentials != api_key):
|
||||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key.")
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key.")
|
||||||
|
|
||||||
|
|||||||
@@ -14,10 +14,9 @@
|
|||||||
|
|
||||||
import time
|
import time
|
||||||
from enum import Enum, unique
|
from enum import Enum, unique
|
||||||
from typing import Any, Optional, Union
|
from typing import Any, Literal
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing_extensions import Literal
|
|
||||||
|
|
||||||
|
|
||||||
@unique
|
@unique
|
||||||
@@ -61,7 +60,7 @@ class FunctionDefinition(BaseModel):
|
|||||||
|
|
||||||
class FunctionAvailable(BaseModel):
|
class FunctionAvailable(BaseModel):
|
||||||
type: Literal["function", "code_interpreter"] = "function"
|
type: Literal["function", "code_interpreter"] = "function"
|
||||||
function: Optional[FunctionDefinition] = None
|
function: FunctionDefinition | None = None
|
||||||
|
|
||||||
|
|
||||||
class FunctionCall(BaseModel):
|
class FunctionCall(BaseModel):
|
||||||
@@ -77,35 +76,35 @@ class URL(BaseModel):
|
|||||||
|
|
||||||
class MultimodalInputItem(BaseModel):
|
class MultimodalInputItem(BaseModel):
|
||||||
type: Literal["text", "image_url", "video_url", "audio_url"]
|
type: Literal["text", "image_url", "video_url", "audio_url"]
|
||||||
text: Optional[str] = None
|
text: str | None = None
|
||||||
image_url: Optional[URL] = None
|
image_url: URL | None = None
|
||||||
video_url: Optional[URL] = None
|
video_url: URL | None = None
|
||||||
audio_url: Optional[URL] = None
|
audio_url: URL | None = None
|
||||||
|
|
||||||
|
|
||||||
class ChatMessage(BaseModel):
|
class ChatMessage(BaseModel):
|
||||||
role: Role
|
role: Role
|
||||||
content: Optional[Union[str, list[MultimodalInputItem]]] = None
|
content: str | list[MultimodalInputItem] | None = None
|
||||||
tool_calls: Optional[list[FunctionCall]] = None
|
tool_calls: list[FunctionCall] | None = None
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionMessage(BaseModel):
|
class ChatCompletionMessage(BaseModel):
|
||||||
role: Optional[Role] = None
|
role: Role | None = None
|
||||||
content: Optional[str] = None
|
content: str | None = None
|
||||||
tool_calls: Optional[list[FunctionCall]] = None
|
tool_calls: list[FunctionCall] | None = None
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionRequest(BaseModel):
|
class ChatCompletionRequest(BaseModel):
|
||||||
model: str
|
model: str
|
||||||
messages: list[ChatMessage]
|
messages: list[ChatMessage]
|
||||||
tools: Optional[list[FunctionAvailable]] = None
|
tools: list[FunctionAvailable] | None = None
|
||||||
do_sample: Optional[bool] = None
|
do_sample: bool | None = None
|
||||||
temperature: Optional[float] = None
|
temperature: float | None = None
|
||||||
top_p: Optional[float] = None
|
top_p: float | None = None
|
||||||
n: int = 1
|
n: int = 1
|
||||||
presence_penalty: Optional[float] = None
|
presence_penalty: float | None = None
|
||||||
max_tokens: Optional[int] = None
|
max_tokens: int | None = None
|
||||||
stop: Optional[Union[str, list[str]]] = None
|
stop: str | list[str] | None = None
|
||||||
stream: bool = False
|
stream: bool = False
|
||||||
|
|
||||||
|
|
||||||
@@ -118,7 +117,7 @@ class ChatCompletionResponseChoice(BaseModel):
|
|||||||
class ChatCompletionStreamResponseChoice(BaseModel):
|
class ChatCompletionStreamResponseChoice(BaseModel):
|
||||||
index: int
|
index: int
|
||||||
delta: ChatCompletionMessage
|
delta: ChatCompletionMessage
|
||||||
finish_reason: Optional[Finish] = None
|
finish_reason: Finish | None = None
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionResponseUsage(BaseModel):
|
class ChatCompletionResponseUsage(BaseModel):
|
||||||
@@ -147,7 +146,7 @@ class ChatCompletionStreamResponse(BaseModel):
|
|||||||
class ScoreEvaluationRequest(BaseModel):
|
class ScoreEvaluationRequest(BaseModel):
|
||||||
model: str
|
model: str
|
||||||
messages: list[str]
|
messages: list[str]
|
||||||
max_length: Optional[int] = None
|
max_length: int | None = None
|
||||||
|
|
||||||
|
|
||||||
class ScoreEvaluationResponse(BaseModel):
|
class ScoreEvaluationResponse(BaseModel):
|
||||||
|
|||||||
@@ -71,6 +71,16 @@ class ChatModel:
|
|||||||
"SGLang not install, you may need to run `pip install sglang[all]`\n"
|
"SGLang not install, you may need to run `pip install sglang[all]`\n"
|
||||||
"or try to use HuggingFace backend: --infer_backend huggingface"
|
"or try to use HuggingFace backend: --infer_backend huggingface"
|
||||||
) from e
|
) from e
|
||||||
|
elif model_args.infer_backend == EngineName.KT:
|
||||||
|
try:
|
||||||
|
from .kt_engine import KTransformersEngine
|
||||||
|
|
||||||
|
self.engine: BaseEngine = KTransformersEngine(model_args, data_args, finetuning_args, generating_args)
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError(
|
||||||
|
"KTransformers not install, you may need to run `pip install ktransformers`\n"
|
||||||
|
"or try to use HuggingFace backend: --infer_backend huggingface"
|
||||||
|
) from e
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"Unknown backend: {model_args.infer_backend}")
|
raise NotImplementedError(f"Unknown backend: {model_args.infer_backend}")
|
||||||
|
|
||||||
|
|||||||
@@ -14,9 +14,9 @@
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import os
|
import os
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator, Callable
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
|
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from transformers import GenerationConfig, TextIteratorStreamer
|
from transformers import GenerationConfig, TextIteratorStreamer
|
||||||
|
|||||||
284
src/llamafactory/chat/kt_engine.py
Normal file
284
src/llamafactory/chat/kt_engine.py
Normal file
@@ -0,0 +1,284 @@
|
|||||||
|
# Copyright 2025 the KVCache.AI team, Approaching AI, and the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
import platform
|
||||||
|
from collections.abc import AsyncGenerator
|
||||||
|
from threading import Thread
|
||||||
|
from typing import TYPE_CHECKING, Any, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
|
from ..data import get_template_and_fix_tokenizer
|
||||||
|
from ..extras import logging
|
||||||
|
from ..extras.constants import EngineName
|
||||||
|
from ..model import load_model, load_tokenizer
|
||||||
|
from .base_engine import BaseEngine, Response
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from transformers import PreTrainedTokenizer
|
||||||
|
from trl import PreTrainedModelWrapper
|
||||||
|
|
||||||
|
from ..data.mm_plugin import AudioInput, ImageInput, VideoInput
|
||||||
|
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
|
||||||
|
|
||||||
|
from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled
|
||||||
|
from ktransformers.server.config.config import Config
|
||||||
|
from ktransformers.util.utils import (
|
||||||
|
get_compute_capability,
|
||||||
|
prefill_and_generate_capture,
|
||||||
|
)
|
||||||
|
from ktransformers.util.vendors import GPUVendor, device_manager
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class KTransformersEngine(BaseEngine):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_args: "ModelArguments",
|
||||||
|
data_args: "DataArguments",
|
||||||
|
finetuning_args: "FinetuningArguments",
|
||||||
|
generating_args: "GeneratingArguments",
|
||||||
|
) -> None:
|
||||||
|
self.name = EngineName.KT
|
||||||
|
self.can_generate = finetuning_args.stage == "sft"
|
||||||
|
|
||||||
|
tok_mod = load_tokenizer(model_args)
|
||||||
|
self.tokenizer = tok_mod["tokenizer"]
|
||||||
|
self.tokenizer.padding_side = "left" if self.can_generate else "right"
|
||||||
|
self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args)
|
||||||
|
|
||||||
|
self.model = load_model(
|
||||||
|
self.tokenizer, model_args, finetuning_args, is_trainable=False, add_valuehead=(not self.can_generate)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.generating_args = generating_args.to_dict()
|
||||||
|
self.max_new_tokens = model_args.kt_maxlen
|
||||||
|
self.use_cuda_graph = model_args.kt_use_cuda_graph
|
||||||
|
self.mode = model_args.kt_mode
|
||||||
|
self.force_think = model_args.kt_force_think
|
||||||
|
self.chunk_size = model_args.chunk_size
|
||||||
|
|
||||||
|
try:
|
||||||
|
asyncio.get_event_loop()
|
||||||
|
except RuntimeError:
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
|
||||||
|
self.semaphore = asyncio.Semaphore(int(os.getenv("MAX_CONCURRENT", "1")))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@torch.inference_mode()
|
||||||
|
def _get_scores(
|
||||||
|
model: "PreTrainedModelWrapper",
|
||||||
|
tokenizer: "PreTrainedTokenizer",
|
||||||
|
batch_input: list[str],
|
||||||
|
input_kwargs: Optional[dict[str, Any]] = {},
|
||||||
|
) -> list[float]:
|
||||||
|
max_length: Optional[int] = input_kwargs.pop("max_length", None)
|
||||||
|
device = getattr(model.pretrained_model, "device", "cuda")
|
||||||
|
inputs = tokenizer(
|
||||||
|
batch_input,
|
||||||
|
padding=True,
|
||||||
|
truncation=True,
|
||||||
|
max_length=max_length or getattr(model.config, "max_position_embeddings", 1024),
|
||||||
|
return_tensors="pt",
|
||||||
|
add_special_tokens=False,
|
||||||
|
).to(device)
|
||||||
|
values: torch.Tensor = model(**inputs, return_dict=True, use_cache=False)[-1]
|
||||||
|
scores = values.gather(dim=-1, index=(inputs["attention_mask"].sum(dim=-1, keepdim=True) - 1))
|
||||||
|
return scores
|
||||||
|
|
||||||
|
async def _generate(
|
||||||
|
self,
|
||||||
|
messages: list[dict[str, str]],
|
||||||
|
system: Optional[str] = None,
|
||||||
|
tools: Optional[str] = None,
|
||||||
|
**input_kwargs,
|
||||||
|
) -> AsyncGenerator[str, None]:
|
||||||
|
paired = messages + [{"role": "assistant", "content": ""}]
|
||||||
|
prompt_ids, _ = self.template.encode_oneturn(self.tokenizer, paired, system, tools)
|
||||||
|
prompt_len = len(prompt_ids)
|
||||||
|
|
||||||
|
max_length: Optional[int] = input_kwargs.pop("max_length", None)
|
||||||
|
max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None)
|
||||||
|
|
||||||
|
if "max_new_tokens" in self.generating_args:
|
||||||
|
max_tokens = int(self.generating_args["max_new_tokens"])
|
||||||
|
elif "max_length" in self.generating_args:
|
||||||
|
gl = int(self.generating_args["max_length"])
|
||||||
|
max_tokens = gl - prompt_len if gl > prompt_len else 1
|
||||||
|
else:
|
||||||
|
max_tokens = self.max_new_tokens or 256
|
||||||
|
|
||||||
|
if max_length is not None:
|
||||||
|
max_tokens = max(max_length - prompt_len, 1)
|
||||||
|
if max_new_tokens is not None:
|
||||||
|
max_tokens = int(max_new_tokens)
|
||||||
|
max_tokens = max(1, int(max_tokens))
|
||||||
|
|
||||||
|
if self.mode == "long_context":
|
||||||
|
max_len_cfg = Config().long_context_config["max_seq_len"]
|
||||||
|
need = prompt_len + max_tokens
|
||||||
|
assert max_len_cfg > need, f"please set max_seq_len > {need} in ~/.ktransformers/config.yaml"
|
||||||
|
|
||||||
|
device = next(self.model.parameters()).device
|
||||||
|
input_tensor = torch.tensor([prompt_ids], dtype=torch.long, device=device)
|
||||||
|
if self.force_think:
|
||||||
|
think = torch.tensor(
|
||||||
|
[self.tokenizer.encode("<think>\n", add_special_tokens=False)], dtype=torch.long, device=device
|
||||||
|
)
|
||||||
|
input_tensor = torch.cat([input_tensor, think], dim=1)
|
||||||
|
|
||||||
|
use_flashinfer = (
|
||||||
|
platform.system() != "Windows"
|
||||||
|
and getattr(self.model.config, "architectures", [""])[0]
|
||||||
|
in {"DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"}
|
||||||
|
and flashinfer_enabled
|
||||||
|
and get_compute_capability() >= 8
|
||||||
|
and device_manager.gpu_vendor == GPUVendor.NVIDIA
|
||||||
|
)
|
||||||
|
|
||||||
|
def make_gen():
|
||||||
|
if use_flashinfer:
|
||||||
|
return prefill_and_generate_capture(
|
||||||
|
self.model,
|
||||||
|
self.tokenizer,
|
||||||
|
input_tensor,
|
||||||
|
max_tokens,
|
||||||
|
self.use_cuda_graph,
|
||||||
|
mode=self.mode,
|
||||||
|
force_think=self.force_think,
|
||||||
|
chunk_size=self.chunk_size,
|
||||||
|
use_flashinfer_mla=True,
|
||||||
|
num_heads=self.model.config.num_attention_heads,
|
||||||
|
head_dim_ckv=getattr(self.model.config, "kv_lora_rank", 0),
|
||||||
|
head_dim_kpe=getattr(self.model.config, "qk_rope_head_dim", 0),
|
||||||
|
q_head_dim=getattr(self.model.config, "qk_rope_head_dim", 0)
|
||||||
|
+ getattr(self.model.config, "qk_nope_head_dim", 0),
|
||||||
|
echo_stream=False,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return prefill_and_generate_capture(
|
||||||
|
self.model,
|
||||||
|
self.tokenizer,
|
||||||
|
input_tensor,
|
||||||
|
max_tokens,
|
||||||
|
self.use_cuda_graph,
|
||||||
|
mode=self.mode,
|
||||||
|
force_think=self.force_think,
|
||||||
|
chunk_size=self.chunk_size,
|
||||||
|
echo_stream=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
q: asyncio.Queue[Optional[str]] = asyncio.Queue()
|
||||||
|
|
||||||
|
def producer():
|
||||||
|
try:
|
||||||
|
gen = make_gen()
|
||||||
|
if hasattr(gen, "__aiter__"):
|
||||||
|
|
||||||
|
async def drain_async():
|
||||||
|
async for t in gen:
|
||||||
|
loop.call_soon_threadsafe(q.put_nowait, t if isinstance(t, str) else str(t))
|
||||||
|
|
||||||
|
asyncio.run(drain_async())
|
||||||
|
elif hasattr(gen, "__iter__"):
|
||||||
|
for t in gen:
|
||||||
|
loop.call_soon_threadsafe(q.put_nowait, t if isinstance(t, str) else str(t))
|
||||||
|
else:
|
||||||
|
loop.call_soon_threadsafe(q.put_nowait, gen if isinstance(gen, str) else str(gen))
|
||||||
|
finally:
|
||||||
|
loop.call_soon_threadsafe(q.put_nowait, None)
|
||||||
|
|
||||||
|
Thread(target=producer, daemon=True).start()
|
||||||
|
|
||||||
|
while True:
|
||||||
|
item = await q.get()
|
||||||
|
if item is None:
|
||||||
|
break
|
||||||
|
yield item
|
||||||
|
|
||||||
|
@override
|
||||||
|
async def chat(
|
||||||
|
self,
|
||||||
|
messages: list[dict[str, str]],
|
||||||
|
system: Optional[str] = None,
|
||||||
|
tools: Optional[str] = None,
|
||||||
|
images: Optional[list["ImageInput"]] = None,
|
||||||
|
videos: Optional[list["VideoInput"]] = None,
|
||||||
|
audios: Optional[list["AudioInput"]] = None,
|
||||||
|
**input_kwargs,
|
||||||
|
) -> list["Response"]:
|
||||||
|
if not self.can_generate:
|
||||||
|
raise ValueError("The current model does not support `chat`.")
|
||||||
|
async with self.semaphore:
|
||||||
|
produced = ""
|
||||||
|
final_text = ""
|
||||||
|
async for t in self._generate(messages, system, tools, **input_kwargs):
|
||||||
|
delta = t
|
||||||
|
produced = produced + delta
|
||||||
|
if delta:
|
||||||
|
final_text += delta
|
||||||
|
|
||||||
|
prompt_ids, _ = self.template.encode_oneturn(
|
||||||
|
self.tokenizer, messages + [{"role": "assistant", "content": ""}], system, tools
|
||||||
|
)
|
||||||
|
return [
|
||||||
|
Response(
|
||||||
|
response_text=final_text,
|
||||||
|
response_length=len(self.tokenizer.encode(final_text, add_special_tokens=False)),
|
||||||
|
prompt_length=len(prompt_ids),
|
||||||
|
finish_reason="stop",
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
@override
|
||||||
|
async def stream_chat(
|
||||||
|
self,
|
||||||
|
messages: list[dict[str, str]],
|
||||||
|
system: Optional[str] = None,
|
||||||
|
tools: Optional[str] = None,
|
||||||
|
images: Optional[list["ImageInput"]] = None,
|
||||||
|
videos: Optional[list["VideoInput"]] = None,
|
||||||
|
audios: Optional[list["AudioInput"]] = None,
|
||||||
|
**input_kwargs,
|
||||||
|
) -> AsyncGenerator[str, None]:
|
||||||
|
if not self.can_generate:
|
||||||
|
raise ValueError("The current model does not support `stream_chat`.")
|
||||||
|
async with self.semaphore:
|
||||||
|
produced = ""
|
||||||
|
async for t in self._generate(messages, system, tools, **input_kwargs):
|
||||||
|
delta = t[len(produced) :] if t.startswith(produced) else t
|
||||||
|
produced = t
|
||||||
|
if delta:
|
||||||
|
yield delta
|
||||||
|
|
||||||
|
@override
|
||||||
|
async def get_scores(
|
||||||
|
self,
|
||||||
|
batch_input: list[str],
|
||||||
|
**input_kwargs,
|
||||||
|
) -> list[float]:
|
||||||
|
if self.can_generate:
|
||||||
|
raise ValueError("Cannot get scores using an auto-regressive model.")
|
||||||
|
args = (self.model, self.tokenizer, batch_input, input_kwargs)
|
||||||
|
async with self.semaphore:
|
||||||
|
return await asyncio.to_thread(self._get_scores, *args)
|
||||||
@@ -16,6 +16,7 @@ import uuid
|
|||||||
from collections.abc import AsyncGenerator, AsyncIterator
|
from collections.abc import AsyncGenerator, AsyncIterator
|
||||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||||
|
|
||||||
|
from packaging import version
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
from ..data import get_template_and_fix_tokenizer
|
from ..data import get_template_and_fix_tokenizer
|
||||||
@@ -77,11 +78,18 @@ class VllmEngine(BaseEngine):
|
|||||||
"tensor_parallel_size": get_device_count() or 1,
|
"tensor_parallel_size": get_device_count() or 1,
|
||||||
"gpu_memory_utilization": model_args.vllm_gpu_util,
|
"gpu_memory_utilization": model_args.vllm_gpu_util,
|
||||||
"disable_log_stats": True,
|
"disable_log_stats": True,
|
||||||
"disable_log_requests": True,
|
|
||||||
"enforce_eager": model_args.vllm_enforce_eager,
|
"enforce_eager": model_args.vllm_enforce_eager,
|
||||||
"enable_lora": model_args.adapter_name_or_path is not None,
|
"enable_lora": model_args.adapter_name_or_path is not None,
|
||||||
"max_lora_rank": model_args.vllm_max_lora_rank,
|
"max_lora_rank": model_args.vllm_max_lora_rank,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
import vllm
|
||||||
|
|
||||||
|
if version.parse(vllm.__version__) <= version.parse("0.10.0"):
|
||||||
|
engine_args["disable_log_requests"] = True
|
||||||
|
else:
|
||||||
|
engine_args["enable_log_requests"] = False
|
||||||
|
|
||||||
if self.template.mm_plugin.__class__.__name__ != "BasePlugin":
|
if self.template.mm_plugin.__class__.__name__ != "BasePlugin":
|
||||||
engine_args["limit_mm_per_prompt"] = {"image": 4, "video": 2, "audio": 2}
|
engine_args["limit_mm_per_prompt"] = {"image": 4, "video": 2, "audio": 2}
|
||||||
|
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ import json
|
|||||||
import os
|
import os
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
from typing import TYPE_CHECKING, Any, Union
|
||||||
|
|
||||||
from ..extras import logging
|
from ..extras import logging
|
||||||
from .data_utils import Role
|
from .data_utils import Role
|
||||||
@@ -40,7 +40,7 @@ class DatasetConverter:
|
|||||||
dataset_attr: "DatasetAttr"
|
dataset_attr: "DatasetAttr"
|
||||||
data_args: "DataArguments"
|
data_args: "DataArguments"
|
||||||
|
|
||||||
def _find_medias(self, medias: Union["MediaType", list["MediaType"], None]) -> Optional[list["MediaType"]]:
|
def _find_medias(self, medias: Union["MediaType", list["MediaType"], None]) -> list["MediaType"] | None:
|
||||||
r"""Optionally concatenate media path to media dir when loading from local disk."""
|
r"""Optionally concatenate media path to media dir when loading from local disk."""
|
||||||
if medias is None:
|
if medias is None:
|
||||||
return None
|
return None
|
||||||
|
|||||||
@@ -81,41 +81,48 @@ def split_dataset(
|
|||||||
eval_dataset: Optional[Union["Dataset", "IterableDataset", dict[str, "Dataset"]]],
|
eval_dataset: Optional[Union["Dataset", "IterableDataset", dict[str, "Dataset"]]],
|
||||||
data_args: "DataArguments",
|
data_args: "DataArguments",
|
||||||
seed: int,
|
seed: int,
|
||||||
) -> "DatasetDict":
|
) -> tuple[dict, dict]:
|
||||||
r"""Split the dataset and returns a dataset dict containing train set and validation set.
|
r"""Split the dataset and returns two dicts containing train set and validation set.
|
||||||
|
|
||||||
Support both map dataset and iterable dataset.
|
Support both map dataset and iterable dataset.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
train_dict: Dictionary containing training data with key "train"
|
||||||
|
eval_dict: Dictionary containing evaluation data with keys "validation" or "validation_{name}"
|
||||||
"""
|
"""
|
||||||
if eval_dataset is not None and data_args.val_size > 1e-6:
|
if eval_dataset is not None and data_args.val_size > 1e-6:
|
||||||
raise ValueError("Cannot specify `val_size` if `eval_dataset` is not None.")
|
raise ValueError("Cannot specify `val_size` if `eval_dataset` is not None.")
|
||||||
|
|
||||||
dataset_dict = {}
|
# the train and eval better to in dict dtype and separately return for cpode clearly and good handle outside
|
||||||
|
train_dict, eval_dict = {}, {}
|
||||||
|
|
||||||
if dataset is not None:
|
if dataset is not None:
|
||||||
if data_args.streaming:
|
if data_args.streaming:
|
||||||
dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=seed)
|
dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=seed)
|
||||||
|
|
||||||
if data_args.val_size > 1e-6:
|
if data_args.val_size > 1e-6:
|
||||||
if data_args.streaming:
|
if data_args.streaming:
|
||||||
dataset_dict["validation"] = dataset.take(int(data_args.val_size))
|
eval_dict["validation"] = dataset.take(int(data_args.val_size))
|
||||||
dataset_dict["train"] = dataset.skip(int(data_args.val_size))
|
train_dict["train"] = dataset.skip(int(data_args.val_size))
|
||||||
else:
|
else:
|
||||||
val_size = int(data_args.val_size) if data_args.val_size > 1 else data_args.val_size
|
val_size = int(data_args.val_size) if data_args.val_size > 1 else data_args.val_size
|
||||||
dataset_dict = dataset.train_test_split(test_size=val_size, seed=seed)
|
split_result = dataset.train_test_split(test_size=val_size, seed=seed)
|
||||||
dataset = dataset.train_test_split(test_size=val_size, seed=seed)
|
train_dict["train"] = split_result["train"]
|
||||||
dataset_dict = {"train": dataset["train"], "validation": dataset["test"]}
|
eval_dict["validation"] = split_result["test"]
|
||||||
else:
|
else:
|
||||||
dataset_dict["train"] = dataset
|
train_dict["train"] = dataset
|
||||||
|
|
||||||
if eval_dataset is not None:
|
if eval_dataset is not None:
|
||||||
if isinstance(eval_dataset, dict):
|
if isinstance(eval_dataset, dict):
|
||||||
dataset_dict.update({f"validation_{name}": data for name, data in eval_dataset.items()})
|
for name, data in eval_dataset.items():
|
||||||
|
eval_dict[f"validation_{name}"] = data
|
||||||
else:
|
else:
|
||||||
if data_args.streaming:
|
if data_args.streaming:
|
||||||
eval_dataset = eval_dataset.shuffle(buffer_size=data_args.buffer_size, seed=seed)
|
eval_dataset = eval_dataset.shuffle(buffer_size=data_args.buffer_size, seed=seed)
|
||||||
|
|
||||||
dataset_dict["validation"] = eval_dataset
|
eval_dict["validation"] = eval_dataset
|
||||||
|
|
||||||
return DatasetDict(dataset_dict)
|
return train_dict, eval_dict
|
||||||
|
|
||||||
|
|
||||||
def get_dataset_module(dataset: Union["Dataset", "DatasetDict"]) -> "DatasetModule":
|
def get_dataset_module(dataset: Union["Dataset", "DatasetDict"]) -> "DatasetModule":
|
||||||
|
|||||||
@@ -16,7 +16,6 @@ import json
|
|||||||
import re
|
import re
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Optional, Union
|
|
||||||
|
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
@@ -27,14 +26,14 @@ from .tool_utils import FunctionCall, get_tool_utils
|
|||||||
@dataclass
|
@dataclass
|
||||||
class Formatter(ABC):
|
class Formatter(ABC):
|
||||||
slots: SLOTS = field(default_factory=list)
|
slots: SLOTS = field(default_factory=list)
|
||||||
tool_format: Optional[str] = None
|
tool_format: str | None = None
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def apply(self, **kwargs) -> SLOTS:
|
def apply(self, **kwargs) -> SLOTS:
|
||||||
r"""Forms a list of slots according to the inputs to encode."""
|
r"""Forms a list of slots according to the inputs to encode."""
|
||||||
...
|
...
|
||||||
|
|
||||||
def extract(self, content: str) -> Union[str, list["FunctionCall"]]:
|
def extract(self, content: str) -> str | list["FunctionCall"]:
|
||||||
r"""Extract a list of tuples from the response message if using tools.
|
r"""Extract a list of tuples from the response message if using tools.
|
||||||
|
|
||||||
Each tuple consists of function name and function arguments.
|
Each tuple consists of function name and function arguments.
|
||||||
@@ -97,31 +96,46 @@ class FunctionFormatter(StringFormatter):
|
|||||||
@override
|
@override
|
||||||
def apply(self, **kwargs) -> SLOTS:
|
def apply(self, **kwargs) -> SLOTS:
|
||||||
content: str = kwargs.pop("content")
|
content: str = kwargs.pop("content")
|
||||||
thought_words, thought = kwargs.pop("thought_words", None), None
|
thought_words = kwargs.pop("thought_words", None)
|
||||||
if thought_words and len(thought_words) == 2:
|
tool_call_words = kwargs.pop("tool_call_words", None)
|
||||||
regex = re.compile(rf"{re.escape(thought_words[0])}(.*?){re.escape(thought_words[1])}", re.DOTALL)
|
|
||||||
thought = re.search(regex, content)
|
|
||||||
|
|
||||||
if thought:
|
def _parse_functions(json_content: str) -> list["FunctionCall"]:
|
||||||
content = content.replace(thought.group(0), "")
|
try:
|
||||||
|
tool_calls = json.loads(json_content)
|
||||||
|
if not isinstance(tool_calls, list): # parallel function call
|
||||||
|
tool_calls = [tool_calls]
|
||||||
|
|
||||||
functions: list[FunctionCall] = []
|
return [FunctionCall(tc["name"], json.dumps(tc["arguments"], ensure_ascii=False)) for tc in tool_calls]
|
||||||
try:
|
except json.JSONDecodeError:
|
||||||
tool_calls = json.loads(content)
|
raise RuntimeError(f"Invalid JSON format in function message: {str([content])}.")
|
||||||
if not isinstance(tool_calls, list): # parallel function call
|
|
||||||
tool_calls = [tool_calls]
|
|
||||||
|
|
||||||
for tool_call in tool_calls:
|
tool_call_match = None
|
||||||
functions.append(
|
if tool_call_words and len(tool_call_words) == 2:
|
||||||
FunctionCall(tool_call["name"], json.dumps(tool_call["arguments"], ensure_ascii=False))
|
tool_call_regex = re.compile(
|
||||||
)
|
rf"{re.escape(tool_call_words[0])}(.*?){re.escape(tool_call_words[1])}", re.DOTALL
|
||||||
|
)
|
||||||
|
tool_call_match = re.search(tool_call_regex, content)
|
||||||
|
|
||||||
except json.JSONDecodeError:
|
if tool_call_match is None:
|
||||||
raise RuntimeError(f"Invalid JSON format in function message: {str([content])}.") # flat string
|
thought_match = None
|
||||||
|
if thought_words and len(thought_words) == 2:
|
||||||
|
regex = re.compile(rf"{re.escape(thought_words[0])}(.*?){re.escape(thought_words[1])}", re.DOTALL)
|
||||||
|
thought_match = re.search(regex, content)
|
||||||
|
|
||||||
function_str = self.tool_utils.function_formatter(functions)
|
if thought_match:
|
||||||
if thought:
|
json_part = content.replace(thought_match.group(0), "")
|
||||||
function_str = thought.group(0) + function_str
|
else:
|
||||||
|
json_part = content
|
||||||
|
|
||||||
|
functions = _parse_functions(json_part)
|
||||||
|
function_str = self.tool_utils.function_formatter(functions)
|
||||||
|
if thought_match:
|
||||||
|
function_str = thought_match.group(0) + function_str
|
||||||
|
else:
|
||||||
|
thought_content = content.replace(tool_call_match.group(0), "")
|
||||||
|
functions = _parse_functions(tool_call_match.group(1))
|
||||||
|
function_str = self.tool_utils.function_formatter(functions)
|
||||||
|
function_str = thought_content + function_str
|
||||||
|
|
||||||
return super().apply(content=function_str)
|
return super().apply(content=function_str)
|
||||||
|
|
||||||
@@ -141,5 +155,5 @@ class ToolFormatter(Formatter):
|
|||||||
raise RuntimeError(f"Invalid JSON format in tool description: {str([content])}.") # flat string
|
raise RuntimeError(f"Invalid JSON format in tool description: {str([content])}.") # flat string
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def extract(self, content: str) -> Union[str, list["FunctionCall"]]:
|
def extract(self, content: str) -> str | list["FunctionCall"]:
|
||||||
return self.tool_utils.tool_extractor(content)
|
return self.tool_utils.tool_extractor(content)
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ import os
|
|||||||
from typing import TYPE_CHECKING, Literal, Optional, Union
|
from typing import TYPE_CHECKING, Literal, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from datasets import Dataset, load_dataset, load_from_disk
|
from datasets import Dataset, DatasetDict, load_dataset, load_from_disk
|
||||||
|
|
||||||
from ..extras import logging
|
from ..extras import logging
|
||||||
from ..extras.constants import FILEEXT2TYPE
|
from ..extras.constants import FILEEXT2TYPE
|
||||||
@@ -137,7 +137,6 @@ def _load_single_dataset(
|
|||||||
cache_dir=model_args.cache_dir,
|
cache_dir=model_args.cache_dir,
|
||||||
token=model_args.hf_hub_token,
|
token=model_args.hf_hub_token,
|
||||||
num_proc=data_args.preprocessing_num_workers,
|
num_proc=data_args.preprocessing_num_workers,
|
||||||
trust_remote_code=model_args.trust_remote_code,
|
|
||||||
streaming=data_args.streaming and dataset_attr.load_from != "file",
|
streaming=data_args.streaming and dataset_attr.load_from != "file",
|
||||||
)
|
)
|
||||||
if data_args.streaming and dataset_attr.load_from == "file":
|
if data_args.streaming and dataset_attr.load_from == "file":
|
||||||
@@ -163,13 +162,13 @@ def _load_single_dataset(
|
|||||||
|
|
||||||
|
|
||||||
def _get_merged_dataset(
|
def _get_merged_dataset(
|
||||||
dataset_names: Optional[list[str]],
|
dataset_names: list[str] | None,
|
||||||
model_args: "ModelArguments",
|
model_args: "ModelArguments",
|
||||||
data_args: "DataArguments",
|
data_args: "DataArguments",
|
||||||
training_args: "Seq2SeqTrainingArguments",
|
training_args: "Seq2SeqTrainingArguments",
|
||||||
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
|
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
|
||||||
return_dict: bool = False,
|
return_dict: bool = False,
|
||||||
) -> Optional[Union["Dataset", "IterableDataset", dict[str, "Dataset"]]]:
|
) -> Union["Dataset", "IterableDataset", dict[str, "Dataset"]] | None:
|
||||||
r"""Return the merged datasets in the standard format."""
|
r"""Return the merged datasets in the standard format."""
|
||||||
if dataset_names is None:
|
if dataset_names is None:
|
||||||
return None
|
return None
|
||||||
@@ -228,7 +227,7 @@ def _get_dataset_processor(
|
|||||||
|
|
||||||
|
|
||||||
def _get_preprocessed_dataset(
|
def _get_preprocessed_dataset(
|
||||||
dataset: Optional[Union["Dataset", "IterableDataset"]],
|
dataset: Union["Dataset", "IterableDataset"] | None,
|
||||||
data_args: "DataArguments",
|
data_args: "DataArguments",
|
||||||
training_args: "Seq2SeqTrainingArguments",
|
training_args: "Seq2SeqTrainingArguments",
|
||||||
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
|
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
|
||||||
@@ -236,7 +235,7 @@ def _get_preprocessed_dataset(
|
|||||||
tokenizer: "PreTrainedTokenizer",
|
tokenizer: "PreTrainedTokenizer",
|
||||||
processor: Optional["ProcessorMixin"] = None,
|
processor: Optional["ProcessorMixin"] = None,
|
||||||
is_eval: bool = False,
|
is_eval: bool = False,
|
||||||
) -> Optional[Union["Dataset", "IterableDataset"]]:
|
) -> Union["Dataset", "IterableDataset"] | None:
|
||||||
r"""Preprocesses the dataset, including format checking and tokenization."""
|
r"""Preprocesses the dataset, including format checking and tokenization."""
|
||||||
if dataset is None:
|
if dataset is None:
|
||||||
return None
|
return None
|
||||||
@@ -312,20 +311,22 @@ def get_dataset(
|
|||||||
)
|
)
|
||||||
|
|
||||||
with training_args.main_process_first(desc="pre-process dataset", local=(not data_args.data_shared_file_system)):
|
with training_args.main_process_first(desc="pre-process dataset", local=(not data_args.data_shared_file_system)):
|
||||||
dataset = _get_preprocessed_dataset(
|
# move front to make sure eval_dataset(if contain or split) can preprocessed appropriately
|
||||||
dataset, data_args, training_args, stage, template, tokenizer, processor, is_eval=False
|
train_dict, eval_dict = split_dataset(dataset, eval_dataset, data_args, seed=training_args.seed)
|
||||||
)
|
|
||||||
if isinstance(eval_dataset, dict):
|
if "train" in train_dict:
|
||||||
for eval_name, eval_data in eval_dataset.items():
|
train_dict["train"] = _get_preprocessed_dataset(
|
||||||
eval_dataset[eval_name] = _get_preprocessed_dataset(
|
train_dict["train"], data_args, training_args, stage, template, tokenizer, processor, is_eval=False
|
||||||
eval_data, data_args, training_args, stage, template, tokenizer, processor, is_eval=True
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
eval_dataset = _get_preprocessed_dataset(
|
|
||||||
eval_dataset, data_args, training_args, stage, template, tokenizer, processor, is_eval=True
|
|
||||||
)
|
)
|
||||||
|
|
||||||
dataset_dict = split_dataset(dataset, eval_dataset, data_args, seed=training_args.seed)
|
for key in eval_dict:
|
||||||
|
eval_dict[key] = _get_preprocessed_dataset(
|
||||||
|
eval_dict[key], data_args, training_args, stage, template, tokenizer, processor, is_eval=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Combine train and eval dictionaries
|
||||||
|
dataset_dict = DatasetDict({**train_dict, **eval_dict})
|
||||||
|
|
||||||
if data_args.tokenized_path is not None: # save tokenized dataset to disk
|
if data_args.tokenized_path is not None: # save tokenized dataset to disk
|
||||||
if training_args.should_save:
|
if training_args.should_save:
|
||||||
dataset_dict.save_to_disk(data_args.tokenized_path)
|
dataset_dict.save_to_disk(data_args.tokenized_path)
|
||||||
|
|||||||
@@ -22,10 +22,11 @@ import re
|
|||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from typing import TYPE_CHECKING, BinaryIO, Literal, Optional, TypedDict, Union
|
from typing import TYPE_CHECKING, BinaryIO, Literal, NotRequired, Optional, TypedDict, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
import torchaudio
|
||||||
from transformers.image_utils import get_image_size, is_valid_image, to_numpy_array
|
from transformers.image_utils import get_image_size, is_valid_image, to_numpy_array
|
||||||
from transformers.models.mllama.processing_mllama import (
|
from transformers.models.mllama.processing_mllama import (
|
||||||
convert_sparse_cross_attention_mask_to_dense,
|
convert_sparse_cross_attention_mask_to_dense,
|
||||||
@@ -34,16 +35,7 @@ from transformers.models.mllama.processing_mllama import (
|
|||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
from ..extras.constants import AUDIO_PLACEHOLDER, IGNORE_INDEX, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
|
from ..extras.constants import AUDIO_PLACEHOLDER, IGNORE_INDEX, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
|
||||||
from ..extras.packages import (
|
from ..extras.packages import is_pillow_available, is_pyav_available, is_transformers_version_greater_than
|
||||||
is_librosa_available,
|
|
||||||
is_pillow_available,
|
|
||||||
is_pyav_available,
|
|
||||||
is_transformers_version_greater_than,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if is_librosa_available():
|
|
||||||
import librosa
|
|
||||||
|
|
||||||
|
|
||||||
if is_pillow_available():
|
if is_pillow_available():
|
||||||
@@ -68,15 +60,28 @@ if TYPE_CHECKING:
|
|||||||
from transformers import PreTrainedTokenizer, ProcessorMixin
|
from transformers import PreTrainedTokenizer, ProcessorMixin
|
||||||
from transformers.feature_extraction_sequence_utils import SequenceFeatureExtractor
|
from transformers.feature_extraction_sequence_utils import SequenceFeatureExtractor
|
||||||
from transformers.image_processing_utils import BaseImageProcessor
|
from transformers.image_processing_utils import BaseImageProcessor
|
||||||
|
from transformers.video_processing_utils import BaseVideoProcessor
|
||||||
|
|
||||||
class EncodedImage(TypedDict):
|
class EncodedImage(TypedDict):
|
||||||
path: Optional[str]
|
path: str | None
|
||||||
bytes: Optional[bytes]
|
bytes: bytes | None
|
||||||
|
|
||||||
ImageInput = Union[str, bytes, EncodedImage, BinaryIO, ImageObject]
|
ImageInput = Union[str, bytes, EncodedImage, BinaryIO, ImageObject]
|
||||||
VideoInput = Union[str, BinaryIO, list[list[ImageInput]]]
|
VideoInput = Union[str, BinaryIO, list[list[ImageInput]]]
|
||||||
AudioInput = Union[str, BinaryIO, NDArray]
|
AudioInput = Union[str, BinaryIO, NDArray]
|
||||||
|
|
||||||
|
class RegularizedImageOutput(TypedDict):
|
||||||
|
images: list[ImageObject]
|
||||||
|
|
||||||
|
class RegularizedVideoOutput(TypedDict):
|
||||||
|
videos: list[list[ImageObject]]
|
||||||
|
durations: list[float]
|
||||||
|
fps_per_video: NotRequired[list[float]]
|
||||||
|
|
||||||
|
class RegularizedAudioOutput(TypedDict):
|
||||||
|
audios: list[NDArray]
|
||||||
|
sampling_rates: list[float]
|
||||||
|
|
||||||
class MMProcessor(ProcessorMixin):
|
class MMProcessor(ProcessorMixin):
|
||||||
patch_size: int
|
patch_size: int
|
||||||
image_seq_length: int
|
image_seq_length: int
|
||||||
@@ -139,9 +144,9 @@ def _check_video_is_nested_images(video: "VideoInput") -> bool:
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class MMPluginMixin:
|
class MMPluginMixin:
|
||||||
image_token: Optional[str]
|
image_token: str | None
|
||||||
video_token: Optional[str]
|
video_token: str | None
|
||||||
audio_token: Optional[str]
|
audio_token: str | None
|
||||||
expand_mm_tokens: bool = True
|
expand_mm_tokens: bool = True
|
||||||
|
|
||||||
def _validate_input(
|
def _validate_input(
|
||||||
@@ -244,7 +249,7 @@ class MMPluginMixin:
|
|||||||
sample_frames = min(total_frames, video_maxlen, sample_frames)
|
sample_frames = min(total_frames, video_maxlen, sample_frames)
|
||||||
return np.linspace(0, total_frames - 1, sample_frames).astype(np.int32)
|
return np.linspace(0, total_frames - 1, sample_frames).astype(np.int32)
|
||||||
|
|
||||||
def _regularize_images(self, images: list["ImageInput"], **kwargs) -> dict[str, list["ImageObject"]]:
|
def _regularize_images(self, images: list["ImageInput"], **kwargs) -> "RegularizedImageOutput":
|
||||||
r"""Regularize images to avoid error. Including reading and pre-processing."""
|
r"""Regularize images to avoid error. Including reading and pre-processing."""
|
||||||
results = []
|
results = []
|
||||||
for image in images:
|
for image in images:
|
||||||
@@ -265,9 +270,10 @@ class MMPluginMixin:
|
|||||||
|
|
||||||
return {"images": results}
|
return {"images": results}
|
||||||
|
|
||||||
def _regularize_videos(self, videos: list["VideoInput"], **kwargs) -> dict[str, list[list["ImageObject"]]]:
|
def _regularize_videos(self, videos: list["VideoInput"], **kwargs) -> "RegularizedVideoOutput":
|
||||||
r"""Regularizes videos to avoid error. Including reading, resizing and converting."""
|
r"""Regularizes videos to avoid error. Including reading, resizing and converting."""
|
||||||
results = []
|
results = []
|
||||||
|
durations = []
|
||||||
for video in videos:
|
for video in videos:
|
||||||
frames: list[ImageObject] = []
|
frames: list[ImageObject] = []
|
||||||
if _check_video_is_nested_images(video):
|
if _check_video_is_nested_images(video):
|
||||||
@@ -275,6 +281,7 @@ class MMPluginMixin:
|
|||||||
if not is_valid_image(frame) and not isinstance(frame, dict) and not os.path.exists(frame):
|
if not is_valid_image(frame) and not isinstance(frame, dict) and not os.path.exists(frame):
|
||||||
raise ValueError("Invalid image found in video frames.")
|
raise ValueError("Invalid image found in video frames.")
|
||||||
frames = video
|
frames = video
|
||||||
|
durations.append(len(frames) / kwargs.get("video_fps", 2.0))
|
||||||
else:
|
else:
|
||||||
container = av.open(video, "r")
|
container = av.open(video, "r")
|
||||||
video_stream = next(stream for stream in container.streams if stream.type == "video")
|
video_stream = next(stream for stream in container.streams if stream.type == "video")
|
||||||
@@ -284,19 +291,31 @@ class MMPluginMixin:
|
|||||||
if frame_idx in sample_indices:
|
if frame_idx in sample_indices:
|
||||||
frames.append(frame.to_image())
|
frames.append(frame.to_image())
|
||||||
|
|
||||||
|
if video_stream.duration is None:
|
||||||
|
durations.append(len(frames) / kwargs.get("video_fps", 2.0))
|
||||||
|
else:
|
||||||
|
durations.append(float(video_stream.duration * video_stream.time_base))
|
||||||
|
|
||||||
frames = self._regularize_images(frames, **kwargs)["images"]
|
frames = self._regularize_images(frames, **kwargs)["images"]
|
||||||
results.append(frames)
|
results.append(frames)
|
||||||
|
|
||||||
return {"videos": results}
|
return {"videos": results, "durations": durations}
|
||||||
|
|
||||||
def _regularize_audios(
|
def _regularize_audios(
|
||||||
self, audios: list["AudioInput"], sampling_rate: float, **kwargs
|
self, audios: list["AudioInput"], sampling_rate: float, **kwargs
|
||||||
) -> dict[str, Union[list["NDArray"], list[float]]]:
|
) -> "RegularizedAudioOutput":
|
||||||
r"""Regularizes audios to avoid error. Including reading and resampling."""
|
r"""Regularizes audios to avoid error. Including reading and resampling."""
|
||||||
results, sampling_rates = [], []
|
results, sampling_rates = [], []
|
||||||
for audio in audios:
|
for audio in audios:
|
||||||
if not isinstance(audio, np.ndarray):
|
if not isinstance(audio, np.ndarray):
|
||||||
audio, sampling_rate = librosa.load(audio, sr=sampling_rate)
|
audio, sr = torchaudio.load(audio)
|
||||||
|
if audio.shape[0] > 1:
|
||||||
|
audio = audio.mean(dim=0, keepdim=True)
|
||||||
|
|
||||||
|
if sr != sampling_rate:
|
||||||
|
audio = torchaudio.functional.resample(audio, sr, sampling_rate)
|
||||||
|
|
||||||
|
audio = audio.squeeze(0).numpy()
|
||||||
|
|
||||||
results.append(audio)
|
results.append(audio)
|
||||||
sampling_rates.append(sampling_rate)
|
sampling_rates.append(sampling_rate)
|
||||||
@@ -309,7 +328,7 @@ class MMPluginMixin:
|
|||||||
videos: list["VideoInput"],
|
videos: list["VideoInput"],
|
||||||
audios: list["AudioInput"],
|
audios: list["AudioInput"],
|
||||||
processor: "MMProcessor",
|
processor: "MMProcessor",
|
||||||
imglens: Optional[list[int]] = None,
|
imglens: list[int] | None = None,
|
||||||
) -> dict[str, "torch.Tensor"]:
|
) -> dict[str, "torch.Tensor"]:
|
||||||
r"""Process visual inputs.
|
r"""Process visual inputs.
|
||||||
|
|
||||||
@@ -407,13 +426,13 @@ class BasePlugin(MMPluginMixin):
|
|||||||
def process_token_ids(
|
def process_token_ids(
|
||||||
self,
|
self,
|
||||||
input_ids: list[int],
|
input_ids: list[int],
|
||||||
labels: Optional[list[int]],
|
labels: list[int] | None,
|
||||||
images: list["ImageInput"],
|
images: list["ImageInput"],
|
||||||
videos: list["VideoInput"],
|
videos: list["VideoInput"],
|
||||||
audios: list["AudioInput"],
|
audios: list["AudioInput"],
|
||||||
tokenizer: "PreTrainedTokenizer",
|
tokenizer: "PreTrainedTokenizer",
|
||||||
processor: Optional["MMProcessor"],
|
processor: Optional["MMProcessor"],
|
||||||
) -> tuple[list[int], Optional[list[int]]]:
|
) -> tuple[list[int], list[int] | None]:
|
||||||
r"""Pre-process token ids after tokenization for VLMs."""
|
r"""Pre-process token ids after tokenization for VLMs."""
|
||||||
self._validate_input(processor, images, videos, audios)
|
self._validate_input(processor, images, videos, audios)
|
||||||
return input_ids, labels
|
return input_ids, labels
|
||||||
@@ -446,6 +465,57 @@ class BasePlugin(MMPluginMixin):
|
|||||||
return self._get_mm_inputs(images, videos, audios, processor)
|
return self._get_mm_inputs(images, videos, audios, processor)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ErnieVLPlugin(BasePlugin):
|
||||||
|
@override
|
||||||
|
def process_messages(
|
||||||
|
self,
|
||||||
|
messages: list[dict[str, str]],
|
||||||
|
images: list["ImageInput"],
|
||||||
|
videos: list["VideoInput"],
|
||||||
|
audios: list["AudioInput"],
|
||||||
|
processor: Optional["MMProcessor"],
|
||||||
|
) -> list[dict[str, str]]:
|
||||||
|
self._validate_input(processor, images, videos, audios)
|
||||||
|
self._validate_messages(messages, images, videos, audios)
|
||||||
|
messages = deepcopy(messages)
|
||||||
|
|
||||||
|
image_processor: BaseImageProcessor = getattr(processor, "image_processor")
|
||||||
|
|
||||||
|
merge_length: int = getattr(image_processor, "merge_size") ** 2
|
||||||
|
if self.expand_mm_tokens:
|
||||||
|
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
||||||
|
image_grid_thw = mm_inputs.get("image_grid_thw", [])
|
||||||
|
video_grid_thw = mm_inputs.get("video_grid_thw", [])
|
||||||
|
else:
|
||||||
|
image_grid_thw = [None] * len(images)
|
||||||
|
video_grid_thw = [None] * len(videos)
|
||||||
|
|
||||||
|
image_idx, video_idx = 0, 0
|
||||||
|
for message in messages:
|
||||||
|
content = message["content"]
|
||||||
|
image_token = self.image_token or "<|IMAGE_PLACEHOLDER|>"
|
||||||
|
video_token = self.video_token or "<|VIDEO_PLACEHOLDER|>"
|
||||||
|
while IMAGE_PLACEHOLDER in content:
|
||||||
|
image_seqlen = image_grid_thw[image_idx].prod() // merge_length if self.expand_mm_tokens else 1
|
||||||
|
content = content.replace(
|
||||||
|
IMAGE_PLACEHOLDER,
|
||||||
|
f"Picture {image_idx + 1}:<|IMAGE_START|>{image_token * image_seqlen}<|IMAGE_END|>",
|
||||||
|
1,
|
||||||
|
)
|
||||||
|
image_idx += 1
|
||||||
|
while VIDEO_PLACEHOLDER in content:
|
||||||
|
video_seqlen = video_grid_thw[video_idx].prod() // merge_length if self.expand_mm_tokens else 1
|
||||||
|
content = content.replace(
|
||||||
|
VIDEO_PLACEHOLDER,
|
||||||
|
f"Video {video_idx + 1}:<|VIDEO_START|>{video_token * video_seqlen}<|VIDEO_END|>",
|
||||||
|
1,
|
||||||
|
)
|
||||||
|
video_idx += 1
|
||||||
|
message["content"] = content
|
||||||
|
return messages
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Gemma3Plugin(BasePlugin):
|
class Gemma3Plugin(BasePlugin):
|
||||||
@override
|
@override
|
||||||
@@ -1235,13 +1305,13 @@ class PaliGemmaPlugin(BasePlugin):
|
|||||||
def process_token_ids(
|
def process_token_ids(
|
||||||
self,
|
self,
|
||||||
input_ids: list[int],
|
input_ids: list[int],
|
||||||
labels: Optional[list[int]],
|
labels: list[int] | None,
|
||||||
images: list["ImageInput"],
|
images: list["ImageInput"],
|
||||||
videos: list["VideoInput"],
|
videos: list["VideoInput"],
|
||||||
audios: list["AudioInput"],
|
audios: list["AudioInput"],
|
||||||
tokenizer: "PreTrainedTokenizer",
|
tokenizer: "PreTrainedTokenizer",
|
||||||
processor: Optional["MMProcessor"],
|
processor: Optional["MMProcessor"],
|
||||||
) -> tuple[list[int], Optional[list[int]]]:
|
) -> tuple[list[int], list[int] | None]:
|
||||||
self._validate_input(processor, images, videos, audios)
|
self._validate_input(processor, images, videos, audios)
|
||||||
num_images = len(images)
|
num_images = len(images)
|
||||||
image_seqlen = processor.image_seq_length if self.expand_mm_tokens else 0 # skip mm token
|
image_seqlen = processor.image_seq_length if self.expand_mm_tokens else 0 # skip mm token
|
||||||
@@ -1418,10 +1488,8 @@ class Qwen2VLPlugin(BasePlugin):
|
|||||||
return image
|
return image
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def _regularize_videos(
|
def _regularize_videos(self, videos: list["VideoInput"], **kwargs) -> "RegularizedVideoOutput":
|
||||||
self, videos: list["VideoInput"], **kwargs
|
results, fps_per_video, durations = [], [], []
|
||||||
) -> dict[str, Union[list[list["ImageObject"]], list[float]]]:
|
|
||||||
results, fps_per_video = [], []
|
|
||||||
for video in videos:
|
for video in videos:
|
||||||
frames: list[ImageObject] = []
|
frames: list[ImageObject] = []
|
||||||
if _check_video_is_nested_images(video):
|
if _check_video_is_nested_images(video):
|
||||||
@@ -1431,6 +1499,7 @@ class Qwen2VLPlugin(BasePlugin):
|
|||||||
|
|
||||||
frames = video
|
frames = video
|
||||||
fps_per_video.append(kwargs.get("video_fps", 2.0))
|
fps_per_video.append(kwargs.get("video_fps", 2.0))
|
||||||
|
durations.append(len(frames) / kwargs.get("video_fps", 2.0))
|
||||||
else:
|
else:
|
||||||
container = av.open(video, "r")
|
container = av.open(video, "r")
|
||||||
video_stream = next(stream for stream in container.streams if stream.type == "video")
|
video_stream = next(stream for stream in container.streams if stream.type == "video")
|
||||||
@@ -1442,8 +1511,10 @@ class Qwen2VLPlugin(BasePlugin):
|
|||||||
|
|
||||||
if video_stream.duration is None:
|
if video_stream.duration is None:
|
||||||
fps_per_video.append(kwargs.get("video_fps", 2.0))
|
fps_per_video.append(kwargs.get("video_fps", 2.0))
|
||||||
|
durations.append(len(frames) / kwargs.get("video_fps", 2.0))
|
||||||
else:
|
else:
|
||||||
fps_per_video.append(len(sample_indices) / float(video_stream.duration * video_stream.time_base))
|
fps_per_video.append(len(sample_indices) / float(video_stream.duration * video_stream.time_base))
|
||||||
|
durations.append(float(video_stream.duration * video_stream.time_base))
|
||||||
|
|
||||||
if len(frames) % 2 != 0:
|
if len(frames) % 2 != 0:
|
||||||
frames.append(frames[-1])
|
frames.append(frames[-1])
|
||||||
@@ -1451,7 +1522,7 @@ class Qwen2VLPlugin(BasePlugin):
|
|||||||
frames = self._regularize_images(frames, **kwargs)["images"]
|
frames = self._regularize_images(frames, **kwargs)["images"]
|
||||||
results.append(frames)
|
results.append(frames)
|
||||||
|
|
||||||
return {"videos": results, "fps_per_video": fps_per_video}
|
return {"videos": results, "fps_per_video": fps_per_video, "durations": durations}
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def _get_mm_inputs(
|
def _get_mm_inputs(
|
||||||
@@ -1462,6 +1533,7 @@ class Qwen2VLPlugin(BasePlugin):
|
|||||||
processor: "MMProcessor",
|
processor: "MMProcessor",
|
||||||
) -> dict[str, "torch.Tensor"]:
|
) -> dict[str, "torch.Tensor"]:
|
||||||
image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
|
image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
|
||||||
|
video_processor: BaseVideoProcessor = getattr(processor, "video_processor", None)
|
||||||
mm_inputs = {}
|
mm_inputs = {}
|
||||||
if len(images) != 0:
|
if len(images) != 0:
|
||||||
images = self._regularize_images(
|
images = self._regularize_images(
|
||||||
@@ -1479,7 +1551,7 @@ class Qwen2VLPlugin(BasePlugin):
|
|||||||
video_fps=getattr(processor, "video_fps", 2.0),
|
video_fps=getattr(processor, "video_fps", 2.0),
|
||||||
video_maxlen=getattr(processor, "video_maxlen", 128),
|
video_maxlen=getattr(processor, "video_maxlen", 128),
|
||||||
)
|
)
|
||||||
mm_inputs.update(image_processor(images=None, videos=video_data["videos"], return_tensors="pt"))
|
mm_inputs.update(video_processor(videos=video_data["videos"], return_tensors="pt"))
|
||||||
temporal_patch_size: int = getattr(image_processor, "temporal_patch_size", 2)
|
temporal_patch_size: int = getattr(image_processor, "temporal_patch_size", 2)
|
||||||
if "second_per_grid_ts" in processor.model_input_names:
|
if "second_per_grid_ts" in processor.model_input_names:
|
||||||
mm_inputs["second_per_grid_ts"] = [temporal_patch_size / fps for fps in video_data["fps_per_video"]]
|
mm_inputs["second_per_grid_ts"] = [temporal_patch_size / fps for fps in video_data["fps_per_video"]]
|
||||||
@@ -1565,11 +1637,16 @@ class Qwen3VLPlugin(Qwen2VLPlugin):
|
|||||||
video_maxlen=getattr(processor, "video_maxlen", 128),
|
video_maxlen=getattr(processor, "video_maxlen", 128),
|
||||||
)
|
)
|
||||||
video_metadata = [
|
video_metadata = [
|
||||||
{"fps": getattr(processor, "video_fps", 24.0), "duration": len(video), "total_num_frames": len(video)}
|
{"fps": getattr(processor, "video_fps", 24.0), "duration": duration, "total_num_frames": len(video)}
|
||||||
for video in videos["videos"]
|
for video, duration in zip(videos["videos"], videos["durations"])
|
||||||
]
|
]
|
||||||
mm_inputs.update(
|
mm_inputs.update(
|
||||||
video_processor(videos=videos["videos"], video_metadata=video_metadata, return_metadata=True)
|
video_processor(
|
||||||
|
videos=videos["videos"],
|
||||||
|
video_metadata=video_metadata,
|
||||||
|
fps=getattr(processor, "video_fps", 2.0),
|
||||||
|
return_metadata=True,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
temporal_patch_size: int = getattr(image_processor, "temporal_patch_size", 2)
|
temporal_patch_size: int = getattr(image_processor, "temporal_patch_size", 2)
|
||||||
if "second_per_grid_ts" in processor.model_input_names:
|
if "second_per_grid_ts" in processor.model_input_names:
|
||||||
@@ -1622,27 +1699,27 @@ class Qwen3VLPlugin(Qwen2VLPlugin):
|
|||||||
num_image_tokens += 1
|
num_image_tokens += 1
|
||||||
|
|
||||||
while VIDEO_PLACEHOLDER in content:
|
while VIDEO_PLACEHOLDER in content:
|
||||||
metadata = video_metadata[idx]
|
if self.expand_mm_tokens:
|
||||||
timestamps = processor._calculate_timestamps(
|
metadata = video_metadata[idx]
|
||||||
metadata.frames_indices,
|
timestamps = processor._calculate_timestamps(
|
||||||
metadata.fps,
|
metadata.frames_indices,
|
||||||
video_processor.merge_size,
|
metadata.fps,
|
||||||
)
|
video_processor.merge_size,
|
||||||
video_structure = ""
|
|
||||||
for frame_index in range(num_frames):
|
|
||||||
video_seqlen = (
|
|
||||||
video_grid_thw[num_video_tokens][1:].prod() // video_merge_length
|
|
||||||
if self.expand_mm_tokens
|
|
||||||
else 1
|
|
||||||
)
|
)
|
||||||
timestamp_sec = timestamps[frame_index]
|
video_structure = ""
|
||||||
frame_structure = (
|
for frame_index in range(num_frames):
|
||||||
f"<{timestamp_sec:.1f} seconds>"
|
video_seqlen = (
|
||||||
f"{self.vision_bos_token}{self.video_token * video_seqlen}{self.vision_eos_token}"
|
video_grid_thw[num_video_tokens][1:].prod() // video_merge_length
|
||||||
)
|
if self.expand_mm_tokens
|
||||||
video_structure += frame_structure
|
else 1
|
||||||
|
)
|
||||||
if not self.expand_mm_tokens:
|
timestamp_sec = timestamps[frame_index]
|
||||||
|
frame_structure = (
|
||||||
|
f"<{timestamp_sec:.1f} seconds>"
|
||||||
|
f"{self.vision_bos_token}{self.video_token * video_seqlen}{self.vision_eos_token}"
|
||||||
|
)
|
||||||
|
video_structure += frame_structure
|
||||||
|
else:
|
||||||
video_structure = f"{self.vision_bos_token}{self.video_token}{self.vision_eos_token}"
|
video_structure = f"{self.vision_bos_token}{self.video_token}{self.vision_eos_token}"
|
||||||
|
|
||||||
content = content.replace(VIDEO_PLACEHOLDER, video_structure, 1)
|
content = content.replace(VIDEO_PLACEHOLDER, video_structure, 1)
|
||||||
@@ -1684,7 +1761,8 @@ class GLM4VPlugin(Qwen2VLPlugin):
|
|||||||
)
|
)
|
||||||
# prepare video metadata
|
# prepare video metadata
|
||||||
video_metadata = [
|
video_metadata = [
|
||||||
{"fps": 2, "duration": len(video), "total_frames": len(video)} for video in video_data["videos"]
|
{"fps": 2, "duration": duration, "total_frames": len(video)}
|
||||||
|
for video, duration in zip(video_data["videos"], video_data["durations"])
|
||||||
]
|
]
|
||||||
mm_inputs.update(video_processor(images=None, videos=video_data["videos"], video_metadata=video_metadata))
|
mm_inputs.update(video_processor(images=None, videos=video_data["videos"], video_metadata=video_metadata))
|
||||||
|
|
||||||
@@ -1797,6 +1875,7 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
|
|||||||
processor: "MMProcessor",
|
processor: "MMProcessor",
|
||||||
) -> dict[str, "torch.Tensor"]:
|
) -> dict[str, "torch.Tensor"]:
|
||||||
image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
|
image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
|
||||||
|
video_processor: BaseVideoProcessor = getattr(processor, "video_processor", None)
|
||||||
feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None)
|
feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None)
|
||||||
mm_inputs = {}
|
mm_inputs = {}
|
||||||
if len(images) != 0:
|
if len(images) != 0:
|
||||||
@@ -1815,7 +1894,7 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
|
|||||||
video_fps=getattr(processor, "video_fps", 2.0),
|
video_fps=getattr(processor, "video_fps", 2.0),
|
||||||
video_maxlen=getattr(processor, "video_maxlen", 128),
|
video_maxlen=getattr(processor, "video_maxlen", 128),
|
||||||
)
|
)
|
||||||
mm_inputs.update(image_processor(images=None, videos=video_dict["videos"], return_tensors="pt"))
|
mm_inputs.update(video_processor(videos=video_dict["videos"], return_tensors="pt"))
|
||||||
temporal_patch_size: int = getattr(image_processor, "temporal_patch_size", 2)
|
temporal_patch_size: int = getattr(image_processor, "temporal_patch_size", 2)
|
||||||
mm_inputs["video_second_per_grid"] = torch.tensor(
|
mm_inputs["video_second_per_grid"] = torch.tensor(
|
||||||
[temporal_patch_size / fps for fps in video_dict["fps_per_video"]]
|
[temporal_patch_size / fps for fps in video_dict["fps_per_video"]]
|
||||||
@@ -1861,8 +1940,14 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
|
|||||||
image_grid_thw = mm_inputs.get("image_grid_thw", [])
|
image_grid_thw = mm_inputs.get("image_grid_thw", [])
|
||||||
video_grid_thw = mm_inputs.get("video_grid_thw", [])
|
video_grid_thw = mm_inputs.get("video_grid_thw", [])
|
||||||
if "feature_attention_mask" in mm_inputs:
|
if "feature_attention_mask" in mm_inputs:
|
||||||
input_lengths = (mm_inputs["feature_attention_mask"].sum(-1).numpy() - 1) // 2 + 1
|
if processor.__class__.__name__ == "Qwen3OmniMoeProcessor": # for qwen3omni
|
||||||
audio_lengths = (input_lengths - 2) // 2 + 1
|
input_lengths = mm_inputs["feature_attention_mask"].sum(-1)
|
||||||
|
input_lengths_leave = input_lengths % 100
|
||||||
|
feature_lengths = (input_lengths_leave - 1) // 2 + 1
|
||||||
|
audio_lengths = ((feature_lengths - 1) // 2 + 1 - 1) // 2 + 1 + (input_lengths // 100) * 13
|
||||||
|
else:
|
||||||
|
input_lengths = (mm_inputs["feature_attention_mask"].sum(-1).numpy() - 1) // 2 + 1
|
||||||
|
audio_lengths = (input_lengths - 2) // 2 + 1
|
||||||
else:
|
else:
|
||||||
mm_inputs = {}
|
mm_inputs = {}
|
||||||
image_grid_thw = [None] * len(images)
|
image_grid_thw = [None] * len(images)
|
||||||
@@ -2009,6 +2094,7 @@ class VideoLlavaPlugin(BasePlugin):
|
|||||||
|
|
||||||
PLUGINS = {
|
PLUGINS = {
|
||||||
"base": BasePlugin,
|
"base": BasePlugin,
|
||||||
|
"ernie_vl": ErnieVLPlugin,
|
||||||
"gemma3": Gemma3Plugin,
|
"gemma3": Gemma3Plugin,
|
||||||
"glm4v": GLM4VPlugin,
|
"glm4v": GLM4VPlugin,
|
||||||
"gemma3n": Gemma3nPlugin,
|
"gemma3n": Gemma3nPlugin,
|
||||||
@@ -2040,9 +2126,9 @@ def register_mm_plugin(name: str, plugin_class: type["BasePlugin"]) -> None:
|
|||||||
|
|
||||||
def get_mm_plugin(
|
def get_mm_plugin(
|
||||||
name: str,
|
name: str,
|
||||||
image_token: Optional[str] = None,
|
image_token: str | None = None,
|
||||||
video_token: Optional[str] = None,
|
video_token: str | None = None,
|
||||||
audio_token: Optional[str] = None,
|
audio_token: str | None = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> "BasePlugin":
|
) -> "BasePlugin":
|
||||||
r"""Get plugin for multimodal inputs."""
|
r"""Get plugin for multimodal inputs."""
|
||||||
|
|||||||
@@ -15,7 +15,7 @@
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Literal, Optional, Union
|
from typing import Any, Literal
|
||||||
|
|
||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download
|
||||||
|
|
||||||
@@ -30,43 +30,43 @@ class DatasetAttr:
|
|||||||
# basic configs
|
# basic configs
|
||||||
load_from: Literal["hf_hub", "ms_hub", "om_hub", "script", "file"]
|
load_from: Literal["hf_hub", "ms_hub", "om_hub", "script", "file"]
|
||||||
dataset_name: str
|
dataset_name: str
|
||||||
formatting: Literal["alpaca", "sharegpt"] = "alpaca"
|
formatting: Literal["alpaca", "sharegpt", "openai"] = "alpaca"
|
||||||
ranking: bool = False
|
ranking: bool = False
|
||||||
# extra configs
|
# extra configs
|
||||||
subset: Optional[str] = None
|
subset: str | None = None
|
||||||
split: str = "train"
|
split: str = "train"
|
||||||
folder: Optional[str] = None
|
folder: str | None = None
|
||||||
num_samples: Optional[int] = None
|
num_samples: int | None = None
|
||||||
# common columns
|
# common columns
|
||||||
system: Optional[str] = None
|
system: str | None = None
|
||||||
tools: Optional[str] = None
|
tools: str | None = None
|
||||||
images: Optional[str] = None
|
images: str | None = None
|
||||||
videos: Optional[str] = None
|
videos: str | None = None
|
||||||
audios: Optional[str] = None
|
audios: str | None = None
|
||||||
# dpo columns
|
# dpo columns
|
||||||
chosen: Optional[str] = None
|
chosen: str | None = None
|
||||||
rejected: Optional[str] = None
|
rejected: str | None = None
|
||||||
kto_tag: Optional[str] = None
|
kto_tag: str | None = None
|
||||||
# alpaca columns
|
# alpaca columns
|
||||||
prompt: Optional[str] = "instruction"
|
prompt: str | None = "instruction"
|
||||||
query: Optional[str] = "input"
|
query: str | None = "input"
|
||||||
response: Optional[str] = "output"
|
response: str | None = "output"
|
||||||
history: Optional[str] = None
|
history: str | None = None
|
||||||
# sharegpt columns
|
# sharegpt columns
|
||||||
messages: Optional[str] = "conversations"
|
messages: str | None = "conversations"
|
||||||
# sharegpt tags
|
# sharegpt tags
|
||||||
role_tag: Optional[str] = "from"
|
role_tag: str | None = "from"
|
||||||
content_tag: Optional[str] = "value"
|
content_tag: str | None = "value"
|
||||||
user_tag: Optional[str] = "human"
|
user_tag: str | None = "human"
|
||||||
assistant_tag: Optional[str] = "gpt"
|
assistant_tag: str | None = "gpt"
|
||||||
observation_tag: Optional[str] = "observation"
|
observation_tag: str | None = "observation"
|
||||||
function_tag: Optional[str] = "function_call"
|
function_tag: str | None = "function_call"
|
||||||
system_tag: Optional[str] = "system"
|
system_tag: str | None = "system"
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return self.dataset_name
|
return self.dataset_name
|
||||||
|
|
||||||
def set_attr(self, key: str, obj: dict[str, Any], default: Optional[Any] = None) -> None:
|
def set_attr(self, key: str, obj: dict[str, Any], default: Any | None = None) -> None:
|
||||||
setattr(self, key, obj.get(key, default))
|
setattr(self, key, obj.get(key, default))
|
||||||
|
|
||||||
def join(self, attr: dict[str, Any]) -> None:
|
def join(self, attr: dict[str, Any]) -> None:
|
||||||
@@ -90,7 +90,7 @@ class DatasetAttr:
|
|||||||
self.set_attr(tag, attr["tags"])
|
self.set_attr(tag, attr["tags"])
|
||||||
|
|
||||||
|
|
||||||
def get_dataset_list(dataset_names: Optional[list[str]], dataset_dir: Union[str, dict]) -> list["DatasetAttr"]:
|
def get_dataset_list(dataset_names: list[str] | None, dataset_dir: str | dict) -> list["DatasetAttr"]:
|
||||||
r"""Get the attributes of the datasets."""
|
r"""Get the attributes of the datasets."""
|
||||||
if dataset_names is None:
|
if dataset_names is None:
|
||||||
dataset_names = []
|
dataset_names = []
|
||||||
|
|||||||
@@ -49,6 +49,7 @@ class Template:
|
|||||||
default_system: str
|
default_system: str
|
||||||
stop_words: list[str]
|
stop_words: list[str]
|
||||||
thought_words: tuple[str, str]
|
thought_words: tuple[str, str]
|
||||||
|
tool_call_words: tuple[str, str]
|
||||||
efficient_eos: bool
|
efficient_eos: bool
|
||||||
replace_eos: bool
|
replace_eos: bool
|
||||||
replace_jinja_template: bool
|
replace_jinja_template: bool
|
||||||
@@ -156,7 +157,9 @@ class Template:
|
|||||||
elif message["role"] == Role.OBSERVATION:
|
elif message["role"] == Role.OBSERVATION:
|
||||||
elements += self.format_observation.apply(content=message["content"])
|
elements += self.format_observation.apply(content=message["content"])
|
||||||
elif message["role"] == Role.FUNCTION:
|
elif message["role"] == Role.FUNCTION:
|
||||||
elements += self.format_function.apply(content=message["content"], thought_words=self.thought_words)
|
elements += self.format_function.apply(
|
||||||
|
content=message["content"], thought_words=self.thought_words, tool_call_words=self.tool_call_words
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("Unexpected role: {}".format(message["role"]))
|
raise NotImplementedError("Unexpected role: {}".format(message["role"]))
|
||||||
|
|
||||||
@@ -199,9 +202,12 @@ class Template:
|
|||||||
logger.info_rank0(f"Add pad token: {tokenizer.pad_token}")
|
logger.info_rank0(f"Add pad token: {tokenizer.pad_token}")
|
||||||
|
|
||||||
if stop_words:
|
if stop_words:
|
||||||
num_added_tokens = tokenizer.add_special_tokens(
|
try:
|
||||||
dict(additional_special_tokens=stop_words), replace_additional_special_tokens=False
|
num_added_tokens = tokenizer.add_special_tokens(
|
||||||
)
|
dict(additional_special_tokens=stop_words), replace_additional_special_tokens=False
|
||||||
|
)
|
||||||
|
except TypeError:
|
||||||
|
num_added_tokens = tokenizer.add_special_tokens(dict(additional_special_tokens=stop_words))
|
||||||
logger.info_rank0("Add {} to stop words.".format(",".join(stop_words)))
|
logger.info_rank0("Add {} to stop words.".format(",".join(stop_words)))
|
||||||
if num_added_tokens > 0:
|
if num_added_tokens > 0:
|
||||||
logger.warning_rank0("New tokens have been added, make sure `resize_vocab` is True.")
|
logger.warning_rank0("New tokens have been added, make sure `resize_vocab` is True.")
|
||||||
@@ -468,6 +474,7 @@ def register_template(
|
|||||||
default_system: str = "",
|
default_system: str = "",
|
||||||
stop_words: Optional[list[str]] = None,
|
stop_words: Optional[list[str]] = None,
|
||||||
thought_words: Optional[tuple[str, str]] = None,
|
thought_words: Optional[tuple[str, str]] = None,
|
||||||
|
tool_call_words: Optional[tuple[str, str]] = None,
|
||||||
efficient_eos: bool = False,
|
efficient_eos: bool = False,
|
||||||
replace_eos: bool = False,
|
replace_eos: bool = False,
|
||||||
replace_jinja_template: bool = False,
|
replace_jinja_template: bool = False,
|
||||||
@@ -519,6 +526,7 @@ def register_template(
|
|||||||
default_system=default_system,
|
default_system=default_system,
|
||||||
stop_words=stop_words or [],
|
stop_words=stop_words or [],
|
||||||
thought_words=thought_words or ("<think>\n", "\n</think>\n\n"),
|
thought_words=thought_words or ("<think>\n", "\n</think>\n\n"),
|
||||||
|
tool_call_words=tool_call_words or ("<tool_call>", "</tool_call>"),
|
||||||
efficient_eos=efficient_eos,
|
efficient_eos=efficient_eos,
|
||||||
replace_eos=replace_eos,
|
replace_eos=replace_eos,
|
||||||
replace_jinja_template=replace_jinja_template,
|
replace_jinja_template=replace_jinja_template,
|
||||||
@@ -580,6 +588,7 @@ def parse_template(tokenizer: "PreTrainedTokenizer") -> "Template":
|
|||||||
default_system=default_system,
|
default_system=default_system,
|
||||||
stop_words=[],
|
stop_words=[],
|
||||||
thought_words=("<think>\n", "\n</think>\n\n"),
|
thought_words=("<think>\n", "\n</think>\n\n"),
|
||||||
|
tool_call_words=("<tool_call>", "</tool_call>"),
|
||||||
efficient_eos=False,
|
efficient_eos=False,
|
||||||
replace_eos=False,
|
replace_eos=False,
|
||||||
replace_jinja_template=False,
|
replace_jinja_template=False,
|
||||||
@@ -616,7 +625,14 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args:
|
|||||||
logger.info_rank0(f"Using default system message: {data_args.default_system}.")
|
logger.info_rank0(f"Using default system message: {data_args.default_system}.")
|
||||||
template.default_system = data_args.default_system
|
template.default_system = data_args.default_system
|
||||||
|
|
||||||
template.enable_thinking = data_args.enable_thinking
|
if isinstance(template, ReasoningTemplate):
|
||||||
|
logger.warning_rank0(
|
||||||
|
"You are using reasoning template, "
|
||||||
|
"please add `_nothink` suffix if the model is not a reasoning model. "
|
||||||
|
"e.g., qwen3_vl_nothink"
|
||||||
|
)
|
||||||
|
template.enable_thinking = data_args.enable_thinking
|
||||||
|
|
||||||
template.fix_special_tokens(tokenizer)
|
template.fix_special_tokens(tokenizer)
|
||||||
template.fix_jinja_template(tokenizer)
|
template.fix_jinja_template(tokenizer)
|
||||||
return template
|
return template
|
||||||
@@ -956,6 +972,19 @@ register_template(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_template(
|
||||||
|
name="ernie_vl",
|
||||||
|
format_user=StringFormatter(slots=["User: {{content}}"]),
|
||||||
|
format_assistant=StringFormatter(slots=["\nAssistant: {{content}}<|end_of_sentence|>"]),
|
||||||
|
format_system=StringFormatter(slots=["{{content}}\n"]),
|
||||||
|
stop_words=["<|end_of_sentence|>"],
|
||||||
|
replace_eos=True,
|
||||||
|
replace_jinja_template=True,
|
||||||
|
template_class=ReasoningTemplate,
|
||||||
|
mm_plugin=get_mm_plugin(name="ernie_vl", image_token="<|IMAGE_PLACEHOLDER|>", video_token="<|VIDEO_PLACEHOLDER|>"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
register_template(
|
register_template(
|
||||||
name="exaone",
|
name="exaone",
|
||||||
format_user=StringFormatter(slots=["[|user|]{{content}}\n[|assistant|]"]),
|
format_user=StringFormatter(slots=["[|user|]{{content}}\n[|assistant|]"]),
|
||||||
@@ -1105,7 +1134,7 @@ register_template(
|
|||||||
|
|
||||||
# copied from glm4 template
|
# copied from glm4 template
|
||||||
register_template(
|
register_template(
|
||||||
name="glm4v_moe",
|
name="glm4_5v",
|
||||||
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]),
|
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]),
|
||||||
format_assistant=StringFormatter(slots=["\n{{content}}"]),
|
format_assistant=StringFormatter(slots=["\n{{content}}"]),
|
||||||
format_system=StringFormatter(slots=["<|system|>\n{{content}}"]),
|
format_system=StringFormatter(slots=["<|system|>\n{{content}}"]),
|
||||||
@@ -1137,7 +1166,7 @@ register_template(
|
|||||||
|
|
||||||
|
|
||||||
register_template(
|
register_template(
|
||||||
name="gpt",
|
name="gpt_oss",
|
||||||
format_user=StringFormatter(slots=["<|start|>user<|message|>{{content}}<|end|><|start|>assistant"]),
|
format_user=StringFormatter(slots=["<|start|>user<|message|>{{content}}<|end|><|start|>assistant"]),
|
||||||
format_assistant=StringFormatter(slots=["{{content}}<|end|>"]),
|
format_assistant=StringFormatter(slots=["{{content}}<|end|>"]),
|
||||||
format_system=StringFormatter(slots=["<|start|>system<|message|>{{content}}<|end|>"]),
|
format_system=StringFormatter(slots=["<|start|>system<|message|>{{content}}<|end|>"]),
|
||||||
@@ -1201,10 +1230,10 @@ register_template(
|
|||||||
|
|
||||||
register_template(
|
register_template(
|
||||||
name="hunyuan",
|
name="hunyuan",
|
||||||
format_user=StringFormatter(slots=["<|bos|>user\n{{content}}<|eos|>\n<|bos|>assistant\n"]),
|
format_user=StringFormatter(slots=["{{content}}<|extra_0|>"]),
|
||||||
format_assistant=StringFormatter(slots=["{{content}}<|eos|>\n"]),
|
format_assistant=StringFormatter(slots=["{{content}}<|eos|>"]),
|
||||||
format_system=StringFormatter(slots=["<|bos|>system\n{{content}}<|eos|>\n"]),
|
format_system=StringFormatter(slots=["{{content}}<|extra_4|>"]),
|
||||||
format_prefix=EmptyFormatter(slots=["<|bos|>"]),
|
format_prefix=EmptyFormatter(slots=["<|startoftext|>"]),
|
||||||
stop_words=["<|eos|>"],
|
stop_words=["<|eos|>"],
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1581,6 +1610,26 @@ register_template(
|
|||||||
template_class=ReasoningTemplate,
|
template_class=ReasoningTemplate,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# copied from qwen template
|
||||||
|
register_template(
|
||||||
|
name="mimo_v2",
|
||||||
|
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||||
|
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
|
||||||
|
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
||||||
|
format_function=FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen"),
|
||||||
|
format_observation=StringFormatter(
|
||||||
|
slots=["<|im_start|>user\n<tool_response>\n{{content}}\n</tool_response><|im_end|>\n<|im_start|>assistant\n"]
|
||||||
|
),
|
||||||
|
format_tools=ToolFormatter(tool_format="qwen"),
|
||||||
|
default_system="You are MiMo, a helpful AI assistant engineered by Xiaomi.",
|
||||||
|
stop_words=["<|im_end|>"],
|
||||||
|
replace_eos=True,
|
||||||
|
thought_words=("<think>", "</think>"),
|
||||||
|
template_class=ReasoningTemplate,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# copied from qwen2vl
|
# copied from qwen2vl
|
||||||
register_template(
|
register_template(
|
||||||
name="mimo_vl",
|
name="mimo_vl",
|
||||||
@@ -1664,6 +1713,19 @@ register_template(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_template(
|
||||||
|
name="ministral3",
|
||||||
|
format_user=StringFormatter(slots=["[INST]{{content}}[/INST]"]),
|
||||||
|
format_system=StringFormatter(slots=["{{content}}\n\n"]),
|
||||||
|
format_function=FunctionFormatter(slots=["[TOOL_CALLS]{{content}}", {"eos_token"}], tool_format="mistral"),
|
||||||
|
format_observation=StringFormatter(slots=["""[TOOL_RESULTS]{"content": {{content}}}[/TOOL_RESULTS]"""]),
|
||||||
|
format_tools=ToolFormatter(tool_format="mistral"),
|
||||||
|
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||||
|
template_class=Llama2Template,
|
||||||
|
mm_plugin=get_mm_plugin(name="pixtral", image_token="[IMG]"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
register_template(
|
register_template(
|
||||||
name="olmo",
|
name="olmo",
|
||||||
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>\n"]),
|
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>\n"]),
|
||||||
|
|||||||
@@ -15,7 +15,6 @@
|
|||||||
import os
|
import os
|
||||||
from collections import OrderedDict, defaultdict
|
from collections import OrderedDict, defaultdict
|
||||||
from enum import Enum, unique
|
from enum import Enum, unique
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from peft.utils import SAFETENSORS_WEIGHTS_NAME as SAFE_ADAPTER_WEIGHTS_NAME
|
from peft.utils import SAFETENSORS_WEIGHTS_NAME as SAFE_ADAPTER_WEIGHTS_NAME
|
||||||
from peft.utils import WEIGHTS_NAME as ADAPTER_WEIGHTS_NAME
|
from peft.utils import WEIGHTS_NAME as ADAPTER_WEIGHTS_NAME
|
||||||
@@ -56,6 +55,19 @@ LAYERNORM_NAMES = {"norm", "ln"}
|
|||||||
|
|
||||||
LLAMABOARD_CONFIG = "llamaboard_config.yaml"
|
LLAMABOARD_CONFIG = "llamaboard_config.yaml"
|
||||||
|
|
||||||
|
MCA_SUPPORTED_MODELS = {
|
||||||
|
"deepseek_v3",
|
||||||
|
"llama",
|
||||||
|
"mistral",
|
||||||
|
"mixtral",
|
||||||
|
"qwen2",
|
||||||
|
"qwen2_vl",
|
||||||
|
"qwen2_5_vl",
|
||||||
|
"qwen3",
|
||||||
|
"qwen3_moe",
|
||||||
|
"qwen3_next",
|
||||||
|
}
|
||||||
|
|
||||||
METHODS = ["full", "freeze", "lora", "oft"]
|
METHODS = ["full", "freeze", "lora", "oft"]
|
||||||
|
|
||||||
MOD_SUPPORTED_MODELS = {"bloom", "falcon", "gemma", "llama", "mistral", "mixtral", "phi", "starcoder2"}
|
MOD_SUPPORTED_MODELS = {"bloom", "falcon", "gemma", "llama", "mistral", "mixtral", "phi", "starcoder2"}
|
||||||
@@ -101,12 +113,14 @@ class AttentionFunction(str, Enum):
|
|||||||
DISABLED = "disabled"
|
DISABLED = "disabled"
|
||||||
SDPA = "sdpa"
|
SDPA = "sdpa"
|
||||||
FA2 = "fa2"
|
FA2 = "fa2"
|
||||||
|
FA3 = "fa3"
|
||||||
|
|
||||||
|
|
||||||
class EngineName(str, Enum):
|
class EngineName(str, Enum):
|
||||||
HF = "huggingface"
|
HF = "huggingface"
|
||||||
VLLM = "vllm"
|
VLLM = "vllm"
|
||||||
SGLANG = "sglang"
|
SGLANG = "sglang"
|
||||||
|
KT = "ktransformers"
|
||||||
|
|
||||||
|
|
||||||
class DownloadSource(str, Enum):
|
class DownloadSource(str, Enum):
|
||||||
@@ -127,6 +141,7 @@ class QuantizationMethod(str, Enum):
|
|||||||
EETQ = "eetq"
|
EETQ = "eetq"
|
||||||
HQQ = "hqq"
|
HQQ = "hqq"
|
||||||
MXFP4 = "mxfp4"
|
MXFP4 = "mxfp4"
|
||||||
|
FP8 = "fp8"
|
||||||
|
|
||||||
|
|
||||||
class RopeScaling(str, Enum):
|
class RopeScaling(str, Enum):
|
||||||
@@ -138,7 +153,7 @@ class RopeScaling(str, Enum):
|
|||||||
|
|
||||||
def register_model_group(
|
def register_model_group(
|
||||||
models: dict[str, dict[DownloadSource, str]],
|
models: dict[str, dict[DownloadSource, str]],
|
||||||
template: Optional[str] = None,
|
template: str | None = None,
|
||||||
multimodal: bool = False,
|
multimodal: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
for name, path in models.items():
|
for name, path in models.items():
|
||||||
@@ -643,6 +658,26 @@ register_model_group(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_model_group(
|
||||||
|
models={
|
||||||
|
"ERNIE-4.5-VL-28B-A3B-PT": {
|
||||||
|
DownloadSource.DEFAULT: "baidu/ERNIE-4.5-VL-28B-A3B-PT",
|
||||||
|
DownloadSource.MODELSCOPE: "PaddlePaddle/ERNIE-4.5-VL-28B-A3B-PT",
|
||||||
|
},
|
||||||
|
"ERNIE-4.5-VL-28B-A3B-Thinking": {
|
||||||
|
DownloadSource.DEFAULT: "baidu/ERNIE-4.5-VL-28B-A3B-Thinking",
|
||||||
|
DownloadSource.MODELSCOPE: "PaddlePaddle/ERNIE-4.5-VL-28B-A3B-Thinking",
|
||||||
|
},
|
||||||
|
"ERNIE-4.5-VL-424B-A47B-Base-PT": {
|
||||||
|
DownloadSource.DEFAULT: "baidu/ERNIE-4.5-VL-424B-A47B-PT",
|
||||||
|
DownloadSource.MODELSCOPE: "PaddlePaddle/ERNIE-4.5-VL-424B-A47B-PT",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
template="ernie_vl",
|
||||||
|
multimodal=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
"EXAONE-3.0-7.8B-Instruct": {
|
"EXAONE-3.0-7.8B-Instruct": {
|
||||||
@@ -969,9 +1004,17 @@ register_model_group(
|
|||||||
"GLM-4.5V-Air-Thinking": {
|
"GLM-4.5V-Air-Thinking": {
|
||||||
DownloadSource.DEFAULT: "zai-org/GLM-4.5V",
|
DownloadSource.DEFAULT: "zai-org/GLM-4.5V",
|
||||||
DownloadSource.MODELSCOPE: "ZhipuAI/GLM-4.5V",
|
DownloadSource.MODELSCOPE: "ZhipuAI/GLM-4.5V",
|
||||||
}
|
},
|
||||||
|
"GLM-4.6V": {
|
||||||
|
DownloadSource.DEFAULT: "zai-org/GLM-4.6V",
|
||||||
|
DownloadSource.MODELSCOPE: "ZhipuAI/GLM-4.6V",
|
||||||
|
},
|
||||||
|
"GLM-4.6V-Flash": {
|
||||||
|
DownloadSource.DEFAULT: "zai-org/GLM-4.6V-Flash",
|
||||||
|
DownloadSource.MODELSCOPE: "ZhipuAI/GLM-4.6V-Flash",
|
||||||
|
},
|
||||||
},
|
},
|
||||||
template="glm4v_moe",
|
template="glm4_5v",
|
||||||
multimodal=True,
|
multimodal=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1024,7 +1067,7 @@ register_model_group(
|
|||||||
DownloadSource.MODELSCOPE: "openai/gpt-oss-120b",
|
DownloadSource.MODELSCOPE: "openai/gpt-oss-120b",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
template="gpt",
|
template="gpt_oss",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -1152,6 +1195,10 @@ register_model_group(
|
|||||||
DownloadSource.DEFAULT: "tencent/Hunyuan-7B-Instruct",
|
DownloadSource.DEFAULT: "tencent/Hunyuan-7B-Instruct",
|
||||||
DownloadSource.MODELSCOPE: "AI-ModelScope/Hunyuan-7B-Instruct",
|
DownloadSource.MODELSCOPE: "AI-ModelScope/Hunyuan-7B-Instruct",
|
||||||
},
|
},
|
||||||
|
"Hunyuan-MT-7B-Instruct": {
|
||||||
|
DownloadSource.DEFAULT: "tencent/Hunyuan-MT-7B",
|
||||||
|
DownloadSource.MODELSCOPE: "Tencent-Hunyuan/Hunyuan-MT-7B",
|
||||||
|
},
|
||||||
},
|
},
|
||||||
template="hunyuan",
|
template="hunyuan",
|
||||||
)
|
)
|
||||||
@@ -1756,6 +1803,21 @@ register_model_group(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_model_group(
|
||||||
|
models={
|
||||||
|
"MiMo-V2-Flash-Base": {
|
||||||
|
DownloadSource.DEFAULT: "XiaomiMiMo/MiMo-V2-Flash-Base",
|
||||||
|
DownloadSource.MODELSCOPE: "XiaomiMiMo/MiMo-V2-Flash-Base",
|
||||||
|
},
|
||||||
|
"MiMo-V2-Flash": {
|
||||||
|
DownloadSource.DEFAULT: "XiaomiMiMo/MiMo-V2-Flash",
|
||||||
|
DownloadSource.MODELSCOPE: "XiaomiMiMo/MiMo-V2-Flash",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
template="mimo_v2",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
"MiMo-7B-VL-RL": {
|
"MiMo-7B-VL-RL": {
|
||||||
@@ -1780,7 +1842,7 @@ register_model_group(
|
|||||||
},
|
},
|
||||||
"MiMo-VL-7B-SFT-2508": {
|
"MiMo-VL-7B-SFT-2508": {
|
||||||
DownloadSource.DEFAULT: "XiaomiMiMo/MiMo-VL-7B-SFT-2508",
|
DownloadSource.DEFAULT: "XiaomiMiMo/MiMo-VL-7B-SFT-2508",
|
||||||
DownloadSource.DEFAULT: "XiaomiMiMo/MiMo-VL-7B-SFT-2508",
|
DownloadSource.MODELSCOPE: "XiaomiMiMo/MiMo-VL-7B-SFT-2508",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
template="qwen2_vl",
|
template="qwen2_vl",
|
||||||
@@ -1931,6 +1993,37 @@ register_model_group(
|
|||||||
template="mistral",
|
template="mistral",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
register_model_group(
|
||||||
|
models={
|
||||||
|
"Ministral-3-3B-Base-2512": {
|
||||||
|
DownloadSource.DEFAULT: "mistralai/Ministral-3-3B-Base-2512",
|
||||||
|
DownloadSource.MODELSCOPE: "mistralai/Ministral-3-3B-Base-2512",
|
||||||
|
},
|
||||||
|
"Ministral-3-8B-Base-2512": {
|
||||||
|
DownloadSource.DEFAULT: "mistralai/Ministral-3-8B-Base-2512",
|
||||||
|
DownloadSource.MODELSCOPE: "mistralai/Ministral-3-8B-Base-2512",
|
||||||
|
},
|
||||||
|
"Ministral-3-14B-Base-2512": {
|
||||||
|
DownloadSource.DEFAULT: "mistralai/Ministral-3-14B-Base-2512",
|
||||||
|
DownloadSource.MODELSCOPE: "mistralai/Ministral-3-14B-Base-2512",
|
||||||
|
},
|
||||||
|
"Ministral-3-3B-Instruct-2512": {
|
||||||
|
DownloadSource.DEFAULT: "mistralai/Ministral-3-3B-Instruct-2512",
|
||||||
|
DownloadSource.MODELSCOPE: "mistralai/Ministral-3-3B-Instruct-2512",
|
||||||
|
},
|
||||||
|
"Ministral-3-8B-Instruct-2512": {
|
||||||
|
DownloadSource.DEFAULT: "mistralai/Ministral-3-8B-Instruct-2512",
|
||||||
|
DownloadSource.MODELSCOPE: "mistralai/Ministral-3-8B-Instruct-2512",
|
||||||
|
},
|
||||||
|
"Ministral-3-14B-Instruct-2512": {
|
||||||
|
DownloadSource.DEFAULT: "mistralai/Ministral-3-14B-Instruct-2512",
|
||||||
|
DownloadSource.MODELSCOPE: "mistralai/Ministral-3-14B-Instruct-2512",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
template="ministral3",
|
||||||
|
multimodal=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
@@ -3193,6 +3286,10 @@ register_model_group(
|
|||||||
|
|
||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
|
"Qwen3-VL-2B-Instruct": {
|
||||||
|
DownloadSource.DEFAULT: "Qwen/Qwen3-VL-2B-Instruct",
|
||||||
|
DownloadSource.MODELSCOPE: "Qwen/Qwen3-VL-2B-Instruct",
|
||||||
|
},
|
||||||
"Qwen3-VL-4B-Instruct": {
|
"Qwen3-VL-4B-Instruct": {
|
||||||
DownloadSource.DEFAULT: "Qwen/Qwen3-VL-4B-Instruct",
|
DownloadSource.DEFAULT: "Qwen/Qwen3-VL-4B-Instruct",
|
||||||
DownloadSource.MODELSCOPE: "Qwen/Qwen3-VL-4B-Instruct",
|
DownloadSource.MODELSCOPE: "Qwen/Qwen3-VL-4B-Instruct",
|
||||||
@@ -3201,6 +3298,10 @@ register_model_group(
|
|||||||
DownloadSource.DEFAULT: "Qwen/Qwen3-VL-8B-Instruct",
|
DownloadSource.DEFAULT: "Qwen/Qwen3-VL-8B-Instruct",
|
||||||
DownloadSource.MODELSCOPE: "Qwen/Qwen3-VL-8B-Instruct",
|
DownloadSource.MODELSCOPE: "Qwen/Qwen3-VL-8B-Instruct",
|
||||||
},
|
},
|
||||||
|
"Qwen3-VL-32B-Instruct": {
|
||||||
|
DownloadSource.DEFAULT: "Qwen/Qwen3-VL-32B-Instruct",
|
||||||
|
DownloadSource.MODELSCOPE: "Qwen/Qwen3-VL-32B-Instruct",
|
||||||
|
},
|
||||||
"Qwen3-VL-30B-A3B-Instruct": {
|
"Qwen3-VL-30B-A3B-Instruct": {
|
||||||
DownloadSource.DEFAULT: "Qwen/Qwen3-VL-30B-A3B-Instruct",
|
DownloadSource.DEFAULT: "Qwen/Qwen3-VL-30B-A3B-Instruct",
|
||||||
DownloadSource.MODELSCOPE: "Qwen/Qwen3-VL-30B-A3B-Instruct",
|
DownloadSource.MODELSCOPE: "Qwen/Qwen3-VL-30B-A3B-Instruct",
|
||||||
@@ -3217,6 +3318,10 @@ register_model_group(
|
|||||||
|
|
||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
|
"Qwen3-VL-2B-Thinking": {
|
||||||
|
DownloadSource.DEFAULT: "Qwen/Qwen3-VL-2B-Thinking",
|
||||||
|
DownloadSource.MODELSCOPE: "Qwen/Qwen3-VL-2B-Thinking",
|
||||||
|
},
|
||||||
"Qwen3-VL-4B-Thinking": {
|
"Qwen3-VL-4B-Thinking": {
|
||||||
DownloadSource.DEFAULT: "Qwen/Qwen3-VL-4B-Thinking",
|
DownloadSource.DEFAULT: "Qwen/Qwen3-VL-4B-Thinking",
|
||||||
DownloadSource.MODELSCOPE: "Qwen/Qwen3-VL-4B-Thinking",
|
DownloadSource.MODELSCOPE: "Qwen/Qwen3-VL-4B-Thinking",
|
||||||
@@ -3225,6 +3330,10 @@ register_model_group(
|
|||||||
DownloadSource.DEFAULT: "Qwen/Qwen3-VL-8B-Thinking",
|
DownloadSource.DEFAULT: "Qwen/Qwen3-VL-8B-Thinking",
|
||||||
DownloadSource.MODELSCOPE: "Qwen/Qwen3-VL-8B-Thinking",
|
DownloadSource.MODELSCOPE: "Qwen/Qwen3-VL-8B-Thinking",
|
||||||
},
|
},
|
||||||
|
"Qwen3-VL-32B-Thinking": {
|
||||||
|
DownloadSource.DEFAULT: "Qwen/Qwen3-VL-32B-Thinking",
|
||||||
|
DownloadSource.MODELSCOPE: "Qwen/Qwen3-VL-32B-Thinking",
|
||||||
|
},
|
||||||
"Qwen3-VL-30B-A3B-Thinking": {
|
"Qwen3-VL-30B-A3B-Thinking": {
|
||||||
DownloadSource.DEFAULT: "Qwen/Qwen3-VL-30B-A3B-Thinking",
|
DownloadSource.DEFAULT: "Qwen/Qwen3-VL-30B-A3B-Thinking",
|
||||||
DownloadSource.MODELSCOPE: "Qwen/Qwen3-VL-30B-A3B-Thinking",
|
DownloadSource.MODELSCOPE: "Qwen/Qwen3-VL-30B-A3B-Thinking",
|
||||||
@@ -3438,6 +3547,17 @@ register_model_group(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_model_group(
|
||||||
|
models={
|
||||||
|
"VibeThinker-1.5B": {
|
||||||
|
DownloadSource.DEFAULT: "WeiboAI/VibeThinker-1.5B",
|
||||||
|
DownloadSource.MODELSCOPE: "WeiboAI/VibeThinker-1.5B",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
template="qwen3",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
"Vicuna-v1.5-7B-Chat": {
|
"Vicuna-v1.5-7B-Chat": {
|
||||||
|
|||||||
@@ -117,7 +117,7 @@ def _configure_library_root_logger() -> None:
|
|||||||
library_root_logger.propagate = False
|
library_root_logger.propagate = False
|
||||||
|
|
||||||
|
|
||||||
def get_logger(name: Optional[str] = None) -> "_Logger":
|
def get_logger(name: str | None = None) -> "_Logger":
|
||||||
r"""Return a logger with the specified name. It it not supposed to be accessed externally."""
|
r"""Return a logger with the specified name. It it not supposed to be accessed externally."""
|
||||||
if name is None:
|
if name is None:
|
||||||
name = _get_library_name()
|
name = _get_library_name()
|
||||||
|
|||||||
@@ -313,6 +313,10 @@ def use_ray() -> bool:
|
|||||||
return is_env_enabled("USE_RAY")
|
return is_env_enabled("USE_RAY")
|
||||||
|
|
||||||
|
|
||||||
|
def use_kt() -> bool:
|
||||||
|
return is_env_enabled("USE_KT")
|
||||||
|
|
||||||
|
|
||||||
def find_available_port() -> int:
|
def find_available_port() -> int:
|
||||||
r"""Find an available port on the local machine."""
|
r"""Find an available port on the local machine."""
|
||||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||||
@@ -328,3 +332,7 @@ def fix_proxy(ipv6_enabled: bool = False) -> None:
|
|||||||
if ipv6_enabled:
|
if ipv6_enabled:
|
||||||
os.environ.pop("http_proxy", None)
|
os.environ.pop("http_proxy", None)
|
||||||
os.environ.pop("HTTP_PROXY", None)
|
os.environ.pop("HTTP_PROXY", None)
|
||||||
|
os.environ.pop("https_proxy", None)
|
||||||
|
os.environ.pop("HTTPS_PROXY", None)
|
||||||
|
os.environ.pop("all_proxy", None)
|
||||||
|
os.environ.pop("ALL_PROXY", None)
|
||||||
|
|||||||
@@ -70,6 +70,10 @@ def is_matplotlib_available():
|
|||||||
return _is_package_available("matplotlib")
|
return _is_package_available("matplotlib")
|
||||||
|
|
||||||
|
|
||||||
|
def is_mcore_adapter_available():
|
||||||
|
return _is_package_available("mcore_adapter")
|
||||||
|
|
||||||
|
|
||||||
def is_pillow_available():
|
def is_pillow_available():
|
||||||
return _is_package_available("PIL")
|
return _is_package_available("PIL")
|
||||||
|
|
||||||
@@ -78,6 +82,10 @@ def is_ray_available():
|
|||||||
return _is_package_available("ray")
|
return _is_package_available("ray")
|
||||||
|
|
||||||
|
|
||||||
|
def is_kt_available():
|
||||||
|
return _is_package_available("ktransformers")
|
||||||
|
|
||||||
|
|
||||||
def is_requests_available():
|
def is_requests_available():
|
||||||
return _is_package_available("requests")
|
return _is_package_available("requests")
|
||||||
|
|
||||||
@@ -86,6 +94,14 @@ def is_rouge_available():
|
|||||||
return _is_package_available("rouge_chinese")
|
return _is_package_available("rouge_chinese")
|
||||||
|
|
||||||
|
|
||||||
|
def is_safetensors_available():
|
||||||
|
return _is_package_available("safetensors")
|
||||||
|
|
||||||
|
|
||||||
|
def is_sglang_available():
|
||||||
|
return _is_package_available("sglang")
|
||||||
|
|
||||||
|
|
||||||
def is_starlette_available():
|
def is_starlette_available():
|
||||||
return _is_package_available("sse_starlette")
|
return _is_package_available("sse_starlette")
|
||||||
|
|
||||||
@@ -95,13 +111,14 @@ def is_transformers_version_greater_than(content: str):
|
|||||||
return _get_package_version("transformers") >= version.parse(content)
|
return _get_package_version("transformers") >= version.parse(content)
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache
|
||||||
|
def is_torch_version_greater_than(content: str):
|
||||||
|
return _get_package_version("torch") >= version.parse(content)
|
||||||
|
|
||||||
|
|
||||||
def is_uvicorn_available():
|
def is_uvicorn_available():
|
||||||
return _is_package_available("uvicorn")
|
return _is_package_available("uvicorn")
|
||||||
|
|
||||||
|
|
||||||
def is_vllm_available():
|
def is_vllm_available():
|
||||||
return _is_package_available("vllm")
|
return _is_package_available("vllm")
|
||||||
|
|
||||||
|
|
||||||
def is_sglang_available():
|
|
||||||
return _is_package_available("sglang")
|
|
||||||
|
|||||||
@@ -16,22 +16,22 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from dataclasses import asdict, dataclass, field
|
from dataclasses import asdict, dataclass, field
|
||||||
from typing import Any, Literal, Optional
|
from typing import Any, Literal
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DataArguments:
|
class DataArguments:
|
||||||
r"""Arguments pertaining to what data we are going to input our model for training and evaluation."""
|
r"""Arguments pertaining to what data we are going to input our model for training and evaluation."""
|
||||||
|
|
||||||
template: Optional[str] = field(
|
template: str | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Which template to use for constructing prompts in training and inference."},
|
metadata={"help": "Which template to use for constructing prompts in training and inference."},
|
||||||
)
|
)
|
||||||
dataset: Optional[str] = field(
|
dataset: str | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "The name of dataset(s) to use for training. Use commas to separate multiple datasets."},
|
metadata={"help": "The name of dataset(s) to use for training. Use commas to separate multiple datasets."},
|
||||||
)
|
)
|
||||||
eval_dataset: Optional[str] = field(
|
eval_dataset: str | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "The name of dataset(s) to use for evaluation. Use commas to separate multiple datasets."},
|
metadata={"help": "The name of dataset(s) to use for evaluation. Use commas to separate multiple datasets."},
|
||||||
)
|
)
|
||||||
@@ -39,7 +39,7 @@ class DataArguments:
|
|||||||
default="data",
|
default="data",
|
||||||
metadata={"help": "Path to the folder containing the datasets."},
|
metadata={"help": "Path to the folder containing the datasets."},
|
||||||
)
|
)
|
||||||
media_dir: Optional[str] = field(
|
media_dir: str | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Path to the folder containing the images, videos or audios. Defaults to `dataset_dir`."},
|
metadata={"help": "Path to the folder containing the images, videos or audios. Defaults to `dataset_dir`."},
|
||||||
)
|
)
|
||||||
@@ -67,7 +67,7 @@ class DataArguments:
|
|||||||
default="concat",
|
default="concat",
|
||||||
metadata={"help": "Strategy to use in dataset mixing (concat/interleave) (undersampling/oversampling)."},
|
metadata={"help": "Strategy to use in dataset mixing (concat/interleave) (undersampling/oversampling)."},
|
||||||
)
|
)
|
||||||
interleave_probs: Optional[str] = field(
|
interleave_probs: str | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Probabilities to sample data from datasets. Use commas to separate multiple datasets."},
|
metadata={"help": "Probabilities to sample data from datasets. Use commas to separate multiple datasets."},
|
||||||
)
|
)
|
||||||
@@ -79,15 +79,15 @@ class DataArguments:
|
|||||||
default=1000,
|
default=1000,
|
||||||
metadata={"help": "The number of examples in one group in pre-processing."},
|
metadata={"help": "The number of examples in one group in pre-processing."},
|
||||||
)
|
)
|
||||||
preprocessing_num_workers: Optional[int] = field(
|
preprocessing_num_workers: int | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "The number of processes to use for the pre-processing."},
|
metadata={"help": "The number of processes to use for the pre-processing."},
|
||||||
)
|
)
|
||||||
max_samples: Optional[int] = field(
|
max_samples: int | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "For debugging purposes, truncate the number of examples for each dataset."},
|
metadata={"help": "For debugging purposes, truncate the number of examples for each dataset."},
|
||||||
)
|
)
|
||||||
eval_num_beams: Optional[int] = field(
|
eval_num_beams: int | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Number of beams to use for evaluation. This argument will be passed to `model.generate`"},
|
metadata={"help": "Number of beams to use for evaluation. This argument will be passed to `model.generate`"},
|
||||||
)
|
)
|
||||||
@@ -103,7 +103,7 @@ class DataArguments:
|
|||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Whether or not to evaluate on each dataset separately."},
|
metadata={"help": "Whether or not to evaluate on each dataset separately."},
|
||||||
)
|
)
|
||||||
packing: Optional[bool] = field(
|
packing: bool | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Enable sequences packing in training. Will automatically enable in pre-training."},
|
metadata={"help": "Enable sequences packing in training. Will automatically enable in pre-training."},
|
||||||
)
|
)
|
||||||
@@ -111,19 +111,19 @@ class DataArguments:
|
|||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Enable sequence packing without cross-attention."},
|
metadata={"help": "Enable sequence packing without cross-attention."},
|
||||||
)
|
)
|
||||||
tool_format: Optional[str] = field(
|
tool_format: str | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Tool format to use for constructing function calling examples."},
|
metadata={"help": "Tool format to use for constructing function calling examples."},
|
||||||
)
|
)
|
||||||
default_system: Optional[str] = field(
|
default_system: str | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Override the default system message in the template."},
|
metadata={"help": "Override the default system message in the template."},
|
||||||
)
|
)
|
||||||
enable_thinking: Optional[bool] = field(
|
enable_thinking: bool | None = field(
|
||||||
default=True,
|
default=True,
|
||||||
metadata={"help": "Whether or not to enable thinking mode for reasoning models."},
|
metadata={"help": "Whether or not to enable thinking mode for reasoning models."},
|
||||||
)
|
)
|
||||||
tokenized_path: Optional[str] = field(
|
tokenized_path: str | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": (
|
"help": (
|
||||||
|
|||||||
@@ -14,7 +14,7 @@
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Literal, Optional
|
from typing import Literal
|
||||||
|
|
||||||
from datasets import DownloadMode
|
from datasets import DownloadMode
|
||||||
|
|
||||||
@@ -46,7 +46,7 @@ class EvaluationArguments:
|
|||||||
default=5,
|
default=5,
|
||||||
metadata={"help": "Number of examplars for few-shot learning."},
|
metadata={"help": "Number of examplars for few-shot learning."},
|
||||||
)
|
)
|
||||||
save_dir: Optional[str] = field(
|
save_dir: str | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Path to save the evaluation results."},
|
metadata={"help": "Path to save the evaluation results."},
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -13,7 +13,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from dataclasses import asdict, dataclass, field
|
from dataclasses import asdict, dataclass, field
|
||||||
from typing import Any, Literal, Optional
|
from typing import Any, Literal
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -40,7 +40,7 @@ class FreezeArguments:
|
|||||||
)
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
freeze_extra_modules: Optional[str] = field(
|
freeze_extra_modules: str | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": (
|
"help": (
|
||||||
@@ -56,7 +56,7 @@ class FreezeArguments:
|
|||||||
class LoraArguments:
|
class LoraArguments:
|
||||||
r"""Arguments pertaining to the LoRA training."""
|
r"""Arguments pertaining to the LoRA training."""
|
||||||
|
|
||||||
additional_target: Optional[str] = field(
|
additional_target: str | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": (
|
"help": (
|
||||||
@@ -66,7 +66,7 @@ class LoraArguments:
|
|||||||
)
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
lora_alpha: Optional[int] = field(
|
lora_alpha: int | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "The scale factor for LoRA fine-tuning (default: lora_rank * 2)."},
|
metadata={"help": "The scale factor for LoRA fine-tuning (default: lora_rank * 2)."},
|
||||||
)
|
)
|
||||||
@@ -88,7 +88,7 @@ class LoraArguments:
|
|||||||
)
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
loraplus_lr_ratio: Optional[float] = field(
|
loraplus_lr_ratio: float | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "LoRA plus learning rate ratio (lr_B / lr_A)."},
|
metadata={"help": "LoRA plus learning rate ratio (lr_B / lr_A)."},
|
||||||
)
|
)
|
||||||
@@ -126,7 +126,7 @@ class LoraArguments:
|
|||||||
class OFTArguments:
|
class OFTArguments:
|
||||||
r"""Arguments pertaining to the OFT training."""
|
r"""Arguments pertaining to the OFT training."""
|
||||||
|
|
||||||
additional_target: Optional[str] = field(
|
additional_target: str | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": (
|
"help": (
|
||||||
@@ -220,27 +220,27 @@ class RLHFArguments:
|
|||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Whiten the rewards before compute advantages in PPO training."},
|
metadata={"help": "Whiten the rewards before compute advantages in PPO training."},
|
||||||
)
|
)
|
||||||
ref_model: Optional[str] = field(
|
ref_model: str | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Path to the reference model used for the PPO or DPO training."},
|
metadata={"help": "Path to the reference model used for the PPO or DPO training."},
|
||||||
)
|
)
|
||||||
ref_model_adapters: Optional[str] = field(
|
ref_model_adapters: str | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Path to the adapters of the reference model."},
|
metadata={"help": "Path to the adapters of the reference model."},
|
||||||
)
|
)
|
||||||
ref_model_quantization_bit: Optional[int] = field(
|
ref_model_quantization_bit: int | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "The number of bits to quantize the reference model."},
|
metadata={"help": "The number of bits to quantize the reference model."},
|
||||||
)
|
)
|
||||||
reward_model: Optional[str] = field(
|
reward_model: str | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Path to the reward model used for the PPO training."},
|
metadata={"help": "Path to the reward model used for the PPO training."},
|
||||||
)
|
)
|
||||||
reward_model_adapters: Optional[str] = field(
|
reward_model_adapters: str | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Path to the adapters of the reward model."},
|
metadata={"help": "Path to the adapters of the reward model."},
|
||||||
)
|
)
|
||||||
reward_model_quantization_bit: Optional[int] = field(
|
reward_model_quantization_bit: int | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "The number of bits to quantize the reward model."},
|
metadata={"help": "The number of bits to quantize the reward model."},
|
||||||
)
|
)
|
||||||
@@ -248,7 +248,7 @@ class RLHFArguments:
|
|||||||
default="lora",
|
default="lora",
|
||||||
metadata={"help": "The type of the reward model in PPO training. Lora model only supports lora training."},
|
metadata={"help": "The type of the reward model in PPO training. Lora model only supports lora training."},
|
||||||
)
|
)
|
||||||
ld_alpha: Optional[float] = field(
|
ld_alpha: float | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": (
|
"help": (
|
||||||
@@ -361,15 +361,15 @@ class BAdamArgument:
|
|||||||
default="layer",
|
default="layer",
|
||||||
metadata={"help": "Whether to use layer-wise or ratio-wise BAdam optimizer."},
|
metadata={"help": "Whether to use layer-wise or ratio-wise BAdam optimizer."},
|
||||||
)
|
)
|
||||||
badam_start_block: Optional[int] = field(
|
badam_start_block: int | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "The starting block index for layer-wise BAdam."},
|
metadata={"help": "The starting block index for layer-wise BAdam."},
|
||||||
)
|
)
|
||||||
badam_switch_mode: Optional[Literal["ascending", "descending", "random", "fixed"]] = field(
|
badam_switch_mode: Literal["ascending", "descending", "random", "fixed"] | None = field(
|
||||||
default="ascending",
|
default="ascending",
|
||||||
metadata={"help": "the strategy of picking block to update for layer-wise BAdam."},
|
metadata={"help": "the strategy of picking block to update for layer-wise BAdam."},
|
||||||
)
|
)
|
||||||
badam_switch_interval: Optional[int] = field(
|
badam_switch_interval: int | None = field(
|
||||||
default=50,
|
default=50,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Number of steps to update the block for layer-wise BAdam. Use -1 to disable the block update."
|
"help": "Number of steps to update the block for layer-wise BAdam. Use -1 to disable the block update."
|
||||||
@@ -406,15 +406,15 @@ class SwanLabArguments:
|
|||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Whether or not to use the SwanLab (an experiment tracking and visualization tool)."},
|
metadata={"help": "Whether or not to use the SwanLab (an experiment tracking and visualization tool)."},
|
||||||
)
|
)
|
||||||
swanlab_project: Optional[str] = field(
|
swanlab_project: str | None = field(
|
||||||
default="llamafactory",
|
default="llamafactory",
|
||||||
metadata={"help": "The project name in SwanLab."},
|
metadata={"help": "The project name in SwanLab."},
|
||||||
)
|
)
|
||||||
swanlab_workspace: Optional[str] = field(
|
swanlab_workspace: str | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "The workspace name in SwanLab."},
|
metadata={"help": "The workspace name in SwanLab."},
|
||||||
)
|
)
|
||||||
swanlab_run_name: Optional[str] = field(
|
swanlab_run_name: str | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "The experiment name in SwanLab."},
|
metadata={"help": "The experiment name in SwanLab."},
|
||||||
)
|
)
|
||||||
@@ -422,19 +422,19 @@ class SwanLabArguments:
|
|||||||
default="cloud",
|
default="cloud",
|
||||||
metadata={"help": "The mode of SwanLab."},
|
metadata={"help": "The mode of SwanLab."},
|
||||||
)
|
)
|
||||||
swanlab_api_key: Optional[str] = field(
|
swanlab_api_key: str | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "The API key for SwanLab."},
|
metadata={"help": "The API key for SwanLab."},
|
||||||
)
|
)
|
||||||
swanlab_logdir: Optional[str] = field(
|
swanlab_logdir: str | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "The log directory for SwanLab."},
|
metadata={"help": "The log directory for SwanLab."},
|
||||||
)
|
)
|
||||||
swanlab_lark_webhook_url: Optional[str] = field(
|
swanlab_lark_webhook_url: str | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "The Lark(飞书) webhook URL for SwanLab."},
|
metadata={"help": "The Lark(飞书) webhook URL for SwanLab."},
|
||||||
)
|
)
|
||||||
swanlab_lark_secret: Optional[str] = field(
|
swanlab_lark_secret: str | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "The Lark(飞书) secret for SwanLab."},
|
metadata={"help": "The Lark(飞书) secret for SwanLab."},
|
||||||
)
|
)
|
||||||
@@ -461,7 +461,7 @@ class FinetuningArguments(
|
|||||||
default="sft",
|
default="sft",
|
||||||
metadata={"help": "Which stage will be performed in training."},
|
metadata={"help": "Which stage will be performed in training."},
|
||||||
)
|
)
|
||||||
finetuning_type: Literal["lora", "freeze", "full"] = field(
|
finetuning_type: Literal["lora", "oft", "freeze", "full"] = field(
|
||||||
default="lora",
|
default="lora",
|
||||||
metadata={"help": "Which fine-tuning method to use."},
|
metadata={"help": "Which fine-tuning method to use."},
|
||||||
)
|
)
|
||||||
@@ -473,6 +473,15 @@ class FinetuningArguments(
|
|||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Whether or not to use the Adam-mini optimizer."},
|
metadata={"help": "Whether or not to use the Adam-mini optimizer."},
|
||||||
)
|
)
|
||||||
|
use_mca: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={
|
||||||
|
"help": (
|
||||||
|
"Whether or not to use MCA (Megatron Core Adapter) training. "
|
||||||
|
"Controlled by USE_MCA environment variable."
|
||||||
|
)
|
||||||
|
},
|
||||||
|
)
|
||||||
use_muon: bool = field(
|
use_muon: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Whether or not to use the Muon optimizer."},
|
metadata={"help": "Whether or not to use the Muon optimizer."},
|
||||||
@@ -501,7 +510,7 @@ class FinetuningArguments(
|
|||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Whether or not to disable the shuffling of the training set."},
|
metadata={"help": "Whether or not to disable the shuffling of the training set."},
|
||||||
)
|
)
|
||||||
early_stopping_steps: Optional[int] = field(
|
early_stopping_steps: int | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Number of steps to stop training if the `metric_for_best_model` does not improve."},
|
metadata={"help": "Number of steps to stop training if the `metric_for_best_model` does not improve."},
|
||||||
)
|
)
|
||||||
@@ -521,11 +530,11 @@ class FinetuningArguments(
|
|||||||
return arg
|
return arg
|
||||||
|
|
||||||
self.freeze_trainable_modules: list[str] = split_arg(self.freeze_trainable_modules)
|
self.freeze_trainable_modules: list[str] = split_arg(self.freeze_trainable_modules)
|
||||||
self.freeze_extra_modules: Optional[list[str]] = split_arg(self.freeze_extra_modules)
|
self.freeze_extra_modules: list[str] | None = split_arg(self.freeze_extra_modules)
|
||||||
self.lora_alpha: int = self.lora_alpha or self.lora_rank * 2
|
self.lora_alpha: int = self.lora_alpha or self.lora_rank * 2
|
||||||
self.lora_target: list[str] = split_arg(self.lora_target)
|
self.lora_target: list[str] = split_arg(self.lora_target)
|
||||||
self.oft_target: list[str] = split_arg(self.oft_target)
|
self.oft_target: list[str] = split_arg(self.oft_target)
|
||||||
self.additional_target: Optional[list[str]] = split_arg(self.additional_target)
|
self.additional_target: list[str] | None = split_arg(self.additional_target)
|
||||||
self.galore_target: list[str] = split_arg(self.galore_target)
|
self.galore_target: list[str] = split_arg(self.galore_target)
|
||||||
self.apollo_target: list[str] = split_arg(self.apollo_target)
|
self.apollo_target: list[str] = split_arg(self.apollo_target)
|
||||||
self.use_ref_model = self.stage == "dpo" and self.pref_loss not in ["orpo", "simpo"]
|
self.use_ref_model = self.stage == "dpo" and self.pref_loss not in ["orpo", "simpo"]
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
|
# Copyright 2025 HuggingFace Inc., the KVCache.AI team, Approaching AI, and the LlamaFactory team.
|
||||||
#
|
#
|
||||||
# This code is inspired by the HuggingFace's transformers library.
|
# This code is inspired by the HuggingFace's transformers library.
|
||||||
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
|
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
|
||||||
@@ -17,29 +17,30 @@
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
from dataclasses import asdict, dataclass, field, fields
|
from dataclasses import asdict, dataclass, field, fields
|
||||||
from typing import Any, Literal, Optional, Union
|
from typing import Any, Literal, Self
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from transformers.training_args import _convert_str_dict
|
|
||||||
from typing_extensions import Self
|
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
|
from transformers.training_args import _convert_str_dict
|
||||||
|
|
||||||
from ..extras.constants import AttentionFunction, EngineName, QuantizationMethod, RopeScaling
|
from ..extras.constants import AttentionFunction, EngineName, QuantizationMethod, RopeScaling
|
||||||
from ..extras.logging import get_logger
|
from ..extras.logging import get_logger
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BaseModelArguments:
|
class BaseModelArguments:
|
||||||
r"""Arguments pertaining to the model."""
|
r"""Arguments pertaining to the model."""
|
||||||
|
|
||||||
model_name_or_path: Optional[str] = field(
|
model_name_or_path: str | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Path to the model weight or identifier from huggingface.co/models or modelscope.cn/models."
|
"help": "Path to the model weight or identifier from huggingface.co/models or modelscope.cn/models."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
adapter_name_or_path: Optional[str] = field(
|
adapter_name_or_path: str | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": (
|
"help": (
|
||||||
@@ -48,11 +49,11 @@ class BaseModelArguments:
|
|||||||
)
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
adapter_folder: Optional[str] = field(
|
adapter_folder: str | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "The folder containing the adapter weights to load."},
|
metadata={"help": "The folder containing the adapter weights to load."},
|
||||||
)
|
)
|
||||||
cache_dir: Optional[str] = field(
|
cache_dir: str | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Where to store the pre-trained models downloaded from huggingface.co or modelscope.cn."},
|
metadata={"help": "Where to store the pre-trained models downloaded from huggingface.co or modelscope.cn."},
|
||||||
)
|
)
|
||||||
@@ -68,17 +69,17 @@ class BaseModelArguments:
|
|||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Whether or not the special tokens should be split during the tokenization process."},
|
metadata={"help": "Whether or not the special tokens should be split during the tokenization process."},
|
||||||
)
|
)
|
||||||
add_tokens: Optional[str] = field(
|
add_tokens: str | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Non-special tokens to be added into the tokenizer. Use commas to separate multiple tokens."
|
"help": "Non-special tokens to be added into the tokenizer. Use commas to separate multiple tokens."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
add_special_tokens: Optional[str] = field(
|
add_special_tokens: str | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Special tokens to be added into the tokenizer. Use commas to separate multiple tokens."},
|
metadata={"help": "Special tokens to be added into the tokenizer. Use commas to separate multiple tokens."},
|
||||||
)
|
)
|
||||||
new_special_tokens_config: Optional[str] = field(
|
new_special_tokens_config: str | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": (
|
"help": (
|
||||||
@@ -108,7 +109,7 @@ class BaseModelArguments:
|
|||||||
default=True,
|
default=True,
|
||||||
metadata={"help": "Whether or not to use memory-efficient model loading."},
|
metadata={"help": "Whether or not to use memory-efficient model loading."},
|
||||||
)
|
)
|
||||||
rope_scaling: Optional[RopeScaling] = field(
|
rope_scaling: RopeScaling | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Which scaling strategy should be adopted for the RoPE embeddings."},
|
metadata={"help": "Which scaling strategy should be adopted for the RoPE embeddings."},
|
||||||
)
|
)
|
||||||
@@ -120,7 +121,7 @@ class BaseModelArguments:
|
|||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."},
|
metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."},
|
||||||
)
|
)
|
||||||
mixture_of_depths: Optional[Literal["convert", "load"]] = field(
|
mixture_of_depths: Literal["convert", "load"] | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Convert the model to mixture-of-depths (MoD) or load the MoD model."},
|
metadata={"help": "Convert the model to mixture-of-depths (MoD) or load the MoD model."},
|
||||||
)
|
)
|
||||||
@@ -136,7 +137,7 @@ class BaseModelArguments:
|
|||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Whether or not to enable liger kernel for faster training."},
|
metadata={"help": "Whether or not to enable liger kernel for faster training."},
|
||||||
)
|
)
|
||||||
moe_aux_loss_coef: Optional[float] = field(
|
moe_aux_loss_coef: float | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Coefficient of the auxiliary router loss in mixture-of-experts model."},
|
metadata={"help": "Coefficient of the auxiliary router loss in mixture-of-experts model."},
|
||||||
)
|
)
|
||||||
@@ -168,23 +169,27 @@ class BaseModelArguments:
|
|||||||
default="offload",
|
default="offload",
|
||||||
metadata={"help": "Path to offload model weights."},
|
metadata={"help": "Path to offload model weights."},
|
||||||
)
|
)
|
||||||
use_cache: bool = field(
|
use_kv_cache: bool = field(
|
||||||
default=True,
|
default=True,
|
||||||
metadata={"help": "Whether or not to use KV cache in generation."},
|
metadata={"help": "Whether or not to use KV cache in generation."},
|
||||||
)
|
)
|
||||||
|
use_v1_kernels: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={"help": "Whether or not to use high-performance kernels in training."},
|
||||||
|
)
|
||||||
infer_dtype: Literal["auto", "float16", "bfloat16", "float32"] = field(
|
infer_dtype: Literal["auto", "float16", "bfloat16", "float32"] = field(
|
||||||
default="auto",
|
default="auto",
|
||||||
metadata={"help": "Data type for model weights and activations at inference."},
|
metadata={"help": "Data type for model weights and activations at inference."},
|
||||||
)
|
)
|
||||||
hf_hub_token: Optional[str] = field(
|
hf_hub_token: str | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Auth token to log in with Hugging Face Hub."},
|
metadata={"help": "Auth token to log in with Hugging Face Hub."},
|
||||||
)
|
)
|
||||||
ms_hub_token: Optional[str] = field(
|
ms_hub_token: str | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Auth token to log in with ModelScope Hub."},
|
metadata={"help": "Auth token to log in with ModelScope Hub."},
|
||||||
)
|
)
|
||||||
om_hub_token: Optional[str] = field(
|
om_hub_token: str | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Auth token to log in with Modelers Hub."},
|
metadata={"help": "Auth token to log in with Modelers Hub."},
|
||||||
)
|
)
|
||||||
@@ -277,7 +282,7 @@ class QuantizationArguments:
|
|||||||
default=QuantizationMethod.BNB,
|
default=QuantizationMethod.BNB,
|
||||||
metadata={"help": "Quantization method to use for on-the-fly quantization."},
|
metadata={"help": "Quantization method to use for on-the-fly quantization."},
|
||||||
)
|
)
|
||||||
quantization_bit: Optional[int] = field(
|
quantization_bit: int | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "The number of bits to quantize the model using on-the-fly quantization."},
|
metadata={"help": "The number of bits to quantize the model using on-the-fly quantization."},
|
||||||
)
|
)
|
||||||
@@ -289,7 +294,7 @@ class QuantizationArguments:
|
|||||||
default=True,
|
default=True,
|
||||||
metadata={"help": "Whether or not to use double quantization in bitsandbytes int4 training."},
|
metadata={"help": "Whether or not to use double quantization in bitsandbytes int4 training."},
|
||||||
)
|
)
|
||||||
quantization_device_map: Optional[Literal["auto"]] = field(
|
quantization_device_map: Literal["auto"] | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Device map used to infer the 4-bit quantized model, needs bitsandbytes>=0.43.0."},
|
metadata={"help": "Device map used to infer the 4-bit quantized model, needs bitsandbytes>=0.43.0."},
|
||||||
)
|
)
|
||||||
@@ -369,7 +374,7 @@ class ProcessorArguments:
|
|||||||
class ExportArguments:
|
class ExportArguments:
|
||||||
r"""Arguments pertaining to the model export."""
|
r"""Arguments pertaining to the model export."""
|
||||||
|
|
||||||
export_dir: Optional[str] = field(
|
export_dir: str | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Path to the directory to save the exported model."},
|
metadata={"help": "Path to the directory to save the exported model."},
|
||||||
)
|
)
|
||||||
@@ -381,11 +386,11 @@ class ExportArguments:
|
|||||||
default="cpu",
|
default="cpu",
|
||||||
metadata={"help": "The device used in model export, use `auto` to accelerate exporting."},
|
metadata={"help": "The device used in model export, use `auto` to accelerate exporting."},
|
||||||
)
|
)
|
||||||
export_quantization_bit: Optional[int] = field(
|
export_quantization_bit: int | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "The number of bits to quantize the exported model."},
|
metadata={"help": "The number of bits to quantize the exported model."},
|
||||||
)
|
)
|
||||||
export_quantization_dataset: Optional[str] = field(
|
export_quantization_dataset: str | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Path to the dataset or dataset name to use in quantizing the exported model."},
|
metadata={"help": "Path to the dataset or dataset name to use in quantizing the exported model."},
|
||||||
)
|
)
|
||||||
@@ -401,7 +406,7 @@ class ExportArguments:
|
|||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Whether or not to save the `.bin` files instead of `.safetensors`."},
|
metadata={"help": "Whether or not to save the `.bin` files instead of `.safetensors`."},
|
||||||
)
|
)
|
||||||
export_hub_model_id: Optional[str] = field(
|
export_hub_model_id: str | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "The name of the repository if push the model to the Hugging Face hub."},
|
metadata={"help": "The name of the repository if push the model to the Hugging Face hub."},
|
||||||
)
|
)
|
||||||
@@ -431,7 +436,7 @@ class VllmArguments:
|
|||||||
default=32,
|
default=32,
|
||||||
metadata={"help": "Maximum rank of all LoRAs in the vLLM engine."},
|
metadata={"help": "Maximum rank of all LoRAs in the vLLM engine."},
|
||||||
)
|
)
|
||||||
vllm_config: Optional[Union[dict, str]] = field(
|
vllm_config: dict | str | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Config to initialize the vllm engine. Please use JSON strings."},
|
metadata={"help": "Config to initialize the vllm engine. Please use JSON strings."},
|
||||||
)
|
)
|
||||||
@@ -457,7 +462,7 @@ class SGLangArguments:
|
|||||||
default=-1,
|
default=-1,
|
||||||
metadata={"help": "Tensor parallel size for the SGLang engine."},
|
metadata={"help": "Tensor parallel size for the SGLang engine."},
|
||||||
)
|
)
|
||||||
sglang_config: Optional[Union[dict, str]] = field(
|
sglang_config: dict | str | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Config to initialize the SGLang engine. Please use JSON strings."},
|
metadata={"help": "Config to initialize the SGLang engine. Please use JSON strings."},
|
||||||
)
|
)
|
||||||
@@ -473,26 +478,77 @@ class SGLangArguments:
|
|||||||
self.sglang_config = _convert_str_dict(json.loads(self.sglang_config))
|
self.sglang_config = _convert_str_dict(json.loads(self.sglang_config))
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class KTransformersArguments:
|
||||||
|
r"""Arguments pertaining to the KT training."""
|
||||||
|
|
||||||
|
use_kt: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={"help": "Whether To Use KTransformers Optimizations For LoRA Training."},
|
||||||
|
)
|
||||||
|
kt_optimize_rule: str | None = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": "Path To The KTransformers Optimize Rule; See https://github.com/kvcache-ai/ktransformers/."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
cpu_infer: int | None = field(
|
||||||
|
default=32,
|
||||||
|
metadata={"help": "Number Of CPU Cores Used For Computation."},
|
||||||
|
)
|
||||||
|
chunk_size: int | None = field(
|
||||||
|
default=8192,
|
||||||
|
metadata={"help": "Chunk Size Used For CPU Compute In KTransformers."},
|
||||||
|
)
|
||||||
|
mode: str | None = field(
|
||||||
|
default="normal",
|
||||||
|
metadata={"help": "Normal Or Long_Context For Llama Models."},
|
||||||
|
)
|
||||||
|
|
||||||
|
kt_maxlen: int = field(
|
||||||
|
default=4096,
|
||||||
|
metadata={"help": "Maximum Sequence (Prompt + Response) Length Of The KT Engine."},
|
||||||
|
)
|
||||||
|
kt_use_cuda_graph: bool = field(
|
||||||
|
default=True,
|
||||||
|
metadata={"help": "Whether To Use CUDA Graphs For The KT Engine."},
|
||||||
|
)
|
||||||
|
kt_mode: str = field(
|
||||||
|
default="normal",
|
||||||
|
metadata={"help": "Normal Or Long_Context Mode For The KT Engine."},
|
||||||
|
)
|
||||||
|
kt_force_think: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={"help": "Force-Think Toggle For The KT Engine."},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ModelArguments(
|
class ModelArguments(
|
||||||
SGLangArguments, VllmArguments, ExportArguments, ProcessorArguments, QuantizationArguments, BaseModelArguments
|
SGLangArguments,
|
||||||
|
VllmArguments,
|
||||||
|
KTransformersArguments,
|
||||||
|
ExportArguments,
|
||||||
|
ProcessorArguments,
|
||||||
|
QuantizationArguments,
|
||||||
|
BaseModelArguments,
|
||||||
):
|
):
|
||||||
r"""Arguments pertaining to which model/config/tokenizer we are going to fine-tune or infer.
|
r"""Arguments pertaining to which model/config/tokenizer we are going to fine-tune or infer.
|
||||||
|
|
||||||
The class on the most right will be displayed first.
|
The class on the most right will be displayed first.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
compute_dtype: Optional[torch.dtype] = field(
|
compute_dtype: torch.dtype | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
init=False,
|
init=False,
|
||||||
metadata={"help": "Torch data type for computing model outputs, derived from `fp/bf16`. Do not specify it."},
|
metadata={"help": "Torch data type for computing model outputs, derived from `fp/bf16`. Do not specify it."},
|
||||||
)
|
)
|
||||||
device_map: Optional[Union[str, dict[str, Any]]] = field(
|
device_map: str | dict[str, Any] | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
init=False,
|
init=False,
|
||||||
metadata={"help": "Device map for model placement, derived from training stage. Do not specify it."},
|
metadata={"help": "Device map for model placement, derived from training stage. Do not specify it."},
|
||||||
)
|
)
|
||||||
model_max_length: Optional[int] = field(
|
model_max_length: int | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
init=False,
|
init=False,
|
||||||
metadata={"help": "The maximum input length for model, derived from `cutoff_len`. Do not specify it."},
|
metadata={"help": "The maximum input length for model, derived from `cutoff_len`. Do not specify it."},
|
||||||
|
|||||||
@@ -18,7 +18,7 @@
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Optional, Union
|
from typing import Any, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
@@ -32,7 +32,7 @@ from transformers.utils import is_torch_bf16_gpu_available, is_torch_npu_availab
|
|||||||
from ..extras import logging
|
from ..extras import logging
|
||||||
from ..extras.constants import CHECKPOINT_NAMES, EngineName
|
from ..extras.constants import CHECKPOINT_NAMES, EngineName
|
||||||
from ..extras.misc import check_dependencies, check_version, get_current_device, is_env_enabled
|
from ..extras.misc import check_dependencies, check_version, get_current_device, is_env_enabled
|
||||||
from ..extras.packages import is_transformers_version_greater_than
|
from ..extras.packages import is_mcore_adapter_available, is_transformers_version_greater_than
|
||||||
from .data_args import DataArguments
|
from .data_args import DataArguments
|
||||||
from .evaluation_args import EvaluationArguments
|
from .evaluation_args import EvaluationArguments
|
||||||
from .finetuning_args import FinetuningArguments
|
from .finetuning_args import FinetuningArguments
|
||||||
@@ -53,8 +53,19 @@ _INFER_CLS = tuple[ModelArguments, DataArguments, FinetuningArguments, Generatin
|
|||||||
_EVAL_ARGS = [ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments]
|
_EVAL_ARGS = [ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments]
|
||||||
_EVAL_CLS = tuple[ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments]
|
_EVAL_CLS = tuple[ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments]
|
||||||
|
|
||||||
|
if is_mcore_adapter_available() and is_env_enabled("USE_MCA"):
|
||||||
|
from mcore_adapter import TrainingArguments as McaTrainingArguments
|
||||||
|
|
||||||
def read_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> Union[dict[str, Any], list[str]]:
|
_TRAIN_MCA_ARGS = [ModelArguments, DataArguments, McaTrainingArguments, FinetuningArguments, GeneratingArguments]
|
||||||
|
_TRAIN_MCA_CLS = tuple[
|
||||||
|
ModelArguments, DataArguments, McaTrainingArguments, FinetuningArguments, GeneratingArguments
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
_TRAIN_MCA_ARGS = []
|
||||||
|
_TRAIN_MCA_CLS = tuple()
|
||||||
|
|
||||||
|
|
||||||
|
def read_args(args: dict[str, Any] | list[str] | None = None) -> dict[str, Any] | list[str]:
|
||||||
r"""Get arguments from the command line or a config file."""
|
r"""Get arguments from the command line or a config file."""
|
||||||
if args is not None:
|
if args is not None:
|
||||||
return args
|
return args
|
||||||
@@ -72,7 +83,7 @@ def read_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> Union[
|
|||||||
|
|
||||||
|
|
||||||
def _parse_args(
|
def _parse_args(
|
||||||
parser: "HfArgumentParser", args: Optional[Union[dict[str, Any], list[str]]] = None, allow_extra_keys: bool = False
|
parser: "HfArgumentParser", args: dict[str, Any] | list[str] | None = None, allow_extra_keys: bool = False
|
||||||
) -> tuple[Any]:
|
) -> tuple[Any]:
|
||||||
args = read_args(args)
|
args = read_args(args)
|
||||||
if isinstance(args, dict):
|
if isinstance(args, dict):
|
||||||
@@ -145,6 +156,9 @@ def _check_extra_dependencies(
|
|||||||
finetuning_args: "FinetuningArguments",
|
finetuning_args: "FinetuningArguments",
|
||||||
training_args: Optional["TrainingArguments"] = None,
|
training_args: Optional["TrainingArguments"] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
if model_args.use_kt:
|
||||||
|
check_version("ktransformers", mandatory=True)
|
||||||
|
|
||||||
if model_args.use_unsloth:
|
if model_args.use_unsloth:
|
||||||
check_version("unsloth", mandatory=True)
|
check_version("unsloth", mandatory=True)
|
||||||
|
|
||||||
@@ -191,32 +205,57 @@ def _check_extra_dependencies(
|
|||||||
check_version("rouge_chinese", mandatory=True)
|
check_version("rouge_chinese", mandatory=True)
|
||||||
|
|
||||||
|
|
||||||
def _parse_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _TRAIN_CLS:
|
def _parse_train_args(args: dict[str, Any] | list[str] | None = None) -> _TRAIN_CLS:
|
||||||
parser = HfArgumentParser(_TRAIN_ARGS)
|
parser = HfArgumentParser(_TRAIN_ARGS)
|
||||||
allow_extra_keys = is_env_enabled("ALLOW_EXTRA_ARGS")
|
allow_extra_keys = is_env_enabled("ALLOW_EXTRA_ARGS")
|
||||||
return _parse_args(parser, args, allow_extra_keys=allow_extra_keys)
|
return _parse_args(parser, args, allow_extra_keys=allow_extra_keys)
|
||||||
|
|
||||||
|
|
||||||
def _parse_infer_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _INFER_CLS:
|
def _parse_train_mca_args(args: dict[str, Any] | list[str] | None = None) -> _TRAIN_MCA_CLS:
|
||||||
|
parser = HfArgumentParser(_TRAIN_MCA_ARGS)
|
||||||
|
allow_extra_keys = is_env_enabled("ALLOW_EXTRA_ARGS")
|
||||||
|
model_args, data_args, training_args, finetuning_args, generating_args = _parse_args(
|
||||||
|
parser, args, allow_extra_keys=allow_extra_keys
|
||||||
|
)
|
||||||
|
|
||||||
|
_configure_mca_training_args(training_args, data_args, finetuning_args)
|
||||||
|
|
||||||
|
return model_args, data_args, training_args, finetuning_args, generating_args
|
||||||
|
|
||||||
|
|
||||||
|
def _configure_mca_training_args(training_args, data_args, finetuning_args) -> None:
|
||||||
|
"""Patch training args to avoid args checking errors and sync MCA settings."""
|
||||||
|
training_args.predict_with_generate = False
|
||||||
|
training_args.generation_max_length = data_args.cutoff_len
|
||||||
|
training_args.generation_num_beams = 1
|
||||||
|
training_args.use_mca = True
|
||||||
|
finetuning_args.use_mca = True
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_infer_args(args: dict[str, Any] | list[str] | None = None) -> _INFER_CLS:
|
||||||
parser = HfArgumentParser(_INFER_ARGS)
|
parser = HfArgumentParser(_INFER_ARGS)
|
||||||
allow_extra_keys = is_env_enabled("ALLOW_EXTRA_ARGS")
|
allow_extra_keys = is_env_enabled("ALLOW_EXTRA_ARGS")
|
||||||
return _parse_args(parser, args, allow_extra_keys=allow_extra_keys)
|
return _parse_args(parser, args, allow_extra_keys=allow_extra_keys)
|
||||||
|
|
||||||
|
|
||||||
def _parse_eval_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _EVAL_CLS:
|
def _parse_eval_args(args: dict[str, Any] | list[str] | None = None) -> _EVAL_CLS:
|
||||||
parser = HfArgumentParser(_EVAL_ARGS)
|
parser = HfArgumentParser(_EVAL_ARGS)
|
||||||
allow_extra_keys = is_env_enabled("ALLOW_EXTRA_ARGS")
|
allow_extra_keys = is_env_enabled("ALLOW_EXTRA_ARGS")
|
||||||
return _parse_args(parser, args, allow_extra_keys=allow_extra_keys)
|
return _parse_args(parser, args, allow_extra_keys=allow_extra_keys)
|
||||||
|
|
||||||
|
|
||||||
def get_ray_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> RayArguments:
|
def get_ray_args(args: dict[str, Any] | list[str] | None = None) -> RayArguments:
|
||||||
parser = HfArgumentParser(RayArguments)
|
parser = HfArgumentParser(RayArguments)
|
||||||
(ray_args,) = _parse_args(parser, args, allow_extra_keys=True)
|
(ray_args,) = _parse_args(parser, args, allow_extra_keys=True)
|
||||||
return ray_args
|
return ray_args
|
||||||
|
|
||||||
|
|
||||||
def get_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _TRAIN_CLS:
|
def get_train_args(args: dict[str, Any] | list[str] | None = None) -> _TRAIN_CLS:
|
||||||
model_args, data_args, training_args, finetuning_args, generating_args = _parse_train_args(args)
|
if is_env_enabled("USE_MCA"):
|
||||||
|
model_args, data_args, training_args, finetuning_args, generating_args = _parse_train_mca_args(args)
|
||||||
|
else:
|
||||||
|
model_args, data_args, training_args, finetuning_args, generating_args = _parse_train_args(args)
|
||||||
|
finetuning_args.use_mca = False
|
||||||
|
|
||||||
# Setup logging
|
# Setup logging
|
||||||
if training_args.should_log:
|
if training_args.should_log:
|
||||||
@@ -246,13 +285,16 @@ def get_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _
|
|||||||
if model_args.shift_attn:
|
if model_args.shift_attn:
|
||||||
raise ValueError("PPO training is incompatible with S^2-Attn.")
|
raise ValueError("PPO training is incompatible with S^2-Attn.")
|
||||||
|
|
||||||
|
if finetuning_args.reward_model_type == "lora" and model_args.use_kt:
|
||||||
|
raise ValueError("KTransformers does not support lora reward model.")
|
||||||
|
|
||||||
if finetuning_args.reward_model_type == "lora" and model_args.use_unsloth:
|
if finetuning_args.reward_model_type == "lora" and model_args.use_unsloth:
|
||||||
raise ValueError("Unsloth does not support lora reward model.")
|
raise ValueError("Unsloth does not support lora reward model.")
|
||||||
|
|
||||||
if training_args.report_to and training_args.report_to[0] not in ["wandb", "tensorboard"]:
|
if training_args.report_to and training_args.report_to[0] not in ["wandb", "tensorboard"]:
|
||||||
raise ValueError("PPO only accepts wandb or tensorboard logger.")
|
raise ValueError("PPO only accepts wandb or tensorboard logger.")
|
||||||
|
|
||||||
if training_args.parallel_mode == ParallelMode.NOT_DISTRIBUTED:
|
if not model_args.use_kt and training_args.parallel_mode == ParallelMode.NOT_DISTRIBUTED:
|
||||||
raise ValueError("Please launch distributed training with `llamafactory-cli` or `torchrun`.")
|
raise ValueError("Please launch distributed training with `llamafactory-cli` or `torchrun`.")
|
||||||
|
|
||||||
if training_args.deepspeed and training_args.parallel_mode != ParallelMode.DISTRIBUTED:
|
if training_args.deepspeed and training_args.parallel_mode != ParallelMode.DISTRIBUTED:
|
||||||
@@ -264,18 +306,15 @@ def get_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _
|
|||||||
if training_args.do_train and data_args.dataset is None:
|
if training_args.do_train and data_args.dataset is None:
|
||||||
raise ValueError("Please specify dataset for training.")
|
raise ValueError("Please specify dataset for training.")
|
||||||
|
|
||||||
if (training_args.do_eval or training_args.do_predict) and (
|
if (training_args.do_eval or training_args.do_predict or training_args.predict_with_generate) and (
|
||||||
data_args.eval_dataset is None and data_args.val_size < 1e-6
|
data_args.eval_dataset is None and data_args.val_size < 1e-6
|
||||||
):
|
):
|
||||||
raise ValueError("Please specify dataset for evaluation.")
|
raise ValueError("Please make sure eval_dataset be provided or val_size >1e-6")
|
||||||
|
|
||||||
if training_args.predict_with_generate:
|
if training_args.predict_with_generate:
|
||||||
if is_deepspeed_zero3_enabled():
|
if is_deepspeed_zero3_enabled():
|
||||||
raise ValueError("`predict_with_generate` is incompatible with DeepSpeed ZeRO-3.")
|
raise ValueError("`predict_with_generate` is incompatible with DeepSpeed ZeRO-3.")
|
||||||
|
|
||||||
if data_args.eval_dataset is None:
|
|
||||||
raise ValueError("Cannot use `predict_with_generate` if `eval_dataset` is None.")
|
|
||||||
|
|
||||||
if finetuning_args.compute_accuracy:
|
if finetuning_args.compute_accuracy:
|
||||||
raise ValueError("Cannot use `predict_with_generate` and `compute_accuracy` together.")
|
raise ValueError("Cannot use `predict_with_generate` and `compute_accuracy` together.")
|
||||||
|
|
||||||
@@ -314,6 +353,9 @@ def get_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _
|
|||||||
if model_args.use_unsloth and is_deepspeed_zero3_enabled():
|
if model_args.use_unsloth and is_deepspeed_zero3_enabled():
|
||||||
raise ValueError("Unsloth is incompatible with DeepSpeed ZeRO-3.")
|
raise ValueError("Unsloth is incompatible with DeepSpeed ZeRO-3.")
|
||||||
|
|
||||||
|
if model_args.use_kt and is_deepspeed_zero3_enabled():
|
||||||
|
raise ValueError("KTransformers is incompatible with DeepSpeed ZeRO-3.")
|
||||||
|
|
||||||
if data_args.neat_packing and is_transformers_version_greater_than("4.53.0"):
|
if data_args.neat_packing and is_transformers_version_greater_than("4.53.0"):
|
||||||
raise ValueError("Neat packing is incompatible with transformers>=4.53.0.")
|
raise ValueError("Neat packing is incompatible with transformers>=4.53.0.")
|
||||||
|
|
||||||
@@ -431,7 +473,7 @@ def get_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _
|
|||||||
return model_args, data_args, training_args, finetuning_args, generating_args
|
return model_args, data_args, training_args, finetuning_args, generating_args
|
||||||
|
|
||||||
|
|
||||||
def get_infer_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _INFER_CLS:
|
def get_infer_args(args: dict[str, Any] | list[str] | None = None) -> _INFER_CLS:
|
||||||
model_args, data_args, finetuning_args, generating_args = _parse_infer_args(args)
|
model_args, data_args, finetuning_args, generating_args = _parse_infer_args(args)
|
||||||
|
|
||||||
# Setup logging
|
# Setup logging
|
||||||
@@ -466,7 +508,7 @@ def get_infer_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _
|
|||||||
return model_args, data_args, finetuning_args, generating_args
|
return model_args, data_args, finetuning_args, generating_args
|
||||||
|
|
||||||
|
|
||||||
def get_eval_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _EVAL_CLS:
|
def get_eval_args(args: dict[str, Any] | list[str] | None = None) -> _EVAL_CLS:
|
||||||
model_args, data_args, eval_args, finetuning_args = _parse_eval_args(args)
|
model_args, data_args, eval_args, finetuning_args = _parse_eval_args(args)
|
||||||
|
|
||||||
# Setup logging
|
# Setup logging
|
||||||
|
|||||||
@@ -14,19 +14,33 @@
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Literal, Optional, Union
|
from typing import Literal
|
||||||
|
|
||||||
from transformers import Seq2SeqTrainingArguments
|
from transformers import Seq2SeqTrainingArguments
|
||||||
from transformers.training_args import _convert_str_dict
|
from transformers.training_args import _convert_str_dict
|
||||||
|
|
||||||
from ..extras.misc import use_ray
|
from ..extras.misc import is_env_enabled, use_ray
|
||||||
|
from ..extras.packages import is_mcore_adapter_available
|
||||||
|
|
||||||
|
|
||||||
|
if is_env_enabled("USE_MCA"):
|
||||||
|
if not is_mcore_adapter_available():
|
||||||
|
raise ImportError(
|
||||||
|
"mcore_adapter is required when USE_MCA=1. Please install `mcore_adapter` and its dependencies."
|
||||||
|
)
|
||||||
|
|
||||||
|
from mcore_adapter import Seq2SeqTrainingArguments as McaSeq2SeqTrainingArguments
|
||||||
|
|
||||||
|
BaseTrainingArguments = McaSeq2SeqTrainingArguments
|
||||||
|
else:
|
||||||
|
BaseTrainingArguments = Seq2SeqTrainingArguments
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class RayArguments:
|
class RayArguments:
|
||||||
r"""Arguments pertaining to the Ray training."""
|
r"""Arguments pertaining to the Ray training."""
|
||||||
|
|
||||||
ray_run_name: Optional[str] = field(
|
ray_run_name: str | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "The training results will be saved at `<ray_storage_path>/ray_run_name`."},
|
metadata={"help": "The training results will be saved at `<ray_storage_path>/ray_run_name`."},
|
||||||
)
|
)
|
||||||
@@ -34,7 +48,7 @@ class RayArguments:
|
|||||||
default="./saves",
|
default="./saves",
|
||||||
metadata={"help": "The storage path to save training results to"},
|
metadata={"help": "The storage path to save training results to"},
|
||||||
)
|
)
|
||||||
ray_storage_filesystem: Optional[Literal["s3", "gs", "gcs"]] = field(
|
ray_storage_filesystem: Literal["s3", "gs", "gcs"] | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "The storage filesystem to use. If None specified, local filesystem will be used."},
|
metadata={"help": "The storage filesystem to use. If None specified, local filesystem will be used."},
|
||||||
)
|
)
|
||||||
@@ -42,7 +56,7 @@ class RayArguments:
|
|||||||
default=1,
|
default=1,
|
||||||
metadata={"help": "The number of workers for Ray training. Default is 1 worker."},
|
metadata={"help": "The number of workers for Ray training. Default is 1 worker."},
|
||||||
)
|
)
|
||||||
resources_per_worker: Union[dict, str] = field(
|
resources_per_worker: dict | str = field(
|
||||||
default_factory=lambda: {"GPU": 1},
|
default_factory=lambda: {"GPU": 1},
|
||||||
metadata={"help": "The resources per worker for Ray training. Default is to use 1 GPU per worker."},
|
metadata={"help": "The resources per worker for Ray training. Default is to use 1 GPU per worker."},
|
||||||
)
|
)
|
||||||
@@ -50,7 +64,7 @@ class RayArguments:
|
|||||||
default="PACK",
|
default="PACK",
|
||||||
metadata={"help": "The placement strategy for Ray training. Default is PACK."},
|
metadata={"help": "The placement strategy for Ray training. Default is PACK."},
|
||||||
)
|
)
|
||||||
ray_init_kwargs: Optional[Union[dict, str]] = field(
|
ray_init_kwargs: dict | str | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "The arguments to pass to ray.init for Ray training. Default is None."},
|
metadata={"help": "The arguments to pass to ray.init for Ray training. Default is None."},
|
||||||
)
|
)
|
||||||
@@ -78,9 +92,14 @@ class RayArguments:
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TrainingArguments(RayArguments, Seq2SeqTrainingArguments):
|
class TrainingArguments(RayArguments, BaseTrainingArguments):
|
||||||
r"""Arguments pertaining to the trainer."""
|
r"""Arguments pertaining to the trainer."""
|
||||||
|
|
||||||
|
overwrite_output_dir: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={"help": "deprecated"},
|
||||||
|
)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
Seq2SeqTrainingArguments.__post_init__(self)
|
|
||||||
RayArguments.__post_init__(self)
|
RayArguments.__post_init__(self)
|
||||||
|
BaseTrainingArguments.__post_init__(self)
|
||||||
|
|||||||
@@ -38,7 +38,7 @@ USAGE = (
|
|||||||
def launch():
|
def launch():
|
||||||
from .extras import logging
|
from .extras import logging
|
||||||
from .extras.env import VERSION, print_env
|
from .extras.env import VERSION, print_env
|
||||||
from .extras.misc import find_available_port, get_device_count, is_env_enabled, use_ray
|
from .extras.misc import find_available_port, get_device_count, is_env_enabled, use_kt, use_ray
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
WELCOME = (
|
WELCOME = (
|
||||||
@@ -54,7 +54,12 @@ def launch():
|
|||||||
)
|
)
|
||||||
|
|
||||||
command = sys.argv.pop(1) if len(sys.argv) > 1 else "help"
|
command = sys.argv.pop(1) if len(sys.argv) > 1 else "help"
|
||||||
if command == "train" and (is_env_enabled("FORCE_TORCHRUN") or (get_device_count() > 1 and not use_ray())):
|
if is_env_enabled("USE_MCA"): # force use torchrun
|
||||||
|
os.environ["FORCE_TORCHRUN"] = "1"
|
||||||
|
|
||||||
|
if command == "train" and (
|
||||||
|
is_env_enabled("FORCE_TORCHRUN") or (get_device_count() > 1 and not use_ray() and not use_kt())
|
||||||
|
):
|
||||||
# launch distributed training
|
# launch distributed training
|
||||||
nnodes = os.getenv("NNODES", "1")
|
nnodes = os.getenv("NNODES", "1")
|
||||||
node_rank = os.getenv("NODE_RANK", "0")
|
node_rank = os.getenv("NODE_RANK", "0")
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user