diff --git a/pyproject.toml b/pyproject.toml index 39dd0ad31..d14019011 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,7 +40,7 @@ dependencies = [ "torch>=2.4.0", "torchvision>=0.19.0", "torchaudio>=2.4.0", - "transformers>=4.51.0,<=5.0.0,!=4.52.0,!=4.57.0", + "transformers>=4.51.0,<=5.2.0,!=4.52.0,!=4.57.0", "datasets>=2.16.0,<=4.0.0", "accelerate>=1.3.0,<=1.11.0", "peft>=0.18.0,<=0.18.1", diff --git a/src/llamafactory/data/collator.py b/src/llamafactory/data/collator.py index d1e562d93..31809eb24 100644 --- a/src/llamafactory/data/collator.py +++ b/src/llamafactory/data/collator.py @@ -15,6 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import inspect from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Literal, Optional @@ -189,6 +190,16 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): "video_grid_thw": mm_inputs.get("video_grid_thw"), "attention_mask": (features["attention_mask"] >= 1).float(), } + if "mm_token_type_ids" in inspect.signature(self.get_rope_func).parameters: + image_token_id = getattr(self.model.config, "image_token_id", None) + video_token_id = getattr(self.model.config, "video_token_id", None) + if image_token_id is not None or video_token_id is not None: + mm_token_type_ids = torch.zeros_like(features["input_ids"]) + if image_token_id is not None: + mm_token_type_ids[features["input_ids"] == image_token_id] = 1 + if video_token_id is not None: + mm_token_type_ids[features["input_ids"] == video_token_id] = 2 + rope_index_kwargs["mm_token_type_ids"] = mm_token_type_ids if "second_per_grid_ts" in mm_inputs: # for qwen2vl rope_index_kwargs["second_per_grid_ts"] = mm_inputs.get("second_per_grid_ts") elif "video_second_per_grid" in mm_inputs: # for qwen2.5 omni @@ -219,6 +230,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): "qwen2_5_vl", "qwen2_5_omni_thinker", "qwen3_omni_moe_thinker", + "qwen3_5", "qwen3_vl", "qwen3_vl_moe", ] diff --git a/src/llamafactory/data/template.py b/src/llamafactory/data/template.py index 1132d7111..f41d74f61 100644 --- a/src/llamafactory/data/template.py +++ b/src/llamafactory/data/template.py @@ -2029,6 +2029,39 @@ register_template( ) +register_template( + name="qwen3_5", + format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), + format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]), + format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), + format_function=FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen3_5"), + format_observation=StringFormatter( + slots=["<|im_start|>user\n\n{{content}}\n<|im_end|>\n<|im_start|>assistant\n"] + ), + format_tools=ToolFormatter(tool_format="qwen3_5"), + stop_words=["<|im_end|>"], + replace_eos=True, + mm_plugin=get_mm_plugin(name="qwen3_vl", image_token="<|image_pad|>", video_token="<|video_pad|>"), + template_class=ReasoningTemplate, +) + + +register_template( + name="qwen3_5_nothink", + format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), + format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]), + format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), + format_function=FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen3_5"), + format_observation=StringFormatter( + slots=["<|im_start|>user\n\n{{content}}\n<|im_end|>\n<|im_start|>assistant\n"] + ), + format_tools=ToolFormatter(tool_format="qwen3_5"), + stop_words=["<|im_end|>"], + replace_eos=True, + mm_plugin=get_mm_plugin(name="qwen3_vl", image_token="<|image_pad|>", video_token="<|video_pad|>"), +) + + register_template( name="sailor", format_user=StringFormatter(slots=["<|im_start|>question\n{{content}}<|im_end|>\n<|im_start|>answer\n"]), diff --git a/src/llamafactory/data/tool_utils.py b/src/llamafactory/data/tool_utils.py index 18c6ad2f0..61e661795 100644 --- a/src/llamafactory/data/tool_utils.py +++ b/src/llamafactory/data/tool_utils.py @@ -85,6 +85,21 @@ QWEN_TOOL_PROMPT = ( """"arguments": }}\n""" ) +QWEN35_TOOL_PROMPT = ( + "\n\n# Tools\n\nYou have access to the following functions:\n\n{tool_text}" + "\n\n\nIf you choose to call a function ONLY reply in the following format with NO suffix:\n\n" + "\n\n\nvalue_1\n\n" + "\nThis is the value for the second parameter\nthat can span\nmultiple lines\n" + "\n\n\n\n\nReminder:\n" + "- Function calls MUST follow the specified format: " + "an inner block must be nested within XML tags\n" + "- Required parameters MUST be specified\n" + "- You may provide optional reasoning for your function call in natural language " + "BEFORE the function call, but NOT after\n" + "- If there is no function call available, answer the question like normal with your current knowledge " + "and do not tell the user about function calls\n" +) + SEED_TOOL_PROMPT = ( "system\nYou are Doubao, a helpful AI assistant. You may call one or more functions to assist with the user query." "Tool List:\nYou are authorized to use the following tools (described in JSON Schema format). Before performing " @@ -453,6 +468,57 @@ class QwenToolUtils(ToolUtils): return results +class Qwen35ToolUtils(ToolUtils): + r"""Qwen 3.5 tool using template.""" + + @override + @staticmethod + def tool_formatter(tools: list[dict[str, Any]]) -> str: + tool_text = "" + for tool in tools: + tool = tool.get("function", tool) if tool.get("type") == "function" else tool + tool_text += "\n" + json.dumps(tool, ensure_ascii=False) + + return QWEN35_TOOL_PROMPT.format(tool_text=tool_text) + + @override + @staticmethod + def function_formatter(functions: list["FunctionCall"]) -> str: + function_texts = [] + for func in functions: + name, arguments = func.name, json.loads(func.arguments) + prompt = f"\n" + for key, value in arguments.items(): + prompt += f"\n" + if not isinstance(value, str): + value = json.dumps(value, ensure_ascii=False) + prompt += f"\n{value}\n" + prompt += "\n\n" + function_texts.append(prompt) + + return "\n".join(function_texts) + + @override + @staticmethod + def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]: + results = [] + regex = re.compile(r"\s*]+)\s*(.*?)\s*\s*", re.DOTALL) + for func_name, params_block in re.findall(regex, content): + args_dict = {} + param_pattern = re.compile(r"(.*?)", re.DOTALL) + for key, raw_value in re.findall(param_pattern, params_block.strip()): + value = raw_value.strip() + try: + parsed_value = json.loads(value) + except json.JSONDecodeError: + parsed_value = raw_value.strip() + args_dict[key] = parsed_value + + results.append(FunctionCall(func_name.strip(), json.dumps(args_dict, ensure_ascii=False))) + + return results if results else content + + class GLM4MOEToolUtils(QwenToolUtils): r"""GLM-4-MOE tool using template.""" @@ -662,6 +728,7 @@ TOOLS = { "minimax2": MiniMaxM2ToolUtils(), "mistral": MistralToolUtils(), "qwen": QwenToolUtils(), + "qwen3_5": Qwen35ToolUtils(), "glm4_moe": GLM4MOEToolUtils(), "seed_oss": SeedToolUtils(), "ling": LingToolUtils(), diff --git a/src/llamafactory/extras/misc.py b/src/llamafactory/extras/misc.py index 5c077dd57..8214267ed 100644 --- a/src/llamafactory/extras/misc.py +++ b/src/llamafactory/extras/misc.py @@ -94,7 +94,7 @@ def check_version(requirement: str, mandatory: bool = False) -> None: def check_dependencies() -> None: r"""Check the version of the required packages.""" - check_version("transformers>=4.51.0,<=5.0.0") + check_version("transformers>=4.51.0,<=5.2.0") check_version("datasets>=2.16.0,<=4.0.0") check_version("accelerate>=1.3.0,<=1.11.0") check_version("peft>=0.18.0,<=0.18.1") diff --git a/tests/data/test_collator.py b/tests/data/test_collator.py index 63370b1b6..7222a8658 100644 --- a/tests/data/test_collator.py +++ b/tests/data/test_collator.py @@ -22,6 +22,7 @@ from transformers import AutoConfig, AutoModelForImageTextToText from llamafactory.data import get_template_and_fix_tokenizer from llamafactory.data.collator import MultiModalDataCollatorForSeq2Seq, prepare_4d_attention_mask from llamafactory.extras.constants import IGNORE_INDEX +from llamafactory.extras.packages import is_transformers_version_greater_than from llamafactory.hparams import get_infer_args from llamafactory.model import load_tokenizer @@ -116,14 +117,14 @@ def test_multimodal_collator(): "labels": [ [0, 1, 2, 3, q, q, q, q, q, q, q, q], ], - "position_ids": [ - [[0, 1, 2, 3, 1, 1, 1, 1, 1, 1, 1, 1]], - [[0, 1, 2, 3, 1, 1, 1, 1, 1, 1, 1, 1]], - [[0, 1, 2, 3, 1, 1, 1, 1, 1, 1, 1, 1]], - ], - "rope_deltas": [[-8]], + "position_ids": [[[0, 1, 2, 3, 0, 0, 0, 0, 0, 0, 0, 0]]] * 3, + "rope_deltas": [[0]], **tokenizer_module["processor"].image_processor(fake_image), } + if not is_transformers_version_greater_than("5.0.0"): + expected_input["position_ids"] = [[[0, 1, 2, 3, 1, 1, 1, 1, 1, 1, 1, 1]]] * 3 + expected_input["rope_deltas"] = [[-8]] + assert batch_input.keys() == expected_input.keys() for k in batch_input.keys(): assert batch_input[k].eq(torch.tensor(expected_input[k])).all() diff --git a/tests/version.txt b/tests/version.txt index fdd7d35a4..e19c965ec 100644 --- a/tests/version.txt +++ b/tests/version.txt @@ -1,2 +1,2 @@ # change if test fails or cache is outdated -0.9.5.106 +0.9.5.107