format style

Former-commit-id: 53b683531b83cd1d19de97c6565f16c1eca6f5e1
This commit is contained in:
hiyouga
2024-01-20 20:15:56 +08:00
parent 01c22ad7f8
commit c0e4eebf17
73 changed files with 1492 additions and 2325 deletions

View File

@@ -1,24 +1,22 @@
import gradio as gr
from gradio.components import Component # cannot use TYPE_CHECKING here
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple
import gradio as gr
from gradio.components import Component # cannot use TYPE_CHECKING here
from ..chat import ChatModel
from ..extras.misc import torch_gc
from ..hparams import GeneratingArguments
from .common import get_save_dir
from .locales import ALERTS
if TYPE_CHECKING:
from .manager import Manager
class WebChatModel(ChatModel):
def __init__(
self,
manager: "Manager",
demo_mode: Optional[bool] = False,
lazy_init: Optional[bool] = True
self, manager: "Manager", demo_mode: Optional[bool] = False, lazy_init: Optional[bool] = True
) -> None:
self.manager = manager
self.demo_mode = demo_mode
@@ -26,11 +24,12 @@ class WebChatModel(ChatModel):
self.tokenizer = None
self.generating_args = GeneratingArguments()
if not lazy_init: # read arguments from command line
if not lazy_init: # read arguments from command line
super().__init__()
if demo_mode: # load demo_config.json if exists
if demo_mode: # load demo_config.json if exists
import json
try:
with open("demo_config.json", "r", encoding="utf-8") as f:
args = json.load(f)
@@ -38,7 +37,7 @@ class WebChatModel(ChatModel):
super().__init__(args)
except AssertionError:
print("Please provided model name and template in `demo_config.json`.")
except:
except Exception:
print("Cannot find `demo_config.json` at current directory.")
@property
@@ -64,9 +63,12 @@ class WebChatModel(ChatModel):
return
if get("top.adapter_path"):
adapter_name_or_path = ",".join([
get_save_dir(get("top.model_name"), get("top.finetuning_type"), adapter)
for adapter in get("top.adapter_path")])
adapter_name_or_path = ",".join(
[
get_save_dir(get("top.model_name"), get("top.finetuning_type"), adapter)
for adapter in get("top.adapter_path")
]
)
else:
adapter_name_or_path = None
@@ -79,7 +81,7 @@ class WebChatModel(ChatModel):
template=get("top.template"),
flash_attn=(get("top.booster") == "flash_attn"),
use_unsloth=(get("top.booster") == "unsloth"),
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
)
super().__init__(args)
@@ -108,7 +110,7 @@ class WebChatModel(ChatModel):
tools: str,
max_new_tokens: int,
top_p: float,
temperature: float
temperature: float,
) -> Generator[Tuple[List[Tuple[str, str]], List[Tuple[str, str]]], None, None]:
chatbot.append([query, ""])
response = ""

View File

@@ -1,9 +1,10 @@
import os
import json
import gradio as gr
import os
from collections import defaultdict
from typing import Any, Dict, Optional
from peft.utils import WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME
import gradio as gr
from peft.utils import SAFETENSORS_WEIGHTS_NAME, WEIGHTS_NAME
from ..extras.constants import (
DATA_CONFIG,
@@ -12,7 +13,7 @@ from ..extras.constants import (
PEFT_METHODS,
SUPPORTED_MODELS,
TRAINING_STAGES,
DownloadSource
DownloadSource,
)
from ..extras.misc import use_modelscope
@@ -36,7 +37,7 @@ def load_config() -> Dict[str, Any]:
try:
with open(get_config_path(), "r", encoding="utf-8") as f:
return json.load(f)
except:
except Exception:
return {"lang": None, "last_model": None, "path_dict": {}, "cache_dir": None}
@@ -59,7 +60,7 @@ def get_model_path(model_name: str) -> str:
use_modelscope()
and path_dict.get(DownloadSource.MODELSCOPE)
and model_path == path_dict.get(DownloadSource.DEFAULT)
): # replace path
): # replace path
model_path = path_dict.get(DownloadSource.MODELSCOPE)
return model_path
@@ -87,9 +88,8 @@ def list_adapters(model_name: str, finetuning_type: str) -> Dict[str, Any]:
save_dir = get_save_dir(model_name, finetuning_type)
if save_dir and os.path.isdir(save_dir):
for adapter in os.listdir(save_dir):
if (
os.path.isdir(os.path.join(save_dir, adapter))
and any([os.path.isfile(os.path.join(save_dir, adapter, name)) for name in ADAPTER_NAMES])
if os.path.isdir(os.path.join(save_dir, adapter)) and any(
os.path.isfile(os.path.join(save_dir, adapter, name)) for name in ADAPTER_NAMES
):
adapters.append(adapter)
return gr.update(value=[], choices=adapters, interactive=True)

View File

@@ -1,11 +1,16 @@
from .chatbot import create_chat_box
from .eval import create_eval_tab
from .export import create_export_tab
from .infer import create_infer_tab
from .top import create_top
from .train import create_train_tab
from .eval import create_eval_tab
from .infer import create_infer_tab
from .export import create_export_tab
from .chatbot import create_chat_box
__all__ = [
"create_top", "create_train_tab", "create_eval_tab", "create_infer_tab", "create_export_tab", "create_chat_box"
"create_top",
"create_train_tab",
"create_eval_tab",
"create_infer_tab",
"create_export_tab",
"create_chat_box",
]

View File

@@ -1,6 +1,7 @@
import gradio as gr
from typing import TYPE_CHECKING, Dict, Optional, Tuple
import gradio as gr
from ..utils import check_json_schema
@@ -12,8 +13,7 @@ if TYPE_CHECKING:
def create_chat_box(
engine: "Engine",
visible: Optional[bool] = False
engine: "Engine", visible: Optional[bool] = False
) -> Tuple["Block", "Component", "Component", Dict[str, "Component"]]:
with gr.Box(visible=visible) as chat_box:
chatbot = gr.Chatbot()
@@ -38,20 +38,23 @@ def create_chat_box(
engine.chatter.predict,
[chatbot, query, history, system, tools, max_new_tokens, top_p, temperature],
[chatbot, history],
show_progress=True
).then(
lambda: gr.update(value=""), outputs=[query]
)
show_progress=True,
).then(lambda: gr.update(value=""), outputs=[query])
clear_btn.click(lambda: ([], []), outputs=[chatbot, history], show_progress=True)
return chat_box, chatbot, history, dict(
system=system,
tools=tools,
query=query,
submit_btn=submit_btn,
clear_btn=clear_btn,
max_new_tokens=max_new_tokens,
top_p=top_p,
temperature=temperature
return (
chat_box,
chatbot,
history,
dict(
system=system,
tools=tools,
query=query,
submit_btn=submit_btn,
clear_btn=clear_btn,
max_new_tokens=max_new_tokens,
top_p=top_p,
temperature=temperature,
),
)

View File

@@ -1,10 +1,12 @@
import os
import json
import gradio as gr
import os
from typing import TYPE_CHECKING, Any, Dict, Tuple
import gradio as gr
from ...extras.constants import DATA_CONFIG
if TYPE_CHECKING:
from gradio.components import Component
@@ -24,7 +26,7 @@ def can_preview(dataset_dir: str, dataset: list) -> Dict[str, Any]:
try:
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
dataset_info = json.load(f)
except:
except Exception:
return gr.update(interactive=False)
if (
@@ -48,7 +50,7 @@ def get_preview(dataset_dir: str, dataset: list, page_index: int) -> Tuple[int,
elif data_file.endswith(".jsonl"):
data = [json.loads(line) for line in f]
else:
data = [line for line in f]
data = [line for line in f] # noqa: C416
return len(data), data[PAGE_SIZE * page_index : PAGE_SIZE * (page_index + 1)], gr.update(visible=True)
@@ -67,32 +69,17 @@ def create_preview_box(dataset_dir: "gr.Textbox", dataset: "gr.Dropdown") -> Dic
with gr.Row():
preview_samples = gr.JSON(interactive=False)
dataset.change(
can_preview, [dataset_dir, dataset], [data_preview_btn], queue=False
).then(
dataset.change(can_preview, [dataset_dir, dataset], [data_preview_btn], queue=False).then(
lambda: 0, outputs=[page_index], queue=False
)
data_preview_btn.click(
get_preview,
[dataset_dir, dataset, page_index],
[preview_count, preview_samples, preview_box],
queue=False
get_preview, [dataset_dir, dataset, page_index], [preview_count, preview_samples, preview_box], queue=False
)
prev_btn.click(
prev_page, [page_index], [page_index], queue=False
).then(
get_preview,
[dataset_dir, dataset, page_index],
[preview_count, preview_samples, preview_box],
queue=False
prev_btn.click(prev_page, [page_index], [page_index], queue=False).then(
get_preview, [dataset_dir, dataset, page_index], [preview_count, preview_samples, preview_box], queue=False
)
next_btn.click(
next_page, [page_index, preview_count], [page_index], queue=False
).then(
get_preview,
[dataset_dir, dataset, page_index],
[preview_count, preview_samples, preview_box],
queue=False
next_btn.click(next_page, [page_index, preview_count], [page_index], queue=False).then(
get_preview, [dataset_dir, dataset, page_index], [preview_count, preview_samples, preview_box], queue=False
)
close_btn.click(lambda: gr.update(visible=False), outputs=[preview_box], queue=False)
return dict(
@@ -102,5 +89,5 @@ def create_preview_box(dataset_dir: "gr.Textbox", dataset: "gr.Dropdown") -> Dic
prev_btn=prev_btn,
next_btn=next_btn,
close_btn=close_btn,
preview_samples=preview_samples
preview_samples=preview_samples,
)

View File

@@ -1,9 +1,11 @@
import gradio as gr
from typing import TYPE_CHECKING, Dict
from ..common import list_dataset, DEFAULT_DATA_DIR
import gradio as gr
from ..common import DEFAULT_DATA_DIR, list_dataset
from .data import create_preview_box
if TYPE_CHECKING:
from gradio.components import Component
@@ -31,9 +33,7 @@ def create_eval_tab(engine: "Engine") -> Dict[str, "Component"]:
predict = gr.Checkbox(value=True)
input_elems.update({cutoff_len, max_samples, batch_size, predict})
elem_dict.update(dict(
cutoff_len=cutoff_len, max_samples=max_samples, batch_size=batch_size, predict=predict
))
elem_dict.update(dict(cutoff_len=cutoff_len, max_samples=max_samples, batch_size=batch_size, predict=predict))
with gr.Row():
max_new_tokens = gr.Slider(10, 2048, value=128, step=1)
@@ -42,9 +42,7 @@ def create_eval_tab(engine: "Engine") -> Dict[str, "Component"]:
output_dir = gr.Textbox()
input_elems.update({max_new_tokens, top_p, temperature, output_dir})
elem_dict.update(dict(
max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature, output_dir=output_dir
))
elem_dict.update(dict(max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature, output_dir=output_dir))
with gr.Row():
cmd_preview_btn = gr.Button()
@@ -59,10 +57,16 @@ def create_eval_tab(engine: "Engine") -> Dict[str, "Component"]:
output_box = gr.Markdown()
output_elems = [output_box, process_bar]
elem_dict.update(dict(
cmd_preview_btn=cmd_preview_btn, start_btn=start_btn, stop_btn=stop_btn,
resume_btn=resume_btn, process_bar=process_bar, output_box=output_box
))
elem_dict.update(
dict(
cmd_preview_btn=cmd_preview_btn,
start_btn=start_btn,
stop_btn=stop_btn,
resume_btn=resume_btn,
process_bar=process_bar,
output_box=output_box,
)
)
cmd_preview_btn.click(engine.runner.preview_eval, input_elems, output_elems)
start_btn.click(engine.runner.run_eval, input_elems, output_elems)

View File

@@ -1,10 +1,12 @@
import gradio as gr
from typing import TYPE_CHECKING, Dict, Generator, List
import gradio as gr
from ...train import export_model
from ..common import get_save_dir
from ..locales import ALERTS
if TYPE_CHECKING:
from gradio.components import Component
@@ -24,7 +26,7 @@ def save_model(
max_shard_size: int,
export_quantization_bit: int,
export_quantization_dataset: str,
export_dir: str
export_dir: str,
) -> Generator[str, None, None]:
error = ""
if not model_name:
@@ -44,7 +46,9 @@ def save_model(
return
if adapter_path:
adapter_name_or_path = ",".join([get_save_dir(model_name, finetuning_type, adapter) for adapter in adapter_path])
adapter_name_or_path = ",".join(
[get_save_dir(model_name, finetuning_type, adapter) for adapter in adapter_path]
)
else:
adapter_name_or_path = None
@@ -56,7 +60,7 @@ def save_model(
export_dir=export_dir,
export_size=max_shard_size,
export_quantization_bit=int(export_quantization_bit) if export_quantization_bit in GPTQ_BITS else None,
export_quantization_dataset=export_quantization_dataset
export_quantization_dataset=export_quantization_dataset,
)
yield ALERTS["info_exporting"][lang]
@@ -86,9 +90,9 @@ def create_export_tab(engine: "Engine") -> Dict[str, "Component"]:
max_shard_size,
export_quantization_bit,
export_quantization_dataset,
export_dir
export_dir,
],
[info_box]
[info_box],
)
return dict(
@@ -97,5 +101,5 @@ def create_export_tab(engine: "Engine") -> Dict[str, "Component"]:
export_quantization_dataset=export_quantization_dataset,
export_dir=export_dir,
export_btn=export_btn,
info_box=info_box
info_box=info_box,
)

View File

@@ -1,8 +1,10 @@
import gradio as gr
from typing import TYPE_CHECKING, Dict
import gradio as gr
from .chatbot import create_chat_box
if TYPE_CHECKING:
from gradio.components import Component
@@ -23,18 +25,12 @@ def create_infer_tab(engine: "Engine") -> Dict[str, "Component"]:
chat_box, chatbot, history, chat_elems = create_chat_box(engine, visible=False)
elem_dict.update(dict(chat_box=chat_box, **chat_elems))
load_btn.click(
engine.chatter.load_model, input_elems, [info_box]
).then(
load_btn.click(engine.chatter.load_model, input_elems, [info_box]).then(
lambda: gr.update(visible=engine.chatter.loaded), outputs=[chat_box]
)
unload_btn.click(
engine.chatter.unload_model, input_elems, [info_box]
).then(
unload_btn.click(engine.chatter.unload_model, input_elems, [info_box]).then(
lambda: ([], []), outputs=[chatbot, history]
).then(
lambda: gr.update(visible=engine.chatter.loaded), outputs=[chat_box]
)
).then(lambda: gr.update(visible=engine.chatter.loaded), outputs=[chat_box])
return elem_dict

View File

@@ -1,11 +1,13 @@
import gradio as gr
from typing import TYPE_CHECKING, Dict
import gradio as gr
from ...data import templates
from ...extras.constants import METHODS, SUPPORTED_MODELS
from ..common import get_model_path, get_template, list_adapters, save_config
from ..utils import can_quantize
if TYPE_CHECKING:
from gradio.components import Component
@@ -30,25 +32,19 @@ def create_top() -> Dict[str, "Component"]:
rope_scaling = gr.Radio(choices=["none", "linear", "dynamic"], value="none")
booster = gr.Radio(choices=["none", "flash_attn", "unsloth"], value="none")
model_name.change(
list_adapters, [model_name, finetuning_type], [adapter_path], queue=False
).then(
model_name.change(list_adapters, [model_name, finetuning_type], [adapter_path], queue=False).then(
get_model_path, [model_name], [model_path], queue=False
).then(
get_template, [model_name], [template], queue=False
) # do not save config since the below line will save
) # do not save config since the below line will save
model_path.change(save_config, inputs=[lang, model_name, model_path], queue=False)
finetuning_type.change(
list_adapters, [model_name, finetuning_type], [adapter_path], queue=False
).then(
finetuning_type.change(list_adapters, [model_name, finetuning_type], [adapter_path], queue=False).then(
can_quantize, [finetuning_type], [quantization_bit], queue=False
)
refresh_btn.click(
list_adapters, [model_name, finetuning_type], [adapter_path], queue=False
)
refresh_btn.click(list_adapters, [model_name, finetuning_type], [adapter_path], queue=False)
return dict(
lang=lang,
@@ -61,5 +57,5 @@ def create_top() -> Dict[str, "Component"]:
quantization_bit=quantization_bit,
template=template,
rope_scaling=rope_scaling,
booster=booster
booster=booster,
)

View File

@@ -1,12 +1,14 @@
import gradio as gr
from typing import TYPE_CHECKING, Dict
import gradio as gr
from transformers.trainer_utils import SchedulerType
from ...extras.constants import TRAINING_STAGES
from ..common import list_adapters, list_dataset, DEFAULT_DATA_DIR
from ..common import DEFAULT_DATA_DIR, list_adapters, list_dataset
from ..components.data import create_preview_box
from ..utils import gen_plot
if TYPE_CHECKING:
from gradio.components import Component
@@ -29,9 +31,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
dataset_dir.change(list_dataset, [dataset_dir, training_stage], [dataset], queue=False)
input_elems.update({training_stage, dataset_dir, dataset})
elem_dict.update(dict(
training_stage=training_stage, dataset_dir=dataset_dir, dataset=dataset, **preview_elems
))
elem_dict.update(dict(training_stage=training_stage, dataset_dir=dataset_dir, dataset=dataset, **preview_elems))
with gr.Row():
cutoff_len = gr.Slider(value=1024, minimum=4, maximum=8192, step=1)
@@ -41,25 +41,33 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
compute_type = gr.Radio(choices=["fp16", "bf16", "fp32"], value="fp16")
input_elems.update({cutoff_len, learning_rate, num_train_epochs, max_samples, compute_type})
elem_dict.update(dict(
cutoff_len=cutoff_len, learning_rate=learning_rate, num_train_epochs=num_train_epochs,
max_samples=max_samples, compute_type=compute_type
))
elem_dict.update(
dict(
cutoff_len=cutoff_len,
learning_rate=learning_rate,
num_train_epochs=num_train_epochs,
max_samples=max_samples,
compute_type=compute_type,
)
)
with gr.Row():
batch_size = gr.Slider(value=4, minimum=1, maximum=512, step=1)
gradient_accumulation_steps = gr.Slider(value=4, minimum=1, maximum=512, step=1)
lr_scheduler_type = gr.Dropdown(
choices=[scheduler.value for scheduler in SchedulerType], value="cosine"
)
lr_scheduler_type = gr.Dropdown(choices=[scheduler.value for scheduler in SchedulerType], value="cosine")
max_grad_norm = gr.Textbox(value="1.0")
val_size = gr.Slider(value=0, minimum=0, maximum=1, step=0.001)
input_elems.update({batch_size, gradient_accumulation_steps, lr_scheduler_type, max_grad_norm, val_size})
elem_dict.update(dict(
batch_size=batch_size, gradient_accumulation_steps=gradient_accumulation_steps,
lr_scheduler_type=lr_scheduler_type, max_grad_norm=max_grad_norm, val_size=val_size
))
elem_dict.update(
dict(
batch_size=batch_size,
gradient_accumulation_steps=gradient_accumulation_steps,
lr_scheduler_type=lr_scheduler_type,
max_grad_norm=max_grad_norm,
val_size=val_size,
)
)
with gr.Accordion(label="Extra config", open=False) as extra_tab:
with gr.Row():
@@ -73,10 +81,17 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
upcast_layernorm = gr.Checkbox(value=False)
input_elems.update({logging_steps, save_steps, warmup_steps, neftune_alpha, sft_packing, upcast_layernorm})
elem_dict.update(dict(
extra_tab=extra_tab, logging_steps=logging_steps, save_steps=save_steps, warmup_steps=warmup_steps,
neftune_alpha=neftune_alpha, sft_packing=sft_packing, upcast_layernorm=upcast_layernorm
))
elem_dict.update(
dict(
extra_tab=extra_tab,
logging_steps=logging_steps,
save_steps=save_steps,
warmup_steps=warmup_steps,
neftune_alpha=neftune_alpha,
sft_packing=sft_packing,
upcast_layernorm=upcast_layernorm,
)
)
with gr.Accordion(label="LoRA config", open=False) as lora_tab:
with gr.Row():
@@ -87,10 +102,16 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
create_new_adapter = gr.Checkbox(scale=1)
input_elems.update({lora_rank, lora_dropout, lora_target, additional_target, create_new_adapter})
elem_dict.update(dict(
lora_tab=lora_tab, lora_rank=lora_rank, lora_dropout=lora_dropout, lora_target=lora_target,
additional_target=additional_target, create_new_adapter=create_new_adapter
))
elem_dict.update(
dict(
lora_tab=lora_tab,
lora_rank=lora_rank,
lora_dropout=lora_dropout,
lora_target=lora_target,
additional_target=additional_target,
create_new_adapter=create_new_adapter,
)
)
with gr.Accordion(label="RLHF config", open=False) as rlhf_tab:
with gr.Row():
@@ -103,13 +124,13 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
list_adapters,
[engine.manager.get_elem_by_name("top.model_name"), engine.manager.get_elem_by_name("top.finetuning_type")],
[reward_model],
queue=False
queue=False,
)
input_elems.update({dpo_beta, dpo_ftx, reward_model})
elem_dict.update(dict(
rlhf_tab=rlhf_tab, dpo_beta=dpo_beta, dpo_ftx=dpo_ftx, reward_model=reward_model, refresh_btn=refresh_btn
))
elem_dict.update(
dict(rlhf_tab=rlhf_tab, dpo_beta=dpo_beta, dpo_ftx=dpo_ftx, reward_model=reward_model, refresh_btn=refresh_btn)
)
with gr.Row():
cmd_preview_btn = gr.Button()
@@ -139,20 +160,28 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
stop_btn.click(engine.runner.set_abort, queue=False)
resume_btn.change(engine.runner.monitor, outputs=output_elems)
elem_dict.update(dict(
cmd_preview_btn=cmd_preview_btn, start_btn=start_btn, stop_btn=stop_btn, output_dir=output_dir,
resume_btn=resume_btn, process_bar=process_bar, output_box=output_box, loss_viewer=loss_viewer
))
elem_dict.update(
dict(
cmd_preview_btn=cmd_preview_btn,
start_btn=start_btn,
stop_btn=stop_btn,
output_dir=output_dir,
resume_btn=resume_btn,
process_bar=process_bar,
output_box=output_box,
loss_viewer=loss_viewer,
)
)
output_box.change(
gen_plot,
[
engine.manager.get_elem_by_name("top.model_name"),
engine.manager.get_elem_by_name("top.finetuning_type"),
output_dir
output_dir,
],
loss_viewer,
queue=False
queue=False,
)
return elem_dict

View File

@@ -1,7 +1,8 @@
import gradio as gr
from gradio.components import Component # cannot use TYPE_CHECKING here
from typing import Any, Dict, Generator, Optional
import gradio as gr
from gradio.components import Component # cannot use TYPE_CHECKING here
from .chatter import WebChatModel
from .common import get_model_path, list_dataset, load_config
from .locales import LOCALES
@@ -11,7 +12,6 @@ from .utils import get_time
class Engine:
def __init__(self, demo_mode: Optional[bool] = False, pure_chat: Optional[bool] = False) -> None:
self.demo_mode = demo_mode
self.pure_chat = pure_chat
@@ -26,10 +26,7 @@ class Engine:
user_config = load_config() if not self.demo_mode else {}
lang = user_config.get("lang", None) or "en"
init_dict = {
"top.lang": {"value": lang},
"infer.chat_box": {"visible": self.chatter.loaded}
}
init_dict = {"top.lang": {"value": lang}, "infer.chat_box": {"visible": self.chatter.loaded}}
if not self.pure_chat:
init_dict["train.dataset"] = {"choices": list_dataset()["choices"]}
@@ -49,13 +46,17 @@ class Engine:
else:
yield self._form_dict({"eval.resume_btn": {"value": True}})
else:
yield self._form_dict({
"train.output_dir": {"value": "train_" + get_time()},
"eval.output_dir": {"value": "eval_" + get_time()},
})
yield self._form_dict(
{
"train.output_dir": {"value": "train_" + get_time()},
"eval.output_dir": {"value": "eval_" + get_time()},
}
)
def change_lang(self, lang: str) -> Dict[Component, Dict[str, Any]]:
return {
component: gr.update(**LOCALES[name][lang])
for elems in self.manager.all_elems.values() for name, component in elems.items() if name in LOCALES
for elems in self.manager.all_elems.values()
for name, component in elems.items()
if name in LOCALES
}

View File

@@ -1,21 +1,22 @@
import gradio as gr
from typing import Optional
import gradio as gr
from transformers.utils.versions import require_version
from .common import save_config
from .components import (
create_chat_box,
create_eval_tab,
create_export_tab,
create_infer_tab,
create_top,
create_train_tab,
create_eval_tab,
create_infer_tab,
create_export_tab,
create_chat_box
)
from .common import save_config
from .css import CSS
from .engine import Engine
require_version("gradio>=3.38.0,<4.0.0", "To fix: pip install \"gradio>=3.38.0,<4.0.0\"")
require_version("gradio>=3.38.0,<4.0.0", 'To fix: pip install "gradio>=3.38.0,<4.0.0"')
def create_ui(demo_mode: Optional[bool] = False) -> gr.Blocks:
@@ -23,11 +24,9 @@ def create_ui(demo_mode: Optional[bool] = False) -> gr.Blocks:
with gr.Blocks(title="LLaMA Board", css=CSS) as demo:
if demo_mode:
gr.HTML("<h1><center>LLaMA Board: A One-stop Web UI for Getting Started with LLaMA Factory</center></h1>")
gr.HTML(
"<h1><center>LLaMA Board: A One-stop Web UI for Getting Started with LLaMA Factory</center></h1>"
)
gr.HTML(
"<h3><center>Visit <a href=\"https://github.com/hiyouga/LLaMA-Factory\" target=\"_blank\">"
'<h3><center>Visit <a href="https://github.com/hiyouga/LLaMA-Factory" target="_blank">'
"LLaMA Factory</a> for details.</center></h3>"
)
gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")

View File

@@ -1,726 +1,220 @@
LOCALES = {
"lang": {
"en": {
"label": "Lang"
},
"zh": {
"label": "语言"
}
},
"model_name": {
"en": {
"label": "Model name"
},
"zh": {
"label": "模型名称"
}
},
"lang": {"en": {"label": "Lang"}, "zh": {"label": "语言"}},
"model_name": {"en": {"label": "Model name"}, "zh": {"label": "模型名称"}},
"model_path": {
"en": {
"label": "Model path",
"info": "Path to pretrained model or model identifier from Hugging Face."
},
"zh": {
"label": "模型路径",
"info": "本地模型的文件路径或 Hugging Face 的模型标识符。"
}
},
"finetuning_type": {
"en": {
"label": "Finetuning method"
},
"zh": {
"label": "微调方法"
}
},
"adapter_path": {
"en": {
"label": "Adapter path"
},
"zh": {
"label": "适配器路径"
}
},
"refresh_btn": {
"en": {
"value": "Refresh adapters"
},
"zh": {
"value": "刷新适配器"
}
},
"advanced_tab": {
"en": {
"label": "Advanced configurations"
},
"zh": {
"label": "高级设置"
}
"en": {"label": "Model path", "info": "Path to pretrained model or model identifier from Hugging Face."},
"zh": {"label": "模型路径", "info": "本地模型的文件路径或 Hugging Face 的模型标识符。"},
},
"finetuning_type": {"en": {"label": "Finetuning method"}, "zh": {"label": "微调方法"}},
"adapter_path": {"en": {"label": "Adapter path"}, "zh": {"label": "适配器路径"}},
"refresh_btn": {"en": {"value": "Refresh adapters"}, "zh": {"value": "刷新适配器"}},
"advanced_tab": {"en": {"label": "Advanced configurations"}, "zh": {"label": "高级设置"}},
"quantization_bit": {
"en": {
"label": "Quantization bit",
"info": "Enable 4/8-bit model quantization (QLoRA)."
},
"zh": {
"label": "量化等级",
"info": "启用 4/8 比特模型量化QLoRA"
}
"en": {"label": "Quantization bit", "info": "Enable 4/8-bit model quantization (QLoRA)."},
"zh": {"label": "量化等级", "info": "启用 4/8 比特模型量化QLoRA"},
},
"template": {
"en": {
"label": "Prompt template",
"info": "The template used in constructing prompts."
},
"zh": {
"label": "提示模板",
"info": "构建提示词时使用的模板"
}
},
"rope_scaling": {
"en": {
"label": "RoPE scaling"
},
"zh": {
"label": "RoPE 插值方法"
}
},
"booster": {
"en": {
"label": "Booster"
},
"zh": {
"label": "加速方式"
}
"en": {"label": "Prompt template", "info": "The template used in constructing prompts."},
"zh": {"label": "提示模板", "info": "构建提示词时使用的模板"},
},
"rope_scaling": {"en": {"label": "RoPE scaling"}, "zh": {"label": "RoPE 插值方法"}},
"booster": {"en": {"label": "Booster"}, "zh": {"label": "加速方式"}},
"training_stage": {
"en": {
"label": "Stage",
"info": "The stage to perform in training."
},
"zh": {
"label": "训练阶段",
"info": "目前采用的训练方式。"
}
"en": {"label": "Stage", "info": "The stage to perform in training."},
"zh": {"label": "训练阶段", "info": "目前采用的训练方式。"},
},
"dataset_dir": {
"en": {
"label": "Data dir",
"info": "Path to the data directory."
},
"zh": {
"label": "数据路径",
"info": "数据文件夹的路径。"
}
},
"dataset": {
"en": {
"label": "Dataset"
},
"zh": {
"label": "数据集"
}
},
"data_preview_btn": {
"en": {
"value": "Preview dataset"
},
"zh": {
"value": "预览数据集"
}
},
"preview_count": {
"en": {
"label": "Count"
},
"zh": {
"label": "数量"
}
},
"page_index": {
"en": {
"label": "Page"
},
"zh": {
"label": "页数"
}
},
"prev_btn": {
"en": {
"value": "Prev"
},
"zh": {
"value": "上一页"
}
},
"next_btn": {
"en": {
"value": "Next"
},
"zh": {
"value": "下一页"
}
},
"close_btn": {
"en": {
"value": "Close"
},
"zh": {
"value": "关闭"
}
},
"preview_samples": {
"en": {
"label": "Samples"
},
"zh": {
"label": "样例"
}
"en": {"label": "Data dir", "info": "Path to the data directory."},
"zh": {"label": "数据路径", "info": "数据文件夹的路径。"},
},
"dataset": {"en": {"label": "Dataset"}, "zh": {"label": "数据集"}},
"data_preview_btn": {"en": {"value": "Preview dataset"}, "zh": {"value": "预览数据集"}},
"preview_count": {"en": {"label": "Count"}, "zh": {"label": "数量"}},
"page_index": {"en": {"label": "Page"}, "zh": {"label": "页数"}},
"prev_btn": {"en": {"value": "Prev"}, "zh": {"value": "上一页"}},
"next_btn": {"en": {"value": "Next"}, "zh": {"value": "下一页"}},
"close_btn": {"en": {"value": "Close"}, "zh": {"value": "关闭"}},
"preview_samples": {"en": {"label": "Samples"}, "zh": {"label": "样例"}},
"cutoff_len": {
"en": {
"label": "Cutoff length",
"info": "Max tokens in input sequence."
},
"zh": {
"label": "截断长度",
"info": "输入序列分词后的最大长度。"
}
"en": {"label": "Cutoff length", "info": "Max tokens in input sequence."},
"zh": {"label": "截断长度", "info": "输入序列分词后的最大长度。"},
},
"learning_rate": {
"en": {
"label": "Learning rate",
"info": "Initial learning rate for AdamW."
},
"zh": {
"label": "学习率",
"info": "AdamW 优化器的初始学习率。"
}
"en": {"label": "Learning rate", "info": "Initial learning rate for AdamW."},
"zh": {"label": "学习率", "info": "AdamW 优化器的初始学习率。"},
},
"num_train_epochs": {
"en": {
"label": "Epochs",
"info": "Total number of training epochs to perform."
},
"zh": {
"label": "训练轮数",
"info": "需要执行的训练总轮数。"
}
"en": {"label": "Epochs", "info": "Total number of training epochs to perform."},
"zh": {"label": "训练轮数", "info": "需要执行的训练总轮数。"},
},
"max_samples": {
"en": {
"label": "Max samples",
"info": "Maximum samples per dataset."
},
"zh": {
"label": "最大样本数",
"info": "每个数据集最多使用的样本数。"
}
"en": {"label": "Max samples", "info": "Maximum samples per dataset."},
"zh": {"label": "最大样本数", "info": "每个数据集最多使用的样本数。"},
},
"compute_type": {
"en": {
"label": "Compute type",
"info": "Whether to use fp16 or bf16 mixed precision training."
},
"zh": {
"label": "计算类型",
"info": "是否启用 FP16 或 BF16 混合精度训练。"
}
"en": {"label": "Compute type", "info": "Whether to use fp16 or bf16 mixed precision training."},
"zh": {"label": "计算类型", "info": "是否启用 FP16 或 BF16 混合精度训练。"},
},
"batch_size": {
"en": {
"label": "Batch size",
"info": "Number of samples to process per GPU."
},
"zh":{
"label": "批处理大小",
"info": "每块 GPU 上处理的样本数量。"
}
"en": {"label": "Batch size", "info": "Number of samples to process per GPU."},
"zh": {"label": "批处理大小", "info": "每块 GPU 上处理的样本数量。"},
},
"gradient_accumulation_steps": {
"en": {
"label": "Gradient accumulation",
"info": "Number of gradient accumulation steps."
},
"zh": {
"label": "梯度累积",
"info": "梯度累积的步数。"
}
"en": {"label": "Gradient accumulation", "info": "Number of gradient accumulation steps."},
"zh": {"label": "梯度累积", "info": "梯度累积的步数。"},
},
"lr_scheduler_type": {
"en": {
"label": "LR Scheduler",
"info": "Name of learning rate scheduler.",
},
"zh": {
"label": "学习率调节器",
"info": "采用的学习率调节器名称。"
}
"zh": {"label": "学习率调节器", "info": "采用的学习率调节器名称。"},
},
"max_grad_norm": {
"en": {
"label": "Maximum gradient norm",
"info": "Norm for gradient clipping.."
},
"zh": {
"label": "最大梯度范数",
"info": "用于梯度裁剪的范数。"
}
"en": {"label": "Maximum gradient norm", "info": "Norm for gradient clipping.."},
"zh": {"label": "最大梯度范数", "info": "用于梯度裁剪的范数。"},
},
"val_size": {
"en": {
"label": "Val size",
"info": "Proportion of data in the dev set."
},
"zh": {
"label": "验证集比例",
"info": "验证集占全部样本的百分比。"
}
},
"extra_tab": {
"en": {
"label": "Extra configurations"
},
"zh": {
"label": "其它参数设置"
}
"en": {"label": "Val size", "info": "Proportion of data in the dev set."},
"zh": {"label": "验证集比例", "info": "验证集占全部样本的百分比。"},
},
"extra_tab": {"en": {"label": "Extra configurations"}, "zh": {"label": "其它参数设置"}},
"logging_steps": {
"en": {
"label": "Logging steps",
"info": "Number of steps between two logs."
},
"zh": {
"label": "日志间隔",
"info": "每两次日志输出间的更新步数。"
}
"en": {"label": "Logging steps", "info": "Number of steps between two logs."},
"zh": {"label": "日志间隔", "info": "每两次日志输出间的更新步数。"},
},
"save_steps": {
"en": {
"label": "Save steps",
"info": "Number of steps between two checkpoints."
},
"zh": {
"label": "保存间隔",
"info": "每两次断点保存间的更新步数。"
}
"en": {"label": "Save steps", "info": "Number of steps between two checkpoints."},
"zh": {"label": "保存间隔", "info": "每两次断点保存间的更新步数。"},
},
"warmup_steps": {
"en": {
"label": "Warmup steps",
"info": "Number of steps used for warmup."
},
"zh": {
"label": "预热步数",
"info": "学习率预热采用的步数。"
}
"en": {"label": "Warmup steps", "info": "Number of steps used for warmup."},
"zh": {"label": "预热步数", "info": "学习率预热采用的步数。"},
},
"neftune_alpha": {
"en": {
"label": "NEFTune Alpha",
"info": "Magnitude of noise adding to embedding vectors."
},
"zh": {
"label": "NEFTune 噪声参数",
"info": "嵌入向量所添加的噪声大小。"
}
"en": {"label": "NEFTune Alpha", "info": "Magnitude of noise adding to embedding vectors."},
"zh": {"label": "NEFTune 噪声参数", "info": "嵌入向量所添加的噪声大小。"},
},
"sft_packing": {
"en": {
"label": "Pack sequences",
"info": "Pack sequences into samples of fixed length in supervised fine-tuning."
"info": "Pack sequences into samples of fixed length in supervised fine-tuning.",
},
"zh": {
"label": "序列打包",
"info": "在有监督微调阶段将序列打包为相同长度的样本。"
}
"zh": {"label": "序列打包", "info": "在有监督微调阶段将序列打包为相同长度的样本。"},
},
"upcast_layernorm": {
"en": {
"label": "Upcast LayerNorm",
"info": "Upcast weights of layernorm in float32."
},
"zh": {
"label": "缩放归一化层",
"info": "将归一化层权重缩放至 32 位精度。"
}
},
"lora_tab": {
"en": {
"label": "LoRA configurations"
},
"zh": {
"label": "LoRA 参数设置"
}
"en": {"label": "Upcast LayerNorm", "info": "Upcast weights of layernorm in float32."},
"zh": {"label": "缩放归一化层", "info": "将归一化层权重缩放至 32 位精度。"},
},
"lora_tab": {"en": {"label": "LoRA configurations"}, "zh": {"label": "LoRA 参数设置"}},
"lora_rank": {
"en": {
"label": "LoRA rank",
"info": "The rank of LoRA matrices."
},
"zh": {
"label": "LoRA 秩",
"info": "LoRA 矩阵的秩。"
}
"en": {"label": "LoRA rank", "info": "The rank of LoRA matrices."},
"zh": {"label": "LoRA ", "info": "LoRA 矩阵的秩。"},
},
"lora_dropout": {
"en": {
"label": "LoRA Dropout",
"info": "Dropout ratio of LoRA weights."
},
"zh": {
"label": "LoRA 随机丢弃",
"info": "LoRA 权重随机丢弃的概率。"
}
"en": {"label": "LoRA Dropout", "info": "Dropout ratio of LoRA weights."},
"zh": {"label": "LoRA 随机丢弃", "info": "LoRA 权重随机丢弃的概率。"},
},
"lora_target": {
"en": {
"label": "LoRA modules (optional)",
"info": "Name(s) of target modules to apply LoRA. Use commas to separate multiple modules."
"info": "Name(s) of target modules to apply LoRA. Use commas to separate multiple modules.",
},
"zh": {
"label": "LoRA 作用模块(非必填)",
"info": "应用 LoRA 的目标模块名称。使用英文逗号分隔多个名称。"
}
"zh": {"label": "LoRA 作用模块(非必填)", "info": "应用 LoRA 的目标模块名称。使用英文逗号分隔多个名称。"},
},
"additional_target": {
"en": {
"label": "Additional modules (optional)",
"info": "Name(s) of modules apart from LoRA layers to be set as trainable. Use commas to separate multiple modules."
"info": "Name(s) of modules apart from LoRA layers to be set as trainable. Use commas to separate multiple modules.",
},
"zh": {
"label": "附加模块(非必填)",
"info": "除 LoRA 层以外的可训练模块名称。使用英文逗号分隔多个名称。"
}
"zh": {"label": "附加模块(非必填)", "info": "除 LoRA 层以外的可训练模块名称。使用英文逗号分隔多个名称。"},
},
"create_new_adapter": {
"en": {
"label": "Create new adapter",
"info": "Whether to create a new adapter with randomly initialized weight or not."
"info": "Whether to create a new adapter with randomly initialized weight or not.",
},
"zh": {
"label": "新建适配器",
"info": "是否创建一个经过随机初始化的新适配器。"
}
},
"rlhf_tab": {
"en": {
"label": "RLHF configurations"
},
"zh": {
"label": "RLHF 参数设置"
}
"zh": {"label": "新建适配器", "info": "是否创建一个经过随机初始化的新适配器。"},
},
"rlhf_tab": {"en": {"label": "RLHF configurations"}, "zh": {"label": "RLHF 参数设置"}},
"dpo_beta": {
"en": {
"label": "DPO beta",
"info": "Value of the beta parameter in the DPO loss."
},
"zh": {
"label": "DPO beta 参数",
"info": "DPO 损失函数中 beta 超参数大小。"
}
"en": {"label": "DPO beta", "info": "Value of the beta parameter in the DPO loss."},
"zh": {"label": "DPO beta 参数", "info": "DPO 损失函数中 beta 超参数大小。"},
},
"dpo_ftx": {
"en": {
"label": "DPO-ftx weight",
"info": "The weight of SFT loss in the DPO-ftx."
},
"zh": {
"label": "DPO-ftx 权重",
"info": "DPO-ftx 中 SFT 损失的权重大小。"
}
"en": {"label": "DPO-ftx weight", "info": "The weight of SFT loss in the DPO-ftx."},
"zh": {"label": "DPO-ftx 权重", "info": "DPO-ftx 中 SFT 损失的权重大小。"},
},
"reward_model": {
"en": {
"label": "Reward model",
"info": "Adapter of the reward model for PPO training. (Needs to refresh adapters)"
"info": "Adapter of the reward model for PPO training. (Needs to refresh adapters)",
},
"zh": {
"label": "奖励模型",
"info": "PPO 训练中奖励模型的适配器路径。(需要刷新适配器)"
}
},
"cmd_preview_btn": {
"en": {
"value": "Preview command"
},
"zh": {
"value": "预览命令"
}
},
"start_btn": {
"en": {
"value": "Start"
},
"zh": {
"value": "开始"
}
},
"stop_btn": {
"en": {
"value": "Abort"
},
"zh": {
"value": "中断"
}
"zh": {"label": "奖励模型", "info": "PPO 训练中奖励模型的适配器路径。(需要刷新适配器)"},
},
"cmd_preview_btn": {"en": {"value": "Preview command"}, "zh": {"value": "预览命令"}},
"start_btn": {"en": {"value": "Start"}, "zh": {"value": "开始"}},
"stop_btn": {"en": {"value": "Abort"}, "zh": {"value": "中断"}},
"output_dir": {
"en": {
"label": "Output dir",
"info": "Directory for saving results."
},
"zh": {
"label": "输出目录",
"info": "保存结果的路径。"
}
},
"output_box": {
"en": {
"value": "Ready."
},
"zh": {
"value": "准备就绪。"
}
},
"loss_viewer": {
"en": {
"label": "Loss"
},
"zh": {
"label": "损失"
}
},
"predict": {
"en": {
"label": "Save predictions"
},
"zh": {
"label": "保存预测结果"
}
},
"load_btn": {
"en": {
"value": "Load model"
},
"zh": {
"value": "加载模型"
}
},
"unload_btn": {
"en": {
"value": "Unload model"
},
"zh": {
"value": "卸载模型"
}
},
"info_box": {
"en": {
"value": "Model unloaded, please load a model first."
},
"zh": {
"value": "模型未加载,请先加载模型。"
}
},
"system": {
"en": {
"placeholder": "System prompt (optional)"
},
"zh": {
"placeholder": "系统提示词(非必填)"
}
},
"tools": {
"en": {
"placeholder": "Tools (optional)"
},
"zh": {
"placeholder": "工具列表(非必填)"
}
},
"query": {
"en": {
"placeholder": "Input..."
},
"zh": {
"placeholder": "输入..."
}
},
"submit_btn": {
"en": {
"value": "Submit"
},
"zh": {
"value": "提交"
}
},
"clear_btn": {
"en": {
"value": "Clear history"
},
"zh": {
"value": "清空历史"
}
},
"max_length": {
"en": {
"label": "Maximum length"
},
"zh": {
"label": "最大长度"
}
},
"max_new_tokens": {
"en": {
"label": "Maximum new tokens"
},
"zh": {
"label": "最大生成长度"
}
},
"top_p": {
"en": {
"label": "Top-p"
},
"zh": {
"label": "Top-p 采样值"
}
},
"temperature": {
"en": {
"label": "Temperature"
},
"zh": {
"label": "温度系数"
}
"en": {"label": "Output dir", "info": "Directory for saving results."},
"zh": {"label": "输出目录", "info": "保存结果的路径。"},
},
"output_box": {"en": {"value": "Ready."}, "zh": {"value": "准备就绪。"}},
"loss_viewer": {"en": {"label": "Loss"}, "zh": {"label": "损失"}},
"predict": {"en": {"label": "Save predictions"}, "zh": {"label": "保存预测结果"}},
"load_btn": {"en": {"value": "Load model"}, "zh": {"value": "加载模型"}},
"unload_btn": {"en": {"value": "Unload model"}, "zh": {"value": "卸载模型"}},
"info_box": {"en": {"value": "Model unloaded, please load a model first."}, "zh": {"value": "模型未加载,请先加载模型。"}},
"system": {"en": {"placeholder": "System prompt (optional)"}, "zh": {"placeholder": "系统提示词(非必填)"}},
"tools": {"en": {"placeholder": "Tools (optional)"}, "zh": {"placeholder": "工具列表(非必填)"}},
"query": {"en": {"placeholder": "Input..."}, "zh": {"placeholder": "输入..."}},
"submit_btn": {"en": {"value": "Submit"}, "zh": {"value": "提交"}},
"clear_btn": {"en": {"value": "Clear history"}, "zh": {"value": "清空历史"}},
"max_length": {"en": {"label": "Maximum length"}, "zh": {"label": "最大长度"}},
"max_new_tokens": {"en": {"label": "Maximum new tokens"}, "zh": {"label": "最大生成长度"}},
"top_p": {"en": {"label": "Top-p"}, "zh": {"label": "Top-p 采样值"}},
"temperature": {"en": {"label": "Temperature"}, "zh": {"label": "温度系数"}},
"max_shard_size": {
"en": {
"label": "Max shard size (GB)",
"info": "The maximum size for a model file."
},
"zh": {
"label": "最大分块大小GB",
"info": "单个模型文件的最大大小。"
}
"en": {"label": "Max shard size (GB)", "info": "The maximum size for a model file."},
"zh": {"label": "最大分块大小GB", "info": "单个模型文件的最大大小。"},
},
"export_quantization_bit": {
"en": {
"label": "Export quantization bit.",
"info": "Quantizing the exported model."
},
"zh": {
"label": "导出量化等级",
"info": "量化导出模型。"
}
"en": {"label": "Export quantization bit.", "info": "Quantizing the exported model."},
"zh": {"label": "导出量化等级", "info": "量化导出模型。"},
},
"export_quantization_dataset": {
"en": {
"label": "Export quantization dataset.",
"info": "The calibration dataset used for quantization."
},
"zh": {
"label": "导出量化数据集",
"info": "量化过程中使用的校准数据集。"
}
"en": {"label": "Export quantization dataset.", "info": "The calibration dataset used for quantization."},
"zh": {"label": "导出量化数据集", "info": "量化过程中使用的校准数据集。"},
},
"export_dir": {
"en": {
"label": "Export dir",
"info": "Directory to save exported model."
},
"zh": {
"label": "导出目录",
"info": "保存导出模型的文件夹路径。"
}
"en": {"label": "Export dir", "info": "Directory to save exported model."},
"zh": {"label": "导出目录", "info": "保存导出模型的文件夹路径。"},
},
"export_btn": {
"en": {
"value": "Export"
},
"zh": {
"value": "开始导出"
}
}
"export_btn": {"en": {"value": "Export"}, "zh": {"value": "开始导出"}},
}
ALERTS = {
"err_conflict": {
"en": "A process is in running, please abort it firstly.",
"zh": "任务已存在,请先中断训练。"
},
"err_exists": {
"en": "You have loaded a model, please unload it first.",
"zh": "模型已存在,请先卸载模型。"
},
"err_no_model": {
"en": "Please select a model.",
"zh": "请选择模型。"
},
"err_no_path": {
"en": "Model not found.",
"zh": "模型未找到。"
},
"err_no_dataset": {
"en": "Please choose a dataset.",
"zh": "请选择数据集。"
},
"err_no_adapter": {
"en": "Please select an adapter.",
"zh": "请选择一个适配器。"
},
"err_no_export_dir": {
"en": "Please provide export dir.",
"zh": "请填写导出目录"
},
"err_failed": {
"en": "Failed.",
"zh": "训练出错。"
},
"err_conflict": {"en": "A process is in running, please abort it firstly.", "zh": "任务已存在,请先中断训练。"},
"err_exists": {"en": "You have loaded a model, please unload it first.", "zh": "模型已存在,请先卸载模型。"},
"err_no_model": {"en": "Please select a model.", "zh": "请选择模型。"},
"err_no_path": {"en": "Model not found.", "zh": "模型未找到。"},
"err_no_dataset": {"en": "Please choose a dataset.", "zh": "请选择数据集。"},
"err_no_adapter": {"en": "Please select an adapter.", "zh": "请选择一个适配器。"},
"err_no_export_dir": {"en": "Please provide export dir.", "zh": "请填写导出目录"},
"err_failed": {"en": "Failed.", "zh": "训练出错。"},
"err_demo": {
"en": "Training is unavailable in demo mode, duplicate the space to a private one first.",
"zh": "展示模式不支持训练,请先复制到私人空间。"
"zh": "展示模式不支持训练,请先复制到私人空间。",
},
"err_device_count": {
"en": "Multiple GPUs are not supported yet.",
"zh": "尚不支持多 GPU 训练。"
},
"info_aborting": {
"en": "Aborted, wait for terminating...",
"zh": "训练中断,正在等待线程结束……"
},
"info_aborted": {
"en": "Ready.",
"zh": "准备就绪。"
},
"info_finished": {
"en": "Finished.",
"zh": "训练完毕。"
},
"info_loading": {
"en": "Loading model...",
"zh": "加载中……"
},
"info_unloading": {
"en": "Unloading model...",
"zh": "卸载中……"
},
"info_loaded": {
"en": "Model loaded, now you can chat with your model!",
"zh": "模型已加载,可以开始聊天了!"
},
"info_unloaded": {
"en": "Model unloaded.",
"zh": "模型已卸载。"
},
"info_exporting": {
"en": "Exporting model...",
"zh": "正在导出模型……"
},
"info_exported": {
"en": "Model exported.",
"zh": "模型导出完成。"
}
"err_device_count": {"en": "Multiple GPUs are not supported yet.", "zh": "尚不支持多 GPU 训练。"},
"info_aborting": {"en": "Aborted, wait for terminating...", "zh": "训练中断,正在等待线程结束……"},
"info_aborted": {"en": "Ready.", "zh": "准备就绪。"},
"info_finished": {"en": "Finished.", "zh": "训练完毕。"},
"info_loading": {"en": "Loading model...", "zh": "加载中……"},
"info_unloading": {"en": "Unloading model...", "zh": "卸载中……"},
"info_loaded": {"en": "Model loaded, now you can chat with your model!", "zh": "模型已加载,可以开始聊天了!"},
"info_unloaded": {"en": "Model unloaded.", "zh": "模型已卸载。"},
"info_exporting": {"en": "Exporting model...", "zh": "正在导出模型……"},
"info_exported": {"en": "Model exported.", "zh": "模型导出完成。"},
}

View File

@@ -1,11 +1,11 @@
from typing import TYPE_CHECKING, Dict, List, Set
if TYPE_CHECKING:
from gradio.components import Component
class Manager:
def __init__(self) -> None:
self.all_elems: Dict[str, Dict[str, "Component"]] = {}
@@ -26,7 +26,7 @@ class Manager:
self.all_elems["top"]["quantization_bit"],
self.all_elems["top"]["template"],
self.all_elems["top"]["rope_scaling"],
self.all_elems["top"]["booster"]
self.all_elems["top"]["booster"],
}
def list_elems(self) -> List["Component"]:

View File

@@ -1,12 +1,12 @@
import logging
import os
import time
import logging
import gradio as gr
from threading import Thread
from gradio.components import Component # cannot use TYPE_CHECKING here
from typing import TYPE_CHECKING, Any, Dict, Generator, Optional, Tuple
import gradio as gr
import transformers
from gradio.components import Component # cannot use TYPE_CHECKING here
from transformers.trainer import TRAINING_ARGS_NAME
from ..extras.callbacks import LogCallback
@@ -18,12 +18,12 @@ from .common import get_module, get_save_dir, load_config
from .locales import ALERTS
from .utils import gen_cmd, get_eval_results, update_process_bar
if TYPE_CHECKING:
from .manager import Manager
class Runner:
def __init__(self, manager: "Manager", demo_mode: Optional[bool] = False) -> None:
self.manager = manager
self.demo_mode = demo_mode
@@ -90,9 +90,12 @@ class Runner:
user_config = load_config()
if get("top.adapter_path"):
adapter_name_or_path = ",".join([
get_save_dir(get("top.model_name"), get("top.finetuning_type"), adapter)
for adapter in get("top.adapter_path")])
adapter_name_or_path = ",".join(
[
get_save_dir(get("top.model_name"), get("top.finetuning_type"), adapter)
for adapter in get("top.adapter_path")
]
)
else:
adapter_name_or_path = None
@@ -131,12 +134,12 @@ class Runner:
create_new_adapter=get("train.create_new_adapter"),
output_dir=get_save_dir(get("top.model_name"), get("top.finetuning_type"), get("train.output_dir")),
fp16=(get("train.compute_type") == "fp16"),
bf16=(get("train.compute_type") == "bf16")
bf16=(get("train.compute_type") == "bf16"),
)
args["disable_tqdm"] = True
if TRAINING_STAGES[get("train.training_stage")] in ["rm", "ppo", "dpo"]:
args["create_new_adapter"] = (args["quantization_bit"] is None)
args["create_new_adapter"] = args["quantization_bit"] is None
if args["stage"] == "ppo":
args["reward_model"] = get_save_dir(
@@ -161,9 +164,12 @@ class Runner:
user_config = load_config()
if get("top.adapter_path"):
adapter_name_or_path = ",".join([
get_save_dir(get("top.model_name"), get("top.finetuning_type"), adapter)
for adapter in get("top.adapter_path")])
adapter_name_or_path = ",".join(
[
get_save_dir(get("top.model_name"), get("top.finetuning_type"), adapter)
for adapter in get("top.adapter_path")
]
)
else:
adapter_name_or_path = None
@@ -187,7 +193,7 @@ class Runner:
max_new_tokens=get("eval.max_new_tokens"),
top_p=get("eval.top_p"),
temperature=get("eval.temperature"),
output_dir=get_save_dir(get("top.model_name"), get("top.finetuning_type"), get("eval.output_dir"))
output_dir=get_save_dir(get("top.model_name"), get("top.finetuning_type"), get("eval.output_dir")),
)
if get("eval.predict"):
@@ -197,7 +203,9 @@ class Runner:
return args
def _preview(self, data: Dict[Component, Any], do_train: bool) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
def _preview(
self, data: Dict[Component, Any], do_train: bool
) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
error = self._initialize(data, do_train, from_preview=True)
if error:
gr.Warning(error)
@@ -235,9 +243,11 @@ class Runner:
get = lambda name: self.running_data[self.manager.get_elem_by_name(name)]
self.running = True
lang = get("top.lang")
output_dir = get_save_dir(get("top.model_name"), get("top.finetuning_type"), get(
"{}.output_dir".format("train" if self.do_train else "eval")
))
output_dir = get_save_dir(
get("top.model_name"),
get("top.finetuning_type"),
get("{}.output_dir".format("train" if self.do_train else "eval")),
)
while self.thread.is_alive():
time.sleep(2)

View File

@@ -1,13 +1,15 @@
import os
import json
import gradio as gr
from typing import TYPE_CHECKING, Any, Dict
import os
from datetime import datetime
from typing import TYPE_CHECKING, Any, Dict
import gradio as gr
from ..extras.packages import is_matplotlib_available
from ..extras.ploting import smooth
from .common import get_save_dir
if TYPE_CHECKING:
from ..extras.callbacks import LogCallback
@@ -22,16 +24,13 @@ def update_process_bar(callback: "LogCallback") -> Dict[str, Any]:
percentage = round(100 * callback.cur_steps / callback.max_steps, 0) if callback.max_steps != 0 else 100.0
label = "Running {:d}/{:d}: {} < {}".format(
callback.cur_steps,
callback.max_steps,
callback.elapsed_time,
callback.remaining_time
callback.cur_steps, callback.max_steps, callback.elapsed_time, callback.remaining_time
)
return gr.update(label=label, value=percentage, visible=True)
def get_time() -> str:
return datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
return datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
def can_quantize(finetuning_type: str) -> Dict[str, Any]: