mirror of
				https://github.com/hiyouga/LLaMA-Factory.git
				synced 2025-11-04 18:02:19 +08:00 
			
		
		
		
	[data] Fix qwen_2vl with valuehead (#9078)
This commit is contained in:
		
							parent
							
								
									a22dab97fd
								
							
						
					
					
						commit
						610a3f1094
					
				@ -313,7 +313,7 @@ Choose your path:
 | 
			
		||||
| [Qwen2-Audio](https://huggingface.co/Qwen)                        | 7B                               | qwen2_audio         |
 | 
			
		||||
| [Qwen2.5-Omni](https://huggingface.co/Qwen)                       | 3B/7B                            | qwen2_omni          |
 | 
			
		||||
| [Qwen2-VL/Qwen2.5-VL/QVQ](https://huggingface.co/Qwen)            | 2B/3B/7B/32B/72B                 | qwen2_vl            |
 | 
			
		||||
| [Seed Coder](https://huggingface.co/ByteDance-Seed)               | 8B                               | seed_coder          |
 | 
			
		||||
| [Seed (Coder/OSS)](https://huggingface.co/ByteDance-Seed)         | 8B/36B                           | seed_coder/seed_oss |
 | 
			
		||||
| [Skywork o1](https://huggingface.co/Skywork)                      | 8B                               | skywork_o1          |
 | 
			
		||||
| [StarCoder 2](https://huggingface.co/bigcode)                     | 3B/7B/15B                        | -                   |
 | 
			
		||||
| [TeleChat2](https://huggingface.co/Tele-AI)                       | 3B/7B/35B/115B                   | telechat2           |
 | 
			
		||||
 | 
			
		||||
@ -315,7 +315,7 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
 | 
			
		||||
| [Qwen2-Audio](https://huggingface.co/Qwen)                        | 7B                               | qwen2_audio         |
 | 
			
		||||
| [Qwen2.5-Omni](https://huggingface.co/Qwen)                       | 3B/7B                            | qwen2_omni          |
 | 
			
		||||
| [Qwen2-VL/Qwen2.5-VL/QVQ](https://huggingface.co/Qwen)            | 2B/3B/7B/32B/72B                 | qwen2_vl            |
 | 
			
		||||
| [Seed Coder](https://huggingface.co/ByteDance-Seed)               | 8B                               | seed_coder          |
 | 
			
		||||
| [Seed (Coder/OSS)](https://huggingface.co/ByteDance-Seed)         | 8B/36B                           | seed_coder/seed_oss |
 | 
			
		||||
| [Skywork o1](https://huggingface.co/Skywork)                      | 8B                               | skywork_o1          |
 | 
			
		||||
| [StarCoder 2](https://huggingface.co/bigcode)                     | 3B/7B/15B                        | -                   |
 | 
			
		||||
| [TeleChat2](https://huggingface.co/Tele-AI)                       | 3B/7B/35B/115B                   | telechat2           |
 | 
			
		||||
 | 
			
		||||
@ -211,9 +211,23 @@ def patch_valuehead_model(model: "AutoModelForCausalLMWithValueHead") -> None:
 | 
			
		||||
        if isinstance(self.pretrained_model, PeftModel):
 | 
			
		||||
            self.pretrained_model.create_or_update_model_card(output_dir)
 | 
			
		||||
 | 
			
		||||
    def get_rope_index_func(self: "AutoModelForCausalLMWithValueHead"):
 | 
			
		||||
        if isinstance(self.pretrained_model, PeftModel):
 | 
			
		||||
            base_model = self.pretrained_model.base_model.model
 | 
			
		||||
        else:
 | 
			
		||||
            base_model = self.pretrained_model
 | 
			
		||||
 | 
			
		||||
        if base_model and hasattr(base_model, "get_rope_index"):
 | 
			
		||||
            return base_model.get_rope_index
 | 
			
		||||
        elif (base_model and hasattr(base_model, "model") and hasattr(base_model.model, "get_rope_index")):
 | 
			
		||||
            return base_model.model.get_rope_index
 | 
			
		||||
        else:
 | 
			
		||||
            return None
 | 
			
		||||
 | 
			
		||||
    ignore_modules = [name for name, _ in model.named_parameters() if "pretrained_model" in name]
 | 
			
		||||
    setattr(model, "_keys_to_ignore_on_save", ignore_modules)
 | 
			
		||||
    setattr(model, "tie_weights", MethodType(tie_weights, model))
 | 
			
		||||
    setattr(model, "get_input_embeddings", MethodType(get_input_embeddings, model))
 | 
			
		||||
    setattr(model, "get_output_embeddings", MethodType(get_output_embeddings, model))
 | 
			
		||||
    setattr(model, "get_rope_index", get_rope_index_func(model))
 | 
			
		||||
    setattr(model, "create_or_update_model_card", MethodType(create_or_update_model_card, model))
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user