From 35b3516812e40eaf123678606b2f456eaa2fe35b Mon Sep 17 00:00:00 2001 From: hiyouga Date: Sun, 31 Mar 2024 18:34:59 +0800 Subject: [PATCH] support orpo in webui Former-commit-id: 5195add324194d2583db40365522e5e2916592b6 --- src/llmtuner/webui/components/train.py | 7 +++++-- src/llmtuner/webui/locales.py | 14 ++++++++++++++ src/llmtuner/webui/runner.py | 5 +++-- 3 files changed, 22 insertions(+), 4 deletions(-) diff --git a/src/llmtuner/webui/components/train.py b/src/llmtuner/webui/components/train.py index 52c8fdb6..9c9f143e 100644 --- a/src/llmtuner/webui/components/train.py +++ b/src/llmtuner/webui/components/train.py @@ -169,10 +169,13 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]: with gr.Row(): dpo_beta = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01, scale=1) dpo_ftx = gr.Slider(value=0, minimum=0, maximum=10, step=0.01, scale=1) + orpo_beta = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01, scale=1) reward_model = gr.Dropdown(multiselect=True, allow_custom_value=True, scale=2) - input_elems.update({dpo_beta, dpo_ftx, reward_model}) - elem_dict.update(dict(rlhf_tab=rlhf_tab, dpo_beta=dpo_beta, dpo_ftx=dpo_ftx, reward_model=reward_model)) + input_elems.update({dpo_beta, dpo_ftx, orpo_beta, reward_model}) + elem_dict.update( + dict(rlhf_tab=rlhf_tab, dpo_beta=dpo_beta, dpo_ftx=dpo_ftx, orpo_beta=orpo_beta, reward_model=reward_model) + ) with gr.Accordion(open=False) as galore_tab: with gr.Row(): diff --git a/src/llmtuner/webui/locales.py b/src/llmtuner/webui/locales.py index f6d6d421..be2841e8 100644 --- a/src/llmtuner/webui/locales.py +++ b/src/llmtuner/webui/locales.py @@ -757,6 +757,20 @@ LOCALES = { "info": "DPO-ftx 中 SFT 损失的权重大小。", }, }, + "orpo_beta": { + "en": { + "label": "ORPO beta", + "info": "Value of the beta parameter in the ORPO loss.", + }, + "ru": { + "label": "ORPO бета", + "info": "Значение параметра бета в функции потерь ORPO.", + }, + "zh": { + "label": "ORPO beta 参数", + "info": "ORPO 损失函数中 beta 超参数大小。", + }, + }, "reward_model": { "en": { "label": "Reward model", diff --git a/src/llmtuner/webui/runner.py b/src/llmtuner/webui/runner.py index ab646051..891a2e4b 100644 --- a/src/llmtuner/webui/runner.py +++ b/src/llmtuner/webui/runner.py @@ -174,10 +174,11 @@ class Runner: ] ) args["reward_model_type"] = "lora" if args["finetuning_type"] == "lora" else "full" - - if args["stage"] == "dpo": + elif args["stage"] == "dpo": args["dpo_beta"] = get("train.dpo_beta") args["dpo_ftx"] = get("train.dpo_ftx") + elif args["stage"] == "orpo": + args["orpo_beta"] = get("train.orpo_beta") if get("train.val_size") > 1e-6 and args["stage"] != "ppo": args["val_size"] = get("train.val_size")