mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-29 10:10:35 +08:00
fix ppo train and dpo eval
Former-commit-id: ced863031836632cb5920e22ae6991f251372118
This commit is contained in:
@@ -15,6 +15,7 @@ from transformers import (
|
||||
)
|
||||
from transformers.models.llama import modeling_llama as LlamaModule
|
||||
from transformers.utils.versions import require_version
|
||||
from peft import PeftModel
|
||||
from trl import AutoModelForCausalLMWithValueHead
|
||||
|
||||
try:
|
||||
@@ -55,9 +56,6 @@ def load_model_and_tokenizer(
|
||||
|
||||
Support both training and inference.
|
||||
"""
|
||||
if (not is_trainable) and model_args.checkpoint_dir is None:
|
||||
logger.warning("Checkpoint is not found at evaluation, load the original model.")
|
||||
finetuning_args = FinetuningArguments(finetuning_type="none")
|
||||
|
||||
config_kwargs = {
|
||||
"trust_remote_code": True,
|
||||
@@ -212,8 +210,11 @@ def load_model_and_tokenizer(
|
||||
|
||||
if stage == "ppo": # load reward model
|
||||
logger.info("Load reward model from {}".format(model_args.reward_model))
|
||||
if getattr(model, "is_peft_model", False):
|
||||
if isinstance(model.pretrained_model, PeftModel):
|
||||
model.pretrained_model.load_adapter(model_args.reward_model, "reward")
|
||||
for name, param in model.named_parameters(): # https://github.com/huggingface/peft/issues/1090
|
||||
if "default" in name:
|
||||
param.data = param.data.to(torch.float32) # trainable params should in fp32
|
||||
assert load_valuehead_params(model, model_args), "Reward model is not correctly loaded."
|
||||
|
||||
# Prepare model for inference
|
||||
|
||||
Reference in New Issue
Block a user