[webui] upgrade to gradio 5 (#6688)

Former-commit-id: 4d0f662dbe227ab0da11a1e109f7a2c5ab8f70b9
This commit is contained in:
hoshi-hiyouga 2025-01-17 20:15:42 +08:00 committed by GitHub
parent 788accb601
commit 770433fa33
8 changed files with 36 additions and 44 deletions

View File

@ -4,7 +4,7 @@ accelerate>=0.34.0,<=1.0.1
peft>=0.11.1,<=0.12.0
trl>=0.8.6,<=0.9.6
tokenizers>=0.19.0,<0.20.4
gradio>=4.0.0,<5.0.0
gradio>=4.0.0,<6.0.0
pandas>=2.0.0
scipy
einops

View File

@ -14,7 +14,7 @@
import json
import os
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Sequence, Tuple
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple
from ..chat import ChatModel
from ..data import Role
@ -120,17 +120,17 @@ class WebChatModel(ChatModel):
def append(
self,
chatbot: List[List[Optional[str]]],
messages: Sequence[Dict[str, str]],
chatbot: List[Dict[str, str]],
messages: List[Dict[str, str]],
role: str,
query: str,
) -> Tuple[List[List[Optional[str]]], List[Dict[str, str]], str]:
return chatbot + [[query, None]], messages + [{"role": role, "content": query}], ""
) -> Tuple[List[Dict[str, str]], List[Dict[str, str]], str]:
return chatbot + [{"role": "user", "content": query}], messages + [{"role": role, "content": query}], ""
def stream(
self,
chatbot: List[List[Optional[str]]],
messages: Sequence[Dict[str, str]],
chatbot: List[Dict[str, str]],
messages: List[Dict[str, str]],
system: str,
tools: str,
image: Optional[Any],
@ -138,8 +138,8 @@ class WebChatModel(ChatModel):
max_new_tokens: int,
top_p: float,
temperature: float,
) -> Generator[Tuple[List[List[Optional[str]]], List[Dict[str, str]]], None, None]:
chatbot[-1][1] = ""
) -> Generator[Tuple[List[Dict[str, str]], List[Dict[str, str]]], None, None]:
chatbot.append({"role": "assistant", "content": ""})
response = ""
for new_text in self.stream_chat(
messages,
@ -166,5 +166,5 @@ class WebChatModel(ChatModel):
output_messages = messages + [{"role": Role.ASSISTANT.value, "content": result}]
bot_text = result
chatbot[-1][1] = bot_text
chatbot[-1] = {"role": "assistant", "content": bot_text}
yield chatbot, output_messages

View File

@ -33,7 +33,7 @@ def create_chat_box(
engine: "Engine", visible: bool = False
) -> Tuple["Component", "Component", Dict[str, "Component"]]:
with gr.Column(visible=visible) as chat_box:
chatbot = gr.Chatbot(show_copy_button=True)
chatbot = gr.Chatbot(type="messages", show_copy_button=True)
messages = gr.State([])
with gr.Row():
with gr.Column(scale=4):

View File

@ -30,11 +30,10 @@ if TYPE_CHECKING:
def create_top() -> Dict[str, "Component"]:
available_models = list(SUPPORTED_MODELS.keys()) + ["Custom"]
with gr.Row():
lang = gr.Dropdown(choices=["en", "ru", "zh", "ko"], scale=1)
model_name = gr.Dropdown(choices=available_models, scale=3)
lang = gr.Dropdown(choices=["en", "ru", "zh", "ko"], value=None, scale=1)
available_models = list(SUPPORTED_MODELS.keys()) + ["Custom"]
model_name = gr.Dropdown(choices=available_models, value=None, scale=3)
model_path = gr.Textbox(scale=3)
with gr.Row():

View File

@ -39,9 +39,8 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
elem_dict = dict()
with gr.Row():
training_stage = gr.Dropdown(
choices=list(TRAINING_STAGES.keys()), value=list(TRAINING_STAGES.keys())[0], scale=1
)
stages = list(TRAINING_STAGES.keys())
training_stage = gr.Dropdown(choices=stages, value=stages[0], scale=1)
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=1)
dataset = gr.Dropdown(multiselect=True, allow_custom_value=True, scale=4)
preview_elems = create_preview_box(dataset_dir, dataset)
@ -107,8 +106,12 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
use_llama_pro = gr.Checkbox()
with gr.Column():
shift_attn = gr.Checkbox()
report_to = gr.Checkbox()
report_to = gr.Dropdown(
choices=["none", "all", "wandb", "mlflow", "neptune", "tensorboard"],
value=["none"],
allow_custom_value=True,
multiselect=True,
)
input_elems.update(
{
@ -123,7 +126,6 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
mask_history,
resize_vocab,
use_llama_pro,
shift_attn,
report_to,
}
)
@ -141,7 +143,6 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
mask_history=mask_history,
resize_vocab=resize_vocab,
use_llama_pro=use_llama_pro,
shift_attn=shift_attn,
report_to=report_to,
)
)

View File

@ -713,24 +713,6 @@ LOCALES = {
"info": "확장된 블록의 매개변수를 학습 가능하게 만듭니다.",
},
},
"shift_attn": {
"en": {
"label": "Enable S^2 Attention",
"info": "Use shift short attention proposed by LongLoRA.",
},
"ru": {
"label": "Включить S^2 внимание",
"info": "Использовать сдвиг внимания на короткие дистанции предложенный LongLoRA.",
},
"zh": {
"label": "使用 S^2 Attention",
"info": "使用 LongLoRA 提出的 shift short attention。",
},
"ko": {
"label": "S^2 Attention 사용",
"info": "LongLoRA에서 제안한 shift short attention을 사용합니다.",
},
},
"report_to": {
"en": {
"label": "Enable external logger",

View File

@ -144,8 +144,7 @@ class Runner:
mask_history=get("train.mask_history"),
resize_vocab=get("train.resize_vocab"),
use_llama_pro=get("train.use_llama_pro"),
shift_attn=get("train.shift_attn"),
report_to="all" if get("train.report_to") else "none",
report_to=get("train.report_to"),
use_galore=get("train.use_galore"),
use_apollo=get("train.use_apollo"),
use_badam=get("train.use_badam"),
@ -239,6 +238,12 @@ class Runner:
args["badam_switch_interval"] = get("train.badam_switch_interval")
args["badam_update_ratio"] = get("train.badam_update_ratio")
# report_to
if "none" in args["report_to"]:
args["report_to"] = "none"
elif "all" in args["report_to"]:
args["report_to"] = "all"
# swanlab config
if get("train.use_swanlab"):
args["swanlab_project"] = get("train.swanlab_project")

View File

@ -111,7 +111,12 @@ def gen_cmd(args: Dict[str, Any]) -> str:
"""
cmd_lines = ["llamafactory-cli train "]
for k, v in clean_cmd(args).items():
cmd_lines.append(f" --{k} {str(v)} ")
if isinstance(v, dict):
cmd_lines.append(f" --{k} {json.dumps(v, ensure_ascii=False)} ")
elif isinstance(v, list):
cmd_lines.append(f" --{k} {' '.join(map(str, v))} ")
else:
cmd_lines.append(f" --{k} {str(v)} ")
if os.name == "nt":
cmd_text = "`\n".join(cmd_lines)