mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-16 03:40:34 +08:00
support longlora for main branch
This commit is contained in:
@@ -95,7 +95,8 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
with gr.Accordion(label="RLHF config", open=False) as rlhf_tab:
|
||||
with gr.Row():
|
||||
dpo_beta = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01, scale=1)
|
||||
reward_model = gr.Dropdown(scale=3, allow_custom_value=True)
|
||||
dpo_ftx = gr.Slider(value=0, minimum=0, maximum=10, step=0.01, scale=1)
|
||||
reward_model = gr.Dropdown(scale=2, allow_custom_value=True)
|
||||
refresh_btn = gr.Button(scale=1)
|
||||
|
||||
refresh_btn.click(
|
||||
@@ -105,8 +106,10 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
queue=False
|
||||
)
|
||||
|
||||
input_elems.update({dpo_beta, reward_model})
|
||||
elem_dict.update(dict(rlhf_tab=rlhf_tab, dpo_beta=dpo_beta, reward_model=reward_model, refresh_btn=refresh_btn))
|
||||
input_elems.update({dpo_beta, dpo_ftx, reward_model})
|
||||
elem_dict.update(dict(
|
||||
rlhf_tab=rlhf_tab, dpo_beta=dpo_beta, dpo_ftx=dpo_ftx, reward_model=reward_model, refresh_btn=refresh_btn
|
||||
))
|
||||
|
||||
with gr.Row():
|
||||
cmd_preview_btn = gr.Button()
|
||||
|
||||
@@ -421,6 +421,16 @@ LOCALES = {
|
||||
"info": "DPO 损失函数中 beta 超参数大小。"
|
||||
}
|
||||
},
|
||||
"dpo_ftx": {
|
||||
"en": {
|
||||
"label": "DPO-ftx weight",
|
||||
"info": "The weight of SFT loss in the DPO-ftx."
|
||||
},
|
||||
"zh": {
|
||||
"label": "DPO-ftx 权重",
|
||||
"info": "DPO-ftx 中 SFT 损失的权重大小。"
|
||||
}
|
||||
},
|
||||
"reward_model": {
|
||||
"en": {
|
||||
"label": "Reward model",
|
||||
|
||||
@@ -146,6 +146,7 @@ class Runner:
|
||||
|
||||
if args["stage"] == "dpo":
|
||||
args["dpo_beta"] = get("train.dpo_beta")
|
||||
args["dpo_ftx"] = get("train.dpo_ftx")
|
||||
|
||||
if get("train.val_size") > 1e-6 and args["stage"] != "ppo":
|
||||
args["val_size"] = get("train.val_size")
|
||||
|
||||
Reference in New Issue
Block a user