update loader

This commit is contained in:
hiyouga
2023-12-24 19:10:23 +08:00
parent e44b82ee24
commit 6629087e12
6 changed files with 67 additions and 68 deletions

View File

@@ -3,14 +3,14 @@ import math
import torch
import random
from types import MethodType
from typing import TYPE_CHECKING, Any, Dict, List
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple
from datasets import load_dataset
from transformers import BitsAndBytesConfig, GPTQConfig, PreTrainedModel, PreTrainedTokenizerBase
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.utils.versions import require_version
from llmtuner.extras.constants import FILEEXT2TYPE
from llmtuner.extras.constants import FILEEXT2TYPE, LAYERNORM_NAMES
from llmtuner.extras.logging import get_logger
from llmtuner.extras.misc import get_current_device, infer_optim_dtype
from llmtuner.extras.packages import is_flash_attn2_available
@@ -180,6 +180,42 @@ def _configure_quantization(
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
def _prepare_model_for_training(
model: "PreTrainedModel",
model_args: "ModelArguments",
output_layer_name: Optional[str] = "lm_head"
) -> None:
r"""
Includes:
(1) cast the layernorm in fp32
(2) make output embedding layer require grads
(3) add the upcasting of the lm_head in fp32
Inspired by: https://github.com/huggingface/peft/blob/v0.7.1/src/peft/utils/other.py#L72
"""
if model_args.upcast_layernorm:
for name, param in model.named_parameters():
if param.ndim == 1 and any(ln_name in name for ln_name in LAYERNORM_NAMES):
param.data = param.data.to(torch.float32)
logger.info("Upcasting layernorm weights in float32.")
if not model_args.disable_gradient_checkpointing:
if getattr(model, "supports_gradient_checkpointing", False):
logger.warning("Current model does not support gradient checkpointing.")
else:
model.enable_input_require_grads()
model.gradient_checkpointing_enable()
model.config.use_cache = False # turn off when gradient checkpointing is enabled
logger.info("Gradient checkpointing enabled.")
if hasattr(model, output_layer_name):
def fp32_forward_post_hook(module: torch.nn.Module, args: Tuple[torch.Tensor], output: torch.Tensor):
return output.to(torch.float32)
output_layer = getattr(model, output_layer_name)
if isinstance(output_layer, torch.nn.Linear) and output_layer.weight.dtype != torch.float32:
output_layer.register_forward_hook(fp32_forward_post_hook)
def patch_tokenizer(tokenizer: "PreTrainedTokenizer") -> None:
if "PreTrainedTokenizerBase" not in str(tokenizer._pad.__func__):
tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer)
@@ -206,7 +242,12 @@ def patch_config(
_configure_quantization(config, tokenizer, model_args, config_kwargs)
def patch_model(model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments") -> None:
def patch_model(
model: "PreTrainedModel",
tokenizer: "PreTrainedTokenizer",
model_args: "ModelArguments",
is_trainable: bool
) -> None:
if "GenerationMixin" not in str(model.generate.__func__):
model.generate = MethodType(PreTrainedModel.generate, model)
@@ -220,6 +261,10 @@ def patch_model(model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer", mode
_resize_embedding_layer(model, tokenizer)
if is_trainable:
_prepare_model_for_training(model, model_args)
def patch_valuehead_model(model: "AutoModelForCausalLMWithValueHead") -> None:
def tie_weights(self: "AutoModelForCausalLMWithValueHead") -> None:
if isinstance(self.pretrained_model, PreTrainedModel):