From 37e40563f13d2ba57c58049beb3484b8f726a498 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Thu, 7 Mar 2024 16:15:53 +0800 Subject: [PATCH] fix #2735 Former-commit-id: f74f804a715dfb16bf24a056bc95db6b102f9ed7 --- src/llmtuner/hparams/model_args.py | 51 +++++++++++++++++------------- src/llmtuner/hparams/parser.py | 9 +++++- src/llmtuner/model/loader.py | 5 +-- src/llmtuner/model/patcher.py | 6 ++-- 4 files changed, 41 insertions(+), 30 deletions(-) diff --git a/src/llmtuner/hparams/model_args.py b/src/llmtuner/hparams/model_args.py index 573efb21..f3972f66 100644 --- a/src/llmtuner/hparams/model_args.py +++ b/src/llmtuner/hparams/model_args.py @@ -5,7 +5,7 @@ from typing import Any, Dict, Literal, Optional @dataclass class ModelArguments: r""" - Arguments pertaining to which model/config/tokenizer we are going to fine-tune. + Arguments pertaining to which model/config/tokenizer we are going to fine-tune or infer. """ model_name_or_path: str = field( @@ -21,31 +21,35 @@ class ModelArguments: default=None, metadata={"help": "Where to store the pre-trained models downloaded from huggingface.co or modelscope.cn."}, ) - use_fast_tokenizer: Optional[bool] = field( + use_fast_tokenizer: bool = field( default=False, metadata={"help": "Whether or not to use one of the fast tokenizer (backed by the tokenizers library)."}, ) - resize_vocab: Optional[bool] = field( + resize_vocab: bool = field( default=False, metadata={"help": "Whether or not to resize the tokenizer vocab and the embedding layers."}, ) - split_special_tokens: Optional[bool] = field( + split_special_tokens: bool = field( default=False, metadata={"help": "Whether or not the special tokens should be split during the tokenization process."}, ) - model_revision: Optional[str] = field( + model_revision: str = field( default="main", metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, ) + low_cpu_mem_usage: bool = field( + default=True, + metadata={"help": "Whether or not to use memory-efficient model loading."}, + ) quantization_bit: Optional[int] = field( default=None, - metadata={"help": "The number of bits to quantize the model."}, + metadata={"help": "The number of bits to quantize the model using bitsandbytes."}, ) - quantization_type: Optional[Literal["fp4", "nf4"]] = field( + quantization_type: Literal["fp4", "nf4"] = field( default="nf4", metadata={"help": "Quantization data type to use in int4 training."}, ) - double_quantization: Optional[bool] = field( + double_quantization: bool = field( default=True, metadata={"help": "Whether or not to use double quantization in int4 training."}, ) @@ -53,30 +57,34 @@ class ModelArguments: default=None, metadata={"help": "Which scaling strategy should be adopted for the RoPE embeddings."}, ) - flash_attn: Optional[bool] = field( + flash_attn: bool = field( default=False, metadata={"help": "Enable FlashAttention-2 for faster training."}, ) - shift_attn: Optional[bool] = field( + shift_attn: bool = field( default=False, metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."}, ) - use_unsloth: Optional[bool] = field( + use_unsloth: bool = field( default=False, metadata={"help": "Whether or not to use unsloth's optimization for the LoRA training."}, ) - disable_gradient_checkpointing: Optional[bool] = field( + disable_gradient_checkpointing: bool = field( default=False, metadata={"help": "Whether or not to disable gradient checkpointing."}, ) - upcast_layernorm: Optional[bool] = field( + upcast_layernorm: bool = field( default=False, metadata={"help": "Whether or not to upcast the layernorm weights in fp32."}, ) - upcast_lmhead_output: Optional[bool] = field( + upcast_lmhead_output: bool = field( default=False, metadata={"help": "Whether or not to upcast the output of lm_head in fp32."}, ) + infer_backend: Literal["hf", "vllm"] = field( + default="hf", + metadata={"help": "Backend engine used at inference."}, + ) hf_hub_token: Optional[str] = field( default=None, metadata={"help": "Auth token to log in with Hugging Face Hub."}, @@ -89,7 +97,7 @@ class ModelArguments: default=None, metadata={"help": "Path to the directory to save the exported model."}, ) - export_size: Optional[int] = field( + export_size: int = field( default=1, metadata={"help": "The file shard size (in GB) of the exported model."}, ) @@ -101,15 +109,15 @@ class ModelArguments: default=None, metadata={"help": "Path to the dataset or dataset name to use in quantizing the exported model."}, ) - export_quantization_nsamples: Optional[int] = field( + export_quantization_nsamples: int = field( default=128, metadata={"help": "The number of samples used for quantization."}, ) - export_quantization_maxlen: Optional[int] = field( + export_quantization_maxlen: int = field( default=1024, metadata={"help": "The maximum length of the model inputs used for quantization."}, ) - export_legacy_format: Optional[bool] = field( + export_legacy_format: bool = field( default=False, metadata={"help": "Whether or not to save the `.bin` files instead of `.safetensors`."}, ) @@ -117,16 +125,15 @@ class ModelArguments: default=None, metadata={"help": "The name of the repository if push the model to the Hugging Face hub."}, ) - print_param_status: Optional[bool] = field( + print_param_status: bool = field( default=False, metadata={"help": "For debugging purposes, print the status of the parameters in the model."}, ) - aqlm_optimization: Optional[bool] = field( - default=False, metadata={"help": "Whether or not to optimize the training performance of AQLM models."} - ) def __post_init__(self): + self.aqlm_optimization = None self.compute_dtype = None + self.device_map = None self.model_max_length = None if self.split_special_tokens and self.use_fast_tokenizer: diff --git a/src/llmtuner/hparams/parser.py b/src/llmtuner/hparams/parser.py index f8a2aa4b..cad08b17 100644 --- a/src/llmtuner/hparams/parser.py +++ b/src/llmtuner/hparams/parser.py @@ -10,6 +10,7 @@ from transformers.trainer_utils import get_last_checkpoint from ..extras.logging import get_logger from ..extras.packages import is_unsloth_available +from ..extras.misc import check_dependencies from .data_args import DataArguments from .evaluation_args import EvaluationArguments from .finetuning_args import FinetuningArguments @@ -20,6 +21,9 @@ from .model_args import ModelArguments logger = get_logger(__name__) +check_dependencies() + + _TRAIN_ARGS = [ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments] _TRAIN_CLS = Tuple[ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments] _INFER_ARGS = [ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments] @@ -221,7 +225,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: training_args.local_rank, training_args.device, training_args.n_gpu, - bool(training_args.local_rank != -1), + training_args.parallel_mode.value == "distributed", str(model_args.compute_dtype), ) ) @@ -236,6 +240,8 @@ def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS: _set_transformers_logging() _verify_model_args(model_args, finetuning_args) + model_args.aqlm_optimization = False + model_args.device_map = "auto" if data_args.template is None: raise ValueError("Please specify which `template` to use.") @@ -249,6 +255,7 @@ def get_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS: _set_transformers_logging() _verify_model_args(model_args, finetuning_args) model_args.aqlm_optimization = True + model_args.device_map = "auto" if data_args.template is None: raise ValueError("Please specify which `template` to use.") diff --git a/src/llmtuner/model/loader.py b/src/llmtuner/model/loader.py index 588b5012..e5b3bdd1 100644 --- a/src/llmtuner/model/loader.py +++ b/src/llmtuner/model/loader.py @@ -5,7 +5,7 @@ from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer from trl import AutoModelForCausalLMWithValueHead from ..extras.logging import get_logger -from ..extras.misc import check_dependencies, count_parameters, get_current_device, try_download_model_from_ms +from ..extras.misc import count_parameters, get_current_device, try_download_model_from_ms from .adapter import init_adapter from .patcher import patch_config, patch_model, patch_tokenizer, patch_valuehead_model from .utils import load_valuehead_params, register_autoclass @@ -20,9 +20,6 @@ if TYPE_CHECKING: logger = get_logger(__name__) -check_dependencies() - - def _get_init_kwargs(model_args: "ModelArguments") -> Dict[str, Any]: return { "trust_remote_code": True, diff --git a/src/llmtuner/model/patcher.py b/src/llmtuner/model/patcher.py index 81097257..4ecfcc86 100644 --- a/src/llmtuner/model/patcher.py +++ b/src/llmtuner/model/patcher.py @@ -286,9 +286,9 @@ def patch_config( init_kwargs["torch_dtype"] = model_args.compute_dtype if not is_deepspeed_zero3_enabled(): - init_kwargs["low_cpu_mem_usage"] = True - if "device_map" not in init_kwargs: - init_kwargs["device_map"] = {"": get_current_device()} if is_trainable else "auto" + init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage + if "device_map" not in init_kwargs: # quant models cannot use auto device map + init_kwargs["device_map"] = model_args.device_map or {"": get_current_device()} def patch_model(