37 Commits

Author SHA1 Message Date
Copilot
eceec8ab69 [deps] goodbye python 3.9 (#9677)
Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: hiyouga <16256802+hiyouga@users.noreply.github.com>
Co-authored-by: hiyouga <hiyouga@buaa.edu.cn>
2025-12-27 02:50:44 +08:00
Yaowei Zheng
b44f651e09 [ci] fix docker (#9678) 2025-12-27 02:43:46 +08:00
Yaowei Zheng
55590f5ece [misc] fix ci with uv (#9676) 2025-12-27 01:39:13 +08:00
Copilot
a1b1931b4a [breaking] migrate from setuptools to uv (#9673)
Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: hiyouga <16256802+hiyouga@users.noreply.github.com>
2025-12-26 22:47:23 +08:00
Xunpeng Xiao
3c17f2722c [model] Update ernie_vl to adapt new version (#9665) 2025-12-26 19:57:49 +08:00
Copilot
a882e2d5fc [assets] Add GitHub Copilot instructions for repository (#9675)
Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: hiyouga <16256802+hiyouga@users.noreply.github.com>
2025-12-26 17:32:48 +08:00
Yaowei Zheng
a754604c11 [misc] fix accelerator (#9661)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2025-12-25 02:11:04 +08:00
Xunpeng Xiao
6a2eafbae3 [feat] Models trained and inferred with Mxfp4 are dequantized by default (#9652)
Co-authored-by: Yaowei Zheng <hiyouga@buaa.edu.cn>
2025-12-24 00:26:40 +08:00
Yaowei Zheng
84485406b7 [ci] disable pip cache for ci (#9654) 2025-12-23 18:37:40 +08:00
Kingsley
1c8a42d2f8 [v1&WIP] dataloader init (#9645) 2025-12-23 16:29:47 +08:00
thulyubh22
7901b2f32e [model] efficient tuning for gpt-oss (#9354)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2025-12-23 16:28:38 +08:00
Yaowei Zheng
1f1f5a7d1b [ci] remove docker cache (#9640) 2025-12-22 01:03:10 +08:00
Yaowei Zheng
6ef9854713 [misc] fix cache & pin transformers to 4.57.1 (#9638) 2025-12-22 00:20:55 +08:00
Hertz
4923f52a28 [model] support MiMo-V2-Flash model (#9637) 2025-12-21 14:38:18 +08:00
Yaowei Zheng
0894b4f37e [misc] lint (#9636) 2025-12-20 16:19:39 +08:00
ZIYI ZENG
b0d49e137f [misc] Support split eval_dataset when explict set "predict_with_generate" (#9604)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2025-12-20 01:46:00 +08:00
Xunpeng Xiao
ddd7dcc722 [data] Fix the video frame sampling issue #9620 (#9634) 2025-12-19 18:36:31 +08:00
浮梦
5204cd2bca [misc] add version check for moe (#9633) 2025-12-19 14:57:37 +08:00
Xunpeng Xiao
8c74dca76a [feat] Models trained and inferred with FP8 are dequantized by default (#9627) 2025-12-18 22:54:35 +08:00
xvxuopop
e8deda53a1 [example] add Qwen3 series examples (#9624)
Co-authored-by: UsernameFull <tohowtodoit@gmail.com>
2025-12-18 21:27:00 +08:00
mrhaoxx
a769fb94b9 [feat] support ktransformers for dpo (#9621)
Co-authored-by: poryfly <porykid@gmail.com>
2025-12-18 21:26:25 +08:00
mrhaoxx
964569751f [kt] refactor ktransformers integration (#9632) 2025-12-18 21:26:04 +08:00
Hertz
9fd4b094d4 [model] support VibeThinker models (#9616) 2025-12-16 21:50:46 +08:00
浮梦
18c21bce5a [test] add allreduce test on npu (#9619)
Co-authored-by: frozenleaves <frozen@Mac.local>
2025-12-16 21:33:30 +08:00
sunyi0505
a0179772ab [example] add deepspeed autotp config and example (#9602) 2025-12-15 15:15:26 +08:00
Yaowei Zheng
aeda079014 [v1] model loader (#9613) 2025-12-14 11:50:52 +08:00
Xunpeng Xiao
fdd24276ed [feat] support new function call value (#9610)
Co-authored-by: Yaowei Zheng <hiyouga@buaa.edu.cn>
2025-12-14 00:20:33 +08:00
Yaowei Zheng
110d21713e [v1] add dp & mp mesh (#9611) 2025-12-13 01:44:28 +08:00
Yaowei Zheng
203069e11c [v1] add accelerator (#9607) 2025-12-12 19:22:06 +08:00
tangefly
4fd94141a4 [model] Add Ministral3 (#9582)
Co-authored-by: kingsley <kingsleydodonow@gmail.com>
2025-12-10 15:57:24 +08:00
Kingsley
22d6ac29d5 [model] Rename GLMV template (#9595) 2025-12-10 13:27:47 +08:00
DoubleWheat
cff4483392 [config] Fix RoPE scaling patch for resuming from a scaled model (#9588) 2025-12-09 20:37:37 +08:00
Yaowei Zheng
5d56817e2b [misc] lint (#9593)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2025-12-09 18:00:35 +08:00
Yaowei Zheng
1bbb461f76 [assets] update readme (#9587) 2025-12-09 12:22:54 +08:00
Hertz
c1f5f8fff6 [model] support GLM4.6v (#9586) 2025-12-09 11:06:42 +08:00
Yaowei Zheng
5744f1ea94 [v1] add models & accelerator (#9579) 2025-12-08 02:30:25 +08:00
tangefly
739954910a [deps] Update for Transformers v5 (#9569) 2025-12-08 01:13:32 +08:00
177 changed files with 4288 additions and 1604 deletions

180
.github/copilot-instructions.md vendored Normal file
View File

@@ -0,0 +1,180 @@
# GitHub Copilot Instructions for LLaMA Factory
## Project Overview
LLaMA Factory is an efficient fine-tuning framework for 100+ large language models (LLMs). It provides:
- Support for various models: LLaMA, LLaVA, Mistral, Qwen, DeepSeek, Yi, Gemma, ChatGLM, Phi, etc.
- Multiple training methods: pre-training, supervised fine-tuning, reward modeling, PPO, DPO, KTO, ORPO
- Scalable resources: 16-bit full-tuning, freeze-tuning, LoRA and QLoRA variants
- Advanced algorithms: GaLore, BAdam, APOLLO, Adam-mini, Muon, OFT, DoRA, etc.
- Web UI (LLaMA Board) and CLI interfaces
### Architecture Versions
LLaMA Factory has two parallel architectures that can be switched via the `USE_V1` environment variable:
**v0 (default)** - File hierarchy:
- `api`, `webui``chat`, `eval`, `train``data`, `model``hparams``extras`
**v1** - File hierarchy:
- `trainers``core``accelerator`, `plugins`, `config``utils`
Set `USE_V1=1` to enable v1 architecture.
## Code Structure
### v0 Architecture (Default)
- `src/llamafactory/` - Main package directory
- `api/` - OpenAI-style API implementation
- `chat/` - Chat interface implementation
- `cli.py` - Command-line interface
- `data/` - Data processing and dataset handling
- `eval/` - Model evaluation utilities
- `extras/` - Additional utilities and helpers
- `hparams/` - Hyperparameter definitions
- `model/` - Model loading, patching, and utilities
- `train/` - Training pipeline implementation
- `webui/` - Gradio-based web interface
- `src/train.py` - Training entry script (delegates to `llamafactory.train.tuner`)
- `src/webui.py` - Web UI entry script (delegates to `llamafactory.webui.interface`)
- `src/api.py` - API server entry script (delegates to `llamafactory.api.app`)
- `tests/` - Test suite
- `examples/` - Example configurations for various training scenarios
- `data/` - Dataset definitions and examples
### v1 Architecture (USE_V1=1)
- `src/llamafactory/v1/` - Version 1 package directory
- `trainers/` - Training implementations
- `core/` - Core training utilities
- `accelerator/` - Acceleration and distributed training
- `plugins/` - Pluggable components (model, data, sampler, trainer)
- `config/` - Configuration management
- `utils/` - Utility functions
## Development Practices
### Code Style
- Follow the [Google Python Style Guide](https://google.github.io/styleguide/pyguide.html)
- Use ruff for linting and formatting
- Line length: 119 characters
- Indentation: 4 spaces
- Quote style: double quotes
- Use Google-style docstrings for documentation
### Import Organization
- Known first-party: `llamafactory`
- Known third-party: `accelerate`, `datasets`, `gradio`, `numpy`, `peft`, `torch`, `transformers`, `trl`
- Use 2 blank lines after imports
### Quality Checks
Before committing code, run:
```bash
make style # Auto-fix style issues
make quality # Check code quality
make test # Run test suite
```
Or use the combined command:
```bash
make commit # Run pre-commit hooks
```
### Testing
- Use pytest for testing
- Tests are located in `tests/` and `tests_v1/` directories
- Run tests with: `make test` (which runs `WANDB_DISABLED=true pytest -vv --import-mode=importlib tests/ tests_v1/`)
- Disable wandb during testing to avoid external dependencies
- **Note**: Training configurations require GPU machines, so training is typically not tested end-to-end. Use `make test` to validate file-level functionality.
### Building
Build the package with:
```bash
pip3 install build && python3 -m build
```
### License
- All source files must include the Apache 2.0 license header
- Check license headers with: `make license`
## Common Patterns
### Configuration Files
- Training configurations are typically YAML or JSON files in `examples/` directory
- Hyperparameters are defined using dataclasses in `src/llamafactory/hparams/`
### Model Support
- New model support is added through model patches in `src/llamafactory/model/`
- Visual models use the visual utilities in `src/llamafactory/model/model_utils/visual.py`
- Quantization support is in `src/llamafactory/model/model_utils/quantization.py`
### Data Processing
- Dataset definitions are in `data/dataset_info.json`
- Data templates and processors are in `src/llamafactory/data/`
### Training
- Training pipelines are in `src/llamafactory/train/`
- Support for different training methods: SFT, DPO, PPO, RM, PT, KTO, ORPO
## Key Dependencies
- Python >= 3.9.0
- PyTorch and transformers for model handling
- datasets for data processing
- peft for parameter-efficient fine-tuning
- accelerate for distributed training
- gradio for web UI
- trl for reinforcement learning
- Optional: vllm/sglang for inference, flash-attention-2, unsloth, liger-kernel
## Entry Points
- **CLI Training**: `llamafactory-cli train --config examples/train_lora/llama3_lora_sft.yaml`
- **Web UI**: `llamafactory-cli webui` or `python src/webui.py`
- **API Server**: `llamafactory-cli api` or `python src/api.py`
- **Chat Interface**: `llamafactory-cli chat --model_name_or_path MODEL_PATH`
## Environment Setup
For development:
```bash
pip install -e ".[dev]"
```
## Important Notes
- The project supports multiple backends: default PyTorch, vLLM, SGLang
- Megatron-core training is supported via mcore_adapter
- SwanLab and W&B are supported for experiment tracking
- Docker support is available with pre-built images
- Day-0/Day-1 support for latest cutting-edge models
- Multi-modal support for vision and audio understanding tasks
## Contribution Guidelines
1. Fork the repository
2. Create a development branch
3. Set up development environment with `pip install -e ".[dev]"`
4. Make changes following the style guide
5. Run quality checks: `make style && make quality`
6. Run tests: `make test`
7. Submit a pull request
## Common Commands
- `make style` - Format code
- `make quality` - Run linters
- `make test` - Run tests
- `make commit` - Install and run pre-commit hooks
- `make license` - Check license headers

View File

@@ -7,7 +7,7 @@ on:
- "main"
paths:
- "**/*.py"
- "requirements.txt"
- "pyproject.toml"
- "docker/**"
- ".github/workflows/*.yml"
pull_request:
@@ -15,7 +15,7 @@ on:
- "main"
paths:
- "**/*.py"
- "requirements.txt"
- "pyproject.toml"
- "docker/**"
- ".github/workflows/*.yml"
release:
@@ -29,16 +29,13 @@ jobs:
matrix:
include:
- device: "cuda"
npu_type: ""
- device: "npu"
npu_type: "a2"
- device: "npu"
npu_type: "a3"
- device: "npu-a2"
- device: "npu-a3"
runs-on: ubuntu-latest
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ matrix.device }}-${{ matrix.npu_type }}
group: ${{ github.workflow }}-${{ github.ref }}-${{ matrix.device }}
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
environment:
@@ -55,16 +52,11 @@ jobs:
- name: Checkout
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.10"
- name: Get llamafactory version
id: version
run: |
if [ "${{ github.event_name }}" = "release" ]; then
echo "tag=$(python setup.py --version)" >> "$GITHUB_OUTPUT"
echo "tag=$(grep -oP 'VERSION = "\K[^"]+' src/llamafactory/extras/env.py)" >> "$GITHUB_OUTPUT"
else
echo "tag=latest" >> "$GITHUB_OUTPUT"
fi
@@ -93,16 +85,12 @@ jobs:
with:
context: .
file: ./docker/docker-cuda/Dockerfile
build-args: |
EXTRAS=metrics,deepspeed,liger-kernel
push: ${{ github.event_name != 'pull_request' }}
tags: |
docker.io/hiyouga/llamafactory:${{ steps.version.outputs.tag }}
cache-from: type=gha
cache-to: type=gha,mode=max
- name: Build and push Docker image (NPU-A2)
if: ${{ matrix.device == 'npu' && matrix.npu_type == 'a2' }}
if: ${{ matrix.device == 'npu-a2' }}
uses: docker/build-push-action@v6
with:
context: .
@@ -112,11 +100,9 @@ jobs:
tags: |
docker.io/hiyouga/llamafactory:${{ steps.version.outputs.tag }}-npu-a2
quay.io/ascend/llamafactory:${{ steps.version.outputs.tag }}-npu-a2
cache-from: type=gha
cache-to: type=gha,mode=max
- name: Build and push Docker image (NPU-A3)
if: ${{ matrix.device == 'npu' && matrix.npu_type == 'a3' }}
if: ${{ matrix.device == 'npu-a3' }}
uses: docker/build-push-action@v6
with:
context: .
@@ -128,5 +114,3 @@ jobs:
tags: |
docker.io/hiyouga/llamafactory:${{ steps.version.outputs.tag }}-npu-a3
quay.io/ascend/llamafactory:${{ steps.version.outputs.tag }}-npu-a3
cache-from: type=gha
cache-to: type=gha,mode=max

View File

@@ -23,10 +23,11 @@ jobs:
- name: Checkout
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
- name: Install uv
uses: astral-sh/setup-uv@v7
with:
python-version: "3.9"
python-version: "3.11"
github-token: ${{ github.token }}
- name: Build package
run: |

View File

@@ -7,7 +7,7 @@ on:
- "main"
paths:
- "**/*.py"
- "requirements.txt"
- "pyproject.toml"
- "Makefile"
- ".github/workflows/*.yml"
pull_request:
@@ -15,7 +15,7 @@ on:
- "main"
paths:
- "**/*.py"
- "requirements.txt"
- "pyproject.toml"
- "Makefile"
- ".github/workflows/*.yml"
@@ -25,10 +25,9 @@ jobs:
fail-fast: false
matrix:
python:
- "3.9"
- "3.10"
- "3.11"
- "3.12"
# - "3.13" # enable after trl is upgraded
os:
- "ubuntu-latest"
- "windows-latest"
@@ -36,18 +35,15 @@ jobs:
transformers:
- null
include: # test backward compatibility
- python: "3.9"
- python: "3.11"
os: "ubuntu-latest"
transformers: "4.49.0"
- python: "3.9"
- python: "3.11"
os: "ubuntu-latest"
transformers: "4.51.0"
- python: "3.9"
- python: "3.11"
os: "ubuntu-latest"
transformers: "4.53.0"
exclude: # exclude python 3.9 on macos
- python: "3.9"
os: "macos-latest"
runs-on: ${{ matrix.os }}
@@ -63,22 +59,23 @@ jobs:
- name: Checkout
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
- name: Install uv
uses: astral-sh/setup-uv@v7
with:
python-version: ${{ matrix.python }}
cache: "pip"
cache-dependency-path: "**/requirements*.txt"
github-token: ${{ github.token }}
enable-cache: false
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install ".[torch,dev]"
uv venv
uv pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
uv pip install -e ".[dev]"
- name: Install transformers
if: ${{ matrix.transformers }}
run: |
python -m pip install "transformers==${{ matrix.transformers }}"
uv pip install "transformers==${{ matrix.transformers }}"
- name: Cache files
id: hf-hub-cache
@@ -90,18 +87,25 @@ jobs:
- name: Check quality
run: |
make style && make quality
env:
UV_NO_SYNC: 1
- name: Check license
run: |
make license
env:
UV_NO_SYNC: 1
- name: Check build
run: |
make build
env:
UV_NO_SYNC: 1
- name: Test with pytest
run: |
make test
env:
UV_NO_SYNC: 1
HF_HOME: ${{ runner.temp }}/huggingface
HF_HUB_OFFLINE: "${{ steps.hf-hub-cache.outputs.cache-hit == 'true' && '1' || '0' }}"

View File

@@ -7,7 +7,7 @@ on:
- "main"
paths:
- "**/*.py"
- "requirements.txt"
- "pyproject.toml"
- "Makefile"
- ".github/workflows/*.yml"
pull_request:
@@ -15,7 +15,7 @@ on:
- "main"
paths:
- "**/*.py"
- "requirements.txt"
- "pyproject.toml"
- "Makefile"
- ".github/workflows/*.yml"
@@ -48,10 +48,15 @@ jobs:
- name: Checkout
uses: actions/checkout@v4
- name: Install uv
run: |
curl -LsSf https://astral.sh/uv/install.sh | sh
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install ".[torch-npu,dev]" torch-npu==${{matrix.pytorch_npu}}
uv venv
uv pip install torch-npu==${{matrix.pytorch_npu}}
uv pip install -e ".[dev]"
- name: Install node
run: |
@@ -70,18 +75,25 @@ jobs:
- name: Check quality
run: |
make style && make quality
env:
UV_NO_SYNC: 1
- name: Check license
run: |
make license
env:
UV_NO_SYNC: 1
- name: Check build
run: |
make build
env:
UV_NO_SYNC: 1
- name: Test with pytest
run: |
make test
env:
UV_NO_SYNC: 1
HF_HOME: /root/.cache/huggingface
HF_HUB_OFFLINE: "${{ steps.hf-hub-cache.outputs.cache-hit == 'true' && '1' || '0' }}"
HF_HUB_OFFLINE: "${{ steps.hf-hub-cache.outputs.cache-hit == 'true' && '1' || '0' }}"

5
.gitignore vendored
View File

@@ -85,7 +85,7 @@ ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
.python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
@@ -165,6 +165,9 @@ cython_debug/
# uv
uv.lock
# macOS
.DS_Store
# custom .gitignore
hf_cache/
ms_cache/

View File

@@ -1 +1 @@
include LICENSE requirements.txt
include LICENSE

View File

@@ -1,24 +1,28 @@
.PHONY: build commit license quality style test
check_dirs := scripts src tests tests_v1 setup.py
check_dirs := scripts src tests tests_v1
RUN := $(shell command -v uv >/dev/null 2>&1 && echo "uv run" || echo "")
BUILD := $(shell command -v uv >/dev/null 2>&1 && echo "uv build" || echo "python -m build")
TOOL := $(shell command -v uv >/dev/null 2>&1 && echo "uvx" || echo "")
build:
pip3 install build && python3 -m build
$(BUILD)
commit:
pre-commit install
pre-commit run --all-files
$(TOOL) pre-commit install
$(TOOL) pre-commit run --all-files
license:
python3 tests/check_license.py $(check_dirs)
$(RUN) python3 tests/check_license.py $(check_dirs)
quality:
ruff check $(check_dirs)
ruff format --check $(check_dirs)
$(TOOL) ruff check $(check_dirs)
$(TOOL) ruff format --check $(check_dirs)
style:
ruff check $(check_dirs) --fix
ruff format $(check_dirs)
$(TOOL) ruff check $(check_dirs) --fix
$(TOOL) ruff format $(check_dirs)
test:
CUDA_VISIBLE_DEVICES= ASCEND_RT_VISIBLE_DEVICES=0 WANDB_DISABLED=true pytest -vv --import-mode=importlib tests/ tests_v1/
WANDB_DISABLED=true $(RUN) pytest -vv --import-mode=importlib tests/ tests_v1/

View File

@@ -278,27 +278,21 @@ Read technical notes:
| Model | Model size | Template |
| ----------------------------------------------------------------- | -------------------------------- | -------------------- |
| [Baichuan 2](https://huggingface.co/baichuan-inc) | 7B/13B | baichuan2 |
| [BLOOM/BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
| [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 |
| [Command R](https://huggingface.co/CohereForAI) | 35B/104B | cohere |
| [DeepSeek (Code/MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek |
| [DeepSeek 2.5/3](https://huggingface.co/deepseek-ai) | 236B/671B | deepseek3 |
| [DeepSeek (LLM/Code/MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek |
| [DeepSeek 3-3.2](https://huggingface.co/deepseek-ai) | 236B/671B | deepseek3 |
| [DeepSeek R1 (Distill)](https://huggingface.co/deepseek-ai) | 1.5B/7B/8B/14B/32B/70B/671B | deepseekr1 |
| [ERNIE-4.5](https://huggingface.co/baidu) | 0.3B/21B/300B | ernie/ernie_nothink |
| [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon |
| [Falcon-H1](https://huggingface.co/tiiuae) | 0.5B/1.5B/3B/7B/34B | falcon_h1 |
| [Falcon/Falcon H1](https://huggingface.co/tiiuae) | 0.5B/1.5B/3B/7B/11B/34B/40B/180B | falcon/falcon_h1 |
| [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma/gemma2 |
| [Gemma 3/Gemma 3n](https://huggingface.co/google) | 270M/1B/4B/6B/8B/12B/27B | gemma3/gemma3n |
| [GLM-4/GLM-4-0414/GLM-Z1](https://huggingface.co/zai-org) | 9B/32B | glm4/glmz1 |
| [GLM-4.1V](https://huggingface.co/zai-org) | 9B | glm4v |
| [GLM-4.5/GLM-4.5V](https://huggingface.co/zai-org) | 106B/355B | glm4_moe/glm4v_moe |
| [GLM-4.5/GLM-4.5(6)V](https://huggingface.co/zai-org) | 9B/106B/355B | glm4_moe/glm4_5v |
| [GPT-2](https://huggingface.co/openai-community) | 0.1B/0.4B/0.8B/1.5B | - |
| [GPT-OSS](https://huggingface.co/openai) | 20B/120B | gpt |
| [Granite 3.0-3.3](https://huggingface.co/ibm-granite) | 1B/2B/3B/8B | granite3 |
| [Granite 4](https://huggingface.co/ibm-granite) | 7B | granite4 |
| [GPT-OSS](https://huggingface.co/openai) | 20B/120B | gpt_oss |
| [Granite 3-4](https://huggingface.co/ibm-granite) | 1B/2B/3B/7B/8B | granite3/granite4 |
| [Hunyuan (MT)](https://huggingface.co/tencent/) | 7B | hunyuan |
| [Index](https://huggingface.co/IndexTeam) | 1.9B | index |
| [InternLM 2-3](https://huggingface.co/internlm) | 7B/8B/20B | intern2 |
| [InternVL 2.5-3.5](https://huggingface.co/OpenGVLab) | 1B/2B/4B/8B/14B/30B/38B/78B/241B | intern_vl |
| [InternLM/Intern-S1-mini](https://huggingface.co/internlm/) | 8B | intern_s1 |
@@ -312,15 +306,13 @@ Read technical notes:
| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | llava |
| [LLaVA-NeXT](https://huggingface.co/llava-hf) | 7B/8B/13B/34B/72B/110B | llava_next |
| [LLaVA-NeXT-Video](https://huggingface.co/llava-hf) | 7B/34B | llava_next_video |
| [MiMo](https://huggingface.co/XiaomiMiMo) | 7B | mimo |
| [MiMo](https://huggingface.co/XiaomiMiMo) | 7B/309B | mimo/mimo_v2 |
| [MiniCPM 1-4.1](https://huggingface.co/openbmb) | 0.5B/1B/2B/4B/8B | cpm/cpm3/cpm4 |
| [MiniCPM-o-2.6/MiniCPM-V-2.6](https://huggingface.co/openbmb) | 8B | minicpm_o/minicpm_v |
| [Ministral/Mistral-Nemo](https://huggingface.co/mistralai) | 8B/12B | ministral |
| [Ministral 3](https://huggingface.co/mistralai) | 3B/8B/14B | ministral3 |
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
| [Mistral Small](https://huggingface.co/mistralai) | 24B | mistral_small |
| [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
| [PaliGemma/PaliGemma2](https://huggingface.co/google) | 3B/10B/28B | paligemma |
| [Phi-1.5/Phi-2](https://huggingface.co/microsoft) | 1.3B/2.7B | - |
| [Phi-3/Phi-3.5](https://huggingface.co/microsoft) | 4B/14B | phi |
| [Phi-3-small](https://huggingface.co/microsoft) | 7B | phi_small |
| [Phi-4](https://huggingface.co/microsoft) | 14B | phi4 |
@@ -333,12 +325,9 @@ Read technical notes:
| [Qwen2-VL/Qwen2.5-VL/QVQ](https://huggingface.co/Qwen) | 2B/3B/7B/32B/72B | qwen2_vl |
| [Qwen3-VL](https://huggingface.co/Qwen) | 2B/4B/8B/30B/32B/235B | qwen3_vl |
| [Seed (OSS/Coder)](https://huggingface.co/ByteDance-Seed) | 8B/36B | seed_oss/seed_coder |
| [Skywork o1](https://huggingface.co/Skywork) | 8B | skywork_o1 |
| [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
| [TeleChat2](https://huggingface.co/Tele-AI) | 3B/7B/35B/115B | telechat2 |
| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | xverse |
| [VibeThinker-1.5B](https://huggingface.co/WeiboAI) | 1.5B | qwen3 |
| [Yi/Yi-1.5 (Code)](https://huggingface.co/01-ai) | 1.5B/6B/9B/34B | yi |
| [Yi-VL](https://huggingface.co/01-ai) | 6B/34B | yi_vl |
| [Yuan 2](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan |
> [!NOTE]
@@ -525,10 +514,12 @@ huggingface-cli login
```bash
git clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git
cd LLaMA-Factory
pip install -e ".[torch,metrics]" --no-build-isolation
pip install -e ".[metrics]" --no-build-isolation
```
Extra dependencies available: torch, torch-npu, metrics, deepspeed, liger-kernel, bitsandbytes, hqq, eetq, gptq, aqlm, vllm, sglang, galore, apollo, badam, adam-mini, qwen, minicpm_v, openmind, swanlab, dev
Optional dependencies available: `metrics`, `deepspeed`. Install with: `pip install -e ".[metrics,deepspeed]"`
Additional dependencies for specific features are available in `examples/requirements/`.
#### Install from Docker Image
@@ -547,13 +538,7 @@ Please refer to [build docker](#build-docker) to build the image yourself.
Create an isolated Python environment with [uv](https://github.com/astral-sh/uv):
```bash
uv sync --extra torch --extra metrics --prerelease=allow
```
Run LLaMA-Factory in the isolated environment:
```bash
uv run --prerelease=allow llamafactory-cli train examples/train_lora/llama3_lora_pretrain.yaml
uv run llamafactory-cli webui
```
</details>
@@ -590,7 +575,7 @@ To enable FlashAttention-2 on the Windows platform, please use the script from [
<details><summary>For Ascend NPU users</summary>
To install LLaMA Factory on Ascend NPU devices, please upgrade Python to version 3.10 or higher and specify extra dependencies: `pip install -e ".[torch-npu,metrics]"`. Additionally, you need to install the **[Ascend CANN Toolkit and Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**. Please follow the [installation tutorial](https://www.hiascend.com/document/detail/en/CANNCommunityEdition/600alphaX/softwareinstall/instg/atlasdeploy_03_0031.html) or use the following commands:
To install LLaMA Factory on Ascend NPU devices, please upgrade Python to version 3.10 or higher: `pip install -e . torch-npu==2.7.1`. Additionally, you need to install the **[Ascend CANN Toolkit and Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**. Please follow the [installation tutorial](https://www.hiascend.com/document/detail/en/CANNCommunityEdition/600alphaX/softwareinstall/instg/atlasdeploy_03_0031.html) or use the following commands:
```bash
# replace the url according to your CANN version and devices
@@ -609,8 +594,8 @@ source /usr/local/Ascend/ascend-toolkit/set_env.sh
| Requirement | Minimum | Recommend |
| ------------ | ------- | -------------- |
| CANN | 8.0.RC1 | 8.0.0.alpha002 |
| torch | 2.1.0 | 2.4.0 |
| torch-npu | 2.1.0 | 2.4.0.post2 |
| torch | 2.1.0 | 2.7.1 |
| torch-npu | 2.1.0 | 2.7.1 |
| deepspeed | 0.13.2 | 0.13.2 |
| vllm-ascend | - | 0.7.3 |
@@ -725,7 +710,6 @@ For CUDA users:
```bash
docker build -f ./docker/docker-cuda/Dockerfile \
--build-arg PIP_INDEX=https://pypi.org/simple \
--build-arg EXTRAS=metrics \
-t llamafactory:latest .
docker run -dit --ipc=host --gpus=all \
@@ -742,7 +726,6 @@ For Ascend NPU users:
```bash
docker build -f ./docker/docker-npu/Dockerfile \
--build-arg PIP_INDEX=https://pypi.org/simple \
--build-arg EXTRAS=torch-npu,metrics \
-t llamafactory:latest .
docker run -dit --ipc=host \
@@ -767,7 +750,6 @@ For AMD ROCm users:
```bash
docker build -f ./docker/docker-rocm/Dockerfile \
--build-arg PIP_INDEX=https://pypi.org/simple \
--build-arg EXTRAS=metrics \
-t llamafactory:latest .
docker run -dit --ipc=host \

View File

@@ -280,27 +280,21 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
| 模型名 | 参数量 | Template |
| ----------------------------------------------------------------- | -------------------------------- | -------------------- |
| [Baichuan 2](https://huggingface.co/baichuan-inc) | 7B/13B | baichuan2 |
| [BLOOM/BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
| [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 |
| [Command R](https://huggingface.co/CohereForAI) | 35B/104B | cohere |
| [DeepSeek (Code/MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek |
| [DeepSeek 2.5/3](https://huggingface.co/deepseek-ai) | 236B/671B | deepseek3 |
| [DeepSeek (LLM/Code/MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek |
| [DeepSeek 3-3.2](https://huggingface.co/deepseek-ai) | 236B/671B | deepseek3 |
| [DeepSeek R1 (Distill)](https://huggingface.co/deepseek-ai) | 1.5B/7B/8B/14B/32B/70B/671B | deepseekr1 |
| [ERNIE-4.5](https://huggingface.co/baidu) | 0.3B/21B/300B | ernie/ernie_nothink |
| [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon |
| [Falcon-H1](https://huggingface.co/tiiuae) | 0.5B/1.5B/3B/7B/34B | falcon_h1 |
| [Falcon/Falcon H1](https://huggingface.co/tiiuae) | 0.5B/1.5B/3B/7B/11B/34B/40B/180B | falcon/falcon_h1 |
| [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma/gemma2 |
| [Gemma 3/Gemma 3n](https://huggingface.co/google) | 270M/1B/4B/6B/8B/12B/27B | gemma3/gemma3n |
| [GLM-4/GLM-4-0414/GLM-Z1](https://huggingface.co/zai-org) | 9B/32B | glm4/glmz1 |
| [GLM-4.1V](https://huggingface.co/zai-org) | 9B | glm4v |
| [GLM-4.5/GLM-4.5V](https://huggingface.co/zai-org) | 106B/355B | glm4_moe/glm4v_moe |
| [GLM-4.5/GLM-4.5(6)V](https://huggingface.co/zai-org) | 9B/106B/355B | glm4_moe/glm4_5v |
| [GPT-2](https://huggingface.co/openai-community) | 0.1B/0.4B/0.8B/1.5B | - |
| [GPT-OSS](https://huggingface.co/openai) | 20B/120B | gpt |
| [Granite 3.0-3.3](https://huggingface.co/ibm-granite) | 1B/2B/3B/8B | granite3 |
| [Granite 4](https://huggingface.co/ibm-granite) | 7B | granite4 |
| [GPT-OSS](https://huggingface.co/openai) | 20B/120B | gpt_oss |
| [Granite 3-4](https://huggingface.co/ibm-granite) | 1B/2B/3B/7B/8B | granite3/granite4 |
| [Hunyuan (MT)](https://huggingface.co/tencent/) | 7B | hunyuan |
| [Index](https://huggingface.co/IndexTeam) | 1.9B | index |
| [InternLM 2-3](https://huggingface.co/internlm) | 7B/8B/20B | intern2 |
| [InternVL 2.5-3.5](https://huggingface.co/OpenGVLab) | 1B/2B/4B/8B/14B/30B/38B/78B/241B | intern_vl |
| [InternLM/Intern-S1-mini](https://huggingface.co/internlm/) | 8B | intern_s1 |
@@ -314,15 +308,13 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | llava |
| [LLaVA-NeXT](https://huggingface.co/llava-hf) | 7B/8B/13B/34B/72B/110B | llava_next |
| [LLaVA-NeXT-Video](https://huggingface.co/llava-hf) | 7B/34B | llava_next_video |
| [MiMo](https://huggingface.co/XiaomiMiMo) | 7B | mimo |
| [MiMo](https://huggingface.co/XiaomiMiMo) | 7B/309B | mimo/mimo_v2 |
| [MiniCPM 1-4.1](https://huggingface.co/openbmb) | 0.5B/1B/2B/4B/8B | cpm/cpm3/cpm4 |
| [MiniCPM-o-2.6/MiniCPM-V-2.6](https://huggingface.co/openbmb) | 8B | minicpm_o/minicpm_v |
| [Ministral/Mistral-Nemo](https://huggingface.co/mistralai) | 8B/12B | ministral |
| [Ministral 3](https://huggingface.co/mistralai) | 3B/8B/14B | ministral3 |
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
| [Mistral Small](https://huggingface.co/mistralai) | 24B | mistral_small |
| [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
| [PaliGemma/PaliGemma2](https://huggingface.co/google) | 3B/10B/28B | paligemma |
| [Phi-1.5/Phi-2](https://huggingface.co/microsoft) | 1.3B/2.7B | - |
| [Phi-3/Phi-3.5](https://huggingface.co/microsoft) | 4B/14B | phi |
| [Phi-3-small](https://huggingface.co/microsoft) | 7B | phi_small |
| [Phi-4](https://huggingface.co/microsoft) | 14B | phi4 |
@@ -335,12 +327,9 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
| [Qwen2-VL/Qwen2.5-VL/QVQ](https://huggingface.co/Qwen) | 2B/3B/7B/32B/72B | qwen2_vl |
| [Qwen3-VL](https://huggingface.co/Qwen) | 2B/4B/8B/30B/32B/235B | qwen3_vl |
| [Seed (OSS/Coder)](https://huggingface.co/ByteDance-Seed) | 8B/36B | seed_oss/seed_coder |
| [Skywork o1](https://huggingface.co/Skywork) | 8B | skywork_o1 |
| [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
| [TeleChat2](https://huggingface.co/Tele-AI) | 3B/7B/35B/115B | telechat2 |
| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | xverse |
| [VibeThinker-1.5B](https://huggingface.co/WeiboAI) | 1.5B | qwen3 |
| [Yi/Yi-1.5 (Code)](https://huggingface.co/01-ai) | 1.5B/6B/9B/34B | yi |
| [Yi-VL](https://huggingface.co/01-ai) | 6B/34B | yi_vl |
| [Yuan 2](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan |
> [!NOTE]
@@ -527,10 +516,12 @@ huggingface-cli login
```bash
git clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git
cd LLaMA-Factory
pip install -e ".[torch,metrics]" --no-build-isolation
pip install -e ".[metrics]" --no-build-isolation
```
可选的额外依赖项:torch、torch-npu、metricsdeepspeed、liger-kernel、bitsandbytes、hqq、eetq、gptq、aqlm、vllm、sglang、galore、apollo、badam、adam-mini、qwen、minicpm_v、openmind、swanlab、dev
可选的额外依赖项:`metrics``deepspeed`。使用 `pip install -e ".[metrics,deepspeed]"` 安装。
其他可选依赖项请参考 `examples/requirements/` 目录下的文件。
#### 从镜像安装
@@ -549,13 +540,7 @@ docker run -it --rm --gpus=all --ipc=host hiyouga/llamafactory:latest
使用 [uv](https://github.com/astral-sh/uv) 创建隔离的 Python 环境:
```bash
uv sync --extra torch --extra metrics --prerelease=allow
```
在环境中运行 LLaMA-Factory
```bash
uv run --prerelease=allow llamafactory-cli train examples/train_lora/llama3_lora_pretrain.yaml
uv run llamafactory-cli webui
```
</details>
@@ -592,7 +577,7 @@ pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/downl
<details><summary>昇腾 NPU 用户指南</summary>
在昇腾 NPU 设备上安装 LLaMA Factory 时,请升级 Python 到 3.10 及以上,并需要指定额外依赖项,使用 `pip install -e ".[torch-npu,metrics]"` 命令安装。此外,还需要安装 **[Ascend CANN Toolkit 与 Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**,安装方法请参考[安装教程](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/80RC2alpha002/quickstart/quickstart/quickstart_18_0004.html)或使用以下命令:
在昇腾 NPU 设备上安装 LLaMA Factory 时,请升级 Python 到 3.10 及以上,并需要指定额外依赖项,使用 `pip install -e . torch-npu==2.7.1` 命令安装。此外,还需要安装 **[Ascend CANN Toolkit 与 Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**,安装方法请参考[安装教程](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/80RC2alpha002/quickstart/quickstart/quickstart_18_0004.html)或使用以下命令:
```bash
# 请替换 URL 为 CANN 版本和设备型号对应的 URL
@@ -611,8 +596,8 @@ source /usr/local/Ascend/ascend-toolkit/set_env.sh
| 依赖项 | 至少 | 推荐 |
| ------------ | ------- | -------------- |
| CANN | 8.0.RC1 | 8.0.0.alpha002 |
| torch | 2.1.0 | 2.4.0 |
| torch-npu | 2.1.0 | 2.4.0.post2 |
| torch | 2.1.0 | 2.7.1 |
| torch-npu | 2.1.0 | 2.7.1 |
| deepspeed | 0.13.2 | 0.13.2 |
| vllm-ascend | - | 0.7.3 |

View File

@@ -1,4 +1,4 @@
dpo_zh_demo:
hf_hub_url: HuggingFaceH4/orca_dpo_pairs
path: HuggingFaceH4/orca_dpo_pairs
split: train_prefs
converter: pair

View File

@@ -1,8 +1,9 @@
identity:
file_name: identity.json
path: data/identity.json
source: local
converter: alpaca
alpaca_en_demo:
file_name: alpaca_en_demo.json
dataset_dir: ~/data
path: data/alpaca_en_demo.json
source: local
converter: alpaca
num_samples: 500
size: 500

View File

@@ -4,7 +4,6 @@ FROM ${BASE_IMAGE}
# Installation arguments
ARG PIP_INDEX=https://pypi.org/simple
ARG EXTRAS=metrics
ARG INSTALL_FLASHATTN=false
ARG HTTP_PROXY=""
@@ -27,17 +26,13 @@ WORKDIR /app
# Change pip source
RUN pip config set global.index-url "${PIP_INDEX}" && \
pip config set global.extra-index-url "${PIP_INDEX}" && \
pip install --no-cache-dir --upgrade pip packaging wheel setuptools
pip install --no-cache-dir --upgrade pip packaging wheel setuptools editables "hatchling>=1.18.0"
# Install the requirements
COPY requirements.txt /app
RUN pip install --no-cache-dir -r requirements.txt
# Copy the rest of the application into the image
# Copy the application into the image
COPY . /app
# Install LLaMA Factory
RUN pip install --no-cache-dir -e ".[${EXTRAS}]" --no-build-isolation
RUN pip install --no-cache-dir --no-build-isolation -e ".[metrics,deepspeed]"
# Rebuild flash attention
RUN if [ "${INSTALL_FLASHATTN}" == "true" ]; then \

View File

@@ -8,7 +8,7 @@ ENV PYPI_MIRROR=https://mirrors.aliyun.com/pypi/simple/
ENV PYPI_TRUSTED_HOST=mirrors.aliyun.com
ENV APT_MIRROR=https://mirrors.tuna.tsinghua.edu.cn/ubuntu/
RUN pip install --upgrade pip setuptools wheel --trusted-host ${PYPI_TRUSTED_HOST} --index-url ${PYPI_MIRROR}
RUN pip install --upgrade pip setuptools wheel "hatchling>=1.18.0" editables --trusted-host ${PYPI_TRUSTED_HOST} --index-url ${PYPI_MIRROR}
RUN pip uninstall -y torch torchvision torch-tensorrt \
flash_attn transformer-engine \
@@ -56,14 +56,14 @@ ENV JAVA_HOME /usr/lib/jvm/java-21-openjdk-amd64
# pip install LLaMA-Factory
WORKDIR /app
COPY requirements.txt /app/
RUN pip install --no-cache-dir -r requirements.txt
# Copy the application into the image
COPY . /app
# Install LLaMA Factory
RUN pip install --no-cache-dir -e ".[metrics]" --no-build-isolation
RUN pip install "git+https://github.com/alibaba/roll.git#subdirectory=mcore_adapter"
COPY . /app/
RUN pip install -e ".[metrics]" --no-build-isolation
# Expose port 7860 for LLaMA Board
ENV GRADIO_SERVER_PORT=7860
EXPOSE 7860

View File

@@ -5,7 +5,6 @@ services:
context: ../..
args:
PIP_INDEX: https://pypi.org/simple
EXTRAS: metrics
container_name: llamafactory
ports:
- "7860:7860"

View File

@@ -5,7 +5,6 @@ FROM ${BASE_IMAGE}
# Installation arguments
ARG PIP_INDEX=https://pypi.org/simple
ARG EXTRAS=torch-npu,metrics
ARG HTTP_PROXY=""
ARG PYTORCH_INDEX=https://download.pytorch.org/whl/cpu
@@ -28,21 +27,15 @@ WORKDIR /app
# Change pip source
RUN pip config set global.index-url "${PIP_INDEX}" && \
pip config set global.extra-index-url "${PIP_INDEX}" && \
pip install --no-cache-dir --upgrade pip packaging wheel setuptools
pip install --no-cache-dir --upgrade pip packaging wheel setuptools editables "hatchling>=1.18.0"
# Copy the application into the image
COPY . /app
# Install torch-npu
RUN pip uninstall -y torch torchvision torchaudio && \
pip install --no-cache-dir "torch==2.7.1" "torch-npu==2.7.1" "torchvision==0.22.1" --index-url "${PYTORCH_INDEX}"
# Install the requirements
COPY requirements.txt /app
RUN pip install --no-cache-dir -r requirements.txt
# Copy the rest of the application into the image
COPY . /app
# Install LLaMA Factory
RUN pip install --no-cache-dir -e ".[${EXTRAS}]" --no-build-isolation
pip install --no-cache-dir "torch==2.7.1" "torch-npu==2.7.1" "torchvision==0.22.1" "torchaudio==2.7.1" --index-url "${PYTORCH_INDEX}" && \
pip install --no-cache-dir -e ".[metrics]" --no-build-isolation
# Set up volumes
# VOLUME [ "/root/.cache/huggingface", "/app/shared_data", "/app/output" ]

View File

@@ -5,7 +5,6 @@ services:
context: ../..
args:
PIP_INDEX: https://pypi.org/simple
EXTRAS: torch-npu,metrics
container_name: llamafactory-a2
image: llamafactory:npu-a2
volumes:
@@ -36,7 +35,6 @@ services:
args:
BASE_IMAGE: quay.io/ascend/cann:8.3.rc2-a3-ubuntu22.04-py3.11
PIP_INDEX: https://pypi.org/simple
EXTRAS: torch-npu,metrics
container_name: llamafactory-a3
image: llamafactory:npu-a3
volumes:

View File

@@ -4,7 +4,6 @@ FROM ${BASE_IMAGE}
# Installation arguments
ARG PIP_INDEX=https://pypi.org/simple
ARG EXTRAS=metrics
ARG INSTALL_FLASHATTN=false
ARG HTTP_PROXY=""
ARG PYTORCH_INDEX=https://download.pytorch.org/whl/rocm6.3
@@ -28,21 +27,14 @@ WORKDIR /app
# Change pip source
RUN pip config set global.index-url "${PIP_INDEX}" && \
pip config set global.extra-index-url "${PIP_INDEX}" && \
pip install --no-cache-dir --upgrade pip packaging wheel setuptools
pip install --no-cache-dir --upgrade pip packaging wheel setuptools editables "hatchling>=1.18.0"
# Reinstall pytorch rocm
RUN pip uninstall -y torch torchvision torchaudio && \
pip install --no-cache-dir --pre torch torchvision torchaudio --index-url "${PYTORCH_INDEX}"
# Install the requirements
COPY requirements.txt /app
RUN pip install --no-cache-dir -r requirements.txt
# Copy the rest of the application into the image
# Copy the application into the image
COPY . /app
# Install LLaMA Factory
RUN pip install --no-cache-dir -e ".[${EXTRAS}]" --no-build-isolation
# Reinstall pytorch rocm and install LLaMA Factory
RUN pip uninstall -y torch torchvision torchaudio && \
pip install --no-cache-dir --no-build-isolation -e --pre ".[metrics,deepspeed]" --index-url "${PYTORCH_INDEX}"
# Rebuild flash attention
RUN if [ "${INSTALL_FLASHATTN}" == "true" ]; then \

View File

@@ -5,7 +5,6 @@ services:
context: ../..
args:
PIP_INDEX: https://pypi.org/simple
EXTRAS: metrics
container_name: llamafactory
ports:
- "7860:7860"

View File

@@ -19,4 +19,4 @@ same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
use_cpu: false

View File

@@ -0,0 +1,45 @@
# Start FSDP2 fine-tuning
# accelerate launch \
# --config_file examples/accelerate/fsdp2_config.yaml \
# src/train.py examples/ascend/qwen3_full_sft_fsdp2.yaml
# Change `num_processes` in fsdp2_config.yaml to 16 in A3
### model
model_name_or_path: Qwen/Qwen3-8B
trust_remote_code: true
use_v1_kernels: true
flash_attn: fa2
### method
stage: sft
do_train: true
finetuning_type: full
### dataset
dataset: alpaca_en_demo
template: qwen3
cutoff_len: 2048
max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16
dataloader_num_workers: 4
### output
output_dir: saves/Qwen3-8B/full/sft
logging_steps: 1
save_steps: 500
max_steps: 500
plot_loss: true
overwrite_output_dir: true
save_only_model: false
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
### train
per_device_train_batch_size: 8
gradient_accumulation_steps: 1
learning_rate: 1.0e-5
lr_scheduler_type: cosine
warmup_ratio: 0.1
bf16: true
ddp_timeout: 1800
resume_from_checkpoint: null

View File

@@ -0,0 +1,46 @@
# Start FSDP fine-tuning
# accelerate launch \
# --config_file examples/accelerate/fsdp_config.yaml \
# src/train.py examples/ascend/qwen3moe_full_sft_fsdp.yaml
# Change `num_processes` in fsdp_config.yaml to 16 in A3
### model
model_name_or_path: Qwen/Qwen3-30B-A3B-Instruct-2507
trust_remote_code: true
use_v1_kernels: true
flash_attn: fa2
### method
stage: sft
do_train: true
finetuning_type: full
disable_gradient_checkpointing: false
### dataset
dataset: alpaca_zh
template: qwen3
cutoff_len: 1024
overwrite_cache: true
preprocessing_num_workers: 16
dataloader_num_workers: 4
### output
output_dir: saves/Qwen3-30B-A3B-Instruct-2507/full/sft
logging_steps: 1
save_steps: 500
max_steps: 500
plot_loss: true
overwrite_output_dir: true
save_only_model: true
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
### train
per_device_train_batch_size: 4
gradient_accumulation_steps: 1
learning_rate: 1.0e-4
lr_scheduler_type: cosine
warmup_ratio: 0.1
bf16: true
ddp_timeout: 180000000
resume_from_checkpoint: null
seed: 1234

View File

@@ -0,0 +1,48 @@
# Start FSDP2 fine-tuning
# accelerate launch \
# --config_file examples/accelerate/fsdp2_config.yaml \
# src/train.py examples/ascend/qwen3vlmoe_full_sft_fsdp2.yaml
# Change `num_processes` in fsdp2_config.yaml to 16 in A3
### model
model_name_or_path: Qwen/Qwen3-VL-30B-A3B-Instruct
image_max_pixels: 262144
video_max_pixels: 16384
trust_remote_code: true
use_v1_kernels: true
flash_attn: fa2
### method
stage: sft
do_train: true
finetuning_type: full
disable_gradient_checkpointing: false
### dataset
dataset: llava_1k_en, llava_1k_zh
template: qwen3_vl
cutoff_len: 1024
overwrite_cache: true
preprocessing_num_workers: 16
dataloader_num_workers: 4
### output
output_dir: saves/Qwen3-VL-30B-A3B-Instruct/full/sft
logging_steps: 1
save_steps: 500
max_steps: 500
plot_loss: true
overwrite_output_dir: true
save_only_model: true
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
### train
per_device_train_batch_size: 2
gradient_accumulation_steps: 1
learning_rate: 1.0e-4
lr_scheduler_type: cosine
warmup_ratio: 0.1
bf16: true
ddp_timeout: 180000000
resume_from_checkpoint: null
seed: 1234

View File

@@ -39,4 +39,4 @@ warmup_ratio: 0.1
bf16: true
ddp_timeout: 180000000
resume_from_checkpoint: null
seed: 1234
seed: 1234

View File

@@ -0,0 +1,32 @@
{
"_comment": "suooprted model list: https://www.deepspeed.ai/tutorials/automatic-tensor-parallelism/#supported-models",
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"zero_allow_untested_optimizer": true,
"fp16": {
"enabled": "auto",
"loss_scale": 0,
"loss_scale_window": 1000,
"initial_scale_power": 16,
"hysteresis": 2,
"min_loss_scale": 1
},
"bf16": {
"enabled": "auto"
},
"zero_optimization": {
"stage": 2,
"allgather_partitions": true,
"allgather_bucket_size": 5e8,
"overlap_comm": false,
"reduce_scatter": true,
"reduce_bucket_size": 5e8,
"contiguous_gradients": true,
"round_robin_gradients": true
},
"tensor_parallel": {
"autotp_size": 2
}
}

View File

@@ -0,0 +1 @@
adam-mini

View File

@@ -0,0 +1 @@
apollo-torch

View File

@@ -0,0 +1 @@
aqlm[gpu]>=1.1.0

View File

@@ -0,0 +1 @@
badam>=1.2.1

View File

@@ -0,0 +1 @@
bitsandbytes>=0.39.0

View File

@@ -0,0 +1 @@
eetq

View File

@@ -0,0 +1,2 @@
transformer_engine[pytorch]>=2.0.0
accelerate>=1.10.0

View File

@@ -0,0 +1,2 @@
torchao>=0.8.0
accelerate>=1.10.0

View File

@@ -0,0 +1 @@
galore-torch

View File

@@ -0,0 +1,2 @@
optimum>=1.24.0
gptqmodel>=2.0.0

View File

@@ -0,0 +1 @@
hqq

View File

@@ -0,0 +1 @@
liger-kernel>=0.5.5

View File

@@ -0,0 +1,8 @@
soundfile
torchvision
torchaudio
vector_quantize_pytorch
vocos
msgpack
referencing
jsonschema_specifications

View File

@@ -0,0 +1 @@
openmind

View File

@@ -0,0 +1,2 @@
sglang[srt]>=0.4.5
transformers==4.51.1

View File

@@ -0,0 +1 @@
swanlab

View File

@@ -0,0 +1 @@
vllm>=0.4.3,<=0.11.0

View File

@@ -0,0 +1,46 @@
### model
model_name_or_path: Qwen/Qwen3-32B
trust_remote_code: true
use_v1_kernels: true
### method
stage: sft
do_train: true
finetuning_type: full
deepspeed: examples/deepspeed/ds_z2_autotp_config.json
### dataset
dataset: identity,alpaca_en_demo
template: qwen3
cutoff_len: 2048
max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16
dataloader_num_workers: 4
### output
output_dir: saves/qwen3-32b/full/sft_autotp
logging_steps: 1
save_steps: 500
plot_loss: true
overwrite_output_dir: true
save_only_model: false
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
### train
per_device_train_batch_size: 4
gradient_accumulation_steps: 1
learning_rate: 1.0e-4
num_train_epochs: 3.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
bf16: true
ddp_timeout: 180000000
resume_from_checkpoint: null
### eval
# eval_dataset: alpaca_en_demo
# val_size: 0.1
# per_device_eval_batch_size: 1
# eval_strategy: steps
# eval_steps: 500

View File

@@ -1,42 +1,123 @@
[build-system]
requires = ["setuptools>=61.0"]
build-backend = "setuptools.build_meta"
requires = ["hatchling"]
build-backend = "hatchling.build"
[project]
name = "llamafactory"
requires-python = ">=3.9.0"
dynamic = [
"version",
"dependencies",
"optional-dependencies",
"scripts",
"authors",
"description",
"readme",
"license",
"keywords",
"classifiers"
dynamic = ["version"]
description = "Unified Efficient Fine-Tuning of 100+ LLMs"
readme = "README.md"
license = "Apache-2.0"
requires-python = ">=3.11.0"
authors = [
{ name = "hiyouga", email = "hiyouga@buaa.edu.cn" }
]
keywords = [
"AI",
"LLM",
"GPT",
"ChatGPT",
"Llama",
"Transformer",
"DeepSeek",
"Pytorch"
]
classifiers = [
"Development Status :: 4 - Beta",
"Intended Audience :: Developers",
"Intended Audience :: Education",
"Intended Audience :: Science/Research",
"License :: OSI Approved :: Apache Software License",
"Operating System :: OS Independent",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
"Topic :: Scientific/Engineering :: Artificial Intelligence"
]
dependencies = [
# core deps
"torch>=2.4.0",
"torchvision>=0.19.0",
"torchaudio>=2.4.0",
"transformers>=4.49.0,<=4.56.2,!=4.52.0; python_version < '3.10'",
"transformers>=4.49.0,<=4.57.1,!=4.52.0,!=4.57.0; python_version >= '3.10'",
"datasets>=2.16.0,<=4.0.0",
"accelerate>=1.3.0,<=1.11.0",
"peft>=0.14.0,<=0.17.1",
"trl>=0.8.6,<=0.9.6",
"torchdata>=0.10.0,<=0.11.0",
# gui
"gradio>=4.38.0,<=6.2.0",
"matplotlib>=3.7.0",
"tyro<0.9.0",
# ops
"einops",
"numpy",
"pandas",
"scipy",
# model and tokenizer
"sentencepiece",
"tiktoken",
"modelscope",
"hf-transfer",
"safetensors",
# python
"av",
"fire",
"omegaconf",
"packaging",
"protobuf",
"pyyaml",
"pydantic",
# api
"uvicorn",
"fastapi",
"sse-starlette"
]
[project.optional-dependencies]
dev = ["pre-commit", "ruff", "pytest", "build"]
metrics = ["nltk", "jieba", "rouge-chinese"]
deepspeed = ["deepspeed>=0.10.0,<=0.16.9"]
[project.scripts]
llamafactory-cli = "llamafactory.cli:main"
lmf = "llamafactory.cli:main"
[project.urls]
Homepage = "https://github.com/hiyouga/LLaMA-Factory"
Repository = "https://github.com/hiyouga/LLaMA-Factory"
[tool.hatch.build.targets.wheel]
packages = ["src/llamafactory"]
[tool.hatch.version]
path = "src/llamafactory/extras/env.py"
pattern = "VERSION = \"(?P<version>[^\"]+)\""
[tool.ruff]
target-version = "py39"
target-version = "py311"
line-length = 119
indent-width = 4
[tool.ruff.lint]
ignore = [
"C408", # collection
"C901", # complex
"E501", # line too long
"E731", # lambda function
"E741", # ambiguous var name
"D100", # no doc public module
"D101", # no doc public class
"D102", # no doc public method
"D103", # no doc public function
"D104", # no doc public package
"D105", # no doc magic method
"D107", # no doc __init__
"C408", # collection
"C901", # complex
"E501", # line too long
"E731", # lambda function
"E741", # ambiguous var name
"UP007", # no upgrade union
"UP045", # no upgrade optional
"D100", # no doc public module
"D101", # no doc public class
"D102", # no doc public method
"D103", # no doc public function
"D104", # no doc public package
"D105", # no doc magic method
"D107", # no doc __init__
]
extend-select = [
"C", # complexity
@@ -73,23 +154,3 @@ indent-style = "space"
docstring-code-format = true
skip-magic-trailing-comma = false
line-ending = "auto"
[tool.uv]
conflicts = [
[
{ extra = "torch-npu" },
{ extra = "aqlm" },
],
[
{ extra = "torch-npu" },
{ extra = "vllm" },
],
[
{ extra = "torch-npu" },
{ extra = "sglang" },
],
[
{ extra = "vllm" },
{ extra = "sglang" },
],
]

View File

@@ -1,38 +0,0 @@
# core deps
transformers>=4.49.0,<=4.56.2,!=4.52.0; python_version < '3.10'
transformers>=4.49.0,<=4.57.1,!=4.52.0,!=4.57.0; python_version >= '3.10'
datasets>=2.16.0,<=4.0.0
accelerate>=1.3.0,<=1.11.0
peft>=0.14.0,<=0.17.1
trl>=0.8.6,<=0.9.6
# gui
gradio>=4.38.0,<=5.45.0
matplotlib>=3.7.0
tyro<0.9.0
# ops
einops
numpy<2.0.0
pandas>=2.0.0
scipy
# model and tokenizer
sentencepiece
tiktoken
modelscope>=1.14.0
hf-transfer
safetensors<=0.5.3
# python
fire
omegaconf
packaging
protobuf
pyyaml
pydantic<=2.10.6
# api
uvicorn
fastapi
sse-starlette
# media
av
librosa
# yanked
propcache!=0.4.0

View File

@@ -16,7 +16,6 @@
# limitations under the License.
import os
from typing import Optional
import fire
import torch
@@ -34,7 +33,7 @@ def convert_mca_to_hf(
output_path: str = "./output",
bf16: bool = False,
fp16: bool = False,
convert_model_max_length: Optional[int] = None,
convert_model_max_length: int | None = None,
):
"""Convert megatron checkpoint to HuggingFace format.
@@ -67,11 +66,11 @@ def convert(
output_path: str = "./output",
bf16: bool = False,
fp16: bool = False,
convert_model_max_length: Optional[int] = None,
convert_model_max_length: int | None = None,
tensor_model_parallel_size: int = 1,
pipeline_model_parallel_size: int = 1,
expert_model_parallel_size: int = 1,
virtual_pipeline_model_parallel_size: Optional[int] = None,
virtual_pipeline_model_parallel_size: int | None = None,
):
"""Convert checkpoint between MCA and HuggingFace formats.

View File

@@ -14,7 +14,7 @@
import json
from dataclasses import dataclass
from typing import Any, Literal, Optional
from typing import Any, Literal
import fire
import torch
@@ -61,7 +61,7 @@ def calculate_ppl(
dataset_dir: str = "data",
template: str = "default",
cutoff_len: int = 2048,
max_samples: Optional[int] = None,
max_samples: int | None = None,
train_on_prompt: bool = False,
):
r"""Calculate the ppl on the dataset of the pre-trained models.

View File

@@ -14,7 +14,6 @@
import gc
import json
from typing import Optional
import av
import fire
@@ -49,7 +48,7 @@ def vllm_infer(
dataset_dir: str = "data",
template: str = "default",
cutoff_len: int = 2048,
max_samples: Optional[int] = None,
max_samples: int | None = None,
vllm_config: str = "{}",
save_name: str = "generated_predictions.jsonl",
temperature: float = 0.95,
@@ -58,9 +57,9 @@ def vllm_infer(
max_new_tokens: int = 1024,
repetition_penalty: float = 1.0,
skip_special_tokens: bool = True,
default_system: Optional[str] = None,
default_system: str | None = None,
enable_thinking: bool = True,
seed: Optional[int] = None,
seed: int | None = None,
pipeline_parallel_size: int = 1,
image_max_pixels: int = 768 * 768,
image_min_pixels: int = 32 * 32,

116
setup.py
View File

@@ -1,116 +0,0 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import re
from setuptools import find_packages, setup
def get_version() -> str:
with open(os.path.join("src", "llamafactory", "extras", "env.py"), encoding="utf-8") as f:
file_content = f.read()
pattern = r"{}\W*=\W*\"([^\"]+)\"".format("VERSION")
(version,) = re.findall(pattern, file_content)
return version
def get_requires() -> list[str]:
with open("requirements.txt", encoding="utf-8") as f:
file_content = f.read()
lines = [line.strip() for line in file_content.strip().split("\n") if not line.startswith("#")]
return lines
def get_console_scripts() -> list[str]:
console_scripts = ["llamafactory-cli = llamafactory.cli:main"]
if os.getenv("ENABLE_SHORT_CONSOLE", "1").lower() in ["true", "y", "1"]:
console_scripts.append("lmf = llamafactory.cli:main")
return console_scripts
extra_require = {
"torch": ["torch>=2.0.0", "torchvision>=0.15.0"],
"torch-npu": ["torch==2.7.1", "torch-npu==2.7.1", "torchvision==0.22.1", "decorator"],
"metrics": ["nltk", "jieba", "rouge-chinese"],
"deepspeed": ["deepspeed>=0.10.0,<=0.16.9"],
"liger-kernel": ["liger-kernel>=0.5.5"],
"bitsandbytes": ["bitsandbytes>=0.39.0"],
"hqq": ["hqq"],
"eetq": ["eetq"],
"gptq": ["optimum>=1.24.0", "gptqmodel>=2.0.0"],
"aqlm": ["aqlm[gpu]>=1.1.0"],
"vllm": ["vllm>=0.4.3,<=0.11.0"],
"sglang": ["sglang[srt]>=0.4.5", "transformers==4.51.1"],
"galore": ["galore-torch"],
"apollo": ["apollo-torch"],
"badam": ["badam>=1.2.1"],
"adam-mini": ["adam-mini"],
"minicpm_v": [
"soundfile",
"torchvision",
"torchaudio",
"vector_quantize_pytorch",
"vocos",
"msgpack",
"referencing",
"jsonschema_specifications",
],
"openmind": ["openmind"],
"swanlab": ["swanlab"],
"fp8": ["torchao>=0.8.0", "accelerate>=1.10.0"],
"fp8-te": ["transformer_engine[pytorch]>=2.0.0", "accelerate>=1.10.0"],
"fp8-all": ["torchao>=0.8.0", "transformer_engine[pytorch]>=2.0.0", "accelerate>=1.10.0"],
"dev": ["pre-commit", "ruff", "pytest", "build"],
}
def main():
setup(
name="llamafactory",
version=get_version(),
author="hiyouga",
author_email="hiyouga@buaa.edu.cn",
description="Unified Efficient Fine-Tuning of 100+ LLMs",
long_description=open("README.md", encoding="utf-8").read(),
long_description_content_type="text/markdown",
keywords=["AI", "LLM", "GPT", "ChatGPT", "Llama", "Transformer", "DeepSeek", "Pytorch"],
license="Apache 2.0 License",
url="https://github.com/hiyouga/LLaMA-Factory",
package_dir={"": "src"},
packages=find_packages("src"),
python_requires=">=3.9.0",
install_requires=get_requires(),
extras_require=extra_require,
entry_points={"console_scripts": get_console_scripts()},
classifiers=[
"Development Status :: 4 - Beta",
"Intended Audience :: Developers",
"Intended Audience :: Education",
"Intended Audience :: Science/Research",
"License :: OSI Approved :: Apache Software License",
"Operating System :: OS Independent",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
],
)
if __name__ == "__main__":
main()

View File

@@ -16,7 +16,7 @@ import asyncio
import os
from contextlib import asynccontextmanager
from functools import partial
from typing import Annotated, Optional
from typing import Annotated
from ..chat import ChatModel
from ..extras.constants import EngineName
@@ -79,7 +79,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
api_key = os.getenv("API_KEY")
security = HTTPBearer(auto_error=False)
async def verify_api_key(auth: Annotated[Optional[HTTPAuthorizationCredentials], Depends(security)]):
async def verify_api_key(auth: Annotated[HTTPAuthorizationCredentials | None, Depends(security)]):
if api_key and (auth is None or auth.credentials != api_key):
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key.")

View File

@@ -14,10 +14,9 @@
import time
from enum import Enum, unique
from typing import Any, Optional, Union
from typing import Any, Literal
from pydantic import BaseModel, Field
from typing_extensions import Literal
@unique
@@ -61,7 +60,7 @@ class FunctionDefinition(BaseModel):
class FunctionAvailable(BaseModel):
type: Literal["function", "code_interpreter"] = "function"
function: Optional[FunctionDefinition] = None
function: FunctionDefinition | None = None
class FunctionCall(BaseModel):
@@ -77,35 +76,35 @@ class URL(BaseModel):
class MultimodalInputItem(BaseModel):
type: Literal["text", "image_url", "video_url", "audio_url"]
text: Optional[str] = None
image_url: Optional[URL] = None
video_url: Optional[URL] = None
audio_url: Optional[URL] = None
text: str | None = None
image_url: URL | None = None
video_url: URL | None = None
audio_url: URL | None = None
class ChatMessage(BaseModel):
role: Role
content: Optional[Union[str, list[MultimodalInputItem]]] = None
tool_calls: Optional[list[FunctionCall]] = None
content: str | list[MultimodalInputItem] | None = None
tool_calls: list[FunctionCall] | None = None
class ChatCompletionMessage(BaseModel):
role: Optional[Role] = None
content: Optional[str] = None
tool_calls: Optional[list[FunctionCall]] = None
role: Role | None = None
content: str | None = None
tool_calls: list[FunctionCall] | None = None
class ChatCompletionRequest(BaseModel):
model: str
messages: list[ChatMessage]
tools: Optional[list[FunctionAvailable]] = None
do_sample: Optional[bool] = None
temperature: Optional[float] = None
top_p: Optional[float] = None
tools: list[FunctionAvailable] | None = None
do_sample: bool | None = None
temperature: float | None = None
top_p: float | None = None
n: int = 1
presence_penalty: Optional[float] = None
max_tokens: Optional[int] = None
stop: Optional[Union[str, list[str]]] = None
presence_penalty: float | None = None
max_tokens: int | None = None
stop: str | list[str] | None = None
stream: bool = False
@@ -118,7 +117,7 @@ class ChatCompletionResponseChoice(BaseModel):
class ChatCompletionStreamResponseChoice(BaseModel):
index: int
delta: ChatCompletionMessage
finish_reason: Optional[Finish] = None
finish_reason: Finish | None = None
class ChatCompletionResponseUsage(BaseModel):
@@ -147,7 +146,7 @@ class ChatCompletionStreamResponse(BaseModel):
class ScoreEvaluationRequest(BaseModel):
model: str
messages: list[str]
max_length: Optional[int] = None
max_length: int | None = None
class ScoreEvaluationResponse(BaseModel):

View File

@@ -14,9 +14,9 @@
import asyncio
import os
from collections.abc import AsyncGenerator
from collections.abc import AsyncGenerator, Callable
from threading import Thread
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
from typing import TYPE_CHECKING, Any, Optional, Union
import torch
from transformers import GenerationConfig, TextIteratorStreamer

View File

@@ -15,8 +15,8 @@
import uuid
from collections.abc import AsyncGenerator, AsyncIterator
from typing import TYPE_CHECKING, Any, Optional, Union
from packaging import version
from packaging import version
from typing_extensions import override
from ..data import get_template_and_fix_tokenizer

View File

@@ -15,7 +15,7 @@ import json
import os
from abc import abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Optional, Union
from typing import TYPE_CHECKING, Any, Union
from ..extras import logging
from .data_utils import Role
@@ -40,7 +40,7 @@ class DatasetConverter:
dataset_attr: "DatasetAttr"
data_args: "DataArguments"
def _find_medias(self, medias: Union["MediaType", list["MediaType"], None]) -> Optional[list["MediaType"]]:
def _find_medias(self, medias: Union["MediaType", list["MediaType"], None]) -> list["MediaType"] | None:
r"""Optionally concatenate media path to media dir when loading from local disk."""
if medias is None:
return None

View File

@@ -81,41 +81,48 @@ def split_dataset(
eval_dataset: Optional[Union["Dataset", "IterableDataset", dict[str, "Dataset"]]],
data_args: "DataArguments",
seed: int,
) -> "DatasetDict":
r"""Split the dataset and returns a dataset dict containing train set and validation set.
) -> tuple[dict, dict]:
r"""Split the dataset and returns two dicts containing train set and validation set.
Support both map dataset and iterable dataset.
Returns:
train_dict: Dictionary containing training data with key "train"
eval_dict: Dictionary containing evaluation data with keys "validation" or "validation_{name}"
"""
if eval_dataset is not None and data_args.val_size > 1e-6:
raise ValueError("Cannot specify `val_size` if `eval_dataset` is not None.")
dataset_dict = {}
# the train and eval better to in dict dtype and separately return for cpode clearly and good handle outside
train_dict, eval_dict = {}, {}
if dataset is not None:
if data_args.streaming:
dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=seed)
if data_args.val_size > 1e-6:
if data_args.streaming:
dataset_dict["validation"] = dataset.take(int(data_args.val_size))
dataset_dict["train"] = dataset.skip(int(data_args.val_size))
eval_dict["validation"] = dataset.take(int(data_args.val_size))
train_dict["train"] = dataset.skip(int(data_args.val_size))
else:
val_size = int(data_args.val_size) if data_args.val_size > 1 else data_args.val_size
dataset_dict = dataset.train_test_split(test_size=val_size, seed=seed)
dataset = dataset.train_test_split(test_size=val_size, seed=seed)
dataset_dict = {"train": dataset["train"], "validation": dataset["test"]}
split_result = dataset.train_test_split(test_size=val_size, seed=seed)
train_dict["train"] = split_result["train"]
eval_dict["validation"] = split_result["test"]
else:
dataset_dict["train"] = dataset
train_dict["train"] = dataset
if eval_dataset is not None:
if isinstance(eval_dataset, dict):
dataset_dict.update({f"validation_{name}": data for name, data in eval_dataset.items()})
for name, data in eval_dataset.items():
eval_dict[f"validation_{name}"] = data
else:
if data_args.streaming:
eval_dataset = eval_dataset.shuffle(buffer_size=data_args.buffer_size, seed=seed)
dataset_dict["validation"] = eval_dataset
eval_dict["validation"] = eval_dataset
return DatasetDict(dataset_dict)
return train_dict, eval_dict
def get_dataset_module(dataset: Union["Dataset", "DatasetDict"]) -> "DatasetModule":

View File

@@ -16,7 +16,6 @@ import json
import re
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Optional, Union
from typing_extensions import override
@@ -27,14 +26,14 @@ from .tool_utils import FunctionCall, get_tool_utils
@dataclass
class Formatter(ABC):
slots: SLOTS = field(default_factory=list)
tool_format: Optional[str] = None
tool_format: str | None = None
@abstractmethod
def apply(self, **kwargs) -> SLOTS:
r"""Forms a list of slots according to the inputs to encode."""
...
def extract(self, content: str) -> Union[str, list["FunctionCall"]]:
def extract(self, content: str) -> str | list["FunctionCall"]:
r"""Extract a list of tuples from the response message if using tools.
Each tuple consists of function name and function arguments.
@@ -97,31 +96,46 @@ class FunctionFormatter(StringFormatter):
@override
def apply(self, **kwargs) -> SLOTS:
content: str = kwargs.pop("content")
thought_words, thought = kwargs.pop("thought_words", None), None
if thought_words and len(thought_words) == 2:
regex = re.compile(rf"{re.escape(thought_words[0])}(.*?){re.escape(thought_words[1])}", re.DOTALL)
thought = re.search(regex, content)
thought_words = kwargs.pop("thought_words", None)
tool_call_words = kwargs.pop("tool_call_words", None)
if thought:
content = content.replace(thought.group(0), "")
def _parse_functions(json_content: str) -> list["FunctionCall"]:
try:
tool_calls = json.loads(json_content)
if not isinstance(tool_calls, list): # parallel function call
tool_calls = [tool_calls]
functions: list[FunctionCall] = []
try:
tool_calls = json.loads(content)
if not isinstance(tool_calls, list): # parallel function call
tool_calls = [tool_calls]
return [FunctionCall(tc["name"], json.dumps(tc["arguments"], ensure_ascii=False)) for tc in tool_calls]
except json.JSONDecodeError:
raise RuntimeError(f"Invalid JSON format in function message: {str([content])}.")
for tool_call in tool_calls:
functions.append(
FunctionCall(tool_call["name"], json.dumps(tool_call["arguments"], ensure_ascii=False))
)
tool_call_match = None
if tool_call_words and len(tool_call_words) == 2:
tool_call_regex = re.compile(
rf"{re.escape(tool_call_words[0])}(.*?){re.escape(tool_call_words[1])}", re.DOTALL
)
tool_call_match = re.search(tool_call_regex, content)
except json.JSONDecodeError:
raise RuntimeError(f"Invalid JSON format in function message: {str([content])}.") # flat string
if tool_call_match is None:
thought_match = None
if thought_words and len(thought_words) == 2:
regex = re.compile(rf"{re.escape(thought_words[0])}(.*?){re.escape(thought_words[1])}", re.DOTALL)
thought_match = re.search(regex, content)
function_str = self.tool_utils.function_formatter(functions)
if thought:
function_str = thought.group(0) + function_str
if thought_match:
json_part = content.replace(thought_match.group(0), "")
else:
json_part = content
functions = _parse_functions(json_part)
function_str = self.tool_utils.function_formatter(functions)
if thought_match:
function_str = thought_match.group(0) + function_str
else:
thought_content = content.replace(tool_call_match.group(0), "")
functions = _parse_functions(tool_call_match.group(1))
function_str = self.tool_utils.function_formatter(functions)
function_str = thought_content + function_str
return super().apply(content=function_str)
@@ -141,5 +155,5 @@ class ToolFormatter(Formatter):
raise RuntimeError(f"Invalid JSON format in tool description: {str([content])}.") # flat string
@override
def extract(self, content: str) -> Union[str, list["FunctionCall"]]:
def extract(self, content: str) -> str | list["FunctionCall"]:
return self.tool_utils.tool_extractor(content)

View File

@@ -16,7 +16,7 @@ import os
from typing import TYPE_CHECKING, Literal, Optional, Union
import numpy as np
from datasets import Dataset, load_dataset, load_from_disk
from datasets import Dataset, DatasetDict, load_dataset, load_from_disk
from ..extras import logging
from ..extras.constants import FILEEXT2TYPE
@@ -162,13 +162,13 @@ def _load_single_dataset(
def _get_merged_dataset(
dataset_names: Optional[list[str]],
dataset_names: list[str] | None,
model_args: "ModelArguments",
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
return_dict: bool = False,
) -> Optional[Union["Dataset", "IterableDataset", dict[str, "Dataset"]]]:
) -> Union["Dataset", "IterableDataset", dict[str, "Dataset"]] | None:
r"""Return the merged datasets in the standard format."""
if dataset_names is None:
return None
@@ -227,7 +227,7 @@ def _get_dataset_processor(
def _get_preprocessed_dataset(
dataset: Optional[Union["Dataset", "IterableDataset"]],
dataset: Union["Dataset", "IterableDataset"] | None,
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
@@ -235,7 +235,7 @@ def _get_preprocessed_dataset(
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"] = None,
is_eval: bool = False,
) -> Optional[Union["Dataset", "IterableDataset"]]:
) -> Union["Dataset", "IterableDataset"] | None:
r"""Preprocesses the dataset, including format checking and tokenization."""
if dataset is None:
return None
@@ -311,20 +311,22 @@ def get_dataset(
)
with training_args.main_process_first(desc="pre-process dataset", local=(not data_args.data_shared_file_system)):
dataset = _get_preprocessed_dataset(
dataset, data_args, training_args, stage, template, tokenizer, processor, is_eval=False
)
if isinstance(eval_dataset, dict):
for eval_name, eval_data in eval_dataset.items():
eval_dataset[eval_name] = _get_preprocessed_dataset(
eval_data, data_args, training_args, stage, template, tokenizer, processor, is_eval=True
)
else:
eval_dataset = _get_preprocessed_dataset(
eval_dataset, data_args, training_args, stage, template, tokenizer, processor, is_eval=True
# move front to make sure eval_dataset(if contain or split) can preprocessed appropriately
train_dict, eval_dict = split_dataset(dataset, eval_dataset, data_args, seed=training_args.seed)
if "train" in train_dict:
train_dict["train"] = _get_preprocessed_dataset(
train_dict["train"], data_args, training_args, stage, template, tokenizer, processor, is_eval=False
)
dataset_dict = split_dataset(dataset, eval_dataset, data_args, seed=training_args.seed)
for key in eval_dict:
eval_dict[key] = _get_preprocessed_dataset(
eval_dict[key], data_args, training_args, stage, template, tokenizer, processor, is_eval=True
)
# Combine train and eval dictionaries
dataset_dict = DatasetDict({**train_dict, **eval_dict})
if data_args.tokenized_path is not None: # save tokenized dataset to disk
if training_args.should_save:
dataset_dict.save_to_disk(data_args.tokenized_path)

View File

@@ -22,28 +22,20 @@ import re
from copy import deepcopy
from dataclasses import dataclass
from io import BytesIO
from typing import TYPE_CHECKING, BinaryIO, Literal, Optional, TypedDict, Union
from typing import TYPE_CHECKING, BinaryIO, Literal, NotRequired, Optional, TypedDict, Union
import numpy as np
import torch
import torchaudio
from transformers.image_utils import get_image_size, is_valid_image, to_numpy_array
from transformers.models.mllama.processing_mllama import (
convert_sparse_cross_attention_mask_to_dense,
get_cross_attention_token_mask,
)
from typing_extensions import NotRequired, override
from typing_extensions import override
from ..extras.constants import AUDIO_PLACEHOLDER, IGNORE_INDEX, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
from ..extras.packages import (
is_librosa_available,
is_pillow_available,
is_pyav_available,
is_transformers_version_greater_than,
)
if is_librosa_available():
import librosa
from ..extras.packages import is_pillow_available, is_pyav_available, is_transformers_version_greater_than
if is_pillow_available():
@@ -71,8 +63,8 @@ if TYPE_CHECKING:
from transformers.video_processing_utils import BaseVideoProcessor
class EncodedImage(TypedDict):
path: Optional[str]
bytes: Optional[bytes]
path: str | None
bytes: bytes | None
ImageInput = Union[str, bytes, EncodedImage, BinaryIO, ImageObject]
VideoInput = Union[str, BinaryIO, list[list[ImageInput]]]
@@ -152,9 +144,9 @@ def _check_video_is_nested_images(video: "VideoInput") -> bool:
@dataclass
class MMPluginMixin:
image_token: Optional[str]
video_token: Optional[str]
audio_token: Optional[str]
image_token: str | None
video_token: str | None
audio_token: str | None
expand_mm_tokens: bool = True
def _validate_input(
@@ -316,7 +308,14 @@ class MMPluginMixin:
results, sampling_rates = [], []
for audio in audios:
if not isinstance(audio, np.ndarray):
audio, sampling_rate = librosa.load(audio, sr=sampling_rate)
audio, sr = torchaudio.load(audio)
if audio.shape[0] > 1:
audio = audio.mean(dim=0, keepdim=True)
if sr != sampling_rate:
audio = torchaudio.functional.resample(audio, sr, sampling_rate)
audio = audio.squeeze(0).numpy()
results.append(audio)
sampling_rates.append(sampling_rate)
@@ -329,7 +328,7 @@ class MMPluginMixin:
videos: list["VideoInput"],
audios: list["AudioInput"],
processor: "MMProcessor",
imglens: Optional[list[int]] = None,
imglens: list[int] | None = None,
) -> dict[str, "torch.Tensor"]:
r"""Process visual inputs.
@@ -427,13 +426,13 @@ class BasePlugin(MMPluginMixin):
def process_token_ids(
self,
input_ids: list[int],
labels: Optional[list[int]],
labels: list[int] | None,
images: list["ImageInput"],
videos: list["VideoInput"],
audios: list["AudioInput"],
tokenizer: "PreTrainedTokenizer",
processor: Optional["MMProcessor"],
) -> tuple[list[int], Optional[list[int]]]:
) -> tuple[list[int], list[int] | None]:
r"""Pre-process token ids after tokenization for VLMs."""
self._validate_input(processor, images, videos, audios)
return input_ids, labels
@@ -465,6 +464,7 @@ class BasePlugin(MMPluginMixin):
self._validate_input(processor, images, videos, audios)
return self._get_mm_inputs(images, videos, audios, processor)
@dataclass
class ErnieVLPlugin(BasePlugin):
@override
@@ -479,21 +479,39 @@ class ErnieVLPlugin(BasePlugin):
self._validate_input(processor, images, videos, audios)
self._validate_messages(messages, images, videos, audios)
messages = deepcopy(messages)
image_processor: BaseImageProcessor = getattr(processor, "image_processor")
merge_length: int = getattr(image_processor, "merge_size") ** 2
if self.expand_mm_tokens:
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
image_grid_thw = mm_inputs.get("image_grid_thw", [])
video_grid_thw = mm_inputs.get("video_grid_thw", [])
else:
image_grid_thw = [None] * len(images)
video_grid_thw = [None] * len(videos)
image_idx, video_idx = 0, 0
for message in messages:
content = message["content"]
image_token = self.image_token or "<|image@placeholder|>"
video_token = self.video_token or "<|video@placeholder|>"
image_token = self.image_token or "<|IMAGE_PLACEHOLDER|>"
video_token = self.video_token or "<|VIDEO_PLACEHOLDER|>"
while IMAGE_PLACEHOLDER in content:
image_seqlen = image_grid_thw[image_idx].prod() // merge_length if self.expand_mm_tokens else 1
content = content.replace(
IMAGE_PLACEHOLDER,
f"Picture {image_idx + 1}:<|IMAGE_START|>{image_token * image_seqlen}<|IMAGE_END|>",
1,
)
image_idx += 1
content = content.replace(
IMAGE_PLACEHOLDER, f"Picture {image_idx}:<|IMAGE_START|>{image_token}<|IMAGE_END|>", 1
)
while VIDEO_PLACEHOLDER in content:
video_idx += 1
video_seqlen = video_grid_thw[video_idx].prod() // merge_length if self.expand_mm_tokens else 1
content = content.replace(
VIDEO_PLACEHOLDER, f"Video {video_idx}:<|VIDEO_START|>{video_token}<|VIDEO_END|>", 1
VIDEO_PLACEHOLDER,
f"Video {video_idx + 1}:<|VIDEO_START|>{video_token * video_seqlen}<|VIDEO_END|>",
1,
)
video_idx += 1
message["content"] = content
return messages
@@ -1287,13 +1305,13 @@ class PaliGemmaPlugin(BasePlugin):
def process_token_ids(
self,
input_ids: list[int],
labels: Optional[list[int]],
labels: list[int] | None,
images: list["ImageInput"],
videos: list["VideoInput"],
audios: list["AudioInput"],
tokenizer: "PreTrainedTokenizer",
processor: Optional["MMProcessor"],
) -> tuple[list[int], Optional[list[int]]]:
) -> tuple[list[int], list[int] | None]:
self._validate_input(processor, images, videos, audios)
num_images = len(images)
image_seqlen = processor.image_seq_length if self.expand_mm_tokens else 0 # skip mm token
@@ -1623,7 +1641,12 @@ class Qwen3VLPlugin(Qwen2VLPlugin):
for video, duration in zip(videos["videos"], videos["durations"])
]
mm_inputs.update(
video_processor(videos=videos["videos"], video_metadata=video_metadata, return_metadata=True)
video_processor(
videos=videos["videos"],
video_metadata=video_metadata,
fps=getattr(processor, "video_fps", 2.0),
return_metadata=True,
)
)
temporal_patch_size: int = getattr(image_processor, "temporal_patch_size", 2)
if "second_per_grid_ts" in processor.model_input_names:
@@ -2103,9 +2126,9 @@ def register_mm_plugin(name: str, plugin_class: type["BasePlugin"]) -> None:
def get_mm_plugin(
name: str,
image_token: Optional[str] = None,
video_token: Optional[str] = None,
audio_token: Optional[str] = None,
image_token: str | None = None,
video_token: str | None = None,
audio_token: str | None = None,
**kwargs,
) -> "BasePlugin":
r"""Get plugin for multimodal inputs."""

View File

@@ -15,7 +15,7 @@
import json
import os
from dataclasses import dataclass
from typing import Any, Literal, Optional, Union
from typing import Any, Literal
from huggingface_hub import hf_hub_download
@@ -33,40 +33,40 @@ class DatasetAttr:
formatting: Literal["alpaca", "sharegpt", "openai"] = "alpaca"
ranking: bool = False
# extra configs
subset: Optional[str] = None
subset: str | None = None
split: str = "train"
folder: Optional[str] = None
num_samples: Optional[int] = None
folder: str | None = None
num_samples: int | None = None
# common columns
system: Optional[str] = None
tools: Optional[str] = None
images: Optional[str] = None
videos: Optional[str] = None
audios: Optional[str] = None
system: str | None = None
tools: str | None = None
images: str | None = None
videos: str | None = None
audios: str | None = None
# dpo columns
chosen: Optional[str] = None
rejected: Optional[str] = None
kto_tag: Optional[str] = None
chosen: str | None = None
rejected: str | None = None
kto_tag: str | None = None
# alpaca columns
prompt: Optional[str] = "instruction"
query: Optional[str] = "input"
response: Optional[str] = "output"
history: Optional[str] = None
prompt: str | None = "instruction"
query: str | None = "input"
response: str | None = "output"
history: str | None = None
# sharegpt columns
messages: Optional[str] = "conversations"
messages: str | None = "conversations"
# sharegpt tags
role_tag: Optional[str] = "from"
content_tag: Optional[str] = "value"
user_tag: Optional[str] = "human"
assistant_tag: Optional[str] = "gpt"
observation_tag: Optional[str] = "observation"
function_tag: Optional[str] = "function_call"
system_tag: Optional[str] = "system"
role_tag: str | None = "from"
content_tag: str | None = "value"
user_tag: str | None = "human"
assistant_tag: str | None = "gpt"
observation_tag: str | None = "observation"
function_tag: str | None = "function_call"
system_tag: str | None = "system"
def __repr__(self) -> str:
return self.dataset_name
def set_attr(self, key: str, obj: dict[str, Any], default: Optional[Any] = None) -> None:
def set_attr(self, key: str, obj: dict[str, Any], default: Any | None = None) -> None:
setattr(self, key, obj.get(key, default))
def join(self, attr: dict[str, Any]) -> None:
@@ -90,7 +90,7 @@ class DatasetAttr:
self.set_attr(tag, attr["tags"])
def get_dataset_list(dataset_names: Optional[list[str]], dataset_dir: Union[str, dict]) -> list["DatasetAttr"]:
def get_dataset_list(dataset_names: list[str] | None, dataset_dir: str | dict) -> list["DatasetAttr"]:
r"""Get the attributes of the datasets."""
if dataset_names is None:
dataset_names = []

View File

@@ -49,6 +49,7 @@ class Template:
default_system: str
stop_words: list[str]
thought_words: tuple[str, str]
tool_call_words: tuple[str, str]
efficient_eos: bool
replace_eos: bool
replace_jinja_template: bool
@@ -156,7 +157,9 @@ class Template:
elif message["role"] == Role.OBSERVATION:
elements += self.format_observation.apply(content=message["content"])
elif message["role"] == Role.FUNCTION:
elements += self.format_function.apply(content=message["content"], thought_words=self.thought_words)
elements += self.format_function.apply(
content=message["content"], thought_words=self.thought_words, tool_call_words=self.tool_call_words
)
else:
raise NotImplementedError("Unexpected role: {}".format(message["role"]))
@@ -199,9 +202,12 @@ class Template:
logger.info_rank0(f"Add pad token: {tokenizer.pad_token}")
if stop_words:
num_added_tokens = tokenizer.add_special_tokens(
dict(additional_special_tokens=stop_words), replace_additional_special_tokens=False
)
try:
num_added_tokens = tokenizer.add_special_tokens(
dict(additional_special_tokens=stop_words), replace_additional_special_tokens=False
)
except TypeError:
num_added_tokens = tokenizer.add_special_tokens(dict(additional_special_tokens=stop_words))
logger.info_rank0("Add {} to stop words.".format(",".join(stop_words)))
if num_added_tokens > 0:
logger.warning_rank0("New tokens have been added, make sure `resize_vocab` is True.")
@@ -468,6 +474,7 @@ def register_template(
default_system: str = "",
stop_words: Optional[list[str]] = None,
thought_words: Optional[tuple[str, str]] = None,
tool_call_words: Optional[tuple[str, str]] = None,
efficient_eos: bool = False,
replace_eos: bool = False,
replace_jinja_template: bool = False,
@@ -519,6 +526,7 @@ def register_template(
default_system=default_system,
stop_words=stop_words or [],
thought_words=thought_words or ("<think>\n", "\n</think>\n\n"),
tool_call_words=tool_call_words or ("<tool_call>", "</tool_call>"),
efficient_eos=efficient_eos,
replace_eos=replace_eos,
replace_jinja_template=replace_jinja_template,
@@ -580,6 +588,7 @@ def parse_template(tokenizer: "PreTrainedTokenizer") -> "Template":
default_system=default_system,
stop_words=[],
thought_words=("<think>\n", "\n</think>\n\n"),
tool_call_words=("<tool_call>", "</tool_call>"),
efficient_eos=False,
replace_eos=False,
replace_jinja_template=False,
@@ -972,7 +981,7 @@ register_template(
replace_eos=True,
replace_jinja_template=True,
template_class=ReasoningTemplate,
mm_plugin=get_mm_plugin(name="ernie_vl", image_token="<|image@placeholder|>", video_token="<|video@placeholder|>"),
mm_plugin=get_mm_plugin(name="ernie_vl", image_token="<|IMAGE_PLACEHOLDER|>", video_token="<|VIDEO_PLACEHOLDER|>"),
)
@@ -1125,7 +1134,7 @@ register_template(
# copied from glm4 template
register_template(
name="glm4v_moe",
name="glm4_5v",
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]),
format_assistant=StringFormatter(slots=["\n{{content}}"]),
format_system=StringFormatter(slots=["<|system|>\n{{content}}"]),
@@ -1157,7 +1166,7 @@ register_template(
register_template(
name="gpt",
name="gpt_oss",
format_user=StringFormatter(slots=["<|start|>user<|message|>{{content}}<|end|><|start|>assistant"]),
format_assistant=StringFormatter(slots=["{{content}}<|end|>"]),
format_system=StringFormatter(slots=["<|start|>system<|message|>{{content}}<|end|>"]),
@@ -1601,6 +1610,26 @@ register_template(
template_class=ReasoningTemplate,
)
# copied from qwen template
register_template(
name="mimo_v2",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_function=FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen"),
format_observation=StringFormatter(
slots=["<|im_start|>user\n<tool_response>\n{{content}}\n</tool_response><|im_end|>\n<|im_start|>assistant\n"]
),
format_tools=ToolFormatter(tool_format="qwen"),
default_system="You are MiMo, a helpful AI assistant engineered by Xiaomi.",
stop_words=["<|im_end|>"],
replace_eos=True,
thought_words=("<think>", "</think>"),
template_class=ReasoningTemplate,
)
# copied from qwen2vl
register_template(
name="mimo_vl",
@@ -1684,6 +1713,19 @@ register_template(
)
register_template(
name="ministral3",
format_user=StringFormatter(slots=["[INST]{{content}}[/INST]"]),
format_system=StringFormatter(slots=["{{content}}\n\n"]),
format_function=FunctionFormatter(slots=["[TOOL_CALLS]{{content}}", {"eos_token"}], tool_format="mistral"),
format_observation=StringFormatter(slots=["""[TOOL_RESULTS]{"content": {{content}}}[/TOOL_RESULTS]"""]),
format_tools=ToolFormatter(tool_format="mistral"),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
template_class=Llama2Template,
mm_plugin=get_mm_plugin(name="pixtral", image_token="[IMG]"),
)
register_template(
name="olmo",
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>\n"]),

View File

@@ -15,7 +15,6 @@
import os
from collections import OrderedDict, defaultdict
from enum import Enum, unique
from typing import Optional
from peft.utils import SAFETENSORS_WEIGHTS_NAME as SAFE_ADAPTER_WEIGHTS_NAME
from peft.utils import WEIGHTS_NAME as ADAPTER_WEIGHTS_NAME
@@ -114,6 +113,7 @@ class AttentionFunction(str, Enum):
DISABLED = "disabled"
SDPA = "sdpa"
FA2 = "fa2"
FA3 = "fa3"
class EngineName(str, Enum):
@@ -141,6 +141,7 @@ class QuantizationMethod(str, Enum):
EETQ = "eetq"
HQQ = "hqq"
MXFP4 = "mxfp4"
FP8 = "fp8"
class RopeScaling(str, Enum):
@@ -152,7 +153,7 @@ class RopeScaling(str, Enum):
def register_model_group(
models: dict[str, dict[DownloadSource, str]],
template: Optional[str] = None,
template: str | None = None,
multimodal: bool = False,
) -> None:
for name, path in models.items():
@@ -1003,9 +1004,17 @@ register_model_group(
"GLM-4.5V-Air-Thinking": {
DownloadSource.DEFAULT: "zai-org/GLM-4.5V",
DownloadSource.MODELSCOPE: "ZhipuAI/GLM-4.5V",
}
},
"GLM-4.6V": {
DownloadSource.DEFAULT: "zai-org/GLM-4.6V",
DownloadSource.MODELSCOPE: "ZhipuAI/GLM-4.6V",
},
"GLM-4.6V-Flash": {
DownloadSource.DEFAULT: "zai-org/GLM-4.6V-Flash",
DownloadSource.MODELSCOPE: "ZhipuAI/GLM-4.6V-Flash",
},
},
template="glm4v_moe",
template="glm4_5v",
multimodal=True,
)
@@ -1058,7 +1067,7 @@ register_model_group(
DownloadSource.MODELSCOPE: "openai/gpt-oss-120b",
},
},
template="gpt",
template="gpt_oss",
)
@@ -1794,6 +1803,21 @@ register_model_group(
)
register_model_group(
models={
"MiMo-V2-Flash-Base": {
DownloadSource.DEFAULT: "XiaomiMiMo/MiMo-V2-Flash-Base",
DownloadSource.MODELSCOPE: "XiaomiMiMo/MiMo-V2-Flash-Base",
},
"MiMo-V2-Flash": {
DownloadSource.DEFAULT: "XiaomiMiMo/MiMo-V2-Flash",
DownloadSource.MODELSCOPE: "XiaomiMiMo/MiMo-V2-Flash",
},
},
template="mimo_v2",
)
register_model_group(
models={
"MiMo-7B-VL-RL": {
@@ -1818,7 +1842,7 @@ register_model_group(
},
"MiMo-VL-7B-SFT-2508": {
DownloadSource.DEFAULT: "XiaomiMiMo/MiMo-VL-7B-SFT-2508",
DownloadSource.DEFAULT: "XiaomiMiMo/MiMo-VL-7B-SFT-2508",
DownloadSource.MODELSCOPE: "XiaomiMiMo/MiMo-VL-7B-SFT-2508",
},
},
template="qwen2_vl",
@@ -1969,6 +1993,37 @@ register_model_group(
template="mistral",
)
register_model_group(
models={
"Ministral-3-3B-Base-2512": {
DownloadSource.DEFAULT: "mistralai/Ministral-3-3B-Base-2512",
DownloadSource.MODELSCOPE: "mistralai/Ministral-3-3B-Base-2512",
},
"Ministral-3-8B-Base-2512": {
DownloadSource.DEFAULT: "mistralai/Ministral-3-8B-Base-2512",
DownloadSource.MODELSCOPE: "mistralai/Ministral-3-8B-Base-2512",
},
"Ministral-3-14B-Base-2512": {
DownloadSource.DEFAULT: "mistralai/Ministral-3-14B-Base-2512",
DownloadSource.MODELSCOPE: "mistralai/Ministral-3-14B-Base-2512",
},
"Ministral-3-3B-Instruct-2512": {
DownloadSource.DEFAULT: "mistralai/Ministral-3-3B-Instruct-2512",
DownloadSource.MODELSCOPE: "mistralai/Ministral-3-3B-Instruct-2512",
},
"Ministral-3-8B-Instruct-2512": {
DownloadSource.DEFAULT: "mistralai/Ministral-3-8B-Instruct-2512",
DownloadSource.MODELSCOPE: "mistralai/Ministral-3-8B-Instruct-2512",
},
"Ministral-3-14B-Instruct-2512": {
DownloadSource.DEFAULT: "mistralai/Ministral-3-14B-Instruct-2512",
DownloadSource.MODELSCOPE: "mistralai/Ministral-3-14B-Instruct-2512",
},
},
template="ministral3",
multimodal=True,
)
register_model_group(
models={
@@ -3492,6 +3547,17 @@ register_model_group(
)
register_model_group(
models={
"VibeThinker-1.5B": {
DownloadSource.DEFAULT: "WeiboAI/VibeThinker-1.5B",
DownloadSource.MODELSCOPE: "WeiboAI/VibeThinker-1.5B",
},
},
template="qwen3",
)
register_model_group(
models={
"Vicuna-v1.5-7B-Chat": {

View File

@@ -117,7 +117,7 @@ def _configure_library_root_logger() -> None:
library_root_logger.propagate = False
def get_logger(name: Optional[str] = None) -> "_Logger":
def get_logger(name: str | None = None) -> "_Logger":
r"""Return a logger with the specified name. It it not supposed to be accessed externally."""
if name is None:
name = _get_library_name()

View File

@@ -332,3 +332,7 @@ def fix_proxy(ipv6_enabled: bool = False) -> None:
if ipv6_enabled:
os.environ.pop("http_proxy", None)
os.environ.pop("HTTP_PROXY", None)
os.environ.pop("https_proxy", None)
os.environ.pop("HTTPS_PROXY", None)
os.environ.pop("all_proxy", None)
os.environ.pop("ALL_PROXY", None)

View File

@@ -16,22 +16,22 @@
# limitations under the License.
from dataclasses import asdict, dataclass, field
from typing import Any, Literal, Optional
from typing import Any, Literal
@dataclass
class DataArguments:
r"""Arguments pertaining to what data we are going to input our model for training and evaluation."""
template: Optional[str] = field(
template: str | None = field(
default=None,
metadata={"help": "Which template to use for constructing prompts in training and inference."},
)
dataset: Optional[str] = field(
dataset: str | None = field(
default=None,
metadata={"help": "The name of dataset(s) to use for training. Use commas to separate multiple datasets."},
)
eval_dataset: Optional[str] = field(
eval_dataset: str | None = field(
default=None,
metadata={"help": "The name of dataset(s) to use for evaluation. Use commas to separate multiple datasets."},
)
@@ -39,7 +39,7 @@ class DataArguments:
default="data",
metadata={"help": "Path to the folder containing the datasets."},
)
media_dir: Optional[str] = field(
media_dir: str | None = field(
default=None,
metadata={"help": "Path to the folder containing the images, videos or audios. Defaults to `dataset_dir`."},
)
@@ -67,7 +67,7 @@ class DataArguments:
default="concat",
metadata={"help": "Strategy to use in dataset mixing (concat/interleave) (undersampling/oversampling)."},
)
interleave_probs: Optional[str] = field(
interleave_probs: str | None = field(
default=None,
metadata={"help": "Probabilities to sample data from datasets. Use commas to separate multiple datasets."},
)
@@ -79,15 +79,15 @@ class DataArguments:
default=1000,
metadata={"help": "The number of examples in one group in pre-processing."},
)
preprocessing_num_workers: Optional[int] = field(
preprocessing_num_workers: int | None = field(
default=None,
metadata={"help": "The number of processes to use for the pre-processing."},
)
max_samples: Optional[int] = field(
max_samples: int | None = field(
default=None,
metadata={"help": "For debugging purposes, truncate the number of examples for each dataset."},
)
eval_num_beams: Optional[int] = field(
eval_num_beams: int | None = field(
default=None,
metadata={"help": "Number of beams to use for evaluation. This argument will be passed to `model.generate`"},
)
@@ -103,7 +103,7 @@ class DataArguments:
default=False,
metadata={"help": "Whether or not to evaluate on each dataset separately."},
)
packing: Optional[bool] = field(
packing: bool | None = field(
default=None,
metadata={"help": "Enable sequences packing in training. Will automatically enable in pre-training."},
)
@@ -111,19 +111,19 @@ class DataArguments:
default=False,
metadata={"help": "Enable sequence packing without cross-attention."},
)
tool_format: Optional[str] = field(
tool_format: str | None = field(
default=None,
metadata={"help": "Tool format to use for constructing function calling examples."},
)
default_system: Optional[str] = field(
default_system: str | None = field(
default=None,
metadata={"help": "Override the default system message in the template."},
)
enable_thinking: Optional[bool] = field(
enable_thinking: bool | None = field(
default=True,
metadata={"help": "Whether or not to enable thinking mode for reasoning models."},
)
tokenized_path: Optional[str] = field(
tokenized_path: str | None = field(
default=None,
metadata={
"help": (

View File

@@ -14,7 +14,7 @@
import os
from dataclasses import dataclass, field
from typing import Literal, Optional
from typing import Literal
from datasets import DownloadMode
@@ -46,7 +46,7 @@ class EvaluationArguments:
default=5,
metadata={"help": "Number of examplars for few-shot learning."},
)
save_dir: Optional[str] = field(
save_dir: str | None = field(
default=None,
metadata={"help": "Path to save the evaluation results."},
)

View File

@@ -13,7 +13,7 @@
# limitations under the License.
from dataclasses import asdict, dataclass, field
from typing import Any, Literal, Optional
from typing import Any, Literal
@dataclass
@@ -40,7 +40,7 @@ class FreezeArguments:
)
},
)
freeze_extra_modules: Optional[str] = field(
freeze_extra_modules: str | None = field(
default=None,
metadata={
"help": (
@@ -56,7 +56,7 @@ class FreezeArguments:
class LoraArguments:
r"""Arguments pertaining to the LoRA training."""
additional_target: Optional[str] = field(
additional_target: str | None = field(
default=None,
metadata={
"help": (
@@ -66,7 +66,7 @@ class LoraArguments:
)
},
)
lora_alpha: Optional[int] = field(
lora_alpha: int | None = field(
default=None,
metadata={"help": "The scale factor for LoRA fine-tuning (default: lora_rank * 2)."},
)
@@ -88,7 +88,7 @@ class LoraArguments:
)
},
)
loraplus_lr_ratio: Optional[float] = field(
loraplus_lr_ratio: float | None = field(
default=None,
metadata={"help": "LoRA plus learning rate ratio (lr_B / lr_A)."},
)
@@ -126,7 +126,7 @@ class LoraArguments:
class OFTArguments:
r"""Arguments pertaining to the OFT training."""
additional_target: Optional[str] = field(
additional_target: str | None = field(
default=None,
metadata={
"help": (
@@ -220,27 +220,27 @@ class RLHFArguments:
default=False,
metadata={"help": "Whiten the rewards before compute advantages in PPO training."},
)
ref_model: Optional[str] = field(
ref_model: str | None = field(
default=None,
metadata={"help": "Path to the reference model used for the PPO or DPO training."},
)
ref_model_adapters: Optional[str] = field(
ref_model_adapters: str | None = field(
default=None,
metadata={"help": "Path to the adapters of the reference model."},
)
ref_model_quantization_bit: Optional[int] = field(
ref_model_quantization_bit: int | None = field(
default=None,
metadata={"help": "The number of bits to quantize the reference model."},
)
reward_model: Optional[str] = field(
reward_model: str | None = field(
default=None,
metadata={"help": "Path to the reward model used for the PPO training."},
)
reward_model_adapters: Optional[str] = field(
reward_model_adapters: str | None = field(
default=None,
metadata={"help": "Path to the adapters of the reward model."},
)
reward_model_quantization_bit: Optional[int] = field(
reward_model_quantization_bit: int | None = field(
default=None,
metadata={"help": "The number of bits to quantize the reward model."},
)
@@ -248,7 +248,7 @@ class RLHFArguments:
default="lora",
metadata={"help": "The type of the reward model in PPO training. Lora model only supports lora training."},
)
ld_alpha: Optional[float] = field(
ld_alpha: float | None = field(
default=None,
metadata={
"help": (
@@ -361,15 +361,15 @@ class BAdamArgument:
default="layer",
metadata={"help": "Whether to use layer-wise or ratio-wise BAdam optimizer."},
)
badam_start_block: Optional[int] = field(
badam_start_block: int | None = field(
default=None,
metadata={"help": "The starting block index for layer-wise BAdam."},
)
badam_switch_mode: Optional[Literal["ascending", "descending", "random", "fixed"]] = field(
badam_switch_mode: Literal["ascending", "descending", "random", "fixed"] | None = field(
default="ascending",
metadata={"help": "the strategy of picking block to update for layer-wise BAdam."},
)
badam_switch_interval: Optional[int] = field(
badam_switch_interval: int | None = field(
default=50,
metadata={
"help": "Number of steps to update the block for layer-wise BAdam. Use -1 to disable the block update."
@@ -406,15 +406,15 @@ class SwanLabArguments:
default=False,
metadata={"help": "Whether or not to use the SwanLab (an experiment tracking and visualization tool)."},
)
swanlab_project: Optional[str] = field(
swanlab_project: str | None = field(
default="llamafactory",
metadata={"help": "The project name in SwanLab."},
)
swanlab_workspace: Optional[str] = field(
swanlab_workspace: str | None = field(
default=None,
metadata={"help": "The workspace name in SwanLab."},
)
swanlab_run_name: Optional[str] = field(
swanlab_run_name: str | None = field(
default=None,
metadata={"help": "The experiment name in SwanLab."},
)
@@ -422,19 +422,19 @@ class SwanLabArguments:
default="cloud",
metadata={"help": "The mode of SwanLab."},
)
swanlab_api_key: Optional[str] = field(
swanlab_api_key: str | None = field(
default=None,
metadata={"help": "The API key for SwanLab."},
)
swanlab_logdir: Optional[str] = field(
swanlab_logdir: str | None = field(
default=None,
metadata={"help": "The log directory for SwanLab."},
)
swanlab_lark_webhook_url: Optional[str] = field(
swanlab_lark_webhook_url: str | None = field(
default=None,
metadata={"help": "The Lark(飞书) webhook URL for SwanLab."},
)
swanlab_lark_secret: Optional[str] = field(
swanlab_lark_secret: str | None = field(
default=None,
metadata={"help": "The Lark(飞书) secret for SwanLab."},
)
@@ -510,7 +510,7 @@ class FinetuningArguments(
default=False,
metadata={"help": "Whether or not to disable the shuffling of the training set."},
)
early_stopping_steps: Optional[int] = field(
early_stopping_steps: int | None = field(
default=None,
metadata={"help": "Number of steps to stop training if the `metric_for_best_model` does not improve."},
)
@@ -530,11 +530,11 @@ class FinetuningArguments(
return arg
self.freeze_trainable_modules: list[str] = split_arg(self.freeze_trainable_modules)
self.freeze_extra_modules: Optional[list[str]] = split_arg(self.freeze_extra_modules)
self.freeze_extra_modules: list[str] | None = split_arg(self.freeze_extra_modules)
self.lora_alpha: int = self.lora_alpha or self.lora_rank * 2
self.lora_target: list[str] = split_arg(self.lora_target)
self.oft_target: list[str] = split_arg(self.oft_target)
self.additional_target: Optional[list[str]] = split_arg(self.additional_target)
self.additional_target: list[str] | None = split_arg(self.additional_target)
self.galore_target: list[str] = split_arg(self.galore_target)
self.apollo_target: list[str] = split_arg(self.apollo_target)
self.use_ref_model = self.stage == "dpo" and self.pref_loss not in ["orpo", "simpo"]

View File

@@ -17,12 +17,11 @@
import json
from dataclasses import asdict, dataclass, field, fields
from typing import Any, Literal, Optional, Union
from typing import Any, Literal, Self
import torch
from omegaconf import OmegaConf
from transformers.training_args import _convert_str_dict
from typing_extensions import Self
from ..extras.constants import AttentionFunction, EngineName, QuantizationMethod, RopeScaling
from ..extras.logging import get_logger
@@ -35,13 +34,13 @@ logger = get_logger(__name__)
class BaseModelArguments:
r"""Arguments pertaining to the model."""
model_name_or_path: Optional[str] = field(
model_name_or_path: str | None = field(
default=None,
metadata={
"help": "Path to the model weight or identifier from huggingface.co/models or modelscope.cn/models."
},
)
adapter_name_or_path: Optional[str] = field(
adapter_name_or_path: str | None = field(
default=None,
metadata={
"help": (
@@ -50,11 +49,11 @@ class BaseModelArguments:
)
},
)
adapter_folder: Optional[str] = field(
adapter_folder: str | None = field(
default=None,
metadata={"help": "The folder containing the adapter weights to load."},
)
cache_dir: Optional[str] = field(
cache_dir: str | None = field(
default=None,
metadata={"help": "Where to store the pre-trained models downloaded from huggingface.co or modelscope.cn."},
)
@@ -70,17 +69,17 @@ class BaseModelArguments:
default=False,
metadata={"help": "Whether or not the special tokens should be split during the tokenization process."},
)
add_tokens: Optional[str] = field(
add_tokens: str | None = field(
default=None,
metadata={
"help": "Non-special tokens to be added into the tokenizer. Use commas to separate multiple tokens."
},
)
add_special_tokens: Optional[str] = field(
add_special_tokens: str | None = field(
default=None,
metadata={"help": "Special tokens to be added into the tokenizer. Use commas to separate multiple tokens."},
)
new_special_tokens_config: Optional[str] = field(
new_special_tokens_config: str | None = field(
default=None,
metadata={
"help": (
@@ -110,7 +109,7 @@ class BaseModelArguments:
default=True,
metadata={"help": "Whether or not to use memory-efficient model loading."},
)
rope_scaling: Optional[RopeScaling] = field(
rope_scaling: RopeScaling | None = field(
default=None,
metadata={"help": "Which scaling strategy should be adopted for the RoPE embeddings."},
)
@@ -122,7 +121,7 @@ class BaseModelArguments:
default=False,
metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."},
)
mixture_of_depths: Optional[Literal["convert", "load"]] = field(
mixture_of_depths: Literal["convert", "load"] | None = field(
default=None,
metadata={"help": "Convert the model to mixture-of-depths (MoD) or load the MoD model."},
)
@@ -138,7 +137,7 @@ class BaseModelArguments:
default=False,
metadata={"help": "Whether or not to enable liger kernel for faster training."},
)
moe_aux_loss_coef: Optional[float] = field(
moe_aux_loss_coef: float | None = field(
default=None,
metadata={"help": "Coefficient of the auxiliary router loss in mixture-of-experts model."},
)
@@ -182,15 +181,15 @@ class BaseModelArguments:
default="auto",
metadata={"help": "Data type for model weights and activations at inference."},
)
hf_hub_token: Optional[str] = field(
hf_hub_token: str | None = field(
default=None,
metadata={"help": "Auth token to log in with Hugging Face Hub."},
)
ms_hub_token: Optional[str] = field(
ms_hub_token: str | None = field(
default=None,
metadata={"help": "Auth token to log in with ModelScope Hub."},
)
om_hub_token: Optional[str] = field(
om_hub_token: str | None = field(
default=None,
metadata={"help": "Auth token to log in with Modelers Hub."},
)
@@ -283,7 +282,7 @@ class QuantizationArguments:
default=QuantizationMethod.BNB,
metadata={"help": "Quantization method to use for on-the-fly quantization."},
)
quantization_bit: Optional[int] = field(
quantization_bit: int | None = field(
default=None,
metadata={"help": "The number of bits to quantize the model using on-the-fly quantization."},
)
@@ -295,7 +294,7 @@ class QuantizationArguments:
default=True,
metadata={"help": "Whether or not to use double quantization in bitsandbytes int4 training."},
)
quantization_device_map: Optional[Literal["auto"]] = field(
quantization_device_map: Literal["auto"] | None = field(
default=None,
metadata={"help": "Device map used to infer the 4-bit quantized model, needs bitsandbytes>=0.43.0."},
)
@@ -375,7 +374,7 @@ class ProcessorArguments:
class ExportArguments:
r"""Arguments pertaining to the model export."""
export_dir: Optional[str] = field(
export_dir: str | None = field(
default=None,
metadata={"help": "Path to the directory to save the exported model."},
)
@@ -387,11 +386,11 @@ class ExportArguments:
default="cpu",
metadata={"help": "The device used in model export, use `auto` to accelerate exporting."},
)
export_quantization_bit: Optional[int] = field(
export_quantization_bit: int | None = field(
default=None,
metadata={"help": "The number of bits to quantize the exported model."},
)
export_quantization_dataset: Optional[str] = field(
export_quantization_dataset: str | None = field(
default=None,
metadata={"help": "Path to the dataset or dataset name to use in quantizing the exported model."},
)
@@ -407,7 +406,7 @@ class ExportArguments:
default=False,
metadata={"help": "Whether or not to save the `.bin` files instead of `.safetensors`."},
)
export_hub_model_id: Optional[str] = field(
export_hub_model_id: str | None = field(
default=None,
metadata={"help": "The name of the repository if push the model to the Hugging Face hub."},
)
@@ -437,7 +436,7 @@ class VllmArguments:
default=32,
metadata={"help": "Maximum rank of all LoRAs in the vLLM engine."},
)
vllm_config: Optional[Union[dict, str]] = field(
vllm_config: dict | str | None = field(
default=None,
metadata={"help": "Config to initialize the vllm engine. Please use JSON strings."},
)
@@ -463,7 +462,7 @@ class SGLangArguments:
default=-1,
metadata={"help": "Tensor parallel size for the SGLang engine."},
)
sglang_config: Optional[Union[dict, str]] = field(
sglang_config: dict | str | None = field(
default=None,
metadata={"help": "Config to initialize the SGLang engine. Please use JSON strings."},
)
@@ -487,21 +486,21 @@ class KTransformersArguments:
default=False,
metadata={"help": "Whether To Use KTransformers Optimizations For LoRA Training."},
)
kt_optimize_rule: Optional[str] = field(
kt_optimize_rule: str | None = field(
default=None,
metadata={
"help": "Path To The KTransformers Optimize Rule; See https://github.com/kvcache-ai/ktransformers/."
},
)
cpu_infer: Optional[int] = field(
cpu_infer: int | None = field(
default=32,
metadata={"help": "Number Of CPU Cores Used For Computation."},
)
chunk_size: Optional[int] = field(
chunk_size: int | None = field(
default=8192,
metadata={"help": "Chunk Size Used For CPU Compute In KTransformers."},
)
mode: Optional[str] = field(
mode: str | None = field(
default="normal",
metadata={"help": "Normal Or Long_Context For Llama Models."},
)
@@ -539,17 +538,17 @@ class ModelArguments(
The class on the most right will be displayed first.
"""
compute_dtype: Optional[torch.dtype] = field(
compute_dtype: torch.dtype | None = field(
default=None,
init=False,
metadata={"help": "Torch data type for computing model outputs, derived from `fp/bf16`. Do not specify it."},
)
device_map: Optional[Union[str, dict[str, Any]]] = field(
device_map: str | dict[str, Any] | None = field(
default=None,
init=False,
metadata={"help": "Device map for model placement, derived from training stage. Do not specify it."},
)
model_max_length: Optional[int] = field(
model_max_length: int | None = field(
default=None,
init=False,
metadata={"help": "The maximum input length for model, derived from `cutoff_len`. Do not specify it."},

View File

@@ -18,7 +18,7 @@
import os
import sys
from pathlib import Path
from typing import Any, Optional, Union
from typing import Any, Optional
import torch
import transformers
@@ -65,7 +65,7 @@ else:
_TRAIN_MCA_CLS = tuple()
def read_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> Union[dict[str, Any], list[str]]:
def read_args(args: dict[str, Any] | list[str] | None = None) -> dict[str, Any] | list[str]:
r"""Get arguments from the command line or a config file."""
if args is not None:
return args
@@ -83,7 +83,7 @@ def read_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> Union[
def _parse_args(
parser: "HfArgumentParser", args: Optional[Union[dict[str, Any], list[str]]] = None, allow_extra_keys: bool = False
parser: "HfArgumentParser", args: dict[str, Any] | list[str] | None = None, allow_extra_keys: bool = False
) -> tuple[Any]:
args = read_args(args)
if isinstance(args, dict):
@@ -205,13 +205,13 @@ def _check_extra_dependencies(
check_version("rouge_chinese", mandatory=True)
def _parse_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _TRAIN_CLS:
def _parse_train_args(args: dict[str, Any] | list[str] | None = None) -> _TRAIN_CLS:
parser = HfArgumentParser(_TRAIN_ARGS)
allow_extra_keys = is_env_enabled("ALLOW_EXTRA_ARGS")
return _parse_args(parser, args, allow_extra_keys=allow_extra_keys)
def _parse_train_mca_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _TRAIN_MCA_CLS:
def _parse_train_mca_args(args: dict[str, Any] | list[str] | None = None) -> _TRAIN_MCA_CLS:
parser = HfArgumentParser(_TRAIN_MCA_ARGS)
allow_extra_keys = is_env_enabled("ALLOW_EXTRA_ARGS")
model_args, data_args, training_args, finetuning_args, generating_args = _parse_args(
@@ -232,25 +232,25 @@ def _configure_mca_training_args(training_args, data_args, finetuning_args) -> N
finetuning_args.use_mca = True
def _parse_infer_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _INFER_CLS:
def _parse_infer_args(args: dict[str, Any] | list[str] | None = None) -> _INFER_CLS:
parser = HfArgumentParser(_INFER_ARGS)
allow_extra_keys = is_env_enabled("ALLOW_EXTRA_ARGS")
return _parse_args(parser, args, allow_extra_keys=allow_extra_keys)
def _parse_eval_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _EVAL_CLS:
def _parse_eval_args(args: dict[str, Any] | list[str] | None = None) -> _EVAL_CLS:
parser = HfArgumentParser(_EVAL_ARGS)
allow_extra_keys = is_env_enabled("ALLOW_EXTRA_ARGS")
return _parse_args(parser, args, allow_extra_keys=allow_extra_keys)
def get_ray_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> RayArguments:
def get_ray_args(args: dict[str, Any] | list[str] | None = None) -> RayArguments:
parser = HfArgumentParser(RayArguments)
(ray_args,) = _parse_args(parser, args, allow_extra_keys=True)
return ray_args
def get_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _TRAIN_CLS:
def get_train_args(args: dict[str, Any] | list[str] | None = None) -> _TRAIN_CLS:
if is_env_enabled("USE_MCA"):
model_args, data_args, training_args, finetuning_args, generating_args = _parse_train_mca_args(args)
else:
@@ -306,18 +306,15 @@ def get_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _
if training_args.do_train and data_args.dataset is None:
raise ValueError("Please specify dataset for training.")
if (training_args.do_eval or training_args.do_predict) and (
if (training_args.do_eval or training_args.do_predict or training_args.predict_with_generate) and (
data_args.eval_dataset is None and data_args.val_size < 1e-6
):
raise ValueError("Please specify dataset for evaluation.")
raise ValueError("Please make sure eval_dataset be provided or val_size >1e-6")
if training_args.predict_with_generate:
if is_deepspeed_zero3_enabled():
raise ValueError("`predict_with_generate` is incompatible with DeepSpeed ZeRO-3.")
if data_args.eval_dataset is None:
raise ValueError("Cannot use `predict_with_generate` if `eval_dataset` is None.")
if finetuning_args.compute_accuracy:
raise ValueError("Cannot use `predict_with_generate` and `compute_accuracy` together.")
@@ -476,7 +473,7 @@ def get_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _
return model_args, data_args, training_args, finetuning_args, generating_args
def get_infer_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _INFER_CLS:
def get_infer_args(args: dict[str, Any] | list[str] | None = None) -> _INFER_CLS:
model_args, data_args, finetuning_args, generating_args = _parse_infer_args(args)
# Setup logging
@@ -511,7 +508,7 @@ def get_infer_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _
return model_args, data_args, finetuning_args, generating_args
def get_eval_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _EVAL_CLS:
def get_eval_args(args: dict[str, Any] | list[str] | None = None) -> _EVAL_CLS:
model_args, data_args, eval_args, finetuning_args = _parse_eval_args(args)
# Setup logging

View File

@@ -14,7 +14,7 @@
import json
from dataclasses import dataclass, field
from typing import Literal, Optional, Union
from typing import Literal
from transformers import Seq2SeqTrainingArguments
from transformers.training_args import _convert_str_dict
@@ -40,7 +40,7 @@ else:
class RayArguments:
r"""Arguments pertaining to the Ray training."""
ray_run_name: Optional[str] = field(
ray_run_name: str | None = field(
default=None,
metadata={"help": "The training results will be saved at `<ray_storage_path>/ray_run_name`."},
)
@@ -48,7 +48,7 @@ class RayArguments:
default="./saves",
metadata={"help": "The storage path to save training results to"},
)
ray_storage_filesystem: Optional[Literal["s3", "gs", "gcs"]] = field(
ray_storage_filesystem: Literal["s3", "gs", "gcs"] | None = field(
default=None,
metadata={"help": "The storage filesystem to use. If None specified, local filesystem will be used."},
)
@@ -56,7 +56,7 @@ class RayArguments:
default=1,
metadata={"help": "The number of workers for Ray training. Default is 1 worker."},
)
resources_per_worker: Union[dict, str] = field(
resources_per_worker: dict | str = field(
default_factory=lambda: {"GPU": 1},
metadata={"help": "The resources per worker for Ray training. Default is to use 1 GPU per worker."},
)
@@ -64,7 +64,7 @@ class RayArguments:
default="PACK",
metadata={"help": "The placement strategy for Ray training. Default is PACK."},
)
ray_init_kwargs: Optional[Union[dict, str]] = field(
ray_init_kwargs: dict | str | None = field(
default=None,
metadata={"help": "The arguments to pass to ray.init for Ray training. Default is None."},
)

View File

@@ -15,7 +15,6 @@
import os
from typing import TYPE_CHECKING, Any, Optional, TypedDict
import torch
from transformers import (
AutoConfig,
AutoModelForCausalLM,
@@ -158,6 +157,7 @@ def load_model(
if model is None and not lazy_load:
init_kwargs["config"] = config
init_kwargs["pretrained_model_name_or_path"] = model_args.model_name_or_path
init_kwargs["torch_dtype"] = "auto"
if model_args.mixture_of_depths == "load":
model = load_mod_pretrained_model(**init_kwargs)
@@ -205,10 +205,6 @@ def load_model(
if not is_trainable:
model.requires_grad_(False)
for param in model.parameters():
if param.data.dtype == torch.float32 and model_args.compute_dtype != torch.float32:
param.data = param.data.to(model_args.compute_dtype)
model.eval()
else:
model.train()

View File

@@ -31,6 +31,18 @@ logger = logging.get_logger(__name__)
def configure_attn_implementation(config: "PretrainedConfig", model_args: "ModelArguments") -> None:
from transformers.utils import is_flash_attn_2_available
if getattr(config, "model_type", None) == "gpt_oss":
from transformers.integrations.hub_kernels import load_and_register_kernel
flash_attn3_kernel = "kernels-community/vllm-flash-attn3"
load_and_register_kernel(flash_attn3_kernel)
setattr(config, "_attn_implementation", flash_attn3_kernel)
setattr(config, "_attn_implementation_internal", flash_attn3_kernel)
model_args.flash_attn = AttentionFunction.FA3
logger.info_rank0("Using FlashAttention-3 with attention sink for the gpt-oss model.")
return
if getattr(config, "model_type", None) == "gemma2":
if model_args.flash_attn == AttentionFunction.AUTO or model_args.flash_attn == AttentionFunction.FA2:
if is_flash_attn_2_available():

View File

@@ -18,11 +18,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import inspect
import os
from collections.abc import Callable
from functools import WRAPPER_ASSIGNMENTS, partial, wraps
from types import MethodType
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
from typing import TYPE_CHECKING, Any, Optional, Union
import torch
@@ -156,11 +157,9 @@ def prepare_model_for_training(model: "PreTrainedModel", model_args: "ModelArgum
if (
os.environ.get("ACCELERATE_USE_FSDP", "false").lower() == "true"
and int(os.environ.get("FSDP_VERSION", "1")) == 2
):
):
model_args.use_reentrant_gc = False
logger.warning_rank0(
"You are using fsdp2, `use_reentrant_gc` has been set to False. "
)
logger.warning_rank0("You are using fsdp2, `use_reentrant_gc` has been set to False.")
if not model_args.disable_gradient_checkpointing:
if not getattr(model, "supports_gradient_checkpointing", False):

View File

@@ -77,6 +77,12 @@ def apply_liger_kernel(
from liger_kernel.transformers import apply_liger_kernel_to_qwen3 as apply_liger_kernel
elif model_type == "qwen3_moe":
from liger_kernel.transformers import apply_liger_kernel_to_qwen3_moe as apply_liger_kernel
elif model_type == "gpt_oss":
try:
from liger_kernel.transformers import apply_liger_kernel_to_gpt_oss as apply_liger_kernel
except ImportError:
logger.warning_rank0("Please install liger-kernel from https://github.com/Comet0322/Liger-Kernel.")
return
else:
logger.warning_rank0("Current model does not support liger kernel.")
return

View File

@@ -82,6 +82,11 @@ def add_z3_leaf_module(model: "PreTrainedModel") -> None:
_set_z3_leaf_modules(model, [Glm4vMoeTextMoE])
if model_type == "gpt_oss":
from transformers.models.gpt_oss.modeling_gpt_oss import GptOssMLP
_set_z3_leaf_modules(model, [GptOssMLP])
if model_type == "jamba":
from transformers.models.jamba.modeling_jamba import JambaSparseMoeBlock

View File

@@ -83,6 +83,7 @@ def configure_quantization(
config: "PretrainedConfig",
tokenizer: "PreTrainedTokenizer",
model_args: "ModelArguments",
is_trainable: bool,
init_kwargs: dict[str, Any],
) -> None:
r"""Priority: PTQ-quantized (train/infer) > AutoGPTQ (export) > On-the-fly quantization (train/infer)."""
@@ -93,10 +94,26 @@ def configure_quantization(
quantization_config: dict[str, Any] = getattr(config, "quantization_config", None)
quant_method = quantization_config.get("quant_method", "")
if quant_method != QuantizationMethod.MXFP4 and (is_deepspeed_zero3_enabled() or is_fsdp_enabled()):
if quant_method not in (QuantizationMethod.MXFP4, QuantizationMethod.FP8) and (
is_deepspeed_zero3_enabled() or is_fsdp_enabled()
):
# mxfp4 will dequant the model weights
raise ValueError("DeepSpeed ZeRO-3 or FSDP is incompatible with PTQ-quantized models.")
if quant_method == QuantizationMethod.MXFP4:
from transformers import Mxfp4Config
quant_config = Mxfp4Config(dequantize=True)
init_kwargs["quantization_config"] = quant_config
init_kwargs["ignore_mismatched_sizes"] = True
if quant_method == QuantizationMethod.FP8:
from transformers import FineGrainedFP8Config
quant_config = FineGrainedFP8Config(dequantize=True)
init_kwargs["quantization_config"] = quant_config
init_kwargs["ignore_mismatched_sizes"] = True
if quant_method == QuantizationMethod.GPTQ:
check_version("gptqmodel>=2.0.0", mandatory=True)
quantization_config.pop("disable_exllama", None) # remove deprecated args

View File

@@ -40,7 +40,10 @@ def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments") ->
logger.warning_rank0("Current model does not support RoPE scaling.")
return
if hasattr(config, "max_position_embeddings"):
rope_scaling = getattr(config, "rope_scaling", None)
if isinstance(rope_scaling, dict) and "original_max_position_embeddings" in rope_scaling:
old_max_length = rope_scaling["original_max_position_embeddings"]
elif hasattr(config, "max_position_embeddings"):
old_max_length = getattr(config, "max_position_embeddings", None)
else:
logger.warning_rank0("Cannot find the max position embeddings in the config.")

View File

@@ -301,6 +301,7 @@ _register_composite_model(
_register_composite_model(
model_type="mistral3",
projector_key="model.multi_modal_projector",
)

View File

@@ -51,9 +51,13 @@ logger = logging.get_logger(__name__)
def patch_qwen3_omni_moe_thinker_text_sparse_moe_block():
if is_transformers_version_greater_than("4.57.0"):
if is_transformers_version_greater_than("4.57.0") and not is_transformers_version_greater_than("4.58.0"):
from .model_utils.moe import Qwen3OmniMoeThinkerTextSparseMoeBlock
logger.warning_rank0(
"You are using transformers with 4.x version, the Qwen3OmniMoeThinkerTextSparseMoeBlock will have some issues about deepspeed zero2 and fsdp2 training, so that we patched this model to avoid it. Transformers v5.0.0rc0 has fixed the issue, you can also try to update the transformers to using qwen3_omni. See more information on https://github.com/hiyouga/LLaMA-Factory/issues/9628."
)
modeling_qwen3_omni_moe.Qwen3OmniMoeThinkerTextSparseMoeBlock = Qwen3OmniMoeThinkerTextSparseMoeBlock
@@ -115,7 +119,7 @@ def patch_config(
configure_attn_implementation(config, model_args)
configure_rope(config, model_args)
configure_longlora(config, model_args, is_trainable)
configure_quantization(config, tokenizer, model_args, init_kwargs)
configure_quantization(config, tokenizer, model_args, is_trainable, init_kwargs)
configure_moe(config, model_args, is_trainable)
configure_visual_model(config)
configure_packing(model_args, is_trainable)
@@ -152,16 +156,13 @@ def patch_config(
# deepspeed zero3 is not compatible with low_cpu_mem_usage
init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage and (not is_deepspeed_zero3_enabled())
# do not cast data type of the model deepspeed zero3 without qlora
if not (is_deepspeed_zero3_enabled() and model_args.quantization_bit is None):
init_kwargs["torch_dtype"] = model_args.compute_dtype
# fsdp/deepspeed zero3 does not need device map
if not (is_deepspeed_zero3_enabled() or is_fsdp_enabled()) and init_kwargs["low_cpu_mem_usage"]:
if "device_map" not in init_kwargs and model_args.device_map:
init_kwargs["device_map"] = model_args.device_map # device map requires low_cpu_mem_usage=True
if init_kwargs["low_cpu_mem_usage"] and not is_fsdp_enabled(): # fsdp does not need device map
if "device_map" not in init_kwargs and model_args.device_map:
init_kwargs["device_map"] = model_args.device_map # device map requires low_cpu_mem_usage=True
if init_kwargs.get("device_map", None) == "auto":
init_kwargs["offload_folder"] = model_args.offload_folder
if init_kwargs.get("device_map", None) == "auto":
init_kwargs["offload_folder"] = model_args.offload_folder
def patch_model(

View File

@@ -0,0 +1,62 @@
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's TRL library.
# https://github.com/huggingface/trl/blob/v0.8.0/trl/trainer/dpo_trainer.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
import torch
from ktransformers.sft.lora import KTrainer # type: ignore
from typing_extensions import override
from ..trainer_utils import get_batch_logps, nested_detach
from .trainer import CustomDPOTrainer
if TYPE_CHECKING:
from transformers import PreTrainedModel
class KDPOTrainer(KTrainer, CustomDPOTrainer):
@override
def concatenated_forward(
self, model: "PreTrainedModel", batch: dict[str, "torch.Tensor"], is_ref_model: bool = False
) -> tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
r"""Compute the sum log probabilities of the labels under given logits if loss_type is not IPO, ORPO or SimPO.
Otherwise the average log probabilities.
"""
if self.finetuning_args.use_ref_model:
batch = nested_detach(batch, clone=True) # avoid error
labels = batch.pop("labels") # dpo do not need compute loss in forward
all_logits: torch.Tensor = model(**batch, return_dict=True, use_cache=False).logits.to(torch.float32)
all_logits = all_logits.to("cpu")
labels = labels.to(all_logits.device)
all_logps, valid_length = get_batch_logps(
logits=all_logits, labels=labels, ld_alpha=(self.ld_alpha if not is_ref_model else None)
)
if self.loss_type in ["ipo", "orpo", "simpo"]:
all_logps = all_logps / valid_length
batch_size = batch["input_ids"].size(0) // 2
chosen_logps, rejected_logps = all_logps.split(batch_size, dim=0)
chosen_logits, rejected_logits = all_logits.split(batch_size, dim=0)
chosen_length, _ = valid_length.split(batch_size, dim=0)
if self.loss_type in ["ipo", "orpo", "simpo"]:
return chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_logps
else:
return chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_logps / chosen_length

View File

@@ -218,9 +218,10 @@ class CustomDPOTrainer(DPOTrainer):
if self.finetuning_args.use_ref_model:
batch = nested_detach(batch, clone=True) # avoid error
labels = batch.pop("labels") # dpo do not need compute loss in forward
all_logits: torch.Tensor = model(**batch, return_dict=True, use_cache=False).logits.to(torch.float32)
all_logps, valid_length = get_batch_logps(
logits=all_logits, labels=batch["labels"], ld_alpha=(self.ld_alpha if not is_ref_model else None)
logits=all_logits, labels=labels, ld_alpha=(self.ld_alpha if not is_ref_model else None)
)
if self.loss_type in ["ipo", "orpo", "simpo"]:
all_logps = all_logps / valid_length

View File

@@ -24,7 +24,6 @@ from ...extras.ploting import plot_loss
from ...hparams import ModelArguments
from ...model import load_model, load_tokenizer
from ..trainer_utils import create_modelcard_and_push, create_ref_model
from .trainer import CustomDPOTrainer
if TYPE_CHECKING:
@@ -63,6 +62,16 @@ def run_dpo(
else:
ref_model = None
if model_args.use_kt:
from ktransformers.util.globals import GLOBAL_CONFIG # type: ignore
from .ktrainer import KDPOTrainer as CustomDPOTrainer
GLOBAL_CONFIG._config["mod"] = "sft"
else:
from .trainer import CustomDPOTrainer
# Initialize our Trainer
trainer = CustomDPOTrainer(
model=model,

View File

@@ -1,18 +0,0 @@
# Copyright 2025 the KVCache.AI team, Approaching AI, and the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .workflow import run_sft
__all__ = ["run_sft"]

View File

@@ -1,113 +0,0 @@
# Copyright 2025 the KVCache.AI team, Approaching AI, and the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Optional
from ...data import SFTDataCollatorWith4DAttentionMask, get_dataset, get_template_and_fix_tokenizer
from ...extras.constants import IGNORE_INDEX
from ...extras.logging import get_logger
from ...extras.misc import calculate_tps
from ...extras.ploting import plot_loss
from ...model import load_model, load_tokenizer
from ..trainer_utils import create_modelcard_and_push
if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, TrainerCallback
from ...hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
logger = get_logger(__name__)
def run_sft(
model_args: "ModelArguments",
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments",
callbacks: Optional[list["TrainerCallback"]] = None,
):
tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"]
template = get_template_and_fix_tokenizer(tokenizer, data_args)
dataset_module = get_dataset(template, model_args, data_args, training_args, stage="sft", **tokenizer_module)
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
from ktransformers.util.globals import GLOBAL_CONFIG
GLOBAL_CONFIG._config["mod"] = "sft"
if getattr(model, "is_quantized", False) and not training_args.do_train:
setattr(model, "_hf_peft_config_loaded", True) # hack here: make model compatible with prediction
data_collator = SFTDataCollatorWith4DAttentionMask(
template=template,
model=model if not training_args.predict_with_generate else None,
pad_to_multiple_of=8 if training_args.do_train else None, # for shift short attention
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id,
block_diag_attn=model_args.block_diag_attn,
attn_implementation=getattr(model.config, "_attn_implementation", None),
compute_dtype=model_args.compute_dtype,
**tokenizer_module,
)
# Metric utils
metric_module = {}
if training_args.predict_with_generate:
raise NotImplementedError("`predict_with_generate` is not supported in KTransformers SFT yet.")
elif finetuning_args.compute_accuracy:
raise NotImplementedError("`compute_accuracy` is not supported in KTransformers SFT yet.")
# Initialize our Trainer
from ktransformers.sft.lora import KTrainer
trainer = KTrainer(
model=model,
args=training_args,
tokenizer=tokenizer_module,
data_collator=data_collator,
callbacks=callbacks,
**dataset_module,
**metric_module,
)
trainer.model_accepts_loss_kwargs = False
# Training
if training_args.do_train:
model.config.use_cache = False
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
trainer.save_model()
if finetuning_args.include_effective_tokens_per_second:
train_result.metrics["effective_tokens_per_sec"] = calculate_tps(
dataset_module["train_dataset"], train_result.metrics, stage="sft"
)
trainer.log_metrics("train", train_result.metrics)
trainer.save_metrics("train", train_result.metrics)
trainer.save_state()
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
keys = ["loss"]
if isinstance(dataset_module.get("eval_dataset"), dict):
keys += sum(
[[f"eval_{key}_loss", f"eval_{key}_accuracy"] for key in dataset_module["eval_dataset"].keys()], []
)
else:
keys += ["eval_loss", "eval_accuracy"]
plot_loss(training_args.output_dir, keys=keys)
# Create model card
create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args)

View File

@@ -13,7 +13,7 @@
# limitations under the License.
from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING
import numpy as np
@@ -28,7 +28,7 @@ if TYPE_CHECKING:
class ComputeAccuracy:
r"""Compute reward accuracy and support `batch_eval_metrics`."""
def _dump(self) -> Optional[dict[str, float]]:
def _dump(self) -> dict[str, float] | None:
result = None
if hasattr(self, "score_dict"):
result = {k: float(np.mean(v)) for k, v in self.score_dict.items()}
@@ -39,7 +39,7 @@ class ComputeAccuracy:
def __post_init__(self):
self._dump()
def __call__(self, eval_preds: "EvalPrediction", compute_result: bool = True) -> Optional[dict[str, float]]:
def __call__(self, eval_preds: "EvalPrediction", compute_result: bool = True) -> dict[str, float] | None:
chosen_scores, rejected_scores = numpify(eval_preds.predictions[0]), numpify(eval_preds.predictions[1])
if not chosen_scores.shape:
self.score_dict["accuracy"].append(chosen_scores > rejected_scores)

View File

@@ -21,6 +21,7 @@ from ...data import SFTDataCollatorWith4DAttentionMask, get_dataset, get_templat
from ...extras.constants import IGNORE_INDEX
from ...extras.logging import get_logger
from ...extras.misc import calculate_tps
from ...extras.packages import is_transformers_version_greater_than
from ...extras.ploting import plot_loss
from ...model import load_model, load_tokenizer
from ..trainer_utils import create_modelcard_and_push
@@ -67,6 +68,12 @@ def run_sft(
# Metric utils
metric_module = {}
if model_args.use_kt:
if training_args.predict_with_generate:
raise NotImplementedError("`predict_with_generate` is not supported in KTransformers SFT yet.")
elif finetuning_args.compute_accuracy:
raise NotImplementedError("`compute_accuracy` is not supported in KTransformers SFT yet.")
if training_args.predict_with_generate:
metric_module["compute_metrics"] = ComputeSimilarity(tokenizer=tokenizer)
elif finetuning_args.compute_accuracy:
@@ -75,21 +82,52 @@ def run_sft(
# Keyword arguments for `model.generate`
gen_kwargs = generating_args.to_dict(obey_generation_config=True)
gen_kwargs["eos_token_id"] = [tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids
# Compatible with Transformers v4 and Transformers v5
if is_transformers_version_greater_than("4.58.0"):
extra_ids = getattr(tokenizer, "additional_special_tokens_ids", None)
if not isinstance(extra_ids, list):
extra_special_tokens = getattr(tokenizer, "_extra_special_tokens", [])
string_tokens = [str(t) for t in extra_special_tokens]
extra_ids = tokenizer.convert_tokens_to_ids(string_tokens)
all_eos_ids = [tokenizer.eos_token_id] + [i for i in extra_ids if i != -1]
unique_eos_ids = list(dict.fromkeys(all_eos_ids))
gen_kwargs["eos_token_id"] = unique_eos_ids
else:
gen_kwargs["eos_token_id"] = [tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids
gen_kwargs["pad_token_id"] = tokenizer.pad_token_id
# Initialize our Trainer
trainer = CustomSeq2SeqTrainer(
model=model,
args=training_args,
finetuning_args=finetuning_args,
data_collator=data_collator,
callbacks=callbacks,
gen_kwargs=gen_kwargs,
**dataset_module,
**tokenizer_module,
**metric_module,
)
if model_args.use_kt:
from ktransformers.sft.lora import KTrainer # type: ignore
from ktransformers.util.globals import GLOBAL_CONFIG # type: ignore
GLOBAL_CONFIG._config["mod"] = "sft"
trainer = KTrainer(
model=model,
args=training_args,
tokenizer=tokenizer_module,
data_collator=data_collator,
callbacks=callbacks,
**dataset_module,
**metric_module,
)
trainer.model_accepts_loss_kwargs = False
model.config.use_cache = False
else:
trainer = CustomSeq2SeqTrainer(
model=model,
args=training_args,
finetuning_args=finetuning_args,
data_collator=data_collator,
callbacks=callbacks,
gen_kwargs=gen_kwargs,
**dataset_module,
**tokenizer_module,
**metric_module,
)
# Training
if training_args.do_train:

View File

@@ -20,7 +20,6 @@ from transformers import AutoModelForCausalLM
from trl import AutoModelForCausalLMWithValueHead
from ..data import get_dataset, get_template_and_fix_tokenizer
from ..extras.misc import get_current_device
from ..hparams import get_infer_args, get_train_args
from ..model import load_model, load_tokenizer
@@ -81,17 +80,14 @@ def load_reference_model(
is_trainable: bool = False,
add_valuehead: bool = False,
) -> Union["PreTrainedModel", "LoraModel"]:
current_device = get_current_device()
if add_valuehead:
model: AutoModelForCausalLMWithValueHead = AutoModelForCausalLMWithValueHead.from_pretrained(
model_path, torch_dtype=torch.float16, device_map=current_device
model_path, torch_dtype=torch.float16, device_map="auto"
)
if not is_trainable:
model.v_head = model.v_head.to(torch.float16)
return model
model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16, device_map=current_device)
model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16, device_map="auto")
if use_lora or use_pissa:
model = PeftModel.from_pretrained(
model, lora_path, subfolder="pissa_init" if use_pissa else None, is_trainable=is_trainable

View File

@@ -19,9 +19,9 @@
import json
import os
from collections.abc import Mapping
from collections.abc import Callable, Mapping
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
from typing import TYPE_CHECKING, Any, Optional, Union
import torch
from transformers import Trainer

View File

@@ -85,13 +85,7 @@ def _training_function(config: dict[str, Any]) -> None:
elif finetuning_args.stage == "pt":
run_pt(model_args, data_args, training_args, finetuning_args, callbacks)
elif finetuning_args.stage == "sft":
if model_args.use_kt:
from .ksft.workflow import run_sft as run_sft_kt
run_sft_kt(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
else:
run_sft(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
run_sft(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
elif finetuning_args.stage == "rm":
run_rm(model_args, data_args, training_args, finetuning_args, callbacks)
elif finetuning_args.stage == "ppo":

View File

@@ -0,0 +1,215 @@
# Copyright 2025 Bytedance Ltd. and the LlamaFactory team.
#
# This code is inspired by the Bytedance's VeOmni library.
# https://github.com/ByteDance-Seed/VeOmni/blob/v0.1.4/veomni/utils/dist_utils.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utility functions used by the distributed interface.
Including:
- Environment info (rank, world_size, local_rank, etc.)
- Accelerator info (device type, device count, etc.)
- Collective communication operations (all_gather, all_reduce, broadcast)
- Synchronize processes and ensure main-process-first execution order
"""
import os
from collections.abc import Callable
from contextlib import contextmanager
from enum import Enum, unique
from functools import lru_cache, wraps
from typing import Optional
import numpy as np
import torch
import torch.distributed as dist
from ..utils.types import ProcessGroup, Tensor, TensorLike
@unique
class DeviceType(str, Enum):
CPU = "cpu"
CUDA = "cuda"
META = "meta"
MPS = "mps"
NPU = "npu"
XPU = "xpu"
@unique
class ReduceOp(str, Enum):
SUM = "sum"
MEAN = "mean"
MAX = "max"
MIN = "min"
def requires_accelerator(fn):
"""Decorator to check if torch.accelerator is available.
Note: this api requires torch>=2.7.0, otherwise it will raise an AttributeError or RuntimeError
"""
@wraps(fn)
def wrapper(*args, **kwargs):
if not hasattr(torch, "accelerator"):
raise RuntimeError("torch.accelerator is not available, please upgrade torch to 2.7.0 or higher.")
return fn(*args, **kwargs)
return wrapper
def is_distributed() -> bool:
"""Check if distributed environment is available."""
return os.getenv("RANK") is not None
def get_rank() -> int:
"""Get rank."""
return int(os.getenv("RANK", "0"))
def get_world_size() -> int:
"""Get world size."""
return int(os.getenv("WORLD_SIZE", "1"))
def get_local_rank() -> int:
"""Get local rank."""
return int(os.getenv("LOCAL_RANK", "0"))
def get_local_world_size() -> int:
"""Get local world size."""
return int(os.getenv("LOCAL_WORLD_SIZE", "1"))
@lru_cache
@requires_accelerator
def get_current_accelerator(check_available: bool = True) -> torch.device:
"""Get current accelerator."""
accelerator = torch.accelerator.current_accelerator(check_available=check_available)
return accelerator or torch.device(DeviceType.CPU.value)
@lru_cache
@requires_accelerator
def get_device_count() -> int:
"""Get the number of available devices."""
return torch.accelerator.device_count()
@requires_accelerator
def synchronize() -> None:
"""Synchronize all processes."""
torch.accelerator.synchronize()
@requires_accelerator
def set_device() -> None:
"""Set current accelerator."""
torch.accelerator.set_device_index(get_local_rank())
def is_torch_cuda_available():
"""Check if CUDA is available."""
return get_current_accelerator().type == DeviceType.CUDA
def is_torch_mps_available():
"""Check if MPS is available."""
return get_current_accelerator().type == DeviceType.MPS
def is_torch_npu_available():
"""Check if NPU is available."""
return get_current_accelerator().type == DeviceType.NPU
def is_torch_xpu_available():
"""Check if XPU is available."""
return get_current_accelerator().type == DeviceType.XPU
def operate_tensorlike(fn: Callable[[...], Tensor], data: TensorLike, **kwargs) -> TensorLike:
"""Operate tensorlike data on current accelerator."""
device = get_current_accelerator()
is_tensor = isinstance(data, torch.Tensor)
is_ndarray = isinstance(data, np.ndarray)
if is_tensor:
orig_device = data.device
data = data.to(device=device)
elif is_ndarray:
data = torch.from_numpy(data).to(device=device, dtype=torch.float)
else:
data = torch.tensor(data, dtype=torch.float, device=device)
result = fn(data, **kwargs)
if is_tensor:
return result.to(orig_device)
elif is_ndarray:
return result.cpu().numpy()
elif result.numel() == 1:
return result.item()
else:
return result.tolist()
def all_gather(tensor: Tensor, group: Optional[ProcessGroup] = None) -> Tensor:
"""Gathers the tensor from all ranks and stacks them at the first dim."""
world_size = get_world_size()
output_tensor = torch.empty(world_size * tensor.numel(), dtype=tensor.dtype, device=tensor.device)
dist.all_gather_into_tensor(output_tensor, tensor, group=group)
return output_tensor.view(-1, *tensor.size())
def all_reduce(tensor: Tensor, op: ReduceOp = ReduceOp.MEAN, group: Optional[ProcessGroup] = None) -> Tensor:
"""Performs all reduce in the given process group."""
reduce_ops = {
ReduceOp.MEAN: dist.ReduceOp.SUM,
ReduceOp.SUM: dist.ReduceOp.SUM,
ReduceOp.MAX: dist.ReduceOp.MAX,
ReduceOp.MIN: dist.ReduceOp.MIN,
}
dist.all_reduce(tensor, op=reduce_ops[op], group=group)
if op == ReduceOp.MEAN: # ReduceOp.AVG is not supported by the NPU backend
tensor /= dist.get_world_size(group=group)
return tensor
def broadcast(tensor: Tensor, src: int = 0, group: Optional[ProcessGroup] = None) -> Tensor:
"""Broadcasts the tensor from the src process to all other processes."""
dist.broadcast(tensor, src=src, group=group)
return tensor
@contextmanager
def main_process_first(local_only: bool = True) -> None:
"""A context manager for torch distributed environment to do something on the main process firstly."""
if get_world_size() > 1:
is_main_process = get_local_rank() == 0 if local_only else get_rank() == 0
try:
if not is_main_process:
dist.barrier()
yield
finally:
if is_main_process:
dist.barrier()
else:
yield

View File

@@ -0,0 +1,249 @@
# Copyright 2025 Bytedance Ltd. and the LlamaFactory team.
#
# This code is inspired by the Bytedance's VeOmni library.
# https://github.com/ByteDance-Seed/VeOmni/blob/v0.1.4/veomni/distributed/parallel_state.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A unified interface for model parallelism and data parallelism.
Supports model parallelism types:
- mp_replicate: Replicate model across multiple devices.
- mp_shard: Shard model across multiple devices.
And data parallelism types:
- dp: Data parallelism.
- cp: Context parallelism.
"""
from dataclasses import dataclass
from datetime import timedelta
from enum import Enum
from typing import Any, Optional
from torch.distributed import barrier, destroy_process_group, init_process_group
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
from ..utils.types import DistributedConfig, ProcessGroup, Tensor, TensorLike
from . import helper
class Dim(str, Enum):
"""Dimension names."""
MP_REPLICATE = "mp_replicate"
MP_SHARD = "mp_shard"
DP = "dp"
CP = "cp"
@dataclass
class DistributedStrategy:
"""Distributed strategy."""
mp_replicate_size: int = 1
"""Model parallel replicate size, default to 1."""
mp_shard_size: int | None = None
"""Model parallel shard size, default to world_size // mp_replicate_size."""
dp_size: int | None = None
"""Data parallel size, default to world_size // cp_size."""
cp_size: int = 1
"""Context parallel size, default to 1."""
def __post_init__(self) -> None:
if not helper.is_distributed():
self.mp_shard_size = 1
elif self.mp_shard_size is None:
self.mp_shard_size = helper.get_world_size() // self.mp_replicate_size
elif self.mp_replicate_size * self.mp_shard_size != helper.get_world_size():
raise ValueError(
f"mp_replicate_size * mp_shard_size must equal to world_size, "
f"got {self.mp_replicate_size} * {self.mp_shard_size} != {helper.get_world_size()}."
)
if not helper.is_distributed():
self.dp_size = 1
elif self.dp_size is None:
self.dp_size = helper.get_world_size() // self.cp_size
elif self.dp_size * self.cp_size != helper.get_world_size():
raise ValueError(
f"dp_size * cp_size must equal to world_size, "
f"got {self.dp_size} * {self.cp_size} != {helper.get_world_size()}."
)
@property
def model_mesh_shape(self) -> tuple[int, int]:
"""Model parallel mesh shape."""
return (self.mp_replicate_size, self.mp_shard_size)
@property
def model_mesh_dim_names(self) -> tuple[str, str]:
"""Model parallel mesh dimension names."""
return (Dim.MP_REPLICATE.value, Dim.MP_SHARD.value)
@property
def data_mesh_shape(self) -> tuple[int, int]:
"""Data parallel mesh shape."""
return (self.dp_size, self.cp_size)
@property
def data_mesh_dim_names(self) -> tuple[str, str]:
"""Data parallel mesh dimension names."""
return (Dim.DP.value, Dim.CP.value)
class DistributedInterface:
"""Distributed interface."""
_instance: Optional["DistributedInterface"] = None
_initialized: bool = False
def __new__(cls, *args: Any, **kwargs: Any) -> "DistributedInterface":
"""Singleton pattern."""
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self, config: DistributedConfig | None = None) -> None:
if self._initialized:
return
self._is_distributed = helper.is_distributed()
self._rank = helper.get_rank()
self._world_size = helper.get_world_size()
self._local_rank = helper.get_local_rank()
self._local_world_size = helper.get_local_world_size()
self.current_accelerator = helper.get_current_accelerator()
self.device_count = helper.get_device_count()
if config is None:
self.strategy = DistributedStrategy()
timeout = 18000
else:
self.strategy = DistributedStrategy(
mp_replicate_size=config.get("mp_replicate_size", 1),
mp_shard_size=config.get("mp_shard_size", None),
dp_size=config.get("dp_size", None),
cp_size=config.get("cp_size", 1),
)
timeout = config.get("timeout", 18000)
if self._is_distributed:
helper.set_device()
init_process_group(timeout=timedelta(seconds=timeout))
self.model_device_mesh = init_device_mesh(
device_type=self.current_accelerator.type,
mesh_shape=self.strategy.model_mesh_shape,
mesh_dim_names=self.strategy.model_mesh_dim_names,
)
self.data_device_mesh = init_device_mesh(
device_type=self.current_accelerator.type,
mesh_shape=self.strategy.data_mesh_shape,
mesh_dim_names=self.strategy.data_mesh_dim_names,
)
else:
self.model_device_mesh = None
self.data_device_mesh = None
self._initialized = True
def __str__(self) -> str:
return (
f"DistributedInterface(strategy={self.strategy}), is_distributed={self._is_distributed}, "
f"current_accelerator={self.current_accelerator}, rank={self._rank}, world_size={self._world_size}, "
f"model_device_mesh={self.model_device_mesh}, data_device_mesh={self.data_device_mesh}"
)
def get_device_mesh(self, dim: Dim | None = None) -> DeviceMesh | None:
"""Get device mesh for specified dimension."""
if dim is None:
raise ValueError("dim must be specified.")
elif self.model_device_mesh is None:
return None
elif dim in self.strategy.data_mesh_dim_names:
return self.data_device_mesh[dim.value]
else:
return self.model_device_mesh[dim.value]
def get_group(self, dim: Dim | None = None) -> Optional[ProcessGroup]:
"""Get process group for specified dimension."""
if self.model_device_mesh is None or dim is None:
return None
else:
return self.get_device_mesh(dim).get_group()
def get_rank(self, dim: Dim | None = None) -> int:
"""Get parallel rank for specified dimension."""
if self.model_device_mesh is None:
return 0
elif dim is None:
return self._rank
else:
return self.get_device_mesh(dim).get_local_rank()
def get_world_size(self, dim: Dim | None = None) -> int:
"""Get parallel size for specified dimension."""
if self.model_device_mesh is None:
return 1
elif dim is None:
return self._world_size
else:
return self.get_device_mesh(dim).size()
def get_local_rank(self) -> int:
"""Get parallel local rank."""
return self._local_rank
def get_local_world_size(self) -> int:
"""Get parallel local world size."""
return self._local_world_size
def all_gather(self, data: Tensor, dim: Dim | None = Dim.DP) -> Tensor:
"""Gather tensor across specified parallel group."""
if self.model_device_mesh is not None:
return helper.operate_tensorlike(helper.all_gather, data, group=self.get_group(dim))
else:
return data
def all_reduce(
self, data: TensorLike, op: helper.ReduceOp = helper.ReduceOp.MEAN, dim: Dim | None = Dim.DP
) -> TensorLike:
"""Reduce tensor across specified parallel group."""
if self.model_device_mesh is not None:
return helper.operate_tensorlike(helper.all_reduce, data, op=op, group=self.get_group(dim))
else:
return data
def broadcast(self, data: TensorLike, src: int = 0, dim: Dim | None = Dim.DP) -> TensorLike:
"""Broadcast tensor across specified parallel group."""
if self.model_device_mesh is not None:
return helper.operate_tensorlike(helper.broadcast, data, src=src, group=self.get_group(dim))
else:
return data
def sync(self) -> None:
"""Synchronize all processes."""
helper.synchronize()
def barrier(self) -> None:
"""Barrier all processes."""
barrier()
def destroy(self) -> None:
"""Destroy all processes."""
destroy_process_group()
if __name__ == "__main__":
print(DistributedInterface(DistributedStrategy()))

View File

@@ -12,11 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import sys
from pathlib import Path
from typing import Any, Optional, Union
from typing import Any
from omegaconf import OmegaConf
from transformers import HfArgumentParser
@@ -28,9 +27,25 @@ from .sample_args import SampleArguments
from .training_args import TrainingArguments
def get_args(
args: Optional[Union[dict[str, Any], list[str]]] = None,
) -> tuple[DataArguments, ModelArguments, TrainingArguments, SampleArguments]:
InputArgument = dict[str, Any] | list[str] | None
def validate_args(
data_args: DataArguments,
model_args: ModelArguments,
training_args: TrainingArguments,
sample_args: SampleArguments,
):
"""Validate arguments."""
if (
model_args.quant_config is not None
and training_args.dist_config is not None
and training_args.dist_config.name == "deepspeed"
):
raise ValueError("Quantization is not supported with deepspeed backend.")
def get_args(args: InputArgument = None) -> tuple[DataArguments, ModelArguments, TrainingArguments, SampleArguments]:
"""Parse arguments from command line or config file."""
parser = HfArgumentParser([DataArguments, ModelArguments, TrainingArguments, SampleArguments])
allow_extra_keys = is_env_enabled("ALLOW_EXTRA_KEYS")
@@ -56,6 +71,8 @@ def get_args(
print(f"Got unknown args, potentially deprecated arguments: {unknown_args}")
raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {unknown_args}")
validate_args(*parsed_args)
return tuple(parsed_args)

View File

@@ -0,0 +1,95 @@
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v5.0.0rc0/src/transformers/training_args.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from enum import Enum, unique
class PluginConfig(dict):
"""Dictionary that allows attribute access."""
@property
def name(self) -> str:
"""Plugin name."""
if "name" not in self:
raise ValueError("Plugin configuration must have a 'name' field.")
return self["name"]
PluginArgument = PluginConfig | dict | str | None
@unique
class ModelClass(str, Enum):
"""Auto class for model config."""
LLM = "llm"
CLS = "cls"
OTHER = "other"
@unique
class SampleBackend(str, Enum):
HF = "hf"
VLLM = "vllm"
def _convert_str_dict(data: dict) -> dict:
"""Parse string representation inside the dictionary.
Args:
data: The string or dictionary to convert.
Returns:
The converted dictionary.
"""
for key, value in data.items():
if isinstance(value, dict):
data[key] = _convert_str_dict(value)
elif isinstance(value, str):
if value.lower() in ("true", "false"):
data[key] = value.lower() == "true"
elif value.isdigit():
data[key] = int(value)
elif value.replace(".", "", 1).isdigit():
data[key] = float(value)
return data
def get_plugin_config(config: PluginArgument) -> PluginConfig | None:
"""Get the plugin configuration from the argument value.
Args:
config: The argument value to get the plugin configuration from.
Returns:
The plugin configuration.
"""
if config is None:
return None
if isinstance(config, str) and config.startswith("{"):
config = json.loads(config)
config = _convert_str_dict(config)
if "name" not in config:
raise ValueError("Plugin configuration must have a 'name' field.")
return PluginConfig(config)

View File

@@ -14,19 +14,14 @@
from dataclasses import dataclass, field
from typing import Optional
@dataclass
class DataArguments:
dataset: Optional[str] = field(
dataset: str | None = field(
default=None,
metadata={"help": "Path to the dataset."},
)
dataset_dir: str = field(
default="data",
metadata={"help": "Path to the folder containing the datasets."},
)
cutoff_len: int = field(
default=2048,
metadata={"help": "Cutoff length for the dataset."},

View File

@@ -15,6 +15,8 @@
from dataclasses import dataclass, field
from .arg_utils import ModelClass, PluginConfig, get_plugin_config
@dataclass
class ModelArguments:
@@ -25,3 +27,28 @@ class ModelArguments:
default=False,
metadata={"help": "Trust remote code from Hugging Face."},
)
use_fast_processor: bool = field(
default=True,
metadata={"help": "Use fast processor from Hugging Face."},
)
model_class: ModelClass = field(
default=ModelClass.LLM,
metadata={"help": "Model class from Hugging Face."},
)
peft_config: PluginConfig | None = field(
default=None,
metadata={"help": "PEFT configuration for the model."},
)
kernel_config: PluginConfig | None = field(
default=None,
metadata={"help": "Kernel configuration for the model."},
)
quant_config: PluginConfig | None = field(
default=None,
metadata={"help": "Quantization configuration for the model."},
)
def __post_init__(self) -> None:
self.peft_config = get_plugin_config(self.peft_config)
self.kernel_config = get_plugin_config(self.kernel_config)
self.quant_config = get_plugin_config(self.quant_config)

View File

@@ -15,9 +15,15 @@
from dataclasses import dataclass, field
from .arg_utils import SampleBackend
@dataclass
class SampleArguments:
sample_backend: SampleBackend = field(
default=SampleBackend.HF,
metadata={"help": "Sampling backend, default to 'hf'."},
)
max_new_tokens: int = field(
default=128,
metadata={"help": "Maximum number of new tokens to generate."},

View File

@@ -12,14 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from dataclasses import dataclass, field
from uuid import uuid4
from .arg_utils import PluginConfig, get_plugin_config
@dataclass
class TrainingArguments:
output_dir: str = field(
default="",
default=os.path.join("outputs", str(uuid4())),
metadata={"help": "Path to the output directory."},
)
micro_batch_size: int = field(
@@ -38,3 +41,10 @@ class TrainingArguments:
default=False,
metadata={"help": "Use bf16 for training."},
)
dist_config: PluginConfig | None = field(
default=None,
metadata={"help": "Distribution configuration for training."},
)
def __post_init__(self) -> None:
self.dist_config = get_plugin_config(self.dist_config)

View File

@@ -12,44 +12,45 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
"""The definition of trainer.
Init Phase:
1. Init dataloader.
2. Init optimizer (deepspeed).
3. Shard model.
4. Init optimizer (fsdp).
5. Init scheduler.
Train Phase:
1. Train Loop
"""
from ..config.training_args import TrainingArguments
from ..extras.types import Model, Processor, Tensor, TorchDataset
class DataCollator:
"""Default Data collator."""
def __init__(self, processor: Processor) -> None:
self.processor = processor
def __call__(self, features: list[dict[str, Any]]) -> dict[str, Tensor]:
"""Collate features into a batch."""
for feature in features:
pass
# sft: messages
# dpo: chosen_messages, rejected_messages
from ..utils.types import HFModel, Processor, TorchDataset
from .trainer_utils.data_collator import DataCollator
class BaseTrainer:
def __init__(
self,
args: TrainingArguments,
model: Model,
model: HFModel,
processor: Processor,
dataset: TorchDataset,
data_collator: DataCollator,
) -> None:
self.args = args
self.model = model
self.processor = processor
self.dataset = dataset
self.data_collator = data_collator
self.data_collator = DataCollator()
self.optimizer = None
self.lr_scheduler = None
def init_model_and_optimizer(self) -> None:
pass
def create_dataloader(self) -> None:
pass

View File

@@ -12,9 +12,33 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from ..config.sample_args import SampleArguments
from abc import ABC, abstractmethod
from ..config.sample_args import SampleArguments, SampleBackend
from .model_loader import ModelLoader
class BaseEngine(ABC):
@abstractmethod
def __init__(self, sample_args: SampleArguments, model_loader: ModelLoader) -> None: ...
@abstractmethod
async def generate(self):
pass
@abstractmethod
async def batch_infer(self):
pass
class HuggingFaceEngine(BaseEngine):
def __init__(self, model_loader: ModelLoader, sample_args: SampleArguments) -> None:
self.args = sample_args
class ChatSampler:
def __init__(self, sample_args: SampleArguments) -> None:
self.args = sample_args
def __init__(self, model_loader: ModelLoader, sample_args: SampleArguments) -> None:
if sample_args.sample_backend == SampleBackend.HF:
self.engine = HuggingFaceEngine(model_loader, sample_args)
else:
raise ValueError(f"Unknown sample backend: {sample_args.sample_backend}")

Some files were not shown because too many files have changed in this diff Show More