mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-16 20:00:36 +08:00
support longlora for main branch
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import torch
|
||||
import inspect
|
||||
from typing import TYPE_CHECKING
|
||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||
from peft import PeftModel, TaskType, LoraConfig, get_peft_model
|
||||
@@ -108,6 +109,9 @@ def init_adapter(
|
||||
if model_args.use_unsloth:
|
||||
from unsloth import FastLlamaModel, FastMistralModel # type: ignore
|
||||
unsloth_peft_kwargs = {"model": model, "max_seq_length": model_args.model_max_length}
|
||||
if "loftq_config" in inspect.signature(FastLlamaModel.get_peft_model).parameters:
|
||||
unsloth_peft_kwargs["loftq_config"] = {}
|
||||
|
||||
if getattr(model.config, "model_type", None) == "llama":
|
||||
model = FastLlamaModel.get_peft_model(**peft_kwargs, **unsloth_peft_kwargs)
|
||||
elif getattr(model.config, "model_type", None) == "mistral":
|
||||
|
||||
Reference in New Issue
Block a user