add template match and stage in webui

Former-commit-id: 79c68e552722079faf2ab0858870b481844d66ae
This commit is contained in:
codemayq 2023-08-14 20:42:59 +08:00
parent 6c9b035c0e
commit ee7da14f81
6 changed files with 77 additions and 14 deletions

View File

@ -10,6 +10,8 @@ LAYERNORM_NAMES = ["norm", "ln_f", "ln_attn", "ln_mlp"]
METHODS = ["full", "freeze", "lora"] METHODS = ["full", "freeze", "lora"]
STAGES = ["Supervised Finetuning", "Reward Modeling", "PPO", "DPO", "Pretraining"]
SUPPORTED_MODELS = { SUPPORTED_MODELS = {
"LLaMA-7B": "huggyllama/llama-7b", "LLaMA-7B": "huggyllama/llama-7b",
"LLaMA-13B": "huggyllama/llama-13b", "LLaMA-13B": "huggyllama/llama-13b",
@ -54,3 +56,31 @@ DEFAULT_MODULE = {
"XVERSE": "q_proj,v_proj", "XVERSE": "q_proj,v_proj",
"ChatGLM2": "query_key_value" "ChatGLM2": "query_key_value"
} }
DEFAULT_TEMPLATE = {
"LLaMA2": "llama2",
"Baichuan": "baichuan",
"InternLM": "intern",
"Qwen": "chatml",
"ChatGLM2": "chatglm2"
}
# huggingface model name prefix 2 template
DEFAULT_TEMPLATE_WITH_CUSTOM_MODEL = {
"Llama-2": "llama2",
"chinese-alpaca-2": "llama2_zh",
"alpaca-7b-wdiff": "alpaca",
"vicuna": "vicuna",
"BELLE": "belle",
"Chinese-LLaMA-2": "linly",
"BiLLa": "billa",
"Ziya": "ziya",
"aquilachat": "aquila",
"internlm": "intern",
"aquilachat": "aquila",
"internlm": "intern",
"Baichuan":"baichuan",
"starchat":"starchat",
"Qwen":"chatml",
"chatglm2":"chatglm2"
}

View File

@ -273,7 +273,8 @@ register_template(
r""" r"""
Supports: https://github.com/ymcui/Chinese-LLaMA-Alpaca-2 Supports: https://huggingface.co/ziqingyang/chinese-alpaca-2-7b
https://github.com/ymcui/Chinese-LLaMA-Alpaca-2
""" """
register_template( register_template(
name="llama2_zh", name="llama2_zh",

View File

@ -6,7 +6,7 @@ import gradio as gr
from peft.utils import WEIGHTS_NAME as PEFT_WEIGHTS_NAME from peft.utils import WEIGHTS_NAME as PEFT_WEIGHTS_NAME
from transformers.trainer import WEIGHTS_NAME, WEIGHTS_INDEX_NAME from transformers.trainer import WEIGHTS_NAME, WEIGHTS_INDEX_NAME
from llmtuner.extras.constants import SUPPORTED_MODELS from llmtuner.extras.constants import SUPPORTED_MODELS, DEFAULT_TEMPLATE_WITH_CUSTOM_MODEL, DEFAULT_TEMPLATE
DEFAULT_CACHE_DIR = "cache" DEFAULT_CACHE_DIR = "cache"
@ -48,6 +48,25 @@ def get_model_path(model_name: str) -> str:
return user_config["path_dict"].get(model_name, SUPPORTED_MODELS.get(model_name, "")) return user_config["path_dict"].get(model_name, SUPPORTED_MODELS.get(model_name, ""))
def get_template(
template: str,
model_name: str,
) -> str:
if template and template != "default":
return template
if model_name == "Custom":
model_name_or_path = get_model_path(model_name)
# get last dir
basename = os.path.basename(model_name_or_path)
# prefix match
for k, v in DEFAULT_TEMPLATE_WITH_CUSTOM_MODEL.items():
if basename.startswith(k):
return v
return "default"
return DEFAULT_TEMPLATE.get(model_name.split("-")[0], "default")
def list_checkpoint(model_name: str, finetuning_type: str) -> Dict[str, Any]: def list_checkpoint(model_name: str, finetuning_type: str) -> Dict[str, Any]:
checkpoints = [] checkpoints = []
save_dir = os.path.join(get_save_dir(model_name), finetuning_type) save_dir = os.path.join(get_save_dir(model_name), finetuning_type)

View File

@ -6,6 +6,7 @@ import gradio as gr
from llmtuner.webui.common import list_checkpoint, list_dataset, DEFAULT_DATA_DIR from llmtuner.webui.common import list_checkpoint, list_dataset, DEFAULT_DATA_DIR
from llmtuner.webui.components.data import create_preview_box from llmtuner.webui.components.data import create_preview_box
from llmtuner.webui.utils import can_preview, get_preview, gen_plot from llmtuner.webui.utils import can_preview, get_preview, gen_plot
from llmtuner.extras.constants import STAGES
if TYPE_CHECKING: if TYPE_CHECKING:
from gradio.components import Component from gradio.components import Component
@ -14,6 +15,9 @@ if TYPE_CHECKING:
def create_train_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict[str, "Component"]: def create_train_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict[str, "Component"]:
with gr.Row(): with gr.Row():
stage = gr.Dropdown(choices=STAGES,
value="Supervised Finetuning", scale=2)
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2) dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2)
dataset = gr.Dropdown(multiselect=True, scale=4) dataset = gr.Dropdown(multiselect=True, scale=4)
data_preview_btn = gr.Button(interactive=False, scale=1) data_preview_btn = gr.Button(interactive=False, scale=1)
@ -62,7 +66,6 @@ def create_train_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dic
with gr.Accordion(label="RLHF config", open=False) as rlhf_tab: with gr.Accordion(label="RLHF config", open=False) as rlhf_tab:
with gr.Row(): with gr.Row():
rlhf_method = gr.Dropdown(choices=["None", "Reward Modeling", "PPO", "DPO"], value="None", scale=1)
dpo_beta = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01, scale=2) dpo_beta = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01, scale=2)
reward_model = gr.Dropdown(scale=2) reward_model = gr.Dropdown(scale=2)
refresh_btn = gr.Button(scale=1) refresh_btn = gr.Button(scale=1)
@ -101,6 +104,7 @@ def create_train_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dic
top_elems["quantization_bit"], top_elems["quantization_bit"],
top_elems["template"], top_elems["template"],
top_elems["source_prefix"], top_elems["source_prefix"],
stage,
dataset_dir, dataset_dir,
dataset, dataset,
max_source_length, max_source_length,
@ -122,7 +126,6 @@ def create_train_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dic
lora_dropout, lora_dropout,
lora_target, lora_target,
resume_lora_training, resume_lora_training,
rlhf_method,
dpo_beta, dpo_beta,
reward_model, reward_model,
output_dir output_dir
@ -142,6 +145,7 @@ def create_train_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dic
) )
return dict( return dict(
stage=stage,
dataset_dir=dataset_dir, dataset_dir=dataset_dir,
dataset=dataset, dataset=dataset,
data_preview_btn=data_preview_btn, data_preview_btn=data_preview_btn,
@ -170,7 +174,6 @@ def create_train_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dic
lora_target=lora_target, lora_target=lora_target,
resume_lora_training=resume_lora_training, resume_lora_training=resume_lora_training,
rlhf_tab=rlhf_tab, rlhf_tab=rlhf_tab,
rlhf_method=rlhf_method,
dpo_beta=dpo_beta, dpo_beta=dpo_beta,
reward_model=reward_model, reward_model=reward_model,
refresh_btn=refresh_btn, refresh_btn=refresh_btn,

View File

@ -546,7 +546,15 @@ LOCALES = {
"zh": { "zh": {
"value": "开始导出" "value": "开始导出"
} }
},
"stage": {
"en": {
"label": "train stage"
},
"zh": {
"label": "训练阶段"
} }
},
} }

View File

@ -8,11 +8,11 @@ from transformers.trainer import TRAINING_ARGS_NAME
from typing import Any, Dict, Generator, List, Tuple from typing import Any, Dict, Generator, List, Tuple
from llmtuner.extras.callbacks import LogCallback from llmtuner.extras.callbacks import LogCallback
from llmtuner.extras.constants import DEFAULT_MODULE from llmtuner.extras.constants import DEFAULT_MODULE, DEFAULT_TEMPLATE, DEFAULT_TEMPLATE_WITH_CUSTOM_MODEL
from llmtuner.extras.logging import LoggerHandler from llmtuner.extras.logging import LoggerHandler
from llmtuner.extras.misc import torch_gc from llmtuner.extras.misc import torch_gc
from llmtuner.tuner import run_exp from llmtuner.tuner import run_exp
from llmtuner.webui.common import get_model_path, get_save_dir from llmtuner.webui.common import get_model_path, get_save_dir, get_template
from llmtuner.webui.locales import ALERTS from llmtuner.webui.locales import ALERTS
from llmtuner.webui.utils import gen_cmd, get_eval_results, update_process_bar from llmtuner.webui.utils import gen_cmd, get_eval_results, update_process_bar
@ -70,6 +70,7 @@ class Runner:
quantization_bit: str, quantization_bit: str,
template: str, template: str,
source_prefix: str, source_prefix: str,
stage: str,
dataset_dir: str, dataset_dir: str,
dataset: List[str], dataset: List[str],
max_source_length: int, max_source_length: int,
@ -91,7 +92,6 @@ class Runner:
lora_dropout: float, lora_dropout: float,
lora_target: str, lora_target: str,
resume_lora_training: bool, resume_lora_training: bool,
rlhf_method: str,
dpo_beta: float, dpo_beta: float,
reward_model: str, reward_model: str,
output_dir: str output_dir: str
@ -113,7 +113,7 @@ class Runner:
checkpoint_dir=checkpoint_dir, checkpoint_dir=checkpoint_dir,
finetuning_type=finetuning_type, finetuning_type=finetuning_type,
quantization_bit=int(quantization_bit) if quantization_bit != "None" else None, quantization_bit=int(quantization_bit) if quantization_bit != "None" else None,
template=template, template=get_template(template, model_name),
source_prefix=source_prefix, source_prefix=source_prefix,
dataset_dir=dataset_dir, dataset_dir=dataset_dir,
dataset=",".join(dataset), dataset=",".join(dataset),
@ -138,16 +138,18 @@ class Runner:
) )
args[compute_type] = True args[compute_type] = True
if rlhf_method == "Reward Modeling": if stage == "Pretraining":
args["stage"] = "pt"
if stage == "Reward Modeling":
args["stage"] = "rm" args["stage"] = "rm"
args["resume_lora_training"] = False args["resume_lora_training"] = False
elif rlhf_method == "PPO": elif stage == "PPO":
args["stage"] = "ppo" args["stage"] = "ppo"
args["resume_lora_training"] = False args["resume_lora_training"] = False
args["reward_model"] = reward_model args["reward_model"] = reward_model
args["padding_side"] = "left" args["padding_side"] = "left"
val_size = 0 val_size = 0
elif rlhf_method == "DPO": elif stage == "DPO":
args["stage"] = "dpo" args["stage"] = "dpo"
args["resume_lora_training"] = False args["resume_lora_training"] = False
args["dpo_beta"] = dpo_beta args["dpo_beta"] = dpo_beta
@ -195,7 +197,7 @@ class Runner:
checkpoint_dir=checkpoint_dir, checkpoint_dir=checkpoint_dir,
finetuning_type=finetuning_type, finetuning_type=finetuning_type,
quantization_bit=int(quantization_bit) if quantization_bit != "None" else None, quantization_bit=int(quantization_bit) if quantization_bit != "None" else None,
template=template, template=get_template(template, model_name),
source_prefix=source_prefix, source_prefix=source_prefix,
dataset_dir=dataset_dir, dataset_dir=dataset_dir,
dataset=",".join(dataset), dataset=",".join(dataset),