diff --git a/src/llmtuner/model/utils.py b/src/llmtuner/model/utils.py index fd587efd..7e4430d1 100644 --- a/src/llmtuner/model/utils.py +++ b/src/llmtuner/model/utils.py @@ -1,5 +1,6 @@ from enum import Enum, unique -from typing import TYPE_CHECKING, Dict, List +from functools import partial +from typing import TYPE_CHECKING, Any, Dict, List, Optional import torch from transformers import PreTrainedModel @@ -100,6 +101,37 @@ def find_expanded_modules(model: "PreTrainedModel", target_modules: List[str], n return module_names +def gradient_checkpointing_enable( + self: "PreTrainedModel", gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None +) -> None: + r""" + Activates gradient checkpointing for the current model. + + Modification of the original method to enable gradient checkpointing for block-wise optimizer. + """ + from torch.utils.checkpoint import checkpoint + + if not self.supports_gradient_checkpointing: + raise ValueError("{} does not support gradient checkpointing.".format(self.__class__.__name__)) + + if gradient_checkpointing_kwargs is None: + gradient_checkpointing_kwargs = {"use_reentrant": True} + + gradient_checkpointing_func = partial(checkpoint, **gradient_checkpointing_kwargs) + + def custom_gradient_checkpointing_func(func, *args, **kwargs): + module: "torch.nn.Module" = func.__self__ + + if any(param.requires_grad for param in module.parameters()): + for arg in args: + if torch.is_tensor(arg) and torch.is_floating_point(arg): + arg.requires_grad_(True) + + return gradient_checkpointing_func(func, *args, **kwargs) + + self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=custom_gradient_checkpointing_func) + + def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") -> Dict[str, torch.Tensor]: r""" Loads value head parameters from Hugging Face Hub or local disk. @@ -135,39 +167,3 @@ def register_autoclass(config: "PretrainedConfig", model: "PreTrainedModel", tok model.__class__.register_for_auto_class() if "AutoTokenizer" in tokenizer.init_kwargs.get("auto_map", {}): tokenizer.__class__.register_for_auto_class() - -def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None): - """ - Modification of the original method to enable gradient checkpointing for block-wise optimizer. - - Activates gradient checkpointing for the current model. - - We pass the `__call__` method of the modules instead of `forward` because `__call__` attaches all the hooks of - the module. https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2 - - Args: - gradient_checkpointing_kwargs (dict, *optional*): - Additional keyword arguments passed along to the `torch.utils.checkpoint.checkpoint` function. - """ - from torch.utils.checkpoint import checkpoint - import functools - - if not self.supports_gradient_checkpointing: - raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.") - - if gradient_checkpointing_kwargs is None: - gradient_checkpointing_kwargs = {"use_reentrant": True} - - checkpoint = functools.partial(checkpoint, **gradient_checkpointing_kwargs) - - def gradient_checkpointing_func(func, *args, **kwargs): - module = func.__self__ - - if any(p.requires_grad for p in module.parameters()): - for arg in args: - if torch.is_tensor(arg) and torch.is_floating_point(arg): - arg.requires_grad_(True) - - return checkpoint(func, *args, **kwargs) - - self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=gradient_checkpointing_func) \ No newline at end of file