mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-23 22:32:54 +08:00
fix inference
Former-commit-id: d3a0692d4d9033a3b58d68357294854144479536
This commit is contained in:
parent
d5d3b2a42f
commit
e34fc5fd2e
@ -18,10 +18,11 @@ from trl import AutoModelForCausalLMWithValueHead
|
|||||||
from llmtuner.extras.logging import get_logger
|
from llmtuner.extras.logging import get_logger
|
||||||
from llmtuner.extras.misc import count_parameters, prepare_model_for_training
|
from llmtuner.extras.misc import count_parameters, prepare_model_for_training
|
||||||
from llmtuner.extras.save_and_load import load_valuehead_params
|
from llmtuner.extras.save_and_load import load_valuehead_params
|
||||||
|
from llmtuner.hparams import FinetuningArguments
|
||||||
from llmtuner.tuner.core.adapter import init_adapter
|
from llmtuner.tuner.core.adapter import init_adapter
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from llmtuner.hparams import ModelArguments, FinetuningArguments
|
from llmtuner.hparams import ModelArguments
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
@ -68,7 +68,7 @@ class PeftTrainer(Seq2SeqTrainer):
|
|||||||
else:
|
else:
|
||||||
torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
|
torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
|
||||||
|
|
||||||
if self.tokenizer is not None:
|
if self.finetuning_args.finetuning_type == "full" and self.tokenizer is not None:
|
||||||
self.tokenizer.save_pretrained(output_dir)
|
self.tokenizer.save_pretrained(output_dir)
|
||||||
|
|
||||||
with open(os.path.join(output_dir, TRAINING_ARGS_NAME), "w", encoding="utf-8") as f:
|
with open(os.path.join(output_dir, TRAINING_ARGS_NAME), "w", encoding="utf-8") as f:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user