This commit is contained in:
hiyouga
2023-11-20 18:46:36 +08:00
parent 00baaa990e
commit 99a3f06377
5 changed files with 34 additions and 36 deletions

View File

@@ -1,4 +1,3 @@
import os
import math
import torch
from types import MethodType
@@ -202,6 +201,7 @@ def load_model_and_tokenizer(
# Prepare model with valuehead for RLHF
if stage in ["rm", "ppo"]:
model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(model)
setattr(model, "_keys_to_ignore_on_save", [name for name, _ in model.named_parameters() if "pretrained_model" in name])
vhead_path = (
model_args.checkpoint_dir[-1] if model_args.checkpoint_dir is not None else model_args.model_name_or_path
)