Former-commit-id: f74f804a715dfb16bf24a056bc95db6b102f9ed7
This commit is contained in:
hiyouga 2024-03-07 16:15:53 +08:00
parent 90e66c8d94
commit 37e40563f1
4 changed files with 41 additions and 30 deletions

View File

@ -5,7 +5,7 @@ from typing import Any, Dict, Literal, Optional
@dataclass @dataclass
class ModelArguments: class ModelArguments:
r""" r"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune. Arguments pertaining to which model/config/tokenizer we are going to fine-tune or infer.
""" """
model_name_or_path: str = field( model_name_or_path: str = field(
@ -21,31 +21,35 @@ class ModelArguments:
default=None, default=None,
metadata={"help": "Where to store the pre-trained models downloaded from huggingface.co or modelscope.cn."}, metadata={"help": "Where to store the pre-trained models downloaded from huggingface.co or modelscope.cn."},
) )
use_fast_tokenizer: Optional[bool] = field( use_fast_tokenizer: bool = field(
default=False, default=False,
metadata={"help": "Whether or not to use one of the fast tokenizer (backed by the tokenizers library)."}, metadata={"help": "Whether or not to use one of the fast tokenizer (backed by the tokenizers library)."},
) )
resize_vocab: Optional[bool] = field( resize_vocab: bool = field(
default=False, default=False,
metadata={"help": "Whether or not to resize the tokenizer vocab and the embedding layers."}, metadata={"help": "Whether or not to resize the tokenizer vocab and the embedding layers."},
) )
split_special_tokens: Optional[bool] = field( split_special_tokens: bool = field(
default=False, default=False,
metadata={"help": "Whether or not the special tokens should be split during the tokenization process."}, metadata={"help": "Whether or not the special tokens should be split during the tokenization process."},
) )
model_revision: Optional[str] = field( model_revision: str = field(
default="main", default="main",
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
) )
low_cpu_mem_usage: bool = field(
default=True,
metadata={"help": "Whether or not to use memory-efficient model loading."},
)
quantization_bit: Optional[int] = field( quantization_bit: Optional[int] = field(
default=None, default=None,
metadata={"help": "The number of bits to quantize the model."}, metadata={"help": "The number of bits to quantize the model using bitsandbytes."},
) )
quantization_type: Optional[Literal["fp4", "nf4"]] = field( quantization_type: Literal["fp4", "nf4"] = field(
default="nf4", default="nf4",
metadata={"help": "Quantization data type to use in int4 training."}, metadata={"help": "Quantization data type to use in int4 training."},
) )
double_quantization: Optional[bool] = field( double_quantization: bool = field(
default=True, default=True,
metadata={"help": "Whether or not to use double quantization in int4 training."}, metadata={"help": "Whether or not to use double quantization in int4 training."},
) )
@ -53,30 +57,34 @@ class ModelArguments:
default=None, default=None,
metadata={"help": "Which scaling strategy should be adopted for the RoPE embeddings."}, metadata={"help": "Which scaling strategy should be adopted for the RoPE embeddings."},
) )
flash_attn: Optional[bool] = field( flash_attn: bool = field(
default=False, default=False,
metadata={"help": "Enable FlashAttention-2 for faster training."}, metadata={"help": "Enable FlashAttention-2 for faster training."},
) )
shift_attn: Optional[bool] = field( shift_attn: bool = field(
default=False, default=False,
metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."}, metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."},
) )
use_unsloth: Optional[bool] = field( use_unsloth: bool = field(
default=False, default=False,
metadata={"help": "Whether or not to use unsloth's optimization for the LoRA training."}, metadata={"help": "Whether or not to use unsloth's optimization for the LoRA training."},
) )
disable_gradient_checkpointing: Optional[bool] = field( disable_gradient_checkpointing: bool = field(
default=False, default=False,
metadata={"help": "Whether or not to disable gradient checkpointing."}, metadata={"help": "Whether or not to disable gradient checkpointing."},
) )
upcast_layernorm: Optional[bool] = field( upcast_layernorm: bool = field(
default=False, default=False,
metadata={"help": "Whether or not to upcast the layernorm weights in fp32."}, metadata={"help": "Whether or not to upcast the layernorm weights in fp32."},
) )
upcast_lmhead_output: Optional[bool] = field( upcast_lmhead_output: bool = field(
default=False, default=False,
metadata={"help": "Whether or not to upcast the output of lm_head in fp32."}, metadata={"help": "Whether or not to upcast the output of lm_head in fp32."},
) )
infer_backend: Literal["hf", "vllm"] = field(
default="hf",
metadata={"help": "Backend engine used at inference."},
)
hf_hub_token: Optional[str] = field( hf_hub_token: Optional[str] = field(
default=None, default=None,
metadata={"help": "Auth token to log in with Hugging Face Hub."}, metadata={"help": "Auth token to log in with Hugging Face Hub."},
@ -89,7 +97,7 @@ class ModelArguments:
default=None, default=None,
metadata={"help": "Path to the directory to save the exported model."}, metadata={"help": "Path to the directory to save the exported model."},
) )
export_size: Optional[int] = field( export_size: int = field(
default=1, default=1,
metadata={"help": "The file shard size (in GB) of the exported model."}, metadata={"help": "The file shard size (in GB) of the exported model."},
) )
@ -101,15 +109,15 @@ class ModelArguments:
default=None, default=None,
metadata={"help": "Path to the dataset or dataset name to use in quantizing the exported model."}, metadata={"help": "Path to the dataset or dataset name to use in quantizing the exported model."},
) )
export_quantization_nsamples: Optional[int] = field( export_quantization_nsamples: int = field(
default=128, default=128,
metadata={"help": "The number of samples used for quantization."}, metadata={"help": "The number of samples used for quantization."},
) )
export_quantization_maxlen: Optional[int] = field( export_quantization_maxlen: int = field(
default=1024, default=1024,
metadata={"help": "The maximum length of the model inputs used for quantization."}, metadata={"help": "The maximum length of the model inputs used for quantization."},
) )
export_legacy_format: Optional[bool] = field( export_legacy_format: bool = field(
default=False, default=False,
metadata={"help": "Whether or not to save the `.bin` files instead of `.safetensors`."}, metadata={"help": "Whether or not to save the `.bin` files instead of `.safetensors`."},
) )
@ -117,16 +125,15 @@ class ModelArguments:
default=None, default=None,
metadata={"help": "The name of the repository if push the model to the Hugging Face hub."}, metadata={"help": "The name of the repository if push the model to the Hugging Face hub."},
) )
print_param_status: Optional[bool] = field( print_param_status: bool = field(
default=False, default=False,
metadata={"help": "For debugging purposes, print the status of the parameters in the model."}, metadata={"help": "For debugging purposes, print the status of the parameters in the model."},
) )
aqlm_optimization: Optional[bool] = field(
default=False, metadata={"help": "Whether or not to optimize the training performance of AQLM models."}
)
def __post_init__(self): def __post_init__(self):
self.aqlm_optimization = None
self.compute_dtype = None self.compute_dtype = None
self.device_map = None
self.model_max_length = None self.model_max_length = None
if self.split_special_tokens and self.use_fast_tokenizer: if self.split_special_tokens and self.use_fast_tokenizer:

View File

@ -10,6 +10,7 @@ from transformers.trainer_utils import get_last_checkpoint
from ..extras.logging import get_logger from ..extras.logging import get_logger
from ..extras.packages import is_unsloth_available from ..extras.packages import is_unsloth_available
from ..extras.misc import check_dependencies
from .data_args import DataArguments from .data_args import DataArguments
from .evaluation_args import EvaluationArguments from .evaluation_args import EvaluationArguments
from .finetuning_args import FinetuningArguments from .finetuning_args import FinetuningArguments
@ -20,6 +21,9 @@ from .model_args import ModelArguments
logger = get_logger(__name__) logger = get_logger(__name__)
check_dependencies()
_TRAIN_ARGS = [ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments] _TRAIN_ARGS = [ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments]
_TRAIN_CLS = Tuple[ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments] _TRAIN_CLS = Tuple[ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments]
_INFER_ARGS = [ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments] _INFER_ARGS = [ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
@ -221,7 +225,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
training_args.local_rank, training_args.local_rank,
training_args.device, training_args.device,
training_args.n_gpu, training_args.n_gpu,
bool(training_args.local_rank != -1), training_args.parallel_mode.value == "distributed",
str(model_args.compute_dtype), str(model_args.compute_dtype),
) )
) )
@ -236,6 +240,8 @@ def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
_set_transformers_logging() _set_transformers_logging()
_verify_model_args(model_args, finetuning_args) _verify_model_args(model_args, finetuning_args)
model_args.aqlm_optimization = False
model_args.device_map = "auto"
if data_args.template is None: if data_args.template is None:
raise ValueError("Please specify which `template` to use.") raise ValueError("Please specify which `template` to use.")
@ -249,6 +255,7 @@ def get_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS:
_set_transformers_logging() _set_transformers_logging()
_verify_model_args(model_args, finetuning_args) _verify_model_args(model_args, finetuning_args)
model_args.aqlm_optimization = True model_args.aqlm_optimization = True
model_args.device_map = "auto"
if data_args.template is None: if data_args.template is None:
raise ValueError("Please specify which `template` to use.") raise ValueError("Please specify which `template` to use.")

View File

@ -5,7 +5,7 @@ from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from trl import AutoModelForCausalLMWithValueHead from trl import AutoModelForCausalLMWithValueHead
from ..extras.logging import get_logger from ..extras.logging import get_logger
from ..extras.misc import check_dependencies, count_parameters, get_current_device, try_download_model_from_ms from ..extras.misc import count_parameters, get_current_device, try_download_model_from_ms
from .adapter import init_adapter from .adapter import init_adapter
from .patcher import patch_config, patch_model, patch_tokenizer, patch_valuehead_model from .patcher import patch_config, patch_model, patch_tokenizer, patch_valuehead_model
from .utils import load_valuehead_params, register_autoclass from .utils import load_valuehead_params, register_autoclass
@ -20,9 +20,6 @@ if TYPE_CHECKING:
logger = get_logger(__name__) logger = get_logger(__name__)
check_dependencies()
def _get_init_kwargs(model_args: "ModelArguments") -> Dict[str, Any]: def _get_init_kwargs(model_args: "ModelArguments") -> Dict[str, Any]:
return { return {
"trust_remote_code": True, "trust_remote_code": True,

View File

@ -286,9 +286,9 @@ def patch_config(
init_kwargs["torch_dtype"] = model_args.compute_dtype init_kwargs["torch_dtype"] = model_args.compute_dtype
if not is_deepspeed_zero3_enabled(): if not is_deepspeed_zero3_enabled():
init_kwargs["low_cpu_mem_usage"] = True init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage
if "device_map" not in init_kwargs: if "device_map" not in init_kwargs: # quant models cannot use auto device map
init_kwargs["device_map"] = {"": get_current_device()} if is_trainable else "auto" init_kwargs["device_map"] = model_args.device_map or {"": get_current_device()}
def patch_model( def patch_model(