[model] fix model generate (#8327)

This commit is contained in:
Yaowei Zheng 2025-06-07 08:47:50 +08:00 committed by GitHub
parent d325a1a7c7
commit 7ecc2d46ca
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, Any
import torch
from peft import PeftModel
from transformers import PreTrainedModel, PreTrainedTokenizerBase
from transformers import GenerationMixin, PreTrainedModel, PreTrainedTokenizerBase
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.modeling_utils import is_fsdp_enabled
@ -169,7 +169,7 @@ def patch_model(
if getattr(model.config, "model_type", None) not in ["minicpmv", "minicpmo"] and "GenerationMixin" not in str(
model.generate.__func__
):
model.generate = MethodType(PreTrainedModel.generate, model)
model.generate = MethodType(GenerationMixin.generate, model)
if add_valuehead:
prepare_valuehead_model(model)