mirror of
				https://github.com/hiyouga/LLaMA-Factory.git
				synced 2025-11-04 18:02:19 +08:00 
			
		
		
		
	[data] Fix bugs of use_audio_in_video in Qwen2.5 Omni (#7638)
				
					
				
			* cache _mm_inputs * nit * support for use_audio_in_video * remove cache * fix data * Update mllm_video_audio_demo.json
This commit is contained in:
		
							parent
							
								
									acb09fa3a3
								
							
						
					
					
						commit
						349c56c51c
					
				
							
								
								
									
										
											BIN
										
									
								
								data/mllm_demo_data/4.mp3
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								data/mllm_demo_data/4.mp3
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							
							
								
								
									
										
											BIN
										
									
								
								data/mllm_demo_data/4.mp4
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								data/mllm_demo_data/4.mp4
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							@ -184,6 +184,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
 | 
			
		||||
                rope_index_kwargs["second_per_grids"] = mm_inputs.get("video_second_per_grid")
 | 
			
		||||
 | 
			
		||||
            if getattr(self.model.config, "model_type", None) == "qwen2_5_omni_thinker":  # for qwen2omni
 | 
			
		||||
                rope_index_kwargs["use_audio_in_video"] = getattr(self.processor, "use_audio_in_video", False)
 | 
			
		||||
                feature_attention_mask = mm_inputs.get("feature_attention_mask", None)
 | 
			
		||||
                if feature_attention_mask is not None:
 | 
			
		||||
                    audio_feature_lengths = torch.sum(
 | 
			
		||||
 | 
			
		||||
@ -1378,6 +1378,7 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
 | 
			
		||||
        else:
 | 
			
		||||
            mm_inputs = {}
 | 
			
		||||
 | 
			
		||||
        image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
 | 
			
		||||
        num_audio_tokens, num_image_tokens, num_video_tokens = 0, 0, 0
 | 
			
		||||
        use_audio_in_video = getattr(processor, "use_audio_in_video", False)
 | 
			
		||||
 | 
			
		||||
@ -1398,16 +1399,16 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
 | 
			
		||||
            if audio_lengths is None:
 | 
			
		||||
                raise ValueError("audio_lengths should exist when use_audio_in_video is `True`.")
 | 
			
		||||
 | 
			
		||||
            if not mm_inputs.get("video_grid_thw", None):
 | 
			
		||||
            if mm_inputs.get("video_grid_thw", None) is None:
 | 
			
		||||
                raise ValueError("video_grid_thw should exist when use_audio_in_video is `True`.")
 | 
			
		||||
 | 
			
		||||
            positions_list = []
 | 
			
		||||
            for i, message in enumerate(messages):  # get multimodal index when use_audio
 | 
			
		||||
            for message in messages:  # get multimodal index when use_audio
 | 
			
		||||
                positions = []
 | 
			
		||||
                for special_token in [self.audio_token, self.image_token, self.video_token]:
 | 
			
		||||
                    start = 0
 | 
			
		||||
                    while True:
 | 
			
		||||
                        pos = message[i].find(special_token, start)
 | 
			
		||||
                        pos = message["content"].find(special_token, start)
 | 
			
		||||
                        if pos == -1:
 | 
			
		||||
                            break
 | 
			
		||||
                        positions.append((pos, special_token))
 | 
			
		||||
@ -1453,8 +1454,8 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
 | 
			
		||||
                        .view(-1, 1, 1)
 | 
			
		||||
                        .expand(
 | 
			
		||||
                            -1,
 | 
			
		||||
                            video_grid_thw[num_video_tokens][1] // self.image_processor.merge_size,
 | 
			
		||||
                            video_grid_thw[num_video_tokens][2] // self.image_processor.merge_size,
 | 
			
		||||
                            video_grid_thw[num_video_tokens][1] // image_processor.merge_size,
 | 
			
		||||
                            video_grid_thw[num_video_tokens][2] // image_processor.merge_size,
 | 
			
		||||
                        )
 | 
			
		||||
                        .flatten()
 | 
			
		||||
                        * mm_inputs["video_second_per_grid"][num_video_tokens]
 | 
			
		||||
@ -1462,17 +1463,17 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
 | 
			
		||||
                    ).long()
 | 
			
		||||
                    t_ntoken_per_chunk = 50  # FIXME hardcode: [25 * 2]
 | 
			
		||||
                    video_chunk_indices = processor.get_chunked_index(video_t_index, t_ntoken_per_chunk)
 | 
			
		||||
                    audio_chunk_indices = self.get_chunked_index(audio_t_index, t_ntoken_per_chunk)
 | 
			
		||||
                    audio_chunk_indices = processor.get_chunked_index(audio_t_index, t_ntoken_per_chunk)
 | 
			
		||||
                    placeholder_string = ""
 | 
			
		||||
                    placeholder_string += "<|vision_bos|>" + "<|audio_bos|>"
 | 
			
		||||
                    for j in range(max(len(video_chunk_indices), len(audio_chunk_indices))):
 | 
			
		||||
                        video_chunk_index = video_chunk_indices[j] if j < len(video_chunk_indices) else None
 | 
			
		||||
                        audio_chunk_index = audio_chunk_indices[j] if j < len(audio_chunk_indices) else None
 | 
			
		||||
                        placeholder_string = "<|vision_bos|>" + "<|audio_bos|>"
 | 
			
		||||
                        if video_chunk_index is not None:
 | 
			
		||||
                            placeholder_string += self.video_token * (video_chunk_index[1] - video_chunk_index[0])
 | 
			
		||||
                        if audio_chunk_index is not None:
 | 
			
		||||
                            placeholder_string += self.audio_token * (audio_chunk_index[1] - audio_chunk_index[0])
 | 
			
		||||
                        placeholder_string += "<|audio_eos|>" + "<|vision_eos|>"
 | 
			
		||||
                    placeholder_string += "<|audio_eos|>" + "<|vision_eos|>"
 | 
			
		||||
 | 
			
		||||
                    content = content.replace(VIDEO_PLACEHOLDER, placeholder_string, 1)
 | 
			
		||||
                    content = content.replace(AUDIO_PLACEHOLDER, "", 1)
 | 
			
		||||
 | 
			
		||||
@ -79,6 +79,7 @@ def patch_processor(
 | 
			
		||||
    setattr(processor, "video_fps", model_args.video_fps)
 | 
			
		||||
    setattr(processor, "video_maxlen", model_args.video_maxlen)
 | 
			
		||||
    setattr(processor, "audio_sampling_rate", model_args.audio_sampling_rate)
 | 
			
		||||
    setattr(processor, "use_audio_in_video", model_args.use_audio_in_video)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def patch_config(
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user