6 Commits

Author SHA1 Message Date
Copilot
eceec8ab69 [deps] goodbye python 3.9 (#9677)
Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: hiyouga <16256802+hiyouga@users.noreply.github.com>
Co-authored-by: hiyouga <hiyouga@buaa.edu.cn>
2025-12-27 02:50:44 +08:00
Yaowei Zheng
b44f651e09 [ci] fix docker (#9678) 2025-12-27 02:43:46 +08:00
Yaowei Zheng
55590f5ece [misc] fix ci with uv (#9676) 2025-12-27 01:39:13 +08:00
Copilot
a1b1931b4a [breaking] migrate from setuptools to uv (#9673)
Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: hiyouga <16256802+hiyouga@users.noreply.github.com>
2025-12-26 22:47:23 +08:00
Xunpeng Xiao
3c17f2722c [model] Update ernie_vl to adapt new version (#9665) 2025-12-26 19:57:49 +08:00
Copilot
a882e2d5fc [assets] Add GitHub Copilot instructions for repository (#9675)
Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: hiyouga <16256802+hiyouga@users.noreply.github.com>
2025-12-26 17:32:48 +08:00
88 changed files with 719 additions and 626 deletions

180
.github/copilot-instructions.md vendored Normal file
View File

@@ -0,0 +1,180 @@
# GitHub Copilot Instructions for LLaMA Factory
## Project Overview
LLaMA Factory is an efficient fine-tuning framework for 100+ large language models (LLMs). It provides:
- Support for various models: LLaMA, LLaVA, Mistral, Qwen, DeepSeek, Yi, Gemma, ChatGLM, Phi, etc.
- Multiple training methods: pre-training, supervised fine-tuning, reward modeling, PPO, DPO, KTO, ORPO
- Scalable resources: 16-bit full-tuning, freeze-tuning, LoRA and QLoRA variants
- Advanced algorithms: GaLore, BAdam, APOLLO, Adam-mini, Muon, OFT, DoRA, etc.
- Web UI (LLaMA Board) and CLI interfaces
### Architecture Versions
LLaMA Factory has two parallel architectures that can be switched via the `USE_V1` environment variable:
**v0 (default)** - File hierarchy:
- `api`, `webui``chat`, `eval`, `train``data`, `model``hparams``extras`
**v1** - File hierarchy:
- `trainers``core``accelerator`, `plugins`, `config``utils`
Set `USE_V1=1` to enable v1 architecture.
## Code Structure
### v0 Architecture (Default)
- `src/llamafactory/` - Main package directory
- `api/` - OpenAI-style API implementation
- `chat/` - Chat interface implementation
- `cli.py` - Command-line interface
- `data/` - Data processing and dataset handling
- `eval/` - Model evaluation utilities
- `extras/` - Additional utilities and helpers
- `hparams/` - Hyperparameter definitions
- `model/` - Model loading, patching, and utilities
- `train/` - Training pipeline implementation
- `webui/` - Gradio-based web interface
- `src/train.py` - Training entry script (delegates to `llamafactory.train.tuner`)
- `src/webui.py` - Web UI entry script (delegates to `llamafactory.webui.interface`)
- `src/api.py` - API server entry script (delegates to `llamafactory.api.app`)
- `tests/` - Test suite
- `examples/` - Example configurations for various training scenarios
- `data/` - Dataset definitions and examples
### v1 Architecture (USE_V1=1)
- `src/llamafactory/v1/` - Version 1 package directory
- `trainers/` - Training implementations
- `core/` - Core training utilities
- `accelerator/` - Acceleration and distributed training
- `plugins/` - Pluggable components (model, data, sampler, trainer)
- `config/` - Configuration management
- `utils/` - Utility functions
## Development Practices
### Code Style
- Follow the [Google Python Style Guide](https://google.github.io/styleguide/pyguide.html)
- Use ruff for linting and formatting
- Line length: 119 characters
- Indentation: 4 spaces
- Quote style: double quotes
- Use Google-style docstrings for documentation
### Import Organization
- Known first-party: `llamafactory`
- Known third-party: `accelerate`, `datasets`, `gradio`, `numpy`, `peft`, `torch`, `transformers`, `trl`
- Use 2 blank lines after imports
### Quality Checks
Before committing code, run:
```bash
make style # Auto-fix style issues
make quality # Check code quality
make test # Run test suite
```
Or use the combined command:
```bash
make commit # Run pre-commit hooks
```
### Testing
- Use pytest for testing
- Tests are located in `tests/` and `tests_v1/` directories
- Run tests with: `make test` (which runs `WANDB_DISABLED=true pytest -vv --import-mode=importlib tests/ tests_v1/`)
- Disable wandb during testing to avoid external dependencies
- **Note**: Training configurations require GPU machines, so training is typically not tested end-to-end. Use `make test` to validate file-level functionality.
### Building
Build the package with:
```bash
pip3 install build && python3 -m build
```
### License
- All source files must include the Apache 2.0 license header
- Check license headers with: `make license`
## Common Patterns
### Configuration Files
- Training configurations are typically YAML or JSON files in `examples/` directory
- Hyperparameters are defined using dataclasses in `src/llamafactory/hparams/`
### Model Support
- New model support is added through model patches in `src/llamafactory/model/`
- Visual models use the visual utilities in `src/llamafactory/model/model_utils/visual.py`
- Quantization support is in `src/llamafactory/model/model_utils/quantization.py`
### Data Processing
- Dataset definitions are in `data/dataset_info.json`
- Data templates and processors are in `src/llamafactory/data/`
### Training
- Training pipelines are in `src/llamafactory/train/`
- Support for different training methods: SFT, DPO, PPO, RM, PT, KTO, ORPO
## Key Dependencies
- Python >= 3.9.0
- PyTorch and transformers for model handling
- datasets for data processing
- peft for parameter-efficient fine-tuning
- accelerate for distributed training
- gradio for web UI
- trl for reinforcement learning
- Optional: vllm/sglang for inference, flash-attention-2, unsloth, liger-kernel
## Entry Points
- **CLI Training**: `llamafactory-cli train --config examples/train_lora/llama3_lora_sft.yaml`
- **Web UI**: `llamafactory-cli webui` or `python src/webui.py`
- **API Server**: `llamafactory-cli api` or `python src/api.py`
- **Chat Interface**: `llamafactory-cli chat --model_name_or_path MODEL_PATH`
## Environment Setup
For development:
```bash
pip install -e ".[dev]"
```
## Important Notes
- The project supports multiple backends: default PyTorch, vLLM, SGLang
- Megatron-core training is supported via mcore_adapter
- SwanLab and W&B are supported for experiment tracking
- Docker support is available with pre-built images
- Day-0/Day-1 support for latest cutting-edge models
- Multi-modal support for vision and audio understanding tasks
## Contribution Guidelines
1. Fork the repository
2. Create a development branch
3. Set up development environment with `pip install -e ".[dev]"`
4. Make changes following the style guide
5. Run quality checks: `make style && make quality`
6. Run tests: `make test`
7. Submit a pull request
## Common Commands
- `make style` - Format code
- `make quality` - Run linters
- `make test` - Run tests
- `make commit` - Install and run pre-commit hooks
- `make license` - Check license headers

View File

@@ -7,7 +7,7 @@ on:
- "main" - "main"
paths: paths:
- "**/*.py" - "**/*.py"
- "requirements.txt" - "pyproject.toml"
- "docker/**" - "docker/**"
- ".github/workflows/*.yml" - ".github/workflows/*.yml"
pull_request: pull_request:
@@ -15,7 +15,7 @@ on:
- "main" - "main"
paths: paths:
- "**/*.py" - "**/*.py"
- "requirements.txt" - "pyproject.toml"
- "docker/**" - "docker/**"
- ".github/workflows/*.yml" - ".github/workflows/*.yml"
release: release:
@@ -29,16 +29,13 @@ jobs:
matrix: matrix:
include: include:
- device: "cuda" - device: "cuda"
npu_type: "" - device: "npu-a2"
- device: "npu" - device: "npu-a3"
npu_type: "a2"
- device: "npu"
npu_type: "a3"
runs-on: ubuntu-latest runs-on: ubuntu-latest
concurrency: concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ matrix.device }}-${{ matrix.npu_type }} group: ${{ github.workflow }}-${{ github.ref }}-${{ matrix.device }}
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }} cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
environment: environment:
@@ -55,16 +52,11 @@ jobs:
- name: Checkout - name: Checkout
uses: actions/checkout@v4 uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.10"
- name: Get llamafactory version - name: Get llamafactory version
id: version id: version
run: | run: |
if [ "${{ github.event_name }}" = "release" ]; then if [ "${{ github.event_name }}" = "release" ]; then
echo "tag=$(python setup.py --version)" >> "$GITHUB_OUTPUT" echo "tag=$(grep -oP 'VERSION = "\K[^"]+' src/llamafactory/extras/env.py)" >> "$GITHUB_OUTPUT"
else else
echo "tag=latest" >> "$GITHUB_OUTPUT" echo "tag=latest" >> "$GITHUB_OUTPUT"
fi fi
@@ -93,14 +85,12 @@ jobs:
with: with:
context: . context: .
file: ./docker/docker-cuda/Dockerfile file: ./docker/docker-cuda/Dockerfile
build-args: |
EXTRAS=metrics,deepspeed,liger-kernel
push: ${{ github.event_name != 'pull_request' }} push: ${{ github.event_name != 'pull_request' }}
tags: | tags: |
docker.io/hiyouga/llamafactory:${{ steps.version.outputs.tag }} docker.io/hiyouga/llamafactory:${{ steps.version.outputs.tag }}
- name: Build and push Docker image (NPU-A2) - name: Build and push Docker image (NPU-A2)
if: ${{ matrix.device == 'npu' && matrix.npu_type == 'a2' }} if: ${{ matrix.device == 'npu-a2' }}
uses: docker/build-push-action@v6 uses: docker/build-push-action@v6
with: with:
context: . context: .
@@ -112,7 +102,7 @@ jobs:
quay.io/ascend/llamafactory:${{ steps.version.outputs.tag }}-npu-a2 quay.io/ascend/llamafactory:${{ steps.version.outputs.tag }}-npu-a2
- name: Build and push Docker image (NPU-A3) - name: Build and push Docker image (NPU-A3)
if: ${{ matrix.device == 'npu' && matrix.npu_type == 'a3' }} if: ${{ matrix.device == 'npu-a3' }}
uses: docker/build-push-action@v6 uses: docker/build-push-action@v6
with: with:
context: . context: .

View File

@@ -23,10 +23,11 @@ jobs:
- name: Checkout - name: Checkout
uses: actions/checkout@v4 uses: actions/checkout@v4
- name: Set up Python - name: Install uv
uses: actions/setup-python@v5 uses: astral-sh/setup-uv@v7
with: with:
python-version: "3.9" python-version: "3.11"
github-token: ${{ github.token }}
- name: Build package - name: Build package
run: | run: |

View File

@@ -7,7 +7,7 @@ on:
- "main" - "main"
paths: paths:
- "**/*.py" - "**/*.py"
- "requirements.txt" - "pyproject.toml"
- "Makefile" - "Makefile"
- ".github/workflows/*.yml" - ".github/workflows/*.yml"
pull_request: pull_request:
@@ -15,7 +15,7 @@ on:
- "main" - "main"
paths: paths:
- "**/*.py" - "**/*.py"
- "requirements.txt" - "pyproject.toml"
- "Makefile" - "Makefile"
- ".github/workflows/*.yml" - ".github/workflows/*.yml"
@@ -25,10 +25,9 @@ jobs:
fail-fast: false fail-fast: false
matrix: matrix:
python: python:
- "3.9"
- "3.10"
- "3.11" - "3.11"
- "3.12" - "3.12"
# - "3.13" # enable after trl is upgraded
os: os:
- "ubuntu-latest" - "ubuntu-latest"
- "windows-latest" - "windows-latest"
@@ -36,18 +35,15 @@ jobs:
transformers: transformers:
- null - null
include: # test backward compatibility include: # test backward compatibility
- python: "3.9" - python: "3.11"
os: "ubuntu-latest" os: "ubuntu-latest"
transformers: "4.49.0" transformers: "4.49.0"
- python: "3.9" - python: "3.11"
os: "ubuntu-latest" os: "ubuntu-latest"
transformers: "4.51.0" transformers: "4.51.0"
- python: "3.9" - python: "3.11"
os: "ubuntu-latest" os: "ubuntu-latest"
transformers: "4.53.0" transformers: "4.53.0"
exclude: # exclude python 3.9 on macos
- python: "3.9"
os: "macos-latest"
runs-on: ${{ matrix.os }} runs-on: ${{ matrix.os }}
@@ -63,21 +59,23 @@ jobs:
- name: Checkout - name: Checkout
uses: actions/checkout@v4 uses: actions/checkout@v4
- name: Set up Python - name: Install uv
uses: actions/setup-python@v5 uses: astral-sh/setup-uv@v7
with: with:
python-version: ${{ matrix.python }} python-version: ${{ matrix.python }}
github-token: ${{ github.token }}
enable-cache: false
- name: Install dependencies - name: Install dependencies
run: | run: |
python -m pip install --upgrade pip uv venv
python -m pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu uv pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
python -m pip install ".[dev]" uv pip install -e ".[dev]"
- name: Install transformers - name: Install transformers
if: ${{ matrix.transformers }} if: ${{ matrix.transformers }}
run: | run: |
python -m pip install "transformers==${{ matrix.transformers }}" uv pip install "transformers==${{ matrix.transformers }}"
- name: Cache files - name: Cache files
id: hf-hub-cache id: hf-hub-cache
@@ -89,18 +87,25 @@ jobs:
- name: Check quality - name: Check quality
run: | run: |
make style && make quality make style && make quality
env:
UV_NO_SYNC: 1
- name: Check license - name: Check license
run: | run: |
make license make license
env:
UV_NO_SYNC: 1
- name: Check build - name: Check build
run: | run: |
make build make build
env:
UV_NO_SYNC: 1
- name: Test with pytest - name: Test with pytest
run: | run: |
make test make test
env: env:
UV_NO_SYNC: 1
HF_HOME: ${{ runner.temp }}/huggingface HF_HOME: ${{ runner.temp }}/huggingface
HF_HUB_OFFLINE: "${{ steps.hf-hub-cache.outputs.cache-hit == 'true' && '1' || '0' }}" HF_HUB_OFFLINE: "${{ steps.hf-hub-cache.outputs.cache-hit == 'true' && '1' || '0' }}"

View File

@@ -7,7 +7,7 @@ on:
- "main" - "main"
paths: paths:
- "**/*.py" - "**/*.py"
- "requirements.txt" - "pyproject.toml"
- "Makefile" - "Makefile"
- ".github/workflows/*.yml" - ".github/workflows/*.yml"
pull_request: pull_request:
@@ -15,7 +15,7 @@ on:
- "main" - "main"
paths: paths:
- "**/*.py" - "**/*.py"
- "requirements.txt" - "pyproject.toml"
- "Makefile" - "Makefile"
- ".github/workflows/*.yml" - ".github/workflows/*.yml"
@@ -48,10 +48,15 @@ jobs:
- name: Checkout - name: Checkout
uses: actions/checkout@v4 uses: actions/checkout@v4
- name: Install uv
run: |
curl -LsSf https://astral.sh/uv/install.sh | sh
- name: Install dependencies - name: Install dependencies
run: | run: |
python -m pip install --upgrade pip uv venv
python -m pip install ".[torch-npu,dev]" torch-npu==${{matrix.pytorch_npu}} uv pip install torch-npu==${{matrix.pytorch_npu}}
uv pip install -e ".[dev]"
- name: Install node - name: Install node
run: | run: |
@@ -70,18 +75,25 @@ jobs:
- name: Check quality - name: Check quality
run: | run: |
make style && make quality make style && make quality
env:
UV_NO_SYNC: 1
- name: Check license - name: Check license
run: | run: |
make license make license
env:
UV_NO_SYNC: 1
- name: Check build - name: Check build
run: | run: |
make build make build
env:
UV_NO_SYNC: 1
- name: Test with pytest - name: Test with pytest
run: | run: |
make test make test
env: env:
UV_NO_SYNC: 1
HF_HOME: /root/.cache/huggingface HF_HOME: /root/.cache/huggingface
HF_HUB_OFFLINE: "${{ steps.hf-hub-cache.outputs.cache-hit == 'true' && '1' || '0' }}" HF_HUB_OFFLINE: "${{ steps.hf-hub-cache.outputs.cache-hit == 'true' && '1' || '0' }}"

View File

@@ -1 +1 @@
include LICENSE requirements.txt include LICENSE

View File

@@ -1,24 +1,28 @@
.PHONY: build commit license quality style test .PHONY: build commit license quality style test
check_dirs := scripts src tests tests_v1 setup.py check_dirs := scripts src tests tests_v1
RUN := $(shell command -v uv >/dev/null 2>&1 && echo "uv run" || echo "")
BUILD := $(shell command -v uv >/dev/null 2>&1 && echo "uv build" || echo "python -m build")
TOOL := $(shell command -v uv >/dev/null 2>&1 && echo "uvx" || echo "")
build: build:
pip3 install build && python3 -m build $(BUILD)
commit: commit:
pre-commit install $(TOOL) pre-commit install
pre-commit run --all-files $(TOOL) pre-commit run --all-files
license: license:
python3 tests/check_license.py $(check_dirs) $(RUN) python3 tests/check_license.py $(check_dirs)
quality: quality:
ruff check $(check_dirs) $(TOOL) ruff check $(check_dirs)
ruff format --check $(check_dirs) $(TOOL) ruff format --check $(check_dirs)
style: style:
ruff check $(check_dirs) --fix $(TOOL) ruff check $(check_dirs) --fix
ruff format $(check_dirs) $(TOOL) ruff format $(check_dirs)
test: test:
WANDB_DISABLED=true pytest -vv --import-mode=importlib tests/ tests_v1/ WANDB_DISABLED=true $(RUN) pytest -vv --import-mode=importlib tests/ tests_v1/

View File

@@ -514,10 +514,12 @@ huggingface-cli login
```bash ```bash
git clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git git clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git
cd LLaMA-Factory cd LLaMA-Factory
pip install -e ".[torch,metrics]" --no-build-isolation pip install -e ".[metrics]" --no-build-isolation
``` ```
Extra dependencies available: torch, torch-npu, metrics, deepspeed, liger-kernel, bitsandbytes, hqq, eetq, gptq, aqlm, vllm, sglang, galore, apollo, badam, adam-mini, qwen, minicpm_v, openmind, swanlab, dev Optional dependencies available: `metrics`, `deepspeed`. Install with: `pip install -e ".[metrics,deepspeed]"`
Additional dependencies for specific features are available in `examples/requirements/`.
#### Install from Docker Image #### Install from Docker Image
@@ -536,13 +538,7 @@ Please refer to [build docker](#build-docker) to build the image yourself.
Create an isolated Python environment with [uv](https://github.com/astral-sh/uv): Create an isolated Python environment with [uv](https://github.com/astral-sh/uv):
```bash ```bash
uv sync --extra torch --extra metrics --prerelease=allow uv run llamafactory-cli webui
```
Run LLaMA-Factory in the isolated environment:
```bash
uv run --prerelease=allow llamafactory-cli train examples/train_lora/llama3_lora_pretrain.yaml
``` ```
</details> </details>
@@ -579,7 +575,7 @@ To enable FlashAttention-2 on the Windows platform, please use the script from [
<details><summary>For Ascend NPU users</summary> <details><summary>For Ascend NPU users</summary>
To install LLaMA Factory on Ascend NPU devices, please upgrade Python to version 3.10 or higher and specify extra dependencies: `pip install -e ".[torch-npu,metrics]"`. Additionally, you need to install the **[Ascend CANN Toolkit and Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**. Please follow the [installation tutorial](https://www.hiascend.com/document/detail/en/CANNCommunityEdition/600alphaX/softwareinstall/instg/atlasdeploy_03_0031.html) or use the following commands: To install LLaMA Factory on Ascend NPU devices, please upgrade Python to version 3.10 or higher: `pip install -e . torch-npu==2.7.1`. Additionally, you need to install the **[Ascend CANN Toolkit and Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**. Please follow the [installation tutorial](https://www.hiascend.com/document/detail/en/CANNCommunityEdition/600alphaX/softwareinstall/instg/atlasdeploy_03_0031.html) or use the following commands:
```bash ```bash
# replace the url according to your CANN version and devices # replace the url according to your CANN version and devices
@@ -598,8 +594,8 @@ source /usr/local/Ascend/ascend-toolkit/set_env.sh
| Requirement | Minimum | Recommend | | Requirement | Minimum | Recommend |
| ------------ | ------- | -------------- | | ------------ | ------- | -------------- |
| CANN | 8.0.RC1 | 8.0.0.alpha002 | | CANN | 8.0.RC1 | 8.0.0.alpha002 |
| torch | 2.1.0 | 2.4.0 | | torch | 2.1.0 | 2.7.1 |
| torch-npu | 2.1.0 | 2.4.0.post2 | | torch-npu | 2.1.0 | 2.7.1 |
| deepspeed | 0.13.2 | 0.13.2 | | deepspeed | 0.13.2 | 0.13.2 |
| vllm-ascend | - | 0.7.3 | | vllm-ascend | - | 0.7.3 |
@@ -714,7 +710,6 @@ For CUDA users:
```bash ```bash
docker build -f ./docker/docker-cuda/Dockerfile \ docker build -f ./docker/docker-cuda/Dockerfile \
--build-arg PIP_INDEX=https://pypi.org/simple \ --build-arg PIP_INDEX=https://pypi.org/simple \
--build-arg EXTRAS=metrics \
-t llamafactory:latest . -t llamafactory:latest .
docker run -dit --ipc=host --gpus=all \ docker run -dit --ipc=host --gpus=all \
@@ -731,7 +726,6 @@ For Ascend NPU users:
```bash ```bash
docker build -f ./docker/docker-npu/Dockerfile \ docker build -f ./docker/docker-npu/Dockerfile \
--build-arg PIP_INDEX=https://pypi.org/simple \ --build-arg PIP_INDEX=https://pypi.org/simple \
--build-arg EXTRAS=torch-npu,metrics \
-t llamafactory:latest . -t llamafactory:latest .
docker run -dit --ipc=host \ docker run -dit --ipc=host \
@@ -756,7 +750,6 @@ For AMD ROCm users:
```bash ```bash
docker build -f ./docker/docker-rocm/Dockerfile \ docker build -f ./docker/docker-rocm/Dockerfile \
--build-arg PIP_INDEX=https://pypi.org/simple \ --build-arg PIP_INDEX=https://pypi.org/simple \
--build-arg EXTRAS=metrics \
-t llamafactory:latest . -t llamafactory:latest .
docker run -dit --ipc=host \ docker run -dit --ipc=host \

View File

@@ -516,10 +516,12 @@ huggingface-cli login
```bash ```bash
git clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git git clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git
cd LLaMA-Factory cd LLaMA-Factory
pip install -e ".[torch,metrics]" --no-build-isolation pip install -e ".[metrics]" --no-build-isolation
``` ```
可选的额外依赖项:torch、torch-npu、metricsdeepspeed、liger-kernel、bitsandbytes、hqq、eetq、gptq、aqlm、vllm、sglang、galore、apollo、badam、adam-mini、qwen、minicpm_v、openmind、swanlab、dev 可选的额外依赖项:`metrics``deepspeed`。使用 `pip install -e ".[metrics,deepspeed]"` 安装。
其他可选依赖项请参考 `examples/requirements/` 目录下的文件。
#### 从镜像安装 #### 从镜像安装
@@ -538,13 +540,7 @@ docker run -it --rm --gpus=all --ipc=host hiyouga/llamafactory:latest
使用 [uv](https://github.com/astral-sh/uv) 创建隔离的 Python 环境: 使用 [uv](https://github.com/astral-sh/uv) 创建隔离的 Python 环境:
```bash ```bash
uv sync --extra torch --extra metrics --prerelease=allow uv run llamafactory-cli webui
```
在环境中运行 LLaMA-Factory
```bash
uv run --prerelease=allow llamafactory-cli train examples/train_lora/llama3_lora_pretrain.yaml
``` ```
</details> </details>
@@ -581,7 +577,7 @@ pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/downl
<details><summary>昇腾 NPU 用户指南</summary> <details><summary>昇腾 NPU 用户指南</summary>
在昇腾 NPU 设备上安装 LLaMA Factory 时,请升级 Python 到 3.10 及以上,并需要指定额外依赖项,使用 `pip install -e ".[torch-npu,metrics]"` 命令安装。此外,还需要安装 **[Ascend CANN Toolkit 与 Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**,安装方法请参考[安装教程](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/80RC2alpha002/quickstart/quickstart/quickstart_18_0004.html)或使用以下命令: 在昇腾 NPU 设备上安装 LLaMA Factory 时,请升级 Python 到 3.10 及以上,并需要指定额外依赖项,使用 `pip install -e . torch-npu==2.7.1` 命令安装。此外,还需要安装 **[Ascend CANN Toolkit 与 Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**,安装方法请参考[安装教程](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/80RC2alpha002/quickstart/quickstart/quickstart_18_0004.html)或使用以下命令:
```bash ```bash
# 请替换 URL 为 CANN 版本和设备型号对应的 URL # 请替换 URL 为 CANN 版本和设备型号对应的 URL
@@ -600,8 +596,8 @@ source /usr/local/Ascend/ascend-toolkit/set_env.sh
| 依赖项 | 至少 | 推荐 | | 依赖项 | 至少 | 推荐 |
| ------------ | ------- | -------------- | | ------------ | ------- | -------------- |
| CANN | 8.0.RC1 | 8.0.0.alpha002 | | CANN | 8.0.RC1 | 8.0.0.alpha002 |
| torch | 2.1.0 | 2.4.0 | | torch | 2.1.0 | 2.7.1 |
| torch-npu | 2.1.0 | 2.4.0.post2 | | torch-npu | 2.1.0 | 2.7.1 |
| deepspeed | 0.13.2 | 0.13.2 | | deepspeed | 0.13.2 | 0.13.2 |
| vllm-ascend | - | 0.7.3 | | vllm-ascend | - | 0.7.3 |

View File

@@ -4,7 +4,6 @@ FROM ${BASE_IMAGE}
# Installation arguments # Installation arguments
ARG PIP_INDEX=https://pypi.org/simple ARG PIP_INDEX=https://pypi.org/simple
ARG EXTRAS=metrics
ARG INSTALL_FLASHATTN=false ARG INSTALL_FLASHATTN=false
ARG HTTP_PROXY="" ARG HTTP_PROXY=""
@@ -27,17 +26,13 @@ WORKDIR /app
# Change pip source # Change pip source
RUN pip config set global.index-url "${PIP_INDEX}" && \ RUN pip config set global.index-url "${PIP_INDEX}" && \
pip config set global.extra-index-url "${PIP_INDEX}" && \ pip config set global.extra-index-url "${PIP_INDEX}" && \
pip install --no-cache-dir --upgrade pip packaging wheel setuptools pip install --no-cache-dir --upgrade pip packaging wheel setuptools editables "hatchling>=1.18.0"
# Install the requirements # Copy the application into the image
COPY requirements.txt /app
RUN pip install --no-cache-dir -r requirements.txt
# Copy the rest of the application into the image
COPY . /app COPY . /app
# Install LLaMA Factory # Install LLaMA Factory
RUN pip install --no-cache-dir -e ".[${EXTRAS}]" --no-build-isolation RUN pip install --no-cache-dir --no-build-isolation -e ".[metrics,deepspeed]"
# Rebuild flash attention # Rebuild flash attention
RUN if [ "${INSTALL_FLASHATTN}" == "true" ]; then \ RUN if [ "${INSTALL_FLASHATTN}" == "true" ]; then \

View File

@@ -8,7 +8,7 @@ ENV PYPI_MIRROR=https://mirrors.aliyun.com/pypi/simple/
ENV PYPI_TRUSTED_HOST=mirrors.aliyun.com ENV PYPI_TRUSTED_HOST=mirrors.aliyun.com
ENV APT_MIRROR=https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ ENV APT_MIRROR=https://mirrors.tuna.tsinghua.edu.cn/ubuntu/
RUN pip install --upgrade pip setuptools wheel --trusted-host ${PYPI_TRUSTED_HOST} --index-url ${PYPI_MIRROR} RUN pip install --upgrade pip setuptools wheel "hatchling>=1.18.0" editables --trusted-host ${PYPI_TRUSTED_HOST} --index-url ${PYPI_MIRROR}
RUN pip uninstall -y torch torchvision torch-tensorrt \ RUN pip uninstall -y torch torchvision torch-tensorrt \
flash_attn transformer-engine \ flash_attn transformer-engine \
@@ -56,14 +56,14 @@ ENV JAVA_HOME /usr/lib/jvm/java-21-openjdk-amd64
# pip install LLaMA-Factory # pip install LLaMA-Factory
WORKDIR /app WORKDIR /app
COPY requirements.txt /app/ # Copy the application into the image
RUN pip install --no-cache-dir -r requirements.txt COPY . /app
# Install LLaMA Factory
RUN pip install --no-cache-dir -e ".[metrics]" --no-build-isolation
RUN pip install "git+https://github.com/alibaba/roll.git#subdirectory=mcore_adapter" RUN pip install "git+https://github.com/alibaba/roll.git#subdirectory=mcore_adapter"
COPY . /app/
RUN pip install -e ".[metrics]" --no-build-isolation
# Expose port 7860 for LLaMA Board # Expose port 7860 for LLaMA Board
ENV GRADIO_SERVER_PORT=7860 ENV GRADIO_SERVER_PORT=7860
EXPOSE 7860 EXPOSE 7860

View File

@@ -5,7 +5,6 @@ services:
context: ../.. context: ../..
args: args:
PIP_INDEX: https://pypi.org/simple PIP_INDEX: https://pypi.org/simple
EXTRAS: metrics
container_name: llamafactory container_name: llamafactory
ports: ports:
- "7860:7860" - "7860:7860"

View File

@@ -5,7 +5,6 @@ FROM ${BASE_IMAGE}
# Installation arguments # Installation arguments
ARG PIP_INDEX=https://pypi.org/simple ARG PIP_INDEX=https://pypi.org/simple
ARG EXTRAS=torch-npu,metrics
ARG HTTP_PROXY="" ARG HTTP_PROXY=""
ARG PYTORCH_INDEX=https://download.pytorch.org/whl/cpu ARG PYTORCH_INDEX=https://download.pytorch.org/whl/cpu
@@ -28,21 +27,15 @@ WORKDIR /app
# Change pip source # Change pip source
RUN pip config set global.index-url "${PIP_INDEX}" && \ RUN pip config set global.index-url "${PIP_INDEX}" && \
pip config set global.extra-index-url "${PIP_INDEX}" && \ pip config set global.extra-index-url "${PIP_INDEX}" && \
pip install --no-cache-dir --upgrade pip packaging wheel setuptools pip install --no-cache-dir --upgrade pip packaging wheel setuptools editables "hatchling>=1.18.0"
# Copy the application into the image
COPY . /app
# Install torch-npu # Install torch-npu
RUN pip uninstall -y torch torchvision torchaudio && \ RUN pip uninstall -y torch torchvision torchaudio && \
pip install --no-cache-dir "torch==2.7.1" "torch-npu==2.7.1" "torchvision==0.22.1" --index-url "${PYTORCH_INDEX}" pip install --no-cache-dir "torch==2.7.1" "torch-npu==2.7.1" "torchvision==0.22.1" "torchaudio==2.7.1" --index-url "${PYTORCH_INDEX}" && \
pip install --no-cache-dir -e ".[metrics]" --no-build-isolation
# Install the requirements
COPY requirements.txt /app
RUN pip install --no-cache-dir -r requirements.txt
# Copy the rest of the application into the image
COPY . /app
# Install LLaMA Factory
RUN pip install --no-cache-dir -e ".[${EXTRAS}]" --no-build-isolation
# Set up volumes # Set up volumes
# VOLUME [ "/root/.cache/huggingface", "/app/shared_data", "/app/output" ] # VOLUME [ "/root/.cache/huggingface", "/app/shared_data", "/app/output" ]

View File

@@ -5,7 +5,6 @@ services:
context: ../.. context: ../..
args: args:
PIP_INDEX: https://pypi.org/simple PIP_INDEX: https://pypi.org/simple
EXTRAS: torch-npu,metrics
container_name: llamafactory-a2 container_name: llamafactory-a2
image: llamafactory:npu-a2 image: llamafactory:npu-a2
volumes: volumes:
@@ -36,7 +35,6 @@ services:
args: args:
BASE_IMAGE: quay.io/ascend/cann:8.3.rc2-a3-ubuntu22.04-py3.11 BASE_IMAGE: quay.io/ascend/cann:8.3.rc2-a3-ubuntu22.04-py3.11
PIP_INDEX: https://pypi.org/simple PIP_INDEX: https://pypi.org/simple
EXTRAS: torch-npu,metrics
container_name: llamafactory-a3 container_name: llamafactory-a3
image: llamafactory:npu-a3 image: llamafactory:npu-a3
volumes: volumes:

View File

@@ -4,7 +4,6 @@ FROM ${BASE_IMAGE}
# Installation arguments # Installation arguments
ARG PIP_INDEX=https://pypi.org/simple ARG PIP_INDEX=https://pypi.org/simple
ARG EXTRAS=metrics
ARG INSTALL_FLASHATTN=false ARG INSTALL_FLASHATTN=false
ARG HTTP_PROXY="" ARG HTTP_PROXY=""
ARG PYTORCH_INDEX=https://download.pytorch.org/whl/rocm6.3 ARG PYTORCH_INDEX=https://download.pytorch.org/whl/rocm6.3
@@ -28,21 +27,14 @@ WORKDIR /app
# Change pip source # Change pip source
RUN pip config set global.index-url "${PIP_INDEX}" && \ RUN pip config set global.index-url "${PIP_INDEX}" && \
pip config set global.extra-index-url "${PIP_INDEX}" && \ pip config set global.extra-index-url "${PIP_INDEX}" && \
pip install --no-cache-dir --upgrade pip packaging wheel setuptools pip install --no-cache-dir --upgrade pip packaging wheel setuptools editables "hatchling>=1.18.0"
# Reinstall pytorch rocm # Copy the application into the image
RUN pip uninstall -y torch torchvision torchaudio && \
pip install --no-cache-dir --pre torch torchvision torchaudio --index-url "${PYTORCH_INDEX}"
# Install the requirements
COPY requirements.txt /app
RUN pip install --no-cache-dir -r requirements.txt
# Copy the rest of the application into the image
COPY . /app COPY . /app
# Install LLaMA Factory # Reinstall pytorch rocm and install LLaMA Factory
RUN pip install --no-cache-dir -e ".[${EXTRAS}]" --no-build-isolation RUN pip uninstall -y torch torchvision torchaudio && \
pip install --no-cache-dir --no-build-isolation -e --pre ".[metrics,deepspeed]" --index-url "${PYTORCH_INDEX}"
# Rebuild flash attention # Rebuild flash attention
RUN if [ "${INSTALL_FLASHATTN}" == "true" ]; then \ RUN if [ "${INSTALL_FLASHATTN}" == "true" ]; then \

View File

@@ -5,7 +5,6 @@ services:
context: ../.. context: ../..
args: args:
PIP_INDEX: https://pypi.org/simple PIP_INDEX: https://pypi.org/simple
EXTRAS: metrics
container_name: llamafactory container_name: llamafactory
ports: ports:
- "7860:7860" - "7860:7860"

View File

@@ -0,0 +1 @@
adam-mini

View File

@@ -0,0 +1 @@
apollo-torch

View File

@@ -0,0 +1 @@
aqlm[gpu]>=1.1.0

View File

@@ -0,0 +1 @@
badam>=1.2.1

View File

@@ -0,0 +1 @@
bitsandbytes>=0.39.0

View File

@@ -0,0 +1 @@
eetq

View File

@@ -0,0 +1,2 @@
transformer_engine[pytorch]>=2.0.0
accelerate>=1.10.0

View File

@@ -0,0 +1,2 @@
torchao>=0.8.0
accelerate>=1.10.0

View File

@@ -0,0 +1 @@
galore-torch

View File

@@ -0,0 +1,2 @@
optimum>=1.24.0
gptqmodel>=2.0.0

View File

@@ -0,0 +1 @@
hqq

View File

@@ -0,0 +1 @@
liger-kernel>=0.5.5

View File

@@ -0,0 +1,8 @@
soundfile
torchvision
torchaudio
vector_quantize_pytorch
vocos
msgpack
referencing
jsonschema_specifications

View File

@@ -0,0 +1 @@
openmind

View File

@@ -0,0 +1,2 @@
sglang[srt]>=0.4.5
transformers==4.51.1

View File

@@ -0,0 +1 @@
swanlab

View File

@@ -0,0 +1 @@
vllm>=0.4.3,<=0.11.0

View File

@@ -1,25 +1,104 @@
[build-system] [build-system]
requires = ["setuptools>=61.0"] requires = ["hatchling"]
build-backend = "setuptools.build_meta" build-backend = "hatchling.build"
[project] [project]
name = "llamafactory" name = "llamafactory"
requires-python = ">=3.9.0" dynamic = ["version"]
dynamic = [ description = "Unified Efficient Fine-Tuning of 100+ LLMs"
"version", readme = "README.md"
"dependencies", license = "Apache-2.0"
"optional-dependencies", requires-python = ">=3.11.0"
"scripts", authors = [
"authors", { name = "hiyouga", email = "hiyouga@buaa.edu.cn" }
"description", ]
"readme", keywords = [
"license", "AI",
"keywords", "LLM",
"classifiers" "GPT",
"ChatGPT",
"Llama",
"Transformer",
"DeepSeek",
"Pytorch"
]
classifiers = [
"Development Status :: 4 - Beta",
"Intended Audience :: Developers",
"Intended Audience :: Education",
"Intended Audience :: Science/Research",
"License :: OSI Approved :: Apache Software License",
"Operating System :: OS Independent",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
"Topic :: Scientific/Engineering :: Artificial Intelligence"
]
dependencies = [
# core deps
"torch>=2.4.0",
"torchvision>=0.19.0",
"torchaudio>=2.4.0",
"transformers>=4.49.0,<=4.56.2,!=4.52.0; python_version < '3.10'",
"transformers>=4.49.0,<=4.57.1,!=4.52.0,!=4.57.0; python_version >= '3.10'",
"datasets>=2.16.0,<=4.0.0",
"accelerate>=1.3.0,<=1.11.0",
"peft>=0.14.0,<=0.17.1",
"trl>=0.8.6,<=0.9.6",
"torchdata>=0.10.0,<=0.11.0",
# gui
"gradio>=4.38.0,<=6.2.0",
"matplotlib>=3.7.0",
"tyro<0.9.0",
# ops
"einops",
"numpy",
"pandas",
"scipy",
# model and tokenizer
"sentencepiece",
"tiktoken",
"modelscope",
"hf-transfer",
"safetensors",
# python
"av",
"fire",
"omegaconf",
"packaging",
"protobuf",
"pyyaml",
"pydantic",
# api
"uvicorn",
"fastapi",
"sse-starlette"
] ]
[project.optional-dependencies]
dev = ["pre-commit", "ruff", "pytest", "build"]
metrics = ["nltk", "jieba", "rouge-chinese"]
deepspeed = ["deepspeed>=0.10.0,<=0.16.9"]
[project.scripts]
llamafactory-cli = "llamafactory.cli:main"
lmf = "llamafactory.cli:main"
[project.urls]
Homepage = "https://github.com/hiyouga/LLaMA-Factory"
Repository = "https://github.com/hiyouga/LLaMA-Factory"
[tool.hatch.build.targets.wheel]
packages = ["src/llamafactory"]
[tool.hatch.version]
path = "src/llamafactory/extras/env.py"
pattern = "VERSION = \"(?P<version>[^\"]+)\""
[tool.ruff] [tool.ruff]
target-version = "py39" target-version = "py311"
line-length = 119 line-length = 119
indent-width = 4 indent-width = 4
@@ -30,6 +109,8 @@ ignore = [
"E501", # line too long "E501", # line too long
"E731", # lambda function "E731", # lambda function
"E741", # ambiguous var name "E741", # ambiguous var name
"UP007", # no upgrade union
"UP045", # no upgrade optional
"D100", # no doc public module "D100", # no doc public module
"D101", # no doc public class "D101", # no doc public class
"D102", # no doc public method "D102", # no doc public method
@@ -73,23 +154,3 @@ indent-style = "space"
docstring-code-format = true docstring-code-format = true
skip-magic-trailing-comma = false skip-magic-trailing-comma = false
line-ending = "auto" line-ending = "auto"
[tool.uv]
conflicts = [
[
{ extra = "torch-npu" },
{ extra = "aqlm" },
],
[
{ extra = "torch-npu" },
{ extra = "vllm" },
],
[
{ extra = "torch-npu" },
{ extra = "sglang" },
],
[
{ extra = "vllm" },
{ extra = "sglang" },
],
]

View File

@@ -1,39 +0,0 @@
# core deps
transformers>=4.49.0,<=4.56.2,!=4.52.0; python_version < '3.10'
transformers>=4.49.0,<=4.57.1,!=4.52.0,!=4.57.0; python_version >= '3.10'
datasets>=2.16.0,<=4.0.0
accelerate>=1.3.0,<=1.11.0
peft>=0.14.0,<=0.17.1
trl>=0.8.6,<=0.9.6
torchdata
# gui
gradio>=4.38.0,<=5.45.0
matplotlib>=3.7.0
tyro<0.9.0
# ops
einops
numpy<2.0.0
pandas>=2.0.0
scipy
# model and tokenizer
sentencepiece
tiktoken
modelscope>=1.14.0
hf-transfer
safetensors<=0.5.3
# python
fire
omegaconf
packaging
protobuf
pyyaml
pydantic<=2.10.6
# api
uvicorn
fastapi
sse-starlette
# media
av
librosa
# yanked
propcache!=0.4.0

View File

@@ -16,7 +16,6 @@
# limitations under the License. # limitations under the License.
import os import os
from typing import Optional
import fire import fire
import torch import torch
@@ -34,7 +33,7 @@ def convert_mca_to_hf(
output_path: str = "./output", output_path: str = "./output",
bf16: bool = False, bf16: bool = False,
fp16: bool = False, fp16: bool = False,
convert_model_max_length: Optional[int] = None, convert_model_max_length: int | None = None,
): ):
"""Convert megatron checkpoint to HuggingFace format. """Convert megatron checkpoint to HuggingFace format.
@@ -67,11 +66,11 @@ def convert(
output_path: str = "./output", output_path: str = "./output",
bf16: bool = False, bf16: bool = False,
fp16: bool = False, fp16: bool = False,
convert_model_max_length: Optional[int] = None, convert_model_max_length: int | None = None,
tensor_model_parallel_size: int = 1, tensor_model_parallel_size: int = 1,
pipeline_model_parallel_size: int = 1, pipeline_model_parallel_size: int = 1,
expert_model_parallel_size: int = 1, expert_model_parallel_size: int = 1,
virtual_pipeline_model_parallel_size: Optional[int] = None, virtual_pipeline_model_parallel_size: int | None = None,
): ):
"""Convert checkpoint between MCA and HuggingFace formats. """Convert checkpoint between MCA and HuggingFace formats.

View File

@@ -14,7 +14,7 @@
import json import json
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Literal, Optional from typing import Any, Literal
import fire import fire
import torch import torch
@@ -61,7 +61,7 @@ def calculate_ppl(
dataset_dir: str = "data", dataset_dir: str = "data",
template: str = "default", template: str = "default",
cutoff_len: int = 2048, cutoff_len: int = 2048,
max_samples: Optional[int] = None, max_samples: int | None = None,
train_on_prompt: bool = False, train_on_prompt: bool = False,
): ):
r"""Calculate the ppl on the dataset of the pre-trained models. r"""Calculate the ppl on the dataset of the pre-trained models.

View File

@@ -14,7 +14,6 @@
import gc import gc
import json import json
from typing import Optional
import av import av
import fire import fire
@@ -49,7 +48,7 @@ def vllm_infer(
dataset_dir: str = "data", dataset_dir: str = "data",
template: str = "default", template: str = "default",
cutoff_len: int = 2048, cutoff_len: int = 2048,
max_samples: Optional[int] = None, max_samples: int | None = None,
vllm_config: str = "{}", vllm_config: str = "{}",
save_name: str = "generated_predictions.jsonl", save_name: str = "generated_predictions.jsonl",
temperature: float = 0.95, temperature: float = 0.95,
@@ -58,9 +57,9 @@ def vllm_infer(
max_new_tokens: int = 1024, max_new_tokens: int = 1024,
repetition_penalty: float = 1.0, repetition_penalty: float = 1.0,
skip_special_tokens: bool = True, skip_special_tokens: bool = True,
default_system: Optional[str] = None, default_system: str | None = None,
enable_thinking: bool = True, enable_thinking: bool = True,
seed: Optional[int] = None, seed: int | None = None,
pipeline_parallel_size: int = 1, pipeline_parallel_size: int = 1,
image_max_pixels: int = 768 * 768, image_max_pixels: int = 768 * 768,
image_min_pixels: int = 32 * 32, image_min_pixels: int = 32 * 32,

116
setup.py
View File

@@ -1,116 +0,0 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import re
from setuptools import find_packages, setup
def get_version() -> str:
with open(os.path.join("src", "llamafactory", "extras", "env.py"), encoding="utf-8") as f:
file_content = f.read()
pattern = r"{}\W*=\W*\"([^\"]+)\"".format("VERSION")
(version,) = re.findall(pattern, file_content)
return version
def get_requires() -> list[str]:
with open("requirements.txt", encoding="utf-8") as f:
file_content = f.read()
lines = [line.strip() for line in file_content.strip().split("\n") if not line.startswith("#")]
return lines
def get_console_scripts() -> list[str]:
console_scripts = ["llamafactory-cli = llamafactory.cli:main"]
if os.getenv("ENABLE_SHORT_CONSOLE", "1").lower() in ["true", "y", "1"]:
console_scripts.append("lmf = llamafactory.cli:main")
return console_scripts
extra_require = {
"torch": ["torch>=2.0.0", "torchvision>=0.15.0"],
"torch-npu": ["torch==2.7.1", "torch-npu==2.7.1", "torchvision==0.22.1", "decorator"],
"metrics": ["nltk", "jieba", "rouge-chinese"],
"deepspeed": ["deepspeed>=0.10.0,<=0.16.9"],
"liger-kernel": ["liger-kernel>=0.5.5"],
"bitsandbytes": ["bitsandbytes>=0.39.0"],
"hqq": ["hqq"],
"eetq": ["eetq"],
"gptq": ["optimum>=1.24.0", "gptqmodel>=2.0.0"],
"aqlm": ["aqlm[gpu]>=1.1.0"],
"vllm": ["vllm>=0.4.3,<=0.11.0"],
"sglang": ["sglang[srt]>=0.4.5", "transformers==4.51.1"],
"galore": ["galore-torch"],
"apollo": ["apollo-torch"],
"badam": ["badam>=1.2.1"],
"adam-mini": ["adam-mini"],
"minicpm_v": [
"soundfile",
"torchvision",
"torchaudio",
"vector_quantize_pytorch",
"vocos",
"msgpack",
"referencing",
"jsonschema_specifications",
],
"openmind": ["openmind"],
"swanlab": ["swanlab"],
"fp8": ["torchao>=0.8.0", "accelerate>=1.10.0"],
"fp8-te": ["transformer_engine[pytorch]>=2.0.0", "accelerate>=1.10.0"],
"fp8-all": ["torchao>=0.8.0", "transformer_engine[pytorch]>=2.0.0", "accelerate>=1.10.0"],
"dev": ["pre-commit", "ruff", "pytest", "build"],
}
def main():
setup(
name="llamafactory",
version=get_version(),
author="hiyouga",
author_email="hiyouga@buaa.edu.cn",
description="Unified Efficient Fine-Tuning of 100+ LLMs",
long_description=open("README.md", encoding="utf-8").read(),
long_description_content_type="text/markdown",
keywords=["AI", "LLM", "GPT", "ChatGPT", "Llama", "Transformer", "DeepSeek", "Pytorch"],
license="Apache 2.0 License",
url="https://github.com/hiyouga/LLaMA-Factory",
package_dir={"": "src"},
packages=find_packages("src"),
python_requires=">=3.9.0",
install_requires=get_requires(),
extras_require=extra_require,
entry_points={"console_scripts": get_console_scripts()},
classifiers=[
"Development Status :: 4 - Beta",
"Intended Audience :: Developers",
"Intended Audience :: Education",
"Intended Audience :: Science/Research",
"License :: OSI Approved :: Apache Software License",
"Operating System :: OS Independent",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
],
)
if __name__ == "__main__":
main()

View File

@@ -16,7 +16,7 @@ import asyncio
import os import os
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from functools import partial from functools import partial
from typing import Annotated, Optional from typing import Annotated
from ..chat import ChatModel from ..chat import ChatModel
from ..extras.constants import EngineName from ..extras.constants import EngineName
@@ -79,7 +79,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
api_key = os.getenv("API_KEY") api_key = os.getenv("API_KEY")
security = HTTPBearer(auto_error=False) security = HTTPBearer(auto_error=False)
async def verify_api_key(auth: Annotated[Optional[HTTPAuthorizationCredentials], Depends(security)]): async def verify_api_key(auth: Annotated[HTTPAuthorizationCredentials | None, Depends(security)]):
if api_key and (auth is None or auth.credentials != api_key): if api_key and (auth is None or auth.credentials != api_key):
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key.") raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key.")

View File

@@ -14,10 +14,9 @@
import time import time
from enum import Enum, unique from enum import Enum, unique
from typing import Any, Optional, Union from typing import Any, Literal
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing_extensions import Literal
@unique @unique
@@ -61,7 +60,7 @@ class FunctionDefinition(BaseModel):
class FunctionAvailable(BaseModel): class FunctionAvailable(BaseModel):
type: Literal["function", "code_interpreter"] = "function" type: Literal["function", "code_interpreter"] = "function"
function: Optional[FunctionDefinition] = None function: FunctionDefinition | None = None
class FunctionCall(BaseModel): class FunctionCall(BaseModel):
@@ -77,35 +76,35 @@ class URL(BaseModel):
class MultimodalInputItem(BaseModel): class MultimodalInputItem(BaseModel):
type: Literal["text", "image_url", "video_url", "audio_url"] type: Literal["text", "image_url", "video_url", "audio_url"]
text: Optional[str] = None text: str | None = None
image_url: Optional[URL] = None image_url: URL | None = None
video_url: Optional[URL] = None video_url: URL | None = None
audio_url: Optional[URL] = None audio_url: URL | None = None
class ChatMessage(BaseModel): class ChatMessage(BaseModel):
role: Role role: Role
content: Optional[Union[str, list[MultimodalInputItem]]] = None content: str | list[MultimodalInputItem] | None = None
tool_calls: Optional[list[FunctionCall]] = None tool_calls: list[FunctionCall] | None = None
class ChatCompletionMessage(BaseModel): class ChatCompletionMessage(BaseModel):
role: Optional[Role] = None role: Role | None = None
content: Optional[str] = None content: str | None = None
tool_calls: Optional[list[FunctionCall]] = None tool_calls: list[FunctionCall] | None = None
class ChatCompletionRequest(BaseModel): class ChatCompletionRequest(BaseModel):
model: str model: str
messages: list[ChatMessage] messages: list[ChatMessage]
tools: Optional[list[FunctionAvailable]] = None tools: list[FunctionAvailable] | None = None
do_sample: Optional[bool] = None do_sample: bool | None = None
temperature: Optional[float] = None temperature: float | None = None
top_p: Optional[float] = None top_p: float | None = None
n: int = 1 n: int = 1
presence_penalty: Optional[float] = None presence_penalty: float | None = None
max_tokens: Optional[int] = None max_tokens: int | None = None
stop: Optional[Union[str, list[str]]] = None stop: str | list[str] | None = None
stream: bool = False stream: bool = False
@@ -118,7 +117,7 @@ class ChatCompletionResponseChoice(BaseModel):
class ChatCompletionStreamResponseChoice(BaseModel): class ChatCompletionStreamResponseChoice(BaseModel):
index: int index: int
delta: ChatCompletionMessage delta: ChatCompletionMessage
finish_reason: Optional[Finish] = None finish_reason: Finish | None = None
class ChatCompletionResponseUsage(BaseModel): class ChatCompletionResponseUsage(BaseModel):
@@ -147,7 +146,7 @@ class ChatCompletionStreamResponse(BaseModel):
class ScoreEvaluationRequest(BaseModel): class ScoreEvaluationRequest(BaseModel):
model: str model: str
messages: list[str] messages: list[str]
max_length: Optional[int] = None max_length: int | None = None
class ScoreEvaluationResponse(BaseModel): class ScoreEvaluationResponse(BaseModel):

View File

@@ -14,9 +14,9 @@
import asyncio import asyncio
import os import os
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator, Callable
from threading import Thread from threading import Thread
from typing import TYPE_CHECKING, Any, Callable, Optional, Union from typing import TYPE_CHECKING, Any, Optional, Union
import torch import torch
from transformers import GenerationConfig, TextIteratorStreamer from transformers import GenerationConfig, TextIteratorStreamer

View File

@@ -15,7 +15,7 @@ import json
import os import os
from abc import abstractmethod from abc import abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Optional, Union from typing import TYPE_CHECKING, Any, Union
from ..extras import logging from ..extras import logging
from .data_utils import Role from .data_utils import Role
@@ -40,7 +40,7 @@ class DatasetConverter:
dataset_attr: "DatasetAttr" dataset_attr: "DatasetAttr"
data_args: "DataArguments" data_args: "DataArguments"
def _find_medias(self, medias: Union["MediaType", list["MediaType"], None]) -> Optional[list["MediaType"]]: def _find_medias(self, medias: Union["MediaType", list["MediaType"], None]) -> list["MediaType"] | None:
r"""Optionally concatenate media path to media dir when loading from local disk.""" r"""Optionally concatenate media path to media dir when loading from local disk."""
if medias is None: if medias is None:
return None return None

View File

@@ -16,7 +16,6 @@ import json
import re import re
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Optional, Union
from typing_extensions import override from typing_extensions import override
@@ -27,14 +26,14 @@ from .tool_utils import FunctionCall, get_tool_utils
@dataclass @dataclass
class Formatter(ABC): class Formatter(ABC):
slots: SLOTS = field(default_factory=list) slots: SLOTS = field(default_factory=list)
tool_format: Optional[str] = None tool_format: str | None = None
@abstractmethod @abstractmethod
def apply(self, **kwargs) -> SLOTS: def apply(self, **kwargs) -> SLOTS:
r"""Forms a list of slots according to the inputs to encode.""" r"""Forms a list of slots according to the inputs to encode."""
... ...
def extract(self, content: str) -> Union[str, list["FunctionCall"]]: def extract(self, content: str) -> str | list["FunctionCall"]:
r"""Extract a list of tuples from the response message if using tools. r"""Extract a list of tuples from the response message if using tools.
Each tuple consists of function name and function arguments. Each tuple consists of function name and function arguments.
@@ -156,5 +155,5 @@ class ToolFormatter(Formatter):
raise RuntimeError(f"Invalid JSON format in tool description: {str([content])}.") # flat string raise RuntimeError(f"Invalid JSON format in tool description: {str([content])}.") # flat string
@override @override
def extract(self, content: str) -> Union[str, list["FunctionCall"]]: def extract(self, content: str) -> str | list["FunctionCall"]:
return self.tool_utils.tool_extractor(content) return self.tool_utils.tool_extractor(content)

View File

@@ -162,13 +162,13 @@ def _load_single_dataset(
def _get_merged_dataset( def _get_merged_dataset(
dataset_names: Optional[list[str]], dataset_names: list[str] | None,
model_args: "ModelArguments", model_args: "ModelArguments",
data_args: "DataArguments", data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments", training_args: "Seq2SeqTrainingArguments",
stage: Literal["pt", "sft", "rm", "ppo", "kto"], stage: Literal["pt", "sft", "rm", "ppo", "kto"],
return_dict: bool = False, return_dict: bool = False,
) -> Optional[Union["Dataset", "IterableDataset", dict[str, "Dataset"]]]: ) -> Union["Dataset", "IterableDataset", dict[str, "Dataset"]] | None:
r"""Return the merged datasets in the standard format.""" r"""Return the merged datasets in the standard format."""
if dataset_names is None: if dataset_names is None:
return None return None
@@ -227,7 +227,7 @@ def _get_dataset_processor(
def _get_preprocessed_dataset( def _get_preprocessed_dataset(
dataset: Optional[Union["Dataset", "IterableDataset"]], dataset: Union["Dataset", "IterableDataset"] | None,
data_args: "DataArguments", data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments", training_args: "Seq2SeqTrainingArguments",
stage: Literal["pt", "sft", "rm", "ppo", "kto"], stage: Literal["pt", "sft", "rm", "ppo", "kto"],
@@ -235,7 +235,7 @@ def _get_preprocessed_dataset(
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"] = None, processor: Optional["ProcessorMixin"] = None,
is_eval: bool = False, is_eval: bool = False,
) -> Optional[Union["Dataset", "IterableDataset"]]: ) -> Union["Dataset", "IterableDataset"] | None:
r"""Preprocesses the dataset, including format checking and tokenization.""" r"""Preprocesses the dataset, including format checking and tokenization."""
if dataset is None: if dataset is None:
return None return None

View File

@@ -22,28 +22,20 @@ import re
from copy import deepcopy from copy import deepcopy
from dataclasses import dataclass from dataclasses import dataclass
from io import BytesIO from io import BytesIO
from typing import TYPE_CHECKING, BinaryIO, Literal, Optional, TypedDict, Union from typing import TYPE_CHECKING, BinaryIO, Literal, NotRequired, Optional, TypedDict, Union
import numpy as np import numpy as np
import torch import torch
import torchaudio
from transformers.image_utils import get_image_size, is_valid_image, to_numpy_array from transformers.image_utils import get_image_size, is_valid_image, to_numpy_array
from transformers.models.mllama.processing_mllama import ( from transformers.models.mllama.processing_mllama import (
convert_sparse_cross_attention_mask_to_dense, convert_sparse_cross_attention_mask_to_dense,
get_cross_attention_token_mask, get_cross_attention_token_mask,
) )
from typing_extensions import NotRequired, override from typing_extensions import override
from ..extras.constants import AUDIO_PLACEHOLDER, IGNORE_INDEX, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER from ..extras.constants import AUDIO_PLACEHOLDER, IGNORE_INDEX, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
from ..extras.packages import ( from ..extras.packages import is_pillow_available, is_pyav_available, is_transformers_version_greater_than
is_librosa_available,
is_pillow_available,
is_pyav_available,
is_transformers_version_greater_than,
)
if is_librosa_available():
import librosa
if is_pillow_available(): if is_pillow_available():
@@ -71,8 +63,8 @@ if TYPE_CHECKING:
from transformers.video_processing_utils import BaseVideoProcessor from transformers.video_processing_utils import BaseVideoProcessor
class EncodedImage(TypedDict): class EncodedImage(TypedDict):
path: Optional[str] path: str | None
bytes: Optional[bytes] bytes: bytes | None
ImageInput = Union[str, bytes, EncodedImage, BinaryIO, ImageObject] ImageInput = Union[str, bytes, EncodedImage, BinaryIO, ImageObject]
VideoInput = Union[str, BinaryIO, list[list[ImageInput]]] VideoInput = Union[str, BinaryIO, list[list[ImageInput]]]
@@ -152,9 +144,9 @@ def _check_video_is_nested_images(video: "VideoInput") -> bool:
@dataclass @dataclass
class MMPluginMixin: class MMPluginMixin:
image_token: Optional[str] image_token: str | None
video_token: Optional[str] video_token: str | None
audio_token: Optional[str] audio_token: str | None
expand_mm_tokens: bool = True expand_mm_tokens: bool = True
def _validate_input( def _validate_input(
@@ -316,7 +308,14 @@ class MMPluginMixin:
results, sampling_rates = [], [] results, sampling_rates = [], []
for audio in audios: for audio in audios:
if not isinstance(audio, np.ndarray): if not isinstance(audio, np.ndarray):
audio, sampling_rate = librosa.load(audio, sr=sampling_rate) audio, sr = torchaudio.load(audio)
if audio.shape[0] > 1:
audio = audio.mean(dim=0, keepdim=True)
if sr != sampling_rate:
audio = torchaudio.functional.resample(audio, sr, sampling_rate)
audio = audio.squeeze(0).numpy()
results.append(audio) results.append(audio)
sampling_rates.append(sampling_rate) sampling_rates.append(sampling_rate)
@@ -329,7 +328,7 @@ class MMPluginMixin:
videos: list["VideoInput"], videos: list["VideoInput"],
audios: list["AudioInput"], audios: list["AudioInput"],
processor: "MMProcessor", processor: "MMProcessor",
imglens: Optional[list[int]] = None, imglens: list[int] | None = None,
) -> dict[str, "torch.Tensor"]: ) -> dict[str, "torch.Tensor"]:
r"""Process visual inputs. r"""Process visual inputs.
@@ -427,13 +426,13 @@ class BasePlugin(MMPluginMixin):
def process_token_ids( def process_token_ids(
self, self,
input_ids: list[int], input_ids: list[int],
labels: Optional[list[int]], labels: list[int] | None,
images: list["ImageInput"], images: list["ImageInput"],
videos: list["VideoInput"], videos: list["VideoInput"],
audios: list["AudioInput"], audios: list["AudioInput"],
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
processor: Optional["MMProcessor"], processor: Optional["MMProcessor"],
) -> tuple[list[int], Optional[list[int]]]: ) -> tuple[list[int], list[int] | None]:
r"""Pre-process token ids after tokenization for VLMs.""" r"""Pre-process token ids after tokenization for VLMs."""
self._validate_input(processor, images, videos, audios) self._validate_input(processor, images, videos, audios)
return input_ids, labels return input_ids, labels
@@ -480,21 +479,39 @@ class ErnieVLPlugin(BasePlugin):
self._validate_input(processor, images, videos, audios) self._validate_input(processor, images, videos, audios)
self._validate_messages(messages, images, videos, audios) self._validate_messages(messages, images, videos, audios)
messages = deepcopy(messages) messages = deepcopy(messages)
image_processor: BaseImageProcessor = getattr(processor, "image_processor")
merge_length: int = getattr(image_processor, "merge_size") ** 2
if self.expand_mm_tokens:
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
image_grid_thw = mm_inputs.get("image_grid_thw", [])
video_grid_thw = mm_inputs.get("video_grid_thw", [])
else:
image_grid_thw = [None] * len(images)
video_grid_thw = [None] * len(videos)
image_idx, video_idx = 0, 0 image_idx, video_idx = 0, 0
for message in messages: for message in messages:
content = message["content"] content = message["content"]
image_token = self.image_token or "<|image@placeholder|>" image_token = self.image_token or "<|IMAGE_PLACEHOLDER|>"
video_token = self.video_token or "<|video@placeholder|>" video_token = self.video_token or "<|VIDEO_PLACEHOLDER|>"
while IMAGE_PLACEHOLDER in content: while IMAGE_PLACEHOLDER in content:
image_seqlen = image_grid_thw[image_idx].prod() // merge_length if self.expand_mm_tokens else 1
content = content.replace(
IMAGE_PLACEHOLDER,
f"Picture {image_idx + 1}:<|IMAGE_START|>{image_token * image_seqlen}<|IMAGE_END|>",
1,
)
image_idx += 1 image_idx += 1
content = content.replace(
IMAGE_PLACEHOLDER, f"Picture {image_idx}:<|IMAGE_START|>{image_token}<|IMAGE_END|>", 1
)
while VIDEO_PLACEHOLDER in content: while VIDEO_PLACEHOLDER in content:
video_idx += 1 video_seqlen = video_grid_thw[video_idx].prod() // merge_length if self.expand_mm_tokens else 1
content = content.replace( content = content.replace(
VIDEO_PLACEHOLDER, f"Video {video_idx}:<|VIDEO_START|>{video_token}<|VIDEO_END|>", 1 VIDEO_PLACEHOLDER,
f"Video {video_idx + 1}:<|VIDEO_START|>{video_token * video_seqlen}<|VIDEO_END|>",
1,
) )
video_idx += 1
message["content"] = content message["content"] = content
return messages return messages
@@ -1288,13 +1305,13 @@ class PaliGemmaPlugin(BasePlugin):
def process_token_ids( def process_token_ids(
self, self,
input_ids: list[int], input_ids: list[int],
labels: Optional[list[int]], labels: list[int] | None,
images: list["ImageInput"], images: list["ImageInput"],
videos: list["VideoInput"], videos: list["VideoInput"],
audios: list["AudioInput"], audios: list["AudioInput"],
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
processor: Optional["MMProcessor"], processor: Optional["MMProcessor"],
) -> tuple[list[int], Optional[list[int]]]: ) -> tuple[list[int], list[int] | None]:
self._validate_input(processor, images, videos, audios) self._validate_input(processor, images, videos, audios)
num_images = len(images) num_images = len(images)
image_seqlen = processor.image_seq_length if self.expand_mm_tokens else 0 # skip mm token image_seqlen = processor.image_seq_length if self.expand_mm_tokens else 0 # skip mm token
@@ -2109,9 +2126,9 @@ def register_mm_plugin(name: str, plugin_class: type["BasePlugin"]) -> None:
def get_mm_plugin( def get_mm_plugin(
name: str, name: str,
image_token: Optional[str] = None, image_token: str | None = None,
video_token: Optional[str] = None, video_token: str | None = None,
audio_token: Optional[str] = None, audio_token: str | None = None,
**kwargs, **kwargs,
) -> "BasePlugin": ) -> "BasePlugin":
r"""Get plugin for multimodal inputs.""" r"""Get plugin for multimodal inputs."""

View File

@@ -15,7 +15,7 @@
import json import json
import os import os
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Literal, Optional, Union from typing import Any, Literal
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
@@ -33,40 +33,40 @@ class DatasetAttr:
formatting: Literal["alpaca", "sharegpt", "openai"] = "alpaca" formatting: Literal["alpaca", "sharegpt", "openai"] = "alpaca"
ranking: bool = False ranking: bool = False
# extra configs # extra configs
subset: Optional[str] = None subset: str | None = None
split: str = "train" split: str = "train"
folder: Optional[str] = None folder: str | None = None
num_samples: Optional[int] = None num_samples: int | None = None
# common columns # common columns
system: Optional[str] = None system: str | None = None
tools: Optional[str] = None tools: str | None = None
images: Optional[str] = None images: str | None = None
videos: Optional[str] = None videos: str | None = None
audios: Optional[str] = None audios: str | None = None
# dpo columns # dpo columns
chosen: Optional[str] = None chosen: str | None = None
rejected: Optional[str] = None rejected: str | None = None
kto_tag: Optional[str] = None kto_tag: str | None = None
# alpaca columns # alpaca columns
prompt: Optional[str] = "instruction" prompt: str | None = "instruction"
query: Optional[str] = "input" query: str | None = "input"
response: Optional[str] = "output" response: str | None = "output"
history: Optional[str] = None history: str | None = None
# sharegpt columns # sharegpt columns
messages: Optional[str] = "conversations" messages: str | None = "conversations"
# sharegpt tags # sharegpt tags
role_tag: Optional[str] = "from" role_tag: str | None = "from"
content_tag: Optional[str] = "value" content_tag: str | None = "value"
user_tag: Optional[str] = "human" user_tag: str | None = "human"
assistant_tag: Optional[str] = "gpt" assistant_tag: str | None = "gpt"
observation_tag: Optional[str] = "observation" observation_tag: str | None = "observation"
function_tag: Optional[str] = "function_call" function_tag: str | None = "function_call"
system_tag: Optional[str] = "system" system_tag: str | None = "system"
def __repr__(self) -> str: def __repr__(self) -> str:
return self.dataset_name return self.dataset_name
def set_attr(self, key: str, obj: dict[str, Any], default: Optional[Any] = None) -> None: def set_attr(self, key: str, obj: dict[str, Any], default: Any | None = None) -> None:
setattr(self, key, obj.get(key, default)) setattr(self, key, obj.get(key, default))
def join(self, attr: dict[str, Any]) -> None: def join(self, attr: dict[str, Any]) -> None:
@@ -90,7 +90,7 @@ class DatasetAttr:
self.set_attr(tag, attr["tags"]) self.set_attr(tag, attr["tags"])
def get_dataset_list(dataset_names: Optional[list[str]], dataset_dir: Union[str, dict]) -> list["DatasetAttr"]: def get_dataset_list(dataset_names: list[str] | None, dataset_dir: str | dict) -> list["DatasetAttr"]:
r"""Get the attributes of the datasets.""" r"""Get the attributes of the datasets."""
if dataset_names is None: if dataset_names is None:
dataset_names = [] dataset_names = []

View File

@@ -981,7 +981,7 @@ register_template(
replace_eos=True, replace_eos=True,
replace_jinja_template=True, replace_jinja_template=True,
template_class=ReasoningTemplate, template_class=ReasoningTemplate,
mm_plugin=get_mm_plugin(name="ernie_vl", image_token="<|image@placeholder|>", video_token="<|video@placeholder|>"), mm_plugin=get_mm_plugin(name="ernie_vl", image_token="<|IMAGE_PLACEHOLDER|>", video_token="<|VIDEO_PLACEHOLDER|>"),
) )

View File

@@ -15,7 +15,6 @@
import os import os
from collections import OrderedDict, defaultdict from collections import OrderedDict, defaultdict
from enum import Enum, unique from enum import Enum, unique
from typing import Optional
from peft.utils import SAFETENSORS_WEIGHTS_NAME as SAFE_ADAPTER_WEIGHTS_NAME from peft.utils import SAFETENSORS_WEIGHTS_NAME as SAFE_ADAPTER_WEIGHTS_NAME
from peft.utils import WEIGHTS_NAME as ADAPTER_WEIGHTS_NAME from peft.utils import WEIGHTS_NAME as ADAPTER_WEIGHTS_NAME
@@ -154,7 +153,7 @@ class RopeScaling(str, Enum):
def register_model_group( def register_model_group(
models: dict[str, dict[DownloadSource, str]], models: dict[str, dict[DownloadSource, str]],
template: Optional[str] = None, template: str | None = None,
multimodal: bool = False, multimodal: bool = False,
) -> None: ) -> None:
for name, path in models.items(): for name, path in models.items():

View File

@@ -117,7 +117,7 @@ def _configure_library_root_logger() -> None:
library_root_logger.propagate = False library_root_logger.propagate = False
def get_logger(name: Optional[str] = None) -> "_Logger": def get_logger(name: str | None = None) -> "_Logger":
r"""Return a logger with the specified name. It it not supposed to be accessed externally.""" r"""Return a logger with the specified name. It it not supposed to be accessed externally."""
if name is None: if name is None:
name = _get_library_name() name = _get_library_name()

View File

@@ -332,3 +332,7 @@ def fix_proxy(ipv6_enabled: bool = False) -> None:
if ipv6_enabled: if ipv6_enabled:
os.environ.pop("http_proxy", None) os.environ.pop("http_proxy", None)
os.environ.pop("HTTP_PROXY", None) os.environ.pop("HTTP_PROXY", None)
os.environ.pop("https_proxy", None)
os.environ.pop("HTTPS_PROXY", None)
os.environ.pop("all_proxy", None)
os.environ.pop("ALL_PROXY", None)

View File

@@ -16,22 +16,22 @@
# limitations under the License. # limitations under the License.
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
from typing import Any, Literal, Optional from typing import Any, Literal
@dataclass @dataclass
class DataArguments: class DataArguments:
r"""Arguments pertaining to what data we are going to input our model for training and evaluation.""" r"""Arguments pertaining to what data we are going to input our model for training and evaluation."""
template: Optional[str] = field( template: str | None = field(
default=None, default=None,
metadata={"help": "Which template to use for constructing prompts in training and inference."}, metadata={"help": "Which template to use for constructing prompts in training and inference."},
) )
dataset: Optional[str] = field( dataset: str | None = field(
default=None, default=None,
metadata={"help": "The name of dataset(s) to use for training. Use commas to separate multiple datasets."}, metadata={"help": "The name of dataset(s) to use for training. Use commas to separate multiple datasets."},
) )
eval_dataset: Optional[str] = field( eval_dataset: str | None = field(
default=None, default=None,
metadata={"help": "The name of dataset(s) to use for evaluation. Use commas to separate multiple datasets."}, metadata={"help": "The name of dataset(s) to use for evaluation. Use commas to separate multiple datasets."},
) )
@@ -39,7 +39,7 @@ class DataArguments:
default="data", default="data",
metadata={"help": "Path to the folder containing the datasets."}, metadata={"help": "Path to the folder containing the datasets."},
) )
media_dir: Optional[str] = field( media_dir: str | None = field(
default=None, default=None,
metadata={"help": "Path to the folder containing the images, videos or audios. Defaults to `dataset_dir`."}, metadata={"help": "Path to the folder containing the images, videos or audios. Defaults to `dataset_dir`."},
) )
@@ -67,7 +67,7 @@ class DataArguments:
default="concat", default="concat",
metadata={"help": "Strategy to use in dataset mixing (concat/interleave) (undersampling/oversampling)."}, metadata={"help": "Strategy to use in dataset mixing (concat/interleave) (undersampling/oversampling)."},
) )
interleave_probs: Optional[str] = field( interleave_probs: str | None = field(
default=None, default=None,
metadata={"help": "Probabilities to sample data from datasets. Use commas to separate multiple datasets."}, metadata={"help": "Probabilities to sample data from datasets. Use commas to separate multiple datasets."},
) )
@@ -79,15 +79,15 @@ class DataArguments:
default=1000, default=1000,
metadata={"help": "The number of examples in one group in pre-processing."}, metadata={"help": "The number of examples in one group in pre-processing."},
) )
preprocessing_num_workers: Optional[int] = field( preprocessing_num_workers: int | None = field(
default=None, default=None,
metadata={"help": "The number of processes to use for the pre-processing."}, metadata={"help": "The number of processes to use for the pre-processing."},
) )
max_samples: Optional[int] = field( max_samples: int | None = field(
default=None, default=None,
metadata={"help": "For debugging purposes, truncate the number of examples for each dataset."}, metadata={"help": "For debugging purposes, truncate the number of examples for each dataset."},
) )
eval_num_beams: Optional[int] = field( eval_num_beams: int | None = field(
default=None, default=None,
metadata={"help": "Number of beams to use for evaluation. This argument will be passed to `model.generate`"}, metadata={"help": "Number of beams to use for evaluation. This argument will be passed to `model.generate`"},
) )
@@ -103,7 +103,7 @@ class DataArguments:
default=False, default=False,
metadata={"help": "Whether or not to evaluate on each dataset separately."}, metadata={"help": "Whether or not to evaluate on each dataset separately."},
) )
packing: Optional[bool] = field( packing: bool | None = field(
default=None, default=None,
metadata={"help": "Enable sequences packing in training. Will automatically enable in pre-training."}, metadata={"help": "Enable sequences packing in training. Will automatically enable in pre-training."},
) )
@@ -111,19 +111,19 @@ class DataArguments:
default=False, default=False,
metadata={"help": "Enable sequence packing without cross-attention."}, metadata={"help": "Enable sequence packing without cross-attention."},
) )
tool_format: Optional[str] = field( tool_format: str | None = field(
default=None, default=None,
metadata={"help": "Tool format to use for constructing function calling examples."}, metadata={"help": "Tool format to use for constructing function calling examples."},
) )
default_system: Optional[str] = field( default_system: str | None = field(
default=None, default=None,
metadata={"help": "Override the default system message in the template."}, metadata={"help": "Override the default system message in the template."},
) )
enable_thinking: Optional[bool] = field( enable_thinking: bool | None = field(
default=True, default=True,
metadata={"help": "Whether or not to enable thinking mode for reasoning models."}, metadata={"help": "Whether or not to enable thinking mode for reasoning models."},
) )
tokenized_path: Optional[str] = field( tokenized_path: str | None = field(
default=None, default=None,
metadata={ metadata={
"help": ( "help": (

View File

@@ -14,7 +14,7 @@
import os import os
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Literal, Optional from typing import Literal
from datasets import DownloadMode from datasets import DownloadMode
@@ -46,7 +46,7 @@ class EvaluationArguments:
default=5, default=5,
metadata={"help": "Number of examplars for few-shot learning."}, metadata={"help": "Number of examplars for few-shot learning."},
) )
save_dir: Optional[str] = field( save_dir: str | None = field(
default=None, default=None,
metadata={"help": "Path to save the evaluation results."}, metadata={"help": "Path to save the evaluation results."},
) )

View File

@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
from typing import Any, Literal, Optional from typing import Any, Literal
@dataclass @dataclass
@@ -40,7 +40,7 @@ class FreezeArguments:
) )
}, },
) )
freeze_extra_modules: Optional[str] = field( freeze_extra_modules: str | None = field(
default=None, default=None,
metadata={ metadata={
"help": ( "help": (
@@ -56,7 +56,7 @@ class FreezeArguments:
class LoraArguments: class LoraArguments:
r"""Arguments pertaining to the LoRA training.""" r"""Arguments pertaining to the LoRA training."""
additional_target: Optional[str] = field( additional_target: str | None = field(
default=None, default=None,
metadata={ metadata={
"help": ( "help": (
@@ -66,7 +66,7 @@ class LoraArguments:
) )
}, },
) )
lora_alpha: Optional[int] = field( lora_alpha: int | None = field(
default=None, default=None,
metadata={"help": "The scale factor for LoRA fine-tuning (default: lora_rank * 2)."}, metadata={"help": "The scale factor for LoRA fine-tuning (default: lora_rank * 2)."},
) )
@@ -88,7 +88,7 @@ class LoraArguments:
) )
}, },
) )
loraplus_lr_ratio: Optional[float] = field( loraplus_lr_ratio: float | None = field(
default=None, default=None,
metadata={"help": "LoRA plus learning rate ratio (lr_B / lr_A)."}, metadata={"help": "LoRA plus learning rate ratio (lr_B / lr_A)."},
) )
@@ -126,7 +126,7 @@ class LoraArguments:
class OFTArguments: class OFTArguments:
r"""Arguments pertaining to the OFT training.""" r"""Arguments pertaining to the OFT training."""
additional_target: Optional[str] = field( additional_target: str | None = field(
default=None, default=None,
metadata={ metadata={
"help": ( "help": (
@@ -220,27 +220,27 @@ class RLHFArguments:
default=False, default=False,
metadata={"help": "Whiten the rewards before compute advantages in PPO training."}, metadata={"help": "Whiten the rewards before compute advantages in PPO training."},
) )
ref_model: Optional[str] = field( ref_model: str | None = field(
default=None, default=None,
metadata={"help": "Path to the reference model used for the PPO or DPO training."}, metadata={"help": "Path to the reference model used for the PPO or DPO training."},
) )
ref_model_adapters: Optional[str] = field( ref_model_adapters: str | None = field(
default=None, default=None,
metadata={"help": "Path to the adapters of the reference model."}, metadata={"help": "Path to the adapters of the reference model."},
) )
ref_model_quantization_bit: Optional[int] = field( ref_model_quantization_bit: int | None = field(
default=None, default=None,
metadata={"help": "The number of bits to quantize the reference model."}, metadata={"help": "The number of bits to quantize the reference model."},
) )
reward_model: Optional[str] = field( reward_model: str | None = field(
default=None, default=None,
metadata={"help": "Path to the reward model used for the PPO training."}, metadata={"help": "Path to the reward model used for the PPO training."},
) )
reward_model_adapters: Optional[str] = field( reward_model_adapters: str | None = field(
default=None, default=None,
metadata={"help": "Path to the adapters of the reward model."}, metadata={"help": "Path to the adapters of the reward model."},
) )
reward_model_quantization_bit: Optional[int] = field( reward_model_quantization_bit: int | None = field(
default=None, default=None,
metadata={"help": "The number of bits to quantize the reward model."}, metadata={"help": "The number of bits to quantize the reward model."},
) )
@@ -248,7 +248,7 @@ class RLHFArguments:
default="lora", default="lora",
metadata={"help": "The type of the reward model in PPO training. Lora model only supports lora training."}, metadata={"help": "The type of the reward model in PPO training. Lora model only supports lora training."},
) )
ld_alpha: Optional[float] = field( ld_alpha: float | None = field(
default=None, default=None,
metadata={ metadata={
"help": ( "help": (
@@ -361,15 +361,15 @@ class BAdamArgument:
default="layer", default="layer",
metadata={"help": "Whether to use layer-wise or ratio-wise BAdam optimizer."}, metadata={"help": "Whether to use layer-wise or ratio-wise BAdam optimizer."},
) )
badam_start_block: Optional[int] = field( badam_start_block: int | None = field(
default=None, default=None,
metadata={"help": "The starting block index for layer-wise BAdam."}, metadata={"help": "The starting block index for layer-wise BAdam."},
) )
badam_switch_mode: Optional[Literal["ascending", "descending", "random", "fixed"]] = field( badam_switch_mode: Literal["ascending", "descending", "random", "fixed"] | None = field(
default="ascending", default="ascending",
metadata={"help": "the strategy of picking block to update for layer-wise BAdam."}, metadata={"help": "the strategy of picking block to update for layer-wise BAdam."},
) )
badam_switch_interval: Optional[int] = field( badam_switch_interval: int | None = field(
default=50, default=50,
metadata={ metadata={
"help": "Number of steps to update the block for layer-wise BAdam. Use -1 to disable the block update." "help": "Number of steps to update the block for layer-wise BAdam. Use -1 to disable the block update."
@@ -406,15 +406,15 @@ class SwanLabArguments:
default=False, default=False,
metadata={"help": "Whether or not to use the SwanLab (an experiment tracking and visualization tool)."}, metadata={"help": "Whether or not to use the SwanLab (an experiment tracking and visualization tool)."},
) )
swanlab_project: Optional[str] = field( swanlab_project: str | None = field(
default="llamafactory", default="llamafactory",
metadata={"help": "The project name in SwanLab."}, metadata={"help": "The project name in SwanLab."},
) )
swanlab_workspace: Optional[str] = field( swanlab_workspace: str | None = field(
default=None, default=None,
metadata={"help": "The workspace name in SwanLab."}, metadata={"help": "The workspace name in SwanLab."},
) )
swanlab_run_name: Optional[str] = field( swanlab_run_name: str | None = field(
default=None, default=None,
metadata={"help": "The experiment name in SwanLab."}, metadata={"help": "The experiment name in SwanLab."},
) )
@@ -422,19 +422,19 @@ class SwanLabArguments:
default="cloud", default="cloud",
metadata={"help": "The mode of SwanLab."}, metadata={"help": "The mode of SwanLab."},
) )
swanlab_api_key: Optional[str] = field( swanlab_api_key: str | None = field(
default=None, default=None,
metadata={"help": "The API key for SwanLab."}, metadata={"help": "The API key for SwanLab."},
) )
swanlab_logdir: Optional[str] = field( swanlab_logdir: str | None = field(
default=None, default=None,
metadata={"help": "The log directory for SwanLab."}, metadata={"help": "The log directory for SwanLab."},
) )
swanlab_lark_webhook_url: Optional[str] = field( swanlab_lark_webhook_url: str | None = field(
default=None, default=None,
metadata={"help": "The Lark(飞书) webhook URL for SwanLab."}, metadata={"help": "The Lark(飞书) webhook URL for SwanLab."},
) )
swanlab_lark_secret: Optional[str] = field( swanlab_lark_secret: str | None = field(
default=None, default=None,
metadata={"help": "The Lark(飞书) secret for SwanLab."}, metadata={"help": "The Lark(飞书) secret for SwanLab."},
) )
@@ -510,7 +510,7 @@ class FinetuningArguments(
default=False, default=False,
metadata={"help": "Whether or not to disable the shuffling of the training set."}, metadata={"help": "Whether or not to disable the shuffling of the training set."},
) )
early_stopping_steps: Optional[int] = field( early_stopping_steps: int | None = field(
default=None, default=None,
metadata={"help": "Number of steps to stop training if the `metric_for_best_model` does not improve."}, metadata={"help": "Number of steps to stop training if the `metric_for_best_model` does not improve."},
) )
@@ -530,11 +530,11 @@ class FinetuningArguments(
return arg return arg
self.freeze_trainable_modules: list[str] = split_arg(self.freeze_trainable_modules) self.freeze_trainable_modules: list[str] = split_arg(self.freeze_trainable_modules)
self.freeze_extra_modules: Optional[list[str]] = split_arg(self.freeze_extra_modules) self.freeze_extra_modules: list[str] | None = split_arg(self.freeze_extra_modules)
self.lora_alpha: int = self.lora_alpha or self.lora_rank * 2 self.lora_alpha: int = self.lora_alpha or self.lora_rank * 2
self.lora_target: list[str] = split_arg(self.lora_target) self.lora_target: list[str] = split_arg(self.lora_target)
self.oft_target: list[str] = split_arg(self.oft_target) self.oft_target: list[str] = split_arg(self.oft_target)
self.additional_target: Optional[list[str]] = split_arg(self.additional_target) self.additional_target: list[str] | None = split_arg(self.additional_target)
self.galore_target: list[str] = split_arg(self.galore_target) self.galore_target: list[str] = split_arg(self.galore_target)
self.apollo_target: list[str] = split_arg(self.apollo_target) self.apollo_target: list[str] = split_arg(self.apollo_target)
self.use_ref_model = self.stage == "dpo" and self.pref_loss not in ["orpo", "simpo"] self.use_ref_model = self.stage == "dpo" and self.pref_loss not in ["orpo", "simpo"]

View File

@@ -17,12 +17,11 @@
import json import json
from dataclasses import asdict, dataclass, field, fields from dataclasses import asdict, dataclass, field, fields
from typing import Any, Literal, Optional, Union from typing import Any, Literal, Self
import torch import torch
from omegaconf import OmegaConf from omegaconf import OmegaConf
from transformers.training_args import _convert_str_dict from transformers.training_args import _convert_str_dict
from typing_extensions import Self
from ..extras.constants import AttentionFunction, EngineName, QuantizationMethod, RopeScaling from ..extras.constants import AttentionFunction, EngineName, QuantizationMethod, RopeScaling
from ..extras.logging import get_logger from ..extras.logging import get_logger
@@ -35,13 +34,13 @@ logger = get_logger(__name__)
class BaseModelArguments: class BaseModelArguments:
r"""Arguments pertaining to the model.""" r"""Arguments pertaining to the model."""
model_name_or_path: Optional[str] = field( model_name_or_path: str | None = field(
default=None, default=None,
metadata={ metadata={
"help": "Path to the model weight or identifier from huggingface.co/models or modelscope.cn/models." "help": "Path to the model weight or identifier from huggingface.co/models or modelscope.cn/models."
}, },
) )
adapter_name_or_path: Optional[str] = field( adapter_name_or_path: str | None = field(
default=None, default=None,
metadata={ metadata={
"help": ( "help": (
@@ -50,11 +49,11 @@ class BaseModelArguments:
) )
}, },
) )
adapter_folder: Optional[str] = field( adapter_folder: str | None = field(
default=None, default=None,
metadata={"help": "The folder containing the adapter weights to load."}, metadata={"help": "The folder containing the adapter weights to load."},
) )
cache_dir: Optional[str] = field( cache_dir: str | None = field(
default=None, default=None,
metadata={"help": "Where to store the pre-trained models downloaded from huggingface.co or modelscope.cn."}, metadata={"help": "Where to store the pre-trained models downloaded from huggingface.co or modelscope.cn."},
) )
@@ -70,17 +69,17 @@ class BaseModelArguments:
default=False, default=False,
metadata={"help": "Whether or not the special tokens should be split during the tokenization process."}, metadata={"help": "Whether or not the special tokens should be split during the tokenization process."},
) )
add_tokens: Optional[str] = field( add_tokens: str | None = field(
default=None, default=None,
metadata={ metadata={
"help": "Non-special tokens to be added into the tokenizer. Use commas to separate multiple tokens." "help": "Non-special tokens to be added into the tokenizer. Use commas to separate multiple tokens."
}, },
) )
add_special_tokens: Optional[str] = field( add_special_tokens: str | None = field(
default=None, default=None,
metadata={"help": "Special tokens to be added into the tokenizer. Use commas to separate multiple tokens."}, metadata={"help": "Special tokens to be added into the tokenizer. Use commas to separate multiple tokens."},
) )
new_special_tokens_config: Optional[str] = field( new_special_tokens_config: str | None = field(
default=None, default=None,
metadata={ metadata={
"help": ( "help": (
@@ -110,7 +109,7 @@ class BaseModelArguments:
default=True, default=True,
metadata={"help": "Whether or not to use memory-efficient model loading."}, metadata={"help": "Whether or not to use memory-efficient model loading."},
) )
rope_scaling: Optional[RopeScaling] = field( rope_scaling: RopeScaling | None = field(
default=None, default=None,
metadata={"help": "Which scaling strategy should be adopted for the RoPE embeddings."}, metadata={"help": "Which scaling strategy should be adopted for the RoPE embeddings."},
) )
@@ -122,7 +121,7 @@ class BaseModelArguments:
default=False, default=False,
metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."}, metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."},
) )
mixture_of_depths: Optional[Literal["convert", "load"]] = field( mixture_of_depths: Literal["convert", "load"] | None = field(
default=None, default=None,
metadata={"help": "Convert the model to mixture-of-depths (MoD) or load the MoD model."}, metadata={"help": "Convert the model to mixture-of-depths (MoD) or load the MoD model."},
) )
@@ -138,7 +137,7 @@ class BaseModelArguments:
default=False, default=False,
metadata={"help": "Whether or not to enable liger kernel for faster training."}, metadata={"help": "Whether or not to enable liger kernel for faster training."},
) )
moe_aux_loss_coef: Optional[float] = field( moe_aux_loss_coef: float | None = field(
default=None, default=None,
metadata={"help": "Coefficient of the auxiliary router loss in mixture-of-experts model."}, metadata={"help": "Coefficient of the auxiliary router loss in mixture-of-experts model."},
) )
@@ -182,15 +181,15 @@ class BaseModelArguments:
default="auto", default="auto",
metadata={"help": "Data type for model weights and activations at inference."}, metadata={"help": "Data type for model weights and activations at inference."},
) )
hf_hub_token: Optional[str] = field( hf_hub_token: str | None = field(
default=None, default=None,
metadata={"help": "Auth token to log in with Hugging Face Hub."}, metadata={"help": "Auth token to log in with Hugging Face Hub."},
) )
ms_hub_token: Optional[str] = field( ms_hub_token: str | None = field(
default=None, default=None,
metadata={"help": "Auth token to log in with ModelScope Hub."}, metadata={"help": "Auth token to log in with ModelScope Hub."},
) )
om_hub_token: Optional[str] = field( om_hub_token: str | None = field(
default=None, default=None,
metadata={"help": "Auth token to log in with Modelers Hub."}, metadata={"help": "Auth token to log in with Modelers Hub."},
) )
@@ -283,7 +282,7 @@ class QuantizationArguments:
default=QuantizationMethod.BNB, default=QuantizationMethod.BNB,
metadata={"help": "Quantization method to use for on-the-fly quantization."}, metadata={"help": "Quantization method to use for on-the-fly quantization."},
) )
quantization_bit: Optional[int] = field( quantization_bit: int | None = field(
default=None, default=None,
metadata={"help": "The number of bits to quantize the model using on-the-fly quantization."}, metadata={"help": "The number of bits to quantize the model using on-the-fly quantization."},
) )
@@ -295,7 +294,7 @@ class QuantizationArguments:
default=True, default=True,
metadata={"help": "Whether or not to use double quantization in bitsandbytes int4 training."}, metadata={"help": "Whether or not to use double quantization in bitsandbytes int4 training."},
) )
quantization_device_map: Optional[Literal["auto"]] = field( quantization_device_map: Literal["auto"] | None = field(
default=None, default=None,
metadata={"help": "Device map used to infer the 4-bit quantized model, needs bitsandbytes>=0.43.0."}, metadata={"help": "Device map used to infer the 4-bit quantized model, needs bitsandbytes>=0.43.0."},
) )
@@ -375,7 +374,7 @@ class ProcessorArguments:
class ExportArguments: class ExportArguments:
r"""Arguments pertaining to the model export.""" r"""Arguments pertaining to the model export."""
export_dir: Optional[str] = field( export_dir: str | None = field(
default=None, default=None,
metadata={"help": "Path to the directory to save the exported model."}, metadata={"help": "Path to the directory to save the exported model."},
) )
@@ -387,11 +386,11 @@ class ExportArguments:
default="cpu", default="cpu",
metadata={"help": "The device used in model export, use `auto` to accelerate exporting."}, metadata={"help": "The device used in model export, use `auto` to accelerate exporting."},
) )
export_quantization_bit: Optional[int] = field( export_quantization_bit: int | None = field(
default=None, default=None,
metadata={"help": "The number of bits to quantize the exported model."}, metadata={"help": "The number of bits to quantize the exported model."},
) )
export_quantization_dataset: Optional[str] = field( export_quantization_dataset: str | None = field(
default=None, default=None,
metadata={"help": "Path to the dataset or dataset name to use in quantizing the exported model."}, metadata={"help": "Path to the dataset or dataset name to use in quantizing the exported model."},
) )
@@ -407,7 +406,7 @@ class ExportArguments:
default=False, default=False,
metadata={"help": "Whether or not to save the `.bin` files instead of `.safetensors`."}, metadata={"help": "Whether or not to save the `.bin` files instead of `.safetensors`."},
) )
export_hub_model_id: Optional[str] = field( export_hub_model_id: str | None = field(
default=None, default=None,
metadata={"help": "The name of the repository if push the model to the Hugging Face hub."}, metadata={"help": "The name of the repository if push the model to the Hugging Face hub."},
) )
@@ -437,7 +436,7 @@ class VllmArguments:
default=32, default=32,
metadata={"help": "Maximum rank of all LoRAs in the vLLM engine."}, metadata={"help": "Maximum rank of all LoRAs in the vLLM engine."},
) )
vllm_config: Optional[Union[dict, str]] = field( vllm_config: dict | str | None = field(
default=None, default=None,
metadata={"help": "Config to initialize the vllm engine. Please use JSON strings."}, metadata={"help": "Config to initialize the vllm engine. Please use JSON strings."},
) )
@@ -463,7 +462,7 @@ class SGLangArguments:
default=-1, default=-1,
metadata={"help": "Tensor parallel size for the SGLang engine."}, metadata={"help": "Tensor parallel size for the SGLang engine."},
) )
sglang_config: Optional[Union[dict, str]] = field( sglang_config: dict | str | None = field(
default=None, default=None,
metadata={"help": "Config to initialize the SGLang engine. Please use JSON strings."}, metadata={"help": "Config to initialize the SGLang engine. Please use JSON strings."},
) )
@@ -487,21 +486,21 @@ class KTransformersArguments:
default=False, default=False,
metadata={"help": "Whether To Use KTransformers Optimizations For LoRA Training."}, metadata={"help": "Whether To Use KTransformers Optimizations For LoRA Training."},
) )
kt_optimize_rule: Optional[str] = field( kt_optimize_rule: str | None = field(
default=None, default=None,
metadata={ metadata={
"help": "Path To The KTransformers Optimize Rule; See https://github.com/kvcache-ai/ktransformers/." "help": "Path To The KTransformers Optimize Rule; See https://github.com/kvcache-ai/ktransformers/."
}, },
) )
cpu_infer: Optional[int] = field( cpu_infer: int | None = field(
default=32, default=32,
metadata={"help": "Number Of CPU Cores Used For Computation."}, metadata={"help": "Number Of CPU Cores Used For Computation."},
) )
chunk_size: Optional[int] = field( chunk_size: int | None = field(
default=8192, default=8192,
metadata={"help": "Chunk Size Used For CPU Compute In KTransformers."}, metadata={"help": "Chunk Size Used For CPU Compute In KTransformers."},
) )
mode: Optional[str] = field( mode: str | None = field(
default="normal", default="normal",
metadata={"help": "Normal Or Long_Context For Llama Models."}, metadata={"help": "Normal Or Long_Context For Llama Models."},
) )
@@ -539,17 +538,17 @@ class ModelArguments(
The class on the most right will be displayed first. The class on the most right will be displayed first.
""" """
compute_dtype: Optional[torch.dtype] = field( compute_dtype: torch.dtype | None = field(
default=None, default=None,
init=False, init=False,
metadata={"help": "Torch data type for computing model outputs, derived from `fp/bf16`. Do not specify it."}, metadata={"help": "Torch data type for computing model outputs, derived from `fp/bf16`. Do not specify it."},
) )
device_map: Optional[Union[str, dict[str, Any]]] = field( device_map: str | dict[str, Any] | None = field(
default=None, default=None,
init=False, init=False,
metadata={"help": "Device map for model placement, derived from training stage. Do not specify it."}, metadata={"help": "Device map for model placement, derived from training stage. Do not specify it."},
) )
model_max_length: Optional[int] = field( model_max_length: int | None = field(
default=None, default=None,
init=False, init=False,
metadata={"help": "The maximum input length for model, derived from `cutoff_len`. Do not specify it."}, metadata={"help": "The maximum input length for model, derived from `cutoff_len`. Do not specify it."},

View File

@@ -18,7 +18,7 @@
import os import os
import sys import sys
from pathlib import Path from pathlib import Path
from typing import Any, Optional, Union from typing import Any, Optional
import torch import torch
import transformers import transformers
@@ -65,7 +65,7 @@ else:
_TRAIN_MCA_CLS = tuple() _TRAIN_MCA_CLS = tuple()
def read_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> Union[dict[str, Any], list[str]]: def read_args(args: dict[str, Any] | list[str] | None = None) -> dict[str, Any] | list[str]:
r"""Get arguments from the command line or a config file.""" r"""Get arguments from the command line or a config file."""
if args is not None: if args is not None:
return args return args
@@ -83,7 +83,7 @@ def read_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> Union[
def _parse_args( def _parse_args(
parser: "HfArgumentParser", args: Optional[Union[dict[str, Any], list[str]]] = None, allow_extra_keys: bool = False parser: "HfArgumentParser", args: dict[str, Any] | list[str] | None = None, allow_extra_keys: bool = False
) -> tuple[Any]: ) -> tuple[Any]:
args = read_args(args) args = read_args(args)
if isinstance(args, dict): if isinstance(args, dict):
@@ -205,13 +205,13 @@ def _check_extra_dependencies(
check_version("rouge_chinese", mandatory=True) check_version("rouge_chinese", mandatory=True)
def _parse_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _TRAIN_CLS: def _parse_train_args(args: dict[str, Any] | list[str] | None = None) -> _TRAIN_CLS:
parser = HfArgumentParser(_TRAIN_ARGS) parser = HfArgumentParser(_TRAIN_ARGS)
allow_extra_keys = is_env_enabled("ALLOW_EXTRA_ARGS") allow_extra_keys = is_env_enabled("ALLOW_EXTRA_ARGS")
return _parse_args(parser, args, allow_extra_keys=allow_extra_keys) return _parse_args(parser, args, allow_extra_keys=allow_extra_keys)
def _parse_train_mca_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _TRAIN_MCA_CLS: def _parse_train_mca_args(args: dict[str, Any] | list[str] | None = None) -> _TRAIN_MCA_CLS:
parser = HfArgumentParser(_TRAIN_MCA_ARGS) parser = HfArgumentParser(_TRAIN_MCA_ARGS)
allow_extra_keys = is_env_enabled("ALLOW_EXTRA_ARGS") allow_extra_keys = is_env_enabled("ALLOW_EXTRA_ARGS")
model_args, data_args, training_args, finetuning_args, generating_args = _parse_args( model_args, data_args, training_args, finetuning_args, generating_args = _parse_args(
@@ -232,25 +232,25 @@ def _configure_mca_training_args(training_args, data_args, finetuning_args) -> N
finetuning_args.use_mca = True finetuning_args.use_mca = True
def _parse_infer_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _INFER_CLS: def _parse_infer_args(args: dict[str, Any] | list[str] | None = None) -> _INFER_CLS:
parser = HfArgumentParser(_INFER_ARGS) parser = HfArgumentParser(_INFER_ARGS)
allow_extra_keys = is_env_enabled("ALLOW_EXTRA_ARGS") allow_extra_keys = is_env_enabled("ALLOW_EXTRA_ARGS")
return _parse_args(parser, args, allow_extra_keys=allow_extra_keys) return _parse_args(parser, args, allow_extra_keys=allow_extra_keys)
def _parse_eval_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _EVAL_CLS: def _parse_eval_args(args: dict[str, Any] | list[str] | None = None) -> _EVAL_CLS:
parser = HfArgumentParser(_EVAL_ARGS) parser = HfArgumentParser(_EVAL_ARGS)
allow_extra_keys = is_env_enabled("ALLOW_EXTRA_ARGS") allow_extra_keys = is_env_enabled("ALLOW_EXTRA_ARGS")
return _parse_args(parser, args, allow_extra_keys=allow_extra_keys) return _parse_args(parser, args, allow_extra_keys=allow_extra_keys)
def get_ray_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> RayArguments: def get_ray_args(args: dict[str, Any] | list[str] | None = None) -> RayArguments:
parser = HfArgumentParser(RayArguments) parser = HfArgumentParser(RayArguments)
(ray_args,) = _parse_args(parser, args, allow_extra_keys=True) (ray_args,) = _parse_args(parser, args, allow_extra_keys=True)
return ray_args return ray_args
def get_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _TRAIN_CLS: def get_train_args(args: dict[str, Any] | list[str] | None = None) -> _TRAIN_CLS:
if is_env_enabled("USE_MCA"): if is_env_enabled("USE_MCA"):
model_args, data_args, training_args, finetuning_args, generating_args = _parse_train_mca_args(args) model_args, data_args, training_args, finetuning_args, generating_args = _parse_train_mca_args(args)
else: else:
@@ -473,7 +473,7 @@ def get_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _
return model_args, data_args, training_args, finetuning_args, generating_args return model_args, data_args, training_args, finetuning_args, generating_args
def get_infer_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _INFER_CLS: def get_infer_args(args: dict[str, Any] | list[str] | None = None) -> _INFER_CLS:
model_args, data_args, finetuning_args, generating_args = _parse_infer_args(args) model_args, data_args, finetuning_args, generating_args = _parse_infer_args(args)
# Setup logging # Setup logging
@@ -508,7 +508,7 @@ def get_infer_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _
return model_args, data_args, finetuning_args, generating_args return model_args, data_args, finetuning_args, generating_args
def get_eval_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _EVAL_CLS: def get_eval_args(args: dict[str, Any] | list[str] | None = None) -> _EVAL_CLS:
model_args, data_args, eval_args, finetuning_args = _parse_eval_args(args) model_args, data_args, eval_args, finetuning_args = _parse_eval_args(args)
# Setup logging # Setup logging

View File

@@ -14,7 +14,7 @@
import json import json
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Literal, Optional, Union from typing import Literal
from transformers import Seq2SeqTrainingArguments from transformers import Seq2SeqTrainingArguments
from transformers.training_args import _convert_str_dict from transformers.training_args import _convert_str_dict
@@ -40,7 +40,7 @@ else:
class RayArguments: class RayArguments:
r"""Arguments pertaining to the Ray training.""" r"""Arguments pertaining to the Ray training."""
ray_run_name: Optional[str] = field( ray_run_name: str | None = field(
default=None, default=None,
metadata={"help": "The training results will be saved at `<ray_storage_path>/ray_run_name`."}, metadata={"help": "The training results will be saved at `<ray_storage_path>/ray_run_name`."},
) )
@@ -48,7 +48,7 @@ class RayArguments:
default="./saves", default="./saves",
metadata={"help": "The storage path to save training results to"}, metadata={"help": "The storage path to save training results to"},
) )
ray_storage_filesystem: Optional[Literal["s3", "gs", "gcs"]] = field( ray_storage_filesystem: Literal["s3", "gs", "gcs"] | None = field(
default=None, default=None,
metadata={"help": "The storage filesystem to use. If None specified, local filesystem will be used."}, metadata={"help": "The storage filesystem to use. If None specified, local filesystem will be used."},
) )
@@ -56,7 +56,7 @@ class RayArguments:
default=1, default=1,
metadata={"help": "The number of workers for Ray training. Default is 1 worker."}, metadata={"help": "The number of workers for Ray training. Default is 1 worker."},
) )
resources_per_worker: Union[dict, str] = field( resources_per_worker: dict | str = field(
default_factory=lambda: {"GPU": 1}, default_factory=lambda: {"GPU": 1},
metadata={"help": "The resources per worker for Ray training. Default is to use 1 GPU per worker."}, metadata={"help": "The resources per worker for Ray training. Default is to use 1 GPU per worker."},
) )
@@ -64,7 +64,7 @@ class RayArguments:
default="PACK", default="PACK",
metadata={"help": "The placement strategy for Ray training. Default is PACK."}, metadata={"help": "The placement strategy for Ray training. Default is PACK."},
) )
ray_init_kwargs: Optional[Union[dict, str]] = field( ray_init_kwargs: dict | str | None = field(
default=None, default=None,
metadata={"help": "The arguments to pass to ray.init for Ray training. Default is None."}, metadata={"help": "The arguments to pass to ray.init for Ray training. Default is None."},
) )

View File

@@ -15,7 +15,6 @@
import os import os
from typing import TYPE_CHECKING, Any, Optional, TypedDict from typing import TYPE_CHECKING, Any, Optional, TypedDict
import torch
from transformers import ( from transformers import (
AutoConfig, AutoConfig,
AutoModelForCausalLM, AutoModelForCausalLM,
@@ -158,6 +157,7 @@ def load_model(
if model is None and not lazy_load: if model is None and not lazy_load:
init_kwargs["config"] = config init_kwargs["config"] = config
init_kwargs["pretrained_model_name_or_path"] = model_args.model_name_or_path init_kwargs["pretrained_model_name_or_path"] = model_args.model_name_or_path
init_kwargs["torch_dtype"] = "auto"
if model_args.mixture_of_depths == "load": if model_args.mixture_of_depths == "load":
model = load_mod_pretrained_model(**init_kwargs) model = load_mod_pretrained_model(**init_kwargs)
@@ -205,10 +205,6 @@ def load_model(
if not is_trainable: if not is_trainable:
model.requires_grad_(False) model.requires_grad_(False)
for param in model.parameters():
if param.data.dtype == torch.float32 and model_args.compute_dtype != torch.float32:
param.data = param.data.to(model_args.compute_dtype)
model.eval() model.eval()
else: else:
model.train() model.train()

View File

@@ -20,9 +20,10 @@
import inspect import inspect
import os import os
from collections.abc import Callable
from functools import WRAPPER_ASSIGNMENTS, partial, wraps from functools import WRAPPER_ASSIGNMENTS, partial, wraps
from types import MethodType from types import MethodType
from typing import TYPE_CHECKING, Any, Callable, Optional, Union from typing import TYPE_CHECKING, Any, Optional, Union
import torch import torch

View File

@@ -156,11 +156,8 @@ def patch_config(
# deepspeed zero3 is not compatible with low_cpu_mem_usage # deepspeed zero3 is not compatible with low_cpu_mem_usage
init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage and (not is_deepspeed_zero3_enabled()) init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage and (not is_deepspeed_zero3_enabled())
# do not cast data type of the model deepspeed zero3 without qlora # fsdp/deepspeed zero3 does not need device map
if not (is_deepspeed_zero3_enabled() and model_args.quantization_bit is None): if not (is_deepspeed_zero3_enabled() or is_fsdp_enabled()) and init_kwargs["low_cpu_mem_usage"]:
init_kwargs["torch_dtype"] = model_args.compute_dtype
if init_kwargs["low_cpu_mem_usage"] and not is_fsdp_enabled(): # fsdp does not need device map
if "device_map" not in init_kwargs and model_args.device_map: if "device_map" not in init_kwargs and model_args.device_map:
init_kwargs["device_map"] = model_args.device_map # device map requires low_cpu_mem_usage=True init_kwargs["device_map"] = model_args.device_map # device map requires low_cpu_mem_usage=True

View File

@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING
import numpy as np import numpy as np
@@ -28,7 +28,7 @@ if TYPE_CHECKING:
class ComputeAccuracy: class ComputeAccuracy:
r"""Compute reward accuracy and support `batch_eval_metrics`.""" r"""Compute reward accuracy and support `batch_eval_metrics`."""
def _dump(self) -> Optional[dict[str, float]]: def _dump(self) -> dict[str, float] | None:
result = None result = None
if hasattr(self, "score_dict"): if hasattr(self, "score_dict"):
result = {k: float(np.mean(v)) for k, v in self.score_dict.items()} result = {k: float(np.mean(v)) for k, v in self.score_dict.items()}
@@ -39,7 +39,7 @@ class ComputeAccuracy:
def __post_init__(self): def __post_init__(self):
self._dump() self._dump()
def __call__(self, eval_preds: "EvalPrediction", compute_result: bool = True) -> Optional[dict[str, float]]: def __call__(self, eval_preds: "EvalPrediction", compute_result: bool = True) -> dict[str, float] | None:
chosen_scores, rejected_scores = numpify(eval_preds.predictions[0]), numpify(eval_preds.predictions[1]) chosen_scores, rejected_scores = numpify(eval_preds.predictions[0]), numpify(eval_preds.predictions[1])
if not chosen_scores.shape: if not chosen_scores.shape:
self.score_dict["accuracy"].append(chosen_scores > rejected_scores) self.score_dict["accuracy"].append(chosen_scores > rejected_scores)

View File

@@ -84,8 +84,6 @@ def load_reference_model(
model: AutoModelForCausalLMWithValueHead = AutoModelForCausalLMWithValueHead.from_pretrained( model: AutoModelForCausalLMWithValueHead = AutoModelForCausalLMWithValueHead.from_pretrained(
model_path, torch_dtype=torch.float16, device_map="auto" model_path, torch_dtype=torch.float16, device_map="auto"
) )
if not is_trainable:
model.v_head = model.v_head.to(torch.float16)
return model return model

View File

@@ -19,9 +19,9 @@
import json import json
import os import os
from collections.abc import Mapping from collections.abc import Callable, Mapping
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Optional, Union from typing import TYPE_CHECKING, Any, Optional, Union
import torch import torch
from transformers import Trainer from transformers import Trainer

View File

@@ -25,10 +25,11 @@ Including:
""" """
import os import os
from collections.abc import Callable
from contextlib import contextmanager from contextlib import contextmanager
from enum import Enum, unique from enum import Enum, unique
from functools import lru_cache, wraps from functools import lru_cache, wraps
from typing import Callable, Optional from typing import Optional
import numpy as np import numpy as np
import torch import torch

View File

@@ -53,9 +53,9 @@ class DistributedStrategy:
mp_replicate_size: int = 1 mp_replicate_size: int = 1
"""Model parallel replicate size, default to 1.""" """Model parallel replicate size, default to 1."""
mp_shard_size: Optional[int] = None mp_shard_size: int | None = None
"""Model parallel shard size, default to world_size // mp_replicate_size.""" """Model parallel shard size, default to world_size // mp_replicate_size."""
dp_size: Optional[int] = None dp_size: int | None = None
"""Data parallel size, default to world_size // cp_size.""" """Data parallel size, default to world_size // cp_size."""
cp_size: int = 1 cp_size: int = 1
"""Context parallel size, default to 1.""" """Context parallel size, default to 1."""
@@ -115,7 +115,7 @@ class DistributedInterface:
return cls._instance return cls._instance
def __init__(self, config: Optional[DistributedConfig] = None) -> None: def __init__(self, config: DistributedConfig | None = None) -> None:
if self._initialized: if self._initialized:
return return
@@ -165,7 +165,7 @@ class DistributedInterface:
f"model_device_mesh={self.model_device_mesh}, data_device_mesh={self.data_device_mesh}" f"model_device_mesh={self.model_device_mesh}, data_device_mesh={self.data_device_mesh}"
) )
def get_device_mesh(self, dim: Optional[Dim] = None) -> Optional[DeviceMesh]: def get_device_mesh(self, dim: Dim | None = None) -> DeviceMesh | None:
"""Get device mesh for specified dimension.""" """Get device mesh for specified dimension."""
if dim is None: if dim is None:
raise ValueError("dim must be specified.") raise ValueError("dim must be specified.")
@@ -176,14 +176,14 @@ class DistributedInterface:
else: else:
return self.model_device_mesh[dim.value] return self.model_device_mesh[dim.value]
def get_group(self, dim: Optional[Dim] = None) -> Optional[ProcessGroup]: def get_group(self, dim: Dim | None = None) -> Optional[ProcessGroup]:
"""Get process group for specified dimension.""" """Get process group for specified dimension."""
if self.model_device_mesh is None or dim is None: if self.model_device_mesh is None or dim is None:
return None return None
else: else:
return self.get_device_mesh(dim).get_group() return self.get_device_mesh(dim).get_group()
def get_rank(self, dim: Optional[Dim] = None) -> int: def get_rank(self, dim: Dim | None = None) -> int:
"""Get parallel rank for specified dimension.""" """Get parallel rank for specified dimension."""
if self.model_device_mesh is None: if self.model_device_mesh is None:
return 0 return 0
@@ -192,7 +192,7 @@ class DistributedInterface:
else: else:
return self.get_device_mesh(dim).get_local_rank() return self.get_device_mesh(dim).get_local_rank()
def get_world_size(self, dim: Optional[Dim] = None) -> int: def get_world_size(self, dim: Dim | None = None) -> int:
"""Get parallel size for specified dimension.""" """Get parallel size for specified dimension."""
if self.model_device_mesh is None: if self.model_device_mesh is None:
return 1 return 1
@@ -209,7 +209,7 @@ class DistributedInterface:
"""Get parallel local world size.""" """Get parallel local world size."""
return self._local_world_size return self._local_world_size
def all_gather(self, data: Tensor, dim: Optional[Dim] = Dim.DP) -> Tensor: def all_gather(self, data: Tensor, dim: Dim | None = Dim.DP) -> Tensor:
"""Gather tensor across specified parallel group.""" """Gather tensor across specified parallel group."""
if self.model_device_mesh is not None: if self.model_device_mesh is not None:
return helper.operate_tensorlike(helper.all_gather, data, group=self.get_group(dim)) return helper.operate_tensorlike(helper.all_gather, data, group=self.get_group(dim))
@@ -217,7 +217,7 @@ class DistributedInterface:
return data return data
def all_reduce( def all_reduce(
self, data: TensorLike, op: helper.ReduceOp = helper.ReduceOp.MEAN, dim: Optional[Dim] = Dim.DP self, data: TensorLike, op: helper.ReduceOp = helper.ReduceOp.MEAN, dim: Dim | None = Dim.DP
) -> TensorLike: ) -> TensorLike:
"""Reduce tensor across specified parallel group.""" """Reduce tensor across specified parallel group."""
if self.model_device_mesh is not None: if self.model_device_mesh is not None:
@@ -225,7 +225,7 @@ class DistributedInterface:
else: else:
return data return data
def broadcast(self, data: TensorLike, src: int = 0, dim: Optional[Dim] = Dim.DP) -> TensorLike: def broadcast(self, data: TensorLike, src: int = 0, dim: Dim | None = Dim.DP) -> TensorLike:
"""Broadcast tensor across specified parallel group.""" """Broadcast tensor across specified parallel group."""
if self.model_device_mesh is not None: if self.model_device_mesh is not None:
return helper.operate_tensorlike(helper.broadcast, data, src=src, group=self.get_group(dim)) return helper.operate_tensorlike(helper.broadcast, data, src=src, group=self.get_group(dim))

View File

@@ -15,7 +15,7 @@
import json import json
import sys import sys
from pathlib import Path from pathlib import Path
from typing import Any, Optional, Union from typing import Any
from omegaconf import OmegaConf from omegaconf import OmegaConf
from transformers import HfArgumentParser from transformers import HfArgumentParser
@@ -27,7 +27,7 @@ from .sample_args import SampleArguments
from .training_args import TrainingArguments from .training_args import TrainingArguments
InputArgument = Optional[Union[dict[str, Any], list[str]]] InputArgument = dict[str, Any] | list[str] | None
def validate_args( def validate_args(

View File

@@ -18,7 +18,6 @@
import json import json
from enum import Enum, unique from enum import Enum, unique
from typing import Optional, Union
class PluginConfig(dict): class PluginConfig(dict):
@@ -33,7 +32,7 @@ class PluginConfig(dict):
return self["name"] return self["name"]
PluginArgument = Optional[Union[PluginConfig, dict, str]] PluginArgument = PluginConfig | dict | str | None
@unique @unique
@@ -74,7 +73,7 @@ def _convert_str_dict(data: dict) -> dict:
return data return data
def get_plugin_config(config: PluginArgument) -> Optional[PluginConfig]: def get_plugin_config(config: PluginArgument) -> PluginConfig | None:
"""Get the plugin configuration from the argument value. """Get the plugin configuration from the argument value.
Args: Args:

View File

@@ -14,12 +14,11 @@
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Optional
@dataclass @dataclass
class DataArguments: class DataArguments:
dataset: Optional[str] = field( dataset: str | None = field(
default=None, default=None,
metadata={"help": "Path to the dataset."}, metadata={"help": "Path to the dataset."},
) )

View File

@@ -14,7 +14,6 @@
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Optional
from .arg_utils import ModelClass, PluginConfig, get_plugin_config from .arg_utils import ModelClass, PluginConfig, get_plugin_config
@@ -36,15 +35,15 @@ class ModelArguments:
default=ModelClass.LLM, default=ModelClass.LLM,
metadata={"help": "Model class from Hugging Face."}, metadata={"help": "Model class from Hugging Face."},
) )
peft_config: Optional[PluginConfig] = field( peft_config: PluginConfig | None = field(
default=None, default=None,
metadata={"help": "PEFT configuration for the model."}, metadata={"help": "PEFT configuration for the model."},
) )
kernel_config: Optional[PluginConfig] = field( kernel_config: PluginConfig | None = field(
default=None, default=None,
metadata={"help": "Kernel configuration for the model."}, metadata={"help": "Kernel configuration for the model."},
) )
quant_config: Optional[PluginConfig] = field( quant_config: PluginConfig | None = field(
default=None, default=None,
metadata={"help": "Quantization configuration for the model."}, metadata={"help": "Quantization configuration for the model."},
) )

View File

@@ -14,7 +14,6 @@
import os import os
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Optional
from uuid import uuid4 from uuid import uuid4
from .arg_utils import PluginConfig, get_plugin_config from .arg_utils import PluginConfig, get_plugin_config
@@ -42,7 +41,7 @@ class TrainingArguments:
default=False, default=False,
metadata={"help": "Use bf16 for training."}, metadata={"help": "Use bf16 for training."},
) )
dist_config: Optional[PluginConfig] = field( dist_config: PluginConfig | None = field(
default=None, default=None,
metadata={"help": "Distribution configuration for training."}, metadata={"help": "Distribution configuration for training."},
) )

View File

@@ -27,7 +27,7 @@ Get Data Sample:
import os import os
from collections.abc import Iterable from collections.abc import Iterable
from typing import Any, Union from typing import Any
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
from omegaconf import OmegaConf from omegaconf import OmegaConf
@@ -134,7 +134,7 @@ class DataEngine(Dataset):
else: else:
return len(self.data_index) return len(self.data_index)
def __getitem__(self, index: Union[int, Any]) -> Union[Sample, list[Sample]]: def __getitem__(self, index: int | Any) -> Sample | list[Sample]:
"""Get dataset item. """Get dataset item.
Args: Args:

View File

@@ -13,9 +13,7 @@
# limitations under the License. # limitations under the License.
from typing import Any, Literal, TypedDict from typing import Any, Literal, NotRequired, TypedDict
from typing_extensions import NotRequired
from ...utils import logging from ...utils import logging
from ...utils.plugin import BasePlugin from ...utils.plugin import BasePlugin

View File

@@ -15,7 +15,7 @@
import os import os
import random import random
from typing import Any, Literal, Optional, Union from typing import Any, Literal
from datasets import load_dataset from datasets import load_dataset
@@ -70,7 +70,7 @@ class DataIndexPlugin(BasePlugin):
"""Plugin for adjusting dataset index.""" """Plugin for adjusting dataset index."""
def adjust_data_index( def adjust_data_index(
self, data_index: list[tuple[str, int]], size: Optional[int], weight: Optional[float] self, data_index: list[tuple[str, int]], size: int | None, weight: float | None
) -> list[tuple[str, int]]: ) -> list[tuple[str, int]]:
"""Adjust dataset index by size and weight. """Adjust dataset index by size and weight.
@@ -95,8 +95,8 @@ class DataSelectorPlugin(BasePlugin):
"""Plugin for selecting dataset samples.""" """Plugin for selecting dataset samples."""
def select( def select(
self, data_index: list[tuple[str, int]], index: Union[slice, list[int], Any] self, data_index: list[tuple[str, int]], index: slice | list[int] | Any
) -> Union[tuple[str, int], list[tuple[str, int]]]: ) -> tuple[str, int] | list[tuple[str, int]]:
"""Select dataset samples. """Select dataset samples.
Args: Args:

View File

@@ -14,7 +14,6 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import Union
@dataclass @dataclass
@@ -32,7 +31,7 @@ class QwenTemplate:
message_template: str = "<|im_start|>{role}\n{content}<|im_end|>\n" # FIXME if role: tool message_template: str = "<|im_start|>{role}\n{content}<|im_end|>\n" # FIXME if role: tool
thinking_template: str = "<think>\n{content}\n</think>\n\n" thinking_template: str = "<think>\n{content}\n</think>\n\n"
def _extract_content(self, content_data: Union[str, list[dict[str, str]]]) -> str: def _extract_content(self, content_data: str | list[dict[str, str]]) -> str:
if isinstance(content_data, str): if isinstance(content_data, str):
return content_data.strip() return content_data.strip()
@@ -47,7 +46,7 @@ class QwenTemplate:
return "" return ""
def render_message(self, message: dict[str, Union[str, list[dict[str, str]]]]) -> str: def render_message(self, message: dict[str, str | list[dict[str, str]]]) -> str:
role = message["role"] role = message["role"]
content = self._extract_content(message.get("content", "")) content = self._extract_content(message.get("content", ""))

View File

@@ -13,7 +13,8 @@
# limitations under the License. # limitations under the License.
from abc import ABC, ABCMeta, abstractmethod from abc import ABC, ABCMeta, abstractmethod
from typing import Any, Callable, Optional, Union from collections.abc import Callable
from typing import Any, Optional
from ....accelerator.helper import DeviceType, get_current_accelerator from ....accelerator.helper import DeviceType, get_current_accelerator
from ....utils.types import HFModel from ....utils.types import HFModel
@@ -38,7 +39,7 @@ class KernelRegistry:
self._initialized = True self._initialized = True
def register( def register(
self, kernel_type: KernelType, device_type: DeviceType, kernel_impl: Optional[Callable[..., Any]] self, kernel_type: KernelType, device_type: DeviceType, kernel_impl: Callable[..., Any] | None
) -> None: ) -> None:
"""Register a kernel implementation. """Register a kernel implementation.
@@ -56,7 +57,7 @@ class KernelRegistry:
self._registry[kernel_type][device_type] = kernel_impl self._registry[kernel_type][device_type] = kernel_impl
print(f"Registered kernel {kernel_type.name} for device {device_type.name}.") print(f"Registered kernel {kernel_type.name} for device {device_type.name}.")
def get_kernel(self, kernel_type: KernelType, device_type: DeviceType) -> Optional[Callable[..., Any]]: def get_kernel(self, kernel_type: KernelType, device_type: DeviceType) -> Callable[..., Any] | None:
return self._registry.get(kernel_type, {}).get(device_type) return self._registry.get(kernel_type, {}).get(device_type)
@@ -105,9 +106,9 @@ class MetaKernel(ABC, metaclass=AutoRegisterKernelMeta):
auto_register: Set to False to disable automatic registration (default: True). auto_register: Set to False to disable automatic registration (default: True).
""" """
type: Optional[KernelType] = None type: KernelType | None = None
device: Optional[DeviceType] = None device: DeviceType | None = None
kernel: Optional[Callable] = None kernel: Callable | None = None
@classmethod @classmethod
@abstractmethod @abstractmethod
@@ -228,7 +229,7 @@ def discover_kernels(model: HFModel = None) -> list[type[MetaKernel]]:
return discovered_kernels return discovered_kernels
def apply_kernel(model: HFModel, kernel: Union[type[MetaKernel], Any], /, **kwargs) -> "HFModel": def apply_kernel(model: HFModel, kernel: type[MetaKernel] | Any, /, **kwargs) -> "HFModel":
"""Call the MetaKernel's `apply` to perform the replacement. """Call the MetaKernel's `apply` to perform the replacement.
Corresponding replacement logic is maintained inside each kernel; the only Corresponding replacement logic is maintained inside each kernel; the only

View File

@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Literal, Optional, TypedDict from typing import Literal, TypedDict
from peft import LoraConfig, PeftModel, get_peft_model from peft import LoraConfig, PeftModel, get_peft_model
@@ -36,7 +36,7 @@ class FreezeConfigDict(TypedDict, total=False):
"""Plugin name.""" """Plugin name."""
freeze_trainable_layers: int freeze_trainable_layers: int
"""Freeze trainable layers.""" """Freeze trainable layers."""
freeze_trainable_modules: Optional[list[str]] freeze_trainable_modules: list[str] | None
"""Freeze trainable modules.""" """Freeze trainable modules."""

View File

@@ -16,7 +16,6 @@
# limitations under the License. # limitations under the License.
from contextlib import contextmanager from contextlib import contextmanager
from typing import Union
import torch import torch
from transformers.utils import is_torch_bf16_available_on_device, is_torch_fp16_available_on_device from transformers.utils import is_torch_bf16_available_on_device, is_torch_fp16_available_on_device
@@ -38,7 +37,7 @@ class DtypeInterface:
_is_fp32_available = True _is_fp32_available = True
@staticmethod @staticmethod
def is_available(precision: Union[str, torch.dtype]) -> bool: def is_available(precision: str | torch.dtype) -> bool:
if precision in DtypeRegistry.HALF_LIST: if precision in DtypeRegistry.HALF_LIST:
return DtypeInterface._is_fp16_available return DtypeInterface._is_fp16_available
elif precision in DtypeRegistry.FLOAT_LIST: elif precision in DtypeRegistry.FLOAT_LIST:
@@ -49,19 +48,19 @@ class DtypeInterface:
raise RuntimeError(f"Unexpected precision: {precision}") raise RuntimeError(f"Unexpected precision: {precision}")
@staticmethod @staticmethod
def is_fp16(precision: Union[str, torch.dtype]) -> bool: def is_fp16(precision: str | torch.dtype) -> bool:
return precision in DtypeRegistry.HALF_LIST return precision in DtypeRegistry.HALF_LIST
@staticmethod @staticmethod
def is_fp32(precision: Union[str, torch.dtype]) -> bool: def is_fp32(precision: str | torch.dtype) -> bool:
return precision in DtypeRegistry.FLOAT_LIST return precision in DtypeRegistry.FLOAT_LIST
@staticmethod @staticmethod
def is_bf16(precision: Union[str, torch.dtype]) -> bool: def is_bf16(precision: str | torch.dtype) -> bool:
return precision in DtypeRegistry.BFLOAT_LIST return precision in DtypeRegistry.BFLOAT_LIST
@staticmethod @staticmethod
def to_dtype(precision: Union[str, torch.dtype]) -> torch.dtype: def to_dtype(precision: str | torch.dtype) -> torch.dtype:
if precision in DtypeRegistry.HALF_LIST: if precision in DtypeRegistry.HALF_LIST:
return torch.float16 return torch.float16
elif precision in DtypeRegistry.FLOAT_LIST: elif precision in DtypeRegistry.FLOAT_LIST:
@@ -83,7 +82,7 @@ class DtypeInterface:
raise RuntimeError(f"Unexpected precision: {precision}") raise RuntimeError(f"Unexpected precision: {precision}")
@contextmanager @contextmanager
def set_dtype(self, precision: Union[str, torch.dtype]): def set_dtype(self, precision: str | torch.dtype):
original_dtype = torch.get_default_dtype() original_dtype = torch.get_default_dtype()
torch.set_default_dtype(self.to_dtype(precision)) torch.set_default_dtype(self.to_dtype(precision))
try: try:

View File

@@ -81,7 +81,7 @@ def _configure_library_root_logger() -> None:
library_root_logger.propagate = False library_root_logger.propagate = False
def get_logger(name: Optional[str] = None) -> "_Logger": def get_logger(name: str | None = None) -> "_Logger":
"""Return a logger with the specified name. It it not supposed to be accessed externally.""" """Return a logger with the specified name. It it not supposed to be accessed externally."""
if name is None: if name is None:
name = _get_library_name() name = _get_library_name()

View File

@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
from typing import Callable, Optional from collections.abc import Callable
from . import logging from . import logging
@@ -29,7 +29,7 @@ class BasePlugin:
_registry: dict[str, Callable] = {} _registry: dict[str, Callable] = {}
def __init__(self, name: Optional[str] = None): def __init__(self, name: str | None = None):
"""Initialize the plugin with a name. """Initialize the plugin with a name.
Args: Args:

View File

@@ -12,9 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import TYPE_CHECKING, Literal, TypedDict, Union from typing import TYPE_CHECKING, Literal, NotRequired, TypedDict, Union
from typing_extensions import NotRequired
if TYPE_CHECKING: if TYPE_CHECKING:

View File

@@ -16,7 +16,7 @@ import json
import os import os
from collections.abc import Generator from collections.abc import Generator
from contextlib import contextmanager from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Optional from typing import TYPE_CHECKING, Any
from transformers.utils import is_torch_npu_available from transformers.utils import is_torch_npu_available
@@ -81,7 +81,7 @@ class WebChatModel(ChatModel):
def __init__(self, manager: "Manager", demo_mode: bool = False, lazy_init: bool = True) -> None: def __init__(self, manager: "Manager", demo_mode: bool = False, lazy_init: bool = True) -> None:
self.manager = manager self.manager = manager
self.demo_mode = demo_mode self.demo_mode = demo_mode
self.engine: Optional[BaseEngine] = None self.engine: BaseEngine | None = None
if not lazy_init: # read arguments from command line if not lazy_init: # read arguments from command line
super().__init__() super().__init__()
@@ -197,9 +197,9 @@ class WebChatModel(ChatModel):
lang: str, lang: str,
system: str, system: str,
tools: str, tools: str,
image: Optional[Any], image: Any | None,
video: Optional[Any], video: Any | None,
audio: Optional[Any], audio: Any | None,
max_new_tokens: int, max_new_tokens: int,
top_p: float, top_p: float,
temperature: float, temperature: float,

View File

@@ -17,7 +17,7 @@ import os
import signal import signal
from collections import defaultdict from collections import defaultdict
from datetime import datetime from datetime import datetime
from typing import Any, Optional, Union from typing import Any
from psutil import Process from psutil import Process
from yaml import safe_dump, safe_load from yaml import safe_dump, safe_load
@@ -71,7 +71,7 @@ def _get_config_path() -> os.PathLike:
return os.path.join(DEFAULT_CACHE_DIR, USER_CONFIG) return os.path.join(DEFAULT_CACHE_DIR, USER_CONFIG)
def load_config() -> dict[str, Union[str, dict[str, Any]]]: def load_config() -> dict[str, str | dict[str, Any]]:
r"""Load user config if exists.""" r"""Load user config if exists."""
try: try:
with open(_get_config_path(), encoding="utf-8") as f: with open(_get_config_path(), encoding="utf-8") as f:
@@ -81,7 +81,7 @@ def load_config() -> dict[str, Union[str, dict[str, Any]]]:
def save_config( def save_config(
lang: str, hub_name: Optional[str] = None, model_name: Optional[str] = None, model_path: Optional[str] = None lang: str, hub_name: str | None = None, model_name: str | None = None, model_path: str | None = None
) -> None: ) -> None:
r"""Save user config.""" r"""Save user config."""
os.makedirs(DEFAULT_CACHE_DIR, exist_ok=True) os.makedirs(DEFAULT_CACHE_DIR, exist_ok=True)
@@ -151,7 +151,7 @@ def load_dataset_info(dataset_dir: str) -> dict[str, dict[str, Any]]:
return {} return {}
def load_args(config_path: str) -> Optional[dict[str, Any]]: def load_args(config_path: str) -> dict[str, Any] | None:
r"""Load the training configuration from config path.""" r"""Load the training configuration from config path."""
try: try:
with open(config_path, encoding="utf-8") as f: with open(config_path, encoding="utf-8") as f:

View File

@@ -14,7 +14,7 @@
import json import json
from collections.abc import Generator from collections.abc import Generator
from typing import TYPE_CHECKING, Union from typing import TYPE_CHECKING
from ...extras.constants import PEFT_METHODS from ...extras.constants import PEFT_METHODS
from ...extras.misc import torch_gc from ...extras.misc import torch_gc
@@ -37,7 +37,7 @@ if TYPE_CHECKING:
GPTQ_BITS = ["8", "4", "3", "2"] GPTQ_BITS = ["8", "4", "3", "2"]
def can_quantize(checkpoint_path: Union[str, list[str]]) -> "gr.Dropdown": def can_quantize(checkpoint_path: str | list[str]) -> "gr.Dropdown":
if isinstance(checkpoint_path, list) and len(checkpoint_path) != 0: if isinstance(checkpoint_path, list) and len(checkpoint_path) != 0:
return gr.Dropdown(value="none", interactive=False) return gr.Dropdown(value="none", interactive=False)
else: else:
@@ -49,7 +49,7 @@ def save_model(
model_name: str, model_name: str,
model_path: str, model_path: str,
finetuning_type: str, finetuning_type: str,
checkpoint_path: Union[str, list[str]], checkpoint_path: str | list[str],
template: str, template: str,
export_size: int, export_size: int,
export_quantization_bit: str, export_quantization_bit: str,

View File

@@ -14,7 +14,7 @@
import json import json
import os import os
from typing import Any, Optional from typing import Any
from transformers.trainer_utils import get_last_checkpoint from transformers.trainer_utils import get_last_checkpoint
@@ -206,7 +206,7 @@ def list_datasets(dataset_dir: str = None, training_stage: str = list(TRAINING_S
return gr.Dropdown(choices=datasets) return gr.Dropdown(choices=datasets)
def list_output_dirs(model_name: Optional[str], finetuning_type: str, current_time: str) -> "gr.Dropdown": def list_output_dirs(model_name: str | None, finetuning_type: str, current_time: str) -> "gr.Dropdown":
r"""List all the directories that can resume from. r"""List all the directories that can resume from.
Inputs: top.model_name, top.finetuning_type, train.current_time Inputs: top.model_name, top.finetuning_type, train.current_time

View File

@@ -35,35 +35,40 @@ LOCALES = {
"value": ( "value": (
"<h3><center>Visit <a href='https://github.com/hiyouga/LLaMA-Factory' target='_blank'>" "<h3><center>Visit <a href='https://github.com/hiyouga/LLaMA-Factory' target='_blank'>"
"GitHub Page</a> <a href='https://llamafactory.readthedocs.io/en/latest/' target='_blank'>" "GitHub Page</a> <a href='https://llamafactory.readthedocs.io/en/latest/' target='_blank'>"
"Documentation</a></center></h3>" "Documentation</a> <a href='https://blog.llamafactory.net/en/' target='_blank'>"
"Blog</a></center></h3>"
), ),
}, },
"ru": { "ru": {
"value": ( "value": (
"<h3><center>Посетить <a href='https://github.com/hiyouga/LLaMA-Factory' target='_blank'>" "<h3><center>Посетить <a href='https://github.com/hiyouga/LLaMA-Factory' target='_blank'>"
"страницу GitHub</a> <a href='https://llamafactory.readthedocs.io/en/latest/' target='_blank'>" "страницу GitHub</a> <a href='https://llamafactory.readthedocs.io/en/latest/' target='_blank'>"
"Документацию</a></center></h3>" "Документацию</a> <a href='https://blog.llamafactory.net/en/' target='_blank'>"
"Блог</a></center></h3>"
), ),
}, },
"zh": { "zh": {
"value": ( "value": (
"<h3><center>访问 <a href='https://github.com/hiyouga/LLaMA-Factory' target='_blank'>" "<h3><center>访问 <a href='https://github.com/hiyouga/LLaMA-Factory' target='_blank'>"
"GitHub 主页</a> <a href='https://llamafactory.readthedocs.io/zh-cn/latest/' target='_blank'>" "GitHub 主页</a> <a href='https://llamafactory.readthedocs.io/zh-cn/latest/' target='_blank'>"
"官方文档</a></center></h3>" "官方文档</a> <a href='https://blog.llamafactory.net/' target='_blank'>"
"博客</a></center></h3>"
), ),
}, },
"ko": { "ko": {
"value": ( "value": (
"<h3><center><a href='https://github.com/hiyouga/LLaMA-Factory' target='_blank'>" "<h3><center><a href='https://github.com/hiyouga/LLaMA-Factory' target='_blank'>"
"GitHub 페이지</a> <a href='https://llamafactory.readthedocs.io/en/latest/' target='_blank'>" "GitHub 페이지</a> <a href='https://llamafactory.readthedocs.io/en/latest/' target='_blank'>"
"공식 문서</a>를 방문하세요.</center></h3>" "공식 문서</a> <a href='https://blog.llamafactory.net/en/' target='_blank'>"
"블로그</a>를 방문하세요.</center></h3>"
), ),
}, },
"ja": { "ja": {
"value": ( "value": (
"<h3><center><a href='https://github.com/hiyouga/LLaMA-Factory' target='_blank'>" "<h3><center><a href='https://github.com/hiyouga/LLaMA-Factory' target='_blank'>"
"GitHub ページ</a> <a href='https://llamafactory.readthedocs.io/en/latest/' target='_blank'>" "GitHub ページ</a> <a href='https://llamafactory.readthedocs.io/en/latest/' target='_blank'>"
"ドキュメント</a>にアクセスする</center></h3>" "ドキュメント</a> <a href='https://blog.llamafactory.net/en/' target='_blank'>"
"ブログ</a>にアクセスする</center></h3>"
), ),
}, },
}, },

View File

@@ -17,7 +17,7 @@ import os
from collections.abc import Generator from collections.abc import Generator
from copy import deepcopy from copy import deepcopy
from subprocess import PIPE, Popen, TimeoutExpired from subprocess import PIPE, Popen, TimeoutExpired
from typing import TYPE_CHECKING, Any, Optional from typing import TYPE_CHECKING, Any
from transformers.utils import is_torch_npu_available from transformers.utils import is_torch_npu_available
@@ -59,7 +59,7 @@ class Runner:
self.manager = manager self.manager = manager
self.demo_mode = demo_mode self.demo_mode = demo_mode
""" Resume """ """ Resume """
self.trainer: Optional[Popen] = None self.trainer: Popen | None = None
self.do_train = True self.do_train = True
self.running_data: dict[Component, Any] = None self.running_data: dict[Component, Any] = None
""" State """ """ State """

View File

@@ -18,7 +18,6 @@ Contains shared fixtures, pytest configuration, and custom markers.
""" """
import os import os
from typing import Optional
import pytest import pytest
from pytest import Config, FixtureRequest, Item, MonkeyPatch from pytest import Config, FixtureRequest, Item, MonkeyPatch
@@ -71,7 +70,7 @@ def _handle_slow_tests(items: list[Item]):
item.add_marker(skip_slow) item.add_marker(skip_slow)
def _get_visible_devices_env() -> Optional[str]: def _get_visible_devices_env() -> str | None:
"""Return device visibility env var name.""" """Return device visibility env var name."""
if CURRENT_DEVICE == "cuda": if CURRENT_DEVICE == "cuda":
return "CUDA_VISIBLE_DEVICES" return "CUDA_VISIBLE_DEVICES"

View File

@@ -18,7 +18,6 @@ Contains shared fixtures, pytest configuration, and custom markers.
""" """
import os import os
from typing import Optional
import pytest import pytest
from pytest import Config, FixtureRequest, Item, MonkeyPatch from pytest import Config, FixtureRequest, Item, MonkeyPatch
@@ -71,7 +70,7 @@ def _handle_slow_tests(items: list[Item]):
item.add_marker(skip_slow) item.add_marker(skip_slow)
def _get_visible_devices_env() -> Optional[str]: def _get_visible_devices_env() -> str | None:
"""Return device visibility env var name.""" """Return device visibility env var name."""
if CURRENT_DEVICE == "cuda": if CURRENT_DEVICE == "cuda":
return "CUDA_VISIBLE_DEVICES" return "CUDA_VISIBLE_DEVICES"