update webUI, fix #179

This commit is contained in:
hiyouga
2023-07-18 15:35:17 +08:00
parent b9fe83fb75
commit 12d8a8633f
9 changed files with 247 additions and 154 deletions

View File

@@ -12,8 +12,8 @@ from llmtuner.webui.utils import can_preview, get_preview, gen_plot
def create_sft_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str, Component]:
with gr.Row():
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, interactive=True, scale=2)
dataset = gr.Dropdown(multiselect=True, interactive=True, scale=4)
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2)
dataset = gr.Dropdown(multiselect=True, scale=4)
preview_btn = gr.Button(interactive=False, scale=1)
preview_box, preview_count, preview_samples, close_btn = create_preview_box()
@@ -23,22 +23,21 @@ def create_sft_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str,
preview_btn.click(get_preview, [dataset_dir, dataset], [preview_count, preview_samples, preview_box])
with gr.Row():
learning_rate = gr.Textbox(value="5e-5", interactive=True)
num_train_epochs = gr.Textbox(value="3.0", interactive=True)
max_samples = gr.Textbox(value="100000", interactive=True)
quantization_bit = gr.Dropdown([8, 4])
learning_rate = gr.Textbox(value="5e-5")
num_train_epochs = gr.Textbox(value="3.0")
max_samples = gr.Textbox(value="100000")
with gr.Row():
batch_size = gr.Slider(value=4, minimum=1, maximum=128, step=1, interactive=True)
gradient_accumulation_steps = gr.Slider(value=4, minimum=1, maximum=32, step=1, interactive=True)
batch_size = gr.Slider(value=4, minimum=1, maximum=128, step=1)
gradient_accumulation_steps = gr.Slider(value=4, minimum=1, maximum=32, step=1)
lr_scheduler_type = gr.Dropdown(
value="cosine", choices=[scheduler.value for scheduler in SchedulerType], interactive=True
value="cosine", choices=[scheduler.value for scheduler in SchedulerType]
)
fp16 = gr.Checkbox(value=True)
with gr.Row():
logging_steps = gr.Slider(value=5, minimum=5, maximum=1000, step=5, interactive=True)
save_steps = gr.Slider(value=100, minimum=10, maximum=2000, step=10, interactive=True)
logging_steps = gr.Slider(value=5, minimum=5, maximum=1000, step=5)
save_steps = gr.Slider(value=100, minimum=10, maximum=2000, step=10)
with gr.Row():
start_btn = gr.Button()
@@ -55,11 +54,25 @@ def create_sft_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str,
start_btn.click(
runner.run_train,
[
top_elems["lang"], top_elems["model_name"], top_elems["checkpoints"],
top_elems["finetuning_type"], top_elems["template"],
dataset, dataset_dir, learning_rate, num_train_epochs, max_samples,
fp16, quantization_bit, batch_size, gradient_accumulation_steps,
lr_scheduler_type, logging_steps, save_steps, output_dir
top_elems["lang"],
top_elems["model_name"],
top_elems["checkpoints"],
top_elems["finetuning_type"],
top_elems["quantization_bit"],
top_elems["template"],
top_elems["source_prefix"],
dataset_dir,
dataset,
learning_rate,
num_train_epochs,
max_samples,
batch_size,
gradient_accumulation_steps,
lr_scheduler_type,
fp16,
logging_steps,
save_steps,
output_dir
],
[output_box]
)
@@ -79,7 +92,6 @@ def create_sft_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str,
learning_rate=learning_rate,
num_train_epochs=num_train_epochs,
max_samples=max_samples,
quantization_bit=quantization_bit,
batch_size=batch_size,
gradient_accumulation_steps=gradient_accumulation_steps,
lr_scheduler_type=lr_scheduler_type,