reorganize adapter code

Former-commit-id: 54cd743ebfbd296ae9eaf10c33f59e127f451785
This commit is contained in:
hiyouga 2024-06-08 00:47:23 +08:00
parent bad35d1730
commit 4f0ce9be4e
2 changed files with 224 additions and 193 deletions

View File

@ -15,7 +15,12 @@ class ModelArguments:
) )
adapter_name_or_path: Optional[str] = field( adapter_name_or_path: Optional[str] = field(
default=None, default=None,
metadata={"help": "Path to the adapter weight or identifier from huggingface.co/models."}, metadata={
"help": (
"Path to the adapter weight or identifier from huggingface.co/models. "
"Use commas to separate multiple adapters."
)
},
) )
cache_dir: Optional[str] = field( cache_dir: Optional[str] = field(
default=None, default=None,
@ -35,7 +40,7 @@ class ModelArguments:
) )
new_special_tokens: Optional[str] = field( new_special_tokens: Optional[str] = field(
default=None, default=None,
metadata={"help": "Special tokens to be added into the tokenizer."}, metadata={"help": "Special tokens to be added into the tokenizer. Use commas to separate multiple tokens."},
) )
model_revision: str = field( model_revision: str = field(
default="main", default="main",

View File

@ -21,38 +21,13 @@ if TYPE_CHECKING:
logger = get_logger(__name__) logger = get_logger(__name__)
def init_adapter( def _setup_full_tuning(
config: "PretrainedConfig",
model: "PreTrainedModel", model: "PreTrainedModel",
model_args: "ModelArguments", model_args: "ModelArguments",
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
is_trainable: bool, cast_trainable_params_to_fp32: bool,
) -> "PreTrainedModel": ) -> None:
r"""
Initializes the adapters.
Support full-parameter, freeze and LoRA training.
Note that the trainable parameters must be cast to float32.
"""
if (not is_trainable) and model_args.adapter_name_or_path is None:
logger.info("Adapter is not found at evaluation, load the base model.")
return model
if finetuning_args.finetuning_type != "lora" and getattr(model, "quantization_method", None):
raise ValueError("You can only use lora for quantized models.")
if is_deepspeed_zero3_enabled() or is_fsdp_enabled() or finetuning_args.pure_bf16 or finetuning_args.use_badam:
logger.info("ZeRO3/FSDP/PureBF16/BAdam detected, remaining trainable params as their original precision.")
cast_trainable_params_to_fp32 = False
else:
logger.info("Upcasting trainable params to float32.")
cast_trainable_params_to_fp32 = True
if is_trainable and finetuning_args.finetuning_type == "full":
logger.info("Fine-tuning method: Full") logger.info("Fine-tuning method: Full")
forbidden_modules = set() forbidden_modules = set()
if model_args.visual_inputs and finetuning_args.freeze_vision_tower: if model_args.visual_inputs and finetuning_args.freeze_vision_tower:
forbidden_modules.add("vision_tower") forbidden_modules.add("vision_tower")
@ -67,9 +42,14 @@ def init_adapter(
else: else:
param.requires_grad_(False) param.requires_grad_(False)
if is_trainable and finetuning_args.finetuning_type == "freeze":
logger.info("Fine-tuning method: Freeze")
def _setup_freeze_tuning(
model: "PreTrainedModel",
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments",
cast_trainable_params_to_fp32: bool,
) -> None:
logger.info("Fine-tuning method: Freeze")
if model_args.visual_inputs: if model_args.visual_inputs:
config = model.config.text_config config = model.config.text_config
else: else:
@ -123,9 +103,7 @@ def init_adapter(
for module_name in finetuning_args.freeze_extra_modules: for module_name in finetuning_args.freeze_extra_modules:
if module_name not in non_hidden_modules: if module_name not in non_hidden_modules:
raise ValueError( raise ValueError(
"Module {} is not found, please choose from {}".format( "Module {} is not found, please choose from {}".format(module_name, ", ".join(non_hidden_modules))
module_name, ", ".join(non_hidden_modules)
)
) )
trainable_layers.append(module_name) trainable_layers.append(module_name)
@ -143,9 +121,17 @@ def init_adapter(
else: else:
param.requires_grad_(False) param.requires_grad_(False)
logger.info("Set trainable layers: {}".format(",".join(map(str, trainable_layer_ids)))) logger.info("Set trainable layers: {}".format(",".join(trainable_layers)))
if finetuning_args.finetuning_type == "lora":
def _setup_lora_tuning(
config: "PretrainedConfig",
model: "PreTrainedModel",
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments",
is_trainable: bool,
cast_trainable_params_to_fp32: bool,
) -> "PeftModel":
logger.info("Fine-tuning method: {}".format("DoRA" if finetuning_args.use_dora else "LoRA")) logger.info("Fine-tuning method: {}".format("DoRA" if finetuning_args.use_dora else "LoRA"))
adapter_to_resume = None adapter_to_resume = None
@ -170,9 +156,7 @@ def init_adapter(
adapter_to_merge = model_args.adapter_name_or_path adapter_to_merge = model_args.adapter_name_or_path
for adapter in adapter_to_merge: for adapter in adapter_to_merge:
model: "LoraModel" = PeftModel.from_pretrained( model: "LoraModel" = PeftModel.from_pretrained(model, adapter, offload_folder=model_args.offload_folder)
model, adapter, offload_folder=model_args.offload_folder
)
model = model.merge_and_unload() model = model.merge_and_unload()
if len(adapter_to_merge) > 0: if len(adapter_to_merge) > 0:
@ -247,3 +231,45 @@ def init_adapter(
logger.info("Loaded adapter(s): {}".format(",".join(model_args.adapter_name_or_path))) logger.info("Loaded adapter(s): {}".format(",".join(model_args.adapter_name_or_path)))
return model return model
def init_adapter(
config: "PretrainedConfig",
model: "PreTrainedModel",
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments",
is_trainable: bool,
) -> "PreTrainedModel":
r"""
Initializes the adapters.
Support full-parameter, freeze and LoRA training.
Note that the trainable parameters must be cast to float32.
"""
if (not is_trainable) and model_args.adapter_name_or_path is None:
logger.info("Adapter is not found at evaluation, load the base model.")
return model
if finetuning_args.finetuning_type != "lora" and getattr(model, "quantization_method", None):
raise ValueError("You can only use lora for quantized models.")
if is_deepspeed_zero3_enabled() or is_fsdp_enabled() or finetuning_args.pure_bf16 or finetuning_args.use_badam:
logger.info("ZeRO3/FSDP/PureBF16/BAdam detected, remaining trainable params as their original precision.")
cast_trainable_params_to_fp32 = False
else:
logger.info("Upcasting trainable params to float32.")
cast_trainable_params_to_fp32 = True
if is_trainable and finetuning_args.finetuning_type == "full":
_setup_full_tuning(model, model_args, finetuning_args, cast_trainable_params_to_fp32)
if is_trainable and finetuning_args.finetuning_type == "freeze":
_setup_freeze_tuning(model, model_args, finetuning_args, cast_trainable_params_to_fp32)
if finetuning_args.finetuning_type == "lora":
model = _setup_lora_tuning(
config, model, model_args, finetuning_args, is_trainable, cast_trainable_params_to_fp32
)
return model