improve KTO impl., replace datasets

Former-commit-id: c450ee87a3
This commit is contained in:
hiyouga
2024-05-18 03:44:56 +08:00
parent 97469892c3
commit 13d7b48efe
66 changed files with 46444 additions and 28125 deletions

View File

@@ -137,21 +137,21 @@ class RLHFArguments:
default=0.1,
metadata={"help": "The beta parameter for the KTO loss."},
)
kto_chosen_weight: float = field(
default=1.0,
metadata={"help": "The weight factor of the desirable losses in KTO training."},
)
kto_rejected_weight: float = field(
default=1.0,
metadata={"help": "The weight factor of the undesirable losses in KTO training."},
)
kto_ftx: float = field(
default=0.0,
metadata={"help": "The supervised fine-tuning loss coefficient in KTO training."},
)
kto_desirable_weight: float = field(
default=1.0,
metadata={"help": "The desirable weight for the KTO loss."},
)
kto_undesirable_weight: float = field(
default=1.0,
metadata={"help": "The undesirable weight for the KTO loss."},
)
orpo_beta: float = field(
default=0.1,
metadata={"help": "The beta (lambda) parameter in ORPO loss representing the weight of the SFT loss."},
metadata={"help": "The beta (lambda) parameter in the ORPO loss representing the weight of the SFT loss."},
)
ppo_buffer_size: int = field(
default=1,
@@ -307,7 +307,7 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA
default=False,
metadata={"help": "Whether or not to train model in purely bf16 precision (without AMP)."},
)
stage: Literal["pt", "sft", "rm", "ppo", "dpo", "orpo", "kto"] = field(
stage: Literal["pt", "sft", "rm", "ppo", "dpo", "kto", "orpo"] = field(
default="sft",
metadata={"help": "Which stage will be performed in training."},
)