fix ppo train and dpo eval

Former-commit-id: ced863031836632cb5920e22ae6991f251372118
This commit is contained in:
hiyouga
2023-11-07 22:48:51 +08:00
parent 14a38b5069
commit f5ba2190fb
5 changed files with 56 additions and 21 deletions

View File

@@ -36,8 +36,8 @@ def init_adapter(
Note that the trainable parameters must be cast to float32.
"""
if finetuning_args.finetuning_type == "none" and is_trainable:
raise ValueError("You cannot use finetuning_type=none while training.")
if (not is_trainable) and model_args.checkpoint_dir is None:
logger.info("Checkpoint is not found at evaluation, load the original model.")
if finetuning_args.finetuning_type == "full" and is_trainable:
logger.info("Fine-tuning method: Full")
@@ -60,11 +60,11 @@ def init_adapter(
if finetuning_args.finetuning_type == "lora":
logger.info("Fine-tuning method: LoRA")
latest_checkpoint = None
checkpoint_to_resume = None
if model_args.checkpoint_dir is not None:
if is_trainable and finetuning_args.resume_lora_training: # continually fine-tuning
checkpoints_to_merge, latest_checkpoint = model_args.checkpoint_dir[:-1], model_args.checkpoint_dir[-1]
if is_trainable and finetuning_args.resume_lora_training:
checkpoints_to_merge, checkpoint_to_resume = model_args.checkpoint_dir[:-1], model_args.checkpoint_dir[-1]
else:
checkpoints_to_merge = model_args.checkpoint_dir
@@ -75,10 +75,10 @@ def init_adapter(
if len(checkpoints_to_merge) > 0:
logger.info("Merged {} model checkpoint(s).".format(len(checkpoints_to_merge)))
if latest_checkpoint is not None: # resume lora training or quantized inference
model = PeftModel.from_pretrained(model, latest_checkpoint, is_trainable=is_trainable)
if checkpoint_to_resume is not None: # resume lora training
model = PeftModel.from_pretrained(model, checkpoint_to_resume, is_trainable=is_trainable)
if is_trainable and latest_checkpoint is None: # create new lora weights while training
if is_trainable and checkpoint_to_resume is None: # create new lora weights while training
if len(finetuning_args.lora_target) == 1 and finetuning_args.lora_target[0] == "all":
target_modules = find_all_linear_modules(model, model_args.quantization_bit)
else: