mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-17 12:20:37 +08:00
fix eval in webui
This commit is contained in:
@@ -17,6 +17,7 @@ from ..extras.constants import (
|
||||
TRAINING_STAGES,
|
||||
DownloadSource,
|
||||
)
|
||||
from ..extras.logging import get_logger
|
||||
from ..extras.misc import use_modelscope
|
||||
from ..extras.packages import is_gradio_available
|
||||
|
||||
@@ -25,6 +26,9 @@ if is_gradio_available():
|
||||
import gradio as gr
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
ADAPTER_NAMES = {WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME}
|
||||
DEFAULT_CACHE_DIR = "cache"
|
||||
DEFAULT_CONFIG_DIR = "config"
|
||||
@@ -128,11 +132,15 @@ def list_adapters(model_name: str, finetuning_type: str) -> "gr.Dropdown":
|
||||
|
||||
|
||||
def load_dataset_info(dataset_dir: str) -> Dict[str, Dict[str, Any]]:
|
||||
if dataset_dir == "ONLINE":
|
||||
logger.info("dataset_dir is ONLINE, using online dataset.")
|
||||
return {}
|
||||
|
||||
try:
|
||||
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
except Exception as err:
|
||||
print("Cannot open {} due to {}.".format(os.path.join(dataset_dir, DATA_CONFIG), str(err)))
|
||||
logger.warning("Cannot open {} due to {}.".format(os.path.join(dataset_dir, DATA_CONFIG), str(err)))
|
||||
return {}
|
||||
|
||||
|
||||
|
||||
@@ -21,16 +21,16 @@ def create_eval_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
|
||||
with gr.Row():
|
||||
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2)
|
||||
dataset = gr.Dropdown(multiselect=True, scale=4)
|
||||
dataset = gr.Dropdown(multiselect=True, allow_custom_value=True, scale=4)
|
||||
preview_elems = create_preview_box(dataset_dir, dataset)
|
||||
|
||||
input_elems.update({dataset_dir, dataset})
|
||||
elem_dict.update(dict(dataset_dir=dataset_dir, dataset=dataset, **preview_elems))
|
||||
|
||||
with gr.Row():
|
||||
cutoff_len = gr.Slider(value=1024, minimum=4, maximum=8192, step=1)
|
||||
cutoff_len = gr.Slider(value=1024, minimum=4, maximum=65536, step=1)
|
||||
max_samples = gr.Textbox(value="100000")
|
||||
batch_size = gr.Slider(value=8, minimum=1, maximum=512, step=1)
|
||||
batch_size = gr.Slider(value=2, minimum=1, maximum=1024, step=1)
|
||||
predict = gr.Checkbox(value=True)
|
||||
|
||||
input_elems.update({cutoff_len, max_samples, batch_size, predict})
|
||||
@@ -48,30 +48,27 @@ def create_eval_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
with gr.Row():
|
||||
cmd_preview_btn = gr.Button()
|
||||
start_btn = gr.Button(variant="primary")
|
||||
stop_btn = gr.Button(variant="stop")
|
||||
|
||||
with gr.Row():
|
||||
resume_btn = gr.Checkbox(visible=False, interactive=False)
|
||||
process_bar = gr.Slider(visible=False, interactive=False)
|
||||
progress_bar = gr.Slider(visible=False, interactive=False)
|
||||
|
||||
with gr.Row():
|
||||
output_box = gr.Markdown()
|
||||
|
||||
output_elems = [output_box, process_bar]
|
||||
output_elems = [output_box, progress_bar]
|
||||
elem_dict.update(
|
||||
dict(
|
||||
cmd_preview_btn=cmd_preview_btn,
|
||||
start_btn=start_btn,
|
||||
stop_btn=stop_btn,
|
||||
resume_btn=resume_btn,
|
||||
process_bar=process_bar,
|
||||
progress_bar=progress_bar,
|
||||
output_box=output_box,
|
||||
)
|
||||
)
|
||||
|
||||
cmd_preview_btn.click(engine.runner.preview_eval, input_elems, output_elems, concurrency_limit=None)
|
||||
start_btn.click(engine.runner.run_eval, input_elems, output_elems)
|
||||
stop_btn.click(engine.runner.set_abort)
|
||||
resume_btn.change(engine.runner.monitor, outputs=output_elems, concurrency_limit=None)
|
||||
|
||||
dataset_dir.change(list_dataset, [dataset_dir], [dataset], queue=False)
|
||||
|
||||
@@ -27,7 +27,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
choices=list(TRAINING_STAGES.keys()), value=list(TRAINING_STAGES.keys())[0], scale=1
|
||||
)
|
||||
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=1)
|
||||
dataset = gr.Dropdown(multiselect=True, scale=4)
|
||||
dataset = gr.Dropdown(multiselect=True, allow_custom_value=True, scale=4)
|
||||
preview_elems = create_preview_box(dataset_dir, dataset)
|
||||
|
||||
input_elems.update({training_stage, dataset_dir, dataset})
|
||||
@@ -52,7 +52,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
cutoff_len = gr.Slider(value=1024, minimum=4, maximum=16384, step=1)
|
||||
cutoff_len = gr.Slider(value=1024, minimum=4, maximum=65536, step=1)
|
||||
batch_size = gr.Slider(value=2, minimum=1, maximum=1024, step=1)
|
||||
gradient_accumulation_steps = gr.Slider(value=8, minimum=1, maximum=1024, step=1)
|
||||
val_size = gr.Slider(value=0, minimum=0, maximum=1, step=0.001)
|
||||
|
||||
@@ -299,12 +299,12 @@ class Runner:
|
||||
progress_bar: gr.Slider(visible=False),
|
||||
}
|
||||
else:
|
||||
running_log, running_progress, running_loss = get_trainer_info(output_path)
|
||||
running_log, running_progress, running_loss = get_trainer_info(output_path, self.do_train)
|
||||
return_dict = {
|
||||
output_box: running_log,
|
||||
progress_bar: running_progress,
|
||||
}
|
||||
if self.do_train and running_loss is not None:
|
||||
if running_loss is not None:
|
||||
return_dict[loss_viewer] = running_loss
|
||||
|
||||
yield return_dict
|
||||
|
||||
@@ -63,7 +63,7 @@ def get_time() -> str:
|
||||
return datetime.now().strftime(r"%Y-%m-%d-%H-%M-%S")
|
||||
|
||||
|
||||
def get_trainer_info(output_path: os.PathLike) -> Tuple[str, "gr.Slider", Optional["gr.Plot"]]:
|
||||
def get_trainer_info(output_path: os.PathLike, do_train: bool) -> Tuple[str, "gr.Slider", Optional["gr.Plot"]]:
|
||||
running_log = ""
|
||||
running_progress = gr.Slider(visible=False)
|
||||
running_loss = None
|
||||
@@ -91,7 +91,7 @@ def get_trainer_info(output_path: os.PathLike) -> Tuple[str, "gr.Slider", Option
|
||||
)
|
||||
running_progress = gr.Slider(label=label, value=percentage, visible=True)
|
||||
|
||||
if is_matplotlib_available():
|
||||
if do_train and is_matplotlib_available():
|
||||
running_loss = gr.Plot(gen_loss_plot(trainer_log))
|
||||
|
||||
return running_log, running_progress, running_loss
|
||||
|
||||
Reference in New Issue
Block a user