mirror of
				https://github.com/hiyouga/LLaMA-Factory.git
				synced 2025-11-04 18:02:19 +08:00 
			
		
		
		
	[data] fix minicpmv plugin (#6801)
* fix template name * tiny fix * support minicpm-o-2.6 * support inference of minicpmv * update readme * support dpo of minicpmv * update init audio * update init audio * [model]fix image process in minicpmo Former-commit-id: 8f704c8b6228ef50f828014f85dce67fda868660
This commit is contained in:
		
							parent
							
								
									34746d6151
								
							
						
					
					
						commit
						cfb926fb84
					
				@ -543,7 +543,7 @@ class MiniCPMVPlugin(BasePlugin):
 | 
			
		||||
        self._validate_input(images, videos)
 | 
			
		||||
        image_bounds_list = []
 | 
			
		||||
        valid_image_nums_ls = []
 | 
			
		||||
        for input_ids in batch_ids:
 | 
			
		||||
        for i, input_ids in enumerate(batch_ids):
 | 
			
		||||
            input_ids_ = torch.tensor(input_ids)
 | 
			
		||||
            start_cond = (input_ids_ == processor.tokenizer.im_start_id) | (
 | 
			
		||||
                input_ids_ == processor.tokenizer.slice_start_id
 | 
			
		||||
@ -552,12 +552,11 @@ class MiniCPMVPlugin(BasePlugin):
 | 
			
		||||
            image_start_tokens = torch.where(start_cond)[0]
 | 
			
		||||
            image_start_tokens += 1
 | 
			
		||||
            image_end_tokens = torch.where(end_cond)[0]
 | 
			
		||||
            valid_image_nums = max(len(image_start_tokens), len(image_end_tokens))
 | 
			
		||||
            valid_image_nums_ls.append(valid_image_nums)
 | 
			
		||||
            valid_image_nums_ls.append(imglens[i])
 | 
			
		||||
            image_bounds = torch.hstack(
 | 
			
		||||
                [
 | 
			
		||||
                    image_start_tokens[:valid_image_nums].unsqueeze(-1),
 | 
			
		||||
                    image_end_tokens[:valid_image_nums].unsqueeze(-1),
 | 
			
		||||
                    image_start_tokens.unsqueeze(-1),
 | 
			
		||||
                    image_end_tokens.unsqueeze(-1),
 | 
			
		||||
                ]
 | 
			
		||||
            )
 | 
			
		||||
            image_bounds_list.append(image_bounds)
 | 
			
		||||
 | 
			
		||||
@ -125,7 +125,7 @@ def patch_config(
 | 
			
		||||
        setattr(config, "use_cache", False)  # qwen2 does not support use_cache when using flash attn
 | 
			
		||||
 | 
			
		||||
    if getattr(config, "model_type", None) == "minicpmo":
 | 
			
		||||
        setattr(config, "init_audio", False)
 | 
			
		||||
        setattr(config, "init_audio", True)
 | 
			
		||||
        setattr(config, "init_tts", False)
 | 
			
		||||
 | 
			
		||||
    if "LlavaLlamaForCausalLM" in getattr(config, "architectures", []):
 | 
			
		||||
 | 
			
		||||
@ -152,6 +152,7 @@ def test_llama3_template(use_fast: bool):
 | 
			
		||||
    _check_template("meta-llama/Meta-Llama-3-8B-Instruct", "llama3", prompt_str, answer_str, use_fast)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.")
 | 
			
		||||
@pytest.mark.parametrize(
 | 
			
		||||
    "use_fast", [True, pytest.param(False, marks=pytest.mark.xfail(reason="Phi-4 slow tokenizer is broken."))]
 | 
			
		||||
)
 | 
			
		||||
@ -166,6 +167,7 @@ def test_phi4_template(use_fast: bool):
 | 
			
		||||
    _check_template("microsoft/phi-4", "phi4", prompt_str, answer_str, use_fast)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.")
 | 
			
		||||
@pytest.mark.parametrize("use_fast", [True, False])
 | 
			
		||||
def test_qwen_template(use_fast: bool):
 | 
			
		||||
    prompt_str = (
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user