[data] Fix bugs of use_audio_in_video in Qwen2.5 Omni (#7638)

* cache _mm_inputs

* nit

* support for use_audio_in_video

* remove cache

* fix data

* Update mllm_video_audio_demo.json
This commit is contained in:
Kingsley 2025-04-08 18:40:10 +08:00 committed by GitHub
parent 8f5f4cc559
commit 7d8bee96fc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 83 additions and 8 deletions

View File

@ -66,6 +66,21 @@
"assistant_tag": "assistant"
}
},
"mllm_video_audio_demo": {
"file_name": "mllm_video_audio_demo.json",
"formatting": "sharegpt",
"columns": {
"messages": "messages",
"videos": "videos",
"audios": "audios"
},
"tags": {
"role_tag": "role",
"content_tag": "content",
"user_tag": "user",
"assistant_tag": "assistant"
}
},
"alpaca_en": {
"hf_hub_url": "llamafactory/alpaca_en",
"ms_hub_url": "llamafactory/alpaca_en",

BIN
data/mllm_demo_data/4.mp3 Normal file

Binary file not shown.

BIN
data/mllm_demo_data/4.mp4 Normal file

Binary file not shown.

View File

@ -0,0 +1,57 @@
[
{
"messages": [
{
"content": "<video><audio>What is the video describing?",
"role": "user"
},
{
"content": "A girl who is drawing a picture of a guitar and feel nervous.",
"role": "assistant"
}
],
"videos": [
"mllm_demo_data/4.mp4"
],
"audios": [
"mllm_demo_data/4.mp3"
]
},
{
"messages": [
{
"content": "<video><audio>What does this girl say?",
"role": "user"
},
{
"content": "She says: 'Hello! Take a look at what am I drawing!'",
"role": "assistant"
}
],
"videos": [
"mllm_demo_data/4.mp4"
],
"audios": [
"mllm_demo_data/4.mp3"
]
},
{
"messages": [
{
"content": "<video><audio>What is this girl drawing with?",
"role": "user"
},
{
"content": "She is drawing with an iPad.",
"role": "assistant"
}
],
"videos": [
"mllm_demo_data/4.mp4"
],
"audios": [
"mllm_demo_data/4.mp3"
]
}
]

View File

@ -184,6 +184,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
rope_index_kwargs["second_per_grids"] = mm_inputs.get("video_second_per_grid")
if getattr(self.model.config, "model_type", None) == "qwen2_5_omni_thinker": # for qwen2omni
rope_index_kwargs["use_audio_in_video"] = getattr(self.processor, "use_audio_in_video", False)
feature_attention_mask = mm_inputs.get("feature_attention_mask", None)
if feature_attention_mask is not None:
audio_feature_lengths = torch.sum(

View File

@ -1378,6 +1378,7 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
else:
mm_inputs = {}
image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
num_audio_tokens, num_image_tokens, num_video_tokens = 0, 0, 0
use_audio_in_video = getattr(processor, "use_audio_in_video", False)
@ -1398,16 +1399,16 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
if audio_lengths is None:
raise ValueError("audio_lengths should exist when use_audio_in_video is `True`.")
if not mm_inputs.get("video_grid_thw", None):
if mm_inputs.get("video_grid_thw", None) is None:
raise ValueError("video_grid_thw should exist when use_audio_in_video is `True`.")
positions_list = []
for i, message in enumerate(messages): # get multimodal index when use_audio
for message in messages: # get multimodal index when use_audio
positions = []
for special_token in [self.audio_token, self.image_token, self.video_token]:
start = 0
while True:
pos = message[i].find(special_token, start)
pos = message["content"].find(special_token, start)
if pos == -1:
break
positions.append((pos, special_token))
@ -1453,8 +1454,8 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
.view(-1, 1, 1)
.expand(
-1,
video_grid_thw[num_video_tokens][1] // self.image_processor.merge_size,
video_grid_thw[num_video_tokens][2] // self.image_processor.merge_size,
video_grid_thw[num_video_tokens][1] // image_processor.merge_size,
video_grid_thw[num_video_tokens][2] // image_processor.merge_size,
)
.flatten()
* mm_inputs["video_second_per_grid"][num_video_tokens]
@ -1462,17 +1463,17 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
).long()
t_ntoken_per_chunk = 50 # FIXME hardcode: [25 * 2]
video_chunk_indices = processor.get_chunked_index(video_t_index, t_ntoken_per_chunk)
audio_chunk_indices = self.get_chunked_index(audio_t_index, t_ntoken_per_chunk)
audio_chunk_indices = processor.get_chunked_index(audio_t_index, t_ntoken_per_chunk)
placeholder_string = ""
placeholder_string += "<|vision_bos|>" + "<|audio_bos|>"
for j in range(max(len(video_chunk_indices), len(audio_chunk_indices))):
video_chunk_index = video_chunk_indices[j] if j < len(video_chunk_indices) else None
audio_chunk_index = audio_chunk_indices[j] if j < len(audio_chunk_indices) else None
placeholder_string = "<|vision_bos|>" + "<|audio_bos|>"
if video_chunk_index is not None:
placeholder_string += self.video_token * (video_chunk_index[1] - video_chunk_index[0])
if audio_chunk_index is not None:
placeholder_string += self.audio_token * (audio_chunk_index[1] - audio_chunk_index[0])
placeholder_string += "<|audio_eos|>" + "<|vision_eos|>"
placeholder_string += "<|audio_eos|>" + "<|vision_eos|>"
content = content.replace(VIDEO_PLACEHOLDER, placeholder_string, 1)
content = content.replace(AUDIO_PLACEHOLDER, "", 1)

View File

@ -79,6 +79,7 @@ def patch_processor(
setattr(processor, "video_fps", model_args.video_fps)
setattr(processor, "video_maxlen", model_args.video_maxlen)
setattr(processor, "audio_sampling_rate", model_args.audio_sampling_rate)
setattr(processor, "use_audio_in_video", model_args.use_audio_in_video)
def patch_config(