mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-28 01:30:36 +08:00
Compare commits
22 Commits
9fd4b094d4
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
eceec8ab69 | ||
|
|
b44f651e09 | ||
|
|
55590f5ece | ||
|
|
a1b1931b4a | ||
|
|
3c17f2722c | ||
|
|
a882e2d5fc | ||
|
|
a754604c11 | ||
|
|
6a2eafbae3 | ||
|
|
84485406b7 | ||
|
|
1c8a42d2f8 | ||
|
|
7901b2f32e | ||
|
|
1f1f5a7d1b | ||
|
|
6ef9854713 | ||
|
|
4923f52a28 | ||
|
|
0894b4f37e | ||
|
|
b0d49e137f | ||
|
|
ddd7dcc722 | ||
|
|
5204cd2bca | ||
|
|
8c74dca76a | ||
|
|
e8deda53a1 | ||
|
|
a769fb94b9 | ||
|
|
964569751f |
180
.github/copilot-instructions.md
vendored
Normal file
180
.github/copilot-instructions.md
vendored
Normal file
@@ -0,0 +1,180 @@
|
||||
# GitHub Copilot Instructions for LLaMA Factory
|
||||
|
||||
## Project Overview
|
||||
|
||||
LLaMA Factory is an efficient fine-tuning framework for 100+ large language models (LLMs). It provides:
|
||||
- Support for various models: LLaMA, LLaVA, Mistral, Qwen, DeepSeek, Yi, Gemma, ChatGLM, Phi, etc.
|
||||
- Multiple training methods: pre-training, supervised fine-tuning, reward modeling, PPO, DPO, KTO, ORPO
|
||||
- Scalable resources: 16-bit full-tuning, freeze-tuning, LoRA and QLoRA variants
|
||||
- Advanced algorithms: GaLore, BAdam, APOLLO, Adam-mini, Muon, OFT, DoRA, etc.
|
||||
- Web UI (LLaMA Board) and CLI interfaces
|
||||
|
||||
### Architecture Versions
|
||||
|
||||
LLaMA Factory has two parallel architectures that can be switched via the `USE_V1` environment variable:
|
||||
|
||||
**v0 (default)** - File hierarchy:
|
||||
- `api`, `webui` → `chat`, `eval`, `train` → `data`, `model` → `hparams` → `extras`
|
||||
|
||||
**v1** - File hierarchy:
|
||||
- `trainers` → `core` → `accelerator`, `plugins`, `config` → `utils`
|
||||
|
||||
Set `USE_V1=1` to enable v1 architecture.
|
||||
|
||||
## Code Structure
|
||||
|
||||
### v0 Architecture (Default)
|
||||
|
||||
- `src/llamafactory/` - Main package directory
|
||||
- `api/` - OpenAI-style API implementation
|
||||
- `chat/` - Chat interface implementation
|
||||
- `cli.py` - Command-line interface
|
||||
- `data/` - Data processing and dataset handling
|
||||
- `eval/` - Model evaluation utilities
|
||||
- `extras/` - Additional utilities and helpers
|
||||
- `hparams/` - Hyperparameter definitions
|
||||
- `model/` - Model loading, patching, and utilities
|
||||
- `train/` - Training pipeline implementation
|
||||
- `webui/` - Gradio-based web interface
|
||||
- `src/train.py` - Training entry script (delegates to `llamafactory.train.tuner`)
|
||||
- `src/webui.py` - Web UI entry script (delegates to `llamafactory.webui.interface`)
|
||||
- `src/api.py` - API server entry script (delegates to `llamafactory.api.app`)
|
||||
- `tests/` - Test suite
|
||||
- `examples/` - Example configurations for various training scenarios
|
||||
- `data/` - Dataset definitions and examples
|
||||
|
||||
### v1 Architecture (USE_V1=1)
|
||||
|
||||
- `src/llamafactory/v1/` - Version 1 package directory
|
||||
- `trainers/` - Training implementations
|
||||
- `core/` - Core training utilities
|
||||
- `accelerator/` - Acceleration and distributed training
|
||||
- `plugins/` - Pluggable components (model, data, sampler, trainer)
|
||||
- `config/` - Configuration management
|
||||
- `utils/` - Utility functions
|
||||
|
||||
## Development Practices
|
||||
|
||||
### Code Style
|
||||
|
||||
- Follow the [Google Python Style Guide](https://google.github.io/styleguide/pyguide.html)
|
||||
- Use ruff for linting and formatting
|
||||
- Line length: 119 characters
|
||||
- Indentation: 4 spaces
|
||||
- Quote style: double quotes
|
||||
- Use Google-style docstrings for documentation
|
||||
|
||||
### Import Organization
|
||||
|
||||
- Known first-party: `llamafactory`
|
||||
- Known third-party: `accelerate`, `datasets`, `gradio`, `numpy`, `peft`, `torch`, `transformers`, `trl`
|
||||
- Use 2 blank lines after imports
|
||||
|
||||
### Quality Checks
|
||||
|
||||
Before committing code, run:
|
||||
```bash
|
||||
make style # Auto-fix style issues
|
||||
make quality # Check code quality
|
||||
make test # Run test suite
|
||||
```
|
||||
|
||||
Or use the combined command:
|
||||
```bash
|
||||
make commit # Run pre-commit hooks
|
||||
```
|
||||
|
||||
### Testing
|
||||
|
||||
- Use pytest for testing
|
||||
- Tests are located in `tests/` and `tests_v1/` directories
|
||||
- Run tests with: `make test` (which runs `WANDB_DISABLED=true pytest -vv --import-mode=importlib tests/ tests_v1/`)
|
||||
- Disable wandb during testing to avoid external dependencies
|
||||
- **Note**: Training configurations require GPU machines, so training is typically not tested end-to-end. Use `make test` to validate file-level functionality.
|
||||
|
||||
### Building
|
||||
|
||||
Build the package with:
|
||||
```bash
|
||||
pip3 install build && python3 -m build
|
||||
```
|
||||
|
||||
### License
|
||||
|
||||
- All source files must include the Apache 2.0 license header
|
||||
- Check license headers with: `make license`
|
||||
|
||||
## Common Patterns
|
||||
|
||||
### Configuration Files
|
||||
|
||||
- Training configurations are typically YAML or JSON files in `examples/` directory
|
||||
- Hyperparameters are defined using dataclasses in `src/llamafactory/hparams/`
|
||||
|
||||
### Model Support
|
||||
|
||||
- New model support is added through model patches in `src/llamafactory/model/`
|
||||
- Visual models use the visual utilities in `src/llamafactory/model/model_utils/visual.py`
|
||||
- Quantization support is in `src/llamafactory/model/model_utils/quantization.py`
|
||||
|
||||
### Data Processing
|
||||
|
||||
- Dataset definitions are in `data/dataset_info.json`
|
||||
- Data templates and processors are in `src/llamafactory/data/`
|
||||
|
||||
### Training
|
||||
|
||||
- Training pipelines are in `src/llamafactory/train/`
|
||||
- Support for different training methods: SFT, DPO, PPO, RM, PT, KTO, ORPO
|
||||
|
||||
## Key Dependencies
|
||||
|
||||
- Python >= 3.9.0
|
||||
- PyTorch and transformers for model handling
|
||||
- datasets for data processing
|
||||
- peft for parameter-efficient fine-tuning
|
||||
- accelerate for distributed training
|
||||
- gradio for web UI
|
||||
- trl for reinforcement learning
|
||||
- Optional: vllm/sglang for inference, flash-attention-2, unsloth, liger-kernel
|
||||
|
||||
## Entry Points
|
||||
|
||||
- **CLI Training**: `llamafactory-cli train --config examples/train_lora/llama3_lora_sft.yaml`
|
||||
- **Web UI**: `llamafactory-cli webui` or `python src/webui.py`
|
||||
- **API Server**: `llamafactory-cli api` or `python src/api.py`
|
||||
- **Chat Interface**: `llamafactory-cli chat --model_name_or_path MODEL_PATH`
|
||||
|
||||
## Environment Setup
|
||||
|
||||
For development:
|
||||
```bash
|
||||
pip install -e ".[dev]"
|
||||
```
|
||||
|
||||
## Important Notes
|
||||
|
||||
- The project supports multiple backends: default PyTorch, vLLM, SGLang
|
||||
- Megatron-core training is supported via mcore_adapter
|
||||
- SwanLab and W&B are supported for experiment tracking
|
||||
- Docker support is available with pre-built images
|
||||
- Day-0/Day-1 support for latest cutting-edge models
|
||||
- Multi-modal support for vision and audio understanding tasks
|
||||
|
||||
## Contribution Guidelines
|
||||
|
||||
1. Fork the repository
|
||||
2. Create a development branch
|
||||
3. Set up development environment with `pip install -e ".[dev]"`
|
||||
4. Make changes following the style guide
|
||||
5. Run quality checks: `make style && make quality`
|
||||
6. Run tests: `make test`
|
||||
7. Submit a pull request
|
||||
|
||||
## Common Commands
|
||||
|
||||
- `make style` - Format code
|
||||
- `make quality` - Run linters
|
||||
- `make test` - Run tests
|
||||
- `make commit` - Install and run pre-commit hooks
|
||||
- `make license` - Check license headers
|
||||
32
.github/workflows/docker.yml
vendored
32
.github/workflows/docker.yml
vendored
@@ -7,7 +7,7 @@ on:
|
||||
- "main"
|
||||
paths:
|
||||
- "**/*.py"
|
||||
- "requirements.txt"
|
||||
- "pyproject.toml"
|
||||
- "docker/**"
|
||||
- ".github/workflows/*.yml"
|
||||
pull_request:
|
||||
@@ -15,7 +15,7 @@ on:
|
||||
- "main"
|
||||
paths:
|
||||
- "**/*.py"
|
||||
- "requirements.txt"
|
||||
- "pyproject.toml"
|
||||
- "docker/**"
|
||||
- ".github/workflows/*.yml"
|
||||
release:
|
||||
@@ -29,16 +29,13 @@ jobs:
|
||||
matrix:
|
||||
include:
|
||||
- device: "cuda"
|
||||
npu_type: ""
|
||||
- device: "npu"
|
||||
npu_type: "a2"
|
||||
- device: "npu"
|
||||
npu_type: "a3"
|
||||
- device: "npu-a2"
|
||||
- device: "npu-a3"
|
||||
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}-${{ matrix.device }}-${{ matrix.npu_type }}
|
||||
group: ${{ github.workflow }}-${{ github.ref }}-${{ matrix.device }}
|
||||
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
|
||||
|
||||
environment:
|
||||
@@ -55,16 +52,11 @@ jobs:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.10"
|
||||
|
||||
- name: Get llamafactory version
|
||||
id: version
|
||||
run: |
|
||||
if [ "${{ github.event_name }}" = "release" ]; then
|
||||
echo "tag=$(python setup.py --version)" >> "$GITHUB_OUTPUT"
|
||||
echo "tag=$(grep -oP 'VERSION = "\K[^"]+' src/llamafactory/extras/env.py)" >> "$GITHUB_OUTPUT"
|
||||
else
|
||||
echo "tag=latest" >> "$GITHUB_OUTPUT"
|
||||
fi
|
||||
@@ -93,16 +85,12 @@ jobs:
|
||||
with:
|
||||
context: .
|
||||
file: ./docker/docker-cuda/Dockerfile
|
||||
build-args: |
|
||||
EXTRAS=metrics,deepspeed,liger-kernel
|
||||
push: ${{ github.event_name != 'pull_request' }}
|
||||
tags: |
|
||||
docker.io/hiyouga/llamafactory:${{ steps.version.outputs.tag }}
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
||||
|
||||
- name: Build and push Docker image (NPU-A2)
|
||||
if: ${{ matrix.device == 'npu' && matrix.npu_type == 'a2' }}
|
||||
if: ${{ matrix.device == 'npu-a2' }}
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
context: .
|
||||
@@ -112,11 +100,9 @@ jobs:
|
||||
tags: |
|
||||
docker.io/hiyouga/llamafactory:${{ steps.version.outputs.tag }}-npu-a2
|
||||
quay.io/ascend/llamafactory:${{ steps.version.outputs.tag }}-npu-a2
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
||||
|
||||
- name: Build and push Docker image (NPU-A3)
|
||||
if: ${{ matrix.device == 'npu' && matrix.npu_type == 'a3' }}
|
||||
if: ${{ matrix.device == 'npu-a3' }}
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
context: .
|
||||
@@ -128,5 +114,3 @@ jobs:
|
||||
tags: |
|
||||
docker.io/hiyouga/llamafactory:${{ steps.version.outputs.tag }}-npu-a3
|
||||
quay.io/ascend/llamafactory:${{ steps.version.outputs.tag }}-npu-a3
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
||||
|
||||
7
.github/workflows/publish.yml
vendored
7
.github/workflows/publish.yml
vendored
@@ -23,10 +23,11 @@ jobs:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v7
|
||||
with:
|
||||
python-version: "3.9"
|
||||
python-version: "3.11"
|
||||
github-token: ${{ github.token }}
|
||||
|
||||
- name: Build package
|
||||
run: |
|
||||
|
||||
38
.github/workflows/tests.yml
vendored
38
.github/workflows/tests.yml
vendored
@@ -7,7 +7,7 @@ on:
|
||||
- "main"
|
||||
paths:
|
||||
- "**/*.py"
|
||||
- "requirements.txt"
|
||||
- "pyproject.toml"
|
||||
- "Makefile"
|
||||
- ".github/workflows/*.yml"
|
||||
pull_request:
|
||||
@@ -15,7 +15,7 @@ on:
|
||||
- "main"
|
||||
paths:
|
||||
- "**/*.py"
|
||||
- "requirements.txt"
|
||||
- "pyproject.toml"
|
||||
- "Makefile"
|
||||
- ".github/workflows/*.yml"
|
||||
|
||||
@@ -25,10 +25,9 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python:
|
||||
- "3.9"
|
||||
- "3.10"
|
||||
- "3.11"
|
||||
- "3.12"
|
||||
# - "3.13" # enable after trl is upgraded
|
||||
os:
|
||||
- "ubuntu-latest"
|
||||
- "windows-latest"
|
||||
@@ -36,18 +35,15 @@ jobs:
|
||||
transformers:
|
||||
- null
|
||||
include: # test backward compatibility
|
||||
- python: "3.9"
|
||||
- python: "3.11"
|
||||
os: "ubuntu-latest"
|
||||
transformers: "4.49.0"
|
||||
- python: "3.9"
|
||||
- python: "3.11"
|
||||
os: "ubuntu-latest"
|
||||
transformers: "4.51.0"
|
||||
- python: "3.9"
|
||||
- python: "3.11"
|
||||
os: "ubuntu-latest"
|
||||
transformers: "4.53.0"
|
||||
exclude: # exclude python 3.9 on macos
|
||||
- python: "3.9"
|
||||
os: "macos-latest"
|
||||
|
||||
runs-on: ${{ matrix.os }}
|
||||
|
||||
@@ -63,22 +59,23 @@ jobs:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v7
|
||||
with:
|
||||
python-version: ${{ matrix.python }}
|
||||
cache: "pip"
|
||||
cache-dependency-path: "**/requirements*.txt"
|
||||
github-token: ${{ github.token }}
|
||||
enable-cache: false
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install ".[torch,dev]"
|
||||
uv venv
|
||||
uv pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
|
||||
uv pip install -e ".[dev]"
|
||||
|
||||
- name: Install transformers
|
||||
if: ${{ matrix.transformers }}
|
||||
run: |
|
||||
python -m pip install "transformers==${{ matrix.transformers }}"
|
||||
uv pip install "transformers==${{ matrix.transformers }}"
|
||||
|
||||
- name: Cache files
|
||||
id: hf-hub-cache
|
||||
@@ -90,18 +87,25 @@ jobs:
|
||||
- name: Check quality
|
||||
run: |
|
||||
make style && make quality
|
||||
env:
|
||||
UV_NO_SYNC: 1
|
||||
|
||||
- name: Check license
|
||||
run: |
|
||||
make license
|
||||
env:
|
||||
UV_NO_SYNC: 1
|
||||
|
||||
- name: Check build
|
||||
run: |
|
||||
make build
|
||||
env:
|
||||
UV_NO_SYNC: 1
|
||||
|
||||
- name: Test with pytest
|
||||
run: |
|
||||
make test
|
||||
env:
|
||||
UV_NO_SYNC: 1
|
||||
HF_HOME: ${{ runner.temp }}/huggingface
|
||||
HF_HUB_OFFLINE: "${{ steps.hf-hub-cache.outputs.cache-hit == 'true' && '1' || '0' }}"
|
||||
|
||||
20
.github/workflows/tests_npu.yml
vendored
20
.github/workflows/tests_npu.yml
vendored
@@ -7,7 +7,7 @@ on:
|
||||
- "main"
|
||||
paths:
|
||||
- "**/*.py"
|
||||
- "requirements.txt"
|
||||
- "pyproject.toml"
|
||||
- "Makefile"
|
||||
- ".github/workflows/*.yml"
|
||||
pull_request:
|
||||
@@ -15,7 +15,7 @@ on:
|
||||
- "main"
|
||||
paths:
|
||||
- "**/*.py"
|
||||
- "requirements.txt"
|
||||
- "pyproject.toml"
|
||||
- "Makefile"
|
||||
- ".github/workflows/*.yml"
|
||||
|
||||
@@ -48,10 +48,15 @@ jobs:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install uv
|
||||
run: |
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install ".[torch-npu,dev]" torch-npu==${{matrix.pytorch_npu}}
|
||||
uv venv
|
||||
uv pip install torch-npu==${{matrix.pytorch_npu}}
|
||||
uv pip install -e ".[dev]"
|
||||
|
||||
- name: Install node
|
||||
run: |
|
||||
@@ -70,18 +75,25 @@ jobs:
|
||||
- name: Check quality
|
||||
run: |
|
||||
make style && make quality
|
||||
env:
|
||||
UV_NO_SYNC: 1
|
||||
|
||||
- name: Check license
|
||||
run: |
|
||||
make license
|
||||
env:
|
||||
UV_NO_SYNC: 1
|
||||
|
||||
- name: Check build
|
||||
run: |
|
||||
make build
|
||||
env:
|
||||
UV_NO_SYNC: 1
|
||||
|
||||
- name: Test with pytest
|
||||
run: |
|
||||
make test
|
||||
env:
|
||||
UV_NO_SYNC: 1
|
||||
HF_HOME: /root/.cache/huggingface
|
||||
HF_HUB_OFFLINE: "${{ steps.hf-hub-cache.outputs.cache-hit == 'true' && '1' || '0' }}"
|
||||
|
||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -85,7 +85,7 @@ ipython_config.py
|
||||
# pyenv
|
||||
# For a library or package, you might want to ignore these files since the code is
|
||||
# intended to run in multiple environments; otherwise, check them in:
|
||||
# .python-version
|
||||
.python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
|
||||
@@ -1 +1 @@
|
||||
include LICENSE requirements.txt
|
||||
include LICENSE
|
||||
|
||||
24
Makefile
24
Makefile
@@ -1,24 +1,28 @@
|
||||
.PHONY: build commit license quality style test
|
||||
|
||||
check_dirs := scripts src tests tests_v1 setup.py
|
||||
check_dirs := scripts src tests tests_v1
|
||||
|
||||
RUN := $(shell command -v uv >/dev/null 2>&1 && echo "uv run" || echo "")
|
||||
BUILD := $(shell command -v uv >/dev/null 2>&1 && echo "uv build" || echo "python -m build")
|
||||
TOOL := $(shell command -v uv >/dev/null 2>&1 && echo "uvx" || echo "")
|
||||
|
||||
build:
|
||||
pip3 install build && python3 -m build
|
||||
$(BUILD)
|
||||
|
||||
commit:
|
||||
pre-commit install
|
||||
pre-commit run --all-files
|
||||
$(TOOL) pre-commit install
|
||||
$(TOOL) pre-commit run --all-files
|
||||
|
||||
license:
|
||||
python3 tests/check_license.py $(check_dirs)
|
||||
$(RUN) python3 tests/check_license.py $(check_dirs)
|
||||
|
||||
quality:
|
||||
ruff check $(check_dirs)
|
||||
ruff format --check $(check_dirs)
|
||||
$(TOOL) ruff check $(check_dirs)
|
||||
$(TOOL) ruff format --check $(check_dirs)
|
||||
|
||||
style:
|
||||
ruff check $(check_dirs) --fix
|
||||
ruff format $(check_dirs)
|
||||
$(TOOL) ruff check $(check_dirs) --fix
|
||||
$(TOOL) ruff format $(check_dirs)
|
||||
|
||||
test:
|
||||
WANDB_DISABLED=true pytest -vv --import-mode=importlib tests/ tests_v1/
|
||||
WANDB_DISABLED=true $(RUN) pytest -vv --import-mode=importlib tests/ tests_v1/
|
||||
|
||||
49
README.md
49
README.md
@@ -278,27 +278,21 @@ Read technical notes:
|
||||
|
||||
| Model | Model size | Template |
|
||||
| ----------------------------------------------------------------- | -------------------------------- | -------------------- |
|
||||
| [Baichuan 2](https://huggingface.co/baichuan-inc) | 7B/13B | baichuan2 |
|
||||
| [BLOOM/BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
|
||||
| [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 |
|
||||
| [Command R](https://huggingface.co/CohereForAI) | 35B/104B | cohere |
|
||||
| [DeepSeek (Code/MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek |
|
||||
| [DeepSeek 2.5/3](https://huggingface.co/deepseek-ai) | 236B/671B | deepseek3 |
|
||||
| [DeepSeek (LLM/Code/MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek |
|
||||
| [DeepSeek 3-3.2](https://huggingface.co/deepseek-ai) | 236B/671B | deepseek3 |
|
||||
| [DeepSeek R1 (Distill)](https://huggingface.co/deepseek-ai) | 1.5B/7B/8B/14B/32B/70B/671B | deepseekr1 |
|
||||
| [ERNIE-4.5](https://huggingface.co/baidu) | 0.3B/21B/300B | ernie/ernie_nothink |
|
||||
| [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon |
|
||||
| [Falcon-H1](https://huggingface.co/tiiuae) | 0.5B/1.5B/3B/7B/34B | falcon_h1 |
|
||||
| [Falcon/Falcon H1](https://huggingface.co/tiiuae) | 0.5B/1.5B/3B/7B/11B/34B/40B/180B | falcon/falcon_h1 |
|
||||
| [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma/gemma2 |
|
||||
| [Gemma 3/Gemma 3n](https://huggingface.co/google) | 270M/1B/4B/6B/8B/12B/27B | gemma3/gemma3n |
|
||||
| [GLM-4/GLM-4-0414/GLM-Z1](https://huggingface.co/zai-org) | 9B/32B | glm4/glmz1 |
|
||||
| [GLM-4.1V](https://huggingface.co/zai-org) | 9B | glm4v |
|
||||
| [GLM-4.5/GLM-4.5(6)V](https://huggingface.co/zai-org) | 9B/106B/355B | glm4_moe/glm4_5v |
|
||||
| [GPT-2](https://huggingface.co/openai-community) | 0.1B/0.4B/0.8B/1.5B | - |
|
||||
| [GPT-OSS](https://huggingface.co/openai) | 20B/120B | gpt |
|
||||
| [Granite 3.0-3.3](https://huggingface.co/ibm-granite) | 1B/2B/3B/8B | granite3 |
|
||||
| [Granite 4](https://huggingface.co/ibm-granite) | 7B | granite4 |
|
||||
| [GPT-OSS](https://huggingface.co/openai) | 20B/120B | gpt_oss |
|
||||
| [Granite 3-4](https://huggingface.co/ibm-granite) | 1B/2B/3B/7B/8B | granite3/granite4 |
|
||||
| [Hunyuan (MT)](https://huggingface.co/tencent/) | 7B | hunyuan |
|
||||
| [Index](https://huggingface.co/IndexTeam) | 1.9B | index |
|
||||
| [InternLM 2-3](https://huggingface.co/internlm) | 7B/8B/20B | intern2 |
|
||||
| [InternVL 2.5-3.5](https://huggingface.co/OpenGVLab) | 1B/2B/4B/8B/14B/30B/38B/78B/241B | intern_vl |
|
||||
| [InternLM/Intern-S1-mini](https://huggingface.co/internlm/) | 8B | intern_s1 |
|
||||
@@ -312,15 +306,13 @@ Read technical notes:
|
||||
| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | llava |
|
||||
| [LLaVA-NeXT](https://huggingface.co/llava-hf) | 7B/8B/13B/34B/72B/110B | llava_next |
|
||||
| [LLaVA-NeXT-Video](https://huggingface.co/llava-hf) | 7B/34B | llava_next_video |
|
||||
| [MiMo](https://huggingface.co/XiaomiMiMo) | 7B | mimo |
|
||||
| [MiMo](https://huggingface.co/XiaomiMiMo) | 7B/309B | mimo/mimo_v2 |
|
||||
| [MiniCPM 1-4.1](https://huggingface.co/openbmb) | 0.5B/1B/2B/4B/8B | cpm/cpm3/cpm4 |
|
||||
| [MiniCPM-o-2.6/MiniCPM-V-2.6](https://huggingface.co/openbmb) | 8B | minicpm_o/minicpm_v |
|
||||
| [Ministral(3)/Mistral-Nemo](https://huggingface.co/mistralai) | 3B/8B/12B/14B | ministral/ministral3 |
|
||||
| [Ministral 3](https://huggingface.co/mistralai) | 3B/8B/14B | ministral3 |
|
||||
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
|
||||
| [Mistral Small](https://huggingface.co/mistralai) | 24B | mistral_small |
|
||||
| [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
|
||||
| [PaliGemma/PaliGemma2](https://huggingface.co/google) | 3B/10B/28B | paligemma |
|
||||
| [Phi-1.5/Phi-2](https://huggingface.co/microsoft) | 1.3B/2.7B | - |
|
||||
| [Phi-3/Phi-3.5](https://huggingface.co/microsoft) | 4B/14B | phi |
|
||||
| [Phi-3-small](https://huggingface.co/microsoft) | 7B | phi_small |
|
||||
| [Phi-4](https://huggingface.co/microsoft) | 14B | phi4 |
|
||||
@@ -333,13 +325,9 @@ Read technical notes:
|
||||
| [Qwen2-VL/Qwen2.5-VL/QVQ](https://huggingface.co/Qwen) | 2B/3B/7B/32B/72B | qwen2_vl |
|
||||
| [Qwen3-VL](https://huggingface.co/Qwen) | 2B/4B/8B/30B/32B/235B | qwen3_vl |
|
||||
| [Seed (OSS/Coder)](https://huggingface.co/ByteDance-Seed) | 8B/36B | seed_oss/seed_coder |
|
||||
| [Skywork o1](https://huggingface.co/Skywork) | 8B | skywork_o1 |
|
||||
| [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
|
||||
| [TeleChat2](https://huggingface.co/Tele-AI) | 3B/7B/35B/115B | telechat2 |
|
||||
| [VibeThinker-1.5B](https://huggingface.co/WeiboAI) | 1.5B | qwen3 |
|
||||
| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | xverse |
|
||||
| [Yi/Yi-1.5 (Code)](https://huggingface.co/01-ai) | 1.5B/6B/9B/34B | yi |
|
||||
| [Yi-VL](https://huggingface.co/01-ai) | 6B/34B | yi_vl |
|
||||
| [Yuan 2](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan |
|
||||
|
||||
> [!NOTE]
|
||||
@@ -526,10 +514,12 @@ huggingface-cli login
|
||||
```bash
|
||||
git clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git
|
||||
cd LLaMA-Factory
|
||||
pip install -e ".[torch,metrics]" --no-build-isolation
|
||||
pip install -e ".[metrics]" --no-build-isolation
|
||||
```
|
||||
|
||||
Extra dependencies available: torch, torch-npu, metrics, deepspeed, liger-kernel, bitsandbytes, hqq, eetq, gptq, aqlm, vllm, sglang, galore, apollo, badam, adam-mini, qwen, minicpm_v, openmind, swanlab, dev
|
||||
Optional dependencies available: `metrics`, `deepspeed`. Install with: `pip install -e ".[metrics,deepspeed]"`
|
||||
|
||||
Additional dependencies for specific features are available in `examples/requirements/`.
|
||||
|
||||
#### Install from Docker Image
|
||||
|
||||
@@ -548,13 +538,7 @@ Please refer to [build docker](#build-docker) to build the image yourself.
|
||||
Create an isolated Python environment with [uv](https://github.com/astral-sh/uv):
|
||||
|
||||
```bash
|
||||
uv sync --extra torch --extra metrics --prerelease=allow
|
||||
```
|
||||
|
||||
Run LLaMA-Factory in the isolated environment:
|
||||
|
||||
```bash
|
||||
uv run --prerelease=allow llamafactory-cli train examples/train_lora/llama3_lora_pretrain.yaml
|
||||
uv run llamafactory-cli webui
|
||||
```
|
||||
|
||||
</details>
|
||||
@@ -591,7 +575,7 @@ To enable FlashAttention-2 on the Windows platform, please use the script from [
|
||||
|
||||
<details><summary>For Ascend NPU users</summary>
|
||||
|
||||
To install LLaMA Factory on Ascend NPU devices, please upgrade Python to version 3.10 or higher and specify extra dependencies: `pip install -e ".[torch-npu,metrics]"`. Additionally, you need to install the **[Ascend CANN Toolkit and Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**. Please follow the [installation tutorial](https://www.hiascend.com/document/detail/en/CANNCommunityEdition/600alphaX/softwareinstall/instg/atlasdeploy_03_0031.html) or use the following commands:
|
||||
To install LLaMA Factory on Ascend NPU devices, please upgrade Python to version 3.10 or higher: `pip install -e . torch-npu==2.7.1`. Additionally, you need to install the **[Ascend CANN Toolkit and Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**. Please follow the [installation tutorial](https://www.hiascend.com/document/detail/en/CANNCommunityEdition/600alphaX/softwareinstall/instg/atlasdeploy_03_0031.html) or use the following commands:
|
||||
|
||||
```bash
|
||||
# replace the url according to your CANN version and devices
|
||||
@@ -610,8 +594,8 @@ source /usr/local/Ascend/ascend-toolkit/set_env.sh
|
||||
| Requirement | Minimum | Recommend |
|
||||
| ------------ | ------- | -------------- |
|
||||
| CANN | 8.0.RC1 | 8.0.0.alpha002 |
|
||||
| torch | 2.1.0 | 2.4.0 |
|
||||
| torch-npu | 2.1.0 | 2.4.0.post2 |
|
||||
| torch | 2.1.0 | 2.7.1 |
|
||||
| torch-npu | 2.1.0 | 2.7.1 |
|
||||
| deepspeed | 0.13.2 | 0.13.2 |
|
||||
| vllm-ascend | - | 0.7.3 |
|
||||
|
||||
@@ -726,7 +710,6 @@ For CUDA users:
|
||||
```bash
|
||||
docker build -f ./docker/docker-cuda/Dockerfile \
|
||||
--build-arg PIP_INDEX=https://pypi.org/simple \
|
||||
--build-arg EXTRAS=metrics \
|
||||
-t llamafactory:latest .
|
||||
|
||||
docker run -dit --ipc=host --gpus=all \
|
||||
@@ -743,7 +726,6 @@ For Ascend NPU users:
|
||||
```bash
|
||||
docker build -f ./docker/docker-npu/Dockerfile \
|
||||
--build-arg PIP_INDEX=https://pypi.org/simple \
|
||||
--build-arg EXTRAS=torch-npu,metrics \
|
||||
-t llamafactory:latest .
|
||||
|
||||
docker run -dit --ipc=host \
|
||||
@@ -768,7 +750,6 @@ For AMD ROCm users:
|
||||
```bash
|
||||
docker build -f ./docker/docker-rocm/Dockerfile \
|
||||
--build-arg PIP_INDEX=https://pypi.org/simple \
|
||||
--build-arg EXTRAS=metrics \
|
||||
-t llamafactory:latest .
|
||||
|
||||
docker run -dit --ipc=host \
|
||||
|
||||
46
README_zh.md
46
README_zh.md
@@ -280,27 +280,21 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
|
||||
|
||||
| 模型名 | 参数量 | Template |
|
||||
| ----------------------------------------------------------------- | -------------------------------- | -------------------- |
|
||||
| [Baichuan 2](https://huggingface.co/baichuan-inc) | 7B/13B | baichuan2 |
|
||||
| [BLOOM/BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
|
||||
| [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 |
|
||||
| [Command R](https://huggingface.co/CohereForAI) | 35B/104B | cohere |
|
||||
| [DeepSeek (Code/MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek |
|
||||
| [DeepSeek 2.5/3](https://huggingface.co/deepseek-ai) | 236B/671B | deepseek3 |
|
||||
| [DeepSeek (LLM/Code/MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek |
|
||||
| [DeepSeek 3-3.2](https://huggingface.co/deepseek-ai) | 236B/671B | deepseek3 |
|
||||
| [DeepSeek R1 (Distill)](https://huggingface.co/deepseek-ai) | 1.5B/7B/8B/14B/32B/70B/671B | deepseekr1 |
|
||||
| [ERNIE-4.5](https://huggingface.co/baidu) | 0.3B/21B/300B | ernie/ernie_nothink |
|
||||
| [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon |
|
||||
| [Falcon-H1](https://huggingface.co/tiiuae) | 0.5B/1.5B/3B/7B/34B | falcon_h1 |
|
||||
| [Falcon/Falcon H1](https://huggingface.co/tiiuae) | 0.5B/1.5B/3B/7B/11B/34B/40B/180B | falcon/falcon_h1 |
|
||||
| [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma/gemma2 |
|
||||
| [Gemma 3/Gemma 3n](https://huggingface.co/google) | 270M/1B/4B/6B/8B/12B/27B | gemma3/gemma3n |
|
||||
| [GLM-4/GLM-4-0414/GLM-Z1](https://huggingface.co/zai-org) | 9B/32B | glm4/glmz1 |
|
||||
| [GLM-4.1V](https://huggingface.co/zai-org) | 9B | glm4v |
|
||||
| [GLM-4.5/GLM-4.5(6)V](https://huggingface.co/zai-org) | 9B/106B/355B | glm4_moe/glm4_5v |
|
||||
| [GPT-2](https://huggingface.co/openai-community) | 0.1B/0.4B/0.8B/1.5B | - |
|
||||
| [GPT-OSS](https://huggingface.co/openai) | 20B/120B | gpt |
|
||||
| [Granite 3.0-3.3](https://huggingface.co/ibm-granite) | 1B/2B/3B/8B | granite3 |
|
||||
| [Granite 4](https://huggingface.co/ibm-granite) | 7B | granite4 |
|
||||
| [GPT-OSS](https://huggingface.co/openai) | 20B/120B | gpt_oss |
|
||||
| [Granite 3-4](https://huggingface.co/ibm-granite) | 1B/2B/3B/7B/8B | granite3/granite4 |
|
||||
| [Hunyuan (MT)](https://huggingface.co/tencent/) | 7B | hunyuan |
|
||||
| [Index](https://huggingface.co/IndexTeam) | 1.9B | index |
|
||||
| [InternLM 2-3](https://huggingface.co/internlm) | 7B/8B/20B | intern2 |
|
||||
| [InternVL 2.5-3.5](https://huggingface.co/OpenGVLab) | 1B/2B/4B/8B/14B/30B/38B/78B/241B | intern_vl |
|
||||
| [InternLM/Intern-S1-mini](https://huggingface.co/internlm/) | 8B | intern_s1 |
|
||||
@@ -314,15 +308,13 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
|
||||
| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | llava |
|
||||
| [LLaVA-NeXT](https://huggingface.co/llava-hf) | 7B/8B/13B/34B/72B/110B | llava_next |
|
||||
| [LLaVA-NeXT-Video](https://huggingface.co/llava-hf) | 7B/34B | llava_next_video |
|
||||
| [MiMo](https://huggingface.co/XiaomiMiMo) | 7B | mimo |
|
||||
| [MiMo](https://huggingface.co/XiaomiMiMo) | 7B/309B | mimo/mimo_v2 |
|
||||
| [MiniCPM 1-4.1](https://huggingface.co/openbmb) | 0.5B/1B/2B/4B/8B | cpm/cpm3/cpm4 |
|
||||
| [MiniCPM-o-2.6/MiniCPM-V-2.6](https://huggingface.co/openbmb) | 8B | minicpm_o/minicpm_v |
|
||||
| [Ministral(3)/Mistral-Nemo](https://huggingface.co/mistralai) | 3B/8B/12B/14B | ministral/ministral3 |
|
||||
| [Ministral 3](https://huggingface.co/mistralai) | 3B/8B/14B | ministral3 |
|
||||
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
|
||||
| [Mistral Small](https://huggingface.co/mistralai) | 24B | mistral_small |
|
||||
| [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
|
||||
| [PaliGemma/PaliGemma2](https://huggingface.co/google) | 3B/10B/28B | paligemma |
|
||||
| [Phi-1.5/Phi-2](https://huggingface.co/microsoft) | 1.3B/2.7B | - |
|
||||
| [Phi-3/Phi-3.5](https://huggingface.co/microsoft) | 4B/14B | phi |
|
||||
| [Phi-3-small](https://huggingface.co/microsoft) | 7B | phi_small |
|
||||
| [Phi-4](https://huggingface.co/microsoft) | 14B | phi4 |
|
||||
@@ -335,13 +327,9 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
|
||||
| [Qwen2-VL/Qwen2.5-VL/QVQ](https://huggingface.co/Qwen) | 2B/3B/7B/32B/72B | qwen2_vl |
|
||||
| [Qwen3-VL](https://huggingface.co/Qwen) | 2B/4B/8B/30B/32B/235B | qwen3_vl |
|
||||
| [Seed (OSS/Coder)](https://huggingface.co/ByteDance-Seed) | 8B/36B | seed_oss/seed_coder |
|
||||
| [Skywork o1](https://huggingface.co/Skywork) | 8B | skywork_o1 |
|
||||
| [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
|
||||
| [TeleChat2](https://huggingface.co/Tele-AI) | 3B/7B/35B/115B | telechat2 |
|
||||
| [VibeThinker-1.5B](https://huggingface.co/WeiboAI) | 1.5B | qwen3 |
|
||||
| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | xverse |
|
||||
| [Yi/Yi-1.5 (Code)](https://huggingface.co/01-ai) | 1.5B/6B/9B/34B | yi |
|
||||
| [Yi-VL](https://huggingface.co/01-ai) | 6B/34B | yi_vl |
|
||||
| [Yuan 2](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan |
|
||||
|
||||
> [!NOTE]
|
||||
@@ -528,10 +516,12 @@ huggingface-cli login
|
||||
```bash
|
||||
git clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git
|
||||
cd LLaMA-Factory
|
||||
pip install -e ".[torch,metrics]" --no-build-isolation
|
||||
pip install -e ".[metrics]" --no-build-isolation
|
||||
```
|
||||
|
||||
可选的额外依赖项:torch、torch-npu、metrics、deepspeed、liger-kernel、bitsandbytes、hqq、eetq、gptq、aqlm、vllm、sglang、galore、apollo、badam、adam-mini、qwen、minicpm_v、openmind、swanlab、dev
|
||||
可选的额外依赖项:`metrics`、`deepspeed`。使用 `pip install -e ".[metrics,deepspeed]"` 安装。
|
||||
|
||||
其他可选依赖项请参考 `examples/requirements/` 目录下的文件。
|
||||
|
||||
#### 从镜像安装
|
||||
|
||||
@@ -550,13 +540,7 @@ docker run -it --rm --gpus=all --ipc=host hiyouga/llamafactory:latest
|
||||
使用 [uv](https://github.com/astral-sh/uv) 创建隔离的 Python 环境:
|
||||
|
||||
```bash
|
||||
uv sync --extra torch --extra metrics --prerelease=allow
|
||||
```
|
||||
|
||||
在环境中运行 LLaMA-Factory:
|
||||
|
||||
```bash
|
||||
uv run --prerelease=allow llamafactory-cli train examples/train_lora/llama3_lora_pretrain.yaml
|
||||
uv run llamafactory-cli webui
|
||||
```
|
||||
|
||||
</details>
|
||||
@@ -593,7 +577,7 @@ pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/downl
|
||||
|
||||
<details><summary>昇腾 NPU 用户指南</summary>
|
||||
|
||||
在昇腾 NPU 设备上安装 LLaMA Factory 时,请升级 Python 到 3.10 及以上,并需要指定额外依赖项,使用 `pip install -e ".[torch-npu,metrics]"` 命令安装。此外,还需要安装 **[Ascend CANN Toolkit 与 Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**,安装方法请参考[安装教程](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/80RC2alpha002/quickstart/quickstart/quickstart_18_0004.html)或使用以下命令:
|
||||
在昇腾 NPU 设备上安装 LLaMA Factory 时,请升级 Python 到 3.10 及以上,并需要指定额外依赖项,使用 `pip install -e . torch-npu==2.7.1` 命令安装。此外,还需要安装 **[Ascend CANN Toolkit 与 Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**,安装方法请参考[安装教程](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/80RC2alpha002/quickstart/quickstart/quickstart_18_0004.html)或使用以下命令:
|
||||
|
||||
```bash
|
||||
# 请替换 URL 为 CANN 版本和设备型号对应的 URL
|
||||
@@ -612,8 +596,8 @@ source /usr/local/Ascend/ascend-toolkit/set_env.sh
|
||||
| 依赖项 | 至少 | 推荐 |
|
||||
| ------------ | ------- | -------------- |
|
||||
| CANN | 8.0.RC1 | 8.0.0.alpha002 |
|
||||
| torch | 2.1.0 | 2.4.0 |
|
||||
| torch-npu | 2.1.0 | 2.4.0.post2 |
|
||||
| torch | 2.1.0 | 2.7.1 |
|
||||
| torch-npu | 2.1.0 | 2.7.1 |
|
||||
| deepspeed | 0.13.2 | 0.13.2 |
|
||||
| vllm-ascend | - | 0.7.3 |
|
||||
|
||||
|
||||
@@ -4,7 +4,6 @@ FROM ${BASE_IMAGE}
|
||||
|
||||
# Installation arguments
|
||||
ARG PIP_INDEX=https://pypi.org/simple
|
||||
ARG EXTRAS=metrics
|
||||
ARG INSTALL_FLASHATTN=false
|
||||
ARG HTTP_PROXY=""
|
||||
|
||||
@@ -27,17 +26,13 @@ WORKDIR /app
|
||||
# Change pip source
|
||||
RUN pip config set global.index-url "${PIP_INDEX}" && \
|
||||
pip config set global.extra-index-url "${PIP_INDEX}" && \
|
||||
pip install --no-cache-dir --upgrade pip packaging wheel setuptools
|
||||
pip install --no-cache-dir --upgrade pip packaging wheel setuptools editables "hatchling>=1.18.0"
|
||||
|
||||
# Install the requirements
|
||||
COPY requirements.txt /app
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
# Copy the rest of the application into the image
|
||||
# Copy the application into the image
|
||||
COPY . /app
|
||||
|
||||
# Install LLaMA Factory
|
||||
RUN pip install --no-cache-dir -e ".[${EXTRAS}]" --no-build-isolation
|
||||
RUN pip install --no-cache-dir --no-build-isolation -e ".[metrics,deepspeed]"
|
||||
|
||||
# Rebuild flash attention
|
||||
RUN if [ "${INSTALL_FLASHATTN}" == "true" ]; then \
|
||||
|
||||
@@ -8,7 +8,7 @@ ENV PYPI_MIRROR=https://mirrors.aliyun.com/pypi/simple/
|
||||
ENV PYPI_TRUSTED_HOST=mirrors.aliyun.com
|
||||
ENV APT_MIRROR=https://mirrors.tuna.tsinghua.edu.cn/ubuntu/
|
||||
|
||||
RUN pip install --upgrade pip setuptools wheel --trusted-host ${PYPI_TRUSTED_HOST} --index-url ${PYPI_MIRROR}
|
||||
RUN pip install --upgrade pip setuptools wheel "hatchling>=1.18.0" editables --trusted-host ${PYPI_TRUSTED_HOST} --index-url ${PYPI_MIRROR}
|
||||
|
||||
RUN pip uninstall -y torch torchvision torch-tensorrt \
|
||||
flash_attn transformer-engine \
|
||||
@@ -56,14 +56,14 @@ ENV JAVA_HOME /usr/lib/jvm/java-21-openjdk-amd64
|
||||
# pip install LLaMA-Factory
|
||||
WORKDIR /app
|
||||
|
||||
COPY requirements.txt /app/
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
# Copy the application into the image
|
||||
COPY . /app
|
||||
|
||||
# Install LLaMA Factory
|
||||
RUN pip install --no-cache-dir -e ".[metrics]" --no-build-isolation
|
||||
|
||||
RUN pip install "git+https://github.com/alibaba/roll.git#subdirectory=mcore_adapter"
|
||||
|
||||
COPY . /app/
|
||||
RUN pip install -e ".[metrics]" --no-build-isolation
|
||||
|
||||
# Expose port 7860 for LLaMA Board
|
||||
ENV GRADIO_SERVER_PORT=7860
|
||||
EXPOSE 7860
|
||||
|
||||
@@ -5,7 +5,6 @@ services:
|
||||
context: ../..
|
||||
args:
|
||||
PIP_INDEX: https://pypi.org/simple
|
||||
EXTRAS: metrics
|
||||
container_name: llamafactory
|
||||
ports:
|
||||
- "7860:7860"
|
||||
|
||||
@@ -5,7 +5,6 @@ FROM ${BASE_IMAGE}
|
||||
|
||||
# Installation arguments
|
||||
ARG PIP_INDEX=https://pypi.org/simple
|
||||
ARG EXTRAS=torch-npu,metrics
|
||||
ARG HTTP_PROXY=""
|
||||
ARG PYTORCH_INDEX=https://download.pytorch.org/whl/cpu
|
||||
|
||||
@@ -28,21 +27,15 @@ WORKDIR /app
|
||||
# Change pip source
|
||||
RUN pip config set global.index-url "${PIP_INDEX}" && \
|
||||
pip config set global.extra-index-url "${PIP_INDEX}" && \
|
||||
pip install --no-cache-dir --upgrade pip packaging wheel setuptools
|
||||
pip install --no-cache-dir --upgrade pip packaging wheel setuptools editables "hatchling>=1.18.0"
|
||||
|
||||
# Copy the application into the image
|
||||
COPY . /app
|
||||
|
||||
# Install torch-npu
|
||||
RUN pip uninstall -y torch torchvision torchaudio && \
|
||||
pip install --no-cache-dir "torch==2.7.1" "torch-npu==2.7.1" "torchvision==0.22.1" --index-url "${PYTORCH_INDEX}"
|
||||
|
||||
# Install the requirements
|
||||
COPY requirements.txt /app
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
# Copy the rest of the application into the image
|
||||
COPY . /app
|
||||
|
||||
# Install LLaMA Factory
|
||||
RUN pip install --no-cache-dir -e ".[${EXTRAS}]" --no-build-isolation
|
||||
pip install --no-cache-dir "torch==2.7.1" "torch-npu==2.7.1" "torchvision==0.22.1" "torchaudio==2.7.1" --index-url "${PYTORCH_INDEX}" && \
|
||||
pip install --no-cache-dir -e ".[metrics]" --no-build-isolation
|
||||
|
||||
# Set up volumes
|
||||
# VOLUME [ "/root/.cache/huggingface", "/app/shared_data", "/app/output" ]
|
||||
|
||||
@@ -5,7 +5,6 @@ services:
|
||||
context: ../..
|
||||
args:
|
||||
PIP_INDEX: https://pypi.org/simple
|
||||
EXTRAS: torch-npu,metrics
|
||||
container_name: llamafactory-a2
|
||||
image: llamafactory:npu-a2
|
||||
volumes:
|
||||
@@ -36,7 +35,6 @@ services:
|
||||
args:
|
||||
BASE_IMAGE: quay.io/ascend/cann:8.3.rc2-a3-ubuntu22.04-py3.11
|
||||
PIP_INDEX: https://pypi.org/simple
|
||||
EXTRAS: torch-npu,metrics
|
||||
container_name: llamafactory-a3
|
||||
image: llamafactory:npu-a3
|
||||
volumes:
|
||||
|
||||
@@ -4,7 +4,6 @@ FROM ${BASE_IMAGE}
|
||||
|
||||
# Installation arguments
|
||||
ARG PIP_INDEX=https://pypi.org/simple
|
||||
ARG EXTRAS=metrics
|
||||
ARG INSTALL_FLASHATTN=false
|
||||
ARG HTTP_PROXY=""
|
||||
ARG PYTORCH_INDEX=https://download.pytorch.org/whl/rocm6.3
|
||||
@@ -28,21 +27,14 @@ WORKDIR /app
|
||||
# Change pip source
|
||||
RUN pip config set global.index-url "${PIP_INDEX}" && \
|
||||
pip config set global.extra-index-url "${PIP_INDEX}" && \
|
||||
pip install --no-cache-dir --upgrade pip packaging wheel setuptools
|
||||
pip install --no-cache-dir --upgrade pip packaging wheel setuptools editables "hatchling>=1.18.0"
|
||||
|
||||
# Reinstall pytorch rocm
|
||||
RUN pip uninstall -y torch torchvision torchaudio && \
|
||||
pip install --no-cache-dir --pre torch torchvision torchaudio --index-url "${PYTORCH_INDEX}"
|
||||
|
||||
# Install the requirements
|
||||
COPY requirements.txt /app
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
# Copy the rest of the application into the image
|
||||
# Copy the application into the image
|
||||
COPY . /app
|
||||
|
||||
# Install LLaMA Factory
|
||||
RUN pip install --no-cache-dir -e ".[${EXTRAS}]" --no-build-isolation
|
||||
# Reinstall pytorch rocm and install LLaMA Factory
|
||||
RUN pip uninstall -y torch torchvision torchaudio && \
|
||||
pip install --no-cache-dir --no-build-isolation -e --pre ".[metrics,deepspeed]" --index-url "${PYTORCH_INDEX}"
|
||||
|
||||
# Rebuild flash attention
|
||||
RUN if [ "${INSTALL_FLASHATTN}" == "true" ]; then \
|
||||
|
||||
@@ -5,7 +5,6 @@ services:
|
||||
context: ../..
|
||||
args:
|
||||
PIP_INDEX: https://pypi.org/simple
|
||||
EXTRAS: metrics
|
||||
container_name: llamafactory
|
||||
ports:
|
||||
- "7860:7860"
|
||||
|
||||
45
examples/ascend/qwen3_full_sft_fsdp2.yaml
Normal file
45
examples/ascend/qwen3_full_sft_fsdp2.yaml
Normal file
@@ -0,0 +1,45 @@
|
||||
# Start FSDP2 fine-tuning
|
||||
# accelerate launch \
|
||||
# --config_file examples/accelerate/fsdp2_config.yaml \
|
||||
# src/train.py examples/ascend/qwen3_full_sft_fsdp2.yaml
|
||||
# Change `num_processes` in fsdp2_config.yaml to 16 in A3
|
||||
|
||||
### model
|
||||
model_name_or_path: Qwen/Qwen3-8B
|
||||
trust_remote_code: true
|
||||
use_v1_kernels: true
|
||||
flash_attn: fa2
|
||||
|
||||
### method
|
||||
stage: sft
|
||||
do_train: true
|
||||
finetuning_type: full
|
||||
|
||||
### dataset
|
||||
dataset: alpaca_en_demo
|
||||
template: qwen3
|
||||
cutoff_len: 2048
|
||||
max_samples: 1000
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
dataloader_num_workers: 4
|
||||
|
||||
### output
|
||||
output_dir: saves/Qwen3-8B/full/sft
|
||||
logging_steps: 1
|
||||
save_steps: 500
|
||||
max_steps: 500
|
||||
plot_loss: true
|
||||
overwrite_output_dir: true
|
||||
save_only_model: false
|
||||
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
|
||||
|
||||
### train
|
||||
per_device_train_batch_size: 8
|
||||
gradient_accumulation_steps: 1
|
||||
learning_rate: 1.0e-5
|
||||
lr_scheduler_type: cosine
|
||||
warmup_ratio: 0.1
|
||||
bf16: true
|
||||
ddp_timeout: 1800
|
||||
resume_from_checkpoint: null
|
||||
46
examples/ascend/qwen3moe_full_sft_fsdp.yaml
Normal file
46
examples/ascend/qwen3moe_full_sft_fsdp.yaml
Normal file
@@ -0,0 +1,46 @@
|
||||
# Start FSDP fine-tuning
|
||||
# accelerate launch \
|
||||
# --config_file examples/accelerate/fsdp_config.yaml \
|
||||
# src/train.py examples/ascend/qwen3moe_full_sft_fsdp.yaml
|
||||
# Change `num_processes` in fsdp_config.yaml to 16 in A3
|
||||
|
||||
### model
|
||||
model_name_or_path: Qwen/Qwen3-30B-A3B-Instruct-2507
|
||||
trust_remote_code: true
|
||||
use_v1_kernels: true
|
||||
flash_attn: fa2
|
||||
|
||||
### method
|
||||
stage: sft
|
||||
do_train: true
|
||||
finetuning_type: full
|
||||
disable_gradient_checkpointing: false
|
||||
|
||||
### dataset
|
||||
dataset: alpaca_zh
|
||||
template: qwen3
|
||||
cutoff_len: 1024
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
dataloader_num_workers: 4
|
||||
|
||||
### output
|
||||
output_dir: saves/Qwen3-30B-A3B-Instruct-2507/full/sft
|
||||
logging_steps: 1
|
||||
save_steps: 500
|
||||
max_steps: 500
|
||||
plot_loss: true
|
||||
overwrite_output_dir: true
|
||||
save_only_model: true
|
||||
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
|
||||
|
||||
### train
|
||||
per_device_train_batch_size: 4
|
||||
gradient_accumulation_steps: 1
|
||||
learning_rate: 1.0e-4
|
||||
lr_scheduler_type: cosine
|
||||
warmup_ratio: 0.1
|
||||
bf16: true
|
||||
ddp_timeout: 180000000
|
||||
resume_from_checkpoint: null
|
||||
seed: 1234
|
||||
48
examples/ascend/qwen3vlmoe_full_sft_fsdp2.yaml
Normal file
48
examples/ascend/qwen3vlmoe_full_sft_fsdp2.yaml
Normal file
@@ -0,0 +1,48 @@
|
||||
# Start FSDP2 fine-tuning
|
||||
# accelerate launch \
|
||||
# --config_file examples/accelerate/fsdp2_config.yaml \
|
||||
# src/train.py examples/ascend/qwen3vlmoe_full_sft_fsdp2.yaml
|
||||
# Change `num_processes` in fsdp2_config.yaml to 16 in A3
|
||||
|
||||
### model
|
||||
model_name_or_path: Qwen/Qwen3-VL-30B-A3B-Instruct
|
||||
image_max_pixels: 262144
|
||||
video_max_pixels: 16384
|
||||
trust_remote_code: true
|
||||
use_v1_kernels: true
|
||||
flash_attn: fa2
|
||||
|
||||
### method
|
||||
stage: sft
|
||||
do_train: true
|
||||
finetuning_type: full
|
||||
disable_gradient_checkpointing: false
|
||||
|
||||
### dataset
|
||||
dataset: llava_1k_en, llava_1k_zh
|
||||
template: qwen3_vl
|
||||
cutoff_len: 1024
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
dataloader_num_workers: 4
|
||||
|
||||
### output
|
||||
output_dir: saves/Qwen3-VL-30B-A3B-Instruct/full/sft
|
||||
logging_steps: 1
|
||||
save_steps: 500
|
||||
max_steps: 500
|
||||
plot_loss: true
|
||||
overwrite_output_dir: true
|
||||
save_only_model: true
|
||||
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
|
||||
|
||||
### train
|
||||
per_device_train_batch_size: 2
|
||||
gradient_accumulation_steps: 1
|
||||
learning_rate: 1.0e-4
|
||||
lr_scheduler_type: cosine
|
||||
warmup_ratio: 0.1
|
||||
bf16: true
|
||||
ddp_timeout: 180000000
|
||||
resume_from_checkpoint: null
|
||||
seed: 1234
|
||||
1
examples/requirements/adam-mini.txt
Normal file
1
examples/requirements/adam-mini.txt
Normal file
@@ -0,0 +1 @@
|
||||
adam-mini
|
||||
1
examples/requirements/apollo.txt
Normal file
1
examples/requirements/apollo.txt
Normal file
@@ -0,0 +1 @@
|
||||
apollo-torch
|
||||
1
examples/requirements/aqlm.txt
Normal file
1
examples/requirements/aqlm.txt
Normal file
@@ -0,0 +1 @@
|
||||
aqlm[gpu]>=1.1.0
|
||||
1
examples/requirements/badam.txt
Normal file
1
examples/requirements/badam.txt
Normal file
@@ -0,0 +1 @@
|
||||
badam>=1.2.1
|
||||
1
examples/requirements/bitsandbytes.txt
Normal file
1
examples/requirements/bitsandbytes.txt
Normal file
@@ -0,0 +1 @@
|
||||
bitsandbytes>=0.39.0
|
||||
1
examples/requirements/eetq.txt
Normal file
1
examples/requirements/eetq.txt
Normal file
@@ -0,0 +1 @@
|
||||
eetq
|
||||
2
examples/requirements/fp8-te.txt
Normal file
2
examples/requirements/fp8-te.txt
Normal file
@@ -0,0 +1,2 @@
|
||||
transformer_engine[pytorch]>=2.0.0
|
||||
accelerate>=1.10.0
|
||||
2
examples/requirements/fp8.txt
Normal file
2
examples/requirements/fp8.txt
Normal file
@@ -0,0 +1,2 @@
|
||||
torchao>=0.8.0
|
||||
accelerate>=1.10.0
|
||||
1
examples/requirements/galore.txt
Normal file
1
examples/requirements/galore.txt
Normal file
@@ -0,0 +1 @@
|
||||
galore-torch
|
||||
2
examples/requirements/gptq.txt
Normal file
2
examples/requirements/gptq.txt
Normal file
@@ -0,0 +1,2 @@
|
||||
optimum>=1.24.0
|
||||
gptqmodel>=2.0.0
|
||||
1
examples/requirements/hqq.txt
Normal file
1
examples/requirements/hqq.txt
Normal file
@@ -0,0 +1 @@
|
||||
hqq
|
||||
1
examples/requirements/liger-kernel.txt
Normal file
1
examples/requirements/liger-kernel.txt
Normal file
@@ -0,0 +1 @@
|
||||
liger-kernel>=0.5.5
|
||||
8
examples/requirements/minicpm-v.txt
Normal file
8
examples/requirements/minicpm-v.txt
Normal file
@@ -0,0 +1,8 @@
|
||||
soundfile
|
||||
torchvision
|
||||
torchaudio
|
||||
vector_quantize_pytorch
|
||||
vocos
|
||||
msgpack
|
||||
referencing
|
||||
jsonschema_specifications
|
||||
1
examples/requirements/openmind.txt
Normal file
1
examples/requirements/openmind.txt
Normal file
@@ -0,0 +1 @@
|
||||
openmind
|
||||
2
examples/requirements/sglang.txt
Normal file
2
examples/requirements/sglang.txt
Normal file
@@ -0,0 +1,2 @@
|
||||
sglang[srt]>=0.4.5
|
||||
transformers==4.51.1
|
||||
1
examples/requirements/swanlab.txt
Normal file
1
examples/requirements/swanlab.txt
Normal file
@@ -0,0 +1 @@
|
||||
swanlab
|
||||
1
examples/requirements/vllm.txt
Normal file
1
examples/requirements/vllm.txt
Normal file
@@ -0,0 +1 @@
|
||||
vllm>=0.4.3,<=0.11.0
|
||||
131
pyproject.toml
131
pyproject.toml
@@ -1,25 +1,104 @@
|
||||
[build-system]
|
||||
requires = ["setuptools>=61.0"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[project]
|
||||
name = "llamafactory"
|
||||
requires-python = ">=3.9.0"
|
||||
dynamic = [
|
||||
"version",
|
||||
"dependencies",
|
||||
"optional-dependencies",
|
||||
"scripts",
|
||||
"authors",
|
||||
"description",
|
||||
"readme",
|
||||
"license",
|
||||
"keywords",
|
||||
"classifiers"
|
||||
dynamic = ["version"]
|
||||
description = "Unified Efficient Fine-Tuning of 100+ LLMs"
|
||||
readme = "README.md"
|
||||
license = "Apache-2.0"
|
||||
requires-python = ">=3.11.0"
|
||||
authors = [
|
||||
{ name = "hiyouga", email = "hiyouga@buaa.edu.cn" }
|
||||
]
|
||||
keywords = [
|
||||
"AI",
|
||||
"LLM",
|
||||
"GPT",
|
||||
"ChatGPT",
|
||||
"Llama",
|
||||
"Transformer",
|
||||
"DeepSeek",
|
||||
"Pytorch"
|
||||
]
|
||||
classifiers = [
|
||||
"Development Status :: 4 - Beta",
|
||||
"Intended Audience :: Developers",
|
||||
"Intended Audience :: Education",
|
||||
"Intended Audience :: Science/Research",
|
||||
"License :: OSI Approved :: Apache Software License",
|
||||
"Operating System :: OS Independent",
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
"Programming Language :: Python :: 3.11",
|
||||
"Programming Language :: Python :: 3.12",
|
||||
"Programming Language :: Python :: 3.13",
|
||||
"Topic :: Scientific/Engineering :: Artificial Intelligence"
|
||||
]
|
||||
dependencies = [
|
||||
# core deps
|
||||
"torch>=2.4.0",
|
||||
"torchvision>=0.19.0",
|
||||
"torchaudio>=2.4.0",
|
||||
"transformers>=4.49.0,<=4.56.2,!=4.52.0; python_version < '3.10'",
|
||||
"transformers>=4.49.0,<=4.57.1,!=4.52.0,!=4.57.0; python_version >= '3.10'",
|
||||
"datasets>=2.16.0,<=4.0.0",
|
||||
"accelerate>=1.3.0,<=1.11.0",
|
||||
"peft>=0.14.0,<=0.17.1",
|
||||
"trl>=0.8.6,<=0.9.6",
|
||||
"torchdata>=0.10.0,<=0.11.0",
|
||||
# gui
|
||||
"gradio>=4.38.0,<=6.2.0",
|
||||
"matplotlib>=3.7.0",
|
||||
"tyro<0.9.0",
|
||||
# ops
|
||||
"einops",
|
||||
"numpy",
|
||||
"pandas",
|
||||
"scipy",
|
||||
# model and tokenizer
|
||||
"sentencepiece",
|
||||
"tiktoken",
|
||||
"modelscope",
|
||||
"hf-transfer",
|
||||
"safetensors",
|
||||
# python
|
||||
"av",
|
||||
"fire",
|
||||
"omegaconf",
|
||||
"packaging",
|
||||
"protobuf",
|
||||
"pyyaml",
|
||||
"pydantic",
|
||||
# api
|
||||
"uvicorn",
|
||||
"fastapi",
|
||||
"sse-starlette"
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
dev = ["pre-commit", "ruff", "pytest", "build"]
|
||||
metrics = ["nltk", "jieba", "rouge-chinese"]
|
||||
deepspeed = ["deepspeed>=0.10.0,<=0.16.9"]
|
||||
|
||||
[project.scripts]
|
||||
llamafactory-cli = "llamafactory.cli:main"
|
||||
lmf = "llamafactory.cli:main"
|
||||
|
||||
[project.urls]
|
||||
Homepage = "https://github.com/hiyouga/LLaMA-Factory"
|
||||
Repository = "https://github.com/hiyouga/LLaMA-Factory"
|
||||
|
||||
[tool.hatch.build.targets.wheel]
|
||||
packages = ["src/llamafactory"]
|
||||
|
||||
[tool.hatch.version]
|
||||
path = "src/llamafactory/extras/env.py"
|
||||
pattern = "VERSION = \"(?P<version>[^\"]+)\""
|
||||
|
||||
[tool.ruff]
|
||||
target-version = "py39"
|
||||
target-version = "py311"
|
||||
line-length = 119
|
||||
indent-width = 4
|
||||
|
||||
@@ -30,6 +109,8 @@ ignore = [
|
||||
"E501", # line too long
|
||||
"E731", # lambda function
|
||||
"E741", # ambiguous var name
|
||||
"UP007", # no upgrade union
|
||||
"UP045", # no upgrade optional
|
||||
"D100", # no doc public module
|
||||
"D101", # no doc public class
|
||||
"D102", # no doc public method
|
||||
@@ -73,23 +154,3 @@ indent-style = "space"
|
||||
docstring-code-format = true
|
||||
skip-magic-trailing-comma = false
|
||||
line-ending = "auto"
|
||||
|
||||
[tool.uv]
|
||||
conflicts = [
|
||||
[
|
||||
{ extra = "torch-npu" },
|
||||
{ extra = "aqlm" },
|
||||
],
|
||||
[
|
||||
{ extra = "torch-npu" },
|
||||
{ extra = "vllm" },
|
||||
],
|
||||
[
|
||||
{ extra = "torch-npu" },
|
||||
{ extra = "sglang" },
|
||||
],
|
||||
[
|
||||
{ extra = "vllm" },
|
||||
{ extra = "sglang" },
|
||||
],
|
||||
]
|
||||
|
||||
@@ -1,38 +0,0 @@
|
||||
# core deps
|
||||
transformers>=4.49.0,<=4.56.2,!=4.52.0; python_version < '3.10'
|
||||
transformers>=4.49.0,<=4.57.3,!=4.52.0,!=4.57.0; python_version >= '3.10'
|
||||
datasets>=2.16.0,<=4.0.0
|
||||
accelerate>=1.3.0,<=1.11.0
|
||||
peft>=0.14.0,<=0.17.1
|
||||
trl>=0.8.6,<=0.9.6
|
||||
# gui
|
||||
gradio>=4.38.0,<=5.45.0
|
||||
matplotlib>=3.7.0
|
||||
tyro<0.9.0
|
||||
# ops
|
||||
einops
|
||||
numpy<2.0.0
|
||||
pandas>=2.0.0
|
||||
scipy
|
||||
# model and tokenizer
|
||||
sentencepiece
|
||||
tiktoken
|
||||
modelscope>=1.14.0
|
||||
hf-transfer
|
||||
safetensors<=0.5.3
|
||||
# python
|
||||
fire
|
||||
omegaconf
|
||||
packaging
|
||||
protobuf
|
||||
pyyaml
|
||||
pydantic<=2.10.6
|
||||
# api
|
||||
uvicorn
|
||||
fastapi
|
||||
sse-starlette
|
||||
# media
|
||||
av
|
||||
librosa
|
||||
# yanked
|
||||
propcache!=0.4.0
|
||||
@@ -16,7 +16,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import fire
|
||||
import torch
|
||||
@@ -34,7 +33,7 @@ def convert_mca_to_hf(
|
||||
output_path: str = "./output",
|
||||
bf16: bool = False,
|
||||
fp16: bool = False,
|
||||
convert_model_max_length: Optional[int] = None,
|
||||
convert_model_max_length: int | None = None,
|
||||
):
|
||||
"""Convert megatron checkpoint to HuggingFace format.
|
||||
|
||||
@@ -67,11 +66,11 @@ def convert(
|
||||
output_path: str = "./output",
|
||||
bf16: bool = False,
|
||||
fp16: bool = False,
|
||||
convert_model_max_length: Optional[int] = None,
|
||||
convert_model_max_length: int | None = None,
|
||||
tensor_model_parallel_size: int = 1,
|
||||
pipeline_model_parallel_size: int = 1,
|
||||
expert_model_parallel_size: int = 1,
|
||||
virtual_pipeline_model_parallel_size: Optional[int] = None,
|
||||
virtual_pipeline_model_parallel_size: int | None = None,
|
||||
):
|
||||
"""Convert checkpoint between MCA and HuggingFace formats.
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Literal, Optional
|
||||
from typing import Any, Literal
|
||||
|
||||
import fire
|
||||
import torch
|
||||
@@ -61,7 +61,7 @@ def calculate_ppl(
|
||||
dataset_dir: str = "data",
|
||||
template: str = "default",
|
||||
cutoff_len: int = 2048,
|
||||
max_samples: Optional[int] = None,
|
||||
max_samples: int | None = None,
|
||||
train_on_prompt: bool = False,
|
||||
):
|
||||
r"""Calculate the ppl on the dataset of the pre-trained models.
|
||||
|
||||
@@ -14,7 +14,6 @@
|
||||
|
||||
import gc
|
||||
import json
|
||||
from typing import Optional
|
||||
|
||||
import av
|
||||
import fire
|
||||
@@ -49,7 +48,7 @@ def vllm_infer(
|
||||
dataset_dir: str = "data",
|
||||
template: str = "default",
|
||||
cutoff_len: int = 2048,
|
||||
max_samples: Optional[int] = None,
|
||||
max_samples: int | None = None,
|
||||
vllm_config: str = "{}",
|
||||
save_name: str = "generated_predictions.jsonl",
|
||||
temperature: float = 0.95,
|
||||
@@ -58,9 +57,9 @@ def vllm_infer(
|
||||
max_new_tokens: int = 1024,
|
||||
repetition_penalty: float = 1.0,
|
||||
skip_special_tokens: bool = True,
|
||||
default_system: Optional[str] = None,
|
||||
default_system: str | None = None,
|
||||
enable_thinking: bool = True,
|
||||
seed: Optional[int] = None,
|
||||
seed: int | None = None,
|
||||
pipeline_parallel_size: int = 1,
|
||||
image_max_pixels: int = 768 * 768,
|
||||
image_min_pixels: int = 32 * 32,
|
||||
|
||||
116
setup.py
116
setup.py
@@ -1,116 +0,0 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import re
|
||||
|
||||
from setuptools import find_packages, setup
|
||||
|
||||
|
||||
def get_version() -> str:
|
||||
with open(os.path.join("src", "llamafactory", "extras", "env.py"), encoding="utf-8") as f:
|
||||
file_content = f.read()
|
||||
pattern = r"{}\W*=\W*\"([^\"]+)\"".format("VERSION")
|
||||
(version,) = re.findall(pattern, file_content)
|
||||
return version
|
||||
|
||||
|
||||
def get_requires() -> list[str]:
|
||||
with open("requirements.txt", encoding="utf-8") as f:
|
||||
file_content = f.read()
|
||||
lines = [line.strip() for line in file_content.strip().split("\n") if not line.startswith("#")]
|
||||
return lines
|
||||
|
||||
|
||||
def get_console_scripts() -> list[str]:
|
||||
console_scripts = ["llamafactory-cli = llamafactory.cli:main"]
|
||||
if os.getenv("ENABLE_SHORT_CONSOLE", "1").lower() in ["true", "y", "1"]:
|
||||
console_scripts.append("lmf = llamafactory.cli:main")
|
||||
|
||||
return console_scripts
|
||||
|
||||
|
||||
extra_require = {
|
||||
"torch": ["torch>=2.0.0", "torchvision>=0.15.0"],
|
||||
"torch-npu": ["torch==2.7.1", "torch-npu==2.7.1", "torchvision==0.22.1", "decorator"],
|
||||
"metrics": ["nltk", "jieba", "rouge-chinese"],
|
||||
"deepspeed": ["deepspeed>=0.10.0,<=0.16.9"],
|
||||
"liger-kernel": ["liger-kernel>=0.5.5"],
|
||||
"bitsandbytes": ["bitsandbytes>=0.39.0"],
|
||||
"hqq": ["hqq"],
|
||||
"eetq": ["eetq"],
|
||||
"gptq": ["optimum>=1.24.0", "gptqmodel>=2.0.0"],
|
||||
"aqlm": ["aqlm[gpu]>=1.1.0"],
|
||||
"vllm": ["vllm>=0.4.3,<=0.11.0"],
|
||||
"sglang": ["sglang[srt]>=0.4.5", "transformers==4.51.1"],
|
||||
"galore": ["galore-torch"],
|
||||
"apollo": ["apollo-torch"],
|
||||
"badam": ["badam>=1.2.1"],
|
||||
"adam-mini": ["adam-mini"],
|
||||
"minicpm_v": [
|
||||
"soundfile",
|
||||
"torchvision",
|
||||
"torchaudio",
|
||||
"vector_quantize_pytorch",
|
||||
"vocos",
|
||||
"msgpack",
|
||||
"referencing",
|
||||
"jsonschema_specifications",
|
||||
],
|
||||
"openmind": ["openmind"],
|
||||
"swanlab": ["swanlab"],
|
||||
"fp8": ["torchao>=0.8.0", "accelerate>=1.10.0"],
|
||||
"fp8-te": ["transformer_engine[pytorch]>=2.0.0", "accelerate>=1.10.0"],
|
||||
"fp8-all": ["torchao>=0.8.0", "transformer_engine[pytorch]>=2.0.0", "accelerate>=1.10.0"],
|
||||
"dev": ["pre-commit", "ruff", "pytest", "build"],
|
||||
}
|
||||
|
||||
|
||||
def main():
|
||||
setup(
|
||||
name="llamafactory",
|
||||
version=get_version(),
|
||||
author="hiyouga",
|
||||
author_email="hiyouga@buaa.edu.cn",
|
||||
description="Unified Efficient Fine-Tuning of 100+ LLMs",
|
||||
long_description=open("README.md", encoding="utf-8").read(),
|
||||
long_description_content_type="text/markdown",
|
||||
keywords=["AI", "LLM", "GPT", "ChatGPT", "Llama", "Transformer", "DeepSeek", "Pytorch"],
|
||||
license="Apache 2.0 License",
|
||||
url="https://github.com/hiyouga/LLaMA-Factory",
|
||||
package_dir={"": "src"},
|
||||
packages=find_packages("src"),
|
||||
python_requires=">=3.9.0",
|
||||
install_requires=get_requires(),
|
||||
extras_require=extra_require,
|
||||
entry_points={"console_scripts": get_console_scripts()},
|
||||
classifiers=[
|
||||
"Development Status :: 4 - Beta",
|
||||
"Intended Audience :: Developers",
|
||||
"Intended Audience :: Education",
|
||||
"Intended Audience :: Science/Research",
|
||||
"License :: OSI Approved :: Apache Software License",
|
||||
"Operating System :: OS Independent",
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3.9",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
"Programming Language :: Python :: 3.11",
|
||||
"Programming Language :: Python :: 3.12",
|
||||
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -16,7 +16,7 @@ import asyncio
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
from functools import partial
|
||||
from typing import Annotated, Optional
|
||||
from typing import Annotated
|
||||
|
||||
from ..chat import ChatModel
|
||||
from ..extras.constants import EngineName
|
||||
@@ -79,7 +79,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
|
||||
api_key = os.getenv("API_KEY")
|
||||
security = HTTPBearer(auto_error=False)
|
||||
|
||||
async def verify_api_key(auth: Annotated[Optional[HTTPAuthorizationCredentials], Depends(security)]):
|
||||
async def verify_api_key(auth: Annotated[HTTPAuthorizationCredentials | None, Depends(security)]):
|
||||
if api_key and (auth is None or auth.credentials != api_key):
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key.")
|
||||
|
||||
|
||||
@@ -14,10 +14,9 @@
|
||||
|
||||
import time
|
||||
from enum import Enum, unique
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Literal
|
||||
|
||||
|
||||
@unique
|
||||
@@ -61,7 +60,7 @@ class FunctionDefinition(BaseModel):
|
||||
|
||||
class FunctionAvailable(BaseModel):
|
||||
type: Literal["function", "code_interpreter"] = "function"
|
||||
function: Optional[FunctionDefinition] = None
|
||||
function: FunctionDefinition | None = None
|
||||
|
||||
|
||||
class FunctionCall(BaseModel):
|
||||
@@ -77,35 +76,35 @@ class URL(BaseModel):
|
||||
|
||||
class MultimodalInputItem(BaseModel):
|
||||
type: Literal["text", "image_url", "video_url", "audio_url"]
|
||||
text: Optional[str] = None
|
||||
image_url: Optional[URL] = None
|
||||
video_url: Optional[URL] = None
|
||||
audio_url: Optional[URL] = None
|
||||
text: str | None = None
|
||||
image_url: URL | None = None
|
||||
video_url: URL | None = None
|
||||
audio_url: URL | None = None
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
role: Role
|
||||
content: Optional[Union[str, list[MultimodalInputItem]]] = None
|
||||
tool_calls: Optional[list[FunctionCall]] = None
|
||||
content: str | list[MultimodalInputItem] | None = None
|
||||
tool_calls: list[FunctionCall] | None = None
|
||||
|
||||
|
||||
class ChatCompletionMessage(BaseModel):
|
||||
role: Optional[Role] = None
|
||||
content: Optional[str] = None
|
||||
tool_calls: Optional[list[FunctionCall]] = None
|
||||
role: Role | None = None
|
||||
content: str | None = None
|
||||
tool_calls: list[FunctionCall] | None = None
|
||||
|
||||
|
||||
class ChatCompletionRequest(BaseModel):
|
||||
model: str
|
||||
messages: list[ChatMessage]
|
||||
tools: Optional[list[FunctionAvailable]] = None
|
||||
do_sample: Optional[bool] = None
|
||||
temperature: Optional[float] = None
|
||||
top_p: Optional[float] = None
|
||||
tools: list[FunctionAvailable] | None = None
|
||||
do_sample: bool | None = None
|
||||
temperature: float | None = None
|
||||
top_p: float | None = None
|
||||
n: int = 1
|
||||
presence_penalty: Optional[float] = None
|
||||
max_tokens: Optional[int] = None
|
||||
stop: Optional[Union[str, list[str]]] = None
|
||||
presence_penalty: float | None = None
|
||||
max_tokens: int | None = None
|
||||
stop: str | list[str] | None = None
|
||||
stream: bool = False
|
||||
|
||||
|
||||
@@ -118,7 +117,7 @@ class ChatCompletionResponseChoice(BaseModel):
|
||||
class ChatCompletionStreamResponseChoice(BaseModel):
|
||||
index: int
|
||||
delta: ChatCompletionMessage
|
||||
finish_reason: Optional[Finish] = None
|
||||
finish_reason: Finish | None = None
|
||||
|
||||
|
||||
class ChatCompletionResponseUsage(BaseModel):
|
||||
@@ -147,7 +146,7 @@ class ChatCompletionStreamResponse(BaseModel):
|
||||
class ScoreEvaluationRequest(BaseModel):
|
||||
model: str
|
||||
messages: list[str]
|
||||
max_length: Optional[int] = None
|
||||
max_length: int | None = None
|
||||
|
||||
|
||||
class ScoreEvaluationResponse(BaseModel):
|
||||
|
||||
@@ -14,9 +14,9 @@
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from collections.abc import AsyncGenerator
|
||||
from collections.abc import AsyncGenerator, Callable
|
||||
from threading import Thread
|
||||
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
import torch
|
||||
from transformers import GenerationConfig, TextIteratorStreamer
|
||||
|
||||
@@ -15,7 +15,7 @@ import json
|
||||
import os
|
||||
from abc import abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Union
|
||||
|
||||
from ..extras import logging
|
||||
from .data_utils import Role
|
||||
@@ -40,7 +40,7 @@ class DatasetConverter:
|
||||
dataset_attr: "DatasetAttr"
|
||||
data_args: "DataArguments"
|
||||
|
||||
def _find_medias(self, medias: Union["MediaType", list["MediaType"], None]) -> Optional[list["MediaType"]]:
|
||||
def _find_medias(self, medias: Union["MediaType", list["MediaType"], None]) -> list["MediaType"] | None:
|
||||
r"""Optionally concatenate media path to media dir when loading from local disk."""
|
||||
if medias is None:
|
||||
return None
|
||||
|
||||
@@ -81,41 +81,48 @@ def split_dataset(
|
||||
eval_dataset: Optional[Union["Dataset", "IterableDataset", dict[str, "Dataset"]]],
|
||||
data_args: "DataArguments",
|
||||
seed: int,
|
||||
) -> "DatasetDict":
|
||||
r"""Split the dataset and returns a dataset dict containing train set and validation set.
|
||||
) -> tuple[dict, dict]:
|
||||
r"""Split the dataset and returns two dicts containing train set and validation set.
|
||||
|
||||
Support both map dataset and iterable dataset.
|
||||
|
||||
Returns:
|
||||
train_dict: Dictionary containing training data with key "train"
|
||||
eval_dict: Dictionary containing evaluation data with keys "validation" or "validation_{name}"
|
||||
"""
|
||||
if eval_dataset is not None and data_args.val_size > 1e-6:
|
||||
raise ValueError("Cannot specify `val_size` if `eval_dataset` is not None.")
|
||||
|
||||
dataset_dict = {}
|
||||
# the train and eval better to in dict dtype and separately return for cpode clearly and good handle outside
|
||||
train_dict, eval_dict = {}, {}
|
||||
|
||||
if dataset is not None:
|
||||
if data_args.streaming:
|
||||
dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=seed)
|
||||
|
||||
if data_args.val_size > 1e-6:
|
||||
if data_args.streaming:
|
||||
dataset_dict["validation"] = dataset.take(int(data_args.val_size))
|
||||
dataset_dict["train"] = dataset.skip(int(data_args.val_size))
|
||||
eval_dict["validation"] = dataset.take(int(data_args.val_size))
|
||||
train_dict["train"] = dataset.skip(int(data_args.val_size))
|
||||
else:
|
||||
val_size = int(data_args.val_size) if data_args.val_size > 1 else data_args.val_size
|
||||
dataset_dict = dataset.train_test_split(test_size=val_size, seed=seed)
|
||||
dataset = dataset.train_test_split(test_size=val_size, seed=seed)
|
||||
dataset_dict = {"train": dataset["train"], "validation": dataset["test"]}
|
||||
split_result = dataset.train_test_split(test_size=val_size, seed=seed)
|
||||
train_dict["train"] = split_result["train"]
|
||||
eval_dict["validation"] = split_result["test"]
|
||||
else:
|
||||
dataset_dict["train"] = dataset
|
||||
train_dict["train"] = dataset
|
||||
|
||||
if eval_dataset is not None:
|
||||
if isinstance(eval_dataset, dict):
|
||||
dataset_dict.update({f"validation_{name}": data for name, data in eval_dataset.items()})
|
||||
for name, data in eval_dataset.items():
|
||||
eval_dict[f"validation_{name}"] = data
|
||||
else:
|
||||
if data_args.streaming:
|
||||
eval_dataset = eval_dataset.shuffle(buffer_size=data_args.buffer_size, seed=seed)
|
||||
|
||||
dataset_dict["validation"] = eval_dataset
|
||||
eval_dict["validation"] = eval_dataset
|
||||
|
||||
return DatasetDict(dataset_dict)
|
||||
return train_dict, eval_dict
|
||||
|
||||
|
||||
def get_dataset_module(dataset: Union["Dataset", "DatasetDict"]) -> "DatasetModule":
|
||||
|
||||
@@ -16,7 +16,6 @@ import json
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional, Union
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
@@ -27,14 +26,14 @@ from .tool_utils import FunctionCall, get_tool_utils
|
||||
@dataclass
|
||||
class Formatter(ABC):
|
||||
slots: SLOTS = field(default_factory=list)
|
||||
tool_format: Optional[str] = None
|
||||
tool_format: str | None = None
|
||||
|
||||
@abstractmethod
|
||||
def apply(self, **kwargs) -> SLOTS:
|
||||
r"""Forms a list of slots according to the inputs to encode."""
|
||||
...
|
||||
|
||||
def extract(self, content: str) -> Union[str, list["FunctionCall"]]:
|
||||
def extract(self, content: str) -> str | list["FunctionCall"]:
|
||||
r"""Extract a list of tuples from the response message if using tools.
|
||||
|
||||
Each tuple consists of function name and function arguments.
|
||||
@@ -156,5 +155,5 @@ class ToolFormatter(Formatter):
|
||||
raise RuntimeError(f"Invalid JSON format in tool description: {str([content])}.") # flat string
|
||||
|
||||
@override
|
||||
def extract(self, content: str) -> Union[str, list["FunctionCall"]]:
|
||||
def extract(self, content: str) -> str | list["FunctionCall"]:
|
||||
return self.tool_utils.tool_extractor(content)
|
||||
|
||||
@@ -16,7 +16,7 @@ import os
|
||||
from typing import TYPE_CHECKING, Literal, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from datasets import Dataset, load_dataset, load_from_disk
|
||||
from datasets import Dataset, DatasetDict, load_dataset, load_from_disk
|
||||
|
||||
from ..extras import logging
|
||||
from ..extras.constants import FILEEXT2TYPE
|
||||
@@ -162,13 +162,13 @@ def _load_single_dataset(
|
||||
|
||||
|
||||
def _get_merged_dataset(
|
||||
dataset_names: Optional[list[str]],
|
||||
dataset_names: list[str] | None,
|
||||
model_args: "ModelArguments",
|
||||
data_args: "DataArguments",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
|
||||
return_dict: bool = False,
|
||||
) -> Optional[Union["Dataset", "IterableDataset", dict[str, "Dataset"]]]:
|
||||
) -> Union["Dataset", "IterableDataset", dict[str, "Dataset"]] | None:
|
||||
r"""Return the merged datasets in the standard format."""
|
||||
if dataset_names is None:
|
||||
return None
|
||||
@@ -227,7 +227,7 @@ def _get_dataset_processor(
|
||||
|
||||
|
||||
def _get_preprocessed_dataset(
|
||||
dataset: Optional[Union["Dataset", "IterableDataset"]],
|
||||
dataset: Union["Dataset", "IterableDataset"] | None,
|
||||
data_args: "DataArguments",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
|
||||
@@ -235,7 +235,7 @@ def _get_preprocessed_dataset(
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"] = None,
|
||||
is_eval: bool = False,
|
||||
) -> Optional[Union["Dataset", "IterableDataset"]]:
|
||||
) -> Union["Dataset", "IterableDataset"] | None:
|
||||
r"""Preprocesses the dataset, including format checking and tokenization."""
|
||||
if dataset is None:
|
||||
return None
|
||||
@@ -311,20 +311,22 @@ def get_dataset(
|
||||
)
|
||||
|
||||
with training_args.main_process_first(desc="pre-process dataset", local=(not data_args.data_shared_file_system)):
|
||||
dataset = _get_preprocessed_dataset(
|
||||
dataset, data_args, training_args, stage, template, tokenizer, processor, is_eval=False
|
||||
)
|
||||
if isinstance(eval_dataset, dict):
|
||||
for eval_name, eval_data in eval_dataset.items():
|
||||
eval_dataset[eval_name] = _get_preprocessed_dataset(
|
||||
eval_data, data_args, training_args, stage, template, tokenizer, processor, is_eval=True
|
||||
)
|
||||
else:
|
||||
eval_dataset = _get_preprocessed_dataset(
|
||||
eval_dataset, data_args, training_args, stage, template, tokenizer, processor, is_eval=True
|
||||
# move front to make sure eval_dataset(if contain or split) can preprocessed appropriately
|
||||
train_dict, eval_dict = split_dataset(dataset, eval_dataset, data_args, seed=training_args.seed)
|
||||
|
||||
if "train" in train_dict:
|
||||
train_dict["train"] = _get_preprocessed_dataset(
|
||||
train_dict["train"], data_args, training_args, stage, template, tokenizer, processor, is_eval=False
|
||||
)
|
||||
|
||||
dataset_dict = split_dataset(dataset, eval_dataset, data_args, seed=training_args.seed)
|
||||
for key in eval_dict:
|
||||
eval_dict[key] = _get_preprocessed_dataset(
|
||||
eval_dict[key], data_args, training_args, stage, template, tokenizer, processor, is_eval=True
|
||||
)
|
||||
|
||||
# Combine train and eval dictionaries
|
||||
dataset_dict = DatasetDict({**train_dict, **eval_dict})
|
||||
|
||||
if data_args.tokenized_path is not None: # save tokenized dataset to disk
|
||||
if training_args.should_save:
|
||||
dataset_dict.save_to_disk(data_args.tokenized_path)
|
||||
|
||||
@@ -22,28 +22,20 @@ import re
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from io import BytesIO
|
||||
from typing import TYPE_CHECKING, BinaryIO, Literal, Optional, TypedDict, Union
|
||||
from typing import TYPE_CHECKING, BinaryIO, Literal, NotRequired, Optional, TypedDict, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchaudio
|
||||
from transformers.image_utils import get_image_size, is_valid_image, to_numpy_array
|
||||
from transformers.models.mllama.processing_mllama import (
|
||||
convert_sparse_cross_attention_mask_to_dense,
|
||||
get_cross_attention_token_mask,
|
||||
)
|
||||
from typing_extensions import NotRequired, override
|
||||
from typing_extensions import override
|
||||
|
||||
from ..extras.constants import AUDIO_PLACEHOLDER, IGNORE_INDEX, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
|
||||
from ..extras.packages import (
|
||||
is_librosa_available,
|
||||
is_pillow_available,
|
||||
is_pyav_available,
|
||||
is_transformers_version_greater_than,
|
||||
)
|
||||
|
||||
|
||||
if is_librosa_available():
|
||||
import librosa
|
||||
from ..extras.packages import is_pillow_available, is_pyav_available, is_transformers_version_greater_than
|
||||
|
||||
|
||||
if is_pillow_available():
|
||||
@@ -71,8 +63,8 @@ if TYPE_CHECKING:
|
||||
from transformers.video_processing_utils import BaseVideoProcessor
|
||||
|
||||
class EncodedImage(TypedDict):
|
||||
path: Optional[str]
|
||||
bytes: Optional[bytes]
|
||||
path: str | None
|
||||
bytes: bytes | None
|
||||
|
||||
ImageInput = Union[str, bytes, EncodedImage, BinaryIO, ImageObject]
|
||||
VideoInput = Union[str, BinaryIO, list[list[ImageInput]]]
|
||||
@@ -152,9 +144,9 @@ def _check_video_is_nested_images(video: "VideoInput") -> bool:
|
||||
|
||||
@dataclass
|
||||
class MMPluginMixin:
|
||||
image_token: Optional[str]
|
||||
video_token: Optional[str]
|
||||
audio_token: Optional[str]
|
||||
image_token: str | None
|
||||
video_token: str | None
|
||||
audio_token: str | None
|
||||
expand_mm_tokens: bool = True
|
||||
|
||||
def _validate_input(
|
||||
@@ -316,7 +308,14 @@ class MMPluginMixin:
|
||||
results, sampling_rates = [], []
|
||||
for audio in audios:
|
||||
if not isinstance(audio, np.ndarray):
|
||||
audio, sampling_rate = librosa.load(audio, sr=sampling_rate)
|
||||
audio, sr = torchaudio.load(audio)
|
||||
if audio.shape[0] > 1:
|
||||
audio = audio.mean(dim=0, keepdim=True)
|
||||
|
||||
if sr != sampling_rate:
|
||||
audio = torchaudio.functional.resample(audio, sr, sampling_rate)
|
||||
|
||||
audio = audio.squeeze(0).numpy()
|
||||
|
||||
results.append(audio)
|
||||
sampling_rates.append(sampling_rate)
|
||||
@@ -329,7 +328,7 @@ class MMPluginMixin:
|
||||
videos: list["VideoInput"],
|
||||
audios: list["AudioInput"],
|
||||
processor: "MMProcessor",
|
||||
imglens: Optional[list[int]] = None,
|
||||
imglens: list[int] | None = None,
|
||||
) -> dict[str, "torch.Tensor"]:
|
||||
r"""Process visual inputs.
|
||||
|
||||
@@ -427,13 +426,13 @@ class BasePlugin(MMPluginMixin):
|
||||
def process_token_ids(
|
||||
self,
|
||||
input_ids: list[int],
|
||||
labels: Optional[list[int]],
|
||||
labels: list[int] | None,
|
||||
images: list["ImageInput"],
|
||||
videos: list["VideoInput"],
|
||||
audios: list["AudioInput"],
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["MMProcessor"],
|
||||
) -> tuple[list[int], Optional[list[int]]]:
|
||||
) -> tuple[list[int], list[int] | None]:
|
||||
r"""Pre-process token ids after tokenization for VLMs."""
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
return input_ids, labels
|
||||
@@ -480,21 +479,39 @@ class ErnieVLPlugin(BasePlugin):
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
self._validate_messages(messages, images, videos, audios)
|
||||
messages = deepcopy(messages)
|
||||
|
||||
image_processor: BaseImageProcessor = getattr(processor, "image_processor")
|
||||
|
||||
merge_length: int = getattr(image_processor, "merge_size") ** 2
|
||||
if self.expand_mm_tokens:
|
||||
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
||||
image_grid_thw = mm_inputs.get("image_grid_thw", [])
|
||||
video_grid_thw = mm_inputs.get("video_grid_thw", [])
|
||||
else:
|
||||
image_grid_thw = [None] * len(images)
|
||||
video_grid_thw = [None] * len(videos)
|
||||
|
||||
image_idx, video_idx = 0, 0
|
||||
for message in messages:
|
||||
content = message["content"]
|
||||
image_token = self.image_token or "<|image@placeholder|>"
|
||||
video_token = self.video_token or "<|video@placeholder|>"
|
||||
image_token = self.image_token or "<|IMAGE_PLACEHOLDER|>"
|
||||
video_token = self.video_token or "<|VIDEO_PLACEHOLDER|>"
|
||||
while IMAGE_PLACEHOLDER in content:
|
||||
image_seqlen = image_grid_thw[image_idx].prod() // merge_length if self.expand_mm_tokens else 1
|
||||
content = content.replace(
|
||||
IMAGE_PLACEHOLDER,
|
||||
f"Picture {image_idx + 1}:<|IMAGE_START|>{image_token * image_seqlen}<|IMAGE_END|>",
|
||||
1,
|
||||
)
|
||||
image_idx += 1
|
||||
content = content.replace(
|
||||
IMAGE_PLACEHOLDER, f"Picture {image_idx}:<|IMAGE_START|>{image_token}<|IMAGE_END|>", 1
|
||||
)
|
||||
while VIDEO_PLACEHOLDER in content:
|
||||
video_idx += 1
|
||||
video_seqlen = video_grid_thw[video_idx].prod() // merge_length if self.expand_mm_tokens else 1
|
||||
content = content.replace(
|
||||
VIDEO_PLACEHOLDER, f"Video {video_idx}:<|VIDEO_START|>{video_token}<|VIDEO_END|>", 1
|
||||
VIDEO_PLACEHOLDER,
|
||||
f"Video {video_idx + 1}:<|VIDEO_START|>{video_token * video_seqlen}<|VIDEO_END|>",
|
||||
1,
|
||||
)
|
||||
video_idx += 1
|
||||
message["content"] = content
|
||||
return messages
|
||||
|
||||
@@ -1288,13 +1305,13 @@ class PaliGemmaPlugin(BasePlugin):
|
||||
def process_token_ids(
|
||||
self,
|
||||
input_ids: list[int],
|
||||
labels: Optional[list[int]],
|
||||
labels: list[int] | None,
|
||||
images: list["ImageInput"],
|
||||
videos: list["VideoInput"],
|
||||
audios: list["AudioInput"],
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["MMProcessor"],
|
||||
) -> tuple[list[int], Optional[list[int]]]:
|
||||
) -> tuple[list[int], list[int] | None]:
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
num_images = len(images)
|
||||
image_seqlen = processor.image_seq_length if self.expand_mm_tokens else 0 # skip mm token
|
||||
@@ -1624,7 +1641,12 @@ class Qwen3VLPlugin(Qwen2VLPlugin):
|
||||
for video, duration in zip(videos["videos"], videos["durations"])
|
||||
]
|
||||
mm_inputs.update(
|
||||
video_processor(videos=videos["videos"], video_metadata=video_metadata, return_metadata=True)
|
||||
video_processor(
|
||||
videos=videos["videos"],
|
||||
video_metadata=video_metadata,
|
||||
fps=getattr(processor, "video_fps", 2.0),
|
||||
return_metadata=True,
|
||||
)
|
||||
)
|
||||
temporal_patch_size: int = getattr(image_processor, "temporal_patch_size", 2)
|
||||
if "second_per_grid_ts" in processor.model_input_names:
|
||||
@@ -2104,9 +2126,9 @@ def register_mm_plugin(name: str, plugin_class: type["BasePlugin"]) -> None:
|
||||
|
||||
def get_mm_plugin(
|
||||
name: str,
|
||||
image_token: Optional[str] = None,
|
||||
video_token: Optional[str] = None,
|
||||
audio_token: Optional[str] = None,
|
||||
image_token: str | None = None,
|
||||
video_token: str | None = None,
|
||||
audio_token: str | None = None,
|
||||
**kwargs,
|
||||
) -> "BasePlugin":
|
||||
r"""Get plugin for multimodal inputs."""
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
import json
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Literal, Optional, Union
|
||||
from typing import Any, Literal
|
||||
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
@@ -33,40 +33,40 @@ class DatasetAttr:
|
||||
formatting: Literal["alpaca", "sharegpt", "openai"] = "alpaca"
|
||||
ranking: bool = False
|
||||
# extra configs
|
||||
subset: Optional[str] = None
|
||||
subset: str | None = None
|
||||
split: str = "train"
|
||||
folder: Optional[str] = None
|
||||
num_samples: Optional[int] = None
|
||||
folder: str | None = None
|
||||
num_samples: int | None = None
|
||||
# common columns
|
||||
system: Optional[str] = None
|
||||
tools: Optional[str] = None
|
||||
images: Optional[str] = None
|
||||
videos: Optional[str] = None
|
||||
audios: Optional[str] = None
|
||||
system: str | None = None
|
||||
tools: str | None = None
|
||||
images: str | None = None
|
||||
videos: str | None = None
|
||||
audios: str | None = None
|
||||
# dpo columns
|
||||
chosen: Optional[str] = None
|
||||
rejected: Optional[str] = None
|
||||
kto_tag: Optional[str] = None
|
||||
chosen: str | None = None
|
||||
rejected: str | None = None
|
||||
kto_tag: str | None = None
|
||||
# alpaca columns
|
||||
prompt: Optional[str] = "instruction"
|
||||
query: Optional[str] = "input"
|
||||
response: Optional[str] = "output"
|
||||
history: Optional[str] = None
|
||||
prompt: str | None = "instruction"
|
||||
query: str | None = "input"
|
||||
response: str | None = "output"
|
||||
history: str | None = None
|
||||
# sharegpt columns
|
||||
messages: Optional[str] = "conversations"
|
||||
messages: str | None = "conversations"
|
||||
# sharegpt tags
|
||||
role_tag: Optional[str] = "from"
|
||||
content_tag: Optional[str] = "value"
|
||||
user_tag: Optional[str] = "human"
|
||||
assistant_tag: Optional[str] = "gpt"
|
||||
observation_tag: Optional[str] = "observation"
|
||||
function_tag: Optional[str] = "function_call"
|
||||
system_tag: Optional[str] = "system"
|
||||
role_tag: str | None = "from"
|
||||
content_tag: str | None = "value"
|
||||
user_tag: str | None = "human"
|
||||
assistant_tag: str | None = "gpt"
|
||||
observation_tag: str | None = "observation"
|
||||
function_tag: str | None = "function_call"
|
||||
system_tag: str | None = "system"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return self.dataset_name
|
||||
|
||||
def set_attr(self, key: str, obj: dict[str, Any], default: Optional[Any] = None) -> None:
|
||||
def set_attr(self, key: str, obj: dict[str, Any], default: Any | None = None) -> None:
|
||||
setattr(self, key, obj.get(key, default))
|
||||
|
||||
def join(self, attr: dict[str, Any]) -> None:
|
||||
@@ -90,7 +90,7 @@ class DatasetAttr:
|
||||
self.set_attr(tag, attr["tags"])
|
||||
|
||||
|
||||
def get_dataset_list(dataset_names: Optional[list[str]], dataset_dir: Union[str, dict]) -> list["DatasetAttr"]:
|
||||
def get_dataset_list(dataset_names: list[str] | None, dataset_dir: str | dict) -> list["DatasetAttr"]:
|
||||
r"""Get the attributes of the datasets."""
|
||||
if dataset_names is None:
|
||||
dataset_names = []
|
||||
|
||||
@@ -981,7 +981,7 @@ register_template(
|
||||
replace_eos=True,
|
||||
replace_jinja_template=True,
|
||||
template_class=ReasoningTemplate,
|
||||
mm_plugin=get_mm_plugin(name="ernie_vl", image_token="<|image@placeholder|>", video_token="<|video@placeholder|>"),
|
||||
mm_plugin=get_mm_plugin(name="ernie_vl", image_token="<|IMAGE_PLACEHOLDER|>", video_token="<|VIDEO_PLACEHOLDER|>"),
|
||||
)
|
||||
|
||||
|
||||
@@ -1166,7 +1166,7 @@ register_template(
|
||||
|
||||
|
||||
register_template(
|
||||
name="gpt",
|
||||
name="gpt_oss",
|
||||
format_user=StringFormatter(slots=["<|start|>user<|message|>{{content}}<|end|><|start|>assistant"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}<|end|>"]),
|
||||
format_system=StringFormatter(slots=["<|start|>system<|message|>{{content}}<|end|>"]),
|
||||
@@ -1610,6 +1610,26 @@ register_template(
|
||||
template_class=ReasoningTemplate,
|
||||
)
|
||||
|
||||
|
||||
# copied from qwen template
|
||||
register_template(
|
||||
name="mimo_v2",
|
||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
|
||||
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
||||
format_function=FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen"),
|
||||
format_observation=StringFormatter(
|
||||
slots=["<|im_start|>user\n<tool_response>\n{{content}}\n</tool_response><|im_end|>\n<|im_start|>assistant\n"]
|
||||
),
|
||||
format_tools=ToolFormatter(tool_format="qwen"),
|
||||
default_system="You are MiMo, a helpful AI assistant engineered by Xiaomi.",
|
||||
stop_words=["<|im_end|>"],
|
||||
replace_eos=True,
|
||||
thought_words=("<think>", "</think>"),
|
||||
template_class=ReasoningTemplate,
|
||||
)
|
||||
|
||||
|
||||
# copied from qwen2vl
|
||||
register_template(
|
||||
name="mimo_vl",
|
||||
|
||||
@@ -15,7 +15,6 @@
|
||||
import os
|
||||
from collections import OrderedDict, defaultdict
|
||||
from enum import Enum, unique
|
||||
from typing import Optional
|
||||
|
||||
from peft.utils import SAFETENSORS_WEIGHTS_NAME as SAFE_ADAPTER_WEIGHTS_NAME
|
||||
from peft.utils import WEIGHTS_NAME as ADAPTER_WEIGHTS_NAME
|
||||
@@ -114,6 +113,7 @@ class AttentionFunction(str, Enum):
|
||||
DISABLED = "disabled"
|
||||
SDPA = "sdpa"
|
||||
FA2 = "fa2"
|
||||
FA3 = "fa3"
|
||||
|
||||
|
||||
class EngineName(str, Enum):
|
||||
@@ -153,7 +153,7 @@ class RopeScaling(str, Enum):
|
||||
|
||||
def register_model_group(
|
||||
models: dict[str, dict[DownloadSource, str]],
|
||||
template: Optional[str] = None,
|
||||
template: str | None = None,
|
||||
multimodal: bool = False,
|
||||
) -> None:
|
||||
for name, path in models.items():
|
||||
@@ -1067,7 +1067,7 @@ register_model_group(
|
||||
DownloadSource.MODELSCOPE: "openai/gpt-oss-120b",
|
||||
},
|
||||
},
|
||||
template="gpt",
|
||||
template="gpt_oss",
|
||||
)
|
||||
|
||||
|
||||
@@ -1803,6 +1803,21 @@ register_model_group(
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"MiMo-V2-Flash-Base": {
|
||||
DownloadSource.DEFAULT: "XiaomiMiMo/MiMo-V2-Flash-Base",
|
||||
DownloadSource.MODELSCOPE: "XiaomiMiMo/MiMo-V2-Flash-Base",
|
||||
},
|
||||
"MiMo-V2-Flash": {
|
||||
DownloadSource.DEFAULT: "XiaomiMiMo/MiMo-V2-Flash",
|
||||
DownloadSource.MODELSCOPE: "XiaomiMiMo/MiMo-V2-Flash",
|
||||
},
|
||||
},
|
||||
template="mimo_v2",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"MiMo-7B-VL-RL": {
|
||||
@@ -1827,7 +1842,7 @@ register_model_group(
|
||||
},
|
||||
"MiMo-VL-7B-SFT-2508": {
|
||||
DownloadSource.DEFAULT: "XiaomiMiMo/MiMo-VL-7B-SFT-2508",
|
||||
DownloadSource.DEFAULT: "XiaomiMiMo/MiMo-VL-7B-SFT-2508",
|
||||
DownloadSource.MODELSCOPE: "XiaomiMiMo/MiMo-VL-7B-SFT-2508",
|
||||
},
|
||||
},
|
||||
template="qwen2_vl",
|
||||
@@ -1980,6 +1995,18 @@ register_model_group(
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Ministral-3-3B-Base-2512": {
|
||||
DownloadSource.DEFAULT: "mistralai/Ministral-3-3B-Base-2512",
|
||||
DownloadSource.MODELSCOPE: "mistralai/Ministral-3-3B-Base-2512",
|
||||
},
|
||||
"Ministral-3-8B-Base-2512": {
|
||||
DownloadSource.DEFAULT: "mistralai/Ministral-3-8B-Base-2512",
|
||||
DownloadSource.MODELSCOPE: "mistralai/Ministral-3-8B-Base-2512",
|
||||
},
|
||||
"Ministral-3-14B-Base-2512": {
|
||||
DownloadSource.DEFAULT: "mistralai/Ministral-3-14B-Base-2512",
|
||||
DownloadSource.MODELSCOPE: "mistralai/Ministral-3-14B-Base-2512",
|
||||
},
|
||||
"Ministral-3-3B-Instruct-2512": {
|
||||
DownloadSource.DEFAULT: "mistralai/Ministral-3-3B-Instruct-2512",
|
||||
DownloadSource.MODELSCOPE: "mistralai/Ministral-3-3B-Instruct-2512",
|
||||
|
||||
@@ -117,7 +117,7 @@ def _configure_library_root_logger() -> None:
|
||||
library_root_logger.propagate = False
|
||||
|
||||
|
||||
def get_logger(name: Optional[str] = None) -> "_Logger":
|
||||
def get_logger(name: str | None = None) -> "_Logger":
|
||||
r"""Return a logger with the specified name. It it not supposed to be accessed externally."""
|
||||
if name is None:
|
||||
name = _get_library_name()
|
||||
|
||||
@@ -94,7 +94,7 @@ def check_version(requirement: str, mandatory: bool = False) -> None:
|
||||
|
||||
def check_dependencies() -> None:
|
||||
r"""Check the version of the required packages."""
|
||||
check_version("transformers>=4.49.0,<=4.57.3")
|
||||
check_version("transformers>=4.49.0,<=4.57.1")
|
||||
check_version("datasets>=2.16.0,<=4.0.0")
|
||||
check_version("accelerate>=1.3.0,<=1.11.0")
|
||||
check_version("peft>=0.14.0,<=0.17.1")
|
||||
@@ -332,3 +332,7 @@ def fix_proxy(ipv6_enabled: bool = False) -> None:
|
||||
if ipv6_enabled:
|
||||
os.environ.pop("http_proxy", None)
|
||||
os.environ.pop("HTTP_PROXY", None)
|
||||
os.environ.pop("https_proxy", None)
|
||||
os.environ.pop("HTTPS_PROXY", None)
|
||||
os.environ.pop("all_proxy", None)
|
||||
os.environ.pop("ALL_PROXY", None)
|
||||
|
||||
@@ -16,22 +16,22 @@
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from typing import Any, Literal, Optional
|
||||
from typing import Any, Literal
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataArguments:
|
||||
r"""Arguments pertaining to what data we are going to input our model for training and evaluation."""
|
||||
|
||||
template: Optional[str] = field(
|
||||
template: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Which template to use for constructing prompts in training and inference."},
|
||||
)
|
||||
dataset: Optional[str] = field(
|
||||
dataset: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The name of dataset(s) to use for training. Use commas to separate multiple datasets."},
|
||||
)
|
||||
eval_dataset: Optional[str] = field(
|
||||
eval_dataset: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The name of dataset(s) to use for evaluation. Use commas to separate multiple datasets."},
|
||||
)
|
||||
@@ -39,7 +39,7 @@ class DataArguments:
|
||||
default="data",
|
||||
metadata={"help": "Path to the folder containing the datasets."},
|
||||
)
|
||||
media_dir: Optional[str] = field(
|
||||
media_dir: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to the folder containing the images, videos or audios. Defaults to `dataset_dir`."},
|
||||
)
|
||||
@@ -67,7 +67,7 @@ class DataArguments:
|
||||
default="concat",
|
||||
metadata={"help": "Strategy to use in dataset mixing (concat/interleave) (undersampling/oversampling)."},
|
||||
)
|
||||
interleave_probs: Optional[str] = field(
|
||||
interleave_probs: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Probabilities to sample data from datasets. Use commas to separate multiple datasets."},
|
||||
)
|
||||
@@ -79,15 +79,15 @@ class DataArguments:
|
||||
default=1000,
|
||||
metadata={"help": "The number of examples in one group in pre-processing."},
|
||||
)
|
||||
preprocessing_num_workers: Optional[int] = field(
|
||||
preprocessing_num_workers: int | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The number of processes to use for the pre-processing."},
|
||||
)
|
||||
max_samples: Optional[int] = field(
|
||||
max_samples: int | None = field(
|
||||
default=None,
|
||||
metadata={"help": "For debugging purposes, truncate the number of examples for each dataset."},
|
||||
)
|
||||
eval_num_beams: Optional[int] = field(
|
||||
eval_num_beams: int | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Number of beams to use for evaluation. This argument will be passed to `model.generate`"},
|
||||
)
|
||||
@@ -103,7 +103,7 @@ class DataArguments:
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to evaluate on each dataset separately."},
|
||||
)
|
||||
packing: Optional[bool] = field(
|
||||
packing: bool | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Enable sequences packing in training. Will automatically enable in pre-training."},
|
||||
)
|
||||
@@ -111,19 +111,19 @@ class DataArguments:
|
||||
default=False,
|
||||
metadata={"help": "Enable sequence packing without cross-attention."},
|
||||
)
|
||||
tool_format: Optional[str] = field(
|
||||
tool_format: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Tool format to use for constructing function calling examples."},
|
||||
)
|
||||
default_system: Optional[str] = field(
|
||||
default_system: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Override the default system message in the template."},
|
||||
)
|
||||
enable_thinking: Optional[bool] = field(
|
||||
enable_thinking: bool | None = field(
|
||||
default=True,
|
||||
metadata={"help": "Whether or not to enable thinking mode for reasoning models."},
|
||||
)
|
||||
tokenized_path: Optional[str] = field(
|
||||
tokenized_path: str | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": (
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Literal, Optional
|
||||
from typing import Literal
|
||||
|
||||
from datasets import DownloadMode
|
||||
|
||||
@@ -46,7 +46,7 @@ class EvaluationArguments:
|
||||
default=5,
|
||||
metadata={"help": "Number of examplars for few-shot learning."},
|
||||
)
|
||||
save_dir: Optional[str] = field(
|
||||
save_dir: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to save the evaluation results."},
|
||||
)
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from typing import Any, Literal, Optional
|
||||
from typing import Any, Literal
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -40,7 +40,7 @@ class FreezeArguments:
|
||||
)
|
||||
},
|
||||
)
|
||||
freeze_extra_modules: Optional[str] = field(
|
||||
freeze_extra_modules: str | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": (
|
||||
@@ -56,7 +56,7 @@ class FreezeArguments:
|
||||
class LoraArguments:
|
||||
r"""Arguments pertaining to the LoRA training."""
|
||||
|
||||
additional_target: Optional[str] = field(
|
||||
additional_target: str | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": (
|
||||
@@ -66,7 +66,7 @@ class LoraArguments:
|
||||
)
|
||||
},
|
||||
)
|
||||
lora_alpha: Optional[int] = field(
|
||||
lora_alpha: int | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The scale factor for LoRA fine-tuning (default: lora_rank * 2)."},
|
||||
)
|
||||
@@ -88,7 +88,7 @@ class LoraArguments:
|
||||
)
|
||||
},
|
||||
)
|
||||
loraplus_lr_ratio: Optional[float] = field(
|
||||
loraplus_lr_ratio: float | None = field(
|
||||
default=None,
|
||||
metadata={"help": "LoRA plus learning rate ratio (lr_B / lr_A)."},
|
||||
)
|
||||
@@ -126,7 +126,7 @@ class LoraArguments:
|
||||
class OFTArguments:
|
||||
r"""Arguments pertaining to the OFT training."""
|
||||
|
||||
additional_target: Optional[str] = field(
|
||||
additional_target: str | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": (
|
||||
@@ -220,27 +220,27 @@ class RLHFArguments:
|
||||
default=False,
|
||||
metadata={"help": "Whiten the rewards before compute advantages in PPO training."},
|
||||
)
|
||||
ref_model: Optional[str] = field(
|
||||
ref_model: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to the reference model used for the PPO or DPO training."},
|
||||
)
|
||||
ref_model_adapters: Optional[str] = field(
|
||||
ref_model_adapters: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to the adapters of the reference model."},
|
||||
)
|
||||
ref_model_quantization_bit: Optional[int] = field(
|
||||
ref_model_quantization_bit: int | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The number of bits to quantize the reference model."},
|
||||
)
|
||||
reward_model: Optional[str] = field(
|
||||
reward_model: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to the reward model used for the PPO training."},
|
||||
)
|
||||
reward_model_adapters: Optional[str] = field(
|
||||
reward_model_adapters: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to the adapters of the reward model."},
|
||||
)
|
||||
reward_model_quantization_bit: Optional[int] = field(
|
||||
reward_model_quantization_bit: int | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The number of bits to quantize the reward model."},
|
||||
)
|
||||
@@ -248,7 +248,7 @@ class RLHFArguments:
|
||||
default="lora",
|
||||
metadata={"help": "The type of the reward model in PPO training. Lora model only supports lora training."},
|
||||
)
|
||||
ld_alpha: Optional[float] = field(
|
||||
ld_alpha: float | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": (
|
||||
@@ -361,15 +361,15 @@ class BAdamArgument:
|
||||
default="layer",
|
||||
metadata={"help": "Whether to use layer-wise or ratio-wise BAdam optimizer."},
|
||||
)
|
||||
badam_start_block: Optional[int] = field(
|
||||
badam_start_block: int | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The starting block index for layer-wise BAdam."},
|
||||
)
|
||||
badam_switch_mode: Optional[Literal["ascending", "descending", "random", "fixed"]] = field(
|
||||
badam_switch_mode: Literal["ascending", "descending", "random", "fixed"] | None = field(
|
||||
default="ascending",
|
||||
metadata={"help": "the strategy of picking block to update for layer-wise BAdam."},
|
||||
)
|
||||
badam_switch_interval: Optional[int] = field(
|
||||
badam_switch_interval: int | None = field(
|
||||
default=50,
|
||||
metadata={
|
||||
"help": "Number of steps to update the block for layer-wise BAdam. Use -1 to disable the block update."
|
||||
@@ -406,15 +406,15 @@ class SwanLabArguments:
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to use the SwanLab (an experiment tracking and visualization tool)."},
|
||||
)
|
||||
swanlab_project: Optional[str] = field(
|
||||
swanlab_project: str | None = field(
|
||||
default="llamafactory",
|
||||
metadata={"help": "The project name in SwanLab."},
|
||||
)
|
||||
swanlab_workspace: Optional[str] = field(
|
||||
swanlab_workspace: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The workspace name in SwanLab."},
|
||||
)
|
||||
swanlab_run_name: Optional[str] = field(
|
||||
swanlab_run_name: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The experiment name in SwanLab."},
|
||||
)
|
||||
@@ -422,19 +422,19 @@ class SwanLabArguments:
|
||||
default="cloud",
|
||||
metadata={"help": "The mode of SwanLab."},
|
||||
)
|
||||
swanlab_api_key: Optional[str] = field(
|
||||
swanlab_api_key: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The API key for SwanLab."},
|
||||
)
|
||||
swanlab_logdir: Optional[str] = field(
|
||||
swanlab_logdir: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The log directory for SwanLab."},
|
||||
)
|
||||
swanlab_lark_webhook_url: Optional[str] = field(
|
||||
swanlab_lark_webhook_url: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The Lark(飞书) webhook URL for SwanLab."},
|
||||
)
|
||||
swanlab_lark_secret: Optional[str] = field(
|
||||
swanlab_lark_secret: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The Lark(飞书) secret for SwanLab."},
|
||||
)
|
||||
@@ -510,7 +510,7 @@ class FinetuningArguments(
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to disable the shuffling of the training set."},
|
||||
)
|
||||
early_stopping_steps: Optional[int] = field(
|
||||
early_stopping_steps: int | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Number of steps to stop training if the `metric_for_best_model` does not improve."},
|
||||
)
|
||||
@@ -530,11 +530,11 @@ class FinetuningArguments(
|
||||
return arg
|
||||
|
||||
self.freeze_trainable_modules: list[str] = split_arg(self.freeze_trainable_modules)
|
||||
self.freeze_extra_modules: Optional[list[str]] = split_arg(self.freeze_extra_modules)
|
||||
self.freeze_extra_modules: list[str] | None = split_arg(self.freeze_extra_modules)
|
||||
self.lora_alpha: int = self.lora_alpha or self.lora_rank * 2
|
||||
self.lora_target: list[str] = split_arg(self.lora_target)
|
||||
self.oft_target: list[str] = split_arg(self.oft_target)
|
||||
self.additional_target: Optional[list[str]] = split_arg(self.additional_target)
|
||||
self.additional_target: list[str] | None = split_arg(self.additional_target)
|
||||
self.galore_target: list[str] = split_arg(self.galore_target)
|
||||
self.apollo_target: list[str] = split_arg(self.apollo_target)
|
||||
self.use_ref_model = self.stage == "dpo" and self.pref_loss not in ["orpo", "simpo"]
|
||||
|
||||
@@ -17,12 +17,11 @@
|
||||
|
||||
import json
|
||||
from dataclasses import asdict, dataclass, field, fields
|
||||
from typing import Any, Literal, Optional, Union
|
||||
from typing import Any, Literal, Self
|
||||
|
||||
import torch
|
||||
from omegaconf import OmegaConf
|
||||
from transformers.training_args import _convert_str_dict
|
||||
from typing_extensions import Self
|
||||
|
||||
from ..extras.constants import AttentionFunction, EngineName, QuantizationMethod, RopeScaling
|
||||
from ..extras.logging import get_logger
|
||||
@@ -35,13 +34,13 @@ logger = get_logger(__name__)
|
||||
class BaseModelArguments:
|
||||
r"""Arguments pertaining to the model."""
|
||||
|
||||
model_name_or_path: Optional[str] = field(
|
||||
model_name_or_path: str | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Path to the model weight or identifier from huggingface.co/models or modelscope.cn/models."
|
||||
},
|
||||
)
|
||||
adapter_name_or_path: Optional[str] = field(
|
||||
adapter_name_or_path: str | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": (
|
||||
@@ -50,11 +49,11 @@ class BaseModelArguments:
|
||||
)
|
||||
},
|
||||
)
|
||||
adapter_folder: Optional[str] = field(
|
||||
adapter_folder: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The folder containing the adapter weights to load."},
|
||||
)
|
||||
cache_dir: Optional[str] = field(
|
||||
cache_dir: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Where to store the pre-trained models downloaded from huggingface.co or modelscope.cn."},
|
||||
)
|
||||
@@ -70,17 +69,17 @@ class BaseModelArguments:
|
||||
default=False,
|
||||
metadata={"help": "Whether or not the special tokens should be split during the tokenization process."},
|
||||
)
|
||||
add_tokens: Optional[str] = field(
|
||||
add_tokens: str | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Non-special tokens to be added into the tokenizer. Use commas to separate multiple tokens."
|
||||
},
|
||||
)
|
||||
add_special_tokens: Optional[str] = field(
|
||||
add_special_tokens: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Special tokens to be added into the tokenizer. Use commas to separate multiple tokens."},
|
||||
)
|
||||
new_special_tokens_config: Optional[str] = field(
|
||||
new_special_tokens_config: str | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": (
|
||||
@@ -110,7 +109,7 @@ class BaseModelArguments:
|
||||
default=True,
|
||||
metadata={"help": "Whether or not to use memory-efficient model loading."},
|
||||
)
|
||||
rope_scaling: Optional[RopeScaling] = field(
|
||||
rope_scaling: RopeScaling | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Which scaling strategy should be adopted for the RoPE embeddings."},
|
||||
)
|
||||
@@ -122,7 +121,7 @@ class BaseModelArguments:
|
||||
default=False,
|
||||
metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."},
|
||||
)
|
||||
mixture_of_depths: Optional[Literal["convert", "load"]] = field(
|
||||
mixture_of_depths: Literal["convert", "load"] | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Convert the model to mixture-of-depths (MoD) or load the MoD model."},
|
||||
)
|
||||
@@ -138,7 +137,7 @@ class BaseModelArguments:
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to enable liger kernel for faster training."},
|
||||
)
|
||||
moe_aux_loss_coef: Optional[float] = field(
|
||||
moe_aux_loss_coef: float | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Coefficient of the auxiliary router loss in mixture-of-experts model."},
|
||||
)
|
||||
@@ -182,15 +181,15 @@ class BaseModelArguments:
|
||||
default="auto",
|
||||
metadata={"help": "Data type for model weights and activations at inference."},
|
||||
)
|
||||
hf_hub_token: Optional[str] = field(
|
||||
hf_hub_token: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Auth token to log in with Hugging Face Hub."},
|
||||
)
|
||||
ms_hub_token: Optional[str] = field(
|
||||
ms_hub_token: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Auth token to log in with ModelScope Hub."},
|
||||
)
|
||||
om_hub_token: Optional[str] = field(
|
||||
om_hub_token: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Auth token to log in with Modelers Hub."},
|
||||
)
|
||||
@@ -283,7 +282,7 @@ class QuantizationArguments:
|
||||
default=QuantizationMethod.BNB,
|
||||
metadata={"help": "Quantization method to use for on-the-fly quantization."},
|
||||
)
|
||||
quantization_bit: Optional[int] = field(
|
||||
quantization_bit: int | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The number of bits to quantize the model using on-the-fly quantization."},
|
||||
)
|
||||
@@ -295,7 +294,7 @@ class QuantizationArguments:
|
||||
default=True,
|
||||
metadata={"help": "Whether or not to use double quantization in bitsandbytes int4 training."},
|
||||
)
|
||||
quantization_device_map: Optional[Literal["auto"]] = field(
|
||||
quantization_device_map: Literal["auto"] | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Device map used to infer the 4-bit quantized model, needs bitsandbytes>=0.43.0."},
|
||||
)
|
||||
@@ -375,7 +374,7 @@ class ProcessorArguments:
|
||||
class ExportArguments:
|
||||
r"""Arguments pertaining to the model export."""
|
||||
|
||||
export_dir: Optional[str] = field(
|
||||
export_dir: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to the directory to save the exported model."},
|
||||
)
|
||||
@@ -387,11 +386,11 @@ class ExportArguments:
|
||||
default="cpu",
|
||||
metadata={"help": "The device used in model export, use `auto` to accelerate exporting."},
|
||||
)
|
||||
export_quantization_bit: Optional[int] = field(
|
||||
export_quantization_bit: int | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The number of bits to quantize the exported model."},
|
||||
)
|
||||
export_quantization_dataset: Optional[str] = field(
|
||||
export_quantization_dataset: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to the dataset or dataset name to use in quantizing the exported model."},
|
||||
)
|
||||
@@ -407,7 +406,7 @@ class ExportArguments:
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to save the `.bin` files instead of `.safetensors`."},
|
||||
)
|
||||
export_hub_model_id: Optional[str] = field(
|
||||
export_hub_model_id: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The name of the repository if push the model to the Hugging Face hub."},
|
||||
)
|
||||
@@ -437,7 +436,7 @@ class VllmArguments:
|
||||
default=32,
|
||||
metadata={"help": "Maximum rank of all LoRAs in the vLLM engine."},
|
||||
)
|
||||
vllm_config: Optional[Union[dict, str]] = field(
|
||||
vllm_config: dict | str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Config to initialize the vllm engine. Please use JSON strings."},
|
||||
)
|
||||
@@ -463,7 +462,7 @@ class SGLangArguments:
|
||||
default=-1,
|
||||
metadata={"help": "Tensor parallel size for the SGLang engine."},
|
||||
)
|
||||
sglang_config: Optional[Union[dict, str]] = field(
|
||||
sglang_config: dict | str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Config to initialize the SGLang engine. Please use JSON strings."},
|
||||
)
|
||||
@@ -487,21 +486,21 @@ class KTransformersArguments:
|
||||
default=False,
|
||||
metadata={"help": "Whether To Use KTransformers Optimizations For LoRA Training."},
|
||||
)
|
||||
kt_optimize_rule: Optional[str] = field(
|
||||
kt_optimize_rule: str | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Path To The KTransformers Optimize Rule; See https://github.com/kvcache-ai/ktransformers/."
|
||||
},
|
||||
)
|
||||
cpu_infer: Optional[int] = field(
|
||||
cpu_infer: int | None = field(
|
||||
default=32,
|
||||
metadata={"help": "Number Of CPU Cores Used For Computation."},
|
||||
)
|
||||
chunk_size: Optional[int] = field(
|
||||
chunk_size: int | None = field(
|
||||
default=8192,
|
||||
metadata={"help": "Chunk Size Used For CPU Compute In KTransformers."},
|
||||
)
|
||||
mode: Optional[str] = field(
|
||||
mode: str | None = field(
|
||||
default="normal",
|
||||
metadata={"help": "Normal Or Long_Context For Llama Models."},
|
||||
)
|
||||
@@ -539,17 +538,17 @@ class ModelArguments(
|
||||
The class on the most right will be displayed first.
|
||||
"""
|
||||
|
||||
compute_dtype: Optional[torch.dtype] = field(
|
||||
compute_dtype: torch.dtype | None = field(
|
||||
default=None,
|
||||
init=False,
|
||||
metadata={"help": "Torch data type for computing model outputs, derived from `fp/bf16`. Do not specify it."},
|
||||
)
|
||||
device_map: Optional[Union[str, dict[str, Any]]] = field(
|
||||
device_map: str | dict[str, Any] | None = field(
|
||||
default=None,
|
||||
init=False,
|
||||
metadata={"help": "Device map for model placement, derived from training stage. Do not specify it."},
|
||||
)
|
||||
model_max_length: Optional[int] = field(
|
||||
model_max_length: int | None = field(
|
||||
default=None,
|
||||
init=False,
|
||||
metadata={"help": "The maximum input length for model, derived from `cutoff_len`. Do not specify it."},
|
||||
|
||||
@@ -18,7 +18,7 @@
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
@@ -65,7 +65,7 @@ else:
|
||||
_TRAIN_MCA_CLS = tuple()
|
||||
|
||||
|
||||
def read_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> Union[dict[str, Any], list[str]]:
|
||||
def read_args(args: dict[str, Any] | list[str] | None = None) -> dict[str, Any] | list[str]:
|
||||
r"""Get arguments from the command line or a config file."""
|
||||
if args is not None:
|
||||
return args
|
||||
@@ -83,7 +83,7 @@ def read_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> Union[
|
||||
|
||||
|
||||
def _parse_args(
|
||||
parser: "HfArgumentParser", args: Optional[Union[dict[str, Any], list[str]]] = None, allow_extra_keys: bool = False
|
||||
parser: "HfArgumentParser", args: dict[str, Any] | list[str] | None = None, allow_extra_keys: bool = False
|
||||
) -> tuple[Any]:
|
||||
args = read_args(args)
|
||||
if isinstance(args, dict):
|
||||
@@ -205,13 +205,13 @@ def _check_extra_dependencies(
|
||||
check_version("rouge_chinese", mandatory=True)
|
||||
|
||||
|
||||
def _parse_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _TRAIN_CLS:
|
||||
def _parse_train_args(args: dict[str, Any] | list[str] | None = None) -> _TRAIN_CLS:
|
||||
parser = HfArgumentParser(_TRAIN_ARGS)
|
||||
allow_extra_keys = is_env_enabled("ALLOW_EXTRA_ARGS")
|
||||
return _parse_args(parser, args, allow_extra_keys=allow_extra_keys)
|
||||
|
||||
|
||||
def _parse_train_mca_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _TRAIN_MCA_CLS:
|
||||
def _parse_train_mca_args(args: dict[str, Any] | list[str] | None = None) -> _TRAIN_MCA_CLS:
|
||||
parser = HfArgumentParser(_TRAIN_MCA_ARGS)
|
||||
allow_extra_keys = is_env_enabled("ALLOW_EXTRA_ARGS")
|
||||
model_args, data_args, training_args, finetuning_args, generating_args = _parse_args(
|
||||
@@ -232,25 +232,25 @@ def _configure_mca_training_args(training_args, data_args, finetuning_args) -> N
|
||||
finetuning_args.use_mca = True
|
||||
|
||||
|
||||
def _parse_infer_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _INFER_CLS:
|
||||
def _parse_infer_args(args: dict[str, Any] | list[str] | None = None) -> _INFER_CLS:
|
||||
parser = HfArgumentParser(_INFER_ARGS)
|
||||
allow_extra_keys = is_env_enabled("ALLOW_EXTRA_ARGS")
|
||||
return _parse_args(parser, args, allow_extra_keys=allow_extra_keys)
|
||||
|
||||
|
||||
def _parse_eval_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _EVAL_CLS:
|
||||
def _parse_eval_args(args: dict[str, Any] | list[str] | None = None) -> _EVAL_CLS:
|
||||
parser = HfArgumentParser(_EVAL_ARGS)
|
||||
allow_extra_keys = is_env_enabled("ALLOW_EXTRA_ARGS")
|
||||
return _parse_args(parser, args, allow_extra_keys=allow_extra_keys)
|
||||
|
||||
|
||||
def get_ray_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> RayArguments:
|
||||
def get_ray_args(args: dict[str, Any] | list[str] | None = None) -> RayArguments:
|
||||
parser = HfArgumentParser(RayArguments)
|
||||
(ray_args,) = _parse_args(parser, args, allow_extra_keys=True)
|
||||
return ray_args
|
||||
|
||||
|
||||
def get_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _TRAIN_CLS:
|
||||
def get_train_args(args: dict[str, Any] | list[str] | None = None) -> _TRAIN_CLS:
|
||||
if is_env_enabled("USE_MCA"):
|
||||
model_args, data_args, training_args, finetuning_args, generating_args = _parse_train_mca_args(args)
|
||||
else:
|
||||
@@ -306,18 +306,15 @@ def get_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _
|
||||
if training_args.do_train and data_args.dataset is None:
|
||||
raise ValueError("Please specify dataset for training.")
|
||||
|
||||
if (training_args.do_eval or training_args.do_predict) and (
|
||||
if (training_args.do_eval or training_args.do_predict or training_args.predict_with_generate) and (
|
||||
data_args.eval_dataset is None and data_args.val_size < 1e-6
|
||||
):
|
||||
raise ValueError("Please specify dataset for evaluation.")
|
||||
raise ValueError("Please make sure eval_dataset be provided or val_size >1e-6")
|
||||
|
||||
if training_args.predict_with_generate:
|
||||
if is_deepspeed_zero3_enabled():
|
||||
raise ValueError("`predict_with_generate` is incompatible with DeepSpeed ZeRO-3.")
|
||||
|
||||
if data_args.eval_dataset is None:
|
||||
raise ValueError("Cannot use `predict_with_generate` if `eval_dataset` is None.")
|
||||
|
||||
if finetuning_args.compute_accuracy:
|
||||
raise ValueError("Cannot use `predict_with_generate` and `compute_accuracy` together.")
|
||||
|
||||
@@ -476,7 +473,7 @@ def get_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _
|
||||
return model_args, data_args, training_args, finetuning_args, generating_args
|
||||
|
||||
|
||||
def get_infer_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _INFER_CLS:
|
||||
def get_infer_args(args: dict[str, Any] | list[str] | None = None) -> _INFER_CLS:
|
||||
model_args, data_args, finetuning_args, generating_args = _parse_infer_args(args)
|
||||
|
||||
# Setup logging
|
||||
@@ -511,7 +508,7 @@ def get_infer_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _
|
||||
return model_args, data_args, finetuning_args, generating_args
|
||||
|
||||
|
||||
def get_eval_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _EVAL_CLS:
|
||||
def get_eval_args(args: dict[str, Any] | list[str] | None = None) -> _EVAL_CLS:
|
||||
model_args, data_args, eval_args, finetuning_args = _parse_eval_args(args)
|
||||
|
||||
# Setup logging
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Literal, Optional, Union
|
||||
from typing import Literal
|
||||
|
||||
from transformers import Seq2SeqTrainingArguments
|
||||
from transformers.training_args import _convert_str_dict
|
||||
@@ -40,7 +40,7 @@ else:
|
||||
class RayArguments:
|
||||
r"""Arguments pertaining to the Ray training."""
|
||||
|
||||
ray_run_name: Optional[str] = field(
|
||||
ray_run_name: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The training results will be saved at `<ray_storage_path>/ray_run_name`."},
|
||||
)
|
||||
@@ -48,7 +48,7 @@ class RayArguments:
|
||||
default="./saves",
|
||||
metadata={"help": "The storage path to save training results to"},
|
||||
)
|
||||
ray_storage_filesystem: Optional[Literal["s3", "gs", "gcs"]] = field(
|
||||
ray_storage_filesystem: Literal["s3", "gs", "gcs"] | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The storage filesystem to use. If None specified, local filesystem will be used."},
|
||||
)
|
||||
@@ -56,7 +56,7 @@ class RayArguments:
|
||||
default=1,
|
||||
metadata={"help": "The number of workers for Ray training. Default is 1 worker."},
|
||||
)
|
||||
resources_per_worker: Union[dict, str] = field(
|
||||
resources_per_worker: dict | str = field(
|
||||
default_factory=lambda: {"GPU": 1},
|
||||
metadata={"help": "The resources per worker for Ray training. Default is to use 1 GPU per worker."},
|
||||
)
|
||||
@@ -64,7 +64,7 @@ class RayArguments:
|
||||
default="PACK",
|
||||
metadata={"help": "The placement strategy for Ray training. Default is PACK."},
|
||||
)
|
||||
ray_init_kwargs: Optional[Union[dict, str]] = field(
|
||||
ray_init_kwargs: dict | str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The arguments to pass to ray.init for Ray training. Default is None."},
|
||||
)
|
||||
|
||||
@@ -15,7 +15,6 @@
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, Optional, TypedDict
|
||||
|
||||
import torch
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModelForCausalLM,
|
||||
@@ -158,6 +157,7 @@ def load_model(
|
||||
if model is None and not lazy_load:
|
||||
init_kwargs["config"] = config
|
||||
init_kwargs["pretrained_model_name_or_path"] = model_args.model_name_or_path
|
||||
init_kwargs["torch_dtype"] = "auto"
|
||||
|
||||
if model_args.mixture_of_depths == "load":
|
||||
model = load_mod_pretrained_model(**init_kwargs)
|
||||
@@ -205,10 +205,6 @@ def load_model(
|
||||
|
||||
if not is_trainable:
|
||||
model.requires_grad_(False)
|
||||
for param in model.parameters():
|
||||
if param.data.dtype == torch.float32 and model_args.compute_dtype != torch.float32:
|
||||
param.data = param.data.to(model_args.compute_dtype)
|
||||
|
||||
model.eval()
|
||||
else:
|
||||
model.train()
|
||||
|
||||
@@ -31,6 +31,18 @@ logger = logging.get_logger(__name__)
|
||||
def configure_attn_implementation(config: "PretrainedConfig", model_args: "ModelArguments") -> None:
|
||||
from transformers.utils import is_flash_attn_2_available
|
||||
|
||||
if getattr(config, "model_type", None) == "gpt_oss":
|
||||
from transformers.integrations.hub_kernels import load_and_register_kernel
|
||||
|
||||
flash_attn3_kernel = "kernels-community/vllm-flash-attn3"
|
||||
load_and_register_kernel(flash_attn3_kernel)
|
||||
setattr(config, "_attn_implementation", flash_attn3_kernel)
|
||||
setattr(config, "_attn_implementation_internal", flash_attn3_kernel)
|
||||
model_args.flash_attn = AttentionFunction.FA3
|
||||
|
||||
logger.info_rank0("Using FlashAttention-3 with attention sink for the gpt-oss model.")
|
||||
return
|
||||
|
||||
if getattr(config, "model_type", None) == "gemma2":
|
||||
if model_args.flash_attn == AttentionFunction.AUTO or model_args.flash_attn == AttentionFunction.FA2:
|
||||
if is_flash_attn_2_available():
|
||||
|
||||
@@ -20,9 +20,10 @@
|
||||
|
||||
import inspect
|
||||
import os
|
||||
from collections.abc import Callable
|
||||
from functools import WRAPPER_ASSIGNMENTS, partial, wraps
|
||||
from types import MethodType
|
||||
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
@@ -77,6 +77,12 @@ def apply_liger_kernel(
|
||||
from liger_kernel.transformers import apply_liger_kernel_to_qwen3 as apply_liger_kernel
|
||||
elif model_type == "qwen3_moe":
|
||||
from liger_kernel.transformers import apply_liger_kernel_to_qwen3_moe as apply_liger_kernel
|
||||
elif model_type == "gpt_oss":
|
||||
try:
|
||||
from liger_kernel.transformers import apply_liger_kernel_to_gpt_oss as apply_liger_kernel
|
||||
except ImportError:
|
||||
logger.warning_rank0("Please install liger-kernel from https://github.com/Comet0322/Liger-Kernel.")
|
||||
return
|
||||
else:
|
||||
logger.warning_rank0("Current model does not support liger kernel.")
|
||||
return
|
||||
|
||||
@@ -82,6 +82,11 @@ def add_z3_leaf_module(model: "PreTrainedModel") -> None:
|
||||
|
||||
_set_z3_leaf_modules(model, [Glm4vMoeTextMoE])
|
||||
|
||||
if model_type == "gpt_oss":
|
||||
from transformers.models.gpt_oss.modeling_gpt_oss import GptOssMLP
|
||||
|
||||
_set_z3_leaf_modules(model, [GptOssMLP])
|
||||
|
||||
if model_type == "jamba":
|
||||
from transformers.models.jamba.modeling_jamba import JambaSparseMoeBlock
|
||||
|
||||
|
||||
@@ -22,7 +22,7 @@ from typing import TYPE_CHECKING, Any
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from transformers import BitsAndBytesConfig, EetqConfig, FineGrainedFP8Config, GPTQConfig, HqqConfig
|
||||
from transformers import BitsAndBytesConfig, EetqConfig, GPTQConfig, HqqConfig
|
||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||
from transformers.modeling_utils import is_fsdp_enabled
|
||||
|
||||
@@ -94,10 +94,26 @@ def configure_quantization(
|
||||
quantization_config: dict[str, Any] = getattr(config, "quantization_config", None)
|
||||
quant_method = quantization_config.get("quant_method", "")
|
||||
|
||||
if quant_method != QuantizationMethod.MXFP4 and (is_deepspeed_zero3_enabled() or is_fsdp_enabled()):
|
||||
if quant_method not in (QuantizationMethod.MXFP4, QuantizationMethod.FP8) and (
|
||||
is_deepspeed_zero3_enabled() or is_fsdp_enabled()
|
||||
):
|
||||
# mxfp4 will dequant the model weights
|
||||
raise ValueError("DeepSpeed ZeRO-3 or FSDP is incompatible with PTQ-quantized models.")
|
||||
|
||||
if quant_method == QuantizationMethod.MXFP4:
|
||||
from transformers import Mxfp4Config
|
||||
|
||||
quant_config = Mxfp4Config(dequantize=True)
|
||||
init_kwargs["quantization_config"] = quant_config
|
||||
init_kwargs["ignore_mismatched_sizes"] = True
|
||||
|
||||
if quant_method == QuantizationMethod.FP8:
|
||||
from transformers import FineGrainedFP8Config
|
||||
|
||||
quant_config = FineGrainedFP8Config(dequantize=True)
|
||||
init_kwargs["quantization_config"] = quant_config
|
||||
init_kwargs["ignore_mismatched_sizes"] = True
|
||||
|
||||
if quant_method == QuantizationMethod.GPTQ:
|
||||
check_version("gptqmodel>=2.0.0", mandatory=True)
|
||||
quantization_config.pop("disable_exllama", None) # remove deprecated args
|
||||
@@ -110,10 +126,6 @@ def configure_quantization(
|
||||
check_version("aqlm>=1.1.0", mandatory=True)
|
||||
quantization_config["bits"] = 2
|
||||
|
||||
if quant_method == QuantizationMethod.FP8 and is_trainable:
|
||||
quant_config = FineGrainedFP8Config(dequantize=True)
|
||||
init_kwargs["quantization_config"] = quant_config
|
||||
|
||||
quant_bits = quantization_config.get("bits", "?")
|
||||
logger.info_rank0(f"Loading {quant_bits}-bit {quant_method.upper()}-quantized model.")
|
||||
|
||||
|
||||
@@ -51,9 +51,13 @@ logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def patch_qwen3_omni_moe_thinker_text_sparse_moe_block():
|
||||
if is_transformers_version_greater_than("4.57.0"):
|
||||
if is_transformers_version_greater_than("4.57.0") and not is_transformers_version_greater_than("4.58.0"):
|
||||
from .model_utils.moe import Qwen3OmniMoeThinkerTextSparseMoeBlock
|
||||
|
||||
logger.warning_rank0(
|
||||
"You are using transformers with 4.x version, the Qwen3OmniMoeThinkerTextSparseMoeBlock will have some issues about deepspeed zero2 and fsdp2 training, so that we patched this model to avoid it. Transformers v5.0.0rc0 has fixed the issue, you can also try to update the transformers to using qwen3_omni. See more information on https://github.com/hiyouga/LLaMA-Factory/issues/9628."
|
||||
)
|
||||
|
||||
modeling_qwen3_omni_moe.Qwen3OmniMoeThinkerTextSparseMoeBlock = Qwen3OmniMoeThinkerTextSparseMoeBlock
|
||||
|
||||
|
||||
@@ -152,11 +156,8 @@ def patch_config(
|
||||
# deepspeed zero3 is not compatible with low_cpu_mem_usage
|
||||
init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage and (not is_deepspeed_zero3_enabled())
|
||||
|
||||
# do not cast data type of the model deepspeed zero3 without qlora
|
||||
if not (is_deepspeed_zero3_enabled() and model_args.quantization_bit is None):
|
||||
init_kwargs["torch_dtype"] = model_args.compute_dtype
|
||||
|
||||
if init_kwargs["low_cpu_mem_usage"] and not is_fsdp_enabled(): # fsdp does not need device map
|
||||
# fsdp/deepspeed zero3 does not need device map
|
||||
if not (is_deepspeed_zero3_enabled() or is_fsdp_enabled()) and init_kwargs["low_cpu_mem_usage"]:
|
||||
if "device_map" not in init_kwargs and model_args.device_map:
|
||||
init_kwargs["device_map"] = model_args.device_map # device map requires low_cpu_mem_usage=True
|
||||
|
||||
|
||||
62
src/llamafactory/train/dpo/ktrainer.py
Normal file
62
src/llamafactory/train/dpo/ktrainer.py
Normal file
@@ -0,0 +1,62 @@
|
||||
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
|
||||
#
|
||||
# This code is inspired by the HuggingFace's TRL library.
|
||||
# https://github.com/huggingface/trl/blob/v0.8.0/trl/trainer/dpo_trainer.py
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
from ktransformers.sft.lora import KTrainer # type: ignore
|
||||
from typing_extensions import override
|
||||
|
||||
from ..trainer_utils import get_batch_logps, nested_detach
|
||||
from .trainer import CustomDPOTrainer
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PreTrainedModel
|
||||
|
||||
|
||||
class KDPOTrainer(KTrainer, CustomDPOTrainer):
|
||||
@override
|
||||
def concatenated_forward(
|
||||
self, model: "PreTrainedModel", batch: dict[str, "torch.Tensor"], is_ref_model: bool = False
|
||||
) -> tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
|
||||
r"""Compute the sum log probabilities of the labels under given logits if loss_type is not IPO, ORPO or SimPO.
|
||||
|
||||
Otherwise the average log probabilities.
|
||||
"""
|
||||
if self.finetuning_args.use_ref_model:
|
||||
batch = nested_detach(batch, clone=True) # avoid error
|
||||
|
||||
labels = batch.pop("labels") # dpo do not need compute loss in forward
|
||||
all_logits: torch.Tensor = model(**batch, return_dict=True, use_cache=False).logits.to(torch.float32)
|
||||
all_logits = all_logits.to("cpu")
|
||||
labels = labels.to(all_logits.device)
|
||||
all_logps, valid_length = get_batch_logps(
|
||||
logits=all_logits, labels=labels, ld_alpha=(self.ld_alpha if not is_ref_model else None)
|
||||
)
|
||||
if self.loss_type in ["ipo", "orpo", "simpo"]:
|
||||
all_logps = all_logps / valid_length
|
||||
|
||||
batch_size = batch["input_ids"].size(0) // 2
|
||||
chosen_logps, rejected_logps = all_logps.split(batch_size, dim=0)
|
||||
chosen_logits, rejected_logits = all_logits.split(batch_size, dim=0)
|
||||
chosen_length, _ = valid_length.split(batch_size, dim=0)
|
||||
|
||||
if self.loss_type in ["ipo", "orpo", "simpo"]:
|
||||
return chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_logps
|
||||
else:
|
||||
return chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_logps / chosen_length
|
||||
@@ -218,9 +218,10 @@ class CustomDPOTrainer(DPOTrainer):
|
||||
if self.finetuning_args.use_ref_model:
|
||||
batch = nested_detach(batch, clone=True) # avoid error
|
||||
|
||||
labels = batch.pop("labels") # dpo do not need compute loss in forward
|
||||
all_logits: torch.Tensor = model(**batch, return_dict=True, use_cache=False).logits.to(torch.float32)
|
||||
all_logps, valid_length = get_batch_logps(
|
||||
logits=all_logits, labels=batch["labels"], ld_alpha=(self.ld_alpha if not is_ref_model else None)
|
||||
logits=all_logits, labels=labels, ld_alpha=(self.ld_alpha if not is_ref_model else None)
|
||||
)
|
||||
if self.loss_type in ["ipo", "orpo", "simpo"]:
|
||||
all_logps = all_logps / valid_length
|
||||
|
||||
@@ -24,7 +24,6 @@ from ...extras.ploting import plot_loss
|
||||
from ...hparams import ModelArguments
|
||||
from ...model import load_model, load_tokenizer
|
||||
from ..trainer_utils import create_modelcard_and_push, create_ref_model
|
||||
from .trainer import CustomDPOTrainer
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -63,6 +62,16 @@ def run_dpo(
|
||||
else:
|
||||
ref_model = None
|
||||
|
||||
if model_args.use_kt:
|
||||
from ktransformers.util.globals import GLOBAL_CONFIG # type: ignore
|
||||
|
||||
from .ktrainer import KDPOTrainer as CustomDPOTrainer
|
||||
|
||||
GLOBAL_CONFIG._config["mod"] = "sft"
|
||||
|
||||
else:
|
||||
from .trainer import CustomDPOTrainer
|
||||
|
||||
# Initialize our Trainer
|
||||
trainer = CustomDPOTrainer(
|
||||
model=model,
|
||||
|
||||
@@ -1,18 +0,0 @@
|
||||
# Copyright 2025 the KVCache.AI team, Approaching AI, and the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .workflow import run_sft
|
||||
|
||||
|
||||
__all__ = ["run_sft"]
|
||||
@@ -1,113 +0,0 @@
|
||||
# Copyright 2025 the KVCache.AI team, Approaching AI, and the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from ...data import SFTDataCollatorWith4DAttentionMask, get_dataset, get_template_and_fix_tokenizer
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
from ...extras.logging import get_logger
|
||||
from ...extras.misc import calculate_tps
|
||||
from ...extras.ploting import plot_loss
|
||||
from ...model import load_model, load_tokenizer
|
||||
from ..trainer_utils import create_modelcard_and_push
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
||||
|
||||
from ...hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def run_sft(
|
||||
model_args: "ModelArguments",
|
||||
data_args: "DataArguments",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
generating_args: "GeneratingArguments",
|
||||
callbacks: Optional[list["TrainerCallback"]] = None,
|
||||
):
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
tokenizer = tokenizer_module["tokenizer"]
|
||||
template = get_template_and_fix_tokenizer(tokenizer, data_args)
|
||||
dataset_module = get_dataset(template, model_args, data_args, training_args, stage="sft", **tokenizer_module)
|
||||
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
|
||||
|
||||
from ktransformers.util.globals import GLOBAL_CONFIG
|
||||
|
||||
GLOBAL_CONFIG._config["mod"] = "sft"
|
||||
|
||||
if getattr(model, "is_quantized", False) and not training_args.do_train:
|
||||
setattr(model, "_hf_peft_config_loaded", True) # hack here: make model compatible with prediction
|
||||
|
||||
data_collator = SFTDataCollatorWith4DAttentionMask(
|
||||
template=template,
|
||||
model=model if not training_args.predict_with_generate else None,
|
||||
pad_to_multiple_of=8 if training_args.do_train else None, # for shift short attention
|
||||
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id,
|
||||
block_diag_attn=model_args.block_diag_attn,
|
||||
attn_implementation=getattr(model.config, "_attn_implementation", None),
|
||||
compute_dtype=model_args.compute_dtype,
|
||||
**tokenizer_module,
|
||||
)
|
||||
|
||||
# Metric utils
|
||||
metric_module = {}
|
||||
if training_args.predict_with_generate:
|
||||
raise NotImplementedError("`predict_with_generate` is not supported in KTransformers SFT yet.")
|
||||
elif finetuning_args.compute_accuracy:
|
||||
raise NotImplementedError("`compute_accuracy` is not supported in KTransformers SFT yet.")
|
||||
|
||||
# Initialize our Trainer
|
||||
from ktransformers.sft.lora import KTrainer
|
||||
|
||||
trainer = KTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
tokenizer=tokenizer_module,
|
||||
data_collator=data_collator,
|
||||
callbacks=callbacks,
|
||||
**dataset_module,
|
||||
**metric_module,
|
||||
)
|
||||
trainer.model_accepts_loss_kwargs = False
|
||||
|
||||
# Training
|
||||
if training_args.do_train:
|
||||
model.config.use_cache = False
|
||||
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
|
||||
trainer.save_model()
|
||||
if finetuning_args.include_effective_tokens_per_second:
|
||||
train_result.metrics["effective_tokens_per_sec"] = calculate_tps(
|
||||
dataset_module["train_dataset"], train_result.metrics, stage="sft"
|
||||
)
|
||||
|
||||
trainer.log_metrics("train", train_result.metrics)
|
||||
trainer.save_metrics("train", train_result.metrics)
|
||||
trainer.save_state()
|
||||
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
|
||||
keys = ["loss"]
|
||||
if isinstance(dataset_module.get("eval_dataset"), dict):
|
||||
keys += sum(
|
||||
[[f"eval_{key}_loss", f"eval_{key}_accuracy"] for key in dataset_module["eval_dataset"].keys()], []
|
||||
)
|
||||
else:
|
||||
keys += ["eval_loss", "eval_accuracy"]
|
||||
|
||||
plot_loss(training_args.output_dir, keys=keys)
|
||||
|
||||
# Create model card
|
||||
create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args)
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -28,7 +28,7 @@ if TYPE_CHECKING:
|
||||
class ComputeAccuracy:
|
||||
r"""Compute reward accuracy and support `batch_eval_metrics`."""
|
||||
|
||||
def _dump(self) -> Optional[dict[str, float]]:
|
||||
def _dump(self) -> dict[str, float] | None:
|
||||
result = None
|
||||
if hasattr(self, "score_dict"):
|
||||
result = {k: float(np.mean(v)) for k, v in self.score_dict.items()}
|
||||
@@ -39,7 +39,7 @@ class ComputeAccuracy:
|
||||
def __post_init__(self):
|
||||
self._dump()
|
||||
|
||||
def __call__(self, eval_preds: "EvalPrediction", compute_result: bool = True) -> Optional[dict[str, float]]:
|
||||
def __call__(self, eval_preds: "EvalPrediction", compute_result: bool = True) -> dict[str, float] | None:
|
||||
chosen_scores, rejected_scores = numpify(eval_preds.predictions[0]), numpify(eval_preds.predictions[1])
|
||||
if not chosen_scores.shape:
|
||||
self.score_dict["accuracy"].append(chosen_scores > rejected_scores)
|
||||
|
||||
@@ -68,6 +68,12 @@ def run_sft(
|
||||
|
||||
# Metric utils
|
||||
metric_module = {}
|
||||
if model_args.use_kt:
|
||||
if training_args.predict_with_generate:
|
||||
raise NotImplementedError("`predict_with_generate` is not supported in KTransformers SFT yet.")
|
||||
elif finetuning_args.compute_accuracy:
|
||||
raise NotImplementedError("`compute_accuracy` is not supported in KTransformers SFT yet.")
|
||||
|
||||
if training_args.predict_with_generate:
|
||||
metric_module["compute_metrics"] = ComputeSimilarity(tokenizer=tokenizer)
|
||||
elif finetuning_args.compute_accuracy:
|
||||
@@ -92,6 +98,25 @@ def run_sft(
|
||||
gen_kwargs["pad_token_id"] = tokenizer.pad_token_id
|
||||
|
||||
# Initialize our Trainer
|
||||
if model_args.use_kt:
|
||||
from ktransformers.sft.lora import KTrainer # type: ignore
|
||||
from ktransformers.util.globals import GLOBAL_CONFIG # type: ignore
|
||||
|
||||
GLOBAL_CONFIG._config["mod"] = "sft"
|
||||
|
||||
trainer = KTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
tokenizer=tokenizer_module,
|
||||
data_collator=data_collator,
|
||||
callbacks=callbacks,
|
||||
**dataset_module,
|
||||
**metric_module,
|
||||
)
|
||||
trainer.model_accepts_loss_kwargs = False
|
||||
model.config.use_cache = False
|
||||
|
||||
else:
|
||||
trainer = CustomSeq2SeqTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
|
||||
@@ -84,8 +84,6 @@ def load_reference_model(
|
||||
model: AutoModelForCausalLMWithValueHead = AutoModelForCausalLMWithValueHead.from_pretrained(
|
||||
model_path, torch_dtype=torch.float16, device_map="auto"
|
||||
)
|
||||
if not is_trainable:
|
||||
model.v_head = model.v_head.to(torch.float16)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@@ -19,9 +19,9 @@
|
||||
|
||||
import json
|
||||
import os
|
||||
from collections.abc import Mapping
|
||||
from collections.abc import Callable, Mapping
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
import torch
|
||||
from transformers import Trainer
|
||||
|
||||
@@ -85,13 +85,7 @@ def _training_function(config: dict[str, Any]) -> None:
|
||||
elif finetuning_args.stage == "pt":
|
||||
run_pt(model_args, data_args, training_args, finetuning_args, callbacks)
|
||||
elif finetuning_args.stage == "sft":
|
||||
if model_args.use_kt:
|
||||
from .ksft.workflow import run_sft as run_sft_kt
|
||||
|
||||
run_sft_kt(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
|
||||
else:
|
||||
run_sft(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
|
||||
|
||||
elif finetuning_args.stage == "rm":
|
||||
run_rm(model_args, data_args, training_args, finetuning_args, callbacks)
|
||||
elif finetuning_args.stage == "ppo":
|
||||
|
||||
@@ -15,10 +15,20 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Utility functions used by the distributed interface.
|
||||
|
||||
Including:
|
||||
- Environment info (rank, world_size, local_rank, etc.)
|
||||
- Accelerator info (device type, device count, etc.)
|
||||
- Collective communication operations (all_gather, all_reduce, broadcast)
|
||||
- Synchronize processes and ensure main-process-first execution order
|
||||
"""
|
||||
|
||||
import os
|
||||
from collections.abc import Callable
|
||||
from contextlib import contextmanager
|
||||
from enum import Enum, unique
|
||||
from functools import lru_cache
|
||||
from functools import lru_cache, wraps
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
@@ -46,6 +56,22 @@ class ReduceOp(str, Enum):
|
||||
MIN = "min"
|
||||
|
||||
|
||||
def requires_accelerator(fn):
|
||||
"""Decorator to check if torch.accelerator is available.
|
||||
|
||||
Note: this api requires torch>=2.7.0, otherwise it will raise an AttributeError or RuntimeError
|
||||
"""
|
||||
|
||||
@wraps(fn)
|
||||
def wrapper(*args, **kwargs):
|
||||
if not hasattr(torch, "accelerator"):
|
||||
raise RuntimeError("torch.accelerator is not available, please upgrade torch to 2.7.0 or higher.")
|
||||
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def is_distributed() -> bool:
|
||||
"""Check if distributed environment is available."""
|
||||
return os.getenv("RANK") is not None
|
||||
@@ -72,105 +98,105 @@ def get_local_world_size() -> int:
|
||||
|
||||
|
||||
@lru_cache
|
||||
@requires_accelerator
|
||||
def get_current_accelerator(check_available: bool = True) -> torch.device:
|
||||
"""Get current accelerator.
|
||||
|
||||
Note: this api requires torch>=2.7.0, otherwise it will raise an AttributeError or RuntimeError
|
||||
"""
|
||||
if not hasattr(torch, "accelerator"):
|
||||
raise RuntimeError("torch.accelerator is not available, please upgrade torch to 2.7.0 or higher.")
|
||||
|
||||
"""Get current accelerator."""
|
||||
accelerator = torch.accelerator.current_accelerator(check_available=check_available)
|
||||
if accelerator is None:
|
||||
return torch.device(DeviceType.CPU.value)
|
||||
return accelerator or torch.device(DeviceType.CPU.value)
|
||||
|
||||
return accelerator
|
||||
|
||||
@lru_cache
|
||||
@requires_accelerator
|
||||
def get_device_count() -> int:
|
||||
"""Get the number of available devices."""
|
||||
return torch.accelerator.device_count()
|
||||
|
||||
|
||||
@requires_accelerator
|
||||
def synchronize() -> None:
|
||||
"""Synchronize all processes."""
|
||||
torch.accelerator.synchronize()
|
||||
|
||||
|
||||
@requires_accelerator
|
||||
def set_device() -> None:
|
||||
"""Set current accelerator."""
|
||||
torch.accelerator.set_device_index(get_local_rank())
|
||||
|
||||
|
||||
def is_torch_cuda_available():
|
||||
"""Check if CUDA is available."""
|
||||
return get_current_accelerator().type == DeviceType.CUDA
|
||||
|
||||
|
||||
def is_torch_mps_available():
|
||||
"""Check if MPS is available."""
|
||||
return get_current_accelerator().type == DeviceType.MPS
|
||||
|
||||
|
||||
def is_torch_npu_available():
|
||||
"""Check if NPU is available."""
|
||||
return get_current_accelerator().type == DeviceType.NPU
|
||||
|
||||
|
||||
def is_torch_xpu_available():
|
||||
"""Check if XPU is available."""
|
||||
return get_current_accelerator().type == DeviceType.XPU
|
||||
|
||||
|
||||
def get_current_device() -> "torch.device":
|
||||
r"""Get the current available device."""
|
||||
if is_torch_xpu_available():
|
||||
device = "xpu:{}".format(os.getenv("LOCAL_RANK", "0"))
|
||||
elif is_torch_npu_available():
|
||||
device = "npu:{}".format(os.getenv("LOCAL_RANK", "0"))
|
||||
elif is_torch_mps_available():
|
||||
device = "mps:{}".format(os.getenv("LOCAL_RANK", "0"))
|
||||
elif is_torch_cuda_available():
|
||||
device = "cuda:{}".format(os.getenv("LOCAL_RANK", "0"))
|
||||
def operate_tensorlike(fn: Callable[[...], Tensor], data: TensorLike, **kwargs) -> TensorLike:
|
||||
"""Operate tensorlike data on current accelerator."""
|
||||
device = get_current_accelerator()
|
||||
is_tensor = isinstance(data, torch.Tensor)
|
||||
is_ndarray = isinstance(data, np.ndarray)
|
||||
|
||||
if is_tensor:
|
||||
orig_device = data.device
|
||||
data = data.to(device=device)
|
||||
elif is_ndarray:
|
||||
data = torch.from_numpy(data).to(device=device, dtype=torch.float)
|
||||
else:
|
||||
device = "cpu"
|
||||
data = torch.tensor(data, dtype=torch.float, device=device)
|
||||
|
||||
return torch.device(device)
|
||||
result = fn(data, **kwargs)
|
||||
|
||||
|
||||
def get_device_count() -> int:
|
||||
r"""Get the number of available devices."""
|
||||
if is_torch_xpu_available():
|
||||
return torch.xpu.device_count()
|
||||
elif is_torch_npu_available():
|
||||
return torch.npu.device_count()
|
||||
elif is_torch_mps_available():
|
||||
return torch.mps.device_count()
|
||||
elif is_torch_cuda_available():
|
||||
return torch.cuda.device_count()
|
||||
if is_tensor:
|
||||
return result.to(orig_device)
|
||||
elif is_ndarray:
|
||||
return result.cpu().numpy()
|
||||
elif result.numel() == 1:
|
||||
return result.item()
|
||||
else:
|
||||
return 0
|
||||
return result.tolist()
|
||||
|
||||
|
||||
def all_gather(tensor: Tensor, group: Optional[ProcessGroup] = None) -> Tensor:
|
||||
"""Gathers the tensor from all ranks and concats them along the first dim."""
|
||||
"""Gathers the tensor from all ranks and stacks them at the first dim."""
|
||||
world_size = get_world_size()
|
||||
device = get_current_accelerator()
|
||||
output_tensor = torch.empty(world_size * tensor.numel(), dtype=tensor.dtype, device=device)
|
||||
output_tensor = torch.empty(world_size * tensor.numel(), dtype=tensor.dtype, device=tensor.device)
|
||||
dist.all_gather_into_tensor(output_tensor, tensor, group=group)
|
||||
return output_tensor.view(-1, *tensor.size()[1:])
|
||||
return output_tensor.view(-1, *tensor.size())
|
||||
|
||||
|
||||
def all_reduce(data: TensorLike, op: ReduceOp = ReduceOp.MEAN, group: Optional[ProcessGroup] = None) -> TensorLike:
|
||||
def all_reduce(tensor: Tensor, op: ReduceOp = ReduceOp.MEAN, group: Optional[ProcessGroup] = None) -> Tensor:
|
||||
"""Performs all reduce in the given process group."""
|
||||
device = get_current_accelerator()
|
||||
is_ndarray = isinstance(data, np.ndarray)
|
||||
is_tensor = isinstance(data, torch.Tensor)
|
||||
|
||||
if is_ndarray:
|
||||
data = torch.from_numpy(data).to(device=device, dtype=torch.float)
|
||||
elif not is_tensor:
|
||||
data = torch.tensor(data, dtype=torch.float, device=device)
|
||||
|
||||
reduce_ops = {
|
||||
ReduceOp.MEAN: dist.ReduceOp.SUM,
|
||||
ReduceOp.SUM: dist.ReduceOp.SUM,
|
||||
ReduceOp.MAX: dist.ReduceOp.MAX,
|
||||
ReduceOp.MIN: dist.ReduceOp.MIN,
|
||||
}
|
||||
dist.all_reduce(data, op=reduce_ops[op], group=group)
|
||||
dist.all_reduce(tensor, op=reduce_ops[op], group=group)
|
||||
if op == ReduceOp.MEAN: # ReduceOp.AVG is not supported by the NPU backend
|
||||
data /= dist.get_world_size(group=group)
|
||||
tensor /= dist.get_world_size(group=group)
|
||||
|
||||
if is_tensor:
|
||||
return data
|
||||
elif is_ndarray:
|
||||
return data.cpu().numpy()
|
||||
elif data.numel() == 1:
|
||||
return data.item()
|
||||
else:
|
||||
return data.tolist()
|
||||
return tensor
|
||||
|
||||
|
||||
def broadcast(tensor: Tensor, src: int = 0, group: Optional[ProcessGroup] = None) -> Tensor:
|
||||
"""Broadcasts the tensor from the src process to all other processes."""
|
||||
dist.broadcast(tensor, src=src, group=group)
|
||||
return tensor
|
||||
|
||||
|
||||
@contextmanager
|
||||
|
||||
@@ -15,26 +15,27 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""A unified interface for model parallelism and data parallelism.
|
||||
|
||||
Supports model parallelism types:
|
||||
- mp_replicate: Replicate model across multiple devices.
|
||||
- mp_shard: Shard model across multiple devices.
|
||||
|
||||
And data parallelism types:
|
||||
- dp: Data parallelism.
|
||||
- cp: Context parallelism.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import timedelta
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
|
||||
from torch.distributed import init_process_group
|
||||
from torch.distributed import barrier, destroy_process_group, init_process_group
|
||||
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
|
||||
|
||||
from ..utils.types import DistributedConfig, ProcessGroup, Tensor, TensorLike
|
||||
from .helper import (
|
||||
ReduceOp,
|
||||
all_gather,
|
||||
all_reduce,
|
||||
get_current_accelerator,
|
||||
get_local_rank,
|
||||
get_local_world_size,
|
||||
get_rank,
|
||||
get_world_size,
|
||||
is_distributed,
|
||||
)
|
||||
from . import helper
|
||||
|
||||
|
||||
class Dim(str, Enum):
|
||||
@@ -52,32 +53,32 @@ class DistributedStrategy:
|
||||
|
||||
mp_replicate_size: int = 1
|
||||
"""Model parallel replicate size, default to 1."""
|
||||
mp_shard_size: Optional[int] = None
|
||||
mp_shard_size: int | None = None
|
||||
"""Model parallel shard size, default to world_size // mp_replicate_size."""
|
||||
dp_size: Optional[int] = None
|
||||
dp_size: int | None = None
|
||||
"""Data parallel size, default to world_size // cp_size."""
|
||||
cp_size: int = 1
|
||||
"""Context parallel size, default to 1."""
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if not is_distributed():
|
||||
if not helper.is_distributed():
|
||||
self.mp_shard_size = 1
|
||||
elif self.mp_shard_size is None:
|
||||
self.mp_shard_size = get_world_size() // self.mp_replicate_size
|
||||
elif self.mp_replicate_size * self.mp_shard_size != get_world_size():
|
||||
self.mp_shard_size = helper.get_world_size() // self.mp_replicate_size
|
||||
elif self.mp_replicate_size * self.mp_shard_size != helper.get_world_size():
|
||||
raise ValueError(
|
||||
f"mp_replicate_size * mp_shard_size must equal to world_size, "
|
||||
f"got {self.mp_replicate_size} * {self.mp_shard_size} != {get_world_size()}."
|
||||
f"got {self.mp_replicate_size} * {self.mp_shard_size} != {helper.get_world_size()}."
|
||||
)
|
||||
|
||||
if not is_distributed():
|
||||
if not helper.is_distributed():
|
||||
self.dp_size = 1
|
||||
elif self.dp_size is None:
|
||||
self.dp_size = get_world_size() // self.cp_size
|
||||
elif self.dp_size * self.cp_size != get_world_size():
|
||||
self.dp_size = helper.get_world_size() // self.cp_size
|
||||
elif self.dp_size * self.cp_size != helper.get_world_size():
|
||||
raise ValueError(
|
||||
f"dp_size * cp_size must equal to world_size, "
|
||||
f"got {self.dp_size} * {self.cp_size} != {get_world_size()}."
|
||||
f"got {self.dp_size} * {self.cp_size} != {helper.get_world_size()}."
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -106,20 +107,6 @@ class DistributedInterface:
|
||||
|
||||
_instance: Optional["DistributedInterface"] = None
|
||||
_initialized: bool = False
|
||||
_is_distributed = is_distributed()
|
||||
_rank = get_rank()
|
||||
_world_size = get_world_size()
|
||||
_local_rank = get_local_rank()
|
||||
_local_world_size = get_local_world_size()
|
||||
|
||||
strategy: Optional[DistributedStrategy] = None
|
||||
"""Distributed strategy."""
|
||||
model_device_mesh: Optional[DeviceMesh] = None
|
||||
"""Model parallel device mesh."""
|
||||
data_device_mesh: Optional[DeviceMesh] = None
|
||||
"""Data parallel device mesh."""
|
||||
current_accelerator = get_current_accelerator()
|
||||
"""Current accelerator."""
|
||||
|
||||
def __new__(cls, *args: Any, **kwargs: Any) -> "DistributedInterface":
|
||||
"""Singleton pattern."""
|
||||
@@ -128,10 +115,18 @@ class DistributedInterface:
|
||||
|
||||
return cls._instance
|
||||
|
||||
def __init__(self, config: Optional[DistributedConfig] = None) -> None:
|
||||
def __init__(self, config: DistributedConfig | None = None) -> None:
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
self._is_distributed = helper.is_distributed()
|
||||
self._rank = helper.get_rank()
|
||||
self._world_size = helper.get_world_size()
|
||||
self._local_rank = helper.get_local_rank()
|
||||
self._local_world_size = helper.get_local_world_size()
|
||||
self.current_accelerator = helper.get_current_accelerator()
|
||||
self.device_count = helper.get_device_count()
|
||||
|
||||
if config is None:
|
||||
self.strategy = DistributedStrategy()
|
||||
timeout = 18000
|
||||
@@ -145,6 +140,7 @@ class DistributedInterface:
|
||||
timeout = config.get("timeout", 18000)
|
||||
|
||||
if self._is_distributed:
|
||||
helper.set_device()
|
||||
init_process_group(timeout=timedelta(seconds=timeout))
|
||||
self.model_device_mesh = init_device_mesh(
|
||||
device_type=self.current_accelerator.type,
|
||||
@@ -169,65 +165,84 @@ class DistributedInterface:
|
||||
f"model_device_mesh={self.model_device_mesh}, data_device_mesh={self.data_device_mesh}"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_device_mesh(cls, dim: Optional[Dim] = None) -> Optional[DeviceMesh]:
|
||||
def get_device_mesh(self, dim: Dim | None = None) -> DeviceMesh | None:
|
||||
"""Get device mesh for specified dimension."""
|
||||
if dim is None:
|
||||
raise ValueError("dim must be specified.")
|
||||
elif cls.model_device_mesh is None:
|
||||
elif self.model_device_mesh is None:
|
||||
return None
|
||||
elif dim in cls.strategy.data_mesh_dim_names:
|
||||
return cls.data_device_mesh[dim.value]
|
||||
elif dim in self.strategy.data_mesh_dim_names:
|
||||
return self.data_device_mesh[dim.value]
|
||||
else:
|
||||
return cls.model_device_mesh[dim.value]
|
||||
return self.model_device_mesh[dim.value]
|
||||
|
||||
@classmethod
|
||||
def get_group(cls, dim: Optional[Dim] = None) -> Optional[ProcessGroup]:
|
||||
def get_group(self, dim: Dim | None = None) -> Optional[ProcessGroup]:
|
||||
"""Get process group for specified dimension."""
|
||||
if cls.model_device_mesh is None or dim is None:
|
||||
if self.model_device_mesh is None or dim is None:
|
||||
return None
|
||||
else:
|
||||
return cls.get_device_mesh(dim).get_group()
|
||||
return self.get_device_mesh(dim).get_group()
|
||||
|
||||
@classmethod
|
||||
def get_rank(cls, dim: Optional[Dim] = None) -> int:
|
||||
def get_rank(self, dim: Dim | None = None) -> int:
|
||||
"""Get parallel rank for specified dimension."""
|
||||
if cls.model_device_mesh is None:
|
||||
if self.model_device_mesh is None:
|
||||
return 0
|
||||
elif dim is None:
|
||||
return cls._rank
|
||||
return self._rank
|
||||
else:
|
||||
return cls.get_device_mesh(dim).get_local_rank()
|
||||
return self.get_device_mesh(dim).get_local_rank()
|
||||
|
||||
@classmethod
|
||||
def get_world_size(cls, dim: Optional[Dim] = None) -> int:
|
||||
def get_world_size(self, dim: Dim | None = None) -> int:
|
||||
"""Get parallel size for specified dimension."""
|
||||
if cls.model_device_mesh is None:
|
||||
if self.model_device_mesh is None:
|
||||
return 1
|
||||
elif dim is None:
|
||||
return cls._world_size
|
||||
return self._world_size
|
||||
else:
|
||||
return cls.get_device_mesh(dim).size()
|
||||
return self.get_device_mesh(dim).size()
|
||||
|
||||
@classmethod
|
||||
def get_local_rank(cls) -> int:
|
||||
def get_local_rank(self) -> int:
|
||||
"""Get parallel local rank."""
|
||||
return cls._local_rank
|
||||
return self._local_rank
|
||||
|
||||
@classmethod
|
||||
def get_local_world_size(cls) -> int:
|
||||
def get_local_world_size(self) -> int:
|
||||
"""Get parallel local world size."""
|
||||
return cls._local_world_size
|
||||
return self._local_world_size
|
||||
|
||||
@classmethod
|
||||
def all_gather(cls, data: Tensor, dim: Optional[Dim] = Dim.DP) -> Tensor:
|
||||
def all_gather(self, data: Tensor, dim: Dim | None = Dim.DP) -> Tensor:
|
||||
"""Gather tensor across specified parallel group."""
|
||||
return all_gather(data, cls.get_group(dim)) if cls.model_device_mesh is not None else data
|
||||
if self.model_device_mesh is not None:
|
||||
return helper.operate_tensorlike(helper.all_gather, data, group=self.get_group(dim))
|
||||
else:
|
||||
return data
|
||||
|
||||
@classmethod
|
||||
def all_reduce(cls, data: TensorLike, op: ReduceOp = ReduceOp.MEAN, dim: Optional[Dim] = Dim.DP) -> TensorLike:
|
||||
def all_reduce(
|
||||
self, data: TensorLike, op: helper.ReduceOp = helper.ReduceOp.MEAN, dim: Dim | None = Dim.DP
|
||||
) -> TensorLike:
|
||||
"""Reduce tensor across specified parallel group."""
|
||||
return all_reduce(data, op, cls.get_group(dim)) if cls.model_device_mesh is not None else data
|
||||
if self.model_device_mesh is not None:
|
||||
return helper.operate_tensorlike(helper.all_reduce, data, op=op, group=self.get_group(dim))
|
||||
else:
|
||||
return data
|
||||
|
||||
def broadcast(self, data: TensorLike, src: int = 0, dim: Dim | None = Dim.DP) -> TensorLike:
|
||||
"""Broadcast tensor across specified parallel group."""
|
||||
if self.model_device_mesh is not None:
|
||||
return helper.operate_tensorlike(helper.broadcast, data, src=src, group=self.get_group(dim))
|
||||
else:
|
||||
return data
|
||||
|
||||
def sync(self) -> None:
|
||||
"""Synchronize all processes."""
|
||||
helper.synchronize()
|
||||
|
||||
def barrier(self) -> None:
|
||||
"""Barrier all processes."""
|
||||
barrier()
|
||||
|
||||
def destroy(self) -> None:
|
||||
"""Destroy all processes."""
|
||||
destroy_process_group()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any
|
||||
|
||||
from omegaconf import OmegaConf
|
||||
from transformers import HfArgumentParser
|
||||
@@ -27,7 +27,7 @@ from .sample_args import SampleArguments
|
||||
from .training_args import TrainingArguments
|
||||
|
||||
|
||||
InputArgument = Optional[Union[dict[str, Any], list[str]]]
|
||||
InputArgument = dict[str, Any] | list[str] | None
|
||||
|
||||
|
||||
def validate_args(
|
||||
|
||||
@@ -18,7 +18,6 @@
|
||||
|
||||
import json
|
||||
from enum import Enum, unique
|
||||
from typing import Optional, Union
|
||||
|
||||
|
||||
class PluginConfig(dict):
|
||||
@@ -33,7 +32,7 @@ class PluginConfig(dict):
|
||||
return self["name"]
|
||||
|
||||
|
||||
PluginArgument = Optional[Union[PluginConfig, dict, str]]
|
||||
PluginArgument = PluginConfig | dict | str | None
|
||||
|
||||
|
||||
@unique
|
||||
@@ -74,7 +73,7 @@ def _convert_str_dict(data: dict) -> dict:
|
||||
return data
|
||||
|
||||
|
||||
def get_plugin_config(config: PluginArgument) -> Optional[PluginConfig]:
|
||||
def get_plugin_config(config: PluginArgument) -> PluginConfig | None:
|
||||
"""Get the plugin configuration from the argument value.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -14,12 +14,11 @@
|
||||
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataArguments:
|
||||
dataset: Optional[str] = field(
|
||||
dataset: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to the dataset."},
|
||||
)
|
||||
|
||||
@@ -14,7 +14,6 @@
|
||||
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
from .arg_utils import ModelClass, PluginConfig, get_plugin_config
|
||||
|
||||
@@ -36,15 +35,15 @@ class ModelArguments:
|
||||
default=ModelClass.LLM,
|
||||
metadata={"help": "Model class from Hugging Face."},
|
||||
)
|
||||
peft_config: Optional[PluginConfig] = field(
|
||||
peft_config: PluginConfig | None = field(
|
||||
default=None,
|
||||
metadata={"help": "PEFT configuration for the model."},
|
||||
)
|
||||
kernel_config: Optional[PluginConfig] = field(
|
||||
kernel_config: PluginConfig | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Kernel configuration for the model."},
|
||||
)
|
||||
quant_config: Optional[PluginConfig] = field(
|
||||
quant_config: PluginConfig | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Quantization configuration for the model."},
|
||||
)
|
||||
|
||||
@@ -14,7 +14,6 @@
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from .arg_utils import PluginConfig, get_plugin_config
|
||||
@@ -42,7 +41,7 @@ class TrainingArguments:
|
||||
default=False,
|
||||
metadata={"help": "Use bf16 for training."},
|
||||
)
|
||||
dist_config: Optional[PluginConfig] = field(
|
||||
dist_config: PluginConfig | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Distribution configuration for training."},
|
||||
)
|
||||
|
||||
@@ -27,7 +27,7 @@ Get Data Sample:
|
||||
|
||||
import os
|
||||
from collections.abc import Iterable
|
||||
from typing import Any, Union
|
||||
from typing import Any
|
||||
|
||||
from huggingface_hub import hf_hub_download
|
||||
from omegaconf import OmegaConf
|
||||
@@ -134,7 +134,7 @@ class DataEngine(Dataset):
|
||||
else:
|
||||
return len(self.data_index)
|
||||
|
||||
def __getitem__(self, index: Union[int, Any]) -> Union[Sample, list[Sample]]:
|
||||
def __getitem__(self, index: int | Any) -> Sample | list[Sample]:
|
||||
"""Get dataset item.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -97,7 +97,7 @@ class ModelLoader:
|
||||
self.args.model,
|
||||
config=self.model_config,
|
||||
dtype="auto",
|
||||
device_map=DistributedInterface.current_accelerator,
|
||||
device_map=DistributedInterface().current_accelerator,
|
||||
trust_remote_code=self.args.trust_remote_code,
|
||||
)
|
||||
|
||||
|
||||
@@ -12,36 +12,108 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from collections import defaultdict
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from ...utils.types import Processor, Tensor, TorchDataset
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from torch.utils.data._utils.collate import default_collate
|
||||
|
||||
from ....extras.constants import IGNORE_INDEX
|
||||
from ...plugins.data_plugins.template import Template
|
||||
from ...utils.types import Processor, Tensor
|
||||
|
||||
|
||||
def len2culen(seqlens: "torch.Tensor") -> "torch.Tensor": # FIXME move to utils
|
||||
"""Convert sequence lengths to cumulative sequence lengths."""
|
||||
return F.pad(torch.cumsum(seqlens, dim=0), (1, 0)).type(torch.int32)
|
||||
|
||||
|
||||
class DataCollator:
|
||||
"""Default Data collator."""
|
||||
|
||||
def __init__(self, processor: Processor) -> None:
|
||||
self.processor = processor
|
||||
processor: "Processor" # processor name -> map to encode_messages function
|
||||
|
||||
def __post_init__(self):
|
||||
# callback for text tokenizer
|
||||
self.tokenizer = self.processor.tokenizer if hasattr(self.processor, "tokenizer") else self.processor
|
||||
|
||||
def __call__(self, features: list[dict[str, Any]]) -> dict[str, Tensor]:
|
||||
"""Collate features into a batch."""
|
||||
for feature in features:
|
||||
pass
|
||||
batch = defaultdict(list)
|
||||
|
||||
# batching features
|
||||
for feature in features:
|
||||
for key in feature.keys():
|
||||
batch[key].append(feature[key])
|
||||
|
||||
for key in batch.keys():
|
||||
# process padding features
|
||||
if key in ["input_ids", "attention_mask", "position_ids"]:
|
||||
padding_value = self.tokenizer.pad_token_id if key == "input_ids" else 0
|
||||
batch[key] = pad_sequence(batch[key], batch_first=True, padding_value=padding_value)
|
||||
elif key in ["labels"]:
|
||||
batch[key] = pad_sequence(batch[key], batch_first=True, padding_value=IGNORE_INDEX)
|
||||
else:
|
||||
batch[key] = default_collate(batch[key])
|
||||
|
||||
return batch
|
||||
# sft: messages
|
||||
# dpo: chosen_messages, rejected_messages
|
||||
|
||||
|
||||
class DataLoader:
|
||||
"""Default DataLoader."""
|
||||
@dataclass
|
||||
class DefaultCollator(DataCollator):
|
||||
"""Example for now."""
|
||||
|
||||
def __init__(self, dataset: TorchDataset) -> None:
|
||||
self.dataset = dataset
|
||||
# 1. Init stateful dataloader (tokenize)
|
||||
# 2. Add to buffer (2 * max seq len per device)
|
||||
# 3. Yield batch indexes (micro batch * grad acc)
|
||||
# a ) non pack + non dynamic
|
||||
# b ) non pack + dynamic
|
||||
# c ) pack + non dynamic
|
||||
# d ) pack + dynamic
|
||||
processor: "Processor" # processor name -> map to encode_messages function
|
||||
template: "Template"
|
||||
|
||||
def __call__(self, messages: list[list[dict[str, Any]]]) -> dict[str, Tensor]:
|
||||
features = []
|
||||
|
||||
# Check if data is already tokenized (contains input_ids)
|
||||
if messages and isinstance(messages[0], dict) and "input_ids" in messages[0]:
|
||||
for feature in messages:
|
||||
if not isinstance(feature, dict):
|
||||
raise ValueError(f"Expected dict but got {type(feature)}")
|
||||
tensor_feature = {
|
||||
k: torch.tensor(v, dtype=torch.long) if not isinstance(v, torch.Tensor) else v
|
||||
for k, v in feature.items()
|
||||
}
|
||||
features.append(tensor_feature)
|
||||
else:
|
||||
# raw messages need to be encoded
|
||||
for message in messages:
|
||||
encoded_message = self.template.encode_messages(self.tokenizer, message)
|
||||
encoded_message = {k: torch.tensor(v, dtype=torch.long) for k, v in encoded_message.items()}
|
||||
features.append(encoded_message)
|
||||
|
||||
return super().__call__(features)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PairwiseCollator(DataCollator):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataCollatorWithPacking(DefaultCollator):
|
||||
"""Data collator with packing."""
|
||||
|
||||
processor: "Processor"
|
||||
template: "Template"
|
||||
|
||||
def __call__(self, features: Sequence[dict[str, "torch.Tensor"]]) -> dict[str, "torch.Tensor"]:
|
||||
seqlens = torch.tensor([len(feature["input_ids"]) for feature in features], dtype=torch.long)
|
||||
batch = {"cu_seqlens": len2culen(seqlens)}
|
||||
for input_name in features[0].keys():
|
||||
if input_name in ("input_ids", "attention_mask", "labels"):
|
||||
batch[input_name] = torch.cat([feature[input_name] for feature in features])
|
||||
else:
|
||||
batch[input_name] = default_collate([feature[input_name] for feature in features])
|
||||
|
||||
return batch
|
||||
|
||||
277
src/llamafactory/v1/core/trainer_utils/data_loader.py
Normal file
277
src/llamafactory/v1/core/trainer_utils/data_loader.py
Normal file
@@ -0,0 +1,277 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import copy
|
||||
import sys
|
||||
from collections.abc import Generator, Iterator
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from torchdata.stateful_dataloader import StatefulDataLoader
|
||||
from torchdata.stateful_dataloader.sampler import StatefulDistributedSampler
|
||||
|
||||
from ...utils.batching_queue import BaseBatchingQueue
|
||||
from ...utils.logging import get_logger
|
||||
from ...utils.types import Processor, TorchDataset
|
||||
from .data_collator import DataCollator
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
# base dataloader
|
||||
class DistributedDataloader(StatefulDataLoader):
|
||||
"""Base Distributed DataLoader."""
|
||||
|
||||
dataset: "TorchDataset"
|
||||
sampler: "StatefulDistributedSampler"
|
||||
|
||||
def set_epoch(self, epoch: int) -> None:
|
||||
if self.sampler is not None and hasattr(self.sampler, "set_epoch"):
|
||||
self.sampler.set_epoch(epoch)
|
||||
elif hasattr(self.dataset, "set_epoch"):
|
||||
self.dataset.set_epoch(epoch)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseDataLoader:
|
||||
"""Default DataLoader."""
|
||||
|
||||
processor: Processor
|
||||
|
||||
def __init__(self, dataset: TorchDataset) -> None:
|
||||
self.dataset = dataset
|
||||
# guidlines: fetch until get fixed batchsize.
|
||||
# save state_dict for buffer.
|
||||
# resume with state
|
||||
|
||||
# 1. Init stateful dataloader (tokenize)
|
||||
# 2. Add to buffer (2 * max seq len per device)
|
||||
# 3. Yield batch indexes (micro batch * grad acc)
|
||||
# a ) non pack + non dynamic
|
||||
# b ) non pack + dynamic
|
||||
# c ) pack + non dynamic
|
||||
# d ) pack + dynamic
|
||||
|
||||
def init_dataloader(self) -> None:
|
||||
### init dataloader
|
||||
pass
|
||||
|
||||
def __iter__(self) -> Iterator:
|
||||
pass
|
||||
|
||||
def __next__(self) -> any:
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataLoader:
|
||||
"""Default DataLoader."""
|
||||
|
||||
processor: "Processor"
|
||||
dataloader: "DistributedDataloader"
|
||||
batching_queue: "BaseBatchingQueue"
|
||||
collate_fn: "DataCollator"
|
||||
num_micro_batch: int = 1
|
||||
length: int = 0
|
||||
drop_last: bool = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataloader: any,
|
||||
collate_fn: "DataCollator",
|
||||
num_micro_batch: int = 1,
|
||||
length: int = 0,
|
||||
drop_last: bool = True,
|
||||
batching_queue: Optional["BaseBatchingQueue"] = None,
|
||||
) -> None:
|
||||
self.batching_queue = batching_queue
|
||||
self.num_micro_batch = num_micro_batch
|
||||
self.step = 0
|
||||
self._collate_fn = collate_fn
|
||||
self._dataloader = dataloader
|
||||
self._drop_last = drop_last
|
||||
self._data_iter: Iterator
|
||||
self._resume = False
|
||||
self._batch_data_iter: Generator
|
||||
|
||||
if length > 0:
|
||||
self._length = length
|
||||
elif length == -1:
|
||||
self._length = sys.maxsize
|
||||
else:
|
||||
self._length = len(self._dataloader)
|
||||
|
||||
def __len__(self):
|
||||
return self._length
|
||||
|
||||
def __iter__(self) -> Iterator:
|
||||
if not self._resume:
|
||||
self.step = 0
|
||||
self._data_iter = iter(self._dataloader)
|
||||
self._batch_data_iter = self.batch_data_generator()
|
||||
self._resume = False
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
return next(self._batch_data_iter) # FIXME maybe we can move origin_batch_data_generator to here
|
||||
|
||||
def origin_batch_data_generator(self):
|
||||
"""Standard pass-through generator if do not use batching queue."""
|
||||
while True:
|
||||
if self._length > 0 and self.step >= self._length:
|
||||
return
|
||||
|
||||
try:
|
||||
batch = []
|
||||
data = next(self._data_iter)
|
||||
# split data into micro batches
|
||||
for i in range(0, len(data), self.num_micro_batch):
|
||||
micro_batch = data[i : i + self.num_micro_batch]
|
||||
if self._collate_fn:
|
||||
micro_batch = self._collate_fn(micro_batch)
|
||||
batch.append(micro_batch)
|
||||
yield batch
|
||||
self.step += 1
|
||||
except StopIteration:
|
||||
if self.step < self._length:
|
||||
# Restart iterator to fill the requested length
|
||||
self._data_iter = iter(self._dataloader)
|
||||
try:
|
||||
batch = []
|
||||
data = next(self._data_iter)
|
||||
for i in range(0, len(data), self.num_micro_batch):
|
||||
micro_batch = data[i : i + self.num_micro_batch]
|
||||
if self._collate_fn:
|
||||
micro_batch = self._collate_fn(micro_batch)
|
||||
batch.append(micro_batch)
|
||||
yield batch
|
||||
self.step += 1
|
||||
except StopIteration:
|
||||
return
|
||||
else:
|
||||
return
|
||||
except Exception as e:
|
||||
logger.error(f"DataLoader origin_batch_data_generator exception: {e}")
|
||||
raise
|
||||
|
||||
def batch_data_generator(self):
|
||||
if self.batching_queue is None:
|
||||
yield from self.origin_batch_data_generator()
|
||||
return
|
||||
|
||||
batch = []
|
||||
|
||||
while True:
|
||||
if self._length and self.step >= self._length:
|
||||
return
|
||||
|
||||
if self.batching_queue.is_full_filled():
|
||||
micro_batch = self.batching_queue.get_micro_batch(self.step)
|
||||
if self._collate_fn:
|
||||
micro_batch = self._collate_fn(micro_batch)
|
||||
batch.append(micro_batch)
|
||||
if len(batch) == self.num_micro_batch:
|
||||
yield batch
|
||||
self.step += 1
|
||||
batch = []
|
||||
|
||||
try:
|
||||
processing_item = next(self._data_iter)
|
||||
except Exception as e:
|
||||
if isinstance(e, StopIteration):
|
||||
if self.step < self._length:
|
||||
# call iter until reach length
|
||||
self._data_iter = iter(self._dataloader)
|
||||
processing_item = next(self._data_iter)
|
||||
elif not self._drop_last and not self.batching_queue.empty():
|
||||
while not self.batching_queue.empty():
|
||||
micro_batch = self.batching_queue.get_micro_batch(self.step)
|
||||
if self._collate_fn:
|
||||
micro_batch = self._collate_fn(micro_batch)
|
||||
batch.append(micro_batch)
|
||||
if len(batch) == self.num_micro_batch:
|
||||
yield batch
|
||||
self.step += 1
|
||||
batch = []
|
||||
|
||||
while len(batch) < self.num_micro_batch:
|
||||
padding_batch = copy.deepcopy(micro_batch)
|
||||
padding_batch["is_padded"] = True
|
||||
batch.append(padding_batch)
|
||||
yield batch
|
||||
self.step += 1
|
||||
return
|
||||
else:
|
||||
return
|
||||
else:
|
||||
logger.error(f"DataLoader iter data exception: {e}")
|
||||
raise
|
||||
|
||||
# put processing_item to buffer
|
||||
if isinstance(processing_item, dict):
|
||||
processing_item = [processing_item]
|
||||
|
||||
for item in processing_item:
|
||||
self.batching_queue.put_item(item)
|
||||
|
||||
def state_dict(self):
|
||||
# save state
|
||||
state = self.__dict__.copy()
|
||||
# remove internal fields
|
||||
for k in list(state.keys()):
|
||||
if k.startswith("_"):
|
||||
del state[k]
|
||||
|
||||
# save dataloader state
|
||||
if hasattr(self._dataloader, "state_dict"):
|
||||
state["dataloader_state"] = self._dataloader.state_dict()
|
||||
elif hasattr(self._dataloader, "__getstate__"):
|
||||
state["dataloader_state"] = self._dataloader.__getstate__()
|
||||
|
||||
batching_strategy = getattr(self, "batching_strategy", None)
|
||||
if batching_strategy and hasattr(batching_strategy, "state_dict"):
|
||||
state["batching_strategy_state"] = batching_strategy.state_dict()
|
||||
if "batching_strategy" in state:
|
||||
del state["batching_strategy"]
|
||||
|
||||
return copy.deepcopy(state)
|
||||
|
||||
def load_state_dict(self, state: dict[str, any]):
|
||||
if state["num_micro_batch"] != self.num_micro_batch:
|
||||
logger.warning(
|
||||
f"num_micro_batch changed: [ {state['num_micro_batch']} -> {self.num_micro_batch} ], will clear prefetch buffer"
|
||||
)
|
||||
del state["num_micro_batch"]
|
||||
self.__dict__.update(state)
|
||||
self._resume = True
|
||||
|
||||
if hasattr(self._dataloader, "load_state_dict"):
|
||||
self._dataloader.load_state_dict(state["dataloader_state"])
|
||||
elif hasattr(self._dataloader, "__getstate__"):
|
||||
self._dataloader.__setstate__(state["dataloader_state"])
|
||||
|
||||
if "batching_strategy_state" in state:
|
||||
batching_strategy = getattr(self, "batching_strategy", None)
|
||||
if batching_strategy:
|
||||
batching_strategy.load_state_dict(state["batching_strategy_state"])
|
||||
del state["batching_strategy_state"]
|
||||
|
||||
self._data_iter = iter(self._dataloader)
|
||||
self._batch_data_iter = self.batch_data_generator()
|
||||
|
||||
def set_epoch(self, epoch: int) -> None:
|
||||
if hasattr(self._dataloader, "set_epoch"):
|
||||
self._dataloader.set_epoch(epoch)
|
||||
@@ -13,9 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from typing import Any, Literal, TypedDict
|
||||
|
||||
from typing_extensions import NotRequired
|
||||
from typing import Any, Literal, NotRequired, TypedDict
|
||||
|
||||
from ...utils import logging
|
||||
from ...utils.plugin import BasePlugin
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
|
||||
import os
|
||||
import random
|
||||
from typing import Any, Literal, Optional, Union
|
||||
from typing import Any, Literal
|
||||
|
||||
from datasets import load_dataset
|
||||
|
||||
@@ -70,7 +70,7 @@ class DataIndexPlugin(BasePlugin):
|
||||
"""Plugin for adjusting dataset index."""
|
||||
|
||||
def adjust_data_index(
|
||||
self, data_index: list[tuple[str, int]], size: Optional[int], weight: Optional[float]
|
||||
self, data_index: list[tuple[str, int]], size: int | None, weight: float | None
|
||||
) -> list[tuple[str, int]]:
|
||||
"""Adjust dataset index by size and weight.
|
||||
|
||||
@@ -95,8 +95,8 @@ class DataSelectorPlugin(BasePlugin):
|
||||
"""Plugin for selecting dataset samples."""
|
||||
|
||||
def select(
|
||||
self, data_index: list[tuple[str, int]], index: Union[slice, list[int], Any]
|
||||
) -> Union[tuple[str, int], list[tuple[str, int]]]:
|
||||
self, data_index: list[tuple[str, int]], index: slice | list[int] | Any
|
||||
) -> tuple[str, int] | list[tuple[str, int]]:
|
||||
"""Select dataset samples.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -22,5 +22,112 @@ class Template:
|
||||
assistant_template: str
|
||||
system_template: str
|
||||
|
||||
def render_message(self, message: "dict[str, str]") -> str:
|
||||
def render_message(self, message: dict[str, str]) -> str:
|
||||
return self.user_template.format(**message)
|
||||
|
||||
|
||||
@dataclass
|
||||
class QwenTemplate:
|
||||
message_template: str = "<|im_start|>{role}\n{content}<|im_end|>\n" # FIXME if role: tool
|
||||
thinking_template: str = "<think>\n{content}\n</think>\n\n"
|
||||
|
||||
def _extract_content(self, content_data: str | list[dict[str, str]]) -> str:
|
||||
if isinstance(content_data, str):
|
||||
return content_data.strip()
|
||||
|
||||
if isinstance(content_data, list):
|
||||
parts = []
|
||||
for item in content_data:
|
||||
if item.get("type") == "text":
|
||||
parts.append(item.get("value", ""))
|
||||
elif item.get("type") == "image_url":
|
||||
pass
|
||||
return "\n".join(parts).strip()
|
||||
|
||||
return ""
|
||||
|
||||
def render_message(self, message: dict[str, str | list[dict[str, str]]]) -> str:
|
||||
role = message["role"]
|
||||
content = self._extract_content(message.get("content", ""))
|
||||
|
||||
if role == "assistant":
|
||||
reasoning_content = message.get("reasoning_content", "")
|
||||
if reasoning_content:
|
||||
reasoning_content = self.thinking_template.format(content=str(reasoning_content).strip())
|
||||
return self.message_template.format(role="assistant", content=reasoning_content + content)
|
||||
else:
|
||||
return self.message_template.format(role=role, content=content)
|
||||
|
||||
def encode_messages(self, tokenizer, messages: list[dict[str, str]], max_seq_len: int = 8192) -> any:
|
||||
"""Encode one message."""
|
||||
input_ids, attention_mask, labels = [], [], []
|
||||
for message in messages:
|
||||
content_str = self.render_message(message)
|
||||
content_ids = tokenizer.encode(content_str, add_special_tokens=False)
|
||||
input_ids += content_ids
|
||||
attention_mask += [1] * len(content_ids)
|
||||
|
||||
if hasattr(message, "loss_weight"):
|
||||
loss_weight = message["loss_weight"]
|
||||
else:
|
||||
loss_weight = 1 if message["role"] == "assistant" else 0
|
||||
if loss_weight == 1:
|
||||
labels += content_ids
|
||||
else:
|
||||
labels += [-100] * len(content_ids)
|
||||
model_inputs = {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}
|
||||
model_inputs.update({"position_ids": list(range(len(input_ids)))})
|
||||
model_inputs = {k: v[-max_seq_len:] for k, v in model_inputs.items()}
|
||||
return model_inputs
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
def to_qwen3_messages(template: QwenTemplate, messages: list[dict]):
|
||||
out = []
|
||||
for m in messages:
|
||||
role = m["role"]
|
||||
content = template._extract_content(m.get("content", ""))
|
||||
if role == "assistant":
|
||||
reasoning = (m.get("reasoning_content") or "").strip()
|
||||
if reasoning:
|
||||
content = template.thinking_template.format(content=reasoning) + content
|
||||
out.append({"role": role, "content": content})
|
||||
return out
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
tok = AutoTokenizer.from_pretrained(
|
||||
"Qwen/Qwen3-30B-A3B-Thinking-2507",
|
||||
trust_remote_code=True,
|
||||
)
|
||||
|
||||
test_messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": "1+1等于几?"}, {"type": "text", "text": "2+2等于几?"}],
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"reasoning_content": "这是一个简单的数学问题。1加1的结果是2。",
|
||||
"content": [{"type": "text", "text": "1+1=2"}, {"type": "text", "text": "2+2=4"}],
|
||||
},
|
||||
]
|
||||
|
||||
template = QwenTemplate()
|
||||
rendered_custom = "".join([template.render_message(m) for m in test_messages])
|
||||
|
||||
qwen3_messages = to_qwen3_messages(template, test_messages)
|
||||
rendered_hf = tok.apply_chat_template(qwen3_messages, tokenize=False, add_generation_prompt=False)
|
||||
|
||||
print("==== custom ====")
|
||||
print(rendered_custom)
|
||||
print("==== hf ====")
|
||||
print(rendered_hf)
|
||||
|
||||
assert rendered_custom.strip() == rendered_hf.strip(), "Rendered text mismatch"
|
||||
|
||||
ids_custom = tok.encode(rendered_custom, add_special_tokens=False)
|
||||
ids_hf = tok.apply_chat_template(qwen3_messages, tokenize=True, add_generation_prompt=False)
|
||||
assert ids_custom == ids_hf, f"Token ids mismatch: custom={len(ids_custom)} hf={len(ids_hf)}"
|
||||
|
||||
@@ -13,7 +13,8 @@
|
||||
# limitations under the License.
|
||||
|
||||
from abc import ABC, ABCMeta, abstractmethod
|
||||
from typing import Any, Callable, Optional, Union
|
||||
from collections.abc import Callable
|
||||
from typing import Any, Optional
|
||||
|
||||
from ....accelerator.helper import DeviceType, get_current_accelerator
|
||||
from ....utils.types import HFModel
|
||||
@@ -38,7 +39,7 @@ class KernelRegistry:
|
||||
self._initialized = True
|
||||
|
||||
def register(
|
||||
self, kernel_type: KernelType, device_type: DeviceType, kernel_impl: Optional[Callable[..., Any]]
|
||||
self, kernel_type: KernelType, device_type: DeviceType, kernel_impl: Callable[..., Any] | None
|
||||
) -> None:
|
||||
"""Register a kernel implementation.
|
||||
|
||||
@@ -56,7 +57,7 @@ class KernelRegistry:
|
||||
self._registry[kernel_type][device_type] = kernel_impl
|
||||
print(f"Registered kernel {kernel_type.name} for device {device_type.name}.")
|
||||
|
||||
def get_kernel(self, kernel_type: KernelType, device_type: DeviceType) -> Optional[Callable[..., Any]]:
|
||||
def get_kernel(self, kernel_type: KernelType, device_type: DeviceType) -> Callable[..., Any] | None:
|
||||
return self._registry.get(kernel_type, {}).get(device_type)
|
||||
|
||||
|
||||
@@ -105,9 +106,9 @@ class MetaKernel(ABC, metaclass=AutoRegisterKernelMeta):
|
||||
auto_register: Set to False to disable automatic registration (default: True).
|
||||
"""
|
||||
|
||||
type: Optional[KernelType] = None
|
||||
device: Optional[DeviceType] = None
|
||||
kernel: Optional[Callable] = None
|
||||
type: KernelType | None = None
|
||||
device: DeviceType | None = None
|
||||
kernel: Callable | None = None
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
@@ -228,7 +229,7 @@ def discover_kernels(model: HFModel = None) -> list[type[MetaKernel]]:
|
||||
return discovered_kernels
|
||||
|
||||
|
||||
def apply_kernel(model: HFModel, kernel: Union[type[MetaKernel], Any], /, **kwargs) -> "HFModel":
|
||||
def apply_kernel(model: HFModel, kernel: type[MetaKernel] | Any, /, **kwargs) -> "HFModel":
|
||||
"""Call the MetaKernel's `apply` to perform the replacement.
|
||||
|
||||
Corresponding replacement logic is maintained inside each kernel; the only
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Literal, Optional, TypedDict
|
||||
from typing import Literal, TypedDict
|
||||
|
||||
from peft import LoraConfig, PeftModel, get_peft_model
|
||||
|
||||
@@ -36,7 +36,7 @@ class FreezeConfigDict(TypedDict, total=False):
|
||||
"""Plugin name."""
|
||||
freeze_trainable_layers: int
|
||||
"""Freeze trainable layers."""
|
||||
freeze_trainable_modules: Optional[list[str]]
|
||||
freeze_trainable_modules: list[str] | None
|
||||
"""Freeze trainable modules."""
|
||||
|
||||
|
||||
|
||||
220
src/llamafactory/v1/utils/batching_queue.py
Normal file
220
src/llamafactory/v1/utils/batching_queue.py
Normal file
@@ -0,0 +1,220 @@
|
||||
# Copyright 2025 Bytedance Ltd. and the LlamaFactory team.
|
||||
#
|
||||
# This code is inspired by the Bytedance's VeOmni library.
|
||||
# https://github.com/ByteDance-Seed/VeOmni/blob/v0.1.4/veomni/data/dynamic_batching.py
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class DynamicBatchSizeBuffer:
|
||||
"""A buffer to store samples for dynamic batch size."""
|
||||
|
||||
def __init__(self):
|
||||
self._buffer: list[dict[str, any]] = []
|
||||
self._buffer_sample_lengths: list[int] = []
|
||||
self._deleted_indices: set[int] = set()
|
||||
self._current_index: int = 0
|
||||
self._total_token_count: int = 0
|
||||
|
||||
def append(self, item: dict[str, any]) -> None:
|
||||
"""Append a sample to the buffer.
|
||||
|
||||
Args:
|
||||
item: A sample to append to the buffer.
|
||||
The sample should be a dict with the following keys:
|
||||
- input_ids: torch.Tensor of shape (seq_len, )
|
||||
- attention_mask: torch.Tensor of shape (seq_len, )
|
||||
"""
|
||||
self._buffer.append(item)
|
||||
sample_length = int(item["attention_mask"].sum().item())
|
||||
self._buffer_sample_lengths.append(sample_length)
|
||||
self._total_token_count += sample_length
|
||||
|
||||
def get_samples(self, max_tokens_per_iteration: int, force: bool = True) -> list[dict[str, any]]:
|
||||
"""Get samples from the buffer that fit within the token budget.
|
||||
|
||||
Args:
|
||||
max_tokens_per_iteration: Maximum number of tokens to retrieve.
|
||||
force: If True, the first available sample will be returned even
|
||||
if it exceeds the token budget.
|
||||
|
||||
Returns:
|
||||
A list of samples that fit within the token budget.
|
||||
|
||||
Raises:
|
||||
AssertionError: If no samples are found (should not happen in normal operation).
|
||||
"""
|
||||
cum_seq_len = 0
|
||||
samples = []
|
||||
|
||||
while self._current_index < len(self._buffer) and cum_seq_len < max_tokens_per_iteration:
|
||||
if self._current_index in self._deleted_indices:
|
||||
self._current_index += 1
|
||||
continue
|
||||
|
||||
seq_len = self._buffer_sample_lengths[self._current_index]
|
||||
remaining_tokens = max_tokens_per_iteration - cum_seq_len
|
||||
|
||||
# Check if we can add this sample
|
||||
can_add = (force and cum_seq_len == 0) or (seq_len <= remaining_tokens)
|
||||
|
||||
if can_add:
|
||||
cum_seq_len += seq_len
|
||||
samples.append(self._buffer[self._current_index])
|
||||
self._deleted_indices.add(self._current_index)
|
||||
|
||||
self._current_index += 1
|
||||
|
||||
assert len(samples) > 0, "No samples found in buffer"
|
||||
return samples
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return the number of samples in the buffer."""
|
||||
return len(self._buffer)
|
||||
|
||||
@property
|
||||
def total_token_count(self) -> int:
|
||||
"""Return the total number of tokens in the buffer."""
|
||||
return self._total_token_count
|
||||
|
||||
def flush(self) -> None:
|
||||
tokens_to_remove = sum(self._buffer_sample_lengths[idx] for idx in self._deleted_indices)
|
||||
self._total_token_count -= tokens_to_remove
|
||||
|
||||
buffer_length = len(self._buffer)
|
||||
self._buffer = [self._buffer[idx] for idx in range(buffer_length) if idx not in self._deleted_indices]
|
||||
self._buffer_sample_lengths = [
|
||||
self._buffer_sample_lengths[idx] for idx in range(buffer_length) if idx not in self._deleted_indices
|
||||
]
|
||||
|
||||
self._current_index = 0
|
||||
self._deleted_indices.clear()
|
||||
|
||||
|
||||
class BaseBatchingQueue(ABC):
|
||||
"""Base class for batching queue."""
|
||||
|
||||
@abstractmethod
|
||||
def is_full_filled(self) -> bool:
|
||||
raise NotImplementedError("Subclasses must implement `is_full_filled`")
|
||||
|
||||
@abstractmethod
|
||||
def put_item(self, item: dict[str, any]) -> None:
|
||||
raise NotImplementedError("Subclasses must implement `put_item`")
|
||||
|
||||
@abstractmethod
|
||||
def get_micro_batch(self, step: int) -> list[dict[str, any]]:
|
||||
raise NotImplementedError("Subclasses must implement `get_micro_batch`")
|
||||
|
||||
@abstractmethod
|
||||
def empty(self) -> bool:
|
||||
raise NotImplementedError("Subclasses must implement `empty`")
|
||||
|
||||
|
||||
class IdentityPacker:
|
||||
def __init__(self, token_micro_bsz, bsz_warmup_steps, bsz_warmup_init_mbtoken):
|
||||
self.token_micro_bsz = token_micro_bsz
|
||||
self.bsz_warmup_steps = bsz_warmup_steps
|
||||
self.bsz_warmup_init_mbtoken = bsz_warmup_init_mbtoken
|
||||
|
||||
def __call__(self, samples):
|
||||
return samples
|
||||
|
||||
def get_token_num_to_request(self, cur_step, warmup):
|
||||
return (
|
||||
(self.token_micro_bsz - self.bsz_warmup_init_mbtoken) * cur_step // self.bsz_warmup_steps
|
||||
+ self.bsz_warmup_init_mbtoken
|
||||
if warmup
|
||||
else self.token_micro_bsz
|
||||
)
|
||||
|
||||
|
||||
class TextBatchingQueue(BaseBatchingQueue):
|
||||
"""Batching text queue for text data."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
token_micro_bsz,
|
||||
buffer_size: int = 500,
|
||||
bsz_warmup_steps: int = -1,
|
||||
bsz_warmup_init_mbtoken: int = 200,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self._step = 0
|
||||
self.token_micro_bsz = token_micro_bsz
|
||||
self.bsz_warmup_steps = bsz_warmup_steps
|
||||
self.buffer_size = buffer_size # minimum samples in buffer
|
||||
self.buffer = DynamicBatchSizeBuffer()
|
||||
self.bsz_warmup_init_mbtoken = bsz_warmup_init_mbtoken # training warmup args
|
||||
assert self.bsz_warmup_init_mbtoken >= 0
|
||||
|
||||
self.packer = IdentityPacker(
|
||||
token_micro_bsz=token_micro_bsz,
|
||||
bsz_warmup_steps=bsz_warmup_steps,
|
||||
bsz_warmup_init_mbtoken=bsz_warmup_init_mbtoken,
|
||||
)
|
||||
|
||||
def is_full_filled(self) -> bool:
|
||||
return len(self.buffer) >= self.buffer_size and self.buffer.total_token_count >= self.token_micro_bsz
|
||||
|
||||
def put_item(self, item: dict[str, any]):
|
||||
if len(item["input_ids"]) == 1:
|
||||
print("WARNING: EMPTY STRING.")
|
||||
return
|
||||
self.buffer.append(item)
|
||||
|
||||
def get_token_num_to_request(self):
|
||||
if self.packer is not None:
|
||||
warmup = self._step <= self.bsz_warmup_steps and self.bsz_warmup_steps > 0
|
||||
return self.packer.get_token_num_to_request(self._step, warmup=warmup)
|
||||
else:
|
||||
return self.get_cur_token_micro_bsz()
|
||||
|
||||
def get_cur_token_micro_bsz(self):
|
||||
warmup = self._step <= self.bsz_warmup_steps and self.bsz_warmup_steps > 0
|
||||
if warmup:
|
||||
return (
|
||||
self.token_micro_bsz - self.bsz_warmup_init_mbtoken
|
||||
) * self._step // self.bsz_warmup_steps + self.bsz_warmup_init_mbtoken
|
||||
else:
|
||||
return self.token_micro_bsz
|
||||
|
||||
def get_micro_batch(self, step) -> any:
|
||||
"""Get a micro batch from the buffer according to the current step.
|
||||
|
||||
Args:
|
||||
step: the current step.
|
||||
|
||||
Returns:
|
||||
data: a list of samples.
|
||||
"""
|
||||
self._step = step
|
||||
n_token_per_iter = self.get_token_num_to_request()
|
||||
cur_token_micro_bsz = self.get_cur_token_micro_bsz()
|
||||
assert cur_token_micro_bsz % n_token_per_iter == 0, (
|
||||
"The token num to get for each request should be divisible by token micro bsz."
|
||||
)
|
||||
n_iter = int(cur_token_micro_bsz // n_token_per_iter)
|
||||
data = []
|
||||
for _ in range(n_iter):
|
||||
samples = self.buffer.get_samples(n_token_per_iter)
|
||||
if self.packer:
|
||||
samples = self.packer(samples) # maybe packed into one sample, but wrapped in list.
|
||||
data.extend(samples)
|
||||
self.buffer.flush() # remove the selected samples.
|
||||
return data
|
||||
|
||||
def empty(self) -> bool:
|
||||
return len(self.buffer) == 0
|
||||
@@ -16,7 +16,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
from contextlib import contextmanager
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
from transformers.utils import is_torch_bf16_available_on_device, is_torch_fp16_available_on_device
|
||||
@@ -38,7 +37,7 @@ class DtypeInterface:
|
||||
_is_fp32_available = True
|
||||
|
||||
@staticmethod
|
||||
def is_available(precision: Union[str, torch.dtype]) -> bool:
|
||||
def is_available(precision: str | torch.dtype) -> bool:
|
||||
if precision in DtypeRegistry.HALF_LIST:
|
||||
return DtypeInterface._is_fp16_available
|
||||
elif precision in DtypeRegistry.FLOAT_LIST:
|
||||
@@ -49,19 +48,19 @@ class DtypeInterface:
|
||||
raise RuntimeError(f"Unexpected precision: {precision}")
|
||||
|
||||
@staticmethod
|
||||
def is_fp16(precision: Union[str, torch.dtype]) -> bool:
|
||||
def is_fp16(precision: str | torch.dtype) -> bool:
|
||||
return precision in DtypeRegistry.HALF_LIST
|
||||
|
||||
@staticmethod
|
||||
def is_fp32(precision: Union[str, torch.dtype]) -> bool:
|
||||
def is_fp32(precision: str | torch.dtype) -> bool:
|
||||
return precision in DtypeRegistry.FLOAT_LIST
|
||||
|
||||
@staticmethod
|
||||
def is_bf16(precision: Union[str, torch.dtype]) -> bool:
|
||||
def is_bf16(precision: str | torch.dtype) -> bool:
|
||||
return precision in DtypeRegistry.BFLOAT_LIST
|
||||
|
||||
@staticmethod
|
||||
def to_dtype(precision: Union[str, torch.dtype]) -> torch.dtype:
|
||||
def to_dtype(precision: str | torch.dtype) -> torch.dtype:
|
||||
if precision in DtypeRegistry.HALF_LIST:
|
||||
return torch.float16
|
||||
elif precision in DtypeRegistry.FLOAT_LIST:
|
||||
@@ -83,7 +82,7 @@ class DtypeInterface:
|
||||
raise RuntimeError(f"Unexpected precision: {precision}")
|
||||
|
||||
@contextmanager
|
||||
def set_dtype(self, precision: Union[str, torch.dtype]):
|
||||
def set_dtype(self, precision: str | torch.dtype):
|
||||
original_dtype = torch.get_default_dtype()
|
||||
torch.set_default_dtype(self.to_dtype(precision))
|
||||
try:
|
||||
|
||||
@@ -17,7 +17,7 @@ import socket
|
||||
|
||||
|
||||
def find_available_port() -> int:
|
||||
r"""Find an available port on the local machine."""
|
||||
"""Find an available port on the local machine."""
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
sock.bind(("", 0))
|
||||
port = sock.getsockname()[1]
|
||||
@@ -26,9 +26,5 @@ def find_available_port() -> int:
|
||||
|
||||
|
||||
def is_env_enabled(env_var: str, default: str = "0") -> bool:
|
||||
r"""Check if the environment variable is enabled."""
|
||||
return os.getenv(env_var, default).lower() in ["true", "y", "1"]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(find_available_port())
|
||||
"""Check if the environment variable is enabled."""
|
||||
return os.getenv(env_var, default).lower() in ["true", "yes", "on", "t", "y", "1"]
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user