From af526c3a46923440b91042436d928acb75312f9c Mon Sep 17 00:00:00 2001 From: hiyouga Date: Tue, 5 Mar 2024 20:53:30 +0800 Subject: [PATCH] fix arg dtype Former-commit-id: e0c47358f9d09ab64acbb5ebafc61b52b5b1f2af --- src/llmtuner/webui/components/train.py | 2 +- src/llmtuner/webui/runner.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/llmtuner/webui/components/train.py b/src/llmtuner/webui/components/train.py index bfec4c4b..daefa1a8 100644 --- a/src/llmtuner/webui/components/train.py +++ b/src/llmtuner/webui/components/train.py @@ -123,7 +123,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]: with gr.Accordion(label="LoRA config", open=False) as lora_tab: with gr.Row(): lora_rank = gr.Slider(value=8, minimum=1, maximum=1024, step=1, scale=1) - lora_alpha = gr.Slider(value=16, minimum=1, maximum=2048, step=0.1, scale=1) + lora_alpha = gr.Slider(value=16, minimum=1, maximum=2048, step=1, scale=1) lora_dropout = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01, scale=1) lora_target = gr.Textbox(scale=2) diff --git a/src/llmtuner/webui/runner.py b/src/llmtuner/webui/runner.py index c6fd4ae6..ba84e2c1 100644 --- a/src/llmtuner/webui/runner.py +++ b/src/llmtuner/webui/runner.py @@ -144,7 +144,7 @@ class Runner: args["name_module_trainable"] = get("train.name_module_trainable") elif args["finetuning_type"] == "lora": args["lora_rank"] = int(get("train.lora_rank")) - args["lora_alpha"] = float(get("train.lora_alpha")) + args["lora_alpha"] = int(get("train.lora_alpha")) args["lora_dropout"] = float(get("train.lora_dropout")) args["lora_target"] = get("train.lora_target") or get_module(get("top.model_name")) args["use_rslora"] = get("train.use_rslora")