From 0b188ca00c9de9efee63807e72e444ea74195da5 Mon Sep 17 00:00:00 2001 From: Kingsley Date: Mon, 30 Jun 2025 01:09:41 +0800 Subject: [PATCH] [model] add GLM-4.1V (#8462) --- src/llamafactory/data/collator.py | 3 +- src/llamafactory/data/mm_plugin.py | 130 +++++++++++++++++++ src/llamafactory/data/template.py | 18 +++ src/llamafactory/extras/constants.py | 12 ++ src/llamafactory/model/model_utils/visual.py | 10 ++ 5 files changed, 172 insertions(+), 1 deletion(-) diff --git a/src/llamafactory/data/collator.py b/src/llamafactory/data/collator.py index b749aaef..45275392 100644 --- a/src/llamafactory/data/collator.py +++ b/src/llamafactory/data/collator.py @@ -210,7 +210,8 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): if ( self.model is not None - and getattr(self.model.config, "model_type", None) in ["qwen2_vl", "qwen2_5_vl", "qwen2_5_omni_thinker"] + and getattr(self.model.config, "model_type", None) + in ["glm4v", "qwen2_vl", "qwen2_5_vl", "qwen2_5_omni_thinker"] and ("position_ids" not in features or features["position_ids"].dim() != 3) ): raise ValueError("Qwen2-VL/Qwen2.5-Omni model requires 3D position ids for mrope.") diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py index 7bf17f79..21ff1304 100644 --- a/src/llamafactory/data/mm_plugin.py +++ b/src/llamafactory/data/mm_plugin.py @@ -1498,6 +1498,135 @@ class Qwen2VLPlugin(BasePlugin): return messages +@dataclass +class GLM4VPlugin(Qwen2VLPlugin): + @override + def _get_mm_inputs( + self, + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], + processor: "MMProcessor", + ) -> dict[str, "torch.Tensor"]: + image_processor: BaseImageProcessor = getattr(processor, "image_processor", None) + video_processor: BaseImageProcessor = getattr(processor, "video_processor", None) + mm_inputs = {} + if len(images) != 0: + images = self._regularize_images( + images, + image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768), + image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32), + )["images"] + mm_inputs.update(image_processor(images, return_tensors="pt")) + + if len(videos) != 0: + video_data = self._regularize_videos( + videos, + image_max_pixels=getattr(processor, "video_max_pixels", 256 * 256), + image_min_pixels=getattr(processor, "video_min_pixels", 16 * 16), + video_fps=getattr(processor, "video_fps", 2.0), + video_maxlen=getattr(processor, "video_maxlen", 128), + ) + # prepare video metadata + video_metadata = [ + {"fps": 2, "duration": len(video), "total_frames": len(video)} for video in video_data["videos"] + ] + mm_inputs.update( + video_processor( + images=None, videos=video_data["videos"], video_metadata=video_metadata, return_tensors="pt" + ) + ) + + return mm_inputs + + @override + def process_messages( + self, + messages: list[dict[str, str]], + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], + processor: Optional["MMProcessor"], + ) -> list[dict[str, str]]: + self._validate_input(processor, images, videos, audios) + self._validate_messages(messages, images, videos, audios) + num_image_tokens, num_video_tokens = 0, 0 + messages = deepcopy(messages) + image_processor: BaseImageProcessor = getattr(processor, "image_processor") + + merge_length: int = getattr(image_processor, "merge_size") ** 2 + if self.expand_mm_tokens: + mm_inputs = self._get_mm_inputs(images, videos, audios, processor) + image_grid_thw = mm_inputs.get("image_grid_thw", []) + video_grid_thw = mm_inputs.get("video_grid_thw", []) + num_frames = len(video_grid_thw) + timestamps = mm_inputs.get("timestamps", []) + if hasattr(timestamps, "tolist"): + timestamps = timestamps.tolist() + + if not timestamps: + timestamps_list = [] + elif isinstance(timestamps[0], list): + timestamps_list = timestamps[0] + else: + timestamps_list = timestamps + + unique_timestamps = timestamps_list.copy() + selected_timestamps = unique_timestamps[:num_frames] + while len(selected_timestamps) < num_frames: + selected_timestamps.append(selected_timestamps[-1] if selected_timestamps else 0) + + else: + image_grid_thw = [None] * len(images) + video_grid_thw = [None] * len(videos) + num_frames = 0 + selected_timestamps = [0] + + for message in messages: + content = message["content"] + while IMAGE_PLACEHOLDER in content: + image_seqlen = image_grid_thw[num_image_tokens].prod() // merge_length if self.expand_mm_tokens else 1 + content = content.replace( + IMAGE_PLACEHOLDER, f"<|begin_of_image|>{self.image_token * image_seqlen}<|end_of_image|>", 1 + ) + num_image_tokens += 1 + + # TODO: DO NOT SUPPORT VIDEO UNTIL NEXT PR + while VIDEO_PLACEHOLDER in content: + video_structure = "" + for frame_index in range(num_frames): + video_seqlen = video_grid_thw[frame_index].prod() // merge_length if self.expand_mm_tokens else 1 + timestamp_sec = selected_timestamps[frame_index] + frame_structure = ( + f"<|begin_of_image|>{self.image_token * video_seqlen}<|end_of_image|>{timestamp_sec}" + ) + video_structure += frame_structure + + content = content.replace(VIDEO_PLACEHOLDER, video_structure, 1) + num_video_tokens += 1 # FIXME: num_video_tokens is not used + + message["content"] = content + + return messages + + @override + def get_mm_inputs( + self, + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], + imglens: list[int], + vidlens: list[int], + audlens: list[int], + batch_ids: list[list[int]], + processor: Optional["ProcessorMixin"], + ) -> dict[str, Union[list[int], "torch.Tensor"]]: + self._validate_input(processor, images, videos, audios) + mm_inputs = self._get_mm_inputs(images, videos, audios, processor) + mm_inputs.pop("timestamps", None) + return mm_inputs + + class Qwen2OmniPlugin(Qwen2VLPlugin): @override def _get_mm_inputs( @@ -1715,6 +1844,7 @@ class VideoLlavaPlugin(BasePlugin): PLUGINS = { "base": BasePlugin, "gemma3": Gemma3Plugin, + "glm4v": GLM4VPlugin, "intern_vl": InternVLPlugin, "kimi_vl": KimiVLPlugin, "llama4": Llama4Plugin, diff --git a/src/llamafactory/data/template.py b/src/llamafactory/data/template.py index b3e2132b..5ffa186f 100644 --- a/src/llamafactory/data/template.py +++ b/src/llamafactory/data/template.py @@ -998,6 +998,24 @@ register_template( ) +# part copied from glm4 template +register_template( + name="glm4v", + format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]), + format_assistant=StringFormatter(slots=["\n{{content}}"]), + format_system=StringFormatter(slots=["<|system|>\n{{content}}"]), + format_function=FunctionFormatter(slots=["{{content}}"], tool_format="glm4"), + format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>"]), + format_tools=ToolFormatter(tool_format="glm4"), + format_prefix=EmptyFormatter(slots=["[gMASK]"]), + stop_words=["<|user|>", "<|observation|>"], + efficient_eos=True, + mm_plugin=get_mm_plugin(name="glm4v", image_token="<|image|>", video_token="<|video|>"), + template_class=ReasoningTemplate, + thought_words=("", ""), +) + + # copied from glm4 template register_template( name="glmz1", diff --git a/src/llamafactory/extras/constants.py b/src/llamafactory/extras/constants.py index 6459168d..1729dcab 100644 --- a/src/llamafactory/extras/constants.py +++ b/src/llamafactory/extras/constants.py @@ -834,6 +834,18 @@ register_model_group( ) +register_model_group( + models={ + "GLM-4.1V-9B-Thinking": { + DownloadSource.DEFAULT: "THUDM/GLM-4.1V-9B-Thinking", + DownloadSource.MODELSCOPE: "ZhipuAI/GLM-4.1V-9B-Thinking", + } + }, + template="glm4v", + multimodal=True, +) + + register_model_group( models={ "GLM-Z1-0414-9B-Chat": { diff --git a/src/llamafactory/model/model_utils/visual.py b/src/llamafactory/model/model_utils/visual.py index ba2bf5c9..f12a56ec 100644 --- a/src/llamafactory/model/model_utils/visual.py +++ b/src/llamafactory/model/model_utils/visual.py @@ -204,6 +204,16 @@ _register_composite_model( ) +# copied from qwen2vl +_register_composite_model( + model_type="glm4v", + projector_key="visual.merger", + vision_model_keys=["visual.patch_embed", "visual.blocks"], + language_model_keys=["language_model", "lm_head"], + lora_conflict_keys=["patch_embed"], +) + + _register_composite_model( model_type="internvl", )