[data] fix qwen2 omni plugin (#7875)

This commit is contained in:
hoshi-hiyouga 2025-04-28 14:22:41 +08:00 committed by GitHub
parent 1bd319d16c
commit 00b5c05946
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1602,34 +1602,30 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
processor: Optional["MMProcessor"],
) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios)
num_image_tokens, num_video_tokens, num_audio_tokens = 0, 0, 0
messages = deepcopy(messages)
image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
merge_length = processor.image_processor.merge_size**2
use_audio_in_video = getattr(processor, "use_audio_in_video", False)
if self.expand_mm_tokens:
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
image_grid_thw = mm_inputs.get("image_grid_thw", [])
video_grid_thw = mm_inputs.get("video_grid_thw", [])
if "feature_attention_mask" in mm_inputs:
input_lengths = (mm_inputs["feature_attention_mask"].sum(-1).numpy() - 1) // 2 + 1
audio_lengths = (input_lengths - 2) // 2 + 1
else:
mm_inputs = {}
image_grid_thw = [None] * len(images)
video_grid_thw = [None] * len(videos)
audio_lengths = [None] * len(audios)
image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
num_audio_tokens, num_image_tokens, num_video_tokens = 0, 0, 0
use_audio_in_video = getattr(processor, "use_audio_in_video", False)
# get length or size from mm_inputs
if "feature_attention_mask" in mm_inputs:
input_lengths = (mm_inputs["feature_attention_mask"].sum(-1).numpy() - 1) // 2 + 1
audio_lengths = (input_lengths - 2) // 2 + 1
if mm_inputs.get("image_grid_thw", None) is not None:
image_grid_thw = mm_inputs["image_grid_thw"]
merge_length = processor.image_processor.merge_size**2
if mm_inputs.get("video_grid_thw", None) is not None:
video_grid_thw = mm_inputs["video_grid_thw"]
merge_length = processor.image_processor.merge_size**2
if use_audio_in_video:
if audio_lengths is None:
if self.expand_mm_tokens and use_audio_in_video:
if "feature_attention_mask" not in mm_inputs:
raise ValueError("audio_lengths should exist when use_audio_in_video is `True`.")
if mm_inputs.get("video_grid_thw", None) is None:
if "video_grid_thw" not in mm_inputs:
raise ValueError("video_grid_thw should exist when use_audio_in_video is `True`.")
positions_list = []
@ -1653,11 +1649,9 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
if num_image_tokens >= len(images):
raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
image_token_replace_length = image_grid_thw[num_image_tokens].prod() // merge_length
image_seqlen = image_grid_thw[num_image_tokens].prod() // merge_length if self.expand_mm_tokens else 1
content = content.replace(
IMAGE_PLACEHOLDER,
f"<|vision_bos|>{self.image_token * image_token_replace_length}<|vision_eos|>",
1,
IMAGE_PLACEHOLDER, f"<|vision_bos|>{self.image_token * image_seqlen}<|vision_eos|>", 1
)
num_image_tokens += 1
@ -1666,11 +1660,9 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
if num_audio_tokens >= len(audios):
raise ValueError(f"`len(audios)` is less than the number of {AUDIO_PLACEHOLDER} tokens.")
audio_token_replace_length = audio_lengths[num_audio_tokens]
audio_seqlen = audio_lengths[num_audio_tokens] if self.expand_mm_tokens else 1
content = content.replace(
AUDIO_PLACEHOLDER,
f"<|audio_bos|>{self.audio_token * audio_token_replace_length}<|audio_eos|>",
1,
AUDIO_PLACEHOLDER, f"<|audio_bos|>{self.audio_token * audio_seqlen}<|audio_eos|>", 1
)
num_audio_tokens += 1
@ -1679,9 +1671,9 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
if num_video_tokens >= len(videos):
raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.")
video_replace_length = video_grid_thw[num_video_tokens].prod() // merge_length
video_seqlen = video_grid_thw[num_video_tokens].prod() // merge_length if self.expand_mm_tokens else 1
content = content.replace(
VIDEO_PLACEHOLDER, f"<|vision_bos|>{self.video_token * video_replace_length}<|vision_eos|>", 1
VIDEO_PLACEHOLDER, f"<|vision_bos|>{self.video_token * video_seqlen}<|vision_eos|>", 1
)
num_video_tokens += 1