mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-07-31 10:42:50 +08:00
[model] add GLM-4.1V (#8462)
This commit is contained in:
parent
0a004904bd
commit
0b188ca00c
@ -210,7 +210,8 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
|
||||
if (
|
||||
self.model is not None
|
||||
and getattr(self.model.config, "model_type", None) in ["qwen2_vl", "qwen2_5_vl", "qwen2_5_omni_thinker"]
|
||||
and getattr(self.model.config, "model_type", None)
|
||||
in ["glm4v", "qwen2_vl", "qwen2_5_vl", "qwen2_5_omni_thinker"]
|
||||
and ("position_ids" not in features or features["position_ids"].dim() != 3)
|
||||
):
|
||||
raise ValueError("Qwen2-VL/Qwen2.5-Omni model requires 3D position ids for mrope.")
|
||||
|
@ -1498,6 +1498,135 @@ class Qwen2VLPlugin(BasePlugin):
|
||||
return messages
|
||||
|
||||
|
||||
@dataclass
|
||||
class GLM4VPlugin(Qwen2VLPlugin):
|
||||
@override
|
||||
def _get_mm_inputs(
|
||||
self,
|
||||
images: list["ImageInput"],
|
||||
videos: list["VideoInput"],
|
||||
audios: list["AudioInput"],
|
||||
processor: "MMProcessor",
|
||||
) -> dict[str, "torch.Tensor"]:
|
||||
image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
|
||||
video_processor: BaseImageProcessor = getattr(processor, "video_processor", None)
|
||||
mm_inputs = {}
|
||||
if len(images) != 0:
|
||||
images = self._regularize_images(
|
||||
images,
|
||||
image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768),
|
||||
image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32),
|
||||
)["images"]
|
||||
mm_inputs.update(image_processor(images, return_tensors="pt"))
|
||||
|
||||
if len(videos) != 0:
|
||||
video_data = self._regularize_videos(
|
||||
videos,
|
||||
image_max_pixels=getattr(processor, "video_max_pixels", 256 * 256),
|
||||
image_min_pixels=getattr(processor, "video_min_pixels", 16 * 16),
|
||||
video_fps=getattr(processor, "video_fps", 2.0),
|
||||
video_maxlen=getattr(processor, "video_maxlen", 128),
|
||||
)
|
||||
# prepare video metadata
|
||||
video_metadata = [
|
||||
{"fps": 2, "duration": len(video), "total_frames": len(video)} for video in video_data["videos"]
|
||||
]
|
||||
mm_inputs.update(
|
||||
video_processor(
|
||||
images=None, videos=video_data["videos"], video_metadata=video_metadata, return_tensors="pt"
|
||||
)
|
||||
)
|
||||
|
||||
return mm_inputs
|
||||
|
||||
@override
|
||||
def process_messages(
|
||||
self,
|
||||
messages: list[dict[str, str]],
|
||||
images: list["ImageInput"],
|
||||
videos: list["VideoInput"],
|
||||
audios: list["AudioInput"],
|
||||
processor: Optional["MMProcessor"],
|
||||
) -> list[dict[str, str]]:
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
self._validate_messages(messages, images, videos, audios)
|
||||
num_image_tokens, num_video_tokens = 0, 0
|
||||
messages = deepcopy(messages)
|
||||
image_processor: BaseImageProcessor = getattr(processor, "image_processor")
|
||||
|
||||
merge_length: int = getattr(image_processor, "merge_size") ** 2
|
||||
if self.expand_mm_tokens:
|
||||
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
||||
image_grid_thw = mm_inputs.get("image_grid_thw", [])
|
||||
video_grid_thw = mm_inputs.get("video_grid_thw", [])
|
||||
num_frames = len(video_grid_thw)
|
||||
timestamps = mm_inputs.get("timestamps", [])
|
||||
if hasattr(timestamps, "tolist"):
|
||||
timestamps = timestamps.tolist()
|
||||
|
||||
if not timestamps:
|
||||
timestamps_list = []
|
||||
elif isinstance(timestamps[0], list):
|
||||
timestamps_list = timestamps[0]
|
||||
else:
|
||||
timestamps_list = timestamps
|
||||
|
||||
unique_timestamps = timestamps_list.copy()
|
||||
selected_timestamps = unique_timestamps[:num_frames]
|
||||
while len(selected_timestamps) < num_frames:
|
||||
selected_timestamps.append(selected_timestamps[-1] if selected_timestamps else 0)
|
||||
|
||||
else:
|
||||
image_grid_thw = [None] * len(images)
|
||||
video_grid_thw = [None] * len(videos)
|
||||
num_frames = 0
|
||||
selected_timestamps = [0]
|
||||
|
||||
for message in messages:
|
||||
content = message["content"]
|
||||
while IMAGE_PLACEHOLDER in content:
|
||||
image_seqlen = image_grid_thw[num_image_tokens].prod() // merge_length if self.expand_mm_tokens else 1
|
||||
content = content.replace(
|
||||
IMAGE_PLACEHOLDER, f"<|begin_of_image|>{self.image_token * image_seqlen}<|end_of_image|>", 1
|
||||
)
|
||||
num_image_tokens += 1
|
||||
|
||||
# TODO: DO NOT SUPPORT VIDEO UNTIL NEXT PR
|
||||
while VIDEO_PLACEHOLDER in content:
|
||||
video_structure = ""
|
||||
for frame_index in range(num_frames):
|
||||
video_seqlen = video_grid_thw[frame_index].prod() // merge_length if self.expand_mm_tokens else 1
|
||||
timestamp_sec = selected_timestamps[frame_index]
|
||||
frame_structure = (
|
||||
f"<|begin_of_image|>{self.image_token * video_seqlen}<|end_of_image|>{timestamp_sec}"
|
||||
)
|
||||
video_structure += frame_structure
|
||||
|
||||
content = content.replace(VIDEO_PLACEHOLDER, video_structure, 1)
|
||||
num_video_tokens += 1 # FIXME: num_video_tokens is not used
|
||||
|
||||
message["content"] = content
|
||||
|
||||
return messages
|
||||
|
||||
@override
|
||||
def get_mm_inputs(
|
||||
self,
|
||||
images: list["ImageInput"],
|
||||
videos: list["VideoInput"],
|
||||
audios: list["AudioInput"],
|
||||
imglens: list[int],
|
||||
vidlens: list[int],
|
||||
audlens: list[int],
|
||||
batch_ids: list[list[int]],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> dict[str, Union[list[int], "torch.Tensor"]]:
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
||||
mm_inputs.pop("timestamps", None)
|
||||
return mm_inputs
|
||||
|
||||
|
||||
class Qwen2OmniPlugin(Qwen2VLPlugin):
|
||||
@override
|
||||
def _get_mm_inputs(
|
||||
@ -1715,6 +1844,7 @@ class VideoLlavaPlugin(BasePlugin):
|
||||
PLUGINS = {
|
||||
"base": BasePlugin,
|
||||
"gemma3": Gemma3Plugin,
|
||||
"glm4v": GLM4VPlugin,
|
||||
"intern_vl": InternVLPlugin,
|
||||
"kimi_vl": KimiVLPlugin,
|
||||
"llama4": Llama4Plugin,
|
||||
|
@ -998,6 +998,24 @@ register_template(
|
||||
)
|
||||
|
||||
|
||||
# part copied from glm4 template
|
||||
register_template(
|
||||
name="glm4v",
|
||||
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]),
|
||||
format_assistant=StringFormatter(slots=["\n{{content}}"]),
|
||||
format_system=StringFormatter(slots=["<|system|>\n{{content}}"]),
|
||||
format_function=FunctionFormatter(slots=["{{content}}"], tool_format="glm4"),
|
||||
format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>"]),
|
||||
format_tools=ToolFormatter(tool_format="glm4"),
|
||||
format_prefix=EmptyFormatter(slots=["[gMASK]<sop>"]),
|
||||
stop_words=["<|user|>", "<|observation|>"],
|
||||
efficient_eos=True,
|
||||
mm_plugin=get_mm_plugin(name="glm4v", image_token="<|image|>", video_token="<|video|>"),
|
||||
template_class=ReasoningTemplate,
|
||||
thought_words=("<think>", "</think>"),
|
||||
)
|
||||
|
||||
|
||||
# copied from glm4 template
|
||||
register_template(
|
||||
name="glmz1",
|
||||
|
@ -834,6 +834,18 @@ register_model_group(
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"GLM-4.1V-9B-Thinking": {
|
||||
DownloadSource.DEFAULT: "THUDM/GLM-4.1V-9B-Thinking",
|
||||
DownloadSource.MODELSCOPE: "ZhipuAI/GLM-4.1V-9B-Thinking",
|
||||
}
|
||||
},
|
||||
template="glm4v",
|
||||
multimodal=True,
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"GLM-Z1-0414-9B-Chat": {
|
||||
|
@ -204,6 +204,16 @@ _register_composite_model(
|
||||
)
|
||||
|
||||
|
||||
# copied from qwen2vl
|
||||
_register_composite_model(
|
||||
model_type="glm4v",
|
||||
projector_key="visual.merger",
|
||||
vision_model_keys=["visual.patch_embed", "visual.blocks"],
|
||||
language_model_keys=["language_model", "lm_head"],
|
||||
lora_conflict_keys=["patch_embed"],
|
||||
)
|
||||
|
||||
|
||||
_register_composite_model(
|
||||
model_type="internvl",
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user