From 4ecf4daeb20050a03b62ff9a592c5a4b0e71ccfe Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Tue, 27 May 2025 18:25:31 +0800 Subject: [PATCH] [webui] add extra args to export (#8178) --- src/llamafactory/webui/components/export.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/llamafactory/webui/components/export.py b/src/llamafactory/webui/components/export.py index bb458e6d..d153ffa6 100644 --- a/src/llamafactory/webui/components/export.py +++ b/src/llamafactory/webui/components/export.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json from collections.abc import Generator from typing import TYPE_CHECKING, Union @@ -57,6 +58,7 @@ def save_model( export_legacy_format: bool, export_dir: str, export_hub_model_id: str, + extra_args: str, ) -> Generator[str, None, None]: user_config = load_config() error = "" @@ -73,6 +75,11 @@ def save_model( elif export_quantization_bit in GPTQ_BITS and checkpoint_path and isinstance(checkpoint_path, list): error = ALERTS["err_gptq_lora"][lang] + try: + json.loads(extra_args) + except json.JSONDecodeError: + error = ALERTS["err_json_schema"][lang] + if error: gr.Warning(error) yield error @@ -92,6 +99,7 @@ def save_model( export_legacy_format=export_legacy_format, trust_remote_code=True, ) + args.update(json.loads(extra_args)) if checkpoint_path: if finetuning_type in PEFT_METHODS: # list @@ -118,6 +126,7 @@ def create_export_tab(engine: "Engine") -> dict[str, "Component"]: with gr.Row(): export_dir = gr.Textbox() export_hub_model_id = gr.Textbox() + extra_args = gr.Textbox(value="{}") checkpoint_path: gr.Dropdown = engine.manager.get_elem_by_id("top.checkpoint_path") checkpoint_path.change(can_quantize, [checkpoint_path], [export_quantization_bit], queue=False) @@ -141,6 +150,7 @@ def create_export_tab(engine: "Engine") -> dict[str, "Component"]: export_legacy_format, export_dir, export_hub_model_id, + extra_args, ], [info_box], ) @@ -153,6 +163,7 @@ def create_export_tab(engine: "Engine") -> dict[str, "Component"]: export_legacy_format=export_legacy_format, export_dir=export_dir, export_hub_model_id=export_hub_model_id, + extra_args=extra_args, export_btn=export_btn, info_box=info_box, )