mirror of
				https://github.com/hiyouga/LLaMA-Factory.git
				synced 2025-11-04 18:02:19 +08:00 
			
		
		
		
	add test case
Former-commit-id: c452d65e1551074dddd1d87517c0d44dc014c6aa
This commit is contained in:
		
							parent
							
								
									294a103ead
								
							
						
					
					
						commit
						158e0e1f63
					
				@ -19,7 +19,7 @@
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
import inspect
 | 
			
		||||
from functools import partial
 | 
			
		||||
from functools import partial, wraps
 | 
			
		||||
from types import MethodType
 | 
			
		||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
 | 
			
		||||
 | 
			
		||||
@ -73,6 +73,25 @@ class UnslothGradientCheckpointing(torch.autograd.Function):
 | 
			
		||||
        return (None, hidden_states.grad) + (None,) * len(ctx.args)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_custom_gradient_checkpointing_func(gradient_checkpointing_func):
 | 
			
		||||
    r"""
 | 
			
		||||
    Only applies gradient checkpointing to trainable layers.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    @wraps(gradient_checkpointing_func)
 | 
			
		||||
    def custom_gradient_checkpointing_func(func, *args: Union["torch.Tensor", Any], **kwargs):
 | 
			
		||||
        module: "torch.nn.Module" = func.__self__
 | 
			
		||||
 | 
			
		||||
        if any(param.requires_grad for param in module.parameters()):
 | 
			
		||||
            for arg in args:
 | 
			
		||||
                if torch.is_tensor(arg) and torch.is_floating_point(arg):
 | 
			
		||||
                    arg.requires_grad_(True)
 | 
			
		||||
 | 
			
		||||
        return gradient_checkpointing_func(func, *args, **kwargs)
 | 
			
		||||
 | 
			
		||||
    return custom_gradient_checkpointing_func
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _gradient_checkpointing_enable(
 | 
			
		||||
    self: "PreTrainedModel",
 | 
			
		||||
    gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None,
 | 
			
		||||
@ -96,22 +115,13 @@ def _gradient_checkpointing_enable(
 | 
			
		||||
    else:
 | 
			
		||||
        gradient_checkpointing_func = partial(checkpoint, **gradient_checkpointing_kwargs)
 | 
			
		||||
 | 
			
		||||
    def custom_gradient_checkpointing_func(func, *args: Union["torch.Tensor", Any], **kwargs):
 | 
			
		||||
        module: "torch.nn.Module" = func.__self__
 | 
			
		||||
 | 
			
		||||
        if any(param.requires_grad for param in module.parameters()):
 | 
			
		||||
            for arg in args:
 | 
			
		||||
                if torch.is_tensor(arg) and torch.is_floating_point(arg):
 | 
			
		||||
                    arg.requires_grad_(True)
 | 
			
		||||
 | 
			
		||||
        return gradient_checkpointing_func(func, *args, **kwargs)
 | 
			
		||||
 | 
			
		||||
    gradient_checkpointing_func = get_custom_gradient_checkpointing_func(gradient_checkpointing_func)
 | 
			
		||||
    if "value" in inspect.signature(self._set_gradient_checkpointing).parameters:  # old GC format
 | 
			
		||||
        self.apply(partial(self._set_gradient_checkpointing, value=True))
 | 
			
		||||
        self.enable_input_require_grads()
 | 
			
		||||
        logger.warning("You are using the old GC format, some features (e.g. BAdam) will be invalid.")
 | 
			
		||||
    else:  # have already enabled input require gradients
 | 
			
		||||
        self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=custom_gradient_checkpointing_func)
 | 
			
		||||
        self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=gradient_checkpointing_func)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _fp32_forward_post_hook(
 | 
			
		||||
 | 
			
		||||
@ -51,6 +51,12 @@ def test_checkpointing_disable():
 | 
			
		||||
        assert getattr(module, "gradient_checkpointing") is False
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_unsloth_gradient_checkpointing():
 | 
			
		||||
    model = load_train_model(use_unsloth_gc=True, **TRAIN_ARGS)
 | 
			
		||||
    for module in filter(lambda m: hasattr(m, "gradient_checkpointing"), model.modules()):
 | 
			
		||||
        assert module._gradient_checkpointing_func.__self__.__name__ == "UnslothGradientCheckpointing"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_upcast_layernorm():
 | 
			
		||||
    model = load_train_model(upcast_layernorm=True, **TRAIN_ARGS)
 | 
			
		||||
    for name, param in model.named_parameters():
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user