mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-16 11:50:35 +08:00
Initial commit
This commit is contained in:
78
src/utils/peft_trainer.py
Normal file
78
src/utils/peft_trainer.py
Normal file
@@ -0,0 +1,78 @@
|
||||
import os
|
||||
import torch
|
||||
from typing import Dict, Optional
|
||||
|
||||
from transformers import Seq2SeqTrainer
|
||||
from transformers.trainer import TRAINING_ARGS_NAME
|
||||
from transformers.modeling_utils import unwrap_model
|
||||
|
||||
from peft.utils.other import WEIGHTS_NAME
|
||||
|
||||
from .config import FinetuningArguments
|
||||
|
||||
from .other import (
|
||||
get_logger,
|
||||
get_state_dict,
|
||||
load_trainable_params,
|
||||
load_valuehead_params,
|
||||
FINETUNING_ARGS_NAME,
|
||||
VALUE_HEAD_FILE_NAME
|
||||
)
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class PeftTrainer(Seq2SeqTrainer):
|
||||
r"""
|
||||
Inherits Seq2SeqTrainer to support parameter-efficient checkpoints.
|
||||
"""
|
||||
|
||||
def __init__(self, finetuning_args: FinetuningArguments, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.finetuning_args = finetuning_args
|
||||
|
||||
def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, torch.Tensor]] = None) -> None:
|
||||
r"""
|
||||
Saves trainable parameters as model checkpoint.
|
||||
|
||||
This function will only be executed at the process zero.
|
||||
|
||||
Subclass and override to inject custom behavior. It should not be directly used by external scripts.
|
||||
"""
|
||||
output_dir = output_dir if output_dir is not None else self.args.output_dir
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
logger.info(f"Saving model checkpoint to {output_dir}")
|
||||
model = unwrap_model(self.model)
|
||||
|
||||
if hasattr(model, "pretrained_model"): # for models with valuehead
|
||||
backbone_model = getattr(model, "pretrained_model")
|
||||
else:
|
||||
backbone_model = model
|
||||
|
||||
if hasattr(backbone_model, "peft_config"): # peft methods
|
||||
backbone_model.save_pretrained(output_dir, state_dict=get_state_dict(backbone_model)) # save lora weights
|
||||
else:
|
||||
torch.save(get_state_dict(backbone_model), os.path.join(output_dir, WEIGHTS_NAME)) # save trainable weights
|
||||
|
||||
if hasattr(model, "v_head"): # save valuehead weights
|
||||
torch.save(get_state_dict(getattr(model, "v_head")), os.path.join(output_dir, VALUE_HEAD_FILE_NAME))
|
||||
|
||||
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
|
||||
self.finetuning_args.save_to_json(os.path.join(output_dir, FINETUNING_ARGS_NAME))
|
||||
|
||||
def _load_best_model(self):
|
||||
r"""
|
||||
Loads trainable parameters from model checkpoint.
|
||||
|
||||
Subclass and override to inject custom behavior. It should not be directly used by external scripts.
|
||||
"""
|
||||
logger.info(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).")
|
||||
model = unwrap_model(self.model)
|
||||
if hasattr(model, "peft_config"): # peft methods
|
||||
model.load_adapter(self.state.best_model_checkpoint, getattr(model, "active_adapter"))
|
||||
else:
|
||||
load_trainable_params(model, self.state.best_model_checkpoint)
|
||||
|
||||
if hasattr(model, "v_head"):
|
||||
load_valuehead_params(model, self.state.best_model_checkpoint)
|
||||
Reference in New Issue
Block a user