refactor adapter hparam

This commit is contained in:
hiyouga
2023-12-15 20:53:11 +08:00
parent d4c351f1ec
commit 0716f5e470
21 changed files with 302 additions and 311 deletions

View File

@@ -41,15 +41,19 @@ _EVAL_CLS = Tuple[
def _verify_model_args(model_args: "ModelArguments", finetuning_args: "FinetuningArguments") -> None:
if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora":
raise ValueError("Quantization is only compatible with the LoRA method.")
if model_args.checkpoint_dir is not None and len(model_args.checkpoint_dir) != 1:
if model_args.quantization_bit is not None:
if finetuning_args.finetuning_type != "lora":
raise ValueError("Multiple checkpoints are only available for LoRA tuning.")
raise ValueError("Quantization is only compatible with the LoRA method.")
if finetuning_args.create_new_adapter:
raise ValueError("Cannot create new adapter upon a quantized model.")
if model_args.adapter_name_or_path is not None and len(model_args.adapter_name_or_path) != 1:
if finetuning_args.finetuning_type != "lora":
raise ValueError("Multiple adapters are only available for LoRA tuning.")
if model_args.quantization_bit is not None:
raise ValueError("Quantized model only accepts a single checkpoint. Merge them first.")
raise ValueError("Quantized model only accepts a single adapter. Merge them first.")
def parse_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
@@ -139,11 +143,17 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
training_args_dict.update(dict(ddp_find_unused_parameters=False))
training_args = Seq2SeqTrainingArguments(**training_args_dict)
if finetuning_args.stage in ["rm", "ppo"] and finetuning_args.finetuning_type in ["full", "freeze"]:
can_resume_from_checkpoint = False
else:
can_resume_from_checkpoint = True
if (
training_args.resume_from_checkpoint is None
and training_args.do_train
and os.path.isdir(training_args.output_dir)
and not training_args.overwrite_output_dir
and can_resume_from_checkpoint
):
last_checkpoint = get_last_checkpoint(training_args.output_dir)
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
@@ -158,7 +168,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
))
if finetuning_args.stage in ["rm", "ppo"] and training_args.resume_from_checkpoint is not None:
logger.warning("Add {} to `checkpoint_dir` to resume training from checkpoint.".format(
logger.warning("Add {} to `adapter_name_or_path` to resume training from checkpoint.".format(
training_args.resume_from_checkpoint
))