From bb6b4823add4cd4818587ac1a2f427ad075adbce Mon Sep 17 00:00:00 2001 From: hiyouga Date: Thu, 30 Nov 2023 20:03:32 +0800 Subject: [PATCH] fix #1682 Former-commit-id: a38dbf55e32a18838eea7f254fd9022fe33bca08 --- src/llmtuner/webui/components/eval.py | 5 +++-- src/llmtuner/webui/engine.py | 5 ++++- src/llmtuner/webui/locales.py | 10 +++++----- src/llmtuner/webui/runner.py | 19 ++++++++----------- 4 files changed, 20 insertions(+), 19 deletions(-) diff --git a/src/llmtuner/webui/components/eval.py b/src/llmtuner/webui/components/eval.py index 36c994a6..0718c63e 100644 --- a/src/llmtuner/webui/components/eval.py +++ b/src/llmtuner/webui/components/eval.py @@ -38,10 +38,11 @@ def create_eval_tab(engine: "Engine") -> Dict[str, "Component"]: max_new_tokens = gr.Slider(10, 2048, value=128, step=1) top_p = gr.Slider(0.01, 1, value=0.7, step=0.01) temperature = gr.Slider(0.01, 1.5, value=0.95, step=0.01) + output_dir = gr.Textbox() - input_elems.update({max_new_tokens, top_p, temperature}) + input_elems.update({max_new_tokens, top_p, temperature, output_dir}) elem_dict.update(dict( - max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature + max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature, output_dir=output_dir )) with gr.Row(): diff --git a/src/llmtuner/webui/engine.py b/src/llmtuner/webui/engine.py index bdac09dd..991b281c 100644 --- a/src/llmtuner/webui/engine.py +++ b/src/llmtuner/webui/engine.py @@ -49,7 +49,10 @@ class Engine: else: yield self._form_dict({"eval.resume_btn": {"value": True}}) else: - yield self._form_dict({"train.output_dir": {"value": get_time()}}) + yield self._form_dict({ + "train.output_dir": {"value": "train_" + get_time()}, + "eval.output_dir": {"value": "eval_" + get_time()}, + }) def change_lang(self, lang: str) -> Dict[Component, Dict[str, Any]]: return { diff --git a/src/llmtuner/webui/locales.py b/src/llmtuner/webui/locales.py index 769cf15d..5f5609d8 100644 --- a/src/llmtuner/webui/locales.py +++ b/src/llmtuner/webui/locales.py @@ -132,7 +132,7 @@ LOCALES = { "dataset_dir": { "en": { "label": "Data dir", - "info": "Path of the data directory." + "info": "Path to the data directory." }, "zh": { "label": "数据路径", @@ -475,12 +475,12 @@ LOCALES = { }, "output_dir": { "en": { - "label": "Checkpoint name", - "info": "Directory to save checkpoint." + "label": "Output dir", + "info": "Directory for saving results." }, "zh": { - "label": "断点名称", - "info": "保存模型断点的文件夹名称。" + "label": "输出目录", + "info": "保存结果的路径。" } }, "output_box": { diff --git a/src/llmtuner/webui/runner.py b/src/llmtuner/webui/runner.py index 7789fc4d..664f3354 100644 --- a/src/llmtuner/webui/runner.py +++ b/src/llmtuner/webui/runner.py @@ -87,9 +87,9 @@ class Runner: user_config = load_config() if get("top.checkpoints"): - checkpoint_dir = ",".join([get_save_dir( - get("top.model_name"), get("top.finetuning_type"), ckpt - ) for ckpt in get("top.checkpoints")]) + checkpoint_dir = ",".join([ + get_save_dir(get("top.model_name"), get("top.finetuning_type"), ckpt) for ckpt in get("top.checkpoints") + ]) else: checkpoint_dir = None @@ -160,15 +160,11 @@ class Runner: user_config = load_config() if get("top.checkpoints"): - checkpoint_dir = ",".join([get_save_dir( - get("top.model_name"), get("top.finetuning_type"), ckpt - ) for ckpt in get("top.checkpoints")]) - output_dir = get_save_dir( - get("top.model_name"), get("top.finetuning_type"), "eval_" + "_".join(get("top.checkpoints")) - ) + checkpoint_dir = ",".join([ + get_save_dir(get("top.model_name"), get("top.finetuning_type"), ckpt) for ckpt in get("top.checkpoints") + ]) else: checkpoint_dir = None - output_dir = get_save_dir(get("top.model_name"), get("top.finetuning_type"), "eval_base") args = dict( stage="sft", @@ -192,7 +188,7 @@ class Runner: max_new_tokens=get("eval.max_new_tokens"), top_p=get("eval.top_p"), temperature=get("eval.temperature"), - output_dir=output_dir + output_dir=get_save_dir(get("top.model_name"), get("top.finetuning_type"), get("eval.output_dir")) ) if get("eval.predict"): @@ -242,6 +238,7 @@ class Runner: output_dir = get_save_dir(get("top.model_name"), get("top.finetuning_type"), get( "{}.output_dir".format("train" if self.do_train else "eval") )) + while self.thread.is_alive(): time.sleep(2) if self.aborted: