fix ppo train and dpo eval

Former-commit-id: ced863031836632cb5920e22ae6991f251372118
This commit is contained in:
hiyouga
2023-11-07 22:48:51 +08:00
parent 14a38b5069
commit f5ba2190fb
5 changed files with 56 additions and 21 deletions

View File

@@ -15,6 +15,7 @@ from transformers import (
)
from transformers.models.llama import modeling_llama as LlamaModule
from transformers.utils.versions import require_version
from peft import PeftModel
from trl import AutoModelForCausalLMWithValueHead
try:
@@ -55,9 +56,6 @@ def load_model_and_tokenizer(
Support both training and inference.
"""
if (not is_trainable) and model_args.checkpoint_dir is None:
logger.warning("Checkpoint is not found at evaluation, load the original model.")
finetuning_args = FinetuningArguments(finetuning_type="none")
config_kwargs = {
"trust_remote_code": True,
@@ -212,8 +210,11 @@ def load_model_and_tokenizer(
if stage == "ppo": # load reward model
logger.info("Load reward model from {}".format(model_args.reward_model))
if getattr(model, "is_peft_model", False):
if isinstance(model.pretrained_model, PeftModel):
model.pretrained_model.load_adapter(model_args.reward_model, "reward")
for name, param in model.named_parameters(): # https://github.com/huggingface/peft/issues/1090
if "default" in name:
param.data = param.data.to(torch.float32) # trainable params should in fp32
assert load_valuehead_params(model, model_args), "Reward model is not correctly loaded."
# Prepare model for inference