From 9acab4949d7aca06f4a12f9d05d6fa3fc1d8347b Mon Sep 17 00:00:00 2001 From: Yaowei Zheng Date: Sat, 7 Jun 2025 08:47:50 +0800 Subject: [PATCH] [model] fix model generate (#8327) --- src/llamafactory/model/patcher.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/llamafactory/model/patcher.py b/src/llamafactory/model/patcher.py index 20228812..4bf1d21d 100644 --- a/src/llamafactory/model/patcher.py +++ b/src/llamafactory/model/patcher.py @@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, Any import torch from peft import PeftModel -from transformers import PreTrainedModel, PreTrainedTokenizerBase +from transformers import GenerationMixin, PreTrainedModel, PreTrainedTokenizerBase from transformers.integrations import is_deepspeed_zero3_enabled from transformers.modeling_utils import is_fsdp_enabled @@ -169,7 +169,7 @@ def patch_model( if getattr(model.config, "model_type", None) not in ["minicpmv", "minicpmo"] and "GenerationMixin" not in str( model.generate.__func__ ): - model.generate = MethodType(PreTrainedModel.generate, model) + model.generate = MethodType(GenerationMixin.generate, model) if add_valuehead: prepare_valuehead_model(model)