mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-07-31 10:42:50 +08:00
[data] Fix bugs of use_audio_in_video
in Qwen2.5 Omni (#7638)
* cache _mm_inputs * nit * support for use_audio_in_video * remove cache * fix data * Update mllm_video_audio_demo.json
This commit is contained in:
parent
8f5f4cc559
commit
7d8bee96fc
@ -66,6 +66,21 @@
|
||||
"assistant_tag": "assistant"
|
||||
}
|
||||
},
|
||||
"mllm_video_audio_demo": {
|
||||
"file_name": "mllm_video_audio_demo.json",
|
||||
"formatting": "sharegpt",
|
||||
"columns": {
|
||||
"messages": "messages",
|
||||
"videos": "videos",
|
||||
"audios": "audios"
|
||||
},
|
||||
"tags": {
|
||||
"role_tag": "role",
|
||||
"content_tag": "content",
|
||||
"user_tag": "user",
|
||||
"assistant_tag": "assistant"
|
||||
}
|
||||
},
|
||||
"alpaca_en": {
|
||||
"hf_hub_url": "llamafactory/alpaca_en",
|
||||
"ms_hub_url": "llamafactory/alpaca_en",
|
||||
|
BIN
data/mllm_demo_data/4.mp3
Normal file
BIN
data/mllm_demo_data/4.mp3
Normal file
Binary file not shown.
BIN
data/mllm_demo_data/4.mp4
Normal file
BIN
data/mllm_demo_data/4.mp4
Normal file
Binary file not shown.
57
data/mllm_video_audio_demo.json
Normal file
57
data/mllm_video_audio_demo.json
Normal file
@ -0,0 +1,57 @@
|
||||
[
|
||||
{
|
||||
"messages": [
|
||||
{
|
||||
"content": "<video><audio>What is the video describing?",
|
||||
"role": "user"
|
||||
},
|
||||
{
|
||||
"content": "A girl who is drawing a picture of a guitar and feel nervous.",
|
||||
"role": "assistant"
|
||||
}
|
||||
],
|
||||
"videos": [
|
||||
"mllm_demo_data/4.mp4"
|
||||
],
|
||||
"audios": [
|
||||
"mllm_demo_data/4.mp3"
|
||||
]
|
||||
},
|
||||
{
|
||||
"messages": [
|
||||
{
|
||||
"content": "<video><audio>What does this girl say?",
|
||||
"role": "user"
|
||||
},
|
||||
{
|
||||
"content": "She says: 'Hello! Take a look at what am I drawing!'",
|
||||
"role": "assistant"
|
||||
}
|
||||
],
|
||||
"videos": [
|
||||
"mllm_demo_data/4.mp4"
|
||||
],
|
||||
"audios": [
|
||||
"mllm_demo_data/4.mp3"
|
||||
]
|
||||
},
|
||||
{
|
||||
"messages": [
|
||||
{
|
||||
"content": "<video><audio>What is this girl drawing with?",
|
||||
"role": "user"
|
||||
},
|
||||
{
|
||||
"content": "She is drawing with an iPad.",
|
||||
"role": "assistant"
|
||||
}
|
||||
],
|
||||
"videos": [
|
||||
"mllm_demo_data/4.mp4"
|
||||
],
|
||||
"audios": [
|
||||
"mllm_demo_data/4.mp3"
|
||||
]
|
||||
}
|
||||
|
||||
]
|
@ -184,6 +184,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
rope_index_kwargs["second_per_grids"] = mm_inputs.get("video_second_per_grid")
|
||||
|
||||
if getattr(self.model.config, "model_type", None) == "qwen2_5_omni_thinker": # for qwen2omni
|
||||
rope_index_kwargs["use_audio_in_video"] = getattr(self.processor, "use_audio_in_video", False)
|
||||
feature_attention_mask = mm_inputs.get("feature_attention_mask", None)
|
||||
if feature_attention_mask is not None:
|
||||
audio_feature_lengths = torch.sum(
|
||||
|
@ -1378,6 +1378,7 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
|
||||
else:
|
||||
mm_inputs = {}
|
||||
|
||||
image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
|
||||
num_audio_tokens, num_image_tokens, num_video_tokens = 0, 0, 0
|
||||
use_audio_in_video = getattr(processor, "use_audio_in_video", False)
|
||||
|
||||
@ -1398,16 +1399,16 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
|
||||
if audio_lengths is None:
|
||||
raise ValueError("audio_lengths should exist when use_audio_in_video is `True`.")
|
||||
|
||||
if not mm_inputs.get("video_grid_thw", None):
|
||||
if mm_inputs.get("video_grid_thw", None) is None:
|
||||
raise ValueError("video_grid_thw should exist when use_audio_in_video is `True`.")
|
||||
|
||||
positions_list = []
|
||||
for i, message in enumerate(messages): # get multimodal index when use_audio
|
||||
for message in messages: # get multimodal index when use_audio
|
||||
positions = []
|
||||
for special_token in [self.audio_token, self.image_token, self.video_token]:
|
||||
start = 0
|
||||
while True:
|
||||
pos = message[i].find(special_token, start)
|
||||
pos = message["content"].find(special_token, start)
|
||||
if pos == -1:
|
||||
break
|
||||
positions.append((pos, special_token))
|
||||
@ -1453,8 +1454,8 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
|
||||
.view(-1, 1, 1)
|
||||
.expand(
|
||||
-1,
|
||||
video_grid_thw[num_video_tokens][1] // self.image_processor.merge_size,
|
||||
video_grid_thw[num_video_tokens][2] // self.image_processor.merge_size,
|
||||
video_grid_thw[num_video_tokens][1] // image_processor.merge_size,
|
||||
video_grid_thw[num_video_tokens][2] // image_processor.merge_size,
|
||||
)
|
||||
.flatten()
|
||||
* mm_inputs["video_second_per_grid"][num_video_tokens]
|
||||
@ -1462,17 +1463,17 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
|
||||
).long()
|
||||
t_ntoken_per_chunk = 50 # FIXME hardcode: [25 * 2]
|
||||
video_chunk_indices = processor.get_chunked_index(video_t_index, t_ntoken_per_chunk)
|
||||
audio_chunk_indices = self.get_chunked_index(audio_t_index, t_ntoken_per_chunk)
|
||||
audio_chunk_indices = processor.get_chunked_index(audio_t_index, t_ntoken_per_chunk)
|
||||
placeholder_string = ""
|
||||
placeholder_string += "<|vision_bos|>" + "<|audio_bos|>"
|
||||
for j in range(max(len(video_chunk_indices), len(audio_chunk_indices))):
|
||||
video_chunk_index = video_chunk_indices[j] if j < len(video_chunk_indices) else None
|
||||
audio_chunk_index = audio_chunk_indices[j] if j < len(audio_chunk_indices) else None
|
||||
placeholder_string = "<|vision_bos|>" + "<|audio_bos|>"
|
||||
if video_chunk_index is not None:
|
||||
placeholder_string += self.video_token * (video_chunk_index[1] - video_chunk_index[0])
|
||||
if audio_chunk_index is not None:
|
||||
placeholder_string += self.audio_token * (audio_chunk_index[1] - audio_chunk_index[0])
|
||||
placeholder_string += "<|audio_eos|>" + "<|vision_eos|>"
|
||||
placeholder_string += "<|audio_eos|>" + "<|vision_eos|>"
|
||||
|
||||
content = content.replace(VIDEO_PLACEHOLDER, placeholder_string, 1)
|
||||
content = content.replace(AUDIO_PLACEHOLDER, "", 1)
|
||||
|
@ -79,6 +79,7 @@ def patch_processor(
|
||||
setattr(processor, "video_fps", model_args.video_fps)
|
||||
setattr(processor, "video_maxlen", model_args.video_maxlen)
|
||||
setattr(processor, "audio_sampling_rate", model_args.audio_sampling_rate)
|
||||
setattr(processor, "use_audio_in_video", model_args.use_audio_in_video)
|
||||
|
||||
|
||||
def patch_config(
|
||||
|
Loading…
x
Reference in New Issue
Block a user