This commit is contained in:
hiyouga
2023-12-16 20:50:45 +08:00
parent 7ae6919b9b
commit 870426ff70
3 changed files with 22 additions and 17 deletions

View File

@@ -81,9 +81,16 @@ def load_model_and_tokenizer(
if add_valuehead:
model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(model)
patcher.patch_valuehead_model(model)
vhead_params = load_valuehead_params(model_args)
if model_args.adapter_name_or_path is not None:
vhead_path = model_args.adapter_name_or_path[-1]
else:
vhead_path = model_args.model_name_or_path
vhead_params = load_valuehead_params(vhead_path, model_args)
if vhead_params is not None:
model.load_state_dict(vhead_params, strict=False)
logger.info("Loaded valuehead from checkpoint: {}".format(vhead_path))
if not is_trainable:
model.requires_grad_(False) # fix all model params