diff --git a/src/llmtuner/tuner/tune.py b/src/llmtuner/tuner/tune.py index f0917f37..4eb7f78f 100644 --- a/src/llmtuner/tuner/tune.py +++ b/src/llmtuner/tuner/tune.py @@ -38,10 +38,10 @@ def export_model(args: Optional[Dict[str, Any]] = None, max_shard_size: Optional model_args, _, finetuning_args, _ = get_infer_args(args) model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args) model.config.use_cache = True - tokenizer.padding_side = "left" # restore padding side - tokenizer.init_kwargs["padding_side"] = "left" model.save_pretrained(model_args.export_dir, max_shard_size=max_shard_size) try: + tokenizer.padding_side = "left" # restore padding side + tokenizer.init_kwargs["padding_side"] = "left" tokenizer.save_pretrained(model_args.export_dir) except: logger.warning("Cannot save tokenizer, please copy the files manually.") diff --git a/src/llmtuner/webui/chatter.py b/src/llmtuner/webui/chatter.py index 40f04d18..57eadb01 100644 --- a/src/llmtuner/webui/chatter.py +++ b/src/llmtuner/webui/chatter.py @@ -1,3 +1,4 @@ +import gradio as gr from gradio.components import Component # cannot use TYPE_CHECKING here from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple @@ -28,16 +29,17 @@ class WebChatModel(ChatModel): def load_model(self, data: Dict[Component, Any]) -> Generator[str, None, None]: get = lambda name: data[self.manager.get_elem_by_name(name)] lang = get("top.lang") + error = "" if self.loaded: - yield ALERTS["err_exists"][lang] - return + error = ALERTS["err_exists"][lang] + elif not get("top.model_name"): + error = ALERTS["err_no_model"][lang] + elif not get("top.model_path"): + error = ALERTS["err_no_path"][lang] - if not get("top.model_name"): - yield ALERTS["err_no_model"][lang] - return - - if not get("top.model_path"): - yield ALERTS["err_no_path"][lang] + if error: + gr.Warning(error) + yield error return if get("top.checkpoints"): diff --git a/src/llmtuner/webui/components/chatbot.py b/src/llmtuner/webui/components/chatbot.py index 57f14d4a..13e2dd4d 100644 --- a/src/llmtuner/webui/components/chatbot.py +++ b/src/llmtuner/webui/components/chatbot.py @@ -11,11 +11,9 @@ def create_chat_box( engine: "Engine", visible: Optional[bool] = False ) -> Tuple["Block", "Component", "Component", Dict[str, "Component"]]: - elem_dict = dict() - with gr.Box(visible=visible) as chat_box: chatbot = gr.Chatbot() - + history = gr.State([]) with gr.Row(): with gr.Column(scale=4): system = gr.Textbox(show_label=False) @@ -29,13 +27,6 @@ def create_chat_box( top_p = gr.Slider(0.01, 1, value=gen_kwargs.top_p, step=0.01) temperature = gr.Slider(0.01, 1.5, value=gen_kwargs.temperature, step=0.01) - elem_dict.update(dict( - system=system, query=query, submit_btn=submit_btn, clear_btn=clear_btn, - max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature - )) - - history = gr.State([]) - submit_btn.click( engine.chatter.predict, [chatbot, query, history, system, max_new_tokens, top_p, temperature], @@ -47,4 +38,12 @@ def create_chat_box( clear_btn.click(lambda: ([], []), outputs=[chatbot, history], show_progress=True) - return chat_box, chatbot, history, elem_dict + return chat_box, chatbot, history, dict( + system=system, + query=query, + submit_btn=submit_btn, + clear_btn=clear_btn, + max_new_tokens=max_new_tokens, + top_p=top_p, + temperature=temperature + ) diff --git a/src/llmtuner/webui/components/data.py b/src/llmtuner/webui/components/data.py index 745f7648..effa39da 100644 --- a/src/llmtuner/webui/components/data.py +++ b/src/llmtuner/webui/components/data.py @@ -1,17 +1,103 @@ +import os +import json import gradio as gr -from typing import TYPE_CHECKING, Tuple +from typing import TYPE_CHECKING, Any, Dict, Tuple + +from llmtuner.webui.common import DATA_CONFIG if TYPE_CHECKING: - from gradio.blocks import Block from gradio.components import Component -def create_preview_box() -> Tuple["Block", "Component", "Component", "Component"]: +PAGE_SIZE = 2 + + +def prev_page(page_index: int) -> int: + return page_index - 1 if page_index > 0 else page_index + + +def next_page(page_index: int, total_num: int) -> int: + return page_index + 1 if (page_index + 1) * PAGE_SIZE < total_num else page_index + + +def can_preview(dataset_dir: str, dataset: list) -> Dict[str, Any]: + with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f: + dataset_info = json.load(f) + + if ( + len(dataset) > 0 + and "file_name" in dataset_info[dataset[0]] + and os.path.isfile(os.path.join(dataset_dir, dataset_info[dataset[0]]["file_name"])) + ): + return gr.update(interactive=True) + else: + return gr.update(interactive=False) + + +def get_preview(dataset_dir: str, dataset: list, page_index: int) -> Tuple[int, list, Dict[str, Any]]: + with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f: + dataset_info = json.load(f) + + data_file: str = dataset_info[dataset[0]]["file_name"] + with open(os.path.join(dataset_dir, data_file), "r", encoding="utf-8") as f: + if data_file.endswith(".json"): + data = json.load(f) + elif data_file.endswith(".jsonl"): + data = [json.loads(line) for line in f] + else: + data = [line for line in f] + return len(data), data[PAGE_SIZE * page_index : PAGE_SIZE * (page_index + 1)], gr.update(visible=True) + + +def create_preview_box(dataset_dir: "gr.Textbox", dataset: "gr.Dropdown") -> Dict[str, "Component"]: + data_preview_btn = gr.Button(interactive=False, scale=1) with gr.Column(visible=False, elem_classes="modal-box") as preview_box: - preview_count = gr.Number(interactive=False) - preview_samples = gr.JSON(interactive=False) - close_btn = gr.Button() + with gr.Row(): + preview_count = gr.Number(value=0, interactive=False, precision=0) + page_index = gr.Number(value=0, interactive=False, precision=0) + with gr.Row(): + prev_btn = gr.Button() + next_btn = gr.Button() + close_btn = gr.Button() + + with gr.Row(): + preview_samples = gr.JSON(interactive=False) + + dataset.change( + can_preview, [dataset_dir, dataset], [data_preview_btn], queue=False + ).then( + lambda: 0, outputs=[page_index], queue=False + ) + data_preview_btn.click( + get_preview, + [dataset_dir, dataset, page_index], + [preview_count, preview_samples, preview_box], + queue=False + ) + prev_btn.click( + prev_page, [page_index], [page_index], queue=False + ).then( + get_preview, + [dataset_dir, dataset, page_index], + [preview_count, preview_samples, preview_box], + queue=False + ) + next_btn.click( + next_page, [page_index, preview_count], [page_index], queue=False + ).then( + get_preview, + [dataset_dir, dataset, page_index], + [preview_count, preview_samples, preview_box], + queue=False + ) close_btn.click(lambda: gr.update(visible=False), outputs=[preview_box], queue=False) - - return preview_box, preview_count, preview_samples, close_btn + return dict( + data_preview_btn=data_preview_btn, + preview_count=preview_count, + page_index=page_index, + prev_btn=prev_btn, + next_btn=next_btn, + close_btn=close_btn, + preview_samples=preview_samples + ) diff --git a/src/llmtuner/webui/components/eval.py b/src/llmtuner/webui/components/eval.py index c8d1530e..36c994a6 100644 --- a/src/llmtuner/webui/components/eval.py +++ b/src/llmtuner/webui/components/eval.py @@ -3,7 +3,6 @@ from typing import TYPE_CHECKING, Dict from llmtuner.webui.common import list_dataset, DEFAULT_DATA_DIR from llmtuner.webui.components.data import create_preview_box -from llmtuner.webui.utils import can_preview, get_preview if TYPE_CHECKING: from gradio.components import Component @@ -17,28 +16,12 @@ def create_eval_tab(engine: "Engine") -> Dict[str, "Component"]: with gr.Row(): dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2) dataset = gr.Dropdown(multiselect=True, scale=4) - data_preview_btn = gr.Button(interactive=False, scale=1) + preview_elems = create_preview_box(dataset_dir, dataset) dataset_dir.change(list_dataset, [dataset_dir], [dataset], queue=False) - dataset.change(can_preview, [dataset_dir, dataset], [data_preview_btn], queue=False) input_elems.update({dataset_dir, dataset}) - elem_dict.update(dict( - dataset_dir=dataset_dir, dataset=dataset, data_preview_btn=data_preview_btn - )) - - preview_box, preview_count, preview_samples, close_btn = create_preview_box() - - data_preview_btn.click( - get_preview, - [dataset_dir, dataset], - [preview_count, preview_samples, preview_box], - queue=False - ) - - elem_dict.update(dict( - preview_count=preview_count, preview_samples=preview_samples, close_btn=close_btn - )) + elem_dict.update(dict(dataset_dir=dataset_dir, dataset=dataset, **preview_elems)) with gr.Row(): cutoff_len = gr.Slider(value=1024, minimum=4, maximum=8192, step=1) diff --git a/src/llmtuner/webui/components/export.py b/src/llmtuner/webui/components/export.py index 75493d4a..d16fa3d1 100644 --- a/src/llmtuner/webui/components/export.py +++ b/src/llmtuner/webui/components/export.py @@ -1,16 +1,54 @@ import gradio as gr -from typing import TYPE_CHECKING, Dict +from typing import TYPE_CHECKING, Dict, Generator, List -from llmtuner.webui.utils import save_model +from llmtuner.tuner import export_model +from llmtuner.webui.common import get_save_dir +from llmtuner.webui.locales import ALERTS if TYPE_CHECKING: from gradio.components import Component from llmtuner.webui.engine import Engine -def create_export_tab(engine: "Engine") -> Dict[str, "Component"]: - elem_dict = dict() +def save_model( + lang: str, + model_name: str, + model_path: str, + checkpoints: List[str], + finetuning_type: str, + template: str, + max_shard_size: int, + export_dir: str +) -> Generator[str, None, None]: + error = "" + if not model_name: + error = ALERTS["err_no_model"][lang] + elif not model_path: + error = ALERTS["err_no_path"][lang] + elif not checkpoints: + error = ALERTS["err_no_checkpoint"][lang] + elif not export_dir: + error = ALERTS["err_no_export_dir"][lang] + if error: + gr.Warning(error) + yield error + return + + args = dict( + model_name_or_path=model_path, + checkpoint_dir=",".join([get_save_dir(model_name, finetuning_type, ckpt) for ckpt in checkpoints]), + finetuning_type=finetuning_type, + template=template, + export_dir=export_dir + ) + + yield ALERTS["info_exporting"][lang] + export_model(args, max_shard_size="{}GB".format(max_shard_size)) + yield ALERTS["info_exported"][lang] + + +def create_export_tab(engine: "Engine") -> Dict[str, "Component"]: with gr.Row(): export_dir = gr.Textbox() max_shard_size = gr.Slider(value=10, minimum=1, maximum=100) @@ -33,11 +71,9 @@ def create_export_tab(engine: "Engine") -> Dict[str, "Component"]: [info_box] ) - elem_dict.update(dict( + return dict( export_dir=export_dir, max_shard_size=max_shard_size, export_btn=export_btn, info_box=info_box - )) - - return elem_dict + ) diff --git a/src/llmtuner/webui/components/train.py b/src/llmtuner/webui/components/train.py index 5c45268d..11109c97 100644 --- a/src/llmtuner/webui/components/train.py +++ b/src/llmtuner/webui/components/train.py @@ -5,7 +5,7 @@ from transformers.trainer_utils import SchedulerType from llmtuner.extras.constants import TRAINING_STAGES from llmtuner.webui.common import list_checkpoint, list_dataset, DEFAULT_DATA_DIR from llmtuner.webui.components.data import create_preview_box -from llmtuner.webui.utils import can_preview, get_preview, gen_plot +from llmtuner.webui.utils import gen_plot if TYPE_CHECKING: from gradio.components import Component @@ -22,28 +22,14 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]: ) dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2) dataset = gr.Dropdown(multiselect=True, scale=4) - data_preview_btn = gr.Button(interactive=False, scale=1) + preview_elems = create_preview_box(dataset_dir, dataset) training_stage.change(list_dataset, [dataset_dir, training_stage], [dataset], queue=False) dataset_dir.change(list_dataset, [dataset_dir, training_stage], [dataset], queue=False) - dataset.change(can_preview, [dataset_dir, dataset], [data_preview_btn], queue=False) input_elems.update({training_stage, dataset_dir, dataset}) elem_dict.update(dict( - training_stage=training_stage, dataset_dir=dataset_dir, dataset=dataset, data_preview_btn=data_preview_btn - )) - - preview_box, preview_count, preview_samples, close_btn = create_preview_box() - - data_preview_btn.click( - get_preview, - [dataset_dir, dataset], - [preview_count, preview_samples, preview_box], - queue=False - ) - - elem_dict.update(dict( - preview_count=preview_count, preview_samples=preview_samples, close_btn=close_btn + training_stage=training_stage, dataset_dir=dataset_dir, dataset=dataset, **preview_elems )) with gr.Row(): @@ -143,16 +129,17 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]: input_elems.add(output_dir) output_elems = [output_box, process_bar] - elem_dict.update(dict( - cmd_preview_btn=cmd_preview_btn, start_btn=start_btn, stop_btn=stop_btn, output_dir=output_dir, - resume_btn=resume_btn, process_bar=process_bar, output_box=output_box, loss_viewer=loss_viewer - )) cmd_preview_btn.click(engine.runner.preview_train, input_elems, output_elems) start_btn.click(engine.runner.run_train, input_elems, output_elems) stop_btn.click(engine.runner.set_abort, queue=False) resume_btn.change(engine.runner.monitor, outputs=output_elems) + elem_dict.update(dict( + cmd_preview_btn=cmd_preview_btn, start_btn=start_btn, stop_btn=stop_btn, output_dir=output_dir, + resume_btn=resume_btn, process_bar=process_bar, output_box=output_box, loss_viewer=loss_viewer + )) + output_box.change( gen_plot, [ diff --git a/src/llmtuner/webui/css.py b/src/llmtuner/webui/css.py index 6dab6ffa..c86fb96b 100644 --- a/src/llmtuner/webui/css.py +++ b/src/llmtuner/webui/css.py @@ -6,7 +6,9 @@ CSS = r""" transform: translate(-50%, -50%); /* center horizontally */ max-width: 1000px; max-height: 750px; + overflow-y: auto; background-color: var(--input-background-fill); + flex-wrap: nowrap !important; border: 2px solid black !important; z-index: 1000; padding: 10px; diff --git a/src/llmtuner/webui/locales.py b/src/llmtuner/webui/locales.py index 780e616f..cc2a1842 100644 --- a/src/llmtuner/webui/locales.py +++ b/src/llmtuner/webui/locales.py @@ -163,12 +163,28 @@ LOCALES = { "label": "数量" } }, - "preview_samples": { + "page_index": { "en": { - "label": "Samples" + "label": "Page" }, "zh": { - "label": "样例" + "label": "页数" + } + }, + "prev_btn": { + "en": { + "value": "Prev" + }, + "zh": { + "value": "上一页" + } + }, + "next_btn": { + "en": { + "value": "Next" + }, + "zh": { + "value": "下一页" } }, "close_btn": { @@ -179,6 +195,14 @@ LOCALES = { "value": "关闭" } }, + "preview_samples": { + "en": { + "label": "Samples" + }, + "zh": { + "label": "样例" + } + }, "cutoff_len": { "en": { "label": "Cutoff length", diff --git a/src/llmtuner/webui/utils.py b/src/llmtuner/webui/utils.py index 78322356..933d951d 100644 --- a/src/llmtuner/webui/utils.py +++ b/src/llmtuner/webui/utils.py @@ -3,13 +3,11 @@ import json import gradio as gr import matplotlib.figure import matplotlib.pyplot as plt -from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Dict from datetime import datetime from llmtuner.extras.ploting import smooth -from llmtuner.tuner import export_model -from llmtuner.webui.common import get_save_dir, DATA_CONFIG -from llmtuner.webui.locales import ALERTS +from llmtuner.webui.common import get_save_dir if TYPE_CHECKING: from llmtuner.extras.callbacks import LogCallback @@ -33,37 +31,6 @@ def get_time() -> str: return datetime.now().strftime('%Y-%m-%d-%H-%M-%S') -def can_preview(dataset_dir: str, dataset: list) -> Dict[str, Any]: - with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f: - dataset_info = json.load(f) - - if ( - len(dataset) > 0 - and "file_name" in dataset_info[dataset[0]] - and os.path.isfile(os.path.join(dataset_dir, dataset_info[dataset[0]]["file_name"])) - ): - return gr.update(interactive=True) - else: - return gr.update(interactive=False) - - -def get_preview( - dataset_dir: str, dataset: list, start: Optional[int] = 0, end: Optional[int] = 2 -) -> Tuple[int, list, Dict[str, Any]]: - with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f: - dataset_info = json.load(f) - - data_file: str = dataset_info[dataset[0]]["file_name"] - with open(os.path.join(dataset_dir, data_file), "r", encoding="utf-8") as f: - if data_file.endswith(".json"): - data = json.load(f) - elif data_file.endswith(".jsonl"): - data = [json.loads(line) for line in f] - else: - data = [line for line in f] - return len(data), data[start:end], gr.update(visible=True) - - def can_quantize(finetuning_type: str) -> Dict[str, Any]: if finetuning_type != "lora": return gr.update(value="None", interactive=False) @@ -116,42 +83,3 @@ def gen_plot(base_model: str, finetuning_type: str, output_dir: str) -> matplotl ax.set_xlabel("step") ax.set_ylabel("loss") return fig - - -def save_model( - lang: str, - model_name: str, - model_path: str, - checkpoints: List[str], - finetuning_type: str, - template: str, - max_shard_size: int, - export_dir: str -) -> Generator[str, None, None]: - if not model_name: - yield ALERTS["err_no_model"][lang] - return - - if not model_path: - yield ALERTS["err_no_path"][lang] - return - - if not checkpoints: - yield ALERTS["err_no_checkpoint"][lang] - return - - if not export_dir: - yield ALERTS["err_no_export_dir"][lang] - return - - args = dict( - model_name_or_path=model_path, - checkpoint_dir=",".join([get_save_dir(model_name, finetuning_type, ckpt) for ckpt in checkpoints]), - finetuning_type=finetuning_type, - template=template, - export_dir=export_dir - ) - - yield ALERTS["info_exporting"][lang] - export_model(args, max_shard_size="{}GB".format(max_shard_size)) - yield ALERTS["info_exported"][lang]