mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-14 19:06:26 +08:00
@@ -114,14 +114,18 @@ class LoraArguments:
|
||||
@dataclass
|
||||
class RLHFArguments:
|
||||
r"""
|
||||
Arguments pertaining to the PPO and DPO training.
|
||||
Arguments pertaining to the PPO, DPO and KTO training.
|
||||
"""
|
||||
|
||||
dpo_beta: float = field(
|
||||
pref_beta: float = field(
|
||||
default=0.1,
|
||||
metadata={"help": "The beta parameter for the DPO loss."},
|
||||
metadata={"help": "The beta parameter in the preference loss."},
|
||||
)
|
||||
dpo_loss: Literal["sigmoid", "hinge", "ipo", "kto_pair"] = field(
|
||||
pref_ftx: float = field(
|
||||
default=0.0,
|
||||
metadata={"help": "The supervised fine-tuning loss coefficient in DPO training."},
|
||||
)
|
||||
pref_loss: Literal["sigmoid", "hinge", "ipo", "kto_pair", "orpo", "simpo"] = field(
|
||||
default="sigmoid",
|
||||
metadata={"help": "The type of DPO loss to use."},
|
||||
)
|
||||
@@ -129,14 +133,6 @@ class RLHFArguments:
|
||||
default=0.0,
|
||||
metadata={"help": "The robust DPO label smoothing parameter in cDPO that should be between 0 and 0.5."},
|
||||
)
|
||||
dpo_ftx: float = field(
|
||||
default=0.0,
|
||||
metadata={"help": "The supervised fine-tuning loss coefficient in DPO training."},
|
||||
)
|
||||
kto_beta: float = field(
|
||||
default=0.1,
|
||||
metadata={"help": "The beta parameter for the KTO loss."},
|
||||
)
|
||||
kto_chosen_weight: float = field(
|
||||
default=1.0,
|
||||
metadata={"help": "The weight factor of the desirable losses in KTO training."},
|
||||
@@ -145,13 +141,9 @@ class RLHFArguments:
|
||||
default=1.0,
|
||||
metadata={"help": "The weight factor of the undesirable losses in KTO training."},
|
||||
)
|
||||
kto_ftx: float = field(
|
||||
default=0.0,
|
||||
metadata={"help": "The supervised fine-tuning loss coefficient in KTO training."},
|
||||
)
|
||||
orpo_beta: float = field(
|
||||
default=0.1,
|
||||
metadata={"help": "The beta (lambda) parameter in the ORPO loss representing the weight of the SFT loss."},
|
||||
simpo_gamma: float = field(
|
||||
default=0.5,
|
||||
metadata={"help": "The target reward margin term in SimPO loss."},
|
||||
)
|
||||
ppo_buffer_size: int = field(
|
||||
default=1,
|
||||
@@ -307,7 +299,7 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to train model in purely bf16 precision (without AMP)."},
|
||||
)
|
||||
stage: Literal["pt", "sft", "rm", "ppo", "dpo", "kto", "orpo"] = field(
|
||||
stage: Literal["pt", "sft", "rm", "ppo", "dpo", "kto"] = field(
|
||||
default="sft",
|
||||
metadata={"help": "Which stage will be performed in training."},
|
||||
)
|
||||
@@ -341,20 +333,22 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA
|
||||
assert self.ref_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
|
||||
assert self.reward_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
|
||||
|
||||
self.use_ref_model = self.pref_loss not in ["orpo", "simpo"]
|
||||
|
||||
if self.stage == "ppo" and self.reward_model is None:
|
||||
raise ValueError("`reward_model` is necessary for PPO training.")
|
||||
|
||||
if self.stage == "ppo" and self.reward_model_type == "lora" and self.finetuning_type != "lora":
|
||||
raise ValueError("`reward_model_type` cannot be lora for Freeze/Full PPO training.")
|
||||
|
||||
if self.stage == "dpo" and self.dpo_loss != "sigmoid" and self.dpo_label_smoothing > 1e-6:
|
||||
if self.stage == "dpo" and self.pref_loss != "sigmoid" and self.dpo_label_smoothing > 1e-6:
|
||||
raise ValueError("`dpo_label_smoothing` is only valid for sigmoid loss function.")
|
||||
|
||||
if self.use_llama_pro and self.finetuning_type == "full":
|
||||
raise ValueError("`use_llama_pro` is only valid for the Freeze or LoRA training.")
|
||||
|
||||
if self.use_galore and self.finetuning_type == "lora":
|
||||
raise ValueError("Cannot use LoRA with GaLore together.")
|
||||
if self.finetuning_type == "lora" and (self.use_galore or self.use_badam):
|
||||
raise ValueError("Cannot use LoRA with GaLore or BAdam together.")
|
||||
|
||||
if self.use_galore and self.use_badam:
|
||||
raise ValueError("Cannot use GaLore with BAdam together.")
|
||||
|
||||
Reference in New Issue
Block a user