diff --git a/data/mllm_demo_data/4.mp3 b/data/mllm_demo_data/4.mp3 new file mode 100644 index 00000000..17a8a845 Binary files /dev/null and b/data/mllm_demo_data/4.mp3 differ diff --git a/data/mllm_demo_data/4.mp4 b/data/mllm_demo_data/4.mp4 new file mode 100644 index 00000000..fdf859dd Binary files /dev/null and b/data/mllm_demo_data/4.mp4 differ diff --git a/src/llamafactory/data/collator.py b/src/llamafactory/data/collator.py index 8b64f9a0..0301d965 100644 --- a/src/llamafactory/data/collator.py +++ b/src/llamafactory/data/collator.py @@ -184,6 +184,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): rope_index_kwargs["second_per_grids"] = mm_inputs.get("video_second_per_grid") if getattr(self.model.config, "model_type", None) == "qwen2_5_omni_thinker": # for qwen2omni + rope_index_kwargs["use_audio_in_video"] = getattr(self.processor, "use_audio_in_video", False) feature_attention_mask = mm_inputs.get("feature_attention_mask", None) if feature_attention_mask is not None: audio_feature_lengths = torch.sum( diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py index 887dc08f..e56f5dd2 100644 --- a/src/llamafactory/data/mm_plugin.py +++ b/src/llamafactory/data/mm_plugin.py @@ -1378,6 +1378,7 @@ class Qwen2OmniPlugin(Qwen2VLPlugin): else: mm_inputs = {} + image_processor: BaseImageProcessor = getattr(processor, "image_processor", None) num_audio_tokens, num_image_tokens, num_video_tokens = 0, 0, 0 use_audio_in_video = getattr(processor, "use_audio_in_video", False) @@ -1398,16 +1399,16 @@ class Qwen2OmniPlugin(Qwen2VLPlugin): if audio_lengths is None: raise ValueError("audio_lengths should exist when use_audio_in_video is `True`.") - if not mm_inputs.get("video_grid_thw", None): + if mm_inputs.get("video_grid_thw", None) is None: raise ValueError("video_grid_thw should exist when use_audio_in_video is `True`.") positions_list = [] - for i, message in enumerate(messages): # get multimodal index when use_audio + for message in messages: # get multimodal index when use_audio positions = [] for special_token in [self.audio_token, self.image_token, self.video_token]: start = 0 while True: - pos = message[i].find(special_token, start) + pos = message["content"].find(special_token, start) if pos == -1: break positions.append((pos, special_token)) @@ -1453,8 +1454,8 @@ class Qwen2OmniPlugin(Qwen2VLPlugin): .view(-1, 1, 1) .expand( -1, - video_grid_thw[num_video_tokens][1] // self.image_processor.merge_size, - video_grid_thw[num_video_tokens][2] // self.image_processor.merge_size, + video_grid_thw[num_video_tokens][1] // image_processor.merge_size, + video_grid_thw[num_video_tokens][2] // image_processor.merge_size, ) .flatten() * mm_inputs["video_second_per_grid"][num_video_tokens] @@ -1462,17 +1463,17 @@ class Qwen2OmniPlugin(Qwen2VLPlugin): ).long() t_ntoken_per_chunk = 50 # FIXME hardcode: [25 * 2] video_chunk_indices = processor.get_chunked_index(video_t_index, t_ntoken_per_chunk) - audio_chunk_indices = self.get_chunked_index(audio_t_index, t_ntoken_per_chunk) + audio_chunk_indices = processor.get_chunked_index(audio_t_index, t_ntoken_per_chunk) placeholder_string = "" + placeholder_string += "<|vision_bos|>" + "<|audio_bos|>" for j in range(max(len(video_chunk_indices), len(audio_chunk_indices))): video_chunk_index = video_chunk_indices[j] if j < len(video_chunk_indices) else None audio_chunk_index = audio_chunk_indices[j] if j < len(audio_chunk_indices) else None - placeholder_string = "<|vision_bos|>" + "<|audio_bos|>" if video_chunk_index is not None: placeholder_string += self.video_token * (video_chunk_index[1] - video_chunk_index[0]) if audio_chunk_index is not None: placeholder_string += self.audio_token * (audio_chunk_index[1] - audio_chunk_index[0]) - placeholder_string += "<|audio_eos|>" + "<|vision_eos|>" + placeholder_string += "<|audio_eos|>" + "<|vision_eos|>" content = content.replace(VIDEO_PLACEHOLDER, placeholder_string, 1) content = content.replace(AUDIO_PLACEHOLDER, "", 1) diff --git a/src/llamafactory/model/patcher.py b/src/llamafactory/model/patcher.py index 28cb599f..de0bd299 100644 --- a/src/llamafactory/model/patcher.py +++ b/src/llamafactory/model/patcher.py @@ -79,6 +79,7 @@ def patch_processor( setattr(processor, "video_fps", model_args.video_fps) setattr(processor, "video_maxlen", model_args.video_maxlen) setattr(processor, "audio_sampling_rate", model_args.audio_sampling_rate) + setattr(processor, "use_audio_in_video", model_args.use_audio_in_video) def patch_config(