mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-29 18:20:35 +08:00
fix reward model loading
Former-commit-id: 9709ca501180a1afce32e9043aedb359762b437d
This commit is contained in:
@@ -104,7 +104,7 @@ def init_adapter(
|
||||
def load_valuehead_params(
|
||||
model: "PreTrainedModel",
|
||||
model_args: "ModelArguments"
|
||||
) -> None:
|
||||
) -> bool:
|
||||
kwargs = {
|
||||
"path_or_repo_id": model_args.reward_model,
|
||||
"cache_dir": model_args.cache_dir,
|
||||
@@ -117,10 +117,12 @@ def load_valuehead_params(
|
||||
try:
|
||||
vhead_file = cached_file(filename=SAFE_WEIGHTS_NAME, **kwargs)
|
||||
except:
|
||||
raise ValueError("Provided path ({}) does not contain valuehead weights.".format(model_args.reward_model))
|
||||
logger.warning("Provided path ({}) does not contain valuehead weights.".format(model_args.reward_model))
|
||||
return False
|
||||
|
||||
vhead_params = torch.load(vhead_file, map_location="cpu")
|
||||
model.register_buffer("reward_head_weight", vhead_params["v_head.summary.weight"], persistent=False)
|
||||
model.register_buffer("reward_head_bias", vhead_params["v_head.summary.bias"], persistent=False)
|
||||
model.register_buffer("default_head_weight", torch.zeros_like(vhead_params["v_head.summary.weight"]), persistent=False)
|
||||
model.register_buffer("default_head_bias", torch.zeros_like(vhead_params["v_head.summary.bias"]), persistent=False)
|
||||
return True
|
||||
|
||||
Reference in New Issue
Block a user