refactor finetuning Args

Former-commit-id: 620efe1d8d2429f4bc3fa8009900ec43e1b5ef4b
This commit is contained in:
hiyouga 2023-09-27 22:28:06 +08:00
parent 650a2a2e01
commit 927ff702ff
2 changed files with 18 additions and 22 deletions

View File

@ -12,18 +12,6 @@ class FinetuningArguments:
default="lora",
metadata={"help": "Which fine-tuning method to use."}
)
num_hidden_layers: Optional[int] = field(
default=32,
metadata={"help": "Number of decoder blocks in the model for partial-parameter (freeze) fine-tuning. \
LLaMA choices: [\"32\", \"40\", \"60\", \"80\"], \
LLaMA-2 choices: [\"32\", \"40\", \"80\"], \
BLOOM choices: [\"24\", \"30\", \"70\"], \
Falcon choices: [\"32\", \"60\"], \
Baichuan choices: [\"32\", \"40\"] \
Qwen choices: [\"32\"], \
XVERSE choices: [\"40\"], \
ChatGLM2 choices: [\"28\"]"}
)
num_layer_trainable: Optional[int] = field(
default=3,
metadata={"help": "Number of trainable layers for partial-parameter (freeze) fine-tuning."}
@ -33,9 +21,9 @@ class FinetuningArguments:
metadata={"help": "Name of trainable modules for partial-parameter (freeze) fine-tuning. \
LLaMA choices: [\"mlp\", \"self_attn\"], \
BLOOM & Falcon & ChatGLM2 choices: [\"mlp\", \"self_attention\"], \
Baichuan choices: [\"mlp\", \"self_attn\"], \
Qwen choices: [\"mlp\", \"attn\"], \
LLaMA-2, InternLM, XVERSE choices: the same as LLaMA."}
Phi-1.5 choices: [\"mlp\", \"mixer\"], \
LLaMA-2, Baichuan, InternLM, XVERSE choices: the same as LLaMA."}
)
lora_rank: Optional[int] = field(
default=8,
@ -56,8 +44,13 @@ class FinetuningArguments:
BLOOM & Falcon & ChatGLM2 choices: [\"query_key_value\", \"self_attention.dense\", \"mlp.dense\"], \
Baichuan choices: [\"W_pack\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \
Qwen choices: [\"c_attn\", \"attn.c_proj\", \"w1\", \"w2\", \"mlp.c_proj\"], \
Phi-1.5 choices: [\"Wqkv\", \"out_proj\", \"fc1\", \"fc2\"], \
LLaMA-2, InternLM, XVERSE choices: the same as LLaMA."}
)
additional_target: Optional[str] = field(
default=None,
metadata={"help": "Name(s) of modules apart from LoRA layers to be set as trainable and saved in the final checkpoint."}
)
resume_lora_training: Optional[bool] = field(
default=True,
metadata={"help": "Whether to resume training from the last LoRA weights or create new weights after merging them."}
@ -75,12 +68,8 @@ class FinetuningArguments:
if isinstance(self.lora_target, str): # support custom target modules/layers of LoRA
self.lora_target = [target.strip() for target in self.lora_target.split(",")]
if self.num_layer_trainable > 0: # fine-tuning the last n layers if num_layer_trainable > 0
trainable_layer_ids = [self.num_hidden_layers - k - 1 for k in range(self.num_layer_trainable)]
else: # fine-tuning the first n layers if num_layer_trainable < 0
trainable_layer_ids = [k for k in range(-self.num_layer_trainable)]
self.trainable_layers = ["{:d}.{}".format(idx, self.name_module_trainable) for idx in trainable_layer_ids]
if isinstance(self.additional_target, str):
self.additional_target = [target.strip() for target in self.additional_target.split(",")]
assert self.finetuning_type in ["lora", "freeze", "full", "none"], "Invalid fine-tuning method."

View File

@ -45,9 +45,15 @@ def init_adapter(
if finetuning_args.finetuning_type == "freeze":
logger.info("Fine-tuning method: Freeze")
num_layers = getattr(model.config, "num_layers")
if finetuning_args.num_layer_trainable > 0: # fine-tuning the last n layers if num_layer_trainable > 0
trainable_layer_ids = [num_layers - k - 1 for k in range(finetuning_args.num_layer_trainable)]
else: # fine-tuning the first n layers if num_layer_trainable < 0
trainable_layer_ids = [k for k in range(-finetuning_args.num_layer_trainable)]
trainable_layers = ["{:d}.{}".format(idx, finetuning_args.name_module_trainable) for idx in trainable_layer_ids]
for name, param in model.named_parameters():
if not any(trainable_layer in name for trainable_layer in finetuning_args.trainable_layers):
if not any(trainable_layer in name for trainable_layer in trainable_layers):
param.requires_grad_(False)
else:
param.data = param.data.to(torch.float32)
@ -89,7 +95,8 @@ def init_adapter(
r=finetuning_args.lora_rank,
lora_alpha=finetuning_args.lora_alpha,
lora_dropout=finetuning_args.lora_dropout,
target_modules=target_modules
target_modules=target_modules,
modules_to_save=finetuning_args.additional_target
)
model = get_peft_model(model, lora_config)
if id(model.peft_config) != id(model.base_model.peft_config): # https://github.com/huggingface/peft/issues/923