diff --git a/src/llamafactory/model/patcher.py b/src/llamafactory/model/patcher.py index 20228812..4bf1d21d 100644 --- a/src/llamafactory/model/patcher.py +++ b/src/llamafactory/model/patcher.py @@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, Any import torch from peft import PeftModel -from transformers import PreTrainedModel, PreTrainedTokenizerBase +from transformers import GenerationMixin, PreTrainedModel, PreTrainedTokenizerBase from transformers.integrations import is_deepspeed_zero3_enabled from transformers.modeling_utils import is_fsdp_enabled @@ -169,7 +169,7 @@ def patch_model( if getattr(model.config, "model_type", None) not in ["minicpmv", "minicpmo"] and "GenerationMixin" not in str( model.generate.__func__ ): - model.generate = MethodType(PreTrainedModel.generate, model) + model.generate = MethodType(GenerationMixin.generate, model) if add_valuehead: prepare_valuehead_model(model)