fix ppo train and dpo eval

Former-commit-id: 01260d975477ebb8570933a1bd7f547b4dba607f
This commit is contained in:
hiyouga 2023-11-07 22:48:51 +08:00
parent 100dc4c458
commit 91f406cc99
5 changed files with 56 additions and 21 deletions

View File

@ -75,6 +75,14 @@ class FinetuningArguments:
default=0.1, default=0.1,
metadata={"help": "The beta parameter for the DPO loss."} metadata={"help": "The beta parameter for the DPO loss."}
) )
dpo_ref_model: Optional[str] = field(
default=None,
metadata={"help": "Path to the reference model used for the DPO training."}
)
dpo_ref_model_checkpoint: Optional[str] = field(
default=None,
metadata={"help": "Path to the directory(s) containing the model checkpoints of the reference model."}
)
upcast_layernorm: Optional[bool] = field( upcast_layernorm: Optional[bool] = field(
default=False, default=False,
metadata={"help": "Whether to upcast the layernorm weights in fp32."} metadata={"help": "Whether to upcast the layernorm weights in fp32."}
@ -91,7 +99,7 @@ class FinetuningArguments:
if isinstance(self.additional_target, str): if isinstance(self.additional_target, str):
self.additional_target = [target.strip() for target in self.additional_target.split(",")] self.additional_target = [target.strip() for target in self.additional_target.split(",")]
assert self.finetuning_type in ["lora", "freeze", "full", "none"], "Invalid fine-tuning method." assert self.finetuning_type in ["lora", "freeze", "full"], "Invalid fine-tuning method."
def save_to_json(self, json_path: str): def save_to_json(self, json_path: str):
r"""Saves the content of this instance in JSON format inside `json_path`.""" r"""Saves the content of this instance in JSON format inside `json_path`."""

View File

@ -1,5 +1,5 @@
from typing import Literal, Optional from typing import Any, Dict, Literal, Optional
from dataclasses import dataclass, field from dataclasses import asdict, dataclass, field
@dataclass @dataclass
@ -44,7 +44,7 @@ class ModelArguments:
) )
checkpoint_dir: Optional[str] = field( checkpoint_dir: Optional[str] = field(
default=None, default=None,
metadata={"help": "Path to the directory(s) containing the delta model checkpoints as well as the configurations."} metadata={"help": "Path to the directory(s) containing the model checkpoints as well as the configurations."}
) )
flash_attn: Optional[bool] = field( flash_attn: Optional[bool] = field(
default=False, default=False,
@ -83,3 +83,6 @@ class ModelArguments:
if self.quantization_bit is not None: if self.quantization_bit is not None:
assert self.quantization_bit in [4, 8], "We only accept 4-bit or 8-bit quantization." assert self.quantization_bit in [4, 8], "We only accept 4-bit or 8-bit quantization."
def to_dict(self) -> Dict[str, Any]:
return asdict(self)

View File

@ -36,8 +36,8 @@ def init_adapter(
Note that the trainable parameters must be cast to float32. Note that the trainable parameters must be cast to float32.
""" """
if finetuning_args.finetuning_type == "none" and is_trainable: if (not is_trainable) and model_args.checkpoint_dir is None:
raise ValueError("You cannot use finetuning_type=none while training.") logger.info("Checkpoint is not found at evaluation, load the original model.")
if finetuning_args.finetuning_type == "full" and is_trainable: if finetuning_args.finetuning_type == "full" and is_trainable:
logger.info("Fine-tuning method: Full") logger.info("Fine-tuning method: Full")
@ -60,11 +60,11 @@ def init_adapter(
if finetuning_args.finetuning_type == "lora": if finetuning_args.finetuning_type == "lora":
logger.info("Fine-tuning method: LoRA") logger.info("Fine-tuning method: LoRA")
latest_checkpoint = None checkpoint_to_resume = None
if model_args.checkpoint_dir is not None: if model_args.checkpoint_dir is not None:
if is_trainable and finetuning_args.resume_lora_training: # continually fine-tuning if is_trainable and finetuning_args.resume_lora_training:
checkpoints_to_merge, latest_checkpoint = model_args.checkpoint_dir[:-1], model_args.checkpoint_dir[-1] checkpoints_to_merge, checkpoint_to_resume = model_args.checkpoint_dir[:-1], model_args.checkpoint_dir[-1]
else: else:
checkpoints_to_merge = model_args.checkpoint_dir checkpoints_to_merge = model_args.checkpoint_dir
@ -75,10 +75,10 @@ def init_adapter(
if len(checkpoints_to_merge) > 0: if len(checkpoints_to_merge) > 0:
logger.info("Merged {} model checkpoint(s).".format(len(checkpoints_to_merge))) logger.info("Merged {} model checkpoint(s).".format(len(checkpoints_to_merge)))
if latest_checkpoint is not None: # resume lora training or quantized inference if checkpoint_to_resume is not None: # resume lora training
model = PeftModel.from_pretrained(model, latest_checkpoint, is_trainable=is_trainable) model = PeftModel.from_pretrained(model, checkpoint_to_resume, is_trainable=is_trainable)
if is_trainable and latest_checkpoint is None: # create new lora weights while training if is_trainable and checkpoint_to_resume is None: # create new lora weights while training
if len(finetuning_args.lora_target) == 1 and finetuning_args.lora_target[0] == "all": if len(finetuning_args.lora_target) == 1 and finetuning_args.lora_target[0] == "all":
target_modules = find_all_linear_modules(model, model_args.quantization_bit) target_modules = find_all_linear_modules(model, model_args.quantization_bit)
else: else:

View File

@ -15,6 +15,7 @@ from transformers import (
) )
from transformers.models.llama import modeling_llama as LlamaModule from transformers.models.llama import modeling_llama as LlamaModule
from transformers.utils.versions import require_version from transformers.utils.versions import require_version
from peft import PeftModel
from trl import AutoModelForCausalLMWithValueHead from trl import AutoModelForCausalLMWithValueHead
try: try:
@ -55,9 +56,6 @@ def load_model_and_tokenizer(
Support both training and inference. Support both training and inference.
""" """
if (not is_trainable) and model_args.checkpoint_dir is None:
logger.warning("Checkpoint is not found at evaluation, load the original model.")
finetuning_args = FinetuningArguments(finetuning_type="none")
config_kwargs = { config_kwargs = {
"trust_remote_code": True, "trust_remote_code": True,
@ -212,8 +210,11 @@ def load_model_and_tokenizer(
if stage == "ppo": # load reward model if stage == "ppo": # load reward model
logger.info("Load reward model from {}".format(model_args.reward_model)) logger.info("Load reward model from {}".format(model_args.reward_model))
if getattr(model, "is_peft_model", False): if isinstance(model.pretrained_model, PeftModel):
model.pretrained_model.load_adapter(model_args.reward_model, "reward") model.pretrained_model.load_adapter(model_args.reward_model, "reward")
for name, param in model.named_parameters(): # https://github.com/huggingface/peft/issues/1090
if "default" in name:
param.data = param.data.to(torch.float32) # trainable params should in fp32
assert load_valuehead_params(model, model_args), "Reward model is not correctly loaded." assert load_valuehead_params(model, model_args), "Reward model is not correctly loaded."
# Prepare model for inference # Prepare model for inference

View File

@ -1,20 +1,24 @@
# Inspired by: https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama_2/scripts/dpo_llama2.py # Inspired by: https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama_2/scripts/dpo_llama2.py
from copy import deepcopy
from peft import PeftModel from peft import PeftModel
from typing import TYPE_CHECKING, Optional, List from typing import TYPE_CHECKING, Optional, List
from transformers import Seq2SeqTrainingArguments from transformers import Seq2SeqTrainingArguments
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
from llmtuner.extras.constants import IGNORE_INDEX from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.extras.logging import get_logger
from llmtuner.extras.ploting import plot_loss from llmtuner.extras.ploting import plot_loss
from llmtuner.hparams import ModelArguments
from llmtuner.tuner.core import generate_model_card, load_model_and_tokenizer from llmtuner.tuner.core import generate_model_card, load_model_and_tokenizer
from llmtuner.tuner.dpo.collator import DPODataCollatorWithPadding from llmtuner.tuner.dpo.collator import DPODataCollatorWithPadding
from llmtuner.tuner.dpo.trainer import CustomDPOTrainer from llmtuner.tuner.dpo.trainer import CustomDPOTrainer
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import TrainerCallback from transformers import TrainerCallback
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments from llmtuner.hparams import DataArguments, FinetuningArguments
logger = get_logger(__name__)
def run_dpo( def run_dpo(
@ -34,9 +38,23 @@ def run_dpo(
) )
# Create reference model # Create reference model
ref_model = None if finetuning_args.dpo_ref_model is not None:
if not isinstance(model, PeftModel): ref_model_args_dict = model_args.to_dict()
ref_model, _ = load_model_and_tokenizer(model_args, finetuning_args, is_trainable=False, stage="sft") ref_model_args_dict.update(dict(
model_name_or_path=finetuning_args.dpo_ref_model,
checkpoint_dir=finetuning_args.dpo_ref_model_checkpoint
))
ref_model_args = ModelArguments(**ref_model_args_dict)
ref_model, _ = load_model_and_tokenizer(ref_model_args, finetuning_args, is_trainable=False, stage="sft")
logger.info("Created reference model from {}".format(finetuning_args.dpo_ref_model))
elif training_args.do_train:
if isinstance(model, PeftModel):
ref_model = None
else:
ref_model, _ = load_model_and_tokenizer(model_args, finetuning_args, is_trainable=False, stage="sft")
logger.info("Created reference model from the model itself.")
else:
ref_model = model
# Update arguments # Update arguments
training_args_dict = training_args.to_dict() training_args_dict = training_args.to_dict()
@ -68,6 +86,11 @@ def run_dpo(
# Evaluation # Evaluation
if training_args.do_eval: if training_args.do_eval:
metrics = trainer.evaluate(metric_key_prefix="eval") metrics = trainer.evaluate(metric_key_prefix="eval")
if id(model) == id(ref_model): # unable to compute rewards without a reference model
logger.warning("Pass `dpo_ref_model` for computing rewards at evaluation.")
remove_keys = [key for key in metrics.keys() if "rewards" in key]
for key in remove_keys:
metrics.pop(key)
trainer.log_metrics("eval", metrics) trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics) trainer.save_metrics("eval", metrics)