mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-03 20:22:49 +08:00
parent
c207b279f2
commit
ec94e5e876
@ -689,6 +689,8 @@ _register_template(
|
|||||||
|
|
||||||
_register_template(
|
_register_template(
|
||||||
name="vanilla",
|
name="vanilla",
|
||||||
|
format_separator=EmptyFormatter(slots=["\n"]),
|
||||||
|
efficient_eos=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,14 +1,10 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, Dict, List, Tuple
|
from typing import Dict, List, Sequence, Tuple
|
||||||
|
|
||||||
from ..data import Role
|
from ..data import Role
|
||||||
from ..extras.constants import CHOICES
|
from ..extras.constants import CHOICES
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from datasets import Dataset
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class EvalTemplate:
|
class EvalTemplate:
|
||||||
system: str
|
system: str
|
||||||
@ -16,22 +12,29 @@ class EvalTemplate:
|
|||||||
answer: str
|
answer: str
|
||||||
prefix: str
|
prefix: str
|
||||||
|
|
||||||
def parse_example(self, example: Dict[str, str]) -> Tuple[str, str]:
|
def _parse_example(self, example: Dict[str, str]) -> Tuple[str, str]:
|
||||||
|
r"""
|
||||||
|
input: a dict with keys {"question", "A", "B", "C", "D", "answer"}
|
||||||
|
output: a tuple of (prompt, response)
|
||||||
|
"""
|
||||||
candidates = [self.choice.format(choice=ch, content=example[ch]) for ch in CHOICES if ch in example]
|
candidates = [self.choice.format(choice=ch, content=example[ch]) for ch in CHOICES if ch in example]
|
||||||
return "".join([example["question"]] + candidates + [self.answer]), example["answer"]
|
return "".join([example["question"]] + candidates + [self.answer]), example["answer"]
|
||||||
|
|
||||||
def format_example(
|
def format_example(
|
||||||
self, target_data: Dict[str, str], support_set: "Dataset", subject_name: str
|
self, target_data: Dict[str, str], support_set: Sequence[Dict[str, str]], subject_name: str
|
||||||
) -> List[Dict[str, str]]:
|
) -> List[Dict[str, str]]:
|
||||||
|
r"""
|
||||||
|
Converts dataset examples to messages.
|
||||||
|
"""
|
||||||
messages = []
|
messages = []
|
||||||
for k in range(len(support_set)):
|
for k in range(len(support_set)):
|
||||||
prompt, response = self.parse_example(support_set[k])
|
prompt, response = self._parse_example(support_set[k])
|
||||||
messages.append({"role": Role.USER, "content": prompt})
|
messages.append({"role": Role.USER.value, "content": prompt})
|
||||||
messages.append({"role": Role.ASSISTANT, "content": response})
|
messages.append({"role": Role.ASSISTANT.value, "content": response})
|
||||||
|
|
||||||
prompt, response = self.parse_example(target_data)
|
prompt, response = self._parse_example(target_data)
|
||||||
messages.append({"role": Role.USER, "content": prompt})
|
messages.append({"role": Role.USER.value, "content": prompt})
|
||||||
messages.append({"role": Role.ASSISTANT, "content": response})
|
messages.append({"role": Role.ASSISTANT.value, "content": response})
|
||||||
messages[0]["content"] = self.system.format(subject=subject_name) + messages[0]["content"]
|
messages[0]["content"] = self.system.format(subject=subject_name) + messages[0]["content"]
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
@ -39,7 +42,7 @@ class EvalTemplate:
|
|||||||
eval_templates: Dict[str, "EvalTemplate"] = {}
|
eval_templates: Dict[str, "EvalTemplate"] = {}
|
||||||
|
|
||||||
|
|
||||||
def register_eval_template(name: str, system: str, choice: str, answer: str, prefix: str) -> None:
|
def _register_eval_template(name: str, system: str, choice: str, answer: str, prefix: str) -> None:
|
||||||
eval_templates[name] = EvalTemplate(system=system, choice=choice, answer=answer, prefix=prefix)
|
eval_templates[name] = EvalTemplate(system=system, choice=choice, answer=answer, prefix=prefix)
|
||||||
|
|
||||||
|
|
||||||
@ -49,7 +52,7 @@ def get_eval_template(name: str) -> "EvalTemplate":
|
|||||||
return eval_template
|
return eval_template
|
||||||
|
|
||||||
|
|
||||||
register_eval_template(
|
_register_eval_template(
|
||||||
name="en",
|
name="en",
|
||||||
system="The following are multiple choice questions (with answers) about {subject}.\n\n",
|
system="The following are multiple choice questions (with answers) about {subject}.\n\n",
|
||||||
choice="\n{choice}. {content}",
|
choice="\n{choice}. {content}",
|
||||||
@ -58,10 +61,10 @@ register_eval_template(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
register_eval_template(
|
_register_eval_template(
|
||||||
name="zh",
|
name="zh",
|
||||||
system="以下是中国关于{subject}考试的单项选择题,请选出其中的正确答案。\n\n",
|
system="以下是中国关于{subject}考试的单项选择题,请选出其中的正确答案。\n\n",
|
||||||
choice="\n{choice}. {content}",
|
choice="\n{choice}. {content}",
|
||||||
answer="\n答案:",
|
answer="\n答案:",
|
||||||
prefix="\n",
|
prefix=" ",
|
||||||
)
|
)
|
||||||
|
@ -102,6 +102,10 @@ class RLHFArguments:
|
|||||||
default="sigmoid",
|
default="sigmoid",
|
||||||
metadata={"help": "The type of DPO loss to use."},
|
metadata={"help": "The type of DPO loss to use."},
|
||||||
)
|
)
|
||||||
|
dpo_label_smoothing = field(
|
||||||
|
default=0.0,
|
||||||
|
metadata={"help": "The robust DPO label smoothing parameter in cDPO that should be between 0 and 0.5."},
|
||||||
|
)
|
||||||
dpo_ftx: float = field(
|
dpo_ftx: float = field(
|
||||||
default=0.0,
|
default=0.0,
|
||||||
metadata={"help": "The supervised fine-tuning loss coefficient in DPO training."},
|
metadata={"help": "The supervised fine-tuning loss coefficient in DPO training."},
|
||||||
@ -248,6 +252,9 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA
|
|||||||
if self.stage == "ppo" and self.reward_model_type == "lora" and self.finetuning_type != "lora":
|
if self.stage == "ppo" and self.reward_model_type == "lora" and self.finetuning_type != "lora":
|
||||||
raise ValueError("`reward_model_type` cannot be lora for Freeze/Full PPO training.")
|
raise ValueError("`reward_model_type` cannot be lora for Freeze/Full PPO training.")
|
||||||
|
|
||||||
|
if self.stage == "dpo" and self.dpo_loss != "sigmoid" and self.dpo_label_smoothing > 1e-6:
|
||||||
|
raise ValueError("`dpo_label_smoothing` is only valid for sigmoid loss function.")
|
||||||
|
|
||||||
if self.use_llama_pro and self.finetuning_type == "full":
|
if self.use_llama_pro and self.finetuning_type == "full":
|
||||||
raise ValueError("`use_llama_pro` is only valid for the Freeze or LoRA method.")
|
raise ValueError("`use_llama_pro` is only valid for the Freeze or LoRA method.")
|
||||||
|
|
||||||
|
@ -173,7 +173,7 @@ def _configure_quantization(
|
|||||||
"""
|
"""
|
||||||
if getattr(config, "quantization_config", None): # ptq
|
if getattr(config, "quantization_config", None): # ptq
|
||||||
if is_deepspeed_zero3_enabled():
|
if is_deepspeed_zero3_enabled():
|
||||||
raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.")
|
raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantized models.")
|
||||||
|
|
||||||
init_kwargs["device_map"] = {"": get_current_device()}
|
init_kwargs["device_map"] = {"": get_current_device()}
|
||||||
quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None)
|
quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None)
|
||||||
|
@ -20,12 +20,9 @@ if TYPE_CHECKING:
|
|||||||
class CustomDPOTrainer(DPOTrainer):
|
class CustomDPOTrainer(DPOTrainer):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
beta: float,
|
|
||||||
loss_type: Literal["sigmoid", "hinge", "ipo", "kto_pair"],
|
|
||||||
ftx_gamma: float,
|
|
||||||
model: Union["PreTrainedModel", torch.nn.Module],
|
model: Union["PreTrainedModel", torch.nn.Module],
|
||||||
|
ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]],
|
||||||
finetuning_args: "FinetuningArguments",
|
finetuning_args: "FinetuningArguments",
|
||||||
ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]] = None,
|
|
||||||
disable_dropout: bool = True,
|
disable_dropout: bool = True,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
@ -47,10 +44,10 @@ class CustomDPOTrainer(DPOTrainer):
|
|||||||
self._peft_has_been_casted_to_bf16 = False
|
self._peft_has_been_casted_to_bf16 = False
|
||||||
|
|
||||||
self.ref_model = ref_model
|
self.ref_model = ref_model
|
||||||
self.beta = beta
|
self.beta = finetuning_args.dpo_beta
|
||||||
self.label_smoothing = 0
|
self.label_smoothing = finetuning_args.dpo_label_smoothing
|
||||||
self.loss_type = loss_type
|
self.loss_type = finetuning_args.dpo_loss
|
||||||
self.ftx_gamma = ftx_gamma
|
self.ftx_gamma = finetuning_args.dpo_ftx
|
||||||
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
||||||
|
|
||||||
Trainer.__init__(self, model=model, **kwargs)
|
Trainer.__init__(self, model=model, **kwargs)
|
||||||
|
@ -45,13 +45,10 @@ def run_dpo(
|
|||||||
|
|
||||||
# Initialize our Trainer
|
# Initialize our Trainer
|
||||||
trainer = CustomDPOTrainer(
|
trainer = CustomDPOTrainer(
|
||||||
beta=finetuning_args.dpo_beta,
|
|
||||||
loss_type=finetuning_args.dpo_loss,
|
|
||||||
ftx_gamma=finetuning_args.dpo_ftx,
|
|
||||||
finetuning_args=finetuning_args,
|
|
||||||
model=model,
|
model=model,
|
||||||
ref_model=ref_model,
|
ref_model=ref_model,
|
||||||
args=training_args,
|
args=training_args,
|
||||||
|
finetuning_args=finetuning_args,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
data_collator=data_collator,
|
data_collator=data_collator,
|
||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
|
@ -5,7 +5,6 @@ from transformers import Trainer
|
|||||||
from transformers.optimization import get_scheduler
|
from transformers.optimization import get_scheduler
|
||||||
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
|
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
|
||||||
from transformers.trainer_pt_utils import get_parameter_names
|
from transformers.trainer_pt_utils import get_parameter_names
|
||||||
from transformers.utils.versions import require_version
|
|
||||||
|
|
||||||
from ..extras.logging import get_logger
|
from ..extras.logging import get_logger
|
||||||
from ..extras.packages import is_galore_available
|
from ..extras.packages import is_galore_available
|
||||||
|
Loading…
x
Reference in New Issue
Block a user