mirror of
				https://github.com/hiyouga/LLaMA-Factory.git
				synced 2025-11-04 09:52:14 +08:00 
			
		
		
		
	[tests] add visual model save test (#8248)
Co-authored-by: Yaowei Zheng <hiyouga@buaa.edu.cn>
This commit is contained in:
		
							parent
							
								
									ed70f8d5a2
								
							
						
					
					
						commit
						212a8006dc
					
				@ -12,6 +12,8 @@
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
import os
 | 
			
		||||
 | 
			
		||||
import pytest
 | 
			
		||||
import torch
 | 
			
		||||
from transformers import AutoConfig, AutoModelForVision2Seq
 | 
			
		||||
@ -76,3 +78,25 @@ def test_visual_lora(freeze_vision_tower: bool, freeze_language_model: bool):
 | 
			
		||||
    assert (visual_param_name in trainable_params) != freeze_vision_tower
 | 
			
		||||
    assert (language_param_name in trainable_params) != freeze_language_model
 | 
			
		||||
    assert (merger_param_name in trainable_params) is False
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_visual_model_save_load():
 | 
			
		||||
    # check VLM's state dict: https://github.com/huggingface/transformers/pull/38385
 | 
			
		||||
    model_args = ModelArguments(model_name_or_path="Qwen/Qwen2-VL-2B-Instruct")
 | 
			
		||||
    finetuning_args = FinetuningArguments(finetuning_type="full")
 | 
			
		||||
    config = AutoConfig.from_pretrained(model_args.model_name_or_path)
 | 
			
		||||
    with torch.device("meta"):
 | 
			
		||||
        model = AutoModelForVision2Seq.from_config(config)
 | 
			
		||||
 | 
			
		||||
    model = init_adapter(config, model, model_args, finetuning_args, is_trainable=False)
 | 
			
		||||
    loaded_model_weight = dict(model.named_parameters())
 | 
			
		||||
 | 
			
		||||
    model.save_pretrained(os.path.join("output", "qwen2_vl"), max_shard_size="10GB", safe_serialization=False)
 | 
			
		||||
    saved_model_weight = torch.load(os.path.join("output", "qwen2_vl", "pytorch_model.bin"), weights_only=False)
 | 
			
		||||
 | 
			
		||||
    if is_transformers_version_greater_than("4.52.0"):
 | 
			
		||||
        assert "model.language_model.layers.0.self_attn.q_proj.weight" in loaded_model_weight
 | 
			
		||||
    else:
 | 
			
		||||
        assert "model.layers.0.self_attn.q_proj.weight" in loaded_model_weight
 | 
			
		||||
 | 
			
		||||
    assert "model.layers.0.self_attn.q_proj.weight" in saved_model_weight
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user