mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-25 07:12:50 +08:00
update webui
Former-commit-id: d0842f682897cb227cda9e9747f42a7281970463
This commit is contained in:
parent
bd52e2b404
commit
cefe7f7bcf
@ -4,7 +4,7 @@ CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
|
|||||||
--stage orpo \
|
--stage orpo \
|
||||||
--do_train \
|
--do_train \
|
||||||
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||||
--dataset comparison_gpt4_en \
|
--dataset orca_rlhf \
|
||||||
--dataset_dir ../../data \
|
--dataset_dir ../../data \
|
||||||
--template default \
|
--template default \
|
||||||
--finetuning_type lora \
|
--finetuning_type lora \
|
||||||
|
@ -21,10 +21,10 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
|||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
training_stage = gr.Dropdown(
|
training_stage = gr.Dropdown(
|
||||||
choices=list(TRAINING_STAGES.keys()), value=list(TRAINING_STAGES.keys())[0], scale=2
|
choices=list(TRAINING_STAGES.keys()), value=list(TRAINING_STAGES.keys())[0], scale=1
|
||||||
)
|
)
|
||||||
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2)
|
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=1)
|
||||||
dataset = gr.Dropdown(multiselect=True, scale=4)
|
dataset = gr.Dropdown(multiselect=True, scale=2, allow_custom_value=True)
|
||||||
preview_elems = create_preview_box(dataset_dir, dataset)
|
preview_elems = create_preview_box(dataset_dir, dataset)
|
||||||
|
|
||||||
input_elems.update({training_stage, dataset_dir, dataset})
|
input_elems.update({training_stage, dataset_dir, dataset})
|
||||||
@ -75,11 +75,17 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
|||||||
optim = gr.Textbox(value="adamw_torch")
|
optim = gr.Textbox(value="adamw_torch")
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
resize_vocab = gr.Checkbox()
|
with gr.Column():
|
||||||
packing = gr.Checkbox()
|
resize_vocab = gr.Checkbox()
|
||||||
upcast_layernorm = gr.Checkbox()
|
packing = gr.Checkbox()
|
||||||
use_llama_pro = gr.Checkbox()
|
|
||||||
shift_attn = gr.Checkbox()
|
with gr.Column():
|
||||||
|
upcast_layernorm = gr.Checkbox()
|
||||||
|
use_llama_pro = gr.Checkbox()
|
||||||
|
|
||||||
|
with gr.Column():
|
||||||
|
shift_attn = gr.Checkbox()
|
||||||
|
report_to = gr.Checkbox()
|
||||||
|
|
||||||
input_elems.update(
|
input_elems.update(
|
||||||
{
|
{
|
||||||
@ -93,6 +99,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
|||||||
upcast_layernorm,
|
upcast_layernorm,
|
||||||
use_llama_pro,
|
use_llama_pro,
|
||||||
shift_attn,
|
shift_attn,
|
||||||
|
report_to,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
elem_dict.update(
|
elem_dict.update(
|
||||||
@ -108,6 +115,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
|||||||
upcast_layernorm=upcast_layernorm,
|
upcast_layernorm=upcast_layernorm,
|
||||||
use_llama_pro=use_llama_pro,
|
use_llama_pro=use_llama_pro,
|
||||||
shift_attn=shift_attn,
|
shift_attn=shift_attn,
|
||||||
|
report_to=report_to,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -536,6 +536,20 @@ LOCALES = {
|
|||||||
"info": "使用 LongLoRA 提出的 shift short attention。",
|
"info": "使用 LongLoRA 提出的 shift short attention。",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
"report_to": {
|
||||||
|
"en": {
|
||||||
|
"label": "Enable external logger",
|
||||||
|
"info": "Use TensorBoard or wandb to log experiment.",
|
||||||
|
},
|
||||||
|
"ru": {
|
||||||
|
"label": "Включить внешний регистратор",
|
||||||
|
"info": "Использовать TensorBoard или wandb для ведения журнала экспериментов.",
|
||||||
|
},
|
||||||
|
"zh": {
|
||||||
|
"label": "启用外部记录面板",
|
||||||
|
"info": "使用 TensorBoard 或 wandb 记录实验。",
|
||||||
|
},
|
||||||
|
},
|
||||||
"freeze_tab": {
|
"freeze_tab": {
|
||||||
"en": {
|
"en": {
|
||||||
"label": "Freeze tuning configurations",
|
"label": "Freeze tuning configurations",
|
||||||
|
@ -80,20 +80,18 @@ class Runner:
|
|||||||
if not from_preview and not is_torch_cuda_available():
|
if not from_preview and not is_torch_cuda_available():
|
||||||
gr.Warning(ALERTS["warn_no_cuda"][lang])
|
gr.Warning(ALERTS["warn_no_cuda"][lang])
|
||||||
|
|
||||||
self.aborted = False
|
|
||||||
self.logger_handler.reset()
|
self.logger_handler.reset()
|
||||||
self.trainer_callback = LogCallback(self)
|
self.trainer_callback = LogCallback(self)
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
def _finalize(self, lang: str, finish_info: str) -> str:
|
def _finalize(self, lang: str, finish_info: str) -> str:
|
||||||
|
finish_info = ALERTS["info_aborted"][lang] if self.aborted else finish_info
|
||||||
self.thread = None
|
self.thread = None
|
||||||
self.running_data = None
|
self.aborted = False
|
||||||
self.running = False
|
self.running = False
|
||||||
|
self.running_data = None
|
||||||
torch_gc()
|
torch_gc()
|
||||||
if self.aborted:
|
return finish_info
|
||||||
return ALERTS["info_aborted"][lang]
|
|
||||||
else:
|
|
||||||
return finish_info
|
|
||||||
|
|
||||||
def _parse_train_args(self, data: Dict["Component", Any]) -> Dict[str, Any]:
|
def _parse_train_args(self, data: Dict["Component", Any]) -> Dict[str, Any]:
|
||||||
get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)]
|
get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)]
|
||||||
@ -141,6 +139,7 @@ class Runner:
|
|||||||
upcast_layernorm=get("train.upcast_layernorm"),
|
upcast_layernorm=get("train.upcast_layernorm"),
|
||||||
use_llama_pro=get("train.use_llama_pro"),
|
use_llama_pro=get("train.use_llama_pro"),
|
||||||
shift_attn=get("train.shift_attn"),
|
shift_attn=get("train.shift_attn"),
|
||||||
|
report_to="all" if get("train.report_to") else "none",
|
||||||
use_galore=get("train.use_galore"),
|
use_galore=get("train.use_galore"),
|
||||||
output_dir=get_save_dir(get("top.model_name"), get("top.finetuning_type"), get("train.output_dir")),
|
output_dir=get_save_dir(get("top.model_name"), get("top.finetuning_type"), get("train.output_dir")),
|
||||||
fp16=(get("train.compute_type") == "fp16"),
|
fp16=(get("train.compute_type") == "fp16"),
|
||||||
|
Loading…
x
Reference in New Issue
Block a user