mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-04 20:52:59 +08:00
parent
3fff67e3c7
commit
9a496950aa
@ -4,9 +4,10 @@ import time
|
|||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
|
|
||||||
from transformers import TrainerCallback
|
from transformers import PreTrainedModel, TrainerCallback
|
||||||
from transformers.modeling_utils import custom_object_save, unwrap_model
|
from transformers.modeling_utils import custom_object_save, unwrap_model
|
||||||
from transformers.trainer_utils import has_length, PREFIX_CHECKPOINT_DIR
|
from transformers.trainer_utils import has_length, PREFIX_CHECKPOINT_DIR
|
||||||
|
from peft import PeftModel
|
||||||
|
|
||||||
from llmtuner.extras.constants import LOG_FILE_NAME
|
from llmtuner.extras.constants import LOG_FILE_NAME
|
||||||
from llmtuner.extras.logging import get_logger
|
from llmtuner.extras.logging import get_logger
|
||||||
@ -19,14 +20,20 @@ if TYPE_CHECKING:
|
|||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _save_model_with_valuehead(model: "AutoModelForCausalLMWithValueHead", output_dir: str) -> None:
|
def _save_model_with_valuehead(
|
||||||
model.pretrained_model.config.save_pretrained(output_dir)
|
model: "AutoModelForCausalLMWithValueHead",
|
||||||
if model.pretrained_model.can_generate():
|
output_dir: str,
|
||||||
model.pretrained_model.generation_config.save_pretrained(output_dir)
|
safe_serialization: bool
|
||||||
if getattr(model, "is_peft_model", False):
|
) -> None:
|
||||||
model.pretrained_model.save_pretrained(output_dir)
|
if isinstance(model.pretrained_model, (PreTrainedModel, PeftModel)):
|
||||||
elif getattr(model.pretrained_model, "_auto_class", None): # must not a peft model
|
model.pretrained_model.config.save_pretrained(output_dir)
|
||||||
custom_object_save(model.pretrained_model, output_dir, config=model.pretrained_model.config)
|
if model.pretrained_model.can_generate():
|
||||||
|
model.pretrained_model.generation_config.save_pretrained(output_dir)
|
||||||
|
|
||||||
|
if getattr(model, "is_peft_model", False):
|
||||||
|
model.pretrained_model.save_pretrained(output_dir, safe_serialization=safe_serialization)
|
||||||
|
elif getattr(model.pretrained_model, "_auto_class", None): # must not a peft model
|
||||||
|
custom_object_save(model.pretrained_model, output_dir, config=model.pretrained_model.config)
|
||||||
|
|
||||||
|
|
||||||
class SavePeftModelCallback(TrainerCallback):
|
class SavePeftModelCallback(TrainerCallback):
|
||||||
@ -38,7 +45,8 @@ class SavePeftModelCallback(TrainerCallback):
|
|||||||
if args.should_save:
|
if args.should_save:
|
||||||
_save_model_with_valuehead(
|
_save_model_with_valuehead(
|
||||||
model=unwrap_model(kwargs.pop("model")),
|
model=unwrap_model(kwargs.pop("model")),
|
||||||
output_dir=os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step))
|
output_dir=os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step)),
|
||||||
|
safe_serialization=args.save_safetensors
|
||||||
)
|
)
|
||||||
|
|
||||||
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||||
@ -46,7 +54,11 @@ class SavePeftModelCallback(TrainerCallback):
|
|||||||
Event called at the end of training.
|
Event called at the end of training.
|
||||||
"""
|
"""
|
||||||
if args.should_save:
|
if args.should_save:
|
||||||
_save_model_with_valuehead(model=unwrap_model(kwargs.pop("model")), output_dir=args.output_dir)
|
_save_model_with_valuehead(
|
||||||
|
model=unwrap_model(kwargs.pop("model")),
|
||||||
|
output_dir=args.output_dir,
|
||||||
|
safe_serialization=args.save_safetensors
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class LogCallback(TrainerCallback):
|
class LogCallback(TrainerCallback):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user