fix ChatGLM lm_head #494

This commit is contained in:
hiyouga
2023-08-14 14:14:48 +08:00
parent 20a29297b1
commit d019956808
3 changed files with 12 additions and 8 deletions

View File

@@ -153,6 +153,10 @@ def load_model_and_tokenizer(
if "GenerationMixin" not in str(model.generate.__func__):
model.generate = MethodType(PreTrainedModel.generate, model)
# Fix LM head (for ChatGLM2)
if not hasattr(model, "lm_head"):
setattr(model, "lm_head", model.transformer.output_layer)
# Register auto class to save the custom code files.
if isinstance(config, PretrainedConfig) and "AutoConfig" in getattr(config, "auto_map", {}):
config.__class__.register_for_auto_class()