[data] Fix bugs of use_audio_in_video in Qwen2.5 Omni (#7638)

* cache _mm_inputs

* nit

* support for use_audio_in_video

* remove cache

* fix data

* Update mllm_video_audio_demo.json
This commit is contained in:
Kingsley 2025-04-08 18:40:10 +08:00 committed by GitHub
parent acb09fa3a3
commit 349c56c51c
5 changed files with 11 additions and 8 deletions

BIN
data/mllm_demo_data/4.mp3 Normal file

Binary file not shown.

BIN
data/mllm_demo_data/4.mp4 Normal file

Binary file not shown.

View File

@ -184,6 +184,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
rope_index_kwargs["second_per_grids"] = mm_inputs.get("video_second_per_grid")
if getattr(self.model.config, "model_type", None) == "qwen2_5_omni_thinker": # for qwen2omni
rope_index_kwargs["use_audio_in_video"] = getattr(self.processor, "use_audio_in_video", False)
feature_attention_mask = mm_inputs.get("feature_attention_mask", None)
if feature_attention_mask is not None:
audio_feature_lengths = torch.sum(

View File

@ -1378,6 +1378,7 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
else:
mm_inputs = {}
image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
num_audio_tokens, num_image_tokens, num_video_tokens = 0, 0, 0
use_audio_in_video = getattr(processor, "use_audio_in_video", False)
@ -1398,16 +1399,16 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
if audio_lengths is None:
raise ValueError("audio_lengths should exist when use_audio_in_video is `True`.")
if not mm_inputs.get("video_grid_thw", None):
if mm_inputs.get("video_grid_thw", None) is None:
raise ValueError("video_grid_thw should exist when use_audio_in_video is `True`.")
positions_list = []
for i, message in enumerate(messages): # get multimodal index when use_audio
for message in messages: # get multimodal index when use_audio
positions = []
for special_token in [self.audio_token, self.image_token, self.video_token]:
start = 0
while True:
pos = message[i].find(special_token, start)
pos = message["content"].find(special_token, start)
if pos == -1:
break
positions.append((pos, special_token))
@ -1453,8 +1454,8 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
.view(-1, 1, 1)
.expand(
-1,
video_grid_thw[num_video_tokens][1] // self.image_processor.merge_size,
video_grid_thw[num_video_tokens][2] // self.image_processor.merge_size,
video_grid_thw[num_video_tokens][1] // image_processor.merge_size,
video_grid_thw[num_video_tokens][2] // image_processor.merge_size,
)
.flatten()
* mm_inputs["video_second_per_grid"][num_video_tokens]
@ -1462,17 +1463,17 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
).long()
t_ntoken_per_chunk = 50 # FIXME hardcode: [25 * 2]
video_chunk_indices = processor.get_chunked_index(video_t_index, t_ntoken_per_chunk)
audio_chunk_indices = self.get_chunked_index(audio_t_index, t_ntoken_per_chunk)
audio_chunk_indices = processor.get_chunked_index(audio_t_index, t_ntoken_per_chunk)
placeholder_string = ""
placeholder_string += "<|vision_bos|>" + "<|audio_bos|>"
for j in range(max(len(video_chunk_indices), len(audio_chunk_indices))):
video_chunk_index = video_chunk_indices[j] if j < len(video_chunk_indices) else None
audio_chunk_index = audio_chunk_indices[j] if j < len(audio_chunk_indices) else None
placeholder_string = "<|vision_bos|>" + "<|audio_bos|>"
if video_chunk_index is not None:
placeholder_string += self.video_token * (video_chunk_index[1] - video_chunk_index[0])
if audio_chunk_index is not None:
placeholder_string += self.audio_token * (audio_chunk_index[1] - audio_chunk_index[0])
placeholder_string += "<|audio_eos|>" + "<|vision_eos|>"
placeholder_string += "<|audio_eos|>" + "<|vision_eos|>"
content = content.replace(VIDEO_PLACEHOLDER, placeholder_string, 1)
content = content.replace(AUDIO_PLACEHOLDER, "", 1)

View File

@ -79,6 +79,7 @@ def patch_processor(
setattr(processor, "video_fps", model_args.video_fps)
setattr(processor, "video_maxlen", model_args.video_maxlen)
setattr(processor, "audio_sampling_rate", model_args.audio_sampling_rate)
setattr(processor, "use_audio_in_video", model_args.use_audio_in_video)
def patch_config(