diff --git a/pyproject.toml b/pyproject.toml
index 39dd0ad31..d14019011 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -40,7 +40,7 @@ dependencies = [
"torch>=2.4.0",
"torchvision>=0.19.0",
"torchaudio>=2.4.0",
- "transformers>=4.51.0,<=5.0.0,!=4.52.0,!=4.57.0",
+ "transformers>=4.51.0,<=5.2.0,!=4.52.0,!=4.57.0",
"datasets>=2.16.0,<=4.0.0",
"accelerate>=1.3.0,<=1.11.0",
"peft>=0.18.0,<=0.18.1",
diff --git a/src/llamafactory/data/collator.py b/src/llamafactory/data/collator.py
index d1e562d93..31809eb24 100644
--- a/src/llamafactory/data/collator.py
+++ b/src/llamafactory/data/collator.py
@@ -15,6 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import inspect
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Literal, Optional
@@ -189,6 +190,16 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
"video_grid_thw": mm_inputs.get("video_grid_thw"),
"attention_mask": (features["attention_mask"] >= 1).float(),
}
+ if "mm_token_type_ids" in inspect.signature(self.get_rope_func).parameters:
+ image_token_id = getattr(self.model.config, "image_token_id", None)
+ video_token_id = getattr(self.model.config, "video_token_id", None)
+ if image_token_id is not None or video_token_id is not None:
+ mm_token_type_ids = torch.zeros_like(features["input_ids"])
+ if image_token_id is not None:
+ mm_token_type_ids[features["input_ids"] == image_token_id] = 1
+ if video_token_id is not None:
+ mm_token_type_ids[features["input_ids"] == video_token_id] = 2
+ rope_index_kwargs["mm_token_type_ids"] = mm_token_type_ids
if "second_per_grid_ts" in mm_inputs: # for qwen2vl
rope_index_kwargs["second_per_grid_ts"] = mm_inputs.get("second_per_grid_ts")
elif "video_second_per_grid" in mm_inputs: # for qwen2.5 omni
@@ -219,6 +230,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
"qwen2_5_vl",
"qwen2_5_omni_thinker",
"qwen3_omni_moe_thinker",
+ "qwen3_5",
"qwen3_vl",
"qwen3_vl_moe",
]
diff --git a/src/llamafactory/data/template.py b/src/llamafactory/data/template.py
index 1132d7111..f41d74f61 100644
--- a/src/llamafactory/data/template.py
+++ b/src/llamafactory/data/template.py
@@ -2029,6 +2029,39 @@ register_template(
)
+register_template(
+ name="qwen3_5",
+ format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
+ format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
+ format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
+ format_function=FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen3_5"),
+ format_observation=StringFormatter(
+ slots=["<|im_start|>user\n\n{{content}}\n<|im_end|>\n<|im_start|>assistant\n"]
+ ),
+ format_tools=ToolFormatter(tool_format="qwen3_5"),
+ stop_words=["<|im_end|>"],
+ replace_eos=True,
+ mm_plugin=get_mm_plugin(name="qwen3_vl", image_token="<|image_pad|>", video_token="<|video_pad|>"),
+ template_class=ReasoningTemplate,
+)
+
+
+register_template(
+ name="qwen3_5_nothink",
+ format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
+ format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
+ format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
+ format_function=FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen3_5"),
+ format_observation=StringFormatter(
+ slots=["<|im_start|>user\n\n{{content}}\n<|im_end|>\n<|im_start|>assistant\n"]
+ ),
+ format_tools=ToolFormatter(tool_format="qwen3_5"),
+ stop_words=["<|im_end|>"],
+ replace_eos=True,
+ mm_plugin=get_mm_plugin(name="qwen3_vl", image_token="<|image_pad|>", video_token="<|video_pad|>"),
+)
+
+
register_template(
name="sailor",
format_user=StringFormatter(slots=["<|im_start|>question\n{{content}}<|im_end|>\n<|im_start|>answer\n"]),
diff --git a/src/llamafactory/data/tool_utils.py b/src/llamafactory/data/tool_utils.py
index 18c6ad2f0..61e661795 100644
--- a/src/llamafactory/data/tool_utils.py
+++ b/src/llamafactory/data/tool_utils.py
@@ -85,6 +85,21 @@ QWEN_TOOL_PROMPT = (
""""arguments": }}\n"""
)
+QWEN35_TOOL_PROMPT = (
+ "\n\n# Tools\n\nYou have access to the following functions:\n\n{tool_text}"
+ "\n\n\nIf you choose to call a function ONLY reply in the following format with NO suffix:\n\n"
+ "\n\n\nvalue_1\n\n"
+ "\nThis is the value for the second parameter\nthat can span\nmultiple lines\n"
+ "\n\n\n\n\nReminder:\n"
+ "- Function calls MUST follow the specified format: "
+ "an inner block must be nested within XML tags\n"
+ "- Required parameters MUST be specified\n"
+ "- You may provide optional reasoning for your function call in natural language "
+ "BEFORE the function call, but NOT after\n"
+ "- If there is no function call available, answer the question like normal with your current knowledge "
+ "and do not tell the user about function calls\n"
+)
+
SEED_TOOL_PROMPT = (
"system\nYou are Doubao, a helpful AI assistant. You may call one or more functions to assist with the user query."
"Tool List:\nYou are authorized to use the following tools (described in JSON Schema format). Before performing "
@@ -453,6 +468,57 @@ class QwenToolUtils(ToolUtils):
return results
+class Qwen35ToolUtils(ToolUtils):
+ r"""Qwen 3.5 tool using template."""
+
+ @override
+ @staticmethod
+ def tool_formatter(tools: list[dict[str, Any]]) -> str:
+ tool_text = ""
+ for tool in tools:
+ tool = tool.get("function", tool) if tool.get("type") == "function" else tool
+ tool_text += "\n" + json.dumps(tool, ensure_ascii=False)
+
+ return QWEN35_TOOL_PROMPT.format(tool_text=tool_text)
+
+ @override
+ @staticmethod
+ def function_formatter(functions: list["FunctionCall"]) -> str:
+ function_texts = []
+ for func in functions:
+ name, arguments = func.name, json.loads(func.arguments)
+ prompt = f"\n"
+ for key, value in arguments.items():
+ prompt += f"\n"
+ if not isinstance(value, str):
+ value = json.dumps(value, ensure_ascii=False)
+ prompt += f"\n{value}\n"
+ prompt += "\n\n"
+ function_texts.append(prompt)
+
+ return "\n".join(function_texts)
+
+ @override
+ @staticmethod
+ def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]:
+ results = []
+ regex = re.compile(r"\s*]+)\s*(.*?)\s*\s*", re.DOTALL)
+ for func_name, params_block in re.findall(regex, content):
+ args_dict = {}
+ param_pattern = re.compile(r"(.*?)", re.DOTALL)
+ for key, raw_value in re.findall(param_pattern, params_block.strip()):
+ value = raw_value.strip()
+ try:
+ parsed_value = json.loads(value)
+ except json.JSONDecodeError:
+ parsed_value = raw_value.strip()
+ args_dict[key] = parsed_value
+
+ results.append(FunctionCall(func_name.strip(), json.dumps(args_dict, ensure_ascii=False)))
+
+ return results if results else content
+
+
class GLM4MOEToolUtils(QwenToolUtils):
r"""GLM-4-MOE tool using template."""
@@ -662,6 +728,7 @@ TOOLS = {
"minimax2": MiniMaxM2ToolUtils(),
"mistral": MistralToolUtils(),
"qwen": QwenToolUtils(),
+ "qwen3_5": Qwen35ToolUtils(),
"glm4_moe": GLM4MOEToolUtils(),
"seed_oss": SeedToolUtils(),
"ling": LingToolUtils(),
diff --git a/src/llamafactory/extras/misc.py b/src/llamafactory/extras/misc.py
index 5c077dd57..8214267ed 100644
--- a/src/llamafactory/extras/misc.py
+++ b/src/llamafactory/extras/misc.py
@@ -94,7 +94,7 @@ def check_version(requirement: str, mandatory: bool = False) -> None:
def check_dependencies() -> None:
r"""Check the version of the required packages."""
- check_version("transformers>=4.51.0,<=5.0.0")
+ check_version("transformers>=4.51.0,<=5.2.0")
check_version("datasets>=2.16.0,<=4.0.0")
check_version("accelerate>=1.3.0,<=1.11.0")
check_version("peft>=0.18.0,<=0.18.1")
diff --git a/tests/data/test_collator.py b/tests/data/test_collator.py
index 63370b1b6..7222a8658 100644
--- a/tests/data/test_collator.py
+++ b/tests/data/test_collator.py
@@ -22,6 +22,7 @@ from transformers import AutoConfig, AutoModelForImageTextToText
from llamafactory.data import get_template_and_fix_tokenizer
from llamafactory.data.collator import MultiModalDataCollatorForSeq2Seq, prepare_4d_attention_mask
from llamafactory.extras.constants import IGNORE_INDEX
+from llamafactory.extras.packages import is_transformers_version_greater_than
from llamafactory.hparams import get_infer_args
from llamafactory.model import load_tokenizer
@@ -116,14 +117,14 @@ def test_multimodal_collator():
"labels": [
[0, 1, 2, 3, q, q, q, q, q, q, q, q],
],
- "position_ids": [
- [[0, 1, 2, 3, 1, 1, 1, 1, 1, 1, 1, 1]],
- [[0, 1, 2, 3, 1, 1, 1, 1, 1, 1, 1, 1]],
- [[0, 1, 2, 3, 1, 1, 1, 1, 1, 1, 1, 1]],
- ],
- "rope_deltas": [[-8]],
+ "position_ids": [[[0, 1, 2, 3, 0, 0, 0, 0, 0, 0, 0, 0]]] * 3,
+ "rope_deltas": [[0]],
**tokenizer_module["processor"].image_processor(fake_image),
}
+ if not is_transformers_version_greater_than("5.0.0"):
+ expected_input["position_ids"] = [[[0, 1, 2, 3, 1, 1, 1, 1, 1, 1, 1, 1]]] * 3
+ expected_input["rope_deltas"] = [[-8]]
+
assert batch_input.keys() == expected_input.keys()
for k in batch_input.keys():
assert batch_input[k].eq(torch.tensor(expected_input[k])).all()
diff --git a/tests/version.txt b/tests/version.txt
index fdd7d35a4..e19c965ec 100644
--- a/tests/version.txt
+++ b/tests/version.txt
@@ -1,2 +1,2 @@
# change if test fails or cache is outdated
-0.9.5.106
+0.9.5.107