diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py index 2cb3e320..ffa029a7 100644 --- a/src/llamafactory/data/mm_plugin.py +++ b/src/llamafactory/data/mm_plugin.py @@ -1558,11 +1558,7 @@ class GLM4VPlugin(Qwen2VLPlugin): video_metadata = [ {"fps": 2, "duration": len(video), "total_frames": len(video)} for video in video_data["videos"] ] - mm_inputs.update( - video_processor( - images=None, videos=video_data["videos"], video_metadata=video_metadata, return_tensors="pt" - ) - ) + mm_inputs.update(video_processor(images=None, videos=video_data["videos"], video_metadata=video_metadata)) return mm_inputs @@ -1586,8 +1582,9 @@ class GLM4VPlugin(Qwen2VLPlugin): mm_inputs = self._get_mm_inputs(images, videos, audios, processor) image_grid_thw = mm_inputs.get("image_grid_thw", []) video_grid_thw = mm_inputs.get("video_grid_thw", []) - num_frames = len(video_grid_thw) + num_frames = video_grid_thw[0][0] if len(video_grid_thw) > 0 else 0 # hard code for now timestamps = mm_inputs.get("timestamps", []) + if hasattr(timestamps, "tolist"): timestamps = timestamps.tolist() @@ -1618,19 +1615,20 @@ class GLM4VPlugin(Qwen2VLPlugin): ) num_image_tokens += 1 - # TODO: DO NOT SUPPORT VIDEO UNTIL NEXT PR while VIDEO_PLACEHOLDER in content: video_structure = "" for frame_index in range(num_frames): - video_seqlen = video_grid_thw[frame_index].prod() // merge_length if self.expand_mm_tokens else 1 + video_seqlen = ( + video_grid_thw[num_video_tokens][1:].prod() // merge_length if self.expand_mm_tokens else 1 + ) timestamp_sec = selected_timestamps[frame_index] frame_structure = ( f"<|begin_of_image|>{self.image_token * video_seqlen}<|end_of_image|>{timestamp_sec}" ) video_structure += frame_structure - content = content.replace(VIDEO_PLACEHOLDER, video_structure, 1) - num_video_tokens += 1 # FIXME: num_video_tokens is not used + content = content.replace(VIDEO_PLACEHOLDER, f"<|begin_of_video|>{video_structure}<|end_of_video|>", 1) + num_video_tokens += 1 message["content"] = content