Former-commit-id: a38dbf55e32a18838eea7f254fd9022fe33bca08
This commit is contained in:
hiyouga 2023-11-30 20:03:32 +08:00
parent 1c43fb6a41
commit bb6b4823ad
4 changed files with 20 additions and 19 deletions

View File

@ -38,10 +38,11 @@ def create_eval_tab(engine: "Engine") -> Dict[str, "Component"]:
max_new_tokens = gr.Slider(10, 2048, value=128, step=1) max_new_tokens = gr.Slider(10, 2048, value=128, step=1)
top_p = gr.Slider(0.01, 1, value=0.7, step=0.01) top_p = gr.Slider(0.01, 1, value=0.7, step=0.01)
temperature = gr.Slider(0.01, 1.5, value=0.95, step=0.01) temperature = gr.Slider(0.01, 1.5, value=0.95, step=0.01)
output_dir = gr.Textbox()
input_elems.update({max_new_tokens, top_p, temperature}) input_elems.update({max_new_tokens, top_p, temperature, output_dir})
elem_dict.update(dict( elem_dict.update(dict(
max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature, output_dir=output_dir
)) ))
with gr.Row(): with gr.Row():

View File

@ -49,7 +49,10 @@ class Engine:
else: else:
yield self._form_dict({"eval.resume_btn": {"value": True}}) yield self._form_dict({"eval.resume_btn": {"value": True}})
else: else:
yield self._form_dict({"train.output_dir": {"value": get_time()}}) yield self._form_dict({
"train.output_dir": {"value": "train_" + get_time()},
"eval.output_dir": {"value": "eval_" + get_time()},
})
def change_lang(self, lang: str) -> Dict[Component, Dict[str, Any]]: def change_lang(self, lang: str) -> Dict[Component, Dict[str, Any]]:
return { return {

View File

@ -132,7 +132,7 @@ LOCALES = {
"dataset_dir": { "dataset_dir": {
"en": { "en": {
"label": "Data dir", "label": "Data dir",
"info": "Path of the data directory." "info": "Path to the data directory."
}, },
"zh": { "zh": {
"label": "数据路径", "label": "数据路径",
@ -475,12 +475,12 @@ LOCALES = {
}, },
"output_dir": { "output_dir": {
"en": { "en": {
"label": "Checkpoint name", "label": "Output dir",
"info": "Directory to save checkpoint." "info": "Directory for saving results."
}, },
"zh": { "zh": {
"label": "断点名称", "label": "输出目录",
"info": "保存模型断点的文件夹名称" "info": "保存结果的路径"
} }
}, },
"output_box": { "output_box": {

View File

@ -87,9 +87,9 @@ class Runner:
user_config = load_config() user_config = load_config()
if get("top.checkpoints"): if get("top.checkpoints"):
checkpoint_dir = ",".join([get_save_dir( checkpoint_dir = ",".join([
get("top.model_name"), get("top.finetuning_type"), ckpt get_save_dir(get("top.model_name"), get("top.finetuning_type"), ckpt) for ckpt in get("top.checkpoints")
) for ckpt in get("top.checkpoints")]) ])
else: else:
checkpoint_dir = None checkpoint_dir = None
@ -160,15 +160,11 @@ class Runner:
user_config = load_config() user_config = load_config()
if get("top.checkpoints"): if get("top.checkpoints"):
checkpoint_dir = ",".join([get_save_dir( checkpoint_dir = ",".join([
get("top.model_name"), get("top.finetuning_type"), ckpt get_save_dir(get("top.model_name"), get("top.finetuning_type"), ckpt) for ckpt in get("top.checkpoints")
) for ckpt in get("top.checkpoints")]) ])
output_dir = get_save_dir(
get("top.model_name"), get("top.finetuning_type"), "eval_" + "_".join(get("top.checkpoints"))
)
else: else:
checkpoint_dir = None checkpoint_dir = None
output_dir = get_save_dir(get("top.model_name"), get("top.finetuning_type"), "eval_base")
args = dict( args = dict(
stage="sft", stage="sft",
@ -192,7 +188,7 @@ class Runner:
max_new_tokens=get("eval.max_new_tokens"), max_new_tokens=get("eval.max_new_tokens"),
top_p=get("eval.top_p"), top_p=get("eval.top_p"),
temperature=get("eval.temperature"), temperature=get("eval.temperature"),
output_dir=output_dir output_dir=get_save_dir(get("top.model_name"), get("top.finetuning_type"), get("eval.output_dir"))
) )
if get("eval.predict"): if get("eval.predict"):
@ -242,6 +238,7 @@ class Runner:
output_dir = get_save_dir(get("top.model_name"), get("top.finetuning_type"), get( output_dir = get_save_dir(get("top.model_name"), get("top.finetuning_type"), get(
"{}.output_dir".format("train" if self.do_train else "eval") "{}.output_dir".format("train" if self.do_train else "eval")
)) ))
while self.thread.is_alive(): while self.thread.is_alive():
time.sleep(2) time.sleep(2)
if self.aborted: if self.aborted: