From 31daa6570b093004b8f7525fd9f16503185d83df Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Fri, 17 Jan 2025 20:15:42 +0800 Subject: [PATCH] [webui] upgrade to gradio 5 (#6688) Former-commit-id: 9df7721264ddef0008d7648e6ed173adef99bd74 --- requirements.txt | 2 +- src/llamafactory/webui/chatter.py | 20 ++++++++++---------- src/llamafactory/webui/components/chatbot.py | 2 +- src/llamafactory/webui/components/top.py | 7 +++---- src/llamafactory/webui/components/train.py | 15 ++++++++------- src/llamafactory/webui/locales.py | 18 ------------------ src/llamafactory/webui/runner.py | 9 +++++++-- src/llamafactory/webui/utils.py | 7 ++++++- 8 files changed, 36 insertions(+), 44 deletions(-) diff --git a/requirements.txt b/requirements.txt index b47643e3..04d395d6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,7 +4,7 @@ accelerate>=0.34.0,<=1.0.1 peft>=0.11.1,<=0.12.0 trl>=0.8.6,<=0.9.6 tokenizers>=0.19.0,<0.20.4 -gradio>=4.0.0,<5.0.0 +gradio>=4.0.0,<6.0.0 pandas>=2.0.0 scipy einops diff --git a/src/llamafactory/webui/chatter.py b/src/llamafactory/webui/chatter.py index 9d0e4b20..7b360be6 100644 --- a/src/llamafactory/webui/chatter.py +++ b/src/llamafactory/webui/chatter.py @@ -14,7 +14,7 @@ import json import os -from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Sequence, Tuple +from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple from ..chat import ChatModel from ..data import Role @@ -120,17 +120,17 @@ class WebChatModel(ChatModel): def append( self, - chatbot: List[List[Optional[str]]], - messages: Sequence[Dict[str, str]], + chatbot: List[Dict[str, str]], + messages: List[Dict[str, str]], role: str, query: str, - ) -> Tuple[List[List[Optional[str]]], List[Dict[str, str]], str]: - return chatbot + [[query, None]], messages + [{"role": role, "content": query}], "" + ) -> Tuple[List[Dict[str, str]], List[Dict[str, str]], str]: + return chatbot + [{"role": "user", "content": query}], messages + [{"role": role, "content": query}], "" def stream( self, - chatbot: List[List[Optional[str]]], - messages: Sequence[Dict[str, str]], + chatbot: List[Dict[str, str]], + messages: List[Dict[str, str]], system: str, tools: str, image: Optional[Any], @@ -138,8 +138,8 @@ class WebChatModel(ChatModel): max_new_tokens: int, top_p: float, temperature: float, - ) -> Generator[Tuple[List[List[Optional[str]]], List[Dict[str, str]]], None, None]: - chatbot[-1][1] = "" + ) -> Generator[Tuple[List[Dict[str, str]], List[Dict[str, str]]], None, None]: + chatbot.append({"role": "assistant", "content": ""}) response = "" for new_text in self.stream_chat( messages, @@ -166,5 +166,5 @@ class WebChatModel(ChatModel): output_messages = messages + [{"role": Role.ASSISTANT.value, "content": result}] bot_text = result - chatbot[-1][1] = bot_text + chatbot[-1] = {"role": "assistant", "content": bot_text} yield chatbot, output_messages diff --git a/src/llamafactory/webui/components/chatbot.py b/src/llamafactory/webui/components/chatbot.py index b71e8bb3..53e41b93 100644 --- a/src/llamafactory/webui/components/chatbot.py +++ b/src/llamafactory/webui/components/chatbot.py @@ -33,7 +33,7 @@ def create_chat_box( engine: "Engine", visible: bool = False ) -> Tuple["Component", "Component", Dict[str, "Component"]]: with gr.Column(visible=visible) as chat_box: - chatbot = gr.Chatbot(show_copy_button=True) + chatbot = gr.Chatbot(type="messages", show_copy_button=True) messages = gr.State([]) with gr.Row(): with gr.Column(scale=4): diff --git a/src/llamafactory/webui/components/top.py b/src/llamafactory/webui/components/top.py index bec6c507..528ee908 100644 --- a/src/llamafactory/webui/components/top.py +++ b/src/llamafactory/webui/components/top.py @@ -30,11 +30,10 @@ if TYPE_CHECKING: def create_top() -> Dict[str, "Component"]: - available_models = list(SUPPORTED_MODELS.keys()) + ["Custom"] - with gr.Row(): - lang = gr.Dropdown(choices=["en", "ru", "zh", "ko"], scale=1) - model_name = gr.Dropdown(choices=available_models, scale=3) + lang = gr.Dropdown(choices=["en", "ru", "zh", "ko"], value=None, scale=1) + available_models = list(SUPPORTED_MODELS.keys()) + ["Custom"] + model_name = gr.Dropdown(choices=available_models, value=None, scale=3) model_path = gr.Textbox(scale=3) with gr.Row(): diff --git a/src/llamafactory/webui/components/train.py b/src/llamafactory/webui/components/train.py index ae3c416c..b4c0bb2a 100644 --- a/src/llamafactory/webui/components/train.py +++ b/src/llamafactory/webui/components/train.py @@ -39,9 +39,8 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]: elem_dict = dict() with gr.Row(): - training_stage = gr.Dropdown( - choices=list(TRAINING_STAGES.keys()), value=list(TRAINING_STAGES.keys())[0], scale=1 - ) + stages = list(TRAINING_STAGES.keys()) + training_stage = gr.Dropdown(choices=stages, value=stages[0], scale=1) dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=1) dataset = gr.Dropdown(multiselect=True, allow_custom_value=True, scale=4) preview_elems = create_preview_box(dataset_dir, dataset) @@ -107,8 +106,12 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]: use_llama_pro = gr.Checkbox() with gr.Column(): - shift_attn = gr.Checkbox() - report_to = gr.Checkbox() + report_to = gr.Dropdown( + choices=["none", "all", "wandb", "mlflow", "neptune", "tensorboard"], + value=["none"], + allow_custom_value=True, + multiselect=True, + ) input_elems.update( { @@ -123,7 +126,6 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]: mask_history, resize_vocab, use_llama_pro, - shift_attn, report_to, } ) @@ -141,7 +143,6 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]: mask_history=mask_history, resize_vocab=resize_vocab, use_llama_pro=use_llama_pro, - shift_attn=shift_attn, report_to=report_to, ) ) diff --git a/src/llamafactory/webui/locales.py b/src/llamafactory/webui/locales.py index b8940d9d..154f1794 100644 --- a/src/llamafactory/webui/locales.py +++ b/src/llamafactory/webui/locales.py @@ -713,24 +713,6 @@ LOCALES = { "info": "확장된 블록의 매개변수를 학습 가능하게 만듭니다.", }, }, - "shift_attn": { - "en": { - "label": "Enable S^2 Attention", - "info": "Use shift short attention proposed by LongLoRA.", - }, - "ru": { - "label": "Включить S^2 внимание", - "info": "Использовать сдвиг внимания на короткие дистанции предложенный LongLoRA.", - }, - "zh": { - "label": "使用 S^2 Attention", - "info": "使用 LongLoRA 提出的 shift short attention。", - }, - "ko": { - "label": "S^2 Attention 사용", - "info": "LongLoRA에서 제안한 shift short attention을 사용합니다.", - }, - }, "report_to": { "en": { "label": "Enable external logger", diff --git a/src/llamafactory/webui/runner.py b/src/llamafactory/webui/runner.py index c397416d..d2195ea4 100644 --- a/src/llamafactory/webui/runner.py +++ b/src/llamafactory/webui/runner.py @@ -144,8 +144,7 @@ class Runner: mask_history=get("train.mask_history"), resize_vocab=get("train.resize_vocab"), use_llama_pro=get("train.use_llama_pro"), - shift_attn=get("train.shift_attn"), - report_to="all" if get("train.report_to") else "none", + report_to=get("train.report_to"), use_galore=get("train.use_galore"), use_apollo=get("train.use_apollo"), use_badam=get("train.use_badam"), @@ -239,6 +238,12 @@ class Runner: args["badam_switch_interval"] = get("train.badam_switch_interval") args["badam_update_ratio"] = get("train.badam_update_ratio") + # report_to + if "none" in args["report_to"]: + args["report_to"] = "none" + elif "all" in args["report_to"]: + args["report_to"] = "all" + # swanlab config if get("train.use_swanlab"): args["swanlab_project"] = get("train.swanlab_project") diff --git a/src/llamafactory/webui/utils.py b/src/llamafactory/webui/utils.py index b014e192..e7b6aa01 100644 --- a/src/llamafactory/webui/utils.py +++ b/src/llamafactory/webui/utils.py @@ -111,7 +111,12 @@ def gen_cmd(args: Dict[str, Any]) -> str: """ cmd_lines = ["llamafactory-cli train "] for k, v in clean_cmd(args).items(): - cmd_lines.append(f" --{k} {str(v)} ") + if isinstance(v, dict): + cmd_lines.append(f" --{k} {json.dumps(v, ensure_ascii=False)} ") + elif isinstance(v, list): + cmd_lines.append(f" --{k} {' '.join(map(str, v))} ") + else: + cmd_lines.append(f" --{k} {str(v)} ") if os.name == "nt": cmd_text = "`\n".join(cmd_lines)