mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-03 20:22:49 +08:00
81 lines
3.9 KiB
Python
81 lines
3.9 KiB
Python
import torch
|
|
from typing import TYPE_CHECKING, Literal, Union
|
|
|
|
from llmtuner.extras.logging import get_logger
|
|
from llmtuner.hparams import ModelArguments, FinetuningArguments
|
|
from llmtuner.model import load_model_and_tokenizer, load_valuehead_params
|
|
|
|
if TYPE_CHECKING:
|
|
from transformers.modeling_utils import PreTrainedModel
|
|
from trl import AutoModelForCausalLMWithValueHead
|
|
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
def create_ref_model(
|
|
model_args: "ModelArguments",
|
|
finetuning_args: "FinetuningArguments",
|
|
stage: Literal["ppo", "dpo"]
|
|
) -> Union["PreTrainedModel", "AutoModelForCausalLMWithValueHead"]:
|
|
r"""
|
|
Creates reference model for PPO/DPO training. Evaluation mode is not supported.
|
|
|
|
The valuehead parameter is randomly initialized since it is useless for PPO training.
|
|
"""
|
|
if finetuning_args.ref_model is not None:
|
|
ref_model_args_dict = model_args.to_dict()
|
|
ref_model_args_dict.update(dict(
|
|
model_name_or_path=finetuning_args.ref_model,
|
|
checkpoint_dir=finetuning_args.ref_model_checkpoint,
|
|
quantization_bit=finetuning_args.ref_model_quantization_bit
|
|
))
|
|
ref_model_args = ModelArguments(**ref_model_args_dict)
|
|
ref_finetuning_args = FinetuningArguments(finetuning_type="lora")
|
|
ref_model, _ = load_model_and_tokenizer(ref_model_args, ref_finetuning_args, is_trainable=False, stage=stage)
|
|
logger.info("Created reference model from {}".format(finetuning_args.ref_model))
|
|
else:
|
|
if finetuning_args.finetuning_type == "lora":
|
|
ref_model = None
|
|
else:
|
|
ref_model, _ = load_model_and_tokenizer(model_args, finetuning_args, is_trainable=False, stage=stage)
|
|
logger.info("Created reference model from the model itself.")
|
|
|
|
return ref_model
|
|
|
|
|
|
def create_reward_model(
|
|
model: "AutoModelForCausalLMWithValueHead",
|
|
model_args: "ModelArguments",
|
|
finetuning_args: "FinetuningArguments"
|
|
) -> "AutoModelForCausalLMWithValueHead":
|
|
r"""
|
|
Creates reward model for PPO training.
|
|
"""
|
|
if finetuning_args.reward_model_type == "lora":
|
|
model.pretrained_model.load_adapter(finetuning_args.reward_model, "reward")
|
|
for name, param in model.named_parameters(): # https://github.com/huggingface/peft/issues/1090
|
|
if "default" in name:
|
|
param.data = param.data.to(torch.float32) # trainable params should in fp32
|
|
vhead_params = load_valuehead_params(finetuning_args.reward_model, model_args)
|
|
assert vhead_params is not None, "Reward model is not correctly loaded."
|
|
model.register_buffer("reward_head_weight", vhead_params["v_head.summary.weight"], persistent=False)
|
|
model.register_buffer("reward_head_bias", vhead_params["v_head.summary.bias"], persistent=False)
|
|
model.register_buffer("default_head_weight", torch.zeros_like(vhead_params["v_head.summary.weight"]), persistent=False)
|
|
model.register_buffer("default_head_bias", torch.zeros_like(vhead_params["v_head.summary.bias"]), persistent=False)
|
|
logger.info("Loaded adapter weights of reward model from {}".format(finetuning_args.reward_model))
|
|
return None
|
|
else:
|
|
reward_model_args_dict = model_args.to_dict()
|
|
reward_model_args_dict.update(dict(
|
|
model_name_or_path=finetuning_args.reward_model,
|
|
checkpoint_dir=finetuning_args.reward_model_checkpoint,
|
|
quantization_bit=finetuning_args.reward_model_quantization_bit
|
|
))
|
|
reward_model_args = ModelArguments(**reward_model_args_dict)
|
|
reward_finetuning_args = FinetuningArguments(finetuning_type="lora")
|
|
reward_model, _ = load_model_and_tokenizer(reward_model_args, reward_finetuning_args, is_trainable=False, stage="ppo")
|
|
logger.info("Load full weights of reward model from {}".format(finetuning_args.reward_model))
|
|
logger.warning("Please ensure the ppo model and reward model share SAME tokenizer and vocabulary.")
|
|
return reward_model
|