mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-04 12:42:51 +08:00
Update utils.py
Former-commit-id: 38a56706e0f52297501d351d38b51bee73e881dc
This commit is contained in:
parent
48fb0be1b9
commit
b92f690190
@ -1,5 +1,6 @@
|
|||||||
from enum import Enum, unique
|
from enum import Enum, unique
|
||||||
from typing import TYPE_CHECKING, Dict, List
|
from functools import partial
|
||||||
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from transformers import PreTrainedModel
|
from transformers import PreTrainedModel
|
||||||
@ -100,6 +101,37 @@ def find_expanded_modules(model: "PreTrainedModel", target_modules: List[str], n
|
|||||||
return module_names
|
return module_names
|
||||||
|
|
||||||
|
|
||||||
|
def gradient_checkpointing_enable(
|
||||||
|
self: "PreTrainedModel", gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None
|
||||||
|
) -> None:
|
||||||
|
r"""
|
||||||
|
Activates gradient checkpointing for the current model.
|
||||||
|
|
||||||
|
Modification of the original method to enable gradient checkpointing for block-wise optimizer.
|
||||||
|
"""
|
||||||
|
from torch.utils.checkpoint import checkpoint
|
||||||
|
|
||||||
|
if not self.supports_gradient_checkpointing:
|
||||||
|
raise ValueError("{} does not support gradient checkpointing.".format(self.__class__.__name__))
|
||||||
|
|
||||||
|
if gradient_checkpointing_kwargs is None:
|
||||||
|
gradient_checkpointing_kwargs = {"use_reentrant": True}
|
||||||
|
|
||||||
|
gradient_checkpointing_func = partial(checkpoint, **gradient_checkpointing_kwargs)
|
||||||
|
|
||||||
|
def custom_gradient_checkpointing_func(func, *args, **kwargs):
|
||||||
|
module: "torch.nn.Module" = func.__self__
|
||||||
|
|
||||||
|
if any(param.requires_grad for param in module.parameters()):
|
||||||
|
for arg in args:
|
||||||
|
if torch.is_tensor(arg) and torch.is_floating_point(arg):
|
||||||
|
arg.requires_grad_(True)
|
||||||
|
|
||||||
|
return gradient_checkpointing_func(func, *args, **kwargs)
|
||||||
|
|
||||||
|
self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=custom_gradient_checkpointing_func)
|
||||||
|
|
||||||
|
|
||||||
def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") -> Dict[str, torch.Tensor]:
|
def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") -> Dict[str, torch.Tensor]:
|
||||||
r"""
|
r"""
|
||||||
Loads value head parameters from Hugging Face Hub or local disk.
|
Loads value head parameters from Hugging Face Hub or local disk.
|
||||||
@ -135,39 +167,3 @@ def register_autoclass(config: "PretrainedConfig", model: "PreTrainedModel", tok
|
|||||||
model.__class__.register_for_auto_class()
|
model.__class__.register_for_auto_class()
|
||||||
if "AutoTokenizer" in tokenizer.init_kwargs.get("auto_map", {}):
|
if "AutoTokenizer" in tokenizer.init_kwargs.get("auto_map", {}):
|
||||||
tokenizer.__class__.register_for_auto_class()
|
tokenizer.__class__.register_for_auto_class()
|
||||||
|
|
||||||
def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
|
|
||||||
"""
|
|
||||||
Modification of the original method to enable gradient checkpointing for block-wise optimizer.
|
|
||||||
|
|
||||||
Activates gradient checkpointing for the current model.
|
|
||||||
|
|
||||||
We pass the `__call__` method of the modules instead of `forward` because `__call__` attaches all the hooks of
|
|
||||||
the module. https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2
|
|
||||||
|
|
||||||
Args:
|
|
||||||
gradient_checkpointing_kwargs (dict, *optional*):
|
|
||||||
Additional keyword arguments passed along to the `torch.utils.checkpoint.checkpoint` function.
|
|
||||||
"""
|
|
||||||
from torch.utils.checkpoint import checkpoint
|
|
||||||
import functools
|
|
||||||
|
|
||||||
if not self.supports_gradient_checkpointing:
|
|
||||||
raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
|
|
||||||
|
|
||||||
if gradient_checkpointing_kwargs is None:
|
|
||||||
gradient_checkpointing_kwargs = {"use_reentrant": True}
|
|
||||||
|
|
||||||
checkpoint = functools.partial(checkpoint, **gradient_checkpointing_kwargs)
|
|
||||||
|
|
||||||
def gradient_checkpointing_func(func, *args, **kwargs):
|
|
||||||
module = func.__self__
|
|
||||||
|
|
||||||
if any(p.requires_grad for p in module.parameters()):
|
|
||||||
for arg in args:
|
|
||||||
if torch.is_tensor(arg) and torch.is_floating_point(arg):
|
|
||||||
arg.requires_grad_(True)
|
|
||||||
|
|
||||||
return checkpoint(func, *args, **kwargs)
|
|
||||||
|
|
||||||
self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=gradient_checkpointing_func)
|
|
Loading…
x
Reference in New Issue
Block a user