fix valuehead model

Former-commit-id: bfdee1608f53a6334d8e73c48dbeb4160969d783
This commit is contained in:
hiyouga 2023-12-14 20:15:20 +08:00
parent 3358416e82
commit e55e32efc4

View File

@ -203,6 +203,9 @@ def load_model_and_tokenizer(
# Prepare model with valuehead for RLHF # Prepare model with valuehead for RLHF
if add_valuehead: if add_valuehead:
model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(model) model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(model)
def get_input_embeddings(self: "AutoModelForCausalLMWithValueHead") -> torch.nn.Module:
return self.pretrained_model.get_input_embeddings()
setattr(model, "get_input_embeddings", MethodType(get_input_embeddings, model))
ignore_modules = [name for name, _ in model.named_parameters() if "pretrained_model" in name] ignore_modules = [name for name, _ in model.named_parameters() if "pretrained_model" in name]
setattr(model, "_keys_to_ignore_on_save", ignore_modules) setattr(model, "_keys_to_ignore_on_save", ignore_modules)
setattr(model, "tie_weights", MethodType(lambda _: None, model)) # use empty method setattr(model, "tie_weights", MethodType(lambda _: None, model)) # use empty method