8 Commits

Author SHA1 Message Date
Xunpeng Xiao
68119e5522 [misc] Add a PyTorch version warning for Conv3D. (#9715) 2026-01-05 13:26:29 +08:00
Yaowei Zheng
f60a6e3d01 [v1] add init plugin (#9716) 2026-01-04 20:51:46 +08:00
jiaqiw09
81b8a50aa5 [deps] Update pyproject.toml and requirements (#9714)
Co-authored-by: Yaowei Zheng <hiyouga@buaa.edu.cn>
2026-01-04 19:52:16 +08:00
Yaowei Zheng
8600530002 [misc] lint (#9710) 2026-01-04 13:47:56 +08:00
Hertz
9ae62c6fc0 [model] support Youtu-LLM-2B (#9707) 2026-01-04 13:17:57 +08:00
Xunpeng Xiao
0087bc253b [misc] Compatible with an empty architectures field in config.json (#9709) 2026-01-04 12:11:35 +08:00
Santosh Bhavani
355d5c5e5a [fix] fp8: add Transformer Engine backend support (#9705)
Co-authored-by: Yaowei Zheng <hiyouga@buaa.edu.cn>
2026-01-01 10:18:02 +08:00
Yaowei Zheng
6fe6bd290b [misc] set dev version (#9703) 2025-12-31 23:41:40 +08:00
57 changed files with 546 additions and 219 deletions

View File

@@ -70,7 +70,8 @@ jobs:
run: | run: |
uv venv uv venv
uv pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu uv pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
uv pip install -e ".[dev]" uv pip install -e .
uv pip install -r requirements/dev.txt
- name: Install transformers - name: Install transformers
if: ${{ matrix.transformers }} if: ${{ matrix.transformers }}

View File

@@ -52,7 +52,8 @@ jobs:
- name: Install dependencies - name: Install dependencies
run: | run: |
uv venv uv venv
uv pip install -e ".[dev]" uv pip install -e .
uv pip install -r requirements/dev.txt
- name: Cache HuggingFace models - name: Cache HuggingFace models
id: hf-hub-cache id: hf-hub-cache

View File

@@ -58,8 +58,9 @@ jobs:
- name: Install dependencies - name: Install dependencies
run: | run: |
uv venv uv venv
uv pip install torch-npu==${{matrix.pytorch_npu}} uv pip install -r requirements/npu.txt
uv pip install -e ".[dev]" uv pip install -e .
uv pip install -r requirements/dev.txt
- name: Install node - name: Install node
run: | run: |

View File

@@ -329,6 +329,7 @@ Read technical notes:
| [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - | | [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
| [VibeThinker-1.5B](https://huggingface.co/WeiboAI) | 1.5B | qwen3 | | [VibeThinker-1.5B](https://huggingface.co/WeiboAI) | 1.5B | qwen3 |
| [Yi/Yi-1.5 (Code)](https://huggingface.co/01-ai) | 1.5B/6B/9B/34B | yi | | [Yi/Yi-1.5 (Code)](https://huggingface.co/01-ai) | 1.5B/6B/9B/34B | yi |
| [Youtu-LLM](https://huggingface.co/tencent/) | 2B | youtu |
| [Yuan 2](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan | | [Yuan 2](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan |
> [!NOTE] > [!NOTE]
@@ -516,10 +517,11 @@ huggingface-cli login
```bash ```bash
git clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git git clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git
cd LLaMA-Factory cd LLaMA-Factory
pip install -e ".[metrics]" pip install -e .
pip install -r requirements/metrics.txt
``` ```
Optional dependencies available: `metrics`, `deepspeed`. Install with: `pip install -e ".[metrics,deepspeed]"` Optional dependencies available: `metrics`, `deepspeed`. Install with: `pip install -e . && pip install -r requirements/metrics.txt -r requirements/deepspeed.txt`
Additional dependencies for specific features are available in `examples/requirements/`. Additional dependencies for specific features are available in `examples/requirements/`.
@@ -577,36 +579,21 @@ To enable FlashAttention-2 on the Windows platform, please use the script from [
<details><summary>For Ascend NPU users</summary> <details><summary>For Ascend NPU users</summary>
To install LLaMA Factory on Ascend NPU devices, please upgrade Python to version 3.10 or higher: `pip install -e . torch-npu==2.7.1`. Additionally, you need to install the **[Ascend CANN Toolkit and Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**. Please follow the [installation tutorial](https://www.hiascend.com/document/detail/en/CANNCommunityEdition/600alphaX/softwareinstall/instg/atlasdeploy_03_0031.html) or use the following commands: To install LLaMA Factory on Ascend NPU devices, please upgrade Python to version 3.10 or higher: `pip install -r requirements/npu.txt`. Additionally, you need to install the **Ascend CANN Toolkit and Kernels**. Please follow the [installation tutorial](https://llamafactory.readthedocs.io/en/latest/advanced/npu_installation.html).
You can also download the pre-built Docker images:
```bash ```bash
# replace the url according to your CANN version and devices # Docker Hub
# install CANN Toolkit docker pull hiyouga/llamafactory:latest-npu-a2
wget https://ascend-repo.obs.cn-east-2.myhuaweicloud.com/Milan-ASL/Milan-ASL%20V100R001C20SPC702/Ascend-cann-toolkit_8.0.0.alpha002_linux-"$(uname -i)".run docker pull hiyouga/llamafactory:latest-npu-a3
bash Ascend-cann-toolkit_8.0.0.alpha002_linux-"$(uname -i)".run --install
# install CANN Kernels # quay.io
wget https://ascend-repo.obs.cn-east-2.myhuaweicloud.com/Milan-ASL/Milan-ASL%20V100R001C20SPC702/Ascend-cann-kernels-910b_8.0.0.alpha002_linux-"$(uname -i)".run docker pull quay.io/ascend/llamafactory:latest-npu-a2
bash Ascend-cann-kernels-910b_8.0.0.alpha002_linux-"$(uname -i)".run --install docker pull quay.io/ascend/llamafactory:latest-npu-a3
# set env variables
source /usr/local/Ascend/ascend-toolkit/set_env.sh
``` ```
| Requirement | Minimum | Recommend |
| ------------ | ------- | -------------- |
| CANN | 8.0.RC1 | 8.0.0.alpha002 |
| torch | 2.1.0 | 2.7.1 |
| torch-npu | 2.1.0 | 2.7.1 |
| deepspeed | 0.13.2 | 0.13.2 |
| vllm-ascend | - | 0.7.3 |
Remember to use `ASCEND_RT_VISIBLE_DEVICES` instead of `CUDA_VISIBLE_DEVICES` to specify the device to use.
If you cannot infer model on NPU devices, try setting `do_sample: false` in the configurations.
Download the pre-built Docker images: [32GB](http://mirrors.cn-central-221.ovaijisuan.com/detail/130.html) | [64GB](http://mirrors.cn-central-221.ovaijisuan.com/detail/131.html)
#### Install BitsAndBytes #### Install BitsAndBytes
To use QLoRA based on bitsandbytes on Ascend NPU, please follow these 3 steps: To use QLoRA based on bitsandbytes on Ascend NPU, please follow these 3 steps:

View File

@@ -331,6 +331,7 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
| [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - | | [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
| [VibeThinker-1.5B](https://huggingface.co/WeiboAI) | 1.5B | qwen3 | | [VibeThinker-1.5B](https://huggingface.co/WeiboAI) | 1.5B | qwen3 |
| [Yi/Yi-1.5 (Code)](https://huggingface.co/01-ai) | 1.5B/6B/9B/34B | yi | | [Yi/Yi-1.5 (Code)](https://huggingface.co/01-ai) | 1.5B/6B/9B/34B | yi |
| [Youtu-LLM](https://huggingface.co/tencent/) | 2B | youtu |
| [Yuan 2](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan | | [Yuan 2](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan |
> [!NOTE] > [!NOTE]
@@ -518,10 +519,11 @@ huggingface-cli login
```bash ```bash
git clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git git clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git
cd LLaMA-Factory cd LLaMA-Factory
pip install -e ".[metrics]" pip install -e .
pip install -r requirements/metrics.txt
``` ```
可选的额外依赖项:`metrics``deepspeed`。使用 `pip install -e ".[metrics,deepspeed]"` 安装。 可选的额外依赖项:`metrics``deepspeed`。使用 `pip install -e . && pip install -r requirements/metrics.txt -r requirements/deepspeed.txt` 安装。
其他可选依赖项请参考 `examples/requirements/` 目录下的文件。 其他可选依赖项请参考 `examples/requirements/` 目录下的文件。
@@ -579,36 +581,20 @@ pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/downl
<details><summary>昇腾 NPU 用户指南</summary> <details><summary>昇腾 NPU 用户指南</summary>
在昇腾 NPU 设备上安装 LLaMA Factory 时,请升级 Python 到 3.10 及以上,并需要指定额外依赖项,使用 `pip install -e . torch-npu==2.7.1` 命令安装。此外,还需要安装 **[Ascend CANN Toolkit 与 Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**,安装方法请参考[安装教程](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/80RC2alpha002/quickstart/quickstart/quickstart_18_0004.html)或使用以下命令: 在昇腾 NPU 设备上安装 LLaMA Factory 时,请升级 Python 到 3.10 及以上,并需要指定额外依赖项,使用 `pip install -r requirements/npu.txt` 命令安装。此外,还需要安装 **Ascend CANN Toolkit 与 Kernels**,安装方法请参考[安装教程](https://llamafactory.readthedocs.io/zh-cn/latest/advanced/npu_installation.html)。
您可以直接下载预安装的最新docker镜像
```bash ```bash
# 请替换 URL 为 CANN 版本和设备型号对应的 URL # Docker Hub
# 安装 CANN Toolkit docker pull hiyouga/llamafactory:latest-npu-a2
wget https://ascend-repo.obs.cn-east-2.myhuaweicloud.com/Milan-ASL/Milan-ASL%20V100R001C17SPC701/Ascend-cann-toolkit_8.0.RC1.alpha001_linux-"$(uname -i)".run docker pull hiyouga/llamafactory:latest-npu-a3
bash Ascend-cann-toolkit_8.0.RC1.alpha001_linux-"$(uname -i)".run --install
# 安装 CANN Kernels # quay.io
wget https://ascend-repo.obs.cn-east-2.myhuaweicloud.com/Milan-ASL/Milan-ASL%20V100R001C17SPC701/Ascend-cann-kernels-910b_8.0.RC1.alpha001_linux.run docker pull quay.io/ascend/llamafactory:latest-npu-a2
bash Ascend-cann-kernels-910b_8.0.RC1.alpha001_linux.run --install docker pull quay.io/ascend/llamafactory:latest-npu-a3
# 设置环境变量
source /usr/local/Ascend/ascend-toolkit/set_env.sh
``` ```
| 依赖项 | 至少 | 推荐 |
| ------------ | ------- | -------------- |
| CANN | 8.0.RC1 | 8.0.0.alpha002 |
| torch | 2.1.0 | 2.7.1 |
| torch-npu | 2.1.0 | 2.7.1 |
| deepspeed | 0.13.2 | 0.13.2 |
| vllm-ascend | - | 0.7.3 |
请使用 `ASCEND_RT_VISIBLE_DEVICES` 而非 `CUDA_VISIBLE_DEVICES` 来指定运算设备。
如果遇到无法正常推理的情况,请尝试设置 `do_sample: false`
下载预构建 Docker 镜像:[32GB](http://mirrors.cn-central-221.ovaijisuan.com/detail/130.html) | [64GB](http://mirrors.cn-central-221.ovaijisuan.com/detail/131.html)
#### 安装 BitsAndBytes #### 安装 BitsAndBytes
如果要在 Ascend NPU 上进行基于 bitsandbytes 的 QLoRA 量化微调,请执行如下步骤: 如果要在 Ascend NPU 上进行基于 bitsandbytes 的 QLoRA 量化微调,请执行如下步骤:

View File

@@ -32,7 +32,8 @@ RUN pip config set global.index-url "${PIP_INDEX}" && \
COPY . /app COPY . /app
# Install LLaMA Factory # Install LLaMA Factory
RUN pip install --no-cache-dir --no-build-isolation -e ".[metrics,deepspeed]" RUN pip install --no-cache-dir --no-build-isolation -e . && \
pip install --no-cache-dir --no-build-isolation -r requirements/metrics.txt -r requirements/deepspeed.txt
# Rebuild flash attention # Rebuild flash attention
RUN if [ "${INSTALL_FLASHATTN}" == "true" ]; then \ RUN if [ "${INSTALL_FLASHATTN}" == "true" ]; then \

View File

@@ -60,7 +60,8 @@ WORKDIR /app
COPY . /app COPY . /app
# Install LLaMA Factory # Install LLaMA Factory
RUN pip install --no-cache-dir -e ".[metrics]" --no-build-isolation RUN pip install --no-cache-dir -e . --no-build-isolation && \
pip install --no-cache-dir -r requirements/metrics.txt --no-build-isolation
RUN pip install "git+https://github.com/alibaba/roll.git#subdirectory=mcore_adapter" RUN pip install "git+https://github.com/alibaba/roll.git#subdirectory=mcore_adapter"

View File

@@ -35,7 +35,8 @@ COPY . /app
# Install torch-npu # Install torch-npu
RUN pip uninstall -y torch torchvision torchaudio && \ RUN pip uninstall -y torch torchvision torchaudio && \
pip install --no-cache-dir "torch==2.7.1" "torch-npu==2.7.1" "torchvision==0.22.1" "torchaudio==2.7.1" --index-url "${PYTORCH_INDEX}" && \ pip install --no-cache-dir "torch==2.7.1" "torch-npu==2.7.1" "torchvision==0.22.1" "torchaudio==2.7.1" --index-url "${PYTORCH_INDEX}" && \
pip install --no-cache-dir -e ".[metrics]" --no-build-isolation pip install --no-cache-dir -e . --no-build-isolation && \
pip install --no-cache-dir -r requirements/metrics.txt --no-build-isolation
# Set up volumes # Set up volumes
# VOLUME [ "/root/.cache/huggingface", "/app/shared_data", "/app/output" ] # VOLUME [ "/root/.cache/huggingface", "/app/shared_data", "/app/output" ]

View File

@@ -34,7 +34,8 @@ COPY . /app
# Reinstall pytorch rocm and install LLaMA Factory # Reinstall pytorch rocm and install LLaMA Factory
RUN pip uninstall -y torch torchvision torchaudio && \ RUN pip uninstall -y torch torchvision torchaudio && \
pip install --no-cache-dir --no-build-isolation -e --pre ".[metrics,deepspeed]" --index-url "${PYTORCH_INDEX}" pip install --no-cache-dir --no-build-isolation -e --pre . --index-url "${PYTORCH_INDEX}" && \
pip install --no-cache-dir --no-build-isolation -r requirements/metrics.txt -r requirements/deepspeed.txt --index-url "${PYTORCH_INDEX}"
# Rebuild flash attention # Rebuild flash attention
RUN if [ "${INSTALL_FLASHATTN}" == "true" ]; then \ RUN if [ "${INSTALL_FLASHATTN}" == "true" ]; then \

View File

@@ -76,11 +76,6 @@ dependencies = [
"sse-starlette" "sse-starlette"
] ]
[project.optional-dependencies]
dev = ["pre-commit", "ruff", "pytest", "build"]
metrics = ["nltk", "jieba", "rouge-chinese"]
deepspeed = ["deepspeed>=0.10.0,<=0.16.9"]
[project.scripts] [project.scripts]
llamafactory-cli = "llamafactory.cli:main" llamafactory-cli = "llamafactory.cli:main"
lmf = "llamafactory.cli:main" lmf = "llamafactory.cli:main"

View File

@@ -0,0 +1 @@
deepspeed>=0.10.0,<=0.16.9

4
requirements/dev.txt Normal file
View File

@@ -0,0 +1,4 @@
pre-commit
ruff
pytest
build

3
requirements/metrics.txt Normal file
View File

@@ -0,0 +1,3 @@
nltk
jieba
rouge-chinese

4
requirements/npu.txt Normal file
View File

@@ -0,0 +1,4 @@
torch==2.7.1
torch-npu==2.7.1
torchvision==0.22.1
torchaudio==2.7.1

View File

@@ -28,7 +28,7 @@ try:
jieba.setLogLevel(logging.CRITICAL) jieba.setLogLevel(logging.CRITICAL)
jieba.initialize() jieba.initialize()
except ImportError: except ImportError:
print("Please install llamafactory with `pip install -e .[metrics]`.") print("Please install llamafactory with `pip install -r requirements/metrics.txt`.")
raise raise

View File

@@ -2278,6 +2278,21 @@ register_template(
) )
register_template(
name="youtu",
format_user=StringFormatter(slots=["<|User|>{{content}}<|Assistant|>"]),
format_assistant=StringFormatter(slots=["{{content}}<|end_of_text|>"]),
format_system=StringFormatter(slots=["{{content}}"]),
format_function=FunctionFormatter(slots=["{{content}}"], tool_format="default"),
format_observation=StringFormatter(slots=["<tool_response>\n{{content}}\n</tool_response><|Assistant|>"]),
format_tools=ToolFormatter(tool_format="default"),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<|end_of_text|>"],
replace_eos=True,
template_class=ReasoningTemplate,
)
register_template( register_template(
name="yuan", name="yuan",
format_user=StringFormatter(slots=["{{content}}", {"token": "<sep>"}]), format_user=StringFormatter(slots=["{{content}}", {"token": "<sep>"}]),

View File

@@ -3846,6 +3846,21 @@ register_model_group(
) )
register_model_group(
models={
"Youtu-LLM-2B-Instruct": {
DownloadSource.DEFAULT: "tencent/Youtu-LLM-2B",
DownloadSource.MODELSCOPE: "Tencent-YouTu-Research/Youtu-LLM-2B",
},
"Youtu-LLM-2B-Base": {
DownloadSource.DEFAULT: "tencent/Youtu-LLM-2B-Base",
DownloadSource.MODELSCOPE: "Tencent-YouTu-Research/Youtu-LLM-2B-Base",
},
},
template="youtu",
)
register_model_group( register_model_group(
models={ models={
"Yuan2-2B-Chat": { "Yuan2-2B-Chat": {

View File

@@ -19,7 +19,7 @@
from collections import OrderedDict from collections import OrderedDict
VERSION = "0.9.4" VERSION = "0.9.5.dev0"
def print_env() -> None: def print_env() -> None:

View File

@@ -298,23 +298,6 @@ class QuantizationArguments:
default=None, default=None,
metadata={"help": "Device map used to infer the 4-bit quantized model, needs bitsandbytes>=0.43.0."}, metadata={"help": "Device map used to infer the 4-bit quantized model, needs bitsandbytes>=0.43.0."},
) )
fp8: bool = field(
default=False,
metadata={
"help": "Enable FP8 mixed precision training via HuggingFace Accelerate. "
"Requires PyTorch 2.7+ and Hopper architecture GPUs."
},
)
fp8_backend: str = field(
default="auto",
metadata={
"help": "FP8 backend to use ('auto', 'torchao', 'te', 'msamp'). 'auto' selects best available backend."
},
)
fp8_enable_fsdp_float8_all_gather: bool = field(
default=False,
metadata={"help": "Enable FP8 optimizations for FSDP2 all-gather operations."},
)
@dataclass @dataclass

View File

@@ -142,14 +142,6 @@ def _verify_model_args(
logger.warning_rank0("We should use slow tokenizer for the Yi models. Change `use_fast_tokenizer` to False.") logger.warning_rank0("We should use slow tokenizer for the Yi models. Change `use_fast_tokenizer` to False.")
model_args.use_fast_tokenizer = False model_args.use_fast_tokenizer = False
# Validate advanced training features
if model_args.fp8 and model_args.quantization_bit is not None:
raise ValueError("FP8 training is not compatible with quantization. Please disable one of them.")
if model_args.fp8_enable_fsdp_float8_all_gather and not model_args.fp8:
logger.warning_rank0("fp8_enable_fsdp_float8_all_gather requires fp8=True. Setting fp8=True.")
model_args.fp8 = True
def _check_extra_dependencies( def _check_extra_dependencies(
model_args: "ModelArguments", model_args: "ModelArguments",
@@ -347,6 +339,9 @@ def get_train_args(args: dict[str, Any] | list[str] | None = None) -> _TRAIN_CLS
if training_args.deepspeed is not None and (finetuning_args.use_galore or finetuning_args.use_apollo): if training_args.deepspeed is not None and (finetuning_args.use_galore or finetuning_args.use_apollo):
raise ValueError("GaLore and APOLLO are incompatible with DeepSpeed yet.") raise ValueError("GaLore and APOLLO are incompatible with DeepSpeed yet.")
if training_args.fp8 and training_args.quantization_bit is not None:
raise ValueError("FP8 training is not compatible with quantization. Please disable one of them.")
if model_args.infer_backend != EngineName.HF: if model_args.infer_backend != EngineName.HF:
raise ValueError("vLLM/SGLang backend is only available for API, CLI and Web.") raise ValueError("vLLM/SGLang backend is only available for API, CLI and Web.")
@@ -363,6 +358,10 @@ def get_train_args(args: dict[str, Any] | list[str] | None = None) -> _TRAIN_CLS
_verify_model_args(model_args, data_args, finetuning_args) _verify_model_args(model_args, data_args, finetuning_args)
_check_extra_dependencies(model_args, finetuning_args, training_args) _check_extra_dependencies(model_args, finetuning_args, training_args)
if training_args.fp8_enable_fsdp_float8_all_gather and not training_args.fp8:
logger.warning_rank0("fp8_enable_fsdp_float8_all_gather requires fp8=True. Setting fp8=True.")
model_args.fp8 = True
if ( if (
training_args.do_train training_args.do_train
and finetuning_args.finetuning_type == "lora" and finetuning_args.finetuning_type == "lora"

View File

@@ -92,7 +92,30 @@ class RayArguments:
@dataclass @dataclass
class TrainingArguments(RayArguments, BaseTrainingArguments): class Fp8Arguments:
r"""Arguments pertaining to the FP8 training."""
fp8: bool = field(
default=False,
metadata={
"help": "Enable FP8 mixed precision training via HuggingFace Accelerate. "
"Requires PyTorch 2.7+ and Hopper architecture GPUs."
},
)
fp8_backend: str = field(
default="auto",
metadata={
"help": "FP8 backend to use ('auto', 'torchao', 'te', 'msamp'). 'auto' selects best available backend."
},
)
fp8_enable_fsdp_float8_all_gather: bool = field(
default=False,
metadata={"help": "Enable FP8 optimizations for FSDP2 all-gather operations."},
)
@dataclass
class TrainingArguments(Fp8Arguments, RayArguments, BaseTrainingArguments):
r"""Arguments pertaining to the trainer.""" r"""Arguments pertaining to the trainer."""
overwrite_output_dir: bool = field( overwrite_output_dir: bool = field(

View File

@@ -25,10 +25,14 @@ from transformers import (
AutoProcessor, AutoProcessor,
AutoTokenizer, AutoTokenizer,
) )
from packaging import version
from torch import nn
from trl import AutoModelForCausalLMWithValueHead from trl import AutoModelForCausalLMWithValueHead
import warnings
from ..extras import logging from ..extras import logging
from ..extras.misc import count_parameters, skip_check_imports, try_download_model_from_other_hub from ..extras.misc import count_parameters, skip_check_imports, try_download_model_from_other_hub
from ..extras.packages import _get_package_version
from .adapter import init_adapter from .adapter import init_adapter
from .model_utils.ktransformers import load_kt_pretrained_model from .model_utils.ktransformers import load_kt_pretrained_model
from .model_utils.liger_kernel import apply_liger_kernel from .model_utils.liger_kernel import apply_liger_kernel
@@ -202,6 +206,17 @@ def load_model(
if vhead_params is not None: if vhead_params is not None:
model.load_state_dict(vhead_params, strict=False) model.load_state_dict(vhead_params, strict=False)
logger.info_rank0(f"Loaded valuehead from checkpoint: {vhead_path}") logger.info_rank0(f"Loaded valuehead from checkpoint: {vhead_path}")
# Conv3D is not recommended when using torch 2.9.x
torch_version = _get_package_version("torch")
if version.parse("2.9.0") <= torch_version < version.parse("2.10.0"):
if any(isinstance(m, nn.Conv3d) for m in model.modules()):
raise ValueError(
"Unsupported torch version detected: torch 2.9.x with Conv3D. "
"This combination is known to cause severe performance regression. "
"Please downgrade torch to <2.9 or remove Conv3D. "
"See https://github.com/pytorch/pytorch/issues/166122"
)
if not is_trainable: if not is_trainable:
model.requires_grad_(False) model.requires_grad_(False)

View File

@@ -138,13 +138,14 @@ def patch_config(
if getattr(config, "model_type", None) == "kimi_vl" and is_trainable: if getattr(config, "model_type", None) == "kimi_vl" and is_trainable:
setattr(config.text_config, "topk_method", "greedy") setattr(config.text_config, "topk_method", "greedy")
if "InternVLChatModel" in getattr(config, "architectures", []): architectures = getattr(config, "architectures", None)
if isinstance(architectures, list) and "InternVLChatModel" in architectures:
raise ValueError( raise ValueError(
"Please download the internvl models in a Hugging Facecompatible format " "Please download the internvl models in a Hugging Facecompatible format "
"(for example, https://huggingface.co/OpenGVLab/InternVL3-8B-hf)." "(for example, https://huggingface.co/OpenGVLab/InternVL3-8B-hf)."
) )
if "LlavaLlamaForCausalLM" in getattr(config, "architectures", []): if isinstance(architectures, list) and "LlavaLlamaForCausalLM" in architectures:
raise ValueError("Please download llava models with hf-compatible format: https://huggingface.co/llava-hf") raise ValueError("Please download llava models with hf-compatible format: https://huggingface.co/llava-hf")
if getattr(config, "model_type", None) == "internlm3" and not is_transformers_version_greater_than("4.47.1"): if getattr(config, "model_type", None) == "internlm3" and not is_transformers_version_greater_than("4.47.1"):

View File

@@ -12,35 +12,45 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
import types
from typing import TYPE_CHECKING, Any, Optional from typing import TYPE_CHECKING, Any, Optional
from ..extras import logging from ..extras import logging
if TYPE_CHECKING: if TYPE_CHECKING:
from ..hparams import ModelArguments from ..hparams import TrainingArguments
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
def create_fp8_kwargs(model_args: "ModelArguments") -> list[Any]: def create_fp8_kwargs(training_args: "TrainingArguments") -> list[Any]:
"""Create AORecipeKwargs for FP8 training with HuggingFace Accelerate. """Create AORecipeKwargs for FP8 training with HuggingFace Accelerate.
Args: Args:
model_args: Model arguments containing FP8 configuration training_args: Training arguments containing FP8 configuration
Returns: Returns:
List containing AORecipeKwargs if FP8 is enabled and supported, empty list otherwise List containing AORecipeKwargs if FP8 is enabled and supported, empty list otherwise
""" """
if not model_args.fp8: if not training_args.fp8:
return [] return []
try: backend = getattr(training_args, "fp8_backend", "auto")
# Check if AORecipeKwargs is available (Accelerate 1.8.0+) logger.info_rank0(f"Creating FP8 configuration with backend: {backend}")
from accelerate.utils import AORecipeKwargs
backend = getattr(model_args, "fp8_backend", "auto") try:
logger.info_rank0(f"Creating FP8 configuration with backend: {backend}") # Use Transformer Engine backend (optimal for Hopper GPUs)
if backend == "te":
from accelerate.utils import FP8RecipeKwargs
logger.info_rank0("Using Transformer Engine FP8 backend")
return [FP8RecipeKwargs(backend="TE", fp8_format="HYBRID", amax_history_len=16, amax_compute_algo="max")]
# Use TorchAO backend (default)
from accelerate.utils import AORecipeKwargs
# Create Float8LinearConfig if torchao backend is used # Create Float8LinearConfig if torchao backend is used
config = None config = None
@@ -83,7 +93,10 @@ def create_fp8_kwargs(model_args: "ModelArguments") -> list[Any]:
return True return True
# Map FSDP all-gather setting if available (this affects the underlying implementation) # Map FSDP all-gather setting if available (this affects the underlying implementation)
if hasattr(model_args, "fp8_enable_fsdp_float8_all_gather") and model_args.fp8_enable_fsdp_float8_all_gather: if (
hasattr(training_args, "fp8_enable_fsdp_float8_all_gather")
and training_args.fp8_enable_fsdp_float8_all_gather
):
logger.info_rank0("FSDP float8 all-gather optimization requested") logger.info_rank0("FSDP float8 all-gather optimization requested")
return [AORecipeKwargs(config=config, module_filter_func=module_filter_func)] return [AORecipeKwargs(config=config, module_filter_func=module_filter_func)]
@@ -92,19 +105,19 @@ def create_fp8_kwargs(model_args: "ModelArguments") -> list[Any]:
return [] return []
def get_fp8_mixed_precision(model_args: "ModelArguments") -> Optional[str]: def get_fp8_mixed_precision(training_args: "TrainingArguments") -> Optional[str]:
"""Get the mixed precision setting for Accelerate when using FP8. """Get the mixed precision setting for Accelerate when using FP8.
Args: Args:
model_args: Model arguments containing FP8 configuration training_args: Training arguments containing FP8 configuration
Returns: Returns:
"fp8" if FP8 is enabled, None otherwise "fp8" if FP8 is enabled, None otherwise
""" """
return "fp8" if model_args.fp8 else None return "fp8" if training_args.fp8 else None
def configure_fp8_environment(model_args: "ModelArguments") -> None: def configure_fp8_environment(training_args: "TrainingArguments") -> None:
"""Configure FP8 environment for HuggingFace Accelerate. """Configure FP8 environment for HuggingFace Accelerate.
FP8 training is handled entirely through HuggingFace Accelerate, regardless of whether FP8 training is handled entirely through HuggingFace Accelerate, regardless of whether
@@ -112,11 +125,9 @@ def configure_fp8_environment(model_args: "ModelArguments") -> None:
variables and validates the FP8 configuration. variables and validates the FP8 configuration.
Args: Args:
model_args: Model arguments containing FP8 configuration training_args: Training arguments containing FP8 configuration
""" """
import os if not training_args.fp8:
if not model_args.fp8:
return return
# Set mixed precision to fp8 for HuggingFace Accelerate # Set mixed precision to fp8 for HuggingFace Accelerate
@@ -124,38 +135,38 @@ def configure_fp8_environment(model_args: "ModelArguments") -> None:
logger.info_rank0("Set ACCELERATE_MIXED_PRECISION=fp8") logger.info_rank0("Set ACCELERATE_MIXED_PRECISION=fp8")
# Configure FP8 backend and options # Configure FP8 backend and options
backend = getattr(model_args, "fp8_backend", "auto") backend = getattr(training_args, "fp8_backend", "auto")
if backend != "auto": if backend != "auto":
os.environ["FP8_BACKEND"] = backend os.environ["FP8_BACKEND"] = backend
logger.info_rank0(f"Set FP8_BACKEND={backend}") logger.info_rank0(f"Set FP8_BACKEND={backend}")
# Create and validate FP8 recipe kwargs (for logging/debugging) # Create and validate FP8 recipe kwargs (for logging/debugging)
fp8_kwargs = create_fp8_kwargs(model_args) fp8_kwargs = create_fp8_kwargs(training_args)
logger.info_rank0(f"FP8 AORecipeKwargs created: {len(fp8_kwargs)} items") logger.info_rank0(f"FP8 AORecipeKwargs created: {len(fp8_kwargs)} items")
# Enable FSDP float8 all-gather optimization if requested # Enable FSDP float8 all-gather optimization if requested
if hasattr(model_args, "fp8_enable_fsdp_float8_all_gather") and model_args.fp8_enable_fsdp_float8_all_gather: if hasattr(training_args, "fp8_enable_fsdp_float8_all_gather") and training_args.fp8_enable_fsdp_float8_all_gather:
os.environ["FP8_ENABLE_FSDP_FLOAT8_ALL_GATHER"] = "true" os.environ["FP8_ENABLE_FSDP_FLOAT8_ALL_GATHER"] = "true"
logger.info_rank0("Set FP8_ENABLE_FSDP_FLOAT8_ALL_GATHER=true") logger.info_rank0("Set FP8_ENABLE_FSDP_FLOAT8_ALL_GATHER=true")
logger.info_rank0("FP8 environment configured - all FP8 training handled by HuggingFace Accelerate") logger.info_rank0("FP8 environment configured - all FP8 training handled by HuggingFace Accelerate")
def verify_fp8_status(accelerator, model_args: "ModelArguments") -> None: def verify_fp8_status(accelerator, training_args: "TrainingArguments") -> None:
"""Verify that FP8 training is actually working after model preparation. """Verify that FP8 training is actually working after model preparation.
Args: Args:
accelerator: The HuggingFace Accelerator instance accelerator: The HuggingFace Accelerator instance
model_args: Model arguments containing FP8 configuration training_args: Training arguments containing FP8 configuration
""" """
if not model_args.fp8: if not training_args.fp8:
return return
# Check Accelerate's FP8 status # Check Accelerate's FP8 status
fp8_enabled = getattr(accelerator, "fp8_enabled", False) fp8_enabled = getattr(accelerator, "fp8_enabled", False)
fp8_backend_type = getattr(accelerator, "fp8_backend", "UNKNOWN") fp8_backend_type = getattr(accelerator, "fp8_backend", "UNKNOWN")
backend = getattr(model_args, "fp8_backend", "auto") backend = getattr(training_args, "fp8_backend", "auto")
if backend == "torchao" or backend == "auto": if backend == "torchao" or backend == "auto":
logger.info_rank0( logger.info_rank0(
"FP8 training enabled with TorchAO backend. For optimal performance, " "FP8 training enabled with TorchAO backend. For optimal performance, "
@@ -169,3 +180,50 @@ def verify_fp8_status(accelerator, model_args: "ModelArguments") -> None:
if not fp8_enabled: if not fp8_enabled:
logger.info_rank0("WARNING: FP8 was requested but Accelerate shows fp8_enabled=False. FP8 may not be working.") logger.info_rank0("WARNING: FP8 was requested but Accelerate shows fp8_enabled=False. FP8 may not be working.")
def patch_accelerator_for_fp8() -> None:
"""Patch Accelerator to inject FP8 recipe kwargs.
This is needed because HuggingFace Trainer doesn't pass kwargs_handlers to Accelerator.
We monkey-patch Accelerator.__init__ to inject the FP8 recipe and force mixed_precision='fp8'.
"""
import transformer_engine.pytorch as te
from accelerate import Accelerator
# Guard against multiple patches
if getattr(Accelerator, "_te_fp8_patched", False):
return
# Stub for Accelerate 1.12+ compatibility (te.fp8.check_mxfp8_support doesn't exist yet)
if not hasattr(te, "fp8"):
te.fp8 = types.ModuleType("fp8")
te.fp8.check_mxfp8_support = lambda: (False, "MXFP8 not supported")
try:
from accelerate.utils import TERecipeKwargs as FP8Recipe
use_te_recipe = True
except ImportError:
from accelerate.utils import FP8RecipeKwargs as FP8Recipe
use_te_recipe = False
original_init = Accelerator.__init__
def patched_init(self, *args, **kwargs):
if "kwargs_handlers" not in kwargs or not kwargs["kwargs_handlers"]:
if use_te_recipe:
kwargs["kwargs_handlers"] = [
FP8Recipe(fp8_format="HYBRID", amax_history_len=16, amax_compute_algo="max")
]
else:
kwargs["kwargs_handlers"] = [
FP8Recipe(backend="TE", fp8_format="HYBRID", amax_history_len=16, amax_compute_algo="max")
]
# Only force mixed_precision when we inject handlers
kwargs["mixed_precision"] = "fp8"
return original_init(self, *args, **kwargs)
Accelerator.__init__ = patched_init
Accelerator._te_fp8_patched = True

View File

@@ -19,16 +19,15 @@ import torch
from transformers import Trainer from transformers import Trainer
from typing_extensions import override from typing_extensions import override
from ...extras.packages import is_transformers_version_greater_than
from ..callbacks import SaveProcessorCallback from ..callbacks import SaveProcessorCallback
from ..fp8_utils import configure_fp8_environment, verify_fp8_status from ..fp8_utils import configure_fp8_environment, patch_accelerator_for_fp8, verify_fp8_status
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import ProcessorMixin from transformers import ProcessorMixin
from ...hparams import FinetuningArguments, ModelArguments from ...hparams import FinetuningArguments, ModelArguments, TrainingArguments
class CustomTrainer(Trainer): class CustomTrainer(Trainer):
@@ -41,11 +40,13 @@ class CustomTrainer(Trainer):
model_args: Optional["ModelArguments"] = None, model_args: Optional["ModelArguments"] = None,
**kwargs, **kwargs,
) -> None: ) -> None:
kwargs["processing_class"] = kwargs.pop("tokenizer")
# Configure FP8 environment if enabled # Configure FP8 environment if enabled
if model_args is not None and model_args.fp8: training_args: TrainingArguments = kwargs.get("args")
configure_fp8_environment(model_args) if training_args.fp8:
if is_transformers_version_greater_than("4.46"): configure_fp8_environment(training_args)
kwargs["processing_class"] = kwargs.pop("tokenizer") if getattr(training_args, "fp8_backend", "auto") == "te":
patch_accelerator_for_fp8()
super().__init__(**kwargs) super().__init__(**kwargs)
if processor is not None: if processor is not None:
@@ -64,9 +65,8 @@ class CustomTrainer(Trainer):
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.add_callback(BAdamCallback) self.add_callback(BAdamCallback)
# Verify FP8 status after trainer initialization (accelerator should be available) if training_args.fp8 and hasattr(self, "accelerator"): # verify FP8 status after trainer initialization
if model_args is not None and model_args.fp8 and hasattr(self, "accelerator"): verify_fp8_status(self.accelerator, training_args)
verify_fp8_status(self.accelerator, model_args)
@override @override
def create_optimizer(self) -> "torch.optim.Optimizer": def create_optimizer(self) -> "torch.optim.Optimizer":

View File

@@ -27,18 +27,17 @@ from typing_extensions import override
from ...extras import logging from ...extras import logging
from ...extras.constants import IGNORE_INDEX from ...extras.constants import IGNORE_INDEX
from ...extras.packages import is_transformers_version_greater_than
from ..callbacks import SaveProcessorCallback from ..callbacks import SaveProcessorCallback
from ..fp8_utils import configure_fp8_environment, verify_fp8_status from ..fp8_utils import configure_fp8_environment, patch_accelerator_for_fp8, verify_fp8_status
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
if TYPE_CHECKING: if TYPE_CHECKING:
from torch.utils.data import Dataset from torch.utils.data import Dataset
from transformers import PreTrainedTokenizer, ProcessorMixin from transformers import ProcessorMixin
from transformers.trainer import PredictionOutput from transformers.trainer import PredictionOutput
from ...hparams import FinetuningArguments, ModelArguments from ...hparams import FinetuningArguments, ModelArguments, TrainingArguments
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@@ -55,13 +54,13 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
gen_kwargs: Optional[dict[str, Any]] = None, gen_kwargs: Optional[dict[str, Any]] = None,
**kwargs, **kwargs,
) -> None: ) -> None:
kwargs["processing_class"] = kwargs.pop("tokenizer")
# Configure FP8 environment if enabled # Configure FP8 environment if enabled
if model_args is not None and model_args.fp8: training_args: TrainingArguments = kwargs.get("args")
configure_fp8_environment(model_args) if training_args.fp8:
if is_transformers_version_greater_than("4.46"): configure_fp8_environment(training_args)
kwargs["processing_class"] = kwargs.pop("tokenizer") if getattr(training_args, "fp8_backend", "auto") == "te":
else: patch_accelerator_for_fp8()
self.processing_class: PreTrainedTokenizer = kwargs.get("tokenizer")
super().__init__(**kwargs) super().__init__(**kwargs)
if processor is not None: if processor is not None:
@@ -88,9 +87,8 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
self.compute_loss_func = dft_loss_func self.compute_loss_func = dft_loss_func
# Verify FP8 status after trainer initialization (accelerator should be available) if training_args.fp8 and hasattr(self, "accelerator"): # verify FP8 status after trainer initialization
if model_args is not None and model_args.fp8 and hasattr(self, "accelerator"): verify_fp8_status(self.accelerator, training_args)
verify_fp8_status(self.accelerator, model_args)
@override @override
def create_optimizer(self) -> "torch.optim.Optimizer": def create_optimizer(self) -> "torch.optim.Optimizer":

View File

@@ -34,10 +34,14 @@ from typing import Any, Optional
from torch.distributed import barrier, destroy_process_group, init_process_group from torch.distributed import barrier, destroy_process_group, init_process_group
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
from ..utils import logging
from ..utils.types import DistributedConfig, ProcessGroup, Tensor, TensorLike from ..utils.types import DistributedConfig, ProcessGroup, Tensor, TensorLike
from . import helper from . import helper
logger = logging.get_logger(__name__)
class Dim(str, Enum): class Dim(str, Enum):
"""Dimension names.""" """Dimension names."""
@@ -157,6 +161,7 @@ class DistributedInterface:
self.data_device_mesh = None self.data_device_mesh = None
self._initialized = True self._initialized = True
logger.info_rank0(f"DistributedInterface initialized with strategy={self.strategy}.")
def __str__(self) -> str: def __str__(self) -> str:
return ( return (

View File

@@ -0,0 +1,32 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .arg_parser import InputArgument, get_args
from .arg_utils import ModelClass, SampleBackend
from .data_args import DataArguments
from .model_args import ModelArguments
from .sample_args import SampleArguments
from .training_args import TrainingArguments
__all__ = [
"DataArguments",
"InputArgument",
"ModelArguments",
"ModelClass",
"SampleArguments",
"SampleBackend",
"TrainingArguments",
"get_args",
]

View File

@@ -27,14 +27,14 @@ class ModelArguments:
default=False, default=False,
metadata={"help": "Trust remote code from Hugging Face."}, metadata={"help": "Trust remote code from Hugging Face."},
) )
use_fast_processor: bool = field(
default=True,
metadata={"help": "Use fast processor from Hugging Face."},
)
model_class: ModelClass = field( model_class: ModelClass = field(
default=ModelClass.LLM, default=ModelClass.LLM,
metadata={"help": "Model class from Hugging Face."}, metadata={"help": "Model class from Hugging Face."},
) )
init_config: PluginConfig | None = field(
default=None,
metadata={"help": "Initialization configuration for the model."},
)
peft_config: PluginConfig | None = field( peft_config: PluginConfig | None = field(
default=None, default=None,
metadata={"help": "PEFT configuration for the model."}, metadata={"help": "PEFT configuration for the model."},
@@ -49,6 +49,7 @@ class ModelArguments:
) )
def __post_init__(self) -> None: def __post_init__(self) -> None:
self.init_config = get_plugin_config(self.init_config)
self.peft_config = get_plugin_config(self.peft_config) self.peft_config = get_plugin_config(self.peft_config)
self.kernel_config = get_plugin_config(self.kernel_config) self.kernel_config = get_plugin_config(self.kernel_config)
self.quant_config = get_plugin_config(self.quant_config) self.quant_config = get_plugin_config(self.quant_config)

View File

@@ -22,7 +22,7 @@ from .arg_utils import PluginConfig, get_plugin_config
@dataclass @dataclass
class TrainingArguments: class TrainingArguments:
output_dir: str = field( output_dir: str = field(
default=os.path.join("outputs", str(uuid4())), default=os.path.join("outputs", str(uuid4().hex)),
metadata={"help": "Path to the output directory."}, metadata={"help": "Path to the output directory."},
) )
micro_batch_size: int = field( micro_batch_size: int = field(

View File

@@ -0,0 +1,77 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC, abstractmethod
from ..config import ModelArguments, SampleArguments, SampleBackend
from ..utils.types import HFModel, Processor, TorchDataset
class BaseEngine(ABC):
@abstractmethod
def __init__(
self,
args: SampleArguments,
model_args: ModelArguments,
model: HFModel = None,
processor: Processor = None,
) -> None:
"""Initialize the engine.
Args:
args: Sample arguments.
model_args: Model arguments.
model: Model.
processor: Processor.
"""
...
@abstractmethod
async def generate(self, messages):
pass
@abstractmethod
async def batch_infer(self, data: TorchDataset) -> None:
pass
class HuggingFaceEngine(BaseEngine):
def __init__(
self,
args: SampleArguments,
model_args: ModelArguments,
model: HFModel,
processor: Processor,
) -> None:
self.args = args
class BaseSampler:
def __init__(
self,
args: SampleArguments,
model_args: ModelArguments,
model: HFModel,
processor: Processor,
) -> None:
if args.sample_backend == SampleBackend.HF:
self.engine = HuggingFaceEngine(args, model_args, model, processor)
else:
raise ValueError(f"Unknown sample backend: {args.sample_backend}")
async def generate(self, messages):
return await self.engine.generate(messages)
async def batch_infer(self, data: TorchDataset) -> None:
return await self.engine.batch_infer(data)

View File

@@ -1,44 +0,0 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC, abstractmethod
from ..config.sample_args import SampleArguments, SampleBackend
from .model_loader import ModelLoader
class BaseEngine(ABC):
@abstractmethod
def __init__(self, sample_args: SampleArguments, model_loader: ModelLoader) -> None: ...
@abstractmethod
async def generate(self):
pass
@abstractmethod
async def batch_infer(self):
pass
class HuggingFaceEngine(BaseEngine):
def __init__(self, model_loader: ModelLoader, sample_args: SampleArguments) -> None:
self.args = sample_args
class ChatSampler:
def __init__(self, model_loader: ModelLoader, sample_args: SampleArguments) -> None:
if sample_args.sample_backend == SampleBackend.HF:
self.engine = HuggingFaceEngine(model_loader, sample_args)
else:
raise ValueError(f"Unknown sample backend: {sample_args.sample_backend}")

View File

@@ -14,17 +14,24 @@
"""The definition of model loader. """The definition of model loader.
Init Phase: How to use:
model_loader = ModelLoader(model_args, is_trainable=True)
model_loader.processor: Get the tokenizer or multi-modal processor.
model_loader.model_config: Get the model configuration.
model_loader.model: Get the HF model.
Init Workflow:
1. Init processor. 1. Init processor.
2. Init model config. 2. Init model config.
3. Init model. 3. Init model.
4. Init adapter. 4. Init adapter.
""" """
import torch import torch
from accelerate import init_empty_weights
from transformers import AutoConfig, AutoProcessor from transformers import AutoConfig, AutoProcessor
from ..accelerator.helper import DeviceType
from ..accelerator.interface import DistributedInterface from ..accelerator.interface import DistributedInterface
from ..config.model_args import ModelArguments, ModelClass from ..config.model_args import ModelArguments, ModelClass
from ..utils import logging from ..utils import logging
@@ -55,11 +62,14 @@ class ModelLoader:
"""HF model.""" """HF model."""
def _init_processor(self) -> Processor: def _init_processor(self) -> Processor:
"""Init processor.""" """Init processor.
NOTE: Transformers v5 always use fast tokenizer.
https://github.com/huggingface/transformers/blob/v5.0.0rc1/src/transformers/models/auto/tokenization_auto.py#L642
"""
return AutoProcessor.from_pretrained( return AutoProcessor.from_pretrained(
self.args.model, self.args.model,
trust_remote_code=self.args.trust_remote_code, trust_remote_code=self.args.trust_remote_code,
use_fast=self.args.use_fast_processor,
) )
def _init_model_config(self) -> HFConfig: def _init_model_config(self) -> HFConfig:
@@ -92,14 +102,24 @@ class ModelLoader:
AutoClass = AutoModel AutoClass = AutoModel
# map the entire model to the current accelerator if self.args.init_config is not None:
model = AutoClass.from_pretrained( from ..plugins.model_plugins.initialization import InitPlugin
self.args.model,
config=self.model_config, init_device = InitPlugin(self.args.init_config.name)()
dtype="auto", else:
device_map=DistributedInterface().current_accelerator, init_device = DistributedInterface().current_accelerator
trust_remote_code=self.args.trust_remote_code,
) if init_device.type == DeviceType.META:
with init_empty_weights():
model = AutoClass.from_config(self.model_config)
else:
model = AutoClass.from_pretrained(
self.args.model,
config=self.model_config,
dtype="auto",
device_map=init_device,
trust_remote_code=self.args.trust_remote_code,
)
if self.args.peft_config is None: if self.args.peft_config is None:
if self.is_train: if self.is_train:

View File

@@ -0,0 +1,43 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from ...accelerator.helper import DeviceType
from ...accelerator.interface import DistributedInterface
from ...utils.plugin import BasePlugin
class InitPlugin(BasePlugin):
def __call__(self) -> torch.device:
return super().__call__()
@InitPlugin("init_on_meta").register
def init_on_meta() -> torch.device:
return torch.device(DeviceType.META.value)
@InitPlugin("init_on_rank0").register
def init_on_rank0() -> torch.device:
if DistributedInterface().get_rank() == 0:
return torch.device(DeviceType.CPU.value)
else:
return torch.device(DeviceType.META.value)
@InitPlugin("init_on_default").register
def init_on_default() -> torch.device:
return DistributedInterface().current_accelerator

View File

@@ -0,0 +1,35 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ..config import InputArgument, SampleBackend, get_args
from ..core.base_sampler import BaseSampler
from ..core.model_loader import ModelLoader
def run_chat(args: InputArgument = None):
data_args, model_args, _, sample_args = get_args(args)
if sample_args.sample_backend != SampleBackend.HF:
model_args.init_plugin = {"name": "init_on_meta"}
model_loader = ModelLoader(model_args)
sampler = BaseSampler(sample_args, model_args, model_loader.model, model_loader.processor)
if data_args.dataset is not None:
sampler.batch_infer()
else:
sampler.generate()
if __name__ == "__main__":
run_chat()

View File

@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""LLaMA-Factory test configuration. """LlamaFactory test configuration.
Contains shared fixtures, pytest configuration, and custom markers. Contains shared fixtures, pytest configuration, and custom markers.
""" """
@@ -110,11 +110,10 @@ def _handle_device_visibility(items: list[Item]):
def pytest_collection_modifyitems(config: Config, items: list[Item]): def pytest_collection_modifyitems(config: Config, items: list[Item]):
"""Modify test collection based on markers and environment.""" """Modify test collection based on markers and environment."""
# Handle version compatibility (from HEAD) # Handle version compatibility (from HEAD)
if not is_transformers_version_greater_than("4.57.0"): skip_bc = pytest.mark.skip(reason="Skip backward compatibility tests")
skip_bc = pytest.mark.skip(reason="Skip backward compatibility tests") for item in items:
for item in items: if "tests_v1" in str(item.fspath) and not is_transformers_version_greater_than("4.57.0"):
if "tests_v1" in str(item.fspath): item.add_marker(skip_bc)
item.add_marker(skip_bc)
_handle_slow_tests(items) _handle_slow_tests(items)
_handle_runs_on(items) _handle_runs_on(items)
@@ -156,6 +155,7 @@ def _manage_distributed_env(request: FixtureRequest, monkeypatch: MonkeyPatch) -
monkeypatch.setenv(env_key, visible_devices[0] if visible_devices else "0") monkeypatch.setenv(env_key, visible_devices[0] if visible_devices else "0")
else: else:
monkeypatch.setenv(env_key, "0") monkeypatch.setenv(env_key, "0")
if CURRENT_DEVICE == "cuda": if CURRENT_DEVICE == "cuda":
monkeypatch.setattr(torch.cuda, "device_count", lambda: 1) monkeypatch.setattr(torch.cuda, "device_count", lambda: 1)
elif CURRENT_DEVICE == "npu": elif CURRENT_DEVICE == "npu":

View File

@@ -24,7 +24,6 @@ def test_get_args_from_yaml(tmp_path: pathlib.Path):
### model ### model
model: "llamafactory/tiny-random-qwen2.5" model: "llamafactory/tiny-random-qwen2.5"
trust_remote_code: true trust_remote_code: true
use_fast_processor: true
model_class: "llm" model_class: "llm"
kernel_config: kernel_config:
name: "auto" name: "auto"

View File

@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""LLaMA-Factory test configuration. """LlamaFactory test configuration.
Contains shared fixtures, pytest configuration, and custom markers. Contains shared fixtures, pytest configuration, and custom markers.
""" """
@@ -22,6 +22,7 @@ import sys
import pytest import pytest
import torch import torch
import torch.distributed as dist
from pytest import Config, FixtureRequest, Item, MonkeyPatch from pytest import Config, FixtureRequest, Item, MonkeyPatch
from llamafactory.v1.accelerator.helper import get_current_accelerator, get_device_count from llamafactory.v1.accelerator.helper import get_current_accelerator, get_device_count
@@ -109,17 +110,24 @@ def _handle_device_visibility(items: list[Item]):
def pytest_collection_modifyitems(config: Config, items: list[Item]): def pytest_collection_modifyitems(config: Config, items: list[Item]):
"""Modify test collection based on markers and environment.""" """Modify test collection based on markers and environment."""
# Handle version compatibility (from HEAD) # Handle version compatibility (from HEAD)
if not is_transformers_version_greater_than("4.57.0"): skip_bc = pytest.mark.skip(reason="Skip backward compatibility tests")
skip_bc = pytest.mark.skip(reason="Skip backward compatibility tests") for item in items:
for item in items: if "tests_v1" in str(item.fspath) and not is_transformers_version_greater_than("4.57.0"):
if "tests_v1" in str(item.fspath): item.add_marker(skip_bc)
item.add_marker(skip_bc)
_handle_slow_tests(items) _handle_slow_tests(items)
_handle_runs_on(items) _handle_runs_on(items)
_handle_device_visibility(items) _handle_device_visibility(items)
@pytest.fixture(autouse=True)
def _cleanup_distributed_state():
"""Cleanup distributed state after each test."""
yield
if dist.is_initialized():
dist.destroy_process_group()
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def _manage_distributed_env(request: FixtureRequest, monkeypatch: MonkeyPatch) -> None: def _manage_distributed_env(request: FixtureRequest, monkeypatch: MonkeyPatch) -> None:
"""Set environment variables for distributed tests if specific devices are requested.""" """Set environment variables for distributed tests if specific devices are requested."""
@@ -155,6 +163,7 @@ def _manage_distributed_env(request: FixtureRequest, monkeypatch: MonkeyPatch) -
monkeypatch.setenv(env_key, visible_devices[0] if visible_devices else "0") monkeypatch.setenv(env_key, visible_devices[0] if visible_devices else "0")
else: else:
monkeypatch.setenv(env_key, "0") monkeypatch.setenv(env_key, "0")
if CURRENT_DEVICE == "cuda": if CURRENT_DEVICE == "cuda":
monkeypatch.setattr(torch.cuda, "device_count", lambda: 1) monkeypatch.setattr(torch.cuda, "device_count", lambda: 1)
elif CURRENT_DEVICE == "npu": elif CURRENT_DEVICE == "npu":

View File

@@ -0,0 +1,56 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
from llamafactory.v1.accelerator.interface import DistributedInterface
from llamafactory.v1.config.arg_parser import get_args
from llamafactory.v1.core.model_loader import ModelLoader
def test_init_on_meta():
_, model_args, *_ = get_args(
dict(
model="llamafactory/tiny-random-qwen2.5",
init_config={"name": "init_on_meta"},
)
)
model_loader = ModelLoader(model_args=model_args)
assert model_loader.model.device.type == "meta"
@pytest.mark.runs_on(["cuda", "npu"])
def test_init_on_rank0():
_, model_args, *_ = get_args(
dict(
model="llamafactory/tiny-random-qwen2.5",
init_config={"name": "init_on_rank0"},
)
)
model_loader = ModelLoader(model_args=model_args)
if DistributedInterface().get_rank() == 0:
assert model_loader.model.device.type == "cpu"
else:
assert model_loader.model.device.type == "meta"
def test_init_on_default():
_, model_args, *_ = get_args(
dict(
model="llamafactory/tiny-random-qwen2.5",
init_config={"name": "init_on_default"},
)
)
model_loader = ModelLoader(model_args=model_args)
assert model_loader.model.device.type == DistributedInterface().current_accelerator.type