mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-01 11:12:50 +08:00
[data] support glm4.1v video training (#8571)
This commit is contained in:
parent
6a8d88826e
commit
766884fa5c
@ -1558,11 +1558,7 @@ class GLM4VPlugin(Qwen2VLPlugin):
|
|||||||
video_metadata = [
|
video_metadata = [
|
||||||
{"fps": 2, "duration": len(video), "total_frames": len(video)} for video in video_data["videos"]
|
{"fps": 2, "duration": len(video), "total_frames": len(video)} for video in video_data["videos"]
|
||||||
]
|
]
|
||||||
mm_inputs.update(
|
mm_inputs.update(video_processor(images=None, videos=video_data["videos"], video_metadata=video_metadata))
|
||||||
video_processor(
|
|
||||||
images=None, videos=video_data["videos"], video_metadata=video_metadata, return_tensors="pt"
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return mm_inputs
|
return mm_inputs
|
||||||
|
|
||||||
@ -1586,8 +1582,9 @@ class GLM4VPlugin(Qwen2VLPlugin):
|
|||||||
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
||||||
image_grid_thw = mm_inputs.get("image_grid_thw", [])
|
image_grid_thw = mm_inputs.get("image_grid_thw", [])
|
||||||
video_grid_thw = mm_inputs.get("video_grid_thw", [])
|
video_grid_thw = mm_inputs.get("video_grid_thw", [])
|
||||||
num_frames = len(video_grid_thw)
|
num_frames = video_grid_thw[0][0] if len(video_grid_thw) > 0 else 0 # hard code for now
|
||||||
timestamps = mm_inputs.get("timestamps", [])
|
timestamps = mm_inputs.get("timestamps", [])
|
||||||
|
|
||||||
if hasattr(timestamps, "tolist"):
|
if hasattr(timestamps, "tolist"):
|
||||||
timestamps = timestamps.tolist()
|
timestamps = timestamps.tolist()
|
||||||
|
|
||||||
@ -1618,19 +1615,20 @@ class GLM4VPlugin(Qwen2VLPlugin):
|
|||||||
)
|
)
|
||||||
num_image_tokens += 1
|
num_image_tokens += 1
|
||||||
|
|
||||||
# TODO: DO NOT SUPPORT VIDEO UNTIL NEXT PR
|
|
||||||
while VIDEO_PLACEHOLDER in content:
|
while VIDEO_PLACEHOLDER in content:
|
||||||
video_structure = ""
|
video_structure = ""
|
||||||
for frame_index in range(num_frames):
|
for frame_index in range(num_frames):
|
||||||
video_seqlen = video_grid_thw[frame_index].prod() // merge_length if self.expand_mm_tokens else 1
|
video_seqlen = (
|
||||||
|
video_grid_thw[num_video_tokens][1:].prod() // merge_length if self.expand_mm_tokens else 1
|
||||||
|
)
|
||||||
timestamp_sec = selected_timestamps[frame_index]
|
timestamp_sec = selected_timestamps[frame_index]
|
||||||
frame_structure = (
|
frame_structure = (
|
||||||
f"<|begin_of_image|>{self.image_token * video_seqlen}<|end_of_image|>{timestamp_sec}"
|
f"<|begin_of_image|>{self.image_token * video_seqlen}<|end_of_image|>{timestamp_sec}"
|
||||||
)
|
)
|
||||||
video_structure += frame_structure
|
video_structure += frame_structure
|
||||||
|
|
||||||
content = content.replace(VIDEO_PLACEHOLDER, video_structure, 1)
|
content = content.replace(VIDEO_PLACEHOLDER, f"<|begin_of_video|>{video_structure}<|end_of_video|>", 1)
|
||||||
num_video_tokens += 1 # FIXME: num_video_tokens is not used
|
num_video_tokens += 1
|
||||||
|
|
||||||
message["content"] = content
|
message["content"] = content
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user