support llama pro #2338 , add rslora

This commit is contained in:
hiyouga
2024-02-15 02:27:36 +08:00
parent 8a1b389086
commit 7924ffc55d
24 changed files with 438 additions and 203 deletions

View File

@@ -30,12 +30,15 @@ _EVAL_ARGS = [ModelArguments, DataArguments, EvaluationArguments, FinetuningArgu
_EVAL_CLS = Tuple[ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments]
def _check_dependencies():
require_version("transformers>=4.37.2", "To fix: pip install transformers>=4.37.2")
require_version("datasets>=2.14.3", "To fix: pip install datasets>=2.14.3")
require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0")
require_version("peft>=0.7.0", "To fix: pip install peft>=0.7.0")
require_version("trl>=0.7.6", "To fix: pip install trl>=0.7.6")
def _check_dependencies(disabled: bool) -> None:
if disabled:
logger.warning("Version checking has been disabled, may lead to unexpected behaviors.")
else:
require_version("transformers>=4.37.2", "To fix: pip install transformers>=4.37.2")
require_version("datasets>=2.14.3", "To fix: pip install datasets>=2.14.3")
require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0")
require_version("peft>=0.8.2", "To fix: pip install peft>=0.8.2")
require_version("trl>=0.7.6", "To fix: pip install trl>=0.7.6")
def _parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None) -> Tuple[Any]:
@@ -130,6 +133,13 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
if training_args.do_train and training_args.predict_with_generate:
raise ValueError("`predict_with_generate` cannot be set as True while training.")
if (
training_args.do_train
and finetuning_args.finetuning_type == "freeze"
and finetuning_args.name_module_trainable is None
):
raise ValueError("Please specify `name_module_trainable` in Freeze training.")
if training_args.do_train and finetuning_args.finetuning_type == "lora" and finetuning_args.lora_target is None:
raise ValueError("Please specify `lora_target` in LoRA training.")
@@ -137,9 +147,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
raise ValueError("Install Unsloth: https://github.com/unslothai/unsloth")
_verify_model_args(model_args, finetuning_args)
if not finetuning_args.disable_version_checking:
_check_dependencies()
_check_dependencies(disabled=finetuning_args.disable_version_checking)
if (
training_args.do_train
@@ -240,13 +248,11 @@ def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
_set_transformers_logging()
_verify_model_args(model_args, finetuning_args)
_check_dependencies(disabled=finetuning_args.disable_version_checking)
if data_args.template is None:
raise ValueError("Please specify which `template` to use.")
if not finetuning_args.disable_version_checking:
_check_dependencies()
return model_args, data_args, finetuning_args, generating_args
@@ -255,13 +261,11 @@ def get_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS:
_set_transformers_logging()
_verify_model_args(model_args, finetuning_args)
_check_dependencies(disabled=finetuning_args.disable_version_checking)
if data_args.template is None:
raise ValueError("Please specify which `template` to use.")
if not finetuning_args.disable_version_checking:
_check_dependencies()
transformers.set_seed(eval_args.seed)
return model_args, data_args, eval_args, finetuning_args