allow non-packing pretraining

This commit is contained in:
hiyouga
2024-03-09 22:21:46 +08:00
parent 412c52e325
commit bdb496644c
22 changed files with 64 additions and 67 deletions

View File

@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Dict, Optional, Tuple
from typing import TYPE_CHECKING, Dict, Tuple
import gradio as gr
@@ -14,7 +14,7 @@ if TYPE_CHECKING:
def create_chat_box(
engine: "Engine", visible: Optional[bool] = False
engine: "Engine", visible: bool = False
) -> Tuple["Block", "Component", "Component", Dict[str, "Component"]]:
with gr.Box(visible=visible) as chat_box:
chatbot = gr.Chatbot()

View File

@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Dict
from typing import TYPE_CHECKING, Dict, Tuple
import gradio as gr
@@ -12,7 +12,7 @@ if TYPE_CHECKING:
from gradio.components import Component
def create_top() -> Dict[str, "Component"]:
def create_top() -> Tuple["gr.Dropdown", Dict[str, "Component"]]:
available_models = list(SUPPORTED_MODELS.keys()) + ["Custom"]
with gr.Row():
@@ -44,7 +44,7 @@ def create_top() -> Dict[str, "Component"]:
refresh_btn.click(list_adapters, [model_name, finetuning_type], [adapter_path], queue=False)
return dict(
return lang, dict(
lang=lang,
model_name=model_name,
model_path=model_path,

View File

@@ -4,7 +4,7 @@ import gradio as gr
from transformers.trainer_utils import SchedulerType
from ...extras.constants import TRAINING_STAGES
from ..common import DEFAULT_DATA_DIR, list_adapters, list_dataset
from ..common import DEFAULT_DATA_DIR, autoset_packing, list_adapters, list_dataset
from ..components.data import create_preview_box
from ..utils import gen_plot
@@ -78,7 +78,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
with gr.Row():
resize_vocab = gr.Checkbox()
sft_packing = gr.Checkbox()
packing = gr.Checkbox()
upcast_layernorm = gr.Checkbox()
use_llama_pro = gr.Checkbox()
shift_attn = gr.Checkbox()
@@ -91,7 +91,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
neftune_alpha,
optim,
resize_vocab,
sft_packing,
packing,
upcast_layernorm,
use_llama_pro,
shift_attn,
@@ -106,7 +106,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
neftune_alpha=neftune_alpha,
optim=optim,
resize_vocab=resize_vocab,
sft_packing=sft_packing,
packing=packing,
upcast_layernorm=upcast_layernorm,
use_llama_pro=use_llama_pro,
shift_attn=shift_attn,
@@ -166,7 +166,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
[engine.manager.get_elem_by_name("top.model_name"), engine.manager.get_elem_by_name("top.finetuning_type")],
[reward_model],
queue=False,
)
).then(autoset_packing, [training_stage], [packing], queue=False)
input_elems.update({dpo_beta, dpo_ftx, reward_model})
elem_dict.update(dict(rlhf_tab=rlhf_tab, dpo_beta=dpo_beta, dpo_ftx=dpo_ftx, reward_model=reward_model))