Former-commit-id: a6a04be2e6700050c64e59c00a86698f1098e7cc
This commit is contained in:
hiyouga 2023-10-15 18:28:45 +08:00
parent e5e740b54d
commit eafde5b73f
9 changed files with 40 additions and 57 deletions

View File

@ -12,7 +12,7 @@ fire
jieba jieba
rouge-chinese rouge-chinese
nltk nltk
gradio>=3.36.0 gradio==3.38.0
uvicorn uvicorn
pydantic==1.10.11 pydantic==1.10.11
fastapi==0.95.1 fastapi==0.95.1

View File

@ -186,7 +186,7 @@ def get_train_args(
# postprocess model_args # postprocess model_args
model_args.compute_dtype = ( model_args.compute_dtype = (
torch.bfloat16 if training_args.bf16 else (torch.float16 if training_args.fp16 else torch.float32) torch.bfloat16 if training_args.bf16 else (torch.float16 if training_args.fp16 else None)
) )
model_args.model_max_length = data_args.cutoff_len model_args.model_max_length = data_args.cutoff_len

View File

@ -1,7 +1,7 @@
import os import os
import json import json
import gradio as gr import gradio as gr
from typing import Any, Dict, Optional, Union from typing import Any, Dict, Optional
from transformers.utils import ( from transformers.utils import (
WEIGHTS_NAME, WEIGHTS_NAME,
WEIGHTS_INDEX_NAME, WEIGHTS_INDEX_NAME,
@ -27,7 +27,6 @@ CKPT_NAMES = [
ADAPTER_WEIGHTS_NAME, ADAPTER_WEIGHTS_NAME,
ADAPTER_SAFE_WEIGHTS_NAME ADAPTER_SAFE_WEIGHTS_NAME
] ]
CONFIG_CLASS = Dict[str, Union[str, Dict[str, str]]]
def get_save_dir(*args) -> os.PathLike: def get_save_dir(*args) -> os.PathLike:
@ -38,7 +37,7 @@ def get_config_path() -> os.PathLike:
return os.path.join(DEFAULT_CACHE_DIR, USER_CONFIG) return os.path.join(DEFAULT_CACHE_DIR, USER_CONFIG)
def load_config() -> CONFIG_CLASS: def load_config() -> Dict[str, Any]:
try: try:
with open(get_config_path(), "r", encoding="utf-8") as f: with open(get_config_path(), "r", encoding="utf-8") as f:
return json.load(f) return json.load(f)
@ -46,20 +45,20 @@ def load_config() -> CONFIG_CLASS:
return {"lang": None, "last_model": None, "path_dict": {}, "cache_dir": None} return {"lang": None, "last_model": None, "path_dict": {}, "cache_dir": None}
def save_config( def save_config(lang: str, model_name: Optional[str] = None, model_path: Optional[str] = None) -> None:
config: CONFIG_CLASS, lang: str, model_name: Optional[str] = None, model_path: Optional[str] = None
) -> None:
os.makedirs(DEFAULT_CACHE_DIR, exist_ok=True) os.makedirs(DEFAULT_CACHE_DIR, exist_ok=True)
config["lang"] = lang or config["lang"] user_config = load_config()
user_config["lang"] = lang or user_config["lang"]
if model_name: if model_name:
config["last_model"] = model_name user_config["last_model"] = model_name
config["path_dict"][model_name] = model_path user_config["path_dict"][model_name] = model_path
with open(get_config_path(), "w", encoding="utf-8") as f: with open(get_config_path(), "w", encoding="utf-8") as f:
json.dump(config, f, indent=2, ensure_ascii=False) json.dump(user_config, f, indent=2, ensure_ascii=False)
def get_model_path(config: Dict[str, Any], model_name: str) -> str: def get_model_path(model_name: str) -> str:
return config["path_dict"].get(model_name, None) or SUPPORTED_MODELS.get(model_name, "") user_config = load_config()
return user_config["path_dict"].get(model_name, None) or SUPPORTED_MODELS.get(model_name, "")
def get_module(model_name: str) -> str: def get_module(model_name: str) -> str:

View File

@ -17,10 +17,7 @@ def create_infer_tab(engine: "Engine") -> Dict[str, "Component"]:
unload_btn = gr.Button() unload_btn = gr.Button()
info_box = gr.Textbox(show_label=False, interactive=False) info_box = gr.Textbox(show_label=False, interactive=False)
elem_dict.update(dict(load_btn=load_btn, unload_btn=unload_btn, info_box=info_box))
elem_dict.update(dict(
info_box=info_box, load_btn=load_btn, unload_btn=unload_btn
))
chat_box, chatbot, history, chat_elems = create_chat_box(engine, visible=False) chat_box, chatbot, history, chat_elems = create_chat_box(engine, visible=False)
elem_dict.update(dict(chat_box=chat_box, **chat_elems)) elem_dict.update(dict(chat_box=chat_box, **chat_elems))

View File

@ -3,7 +3,7 @@ from typing import TYPE_CHECKING, Dict
from llmtuner.extras.constants import METHODS, SUPPORTED_MODELS from llmtuner.extras.constants import METHODS, SUPPORTED_MODELS
from llmtuner.extras.template import templates from llmtuner.extras.template import templates
from llmtuner.webui.common import get_model_path, get_template, list_checkpoint, load_config, save_config from llmtuner.webui.common import get_model_path, get_template, list_checkpoint, save_config
from llmtuner.webui.utils import can_quantize from llmtuner.webui.utils import can_quantize
if TYPE_CHECKING: if TYPE_CHECKING:
@ -12,7 +12,6 @@ if TYPE_CHECKING:
def create_top() -> Dict[str, "Component"]: def create_top() -> Dict[str, "Component"]:
available_models = list(SUPPORTED_MODELS.keys()) + ["Custom"] available_models = list(SUPPORTED_MODELS.keys()) + ["Custom"]
config = gr.State(value=load_config())
with gr.Row(): with gr.Row():
lang = gr.Dropdown(choices=["en", "zh"], scale=1) lang = gr.Dropdown(choices=["en", "zh"], scale=1)
@ -39,17 +38,17 @@ def create_top() -> Dict[str, "Component"]:
model_name.change( model_name.change(
list_checkpoint, [model_name, finetuning_type], [checkpoints], queue=False list_checkpoint, [model_name, finetuning_type], [checkpoints], queue=False
).then( ).then(
get_model_path, [config, model_name], [model_path], queue=False get_model_path, [model_name], [model_path], queue=False
).then( ).then(
get_template, [model_name], [template], queue=False get_template, [model_name], [template], queue=False
) # do not save config since the below line will save ) # do not save config since the below line will save
model_path.change(save_config, inputs=[config, lang, model_name, model_path]) model_path.change(save_config, inputs=[lang, model_name, model_path], queue=False)
finetuning_type.change( finetuning_type.change(
list_checkpoint, [model_name, finetuning_type], [checkpoints] list_checkpoint, [model_name, finetuning_type], [checkpoints], queue=False
).then( ).then(
can_quantize, [finetuning_type], [quantization_bit] can_quantize, [finetuning_type], [quantization_bit], queue=False
) )
refresh_btn.click( refresh_btn.click(
@ -57,7 +56,6 @@ def create_top() -> Dict[str, "Component"]:
) )
return dict( return dict(
config=config,
lang=lang, lang=lang,
model_name=model_name, model_name=model_name,
model_path=model_path, model_path=model_path,

View File

@ -3,7 +3,7 @@ from gradio.components import Component # cannot use TYPE_CHECKING here
from typing import Any, Dict, Generator, Optional from typing import Any, Dict, Generator, Optional
from llmtuner.webui.chatter import WebChatModel from llmtuner.webui.chatter import WebChatModel
from llmtuner.webui.common import get_model_path, list_dataset, CONFIG_CLASS from llmtuner.webui.common import get_model_path, list_dataset, load_config
from llmtuner.webui.locales import LOCALES from llmtuner.webui.locales import LOCALES
from llmtuner.webui.manager import Manager from llmtuner.webui.manager import Manager
from llmtuner.webui.runner import Runner from llmtuner.webui.runner import Runner
@ -21,8 +21,9 @@ class Engine:
def _form_dict(self, resume_dict: Dict[str, Dict[str, Any]]): def _form_dict(self, resume_dict: Dict[str, Dict[str, Any]]):
return {self.manager.get_elem(k): gr.update(**v) for k, v in resume_dict.items()} return {self.manager.get_elem(k): gr.update(**v) for k, v in resume_dict.items()}
def resume(self, config: CONFIG_CLASS) -> Generator[Dict[Component, Dict[str, Any]], None, None]: def resume(self) -> Generator[Dict[Component, Dict[str, Any]], None, None]:
lang = config.get("lang", None) or "en" user_config = load_config()
lang = user_config.get("lang", None) or "en"
resume_dict = { resume_dict = {
"top.lang": {"value": lang}, "top.lang": {"value": lang},
@ -33,9 +34,9 @@ class Engine:
resume_dict["train.dataset"] = {"choices": list_dataset()["choices"]} resume_dict["train.dataset"] = {"choices": list_dataset()["choices"]}
resume_dict["eval.dataset"] = {"choices": list_dataset()["choices"]} resume_dict["eval.dataset"] = {"choices": list_dataset()["choices"]}
if config.get("last_model", None): if user_config.get("last_model", None):
resume_dict["top.model_name"] = {"value": config["last_model"]} resume_dict["top.model_name"] = {"value": user_config["last_model"]}
resume_dict["top.model_path"] = {"value": get_model_path(config, config["last_model"])} resume_dict["top.model_path"] = {"value": get_model_path(user_config["last_model"])}
yield self._form_dict(resume_dict) yield self._form_dict(resume_dict)

View File

@ -9,12 +9,12 @@ from llmtuner.webui.components import (
create_export_tab, create_export_tab,
create_chat_box create_chat_box
) )
from llmtuner.webui.common import load_config, save_config from llmtuner.webui.common import save_config
from llmtuner.webui.css import CSS from llmtuner.webui.css import CSS
from llmtuner.webui.engine import Engine from llmtuner.webui.engine import Engine
require_version("gradio>=3.36.0", "To fix: pip install gradio>=3.36.0") require_version("gradio==3.38.0", "To fix: pip install gradio==3.38.0")
def create_ui() -> gr.Blocks: def create_ui() -> gr.Blocks:
@ -23,9 +23,6 @@ def create_ui() -> gr.Blocks:
with gr.Blocks(title="Web Tuner", css=CSS) as demo: with gr.Blocks(title="Web Tuner", css=CSS) as demo:
engine.manager.all_elems["top"] = create_top() engine.manager.all_elems["top"] = create_top()
lang: "gr.Dropdown" = engine.manager.get_elem("top.lang") lang: "gr.Dropdown" = engine.manager.get_elem("top.lang")
config = engine.manager.get_elem("top.config")
model_name = engine.manager.get_elem("top.model_name")
model_path = engine.manager.get_elem("top.model_path")
with gr.Tab("Train"): with gr.Tab("Train"):
engine.manager.all_elems["train"] = create_train_tab(engine) engine.manager.all_elems["train"] = create_train_tab(engine)
@ -39,13 +36,9 @@ def create_ui() -> gr.Blocks:
with gr.Tab("Export"): with gr.Tab("Export"):
engine.manager.all_elems["export"] = create_export_tab(engine) engine.manager.all_elems["export"] = create_export_tab(engine)
demo.load(engine.resume, [config], engine.manager.list_elems()) demo.load(engine.resume, outputs=engine.manager.list_elems())
lang.change(engine.change_lang, [lang], engine.manager.list_elems(), queue=False)
lang.change( lang.input(save_config, inputs=[lang], queue=False)
engine.change_lang, [lang], engine.manager.list_elems(), queue=False
).then(
save_config, inputs=[config, lang, model_name, model_path]
)
return demo return demo
@ -54,21 +47,15 @@ def create_web_demo() -> gr.Blocks:
engine = Engine(pure_chat=True) engine = Engine(pure_chat=True)
with gr.Blocks(title="Web Demo", css=CSS) as demo: with gr.Blocks(title="Web Demo", css=CSS) as demo:
config = gr.State(value=load_config())
lang = gr.Dropdown(choices=["en", "zh"]) lang = gr.Dropdown(choices=["en", "zh"])
engine.manager.all_elems["top"] = dict(lang=lang)
engine.manager.all_elems["top"] = dict(config=config, lang=lang)
chat_box, _, _, chat_elems = create_chat_box(engine, visible=True) chat_box, _, _, chat_elems = create_chat_box(engine, visible=True)
engine.manager.all_elems["infer"] = dict(chat_box=chat_box, **chat_elems) engine.manager.all_elems["infer"] = dict(chat_box=chat_box, **chat_elems)
demo.load(engine.resume, [config], engine.manager.list_elems()) demo.load(engine.resume, outputs=engine.manager.list_elems())
lang.change(engine.change_lang, [lang], engine.manager.list_elems(), queue=False)
lang.change( lang.input(save_config, inputs=[lang], queue=False)
engine.change_lang, [lang], engine.manager.list_elems(), queue=False
).then(
save_config, inputs=[config, lang]
)
return demo return demo

View File

@ -18,7 +18,6 @@ class Manager:
def get_base_elems(self): def get_base_elems(self):
return { return {
self.all_elems["top"]["config"],
self.all_elems["top"]["lang"], self.all_elems["top"]["lang"],
self.all_elems["top"]["model_name"], self.all_elems["top"]["model_name"],
self.all_elems["top"]["model_path"], self.all_elems["top"]["model_path"],

View File

@ -14,7 +14,7 @@ from llmtuner.extras.constants import TRAINING_STAGES
from llmtuner.extras.logging import LoggerHandler from llmtuner.extras.logging import LoggerHandler
from llmtuner.extras.misc import torch_gc from llmtuner.extras.misc import torch_gc
from llmtuner.tuner import run_exp from llmtuner.tuner import run_exp
from llmtuner.webui.common import get_module, get_save_dir from llmtuner.webui.common import get_module, get_save_dir, load_config
from llmtuner.webui.locales import ALERTS from llmtuner.webui.locales import ALERTS
from llmtuner.webui.utils import gen_cmd, get_eval_results, update_process_bar from llmtuner.webui.utils import gen_cmd, get_eval_results, update_process_bar
@ -74,6 +74,7 @@ class Runner:
def _parse_train_args(self, data: Dict[Component, Any]) -> Tuple[str, str, str, List[str], str, Dict[str, Any]]: def _parse_train_args(self, data: Dict[Component, Any]) -> Tuple[str, str, str, List[str], str, Dict[str, Any]]:
get = lambda name: data[self.manager.get_elem(name)] get = lambda name: data[self.manager.get_elem(name)]
user_config = load_config()
if get("top.checkpoints"): if get("top.checkpoints"):
checkpoint_dir = ",".join([ checkpoint_dir = ",".join([
@ -89,7 +90,7 @@ class Runner:
model_name_or_path=get("top.model_path"), model_name_or_path=get("top.model_path"),
do_train=True, do_train=True,
overwrite_cache=False, overwrite_cache=False,
cache_dir=get("top.config").get("cache_dir", None), cache_dir=user_config.get("cache_dir", None),
checkpoint_dir=checkpoint_dir, checkpoint_dir=checkpoint_dir,
finetuning_type=get("top.finetuning_type"), finetuning_type=get("top.finetuning_type"),
quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None, quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None,
@ -142,6 +143,7 @@ class Runner:
def _parse_eval_args(self, data: Dict[Component, Any]) -> Tuple[str, str, str, List[str], str, Dict[str, Any]]: def _parse_eval_args(self, data: Dict[Component, Any]) -> Tuple[str, str, str, List[str], str, Dict[str, Any]]:
get = lambda name: data[self.manager.get_elem(name)] get = lambda name: data[self.manager.get_elem(name)]
user_config = load_config()
if get("top.checkpoints"): if get("top.checkpoints"):
checkpoint_dir = ",".join([ checkpoint_dir = ",".join([
@ -160,7 +162,7 @@ class Runner:
do_eval=True, do_eval=True,
overwrite_cache=False, overwrite_cache=False,
predict_with_generate=True, predict_with_generate=True,
cache_dir=get("top.config").get("cache_dir", None), cache_dir=user_config.get("cache_dir", None),
checkpoint_dir=checkpoint_dir, checkpoint_dir=checkpoint_dir,
finetuning_type=get("top.finetuning_type"), finetuning_type=get("top.finetuning_type"),
quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None, quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None,