mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-23 06:12:50 +08:00
[data] fix qwen2 omni plugin (#7875)
This commit is contained in:
parent
1bd319d16c
commit
00b5c05946
@ -1602,34 +1602,30 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
|
||||
processor: Optional["MMProcessor"],
|
||||
) -> list[dict[str, str]]:
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
num_image_tokens, num_video_tokens, num_audio_tokens = 0, 0, 0
|
||||
messages = deepcopy(messages)
|
||||
image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
|
||||
|
||||
merge_length = processor.image_processor.merge_size**2
|
||||
use_audio_in_video = getattr(processor, "use_audio_in_video", False)
|
||||
if self.expand_mm_tokens:
|
||||
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
||||
image_grid_thw = mm_inputs.get("image_grid_thw", [])
|
||||
video_grid_thw = mm_inputs.get("video_grid_thw", [])
|
||||
if "feature_attention_mask" in mm_inputs:
|
||||
input_lengths = (mm_inputs["feature_attention_mask"].sum(-1).numpy() - 1) // 2 + 1
|
||||
audio_lengths = (input_lengths - 2) // 2 + 1
|
||||
else:
|
||||
mm_inputs = {}
|
||||
image_grid_thw = [None] * len(images)
|
||||
video_grid_thw = [None] * len(videos)
|
||||
audio_lengths = [None] * len(audios)
|
||||
|
||||
image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
|
||||
num_audio_tokens, num_image_tokens, num_video_tokens = 0, 0, 0
|
||||
use_audio_in_video = getattr(processor, "use_audio_in_video", False)
|
||||
|
||||
# get length or size from mm_inputs
|
||||
if "feature_attention_mask" in mm_inputs:
|
||||
input_lengths = (mm_inputs["feature_attention_mask"].sum(-1).numpy() - 1) // 2 + 1
|
||||
audio_lengths = (input_lengths - 2) // 2 + 1
|
||||
|
||||
if mm_inputs.get("image_grid_thw", None) is not None:
|
||||
image_grid_thw = mm_inputs["image_grid_thw"]
|
||||
merge_length = processor.image_processor.merge_size**2
|
||||
|
||||
if mm_inputs.get("video_grid_thw", None) is not None:
|
||||
video_grid_thw = mm_inputs["video_grid_thw"]
|
||||
merge_length = processor.image_processor.merge_size**2
|
||||
|
||||
if use_audio_in_video:
|
||||
if audio_lengths is None:
|
||||
if self.expand_mm_tokens and use_audio_in_video:
|
||||
if "feature_attention_mask" not in mm_inputs:
|
||||
raise ValueError("audio_lengths should exist when use_audio_in_video is `True`.")
|
||||
|
||||
if mm_inputs.get("video_grid_thw", None) is None:
|
||||
if "video_grid_thw" not in mm_inputs:
|
||||
raise ValueError("video_grid_thw should exist when use_audio_in_video is `True`.")
|
||||
|
||||
positions_list = []
|
||||
@ -1653,11 +1649,9 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
|
||||
if num_image_tokens >= len(images):
|
||||
raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
|
||||
|
||||
image_token_replace_length = image_grid_thw[num_image_tokens].prod() // merge_length
|
||||
image_seqlen = image_grid_thw[num_image_tokens].prod() // merge_length if self.expand_mm_tokens else 1
|
||||
content = content.replace(
|
||||
IMAGE_PLACEHOLDER,
|
||||
f"<|vision_bos|>{self.image_token * image_token_replace_length}<|vision_eos|>",
|
||||
1,
|
||||
IMAGE_PLACEHOLDER, f"<|vision_bos|>{self.image_token * image_seqlen}<|vision_eos|>", 1
|
||||
)
|
||||
num_image_tokens += 1
|
||||
|
||||
@ -1666,11 +1660,9 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
|
||||
if num_audio_tokens >= len(audios):
|
||||
raise ValueError(f"`len(audios)` is less than the number of {AUDIO_PLACEHOLDER} tokens.")
|
||||
|
||||
audio_token_replace_length = audio_lengths[num_audio_tokens]
|
||||
audio_seqlen = audio_lengths[num_audio_tokens] if self.expand_mm_tokens else 1
|
||||
content = content.replace(
|
||||
AUDIO_PLACEHOLDER,
|
||||
f"<|audio_bos|>{self.audio_token * audio_token_replace_length}<|audio_eos|>",
|
||||
1,
|
||||
AUDIO_PLACEHOLDER, f"<|audio_bos|>{self.audio_token * audio_seqlen}<|audio_eos|>", 1
|
||||
)
|
||||
num_audio_tokens += 1
|
||||
|
||||
@ -1679,9 +1671,9 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
|
||||
if num_video_tokens >= len(videos):
|
||||
raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.")
|
||||
|
||||
video_replace_length = video_grid_thw[num_video_tokens].prod() // merge_length
|
||||
video_seqlen = video_grid_thw[num_video_tokens].prod() // merge_length if self.expand_mm_tokens else 1
|
||||
content = content.replace(
|
||||
VIDEO_PLACEHOLDER, f"<|vision_bos|>{self.video_token * video_replace_length}<|vision_eos|>", 1
|
||||
VIDEO_PLACEHOLDER, f"<|vision_bos|>{self.video_token * video_seqlen}<|vision_eos|>", 1
|
||||
)
|
||||
num_video_tokens += 1
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user