diff --git a/src/llmtuner/model/patcher.py b/src/llmtuner/model/patcher.py index e3d7539f..03ca0096 100644 --- a/src/llmtuner/model/patcher.py +++ b/src/llmtuner/model/patcher.py @@ -320,7 +320,7 @@ def patch_model( or (gen_config.typical_p is not None and gen_config.typical_p != 1.0) ): gen_config.do_sample = True - + if "GenerationMixin" not in str(model.generate.__func__): model.generate = MethodType(PreTrainedModel.generate, model)