[model] add GLM-4.1V (#8462)

This commit is contained in:
Kingsley 2025-06-30 01:09:41 +08:00 committed by GitHub
parent 0a004904bd
commit 0b188ca00c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 172 additions and 1 deletions

View File

@ -210,7 +210,8 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
if (
self.model is not None
and getattr(self.model.config, "model_type", None) in ["qwen2_vl", "qwen2_5_vl", "qwen2_5_omni_thinker"]
and getattr(self.model.config, "model_type", None)
in ["glm4v", "qwen2_vl", "qwen2_5_vl", "qwen2_5_omni_thinker"]
and ("position_ids" not in features or features["position_ids"].dim() != 3)
):
raise ValueError("Qwen2-VL/Qwen2.5-Omni model requires 3D position ids for mrope.")

View File

@ -1498,6 +1498,135 @@ class Qwen2VLPlugin(BasePlugin):
return messages
@dataclass
class GLM4VPlugin(Qwen2VLPlugin):
@override
def _get_mm_inputs(
self,
images: list["ImageInput"],
videos: list["VideoInput"],
audios: list["AudioInput"],
processor: "MMProcessor",
) -> dict[str, "torch.Tensor"]:
image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
video_processor: BaseImageProcessor = getattr(processor, "video_processor", None)
mm_inputs = {}
if len(images) != 0:
images = self._regularize_images(
images,
image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768),
image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32),
)["images"]
mm_inputs.update(image_processor(images, return_tensors="pt"))
if len(videos) != 0:
video_data = self._regularize_videos(
videos,
image_max_pixels=getattr(processor, "video_max_pixels", 256 * 256),
image_min_pixels=getattr(processor, "video_min_pixels", 16 * 16),
video_fps=getattr(processor, "video_fps", 2.0),
video_maxlen=getattr(processor, "video_maxlen", 128),
)
# prepare video metadata
video_metadata = [
{"fps": 2, "duration": len(video), "total_frames": len(video)} for video in video_data["videos"]
]
mm_inputs.update(
video_processor(
images=None, videos=video_data["videos"], video_metadata=video_metadata, return_tensors="pt"
)
)
return mm_inputs
@override
def process_messages(
self,
messages: list[dict[str, str]],
images: list["ImageInput"],
videos: list["VideoInput"],
audios: list["AudioInput"],
processor: Optional["MMProcessor"],
) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios)
self._validate_messages(messages, images, videos, audios)
num_image_tokens, num_video_tokens = 0, 0
messages = deepcopy(messages)
image_processor: BaseImageProcessor = getattr(processor, "image_processor")
merge_length: int = getattr(image_processor, "merge_size") ** 2
if self.expand_mm_tokens:
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
image_grid_thw = mm_inputs.get("image_grid_thw", [])
video_grid_thw = mm_inputs.get("video_grid_thw", [])
num_frames = len(video_grid_thw)
timestamps = mm_inputs.get("timestamps", [])
if hasattr(timestamps, "tolist"):
timestamps = timestamps.tolist()
if not timestamps:
timestamps_list = []
elif isinstance(timestamps[0], list):
timestamps_list = timestamps[0]
else:
timestamps_list = timestamps
unique_timestamps = timestamps_list.copy()
selected_timestamps = unique_timestamps[:num_frames]
while len(selected_timestamps) < num_frames:
selected_timestamps.append(selected_timestamps[-1] if selected_timestamps else 0)
else:
image_grid_thw = [None] * len(images)
video_grid_thw = [None] * len(videos)
num_frames = 0
selected_timestamps = [0]
for message in messages:
content = message["content"]
while IMAGE_PLACEHOLDER in content:
image_seqlen = image_grid_thw[num_image_tokens].prod() // merge_length if self.expand_mm_tokens else 1
content = content.replace(
IMAGE_PLACEHOLDER, f"<|begin_of_image|>{self.image_token * image_seqlen}<|end_of_image|>", 1
)
num_image_tokens += 1
# TODO: DO NOT SUPPORT VIDEO UNTIL NEXT PR
while VIDEO_PLACEHOLDER in content:
video_structure = ""
for frame_index in range(num_frames):
video_seqlen = video_grid_thw[frame_index].prod() // merge_length if self.expand_mm_tokens else 1
timestamp_sec = selected_timestamps[frame_index]
frame_structure = (
f"<|begin_of_image|>{self.image_token * video_seqlen}<|end_of_image|>{timestamp_sec}"
)
video_structure += frame_structure
content = content.replace(VIDEO_PLACEHOLDER, video_structure, 1)
num_video_tokens += 1 # FIXME: num_video_tokens is not used
message["content"] = content
return messages
@override
def get_mm_inputs(
self,
images: list["ImageInput"],
videos: list["VideoInput"],
audios: list["AudioInput"],
imglens: list[int],
vidlens: list[int],
audlens: list[int],
batch_ids: list[list[int]],
processor: Optional["ProcessorMixin"],
) -> dict[str, Union[list[int], "torch.Tensor"]]:
self._validate_input(processor, images, videos, audios)
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
mm_inputs.pop("timestamps", None)
return mm_inputs
class Qwen2OmniPlugin(Qwen2VLPlugin):
@override
def _get_mm_inputs(
@ -1715,6 +1844,7 @@ class VideoLlavaPlugin(BasePlugin):
PLUGINS = {
"base": BasePlugin,
"gemma3": Gemma3Plugin,
"glm4v": GLM4VPlugin,
"intern_vl": InternVLPlugin,
"kimi_vl": KimiVLPlugin,
"llama4": Llama4Plugin,

View File

@ -998,6 +998,24 @@ register_template(
)
# part copied from glm4 template
register_template(
name="glm4v",
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]),
format_assistant=StringFormatter(slots=["\n{{content}}"]),
format_system=StringFormatter(slots=["<|system|>\n{{content}}"]),
format_function=FunctionFormatter(slots=["{{content}}"], tool_format="glm4"),
format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>"]),
format_tools=ToolFormatter(tool_format="glm4"),
format_prefix=EmptyFormatter(slots=["[gMASK]<sop>"]),
stop_words=["<|user|>", "<|observation|>"],
efficient_eos=True,
mm_plugin=get_mm_plugin(name="glm4v", image_token="<|image|>", video_token="<|video|>"),
template_class=ReasoningTemplate,
thought_words=("<think>", "</think>"),
)
# copied from glm4 template
register_template(
name="glmz1",

View File

@ -834,6 +834,18 @@ register_model_group(
)
register_model_group(
models={
"GLM-4.1V-9B-Thinking": {
DownloadSource.DEFAULT: "THUDM/GLM-4.1V-9B-Thinking",
DownloadSource.MODELSCOPE: "ZhipuAI/GLM-4.1V-9B-Thinking",
}
},
template="glm4v",
multimodal=True,
)
register_model_group(
models={
"GLM-Z1-0414-9B-Chat": {

View File

@ -204,6 +204,16 @@ _register_composite_model(
)
# copied from qwen2vl
_register_composite_model(
model_type="glm4v",
projector_key="visual.merger",
vision_model_keys=["visual.patch_embed", "visual.blocks"],
language_model_keys=["language_model", "lm_head"],
lora_conflict_keys=["patch_embed"],
)
_register_composite_model(
model_type="internvl",
)