From d24d2f04582fb5a4f1868c20b2eed223f5da0290 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Fri, 15 Dec 2023 21:46:40 +0800 Subject: [PATCH] add configurer Former-commit-id: 2740aa9cbbcfc6dcfef82915b7db4e0f8b2c1bae --- src/llmtuner/model/loader.py | 63 ++++---------- src/llmtuner/model/{patches.py => patcher.py} | 85 +++++++++++++------ src/llmtuner/model/utils.py | 12 ++- 3 files changed, 83 insertions(+), 77 deletions(-) rename src/llmtuner/model/{patches.py => patcher.py} (54%) diff --git a/src/llmtuner/model/loader.py b/src/llmtuner/model/loader.py index 15d4fe03..ffc03827 100644 --- a/src/llmtuner/model/loader.py +++ b/src/llmtuner/model/loader.py @@ -1,20 +1,20 @@ from typing import TYPE_CHECKING, Optional, Tuple -from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig +from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer from transformers.integrations import is_deepspeed_zero3_enabled from transformers.utils.versions import require_version from trl import AutoModelForCausalLMWithValueHead +import llmtuner.model.patcher as patcher from llmtuner.extras.logging import get_logger -from llmtuner.extras.misc import count_parameters, get_current_device, try_download_model_from_ms -from llmtuner.extras.packages import is_flash_attn2_available -from llmtuner.hparams import FinetuningArguments +from llmtuner.extras.misc import count_parameters, try_download_model_from_ms from llmtuner.model.adapter import init_adapter -from llmtuner.model.patches import patch_config, patch_model, patch_valuehead_model, patch_tokenizer, register_autoclass -from llmtuner.model.utils import load_valuehead_params, prepare_model_for_training, resize_embedding_layer +from llmtuner.model.utils import ( + load_valuehead_params, prepare_model_for_training, resize_embedding_layer, register_autoclass +) if TYPE_CHECKING: from transformers import PreTrainedModel, PreTrainedTokenizer - from llmtuner.hparams import ModelArguments + from llmtuner.hparams import ModelArguments, FinetuningArguments logger = get_logger(__name__) @@ -55,45 +55,15 @@ def load_model_and_tokenizer( padding_side="right", # training with left-padded tensors in fp16 precision may cause overflow **config_kwargs ) - patch_tokenizer(tokenizer) - config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs) - patch_config(config, model_args, is_trainable) - # Set FlashAttention-2 - if model_args.flash_attn and is_flash_attn2_available(): - config_kwargs["use_flash_attention_2"] = True - logger.info("Using FlashAttention-2 for faster training and inference.") + patcher.patch_tokenizer(tokenizer) + patcher.patch_config(config, model_args, is_trainable) + patcher.configure_rope(config, model_args, is_trainable) + patcher.configure_flashattn(config, model_args) + patcher.configure_longlora(config, model_args, is_trainable) + patcher.configure_quantization(config, config_kwargs, model_args) - # Quantization configurations (using gptq or awq) - if getattr(config, "quantization_config", None): - model_args.quantization_bit = None # remove bnb quantization - config_kwargs["device_map"] = {"": get_current_device()} - quantization_config = getattr(config, "quantization_config", None) - logger.info("Loading {}-bit pre-quantized model.".format(quantization_config.get("bits", -1))) - - # Quantization configurations (using bitsandbytes) - if model_args.quantization_bit is not None: - if is_deepspeed_zero3_enabled(): - raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.") - - if model_args.quantization_bit == 8: - require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0") - config_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True) - - if model_args.quantization_bit == 4: - require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0") - config_kwargs["quantization_config"] = BitsAndBytesConfig( - load_in_4bit=True, - bnb_4bit_compute_dtype=model_args.compute_dtype, - bnb_4bit_use_double_quant=model_args.double_quantization, - bnb_4bit_quant_type=model_args.quantization_type - ) - - config_kwargs["device_map"] = {"": get_current_device()} - logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit)) - - # Load pre-trained models (without valuehead) model = AutoModelForCausalLM.from_pretrained( model_args.model_name_or_path, config=config, @@ -101,23 +71,20 @@ def load_model_and_tokenizer( low_cpu_mem_usage=(not is_deepspeed_zero3_enabled()), **config_kwargs ) - patch_model(model) + patcher.patch_model(model) register_autoclass(config, model, tokenizer) resize_embedding_layer(model, tokenizer) - # Initialize adapters model = prepare_model_for_training(model=model, finetuning_args=finetuning_args) if is_trainable else model model = init_adapter(model, model_args, finetuning_args, is_trainable) - # Prepare model with valuehead for RLHF if add_valuehead: model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(model) - patch_valuehead_model(model) + patcher.patch_valuehead_model(model) vhead_params = load_valuehead_params(model_args) if vhead_params is not None: model.load_state_dict(vhead_params, strict=False) - # Prepare model for inference if not is_trainable: model.requires_grad_(False) # fix all model params model = model.to(model_args.compute_dtype) if not getattr(model, "quantization_method", None) else model diff --git a/src/llmtuner/model/patches.py b/src/llmtuner/model/patcher.py similarity index 54% rename from src/llmtuner/model/patches.py rename to src/llmtuner/model/patcher.py index 7c632ec2..e90976bd 100644 --- a/src/llmtuner/model/patches.py +++ b/src/llmtuner/model/patcher.py @@ -1,12 +1,15 @@ import math import torch from types import MethodType -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Dict -from transformers import PreTrainedModel, PreTrainedTokenizerBase +from transformers import BitsAndBytesConfig, PreTrainedModel, PreTrainedTokenizerBase +from transformers.integrations import is_deepspeed_zero3_enabled +from transformers.utils.versions import require_version from llmtuner.extras.logging import get_logger -from llmtuner.extras.misc import infer_optim_dtype +from llmtuner.extras.misc import get_current_device, infer_optim_dtype +from llmtuner.extras.packages import is_flash_attn2_available if TYPE_CHECKING: from transformers import PretrainedConfig, PreTrainedTokenizer @@ -15,17 +18,53 @@ if TYPE_CHECKING: logger = get_logger(__name__) +SUPPORTED_CLASS_FOR_S2ATTN = [] # TODO: add llama -def patch_config(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool): - if model_args.compute_dtype is None: # priority: bf16 > fp16 > fp32 - model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None)) - setattr(config, "torch_dtype", model_args.compute_dtype) +def configure_flashattn(config_kwargs: Dict[str, Any], model_args: "ModelArguments"): + if model_args.flash_attn and is_flash_attn2_available(): + config_kwargs["use_flash_attention_2"] = True + logger.info("Using FlashAttention-2 for faster training and inference.") - if getattr(config, "model_type", None) == "qwen": - for dtype_name, dtype in [("fp16", torch.float16), ("bf16", torch.bfloat16), ("fp32", torch.float32)]: - setattr(config, dtype_name, getattr(config, "torch_dtype", None) == dtype) +def configure_longlora(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool): + if is_trainable and model_args.shift_attn: + if getattr(config, "model_type", None) in SUPPORTED_CLASS_FOR_S2ATTN: + setattr(config, "group_size_ratio", 0.25) + logger.info("Using shift short attention with group_size_ratio=1/4.") + else: + logger.warning("Current model does not support shift short attention.") + + +def configure_quantization(config: "PretrainedConfig", config_kwargs: Dict[str, Any], model_args: "ModelArguments"): + if getattr(config, "quantization_config", None): # gptq or awq + model_args.quantization_bit = None # remove bnb quantization + config_kwargs["device_map"] = {"": get_current_device()} + quantization_config = getattr(config, "quantization_config", None) + logger.info("Loading {}-bit pre-quantized model.".format(quantization_config.get("bits", -1))) + + if model_args.quantization_bit is not None: # bnb + if is_deepspeed_zero3_enabled(): + raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.") + + if model_args.quantization_bit == 8: + require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0") + config_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True) + + if model_args.quantization_bit == 4: + require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0") + config_kwargs["quantization_config"] = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_compute_dtype=model_args.compute_dtype, + bnb_4bit_use_double_quant=model_args.double_quantization, + bnb_4bit_quant_type=model_args.quantization_type + ) + + config_kwargs["device_map"] = {"": get_current_device()} + logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit)) + + +def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool): if model_args.rope_scaling is not None: if not hasattr(config, "rope_scaling"): logger.warning("Current model does not support RoPE scaling.") @@ -51,14 +90,15 @@ def patch_config(config: "PretrainedConfig", model_args: "ModelArguments", is_tr model_args.rope_scaling, scaling_factor )) - # Set shift short attention (S^2-Attn) - if is_trainable and model_args.shift_attn: - logger.warning("Shift short attention is temporarily invalid due to breaking changes.") - # if getattr(config, "model_type", None) == "llama": - # setattr(config, "group_size_ratio", 0.25) - # logger.info("Using shift short attention with group_size_ratio=1/4.") - # else: - # logger.warning("Current model does not support shift short attention.") + +def patch_config(config: "PretrainedConfig", model_args: "ModelArguments"): + if model_args.compute_dtype is None: # priority: bf16 > fp16 > fp32 + model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None)) + setattr(config, "torch_dtype", model_args.compute_dtype) + + if getattr(config, "model_type", None) == "qwen": + for dtype_name, dtype in [("fp16", torch.float16), ("bf16", torch.bfloat16), ("fp32", torch.float32)]: + setattr(config, dtype_name, getattr(config, "torch_dtype", None) == dtype) def patch_model(model: "PreTrainedModel"): @@ -83,12 +123,3 @@ def patch_valuehead_model(model: "AutoModelForCausalLMWithValueHead"): def patch_tokenizer(tokenizer: "PreTrainedTokenizer"): if "PreTrainedTokenizerBase" not in str(tokenizer._pad.__func__): tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer) - - -def register_autoclass(config: "PretrainedConfig", model: "PreTrainedModel", tokenizer: "PreTrainedTokenizerBase"): - if "AutoConfig" in getattr(config, "auto_map", {}): - config.__class__.register_for_auto_class() - if "AutoModelForCausalLM" in getattr(config, "auto_map", {}): - model.__class__.register_for_auto_class() - if "AutoTokenizer" in tokenizer.init_kwargs.get("auto_map", {}): - tokenizer.__class__.register_for_auto_class() diff --git a/src/llmtuner/model/utils.py b/src/llmtuner/model/utils.py index d173c0d4..b173e1d9 100644 --- a/src/llmtuner/model/utils.py +++ b/src/llmtuner/model/utils.py @@ -9,8 +9,7 @@ from llmtuner.extras.logging import get_logger from llmtuner.hparams import ModelArguments, FinetuningArguments if TYPE_CHECKING: - from transformers.modeling_utils import PreTrainedModel - from transformers.tokenization_utils import PreTrainedTokenizer + from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer from llmtuner.hparams import DataArguments @@ -183,3 +182,12 @@ def resize_embedding_layer(model: "PreTrainedModel", tokenizer: "PreTrainedToken model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=64) new_embedding_size = model.get_input_embeddings().weight.size(0) logger.info("Resized token embeddings from {} to {}.".format(current_embedding_size, new_embedding_size)) + + +def register_autoclass(config: "PretrainedConfig", model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer"): + if "AutoConfig" in getattr(config, "auto_map", {}): + config.__class__.register_for_auto_class() + if "AutoModelForCausalLM" in getattr(config, "auto_map", {}): + model.__class__.register_for_auto_class() + if "AutoTokenizer" in tokenizer.init_kwargs.get("auto_map", {}): + tokenizer.__class__.register_for_auto_class()