diff --git a/src/llamafactory/model/model_utils/checkpointing.py b/src/llamafactory/model/model_utils/checkpointing.py index 337b137a..3130d6d2 100644 --- a/src/llamafactory/model/model_utils/checkpointing.py +++ b/src/llamafactory/model/model_utils/checkpointing.py @@ -19,7 +19,7 @@ # limitations under the License. import inspect -from functools import partial +from functools import partial, wraps from types import MethodType from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union @@ -73,6 +73,25 @@ class UnslothGradientCheckpointing(torch.autograd.Function): return (None, hidden_states.grad) + (None,) * len(ctx.args) +def get_custom_gradient_checkpointing_func(gradient_checkpointing_func): + r""" + Only applies gradient checkpointing to trainable layers. + """ + + @wraps(gradient_checkpointing_func) + def custom_gradient_checkpointing_func(func, *args: Union["torch.Tensor", Any], **kwargs): + module: "torch.nn.Module" = func.__self__ + + if any(param.requires_grad for param in module.parameters()): + for arg in args: + if torch.is_tensor(arg) and torch.is_floating_point(arg): + arg.requires_grad_(True) + + return gradient_checkpointing_func(func, *args, **kwargs) + + return custom_gradient_checkpointing_func + + def _gradient_checkpointing_enable( self: "PreTrainedModel", gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None, @@ -96,22 +115,13 @@ def _gradient_checkpointing_enable( else: gradient_checkpointing_func = partial(checkpoint, **gradient_checkpointing_kwargs) - def custom_gradient_checkpointing_func(func, *args: Union["torch.Tensor", Any], **kwargs): - module: "torch.nn.Module" = func.__self__ - - if any(param.requires_grad for param in module.parameters()): - for arg in args: - if torch.is_tensor(arg) and torch.is_floating_point(arg): - arg.requires_grad_(True) - - return gradient_checkpointing_func(func, *args, **kwargs) - + gradient_checkpointing_func = get_custom_gradient_checkpointing_func(gradient_checkpointing_func) if "value" in inspect.signature(self._set_gradient_checkpointing).parameters: # old GC format self.apply(partial(self._set_gradient_checkpointing, value=True)) self.enable_input_require_grads() logger.warning("You are using the old GC format, some features (e.g. BAdam) will be invalid.") else: # have already enabled input require gradients - self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=custom_gradient_checkpointing_func) + self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=gradient_checkpointing_func) def _fp32_forward_post_hook( diff --git a/tests/model/model_utils/test_checkpointing.py b/tests/model/model_utils/test_checkpointing.py index 23ada691..ac500df2 100644 --- a/tests/model/model_utils/test_checkpointing.py +++ b/tests/model/model_utils/test_checkpointing.py @@ -51,6 +51,12 @@ def test_checkpointing_disable(): assert getattr(module, "gradient_checkpointing") is False +def test_unsloth_gradient_checkpointing(): + model = load_train_model(use_unsloth_gc=True, **TRAIN_ARGS) + for module in filter(lambda m: hasattr(m, "gradient_checkpointing"), model.modules()): + assert module._gradient_checkpointing_func.__self__.__name__ == "UnslothGradientCheckpointing" + + def test_upcast_layernorm(): model = load_train_model(upcast_layernorm=True, **TRAIN_ARGS) for name, param in model.named_parameters():