mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-11-06 02:42:15 +08:00
fix valuehead model
Former-commit-id: 9f628debb6510f2d1c91b00f121a721ab5d648e9
This commit is contained in:
parent
c9ff152cef
commit
511f3f68e2
@ -203,6 +203,9 @@ def load_model_and_tokenizer(
|
|||||||
# Prepare model with valuehead for RLHF
|
# Prepare model with valuehead for RLHF
|
||||||
if add_valuehead:
|
if add_valuehead:
|
||||||
model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(model)
|
model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(model)
|
||||||
|
def get_input_embeddings(self: "AutoModelForCausalLMWithValueHead") -> torch.nn.Module:
|
||||||
|
return self.pretrained_model.get_input_embeddings()
|
||||||
|
setattr(model, "get_input_embeddings", MethodType(get_input_embeddings, model))
|
||||||
ignore_modules = [name for name, _ in model.named_parameters() if "pretrained_model" in name]
|
ignore_modules = [name for name, _ in model.named_parameters() if "pretrained_model" in name]
|
||||||
setattr(model, "_keys_to_ignore_on_save", ignore_modules)
|
setattr(model, "_keys_to_ignore_on_save", ignore_modules)
|
||||||
setattr(model, "tie_weights", MethodType(lambda _: None, model)) # use empty method
|
setattr(model, "tie_weights", MethodType(lambda _: None, model)) # use empty method
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user