mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-20 22:00:36 +08:00
upgrade gradio to 4.21.0
This commit is contained in:
@@ -7,7 +7,6 @@ from ..utils import check_json_schema
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from gradio.blocks import Block
|
||||
from gradio.components import Component
|
||||
|
||||
from ..engine import Engine
|
||||
@@ -15,9 +14,9 @@ if TYPE_CHECKING:
|
||||
|
||||
def create_chat_box(
|
||||
engine: "Engine", visible: bool = False
|
||||
) -> Tuple["Block", "Component", "Component", Dict[str, "Component"]]:
|
||||
with gr.Box(visible=visible) as chat_box:
|
||||
chatbot = gr.Chatbot()
|
||||
) -> Tuple["gr.Column", "Component", "Component", Dict[str, "Component"]]:
|
||||
with gr.Column(visible=visible) as chat_box:
|
||||
chatbot = gr.Chatbot(show_copy_button=True)
|
||||
messages = gr.State([])
|
||||
with gr.Row():
|
||||
with gr.Column(scale=4):
|
||||
@@ -33,14 +32,14 @@ def create_chat_box(
|
||||
temperature = gr.Slider(0.01, 1.5, value=0.95, step=0.01)
|
||||
clear_btn = gr.Button()
|
||||
|
||||
tools.input(check_json_schema, [tools, engine.manager.get_elem_by_name("top.lang")])
|
||||
tools.input(check_json_schema, inputs=[tools, engine.manager.get_elem_by_id("top.lang")])
|
||||
|
||||
submit_btn.click(
|
||||
engine.chatter.predict,
|
||||
[chatbot, role, query, messages, system, tools, max_new_tokens, top_p, temperature],
|
||||
[chatbot, messages],
|
||||
show_progress=True,
|
||||
).then(lambda: gr.update(value=""), outputs=[query])
|
||||
).then(lambda: "", outputs=[query])
|
||||
|
||||
clear_btn.click(lambda: ([], []), outputs=[chatbot, messages], show_progress=True)
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import json
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, Dict, Tuple
|
||||
from typing import TYPE_CHECKING, Dict, Tuple
|
||||
|
||||
import gradio as gr
|
||||
|
||||
@@ -22,24 +22,24 @@ def next_page(page_index: int, total_num: int) -> int:
|
||||
return page_index + 1 if (page_index + 1) * PAGE_SIZE < total_num else page_index
|
||||
|
||||
|
||||
def can_preview(dataset_dir: str, dataset: list) -> Dict[str, Any]:
|
||||
def can_preview(dataset_dir: str, dataset: list) -> "gr.Button":
|
||||
try:
|
||||
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
|
||||
dataset_info = json.load(f)
|
||||
except Exception:
|
||||
return gr.update(interactive=False)
|
||||
return gr.Button(interactive=False)
|
||||
|
||||
if (
|
||||
len(dataset) > 0
|
||||
and "file_name" in dataset_info[dataset[0]]
|
||||
and os.path.isfile(os.path.join(dataset_dir, dataset_info[dataset[0]]["file_name"]))
|
||||
):
|
||||
return gr.update(interactive=True)
|
||||
return gr.Button(interactive=True)
|
||||
else:
|
||||
return gr.update(interactive=False)
|
||||
return gr.Button(interactive=False)
|
||||
|
||||
|
||||
def get_preview(dataset_dir: str, dataset: list, page_index: int) -> Tuple[int, list, Dict[str, Any]]:
|
||||
def get_preview(dataset_dir: str, dataset: list, page_index: int) -> Tuple[int, list, "gr.Column"]:
|
||||
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
|
||||
dataset_info = json.load(f)
|
||||
|
||||
@@ -51,7 +51,7 @@ def get_preview(dataset_dir: str, dataset: list, page_index: int) -> Tuple[int,
|
||||
data = [json.loads(line) for line in f]
|
||||
else:
|
||||
data = [line for line in f] # noqa: C416
|
||||
return len(data), data[PAGE_SIZE * page_index : PAGE_SIZE * (page_index + 1)], gr.update(visible=True)
|
||||
return len(data), data[PAGE_SIZE * page_index : PAGE_SIZE * (page_index + 1)], gr.Column(visible=True)
|
||||
|
||||
|
||||
def create_preview_box(dataset_dir: "gr.Textbox", dataset: "gr.Dropdown") -> Dict[str, "Component"]:
|
||||
@@ -67,7 +67,7 @@ def create_preview_box(dataset_dir: "gr.Textbox", dataset: "gr.Dropdown") -> Dic
|
||||
close_btn = gr.Button()
|
||||
|
||||
with gr.Row():
|
||||
preview_samples = gr.JSON(interactive=False)
|
||||
preview_samples = gr.JSON()
|
||||
|
||||
dataset.change(can_preview, [dataset_dir, dataset], [data_preview_btn], queue=False).then(
|
||||
lambda: 0, outputs=[page_index], queue=False
|
||||
@@ -81,7 +81,7 @@ def create_preview_box(dataset_dir: "gr.Textbox", dataset: "gr.Dropdown") -> Dic
|
||||
next_btn.click(next_page, [page_index, preview_count], [page_index], queue=False).then(
|
||||
get_preview, [dataset_dir, dataset, page_index], [preview_count, preview_samples, preview_box], queue=False
|
||||
)
|
||||
close_btn.click(lambda: gr.update(visible=False), outputs=[preview_box], queue=False)
|
||||
close_btn.click(lambda: gr.Column(visible=False), outputs=[preview_box], queue=False)
|
||||
return dict(
|
||||
data_preview_btn=data_preview_btn,
|
||||
preview_count=preview_count,
|
||||
|
||||
@@ -53,7 +53,7 @@ def create_eval_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
resume_btn = gr.Checkbox(visible=False, interactive=False, value=False)
|
||||
process_bar = gr.Slider(visible=False, interactive=False)
|
||||
|
||||
with gr.Box():
|
||||
with gr.Row():
|
||||
output_box = gr.Markdown()
|
||||
|
||||
output_elems = [output_box, process_bar]
|
||||
@@ -68,9 +68,9 @@ def create_eval_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
)
|
||||
)
|
||||
|
||||
cmd_preview_btn.click(engine.runner.preview_eval, input_elems, output_elems)
|
||||
cmd_preview_btn.click(engine.runner.preview_eval, input_elems, output_elems, concurrency_limit=None)
|
||||
start_btn.click(engine.runner.run_eval, input_elems, output_elems)
|
||||
stop_btn.click(engine.runner.set_abort, queue=False)
|
||||
resume_btn.change(engine.runner.monitor, outputs=output_elems)
|
||||
resume_btn.change(engine.runner.monitor, outputs=output_elems, concurrency_limit=None)
|
||||
|
||||
return elem_dict
|
||||
|
||||
@@ -74,7 +74,7 @@ def save_model(
|
||||
|
||||
def create_export_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
with gr.Row():
|
||||
max_shard_size = gr.Slider(value=1, minimum=1, maximum=100)
|
||||
max_shard_size = gr.Slider(value=1, minimum=1, maximum=100, step=1)
|
||||
export_quantization_bit = gr.Dropdown(choices=["none", "8", "4", "3", "2"], value="none")
|
||||
export_quantization_dataset = gr.Textbox(value="data/c4_demo.json")
|
||||
export_legacy_format = gr.Checkbox()
|
||||
@@ -89,12 +89,12 @@ def create_export_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
export_btn.click(
|
||||
save_model,
|
||||
[
|
||||
engine.manager.get_elem_by_name("top.lang"),
|
||||
engine.manager.get_elem_by_name("top.model_name"),
|
||||
engine.manager.get_elem_by_name("top.model_path"),
|
||||
engine.manager.get_elem_by_name("top.adapter_path"),
|
||||
engine.manager.get_elem_by_name("top.finetuning_type"),
|
||||
engine.manager.get_elem_by_name("top.template"),
|
||||
engine.manager.get_elem_by_id("top.lang"),
|
||||
engine.manager.get_elem_by_id("top.model_name"),
|
||||
engine.manager.get_elem_by_id("top.model_path"),
|
||||
engine.manager.get_elem_by_id("top.adapter_path"),
|
||||
engine.manager.get_elem_by_id("top.finetuning_type"),
|
||||
engine.manager.get_elem_by_id("top.template"),
|
||||
max_shard_size,
|
||||
export_quantization_bit,
|
||||
export_quantization_dataset,
|
||||
|
||||
@@ -29,11 +29,11 @@ def create_infer_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
elem_dict.update(dict(chat_box=chat_box, **chat_elems))
|
||||
|
||||
load_btn.click(engine.chatter.load_model, input_elems, [info_box]).then(
|
||||
lambda: gr.update(visible=engine.chatter.loaded), outputs=[chat_box]
|
||||
lambda: gr.Column(visible=engine.chatter.loaded), outputs=[chat_box]
|
||||
)
|
||||
|
||||
unload_btn.click(engine.chatter.unload_model, input_elems, [info_box]).then(
|
||||
lambda: ([], []), outputs=[chatbot, history]
|
||||
).then(lambda: gr.update(visible=engine.chatter.loaded), outputs=[chat_box])
|
||||
).then(lambda: gr.Column(visible=engine.chatter.loaded), outputs=[chat_box])
|
||||
|
||||
return elem_dict
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import TYPE_CHECKING, Dict, Tuple
|
||||
from typing import TYPE_CHECKING, Dict
|
||||
|
||||
import gradio as gr
|
||||
|
||||
@@ -12,7 +12,7 @@ if TYPE_CHECKING:
|
||||
from gradio.components import Component
|
||||
|
||||
|
||||
def create_top() -> Tuple["gr.Dropdown", Dict[str, "Component"]]:
|
||||
def create_top() -> Dict[str, "Component"]:
|
||||
available_models = list(SUPPORTED_MODELS.keys()) + ["Custom"]
|
||||
|
||||
with gr.Row():
|
||||
@@ -25,7 +25,7 @@ def create_top() -> Tuple["gr.Dropdown", Dict[str, "Component"]]:
|
||||
adapter_path = gr.Dropdown(multiselect=True, allow_custom_value=True, scale=5)
|
||||
refresh_btn = gr.Button(scale=1)
|
||||
|
||||
with gr.Accordion(label="Advanced config", open=False) as advanced_tab:
|
||||
with gr.Accordion(open=False) as advanced_tab:
|
||||
with gr.Row():
|
||||
quantization_bit = gr.Dropdown(choices=["none", "8", "4"], value="none")
|
||||
template = gr.Dropdown(choices=list(templates.keys()), value="default")
|
||||
@@ -44,7 +44,7 @@ def create_top() -> Tuple["gr.Dropdown", Dict[str, "Component"]]:
|
||||
|
||||
refresh_btn.click(list_adapters, [model_name, finetuning_type], [adapter_path], queue=False)
|
||||
|
||||
return lang, dict(
|
||||
return dict(
|
||||
lang=lang,
|
||||
model_name=model_name,
|
||||
model_path=model_path,
|
||||
|
||||
@@ -68,7 +68,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
)
|
||||
)
|
||||
|
||||
with gr.Accordion(label="Extra config", open=False) as extra_tab:
|
||||
with gr.Accordion(open=False) as extra_tab:
|
||||
with gr.Row():
|
||||
logging_steps = gr.Slider(value=5, minimum=5, maximum=1000, step=5)
|
||||
save_steps = gr.Slider(value=100, minimum=10, maximum=5000, step=10)
|
||||
@@ -113,7 +113,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
)
|
||||
)
|
||||
|
||||
with gr.Accordion(label="Freeze config", open=False) as freeze_tab:
|
||||
with gr.Accordion(open=False) as freeze_tab:
|
||||
with gr.Row():
|
||||
num_layer_trainable = gr.Slider(value=3, minimum=1, maximum=128, step=1, scale=2)
|
||||
name_module_trainable = gr.Textbox(value="all", scale=3)
|
||||
@@ -125,7 +125,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
)
|
||||
)
|
||||
|
||||
with gr.Accordion(label="LoRA config", open=False) as lora_tab:
|
||||
with gr.Accordion(open=False) as lora_tab:
|
||||
with gr.Row():
|
||||
lora_rank = gr.Slider(value=8, minimum=1, maximum=1024, step=1, scale=1)
|
||||
lora_alpha = gr.Slider(value=16, minimum=1, maximum=2048, step=1, scale=1)
|
||||
@@ -155,7 +155,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
)
|
||||
)
|
||||
|
||||
with gr.Accordion(label="RLHF config", open=False) as rlhf_tab:
|
||||
with gr.Accordion(open=False) as rlhf_tab:
|
||||
with gr.Row():
|
||||
dpo_beta = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01, scale=1)
|
||||
dpo_ftx = gr.Slider(value=0, minimum=0, maximum=10, step=0.01, scale=1)
|
||||
@@ -163,7 +163,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
|
||||
training_stage.change(list_dataset, [dataset_dir, training_stage], [dataset], queue=False).then(
|
||||
list_adapters,
|
||||
[engine.manager.get_elem_by_name("top.model_name"), engine.manager.get_elem_by_name("top.finetuning_type")],
|
||||
[engine.manager.get_elem_by_id("top.model_name"), engine.manager.get_elem_by_id("top.finetuning_type")],
|
||||
[reward_model],
|
||||
queue=False,
|
||||
).then(autoset_packing, [training_stage], [packing], queue=False)
|
||||
@@ -171,7 +171,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
input_elems.update({dpo_beta, dpo_ftx, reward_model})
|
||||
elem_dict.update(dict(rlhf_tab=rlhf_tab, dpo_beta=dpo_beta, dpo_ftx=dpo_ftx, reward_model=reward_model))
|
||||
|
||||
with gr.Accordion(label="GaLore config", open=False) as galore_tab:
|
||||
with gr.Accordion(open=False) as galore_tab:
|
||||
with gr.Row():
|
||||
use_galore = gr.Checkbox(scale=1)
|
||||
galore_rank = gr.Slider(value=16, minimum=1, maximum=1024, step=1, scale=2)
|
||||
@@ -205,7 +205,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
resume_btn = gr.Checkbox(visible=False, interactive=False)
|
||||
process_bar = gr.Slider(visible=False, interactive=False)
|
||||
|
||||
with gr.Box():
|
||||
with gr.Row():
|
||||
output_box = gr.Markdown()
|
||||
|
||||
with gr.Column(scale=1):
|
||||
@@ -214,10 +214,10 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
input_elems.add(output_dir)
|
||||
output_elems = [output_box, process_bar]
|
||||
|
||||
cmd_preview_btn.click(engine.runner.preview_train, input_elems, output_elems)
|
||||
cmd_preview_btn.click(engine.runner.preview_train, input_elems, output_elems, concurrency_limit=None)
|
||||
start_btn.click(engine.runner.run_train, input_elems, output_elems)
|
||||
stop_btn.click(engine.runner.set_abort, queue=False)
|
||||
resume_btn.change(engine.runner.monitor, outputs=output_elems)
|
||||
resume_btn.change(engine.runner.monitor, outputs=output_elems, concurrency_limit=None)
|
||||
|
||||
elem_dict.update(
|
||||
dict(
|
||||
@@ -235,8 +235,8 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
output_box.change(
|
||||
gen_plot,
|
||||
[
|
||||
engine.manager.get_elem_by_name("top.model_name"),
|
||||
engine.manager.get_elem_by_name("top.finetuning_type"),
|
||||
engine.manager.get_elem_by_id("top.model_name"),
|
||||
engine.manager.get_elem_by_id("top.finetuning_type"),
|
||||
output_dir,
|
||||
],
|
||||
loss_viewer,
|
||||
|
||||
Reference in New Issue
Block a user