diff --git a/src/llmtuner/data/template.py b/src/llmtuner/data/template.py index af80272b..cd645cf7 100644 --- a/src/llmtuner/data/template.py +++ b/src/llmtuner/data/template.py @@ -689,6 +689,8 @@ _register_template( _register_template( name="vanilla", + format_separator=EmptyFormatter(slots=["\n"]), + efficient_eos=True, ) diff --git a/src/llmtuner/eval/template.py b/src/llmtuner/eval/template.py index b17f7084..a4a6ef0e 100644 --- a/src/llmtuner/eval/template.py +++ b/src/llmtuner/eval/template.py @@ -1,14 +1,10 @@ from dataclasses import dataclass -from typing import TYPE_CHECKING, Dict, List, Tuple +from typing import Dict, List, Sequence, Tuple from ..data import Role from ..extras.constants import CHOICES -if TYPE_CHECKING: - from datasets import Dataset - - @dataclass class EvalTemplate: system: str @@ -16,22 +12,29 @@ class EvalTemplate: answer: str prefix: str - def parse_example(self, example: Dict[str, str]) -> Tuple[str, str]: + def _parse_example(self, example: Dict[str, str]) -> Tuple[str, str]: + r""" + input: a dict with keys {"question", "A", "B", "C", "D", "answer"} + output: a tuple of (prompt, response) + """ candidates = [self.choice.format(choice=ch, content=example[ch]) for ch in CHOICES if ch in example] return "".join([example["question"]] + candidates + [self.answer]), example["answer"] def format_example( - self, target_data: Dict[str, str], support_set: "Dataset", subject_name: str + self, target_data: Dict[str, str], support_set: Sequence[Dict[str, str]], subject_name: str ) -> List[Dict[str, str]]: + r""" + Converts dataset examples to messages. + """ messages = [] for k in range(len(support_set)): - prompt, response = self.parse_example(support_set[k]) - messages.append({"role": Role.USER, "content": prompt}) - messages.append({"role": Role.ASSISTANT, "content": response}) + prompt, response = self._parse_example(support_set[k]) + messages.append({"role": Role.USER.value, "content": prompt}) + messages.append({"role": Role.ASSISTANT.value, "content": response}) - prompt, response = self.parse_example(target_data) - messages.append({"role": Role.USER, "content": prompt}) - messages.append({"role": Role.ASSISTANT, "content": response}) + prompt, response = self._parse_example(target_data) + messages.append({"role": Role.USER.value, "content": prompt}) + messages.append({"role": Role.ASSISTANT.value, "content": response}) messages[0]["content"] = self.system.format(subject=subject_name) + messages[0]["content"] return messages @@ -39,7 +42,7 @@ class EvalTemplate: eval_templates: Dict[str, "EvalTemplate"] = {} -def register_eval_template(name: str, system: str, choice: str, answer: str, prefix: str) -> None: +def _register_eval_template(name: str, system: str, choice: str, answer: str, prefix: str) -> None: eval_templates[name] = EvalTemplate(system=system, choice=choice, answer=answer, prefix=prefix) @@ -49,7 +52,7 @@ def get_eval_template(name: str) -> "EvalTemplate": return eval_template -register_eval_template( +_register_eval_template( name="en", system="The following are multiple choice questions (with answers) about {subject}.\n\n", choice="\n{choice}. {content}", @@ -58,10 +61,10 @@ register_eval_template( ) -register_eval_template( +_register_eval_template( name="zh", system="以下是中国关于{subject}考试的单项选择题,请选出其中的正确答案。\n\n", choice="\n{choice}. {content}", answer="\n答案:", - prefix="\n", + prefix=" ", ) diff --git a/src/llmtuner/hparams/finetuning_args.py b/src/llmtuner/hparams/finetuning_args.py index 8188fdcc..6fe951f1 100644 --- a/src/llmtuner/hparams/finetuning_args.py +++ b/src/llmtuner/hparams/finetuning_args.py @@ -102,6 +102,10 @@ class RLHFArguments: default="sigmoid", metadata={"help": "The type of DPO loss to use."}, ) + dpo_label_smoothing = field( + default=0.0, + metadata={"help": "The robust DPO label smoothing parameter in cDPO that should be between 0 and 0.5."}, + ) dpo_ftx: float = field( default=0.0, metadata={"help": "The supervised fine-tuning loss coefficient in DPO training."}, @@ -248,6 +252,9 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA if self.stage == "ppo" and self.reward_model_type == "lora" and self.finetuning_type != "lora": raise ValueError("`reward_model_type` cannot be lora for Freeze/Full PPO training.") + if self.stage == "dpo" and self.dpo_loss != "sigmoid" and self.dpo_label_smoothing > 1e-6: + raise ValueError("`dpo_label_smoothing` is only valid for sigmoid loss function.") + if self.use_llama_pro and self.finetuning_type == "full": raise ValueError("`use_llama_pro` is only valid for the Freeze or LoRA method.") diff --git a/src/llmtuner/model/patcher.py b/src/llmtuner/model/patcher.py index 7fd5f573..cb55f5ed 100644 --- a/src/llmtuner/model/patcher.py +++ b/src/llmtuner/model/patcher.py @@ -173,7 +173,7 @@ def _configure_quantization( """ if getattr(config, "quantization_config", None): # ptq if is_deepspeed_zero3_enabled(): - raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.") + raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantized models.") init_kwargs["device_map"] = {"": get_current_device()} quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None) diff --git a/src/llmtuner/train/dpo/trainer.py b/src/llmtuner/train/dpo/trainer.py index 957e79ff..ec1cb94f 100644 --- a/src/llmtuner/train/dpo/trainer.py +++ b/src/llmtuner/train/dpo/trainer.py @@ -20,12 +20,9 @@ if TYPE_CHECKING: class CustomDPOTrainer(DPOTrainer): def __init__( self, - beta: float, - loss_type: Literal["sigmoid", "hinge", "ipo", "kto_pair"], - ftx_gamma: float, model: Union["PreTrainedModel", torch.nn.Module], + ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]], finetuning_args: "FinetuningArguments", - ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]] = None, disable_dropout: bool = True, **kwargs, ): @@ -47,10 +44,10 @@ class CustomDPOTrainer(DPOTrainer): self._peft_has_been_casted_to_bf16 = False self.ref_model = ref_model - self.beta = beta - self.label_smoothing = 0 - self.loss_type = loss_type - self.ftx_gamma = ftx_gamma + self.beta = finetuning_args.dpo_beta + self.label_smoothing = finetuning_args.dpo_label_smoothing + self.loss_type = finetuning_args.dpo_loss + self.ftx_gamma = finetuning_args.dpo_ftx self._stored_metrics = defaultdict(lambda: defaultdict(list)) Trainer.__init__(self, model=model, **kwargs) diff --git a/src/llmtuner/train/dpo/workflow.py b/src/llmtuner/train/dpo/workflow.py index 26424ee1..7014177a 100644 --- a/src/llmtuner/train/dpo/workflow.py +++ b/src/llmtuner/train/dpo/workflow.py @@ -45,13 +45,10 @@ def run_dpo( # Initialize our Trainer trainer = CustomDPOTrainer( - beta=finetuning_args.dpo_beta, - loss_type=finetuning_args.dpo_loss, - ftx_gamma=finetuning_args.dpo_ftx, - finetuning_args=finetuning_args, model=model, ref_model=ref_model, args=training_args, + finetuning_args=finetuning_args, tokenizer=tokenizer, data_collator=data_collator, callbacks=callbacks, diff --git a/src/llmtuner/train/utils.py b/src/llmtuner/train/utils.py index 0093f4a4..49c42d4e 100644 --- a/src/llmtuner/train/utils.py +++ b/src/llmtuner/train/utils.py @@ -5,7 +5,6 @@ from transformers import Trainer from transformers.optimization import get_scheduler from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS from transformers.trainer_pt_utils import get_parameter_names -from transformers.utils.versions import require_version from ..extras.logging import get_logger from ..extras.packages import is_galore_available