mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-03 20:22:49 +08:00
fix arg dtype
Former-commit-id: e0c47358f9d09ab64acbb5ebafc61b52b5b1f2af
This commit is contained in:
parent
9561809ce9
commit
af526c3a46
@ -123,7 +123,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
with gr.Accordion(label="LoRA config", open=False) as lora_tab:
|
||||
with gr.Row():
|
||||
lora_rank = gr.Slider(value=8, minimum=1, maximum=1024, step=1, scale=1)
|
||||
lora_alpha = gr.Slider(value=16, minimum=1, maximum=2048, step=0.1, scale=1)
|
||||
lora_alpha = gr.Slider(value=16, minimum=1, maximum=2048, step=1, scale=1)
|
||||
lora_dropout = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01, scale=1)
|
||||
lora_target = gr.Textbox(scale=2)
|
||||
|
||||
|
@ -144,7 +144,7 @@ class Runner:
|
||||
args["name_module_trainable"] = get("train.name_module_trainable")
|
||||
elif args["finetuning_type"] == "lora":
|
||||
args["lora_rank"] = int(get("train.lora_rank"))
|
||||
args["lora_alpha"] = float(get("train.lora_alpha"))
|
||||
args["lora_alpha"] = int(get("train.lora_alpha"))
|
||||
args["lora_dropout"] = float(get("train.lora_dropout"))
|
||||
args["lora_target"] = get("train.lora_target") or get_module(get("top.model_name"))
|
||||
args["use_rslora"] = get("train.use_rslora")
|
||||
|
Loading…
x
Reference in New Issue
Block a user