fix reward model loading

Former-commit-id: 9709ca501180a1afce32e9043aedb359762b437d
This commit is contained in:
hiyouga
2023-11-07 17:20:51 +08:00
parent 857696ed9c
commit f23e5b602a
6 changed files with 34 additions and 24 deletions

View File

@@ -104,7 +104,7 @@ def init_adapter(
def load_valuehead_params(
model: "PreTrainedModel",
model_args: "ModelArguments"
) -> None:
) -> bool:
kwargs = {
"path_or_repo_id": model_args.reward_model,
"cache_dir": model_args.cache_dir,
@@ -117,10 +117,12 @@ def load_valuehead_params(
try:
vhead_file = cached_file(filename=SAFE_WEIGHTS_NAME, **kwargs)
except:
raise ValueError("Provided path ({}) does not contain valuehead weights.".format(model_args.reward_model))
logger.warning("Provided path ({}) does not contain valuehead weights.".format(model_args.reward_model))
return False
vhead_params = torch.load(vhead_file, map_location="cpu")
model.register_buffer("reward_head_weight", vhead_params["v_head.summary.weight"], persistent=False)
model.register_buffer("reward_head_bias", vhead_params["v_head.summary.bias"], persistent=False)
model.register_buffer("default_head_weight", torch.zeros_like(vhead_params["v_head.summary.weight"]), persistent=False)
model.register_buffer("default_head_bias", torch.zeros_like(vhead_params["v_head.summary.bias"]), persistent=False)
return True