diff --git a/src/llmtuner/model/patcher.py b/src/llmtuner/model/patcher.py index cb55f5ed..e3d7539f 100644 --- a/src/llmtuner/model/patcher.py +++ b/src/llmtuner/model/patcher.py @@ -312,6 +312,15 @@ def patch_config( def patch_model( model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments", is_trainable: bool ) -> None: + #Config check and fix + gen_config = model.generation_config + if not gen_config.do_sample and ( + (gen_config.temperature is not None and gen_config.temperature != 1.0) + or (gen_config.top_p is not None and gen_config.top_p != 1.0) + or (gen_config.typical_p is not None and gen_config.typical_p != 1.0) + ): + gen_config.do_sample = True + if "GenerationMixin" not in str(model.generate.__func__): model.generate = MethodType(PreTrainedModel.generate, model) diff --git a/src/llmtuner/train/tuner.py b/src/llmtuner/train/tuner.py index a03955d5..1b8e3cb7 100644 --- a/src/llmtuner/train/tuner.py +++ b/src/llmtuner/train/tuner.py @@ -64,14 +64,6 @@ def export_model(args: Optional[Dict[str, Any]] = None): for param in model.parameters(): param.data = param.data.to(output_dtype) - gen_config = model.generation_config # check and fix generation config - if not gen_config.do_sample and ( - (gen_config.temperature is not None and gen_config.temperature != 1.0) - or (gen_config.top_p is not None and gen_config.top_p != 1.0) - or (gen_config.typical_p is not None and gen_config.typical_p != 1.0) - ): - gen_config.do_sample = True - model.save_pretrained( save_directory=model_args.export_dir, max_shard_size="{}GB".format(model_args.export_size),