mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-10-15 08:08:09 +08:00
[webui] upgrade to gradio 5 (#6688)
Former-commit-id: 9df7721264ddef0008d7648e6ed173adef99bd74
This commit is contained in:
parent
33525a34b6
commit
31daa6570b
@ -4,7 +4,7 @@ accelerate>=0.34.0,<=1.0.1
|
|||||||
peft>=0.11.1,<=0.12.0
|
peft>=0.11.1,<=0.12.0
|
||||||
trl>=0.8.6,<=0.9.6
|
trl>=0.8.6,<=0.9.6
|
||||||
tokenizers>=0.19.0,<0.20.4
|
tokenizers>=0.19.0,<0.20.4
|
||||||
gradio>=4.0.0,<5.0.0
|
gradio>=4.0.0,<6.0.0
|
||||||
pandas>=2.0.0
|
pandas>=2.0.0
|
||||||
scipy
|
scipy
|
||||||
einops
|
einops
|
||||||
|
@ -14,7 +14,7 @@
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Sequence, Tuple
|
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple
|
||||||
|
|
||||||
from ..chat import ChatModel
|
from ..chat import ChatModel
|
||||||
from ..data import Role
|
from ..data import Role
|
||||||
@ -120,17 +120,17 @@ class WebChatModel(ChatModel):
|
|||||||
|
|
||||||
def append(
|
def append(
|
||||||
self,
|
self,
|
||||||
chatbot: List[List[Optional[str]]],
|
chatbot: List[Dict[str, str]],
|
||||||
messages: Sequence[Dict[str, str]],
|
messages: List[Dict[str, str]],
|
||||||
role: str,
|
role: str,
|
||||||
query: str,
|
query: str,
|
||||||
) -> Tuple[List[List[Optional[str]]], List[Dict[str, str]], str]:
|
) -> Tuple[List[Dict[str, str]], List[Dict[str, str]], str]:
|
||||||
return chatbot + [[query, None]], messages + [{"role": role, "content": query}], ""
|
return chatbot + [{"role": "user", "content": query}], messages + [{"role": role, "content": query}], ""
|
||||||
|
|
||||||
def stream(
|
def stream(
|
||||||
self,
|
self,
|
||||||
chatbot: List[List[Optional[str]]],
|
chatbot: List[Dict[str, str]],
|
||||||
messages: Sequence[Dict[str, str]],
|
messages: List[Dict[str, str]],
|
||||||
system: str,
|
system: str,
|
||||||
tools: str,
|
tools: str,
|
||||||
image: Optional[Any],
|
image: Optional[Any],
|
||||||
@ -138,8 +138,8 @@ class WebChatModel(ChatModel):
|
|||||||
max_new_tokens: int,
|
max_new_tokens: int,
|
||||||
top_p: float,
|
top_p: float,
|
||||||
temperature: float,
|
temperature: float,
|
||||||
) -> Generator[Tuple[List[List[Optional[str]]], List[Dict[str, str]]], None, None]:
|
) -> Generator[Tuple[List[Dict[str, str]], List[Dict[str, str]]], None, None]:
|
||||||
chatbot[-1][1] = ""
|
chatbot.append({"role": "assistant", "content": ""})
|
||||||
response = ""
|
response = ""
|
||||||
for new_text in self.stream_chat(
|
for new_text in self.stream_chat(
|
||||||
messages,
|
messages,
|
||||||
@ -166,5 +166,5 @@ class WebChatModel(ChatModel):
|
|||||||
output_messages = messages + [{"role": Role.ASSISTANT.value, "content": result}]
|
output_messages = messages + [{"role": Role.ASSISTANT.value, "content": result}]
|
||||||
bot_text = result
|
bot_text = result
|
||||||
|
|
||||||
chatbot[-1][1] = bot_text
|
chatbot[-1] = {"role": "assistant", "content": bot_text}
|
||||||
yield chatbot, output_messages
|
yield chatbot, output_messages
|
||||||
|
@ -33,7 +33,7 @@ def create_chat_box(
|
|||||||
engine: "Engine", visible: bool = False
|
engine: "Engine", visible: bool = False
|
||||||
) -> Tuple["Component", "Component", Dict[str, "Component"]]:
|
) -> Tuple["Component", "Component", Dict[str, "Component"]]:
|
||||||
with gr.Column(visible=visible) as chat_box:
|
with gr.Column(visible=visible) as chat_box:
|
||||||
chatbot = gr.Chatbot(show_copy_button=True)
|
chatbot = gr.Chatbot(type="messages", show_copy_button=True)
|
||||||
messages = gr.State([])
|
messages = gr.State([])
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column(scale=4):
|
with gr.Column(scale=4):
|
||||||
|
@ -30,11 +30,10 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
|
|
||||||
def create_top() -> Dict[str, "Component"]:
|
def create_top() -> Dict[str, "Component"]:
|
||||||
available_models = list(SUPPORTED_MODELS.keys()) + ["Custom"]
|
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
lang = gr.Dropdown(choices=["en", "ru", "zh", "ko"], scale=1)
|
lang = gr.Dropdown(choices=["en", "ru", "zh", "ko"], value=None, scale=1)
|
||||||
model_name = gr.Dropdown(choices=available_models, scale=3)
|
available_models = list(SUPPORTED_MODELS.keys()) + ["Custom"]
|
||||||
|
model_name = gr.Dropdown(choices=available_models, value=None, scale=3)
|
||||||
model_path = gr.Textbox(scale=3)
|
model_path = gr.Textbox(scale=3)
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
|
@ -39,9 +39,8 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
|||||||
elem_dict = dict()
|
elem_dict = dict()
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
training_stage = gr.Dropdown(
|
stages = list(TRAINING_STAGES.keys())
|
||||||
choices=list(TRAINING_STAGES.keys()), value=list(TRAINING_STAGES.keys())[0], scale=1
|
training_stage = gr.Dropdown(choices=stages, value=stages[0], scale=1)
|
||||||
)
|
|
||||||
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=1)
|
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=1)
|
||||||
dataset = gr.Dropdown(multiselect=True, allow_custom_value=True, scale=4)
|
dataset = gr.Dropdown(multiselect=True, allow_custom_value=True, scale=4)
|
||||||
preview_elems = create_preview_box(dataset_dir, dataset)
|
preview_elems = create_preview_box(dataset_dir, dataset)
|
||||||
@ -107,8 +106,12 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
|||||||
use_llama_pro = gr.Checkbox()
|
use_llama_pro = gr.Checkbox()
|
||||||
|
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
shift_attn = gr.Checkbox()
|
report_to = gr.Dropdown(
|
||||||
report_to = gr.Checkbox()
|
choices=["none", "all", "wandb", "mlflow", "neptune", "tensorboard"],
|
||||||
|
value=["none"],
|
||||||
|
allow_custom_value=True,
|
||||||
|
multiselect=True,
|
||||||
|
)
|
||||||
|
|
||||||
input_elems.update(
|
input_elems.update(
|
||||||
{
|
{
|
||||||
@ -123,7 +126,6 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
|||||||
mask_history,
|
mask_history,
|
||||||
resize_vocab,
|
resize_vocab,
|
||||||
use_llama_pro,
|
use_llama_pro,
|
||||||
shift_attn,
|
|
||||||
report_to,
|
report_to,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@ -141,7 +143,6 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
|||||||
mask_history=mask_history,
|
mask_history=mask_history,
|
||||||
resize_vocab=resize_vocab,
|
resize_vocab=resize_vocab,
|
||||||
use_llama_pro=use_llama_pro,
|
use_llama_pro=use_llama_pro,
|
||||||
shift_attn=shift_attn,
|
|
||||||
report_to=report_to,
|
report_to=report_to,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
@ -713,24 +713,6 @@ LOCALES = {
|
|||||||
"info": "확장된 블록의 매개변수를 학습 가능하게 만듭니다.",
|
"info": "확장된 블록의 매개변수를 학습 가능하게 만듭니다.",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"shift_attn": {
|
|
||||||
"en": {
|
|
||||||
"label": "Enable S^2 Attention",
|
|
||||||
"info": "Use shift short attention proposed by LongLoRA.",
|
|
||||||
},
|
|
||||||
"ru": {
|
|
||||||
"label": "Включить S^2 внимание",
|
|
||||||
"info": "Использовать сдвиг внимания на короткие дистанции предложенный LongLoRA.",
|
|
||||||
},
|
|
||||||
"zh": {
|
|
||||||
"label": "使用 S^2 Attention",
|
|
||||||
"info": "使用 LongLoRA 提出的 shift short attention。",
|
|
||||||
},
|
|
||||||
"ko": {
|
|
||||||
"label": "S^2 Attention 사용",
|
|
||||||
"info": "LongLoRA에서 제안한 shift short attention을 사용합니다.",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"report_to": {
|
"report_to": {
|
||||||
"en": {
|
"en": {
|
||||||
"label": "Enable external logger",
|
"label": "Enable external logger",
|
||||||
|
@ -144,8 +144,7 @@ class Runner:
|
|||||||
mask_history=get("train.mask_history"),
|
mask_history=get("train.mask_history"),
|
||||||
resize_vocab=get("train.resize_vocab"),
|
resize_vocab=get("train.resize_vocab"),
|
||||||
use_llama_pro=get("train.use_llama_pro"),
|
use_llama_pro=get("train.use_llama_pro"),
|
||||||
shift_attn=get("train.shift_attn"),
|
report_to=get("train.report_to"),
|
||||||
report_to="all" if get("train.report_to") else "none",
|
|
||||||
use_galore=get("train.use_galore"),
|
use_galore=get("train.use_galore"),
|
||||||
use_apollo=get("train.use_apollo"),
|
use_apollo=get("train.use_apollo"),
|
||||||
use_badam=get("train.use_badam"),
|
use_badam=get("train.use_badam"),
|
||||||
@ -239,6 +238,12 @@ class Runner:
|
|||||||
args["badam_switch_interval"] = get("train.badam_switch_interval")
|
args["badam_switch_interval"] = get("train.badam_switch_interval")
|
||||||
args["badam_update_ratio"] = get("train.badam_update_ratio")
|
args["badam_update_ratio"] = get("train.badam_update_ratio")
|
||||||
|
|
||||||
|
# report_to
|
||||||
|
if "none" in args["report_to"]:
|
||||||
|
args["report_to"] = "none"
|
||||||
|
elif "all" in args["report_to"]:
|
||||||
|
args["report_to"] = "all"
|
||||||
|
|
||||||
# swanlab config
|
# swanlab config
|
||||||
if get("train.use_swanlab"):
|
if get("train.use_swanlab"):
|
||||||
args["swanlab_project"] = get("train.swanlab_project")
|
args["swanlab_project"] = get("train.swanlab_project")
|
||||||
|
@ -111,7 +111,12 @@ def gen_cmd(args: Dict[str, Any]) -> str:
|
|||||||
"""
|
"""
|
||||||
cmd_lines = ["llamafactory-cli train "]
|
cmd_lines = ["llamafactory-cli train "]
|
||||||
for k, v in clean_cmd(args).items():
|
for k, v in clean_cmd(args).items():
|
||||||
cmd_lines.append(f" --{k} {str(v)} ")
|
if isinstance(v, dict):
|
||||||
|
cmd_lines.append(f" --{k} {json.dumps(v, ensure_ascii=False)} ")
|
||||||
|
elif isinstance(v, list):
|
||||||
|
cmd_lines.append(f" --{k} {' '.join(map(str, v))} ")
|
||||||
|
else:
|
||||||
|
cmd_lines.append(f" --{k} {str(v)} ")
|
||||||
|
|
||||||
if os.name == "nt":
|
if os.name == "nt":
|
||||||
cmd_text = "`\n".join(cmd_lines)
|
cmd_text = "`\n".join(cmd_lines)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user