4 Commits

Author SHA1 Message Date
Yaowei Zheng
f60a6e3d01 [v1] add init plugin (#9716) 2026-01-04 20:51:46 +08:00
jiaqiw09
81b8a50aa5 [deps] Update pyproject.toml and requirements (#9714)
Co-authored-by: Yaowei Zheng <hiyouga@buaa.edu.cn>
2026-01-04 19:52:16 +08:00
Yaowei Zheng
8600530002 [misc] lint (#9710) 2026-01-04 13:47:56 +08:00
Hertz
9ae62c6fc0 [model] support Youtu-LLM-2B (#9707) 2026-01-04 13:17:57 +08:00
54 changed files with 407 additions and 155 deletions

View File

@@ -70,7 +70,8 @@ jobs:
run: | run: |
uv venv uv venv
uv pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu uv pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
uv pip install -e ".[dev]" uv pip install -e .
uv pip install -r requirements/dev.txt
- name: Install transformers - name: Install transformers
if: ${{ matrix.transformers }} if: ${{ matrix.transformers }}

View File

@@ -52,7 +52,8 @@ jobs:
- name: Install dependencies - name: Install dependencies
run: | run: |
uv venv uv venv
uv pip install -e ".[dev]" uv pip install -e .
uv pip install -r requirements/dev.txt
- name: Cache HuggingFace models - name: Cache HuggingFace models
id: hf-hub-cache id: hf-hub-cache

View File

@@ -58,8 +58,9 @@ jobs:
- name: Install dependencies - name: Install dependencies
run: | run: |
uv venv uv venv
uv pip install torch-npu==${{matrix.pytorch_npu}} uv pip install -r requirements/npu.txt
uv pip install -e ".[dev]" uv pip install -e .
uv pip install -r requirements/dev.txt
- name: Install node - name: Install node
run: | run: |

View File

@@ -329,6 +329,7 @@ Read technical notes:
| [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - | | [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
| [VibeThinker-1.5B](https://huggingface.co/WeiboAI) | 1.5B | qwen3 | | [VibeThinker-1.5B](https://huggingface.co/WeiboAI) | 1.5B | qwen3 |
| [Yi/Yi-1.5 (Code)](https://huggingface.co/01-ai) | 1.5B/6B/9B/34B | yi | | [Yi/Yi-1.5 (Code)](https://huggingface.co/01-ai) | 1.5B/6B/9B/34B | yi |
| [Youtu-LLM](https://huggingface.co/tencent/) | 2B | youtu |
| [Yuan 2](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan | | [Yuan 2](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan |
> [!NOTE] > [!NOTE]
@@ -516,10 +517,11 @@ huggingface-cli login
```bash ```bash
git clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git git clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git
cd LLaMA-Factory cd LLaMA-Factory
pip install -e ".[metrics]" pip install -e .
pip install -r requirements/metrics.txt
``` ```
Optional dependencies available: `metrics`, `deepspeed`. Install with: `pip install -e ".[metrics,deepspeed]"` Optional dependencies available: `metrics`, `deepspeed`. Install with: `pip install -e . && pip install -r requirements/metrics.txt -r requirements/deepspeed.txt`
Additional dependencies for specific features are available in `examples/requirements/`. Additional dependencies for specific features are available in `examples/requirements/`.
@@ -577,36 +579,21 @@ To enable FlashAttention-2 on the Windows platform, please use the script from [
<details><summary>For Ascend NPU users</summary> <details><summary>For Ascend NPU users</summary>
To install LLaMA Factory on Ascend NPU devices, please upgrade Python to version 3.10 or higher: `pip install -e . torch-npu==2.7.1`. Additionally, you need to install the **[Ascend CANN Toolkit and Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**. Please follow the [installation tutorial](https://www.hiascend.com/document/detail/en/CANNCommunityEdition/600alphaX/softwareinstall/instg/atlasdeploy_03_0031.html) or use the following commands: To install LLaMA Factory on Ascend NPU devices, please upgrade Python to version 3.10 or higher: `pip install -r requirements/npu.txt`. Additionally, you need to install the **Ascend CANN Toolkit and Kernels**. Please follow the [installation tutorial](https://llamafactory.readthedocs.io/en/latest/advanced/npu_installation.html).
You can also download the pre-built Docker images:
```bash ```bash
# replace the url according to your CANN version and devices # Docker Hub
# install CANN Toolkit docker pull hiyouga/llamafactory:latest-npu-a2
wget https://ascend-repo.obs.cn-east-2.myhuaweicloud.com/Milan-ASL/Milan-ASL%20V100R001C20SPC702/Ascend-cann-toolkit_8.0.0.alpha002_linux-"$(uname -i)".run docker pull hiyouga/llamafactory:latest-npu-a3
bash Ascend-cann-toolkit_8.0.0.alpha002_linux-"$(uname -i)".run --install
# install CANN Kernels # quay.io
wget https://ascend-repo.obs.cn-east-2.myhuaweicloud.com/Milan-ASL/Milan-ASL%20V100R001C20SPC702/Ascend-cann-kernels-910b_8.0.0.alpha002_linux-"$(uname -i)".run docker pull quay.io/ascend/llamafactory:latest-npu-a2
bash Ascend-cann-kernels-910b_8.0.0.alpha002_linux-"$(uname -i)".run --install docker pull quay.io/ascend/llamafactory:latest-npu-a3
# set env variables
source /usr/local/Ascend/ascend-toolkit/set_env.sh
``` ```
| Requirement | Minimum | Recommend |
| ------------ | ------- | -------------- |
| CANN | 8.0.RC1 | 8.0.0.alpha002 |
| torch | 2.1.0 | 2.7.1 |
| torch-npu | 2.1.0 | 2.7.1 |
| deepspeed | 0.13.2 | 0.13.2 |
| vllm-ascend | - | 0.7.3 |
Remember to use `ASCEND_RT_VISIBLE_DEVICES` instead of `CUDA_VISIBLE_DEVICES` to specify the device to use.
If you cannot infer model on NPU devices, try setting `do_sample: false` in the configurations.
Download the pre-built Docker images: [32GB](http://mirrors.cn-central-221.ovaijisuan.com/detail/130.html) | [64GB](http://mirrors.cn-central-221.ovaijisuan.com/detail/131.html)
#### Install BitsAndBytes #### Install BitsAndBytes
To use QLoRA based on bitsandbytes on Ascend NPU, please follow these 3 steps: To use QLoRA based on bitsandbytes on Ascend NPU, please follow these 3 steps:

View File

@@ -331,6 +331,7 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
| [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - | | [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
| [VibeThinker-1.5B](https://huggingface.co/WeiboAI) | 1.5B | qwen3 | | [VibeThinker-1.5B](https://huggingface.co/WeiboAI) | 1.5B | qwen3 |
| [Yi/Yi-1.5 (Code)](https://huggingface.co/01-ai) | 1.5B/6B/9B/34B | yi | | [Yi/Yi-1.5 (Code)](https://huggingface.co/01-ai) | 1.5B/6B/9B/34B | yi |
| [Youtu-LLM](https://huggingface.co/tencent/) | 2B | youtu |
| [Yuan 2](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan | | [Yuan 2](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan |
> [!NOTE] > [!NOTE]
@@ -518,10 +519,11 @@ huggingface-cli login
```bash ```bash
git clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git git clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git
cd LLaMA-Factory cd LLaMA-Factory
pip install -e ".[metrics]" pip install -e .
pip install -r requirements/metrics.txt
``` ```
可选的额外依赖项:`metrics``deepspeed`。使用 `pip install -e ".[metrics,deepspeed]"` 安装。 可选的额外依赖项:`metrics``deepspeed`。使用 `pip install -e . && pip install -r requirements/metrics.txt -r requirements/deepspeed.txt` 安装。
其他可选依赖项请参考 `examples/requirements/` 目录下的文件。 其他可选依赖项请参考 `examples/requirements/` 目录下的文件。
@@ -579,36 +581,20 @@ pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/downl
<details><summary>昇腾 NPU 用户指南</summary> <details><summary>昇腾 NPU 用户指南</summary>
在昇腾 NPU 设备上安装 LLaMA Factory 时,请升级 Python 到 3.10 及以上,并需要指定额外依赖项,使用 `pip install -e . torch-npu==2.7.1` 命令安装。此外,还需要安装 **[Ascend CANN Toolkit 与 Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**,安装方法请参考[安装教程](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/80RC2alpha002/quickstart/quickstart/quickstart_18_0004.html)或使用以下命令: 在昇腾 NPU 设备上安装 LLaMA Factory 时,请升级 Python 到 3.10 及以上,并需要指定额外依赖项,使用 `pip install -r requirements/npu.txt` 命令安装。此外,还需要安装 **Ascend CANN Toolkit 与 Kernels**,安装方法请参考[安装教程](https://llamafactory.readthedocs.io/zh-cn/latest/advanced/npu_installation.html)。
您可以直接下载预安装的最新docker镜像
```bash ```bash
# 请替换 URL 为 CANN 版本和设备型号对应的 URL # Docker Hub
# 安装 CANN Toolkit docker pull hiyouga/llamafactory:latest-npu-a2
wget https://ascend-repo.obs.cn-east-2.myhuaweicloud.com/Milan-ASL/Milan-ASL%20V100R001C17SPC701/Ascend-cann-toolkit_8.0.RC1.alpha001_linux-"$(uname -i)".run docker pull hiyouga/llamafactory:latest-npu-a3
bash Ascend-cann-toolkit_8.0.RC1.alpha001_linux-"$(uname -i)".run --install
# 安装 CANN Kernels # quay.io
wget https://ascend-repo.obs.cn-east-2.myhuaweicloud.com/Milan-ASL/Milan-ASL%20V100R001C17SPC701/Ascend-cann-kernels-910b_8.0.RC1.alpha001_linux.run docker pull quay.io/ascend/llamafactory:latest-npu-a2
bash Ascend-cann-kernels-910b_8.0.RC1.alpha001_linux.run --install docker pull quay.io/ascend/llamafactory:latest-npu-a3
# 设置环境变量
source /usr/local/Ascend/ascend-toolkit/set_env.sh
``` ```
| 依赖项 | 至少 | 推荐 |
| ------------ | ------- | -------------- |
| CANN | 8.0.RC1 | 8.0.0.alpha002 |
| torch | 2.1.0 | 2.7.1 |
| torch-npu | 2.1.0 | 2.7.1 |
| deepspeed | 0.13.2 | 0.13.2 |
| vllm-ascend | - | 0.7.3 |
请使用 `ASCEND_RT_VISIBLE_DEVICES` 而非 `CUDA_VISIBLE_DEVICES` 来指定运算设备。
如果遇到无法正常推理的情况,请尝试设置 `do_sample: false`
下载预构建 Docker 镜像:[32GB](http://mirrors.cn-central-221.ovaijisuan.com/detail/130.html) | [64GB](http://mirrors.cn-central-221.ovaijisuan.com/detail/131.html)
#### 安装 BitsAndBytes #### 安装 BitsAndBytes
如果要在 Ascend NPU 上进行基于 bitsandbytes 的 QLoRA 量化微调,请执行如下步骤: 如果要在 Ascend NPU 上进行基于 bitsandbytes 的 QLoRA 量化微调,请执行如下步骤:

View File

@@ -32,7 +32,8 @@ RUN pip config set global.index-url "${PIP_INDEX}" && \
COPY . /app COPY . /app
# Install LLaMA Factory # Install LLaMA Factory
RUN pip install --no-cache-dir --no-build-isolation -e ".[metrics,deepspeed]" RUN pip install --no-cache-dir --no-build-isolation -e . && \
pip install --no-cache-dir --no-build-isolation -r requirements/metrics.txt -r requirements/deepspeed.txt
# Rebuild flash attention # Rebuild flash attention
RUN if [ "${INSTALL_FLASHATTN}" == "true" ]; then \ RUN if [ "${INSTALL_FLASHATTN}" == "true" ]; then \

View File

@@ -60,7 +60,8 @@ WORKDIR /app
COPY . /app COPY . /app
# Install LLaMA Factory # Install LLaMA Factory
RUN pip install --no-cache-dir -e ".[metrics]" --no-build-isolation RUN pip install --no-cache-dir -e . --no-build-isolation && \
pip install --no-cache-dir -r requirements/metrics.txt --no-build-isolation
RUN pip install "git+https://github.com/alibaba/roll.git#subdirectory=mcore_adapter" RUN pip install "git+https://github.com/alibaba/roll.git#subdirectory=mcore_adapter"

View File

@@ -35,7 +35,8 @@ COPY . /app
# Install torch-npu # Install torch-npu
RUN pip uninstall -y torch torchvision torchaudio && \ RUN pip uninstall -y torch torchvision torchaudio && \
pip install --no-cache-dir "torch==2.7.1" "torch-npu==2.7.1" "torchvision==0.22.1" "torchaudio==2.7.1" --index-url "${PYTORCH_INDEX}" && \ pip install --no-cache-dir "torch==2.7.1" "torch-npu==2.7.1" "torchvision==0.22.1" "torchaudio==2.7.1" --index-url "${PYTORCH_INDEX}" && \
pip install --no-cache-dir -e ".[metrics]" --no-build-isolation pip install --no-cache-dir -e . --no-build-isolation && \
pip install --no-cache-dir -r requirements/metrics.txt --no-build-isolation
# Set up volumes # Set up volumes
# VOLUME [ "/root/.cache/huggingface", "/app/shared_data", "/app/output" ] # VOLUME [ "/root/.cache/huggingface", "/app/shared_data", "/app/output" ]

View File

@@ -34,7 +34,8 @@ COPY . /app
# Reinstall pytorch rocm and install LLaMA Factory # Reinstall pytorch rocm and install LLaMA Factory
RUN pip uninstall -y torch torchvision torchaudio && \ RUN pip uninstall -y torch torchvision torchaudio && \
pip install --no-cache-dir --no-build-isolation -e --pre ".[metrics,deepspeed]" --index-url "${PYTORCH_INDEX}" pip install --no-cache-dir --no-build-isolation -e --pre . --index-url "${PYTORCH_INDEX}" && \
pip install --no-cache-dir --no-build-isolation -r requirements/metrics.txt -r requirements/deepspeed.txt --index-url "${PYTORCH_INDEX}"
# Rebuild flash attention # Rebuild flash attention
RUN if [ "${INSTALL_FLASHATTN}" == "true" ]; then \ RUN if [ "${INSTALL_FLASHATTN}" == "true" ]; then \

View File

@@ -76,11 +76,6 @@ dependencies = [
"sse-starlette" "sse-starlette"
] ]
[project.optional-dependencies]
dev = ["pre-commit", "ruff", "pytest", "build"]
metrics = ["nltk", "jieba", "rouge-chinese"]
deepspeed = ["deepspeed>=0.10.0,<=0.16.9"]
[project.scripts] [project.scripts]
llamafactory-cli = "llamafactory.cli:main" llamafactory-cli = "llamafactory.cli:main"
lmf = "llamafactory.cli:main" lmf = "llamafactory.cli:main"

View File

@@ -0,0 +1 @@
deepspeed>=0.10.0,<=0.16.9

4
requirements/dev.txt Normal file
View File

@@ -0,0 +1,4 @@
pre-commit
ruff
pytest
build

3
requirements/metrics.txt Normal file
View File

@@ -0,0 +1,3 @@
nltk
jieba
rouge-chinese

4
requirements/npu.txt Normal file
View File

@@ -0,0 +1,4 @@
torch==2.7.1
torch-npu==2.7.1
torchvision==0.22.1
torchaudio==2.7.1

View File

@@ -28,7 +28,7 @@ try:
jieba.setLogLevel(logging.CRITICAL) jieba.setLogLevel(logging.CRITICAL)
jieba.initialize() jieba.initialize()
except ImportError: except ImportError:
print("Please install llamafactory with `pip install -e .[metrics]`.") print("Please install llamafactory with `pip install -r requirements/metrics.txt`.")
raise raise

View File

@@ -2278,6 +2278,21 @@ register_template(
) )
register_template(
name="youtu",
format_user=StringFormatter(slots=["<|User|>{{content}}<|Assistant|>"]),
format_assistant=StringFormatter(slots=["{{content}}<|end_of_text|>"]),
format_system=StringFormatter(slots=["{{content}}"]),
format_function=FunctionFormatter(slots=["{{content}}"], tool_format="default"),
format_observation=StringFormatter(slots=["<tool_response>\n{{content}}\n</tool_response><|Assistant|>"]),
format_tools=ToolFormatter(tool_format="default"),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<|end_of_text|>"],
replace_eos=True,
template_class=ReasoningTemplate,
)
register_template( register_template(
name="yuan", name="yuan",
format_user=StringFormatter(slots=["{{content}}", {"token": "<sep>"}]), format_user=StringFormatter(slots=["{{content}}", {"token": "<sep>"}]),

View File

@@ -3846,6 +3846,21 @@ register_model_group(
) )
register_model_group(
models={
"Youtu-LLM-2B-Instruct": {
DownloadSource.DEFAULT: "tencent/Youtu-LLM-2B",
DownloadSource.MODELSCOPE: "Tencent-YouTu-Research/Youtu-LLM-2B",
},
"Youtu-LLM-2B-Base": {
DownloadSource.DEFAULT: "tencent/Youtu-LLM-2B-Base",
DownloadSource.MODELSCOPE: "Tencent-YouTu-Research/Youtu-LLM-2B-Base",
},
},
template="youtu",
)
register_model_group( register_model_group(
models={ models={
"Yuan2-2B-Chat": { "Yuan2-2B-Chat": {

View File

@@ -142,6 +142,7 @@ def _verify_model_args(
logger.warning_rank0("We should use slow tokenizer for the Yi models. Change `use_fast_tokenizer` to False.") logger.warning_rank0("We should use slow tokenizer for the Yi models. Change `use_fast_tokenizer` to False.")
model_args.use_fast_tokenizer = False model_args.use_fast_tokenizer = False
def _check_extra_dependencies( def _check_extra_dependencies(
model_args: "ModelArguments", model_args: "ModelArguments",
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",

View File

@@ -94,6 +94,7 @@ class RayArguments:
@dataclass @dataclass
class Fp8Arguments: class Fp8Arguments:
r"""Arguments pertaining to the FP8 training.""" r"""Arguments pertaining to the FP8 training."""
fp8: bool = field( fp8: bool = field(
default=False, default=False,
metadata={ metadata={

View File

@@ -139,14 +139,13 @@ def patch_config(
setattr(config.text_config, "topk_method", "greedy") setattr(config.text_config, "topk_method", "greedy")
architectures = getattr(config, "architectures", None) architectures = getattr(config, "architectures", None)
if isinstance(architectures, list) and "InternVLChatModel" in architectures:
if isinstance(architectures, (list, tuple)) and "InternVLChatModel" in architectures:
raise ValueError( raise ValueError(
"Please download the internvl models in a Hugging Facecompatible format " "Please download the internvl models in a Hugging Facecompatible format "
"(for example, https://huggingface.co/OpenGVLab/InternVL3-8B-hf)." "(for example, https://huggingface.co/OpenGVLab/InternVL3-8B-hf)."
) )
if isinstance(architectures, (list, tuple)) and "LlavaLlamaForCausalLM" in architectures: if isinstance(architectures, list) and "LlavaLlamaForCausalLM" in architectures:
raise ValueError("Please download llava models with hf-compatible format: https://huggingface.co/llava-hf") raise ValueError("Please download llava models with hf-compatible format: https://huggingface.co/llava-hf")
if getattr(config, "model_type", None) == "internlm3" and not is_transformers_version_greater_than("4.47.1"): if getattr(config, "model_type", None) == "internlm3" and not is_transformers_version_greater_than("4.47.1"):

View File

@@ -93,7 +93,10 @@ def create_fp8_kwargs(training_args: "TrainingArguments") -> list[Any]:
return True return True
# Map FSDP all-gather setting if available (this affects the underlying implementation) # Map FSDP all-gather setting if available (this affects the underlying implementation)
if hasattr(training_args, "fp8_enable_fsdp_float8_all_gather") and training_args.fp8_enable_fsdp_float8_all_gather: if (
hasattr(training_args, "fp8_enable_fsdp_float8_all_gather")
and training_args.fp8_enable_fsdp_float8_all_gather
):
logger.info_rank0("FSDP float8 all-gather optimization requested") logger.info_rank0("FSDP float8 all-gather optimization requested")
return [AORecipeKwargs(config=config, module_filter_func=module_filter_func)] return [AORecipeKwargs(config=config, module_filter_func=module_filter_func)]

View File

@@ -19,7 +19,6 @@ import torch
from transformers import Trainer from transformers import Trainer
from typing_extensions import override from typing_extensions import override
from ...extras.packages import is_transformers_version_greater_than
from ..callbacks import SaveProcessorCallback from ..callbacks import SaveProcessorCallback
from ..fp8_utils import configure_fp8_environment, patch_accelerator_for_fp8, verify_fp8_status from ..fp8_utils import configure_fp8_environment, patch_accelerator_for_fp8, verify_fp8_status
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
@@ -28,7 +27,7 @@ from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import ProcessorMixin from transformers import ProcessorMixin
from ...hparams import FinetuningArguments, ModelArguments from ...hparams import FinetuningArguments, ModelArguments, TrainingArguments
class CustomTrainer(Trainer): class CustomTrainer(Trainer):
@@ -43,7 +42,7 @@ class CustomTrainer(Trainer):
) -> None: ) -> None:
kwargs["processing_class"] = kwargs.pop("tokenizer") kwargs["processing_class"] = kwargs.pop("tokenizer")
# Configure FP8 environment if enabled # Configure FP8 environment if enabled
training_args = kwargs.get("args") training_args: TrainingArguments = kwargs.get("args")
if training_args.fp8: if training_args.fp8:
configure_fp8_environment(training_args) configure_fp8_environment(training_args)
if getattr(training_args, "fp8_backend", "auto") == "te": if getattr(training_args, "fp8_backend", "auto") == "te":
@@ -66,7 +65,7 @@ class CustomTrainer(Trainer):
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.add_callback(BAdamCallback) self.add_callback(BAdamCallback)
if training_args.fp8 and hasattr(self, "accelerator"): # verify FP8 status after trainer initialization if training_args.fp8 and hasattr(self, "accelerator"): # verify FP8 status after trainer initialization
verify_fp8_status(self.accelerator, training_args) verify_fp8_status(self.accelerator, training_args)
@override @override

View File

@@ -27,7 +27,6 @@ from typing_extensions import override
from ...extras import logging from ...extras import logging
from ...extras.constants import IGNORE_INDEX from ...extras.constants import IGNORE_INDEX
from ...extras.packages import is_transformers_version_greater_than
from ..callbacks import SaveProcessorCallback from ..callbacks import SaveProcessorCallback
from ..fp8_utils import configure_fp8_environment, patch_accelerator_for_fp8, verify_fp8_status from ..fp8_utils import configure_fp8_environment, patch_accelerator_for_fp8, verify_fp8_status
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
@@ -35,10 +34,10 @@ from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
if TYPE_CHECKING: if TYPE_CHECKING:
from torch.utils.data import Dataset from torch.utils.data import Dataset
from transformers import PreTrainedTokenizer, ProcessorMixin from transformers import ProcessorMixin
from transformers.trainer import PredictionOutput from transformers.trainer import PredictionOutput
from ...hparams import FinetuningArguments, ModelArguments from ...hparams import FinetuningArguments, ModelArguments, TrainingArguments
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@@ -57,7 +56,7 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
) -> None: ) -> None:
kwargs["processing_class"] = kwargs.pop("tokenizer") kwargs["processing_class"] = kwargs.pop("tokenizer")
# Configure FP8 environment if enabled # Configure FP8 environment if enabled
training_args = kwargs.get("args") training_args: TrainingArguments = kwargs.get("args")
if training_args.fp8: if training_args.fp8:
configure_fp8_environment(training_args) configure_fp8_environment(training_args)
if getattr(training_args, "fp8_backend", "auto") == "te": if getattr(training_args, "fp8_backend", "auto") == "te":
@@ -88,7 +87,7 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
self.compute_loss_func = dft_loss_func self.compute_loss_func = dft_loss_func
if training_args.fp8 and hasattr(self, "accelerator"): # verify FP8 status after trainer initialization if training_args.fp8 and hasattr(self, "accelerator"): # verify FP8 status after trainer initialization
verify_fp8_status(self.accelerator, training_args) verify_fp8_status(self.accelerator, training_args)
@override @override

View File

@@ -34,10 +34,14 @@ from typing import Any, Optional
from torch.distributed import barrier, destroy_process_group, init_process_group from torch.distributed import barrier, destroy_process_group, init_process_group
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
from ..utils import logging
from ..utils.types import DistributedConfig, ProcessGroup, Tensor, TensorLike from ..utils.types import DistributedConfig, ProcessGroup, Tensor, TensorLike
from . import helper from . import helper
logger = logging.get_logger(__name__)
class Dim(str, Enum): class Dim(str, Enum):
"""Dimension names.""" """Dimension names."""
@@ -157,6 +161,7 @@ class DistributedInterface:
self.data_device_mesh = None self.data_device_mesh = None
self._initialized = True self._initialized = True
logger.info_rank0(f"DistributedInterface initialized with strategy={self.strategy}.")
def __str__(self) -> str: def __str__(self) -> str:
return ( return (

View File

@@ -0,0 +1,32 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .arg_parser import InputArgument, get_args
from .arg_utils import ModelClass, SampleBackend
from .data_args import DataArguments
from .model_args import ModelArguments
from .sample_args import SampleArguments
from .training_args import TrainingArguments
__all__ = [
"DataArguments",
"InputArgument",
"ModelArguments",
"ModelClass",
"SampleArguments",
"SampleBackend",
"TrainingArguments",
"get_args",
]

View File

@@ -27,14 +27,14 @@ class ModelArguments:
default=False, default=False,
metadata={"help": "Trust remote code from Hugging Face."}, metadata={"help": "Trust remote code from Hugging Face."},
) )
use_fast_processor: bool = field(
default=True,
metadata={"help": "Use fast processor from Hugging Face."},
)
model_class: ModelClass = field( model_class: ModelClass = field(
default=ModelClass.LLM, default=ModelClass.LLM,
metadata={"help": "Model class from Hugging Face."}, metadata={"help": "Model class from Hugging Face."},
) )
init_config: PluginConfig | None = field(
default=None,
metadata={"help": "Initialization configuration for the model."},
)
peft_config: PluginConfig | None = field( peft_config: PluginConfig | None = field(
default=None, default=None,
metadata={"help": "PEFT configuration for the model."}, metadata={"help": "PEFT configuration for the model."},
@@ -49,6 +49,7 @@ class ModelArguments:
) )
def __post_init__(self) -> None: def __post_init__(self) -> None:
self.init_config = get_plugin_config(self.init_config)
self.peft_config = get_plugin_config(self.peft_config) self.peft_config = get_plugin_config(self.peft_config)
self.kernel_config = get_plugin_config(self.kernel_config) self.kernel_config = get_plugin_config(self.kernel_config)
self.quant_config = get_plugin_config(self.quant_config) self.quant_config = get_plugin_config(self.quant_config)

View File

@@ -22,7 +22,7 @@ from .arg_utils import PluginConfig, get_plugin_config
@dataclass @dataclass
class TrainingArguments: class TrainingArguments:
output_dir: str = field( output_dir: str = field(
default=os.path.join("outputs", str(uuid4())), default=os.path.join("outputs", str(uuid4().hex)),
metadata={"help": "Path to the output directory."}, metadata={"help": "Path to the output directory."},
) )
micro_batch_size: int = field( micro_batch_size: int = field(

View File

@@ -0,0 +1,77 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC, abstractmethod
from ..config import ModelArguments, SampleArguments, SampleBackend
from ..utils.types import HFModel, Processor, TorchDataset
class BaseEngine(ABC):
@abstractmethod
def __init__(
self,
args: SampleArguments,
model_args: ModelArguments,
model: HFModel = None,
processor: Processor = None,
) -> None:
"""Initialize the engine.
Args:
args: Sample arguments.
model_args: Model arguments.
model: Model.
processor: Processor.
"""
...
@abstractmethod
async def generate(self, messages):
pass
@abstractmethod
async def batch_infer(self, data: TorchDataset) -> None:
pass
class HuggingFaceEngine(BaseEngine):
def __init__(
self,
args: SampleArguments,
model_args: ModelArguments,
model: HFModel,
processor: Processor,
) -> None:
self.args = args
class BaseSampler:
def __init__(
self,
args: SampleArguments,
model_args: ModelArguments,
model: HFModel,
processor: Processor,
) -> None:
if args.sample_backend == SampleBackend.HF:
self.engine = HuggingFaceEngine(args, model_args, model, processor)
else:
raise ValueError(f"Unknown sample backend: {args.sample_backend}")
async def generate(self, messages):
return await self.engine.generate(messages)
async def batch_infer(self, data: TorchDataset) -> None:
return await self.engine.batch_infer(data)

View File

@@ -1,44 +0,0 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC, abstractmethod
from ..config.sample_args import SampleArguments, SampleBackend
from .model_loader import ModelLoader
class BaseEngine(ABC):
@abstractmethod
def __init__(self, sample_args: SampleArguments, model_loader: ModelLoader) -> None: ...
@abstractmethod
async def generate(self):
pass
@abstractmethod
async def batch_infer(self):
pass
class HuggingFaceEngine(BaseEngine):
def __init__(self, model_loader: ModelLoader, sample_args: SampleArguments) -> None:
self.args = sample_args
class ChatSampler:
def __init__(self, model_loader: ModelLoader, sample_args: SampleArguments) -> None:
if sample_args.sample_backend == SampleBackend.HF:
self.engine = HuggingFaceEngine(model_loader, sample_args)
else:
raise ValueError(f"Unknown sample backend: {sample_args.sample_backend}")

View File

@@ -14,17 +14,24 @@
"""The definition of model loader. """The definition of model loader.
Init Phase: How to use:
model_loader = ModelLoader(model_args, is_trainable=True)
model_loader.processor: Get the tokenizer or multi-modal processor.
model_loader.model_config: Get the model configuration.
model_loader.model: Get the HF model.
Init Workflow:
1. Init processor. 1. Init processor.
2. Init model config. 2. Init model config.
3. Init model. 3. Init model.
4. Init adapter. 4. Init adapter.
""" """
import torch import torch
from accelerate import init_empty_weights
from transformers import AutoConfig, AutoProcessor from transformers import AutoConfig, AutoProcessor
from ..accelerator.helper import DeviceType
from ..accelerator.interface import DistributedInterface from ..accelerator.interface import DistributedInterface
from ..config.model_args import ModelArguments, ModelClass from ..config.model_args import ModelArguments, ModelClass
from ..utils import logging from ..utils import logging
@@ -55,11 +62,14 @@ class ModelLoader:
"""HF model.""" """HF model."""
def _init_processor(self) -> Processor: def _init_processor(self) -> Processor:
"""Init processor.""" """Init processor.
NOTE: Transformers v5 always use fast tokenizer.
https://github.com/huggingface/transformers/blob/v5.0.0rc1/src/transformers/models/auto/tokenization_auto.py#L642
"""
return AutoProcessor.from_pretrained( return AutoProcessor.from_pretrained(
self.args.model, self.args.model,
trust_remote_code=self.args.trust_remote_code, trust_remote_code=self.args.trust_remote_code,
use_fast=self.args.use_fast_processor,
) )
def _init_model_config(self) -> HFConfig: def _init_model_config(self) -> HFConfig:
@@ -92,14 +102,24 @@ class ModelLoader:
AutoClass = AutoModel AutoClass = AutoModel
# map the entire model to the current accelerator if self.args.init_config is not None:
model = AutoClass.from_pretrained( from ..plugins.model_plugins.initialization import InitPlugin
self.args.model,
config=self.model_config, init_device = InitPlugin(self.args.init_config.name)()
dtype="auto", else:
device_map=DistributedInterface().current_accelerator, init_device = DistributedInterface().current_accelerator
trust_remote_code=self.args.trust_remote_code,
) if init_device.type == DeviceType.META:
with init_empty_weights():
model = AutoClass.from_config(self.model_config)
else:
model = AutoClass.from_pretrained(
self.args.model,
config=self.model_config,
dtype="auto",
device_map=init_device,
trust_remote_code=self.args.trust_remote_code,
)
if self.args.peft_config is None: if self.args.peft_config is None:
if self.is_train: if self.is_train:

View File

@@ -0,0 +1,43 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from ...accelerator.helper import DeviceType
from ...accelerator.interface import DistributedInterface
from ...utils.plugin import BasePlugin
class InitPlugin(BasePlugin):
def __call__(self) -> torch.device:
return super().__call__()
@InitPlugin("init_on_meta").register
def init_on_meta() -> torch.device:
return torch.device(DeviceType.META.value)
@InitPlugin("init_on_rank0").register
def init_on_rank0() -> torch.device:
if DistributedInterface().get_rank() == 0:
return torch.device(DeviceType.CPU.value)
else:
return torch.device(DeviceType.META.value)
@InitPlugin("init_on_default").register
def init_on_default() -> torch.device:
return DistributedInterface().current_accelerator

View File

@@ -0,0 +1,35 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ..config import InputArgument, SampleBackend, get_args
from ..core.base_sampler import BaseSampler
from ..core.model_loader import ModelLoader
def run_chat(args: InputArgument = None):
data_args, model_args, _, sample_args = get_args(args)
if sample_args.sample_backend != SampleBackend.HF:
model_args.init_plugin = {"name": "init_on_meta"}
model_loader = ModelLoader(model_args)
sampler = BaseSampler(sample_args, model_args, model_loader.model, model_loader.processor)
if data_args.dataset is not None:
sampler.batch_infer()
else:
sampler.generate()
if __name__ == "__main__":
run_chat()

View File

@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""LLaMA-Factory test configuration. """LlamaFactory test configuration.
Contains shared fixtures, pytest configuration, and custom markers. Contains shared fixtures, pytest configuration, and custom markers.
""" """
@@ -110,11 +110,10 @@ def _handle_device_visibility(items: list[Item]):
def pytest_collection_modifyitems(config: Config, items: list[Item]): def pytest_collection_modifyitems(config: Config, items: list[Item]):
"""Modify test collection based on markers and environment.""" """Modify test collection based on markers and environment."""
# Handle version compatibility (from HEAD) # Handle version compatibility (from HEAD)
if not is_transformers_version_greater_than("4.57.0"): skip_bc = pytest.mark.skip(reason="Skip backward compatibility tests")
skip_bc = pytest.mark.skip(reason="Skip backward compatibility tests") for item in items:
for item in items: if "tests_v1" in str(item.fspath) and not is_transformers_version_greater_than("4.57.0"):
if "tests_v1" in str(item.fspath): item.add_marker(skip_bc)
item.add_marker(skip_bc)
_handle_slow_tests(items) _handle_slow_tests(items)
_handle_runs_on(items) _handle_runs_on(items)
@@ -156,6 +155,7 @@ def _manage_distributed_env(request: FixtureRequest, monkeypatch: MonkeyPatch) -
monkeypatch.setenv(env_key, visible_devices[0] if visible_devices else "0") monkeypatch.setenv(env_key, visible_devices[0] if visible_devices else "0")
else: else:
monkeypatch.setenv(env_key, "0") monkeypatch.setenv(env_key, "0")
if CURRENT_DEVICE == "cuda": if CURRENT_DEVICE == "cuda":
monkeypatch.setattr(torch.cuda, "device_count", lambda: 1) monkeypatch.setattr(torch.cuda, "device_count", lambda: 1)
elif CURRENT_DEVICE == "npu": elif CURRENT_DEVICE == "npu":

View File

@@ -24,7 +24,6 @@ def test_get_args_from_yaml(tmp_path: pathlib.Path):
### model ### model
model: "llamafactory/tiny-random-qwen2.5" model: "llamafactory/tiny-random-qwen2.5"
trust_remote_code: true trust_remote_code: true
use_fast_processor: true
model_class: "llm" model_class: "llm"
kernel_config: kernel_config:
name: "auto" name: "auto"

View File

@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""LLaMA-Factory test configuration. """LlamaFactory test configuration.
Contains shared fixtures, pytest configuration, and custom markers. Contains shared fixtures, pytest configuration, and custom markers.
""" """
@@ -22,6 +22,7 @@ import sys
import pytest import pytest
import torch import torch
import torch.distributed as dist
from pytest import Config, FixtureRequest, Item, MonkeyPatch from pytest import Config, FixtureRequest, Item, MonkeyPatch
from llamafactory.v1.accelerator.helper import get_current_accelerator, get_device_count from llamafactory.v1.accelerator.helper import get_current_accelerator, get_device_count
@@ -109,17 +110,24 @@ def _handle_device_visibility(items: list[Item]):
def pytest_collection_modifyitems(config: Config, items: list[Item]): def pytest_collection_modifyitems(config: Config, items: list[Item]):
"""Modify test collection based on markers and environment.""" """Modify test collection based on markers and environment."""
# Handle version compatibility (from HEAD) # Handle version compatibility (from HEAD)
if not is_transformers_version_greater_than("4.57.0"): skip_bc = pytest.mark.skip(reason="Skip backward compatibility tests")
skip_bc = pytest.mark.skip(reason="Skip backward compatibility tests") for item in items:
for item in items: if "tests_v1" in str(item.fspath) and not is_transformers_version_greater_than("4.57.0"):
if "tests_v1" in str(item.fspath): item.add_marker(skip_bc)
item.add_marker(skip_bc)
_handle_slow_tests(items) _handle_slow_tests(items)
_handle_runs_on(items) _handle_runs_on(items)
_handle_device_visibility(items) _handle_device_visibility(items)
@pytest.fixture(autouse=True)
def _cleanup_distributed_state():
"""Cleanup distributed state after each test."""
yield
if dist.is_initialized():
dist.destroy_process_group()
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def _manage_distributed_env(request: FixtureRequest, monkeypatch: MonkeyPatch) -> None: def _manage_distributed_env(request: FixtureRequest, monkeypatch: MonkeyPatch) -> None:
"""Set environment variables for distributed tests if specific devices are requested.""" """Set environment variables for distributed tests if specific devices are requested."""
@@ -155,6 +163,7 @@ def _manage_distributed_env(request: FixtureRequest, monkeypatch: MonkeyPatch) -
monkeypatch.setenv(env_key, visible_devices[0] if visible_devices else "0") monkeypatch.setenv(env_key, visible_devices[0] if visible_devices else "0")
else: else:
monkeypatch.setenv(env_key, "0") monkeypatch.setenv(env_key, "0")
if CURRENT_DEVICE == "cuda": if CURRENT_DEVICE == "cuda":
monkeypatch.setattr(torch.cuda, "device_count", lambda: 1) monkeypatch.setattr(torch.cuda, "device_count", lambda: 1)
elif CURRENT_DEVICE == "npu": elif CURRENT_DEVICE == "npu":

View File

@@ -0,0 +1,56 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
from llamafactory.v1.accelerator.interface import DistributedInterface
from llamafactory.v1.config.arg_parser import get_args
from llamafactory.v1.core.model_loader import ModelLoader
def test_init_on_meta():
_, model_args, *_ = get_args(
dict(
model="llamafactory/tiny-random-qwen2.5",
init_config={"name": "init_on_meta"},
)
)
model_loader = ModelLoader(model_args=model_args)
assert model_loader.model.device.type == "meta"
@pytest.mark.runs_on(["cuda", "npu"])
def test_init_on_rank0():
_, model_args, *_ = get_args(
dict(
model="llamafactory/tiny-random-qwen2.5",
init_config={"name": "init_on_rank0"},
)
)
model_loader = ModelLoader(model_args=model_args)
if DistributedInterface().get_rank() == 0:
assert model_loader.model.device.type == "cpu"
else:
assert model_loader.model.device.type == "meta"
def test_init_on_default():
_, model_args, *_ = get_args(
dict(
model="llamafactory/tiny-random-qwen2.5",
init_config={"name": "init_on_default"},
)
)
model_loader = ModelLoader(model_args=model_args)
assert model_loader.model.device.type == DistributedInterface().current_accelerator.type