From e34fc5fd2e9b17f1aa4cc982d779066e6f301df3 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Tue, 1 Aug 2023 00:06:48 +0800 Subject: [PATCH] fix inference Former-commit-id: d3a0692d4d9033a3b58d68357294854144479536 --- src/llmtuner/tuner/core/loader.py | 3 ++- src/llmtuner/tuner/core/trainer.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/llmtuner/tuner/core/loader.py b/src/llmtuner/tuner/core/loader.py index 921dbc11..22b657e3 100644 --- a/src/llmtuner/tuner/core/loader.py +++ b/src/llmtuner/tuner/core/loader.py @@ -18,10 +18,11 @@ from trl import AutoModelForCausalLMWithValueHead from llmtuner.extras.logging import get_logger from llmtuner.extras.misc import count_parameters, prepare_model_for_training from llmtuner.extras.save_and_load import load_valuehead_params +from llmtuner.hparams import FinetuningArguments from llmtuner.tuner.core.adapter import init_adapter if TYPE_CHECKING: - from llmtuner.hparams import ModelArguments, FinetuningArguments + from llmtuner.hparams import ModelArguments logger = get_logger(__name__) diff --git a/src/llmtuner/tuner/core/trainer.py b/src/llmtuner/tuner/core/trainer.py index 928b0d9b..805a3553 100644 --- a/src/llmtuner/tuner/core/trainer.py +++ b/src/llmtuner/tuner/core/trainer.py @@ -68,7 +68,7 @@ class PeftTrainer(Seq2SeqTrainer): else: torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME)) - if self.tokenizer is not None: + if self.finetuning_args.finetuning_type == "full" and self.tokenizer is not None: self.tokenizer.save_pretrained(output_dir) with open(os.path.join(output_dir, TRAINING_ARGS_NAME), "w", encoding="utf-8") as f: