mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-15 19:30:36 +08:00
fix ChatGLM lm_head #494
This commit is contained in:
@@ -153,6 +153,10 @@ def load_model_and_tokenizer(
|
||||
if "GenerationMixin" not in str(model.generate.__func__):
|
||||
model.generate = MethodType(PreTrainedModel.generate, model)
|
||||
|
||||
# Fix LM head (for ChatGLM2)
|
||||
if not hasattr(model, "lm_head"):
|
||||
setattr(model, "lm_head", model.transformer.output_layer)
|
||||
|
||||
# Register auto class to save the custom code files.
|
||||
if isinstance(config, PretrainedConfig) and "AutoConfig" in getattr(config, "auto_map", {}):
|
||||
config.__class__.register_for_auto_class()
|
||||
|
||||
Reference in New Issue
Block a user