mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-22 13:42:51 +08:00
support pagination in webui preview
Former-commit-id: c1edb0cf1b2a4d52506fc9e15353dfbe513e5d8f
This commit is contained in:
parent
c3fab5307b
commit
89c1a80920
@ -38,10 +38,10 @@ def export_model(args: Optional[Dict[str, Any]] = None, max_shard_size: Optional
|
|||||||
model_args, _, finetuning_args, _ = get_infer_args(args)
|
model_args, _, finetuning_args, _ = get_infer_args(args)
|
||||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
|
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
|
||||||
model.config.use_cache = True
|
model.config.use_cache = True
|
||||||
tokenizer.padding_side = "left" # restore padding side
|
|
||||||
tokenizer.init_kwargs["padding_side"] = "left"
|
|
||||||
model.save_pretrained(model_args.export_dir, max_shard_size=max_shard_size)
|
model.save_pretrained(model_args.export_dir, max_shard_size=max_shard_size)
|
||||||
try:
|
try:
|
||||||
|
tokenizer.padding_side = "left" # restore padding side
|
||||||
|
tokenizer.init_kwargs["padding_side"] = "left"
|
||||||
tokenizer.save_pretrained(model_args.export_dir)
|
tokenizer.save_pretrained(model_args.export_dir)
|
||||||
except:
|
except:
|
||||||
logger.warning("Cannot save tokenizer, please copy the files manually.")
|
logger.warning("Cannot save tokenizer, please copy the files manually.")
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import gradio as gr
|
||||||
from gradio.components import Component # cannot use TYPE_CHECKING here
|
from gradio.components import Component # cannot use TYPE_CHECKING here
|
||||||
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple
|
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple
|
||||||
|
|
||||||
@ -28,16 +29,17 @@ class WebChatModel(ChatModel):
|
|||||||
def load_model(self, data: Dict[Component, Any]) -> Generator[str, None, None]:
|
def load_model(self, data: Dict[Component, Any]) -> Generator[str, None, None]:
|
||||||
get = lambda name: data[self.manager.get_elem_by_name(name)]
|
get = lambda name: data[self.manager.get_elem_by_name(name)]
|
||||||
lang = get("top.lang")
|
lang = get("top.lang")
|
||||||
|
error = ""
|
||||||
if self.loaded:
|
if self.loaded:
|
||||||
yield ALERTS["err_exists"][lang]
|
error = ALERTS["err_exists"][lang]
|
||||||
return
|
elif not get("top.model_name"):
|
||||||
|
error = ALERTS["err_no_model"][lang]
|
||||||
|
elif not get("top.model_path"):
|
||||||
|
error = ALERTS["err_no_path"][lang]
|
||||||
|
|
||||||
if not get("top.model_name"):
|
if error:
|
||||||
yield ALERTS["err_no_model"][lang]
|
gr.Warning(error)
|
||||||
return
|
yield error
|
||||||
|
|
||||||
if not get("top.model_path"):
|
|
||||||
yield ALERTS["err_no_path"][lang]
|
|
||||||
return
|
return
|
||||||
|
|
||||||
if get("top.checkpoints"):
|
if get("top.checkpoints"):
|
||||||
|
@ -11,11 +11,9 @@ def create_chat_box(
|
|||||||
engine: "Engine",
|
engine: "Engine",
|
||||||
visible: Optional[bool] = False
|
visible: Optional[bool] = False
|
||||||
) -> Tuple["Block", "Component", "Component", Dict[str, "Component"]]:
|
) -> Tuple["Block", "Component", "Component", Dict[str, "Component"]]:
|
||||||
elem_dict = dict()
|
|
||||||
|
|
||||||
with gr.Box(visible=visible) as chat_box:
|
with gr.Box(visible=visible) as chat_box:
|
||||||
chatbot = gr.Chatbot()
|
chatbot = gr.Chatbot()
|
||||||
|
history = gr.State([])
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column(scale=4):
|
with gr.Column(scale=4):
|
||||||
system = gr.Textbox(show_label=False)
|
system = gr.Textbox(show_label=False)
|
||||||
@ -29,13 +27,6 @@ def create_chat_box(
|
|||||||
top_p = gr.Slider(0.01, 1, value=gen_kwargs.top_p, step=0.01)
|
top_p = gr.Slider(0.01, 1, value=gen_kwargs.top_p, step=0.01)
|
||||||
temperature = gr.Slider(0.01, 1.5, value=gen_kwargs.temperature, step=0.01)
|
temperature = gr.Slider(0.01, 1.5, value=gen_kwargs.temperature, step=0.01)
|
||||||
|
|
||||||
elem_dict.update(dict(
|
|
||||||
system=system, query=query, submit_btn=submit_btn, clear_btn=clear_btn,
|
|
||||||
max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature
|
|
||||||
))
|
|
||||||
|
|
||||||
history = gr.State([])
|
|
||||||
|
|
||||||
submit_btn.click(
|
submit_btn.click(
|
||||||
engine.chatter.predict,
|
engine.chatter.predict,
|
||||||
[chatbot, query, history, system, max_new_tokens, top_p, temperature],
|
[chatbot, query, history, system, max_new_tokens, top_p, temperature],
|
||||||
@ -47,4 +38,12 @@ def create_chat_box(
|
|||||||
|
|
||||||
clear_btn.click(lambda: ([], []), outputs=[chatbot, history], show_progress=True)
|
clear_btn.click(lambda: ([], []), outputs=[chatbot, history], show_progress=True)
|
||||||
|
|
||||||
return chat_box, chatbot, history, elem_dict
|
return chat_box, chatbot, history, dict(
|
||||||
|
system=system,
|
||||||
|
query=query,
|
||||||
|
submit_btn=submit_btn,
|
||||||
|
clear_btn=clear_btn,
|
||||||
|
max_new_tokens=max_new_tokens,
|
||||||
|
top_p=top_p,
|
||||||
|
temperature=temperature
|
||||||
|
)
|
||||||
|
@ -1,17 +1,103 @@
|
|||||||
|
import os
|
||||||
|
import json
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
from typing import TYPE_CHECKING, Tuple
|
from typing import TYPE_CHECKING, Any, Dict, Tuple
|
||||||
|
|
||||||
|
from llmtuner.webui.common import DATA_CONFIG
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from gradio.blocks import Block
|
|
||||||
from gradio.components import Component
|
from gradio.components import Component
|
||||||
|
|
||||||
|
|
||||||
def create_preview_box() -> Tuple["Block", "Component", "Component", "Component"]:
|
PAGE_SIZE = 2
|
||||||
|
|
||||||
|
|
||||||
|
def prev_page(page_index: int) -> int:
|
||||||
|
return page_index - 1 if page_index > 0 else page_index
|
||||||
|
|
||||||
|
|
||||||
|
def next_page(page_index: int, total_num: int) -> int:
|
||||||
|
return page_index + 1 if (page_index + 1) * PAGE_SIZE < total_num else page_index
|
||||||
|
|
||||||
|
|
||||||
|
def can_preview(dataset_dir: str, dataset: list) -> Dict[str, Any]:
|
||||||
|
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
|
||||||
|
dataset_info = json.load(f)
|
||||||
|
|
||||||
|
if (
|
||||||
|
len(dataset) > 0
|
||||||
|
and "file_name" in dataset_info[dataset[0]]
|
||||||
|
and os.path.isfile(os.path.join(dataset_dir, dataset_info[dataset[0]]["file_name"]))
|
||||||
|
):
|
||||||
|
return gr.update(interactive=True)
|
||||||
|
else:
|
||||||
|
return gr.update(interactive=False)
|
||||||
|
|
||||||
|
|
||||||
|
def get_preview(dataset_dir: str, dataset: list, page_index: int) -> Tuple[int, list, Dict[str, Any]]:
|
||||||
|
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
|
||||||
|
dataset_info = json.load(f)
|
||||||
|
|
||||||
|
data_file: str = dataset_info[dataset[0]]["file_name"]
|
||||||
|
with open(os.path.join(dataset_dir, data_file), "r", encoding="utf-8") as f:
|
||||||
|
if data_file.endswith(".json"):
|
||||||
|
data = json.load(f)
|
||||||
|
elif data_file.endswith(".jsonl"):
|
||||||
|
data = [json.loads(line) for line in f]
|
||||||
|
else:
|
||||||
|
data = [line for line in f]
|
||||||
|
return len(data), data[PAGE_SIZE * page_index : PAGE_SIZE * (page_index + 1)], gr.update(visible=True)
|
||||||
|
|
||||||
|
|
||||||
|
def create_preview_box(dataset_dir: "gr.Textbox", dataset: "gr.Dropdown") -> Dict[str, "Component"]:
|
||||||
|
data_preview_btn = gr.Button(interactive=False, scale=1)
|
||||||
with gr.Column(visible=False, elem_classes="modal-box") as preview_box:
|
with gr.Column(visible=False, elem_classes="modal-box") as preview_box:
|
||||||
preview_count = gr.Number(interactive=False)
|
with gr.Row():
|
||||||
preview_samples = gr.JSON(interactive=False)
|
preview_count = gr.Number(value=0, interactive=False, precision=0)
|
||||||
close_btn = gr.Button()
|
page_index = gr.Number(value=0, interactive=False, precision=0)
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
prev_btn = gr.Button()
|
||||||
|
next_btn = gr.Button()
|
||||||
|
close_btn = gr.Button()
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
preview_samples = gr.JSON(interactive=False)
|
||||||
|
|
||||||
|
dataset.change(
|
||||||
|
can_preview, [dataset_dir, dataset], [data_preview_btn], queue=False
|
||||||
|
).then(
|
||||||
|
lambda: 0, outputs=[page_index], queue=False
|
||||||
|
)
|
||||||
|
data_preview_btn.click(
|
||||||
|
get_preview,
|
||||||
|
[dataset_dir, dataset, page_index],
|
||||||
|
[preview_count, preview_samples, preview_box],
|
||||||
|
queue=False
|
||||||
|
)
|
||||||
|
prev_btn.click(
|
||||||
|
prev_page, [page_index], [page_index], queue=False
|
||||||
|
).then(
|
||||||
|
get_preview,
|
||||||
|
[dataset_dir, dataset, page_index],
|
||||||
|
[preview_count, preview_samples, preview_box],
|
||||||
|
queue=False
|
||||||
|
)
|
||||||
|
next_btn.click(
|
||||||
|
next_page, [page_index, preview_count], [page_index], queue=False
|
||||||
|
).then(
|
||||||
|
get_preview,
|
||||||
|
[dataset_dir, dataset, page_index],
|
||||||
|
[preview_count, preview_samples, preview_box],
|
||||||
|
queue=False
|
||||||
|
)
|
||||||
close_btn.click(lambda: gr.update(visible=False), outputs=[preview_box], queue=False)
|
close_btn.click(lambda: gr.update(visible=False), outputs=[preview_box], queue=False)
|
||||||
|
return dict(
|
||||||
return preview_box, preview_count, preview_samples, close_btn
|
data_preview_btn=data_preview_btn,
|
||||||
|
preview_count=preview_count,
|
||||||
|
page_index=page_index,
|
||||||
|
prev_btn=prev_btn,
|
||||||
|
next_btn=next_btn,
|
||||||
|
close_btn=close_btn,
|
||||||
|
preview_samples=preview_samples
|
||||||
|
)
|
||||||
|
@ -3,7 +3,6 @@ from typing import TYPE_CHECKING, Dict
|
|||||||
|
|
||||||
from llmtuner.webui.common import list_dataset, DEFAULT_DATA_DIR
|
from llmtuner.webui.common import list_dataset, DEFAULT_DATA_DIR
|
||||||
from llmtuner.webui.components.data import create_preview_box
|
from llmtuner.webui.components.data import create_preview_box
|
||||||
from llmtuner.webui.utils import can_preview, get_preview
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from gradio.components import Component
|
from gradio.components import Component
|
||||||
@ -17,28 +16,12 @@ def create_eval_tab(engine: "Engine") -> Dict[str, "Component"]:
|
|||||||
with gr.Row():
|
with gr.Row():
|
||||||
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2)
|
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2)
|
||||||
dataset = gr.Dropdown(multiselect=True, scale=4)
|
dataset = gr.Dropdown(multiselect=True, scale=4)
|
||||||
data_preview_btn = gr.Button(interactive=False, scale=1)
|
preview_elems = create_preview_box(dataset_dir, dataset)
|
||||||
|
|
||||||
dataset_dir.change(list_dataset, [dataset_dir], [dataset], queue=False)
|
dataset_dir.change(list_dataset, [dataset_dir], [dataset], queue=False)
|
||||||
dataset.change(can_preview, [dataset_dir, dataset], [data_preview_btn], queue=False)
|
|
||||||
|
|
||||||
input_elems.update({dataset_dir, dataset})
|
input_elems.update({dataset_dir, dataset})
|
||||||
elem_dict.update(dict(
|
elem_dict.update(dict(dataset_dir=dataset_dir, dataset=dataset, **preview_elems))
|
||||||
dataset_dir=dataset_dir, dataset=dataset, data_preview_btn=data_preview_btn
|
|
||||||
))
|
|
||||||
|
|
||||||
preview_box, preview_count, preview_samples, close_btn = create_preview_box()
|
|
||||||
|
|
||||||
data_preview_btn.click(
|
|
||||||
get_preview,
|
|
||||||
[dataset_dir, dataset],
|
|
||||||
[preview_count, preview_samples, preview_box],
|
|
||||||
queue=False
|
|
||||||
)
|
|
||||||
|
|
||||||
elem_dict.update(dict(
|
|
||||||
preview_count=preview_count, preview_samples=preview_samples, close_btn=close_btn
|
|
||||||
))
|
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
cutoff_len = gr.Slider(value=1024, minimum=4, maximum=8192, step=1)
|
cutoff_len = gr.Slider(value=1024, minimum=4, maximum=8192, step=1)
|
||||||
|
@ -1,16 +1,54 @@
|
|||||||
import gradio as gr
|
import gradio as gr
|
||||||
from typing import TYPE_CHECKING, Dict
|
from typing import TYPE_CHECKING, Dict, Generator, List
|
||||||
|
|
||||||
from llmtuner.webui.utils import save_model
|
from llmtuner.tuner import export_model
|
||||||
|
from llmtuner.webui.common import get_save_dir
|
||||||
|
from llmtuner.webui.locales import ALERTS
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from gradio.components import Component
|
from gradio.components import Component
|
||||||
from llmtuner.webui.engine import Engine
|
from llmtuner.webui.engine import Engine
|
||||||
|
|
||||||
|
|
||||||
def create_export_tab(engine: "Engine") -> Dict[str, "Component"]:
|
def save_model(
|
||||||
elem_dict = dict()
|
lang: str,
|
||||||
|
model_name: str,
|
||||||
|
model_path: str,
|
||||||
|
checkpoints: List[str],
|
||||||
|
finetuning_type: str,
|
||||||
|
template: str,
|
||||||
|
max_shard_size: int,
|
||||||
|
export_dir: str
|
||||||
|
) -> Generator[str, None, None]:
|
||||||
|
error = ""
|
||||||
|
if not model_name:
|
||||||
|
error = ALERTS["err_no_model"][lang]
|
||||||
|
elif not model_path:
|
||||||
|
error = ALERTS["err_no_path"][lang]
|
||||||
|
elif not checkpoints:
|
||||||
|
error = ALERTS["err_no_checkpoint"][lang]
|
||||||
|
elif not export_dir:
|
||||||
|
error = ALERTS["err_no_export_dir"][lang]
|
||||||
|
|
||||||
|
if error:
|
||||||
|
gr.Warning(error)
|
||||||
|
yield error
|
||||||
|
return
|
||||||
|
|
||||||
|
args = dict(
|
||||||
|
model_name_or_path=model_path,
|
||||||
|
checkpoint_dir=",".join([get_save_dir(model_name, finetuning_type, ckpt) for ckpt in checkpoints]),
|
||||||
|
finetuning_type=finetuning_type,
|
||||||
|
template=template,
|
||||||
|
export_dir=export_dir
|
||||||
|
)
|
||||||
|
|
||||||
|
yield ALERTS["info_exporting"][lang]
|
||||||
|
export_model(args, max_shard_size="{}GB".format(max_shard_size))
|
||||||
|
yield ALERTS["info_exported"][lang]
|
||||||
|
|
||||||
|
|
||||||
|
def create_export_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
export_dir = gr.Textbox()
|
export_dir = gr.Textbox()
|
||||||
max_shard_size = gr.Slider(value=10, minimum=1, maximum=100)
|
max_shard_size = gr.Slider(value=10, minimum=1, maximum=100)
|
||||||
@ -33,11 +71,9 @@ def create_export_tab(engine: "Engine") -> Dict[str, "Component"]:
|
|||||||
[info_box]
|
[info_box]
|
||||||
)
|
)
|
||||||
|
|
||||||
elem_dict.update(dict(
|
return dict(
|
||||||
export_dir=export_dir,
|
export_dir=export_dir,
|
||||||
max_shard_size=max_shard_size,
|
max_shard_size=max_shard_size,
|
||||||
export_btn=export_btn,
|
export_btn=export_btn,
|
||||||
info_box=info_box
|
info_box=info_box
|
||||||
))
|
)
|
||||||
|
|
||||||
return elem_dict
|
|
||||||
|
@ -5,7 +5,7 @@ from transformers.trainer_utils import SchedulerType
|
|||||||
from llmtuner.extras.constants import TRAINING_STAGES
|
from llmtuner.extras.constants import TRAINING_STAGES
|
||||||
from llmtuner.webui.common import list_checkpoint, list_dataset, DEFAULT_DATA_DIR
|
from llmtuner.webui.common import list_checkpoint, list_dataset, DEFAULT_DATA_DIR
|
||||||
from llmtuner.webui.components.data import create_preview_box
|
from llmtuner.webui.components.data import create_preview_box
|
||||||
from llmtuner.webui.utils import can_preview, get_preview, gen_plot
|
from llmtuner.webui.utils import gen_plot
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from gradio.components import Component
|
from gradio.components import Component
|
||||||
@ -22,28 +22,14 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
|||||||
)
|
)
|
||||||
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2)
|
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2)
|
||||||
dataset = gr.Dropdown(multiselect=True, scale=4)
|
dataset = gr.Dropdown(multiselect=True, scale=4)
|
||||||
data_preview_btn = gr.Button(interactive=False, scale=1)
|
preview_elems = create_preview_box(dataset_dir, dataset)
|
||||||
|
|
||||||
training_stage.change(list_dataset, [dataset_dir, training_stage], [dataset], queue=False)
|
training_stage.change(list_dataset, [dataset_dir, training_stage], [dataset], queue=False)
|
||||||
dataset_dir.change(list_dataset, [dataset_dir, training_stage], [dataset], queue=False)
|
dataset_dir.change(list_dataset, [dataset_dir, training_stage], [dataset], queue=False)
|
||||||
dataset.change(can_preview, [dataset_dir, dataset], [data_preview_btn], queue=False)
|
|
||||||
|
|
||||||
input_elems.update({training_stage, dataset_dir, dataset})
|
input_elems.update({training_stage, dataset_dir, dataset})
|
||||||
elem_dict.update(dict(
|
elem_dict.update(dict(
|
||||||
training_stage=training_stage, dataset_dir=dataset_dir, dataset=dataset, data_preview_btn=data_preview_btn
|
training_stage=training_stage, dataset_dir=dataset_dir, dataset=dataset, **preview_elems
|
||||||
))
|
|
||||||
|
|
||||||
preview_box, preview_count, preview_samples, close_btn = create_preview_box()
|
|
||||||
|
|
||||||
data_preview_btn.click(
|
|
||||||
get_preview,
|
|
||||||
[dataset_dir, dataset],
|
|
||||||
[preview_count, preview_samples, preview_box],
|
|
||||||
queue=False
|
|
||||||
)
|
|
||||||
|
|
||||||
elem_dict.update(dict(
|
|
||||||
preview_count=preview_count, preview_samples=preview_samples, close_btn=close_btn
|
|
||||||
))
|
))
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
@ -143,16 +129,17 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
|||||||
|
|
||||||
input_elems.add(output_dir)
|
input_elems.add(output_dir)
|
||||||
output_elems = [output_box, process_bar]
|
output_elems = [output_box, process_bar]
|
||||||
elem_dict.update(dict(
|
|
||||||
cmd_preview_btn=cmd_preview_btn, start_btn=start_btn, stop_btn=stop_btn, output_dir=output_dir,
|
|
||||||
resume_btn=resume_btn, process_bar=process_bar, output_box=output_box, loss_viewer=loss_viewer
|
|
||||||
))
|
|
||||||
|
|
||||||
cmd_preview_btn.click(engine.runner.preview_train, input_elems, output_elems)
|
cmd_preview_btn.click(engine.runner.preview_train, input_elems, output_elems)
|
||||||
start_btn.click(engine.runner.run_train, input_elems, output_elems)
|
start_btn.click(engine.runner.run_train, input_elems, output_elems)
|
||||||
stop_btn.click(engine.runner.set_abort, queue=False)
|
stop_btn.click(engine.runner.set_abort, queue=False)
|
||||||
resume_btn.change(engine.runner.monitor, outputs=output_elems)
|
resume_btn.change(engine.runner.monitor, outputs=output_elems)
|
||||||
|
|
||||||
|
elem_dict.update(dict(
|
||||||
|
cmd_preview_btn=cmd_preview_btn, start_btn=start_btn, stop_btn=stop_btn, output_dir=output_dir,
|
||||||
|
resume_btn=resume_btn, process_bar=process_bar, output_box=output_box, loss_viewer=loss_viewer
|
||||||
|
))
|
||||||
|
|
||||||
output_box.change(
|
output_box.change(
|
||||||
gen_plot,
|
gen_plot,
|
||||||
[
|
[
|
||||||
|
@ -6,7 +6,9 @@ CSS = r"""
|
|||||||
transform: translate(-50%, -50%); /* center horizontally */
|
transform: translate(-50%, -50%); /* center horizontally */
|
||||||
max-width: 1000px;
|
max-width: 1000px;
|
||||||
max-height: 750px;
|
max-height: 750px;
|
||||||
|
overflow-y: auto;
|
||||||
background-color: var(--input-background-fill);
|
background-color: var(--input-background-fill);
|
||||||
|
flex-wrap: nowrap !important;
|
||||||
border: 2px solid black !important;
|
border: 2px solid black !important;
|
||||||
z-index: 1000;
|
z-index: 1000;
|
||||||
padding: 10px;
|
padding: 10px;
|
||||||
|
@ -163,12 +163,28 @@ LOCALES = {
|
|||||||
"label": "数量"
|
"label": "数量"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"preview_samples": {
|
"page_index": {
|
||||||
"en": {
|
"en": {
|
||||||
"label": "Samples"
|
"label": "Page"
|
||||||
},
|
},
|
||||||
"zh": {
|
"zh": {
|
||||||
"label": "样例"
|
"label": "页数"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"prev_btn": {
|
||||||
|
"en": {
|
||||||
|
"value": "Prev"
|
||||||
|
},
|
||||||
|
"zh": {
|
||||||
|
"value": "上一页"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"next_btn": {
|
||||||
|
"en": {
|
||||||
|
"value": "Next"
|
||||||
|
},
|
||||||
|
"zh": {
|
||||||
|
"value": "下一页"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"close_btn": {
|
"close_btn": {
|
||||||
@ -179,6 +195,14 @@ LOCALES = {
|
|||||||
"value": "关闭"
|
"value": "关闭"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"preview_samples": {
|
||||||
|
"en": {
|
||||||
|
"label": "Samples"
|
||||||
|
},
|
||||||
|
"zh": {
|
||||||
|
"label": "样例"
|
||||||
|
}
|
||||||
|
},
|
||||||
"cutoff_len": {
|
"cutoff_len": {
|
||||||
"en": {
|
"en": {
|
||||||
"label": "Cutoff length",
|
"label": "Cutoff length",
|
||||||
|
@ -3,13 +3,11 @@ import json
|
|||||||
import gradio as gr
|
import gradio as gr
|
||||||
import matplotlib.figure
|
import matplotlib.figure
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple
|
from typing import TYPE_CHECKING, Any, Dict
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from llmtuner.extras.ploting import smooth
|
from llmtuner.extras.ploting import smooth
|
||||||
from llmtuner.tuner import export_model
|
from llmtuner.webui.common import get_save_dir
|
||||||
from llmtuner.webui.common import get_save_dir, DATA_CONFIG
|
|
||||||
from llmtuner.webui.locales import ALERTS
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from llmtuner.extras.callbacks import LogCallback
|
from llmtuner.extras.callbacks import LogCallback
|
||||||
@ -33,37 +31,6 @@ def get_time() -> str:
|
|||||||
return datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
|
return datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
|
||||||
|
|
||||||
|
|
||||||
def can_preview(dataset_dir: str, dataset: list) -> Dict[str, Any]:
|
|
||||||
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
|
|
||||||
dataset_info = json.load(f)
|
|
||||||
|
|
||||||
if (
|
|
||||||
len(dataset) > 0
|
|
||||||
and "file_name" in dataset_info[dataset[0]]
|
|
||||||
and os.path.isfile(os.path.join(dataset_dir, dataset_info[dataset[0]]["file_name"]))
|
|
||||||
):
|
|
||||||
return gr.update(interactive=True)
|
|
||||||
else:
|
|
||||||
return gr.update(interactive=False)
|
|
||||||
|
|
||||||
|
|
||||||
def get_preview(
|
|
||||||
dataset_dir: str, dataset: list, start: Optional[int] = 0, end: Optional[int] = 2
|
|
||||||
) -> Tuple[int, list, Dict[str, Any]]:
|
|
||||||
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
|
|
||||||
dataset_info = json.load(f)
|
|
||||||
|
|
||||||
data_file: str = dataset_info[dataset[0]]["file_name"]
|
|
||||||
with open(os.path.join(dataset_dir, data_file), "r", encoding="utf-8") as f:
|
|
||||||
if data_file.endswith(".json"):
|
|
||||||
data = json.load(f)
|
|
||||||
elif data_file.endswith(".jsonl"):
|
|
||||||
data = [json.loads(line) for line in f]
|
|
||||||
else:
|
|
||||||
data = [line for line in f]
|
|
||||||
return len(data), data[start:end], gr.update(visible=True)
|
|
||||||
|
|
||||||
|
|
||||||
def can_quantize(finetuning_type: str) -> Dict[str, Any]:
|
def can_quantize(finetuning_type: str) -> Dict[str, Any]:
|
||||||
if finetuning_type != "lora":
|
if finetuning_type != "lora":
|
||||||
return gr.update(value="None", interactive=False)
|
return gr.update(value="None", interactive=False)
|
||||||
@ -116,42 +83,3 @@ def gen_plot(base_model: str, finetuning_type: str, output_dir: str) -> matplotl
|
|||||||
ax.set_xlabel("step")
|
ax.set_xlabel("step")
|
||||||
ax.set_ylabel("loss")
|
ax.set_ylabel("loss")
|
||||||
return fig
|
return fig
|
||||||
|
|
||||||
|
|
||||||
def save_model(
|
|
||||||
lang: str,
|
|
||||||
model_name: str,
|
|
||||||
model_path: str,
|
|
||||||
checkpoints: List[str],
|
|
||||||
finetuning_type: str,
|
|
||||||
template: str,
|
|
||||||
max_shard_size: int,
|
|
||||||
export_dir: str
|
|
||||||
) -> Generator[str, None, None]:
|
|
||||||
if not model_name:
|
|
||||||
yield ALERTS["err_no_model"][lang]
|
|
||||||
return
|
|
||||||
|
|
||||||
if not model_path:
|
|
||||||
yield ALERTS["err_no_path"][lang]
|
|
||||||
return
|
|
||||||
|
|
||||||
if not checkpoints:
|
|
||||||
yield ALERTS["err_no_checkpoint"][lang]
|
|
||||||
return
|
|
||||||
|
|
||||||
if not export_dir:
|
|
||||||
yield ALERTS["err_no_export_dir"][lang]
|
|
||||||
return
|
|
||||||
|
|
||||||
args = dict(
|
|
||||||
model_name_or_path=model_path,
|
|
||||||
checkpoint_dir=",".join([get_save_dir(model_name, finetuning_type, ckpt) for ckpt in checkpoints]),
|
|
||||||
finetuning_type=finetuning_type,
|
|
||||||
template=template,
|
|
||||||
export_dir=export_dir
|
|
||||||
)
|
|
||||||
|
|
||||||
yield ALERTS["info_exporting"][lang]
|
|
||||||
export_model(args, max_shard_size="{}GB".format(max_shard_size))
|
|
||||||
yield ALERTS["info_exported"][lang]
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user