fix reward model loading

Former-commit-id: 9709ca501180a1afce32e9043aedb359762b437d
This commit is contained in:
hiyouga
2023-11-07 17:20:51 +08:00
parent 857696ed9c
commit f23e5b602a
6 changed files with 34 additions and 24 deletions

View File

@@ -204,7 +204,7 @@ def load_model_and_tokenizer(
reset_logging()
if stage == "rm" and model_args.checkpoint_dir is not None: # load valuehead weights to evaluate reward model
logger.warning("Only the last checkpoint containing valuehead will be loaded.")
if load_valuehead_params(model, model_args.checkpoint_dir[-1]):
if load_valuehead_params(model, model_args):
model.v_head.load_state_dict({
"summary.weight": getattr(model, "reward_head_weight"),
"summary.bias": getattr(model, "reward_head_bias")
@@ -214,7 +214,7 @@ def load_model_and_tokenizer(
logger.info("Load reward model from {}".format(model_args.reward_model))
if getattr(model, "is_peft_model", False):
model.pretrained_model.load_adapter(model_args.reward_model, "reward")
load_valuehead_params(model, model_args)
assert load_valuehead_params(model, model_args), "Reward model is not correctly loaded."
# Prepare model for inference
if not is_trainable: