From 7168392a51c0113b9d5c24a6cb3be961378fdf78 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Wed, 3 Jan 2024 16:19:23 +0800 Subject: [PATCH] fix valuehead patch Former-commit-id: d9cb98362b58b28ae0ee207e7c07e75e5d810876 --- src/llmtuner/model/patcher.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/llmtuner/model/patcher.py b/src/llmtuner/model/patcher.py index 788f6e60..e63d9477 100644 --- a/src/llmtuner/model/patcher.py +++ b/src/llmtuner/model/patcher.py @@ -277,5 +277,6 @@ def patch_valuehead_model(model: "AutoModelForCausalLMWithValueHead") -> None: ignore_modules = [name for name, _ in model.named_parameters() if "pretrained_model" in name] setattr(model, "_keys_to_ignore_on_save", ignore_modules) setattr(model, "_no_split_modules", getattr(model.pretrained_model, "_no_split_modules", None)) + setattr(model, "dtype", getattr(model.pretrained_model, "dtype", None)) setattr(model, "tie_weights", MethodType(tie_weights, model)) setattr(model, "get_input_embeddings", MethodType(get_input_embeddings, model))