mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-04 20:52:59 +08:00
parent
1c43fb6a41
commit
bb6b4823ad
@ -38,10 +38,11 @@ def create_eval_tab(engine: "Engine") -> Dict[str, "Component"]:
|
|||||||
max_new_tokens = gr.Slider(10, 2048, value=128, step=1)
|
max_new_tokens = gr.Slider(10, 2048, value=128, step=1)
|
||||||
top_p = gr.Slider(0.01, 1, value=0.7, step=0.01)
|
top_p = gr.Slider(0.01, 1, value=0.7, step=0.01)
|
||||||
temperature = gr.Slider(0.01, 1.5, value=0.95, step=0.01)
|
temperature = gr.Slider(0.01, 1.5, value=0.95, step=0.01)
|
||||||
|
output_dir = gr.Textbox()
|
||||||
|
|
||||||
input_elems.update({max_new_tokens, top_p, temperature})
|
input_elems.update({max_new_tokens, top_p, temperature, output_dir})
|
||||||
elem_dict.update(dict(
|
elem_dict.update(dict(
|
||||||
max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature
|
max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature, output_dir=output_dir
|
||||||
))
|
))
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
|
@ -49,7 +49,10 @@ class Engine:
|
|||||||
else:
|
else:
|
||||||
yield self._form_dict({"eval.resume_btn": {"value": True}})
|
yield self._form_dict({"eval.resume_btn": {"value": True}})
|
||||||
else:
|
else:
|
||||||
yield self._form_dict({"train.output_dir": {"value": get_time()}})
|
yield self._form_dict({
|
||||||
|
"train.output_dir": {"value": "train_" + get_time()},
|
||||||
|
"eval.output_dir": {"value": "eval_" + get_time()},
|
||||||
|
})
|
||||||
|
|
||||||
def change_lang(self, lang: str) -> Dict[Component, Dict[str, Any]]:
|
def change_lang(self, lang: str) -> Dict[Component, Dict[str, Any]]:
|
||||||
return {
|
return {
|
||||||
|
@ -132,7 +132,7 @@ LOCALES = {
|
|||||||
"dataset_dir": {
|
"dataset_dir": {
|
||||||
"en": {
|
"en": {
|
||||||
"label": "Data dir",
|
"label": "Data dir",
|
||||||
"info": "Path of the data directory."
|
"info": "Path to the data directory."
|
||||||
},
|
},
|
||||||
"zh": {
|
"zh": {
|
||||||
"label": "数据路径",
|
"label": "数据路径",
|
||||||
@ -475,12 +475,12 @@ LOCALES = {
|
|||||||
},
|
},
|
||||||
"output_dir": {
|
"output_dir": {
|
||||||
"en": {
|
"en": {
|
||||||
"label": "Checkpoint name",
|
"label": "Output dir",
|
||||||
"info": "Directory to save checkpoint."
|
"info": "Directory for saving results."
|
||||||
},
|
},
|
||||||
"zh": {
|
"zh": {
|
||||||
"label": "断点名称",
|
"label": "输出目录",
|
||||||
"info": "保存模型断点的文件夹名称。"
|
"info": "保存结果的路径。"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"output_box": {
|
"output_box": {
|
||||||
|
@ -87,9 +87,9 @@ class Runner:
|
|||||||
user_config = load_config()
|
user_config = load_config()
|
||||||
|
|
||||||
if get("top.checkpoints"):
|
if get("top.checkpoints"):
|
||||||
checkpoint_dir = ",".join([get_save_dir(
|
checkpoint_dir = ",".join([
|
||||||
get("top.model_name"), get("top.finetuning_type"), ckpt
|
get_save_dir(get("top.model_name"), get("top.finetuning_type"), ckpt) for ckpt in get("top.checkpoints")
|
||||||
) for ckpt in get("top.checkpoints")])
|
])
|
||||||
else:
|
else:
|
||||||
checkpoint_dir = None
|
checkpoint_dir = None
|
||||||
|
|
||||||
@ -160,15 +160,11 @@ class Runner:
|
|||||||
user_config = load_config()
|
user_config = load_config()
|
||||||
|
|
||||||
if get("top.checkpoints"):
|
if get("top.checkpoints"):
|
||||||
checkpoint_dir = ",".join([get_save_dir(
|
checkpoint_dir = ",".join([
|
||||||
get("top.model_name"), get("top.finetuning_type"), ckpt
|
get_save_dir(get("top.model_name"), get("top.finetuning_type"), ckpt) for ckpt in get("top.checkpoints")
|
||||||
) for ckpt in get("top.checkpoints")])
|
])
|
||||||
output_dir = get_save_dir(
|
|
||||||
get("top.model_name"), get("top.finetuning_type"), "eval_" + "_".join(get("top.checkpoints"))
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
checkpoint_dir = None
|
checkpoint_dir = None
|
||||||
output_dir = get_save_dir(get("top.model_name"), get("top.finetuning_type"), "eval_base")
|
|
||||||
|
|
||||||
args = dict(
|
args = dict(
|
||||||
stage="sft",
|
stage="sft",
|
||||||
@ -192,7 +188,7 @@ class Runner:
|
|||||||
max_new_tokens=get("eval.max_new_tokens"),
|
max_new_tokens=get("eval.max_new_tokens"),
|
||||||
top_p=get("eval.top_p"),
|
top_p=get("eval.top_p"),
|
||||||
temperature=get("eval.temperature"),
|
temperature=get("eval.temperature"),
|
||||||
output_dir=output_dir
|
output_dir=get_save_dir(get("top.model_name"), get("top.finetuning_type"), get("eval.output_dir"))
|
||||||
)
|
)
|
||||||
|
|
||||||
if get("eval.predict"):
|
if get("eval.predict"):
|
||||||
@ -242,6 +238,7 @@ class Runner:
|
|||||||
output_dir = get_save_dir(get("top.model_name"), get("top.finetuning_type"), get(
|
output_dir = get_save_dir(get("top.model_name"), get("top.finetuning_type"), get(
|
||||||
"{}.output_dir".format("train" if self.do_train else "eval")
|
"{}.output_dir".format("train" if self.do_train else "eval")
|
||||||
))
|
))
|
||||||
|
|
||||||
while self.thread.is_alive():
|
while self.thread.is_alive():
|
||||||
time.sleep(2)
|
time.sleep(2)
|
||||||
if self.aborted:
|
if self.aborted:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user