mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-03-04 10:46:00 +08:00
Compare commits
14 Commits
v0.9.4
...
b4e051bea4
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b4e051bea4 | ||
|
|
d43e1007e8 | ||
|
|
f89d9367e5 | ||
|
|
d22de0d4bf | ||
|
|
ea0b4e2466 | ||
|
|
e944dc442c | ||
|
|
68119e5522 | ||
|
|
f60a6e3d01 | ||
|
|
81b8a50aa5 | ||
|
|
8600530002 | ||
|
|
9ae62c6fc0 | ||
|
|
0087bc253b | ||
|
|
355d5c5e5a | ||
|
|
6fe6bd290b |
3
.github/workflows/tests.yml
vendored
3
.github/workflows/tests.yml
vendored
@@ -70,7 +70,8 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
uv venv
|
uv venv
|
||||||
uv pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
|
uv pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
|
||||||
uv pip install -e ".[dev]"
|
uv pip install -e .
|
||||||
|
uv pip install -r requirements/dev.txt
|
||||||
|
|
||||||
- name: Install transformers
|
- name: Install transformers
|
||||||
if: ${{ matrix.transformers }}
|
if: ${{ matrix.transformers }}
|
||||||
|
|||||||
25
.github/workflows/tests_cuda.yml
vendored
25
.github/workflows/tests_cuda.yml
vendored
@@ -35,6 +35,11 @@ jobs:
|
|||||||
group: ${{ github.workflow }}-${{ github.ref }}-${{ matrix.os }}-${{ matrix.python }}
|
group: ${{ github.workflow }}-${{ github.ref }}-${{ matrix.os }}-${{ matrix.python }}
|
||||||
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
|
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
|
||||||
|
|
||||||
|
env:
|
||||||
|
HF_HOME: "${{ github.workspace }}/../.runner_cache/huggingface"
|
||||||
|
UV_CACHE_DIR: "${{ github.workspace }}/../.runner_cache/uv"
|
||||||
|
UV_NO_SYNC: 1
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
@@ -52,37 +57,21 @@ jobs:
|
|||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
uv venv
|
uv venv
|
||||||
uv pip install -e ".[dev]"
|
uv pip install -e .
|
||||||
|
uv pip install -r requirements/dev.txt
|
||||||
- name: Cache HuggingFace models
|
|
||||||
id: hf-hub-cache
|
|
||||||
uses: actions/cache@v4
|
|
||||||
with:
|
|
||||||
path: ${{ runner.temp }}/huggingface
|
|
||||||
key: hf-cache-${{ runner.os }}-${{ hashFiles('tests/version.txt') }}
|
|
||||||
|
|
||||||
- name: Check quality
|
- name: Check quality
|
||||||
run: |
|
run: |
|
||||||
make style && make quality
|
make style && make quality
|
||||||
env:
|
|
||||||
UV_NO_SYNC: 1
|
|
||||||
|
|
||||||
- name: Check license
|
- name: Check license
|
||||||
run: |
|
run: |
|
||||||
make license
|
make license
|
||||||
env:
|
|
||||||
UV_NO_SYNC: 1
|
|
||||||
|
|
||||||
- name: Check build
|
- name: Check build
|
||||||
run: |
|
run: |
|
||||||
make build
|
make build
|
||||||
env:
|
|
||||||
UV_NO_SYNC: 1
|
|
||||||
|
|
||||||
- name: Test with pytest
|
- name: Test with pytest
|
||||||
run: |
|
run: |
|
||||||
make test
|
make test
|
||||||
env:
|
|
||||||
UV_NO_SYNC: 1
|
|
||||||
HF_HOME: ${{ runner.temp }}/huggingface
|
|
||||||
HF_HUB_OFFLINE: "${{ steps.hf-hub-cache.outputs.cache-hit == 'true' && '1' || '0' }}"
|
|
||||||
|
|||||||
5
.github/workflows/tests_npu.yml
vendored
5
.github/workflows/tests_npu.yml
vendored
@@ -58,8 +58,9 @@ jobs:
|
|||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
uv venv
|
uv venv
|
||||||
uv pip install torch-npu==${{matrix.pytorch_npu}}
|
uv pip install -r requirements/npu.txt
|
||||||
uv pip install -e ".[dev]"
|
uv pip install -e .
|
||||||
|
uv pip install -r requirements/dev.txt
|
||||||
|
|
||||||
- name: Install node
|
- name: Install node
|
||||||
run: |
|
run: |
|
||||||
|
|||||||
45
README.md
45
README.md
@@ -329,6 +329,7 @@ Read technical notes:
|
|||||||
| [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
|
| [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
|
||||||
| [VibeThinker-1.5B](https://huggingface.co/WeiboAI) | 1.5B | qwen3 |
|
| [VibeThinker-1.5B](https://huggingface.co/WeiboAI) | 1.5B | qwen3 |
|
||||||
| [Yi/Yi-1.5 (Code)](https://huggingface.co/01-ai) | 1.5B/6B/9B/34B | yi |
|
| [Yi/Yi-1.5 (Code)](https://huggingface.co/01-ai) | 1.5B/6B/9B/34B | yi |
|
||||||
|
| [Youtu-LLM](https://huggingface.co/tencent/) | 2B | youtu |
|
||||||
| [Yuan 2](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan |
|
| [Yuan 2](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan |
|
||||||
|
|
||||||
> [!NOTE]
|
> [!NOTE]
|
||||||
@@ -514,12 +515,13 @@ huggingface-cli login
|
|||||||
#### Install from Source
|
#### Install from Source
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
git clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git
|
git clone --depth 1 https://github.com/hiyouga/LlamaFactory.git
|
||||||
cd LLaMA-Factory
|
cd LlamaFactory
|
||||||
pip install -e ".[metrics]"
|
pip install -e .
|
||||||
|
pip install -r requirements/metrics.txt
|
||||||
```
|
```
|
||||||
|
|
||||||
Optional dependencies available: `metrics`, `deepspeed`. Install with: `pip install -e ".[metrics,deepspeed]"`
|
Optional dependencies available: `metrics`, `deepspeed`. Install with: `pip install -e . && pip install -r requirements/metrics.txt -r requirements/deepspeed.txt`
|
||||||
|
|
||||||
Additional dependencies for specific features are available in `examples/requirements/`.
|
Additional dependencies for specific features are available in `examples/requirements/`.
|
||||||
|
|
||||||
@@ -577,36 +579,21 @@ To enable FlashAttention-2 on the Windows platform, please use the script from [
|
|||||||
|
|
||||||
<details><summary>For Ascend NPU users</summary>
|
<details><summary>For Ascend NPU users</summary>
|
||||||
|
|
||||||
To install LLaMA Factory on Ascend NPU devices, please upgrade Python to version 3.10 or higher: `pip install -e . torch-npu==2.7.1`. Additionally, you need to install the **[Ascend CANN Toolkit and Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**. Please follow the [installation tutorial](https://www.hiascend.com/document/detail/en/CANNCommunityEdition/600alphaX/softwareinstall/instg/atlasdeploy_03_0031.html) or use the following commands:
|
To install LLaMA Factory on Ascend NPU devices, please upgrade Python to version 3.10 or higher: `pip install -r requirements/npu.txt`. Additionally, you need to install the **Ascend CANN Toolkit and Kernels**. Please follow the [installation tutorial](https://llamafactory.readthedocs.io/en/latest/advanced/npu_installation.html).
|
||||||
|
|
||||||
|
|
||||||
|
You can also download the pre-built Docker images:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# replace the url according to your CANN version and devices
|
# Docker Hub
|
||||||
# install CANN Toolkit
|
docker pull hiyouga/llamafactory:latest-npu-a2
|
||||||
wget https://ascend-repo.obs.cn-east-2.myhuaweicloud.com/Milan-ASL/Milan-ASL%20V100R001C20SPC702/Ascend-cann-toolkit_8.0.0.alpha002_linux-"$(uname -i)".run
|
docker pull hiyouga/llamafactory:latest-npu-a3
|
||||||
bash Ascend-cann-toolkit_8.0.0.alpha002_linux-"$(uname -i)".run --install
|
|
||||||
|
|
||||||
# install CANN Kernels
|
# quay.io
|
||||||
wget https://ascend-repo.obs.cn-east-2.myhuaweicloud.com/Milan-ASL/Milan-ASL%20V100R001C20SPC702/Ascend-cann-kernels-910b_8.0.0.alpha002_linux-"$(uname -i)".run
|
docker pull quay.io/ascend/llamafactory:latest-npu-a2
|
||||||
bash Ascend-cann-kernels-910b_8.0.0.alpha002_linux-"$(uname -i)".run --install
|
docker pull quay.io/ascend/llamafactory:latest-npu-a3
|
||||||
|
|
||||||
# set env variables
|
|
||||||
source /usr/local/Ascend/ascend-toolkit/set_env.sh
|
|
||||||
```
|
```
|
||||||
|
|
||||||
| Requirement | Minimum | Recommend |
|
|
||||||
| ------------ | ------- | -------------- |
|
|
||||||
| CANN | 8.0.RC1 | 8.0.0.alpha002 |
|
|
||||||
| torch | 2.1.0 | 2.7.1 |
|
|
||||||
| torch-npu | 2.1.0 | 2.7.1 |
|
|
||||||
| deepspeed | 0.13.2 | 0.13.2 |
|
|
||||||
| vllm-ascend | - | 0.7.3 |
|
|
||||||
|
|
||||||
Remember to use `ASCEND_RT_VISIBLE_DEVICES` instead of `CUDA_VISIBLE_DEVICES` to specify the device to use.
|
|
||||||
|
|
||||||
If you cannot infer model on NPU devices, try setting `do_sample: false` in the configurations.
|
|
||||||
|
|
||||||
Download the pre-built Docker images: [32GB](http://mirrors.cn-central-221.ovaijisuan.com/detail/130.html) | [64GB](http://mirrors.cn-central-221.ovaijisuan.com/detail/131.html)
|
|
||||||
|
|
||||||
#### Install BitsAndBytes
|
#### Install BitsAndBytes
|
||||||
|
|
||||||
To use QLoRA based on bitsandbytes on Ascend NPU, please follow these 3 steps:
|
To use QLoRA based on bitsandbytes on Ascend NPU, please follow these 3 steps:
|
||||||
|
|||||||
44
README_zh.md
44
README_zh.md
@@ -331,6 +331,7 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
|
|||||||
| [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
|
| [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
|
||||||
| [VibeThinker-1.5B](https://huggingface.co/WeiboAI) | 1.5B | qwen3 |
|
| [VibeThinker-1.5B](https://huggingface.co/WeiboAI) | 1.5B | qwen3 |
|
||||||
| [Yi/Yi-1.5 (Code)](https://huggingface.co/01-ai) | 1.5B/6B/9B/34B | yi |
|
| [Yi/Yi-1.5 (Code)](https://huggingface.co/01-ai) | 1.5B/6B/9B/34B | yi |
|
||||||
|
| [Youtu-LLM](https://huggingface.co/tencent/) | 2B | youtu |
|
||||||
| [Yuan 2](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan |
|
| [Yuan 2](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan |
|
||||||
|
|
||||||
> [!NOTE]
|
> [!NOTE]
|
||||||
@@ -516,12 +517,13 @@ huggingface-cli login
|
|||||||
#### 从源码安装
|
#### 从源码安装
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
git clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git
|
git clone --depth 1 https://github.com/hiyouga/LlamaFactory.git
|
||||||
cd LLaMA-Factory
|
cd LlamaFactory
|
||||||
pip install -e ".[metrics]"
|
pip install -e .
|
||||||
|
pip install -r requirements/metrics.txt
|
||||||
```
|
```
|
||||||
|
|
||||||
可选的额外依赖项:`metrics`、`deepspeed`。使用 `pip install -e ".[metrics,deepspeed]"` 安装。
|
可选的额外依赖项:`metrics`、`deepspeed`。使用 `pip install -e . && pip install -r requirements/metrics.txt -r requirements/deepspeed.txt` 安装。
|
||||||
|
|
||||||
其他可选依赖项请参考 `examples/requirements/` 目录下的文件。
|
其他可选依赖项请参考 `examples/requirements/` 目录下的文件。
|
||||||
|
|
||||||
@@ -579,36 +581,20 @@ pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/downl
|
|||||||
|
|
||||||
<details><summary>昇腾 NPU 用户指南</summary>
|
<details><summary>昇腾 NPU 用户指南</summary>
|
||||||
|
|
||||||
在昇腾 NPU 设备上安装 LLaMA Factory 时,请升级 Python 到 3.10 及以上,并需要指定额外依赖项,使用 `pip install -e . torch-npu==2.7.1` 命令安装。此外,还需要安装 **[Ascend CANN Toolkit 与 Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**,安装方法请参考[安装教程](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/80RC2alpha002/quickstart/quickstart/quickstart_18_0004.html)或使用以下命令:
|
在昇腾 NPU 设备上安装 LLaMA Factory 时,请升级 Python 到 3.10 及以上,并需要指定额外依赖项,使用 `pip install -r requirements/npu.txt` 命令安装。此外,还需要安装 **Ascend CANN Toolkit 与 Kernels**,安装方法请参考[安装教程](https://llamafactory.readthedocs.io/zh-cn/latest/advanced/npu_installation.html)。
|
||||||
|
|
||||||
|
您可以直接下载预安装的最新docker镜像:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# 请替换 URL 为 CANN 版本和设备型号对应的 URL
|
# Docker Hub
|
||||||
# 安装 CANN Toolkit
|
docker pull hiyouga/llamafactory:latest-npu-a2
|
||||||
wget https://ascend-repo.obs.cn-east-2.myhuaweicloud.com/Milan-ASL/Milan-ASL%20V100R001C17SPC701/Ascend-cann-toolkit_8.0.RC1.alpha001_linux-"$(uname -i)".run
|
docker pull hiyouga/llamafactory:latest-npu-a3
|
||||||
bash Ascend-cann-toolkit_8.0.RC1.alpha001_linux-"$(uname -i)".run --install
|
|
||||||
|
|
||||||
# 安装 CANN Kernels
|
# quay.io
|
||||||
wget https://ascend-repo.obs.cn-east-2.myhuaweicloud.com/Milan-ASL/Milan-ASL%20V100R001C17SPC701/Ascend-cann-kernels-910b_8.0.RC1.alpha001_linux.run
|
docker pull quay.io/ascend/llamafactory:latest-npu-a2
|
||||||
bash Ascend-cann-kernels-910b_8.0.RC1.alpha001_linux.run --install
|
docker pull quay.io/ascend/llamafactory:latest-npu-a3
|
||||||
|
|
||||||
# 设置环境变量
|
|
||||||
source /usr/local/Ascend/ascend-toolkit/set_env.sh
|
|
||||||
```
|
```
|
||||||
|
|
||||||
| 依赖项 | 至少 | 推荐 |
|
|
||||||
| ------------ | ------- | -------------- |
|
|
||||||
| CANN | 8.0.RC1 | 8.0.0.alpha002 |
|
|
||||||
| torch | 2.1.0 | 2.7.1 |
|
|
||||||
| torch-npu | 2.1.0 | 2.7.1 |
|
|
||||||
| deepspeed | 0.13.2 | 0.13.2 |
|
|
||||||
| vllm-ascend | - | 0.7.3 |
|
|
||||||
|
|
||||||
请使用 `ASCEND_RT_VISIBLE_DEVICES` 而非 `CUDA_VISIBLE_DEVICES` 来指定运算设备。
|
|
||||||
|
|
||||||
如果遇到无法正常推理的情况,请尝试设置 `do_sample: false`。
|
|
||||||
|
|
||||||
下载预构建 Docker 镜像:[32GB](http://mirrors.cn-central-221.ovaijisuan.com/detail/130.html) | [64GB](http://mirrors.cn-central-221.ovaijisuan.com/detail/131.html)
|
|
||||||
|
|
||||||
#### 安装 BitsAndBytes
|
#### 安装 BitsAndBytes
|
||||||
|
|
||||||
如果要在 Ascend NPU 上进行基于 bitsandbytes 的 QLoRA 量化微调,请执行如下步骤:
|
如果要在 Ascend NPU 上进行基于 bitsandbytes 的 QLoRA 量化微调,请执行如下步骤:
|
||||||
|
|||||||
File diff suppressed because one or more lines are too long
@@ -32,7 +32,8 @@ RUN pip config set global.index-url "${PIP_INDEX}" && \
|
|||||||
COPY . /app
|
COPY . /app
|
||||||
|
|
||||||
# Install LLaMA Factory
|
# Install LLaMA Factory
|
||||||
RUN pip install --no-cache-dir --no-build-isolation -e ".[metrics,deepspeed]"
|
RUN pip install --no-cache-dir --no-build-isolation -e . && \
|
||||||
|
pip install --no-cache-dir --no-build-isolation -r requirements/metrics.txt -r requirements/deepspeed.txt
|
||||||
|
|
||||||
# Rebuild flash attention
|
# Rebuild flash attention
|
||||||
RUN if [ "${INSTALL_FLASHATTN}" == "true" ]; then \
|
RUN if [ "${INSTALL_FLASHATTN}" == "true" ]; then \
|
||||||
|
|||||||
@@ -60,7 +60,8 @@ WORKDIR /app
|
|||||||
COPY . /app
|
COPY . /app
|
||||||
|
|
||||||
# Install LLaMA Factory
|
# Install LLaMA Factory
|
||||||
RUN pip install --no-cache-dir -e ".[metrics]" --no-build-isolation
|
RUN pip install --no-cache-dir -e . --no-build-isolation && \
|
||||||
|
pip install --no-cache-dir -r requirements/metrics.txt --no-build-isolation
|
||||||
|
|
||||||
RUN pip install "git+https://github.com/alibaba/roll.git#subdirectory=mcore_adapter"
|
RUN pip install "git+https://github.com/alibaba/roll.git#subdirectory=mcore_adapter"
|
||||||
|
|
||||||
|
|||||||
@@ -35,7 +35,8 @@ COPY . /app
|
|||||||
# Install torch-npu
|
# Install torch-npu
|
||||||
RUN pip uninstall -y torch torchvision torchaudio && \
|
RUN pip uninstall -y torch torchvision torchaudio && \
|
||||||
pip install --no-cache-dir "torch==2.7.1" "torch-npu==2.7.1" "torchvision==0.22.1" "torchaudio==2.7.1" --index-url "${PYTORCH_INDEX}" && \
|
pip install --no-cache-dir "torch==2.7.1" "torch-npu==2.7.1" "torchvision==0.22.1" "torchaudio==2.7.1" --index-url "${PYTORCH_INDEX}" && \
|
||||||
pip install --no-cache-dir -e ".[metrics]" --no-build-isolation
|
pip install --no-cache-dir -e . --no-build-isolation && \
|
||||||
|
pip install --no-cache-dir -r requirements/metrics.txt --no-build-isolation
|
||||||
|
|
||||||
# Set up volumes
|
# Set up volumes
|
||||||
# VOLUME [ "/root/.cache/huggingface", "/app/shared_data", "/app/output" ]
|
# VOLUME [ "/root/.cache/huggingface", "/app/shared_data", "/app/output" ]
|
||||||
|
|||||||
@@ -34,7 +34,8 @@ COPY . /app
|
|||||||
|
|
||||||
# Reinstall pytorch rocm and install LLaMA Factory
|
# Reinstall pytorch rocm and install LLaMA Factory
|
||||||
RUN pip uninstall -y torch torchvision torchaudio && \
|
RUN pip uninstall -y torch torchvision torchaudio && \
|
||||||
pip install --no-cache-dir --no-build-isolation -e --pre ".[metrics,deepspeed]" --index-url "${PYTORCH_INDEX}"
|
pip install --no-cache-dir --no-build-isolation -e --pre . --index-url "${PYTORCH_INDEX}" && \
|
||||||
|
pip install --no-cache-dir --no-build-isolation -r requirements/metrics.txt -r requirements/deepspeed.txt --index-url "${PYTORCH_INDEX}"
|
||||||
|
|
||||||
# Rebuild flash attention
|
# Rebuild flash attention
|
||||||
RUN if [ "${INSTALL_FLASHATTN}" == "true" ]; then \
|
RUN if [ "${INSTALL_FLASHATTN}" == "true" ]; then \
|
||||||
|
|||||||
40
examples/extras/eaft/qwen25_05b_eaft_full.yaml
Normal file
40
examples/extras/eaft/qwen25_05b_eaft_full.yaml
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
### model
|
||||||
|
model_name_or_path: Qwen/Qwen2.5-0.5B-Instruct
|
||||||
|
trust_remote_code: true
|
||||||
|
|
||||||
|
### method
|
||||||
|
stage: sft
|
||||||
|
do_train: true
|
||||||
|
finetuning_type: full
|
||||||
|
use_eaft_loss: true
|
||||||
|
|
||||||
|
### dataset
|
||||||
|
dataset: identity,alpaca_en_demo
|
||||||
|
template: qwen
|
||||||
|
cutoff_len: 2048
|
||||||
|
max_samples: 1000
|
||||||
|
overwrite_cache: true
|
||||||
|
preprocessing_num_workers: 16
|
||||||
|
dataloader_num_workers: 4
|
||||||
|
|
||||||
|
### output
|
||||||
|
output_dir: qwen2.5-0_5b/full/sft_eaft
|
||||||
|
logging_steps: 1
|
||||||
|
save_steps: 500
|
||||||
|
plot_loss: true
|
||||||
|
overwrite_output_dir: true
|
||||||
|
save_only_model: false
|
||||||
|
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
|
||||||
|
|
||||||
|
|
||||||
|
### train
|
||||||
|
per_device_train_batch_size: 2
|
||||||
|
gradient_accumulation_steps: 8
|
||||||
|
learning_rate: 1.0e-5
|
||||||
|
num_train_epochs: 3.0
|
||||||
|
lr_scheduler_type: cosine
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
bf16: true
|
||||||
|
ddp_timeout: 180000000
|
||||||
|
|
||||||
|
|
||||||
@@ -73,14 +73,9 @@ dependencies = [
|
|||||||
# api
|
# api
|
||||||
"uvicorn",
|
"uvicorn",
|
||||||
"fastapi",
|
"fastapi",
|
||||||
"sse-starlette"
|
"sse-starlette",
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.optional-dependencies]
|
|
||||||
dev = ["pre-commit", "ruff", "pytest", "build"]
|
|
||||||
metrics = ["nltk", "jieba", "rouge-chinese"]
|
|
||||||
deepspeed = ["deepspeed>=0.10.0,<=0.16.9"]
|
|
||||||
|
|
||||||
[project.scripts]
|
[project.scripts]
|
||||||
llamafactory-cli = "llamafactory.cli:main"
|
llamafactory-cli = "llamafactory.cli:main"
|
||||||
lmf = "llamafactory.cli:main"
|
lmf = "llamafactory.cli:main"
|
||||||
|
|||||||
1
requirements/deepspeed.txt
Normal file
1
requirements/deepspeed.txt
Normal file
@@ -0,0 +1 @@
|
|||||||
|
deepspeed>=0.10.0,<=0.16.9
|
||||||
4
requirements/dev.txt
Normal file
4
requirements/dev.txt
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
pre-commit
|
||||||
|
ruff
|
||||||
|
pytest
|
||||||
|
build
|
||||||
3
requirements/metrics.txt
Normal file
3
requirements/metrics.txt
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
nltk
|
||||||
|
jieba
|
||||||
|
rouge-chinese
|
||||||
4
requirements/npu.txt
Normal file
4
requirements/npu.txt
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
torch==2.7.1
|
||||||
|
torch-npu==2.7.1
|
||||||
|
torchvision==0.22.1
|
||||||
|
torchaudio==2.7.1
|
||||||
@@ -28,7 +28,7 @@ try:
|
|||||||
jieba.setLogLevel(logging.CRITICAL)
|
jieba.setLogLevel(logging.CRITICAL)
|
||||||
jieba.initialize()
|
jieba.initialize()
|
||||||
except ImportError:
|
except ImportError:
|
||||||
print("Please install llamafactory with `pip install -e .[metrics]`.")
|
print("Please install llamafactory with `pip install -r requirements/metrics.txt`.")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1330,6 +1330,26 @@ register_template(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_template(
|
||||||
|
name="lfm",
|
||||||
|
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||||
|
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
|
||||||
|
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
||||||
|
format_function=FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="lfm"),
|
||||||
|
format_observation=StringFormatter(
|
||||||
|
slots=[
|
||||||
|
"<|im_start|>tool\n<|tool_response_start|>{{content}}<|tool_response_end|><|im_end|>\n"
|
||||||
|
"<|im_start|>assistant\n"
|
||||||
|
]
|
||||||
|
),
|
||||||
|
format_tools=ToolFormatter(tool_format="lfm"),
|
||||||
|
default_system="You are a helpful AI assistant.",
|
||||||
|
stop_words=["<|im_end|>"],
|
||||||
|
tool_call_words=("<|tool_call_start|>", "<|tool_call_end|>"),
|
||||||
|
replace_eos=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
register_template(
|
register_template(
|
||||||
name="llama2",
|
name="llama2",
|
||||||
format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]),
|
format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]),
|
||||||
@@ -2278,6 +2298,21 @@ register_template(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_template(
|
||||||
|
name="youtu",
|
||||||
|
format_user=StringFormatter(slots=["<|User|>{{content}}<|Assistant|>"]),
|
||||||
|
format_assistant=StringFormatter(slots=["{{content}}<|end_of_text|>"]),
|
||||||
|
format_system=StringFormatter(slots=["{{content}}"]),
|
||||||
|
format_function=FunctionFormatter(slots=["{{content}}"], tool_format="default"),
|
||||||
|
format_observation=StringFormatter(slots=["<tool_response>\n{{content}}\n</tool_response><|Assistant|>"]),
|
||||||
|
format_tools=ToolFormatter(tool_format="default"),
|
||||||
|
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||||
|
stop_words=["<|end_of_text|>"],
|
||||||
|
replace_eos=True,
|
||||||
|
template_class=ReasoningTemplate,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
register_template(
|
register_template(
|
||||||
name="yuan",
|
name="yuan",
|
||||||
format_user=StringFormatter(slots=["{{content}}", {"token": "<sep>"}]),
|
format_user=StringFormatter(slots=["{{content}}", {"token": "<sep>"}]),
|
||||||
|
|||||||
@@ -12,6 +12,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import ast
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
@@ -101,6 +102,8 @@ LING_TOOL_PROMPT = (
|
|||||||
""""arguments": <args-json-object>}}\n</tool_call>"""
|
""""arguments": <args-json-object>}}\n</tool_call>"""
|
||||||
)
|
)
|
||||||
|
|
||||||
|
LFM_TOOL_PROMPT = "List of tools: <|tool_list_start|>{tool_text}<|tool_list_end|>"
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ToolUtils(ABC):
|
class ToolUtils(ABC):
|
||||||
@@ -546,10 +549,115 @@ class LingToolUtils(QwenToolUtils):
|
|||||||
return LING_TOOL_PROMPT.format(tool_text=tool_text) + "\n" + "detailed thinking off"
|
return LING_TOOL_PROMPT.format(tool_text=tool_text) + "\n" + "detailed thinking off"
|
||||||
|
|
||||||
|
|
||||||
|
class LFMToolUtils(ToolUtils):
|
||||||
|
r"""LFM 2.5 tool using template with Pythonic function call syntax."""
|
||||||
|
|
||||||
|
@override
|
||||||
|
@staticmethod
|
||||||
|
def tool_formatter(tools: list[dict[str, Any]]) -> str:
|
||||||
|
tool_list = []
|
||||||
|
for tool in tools:
|
||||||
|
tool = tool.get("function", tool) if tool.get("type") == "function" else tool
|
||||||
|
tool_list.append(tool)
|
||||||
|
|
||||||
|
return LFM_TOOL_PROMPT.format(tool_text=json.dumps(tool_list, ensure_ascii=False))
|
||||||
|
|
||||||
|
@override
|
||||||
|
@staticmethod
|
||||||
|
def function_formatter(functions: list["FunctionCall"]) -> str:
|
||||||
|
calls = []
|
||||||
|
for name, args_json in functions:
|
||||||
|
args = json.loads(args_json)
|
||||||
|
kwargs_parts = []
|
||||||
|
for key, value in args.items():
|
||||||
|
if isinstance(value, str):
|
||||||
|
kwargs_parts.append(f'{key}="{value}"')
|
||||||
|
else:
|
||||||
|
kwargs_parts.append(f"{key}={json.dumps(value, ensure_ascii=False)}")
|
||||||
|
|
||||||
|
calls.append(f"{name}({', '.join(kwargs_parts)})")
|
||||||
|
|
||||||
|
return f"<|tool_call_start|>[{', '.join(calls)}]<|tool_call_end|>"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _ast_to_value(node: ast.AST) -> Any:
|
||||||
|
"""Convert an AST node to a Python value, handling JSON-style booleans/null."""
|
||||||
|
# Handle JSON-style true/false/null as Name nodes
|
||||||
|
if isinstance(node, ast.Name):
|
||||||
|
if node.id == "true":
|
||||||
|
return True
|
||||||
|
elif node.id == "false":
|
||||||
|
return False
|
||||||
|
elif node.id == "null":
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown identifier: {node.id}")
|
||||||
|
|
||||||
|
# Use literal_eval for other cases (strings, numbers, lists, dicts)
|
||||||
|
return ast.literal_eval(node)
|
||||||
|
|
||||||
|
@override
|
||||||
|
@staticmethod
|
||||||
|
def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]:
|
||||||
|
# Extract content between tool call markers
|
||||||
|
start_marker = "<|tool_call_start|>"
|
||||||
|
end_marker = "<|tool_call_end|>"
|
||||||
|
|
||||||
|
start_idx = content.find(start_marker)
|
||||||
|
if start_idx == -1:
|
||||||
|
return content
|
||||||
|
|
||||||
|
end_idx = content.find(end_marker, start_idx)
|
||||||
|
if end_idx == -1:
|
||||||
|
return content
|
||||||
|
|
||||||
|
tool_call_str = content[start_idx + len(start_marker) : end_idx].strip()
|
||||||
|
|
||||||
|
# Parse Pythonic function call syntax using AST
|
||||||
|
try:
|
||||||
|
tree = ast.parse(tool_call_str, mode="eval")
|
||||||
|
except SyntaxError:
|
||||||
|
return content
|
||||||
|
|
||||||
|
# Handle both single call and list of calls
|
||||||
|
if isinstance(tree.body, ast.List):
|
||||||
|
call_nodes = tree.body.elts
|
||||||
|
elif isinstance(tree.body, ast.Call):
|
||||||
|
call_nodes = [tree.body]
|
||||||
|
else:
|
||||||
|
return content
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for node in call_nodes:
|
||||||
|
if not isinstance(node, ast.Call):
|
||||||
|
return content
|
||||||
|
|
||||||
|
# Extract function name
|
||||||
|
if isinstance(node.func, ast.Name):
|
||||||
|
func_name = node.func.id
|
||||||
|
else:
|
||||||
|
return content
|
||||||
|
|
||||||
|
# Extract keyword arguments
|
||||||
|
args_dict = {}
|
||||||
|
for keyword in node.keywords:
|
||||||
|
key = keyword.arg
|
||||||
|
try:
|
||||||
|
value = LFMToolUtils._ast_to_value(keyword.value)
|
||||||
|
except (ValueError, SyntaxError):
|
||||||
|
return content
|
||||||
|
args_dict[key] = value
|
||||||
|
|
||||||
|
results.append(FunctionCall(func_name, json.dumps(args_dict, ensure_ascii=False)))
|
||||||
|
|
||||||
|
return results if results else content
|
||||||
|
|
||||||
|
|
||||||
TOOLS = {
|
TOOLS = {
|
||||||
"default": DefaultToolUtils(),
|
"default": DefaultToolUtils(),
|
||||||
"glm4": GLM4ToolUtils(),
|
"glm4": GLM4ToolUtils(),
|
||||||
"llama3": Llama3ToolUtils(),
|
"llama3": Llama3ToolUtils(),
|
||||||
|
"lfm": LFMToolUtils(),
|
||||||
"minimax1": MiniMaxM1ToolUtils(),
|
"minimax1": MiniMaxM1ToolUtils(),
|
||||||
"minimax2": MiniMaxM2ToolUtils(),
|
"minimax2": MiniMaxM2ToolUtils(),
|
||||||
"mistral": MistralToolUtils(),
|
"mistral": MistralToolUtils(),
|
||||||
|
|||||||
@@ -1493,6 +1493,19 @@ register_model_group(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_model_group(
|
||||||
|
models={
|
||||||
|
"LFM2.5-1.2B": {
|
||||||
|
DownloadSource.DEFAULT: "LiquidAI/LFM2.5-1.2B-Base",
|
||||||
|
},
|
||||||
|
"LFM2.5-1.2B-Instruct": {
|
||||||
|
DownloadSource.DEFAULT: "LiquidAI/LFM2.5-1.2B-Instruct",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
template="lfm",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
"Llama-7B": {
|
"Llama-7B": {
|
||||||
@@ -3846,6 +3859,21 @@ register_model_group(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_model_group(
|
||||||
|
models={
|
||||||
|
"Youtu-LLM-2B-Instruct": {
|
||||||
|
DownloadSource.DEFAULT: "tencent/Youtu-LLM-2B",
|
||||||
|
DownloadSource.MODELSCOPE: "Tencent-YouTu-Research/Youtu-LLM-2B",
|
||||||
|
},
|
||||||
|
"Youtu-LLM-2B-Base": {
|
||||||
|
DownloadSource.DEFAULT: "tencent/Youtu-LLM-2B-Base",
|
||||||
|
DownloadSource.MODELSCOPE: "Tencent-YouTu-Research/Youtu-LLM-2B-Base",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
template="youtu",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
"Yuan2-2B-Chat": {
|
"Yuan2-2B-Chat": {
|
||||||
|
|||||||
@@ -19,7 +19,7 @@
|
|||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
|
||||||
|
|
||||||
VERSION = "0.9.4"
|
VERSION = "0.9.5.dev0"
|
||||||
|
|
||||||
|
|
||||||
def print_env() -> None:
|
def print_env() -> None:
|
||||||
|
|||||||
@@ -490,6 +490,14 @@ class FinetuningArguments(
|
|||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Whether to use the DFT loss."},
|
metadata={"help": "Whether to use the DFT loss."},
|
||||||
)
|
)
|
||||||
|
use_eaft_loss: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={"help": "Whether to use the EAFT loss."},
|
||||||
|
)
|
||||||
|
eaft_alpha: float = field(
|
||||||
|
default=1.0,
|
||||||
|
metadata={"help": "The alpha parameter for EAFT loss to control the power of adaptive weight."},
|
||||||
|
)
|
||||||
freeze_vision_tower: bool = field(
|
freeze_vision_tower: bool = field(
|
||||||
default=True,
|
default=True,
|
||||||
metadata={"help": "Whether ot not to freeze the vision tower in MLLM training."},
|
metadata={"help": "Whether ot not to freeze the vision tower in MLLM training."},
|
||||||
|
|||||||
@@ -298,23 +298,6 @@ class QuantizationArguments:
|
|||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Device map used to infer the 4-bit quantized model, needs bitsandbytes>=0.43.0."},
|
metadata={"help": "Device map used to infer the 4-bit quantized model, needs bitsandbytes>=0.43.0."},
|
||||||
)
|
)
|
||||||
fp8: bool = field(
|
|
||||||
default=False,
|
|
||||||
metadata={
|
|
||||||
"help": "Enable FP8 mixed precision training via HuggingFace Accelerate. "
|
|
||||||
"Requires PyTorch 2.7+ and Hopper architecture GPUs."
|
|
||||||
},
|
|
||||||
)
|
|
||||||
fp8_backend: str = field(
|
|
||||||
default="auto",
|
|
||||||
metadata={
|
|
||||||
"help": "FP8 backend to use ('auto', 'torchao', 'te', 'msamp'). 'auto' selects best available backend."
|
|
||||||
},
|
|
||||||
)
|
|
||||||
fp8_enable_fsdp_float8_all_gather: bool = field(
|
|
||||||
default=False,
|
|
||||||
metadata={"help": "Enable FP8 optimizations for FSDP2 all-gather operations."},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
@@ -142,14 +142,6 @@ def _verify_model_args(
|
|||||||
logger.warning_rank0("We should use slow tokenizer for the Yi models. Change `use_fast_tokenizer` to False.")
|
logger.warning_rank0("We should use slow tokenizer for the Yi models. Change `use_fast_tokenizer` to False.")
|
||||||
model_args.use_fast_tokenizer = False
|
model_args.use_fast_tokenizer = False
|
||||||
|
|
||||||
# Validate advanced training features
|
|
||||||
if model_args.fp8 and model_args.quantization_bit is not None:
|
|
||||||
raise ValueError("FP8 training is not compatible with quantization. Please disable one of them.")
|
|
||||||
|
|
||||||
if model_args.fp8_enable_fsdp_float8_all_gather and not model_args.fp8:
|
|
||||||
logger.warning_rank0("fp8_enable_fsdp_float8_all_gather requires fp8=True. Setting fp8=True.")
|
|
||||||
model_args.fp8 = True
|
|
||||||
|
|
||||||
|
|
||||||
def _check_extra_dependencies(
|
def _check_extra_dependencies(
|
||||||
model_args: "ModelArguments",
|
model_args: "ModelArguments",
|
||||||
@@ -347,6 +339,9 @@ def get_train_args(args: dict[str, Any] | list[str] | None = None) -> _TRAIN_CLS
|
|||||||
if training_args.deepspeed is not None and (finetuning_args.use_galore or finetuning_args.use_apollo):
|
if training_args.deepspeed is not None and (finetuning_args.use_galore or finetuning_args.use_apollo):
|
||||||
raise ValueError("GaLore and APOLLO are incompatible with DeepSpeed yet.")
|
raise ValueError("GaLore and APOLLO are incompatible with DeepSpeed yet.")
|
||||||
|
|
||||||
|
if training_args.fp8 and training_args.quantization_bit is not None:
|
||||||
|
raise ValueError("FP8 training is not compatible with quantization. Please disable one of them.")
|
||||||
|
|
||||||
if model_args.infer_backend != EngineName.HF:
|
if model_args.infer_backend != EngineName.HF:
|
||||||
raise ValueError("vLLM/SGLang backend is only available for API, CLI and Web.")
|
raise ValueError("vLLM/SGLang backend is only available for API, CLI and Web.")
|
||||||
|
|
||||||
@@ -363,6 +358,10 @@ def get_train_args(args: dict[str, Any] | list[str] | None = None) -> _TRAIN_CLS
|
|||||||
_verify_model_args(model_args, data_args, finetuning_args)
|
_verify_model_args(model_args, data_args, finetuning_args)
|
||||||
_check_extra_dependencies(model_args, finetuning_args, training_args)
|
_check_extra_dependencies(model_args, finetuning_args, training_args)
|
||||||
|
|
||||||
|
if training_args.fp8_enable_fsdp_float8_all_gather and not training_args.fp8:
|
||||||
|
logger.warning_rank0("fp8_enable_fsdp_float8_all_gather requires fp8=True. Setting fp8=True.")
|
||||||
|
model_args.fp8 = True
|
||||||
|
|
||||||
if (
|
if (
|
||||||
training_args.do_train
|
training_args.do_train
|
||||||
and finetuning_args.finetuning_type == "lora"
|
and finetuning_args.finetuning_type == "lora"
|
||||||
|
|||||||
@@ -92,7 +92,30 @@ class RayArguments:
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TrainingArguments(RayArguments, BaseTrainingArguments):
|
class Fp8Arguments:
|
||||||
|
r"""Arguments pertaining to the FP8 training."""
|
||||||
|
|
||||||
|
fp8: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={
|
||||||
|
"help": "Enable FP8 mixed precision training via HuggingFace Accelerate. "
|
||||||
|
"Requires PyTorch 2.7+ and Hopper architecture GPUs."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
fp8_backend: str = field(
|
||||||
|
default="auto",
|
||||||
|
metadata={
|
||||||
|
"help": "FP8 backend to use ('auto', 'torchao', 'te', 'msamp'). 'auto' selects best available backend."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
fp8_enable_fsdp_float8_all_gather: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={"help": "Enable FP8 optimizations for FSDP2 all-gather operations."},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TrainingArguments(Fp8Arguments, RayArguments, BaseTrainingArguments):
|
||||||
r"""Arguments pertaining to the trainer."""
|
r"""Arguments pertaining to the trainer."""
|
||||||
|
|
||||||
overwrite_output_dir: bool = field(
|
overwrite_output_dir: bool = field(
|
||||||
|
|||||||
@@ -15,6 +15,7 @@
|
|||||||
import os
|
import os
|
||||||
from typing import TYPE_CHECKING, Any, Optional, TypedDict
|
from typing import TYPE_CHECKING, Any, Optional, TypedDict
|
||||||
|
|
||||||
|
import torch
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AutoConfig,
|
AutoConfig,
|
||||||
AutoModelForCausalLM,
|
AutoModelForCausalLM,
|
||||||
@@ -29,6 +30,7 @@ from trl import AutoModelForCausalLMWithValueHead
|
|||||||
|
|
||||||
from ..extras import logging
|
from ..extras import logging
|
||||||
from ..extras.misc import count_parameters, skip_check_imports, try_download_model_from_other_hub
|
from ..extras.misc import count_parameters, skip_check_imports, try_download_model_from_other_hub
|
||||||
|
from ..extras.packages import is_torch_version_greater_than
|
||||||
from .adapter import init_adapter
|
from .adapter import init_adapter
|
||||||
from .model_utils.ktransformers import load_kt_pretrained_model
|
from .model_utils.ktransformers import load_kt_pretrained_model
|
||||||
from .model_utils.liger_kernel import apply_liger_kernel
|
from .model_utils.liger_kernel import apply_liger_kernel
|
||||||
@@ -203,6 +205,16 @@ def load_model(
|
|||||||
model.load_state_dict(vhead_params, strict=False)
|
model.load_state_dict(vhead_params, strict=False)
|
||||||
logger.info_rank0(f"Loaded valuehead from checkpoint: {vhead_path}")
|
logger.info_rank0(f"Loaded valuehead from checkpoint: {vhead_path}")
|
||||||
|
|
||||||
|
# Conv3D is not recommended when using torch 2.9.x
|
||||||
|
if is_torch_version_greater_than("2.9.0") and not is_torch_version_greater_than("2.10.0"):
|
||||||
|
if any(isinstance(m, torch.nn.Conv3d) for m in model.modules()):
|
||||||
|
raise ValueError(
|
||||||
|
"Unsupported torch version detected: torch 2.9.x with Conv3D. "
|
||||||
|
"This combination is known to cause severe performance regression. "
|
||||||
|
"Please downgrade torch to <2.9 or remove Conv3D. "
|
||||||
|
"See https://github.com/pytorch/pytorch/issues/166122"
|
||||||
|
)
|
||||||
|
|
||||||
if not is_trainable:
|
if not is_trainable:
|
||||||
model.requires_grad_(False)
|
model.requires_grad_(False)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|||||||
@@ -138,13 +138,14 @@ def patch_config(
|
|||||||
if getattr(config, "model_type", None) == "kimi_vl" and is_trainable:
|
if getattr(config, "model_type", None) == "kimi_vl" and is_trainable:
|
||||||
setattr(config.text_config, "topk_method", "greedy")
|
setattr(config.text_config, "topk_method", "greedy")
|
||||||
|
|
||||||
if "InternVLChatModel" in getattr(config, "architectures", []):
|
architectures = getattr(config, "architectures", None)
|
||||||
|
if isinstance(architectures, list) and "InternVLChatModel" in architectures:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Please download the internvl models in a Hugging Face–compatible format "
|
"Please download the internvl models in a Hugging Face–compatible format "
|
||||||
"(for example, https://huggingface.co/OpenGVLab/InternVL3-8B-hf)."
|
"(for example, https://huggingface.co/OpenGVLab/InternVL3-8B-hf)."
|
||||||
)
|
)
|
||||||
|
|
||||||
if "LlavaLlamaForCausalLM" in getattr(config, "architectures", []):
|
if isinstance(architectures, list) and "LlavaLlamaForCausalLM" in architectures:
|
||||||
raise ValueError("Please download llava models with hf-compatible format: https://huggingface.co/llava-hf")
|
raise ValueError("Please download llava models with hf-compatible format: https://huggingface.co/llava-hf")
|
||||||
|
|
||||||
if getattr(config, "model_type", None) == "internlm3" and not is_transformers_version_greater_than("4.47.1"):
|
if getattr(config, "model_type", None) == "internlm3" and not is_transformers_version_greater_than("4.47.1"):
|
||||||
|
|||||||
@@ -12,35 +12,45 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import os
|
||||||
|
import types
|
||||||
from typing import TYPE_CHECKING, Any, Optional
|
from typing import TYPE_CHECKING, Any, Optional
|
||||||
|
|
||||||
from ..extras import logging
|
from ..extras import logging
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ..hparams import ModelArguments
|
from ..hparams import TrainingArguments
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def create_fp8_kwargs(model_args: "ModelArguments") -> list[Any]:
|
def create_fp8_kwargs(training_args: "TrainingArguments") -> list[Any]:
|
||||||
"""Create AORecipeKwargs for FP8 training with HuggingFace Accelerate.
|
"""Create AORecipeKwargs for FP8 training with HuggingFace Accelerate.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_args: Model arguments containing FP8 configuration
|
training_args: Training arguments containing FP8 configuration
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List containing AORecipeKwargs if FP8 is enabled and supported, empty list otherwise
|
List containing AORecipeKwargs if FP8 is enabled and supported, empty list otherwise
|
||||||
"""
|
"""
|
||||||
if not model_args.fp8:
|
if not training_args.fp8:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
try:
|
backend = getattr(training_args, "fp8_backend", "auto")
|
||||||
# Check if AORecipeKwargs is available (Accelerate 1.8.0+)
|
logger.info_rank0(f"Creating FP8 configuration with backend: {backend}")
|
||||||
from accelerate.utils import AORecipeKwargs
|
|
||||||
|
|
||||||
backend = getattr(model_args, "fp8_backend", "auto")
|
try:
|
||||||
logger.info_rank0(f"Creating FP8 configuration with backend: {backend}")
|
# Use Transformer Engine backend (optimal for Hopper GPUs)
|
||||||
|
if backend == "te":
|
||||||
|
from accelerate.utils import FP8RecipeKwargs
|
||||||
|
|
||||||
|
logger.info_rank0("Using Transformer Engine FP8 backend")
|
||||||
|
return [FP8RecipeKwargs(backend="TE", fp8_format="HYBRID", amax_history_len=16, amax_compute_algo="max")]
|
||||||
|
|
||||||
|
# Use TorchAO backend (default)
|
||||||
|
from accelerate.utils import AORecipeKwargs
|
||||||
|
|
||||||
# Create Float8LinearConfig if torchao backend is used
|
# Create Float8LinearConfig if torchao backend is used
|
||||||
config = None
|
config = None
|
||||||
@@ -83,7 +93,10 @@ def create_fp8_kwargs(model_args: "ModelArguments") -> list[Any]:
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
# Map FSDP all-gather setting if available (this affects the underlying implementation)
|
# Map FSDP all-gather setting if available (this affects the underlying implementation)
|
||||||
if hasattr(model_args, "fp8_enable_fsdp_float8_all_gather") and model_args.fp8_enable_fsdp_float8_all_gather:
|
if (
|
||||||
|
hasattr(training_args, "fp8_enable_fsdp_float8_all_gather")
|
||||||
|
and training_args.fp8_enable_fsdp_float8_all_gather
|
||||||
|
):
|
||||||
logger.info_rank0("FSDP float8 all-gather optimization requested")
|
logger.info_rank0("FSDP float8 all-gather optimization requested")
|
||||||
|
|
||||||
return [AORecipeKwargs(config=config, module_filter_func=module_filter_func)]
|
return [AORecipeKwargs(config=config, module_filter_func=module_filter_func)]
|
||||||
@@ -92,19 +105,19 @@ def create_fp8_kwargs(model_args: "ModelArguments") -> list[Any]:
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
def get_fp8_mixed_precision(model_args: "ModelArguments") -> Optional[str]:
|
def get_fp8_mixed_precision(training_args: "TrainingArguments") -> Optional[str]:
|
||||||
"""Get the mixed precision setting for Accelerate when using FP8.
|
"""Get the mixed precision setting for Accelerate when using FP8.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_args: Model arguments containing FP8 configuration
|
training_args: Training arguments containing FP8 configuration
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
"fp8" if FP8 is enabled, None otherwise
|
"fp8" if FP8 is enabled, None otherwise
|
||||||
"""
|
"""
|
||||||
return "fp8" if model_args.fp8 else None
|
return "fp8" if training_args.fp8 else None
|
||||||
|
|
||||||
|
|
||||||
def configure_fp8_environment(model_args: "ModelArguments") -> None:
|
def configure_fp8_environment(training_args: "TrainingArguments") -> None:
|
||||||
"""Configure FP8 environment for HuggingFace Accelerate.
|
"""Configure FP8 environment for HuggingFace Accelerate.
|
||||||
|
|
||||||
FP8 training is handled entirely through HuggingFace Accelerate, regardless of whether
|
FP8 training is handled entirely through HuggingFace Accelerate, regardless of whether
|
||||||
@@ -112,11 +125,9 @@ def configure_fp8_environment(model_args: "ModelArguments") -> None:
|
|||||||
variables and validates the FP8 configuration.
|
variables and validates the FP8 configuration.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_args: Model arguments containing FP8 configuration
|
training_args: Training arguments containing FP8 configuration
|
||||||
"""
|
"""
|
||||||
import os
|
if not training_args.fp8:
|
||||||
|
|
||||||
if not model_args.fp8:
|
|
||||||
return
|
return
|
||||||
|
|
||||||
# Set mixed precision to fp8 for HuggingFace Accelerate
|
# Set mixed precision to fp8 for HuggingFace Accelerate
|
||||||
@@ -124,38 +135,38 @@ def configure_fp8_environment(model_args: "ModelArguments") -> None:
|
|||||||
logger.info_rank0("Set ACCELERATE_MIXED_PRECISION=fp8")
|
logger.info_rank0("Set ACCELERATE_MIXED_PRECISION=fp8")
|
||||||
|
|
||||||
# Configure FP8 backend and options
|
# Configure FP8 backend and options
|
||||||
backend = getattr(model_args, "fp8_backend", "auto")
|
backend = getattr(training_args, "fp8_backend", "auto")
|
||||||
if backend != "auto":
|
if backend != "auto":
|
||||||
os.environ["FP8_BACKEND"] = backend
|
os.environ["FP8_BACKEND"] = backend
|
||||||
logger.info_rank0(f"Set FP8_BACKEND={backend}")
|
logger.info_rank0(f"Set FP8_BACKEND={backend}")
|
||||||
|
|
||||||
# Create and validate FP8 recipe kwargs (for logging/debugging)
|
# Create and validate FP8 recipe kwargs (for logging/debugging)
|
||||||
fp8_kwargs = create_fp8_kwargs(model_args)
|
fp8_kwargs = create_fp8_kwargs(training_args)
|
||||||
logger.info_rank0(f"FP8 AORecipeKwargs created: {len(fp8_kwargs)} items")
|
logger.info_rank0(f"FP8 AORecipeKwargs created: {len(fp8_kwargs)} items")
|
||||||
|
|
||||||
# Enable FSDP float8 all-gather optimization if requested
|
# Enable FSDP float8 all-gather optimization if requested
|
||||||
if hasattr(model_args, "fp8_enable_fsdp_float8_all_gather") and model_args.fp8_enable_fsdp_float8_all_gather:
|
if hasattr(training_args, "fp8_enable_fsdp_float8_all_gather") and training_args.fp8_enable_fsdp_float8_all_gather:
|
||||||
os.environ["FP8_ENABLE_FSDP_FLOAT8_ALL_GATHER"] = "true"
|
os.environ["FP8_ENABLE_FSDP_FLOAT8_ALL_GATHER"] = "true"
|
||||||
logger.info_rank0("Set FP8_ENABLE_FSDP_FLOAT8_ALL_GATHER=true")
|
logger.info_rank0("Set FP8_ENABLE_FSDP_FLOAT8_ALL_GATHER=true")
|
||||||
|
|
||||||
logger.info_rank0("FP8 environment configured - all FP8 training handled by HuggingFace Accelerate")
|
logger.info_rank0("FP8 environment configured - all FP8 training handled by HuggingFace Accelerate")
|
||||||
|
|
||||||
|
|
||||||
def verify_fp8_status(accelerator, model_args: "ModelArguments") -> None:
|
def verify_fp8_status(accelerator, training_args: "TrainingArguments") -> None:
|
||||||
"""Verify that FP8 training is actually working after model preparation.
|
"""Verify that FP8 training is actually working after model preparation.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
accelerator: The HuggingFace Accelerator instance
|
accelerator: The HuggingFace Accelerator instance
|
||||||
model_args: Model arguments containing FP8 configuration
|
training_args: Training arguments containing FP8 configuration
|
||||||
"""
|
"""
|
||||||
if not model_args.fp8:
|
if not training_args.fp8:
|
||||||
return
|
return
|
||||||
|
|
||||||
# Check Accelerate's FP8 status
|
# Check Accelerate's FP8 status
|
||||||
fp8_enabled = getattr(accelerator, "fp8_enabled", False)
|
fp8_enabled = getattr(accelerator, "fp8_enabled", False)
|
||||||
fp8_backend_type = getattr(accelerator, "fp8_backend", "UNKNOWN")
|
fp8_backend_type = getattr(accelerator, "fp8_backend", "UNKNOWN")
|
||||||
|
|
||||||
backend = getattr(model_args, "fp8_backend", "auto")
|
backend = getattr(training_args, "fp8_backend", "auto")
|
||||||
if backend == "torchao" or backend == "auto":
|
if backend == "torchao" or backend == "auto":
|
||||||
logger.info_rank0(
|
logger.info_rank0(
|
||||||
"FP8 training enabled with TorchAO backend. For optimal performance, "
|
"FP8 training enabled with TorchAO backend. For optimal performance, "
|
||||||
@@ -169,3 +180,50 @@ def verify_fp8_status(accelerator, model_args: "ModelArguments") -> None:
|
|||||||
|
|
||||||
if not fp8_enabled:
|
if not fp8_enabled:
|
||||||
logger.info_rank0("WARNING: FP8 was requested but Accelerate shows fp8_enabled=False. FP8 may not be working.")
|
logger.info_rank0("WARNING: FP8 was requested but Accelerate shows fp8_enabled=False. FP8 may not be working.")
|
||||||
|
|
||||||
|
|
||||||
|
def patch_accelerator_for_fp8() -> None:
|
||||||
|
"""Patch Accelerator to inject FP8 recipe kwargs.
|
||||||
|
|
||||||
|
This is needed because HuggingFace Trainer doesn't pass kwargs_handlers to Accelerator.
|
||||||
|
We monkey-patch Accelerator.__init__ to inject the FP8 recipe and force mixed_precision='fp8'.
|
||||||
|
"""
|
||||||
|
import transformer_engine.pytorch as te
|
||||||
|
from accelerate import Accelerator
|
||||||
|
|
||||||
|
# Guard against multiple patches
|
||||||
|
if getattr(Accelerator, "_te_fp8_patched", False):
|
||||||
|
return
|
||||||
|
|
||||||
|
# Stub for Accelerate 1.12+ compatibility (te.fp8.check_mxfp8_support doesn't exist yet)
|
||||||
|
if not hasattr(te, "fp8"):
|
||||||
|
te.fp8 = types.ModuleType("fp8")
|
||||||
|
te.fp8.check_mxfp8_support = lambda: (False, "MXFP8 not supported")
|
||||||
|
|
||||||
|
try:
|
||||||
|
from accelerate.utils import TERecipeKwargs as FP8Recipe
|
||||||
|
|
||||||
|
use_te_recipe = True
|
||||||
|
except ImportError:
|
||||||
|
from accelerate.utils import FP8RecipeKwargs as FP8Recipe
|
||||||
|
|
||||||
|
use_te_recipe = False
|
||||||
|
|
||||||
|
original_init = Accelerator.__init__
|
||||||
|
|
||||||
|
def patched_init(self, *args, **kwargs):
|
||||||
|
if "kwargs_handlers" not in kwargs or not kwargs["kwargs_handlers"]:
|
||||||
|
if use_te_recipe:
|
||||||
|
kwargs["kwargs_handlers"] = [
|
||||||
|
FP8Recipe(fp8_format="HYBRID", amax_history_len=16, amax_compute_algo="max")
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
kwargs["kwargs_handlers"] = [
|
||||||
|
FP8Recipe(backend="TE", fp8_format="HYBRID", amax_history_len=16, amax_compute_algo="max")
|
||||||
|
]
|
||||||
|
# Only force mixed_precision when we inject handlers
|
||||||
|
kwargs["mixed_precision"] = "fp8"
|
||||||
|
return original_init(self, *args, **kwargs)
|
||||||
|
|
||||||
|
Accelerator.__init__ = patched_init
|
||||||
|
Accelerator._te_fp8_patched = True
|
||||||
|
|||||||
@@ -19,16 +19,15 @@ import torch
|
|||||||
from transformers import Trainer
|
from transformers import Trainer
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
from ...extras.packages import is_transformers_version_greater_than
|
|
||||||
from ..callbacks import SaveProcessorCallback
|
from ..callbacks import SaveProcessorCallback
|
||||||
from ..fp8_utils import configure_fp8_environment, verify_fp8_status
|
from ..fp8_utils import configure_fp8_environment, patch_accelerator_for_fp8, verify_fp8_status
|
||||||
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
|
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import ProcessorMixin
|
from transformers import ProcessorMixin
|
||||||
|
|
||||||
from ...hparams import FinetuningArguments, ModelArguments
|
from ...hparams import FinetuningArguments, ModelArguments, TrainingArguments
|
||||||
|
|
||||||
|
|
||||||
class CustomTrainer(Trainer):
|
class CustomTrainer(Trainer):
|
||||||
@@ -41,11 +40,13 @@ class CustomTrainer(Trainer):
|
|||||||
model_args: Optional["ModelArguments"] = None,
|
model_args: Optional["ModelArguments"] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
kwargs["processing_class"] = kwargs.pop("tokenizer")
|
||||||
# Configure FP8 environment if enabled
|
# Configure FP8 environment if enabled
|
||||||
if model_args is not None and model_args.fp8:
|
training_args: TrainingArguments = kwargs.get("args")
|
||||||
configure_fp8_environment(model_args)
|
if training_args.fp8:
|
||||||
if is_transformers_version_greater_than("4.46"):
|
configure_fp8_environment(training_args)
|
||||||
kwargs["processing_class"] = kwargs.pop("tokenizer")
|
if getattr(training_args, "fp8_backend", "auto") == "te":
|
||||||
|
patch_accelerator_for_fp8()
|
||||||
|
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
if processor is not None:
|
if processor is not None:
|
||||||
@@ -64,9 +65,8 @@ class CustomTrainer(Trainer):
|
|||||||
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
|
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
|
||||||
self.add_callback(BAdamCallback)
|
self.add_callback(BAdamCallback)
|
||||||
|
|
||||||
# Verify FP8 status after trainer initialization (accelerator should be available)
|
if training_args.fp8 and hasattr(self, "accelerator"): # verify FP8 status after trainer initialization
|
||||||
if model_args is not None and model_args.fp8 and hasattr(self, "accelerator"):
|
verify_fp8_status(self.accelerator, training_args)
|
||||||
verify_fp8_status(self.accelerator, model_args)
|
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def create_optimizer(self) -> "torch.optim.Optimizer":
|
def create_optimizer(self) -> "torch.optim.Optimizer":
|
||||||
|
|||||||
@@ -27,18 +27,17 @@ from typing_extensions import override
|
|||||||
|
|
||||||
from ...extras import logging
|
from ...extras import logging
|
||||||
from ...extras.constants import IGNORE_INDEX
|
from ...extras.constants import IGNORE_INDEX
|
||||||
from ...extras.packages import is_transformers_version_greater_than
|
|
||||||
from ..callbacks import SaveProcessorCallback
|
from ..callbacks import SaveProcessorCallback
|
||||||
from ..fp8_utils import configure_fp8_environment, verify_fp8_status
|
from ..fp8_utils import configure_fp8_environment, patch_accelerator_for_fp8, verify_fp8_status
|
||||||
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
|
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
from transformers import PreTrainedTokenizer, ProcessorMixin
|
from transformers import ProcessorMixin
|
||||||
from transformers.trainer import PredictionOutput
|
from transformers.trainer import PredictionOutput
|
||||||
|
|
||||||
from ...hparams import FinetuningArguments, ModelArguments
|
from ...hparams import FinetuningArguments, ModelArguments, TrainingArguments
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
@@ -55,13 +54,13 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
|||||||
gen_kwargs: Optional[dict[str, Any]] = None,
|
gen_kwargs: Optional[dict[str, Any]] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
kwargs["processing_class"] = kwargs.pop("tokenizer")
|
||||||
# Configure FP8 environment if enabled
|
# Configure FP8 environment if enabled
|
||||||
if model_args is not None and model_args.fp8:
|
training_args: TrainingArguments = kwargs.get("args")
|
||||||
configure_fp8_environment(model_args)
|
if training_args.fp8:
|
||||||
if is_transformers_version_greater_than("4.46"):
|
configure_fp8_environment(training_args)
|
||||||
kwargs["processing_class"] = kwargs.pop("tokenizer")
|
if getattr(training_args, "fp8_backend", "auto") == "te":
|
||||||
else:
|
patch_accelerator_for_fp8()
|
||||||
self.processing_class: PreTrainedTokenizer = kwargs.get("tokenizer")
|
|
||||||
|
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
if processor is not None:
|
if processor is not None:
|
||||||
@@ -88,9 +87,15 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
|||||||
|
|
||||||
self.compute_loss_func = dft_loss_func
|
self.compute_loss_func = dft_loss_func
|
||||||
|
|
||||||
# Verify FP8 status after trainer initialization (accelerator should be available)
|
elif finetuning_args.use_eaft_loss:
|
||||||
if model_args is not None and model_args.fp8 and hasattr(self, "accelerator"):
|
from ..trainer_utils import eaft_loss_func
|
||||||
verify_fp8_status(self.accelerator, model_args)
|
|
||||||
|
self.compute_loss_func = lambda outputs, labels, num_items_in_batch=None: eaft_loss_func(
|
||||||
|
outputs, labels, num_items_in_batch, finetuning_args.eaft_alpha
|
||||||
|
)
|
||||||
|
|
||||||
|
if training_args.fp8 and hasattr(self, "accelerator"): # verify FP8 status after trainer initialization
|
||||||
|
verify_fp8_status(self.accelerator, training_args)
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def create_optimizer(self) -> "torch.optim.Optimizer":
|
def create_optimizer(self) -> "torch.optim.Optimizer":
|
||||||
|
|||||||
@@ -634,7 +634,9 @@ def get_batch_logps(
|
|||||||
return logps, valid_length
|
return logps, valid_length
|
||||||
|
|
||||||
|
|
||||||
def dft_loss_func(outputs, labels, num_items_in_batch=None):
|
def dft_loss_func(
|
||||||
|
outputs: "torch.Tensor", labels: "torch.Tensor", num_items_in_batch: Optional["torch.Tensor"] = None
|
||||||
|
):
|
||||||
logits = outputs.get("logits")
|
logits = outputs.get("logits")
|
||||||
if logits is None:
|
if logits is None:
|
||||||
return outputs.get("loss", torch.tensor(0.0))
|
return outputs.get("loss", torch.tensor(0.0))
|
||||||
@@ -652,11 +654,11 @@ def dft_loss_func(outputs, labels, num_items_in_batch=None):
|
|||||||
|
|
||||||
|
|
||||||
def _dft_cross_entropy(
|
def _dft_cross_entropy(
|
||||||
source: torch.Tensor,
|
source: "torch.Tensor",
|
||||||
target: torch.Tensor,
|
target: "torch.Tensor",
|
||||||
num_items_in_batch: Optional[torch.Tensor] = None,
|
num_items_in_batch: Optional["torch.Tensor"] = None,
|
||||||
ignore_index: int = -100,
|
ignore_index: int = -100,
|
||||||
) -> torch.Tensor:
|
) -> "torch.Tensor":
|
||||||
per_token_loss = torch.nn.functional.cross_entropy(source, target, ignore_index=ignore_index, reduction="none")
|
per_token_loss = torch.nn.functional.cross_entropy(source, target, ignore_index=ignore_index, reduction="none")
|
||||||
valid_mask = target != ignore_index
|
valid_mask = target != ignore_index
|
||||||
if not valid_mask.any():
|
if not valid_mask.any():
|
||||||
@@ -679,6 +681,67 @@ def _dft_cross_entropy(
|
|||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
def eaft_loss_func(
|
||||||
|
outputs: "torch.Tensor",
|
||||||
|
labels: "torch.Tensor",
|
||||||
|
num_items_in_batch: Optional["torch.Tensor"] = None,
|
||||||
|
alpha: float = 1.0,
|
||||||
|
) -> "torch.Tensor":
|
||||||
|
logits = outputs.get("logits")
|
||||||
|
if logits is None:
|
||||||
|
return outputs.get("loss", torch.tensor(0.0))
|
||||||
|
|
||||||
|
logits = logits.float()
|
||||||
|
vocab_size = logits.size(-1)
|
||||||
|
labels = torch.nn.functional.pad(labels, (0, 1), value=-100)
|
||||||
|
shift_labels = labels[..., 1:].contiguous()
|
||||||
|
logits = logits.view(-1, vocab_size)
|
||||||
|
shift_labels = shift_labels.view(-1)
|
||||||
|
shift_labels = shift_labels.to(logits.device)
|
||||||
|
|
||||||
|
loss = _eaft_cross_entropy(logits, shift_labels, num_items_in_batch, alpha)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
def _eaft_cross_entropy(
|
||||||
|
source: "torch.Tensor",
|
||||||
|
target: "torch.Tensor",
|
||||||
|
num_items_in_batch: Optional["torch.Tensor"] = None,
|
||||||
|
alpha: float = 1.0,
|
||||||
|
ignore_index: int = -100,
|
||||||
|
) -> "torch.Tensor":
|
||||||
|
per_token_loss = torch.nn.functional.cross_entropy(source, target, ignore_index=ignore_index, reduction="none")
|
||||||
|
valid_mask = target != ignore_index
|
||||||
|
if not valid_mask.any():
|
||||||
|
return torch.tensor(0.0, device=source.device, dtype=source.dtype)
|
||||||
|
|
||||||
|
valid_losses = per_token_loss[valid_mask]
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
source_detached = source[valid_mask].detach()
|
||||||
|
|
||||||
|
topk_val, _ = torch.topk(source_detached, k=20, dim=-1)
|
||||||
|
logsumexp_topk = torch.logsumexp(topk_val, dim=-1, keepdim=True)
|
||||||
|
log_probs_topk = topk_val - logsumexp_topk
|
||||||
|
probs_topk = torch.exp(log_probs_topk)
|
||||||
|
entropy_approx = -(probs_topk * log_probs_topk).sum(dim=-1)
|
||||||
|
|
||||||
|
entropy_term = entropy_approx / 3.0
|
||||||
|
adaptive_weight = torch.pow(entropy_term, alpha)
|
||||||
|
|
||||||
|
weighted_losses = valid_losses * adaptive_weight
|
||||||
|
|
||||||
|
if num_items_in_batch is not None:
|
||||||
|
total_loss = weighted_losses.sum()
|
||||||
|
if torch.is_tensor(num_items_in_batch):
|
||||||
|
num_items_in_batch = num_items_in_batch.to(total_loss.device)
|
||||||
|
loss = total_loss / num_items_in_batch
|
||||||
|
else:
|
||||||
|
loss = weighted_losses.mean()
|
||||||
|
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
def nested_detach(
|
def nested_detach(
|
||||||
tensors: Union["torch.Tensor", list["torch.Tensor"], tuple["torch.Tensor"], dict[str, "torch.Tensor"]],
|
tensors: Union["torch.Tensor", list["torch.Tensor"], tuple["torch.Tensor"], dict[str, "torch.Tensor"]],
|
||||||
clone: bool = False,
|
clone: bool = False,
|
||||||
|
|||||||
@@ -119,9 +119,19 @@ def synchronize() -> None:
|
|||||||
|
|
||||||
|
|
||||||
@requires_accelerator
|
@requires_accelerator
|
||||||
def set_device() -> None:
|
def set_device_index() -> None:
|
||||||
"""Set current accelerator."""
|
"""Set current accelerator index to local rank."""
|
||||||
torch.accelerator.set_device_index(get_local_rank())
|
if get_current_accelerator().type != DeviceType.CPU:
|
||||||
|
torch.accelerator.set_device_index(get_local_rank())
|
||||||
|
|
||||||
|
|
||||||
|
@requires_accelerator
|
||||||
|
def get_current_device() -> torch.device:
|
||||||
|
"""Get current accelerator device."""
|
||||||
|
if get_current_accelerator().type == DeviceType.CPU:
|
||||||
|
return torch.device(DeviceType.CPU.value)
|
||||||
|
else:
|
||||||
|
return torch.device(type=get_current_accelerator().type, index=torch.accelerator.current_device_index())
|
||||||
|
|
||||||
|
|
||||||
def is_torch_cuda_available():
|
def is_torch_cuda_available():
|
||||||
|
|||||||
@@ -34,10 +34,14 @@ from typing import Any, Optional
|
|||||||
from torch.distributed import barrier, destroy_process_group, init_process_group
|
from torch.distributed import barrier, destroy_process_group, init_process_group
|
||||||
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
|
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
|
||||||
|
|
||||||
|
from ..utils import logging
|
||||||
from ..utils.types import DistributedConfig, ProcessGroup, Tensor, TensorLike
|
from ..utils.types import DistributedConfig, ProcessGroup, Tensor, TensorLike
|
||||||
from . import helper
|
from . import helper
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class Dim(str, Enum):
|
class Dim(str, Enum):
|
||||||
"""Dimension names."""
|
"""Dimension names."""
|
||||||
|
|
||||||
@@ -119,12 +123,13 @@ class DistributedInterface:
|
|||||||
if self._initialized:
|
if self._initialized:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
helper.set_device_index()
|
||||||
self._is_distributed = helper.is_distributed()
|
self._is_distributed = helper.is_distributed()
|
||||||
self._rank = helper.get_rank()
|
self._rank = helper.get_rank()
|
||||||
self._world_size = helper.get_world_size()
|
self._world_size = helper.get_world_size()
|
||||||
self._local_rank = helper.get_local_rank()
|
self._local_rank = helper.get_local_rank()
|
||||||
self._local_world_size = helper.get_local_world_size()
|
self._local_world_size = helper.get_local_world_size()
|
||||||
self.current_accelerator = helper.get_current_accelerator()
|
self.current_device = helper.get_current_device()
|
||||||
self.device_count = helper.get_device_count()
|
self.device_count = helper.get_device_count()
|
||||||
|
|
||||||
if config is None:
|
if config is None:
|
||||||
@@ -140,15 +145,14 @@ class DistributedInterface:
|
|||||||
timeout = config.get("timeout", 18000)
|
timeout = config.get("timeout", 18000)
|
||||||
|
|
||||||
if self._is_distributed:
|
if self._is_distributed:
|
||||||
helper.set_device()
|
|
||||||
init_process_group(timeout=timedelta(seconds=timeout))
|
init_process_group(timeout=timedelta(seconds=timeout))
|
||||||
self.model_device_mesh = init_device_mesh(
|
self.model_device_mesh = init_device_mesh(
|
||||||
device_type=self.current_accelerator.type,
|
device_type=self.current_device.type,
|
||||||
mesh_shape=self.strategy.model_mesh_shape,
|
mesh_shape=self.strategy.model_mesh_shape,
|
||||||
mesh_dim_names=self.strategy.model_mesh_dim_names,
|
mesh_dim_names=self.strategy.model_mesh_dim_names,
|
||||||
)
|
)
|
||||||
self.data_device_mesh = init_device_mesh(
|
self.data_device_mesh = init_device_mesh(
|
||||||
device_type=self.current_accelerator.type,
|
device_type=self.current_device.type,
|
||||||
mesh_shape=self.strategy.data_mesh_shape,
|
mesh_shape=self.strategy.data_mesh_shape,
|
||||||
mesh_dim_names=self.strategy.data_mesh_dim_names,
|
mesh_dim_names=self.strategy.data_mesh_dim_names,
|
||||||
)
|
)
|
||||||
@@ -157,11 +161,12 @@ class DistributedInterface:
|
|||||||
self.data_device_mesh = None
|
self.data_device_mesh = None
|
||||||
|
|
||||||
self._initialized = True
|
self._initialized = True
|
||||||
|
logger.info_rank0(f"DistributedInterface initialized: {self}.")
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
return (
|
return (
|
||||||
f"DistributedInterface(strategy={self.strategy}), is_distributed={self._is_distributed}, "
|
f"DistributedInterface(strategy={self.strategy}), is_distributed={self._is_distributed}, "
|
||||||
f"current_accelerator={self.current_accelerator}, rank={self._rank}, world_size={self._world_size}, "
|
f"current_device={self.current_device}, rank={self._rank}, world_size={self._world_size}, "
|
||||||
f"model_device_mesh={self.model_device_mesh}, data_device_mesh={self.data_device_mesh}"
|
f"model_device_mesh={self.model_device_mesh}, data_device_mesh={self.data_device_mesh}"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -246,4 +251,7 @@ class DistributedInterface:
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
print(DistributedInterface(DistributedStrategy()))
|
"""
|
||||||
|
python -m llamafactory.v1.accelerator.interface
|
||||||
|
"""
|
||||||
|
print(DistributedInterface())
|
||||||
|
|||||||
@@ -0,0 +1,32 @@
|
|||||||
|
# Copyright 2025 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from .arg_parser import InputArgument, get_args
|
||||||
|
from .arg_utils import ModelClass, SampleBackend
|
||||||
|
from .data_args import DataArguments
|
||||||
|
from .model_args import ModelArguments
|
||||||
|
from .sample_args import SampleArguments
|
||||||
|
from .training_args import TrainingArguments
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"DataArguments",
|
||||||
|
"InputArgument",
|
||||||
|
"ModelArguments",
|
||||||
|
"ModelClass",
|
||||||
|
"SampleArguments",
|
||||||
|
"SampleBackend",
|
||||||
|
"TrainingArguments",
|
||||||
|
"get_args",
|
||||||
|
]
|
||||||
|
|||||||
@@ -17,7 +17,7 @@
|
|||||||
|
|
||||||
|
|
||||||
import json
|
import json
|
||||||
from enum import Enum, unique
|
from enum import StrEnum, unique
|
||||||
|
|
||||||
|
|
||||||
class PluginConfig(dict):
|
class PluginConfig(dict):
|
||||||
@@ -36,7 +36,7 @@ PluginArgument = PluginConfig | dict | str | None
|
|||||||
|
|
||||||
|
|
||||||
@unique
|
@unique
|
||||||
class ModelClass(str, Enum):
|
class ModelClass(StrEnum):
|
||||||
"""Auto class for model config."""
|
"""Auto class for model config."""
|
||||||
|
|
||||||
LLM = "llm"
|
LLM = "llm"
|
||||||
@@ -45,7 +45,7 @@ class ModelClass(str, Enum):
|
|||||||
|
|
||||||
|
|
||||||
@unique
|
@unique
|
||||||
class SampleBackend(str, Enum):
|
class SampleBackend(StrEnum):
|
||||||
HF = "hf"
|
HF = "hf"
|
||||||
VLLM = "vllm"
|
VLLM = "vllm"
|
||||||
|
|
||||||
|
|||||||
@@ -21,20 +21,25 @@ from .arg_utils import ModelClass, PluginConfig, get_plugin_config
|
|||||||
@dataclass
|
@dataclass
|
||||||
class ModelArguments:
|
class ModelArguments:
|
||||||
model: str = field(
|
model: str = field(
|
||||||
|
default="Qwen/Qwen3-4B-Instruct-2507",
|
||||||
metadata={"help": "Path to the model or model identifier from Hugging Face."},
|
metadata={"help": "Path to the model or model identifier from Hugging Face."},
|
||||||
)
|
)
|
||||||
|
template: str = field(
|
||||||
|
default="qwen3_nothink",
|
||||||
|
metadata={"help": "Template for the model."},
|
||||||
|
)
|
||||||
trust_remote_code: bool = field(
|
trust_remote_code: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Trust remote code from Hugging Face."},
|
metadata={"help": "Trust remote code from Hugging Face."},
|
||||||
)
|
)
|
||||||
use_fast_processor: bool = field(
|
|
||||||
default=True,
|
|
||||||
metadata={"help": "Use fast processor from Hugging Face."},
|
|
||||||
)
|
|
||||||
model_class: ModelClass = field(
|
model_class: ModelClass = field(
|
||||||
default=ModelClass.LLM,
|
default=ModelClass.LLM,
|
||||||
metadata={"help": "Model class from Hugging Face."},
|
metadata={"help": "Model class from Hugging Face."},
|
||||||
)
|
)
|
||||||
|
init_config: PluginConfig | None = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "Initialization configuration for the model."},
|
||||||
|
)
|
||||||
peft_config: PluginConfig | None = field(
|
peft_config: PluginConfig | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "PEFT configuration for the model."},
|
metadata={"help": "PEFT configuration for the model."},
|
||||||
@@ -49,6 +54,7 @@ class ModelArguments:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
|
self.init_config = get_plugin_config(self.init_config)
|
||||||
self.peft_config = get_plugin_config(self.peft_config)
|
self.peft_config = get_plugin_config(self.peft_config)
|
||||||
self.kernel_config = get_plugin_config(self.kernel_config)
|
self.kernel_config = get_plugin_config(self.kernel_config)
|
||||||
self.quant_config = get_plugin_config(self.quant_config)
|
self.quant_config = get_plugin_config(self.quant_config)
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ from .arg_utils import PluginConfig, get_plugin_config
|
|||||||
@dataclass
|
@dataclass
|
||||||
class TrainingArguments:
|
class TrainingArguments:
|
||||||
output_dir: str = field(
|
output_dir: str = field(
|
||||||
default=os.path.join("outputs", str(uuid4())),
|
default=os.path.join("outputs", str(uuid4().hex)),
|
||||||
metadata={"help": "Path to the output directory."},
|
metadata={"help": "Path to the output directory."},
|
||||||
)
|
)
|
||||||
micro_batch_size: int = field(
|
micro_batch_size: int = field(
|
||||||
|
|||||||
181
src/llamafactory/v1/core/base_sampler.py
Normal file
181
src/llamafactory/v1/core/base_sampler.py
Normal file
@@ -0,0 +1,181 @@
|
|||||||
|
# Copyright 2025 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from collections.abc import AsyncGenerator, Generator
|
||||||
|
from threading import Thread
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from transformers import TextIteratorStreamer
|
||||||
|
|
||||||
|
from ..accelerator.interface import DistributedInterface
|
||||||
|
from ..config import ModelArguments, SampleArguments, SampleBackend
|
||||||
|
from ..utils.helper import get_tokenizer
|
||||||
|
from ..utils.types import HFModel, Message, Sample, TorchDataset
|
||||||
|
from .utils.rendering import Renderer
|
||||||
|
|
||||||
|
|
||||||
|
class BaseEngine(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
args: SampleArguments,
|
||||||
|
model_args: ModelArguments,
|
||||||
|
model: HFModel,
|
||||||
|
renderer: Renderer,
|
||||||
|
) -> None:
|
||||||
|
"""Initialize the engine.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
args: Sample arguments.
|
||||||
|
model_args: Model arguments.
|
||||||
|
model: Model.
|
||||||
|
renderer: Renderer.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def generate(self, messages: list[Message], tools: str | None = None) -> AsyncGenerator[str, None]:
|
||||||
|
"""Generate tokens asynchronously.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: List of messages.
|
||||||
|
tools: Tools string.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
Generated tokens.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def batch_infer(self, dataset: TorchDataset) -> list[Sample]:
|
||||||
|
"""Batch infer samples.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset: Torch dataset.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of samples.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class HuggingFaceEngine(BaseEngine):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
args: SampleArguments,
|
||||||
|
model_args: ModelArguments,
|
||||||
|
model: HFModel,
|
||||||
|
renderer: Renderer,
|
||||||
|
) -> None:
|
||||||
|
self.args = args
|
||||||
|
self.model_args = model_args
|
||||||
|
self.model = model
|
||||||
|
self.renderer = renderer
|
||||||
|
self.semaphore = asyncio.Semaphore(int(os.getenv("MAX_CONCURRENT", "1")))
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def get_response(self, messages: list[Message], tools: str | None = None) -> Generator[str, None, None]:
|
||||||
|
model_inputs = self.renderer.render_messages(messages, tools, is_generate=True)
|
||||||
|
streamer = TextIteratorStreamer(
|
||||||
|
tokenizer=get_tokenizer(self.renderer.processor),
|
||||||
|
skip_prompt=True,
|
||||||
|
skip_special_tokens=True, # TODO: configurable
|
||||||
|
)
|
||||||
|
device = DistributedInterface().current_device
|
||||||
|
kwargs = {
|
||||||
|
"input_ids": torch.tensor([model_inputs["input_ids"]]).to(device),
|
||||||
|
"attention_mask": torch.tensor([model_inputs["attention_mask"]]).to(device),
|
||||||
|
"max_new_tokens": self.args.max_new_tokens,
|
||||||
|
"streamer": streamer,
|
||||||
|
}
|
||||||
|
thread = Thread(target=self.model.generate, kwargs=kwargs, daemon=True)
|
||||||
|
thread.start()
|
||||||
|
|
||||||
|
def stream():
|
||||||
|
try:
|
||||||
|
return streamer.__next__()
|
||||||
|
except StopIteration:
|
||||||
|
raise StopAsyncIteration()
|
||||||
|
|
||||||
|
return stream
|
||||||
|
|
||||||
|
async def generate(self, messages: list[Message], tools: str | None = None) -> AsyncGenerator[str, None]:
|
||||||
|
async with self.semaphore:
|
||||||
|
response = self.get_response(messages, tools)
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
yield await asyncio.to_thread(response)
|
||||||
|
except StopAsyncIteration:
|
||||||
|
break
|
||||||
|
|
||||||
|
async def batch_infer(self, dataset: TorchDataset) -> list[Sample]:
|
||||||
|
"""Batch infer samples.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset: Torch dataset.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of samples.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError("Batch infer is not implemented.")
|
||||||
|
|
||||||
|
|
||||||
|
class BaseSampler:
|
||||||
|
"""Base sampler.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
args: Sample arguments.
|
||||||
|
model_args: Model arguments.
|
||||||
|
model: Model.
|
||||||
|
renderer: Renderer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
args: SampleArguments,
|
||||||
|
model_args: ModelArguments,
|
||||||
|
model: HFModel,
|
||||||
|
renderer: Renderer,
|
||||||
|
) -> None:
|
||||||
|
if args.sample_backend == SampleBackend.HF:
|
||||||
|
self.engine = HuggingFaceEngine(args, model_args, model, renderer)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown sample backend: {args.sample_backend}")
|
||||||
|
|
||||||
|
async def generate(self, messages: list[Message], tools: str | None = None) -> AsyncGenerator[str, None]:
|
||||||
|
"""Generate tokens asynchronously.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: List of messages.
|
||||||
|
tools: Tools string.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
Generated tokens.
|
||||||
|
"""
|
||||||
|
async for token in self.engine.generate(messages, tools):
|
||||||
|
yield token
|
||||||
|
|
||||||
|
async def batch_infer(self, dataset: TorchDataset) -> list[Sample]:
|
||||||
|
"""Batch infer samples.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset: Torch dataset.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of samples.
|
||||||
|
"""
|
||||||
|
return await self.engine.batch_infer(dataset)
|
||||||
@@ -1,44 +0,0 @@
|
|||||||
# Copyright 2025 the LlamaFactory team.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
|
|
||||||
from ..config.sample_args import SampleArguments, SampleBackend
|
|
||||||
from .model_loader import ModelLoader
|
|
||||||
|
|
||||||
|
|
||||||
class BaseEngine(ABC):
|
|
||||||
@abstractmethod
|
|
||||||
def __init__(self, sample_args: SampleArguments, model_loader: ModelLoader) -> None: ...
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def generate(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def batch_infer(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class HuggingFaceEngine(BaseEngine):
|
|
||||||
def __init__(self, model_loader: ModelLoader, sample_args: SampleArguments) -> None:
|
|
||||||
self.args = sample_args
|
|
||||||
|
|
||||||
|
|
||||||
class ChatSampler:
|
|
||||||
def __init__(self, model_loader: ModelLoader, sample_args: SampleArguments) -> None:
|
|
||||||
if sample_args.sample_backend == SampleBackend.HF:
|
|
||||||
self.engine = HuggingFaceEngine(model_loader, sample_args)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unknown sample backend: {sample_args.sample_backend}")
|
|
||||||
@@ -14,15 +14,23 @@
|
|||||||
|
|
||||||
"""The definition of data engine.
|
"""The definition of data engine.
|
||||||
|
|
||||||
Init Data engine:
|
How to use:
|
||||||
|
data_engine = DataEngine(data_args)
|
||||||
|
data_engine[i]: Get the sample via index.
|
||||||
|
|
||||||
|
Init workflow:
|
||||||
1. Parse dataset info from arguments.
|
1. Parse dataset info from arguments.
|
||||||
2. Load datasets according to dataset info.
|
2. Load datasets according to dataset info.
|
||||||
3. Build data index (and reweight samples if necessary).
|
3. Build data index (and reweight samples if necessary).
|
||||||
|
|
||||||
Get Data Sample:
|
Get data sample:
|
||||||
1. Get sample from data index.
|
1. Get sample from data index.
|
||||||
2. Convert sample to standard format.
|
2. Convert sample to standard format.
|
||||||
3. Return sample.
|
3. Return sample.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
1. The data engine is equivalent to the torch dataset.
|
||||||
|
2. The data engine is agnostic to the model used.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
@@ -98,10 +106,10 @@ class DataEngine(Dataset):
|
|||||||
|
|
||||||
size = self.dataset_infos[dataset_name].get("size")
|
size = self.dataset_infos[dataset_name].get("size")
|
||||||
weight = self.dataset_infos[dataset_name].get("weight")
|
weight = self.dataset_infos[dataset_name].get("weight")
|
||||||
if size or weight: # data index plugin
|
if size or weight:
|
||||||
from ..plugins.data_plugins.loader import DataIndexPlugin
|
from ..plugins.data_plugins.loader import adjust_data_index
|
||||||
|
|
||||||
data_index = DataIndexPlugin().adjust_data_index(data_index, size, weight)
|
data_index = adjust_data_index(data_index, size, weight)
|
||||||
|
|
||||||
self.data_index.extend(data_index)
|
self.data_index.extend(data_index)
|
||||||
|
|
||||||
@@ -150,9 +158,9 @@ class DataEngine(Dataset):
|
|||||||
dataset_name, sample_index = self.data_index[index]
|
dataset_name, sample_index = self.data_index[index]
|
||||||
return self._convert_data_sample(self.datasets[dataset_name][sample_index], dataset_name)
|
return self._convert_data_sample(self.datasets[dataset_name][sample_index], dataset_name)
|
||||||
else: # data selector plugin
|
else: # data selector plugin
|
||||||
from ..plugins.data_plugins.loader import DataSelectorPlugin
|
from ..plugins.data_plugins.loader import select_data_sample
|
||||||
|
|
||||||
selected_index = DataSelectorPlugin().select(self.data_index, index)
|
selected_index = select_data_sample(self.data_index, index)
|
||||||
if isinstance(selected_index, list):
|
if isinstance(selected_index, list):
|
||||||
return [
|
return [
|
||||||
self._convert_data_sample(self.datasets[dataset_name][sample_index], dataset_name)
|
self._convert_data_sample(self.datasets[dataset_name][sample_index], dataset_name)
|
||||||
|
|||||||
@@ -12,34 +12,44 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
"""The definition of model loader.
|
"""The definition of model engine.
|
||||||
|
|
||||||
Init Phase:
|
How to use:
|
||||||
|
model_engine = ModelEngine(model_args, is_train=True)
|
||||||
|
model_engine.processor: Get the tokenizer or multi-modal processor.
|
||||||
|
model_engine.renderer: Get the renderer.
|
||||||
|
model_engine.model_config: Get the model configuration.
|
||||||
|
model_engine.model: Get the HF model.
|
||||||
|
|
||||||
|
Init workflow:
|
||||||
1. Init processor.
|
1. Init processor.
|
||||||
|
2. Init render.
|
||||||
2. Init model config.
|
2. Init model config.
|
||||||
3. Init model.
|
3. Init model.
|
||||||
4. Init adapter.
|
4. Init adapter.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from accelerate import init_empty_weights
|
||||||
from transformers import AutoConfig, AutoProcessor
|
from transformers import AutoConfig, AutoProcessor
|
||||||
|
|
||||||
|
from ..accelerator.helper import DeviceType
|
||||||
from ..accelerator.interface import DistributedInterface
|
from ..accelerator.interface import DistributedInterface
|
||||||
from ..config.model_args import ModelArguments, ModelClass
|
from ..config.model_args import ModelArguments, ModelClass
|
||||||
from ..utils import logging
|
from ..utils import logging
|
||||||
from ..utils.types import HFConfig, HFModel, Processor
|
from ..utils.types import HFConfig, HFModel, Processor
|
||||||
|
from .utils.rendering import Renderer
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class ModelLoader:
|
class ModelEngine:
|
||||||
"""Model loader.
|
"""Model engine.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_args: Model arguments.
|
model_args: Model arguments.
|
||||||
is_trainable: Whether to train the model.
|
is_train: Whether to train the model.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, model_args: ModelArguments, is_train: bool = False) -> None:
|
def __init__(self, model_args: ModelArguments, is_train: bool = False) -> None:
|
||||||
@@ -49,17 +59,22 @@ class ModelLoader:
|
|||||||
"""Whether to train the model."""
|
"""Whether to train the model."""
|
||||||
self.processor = self._init_processor()
|
self.processor = self._init_processor()
|
||||||
"""Tokenizer or multi-modal processor."""
|
"""Tokenizer or multi-modal processor."""
|
||||||
|
self.renderer = Renderer(self.args.template, self.processor)
|
||||||
|
"""Renderer."""
|
||||||
self.model_config = self._init_model_config()
|
self.model_config = self._init_model_config()
|
||||||
"""Model configuration."""
|
"""Model configuration."""
|
||||||
self.model = self._init_model()
|
self.model = self._init_model()
|
||||||
"""HF model."""
|
"""HF model."""
|
||||||
|
|
||||||
def _init_processor(self) -> Processor:
|
def _init_processor(self) -> Processor:
|
||||||
"""Init processor."""
|
"""Init processor.
|
||||||
|
|
||||||
|
NOTE: Transformers v5 always use fast tokenizer.
|
||||||
|
https://github.com/huggingface/transformers/blob/v5.0.0rc1/src/transformers/models/auto/tokenization_auto.py#L642
|
||||||
|
"""
|
||||||
return AutoProcessor.from_pretrained(
|
return AutoProcessor.from_pretrained(
|
||||||
self.args.model,
|
self.args.model,
|
||||||
trust_remote_code=self.args.trust_remote_code,
|
trust_remote_code=self.args.trust_remote_code,
|
||||||
use_fast=self.args.use_fast_processor,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def _init_model_config(self) -> HFConfig:
|
def _init_model_config(self) -> HFConfig:
|
||||||
@@ -92,14 +107,24 @@ class ModelLoader:
|
|||||||
|
|
||||||
AutoClass = AutoModel
|
AutoClass = AutoModel
|
||||||
|
|
||||||
# map the entire model to the current accelerator
|
if self.args.init_config is not None:
|
||||||
model = AutoClass.from_pretrained(
|
from ..plugins.model_plugins.initialization import InitPlugin
|
||||||
self.args.model,
|
|
||||||
config=self.model_config,
|
init_device = InitPlugin(self.args.init_config.name)()
|
||||||
dtype="auto",
|
else:
|
||||||
device_map=DistributedInterface().current_accelerator,
|
init_device = DistributedInterface().current_device
|
||||||
trust_remote_code=self.args.trust_remote_code,
|
|
||||||
)
|
if init_device.type == DeviceType.META:
|
||||||
|
with init_empty_weights():
|
||||||
|
model = AutoClass.from_config(self.model_config)
|
||||||
|
else:
|
||||||
|
model = AutoClass.from_pretrained(
|
||||||
|
self.args.model,
|
||||||
|
config=self.model_config,
|
||||||
|
dtype="auto",
|
||||||
|
device_map=init_device,
|
||||||
|
trust_remote_code=self.args.trust_remote_code,
|
||||||
|
)
|
||||||
|
|
||||||
if self.args.peft_config is None:
|
if self.args.peft_config is None:
|
||||||
if self.is_train:
|
if self.is_train:
|
||||||
@@ -124,12 +149,12 @@ class ModelLoader:
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
"""
|
"""
|
||||||
python -m llamafactory.v1.core.model_loader --model llamafactory/tiny-random-qwen2.5
|
python -m llamafactory.v1.core.model_engine --model llamafactory/tiny-random-qwen2.5
|
||||||
"""
|
"""
|
||||||
from ..config.arg_parser import get_args
|
from ..config.arg_parser import get_args
|
||||||
|
|
||||||
_, model_args, *_ = get_args()
|
_, model_args, *_ = get_args()
|
||||||
model_loader = ModelLoader(model_args=model_args)
|
model_engine = ModelEngine(model_args=model_args)
|
||||||
print(model_loader.processor)
|
print(model_engine.processor)
|
||||||
print(model_loader.model_config)
|
print(model_engine.model_config)
|
||||||
print(model_loader.model)
|
print(model_engine.model)
|
||||||
99
src/llamafactory/v1/core/utils/rendering.py
Normal file
99
src/llamafactory/v1/core/utils/rendering.py
Normal file
@@ -0,0 +1,99 @@
|
|||||||
|
# Copyright 2025 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
from ...utils.constants import IGNORE_INDEX
|
||||||
|
from ...utils.helper import get_tokenizer
|
||||||
|
from ...utils.types import Message, ModelInput, Processor
|
||||||
|
|
||||||
|
|
||||||
|
def render_chatml_messages(
|
||||||
|
processor: Processor,
|
||||||
|
messages: list[Message],
|
||||||
|
tools: str | None = None,
|
||||||
|
is_generate: bool = False,
|
||||||
|
) -> ModelInput:
|
||||||
|
"""Apply chatml template to messages and convert them to model input.
|
||||||
|
|
||||||
|
See https://huggingface.co/spaces/huggingfacejs/chat-template-playground?modelId=Qwen/Qwen2-7B-Instruct
|
||||||
|
"""
|
||||||
|
tokenizer = get_tokenizer(processor)
|
||||||
|
input_ids, labels, loss_weights = [], [], []
|
||||||
|
|
||||||
|
for message in messages:
|
||||||
|
temp_str = "<|im_start|>" + message["role"] + "\n"
|
||||||
|
for content in message["content"]:
|
||||||
|
if content["type"] == "text":
|
||||||
|
temp_str += content["value"]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported content type: {content['type']}")
|
||||||
|
|
||||||
|
temp_str += "<|im_end|>\n"
|
||||||
|
temp_weight = message.get("loss_weight", 1.0 if message["role"] == "assistant" else 0.0)
|
||||||
|
temp_ids = tokenizer.encode(temp_str, add_special_tokens=False)
|
||||||
|
input_ids.extend(temp_ids)
|
||||||
|
loss_weights.extend([temp_weight] * len(temp_ids))
|
||||||
|
if temp_weight > 1e-6:
|
||||||
|
labels.extend(temp_ids)
|
||||||
|
else:
|
||||||
|
labels.extend([IGNORE_INDEX] * len(temp_ids))
|
||||||
|
|
||||||
|
if is_generate:
|
||||||
|
temp_ids = tokenizer.encode("<|im_start|>assistant\n", add_special_tokens=False)
|
||||||
|
input_ids.extend(temp_ids)
|
||||||
|
loss_weights.extend([0.0] * len(temp_ids))
|
||||||
|
labels.extend([IGNORE_INDEX] * len(temp_ids))
|
||||||
|
|
||||||
|
return ModelInput(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=[1] * len(input_ids),
|
||||||
|
labels=labels,
|
||||||
|
loss_weights=loss_weights,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_chatml_message(generated_text: str) -> Message:
|
||||||
|
"""Parse a message in ChatML format. Supports interleaved reasoning and tool calls.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
generated_text (str): The generated text in ChatML format.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Message: The parsed message.
|
||||||
|
"""
|
||||||
|
return Message(role="assistant", content=[{"type": "text", "value": generated_text}])
|
||||||
|
|
||||||
|
|
||||||
|
class Renderer:
|
||||||
|
def __init__(self, template: str, processor: Processor):
|
||||||
|
self.template = template
|
||||||
|
self.processor = processor
|
||||||
|
|
||||||
|
def render_messages(
|
||||||
|
self, messages: list[Message], tools: str | None = None, is_generate: bool = False
|
||||||
|
) -> ModelInput:
|
||||||
|
if self.template == "chatml":
|
||||||
|
return render_chatml_messages(self.processor, messages, tools, is_generate)
|
||||||
|
else:
|
||||||
|
from ...plugins.model_plugins.rendering import RenderingPlugin
|
||||||
|
|
||||||
|
return RenderingPlugin(self.template).render_messages(self.processor, messages, tools, is_generate)
|
||||||
|
|
||||||
|
def parse_message(self, generated_text: str) -> Message:
|
||||||
|
if self.template == "chatml":
|
||||||
|
return parse_chatml_message(generated_text)
|
||||||
|
else:
|
||||||
|
from ...plugins.model_plugins.rendering import RenderingPlugin
|
||||||
|
|
||||||
|
return RenderingPlugin(self.template).parse_message(generated_text)
|
||||||
@@ -49,6 +49,11 @@ def launch():
|
|||||||
|
|
||||||
run_sft()
|
run_sft()
|
||||||
|
|
||||||
|
elif command == "chat":
|
||||||
|
from .samplers.cli_sampler import run_chat
|
||||||
|
|
||||||
|
run_chat()
|
||||||
|
|
||||||
elif command == "env":
|
elif command == "env":
|
||||||
print_env()
|
print_env()
|
||||||
|
|
||||||
|
|||||||
@@ -13,11 +13,12 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import json
|
||||||
from typing import Any, Literal, NotRequired, TypedDict
|
from typing import Any, Literal, NotRequired, TypedDict
|
||||||
|
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
from ...utils.plugin import BasePlugin
|
from ...utils.plugin import BasePlugin
|
||||||
from ...utils.types import DPOSample, Sample, SFTSample
|
from ...utils.types import DPOSample, Sample, SFTSample, ToolCall
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
@@ -61,7 +62,7 @@ class DataConverterPlugin(BasePlugin):
|
|||||||
return super().__call__(raw_sample)
|
return super().__call__(raw_sample)
|
||||||
|
|
||||||
|
|
||||||
@DataConverterPlugin("alpaca").register
|
@DataConverterPlugin("alpaca").register()
|
||||||
def alpaca_converter(raw_sample: AlpacaSample) -> SFTSample:
|
def alpaca_converter(raw_sample: AlpacaSample) -> SFTSample:
|
||||||
"""Convert Alpaca sample to SFT sample.
|
"""Convert Alpaca sample to SFT sample.
|
||||||
|
|
||||||
@@ -98,7 +99,7 @@ def alpaca_converter(raw_sample: AlpacaSample) -> SFTSample:
|
|||||||
return {"messages": messages}
|
return {"messages": messages}
|
||||||
|
|
||||||
|
|
||||||
@DataConverterPlugin("sharegpt").register
|
@DataConverterPlugin("sharegpt").register()
|
||||||
def sharegpt_converter(raw_sample: SharegptSample) -> SFTSample:
|
def sharegpt_converter(raw_sample: SharegptSample) -> SFTSample:
|
||||||
"""Convert ShareGPT sample to SFT sample.
|
"""Convert ShareGPT sample to SFT sample.
|
||||||
|
|
||||||
@@ -118,17 +119,32 @@ def sharegpt_converter(raw_sample: SharegptSample) -> SFTSample:
|
|||||||
"function_call": "assistant",
|
"function_call": "assistant",
|
||||||
}
|
}
|
||||||
messages = []
|
messages = []
|
||||||
tools = raw_sample.get("tools", "")
|
tools = raw_sample.get("tools")
|
||||||
|
if tools:
|
||||||
|
try:
|
||||||
|
tools: list[dict[str, Any]] = json.loads(tools)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
logger.warning_rank0(f"Invalid tools format: {str(tools)}")
|
||||||
|
tools = []
|
||||||
|
|
||||||
for message in raw_sample.get("conversations", []):
|
for message in raw_sample.get("conversations", []):
|
||||||
tag = message["from"]
|
tag = message["from"]
|
||||||
if tag not in tag_mapping:
|
if tag not in tag_mapping:
|
||||||
logger.warning_rank0(f"Unsupported role tag {tag} in message: {message}")
|
logger.warning_rank0(f"Unsupported role tag {tag} in message: {message}")
|
||||||
elif tag == "function_call":
|
elif tag == "function_call":
|
||||||
|
try:
|
||||||
|
tool_calls: ToolCall | list[ToolCall] = json.loads(message["value"])
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
logger.warning_rank0(f"Invalid tool call format: {str(message['value'])}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not isinstance(tool_calls, list):
|
||||||
|
tool_calls = [tool_calls]
|
||||||
|
|
||||||
messages.append(
|
messages.append(
|
||||||
{
|
{
|
||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
"content": [{"type": "tool_calls", "value": message["value"]}],
|
"content": [{"type": "tool_call", "value": json.dumps(tool_call)} for tool_call in tool_calls],
|
||||||
"loss_weight": 1.0,
|
"loss_weight": 1.0,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@@ -142,15 +158,12 @@ def sharegpt_converter(raw_sample: SharegptSample) -> SFTSample:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if tools:
|
if tools:
|
||||||
if messages and messages[0]["role"] == "system":
|
return {"messages": messages, "tools": json.dumps(tools)}
|
||||||
messages[0]["content"].append({"type": "tools", "value": tools})
|
else:
|
||||||
else:
|
return {"messages": messages}
|
||||||
messages.insert(0, {"role": "system", "content": [{"type": "tools", "value": tools}], "loss_weight": 0.0})
|
|
||||||
|
|
||||||
return {"messages": messages}
|
|
||||||
|
|
||||||
|
|
||||||
@DataConverterPlugin("pair").register
|
@DataConverterPlugin("pair").register()
|
||||||
def pair_converter(raw_sample: PairSample) -> DPOSample:
|
def pair_converter(raw_sample: PairSample) -> DPOSample:
|
||||||
"""Convert Pair sample to DPO sample.
|
"""Convert Pair sample to DPO sample.
|
||||||
|
|
||||||
|
|||||||
@@ -49,7 +49,7 @@ def _get_builder_name(path: str) -> Literal["arrow", "csv", "json", "parquet", "
|
|||||||
raise ValueError(f"Unknown dataset filetype: {filetype}.")
|
raise ValueError(f"Unknown dataset filetype: {filetype}.")
|
||||||
|
|
||||||
|
|
||||||
@DataLoaderPlugin("local").register
|
@DataLoaderPlugin("local").register()
|
||||||
def load_data_from_file(filepath: str, split: str, streaming: bool) -> HFDataset:
|
def load_data_from_file(filepath: str, split: str, streaming: bool) -> HFDataset:
|
||||||
if os.path.isdir(filepath):
|
if os.path.isdir(filepath):
|
||||||
filetype = _get_builder_name(os.listdir(filepath)[0])
|
filetype = _get_builder_name(os.listdir(filepath)[0])
|
||||||
@@ -66,49 +66,43 @@ def load_data_from_file(filepath: str, split: str, streaming: bool) -> HFDataset
|
|||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
class DataIndexPlugin(BasePlugin):
|
def adjust_data_index(
|
||||||
"""Plugin for adjusting dataset index."""
|
data_index: list[tuple[str, int]], size: int | None, weight: float | None
|
||||||
|
) -> list[tuple[str, int]]:
|
||||||
|
"""Adjust dataset index by size and weight.
|
||||||
|
|
||||||
def adjust_data_index(
|
Args:
|
||||||
self, data_index: list[tuple[str, int]], size: int | None, weight: float | None
|
data_index (list[tuple[str, int]]): List of (dataset_name, sample_index).
|
||||||
) -> list[tuple[str, int]]:
|
size (Optional[int]): Desired dataset size.
|
||||||
"""Adjust dataset index by size and weight.
|
weight (Optional[float]): Desired dataset weight.
|
||||||
|
|
||||||
Args:
|
Returns:
|
||||||
data_index (list[tuple[str, int]]): List of (dataset_name, sample_index).
|
list[tuple[str, int]]: Adjusted dataset index.
|
||||||
size (Optional[int]): Desired dataset size.
|
"""
|
||||||
weight (Optional[float]): Desired dataset weight.
|
if size is not None:
|
||||||
|
data_index = random.choices(data_index, k=size)
|
||||||
|
|
||||||
Returns:
|
if weight is not None:
|
||||||
list[tuple[str, int]]: Adjusted dataset index.
|
data_index = random.choices(data_index, k=int(len(data_index) * weight))
|
||||||
"""
|
|
||||||
if size is not None:
|
|
||||||
data_index = random.choices(data_index, k=size)
|
|
||||||
|
|
||||||
if weight is not None:
|
return data_index
|
||||||
data_index = random.choices(data_index, k=int(len(data_index) * weight))
|
|
||||||
|
|
||||||
return data_index
|
|
||||||
|
|
||||||
|
|
||||||
class DataSelectorPlugin(BasePlugin):
|
def select_data_sample(
|
||||||
"""Plugin for selecting dataset samples."""
|
data_index: list[tuple[str, int]], index: slice | list[int] | Any
|
||||||
|
) -> tuple[str, int] | list[tuple[str, int]]:
|
||||||
|
"""Select dataset samples.
|
||||||
|
|
||||||
def select(
|
Args:
|
||||||
self, data_index: list[tuple[str, int]], index: slice | list[int] | Any
|
data_index (list[tuple[str, int]]): List of (dataset_name, sample_index).
|
||||||
) -> tuple[str, int] | list[tuple[str, int]]:
|
index (Union[slice, list[int], Any]): Index of dataset samples.
|
||||||
"""Select dataset samples.
|
|
||||||
|
|
||||||
Args:
|
Returns:
|
||||||
data_index (list[tuple[str, int]]): List of (dataset_name, sample_index).
|
Union[tuple[str, int], list[tuple[str, int]]]: Selected dataset samples.
|
||||||
index (Union[slice, list[int], Any]): Index of dataset samples.
|
"""
|
||||||
|
if isinstance(index, slice):
|
||||||
Returns:
|
return [data_index[i] for i in range(*index.indices(len(data_index)))]
|
||||||
Union[tuple[str, int], list[tuple[str, int]]]: Selected dataset samples.
|
elif isinstance(index, list):
|
||||||
"""
|
return [data_index[i] for i in index]
|
||||||
if isinstance(index, slice):
|
else:
|
||||||
return [data_index[i] for i in range(*index.indices(len(data_index)))]
|
raise ValueError(f"Invalid index type {type(index)}.")
|
||||||
elif isinstance(index, list):
|
|
||||||
return [data_index[i] for i in index]
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Invalid index type {type(index)}.")
|
|
||||||
|
|||||||
@@ -1,133 +0,0 @@
|
|||||||
# Copyright 2025 the LlamaFactory team.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
|
|
||||||
from dataclasses import dataclass
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Template:
|
|
||||||
user_template: str
|
|
||||||
assistant_template: str
|
|
||||||
system_template: str
|
|
||||||
|
|
||||||
def render_message(self, message: dict[str, str]) -> str:
|
|
||||||
return self.user_template.format(**message)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class QwenTemplate:
|
|
||||||
message_template: str = "<|im_start|>{role}\n{content}<|im_end|>\n" # FIXME if role: tool
|
|
||||||
thinking_template: str = "<think>\n{content}\n</think>\n\n"
|
|
||||||
|
|
||||||
def _extract_content(self, content_data: str | list[dict[str, str]]) -> str:
|
|
||||||
if isinstance(content_data, str):
|
|
||||||
return content_data.strip()
|
|
||||||
|
|
||||||
if isinstance(content_data, list):
|
|
||||||
parts = []
|
|
||||||
for item in content_data:
|
|
||||||
if item.get("type") == "text":
|
|
||||||
parts.append(item.get("value", ""))
|
|
||||||
elif item.get("type") == "image_url":
|
|
||||||
pass
|
|
||||||
return "\n".join(parts).strip()
|
|
||||||
|
|
||||||
return ""
|
|
||||||
|
|
||||||
def render_message(self, message: dict[str, str | list[dict[str, str]]]) -> str:
|
|
||||||
role = message["role"]
|
|
||||||
content = self._extract_content(message.get("content", ""))
|
|
||||||
|
|
||||||
if role == "assistant":
|
|
||||||
reasoning_content = message.get("reasoning_content", "")
|
|
||||||
if reasoning_content:
|
|
||||||
reasoning_content = self.thinking_template.format(content=str(reasoning_content).strip())
|
|
||||||
return self.message_template.format(role="assistant", content=reasoning_content + content)
|
|
||||||
else:
|
|
||||||
return self.message_template.format(role=role, content=content)
|
|
||||||
|
|
||||||
def encode_messages(self, tokenizer, messages: list[dict[str, str]], max_seq_len: int = 8192) -> any:
|
|
||||||
"""Encode one message."""
|
|
||||||
input_ids, attention_mask, labels = [], [], []
|
|
||||||
for message in messages:
|
|
||||||
content_str = self.render_message(message)
|
|
||||||
content_ids = tokenizer.encode(content_str, add_special_tokens=False)
|
|
||||||
input_ids += content_ids
|
|
||||||
attention_mask += [1] * len(content_ids)
|
|
||||||
|
|
||||||
if hasattr(message, "loss_weight"):
|
|
||||||
loss_weight = message["loss_weight"]
|
|
||||||
else:
|
|
||||||
loss_weight = 1 if message["role"] == "assistant" else 0
|
|
||||||
if loss_weight == 1:
|
|
||||||
labels += content_ids
|
|
||||||
else:
|
|
||||||
labels += [-100] * len(content_ids)
|
|
||||||
model_inputs = {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}
|
|
||||||
model_inputs.update({"position_ids": list(range(len(input_ids)))})
|
|
||||||
model_inputs = {k: v[-max_seq_len:] for k, v in model_inputs.items()}
|
|
||||||
return model_inputs
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
|
|
||||||
def to_qwen3_messages(template: QwenTemplate, messages: list[dict]):
|
|
||||||
out = []
|
|
||||||
for m in messages:
|
|
||||||
role = m["role"]
|
|
||||||
content = template._extract_content(m.get("content", ""))
|
|
||||||
if role == "assistant":
|
|
||||||
reasoning = (m.get("reasoning_content") or "").strip()
|
|
||||||
if reasoning:
|
|
||||||
content = template.thinking_template.format(content=reasoning) + content
|
|
||||||
out.append({"role": role, "content": content})
|
|
||||||
return out
|
|
||||||
|
|
||||||
from transformers import AutoTokenizer
|
|
||||||
|
|
||||||
tok = AutoTokenizer.from_pretrained(
|
|
||||||
"Qwen/Qwen3-30B-A3B-Thinking-2507",
|
|
||||||
trust_remote_code=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
test_messages = [
|
|
||||||
{"role": "system", "content": "You are a helpful assistant."},
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": [{"type": "text", "text": "1+1等于几?"}, {"type": "text", "text": "2+2等于几?"}],
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"role": "assistant",
|
|
||||||
"reasoning_content": "这是一个简单的数学问题。1加1的结果是2。",
|
|
||||||
"content": [{"type": "text", "text": "1+1=2"}, {"type": "text", "text": "2+2=4"}],
|
|
||||||
},
|
|
||||||
]
|
|
||||||
|
|
||||||
template = QwenTemplate()
|
|
||||||
rendered_custom = "".join([template.render_message(m) for m in test_messages])
|
|
||||||
|
|
||||||
qwen3_messages = to_qwen3_messages(template, test_messages)
|
|
||||||
rendered_hf = tok.apply_chat_template(qwen3_messages, tokenize=False, add_generation_prompt=False)
|
|
||||||
|
|
||||||
print("==== custom ====")
|
|
||||||
print(rendered_custom)
|
|
||||||
print("==== hf ====")
|
|
||||||
print(rendered_hf)
|
|
||||||
|
|
||||||
assert rendered_custom.strip() == rendered_hf.strip(), "Rendered text mismatch"
|
|
||||||
|
|
||||||
ids_custom = tok.encode(rendered_custom, add_special_tokens=False)
|
|
||||||
ids_hf = tok.apply_chat_template(qwen3_messages, tokenize=True, add_generation_prompt=False)
|
|
||||||
assert ids_custom == ids_hf, f"Token ids mismatch: custom={len(ids_custom)} hf={len(ids_hf)}"
|
|
||||||
@@ -0,0 +1,43 @@
|
|||||||
|
# Copyright 2025 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from ...accelerator.helper import DeviceType
|
||||||
|
from ...accelerator.interface import DistributedInterface
|
||||||
|
from ...utils.plugin import BasePlugin
|
||||||
|
|
||||||
|
|
||||||
|
class InitPlugin(BasePlugin):
|
||||||
|
def __call__(self) -> torch.device:
|
||||||
|
return super().__call__()
|
||||||
|
|
||||||
|
|
||||||
|
@InitPlugin("init_on_meta").register()
|
||||||
|
def init_on_meta() -> torch.device:
|
||||||
|
return torch.device(DeviceType.META.value)
|
||||||
|
|
||||||
|
|
||||||
|
@InitPlugin("init_on_rank0").register()
|
||||||
|
def init_on_rank0() -> torch.device:
|
||||||
|
if DistributedInterface().get_rank() == 0:
|
||||||
|
return torch.device(DeviceType.CPU.value)
|
||||||
|
else:
|
||||||
|
return torch.device(DeviceType.META.value)
|
||||||
|
|
||||||
|
|
||||||
|
@InitPlugin("init_on_default").register()
|
||||||
|
def init_on_default() -> torch.device:
|
||||||
|
return DistributedInterface().current_device
|
||||||
|
|||||||
@@ -38,17 +38,17 @@ class BaseKernel(ABC):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_kernel_id(cls) -> str:
|
def get_kernel_id(cls) -> str:
|
||||||
r"""Returns the unique identifier for the kernel."""
|
"""Returns the unique identifier for the kernel."""
|
||||||
return cls._kernel_id
|
return cls._kernel_id
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_device(cls) -> str:
|
def get_device(cls) -> str:
|
||||||
r"""Returns the device type associated with the kernel (e.g., "cuda", "npu", "cpu")."""
|
"""Returns the device type associated with the kernel (e.g., "cuda", "npu", "cpu")."""
|
||||||
return cls._device
|
return cls._device
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_deps(cls) -> bool:
|
def check_deps(cls) -> bool:
|
||||||
r"""Checks if the required dependencies for the kernel are available.
|
"""Checks if the required dependencies for the kernel are available.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
bool: ``True`` if dependencies are met, ``False`` otherwise.
|
bool: ``True`` if dependencies are met, ``False`` otherwise.
|
||||||
@@ -65,7 +65,7 @@ class BaseKernel(ABC):
|
|||||||
@classmethod
|
@classmethod
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def apply(cls, **kwargs) -> HFModel:
|
def apply(cls, **kwargs) -> HFModel:
|
||||||
r"""Applies the kernel optimization to the model.
|
"""Applies the kernel optimization to the model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
**kwargs: Arbitrary keyword arguments, usually containing the model instance and the kernel configuration.
|
**kwargs: Arbitrary keyword arguments, usually containing the model instance and the kernel configuration.
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ logger = get_logger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
def scan_all_kernels():
|
def scan_all_kernels():
|
||||||
r"""Scan all kernels in the ``ops`` directory.
|
"""Scan all kernels in the ``ops`` directory.
|
||||||
|
|
||||||
Scans the ``ops`` directory for all ``.py`` files and attempts to import them.
|
Scans the ``ops`` directory for all ``.py`` files and attempts to import them.
|
||||||
Importing triggers the :func:`~registry.register_kernel` decorator, which automatically registers the kernels.
|
Importing triggers the :func:`~registry.register_kernel` decorator, which automatically registers the kernels.
|
||||||
@@ -77,7 +77,7 @@ default_kernels = scan_all_kernels()
|
|||||||
|
|
||||||
|
|
||||||
def get_default_kernels():
|
def get_default_kernels():
|
||||||
r"""Get a list of default registered kernel IDs.
|
"""Get a list of default registered kernel IDs.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
list[str]: List of kernel IDs.
|
list[str]: List of kernel IDs.
|
||||||
@@ -86,7 +86,7 @@ def get_default_kernels():
|
|||||||
|
|
||||||
|
|
||||||
def apply_kernel(kernel_id: str, **kwargs):
|
def apply_kernel(kernel_id: str, **kwargs):
|
||||||
r"""Applies a specific kernel to the model.
|
"""Applies a specific kernel to the model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
kernel_id (str): The ID of the kernel to apply.
|
kernel_id (str): The ID of the kernel to apply.
|
||||||
@@ -99,18 +99,19 @@ def apply_kernel(kernel_id: str, **kwargs):
|
|||||||
kernel = default_kernels.get(kernel_id)
|
kernel = default_kernels.get(kernel_id)
|
||||||
if kernel is None:
|
if kernel is None:
|
||||||
raise ValueError(f"Kernel {kernel_id} not found")
|
raise ValueError(f"Kernel {kernel_id} not found")
|
||||||
|
|
||||||
kernel.apply(**kwargs)
|
kernel.apply(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
class KernelPlugin(BasePlugin):
|
class KernelPlugin(BasePlugin):
|
||||||
r"""Plugin for managing kernel optimizations."""
|
"""Plugin for managing kernel optimizations."""
|
||||||
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@KernelPlugin("auto").register
|
@KernelPlugin("auto").register()
|
||||||
def apply_default_kernels(**kwargs):
|
def apply_default_kernels(**kwargs):
|
||||||
r"""Applies all default registered kernels to the model.
|
"""Applies all default registered kernels to the model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
**kwargs: Keyword arguments passed to the kernel application function.
|
**kwargs: Keyword arguments passed to the kernel application function.
|
||||||
@@ -125,8 +126,11 @@ def apply_default_kernels(**kwargs):
|
|||||||
use_kernels = default_kernels.keys()
|
use_kernels = default_kernels.keys()
|
||||||
else:
|
else:
|
||||||
use_kernels = kwargs.get("include_kernels").split(",") # "kernel_id1,kernel_id2,kernel_id3"
|
use_kernels = kwargs.get("include_kernels").split(",") # "kernel_id1,kernel_id2,kernel_id3"
|
||||||
|
|
||||||
for kernel in use_kernels:
|
for kernel in use_kernels:
|
||||||
if kernel not in default_kernels:
|
if kernel not in default_kernels:
|
||||||
raise ValueError(f"Kernel {kernel} not found")
|
raise ValueError(f"Kernel {kernel} not found")
|
||||||
|
|
||||||
apply_kernel(kernel, **kwargs)
|
apply_kernel(kernel, **kwargs)
|
||||||
|
|
||||||
return kwargs.get("model")
|
return kwargs.get("model")
|
||||||
|
|||||||
@@ -40,11 +40,11 @@ from ...registry import register_kernel
|
|||||||
|
|
||||||
|
|
||||||
class GmmFunction(torch.autograd.Function):
|
class GmmFunction(torch.autograd.Function):
|
||||||
r"""Custom autograd function for NPU Grouped Matrix Multiplication (GMM)."""
|
"""Custom autograd function for NPU Grouped Matrix Multiplication (GMM)."""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, x, weight, group_list):
|
def forward(ctx, x, weight, group_list):
|
||||||
r"""Performs the forward pass of Grouped Matrix Multiplication.
|
"""Performs the forward pass of Grouped Matrix Multiplication.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
ctx: Context object to save tensors for backward pass.
|
ctx: Context object to save tensors for backward pass.
|
||||||
@@ -65,7 +65,7 @@ class GmmFunction(torch.autograd.Function):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, grad_output):
|
def backward(ctx, grad_output):
|
||||||
r"""Performs the backward pass of Grouped Matrix Multiplication.
|
"""Performs the backward pass of Grouped Matrix Multiplication.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
ctx: Context object containing saved tensors.
|
ctx: Context object containing saved tensors.
|
||||||
@@ -94,11 +94,11 @@ class GmmFunction(torch.autograd.Function):
|
|||||||
|
|
||||||
|
|
||||||
class HybridGmmFunction(torch.autograd.Function):
|
class HybridGmmFunction(torch.autograd.Function):
|
||||||
r"""Custom autograd function for Hybrid Grouped Matrix Multiplication on NPU."""
|
"""Custom autograd function for Hybrid Grouped Matrix Multiplication on NPU."""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, num_experts, *args):
|
def forward(ctx, num_experts, *args):
|
||||||
r"""Performs the forward pass of Hybrid GMM.
|
"""Performs the forward pass of Hybrid GMM.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
ctx: Context object to save tensors.
|
ctx: Context object to save tensors.
|
||||||
@@ -124,7 +124,7 @@ class HybridGmmFunction(torch.autograd.Function):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, *grad_outputs):
|
def backward(ctx, *grad_outputs):
|
||||||
r"""Performs the backward pass of Hybrid GMM.
|
"""Performs the backward pass of Hybrid GMM.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
ctx: Context object containing saved tensors.
|
ctx: Context object containing saved tensors.
|
||||||
@@ -176,13 +176,13 @@ class HybridGmmFunction(torch.autograd.Function):
|
|||||||
|
|
||||||
|
|
||||||
class NpuMoeFused:
|
class NpuMoeFused:
|
||||||
r"""Container for NPU fused MoE forward functions."""
|
"""Container for NPU fused MoE forward functions."""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def npu_moe_experts_forward(
|
def npu_moe_experts_forward(
|
||||||
self, hidden_states: torch.Tensor, routing_weights: torch.Tensor, router_indices: torch.Tensor
|
self, hidden_states: torch.Tensor, routing_weights: torch.Tensor, router_indices: torch.Tensor
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
r"""Forward pass for MoE experts using NPU fused operations.
|
"""Forward pass for MoE experts using NPU fused operations.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
self: The MoE layer instance.
|
self: The MoE layer instance.
|
||||||
@@ -230,11 +230,11 @@ class NpuMoeFused:
|
|||||||
|
|
||||||
|
|
||||||
class Qwen3NpuMoeFused:
|
class Qwen3NpuMoeFused:
|
||||||
r"""Container for Qwen3 NPU fused MoE forward functions."""
|
"""Container for Qwen3 NPU fused MoE forward functions."""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def qwen3moe_sparse_moe_block_forward(self, hidden_states: torch.Tensor):
|
def qwen3moe_sparse_moe_block_forward(self, hidden_states: torch.Tensor):
|
||||||
r"""Forward pass for Qwen3 sparse MoE block using NPU fused operations.
|
"""Forward pass for Qwen3 sparse MoE block using NPU fused operations.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
self: The Qwen3 MoE block instance.
|
self: The Qwen3 MoE block instance.
|
||||||
@@ -298,14 +298,14 @@ if not is_transformers_version_greater_than("5.0.0"):
|
|||||||
|
|
||||||
@register_kernel
|
@register_kernel
|
||||||
class NpuFusedMoEKernel(BaseKernel):
|
class NpuFusedMoEKernel(BaseKernel):
|
||||||
r"""NPU Fused MoE Kernel implementation."""
|
"""NPU Fused MoE Kernel implementation."""
|
||||||
|
|
||||||
_kernel_id = "npu_fused_moe"
|
_kernel_id = "npu_fused_moe"
|
||||||
_device = DeviceType.NPU
|
_device = DeviceType.NPU
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def apply(cls, **kwargs) -> HFModel:
|
def apply(cls, **kwargs) -> HFModel:
|
||||||
r"""Applies the NPU fused MoE kernel to the model.
|
"""Applies the NPU fused MoE kernel to the model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
**kwargs: Keyword arguments containing the model.
|
**kwargs: Keyword arguments containing the model.
|
||||||
@@ -333,6 +333,7 @@ class NpuFusedMoEKernel(BaseKernel):
|
|||||||
|
|
||||||
if target_moe_mapping is None:
|
if target_moe_mapping is None:
|
||||||
return model
|
return model
|
||||||
|
|
||||||
for module in model.modules():
|
for module in model.modules():
|
||||||
class_name = module.__class__.__name__
|
class_name = module.__class__.__name__
|
||||||
if class_name in target_moe_mapping:
|
if class_name in target_moe_mapping:
|
||||||
|
|||||||
@@ -38,7 +38,7 @@ except ImportError:
|
|||||||
|
|
||||||
|
|
||||||
def npu_swiglu_forward(self, hidden_state):
|
def npu_swiglu_forward(self, hidden_state):
|
||||||
r"""SwiGLU forward pass for NPU.
|
"""SwiGLU forward pass for NPU.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
self: The MLP layer instance.
|
self: The MLP layer instance.
|
||||||
@@ -53,7 +53,7 @@ def npu_swiglu_forward(self, hidden_state):
|
|||||||
|
|
||||||
|
|
||||||
def _npu_swiglu_glm4_forward(self, hidden_states):
|
def _npu_swiglu_glm4_forward(self, hidden_states):
|
||||||
r"""SwiGLU forward pass for GLM4 on NPU.
|
"""SwiGLU forward pass for GLM4 on NPU.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
self: The GLM4 MLP layer instance.
|
self: The GLM4 MLP layer instance.
|
||||||
@@ -68,7 +68,7 @@ def _npu_swiglu_glm4_forward(self, hidden_states):
|
|||||||
|
|
||||||
|
|
||||||
def _npu_swiglu_gemma3ntext_forward(self, hidden_states):
|
def _npu_swiglu_gemma3ntext_forward(self, hidden_states):
|
||||||
r"""SwiGLU forward pass for Gemma3nText on NPU.
|
"""SwiGLU forward pass for Gemma3nText on NPU.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
self: The Gemma3nText MLP layer instance.
|
self: The Gemma3nText MLP layer instance.
|
||||||
@@ -88,7 +88,7 @@ def _npu_swiglu_gemma3ntext_forward(self, hidden_states):
|
|||||||
|
|
||||||
@register_kernel
|
@register_kernel
|
||||||
class NpuSwiGluKernel(BaseKernel):
|
class NpuSwiGluKernel(BaseKernel):
|
||||||
r"""NPU Kernel for fused SwiGLU activation."""
|
"""NPU Kernel for fused SwiGLU activation."""
|
||||||
|
|
||||||
# just support apply to the following module layers
|
# just support apply to the following module layers
|
||||||
expect_modules = frozenset(
|
expect_modules = frozenset(
|
||||||
@@ -126,7 +126,7 @@ class NpuSwiGluKernel(BaseKernel):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def apply(cls, **kwargs) -> "HFModel":
|
def apply(cls, **kwargs) -> "HFModel":
|
||||||
r"""Applies the NPU fused SwiGLU kernel to the model.
|
"""Applies the NPU fused SwiGLU kernel to the model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
**kwargs: Keyword arguments containing the model.
|
**kwargs: Keyword arguments containing the model.
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ from ...registry import register_kernel
|
|||||||
|
|
||||||
|
|
||||||
def npu_rms_norm_forward(self, hidden_states):
|
def npu_rms_norm_forward(self, hidden_states):
|
||||||
r"""NPU forward implementation for RMSNorm.
|
"""NPU forward implementation for RMSNorm.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
self: RMSNorm module instance with `weight` and `variance_epsilon`.
|
self: RMSNorm module instance with `weight` and `variance_epsilon`.
|
||||||
@@ -46,14 +46,14 @@ def npu_rms_norm_forward(self, hidden_states):
|
|||||||
|
|
||||||
@register_kernel
|
@register_kernel
|
||||||
class NpuRMSNormKernel(BaseKernel):
|
class NpuRMSNormKernel(BaseKernel):
|
||||||
r"""NPU kernel wrapper for RMSNorm that applies the replacement within a model."""
|
"""NPU kernel wrapper for RMSNorm that applies the replacement within a model."""
|
||||||
|
|
||||||
_kernel_id = "npu_fused_rmsnorm"
|
_kernel_id = "npu_fused_rmsnorm"
|
||||||
_device = DeviceType.NPU
|
_device = DeviceType.NPU
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def apply(cls, **kwargs) -> "HFModel":
|
def apply(cls, **kwargs) -> "HFModel":
|
||||||
r"""Iterate the model and apply NPU-optimized forward to matched RMSNorm modules.
|
"""Iterate the model and apply NPU-optimized forward to matched RMSNorm modules.
|
||||||
|
|
||||||
Key points:
|
Key points:
|
||||||
- Match modules whose class name contains "RMSNorm" (case-insensitive).
|
- Match modules whose class name contains "RMSNorm" (case-insensitive).
|
||||||
@@ -78,6 +78,7 @@ class NpuRMSNormKernel(BaseKernel):
|
|||||||
|
|
||||||
if not cls.check_deps():
|
if not cls.check_deps():
|
||||||
raise RuntimeError(f"torch_npu is not available but {cls.__name__} was called.")
|
raise RuntimeError(f"torch_npu is not available but {cls.__name__} was called.")
|
||||||
|
|
||||||
rms_norm_pattern = re.compile("RMSNorm", re.IGNORECASE)
|
rms_norm_pattern = re.compile("RMSNorm", re.IGNORECASE)
|
||||||
|
|
||||||
for name, module in model.named_modules():
|
for name, module in model.named_modules():
|
||||||
|
|||||||
@@ -40,7 +40,7 @@ except ImportError:
|
|||||||
|
|
||||||
|
|
||||||
def _apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
def _apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
||||||
r"""Applies Rotary Position Embedding to the query and key tensors using NPU optimization.
|
"""Applies Rotary Position Embedding to the query and key tensors using NPU optimization.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
q (Tensor): Query tensor.
|
q (Tensor): Query tensor.
|
||||||
@@ -61,7 +61,7 @@ def _apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
|||||||
|
|
||||||
|
|
||||||
def _apply_multimodal_rotary_pos_emb_qwen25_vl(q, k, cos, sin, mrope_section, unsqueeze_dim=1):
|
def _apply_multimodal_rotary_pos_emb_qwen25_vl(q, k, cos, sin, mrope_section, unsqueeze_dim=1):
|
||||||
r"""Applies Rotary Position Embedding with multimodal sections (Qwen2-VL) on NPU.
|
"""Applies Rotary Position Embedding with multimodal sections (Qwen2-VL) on NPU.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
q (Tensor): Query tensor.
|
q (Tensor): Query tensor.
|
||||||
@@ -89,14 +89,14 @@ def _apply_multimodal_rotary_pos_emb_qwen25_vl(q, k, cos, sin, mrope_section, un
|
|||||||
|
|
||||||
@register_kernel
|
@register_kernel
|
||||||
class NpuRoPEKernel(BaseKernel):
|
class NpuRoPEKernel(BaseKernel):
|
||||||
r"""NPU Kernel for Rotary Position Embedding."""
|
"""NPU Kernel for Rotary Position Embedding."""
|
||||||
|
|
||||||
_kernel_id = "npu_fused_rope"
|
_kernel_id = "npu_fused_rope"
|
||||||
_device = DeviceType.NPU
|
_device = DeviceType.NPU
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def apply(cls, **kwargs) -> "HFModel":
|
def apply(cls, **kwargs) -> "HFModel":
|
||||||
r"""Apply RoPE acceleration by monkey-patching `apply_rotary_pos_emb`.
|
"""Apply RoPE acceleration by monkey-patching `apply_rotary_pos_emb`.
|
||||||
|
|
||||||
This function iterates through the model's modules to find attention layers,
|
This function iterates through the model's modules to find attention layers,
|
||||||
identifies the module where they are defined, and replaces the original
|
identifies the module where they are defined, and replaces the original
|
||||||
@@ -115,9 +115,11 @@ class NpuRoPEKernel(BaseKernel):
|
|||||||
"""
|
"""
|
||||||
if not cls.check_deps():
|
if not cls.check_deps():
|
||||||
raise RuntimeError(f"torch_npu is not available but {cls.__name__} was called.")
|
raise RuntimeError(f"torch_npu is not available but {cls.__name__} was called.")
|
||||||
|
|
||||||
model = kwargs.get("model", None)
|
model = kwargs.get("model", None)
|
||||||
if model is None:
|
if model is None:
|
||||||
raise ValueError(f"HFModel instance is required for {cls.__name__}.")
|
raise ValueError(f"HFModel instance is required for {cls.__name__}.")
|
||||||
|
|
||||||
_modules = set()
|
_modules = set()
|
||||||
for module in model.modules():
|
for module in model.modules():
|
||||||
if "Attention" in module.__class__.__name__:
|
if "Attention" in module.__class__.__name__:
|
||||||
@@ -143,4 +145,5 @@ class NpuRoPEKernel(BaseKernel):
|
|||||||
_modules.add(module_name)
|
_modules.add(module_name)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning_rank0_once(f"Failed to apply RoPE kernel to module {module_name}: {e}")
|
logger.warning_rank0_once(f"Failed to apply RoPE kernel to module {module_name}: {e}")
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ __all__ = ["Registry", "register_kernel"]
|
|||||||
|
|
||||||
|
|
||||||
class Registry:
|
class Registry:
|
||||||
r"""Registry for managing kernel implementations.
|
"""Registry for managing kernel implementations.
|
||||||
|
|
||||||
Storage structure: ``{ "kernel_id": Class }``
|
Storage structure: ``{ "kernel_id": Class }``
|
||||||
"""
|
"""
|
||||||
@@ -38,8 +38,8 @@ class Registry:
|
|||||||
_kernels: dict[str, type[BaseKernel]] = {}
|
_kernels: dict[str, type[BaseKernel]] = {}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def register(cls, kernel_cls: type[BaseKernel]):
|
def register(cls, kernel_cls: type[BaseKernel]) -> type[BaseKernel] | None:
|
||||||
r"""Decorator to register a kernel class.
|
"""Decorator to register a kernel class.
|
||||||
|
|
||||||
The class must inherit from :class:`BaseKernel` and specify ``_kernel_id`` and ``_device`` attributes.
|
The class must inherit from :class:`BaseKernel` and specify ``_kernel_id`` and ``_device`` attributes.
|
||||||
|
|
||||||
@@ -47,7 +47,7 @@ class Registry:
|
|||||||
kernel_cls (type[BaseKernel]): The kernel class to register.
|
kernel_cls (type[BaseKernel]): The kernel class to register.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
type[BaseKernel]: The registered kernel class.
|
type[BaseKernel] | None: The registered kernel class if the device type matches the current accelerator
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
TypeError: If the class does not inherit from :class:`BaseKernel`.
|
TypeError: If the class does not inherit from :class:`BaseKernel`.
|
||||||
@@ -55,6 +55,7 @@ class Registry:
|
|||||||
"""
|
"""
|
||||||
if not issubclass(kernel_cls, BaseKernel):
|
if not issubclass(kernel_cls, BaseKernel):
|
||||||
raise TypeError(f"Class {kernel_cls} must inherit from BaseKernel")
|
raise TypeError(f"Class {kernel_cls} must inherit from BaseKernel")
|
||||||
|
|
||||||
kernel_id = kernel_cls.get_kernel_id()
|
kernel_id = kernel_cls.get_kernel_id()
|
||||||
device = kernel_cls.get_device()
|
device = kernel_cls.get_device()
|
||||||
|
|
||||||
@@ -73,7 +74,7 @@ class Registry:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get(cls, kernel_id: str) -> Optional[type[BaseKernel]]:
|
def get(cls, kernel_id: str) -> Optional[type[BaseKernel]]:
|
||||||
r"""Retrieves a registered kernel implementation by its ID.
|
"""Retrieves a registered kernel implementation by its ID.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
kernel_id (str): The ID of the kernel to retrieve.
|
kernel_id (str): The ID of the kernel to retrieve.
|
||||||
@@ -85,7 +86,7 @@ class Registry:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_registered_kernels(cls) -> dict[str, type[BaseKernel]]:
|
def get_registered_kernels(cls) -> dict[str, type[BaseKernel]]:
|
||||||
r"""Returns a dictionary of all registered kernels.
|
"""Returns a dictionary of all registered kernels.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict[str, type[BaseKernel]]: Dictionary mapping kernel IDs to kernel classes.
|
dict[str, type[BaseKernel]]: Dictionary mapping kernel IDs to kernel classes.
|
||||||
|
|||||||
@@ -45,13 +45,13 @@ class PeftPlugin(BasePlugin):
|
|||||||
return super().__call__(model, config)
|
return super().__call__(model, config)
|
||||||
|
|
||||||
|
|
||||||
@PeftPlugin("lora").register
|
@PeftPlugin("lora").register()
|
||||||
def get_lora_model(model: HFModel, config: LoraConfigDict, is_train: bool) -> PeftModel:
|
def get_lora_model(model: HFModel, config: LoraConfigDict, is_train: bool) -> PeftModel:
|
||||||
peft_config = LoraConfig(**config)
|
peft_config = LoraConfig(**config)
|
||||||
model = get_peft_model(model, peft_config)
|
model = get_peft_model(model, peft_config)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@PeftPlugin("freeze").register
|
@PeftPlugin("freeze").register()
|
||||||
def get_freeze_model(model: HFModel, config: FreezeConfigDict, is_train: bool) -> HFModel:
|
def get_freeze_model(model: HFModel, config: FreezeConfigDict, is_train: bool) -> HFModel:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|||||||
212
src/llamafactory/v1/plugins/model_plugins/rendering.py
Normal file
212
src/llamafactory/v1/plugins/model_plugins/rendering.py
Normal file
@@ -0,0 +1,212 @@
|
|||||||
|
# Copyright 2025 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
|
||||||
|
from ...utils.constants import IGNORE_INDEX
|
||||||
|
from ...utils.helper import get_tokenizer
|
||||||
|
from ...utils.plugin import BasePlugin
|
||||||
|
from ...utils.types import Message, ModelInput, Processor, ToolCall
|
||||||
|
|
||||||
|
|
||||||
|
class RenderingPlugin(BasePlugin):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def _update_model_input(
|
||||||
|
processor: Processor,
|
||||||
|
input_ids: list[int],
|
||||||
|
labels: list[int],
|
||||||
|
loss_weights: list[int],
|
||||||
|
temp_str: str,
|
||||||
|
temp_weight: float,
|
||||||
|
) -> str:
|
||||||
|
"""Update model input with temporary string."""
|
||||||
|
if not temp_str:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
tokenizer = get_tokenizer(processor)
|
||||||
|
temp_ids = tokenizer.encode(temp_str, add_special_tokens=False)
|
||||||
|
input_ids.extend(temp_ids)
|
||||||
|
loss_weights.extend([temp_weight] * len(temp_ids))
|
||||||
|
if temp_weight > 1e-6:
|
||||||
|
labels.extend(temp_ids)
|
||||||
|
else:
|
||||||
|
labels.extend([IGNORE_INDEX] * len(temp_ids))
|
||||||
|
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
@RenderingPlugin("qwen3_nothink").register("render_messages")
|
||||||
|
def render_qwen_messages(
|
||||||
|
processor: Processor,
|
||||||
|
messages: list[Message],
|
||||||
|
tools: str | None = None,
|
||||||
|
is_generate: bool = False,
|
||||||
|
) -> ModelInput:
|
||||||
|
input_ids, labels, loss_weights = [], [], []
|
||||||
|
temp_str, temp_weight = "", 0.0
|
||||||
|
if tools:
|
||||||
|
temp_str += "<|im_start|>system\n"
|
||||||
|
if messages[0]["role"] == "system":
|
||||||
|
for content in messages[0]["content"]:
|
||||||
|
if content["type"] == "text":
|
||||||
|
temp_str += content["value"]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported content type: {content['type']}")
|
||||||
|
|
||||||
|
temp_str += "\n\n"
|
||||||
|
temp_weight = messages[0].get("loss_weight", 0.0)
|
||||||
|
|
||||||
|
temp_str += (
|
||||||
|
"# Tools\n\nYou may call one or more functions to assist with the user query.\n\n"
|
||||||
|
"You are provided with function signatures within <tools></tools> XML tags:\n<tools>"
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
tools = json.loads(tools)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
raise ValueError(f"Invalid tools format: {str(tools)}.")
|
||||||
|
|
||||||
|
if not isinstance(tools, list):
|
||||||
|
tools = [tools]
|
||||||
|
|
||||||
|
for tool in tools:
|
||||||
|
temp_str += "\n" + json.dumps(tool, ensure_ascii=False)
|
||||||
|
|
||||||
|
temp_str += (
|
||||||
|
"\n</tools>\n\nFor each function call, return a json object with function name "
|
||||||
|
'and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{"name": '
|
||||||
|
'<function-name>, "arguments": <args-json-object>}\n</tool_call><|im_end|>\n'
|
||||||
|
)
|
||||||
|
elif messages[0]["role"] == "system":
|
||||||
|
temp_str += "<|im_start|>system\n"
|
||||||
|
for content in messages[0]["content"]:
|
||||||
|
if content["type"] == "text":
|
||||||
|
temp_str += content["value"]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported content type: {content['type']}")
|
||||||
|
|
||||||
|
temp_str += "<|im_end|>\n"
|
||||||
|
temp_weight = messages[0].get("loss_weight", 0.0)
|
||||||
|
|
||||||
|
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
|
||||||
|
|
||||||
|
for turn_idx, message in enumerate(messages):
|
||||||
|
if message["role"] == "user" or (message["role"] == "system" and turn_idx != 0):
|
||||||
|
temp_str += "<|im_start|>" + message["role"] + "\n"
|
||||||
|
for content in message["content"]:
|
||||||
|
if content["type"] == "text":
|
||||||
|
temp_str += content["value"]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported content type: {content['type']}")
|
||||||
|
|
||||||
|
temp_str += "<|im_end|>\n"
|
||||||
|
temp_weight = message.get("loss_weight", 0.0)
|
||||||
|
elif message["role"] == "assistant":
|
||||||
|
temp_str += "<|im_start|>" + message["role"] + "\n"
|
||||||
|
for val_idx, content in enumerate(message["content"]):
|
||||||
|
if content["type"] == "text":
|
||||||
|
temp_str += content["value"]
|
||||||
|
elif content["type"] == "reasoning":
|
||||||
|
temp_str += "<thinking>\n" + content["value"] + "\n</thinking>\n\n" # avoid using special tokens
|
||||||
|
elif content["type"] == "tool_call":
|
||||||
|
if val_idx != 0 and message["content"][val_idx - 1]["type"] in ["text", "tool_call"]:
|
||||||
|
temp_str += "\n"
|
||||||
|
|
||||||
|
try:
|
||||||
|
tool_call: ToolCall = json.loads(content["value"])
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
raise ValueError(f"Invalid tool call format: {content['value']}.")
|
||||||
|
|
||||||
|
temp_str += (
|
||||||
|
'<tool_call>\n{"name": "'
|
||||||
|
+ tool_call["name"]
|
||||||
|
+ '", "arguments": '
|
||||||
|
+ json.dumps(tool_call["arguments"], ensure_ascii=False)
|
||||||
|
+ "}\n</tool_call>"
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported content type: {content['type']}")
|
||||||
|
|
||||||
|
temp_str += "<|im_end|>\n"
|
||||||
|
temp_weight = message.get("loss_weight", 1.0)
|
||||||
|
elif message["role"] == "tool":
|
||||||
|
if turn_idx == 0 or messages[turn_idx - 1]["role"] != "tool":
|
||||||
|
temp_str += "<|im_start|>user"
|
||||||
|
|
||||||
|
temp_str += "\n<tool_response>\n"
|
||||||
|
for content in message["content"]:
|
||||||
|
if content["type"] == "text":
|
||||||
|
temp_str += content["value"]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported content type: {content['type']}")
|
||||||
|
|
||||||
|
temp_str += "\n</tool_response>"
|
||||||
|
if turn_idx == len(messages) - 1 or messages[turn_idx + 1]["role"] != "tool":
|
||||||
|
temp_str += "<|im_end|>\n"
|
||||||
|
|
||||||
|
temp_weight = message.get("loss_weight", 0.0)
|
||||||
|
|
||||||
|
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
|
||||||
|
|
||||||
|
if is_generate:
|
||||||
|
temp_str += "<|im_start|>assistant\n"
|
||||||
|
temp_weight = 0.0
|
||||||
|
|
||||||
|
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
|
||||||
|
|
||||||
|
attention_mask = [1] * len(input_ids)
|
||||||
|
return ModelInput(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
labels=labels,
|
||||||
|
loss_weights=loss_weights,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@RenderingPlugin("qwen3_nothink").register("parse_message")
|
||||||
|
def parse_qwen_message(generated_text: str) -> Message:
|
||||||
|
pattern = re.compile(r"<(thinking|tool_call)>\s*(.*?)\s*</\1>\s*", re.DOTALL)
|
||||||
|
content = []
|
||||||
|
last_end = 0
|
||||||
|
for match in pattern.finditer(generated_text):
|
||||||
|
start, end = match.span()
|
||||||
|
if start > last_end:
|
||||||
|
text = generated_text[last_end:start].strip()
|
||||||
|
if text:
|
||||||
|
content.append({"type": "text", "value": text})
|
||||||
|
|
||||||
|
tag_type = match.group(1)
|
||||||
|
tag_value = match.group(2).strip()
|
||||||
|
if tag_type == "thinking":
|
||||||
|
content.append({"type": "reasoning", "value": tag_value.strip()})
|
||||||
|
elif tag_type == "tool_call":
|
||||||
|
try:
|
||||||
|
json.loads(tag_value.strip())
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
raise ValueError(f"Invalid tool call format: {tag_value.strip()}.")
|
||||||
|
|
||||||
|
content.append({"type": "tool_call", "value": tag_value.strip()})
|
||||||
|
|
||||||
|
last_end = end
|
||||||
|
|
||||||
|
if last_end < len(generated_text):
|
||||||
|
text = generated_text[last_end:].strip()
|
||||||
|
if text:
|
||||||
|
content.append({"type": "text", "value": text})
|
||||||
|
|
||||||
|
return Message(role="assistant", content=content)
|
||||||
125
src/llamafactory/v1/samplers/cli_sampler.py
Normal file
125
src/llamafactory/v1/samplers/cli_sampler.py
Normal file
@@ -0,0 +1,125 @@
|
|||||||
|
# Copyright 2025 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
from collections.abc import Generator
|
||||||
|
from threading import Thread
|
||||||
|
|
||||||
|
from ..config import InputArgument, ModelArguments, SampleArguments, SampleBackend, get_args
|
||||||
|
from ..core.base_sampler import BaseSampler
|
||||||
|
from ..core.data_engine import DataEngine
|
||||||
|
from ..core.model_engine import ModelEngine
|
||||||
|
from ..core.utils.rendering import Renderer
|
||||||
|
from ..utils.types import HFModel, Message, Sample, TorchDataset
|
||||||
|
|
||||||
|
|
||||||
|
class SyncSampler(BaseSampler):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
args: SampleArguments,
|
||||||
|
model_args: ModelArguments,
|
||||||
|
model: HFModel,
|
||||||
|
renderer: Renderer,
|
||||||
|
) -> None:
|
||||||
|
def _start_background_loop(loop: asyncio.AbstractEventLoop) -> None:
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
loop.run_forever()
|
||||||
|
|
||||||
|
super().__init__(args, model_args, model, renderer)
|
||||||
|
self._loop = asyncio.new_event_loop()
|
||||||
|
self._thread = Thread(target=_start_background_loop, args=(self._loop,), daemon=True)
|
||||||
|
self._thread.start()
|
||||||
|
|
||||||
|
def generate(self, messages: list[Message], tools: str | None = None) -> Generator[str, None, None]:
|
||||||
|
"""Generate tokens synchronously.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: List of messages.
|
||||||
|
tools: Tools string.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
Generated tokens.
|
||||||
|
"""
|
||||||
|
generator = super().generate(messages, tools)
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
token = asyncio.run_coroutine_threadsafe(generator.__anext__(), self._loop).result()
|
||||||
|
yield token
|
||||||
|
except StopAsyncIteration:
|
||||||
|
break
|
||||||
|
|
||||||
|
def batch_infer(self, dataset: TorchDataset) -> list[Sample]:
|
||||||
|
"""Batch infer samples synchronously.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset: Torch dataset.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of samples.
|
||||||
|
"""
|
||||||
|
return asyncio.run_coroutine_threadsafe(super().batch_infer(dataset), self._loop).result()
|
||||||
|
|
||||||
|
|
||||||
|
def run_chat(args: InputArgument = None):
|
||||||
|
data_args, model_args, _, sample_args = get_args(args)
|
||||||
|
if sample_args.sample_backend != SampleBackend.HF:
|
||||||
|
model_args.init_plugin = {"name": "init_on_meta"}
|
||||||
|
|
||||||
|
model_engine = ModelEngine(model_args)
|
||||||
|
sampler = SyncSampler(sample_args, model_args, model_engine.model, model_engine.renderer)
|
||||||
|
if data_args.dataset is not None:
|
||||||
|
dataset = DataEngine(data_args)
|
||||||
|
sampler.batch_infer(dataset)
|
||||||
|
else:
|
||||||
|
if os.name != "nt":
|
||||||
|
try:
|
||||||
|
import readline # noqa: F401
|
||||||
|
except ImportError:
|
||||||
|
print("Install `readline` for a better experience.")
|
||||||
|
|
||||||
|
messages = []
|
||||||
|
print("Welcome to the CLI application, use `clear` to remove the history, use `exit` to exit the application.")
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
query = input("\nUser: ")
|
||||||
|
except UnicodeDecodeError:
|
||||||
|
print("Detected decoding error at the inputs, please set the terminal encoding to utf-8.")
|
||||||
|
continue
|
||||||
|
except Exception:
|
||||||
|
raise
|
||||||
|
|
||||||
|
if query.strip() == "exit":
|
||||||
|
break
|
||||||
|
|
||||||
|
if query.strip() == "clear":
|
||||||
|
messages = []
|
||||||
|
print("History has been removed.")
|
||||||
|
continue
|
||||||
|
|
||||||
|
messages.append({"role": "user", "content": [{"type": "text", "value": query}]})
|
||||||
|
print("Assistant: ", end="", flush=True)
|
||||||
|
|
||||||
|
response = ""
|
||||||
|
for new_text in sampler.generate(messages):
|
||||||
|
print(new_text, end="", flush=True)
|
||||||
|
response += new_text
|
||||||
|
|
||||||
|
print()
|
||||||
|
messages.append(model_engine.renderer.parse_message(response))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
run_chat()
|
||||||
@@ -17,7 +17,7 @@ from ..accelerator.interface import DistributedInterface
|
|||||||
from ..config.arg_parser import get_args
|
from ..config.arg_parser import get_args
|
||||||
from ..core.base_trainer import BaseTrainer
|
from ..core.base_trainer import BaseTrainer
|
||||||
from ..core.data_engine import DataEngine
|
from ..core.data_engine import DataEngine
|
||||||
from ..core.model_loader import ModelLoader
|
from ..core.model_engine import ModelEngine
|
||||||
|
|
||||||
|
|
||||||
class SFTTrainer(BaseTrainer):
|
class SFTTrainer(BaseTrainer):
|
||||||
@@ -28,11 +28,11 @@ def run_sft(user_args):
|
|||||||
model_args, data_args, training_args, _ = get_args(user_args)
|
model_args, data_args, training_args, _ = get_args(user_args)
|
||||||
DistributedInterface(training_args.dist_config)
|
DistributedInterface(training_args.dist_config)
|
||||||
data_engine = DataEngine(data_args)
|
data_engine = DataEngine(data_args)
|
||||||
model_loader = ModelLoader(model_args)
|
model_engine = ModelEngine(model_args)
|
||||||
trainer = SFTTrainer(
|
trainer = SFTTrainer(
|
||||||
args=training_args,
|
args=training_args,
|
||||||
model=model_loader.model,
|
model=model_engine.model,
|
||||||
processor=model_loader.processor,
|
processor=model_engine.processor,
|
||||||
dataset=data_engine,
|
dataset=data_engine,
|
||||||
)
|
)
|
||||||
trainer.fit()
|
trainer.fit()
|
||||||
|
|||||||
@@ -11,3 +11,5 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
IGNORE_INDEX = -100
|
||||||
|
|||||||
29
src/llamafactory/v1/utils/helper.py
Normal file
29
src/llamafactory/v1/utils/helper.py
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
# Copyright 2025 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from transformers import PreTrainedTokenizer
|
||||||
|
|
||||||
|
from .types import Processor
|
||||||
|
|
||||||
|
|
||||||
|
def get_tokenizer(processor: Processor) -> PreTrainedTokenizer:
|
||||||
|
"""Get tokenizer from processor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
processor: Processor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tokenizer.
|
||||||
|
"""
|
||||||
|
return processor.tokenizer if hasattr(processor, "tokenizer") else processor
|
||||||
@@ -54,7 +54,7 @@ def _get_default_logging_level() -> "logging._Level":
|
|||||||
|
|
||||||
|
|
||||||
def _get_library_name() -> str:
|
def _get_library_name() -> str:
|
||||||
return __name__.split(".")[0]
|
return ".".join(__name__.split(".")[:2]) # llamafactory.v1
|
||||||
|
|
||||||
|
|
||||||
def _get_library_root_logger() -> "_Logger":
|
def _get_library_root_logger() -> "_Logger":
|
||||||
|
|||||||
@@ -13,6 +13,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
from collections import defaultdict
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
|
|
||||||
from . import logging
|
from . import logging
|
||||||
@@ -27,7 +28,7 @@ class BasePlugin:
|
|||||||
A plugin is a callable object that can be registered and called by name.
|
A plugin is a callable object that can be registered and called by name.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_registry: dict[str, Callable] = {}
|
_registry: dict[str, dict[str, Callable]] = defaultdict(dict)
|
||||||
|
|
||||||
def __init__(self, name: str | None = None):
|
def __init__(self, name: str | None = None):
|
||||||
"""Initialize the plugin with a name.
|
"""Initialize the plugin with a name.
|
||||||
@@ -37,8 +38,7 @@ class BasePlugin:
|
|||||||
"""
|
"""
|
||||||
self.name = name
|
self.name = name
|
||||||
|
|
||||||
@property
|
def register(self, method_name: str = "__call__"):
|
||||||
def register(self):
|
|
||||||
"""Decorator to register a function as a plugin.
|
"""Decorator to register a function as a plugin.
|
||||||
|
|
||||||
Example usage:
|
Example usage:
|
||||||
@@ -46,16 +46,21 @@ class BasePlugin:
|
|||||||
@PrintPlugin("hello").register()
|
@PrintPlugin("hello").register()
|
||||||
def print_hello():
|
def print_hello():
|
||||||
print("Hello world!")
|
print("Hello world!")
|
||||||
|
|
||||||
|
|
||||||
|
@PrintPlugin("hello").register("again")
|
||||||
|
def print_hello_again():
|
||||||
|
print("Hello world! Again.")
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
if self.name is None:
|
if self.name is None:
|
||||||
raise ValueError("Plugin name is not specified.")
|
raise ValueError("Plugin name should be specified.")
|
||||||
|
|
||||||
if self.name in self._registry:
|
if method_name in self._registry[self.name]:
|
||||||
logger.warning_rank0_once(f"Plugin {self.name} is already registered.")
|
logger.warning_rank0_once(f"Method {method_name} of plugin {self.name} is already registered.")
|
||||||
|
|
||||||
def decorator(func: Callable) -> Callable:
|
def decorator(func: Callable) -> Callable:
|
||||||
self._registry[self.name] = func
|
self._registry[self.name][method_name] = func
|
||||||
return func
|
return func
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
@@ -68,10 +73,23 @@ class BasePlugin:
|
|||||||
PrintPlugin("hello")()
|
PrintPlugin("hello")()
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
if self.name not in self._registry:
|
if "__call__" not in self._registry[self.name]:
|
||||||
raise ValueError(f"Plugin {self.name} is not registered.")
|
raise ValueError(f"Method __call__ of plugin {self.name} is not registered.")
|
||||||
|
|
||||||
return self._registry[self.name](*args, **kwargs)
|
return self._registry[self.name]["__call__"](*args, **kwargs)
|
||||||
|
|
||||||
|
def __getattr__(self, method_name: str):
|
||||||
|
"""Get the registered function with the given name.
|
||||||
|
|
||||||
|
Example usage:
|
||||||
|
```python
|
||||||
|
PrintPlugin("hello").again()
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
if method_name not in self._registry[self.name]:
|
||||||
|
raise ValueError(f"Method {method_name} of plugin {self.name} is not registered.")
|
||||||
|
|
||||||
|
return self._registry[self.name][method_name]
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
@@ -82,8 +100,13 @@ if __name__ == "__main__":
|
|||||||
class PrintPlugin(BasePlugin):
|
class PrintPlugin(BasePlugin):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@PrintPlugin("hello").register
|
@PrintPlugin("hello").register()
|
||||||
def print_hello():
|
def print_hello():
|
||||||
print("Hello world!")
|
print("Hello world!")
|
||||||
|
|
||||||
|
@PrintPlugin("hello").register("again")
|
||||||
|
def print_hello_again():
|
||||||
|
print("Hello world! Again.")
|
||||||
|
|
||||||
PrintPlugin("hello")()
|
PrintPlugin("hello")()
|
||||||
|
PrintPlugin("hello").again()
|
||||||
|
|||||||
@@ -12,7 +12,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, Literal, NotRequired, TypedDict, Union
|
from typing import TYPE_CHECKING, Any, Literal, NotRequired, TypedDict, Union
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -84,27 +84,63 @@ class DistributedConfig(TypedDict, total=False):
|
|||||||
|
|
||||||
|
|
||||||
class Content(TypedDict):
|
class Content(TypedDict):
|
||||||
type: Literal["text", "reasoning", "tools", "tool_calls", "image_url"]
|
type: Literal["text", "reasoning", "tool_call", "image_url"]
|
||||||
|
"""Type of the content."""
|
||||||
value: str
|
value: str
|
||||||
|
"""Value of the content."""
|
||||||
|
|
||||||
|
|
||||||
class Message(TypedDict):
|
class Message(TypedDict):
|
||||||
role: Literal["system", "user", "assistant", "tool"]
|
role: Literal["system", "user", "assistant", "tool"]
|
||||||
|
"""Role of the message."""
|
||||||
content: list[Content]
|
content: list[Content]
|
||||||
loss_weight: float
|
"""Content of the message."""
|
||||||
|
loss_weight: NotRequired[float]
|
||||||
|
"""Loss weight for this message, default to 1.0. Required in training."""
|
||||||
|
|
||||||
|
|
||||||
class SFTSample(TypedDict):
|
class SFTSample(TypedDict):
|
||||||
messages: list[Message]
|
messages: list[Message]
|
||||||
|
"""Messages in the sample."""
|
||||||
|
tools: NotRequired[str]
|
||||||
|
"""Tools for the sample in JSON string format."""
|
||||||
extra_info: NotRequired[str]
|
extra_info: NotRequired[str]
|
||||||
|
"""Extra information for the sample, e.g. kto_labels."""
|
||||||
_dataset_name: NotRequired[str]
|
_dataset_name: NotRequired[str]
|
||||||
|
"""Dataset name for the sample."""
|
||||||
|
|
||||||
|
|
||||||
class DPOSample(TypedDict):
|
class DPOSample(TypedDict):
|
||||||
chosen_messages: list[Message]
|
chosen_messages: list[Message]
|
||||||
|
"""Chosen messages in the sample."""
|
||||||
rejected_messages: list[Message]
|
rejected_messages: list[Message]
|
||||||
|
"""Rejected messages in the sample."""
|
||||||
|
tools: NotRequired[str]
|
||||||
|
"""Tools for the sample in JSON string format."""
|
||||||
extra_info: NotRequired[str]
|
extra_info: NotRequired[str]
|
||||||
|
"""Extra information for the sample, e.g. kto_labels."""
|
||||||
_dataset_name: NotRequired[str]
|
_dataset_name: NotRequired[str]
|
||||||
|
"""Dataset name for the sample."""
|
||||||
|
|
||||||
|
|
||||||
Sample = Union[SFTSample, DPOSample]
|
Sample = Union[SFTSample, DPOSample]
|
||||||
|
|
||||||
|
|
||||||
|
class ToolCall(TypedDict):
|
||||||
|
name: str
|
||||||
|
"""Function name."""
|
||||||
|
arguments: dict[str, Any]
|
||||||
|
"""Function arguments."""
|
||||||
|
|
||||||
|
|
||||||
|
class ModelInput(TypedDict, total=False):
|
||||||
|
input_ids: list[int]
|
||||||
|
"""Input ids for the model."""
|
||||||
|
attention_mask: list[int]
|
||||||
|
"""Attention mask for the model."""
|
||||||
|
labels: list[int]
|
||||||
|
"""Labels for the model."""
|
||||||
|
loss_weights: list[float]
|
||||||
|
"""Loss weight for each token, default to 1.0."""
|
||||||
|
position_ids: NotRequired[list[int] | list[list[int]]]
|
||||||
|
"""Position ids for the model (optional)."""
|
||||||
|
|||||||
@@ -12,13 +12,13 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
"""LLaMA-Factory test configuration.
|
"""LlamaFactory test configuration.
|
||||||
|
|
||||||
Contains shared fixtures, pytest configuration, and custom markers.
|
Contains shared fixtures, pytest configuration, and custom markers.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from typing import Optional
|
import sys
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
@@ -73,7 +73,7 @@ def _handle_slow_tests(items: list[Item]):
|
|||||||
item.add_marker(skip_slow)
|
item.add_marker(skip_slow)
|
||||||
|
|
||||||
|
|
||||||
def _get_visible_devices_env() -> Optional[str]:
|
def _get_visible_devices_env() -> str | None:
|
||||||
"""Return device visibility env var name."""
|
"""Return device visibility env var name."""
|
||||||
if CURRENT_DEVICE == "cuda":
|
if CURRENT_DEVICE == "cuda":
|
||||||
return "CUDA_VISIBLE_DEVICES"
|
return "CUDA_VISIBLE_DEVICES"
|
||||||
@@ -110,11 +110,10 @@ def _handle_device_visibility(items: list[Item]):
|
|||||||
def pytest_collection_modifyitems(config: Config, items: list[Item]):
|
def pytest_collection_modifyitems(config: Config, items: list[Item]):
|
||||||
"""Modify test collection based on markers and environment."""
|
"""Modify test collection based on markers and environment."""
|
||||||
# Handle version compatibility (from HEAD)
|
# Handle version compatibility (from HEAD)
|
||||||
if not is_transformers_version_greater_than("4.57.0"):
|
skip_bc = pytest.mark.skip(reason="Skip backward compatibility tests")
|
||||||
skip_bc = pytest.mark.skip(reason="Skip backward compatibility tests")
|
for item in items:
|
||||||
for item in items:
|
if "tests_v1" in str(item.fspath) and not is_transformers_version_greater_than("4.57.0"):
|
||||||
if "tests_v1" in str(item.fspath):
|
item.add_marker(skip_bc)
|
||||||
item.add_marker(skip_bc)
|
|
||||||
|
|
||||||
_handle_slow_tests(items)
|
_handle_slow_tests(items)
|
||||||
_handle_runs_on(items)
|
_handle_runs_on(items)
|
||||||
@@ -150,12 +149,21 @@ def _manage_distributed_env(request: FixtureRequest, monkeypatch: MonkeyPatch) -
|
|||||||
devices_str = ",".join(str(i) for i in range(required))
|
devices_str = ",".join(str(i) for i in range(required))
|
||||||
|
|
||||||
monkeypatch.setenv(env_key, devices_str)
|
monkeypatch.setenv(env_key, devices_str)
|
||||||
|
|
||||||
|
# add project root dir to path for mp run
|
||||||
|
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||||
|
if project_root not in sys.path:
|
||||||
|
sys.path.insert(0, project_root)
|
||||||
|
|
||||||
|
os.environ["PYTHONPATH"] = project_root + os.pathsep + os.environ.get("PYTHONPATH", "")
|
||||||
|
|
||||||
else: # non-distributed test
|
else: # non-distributed test
|
||||||
if old_value:
|
if old_value:
|
||||||
visible_devices = [v for v in old_value.split(",") if v != ""]
|
visible_devices = [v for v in old_value.split(",") if v != ""]
|
||||||
monkeypatch.setenv(env_key, visible_devices[0] if visible_devices else "0")
|
monkeypatch.setenv(env_key, visible_devices[0] if visible_devices else "0")
|
||||||
else:
|
else:
|
||||||
monkeypatch.setenv(env_key, "0")
|
monkeypatch.setenv(env_key, "0")
|
||||||
|
|
||||||
if CURRENT_DEVICE == "cuda":
|
if CURRENT_DEVICE == "cuda":
|
||||||
monkeypatch.setattr(torch.cuda, "device_count", lambda: 1)
|
monkeypatch.setattr(torch.cuda, "device_count", lambda: 1)
|
||||||
elif CURRENT_DEVICE == "npu":
|
elif CURRENT_DEVICE == "npu":
|
||||||
|
|||||||
@@ -292,3 +292,91 @@ def test_qwen_multi_tool_extractor():
|
|||||||
("test_tool", """{"foo": "bar", "size": 10}"""),
|
("test_tool", """{"foo": "bar", "size": 10}"""),
|
||||||
("another_tool", """{"foo": "job", "size": 2}"""),
|
("another_tool", """{"foo": "job", "size": 2}"""),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.runs_on(["cpu", "mps"])
|
||||||
|
def test_lfm_function_formatter():
|
||||||
|
formatter = FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="lfm")
|
||||||
|
tool_calls = json.dumps(FUNCTION)
|
||||||
|
assert formatter.apply(content=tool_calls) == [
|
||||||
|
"""<|tool_call_start|>[tool_name(foo="bar", size=10)]<|tool_call_end|><|im_end|>\n"""
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.runs_on(["cpu", "mps"])
|
||||||
|
def test_lfm_multi_function_formatter():
|
||||||
|
formatter = FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="lfm")
|
||||||
|
tool_calls = json.dumps([FUNCTION] * 2)
|
||||||
|
assert formatter.apply(content=tool_calls) == [
|
||||||
|
"""<|tool_call_start|>[tool_name(foo="bar", size=10), tool_name(foo="bar", size=10)]<|tool_call_end|>"""
|
||||||
|
"<|im_end|>\n"
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.runs_on(["cpu", "mps"])
|
||||||
|
def test_lfm_tool_formatter():
|
||||||
|
formatter = ToolFormatter(tool_format="lfm")
|
||||||
|
assert formatter.apply(content=json.dumps(TOOLS)) == [
|
||||||
|
"List of tools: <|tool_list_start|>" + json.dumps(TOOLS, ensure_ascii=False) + "<|tool_list_end|>"
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.runs_on(["cpu", "mps"])
|
||||||
|
def test_lfm_tool_extractor():
|
||||||
|
formatter = ToolFormatter(tool_format="lfm")
|
||||||
|
result = """<|tool_call_start|>[test_tool(foo="bar", size=10)]<|tool_call_end|>"""
|
||||||
|
assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.runs_on(["cpu", "mps"])
|
||||||
|
def test_lfm_multi_tool_extractor():
|
||||||
|
formatter = ToolFormatter(tool_format="lfm")
|
||||||
|
result = """<|tool_call_start|>[test_tool(foo="bar", size=10), another_tool(foo="job", size=2)]<|tool_call_end|>"""
|
||||||
|
assert formatter.extract(result) == [
|
||||||
|
("test_tool", """{"foo": "bar", "size": 10}"""),
|
||||||
|
("another_tool", """{"foo": "job", "size": 2}"""),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.runs_on(["cpu", "mps"])
|
||||||
|
def test_lfm_tool_extractor_with_nested_dict():
|
||||||
|
formatter = ToolFormatter(tool_format="lfm")
|
||||||
|
result = """<|tool_call_start|>[search(query="test", options={"limit": 10, "offset": 0})]<|tool_call_end|>"""
|
||||||
|
extracted = formatter.extract(result)
|
||||||
|
assert len(extracted) == 1
|
||||||
|
assert extracted[0][0] == "search"
|
||||||
|
args = json.loads(extracted[0][1])
|
||||||
|
assert args["query"] == "test"
|
||||||
|
assert args["options"] == {"limit": 10, "offset": 0}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.runs_on(["cpu", "mps"])
|
||||||
|
def test_lfm_tool_extractor_with_list_arg():
|
||||||
|
formatter = ToolFormatter(tool_format="lfm")
|
||||||
|
result = """<|tool_call_start|>[batch_process(items=[1, 2, 3], enabled=True)]<|tool_call_end|>"""
|
||||||
|
extracted = formatter.extract(result)
|
||||||
|
assert len(extracted) == 1
|
||||||
|
assert extracted[0][0] == "batch_process"
|
||||||
|
args = json.loads(extracted[0][1])
|
||||||
|
assert args["items"] == [1, 2, 3]
|
||||||
|
assert args["enabled"] is True
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.runs_on(["cpu", "mps"])
|
||||||
|
def test_lfm_tool_extractor_no_match():
|
||||||
|
formatter = ToolFormatter(tool_format="lfm")
|
||||||
|
result = "This is a regular response without tool calls."
|
||||||
|
extracted = formatter.extract(result)
|
||||||
|
assert extracted == result
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.runs_on(["cpu", "mps"])
|
||||||
|
def test_lfm_tool_round_trip():
|
||||||
|
formatter = FunctionFormatter(slots=["{{content}}"], tool_format="lfm")
|
||||||
|
tool_formatter = ToolFormatter(tool_format="lfm")
|
||||||
|
original = {"name": "my_func", "arguments": {"arg1": "hello", "arg2": 42, "arg3": True}}
|
||||||
|
formatted = formatter.apply(content=json.dumps(original))
|
||||||
|
extracted = tool_formatter.extract(formatted[0])
|
||||||
|
assert len(extracted) == 1
|
||||||
|
assert extracted[0][0] == original["name"]
|
||||||
|
assert json.loads(extracted[0][1]) == original["arguments"]
|
||||||
|
|||||||
@@ -1,2 +1,2 @@
|
|||||||
# change if test fails or cache is outdated
|
# change if test fails or cache is outdated
|
||||||
0.9.4.105
|
0.9.5.103
|
||||||
|
|||||||
@@ -24,7 +24,6 @@ def test_get_args_from_yaml(tmp_path: pathlib.Path):
|
|||||||
### model
|
### model
|
||||||
model: "llamafactory/tiny-random-qwen2.5"
|
model: "llamafactory/tiny-random-qwen2.5"
|
||||||
trust_remote_code: true
|
trust_remote_code: true
|
||||||
use_fast_processor: true
|
|
||||||
model_class: "llm"
|
model_class: "llm"
|
||||||
kernel_config:
|
kernel_config:
|
||||||
name: "auto"
|
name: "auto"
|
||||||
|
|||||||
@@ -12,7 +12,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
"""LLaMA-Factory test configuration.
|
"""LlamaFactory test configuration.
|
||||||
|
|
||||||
Contains shared fixtures, pytest configuration, and custom markers.
|
Contains shared fixtures, pytest configuration, and custom markers.
|
||||||
"""
|
"""
|
||||||
@@ -22,6 +22,7 @@ import sys
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
from pytest import Config, FixtureRequest, Item, MonkeyPatch
|
from pytest import Config, FixtureRequest, Item, MonkeyPatch
|
||||||
|
|
||||||
from llamafactory.v1.accelerator.helper import get_current_accelerator, get_device_count
|
from llamafactory.v1.accelerator.helper import get_current_accelerator, get_device_count
|
||||||
@@ -109,17 +110,24 @@ def _handle_device_visibility(items: list[Item]):
|
|||||||
def pytest_collection_modifyitems(config: Config, items: list[Item]):
|
def pytest_collection_modifyitems(config: Config, items: list[Item]):
|
||||||
"""Modify test collection based on markers and environment."""
|
"""Modify test collection based on markers and environment."""
|
||||||
# Handle version compatibility (from HEAD)
|
# Handle version compatibility (from HEAD)
|
||||||
if not is_transformers_version_greater_than("4.57.0"):
|
skip_bc = pytest.mark.skip(reason="Skip backward compatibility tests")
|
||||||
skip_bc = pytest.mark.skip(reason="Skip backward compatibility tests")
|
for item in items:
|
||||||
for item in items:
|
if "tests_v1" in str(item.fspath) and not is_transformers_version_greater_than("4.57.0"):
|
||||||
if "tests_v1" in str(item.fspath):
|
item.add_marker(skip_bc)
|
||||||
item.add_marker(skip_bc)
|
|
||||||
|
|
||||||
_handle_slow_tests(items)
|
_handle_slow_tests(items)
|
||||||
_handle_runs_on(items)
|
_handle_runs_on(items)
|
||||||
_handle_device_visibility(items)
|
_handle_device_visibility(items)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _cleanup_distributed_state():
|
||||||
|
"""Cleanup distributed state after each test."""
|
||||||
|
yield
|
||||||
|
if dist.is_initialized():
|
||||||
|
dist.destroy_process_group()
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def _manage_distributed_env(request: FixtureRequest, monkeypatch: MonkeyPatch) -> None:
|
def _manage_distributed_env(request: FixtureRequest, monkeypatch: MonkeyPatch) -> None:
|
||||||
"""Set environment variables for distributed tests if specific devices are requested."""
|
"""Set environment variables for distributed tests if specific devices are requested."""
|
||||||
@@ -155,6 +163,7 @@ def _manage_distributed_env(request: FixtureRequest, monkeypatch: MonkeyPatch) -
|
|||||||
monkeypatch.setenv(env_key, visible_devices[0] if visible_devices else "0")
|
monkeypatch.setenv(env_key, visible_devices[0] if visible_devices else "0")
|
||||||
else:
|
else:
|
||||||
monkeypatch.setenv(env_key, "0")
|
monkeypatch.setenv(env_key, "0")
|
||||||
|
|
||||||
if CURRENT_DEVICE == "cuda":
|
if CURRENT_DEVICE == "cuda":
|
||||||
monkeypatch.setattr(torch.cuda, "device_count", lambda: 1)
|
monkeypatch.setattr(torch.cuda, "device_count", lambda: 1)
|
||||||
elif CURRENT_DEVICE == "npu":
|
elif CURRENT_DEVICE == "npu":
|
||||||
|
|||||||
@@ -1,173 +0,0 @@
|
|||||||
# Copyright 2025 the LlamaFactory team.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
"""Integration tests for DataLoader with different combinations of packing and dynamic batching.
|
|
||||||
|
|
||||||
Tests the 4 scenarios:
|
|
||||||
a) non pack + non dynamic.
|
|
||||||
b) non pack + dynamic.
|
|
||||||
c) pack + non dynamic.
|
|
||||||
d) pack + dynamic.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch.utils.data import DataLoader as TorchDataLoader
|
|
||||||
from torch.utils.data import Dataset
|
|
||||||
from transformers import AutoTokenizer
|
|
||||||
|
|
||||||
from llamafactory.v1.config.data_args import DataArguments
|
|
||||||
from llamafactory.v1.core.data_engine import DataEngine
|
|
||||||
from llamafactory.v1.core.trainer_utils.data_collator import (
|
|
||||||
DefaultCollator,
|
|
||||||
)
|
|
||||||
from llamafactory.v1.core.trainer_utils.data_loader import DataLoader
|
|
||||||
from llamafactory.v1.plugins.data_plugins.template import QwenTemplate
|
|
||||||
from llamafactory.v1.utils.batching_queue import TextBatchingQueue
|
|
||||||
|
|
||||||
|
|
||||||
class TensorDataset(Dataset):
|
|
||||||
"""Wrapper dataset that converts DataEngine samples to tensor format."""
|
|
||||||
|
|
||||||
def __init__(self, data_engine: DataEngine, processor, template, max_samples: int = None):
|
|
||||||
self.data_engine = data_engine
|
|
||||||
self.processor = processor
|
|
||||||
self.template = template
|
|
||||||
self.max_samples = max_samples or len(data_engine)
|
|
||||||
self.tokenizer = processor.tokenizer if hasattr(processor, "tokenizer") else processor
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return min(self.max_samples, len(self.data_engine))
|
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
|
||||||
# Get sample from DataEngine
|
|
||||||
sample = self.data_engine[idx]
|
|
||||||
|
|
||||||
# Extract messages from sample
|
|
||||||
# DataEngine returns samples with format like {"messages": [...], ...}
|
|
||||||
# For llamafactory/v1-sft-demo, the format should have "messages" field
|
|
||||||
messages = None
|
|
||||||
if "messages" in sample:
|
|
||||||
messages = sample["messages"]
|
|
||||||
elif "conversations" in sample:
|
|
||||||
messages = sample["conversations"]
|
|
||||||
elif "conversation" in sample:
|
|
||||||
messages = sample["conversation"]
|
|
||||||
else:
|
|
||||||
# Try to find message-like fields (skip _dataset_name)
|
|
||||||
for key, value in sample.items():
|
|
||||||
if key.startswith("_"):
|
|
||||||
continue
|
|
||||||
if isinstance(value, list) and len(value) > 0:
|
|
||||||
# Check if it looks like a message list
|
|
||||||
if isinstance(value[0], dict) and "role" in value[0]:
|
|
||||||
messages = value
|
|
||||||
break
|
|
||||||
|
|
||||||
if messages is None:
|
|
||||||
raise ValueError(f"Could not find messages in sample: {list(sample.keys())}")
|
|
||||||
|
|
||||||
# Encode messages using template
|
|
||||||
encoded = self.template.encode_messages(self.tokenizer, messages)
|
|
||||||
|
|
||||||
# Convert to tensors
|
|
||||||
return {
|
|
||||||
"input_ids": torch.tensor(encoded["input_ids"], dtype=torch.long),
|
|
||||||
"attention_mask": torch.tensor(encoded["attention_mask"], dtype=torch.long),
|
|
||||||
"labels": torch.tensor(encoded["labels"], dtype=torch.long),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def create_real_dataset(max_samples: int = 20, batch_size: int = 4):
|
|
||||||
"""Create a real dataset using DataEngine."""
|
|
||||||
data_args = DataArguments(dataset="llamafactory/v1-sft-demo")
|
|
||||||
data_engine = DataEngine(data_args)
|
|
||||||
|
|
||||||
# Create processor and template
|
|
||||||
processor = AutoTokenizer.from_pretrained("llamafactory/tiny-random-qwen2.5")
|
|
||||||
template = QwenTemplate()
|
|
||||||
|
|
||||||
# Create tensor dataset
|
|
||||||
raw_data_dataset = TensorDataset(data_engine, processor, template, max_samples=max_samples)
|
|
||||||
|
|
||||||
# Create torch DataLoader
|
|
||||||
torch_dataloader = TorchDataLoader(
|
|
||||||
raw_data_dataset,
|
|
||||||
batch_size=batch_size,
|
|
||||||
shuffle=False,
|
|
||||||
collate_fn=lambda x: x,
|
|
||||||
)
|
|
||||||
|
|
||||||
return torch_dataloader, processor, template
|
|
||||||
|
|
||||||
|
|
||||||
class TestDataLoaderNonPackNonDynamic:
|
|
||||||
"""Test case a) non pack + non dynamic."""
|
|
||||||
|
|
||||||
def test_basic_functionality(self):
|
|
||||||
"""Test DataLoader without packing and without dynamic batching."""
|
|
||||||
# Create real dataset
|
|
||||||
torch_dataloader, processor, template = create_real_dataset(max_samples=80, batch_size=8)
|
|
||||||
|
|
||||||
# Create collator (non-packing)
|
|
||||||
collator = DefaultCollator(processor=processor, template=template)
|
|
||||||
|
|
||||||
# Create DataLoader without batching_queue (non-dynamic)
|
|
||||||
data_loader = DataLoader(
|
|
||||||
dataloader=torch_dataloader,
|
|
||||||
collate_fn=collator,
|
|
||||||
num_micro_batch=1,
|
|
||||||
batching_queue=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Iterate and check results
|
|
||||||
batches = list(iter(data_loader))
|
|
||||||
assert len(batches) > 0
|
|
||||||
|
|
||||||
# Check first batch
|
|
||||||
one_batch = batches[0]
|
|
||||||
micro_batches = one_batch[0]
|
|
||||||
assert "input_ids" in micro_batches
|
|
||||||
assert "attention_mask" in micro_batches
|
|
||||||
assert "labels" in micro_batches
|
|
||||||
assert micro_batches["input_ids"].shape[0] == 1 # batch_size=1
|
|
||||||
assert micro_batches["input_ids"].ndim == 2 # [batch_size, seq_len]
|
|
||||||
|
|
||||||
|
|
||||||
class TestDataLoaderNonPackDynamic:
|
|
||||||
"""Test case b) non pack + dynamic."""
|
|
||||||
|
|
||||||
def test_basic_functionality(self):
|
|
||||||
"""Test DataLoader without packing but with dynamic batching."""
|
|
||||||
# Create real dataset
|
|
||||||
torch_dataloader, processor, template = create_real_dataset(max_samples=80, batch_size=8)
|
|
||||||
collator = DefaultCollator(processor=processor, template=template)
|
|
||||||
|
|
||||||
# Create batching queue for dynamic batching
|
|
||||||
batching_queue = TextBatchingQueue(
|
|
||||||
token_micro_bsz=120,
|
|
||||||
buffer_size=8,
|
|
||||||
)
|
|
||||||
|
|
||||||
data_loader = DataLoader(
|
|
||||||
dataloader=torch_dataloader,
|
|
||||||
collate_fn=collator,
|
|
||||||
num_micro_batch=4,
|
|
||||||
batching_queue=batching_queue,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Iterate and check
|
|
||||||
batches = list(iter(data_loader))
|
|
||||||
micro_batch_tokens_first = [micro_batch["attention_mask"].sum() for micro_batch in batches[0]]
|
|
||||||
assert all(num_tokens <= 120 for num_tokens in micro_batch_tokens_first)
|
|
||||||
assert len(batches) > 0
|
|
||||||
@@ -15,18 +15,18 @@
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from llamafactory.v1.config.model_args import ModelArguments, PluginConfig
|
from llamafactory.v1.config.model_args import ModelArguments, PluginConfig
|
||||||
from llamafactory.v1.core.model_loader import ModelLoader
|
from llamafactory.v1.core.model_engine import ModelEngine
|
||||||
|
|
||||||
|
|
||||||
def test_tiny_qwen():
|
def test_tiny_qwen():
|
||||||
from transformers import Qwen2Config, Qwen2ForCausalLM, Qwen2TokenizerFast
|
from transformers import Qwen2Config, Qwen2ForCausalLM, Qwen2TokenizerFast
|
||||||
|
|
||||||
model_args = ModelArguments(model="llamafactory/tiny-random-qwen2.5")
|
model_args = ModelArguments(model="llamafactory/tiny-random-qwen2.5")
|
||||||
model_loader = ModelLoader(model_args)
|
model_engine = ModelEngine(model_args)
|
||||||
assert isinstance(model_loader.processor, Qwen2TokenizerFast)
|
assert isinstance(model_engine.processor, Qwen2TokenizerFast)
|
||||||
assert isinstance(model_loader.model.config, Qwen2Config)
|
assert isinstance(model_engine.model_config, Qwen2Config)
|
||||||
assert isinstance(model_loader.model, Qwen2ForCausalLM)
|
assert isinstance(model_engine.model, Qwen2ForCausalLM)
|
||||||
assert model_loader.model.dtype == torch.bfloat16
|
assert model_engine.model.dtype == torch.bfloat16
|
||||||
|
|
||||||
|
|
||||||
def test_tiny_qwen_with_kernel_plugin():
|
def test_tiny_qwen_with_kernel_plugin():
|
||||||
@@ -37,13 +37,14 @@ def test_tiny_qwen_with_kernel_plugin():
|
|||||||
model_args = ModelArguments(
|
model_args = ModelArguments(
|
||||||
model="llamafactory/tiny-random-qwen2.5", kernel_config=PluginConfig(name="auto", include_kernels="auto")
|
model="llamafactory/tiny-random-qwen2.5", kernel_config=PluginConfig(name="auto", include_kernels="auto")
|
||||||
)
|
)
|
||||||
model_loader = ModelLoader(model_args)
|
model_engine = ModelEngine(model_args)
|
||||||
# test enable apply kernel plugin
|
# test enable apply kernel plugin
|
||||||
if hasattr(torch, "npu"):
|
if hasattr(torch, "npu"):
|
||||||
assert model_loader.model.model.layers[0].input_layernorm.forward.__code__ == npu_rms_norm_forward.__code__
|
assert model_engine.model.model.layers[0].input_layernorm.forward.__code__ == npu_rms_norm_forward.__code__
|
||||||
else:
|
else:
|
||||||
assert model_loader.model.model.layers[0].input_layernorm.forward.__code__ != npu_rms_norm_forward.__code__
|
assert model_engine.model.model.layers[0].input_layernorm.forward.__code__ != npu_rms_norm_forward.__code__
|
||||||
assert isinstance(model_loader.model, Qwen2ForCausalLM)
|
|
||||||
|
assert isinstance(model_engine.model, Qwen2ForCausalLM)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
171
tests_v1/core/utils/test_data_loader.py
Normal file
171
tests_v1/core/utils/test_data_loader.py
Normal file
@@ -0,0 +1,171 @@
|
|||||||
|
# Copyright 2025 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""Integration tests for DataLoader with different combinations of packing and dynamic batching.
|
||||||
|
|
||||||
|
Tests the 4 scenarios:
|
||||||
|
a) non pack + non dynamic.
|
||||||
|
b) non pack + dynamic.
|
||||||
|
c) pack + non dynamic.
|
||||||
|
d) pack + dynamic.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# import torch
|
||||||
|
# from torch.utils.data import DataLoader as TorchDataLoader
|
||||||
|
# from torch.utils.data import Dataset
|
||||||
|
# from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
# from llamafactory.v1.config.data_args import DataArguments
|
||||||
|
# from llamafactory.v1.core.data_engine import DataEngine
|
||||||
|
# from llamafactory.v1.core.utils.data_collator import DefaultCollator
|
||||||
|
# from llamafactory.v1.core.utils.data_loader import DataLoader
|
||||||
|
# from llamafactory.v1.plugins.data_plugins.rendering import QwenTemplate
|
||||||
|
# from llamafactory.v1.utils.batching_queue import TextBatchingQueue
|
||||||
|
|
||||||
|
|
||||||
|
# class TensorDataset(Dataset):
|
||||||
|
# """Wrapper dataset that converts DataEngine samples to tensor format."""
|
||||||
|
|
||||||
|
# def __init__(self, data_engine: DataEngine, processor, template, max_samples: int = None):
|
||||||
|
# self.data_engine = data_engine
|
||||||
|
# self.processor = processor
|
||||||
|
# self.template = template
|
||||||
|
# self.max_samples = max_samples or len(data_engine)
|
||||||
|
# self.tokenizer = processor.tokenizer if hasattr(processor, "tokenizer") else processor
|
||||||
|
|
||||||
|
# def __len__(self):
|
||||||
|
# return min(self.max_samples, len(self.data_engine))
|
||||||
|
|
||||||
|
# def __getitem__(self, idx):
|
||||||
|
# # Get sample from DataEngine
|
||||||
|
# sample = self.data_engine[idx]
|
||||||
|
|
||||||
|
# # Extract messages from sample
|
||||||
|
# # DataEngine returns samples with format like {"messages": [...], ...}
|
||||||
|
# # For llamafactory/v1-sft-demo, the format should have "messages" field
|
||||||
|
# messages = None
|
||||||
|
# if "messages" in sample:
|
||||||
|
# messages = sample["messages"]
|
||||||
|
# elif "conversations" in sample:
|
||||||
|
# messages = sample["conversations"]
|
||||||
|
# elif "conversation" in sample:
|
||||||
|
# messages = sample["conversation"]
|
||||||
|
# else:
|
||||||
|
# # Try to find message-like fields (skip _dataset_name)
|
||||||
|
# for key, value in sample.items():
|
||||||
|
# if key.startswith("_"):
|
||||||
|
# continue
|
||||||
|
# if isinstance(value, list) and len(value) > 0:
|
||||||
|
# # Check if it looks like a message list
|
||||||
|
# if isinstance(value[0], dict) and "role" in value[0]:
|
||||||
|
# messages = value
|
||||||
|
# break
|
||||||
|
|
||||||
|
# if messages is None:
|
||||||
|
# raise ValueError(f"Could not find messages in sample: {list(sample.keys())}")
|
||||||
|
|
||||||
|
# # Encode messages using template
|
||||||
|
# encoded = self.template.encode_messages(self.tokenizer, messages)
|
||||||
|
|
||||||
|
# # Convert to tensors
|
||||||
|
# return {
|
||||||
|
# "input_ids": torch.tensor(encoded["input_ids"], dtype=torch.long),
|
||||||
|
# "attention_mask": torch.tensor(encoded["attention_mask"], dtype=torch.long),
|
||||||
|
# "labels": torch.tensor(encoded["labels"], dtype=torch.long),
|
||||||
|
# }
|
||||||
|
|
||||||
|
|
||||||
|
# def create_real_dataset(max_samples: int = 20, batch_size: int = 4):
|
||||||
|
# """Create a real dataset using DataEngine."""
|
||||||
|
# data_args = DataArguments(dataset="llamafactory/v1-sft-demo")
|
||||||
|
# data_engine = DataEngine(data_args)
|
||||||
|
|
||||||
|
# # Create processor and template
|
||||||
|
# processor = AutoTokenizer.from_pretrained("llamafactory/tiny-random-qwen2.5")
|
||||||
|
# template = QwenTemplate()
|
||||||
|
|
||||||
|
# # Create tensor dataset
|
||||||
|
# raw_data_dataset = TensorDataset(data_engine, processor, template, max_samples=max_samples)
|
||||||
|
|
||||||
|
# # Create torch DataLoader
|
||||||
|
# torch_dataloader = TorchDataLoader(
|
||||||
|
# raw_data_dataset,
|
||||||
|
# batch_size=batch_size,
|
||||||
|
# shuffle=False,
|
||||||
|
# collate_fn=lambda x: x,
|
||||||
|
# )
|
||||||
|
|
||||||
|
# return torch_dataloader, processor, template
|
||||||
|
|
||||||
|
|
||||||
|
# class TestDataLoaderNonPackNonDynamic:
|
||||||
|
# """Test case a) non pack + non dynamic."""
|
||||||
|
|
||||||
|
# def test_basic_functionality(self):
|
||||||
|
# """Test DataLoader without packing and without dynamic batching."""
|
||||||
|
# # Create real dataset
|
||||||
|
# torch_dataloader, processor, template = create_real_dataset(max_samples=80, batch_size=8)
|
||||||
|
|
||||||
|
# # Create collator (non-packing)
|
||||||
|
# collator = DefaultCollator(processor=processor, template=template)
|
||||||
|
|
||||||
|
# # Create DataLoader without batching_queue (non-dynamic)
|
||||||
|
# data_loader = DataLoader(
|
||||||
|
# dataloader=torch_dataloader,
|
||||||
|
# collate_fn=collator,
|
||||||
|
# num_micro_batch=1,
|
||||||
|
# batching_queue=None,
|
||||||
|
# )
|
||||||
|
|
||||||
|
# # Iterate and check results
|
||||||
|
# batches = list(iter(data_loader))
|
||||||
|
# assert len(batches) > 0
|
||||||
|
|
||||||
|
# # Check first batch
|
||||||
|
# one_batch = batches[0]
|
||||||
|
# micro_batches = one_batch[0]
|
||||||
|
# assert "input_ids" in micro_batches
|
||||||
|
# assert "attention_mask" in micro_batches
|
||||||
|
# assert "labels" in micro_batches
|
||||||
|
# assert micro_batches["input_ids"].shape[0] == 1 # batch_size=1
|
||||||
|
# assert micro_batches["input_ids"].ndim == 2 # [batch_size, seq_len]
|
||||||
|
|
||||||
|
|
||||||
|
# class TestDataLoaderNonPackDynamic:
|
||||||
|
# """Test case b) non pack + dynamic."""
|
||||||
|
|
||||||
|
# def test_basic_functionality(self):
|
||||||
|
# """Test DataLoader without packing but with dynamic batching."""
|
||||||
|
# # Create real dataset
|
||||||
|
# torch_dataloader, processor, template = create_real_dataset(max_samples=80, batch_size=8)
|
||||||
|
# collator = DefaultCollator(processor=processor, template=template)
|
||||||
|
|
||||||
|
# # Create batching queue for dynamic batching
|
||||||
|
# batching_queue = TextBatchingQueue(
|
||||||
|
# token_micro_bsz=120,
|
||||||
|
# buffer_size=8,
|
||||||
|
# )
|
||||||
|
|
||||||
|
# data_loader = DataLoader(
|
||||||
|
# dataloader=torch_dataloader,
|
||||||
|
# collate_fn=collator,
|
||||||
|
# num_micro_batch=4,
|
||||||
|
# batching_queue=batching_queue,
|
||||||
|
# )
|
||||||
|
|
||||||
|
# # Iterate and check
|
||||||
|
# batches = list(iter(data_loader))
|
||||||
|
# micro_batch_tokens_first = [micro_batch["attention_mask"].sum() for micro_batch in batches[0]]
|
||||||
|
# assert all(num_tokens <= 120 for num_tokens in micro_batch_tokens_first)
|
||||||
|
# assert len(batches) > 0
|
||||||
193
tests_v1/core/utils/test_rendering.py
Normal file
193
tests_v1/core/utils/test_rendering.py
Normal file
@@ -0,0 +1,193 @@
|
|||||||
|
# Copyright 2025 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import json
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
from llamafactory.v1.config import DataArguments
|
||||||
|
from llamafactory.v1.core.data_engine import DataEngine
|
||||||
|
from llamafactory.v1.core.utils.rendering import Renderer
|
||||||
|
from llamafactory.v1.utils.types import Processor
|
||||||
|
|
||||||
|
|
||||||
|
HF_MESSAGES = [
|
||||||
|
{"role": "system", "content": "You are a helpful assistant."},
|
||||||
|
{"role": "user", "content": "What is LLM?"},
|
||||||
|
{"role": "assistant", "content": "LLM stands for Large Language Model."},
|
||||||
|
]
|
||||||
|
|
||||||
|
V1_MESSAGES = [
|
||||||
|
{"role": "system", "content": [{"type": "text", "value": "You are a helpful assistant."}]},
|
||||||
|
{"role": "user", "content": [{"type": "text", "value": "What is LLM?"}]},
|
||||||
|
{"role": "assistant", "content": [{"type": "text", "value": "LLM stands for Large Language Model."}]},
|
||||||
|
]
|
||||||
|
|
||||||
|
HF_MESSAGES_WITH_TOOLS = [
|
||||||
|
{"role": "system", "content": "You are a helpful assistant."},
|
||||||
|
{"role": "user", "content": "What is 6*8?"},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": [{"type": "function", "function": {"name": "multiply", "arguments": {"a": 6, "b": 8}}}],
|
||||||
|
},
|
||||||
|
{"role": "tool", "content": "48."},
|
||||||
|
{"role": "assistant", "content": "The result of 6*8 is 48."},
|
||||||
|
]
|
||||||
|
|
||||||
|
V1_MESSAGES_WITH_TOOLS = [
|
||||||
|
{"role": "system", "content": [{"type": "text", "value": "You are a helpful assistant."}]},
|
||||||
|
{"role": "user", "content": [{"type": "text", "value": "What is 6*8?"}]},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": [{"type": "tool_call", "value": json.dumps({"name": "multiply", "arguments": {"a": 6, "b": 8}})}],
|
||||||
|
"loss_weight": 0.0,
|
||||||
|
},
|
||||||
|
{"role": "tool", "content": [{"type": "text", "value": "48."}]},
|
||||||
|
{"role": "assistant", "content": [{"type": "text", "value": "The result of 6*8 is 48."}]},
|
||||||
|
]
|
||||||
|
|
||||||
|
V1_TOOLS = [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "multiply",
|
||||||
|
"description": "A function that multiplies two numbers",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"a": {"type": "number", "description": "The first number to multiply"},
|
||||||
|
"b": {"type": "number", "description": "The second number to multiply"},
|
||||||
|
},
|
||||||
|
"required": ["a", "b"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_chatml_rendering():
|
||||||
|
tokenizer: Processor = AutoTokenizer.from_pretrained("llamafactory/tiny-random-qwen3")
|
||||||
|
renderer = Renderer(template="chatml", processor=tokenizer)
|
||||||
|
|
||||||
|
hf_inputs = tokenizer.apply_chat_template(HF_MESSAGES[:-1], add_generation_prompt=True)
|
||||||
|
v1_inputs = renderer.render_messages(V1_MESSAGES[:-1], is_generate=True)
|
||||||
|
assert v1_inputs["input_ids"] == hf_inputs
|
||||||
|
assert v1_inputs["attention_mask"] == [1] * len(hf_inputs)
|
||||||
|
assert v1_inputs["labels"] == [-100] * len(hf_inputs)
|
||||||
|
assert v1_inputs["loss_weights"] == [0.0] * len(hf_inputs)
|
||||||
|
|
||||||
|
hf_inputs_part = tokenizer.apply_chat_template(HF_MESSAGES[:-1], add_generation_prompt=False)
|
||||||
|
hf_inputs_full = tokenizer.apply_chat_template(HF_MESSAGES, add_generation_prompt=False)
|
||||||
|
v1_inputs_full = renderer.render_messages(V1_MESSAGES, is_generate=False)
|
||||||
|
assert v1_inputs_full["input_ids"] == hf_inputs_full
|
||||||
|
assert v1_inputs_full["attention_mask"] == [1] * len(hf_inputs_full)
|
||||||
|
assert v1_inputs_full["labels"] == [-100] * len(hf_inputs_part) + hf_inputs_full[len(hf_inputs_part) :]
|
||||||
|
assert v1_inputs_full["loss_weights"] == [0.0] * len(hf_inputs_part) + [1.0] * (
|
||||||
|
len(hf_inputs_full) - len(hf_inputs_part)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_chatml_parse():
|
||||||
|
tokenizer: Processor = AutoTokenizer.from_pretrained("llamafactory/tiny-random-qwen3")
|
||||||
|
renderer = Renderer(template="chatml", processor=tokenizer)
|
||||||
|
generated_text = "LLM stands for Large Language Model."
|
||||||
|
parsed_message = renderer.parse_message(generated_text)
|
||||||
|
assert parsed_message == V1_MESSAGES[-1]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("num_samples", [16])
|
||||||
|
def test_chatml_rendering_remote(num_samples: int):
|
||||||
|
tokenizer: Processor = AutoTokenizer.from_pretrained("llamafactory/tiny-random-qwen3")
|
||||||
|
renderer = Renderer(template="chatml", processor=tokenizer)
|
||||||
|
data_args = DataArguments(dataset="llamafactory/v1-sft-demo")
|
||||||
|
data_engine = DataEngine(data_args)
|
||||||
|
for index in range(num_samples):
|
||||||
|
v1_inputs = renderer.render_messages(data_engine[index]["messages"], is_generate=True)
|
||||||
|
prefix = tokenizer.encode("<|im_start|>user\n", add_special_tokens=False)
|
||||||
|
print(tokenizer.decode(v1_inputs["input_ids"][: len(prefix)]))
|
||||||
|
assert v1_inputs["input_ids"][: len(prefix)] == prefix
|
||||||
|
|
||||||
|
|
||||||
|
def test_qwen3_nothink_rendering():
|
||||||
|
tokenizer: Processor = AutoTokenizer.from_pretrained("Qwen/Qwen3-4B-Instruct-2507")
|
||||||
|
renderer = Renderer(template="qwen3_nothink", processor=tokenizer)
|
||||||
|
|
||||||
|
hf_inputs = tokenizer.apply_chat_template(HF_MESSAGES_WITH_TOOLS[:-1], tools=V1_TOOLS, add_generation_prompt=True)
|
||||||
|
v1_inputs = renderer.render_messages(V1_MESSAGES_WITH_TOOLS[:-1], tools=json.dumps(V1_TOOLS), is_generate=True)
|
||||||
|
assert v1_inputs["input_ids"] == hf_inputs
|
||||||
|
assert v1_inputs["attention_mask"] == [1] * len(hf_inputs)
|
||||||
|
assert v1_inputs["labels"] == [-100] * len(hf_inputs)
|
||||||
|
assert v1_inputs["loss_weights"] == [0.0] * len(hf_inputs)
|
||||||
|
|
||||||
|
hf_inputs_part = tokenizer.apply_chat_template(
|
||||||
|
HF_MESSAGES_WITH_TOOLS[:-1], tools=V1_TOOLS, add_generation_prompt=False
|
||||||
|
)
|
||||||
|
hf_inputs_full = tokenizer.apply_chat_template(HF_MESSAGES_WITH_TOOLS, tools=V1_TOOLS, add_generation_prompt=False)
|
||||||
|
v1_inputs_full = renderer.render_messages(V1_MESSAGES_WITH_TOOLS, tools=json.dumps(V1_TOOLS), is_generate=False)
|
||||||
|
assert v1_inputs_full["input_ids"] == hf_inputs_full
|
||||||
|
assert v1_inputs_full["attention_mask"] == [1] * len(hf_inputs_full)
|
||||||
|
assert v1_inputs_full["labels"] == [-100] * len(hf_inputs_part) + hf_inputs_full[len(hf_inputs_part) :]
|
||||||
|
assert v1_inputs_full["loss_weights"] == [0.0] * len(hf_inputs_part) + [1.0] * (
|
||||||
|
len(hf_inputs_full) - len(hf_inputs_part)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_qwen3_nothink_parse():
|
||||||
|
tokenizer: Processor = AutoTokenizer.from_pretrained("Qwen/Qwen3-4B-Instruct-2507")
|
||||||
|
renderer = Renderer(template="qwen3_nothink", processor=tokenizer)
|
||||||
|
generated_text = (
|
||||||
|
"<thinking>I need to use the multiply function to calculate 6*8.</thinking>"
|
||||||
|
"Let me call the multiply function."
|
||||||
|
'<tool_call>{"name": "multiply", "arguments": {"a": 6, "b": 8}}</tool_call>'
|
||||||
|
)
|
||||||
|
parsed_message = renderer.parse_message(generated_text)
|
||||||
|
assert parsed_message == {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": [
|
||||||
|
{"type": "reasoning", "value": "I need to use the multiply function to calculate 6*8."},
|
||||||
|
{"type": "text", "value": "Let me call the multiply function."},
|
||||||
|
{"type": "tool_call", "value": json.dumps({"name": "multiply", "arguments": {"a": 6, "b": 8}})},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("num_samples", [8])
|
||||||
|
def test_qwen3_nothink_rendering_remote(num_samples: int):
|
||||||
|
tokenizer: Processor = AutoTokenizer.from_pretrained("Qwen/Qwen3-4B-Instruct-2507")
|
||||||
|
renderer = Renderer(template="qwen3_nothink", processor=tokenizer)
|
||||||
|
data_args = DataArguments(dataset="llamafactory/reason-tool-use-demo-1500")
|
||||||
|
data_engine = DataEngine(data_args)
|
||||||
|
for index in range(num_samples):
|
||||||
|
v1_inputs = renderer.render_messages(data_engine[index]["messages"], tools=data_engine[index]["tools"])
|
||||||
|
prefix_text = (
|
||||||
|
"<|im_start|>system\nYou are a methodical and expert assistant. "
|
||||||
|
"Your primary goal is to solve user requests by leveraging a set of available tools. "
|
||||||
|
"You must reason for the best course of action in a structured manner before responding.\n\n"
|
||||||
|
"# Tools\n\nYou may call one or more functions to assist with the user query.\n\n"
|
||||||
|
"You are provided with function signatures within <tools></tools> XML tags:\n<tools>\n"
|
||||||
|
'{"type": "function", "function": {"name":'
|
||||||
|
)
|
||||||
|
prefix = tokenizer.encode(prefix_text, add_special_tokens=False)
|
||||||
|
print(tokenizer.decode(v1_inputs["input_ids"][: len(prefix)]))
|
||||||
|
assert v1_inputs["input_ids"][: len(prefix)] == prefix
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_chatml_rendering()
|
||||||
|
test_chatml_parse()
|
||||||
|
test_chatml_rendering_remote(16)
|
||||||
|
test_qwen3_nothink_rendering()
|
||||||
|
test_qwen3_nothink_parse()
|
||||||
|
test_qwen3_nothink_rendering_remote(16)
|
||||||
@@ -54,18 +54,18 @@ def test_sharegpt_converter():
|
|||||||
"conversations": [
|
"conversations": [
|
||||||
{"from": "system", "value": "System"},
|
{"from": "system", "value": "System"},
|
||||||
{"from": "human", "value": "User"},
|
{"from": "human", "value": "User"},
|
||||||
{"from": "function_call", "value": "Tool"},
|
{"from": "function_call", "value": "1"},
|
||||||
{"from": "observation", "value": "Observation"},
|
{"from": "observation", "value": "Observation"},
|
||||||
{"from": "gpt", "value": "Assistant"},
|
{"from": "gpt", "value": "Assistant"},
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
expected_data = {
|
expected_data = {
|
||||||
"messages": [
|
"messages": [
|
||||||
{"content": [{"type": "text", "value": "System"}], "loss_weight": 0.0, "role": "system"},
|
{"role": "system", "content": [{"type": "text", "value": "System"}], "loss_weight": 0.0},
|
||||||
{"content": [{"type": "text", "value": "User"}], "loss_weight": 0.0, "role": "user"},
|
{"role": "user", "content": [{"type": "text", "value": "User"}], "loss_weight": 0.0},
|
||||||
{"content": [{"type": "tool_calls", "value": "Tool"}], "loss_weight": 1.0, "role": "assistant"},
|
{"role": "assistant", "content": [{"type": "tool_call", "value": "1"}], "loss_weight": 1.0},
|
||||||
{"content": [{"type": "text", "value": "Observation"}], "loss_weight": 0.0, "role": "tool"},
|
{"role": "tool", "content": [{"type": "text", "value": "Observation"}], "loss_weight": 0.0},
|
||||||
{"content": [{"type": "text", "value": "Assistant"}], "loss_weight": 1.0, "role": "assistant"},
|
{"role": "assistant", "content": [{"type": "text", "value": "Assistant"}], "loss_weight": 1.0},
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
assert DataConverterPlugin("sharegpt")(example) == expected_data
|
assert DataConverterPlugin("sharegpt")(example) == expected_data
|
||||||
|
|||||||
54
tests_v1/plugins/model_plugins/test_init_plugin.py
Normal file
54
tests_v1/plugins/model_plugins/test_init_plugin.py
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
# Copyright 2025 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
from llamafactory.v1.accelerator.interface import DistributedInterface
|
||||||
|
from llamafactory.v1.config.arg_parser import get_args
|
||||||
|
from llamafactory.v1.core.model_engine import ModelEngine
|
||||||
|
|
||||||
|
|
||||||
|
def test_init_on_meta():
|
||||||
|
_, model_args, *_ = get_args(
|
||||||
|
dict(
|
||||||
|
model="llamafactory/tiny-random-qwen2.5",
|
||||||
|
init_config={"name": "init_on_meta"},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
model_engine = ModelEngine(model_args=model_args)
|
||||||
|
assert model_engine.model.device.type == "meta"
|
||||||
|
|
||||||
|
|
||||||
|
def test_init_on_rank0():
|
||||||
|
_, model_args, *_ = get_args(
|
||||||
|
dict(
|
||||||
|
model="llamafactory/tiny-random-qwen2.5",
|
||||||
|
init_config={"name": "init_on_rank0"},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
model_engine = ModelEngine(model_args=model_args)
|
||||||
|
if DistributedInterface().get_rank() == 0:
|
||||||
|
assert model_engine.model.device.type == "cpu"
|
||||||
|
else:
|
||||||
|
assert model_engine.model.device.type == "meta"
|
||||||
|
|
||||||
|
|
||||||
|
def test_init_on_default():
|
||||||
|
_, model_args, *_ = get_args(
|
||||||
|
dict(
|
||||||
|
model="llamafactory/tiny-random-qwen2.5",
|
||||||
|
init_config={"name": "init_on_default"},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
model_engine = ModelEngine(model_args=model_args)
|
||||||
|
assert model_engine.model.device == DistributedInterface().current_device
|
||||||
41
tests_v1/sampler/test_cli_sampler.py
Normal file
41
tests_v1/sampler/test_cli_sampler.py
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
# Copyright 2025 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from llamafactory.v1.config import ModelArguments, SampleArguments
|
||||||
|
from llamafactory.v1.core.model_engine import ModelEngine
|
||||||
|
from llamafactory.v1.samplers.cli_sampler import SyncSampler
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.runs_on(["cuda", "npu"])
|
||||||
|
def test_sync_sampler():
|
||||||
|
model_args = ModelArguments(model="Qwen/Qwen3-4B-Instruct-2507", template="qwen3_nothink")
|
||||||
|
sample_args = SampleArguments()
|
||||||
|
model_engine = ModelEngine(model_args)
|
||||||
|
sampler = SyncSampler(sample_args, model_args, model_engine.model, model_engine.renderer)
|
||||||
|
messages = [{"role": "user", "content": [{"type": "text", "value": "Say 'This is a test.'"}]}]
|
||||||
|
response = ""
|
||||||
|
for new_text in sampler.generate(messages):
|
||||||
|
response += new_text
|
||||||
|
|
||||||
|
print(response)
|
||||||
|
assert model_engine.renderer.parse_message(response) == {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": [{"type": "text", "value": "This is a test."}],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_sync_sampler()
|
||||||
Reference in New Issue
Block a user