refactor webui

Former-commit-id: 7ed1fa6fe9025a179bbd2a23a0d50213f53ffba2
This commit is contained in:
hiyouga 2023-10-15 03:06:21 +08:00
parent 089785c71b
commit a902ce4dc7
14 changed files with 440 additions and 501 deletions

View File

@ -1,69 +1,73 @@
from typing import Any, Dict, Generator, List, Optional, Tuple from gradio.components import Component # cannot use TYPE_CHECKING here
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple
from llmtuner.chat.stream_chat import ChatModel from llmtuner.chat.stream_chat import ChatModel
from llmtuner.extras.misc import torch_gc from llmtuner.extras.misc import torch_gc
from llmtuner.hparams import GeneratingArguments from llmtuner.hparams import GeneratingArguments
from llmtuner.webui.common import get_model_path, get_save_dir from llmtuner.webui.common import get_save_dir
from llmtuner.webui.locales import ALERTS from llmtuner.webui.locales import ALERTS
if TYPE_CHECKING:
from llmtuner.webui.manager import Manager
class WebChatModel(ChatModel): class WebChatModel(ChatModel):
def __init__(self, args: Optional[Dict[str, Any]] = None, lazy_init: Optional[bool] = True) -> None: def __init__(self, manager: "Manager", lazy_init: Optional[bool] = True) -> None:
self.manager = manager
self.model = None self.model = None
self.tokenizer = None self.tokenizer = None
self.generating_args = GeneratingArguments() self.generating_args = GeneratingArguments()
if not lazy_init: if not lazy_init:
super().__init__(args) super().__init__()
def load_model( @property
self, def loaded(self) -> bool:
lang: str, return self.model is not None
model_name: str,
checkpoints: List[str], def load_model(self, data: Dict[Component, Any]) -> Generator[str, None, None]:
finetuning_type: str, get = lambda name: data[self.manager.get_elem(name)]
quantization_bit: str, lang = get("top.lang")
template: str,
system_prompt: str, if self.loaded:
flash_attn: bool,
shift_attn: bool,
rope_scaling: str
) -> Generator[str, None, None]:
if self.model is not None:
yield ALERTS["err_exists"][lang] yield ALERTS["err_exists"][lang]
return return
if not model_name: if not get("top.model_name"):
yield ALERTS["err_no_model"][lang] yield ALERTS["err_no_model"][lang]
return return
model_name_or_path = get_model_path(model_name) if not get("top.model_path"):
if not model_name_or_path:
yield ALERTS["err_no_path"][lang] yield ALERTS["err_no_path"][lang]
return return
if checkpoints: if get("top.checkpoints"):
checkpoint_dir = ",".join([get_save_dir(model_name, finetuning_type, ckpt) for ckpt in checkpoints]) checkpoint_dir = ",".join([
get_save_dir(get("top.model_name"), get("top.finetuning_type"), ckpt) for ckpt in get("top.checkpoints")
])
else: else:
checkpoint_dir = None checkpoint_dir = None
yield ALERTS["info_loading"][lang] yield ALERTS["info_loading"][lang]
args = dict( args = dict(
model_name_or_path=model_name_or_path, model_name_or_path=get("top.model_path"),
checkpoint_dir=checkpoint_dir, checkpoint_dir=checkpoint_dir,
finetuning_type=finetuning_type, finetuning_type=get("top.finetuning_type"),
quantization_bit=int(quantization_bit) if quantization_bit in ["8", "4"] else None, quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None,
template=template, template=get("top.template"),
system_prompt=system_prompt, system_prompt=get("top.system_prompt"),
flash_attn=flash_attn, flash_attn=get("top.flash_attn"),
shift_attn=shift_attn, shift_attn=get("top.shift_attn"),
rope_scaling=rope_scaling if rope_scaling in ["linear", "dynamic"] else None rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None
) )
super().__init__(args) super().__init__(args)
yield ALERTS["info_loaded"][lang] yield ALERTS["info_loaded"][lang]
def unload_model(self, lang: str) -> Generator[str, None, None]: def unload_model(self, data: Dict[Component, Any]) -> Generator[str, None, None]:
get = lambda name: data[self.manager.get_elem(name)]
lang = get("top.lang")
yield ALERTS["info_unloading"][lang] yield ALERTS["info_unloading"][lang]
self.model = None self.model = None
self.tokenizer = None self.tokenizer = None

View File

@ -1,7 +1,7 @@
import os import os
import json import json
import gradio as gr import gradio as gr
from typing import Any, Dict, Optional from typing import Any, Dict, Optional, Union
from transformers.utils import ( from transformers.utils import (
WEIGHTS_NAME, WEIGHTS_NAME,
WEIGHTS_INDEX_NAME, WEIGHTS_INDEX_NAME,
@ -11,7 +11,7 @@ from transformers.utils import (
ADAPTER_SAFE_WEIGHTS_NAME ADAPTER_SAFE_WEIGHTS_NAME
) )
from llmtuner.extras.constants import DEFAULT_TEMPLATE, SUPPORTED_MODELS, TRAINING_STAGES from llmtuner.extras.constants import DEFAULT_MODULE, DEFAULT_TEMPLATE, SUPPORTED_MODELS, TRAINING_STAGES
DEFAULT_CACHE_DIR = "cache" DEFAULT_CACHE_DIR = "cache"
@ -27,6 +27,7 @@ CKPT_NAMES = [
ADAPTER_WEIGHTS_NAME, ADAPTER_WEIGHTS_NAME,
ADAPTER_SAFE_WEIGHTS_NAME ADAPTER_SAFE_WEIGHTS_NAME
] ]
CONFIG_CLASS = Dict[str, Union[str, Dict[str, str]]]
def get_save_dir(*args) -> os.PathLike: def get_save_dir(*args) -> os.PathLike:
@ -37,7 +38,7 @@ def get_config_path() -> os.PathLike:
return os.path.join(DEFAULT_CACHE_DIR, USER_CONFIG) return os.path.join(DEFAULT_CACHE_DIR, USER_CONFIG)
def load_config() -> Dict[str, Any]: def load_config() -> CONFIG_CLASS:
try: try:
with open(get_config_path(), "r", encoding="utf-8") as f: with open(get_config_path(), "r", encoding="utf-8") as f:
return json.load(f) return json.load(f)
@ -45,20 +46,24 @@ def load_config() -> Dict[str, Any]:
return {"lang": None, "last_model": None, "path_dict": {}, "cache_dir": None} return {"lang": None, "last_model": None, "path_dict": {}, "cache_dir": None}
def save_config(lang: str, model_name: str, model_path: str) -> None: def save_config(
config: CONFIG_CLASS, lang: str, model_name: Optional[str] = None, model_path: Optional[str] = None
) -> None:
os.makedirs(DEFAULT_CACHE_DIR, exist_ok=True) os.makedirs(DEFAULT_CACHE_DIR, exist_ok=True)
user_config = load_config() config["lang"] = lang or config["lang"]
user_config["lang"] = lang or user_config["lang"]
if model_name: if model_name:
user_config["last_model"] = model_name config["last_model"] = model_name
user_config["path_dict"][model_name] = model_path config["path_dict"][model_name] = model_path
with open(get_config_path(), "w", encoding="utf-8") as f: with open(get_config_path(), "w", encoding="utf-8") as f:
json.dump(user_config, f, indent=2, ensure_ascii=False) json.dump(config, f, indent=2, ensure_ascii=False)
def get_model_path(model_name: str) -> str: def get_model_path(config: Dict[str, Any], model_name: str) -> str:
user_config = load_config() return config["path_dict"].get(model_name, None) or SUPPORTED_MODELS.get(model_name, "")
return user_config["path_dict"].get(model_name, SUPPORTED_MODELS.get(model_name, ""))
def get_module(model_name: str) -> str:
return DEFAULT_MODULE.get(model_name.split("-")[0], "q_proj,v_proj")
def get_template(model_name: str) -> str: def get_template(model_name: str) -> str:

View File

@ -4,13 +4,15 @@ from typing import TYPE_CHECKING, Dict, Optional, Tuple
if TYPE_CHECKING: if TYPE_CHECKING:
from gradio.blocks import Block from gradio.blocks import Block
from gradio.components import Component from gradio.components import Component
from llmtuner.webui.chat import WebChatModel from llmtuner.webui.engine import Engine
def create_chat_box( def create_chat_box(
chat_model: "WebChatModel", engine: "Engine",
visible: Optional[bool] = False visible: Optional[bool] = False
) -> Tuple["Block", "Component", "Component", Dict[str, "Component"]]: ) -> Tuple["Block", "Component", "Component", Dict[str, "Component"]]:
elem_dict = dict()
with gr.Box(visible=visible) as chat_box: with gr.Box(visible=visible) as chat_box:
chatbot = gr.Chatbot() chatbot = gr.Chatbot()
@ -22,14 +24,20 @@ def create_chat_box(
with gr.Column(scale=1): with gr.Column(scale=1):
clear_btn = gr.Button() clear_btn = gr.Button()
max_new_tokens = gr.Slider(10, 2048, value=chat_model.generating_args.max_new_tokens, step=1) gen_kwargs = engine.chatter.generating_args
top_p = gr.Slider(0.01, 1, value=chat_model.generating_args.top_p, step=0.01) max_new_tokens = gr.Slider(10, 2048, value=gen_kwargs.max_new_tokens, step=1)
temperature = gr.Slider(0.01, 1.5, value=chat_model.generating_args.temperature, step=0.01) top_p = gr.Slider(0.01, 1, value=gen_kwargs.top_p, step=0.01)
temperature = gr.Slider(0.01, 1.5, value=gen_kwargs.temperature, step=0.01)
elem_dict.update(dict(
system=system, query=query, submit_btn=submit_btn, clear_btn=clear_btn,
max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature
))
history = gr.State([]) history = gr.State([])
submit_btn.click( submit_btn.click(
chat_model.predict, engine.chatter.predict,
[chatbot, query, history, system, max_new_tokens, top_p, temperature], [chatbot, query, history, system, max_new_tokens, top_p, temperature],
[chatbot, history], [chatbot, history],
show_progress=True show_progress=True
@ -39,12 +47,4 @@ def create_chat_box(
clear_btn.click(lambda: ([], []), outputs=[chatbot, history], show_progress=True) clear_btn.click(lambda: ([], []), outputs=[chatbot, history], show_progress=True)
return chat_box, chatbot, history, dict( return chat_box, chatbot, history, elem_dict
system=system,
query=query,
submit_btn=submit_btn,
clear_btn=clear_btn,
max_new_tokens=max_new_tokens,
top_p=top_p,
temperature=temperature
)

View File

@ -7,19 +7,28 @@ from llmtuner.webui.utils import can_preview, get_preview
if TYPE_CHECKING: if TYPE_CHECKING:
from gradio.components import Component from gradio.components import Component
from llmtuner.webui.runner import Runner from llmtuner.webui.engine import Engine
def create_eval_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict[str, "Component"]: def create_eval_tab(engine: "Engine") -> Dict[str, "Component"]:
input_elems = engine.manager.get_base_elems()
elem_dict = dict()
with gr.Row(): with gr.Row():
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2) dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2)
dataset = gr.Dropdown(multiselect=True, scale=4) dataset = gr.Dropdown(multiselect=True, scale=4)
data_preview_btn = gr.Button(interactive=False, scale=1) data_preview_btn = gr.Button(interactive=False, scale=1)
preview_box, preview_count, preview_samples, close_btn = create_preview_box()
dataset_dir.change(list_dataset, [dataset_dir], [dataset]) dataset_dir.change(list_dataset, [dataset_dir], [dataset])
dataset.change(can_preview, [dataset_dir, dataset], [data_preview_btn]) dataset.change(can_preview, [dataset_dir, dataset], [data_preview_btn])
input_elems.update({dataset_dir, dataset})
elem_dict.update(dict(
dataset_dir=dataset_dir, dataset=dataset, data_preview_btn=data_preview_btn
))
preview_box, preview_count, preview_samples, close_btn = create_preview_box()
data_preview_btn.click( data_preview_btn.click(
get_preview, get_preview,
[dataset_dir, dataset], [dataset_dir, dataset],
@ -27,17 +36,31 @@ def create_eval_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict
queue=False queue=False
) )
elem_dict.update(dict(
preview_count=preview_count, preview_samples=preview_samples, close_btn=close_btn
))
with gr.Row(): with gr.Row():
cutoff_len = gr.Slider(value=1024, minimum=4, maximum=8192, step=1) cutoff_len = gr.Slider(value=1024, minimum=4, maximum=8192, step=1)
max_samples = gr.Textbox(value="100000") max_samples = gr.Textbox(value="100000")
batch_size = gr.Slider(value=8, minimum=1, maximum=512, step=1) batch_size = gr.Slider(value=8, minimum=1, maximum=512, step=1)
predict = gr.Checkbox(value=True) predict = gr.Checkbox(value=True)
input_elems.update({cutoff_len, max_samples, batch_size, predict})
elem_dict.update(dict(
cutoff_len=cutoff_len, max_samples=max_samples, batch_size=batch_size, predict=predict
))
with gr.Row(): with gr.Row():
max_new_tokens = gr.Slider(10, 2048, value=128, step=1) max_new_tokens = gr.Slider(10, 2048, value=128, step=1)
top_p = gr.Slider(0.01, 1, value=0.7, step=0.01) top_p = gr.Slider(0.01, 1, value=0.7, step=0.01)
temperature = gr.Slider(0.01, 1.5, value=0.95, step=0.01) temperature = gr.Slider(0.01, 1.5, value=0.95, step=0.01)
input_elems.update({max_new_tokens, top_p, temperature})
elem_dict.update(dict(
max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature
))
with gr.Row(): with gr.Row():
cmd_preview_btn = gr.Button() cmd_preview_btn = gr.Button()
start_btn = gr.Button() start_btn = gr.Button()
@ -49,53 +72,13 @@ def create_eval_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict
with gr.Box(): with gr.Box():
output_box = gr.Markdown() output_box = gr.Markdown()
input_components = [ output_elems = [output_box, process_bar]
top_elems["lang"], elem_dict.update(dict(
top_elems["model_name"], cmd_preview_btn=cmd_preview_btn, start_btn=start_btn, stop_btn=stop_btn, output_box=output_box
top_elems["checkpoints"], ))
top_elems["finetuning_type"],
top_elems["quantization_bit"],
top_elems["template"],
top_elems["system_prompt"],
top_elems["flash_attn"],
top_elems["shift_attn"],
top_elems["rope_scaling"],
dataset_dir,
dataset,
cutoff_len,
max_samples,
batch_size,
predict,
max_new_tokens,
top_p,
temperature
]
output_components = [ cmd_preview_btn.click(engine.runner.preview_eval, input_elems, output_elems)
output_box, start_btn.click(engine.runner.run_eval, input_elems, output_elems)
process_bar stop_btn.click(engine.runner.set_abort, queue=False)
]
cmd_preview_btn.click(runner.preview_eval, input_components, output_components) return elem_dict
start_btn.click(runner.run_eval, input_components, output_components)
stop_btn.click(runner.set_abort, queue=False)
return dict(
dataset_dir=dataset_dir,
dataset=dataset,
data_preview_btn=data_preview_btn,
preview_count=preview_count,
preview_samples=preview_samples,
close_btn=close_btn,
cutoff_len=cutoff_len,
max_samples=max_samples,
batch_size=batch_size,
predict=predict,
max_new_tokens=max_new_tokens,
top_p=top_p,
temperature=temperature,
cmd_preview_btn=cmd_preview_btn,
start_btn=start_btn,
stop_btn=stop_btn,
output_box=output_box
)

View File

@ -5,9 +5,12 @@ from llmtuner.webui.utils import save_model
if TYPE_CHECKING: if TYPE_CHECKING:
from gradio.components import Component from gradio.components import Component
from llmtuner.webui.engine import Engine
def create_export_tab(top_elems: Dict[str, "Component"]) -> Dict[str, "Component"]: def create_export_tab(engine: "Engine") -> Dict[str, "Component"]:
elem_dict = dict()
with gr.Row(): with gr.Row():
save_dir = gr.Textbox() save_dir = gr.Textbox()
max_shard_size = gr.Slider(value=10, minimum=1, maximum=100) max_shard_size = gr.Slider(value=10, minimum=1, maximum=100)
@ -18,20 +21,23 @@ def create_export_tab(top_elems: Dict[str, "Component"]) -> Dict[str, "Component
export_btn.click( export_btn.click(
save_model, save_model,
[ [
top_elems["lang"], engine.manager.get_elem("top.lang"),
top_elems["model_name"], engine.manager.get_elem("top.model_name"),
top_elems["checkpoints"], engine.manager.get_elem("top.model_path"),
top_elems["finetuning_type"], engine.manager.get_elem("top.checkpoints"),
top_elems["template"], engine.manager.get_elem("top.finetuning_type"),
engine.manager.get_elem("top.template"),
max_shard_size, max_shard_size,
save_dir save_dir
], ],
[info_box] [info_box]
) )
return dict( elem_dict.update(dict(
save_dir=save_dir, save_dir=save_dir,
max_shard_size=max_shard_size, max_shard_size=max_shard_size,
export_btn=export_btn, export_btn=export_btn,
info_box=info_box info_box=info_box
) ))
return elem_dict

View File

@ -1,53 +1,42 @@
import gradio as gr import gradio as gr
from typing import TYPE_CHECKING, Dict from typing import TYPE_CHECKING, Dict
from llmtuner.webui.chat import WebChatModel
from llmtuner.webui.components.chatbot import create_chat_box from llmtuner.webui.components.chatbot import create_chat_box
if TYPE_CHECKING: if TYPE_CHECKING:
from gradio.components import Component from gradio.components import Component
from llmtuner.webui.engine import Engine
def create_infer_tab(top_elems: Dict[str, "Component"]) -> Dict[str, "Component"]: def create_infer_tab(engine: "Engine") -> Dict[str, "Component"]:
input_elems = engine.manager.get_base_elems()
elem_dict = dict()
with gr.Row(): with gr.Row():
load_btn = gr.Button() load_btn = gr.Button()
unload_btn = gr.Button() unload_btn = gr.Button()
info_box = gr.Textbox(show_label=False, interactive=False) info_box = gr.Textbox(show_label=False, interactive=False)
chat_model = WebChatModel(lazy_init=True) elem_dict.update(dict(
chat_box, chatbot, history, chat_elems = create_chat_box(chat_model) info_box=info_box, load_btn=load_btn, unload_btn=unload_btn
))
chat_box, chatbot, history, chat_elems = create_chat_box(engine, visible=False)
elem_dict.update(dict(chat_box=chat_box, **chat_elems))
load_btn.click( load_btn.click(
chat_model.load_model, engine.chatter.load_model, input_elems, [info_box]
[
top_elems["lang"],
top_elems["model_name"],
top_elems["checkpoints"],
top_elems["finetuning_type"],
top_elems["quantization_bit"],
top_elems["template"],
top_elems["system_prompt"],
top_elems["flash_attn"],
top_elems["shift_attn"],
top_elems["rope_scaling"]
],
[info_box]
).then( ).then(
lambda: gr.update(visible=(chat_model.model is not None)), outputs=[chat_box] lambda: gr.update(visible=engine.chatter.loaded), outputs=[chat_box]
) )
unload_btn.click( unload_btn.click(
chat_model.unload_model, [top_elems["lang"]], [info_box] engine.chatter.unload_model, input_elems, [info_box]
).then( ).then(
lambda: ([], []), outputs=[chatbot, history] lambda: ([], []), outputs=[chatbot, history]
).then( ).then(
lambda: gr.update(visible=(chat_model.model is not None)), outputs=[chat_box] lambda: gr.update(visible=engine.chatter.loaded), outputs=[chat_box]
) )
return dict( return elem_dict
info_box=info_box,
load_btn=load_btn,
unload_btn=unload_btn,
**chat_elems
)

View File

@ -3,15 +3,17 @@ from typing import TYPE_CHECKING, Dict
from llmtuner.extras.constants import METHODS, SUPPORTED_MODELS from llmtuner.extras.constants import METHODS, SUPPORTED_MODELS
from llmtuner.extras.template import templates from llmtuner.extras.template import templates
from llmtuner.webui.common import list_checkpoint, get_model_path, get_template, save_config from llmtuner.webui.common import get_model_path, get_template, list_checkpoint, load_config, save_config
from llmtuner.webui.utils import can_quantize from llmtuner.webui.utils import can_quantize
if TYPE_CHECKING: if TYPE_CHECKING:
from gradio.components import Component from gradio.components import Component
from llmtuner.webui.engine import Engine
def create_top() -> Dict[str, "Component"]: def create_top(engine: "Engine") -> Dict[str, "Component"]:
available_models = list(SUPPORTED_MODELS.keys()) + ["Custom"] available_models = list(SUPPORTED_MODELS.keys()) + ["Custom"]
config = gr.State(value=load_config())
with gr.Row(): with gr.Row():
lang = gr.Dropdown(choices=["en", "zh"], scale=1) lang = gr.Dropdown(choices=["en", "zh"], scale=1)
@ -35,17 +37,21 @@ def create_top() -> Dict[str, "Component"]:
shift_attn = gr.Checkbox(value=False) shift_attn = gr.Checkbox(value=False)
rope_scaling = gr.Dropdown(choices=["none", "linear", "dynamic"], value="none") rope_scaling = gr.Dropdown(choices=["none", "linear", "dynamic"], value="none")
lang.change(save_config, [lang, model_name, model_path]) lang.change(
engine.change_lang, [lang], engine.manager.list_elems(), queue=False
).then(
save_config, inputs=[config, lang, model_name, model_path]
)
model_name.change( model_name.change(
list_checkpoint, [model_name, finetuning_type], [checkpoints] list_checkpoint, [model_name, finetuning_type], [checkpoints]
).then( ).then(
get_model_path, [model_name], [model_path] get_model_path, [config, model_name], [model_path]
).then( ).then(
get_template, [model_name], [template] get_template, [model_name], [template]
) # do not save config since the below line will save ) # do not save config since the below line will save
model_path.change(save_config, [lang, model_name, model_path]) model_path.change(save_config, inputs=[config, lang, model_name, model_path])
finetuning_type.change( finetuning_type.change(
list_checkpoint, [model_name, finetuning_type], [checkpoints] list_checkpoint, [model_name, finetuning_type], [checkpoints]
@ -58,6 +64,7 @@ def create_top() -> Dict[str, "Component"]:
) )
return dict( return dict(
config=config,
lang=lang, lang=lang,
model_name=model_name, model_name=model_name,
model_path=model_path, model_path=model_path,

View File

@ -9,10 +9,13 @@ from llmtuner.webui.utils import can_preview, get_preview, gen_plot
if TYPE_CHECKING: if TYPE_CHECKING:
from gradio.components import Component from gradio.components import Component
from llmtuner.webui.runner import Runner from llmtuner.webui.engine import Engine
def create_train_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict[str, "Component"]: def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
input_elems = engine.manager.get_base_elems()
elem_dict = dict()
with gr.Row(): with gr.Row():
training_stage = gr.Dropdown( training_stage = gr.Dropdown(
choices=list(TRAINING_STAGES.keys()), value=list(TRAINING_STAGES.keys())[0], scale=2 choices=list(TRAINING_STAGES.keys()), value=list(TRAINING_STAGES.keys())[0], scale=2
@ -21,11 +24,17 @@ def create_train_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dic
dataset = gr.Dropdown(multiselect=True, scale=4) dataset = gr.Dropdown(multiselect=True, scale=4)
data_preview_btn = gr.Button(interactive=False, scale=1) data_preview_btn = gr.Button(interactive=False, scale=1)
preview_box, preview_count, preview_samples, close_btn = create_preview_box()
training_stage.change(list_dataset, [dataset_dir, training_stage], [dataset]) training_stage.change(list_dataset, [dataset_dir, training_stage], [dataset])
dataset_dir.change(list_dataset, [dataset_dir, training_stage], [dataset]) dataset_dir.change(list_dataset, [dataset_dir, training_stage], [dataset])
dataset.change(can_preview, [dataset_dir, dataset], [data_preview_btn]) dataset.change(can_preview, [dataset_dir, dataset], [data_preview_btn])
input_elems.update({training_stage, dataset_dir, dataset})
elem_dict.update(dict(
training_stage=training_stage, dataset_dir=dataset_dir, dataset=dataset, data_preview_btn=data_preview_btn
))
preview_box, preview_count, preview_samples, close_btn = create_preview_box()
data_preview_btn.click( data_preview_btn.click(
get_preview, get_preview,
[dataset_dir, dataset], [dataset_dir, dataset],
@ -33,6 +42,10 @@ def create_train_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dic
queue=False queue=False
) )
elem_dict.update(dict(
preview_count=preview_count, preview_samples=preview_samples, close_btn=close_btn
))
with gr.Row(): with gr.Row():
cutoff_len = gr.Slider(value=1024, minimum=4, maximum=8192, step=1) cutoff_len = gr.Slider(value=1024, minimum=4, maximum=8192, step=1)
learning_rate = gr.Textbox(value="5e-5") learning_rate = gr.Textbox(value="5e-5")
@ -40,6 +53,12 @@ def create_train_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dic
max_samples = gr.Textbox(value="100000") max_samples = gr.Textbox(value="100000")
compute_type = gr.Radio(choices=["fp16", "bf16"], value="fp16") compute_type = gr.Radio(choices=["fp16", "bf16"], value="fp16")
input_elems.update({cutoff_len, learning_rate, num_train_epochs, max_samples, compute_type})
elem_dict.update(dict(
cutoff_len=cutoff_len, learning_rate=learning_rate, num_train_epochs=num_train_epochs,
max_samples=max_samples, compute_type=compute_type
))
with gr.Row(): with gr.Row():
batch_size = gr.Slider(value=4, minimum=1, maximum=512, step=1) batch_size = gr.Slider(value=4, minimum=1, maximum=512, step=1)
gradient_accumulation_steps = gr.Slider(value=4, minimum=1, maximum=512, step=1) gradient_accumulation_steps = gr.Slider(value=4, minimum=1, maximum=512, step=1)
@ -49,12 +68,23 @@ def create_train_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dic
max_grad_norm = gr.Textbox(value="1.0") max_grad_norm = gr.Textbox(value="1.0")
val_size = gr.Slider(value=0, minimum=0, maximum=1, step=0.001) val_size = gr.Slider(value=0, minimum=0, maximum=1, step=0.001)
input_elems.update({batch_size, gradient_accumulation_steps, lr_scheduler_type, max_grad_norm, val_size})
elem_dict.update(dict(
batch_size=batch_size, gradient_accumulation_steps=gradient_accumulation_steps,
lr_scheduler_type=lr_scheduler_type, max_grad_norm=max_grad_norm, val_size=val_size
))
with gr.Accordion(label="Advanced config", open=False) as advanced_tab: with gr.Accordion(label="Advanced config", open=False) as advanced_tab:
with gr.Row(): with gr.Row():
logging_steps = gr.Slider(value=5, minimum=5, maximum=1000, step=5) logging_steps = gr.Slider(value=5, minimum=5, maximum=1000, step=5)
save_steps = gr.Slider(value=100, minimum=10, maximum=5000, step=10) save_steps = gr.Slider(value=100, minimum=10, maximum=5000, step=10)
warmup_steps = gr.Slider(value=0, minimum=0, maximum=5000, step=1) warmup_steps = gr.Slider(value=0, minimum=0, maximum=5000, step=1)
input_elems.update({logging_steps, save_steps, warmup_steps})
elem_dict.update(dict(
advanced_tab=advanced_tab, logging_steps=logging_steps, save_steps=save_steps, warmup_steps=warmup_steps
))
with gr.Accordion(label="LoRA config", open=False) as lora_tab: with gr.Accordion(label="LoRA config", open=False) as lora_tab:
with gr.Row(): with gr.Row():
lora_rank = gr.Slider(value=8, minimum=1, maximum=1024, step=1, scale=1) lora_rank = gr.Slider(value=8, minimum=1, maximum=1024, step=1, scale=1)
@ -62,6 +92,15 @@ def create_train_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dic
lora_target = gr.Textbox(scale=2) lora_target = gr.Textbox(scale=2)
resume_lora_training = gr.Checkbox(value=True, scale=1) resume_lora_training = gr.Checkbox(value=True, scale=1)
input_elems.update({lora_rank, lora_dropout, lora_target, resume_lora_training})
elem_dict.update(dict(
lora_tab=lora_tab,
lora_rank=lora_rank,
lora_dropout=lora_dropout,
lora_target=lora_target,
resume_lora_training=resume_lora_training,
))
with gr.Accordion(label="RLHF config", open=False) as rlhf_tab: with gr.Accordion(label="RLHF config", open=False) as rlhf_tab:
with gr.Row(): with gr.Row():
dpo_beta = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01, scale=1) dpo_beta = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01, scale=1)
@ -70,11 +109,14 @@ def create_train_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dic
refresh_btn.click( refresh_btn.click(
list_checkpoint, list_checkpoint,
[top_elems["model_name"], top_elems["finetuning_type"]], [engine.manager.get_elem("top.model_name"), engine.manager.get_elem("top.finetuning_type")],
[reward_model], [reward_model],
queue=False queue=False
) )
input_elems.update({dpo_beta, reward_model})
elem_dict.update(dict(rlhf_tab=rlhf_tab, dpo_beta=dpo_beta, reward_model=reward_model, refresh_btn=refresh_btn))
with gr.Row(): with gr.Row():
cmd_preview_btn = gr.Button() cmd_preview_btn = gr.Button()
start_btn = gr.Button() start_btn = gr.Button()
@ -94,90 +136,22 @@ def create_train_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dic
with gr.Column(scale=1): with gr.Column(scale=1):
loss_viewer = gr.Plot() loss_viewer = gr.Plot()
input_components = [ input_elems.add(output_dir)
top_elems["lang"], output_elems = [output_box, process_bar]
top_elems["model_name"], elem_dict.update(dict(
top_elems["checkpoints"], cmd_preview_btn=cmd_preview_btn, start_btn=start_btn, stop_btn=stop_btn,
top_elems["finetuning_type"], output_dir=output_dir, output_box=output_box, loss_viewer=loss_viewer
top_elems["quantization_bit"], ))
top_elems["template"],
top_elems["system_prompt"],
top_elems["flash_attn"],
top_elems["shift_attn"],
top_elems["rope_scaling"],
training_stage,
dataset_dir,
dataset,
cutoff_len,
learning_rate,
num_train_epochs,
max_samples,
compute_type,
batch_size,
gradient_accumulation_steps,
lr_scheduler_type,
max_grad_norm,
val_size,
logging_steps,
save_steps,
warmup_steps,
lora_rank,
lora_dropout,
lora_target,
resume_lora_training,
dpo_beta,
reward_model,
output_dir
]
output_components = [ cmd_preview_btn.click(engine.runner.preview_train, input_elems, output_elems)
output_box, start_btn.click(engine.runner.run_train, input_elems, output_elems)
process_bar stop_btn.click(engine.runner.set_abort, queue=False)
]
cmd_preview_btn.click(runner.preview_train, input_components, output_components)
start_btn.click(runner.run_train, input_components, output_components)
stop_btn.click(runner.set_abort, queue=False)
process_bar.change( process_bar.change(
gen_plot, [top_elems["model_name"], top_elems["finetuning_type"], output_dir], loss_viewer, queue=False gen_plot,
[engine.manager.get_elem("top.model_name"), engine.manager.get_elem("top.finetuning_type"), output_dir],
loss_viewer,
queue=False
) )
return dict( return elem_dict
training_stage=training_stage,
dataset_dir=dataset_dir,
dataset=dataset,
data_preview_btn=data_preview_btn,
preview_count=preview_count,
preview_samples=preview_samples,
close_btn=close_btn,
cutoff_len=cutoff_len,
learning_rate=learning_rate,
num_train_epochs=num_train_epochs,
max_samples=max_samples,
compute_type=compute_type,
batch_size=batch_size,
gradient_accumulation_steps=gradient_accumulation_steps,
lr_scheduler_type=lr_scheduler_type,
max_grad_norm=max_grad_norm,
val_size=val_size,
advanced_tab=advanced_tab,
logging_steps=logging_steps,
save_steps=save_steps,
warmup_steps=warmup_steps,
lora_tab=lora_tab,
lora_rank=lora_rank,
lora_dropout=lora_dropout,
lora_target=lora_target,
resume_lora_training=resume_lora_training,
rlhf_tab=rlhf_tab,
dpo_beta=dpo_beta,
reward_model=reward_model,
refresh_btn=refresh_btn,
cmd_preview_btn=cmd_preview_btn,
start_btn=start_btn,
stop_btn=stop_btn,
output_dir=output_dir,
output_box=output_box,
loss_viewer=loss_viewer
)

View File

@ -0,0 +1,46 @@
import gradio as gr
from gradio.components import Component # cannot use TYPE_CHECKING here
from typing import Any, Dict, Generator, List, Optional, Tuple
from llmtuner.webui.chatter import WebChatModel
from llmtuner.webui.common import get_model_path, list_dataset, CONFIG_CLASS
from llmtuner.webui.locales import LOCALES
from llmtuner.webui.manager import Manager
from llmtuner.webui.runner import Runner
from llmtuner.webui.utils import get_time
class Engine:
def __init__(self, init_chat: Optional[bool] = False) -> None:
self.manager: "Manager" = Manager()
self.runner: "Runner" = Runner(self.manager)
self.chatter: "WebChatModel" = WebChatModel(manager=self.manager, lazy_init=(not init_chat))
def resume(self, config: CONFIG_CLASS) -> Generator[Dict[Component, Dict[str, Any]], None, None]:
lang = config.get("lang", None) or "en"
resume_dict = {
"top.config": {"value": config},
"top.lang": {"value": lang},
"train.dataset": {"choices": list_dataset()["choices"]},
"eval.dataset": {"choices": list_dataset()["choices"]},
"infer.chat_box": {"visible": self.chatter.loaded}
}
if config.get("last_model", None):
resume_dict["top.model_name"] = {"value": config["last_model"]}
resume_dict["top.model_path"] = {"value": get_model_path(config, config["last_model"])}
yield {self.manager.get_elem(k): gr.update(**v) for k, v in resume_dict.items()}
if self.runner.alive:
pass # TODO: restore training
else:
resume_dict = {"train.output_dir": {"value": get_time()}} # TODO: xxx
def change_lang(self, lang: str) -> Dict[Component, Dict[str, Any]]:
return {
component: gr.update(**LOCALES[name][lang])
for elems in self.manager.all_elems.values() for name, component in elems.items()
}

View File

@ -9,65 +9,54 @@ from llmtuner.webui.components import (
create_export_tab, create_export_tab,
create_chat_box create_chat_box
) )
from llmtuner.webui.chat import WebChatModel from llmtuner.webui.common import load_config, save_config
from llmtuner.webui.css import CSS from llmtuner.webui.css import CSS
from llmtuner.webui.manager import Manager from llmtuner.webui.engine import Engine
from llmtuner.webui.runner import Runner
require_version("gradio>=3.36.0", "To fix: pip install gradio>=3.36.0") require_version("gradio>=3.36.0", "To fix: pip install gradio>=3.36.0")
def create_ui() -> gr.Blocks: def create_ui() -> gr.Blocks:
runner = Runner() engine = Engine(init_chat=False)
with gr.Blocks(title="Web Tuner", css=CSS) as demo: with gr.Blocks(title="Web Tuner", css=CSS) as demo:
top_elems = create_top() engine.manager.all_elems["top"] = create_top(engine)
with gr.Tab("Train"): with gr.Tab("Train"):
train_elems = create_train_tab(top_elems, runner) engine.manager.all_elems["train"] = create_train_tab(engine)
with gr.Tab("Evaluate"): with gr.Tab("Evaluate"):
eval_elems = create_eval_tab(top_elems, runner) engine.manager.all_elems["eval"] = create_eval_tab(engine)
with gr.Tab("Chat"): with gr.Tab("Chat"):
infer_elems = create_infer_tab(top_elems) engine.manager.all_elems["infer"] = create_infer_tab(engine)
with gr.Tab("Export"): with gr.Tab("Export"):
export_elems = create_export_tab(top_elems) engine.manager.all_elems["export"] = create_export_tab(engine)
elem_list = [top_elems, train_elems, eval_elems, infer_elems, export_elems] demo.load(engine.resume, [engine.manager.get_elem("top.config")], engine.manager.list_elems())
manager = Manager(elem_list)
demo.load(
manager.gen_label,
[top_elems["lang"]],
[elem for elems in elem_list for elem in elems.values()],
)
top_elems["lang"].change(
manager.gen_label,
[top_elems["lang"]],
[elem for elems in elem_list for elem in elems.values()],
queue=False
)
return demo return demo
def create_web_demo() -> gr.Blocks: def create_web_demo() -> gr.Blocks:
chat_model = WebChatModel(lazy_init=False) engine = Engine(init_chat=True)
with gr.Blocks(title="Web Demo", css=CSS) as demo: with gr.Blocks(title="Web Demo", css=CSS) as demo:
lang = gr.Dropdown(choices=["en", "zh"], value="en") lang = gr.Dropdown(choices=["en", "zh"])
config = gr.State(value=load_config())
lang.change(
engine.change_lang, [lang], engine.manager.list_elems(), queue=False
).then(
save_config, inputs=[config, lang]
)
_, _, _, chat_elems = create_chat_box(chat_model, visible=True) engine.manager.all_elems["top"] = dict(lang=lang)
manager = Manager([{"lang": lang}, chat_elems]) _, _, _, engine.manager.all_elems["infer"] = create_chat_box(engine, visible=True)
demo.load(manager.gen_label, [lang], [lang] + list(chat_elems.values())) demo.load(engine.resume, [config], engine.manager.list_elems())
lang.select(manager.gen_label, [lang], [lang] + list(chat_elems.values()), queue=False)
return demo return demo

View File

@ -1,4 +1,8 @@
LOCALES = { LOCALES = {
"config": {
"en": {},
"zh": {}
},
"lang": { "lang": {
"en": { "en": {
"label": "Lang" "label": "Lang"
@ -443,6 +447,10 @@ LOCALES = {
"label": "保存预测结果" "label": "保存预测结果"
} }
}, },
"chat_box": {
"en": {},
"zh": {}
},
"load_btn": { "load_btn": {
"en": { "en": {
"value": "Load model" "value": "Load model"

View File

@ -1,46 +1,36 @@
import gradio as gr from typing import TYPE_CHECKING, Dict, List
from gradio.components import Component
from typing import Any, Dict, List
from llmtuner.webui.common import get_model_path, list_dataset, load_config if TYPE_CHECKING:
from llmtuner.webui.locales import LOCALES from gradio.components import Component
from llmtuner.webui.utils import get_time
class Manager: class Manager:
def __init__(self, elem_list: List[Dict[str, Component]]): def __init__(self) -> None:
self.elem_list = elem_list self.all_elems: Dict[str, Dict[str, "Component"]] = {}
def gen_refresh(self, lang: str) -> Dict[str, Any]: def get_elem(self, name: str) -> "Component":
refresh_dict = { r"""
"dataset": {"choices": list_dataset()["choices"]}, Example: top.lang, train.dataset
"output_dir": {"value": get_time()} """
tab_name, elem_name = name.split(".")
return self.all_elems[tab_name][elem_name]
def get_base_elems(self):
return {
self.all_elems["top"]["config"],
self.all_elems["top"]["lang"],
self.all_elems["top"]["model_name"],
self.all_elems["top"]["model_path"],
self.all_elems["top"]["checkpoints"],
self.all_elems["top"]["finetuning_type"],
self.all_elems["top"]["quantization_bit"],
self.all_elems["top"]["template"],
self.all_elems["top"]["system_prompt"],
self.all_elems["top"]["flash_attn"],
self.all_elems["top"]["shift_attn"],
self.all_elems["top"]["rope_scaling"]
} }
user_config = load_config() def list_elems(self) -> List["Component"]:
if not lang: return [elem for elems in self.all_elems.values() for elem in elems.values()]
if user_config.get("lang", None):
lang = user_config["lang"]
else:
lang = "en"
refresh_dict["lang"] = {"value": lang}
if user_config.get("last_model", None):
refresh_dict["model_name"] = {"value": user_config["last_model"]}
refresh_dict["model_path"] = {"value": get_model_path(user_config["last_model"])}
return refresh_dict
def gen_label(self, lang: str) -> Dict[Component, Dict[str, Any]]: # cannot use TYPE_CHECKING
update_dict = {}
refresh_dict = self.gen_refresh(lang)
for elems in self.elem_list:
for name, component in elems.items():
update_dict[component] = gr.update(
**LOCALES[name][refresh_dict["lang"]["value"]], **refresh_dict.get(name, {})
)
return update_dict

View File

@ -1,26 +1,32 @@
import os import os
import time import time
import logging import logging
import threading
import gradio as gr import gradio as gr
from typing import Any, Dict, Generator, List, Tuple from threading import Thread
from gradio.components import Component # cannot use TYPE_CHECKING here
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Tuple
import transformers import transformers
from transformers.trainer import TRAINING_ARGS_NAME from transformers.trainer import TRAINING_ARGS_NAME
from llmtuner.extras.callbacks import LogCallback from llmtuner.extras.callbacks import LogCallback
from llmtuner.extras.constants import DEFAULT_MODULE, TRAINING_STAGES from llmtuner.extras.constants import TRAINING_STAGES
from llmtuner.extras.logging import LoggerHandler from llmtuner.extras.logging import LoggerHandler
from llmtuner.extras.misc import torch_gc from llmtuner.extras.misc import torch_gc
from llmtuner.tuner import run_exp from llmtuner.tuner import run_exp
from llmtuner.webui.common import get_model_path, get_save_dir, load_config from llmtuner.webui.common import get_module, get_save_dir
from llmtuner.webui.locales import ALERTS from llmtuner.webui.locales import ALERTS
from llmtuner.webui.utils import gen_cmd, get_eval_results, update_process_bar from llmtuner.webui.utils import gen_cmd, get_eval_results, update_process_bar
if TYPE_CHECKING:
from llmtuner.webui.manager import Manager
class Runner: class Runner:
def __init__(self): def __init__(self, manager: "Manager") -> None:
self.manager = manager
self.thread: "Thread" = None
self.aborted = False self.aborted = False
self.running = False self.running = False
self.logger_handler = LoggerHandler() self.logger_handler = LoggerHandler()
@ -28,20 +34,22 @@ class Runner:
logging.root.addHandler(self.logger_handler) logging.root.addHandler(self.logger_handler)
transformers.logging.add_handler(self.logger_handler) transformers.logging.add_handler(self.logger_handler)
def set_abort(self): @property
def alive(self) -> bool:
return self.thread is not None
def set_abort(self) -> None:
self.aborted = True self.aborted = True
self.running = False self.running = False
def _initialize( def _initialize(self, lang: str, model_name: str, model_path: str, dataset: List[str]) -> str:
self, lang: str, model_name: str, dataset: List[str]
) -> str:
if self.running: if self.running:
return ALERTS["err_conflict"][lang] return ALERTS["err_conflict"][lang]
if not model_name: if not model_name:
return ALERTS["err_no_model"][lang] return ALERTS["err_no_model"][lang]
if not get_model_path(model_name): if not model_path:
return ALERTS["err_no_path"][lang] return ALERTS["err_no_path"][lang]
if len(dataset) == 0: if len(dataset) == 0:
@ -52,9 +60,8 @@ class Runner:
self.trainer_callback = LogCallback(self) self.trainer_callback = LogCallback(self)
return "" return ""
def _finalize( def _finalize(self, lang: str, finish_info: str) -> str:
self, lang: str, finish_info: str self.thread = None
) -> str:
self.running = False self.running = False
torch_gc() torch_gc()
if self.aborted: if self.aborted:
@ -62,236 +69,171 @@ class Runner:
else: else:
return finish_info return finish_info
def _parse_train_args( def _parse_train_args(self, data: Dict[Component, Any]) -> Tuple[str, str, str, List[str], str, Dict[str, Any]]:
self, get = lambda name: data[self.manager.get_elem(name)]
lang: str,
model_name: str, if get("top.checkpoints"):
checkpoints: List[str], checkpoint_dir = ",".join([
finetuning_type: str, get_save_dir(get("top.model_name"), get("top.finetuning_type"), ckpt) for ckpt in get("top.checkpoints")
quantization_bit: str, ])
template: str,
system_prompt: str,
flash_attn: bool,
shift_attn: bool,
rope_scaling: str,
training_stage: str,
dataset_dir: str,
dataset: List[str],
cutoff_len: int,
learning_rate: str,
num_train_epochs: str,
max_samples: str,
compute_type: str,
batch_size: int,
gradient_accumulation_steps: int,
lr_scheduler_type: str,
max_grad_norm: str,
val_size: float,
logging_steps: int,
save_steps: int,
warmup_steps: int,
lora_rank: int,
lora_dropout: float,
lora_target: str,
resume_lora_training: bool,
dpo_beta: float,
reward_model: str,
output_dir: str
) -> Tuple[str, str, List[str], str, Dict[str, Any]]:
if checkpoints:
checkpoint_dir = ",".join([get_save_dir(model_name, finetuning_type, ckpt) for ckpt in checkpoints])
else: else:
checkpoint_dir = None checkpoint_dir = None
output_dir = get_save_dir(model_name, finetuning_type, output_dir) output_dir = get_save_dir(get("top.model_name"), get("top.finetuning_type"), get("train.output_dir"))
user_config = load_config()
cache_dir = user_config.get("cache_dir", None)
args = dict( args = dict(
stage=TRAINING_STAGES[training_stage], stage=TRAINING_STAGES[get("train.training_stage")],
model_name_or_path=get_model_path(model_name), model_name_or_path=get("top.model_path"),
do_train=True, do_train=True,
overwrite_cache=False, overwrite_cache=False,
cache_dir=cache_dir, cache_dir=get("top.config").get("cache_dir", None),
checkpoint_dir=checkpoint_dir, checkpoint_dir=checkpoint_dir,
finetuning_type=finetuning_type, finetuning_type=get("top.finetuning_type"),
quantization_bit=int(quantization_bit) if quantization_bit in ["8", "4"] else None, quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None,
template=template, template=get("top.template"),
system_prompt=system_prompt, system_prompt=get("top.system_prompt"),
flash_attn=flash_attn, flash_attn=get("top.flash_attn"),
shift_attn=shift_attn, shift_attn=get("top.shift_attn"),
rope_scaling=rope_scaling if rope_scaling in ["linear", "dynamic"] else None, rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
dataset_dir=dataset_dir, dataset_dir=get("train.dataset_dir"),
dataset=",".join(dataset), dataset=",".join(get("train.dataset")),
cutoff_len=cutoff_len, cutoff_len=get("train.cutoff_len"),
learning_rate=float(learning_rate), learning_rate=float(get("train.learning_rate")),
num_train_epochs=float(num_train_epochs), num_train_epochs=float(get("train.num_train_epochs")),
max_samples=int(max_samples), max_samples=int(get("train.max_samples")),
per_device_train_batch_size=batch_size, per_device_train_batch_size=get("train.batch_size"),
gradient_accumulation_steps=gradient_accumulation_steps, gradient_accumulation_steps=get("train.gradient_accumulation_steps"),
lr_scheduler_type=lr_scheduler_type, lr_scheduler_type=get("train.lr_scheduler_type"),
max_grad_norm=float(max_grad_norm), max_grad_norm=float(get("train.max_grad_norm")),
logging_steps=logging_steps, logging_steps=get("train.logging_steps"),
save_steps=save_steps, save_steps=get("train.save_steps"),
warmup_steps=warmup_steps, warmup_steps=get("train.warmup_steps"),
lora_rank=lora_rank, lora_rank=get("train.lora_rank"),
lora_dropout=lora_dropout, lora_dropout=get("train.lora_dropout"),
lora_target=lora_target or DEFAULT_MODULE.get(model_name.split("-")[0], "q_proj,v_proj"), lora_target=get("train.lora_target") or get_module(get("top.model_name")),
resume_lora_training=resume_lora_training, resume_lora_training=get("train.resume_lora_training"),
output_dir=output_dir output_dir=output_dir
) )
args[compute_type] = True args[get("train.compute_type")] = True
if TRAINING_STAGES[training_stage] in ["rm", "ppo", "dpo"] and args["quantization_bit"] is None: if TRAINING_STAGES[get("train.training_stage")] in ["rm", "ppo", "dpo"]:
args["resume_lora_training"] = False args["resume_lora_training"] = (args["quantization_bit"] is not None)
if args["quantization_bit"] is not None: if args["quantization_bit"] is not None:
args["upcast_layernorm"] = True args["upcast_layernorm"] = True
if args["stage"] == "ppo": if args["stage"] == "ppo":
args["reward_model"] = reward_model args["reward_model"] = get("train.reward_model")
val_size = 0
if args["stage"] == "dpo": if args["stage"] == "dpo":
args["dpo_beta"] = dpo_beta args["dpo_beta"] = get("train.dpo_beta")
if val_size > 1e-6: if get("train.val_size") > 1e-6 and args["stage"] != "ppo":
args["val_size"] = val_size args["val_size"] = get("train.val_size")
args["evaluation_strategy"] = "steps" args["evaluation_strategy"] = "steps"
args["eval_steps"] = save_steps args["eval_steps"] = get("train.save_steps")
args["load_best_model_at_end"] = True args["load_best_model_at_end"] = True
return lang, model_name, dataset, output_dir, args return get("top.lang"), get("top.model_name"), get("top.model_path"), get("train.dataset"), output_dir, args
def _parse_eval_args( def _parse_eval_args(self, data: Dict[Component, Any]) -> Tuple[str, str, str, List[str], str, Dict[str, Any]]:
self, get = lambda name: data[self.manager.get_elem(name)]
lang: str,
model_name: str, if get("top.checkpoints"):
checkpoints: List[str], checkpoint_dir = ",".join([
finetuning_type: str, get_save_dir(get("top.model_name"), get("top.finetuning_type"), ckpt) for ckpt in get("top.checkpoints")
quantization_bit: str, ])
template: str, output_dir = get_save_dir(
system_prompt: str, get("top.model_name"), get("top.finetuning_type"), "eval_" + "_".join(get("top.checkpoints"))
flash_attn: bool, )
shift_attn: bool,
rope_scaling: str,
dataset_dir: str,
dataset: List[str],
cutoff_len: int,
max_samples: str,
batch_size: int,
predict: bool,
max_new_tokens: int,
top_p: float,
temperature: float
) -> Tuple[str, str, List[str], str, Dict[str, Any]]:
if checkpoints:
checkpoint_dir = ",".join([get_save_dir(model_name, finetuning_type, ckpt) for ckpt in checkpoints])
output_dir = get_save_dir(model_name, finetuning_type, "eval_" + "_".join(checkpoints))
else: else:
checkpoint_dir = None checkpoint_dir = None
output_dir = get_save_dir(model_name, finetuning_type, "eval_base") output_dir = get_save_dir(get("top.model_name"), get("top.finetuning_type"), "eval_base")
user_config = load_config()
cache_dir = user_config.get("cache_dir", None)
args = dict( args = dict(
stage="sft", stage="sft",
model_name_or_path=get_model_path(model_name), model_name_or_path=get("top.model_path"),
do_eval=True, do_eval=True,
overwrite_cache=False, overwrite_cache=False,
predict_with_generate=True, predict_with_generate=True,
cache_dir=cache_dir, cache_dir=get("top.config").get("cache_dir", None),
checkpoint_dir=checkpoint_dir, checkpoint_dir=checkpoint_dir,
finetuning_type=finetuning_type, finetuning_type=get("top.finetuning_type"),
quantization_bit=int(quantization_bit) if quantization_bit in ["8", "4"] else None, quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None,
template=template, template=get("top.template"),
system_prompt=system_prompt, system_prompt=get("top.system_prompt"),
flash_attn=flash_attn, flash_attn=get("top.flash_attn"),
shift_attn=shift_attn, shift_attn=get("top.shift_attn"),
rope_scaling=rope_scaling if rope_scaling in ["linear", "dynamic"] else None, rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
dataset_dir=dataset_dir, dataset_dir=get("eval.dataset_dir"),
dataset=",".join(dataset), dataset=",".join(get("eval.dataset")),
cutoff_len=cutoff_len, cutoff_len=get("eval.cutoff_len"),
max_samples=int(max_samples), max_samples=int(get("eval.max_samples")),
per_device_eval_batch_size=batch_size, per_device_eval_batch_size=get("eval.batch_size"),
max_new_tokens=max_new_tokens, max_new_tokens=get("eval.max_new_tokens"),
top_p=top_p, top_p=get("eval.top_p"),
temperature=temperature, temperature=get("eval.temperature"),
output_dir=output_dir output_dir=get("eval.output_dir")
) )
if predict: if get("eval.predict"):
args.pop("do_eval", None) args.pop("do_eval", None)
args["do_predict"] = True args["do_predict"] = True
return lang, model_name, dataset, output_dir, args return get("top.lang"), get("top.model_name"), get("top.model_path"), get("eval.dataset"), output_dir, args
def preview_train(self, *args) -> Generator[Tuple[str, Dict[str, Any]], None, None]: def preview_train(self, data: Dict[Component, Any]) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
lang, model_name, dataset, _, args = self._parse_train_args(*args) lang, model_name, model_path, dataset, _, args = self._parse_train_args(data)
error = self._initialize(lang, model_name, dataset) error = self._initialize(lang, model_name, model_path, dataset)
if error: if error:
yield error, gr.update(visible=False) yield error, gr.update(visible=False)
else: else:
yield gen_cmd(args), gr.update(visible=False) yield gen_cmd(args), gr.update(visible=False)
def preview_eval(self, *args) -> Generator[Tuple[str, Dict[str, Any]], None, None]: def preview_eval(self, data: Dict[Component, Any]) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
lang, model_name, dataset, _, args = self._parse_eval_args(*args) lang, model_name, model_path, dataset, _, args = self._parse_eval_args(data)
error = self._initialize(lang, model_name, dataset) error = self._initialize(lang, model_name, model_path, dataset)
if error: if error:
yield error, gr.update(visible=False) yield error, gr.update(visible=False)
else: else:
yield gen_cmd(args), gr.update(visible=False) yield gen_cmd(args), gr.update(visible=False)
def run_train(self, *args) -> Generator[Tuple[str, Dict[str, Any]], None, None]: def run_train(self, data: Dict[Component, Any]) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
lang, model_name, dataset, output_dir, args = self._parse_train_args(*args) self.prepare(data, self._parse_train_args)
error = self._initialize(lang, model_name, dataset)
def run_eval(self, data: Dict[Component, Any]) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
self.prepare(data, self._parse_eval_args)
def prepare(self, data: Dict[Component, Any], is_training: bool) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
parse_func = self._parse_train_args if is_training else self._parse_eval_args
lang, model_name, model_path, dataset, output_dir, args = parse_func(data)
error = self._initialize(lang, model_name, model_path, dataset)
if error: if error:
yield error, gr.update(visible=False) yield error, gr.update(visible=False)
return else:
self.running = True
run_kwargs = dict(args=args, callbacks=[self.trainer_callback])
thread = Thread(target=run_exp, kwargs=run_kwargs)
thread.start()
yield self.monitor(lang, output_dir, is_training)
self.running = True def monitor(self, lang: str, output_dir: str, is_training: bool) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
run_kwargs = dict(args=args, callbacks=[self.trainer_callback]) while self.thread.is_alive():
thread = threading.Thread(target=run_exp, kwargs=run_kwargs)
thread.start()
while thread.is_alive():
time.sleep(2) time.sleep(2)
if self.aborted: if self.aborted:
yield ALERTS["info_aborting"][lang], gr.update(visible=False) yield ALERTS["info_aborting"][lang], gr.update(visible=False)
else: else:
yield self.logger_handler.log, update_process_bar(self.trainer_callback) yield self.logger_handler.log, update_process_bar(self.trainer_callback)
if os.path.exists(os.path.join(output_dir, TRAINING_ARGS_NAME)): if is_training:
finish_info = ALERTS["info_finished"][lang] if os.path.exists(os.path.join(output_dir, TRAINING_ARGS_NAME)):
else: finish_info = ALERTS["info_finished"][lang]
finish_info = ALERTS["err_failed"][lang]
yield self._finalize(lang, finish_info), gr.update(visible=False)
def run_eval(self, *args) -> Generator[str, None, None]:
lang, model_name, dataset, output_dir, args = self._parse_eval_args(*args)
error = self._initialize(lang, model_name, dataset)
if error:
yield error, gr.update(visible=False)
return
self.running = True
run_kwargs = dict(args=args, callbacks=[self.trainer_callback])
thread = threading.Thread(target=run_exp, kwargs=run_kwargs)
thread.start()
while thread.is_alive():
time.sleep(2)
if self.aborted:
yield ALERTS["info_aborting"][lang], gr.update(visible=False)
else: else:
yield self.logger_handler.log, update_process_bar(self.trainer_callback) finish_info = ALERTS["err_failed"][lang]
if os.path.exists(os.path.join(output_dir, "all_results.json")):
finish_info = get_eval_results(os.path.join(output_dir, "all_results.json"))
else: else:
finish_info = ALERTS["err_failed"][lang] if os.path.exists(os.path.join(output_dir, "all_results.json")):
finish_info = get_eval_results(os.path.join(output_dir, "all_results.json"))
else:
finish_info = ALERTS["err_failed"][lang]
yield self._finalize(lang, finish_info), gr.update(visible=False) yield self._finalize(lang, finish_info), gr.update(visible=False)

View File

@ -8,7 +8,7 @@ from datetime import datetime
from llmtuner.extras.ploting import smooth from llmtuner.extras.ploting import smooth
from llmtuner.tuner import export_model from llmtuner.tuner import export_model
from llmtuner.webui.common import get_model_path, get_save_dir, DATA_CONFIG from llmtuner.webui.common import get_save_dir, DATA_CONFIG
from llmtuner.webui.locales import ALERTS from llmtuner.webui.locales import ALERTS
if TYPE_CHECKING: if TYPE_CHECKING:
@ -119,6 +119,7 @@ def gen_plot(base_model: str, finetuning_type: str, output_dir: str) -> matplotl
def save_model( def save_model(
lang: str, lang: str,
model_name: str, model_name: str,
model_path: str,
checkpoints: List[str], checkpoints: List[str],
finetuning_type: str, finetuning_type: str,
template: str, template: str,
@ -129,8 +130,7 @@ def save_model(
yield ALERTS["err_no_model"][lang] yield ALERTS["err_no_model"][lang]
return return
model_name_or_path = get_model_path(model_name) if not model_path:
if not model_name_or_path:
yield ALERTS["err_no_path"][lang] yield ALERTS["err_no_path"][lang]
return return
@ -138,17 +138,13 @@ def save_model(
yield ALERTS["err_no_checkpoint"][lang] yield ALERTS["err_no_checkpoint"][lang]
return return
checkpoint_dir = ",".join(
[get_save_dir(model_name, finetuning_type, ckpt) for ckpt in checkpoints]
)
if not save_dir: if not save_dir:
yield ALERTS["err_no_save_dir"][lang] yield ALERTS["err_no_save_dir"][lang]
return return
args = dict( args = dict(
model_name_or_path=model_name_or_path, model_name_or_path=model_path,
checkpoint_dir=checkpoint_dir, checkpoint_dir=",".join([get_save_dir(model_name, finetuning_type, ckpt) for ckpt in checkpoints]),
finetuning_type=finetuning_type, finetuning_type=finetuning_type,
template=template, template=template,
output_dir=save_dir output_dir=save_dir