diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py index b653f57b..e9df7e33 100644 --- a/src/llamafactory/data/mm_plugin.py +++ b/src/llamafactory/data/mm_plugin.py @@ -1621,30 +1621,8 @@ class Qwen2OmniPlugin(Qwen2VLPlugin): video_grid_thw = [None] * len(videos) audio_lengths = [None] * len(audios) - if self.expand_mm_tokens and use_audio_in_video: - if "feature_attention_mask" not in mm_inputs: - raise ValueError("audio_lengths should exist when use_audio_in_video is `True`.") - - if "video_grid_thw" not in mm_inputs: - raise ValueError("video_grid_thw should exist when use_audio_in_video is `True`.") - - positions_list = [] - for message in messages: # get multimodal index when use_audio - positions = [] - for special_token in [self.audio_token, self.image_token, self.video_token]: - start = 0 - while True: - pos = message["content"].find(special_token, start) - if pos == -1: - break - positions.append((pos, special_token)) - start = pos + len(special_token) - - positions_list.append(positions.sort(key=lambda x: x[0])) - for message in messages: content = message["content"] - # separate with audio-video while IMAGE_PLACEHOLDER in content: if num_image_tokens >= len(images): raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.") @@ -1655,34 +1633,26 @@ class Qwen2OmniPlugin(Qwen2VLPlugin): ) num_image_tokens += 1 - if not use_audio_in_video: - while AUDIO_PLACEHOLDER in content: + if ( + use_audio_in_video and len(audios) and len(videos) + ): # if use the audio of video # deal video token and audio token togather + if len(videos) != len(audios): + raise ValueError( + f"Number of videos ({len(videos)}) must match number of audios ({len(audios)}) when using audio in video." + ) + + while VIDEO_PLACEHOLDER in content: + if num_video_tokens >= len(videos): + raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.") if num_audio_tokens >= len(audios): raise ValueError(f"`len(audios)` is less than the number of {AUDIO_PLACEHOLDER} tokens.") - audio_seqlen = audio_lengths[num_audio_tokens] if self.expand_mm_tokens else 1 - content = content.replace( - AUDIO_PLACEHOLDER, f"<|audio_bos|>{self.audio_token * audio_seqlen}<|audio_eos|>", 1 - ) - num_audio_tokens += 1 - - # TODO handle video_input and use_audio_in_video - while VIDEO_PLACEHOLDER in content: - if num_video_tokens >= len(videos): - raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.") - - video_seqlen = ( - video_grid_thw[num_video_tokens].prod() // merge_length if self.expand_mm_tokens else 1 - ) - content = content.replace( - VIDEO_PLACEHOLDER, f"<|vision_bos|>{self.video_token * video_seqlen}<|vision_eos|>", 1 - ) - num_video_tokens += 1 - - else: # if use the audio of video # deal video token and audio token togather - while VIDEO_PLACEHOLDER in content: - if num_video_tokens >= len(videos): - raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.") + video_pos = content.find(VIDEO_PLACEHOLDER) + audio_pos = content.find(AUDIO_PLACEHOLDER, video_pos) + if audio_pos == -1 or audio_pos < video_pos: + raise ValueError( + f"Each {VIDEO_PLACEHOLDER} must be followed by an {AUDIO_PLACEHOLDER} when using audio in video." + ) audio_t_index = torch.arange(audio_lengths[num_audio_tokens]) video_t_index = ( @@ -1716,6 +1686,28 @@ class Qwen2OmniPlugin(Qwen2VLPlugin): content = content.replace(AUDIO_PLACEHOLDER, "", 1) num_audio_tokens += 1 num_video_tokens += 1 + else: + while AUDIO_PLACEHOLDER in content: + if num_audio_tokens >= len(audios): + raise ValueError(f"`len(audios)` is less than the number of {AUDIO_PLACEHOLDER} tokens.") + + audio_seqlen = audio_lengths[num_audio_tokens] if self.expand_mm_tokens else 1 + content = content.replace( + AUDIO_PLACEHOLDER, f"<|audio_bos|>{self.audio_token * audio_seqlen}<|audio_eos|>", 1 + ) + num_audio_tokens += 1 + + while VIDEO_PLACEHOLDER in content: + if num_video_tokens >= len(videos): + raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.") + + video_seqlen = ( + video_grid_thw[num_video_tokens].prod() // merge_length if self.expand_mm_tokens else 1 + ) + content = content.replace( + VIDEO_PLACEHOLDER, f"<|vision_bos|>{self.video_token * video_seqlen}<|vision_eos|>", 1 + ) + num_video_tokens += 1 message["content"] = content diff --git a/tests/data/test_mm_plugin.py b/tests/data/test_mm_plugin.py index a98e8f17..a19de12a 100644 --- a/tests/data/test_mm_plugin.py +++ b/tests/data/test_mm_plugin.py @@ -15,6 +15,7 @@ import os from typing import TYPE_CHECKING, Any +import numpy as np import pytest import torch from PIL import Image @@ -43,11 +44,20 @@ MM_MESSAGES = [ {"role": "assistant", "content": "A cat."}, ] +OMNI_MESSAGES = [ + {"role": "user", "content": "What is in this image?"}, + {"role": "assistant", "content": "A cat."}, + {"role": "user", "content": "