fix ChatGLM lm_head #494

Former-commit-id: bf0048abdaeb2b9592d38ac991704ad014370b47
This commit is contained in:
hiyouga
2023-08-14 14:14:48 +08:00
parent 0bfeed3a7e
commit bceaba551d
3 changed files with 12 additions and 8 deletions

View File

@@ -153,6 +153,10 @@ def load_model_and_tokenizer(
if "GenerationMixin" not in str(model.generate.__func__):
model.generate = MethodType(PreTrainedModel.generate, model)
# Fix LM head (for ChatGLM2)
if not hasattr(model, "lm_head"):
setattr(model, "lm_head", model.transformer.output_layer)
# Register auto class to save the custom code files.
if isinstance(config, PretrainedConfig) and "AutoConfig" in getattr(config, "auto_map", {}):
config.__class__.register_for_auto_class()