mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-10-14 15:52:49 +08:00
[model] fix model generate (#8327)
This commit is contained in:
parent
32b4574094
commit
9acab4949d
@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, Any
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from peft import PeftModel
|
from peft import PeftModel
|
||||||
from transformers import PreTrainedModel, PreTrainedTokenizerBase
|
from transformers import GenerationMixin, PreTrainedModel, PreTrainedTokenizerBase
|
||||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||||
from transformers.modeling_utils import is_fsdp_enabled
|
from transformers.modeling_utils import is_fsdp_enabled
|
||||||
|
|
||||||
@ -169,7 +169,7 @@ def patch_model(
|
|||||||
if getattr(model.config, "model_type", None) not in ["minicpmv", "minicpmo"] and "GenerationMixin" not in str(
|
if getattr(model.config, "model_type", None) not in ["minicpmv", "minicpmo"] and "GenerationMixin" not in str(
|
||||||
model.generate.__func__
|
model.generate.__func__
|
||||||
):
|
):
|
||||||
model.generate = MethodType(PreTrainedModel.generate, model)
|
model.generate = MethodType(GenerationMixin.generate, model)
|
||||||
|
|
||||||
if add_valuehead:
|
if add_valuehead:
|
||||||
prepare_valuehead_model(model)
|
prepare_valuehead_model(model)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user