From 9a496950aae31651b50f7b985aba2ba9189d37bf Mon Sep 17 00:00:00 2001 From: hiyouga Date: Thu, 4 Jan 2024 22:53:03 +0800 Subject: [PATCH] fix #2067 Former-commit-id: 368b31f6b7422562e4cd471d54affe644c4f10cb --- src/llmtuner/extras/callbacks.py | 34 +++++++++++++++++++++----------- 1 file changed, 23 insertions(+), 11 deletions(-) diff --git a/src/llmtuner/extras/callbacks.py b/src/llmtuner/extras/callbacks.py index fd78391d..44afacf0 100644 --- a/src/llmtuner/extras/callbacks.py +++ b/src/llmtuner/extras/callbacks.py @@ -4,9 +4,10 @@ import time from typing import TYPE_CHECKING from datetime import timedelta -from transformers import TrainerCallback +from transformers import PreTrainedModel, TrainerCallback from transformers.modeling_utils import custom_object_save, unwrap_model from transformers.trainer_utils import has_length, PREFIX_CHECKPOINT_DIR +from peft import PeftModel from llmtuner.extras.constants import LOG_FILE_NAME from llmtuner.extras.logging import get_logger @@ -19,14 +20,20 @@ if TYPE_CHECKING: logger = get_logger(__name__) -def _save_model_with_valuehead(model: "AutoModelForCausalLMWithValueHead", output_dir: str) -> None: - model.pretrained_model.config.save_pretrained(output_dir) - if model.pretrained_model.can_generate(): - model.pretrained_model.generation_config.save_pretrained(output_dir) - if getattr(model, "is_peft_model", False): - model.pretrained_model.save_pretrained(output_dir) - elif getattr(model.pretrained_model, "_auto_class", None): # must not a peft model - custom_object_save(model.pretrained_model, output_dir, config=model.pretrained_model.config) +def _save_model_with_valuehead( + model: "AutoModelForCausalLMWithValueHead", + output_dir: str, + safe_serialization: bool +) -> None: + if isinstance(model.pretrained_model, (PreTrainedModel, PeftModel)): + model.pretrained_model.config.save_pretrained(output_dir) + if model.pretrained_model.can_generate(): + model.pretrained_model.generation_config.save_pretrained(output_dir) + + if getattr(model, "is_peft_model", False): + model.pretrained_model.save_pretrained(output_dir, safe_serialization=safe_serialization) + elif getattr(model.pretrained_model, "_auto_class", None): # must not a peft model + custom_object_save(model.pretrained_model, output_dir, config=model.pretrained_model.config) class SavePeftModelCallback(TrainerCallback): @@ -38,7 +45,8 @@ class SavePeftModelCallback(TrainerCallback): if args.should_save: _save_model_with_valuehead( model=unwrap_model(kwargs.pop("model")), - output_dir=os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step)) + output_dir=os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step)), + safe_serialization=args.save_safetensors ) def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): @@ -46,7 +54,11 @@ class SavePeftModelCallback(TrainerCallback): Event called at the end of training. """ if args.should_save: - _save_model_with_valuehead(model=unwrap_model(kwargs.pop("model")), output_dir=args.output_dir) + _save_model_with_valuehead( + model=unwrap_model(kwargs.pop("model")), + output_dir=args.output_dir, + safe_serialization=args.save_safetensors + ) class LogCallback(TrainerCallback):