mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-16 03:40:34 +08:00
16
src/llamafactory/webui/components/__init__.py
Normal file
16
src/llamafactory/webui/components/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from .chatbot import create_chat_box
|
||||
from .eval import create_eval_tab
|
||||
from .export import create_export_tab
|
||||
from .infer import create_infer_tab
|
||||
from .top import create_top
|
||||
from .train import create_train_tab
|
||||
|
||||
|
||||
__all__ = [
|
||||
"create_chat_box",
|
||||
"create_eval_tab",
|
||||
"create_export_tab",
|
||||
"create_infer_tab",
|
||||
"create_top",
|
||||
"create_train_tab",
|
||||
]
|
||||
74
src/llamafactory/webui/components/chatbot.py
Normal file
74
src/llamafactory/webui/components/chatbot.py
Normal file
@@ -0,0 +1,74 @@
|
||||
from typing import TYPE_CHECKING, Dict, Tuple
|
||||
|
||||
from ...data import Role
|
||||
from ...extras.packages import is_gradio_available
|
||||
from ..utils import check_json_schema
|
||||
|
||||
|
||||
if is_gradio_available():
|
||||
import gradio as gr
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from gradio.components import Component
|
||||
|
||||
from ..engine import Engine
|
||||
|
||||
|
||||
def create_chat_box(
|
||||
engine: "Engine", visible: bool = False
|
||||
) -> Tuple["Component", "Component", Dict[str, "Component"]]:
|
||||
with gr.Column(visible=visible) as chat_box:
|
||||
chatbot = gr.Chatbot(show_copy_button=True)
|
||||
messages = gr.State([])
|
||||
with gr.Row():
|
||||
with gr.Column(scale=4):
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
role = gr.Dropdown(choices=[Role.USER.value, Role.OBSERVATION.value], value=Role.USER.value)
|
||||
system = gr.Textbox(show_label=False)
|
||||
tools = gr.Textbox(show_label=False, lines=3)
|
||||
|
||||
with gr.Column() as image_box:
|
||||
image = gr.Image(sources=["upload"], type="numpy")
|
||||
|
||||
query = gr.Textbox(show_label=False, lines=8)
|
||||
submit_btn = gr.Button(variant="primary")
|
||||
|
||||
with gr.Column(scale=1):
|
||||
max_new_tokens = gr.Slider(minimum=8, maximum=4096, value=512, step=1)
|
||||
top_p = gr.Slider(minimum=0.01, maximum=1.0, value=0.7, step=0.01)
|
||||
temperature = gr.Slider(minimum=0.01, maximum=1.5, value=0.95, step=0.01)
|
||||
clear_btn = gr.Button()
|
||||
|
||||
tools.input(check_json_schema, inputs=[tools, engine.manager.get_elem_by_id("top.lang")])
|
||||
|
||||
submit_btn.click(
|
||||
engine.chatter.append,
|
||||
[chatbot, messages, role, query],
|
||||
[chatbot, messages, query],
|
||||
).then(
|
||||
engine.chatter.stream,
|
||||
[chatbot, messages, system, tools, image, max_new_tokens, top_p, temperature],
|
||||
[chatbot, messages],
|
||||
)
|
||||
clear_btn.click(lambda: ([], []), outputs=[chatbot, messages])
|
||||
|
||||
return (
|
||||
chatbot,
|
||||
messages,
|
||||
dict(
|
||||
chat_box=chat_box,
|
||||
role=role,
|
||||
system=system,
|
||||
tools=tools,
|
||||
image_box=image_box,
|
||||
image=image,
|
||||
query=query,
|
||||
submit_btn=submit_btn,
|
||||
max_new_tokens=max_new_tokens,
|
||||
top_p=top_p,
|
||||
temperature=temperature,
|
||||
clear_btn=clear_btn,
|
||||
),
|
||||
)
|
||||
106
src/llamafactory/webui/components/data.py
Normal file
106
src/llamafactory/webui/components/data.py
Normal file
@@ -0,0 +1,106 @@
|
||||
import json
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Tuple
|
||||
|
||||
from ...extras.constants import DATA_CONFIG
|
||||
from ...extras.packages import is_gradio_available
|
||||
|
||||
|
||||
if is_gradio_available():
|
||||
import gradio as gr
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from gradio.components import Component
|
||||
|
||||
|
||||
PAGE_SIZE = 2
|
||||
|
||||
|
||||
def prev_page(page_index: int) -> int:
|
||||
return page_index - 1 if page_index > 0 else page_index
|
||||
|
||||
|
||||
def next_page(page_index: int, total_num: int) -> int:
|
||||
return page_index + 1 if (page_index + 1) * PAGE_SIZE < total_num else page_index
|
||||
|
||||
|
||||
def can_preview(dataset_dir: str, dataset: list) -> "gr.Button":
|
||||
try:
|
||||
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
|
||||
dataset_info = json.load(f)
|
||||
except Exception:
|
||||
return gr.Button(interactive=False)
|
||||
|
||||
if len(dataset) == 0 or "file_name" not in dataset_info[dataset[0]]:
|
||||
return gr.Button(interactive=False)
|
||||
|
||||
data_path = os.path.join(dataset_dir, dataset_info[dataset[0]]["file_name"])
|
||||
if os.path.isfile(data_path) or (os.path.isdir(data_path) and os.listdir(data_path)):
|
||||
return gr.Button(interactive=True)
|
||||
else:
|
||||
return gr.Button(interactive=False)
|
||||
|
||||
|
||||
def _load_data_file(file_path: str) -> List[Any]:
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
if file_path.endswith(".json"):
|
||||
return json.load(f)
|
||||
elif file_path.endswith(".jsonl"):
|
||||
return [json.loads(line) for line in f]
|
||||
else:
|
||||
return list(f)
|
||||
|
||||
|
||||
def get_preview(dataset_dir: str, dataset: list, page_index: int) -> Tuple[int, list, "gr.Column"]:
|
||||
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
|
||||
dataset_info = json.load(f)
|
||||
|
||||
data_path = os.path.join(dataset_dir, dataset_info[dataset[0]]["file_name"])
|
||||
if os.path.isfile(data_path):
|
||||
data = _load_data_file(data_path)
|
||||
else:
|
||||
data = []
|
||||
for file_name in os.listdir(data_path):
|
||||
data.extend(_load_data_file(os.path.join(data_path, file_name)))
|
||||
|
||||
return len(data), data[PAGE_SIZE * page_index : PAGE_SIZE * (page_index + 1)], gr.Column(visible=True)
|
||||
|
||||
|
||||
def create_preview_box(dataset_dir: "gr.Textbox", dataset: "gr.Dropdown") -> Dict[str, "Component"]:
|
||||
data_preview_btn = gr.Button(interactive=False, scale=1)
|
||||
with gr.Column(visible=False, elem_classes="modal-box") as preview_box:
|
||||
with gr.Row():
|
||||
preview_count = gr.Number(value=0, interactive=False, precision=0)
|
||||
page_index = gr.Number(value=0, interactive=False, precision=0)
|
||||
|
||||
with gr.Row():
|
||||
prev_btn = gr.Button()
|
||||
next_btn = gr.Button()
|
||||
close_btn = gr.Button()
|
||||
|
||||
with gr.Row():
|
||||
preview_samples = gr.JSON()
|
||||
|
||||
dataset.change(can_preview, [dataset_dir, dataset], [data_preview_btn], queue=False).then(
|
||||
lambda: 0, outputs=[page_index], queue=False
|
||||
)
|
||||
data_preview_btn.click(
|
||||
get_preview, [dataset_dir, dataset, page_index], [preview_count, preview_samples, preview_box], queue=False
|
||||
)
|
||||
prev_btn.click(prev_page, [page_index], [page_index], queue=False).then(
|
||||
get_preview, [dataset_dir, dataset, page_index], [preview_count, preview_samples, preview_box], queue=False
|
||||
)
|
||||
next_btn.click(next_page, [page_index, preview_count], [page_index], queue=False).then(
|
||||
get_preview, [dataset_dir, dataset, page_index], [preview_count, preview_samples, preview_box], queue=False
|
||||
)
|
||||
close_btn.click(lambda: gr.Column(visible=False), outputs=[preview_box], queue=False)
|
||||
return dict(
|
||||
data_preview_btn=data_preview_btn,
|
||||
preview_count=preview_count,
|
||||
page_index=page_index,
|
||||
prev_btn=prev_btn,
|
||||
next_btn=next_btn,
|
||||
close_btn=close_btn,
|
||||
preview_samples=preview_samples,
|
||||
)
|
||||
79
src/llamafactory/webui/components/eval.py
Normal file
79
src/llamafactory/webui/components/eval.py
Normal file
@@ -0,0 +1,79 @@
|
||||
from typing import TYPE_CHECKING, Dict
|
||||
|
||||
from ...extras.packages import is_gradio_available
|
||||
from ..common import DEFAULT_DATA_DIR, list_dataset
|
||||
from .data import create_preview_box
|
||||
|
||||
|
||||
if is_gradio_available():
|
||||
import gradio as gr
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from gradio.components import Component
|
||||
|
||||
from ..engine import Engine
|
||||
|
||||
|
||||
def create_eval_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
input_elems = engine.manager.get_base_elems()
|
||||
elem_dict = dict()
|
||||
|
||||
with gr.Row():
|
||||
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2)
|
||||
dataset = gr.Dropdown(multiselect=True, allow_custom_value=True, scale=4)
|
||||
preview_elems = create_preview_box(dataset_dir, dataset)
|
||||
|
||||
input_elems.update({dataset_dir, dataset})
|
||||
elem_dict.update(dict(dataset_dir=dataset_dir, dataset=dataset, **preview_elems))
|
||||
|
||||
with gr.Row():
|
||||
cutoff_len = gr.Slider(minimum=4, maximum=65536, value=1024, step=1)
|
||||
max_samples = gr.Textbox(value="100000")
|
||||
batch_size = gr.Slider(minimum=1, maximum=1024, value=2, step=1)
|
||||
predict = gr.Checkbox(value=True)
|
||||
|
||||
input_elems.update({cutoff_len, max_samples, batch_size, predict})
|
||||
elem_dict.update(dict(cutoff_len=cutoff_len, max_samples=max_samples, batch_size=batch_size, predict=predict))
|
||||
|
||||
with gr.Row():
|
||||
max_new_tokens = gr.Slider(minimum=8, maximum=4096, value=512, step=1)
|
||||
top_p = gr.Slider(minimum=0.01, maximum=1, value=0.7, step=0.01)
|
||||
temperature = gr.Slider(minimum=0.01, maximum=1.5, value=0.95, step=0.01)
|
||||
output_dir = gr.Textbox()
|
||||
|
||||
input_elems.update({max_new_tokens, top_p, temperature, output_dir})
|
||||
elem_dict.update(dict(max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature, output_dir=output_dir))
|
||||
|
||||
with gr.Row():
|
||||
cmd_preview_btn = gr.Button()
|
||||
start_btn = gr.Button(variant="primary")
|
||||
stop_btn = gr.Button(variant="stop")
|
||||
|
||||
with gr.Row():
|
||||
resume_btn = gr.Checkbox(visible=False, interactive=False)
|
||||
progress_bar = gr.Slider(visible=False, interactive=False)
|
||||
|
||||
with gr.Row():
|
||||
output_box = gr.Markdown()
|
||||
|
||||
output_elems = [output_box, progress_bar]
|
||||
elem_dict.update(
|
||||
dict(
|
||||
cmd_preview_btn=cmd_preview_btn,
|
||||
start_btn=start_btn,
|
||||
stop_btn=stop_btn,
|
||||
resume_btn=resume_btn,
|
||||
progress_bar=progress_bar,
|
||||
output_box=output_box,
|
||||
)
|
||||
)
|
||||
|
||||
cmd_preview_btn.click(engine.runner.preview_eval, input_elems, output_elems, concurrency_limit=None)
|
||||
start_btn.click(engine.runner.run_eval, input_elems, output_elems)
|
||||
stop_btn.click(engine.runner.set_abort)
|
||||
resume_btn.change(engine.runner.monitor, outputs=output_elems, concurrency_limit=None)
|
||||
|
||||
dataset_dir.change(list_dataset, [dataset_dir], [dataset], queue=False)
|
||||
|
||||
return elem_dict
|
||||
132
src/llamafactory/webui/components/export.py
Normal file
132
src/llamafactory/webui/components/export.py
Normal file
@@ -0,0 +1,132 @@
|
||||
from typing import TYPE_CHECKING, Dict, Generator, List
|
||||
|
||||
from ...extras.misc import torch_gc
|
||||
from ...extras.packages import is_gradio_available
|
||||
from ...train.tuner import export_model
|
||||
from ..common import get_save_dir
|
||||
from ..locales import ALERTS
|
||||
|
||||
|
||||
if is_gradio_available():
|
||||
import gradio as gr
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from gradio.components import Component
|
||||
|
||||
from ..engine import Engine
|
||||
|
||||
|
||||
GPTQ_BITS = ["8", "4", "3", "2"]
|
||||
|
||||
|
||||
def save_model(
|
||||
lang: str,
|
||||
model_name: str,
|
||||
model_path: str,
|
||||
adapter_path: List[str],
|
||||
finetuning_type: str,
|
||||
template: str,
|
||||
visual_inputs: bool,
|
||||
export_size: int,
|
||||
export_quantization_bit: int,
|
||||
export_quantization_dataset: str,
|
||||
export_device: str,
|
||||
export_legacy_format: bool,
|
||||
export_dir: str,
|
||||
export_hub_model_id: str,
|
||||
) -> Generator[str, None, None]:
|
||||
error = ""
|
||||
if not model_name:
|
||||
error = ALERTS["err_no_model"][lang]
|
||||
elif not model_path:
|
||||
error = ALERTS["err_no_path"][lang]
|
||||
elif not export_dir:
|
||||
error = ALERTS["err_no_export_dir"][lang]
|
||||
elif export_quantization_bit in GPTQ_BITS and not export_quantization_dataset:
|
||||
error = ALERTS["err_no_dataset"][lang]
|
||||
elif export_quantization_bit not in GPTQ_BITS and not adapter_path:
|
||||
error = ALERTS["err_no_adapter"][lang]
|
||||
elif export_quantization_bit in GPTQ_BITS and adapter_path:
|
||||
error = ALERTS["err_gptq_lora"][lang]
|
||||
|
||||
if error:
|
||||
gr.Warning(error)
|
||||
yield error
|
||||
return
|
||||
|
||||
if adapter_path:
|
||||
adapter_name_or_path = ",".join(
|
||||
[get_save_dir(model_name, finetuning_type, adapter) for adapter in adapter_path]
|
||||
)
|
||||
else:
|
||||
adapter_name_or_path = None
|
||||
|
||||
args = dict(
|
||||
model_name_or_path=model_path,
|
||||
adapter_name_or_path=adapter_name_or_path,
|
||||
finetuning_type=finetuning_type,
|
||||
template=template,
|
||||
visual_inputs=visual_inputs,
|
||||
export_dir=export_dir,
|
||||
export_hub_model_id=export_hub_model_id or None,
|
||||
export_size=export_size,
|
||||
export_quantization_bit=int(export_quantization_bit) if export_quantization_bit in GPTQ_BITS else None,
|
||||
export_quantization_dataset=export_quantization_dataset,
|
||||
export_device=export_device,
|
||||
export_legacy_format=export_legacy_format,
|
||||
)
|
||||
|
||||
yield ALERTS["info_exporting"][lang]
|
||||
export_model(args)
|
||||
torch_gc()
|
||||
yield ALERTS["info_exported"][lang]
|
||||
|
||||
|
||||
def create_export_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
with gr.Row():
|
||||
export_size = gr.Slider(minimum=1, maximum=100, value=1, step=1)
|
||||
export_quantization_bit = gr.Dropdown(choices=["none", "8", "4", "3", "2"], value="none")
|
||||
export_quantization_dataset = gr.Textbox(value="data/c4_demo.json")
|
||||
export_device = gr.Radio(choices=["cpu", "cuda"], value="cpu")
|
||||
export_legacy_format = gr.Checkbox()
|
||||
|
||||
with gr.Row():
|
||||
export_dir = gr.Textbox()
|
||||
export_hub_model_id = gr.Textbox()
|
||||
|
||||
export_btn = gr.Button()
|
||||
info_box = gr.Textbox(show_label=False, interactive=False)
|
||||
|
||||
export_btn.click(
|
||||
save_model,
|
||||
[
|
||||
engine.manager.get_elem_by_id("top.lang"),
|
||||
engine.manager.get_elem_by_id("top.model_name"),
|
||||
engine.manager.get_elem_by_id("top.model_path"),
|
||||
engine.manager.get_elem_by_id("top.adapter_path"),
|
||||
engine.manager.get_elem_by_id("top.finetuning_type"),
|
||||
engine.manager.get_elem_by_id("top.template"),
|
||||
engine.manager.get_elem_by_id("top.visual_inputs"),
|
||||
export_size,
|
||||
export_quantization_bit,
|
||||
export_quantization_dataset,
|
||||
export_device,
|
||||
export_legacy_format,
|
||||
export_dir,
|
||||
export_hub_model_id,
|
||||
],
|
||||
[info_box],
|
||||
)
|
||||
|
||||
return dict(
|
||||
export_size=export_size,
|
||||
export_quantization_bit=export_quantization_bit,
|
||||
export_quantization_dataset=export_quantization_dataset,
|
||||
export_device=export_device,
|
||||
export_legacy_format=export_legacy_format,
|
||||
export_dir=export_dir,
|
||||
export_hub_model_id=export_hub_model_id,
|
||||
export_btn=export_btn,
|
||||
info_box=info_box,
|
||||
)
|
||||
48
src/llamafactory/webui/components/infer.py
Normal file
48
src/llamafactory/webui/components/infer.py
Normal file
@@ -0,0 +1,48 @@
|
||||
from typing import TYPE_CHECKING, Dict
|
||||
|
||||
from ...extras.packages import is_gradio_available
|
||||
from .chatbot import create_chat_box
|
||||
|
||||
|
||||
if is_gradio_available():
|
||||
import gradio as gr
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from gradio.components import Component
|
||||
|
||||
from ..engine import Engine
|
||||
|
||||
|
||||
def create_infer_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
input_elems = engine.manager.get_base_elems()
|
||||
elem_dict = dict()
|
||||
|
||||
infer_backend = gr.Dropdown(choices=["huggingface", "vllm"], value="huggingface")
|
||||
with gr.Row():
|
||||
load_btn = gr.Button()
|
||||
unload_btn = gr.Button()
|
||||
|
||||
info_box = gr.Textbox(show_label=False, interactive=False)
|
||||
|
||||
input_elems.update({infer_backend})
|
||||
elem_dict.update(dict(infer_backend=infer_backend, load_btn=load_btn, unload_btn=unload_btn, info_box=info_box))
|
||||
|
||||
chatbot, messages, chat_elems = create_chat_box(engine, visible=False)
|
||||
elem_dict.update(chat_elems)
|
||||
|
||||
load_btn.click(engine.chatter.load_model, input_elems, [info_box]).then(
|
||||
lambda: gr.Column(visible=engine.chatter.loaded), outputs=[chat_elems["chat_box"]]
|
||||
)
|
||||
|
||||
unload_btn.click(engine.chatter.unload_model, input_elems, [info_box]).then(
|
||||
lambda: ([], []), outputs=[chatbot, messages]
|
||||
).then(lambda: gr.Column(visible=engine.chatter.loaded), outputs=[chat_elems["chat_box"]])
|
||||
|
||||
engine.manager.get_elem_by_id("top.visual_inputs").change(
|
||||
lambda enabled: gr.Column(visible=enabled),
|
||||
[engine.manager.get_elem_by_id("top.visual_inputs")],
|
||||
[chat_elems["image_box"]],
|
||||
)
|
||||
|
||||
return elem_dict
|
||||
66
src/llamafactory/webui/components/top.py
Normal file
66
src/llamafactory/webui/components/top.py
Normal file
@@ -0,0 +1,66 @@
|
||||
from typing import TYPE_CHECKING, Dict
|
||||
|
||||
from ...data import templates
|
||||
from ...extras.constants import METHODS, SUPPORTED_MODELS
|
||||
from ...extras.packages import is_gradio_available
|
||||
from ..common import get_model_path, get_template, get_visual, list_adapters, save_config
|
||||
from ..utils import can_quantize
|
||||
|
||||
|
||||
if is_gradio_available():
|
||||
import gradio as gr
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from gradio.components import Component
|
||||
|
||||
|
||||
def create_top() -> Dict[str, "Component"]:
|
||||
available_models = list(SUPPORTED_MODELS.keys()) + ["Custom"]
|
||||
|
||||
with gr.Row():
|
||||
lang = gr.Dropdown(choices=["en", "ru", "zh"], scale=1)
|
||||
model_name = gr.Dropdown(choices=available_models, scale=3)
|
||||
model_path = gr.Textbox(scale=3)
|
||||
|
||||
with gr.Row():
|
||||
finetuning_type = gr.Dropdown(choices=METHODS, value="lora", scale=1)
|
||||
adapter_path = gr.Dropdown(multiselect=True, allow_custom_value=True, scale=5)
|
||||
refresh_btn = gr.Button(scale=1)
|
||||
|
||||
with gr.Accordion(open=False) as advanced_tab:
|
||||
with gr.Row():
|
||||
quantization_bit = gr.Dropdown(choices=["none", "8", "4"], value="none", scale=2)
|
||||
template = gr.Dropdown(choices=list(templates.keys()), value="default", scale=2)
|
||||
rope_scaling = gr.Radio(choices=["none", "linear", "dynamic"], value="none", scale=3)
|
||||
booster = gr.Radio(choices=["none", "flashattn2", "unsloth"], value="none", scale=3)
|
||||
visual_inputs = gr.Checkbox(scale=1)
|
||||
|
||||
model_name.change(list_adapters, [model_name, finetuning_type], [adapter_path], queue=False).then(
|
||||
get_model_path, [model_name], [model_path], queue=False
|
||||
).then(get_template, [model_name], [template], queue=False).then(
|
||||
get_visual, [model_name], [visual_inputs], queue=False
|
||||
) # do not save config since the below line will save
|
||||
|
||||
model_path.change(save_config, inputs=[lang, model_name, model_path], queue=False)
|
||||
|
||||
finetuning_type.change(list_adapters, [model_name, finetuning_type], [adapter_path], queue=False).then(
|
||||
can_quantize, [finetuning_type], [quantization_bit], queue=False
|
||||
)
|
||||
|
||||
refresh_btn.click(list_adapters, [model_name, finetuning_type], [adapter_path], queue=False)
|
||||
|
||||
return dict(
|
||||
lang=lang,
|
||||
model_name=model_name,
|
||||
model_path=model_path,
|
||||
finetuning_type=finetuning_type,
|
||||
adapter_path=adapter_path,
|
||||
refresh_btn=refresh_btn,
|
||||
advanced_tab=advanced_tab,
|
||||
quantization_bit=quantization_bit,
|
||||
template=template,
|
||||
rope_scaling=rope_scaling,
|
||||
booster=booster,
|
||||
visual_inputs=visual_inputs,
|
||||
)
|
||||
299
src/llamafactory/webui/components/train.py
Normal file
299
src/llamafactory/webui/components/train.py
Normal file
@@ -0,0 +1,299 @@
|
||||
from typing import TYPE_CHECKING, Dict
|
||||
|
||||
from transformers.trainer_utils import SchedulerType
|
||||
|
||||
from ...extras.constants import TRAINING_STAGES
|
||||
from ...extras.packages import is_gradio_available
|
||||
from ..common import DEFAULT_DATA_DIR, autoset_packing, list_adapters, list_dataset
|
||||
from ..components.data import create_preview_box
|
||||
|
||||
|
||||
if is_gradio_available():
|
||||
import gradio as gr
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from gradio.components import Component
|
||||
|
||||
from ..engine import Engine
|
||||
|
||||
|
||||
def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
input_elems = engine.manager.get_base_elems()
|
||||
elem_dict = dict()
|
||||
|
||||
with gr.Row():
|
||||
training_stage = gr.Dropdown(
|
||||
choices=list(TRAINING_STAGES.keys()), value=list(TRAINING_STAGES.keys())[0], scale=1
|
||||
)
|
||||
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=1)
|
||||
dataset = gr.Dropdown(multiselect=True, allow_custom_value=True, scale=4)
|
||||
preview_elems = create_preview_box(dataset_dir, dataset)
|
||||
|
||||
input_elems.update({training_stage, dataset_dir, dataset})
|
||||
elem_dict.update(dict(training_stage=training_stage, dataset_dir=dataset_dir, dataset=dataset, **preview_elems))
|
||||
|
||||
with gr.Row():
|
||||
learning_rate = gr.Textbox(value="5e-5")
|
||||
num_train_epochs = gr.Textbox(value="3.0")
|
||||
max_grad_norm = gr.Textbox(value="1.0")
|
||||
max_samples = gr.Textbox(value="100000")
|
||||
compute_type = gr.Dropdown(choices=["fp16", "bf16", "fp32", "pure_bf16"], value="fp16")
|
||||
|
||||
input_elems.update({learning_rate, num_train_epochs, max_grad_norm, max_samples, compute_type})
|
||||
elem_dict.update(
|
||||
dict(
|
||||
learning_rate=learning_rate,
|
||||
num_train_epochs=num_train_epochs,
|
||||
max_grad_norm=max_grad_norm,
|
||||
max_samples=max_samples,
|
||||
compute_type=compute_type,
|
||||
)
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
cutoff_len = gr.Slider(minimum=4, maximum=65536, value=1024, step=1)
|
||||
batch_size = gr.Slider(minimum=1, maximum=1024, value=2, step=1)
|
||||
gradient_accumulation_steps = gr.Slider(minimum=1, maximum=1024, value=8, step=1)
|
||||
val_size = gr.Slider(minimum=0, maximum=1, value=0, step=0.001)
|
||||
lr_scheduler_type = gr.Dropdown(choices=[scheduler.value for scheduler in SchedulerType], value="cosine")
|
||||
|
||||
input_elems.update({cutoff_len, batch_size, gradient_accumulation_steps, val_size, lr_scheduler_type})
|
||||
elem_dict.update(
|
||||
dict(
|
||||
cutoff_len=cutoff_len,
|
||||
batch_size=batch_size,
|
||||
gradient_accumulation_steps=gradient_accumulation_steps,
|
||||
val_size=val_size,
|
||||
lr_scheduler_type=lr_scheduler_type,
|
||||
)
|
||||
)
|
||||
|
||||
with gr.Accordion(open=False) as extra_tab:
|
||||
with gr.Row():
|
||||
logging_steps = gr.Slider(minimum=1, maximum=1000, value=5, step=5)
|
||||
save_steps = gr.Slider(minimum=10, maximum=5000, value=100, step=10)
|
||||
warmup_steps = gr.Slider(minimum=0, maximum=5000, value=0, step=1)
|
||||
neftune_alpha = gr.Slider(minimum=0, maximum=10, value=0, step=0.1)
|
||||
optim = gr.Textbox(value="adamw_torch")
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
resize_vocab = gr.Checkbox()
|
||||
packing = gr.Checkbox()
|
||||
|
||||
with gr.Column():
|
||||
upcast_layernorm = gr.Checkbox()
|
||||
use_llama_pro = gr.Checkbox()
|
||||
|
||||
with gr.Column():
|
||||
shift_attn = gr.Checkbox()
|
||||
report_to = gr.Checkbox()
|
||||
|
||||
input_elems.update(
|
||||
{
|
||||
logging_steps,
|
||||
save_steps,
|
||||
warmup_steps,
|
||||
neftune_alpha,
|
||||
optim,
|
||||
resize_vocab,
|
||||
packing,
|
||||
upcast_layernorm,
|
||||
use_llama_pro,
|
||||
shift_attn,
|
||||
report_to,
|
||||
}
|
||||
)
|
||||
elem_dict.update(
|
||||
dict(
|
||||
extra_tab=extra_tab,
|
||||
logging_steps=logging_steps,
|
||||
save_steps=save_steps,
|
||||
warmup_steps=warmup_steps,
|
||||
neftune_alpha=neftune_alpha,
|
||||
optim=optim,
|
||||
resize_vocab=resize_vocab,
|
||||
packing=packing,
|
||||
upcast_layernorm=upcast_layernorm,
|
||||
use_llama_pro=use_llama_pro,
|
||||
shift_attn=shift_attn,
|
||||
report_to=report_to,
|
||||
)
|
||||
)
|
||||
|
||||
with gr.Accordion(open=False) as freeze_tab:
|
||||
with gr.Row():
|
||||
freeze_trainable_layers = gr.Slider(minimum=-128, maximum=128, value=2, step=1)
|
||||
freeze_trainable_modules = gr.Textbox(value="all")
|
||||
freeze_extra_modules = gr.Textbox()
|
||||
|
||||
input_elems.update({freeze_trainable_layers, freeze_trainable_modules, freeze_extra_modules})
|
||||
elem_dict.update(
|
||||
dict(
|
||||
freeze_tab=freeze_tab,
|
||||
freeze_trainable_layers=freeze_trainable_layers,
|
||||
freeze_trainable_modules=freeze_trainable_modules,
|
||||
freeze_extra_modules=freeze_extra_modules,
|
||||
)
|
||||
)
|
||||
|
||||
with gr.Accordion(open=False) as lora_tab:
|
||||
with gr.Row():
|
||||
lora_rank = gr.Slider(minimum=1, maximum=1024, value=8, step=1)
|
||||
lora_alpha = gr.Slider(minimum=1, maximum=2048, value=16, step=1)
|
||||
lora_dropout = gr.Slider(minimum=0, maximum=1, value=0, step=0.01)
|
||||
loraplus_lr_ratio = gr.Slider(minimum=0, maximum=64, value=0, step=0.01)
|
||||
create_new_adapter = gr.Checkbox()
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1):
|
||||
use_rslora = gr.Checkbox()
|
||||
use_dora = gr.Checkbox()
|
||||
|
||||
lora_target = gr.Textbox(scale=2)
|
||||
additional_target = gr.Textbox(scale=2)
|
||||
|
||||
input_elems.update(
|
||||
{
|
||||
lora_rank,
|
||||
lora_alpha,
|
||||
lora_dropout,
|
||||
loraplus_lr_ratio,
|
||||
create_new_adapter,
|
||||
use_rslora,
|
||||
use_dora,
|
||||
lora_target,
|
||||
additional_target,
|
||||
}
|
||||
)
|
||||
elem_dict.update(
|
||||
dict(
|
||||
lora_tab=lora_tab,
|
||||
lora_rank=lora_rank,
|
||||
lora_alpha=lora_alpha,
|
||||
lora_dropout=lora_dropout,
|
||||
loraplus_lr_ratio=loraplus_lr_ratio,
|
||||
create_new_adapter=create_new_adapter,
|
||||
use_rslora=use_rslora,
|
||||
use_dora=use_dora,
|
||||
lora_target=lora_target,
|
||||
additional_target=additional_target,
|
||||
)
|
||||
)
|
||||
|
||||
with gr.Accordion(open=False) as rlhf_tab:
|
||||
with gr.Row():
|
||||
dpo_beta = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.01)
|
||||
dpo_ftx = gr.Slider(minimum=0, maximum=10, value=0, step=0.01)
|
||||
orpo_beta = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.01)
|
||||
reward_model = gr.Dropdown(multiselect=True, allow_custom_value=True)
|
||||
|
||||
input_elems.update({dpo_beta, dpo_ftx, orpo_beta, reward_model})
|
||||
elem_dict.update(
|
||||
dict(rlhf_tab=rlhf_tab, dpo_beta=dpo_beta, dpo_ftx=dpo_ftx, orpo_beta=orpo_beta, reward_model=reward_model)
|
||||
)
|
||||
|
||||
with gr.Accordion(open=False) as galore_tab:
|
||||
with gr.Row():
|
||||
use_galore = gr.Checkbox()
|
||||
galore_rank = gr.Slider(minimum=1, maximum=1024, value=16, step=1)
|
||||
galore_update_interval = gr.Slider(minimum=1, maximum=1024, value=200, step=1)
|
||||
galore_scale = gr.Slider(minimum=0, maximum=1, value=0.25, step=0.01)
|
||||
galore_target = gr.Textbox(value="all")
|
||||
|
||||
input_elems.update({use_galore, galore_rank, galore_update_interval, galore_scale, galore_target})
|
||||
elem_dict.update(
|
||||
dict(
|
||||
galore_tab=galore_tab,
|
||||
use_galore=use_galore,
|
||||
galore_rank=galore_rank,
|
||||
galore_update_interval=galore_update_interval,
|
||||
galore_scale=galore_scale,
|
||||
galore_target=galore_target,
|
||||
)
|
||||
)
|
||||
|
||||
with gr.Accordion(open=False) as badam_tab:
|
||||
with gr.Row():
|
||||
use_badam = gr.Checkbox()
|
||||
badam_mode = gr.Dropdown(choices=["layer", "ratio"], value="layer")
|
||||
badam_switch_mode = gr.Dropdown(choices=["ascending", "descending", "random", "fixed"], value="ascending")
|
||||
badam_switch_interval = gr.Slider(minimum=1, maximum=1024, value=50, step=1)
|
||||
badam_update_ratio = gr.Slider(minimum=0, maximum=1, value=0.05, step=0.01)
|
||||
|
||||
input_elems.update({use_badam, badam_mode, badam_switch_mode, badam_switch_interval, badam_update_ratio})
|
||||
elem_dict.update(
|
||||
dict(
|
||||
badam_tab=badam_tab,
|
||||
use_badam=use_badam,
|
||||
badam_mode=badam_mode,
|
||||
badam_switch_mode=badam_switch_mode,
|
||||
badam_switch_interval=badam_switch_interval,
|
||||
badam_update_ratio=badam_update_ratio,
|
||||
)
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
cmd_preview_btn = gr.Button()
|
||||
arg_save_btn = gr.Button()
|
||||
arg_load_btn = gr.Button()
|
||||
start_btn = gr.Button(variant="primary")
|
||||
stop_btn = gr.Button(variant="stop")
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column(scale=3):
|
||||
with gr.Row():
|
||||
output_dir = gr.Textbox()
|
||||
config_path = gr.Textbox()
|
||||
|
||||
with gr.Row():
|
||||
resume_btn = gr.Checkbox(visible=False, interactive=False)
|
||||
progress_bar = gr.Slider(visible=False, interactive=False)
|
||||
|
||||
with gr.Row():
|
||||
output_box = gr.Markdown()
|
||||
|
||||
with gr.Column(scale=1):
|
||||
loss_viewer = gr.Plot()
|
||||
|
||||
elem_dict.update(
|
||||
dict(
|
||||
cmd_preview_btn=cmd_preview_btn,
|
||||
arg_save_btn=arg_save_btn,
|
||||
arg_load_btn=arg_load_btn,
|
||||
start_btn=start_btn,
|
||||
stop_btn=stop_btn,
|
||||
output_dir=output_dir,
|
||||
config_path=config_path,
|
||||
resume_btn=resume_btn,
|
||||
progress_bar=progress_bar,
|
||||
output_box=output_box,
|
||||
loss_viewer=loss_viewer,
|
||||
)
|
||||
)
|
||||
|
||||
input_elems.update({output_dir, config_path})
|
||||
output_elems = [output_box, progress_bar, loss_viewer]
|
||||
|
||||
cmd_preview_btn.click(engine.runner.preview_train, input_elems, output_elems, concurrency_limit=None)
|
||||
arg_save_btn.click(engine.runner.save_args, input_elems, output_elems, concurrency_limit=None)
|
||||
arg_load_btn.click(
|
||||
engine.runner.load_args,
|
||||
[engine.manager.get_elem_by_id("top.lang"), config_path],
|
||||
list(input_elems) + [output_box],
|
||||
concurrency_limit=None,
|
||||
)
|
||||
start_btn.click(engine.runner.run_train, input_elems, output_elems)
|
||||
stop_btn.click(engine.runner.set_abort)
|
||||
resume_btn.change(engine.runner.monitor, outputs=output_elems, concurrency_limit=None)
|
||||
|
||||
dataset_dir.change(list_dataset, [dataset_dir, training_stage], [dataset], queue=False)
|
||||
training_stage.change(list_dataset, [dataset_dir, training_stage], [dataset], queue=False).then(
|
||||
list_adapters,
|
||||
[engine.manager.get_elem_by_id("top.model_name"), engine.manager.get_elem_by_id("top.finetuning_type")],
|
||||
[reward_model],
|
||||
queue=False,
|
||||
).then(autoset_packing, [training_stage], [packing], queue=False)
|
||||
|
||||
return elem_dict
|
||||
Reference in New Issue
Block a user