mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-23 06:12:50 +08:00
[data] fix qwen2 omni plugin (#7875)
This commit is contained in:
parent
1bd319d16c
commit
00b5c05946
@ -1602,34 +1602,30 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
|
|||||||
processor: Optional["MMProcessor"],
|
processor: Optional["MMProcessor"],
|
||||||
) -> list[dict[str, str]]:
|
) -> list[dict[str, str]]:
|
||||||
self._validate_input(processor, images, videos, audios)
|
self._validate_input(processor, images, videos, audios)
|
||||||
|
num_image_tokens, num_video_tokens, num_audio_tokens = 0, 0, 0
|
||||||
messages = deepcopy(messages)
|
messages = deepcopy(messages)
|
||||||
|
image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
|
||||||
|
|
||||||
|
merge_length = processor.image_processor.merge_size**2
|
||||||
|
use_audio_in_video = getattr(processor, "use_audio_in_video", False)
|
||||||
if self.expand_mm_tokens:
|
if self.expand_mm_tokens:
|
||||||
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
||||||
|
image_grid_thw = mm_inputs.get("image_grid_thw", [])
|
||||||
|
video_grid_thw = mm_inputs.get("video_grid_thw", [])
|
||||||
|
if "feature_attention_mask" in mm_inputs:
|
||||||
|
input_lengths = (mm_inputs["feature_attention_mask"].sum(-1).numpy() - 1) // 2 + 1
|
||||||
|
audio_lengths = (input_lengths - 2) // 2 + 1
|
||||||
else:
|
else:
|
||||||
mm_inputs = {}
|
mm_inputs = {}
|
||||||
|
image_grid_thw = [None] * len(images)
|
||||||
|
video_grid_thw = [None] * len(videos)
|
||||||
|
audio_lengths = [None] * len(audios)
|
||||||
|
|
||||||
image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
|
if self.expand_mm_tokens and use_audio_in_video:
|
||||||
num_audio_tokens, num_image_tokens, num_video_tokens = 0, 0, 0
|
if "feature_attention_mask" not in mm_inputs:
|
||||||
use_audio_in_video = getattr(processor, "use_audio_in_video", False)
|
|
||||||
|
|
||||||
# get length or size from mm_inputs
|
|
||||||
if "feature_attention_mask" in mm_inputs:
|
|
||||||
input_lengths = (mm_inputs["feature_attention_mask"].sum(-1).numpy() - 1) // 2 + 1
|
|
||||||
audio_lengths = (input_lengths - 2) // 2 + 1
|
|
||||||
|
|
||||||
if mm_inputs.get("image_grid_thw", None) is not None:
|
|
||||||
image_grid_thw = mm_inputs["image_grid_thw"]
|
|
||||||
merge_length = processor.image_processor.merge_size**2
|
|
||||||
|
|
||||||
if mm_inputs.get("video_grid_thw", None) is not None:
|
|
||||||
video_grid_thw = mm_inputs["video_grid_thw"]
|
|
||||||
merge_length = processor.image_processor.merge_size**2
|
|
||||||
|
|
||||||
if use_audio_in_video:
|
|
||||||
if audio_lengths is None:
|
|
||||||
raise ValueError("audio_lengths should exist when use_audio_in_video is `True`.")
|
raise ValueError("audio_lengths should exist when use_audio_in_video is `True`.")
|
||||||
|
|
||||||
if mm_inputs.get("video_grid_thw", None) is None:
|
if "video_grid_thw" not in mm_inputs:
|
||||||
raise ValueError("video_grid_thw should exist when use_audio_in_video is `True`.")
|
raise ValueError("video_grid_thw should exist when use_audio_in_video is `True`.")
|
||||||
|
|
||||||
positions_list = []
|
positions_list = []
|
||||||
@ -1653,11 +1649,9 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
|
|||||||
if num_image_tokens >= len(images):
|
if num_image_tokens >= len(images):
|
||||||
raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
|
raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
|
||||||
|
|
||||||
image_token_replace_length = image_grid_thw[num_image_tokens].prod() // merge_length
|
image_seqlen = image_grid_thw[num_image_tokens].prod() // merge_length if self.expand_mm_tokens else 1
|
||||||
content = content.replace(
|
content = content.replace(
|
||||||
IMAGE_PLACEHOLDER,
|
IMAGE_PLACEHOLDER, f"<|vision_bos|>{self.image_token * image_seqlen}<|vision_eos|>", 1
|
||||||
f"<|vision_bos|>{self.image_token * image_token_replace_length}<|vision_eos|>",
|
|
||||||
1,
|
|
||||||
)
|
)
|
||||||
num_image_tokens += 1
|
num_image_tokens += 1
|
||||||
|
|
||||||
@ -1666,11 +1660,9 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
|
|||||||
if num_audio_tokens >= len(audios):
|
if num_audio_tokens >= len(audios):
|
||||||
raise ValueError(f"`len(audios)` is less than the number of {AUDIO_PLACEHOLDER} tokens.")
|
raise ValueError(f"`len(audios)` is less than the number of {AUDIO_PLACEHOLDER} tokens.")
|
||||||
|
|
||||||
audio_token_replace_length = audio_lengths[num_audio_tokens]
|
audio_seqlen = audio_lengths[num_audio_tokens] if self.expand_mm_tokens else 1
|
||||||
content = content.replace(
|
content = content.replace(
|
||||||
AUDIO_PLACEHOLDER,
|
AUDIO_PLACEHOLDER, f"<|audio_bos|>{self.audio_token * audio_seqlen}<|audio_eos|>", 1
|
||||||
f"<|audio_bos|>{self.audio_token * audio_token_replace_length}<|audio_eos|>",
|
|
||||||
1,
|
|
||||||
)
|
)
|
||||||
num_audio_tokens += 1
|
num_audio_tokens += 1
|
||||||
|
|
||||||
@ -1679,9 +1671,9 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
|
|||||||
if num_video_tokens >= len(videos):
|
if num_video_tokens >= len(videos):
|
||||||
raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.")
|
raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.")
|
||||||
|
|
||||||
video_replace_length = video_grid_thw[num_video_tokens].prod() // merge_length
|
video_seqlen = video_grid_thw[num_video_tokens].prod() // merge_length if self.expand_mm_tokens else 1
|
||||||
content = content.replace(
|
content = content.replace(
|
||||||
VIDEO_PLACEHOLDER, f"<|vision_bos|>{self.video_token * video_replace_length}<|vision_eos|>", 1
|
VIDEO_PLACEHOLDER, f"<|vision_bos|>{self.video_token * video_seqlen}<|vision_eos|>", 1
|
||||||
)
|
)
|
||||||
num_video_tokens += 1
|
num_video_tokens += 1
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user