mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-04 12:42:51 +08:00
fix valuehead patch
Former-commit-id: 24d8d6f224ccb98387ec72e688fa32f5f308dd07
This commit is contained in:
parent
1a86cc3078
commit
eb021ca748
@ -277,5 +277,6 @@ def patch_valuehead_model(model: "AutoModelForCausalLMWithValueHead") -> None:
|
|||||||
ignore_modules = [name for name, _ in model.named_parameters() if "pretrained_model" in name]
|
ignore_modules = [name for name, _ in model.named_parameters() if "pretrained_model" in name]
|
||||||
setattr(model, "_keys_to_ignore_on_save", ignore_modules)
|
setattr(model, "_keys_to_ignore_on_save", ignore_modules)
|
||||||
setattr(model, "_no_split_modules", getattr(model.pretrained_model, "_no_split_modules", None))
|
setattr(model, "_no_split_modules", getattr(model.pretrained_model, "_no_split_modules", None))
|
||||||
|
setattr(model, "dtype", getattr(model.pretrained_model, "dtype", None))
|
||||||
setattr(model, "tie_weights", MethodType(tie_weights, model))
|
setattr(model, "tie_weights", MethodType(tie_weights, model))
|
||||||
setattr(model, "get_input_embeddings", MethodType(get_input_embeddings, model))
|
setattr(model, "get_input_embeddings", MethodType(get_input_embeddings, model))
|
||||||
|
Loading…
x
Reference in New Issue
Block a user