support lora for llama pro

This commit is contained in:
hiyouga
2024-02-21 02:17:22 +08:00
parent 02c8c55ce3
commit 9aeb404a94
7 changed files with 119 additions and 28 deletions

View File

@@ -129,26 +129,34 @@ class Runner:
sft_packing=get("train.sft_packing"),
upcast_layernorm=get("train.upcast_layernorm"),
use_llama_pro=get("train.use_llama_pro"),
lora_rank=get("train.lora_rank"),
lora_dropout=get("train.lora_dropout"),
lora_target=get("train.lora_target") or get_module(get("top.model_name")),
additional_target=get("train.additional_target") or None,
use_rslora=get("train.use_rslora"),
create_new_adapter=get("train.create_new_adapter"),
output_dir=get_save_dir(get("top.model_name"), get("top.finetuning_type"), get("train.output_dir")),
fp16=(get("train.compute_type") == "fp16"),
bf16=(get("train.compute_type") == "bf16"),
)
args["disable_tqdm"] = True
if TRAINING_STAGES[get("train.training_stage")] in ["rm", "ppo", "dpo"]:
args["create_new_adapter"] = args["quantization_bit"] is None
if args["finetuning_type"] == "freeze":
args["num_layer_trainable"] = int(get("train.num_layer_trainable"))
args["name_module_trainable"] = get("train.name_module_trainable")
elif args["finetuning_type"] == "lora":
args["lora_rank"] = get("train.lora_rank")
args["lora_dropout"] = get("train.lora_dropout")
args["lora_target"] = get("train.lora_target") or get_module(get("top.model_name"))
args["additional_target"] = get("train.additional_target") or None
args["use_rslora"] = get("train.use_rslora")
if TRAINING_STAGES[get("train.training_stage")] in ["rm", "ppo", "dpo"]:
args["create_new_adapter"] = args["quantization_bit"] is None
else:
args["create_new_adapter"] = get("train.create_new_adapter")
if args["use_llama_pro"]:
args["num_layer_trainable"] = int(get("train.num_layer_trainable"))
if args["stage"] == "ppo":
args["reward_model"] = get_save_dir(
get("top.model_name"), get("top.finetuning_type"), get("train.reward_model")
)
args["reward_model_type"] = "lora" if get("top.finetuning_type") == "lora" else "full"
args["reward_model_type"] = "lora" if args["finetuning_type"] == "lora" else "full"
if args["stage"] == "dpo":
args["dpo_beta"] = get("train.dpo_beta")