Former-commit-id: 779aae83d2
This commit is contained in:
hiyouga
2024-07-18 22:06:12 +08:00
parent c8e77c11d1
commit 34f16cc635
7 changed files with 56 additions and 36 deletions

View File

@@ -44,11 +44,10 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
)
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=1)
dataset = gr.Dropdown(multiselect=True, allow_custom_value=True, scale=4)
train_last_turn_only = gr.Checkbox()
preview_elems = create_preview_box(dataset_dir, dataset)
input_elems.update({training_stage, dataset_dir, dataset,train_last_turn_only})
elem_dict.update(dict(training_stage=training_stage, dataset_dir=dataset_dir, dataset=dataset,train_last_turn_only=train_last_turn_only, **preview_elems))
input_elems.update({training_stage, dataset_dir, dataset})
elem_dict.update(dict(training_stage=training_stage, dataset_dir=dataset_dir, dataset=dataset, **preview_elems))
with gr.Row():
learning_rate = gr.Textbox(value="5e-5")
@@ -99,6 +98,10 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
packing = gr.Checkbox()
neat_packing = gr.Checkbox()
with gr.Column():
train_on_prompt = gr.Checkbox()
mask_history = gr.Checkbox()
with gr.Column():
resize_vocab = gr.Checkbox()
use_llama_pro = gr.Checkbox()
@@ -116,6 +119,8 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
optim,
packing,
neat_packing,
train_on_prompt,
mask_history,
resize_vocab,
use_llama_pro,
shift_attn,
@@ -132,6 +137,8 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
optim=optim,
packing=packing,
neat_packing=neat_packing,
train_on_prompt=train_on_prompt,
mask_history=mask_history,
resize_vocab=resize_vocab,
use_llama_pro=use_llama_pro,
shift_attn=shift_attn,