This commit is contained in:
hiyouga
2024-05-16 00:35:28 +08:00
parent 44cfa9a1cd
commit 2a67ab3925
7 changed files with 133 additions and 77 deletions

View File

@@ -1,3 +1,4 @@
import re
from typing import TYPE_CHECKING
import torch
@@ -68,37 +69,52 @@ def init_adapter(
raise ValueError("Current model does not support freeze tuning.")
if finetuning_args.use_llama_pro:
if num_layers % finetuning_args.num_layer_trainable != 0:
if num_layers % finetuning_args.freeze_trainable_layers != 0:
raise ValueError(
"`num_layers` {} should be divisible by `num_layer_trainable` {}.".format(
num_layers, finetuning_args.num_layer_trainable
num_layers, finetuning_args.freeze_trainable_layers
)
)
stride = num_layers // finetuning_args.num_layer_trainable
stride = num_layers // finetuning_args.freeze_trainable_layers
trainable_layer_ids = range(stride - 1, num_layers + stride - 1, stride)
elif finetuning_args.num_layer_trainable > 0: # fine-tuning the last n layers if num_layer_trainable > 0
trainable_layer_ids = range(num_layers - finetuning_args.num_layer_trainable, num_layers)
elif finetuning_args.freeze_trainable_layers > 0: # fine-tuning the last n layers if num_layer_trainable > 0
trainable_layer_ids = range(max(0, num_layers - finetuning_args.freeze_trainable_layers), num_layers)
else: # fine-tuning the first n layers if num_layer_trainable < 0
trainable_layer_ids = range(-finetuning_args.num_layer_trainable)
trainable_layer_ids = range(min(-finetuning_args.freeze_trainable_layers, num_layers))
freeze_modules = {"all"}
for name, _ in model.named_modules():
hidden_modules = set()
non_hidden_modules = set()
for name, _ in model.named_parameters():
if ".0." in name:
freeze_modules.add(name.split(".0.")[-1].split(".")[0])
hidden_modules.add(name.split(".0.")[-1].split(".")[0])
elif ".1." in name: # MoD starts from layer 1
freeze_modules.add(name.split(".1.")[-1].split(".")[0])
hidden_modules.add(name.split(".1.")[-1].split(".")[0])
if re.search(r"\.\d+\.", name) is None:
non_hidden_modules.add(name.split(".")[-2])
trainable_layers = []
for module_name in finetuning_args.name_module_trainable:
if module_name not in freeze_modules:
for module_name in finetuning_args.freeze_trainable_modules:
if module_name != "all" and module_name not in hidden_modules:
raise ValueError(
"Module {} is not found, please choose from {}".format(module_name, ", ".join(freeze_modules))
"Module {} is not found, please choose from {}".format(module_name, ", ".join(hidden_modules))
)
for idx in trainable_layer_ids:
trainable_layers.append(".{:d}.{}".format(idx, module_name if module_name != "all" else ""))
if finetuning_args.freeze_extra_modules:
for module_name in finetuning_args.freeze_extra_modules:
if module_name not in non_hidden_modules:
raise ValueError(
"Module {} is not found, please choose from {}".format(
module_name, ", ".join(non_hidden_modules)
)
)
trainable_layers.append(module_name)
for name, param in model.named_parameters():
if any(trainable_layer in name for trainable_layer in trainable_layers):
if cast_trainable_params_to_fp32: