diff --git a/src/llamafactory/webui/chatter.py b/src/llamafactory/webui/chatter.py index 9b1bcd67..feeeaf3d 100644 --- a/src/llamafactory/webui/chatter.py +++ b/src/llamafactory/webui/chatter.py @@ -114,6 +114,11 @@ class WebChatModel(ChatModel): elif self.demo_mode: error = ALERTS["err_demo"][lang] + try: + json.loads(get("infer.extra_args")) + except json.JSONDecodeError: + error = ALERTS["err_json_schema"][lang] + if error: gr.Warning(error) yield error @@ -131,9 +136,9 @@ class WebChatModel(ChatModel): enable_liger_kernel=(get("top.booster") == "liger_kernel"), infer_backend=get("infer.infer_backend"), infer_dtype=get("infer.infer_dtype"), - vllm_enforce_eager=True, trust_remote_code=True, ) + args.update(json.loads(get("infer.extra_args"))) # checkpoints if checkpoint_path: diff --git a/src/llamafactory/webui/components/infer.py b/src/llamafactory/webui/components/infer.py index 677036b9..ef508cdf 100644 --- a/src/llamafactory/webui/components/infer.py +++ b/src/llamafactory/webui/components/infer.py @@ -36,6 +36,7 @@ def create_infer_tab(engine: "Engine") -> dict[str, "Component"]: with gr.Row(): infer_backend = gr.Dropdown(choices=["huggingface", "vllm", "sglang"], value="huggingface") infer_dtype = gr.Dropdown(choices=["auto", "float16", "bfloat16", "float32"], value="auto") + extra_args = gr.Textbox(value='{"vllm_enforce_eager": true}') with gr.Row(): load_btn = gr.Button() @@ -43,11 +44,12 @@ def create_infer_tab(engine: "Engine") -> dict[str, "Component"]: info_box = gr.Textbox(show_label=False, interactive=False) - input_elems.update({infer_backend, infer_dtype}) + input_elems.update({infer_backend, infer_dtype, extra_args}) elem_dict.update( dict( infer_backend=infer_backend, infer_dtype=infer_dtype, + extra_args=extra_args, load_btn=load_btn, unload_btn=unload_btn, info_box=info_box,