mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-22 13:42:51 +08:00
parent
90e66c8d94
commit
37e40563f1
@ -5,7 +5,7 @@ from typing import Any, Dict, Literal, Optional
|
|||||||
@dataclass
|
@dataclass
|
||||||
class ModelArguments:
|
class ModelArguments:
|
||||||
r"""
|
r"""
|
||||||
Arguments pertaining to which model/config/tokenizer we are going to fine-tune.
|
Arguments pertaining to which model/config/tokenizer we are going to fine-tune or infer.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model_name_or_path: str = field(
|
model_name_or_path: str = field(
|
||||||
@ -21,31 +21,35 @@ class ModelArguments:
|
|||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Where to store the pre-trained models downloaded from huggingface.co or modelscope.cn."},
|
metadata={"help": "Where to store the pre-trained models downloaded from huggingface.co or modelscope.cn."},
|
||||||
)
|
)
|
||||||
use_fast_tokenizer: Optional[bool] = field(
|
use_fast_tokenizer: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Whether or not to use one of the fast tokenizer (backed by the tokenizers library)."},
|
metadata={"help": "Whether or not to use one of the fast tokenizer (backed by the tokenizers library)."},
|
||||||
)
|
)
|
||||||
resize_vocab: Optional[bool] = field(
|
resize_vocab: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Whether or not to resize the tokenizer vocab and the embedding layers."},
|
metadata={"help": "Whether or not to resize the tokenizer vocab and the embedding layers."},
|
||||||
)
|
)
|
||||||
split_special_tokens: Optional[bool] = field(
|
split_special_tokens: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Whether or not the special tokens should be split during the tokenization process."},
|
metadata={"help": "Whether or not the special tokens should be split during the tokenization process."},
|
||||||
)
|
)
|
||||||
model_revision: Optional[str] = field(
|
model_revision: str = field(
|
||||||
default="main",
|
default="main",
|
||||||
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
|
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
|
||||||
)
|
)
|
||||||
|
low_cpu_mem_usage: bool = field(
|
||||||
|
default=True,
|
||||||
|
metadata={"help": "Whether or not to use memory-efficient model loading."},
|
||||||
|
)
|
||||||
quantization_bit: Optional[int] = field(
|
quantization_bit: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "The number of bits to quantize the model."},
|
metadata={"help": "The number of bits to quantize the model using bitsandbytes."},
|
||||||
)
|
)
|
||||||
quantization_type: Optional[Literal["fp4", "nf4"]] = field(
|
quantization_type: Literal["fp4", "nf4"] = field(
|
||||||
default="nf4",
|
default="nf4",
|
||||||
metadata={"help": "Quantization data type to use in int4 training."},
|
metadata={"help": "Quantization data type to use in int4 training."},
|
||||||
)
|
)
|
||||||
double_quantization: Optional[bool] = field(
|
double_quantization: bool = field(
|
||||||
default=True,
|
default=True,
|
||||||
metadata={"help": "Whether or not to use double quantization in int4 training."},
|
metadata={"help": "Whether or not to use double quantization in int4 training."},
|
||||||
)
|
)
|
||||||
@ -53,30 +57,34 @@ class ModelArguments:
|
|||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Which scaling strategy should be adopted for the RoPE embeddings."},
|
metadata={"help": "Which scaling strategy should be adopted for the RoPE embeddings."},
|
||||||
)
|
)
|
||||||
flash_attn: Optional[bool] = field(
|
flash_attn: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Enable FlashAttention-2 for faster training."},
|
metadata={"help": "Enable FlashAttention-2 for faster training."},
|
||||||
)
|
)
|
||||||
shift_attn: Optional[bool] = field(
|
shift_attn: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."},
|
metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."},
|
||||||
)
|
)
|
||||||
use_unsloth: Optional[bool] = field(
|
use_unsloth: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Whether or not to use unsloth's optimization for the LoRA training."},
|
metadata={"help": "Whether or not to use unsloth's optimization for the LoRA training."},
|
||||||
)
|
)
|
||||||
disable_gradient_checkpointing: Optional[bool] = field(
|
disable_gradient_checkpointing: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Whether or not to disable gradient checkpointing."},
|
metadata={"help": "Whether or not to disable gradient checkpointing."},
|
||||||
)
|
)
|
||||||
upcast_layernorm: Optional[bool] = field(
|
upcast_layernorm: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Whether or not to upcast the layernorm weights in fp32."},
|
metadata={"help": "Whether or not to upcast the layernorm weights in fp32."},
|
||||||
)
|
)
|
||||||
upcast_lmhead_output: Optional[bool] = field(
|
upcast_lmhead_output: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Whether or not to upcast the output of lm_head in fp32."},
|
metadata={"help": "Whether or not to upcast the output of lm_head in fp32."},
|
||||||
)
|
)
|
||||||
|
infer_backend: Literal["hf", "vllm"] = field(
|
||||||
|
default="hf",
|
||||||
|
metadata={"help": "Backend engine used at inference."},
|
||||||
|
)
|
||||||
hf_hub_token: Optional[str] = field(
|
hf_hub_token: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Auth token to log in with Hugging Face Hub."},
|
metadata={"help": "Auth token to log in with Hugging Face Hub."},
|
||||||
@ -89,7 +97,7 @@ class ModelArguments:
|
|||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Path to the directory to save the exported model."},
|
metadata={"help": "Path to the directory to save the exported model."},
|
||||||
)
|
)
|
||||||
export_size: Optional[int] = field(
|
export_size: int = field(
|
||||||
default=1,
|
default=1,
|
||||||
metadata={"help": "The file shard size (in GB) of the exported model."},
|
metadata={"help": "The file shard size (in GB) of the exported model."},
|
||||||
)
|
)
|
||||||
@ -101,15 +109,15 @@ class ModelArguments:
|
|||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Path to the dataset or dataset name to use in quantizing the exported model."},
|
metadata={"help": "Path to the dataset or dataset name to use in quantizing the exported model."},
|
||||||
)
|
)
|
||||||
export_quantization_nsamples: Optional[int] = field(
|
export_quantization_nsamples: int = field(
|
||||||
default=128,
|
default=128,
|
||||||
metadata={"help": "The number of samples used for quantization."},
|
metadata={"help": "The number of samples used for quantization."},
|
||||||
)
|
)
|
||||||
export_quantization_maxlen: Optional[int] = field(
|
export_quantization_maxlen: int = field(
|
||||||
default=1024,
|
default=1024,
|
||||||
metadata={"help": "The maximum length of the model inputs used for quantization."},
|
metadata={"help": "The maximum length of the model inputs used for quantization."},
|
||||||
)
|
)
|
||||||
export_legacy_format: Optional[bool] = field(
|
export_legacy_format: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Whether or not to save the `.bin` files instead of `.safetensors`."},
|
metadata={"help": "Whether or not to save the `.bin` files instead of `.safetensors`."},
|
||||||
)
|
)
|
||||||
@ -117,16 +125,15 @@ class ModelArguments:
|
|||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "The name of the repository if push the model to the Hugging Face hub."},
|
metadata={"help": "The name of the repository if push the model to the Hugging Face hub."},
|
||||||
)
|
)
|
||||||
print_param_status: Optional[bool] = field(
|
print_param_status: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "For debugging purposes, print the status of the parameters in the model."},
|
metadata={"help": "For debugging purposes, print the status of the parameters in the model."},
|
||||||
)
|
)
|
||||||
aqlm_optimization: Optional[bool] = field(
|
|
||||||
default=False, metadata={"help": "Whether or not to optimize the training performance of AQLM models."}
|
|
||||||
)
|
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
|
self.aqlm_optimization = None
|
||||||
self.compute_dtype = None
|
self.compute_dtype = None
|
||||||
|
self.device_map = None
|
||||||
self.model_max_length = None
|
self.model_max_length = None
|
||||||
|
|
||||||
if self.split_special_tokens and self.use_fast_tokenizer:
|
if self.split_special_tokens and self.use_fast_tokenizer:
|
||||||
|
@ -10,6 +10,7 @@ from transformers.trainer_utils import get_last_checkpoint
|
|||||||
|
|
||||||
from ..extras.logging import get_logger
|
from ..extras.logging import get_logger
|
||||||
from ..extras.packages import is_unsloth_available
|
from ..extras.packages import is_unsloth_available
|
||||||
|
from ..extras.misc import check_dependencies
|
||||||
from .data_args import DataArguments
|
from .data_args import DataArguments
|
||||||
from .evaluation_args import EvaluationArguments
|
from .evaluation_args import EvaluationArguments
|
||||||
from .finetuning_args import FinetuningArguments
|
from .finetuning_args import FinetuningArguments
|
||||||
@ -20,6 +21,9 @@ from .model_args import ModelArguments
|
|||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
check_dependencies()
|
||||||
|
|
||||||
|
|
||||||
_TRAIN_ARGS = [ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments]
|
_TRAIN_ARGS = [ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments]
|
||||||
_TRAIN_CLS = Tuple[ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments]
|
_TRAIN_CLS = Tuple[ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments]
|
||||||
_INFER_ARGS = [ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
|
_INFER_ARGS = [ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
|
||||||
@ -221,7 +225,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
|||||||
training_args.local_rank,
|
training_args.local_rank,
|
||||||
training_args.device,
|
training_args.device,
|
||||||
training_args.n_gpu,
|
training_args.n_gpu,
|
||||||
bool(training_args.local_rank != -1),
|
training_args.parallel_mode.value == "distributed",
|
||||||
str(model_args.compute_dtype),
|
str(model_args.compute_dtype),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -236,6 +240,8 @@ def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
|
|||||||
|
|
||||||
_set_transformers_logging()
|
_set_transformers_logging()
|
||||||
_verify_model_args(model_args, finetuning_args)
|
_verify_model_args(model_args, finetuning_args)
|
||||||
|
model_args.aqlm_optimization = False
|
||||||
|
model_args.device_map = "auto"
|
||||||
|
|
||||||
if data_args.template is None:
|
if data_args.template is None:
|
||||||
raise ValueError("Please specify which `template` to use.")
|
raise ValueError("Please specify which `template` to use.")
|
||||||
@ -249,6 +255,7 @@ def get_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS:
|
|||||||
_set_transformers_logging()
|
_set_transformers_logging()
|
||||||
_verify_model_args(model_args, finetuning_args)
|
_verify_model_args(model_args, finetuning_args)
|
||||||
model_args.aqlm_optimization = True
|
model_args.aqlm_optimization = True
|
||||||
|
model_args.device_map = "auto"
|
||||||
|
|
||||||
if data_args.template is None:
|
if data_args.template is None:
|
||||||
raise ValueError("Please specify which `template` to use.")
|
raise ValueError("Please specify which `template` to use.")
|
||||||
|
@ -5,7 +5,7 @@ from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
|
|||||||
from trl import AutoModelForCausalLMWithValueHead
|
from trl import AutoModelForCausalLMWithValueHead
|
||||||
|
|
||||||
from ..extras.logging import get_logger
|
from ..extras.logging import get_logger
|
||||||
from ..extras.misc import check_dependencies, count_parameters, get_current_device, try_download_model_from_ms
|
from ..extras.misc import count_parameters, get_current_device, try_download_model_from_ms
|
||||||
from .adapter import init_adapter
|
from .adapter import init_adapter
|
||||||
from .patcher import patch_config, patch_model, patch_tokenizer, patch_valuehead_model
|
from .patcher import patch_config, patch_model, patch_tokenizer, patch_valuehead_model
|
||||||
from .utils import load_valuehead_params, register_autoclass
|
from .utils import load_valuehead_params, register_autoclass
|
||||||
@ -20,9 +20,6 @@ if TYPE_CHECKING:
|
|||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
check_dependencies()
|
|
||||||
|
|
||||||
|
|
||||||
def _get_init_kwargs(model_args: "ModelArguments") -> Dict[str, Any]:
|
def _get_init_kwargs(model_args: "ModelArguments") -> Dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"trust_remote_code": True,
|
"trust_remote_code": True,
|
||||||
|
@ -286,9 +286,9 @@ def patch_config(
|
|||||||
|
|
||||||
init_kwargs["torch_dtype"] = model_args.compute_dtype
|
init_kwargs["torch_dtype"] = model_args.compute_dtype
|
||||||
if not is_deepspeed_zero3_enabled():
|
if not is_deepspeed_zero3_enabled():
|
||||||
init_kwargs["low_cpu_mem_usage"] = True
|
init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage
|
||||||
if "device_map" not in init_kwargs:
|
if "device_map" not in init_kwargs: # quant models cannot use auto device map
|
||||||
init_kwargs["device_map"] = {"": get_current_device()} if is_trainable else "auto"
|
init_kwargs["device_map"] = model_args.device_map or {"": get_current_device()}
|
||||||
|
|
||||||
|
|
||||||
def patch_model(
|
def patch_model(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user