remove conflicts

This commit is contained in:
BUAADreamer
2024-04-25 00:34:22 +08:00
32 changed files with 965 additions and 471 deletions

View File

@@ -1,25 +1,30 @@
from typing import TYPE_CHECKING, Union
from typing import TYPE_CHECKING
import torch
from peft import LoraConfig, LoraModel, PeftModel, TaskType, get_peft_model
from transformers import AutoModelForVision2Seq
from transformers.integrations import is_deepspeed_zero3_enabled
from ..extras.logging import get_logger
from .utils import QuantizationMethod, find_all_linear_modules, find_expanded_modules
from .utils.misc import find_all_linear_modules, find_expanded_modules
from .utils.quantization import QuantizationMethod
from .utils.unsloth import get_unsloth_peft_model, load_unsloth_peft_model
if TYPE_CHECKING:
from transformers.modeling_utils import PreTrainedModel, AutoModelForVision2Seq
from transformers import PretrainedConfig, PreTrainedModel, AutoModelForVision2Seq
from ..hparams import FinetuningArguments, ModelArguments
logger = get_logger(__name__)
def init_adapter(
model: "PreTrainedModel", model_args: "ModelArguments",
finetuning_args: "FinetuningArguments",
is_trainable: bool
config: "PretrainedConfig",
model: "PreTrainedModel",
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments",
is_trainable: bool,
) -> "PreTrainedModel":
r"""
Initializes the adapters.
@@ -106,6 +111,10 @@ def init_adapter(
assert len(model_args.adapter_name_or_path) == 1, "Cannot use multiple adapters in DeepSpeed ZeRO-3."
is_mergeable = False
if model_args.use_unsloth:
assert len(model_args.adapter_name_or_path) == 1, "Unsloth model only accepts a single adapter."
is_mergeable = False
if (is_trainable and not finetuning_args.create_new_adapter) or (not is_mergeable):
adapter_to_merge = model_args.adapter_name_or_path[:-1]
adapter_to_resume = model_args.adapter_name_or_path[-1]
@@ -122,9 +131,15 @@ def init_adapter(
logger.info("Merged {} adapter(s).".format(len(adapter_to_merge)))
if adapter_to_resume is not None: # resume lora training
model = PeftModel.from_pretrained(
model, adapter_to_resume, is_trainable=is_trainable, offload_folder=model_args.offload_folder
)
if model_args.use_unsloth:
model = load_unsloth_peft_model(config, model_args, is_trainable=is_trainable)
else:
model = PeftModel.from_pretrained(
model,
adapter_to_resume,
is_trainable=is_trainable,
offload_folder=model_args.offload_folder,
)
if is_trainable and adapter_to_resume is None: # create new lora weights while training
if len(finetuning_args.lora_target) == 1 and finetuning_args.lora_target[0] == "all":
@@ -142,6 +157,17 @@ def init_adapter(
):
raise ValueError("DoRA is not compatible with PTQ-quantized models.")
if model_args.resize_vocab and finetuning_args.additional_target is None:
input_embeddings = model.get_input_embeddings()
output_embeddings = model.get_output_embeddings()
module_names = set()
for name, module in model.named_modules():
if module in [input_embeddings, output_embeddings]:
module_names.add(name.split(".")[-1])
finetuning_args.additional_target = module_names
logger.warning("Vocab has been resized, add {} to trainable params.".format(",".join(module_names)))
peft_kwargs = {
"r": finetuning_args.lora_rank,
"target_modules": target_modules,
@@ -152,14 +178,7 @@ def init_adapter(
}
if model_args.use_unsloth:
from unsloth import FastLanguageModel # type: ignore
unsloth_peft_kwargs = {
"model": model,
"max_seq_length": model_args.model_max_length,
"use_gradient_checkpointing": "unsloth",
}
model = FastLanguageModel.get_peft_model(**peft_kwargs, **unsloth_peft_kwargs)
model = get_unsloth_peft_model(model, model_args, peft_kwargs)
else:
lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,