mirror of
				https://github.com/hiyouga/LLaMA-Factory.git
				synced 2025-11-04 18:02:19 +08:00 
			
		
		
		
	fix test case
Former-commit-id: b075b2971c6acb2c6039b36420a296f1f4e1b91b
This commit is contained in:
		
							parent
							
								
									158e0e1f63
								
							
						
					
					
						commit
						b6810b209a
					
				@ -89,6 +89,9 @@ def get_custom_gradient_checkpointing_func(gradient_checkpointing_func):
 | 
			
		||||
 | 
			
		||||
        return gradient_checkpointing_func(func, *args, **kwargs)
 | 
			
		||||
 | 
			
		||||
    if hasattr(gradient_checkpointing_func, "__self__"):  # fix test case
 | 
			
		||||
        custom_gradient_checkpointing_func.__self__ = gradient_checkpointing_func.__self__
 | 
			
		||||
 | 
			
		||||
    return custom_gradient_checkpointing_func
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -54,7 +54,7 @@ def test_checkpointing_disable():
 | 
			
		||||
def test_unsloth_gradient_checkpointing():
 | 
			
		||||
    model = load_train_model(use_unsloth_gc=True, **TRAIN_ARGS)
 | 
			
		||||
    for module in filter(lambda m: hasattr(m, "gradient_checkpointing"), model.modules()):
 | 
			
		||||
        assert module._gradient_checkpointing_func.__self__.__name__ == "UnslothGradientCheckpointing"
 | 
			
		||||
        assert module._gradient_checkpointing_func.__self__.__name__ == "UnslothGradientCheckpointing"  # classmethod
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_upcast_layernorm():
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user