fix value head model resuming

Former-commit-id: 2a36fd5064f028f394ac07c25440fd5e965a07b8
This commit is contained in:
hiyouga 2023-11-20 19:01:37 +08:00
parent 682d81caa9
commit a7b1632ace

View File

@ -202,6 +202,7 @@ def load_model_and_tokenizer(
if stage in ["rm", "ppo"]:
model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(model)
setattr(model, "_keys_to_ignore_on_save", [name for name, _ in model.named_parameters() if "pretrained_model" in name])
setattr(model, "tie_weights", MethodType(lambda _: None, model))
vhead_path = (
model_args.checkpoint_dir[-1] if model_args.checkpoint_dir is not None else model_args.model_name_or_path
)