support longlora for main branch

This commit is contained in:
hiyouga
2024-01-20 19:25:22 +08:00
parent bb92cdd0db
commit 38af076a75
7 changed files with 168 additions and 204 deletions

View File

@@ -1,4 +1,5 @@
import torch
import inspect
from typing import TYPE_CHECKING
from transformers.integrations import is_deepspeed_zero3_enabled
from peft import PeftModel, TaskType, LoraConfig, get_peft_model
@@ -108,6 +109,9 @@ def init_adapter(
if model_args.use_unsloth:
from unsloth import FastLlamaModel, FastMistralModel # type: ignore
unsloth_peft_kwargs = {"model": model, "max_seq_length": model_args.model_max_length}
if "loftq_config" in inspect.signature(FastLlamaModel.get_peft_model).parameters:
unsloth_peft_kwargs["loftq_config"] = {}
if getattr(model.config, "model_type", None) == "llama":
model = FastLlamaModel.get_peft_model(**peft_kwargs, **unsloth_peft_kwargs)
elif getattr(model.config, "model_type", None) == "mistral":