mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-06 13:42:51 +08:00
110 lines
4.6 KiB
Python
110 lines
4.6 KiB
Python
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
|
|
#
|
|
# This code is inspired by the HuggingFace's Transformers and PEFT library.
|
|
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/modeling_utils.py
|
|
# https://github.com/huggingface/peft/blob/v0.10.0/src/peft/utils/other.py
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import inspect
|
|
from functools import partial
|
|
from types import MethodType
|
|
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
|
|
|
|
import torch
|
|
|
|
from ...extras.constants import LAYERNORM_NAMES
|
|
from ...extras.logging import get_logger
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
from transformers import PreTrainedModel
|
|
|
|
from ...hparams import ModelArguments
|
|
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
def _gradient_checkpointing_enable(
|
|
self: "PreTrainedModel", gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None
|
|
) -> None:
|
|
r"""
|
|
Activates gradient checkpointing for the current model.
|
|
|
|
Modification of the original method to enable gradient checkpointing for block-wise optimizer.
|
|
"""
|
|
from torch.utils.checkpoint import checkpoint
|
|
|
|
if not self.supports_gradient_checkpointing:
|
|
raise ValueError("{} does not support gradient checkpointing.".format(self.__class__.__name__))
|
|
|
|
if gradient_checkpointing_kwargs is None:
|
|
gradient_checkpointing_kwargs = {"use_reentrant": True}
|
|
|
|
gradient_checkpointing_func = partial(checkpoint, **gradient_checkpointing_kwargs)
|
|
|
|
def custom_gradient_checkpointing_func(func, *args, **kwargs):
|
|
module: "torch.nn.Module" = func.__self__
|
|
|
|
if any(param.requires_grad for param in module.parameters()):
|
|
for arg in args:
|
|
if torch.is_tensor(arg) and torch.is_floating_point(arg):
|
|
arg.requires_grad_(True)
|
|
|
|
return gradient_checkpointing_func(func, *args, **kwargs)
|
|
|
|
if "value" in inspect.signature(self._set_gradient_checkpointing).parameters: # old GC format
|
|
self.apply(partial(self._set_gradient_checkpointing, value=True))
|
|
self.enable_input_require_grads()
|
|
logger.warning("You are using the old GC format, some features (e.g. BAdam) will be invalid.")
|
|
else: # have already enabled input require gradients
|
|
self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=custom_gradient_checkpointing_func)
|
|
|
|
|
|
def _fp32_forward_post_hook(
|
|
module: "torch.nn.Module", args: Tuple["torch.Tensor"], output: "torch.Tensor"
|
|
) -> "torch.Tensor":
|
|
return output.to(torch.float32)
|
|
|
|
|
|
def prepare_model_for_training(model: "PreTrainedModel", model_args: "ModelArguments") -> None:
|
|
r"""
|
|
Includes:
|
|
(1) cast the layernorm in fp32
|
|
(2) make output embedding layer require grads
|
|
(3) add the upcasting of the lm_head in fp32
|
|
"""
|
|
if model_args.upcast_layernorm:
|
|
logger.info("Upcasting layernorm weights in float32.")
|
|
for name, param in model.named_parameters():
|
|
if param.ndim == 1 and any(ln_name in name for ln_name in LAYERNORM_NAMES):
|
|
param.data = param.data.to(torch.float32)
|
|
|
|
if not model_args.disable_gradient_checkpointing:
|
|
if not getattr(model, "supports_gradient_checkpointing", False):
|
|
logger.warning("Current model does not support gradient checkpointing.")
|
|
else:
|
|
# use_reentrant=False might increase VRAM usage (have not been empirically verified yet)
|
|
# According to: https://github.com/huggingface/transformers/issues/28339
|
|
model.gradient_checkpointing_enable = MethodType(_gradient_checkpointing_enable, model)
|
|
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": True})
|
|
setattr(model.config, "use_cache", False) # turn off when gradient checkpointing is enabled
|
|
logger.info("Gradient checkpointing enabled.")
|
|
|
|
if model_args.upcast_lmhead_output:
|
|
output_layer = model.get_output_embeddings()
|
|
if isinstance(output_layer, torch.nn.Linear) and output_layer.weight.dtype != torch.float32:
|
|
logger.info("Upcasting lm_head outputs in float32.")
|
|
output_layer.register_forward_hook(_fp32_forward_post_hook)
|