mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-02-26 15:56:00 +08:00
Compare commits
14 Commits
a1b1931b4a
...
v0.9.4
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
95ac3f2373 | ||
|
|
000526908a | ||
|
|
c8d7e85b3e | ||
|
|
16735b9e35 | ||
|
|
4e1d69579a | ||
|
|
1857fbdd6b | ||
|
|
bb1ba31005 | ||
|
|
e97d0474fb | ||
|
|
3f0c3dc84d | ||
|
|
c107cc22d0 | ||
|
|
7ef1fba34a | ||
|
|
eceec8ab69 | ||
|
|
b44f651e09 | ||
|
|
55590f5ece |
20
.github/workflows/docker.yml
vendored
20
.github/workflows/docker.yml
vendored
@@ -29,16 +29,13 @@ jobs:
|
||||
matrix:
|
||||
include:
|
||||
- device: "cuda"
|
||||
npu_type: ""
|
||||
- device: "npu"
|
||||
npu_type: "a2"
|
||||
- device: "npu"
|
||||
npu_type: "a3"
|
||||
- device: "npu-a2"
|
||||
- device: "npu-a3"
|
||||
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}-${{ matrix.device }}-${{ matrix.npu_type }}
|
||||
group: ${{ github.workflow }}-${{ github.ref }}-${{ matrix.device }}
|
||||
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
|
||||
|
||||
environment:
|
||||
@@ -55,11 +52,6 @@ jobs:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.10"
|
||||
|
||||
- name: Get llamafactory version
|
||||
id: version
|
||||
run: |
|
||||
@@ -80,7 +72,7 @@ jobs:
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
|
||||
- name: Login to Quay
|
||||
if: ${{ github.event_name != 'pull_request' && matrix.device == 'npu'}}
|
||||
if: ${{ github.event_name != 'pull_request' && startsWith(matrix.device, 'npu') }}
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: quay.io
|
||||
@@ -98,7 +90,7 @@ jobs:
|
||||
docker.io/hiyouga/llamafactory:${{ steps.version.outputs.tag }}
|
||||
|
||||
- name: Build and push Docker image (NPU-A2)
|
||||
if: ${{ matrix.device == 'npu' && matrix.npu_type == 'a2' }}
|
||||
if: ${{ matrix.device == 'npu-a2' }}
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
context: .
|
||||
@@ -110,7 +102,7 @@ jobs:
|
||||
quay.io/ascend/llamafactory:${{ steps.version.outputs.tag }}-npu-a2
|
||||
|
||||
- name: Build and push Docker image (NPU-A3)
|
||||
if: ${{ matrix.device == 'npu' && matrix.npu_type == 'a3' }}
|
||||
if: ${{ matrix.device == 'npu-a3' }}
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
context: .
|
||||
|
||||
7
.github/workflows/publish.yml
vendored
7
.github/workflows/publish.yml
vendored
@@ -23,10 +23,11 @@ jobs:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v7
|
||||
with:
|
||||
python-version: "3.9"
|
||||
python-version: "3.11"
|
||||
github-token: ${{ github.token }}
|
||||
|
||||
- name: Build package
|
||||
run: |
|
||||
|
||||
42
.github/workflows/tests.yml
vendored
42
.github/workflows/tests.yml
vendored
@@ -25,29 +25,25 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python:
|
||||
- "3.9"
|
||||
- "3.10"
|
||||
- "3.11"
|
||||
- "3.12"
|
||||
- "3.13"
|
||||
os:
|
||||
- "ubuntu-latest"
|
||||
- "windows-latest"
|
||||
- "macos-latest"
|
||||
transformers:
|
||||
- null
|
||||
- ""
|
||||
include: # test backward compatibility
|
||||
- python: "3.9"
|
||||
os: "ubuntu-latest"
|
||||
transformers: "4.49.0"
|
||||
- python: "3.9"
|
||||
- python: "3.11"
|
||||
os: "ubuntu-latest"
|
||||
transformers: "4.51.0"
|
||||
- python: "3.9"
|
||||
- python: "3.11"
|
||||
os: "ubuntu-latest"
|
||||
transformers: "4.53.0"
|
||||
exclude: # exclude python 3.9 on macos
|
||||
- python: "3.9"
|
||||
os: "macos-latest"
|
||||
- python: "3.11"
|
||||
os: "ubuntu-latest"
|
||||
transformers: "4.55.0"
|
||||
|
||||
runs-on: ${{ matrix.os }}
|
||||
|
||||
@@ -63,24 +59,23 @@ jobs:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v7
|
||||
with:
|
||||
python-version: ${{ matrix.python }}
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v5
|
||||
github-token: ${{ github.token }}
|
||||
enable-cache: false
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
uv pip install --system torch torchvision --index-url https://download.pytorch.org/whl/cpu
|
||||
uv pip install --system -e "."
|
||||
uv pip install --system -r examples/requirements/dev.txt
|
||||
uv venv
|
||||
uv pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
|
||||
uv pip install -e ".[dev]"
|
||||
|
||||
- name: Install transformers
|
||||
if: ${{ matrix.transformers }}
|
||||
run: |
|
||||
uv pip install --system "transformers==${{ matrix.transformers }}"
|
||||
uv pip install "transformers==${{ matrix.transformers }}"
|
||||
|
||||
- name: Cache files
|
||||
id: hf-hub-cache
|
||||
@@ -92,18 +87,25 @@ jobs:
|
||||
- name: Check quality
|
||||
run: |
|
||||
make style && make quality
|
||||
env:
|
||||
UV_NO_SYNC: 1
|
||||
|
||||
- name: Check license
|
||||
run: |
|
||||
make license
|
||||
env:
|
||||
UV_NO_SYNC: 1
|
||||
|
||||
- name: Check build
|
||||
run: |
|
||||
make build
|
||||
env:
|
||||
UV_NO_SYNC: 1
|
||||
|
||||
- name: Test with pytest
|
||||
run: |
|
||||
make test
|
||||
env:
|
||||
UV_NO_SYNC: 1
|
||||
HF_HOME: ${{ runner.temp }}/huggingface
|
||||
HF_HUB_OFFLINE: "${{ steps.hf-hub-cache.outputs.cache-hit == 'true' && '1' || '0' }}"
|
||||
|
||||
88
.github/workflows/tests_cuda.yml
vendored
Normal file
88
.github/workflows/tests_cuda.yml
vendored
Normal file
@@ -0,0 +1,88 @@
|
||||
name: tests_cuda
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
push:
|
||||
branches:
|
||||
- "main"
|
||||
paths:
|
||||
- "**/*.py"
|
||||
- "pyproject.toml"
|
||||
- "Makefile"
|
||||
- ".github/workflows/*.yml"
|
||||
pull_request:
|
||||
branches:
|
||||
- "main"
|
||||
paths:
|
||||
- "**/*.py"
|
||||
- "pyproject.toml"
|
||||
- "Makefile"
|
||||
- ".github/workflows/*.yml"
|
||||
|
||||
jobs:
|
||||
tests:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python:
|
||||
- "3.11"
|
||||
os:
|
||||
- "linux-x86_64-gpu-2"
|
||||
|
||||
runs-on: ${{ matrix.os }}
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}-${{ matrix.os }}-${{ matrix.python }}
|
||||
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
|
||||
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v7
|
||||
with:
|
||||
python-version: ${{ matrix.python }}
|
||||
github-token: ${{ github.token }}
|
||||
enable-cache: false
|
||||
|
||||
- name: Check GPU Status
|
||||
run: nvidia-smi
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
uv venv
|
||||
uv pip install -e ".[dev]"
|
||||
|
||||
- name: Cache HuggingFace models
|
||||
id: hf-hub-cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ${{ runner.temp }}/huggingface
|
||||
key: hf-cache-${{ runner.os }}-${{ hashFiles('tests/version.txt') }}
|
||||
|
||||
- name: Check quality
|
||||
run: |
|
||||
make style && make quality
|
||||
env:
|
||||
UV_NO_SYNC: 1
|
||||
|
||||
- name: Check license
|
||||
run: |
|
||||
make license
|
||||
env:
|
||||
UV_NO_SYNC: 1
|
||||
|
||||
- name: Check build
|
||||
run: |
|
||||
make build
|
||||
env:
|
||||
UV_NO_SYNC: 1
|
||||
|
||||
- name: Test with pytest
|
||||
run: |
|
||||
make test
|
||||
env:
|
||||
UV_NO_SYNC: 1
|
||||
HF_HOME: ${{ runner.temp }}/huggingface
|
||||
HF_HUB_OFFLINE: "${{ steps.hf-hub-cache.outputs.cache-hit == 'true' && '1' || '0' }}"
|
||||
19
.github/workflows/tests_npu.yml
vendored
19
.github/workflows/tests_npu.yml
vendored
@@ -49,13 +49,17 @@ jobs:
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install uv
|
||||
run: |
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
uses: astral-sh/setup-uv@v7
|
||||
with:
|
||||
python-version: ${{ matrix.python }}
|
||||
github-token: ${{ github.token }}
|
||||
enable-cache: false
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
uv pip install --system -e "." torch-npu==${{matrix.pytorch_npu}}
|
||||
uv pip install --system -r examples/requirements/dev.txt
|
||||
uv venv
|
||||
uv pip install torch-npu==${{matrix.pytorch_npu}}
|
||||
uv pip install -e ".[dev]"
|
||||
|
||||
- name: Install node
|
||||
run: |
|
||||
@@ -74,18 +78,25 @@ jobs:
|
||||
- name: Check quality
|
||||
run: |
|
||||
make style && make quality
|
||||
env:
|
||||
UV_NO_SYNC: 1
|
||||
|
||||
- name: Check license
|
||||
run: |
|
||||
make license
|
||||
env:
|
||||
UV_NO_SYNC: 1
|
||||
|
||||
- name: Check build
|
||||
run: |
|
||||
make build
|
||||
env:
|
||||
UV_NO_SYNC: 1
|
||||
|
||||
- name: Test with pytest
|
||||
run: |
|
||||
make test
|
||||
env:
|
||||
UV_NO_SYNC: 1
|
||||
HF_HOME: /root/.cache/huggingface
|
||||
HF_HUB_OFFLINE: "${{ steps.hf-hub-cache.outputs.cache-hit == 'true' && '1' || '0' }}"
|
||||
|
||||
22
Makefile
22
Makefile
@@ -2,23 +2,27 @@
|
||||
|
||||
check_dirs := scripts src tests tests_v1
|
||||
|
||||
RUN := $(shell command -v uv >/dev/null 2>&1 && echo "uv run" || echo "")
|
||||
BUILD := $(shell command -v uv >/dev/null 2>&1 && echo "uv build" || echo "python -m build")
|
||||
TOOL := $(shell command -v uv >/dev/null 2>&1 && echo "uvx" || echo "")
|
||||
|
||||
build:
|
||||
uv build
|
||||
$(BUILD)
|
||||
|
||||
commit:
|
||||
uv run pre-commit install
|
||||
uv run pre-commit run --all-files
|
||||
$(TOOL) pre-commit install
|
||||
$(TOOL) pre-commit run --all-files
|
||||
|
||||
license:
|
||||
uv run python tests/check_license.py $(check_dirs)
|
||||
$(RUN) python3 tests/check_license.py $(check_dirs)
|
||||
|
||||
quality:
|
||||
uv run ruff check $(check_dirs)
|
||||
uv run ruff format --check $(check_dirs)
|
||||
$(TOOL) ruff check $(check_dirs)
|
||||
$(TOOL) ruff format --check $(check_dirs)
|
||||
|
||||
style:
|
||||
uv run ruff check $(check_dirs) --fix
|
||||
uv run ruff format $(check_dirs)
|
||||
$(TOOL) ruff check $(check_dirs) --fix
|
||||
$(TOOL) ruff format $(check_dirs)
|
||||
|
||||
test:
|
||||
WANDB_DISABLED=true uv run pytest -vv --import-mode=importlib tests/ tests_v1/
|
||||
WANDB_DISABLED=true $(RUN) pytest -vv --import-mode=importlib tests/ tests_v1/
|
||||
|
||||
30
README.md
30
README.md
@@ -309,6 +309,7 @@ Read technical notes:
|
||||
| [MiMo](https://huggingface.co/XiaomiMiMo) | 7B/309B | mimo/mimo_v2 |
|
||||
| [MiniCPM 1-4.1](https://huggingface.co/openbmb) | 0.5B/1B/2B/4B/8B | cpm/cpm3/cpm4 |
|
||||
| [MiniCPM-o-2.6/MiniCPM-V-2.6](https://huggingface.co/openbmb) | 8B | minicpm_o/minicpm_v |
|
||||
| [MiniMax-M1/MiniMax-M2](https://huggingface.co/MiniMaxAI/models) | 229B/456B | minimax1/minimax2 |
|
||||
| [Ministral 3](https://huggingface.co/mistralai) | 3B/8B/14B | ministral3 |
|
||||
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
|
||||
| [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
|
||||
@@ -433,6 +434,7 @@ You also can add a custom chat template to [template.py](src/llamafactory/data/t
|
||||
- [Chinese-DeepSeek-R1-Distill (zh)](https://huggingface.co/datasets/Congliu/Chinese-DeepSeek-R1-Distill-data-110k-SFT)
|
||||
- [LLaVA mixed (en&zh)](https://huggingface.co/datasets/BUAADreamer/llava-en-zh-300k)
|
||||
- [Pokemon-gpt4o-captions (en&zh)](https://huggingface.co/datasets/jugg1024/pokemon-gpt4o-captions)
|
||||
- [DLR-Web (en)](https://huggingface.co/datasets/Attention1115/DLR-Web)
|
||||
- [Open Assistant (de)](https://huggingface.co/datasets/mayflowergmbh/oasst_de)
|
||||
- [Dolly 15k (de)](https://huggingface.co/datasets/mayflowergmbh/dolly-15k_de)
|
||||
- [Alpaca GPT4 (de)](https://huggingface.co/datasets/mayflowergmbh/alpaca-gpt4_de)
|
||||
@@ -514,7 +516,7 @@ huggingface-cli login
|
||||
```bash
|
||||
git clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git
|
||||
cd LLaMA-Factory
|
||||
pip install -e "." --no-build-isolation
|
||||
pip install -e ".[metrics]"
|
||||
```
|
||||
|
||||
Optional dependencies available: `metrics`, `deepspeed`. Install with: `pip install -e ".[metrics,deepspeed]"`
|
||||
@@ -538,13 +540,7 @@ Please refer to [build docker](#build-docker) to build the image yourself.
|
||||
Create an isolated Python environment with [uv](https://github.com/astral-sh/uv):
|
||||
|
||||
```bash
|
||||
uv sync --extra torch --extra metrics --prerelease=allow
|
||||
```
|
||||
|
||||
Run LLaMA-Factory in the isolated environment:
|
||||
|
||||
```bash
|
||||
uv run --prerelease=allow llamafactory-cli train examples/train_lora/llama3_lora_pretrain.yaml
|
||||
uv run llamafactory-cli webui
|
||||
```
|
||||
|
||||
</details>
|
||||
@@ -581,7 +577,7 @@ To enable FlashAttention-2 on the Windows platform, please use the script from [
|
||||
|
||||
<details><summary>For Ascend NPU users</summary>
|
||||
|
||||
To install LLaMA Factory on Ascend NPU devices, please upgrade Python to version 3.10 or higher: `pip install -e "."`. Additionally, you need to install the **[Ascend CANN Toolkit and Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**. Please follow the [installation tutorial](https://www.hiascend.com/document/detail/en/CANNCommunityEdition/600alphaX/softwareinstall/instg/atlasdeploy_03_0031.html) or use the following commands:
|
||||
To install LLaMA Factory on Ascend NPU devices, please upgrade Python to version 3.10 or higher: `pip install -e . torch-npu==2.7.1`. Additionally, you need to install the **[Ascend CANN Toolkit and Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**. Please follow the [installation tutorial](https://www.hiascend.com/document/detail/en/CANNCommunityEdition/600alphaX/softwareinstall/instg/atlasdeploy_03_0031.html) or use the following commands:
|
||||
|
||||
```bash
|
||||
# replace the url according to your CANN version and devices
|
||||
@@ -600,8 +596,8 @@ source /usr/local/Ascend/ascend-toolkit/set_env.sh
|
||||
| Requirement | Minimum | Recommend |
|
||||
| ------------ | ------- | -------------- |
|
||||
| CANN | 8.0.RC1 | 8.0.0.alpha002 |
|
||||
| torch | 2.1.0 | 2.4.0 |
|
||||
| torch-npu | 2.1.0 | 2.4.0.post2 |
|
||||
| torch | 2.1.0 | 2.7.1 |
|
||||
| torch-npu | 2.1.0 | 2.7.1 |
|
||||
| deepspeed | 0.13.2 | 0.13.2 |
|
||||
| vllm-ascend | - | 0.7.3 |
|
||||
|
||||
@@ -643,7 +639,7 @@ cd transformers
|
||||
pip install .
|
||||
```
|
||||
|
||||
3. Set `double_quantization: false` in the configuration. You can refer to the [example](examples/train_qlora/llama3_lora_sft_bnb_npu.yaml).
|
||||
3. Set `double_quantization: false` in the configuration. You can refer to the [example](examples/train_qlora/qwen3_lora_sft_bnb_npu.yaml).
|
||||
|
||||
</details>
|
||||
|
||||
@@ -658,12 +654,12 @@ You can also use **[Easy Dataset](https://github.com/ConardLi/easy-dataset)**, *
|
||||
|
||||
### Quickstart
|
||||
|
||||
Use the following 3 commands to run LoRA **fine-tuning**, **inference** and **merging** of the Llama3-8B-Instruct model, respectively.
|
||||
Use the following 3 commands to run LoRA **fine-tuning**, **inference** and **merging** of the Qwen3-4B-Instruct model, respectively.
|
||||
|
||||
```bash
|
||||
llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
|
||||
llamafactory-cli chat examples/inference/llama3_lora_sft.yaml
|
||||
llamafactory-cli export examples/merge_lora/llama3_lora_sft.yaml
|
||||
llamafactory-cli train examples/train_lora/qwen3_lora_sft.yaml
|
||||
llamafactory-cli chat examples/inference/qwen3_lora_sft.yaml
|
||||
llamafactory-cli export examples/merge_lora/qwen3_lora_sft.yaml
|
||||
```
|
||||
|
||||
See [examples/README.md](examples/README.md) for advanced usage (including distributed training).
|
||||
@@ -786,7 +782,7 @@ When building the Docker image, use `-v ./hf_cache:/root/.cache/huggingface` arg
|
||||
### Deploy with OpenAI-style API and vLLM
|
||||
|
||||
```bash
|
||||
API_PORT=8000 llamafactory-cli api examples/inference/llama3.yaml infer_backend=vllm vllm_enforce_eager=true
|
||||
API_PORT=8000 llamafactory-cli api examples/inference/qwen3.yaml infer_backend=vllm vllm_enforce_eager=true
|
||||
```
|
||||
|
||||
> [!TIP]
|
||||
|
||||
34
README_zh.md
34
README_zh.md
@@ -311,6 +311,7 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
|
||||
| [MiMo](https://huggingface.co/XiaomiMiMo) | 7B/309B | mimo/mimo_v2 |
|
||||
| [MiniCPM 1-4.1](https://huggingface.co/openbmb) | 0.5B/1B/2B/4B/8B | cpm/cpm3/cpm4 |
|
||||
| [MiniCPM-o-2.6/MiniCPM-V-2.6](https://huggingface.co/openbmb) | 8B | minicpm_o/minicpm_v |
|
||||
| [MiniMax-M1/MiniMax-M2](https://huggingface.co/MiniMaxAI/models) | 229B/456B | minimax1/minimax2 |
|
||||
| [Ministral 3](https://huggingface.co/mistralai) | 3B/8B/14B | ministral3 |
|
||||
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
|
||||
| [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
|
||||
@@ -435,6 +436,7 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
|
||||
- [Chinese-DeepSeek-R1-Distill (zh)](https://huggingface.co/datasets/Congliu/Chinese-DeepSeek-R1-Distill-data-110k-SFT)
|
||||
- [LLaVA mixed (en&zh)](https://huggingface.co/datasets/BUAADreamer/llava-en-zh-300k)
|
||||
- [Pokemon-gpt4o-captions (en&zh)](https://huggingface.co/datasets/jugg1024/pokemon-gpt4o-captions)
|
||||
- [DLR-Web (en)](https://huggingface.co/datasets/Attention1115/DLR-Web)
|
||||
- [Open Assistant (de)](https://huggingface.co/datasets/mayflowergmbh/oasst_de)
|
||||
- [Dolly 15k (de)](https://huggingface.co/datasets/mayflowergmbh/dolly-15k_de)
|
||||
- [Alpaca GPT4 (de)](https://huggingface.co/datasets/mayflowergmbh/alpaca-gpt4_de)
|
||||
@@ -516,10 +518,12 @@ huggingface-cli login
|
||||
```bash
|
||||
git clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git
|
||||
cd LLaMA-Factory
|
||||
pip install -e ".[torch,metrics]" --no-build-isolation
|
||||
pip install -e ".[metrics]"
|
||||
```
|
||||
|
||||
可选的额外依赖项:torch、torch-npu、metrics、deepspeed、liger-kernel、bitsandbytes、hqq、eetq、gptq、aqlm、vllm、sglang、galore、apollo、badam、adam-mini、qwen、minicpm_v、openmind、swanlab、dev
|
||||
可选的额外依赖项:`metrics`、`deepspeed`。使用 `pip install -e ".[metrics,deepspeed]"` 安装。
|
||||
|
||||
其他可选依赖项请参考 `examples/requirements/` 目录下的文件。
|
||||
|
||||
#### 从镜像安装
|
||||
|
||||
@@ -538,13 +542,7 @@ docker run -it --rm --gpus=all --ipc=host hiyouga/llamafactory:latest
|
||||
使用 [uv](https://github.com/astral-sh/uv) 创建隔离的 Python 环境:
|
||||
|
||||
```bash
|
||||
uv sync --extra torch --extra metrics --prerelease=allow
|
||||
```
|
||||
|
||||
在环境中运行 LLaMA-Factory:
|
||||
|
||||
```bash
|
||||
uv run --prerelease=allow llamafactory-cli train examples/train_lora/llama3_lora_pretrain.yaml
|
||||
uv run llamafactory-cli webui
|
||||
```
|
||||
|
||||
</details>
|
||||
@@ -581,7 +579,7 @@ pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/downl
|
||||
|
||||
<details><summary>昇腾 NPU 用户指南</summary>
|
||||
|
||||
在昇腾 NPU 设备上安装 LLaMA Factory 时,请升级 Python 到 3.10 及以上,并需要指定额外依赖项,使用 `pip install -e ".[torch-npu,metrics]"` 命令安装。此外,还需要安装 **[Ascend CANN Toolkit 与 Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**,安装方法请参考[安装教程](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/80RC2alpha002/quickstart/quickstart/quickstart_18_0004.html)或使用以下命令:
|
||||
在昇腾 NPU 设备上安装 LLaMA Factory 时,请升级 Python 到 3.10 及以上,并需要指定额外依赖项,使用 `pip install -e . torch-npu==2.7.1` 命令安装。此外,还需要安装 **[Ascend CANN Toolkit 与 Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**,安装方法请参考[安装教程](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/80RC2alpha002/quickstart/quickstart/quickstart_18_0004.html)或使用以下命令:
|
||||
|
||||
```bash
|
||||
# 请替换 URL 为 CANN 版本和设备型号对应的 URL
|
||||
@@ -600,8 +598,8 @@ source /usr/local/Ascend/ascend-toolkit/set_env.sh
|
||||
| 依赖项 | 至少 | 推荐 |
|
||||
| ------------ | ------- | -------------- |
|
||||
| CANN | 8.0.RC1 | 8.0.0.alpha002 |
|
||||
| torch | 2.1.0 | 2.4.0 |
|
||||
| torch-npu | 2.1.0 | 2.4.0.post2 |
|
||||
| torch | 2.1.0 | 2.7.1 |
|
||||
| torch-npu | 2.1.0 | 2.7.1 |
|
||||
| deepspeed | 0.13.2 | 0.13.2 |
|
||||
| vllm-ascend | - | 0.7.3 |
|
||||
|
||||
@@ -643,7 +641,7 @@ cd transformers
|
||||
pip install .
|
||||
```
|
||||
|
||||
3. 在训练参数中设置 `double_quantization: false`,可参考[示例](examples/train_qlora/llama3_lora_sft_bnb_npu.yaml)。
|
||||
3. 在训练参数中设置 `double_quantization: false`,可参考[示例](examples/train_qlora/qwen3_lora_sft_bnb_npu.yaml)。
|
||||
|
||||
</details>
|
||||
|
||||
@@ -658,12 +656,12 @@ pip install .
|
||||
|
||||
### 快速开始
|
||||
|
||||
下面三行命令分别对 Llama3-8B-Instruct 模型进行 LoRA **微调**、**推理**和**合并**。
|
||||
下面三行命令分别对 Qwen3-4B-Instruct 模型进行 LoRA **微调**、**推理**和**合并**。
|
||||
|
||||
```bash
|
||||
llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
|
||||
llamafactory-cli chat examples/inference/llama3_lora_sft.yaml
|
||||
llamafactory-cli export examples/merge_lora/llama3_lora_sft.yaml
|
||||
llamafactory-cli train examples/train_lora/qwen3_lora_sft.yaml
|
||||
llamafactory-cli chat examples/inference/qwen3_lora_sft.yaml
|
||||
llamafactory-cli export examples/merge_lora/qwen3_lora_sft.yaml
|
||||
```
|
||||
|
||||
高级用法请参考 [examples/README_zh.md](examples/README_zh.md)(包括多 GPU 微调)。
|
||||
@@ -789,7 +787,7 @@ docker exec -it llamafactory bash
|
||||
### 利用 vLLM 部署 OpenAI API
|
||||
|
||||
```bash
|
||||
API_PORT=8000 llamafactory-cli api examples/inference/llama3.yaml infer_backend=vllm vllm_enforce_eager=true
|
||||
API_PORT=8000 llamafactory-cli api examples/inference/qwen3.yaml infer_backend=vllm vllm_enforce_eager=true
|
||||
```
|
||||
|
||||
> [!TIP]
|
||||
|
||||
@@ -471,6 +471,14 @@
|
||||
"ultrachat_de": {
|
||||
"hf_hub_url": "mayflowergmbh/ultra-chat_de"
|
||||
},
|
||||
"dlr_web": {
|
||||
"hf_hub_url": "Attention1115/DLR-Web",
|
||||
"split": "full",
|
||||
"columns": {
|
||||
"prompt": "question",
|
||||
"response": "response"
|
||||
}
|
||||
},
|
||||
"dpo_en_demo": {
|
||||
"file_name": "dpo_en_demo.json",
|
||||
"ranking": true,
|
||||
|
||||
@@ -26,13 +26,13 @@ WORKDIR /app
|
||||
# Change pip source
|
||||
RUN pip config set global.index-url "${PIP_INDEX}" && \
|
||||
pip config set global.extra-index-url "${PIP_INDEX}" && \
|
||||
pip install --no-cache-dir --upgrade pip packaging wheel setuptools "hatchling>=1.18.0" editables
|
||||
pip install --no-cache-dir --upgrade pip packaging wheel setuptools editables "hatchling>=1.18.0"
|
||||
|
||||
# Copy the application into the image
|
||||
COPY . /app
|
||||
|
||||
# Install LLaMA Factory
|
||||
RUN pip install --no-cache-dir -e "." --no-build-isolation
|
||||
RUN pip install --no-cache-dir --no-build-isolation -e ".[metrics,deepspeed]"
|
||||
|
||||
# Rebuild flash attention
|
||||
RUN if [ "${INSTALL_FLASHATTN}" == "true" ]; then \
|
||||
|
||||
@@ -60,7 +60,7 @@ WORKDIR /app
|
||||
COPY . /app
|
||||
|
||||
# Install LLaMA Factory
|
||||
RUN pip install --no-cache-dir -e "." --no-build-isolation
|
||||
RUN pip install --no-cache-dir -e ".[metrics]" --no-build-isolation
|
||||
|
||||
RUN pip install "git+https://github.com/alibaba/roll.git#subdirectory=mcore_adapter"
|
||||
|
||||
|
||||
@@ -5,7 +5,6 @@ services:
|
||||
context: ../..
|
||||
args:
|
||||
PIP_INDEX: https://pypi.org/simple
|
||||
EXTRAS: metrics
|
||||
container_name: llamafactory
|
||||
ports:
|
||||
- "7860:7860"
|
||||
|
||||
@@ -27,17 +27,15 @@ WORKDIR /app
|
||||
# Change pip source
|
||||
RUN pip config set global.index-url "${PIP_INDEX}" && \
|
||||
pip config set global.extra-index-url "${PIP_INDEX}" && \
|
||||
pip install --no-cache-dir --upgrade pip packaging wheel setuptools "hatchling>=1.18.0" editables
|
||||
|
||||
# Install torch-npu
|
||||
RUN pip uninstall -y torch torchvision torchaudio && \
|
||||
pip install --no-cache-dir "torch==2.7.1" "torch-npu==2.7.1" "torchvision==0.22.1" --index-url "${PYTORCH_INDEX}"
|
||||
pip install --no-cache-dir --upgrade pip packaging wheel setuptools editables "hatchling>=1.18.0"
|
||||
|
||||
# Copy the application into the image
|
||||
COPY . /app
|
||||
|
||||
# Install LLaMA Factory
|
||||
RUN pip install --no-cache-dir -e "." --no-build-isolation
|
||||
# Install torch-npu
|
||||
RUN pip uninstall -y torch torchvision torchaudio && \
|
||||
pip install --no-cache-dir "torch==2.7.1" "torch-npu==2.7.1" "torchvision==0.22.1" "torchaudio==2.7.1" --index-url "${PYTORCH_INDEX}" && \
|
||||
pip install --no-cache-dir -e ".[metrics]" --no-build-isolation
|
||||
|
||||
# Set up volumes
|
||||
# VOLUME [ "/root/.cache/huggingface", "/app/shared_data", "/app/output" ]
|
||||
|
||||
@@ -5,7 +5,6 @@ services:
|
||||
context: ../..
|
||||
args:
|
||||
PIP_INDEX: https://pypi.org/simple
|
||||
EXTRAS: torch-npu,metrics
|
||||
container_name: llamafactory-a2
|
||||
image: llamafactory:npu-a2
|
||||
volumes:
|
||||
@@ -36,7 +35,6 @@ services:
|
||||
args:
|
||||
BASE_IMAGE: quay.io/ascend/cann:8.3.rc2-a3-ubuntu22.04-py3.11
|
||||
PIP_INDEX: https://pypi.org/simple
|
||||
EXTRAS: torch-npu,metrics
|
||||
container_name: llamafactory-a3
|
||||
image: llamafactory:npu-a3
|
||||
volumes:
|
||||
|
||||
@@ -27,17 +27,14 @@ WORKDIR /app
|
||||
# Change pip source
|
||||
RUN pip config set global.index-url "${PIP_INDEX}" && \
|
||||
pip config set global.extra-index-url "${PIP_INDEX}" && \
|
||||
pip install --no-cache-dir --upgrade pip packaging wheel setuptools "hatchling>=1.18.0" editables
|
||||
|
||||
# Reinstall pytorch rocm
|
||||
RUN pip uninstall -y torch torchvision torchaudio && \
|
||||
pip install --no-cache-dir --pre torch torchvision torchaudio --index-url "${PYTORCH_INDEX}"
|
||||
pip install --no-cache-dir --upgrade pip packaging wheel setuptools editables "hatchling>=1.18.0"
|
||||
|
||||
# Copy the application into the image
|
||||
COPY . /app
|
||||
|
||||
# Install LLaMA Factory
|
||||
RUN pip install --no-cache-dir -e "." --no-build-isolation
|
||||
# Reinstall pytorch rocm and install LLaMA Factory
|
||||
RUN pip uninstall -y torch torchvision torchaudio && \
|
||||
pip install --no-cache-dir --no-build-isolation -e --pre ".[metrics,deepspeed]" --index-url "${PYTORCH_INDEX}"
|
||||
|
||||
# Rebuild flash attention
|
||||
RUN if [ "${INSTALL_FLASHATTN}" == "true" ]; then \
|
||||
|
||||
@@ -5,7 +5,6 @@ services:
|
||||
context: ../..
|
||||
args:
|
||||
PIP_INDEX: https://pypi.org/simple
|
||||
EXTRAS: metrics
|
||||
container_name: llamafactory
|
||||
ports:
|
||||
- "7860:7860"
|
||||
|
||||
@@ -18,19 +18,19 @@ By default, LLaMA-Factory uses all visible computing devices.
|
||||
Basic usage:
|
||||
|
||||
```bash
|
||||
llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
|
||||
llamafactory-cli train examples/train_lora/qwen3_lora_sft.yaml
|
||||
```
|
||||
|
||||
Advanced usage:
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0,1 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml \
|
||||
CUDA_VISIBLE_DEVICES=0,1 llamafactory-cli train examples/train_lora/qwen3_lora_sft.yaml \
|
||||
learning_rate=1e-5 \
|
||||
logging_steps=1
|
||||
```
|
||||
|
||||
```bash
|
||||
bash examples/train_lora/llama3_lora_sft.sh
|
||||
bash examples/train_lora/qwen3_lora_sft.sh
|
||||
```
|
||||
|
||||
## Examples
|
||||
@@ -40,49 +40,43 @@ bash examples/train_lora/llama3_lora_sft.sh
|
||||
#### (Continuous) Pre-Training
|
||||
|
||||
```bash
|
||||
llamafactory-cli train examples/train_lora/llama3_lora_pretrain.yaml
|
||||
llamafactory-cli train examples/train_lora/qwen3_lora_pretrain.yaml
|
||||
```
|
||||
|
||||
#### Supervised Fine-Tuning
|
||||
|
||||
```bash
|
||||
llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
|
||||
llamafactory-cli train examples/train_lora/qwen3_lora_sft.yaml
|
||||
```
|
||||
|
||||
#### Multimodal Supervised Fine-Tuning
|
||||
|
||||
```bash
|
||||
llamafactory-cli train examples/train_lora/qwen2_5vl_lora_sft.yaml
|
||||
llamafactory-cli train examples/train_lora/qwen3vl_lora_sft.yaml
|
||||
```
|
||||
|
||||
#### DPO/ORPO/SimPO Training
|
||||
|
||||
```bash
|
||||
llamafactory-cli train examples/train_lora/llama3_lora_dpo.yaml
|
||||
llamafactory-cli train examples/train_lora/qwen3_lora_dpo.yaml
|
||||
```
|
||||
|
||||
#### Multimodal DPO/ORPO/SimPO Training
|
||||
|
||||
```bash
|
||||
llamafactory-cli train examples/train_lora/qwen2_5vl_lora_dpo.yaml
|
||||
llamafactory-cli train examples/train_lora/qwen3vl_lora_dpo.yaml
|
||||
```
|
||||
|
||||
#### Reward Modeling
|
||||
|
||||
```bash
|
||||
llamafactory-cli train examples/train_lora/llama3_lora_reward.yaml
|
||||
```
|
||||
|
||||
#### PPO Training
|
||||
|
||||
```bash
|
||||
llamafactory-cli train examples/train_lora/llama3_lora_ppo.yaml
|
||||
llamafactory-cli train examples/train_lora/qwen3_lora_reward.yaml
|
||||
```
|
||||
|
||||
#### KTO Training
|
||||
|
||||
```bash
|
||||
llamafactory-cli train examples/train_lora/llama3_lora_kto.yaml
|
||||
llamafactory-cli train examples/train_lora/qwen3_lora_kto.yaml
|
||||
```
|
||||
|
||||
#### Preprocess Dataset
|
||||
@@ -90,32 +84,26 @@ llamafactory-cli train examples/train_lora/llama3_lora_kto.yaml
|
||||
It is useful for large dataset, use `tokenized_path` in config to load the preprocessed dataset.
|
||||
|
||||
```bash
|
||||
llamafactory-cli train examples/train_lora/llama3_preprocess.yaml
|
||||
```
|
||||
|
||||
#### Evaluating on MMLU/CMMLU/C-Eval Benchmarks
|
||||
|
||||
```bash
|
||||
llamafactory-cli eval examples/train_lora/llama3_lora_eval.yaml
|
||||
llamafactory-cli train examples/train_lora/qwen3_preprocess.yaml
|
||||
```
|
||||
|
||||
#### Supervised Fine-Tuning on Multiple Nodes
|
||||
|
||||
```bash
|
||||
FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
|
||||
FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
|
||||
FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/qwen3_lora_sft.yaml
|
||||
FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/qwen3_lora_sft.yaml
|
||||
```
|
||||
|
||||
#### Supervised Fine-Tuning with DeepSpeed ZeRO-3 (Weight Sharding)
|
||||
|
||||
```bash
|
||||
FORCE_TORCHRUN=1 llamafactory-cli train examples/train_lora/llama3_lora_sft_ds3.yaml
|
||||
FORCE_TORCHRUN=1 llamafactory-cli train examples/train_lora/qwen3_lora_sft_ds3.yaml
|
||||
```
|
||||
|
||||
#### Supervised Fine-Tuning with Ray on 4 GPUs
|
||||
|
||||
```bash
|
||||
USE_RAY=1 llamafactory-cli train examples/train_lora/llama3_lora_sft_ray.yaml
|
||||
USE_RAY=1 llamafactory-cli train examples/train_lora/qwen3_lora_sft_ray.yaml
|
||||
```
|
||||
|
||||
### QLoRA Fine-Tuning
|
||||
@@ -123,13 +111,13 @@ USE_RAY=1 llamafactory-cli train examples/train_lora/llama3_lora_sft_ray.yaml
|
||||
#### Supervised Fine-Tuning with 4/8-bit Bitsandbytes/HQQ/EETQ Quantization (Recommended)
|
||||
|
||||
```bash
|
||||
llamafactory-cli train examples/train_qlora/llama3_lora_sft_otfq.yaml
|
||||
llamafactory-cli train examples/train_qlora/qwen3_lora_sft_otfq.yaml
|
||||
```
|
||||
|
||||
#### Supervised Fine-Tuning with 4-bit Bitsandbytes Quantization on Ascend NPU
|
||||
|
||||
```bash
|
||||
llamafactory-cli train examples/train_qlora/llama3_lora_sft_bnb_npu.yaml
|
||||
llamafactory-cli train examples/train_qlora/qwen3_lora_sft_bnb_npu.yaml
|
||||
```
|
||||
|
||||
#### Supervised Fine-Tuning with 4/8-bit GPTQ Quantization
|
||||
@@ -155,14 +143,14 @@ llamafactory-cli train examples/train_qlora/llama3_lora_sft_aqlm.yaml
|
||||
#### Supervised Fine-Tuning on Single Node
|
||||
|
||||
```bash
|
||||
FORCE_TORCHRUN=1 llamafactory-cli train examples/train_full/llama3_full_sft.yaml
|
||||
FORCE_TORCHRUN=1 llamafactory-cli train examples/train_full/qwen3_full_sft.yaml
|
||||
```
|
||||
|
||||
#### Supervised Fine-Tuning on Multiple Nodes
|
||||
|
||||
```bash
|
||||
FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft.yaml
|
||||
FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft.yaml
|
||||
FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/qwen3_full_sft.yaml
|
||||
FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/qwen3_full_sft.yaml
|
||||
```
|
||||
|
||||
### Elastic and Fault-Tolerant Supervised Fine-Tuning on Multiple Nodes
|
||||
@@ -170,13 +158,13 @@ FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500
|
||||
To launch an elastic job with `MAX_RESTARTS` failures retries, run the following on at least `MIN_NNODES` nodes and at most `MAX_NNODES` nodes. `RDZV_ID` should be set as a unique job id (shared by all nodes participating in the job). See also [torchrun](https://docs.pytorch.org/docs/stable/elastic/run.html).
|
||||
|
||||
```bash
|
||||
FORCE_TORCHRUN=1 MIN_NNODES=1 MAX_NNODES=3 MAX_RESTARTS=3 RDZV_ID=llamafactory MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft.yaml
|
||||
FORCE_TORCHRUN=1 MIN_NNODES=1 MAX_NNODES=3 MAX_RESTARTS=3 RDZV_ID=llamafactory MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/qwen3_full_sft.yaml
|
||||
```
|
||||
|
||||
#### Multimodal Supervised Fine-Tuning
|
||||
|
||||
```bash
|
||||
FORCE_TORCHRUN=1 llamafactory-cli train examples/train_full/qwen2_5vl_full_sft.yaml
|
||||
FORCE_TORCHRUN=1 llamafactory-cli train examples/train_full/qwen3vl_full_sft.yaml
|
||||
```
|
||||
|
||||
### Merging LoRA Adapters and Quantization
|
||||
@@ -186,19 +174,19 @@ FORCE_TORCHRUN=1 llamafactory-cli train examples/train_full/qwen2_5vl_full_sft.y
|
||||
Note: DO NOT use quantized model or `quantization_bit` when merging LoRA adapters.
|
||||
|
||||
```bash
|
||||
llamafactory-cli export examples/merge_lora/llama3_lora_sft.yaml
|
||||
llamafactory-cli export examples/merge_lora/qwen3_lora_sft.yaml
|
||||
```
|
||||
|
||||
#### Quantizing Model using AutoGPTQ
|
||||
|
||||
```bash
|
||||
llamafactory-cli export examples/merge_lora/llama3_gptq.yaml
|
||||
llamafactory-cli export examples/merge_lora/qwen3_gptq.yaml
|
||||
```
|
||||
|
||||
### Save Ollama modelfile
|
||||
|
||||
```bash
|
||||
llamafactory-cli export examples/merge_lora/llama3_full_sft.yaml
|
||||
llamafactory-cli export examples/merge_lora/qwen3_full_sft.yaml
|
||||
```
|
||||
|
||||
### Inferring LoRA Fine-Tuned Models
|
||||
@@ -206,26 +194,26 @@ llamafactory-cli export examples/merge_lora/llama3_full_sft.yaml
|
||||
#### Evaluation using vLLM's Multi-GPU Inference
|
||||
|
||||
```
|
||||
python scripts/vllm_infer.py --model_name_or_path meta-llama/Meta-Llama-3-8B-Instruct --template llama3 --dataset alpaca_en_demo
|
||||
python scripts/vllm_infer.py --model_name_or_path Qwen/Qwen3-4B-Instruct-2507 --template qwen3_nothink --dataset alpaca_en_demo
|
||||
python scripts/eval_bleu_rouge.py generated_predictions.jsonl
|
||||
```
|
||||
|
||||
#### Use CLI ChatBox
|
||||
|
||||
```bash
|
||||
llamafactory-cli chat examples/inference/llama3_lora_sft.yaml
|
||||
llamafactory-cli chat examples/inference/qwen3_lora_sft.yaml
|
||||
```
|
||||
|
||||
#### Use Web UI ChatBox
|
||||
|
||||
```bash
|
||||
llamafactory-cli webchat examples/inference/llama3_lora_sft.yaml
|
||||
llamafactory-cli webchat examples/inference/qwen3_lora_sft.yaml
|
||||
```
|
||||
|
||||
#### Launch OpenAI-style API
|
||||
|
||||
```bash
|
||||
llamafactory-cli api examples/inference/llama3_lora_sft.yaml
|
||||
llamafactory-cli api examples/inference/qwen3_lora_sft.yaml
|
||||
```
|
||||
|
||||
### Extras
|
||||
|
||||
@@ -18,19 +18,19 @@ LLaMA-Factory 默认使用所有可见的计算设备。
|
||||
基础用法:
|
||||
|
||||
```bash
|
||||
llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
|
||||
llamafactory-cli train examples/train_lora/qwen3_lora_sft.yaml
|
||||
```
|
||||
|
||||
高级用法:
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0,1 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml \
|
||||
CUDA_VISIBLE_DEVICES=0,1 llamafactory-cli train examples/train_lora/qwen3_lora_sft.yaml \
|
||||
learning_rate=1e-5 \
|
||||
logging_steps=1
|
||||
```
|
||||
|
||||
```bash
|
||||
bash examples/train_lora/llama3_lora_sft.sh
|
||||
bash examples/train_lora/qwen3_lora_sft.sh
|
||||
```
|
||||
|
||||
## 示例
|
||||
@@ -40,49 +40,43 @@ bash examples/train_lora/llama3_lora_sft.sh
|
||||
#### (增量)预训练
|
||||
|
||||
```bash
|
||||
llamafactory-cli train examples/train_lora/llama3_lora_pretrain.yaml
|
||||
llamafactory-cli train examples/train_lora/qwen3_lora_pretrain.yaml
|
||||
```
|
||||
|
||||
#### 指令监督微调
|
||||
|
||||
```bash
|
||||
llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
|
||||
llamafactory-cli train examples/train_lora/qwen3_lora_sft.yaml
|
||||
```
|
||||
|
||||
#### 多模态指令监督微调
|
||||
|
||||
```bash
|
||||
llamafactory-cli train examples/train_lora/qwen2_5vl_lora_sft.yaml
|
||||
llamafactory-cli train examples/train_lora/qwen3vl_lora_sft.yaml
|
||||
```
|
||||
|
||||
#### DPO/ORPO/SimPO 训练
|
||||
|
||||
```bash
|
||||
llamafactory-cli train examples/train_lora/llama3_lora_dpo.yaml
|
||||
llamafactory-cli train examples/train_lora/qwen3_lora_dpo.yaml
|
||||
```
|
||||
|
||||
#### 多模态 DPO/ORPO/SimPO 训练
|
||||
|
||||
```bash
|
||||
llamafactory-cli train examples/train_lora/qwen2_5vl_lora_dpo.yaml
|
||||
llamafactory-cli train examples/train_lora/qwen3vl_lora_dpo.yaml
|
||||
```
|
||||
|
||||
#### 奖励模型训练
|
||||
|
||||
```bash
|
||||
llamafactory-cli train examples/train_lora/llama3_lora_reward.yaml
|
||||
```
|
||||
|
||||
#### PPO 训练
|
||||
|
||||
```bash
|
||||
llamafactory-cli train examples/train_lora/llama3_lora_ppo.yaml
|
||||
llamafactory-cli train examples/train_lora/qwen3_lora_reward.yaml
|
||||
```
|
||||
|
||||
#### KTO 训练
|
||||
|
||||
```bash
|
||||
llamafactory-cli train examples/train_lora/llama3_lora_kto.yaml
|
||||
llamafactory-cli train examples/train_lora/qwen3_lora_kto.yaml
|
||||
```
|
||||
|
||||
#### 预处理数据集
|
||||
@@ -90,20 +84,14 @@ llamafactory-cli train examples/train_lora/llama3_lora_kto.yaml
|
||||
对于大数据集有帮助,在配置中使用 `tokenized_path` 以加载预处理后的数据集。
|
||||
|
||||
```bash
|
||||
llamafactory-cli train examples/train_lora/llama3_preprocess.yaml
|
||||
```
|
||||
|
||||
#### 在 MMLU/CMMLU/C-Eval 上评估
|
||||
|
||||
```bash
|
||||
llamafactory-cli eval examples/train_lora/llama3_lora_eval.yaml
|
||||
llamafactory-cli train examples/train_lora/qwen3_preprocess.yaml
|
||||
```
|
||||
|
||||
#### 多机指令监督微调
|
||||
|
||||
```bash
|
||||
FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
|
||||
FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
|
||||
FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/qwen3_lora_sft.yaml
|
||||
FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/qwen3_lora_sft.yaml
|
||||
```
|
||||
|
||||
### 支持弹性和容错的多机指令监督微调
|
||||
@@ -111,19 +99,19 @@ FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500
|
||||
要启动一个支持弹性节点和容错的多机指令微调,在每个节点上执行以下命令。弹性节点数量范围为 `MIN_NNODES:MAX_NNODES`,每个节点最多允许因为错误重启 `MAX_RESTARTS` 次。`RDZV_ID` 应设置为一个唯一的作业 ID(由参与该作业的所有节点共享)。更多新可以参考官方文档 [torchrun](https://docs.pytorch.org/docs/stable/elastic/run.html)。
|
||||
|
||||
```bash
|
||||
FORCE_TORCHRUN=1 MIN_NNODES=1 MAX_NNODES=3 MAX_RESTARTS=3 RDZV_ID=llamafactory MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft.yaml
|
||||
FORCE_TORCHRUN=1 MIN_NNODES=1 MAX_NNODES=3 MAX_RESTARTS=3 RDZV_ID=llamafactory MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/qwen3_full_sft.yaml
|
||||
```
|
||||
|
||||
#### 使用 DeepSpeed ZeRO-3 平均分配显存
|
||||
|
||||
```bash
|
||||
FORCE_TORCHRUN=1 llamafactory-cli train examples/train_lora/llama3_lora_sft_ds3.yaml
|
||||
FORCE_TORCHRUN=1 llamafactory-cli train examples/train_lora/qwen3_lora_sft_ds3.yaml
|
||||
```
|
||||
|
||||
#### 使用 Ray 在 4 张 GPU 上微调
|
||||
|
||||
```bash
|
||||
USE_RAY=1 llamafactory-cli train examples/train_lora/llama3_lora_sft_ray.yaml
|
||||
USE_RAY=1 llamafactory-cli train examples/train_lora/qwen3_lora_sft_ray.yaml
|
||||
```
|
||||
|
||||
### QLoRA 微调
|
||||
@@ -131,13 +119,13 @@ USE_RAY=1 llamafactory-cli train examples/train_lora/llama3_lora_sft_ray.yaml
|
||||
#### 基于 4/8 比特 Bitsandbytes/HQQ/EETQ 量化进行指令监督微调(推荐)
|
||||
|
||||
```bash
|
||||
llamafactory-cli train examples/train_qlora/llama3_lora_sft_otfq.yaml
|
||||
llamafactory-cli train examples/train_qlora/qwen3_lora_sft_otfq.yaml
|
||||
```
|
||||
|
||||
#### 在 NPU 上基于 4 比特 Bitsandbytes 量化进行指令监督微调
|
||||
|
||||
```bash
|
||||
llamafactory-cli train examples/train_qlora/llama3_lora_sft_bnb_npu.yaml
|
||||
llamafactory-cli train examples/train_qlora/qwen3_lora_sft_bnb_npu.yaml
|
||||
```
|
||||
|
||||
#### 基于 4/8 比特 GPTQ 量化进行指令监督微调
|
||||
@@ -163,20 +151,20 @@ llamafactory-cli train examples/train_qlora/llama3_lora_sft_aqlm.yaml
|
||||
#### 在单机上进行指令监督微调
|
||||
|
||||
```bash
|
||||
FORCE_TORCHRUN=1 llamafactory-cli train examples/train_full/llama3_full_sft.yaml
|
||||
FORCE_TORCHRUN=1 llamafactory-cli train examples/train_full/qwen3_full_sft.yaml
|
||||
```
|
||||
|
||||
#### 在多机上进行指令监督微调
|
||||
|
||||
```bash
|
||||
FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft.yaml
|
||||
FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft.yaml
|
||||
FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/qwen3_full_sft.yaml
|
||||
FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/qwen3_full_sft.yaml
|
||||
```
|
||||
|
||||
#### 多模态指令监督微调
|
||||
|
||||
```bash
|
||||
FORCE_TORCHRUN=1 llamafactory-cli train examples/train_full/qwen2_5vl_full_sft.yaml
|
||||
FORCE_TORCHRUN=1 llamafactory-cli train examples/train_full/qwen3vl_full_sft.yaml
|
||||
```
|
||||
|
||||
### 合并 LoRA 适配器与模型量化
|
||||
@@ -186,19 +174,19 @@ FORCE_TORCHRUN=1 llamafactory-cli train examples/train_full/qwen2_5vl_full_sft.y
|
||||
注:请勿使用量化后的模型或 `quantization_bit` 参数来合并 LoRA 适配器。
|
||||
|
||||
```bash
|
||||
llamafactory-cli export examples/merge_lora/llama3_lora_sft.yaml
|
||||
llamafactory-cli export examples/merge_lora/qwen3_lora_sft.yaml
|
||||
```
|
||||
|
||||
#### 使用 AutoGPTQ 量化模型
|
||||
|
||||
```bash
|
||||
llamafactory-cli export examples/merge_lora/llama3_gptq.yaml
|
||||
llamafactory-cli export examples/merge_lora/qwen3_gptq.yaml
|
||||
```
|
||||
|
||||
### 保存 Ollama 配置文件
|
||||
|
||||
```bash
|
||||
llamafactory-cli export examples/merge_lora/llama3_full_sft.yaml
|
||||
llamafactory-cli export examples/merge_lora/qwen3_full_sft.yaml
|
||||
```
|
||||
|
||||
### 推理 LoRA 模型
|
||||
@@ -206,26 +194,26 @@ llamafactory-cli export examples/merge_lora/llama3_full_sft.yaml
|
||||
#### 使用 vLLM 多卡推理评估
|
||||
|
||||
```
|
||||
python scripts/vllm_infer.py --model_name_or_path meta-llama/Meta-Llama-3-8B-Instruct --template llama3 --dataset alpaca_en_demo
|
||||
python scripts/vllm_infer.py --model_name_or_path Qwen/Qwen3-4B-Instruct-2507 --template qwen3_nothink --dataset alpaca_en_demo
|
||||
python scripts/eval_bleu_rouge.py generated_predictions.jsonl
|
||||
```
|
||||
|
||||
#### 使用命令行对话框
|
||||
|
||||
```bash
|
||||
llamafactory-cli chat examples/inference/llama3_lora_sft.yaml
|
||||
llamafactory-cli chat examples/inference/qwen3_lora_sft.yaml
|
||||
```
|
||||
|
||||
#### 使用浏览器对话框
|
||||
|
||||
```bash
|
||||
llamafactory-cli webchat examples/inference/llama3_lora_sft.yaml
|
||||
llamafactory-cli webchat examples/inference/qwen3_lora_sft.yaml
|
||||
```
|
||||
|
||||
#### 启动 OpenAI 风格 API
|
||||
|
||||
```bash
|
||||
llamafactory-cli api examples/inference/llama3_lora_sft.yaml
|
||||
llamafactory-cli api examples/inference/qwen3_lora_sft.yaml
|
||||
```
|
||||
|
||||
### 杂项
|
||||
|
||||
@@ -1,5 +0,0 @@
|
||||
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
||||
adapter_name_or_path: saves/llama3-8b/lora/sft
|
||||
template: llama3
|
||||
infer_backend: huggingface # choices: [huggingface, vllm, sglang, ktransformers]
|
||||
trust_remote_code: true
|
||||
@@ -1,4 +1,4 @@
|
||||
model_name_or_path: Qwen/Qwen2.5-VL-7B-Instruct
|
||||
template: qwen2_vl
|
||||
model_name_or_path: Qwen/Qwen3-4B-Instruct-2507
|
||||
template: qwen3_nothink
|
||||
infer_backend: huggingface # choices: [huggingface, vllm, sglang, ktransformers]
|
||||
trust_remote_code: true
|
||||
@@ -1,4 +1,4 @@
|
||||
model_name_or_path: saves/llama3-8b/full/sft
|
||||
template: llama3
|
||||
model_name_or_path: saves/qwen3-4b/full/sft
|
||||
template: qwen3_nothink
|
||||
infer_backend: huggingface # choices: [huggingface, vllm, sglang, ktransformers]
|
||||
trust_remote_code: true
|
||||
5
examples/inference/qwen3_lora_sft.yaml
Normal file
5
examples/inference/qwen3_lora_sft.yaml
Normal file
@@ -0,0 +1,5 @@
|
||||
model_name_or_path: Qwen/Qwen3-4B-Instruct-2507
|
||||
adapter_name_or_path: saves/qwen3-4b/lora/sft
|
||||
template: qwen3_nothink
|
||||
infer_backend: huggingface # choices: [huggingface, vllm, sglang, ktransformers]
|
||||
trust_remote_code: true
|
||||
@@ -1,4 +1,4 @@
|
||||
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
||||
template: llama3
|
||||
model_name_or_path: Qwen/Qwen3-VL-4B-Instruct
|
||||
template: qwen3_vl_nothink
|
||||
infer_backend: huggingface # choices: [huggingface, vllm, sglang, ktransformers]
|
||||
trust_remote_code: true
|
||||
@@ -1,10 +1,10 @@
|
||||
### model
|
||||
model_name_or_path: saves/llama3-8b/full/sft
|
||||
template: llama3
|
||||
model_name_or_path: saves/qwen3-4b/full/sft
|
||||
template: qwen3_nothink
|
||||
trust_remote_code: true
|
||||
|
||||
### export
|
||||
export_dir: output/llama3_full_sft
|
||||
export_dir: saves/qwen3_sft_merged
|
||||
export_size: 5
|
||||
export_device: cpu # choices: [cpu, auto]
|
||||
export_legacy_format: false
|
||||
@@ -1,10 +1,10 @@
|
||||
### model
|
||||
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
||||
template: llama3
|
||||
model_name_or_path: Qwen/Qwen3-4B-Instruct-2507
|
||||
template: qwen3_nothink
|
||||
trust_remote_code: true
|
||||
|
||||
### export
|
||||
export_dir: output/llama3_gptq
|
||||
export_dir: saves/qwen3_gptq
|
||||
export_quantization_bit: 4
|
||||
export_quantization_dataset: data/c4_demo.jsonl
|
||||
export_size: 5
|
||||
@@ -1,13 +1,13 @@
|
||||
### Note: DO NOT use quantized model or quantization_bit when merging lora adapters
|
||||
|
||||
### model
|
||||
model_name_or_path: Qwen/Qwen2.5-VL-7B-Instruct
|
||||
adapter_name_or_path: saves/qwen2_5vl-7b/lora/sft
|
||||
template: qwen2_vl
|
||||
model_name_or_path: Qwen/Qwen3-4B-Instruct-2507
|
||||
adapter_name_or_path: saves/qwen3-4b/lora/sft
|
||||
template: qwen3_nothink
|
||||
trust_remote_code: true
|
||||
|
||||
### export
|
||||
export_dir: output/qwen2_5vl_lora_sft
|
||||
export_dir: saves/qwen3_sft_merged
|
||||
export_size: 5
|
||||
export_device: cpu # choices: [cpu, auto]
|
||||
export_legacy_format: false
|
||||
@@ -1,13 +1,13 @@
|
||||
### Note: DO NOT use quantized model or quantization_bit when merging lora adapters
|
||||
|
||||
### model
|
||||
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
||||
adapter_name_or_path: saves/llama3-8b/lora/sft
|
||||
template: llama3
|
||||
model_name_or_path: Qwen/Qwen3-VL-4B-Instruct
|
||||
adapter_name_or_path: saves/qwen3-vl-4b/lora/sft
|
||||
template: qwen3_vl_nothink
|
||||
trust_remote_code: true
|
||||
|
||||
### export
|
||||
export_dir: output/llama3_lora_sft
|
||||
export_dir: saves/qwen3_vl_sft_merged
|
||||
export_size: 5
|
||||
export_device: cpu # choices: [cpu, auto]
|
||||
export_legacy_format: false
|
||||
@@ -1,4 +0,0 @@
|
||||
pre-commit
|
||||
ruff
|
||||
pytest
|
||||
build
|
||||
@@ -1,5 +1,5 @@
|
||||
### model
|
||||
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
||||
model_name_or_path: Qwen/Qwen3-4B-Instruct-2507
|
||||
trust_remote_code: true
|
||||
|
||||
### method
|
||||
@@ -10,15 +10,14 @@ deepspeed: examples/deepspeed/ds_z3_config.json # choices: [ds_z0_config.json,
|
||||
|
||||
### dataset
|
||||
dataset: identity,alpaca_en_demo
|
||||
template: llama3
|
||||
template: qwen3_nothink
|
||||
cutoff_len: 2048
|
||||
max_samples: 1000
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
dataloader_num_workers: 4
|
||||
|
||||
### output
|
||||
output_dir: saves/llama3-8b/full/sft
|
||||
output_dir: saves/qwen3-4b/full/sft
|
||||
logging_steps: 10
|
||||
save_steps: 500
|
||||
plot_loss: true
|
||||
@@ -1,46 +0,0 @@
|
||||
### model
|
||||
model_name_or_path: Qwen/Qwen3-32B
|
||||
trust_remote_code: true
|
||||
use_v1_kernels: true
|
||||
|
||||
### method
|
||||
stage: sft
|
||||
do_train: true
|
||||
finetuning_type: full
|
||||
deepspeed: examples/deepspeed/ds_z2_autotp_config.json
|
||||
|
||||
### dataset
|
||||
dataset: identity,alpaca_en_demo
|
||||
template: qwen3
|
||||
cutoff_len: 2048
|
||||
max_samples: 1000
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
dataloader_num_workers: 4
|
||||
|
||||
### output
|
||||
output_dir: saves/qwen3-32b/full/sft_autotp
|
||||
logging_steps: 1
|
||||
save_steps: 500
|
||||
plot_loss: true
|
||||
overwrite_output_dir: true
|
||||
save_only_model: false
|
||||
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
|
||||
|
||||
### train
|
||||
per_device_train_batch_size: 4
|
||||
gradient_accumulation_steps: 1
|
||||
learning_rate: 1.0e-4
|
||||
num_train_epochs: 3.0
|
||||
lr_scheduler_type: cosine
|
||||
warmup_ratio: 0.1
|
||||
bf16: true
|
||||
ddp_timeout: 180000000
|
||||
resume_from_checkpoint: null
|
||||
|
||||
### eval
|
||||
# eval_dataset: alpaca_en_demo
|
||||
# val_size: 0.1
|
||||
# per_device_eval_batch_size: 1
|
||||
# eval_strategy: steps
|
||||
# eval_steps: 500
|
||||
@@ -1,5 +1,5 @@
|
||||
### model
|
||||
model_name_or_path: Qwen/Qwen2.5-VL-7B-Instruct
|
||||
model_name_or_path: Qwen/Qwen3-VL-4B-Instruct
|
||||
image_max_pixels: 262144
|
||||
video_max_pixels: 16384
|
||||
trust_remote_code: true
|
||||
@@ -15,15 +15,14 @@ deepspeed: examples/deepspeed/ds_z3_config.json
|
||||
|
||||
### dataset
|
||||
dataset: mllm_demo,identity,alpaca_en_demo
|
||||
template: qwen2_vl
|
||||
template: qwen3_vl_nothink
|
||||
cutoff_len: 2048
|
||||
max_samples: 1000
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
dataloader_num_workers: 4
|
||||
|
||||
### output
|
||||
output_dir: saves/qwen2_5vl-7b/full/sft
|
||||
output_dir: saves/qwen3-vl-4b/full/sft
|
||||
logging_steps: 10
|
||||
save_steps: 500
|
||||
plot_loss: true
|
||||
@@ -1,19 +0,0 @@
|
||||
### model
|
||||
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
||||
adapter_name_or_path: saves/llama3-8b/lora/sft
|
||||
trust_remote_code: true
|
||||
|
||||
### method
|
||||
finetuning_type: lora
|
||||
|
||||
### dataset
|
||||
task: mmlu_test # choices: [mmlu_test, ceval_validation, cmmlu_test]
|
||||
template: fewshot
|
||||
lang: en
|
||||
n_shot: 5
|
||||
|
||||
### output
|
||||
save_dir: saves/llama3-8b/lora/eval
|
||||
|
||||
### eval
|
||||
batch_size: 4
|
||||
@@ -1,43 +0,0 @@
|
||||
### model
|
||||
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
||||
reward_model: saves/llama3-8b/lora/reward
|
||||
trust_remote_code: true
|
||||
|
||||
### method
|
||||
stage: ppo
|
||||
do_train: true
|
||||
finetuning_type: lora
|
||||
lora_rank: 8
|
||||
lora_target: all
|
||||
|
||||
### dataset
|
||||
dataset: identity,alpaca_en_demo
|
||||
template: llama3
|
||||
cutoff_len: 2048
|
||||
max_samples: 1000
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
dataloader_num_workers: 4
|
||||
|
||||
### output
|
||||
output_dir: saves/llama3-8b/lora/ppo
|
||||
logging_steps: 10
|
||||
save_steps: 500
|
||||
plot_loss: true
|
||||
overwrite_output_dir: true
|
||||
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
|
||||
|
||||
### train
|
||||
per_device_train_batch_size: 1
|
||||
gradient_accumulation_steps: 8
|
||||
learning_rate: 1.0e-5
|
||||
num_train_epochs: 3.0
|
||||
lr_scheduler_type: cosine
|
||||
warmup_ratio: 0.1
|
||||
bf16: true
|
||||
ddp_timeout: 180000000
|
||||
|
||||
### generate
|
||||
max_new_tokens: 512
|
||||
top_k: 0
|
||||
top_p: 0.9
|
||||
@@ -1,46 +0,0 @@
|
||||
### model
|
||||
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
||||
trust_remote_code: true
|
||||
|
||||
### method
|
||||
stage: sft
|
||||
do_train: true
|
||||
finetuning_type: lora
|
||||
lora_rank: 8
|
||||
lora_target: all
|
||||
|
||||
### dataset
|
||||
dataset: identity,alpaca_en_demo
|
||||
template: llama3
|
||||
cutoff_len: 2048
|
||||
max_samples: 1000
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
dataloader_num_workers: 4
|
||||
|
||||
### output
|
||||
output_dir: saves/llama3-8b/lora/sft
|
||||
logging_steps: 10
|
||||
save_steps: 500
|
||||
plot_loss: true
|
||||
overwrite_output_dir: true
|
||||
save_only_model: false
|
||||
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
|
||||
|
||||
### train
|
||||
per_device_train_batch_size: 1
|
||||
gradient_accumulation_steps: 8
|
||||
learning_rate: 1.0e-4
|
||||
num_train_epochs: 3.0
|
||||
lr_scheduler_type: cosine
|
||||
warmup_ratio: 0.1
|
||||
bf16: true
|
||||
ddp_timeout: 180000000
|
||||
resume_from_checkpoint: null
|
||||
|
||||
### eval
|
||||
# eval_dataset: alpaca_en_demo
|
||||
# val_size: 0.1
|
||||
# per_device_eval_batch_size: 1
|
||||
# eval_strategy: steps
|
||||
# eval_steps: 500
|
||||
@@ -1,49 +0,0 @@
|
||||
# pip install git+https://github.com/hiyouga/transformers.git@llama4_train
|
||||
|
||||
### model
|
||||
model_name_or_path: meta-llama/Llama-4-Scout-17B-16E-Instruct
|
||||
trust_remote_code: true
|
||||
|
||||
### method
|
||||
stage: sft
|
||||
do_train: true
|
||||
finetuning_type: lora
|
||||
lora_rank: 8
|
||||
lora_target: all
|
||||
deepspeed: examples/deepspeed/ds_z3_config.json # choices: [ds_z0_config.json, ds_z2_config.json, ds_z3_config.json]
|
||||
|
||||
### dataset
|
||||
dataset: mllm_demo,identity,alpaca_en_demo
|
||||
template: llama4
|
||||
cutoff_len: 2048
|
||||
max_samples: 1000
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
dataloader_num_workers: 4
|
||||
|
||||
### output
|
||||
output_dir: saves/llama4-8b/lora/sft
|
||||
logging_steps: 10
|
||||
save_steps: 500
|
||||
plot_loss: true
|
||||
overwrite_output_dir: true
|
||||
save_only_model: false
|
||||
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
|
||||
|
||||
### train
|
||||
per_device_train_batch_size: 1
|
||||
gradient_accumulation_steps: 2
|
||||
learning_rate: 1.0e-4
|
||||
num_train_epochs: 3.0
|
||||
lr_scheduler_type: cosine
|
||||
warmup_ratio: 0.1
|
||||
bf16: true
|
||||
ddp_timeout: 180000000
|
||||
resume_from_checkpoint: null
|
||||
|
||||
### eval
|
||||
# eval_dataset: alpaca_en_demo
|
||||
# val_size: 0.1
|
||||
# per_device_eval_batch_size: 1
|
||||
# eval_strategy: steps
|
||||
# eval_steps: 500
|
||||
@@ -1,5 +1,5 @@
|
||||
### model
|
||||
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
||||
model_name_or_path: Qwen/Qwen3-4B-Instruct-2507
|
||||
trust_remote_code: true
|
||||
|
||||
### method
|
||||
@@ -13,15 +13,14 @@ pref_loss: sigmoid # choices: [sigmoid (dpo), orpo, simpo]
|
||||
|
||||
### dataset
|
||||
dataset: dpo_en_demo
|
||||
template: llama3
|
||||
template: qwen3_nothink
|
||||
cutoff_len: 2048
|
||||
max_samples: 1000
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
dataloader_num_workers: 4
|
||||
|
||||
### output
|
||||
output_dir: saves/llama3-8b/lora/dpo
|
||||
output_dir: saves/qwen3-4b/lora/dpo
|
||||
logging_steps: 10
|
||||
save_steps: 500
|
||||
plot_loss: true
|
||||
@@ -1,5 +1,5 @@
|
||||
### model
|
||||
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
||||
model_name_or_path: Qwen/Qwen3-4B-Instruct-2507
|
||||
trust_remote_code: true
|
||||
|
||||
### method
|
||||
@@ -12,15 +12,14 @@ pref_beta: 0.1
|
||||
|
||||
### dataset
|
||||
dataset: kto_en_demo
|
||||
template: llama3
|
||||
template: qwen3_nothink
|
||||
cutoff_len: 2048
|
||||
max_samples: 1000
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
dataloader_num_workers: 4
|
||||
|
||||
### output
|
||||
output_dir: saves/llama3-8b/lora/kto
|
||||
output_dir: saves/qwen3-4b/lora/kto
|
||||
logging_steps: 10
|
||||
save_steps: 500
|
||||
plot_loss: true
|
||||
@@ -1,5 +1,5 @@
|
||||
### model
|
||||
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
||||
model_name_or_path: Qwen/Qwen3-4B-Instruct-2507
|
||||
trust_remote_code: true
|
||||
|
||||
### method
|
||||
@@ -13,12 +13,11 @@ lora_target: all
|
||||
dataset: c4_demo
|
||||
cutoff_len: 2048
|
||||
max_samples: 1000
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
dataloader_num_workers: 4
|
||||
|
||||
### output
|
||||
output_dir: saves/llama3-8b/lora/pretrain
|
||||
output_dir: saves/qwen3-4b/lora/pretrain
|
||||
logging_steps: 10
|
||||
save_steps: 500
|
||||
plot_loss: true
|
||||
@@ -1,5 +1,5 @@
|
||||
### model
|
||||
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
||||
model_name_or_path: Qwen/Qwen3-4B-Instruct-2507
|
||||
trust_remote_code: true
|
||||
|
||||
### method
|
||||
@@ -11,15 +11,14 @@ lora_target: all
|
||||
|
||||
### dataset
|
||||
dataset: dpo_en_demo
|
||||
template: llama3
|
||||
template: qwen3_nothink
|
||||
cutoff_len: 2048
|
||||
max_samples: 1000
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
dataloader_num_workers: 4
|
||||
|
||||
### output
|
||||
output_dir: saves/llama3-8b/lora/reward
|
||||
output_dir: saves/qwen3-4b/lora/reward
|
||||
logging_steps: 10
|
||||
save_steps: 500
|
||||
plot_loss: true
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
set -x
|
||||
|
||||
MODEL_PATH=meta-llama/Meta-Llama-3-8B-Instruct
|
||||
MODEL_PATH=Qwen/Qwen3-4B-Instruct-2507
|
||||
|
||||
llamafactory-cli train \
|
||||
--model_name_or_path ${MODEL_PATH} \
|
||||
@@ -13,13 +13,12 @@ llamafactory-cli train \
|
||||
--lora_rank 8 \
|
||||
--lora_target all \
|
||||
--dataset identity,alpaca_en_demo \
|
||||
--template llama3 \
|
||||
--template qwen3_nothink \
|
||||
--cutoff_len 2048 \
|
||||
--max_samples 1000 \
|
||||
--overwrite_cache \
|
||||
--preprocessing_num_workers 16 \
|
||||
--dataloader_num_workers 4 \
|
||||
--output_dir saves/llama3-8b/lora/sft \
|
||||
--output_dir saves/qwen3-4b/lora/sft \
|
||||
--logging_steps 10 \
|
||||
--save_steps 500 \
|
||||
--plot_loss \
|
||||
@@ -1,5 +1,5 @@
|
||||
### model
|
||||
model_name_or_path: openai/gpt-oss-20b
|
||||
model_name_or_path: Qwen/Qwen3-4B-Instruct-2507
|
||||
trust_remote_code: true
|
||||
|
||||
### method
|
||||
@@ -11,15 +11,14 @@ lora_target: all
|
||||
|
||||
### dataset
|
||||
dataset: identity,alpaca_en_demo
|
||||
template: gpt
|
||||
template: qwen3_nothink
|
||||
cutoff_len: 2048
|
||||
max_samples: 1000
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
dataloader_num_workers: 4
|
||||
|
||||
### output
|
||||
output_dir: saves/gpt-20b/lora/sft
|
||||
output_dir: saves/qwen3-4b/lora/sft
|
||||
logging_steps: 10
|
||||
save_steps: 500
|
||||
plot_loss: true
|
||||
@@ -1,5 +1,5 @@
|
||||
### model
|
||||
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
||||
model_name_or_path: Qwen/Qwen3-4B-Instruct-2507
|
||||
trust_remote_code: true
|
||||
|
||||
### method
|
||||
@@ -12,15 +12,14 @@ deepspeed: examples/deepspeed/ds_z3_config.json # choices: [ds_z0_config.json,
|
||||
|
||||
### dataset
|
||||
dataset: identity,alpaca_en_demo
|
||||
template: llama3
|
||||
template: qwen3_nothink
|
||||
cutoff_len: 2048
|
||||
max_samples: 1000
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
dataloader_num_workers: 4
|
||||
|
||||
### output
|
||||
output_dir: saves/llama3-8b/lora/sft
|
||||
output_dir: saves/qwen3-4b/lora/sft
|
||||
logging_steps: 10
|
||||
save_steps: 500
|
||||
plot_loss: true
|
||||
@@ -1,5 +1,5 @@
|
||||
### model
|
||||
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct # or use local absolute path
|
||||
model_name_or_path: Qwen/Qwen3-4B-Instruct-2507 # or use local absolute path
|
||||
trust_remote_code: true
|
||||
|
||||
### method
|
||||
@@ -12,10 +12,9 @@ lora_target: all
|
||||
### dataset
|
||||
dataset: identity,alpaca_en_demo
|
||||
dataset_dir: REMOTE:llamafactory/demo_data # or use local absolute path
|
||||
template: llama3
|
||||
template: qwen3_nothink
|
||||
cutoff_len: 2048
|
||||
max_samples: 1000
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
dataloader_num_workers: 4
|
||||
|
||||
@@ -29,7 +28,7 @@ save_only_model: false
|
||||
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
|
||||
|
||||
### ray
|
||||
ray_run_name: llama3_8b_sft_lora
|
||||
ray_run_name: qwen3_4b_sft_lora
|
||||
ray_storage_path: ./saves
|
||||
ray_num_workers: 4 # Number of GPUs to use.
|
||||
placement_strategy: PACK
|
||||
@@ -1,5 +1,5 @@
|
||||
### model
|
||||
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
||||
model_name_or_path: Qwen/Qwen3-4B-Instruct-2507
|
||||
trust_remote_code: true
|
||||
|
||||
### method
|
||||
@@ -11,13 +11,12 @@ lora_target: all
|
||||
|
||||
### dataset
|
||||
dataset: identity,alpaca_en_demo
|
||||
template: llama3
|
||||
template: qwen3_nothink
|
||||
cutoff_len: 2048
|
||||
max_samples: 1000
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
tokenized_path: saves/llama3-8b/dataset/sft
|
||||
tokenized_path: saves/qwen3-4b/dataset/sft
|
||||
|
||||
### output
|
||||
output_dir: saves/llama3-8b/lora/sft
|
||||
### output (not used)
|
||||
output_dir: saves/qwen3-4b/lora/sft
|
||||
overwrite_output_dir: true
|
||||
@@ -1,5 +1,5 @@
|
||||
### model
|
||||
model_name_or_path: Qwen/Qwen2.5-VL-7B-Instruct
|
||||
model_name_or_path: Qwen/Qwen3-VL-4B-Instruct
|
||||
image_max_pixels: 262144
|
||||
video_max_pixels: 16384
|
||||
trust_remote_code: true
|
||||
@@ -15,15 +15,14 @@ pref_loss: sigmoid # choices: [sigmoid (dpo), orpo, simpo]
|
||||
|
||||
### dataset
|
||||
dataset: rlhf_v
|
||||
template: qwen2_vl
|
||||
template: qwen3_vl_nothink
|
||||
cutoff_len: 2048
|
||||
max_samples: 1000
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
dataloader_num_workers: 4
|
||||
|
||||
### output
|
||||
output_dir: saves/qwen2_5vl-7b/lora/dpo
|
||||
output_dir: saves/qwen3-vl-4b/lora/dpo
|
||||
logging_steps: 10
|
||||
save_steps: 500
|
||||
plot_loss: true
|
||||
@@ -1,5 +1,5 @@
|
||||
### model
|
||||
model_name_or_path: Qwen/Qwen2.5-VL-7B-Instruct
|
||||
model_name_or_path: Qwen/Qwen3-VL-4B-Instruct
|
||||
image_max_pixels: 262144
|
||||
video_max_pixels: 16384
|
||||
trust_remote_code: true
|
||||
@@ -13,15 +13,14 @@ lora_target: all
|
||||
|
||||
### dataset
|
||||
dataset: mllm_demo,identity,alpaca_en_demo # video: mllm_video_demo
|
||||
template: qwen2_vl
|
||||
template: qwen3_vl_nothink
|
||||
cutoff_len: 2048
|
||||
max_samples: 1000
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
dataloader_num_workers: 4
|
||||
|
||||
### output
|
||||
output_dir: saves/qwen2_5vl-7b/lora/sft
|
||||
output_dir: saves/qwen3-vl-4b/lora/sft
|
||||
logging_steps: 10
|
||||
save_steps: 500
|
||||
plot_loss: true
|
||||
@@ -14,7 +14,6 @@ dataset: identity,alpaca_en_demo
|
||||
template: llama3
|
||||
cutoff_len: 2048
|
||||
max_samples: 1000
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
dataloader_num_workers: 4
|
||||
|
||||
|
||||
@@ -14,7 +14,6 @@ dataset: identity,alpaca_en_demo
|
||||
template: llama3
|
||||
cutoff_len: 2048
|
||||
max_samples: 1000
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
dataloader_num_workers: 4
|
||||
|
||||
|
||||
@@ -14,7 +14,6 @@ dataset: identity,alpaca_en_demo
|
||||
template: llama3
|
||||
cutoff_len: 2048
|
||||
max_samples: 1000
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
dataloader_num_workers: 4
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
### model
|
||||
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
||||
model_name_or_path: Qwen/Qwen3-4B-Instruct-2507
|
||||
quantization_bit: 4
|
||||
quantization_method: bnb
|
||||
double_quantization: false
|
||||
@@ -14,15 +14,14 @@ lora_target: all
|
||||
|
||||
### dataset
|
||||
dataset: identity,alpaca_en_demo
|
||||
template: llama3
|
||||
template: qwen3_nothink
|
||||
cutoff_len: 2048
|
||||
max_samples: 1000
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
dataloader_num_workers: 4
|
||||
|
||||
### output
|
||||
output_dir: saves/llama3-8b/lora/sft
|
||||
output_dir: saves/qwen3-4b/lora/sft
|
||||
logging_steps: 10
|
||||
save_steps: 500
|
||||
plot_loss: true
|
||||
@@ -1,5 +1,5 @@
|
||||
### model
|
||||
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
||||
model_name_or_path: Qwen/Qwen3-4B-Instruct-2507
|
||||
quantization_bit: 4 # choices: [8 (bnb/hqq/eetq), 4 (bnb/hqq), 3 (hqq), 2 (hqq)]
|
||||
quantization_method: bnb # choices: [bnb, hqq, eetq]
|
||||
trust_remote_code: true
|
||||
@@ -13,15 +13,14 @@ lora_target: all
|
||||
|
||||
### dataset
|
||||
dataset: identity,alpaca_en_demo
|
||||
template: llama3
|
||||
template: qwen3_nothink
|
||||
cutoff_len: 2048
|
||||
max_samples: 1000
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
dataloader_num_workers: 4
|
||||
|
||||
### output
|
||||
output_dir: saves/llama3-8b/lora/sft
|
||||
output_dir: saves/qwen3-4b/lora/sft
|
||||
logging_steps: 10
|
||||
save_steps: 500
|
||||
plot_loss: true
|
||||
@@ -8,7 +8,7 @@ dynamic = ["version"]
|
||||
description = "Unified Efficient Fine-Tuning of 100+ LLMs"
|
||||
readme = "README.md"
|
||||
license = "Apache-2.0"
|
||||
requires-python = ">=3.9.0"
|
||||
requires-python = ">=3.11.0"
|
||||
authors = [
|
||||
{ name = "hiyouga", email = "hiyouga@buaa.edu.cn" }
|
||||
]
|
||||
@@ -30,58 +30,54 @@ classifiers = [
|
||||
"License :: OSI Approved :: Apache Software License",
|
||||
"Operating System :: OS Independent",
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3.9",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
"Programming Language :: Python :: 3.11",
|
||||
"Programming Language :: Python :: 3.12",
|
||||
"Programming Language :: Python :: 3.13",
|
||||
"Topic :: Scientific/Engineering :: Artificial Intelligence"
|
||||
]
|
||||
dependencies = [
|
||||
# core deps
|
||||
"transformers>=4.49.0,<=4.56.2,!=4.52.0; python_version < '3.10'",
|
||||
"transformers>=4.49.0,<=4.57.1,!=4.52.0,!=4.57.0; python_version >= '3.10'",
|
||||
"torch>=2.4.0",
|
||||
"torchvision>=0.19.0",
|
||||
"torchaudio>=2.4.0",
|
||||
"transformers>=4.51.0,<=4.57.1,!=4.52.0,!=4.57.0",
|
||||
"datasets>=2.16.0,<=4.0.0",
|
||||
"accelerate>=1.3.0,<=1.11.0",
|
||||
"peft>=0.14.0,<=0.17.1",
|
||||
"trl>=0.8.6,<=0.9.6",
|
||||
"torchdata",
|
||||
# torch
|
||||
"torch>=2.0.0",
|
||||
"torchvision>=0.15.0",
|
||||
"trl>=0.18.0,<=0.24.0",
|
||||
"torchdata>=0.10.0,<=0.11.0",
|
||||
# gui
|
||||
"gradio>=4.38.0,<=5.45.0",
|
||||
"gradio>=4.38.0,<=5.50.0",
|
||||
"matplotlib>=3.7.0",
|
||||
"tyro<0.9.0",
|
||||
# ops
|
||||
"einops",
|
||||
"numpy<2.0.0",
|
||||
"pandas>=2.0.0",
|
||||
"numpy",
|
||||
"pandas",
|
||||
"scipy",
|
||||
# model and tokenizer
|
||||
"sentencepiece",
|
||||
"tiktoken",
|
||||
"modelscope>=1.14.0",
|
||||
"modelscope",
|
||||
"hf-transfer",
|
||||
"safetensors<=0.5.3",
|
||||
"safetensors",
|
||||
# python
|
||||
"av",
|
||||
"fire",
|
||||
"omegaconf",
|
||||
"packaging",
|
||||
"protobuf",
|
||||
"pyyaml",
|
||||
"pydantic<=2.10.6",
|
||||
"pydantic",
|
||||
# api
|
||||
"uvicorn",
|
||||
"fastapi",
|
||||
"sse-starlette",
|
||||
# media
|
||||
"av",
|
||||
"librosa",
|
||||
# yanked
|
||||
"propcache!=0.4.0"
|
||||
"sse-starlette"
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
dev = ["pre-commit", "ruff", "pytest", "build"]
|
||||
metrics = ["nltk", "jieba", "rouge-chinese"]
|
||||
deepspeed = ["deepspeed>=0.10.0,<=0.16.9"]
|
||||
|
||||
@@ -101,24 +97,26 @@ path = "src/llamafactory/extras/env.py"
|
||||
pattern = "VERSION = \"(?P<version>[^\"]+)\""
|
||||
|
||||
[tool.ruff]
|
||||
target-version = "py39"
|
||||
target-version = "py311"
|
||||
line-length = 119
|
||||
indent-width = 4
|
||||
|
||||
[tool.ruff.lint]
|
||||
ignore = [
|
||||
"C408", # collection
|
||||
"C901", # complex
|
||||
"E501", # line too long
|
||||
"E731", # lambda function
|
||||
"E741", # ambiguous var name
|
||||
"D100", # no doc public module
|
||||
"D101", # no doc public class
|
||||
"D102", # no doc public method
|
||||
"D103", # no doc public function
|
||||
"D104", # no doc public package
|
||||
"D105", # no doc magic method
|
||||
"D107", # no doc __init__
|
||||
"C408", # collection
|
||||
"C901", # complex
|
||||
"E501", # line too long
|
||||
"E731", # lambda function
|
||||
"E741", # ambiguous var name
|
||||
"UP007", # no upgrade union
|
||||
"UP045", # no upgrade optional
|
||||
"D100", # no doc public module
|
||||
"D101", # no doc public class
|
||||
"D102", # no doc public method
|
||||
"D103", # no doc public function
|
||||
"D104", # no doc public package
|
||||
"D105", # no doc magic method
|
||||
"D107", # no doc __init__
|
||||
]
|
||||
extend-select = [
|
||||
"C", # complexity
|
||||
|
||||
@@ -16,7 +16,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import fire
|
||||
import torch
|
||||
@@ -34,7 +33,7 @@ def convert_mca_to_hf(
|
||||
output_path: str = "./output",
|
||||
bf16: bool = False,
|
||||
fp16: bool = False,
|
||||
convert_model_max_length: Optional[int] = None,
|
||||
convert_model_max_length: int | None = None,
|
||||
):
|
||||
"""Convert megatron checkpoint to HuggingFace format.
|
||||
|
||||
@@ -67,11 +66,11 @@ def convert(
|
||||
output_path: str = "./output",
|
||||
bf16: bool = False,
|
||||
fp16: bool = False,
|
||||
convert_model_max_length: Optional[int] = None,
|
||||
convert_model_max_length: int | None = None,
|
||||
tensor_model_parallel_size: int = 1,
|
||||
pipeline_model_parallel_size: int = 1,
|
||||
expert_model_parallel_size: int = 1,
|
||||
virtual_pipeline_model_parallel_size: Optional[int] = None,
|
||||
virtual_pipeline_model_parallel_size: int | None = None,
|
||||
):
|
||||
"""Convert checkpoint between MCA and HuggingFace formats.
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Literal, Optional
|
||||
from typing import Any, Literal
|
||||
|
||||
import fire
|
||||
import torch
|
||||
@@ -61,7 +61,7 @@ def calculate_ppl(
|
||||
dataset_dir: str = "data",
|
||||
template: str = "default",
|
||||
cutoff_len: int = 2048,
|
||||
max_samples: Optional[int] = None,
|
||||
max_samples: int | None = None,
|
||||
train_on_prompt: bool = False,
|
||||
):
|
||||
r"""Calculate the ppl on the dataset of the pre-trained models.
|
||||
|
||||
@@ -14,10 +14,12 @@
|
||||
|
||||
import gc
|
||||
import json
|
||||
from typing import Optional
|
||||
import time
|
||||
|
||||
import av
|
||||
import fire
|
||||
from datasets import load_dataset
|
||||
from eval_bleu_rouge import compute_metrics
|
||||
from tqdm import tqdm
|
||||
from transformers import Seq2SeqTrainingArguments
|
||||
|
||||
@@ -49,18 +51,19 @@ def vllm_infer(
|
||||
dataset_dir: str = "data",
|
||||
template: str = "default",
|
||||
cutoff_len: int = 2048,
|
||||
max_samples: Optional[int] = None,
|
||||
max_samples: int | None = None,
|
||||
vllm_config: str = "{}",
|
||||
save_name: str = "generated_predictions.jsonl",
|
||||
matrix_save_name: str = None,
|
||||
temperature: float = 0.95,
|
||||
top_p: float = 0.7,
|
||||
top_k: int = 50,
|
||||
max_new_tokens: int = 1024,
|
||||
repetition_penalty: float = 1.0,
|
||||
skip_special_tokens: bool = True,
|
||||
default_system: Optional[str] = None,
|
||||
default_system: str | None = None,
|
||||
enable_thinking: bool = True,
|
||||
seed: Optional[int] = None,
|
||||
seed: int | None = None,
|
||||
pipeline_parallel_size: int = 1,
|
||||
image_max_pixels: int = 768 * 768,
|
||||
image_min_pixels: int = 32 * 32,
|
||||
@@ -118,6 +121,7 @@ def vllm_infer(
|
||||
if isinstance(model_args.vllm_config, dict):
|
||||
engine_args.update(model_args.vllm_config)
|
||||
|
||||
model_preparation_start_time = time.time()
|
||||
llm = LLM(**engine_args)
|
||||
|
||||
# load datasets
|
||||
@@ -143,6 +147,7 @@ def vllm_infer(
|
||||
all_prompts, all_preds, all_labels = [], [], []
|
||||
need_video_kwargs = _need_video_kwargs(template)
|
||||
|
||||
model_predict_start_time = time.time()
|
||||
# Add batch process to avoid the issue of too many files opened
|
||||
for i in tqdm(range(0, len(train_dataset), batch_size), desc="Processing batched inference"):
|
||||
vllm_inputs, prompts, labels = [], [], []
|
||||
@@ -219,6 +224,7 @@ def vllm_infer(
|
||||
all_labels.extend(labels)
|
||||
gc.collect()
|
||||
|
||||
model_predict_end_time = time.time()
|
||||
# Write all results at once outside the loop
|
||||
with open(save_name, "w", encoding="utf-8") as f:
|
||||
for text, pred, label in zip(all_prompts, all_preds, all_labels):
|
||||
@@ -228,6 +234,49 @@ def vllm_infer(
|
||||
print(f"{len(all_prompts)} total generated results have been saved at {save_name}.")
|
||||
print("*" * 70)
|
||||
|
||||
# Write all matrix results when matrix_save_name is not None,
|
||||
# The result matrix is referencing src.llamafactory.train.sft.workflow.run_sft # 127~132
|
||||
# trainer.save_metrics("predict", predict_results.metrics)
|
||||
#
|
||||
# {
|
||||
# "predict_bleu-4": 4.349975,
|
||||
# "predict_model_preparation_time": 0.0128,
|
||||
# "predict_rouge-1": 21.873359375,
|
||||
# "predict_rouge-2": 4.144340625,
|
||||
# "predict_rouge-l": 10.83949375,
|
||||
# "predict_runtime": 131.664,
|
||||
# "predict_samples_per_second": 0.076,
|
||||
# "predict_steps_per_second": 0.008
|
||||
# }
|
||||
#
|
||||
if matrix_save_name is not None:
|
||||
predict_time = model_predict_end_time - model_predict_start_time
|
||||
preparation_time = model_predict_start_time - model_preparation_start_time
|
||||
|
||||
start_time = time.time()
|
||||
dataset = load_dataset("json", data_files=save_name, split="train")
|
||||
dataset = dataset.map(compute_metrics, num_proc=8, remove_columns=dataset.column_names)
|
||||
score_dict = dataset.to_dict()
|
||||
|
||||
average_score = {}
|
||||
for task, scores in sorted(score_dict.items(), key=lambda x: x[0]):
|
||||
score = sum(scores) / len(scores) if scores else 0.0
|
||||
print(f"predict_{task}: {score:.4f}")
|
||||
average_score["predict_" + task] = score
|
||||
|
||||
average_score["predict_model_preparation_time"] = preparation_time
|
||||
average_score["predict_runtime"] = predict_time
|
||||
num_steps = len(range(0, len(train_dataset), batch_size))
|
||||
average_score["predict_samples_per_second"] = len(dataset) / predict_time if predict_time > 0 else 0.0
|
||||
average_score["predict_steps_per_second"] = num_steps / predict_time if predict_time > 0 else 0.0
|
||||
|
||||
with open(matrix_save_name, "w", encoding="utf-8") as f:
|
||||
json.dump(average_score, f, indent=4)
|
||||
|
||||
print("*" * 70)
|
||||
print(f"\nDone in {time.time() - start_time:.3f}s.\nScore file saved to {matrix_save_name}.")
|
||||
print("*" * 70)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(vllm_infer)
|
||||
|
||||
@@ -16,7 +16,7 @@ import asyncio
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
from functools import partial
|
||||
from typing import Annotated, Optional
|
||||
from typing import Annotated
|
||||
|
||||
from ..chat import ChatModel
|
||||
from ..extras.constants import EngineName
|
||||
@@ -79,7 +79,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
|
||||
api_key = os.getenv("API_KEY")
|
||||
security = HTTPBearer(auto_error=False)
|
||||
|
||||
async def verify_api_key(auth: Annotated[Optional[HTTPAuthorizationCredentials], Depends(security)]):
|
||||
async def verify_api_key(auth: Annotated[HTTPAuthorizationCredentials | None, Depends(security)]):
|
||||
if api_key and (auth is None or auth.credentials != api_key):
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key.")
|
||||
|
||||
|
||||
@@ -14,10 +14,9 @@
|
||||
|
||||
import time
|
||||
from enum import Enum, unique
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Literal
|
||||
|
||||
|
||||
@unique
|
||||
@@ -61,7 +60,7 @@ class FunctionDefinition(BaseModel):
|
||||
|
||||
class FunctionAvailable(BaseModel):
|
||||
type: Literal["function", "code_interpreter"] = "function"
|
||||
function: Optional[FunctionDefinition] = None
|
||||
function: FunctionDefinition | None = None
|
||||
|
||||
|
||||
class FunctionCall(BaseModel):
|
||||
@@ -77,35 +76,35 @@ class URL(BaseModel):
|
||||
|
||||
class MultimodalInputItem(BaseModel):
|
||||
type: Literal["text", "image_url", "video_url", "audio_url"]
|
||||
text: Optional[str] = None
|
||||
image_url: Optional[URL] = None
|
||||
video_url: Optional[URL] = None
|
||||
audio_url: Optional[URL] = None
|
||||
text: str | None = None
|
||||
image_url: URL | None = None
|
||||
video_url: URL | None = None
|
||||
audio_url: URL | None = None
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
role: Role
|
||||
content: Optional[Union[str, list[MultimodalInputItem]]] = None
|
||||
tool_calls: Optional[list[FunctionCall]] = None
|
||||
content: str | list[MultimodalInputItem] | None = None
|
||||
tool_calls: list[FunctionCall] | None = None
|
||||
|
||||
|
||||
class ChatCompletionMessage(BaseModel):
|
||||
role: Optional[Role] = None
|
||||
content: Optional[str] = None
|
||||
tool_calls: Optional[list[FunctionCall]] = None
|
||||
role: Role | None = None
|
||||
content: str | None = None
|
||||
tool_calls: list[FunctionCall] | None = None
|
||||
|
||||
|
||||
class ChatCompletionRequest(BaseModel):
|
||||
model: str
|
||||
messages: list[ChatMessage]
|
||||
tools: Optional[list[FunctionAvailable]] = None
|
||||
do_sample: Optional[bool] = None
|
||||
temperature: Optional[float] = None
|
||||
top_p: Optional[float] = None
|
||||
tools: list[FunctionAvailable] | None = None
|
||||
do_sample: bool | None = None
|
||||
temperature: float | None = None
|
||||
top_p: float | None = None
|
||||
n: int = 1
|
||||
presence_penalty: Optional[float] = None
|
||||
max_tokens: Optional[int] = None
|
||||
stop: Optional[Union[str, list[str]]] = None
|
||||
presence_penalty: float | None = None
|
||||
max_tokens: int | None = None
|
||||
stop: str | list[str] | None = None
|
||||
stream: bool = False
|
||||
|
||||
|
||||
@@ -118,7 +117,7 @@ class ChatCompletionResponseChoice(BaseModel):
|
||||
class ChatCompletionStreamResponseChoice(BaseModel):
|
||||
index: int
|
||||
delta: ChatCompletionMessage
|
||||
finish_reason: Optional[Finish] = None
|
||||
finish_reason: Finish | None = None
|
||||
|
||||
|
||||
class ChatCompletionResponseUsage(BaseModel):
|
||||
@@ -147,7 +146,7 @@ class ChatCompletionStreamResponse(BaseModel):
|
||||
class ScoreEvaluationRequest(BaseModel):
|
||||
model: str
|
||||
messages: list[str]
|
||||
max_length: Optional[int] = None
|
||||
max_length: int | None = None
|
||||
|
||||
|
||||
class ScoreEvaluationResponse(BaseModel):
|
||||
|
||||
@@ -14,9 +14,9 @@
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from collections.abc import AsyncGenerator
|
||||
from collections.abc import AsyncGenerator, Callable
|
||||
from threading import Thread
|
||||
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
import torch
|
||||
from transformers import GenerationConfig, TextIteratorStreamer
|
||||
|
||||
@@ -15,7 +15,7 @@ import json
|
||||
import os
|
||||
from abc import abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Union
|
||||
|
||||
from ..extras import logging
|
||||
from .data_utils import Role
|
||||
@@ -40,7 +40,7 @@ class DatasetConverter:
|
||||
dataset_attr: "DatasetAttr"
|
||||
data_args: "DataArguments"
|
||||
|
||||
def _find_medias(self, medias: Union["MediaType", list["MediaType"], None]) -> Optional[list["MediaType"]]:
|
||||
def _find_medias(self, medias: Union["MediaType", list["MediaType"], None]) -> list["MediaType"] | None:
|
||||
r"""Optionally concatenate media path to media dir when loading from local disk."""
|
||||
if medias is None:
|
||||
return None
|
||||
|
||||
@@ -16,7 +16,6 @@ import json
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional, Union
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
@@ -27,14 +26,14 @@ from .tool_utils import FunctionCall, get_tool_utils
|
||||
@dataclass
|
||||
class Formatter(ABC):
|
||||
slots: SLOTS = field(default_factory=list)
|
||||
tool_format: Optional[str] = None
|
||||
tool_format: str | None = None
|
||||
|
||||
@abstractmethod
|
||||
def apply(self, **kwargs) -> SLOTS:
|
||||
r"""Forms a list of slots according to the inputs to encode."""
|
||||
...
|
||||
|
||||
def extract(self, content: str) -> Union[str, list["FunctionCall"]]:
|
||||
def extract(self, content: str) -> str | list["FunctionCall"]:
|
||||
r"""Extract a list of tuples from the response message if using tools.
|
||||
|
||||
Each tuple consists of function name and function arguments.
|
||||
@@ -156,5 +155,5 @@ class ToolFormatter(Formatter):
|
||||
raise RuntimeError(f"Invalid JSON format in tool description: {str([content])}.") # flat string
|
||||
|
||||
@override
|
||||
def extract(self, content: str) -> Union[str, list["FunctionCall"]]:
|
||||
def extract(self, content: str) -> str | list["FunctionCall"]:
|
||||
return self.tool_utils.tool_extractor(content)
|
||||
|
||||
@@ -162,13 +162,13 @@ def _load_single_dataset(
|
||||
|
||||
|
||||
def _get_merged_dataset(
|
||||
dataset_names: Optional[list[str]],
|
||||
dataset_names: list[str] | None,
|
||||
model_args: "ModelArguments",
|
||||
data_args: "DataArguments",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
|
||||
return_dict: bool = False,
|
||||
) -> Optional[Union["Dataset", "IterableDataset", dict[str, "Dataset"]]]:
|
||||
) -> Union["Dataset", "IterableDataset", dict[str, "Dataset"]] | None:
|
||||
r"""Return the merged datasets in the standard format."""
|
||||
if dataset_names is None:
|
||||
return None
|
||||
@@ -227,7 +227,7 @@ def _get_dataset_processor(
|
||||
|
||||
|
||||
def _get_preprocessed_dataset(
|
||||
dataset: Optional[Union["Dataset", "IterableDataset"]],
|
||||
dataset: Union["Dataset", "IterableDataset"] | None,
|
||||
data_args: "DataArguments",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
|
||||
@@ -235,7 +235,7 @@ def _get_preprocessed_dataset(
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"] = None,
|
||||
is_eval: bool = False,
|
||||
) -> Optional[Union["Dataset", "IterableDataset"]]:
|
||||
) -> Union["Dataset", "IterableDataset"] | None:
|
||||
r"""Preprocesses the dataset, including format checking and tokenization."""
|
||||
if dataset is None:
|
||||
return None
|
||||
|
||||
@@ -22,28 +22,20 @@ import re
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from io import BytesIO
|
||||
from typing import TYPE_CHECKING, BinaryIO, Literal, Optional, TypedDict, Union
|
||||
from typing import TYPE_CHECKING, BinaryIO, Literal, NotRequired, Optional, TypedDict, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchaudio
|
||||
from transformers.image_utils import get_image_size, is_valid_image, to_numpy_array
|
||||
from transformers.models.mllama.processing_mllama import (
|
||||
convert_sparse_cross_attention_mask_to_dense,
|
||||
get_cross_attention_token_mask,
|
||||
)
|
||||
from typing_extensions import NotRequired, override
|
||||
from typing_extensions import override
|
||||
|
||||
from ..extras.constants import AUDIO_PLACEHOLDER, IGNORE_INDEX, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
|
||||
from ..extras.packages import (
|
||||
is_librosa_available,
|
||||
is_pillow_available,
|
||||
is_pyav_available,
|
||||
is_transformers_version_greater_than,
|
||||
)
|
||||
|
||||
|
||||
if is_librosa_available():
|
||||
import librosa
|
||||
from ..extras.packages import is_pillow_available, is_pyav_available, is_transformers_version_greater_than
|
||||
|
||||
|
||||
if is_pillow_available():
|
||||
@@ -71,8 +63,8 @@ if TYPE_CHECKING:
|
||||
from transformers.video_processing_utils import BaseVideoProcessor
|
||||
|
||||
class EncodedImage(TypedDict):
|
||||
path: Optional[str]
|
||||
bytes: Optional[bytes]
|
||||
path: str | None
|
||||
bytes: bytes | None
|
||||
|
||||
ImageInput = Union[str, bytes, EncodedImage, BinaryIO, ImageObject]
|
||||
VideoInput = Union[str, BinaryIO, list[list[ImageInput]]]
|
||||
@@ -152,9 +144,9 @@ def _check_video_is_nested_images(video: "VideoInput") -> bool:
|
||||
|
||||
@dataclass
|
||||
class MMPluginMixin:
|
||||
image_token: Optional[str]
|
||||
video_token: Optional[str]
|
||||
audio_token: Optional[str]
|
||||
image_token: str | None
|
||||
video_token: str | None
|
||||
audio_token: str | None
|
||||
expand_mm_tokens: bool = True
|
||||
|
||||
def _validate_input(
|
||||
@@ -316,7 +308,14 @@ class MMPluginMixin:
|
||||
results, sampling_rates = [], []
|
||||
for audio in audios:
|
||||
if not isinstance(audio, np.ndarray):
|
||||
audio, sampling_rate = librosa.load(audio, sr=sampling_rate)
|
||||
audio, sr = torchaudio.load(audio)
|
||||
if audio.shape[0] > 1:
|
||||
audio = audio.mean(dim=0, keepdim=True)
|
||||
|
||||
if sr != sampling_rate:
|
||||
audio = torchaudio.functional.resample(audio, sr, sampling_rate)
|
||||
|
||||
audio = audio.squeeze(0).numpy()
|
||||
|
||||
results.append(audio)
|
||||
sampling_rates.append(sampling_rate)
|
||||
@@ -329,7 +328,7 @@ class MMPluginMixin:
|
||||
videos: list["VideoInput"],
|
||||
audios: list["AudioInput"],
|
||||
processor: "MMProcessor",
|
||||
imglens: Optional[list[int]] = None,
|
||||
imglens: list[int] | None = None,
|
||||
) -> dict[str, "torch.Tensor"]:
|
||||
r"""Process visual inputs.
|
||||
|
||||
@@ -427,13 +426,13 @@ class BasePlugin(MMPluginMixin):
|
||||
def process_token_ids(
|
||||
self,
|
||||
input_ids: list[int],
|
||||
labels: Optional[list[int]],
|
||||
labels: list[int] | None,
|
||||
images: list["ImageInput"],
|
||||
videos: list["VideoInput"],
|
||||
audios: list["AudioInput"],
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["MMProcessor"],
|
||||
) -> tuple[list[int], Optional[list[int]]]:
|
||||
) -> tuple[list[int], list[int] | None]:
|
||||
r"""Pre-process token ids after tokenization for VLMs."""
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
return input_ids, labels
|
||||
@@ -500,13 +499,17 @@ class ErnieVLPlugin(BasePlugin):
|
||||
while IMAGE_PLACEHOLDER in content:
|
||||
image_seqlen = image_grid_thw[image_idx].prod() // merge_length if self.expand_mm_tokens else 1
|
||||
content = content.replace(
|
||||
IMAGE_PLACEHOLDER, f"Picture {image_idx + 1}:<|IMAGE_START|>{image_token * image_seqlen}<|IMAGE_END|>", 1
|
||||
IMAGE_PLACEHOLDER,
|
||||
f"Picture {image_idx + 1}:<|IMAGE_START|>{image_token * image_seqlen}<|IMAGE_END|>",
|
||||
1,
|
||||
)
|
||||
image_idx += 1
|
||||
while VIDEO_PLACEHOLDER in content:
|
||||
video_seqlen = video_grid_thw[video_idx].prod() // merge_length if self.expand_mm_tokens else 1
|
||||
content = content.replace(
|
||||
VIDEO_PLACEHOLDER, f"Video {video_idx + 1}:<|VIDEO_START|>{video_token * video_seqlen}<|VIDEO_END|>", 1
|
||||
VIDEO_PLACEHOLDER,
|
||||
f"Video {video_idx + 1}:<|VIDEO_START|>{video_token * video_seqlen}<|VIDEO_END|>",
|
||||
1,
|
||||
)
|
||||
video_idx += 1
|
||||
message["content"] = content
|
||||
@@ -1302,13 +1305,13 @@ class PaliGemmaPlugin(BasePlugin):
|
||||
def process_token_ids(
|
||||
self,
|
||||
input_ids: list[int],
|
||||
labels: Optional[list[int]],
|
||||
labels: list[int] | None,
|
||||
images: list["ImageInput"],
|
||||
videos: list["VideoInput"],
|
||||
audios: list["AudioInput"],
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["MMProcessor"],
|
||||
) -> tuple[list[int], Optional[list[int]]]:
|
||||
) -> tuple[list[int], list[int] | None]:
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
num_images = len(images)
|
||||
image_seqlen = processor.image_seq_length if self.expand_mm_tokens else 0 # skip mm token
|
||||
@@ -2123,9 +2126,9 @@ def register_mm_plugin(name: str, plugin_class: type["BasePlugin"]) -> None:
|
||||
|
||||
def get_mm_plugin(
|
||||
name: str,
|
||||
image_token: Optional[str] = None,
|
||||
video_token: Optional[str] = None,
|
||||
audio_token: Optional[str] = None,
|
||||
image_token: str | None = None,
|
||||
video_token: str | None = None,
|
||||
audio_token: str | None = None,
|
||||
**kwargs,
|
||||
) -> "BasePlugin":
|
||||
r"""Get plugin for multimodal inputs."""
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
import json
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Literal, Optional, Union
|
||||
from typing import Any, Literal
|
||||
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
@@ -33,40 +33,40 @@ class DatasetAttr:
|
||||
formatting: Literal["alpaca", "sharegpt", "openai"] = "alpaca"
|
||||
ranking: bool = False
|
||||
# extra configs
|
||||
subset: Optional[str] = None
|
||||
subset: str | None = None
|
||||
split: str = "train"
|
||||
folder: Optional[str] = None
|
||||
num_samples: Optional[int] = None
|
||||
folder: str | None = None
|
||||
num_samples: int | None = None
|
||||
# common columns
|
||||
system: Optional[str] = None
|
||||
tools: Optional[str] = None
|
||||
images: Optional[str] = None
|
||||
videos: Optional[str] = None
|
||||
audios: Optional[str] = None
|
||||
system: str | None = None
|
||||
tools: str | None = None
|
||||
images: str | None = None
|
||||
videos: str | None = None
|
||||
audios: str | None = None
|
||||
# dpo columns
|
||||
chosen: Optional[str] = None
|
||||
rejected: Optional[str] = None
|
||||
kto_tag: Optional[str] = None
|
||||
chosen: str | None = None
|
||||
rejected: str | None = None
|
||||
kto_tag: str | None = None
|
||||
# alpaca columns
|
||||
prompt: Optional[str] = "instruction"
|
||||
query: Optional[str] = "input"
|
||||
response: Optional[str] = "output"
|
||||
history: Optional[str] = None
|
||||
prompt: str | None = "instruction"
|
||||
query: str | None = "input"
|
||||
response: str | None = "output"
|
||||
history: str | None = None
|
||||
# sharegpt columns
|
||||
messages: Optional[str] = "conversations"
|
||||
messages: str | None = "conversations"
|
||||
# sharegpt tags
|
||||
role_tag: Optional[str] = "from"
|
||||
content_tag: Optional[str] = "value"
|
||||
user_tag: Optional[str] = "human"
|
||||
assistant_tag: Optional[str] = "gpt"
|
||||
observation_tag: Optional[str] = "observation"
|
||||
function_tag: Optional[str] = "function_call"
|
||||
system_tag: Optional[str] = "system"
|
||||
role_tag: str | None = "from"
|
||||
content_tag: str | None = "value"
|
||||
user_tag: str | None = "human"
|
||||
assistant_tag: str | None = "gpt"
|
||||
observation_tag: str | None = "observation"
|
||||
function_tag: str | None = "function_call"
|
||||
system_tag: str | None = "system"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return self.dataset_name
|
||||
|
||||
def set_attr(self, key: str, obj: dict[str, Any], default: Optional[Any] = None) -> None:
|
||||
def set_attr(self, key: str, obj: dict[str, Any], default: Any | None = None) -> None:
|
||||
setattr(self, key, obj.get(key, default))
|
||||
|
||||
def join(self, attr: dict[str, Any]) -> None:
|
||||
@@ -90,7 +90,7 @@ class DatasetAttr:
|
||||
self.set_attr(tag, attr["tags"])
|
||||
|
||||
|
||||
def get_dataset_list(dataset_names: Optional[list[str]], dataset_dir: Union[str, dict]) -> list["DatasetAttr"]:
|
||||
def get_dataset_list(dataset_names: list[str] | None, dataset_dir: str | dict) -> list["DatasetAttr"]:
|
||||
r"""Get the attributes of the datasets."""
|
||||
if dataset_names is None:
|
||||
dataset_names = []
|
||||
|
||||
@@ -1673,6 +1673,43 @@ register_template(
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="minimax1",
|
||||
format_user=StringFormatter(
|
||||
slots=[
|
||||
"<beginning_of_sentence>user name=user\n{{content}}<end_of_sentence>\n<beginning_of_sentence>ai name=assistant\n"
|
||||
]
|
||||
),
|
||||
format_assistant=StringFormatter(slots=["{{content}}<end_of_sentence>\n"]),
|
||||
format_system=StringFormatter(
|
||||
slots=["<beginning_of_sentence>system ai_setting=assistant\n{{content}}<end_of_sentence>\n"]
|
||||
),
|
||||
format_function=FunctionFormatter(slots=["{{content}}<end_of_sentence>\n"], tool_format="minimax1"),
|
||||
format_observation=StringFormatter(
|
||||
slots=[
|
||||
"<beginning_of_sentence>tool name=tools\n{{content}}<end_of_sentence>\n<beginning_of_sentence>ai name=assistant\n"
|
||||
]
|
||||
),
|
||||
format_tools=ToolFormatter(tool_format="minimax1"),
|
||||
default_system="You are a helpful assistant.",
|
||||
stop_words=["<end_of_sentence>"],
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="minimax2",
|
||||
format_user=StringFormatter(slots=["]~b]user\n{{content}}[e~[\n]~b]ai\n"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}[e~[\n"]),
|
||||
format_system=StringFormatter(slots=["]~!b[]~b]system\n{{content}}[e~[\n"]),
|
||||
format_function=FunctionFormatter(slots=["{{content}}[e~[\n"], tool_format="minimax2"),
|
||||
format_observation=StringFormatter(slots=["]~b]tool\n<response>{{content}}</response>[e~[\n]~b]ai\n"]),
|
||||
format_tools=ToolFormatter(tool_format="minimax2"),
|
||||
default_system="You are a helpful assistant. Your name is MiniMax-M2.1 and is built by MiniMax.",
|
||||
stop_words=["[e~["],
|
||||
template_class=ReasoningTemplate,
|
||||
)
|
||||
|
||||
|
||||
# mistral tokenizer v3 tekken
|
||||
register_template(
|
||||
name="ministral",
|
||||
|
||||
@@ -61,6 +61,21 @@ LLAMA3_TOOL_PROMPT = (
|
||||
"Do not use variables.\n\n{tool_text}"
|
||||
)
|
||||
|
||||
MINIMAX_M1_TOOL_PROMPT = (
|
||||
"You are provided with these tools:\n<tools>\n{tool_text}</tools>\n\n"
|
||||
"If you need to call tools, please respond with <tool_calls></tool_calls> XML tags, and provide tool-name and "
|
||||
"json-object of arguments, following the format below:\n<tool_calls>\n"
|
||||
"""{{"name": <tool-name-1>, "arguments": <args-json-object-1>}}\n...\n</tool_calls>"""
|
||||
)
|
||||
|
||||
MINIMAX_M2_TOOL_PROMPT = (
|
||||
"\n\n# Tools\n\nYou may call one or more tools to assist with the user query.\n"
|
||||
"Here are the tools available in JSONSchema format:\n\n<tools>\n{tool_text}</tools>\n\n"
|
||||
"When making tool calls, use XML format to invoke tools and pass parameters:\n"
|
||||
"""\n<minimax:tool_call>\n<invoke name="tool-name-1">\n<parameter name="param-key-1">param-value-1</parameter>\n"""
|
||||
"""<parameter name="param-key-2">param-value-2</parameter>\n...\n</invoke>\n</minimax:tool_call>"""
|
||||
)
|
||||
|
||||
QWEN_TOOL_PROMPT = (
|
||||
"\n\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\n"
|
||||
"You are provided with function signatures within <tools></tools> XML tags:\n<tools>{tool_text}"
|
||||
@@ -253,6 +268,109 @@ class Llama3ToolUtils(ToolUtils):
|
||||
return content
|
||||
|
||||
|
||||
class MiniMaxM1ToolUtils(ToolUtils):
|
||||
r"""MiniMax-M1 tool using template."""
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def tool_formatter(tools: list[dict[str, Any]]) -> str:
|
||||
tool_text = ""
|
||||
for tool in tools:
|
||||
tool = tool.get("function", "") if tool.get("type") == "function" else tool
|
||||
tool_text += json.dumps(tool, ensure_ascii=False) + "\n"
|
||||
|
||||
return MINIMAX_M1_TOOL_PROMPT.format(tool_text=tool_text)
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def function_formatter(functions: list["FunctionCall"]) -> str:
|
||||
function_texts = []
|
||||
for func in functions:
|
||||
name, arguments = func.name, json.loads(func.arguments)
|
||||
function_texts.append(json.dumps({"name": name, "arguments": arguments}, ensure_ascii=False))
|
||||
|
||||
return "<tool_calls>\n" + "\n".join(function_texts) + "\n</tool_calls>"
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]:
|
||||
regex = re.compile(r"<tool_calls>\s*(.+?)\s*</tool_calls>", re.DOTALL)
|
||||
tool_match = re.search(regex, content)
|
||||
if not tool_match:
|
||||
return content
|
||||
|
||||
tool_calls_content = tool_match.group(1)
|
||||
results = []
|
||||
for line in tool_calls_content.split("\n"):
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
try:
|
||||
tool_call = json.loads(line)
|
||||
results.append(FunctionCall(tool_call["name"], json.dumps(tool_call["arguments"], ensure_ascii=False)))
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
return results
|
||||
|
||||
|
||||
class MiniMaxM2ToolUtils(ToolUtils):
|
||||
r"""MiniMax-M2 tool using template."""
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def tool_formatter(tools: list[dict[str, Any]]) -> str:
|
||||
tool_text = ""
|
||||
for tool in tools:
|
||||
tool = tool.get("function", "") if tool.get("type") == "function" else tool
|
||||
tool_text += "<tool>" + json.dumps(tool, ensure_ascii=False) + "</tool>\n"
|
||||
|
||||
return MINIMAX_M2_TOOL_PROMPT.format(tool_text=tool_text)
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def function_formatter(functions: list["FunctionCall"]) -> str:
|
||||
function_texts = []
|
||||
for func in functions:
|
||||
name, arguments = func.name, json.loads(func.arguments)
|
||||
prompt = f'<invoke name="{name}">'
|
||||
for key, value in arguments.items():
|
||||
prompt += f'\n<parameter name="{key}">'
|
||||
if not isinstance(value, str):
|
||||
value = json.dumps(value, ensure_ascii=False)
|
||||
prompt += value + "</parameter>"
|
||||
prompt += "\n</invoke>"
|
||||
function_texts.append(prompt)
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]:
|
||||
regex = re.compile(r"<minimax:tool_call>\s*(.+?)\s*</minimax:tool_call>", re.DOTALL)
|
||||
tool_match = re.search(regex, content)
|
||||
if not tool_match:
|
||||
return content
|
||||
|
||||
tool_calls_content = tool_match.group(1)
|
||||
invoke_regex = re.compile(r"<invoke name=\"(.*?)\">(.*?)</invoke>", re.DOTALL)
|
||||
results = []
|
||||
|
||||
for func_name, params_block in re.findall(invoke_regex, tool_calls_content):
|
||||
args_dict = {}
|
||||
param_pattern = re.compile(r"<parameter name=\"(.*?)\">(.*?)</parameter>", re.DOTALL)
|
||||
for key, raw_value in re.findall(param_pattern, params_block):
|
||||
value = raw_value.strip()
|
||||
try:
|
||||
parsed_value = json.loads(value)
|
||||
except json.JSONDecodeError:
|
||||
parsed_value = raw_value
|
||||
args_dict[key] = parsed_value
|
||||
|
||||
results.append(FunctionCall(func_name.strip(), json.dumps(args_dict, ensure_ascii=False)))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
class MistralToolUtils(ToolUtils):
|
||||
r"""Mistral v0.3 tool using template."""
|
||||
|
||||
@@ -432,6 +550,8 @@ TOOLS = {
|
||||
"default": DefaultToolUtils(),
|
||||
"glm4": GLM4ToolUtils(),
|
||||
"llama3": Llama3ToolUtils(),
|
||||
"minimax1": MiniMaxM1ToolUtils(),
|
||||
"minimax2": MiniMaxM2ToolUtils(),
|
||||
"mistral": MistralToolUtils(),
|
||||
"qwen": QwenToolUtils(),
|
||||
"glm4_moe": GLM4MOEToolUtils(),
|
||||
|
||||
@@ -15,7 +15,6 @@
|
||||
import os
|
||||
from collections import OrderedDict, defaultdict
|
||||
from enum import Enum, unique
|
||||
from typing import Optional
|
||||
|
||||
from peft.utils import SAFETENSORS_WEIGHTS_NAME as SAFE_ADAPTER_WEIGHTS_NAME
|
||||
from peft.utils import WEIGHTS_NAME as ADAPTER_WEIGHTS_NAME
|
||||
@@ -64,6 +63,7 @@ MCA_SUPPORTED_MODELS = {
|
||||
"qwen2",
|
||||
"qwen2_vl",
|
||||
"qwen2_5_vl",
|
||||
"qwen3_vl",
|
||||
"qwen3",
|
||||
"qwen3_moe",
|
||||
"qwen3_next",
|
||||
@@ -154,7 +154,7 @@ class RopeScaling(str, Enum):
|
||||
|
||||
def register_model_group(
|
||||
models: dict[str, dict[DownloadSource, str]],
|
||||
template: Optional[str] = None,
|
||||
template: str | None = None,
|
||||
multimodal: bool = False,
|
||||
) -> None:
|
||||
for name, path in models.items():
|
||||
@@ -1072,6 +1072,40 @@ register_model_group(
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"MiniMax-Text-01-Instruct": {
|
||||
DownloadSource.DEFAULT: "MiniMaxAI/MiniMax-Text-01-hf",
|
||||
DownloadSource.MODELSCOPE: "MiniMaxAI/MiniMax-Text-01",
|
||||
},
|
||||
"MiniMax-M1-40k-Thinking": {
|
||||
DownloadSource.DEFAULT: "MiniMaxAI/MiniMax-M1-40k-hf",
|
||||
DownloadSource.MODELSCOPE: "MiniMaxAI/MiniMax-M1-40k-hf",
|
||||
},
|
||||
"MiniMax-M1-80k-Thinking": {
|
||||
DownloadSource.DEFAULT: "MiniMaxAI/MiniMax-M1-80k-hf",
|
||||
DownloadSource.MODELSCOPE: "MiniMaxAI/MiniMax-M1-80k-hf",
|
||||
},
|
||||
},
|
||||
template="minimax1",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"MiniMax-M2-Thinking": {
|
||||
DownloadSource.DEFAULT: "MiniMaxAI/MiniMax-M2",
|
||||
DownloadSource.MODELSCOPE: "MiniMaxAI/MiniMax-M2",
|
||||
},
|
||||
"MiniMax-M2.1-Thinking": {
|
||||
DownloadSource.DEFAULT: "MiniMaxAI/MiniMax-M2.1",
|
||||
DownloadSource.MODELSCOPE: "MiniMaxAI/MiniMax-M2.1",
|
||||
},
|
||||
},
|
||||
template="minimax2",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Granite-3.0-1B-A400M-Base": {
|
||||
|
||||
@@ -19,7 +19,7 @@
|
||||
from collections import OrderedDict
|
||||
|
||||
|
||||
VERSION = "0.9.4.dev0"
|
||||
VERSION = "0.9.4"
|
||||
|
||||
|
||||
def print_env() -> None:
|
||||
|
||||
@@ -117,7 +117,7 @@ def _configure_library_root_logger() -> None:
|
||||
library_root_logger.propagate = False
|
||||
|
||||
|
||||
def get_logger(name: Optional[str] = None) -> "_Logger":
|
||||
def get_logger(name: str | None = None) -> "_Logger":
|
||||
r"""Return a logger with the specified name. It it not supposed to be accessed externally."""
|
||||
if name is None:
|
||||
name = _get_library_name()
|
||||
|
||||
@@ -94,11 +94,11 @@ def check_version(requirement: str, mandatory: bool = False) -> None:
|
||||
|
||||
def check_dependencies() -> None:
|
||||
r"""Check the version of the required packages."""
|
||||
check_version("transformers>=4.49.0,<=4.57.1")
|
||||
check_version("transformers>=4.51.0,<=4.57.1")
|
||||
check_version("datasets>=2.16.0,<=4.0.0")
|
||||
check_version("accelerate>=1.3.0,<=1.11.0")
|
||||
check_version("peft>=0.14.0,<=0.17.1")
|
||||
check_version("trl>=0.8.6,<=0.9.6")
|
||||
check_version("trl>=0.18.0,<=0.24.0")
|
||||
|
||||
|
||||
def calculate_tps(dataset: list[dict[str, Any]], metrics: dict[str, float], stage: Literal["sft", "rm"]) -> float:
|
||||
@@ -332,3 +332,7 @@ def fix_proxy(ipv6_enabled: bool = False) -> None:
|
||||
if ipv6_enabled:
|
||||
os.environ.pop("http_proxy", None)
|
||||
os.environ.pop("HTTP_PROXY", None)
|
||||
os.environ.pop("https_proxy", None)
|
||||
os.environ.pop("HTTPS_PROXY", None)
|
||||
os.environ.pop("all_proxy", None)
|
||||
os.environ.pop("ALL_PROXY", None)
|
||||
|
||||
@@ -16,22 +16,22 @@
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from typing import Any, Literal, Optional
|
||||
from typing import Any, Literal
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataArguments:
|
||||
r"""Arguments pertaining to what data we are going to input our model for training and evaluation."""
|
||||
|
||||
template: Optional[str] = field(
|
||||
template: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Which template to use for constructing prompts in training and inference."},
|
||||
)
|
||||
dataset: Optional[str] = field(
|
||||
dataset: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The name of dataset(s) to use for training. Use commas to separate multiple datasets."},
|
||||
)
|
||||
eval_dataset: Optional[str] = field(
|
||||
eval_dataset: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The name of dataset(s) to use for evaluation. Use commas to separate multiple datasets."},
|
||||
)
|
||||
@@ -39,7 +39,7 @@ class DataArguments:
|
||||
default="data",
|
||||
metadata={"help": "Path to the folder containing the datasets."},
|
||||
)
|
||||
media_dir: Optional[str] = field(
|
||||
media_dir: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to the folder containing the images, videos or audios. Defaults to `dataset_dir`."},
|
||||
)
|
||||
@@ -67,7 +67,7 @@ class DataArguments:
|
||||
default="concat",
|
||||
metadata={"help": "Strategy to use in dataset mixing (concat/interleave) (undersampling/oversampling)."},
|
||||
)
|
||||
interleave_probs: Optional[str] = field(
|
||||
interleave_probs: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Probabilities to sample data from datasets. Use commas to separate multiple datasets."},
|
||||
)
|
||||
@@ -79,15 +79,15 @@ class DataArguments:
|
||||
default=1000,
|
||||
metadata={"help": "The number of examples in one group in pre-processing."},
|
||||
)
|
||||
preprocessing_num_workers: Optional[int] = field(
|
||||
preprocessing_num_workers: int | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The number of processes to use for the pre-processing."},
|
||||
)
|
||||
max_samples: Optional[int] = field(
|
||||
max_samples: int | None = field(
|
||||
default=None,
|
||||
metadata={"help": "For debugging purposes, truncate the number of examples for each dataset."},
|
||||
)
|
||||
eval_num_beams: Optional[int] = field(
|
||||
eval_num_beams: int | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Number of beams to use for evaluation. This argument will be passed to `model.generate`"},
|
||||
)
|
||||
@@ -103,7 +103,7 @@ class DataArguments:
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to evaluate on each dataset separately."},
|
||||
)
|
||||
packing: Optional[bool] = field(
|
||||
packing: bool | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Enable sequences packing in training. Will automatically enable in pre-training."},
|
||||
)
|
||||
@@ -111,19 +111,19 @@ class DataArguments:
|
||||
default=False,
|
||||
metadata={"help": "Enable sequence packing without cross-attention."},
|
||||
)
|
||||
tool_format: Optional[str] = field(
|
||||
tool_format: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Tool format to use for constructing function calling examples."},
|
||||
)
|
||||
default_system: Optional[str] = field(
|
||||
default_system: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Override the default system message in the template."},
|
||||
)
|
||||
enable_thinking: Optional[bool] = field(
|
||||
enable_thinking: bool | None = field(
|
||||
default=True,
|
||||
metadata={"help": "Whether or not to enable thinking mode for reasoning models."},
|
||||
)
|
||||
tokenized_path: Optional[str] = field(
|
||||
tokenized_path: str | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": (
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Literal, Optional
|
||||
from typing import Literal
|
||||
|
||||
from datasets import DownloadMode
|
||||
|
||||
@@ -46,7 +46,7 @@ class EvaluationArguments:
|
||||
default=5,
|
||||
metadata={"help": "Number of examplars for few-shot learning."},
|
||||
)
|
||||
save_dir: Optional[str] = field(
|
||||
save_dir: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to save the evaluation results."},
|
||||
)
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from typing import Any, Literal, Optional
|
||||
from typing import Any, Literal
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -40,7 +40,7 @@ class FreezeArguments:
|
||||
)
|
||||
},
|
||||
)
|
||||
freeze_extra_modules: Optional[str] = field(
|
||||
freeze_extra_modules: str | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": (
|
||||
@@ -56,7 +56,7 @@ class FreezeArguments:
|
||||
class LoraArguments:
|
||||
r"""Arguments pertaining to the LoRA training."""
|
||||
|
||||
additional_target: Optional[str] = field(
|
||||
additional_target: str | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": (
|
||||
@@ -66,7 +66,7 @@ class LoraArguments:
|
||||
)
|
||||
},
|
||||
)
|
||||
lora_alpha: Optional[int] = field(
|
||||
lora_alpha: int | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The scale factor for LoRA fine-tuning (default: lora_rank * 2)."},
|
||||
)
|
||||
@@ -88,7 +88,7 @@ class LoraArguments:
|
||||
)
|
||||
},
|
||||
)
|
||||
loraplus_lr_ratio: Optional[float] = field(
|
||||
loraplus_lr_ratio: float | None = field(
|
||||
default=None,
|
||||
metadata={"help": "LoRA plus learning rate ratio (lr_B / lr_A)."},
|
||||
)
|
||||
@@ -126,7 +126,7 @@ class LoraArguments:
|
||||
class OFTArguments:
|
||||
r"""Arguments pertaining to the OFT training."""
|
||||
|
||||
additional_target: Optional[str] = field(
|
||||
additional_target: str | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": (
|
||||
@@ -220,27 +220,27 @@ class RLHFArguments:
|
||||
default=False,
|
||||
metadata={"help": "Whiten the rewards before compute advantages in PPO training."},
|
||||
)
|
||||
ref_model: Optional[str] = field(
|
||||
ref_model: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to the reference model used for the PPO or DPO training."},
|
||||
)
|
||||
ref_model_adapters: Optional[str] = field(
|
||||
ref_model_adapters: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to the adapters of the reference model."},
|
||||
)
|
||||
ref_model_quantization_bit: Optional[int] = field(
|
||||
ref_model_quantization_bit: int | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The number of bits to quantize the reference model."},
|
||||
)
|
||||
reward_model: Optional[str] = field(
|
||||
reward_model: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to the reward model used for the PPO training."},
|
||||
)
|
||||
reward_model_adapters: Optional[str] = field(
|
||||
reward_model_adapters: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to the adapters of the reward model."},
|
||||
)
|
||||
reward_model_quantization_bit: Optional[int] = field(
|
||||
reward_model_quantization_bit: int | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The number of bits to quantize the reward model."},
|
||||
)
|
||||
@@ -248,7 +248,7 @@ class RLHFArguments:
|
||||
default="lora",
|
||||
metadata={"help": "The type of the reward model in PPO training. Lora model only supports lora training."},
|
||||
)
|
||||
ld_alpha: Optional[float] = field(
|
||||
ld_alpha: float | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": (
|
||||
@@ -361,15 +361,15 @@ class BAdamArgument:
|
||||
default="layer",
|
||||
metadata={"help": "Whether to use layer-wise or ratio-wise BAdam optimizer."},
|
||||
)
|
||||
badam_start_block: Optional[int] = field(
|
||||
badam_start_block: int | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The starting block index for layer-wise BAdam."},
|
||||
)
|
||||
badam_switch_mode: Optional[Literal["ascending", "descending", "random", "fixed"]] = field(
|
||||
badam_switch_mode: Literal["ascending", "descending", "random", "fixed"] | None = field(
|
||||
default="ascending",
|
||||
metadata={"help": "the strategy of picking block to update for layer-wise BAdam."},
|
||||
)
|
||||
badam_switch_interval: Optional[int] = field(
|
||||
badam_switch_interval: int | None = field(
|
||||
default=50,
|
||||
metadata={
|
||||
"help": "Number of steps to update the block for layer-wise BAdam. Use -1 to disable the block update."
|
||||
@@ -406,15 +406,15 @@ class SwanLabArguments:
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to use the SwanLab (an experiment tracking and visualization tool)."},
|
||||
)
|
||||
swanlab_project: Optional[str] = field(
|
||||
swanlab_project: str | None = field(
|
||||
default="llamafactory",
|
||||
metadata={"help": "The project name in SwanLab."},
|
||||
)
|
||||
swanlab_workspace: Optional[str] = field(
|
||||
swanlab_workspace: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The workspace name in SwanLab."},
|
||||
)
|
||||
swanlab_run_name: Optional[str] = field(
|
||||
swanlab_run_name: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The experiment name in SwanLab."},
|
||||
)
|
||||
@@ -422,19 +422,19 @@ class SwanLabArguments:
|
||||
default="cloud",
|
||||
metadata={"help": "The mode of SwanLab."},
|
||||
)
|
||||
swanlab_api_key: Optional[str] = field(
|
||||
swanlab_api_key: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The API key for SwanLab."},
|
||||
)
|
||||
swanlab_logdir: Optional[str] = field(
|
||||
swanlab_logdir: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The log directory for SwanLab."},
|
||||
)
|
||||
swanlab_lark_webhook_url: Optional[str] = field(
|
||||
swanlab_lark_webhook_url: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The Lark(飞书) webhook URL for SwanLab."},
|
||||
)
|
||||
swanlab_lark_secret: Optional[str] = field(
|
||||
swanlab_lark_secret: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The Lark(飞书) secret for SwanLab."},
|
||||
)
|
||||
@@ -510,7 +510,7 @@ class FinetuningArguments(
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to disable the shuffling of the training set."},
|
||||
)
|
||||
early_stopping_steps: Optional[int] = field(
|
||||
early_stopping_steps: int | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Number of steps to stop training if the `metric_for_best_model` does not improve."},
|
||||
)
|
||||
@@ -530,11 +530,11 @@ class FinetuningArguments(
|
||||
return arg
|
||||
|
||||
self.freeze_trainable_modules: list[str] = split_arg(self.freeze_trainable_modules)
|
||||
self.freeze_extra_modules: Optional[list[str]] = split_arg(self.freeze_extra_modules)
|
||||
self.freeze_extra_modules: list[str] | None = split_arg(self.freeze_extra_modules)
|
||||
self.lora_alpha: int = self.lora_alpha or self.lora_rank * 2
|
||||
self.lora_target: list[str] = split_arg(self.lora_target)
|
||||
self.oft_target: list[str] = split_arg(self.oft_target)
|
||||
self.additional_target: Optional[list[str]] = split_arg(self.additional_target)
|
||||
self.additional_target: list[str] | None = split_arg(self.additional_target)
|
||||
self.galore_target: list[str] = split_arg(self.galore_target)
|
||||
self.apollo_target: list[str] = split_arg(self.apollo_target)
|
||||
self.use_ref_model = self.stage == "dpo" and self.pref_loss not in ["orpo", "simpo"]
|
||||
|
||||
@@ -17,12 +17,11 @@
|
||||
|
||||
import json
|
||||
from dataclasses import asdict, dataclass, field, fields
|
||||
from typing import Any, Literal, Optional, Union
|
||||
from typing import Any, Literal, Self
|
||||
|
||||
import torch
|
||||
from omegaconf import OmegaConf
|
||||
from transformers.training_args import _convert_str_dict
|
||||
from typing_extensions import Self
|
||||
|
||||
from ..extras.constants import AttentionFunction, EngineName, QuantizationMethod, RopeScaling
|
||||
from ..extras.logging import get_logger
|
||||
@@ -35,13 +34,13 @@ logger = get_logger(__name__)
|
||||
class BaseModelArguments:
|
||||
r"""Arguments pertaining to the model."""
|
||||
|
||||
model_name_or_path: Optional[str] = field(
|
||||
model_name_or_path: str | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Path to the model weight or identifier from huggingface.co/models or modelscope.cn/models."
|
||||
},
|
||||
)
|
||||
adapter_name_or_path: Optional[str] = field(
|
||||
adapter_name_or_path: str | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": (
|
||||
@@ -50,11 +49,11 @@ class BaseModelArguments:
|
||||
)
|
||||
},
|
||||
)
|
||||
adapter_folder: Optional[str] = field(
|
||||
adapter_folder: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The folder containing the adapter weights to load."},
|
||||
)
|
||||
cache_dir: Optional[str] = field(
|
||||
cache_dir: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Where to store the pre-trained models downloaded from huggingface.co or modelscope.cn."},
|
||||
)
|
||||
@@ -70,17 +69,17 @@ class BaseModelArguments:
|
||||
default=False,
|
||||
metadata={"help": "Whether or not the special tokens should be split during the tokenization process."},
|
||||
)
|
||||
add_tokens: Optional[str] = field(
|
||||
add_tokens: str | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Non-special tokens to be added into the tokenizer. Use commas to separate multiple tokens."
|
||||
},
|
||||
)
|
||||
add_special_tokens: Optional[str] = field(
|
||||
add_special_tokens: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Special tokens to be added into the tokenizer. Use commas to separate multiple tokens."},
|
||||
)
|
||||
new_special_tokens_config: Optional[str] = field(
|
||||
new_special_tokens_config: str | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": (
|
||||
@@ -110,7 +109,7 @@ class BaseModelArguments:
|
||||
default=True,
|
||||
metadata={"help": "Whether or not to use memory-efficient model loading."},
|
||||
)
|
||||
rope_scaling: Optional[RopeScaling] = field(
|
||||
rope_scaling: RopeScaling | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Which scaling strategy should be adopted for the RoPE embeddings."},
|
||||
)
|
||||
@@ -122,7 +121,7 @@ class BaseModelArguments:
|
||||
default=False,
|
||||
metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."},
|
||||
)
|
||||
mixture_of_depths: Optional[Literal["convert", "load"]] = field(
|
||||
mixture_of_depths: Literal["convert", "load"] | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Convert the model to mixture-of-depths (MoD) or load the MoD model."},
|
||||
)
|
||||
@@ -138,7 +137,7 @@ class BaseModelArguments:
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to enable liger kernel for faster training."},
|
||||
)
|
||||
moe_aux_loss_coef: Optional[float] = field(
|
||||
moe_aux_loss_coef: float | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Coefficient of the auxiliary router loss in mixture-of-experts model."},
|
||||
)
|
||||
@@ -174,7 +173,7 @@ class BaseModelArguments:
|
||||
default=True,
|
||||
metadata={"help": "Whether or not to use KV cache in generation."},
|
||||
)
|
||||
use_v1_kernels: bool = field(
|
||||
use_v1_kernels: bool | None = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to use high-performance kernels in training."},
|
||||
)
|
||||
@@ -182,15 +181,15 @@ class BaseModelArguments:
|
||||
default="auto",
|
||||
metadata={"help": "Data type for model weights and activations at inference."},
|
||||
)
|
||||
hf_hub_token: Optional[str] = field(
|
||||
hf_hub_token: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Auth token to log in with Hugging Face Hub."},
|
||||
)
|
||||
ms_hub_token: Optional[str] = field(
|
||||
ms_hub_token: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Auth token to log in with ModelScope Hub."},
|
||||
)
|
||||
om_hub_token: Optional[str] = field(
|
||||
om_hub_token: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Auth token to log in with Modelers Hub."},
|
||||
)
|
||||
@@ -283,7 +282,7 @@ class QuantizationArguments:
|
||||
default=QuantizationMethod.BNB,
|
||||
metadata={"help": "Quantization method to use for on-the-fly quantization."},
|
||||
)
|
||||
quantization_bit: Optional[int] = field(
|
||||
quantization_bit: int | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The number of bits to quantize the model using on-the-fly quantization."},
|
||||
)
|
||||
@@ -295,7 +294,7 @@ class QuantizationArguments:
|
||||
default=True,
|
||||
metadata={"help": "Whether or not to use double quantization in bitsandbytes int4 training."},
|
||||
)
|
||||
quantization_device_map: Optional[Literal["auto"]] = field(
|
||||
quantization_device_map: Literal["auto"] | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Device map used to infer the 4-bit quantized model, needs bitsandbytes>=0.43.0."},
|
||||
)
|
||||
@@ -375,7 +374,7 @@ class ProcessorArguments:
|
||||
class ExportArguments:
|
||||
r"""Arguments pertaining to the model export."""
|
||||
|
||||
export_dir: Optional[str] = field(
|
||||
export_dir: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to the directory to save the exported model."},
|
||||
)
|
||||
@@ -387,11 +386,11 @@ class ExportArguments:
|
||||
default="cpu",
|
||||
metadata={"help": "The device used in model export, use `auto` to accelerate exporting."},
|
||||
)
|
||||
export_quantization_bit: Optional[int] = field(
|
||||
export_quantization_bit: int | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The number of bits to quantize the exported model."},
|
||||
)
|
||||
export_quantization_dataset: Optional[str] = field(
|
||||
export_quantization_dataset: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to the dataset or dataset name to use in quantizing the exported model."},
|
||||
)
|
||||
@@ -407,7 +406,7 @@ class ExportArguments:
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to save the `.bin` files instead of `.safetensors`."},
|
||||
)
|
||||
export_hub_model_id: Optional[str] = field(
|
||||
export_hub_model_id: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The name of the repository if push the model to the Hugging Face hub."},
|
||||
)
|
||||
@@ -437,7 +436,7 @@ class VllmArguments:
|
||||
default=32,
|
||||
metadata={"help": "Maximum rank of all LoRAs in the vLLM engine."},
|
||||
)
|
||||
vllm_config: Optional[Union[dict, str]] = field(
|
||||
vllm_config: dict | str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Config to initialize the vllm engine. Please use JSON strings."},
|
||||
)
|
||||
@@ -463,7 +462,7 @@ class SGLangArguments:
|
||||
default=-1,
|
||||
metadata={"help": "Tensor parallel size for the SGLang engine."},
|
||||
)
|
||||
sglang_config: Optional[Union[dict, str]] = field(
|
||||
sglang_config: dict | str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Config to initialize the SGLang engine. Please use JSON strings."},
|
||||
)
|
||||
@@ -487,21 +486,21 @@ class KTransformersArguments:
|
||||
default=False,
|
||||
metadata={"help": "Whether To Use KTransformers Optimizations For LoRA Training."},
|
||||
)
|
||||
kt_optimize_rule: Optional[str] = field(
|
||||
kt_optimize_rule: str | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Path To The KTransformers Optimize Rule; See https://github.com/kvcache-ai/ktransformers/."
|
||||
},
|
||||
)
|
||||
cpu_infer: Optional[int] = field(
|
||||
cpu_infer: int | None = field(
|
||||
default=32,
|
||||
metadata={"help": "Number Of CPU Cores Used For Computation."},
|
||||
)
|
||||
chunk_size: Optional[int] = field(
|
||||
chunk_size: int | None = field(
|
||||
default=8192,
|
||||
metadata={"help": "Chunk Size Used For CPU Compute In KTransformers."},
|
||||
)
|
||||
mode: Optional[str] = field(
|
||||
mode: str | None = field(
|
||||
default="normal",
|
||||
metadata={"help": "Normal Or Long_Context For Llama Models."},
|
||||
)
|
||||
@@ -539,17 +538,17 @@ class ModelArguments(
|
||||
The class on the most right will be displayed first.
|
||||
"""
|
||||
|
||||
compute_dtype: Optional[torch.dtype] = field(
|
||||
compute_dtype: torch.dtype | None = field(
|
||||
default=None,
|
||||
init=False,
|
||||
metadata={"help": "Torch data type for computing model outputs, derived from `fp/bf16`. Do not specify it."},
|
||||
)
|
||||
device_map: Optional[Union[str, dict[str, Any]]] = field(
|
||||
device_map: str | dict[str, Any] | None = field(
|
||||
default=None,
|
||||
init=False,
|
||||
metadata={"help": "Device map for model placement, derived from training stage. Do not specify it."},
|
||||
)
|
||||
model_max_length: Optional[int] = field(
|
||||
model_max_length: int | None = field(
|
||||
default=None,
|
||||
init=False,
|
||||
metadata={"help": "The maximum input length for model, derived from `cutoff_len`. Do not specify it."},
|
||||
|
||||
@@ -18,7 +18,7 @@
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
@@ -65,7 +65,7 @@ else:
|
||||
_TRAIN_MCA_CLS = tuple()
|
||||
|
||||
|
||||
def read_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> Union[dict[str, Any], list[str]]:
|
||||
def read_args(args: dict[str, Any] | list[str] | None = None) -> dict[str, Any] | list[str]:
|
||||
r"""Get arguments from the command line or a config file."""
|
||||
if args is not None:
|
||||
return args
|
||||
@@ -83,7 +83,7 @@ def read_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> Union[
|
||||
|
||||
|
||||
def _parse_args(
|
||||
parser: "HfArgumentParser", args: Optional[Union[dict[str, Any], list[str]]] = None, allow_extra_keys: bool = False
|
||||
parser: "HfArgumentParser", args: dict[str, Any] | list[str] | None = None, allow_extra_keys: bool = False
|
||||
) -> tuple[Any]:
|
||||
args = read_args(args)
|
||||
if isinstance(args, dict):
|
||||
@@ -205,13 +205,13 @@ def _check_extra_dependencies(
|
||||
check_version("rouge_chinese", mandatory=True)
|
||||
|
||||
|
||||
def _parse_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _TRAIN_CLS:
|
||||
def _parse_train_args(args: dict[str, Any] | list[str] | None = None) -> _TRAIN_CLS:
|
||||
parser = HfArgumentParser(_TRAIN_ARGS)
|
||||
allow_extra_keys = is_env_enabled("ALLOW_EXTRA_ARGS")
|
||||
return _parse_args(parser, args, allow_extra_keys=allow_extra_keys)
|
||||
|
||||
|
||||
def _parse_train_mca_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _TRAIN_MCA_CLS:
|
||||
def _parse_train_mca_args(args: dict[str, Any] | list[str] | None = None) -> _TRAIN_MCA_CLS:
|
||||
parser = HfArgumentParser(_TRAIN_MCA_ARGS)
|
||||
allow_extra_keys = is_env_enabled("ALLOW_EXTRA_ARGS")
|
||||
model_args, data_args, training_args, finetuning_args, generating_args = _parse_args(
|
||||
@@ -232,25 +232,25 @@ def _configure_mca_training_args(training_args, data_args, finetuning_args) -> N
|
||||
finetuning_args.use_mca = True
|
||||
|
||||
|
||||
def _parse_infer_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _INFER_CLS:
|
||||
def _parse_infer_args(args: dict[str, Any] | list[str] | None = None) -> _INFER_CLS:
|
||||
parser = HfArgumentParser(_INFER_ARGS)
|
||||
allow_extra_keys = is_env_enabled("ALLOW_EXTRA_ARGS")
|
||||
return _parse_args(parser, args, allow_extra_keys=allow_extra_keys)
|
||||
|
||||
|
||||
def _parse_eval_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _EVAL_CLS:
|
||||
def _parse_eval_args(args: dict[str, Any] | list[str] | None = None) -> _EVAL_CLS:
|
||||
parser = HfArgumentParser(_EVAL_ARGS)
|
||||
allow_extra_keys = is_env_enabled("ALLOW_EXTRA_ARGS")
|
||||
return _parse_args(parser, args, allow_extra_keys=allow_extra_keys)
|
||||
|
||||
|
||||
def get_ray_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> RayArguments:
|
||||
def get_ray_args(args: dict[str, Any] | list[str] | None = None) -> RayArguments:
|
||||
parser = HfArgumentParser(RayArguments)
|
||||
(ray_args,) = _parse_args(parser, args, allow_extra_keys=True)
|
||||
return ray_args
|
||||
|
||||
|
||||
def get_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _TRAIN_CLS:
|
||||
def get_train_args(args: dict[str, Any] | list[str] | None = None) -> _TRAIN_CLS:
|
||||
if is_env_enabled("USE_MCA"):
|
||||
model_args, data_args, training_args, finetuning_args, generating_args = _parse_train_mca_args(args)
|
||||
else:
|
||||
@@ -473,7 +473,7 @@ def get_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _
|
||||
return model_args, data_args, training_args, finetuning_args, generating_args
|
||||
|
||||
|
||||
def get_infer_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _INFER_CLS:
|
||||
def get_infer_args(args: dict[str, Any] | list[str] | None = None) -> _INFER_CLS:
|
||||
model_args, data_args, finetuning_args, generating_args = _parse_infer_args(args)
|
||||
|
||||
# Setup logging
|
||||
@@ -508,7 +508,7 @@ def get_infer_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _
|
||||
return model_args, data_args, finetuning_args, generating_args
|
||||
|
||||
|
||||
def get_eval_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _EVAL_CLS:
|
||||
def get_eval_args(args: dict[str, Any] | list[str] | None = None) -> _EVAL_CLS:
|
||||
model_args, data_args, eval_args, finetuning_args = _parse_eval_args(args)
|
||||
|
||||
# Setup logging
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Literal, Optional, Union
|
||||
from typing import Literal
|
||||
|
||||
from transformers import Seq2SeqTrainingArguments
|
||||
from transformers.training_args import _convert_str_dict
|
||||
@@ -40,7 +40,7 @@ else:
|
||||
class RayArguments:
|
||||
r"""Arguments pertaining to the Ray training."""
|
||||
|
||||
ray_run_name: Optional[str] = field(
|
||||
ray_run_name: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The training results will be saved at `<ray_storage_path>/ray_run_name`."},
|
||||
)
|
||||
@@ -48,7 +48,7 @@ class RayArguments:
|
||||
default="./saves",
|
||||
metadata={"help": "The storage path to save training results to"},
|
||||
)
|
||||
ray_storage_filesystem: Optional[Literal["s3", "gs", "gcs"]] = field(
|
||||
ray_storage_filesystem: Literal["s3", "gs", "gcs"] | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The storage filesystem to use. If None specified, local filesystem will be used."},
|
||||
)
|
||||
@@ -56,7 +56,7 @@ class RayArguments:
|
||||
default=1,
|
||||
metadata={"help": "The number of workers for Ray training. Default is 1 worker."},
|
||||
)
|
||||
resources_per_worker: Union[dict, str] = field(
|
||||
resources_per_worker: dict | str = field(
|
||||
default_factory=lambda: {"GPU": 1},
|
||||
metadata={"help": "The resources per worker for Ray training. Default is to use 1 GPU per worker."},
|
||||
)
|
||||
@@ -64,7 +64,7 @@ class RayArguments:
|
||||
default="PACK",
|
||||
metadata={"help": "The placement strategy for Ray training. Default is PACK."},
|
||||
)
|
||||
ray_init_kwargs: Optional[Union[dict, str]] = field(
|
||||
ray_init_kwargs: dict | str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The arguments to pass to ray.init for Ray training. Default is None."},
|
||||
)
|
||||
|
||||
@@ -15,7 +15,6 @@
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, Optional, TypedDict
|
||||
|
||||
import torch
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModelForCausalLM,
|
||||
@@ -158,6 +157,7 @@ def load_model(
|
||||
if model is None and not lazy_load:
|
||||
init_kwargs["config"] = config
|
||||
init_kwargs["pretrained_model_name_or_path"] = model_args.model_name_or_path
|
||||
init_kwargs["torch_dtype"] = "auto"
|
||||
|
||||
if model_args.mixture_of_depths == "load":
|
||||
model = load_mod_pretrained_model(**init_kwargs)
|
||||
@@ -216,9 +216,9 @@ def load_model(
|
||||
"You are try to using future feature about kernels, please note that this feature "
|
||||
"is not supported for all models. If get any error, please disable this feature, or report the issue."
|
||||
)
|
||||
from ..v1.plugins.model_plugins.kernels.registry import apply_available_kernels
|
||||
from ..v1.plugins.model_plugins.kernels.interface import apply_default_kernels
|
||||
|
||||
model = apply_available_kernels(model)
|
||||
model = apply_default_kernels(model=model, include_kernels=model_args.use_v1_kernels)
|
||||
|
||||
trainable_params, all_param = count_parameters(model)
|
||||
if is_trainable:
|
||||
|
||||
@@ -20,9 +20,10 @@
|
||||
|
||||
import inspect
|
||||
import os
|
||||
from collections.abc import Callable
|
||||
from functools import WRAPPER_ASSIGNMENTS, partial, wraps
|
||||
from types import MethodType
|
||||
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
@@ -156,16 +156,13 @@ def patch_config(
|
||||
# deepspeed zero3 is not compatible with low_cpu_mem_usage
|
||||
init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage and (not is_deepspeed_zero3_enabled())
|
||||
|
||||
# do not cast data type of the model deepspeed zero3 without qlora
|
||||
if not (is_deepspeed_zero3_enabled() and model_args.quantization_bit is None):
|
||||
init_kwargs["torch_dtype"] = "auto"
|
||||
# fsdp/deepspeed zero3 does not need device map
|
||||
if not (is_deepspeed_zero3_enabled() or is_fsdp_enabled()) and init_kwargs["low_cpu_mem_usage"]:
|
||||
if "device_map" not in init_kwargs and model_args.device_map:
|
||||
init_kwargs["device_map"] = model_args.device_map # device map requires low_cpu_mem_usage=True
|
||||
|
||||
if init_kwargs["low_cpu_mem_usage"] and not is_fsdp_enabled(): # fsdp does not need device map
|
||||
if "device_map" not in init_kwargs and model_args.device_map:
|
||||
init_kwargs["device_map"] = model_args.device_map # device map requires low_cpu_mem_usage=True
|
||||
|
||||
if init_kwargs.get("device_map", None) == "auto":
|
||||
init_kwargs["offload_folder"] = model_args.offload_folder
|
||||
if init_kwargs.get("device_map", None) == "auto":
|
||||
init_kwargs["offload_folder"] = model_args.offload_folder
|
||||
|
||||
|
||||
def patch_model(
|
||||
|
||||
@@ -26,6 +26,7 @@ import torch.nn.functional as F
|
||||
from transformers import Trainer
|
||||
from trl import DPOTrainer
|
||||
from trl.trainer import disable_dropout_in_model
|
||||
from trl.trainer.utils import prepare_deepspeed
|
||||
from typing_extensions import override
|
||||
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
@@ -95,7 +96,7 @@ class CustomDPOTrainer(DPOTrainer):
|
||||
if not (
|
||||
getattr(ref_model, "is_loaded_in_8bit", False) or getattr(ref_model, "is_loaded_in_4bit", False)
|
||||
): # quantized models are already set on the correct device
|
||||
self.ref_model = self._prepare_deepspeed(self.ref_model)
|
||||
self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
|
||||
else:
|
||||
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
|
||||
self.ref_model.eval()
|
||||
@@ -210,7 +211,7 @@ class CustomDPOTrainer(DPOTrainer):
|
||||
@override
|
||||
def concatenated_forward(
|
||||
self, model: "PreTrainedModel", batch: dict[str, "torch.Tensor"], is_ref_model: bool = False
|
||||
) -> tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
|
||||
) -> dict[str, "torch.Tensor"]:
|
||||
r"""Compute the sum log probabilities of the labels under given logits if loss_type is not IPO, ORPO or SimPO.
|
||||
|
||||
Otherwise the average log probabilities.
|
||||
@@ -230,11 +231,18 @@ class CustomDPOTrainer(DPOTrainer):
|
||||
chosen_logps, rejected_logps = all_logps.split(batch_size, dim=0)
|
||||
chosen_logits, rejected_logits = all_logits.split(batch_size, dim=0)
|
||||
chosen_length, _ = valid_length.split(batch_size, dim=0)
|
||||
|
||||
if self.loss_type in ["ipo", "orpo", "simpo"]:
|
||||
return chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_logps
|
||||
chosen_logps_avg = chosen_logps
|
||||
else:
|
||||
return chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_logps / chosen_length
|
||||
chosen_logps_avg = chosen_logps / chosen_length
|
||||
|
||||
return {
|
||||
"chosen_logps": chosen_logps,
|
||||
"rejected_logps": rejected_logps,
|
||||
"chosen_logits": chosen_logits,
|
||||
"rejected_logits": rejected_logits,
|
||||
"chosen_logps_avg": chosen_logps_avg,
|
||||
}
|
||||
|
||||
@override
|
||||
def compute_reference_log_probs(
|
||||
@@ -252,9 +260,9 @@ class CustomDPOTrainer(DPOTrainer):
|
||||
ref_context = nullcontext()
|
||||
|
||||
with torch.no_grad(), ref_context:
|
||||
reference_chosen_logps, reference_rejected_logps, *_ = self.concatenated_forward(
|
||||
ref_model, batch, is_ref_model=True
|
||||
)
|
||||
ref_output = self.concatenated_forward(ref_model, batch, is_ref_model=True)
|
||||
reference_chosen_logps = ref_output["chosen_logps"]
|
||||
reference_rejected_logps = ref_output["rejected_logps"]
|
||||
|
||||
return reference_chosen_logps, reference_rejected_logps
|
||||
|
||||
@@ -267,13 +275,13 @@ class CustomDPOTrainer(DPOTrainer):
|
||||
) -> tuple["torch.Tensor", dict[str, "torch.Tensor"]]:
|
||||
r"""Compute the DPO loss and other metrics for the given batch of inputs for train or test."""
|
||||
metrics = {}
|
||||
(
|
||||
policy_chosen_logps,
|
||||
policy_rejected_logps,
|
||||
policy_chosen_logits,
|
||||
policy_rejected_logits,
|
||||
policy_chosen_logps_avg,
|
||||
) = self.concatenated_forward(model, batch)
|
||||
|
||||
model_output = self.concatenated_forward(model, batch)
|
||||
policy_chosen_logps = model_output["chosen_logps"]
|
||||
policy_rejected_logps = model_output["rejected_logps"]
|
||||
policy_chosen_logits = model_output["chosen_logits"]
|
||||
policy_rejected_logits = model_output["rejected_logits"]
|
||||
policy_chosen_logps_avg = model_output["chosen_logps_avg"]
|
||||
|
||||
reference_chosen_logps, reference_rejected_logps = self.compute_reference_log_probs(model, batch)
|
||||
losses, chosen_rewards, rejected_rewards = self.compute_preference_loss(
|
||||
|
||||
@@ -25,6 +25,7 @@ import torch
|
||||
from transformers import Trainer
|
||||
from trl import KTOTrainer
|
||||
from trl.trainer import disable_dropout_in_model
|
||||
from trl.trainer.utils import prepare_deepspeed
|
||||
from typing_extensions import override
|
||||
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
@@ -77,6 +78,13 @@ class CustomKTOTrainer(KTOTrainer):
|
||||
self.desirable_weight = finetuning_args.kto_chosen_weight
|
||||
self.undesirable_weight = finetuning_args.kto_rejected_weight
|
||||
self.ftx_gamma = finetuning_args.pref_ftx
|
||||
# trl
|
||||
# Not all losses require a KL calculation
|
||||
self.calculate_KL = True
|
||||
if hasattr(self, "loss_type") and self.loss_type in ["apo_zero_unpaired"]:
|
||||
self.calculate_KL = False
|
||||
else:
|
||||
self.loss_type = "kto"
|
||||
|
||||
Trainer.__init__(self, model=model, **kwargs)
|
||||
self.model_accepts_loss_kwargs = False # overwrite trainer's default behavior
|
||||
@@ -90,7 +98,7 @@ class CustomKTOTrainer(KTOTrainer):
|
||||
if not (
|
||||
getattr(ref_model, "is_loaded_in_8bit", False) or getattr(ref_model, "is_loaded_in_4bit", False)
|
||||
): # quantized models are already set on the correct device
|
||||
self.ref_model = self._prepare_deepspeed(self.ref_model)
|
||||
self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
|
||||
else:
|
||||
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
|
||||
self.ref_model.eval()
|
||||
|
||||
@@ -11,14 +11,13 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""MCA (mcore_adapter) workflows for PT/SFT/DPO stages, aligned with LLaMA-Factory's workflow style."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
from collections.abc import Sequence
|
||||
from copy import deepcopy
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from transformers import DataCollatorForSeq2Seq
|
||||
|
||||
from ...data import (
|
||||
SFTDataCollatorWith4DAttentionMask,
|
||||
@@ -44,11 +43,11 @@ from mcore_adapter.models import AutoConfig, AutoModel
|
||||
from mcore_adapter.trainer import DPOTrainer as McaDPOTrainer
|
||||
from mcore_adapter.trainer import McaTrainer
|
||||
from mcore_adapter.trainer.dpo_config import DPOConfig
|
||||
from mcore_adapter.training_args import Seq2SeqTrainingArguments as McaSeq2SeqTrainingArguments
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import DataCollatorForSeq2Seq, TrainerCallback
|
||||
from mcore_adapter.training_args import Seq2SeqTrainingArguments as McaSeq2SeqTrainingArguments
|
||||
from transformers import TrainerCallback
|
||||
|
||||
from ...hparams import DataArguments, FinetuningArguments, ModelArguments
|
||||
|
||||
@@ -76,7 +75,7 @@ def _data_collator_wrapper(data_collator: Any):
|
||||
return wrapper
|
||||
|
||||
|
||||
def _check_model_support(model_args: ModelArguments):
|
||||
def _check_model_support(model_args: "ModelArguments"):
|
||||
from transformers import AutoConfig as HfAutoConfig
|
||||
|
||||
config = HfAutoConfig.from_pretrained(
|
||||
@@ -87,11 +86,11 @@ def _check_model_support(model_args: ModelArguments):
|
||||
|
||||
|
||||
def run_pt(
|
||||
model_args: ModelArguments,
|
||||
data_args: DataArguments,
|
||||
training_args: McaSeq2SeqTrainingArguments,
|
||||
finetuning_args: FinetuningArguments,
|
||||
callbacks: list[TrainerCallback] | None = None,
|
||||
model_args: "ModelArguments",
|
||||
data_args: "DataArguments",
|
||||
training_args: "McaSeq2SeqTrainingArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
callbacks: Optional[list["TrainerCallback"]] = None,
|
||||
):
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
tokenizer = tokenizer_module["tokenizer"]
|
||||
@@ -104,10 +103,7 @@ def run_pt(
|
||||
|
||||
_check_model_support(model_args)
|
||||
model = AutoModel.from_pretrained(model_args.model_name_or_path, training_args)
|
||||
|
||||
from transformers import DataCollatorForSeq2Seq
|
||||
|
||||
data_collator: DataCollatorForSeq2Seq = DataCollatorForSeq2Seq(
|
||||
data_collator = DataCollatorForSeq2Seq(
|
||||
tokenizer=tokenizer,
|
||||
pad_to_multiple_of=8,
|
||||
label_pad_token_id=IGNORE_INDEX,
|
||||
@@ -142,11 +138,11 @@ def run_pt(
|
||||
|
||||
|
||||
def run_sft(
|
||||
model_args: ModelArguments,
|
||||
data_args: DataArguments,
|
||||
training_args: McaSeq2SeqTrainingArguments,
|
||||
finetuning_args: FinetuningArguments,
|
||||
callbacks: list[TrainerCallback] | None = None,
|
||||
model_args: "ModelArguments",
|
||||
data_args: "DataArguments",
|
||||
training_args: "McaSeq2SeqTrainingArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
callbacks: Optional[list["TrainerCallback"]] = None,
|
||||
):
|
||||
# align packing flags
|
||||
# TODO: FIX SequencePacking
|
||||
@@ -166,7 +162,7 @@ def run_sft(
|
||||
model = AutoModel.from_pretrained(model_args.model_name_or_path, training_args)
|
||||
|
||||
# optional freezing for qwen2_vl, qwen2_5_vl
|
||||
if getattr(model.config, "hf_model_type", None) in ["qwen2_vl", "qwen2_5_vl"]:
|
||||
if getattr(model.config, "hf_model_type", None) in ["qwen2_vl", "qwen2_5_vl", "qwen3_vl"]:
|
||||
params_to_freeze = []
|
||||
if finetuning_args.freeze_vision_tower:
|
||||
params_to_freeze.extend(["vision_model.blocks", "vision_model.patch_embed"])
|
||||
@@ -220,11 +216,11 @@ def run_sft(
|
||||
|
||||
|
||||
def run_dpo(
|
||||
model_args: ModelArguments,
|
||||
data_args: DataArguments,
|
||||
training_args: McaSeq2SeqTrainingArguments,
|
||||
finetuning_args: FinetuningArguments,
|
||||
callbacks: list[TrainerCallback] | None = None,
|
||||
model_args: "ModelArguments",
|
||||
data_args: "DataArguments",
|
||||
training_args: "McaSeq2SeqTrainingArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
callbacks: Optional[list["TrainerCallback"]] = None,
|
||||
):
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
tokenizer = tokenizer_module["tokenizer"]
|
||||
|
||||
@@ -33,12 +33,12 @@ from transformers.trainer_pt_utils import remove_dummy_checkpoint
|
||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
||||
from transformers.utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME
|
||||
from trl import PPOConfig, PPOTrainer
|
||||
from trl.core import PPODecorators, logprobs_from_logits
|
||||
from trl import __version__ as trl_version
|
||||
from trl.models.utils import unwrap_model_for_generation
|
||||
from typing_extensions import override
|
||||
|
||||
from ...extras import logging
|
||||
from ...extras.misc import AverageMeter, count_parameters, get_current_device, get_logits_processor
|
||||
from ...extras.misc import AverageMeter, count_parameters, get_current_device, get_logits_processor, torch_gc
|
||||
from ..callbacks import FixValueHeadModelCallback, SaveProcessorCallback
|
||||
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
|
||||
from .ppo_utils import dump_layernorm, get_rewards_from_server, replace_model, restore_layernorm
|
||||
@@ -83,6 +83,19 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
if eval_dataset is not None:
|
||||
raise NotImplementedError("PPOTrainer does not support eval dataset yet.")
|
||||
|
||||
# Check if TRL version is compatible (0.8.6 <= version <= 0.9.6)
|
||||
try:
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
require_version(
|
||||
"trl>=0.8.6,<=0.9.6",
|
||||
"Incompatible TRL version detected. LLaMA-Factory ppo requires TRL version >=0.8.6,<=0.9.6. "
|
||||
f"Found version {trl_version}. Please install the correct version with: `pip install trl>=0.8.6,<=0.9.6`\n"
|
||||
"To fix: run `DISABLE_VERSION_CHECK=1 llamafactory-cli train example_ppo.yaml`\n",
|
||||
)
|
||||
except ImportError as e:
|
||||
raise e
|
||||
|
||||
backward_batch_size = training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps
|
||||
ppo_config = PPOConfig(
|
||||
model_name=model_args.model_name_or_path,
|
||||
@@ -406,7 +419,6 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
return rewards.float().detach() # use fp32 type
|
||||
|
||||
@override
|
||||
@PPODecorators.empty_device_cache()
|
||||
def batched_forward_pass(
|
||||
self,
|
||||
model: "AutoModelForCausalLMWithValueHead",
|
||||
@@ -420,6 +432,9 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
|
||||
Subclass and override to inject custom behavior.
|
||||
"""
|
||||
from trl.core import logprobs_from_logits
|
||||
|
||||
torch_gc()
|
||||
bs = len(queries)
|
||||
fbs = self.config.mini_batch_size
|
||||
all_logprobs = []
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user