mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-01 11:12:50 +08:00
[model] fix model generate (#8327)
This commit is contained in:
parent
d325a1a7c7
commit
7ecc2d46ca
@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, Any
|
||||
|
||||
import torch
|
||||
from peft import PeftModel
|
||||
from transformers import PreTrainedModel, PreTrainedTokenizerBase
|
||||
from transformers import GenerationMixin, PreTrainedModel, PreTrainedTokenizerBase
|
||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||
from transformers.modeling_utils import is_fsdp_enabled
|
||||
|
||||
@ -169,7 +169,7 @@ def patch_model(
|
||||
if getattr(model.config, "model_type", None) not in ["minicpmv", "minicpmo"] and "GenerationMixin" not in str(
|
||||
model.generate.__func__
|
||||
):
|
||||
model.generate = MethodType(PreTrainedModel.generate, model)
|
||||
model.generate = MethodType(GenerationMixin.generate, model)
|
||||
|
||||
if add_valuehead:
|
||||
prepare_valuehead_model(model)
|
||||
|
Loading…
x
Reference in New Issue
Block a user