mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-03 04:02:49 +08:00
[webui] add extra args to export (#8178)
This commit is contained in:
parent
519ac92803
commit
4ecf4daeb2
@ -12,6 +12,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import json
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
from typing import TYPE_CHECKING, Union
|
from typing import TYPE_CHECKING, Union
|
||||||
|
|
||||||
@ -57,6 +58,7 @@ def save_model(
|
|||||||
export_legacy_format: bool,
|
export_legacy_format: bool,
|
||||||
export_dir: str,
|
export_dir: str,
|
||||||
export_hub_model_id: str,
|
export_hub_model_id: str,
|
||||||
|
extra_args: str,
|
||||||
) -> Generator[str, None, None]:
|
) -> Generator[str, None, None]:
|
||||||
user_config = load_config()
|
user_config = load_config()
|
||||||
error = ""
|
error = ""
|
||||||
@ -73,6 +75,11 @@ def save_model(
|
|||||||
elif export_quantization_bit in GPTQ_BITS and checkpoint_path and isinstance(checkpoint_path, list):
|
elif export_quantization_bit in GPTQ_BITS and checkpoint_path and isinstance(checkpoint_path, list):
|
||||||
error = ALERTS["err_gptq_lora"][lang]
|
error = ALERTS["err_gptq_lora"][lang]
|
||||||
|
|
||||||
|
try:
|
||||||
|
json.loads(extra_args)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
error = ALERTS["err_json_schema"][lang]
|
||||||
|
|
||||||
if error:
|
if error:
|
||||||
gr.Warning(error)
|
gr.Warning(error)
|
||||||
yield error
|
yield error
|
||||||
@ -92,6 +99,7 @@ def save_model(
|
|||||||
export_legacy_format=export_legacy_format,
|
export_legacy_format=export_legacy_format,
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
)
|
)
|
||||||
|
args.update(json.loads(extra_args))
|
||||||
|
|
||||||
if checkpoint_path:
|
if checkpoint_path:
|
||||||
if finetuning_type in PEFT_METHODS: # list
|
if finetuning_type in PEFT_METHODS: # list
|
||||||
@ -118,6 +126,7 @@ def create_export_tab(engine: "Engine") -> dict[str, "Component"]:
|
|||||||
with gr.Row():
|
with gr.Row():
|
||||||
export_dir = gr.Textbox()
|
export_dir = gr.Textbox()
|
||||||
export_hub_model_id = gr.Textbox()
|
export_hub_model_id = gr.Textbox()
|
||||||
|
extra_args = gr.Textbox(value="{}")
|
||||||
|
|
||||||
checkpoint_path: gr.Dropdown = engine.manager.get_elem_by_id("top.checkpoint_path")
|
checkpoint_path: gr.Dropdown = engine.manager.get_elem_by_id("top.checkpoint_path")
|
||||||
checkpoint_path.change(can_quantize, [checkpoint_path], [export_quantization_bit], queue=False)
|
checkpoint_path.change(can_quantize, [checkpoint_path], [export_quantization_bit], queue=False)
|
||||||
@ -141,6 +150,7 @@ def create_export_tab(engine: "Engine") -> dict[str, "Component"]:
|
|||||||
export_legacy_format,
|
export_legacy_format,
|
||||||
export_dir,
|
export_dir,
|
||||||
export_hub_model_id,
|
export_hub_model_id,
|
||||||
|
extra_args,
|
||||||
],
|
],
|
||||||
[info_box],
|
[info_box],
|
||||||
)
|
)
|
||||||
@ -153,6 +163,7 @@ def create_export_tab(engine: "Engine") -> dict[str, "Component"]:
|
|||||||
export_legacy_format=export_legacy_format,
|
export_legacy_format=export_legacy_format,
|
||||||
export_dir=export_dir,
|
export_dir=export_dir,
|
||||||
export_hub_model_id=export_hub_model_id,
|
export_hub_model_id=export_hub_model_id,
|
||||||
|
extra_args=extra_args,
|
||||||
export_btn=export_btn,
|
export_btn=export_btn,
|
||||||
info_box=info_box,
|
info_box=info_box,
|
||||||
)
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user