mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-02-26 15:56:00 +08:00
Compare commits
4 Commits
0087bc253b
...
f60a6e3d01
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f60a6e3d01 | ||
|
|
81b8a50aa5 | ||
|
|
8600530002 | ||
|
|
9ae62c6fc0 |
3
.github/workflows/tests.yml
vendored
3
.github/workflows/tests.yml
vendored
@@ -70,7 +70,8 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
uv venv
|
uv venv
|
||||||
uv pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
|
uv pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
|
||||||
uv pip install -e ".[dev]"
|
uv pip install -e .
|
||||||
|
uv pip install -r requirements/dev.txt
|
||||||
|
|
||||||
- name: Install transformers
|
- name: Install transformers
|
||||||
if: ${{ matrix.transformers }}
|
if: ${{ matrix.transformers }}
|
||||||
|
|||||||
3
.github/workflows/tests_cuda.yml
vendored
3
.github/workflows/tests_cuda.yml
vendored
@@ -52,7 +52,8 @@ jobs:
|
|||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
uv venv
|
uv venv
|
||||||
uv pip install -e ".[dev]"
|
uv pip install -e .
|
||||||
|
uv pip install -r requirements/dev.txt
|
||||||
|
|
||||||
- name: Cache HuggingFace models
|
- name: Cache HuggingFace models
|
||||||
id: hf-hub-cache
|
id: hf-hub-cache
|
||||||
|
|||||||
5
.github/workflows/tests_npu.yml
vendored
5
.github/workflows/tests_npu.yml
vendored
@@ -58,8 +58,9 @@ jobs:
|
|||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
uv venv
|
uv venv
|
||||||
uv pip install torch-npu==${{matrix.pytorch_npu}}
|
uv pip install -r requirements/npu.txt
|
||||||
uv pip install -e ".[dev]"
|
uv pip install -e .
|
||||||
|
uv pip install -r requirements/dev.txt
|
||||||
|
|
||||||
- name: Install node
|
- name: Install node
|
||||||
run: |
|
run: |
|
||||||
|
|||||||
41
README.md
41
README.md
@@ -329,6 +329,7 @@ Read technical notes:
|
|||||||
| [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
|
| [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
|
||||||
| [VibeThinker-1.5B](https://huggingface.co/WeiboAI) | 1.5B | qwen3 |
|
| [VibeThinker-1.5B](https://huggingface.co/WeiboAI) | 1.5B | qwen3 |
|
||||||
| [Yi/Yi-1.5 (Code)](https://huggingface.co/01-ai) | 1.5B/6B/9B/34B | yi |
|
| [Yi/Yi-1.5 (Code)](https://huggingface.co/01-ai) | 1.5B/6B/9B/34B | yi |
|
||||||
|
| [Youtu-LLM](https://huggingface.co/tencent/) | 2B | youtu |
|
||||||
| [Yuan 2](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan |
|
| [Yuan 2](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan |
|
||||||
|
|
||||||
> [!NOTE]
|
> [!NOTE]
|
||||||
@@ -516,10 +517,11 @@ huggingface-cli login
|
|||||||
```bash
|
```bash
|
||||||
git clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git
|
git clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git
|
||||||
cd LLaMA-Factory
|
cd LLaMA-Factory
|
||||||
pip install -e ".[metrics]"
|
pip install -e .
|
||||||
|
pip install -r requirements/metrics.txt
|
||||||
```
|
```
|
||||||
|
|
||||||
Optional dependencies available: `metrics`, `deepspeed`. Install with: `pip install -e ".[metrics,deepspeed]"`
|
Optional dependencies available: `metrics`, `deepspeed`. Install with: `pip install -e . && pip install -r requirements/metrics.txt -r requirements/deepspeed.txt`
|
||||||
|
|
||||||
Additional dependencies for specific features are available in `examples/requirements/`.
|
Additional dependencies for specific features are available in `examples/requirements/`.
|
||||||
|
|
||||||
@@ -577,36 +579,21 @@ To enable FlashAttention-2 on the Windows platform, please use the script from [
|
|||||||
|
|
||||||
<details><summary>For Ascend NPU users</summary>
|
<details><summary>For Ascend NPU users</summary>
|
||||||
|
|
||||||
To install LLaMA Factory on Ascend NPU devices, please upgrade Python to version 3.10 or higher: `pip install -e . torch-npu==2.7.1`. Additionally, you need to install the **[Ascend CANN Toolkit and Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**. Please follow the [installation tutorial](https://www.hiascend.com/document/detail/en/CANNCommunityEdition/600alphaX/softwareinstall/instg/atlasdeploy_03_0031.html) or use the following commands:
|
To install LLaMA Factory on Ascend NPU devices, please upgrade Python to version 3.10 or higher: `pip install -r requirements/npu.txt`. Additionally, you need to install the **Ascend CANN Toolkit and Kernels**. Please follow the [installation tutorial](https://llamafactory.readthedocs.io/en/latest/advanced/npu_installation.html).
|
||||||
|
|
||||||
|
|
||||||
|
You can also download the pre-built Docker images:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# replace the url according to your CANN version and devices
|
# Docker Hub
|
||||||
# install CANN Toolkit
|
docker pull hiyouga/llamafactory:latest-npu-a2
|
||||||
wget https://ascend-repo.obs.cn-east-2.myhuaweicloud.com/Milan-ASL/Milan-ASL%20V100R001C20SPC702/Ascend-cann-toolkit_8.0.0.alpha002_linux-"$(uname -i)".run
|
docker pull hiyouga/llamafactory:latest-npu-a3
|
||||||
bash Ascend-cann-toolkit_8.0.0.alpha002_linux-"$(uname -i)".run --install
|
|
||||||
|
|
||||||
# install CANN Kernels
|
# quay.io
|
||||||
wget https://ascend-repo.obs.cn-east-2.myhuaweicloud.com/Milan-ASL/Milan-ASL%20V100R001C20SPC702/Ascend-cann-kernels-910b_8.0.0.alpha002_linux-"$(uname -i)".run
|
docker pull quay.io/ascend/llamafactory:latest-npu-a2
|
||||||
bash Ascend-cann-kernels-910b_8.0.0.alpha002_linux-"$(uname -i)".run --install
|
docker pull quay.io/ascend/llamafactory:latest-npu-a3
|
||||||
|
|
||||||
# set env variables
|
|
||||||
source /usr/local/Ascend/ascend-toolkit/set_env.sh
|
|
||||||
```
|
```
|
||||||
|
|
||||||
| Requirement | Minimum | Recommend |
|
|
||||||
| ------------ | ------- | -------------- |
|
|
||||||
| CANN | 8.0.RC1 | 8.0.0.alpha002 |
|
|
||||||
| torch | 2.1.0 | 2.7.1 |
|
|
||||||
| torch-npu | 2.1.0 | 2.7.1 |
|
|
||||||
| deepspeed | 0.13.2 | 0.13.2 |
|
|
||||||
| vllm-ascend | - | 0.7.3 |
|
|
||||||
|
|
||||||
Remember to use `ASCEND_RT_VISIBLE_DEVICES` instead of `CUDA_VISIBLE_DEVICES` to specify the device to use.
|
|
||||||
|
|
||||||
If you cannot infer model on NPU devices, try setting `do_sample: false` in the configurations.
|
|
||||||
|
|
||||||
Download the pre-built Docker images: [32GB](http://mirrors.cn-central-221.ovaijisuan.com/detail/130.html) | [64GB](http://mirrors.cn-central-221.ovaijisuan.com/detail/131.html)
|
|
||||||
|
|
||||||
#### Install BitsAndBytes
|
#### Install BitsAndBytes
|
||||||
|
|
||||||
To use QLoRA based on bitsandbytes on Ascend NPU, please follow these 3 steps:
|
To use QLoRA based on bitsandbytes on Ascend NPU, please follow these 3 steps:
|
||||||
|
|||||||
40
README_zh.md
40
README_zh.md
@@ -331,6 +331,7 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
|
|||||||
| [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
|
| [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
|
||||||
| [VibeThinker-1.5B](https://huggingface.co/WeiboAI) | 1.5B | qwen3 |
|
| [VibeThinker-1.5B](https://huggingface.co/WeiboAI) | 1.5B | qwen3 |
|
||||||
| [Yi/Yi-1.5 (Code)](https://huggingface.co/01-ai) | 1.5B/6B/9B/34B | yi |
|
| [Yi/Yi-1.5 (Code)](https://huggingface.co/01-ai) | 1.5B/6B/9B/34B | yi |
|
||||||
|
| [Youtu-LLM](https://huggingface.co/tencent/) | 2B | youtu |
|
||||||
| [Yuan 2](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan |
|
| [Yuan 2](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan |
|
||||||
|
|
||||||
> [!NOTE]
|
> [!NOTE]
|
||||||
@@ -518,10 +519,11 @@ huggingface-cli login
|
|||||||
```bash
|
```bash
|
||||||
git clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git
|
git clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git
|
||||||
cd LLaMA-Factory
|
cd LLaMA-Factory
|
||||||
pip install -e ".[metrics]"
|
pip install -e .
|
||||||
|
pip install -r requirements/metrics.txt
|
||||||
```
|
```
|
||||||
|
|
||||||
可选的额外依赖项:`metrics`、`deepspeed`。使用 `pip install -e ".[metrics,deepspeed]"` 安装。
|
可选的额外依赖项:`metrics`、`deepspeed`。使用 `pip install -e . && pip install -r requirements/metrics.txt -r requirements/deepspeed.txt` 安装。
|
||||||
|
|
||||||
其他可选依赖项请参考 `examples/requirements/` 目录下的文件。
|
其他可选依赖项请参考 `examples/requirements/` 目录下的文件。
|
||||||
|
|
||||||
@@ -579,36 +581,20 @@ pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/downl
|
|||||||
|
|
||||||
<details><summary>昇腾 NPU 用户指南</summary>
|
<details><summary>昇腾 NPU 用户指南</summary>
|
||||||
|
|
||||||
在昇腾 NPU 设备上安装 LLaMA Factory 时,请升级 Python 到 3.10 及以上,并需要指定额外依赖项,使用 `pip install -e . torch-npu==2.7.1` 命令安装。此外,还需要安装 **[Ascend CANN Toolkit 与 Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**,安装方法请参考[安装教程](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/80RC2alpha002/quickstart/quickstart/quickstart_18_0004.html)或使用以下命令:
|
在昇腾 NPU 设备上安装 LLaMA Factory 时,请升级 Python 到 3.10 及以上,并需要指定额外依赖项,使用 `pip install -r requirements/npu.txt` 命令安装。此外,还需要安装 **Ascend CANN Toolkit 与 Kernels**,安装方法请参考[安装教程](https://llamafactory.readthedocs.io/zh-cn/latest/advanced/npu_installation.html)。
|
||||||
|
|
||||||
|
您可以直接下载预安装的最新docker镜像:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# 请替换 URL 为 CANN 版本和设备型号对应的 URL
|
# Docker Hub
|
||||||
# 安装 CANN Toolkit
|
docker pull hiyouga/llamafactory:latest-npu-a2
|
||||||
wget https://ascend-repo.obs.cn-east-2.myhuaweicloud.com/Milan-ASL/Milan-ASL%20V100R001C17SPC701/Ascend-cann-toolkit_8.0.RC1.alpha001_linux-"$(uname -i)".run
|
docker pull hiyouga/llamafactory:latest-npu-a3
|
||||||
bash Ascend-cann-toolkit_8.0.RC1.alpha001_linux-"$(uname -i)".run --install
|
|
||||||
|
|
||||||
# 安装 CANN Kernels
|
# quay.io
|
||||||
wget https://ascend-repo.obs.cn-east-2.myhuaweicloud.com/Milan-ASL/Milan-ASL%20V100R001C17SPC701/Ascend-cann-kernels-910b_8.0.RC1.alpha001_linux.run
|
docker pull quay.io/ascend/llamafactory:latest-npu-a2
|
||||||
bash Ascend-cann-kernels-910b_8.0.RC1.alpha001_linux.run --install
|
docker pull quay.io/ascend/llamafactory:latest-npu-a3
|
||||||
|
|
||||||
# 设置环境变量
|
|
||||||
source /usr/local/Ascend/ascend-toolkit/set_env.sh
|
|
||||||
```
|
```
|
||||||
|
|
||||||
| 依赖项 | 至少 | 推荐 |
|
|
||||||
| ------------ | ------- | -------------- |
|
|
||||||
| CANN | 8.0.RC1 | 8.0.0.alpha002 |
|
|
||||||
| torch | 2.1.0 | 2.7.1 |
|
|
||||||
| torch-npu | 2.1.0 | 2.7.1 |
|
|
||||||
| deepspeed | 0.13.2 | 0.13.2 |
|
|
||||||
| vllm-ascend | - | 0.7.3 |
|
|
||||||
|
|
||||||
请使用 `ASCEND_RT_VISIBLE_DEVICES` 而非 `CUDA_VISIBLE_DEVICES` 来指定运算设备。
|
|
||||||
|
|
||||||
如果遇到无法正常推理的情况,请尝试设置 `do_sample: false`。
|
|
||||||
|
|
||||||
下载预构建 Docker 镜像:[32GB](http://mirrors.cn-central-221.ovaijisuan.com/detail/130.html) | [64GB](http://mirrors.cn-central-221.ovaijisuan.com/detail/131.html)
|
|
||||||
|
|
||||||
#### 安装 BitsAndBytes
|
#### 安装 BitsAndBytes
|
||||||
|
|
||||||
如果要在 Ascend NPU 上进行基于 bitsandbytes 的 QLoRA 量化微调,请执行如下步骤:
|
如果要在 Ascend NPU 上进行基于 bitsandbytes 的 QLoRA 量化微调,请执行如下步骤:
|
||||||
|
|||||||
@@ -32,7 +32,8 @@ RUN pip config set global.index-url "${PIP_INDEX}" && \
|
|||||||
COPY . /app
|
COPY . /app
|
||||||
|
|
||||||
# Install LLaMA Factory
|
# Install LLaMA Factory
|
||||||
RUN pip install --no-cache-dir --no-build-isolation -e ".[metrics,deepspeed]"
|
RUN pip install --no-cache-dir --no-build-isolation -e . && \
|
||||||
|
pip install --no-cache-dir --no-build-isolation -r requirements/metrics.txt -r requirements/deepspeed.txt
|
||||||
|
|
||||||
# Rebuild flash attention
|
# Rebuild flash attention
|
||||||
RUN if [ "${INSTALL_FLASHATTN}" == "true" ]; then \
|
RUN if [ "${INSTALL_FLASHATTN}" == "true" ]; then \
|
||||||
|
|||||||
@@ -60,7 +60,8 @@ WORKDIR /app
|
|||||||
COPY . /app
|
COPY . /app
|
||||||
|
|
||||||
# Install LLaMA Factory
|
# Install LLaMA Factory
|
||||||
RUN pip install --no-cache-dir -e ".[metrics]" --no-build-isolation
|
RUN pip install --no-cache-dir -e . --no-build-isolation && \
|
||||||
|
pip install --no-cache-dir -r requirements/metrics.txt --no-build-isolation
|
||||||
|
|
||||||
RUN pip install "git+https://github.com/alibaba/roll.git#subdirectory=mcore_adapter"
|
RUN pip install "git+https://github.com/alibaba/roll.git#subdirectory=mcore_adapter"
|
||||||
|
|
||||||
|
|||||||
@@ -35,7 +35,8 @@ COPY . /app
|
|||||||
# Install torch-npu
|
# Install torch-npu
|
||||||
RUN pip uninstall -y torch torchvision torchaudio && \
|
RUN pip uninstall -y torch torchvision torchaudio && \
|
||||||
pip install --no-cache-dir "torch==2.7.1" "torch-npu==2.7.1" "torchvision==0.22.1" "torchaudio==2.7.1" --index-url "${PYTORCH_INDEX}" && \
|
pip install --no-cache-dir "torch==2.7.1" "torch-npu==2.7.1" "torchvision==0.22.1" "torchaudio==2.7.1" --index-url "${PYTORCH_INDEX}" && \
|
||||||
pip install --no-cache-dir -e ".[metrics]" --no-build-isolation
|
pip install --no-cache-dir -e . --no-build-isolation && \
|
||||||
|
pip install --no-cache-dir -r requirements/metrics.txt --no-build-isolation
|
||||||
|
|
||||||
# Set up volumes
|
# Set up volumes
|
||||||
# VOLUME [ "/root/.cache/huggingface", "/app/shared_data", "/app/output" ]
|
# VOLUME [ "/root/.cache/huggingface", "/app/shared_data", "/app/output" ]
|
||||||
|
|||||||
@@ -34,7 +34,8 @@ COPY . /app
|
|||||||
|
|
||||||
# Reinstall pytorch rocm and install LLaMA Factory
|
# Reinstall pytorch rocm and install LLaMA Factory
|
||||||
RUN pip uninstall -y torch torchvision torchaudio && \
|
RUN pip uninstall -y torch torchvision torchaudio && \
|
||||||
pip install --no-cache-dir --no-build-isolation -e --pre ".[metrics,deepspeed]" --index-url "${PYTORCH_INDEX}"
|
pip install --no-cache-dir --no-build-isolation -e --pre . --index-url "${PYTORCH_INDEX}" && \
|
||||||
|
pip install --no-cache-dir --no-build-isolation -r requirements/metrics.txt -r requirements/deepspeed.txt --index-url "${PYTORCH_INDEX}"
|
||||||
|
|
||||||
# Rebuild flash attention
|
# Rebuild flash attention
|
||||||
RUN if [ "${INSTALL_FLASHATTN}" == "true" ]; then \
|
RUN if [ "${INSTALL_FLASHATTN}" == "true" ]; then \
|
||||||
|
|||||||
@@ -76,11 +76,6 @@ dependencies = [
|
|||||||
"sse-starlette"
|
"sse-starlette"
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.optional-dependencies]
|
|
||||||
dev = ["pre-commit", "ruff", "pytest", "build"]
|
|
||||||
metrics = ["nltk", "jieba", "rouge-chinese"]
|
|
||||||
deepspeed = ["deepspeed>=0.10.0,<=0.16.9"]
|
|
||||||
|
|
||||||
[project.scripts]
|
[project.scripts]
|
||||||
llamafactory-cli = "llamafactory.cli:main"
|
llamafactory-cli = "llamafactory.cli:main"
|
||||||
lmf = "llamafactory.cli:main"
|
lmf = "llamafactory.cli:main"
|
||||||
|
|||||||
1
requirements/deepspeed.txt
Normal file
1
requirements/deepspeed.txt
Normal file
@@ -0,0 +1 @@
|
|||||||
|
deepspeed>=0.10.0,<=0.16.9
|
||||||
4
requirements/dev.txt
Normal file
4
requirements/dev.txt
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
pre-commit
|
||||||
|
ruff
|
||||||
|
pytest
|
||||||
|
build
|
||||||
3
requirements/metrics.txt
Normal file
3
requirements/metrics.txt
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
nltk
|
||||||
|
jieba
|
||||||
|
rouge-chinese
|
||||||
4
requirements/npu.txt
Normal file
4
requirements/npu.txt
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
torch==2.7.1
|
||||||
|
torch-npu==2.7.1
|
||||||
|
torchvision==0.22.1
|
||||||
|
torchaudio==2.7.1
|
||||||
@@ -28,7 +28,7 @@ try:
|
|||||||
jieba.setLogLevel(logging.CRITICAL)
|
jieba.setLogLevel(logging.CRITICAL)
|
||||||
jieba.initialize()
|
jieba.initialize()
|
||||||
except ImportError:
|
except ImportError:
|
||||||
print("Please install llamafactory with `pip install -e .[metrics]`.")
|
print("Please install llamafactory with `pip install -r requirements/metrics.txt`.")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -2278,6 +2278,21 @@ register_template(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_template(
|
||||||
|
name="youtu",
|
||||||
|
format_user=StringFormatter(slots=["<|User|>{{content}}<|Assistant|>"]),
|
||||||
|
format_assistant=StringFormatter(slots=["{{content}}<|end_of_text|>"]),
|
||||||
|
format_system=StringFormatter(slots=["{{content}}"]),
|
||||||
|
format_function=FunctionFormatter(slots=["{{content}}"], tool_format="default"),
|
||||||
|
format_observation=StringFormatter(slots=["<tool_response>\n{{content}}\n</tool_response><|Assistant|>"]),
|
||||||
|
format_tools=ToolFormatter(tool_format="default"),
|
||||||
|
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||||
|
stop_words=["<|end_of_text|>"],
|
||||||
|
replace_eos=True,
|
||||||
|
template_class=ReasoningTemplate,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
register_template(
|
register_template(
|
||||||
name="yuan",
|
name="yuan",
|
||||||
format_user=StringFormatter(slots=["{{content}}", {"token": "<sep>"}]),
|
format_user=StringFormatter(slots=["{{content}}", {"token": "<sep>"}]),
|
||||||
|
|||||||
@@ -3846,6 +3846,21 @@ register_model_group(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_model_group(
|
||||||
|
models={
|
||||||
|
"Youtu-LLM-2B-Instruct": {
|
||||||
|
DownloadSource.DEFAULT: "tencent/Youtu-LLM-2B",
|
||||||
|
DownloadSource.MODELSCOPE: "Tencent-YouTu-Research/Youtu-LLM-2B",
|
||||||
|
},
|
||||||
|
"Youtu-LLM-2B-Base": {
|
||||||
|
DownloadSource.DEFAULT: "tencent/Youtu-LLM-2B-Base",
|
||||||
|
DownloadSource.MODELSCOPE: "Tencent-YouTu-Research/Youtu-LLM-2B-Base",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
template="youtu",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
"Yuan2-2B-Chat": {
|
"Yuan2-2B-Chat": {
|
||||||
|
|||||||
@@ -142,6 +142,7 @@ def _verify_model_args(
|
|||||||
logger.warning_rank0("We should use slow tokenizer for the Yi models. Change `use_fast_tokenizer` to False.")
|
logger.warning_rank0("We should use slow tokenizer for the Yi models. Change `use_fast_tokenizer` to False.")
|
||||||
model_args.use_fast_tokenizer = False
|
model_args.use_fast_tokenizer = False
|
||||||
|
|
||||||
|
|
||||||
def _check_extra_dependencies(
|
def _check_extra_dependencies(
|
||||||
model_args: "ModelArguments",
|
model_args: "ModelArguments",
|
||||||
finetuning_args: "FinetuningArguments",
|
finetuning_args: "FinetuningArguments",
|
||||||
|
|||||||
@@ -94,6 +94,7 @@ class RayArguments:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class Fp8Arguments:
|
class Fp8Arguments:
|
||||||
r"""Arguments pertaining to the FP8 training."""
|
r"""Arguments pertaining to the FP8 training."""
|
||||||
|
|
||||||
fp8: bool = field(
|
fp8: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={
|
metadata={
|
||||||
|
|||||||
@@ -139,14 +139,13 @@ def patch_config(
|
|||||||
setattr(config.text_config, "topk_method", "greedy")
|
setattr(config.text_config, "topk_method", "greedy")
|
||||||
|
|
||||||
architectures = getattr(config, "architectures", None)
|
architectures = getattr(config, "architectures", None)
|
||||||
|
if isinstance(architectures, list) and "InternVLChatModel" in architectures:
|
||||||
if isinstance(architectures, (list, tuple)) and "InternVLChatModel" in architectures:
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Please download the internvl models in a Hugging Face–compatible format "
|
"Please download the internvl models in a Hugging Face–compatible format "
|
||||||
"(for example, https://huggingface.co/OpenGVLab/InternVL3-8B-hf)."
|
"(for example, https://huggingface.co/OpenGVLab/InternVL3-8B-hf)."
|
||||||
)
|
)
|
||||||
|
|
||||||
if isinstance(architectures, (list, tuple)) and "LlavaLlamaForCausalLM" in architectures:
|
if isinstance(architectures, list) and "LlavaLlamaForCausalLM" in architectures:
|
||||||
raise ValueError("Please download llava models with hf-compatible format: https://huggingface.co/llava-hf")
|
raise ValueError("Please download llava models with hf-compatible format: https://huggingface.co/llava-hf")
|
||||||
|
|
||||||
if getattr(config, "model_type", None) == "internlm3" and not is_transformers_version_greater_than("4.47.1"):
|
if getattr(config, "model_type", None) == "internlm3" and not is_transformers_version_greater_than("4.47.1"):
|
||||||
|
|||||||
@@ -93,7 +93,10 @@ def create_fp8_kwargs(training_args: "TrainingArguments") -> list[Any]:
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
# Map FSDP all-gather setting if available (this affects the underlying implementation)
|
# Map FSDP all-gather setting if available (this affects the underlying implementation)
|
||||||
if hasattr(training_args, "fp8_enable_fsdp_float8_all_gather") and training_args.fp8_enable_fsdp_float8_all_gather:
|
if (
|
||||||
|
hasattr(training_args, "fp8_enable_fsdp_float8_all_gather")
|
||||||
|
and training_args.fp8_enable_fsdp_float8_all_gather
|
||||||
|
):
|
||||||
logger.info_rank0("FSDP float8 all-gather optimization requested")
|
logger.info_rank0("FSDP float8 all-gather optimization requested")
|
||||||
|
|
||||||
return [AORecipeKwargs(config=config, module_filter_func=module_filter_func)]
|
return [AORecipeKwargs(config=config, module_filter_func=module_filter_func)]
|
||||||
|
|||||||
@@ -19,7 +19,6 @@ import torch
|
|||||||
from transformers import Trainer
|
from transformers import Trainer
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
from ...extras.packages import is_transformers_version_greater_than
|
|
||||||
from ..callbacks import SaveProcessorCallback
|
from ..callbacks import SaveProcessorCallback
|
||||||
from ..fp8_utils import configure_fp8_environment, patch_accelerator_for_fp8, verify_fp8_status
|
from ..fp8_utils import configure_fp8_environment, patch_accelerator_for_fp8, verify_fp8_status
|
||||||
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
|
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
|
||||||
@@ -28,7 +27,7 @@ from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import ProcessorMixin
|
from transformers import ProcessorMixin
|
||||||
|
|
||||||
from ...hparams import FinetuningArguments, ModelArguments
|
from ...hparams import FinetuningArguments, ModelArguments, TrainingArguments
|
||||||
|
|
||||||
|
|
||||||
class CustomTrainer(Trainer):
|
class CustomTrainer(Trainer):
|
||||||
@@ -43,7 +42,7 @@ class CustomTrainer(Trainer):
|
|||||||
) -> None:
|
) -> None:
|
||||||
kwargs["processing_class"] = kwargs.pop("tokenizer")
|
kwargs["processing_class"] = kwargs.pop("tokenizer")
|
||||||
# Configure FP8 environment if enabled
|
# Configure FP8 environment if enabled
|
||||||
training_args = kwargs.get("args")
|
training_args: TrainingArguments = kwargs.get("args")
|
||||||
if training_args.fp8:
|
if training_args.fp8:
|
||||||
configure_fp8_environment(training_args)
|
configure_fp8_environment(training_args)
|
||||||
if getattr(training_args, "fp8_backend", "auto") == "te":
|
if getattr(training_args, "fp8_backend", "auto") == "te":
|
||||||
@@ -66,7 +65,7 @@ class CustomTrainer(Trainer):
|
|||||||
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
|
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
|
||||||
self.add_callback(BAdamCallback)
|
self.add_callback(BAdamCallback)
|
||||||
|
|
||||||
if training_args.fp8 and hasattr(self, "accelerator"): # verify FP8 status after trainer initialization
|
if training_args.fp8 and hasattr(self, "accelerator"): # verify FP8 status after trainer initialization
|
||||||
verify_fp8_status(self.accelerator, training_args)
|
verify_fp8_status(self.accelerator, training_args)
|
||||||
|
|
||||||
@override
|
@override
|
||||||
|
|||||||
@@ -27,7 +27,6 @@ from typing_extensions import override
|
|||||||
|
|
||||||
from ...extras import logging
|
from ...extras import logging
|
||||||
from ...extras.constants import IGNORE_INDEX
|
from ...extras.constants import IGNORE_INDEX
|
||||||
from ...extras.packages import is_transformers_version_greater_than
|
|
||||||
from ..callbacks import SaveProcessorCallback
|
from ..callbacks import SaveProcessorCallback
|
||||||
from ..fp8_utils import configure_fp8_environment, patch_accelerator_for_fp8, verify_fp8_status
|
from ..fp8_utils import configure_fp8_environment, patch_accelerator_for_fp8, verify_fp8_status
|
||||||
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
|
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
|
||||||
@@ -35,10 +34,10 @@ from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
|
|||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
from transformers import PreTrainedTokenizer, ProcessorMixin
|
from transformers import ProcessorMixin
|
||||||
from transformers.trainer import PredictionOutput
|
from transformers.trainer import PredictionOutput
|
||||||
|
|
||||||
from ...hparams import FinetuningArguments, ModelArguments
|
from ...hparams import FinetuningArguments, ModelArguments, TrainingArguments
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
@@ -57,7 +56,7 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
|||||||
) -> None:
|
) -> None:
|
||||||
kwargs["processing_class"] = kwargs.pop("tokenizer")
|
kwargs["processing_class"] = kwargs.pop("tokenizer")
|
||||||
# Configure FP8 environment if enabled
|
# Configure FP8 environment if enabled
|
||||||
training_args = kwargs.get("args")
|
training_args: TrainingArguments = kwargs.get("args")
|
||||||
if training_args.fp8:
|
if training_args.fp8:
|
||||||
configure_fp8_environment(training_args)
|
configure_fp8_environment(training_args)
|
||||||
if getattr(training_args, "fp8_backend", "auto") == "te":
|
if getattr(training_args, "fp8_backend", "auto") == "te":
|
||||||
@@ -88,7 +87,7 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
|||||||
|
|
||||||
self.compute_loss_func = dft_loss_func
|
self.compute_loss_func = dft_loss_func
|
||||||
|
|
||||||
if training_args.fp8 and hasattr(self, "accelerator"): # verify FP8 status after trainer initialization
|
if training_args.fp8 and hasattr(self, "accelerator"): # verify FP8 status after trainer initialization
|
||||||
verify_fp8_status(self.accelerator, training_args)
|
verify_fp8_status(self.accelerator, training_args)
|
||||||
|
|
||||||
@override
|
@override
|
||||||
|
|||||||
@@ -34,10 +34,14 @@ from typing import Any, Optional
|
|||||||
from torch.distributed import barrier, destroy_process_group, init_process_group
|
from torch.distributed import barrier, destroy_process_group, init_process_group
|
||||||
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
|
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
|
||||||
|
|
||||||
|
from ..utils import logging
|
||||||
from ..utils.types import DistributedConfig, ProcessGroup, Tensor, TensorLike
|
from ..utils.types import DistributedConfig, ProcessGroup, Tensor, TensorLike
|
||||||
from . import helper
|
from . import helper
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class Dim(str, Enum):
|
class Dim(str, Enum):
|
||||||
"""Dimension names."""
|
"""Dimension names."""
|
||||||
|
|
||||||
@@ -157,6 +161,7 @@ class DistributedInterface:
|
|||||||
self.data_device_mesh = None
|
self.data_device_mesh = None
|
||||||
|
|
||||||
self._initialized = True
|
self._initialized = True
|
||||||
|
logger.info_rank0(f"DistributedInterface initialized with strategy={self.strategy}.")
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
return (
|
return (
|
||||||
|
|||||||
@@ -0,0 +1,32 @@
|
|||||||
|
# Copyright 2025 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from .arg_parser import InputArgument, get_args
|
||||||
|
from .arg_utils import ModelClass, SampleBackend
|
||||||
|
from .data_args import DataArguments
|
||||||
|
from .model_args import ModelArguments
|
||||||
|
from .sample_args import SampleArguments
|
||||||
|
from .training_args import TrainingArguments
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"DataArguments",
|
||||||
|
"InputArgument",
|
||||||
|
"ModelArguments",
|
||||||
|
"ModelClass",
|
||||||
|
"SampleArguments",
|
||||||
|
"SampleBackend",
|
||||||
|
"TrainingArguments",
|
||||||
|
"get_args",
|
||||||
|
]
|
||||||
|
|||||||
@@ -27,14 +27,14 @@ class ModelArguments:
|
|||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Trust remote code from Hugging Face."},
|
metadata={"help": "Trust remote code from Hugging Face."},
|
||||||
)
|
)
|
||||||
use_fast_processor: bool = field(
|
|
||||||
default=True,
|
|
||||||
metadata={"help": "Use fast processor from Hugging Face."},
|
|
||||||
)
|
|
||||||
model_class: ModelClass = field(
|
model_class: ModelClass = field(
|
||||||
default=ModelClass.LLM,
|
default=ModelClass.LLM,
|
||||||
metadata={"help": "Model class from Hugging Face."},
|
metadata={"help": "Model class from Hugging Face."},
|
||||||
)
|
)
|
||||||
|
init_config: PluginConfig | None = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "Initialization configuration for the model."},
|
||||||
|
)
|
||||||
peft_config: PluginConfig | None = field(
|
peft_config: PluginConfig | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "PEFT configuration for the model."},
|
metadata={"help": "PEFT configuration for the model."},
|
||||||
@@ -49,6 +49,7 @@ class ModelArguments:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
|
self.init_config = get_plugin_config(self.init_config)
|
||||||
self.peft_config = get_plugin_config(self.peft_config)
|
self.peft_config = get_plugin_config(self.peft_config)
|
||||||
self.kernel_config = get_plugin_config(self.kernel_config)
|
self.kernel_config = get_plugin_config(self.kernel_config)
|
||||||
self.quant_config = get_plugin_config(self.quant_config)
|
self.quant_config = get_plugin_config(self.quant_config)
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ from .arg_utils import PluginConfig, get_plugin_config
|
|||||||
@dataclass
|
@dataclass
|
||||||
class TrainingArguments:
|
class TrainingArguments:
|
||||||
output_dir: str = field(
|
output_dir: str = field(
|
||||||
default=os.path.join("outputs", str(uuid4())),
|
default=os.path.join("outputs", str(uuid4().hex)),
|
||||||
metadata={"help": "Path to the output directory."},
|
metadata={"help": "Path to the output directory."},
|
||||||
)
|
)
|
||||||
micro_batch_size: int = field(
|
micro_batch_size: int = field(
|
||||||
|
|||||||
77
src/llamafactory/v1/core/base_sampler.py
Normal file
77
src/llamafactory/v1/core/base_sampler.py
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
# Copyright 2025 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
from ..config import ModelArguments, SampleArguments, SampleBackend
|
||||||
|
from ..utils.types import HFModel, Processor, TorchDataset
|
||||||
|
|
||||||
|
|
||||||
|
class BaseEngine(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
args: SampleArguments,
|
||||||
|
model_args: ModelArguments,
|
||||||
|
model: HFModel = None,
|
||||||
|
processor: Processor = None,
|
||||||
|
) -> None:
|
||||||
|
"""Initialize the engine.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
args: Sample arguments.
|
||||||
|
model_args: Model arguments.
|
||||||
|
model: Model.
|
||||||
|
processor: Processor.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def generate(self, messages):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def batch_infer(self, data: TorchDataset) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class HuggingFaceEngine(BaseEngine):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
args: SampleArguments,
|
||||||
|
model_args: ModelArguments,
|
||||||
|
model: HFModel,
|
||||||
|
processor: Processor,
|
||||||
|
) -> None:
|
||||||
|
self.args = args
|
||||||
|
|
||||||
|
|
||||||
|
class BaseSampler:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
args: SampleArguments,
|
||||||
|
model_args: ModelArguments,
|
||||||
|
model: HFModel,
|
||||||
|
processor: Processor,
|
||||||
|
) -> None:
|
||||||
|
if args.sample_backend == SampleBackend.HF:
|
||||||
|
self.engine = HuggingFaceEngine(args, model_args, model, processor)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown sample backend: {args.sample_backend}")
|
||||||
|
|
||||||
|
async def generate(self, messages):
|
||||||
|
return await self.engine.generate(messages)
|
||||||
|
|
||||||
|
async def batch_infer(self, data: TorchDataset) -> None:
|
||||||
|
return await self.engine.batch_infer(data)
|
||||||
@@ -1,44 +0,0 @@
|
|||||||
# Copyright 2025 the LlamaFactory team.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
|
|
||||||
from ..config.sample_args import SampleArguments, SampleBackend
|
|
||||||
from .model_loader import ModelLoader
|
|
||||||
|
|
||||||
|
|
||||||
class BaseEngine(ABC):
|
|
||||||
@abstractmethod
|
|
||||||
def __init__(self, sample_args: SampleArguments, model_loader: ModelLoader) -> None: ...
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def generate(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def batch_infer(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class HuggingFaceEngine(BaseEngine):
|
|
||||||
def __init__(self, model_loader: ModelLoader, sample_args: SampleArguments) -> None:
|
|
||||||
self.args = sample_args
|
|
||||||
|
|
||||||
|
|
||||||
class ChatSampler:
|
|
||||||
def __init__(self, model_loader: ModelLoader, sample_args: SampleArguments) -> None:
|
|
||||||
if sample_args.sample_backend == SampleBackend.HF:
|
|
||||||
self.engine = HuggingFaceEngine(model_loader, sample_args)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unknown sample backend: {sample_args.sample_backend}")
|
|
||||||
@@ -14,17 +14,24 @@
|
|||||||
|
|
||||||
"""The definition of model loader.
|
"""The definition of model loader.
|
||||||
|
|
||||||
Init Phase:
|
How to use:
|
||||||
|
model_loader = ModelLoader(model_args, is_trainable=True)
|
||||||
|
model_loader.processor: Get the tokenizer or multi-modal processor.
|
||||||
|
model_loader.model_config: Get the model configuration.
|
||||||
|
model_loader.model: Get the HF model.
|
||||||
|
|
||||||
|
Init Workflow:
|
||||||
1. Init processor.
|
1. Init processor.
|
||||||
2. Init model config.
|
2. Init model config.
|
||||||
3. Init model.
|
3. Init model.
|
||||||
4. Init adapter.
|
4. Init adapter.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from accelerate import init_empty_weights
|
||||||
from transformers import AutoConfig, AutoProcessor
|
from transformers import AutoConfig, AutoProcessor
|
||||||
|
|
||||||
|
from ..accelerator.helper import DeviceType
|
||||||
from ..accelerator.interface import DistributedInterface
|
from ..accelerator.interface import DistributedInterface
|
||||||
from ..config.model_args import ModelArguments, ModelClass
|
from ..config.model_args import ModelArguments, ModelClass
|
||||||
from ..utils import logging
|
from ..utils import logging
|
||||||
@@ -55,11 +62,14 @@ class ModelLoader:
|
|||||||
"""HF model."""
|
"""HF model."""
|
||||||
|
|
||||||
def _init_processor(self) -> Processor:
|
def _init_processor(self) -> Processor:
|
||||||
"""Init processor."""
|
"""Init processor.
|
||||||
|
|
||||||
|
NOTE: Transformers v5 always use fast tokenizer.
|
||||||
|
https://github.com/huggingface/transformers/blob/v5.0.0rc1/src/transformers/models/auto/tokenization_auto.py#L642
|
||||||
|
"""
|
||||||
return AutoProcessor.from_pretrained(
|
return AutoProcessor.from_pretrained(
|
||||||
self.args.model,
|
self.args.model,
|
||||||
trust_remote_code=self.args.trust_remote_code,
|
trust_remote_code=self.args.trust_remote_code,
|
||||||
use_fast=self.args.use_fast_processor,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def _init_model_config(self) -> HFConfig:
|
def _init_model_config(self) -> HFConfig:
|
||||||
@@ -92,14 +102,24 @@ class ModelLoader:
|
|||||||
|
|
||||||
AutoClass = AutoModel
|
AutoClass = AutoModel
|
||||||
|
|
||||||
# map the entire model to the current accelerator
|
if self.args.init_config is not None:
|
||||||
model = AutoClass.from_pretrained(
|
from ..plugins.model_plugins.initialization import InitPlugin
|
||||||
self.args.model,
|
|
||||||
config=self.model_config,
|
init_device = InitPlugin(self.args.init_config.name)()
|
||||||
dtype="auto",
|
else:
|
||||||
device_map=DistributedInterface().current_accelerator,
|
init_device = DistributedInterface().current_accelerator
|
||||||
trust_remote_code=self.args.trust_remote_code,
|
|
||||||
)
|
if init_device.type == DeviceType.META:
|
||||||
|
with init_empty_weights():
|
||||||
|
model = AutoClass.from_config(self.model_config)
|
||||||
|
else:
|
||||||
|
model = AutoClass.from_pretrained(
|
||||||
|
self.args.model,
|
||||||
|
config=self.model_config,
|
||||||
|
dtype="auto",
|
||||||
|
device_map=init_device,
|
||||||
|
trust_remote_code=self.args.trust_remote_code,
|
||||||
|
)
|
||||||
|
|
||||||
if self.args.peft_config is None:
|
if self.args.peft_config is None:
|
||||||
if self.is_train:
|
if self.is_train:
|
||||||
|
|||||||
@@ -0,0 +1,43 @@
|
|||||||
|
# Copyright 2025 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from ...accelerator.helper import DeviceType
|
||||||
|
from ...accelerator.interface import DistributedInterface
|
||||||
|
from ...utils.plugin import BasePlugin
|
||||||
|
|
||||||
|
|
||||||
|
class InitPlugin(BasePlugin):
|
||||||
|
def __call__(self) -> torch.device:
|
||||||
|
return super().__call__()
|
||||||
|
|
||||||
|
|
||||||
|
@InitPlugin("init_on_meta").register
|
||||||
|
def init_on_meta() -> torch.device:
|
||||||
|
return torch.device(DeviceType.META.value)
|
||||||
|
|
||||||
|
|
||||||
|
@InitPlugin("init_on_rank0").register
|
||||||
|
def init_on_rank0() -> torch.device:
|
||||||
|
if DistributedInterface().get_rank() == 0:
|
||||||
|
return torch.device(DeviceType.CPU.value)
|
||||||
|
else:
|
||||||
|
return torch.device(DeviceType.META.value)
|
||||||
|
|
||||||
|
|
||||||
|
@InitPlugin("init_on_default").register
|
||||||
|
def init_on_default() -> torch.device:
|
||||||
|
return DistributedInterface().current_accelerator
|
||||||
|
|||||||
35
src/llamafactory/v1/samplers/cli_sampler.py
Normal file
35
src/llamafactory/v1/samplers/cli_sampler.py
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
# Copyright 2025 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
from ..config import InputArgument, SampleBackend, get_args
|
||||||
|
from ..core.base_sampler import BaseSampler
|
||||||
|
from ..core.model_loader import ModelLoader
|
||||||
|
|
||||||
|
|
||||||
|
def run_chat(args: InputArgument = None):
|
||||||
|
data_args, model_args, _, sample_args = get_args(args)
|
||||||
|
if sample_args.sample_backend != SampleBackend.HF:
|
||||||
|
model_args.init_plugin = {"name": "init_on_meta"}
|
||||||
|
|
||||||
|
model_loader = ModelLoader(model_args)
|
||||||
|
sampler = BaseSampler(sample_args, model_args, model_loader.model, model_loader.processor)
|
||||||
|
if data_args.dataset is not None:
|
||||||
|
sampler.batch_infer()
|
||||||
|
else:
|
||||||
|
sampler.generate()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
run_chat()
|
||||||
@@ -12,7 +12,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
"""LLaMA-Factory test configuration.
|
"""LlamaFactory test configuration.
|
||||||
|
|
||||||
Contains shared fixtures, pytest configuration, and custom markers.
|
Contains shared fixtures, pytest configuration, and custom markers.
|
||||||
"""
|
"""
|
||||||
@@ -110,11 +110,10 @@ def _handle_device_visibility(items: list[Item]):
|
|||||||
def pytest_collection_modifyitems(config: Config, items: list[Item]):
|
def pytest_collection_modifyitems(config: Config, items: list[Item]):
|
||||||
"""Modify test collection based on markers and environment."""
|
"""Modify test collection based on markers and environment."""
|
||||||
# Handle version compatibility (from HEAD)
|
# Handle version compatibility (from HEAD)
|
||||||
if not is_transformers_version_greater_than("4.57.0"):
|
skip_bc = pytest.mark.skip(reason="Skip backward compatibility tests")
|
||||||
skip_bc = pytest.mark.skip(reason="Skip backward compatibility tests")
|
for item in items:
|
||||||
for item in items:
|
if "tests_v1" in str(item.fspath) and not is_transformers_version_greater_than("4.57.0"):
|
||||||
if "tests_v1" in str(item.fspath):
|
item.add_marker(skip_bc)
|
||||||
item.add_marker(skip_bc)
|
|
||||||
|
|
||||||
_handle_slow_tests(items)
|
_handle_slow_tests(items)
|
||||||
_handle_runs_on(items)
|
_handle_runs_on(items)
|
||||||
@@ -156,6 +155,7 @@ def _manage_distributed_env(request: FixtureRequest, monkeypatch: MonkeyPatch) -
|
|||||||
monkeypatch.setenv(env_key, visible_devices[0] if visible_devices else "0")
|
monkeypatch.setenv(env_key, visible_devices[0] if visible_devices else "0")
|
||||||
else:
|
else:
|
||||||
monkeypatch.setenv(env_key, "0")
|
monkeypatch.setenv(env_key, "0")
|
||||||
|
|
||||||
if CURRENT_DEVICE == "cuda":
|
if CURRENT_DEVICE == "cuda":
|
||||||
monkeypatch.setattr(torch.cuda, "device_count", lambda: 1)
|
monkeypatch.setattr(torch.cuda, "device_count", lambda: 1)
|
||||||
elif CURRENT_DEVICE == "npu":
|
elif CURRENT_DEVICE == "npu":
|
||||||
|
|||||||
@@ -24,7 +24,6 @@ def test_get_args_from_yaml(tmp_path: pathlib.Path):
|
|||||||
### model
|
### model
|
||||||
model: "llamafactory/tiny-random-qwen2.5"
|
model: "llamafactory/tiny-random-qwen2.5"
|
||||||
trust_remote_code: true
|
trust_remote_code: true
|
||||||
use_fast_processor: true
|
|
||||||
model_class: "llm"
|
model_class: "llm"
|
||||||
kernel_config:
|
kernel_config:
|
||||||
name: "auto"
|
name: "auto"
|
||||||
|
|||||||
@@ -12,7 +12,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
"""LLaMA-Factory test configuration.
|
"""LlamaFactory test configuration.
|
||||||
|
|
||||||
Contains shared fixtures, pytest configuration, and custom markers.
|
Contains shared fixtures, pytest configuration, and custom markers.
|
||||||
"""
|
"""
|
||||||
@@ -22,6 +22,7 @@ import sys
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
from pytest import Config, FixtureRequest, Item, MonkeyPatch
|
from pytest import Config, FixtureRequest, Item, MonkeyPatch
|
||||||
|
|
||||||
from llamafactory.v1.accelerator.helper import get_current_accelerator, get_device_count
|
from llamafactory.v1.accelerator.helper import get_current_accelerator, get_device_count
|
||||||
@@ -109,17 +110,24 @@ def _handle_device_visibility(items: list[Item]):
|
|||||||
def pytest_collection_modifyitems(config: Config, items: list[Item]):
|
def pytest_collection_modifyitems(config: Config, items: list[Item]):
|
||||||
"""Modify test collection based on markers and environment."""
|
"""Modify test collection based on markers and environment."""
|
||||||
# Handle version compatibility (from HEAD)
|
# Handle version compatibility (from HEAD)
|
||||||
if not is_transformers_version_greater_than("4.57.0"):
|
skip_bc = pytest.mark.skip(reason="Skip backward compatibility tests")
|
||||||
skip_bc = pytest.mark.skip(reason="Skip backward compatibility tests")
|
for item in items:
|
||||||
for item in items:
|
if "tests_v1" in str(item.fspath) and not is_transformers_version_greater_than("4.57.0"):
|
||||||
if "tests_v1" in str(item.fspath):
|
item.add_marker(skip_bc)
|
||||||
item.add_marker(skip_bc)
|
|
||||||
|
|
||||||
_handle_slow_tests(items)
|
_handle_slow_tests(items)
|
||||||
_handle_runs_on(items)
|
_handle_runs_on(items)
|
||||||
_handle_device_visibility(items)
|
_handle_device_visibility(items)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _cleanup_distributed_state():
|
||||||
|
"""Cleanup distributed state after each test."""
|
||||||
|
yield
|
||||||
|
if dist.is_initialized():
|
||||||
|
dist.destroy_process_group()
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def _manage_distributed_env(request: FixtureRequest, monkeypatch: MonkeyPatch) -> None:
|
def _manage_distributed_env(request: FixtureRequest, monkeypatch: MonkeyPatch) -> None:
|
||||||
"""Set environment variables for distributed tests if specific devices are requested."""
|
"""Set environment variables for distributed tests if specific devices are requested."""
|
||||||
@@ -155,6 +163,7 @@ def _manage_distributed_env(request: FixtureRequest, monkeypatch: MonkeyPatch) -
|
|||||||
monkeypatch.setenv(env_key, visible_devices[0] if visible_devices else "0")
|
monkeypatch.setenv(env_key, visible_devices[0] if visible_devices else "0")
|
||||||
else:
|
else:
|
||||||
monkeypatch.setenv(env_key, "0")
|
monkeypatch.setenv(env_key, "0")
|
||||||
|
|
||||||
if CURRENT_DEVICE == "cuda":
|
if CURRENT_DEVICE == "cuda":
|
||||||
monkeypatch.setattr(torch.cuda, "device_count", lambda: 1)
|
monkeypatch.setattr(torch.cuda, "device_count", lambda: 1)
|
||||||
elif CURRENT_DEVICE == "npu":
|
elif CURRENT_DEVICE == "npu":
|
||||||
|
|||||||
56
tests_v1/plugins/model_plugins/test_init_plugin.py
Normal file
56
tests_v1/plugins/model_plugins/test_init_plugin.py
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
# Copyright 2025 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from llamafactory.v1.accelerator.interface import DistributedInterface
|
||||||
|
from llamafactory.v1.config.arg_parser import get_args
|
||||||
|
from llamafactory.v1.core.model_loader import ModelLoader
|
||||||
|
|
||||||
|
|
||||||
|
def test_init_on_meta():
|
||||||
|
_, model_args, *_ = get_args(
|
||||||
|
dict(
|
||||||
|
model="llamafactory/tiny-random-qwen2.5",
|
||||||
|
init_config={"name": "init_on_meta"},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
model_loader = ModelLoader(model_args=model_args)
|
||||||
|
assert model_loader.model.device.type == "meta"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.runs_on(["cuda", "npu"])
|
||||||
|
def test_init_on_rank0():
|
||||||
|
_, model_args, *_ = get_args(
|
||||||
|
dict(
|
||||||
|
model="llamafactory/tiny-random-qwen2.5",
|
||||||
|
init_config={"name": "init_on_rank0"},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
model_loader = ModelLoader(model_args=model_args)
|
||||||
|
if DistributedInterface().get_rank() == 0:
|
||||||
|
assert model_loader.model.device.type == "cpu"
|
||||||
|
else:
|
||||||
|
assert model_loader.model.device.type == "meta"
|
||||||
|
|
||||||
|
|
||||||
|
def test_init_on_default():
|
||||||
|
_, model_args, *_ = get_args(
|
||||||
|
dict(
|
||||||
|
model="llamafactory/tiny-random-qwen2.5",
|
||||||
|
init_config={"name": "init_on_default"},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
model_loader = ModelLoader(model_args=model_args)
|
||||||
|
assert model_loader.model.device.type == DistributedInterface().current_accelerator.type
|
||||||
Reference in New Issue
Block a user