9 Commits

Author SHA1 Message Date
娄宗志
589da21d32 [model] support Aeva (#10214) 2026-02-26 23:03:13 +08:00
Yaowei Zheng
122cd46084 [model] update constants (#10220) 2026-02-26 21:13:56 +08:00
浮梦
2b8b871475 [model] Adapt Qwen3.5 (#10213)
Co-authored-by: frozenleaves <frozen@Mac.local>
Co-authored-by: Yaowei Zheng <hiyouga@buaa.edu.cn>
2026-02-26 20:45:02 +08:00
Shanay Mehta
aab9b400bb [model] Add DeepSpeed Z3 leaf module for Qwen3-Next (#10194)
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-24 19:54:37 +08:00
P. Clawmogorov
50599c719b [misc] remove safe_serialization arg for transformers v5 compatibility (#10208)
Co-authored-by: P. Clawmogorov <262173731+Alm0stSurely@users.noreply.github.com>
2026-02-24 11:14:19 +08:00
Kingsley
a0f3ad0cee [mca] update supported models (#10196) 2026-02-20 22:02:49 +08:00
jiaqiw09
f80e15dbb4 [ci] fix ut huggingface hub 429 error when transformers>=5.0.0 (#10155) 2026-02-12 22:14:10 +08:00
sunyi0505
991267fd3b [v1] support quantization (#10161) 2026-02-12 20:37:41 +08:00
浮梦
5c52afa30d [v1] support deepspeed (#10181) 2026-02-12 17:24:30 +08:00
31 changed files with 877 additions and 126 deletions

View File

@@ -25,16 +25,16 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.10'
- name: Install dependencies
run: |
pip install -r docs/requirements.txt
- name: Build Sphinx
run: |
sphinx-build -b html docs/zh docs/_build/html/zh
@@ -56,10 +56,10 @@ jobs:
> docs/_build/html/index.html
touch docs/_build/html/.nojekyll
- name: Setup Pages
uses: actions/configure-pages@v5
- name: Upload artifact
uses: actions/upload-pages-artifact@v3
with:

View File

@@ -61,6 +61,7 @@ jobs:
uv venv
uv pip install -e .
uv pip install -r requirements/dev.txt
uv pip install -r requirements/bitsandbytes.txt
- name: Check quality
run: |

View File

@@ -291,7 +291,7 @@ Read technical notes:
| [GPT-2](https://huggingface.co/openai-community) | 0.1B/0.4B/0.8B/1.5B | - |
| [GPT-OSS](https://huggingface.co/openai) | 20B/120B | gpt_oss |
| [Granite 3-4](https://huggingface.co/ibm-granite) | 1B/2B/3B/7B/8B | granite3/granite4 |
| [Hunyuan/Hunyuan1.5 (MT)](https://huggingface.co/tencent/) | 0.5B/1.8B/4B/7B/13B | hunyuan/hunyuan_small |
| [Hunyuan/Hunyuan1.5 (MT)](https://huggingface.co/tencent/) | 0.5B/1.8B/4B/7B/13B | hunyuan/hunyuan_small|
| [InternLM 2-3](https://huggingface.co/internlm) | 7B/8B/20B | intern2 |
| [InternVL 2.5-3.5](https://huggingface.co/OpenGVLab) | 1B/2B/4B/8B/14B/30B/38B/78B/241B | intern_vl |
| [Intern-S1-mini](https://huggingface.co/internlm/) | 8B | intern_s1 |
@@ -319,6 +319,7 @@ Read technical notes:
| [Pixtral](https://huggingface.co/mistralai) | 12B | pixtral |
| [Qwen2 (Code/Math/MoE/QwQ)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen |
| [Qwen3 (MoE/Instruct/Thinking/Next)](https://huggingface.co/Qwen) | 0.6B/1.7B/4B/8B/14B/32B/80B/235B | qwen3/qwen3_nothink |
| [Qwen3.5](https://huggingface.co/Qwen) | 27B/35B/122B/397B | qwen3_5 |
| [Qwen2-Audio](https://huggingface.co/Qwen) | 7B | qwen2_audio |
| [Qwen2.5-Omni](https://huggingface.co/Qwen) | 3B/7B | qwen2_omni |
| [Qwen3-Omni](https://huggingface.co/Qwen) | 30B | qwen3_omni |

View File

@@ -293,7 +293,7 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
| [GPT-2](https://huggingface.co/openai-community) | 0.1B/0.4B/0.8B/1.5B | - |
| [GPT-OSS](https://huggingface.co/openai) | 20B/120B | gpt_oss |
| [Granite 3-4](https://huggingface.co/ibm-granite) | 1B/2B/3B/7B/8B | granite3/granite4 |
| [Hunyuan/Hunyuan1.5 (MT)](https://huggingface.co/tencent/) | 0.5B/1.8B/4B/7B/13B | hunyuan/hunyuan_small |
| [Hunyuan/Hunyuan1.5 (MT)](https://huggingface.co/tencent/) | 0.5B/1.8B/4B/7B/13B | hunyuan/hunyuan_small|
| [InternLM 2-3](https://huggingface.co/internlm) | 7B/8B/20B | intern2 |
| [InternVL 2.5-3.5](https://huggingface.co/OpenGVLab) | 1B/2B/4B/8B/14B/30B/38B/78B/241B | intern_vl |
| [Intern-S1-mini](https://huggingface.co/internlm/) | 8B | intern_s1 |
@@ -321,6 +321,7 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
| [Pixtral](https://huggingface.co/mistralai) | 12B | pixtral |
| [Qwen2 (Code/Math/MoE/QwQ)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen |
| [Qwen3 (MoE/Instruct/Thinking/Next)](https://huggingface.co/Qwen) | 0.6B/1.7B/4B/8B/14B/32B/80B/235B | qwen3/qwen3_nothink |
| [Qwen3.5](https://huggingface.co/Qwen) | 27B/35B/122B/397B | qwen3_5 |
| [Qwen2-Audio](https://huggingface.co/Qwen) | 7B | qwen2_audio |
| [Qwen2.5-Omni](https://huggingface.co/Qwen) | 3B/7B | qwen2_omni |
| [Qwen3-Omni](https://huggingface.co/Qwen) | 30B | qwen3_omni |

View File

@@ -47,4 +47,3 @@
border-color: rgba(255, 255, 255, 0.45);
box-shadow: 0 0 0 3px rgba(255, 255, 255, 0.12);
}

View File

@@ -1,33 +1,31 @@
# Configuration file for the Sphinx documentation builder.
import os
import sys
# Define common settings here
project = 'LlamaFactory'
copyright = '2024, LlamaFactory Team'
author = 'LlamaFactory Team'
project = "LlamaFactory"
copyright = "2024, LlamaFactory Team"
author = "LlamaFactory Team"
extensions = [
'sphinx.ext.autodoc',
'sphinx.ext.viewcode',
'sphinx.ext.napoleon',
'myst_parser',
"sphinx.ext.autodoc",
"sphinx.ext.viewcode",
"sphinx.ext.napoleon",
"myst_parser",
]
templates_path = ['_templates']
exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
templates_path = ["_templates"]
exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
html_theme = 'sphinx_rtd_theme'
html_theme = "sphinx_rtd_theme"
html_static_path = ['_static']
html_static_path = ["_static"]
html_js_files = [
'js/switcher.js',
"js/switcher.js",
]
html_css_files = [
'css/lang-switcher.css',
"css/lang-switcher.css",
]
myst_enable_extensions = [

View File

@@ -1,20 +1,22 @@
import os
import sys
# Add parent dir to path to allow importing conf.py
sys.path.insert(0, os.path.abspath('..'))
from conf import *
# Add parent dir to path to allow importing conf.py
sys.path.insert(0, os.path.abspath(".."))
from conf import * # noqa: F403
# Language settings
language = 'en'
html_search_language = 'en'
language = "en"
html_search_language = "en"
# Static files
# Point to the root _static directory
html_static_path = ['../_static']
html_static_path = ["../_static"]
# Add custom JS for language switcher
html_js_files = [
'js/switcher.js',
"js/switcher.js",
]

View File

@@ -1,20 +1,22 @@
import os
import sys
# Add parent dir to path to allow importing conf.py
sys.path.insert(0, os.path.abspath('..'))
from conf import *
# Add parent dir to path to allow importing conf.py
sys.path.insert(0, os.path.abspath(".."))
from conf import * # noqa: F403
# Language settings
language = 'zh_CN'
html_search_language = 'zh'
language = "zh_CN"
html_search_language = "zh"
# Static files
# Point to the root _static directory
html_static_path = ['../_static']
html_static_path = ["../_static"]
# Add custom JS for language switcher
html_js_files = [
'js/switcher.js',
"js/switcher.js",
]

View File

@@ -0,0 +1,24 @@
model: Qwen/Qwen3-0.6B
model_class: llm
template: qwen3_nothink
kernel_config:
name: auto
include_kernels: auto
dist_config:
name: deepspeed
config_file: examples/deepspeed/ds_z3_config.json
### data
train_dataset: data/v1_sft_demo.yaml
### training
output_dir: outputs/Qwen3-0.6B-deepspeed
micro_batch_size: 1
cutoff_len: 2048
learning_rate: 1.0e-4
bf16: true
max_steps: 10

View File

@@ -0,0 +1,43 @@
model: Qwen/Qwen3-0.6B
trust_remote_code: true
model_class: llm
template: qwen3_nothink
# PEFT Configuration
peft_config:
name: lora
r: 16
lora_alpha: 32
lora_dropout: 0.05
target_modules: all
# Kernel Config
kernel_config:
name: auto
include_kernels: auto
# FSDP Config
dist_config:
name: fsdp2
dcp_path: null
# Quantization Config
quant_config:
name: bnb # choice: auto/bnb if auto is selected, the quantization method will be automatically selected based on the model and environment.
quantization_bit: 4 # choice: 8/4(bnb)
### data
train_dataset: data/v1_sft_demo.yaml
### training
output_dir: outputs/test_quantization
micro_batch_size: 1
cutoff_len: 2048
learning_rate: 1.0e-4
bf16: false
max_steps: 10
### sample
sample_backend: hf
max_new_tokens: 128

View File

@@ -40,7 +40,7 @@ dependencies = [
"torch>=2.4.0",
"torchvision>=0.19.0",
"torchaudio>=2.4.0",
"transformers>=4.51.0,<=5.0.0,!=4.52.0,!=4.57.0",
"transformers>=4.51.0,<=5.2.0,!=4.52.0,!=4.57.0",
"datasets>=2.16.0,<=4.0.0",
"accelerate>=1.3.0,<=1.11.0",
"peft>=0.18.0,<=0.18.1",

View File

@@ -15,6 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Literal, Optional
@@ -189,6 +190,16 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
"video_grid_thw": mm_inputs.get("video_grid_thw"),
"attention_mask": (features["attention_mask"] >= 1).float(),
}
if "mm_token_type_ids" in inspect.signature(self.get_rope_func).parameters:
image_token_id = getattr(self.model.config, "image_token_id", None)
video_token_id = getattr(self.model.config, "video_token_id", None)
if image_token_id is not None or video_token_id is not None:
mm_token_type_ids = torch.zeros_like(features["input_ids"])
if image_token_id is not None:
mm_token_type_ids[features["input_ids"] == image_token_id] = 1
if video_token_id is not None:
mm_token_type_ids[features["input_ids"] == video_token_id] = 2
rope_index_kwargs["mm_token_type_ids"] = mm_token_type_ids
if "second_per_grid_ts" in mm_inputs: # for qwen2vl
rope_index_kwargs["second_per_grid_ts"] = mm_inputs.get("second_per_grid_ts")
elif "video_second_per_grid" in mm_inputs: # for qwen2.5 omni
@@ -219,6 +230,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
"qwen2_5_vl",
"qwen2_5_omni_thinker",
"qwen3_omni_moe_thinker",
"qwen3_5",
"qwen3_vl",
"qwen3_vl_moe",
]

View File

@@ -2029,6 +2029,39 @@ register_template(
)
register_template(
name="qwen3_5",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_function=FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen3_5"),
format_observation=StringFormatter(
slots=["<|im_start|>user\n<tool_response>\n{{content}}\n</tool_response><|im_end|>\n<|im_start|>assistant\n"]
),
format_tools=ToolFormatter(tool_format="qwen3_5"),
stop_words=["<|im_end|>"],
replace_eos=True,
mm_plugin=get_mm_plugin(name="qwen3_vl", image_token="<|image_pad|>", video_token="<|video_pad|>"),
template_class=ReasoningTemplate,
)
register_template(
name="qwen3_5_nothink",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_function=FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen3_5"),
format_observation=StringFormatter(
slots=["<|im_start|>user\n<tool_response>\n{{content}}\n</tool_response><|im_end|>\n<|im_start|>assistant\n"]
),
format_tools=ToolFormatter(tool_format="qwen3_5"),
stop_words=["<|im_end|>"],
replace_eos=True,
mm_plugin=get_mm_plugin(name="qwen3_vl", image_token="<|image_pad|>", video_token="<|video_pad|>"),
)
register_template(
name="sailor",
format_user=StringFormatter(slots=["<|im_start|>question\n{{content}}<|im_end|>\n<|im_start|>answer\n"]),
@@ -2218,3 +2251,24 @@ register_template(
format_system=StringFormatter(slots=["<|system|>\n{{content}}", {"eos_token"}]),
default_system="You are Zephyr, a helpful assistant.",
)
# copied from glm4_7 template
register_template(
name="aeva",
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]),
format_assistant=StringFormatter(slots=["\n{{content}}"]),
format_system=StringFormatter(slots=["<|system|>\n{{content}}"]),
format_function=FunctionFormatter(slots=["{{content}}"], tool_format="glm4_moe"),
format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>"]),
format_tools=ToolFormatter(tool_format="glm4_moe"),
format_prefix=EmptyFormatter(slots=["[gMASK]<sop>"]),
default_system=(
"You are an AI assistant named Aeva created by Zongzhi Lou. "
"Your answer should be friendly, unbiased, faithful, informative and detailed."
),
stop_words=["<|user|>", "<|observation|>"],
thought_words=("<think>", "</think>"),
efficient_eos=True,
template_class=Glm47ReasoningTemplate,
)

View File

@@ -85,6 +85,21 @@ QWEN_TOOL_PROMPT = (
""""arguments": <args-json-object>}}\n</tool_call>"""
)
QWEN35_TOOL_PROMPT = (
"\n\n# Tools\n\nYou have access to the following functions:\n\n<tools>{tool_text}"
"\n</tools>\n\nIf you choose to call a function ONLY reply in the following format with NO suffix:\n\n"
"<tool_call>\n<function=example_function_name>\n<parameter=example_parameter_1>\nvalue_1\n</parameter>\n"
"<parameter=example_parameter_2>\nThis is the value for the second parameter\nthat can span\nmultiple lines\n"
"</parameter>\n</function>\n</tool_call>\n\n<IMPORTANT>\nReminder:\n"
"- Function calls MUST follow the specified format: "
"an inner <function=...></function> block must be nested within <tool_call></tool_call> XML tags\n"
"- Required parameters MUST be specified\n"
"- You may provide optional reasoning for your function call in natural language "
"BEFORE the function call, but NOT after\n"
"- If there is no function call available, answer the question like normal with your current knowledge "
"and do not tell the user about function calls\n</IMPORTANT>"
)
SEED_TOOL_PROMPT = (
"system\nYou are Doubao, a helpful AI assistant. You may call one or more functions to assist with the user query."
"Tool List:\nYou are authorized to use the following tools (described in JSON Schema format). Before performing "
@@ -453,6 +468,57 @@ class QwenToolUtils(ToolUtils):
return results
class Qwen35ToolUtils(ToolUtils):
r"""Qwen 3.5 tool using template."""
@override
@staticmethod
def tool_formatter(tools: list[dict[str, Any]]) -> str:
tool_text = ""
for tool in tools:
tool = tool.get("function", tool) if tool.get("type") == "function" else tool
tool_text += "\n" + json.dumps(tool, ensure_ascii=False)
return QWEN35_TOOL_PROMPT.format(tool_text=tool_text)
@override
@staticmethod
def function_formatter(functions: list["FunctionCall"]) -> str:
function_texts = []
for func in functions:
name, arguments = func.name, json.loads(func.arguments)
prompt = f"<tool_call>\n<function={name}>"
for key, value in arguments.items():
prompt += f"\n<parameter={key}>"
if not isinstance(value, str):
value = json.dumps(value, ensure_ascii=False)
prompt += f"\n{value}\n</parameter>"
prompt += "\n</function>\n</tool_call>"
function_texts.append(prompt)
return "\n".join(function_texts)
@override
@staticmethod
def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]:
results = []
regex = re.compile(r"<tool_call>\s*<function=\s*([^\s<>]+)\s*(.*?)\s*</function>\s*</tool_call>", re.DOTALL)
for func_name, params_block in re.findall(regex, content):
args_dict = {}
param_pattern = re.compile(r"<parameter=(.*?)>(.*?)</parameter>", re.DOTALL)
for key, raw_value in re.findall(param_pattern, params_block.strip()):
value = raw_value.strip()
try:
parsed_value = json.loads(value)
except json.JSONDecodeError:
parsed_value = raw_value.strip()
args_dict[key] = parsed_value
results.append(FunctionCall(func_name.strip(), json.dumps(args_dict, ensure_ascii=False)))
return results if results else content
class GLM4MOEToolUtils(QwenToolUtils):
r"""GLM-4-MOE tool using template."""
@@ -662,6 +728,7 @@ TOOLS = {
"minimax2": MiniMaxM2ToolUtils(),
"mistral": MistralToolUtils(),
"qwen": QwenToolUtils(),
"qwen3_5": Qwen35ToolUtils(),
"glm4_moe": GLM4MOEToolUtils(),
"seed_oss": SeedToolUtils(),
"ling": LingToolUtils(),

View File

@@ -65,6 +65,7 @@ MCA_SUPPORTED_MODELS = {
"qwen2_vl",
"qwen2_5_vl",
"qwen3_vl",
"qwen3_vl_moe",
"qwen3",
"qwen3_moe",
"qwen3_next",
@@ -2809,6 +2810,29 @@ register_model_group(
)
register_model_group(
models={
"Qwen3.5-27B": {
DownloadSource.DEFAULT: "Qwen/Qwen3.5-27B",
DownloadSource.MODELSCOPE: "Qwen/Qwen3.5-27B",
},
"Qwen3.5-35B-A3B": {
DownloadSource.DEFAULT: "Qwen/Qwen3.5-35B-A3B",
DownloadSource.MODELSCOPE: "Qwen/Qwen3.5-35B-A3B",
},
"Qwen3.5-122B-A10B": {
DownloadSource.DEFAULT: "Qwen/Qwen3.5-122B-A10B",
DownloadSource.MODELSCOPE: "Qwen/Qwen3.5-122B-A10B",
},
"Qwen3.5-397B-A17B": {
DownloadSource.DEFAULT: "Qwen/Qwen3.5-397B-A17B",
DownloadSource.MODELSCOPE: "Qwen/Qwen3.5-397B-A17B",
},
},
template="qwen3_5",
)
register_model_group(
models={
"Qwen2-Audio-7B": {
@@ -3450,3 +3474,35 @@ register_model_group(
},
template="zephyr",
)
register_model_group(
models={
"Aeva-Flash-Chat": {
DownloadSource.DEFAULT: "louzongzhi/Aeva-Flash",
DownloadSource.MODELSCOPE: "louzongktsi/Aeva-Flash",
DownloadSource.OPENMIND: "louzongzhi/Aeva-Flash",
},
"Aeva-Air-Chat": {
DownloadSource.DEFAULT: "louzongzhi/Aeva-Air",
DownloadSource.MODELSCOPE: "louzongktsi/Aeva-Air",
DownloadSource.OPENMIND: "louzongzhi/Aeva-Air",
},
"Aeva-Chat": {
DownloadSource.DEFAULT: "louzongzhi/Aeva",
DownloadSource.MODELSCOPE: "louzongktsi/Aeva",
DownloadSource.OPENMIND: "louzongzhi/Aeva",
},
"Aeva-Pro-Chat": {
DownloadSource.DEFAULT: "louzongzhi/Aeva-Pro",
DownloadSource.MODELSCOPE: "louzongktsi/Aeva-Pro",
DownloadSource.OPENMIND: "louzongzhi/Aeva-Pro",
},
"Aeva-Max-Chat": {
DownloadSource.DEFAULT: "louzongzhi/Aeva-Max",
DownloadSource.MODELSCOPE: "louzongktsi/Aeva-Max",
DownloadSource.OPENMIND: "louzongzhi/Aeva-Max",
},
},
template="aeva",
)

View File

@@ -94,7 +94,7 @@ def check_version(requirement: str, mandatory: bool = False) -> None:
def check_dependencies() -> None:
r"""Check the version of the required packages."""
check_version("transformers>=4.51.0,<=5.0.0")
check_version("transformers>=4.51.0,<=5.2.0")
check_version("datasets>=2.16.0,<=4.0.0")
check_version("accelerate>=1.3.0,<=1.11.0")
check_version("peft>=0.18.0,<=0.18.1")

View File

@@ -142,6 +142,11 @@ def add_z3_leaf_module(model: "PreTrainedModel") -> None:
_set_z3_leaf_modules(model, [Qwen3OmniMoeThinkerTextSparseMoeBlock])
if model_type == "qwen3_next":
from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextSparseMoeBlock
_set_z3_leaf_modules(model, [Qwen3NextSparseMoeBlock])
def configure_moe(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
if not is_trainable or not model_args.moe_aux_loss_coef:

View File

@@ -82,7 +82,33 @@ def _check_model_support(model_args: "ModelArguments"):
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
)
if config.model_type not in MCA_SUPPORTED_MODELS:
raise ValueError(f"Model {config.model_type} is not supported by MCA.")
raise ValueError(
f"Model {config.model_type} is not supported by mcore_adapter."
"You can try to upgrade mcore_adapter to the latest version for more supported models."
)
def _freeze_model_parameters(model: Any, finetuning_args: "FinetuningArguments"):
"""Freeze model parameters for qwen_vl series models based on finetuning arguments."""
if getattr(model.config, "hf_model_type", None) not in ["qwen2_vl", "qwen2_5_vl", "qwen3_vl", "qwen3_vl_moe"]:
return
params_to_freeze = []
if finetuning_args.freeze_vision_tower:
params_to_freeze.extend(["vision_model.blocks", "vision_model.patch_embed"])
if getattr(model.config, "hf_model_type", None) in ["qwen3_vl", "qwen3_vl_moe"]:
params_to_freeze.extend(["vision_model.pos_embed"])
if finetuning_args.freeze_multi_modal_projector:
params_to_freeze.extend(["multi_modal_projector"])
if finetuning_args.freeze_language_model:
params_to_freeze.extend(["embedding", "decoder", "output_layer"])
if params_to_freeze:
for name, p in model.named_parameters():
if any(name.startswith(k) for k in params_to_freeze):
p.requires_grad_(False)
def run_pt(
@@ -161,22 +187,8 @@ def run_sft(
_check_model_support(model_args)
model = AutoModel.from_pretrained(model_args.model_name_or_path, training_args)
# optional freezing for qwen2_vl, qwen2_5_vl
if getattr(model.config, "hf_model_type", None) in ["qwen2_vl", "qwen2_5_vl", "qwen3_vl"]:
params_to_freeze = []
if finetuning_args.freeze_vision_tower:
params_to_freeze.extend(["vision_model.blocks", "vision_model.patch_embed"])
if finetuning_args.freeze_multi_modal_projector:
params_to_freeze.extend(["multi_modal_projector"])
if finetuning_args.freeze_language_model:
params_to_freeze.extend(["embedding", "decoder", "output_layer"])
if params_to_freeze:
for name, p in model.named_parameters():
if any(name.startswith(k) for k in params_to_freeze):
p.requires_grad_(False)
# optional freezing for qwen_vl series
_freeze_model_parameters(model, finetuning_args)
pad_to_max = training_args.expert_model_parallel_size is not None and training_args.expert_model_parallel_size > 1
data_collator = SFTDataCollatorWith4DAttentionMask(
@@ -229,6 +241,8 @@ def run_dpo(
_check_model_support(model_args)
model = AutoModel.from_pretrained(model_args.model_name_or_path, training_args)
_freeze_model_parameters(model, finetuning_args)
if finetuning_args.use_ref_model:
ref_config = AutoConfig.from_pretrained(model_args.model_name_or_path, training_args)
ref_model = AutoModel.from_config(ref_config)

View File

@@ -24,7 +24,7 @@ from ..data import get_template_and_fix_tokenizer
from ..extras import logging
from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
from ..extras.misc import find_available_port, get_device_name, get_torch_device, infer_optim_dtype
from ..extras.packages import is_mcore_adapter_available, is_ray_available
from ..extras.packages import is_mcore_adapter_available, is_ray_available, is_transformers_version_greater_than
from ..hparams import RayArguments, get_infer_args, get_ray_args, get_train_args, read_args
from ..model import load_model, load_tokenizer
from .callbacks import LogCallback, PissaConvertCallback, ReporterCallback
@@ -160,17 +160,28 @@ def export_model(args: Optional[dict[str, Any]] = None) -> None:
model = model.to(output_dtype)
logger.info_rank0(f"Convert model dtype to: {output_dtype}.")
model.save_pretrained(
save_directory=model_args.export_dir,
max_shard_size=f"{model_args.export_size}GB",
safe_serialization=(not model_args.export_legacy_format),
)
# Prepare save arguments (safe_serialization removed in transformers v5.0.0)
save_kwargs = {
"save_directory": model_args.export_dir,
"max_shard_size": f"{model_args.export_size}GB",
}
if not is_transformers_version_greater_than("5.0.0"):
save_kwargs["safe_serialization"] = not model_args.export_legacy_format
model.save_pretrained(**save_kwargs)
if model_args.export_hub_model_id is not None:
# Prepare push arguments (safe_serialization removed in transformers v5.0.0)
push_kwargs = {
"max_shard_size": f"{model_args.export_size}GB",
}
if not is_transformers_version_greater_than("5.0.0"):
push_kwargs["safe_serialization"] = not model_args.export_legacy_format
model.push_to_hub(
model_args.export_hub_model_id,
token=model_args.hf_hub_token,
max_shard_size=f"{model_args.export_size}GB",
safe_serialization=(not model_args.export_legacy_format),
**push_kwargs,
)
if finetuning_args.stage == "rm":

View File

@@ -76,19 +76,28 @@ class BaseTrainer:
if self.args.enable_activation_checkpointing:
self.model.gradient_checkpointing_enable({"use_reentrant": False})
if self.args.dist_config is not None:
shard_need_optimizer = self.args.dist_config.name == "deepspeed"
else:
shard_need_optimizer = False
self._accelerate_engine = None
dist_name = self.args.dist_config.name if self.args.dist_config is not None else None
if shard_need_optimizer:
if dist_name == "deepspeed":
from ..plugins.trainer_plugins.distributed.hub import DistributedPlugin
self._deepspeed_engine = DistributedPlugin("deepspeed")(
self.model,
self.args.dist_config,
num_micro_batch=self.train_batch_generator.num_micro_batch,
micro_batch_size=self.args.micro_batch_size,
)
self._init_optimizer()
self._shard_model()
self._init_lr_scheduler()
self.model, self.optimizer, self.lr_scheduler = self._deepspeed_engine.prepare(
self.model, self.optimizer, self.lr_scheduler
)
else:
# fsdp2 / DDP / no dist
self._shard_model()
self._init_optimizer()
self._init_lr_scheduler()
self._init_lr_scheduler()
def _create_batch_generator(self) -> None:
self.train_batch_generator = BatchGenerator(
@@ -171,25 +180,35 @@ class BaseTrainer:
step_loss = 0
step_valid_tokens = compute_valid_tokens(micro_batches)
step_valid_tokens = DistributedInterface().all_reduce(step_valid_tokens, op=ReduceOp.SUM)
for micro_batch in micro_batches:
num_micro = len(micro_batches)
for i, micro_batch in enumerate(micro_batches):
loss = self.compute_loss(micro_batch)
mini_step_valid_tokens = compute_valid_tokens([micro_batch])
# fsdp uses mean reduction so we need to scale the loss by dp_size
loss = loss * mini_step_valid_tokens * self.dp_size / (step_valid_tokens + 1e-6)
loss.backward()
if self._deepspeed_engine is not None:
# deepspeed: set sync_gradients so engine.step() only fires on last micro-batch
self._deepspeed_engine.accelerator.sync_gradients = i == num_micro - 1
self._deepspeed_engine.backward(loss)
else:
loss.backward()
step_loss += loss.item()
grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.max_grad_norm).item()
# isfinite(): argument 'input' (position 1) must be Tensor, not float
if not torch.isfinite(torch.tensor(grad_norm)): # type: ignore # pyright: ignore [reportUnknownReturnType]
logger.warning_rank0(f"Gradient norm is not finite: {grad_norm}")
if self._deepspeed_engine is not None:
# deepspeed: engine.step() already ran inside backward at the sync boundary
grad_norm = self._deepspeed_engine.get_grad_norm()
else:
self.optimizer.step()
grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.max_grad_norm).item()
self.lr_scheduler.step()
self.optimizer.zero_grad()
# isfinite(): argument 'input' (position 1) must be Tensor, not float
if not torch.isfinite(torch.tensor(grad_norm)): # type: ignore # pyright: ignore [reportUnknownReturnType]
logger.warning_rank0(f"Gradient norm is not finite: {grad_norm}")
else:
self.optimizer.step()
self.lr_scheduler.step()
self.optimizer.zero_grad()
step_loss, grad_norm = DistributedInterface().all_reduce([step_loss, grad_norm])
DistributedInterface().sync()
@@ -203,17 +222,14 @@ class BaseTrainer:
def save_model(self) -> None:
"""Save the model."""
model_to_save = self.model.module if hasattr(self.model, "module") else self.model
state_dict = None
if self.args.dist_config is not None and self.args.dist_config.name == "fsdp2":
from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict
if self.args.dist_config is not None and self.args.dist_config.name in ("deepspeed", "fsdp2"):
from ..plugins.trainer_plugins.distributed.hub import DistributedPlugin
options = StateDictOptions(full_state_dict=True, cpu_offload=True)
state_dict = get_model_state_dict(self.model, options=options)
if DistributedInterface().get_rank() != 0:
return
model_to_save.save_pretrained(self.args.output_dir, state_dict=state_dict)
self.renderer.processor.save_pretrained(self.args.output_dir)
logger.info_rank0(f"Model saved to {self.args.output_dir}")
DistributedPlugin(self.args.dist_config.name).save_model(
self.model, self.args.output_dir, self.renderer.processor
)
else:
model_to_save = self.model.module if hasattr(self.model, "module") else self.model
model_to_save.save_pretrained(self.args.output_dir, max_shard_size="4GB")
self.renderer.processor.save_pretrained(self.args.output_dir, max_shard_size="4GB")
logger.info_rank0(f"Model saved to {self.args.output_dir}")

View File

@@ -90,6 +90,26 @@ class ModelEngine:
Transformers can choose the proper model init context.
https://github.com/huggingface/transformers/blob/v5.0.0rc0/src/transformers/modeling_utils.py#L3538
"""
if self.args.init_config is not None:
from ..plugins.model_plugins.initialization import InitPlugin
init_device = InitPlugin(self.args.init_config.name)()
else:
init_device = DistributedInterface().current_device
init_kwargs = {"device_map": init_device}
if self.args.quant_config is not None:
from ..plugins.model_plugins.quantization import QuantizationPlugin
init_kwargs = QuantizationPlugin(self.args.quant_config.name)(
init_kwargs=init_kwargs,
config=self.model_config,
tokenizer=self.processor,
model_args=self.args,
is_trainable=self.is_train,
)
if self.args.model_class == ModelClass.LLM:
from transformers import AutoModelForCausalLM, AutoModelForImageTextToText
@@ -107,14 +127,8 @@ class ModelEngine:
AutoClass = AutoModel
if self.args.init_config is not None:
from ..plugins.model_plugins.initialization import InitPlugin
init_device = InitPlugin(self.args.init_config.name)()
else:
init_device = DistributedInterface().current_device
if init_device.type == DeviceType.META:
assert self.args.quant_config is None, "Quantization is not supported with meta device."
with init_empty_weights():
model = AutoClass.from_config(self.model_config)
else:
@@ -122,8 +136,8 @@ class ModelEngine:
self.args.model,
config=self.model_config,
dtype="auto",
device_map=init_device,
trust_remote_code=self.args.trust_remote_code,
**init_kwargs,
)
if self.args.peft_config is None:

View File

@@ -0,0 +1,122 @@
# Copyright 2025 HuggingFace Inc., the KVCache.AI team, Approaching AI, and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Any
import torch
from transformers import BitsAndBytesConfig
from ...accelerator.helper import get_current_device
from ...config.model_args import ModelArguments
from ...utils import logging
from ...utils.packages import check_version
from ...utils.plugin import BasePlugin
if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedTokenizer
logger = logging.get_logger(__name__)
class QuantizationPlugin(BasePlugin):
r"""Plugin for model quantization."""
def __call__(
self,
init_kwargs: dict[str, Any] = None,
config: "PretrainedConfig" = None,
tokenizer: "PreTrainedTokenizer" = None,
model_args: "ModelArguments" = None,
is_trainable: bool = False,
) -> dict[str, Any]:
return super().__call__(
init_kwargs, config=config, tokenizer=tokenizer, model_args=model_args, is_trainable=is_trainable
)
@QuantizationPlugin("auto").register()
def quantization_auto(
init_kwargs: dict[str, Any],
**kwargs,
) -> dict[str, Any]:
"""Automatic quantization selection, only support bnb currently.
Args:
init_kwargs (dict[str, Any]): The kwargs for model initialization.
**kwargs: Keyword arguments containing the model.
Returns:
dict[str, Any]: The updated kwargs for model initialization.
"""
model_args: ModelArguments = kwargs.get("model_args", None)
quant_config = model_args.quant_config
quantization_bit = quant_config.get("quantization_bit", None)
if quantization_bit is not None:
logger.info_rank0(f"Loading {quantization_bit}-bit quantized model.")
if quantization_bit in [8, 4]:
return quantization_with_bnb(init_kwargs, **kwargs)
else:
raise ValueError(f"Unsupported quantization bit: {quantization_bit} for auto quantization.")
logger.warning_rank0("No quantization method applied.")
return init_kwargs
@QuantizationPlugin("bnb").register()
def quantization_with_bnb(
init_kwargs: dict[str, Any],
model_args: "ModelArguments" = None,
**kwargs,
) -> dict[str, Any]:
r"""Quantization with BNB."""
logger.info_rank0("Using Bitsandbytes quantization.")
quantization_bit = model_args.quant_config.get("quantization_bit", None)
if quantization_bit is None:
logger.warning_rank0("quantization_bit is not specified, default to 8-bit quantization.")
quantization_bit = 4
assert quantization_bit in [8, 4], "Bitsandbytes only accepts 4-bit or 8-bit quantization."
if quantization_bit == 8:
check_version("bitsandbytes>=0.37.0", mandatory=True)
init_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
elif quantization_bit == 4:
check_version("bitsandbytes>=0.39.0", mandatory=True)
init_kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=model_args.quant_config.get("compute_dtype", torch.float16),
bnb_4bit_use_double_quant=model_args.quant_config.get("double_quantization", True),
bnb_4bit_quant_type=model_args.quant_config.get("quantization_type", "nf4"),
bnb_4bit_quant_storage=model_args.quant_config.get(
"compute_dtype", torch.float16
), # crucial for fsdp+qlora
)
else:
raise ValueError("Bitsandbytes only accepts 4-bit or 8-bit quantization.")
# TODO: improve deepspeed zero3 and fsdp detection.
if kwargs.get("is_trainable", False):
logger.info_rank0("Detected inference mode, setting device_map for bitsandbytes quantization.")
init_kwargs["device_map"] = {"": get_current_device()} # change auto device map for inference
else:
logger.info_rank0("Detected training mode, skip setting device_map for bitsandbytes quantization.")
if model_args.quant_config.get("quantization_bit") != 4:
raise ValueError("Only 4-bit quantized model can use fsdp+qlora or auto device map.")
check_version("bitsandbytes>=0.43.0", mandatory=True)
logger.info_rank0(f"Quantizing model to {model_args.quant_config.get('quantization_bit')} bit with bitsandbytes.")
return init_kwargs

View File

@@ -0,0 +1,129 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""DeepSpeed integration via accelerate's built-in capabilities.
Instead of manually calling deepspeed.initialize() and syncing config,
this module leverages accelerate's Accelerator + DeepSpeedPlugin to handle
initialization, backward, gradient accumulation, and model saving.
"""
from typing import Any, Optional
import torch
from accelerate import Accelerator
from accelerate.utils import DeepSpeedPlugin
from ....utils.logging import get_logger
from ....utils.types import HFModel, Processor
logger = get_logger(__name__)
class DeepSpeedEngine:
"""DeepSpeed integration using accelerate's built-in capabilities.
This replaces the manual DeepSpeedConfigHelper / DeepSpeedEngine approach
with accelerate's Accelerator + DeepSpeedPlugin, which handles:
- Config syncing (auto values, batch size, lr, etc.)
- deepspeed.initialize() call
- Optimizer / LR scheduler wrapping
- Backward + gradient accumulation boundary
- ZeRO-3 parameter gathering for saving
"""
def __init__(self, dist_config: dict[str, Any], num_micro_batch: int = 1, micro_batch_size: int = 1):
config_file = dist_config.get("config_file")
if not config_file:
raise ValueError("DeepSpeed config_file is required in dist_config")
ds_plugin = DeepSpeedPlugin(hf_ds_config=config_file)
self.accelerator = Accelerator(
deepspeed_plugin=ds_plugin,
gradient_accumulation_steps=num_micro_batch,
)
# Resolve "auto" for train_micro_batch_size_per_gpu so that
# accelerate.prepare() does not require a DataLoader to infer it.
ds_config = self.accelerator.state.deepspeed_plugin.deepspeed_config
if ds_config.get("train_micro_batch_size_per_gpu") in (None, "auto"):
ds_config["train_micro_batch_size_per_gpu"] = micro_batch_size
logger.info_rank0(f"DeepSpeedEngine initialized with config: {config_file}")
def shard_model(self, model: HFModel) -> "DeepSpeedEngine":
"""No-op shard — actual model wrapping happens in prepare().
Returns self so the caller gets the engine instance via the hub interface.
"""
return self
def prepare(
self,
model: HFModel,
optimizer: torch.optim.Optimizer,
lr_scheduler: Optional[Any] = None,
) -> tuple[HFModel, torch.optim.Optimizer, Any]:
"""Prepare model, optimizer, and lr_scheduler using accelerate.
Internally calls deepspeed.initialize() and wraps the returned objects.
"""
if lr_scheduler is not None:
model, optimizer, lr_scheduler = self.accelerator.prepare(model, optimizer, lr_scheduler)
else:
model, optimizer = self.accelerator.prepare(model, optimizer)
model._accelerator = self.accelerator # type: ignore[assignment]
logger.info_rank0("Model, optimizer, and lr_scheduler prepared via accelerate")
return model, optimizer, lr_scheduler
def backward(self, loss: torch.Tensor) -> None:
"""Backward pass using accelerate.
Delegates to DeepSpeedEngineWrapper.backward() which respects
sync_gradients to control gradient accumulation boundaries.
When sync_gradients=True: engine.backward(loss) + engine.step()
When sync_gradients=False: engine.backward(loss) only
"""
self.accelerator.backward(loss)
def get_grad_norm(self) -> float:
"""Get the global gradient norm from the DeepSpeed engine."""
engine_wrapper = getattr(self.accelerator, "deepspeed_engine_wrapped", None)
if engine_wrapper is not None:
return engine_wrapper.engine.get_global_grad_norm() or 0.0
return 0.0
def save_model(model: HFModel, output_dir: str, processor: Processor) -> None:
"""Save model using accelerate's built-in ZeRO-aware utilities.
Expects model._accelerator to be set during prepare().
Handles ZeRO-3 parameter gathering automatically via
accelerator.get_state_dict().
"""
accelerator: Accelerator = model._accelerator # type: ignore[union-attr]
unwrapped_model = accelerator.unwrap_model(model)
state_dict = accelerator.get_state_dict(model)
if accelerator.is_main_process:
unwrapped_model.save_pretrained(output_dir, state_dict=state_dict, max_shard_size="4GB")
processor.save_pretrained(output_dir, max_shard_size="4GB")
accelerator.wait_for_everyone()
logger.info_rank0(f"Model saved to {output_dir}")

View File

@@ -17,24 +17,24 @@ import os
import torch
import torch.nn as nn
from peft.tuners.lora import LoraLayer
from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict, set_model_state_dict
from torch.distributed.fsdp import (
CPUOffloadPolicy,
MixedPrecisionPolicy,
fully_shard,
)
from transformers import PreTrainedModel
from peft.tuners.lora import LoraLayer
from ....accelerator.helper import get_current_accelerator
from ....accelerator.interface import DistributedInterface
from ....utils.logging import get_logger
from ....utils.types import HFModel, Processor
logger = get_logger(__name__)
def get_transformer_layer_cls(model: PreTrainedModel) -> type[nn.Module] | None:
def get_transformer_layer_cls(model: HFModel) -> type[nn.Module] | None:
no_split_modules = getattr(model, "_no_split_modules", None)
if no_split_modules:
if isinstance(no_split_modules, (list, tuple)):
@@ -50,6 +50,20 @@ def get_transformer_layer_cls(model: PreTrainedModel) -> type[nn.Module] | None:
return None
def save_model(model: HFModel, output_dir: str, processor: Processor) -> None:
if DistributedInterface().get_rank() == 0:
logger.info("Gathering state dict for saving...")
options = StateDictOptions(full_state_dict=True, cpu_offload=True)
state_dict = get_model_state_dict(model, options=options)
if DistributedInterface().get_rank() == 0:
model_to_save = model.module if hasattr(model, "module") else model
model_to_save.save_pretrained(output_dir, state_dict=state_dict, max_shard_size="4GB")
processor.save_pretrained(output_dir, max_shard_size="4GB")
logger.info(f"Model saved to {output_dir}")
class FSDP2Engine:
def __init__(self, dist_config: dict):
self.dist_interface = DistributedInterface()
@@ -94,12 +108,11 @@ class FSDP2Engine:
reduce_dtype=reduce_dtype,
cast_forward_inputs=True,
)
def is_lora_module_wrap(self, model) -> bool:
return any(isinstance(module, LoraLayer) for module in model.modules())
def prepare_model(self, model: PreTrainedModel) -> PreTrainedModel:
def prepare_model(self, model: HFModel) -> HFModel:
if self.fsdp_mesh is None:
logger.warning("No FSDP Mesh available, skipping FSDP wrapping.")
return model
@@ -115,11 +128,10 @@ class FSDP2Engine:
else:
logger.info(f"Applying per-layer FSDP to {layer_cls.__name__}")
transformer_layer_cls_to_wrap = {layer_cls}
if self.is_lora_module_wrap(model):
lora_modules = []
for module in model.modules():
if len(list(module.children())) != 0:
continue
if any(param.requires_grad for param in module.parameters(recurse=False)):
@@ -134,7 +146,7 @@ class FSDP2Engine:
offload_policy=CPUOffloadPolicy(pin_memory=self.pin_memory) if self.offload_params else None,
)
logger.info(f"Applying FSDP wrap for LoRA layer separately.")
logger.info("Applying FSDP wrap for LoRA layer separately.")
for name, module in model.named_modules():
should_wrap = False
@@ -179,8 +191,9 @@ class FSDP2Engine:
)
return model
@torch.no_grad()
def materialize_and_load(self, model: PreTrainedModel, hf_model_path: str, dcp_path: str = None):
def materialize_and_load(self, model: HFModel, hf_model_path: str, dcp_path: str = None):
if self.rank == 0:
logger.info("Materializing sharded model params...")
@@ -200,7 +213,7 @@ class FSDP2Engine:
return model
def shard_model(self, model: PreTrainedModel) -> PreTrainedModel:
def shard_model(self, model: HFModel) -> HFModel:
if model.device.type == "meta":
model = self.prepare_model(model)
model = self.materialize_and_load(model, hf_model_path=model.config.name_or_path, dcp_path=self.dcp_path)
@@ -208,7 +221,7 @@ class FSDP2Engine:
model = self.prepare_model(model)
return model
def _load_from_dcp(self, model: PreTrainedModel, dcp_path: str):
def _load_from_dcp(self, model: HFModel, dcp_path: str):
import torch.distributed.checkpoint as dcp
try:
@@ -227,7 +240,7 @@ class FSDP2Engine:
logger.error(f"Failed to load from DCP: {e}")
raise e
def _load_weights_from_hf_checkpoint(self, model, hf_model_path):
def _load_weights_from_hf_checkpoint(self, model: HFModel, hf_model_path: str):
import glob
import json

View File

@@ -12,9 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import TYPE_CHECKING
from ....config.arg_utils import PluginConfig
from ....utils.plugin import BasePlugin
from ....utils.types import HFModel
if TYPE_CHECKING:
from ....utils.types import HFModel, Processor
class DistributedPlugin(BasePlugin):
@@ -23,12 +30,32 @@ class DistributedPlugin(BasePlugin):
@DistributedPlugin("fsdp2").register()
def shard_model_fsdp2(model: HFModel, dist_config: PluginConfig) -> HFModel:
def shard_model_fsdp2(model: HFModel, dist_config: PluginConfig, **kwargs) -> HFModel:
from .fsdp2 import FSDP2Engine
return FSDP2Engine(dist_config).shard_model(model)
@DistributedPlugin("fsdp2").register("save_model")
def save_model_fsdp2(model: HFModel, output_dir: str, processor: Processor) -> None:
from .fsdp2 import save_model
return save_model(model, output_dir, processor)
@DistributedPlugin("deepspeed").register()
def shard_model_deepspeed(model: HFModel, dist_config: PluginConfig) -> HFModel:
return model
def shard_model_deepspeed(model: HFModel, dist_config: PluginConfig, **kwargs) -> HFModel:
from .deepspeed import DeepSpeedEngine
return DeepSpeedEngine(
dist_config,
num_micro_batch=kwargs.get("num_micro_batch"),
micro_batch_size=kwargs.get("micro_batch_size"),
).shard_model(model)
@DistributedPlugin("deepspeed").register("save_model")
def save_model_deepspeed(model: HFModel, output_dir: str, processor: Processor) -> None:
from .deepspeed import save_model
return save_model(model, output_dir, processor)

View File

@@ -21,6 +21,13 @@ from functools import lru_cache
from typing import TYPE_CHECKING
from packaging import version
from transformers.utils.versions import require_version
from . import logging
from .env import is_env_enabled
logger = logging.get_logger(__name__)
if TYPE_CHECKING:
@@ -41,3 +48,22 @@ def _get_package_version(name: str) -> "Version":
@lru_cache
def is_transformers_version_greater_than(content: str):
return _get_package_version("transformers") >= version.parse(content)
def check_version(requirement: str, mandatory: bool = False) -> None:
r"""Optionally check the package version."""
if is_env_enabled("DISABLE_VERSION_CHECK") and not mandatory:
logger.warning_rank0_once("Version checking has been disabled, may lead to unexpected behaviors.")
return
if "gptqmodel" in requirement or "autoawq" in requirement:
pip_command = f"pip install {requirement} --no-build-isolation"
else:
pip_command = f"pip install {requirement}"
if mandatory:
hint = f"To fix: run `{pip_command}`."
else:
hint = f"To fix: run `{pip_command}` or set `DISABLE_VERSION_CHECK=1` to skip this check."
require_version(requirement, hint)

View File

@@ -166,3 +166,33 @@ def _manage_distributed_env(request: FixtureRequest, monkeypatch: MonkeyPatch) -
def fix_valuehead_cpu_loading():
"""Fix valuehead model loading."""
patch_valuehead_model()
@pytest.fixture(scope="session", autouse=True)
def bypass_mistral_regex_check():
"""Disable Mistral regex network check.
Monkey-patch TokenizersBackend._patch_mistral_regex into a no-op.
"""
try:
from transformers.tokenization_utils_fast import TokenizersBackend
except ImportError:
# Very old transformers, nothing to patch
yield
return
if not hasattr(TokenizersBackend, "_patch_mistral_regex"):
# Method does not exist in this version
yield
return
# Backup original method
original = TokenizersBackend._patch_mistral_regex
# Replace with no-op
TokenizersBackend._patch_mistral_regex = lambda cls, tokenizer, *args, **kwargs: tokenizer
yield
# Restore original method
TokenizersBackend._patch_mistral_regex = original

View File

@@ -22,6 +22,7 @@ from transformers import AutoConfig, AutoModelForImageTextToText
from llamafactory.data import get_template_and_fix_tokenizer
from llamafactory.data.collator import MultiModalDataCollatorForSeq2Seq, prepare_4d_attention_mask
from llamafactory.extras.constants import IGNORE_INDEX
from llamafactory.extras.packages import is_transformers_version_greater_than
from llamafactory.hparams import get_infer_args
from llamafactory.model import load_tokenizer
@@ -116,14 +117,16 @@ def test_multimodal_collator():
"labels": [
[0, 1, 2, 3, q, q, q, q, q, q, q, q],
],
"position_ids": [
[[0, 1, 2, 3, 1, 1, 1, 1, 1, 1, 1, 1]],
[[0, 1, 2, 3, 1, 1, 1, 1, 1, 1, 1, 1]],
[[0, 1, 2, 3, 1, 1, 1, 1, 1, 1, 1, 1]],
],
"rope_deltas": [[-8]],
"position_ids": [[[0, 1, 2, 3, 0, 0, 0, 0, 0, 0, 0, 0]]] * 3,
"rope_deltas": [[0]],
**tokenizer_module["processor"].image_processor(fake_image),
}
if not is_transformers_version_greater_than("5.0.0"):
# adapt position_ids and rope_deltas for transformers < 5.0.0
# https://github.com/huggingface/transformers/pull/43972
expected_input["position_ids"] = [[[0, 1, 2, 3, 1, 1, 1, 1, 1, 1, 1, 1]]] * 3
expected_input["rope_deltas"] = [[-8]]
assert batch_input.keys() == expected_input.keys()
for k in batch_input.keys():
assert batch_input[k].eq(torch.tensor(expected_input[k])).all()

View File

@@ -1,2 +1,2 @@
# change if test fails or cache is outdated
0.9.5.106
0.9.5.107

View File

@@ -172,3 +172,33 @@ def _manage_distributed_env(request: FixtureRequest, monkeypatch: MonkeyPatch) -
monkeypatch.setattr(torch.cuda, "device_count", lambda: 1)
elif CURRENT_DEVICE == "npu":
monkeypatch.setattr(torch.npu, "device_count", lambda: 1)
@pytest.fixture(scope="session", autouse=True)
def bypass_mistral_regex_check():
"""Disable Mistral regex network check.
Monkey-patch TokenizersBackend._patch_mistral_regex into a no-op.
"""
try:
from transformers.tokenization_utils_fast import TokenizersBackend
except ImportError:
# Very old transformers, nothing to patch
yield
return
if not hasattr(TokenizersBackend, "_patch_mistral_regex"):
# Method does not exist in this version
yield
return
# Backup original method
original = TokenizersBackend._patch_mistral_regex
# Replace with no-op
TokenizersBackend._patch_mistral_regex = lambda cls, tokenizer, *args, **kwargs: tokenizer
yield
# Restore original method
TokenizersBackend._patch_mistral_regex = original

View File

@@ -0,0 +1,51 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
from llamafactory.v1.config.model_args import ModelArguments
from llamafactory.v1.core.model_engine import ModelEngine
bitsandbytes = pytest.importorskip("bitsandbytes")
def check_quantization_status(model):
quantized_info = {"bnb": []}
for name, module in model.named_modules():
# check BitsAndBytes quantization
if isinstance(module, bitsandbytes.nn.modules.Linear8bitLt) or isinstance(
module, bitsandbytes.nn.modules.Linear4bit
):
quantized_info["bnb"].append(name)
return quantized_info
@pytest.mark.runs_on(["cuda"])
@pytest.mark.parametrize("name, quantization_bit", [("bnb", 4), ("auto", 4)])
def test_quantization_plugin(name, quantization_bit):
model_args = ModelArguments(
model="llamafactory/tiny-random-qwen3",
quant_config={
"name": name,
"quantization_bit": quantization_bit,
},
)
model_engine = ModelEngine(model_args=model_args)
quantized_info = check_quantization_status(model_engine.model)
print(f"Quantized weights for method {name} with {quantization_bit} bit: {quantized_info}")
assert any(v for v in quantized_info.values()), "model is not quantized properly."