14 Commits

Author SHA1 Message Date
Vo Van Phuc
b4e051bea4 [model] support for LiquidAI's LFM2.5 (Liquid Foundation Models) (#9726) 2026-01-07 14:14:47 +08:00
浮梦
d43e1007e8 [ci] improve cuda ci cache (#9725)
Co-authored-by: frozenleaves <frozen@Mac.local>
2026-01-07 12:34:40 +08:00
Xunpeng Xiao
f89d9367e5 [assets] update README.md (#9724) 2026-01-07 12:11:50 +08:00
Yaowei Zheng
d22de0d4bf [v1] add renderer ut (#9722) 2026-01-07 02:06:07 +08:00
Yaowei Zheng
ea0b4e2466 [v1] add cli sampler (#9721) 2026-01-06 23:31:27 +08:00
yanglele
e944dc442c [feature] add support for EAFT loss (#9720)
Co-authored-by: Yaowei Zheng <hiyouga@buaa.edu.cn>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2026-01-06 23:07:12 +08:00
Xunpeng Xiao
68119e5522 [misc] Add a PyTorch version warning for Conv3D. (#9715) 2026-01-05 13:26:29 +08:00
Yaowei Zheng
f60a6e3d01 [v1] add init plugin (#9716) 2026-01-04 20:51:46 +08:00
jiaqiw09
81b8a50aa5 [deps] Update pyproject.toml and requirements (#9714)
Co-authored-by: Yaowei Zheng <hiyouga@buaa.edu.cn>
2026-01-04 19:52:16 +08:00
Yaowei Zheng
8600530002 [misc] lint (#9710) 2026-01-04 13:47:56 +08:00
Hertz
9ae62c6fc0 [model] support Youtu-LLM-2B (#9707) 2026-01-04 13:17:57 +08:00
Xunpeng Xiao
0087bc253b [misc] Compatible with an empty architectures field in config.json (#9709) 2026-01-04 12:11:35 +08:00
Santosh Bhavani
355d5c5e5a [fix] fp8: add Transformer Engine backend support (#9705)
Co-authored-by: Yaowei Zheng <hiyouga@buaa.edu.cn>
2026-01-01 10:18:02 +08:00
Yaowei Zheng
6fe6bd290b [misc] set dev version (#9703) 2025-12-31 23:41:40 +08:00
98 changed files with 2174 additions and 761 deletions

View File

@@ -70,7 +70,8 @@ jobs:
run: |
uv venv
uv pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
uv pip install -e ".[dev]"
uv pip install -e .
uv pip install -r requirements/dev.txt
- name: Install transformers
if: ${{ matrix.transformers }}

View File

@@ -35,6 +35,11 @@ jobs:
group: ${{ github.workflow }}-${{ github.ref }}-${{ matrix.os }}-${{ matrix.python }}
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
env:
HF_HOME: "${{ github.workspace }}/../.runner_cache/huggingface"
UV_CACHE_DIR: "${{ github.workspace }}/../.runner_cache/uv"
UV_NO_SYNC: 1
steps:
- name: Checkout
uses: actions/checkout@v4
@@ -52,37 +57,21 @@ jobs:
- name: Install dependencies
run: |
uv venv
uv pip install -e ".[dev]"
- name: Cache HuggingFace models
id: hf-hub-cache
uses: actions/cache@v4
with:
path: ${{ runner.temp }}/huggingface
key: hf-cache-${{ runner.os }}-${{ hashFiles('tests/version.txt') }}
uv pip install -e .
uv pip install -r requirements/dev.txt
- name: Check quality
run: |
make style && make quality
env:
UV_NO_SYNC: 1
- name: Check license
run: |
make license
env:
UV_NO_SYNC: 1
- name: Check build
run: |
make build
env:
UV_NO_SYNC: 1
- name: Test with pytest
run: |
make test
env:
UV_NO_SYNC: 1
HF_HOME: ${{ runner.temp }}/huggingface
HF_HUB_OFFLINE: "${{ steps.hf-hub-cache.outputs.cache-hit == 'true' && '1' || '0' }}"

View File

@@ -58,8 +58,9 @@ jobs:
- name: Install dependencies
run: |
uv venv
uv pip install torch-npu==${{matrix.pytorch_npu}}
uv pip install -e ".[dev]"
uv pip install -r requirements/npu.txt
uv pip install -e .
uv pip install -r requirements/dev.txt
- name: Install node
run: |

View File

@@ -329,6 +329,7 @@ Read technical notes:
| [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
| [VibeThinker-1.5B](https://huggingface.co/WeiboAI) | 1.5B | qwen3 |
| [Yi/Yi-1.5 (Code)](https://huggingface.co/01-ai) | 1.5B/6B/9B/34B | yi |
| [Youtu-LLM](https://huggingface.co/tencent/) | 2B | youtu |
| [Yuan 2](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan |
> [!NOTE]
@@ -514,12 +515,13 @@ huggingface-cli login
#### Install from Source
```bash
git clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git
cd LLaMA-Factory
pip install -e ".[metrics]"
git clone --depth 1 https://github.com/hiyouga/LlamaFactory.git
cd LlamaFactory
pip install -e .
pip install -r requirements/metrics.txt
```
Optional dependencies available: `metrics`, `deepspeed`. Install with: `pip install -e ".[metrics,deepspeed]"`
Optional dependencies available: `metrics`, `deepspeed`. Install with: `pip install -e . && pip install -r requirements/metrics.txt -r requirements/deepspeed.txt`
Additional dependencies for specific features are available in `examples/requirements/`.
@@ -577,36 +579,21 @@ To enable FlashAttention-2 on the Windows platform, please use the script from [
<details><summary>For Ascend NPU users</summary>
To install LLaMA Factory on Ascend NPU devices, please upgrade Python to version 3.10 or higher: `pip install -e . torch-npu==2.7.1`. Additionally, you need to install the **[Ascend CANN Toolkit and Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**. Please follow the [installation tutorial](https://www.hiascend.com/document/detail/en/CANNCommunityEdition/600alphaX/softwareinstall/instg/atlasdeploy_03_0031.html) or use the following commands:
To install LLaMA Factory on Ascend NPU devices, please upgrade Python to version 3.10 or higher: `pip install -r requirements/npu.txt`. Additionally, you need to install the **Ascend CANN Toolkit and Kernels**. Please follow the [installation tutorial](https://llamafactory.readthedocs.io/en/latest/advanced/npu_installation.html).
You can also download the pre-built Docker images:
```bash
# replace the url according to your CANN version and devices
# install CANN Toolkit
wget https://ascend-repo.obs.cn-east-2.myhuaweicloud.com/Milan-ASL/Milan-ASL%20V100R001C20SPC702/Ascend-cann-toolkit_8.0.0.alpha002_linux-"$(uname -i)".run
bash Ascend-cann-toolkit_8.0.0.alpha002_linux-"$(uname -i)".run --install
# Docker Hub
docker pull hiyouga/llamafactory:latest-npu-a2
docker pull hiyouga/llamafactory:latest-npu-a3
# install CANN Kernels
wget https://ascend-repo.obs.cn-east-2.myhuaweicloud.com/Milan-ASL/Milan-ASL%20V100R001C20SPC702/Ascend-cann-kernels-910b_8.0.0.alpha002_linux-"$(uname -i)".run
bash Ascend-cann-kernels-910b_8.0.0.alpha002_linux-"$(uname -i)".run --install
# set env variables
source /usr/local/Ascend/ascend-toolkit/set_env.sh
# quay.io
docker pull quay.io/ascend/llamafactory:latest-npu-a2
docker pull quay.io/ascend/llamafactory:latest-npu-a3
```
| Requirement | Minimum | Recommend |
| ------------ | ------- | -------------- |
| CANN | 8.0.RC1 | 8.0.0.alpha002 |
| torch | 2.1.0 | 2.7.1 |
| torch-npu | 2.1.0 | 2.7.1 |
| deepspeed | 0.13.2 | 0.13.2 |
| vllm-ascend | - | 0.7.3 |
Remember to use `ASCEND_RT_VISIBLE_DEVICES` instead of `CUDA_VISIBLE_DEVICES` to specify the device to use.
If you cannot infer model on NPU devices, try setting `do_sample: false` in the configurations.
Download the pre-built Docker images: [32GB](http://mirrors.cn-central-221.ovaijisuan.com/detail/130.html) | [64GB](http://mirrors.cn-central-221.ovaijisuan.com/detail/131.html)
#### Install BitsAndBytes
To use QLoRA based on bitsandbytes on Ascend NPU, please follow these 3 steps:

View File

@@ -331,6 +331,7 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
| [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
| [VibeThinker-1.5B](https://huggingface.co/WeiboAI) | 1.5B | qwen3 |
| [Yi/Yi-1.5 (Code)](https://huggingface.co/01-ai) | 1.5B/6B/9B/34B | yi |
| [Youtu-LLM](https://huggingface.co/tencent/) | 2B | youtu |
| [Yuan 2](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan |
> [!NOTE]
@@ -516,12 +517,13 @@ huggingface-cli login
#### 从源码安装
```bash
git clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git
cd LLaMA-Factory
pip install -e ".[metrics]"
git clone --depth 1 https://github.com/hiyouga/LlamaFactory.git
cd LlamaFactory
pip install -e .
pip install -r requirements/metrics.txt
```
可选的额外依赖项:`metrics``deepspeed`。使用 `pip install -e ".[metrics,deepspeed]"` 安装。
可选的额外依赖项:`metrics``deepspeed`。使用 `pip install -e . && pip install -r requirements/metrics.txt -r requirements/deepspeed.txt` 安装。
其他可选依赖项请参考 `examples/requirements/` 目录下的文件。
@@ -579,36 +581,20 @@ pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/downl
<details><summary>昇腾 NPU 用户指南</summary>
在昇腾 NPU 设备上安装 LLaMA Factory 时,请升级 Python 到 3.10 及以上,并需要指定额外依赖项,使用 `pip install -e . torch-npu==2.7.1` 命令安装。此外,还需要安装 **[Ascend CANN Toolkit 与 Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**,安装方法请参考[安装教程](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/80RC2alpha002/quickstart/quickstart/quickstart_18_0004.html)或使用以下命令:
在昇腾 NPU 设备上安装 LLaMA Factory 时,请升级 Python 到 3.10 及以上,并需要指定额外依赖项,使用 `pip install -r requirements/npu.txt` 命令安装。此外,还需要安装 **Ascend CANN Toolkit 与 Kernels**,安装方法请参考[安装教程](https://llamafactory.readthedocs.io/zh-cn/latest/advanced/npu_installation.html)。
您可以直接下载预安装的最新docker镜像
```bash
# 请替换 URL 为 CANN 版本和设备型号对应的 URL
# 安装 CANN Toolkit
wget https://ascend-repo.obs.cn-east-2.myhuaweicloud.com/Milan-ASL/Milan-ASL%20V100R001C17SPC701/Ascend-cann-toolkit_8.0.RC1.alpha001_linux-"$(uname -i)".run
bash Ascend-cann-toolkit_8.0.RC1.alpha001_linux-"$(uname -i)".run --install
# Docker Hub
docker pull hiyouga/llamafactory:latest-npu-a2
docker pull hiyouga/llamafactory:latest-npu-a3
# 安装 CANN Kernels
wget https://ascend-repo.obs.cn-east-2.myhuaweicloud.com/Milan-ASL/Milan-ASL%20V100R001C17SPC701/Ascend-cann-kernels-910b_8.0.RC1.alpha001_linux.run
bash Ascend-cann-kernels-910b_8.0.RC1.alpha001_linux.run --install
# 设置环境变量
source /usr/local/Ascend/ascend-toolkit/set_env.sh
# quay.io
docker pull quay.io/ascend/llamafactory:latest-npu-a2
docker pull quay.io/ascend/llamafactory:latest-npu-a3
```
| 依赖项 | 至少 | 推荐 |
| ------------ | ------- | -------------- |
| CANN | 8.0.RC1 | 8.0.0.alpha002 |
| torch | 2.1.0 | 2.7.1 |
| torch-npu | 2.1.0 | 2.7.1 |
| deepspeed | 0.13.2 | 0.13.2 |
| vllm-ascend | - | 0.7.3 |
请使用 `ASCEND_RT_VISIBLE_DEVICES` 而非 `CUDA_VISIBLE_DEVICES` 来指定运算设备。
如果遇到无法正常推理的情况,请尝试设置 `do_sample: false`
下载预构建 Docker 镜像:[32GB](http://mirrors.cn-central-221.ovaijisuan.com/detail/130.html) | [64GB](http://mirrors.cn-central-221.ovaijisuan.com/detail/131.html)
#### 安装 BitsAndBytes
如果要在 Ascend NPU 上进行基于 bitsandbytes 的 QLoRA 量化微调,请执行如下步骤:

File diff suppressed because one or more lines are too long

View File

@@ -32,7 +32,8 @@ RUN pip config set global.index-url "${PIP_INDEX}" && \
COPY . /app
# Install LLaMA Factory
RUN pip install --no-cache-dir --no-build-isolation -e ".[metrics,deepspeed]"
RUN pip install --no-cache-dir --no-build-isolation -e . && \
pip install --no-cache-dir --no-build-isolation -r requirements/metrics.txt -r requirements/deepspeed.txt
# Rebuild flash attention
RUN if [ "${INSTALL_FLASHATTN}" == "true" ]; then \

View File

@@ -60,7 +60,8 @@ WORKDIR /app
COPY . /app
# Install LLaMA Factory
RUN pip install --no-cache-dir -e ".[metrics]" --no-build-isolation
RUN pip install --no-cache-dir -e . --no-build-isolation && \
pip install --no-cache-dir -r requirements/metrics.txt --no-build-isolation
RUN pip install "git+https://github.com/alibaba/roll.git#subdirectory=mcore_adapter"

View File

@@ -35,7 +35,8 @@ COPY . /app
# Install torch-npu
RUN pip uninstall -y torch torchvision torchaudio && \
pip install --no-cache-dir "torch==2.7.1" "torch-npu==2.7.1" "torchvision==0.22.1" "torchaudio==2.7.1" --index-url "${PYTORCH_INDEX}" && \
pip install --no-cache-dir -e ".[metrics]" --no-build-isolation
pip install --no-cache-dir -e . --no-build-isolation && \
pip install --no-cache-dir -r requirements/metrics.txt --no-build-isolation
# Set up volumes
# VOLUME [ "/root/.cache/huggingface", "/app/shared_data", "/app/output" ]

View File

@@ -34,7 +34,8 @@ COPY . /app
# Reinstall pytorch rocm and install LLaMA Factory
RUN pip uninstall -y torch torchvision torchaudio && \
pip install --no-cache-dir --no-build-isolation -e --pre ".[metrics,deepspeed]" --index-url "${PYTORCH_INDEX}"
pip install --no-cache-dir --no-build-isolation -e --pre . --index-url "${PYTORCH_INDEX}" && \
pip install --no-cache-dir --no-build-isolation -r requirements/metrics.txt -r requirements/deepspeed.txt --index-url "${PYTORCH_INDEX}"
# Rebuild flash attention
RUN if [ "${INSTALL_FLASHATTN}" == "true" ]; then \

View File

@@ -0,0 +1,40 @@
### model
model_name_or_path: Qwen/Qwen2.5-0.5B-Instruct
trust_remote_code: true
### method
stage: sft
do_train: true
finetuning_type: full
use_eaft_loss: true
### dataset
dataset: identity,alpaca_en_demo
template: qwen
cutoff_len: 2048
max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16
dataloader_num_workers: 4
### output
output_dir: qwen2.5-0_5b/full/sft_eaft
logging_steps: 1
save_steps: 500
plot_loss: true
overwrite_output_dir: true
save_only_model: false
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
### train
per_device_train_batch_size: 2
gradient_accumulation_steps: 8
learning_rate: 1.0e-5
num_train_epochs: 3.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
bf16: true
ddp_timeout: 180000000

View File

@@ -73,14 +73,9 @@ dependencies = [
# api
"uvicorn",
"fastapi",
"sse-starlette"
"sse-starlette",
]
[project.optional-dependencies]
dev = ["pre-commit", "ruff", "pytest", "build"]
metrics = ["nltk", "jieba", "rouge-chinese"]
deepspeed = ["deepspeed>=0.10.0,<=0.16.9"]
[project.scripts]
llamafactory-cli = "llamafactory.cli:main"
lmf = "llamafactory.cli:main"

View File

@@ -0,0 +1 @@
deepspeed>=0.10.0,<=0.16.9

4
requirements/dev.txt Normal file
View File

@@ -0,0 +1,4 @@
pre-commit
ruff
pytest
build

3
requirements/metrics.txt Normal file
View File

@@ -0,0 +1,3 @@
nltk
jieba
rouge-chinese

4
requirements/npu.txt Normal file
View File

@@ -0,0 +1,4 @@
torch==2.7.1
torch-npu==2.7.1
torchvision==0.22.1
torchaudio==2.7.1

View File

@@ -28,7 +28,7 @@ try:
jieba.setLogLevel(logging.CRITICAL)
jieba.initialize()
except ImportError:
print("Please install llamafactory with `pip install -e .[metrics]`.")
print("Please install llamafactory with `pip install -r requirements/metrics.txt`.")
raise

View File

@@ -1330,6 +1330,26 @@ register_template(
)
register_template(
name="lfm",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_function=FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="lfm"),
format_observation=StringFormatter(
slots=[
"<|im_start|>tool\n<|tool_response_start|>{{content}}<|tool_response_end|><|im_end|>\n"
"<|im_start|>assistant\n"
]
),
format_tools=ToolFormatter(tool_format="lfm"),
default_system="You are a helpful AI assistant.",
stop_words=["<|im_end|>"],
tool_call_words=("<|tool_call_start|>", "<|tool_call_end|>"),
replace_eos=True,
)
register_template(
name="llama2",
format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]),
@@ -2278,6 +2298,21 @@ register_template(
)
register_template(
name="youtu",
format_user=StringFormatter(slots=["<|User|>{{content}}<|Assistant|>"]),
format_assistant=StringFormatter(slots=["{{content}}<|end_of_text|>"]),
format_system=StringFormatter(slots=["{{content}}"]),
format_function=FunctionFormatter(slots=["{{content}}"], tool_format="default"),
format_observation=StringFormatter(slots=["<tool_response>\n{{content}}\n</tool_response><|Assistant|>"]),
format_tools=ToolFormatter(tool_format="default"),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<|end_of_text|>"],
replace_eos=True,
template_class=ReasoningTemplate,
)
register_template(
name="yuan",
format_user=StringFormatter(slots=["{{content}}", {"token": "<sep>"}]),

View File

@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import ast
import json
import re
from abc import ABC, abstractmethod
@@ -101,6 +102,8 @@ LING_TOOL_PROMPT = (
""""arguments": <args-json-object>}}\n</tool_call>"""
)
LFM_TOOL_PROMPT = "List of tools: <|tool_list_start|>{tool_text}<|tool_list_end|>"
@dataclass
class ToolUtils(ABC):
@@ -546,10 +549,115 @@ class LingToolUtils(QwenToolUtils):
return LING_TOOL_PROMPT.format(tool_text=tool_text) + "\n" + "detailed thinking off"
class LFMToolUtils(ToolUtils):
r"""LFM 2.5 tool using template with Pythonic function call syntax."""
@override
@staticmethod
def tool_formatter(tools: list[dict[str, Any]]) -> str:
tool_list = []
for tool in tools:
tool = tool.get("function", tool) if tool.get("type") == "function" else tool
tool_list.append(tool)
return LFM_TOOL_PROMPT.format(tool_text=json.dumps(tool_list, ensure_ascii=False))
@override
@staticmethod
def function_formatter(functions: list["FunctionCall"]) -> str:
calls = []
for name, args_json in functions:
args = json.loads(args_json)
kwargs_parts = []
for key, value in args.items():
if isinstance(value, str):
kwargs_parts.append(f'{key}="{value}"')
else:
kwargs_parts.append(f"{key}={json.dumps(value, ensure_ascii=False)}")
calls.append(f"{name}({', '.join(kwargs_parts)})")
return f"<|tool_call_start|>[{', '.join(calls)}]<|tool_call_end|>"
@staticmethod
def _ast_to_value(node: ast.AST) -> Any:
"""Convert an AST node to a Python value, handling JSON-style booleans/null."""
# Handle JSON-style true/false/null as Name nodes
if isinstance(node, ast.Name):
if node.id == "true":
return True
elif node.id == "false":
return False
elif node.id == "null":
return None
else:
raise ValueError(f"Unknown identifier: {node.id}")
# Use literal_eval for other cases (strings, numbers, lists, dicts)
return ast.literal_eval(node)
@override
@staticmethod
def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]:
# Extract content between tool call markers
start_marker = "<|tool_call_start|>"
end_marker = "<|tool_call_end|>"
start_idx = content.find(start_marker)
if start_idx == -1:
return content
end_idx = content.find(end_marker, start_idx)
if end_idx == -1:
return content
tool_call_str = content[start_idx + len(start_marker) : end_idx].strip()
# Parse Pythonic function call syntax using AST
try:
tree = ast.parse(tool_call_str, mode="eval")
except SyntaxError:
return content
# Handle both single call and list of calls
if isinstance(tree.body, ast.List):
call_nodes = tree.body.elts
elif isinstance(tree.body, ast.Call):
call_nodes = [tree.body]
else:
return content
results = []
for node in call_nodes:
if not isinstance(node, ast.Call):
return content
# Extract function name
if isinstance(node.func, ast.Name):
func_name = node.func.id
else:
return content
# Extract keyword arguments
args_dict = {}
for keyword in node.keywords:
key = keyword.arg
try:
value = LFMToolUtils._ast_to_value(keyword.value)
except (ValueError, SyntaxError):
return content
args_dict[key] = value
results.append(FunctionCall(func_name, json.dumps(args_dict, ensure_ascii=False)))
return results if results else content
TOOLS = {
"default": DefaultToolUtils(),
"glm4": GLM4ToolUtils(),
"llama3": Llama3ToolUtils(),
"lfm": LFMToolUtils(),
"minimax1": MiniMaxM1ToolUtils(),
"minimax2": MiniMaxM2ToolUtils(),
"mistral": MistralToolUtils(),

View File

@@ -1493,6 +1493,19 @@ register_model_group(
)
register_model_group(
models={
"LFM2.5-1.2B": {
DownloadSource.DEFAULT: "LiquidAI/LFM2.5-1.2B-Base",
},
"LFM2.5-1.2B-Instruct": {
DownloadSource.DEFAULT: "LiquidAI/LFM2.5-1.2B-Instruct",
},
},
template="lfm",
)
register_model_group(
models={
"Llama-7B": {
@@ -3846,6 +3859,21 @@ register_model_group(
)
register_model_group(
models={
"Youtu-LLM-2B-Instruct": {
DownloadSource.DEFAULT: "tencent/Youtu-LLM-2B",
DownloadSource.MODELSCOPE: "Tencent-YouTu-Research/Youtu-LLM-2B",
},
"Youtu-LLM-2B-Base": {
DownloadSource.DEFAULT: "tencent/Youtu-LLM-2B-Base",
DownloadSource.MODELSCOPE: "Tencent-YouTu-Research/Youtu-LLM-2B-Base",
},
},
template="youtu",
)
register_model_group(
models={
"Yuan2-2B-Chat": {

View File

@@ -19,7 +19,7 @@
from collections import OrderedDict
VERSION = "0.9.4"
VERSION = "0.9.5.dev0"
def print_env() -> None:

View File

@@ -490,6 +490,14 @@ class FinetuningArguments(
default=False,
metadata={"help": "Whether to use the DFT loss."},
)
use_eaft_loss: bool = field(
default=False,
metadata={"help": "Whether to use the EAFT loss."},
)
eaft_alpha: float = field(
default=1.0,
metadata={"help": "The alpha parameter for EAFT loss to control the power of adaptive weight."},
)
freeze_vision_tower: bool = field(
default=True,
metadata={"help": "Whether ot not to freeze the vision tower in MLLM training."},

View File

@@ -298,23 +298,6 @@ class QuantizationArguments:
default=None,
metadata={"help": "Device map used to infer the 4-bit quantized model, needs bitsandbytes>=0.43.0."},
)
fp8: bool = field(
default=False,
metadata={
"help": "Enable FP8 mixed precision training via HuggingFace Accelerate. "
"Requires PyTorch 2.7+ and Hopper architecture GPUs."
},
)
fp8_backend: str = field(
default="auto",
metadata={
"help": "FP8 backend to use ('auto', 'torchao', 'te', 'msamp'). 'auto' selects best available backend."
},
)
fp8_enable_fsdp_float8_all_gather: bool = field(
default=False,
metadata={"help": "Enable FP8 optimizations for FSDP2 all-gather operations."},
)
@dataclass

View File

@@ -142,14 +142,6 @@ def _verify_model_args(
logger.warning_rank0("We should use slow tokenizer for the Yi models. Change `use_fast_tokenizer` to False.")
model_args.use_fast_tokenizer = False
# Validate advanced training features
if model_args.fp8 and model_args.quantization_bit is not None:
raise ValueError("FP8 training is not compatible with quantization. Please disable one of them.")
if model_args.fp8_enable_fsdp_float8_all_gather and not model_args.fp8:
logger.warning_rank0("fp8_enable_fsdp_float8_all_gather requires fp8=True. Setting fp8=True.")
model_args.fp8 = True
def _check_extra_dependencies(
model_args: "ModelArguments",
@@ -347,6 +339,9 @@ def get_train_args(args: dict[str, Any] | list[str] | None = None) -> _TRAIN_CLS
if training_args.deepspeed is not None and (finetuning_args.use_galore or finetuning_args.use_apollo):
raise ValueError("GaLore and APOLLO are incompatible with DeepSpeed yet.")
if training_args.fp8 and training_args.quantization_bit is not None:
raise ValueError("FP8 training is not compatible with quantization. Please disable one of them.")
if model_args.infer_backend != EngineName.HF:
raise ValueError("vLLM/SGLang backend is only available for API, CLI and Web.")
@@ -363,6 +358,10 @@ def get_train_args(args: dict[str, Any] | list[str] | None = None) -> _TRAIN_CLS
_verify_model_args(model_args, data_args, finetuning_args)
_check_extra_dependencies(model_args, finetuning_args, training_args)
if training_args.fp8_enable_fsdp_float8_all_gather and not training_args.fp8:
logger.warning_rank0("fp8_enable_fsdp_float8_all_gather requires fp8=True. Setting fp8=True.")
model_args.fp8 = True
if (
training_args.do_train
and finetuning_args.finetuning_type == "lora"

View File

@@ -92,7 +92,30 @@ class RayArguments:
@dataclass
class TrainingArguments(RayArguments, BaseTrainingArguments):
class Fp8Arguments:
r"""Arguments pertaining to the FP8 training."""
fp8: bool = field(
default=False,
metadata={
"help": "Enable FP8 mixed precision training via HuggingFace Accelerate. "
"Requires PyTorch 2.7+ and Hopper architecture GPUs."
},
)
fp8_backend: str = field(
default="auto",
metadata={
"help": "FP8 backend to use ('auto', 'torchao', 'te', 'msamp'). 'auto' selects best available backend."
},
)
fp8_enable_fsdp_float8_all_gather: bool = field(
default=False,
metadata={"help": "Enable FP8 optimizations for FSDP2 all-gather operations."},
)
@dataclass
class TrainingArguments(Fp8Arguments, RayArguments, BaseTrainingArguments):
r"""Arguments pertaining to the trainer."""
overwrite_output_dir: bool = field(

View File

@@ -15,6 +15,7 @@
import os
from typing import TYPE_CHECKING, Any, Optional, TypedDict
import torch
from transformers import (
AutoConfig,
AutoModelForCausalLM,
@@ -29,6 +30,7 @@ from trl import AutoModelForCausalLMWithValueHead
from ..extras import logging
from ..extras.misc import count_parameters, skip_check_imports, try_download_model_from_other_hub
from ..extras.packages import is_torch_version_greater_than
from .adapter import init_adapter
from .model_utils.ktransformers import load_kt_pretrained_model
from .model_utils.liger_kernel import apply_liger_kernel
@@ -203,6 +205,16 @@ def load_model(
model.load_state_dict(vhead_params, strict=False)
logger.info_rank0(f"Loaded valuehead from checkpoint: {vhead_path}")
# Conv3D is not recommended when using torch 2.9.x
if is_torch_version_greater_than("2.9.0") and not is_torch_version_greater_than("2.10.0"):
if any(isinstance(m, torch.nn.Conv3d) for m in model.modules()):
raise ValueError(
"Unsupported torch version detected: torch 2.9.x with Conv3D. "
"This combination is known to cause severe performance regression. "
"Please downgrade torch to <2.9 or remove Conv3D. "
"See https://github.com/pytorch/pytorch/issues/166122"
)
if not is_trainable:
model.requires_grad_(False)
model.eval()

View File

@@ -138,13 +138,14 @@ def patch_config(
if getattr(config, "model_type", None) == "kimi_vl" and is_trainable:
setattr(config.text_config, "topk_method", "greedy")
if "InternVLChatModel" in getattr(config, "architectures", []):
architectures = getattr(config, "architectures", None)
if isinstance(architectures, list) and "InternVLChatModel" in architectures:
raise ValueError(
"Please download the internvl models in a Hugging Facecompatible format "
"(for example, https://huggingface.co/OpenGVLab/InternVL3-8B-hf)."
)
if "LlavaLlamaForCausalLM" in getattr(config, "architectures", []):
if isinstance(architectures, list) and "LlavaLlamaForCausalLM" in architectures:
raise ValueError("Please download llava models with hf-compatible format: https://huggingface.co/llava-hf")
if getattr(config, "model_type", None) == "internlm3" and not is_transformers_version_greater_than("4.47.1"):

View File

@@ -12,35 +12,45 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import types
from typing import TYPE_CHECKING, Any, Optional
from ..extras import logging
if TYPE_CHECKING:
from ..hparams import ModelArguments
from ..hparams import TrainingArguments
logger = logging.get_logger(__name__)
def create_fp8_kwargs(model_args: "ModelArguments") -> list[Any]:
def create_fp8_kwargs(training_args: "TrainingArguments") -> list[Any]:
"""Create AORecipeKwargs for FP8 training with HuggingFace Accelerate.
Args:
model_args: Model arguments containing FP8 configuration
training_args: Training arguments containing FP8 configuration
Returns:
List containing AORecipeKwargs if FP8 is enabled and supported, empty list otherwise
"""
if not model_args.fp8:
if not training_args.fp8:
return []
try:
# Check if AORecipeKwargs is available (Accelerate 1.8.0+)
from accelerate.utils import AORecipeKwargs
backend = getattr(training_args, "fp8_backend", "auto")
logger.info_rank0(f"Creating FP8 configuration with backend: {backend}")
backend = getattr(model_args, "fp8_backend", "auto")
logger.info_rank0(f"Creating FP8 configuration with backend: {backend}")
try:
# Use Transformer Engine backend (optimal for Hopper GPUs)
if backend == "te":
from accelerate.utils import FP8RecipeKwargs
logger.info_rank0("Using Transformer Engine FP8 backend")
return [FP8RecipeKwargs(backend="TE", fp8_format="HYBRID", amax_history_len=16, amax_compute_algo="max")]
# Use TorchAO backend (default)
from accelerate.utils import AORecipeKwargs
# Create Float8LinearConfig if torchao backend is used
config = None
@@ -83,7 +93,10 @@ def create_fp8_kwargs(model_args: "ModelArguments") -> list[Any]:
return True
# Map FSDP all-gather setting if available (this affects the underlying implementation)
if hasattr(model_args, "fp8_enable_fsdp_float8_all_gather") and model_args.fp8_enable_fsdp_float8_all_gather:
if (
hasattr(training_args, "fp8_enable_fsdp_float8_all_gather")
and training_args.fp8_enable_fsdp_float8_all_gather
):
logger.info_rank0("FSDP float8 all-gather optimization requested")
return [AORecipeKwargs(config=config, module_filter_func=module_filter_func)]
@@ -92,19 +105,19 @@ def create_fp8_kwargs(model_args: "ModelArguments") -> list[Any]:
return []
def get_fp8_mixed_precision(model_args: "ModelArguments") -> Optional[str]:
def get_fp8_mixed_precision(training_args: "TrainingArguments") -> Optional[str]:
"""Get the mixed precision setting for Accelerate when using FP8.
Args:
model_args: Model arguments containing FP8 configuration
training_args: Training arguments containing FP8 configuration
Returns:
"fp8" if FP8 is enabled, None otherwise
"""
return "fp8" if model_args.fp8 else None
return "fp8" if training_args.fp8 else None
def configure_fp8_environment(model_args: "ModelArguments") -> None:
def configure_fp8_environment(training_args: "TrainingArguments") -> None:
"""Configure FP8 environment for HuggingFace Accelerate.
FP8 training is handled entirely through HuggingFace Accelerate, regardless of whether
@@ -112,11 +125,9 @@ def configure_fp8_environment(model_args: "ModelArguments") -> None:
variables and validates the FP8 configuration.
Args:
model_args: Model arguments containing FP8 configuration
training_args: Training arguments containing FP8 configuration
"""
import os
if not model_args.fp8:
if not training_args.fp8:
return
# Set mixed precision to fp8 for HuggingFace Accelerate
@@ -124,38 +135,38 @@ def configure_fp8_environment(model_args: "ModelArguments") -> None:
logger.info_rank0("Set ACCELERATE_MIXED_PRECISION=fp8")
# Configure FP8 backend and options
backend = getattr(model_args, "fp8_backend", "auto")
backend = getattr(training_args, "fp8_backend", "auto")
if backend != "auto":
os.environ["FP8_BACKEND"] = backend
logger.info_rank0(f"Set FP8_BACKEND={backend}")
# Create and validate FP8 recipe kwargs (for logging/debugging)
fp8_kwargs = create_fp8_kwargs(model_args)
fp8_kwargs = create_fp8_kwargs(training_args)
logger.info_rank0(f"FP8 AORecipeKwargs created: {len(fp8_kwargs)} items")
# Enable FSDP float8 all-gather optimization if requested
if hasattr(model_args, "fp8_enable_fsdp_float8_all_gather") and model_args.fp8_enable_fsdp_float8_all_gather:
if hasattr(training_args, "fp8_enable_fsdp_float8_all_gather") and training_args.fp8_enable_fsdp_float8_all_gather:
os.environ["FP8_ENABLE_FSDP_FLOAT8_ALL_GATHER"] = "true"
logger.info_rank0("Set FP8_ENABLE_FSDP_FLOAT8_ALL_GATHER=true")
logger.info_rank0("FP8 environment configured - all FP8 training handled by HuggingFace Accelerate")
def verify_fp8_status(accelerator, model_args: "ModelArguments") -> None:
def verify_fp8_status(accelerator, training_args: "TrainingArguments") -> None:
"""Verify that FP8 training is actually working after model preparation.
Args:
accelerator: The HuggingFace Accelerator instance
model_args: Model arguments containing FP8 configuration
training_args: Training arguments containing FP8 configuration
"""
if not model_args.fp8:
if not training_args.fp8:
return
# Check Accelerate's FP8 status
fp8_enabled = getattr(accelerator, "fp8_enabled", False)
fp8_backend_type = getattr(accelerator, "fp8_backend", "UNKNOWN")
backend = getattr(model_args, "fp8_backend", "auto")
backend = getattr(training_args, "fp8_backend", "auto")
if backend == "torchao" or backend == "auto":
logger.info_rank0(
"FP8 training enabled with TorchAO backend. For optimal performance, "
@@ -169,3 +180,50 @@ def verify_fp8_status(accelerator, model_args: "ModelArguments") -> None:
if not fp8_enabled:
logger.info_rank0("WARNING: FP8 was requested but Accelerate shows fp8_enabled=False. FP8 may not be working.")
def patch_accelerator_for_fp8() -> None:
"""Patch Accelerator to inject FP8 recipe kwargs.
This is needed because HuggingFace Trainer doesn't pass kwargs_handlers to Accelerator.
We monkey-patch Accelerator.__init__ to inject the FP8 recipe and force mixed_precision='fp8'.
"""
import transformer_engine.pytorch as te
from accelerate import Accelerator
# Guard against multiple patches
if getattr(Accelerator, "_te_fp8_patched", False):
return
# Stub for Accelerate 1.12+ compatibility (te.fp8.check_mxfp8_support doesn't exist yet)
if not hasattr(te, "fp8"):
te.fp8 = types.ModuleType("fp8")
te.fp8.check_mxfp8_support = lambda: (False, "MXFP8 not supported")
try:
from accelerate.utils import TERecipeKwargs as FP8Recipe
use_te_recipe = True
except ImportError:
from accelerate.utils import FP8RecipeKwargs as FP8Recipe
use_te_recipe = False
original_init = Accelerator.__init__
def patched_init(self, *args, **kwargs):
if "kwargs_handlers" not in kwargs or not kwargs["kwargs_handlers"]:
if use_te_recipe:
kwargs["kwargs_handlers"] = [
FP8Recipe(fp8_format="HYBRID", amax_history_len=16, amax_compute_algo="max")
]
else:
kwargs["kwargs_handlers"] = [
FP8Recipe(backend="TE", fp8_format="HYBRID", amax_history_len=16, amax_compute_algo="max")
]
# Only force mixed_precision when we inject handlers
kwargs["mixed_precision"] = "fp8"
return original_init(self, *args, **kwargs)
Accelerator.__init__ = patched_init
Accelerator._te_fp8_patched = True

View File

@@ -19,16 +19,15 @@ import torch
from transformers import Trainer
from typing_extensions import override
from ...extras.packages import is_transformers_version_greater_than
from ..callbacks import SaveProcessorCallback
from ..fp8_utils import configure_fp8_environment, verify_fp8_status
from ..fp8_utils import configure_fp8_environment, patch_accelerator_for_fp8, verify_fp8_status
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
if TYPE_CHECKING:
from transformers import ProcessorMixin
from ...hparams import FinetuningArguments, ModelArguments
from ...hparams import FinetuningArguments, ModelArguments, TrainingArguments
class CustomTrainer(Trainer):
@@ -41,11 +40,13 @@ class CustomTrainer(Trainer):
model_args: Optional["ModelArguments"] = None,
**kwargs,
) -> None:
kwargs["processing_class"] = kwargs.pop("tokenizer")
# Configure FP8 environment if enabled
if model_args is not None and model_args.fp8:
configure_fp8_environment(model_args)
if is_transformers_version_greater_than("4.46"):
kwargs["processing_class"] = kwargs.pop("tokenizer")
training_args: TrainingArguments = kwargs.get("args")
if training_args.fp8:
configure_fp8_environment(training_args)
if getattr(training_args, "fp8_backend", "auto") == "te":
patch_accelerator_for_fp8()
super().__init__(**kwargs)
if processor is not None:
@@ -64,9 +65,8 @@ class CustomTrainer(Trainer):
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.add_callback(BAdamCallback)
# Verify FP8 status after trainer initialization (accelerator should be available)
if model_args is not None and model_args.fp8 and hasattr(self, "accelerator"):
verify_fp8_status(self.accelerator, model_args)
if training_args.fp8 and hasattr(self, "accelerator"): # verify FP8 status after trainer initialization
verify_fp8_status(self.accelerator, training_args)
@override
def create_optimizer(self) -> "torch.optim.Optimizer":

View File

@@ -27,18 +27,17 @@ from typing_extensions import override
from ...extras import logging
from ...extras.constants import IGNORE_INDEX
from ...extras.packages import is_transformers_version_greater_than
from ..callbacks import SaveProcessorCallback
from ..fp8_utils import configure_fp8_environment, verify_fp8_status
from ..fp8_utils import configure_fp8_environment, patch_accelerator_for_fp8, verify_fp8_status
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
if TYPE_CHECKING:
from torch.utils.data import Dataset
from transformers import PreTrainedTokenizer, ProcessorMixin
from transformers import ProcessorMixin
from transformers.trainer import PredictionOutput
from ...hparams import FinetuningArguments, ModelArguments
from ...hparams import FinetuningArguments, ModelArguments, TrainingArguments
logger = logging.get_logger(__name__)
@@ -55,13 +54,13 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
gen_kwargs: Optional[dict[str, Any]] = None,
**kwargs,
) -> None:
kwargs["processing_class"] = kwargs.pop("tokenizer")
# Configure FP8 environment if enabled
if model_args is not None and model_args.fp8:
configure_fp8_environment(model_args)
if is_transformers_version_greater_than("4.46"):
kwargs["processing_class"] = kwargs.pop("tokenizer")
else:
self.processing_class: PreTrainedTokenizer = kwargs.get("tokenizer")
training_args: TrainingArguments = kwargs.get("args")
if training_args.fp8:
configure_fp8_environment(training_args)
if getattr(training_args, "fp8_backend", "auto") == "te":
patch_accelerator_for_fp8()
super().__init__(**kwargs)
if processor is not None:
@@ -88,9 +87,15 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
self.compute_loss_func = dft_loss_func
# Verify FP8 status after trainer initialization (accelerator should be available)
if model_args is not None and model_args.fp8 and hasattr(self, "accelerator"):
verify_fp8_status(self.accelerator, model_args)
elif finetuning_args.use_eaft_loss:
from ..trainer_utils import eaft_loss_func
self.compute_loss_func = lambda outputs, labels, num_items_in_batch=None: eaft_loss_func(
outputs, labels, num_items_in_batch, finetuning_args.eaft_alpha
)
if training_args.fp8 and hasattr(self, "accelerator"): # verify FP8 status after trainer initialization
verify_fp8_status(self.accelerator, training_args)
@override
def create_optimizer(self) -> "torch.optim.Optimizer":

View File

@@ -634,7 +634,9 @@ def get_batch_logps(
return logps, valid_length
def dft_loss_func(outputs, labels, num_items_in_batch=None):
def dft_loss_func(
outputs: "torch.Tensor", labels: "torch.Tensor", num_items_in_batch: Optional["torch.Tensor"] = None
):
logits = outputs.get("logits")
if logits is None:
return outputs.get("loss", torch.tensor(0.0))
@@ -652,11 +654,11 @@ def dft_loss_func(outputs, labels, num_items_in_batch=None):
def _dft_cross_entropy(
source: torch.Tensor,
target: torch.Tensor,
num_items_in_batch: Optional[torch.Tensor] = None,
source: "torch.Tensor",
target: "torch.Tensor",
num_items_in_batch: Optional["torch.Tensor"] = None,
ignore_index: int = -100,
) -> torch.Tensor:
) -> "torch.Tensor":
per_token_loss = torch.nn.functional.cross_entropy(source, target, ignore_index=ignore_index, reduction="none")
valid_mask = target != ignore_index
if not valid_mask.any():
@@ -679,6 +681,67 @@ def _dft_cross_entropy(
return loss
def eaft_loss_func(
outputs: "torch.Tensor",
labels: "torch.Tensor",
num_items_in_batch: Optional["torch.Tensor"] = None,
alpha: float = 1.0,
) -> "torch.Tensor":
logits = outputs.get("logits")
if logits is None:
return outputs.get("loss", torch.tensor(0.0))
logits = logits.float()
vocab_size = logits.size(-1)
labels = torch.nn.functional.pad(labels, (0, 1), value=-100)
shift_labels = labels[..., 1:].contiguous()
logits = logits.view(-1, vocab_size)
shift_labels = shift_labels.view(-1)
shift_labels = shift_labels.to(logits.device)
loss = _eaft_cross_entropy(logits, shift_labels, num_items_in_batch, alpha)
return loss
def _eaft_cross_entropy(
source: "torch.Tensor",
target: "torch.Tensor",
num_items_in_batch: Optional["torch.Tensor"] = None,
alpha: float = 1.0,
ignore_index: int = -100,
) -> "torch.Tensor":
per_token_loss = torch.nn.functional.cross_entropy(source, target, ignore_index=ignore_index, reduction="none")
valid_mask = target != ignore_index
if not valid_mask.any():
return torch.tensor(0.0, device=source.device, dtype=source.dtype)
valid_losses = per_token_loss[valid_mask]
with torch.no_grad():
source_detached = source[valid_mask].detach()
topk_val, _ = torch.topk(source_detached, k=20, dim=-1)
logsumexp_topk = torch.logsumexp(topk_val, dim=-1, keepdim=True)
log_probs_topk = topk_val - logsumexp_topk
probs_topk = torch.exp(log_probs_topk)
entropy_approx = -(probs_topk * log_probs_topk).sum(dim=-1)
entropy_term = entropy_approx / 3.0
adaptive_weight = torch.pow(entropy_term, alpha)
weighted_losses = valid_losses * adaptive_weight
if num_items_in_batch is not None:
total_loss = weighted_losses.sum()
if torch.is_tensor(num_items_in_batch):
num_items_in_batch = num_items_in_batch.to(total_loss.device)
loss = total_loss / num_items_in_batch
else:
loss = weighted_losses.mean()
return loss
def nested_detach(
tensors: Union["torch.Tensor", list["torch.Tensor"], tuple["torch.Tensor"], dict[str, "torch.Tensor"]],
clone: bool = False,

View File

@@ -119,9 +119,19 @@ def synchronize() -> None:
@requires_accelerator
def set_device() -> None:
"""Set current accelerator."""
torch.accelerator.set_device_index(get_local_rank())
def set_device_index() -> None:
"""Set current accelerator index to local rank."""
if get_current_accelerator().type != DeviceType.CPU:
torch.accelerator.set_device_index(get_local_rank())
@requires_accelerator
def get_current_device() -> torch.device:
"""Get current accelerator device."""
if get_current_accelerator().type == DeviceType.CPU:
return torch.device(DeviceType.CPU.value)
else:
return torch.device(type=get_current_accelerator().type, index=torch.accelerator.current_device_index())
def is_torch_cuda_available():

View File

@@ -34,10 +34,14 @@ from typing import Any, Optional
from torch.distributed import barrier, destroy_process_group, init_process_group
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
from ..utils import logging
from ..utils.types import DistributedConfig, ProcessGroup, Tensor, TensorLike
from . import helper
logger = logging.get_logger(__name__)
class Dim(str, Enum):
"""Dimension names."""
@@ -119,12 +123,13 @@ class DistributedInterface:
if self._initialized:
return
helper.set_device_index()
self._is_distributed = helper.is_distributed()
self._rank = helper.get_rank()
self._world_size = helper.get_world_size()
self._local_rank = helper.get_local_rank()
self._local_world_size = helper.get_local_world_size()
self.current_accelerator = helper.get_current_accelerator()
self.current_device = helper.get_current_device()
self.device_count = helper.get_device_count()
if config is None:
@@ -140,15 +145,14 @@ class DistributedInterface:
timeout = config.get("timeout", 18000)
if self._is_distributed:
helper.set_device()
init_process_group(timeout=timedelta(seconds=timeout))
self.model_device_mesh = init_device_mesh(
device_type=self.current_accelerator.type,
device_type=self.current_device.type,
mesh_shape=self.strategy.model_mesh_shape,
mesh_dim_names=self.strategy.model_mesh_dim_names,
)
self.data_device_mesh = init_device_mesh(
device_type=self.current_accelerator.type,
device_type=self.current_device.type,
mesh_shape=self.strategy.data_mesh_shape,
mesh_dim_names=self.strategy.data_mesh_dim_names,
)
@@ -157,11 +161,12 @@ class DistributedInterface:
self.data_device_mesh = None
self._initialized = True
logger.info_rank0(f"DistributedInterface initialized: {self}.")
def __str__(self) -> str:
return (
f"DistributedInterface(strategy={self.strategy}), is_distributed={self._is_distributed}, "
f"current_accelerator={self.current_accelerator}, rank={self._rank}, world_size={self._world_size}, "
f"current_device={self.current_device}, rank={self._rank}, world_size={self._world_size}, "
f"model_device_mesh={self.model_device_mesh}, data_device_mesh={self.data_device_mesh}"
)
@@ -246,4 +251,7 @@ class DistributedInterface:
if __name__ == "__main__":
print(DistributedInterface(DistributedStrategy()))
"""
python -m llamafactory.v1.accelerator.interface
"""
print(DistributedInterface())

View File

@@ -0,0 +1,32 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .arg_parser import InputArgument, get_args
from .arg_utils import ModelClass, SampleBackend
from .data_args import DataArguments
from .model_args import ModelArguments
from .sample_args import SampleArguments
from .training_args import TrainingArguments
__all__ = [
"DataArguments",
"InputArgument",
"ModelArguments",
"ModelClass",
"SampleArguments",
"SampleBackend",
"TrainingArguments",
"get_args",
]

View File

@@ -17,7 +17,7 @@
import json
from enum import Enum, unique
from enum import StrEnum, unique
class PluginConfig(dict):
@@ -36,7 +36,7 @@ PluginArgument = PluginConfig | dict | str | None
@unique
class ModelClass(str, Enum):
class ModelClass(StrEnum):
"""Auto class for model config."""
LLM = "llm"
@@ -45,7 +45,7 @@ class ModelClass(str, Enum):
@unique
class SampleBackend(str, Enum):
class SampleBackend(StrEnum):
HF = "hf"
VLLM = "vllm"

View File

@@ -21,20 +21,25 @@ from .arg_utils import ModelClass, PluginConfig, get_plugin_config
@dataclass
class ModelArguments:
model: str = field(
default="Qwen/Qwen3-4B-Instruct-2507",
metadata={"help": "Path to the model or model identifier from Hugging Face."},
)
template: str = field(
default="qwen3_nothink",
metadata={"help": "Template for the model."},
)
trust_remote_code: bool = field(
default=False,
metadata={"help": "Trust remote code from Hugging Face."},
)
use_fast_processor: bool = field(
default=True,
metadata={"help": "Use fast processor from Hugging Face."},
)
model_class: ModelClass = field(
default=ModelClass.LLM,
metadata={"help": "Model class from Hugging Face."},
)
init_config: PluginConfig | None = field(
default=None,
metadata={"help": "Initialization configuration for the model."},
)
peft_config: PluginConfig | None = field(
default=None,
metadata={"help": "PEFT configuration for the model."},
@@ -49,6 +54,7 @@ class ModelArguments:
)
def __post_init__(self) -> None:
self.init_config = get_plugin_config(self.init_config)
self.peft_config = get_plugin_config(self.peft_config)
self.kernel_config = get_plugin_config(self.kernel_config)
self.quant_config = get_plugin_config(self.quant_config)

View File

@@ -22,7 +22,7 @@ from .arg_utils import PluginConfig, get_plugin_config
@dataclass
class TrainingArguments:
output_dir: str = field(
default=os.path.join("outputs", str(uuid4())),
default=os.path.join("outputs", str(uuid4().hex)),
metadata={"help": "Path to the output directory."},
)
micro_batch_size: int = field(

View File

@@ -0,0 +1,181 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import os
from abc import ABC, abstractmethod
from collections.abc import AsyncGenerator, Generator
from threading import Thread
import torch
from transformers import TextIteratorStreamer
from ..accelerator.interface import DistributedInterface
from ..config import ModelArguments, SampleArguments, SampleBackend
from ..utils.helper import get_tokenizer
from ..utils.types import HFModel, Message, Sample, TorchDataset
from .utils.rendering import Renderer
class BaseEngine(ABC):
@abstractmethod
def __init__(
self,
args: SampleArguments,
model_args: ModelArguments,
model: HFModel,
renderer: Renderer,
) -> None:
"""Initialize the engine.
Args:
args: Sample arguments.
model_args: Model arguments.
model: Model.
renderer: Renderer.
"""
...
@abstractmethod
async def generate(self, messages: list[Message], tools: str | None = None) -> AsyncGenerator[str, None]:
"""Generate tokens asynchronously.
Args:
messages: List of messages.
tools: Tools string.
Yields:
Generated tokens.
"""
...
@abstractmethod
async def batch_infer(self, dataset: TorchDataset) -> list[Sample]:
"""Batch infer samples.
Args:
dataset: Torch dataset.
Returns:
List of samples.
"""
...
class HuggingFaceEngine(BaseEngine):
def __init__(
self,
args: SampleArguments,
model_args: ModelArguments,
model: HFModel,
renderer: Renderer,
) -> None:
self.args = args
self.model_args = model_args
self.model = model
self.renderer = renderer
self.semaphore = asyncio.Semaphore(int(os.getenv("MAX_CONCURRENT", "1")))
@torch.inference_mode()
def get_response(self, messages: list[Message], tools: str | None = None) -> Generator[str, None, None]:
model_inputs = self.renderer.render_messages(messages, tools, is_generate=True)
streamer = TextIteratorStreamer(
tokenizer=get_tokenizer(self.renderer.processor),
skip_prompt=True,
skip_special_tokens=True, # TODO: configurable
)
device = DistributedInterface().current_device
kwargs = {
"input_ids": torch.tensor([model_inputs["input_ids"]]).to(device),
"attention_mask": torch.tensor([model_inputs["attention_mask"]]).to(device),
"max_new_tokens": self.args.max_new_tokens,
"streamer": streamer,
}
thread = Thread(target=self.model.generate, kwargs=kwargs, daemon=True)
thread.start()
def stream():
try:
return streamer.__next__()
except StopIteration:
raise StopAsyncIteration()
return stream
async def generate(self, messages: list[Message], tools: str | None = None) -> AsyncGenerator[str, None]:
async with self.semaphore:
response = self.get_response(messages, tools)
while True:
try:
yield await asyncio.to_thread(response)
except StopAsyncIteration:
break
async def batch_infer(self, dataset: TorchDataset) -> list[Sample]:
"""Batch infer samples.
Args:
dataset: Torch dataset.
Returns:
List of samples.
"""
raise NotImplementedError("Batch infer is not implemented.")
class BaseSampler:
"""Base sampler.
Args:
args: Sample arguments.
model_args: Model arguments.
model: Model.
renderer: Renderer.
"""
def __init__(
self,
args: SampleArguments,
model_args: ModelArguments,
model: HFModel,
renderer: Renderer,
) -> None:
if args.sample_backend == SampleBackend.HF:
self.engine = HuggingFaceEngine(args, model_args, model, renderer)
else:
raise ValueError(f"Unknown sample backend: {args.sample_backend}")
async def generate(self, messages: list[Message], tools: str | None = None) -> AsyncGenerator[str, None]:
"""Generate tokens asynchronously.
Args:
messages: List of messages.
tools: Tools string.
Yields:
Generated tokens.
"""
async for token in self.engine.generate(messages, tools):
yield token
async def batch_infer(self, dataset: TorchDataset) -> list[Sample]:
"""Batch infer samples.
Args:
dataset: Torch dataset.
Returns:
List of samples.
"""
return await self.engine.batch_infer(dataset)

View File

@@ -1,44 +0,0 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC, abstractmethod
from ..config.sample_args import SampleArguments, SampleBackend
from .model_loader import ModelLoader
class BaseEngine(ABC):
@abstractmethod
def __init__(self, sample_args: SampleArguments, model_loader: ModelLoader) -> None: ...
@abstractmethod
async def generate(self):
pass
@abstractmethod
async def batch_infer(self):
pass
class HuggingFaceEngine(BaseEngine):
def __init__(self, model_loader: ModelLoader, sample_args: SampleArguments) -> None:
self.args = sample_args
class ChatSampler:
def __init__(self, model_loader: ModelLoader, sample_args: SampleArguments) -> None:
if sample_args.sample_backend == SampleBackend.HF:
self.engine = HuggingFaceEngine(model_loader, sample_args)
else:
raise ValueError(f"Unknown sample backend: {sample_args.sample_backend}")

View File

@@ -14,15 +14,23 @@
"""The definition of data engine.
Init Data engine:
How to use:
data_engine = DataEngine(data_args)
data_engine[i]: Get the sample via index.
Init workflow:
1. Parse dataset info from arguments.
2. Load datasets according to dataset info.
3. Build data index (and reweight samples if necessary).
Get Data Sample:
Get data sample:
1. Get sample from data index.
2. Convert sample to standard format.
3. Return sample.
Note:
1. The data engine is equivalent to the torch dataset.
2. The data engine is agnostic to the model used.
"""
import os
@@ -98,10 +106,10 @@ class DataEngine(Dataset):
size = self.dataset_infos[dataset_name].get("size")
weight = self.dataset_infos[dataset_name].get("weight")
if size or weight: # data index plugin
from ..plugins.data_plugins.loader import DataIndexPlugin
if size or weight:
from ..plugins.data_plugins.loader import adjust_data_index
data_index = DataIndexPlugin().adjust_data_index(data_index, size, weight)
data_index = adjust_data_index(data_index, size, weight)
self.data_index.extend(data_index)
@@ -150,9 +158,9 @@ class DataEngine(Dataset):
dataset_name, sample_index = self.data_index[index]
return self._convert_data_sample(self.datasets[dataset_name][sample_index], dataset_name)
else: # data selector plugin
from ..plugins.data_plugins.loader import DataSelectorPlugin
from ..plugins.data_plugins.loader import select_data_sample
selected_index = DataSelectorPlugin().select(self.data_index, index)
selected_index = select_data_sample(self.data_index, index)
if isinstance(selected_index, list):
return [
self._convert_data_sample(self.datasets[dataset_name][sample_index], dataset_name)

View File

@@ -12,34 +12,44 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""The definition of model loader.
"""The definition of model engine.
Init Phase:
How to use:
model_engine = ModelEngine(model_args, is_train=True)
model_engine.processor: Get the tokenizer or multi-modal processor.
model_engine.renderer: Get the renderer.
model_engine.model_config: Get the model configuration.
model_engine.model: Get the HF model.
Init workflow:
1. Init processor.
2. Init render.
2. Init model config.
3. Init model.
4. Init adapter.
"""
import torch
from accelerate import init_empty_weights
from transformers import AutoConfig, AutoProcessor
from ..accelerator.helper import DeviceType
from ..accelerator.interface import DistributedInterface
from ..config.model_args import ModelArguments, ModelClass
from ..utils import logging
from ..utils.types import HFConfig, HFModel, Processor
from .utils.rendering import Renderer
logger = logging.get_logger(__name__)
class ModelLoader:
"""Model loader.
class ModelEngine:
"""Model engine.
Args:
model_args: Model arguments.
is_trainable: Whether to train the model.
is_train: Whether to train the model.
"""
def __init__(self, model_args: ModelArguments, is_train: bool = False) -> None:
@@ -49,17 +59,22 @@ class ModelLoader:
"""Whether to train the model."""
self.processor = self._init_processor()
"""Tokenizer or multi-modal processor."""
self.renderer = Renderer(self.args.template, self.processor)
"""Renderer."""
self.model_config = self._init_model_config()
"""Model configuration."""
self.model = self._init_model()
"""HF model."""
def _init_processor(self) -> Processor:
"""Init processor."""
"""Init processor.
NOTE: Transformers v5 always use fast tokenizer.
https://github.com/huggingface/transformers/blob/v5.0.0rc1/src/transformers/models/auto/tokenization_auto.py#L642
"""
return AutoProcessor.from_pretrained(
self.args.model,
trust_remote_code=self.args.trust_remote_code,
use_fast=self.args.use_fast_processor,
)
def _init_model_config(self) -> HFConfig:
@@ -92,14 +107,24 @@ class ModelLoader:
AutoClass = AutoModel
# map the entire model to the current accelerator
model = AutoClass.from_pretrained(
self.args.model,
config=self.model_config,
dtype="auto",
device_map=DistributedInterface().current_accelerator,
trust_remote_code=self.args.trust_remote_code,
)
if self.args.init_config is not None:
from ..plugins.model_plugins.initialization import InitPlugin
init_device = InitPlugin(self.args.init_config.name)()
else:
init_device = DistributedInterface().current_device
if init_device.type == DeviceType.META:
with init_empty_weights():
model = AutoClass.from_config(self.model_config)
else:
model = AutoClass.from_pretrained(
self.args.model,
config=self.model_config,
dtype="auto",
device_map=init_device,
trust_remote_code=self.args.trust_remote_code,
)
if self.args.peft_config is None:
if self.is_train:
@@ -124,12 +149,12 @@ class ModelLoader:
if __name__ == "__main__":
"""
python -m llamafactory.v1.core.model_loader --model llamafactory/tiny-random-qwen2.5
python -m llamafactory.v1.core.model_engine --model llamafactory/tiny-random-qwen2.5
"""
from ..config.arg_parser import get_args
_, model_args, *_ = get_args()
model_loader = ModelLoader(model_args=model_args)
print(model_loader.processor)
print(model_loader.model_config)
print(model_loader.model)
model_engine = ModelEngine(model_args=model_args)
print(model_engine.processor)
print(model_engine.model_config)
print(model_engine.model)

View File

@@ -0,0 +1,99 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ...utils.constants import IGNORE_INDEX
from ...utils.helper import get_tokenizer
from ...utils.types import Message, ModelInput, Processor
def render_chatml_messages(
processor: Processor,
messages: list[Message],
tools: str | None = None,
is_generate: bool = False,
) -> ModelInput:
"""Apply chatml template to messages and convert them to model input.
See https://huggingface.co/spaces/huggingfacejs/chat-template-playground?modelId=Qwen/Qwen2-7B-Instruct
"""
tokenizer = get_tokenizer(processor)
input_ids, labels, loss_weights = [], [], []
for message in messages:
temp_str = "<|im_start|>" + message["role"] + "\n"
for content in message["content"]:
if content["type"] == "text":
temp_str += content["value"]
else:
raise ValueError(f"Unsupported content type: {content['type']}")
temp_str += "<|im_end|>\n"
temp_weight = message.get("loss_weight", 1.0 if message["role"] == "assistant" else 0.0)
temp_ids = tokenizer.encode(temp_str, add_special_tokens=False)
input_ids.extend(temp_ids)
loss_weights.extend([temp_weight] * len(temp_ids))
if temp_weight > 1e-6:
labels.extend(temp_ids)
else:
labels.extend([IGNORE_INDEX] * len(temp_ids))
if is_generate:
temp_ids = tokenizer.encode("<|im_start|>assistant\n", add_special_tokens=False)
input_ids.extend(temp_ids)
loss_weights.extend([0.0] * len(temp_ids))
labels.extend([IGNORE_INDEX] * len(temp_ids))
return ModelInput(
input_ids=input_ids,
attention_mask=[1] * len(input_ids),
labels=labels,
loss_weights=loss_weights,
)
def parse_chatml_message(generated_text: str) -> Message:
"""Parse a message in ChatML format. Supports interleaved reasoning and tool calls.
Args:
generated_text (str): The generated text in ChatML format.
Returns:
Message: The parsed message.
"""
return Message(role="assistant", content=[{"type": "text", "value": generated_text}])
class Renderer:
def __init__(self, template: str, processor: Processor):
self.template = template
self.processor = processor
def render_messages(
self, messages: list[Message], tools: str | None = None, is_generate: bool = False
) -> ModelInput:
if self.template == "chatml":
return render_chatml_messages(self.processor, messages, tools, is_generate)
else:
from ...plugins.model_plugins.rendering import RenderingPlugin
return RenderingPlugin(self.template).render_messages(self.processor, messages, tools, is_generate)
def parse_message(self, generated_text: str) -> Message:
if self.template == "chatml":
return parse_chatml_message(generated_text)
else:
from ...plugins.model_plugins.rendering import RenderingPlugin
return RenderingPlugin(self.template).parse_message(generated_text)

View File

@@ -49,6 +49,11 @@ def launch():
run_sft()
elif command == "chat":
from .samplers.cli_sampler import run_chat
run_chat()
elif command == "env":
print_env()

View File

@@ -13,11 +13,12 @@
# limitations under the License.
import json
from typing import Any, Literal, NotRequired, TypedDict
from ...utils import logging
from ...utils.plugin import BasePlugin
from ...utils.types import DPOSample, Sample, SFTSample
from ...utils.types import DPOSample, Sample, SFTSample, ToolCall
logger = logging.get_logger(__name__)
@@ -61,7 +62,7 @@ class DataConverterPlugin(BasePlugin):
return super().__call__(raw_sample)
@DataConverterPlugin("alpaca").register
@DataConverterPlugin("alpaca").register()
def alpaca_converter(raw_sample: AlpacaSample) -> SFTSample:
"""Convert Alpaca sample to SFT sample.
@@ -98,7 +99,7 @@ def alpaca_converter(raw_sample: AlpacaSample) -> SFTSample:
return {"messages": messages}
@DataConverterPlugin("sharegpt").register
@DataConverterPlugin("sharegpt").register()
def sharegpt_converter(raw_sample: SharegptSample) -> SFTSample:
"""Convert ShareGPT sample to SFT sample.
@@ -118,17 +119,32 @@ def sharegpt_converter(raw_sample: SharegptSample) -> SFTSample:
"function_call": "assistant",
}
messages = []
tools = raw_sample.get("tools", "")
tools = raw_sample.get("tools")
if tools:
try:
tools: list[dict[str, Any]] = json.loads(tools)
except json.JSONDecodeError:
logger.warning_rank0(f"Invalid tools format: {str(tools)}")
tools = []
for message in raw_sample.get("conversations", []):
tag = message["from"]
if tag not in tag_mapping:
logger.warning_rank0(f"Unsupported role tag {tag} in message: {message}")
elif tag == "function_call":
try:
tool_calls: ToolCall | list[ToolCall] = json.loads(message["value"])
except json.JSONDecodeError:
logger.warning_rank0(f"Invalid tool call format: {str(message['value'])}")
continue
if not isinstance(tool_calls, list):
tool_calls = [tool_calls]
messages.append(
{
"role": "assistant",
"content": [{"type": "tool_calls", "value": message["value"]}],
"content": [{"type": "tool_call", "value": json.dumps(tool_call)} for tool_call in tool_calls],
"loss_weight": 1.0,
}
)
@@ -142,15 +158,12 @@ def sharegpt_converter(raw_sample: SharegptSample) -> SFTSample:
)
if tools:
if messages and messages[0]["role"] == "system":
messages[0]["content"].append({"type": "tools", "value": tools})
else:
messages.insert(0, {"role": "system", "content": [{"type": "tools", "value": tools}], "loss_weight": 0.0})
return {"messages": messages}
return {"messages": messages, "tools": json.dumps(tools)}
else:
return {"messages": messages}
@DataConverterPlugin("pair").register
@DataConverterPlugin("pair").register()
def pair_converter(raw_sample: PairSample) -> DPOSample:
"""Convert Pair sample to DPO sample.

View File

@@ -49,7 +49,7 @@ def _get_builder_name(path: str) -> Literal["arrow", "csv", "json", "parquet", "
raise ValueError(f"Unknown dataset filetype: {filetype}.")
@DataLoaderPlugin("local").register
@DataLoaderPlugin("local").register()
def load_data_from_file(filepath: str, split: str, streaming: bool) -> HFDataset:
if os.path.isdir(filepath):
filetype = _get_builder_name(os.listdir(filepath)[0])
@@ -66,49 +66,43 @@ def load_data_from_file(filepath: str, split: str, streaming: bool) -> HFDataset
return dataset
class DataIndexPlugin(BasePlugin):
"""Plugin for adjusting dataset index."""
def adjust_data_index(
data_index: list[tuple[str, int]], size: int | None, weight: float | None
) -> list[tuple[str, int]]:
"""Adjust dataset index by size and weight.
def adjust_data_index(
self, data_index: list[tuple[str, int]], size: int | None, weight: float | None
) -> list[tuple[str, int]]:
"""Adjust dataset index by size and weight.
Args:
data_index (list[tuple[str, int]]): List of (dataset_name, sample_index).
size (Optional[int]): Desired dataset size.
weight (Optional[float]): Desired dataset weight.
Args:
data_index (list[tuple[str, int]]): List of (dataset_name, sample_index).
size (Optional[int]): Desired dataset size.
weight (Optional[float]): Desired dataset weight.
Returns:
list[tuple[str, int]]: Adjusted dataset index.
"""
if size is not None:
data_index = random.choices(data_index, k=size)
Returns:
list[tuple[str, int]]: Adjusted dataset index.
"""
if size is not None:
data_index = random.choices(data_index, k=size)
if weight is not None:
data_index = random.choices(data_index, k=int(len(data_index) * weight))
if weight is not None:
data_index = random.choices(data_index, k=int(len(data_index) * weight))
return data_index
return data_index
class DataSelectorPlugin(BasePlugin):
"""Plugin for selecting dataset samples."""
def select_data_sample(
data_index: list[tuple[str, int]], index: slice | list[int] | Any
) -> tuple[str, int] | list[tuple[str, int]]:
"""Select dataset samples.
def select(
self, data_index: list[tuple[str, int]], index: slice | list[int] | Any
) -> tuple[str, int] | list[tuple[str, int]]:
"""Select dataset samples.
Args:
data_index (list[tuple[str, int]]): List of (dataset_name, sample_index).
index (Union[slice, list[int], Any]): Index of dataset samples.
Args:
data_index (list[tuple[str, int]]): List of (dataset_name, sample_index).
index (Union[slice, list[int], Any]): Index of dataset samples.
Returns:
Union[tuple[str, int], list[tuple[str, int]]]: Selected dataset samples.
"""
if isinstance(index, slice):
return [data_index[i] for i in range(*index.indices(len(data_index)))]
elif isinstance(index, list):
return [data_index[i] for i in index]
else:
raise ValueError(f"Invalid index type {type(index)}.")
Returns:
Union[tuple[str, int], list[tuple[str, int]]]: Selected dataset samples.
"""
if isinstance(index, slice):
return [data_index[i] for i in range(*index.indices(len(data_index)))]
elif isinstance(index, list):
return [data_index[i] for i in index]
else:
raise ValueError(f"Invalid index type {type(index)}.")

View File

@@ -1,133 +0,0 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
@dataclass
class Template:
user_template: str
assistant_template: str
system_template: str
def render_message(self, message: dict[str, str]) -> str:
return self.user_template.format(**message)
@dataclass
class QwenTemplate:
message_template: str = "<|im_start|>{role}\n{content}<|im_end|>\n" # FIXME if role: tool
thinking_template: str = "<think>\n{content}\n</think>\n\n"
def _extract_content(self, content_data: str | list[dict[str, str]]) -> str:
if isinstance(content_data, str):
return content_data.strip()
if isinstance(content_data, list):
parts = []
for item in content_data:
if item.get("type") == "text":
parts.append(item.get("value", ""))
elif item.get("type") == "image_url":
pass
return "\n".join(parts).strip()
return ""
def render_message(self, message: dict[str, str | list[dict[str, str]]]) -> str:
role = message["role"]
content = self._extract_content(message.get("content", ""))
if role == "assistant":
reasoning_content = message.get("reasoning_content", "")
if reasoning_content:
reasoning_content = self.thinking_template.format(content=str(reasoning_content).strip())
return self.message_template.format(role="assistant", content=reasoning_content + content)
else:
return self.message_template.format(role=role, content=content)
def encode_messages(self, tokenizer, messages: list[dict[str, str]], max_seq_len: int = 8192) -> any:
"""Encode one message."""
input_ids, attention_mask, labels = [], [], []
for message in messages:
content_str = self.render_message(message)
content_ids = tokenizer.encode(content_str, add_special_tokens=False)
input_ids += content_ids
attention_mask += [1] * len(content_ids)
if hasattr(message, "loss_weight"):
loss_weight = message["loss_weight"]
else:
loss_weight = 1 if message["role"] == "assistant" else 0
if loss_weight == 1:
labels += content_ids
else:
labels += [-100] * len(content_ids)
model_inputs = {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}
model_inputs.update({"position_ids": list(range(len(input_ids)))})
model_inputs = {k: v[-max_seq_len:] for k, v in model_inputs.items()}
return model_inputs
if __name__ == "__main__":
def to_qwen3_messages(template: QwenTemplate, messages: list[dict]):
out = []
for m in messages:
role = m["role"]
content = template._extract_content(m.get("content", ""))
if role == "assistant":
reasoning = (m.get("reasoning_content") or "").strip()
if reasoning:
content = template.thinking_template.format(content=reasoning) + content
out.append({"role": role, "content": content})
return out
from transformers import AutoTokenizer
tok = AutoTokenizer.from_pretrained(
"Qwen/Qwen3-30B-A3B-Thinking-2507",
trust_remote_code=True,
)
test_messages = [
{"role": "system", "content": "You are a helpful assistant."},
{
"role": "user",
"content": [{"type": "text", "text": "1+1等于几"}, {"type": "text", "text": "2+2等于几"}],
},
{
"role": "assistant",
"reasoning_content": "这是一个简单的数学问题。1加1的结果是2。",
"content": [{"type": "text", "text": "1+1=2"}, {"type": "text", "text": "2+2=4"}],
},
]
template = QwenTemplate()
rendered_custom = "".join([template.render_message(m) for m in test_messages])
qwen3_messages = to_qwen3_messages(template, test_messages)
rendered_hf = tok.apply_chat_template(qwen3_messages, tokenize=False, add_generation_prompt=False)
print("==== custom ====")
print(rendered_custom)
print("==== hf ====")
print(rendered_hf)
assert rendered_custom.strip() == rendered_hf.strip(), "Rendered text mismatch"
ids_custom = tok.encode(rendered_custom, add_special_tokens=False)
ids_hf = tok.apply_chat_template(qwen3_messages, tokenize=True, add_generation_prompt=False)
assert ids_custom == ids_hf, f"Token ids mismatch: custom={len(ids_custom)} hf={len(ids_hf)}"

View File

@@ -0,0 +1,43 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from ...accelerator.helper import DeviceType
from ...accelerator.interface import DistributedInterface
from ...utils.plugin import BasePlugin
class InitPlugin(BasePlugin):
def __call__(self) -> torch.device:
return super().__call__()
@InitPlugin("init_on_meta").register()
def init_on_meta() -> torch.device:
return torch.device(DeviceType.META.value)
@InitPlugin("init_on_rank0").register()
def init_on_rank0() -> torch.device:
if DistributedInterface().get_rank() == 0:
return torch.device(DeviceType.CPU.value)
else:
return torch.device(DeviceType.META.value)
@InitPlugin("init_on_default").register()
def init_on_default() -> torch.device:
return DistributedInterface().current_device

View File

@@ -38,17 +38,17 @@ class BaseKernel(ABC):
@classmethod
def get_kernel_id(cls) -> str:
r"""Returns the unique identifier for the kernel."""
"""Returns the unique identifier for the kernel."""
return cls._kernel_id
@classmethod
def get_device(cls) -> str:
r"""Returns the device type associated with the kernel (e.g., "cuda", "npu", "cpu")."""
"""Returns the device type associated with the kernel (e.g., "cuda", "npu", "cpu")."""
return cls._device
@classmethod
def check_deps(cls) -> bool:
r"""Checks if the required dependencies for the kernel are available.
"""Checks if the required dependencies for the kernel are available.
Returns:
bool: ``True`` if dependencies are met, ``False`` otherwise.
@@ -65,7 +65,7 @@ class BaseKernel(ABC):
@classmethod
@abstractmethod
def apply(cls, **kwargs) -> HFModel:
r"""Applies the kernel optimization to the model.
"""Applies the kernel optimization to the model.
Args:
**kwargs: Arbitrary keyword arguments, usually containing the model instance and the kernel configuration.

View File

@@ -33,7 +33,7 @@ logger = get_logger(__name__)
def scan_all_kernels():
r"""Scan all kernels in the ``ops`` directory.
"""Scan all kernels in the ``ops`` directory.
Scans the ``ops`` directory for all ``.py`` files and attempts to import them.
Importing triggers the :func:`~registry.register_kernel` decorator, which automatically registers the kernels.
@@ -77,7 +77,7 @@ default_kernels = scan_all_kernels()
def get_default_kernels():
r"""Get a list of default registered kernel IDs.
"""Get a list of default registered kernel IDs.
Returns:
list[str]: List of kernel IDs.
@@ -86,7 +86,7 @@ def get_default_kernels():
def apply_kernel(kernel_id: str, **kwargs):
r"""Applies a specific kernel to the model.
"""Applies a specific kernel to the model.
Args:
kernel_id (str): The ID of the kernel to apply.
@@ -99,18 +99,19 @@ def apply_kernel(kernel_id: str, **kwargs):
kernel = default_kernels.get(kernel_id)
if kernel is None:
raise ValueError(f"Kernel {kernel_id} not found")
kernel.apply(**kwargs)
class KernelPlugin(BasePlugin):
r"""Plugin for managing kernel optimizations."""
"""Plugin for managing kernel optimizations."""
pass
@KernelPlugin("auto").register
@KernelPlugin("auto").register()
def apply_default_kernels(**kwargs):
r"""Applies all default registered kernels to the model.
"""Applies all default registered kernels to the model.
Args:
**kwargs: Keyword arguments passed to the kernel application function.
@@ -125,8 +126,11 @@ def apply_default_kernels(**kwargs):
use_kernels = default_kernels.keys()
else:
use_kernels = kwargs.get("include_kernels").split(",") # "kernel_id1,kernel_id2,kernel_id3"
for kernel in use_kernels:
if kernel not in default_kernels:
raise ValueError(f"Kernel {kernel} not found")
apply_kernel(kernel, **kwargs)
return kwargs.get("model")

View File

@@ -40,11 +40,11 @@ from ...registry import register_kernel
class GmmFunction(torch.autograd.Function):
r"""Custom autograd function for NPU Grouped Matrix Multiplication (GMM)."""
"""Custom autograd function for NPU Grouped Matrix Multiplication (GMM)."""
@staticmethod
def forward(ctx, x, weight, group_list):
r"""Performs the forward pass of Grouped Matrix Multiplication.
"""Performs the forward pass of Grouped Matrix Multiplication.
Args:
ctx: Context object to save tensors for backward pass.
@@ -65,7 +65,7 @@ class GmmFunction(torch.autograd.Function):
@staticmethod
def backward(ctx, grad_output):
r"""Performs the backward pass of Grouped Matrix Multiplication.
"""Performs the backward pass of Grouped Matrix Multiplication.
Args:
ctx: Context object containing saved tensors.
@@ -94,11 +94,11 @@ class GmmFunction(torch.autograd.Function):
class HybridGmmFunction(torch.autograd.Function):
r"""Custom autograd function for Hybrid Grouped Matrix Multiplication on NPU."""
"""Custom autograd function for Hybrid Grouped Matrix Multiplication on NPU."""
@staticmethod
def forward(ctx, num_experts, *args):
r"""Performs the forward pass of Hybrid GMM.
"""Performs the forward pass of Hybrid GMM.
Args:
ctx: Context object to save tensors.
@@ -124,7 +124,7 @@ class HybridGmmFunction(torch.autograd.Function):
@staticmethod
def backward(ctx, *grad_outputs):
r"""Performs the backward pass of Hybrid GMM.
"""Performs the backward pass of Hybrid GMM.
Args:
ctx: Context object containing saved tensors.
@@ -176,13 +176,13 @@ class HybridGmmFunction(torch.autograd.Function):
class NpuMoeFused:
r"""Container for NPU fused MoE forward functions."""
"""Container for NPU fused MoE forward functions."""
@staticmethod
def npu_moe_experts_forward(
self, hidden_states: torch.Tensor, routing_weights: torch.Tensor, router_indices: torch.Tensor
) -> torch.Tensor:
r"""Forward pass for MoE experts using NPU fused operations.
"""Forward pass for MoE experts using NPU fused operations.
Args:
self: The MoE layer instance.
@@ -230,11 +230,11 @@ class NpuMoeFused:
class Qwen3NpuMoeFused:
r"""Container for Qwen3 NPU fused MoE forward functions."""
"""Container for Qwen3 NPU fused MoE forward functions."""
@staticmethod
def qwen3moe_sparse_moe_block_forward(self, hidden_states: torch.Tensor):
r"""Forward pass for Qwen3 sparse MoE block using NPU fused operations.
"""Forward pass for Qwen3 sparse MoE block using NPU fused operations.
Args:
self: The Qwen3 MoE block instance.
@@ -298,14 +298,14 @@ if not is_transformers_version_greater_than("5.0.0"):
@register_kernel
class NpuFusedMoEKernel(BaseKernel):
r"""NPU Fused MoE Kernel implementation."""
"""NPU Fused MoE Kernel implementation."""
_kernel_id = "npu_fused_moe"
_device = DeviceType.NPU
@classmethod
def apply(cls, **kwargs) -> HFModel:
r"""Applies the NPU fused MoE kernel to the model.
"""Applies the NPU fused MoE kernel to the model.
Args:
**kwargs: Keyword arguments containing the model.
@@ -333,6 +333,7 @@ class NpuFusedMoEKernel(BaseKernel):
if target_moe_mapping is None:
return model
for module in model.modules():
class_name = module.__class__.__name__
if class_name in target_moe_mapping:

View File

@@ -38,7 +38,7 @@ except ImportError:
def npu_swiglu_forward(self, hidden_state):
r"""SwiGLU forward pass for NPU.
"""SwiGLU forward pass for NPU.
Args:
self: The MLP layer instance.
@@ -53,7 +53,7 @@ def npu_swiglu_forward(self, hidden_state):
def _npu_swiglu_glm4_forward(self, hidden_states):
r"""SwiGLU forward pass for GLM4 on NPU.
"""SwiGLU forward pass for GLM4 on NPU.
Args:
self: The GLM4 MLP layer instance.
@@ -68,7 +68,7 @@ def _npu_swiglu_glm4_forward(self, hidden_states):
def _npu_swiglu_gemma3ntext_forward(self, hidden_states):
r"""SwiGLU forward pass for Gemma3nText on NPU.
"""SwiGLU forward pass for Gemma3nText on NPU.
Args:
self: The Gemma3nText MLP layer instance.
@@ -88,7 +88,7 @@ def _npu_swiglu_gemma3ntext_forward(self, hidden_states):
@register_kernel
class NpuSwiGluKernel(BaseKernel):
r"""NPU Kernel for fused SwiGLU activation."""
"""NPU Kernel for fused SwiGLU activation."""
# just support apply to the following module layers
expect_modules = frozenset(
@@ -126,7 +126,7 @@ class NpuSwiGluKernel(BaseKernel):
@classmethod
def apply(cls, **kwargs) -> "HFModel":
r"""Applies the NPU fused SwiGLU kernel to the model.
"""Applies the NPU fused SwiGLU kernel to the model.
Args:
**kwargs: Keyword arguments containing the model.

View File

@@ -30,7 +30,7 @@ from ...registry import register_kernel
def npu_rms_norm_forward(self, hidden_states):
r"""NPU forward implementation for RMSNorm.
"""NPU forward implementation for RMSNorm.
Args:
self: RMSNorm module instance with `weight` and `variance_epsilon`.
@@ -46,14 +46,14 @@ def npu_rms_norm_forward(self, hidden_states):
@register_kernel
class NpuRMSNormKernel(BaseKernel):
r"""NPU kernel wrapper for RMSNorm that applies the replacement within a model."""
"""NPU kernel wrapper for RMSNorm that applies the replacement within a model."""
_kernel_id = "npu_fused_rmsnorm"
_device = DeviceType.NPU
@classmethod
def apply(cls, **kwargs) -> "HFModel":
r"""Iterate the model and apply NPU-optimized forward to matched RMSNorm modules.
"""Iterate the model and apply NPU-optimized forward to matched RMSNorm modules.
Key points:
- Match modules whose class name contains "RMSNorm" (case-insensitive).
@@ -78,6 +78,7 @@ class NpuRMSNormKernel(BaseKernel):
if not cls.check_deps():
raise RuntimeError(f"torch_npu is not available but {cls.__name__} was called.")
rms_norm_pattern = re.compile("RMSNorm", re.IGNORECASE)
for name, module in model.named_modules():

View File

@@ -40,7 +40,7 @@ except ImportError:
def _apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
r"""Applies Rotary Position Embedding to the query and key tensors using NPU optimization.
"""Applies Rotary Position Embedding to the query and key tensors using NPU optimization.
Args:
q (Tensor): Query tensor.
@@ -61,7 +61,7 @@ def _apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
def _apply_multimodal_rotary_pos_emb_qwen25_vl(q, k, cos, sin, mrope_section, unsqueeze_dim=1):
r"""Applies Rotary Position Embedding with multimodal sections (Qwen2-VL) on NPU.
"""Applies Rotary Position Embedding with multimodal sections (Qwen2-VL) on NPU.
Args:
q (Tensor): Query tensor.
@@ -89,14 +89,14 @@ def _apply_multimodal_rotary_pos_emb_qwen25_vl(q, k, cos, sin, mrope_section, un
@register_kernel
class NpuRoPEKernel(BaseKernel):
r"""NPU Kernel for Rotary Position Embedding."""
"""NPU Kernel for Rotary Position Embedding."""
_kernel_id = "npu_fused_rope"
_device = DeviceType.NPU
@classmethod
def apply(cls, **kwargs) -> "HFModel":
r"""Apply RoPE acceleration by monkey-patching `apply_rotary_pos_emb`.
"""Apply RoPE acceleration by monkey-patching `apply_rotary_pos_emb`.
This function iterates through the model's modules to find attention layers,
identifies the module where they are defined, and replaces the original
@@ -115,9 +115,11 @@ class NpuRoPEKernel(BaseKernel):
"""
if not cls.check_deps():
raise RuntimeError(f"torch_npu is not available but {cls.__name__} was called.")
model = kwargs.get("model", None)
if model is None:
raise ValueError(f"HFModel instance is required for {cls.__name__}.")
_modules = set()
for module in model.modules():
if "Attention" in module.__class__.__name__:
@@ -143,4 +145,5 @@ class NpuRoPEKernel(BaseKernel):
_modules.add(module_name)
except Exception as e:
logger.warning_rank0_once(f"Failed to apply RoPE kernel to module {module_name}: {e}")
return model

View File

@@ -30,7 +30,7 @@ __all__ = ["Registry", "register_kernel"]
class Registry:
r"""Registry for managing kernel implementations.
"""Registry for managing kernel implementations.
Storage structure: ``{ "kernel_id": Class }``
"""
@@ -38,8 +38,8 @@ class Registry:
_kernels: dict[str, type[BaseKernel]] = {}
@classmethod
def register(cls, kernel_cls: type[BaseKernel]):
r"""Decorator to register a kernel class.
def register(cls, kernel_cls: type[BaseKernel]) -> type[BaseKernel] | None:
"""Decorator to register a kernel class.
The class must inherit from :class:`BaseKernel` and specify ``_kernel_id`` and ``_device`` attributes.
@@ -47,7 +47,7 @@ class Registry:
kernel_cls (type[BaseKernel]): The kernel class to register.
Returns:
type[BaseKernel]: The registered kernel class.
type[BaseKernel] | None: The registered kernel class if the device type matches the current accelerator
Raises:
TypeError: If the class does not inherit from :class:`BaseKernel`.
@@ -55,6 +55,7 @@ class Registry:
"""
if not issubclass(kernel_cls, BaseKernel):
raise TypeError(f"Class {kernel_cls} must inherit from BaseKernel")
kernel_id = kernel_cls.get_kernel_id()
device = kernel_cls.get_device()
@@ -73,7 +74,7 @@ class Registry:
@classmethod
def get(cls, kernel_id: str) -> Optional[type[BaseKernel]]:
r"""Retrieves a registered kernel implementation by its ID.
"""Retrieves a registered kernel implementation by its ID.
Args:
kernel_id (str): The ID of the kernel to retrieve.
@@ -85,7 +86,7 @@ class Registry:
@classmethod
def get_registered_kernels(cls) -> dict[str, type[BaseKernel]]:
r"""Returns a dictionary of all registered kernels.
"""Returns a dictionary of all registered kernels.
Returns:
dict[str, type[BaseKernel]]: Dictionary mapping kernel IDs to kernel classes.

View File

@@ -45,13 +45,13 @@ class PeftPlugin(BasePlugin):
return super().__call__(model, config)
@PeftPlugin("lora").register
@PeftPlugin("lora").register()
def get_lora_model(model: HFModel, config: LoraConfigDict, is_train: bool) -> PeftModel:
peft_config = LoraConfig(**config)
model = get_peft_model(model, peft_config)
return model
@PeftPlugin("freeze").register
@PeftPlugin("freeze").register()
def get_freeze_model(model: HFModel, config: FreezeConfigDict, is_train: bool) -> HFModel:
raise NotImplementedError()

View File

@@ -0,0 +1,212 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import re
from ...utils.constants import IGNORE_INDEX
from ...utils.helper import get_tokenizer
from ...utils.plugin import BasePlugin
from ...utils.types import Message, ModelInput, Processor, ToolCall
class RenderingPlugin(BasePlugin):
pass
def _update_model_input(
processor: Processor,
input_ids: list[int],
labels: list[int],
loss_weights: list[int],
temp_str: str,
temp_weight: float,
) -> str:
"""Update model input with temporary string."""
if not temp_str:
return ""
tokenizer = get_tokenizer(processor)
temp_ids = tokenizer.encode(temp_str, add_special_tokens=False)
input_ids.extend(temp_ids)
loss_weights.extend([temp_weight] * len(temp_ids))
if temp_weight > 1e-6:
labels.extend(temp_ids)
else:
labels.extend([IGNORE_INDEX] * len(temp_ids))
return ""
@RenderingPlugin("qwen3_nothink").register("render_messages")
def render_qwen_messages(
processor: Processor,
messages: list[Message],
tools: str | None = None,
is_generate: bool = False,
) -> ModelInput:
input_ids, labels, loss_weights = [], [], []
temp_str, temp_weight = "", 0.0
if tools:
temp_str += "<|im_start|>system\n"
if messages[0]["role"] == "system":
for content in messages[0]["content"]:
if content["type"] == "text":
temp_str += content["value"]
else:
raise ValueError(f"Unsupported content type: {content['type']}")
temp_str += "\n\n"
temp_weight = messages[0].get("loss_weight", 0.0)
temp_str += (
"# Tools\n\nYou may call one or more functions to assist with the user query.\n\n"
"You are provided with function signatures within <tools></tools> XML tags:\n<tools>"
)
try:
tools = json.loads(tools)
except json.JSONDecodeError:
raise ValueError(f"Invalid tools format: {str(tools)}.")
if not isinstance(tools, list):
tools = [tools]
for tool in tools:
temp_str += "\n" + json.dumps(tool, ensure_ascii=False)
temp_str += (
"\n</tools>\n\nFor each function call, return a json object with function name "
'and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{"name": '
'<function-name>, "arguments": <args-json-object>}\n</tool_call><|im_end|>\n'
)
elif messages[0]["role"] == "system":
temp_str += "<|im_start|>system\n"
for content in messages[0]["content"]:
if content["type"] == "text":
temp_str += content["value"]
else:
raise ValueError(f"Unsupported content type: {content['type']}")
temp_str += "<|im_end|>\n"
temp_weight = messages[0].get("loss_weight", 0.0)
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
for turn_idx, message in enumerate(messages):
if message["role"] == "user" or (message["role"] == "system" and turn_idx != 0):
temp_str += "<|im_start|>" + message["role"] + "\n"
for content in message["content"]:
if content["type"] == "text":
temp_str += content["value"]
else:
raise ValueError(f"Unsupported content type: {content['type']}")
temp_str += "<|im_end|>\n"
temp_weight = message.get("loss_weight", 0.0)
elif message["role"] == "assistant":
temp_str += "<|im_start|>" + message["role"] + "\n"
for val_idx, content in enumerate(message["content"]):
if content["type"] == "text":
temp_str += content["value"]
elif content["type"] == "reasoning":
temp_str += "<thinking>\n" + content["value"] + "\n</thinking>\n\n" # avoid using special tokens
elif content["type"] == "tool_call":
if val_idx != 0 and message["content"][val_idx - 1]["type"] in ["text", "tool_call"]:
temp_str += "\n"
try:
tool_call: ToolCall = json.loads(content["value"])
except json.JSONDecodeError:
raise ValueError(f"Invalid tool call format: {content['value']}.")
temp_str += (
'<tool_call>\n{"name": "'
+ tool_call["name"]
+ '", "arguments": '
+ json.dumps(tool_call["arguments"], ensure_ascii=False)
+ "}\n</tool_call>"
)
else:
raise ValueError(f"Unsupported content type: {content['type']}")
temp_str += "<|im_end|>\n"
temp_weight = message.get("loss_weight", 1.0)
elif message["role"] == "tool":
if turn_idx == 0 or messages[turn_idx - 1]["role"] != "tool":
temp_str += "<|im_start|>user"
temp_str += "\n<tool_response>\n"
for content in message["content"]:
if content["type"] == "text":
temp_str += content["value"]
else:
raise ValueError(f"Unsupported content type: {content['type']}")
temp_str += "\n</tool_response>"
if turn_idx == len(messages) - 1 or messages[turn_idx + 1]["role"] != "tool":
temp_str += "<|im_end|>\n"
temp_weight = message.get("loss_weight", 0.0)
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
if is_generate:
temp_str += "<|im_start|>assistant\n"
temp_weight = 0.0
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
attention_mask = [1] * len(input_ids)
return ModelInput(
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels,
loss_weights=loss_weights,
)
@RenderingPlugin("qwen3_nothink").register("parse_message")
def parse_qwen_message(generated_text: str) -> Message:
pattern = re.compile(r"<(thinking|tool_call)>\s*(.*?)\s*</\1>\s*", re.DOTALL)
content = []
last_end = 0
for match in pattern.finditer(generated_text):
start, end = match.span()
if start > last_end:
text = generated_text[last_end:start].strip()
if text:
content.append({"type": "text", "value": text})
tag_type = match.group(1)
tag_value = match.group(2).strip()
if tag_type == "thinking":
content.append({"type": "reasoning", "value": tag_value.strip()})
elif tag_type == "tool_call":
try:
json.loads(tag_value.strip())
except json.JSONDecodeError:
raise ValueError(f"Invalid tool call format: {tag_value.strip()}.")
content.append({"type": "tool_call", "value": tag_value.strip()})
last_end = end
if last_end < len(generated_text):
text = generated_text[last_end:].strip()
if text:
content.append({"type": "text", "value": text})
return Message(role="assistant", content=content)

View File

@@ -0,0 +1,125 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import os
from collections.abc import Generator
from threading import Thread
from ..config import InputArgument, ModelArguments, SampleArguments, SampleBackend, get_args
from ..core.base_sampler import BaseSampler
from ..core.data_engine import DataEngine
from ..core.model_engine import ModelEngine
from ..core.utils.rendering import Renderer
from ..utils.types import HFModel, Message, Sample, TorchDataset
class SyncSampler(BaseSampler):
def __init__(
self,
args: SampleArguments,
model_args: ModelArguments,
model: HFModel,
renderer: Renderer,
) -> None:
def _start_background_loop(loop: asyncio.AbstractEventLoop) -> None:
asyncio.set_event_loop(loop)
loop.run_forever()
super().__init__(args, model_args, model, renderer)
self._loop = asyncio.new_event_loop()
self._thread = Thread(target=_start_background_loop, args=(self._loop,), daemon=True)
self._thread.start()
def generate(self, messages: list[Message], tools: str | None = None) -> Generator[str, None, None]:
"""Generate tokens synchronously.
Args:
messages: List of messages.
tools: Tools string.
Yields:
Generated tokens.
"""
generator = super().generate(messages, tools)
while True:
try:
token = asyncio.run_coroutine_threadsafe(generator.__anext__(), self._loop).result()
yield token
except StopAsyncIteration:
break
def batch_infer(self, dataset: TorchDataset) -> list[Sample]:
"""Batch infer samples synchronously.
Args:
dataset: Torch dataset.
Returns:
List of samples.
"""
return asyncio.run_coroutine_threadsafe(super().batch_infer(dataset), self._loop).result()
def run_chat(args: InputArgument = None):
data_args, model_args, _, sample_args = get_args(args)
if sample_args.sample_backend != SampleBackend.HF:
model_args.init_plugin = {"name": "init_on_meta"}
model_engine = ModelEngine(model_args)
sampler = SyncSampler(sample_args, model_args, model_engine.model, model_engine.renderer)
if data_args.dataset is not None:
dataset = DataEngine(data_args)
sampler.batch_infer(dataset)
else:
if os.name != "nt":
try:
import readline # noqa: F401
except ImportError:
print("Install `readline` for a better experience.")
messages = []
print("Welcome to the CLI application, use `clear` to remove the history, use `exit` to exit the application.")
while True:
try:
query = input("\nUser: ")
except UnicodeDecodeError:
print("Detected decoding error at the inputs, please set the terminal encoding to utf-8.")
continue
except Exception:
raise
if query.strip() == "exit":
break
if query.strip() == "clear":
messages = []
print("History has been removed.")
continue
messages.append({"role": "user", "content": [{"type": "text", "value": query}]})
print("Assistant: ", end="", flush=True)
response = ""
for new_text in sampler.generate(messages):
print(new_text, end="", flush=True)
response += new_text
print()
messages.append(model_engine.renderer.parse_message(response))
if __name__ == "__main__":
run_chat()

View File

@@ -17,7 +17,7 @@ from ..accelerator.interface import DistributedInterface
from ..config.arg_parser import get_args
from ..core.base_trainer import BaseTrainer
from ..core.data_engine import DataEngine
from ..core.model_loader import ModelLoader
from ..core.model_engine import ModelEngine
class SFTTrainer(BaseTrainer):
@@ -28,11 +28,11 @@ def run_sft(user_args):
model_args, data_args, training_args, _ = get_args(user_args)
DistributedInterface(training_args.dist_config)
data_engine = DataEngine(data_args)
model_loader = ModelLoader(model_args)
model_engine = ModelEngine(model_args)
trainer = SFTTrainer(
args=training_args,
model=model_loader.model,
processor=model_loader.processor,
model=model_engine.model,
processor=model_engine.processor,
dataset=data_engine,
)
trainer.fit()

View File

@@ -11,3 +11,5 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
IGNORE_INDEX = -100

View File

@@ -0,0 +1,29 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from transformers import PreTrainedTokenizer
from .types import Processor
def get_tokenizer(processor: Processor) -> PreTrainedTokenizer:
"""Get tokenizer from processor.
Args:
processor: Processor.
Returns:
Tokenizer.
"""
return processor.tokenizer if hasattr(processor, "tokenizer") else processor

View File

@@ -54,7 +54,7 @@ def _get_default_logging_level() -> "logging._Level":
def _get_library_name() -> str:
return __name__.split(".")[0]
return ".".join(__name__.split(".")[:2]) # llamafactory.v1
def _get_library_root_logger() -> "_Logger":

View File

@@ -13,6 +13,7 @@
# limitations under the License.
from collections import defaultdict
from collections.abc import Callable
from . import logging
@@ -27,7 +28,7 @@ class BasePlugin:
A plugin is a callable object that can be registered and called by name.
"""
_registry: dict[str, Callable] = {}
_registry: dict[str, dict[str, Callable]] = defaultdict(dict)
def __init__(self, name: str | None = None):
"""Initialize the plugin with a name.
@@ -37,8 +38,7 @@ class BasePlugin:
"""
self.name = name
@property
def register(self):
def register(self, method_name: str = "__call__"):
"""Decorator to register a function as a plugin.
Example usage:
@@ -46,16 +46,21 @@ class BasePlugin:
@PrintPlugin("hello").register()
def print_hello():
print("Hello world!")
@PrintPlugin("hello").register("again")
def print_hello_again():
print("Hello world! Again.")
```
"""
if self.name is None:
raise ValueError("Plugin name is not specified.")
raise ValueError("Plugin name should be specified.")
if self.name in self._registry:
logger.warning_rank0_once(f"Plugin {self.name} is already registered.")
if method_name in self._registry[self.name]:
logger.warning_rank0_once(f"Method {method_name} of plugin {self.name} is already registered.")
def decorator(func: Callable) -> Callable:
self._registry[self.name] = func
self._registry[self.name][method_name] = func
return func
return decorator
@@ -68,10 +73,23 @@ class BasePlugin:
PrintPlugin("hello")()
```
"""
if self.name not in self._registry:
raise ValueError(f"Plugin {self.name} is not registered.")
if "__call__" not in self._registry[self.name]:
raise ValueError(f"Method __call__ of plugin {self.name} is not registered.")
return self._registry[self.name](*args, **kwargs)
return self._registry[self.name]["__call__"](*args, **kwargs)
def __getattr__(self, method_name: str):
"""Get the registered function with the given name.
Example usage:
```python
PrintPlugin("hello").again()
```
"""
if method_name not in self._registry[self.name]:
raise ValueError(f"Method {method_name} of plugin {self.name} is not registered.")
return self._registry[self.name][method_name]
if __name__ == "__main__":
@@ -82,8 +100,13 @@ if __name__ == "__main__":
class PrintPlugin(BasePlugin):
pass
@PrintPlugin("hello").register
@PrintPlugin("hello").register()
def print_hello():
print("Hello world!")
@PrintPlugin("hello").register("again")
def print_hello_again():
print("Hello world! Again.")
PrintPlugin("hello")()
PrintPlugin("hello").again()

View File

@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Literal, NotRequired, TypedDict, Union
from typing import TYPE_CHECKING, Any, Literal, NotRequired, TypedDict, Union
if TYPE_CHECKING:
@@ -84,27 +84,63 @@ class DistributedConfig(TypedDict, total=False):
class Content(TypedDict):
type: Literal["text", "reasoning", "tools", "tool_calls", "image_url"]
type: Literal["text", "reasoning", "tool_call", "image_url"]
"""Type of the content."""
value: str
"""Value of the content."""
class Message(TypedDict):
role: Literal["system", "user", "assistant", "tool"]
"""Role of the message."""
content: list[Content]
loss_weight: float
"""Content of the message."""
loss_weight: NotRequired[float]
"""Loss weight for this message, default to 1.0. Required in training."""
class SFTSample(TypedDict):
messages: list[Message]
"""Messages in the sample."""
tools: NotRequired[str]
"""Tools for the sample in JSON string format."""
extra_info: NotRequired[str]
"""Extra information for the sample, e.g. kto_labels."""
_dataset_name: NotRequired[str]
"""Dataset name for the sample."""
class DPOSample(TypedDict):
chosen_messages: list[Message]
"""Chosen messages in the sample."""
rejected_messages: list[Message]
"""Rejected messages in the sample."""
tools: NotRequired[str]
"""Tools for the sample in JSON string format."""
extra_info: NotRequired[str]
"""Extra information for the sample, e.g. kto_labels."""
_dataset_name: NotRequired[str]
"""Dataset name for the sample."""
Sample = Union[SFTSample, DPOSample]
class ToolCall(TypedDict):
name: str
"""Function name."""
arguments: dict[str, Any]
"""Function arguments."""
class ModelInput(TypedDict, total=False):
input_ids: list[int]
"""Input ids for the model."""
attention_mask: list[int]
"""Attention mask for the model."""
labels: list[int]
"""Labels for the model."""
loss_weights: list[float]
"""Loss weight for each token, default to 1.0."""
position_ids: NotRequired[list[int] | list[list[int]]]
"""Position ids for the model (optional)."""

View File

@@ -12,13 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""LLaMA-Factory test configuration.
"""LlamaFactory test configuration.
Contains shared fixtures, pytest configuration, and custom markers.
"""
import os
from typing import Optional
import sys
import pytest
import torch
@@ -73,7 +73,7 @@ def _handle_slow_tests(items: list[Item]):
item.add_marker(skip_slow)
def _get_visible_devices_env() -> Optional[str]:
def _get_visible_devices_env() -> str | None:
"""Return device visibility env var name."""
if CURRENT_DEVICE == "cuda":
return "CUDA_VISIBLE_DEVICES"
@@ -110,11 +110,10 @@ def _handle_device_visibility(items: list[Item]):
def pytest_collection_modifyitems(config: Config, items: list[Item]):
"""Modify test collection based on markers and environment."""
# Handle version compatibility (from HEAD)
if not is_transformers_version_greater_than("4.57.0"):
skip_bc = pytest.mark.skip(reason="Skip backward compatibility tests")
for item in items:
if "tests_v1" in str(item.fspath):
item.add_marker(skip_bc)
skip_bc = pytest.mark.skip(reason="Skip backward compatibility tests")
for item in items:
if "tests_v1" in str(item.fspath) and not is_transformers_version_greater_than("4.57.0"):
item.add_marker(skip_bc)
_handle_slow_tests(items)
_handle_runs_on(items)
@@ -150,12 +149,21 @@ def _manage_distributed_env(request: FixtureRequest, monkeypatch: MonkeyPatch) -
devices_str = ",".join(str(i) for i in range(required))
monkeypatch.setenv(env_key, devices_str)
# add project root dir to path for mp run
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
if project_root not in sys.path:
sys.path.insert(0, project_root)
os.environ["PYTHONPATH"] = project_root + os.pathsep + os.environ.get("PYTHONPATH", "")
else: # non-distributed test
if old_value:
visible_devices = [v for v in old_value.split(",") if v != ""]
monkeypatch.setenv(env_key, visible_devices[0] if visible_devices else "0")
else:
monkeypatch.setenv(env_key, "0")
if CURRENT_DEVICE == "cuda":
monkeypatch.setattr(torch.cuda, "device_count", lambda: 1)
elif CURRENT_DEVICE == "npu":

View File

@@ -292,3 +292,91 @@ def test_qwen_multi_tool_extractor():
("test_tool", """{"foo": "bar", "size": 10}"""),
("another_tool", """{"foo": "job", "size": 2}"""),
]
@pytest.mark.runs_on(["cpu", "mps"])
def test_lfm_function_formatter():
formatter = FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="lfm")
tool_calls = json.dumps(FUNCTION)
assert formatter.apply(content=tool_calls) == [
"""<|tool_call_start|>[tool_name(foo="bar", size=10)]<|tool_call_end|><|im_end|>\n"""
]
@pytest.mark.runs_on(["cpu", "mps"])
def test_lfm_multi_function_formatter():
formatter = FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="lfm")
tool_calls = json.dumps([FUNCTION] * 2)
assert formatter.apply(content=tool_calls) == [
"""<|tool_call_start|>[tool_name(foo="bar", size=10), tool_name(foo="bar", size=10)]<|tool_call_end|>"""
"<|im_end|>\n"
]
@pytest.mark.runs_on(["cpu", "mps"])
def test_lfm_tool_formatter():
formatter = ToolFormatter(tool_format="lfm")
assert formatter.apply(content=json.dumps(TOOLS)) == [
"List of tools: <|tool_list_start|>" + json.dumps(TOOLS, ensure_ascii=False) + "<|tool_list_end|>"
]
@pytest.mark.runs_on(["cpu", "mps"])
def test_lfm_tool_extractor():
formatter = ToolFormatter(tool_format="lfm")
result = """<|tool_call_start|>[test_tool(foo="bar", size=10)]<|tool_call_end|>"""
assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")]
@pytest.mark.runs_on(["cpu", "mps"])
def test_lfm_multi_tool_extractor():
formatter = ToolFormatter(tool_format="lfm")
result = """<|tool_call_start|>[test_tool(foo="bar", size=10), another_tool(foo="job", size=2)]<|tool_call_end|>"""
assert formatter.extract(result) == [
("test_tool", """{"foo": "bar", "size": 10}"""),
("another_tool", """{"foo": "job", "size": 2}"""),
]
@pytest.mark.runs_on(["cpu", "mps"])
def test_lfm_tool_extractor_with_nested_dict():
formatter = ToolFormatter(tool_format="lfm")
result = """<|tool_call_start|>[search(query="test", options={"limit": 10, "offset": 0})]<|tool_call_end|>"""
extracted = formatter.extract(result)
assert len(extracted) == 1
assert extracted[0][0] == "search"
args = json.loads(extracted[0][1])
assert args["query"] == "test"
assert args["options"] == {"limit": 10, "offset": 0}
@pytest.mark.runs_on(["cpu", "mps"])
def test_lfm_tool_extractor_with_list_arg():
formatter = ToolFormatter(tool_format="lfm")
result = """<|tool_call_start|>[batch_process(items=[1, 2, 3], enabled=True)]<|tool_call_end|>"""
extracted = formatter.extract(result)
assert len(extracted) == 1
assert extracted[0][0] == "batch_process"
args = json.loads(extracted[0][1])
assert args["items"] == [1, 2, 3]
assert args["enabled"] is True
@pytest.mark.runs_on(["cpu", "mps"])
def test_lfm_tool_extractor_no_match():
formatter = ToolFormatter(tool_format="lfm")
result = "This is a regular response without tool calls."
extracted = formatter.extract(result)
assert extracted == result
@pytest.mark.runs_on(["cpu", "mps"])
def test_lfm_tool_round_trip():
formatter = FunctionFormatter(slots=["{{content}}"], tool_format="lfm")
tool_formatter = ToolFormatter(tool_format="lfm")
original = {"name": "my_func", "arguments": {"arg1": "hello", "arg2": 42, "arg3": True}}
formatted = formatter.apply(content=json.dumps(original))
extracted = tool_formatter.extract(formatted[0])
assert len(extracted) == 1
assert extracted[0][0] == original["name"]
assert json.loads(extracted[0][1]) == original["arguments"]

View File

@@ -1,2 +1,2 @@
# change if test fails or cache is outdated
0.9.4.105
0.9.5.103

View File

@@ -24,7 +24,6 @@ def test_get_args_from_yaml(tmp_path: pathlib.Path):
### model
model: "llamafactory/tiny-random-qwen2.5"
trust_remote_code: true
use_fast_processor: true
model_class: "llm"
kernel_config:
name: "auto"

View File

@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""LLaMA-Factory test configuration.
"""LlamaFactory test configuration.
Contains shared fixtures, pytest configuration, and custom markers.
"""
@@ -22,6 +22,7 @@ import sys
import pytest
import torch
import torch.distributed as dist
from pytest import Config, FixtureRequest, Item, MonkeyPatch
from llamafactory.v1.accelerator.helper import get_current_accelerator, get_device_count
@@ -109,17 +110,24 @@ def _handle_device_visibility(items: list[Item]):
def pytest_collection_modifyitems(config: Config, items: list[Item]):
"""Modify test collection based on markers and environment."""
# Handle version compatibility (from HEAD)
if not is_transformers_version_greater_than("4.57.0"):
skip_bc = pytest.mark.skip(reason="Skip backward compatibility tests")
for item in items:
if "tests_v1" in str(item.fspath):
item.add_marker(skip_bc)
skip_bc = pytest.mark.skip(reason="Skip backward compatibility tests")
for item in items:
if "tests_v1" in str(item.fspath) and not is_transformers_version_greater_than("4.57.0"):
item.add_marker(skip_bc)
_handle_slow_tests(items)
_handle_runs_on(items)
_handle_device_visibility(items)
@pytest.fixture(autouse=True)
def _cleanup_distributed_state():
"""Cleanup distributed state after each test."""
yield
if dist.is_initialized():
dist.destroy_process_group()
@pytest.fixture(autouse=True)
def _manage_distributed_env(request: FixtureRequest, monkeypatch: MonkeyPatch) -> None:
"""Set environment variables for distributed tests if specific devices are requested."""
@@ -155,6 +163,7 @@ def _manage_distributed_env(request: FixtureRequest, monkeypatch: MonkeyPatch) -
monkeypatch.setenv(env_key, visible_devices[0] if visible_devices else "0")
else:
monkeypatch.setenv(env_key, "0")
if CURRENT_DEVICE == "cuda":
monkeypatch.setattr(torch.cuda, "device_count", lambda: 1)
elif CURRENT_DEVICE == "npu":

View File

@@ -1,173 +0,0 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Integration tests for DataLoader with different combinations of packing and dynamic batching.
Tests the 4 scenarios:
a) non pack + non dynamic.
b) non pack + dynamic.
c) pack + non dynamic.
d) pack + dynamic.
"""
import torch
from torch.utils.data import DataLoader as TorchDataLoader
from torch.utils.data import Dataset
from transformers import AutoTokenizer
from llamafactory.v1.config.data_args import DataArguments
from llamafactory.v1.core.data_engine import DataEngine
from llamafactory.v1.core.trainer_utils.data_collator import (
DefaultCollator,
)
from llamafactory.v1.core.trainer_utils.data_loader import DataLoader
from llamafactory.v1.plugins.data_plugins.template import QwenTemplate
from llamafactory.v1.utils.batching_queue import TextBatchingQueue
class TensorDataset(Dataset):
"""Wrapper dataset that converts DataEngine samples to tensor format."""
def __init__(self, data_engine: DataEngine, processor, template, max_samples: int = None):
self.data_engine = data_engine
self.processor = processor
self.template = template
self.max_samples = max_samples or len(data_engine)
self.tokenizer = processor.tokenizer if hasattr(processor, "tokenizer") else processor
def __len__(self):
return min(self.max_samples, len(self.data_engine))
def __getitem__(self, idx):
# Get sample from DataEngine
sample = self.data_engine[idx]
# Extract messages from sample
# DataEngine returns samples with format like {"messages": [...], ...}
# For llamafactory/v1-sft-demo, the format should have "messages" field
messages = None
if "messages" in sample:
messages = sample["messages"]
elif "conversations" in sample:
messages = sample["conversations"]
elif "conversation" in sample:
messages = sample["conversation"]
else:
# Try to find message-like fields (skip _dataset_name)
for key, value in sample.items():
if key.startswith("_"):
continue
if isinstance(value, list) and len(value) > 0:
# Check if it looks like a message list
if isinstance(value[0], dict) and "role" in value[0]:
messages = value
break
if messages is None:
raise ValueError(f"Could not find messages in sample: {list(sample.keys())}")
# Encode messages using template
encoded = self.template.encode_messages(self.tokenizer, messages)
# Convert to tensors
return {
"input_ids": torch.tensor(encoded["input_ids"], dtype=torch.long),
"attention_mask": torch.tensor(encoded["attention_mask"], dtype=torch.long),
"labels": torch.tensor(encoded["labels"], dtype=torch.long),
}
def create_real_dataset(max_samples: int = 20, batch_size: int = 4):
"""Create a real dataset using DataEngine."""
data_args = DataArguments(dataset="llamafactory/v1-sft-demo")
data_engine = DataEngine(data_args)
# Create processor and template
processor = AutoTokenizer.from_pretrained("llamafactory/tiny-random-qwen2.5")
template = QwenTemplate()
# Create tensor dataset
raw_data_dataset = TensorDataset(data_engine, processor, template, max_samples=max_samples)
# Create torch DataLoader
torch_dataloader = TorchDataLoader(
raw_data_dataset,
batch_size=batch_size,
shuffle=False,
collate_fn=lambda x: x,
)
return torch_dataloader, processor, template
class TestDataLoaderNonPackNonDynamic:
"""Test case a) non pack + non dynamic."""
def test_basic_functionality(self):
"""Test DataLoader without packing and without dynamic batching."""
# Create real dataset
torch_dataloader, processor, template = create_real_dataset(max_samples=80, batch_size=8)
# Create collator (non-packing)
collator = DefaultCollator(processor=processor, template=template)
# Create DataLoader without batching_queue (non-dynamic)
data_loader = DataLoader(
dataloader=torch_dataloader,
collate_fn=collator,
num_micro_batch=1,
batching_queue=None,
)
# Iterate and check results
batches = list(iter(data_loader))
assert len(batches) > 0
# Check first batch
one_batch = batches[0]
micro_batches = one_batch[0]
assert "input_ids" in micro_batches
assert "attention_mask" in micro_batches
assert "labels" in micro_batches
assert micro_batches["input_ids"].shape[0] == 1 # batch_size=1
assert micro_batches["input_ids"].ndim == 2 # [batch_size, seq_len]
class TestDataLoaderNonPackDynamic:
"""Test case b) non pack + dynamic."""
def test_basic_functionality(self):
"""Test DataLoader without packing but with dynamic batching."""
# Create real dataset
torch_dataloader, processor, template = create_real_dataset(max_samples=80, batch_size=8)
collator = DefaultCollator(processor=processor, template=template)
# Create batching queue for dynamic batching
batching_queue = TextBatchingQueue(
token_micro_bsz=120,
buffer_size=8,
)
data_loader = DataLoader(
dataloader=torch_dataloader,
collate_fn=collator,
num_micro_batch=4,
batching_queue=batching_queue,
)
# Iterate and check
batches = list(iter(data_loader))
micro_batch_tokens_first = [micro_batch["attention_mask"].sum() for micro_batch in batches[0]]
assert all(num_tokens <= 120 for num_tokens in micro_batch_tokens_first)
assert len(batches) > 0

View File

@@ -15,18 +15,18 @@
import torch
from llamafactory.v1.config.model_args import ModelArguments, PluginConfig
from llamafactory.v1.core.model_loader import ModelLoader
from llamafactory.v1.core.model_engine import ModelEngine
def test_tiny_qwen():
from transformers import Qwen2Config, Qwen2ForCausalLM, Qwen2TokenizerFast
model_args = ModelArguments(model="llamafactory/tiny-random-qwen2.5")
model_loader = ModelLoader(model_args)
assert isinstance(model_loader.processor, Qwen2TokenizerFast)
assert isinstance(model_loader.model.config, Qwen2Config)
assert isinstance(model_loader.model, Qwen2ForCausalLM)
assert model_loader.model.dtype == torch.bfloat16
model_engine = ModelEngine(model_args)
assert isinstance(model_engine.processor, Qwen2TokenizerFast)
assert isinstance(model_engine.model_config, Qwen2Config)
assert isinstance(model_engine.model, Qwen2ForCausalLM)
assert model_engine.model.dtype == torch.bfloat16
def test_tiny_qwen_with_kernel_plugin():
@@ -37,13 +37,14 @@ def test_tiny_qwen_with_kernel_plugin():
model_args = ModelArguments(
model="llamafactory/tiny-random-qwen2.5", kernel_config=PluginConfig(name="auto", include_kernels="auto")
)
model_loader = ModelLoader(model_args)
model_engine = ModelEngine(model_args)
# test enable apply kernel plugin
if hasattr(torch, "npu"):
assert model_loader.model.model.layers[0].input_layernorm.forward.__code__ == npu_rms_norm_forward.__code__
assert model_engine.model.model.layers[0].input_layernorm.forward.__code__ == npu_rms_norm_forward.__code__
else:
assert model_loader.model.model.layers[0].input_layernorm.forward.__code__ != npu_rms_norm_forward.__code__
assert isinstance(model_loader.model, Qwen2ForCausalLM)
assert model_engine.model.model.layers[0].input_layernorm.forward.__code__ != npu_rms_norm_forward.__code__
assert isinstance(model_engine.model, Qwen2ForCausalLM)
if __name__ == "__main__":

View File

@@ -0,0 +1,171 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Integration tests for DataLoader with different combinations of packing and dynamic batching.
Tests the 4 scenarios:
a) non pack + non dynamic.
b) non pack + dynamic.
c) pack + non dynamic.
d) pack + dynamic.
"""
# import torch
# from torch.utils.data import DataLoader as TorchDataLoader
# from torch.utils.data import Dataset
# from transformers import AutoTokenizer
# from llamafactory.v1.config.data_args import DataArguments
# from llamafactory.v1.core.data_engine import DataEngine
# from llamafactory.v1.core.utils.data_collator import DefaultCollator
# from llamafactory.v1.core.utils.data_loader import DataLoader
# from llamafactory.v1.plugins.data_plugins.rendering import QwenTemplate
# from llamafactory.v1.utils.batching_queue import TextBatchingQueue
# class TensorDataset(Dataset):
# """Wrapper dataset that converts DataEngine samples to tensor format."""
# def __init__(self, data_engine: DataEngine, processor, template, max_samples: int = None):
# self.data_engine = data_engine
# self.processor = processor
# self.template = template
# self.max_samples = max_samples or len(data_engine)
# self.tokenizer = processor.tokenizer if hasattr(processor, "tokenizer") else processor
# def __len__(self):
# return min(self.max_samples, len(self.data_engine))
# def __getitem__(self, idx):
# # Get sample from DataEngine
# sample = self.data_engine[idx]
# # Extract messages from sample
# # DataEngine returns samples with format like {"messages": [...], ...}
# # For llamafactory/v1-sft-demo, the format should have "messages" field
# messages = None
# if "messages" in sample:
# messages = sample["messages"]
# elif "conversations" in sample:
# messages = sample["conversations"]
# elif "conversation" in sample:
# messages = sample["conversation"]
# else:
# # Try to find message-like fields (skip _dataset_name)
# for key, value in sample.items():
# if key.startswith("_"):
# continue
# if isinstance(value, list) and len(value) > 0:
# # Check if it looks like a message list
# if isinstance(value[0], dict) and "role" in value[0]:
# messages = value
# break
# if messages is None:
# raise ValueError(f"Could not find messages in sample: {list(sample.keys())}")
# # Encode messages using template
# encoded = self.template.encode_messages(self.tokenizer, messages)
# # Convert to tensors
# return {
# "input_ids": torch.tensor(encoded["input_ids"], dtype=torch.long),
# "attention_mask": torch.tensor(encoded["attention_mask"], dtype=torch.long),
# "labels": torch.tensor(encoded["labels"], dtype=torch.long),
# }
# def create_real_dataset(max_samples: int = 20, batch_size: int = 4):
# """Create a real dataset using DataEngine."""
# data_args = DataArguments(dataset="llamafactory/v1-sft-demo")
# data_engine = DataEngine(data_args)
# # Create processor and template
# processor = AutoTokenizer.from_pretrained("llamafactory/tiny-random-qwen2.5")
# template = QwenTemplate()
# # Create tensor dataset
# raw_data_dataset = TensorDataset(data_engine, processor, template, max_samples=max_samples)
# # Create torch DataLoader
# torch_dataloader = TorchDataLoader(
# raw_data_dataset,
# batch_size=batch_size,
# shuffle=False,
# collate_fn=lambda x: x,
# )
# return torch_dataloader, processor, template
# class TestDataLoaderNonPackNonDynamic:
# """Test case a) non pack + non dynamic."""
# def test_basic_functionality(self):
# """Test DataLoader without packing and without dynamic batching."""
# # Create real dataset
# torch_dataloader, processor, template = create_real_dataset(max_samples=80, batch_size=8)
# # Create collator (non-packing)
# collator = DefaultCollator(processor=processor, template=template)
# # Create DataLoader without batching_queue (non-dynamic)
# data_loader = DataLoader(
# dataloader=torch_dataloader,
# collate_fn=collator,
# num_micro_batch=1,
# batching_queue=None,
# )
# # Iterate and check results
# batches = list(iter(data_loader))
# assert len(batches) > 0
# # Check first batch
# one_batch = batches[0]
# micro_batches = one_batch[0]
# assert "input_ids" in micro_batches
# assert "attention_mask" in micro_batches
# assert "labels" in micro_batches
# assert micro_batches["input_ids"].shape[0] == 1 # batch_size=1
# assert micro_batches["input_ids"].ndim == 2 # [batch_size, seq_len]
# class TestDataLoaderNonPackDynamic:
# """Test case b) non pack + dynamic."""
# def test_basic_functionality(self):
# """Test DataLoader without packing but with dynamic batching."""
# # Create real dataset
# torch_dataloader, processor, template = create_real_dataset(max_samples=80, batch_size=8)
# collator = DefaultCollator(processor=processor, template=template)
# # Create batching queue for dynamic batching
# batching_queue = TextBatchingQueue(
# token_micro_bsz=120,
# buffer_size=8,
# )
# data_loader = DataLoader(
# dataloader=torch_dataloader,
# collate_fn=collator,
# num_micro_batch=4,
# batching_queue=batching_queue,
# )
# # Iterate and check
# batches = list(iter(data_loader))
# micro_batch_tokens_first = [micro_batch["attention_mask"].sum() for micro_batch in batches[0]]
# assert all(num_tokens <= 120 for num_tokens in micro_batch_tokens_first)
# assert len(batches) > 0

View File

@@ -0,0 +1,193 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import pytest
from transformers import AutoTokenizer
from llamafactory.v1.config import DataArguments
from llamafactory.v1.core.data_engine import DataEngine
from llamafactory.v1.core.utils.rendering import Renderer
from llamafactory.v1.utils.types import Processor
HF_MESSAGES = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What is LLM?"},
{"role": "assistant", "content": "LLM stands for Large Language Model."},
]
V1_MESSAGES = [
{"role": "system", "content": [{"type": "text", "value": "You are a helpful assistant."}]},
{"role": "user", "content": [{"type": "text", "value": "What is LLM?"}]},
{"role": "assistant", "content": [{"type": "text", "value": "LLM stands for Large Language Model."}]},
]
HF_MESSAGES_WITH_TOOLS = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What is 6*8?"},
{
"role": "assistant",
"tool_calls": [{"type": "function", "function": {"name": "multiply", "arguments": {"a": 6, "b": 8}}}],
},
{"role": "tool", "content": "48."},
{"role": "assistant", "content": "The result of 6*8 is 48."},
]
V1_MESSAGES_WITH_TOOLS = [
{"role": "system", "content": [{"type": "text", "value": "You are a helpful assistant."}]},
{"role": "user", "content": [{"type": "text", "value": "What is 6*8?"}]},
{
"role": "assistant",
"content": [{"type": "tool_call", "value": json.dumps({"name": "multiply", "arguments": {"a": 6, "b": 8}})}],
"loss_weight": 0.0,
},
{"role": "tool", "content": [{"type": "text", "value": "48."}]},
{"role": "assistant", "content": [{"type": "text", "value": "The result of 6*8 is 48."}]},
]
V1_TOOLS = [
{
"type": "function",
"function": {
"name": "multiply",
"description": "A function that multiplies two numbers",
"parameters": {
"type": "object",
"properties": {
"a": {"type": "number", "description": "The first number to multiply"},
"b": {"type": "number", "description": "The second number to multiply"},
},
"required": ["a", "b"],
},
},
}
]
def test_chatml_rendering():
tokenizer: Processor = AutoTokenizer.from_pretrained("llamafactory/tiny-random-qwen3")
renderer = Renderer(template="chatml", processor=tokenizer)
hf_inputs = tokenizer.apply_chat_template(HF_MESSAGES[:-1], add_generation_prompt=True)
v1_inputs = renderer.render_messages(V1_MESSAGES[:-1], is_generate=True)
assert v1_inputs["input_ids"] == hf_inputs
assert v1_inputs["attention_mask"] == [1] * len(hf_inputs)
assert v1_inputs["labels"] == [-100] * len(hf_inputs)
assert v1_inputs["loss_weights"] == [0.0] * len(hf_inputs)
hf_inputs_part = tokenizer.apply_chat_template(HF_MESSAGES[:-1], add_generation_prompt=False)
hf_inputs_full = tokenizer.apply_chat_template(HF_MESSAGES, add_generation_prompt=False)
v1_inputs_full = renderer.render_messages(V1_MESSAGES, is_generate=False)
assert v1_inputs_full["input_ids"] == hf_inputs_full
assert v1_inputs_full["attention_mask"] == [1] * len(hf_inputs_full)
assert v1_inputs_full["labels"] == [-100] * len(hf_inputs_part) + hf_inputs_full[len(hf_inputs_part) :]
assert v1_inputs_full["loss_weights"] == [0.0] * len(hf_inputs_part) + [1.0] * (
len(hf_inputs_full) - len(hf_inputs_part)
)
def test_chatml_parse():
tokenizer: Processor = AutoTokenizer.from_pretrained("llamafactory/tiny-random-qwen3")
renderer = Renderer(template="chatml", processor=tokenizer)
generated_text = "LLM stands for Large Language Model."
parsed_message = renderer.parse_message(generated_text)
assert parsed_message == V1_MESSAGES[-1]
@pytest.mark.parametrize("num_samples", [16])
def test_chatml_rendering_remote(num_samples: int):
tokenizer: Processor = AutoTokenizer.from_pretrained("llamafactory/tiny-random-qwen3")
renderer = Renderer(template="chatml", processor=tokenizer)
data_args = DataArguments(dataset="llamafactory/v1-sft-demo")
data_engine = DataEngine(data_args)
for index in range(num_samples):
v1_inputs = renderer.render_messages(data_engine[index]["messages"], is_generate=True)
prefix = tokenizer.encode("<|im_start|>user\n", add_special_tokens=False)
print(tokenizer.decode(v1_inputs["input_ids"][: len(prefix)]))
assert v1_inputs["input_ids"][: len(prefix)] == prefix
def test_qwen3_nothink_rendering():
tokenizer: Processor = AutoTokenizer.from_pretrained("Qwen/Qwen3-4B-Instruct-2507")
renderer = Renderer(template="qwen3_nothink", processor=tokenizer)
hf_inputs = tokenizer.apply_chat_template(HF_MESSAGES_WITH_TOOLS[:-1], tools=V1_TOOLS, add_generation_prompt=True)
v1_inputs = renderer.render_messages(V1_MESSAGES_WITH_TOOLS[:-1], tools=json.dumps(V1_TOOLS), is_generate=True)
assert v1_inputs["input_ids"] == hf_inputs
assert v1_inputs["attention_mask"] == [1] * len(hf_inputs)
assert v1_inputs["labels"] == [-100] * len(hf_inputs)
assert v1_inputs["loss_weights"] == [0.0] * len(hf_inputs)
hf_inputs_part = tokenizer.apply_chat_template(
HF_MESSAGES_WITH_TOOLS[:-1], tools=V1_TOOLS, add_generation_prompt=False
)
hf_inputs_full = tokenizer.apply_chat_template(HF_MESSAGES_WITH_TOOLS, tools=V1_TOOLS, add_generation_prompt=False)
v1_inputs_full = renderer.render_messages(V1_MESSAGES_WITH_TOOLS, tools=json.dumps(V1_TOOLS), is_generate=False)
assert v1_inputs_full["input_ids"] == hf_inputs_full
assert v1_inputs_full["attention_mask"] == [1] * len(hf_inputs_full)
assert v1_inputs_full["labels"] == [-100] * len(hf_inputs_part) + hf_inputs_full[len(hf_inputs_part) :]
assert v1_inputs_full["loss_weights"] == [0.0] * len(hf_inputs_part) + [1.0] * (
len(hf_inputs_full) - len(hf_inputs_part)
)
def test_qwen3_nothink_parse():
tokenizer: Processor = AutoTokenizer.from_pretrained("Qwen/Qwen3-4B-Instruct-2507")
renderer = Renderer(template="qwen3_nothink", processor=tokenizer)
generated_text = (
"<thinking>I need to use the multiply function to calculate 6*8.</thinking>"
"Let me call the multiply function."
'<tool_call>{"name": "multiply", "arguments": {"a": 6, "b": 8}}</tool_call>'
)
parsed_message = renderer.parse_message(generated_text)
assert parsed_message == {
"role": "assistant",
"content": [
{"type": "reasoning", "value": "I need to use the multiply function to calculate 6*8."},
{"type": "text", "value": "Let me call the multiply function."},
{"type": "tool_call", "value": json.dumps({"name": "multiply", "arguments": {"a": 6, "b": 8}})},
],
}
@pytest.mark.parametrize("num_samples", [8])
def test_qwen3_nothink_rendering_remote(num_samples: int):
tokenizer: Processor = AutoTokenizer.from_pretrained("Qwen/Qwen3-4B-Instruct-2507")
renderer = Renderer(template="qwen3_nothink", processor=tokenizer)
data_args = DataArguments(dataset="llamafactory/reason-tool-use-demo-1500")
data_engine = DataEngine(data_args)
for index in range(num_samples):
v1_inputs = renderer.render_messages(data_engine[index]["messages"], tools=data_engine[index]["tools"])
prefix_text = (
"<|im_start|>system\nYou are a methodical and expert assistant. "
"Your primary goal is to solve user requests by leveraging a set of available tools. "
"You must reason for the best course of action in a structured manner before responding.\n\n"
"# Tools\n\nYou may call one or more functions to assist with the user query.\n\n"
"You are provided with function signatures within <tools></tools> XML tags:\n<tools>\n"
'{"type": "function", "function": {"name":'
)
prefix = tokenizer.encode(prefix_text, add_special_tokens=False)
print(tokenizer.decode(v1_inputs["input_ids"][: len(prefix)]))
assert v1_inputs["input_ids"][: len(prefix)] == prefix
if __name__ == "__main__":
test_chatml_rendering()
test_chatml_parse()
test_chatml_rendering_remote(16)
test_qwen3_nothink_rendering()
test_qwen3_nothink_parse()
test_qwen3_nothink_rendering_remote(16)

View File

@@ -54,18 +54,18 @@ def test_sharegpt_converter():
"conversations": [
{"from": "system", "value": "System"},
{"from": "human", "value": "User"},
{"from": "function_call", "value": "Tool"},
{"from": "function_call", "value": "1"},
{"from": "observation", "value": "Observation"},
{"from": "gpt", "value": "Assistant"},
]
}
expected_data = {
"messages": [
{"content": [{"type": "text", "value": "System"}], "loss_weight": 0.0, "role": "system"},
{"content": [{"type": "text", "value": "User"}], "loss_weight": 0.0, "role": "user"},
{"content": [{"type": "tool_calls", "value": "Tool"}], "loss_weight": 1.0, "role": "assistant"},
{"content": [{"type": "text", "value": "Observation"}], "loss_weight": 0.0, "role": "tool"},
{"content": [{"type": "text", "value": "Assistant"}], "loss_weight": 1.0, "role": "assistant"},
{"role": "system", "content": [{"type": "text", "value": "System"}], "loss_weight": 0.0},
{"role": "user", "content": [{"type": "text", "value": "User"}], "loss_weight": 0.0},
{"role": "assistant", "content": [{"type": "tool_call", "value": "1"}], "loss_weight": 1.0},
{"role": "tool", "content": [{"type": "text", "value": "Observation"}], "loss_weight": 0.0},
{"role": "assistant", "content": [{"type": "text", "value": "Assistant"}], "loss_weight": 1.0},
]
}
assert DataConverterPlugin("sharegpt")(example) == expected_data

View File

@@ -0,0 +1,54 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from llamafactory.v1.accelerator.interface import DistributedInterface
from llamafactory.v1.config.arg_parser import get_args
from llamafactory.v1.core.model_engine import ModelEngine
def test_init_on_meta():
_, model_args, *_ = get_args(
dict(
model="llamafactory/tiny-random-qwen2.5",
init_config={"name": "init_on_meta"},
)
)
model_engine = ModelEngine(model_args=model_args)
assert model_engine.model.device.type == "meta"
def test_init_on_rank0():
_, model_args, *_ = get_args(
dict(
model="llamafactory/tiny-random-qwen2.5",
init_config={"name": "init_on_rank0"},
)
)
model_engine = ModelEngine(model_args=model_args)
if DistributedInterface().get_rank() == 0:
assert model_engine.model.device.type == "cpu"
else:
assert model_engine.model.device.type == "meta"
def test_init_on_default():
_, model_args, *_ = get_args(
dict(
model="llamafactory/tiny-random-qwen2.5",
init_config={"name": "init_on_default"},
)
)
model_engine = ModelEngine(model_args=model_args)
assert model_engine.model.device == DistributedInterface().current_device

View File

@@ -0,0 +1,41 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
from llamafactory.v1.config import ModelArguments, SampleArguments
from llamafactory.v1.core.model_engine import ModelEngine
from llamafactory.v1.samplers.cli_sampler import SyncSampler
@pytest.mark.runs_on(["cuda", "npu"])
def test_sync_sampler():
model_args = ModelArguments(model="Qwen/Qwen3-4B-Instruct-2507", template="qwen3_nothink")
sample_args = SampleArguments()
model_engine = ModelEngine(model_args)
sampler = SyncSampler(sample_args, model_args, model_engine.model, model_engine.renderer)
messages = [{"role": "user", "content": [{"type": "text", "value": "Say 'This is a test.'"}]}]
response = ""
for new_text in sampler.generate(messages):
response += new_text
print(response)
assert model_engine.renderer.parse_message(response) == {
"role": "assistant",
"content": [{"type": "text", "value": "This is a test."}],
}
if __name__ == "__main__":
test_sync_sampler()