From 5d440f978e96f4b9b50f81e6b72e3e8fd186a6ef Mon Sep 17 00:00:00 2001 From: hiyouga Date: Sat, 23 Dec 2023 15:24:27 +0800 Subject: [PATCH] update patcher Former-commit-id: e44b82ee245a7ee99057c7b58b1edef5c222dc1f --- README.md | 2 +- README_zh.md | 2 +- src/llmtuner/hparams/model_args.py | 4 + src/llmtuner/model/loader.py | 8 +- src/llmtuner/model/patcher.py | 222 +++++++++++++++++------------ src/llmtuner/model/utils.py | 27 ---- 6 files changed, 135 insertions(+), 130 deletions(-) diff --git a/README.md b/README.md index fff29b6a..3cb26ae8 100644 --- a/README.md +++ b/README.md @@ -214,7 +214,7 @@ huggingface-cli login | Method | Bits | 7B | 13B | 30B | 65B | 8x7B | | ------ | ---- | ----- | ----- | ----- | ------ | ------ | -| Full | 16 | 160GB | 320GB | 600GB | 1200GB | 1000GB | +| Full | 16 | 160GB | 320GB | 600GB | 1200GB | 900GB | | Freeze | 16 | 20GB | 40GB | 120GB | 240GB | 200GB | | LoRA | 16 | 16GB | 32GB | 80GB | 160GB | 120GB | | QLoRA | 8 | 10GB | 16GB | 40GB | 80GB | 80GB | diff --git a/README_zh.md b/README_zh.md index 5e39e5f4..ac47fbec 100644 --- a/README_zh.md +++ b/README_zh.md @@ -214,7 +214,7 @@ huggingface-cli login | 训练方法 | 精度 | 7B | 13B | 30B | 65B | 8x7B | | ------- | ---- | ----- | ----- | ----- | ------ | ------ | -| 全参数 | 16 | 160GB | 320GB | 600GB | 1200GB | 1000GB | +| 全参数 | 16 | 160GB | 320GB | 600GB | 1200GB | 900GB | | 部分参数 | 16 | 20GB | 40GB | 120GB | 240GB | 200GB | | LoRA | 16 | 16GB | 32GB | 80GB | 160GB | 120GB | | QLoRA | 8 | 10GB | 16GB | 40GB | 80GB | 80GB | diff --git a/src/llmtuner/hparams/model_args.py b/src/llmtuner/hparams/model_args.py index 24ca5dc1..4c999a01 100644 --- a/src/llmtuner/hparams/model_args.py +++ b/src/llmtuner/hparams/model_args.py @@ -22,6 +22,10 @@ class ModelArguments: default=False, metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."} ) + resize_vocab: Optional[bool] = field( + default=False, + metadata={"help": "Whether to resize the tokenizer vocab and the embedding layers."} + ) split_special_tokens: Optional[bool] = field( default=False, metadata={"help": "Whether or not the special tokens should be split during the tokenization process."} diff --git a/src/llmtuner/model/loader.py b/src/llmtuner/model/loader.py index 4ab707a6..6d96b674 100644 --- a/src/llmtuner/model/loader.py +++ b/src/llmtuner/model/loader.py @@ -8,9 +8,7 @@ from llmtuner.extras.logging import get_logger from llmtuner.extras.misc import count_parameters, get_current_device, try_download_model_from_ms from llmtuner.model.adapter import init_adapter from llmtuner.model.patcher import patch_config, patch_tokenizer, patch_model, patch_valuehead_model -from llmtuner.model.utils import ( - load_valuehead_params, prepare_model_for_training, resize_embedding_layer, register_autoclass -) +from llmtuner.model.utils import load_valuehead_params, prepare_model_for_training, register_autoclass if TYPE_CHECKING: from transformers import PreTrainedModel, PreTrainedTokenizer @@ -94,10 +92,8 @@ def load_model_and_tokenizer( ) model = model.to(model_args.compute_dtype) if not getattr(model, "quantization_method", None) else model - patch_model(model) + patch_model(model, tokenizer, model_args) register_autoclass(config, model, tokenizer) - if not is_deepspeed_zero3_enabled(): - resize_embedding_layer(model, tokenizer) model = prepare_model_for_training(model=model, finetuning_args=finetuning_args) if is_trainable else model model = init_adapter(model, model_args, finetuning_args, is_trainable) diff --git a/src/llmtuner/model/patcher.py b/src/llmtuner/model/patcher.py index 80438a25..09397d6a 100644 --- a/src/llmtuner/model/patcher.py +++ b/src/llmtuner/model/patcher.py @@ -25,105 +25,34 @@ logger = get_logger(__name__) SUPPORTED_CLASS_FOR_S2ATTN = [] # TODO: add llama -def _configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool): - if model_args.rope_scaling is not None: - if not hasattr(config, "rope_scaling"): - logger.warning("Current model does not support RoPE scaling.") - else: - if is_trainable: - if model_args.rope_scaling == "dynamic": - logger.warning( - "Dynamic NTK may not work well with fine-tuning. " - "See: https://github.com/huggingface/transformers/pull/24653" - ) - - current_max_length = getattr(config, "max_position_embeddings", None) - if current_max_length and model_args.model_max_length > current_max_length: - scaling_factor = float(math.ceil(model_args.model_max_length / current_max_length)) - else: - logger.warning("Input length is smaller than max length. Consider increase input length.") - scaling_factor = 1.0 - else: - scaling_factor = 2.0 - - setattr(config, "rope_scaling", {"type": model_args.rope_scaling, "factor": scaling_factor}) - logger.info("Using {} scaling strategy and setting scaling factor to {}".format( - model_args.rope_scaling, scaling_factor - )) +def _noisy_mean_initialization(embed_weight: torch.Tensor, num_new_tokens: int): + embedding_dim = embed_weight.size(1) + avg_weight = embed_weight[:-num_new_tokens].mean(dim=0, keepdim=True) + noise_weight = torch.empty_like(avg_weight[-num_new_tokens:]) + noise_weight.normal_(mean=0, std=(1.0 / math.sqrt(embedding_dim))) + embed_weight[-num_new_tokens:] = avg_weight + noise_weight -def _configure_flashattn(model_args: "ModelArguments", config_kwargs: Dict[str, Any]): - if model_args.flash_attn and is_flash_attn2_available(): - config_kwargs["use_flash_attention_2"] = True - config_kwargs["torch_dtype"] = model_args.compute_dtype - logger.info("Using FlashAttention-2 for faster training and inference.") - - -def _configure_longlora(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool): - if is_trainable and model_args.shift_attn: - if getattr(config, "model_type", None) in SUPPORTED_CLASS_FOR_S2ATTN: - setattr(config, "group_size_ratio", 0.25) - logger.info("Using shift short attention with group_size_ratio=1/4.") - else: - logger.warning("Current model does not support shift short attention.") - - -def _configure_quantization( - config: "PretrainedConfig", - tokenizer: "PreTrainedTokenizer", - model_args: "ModelArguments", - config_kwargs: Dict[str, Any] -): +def _resize_embedding_layer(model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer") -> None: r""" - Priority: Pre-quantized (training) > AutoGPTQ (export) > Bitsandbytes (training) + Resize token embeddings. """ - if getattr(config, "quantization_config", None): # gptq or awq - if is_deepspeed_zero3_enabled(): - raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.") + current_embedding_size = model.get_input_embeddings().weight.size(0) + if len(tokenizer) > current_embedding_size: + if not isinstance(model.get_output_embeddings(), torch.nn.Linear): + logger.warning("Current model does not support resizing token embeddings.") + return - config_kwargs["device_map"] = {"": get_current_device()} - quantization_config = getattr(config, "quantization_config", None) - logger.info("Loading {}-bit pre-quantized model.".format(quantization_config.get("bits", -1))) + model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=64) + new_embedding_size = model.get_input_embeddings().weight.size(0) + num_new_tokens = new_embedding_size - current_embedding_size + _noisy_mean_initialization(model.get_input_embeddings().weight.data, num_new_tokens) + _noisy_mean_initialization(model.get_output_embeddings().weight.data, num_new_tokens) - elif model_args.export_quantization_bit is not None: # gptq - require_version("optimum>=1.16.0", "To fix: pip install optimum>=1.16.0") - require_version("auto_gptq>=0.5.0", "To fix: pip install auto_gptq>=0.5.0") - from accelerate.utils import get_max_memory - - if getattr(config, "model_type", None) == "chatglm": - raise ValueError("ChatGLM model is not supported.") - - config_kwargs["quantization_config"] = GPTQConfig( - bits=model_args.export_quantization_bit, - tokenizer=tokenizer, - dataset=get_quantization_dataset(tokenizer, model_args) - ) - config_kwargs["device_map"] = "auto" - config_kwargs["max_memory"] = get_max_memory() - logger.info("Quantizing model to {} bit.".format(model_args.export_quantization_bit)) - - elif model_args.quantization_bit is not None: # bnb - if is_deepspeed_zero3_enabled(): - raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.") - - if model_args.quantization_bit == 8: - require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0") - config_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True) - - elif model_args.quantization_bit == 4: - require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0") - config_kwargs["quantization_config"] = BitsAndBytesConfig( - load_in_4bit=True, - bnb_4bit_compute_dtype=model_args.compute_dtype, - bnb_4bit_use_double_quant=model_args.double_quantization, - bnb_4bit_quant_type=model_args.quantization_type - ) - - config_kwargs["device_map"] = {"": get_current_device()} - logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit)) + logger.info("Resized token embeddings from {} to {}.".format(current_embedding_size, new_embedding_size)) -def get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments") -> List[str]: +def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments") -> List[str]: r""" Inspired by: https://github.com/huggingface/optimum/blob/v1.16.0/optimum/gptq/data.py#L133 TODO: remove tokenizer.decode() https://github.com/huggingface/optimum/pull/1600 @@ -153,7 +82,105 @@ def get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "Mode return samples -def patch_tokenizer(tokenizer: "PreTrainedTokenizer"): +def _configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None: + if model_args.rope_scaling is not None: + if not hasattr(config, "rope_scaling"): + logger.warning("Current model does not support RoPE scaling.") + else: + if is_trainable: + if model_args.rope_scaling == "dynamic": + logger.warning( + "Dynamic NTK scaling may not work well with fine-tuning. " + "See: https://github.com/huggingface/transformers/pull/24653" + ) + + current_max_length = getattr(config, "max_position_embeddings", None) + if current_max_length and model_args.model_max_length > current_max_length: + scaling_factor = float(math.ceil(model_args.model_max_length / current_max_length)) + else: + logger.warning("Input length is smaller than max length. Consider increase input length.") + scaling_factor = 1.0 + else: + scaling_factor = 2.0 + + setattr(config, "rope_scaling", {"type": model_args.rope_scaling, "factor": scaling_factor}) + logger.info("Using {} scaling strategy and setting scaling factor to {}".format( + model_args.rope_scaling, scaling_factor + )) + + +def _configure_flashattn(model_args: "ModelArguments", config_kwargs: Dict[str, Any]) -> None: + if model_args.flash_attn and is_flash_attn2_available(): + config_kwargs["use_flash_attention_2"] = True + config_kwargs["torch_dtype"] = model_args.compute_dtype + logger.info("Using FlashAttention-2 for faster training and inference.") + + +def _configure_longlora(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None: + if is_trainable and model_args.shift_attn: + if getattr(config, "model_type", None) in SUPPORTED_CLASS_FOR_S2ATTN: + setattr(config, "group_size_ratio", 0.25) + logger.info("Using shift short attention with group_size_ratio=1/4.") + else: + logger.warning("Current model does not support shift short attention.") + + +def _configure_quantization( + config: "PretrainedConfig", + tokenizer: "PreTrainedTokenizer", + model_args: "ModelArguments", + config_kwargs: Dict[str, Any] +) -> None: + r""" + Priority: Pre-quantized (training) > AutoGPTQ (export) > Bitsandbytes (training) + """ + if getattr(config, "quantization_config", None): # gptq or awq + if is_deepspeed_zero3_enabled(): + raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.") + + config_kwargs["device_map"] = {"": get_current_device()} + quantization_config = getattr(config, "quantization_config", None) + logger.info("Loading {}-bit pre-quantized model.".format(quantization_config.get("bits", -1))) + + elif model_args.export_quantization_bit is not None: # gptq + require_version("optimum>=1.16.0", "To fix: pip install optimum>=1.16.0") + require_version("auto_gptq>=0.5.0", "To fix: pip install auto_gptq>=0.5.0") + from accelerate.utils import get_max_memory + + if getattr(config, "model_type", None) == "chatglm": + raise ValueError("ChatGLM model is not supported.") + + config_kwargs["quantization_config"] = GPTQConfig( + bits=model_args.export_quantization_bit, + tokenizer=tokenizer, + dataset=_get_quantization_dataset(tokenizer, model_args) + ) + config_kwargs["device_map"] = "auto" + config_kwargs["max_memory"] = get_max_memory() + logger.info("Quantizing model to {} bit.".format(model_args.export_quantization_bit)) + + elif model_args.quantization_bit is not None: # bnb + if is_deepspeed_zero3_enabled(): + raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.") + + if model_args.quantization_bit == 8: + require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0") + config_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True) + + elif model_args.quantization_bit == 4: + require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0") + config_kwargs["quantization_config"] = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_compute_dtype=model_args.compute_dtype, + bnb_4bit_use_double_quant=model_args.double_quantization, + bnb_4bit_quant_type=model_args.quantization_type + ) + + config_kwargs["device_map"] = {"": get_current_device()} + logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit)) + + +def patch_tokenizer(tokenizer: "PreTrainedTokenizer") -> None: if "PreTrainedTokenizerBase" not in str(tokenizer._pad.__func__): tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer) @@ -164,7 +191,7 @@ def patch_config( model_args: "ModelArguments", config_kwargs: Dict[str, Any], is_trainable: bool -): +) -> None: if model_args.compute_dtype is None: # priority: bf16 > fp16 > fp32 model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None)) setattr(config, "torch_dtype", model_args.compute_dtype) @@ -179,7 +206,7 @@ def patch_config( _configure_quantization(config, tokenizer, model_args, config_kwargs) -def patch_model(model: "PreTrainedModel"): +def patch_model(model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments") -> None: if "GenerationMixin" not in str(model.generate.__func__): model.generate = MethodType(PreTrainedModel.generate, model) @@ -187,8 +214,13 @@ def patch_model(model: "PreTrainedModel"): setattr(model, "lm_head", model.transformer.output_layer) setattr(model, "_keys_to_ignore_on_save", ["lm_head.weight"]) + if model_args.resize_vocab: + if is_deepspeed_zero3_enabled(): + raise ValueError("DeepSpeed ZeRO-3 is incompatible with vocab resizing.") -def patch_valuehead_model(model: "AutoModelForCausalLMWithValueHead"): + _resize_embedding_layer(model, tokenizer) + +def patch_valuehead_model(model: "AutoModelForCausalLMWithValueHead") -> None: def tie_weights(self: "AutoModelForCausalLMWithValueHead") -> None: if isinstance(self.pretrained_model, PreTrainedModel): self.pretrained_model.tie_weights() diff --git a/src/llmtuner/model/utils.py b/src/llmtuner/model/utils.py index c4f5de38..302d5125 100644 --- a/src/llmtuner/model/utils.py +++ b/src/llmtuner/model/utils.py @@ -123,14 +123,6 @@ def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") -> return None -def noisy_mean_initialization(embed_weight: torch.Tensor, num_new_tokens: int): - embedding_dim = embed_weight.size(1) - avg_weight = embed_weight[:-num_new_tokens].mean(dim=0, keepdim=True) - noise_weight = torch.empty_like(avg_weight[-num_new_tokens:]) - noise_weight.normal_(mean=0, std=(1.0 / math.sqrt(embedding_dim))) - embed_weight[-num_new_tokens:] = avg_weight + noise_weight - - def prepare_model_for_training( model: "PreTrainedModel", finetuning_args: "FinetuningArguments", @@ -176,25 +168,6 @@ def prepare_model_for_training( return model -def resize_embedding_layer(model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer") -> None: - r""" - Resize token embeddings. - """ - current_embedding_size = model.get_input_embeddings().weight.size(0) - if len(tokenizer) > current_embedding_size: - if not isinstance(model.get_output_embeddings(), torch.nn.Linear): - logger.warning("Current model does not support resizing token embeddings.") - return - - model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=64) - new_embedding_size = model.get_input_embeddings().weight.size(0) - num_new_tokens = new_embedding_size - current_embedding_size - noisy_mean_initialization(model.get_input_embeddings().weight.data, num_new_tokens) - noisy_mean_initialization(model.get_output_embeddings().weight.data, num_new_tokens) - - logger.info("Resized token embeddings from {} to {}.".format(current_embedding_size, new_embedding_size)) - - def register_autoclass(config: "PretrainedConfig", model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer"): if "AutoConfig" in getattr(config, "auto_map", {}): config.__class__.register_for_auto_class()