Former-commit-id: 511f6754026fbbf48bd481018015338a6a3ad92f
This commit is contained in:
hiyouga 2024-03-26 17:26:14 +08:00
parent c207b279f2
commit ec94e5e876
7 changed files with 36 additions and 31 deletions

View File

@ -689,6 +689,8 @@ _register_template(
_register_template( _register_template(
name="vanilla", name="vanilla",
format_separator=EmptyFormatter(slots=["\n"]),
efficient_eos=True,
) )

View File

@ -1,14 +1,10 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Tuple from typing import Dict, List, Sequence, Tuple
from ..data import Role from ..data import Role
from ..extras.constants import CHOICES from ..extras.constants import CHOICES
if TYPE_CHECKING:
from datasets import Dataset
@dataclass @dataclass
class EvalTemplate: class EvalTemplate:
system: str system: str
@ -16,22 +12,29 @@ class EvalTemplate:
answer: str answer: str
prefix: str prefix: str
def parse_example(self, example: Dict[str, str]) -> Tuple[str, str]: def _parse_example(self, example: Dict[str, str]) -> Tuple[str, str]:
r"""
input: a dict with keys {"question", "A", "B", "C", "D", "answer"}
output: a tuple of (prompt, response)
"""
candidates = [self.choice.format(choice=ch, content=example[ch]) for ch in CHOICES if ch in example] candidates = [self.choice.format(choice=ch, content=example[ch]) for ch in CHOICES if ch in example]
return "".join([example["question"]] + candidates + [self.answer]), example["answer"] return "".join([example["question"]] + candidates + [self.answer]), example["answer"]
def format_example( def format_example(
self, target_data: Dict[str, str], support_set: "Dataset", subject_name: str self, target_data: Dict[str, str], support_set: Sequence[Dict[str, str]], subject_name: str
) -> List[Dict[str, str]]: ) -> List[Dict[str, str]]:
r"""
Converts dataset examples to messages.
"""
messages = [] messages = []
for k in range(len(support_set)): for k in range(len(support_set)):
prompt, response = self.parse_example(support_set[k]) prompt, response = self._parse_example(support_set[k])
messages.append({"role": Role.USER, "content": prompt}) messages.append({"role": Role.USER.value, "content": prompt})
messages.append({"role": Role.ASSISTANT, "content": response}) messages.append({"role": Role.ASSISTANT.value, "content": response})
prompt, response = self.parse_example(target_data) prompt, response = self._parse_example(target_data)
messages.append({"role": Role.USER, "content": prompt}) messages.append({"role": Role.USER.value, "content": prompt})
messages.append({"role": Role.ASSISTANT, "content": response}) messages.append({"role": Role.ASSISTANT.value, "content": response})
messages[0]["content"] = self.system.format(subject=subject_name) + messages[0]["content"] messages[0]["content"] = self.system.format(subject=subject_name) + messages[0]["content"]
return messages return messages
@ -39,7 +42,7 @@ class EvalTemplate:
eval_templates: Dict[str, "EvalTemplate"] = {} eval_templates: Dict[str, "EvalTemplate"] = {}
def register_eval_template(name: str, system: str, choice: str, answer: str, prefix: str) -> None: def _register_eval_template(name: str, system: str, choice: str, answer: str, prefix: str) -> None:
eval_templates[name] = EvalTemplate(system=system, choice=choice, answer=answer, prefix=prefix) eval_templates[name] = EvalTemplate(system=system, choice=choice, answer=answer, prefix=prefix)
@ -49,7 +52,7 @@ def get_eval_template(name: str) -> "EvalTemplate":
return eval_template return eval_template
register_eval_template( _register_eval_template(
name="en", name="en",
system="The following are multiple choice questions (with answers) about {subject}.\n\n", system="The following are multiple choice questions (with answers) about {subject}.\n\n",
choice="\n{choice}. {content}", choice="\n{choice}. {content}",
@ -58,10 +61,10 @@ register_eval_template(
) )
register_eval_template( _register_eval_template(
name="zh", name="zh",
system="以下是中国关于{subject}考试的单项选择题,请选出其中的正确答案。\n\n", system="以下是中国关于{subject}考试的单项选择题,请选出其中的正确答案。\n\n",
choice="\n{choice}. {content}", choice="\n{choice}. {content}",
answer="\n答案:", answer="\n答案:",
prefix="\n", prefix=" ",
) )

View File

@ -102,6 +102,10 @@ class RLHFArguments:
default="sigmoid", default="sigmoid",
metadata={"help": "The type of DPO loss to use."}, metadata={"help": "The type of DPO loss to use."},
) )
dpo_label_smoothing = field(
default=0.0,
metadata={"help": "The robust DPO label smoothing parameter in cDPO that should be between 0 and 0.5."},
)
dpo_ftx: float = field( dpo_ftx: float = field(
default=0.0, default=0.0,
metadata={"help": "The supervised fine-tuning loss coefficient in DPO training."}, metadata={"help": "The supervised fine-tuning loss coefficient in DPO training."},
@ -248,6 +252,9 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA
if self.stage == "ppo" and self.reward_model_type == "lora" and self.finetuning_type != "lora": if self.stage == "ppo" and self.reward_model_type == "lora" and self.finetuning_type != "lora":
raise ValueError("`reward_model_type` cannot be lora for Freeze/Full PPO training.") raise ValueError("`reward_model_type` cannot be lora for Freeze/Full PPO training.")
if self.stage == "dpo" and self.dpo_loss != "sigmoid" and self.dpo_label_smoothing > 1e-6:
raise ValueError("`dpo_label_smoothing` is only valid for sigmoid loss function.")
if self.use_llama_pro and self.finetuning_type == "full": if self.use_llama_pro and self.finetuning_type == "full":
raise ValueError("`use_llama_pro` is only valid for the Freeze or LoRA method.") raise ValueError("`use_llama_pro` is only valid for the Freeze or LoRA method.")

View File

@ -173,7 +173,7 @@ def _configure_quantization(
""" """
if getattr(config, "quantization_config", None): # ptq if getattr(config, "quantization_config", None): # ptq
if is_deepspeed_zero3_enabled(): if is_deepspeed_zero3_enabled():
raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.") raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantized models.")
init_kwargs["device_map"] = {"": get_current_device()} init_kwargs["device_map"] = {"": get_current_device()}
quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None) quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None)

View File

@ -20,12 +20,9 @@ if TYPE_CHECKING:
class CustomDPOTrainer(DPOTrainer): class CustomDPOTrainer(DPOTrainer):
def __init__( def __init__(
self, self,
beta: float,
loss_type: Literal["sigmoid", "hinge", "ipo", "kto_pair"],
ftx_gamma: float,
model: Union["PreTrainedModel", torch.nn.Module], model: Union["PreTrainedModel", torch.nn.Module],
ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]],
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]] = None,
disable_dropout: bool = True, disable_dropout: bool = True,
**kwargs, **kwargs,
): ):
@ -47,10 +44,10 @@ class CustomDPOTrainer(DPOTrainer):
self._peft_has_been_casted_to_bf16 = False self._peft_has_been_casted_to_bf16 = False
self.ref_model = ref_model self.ref_model = ref_model
self.beta = beta self.beta = finetuning_args.dpo_beta
self.label_smoothing = 0 self.label_smoothing = finetuning_args.dpo_label_smoothing
self.loss_type = loss_type self.loss_type = finetuning_args.dpo_loss
self.ftx_gamma = ftx_gamma self.ftx_gamma = finetuning_args.dpo_ftx
self._stored_metrics = defaultdict(lambda: defaultdict(list)) self._stored_metrics = defaultdict(lambda: defaultdict(list))
Trainer.__init__(self, model=model, **kwargs) Trainer.__init__(self, model=model, **kwargs)

View File

@ -45,13 +45,10 @@ def run_dpo(
# Initialize our Trainer # Initialize our Trainer
trainer = CustomDPOTrainer( trainer = CustomDPOTrainer(
beta=finetuning_args.dpo_beta,
loss_type=finetuning_args.dpo_loss,
ftx_gamma=finetuning_args.dpo_ftx,
finetuning_args=finetuning_args,
model=model, model=model,
ref_model=ref_model, ref_model=ref_model,
args=training_args, args=training_args,
finetuning_args=finetuning_args,
tokenizer=tokenizer, tokenizer=tokenizer,
data_collator=data_collator, data_collator=data_collator,
callbacks=callbacks, callbacks=callbacks,

View File

@ -5,7 +5,6 @@ from transformers import Trainer
from transformers.optimization import get_scheduler from transformers.optimization import get_scheduler
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
from transformers.trainer_pt_utils import get_parameter_names from transformers.trainer_pt_utils import get_parameter_names
from transformers.utils.versions import require_version
from ..extras.logging import get_logger from ..extras.logging import get_logger
from ..extras.packages import is_galore_available from ..extras.packages import is_galore_available