mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-01 11:12:50 +08:00
[webui] upgrade to gradio 5 (#6688)
Former-commit-id: 4d0f662dbe227ab0da11a1e109f7a2c5ab8f70b9
This commit is contained in:
parent
788accb601
commit
770433fa33
@ -4,7 +4,7 @@ accelerate>=0.34.0,<=1.0.1
|
||||
peft>=0.11.1,<=0.12.0
|
||||
trl>=0.8.6,<=0.9.6
|
||||
tokenizers>=0.19.0,<0.20.4
|
||||
gradio>=4.0.0,<5.0.0
|
||||
gradio>=4.0.0,<6.0.0
|
||||
pandas>=2.0.0
|
||||
scipy
|
||||
einops
|
||||
|
@ -14,7 +14,7 @@
|
||||
|
||||
import json
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Sequence, Tuple
|
||||
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple
|
||||
|
||||
from ..chat import ChatModel
|
||||
from ..data import Role
|
||||
@ -120,17 +120,17 @@ class WebChatModel(ChatModel):
|
||||
|
||||
def append(
|
||||
self,
|
||||
chatbot: List[List[Optional[str]]],
|
||||
messages: Sequence[Dict[str, str]],
|
||||
chatbot: List[Dict[str, str]],
|
||||
messages: List[Dict[str, str]],
|
||||
role: str,
|
||||
query: str,
|
||||
) -> Tuple[List[List[Optional[str]]], List[Dict[str, str]], str]:
|
||||
return chatbot + [[query, None]], messages + [{"role": role, "content": query}], ""
|
||||
) -> Tuple[List[Dict[str, str]], List[Dict[str, str]], str]:
|
||||
return chatbot + [{"role": "user", "content": query}], messages + [{"role": role, "content": query}], ""
|
||||
|
||||
def stream(
|
||||
self,
|
||||
chatbot: List[List[Optional[str]]],
|
||||
messages: Sequence[Dict[str, str]],
|
||||
chatbot: List[Dict[str, str]],
|
||||
messages: List[Dict[str, str]],
|
||||
system: str,
|
||||
tools: str,
|
||||
image: Optional[Any],
|
||||
@ -138,8 +138,8 @@ class WebChatModel(ChatModel):
|
||||
max_new_tokens: int,
|
||||
top_p: float,
|
||||
temperature: float,
|
||||
) -> Generator[Tuple[List[List[Optional[str]]], List[Dict[str, str]]], None, None]:
|
||||
chatbot[-1][1] = ""
|
||||
) -> Generator[Tuple[List[Dict[str, str]], List[Dict[str, str]]], None, None]:
|
||||
chatbot.append({"role": "assistant", "content": ""})
|
||||
response = ""
|
||||
for new_text in self.stream_chat(
|
||||
messages,
|
||||
@ -166,5 +166,5 @@ class WebChatModel(ChatModel):
|
||||
output_messages = messages + [{"role": Role.ASSISTANT.value, "content": result}]
|
||||
bot_text = result
|
||||
|
||||
chatbot[-1][1] = bot_text
|
||||
chatbot[-1] = {"role": "assistant", "content": bot_text}
|
||||
yield chatbot, output_messages
|
||||
|
@ -33,7 +33,7 @@ def create_chat_box(
|
||||
engine: "Engine", visible: bool = False
|
||||
) -> Tuple["Component", "Component", Dict[str, "Component"]]:
|
||||
with gr.Column(visible=visible) as chat_box:
|
||||
chatbot = gr.Chatbot(show_copy_button=True)
|
||||
chatbot = gr.Chatbot(type="messages", show_copy_button=True)
|
||||
messages = gr.State([])
|
||||
with gr.Row():
|
||||
with gr.Column(scale=4):
|
||||
|
@ -30,11 +30,10 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
def create_top() -> Dict[str, "Component"]:
|
||||
available_models = list(SUPPORTED_MODELS.keys()) + ["Custom"]
|
||||
|
||||
with gr.Row():
|
||||
lang = gr.Dropdown(choices=["en", "ru", "zh", "ko"], scale=1)
|
||||
model_name = gr.Dropdown(choices=available_models, scale=3)
|
||||
lang = gr.Dropdown(choices=["en", "ru", "zh", "ko"], value=None, scale=1)
|
||||
available_models = list(SUPPORTED_MODELS.keys()) + ["Custom"]
|
||||
model_name = gr.Dropdown(choices=available_models, value=None, scale=3)
|
||||
model_path = gr.Textbox(scale=3)
|
||||
|
||||
with gr.Row():
|
||||
|
@ -39,9 +39,8 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
elem_dict = dict()
|
||||
|
||||
with gr.Row():
|
||||
training_stage = gr.Dropdown(
|
||||
choices=list(TRAINING_STAGES.keys()), value=list(TRAINING_STAGES.keys())[0], scale=1
|
||||
)
|
||||
stages = list(TRAINING_STAGES.keys())
|
||||
training_stage = gr.Dropdown(choices=stages, value=stages[0], scale=1)
|
||||
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=1)
|
||||
dataset = gr.Dropdown(multiselect=True, allow_custom_value=True, scale=4)
|
||||
preview_elems = create_preview_box(dataset_dir, dataset)
|
||||
@ -107,8 +106,12 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
use_llama_pro = gr.Checkbox()
|
||||
|
||||
with gr.Column():
|
||||
shift_attn = gr.Checkbox()
|
||||
report_to = gr.Checkbox()
|
||||
report_to = gr.Dropdown(
|
||||
choices=["none", "all", "wandb", "mlflow", "neptune", "tensorboard"],
|
||||
value=["none"],
|
||||
allow_custom_value=True,
|
||||
multiselect=True,
|
||||
)
|
||||
|
||||
input_elems.update(
|
||||
{
|
||||
@ -123,7 +126,6 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
mask_history,
|
||||
resize_vocab,
|
||||
use_llama_pro,
|
||||
shift_attn,
|
||||
report_to,
|
||||
}
|
||||
)
|
||||
@ -141,7 +143,6 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
mask_history=mask_history,
|
||||
resize_vocab=resize_vocab,
|
||||
use_llama_pro=use_llama_pro,
|
||||
shift_attn=shift_attn,
|
||||
report_to=report_to,
|
||||
)
|
||||
)
|
||||
|
@ -713,24 +713,6 @@ LOCALES = {
|
||||
"info": "확장된 블록의 매개변수를 학습 가능하게 만듭니다.",
|
||||
},
|
||||
},
|
||||
"shift_attn": {
|
||||
"en": {
|
||||
"label": "Enable S^2 Attention",
|
||||
"info": "Use shift short attention proposed by LongLoRA.",
|
||||
},
|
||||
"ru": {
|
||||
"label": "Включить S^2 внимание",
|
||||
"info": "Использовать сдвиг внимания на короткие дистанции предложенный LongLoRA.",
|
||||
},
|
||||
"zh": {
|
||||
"label": "使用 S^2 Attention",
|
||||
"info": "使用 LongLoRA 提出的 shift short attention。",
|
||||
},
|
||||
"ko": {
|
||||
"label": "S^2 Attention 사용",
|
||||
"info": "LongLoRA에서 제안한 shift short attention을 사용합니다.",
|
||||
},
|
||||
},
|
||||
"report_to": {
|
||||
"en": {
|
||||
"label": "Enable external logger",
|
||||
|
@ -144,8 +144,7 @@ class Runner:
|
||||
mask_history=get("train.mask_history"),
|
||||
resize_vocab=get("train.resize_vocab"),
|
||||
use_llama_pro=get("train.use_llama_pro"),
|
||||
shift_attn=get("train.shift_attn"),
|
||||
report_to="all" if get("train.report_to") else "none",
|
||||
report_to=get("train.report_to"),
|
||||
use_galore=get("train.use_galore"),
|
||||
use_apollo=get("train.use_apollo"),
|
||||
use_badam=get("train.use_badam"),
|
||||
@ -239,6 +238,12 @@ class Runner:
|
||||
args["badam_switch_interval"] = get("train.badam_switch_interval")
|
||||
args["badam_update_ratio"] = get("train.badam_update_ratio")
|
||||
|
||||
# report_to
|
||||
if "none" in args["report_to"]:
|
||||
args["report_to"] = "none"
|
||||
elif "all" in args["report_to"]:
|
||||
args["report_to"] = "all"
|
||||
|
||||
# swanlab config
|
||||
if get("train.use_swanlab"):
|
||||
args["swanlab_project"] = get("train.swanlab_project")
|
||||
|
@ -111,7 +111,12 @@ def gen_cmd(args: Dict[str, Any]) -> str:
|
||||
"""
|
||||
cmd_lines = ["llamafactory-cli train "]
|
||||
for k, v in clean_cmd(args).items():
|
||||
cmd_lines.append(f" --{k} {str(v)} ")
|
||||
if isinstance(v, dict):
|
||||
cmd_lines.append(f" --{k} {json.dumps(v, ensure_ascii=False)} ")
|
||||
elif isinstance(v, list):
|
||||
cmd_lines.append(f" --{k} {' '.join(map(str, v))} ")
|
||||
else:
|
||||
cmd_lines.append(f" --{k} {str(v)} ")
|
||||
|
||||
if os.name == "nt":
|
||||
cmd_text = "`\n".join(cmd_lines)
|
||||
|
Loading…
x
Reference in New Issue
Block a user