mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-16 20:00:36 +08:00
fix #3694
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import re
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
@@ -68,37 +69,52 @@ def init_adapter(
|
||||
raise ValueError("Current model does not support freeze tuning.")
|
||||
|
||||
if finetuning_args.use_llama_pro:
|
||||
if num_layers % finetuning_args.num_layer_trainable != 0:
|
||||
if num_layers % finetuning_args.freeze_trainable_layers != 0:
|
||||
raise ValueError(
|
||||
"`num_layers` {} should be divisible by `num_layer_trainable` {}.".format(
|
||||
num_layers, finetuning_args.num_layer_trainable
|
||||
num_layers, finetuning_args.freeze_trainable_layers
|
||||
)
|
||||
)
|
||||
|
||||
stride = num_layers // finetuning_args.num_layer_trainable
|
||||
stride = num_layers // finetuning_args.freeze_trainable_layers
|
||||
trainable_layer_ids = range(stride - 1, num_layers + stride - 1, stride)
|
||||
elif finetuning_args.num_layer_trainable > 0: # fine-tuning the last n layers if num_layer_trainable > 0
|
||||
trainable_layer_ids = range(num_layers - finetuning_args.num_layer_trainable, num_layers)
|
||||
elif finetuning_args.freeze_trainable_layers > 0: # fine-tuning the last n layers if num_layer_trainable > 0
|
||||
trainable_layer_ids = range(max(0, num_layers - finetuning_args.freeze_trainable_layers), num_layers)
|
||||
else: # fine-tuning the first n layers if num_layer_trainable < 0
|
||||
trainable_layer_ids = range(-finetuning_args.num_layer_trainable)
|
||||
trainable_layer_ids = range(min(-finetuning_args.freeze_trainable_layers, num_layers))
|
||||
|
||||
freeze_modules = {"all"}
|
||||
for name, _ in model.named_modules():
|
||||
hidden_modules = set()
|
||||
non_hidden_modules = set()
|
||||
for name, _ in model.named_parameters():
|
||||
if ".0." in name:
|
||||
freeze_modules.add(name.split(".0.")[-1].split(".")[0])
|
||||
hidden_modules.add(name.split(".0.")[-1].split(".")[0])
|
||||
elif ".1." in name: # MoD starts from layer 1
|
||||
freeze_modules.add(name.split(".1.")[-1].split(".")[0])
|
||||
hidden_modules.add(name.split(".1.")[-1].split(".")[0])
|
||||
|
||||
if re.search(r"\.\d+\.", name) is None:
|
||||
non_hidden_modules.add(name.split(".")[-2])
|
||||
|
||||
trainable_layers = []
|
||||
for module_name in finetuning_args.name_module_trainable:
|
||||
if module_name not in freeze_modules:
|
||||
for module_name in finetuning_args.freeze_trainable_modules:
|
||||
if module_name != "all" and module_name not in hidden_modules:
|
||||
raise ValueError(
|
||||
"Module {} is not found, please choose from {}".format(module_name, ", ".join(freeze_modules))
|
||||
"Module {} is not found, please choose from {}".format(module_name, ", ".join(hidden_modules))
|
||||
)
|
||||
|
||||
for idx in trainable_layer_ids:
|
||||
trainable_layers.append(".{:d}.{}".format(idx, module_name if module_name != "all" else ""))
|
||||
|
||||
if finetuning_args.freeze_extra_modules:
|
||||
for module_name in finetuning_args.freeze_extra_modules:
|
||||
if module_name not in non_hidden_modules:
|
||||
raise ValueError(
|
||||
"Module {} is not found, please choose from {}".format(
|
||||
module_name, ", ".join(non_hidden_modules)
|
||||
)
|
||||
)
|
||||
|
||||
trainable_layers.append(module_name)
|
||||
|
||||
for name, param in model.named_parameters():
|
||||
if any(trainable_layer in name for trainable_layer in trainable_layers):
|
||||
if cast_trainable_params_to_fp32:
|
||||
|
||||
Reference in New Issue
Block a user