mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-03 20:22:49 +08:00
update webui
Former-commit-id: 9d0f6214b68a653c0a67632437b227ab8f589bed
This commit is contained in:
parent
d021d31a9c
commit
02a61b08b1
@ -10,7 +10,13 @@ LAYERNORM_NAMES = ["norm", "ln_f", "ln_attn", "ln_mlp"]
|
|||||||
|
|
||||||
METHODS = ["full", "freeze", "lora"]
|
METHODS = ["full", "freeze", "lora"]
|
||||||
|
|
||||||
STAGES = ["Supervised Finetuning", "Reward Modeling", "PPO", "DPO", "Pretraining"]
|
STAGES = [
|
||||||
|
"SFT",
|
||||||
|
"Reward Modeling",
|
||||||
|
"PPO",
|
||||||
|
"DPO",
|
||||||
|
"Pre-Training"
|
||||||
|
]
|
||||||
|
|
||||||
SUPPORTED_MODELS = {
|
SUPPORTED_MODELS = {
|
||||||
"LLaMA-7B": "huggyllama/llama-7b",
|
"LLaMA-7B": "huggyllama/llama-7b",
|
||||||
@ -23,6 +29,10 @@ SUPPORTED_MODELS = {
|
|||||||
"LLaMA2-7B-Chat": "meta-llama/Llama-2-7b-chat-hf",
|
"LLaMA2-7B-Chat": "meta-llama/Llama-2-7b-chat-hf",
|
||||||
"LLaMA2-13B-Chat": "meta-llama/Llama-2-13b-chat-hf",
|
"LLaMA2-13B-Chat": "meta-llama/Llama-2-13b-chat-hf",
|
||||||
"LLaMA2-70B-Chat": "meta-llama/Llama-2-70b-chat-hf",
|
"LLaMA2-70B-Chat": "meta-llama/Llama-2-70b-chat-hf",
|
||||||
|
"ChineseLLaMA2-7B": "ziqingyang/chinese-llama-2-7b",
|
||||||
|
"ChineseLLaMA2-13B": "ziqingyang/chinese-llama-2-13b",
|
||||||
|
"ChineseLLaMA2-7B-Chat": "ziqingyang/chinese-alpaca-2-7b",
|
||||||
|
"ChineseLLaMA2-13B-Chat": "ziqingyang/chinese-alpaca-2-13b",
|
||||||
"BLOOM-560M": "bigscience/bloom-560m",
|
"BLOOM-560M": "bigscience/bloom-560m",
|
||||||
"BLOOM-3B": "bigscience/bloom-3b",
|
"BLOOM-3B": "bigscience/bloom-3b",
|
||||||
"BLOOM-7B1": "bigscience/bloom-7b1",
|
"BLOOM-7B1": "bigscience/bloom-7b1",
|
||||||
@ -41,12 +51,13 @@ SUPPORTED_MODELS = {
|
|||||||
"Qwen-7B": "Qwen/Qwen-7B",
|
"Qwen-7B": "Qwen/Qwen-7B",
|
||||||
"Qwen-7B-Chat": "Qwen/Qwen-7B-Chat",
|
"Qwen-7B-Chat": "Qwen/Qwen-7B-Chat",
|
||||||
"XVERSE-13B": "xverse/XVERSE-13B",
|
"XVERSE-13B": "xverse/XVERSE-13B",
|
||||||
"ChatGLM2-6B": "THUDM/chatglm2-6b"
|
"ChatGLM2-6B-Chat": "THUDM/chatglm2-6b"
|
||||||
}
|
}
|
||||||
|
|
||||||
DEFAULT_MODULE = {
|
DEFAULT_MODULE = {
|
||||||
"LLaMA": "q_proj,v_proj",
|
"LLaMA": "q_proj,v_proj",
|
||||||
"LLaMA2": "q_proj,v_proj",
|
"LLaMA2": "q_proj,v_proj",
|
||||||
|
"ChineseLLaMA2": "q_proj,v_proj",
|
||||||
"BLOOM": "query_key_value",
|
"BLOOM": "query_key_value",
|
||||||
"BLOOMZ": "query_key_value",
|
"BLOOMZ": "query_key_value",
|
||||||
"Falcon": "query_key_value",
|
"Falcon": "query_key_value",
|
||||||
@ -59,28 +70,9 @@ DEFAULT_MODULE = {
|
|||||||
|
|
||||||
DEFAULT_TEMPLATE = {
|
DEFAULT_TEMPLATE = {
|
||||||
"LLaMA2": "llama2",
|
"LLaMA2": "llama2",
|
||||||
|
"ChineseLLaMA2": "llama2_zh",
|
||||||
"Baichuan": "baichuan",
|
"Baichuan": "baichuan",
|
||||||
"InternLM": "intern",
|
"InternLM": "intern",
|
||||||
"Qwen": "chatml",
|
"Qwen": "chatml",
|
||||||
"ChatGLM2": "chatglm2"
|
"ChatGLM2": "chatglm2"
|
||||||
}
|
}
|
||||||
|
|
||||||
# huggingface model name prefix 2 template
|
|
||||||
DEFAULT_TEMPLATE_WITH_CUSTOM_MODEL = {
|
|
||||||
"Llama-2": "llama2",
|
|
||||||
"chinese-alpaca-2": "llama2_zh",
|
|
||||||
"alpaca-7b-wdiff": "alpaca",
|
|
||||||
"vicuna": "vicuna",
|
|
||||||
"BELLE": "belle",
|
|
||||||
"Chinese-LLaMA-2": "linly",
|
|
||||||
"BiLLa": "billa",
|
|
||||||
"Ziya": "ziya",
|
|
||||||
"aquilachat": "aquila",
|
|
||||||
"internlm": "intern",
|
|
||||||
"aquilachat": "aquila",
|
|
||||||
"internlm": "intern",
|
|
||||||
"Baichuan":"baichuan",
|
|
||||||
"starchat":"starchat",
|
|
||||||
"Qwen":"chatml",
|
|
||||||
"chatglm2":"chatglm2"
|
|
||||||
}
|
|
@ -95,7 +95,6 @@ def prepare_model_for_training(
|
|||||||
use_gradient_checkpointing: Optional[bool] = True,
|
use_gradient_checkpointing: Optional[bool] = True,
|
||||||
layer_norm_names: Optional[List[str]] = LAYERNORM_NAMES
|
layer_norm_names: Optional[List[str]] = LAYERNORM_NAMES
|
||||||
) -> "PreTrainedModel":
|
) -> "PreTrainedModel":
|
||||||
|
|
||||||
for name, param in model.named_parameters():
|
for name, param in model.named_parameters():
|
||||||
if param.ndim == 1 and any(layer_norm_name in name for layer_norm_name in layer_norm_names):
|
if param.ndim == 1 and any(layer_norm_name in name for layer_norm_name in layer_norm_names):
|
||||||
param.data = param.data.to(torch.float32)
|
param.data = param.data.to(torch.float32)
|
||||||
@ -112,9 +111,6 @@ def prepare_model_for_training(
|
|||||||
model.config.use_cache = False # turn off when gradient checkpointing is enabled
|
model.config.use_cache = False # turn off when gradient checkpointing is enabled
|
||||||
|
|
||||||
if finetuning_type != "full" and hasattr(model, output_layer_name):
|
if finetuning_type != "full" and hasattr(model, output_layer_name):
|
||||||
if hasattr(model, "config") and hasattr(model.config, "pretraining_tp"):
|
|
||||||
model.config.pretraining_tp = 1 # disable TP for LoRA (https://github.com/huggingface/peft/pull/728)
|
|
||||||
|
|
||||||
output_layer: torch.nn.Linear = getattr(model, output_layer_name)
|
output_layer: torch.nn.Linear = getattr(model, output_layer_name)
|
||||||
input_dtype = output_layer.weight.dtype
|
input_dtype = output_layer.weight.dtype
|
||||||
|
|
||||||
|
@ -273,8 +273,8 @@ register_template(
|
|||||||
|
|
||||||
|
|
||||||
r"""
|
r"""
|
||||||
Supports: https://huggingface.co/ziqingyang/chinese-alpaca-2-7b
|
Supports: https://github.com/ymcui/Chinese-LLaMA-Alpaca-2
|
||||||
https://github.com/ymcui/Chinese-LLaMA-Alpaca-2
|
https://huggingface.co/ziqingyang/chinese-alpaca-2-7b
|
||||||
"""
|
"""
|
||||||
register_template(
|
register_template(
|
||||||
name="llama2_zh",
|
name="llama2_zh",
|
||||||
|
@ -6,7 +6,7 @@ import gradio as gr
|
|||||||
from peft.utils import WEIGHTS_NAME as PEFT_WEIGHTS_NAME
|
from peft.utils import WEIGHTS_NAME as PEFT_WEIGHTS_NAME
|
||||||
from transformers.trainer import WEIGHTS_NAME, WEIGHTS_INDEX_NAME
|
from transformers.trainer import WEIGHTS_NAME, WEIGHTS_INDEX_NAME
|
||||||
|
|
||||||
from llmtuner.extras.constants import SUPPORTED_MODELS, DEFAULT_TEMPLATE_WITH_CUSTOM_MODEL, DEFAULT_TEMPLATE
|
from llmtuner.extras.constants import DEFAULT_TEMPLATE, SUPPORTED_MODELS
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_CACHE_DIR = "cache"
|
DEFAULT_CACHE_DIR = "cache"
|
||||||
@ -48,20 +48,10 @@ def get_model_path(model_name: str) -> str:
|
|||||||
return user_config["path_dict"].get(model_name, SUPPORTED_MODELS.get(model_name, ""))
|
return user_config["path_dict"].get(model_name, SUPPORTED_MODELS.get(model_name, ""))
|
||||||
|
|
||||||
|
|
||||||
def get_template(
|
def get_template(model_name: str) -> str:
|
||||||
model_name: str,
|
if model_name.endswith("Chat") and model_name.split("-")[0] in DEFAULT_TEMPLATE:
|
||||||
) -> str:
|
return DEFAULT_TEMPLATE[model_name.split("-")[0]]
|
||||||
if model_name == "Custom":
|
return "default"
|
||||||
model_name_or_path = get_model_path(model_name)
|
|
||||||
# get last dir
|
|
||||||
basename = os.path.basename(model_name_or_path)
|
|
||||||
# prefix match
|
|
||||||
for k, v in DEFAULT_TEMPLATE_WITH_CUSTOM_MODEL.items():
|
|
||||||
if basename.startswith(k):
|
|
||||||
return v
|
|
||||||
return "default"
|
|
||||||
|
|
||||||
return DEFAULT_TEMPLATE.get(model_name.split("-")[0], "default")
|
|
||||||
|
|
||||||
|
|
||||||
def list_checkpoint(model_name: str, finetuning_type: str) -> Dict[str, Any]:
|
def list_checkpoint(model_name: str, finetuning_type: str) -> Dict[str, Any]:
|
||||||
|
@ -4,7 +4,7 @@ import gradio as gr
|
|||||||
|
|
||||||
from llmtuner.extras.constants import METHODS, SUPPORTED_MODELS
|
from llmtuner.extras.constants import METHODS, SUPPORTED_MODELS
|
||||||
from llmtuner.extras.template import templates
|
from llmtuner.extras.template import templates
|
||||||
from llmtuner.webui.common import list_checkpoint, get_model_path, save_config, get_template
|
from llmtuner.webui.common import list_checkpoint, get_model_path, get_template, save_config
|
||||||
from llmtuner.webui.utils import can_quantize
|
from llmtuner.webui.utils import can_quantize
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -36,10 +36,11 @@ def create_top() -> Dict[str, "Component"]:
|
|||||||
list_checkpoint, [model_name, finetuning_type], [checkpoints]
|
list_checkpoint, [model_name, finetuning_type], [checkpoints]
|
||||||
).then(
|
).then(
|
||||||
get_model_path, [model_name], [model_path]
|
get_model_path, [model_name], [model_path]
|
||||||
|
).then(
|
||||||
|
get_template, [model_name], [template]
|
||||||
) # do not save config since the below line will save
|
) # do not save config since the below line will save
|
||||||
|
|
||||||
model_path.change(save_config, [lang, model_name, model_path])
|
model_path.change(save_config, [lang, model_name, model_path])
|
||||||
model_path.change(get_template, [model_name], [template])
|
|
||||||
|
|
||||||
finetuning_type.change(
|
finetuning_type.change(
|
||||||
list_checkpoint, [model_name, finetuning_type], [checkpoints]
|
list_checkpoint, [model_name, finetuning_type], [checkpoints]
|
||||||
|
@ -3,10 +3,10 @@ from transformers.trainer_utils import SchedulerType
|
|||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
|
from llmtuner.extras.constants import STAGES
|
||||||
from llmtuner.webui.common import list_checkpoint, list_dataset, DEFAULT_DATA_DIR
|
from llmtuner.webui.common import list_checkpoint, list_dataset, DEFAULT_DATA_DIR
|
||||||
from llmtuner.webui.components.data import create_preview_box
|
from llmtuner.webui.components.data import create_preview_box
|
||||||
from llmtuner.webui.utils import can_preview, get_preview, gen_plot
|
from llmtuner.webui.utils import can_preview, get_preview, gen_plot
|
||||||
from llmtuner.extras.constants import STAGES
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from gradio.components import Component
|
from gradio.components import Component
|
||||||
@ -15,9 +15,7 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
def create_train_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict[str, "Component"]:
|
def create_train_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict[str, "Component"]:
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
stage = gr.Dropdown(choices=STAGES,
|
training_stage = gr.Dropdown(choices=STAGES, value=STAGES[0], scale=2)
|
||||||
value="Supervised Finetuning", scale=2)
|
|
||||||
|
|
||||||
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2)
|
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2)
|
||||||
dataset = gr.Dropdown(multiselect=True, scale=4)
|
dataset = gr.Dropdown(multiselect=True, scale=4)
|
||||||
data_preview_btn = gr.Button(interactive=False, scale=1)
|
data_preview_btn = gr.Button(interactive=False, scale=1)
|
||||||
@ -104,7 +102,7 @@ def create_train_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dic
|
|||||||
top_elems["quantization_bit"],
|
top_elems["quantization_bit"],
|
||||||
top_elems["template"],
|
top_elems["template"],
|
||||||
top_elems["source_prefix"],
|
top_elems["source_prefix"],
|
||||||
stage,
|
training_stage,
|
||||||
dataset_dir,
|
dataset_dir,
|
||||||
dataset,
|
dataset,
|
||||||
max_source_length,
|
max_source_length,
|
||||||
@ -145,7 +143,7 @@ def create_train_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dic
|
|||||||
)
|
)
|
||||||
|
|
||||||
return dict(
|
return dict(
|
||||||
stage=stage,
|
training_stage=training_stage,
|
||||||
dataset_dir=dataset_dir,
|
dataset_dir=dataset_dir,
|
||||||
dataset=dataset,
|
dataset=dataset,
|
||||||
data_preview_btn=data_preview_btn,
|
data_preview_btn=data_preview_btn,
|
||||||
|
@ -87,6 +87,16 @@ LOCALES = {
|
|||||||
"info": "默认使用的系统提示词"
|
"info": "默认使用的系统提示词"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"training_stage": {
|
||||||
|
"en": {
|
||||||
|
"label": "Stage",
|
||||||
|
"info": "The stage to perform in training."
|
||||||
|
},
|
||||||
|
"zh": {
|
||||||
|
"label": "训练阶段",
|
||||||
|
"info": "目前采用的训练方式。"
|
||||||
|
}
|
||||||
|
},
|
||||||
"dataset_dir": {
|
"dataset_dir": {
|
||||||
"en": {
|
"en": {
|
||||||
"label": "Data dir",
|
"label": "Data dir",
|
||||||
@ -343,16 +353,6 @@ LOCALES = {
|
|||||||
"label": "RLHF 参数设置"
|
"label": "RLHF 参数设置"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"rlhf_method": {
|
|
||||||
"en": {
|
|
||||||
"label": "RLHF method",
|
|
||||||
"info": "The RLHF algorithm to adopt."
|
|
||||||
},
|
|
||||||
"zh": {
|
|
||||||
"label": "RLHF 方法",
|
|
||||||
"info": "RLHF 阶段使用的算法。"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"dpo_beta": {
|
"dpo_beta": {
|
||||||
"en": {
|
"en": {
|
||||||
"label": "DPO beta",
|
"label": "DPO beta",
|
||||||
@ -546,15 +546,7 @@ LOCALES = {
|
|||||||
"zh": {
|
"zh": {
|
||||||
"value": "开始导出"
|
"value": "开始导出"
|
||||||
}
|
}
|
||||||
},
|
}
|
||||||
"stage": {
|
|
||||||
"en": {
|
|
||||||
"label": "train stage"
|
|
||||||
},
|
|
||||||
"zh": {
|
|
||||||
"label": "训练阶段"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -8,11 +8,11 @@ from transformers.trainer import TRAINING_ARGS_NAME
|
|||||||
from typing import Any, Dict, Generator, List, Tuple
|
from typing import Any, Dict, Generator, List, Tuple
|
||||||
|
|
||||||
from llmtuner.extras.callbacks import LogCallback
|
from llmtuner.extras.callbacks import LogCallback
|
||||||
from llmtuner.extras.constants import DEFAULT_MODULE, DEFAULT_TEMPLATE, DEFAULT_TEMPLATE_WITH_CUSTOM_MODEL
|
from llmtuner.extras.constants import DEFAULT_MODULE
|
||||||
from llmtuner.extras.logging import LoggerHandler
|
from llmtuner.extras.logging import LoggerHandler
|
||||||
from llmtuner.extras.misc import torch_gc
|
from llmtuner.extras.misc import torch_gc
|
||||||
from llmtuner.tuner import run_exp
|
from llmtuner.tuner import run_exp
|
||||||
from llmtuner.webui.common import get_model_path, get_save_dir, get_template
|
from llmtuner.webui.common import get_model_path, get_save_dir
|
||||||
from llmtuner.webui.locales import ALERTS
|
from llmtuner.webui.locales import ALERTS
|
||||||
from llmtuner.webui.utils import gen_cmd, get_eval_results, update_process_bar
|
from llmtuner.webui.utils import gen_cmd, get_eval_results, update_process_bar
|
||||||
|
|
||||||
@ -70,7 +70,7 @@ class Runner:
|
|||||||
quantization_bit: str,
|
quantization_bit: str,
|
||||||
template: str,
|
template: str,
|
||||||
source_prefix: str,
|
source_prefix: str,
|
||||||
stage: str,
|
training_stage: str,
|
||||||
dataset_dir: str,
|
dataset_dir: str,
|
||||||
dataset: List[str],
|
dataset: List[str],
|
||||||
max_source_length: int,
|
max_source_length: int,
|
||||||
@ -138,21 +138,21 @@ class Runner:
|
|||||||
)
|
)
|
||||||
args[compute_type] = True
|
args[compute_type] = True
|
||||||
|
|
||||||
if stage == "Pretraining":
|
if training_stage == "Reward Modeling":
|
||||||
args["stage"] = "pt"
|
|
||||||
if stage == "Reward Modeling":
|
|
||||||
args["stage"] = "rm"
|
args["stage"] = "rm"
|
||||||
args["resume_lora_training"] = False
|
args["resume_lora_training"] = False
|
||||||
elif stage == "PPO":
|
elif training_stage == "PPO":
|
||||||
args["stage"] = "ppo"
|
args["stage"] = "ppo"
|
||||||
args["resume_lora_training"] = False
|
args["resume_lora_training"] = False
|
||||||
args["reward_model"] = reward_model
|
args["reward_model"] = reward_model
|
||||||
args["padding_side"] = "left"
|
args["padding_side"] = "left"
|
||||||
val_size = 0
|
val_size = 0
|
||||||
elif stage == "DPO":
|
elif training_stage == "DPO":
|
||||||
args["stage"] = "dpo"
|
args["stage"] = "dpo"
|
||||||
args["resume_lora_training"] = False
|
args["resume_lora_training"] = False
|
||||||
args["dpo_beta"] = dpo_beta
|
args["dpo_beta"] = dpo_beta
|
||||||
|
elif training_stage == "Pre-Training":
|
||||||
|
args["stage"] = "pt"
|
||||||
|
|
||||||
if val_size > 1e-6:
|
if val_size > 1e-6:
|
||||||
args["val_size"] = val_size
|
args["val_size"] = val_size
|
||||||
|
Loading…
x
Reference in New Issue
Block a user