mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-18 12:50:38 +08:00
format style
Former-commit-id: 53b683531b83cd1d19de97c6565f16c1eca6f5e1
This commit is contained in:
@@ -1,24 +1,22 @@
|
||||
import gradio as gr
|
||||
from gradio.components import Component # cannot use TYPE_CHECKING here
|
||||
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple
|
||||
|
||||
import gradio as gr
|
||||
from gradio.components import Component # cannot use TYPE_CHECKING here
|
||||
|
||||
from ..chat import ChatModel
|
||||
from ..extras.misc import torch_gc
|
||||
from ..hparams import GeneratingArguments
|
||||
from .common import get_save_dir
|
||||
from .locales import ALERTS
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .manager import Manager
|
||||
|
||||
|
||||
class WebChatModel(ChatModel):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
manager: "Manager",
|
||||
demo_mode: Optional[bool] = False,
|
||||
lazy_init: Optional[bool] = True
|
||||
self, manager: "Manager", demo_mode: Optional[bool] = False, lazy_init: Optional[bool] = True
|
||||
) -> None:
|
||||
self.manager = manager
|
||||
self.demo_mode = demo_mode
|
||||
@@ -26,11 +24,12 @@ class WebChatModel(ChatModel):
|
||||
self.tokenizer = None
|
||||
self.generating_args = GeneratingArguments()
|
||||
|
||||
if not lazy_init: # read arguments from command line
|
||||
if not lazy_init: # read arguments from command line
|
||||
super().__init__()
|
||||
|
||||
if demo_mode: # load demo_config.json if exists
|
||||
if demo_mode: # load demo_config.json if exists
|
||||
import json
|
||||
|
||||
try:
|
||||
with open("demo_config.json", "r", encoding="utf-8") as f:
|
||||
args = json.load(f)
|
||||
@@ -38,7 +37,7 @@ class WebChatModel(ChatModel):
|
||||
super().__init__(args)
|
||||
except AssertionError:
|
||||
print("Please provided model name and template in `demo_config.json`.")
|
||||
except:
|
||||
except Exception:
|
||||
print("Cannot find `demo_config.json` at current directory.")
|
||||
|
||||
@property
|
||||
@@ -64,9 +63,12 @@ class WebChatModel(ChatModel):
|
||||
return
|
||||
|
||||
if get("top.adapter_path"):
|
||||
adapter_name_or_path = ",".join([
|
||||
get_save_dir(get("top.model_name"), get("top.finetuning_type"), adapter)
|
||||
for adapter in get("top.adapter_path")])
|
||||
adapter_name_or_path = ",".join(
|
||||
[
|
||||
get_save_dir(get("top.model_name"), get("top.finetuning_type"), adapter)
|
||||
for adapter in get("top.adapter_path")
|
||||
]
|
||||
)
|
||||
else:
|
||||
adapter_name_or_path = None
|
||||
|
||||
@@ -79,7 +81,7 @@ class WebChatModel(ChatModel):
|
||||
template=get("top.template"),
|
||||
flash_attn=(get("top.booster") == "flash_attn"),
|
||||
use_unsloth=(get("top.booster") == "unsloth"),
|
||||
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None
|
||||
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
|
||||
)
|
||||
super().__init__(args)
|
||||
|
||||
@@ -108,7 +110,7 @@ class WebChatModel(ChatModel):
|
||||
tools: str,
|
||||
max_new_tokens: int,
|
||||
top_p: float,
|
||||
temperature: float
|
||||
temperature: float,
|
||||
) -> Generator[Tuple[List[Tuple[str, str]], List[Tuple[str, str]]], None, None]:
|
||||
chatbot.append([query, ""])
|
||||
response = ""
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
import os
|
||||
import json
|
||||
import gradio as gr
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from typing import Any, Dict, Optional
|
||||
from peft.utils import WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME
|
||||
|
||||
import gradio as gr
|
||||
from peft.utils import SAFETENSORS_WEIGHTS_NAME, WEIGHTS_NAME
|
||||
|
||||
from ..extras.constants import (
|
||||
DATA_CONFIG,
|
||||
@@ -12,7 +13,7 @@ from ..extras.constants import (
|
||||
PEFT_METHODS,
|
||||
SUPPORTED_MODELS,
|
||||
TRAINING_STAGES,
|
||||
DownloadSource
|
||||
DownloadSource,
|
||||
)
|
||||
from ..extras.misc import use_modelscope
|
||||
|
||||
@@ -36,7 +37,7 @@ def load_config() -> Dict[str, Any]:
|
||||
try:
|
||||
with open(get_config_path(), "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
except:
|
||||
except Exception:
|
||||
return {"lang": None, "last_model": None, "path_dict": {}, "cache_dir": None}
|
||||
|
||||
|
||||
@@ -59,7 +60,7 @@ def get_model_path(model_name: str) -> str:
|
||||
use_modelscope()
|
||||
and path_dict.get(DownloadSource.MODELSCOPE)
|
||||
and model_path == path_dict.get(DownloadSource.DEFAULT)
|
||||
): # replace path
|
||||
): # replace path
|
||||
model_path = path_dict.get(DownloadSource.MODELSCOPE)
|
||||
return model_path
|
||||
|
||||
@@ -87,9 +88,8 @@ def list_adapters(model_name: str, finetuning_type: str) -> Dict[str, Any]:
|
||||
save_dir = get_save_dir(model_name, finetuning_type)
|
||||
if save_dir and os.path.isdir(save_dir):
|
||||
for adapter in os.listdir(save_dir):
|
||||
if (
|
||||
os.path.isdir(os.path.join(save_dir, adapter))
|
||||
and any([os.path.isfile(os.path.join(save_dir, adapter, name)) for name in ADAPTER_NAMES])
|
||||
if os.path.isdir(os.path.join(save_dir, adapter)) and any(
|
||||
os.path.isfile(os.path.join(save_dir, adapter, name)) for name in ADAPTER_NAMES
|
||||
):
|
||||
adapters.append(adapter)
|
||||
return gr.update(value=[], choices=adapters, interactive=True)
|
||||
|
||||
@@ -1,11 +1,16 @@
|
||||
from .chatbot import create_chat_box
|
||||
from .eval import create_eval_tab
|
||||
from .export import create_export_tab
|
||||
from .infer import create_infer_tab
|
||||
from .top import create_top
|
||||
from .train import create_train_tab
|
||||
from .eval import create_eval_tab
|
||||
from .infer import create_infer_tab
|
||||
from .export import create_export_tab
|
||||
from .chatbot import create_chat_box
|
||||
|
||||
|
||||
__all__ = [
|
||||
"create_top", "create_train_tab", "create_eval_tab", "create_infer_tab", "create_export_tab", "create_chat_box"
|
||||
"create_top",
|
||||
"create_train_tab",
|
||||
"create_eval_tab",
|
||||
"create_infer_tab",
|
||||
"create_export_tab",
|
||||
"create_chat_box",
|
||||
]
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import gradio as gr
|
||||
from typing import TYPE_CHECKING, Dict, Optional, Tuple
|
||||
|
||||
import gradio as gr
|
||||
|
||||
from ..utils import check_json_schema
|
||||
|
||||
|
||||
@@ -12,8 +13,7 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
def create_chat_box(
|
||||
engine: "Engine",
|
||||
visible: Optional[bool] = False
|
||||
engine: "Engine", visible: Optional[bool] = False
|
||||
) -> Tuple["Block", "Component", "Component", Dict[str, "Component"]]:
|
||||
with gr.Box(visible=visible) as chat_box:
|
||||
chatbot = gr.Chatbot()
|
||||
@@ -38,20 +38,23 @@ def create_chat_box(
|
||||
engine.chatter.predict,
|
||||
[chatbot, query, history, system, tools, max_new_tokens, top_p, temperature],
|
||||
[chatbot, history],
|
||||
show_progress=True
|
||||
).then(
|
||||
lambda: gr.update(value=""), outputs=[query]
|
||||
)
|
||||
show_progress=True,
|
||||
).then(lambda: gr.update(value=""), outputs=[query])
|
||||
|
||||
clear_btn.click(lambda: ([], []), outputs=[chatbot, history], show_progress=True)
|
||||
|
||||
return chat_box, chatbot, history, dict(
|
||||
system=system,
|
||||
tools=tools,
|
||||
query=query,
|
||||
submit_btn=submit_btn,
|
||||
clear_btn=clear_btn,
|
||||
max_new_tokens=max_new_tokens,
|
||||
top_p=top_p,
|
||||
temperature=temperature
|
||||
return (
|
||||
chat_box,
|
||||
chatbot,
|
||||
history,
|
||||
dict(
|
||||
system=system,
|
||||
tools=tools,
|
||||
query=query,
|
||||
submit_btn=submit_btn,
|
||||
clear_btn=clear_btn,
|
||||
max_new_tokens=max_new_tokens,
|
||||
top_p=top_p,
|
||||
temperature=temperature,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
import os
|
||||
import json
|
||||
import gradio as gr
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, Dict, Tuple
|
||||
|
||||
import gradio as gr
|
||||
|
||||
from ...extras.constants import DATA_CONFIG
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from gradio.components import Component
|
||||
|
||||
@@ -24,7 +26,7 @@ def can_preview(dataset_dir: str, dataset: list) -> Dict[str, Any]:
|
||||
try:
|
||||
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
|
||||
dataset_info = json.load(f)
|
||||
except:
|
||||
except Exception:
|
||||
return gr.update(interactive=False)
|
||||
|
||||
if (
|
||||
@@ -48,7 +50,7 @@ def get_preview(dataset_dir: str, dataset: list, page_index: int) -> Tuple[int,
|
||||
elif data_file.endswith(".jsonl"):
|
||||
data = [json.loads(line) for line in f]
|
||||
else:
|
||||
data = [line for line in f]
|
||||
data = [line for line in f] # noqa: C416
|
||||
return len(data), data[PAGE_SIZE * page_index : PAGE_SIZE * (page_index + 1)], gr.update(visible=True)
|
||||
|
||||
|
||||
@@ -67,32 +69,17 @@ def create_preview_box(dataset_dir: "gr.Textbox", dataset: "gr.Dropdown") -> Dic
|
||||
with gr.Row():
|
||||
preview_samples = gr.JSON(interactive=False)
|
||||
|
||||
dataset.change(
|
||||
can_preview, [dataset_dir, dataset], [data_preview_btn], queue=False
|
||||
).then(
|
||||
dataset.change(can_preview, [dataset_dir, dataset], [data_preview_btn], queue=False).then(
|
||||
lambda: 0, outputs=[page_index], queue=False
|
||||
)
|
||||
data_preview_btn.click(
|
||||
get_preview,
|
||||
[dataset_dir, dataset, page_index],
|
||||
[preview_count, preview_samples, preview_box],
|
||||
queue=False
|
||||
get_preview, [dataset_dir, dataset, page_index], [preview_count, preview_samples, preview_box], queue=False
|
||||
)
|
||||
prev_btn.click(
|
||||
prev_page, [page_index], [page_index], queue=False
|
||||
).then(
|
||||
get_preview,
|
||||
[dataset_dir, dataset, page_index],
|
||||
[preview_count, preview_samples, preview_box],
|
||||
queue=False
|
||||
prev_btn.click(prev_page, [page_index], [page_index], queue=False).then(
|
||||
get_preview, [dataset_dir, dataset, page_index], [preview_count, preview_samples, preview_box], queue=False
|
||||
)
|
||||
next_btn.click(
|
||||
next_page, [page_index, preview_count], [page_index], queue=False
|
||||
).then(
|
||||
get_preview,
|
||||
[dataset_dir, dataset, page_index],
|
||||
[preview_count, preview_samples, preview_box],
|
||||
queue=False
|
||||
next_btn.click(next_page, [page_index, preview_count], [page_index], queue=False).then(
|
||||
get_preview, [dataset_dir, dataset, page_index], [preview_count, preview_samples, preview_box], queue=False
|
||||
)
|
||||
close_btn.click(lambda: gr.update(visible=False), outputs=[preview_box], queue=False)
|
||||
return dict(
|
||||
@@ -102,5 +89,5 @@ def create_preview_box(dataset_dir: "gr.Textbox", dataset: "gr.Dropdown") -> Dic
|
||||
prev_btn=prev_btn,
|
||||
next_btn=next_btn,
|
||||
close_btn=close_btn,
|
||||
preview_samples=preview_samples
|
||||
preview_samples=preview_samples,
|
||||
)
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
import gradio as gr
|
||||
from typing import TYPE_CHECKING, Dict
|
||||
|
||||
from ..common import list_dataset, DEFAULT_DATA_DIR
|
||||
import gradio as gr
|
||||
|
||||
from ..common import DEFAULT_DATA_DIR, list_dataset
|
||||
from .data import create_preview_box
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from gradio.components import Component
|
||||
|
||||
@@ -31,9 +33,7 @@ def create_eval_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
predict = gr.Checkbox(value=True)
|
||||
|
||||
input_elems.update({cutoff_len, max_samples, batch_size, predict})
|
||||
elem_dict.update(dict(
|
||||
cutoff_len=cutoff_len, max_samples=max_samples, batch_size=batch_size, predict=predict
|
||||
))
|
||||
elem_dict.update(dict(cutoff_len=cutoff_len, max_samples=max_samples, batch_size=batch_size, predict=predict))
|
||||
|
||||
with gr.Row():
|
||||
max_new_tokens = gr.Slider(10, 2048, value=128, step=1)
|
||||
@@ -42,9 +42,7 @@ def create_eval_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
output_dir = gr.Textbox()
|
||||
|
||||
input_elems.update({max_new_tokens, top_p, temperature, output_dir})
|
||||
elem_dict.update(dict(
|
||||
max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature, output_dir=output_dir
|
||||
))
|
||||
elem_dict.update(dict(max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature, output_dir=output_dir))
|
||||
|
||||
with gr.Row():
|
||||
cmd_preview_btn = gr.Button()
|
||||
@@ -59,10 +57,16 @@ def create_eval_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
output_box = gr.Markdown()
|
||||
|
||||
output_elems = [output_box, process_bar]
|
||||
elem_dict.update(dict(
|
||||
cmd_preview_btn=cmd_preview_btn, start_btn=start_btn, stop_btn=stop_btn,
|
||||
resume_btn=resume_btn, process_bar=process_bar, output_box=output_box
|
||||
))
|
||||
elem_dict.update(
|
||||
dict(
|
||||
cmd_preview_btn=cmd_preview_btn,
|
||||
start_btn=start_btn,
|
||||
stop_btn=stop_btn,
|
||||
resume_btn=resume_btn,
|
||||
process_bar=process_bar,
|
||||
output_box=output_box,
|
||||
)
|
||||
)
|
||||
|
||||
cmd_preview_btn.click(engine.runner.preview_eval, input_elems, output_elems)
|
||||
start_btn.click(engine.runner.run_eval, input_elems, output_elems)
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
import gradio as gr
|
||||
from typing import TYPE_CHECKING, Dict, Generator, List
|
||||
|
||||
import gradio as gr
|
||||
|
||||
from ...train import export_model
|
||||
from ..common import get_save_dir
|
||||
from ..locales import ALERTS
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from gradio.components import Component
|
||||
|
||||
@@ -24,7 +26,7 @@ def save_model(
|
||||
max_shard_size: int,
|
||||
export_quantization_bit: int,
|
||||
export_quantization_dataset: str,
|
||||
export_dir: str
|
||||
export_dir: str,
|
||||
) -> Generator[str, None, None]:
|
||||
error = ""
|
||||
if not model_name:
|
||||
@@ -44,7 +46,9 @@ def save_model(
|
||||
return
|
||||
|
||||
if adapter_path:
|
||||
adapter_name_or_path = ",".join([get_save_dir(model_name, finetuning_type, adapter) for adapter in adapter_path])
|
||||
adapter_name_or_path = ",".join(
|
||||
[get_save_dir(model_name, finetuning_type, adapter) for adapter in adapter_path]
|
||||
)
|
||||
else:
|
||||
adapter_name_or_path = None
|
||||
|
||||
@@ -56,7 +60,7 @@ def save_model(
|
||||
export_dir=export_dir,
|
||||
export_size=max_shard_size,
|
||||
export_quantization_bit=int(export_quantization_bit) if export_quantization_bit in GPTQ_BITS else None,
|
||||
export_quantization_dataset=export_quantization_dataset
|
||||
export_quantization_dataset=export_quantization_dataset,
|
||||
)
|
||||
|
||||
yield ALERTS["info_exporting"][lang]
|
||||
@@ -86,9 +90,9 @@ def create_export_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
max_shard_size,
|
||||
export_quantization_bit,
|
||||
export_quantization_dataset,
|
||||
export_dir
|
||||
export_dir,
|
||||
],
|
||||
[info_box]
|
||||
[info_box],
|
||||
)
|
||||
|
||||
return dict(
|
||||
@@ -97,5 +101,5 @@ def create_export_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
export_quantization_dataset=export_quantization_dataset,
|
||||
export_dir=export_dir,
|
||||
export_btn=export_btn,
|
||||
info_box=info_box
|
||||
info_box=info_box,
|
||||
)
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
import gradio as gr
|
||||
from typing import TYPE_CHECKING, Dict
|
||||
|
||||
import gradio as gr
|
||||
|
||||
from .chatbot import create_chat_box
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from gradio.components import Component
|
||||
|
||||
@@ -23,18 +25,12 @@ def create_infer_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
chat_box, chatbot, history, chat_elems = create_chat_box(engine, visible=False)
|
||||
elem_dict.update(dict(chat_box=chat_box, **chat_elems))
|
||||
|
||||
load_btn.click(
|
||||
engine.chatter.load_model, input_elems, [info_box]
|
||||
).then(
|
||||
load_btn.click(engine.chatter.load_model, input_elems, [info_box]).then(
|
||||
lambda: gr.update(visible=engine.chatter.loaded), outputs=[chat_box]
|
||||
)
|
||||
|
||||
unload_btn.click(
|
||||
engine.chatter.unload_model, input_elems, [info_box]
|
||||
).then(
|
||||
unload_btn.click(engine.chatter.unload_model, input_elems, [info_box]).then(
|
||||
lambda: ([], []), outputs=[chatbot, history]
|
||||
).then(
|
||||
lambda: gr.update(visible=engine.chatter.loaded), outputs=[chat_box]
|
||||
)
|
||||
).then(lambda: gr.update(visible=engine.chatter.loaded), outputs=[chat_box])
|
||||
|
||||
return elem_dict
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
import gradio as gr
|
||||
from typing import TYPE_CHECKING, Dict
|
||||
|
||||
import gradio as gr
|
||||
|
||||
from ...data import templates
|
||||
from ...extras.constants import METHODS, SUPPORTED_MODELS
|
||||
from ..common import get_model_path, get_template, list_adapters, save_config
|
||||
from ..utils import can_quantize
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from gradio.components import Component
|
||||
|
||||
@@ -30,25 +32,19 @@ def create_top() -> Dict[str, "Component"]:
|
||||
rope_scaling = gr.Radio(choices=["none", "linear", "dynamic"], value="none")
|
||||
booster = gr.Radio(choices=["none", "flash_attn", "unsloth"], value="none")
|
||||
|
||||
model_name.change(
|
||||
list_adapters, [model_name, finetuning_type], [adapter_path], queue=False
|
||||
).then(
|
||||
model_name.change(list_adapters, [model_name, finetuning_type], [adapter_path], queue=False).then(
|
||||
get_model_path, [model_name], [model_path], queue=False
|
||||
).then(
|
||||
get_template, [model_name], [template], queue=False
|
||||
) # do not save config since the below line will save
|
||||
) # do not save config since the below line will save
|
||||
|
||||
model_path.change(save_config, inputs=[lang, model_name, model_path], queue=False)
|
||||
|
||||
finetuning_type.change(
|
||||
list_adapters, [model_name, finetuning_type], [adapter_path], queue=False
|
||||
).then(
|
||||
finetuning_type.change(list_adapters, [model_name, finetuning_type], [adapter_path], queue=False).then(
|
||||
can_quantize, [finetuning_type], [quantization_bit], queue=False
|
||||
)
|
||||
|
||||
refresh_btn.click(
|
||||
list_adapters, [model_name, finetuning_type], [adapter_path], queue=False
|
||||
)
|
||||
refresh_btn.click(list_adapters, [model_name, finetuning_type], [adapter_path], queue=False)
|
||||
|
||||
return dict(
|
||||
lang=lang,
|
||||
@@ -61,5 +57,5 @@ def create_top() -> Dict[str, "Component"]:
|
||||
quantization_bit=quantization_bit,
|
||||
template=template,
|
||||
rope_scaling=rope_scaling,
|
||||
booster=booster
|
||||
booster=booster,
|
||||
)
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
import gradio as gr
|
||||
from typing import TYPE_CHECKING, Dict
|
||||
|
||||
import gradio as gr
|
||||
from transformers.trainer_utils import SchedulerType
|
||||
|
||||
from ...extras.constants import TRAINING_STAGES
|
||||
from ..common import list_adapters, list_dataset, DEFAULT_DATA_DIR
|
||||
from ..common import DEFAULT_DATA_DIR, list_adapters, list_dataset
|
||||
from ..components.data import create_preview_box
|
||||
from ..utils import gen_plot
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from gradio.components import Component
|
||||
|
||||
@@ -29,9 +31,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
dataset_dir.change(list_dataset, [dataset_dir, training_stage], [dataset], queue=False)
|
||||
|
||||
input_elems.update({training_stage, dataset_dir, dataset})
|
||||
elem_dict.update(dict(
|
||||
training_stage=training_stage, dataset_dir=dataset_dir, dataset=dataset, **preview_elems
|
||||
))
|
||||
elem_dict.update(dict(training_stage=training_stage, dataset_dir=dataset_dir, dataset=dataset, **preview_elems))
|
||||
|
||||
with gr.Row():
|
||||
cutoff_len = gr.Slider(value=1024, minimum=4, maximum=8192, step=1)
|
||||
@@ -41,25 +41,33 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
compute_type = gr.Radio(choices=["fp16", "bf16", "fp32"], value="fp16")
|
||||
|
||||
input_elems.update({cutoff_len, learning_rate, num_train_epochs, max_samples, compute_type})
|
||||
elem_dict.update(dict(
|
||||
cutoff_len=cutoff_len, learning_rate=learning_rate, num_train_epochs=num_train_epochs,
|
||||
max_samples=max_samples, compute_type=compute_type
|
||||
))
|
||||
elem_dict.update(
|
||||
dict(
|
||||
cutoff_len=cutoff_len,
|
||||
learning_rate=learning_rate,
|
||||
num_train_epochs=num_train_epochs,
|
||||
max_samples=max_samples,
|
||||
compute_type=compute_type,
|
||||
)
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
batch_size = gr.Slider(value=4, minimum=1, maximum=512, step=1)
|
||||
gradient_accumulation_steps = gr.Slider(value=4, minimum=1, maximum=512, step=1)
|
||||
lr_scheduler_type = gr.Dropdown(
|
||||
choices=[scheduler.value for scheduler in SchedulerType], value="cosine"
|
||||
)
|
||||
lr_scheduler_type = gr.Dropdown(choices=[scheduler.value for scheduler in SchedulerType], value="cosine")
|
||||
max_grad_norm = gr.Textbox(value="1.0")
|
||||
val_size = gr.Slider(value=0, minimum=0, maximum=1, step=0.001)
|
||||
|
||||
input_elems.update({batch_size, gradient_accumulation_steps, lr_scheduler_type, max_grad_norm, val_size})
|
||||
elem_dict.update(dict(
|
||||
batch_size=batch_size, gradient_accumulation_steps=gradient_accumulation_steps,
|
||||
lr_scheduler_type=lr_scheduler_type, max_grad_norm=max_grad_norm, val_size=val_size
|
||||
))
|
||||
elem_dict.update(
|
||||
dict(
|
||||
batch_size=batch_size,
|
||||
gradient_accumulation_steps=gradient_accumulation_steps,
|
||||
lr_scheduler_type=lr_scheduler_type,
|
||||
max_grad_norm=max_grad_norm,
|
||||
val_size=val_size,
|
||||
)
|
||||
)
|
||||
|
||||
with gr.Accordion(label="Extra config", open=False) as extra_tab:
|
||||
with gr.Row():
|
||||
@@ -73,10 +81,17 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
upcast_layernorm = gr.Checkbox(value=False)
|
||||
|
||||
input_elems.update({logging_steps, save_steps, warmup_steps, neftune_alpha, sft_packing, upcast_layernorm})
|
||||
elem_dict.update(dict(
|
||||
extra_tab=extra_tab, logging_steps=logging_steps, save_steps=save_steps, warmup_steps=warmup_steps,
|
||||
neftune_alpha=neftune_alpha, sft_packing=sft_packing, upcast_layernorm=upcast_layernorm
|
||||
))
|
||||
elem_dict.update(
|
||||
dict(
|
||||
extra_tab=extra_tab,
|
||||
logging_steps=logging_steps,
|
||||
save_steps=save_steps,
|
||||
warmup_steps=warmup_steps,
|
||||
neftune_alpha=neftune_alpha,
|
||||
sft_packing=sft_packing,
|
||||
upcast_layernorm=upcast_layernorm,
|
||||
)
|
||||
)
|
||||
|
||||
with gr.Accordion(label="LoRA config", open=False) as lora_tab:
|
||||
with gr.Row():
|
||||
@@ -87,10 +102,16 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
create_new_adapter = gr.Checkbox(scale=1)
|
||||
|
||||
input_elems.update({lora_rank, lora_dropout, lora_target, additional_target, create_new_adapter})
|
||||
elem_dict.update(dict(
|
||||
lora_tab=lora_tab, lora_rank=lora_rank, lora_dropout=lora_dropout, lora_target=lora_target,
|
||||
additional_target=additional_target, create_new_adapter=create_new_adapter
|
||||
))
|
||||
elem_dict.update(
|
||||
dict(
|
||||
lora_tab=lora_tab,
|
||||
lora_rank=lora_rank,
|
||||
lora_dropout=lora_dropout,
|
||||
lora_target=lora_target,
|
||||
additional_target=additional_target,
|
||||
create_new_adapter=create_new_adapter,
|
||||
)
|
||||
)
|
||||
|
||||
with gr.Accordion(label="RLHF config", open=False) as rlhf_tab:
|
||||
with gr.Row():
|
||||
@@ -103,13 +124,13 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
list_adapters,
|
||||
[engine.manager.get_elem_by_name("top.model_name"), engine.manager.get_elem_by_name("top.finetuning_type")],
|
||||
[reward_model],
|
||||
queue=False
|
||||
queue=False,
|
||||
)
|
||||
|
||||
input_elems.update({dpo_beta, dpo_ftx, reward_model})
|
||||
elem_dict.update(dict(
|
||||
rlhf_tab=rlhf_tab, dpo_beta=dpo_beta, dpo_ftx=dpo_ftx, reward_model=reward_model, refresh_btn=refresh_btn
|
||||
))
|
||||
elem_dict.update(
|
||||
dict(rlhf_tab=rlhf_tab, dpo_beta=dpo_beta, dpo_ftx=dpo_ftx, reward_model=reward_model, refresh_btn=refresh_btn)
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
cmd_preview_btn = gr.Button()
|
||||
@@ -139,20 +160,28 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
stop_btn.click(engine.runner.set_abort, queue=False)
|
||||
resume_btn.change(engine.runner.monitor, outputs=output_elems)
|
||||
|
||||
elem_dict.update(dict(
|
||||
cmd_preview_btn=cmd_preview_btn, start_btn=start_btn, stop_btn=stop_btn, output_dir=output_dir,
|
||||
resume_btn=resume_btn, process_bar=process_bar, output_box=output_box, loss_viewer=loss_viewer
|
||||
))
|
||||
elem_dict.update(
|
||||
dict(
|
||||
cmd_preview_btn=cmd_preview_btn,
|
||||
start_btn=start_btn,
|
||||
stop_btn=stop_btn,
|
||||
output_dir=output_dir,
|
||||
resume_btn=resume_btn,
|
||||
process_bar=process_bar,
|
||||
output_box=output_box,
|
||||
loss_viewer=loss_viewer,
|
||||
)
|
||||
)
|
||||
|
||||
output_box.change(
|
||||
gen_plot,
|
||||
[
|
||||
engine.manager.get_elem_by_name("top.model_name"),
|
||||
engine.manager.get_elem_by_name("top.finetuning_type"),
|
||||
output_dir
|
||||
output_dir,
|
||||
],
|
||||
loss_viewer,
|
||||
queue=False
|
||||
queue=False,
|
||||
)
|
||||
|
||||
return elem_dict
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import gradio as gr
|
||||
from gradio.components import Component # cannot use TYPE_CHECKING here
|
||||
from typing import Any, Dict, Generator, Optional
|
||||
|
||||
import gradio as gr
|
||||
from gradio.components import Component # cannot use TYPE_CHECKING here
|
||||
|
||||
from .chatter import WebChatModel
|
||||
from .common import get_model_path, list_dataset, load_config
|
||||
from .locales import LOCALES
|
||||
@@ -11,7 +12,6 @@ from .utils import get_time
|
||||
|
||||
|
||||
class Engine:
|
||||
|
||||
def __init__(self, demo_mode: Optional[bool] = False, pure_chat: Optional[bool] = False) -> None:
|
||||
self.demo_mode = demo_mode
|
||||
self.pure_chat = pure_chat
|
||||
@@ -26,10 +26,7 @@ class Engine:
|
||||
user_config = load_config() if not self.demo_mode else {}
|
||||
lang = user_config.get("lang", None) or "en"
|
||||
|
||||
init_dict = {
|
||||
"top.lang": {"value": lang},
|
||||
"infer.chat_box": {"visible": self.chatter.loaded}
|
||||
}
|
||||
init_dict = {"top.lang": {"value": lang}, "infer.chat_box": {"visible": self.chatter.loaded}}
|
||||
|
||||
if not self.pure_chat:
|
||||
init_dict["train.dataset"] = {"choices": list_dataset()["choices"]}
|
||||
@@ -49,13 +46,17 @@ class Engine:
|
||||
else:
|
||||
yield self._form_dict({"eval.resume_btn": {"value": True}})
|
||||
else:
|
||||
yield self._form_dict({
|
||||
"train.output_dir": {"value": "train_" + get_time()},
|
||||
"eval.output_dir": {"value": "eval_" + get_time()},
|
||||
})
|
||||
yield self._form_dict(
|
||||
{
|
||||
"train.output_dir": {"value": "train_" + get_time()},
|
||||
"eval.output_dir": {"value": "eval_" + get_time()},
|
||||
}
|
||||
)
|
||||
|
||||
def change_lang(self, lang: str) -> Dict[Component, Dict[str, Any]]:
|
||||
return {
|
||||
component: gr.update(**LOCALES[name][lang])
|
||||
for elems in self.manager.all_elems.values() for name, component in elems.items() if name in LOCALES
|
||||
for elems in self.manager.all_elems.values()
|
||||
for name, component in elems.items()
|
||||
if name in LOCALES
|
||||
}
|
||||
|
||||
@@ -1,21 +1,22 @@
|
||||
import gradio as gr
|
||||
from typing import Optional
|
||||
|
||||
import gradio as gr
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
from .common import save_config
|
||||
from .components import (
|
||||
create_chat_box,
|
||||
create_eval_tab,
|
||||
create_export_tab,
|
||||
create_infer_tab,
|
||||
create_top,
|
||||
create_train_tab,
|
||||
create_eval_tab,
|
||||
create_infer_tab,
|
||||
create_export_tab,
|
||||
create_chat_box
|
||||
)
|
||||
from .common import save_config
|
||||
from .css import CSS
|
||||
from .engine import Engine
|
||||
|
||||
|
||||
require_version("gradio>=3.38.0,<4.0.0", "To fix: pip install \"gradio>=3.38.0,<4.0.0\"")
|
||||
require_version("gradio>=3.38.0,<4.0.0", 'To fix: pip install "gradio>=3.38.0,<4.0.0"')
|
||||
|
||||
|
||||
def create_ui(demo_mode: Optional[bool] = False) -> gr.Blocks:
|
||||
@@ -23,11 +24,9 @@ def create_ui(demo_mode: Optional[bool] = False) -> gr.Blocks:
|
||||
|
||||
with gr.Blocks(title="LLaMA Board", css=CSS) as demo:
|
||||
if demo_mode:
|
||||
gr.HTML("<h1><center>LLaMA Board: A One-stop Web UI for Getting Started with LLaMA Factory</center></h1>")
|
||||
gr.HTML(
|
||||
"<h1><center>LLaMA Board: A One-stop Web UI for Getting Started with LLaMA Factory</center></h1>"
|
||||
)
|
||||
gr.HTML(
|
||||
"<h3><center>Visit <a href=\"https://github.com/hiyouga/LLaMA-Factory\" target=\"_blank\">"
|
||||
'<h3><center>Visit <a href="https://github.com/hiyouga/LLaMA-Factory" target="_blank">'
|
||||
"LLaMA Factory</a> for details.</center></h3>"
|
||||
)
|
||||
gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
|
||||
|
||||
@@ -1,726 +1,220 @@
|
||||
LOCALES = {
|
||||
"lang": {
|
||||
"en": {
|
||||
"label": "Lang"
|
||||
},
|
||||
"zh": {
|
||||
"label": "语言"
|
||||
}
|
||||
},
|
||||
"model_name": {
|
||||
"en": {
|
||||
"label": "Model name"
|
||||
},
|
||||
"zh": {
|
||||
"label": "模型名称"
|
||||
}
|
||||
},
|
||||
"lang": {"en": {"label": "Lang"}, "zh": {"label": "语言"}},
|
||||
"model_name": {"en": {"label": "Model name"}, "zh": {"label": "模型名称"}},
|
||||
"model_path": {
|
||||
"en": {
|
||||
"label": "Model path",
|
||||
"info": "Path to pretrained model or model identifier from Hugging Face."
|
||||
},
|
||||
"zh": {
|
||||
"label": "模型路径",
|
||||
"info": "本地模型的文件路径或 Hugging Face 的模型标识符。"
|
||||
}
|
||||
},
|
||||
"finetuning_type": {
|
||||
"en": {
|
||||
"label": "Finetuning method"
|
||||
},
|
||||
"zh": {
|
||||
"label": "微调方法"
|
||||
}
|
||||
},
|
||||
"adapter_path": {
|
||||
"en": {
|
||||
"label": "Adapter path"
|
||||
},
|
||||
"zh": {
|
||||
"label": "适配器路径"
|
||||
}
|
||||
},
|
||||
"refresh_btn": {
|
||||
"en": {
|
||||
"value": "Refresh adapters"
|
||||
},
|
||||
"zh": {
|
||||
"value": "刷新适配器"
|
||||
}
|
||||
},
|
||||
"advanced_tab": {
|
||||
"en": {
|
||||
"label": "Advanced configurations"
|
||||
},
|
||||
"zh": {
|
||||
"label": "高级设置"
|
||||
}
|
||||
"en": {"label": "Model path", "info": "Path to pretrained model or model identifier from Hugging Face."},
|
||||
"zh": {"label": "模型路径", "info": "本地模型的文件路径或 Hugging Face 的模型标识符。"},
|
||||
},
|
||||
"finetuning_type": {"en": {"label": "Finetuning method"}, "zh": {"label": "微调方法"}},
|
||||
"adapter_path": {"en": {"label": "Adapter path"}, "zh": {"label": "适配器路径"}},
|
||||
"refresh_btn": {"en": {"value": "Refresh adapters"}, "zh": {"value": "刷新适配器"}},
|
||||
"advanced_tab": {"en": {"label": "Advanced configurations"}, "zh": {"label": "高级设置"}},
|
||||
"quantization_bit": {
|
||||
"en": {
|
||||
"label": "Quantization bit",
|
||||
"info": "Enable 4/8-bit model quantization (QLoRA)."
|
||||
},
|
||||
"zh": {
|
||||
"label": "量化等级",
|
||||
"info": "启用 4/8 比特模型量化(QLoRA)。"
|
||||
}
|
||||
"en": {"label": "Quantization bit", "info": "Enable 4/8-bit model quantization (QLoRA)."},
|
||||
"zh": {"label": "量化等级", "info": "启用 4/8 比特模型量化(QLoRA)。"},
|
||||
},
|
||||
"template": {
|
||||
"en": {
|
||||
"label": "Prompt template",
|
||||
"info": "The template used in constructing prompts."
|
||||
},
|
||||
"zh": {
|
||||
"label": "提示模板",
|
||||
"info": "构建提示词时使用的模板"
|
||||
}
|
||||
},
|
||||
"rope_scaling": {
|
||||
"en": {
|
||||
"label": "RoPE scaling"
|
||||
},
|
||||
"zh": {
|
||||
"label": "RoPE 插值方法"
|
||||
}
|
||||
},
|
||||
"booster": {
|
||||
"en": {
|
||||
"label": "Booster"
|
||||
},
|
||||
"zh": {
|
||||
"label": "加速方式"
|
||||
}
|
||||
"en": {"label": "Prompt template", "info": "The template used in constructing prompts."},
|
||||
"zh": {"label": "提示模板", "info": "构建提示词时使用的模板"},
|
||||
},
|
||||
"rope_scaling": {"en": {"label": "RoPE scaling"}, "zh": {"label": "RoPE 插值方法"}},
|
||||
"booster": {"en": {"label": "Booster"}, "zh": {"label": "加速方式"}},
|
||||
"training_stage": {
|
||||
"en": {
|
||||
"label": "Stage",
|
||||
"info": "The stage to perform in training."
|
||||
},
|
||||
"zh": {
|
||||
"label": "训练阶段",
|
||||
"info": "目前采用的训练方式。"
|
||||
}
|
||||
"en": {"label": "Stage", "info": "The stage to perform in training."},
|
||||
"zh": {"label": "训练阶段", "info": "目前采用的训练方式。"},
|
||||
},
|
||||
"dataset_dir": {
|
||||
"en": {
|
||||
"label": "Data dir",
|
||||
"info": "Path to the data directory."
|
||||
},
|
||||
"zh": {
|
||||
"label": "数据路径",
|
||||
"info": "数据文件夹的路径。"
|
||||
}
|
||||
},
|
||||
"dataset": {
|
||||
"en": {
|
||||
"label": "Dataset"
|
||||
},
|
||||
"zh": {
|
||||
"label": "数据集"
|
||||
}
|
||||
},
|
||||
"data_preview_btn": {
|
||||
"en": {
|
||||
"value": "Preview dataset"
|
||||
},
|
||||
"zh": {
|
||||
"value": "预览数据集"
|
||||
}
|
||||
},
|
||||
"preview_count": {
|
||||
"en": {
|
||||
"label": "Count"
|
||||
},
|
||||
"zh": {
|
||||
"label": "数量"
|
||||
}
|
||||
},
|
||||
"page_index": {
|
||||
"en": {
|
||||
"label": "Page"
|
||||
},
|
||||
"zh": {
|
||||
"label": "页数"
|
||||
}
|
||||
},
|
||||
"prev_btn": {
|
||||
"en": {
|
||||
"value": "Prev"
|
||||
},
|
||||
"zh": {
|
||||
"value": "上一页"
|
||||
}
|
||||
},
|
||||
"next_btn": {
|
||||
"en": {
|
||||
"value": "Next"
|
||||
},
|
||||
"zh": {
|
||||
"value": "下一页"
|
||||
}
|
||||
},
|
||||
"close_btn": {
|
||||
"en": {
|
||||
"value": "Close"
|
||||
},
|
||||
"zh": {
|
||||
"value": "关闭"
|
||||
}
|
||||
},
|
||||
"preview_samples": {
|
||||
"en": {
|
||||
"label": "Samples"
|
||||
},
|
||||
"zh": {
|
||||
"label": "样例"
|
||||
}
|
||||
"en": {"label": "Data dir", "info": "Path to the data directory."},
|
||||
"zh": {"label": "数据路径", "info": "数据文件夹的路径。"},
|
||||
},
|
||||
"dataset": {"en": {"label": "Dataset"}, "zh": {"label": "数据集"}},
|
||||
"data_preview_btn": {"en": {"value": "Preview dataset"}, "zh": {"value": "预览数据集"}},
|
||||
"preview_count": {"en": {"label": "Count"}, "zh": {"label": "数量"}},
|
||||
"page_index": {"en": {"label": "Page"}, "zh": {"label": "页数"}},
|
||||
"prev_btn": {"en": {"value": "Prev"}, "zh": {"value": "上一页"}},
|
||||
"next_btn": {"en": {"value": "Next"}, "zh": {"value": "下一页"}},
|
||||
"close_btn": {"en": {"value": "Close"}, "zh": {"value": "关闭"}},
|
||||
"preview_samples": {"en": {"label": "Samples"}, "zh": {"label": "样例"}},
|
||||
"cutoff_len": {
|
||||
"en": {
|
||||
"label": "Cutoff length",
|
||||
"info": "Max tokens in input sequence."
|
||||
},
|
||||
"zh": {
|
||||
"label": "截断长度",
|
||||
"info": "输入序列分词后的最大长度。"
|
||||
}
|
||||
"en": {"label": "Cutoff length", "info": "Max tokens in input sequence."},
|
||||
"zh": {"label": "截断长度", "info": "输入序列分词后的最大长度。"},
|
||||
},
|
||||
"learning_rate": {
|
||||
"en": {
|
||||
"label": "Learning rate",
|
||||
"info": "Initial learning rate for AdamW."
|
||||
},
|
||||
"zh": {
|
||||
"label": "学习率",
|
||||
"info": "AdamW 优化器的初始学习率。"
|
||||
}
|
||||
"en": {"label": "Learning rate", "info": "Initial learning rate for AdamW."},
|
||||
"zh": {"label": "学习率", "info": "AdamW 优化器的初始学习率。"},
|
||||
},
|
||||
"num_train_epochs": {
|
||||
"en": {
|
||||
"label": "Epochs",
|
||||
"info": "Total number of training epochs to perform."
|
||||
},
|
||||
"zh": {
|
||||
"label": "训练轮数",
|
||||
"info": "需要执行的训练总轮数。"
|
||||
}
|
||||
"en": {"label": "Epochs", "info": "Total number of training epochs to perform."},
|
||||
"zh": {"label": "训练轮数", "info": "需要执行的训练总轮数。"},
|
||||
},
|
||||
"max_samples": {
|
||||
"en": {
|
||||
"label": "Max samples",
|
||||
"info": "Maximum samples per dataset."
|
||||
},
|
||||
"zh": {
|
||||
"label": "最大样本数",
|
||||
"info": "每个数据集最多使用的样本数。"
|
||||
}
|
||||
"en": {"label": "Max samples", "info": "Maximum samples per dataset."},
|
||||
"zh": {"label": "最大样本数", "info": "每个数据集最多使用的样本数。"},
|
||||
},
|
||||
"compute_type": {
|
||||
"en": {
|
||||
"label": "Compute type",
|
||||
"info": "Whether to use fp16 or bf16 mixed precision training."
|
||||
},
|
||||
"zh": {
|
||||
"label": "计算类型",
|
||||
"info": "是否启用 FP16 或 BF16 混合精度训练。"
|
||||
}
|
||||
"en": {"label": "Compute type", "info": "Whether to use fp16 or bf16 mixed precision training."},
|
||||
"zh": {"label": "计算类型", "info": "是否启用 FP16 或 BF16 混合精度训练。"},
|
||||
},
|
||||
"batch_size": {
|
||||
"en": {
|
||||
"label": "Batch size",
|
||||
"info": "Number of samples to process per GPU."
|
||||
},
|
||||
"zh":{
|
||||
"label": "批处理大小",
|
||||
"info": "每块 GPU 上处理的样本数量。"
|
||||
}
|
||||
"en": {"label": "Batch size", "info": "Number of samples to process per GPU."},
|
||||
"zh": {"label": "批处理大小", "info": "每块 GPU 上处理的样本数量。"},
|
||||
},
|
||||
"gradient_accumulation_steps": {
|
||||
"en": {
|
||||
"label": "Gradient accumulation",
|
||||
"info": "Number of gradient accumulation steps."
|
||||
},
|
||||
"zh": {
|
||||
"label": "梯度累积",
|
||||
"info": "梯度累积的步数。"
|
||||
}
|
||||
"en": {"label": "Gradient accumulation", "info": "Number of gradient accumulation steps."},
|
||||
"zh": {"label": "梯度累积", "info": "梯度累积的步数。"},
|
||||
},
|
||||
"lr_scheduler_type": {
|
||||
"en": {
|
||||
"label": "LR Scheduler",
|
||||
"info": "Name of learning rate scheduler.",
|
||||
},
|
||||
"zh": {
|
||||
"label": "学习率调节器",
|
||||
"info": "采用的学习率调节器名称。"
|
||||
}
|
||||
"zh": {"label": "学习率调节器", "info": "采用的学习率调节器名称。"},
|
||||
},
|
||||
"max_grad_norm": {
|
||||
"en": {
|
||||
"label": "Maximum gradient norm",
|
||||
"info": "Norm for gradient clipping.."
|
||||
},
|
||||
"zh": {
|
||||
"label": "最大梯度范数",
|
||||
"info": "用于梯度裁剪的范数。"
|
||||
}
|
||||
"en": {"label": "Maximum gradient norm", "info": "Norm for gradient clipping.."},
|
||||
"zh": {"label": "最大梯度范数", "info": "用于梯度裁剪的范数。"},
|
||||
},
|
||||
"val_size": {
|
||||
"en": {
|
||||
"label": "Val size",
|
||||
"info": "Proportion of data in the dev set."
|
||||
},
|
||||
"zh": {
|
||||
"label": "验证集比例",
|
||||
"info": "验证集占全部样本的百分比。"
|
||||
}
|
||||
},
|
||||
"extra_tab": {
|
||||
"en": {
|
||||
"label": "Extra configurations"
|
||||
},
|
||||
"zh": {
|
||||
"label": "其它参数设置"
|
||||
}
|
||||
"en": {"label": "Val size", "info": "Proportion of data in the dev set."},
|
||||
"zh": {"label": "验证集比例", "info": "验证集占全部样本的百分比。"},
|
||||
},
|
||||
"extra_tab": {"en": {"label": "Extra configurations"}, "zh": {"label": "其它参数设置"}},
|
||||
"logging_steps": {
|
||||
"en": {
|
||||
"label": "Logging steps",
|
||||
"info": "Number of steps between two logs."
|
||||
},
|
||||
"zh": {
|
||||
"label": "日志间隔",
|
||||
"info": "每两次日志输出间的更新步数。"
|
||||
}
|
||||
"en": {"label": "Logging steps", "info": "Number of steps between two logs."},
|
||||
"zh": {"label": "日志间隔", "info": "每两次日志输出间的更新步数。"},
|
||||
},
|
||||
"save_steps": {
|
||||
"en": {
|
||||
"label": "Save steps",
|
||||
"info": "Number of steps between two checkpoints."
|
||||
},
|
||||
"zh": {
|
||||
"label": "保存间隔",
|
||||
"info": "每两次断点保存间的更新步数。"
|
||||
}
|
||||
"en": {"label": "Save steps", "info": "Number of steps between two checkpoints."},
|
||||
"zh": {"label": "保存间隔", "info": "每两次断点保存间的更新步数。"},
|
||||
},
|
||||
"warmup_steps": {
|
||||
"en": {
|
||||
"label": "Warmup steps",
|
||||
"info": "Number of steps used for warmup."
|
||||
},
|
||||
"zh": {
|
||||
"label": "预热步数",
|
||||
"info": "学习率预热采用的步数。"
|
||||
}
|
||||
"en": {"label": "Warmup steps", "info": "Number of steps used for warmup."},
|
||||
"zh": {"label": "预热步数", "info": "学习率预热采用的步数。"},
|
||||
},
|
||||
"neftune_alpha": {
|
||||
"en": {
|
||||
"label": "NEFTune Alpha",
|
||||
"info": "Magnitude of noise adding to embedding vectors."
|
||||
},
|
||||
"zh": {
|
||||
"label": "NEFTune 噪声参数",
|
||||
"info": "嵌入向量所添加的噪声大小。"
|
||||
}
|
||||
"en": {"label": "NEFTune Alpha", "info": "Magnitude of noise adding to embedding vectors."},
|
||||
"zh": {"label": "NEFTune 噪声参数", "info": "嵌入向量所添加的噪声大小。"},
|
||||
},
|
||||
"sft_packing": {
|
||||
"en": {
|
||||
"label": "Pack sequences",
|
||||
"info": "Pack sequences into samples of fixed length in supervised fine-tuning."
|
||||
"info": "Pack sequences into samples of fixed length in supervised fine-tuning.",
|
||||
},
|
||||
"zh": {
|
||||
"label": "序列打包",
|
||||
"info": "在有监督微调阶段将序列打包为相同长度的样本。"
|
||||
}
|
||||
"zh": {"label": "序列打包", "info": "在有监督微调阶段将序列打包为相同长度的样本。"},
|
||||
},
|
||||
"upcast_layernorm": {
|
||||
"en": {
|
||||
"label": "Upcast LayerNorm",
|
||||
"info": "Upcast weights of layernorm in float32."
|
||||
},
|
||||
"zh": {
|
||||
"label": "缩放归一化层",
|
||||
"info": "将归一化层权重缩放至 32 位精度。"
|
||||
}
|
||||
},
|
||||
"lora_tab": {
|
||||
"en": {
|
||||
"label": "LoRA configurations"
|
||||
},
|
||||
"zh": {
|
||||
"label": "LoRA 参数设置"
|
||||
}
|
||||
"en": {"label": "Upcast LayerNorm", "info": "Upcast weights of layernorm in float32."},
|
||||
"zh": {"label": "缩放归一化层", "info": "将归一化层权重缩放至 32 位精度。"},
|
||||
},
|
||||
"lora_tab": {"en": {"label": "LoRA configurations"}, "zh": {"label": "LoRA 参数设置"}},
|
||||
"lora_rank": {
|
||||
"en": {
|
||||
"label": "LoRA rank",
|
||||
"info": "The rank of LoRA matrices."
|
||||
},
|
||||
"zh": {
|
||||
"label": "LoRA 秩",
|
||||
"info": "LoRA 矩阵的秩。"
|
||||
}
|
||||
"en": {"label": "LoRA rank", "info": "The rank of LoRA matrices."},
|
||||
"zh": {"label": "LoRA 秩", "info": "LoRA 矩阵的秩。"},
|
||||
},
|
||||
"lora_dropout": {
|
||||
"en": {
|
||||
"label": "LoRA Dropout",
|
||||
"info": "Dropout ratio of LoRA weights."
|
||||
},
|
||||
"zh": {
|
||||
"label": "LoRA 随机丢弃",
|
||||
"info": "LoRA 权重随机丢弃的概率。"
|
||||
}
|
||||
"en": {"label": "LoRA Dropout", "info": "Dropout ratio of LoRA weights."},
|
||||
"zh": {"label": "LoRA 随机丢弃", "info": "LoRA 权重随机丢弃的概率。"},
|
||||
},
|
||||
"lora_target": {
|
||||
"en": {
|
||||
"label": "LoRA modules (optional)",
|
||||
"info": "Name(s) of target modules to apply LoRA. Use commas to separate multiple modules."
|
||||
"info": "Name(s) of target modules to apply LoRA. Use commas to separate multiple modules.",
|
||||
},
|
||||
"zh": {
|
||||
"label": "LoRA 作用模块(非必填)",
|
||||
"info": "应用 LoRA 的目标模块名称。使用英文逗号分隔多个名称。"
|
||||
}
|
||||
"zh": {"label": "LoRA 作用模块(非必填)", "info": "应用 LoRA 的目标模块名称。使用英文逗号分隔多个名称。"},
|
||||
},
|
||||
"additional_target": {
|
||||
"en": {
|
||||
"label": "Additional modules (optional)",
|
||||
"info": "Name(s) of modules apart from LoRA layers to be set as trainable. Use commas to separate multiple modules."
|
||||
"info": "Name(s) of modules apart from LoRA layers to be set as trainable. Use commas to separate multiple modules.",
|
||||
},
|
||||
"zh": {
|
||||
"label": "附加模块(非必填)",
|
||||
"info": "除 LoRA 层以外的可训练模块名称。使用英文逗号分隔多个名称。"
|
||||
}
|
||||
"zh": {"label": "附加模块(非必填)", "info": "除 LoRA 层以外的可训练模块名称。使用英文逗号分隔多个名称。"},
|
||||
},
|
||||
"create_new_adapter": {
|
||||
"en": {
|
||||
"label": "Create new adapter",
|
||||
"info": "Whether to create a new adapter with randomly initialized weight or not."
|
||||
"info": "Whether to create a new adapter with randomly initialized weight or not.",
|
||||
},
|
||||
"zh": {
|
||||
"label": "新建适配器",
|
||||
"info": "是否创建一个经过随机初始化的新适配器。"
|
||||
}
|
||||
},
|
||||
"rlhf_tab": {
|
||||
"en": {
|
||||
"label": "RLHF configurations"
|
||||
},
|
||||
"zh": {
|
||||
"label": "RLHF 参数设置"
|
||||
}
|
||||
"zh": {"label": "新建适配器", "info": "是否创建一个经过随机初始化的新适配器。"},
|
||||
},
|
||||
"rlhf_tab": {"en": {"label": "RLHF configurations"}, "zh": {"label": "RLHF 参数设置"}},
|
||||
"dpo_beta": {
|
||||
"en": {
|
||||
"label": "DPO beta",
|
||||
"info": "Value of the beta parameter in the DPO loss."
|
||||
},
|
||||
"zh": {
|
||||
"label": "DPO beta 参数",
|
||||
"info": "DPO 损失函数中 beta 超参数大小。"
|
||||
}
|
||||
"en": {"label": "DPO beta", "info": "Value of the beta parameter in the DPO loss."},
|
||||
"zh": {"label": "DPO beta 参数", "info": "DPO 损失函数中 beta 超参数大小。"},
|
||||
},
|
||||
"dpo_ftx": {
|
||||
"en": {
|
||||
"label": "DPO-ftx weight",
|
||||
"info": "The weight of SFT loss in the DPO-ftx."
|
||||
},
|
||||
"zh": {
|
||||
"label": "DPO-ftx 权重",
|
||||
"info": "DPO-ftx 中 SFT 损失的权重大小。"
|
||||
}
|
||||
"en": {"label": "DPO-ftx weight", "info": "The weight of SFT loss in the DPO-ftx."},
|
||||
"zh": {"label": "DPO-ftx 权重", "info": "DPO-ftx 中 SFT 损失的权重大小。"},
|
||||
},
|
||||
"reward_model": {
|
||||
"en": {
|
||||
"label": "Reward model",
|
||||
"info": "Adapter of the reward model for PPO training. (Needs to refresh adapters)"
|
||||
"info": "Adapter of the reward model for PPO training. (Needs to refresh adapters)",
|
||||
},
|
||||
"zh": {
|
||||
"label": "奖励模型",
|
||||
"info": "PPO 训练中奖励模型的适配器路径。(需要刷新适配器)"
|
||||
}
|
||||
},
|
||||
"cmd_preview_btn": {
|
||||
"en": {
|
||||
"value": "Preview command"
|
||||
},
|
||||
"zh": {
|
||||
"value": "预览命令"
|
||||
}
|
||||
},
|
||||
"start_btn": {
|
||||
"en": {
|
||||
"value": "Start"
|
||||
},
|
||||
"zh": {
|
||||
"value": "开始"
|
||||
}
|
||||
},
|
||||
"stop_btn": {
|
||||
"en": {
|
||||
"value": "Abort"
|
||||
},
|
||||
"zh": {
|
||||
"value": "中断"
|
||||
}
|
||||
"zh": {"label": "奖励模型", "info": "PPO 训练中奖励模型的适配器路径。(需要刷新适配器)"},
|
||||
},
|
||||
"cmd_preview_btn": {"en": {"value": "Preview command"}, "zh": {"value": "预览命令"}},
|
||||
"start_btn": {"en": {"value": "Start"}, "zh": {"value": "开始"}},
|
||||
"stop_btn": {"en": {"value": "Abort"}, "zh": {"value": "中断"}},
|
||||
"output_dir": {
|
||||
"en": {
|
||||
"label": "Output dir",
|
||||
"info": "Directory for saving results."
|
||||
},
|
||||
"zh": {
|
||||
"label": "输出目录",
|
||||
"info": "保存结果的路径。"
|
||||
}
|
||||
},
|
||||
"output_box": {
|
||||
"en": {
|
||||
"value": "Ready."
|
||||
},
|
||||
"zh": {
|
||||
"value": "准备就绪。"
|
||||
}
|
||||
},
|
||||
"loss_viewer": {
|
||||
"en": {
|
||||
"label": "Loss"
|
||||
},
|
||||
"zh": {
|
||||
"label": "损失"
|
||||
}
|
||||
},
|
||||
"predict": {
|
||||
"en": {
|
||||
"label": "Save predictions"
|
||||
},
|
||||
"zh": {
|
||||
"label": "保存预测结果"
|
||||
}
|
||||
},
|
||||
"load_btn": {
|
||||
"en": {
|
||||
"value": "Load model"
|
||||
},
|
||||
"zh": {
|
||||
"value": "加载模型"
|
||||
}
|
||||
},
|
||||
"unload_btn": {
|
||||
"en": {
|
||||
"value": "Unload model"
|
||||
},
|
||||
"zh": {
|
||||
"value": "卸载模型"
|
||||
}
|
||||
},
|
||||
"info_box": {
|
||||
"en": {
|
||||
"value": "Model unloaded, please load a model first."
|
||||
},
|
||||
"zh": {
|
||||
"value": "模型未加载,请先加载模型。"
|
||||
}
|
||||
},
|
||||
"system": {
|
||||
"en": {
|
||||
"placeholder": "System prompt (optional)"
|
||||
},
|
||||
"zh": {
|
||||
"placeholder": "系统提示词(非必填)"
|
||||
}
|
||||
},
|
||||
"tools": {
|
||||
"en": {
|
||||
"placeholder": "Tools (optional)"
|
||||
},
|
||||
"zh": {
|
||||
"placeholder": "工具列表(非必填)"
|
||||
}
|
||||
},
|
||||
"query": {
|
||||
"en": {
|
||||
"placeholder": "Input..."
|
||||
},
|
||||
"zh": {
|
||||
"placeholder": "输入..."
|
||||
}
|
||||
},
|
||||
"submit_btn": {
|
||||
"en": {
|
||||
"value": "Submit"
|
||||
},
|
||||
"zh": {
|
||||
"value": "提交"
|
||||
}
|
||||
},
|
||||
"clear_btn": {
|
||||
"en": {
|
||||
"value": "Clear history"
|
||||
},
|
||||
"zh": {
|
||||
"value": "清空历史"
|
||||
}
|
||||
},
|
||||
"max_length": {
|
||||
"en": {
|
||||
"label": "Maximum length"
|
||||
},
|
||||
"zh": {
|
||||
"label": "最大长度"
|
||||
}
|
||||
},
|
||||
"max_new_tokens": {
|
||||
"en": {
|
||||
"label": "Maximum new tokens"
|
||||
},
|
||||
"zh": {
|
||||
"label": "最大生成长度"
|
||||
}
|
||||
},
|
||||
"top_p": {
|
||||
"en": {
|
||||
"label": "Top-p"
|
||||
},
|
||||
"zh": {
|
||||
"label": "Top-p 采样值"
|
||||
}
|
||||
},
|
||||
"temperature": {
|
||||
"en": {
|
||||
"label": "Temperature"
|
||||
},
|
||||
"zh": {
|
||||
"label": "温度系数"
|
||||
}
|
||||
"en": {"label": "Output dir", "info": "Directory for saving results."},
|
||||
"zh": {"label": "输出目录", "info": "保存结果的路径。"},
|
||||
},
|
||||
"output_box": {"en": {"value": "Ready."}, "zh": {"value": "准备就绪。"}},
|
||||
"loss_viewer": {"en": {"label": "Loss"}, "zh": {"label": "损失"}},
|
||||
"predict": {"en": {"label": "Save predictions"}, "zh": {"label": "保存预测结果"}},
|
||||
"load_btn": {"en": {"value": "Load model"}, "zh": {"value": "加载模型"}},
|
||||
"unload_btn": {"en": {"value": "Unload model"}, "zh": {"value": "卸载模型"}},
|
||||
"info_box": {"en": {"value": "Model unloaded, please load a model first."}, "zh": {"value": "模型未加载,请先加载模型。"}},
|
||||
"system": {"en": {"placeholder": "System prompt (optional)"}, "zh": {"placeholder": "系统提示词(非必填)"}},
|
||||
"tools": {"en": {"placeholder": "Tools (optional)"}, "zh": {"placeholder": "工具列表(非必填)"}},
|
||||
"query": {"en": {"placeholder": "Input..."}, "zh": {"placeholder": "输入..."}},
|
||||
"submit_btn": {"en": {"value": "Submit"}, "zh": {"value": "提交"}},
|
||||
"clear_btn": {"en": {"value": "Clear history"}, "zh": {"value": "清空历史"}},
|
||||
"max_length": {"en": {"label": "Maximum length"}, "zh": {"label": "最大长度"}},
|
||||
"max_new_tokens": {"en": {"label": "Maximum new tokens"}, "zh": {"label": "最大生成长度"}},
|
||||
"top_p": {"en": {"label": "Top-p"}, "zh": {"label": "Top-p 采样值"}},
|
||||
"temperature": {"en": {"label": "Temperature"}, "zh": {"label": "温度系数"}},
|
||||
"max_shard_size": {
|
||||
"en": {
|
||||
"label": "Max shard size (GB)",
|
||||
"info": "The maximum size for a model file."
|
||||
},
|
||||
"zh": {
|
||||
"label": "最大分块大小(GB)",
|
||||
"info": "单个模型文件的最大大小。"
|
||||
}
|
||||
"en": {"label": "Max shard size (GB)", "info": "The maximum size for a model file."},
|
||||
"zh": {"label": "最大分块大小(GB)", "info": "单个模型文件的最大大小。"},
|
||||
},
|
||||
"export_quantization_bit": {
|
||||
"en": {
|
||||
"label": "Export quantization bit.",
|
||||
"info": "Quantizing the exported model."
|
||||
},
|
||||
"zh": {
|
||||
"label": "导出量化等级",
|
||||
"info": "量化导出模型。"
|
||||
}
|
||||
"en": {"label": "Export quantization bit.", "info": "Quantizing the exported model."},
|
||||
"zh": {"label": "导出量化等级", "info": "量化导出模型。"},
|
||||
},
|
||||
"export_quantization_dataset": {
|
||||
"en": {
|
||||
"label": "Export quantization dataset.",
|
||||
"info": "The calibration dataset used for quantization."
|
||||
},
|
||||
"zh": {
|
||||
"label": "导出量化数据集",
|
||||
"info": "量化过程中使用的校准数据集。"
|
||||
}
|
||||
"en": {"label": "Export quantization dataset.", "info": "The calibration dataset used for quantization."},
|
||||
"zh": {"label": "导出量化数据集", "info": "量化过程中使用的校准数据集。"},
|
||||
},
|
||||
"export_dir": {
|
||||
"en": {
|
||||
"label": "Export dir",
|
||||
"info": "Directory to save exported model."
|
||||
},
|
||||
"zh": {
|
||||
"label": "导出目录",
|
||||
"info": "保存导出模型的文件夹路径。"
|
||||
}
|
||||
"en": {"label": "Export dir", "info": "Directory to save exported model."},
|
||||
"zh": {"label": "导出目录", "info": "保存导出模型的文件夹路径。"},
|
||||
},
|
||||
"export_btn": {
|
||||
"en": {
|
||||
"value": "Export"
|
||||
},
|
||||
"zh": {
|
||||
"value": "开始导出"
|
||||
}
|
||||
}
|
||||
"export_btn": {"en": {"value": "Export"}, "zh": {"value": "开始导出"}},
|
||||
}
|
||||
|
||||
|
||||
ALERTS = {
|
||||
"err_conflict": {
|
||||
"en": "A process is in running, please abort it firstly.",
|
||||
"zh": "任务已存在,请先中断训练。"
|
||||
},
|
||||
"err_exists": {
|
||||
"en": "You have loaded a model, please unload it first.",
|
||||
"zh": "模型已存在,请先卸载模型。"
|
||||
},
|
||||
"err_no_model": {
|
||||
"en": "Please select a model.",
|
||||
"zh": "请选择模型。"
|
||||
},
|
||||
"err_no_path": {
|
||||
"en": "Model not found.",
|
||||
"zh": "模型未找到。"
|
||||
},
|
||||
"err_no_dataset": {
|
||||
"en": "Please choose a dataset.",
|
||||
"zh": "请选择数据集。"
|
||||
},
|
||||
"err_no_adapter": {
|
||||
"en": "Please select an adapter.",
|
||||
"zh": "请选择一个适配器。"
|
||||
},
|
||||
"err_no_export_dir": {
|
||||
"en": "Please provide export dir.",
|
||||
"zh": "请填写导出目录"
|
||||
},
|
||||
"err_failed": {
|
||||
"en": "Failed.",
|
||||
"zh": "训练出错。"
|
||||
},
|
||||
"err_conflict": {"en": "A process is in running, please abort it firstly.", "zh": "任务已存在,请先中断训练。"},
|
||||
"err_exists": {"en": "You have loaded a model, please unload it first.", "zh": "模型已存在,请先卸载模型。"},
|
||||
"err_no_model": {"en": "Please select a model.", "zh": "请选择模型。"},
|
||||
"err_no_path": {"en": "Model not found.", "zh": "模型未找到。"},
|
||||
"err_no_dataset": {"en": "Please choose a dataset.", "zh": "请选择数据集。"},
|
||||
"err_no_adapter": {"en": "Please select an adapter.", "zh": "请选择一个适配器。"},
|
||||
"err_no_export_dir": {"en": "Please provide export dir.", "zh": "请填写导出目录"},
|
||||
"err_failed": {"en": "Failed.", "zh": "训练出错。"},
|
||||
"err_demo": {
|
||||
"en": "Training is unavailable in demo mode, duplicate the space to a private one first.",
|
||||
"zh": "展示模式不支持训练,请先复制到私人空间。"
|
||||
"zh": "展示模式不支持训练,请先复制到私人空间。",
|
||||
},
|
||||
"err_device_count": {
|
||||
"en": "Multiple GPUs are not supported yet.",
|
||||
"zh": "尚不支持多 GPU 训练。"
|
||||
},
|
||||
"info_aborting": {
|
||||
"en": "Aborted, wait for terminating...",
|
||||
"zh": "训练中断,正在等待线程结束……"
|
||||
},
|
||||
"info_aborted": {
|
||||
"en": "Ready.",
|
||||
"zh": "准备就绪。"
|
||||
},
|
||||
"info_finished": {
|
||||
"en": "Finished.",
|
||||
"zh": "训练完毕。"
|
||||
},
|
||||
"info_loading": {
|
||||
"en": "Loading model...",
|
||||
"zh": "加载中……"
|
||||
},
|
||||
"info_unloading": {
|
||||
"en": "Unloading model...",
|
||||
"zh": "卸载中……"
|
||||
},
|
||||
"info_loaded": {
|
||||
"en": "Model loaded, now you can chat with your model!",
|
||||
"zh": "模型已加载,可以开始聊天了!"
|
||||
},
|
||||
"info_unloaded": {
|
||||
"en": "Model unloaded.",
|
||||
"zh": "模型已卸载。"
|
||||
},
|
||||
"info_exporting": {
|
||||
"en": "Exporting model...",
|
||||
"zh": "正在导出模型……"
|
||||
},
|
||||
"info_exported": {
|
||||
"en": "Model exported.",
|
||||
"zh": "模型导出完成。"
|
||||
}
|
||||
"err_device_count": {"en": "Multiple GPUs are not supported yet.", "zh": "尚不支持多 GPU 训练。"},
|
||||
"info_aborting": {"en": "Aborted, wait for terminating...", "zh": "训练中断,正在等待线程结束……"},
|
||||
"info_aborted": {"en": "Ready.", "zh": "准备就绪。"},
|
||||
"info_finished": {"en": "Finished.", "zh": "训练完毕。"},
|
||||
"info_loading": {"en": "Loading model...", "zh": "加载中……"},
|
||||
"info_unloading": {"en": "Unloading model...", "zh": "卸载中……"},
|
||||
"info_loaded": {"en": "Model loaded, now you can chat with your model!", "zh": "模型已加载,可以开始聊天了!"},
|
||||
"info_unloaded": {"en": "Model unloaded.", "zh": "模型已卸载。"},
|
||||
"info_exporting": {"en": "Exporting model...", "zh": "正在导出模型……"},
|
||||
"info_exported": {"en": "Model exported.", "zh": "模型导出完成。"},
|
||||
}
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
from typing import TYPE_CHECKING, Dict, List, Set
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from gradio.components import Component
|
||||
|
||||
|
||||
class Manager:
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.all_elems: Dict[str, Dict[str, "Component"]] = {}
|
||||
|
||||
@@ -26,7 +26,7 @@ class Manager:
|
||||
self.all_elems["top"]["quantization_bit"],
|
||||
self.all_elems["top"]["template"],
|
||||
self.all_elems["top"]["rope_scaling"],
|
||||
self.all_elems["top"]["booster"]
|
||||
self.all_elems["top"]["booster"],
|
||||
}
|
||||
|
||||
def list_elems(self) -> List["Component"]:
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
import logging
|
||||
import gradio as gr
|
||||
from threading import Thread
|
||||
from gradio.components import Component # cannot use TYPE_CHECKING here
|
||||
from typing import TYPE_CHECKING, Any, Dict, Generator, Optional, Tuple
|
||||
|
||||
import gradio as gr
|
||||
import transformers
|
||||
from gradio.components import Component # cannot use TYPE_CHECKING here
|
||||
from transformers.trainer import TRAINING_ARGS_NAME
|
||||
|
||||
from ..extras.callbacks import LogCallback
|
||||
@@ -18,12 +18,12 @@ from .common import get_module, get_save_dir, load_config
|
||||
from .locales import ALERTS
|
||||
from .utils import gen_cmd, get_eval_results, update_process_bar
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .manager import Manager
|
||||
|
||||
|
||||
class Runner:
|
||||
|
||||
def __init__(self, manager: "Manager", demo_mode: Optional[bool] = False) -> None:
|
||||
self.manager = manager
|
||||
self.demo_mode = demo_mode
|
||||
@@ -90,9 +90,12 @@ class Runner:
|
||||
user_config = load_config()
|
||||
|
||||
if get("top.adapter_path"):
|
||||
adapter_name_or_path = ",".join([
|
||||
get_save_dir(get("top.model_name"), get("top.finetuning_type"), adapter)
|
||||
for adapter in get("top.adapter_path")])
|
||||
adapter_name_or_path = ",".join(
|
||||
[
|
||||
get_save_dir(get("top.model_name"), get("top.finetuning_type"), adapter)
|
||||
for adapter in get("top.adapter_path")
|
||||
]
|
||||
)
|
||||
else:
|
||||
adapter_name_or_path = None
|
||||
|
||||
@@ -131,12 +134,12 @@ class Runner:
|
||||
create_new_adapter=get("train.create_new_adapter"),
|
||||
output_dir=get_save_dir(get("top.model_name"), get("top.finetuning_type"), get("train.output_dir")),
|
||||
fp16=(get("train.compute_type") == "fp16"),
|
||||
bf16=(get("train.compute_type") == "bf16")
|
||||
bf16=(get("train.compute_type") == "bf16"),
|
||||
)
|
||||
args["disable_tqdm"] = True
|
||||
|
||||
if TRAINING_STAGES[get("train.training_stage")] in ["rm", "ppo", "dpo"]:
|
||||
args["create_new_adapter"] = (args["quantization_bit"] is None)
|
||||
args["create_new_adapter"] = args["quantization_bit"] is None
|
||||
|
||||
if args["stage"] == "ppo":
|
||||
args["reward_model"] = get_save_dir(
|
||||
@@ -161,9 +164,12 @@ class Runner:
|
||||
user_config = load_config()
|
||||
|
||||
if get("top.adapter_path"):
|
||||
adapter_name_or_path = ",".join([
|
||||
get_save_dir(get("top.model_name"), get("top.finetuning_type"), adapter)
|
||||
for adapter in get("top.adapter_path")])
|
||||
adapter_name_or_path = ",".join(
|
||||
[
|
||||
get_save_dir(get("top.model_name"), get("top.finetuning_type"), adapter)
|
||||
for adapter in get("top.adapter_path")
|
||||
]
|
||||
)
|
||||
else:
|
||||
adapter_name_or_path = None
|
||||
|
||||
@@ -187,7 +193,7 @@ class Runner:
|
||||
max_new_tokens=get("eval.max_new_tokens"),
|
||||
top_p=get("eval.top_p"),
|
||||
temperature=get("eval.temperature"),
|
||||
output_dir=get_save_dir(get("top.model_name"), get("top.finetuning_type"), get("eval.output_dir"))
|
||||
output_dir=get_save_dir(get("top.model_name"), get("top.finetuning_type"), get("eval.output_dir")),
|
||||
)
|
||||
|
||||
if get("eval.predict"):
|
||||
@@ -197,7 +203,9 @@ class Runner:
|
||||
|
||||
return args
|
||||
|
||||
def _preview(self, data: Dict[Component, Any], do_train: bool) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
|
||||
def _preview(
|
||||
self, data: Dict[Component, Any], do_train: bool
|
||||
) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
|
||||
error = self._initialize(data, do_train, from_preview=True)
|
||||
if error:
|
||||
gr.Warning(error)
|
||||
@@ -235,9 +243,11 @@ class Runner:
|
||||
get = lambda name: self.running_data[self.manager.get_elem_by_name(name)]
|
||||
self.running = True
|
||||
lang = get("top.lang")
|
||||
output_dir = get_save_dir(get("top.model_name"), get("top.finetuning_type"), get(
|
||||
"{}.output_dir".format("train" if self.do_train else "eval")
|
||||
))
|
||||
output_dir = get_save_dir(
|
||||
get("top.model_name"),
|
||||
get("top.finetuning_type"),
|
||||
get("{}.output_dir".format("train" if self.do_train else "eval")),
|
||||
)
|
||||
|
||||
while self.thread.is_alive():
|
||||
time.sleep(2)
|
||||
|
||||
@@ -1,13 +1,15 @@
|
||||
import os
|
||||
import json
|
||||
import gradio as gr
|
||||
from typing import TYPE_CHECKING, Any, Dict
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Any, Dict
|
||||
|
||||
import gradio as gr
|
||||
|
||||
from ..extras.packages import is_matplotlib_available
|
||||
from ..extras.ploting import smooth
|
||||
from .common import get_save_dir
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..extras.callbacks import LogCallback
|
||||
|
||||
@@ -22,16 +24,13 @@ def update_process_bar(callback: "LogCallback") -> Dict[str, Any]:
|
||||
|
||||
percentage = round(100 * callback.cur_steps / callback.max_steps, 0) if callback.max_steps != 0 else 100.0
|
||||
label = "Running {:d}/{:d}: {} < {}".format(
|
||||
callback.cur_steps,
|
||||
callback.max_steps,
|
||||
callback.elapsed_time,
|
||||
callback.remaining_time
|
||||
callback.cur_steps, callback.max_steps, callback.elapsed_time, callback.remaining_time
|
||||
)
|
||||
return gr.update(label=label, value=percentage, visible=True)
|
||||
|
||||
|
||||
def get_time() -> str:
|
||||
return datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
|
||||
return datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
|
||||
|
||||
|
||||
def can_quantize(finetuning_type: str) -> Dict[str, Any]:
|
||||
|
||||
Reference in New Issue
Block a user