add kto to webui

Former-commit-id: 9b0f4d7602
This commit is contained in:
hiyouga
2024-05-20 21:20:25 +08:00
parent 864da49139
commit 5351e3945b
3 changed files with 91 additions and 38 deletions

View File

@@ -184,14 +184,25 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
with gr.Accordion(open=False) as rlhf_tab:
with gr.Row():
dpo_beta = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.01)
dpo_ftx = gr.Slider(minimum=0, maximum=10, value=0, step=0.01)
orpo_beta = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.01)
pref_beta = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.01)
pref_ftx = gr.Slider(minimum=0, maximum=10, value=0, step=0.01)
pref_loss = gr.Dropdown(choices=["sigmoid", "hinge", "ipo", "kto_pair"], value="sigmoid")
reward_model = gr.Dropdown(multiselect=True, allow_custom_value=True)
with gr.Column():
ppo_score_norm = gr.Checkbox()
ppo_whiten_rewards = gr.Checkbox()
input_elems.update({dpo_beta, dpo_ftx, orpo_beta, reward_model})
input_elems.update({pref_beta, pref_ftx, pref_loss, reward_model, ppo_score_norm, ppo_whiten_rewards})
elem_dict.update(
dict(rlhf_tab=rlhf_tab, dpo_beta=dpo_beta, dpo_ftx=dpo_ftx, orpo_beta=orpo_beta, reward_model=reward_model)
dict(
rlhf_tab=rlhf_tab,
pref_beta=pref_beta,
pref_ftx=pref_ftx,
pref_loss=pref_loss,
reward_model=reward_model,
ppo_score_norm=ppo_score_norm,
ppo_whiten_rewards=ppo_whiten_rewards,
)
)
with gr.Accordion(open=False) as galore_tab: