mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-16 03:40:34 +08:00
support lora for llama pro
This commit is contained in:
@@ -108,6 +108,18 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
)
|
||||
)
|
||||
|
||||
with gr.Accordion(label="Freeze config", open=False) as freeze_tab:
|
||||
with gr.Row():
|
||||
num_layer_trainable = gr.Slider(value=3, minimum=1, maximum=128, step=1, scale=2)
|
||||
name_module_trainable = gr.Textbox(scale=3)
|
||||
|
||||
input_elems.update({num_layer_trainable, name_module_trainable})
|
||||
elem_dict.update(
|
||||
dict(
|
||||
freeze_tab=freeze_tab, num_layer_trainable=num_layer_trainable, name_module_trainable=name_module_trainable
|
||||
)
|
||||
)
|
||||
|
||||
with gr.Accordion(label="LoRA config", open=False) as lora_tab:
|
||||
with gr.Row():
|
||||
lora_rank = gr.Slider(value=8, minimum=1, maximum=1024, step=1)
|
||||
|
||||
@@ -508,6 +508,45 @@ LOCALES = {
|
||||
"info": "仅训练块扩展后的参数。",
|
||||
},
|
||||
},
|
||||
"freeze_tab": {
|
||||
"en": {
|
||||
"label": "Freeze tuning configurations",
|
||||
},
|
||||
"ru": {
|
||||
"label": "конфигурации для настройки заморозки",
|
||||
},
|
||||
"zh": {
|
||||
"label": "部分参数微调设置",
|
||||
},
|
||||
},
|
||||
"num_layer_trainable": {
|
||||
"en": {
|
||||
"label": "Trainable layers",
|
||||
"info": "The number of trainable layers.",
|
||||
},
|
||||
"ru": {
|
||||
"label": "Обучаемые слои",
|
||||
"info": "Количество обучаемых слоев.",
|
||||
},
|
||||
"zh": {
|
||||
"label": "可训练层数",
|
||||
"info": "可训练模型层的数量。",
|
||||
},
|
||||
},
|
||||
"name_module_trainable": {
|
||||
"en": {
|
||||
"label": "Trainable modules",
|
||||
"info": "The name of trainable modules. Use commas to separate multiple modules.",
|
||||
},
|
||||
"ru": {
|
||||
"label": "Обучаемые модули",
|
||||
"info": "Название обучаемых модулей. Используйте запятые для разделения нескольких модулей.",
|
||||
},
|
||||
"zh": {
|
||||
"label": "可训练模块",
|
||||
"info": "可训练模块的名称。使用英文逗号分隔多个名称。",
|
||||
},
|
||||
},
|
||||
"lora_tab": {
|
||||
"en": {
|
||||
"label": "LoRA configurations",
|
||||
|
||||
@@ -129,26 +129,34 @@ class Runner:
|
||||
sft_packing=get("train.sft_packing"),
|
||||
upcast_layernorm=get("train.upcast_layernorm"),
|
||||
use_llama_pro=get("train.use_llama_pro"),
|
||||
lora_rank=get("train.lora_rank"),
|
||||
lora_dropout=get("train.lora_dropout"),
|
||||
lora_target=get("train.lora_target") or get_module(get("top.model_name")),
|
||||
additional_target=get("train.additional_target") or None,
|
||||
use_rslora=get("train.use_rslora"),
|
||||
create_new_adapter=get("train.create_new_adapter"),
|
||||
output_dir=get_save_dir(get("top.model_name"), get("top.finetuning_type"), get("train.output_dir")),
|
||||
fp16=(get("train.compute_type") == "fp16"),
|
||||
bf16=(get("train.compute_type") == "bf16"),
|
||||
)
|
||||
args["disable_tqdm"] = True
|
||||
|
||||
if TRAINING_STAGES[get("train.training_stage")] in ["rm", "ppo", "dpo"]:
|
||||
args["create_new_adapter"] = args["quantization_bit"] is None
|
||||
if args["finetuning_type"] == "freeze":
|
||||
args["num_layer_trainable"] = int(get("train.num_layer_trainable"))
|
||||
args["name_module_trainable"] = get("train.name_module_trainable")
|
||||
elif args["finetuning_type"] == "lora":
|
||||
args["lora_rank"] = get("train.lora_rank")
|
||||
args["lora_dropout"] = get("train.lora_dropout")
|
||||
args["lora_target"] = get("train.lora_target") or get_module(get("top.model_name"))
|
||||
args["additional_target"] = get("train.additional_target") or None
|
||||
args["use_rslora"] = get("train.use_rslora")
|
||||
if TRAINING_STAGES[get("train.training_stage")] in ["rm", "ppo", "dpo"]:
|
||||
args["create_new_adapter"] = args["quantization_bit"] is None
|
||||
else:
|
||||
args["create_new_adapter"] = get("train.create_new_adapter")
|
||||
|
||||
if args["use_llama_pro"]:
|
||||
args["num_layer_trainable"] = int(get("train.num_layer_trainable"))
|
||||
|
||||
if args["stage"] == "ppo":
|
||||
args["reward_model"] = get_save_dir(
|
||||
get("top.model_name"), get("top.finetuning_type"), get("train.reward_model")
|
||||
)
|
||||
args["reward_model_type"] = "lora" if get("top.finetuning_type") == "lora" else "full"
|
||||
args["reward_model_type"] = "lora" if args["finetuning_type"] == "lora" else "full"
|
||||
|
||||
if args["stage"] == "dpo":
|
||||
args["dpo_beta"] = get("train.dpo_beta")
|
||||
|
||||
Reference in New Issue
Block a user