mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-09-12 16:12:48 +08:00
reorganize adapter code
Former-commit-id: 54cd743ebfbd296ae9eaf10c33f59e127f451785
This commit is contained in:
parent
bad35d1730
commit
4f0ce9be4e
@ -15,7 +15,12 @@ class ModelArguments:
|
|||||||
)
|
)
|
||||||
adapter_name_or_path: Optional[str] = field(
|
adapter_name_or_path: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Path to the adapter weight or identifier from huggingface.co/models."},
|
metadata={
|
||||||
|
"help": (
|
||||||
|
"Path to the adapter weight or identifier from huggingface.co/models. "
|
||||||
|
"Use commas to separate multiple adapters."
|
||||||
|
)
|
||||||
|
},
|
||||||
)
|
)
|
||||||
cache_dir: Optional[str] = field(
|
cache_dir: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
@ -35,7 +40,7 @@ class ModelArguments:
|
|||||||
)
|
)
|
||||||
new_special_tokens: Optional[str] = field(
|
new_special_tokens: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Special tokens to be added into the tokenizer."},
|
metadata={"help": "Special tokens to be added into the tokenizer. Use commas to separate multiple tokens."},
|
||||||
)
|
)
|
||||||
model_revision: str = field(
|
model_revision: str = field(
|
||||||
default="main",
|
default="main",
|
||||||
|
@ -21,38 +21,13 @@ if TYPE_CHECKING:
|
|||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def init_adapter(
|
def _setup_full_tuning(
|
||||||
config: "PretrainedConfig",
|
|
||||||
model: "PreTrainedModel",
|
model: "PreTrainedModel",
|
||||||
model_args: "ModelArguments",
|
model_args: "ModelArguments",
|
||||||
finetuning_args: "FinetuningArguments",
|
finetuning_args: "FinetuningArguments",
|
||||||
is_trainable: bool,
|
cast_trainable_params_to_fp32: bool,
|
||||||
) -> "PreTrainedModel":
|
) -> None:
|
||||||
r"""
|
|
||||||
Initializes the adapters.
|
|
||||||
|
|
||||||
Support full-parameter, freeze and LoRA training.
|
|
||||||
|
|
||||||
Note that the trainable parameters must be cast to float32.
|
|
||||||
"""
|
|
||||||
|
|
||||||
if (not is_trainable) and model_args.adapter_name_or_path is None:
|
|
||||||
logger.info("Adapter is not found at evaluation, load the base model.")
|
|
||||||
return model
|
|
||||||
|
|
||||||
if finetuning_args.finetuning_type != "lora" and getattr(model, "quantization_method", None):
|
|
||||||
raise ValueError("You can only use lora for quantized models.")
|
|
||||||
|
|
||||||
if is_deepspeed_zero3_enabled() or is_fsdp_enabled() or finetuning_args.pure_bf16 or finetuning_args.use_badam:
|
|
||||||
logger.info("ZeRO3/FSDP/PureBF16/BAdam detected, remaining trainable params as their original precision.")
|
|
||||||
cast_trainable_params_to_fp32 = False
|
|
||||||
else:
|
|
||||||
logger.info("Upcasting trainable params to float32.")
|
|
||||||
cast_trainable_params_to_fp32 = True
|
|
||||||
|
|
||||||
if is_trainable and finetuning_args.finetuning_type == "full":
|
|
||||||
logger.info("Fine-tuning method: Full")
|
logger.info("Fine-tuning method: Full")
|
||||||
|
|
||||||
forbidden_modules = set()
|
forbidden_modules = set()
|
||||||
if model_args.visual_inputs and finetuning_args.freeze_vision_tower:
|
if model_args.visual_inputs and finetuning_args.freeze_vision_tower:
|
||||||
forbidden_modules.add("vision_tower")
|
forbidden_modules.add("vision_tower")
|
||||||
@ -67,9 +42,14 @@ def init_adapter(
|
|||||||
else:
|
else:
|
||||||
param.requires_grad_(False)
|
param.requires_grad_(False)
|
||||||
|
|
||||||
if is_trainable and finetuning_args.finetuning_type == "freeze":
|
|
||||||
logger.info("Fine-tuning method: Freeze")
|
|
||||||
|
|
||||||
|
def _setup_freeze_tuning(
|
||||||
|
model: "PreTrainedModel",
|
||||||
|
model_args: "ModelArguments",
|
||||||
|
finetuning_args: "FinetuningArguments",
|
||||||
|
cast_trainable_params_to_fp32: bool,
|
||||||
|
) -> None:
|
||||||
|
logger.info("Fine-tuning method: Freeze")
|
||||||
if model_args.visual_inputs:
|
if model_args.visual_inputs:
|
||||||
config = model.config.text_config
|
config = model.config.text_config
|
||||||
else:
|
else:
|
||||||
@ -123,9 +103,7 @@ def init_adapter(
|
|||||||
for module_name in finetuning_args.freeze_extra_modules:
|
for module_name in finetuning_args.freeze_extra_modules:
|
||||||
if module_name not in non_hidden_modules:
|
if module_name not in non_hidden_modules:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Module {} is not found, please choose from {}".format(
|
"Module {} is not found, please choose from {}".format(module_name, ", ".join(non_hidden_modules))
|
||||||
module_name, ", ".join(non_hidden_modules)
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
trainable_layers.append(module_name)
|
trainable_layers.append(module_name)
|
||||||
@ -143,9 +121,17 @@ def init_adapter(
|
|||||||
else:
|
else:
|
||||||
param.requires_grad_(False)
|
param.requires_grad_(False)
|
||||||
|
|
||||||
logger.info("Set trainable layers: {}".format(",".join(map(str, trainable_layer_ids))))
|
logger.info("Set trainable layers: {}".format(",".join(trainable_layers)))
|
||||||
|
|
||||||
if finetuning_args.finetuning_type == "lora":
|
|
||||||
|
def _setup_lora_tuning(
|
||||||
|
config: "PretrainedConfig",
|
||||||
|
model: "PreTrainedModel",
|
||||||
|
model_args: "ModelArguments",
|
||||||
|
finetuning_args: "FinetuningArguments",
|
||||||
|
is_trainable: bool,
|
||||||
|
cast_trainable_params_to_fp32: bool,
|
||||||
|
) -> "PeftModel":
|
||||||
logger.info("Fine-tuning method: {}".format("DoRA" if finetuning_args.use_dora else "LoRA"))
|
logger.info("Fine-tuning method: {}".format("DoRA" if finetuning_args.use_dora else "LoRA"))
|
||||||
adapter_to_resume = None
|
adapter_to_resume = None
|
||||||
|
|
||||||
@ -170,9 +156,7 @@ def init_adapter(
|
|||||||
adapter_to_merge = model_args.adapter_name_or_path
|
adapter_to_merge = model_args.adapter_name_or_path
|
||||||
|
|
||||||
for adapter in adapter_to_merge:
|
for adapter in adapter_to_merge:
|
||||||
model: "LoraModel" = PeftModel.from_pretrained(
|
model: "LoraModel" = PeftModel.from_pretrained(model, adapter, offload_folder=model_args.offload_folder)
|
||||||
model, adapter, offload_folder=model_args.offload_folder
|
|
||||||
)
|
|
||||||
model = model.merge_and_unload()
|
model = model.merge_and_unload()
|
||||||
|
|
||||||
if len(adapter_to_merge) > 0:
|
if len(adapter_to_merge) > 0:
|
||||||
@ -247,3 +231,45 @@ def init_adapter(
|
|||||||
logger.info("Loaded adapter(s): {}".format(",".join(model_args.adapter_name_or_path)))
|
logger.info("Loaded adapter(s): {}".format(",".join(model_args.adapter_name_or_path)))
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def init_adapter(
|
||||||
|
config: "PretrainedConfig",
|
||||||
|
model: "PreTrainedModel",
|
||||||
|
model_args: "ModelArguments",
|
||||||
|
finetuning_args: "FinetuningArguments",
|
||||||
|
is_trainable: bool,
|
||||||
|
) -> "PreTrainedModel":
|
||||||
|
r"""
|
||||||
|
Initializes the adapters.
|
||||||
|
|
||||||
|
Support full-parameter, freeze and LoRA training.
|
||||||
|
|
||||||
|
Note that the trainable parameters must be cast to float32.
|
||||||
|
"""
|
||||||
|
if (not is_trainable) and model_args.adapter_name_or_path is None:
|
||||||
|
logger.info("Adapter is not found at evaluation, load the base model.")
|
||||||
|
return model
|
||||||
|
|
||||||
|
if finetuning_args.finetuning_type != "lora" and getattr(model, "quantization_method", None):
|
||||||
|
raise ValueError("You can only use lora for quantized models.")
|
||||||
|
|
||||||
|
if is_deepspeed_zero3_enabled() or is_fsdp_enabled() or finetuning_args.pure_bf16 or finetuning_args.use_badam:
|
||||||
|
logger.info("ZeRO3/FSDP/PureBF16/BAdam detected, remaining trainable params as their original precision.")
|
||||||
|
cast_trainable_params_to_fp32 = False
|
||||||
|
else:
|
||||||
|
logger.info("Upcasting trainable params to float32.")
|
||||||
|
cast_trainable_params_to_fp32 = True
|
||||||
|
|
||||||
|
if is_trainable and finetuning_args.finetuning_type == "full":
|
||||||
|
_setup_full_tuning(model, model_args, finetuning_args, cast_trainable_params_to_fp32)
|
||||||
|
|
||||||
|
if is_trainable and finetuning_args.finetuning_type == "freeze":
|
||||||
|
_setup_freeze_tuning(model, model_args, finetuning_args, cast_trainable_params_to_fp32)
|
||||||
|
|
||||||
|
if finetuning_args.finetuning_type == "lora":
|
||||||
|
model = _setup_lora_tuning(
|
||||||
|
config, model, model_args, finetuning_args, is_trainable, cast_trainable_params_to_fp32
|
||||||
|
)
|
||||||
|
|
||||||
|
return model
|
||||||
|
Loading…
x
Reference in New Issue
Block a user