mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-15 03:10:35 +08:00
@@ -85,7 +85,7 @@ def init_adapter(
|
||||
logger.info("Set trainable layers: {}".format(",".join(map(str, trainable_layer_ids))))
|
||||
|
||||
if finetuning_args.finetuning_type == "lora":
|
||||
logger.info("Fine-tuning method: LoRA")
|
||||
logger.info("Fine-tuning method: {}".format("DoRA" if finetuning_args.use_dora else "LoRA"))
|
||||
adapter_to_resume = None
|
||||
|
||||
if model_args.adapter_name_or_path is not None:
|
||||
@@ -123,6 +123,10 @@ def init_adapter(
|
||||
if finetuning_args.use_llama_pro:
|
||||
target_modules = find_expanded_modules(model, target_modules, finetuning_args.num_layer_trainable)
|
||||
|
||||
if finetuning_args.use_dora:
|
||||
if getattr(model, "quantization_method", None):
|
||||
raise ValueError("DoRA is currently not compatible with quantized models.")
|
||||
|
||||
peft_kwargs = {
|
||||
"r": finetuning_args.lora_rank,
|
||||
"target_modules": target_modules,
|
||||
@@ -141,6 +145,7 @@ def init_adapter(
|
||||
task_type=TaskType.CAUSAL_LM,
|
||||
inference_mode=False,
|
||||
modules_to_save=finetuning_args.additional_target,
|
||||
use_dora=finetuning_args.use_dora,
|
||||
**peft_kwargs,
|
||||
)
|
||||
model = get_peft_model(model, lora_config)
|
||||
|
||||
@@ -51,7 +51,7 @@ def load_model_and_tokenizer(
|
||||
patch_tokenizer(tokenizer)
|
||||
|
||||
config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs)
|
||||
patch_config(config, tokenizer, model_args, config_kwargs, is_trainable)
|
||||
patch_config(config, tokenizer, model_args, finetuning_args, config_kwargs, is_trainable)
|
||||
|
||||
model = None
|
||||
if is_trainable and model_args.use_unsloth:
|
||||
|
||||
@@ -24,7 +24,7 @@ if TYPE_CHECKING:
|
||||
from transformers import PretrainedConfig, PreTrainedTokenizer
|
||||
from trl import AutoModelForCausalLMWithValueHead
|
||||
|
||||
from ..hparams import ModelArguments
|
||||
from ..hparams import FinetuningArguments, ModelArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@@ -253,6 +253,7 @@ def patch_config(
|
||||
config: "PretrainedConfig",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
model_args: "ModelArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
config_kwargs: Dict[str, Any],
|
||||
is_trainable: bool,
|
||||
) -> None:
|
||||
@@ -273,6 +274,9 @@ def patch_config(
|
||||
|
||||
_configure_quantization(config, tokenizer, model_args, config_kwargs)
|
||||
|
||||
if finetuning_args.use_dora:
|
||||
config_kwargs["device_map"] = {"": get_current_device()}
|
||||
|
||||
|
||||
def patch_model(
|
||||
model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments", is_trainable: bool
|
||||
|
||||
Reference in New Issue
Block a user