mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-11-05 10:22:15 +08:00
86 lines
3.9 KiB
Python
86 lines
3.9 KiB
Python
import os
|
|
import torch
|
|
from typing import Dict, Optional
|
|
|
|
from transformers import Seq2SeqTrainer
|
|
from transformers.trainer import TRAINING_ARGS_NAME
|
|
from transformers.modeling_utils import unwrap_model
|
|
|
|
from llmtuner.extras.constants import FINETUNING_ARGS_NAME, VALUE_HEAD_FILE_NAME
|
|
from llmtuner.extras.logging import get_logger
|
|
from llmtuner.extras.save_and_load import get_state_dict, load_trainable_params, load_valuehead_params
|
|
from llmtuner.hparams import FinetuningArguments
|
|
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
class PeftTrainer(Seq2SeqTrainer):
|
|
r"""
|
|
Inherits Seq2SeqTrainer to support parameter-efficient checkpoints.
|
|
"""
|
|
|
|
def __init__(self, finetuning_args: FinetuningArguments, **kwargs):
|
|
super().__init__(**kwargs)
|
|
self.finetuning_args = finetuning_args
|
|
if self.is_world_process_zero() and os.path.exists(os.path.join(self.args.output_dir, "trainer_log.jsonl")):
|
|
logger.warning("Previous log file in this folder will be deleted.")
|
|
os.remove(os.path.join(self.args.output_dir, "trainer_log.jsonl"))
|
|
|
|
def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, torch.Tensor]] = None) -> None:
|
|
r"""
|
|
Saves trainable parameters as model checkpoint.
|
|
|
|
This function will only be executed at the process zero.
|
|
|
|
Subclass and override to inject custom behavior. It should not be directly used by external scripts.
|
|
"""
|
|
output_dir = output_dir if output_dir is not None else self.args.output_dir
|
|
os.makedirs(output_dir, exist_ok=True)
|
|
logger.info(f"Saving model checkpoint to {output_dir}")
|
|
model = unwrap_model(self.model)
|
|
|
|
if hasattr(model, "pretrained_model"): # for models with valuehead (currently using LoRA only)
|
|
backbone_model = getattr(model, "pretrained_model")
|
|
torch.save(get_state_dict(getattr(model, "v_head")), os.path.join(output_dir, VALUE_HEAD_FILE_NAME))
|
|
else:
|
|
backbone_model = model
|
|
|
|
if self.finetuning_args.finetuning_type == "lora":
|
|
backbone_model.save_pretrained(output_dir, state_dict=get_state_dict(backbone_model))
|
|
else: # freeze/full tuning
|
|
backbone_model.config.use_cache = True
|
|
backbone_model.save_pretrained(
|
|
output_dir,
|
|
state_dict=get_state_dict(backbone_model),
|
|
safe_serialization=self.args.save_safetensors
|
|
)
|
|
backbone_model.config.use_cache = False
|
|
if self.tokenizer is not None:
|
|
self.tokenizer.save_pretrained(output_dir)
|
|
|
|
with open(os.path.join(output_dir, TRAINING_ARGS_NAME), "w", encoding="utf-8") as f:
|
|
f.write(self.args.to_json_string() + "\n")
|
|
self.finetuning_args.save_to_json(os.path.join(output_dir, FINETUNING_ARGS_NAME))
|
|
|
|
def _load_best_model(self):
|
|
r"""
|
|
Loads trainable parameters from model checkpoint.
|
|
|
|
Subclass and override to inject custom behavior. It should not be directly used by external scripts.
|
|
"""
|
|
logger.info(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).")
|
|
|
|
model = unwrap_model(self.model)
|
|
backbone_model = getattr(model, "pretrained_model") if hasattr(model, "pretrained_model") else model
|
|
|
|
if self.finetuning_args.finetuning_type == "lora":
|
|
backbone_model.load_adapter(self.state.best_model_checkpoint, getattr(backbone_model, "active_adapter"))
|
|
if hasattr(model, "v_head") and load_valuehead_params(model, self.state.best_model_checkpoint):
|
|
model.v_head.load_state_dict({
|
|
"summary.weight": getattr(model, "reward_head_weight"),
|
|
"summary.bias": getattr(model, "reward_head_bias")
|
|
})
|
|
else: # freeze/full-tuning
|
|
load_trainable_params(backbone_model, self.state.best_model_checkpoint)
|