mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-29 10:10:35 +08:00
fix reward model loading
Former-commit-id: 9709ca501180a1afce32e9043aedb359762b437d
This commit is contained in:
@@ -204,7 +204,7 @@ def load_model_and_tokenizer(
|
||||
reset_logging()
|
||||
if stage == "rm" and model_args.checkpoint_dir is not None: # load valuehead weights to evaluate reward model
|
||||
logger.warning("Only the last checkpoint containing valuehead will be loaded.")
|
||||
if load_valuehead_params(model, model_args.checkpoint_dir[-1]):
|
||||
if load_valuehead_params(model, model_args):
|
||||
model.v_head.load_state_dict({
|
||||
"summary.weight": getattr(model, "reward_head_weight"),
|
||||
"summary.bias": getattr(model, "reward_head_bias")
|
||||
@@ -214,7 +214,7 @@ def load_model_and_tokenizer(
|
||||
logger.info("Load reward model from {}".format(model_args.reward_model))
|
||||
if getattr(model, "is_peft_model", False):
|
||||
model.pretrained_model.load_adapter(model_args.reward_model, "reward")
|
||||
load_valuehead_params(model, model_args)
|
||||
assert load_valuehead_params(model, model_args), "Reward model is not correctly loaded."
|
||||
|
||||
# Prepare model for inference
|
||||
if not is_trainable:
|
||||
|
||||
Reference in New Issue
Block a user