diff --git a/src/llamafactory/model/adapter.py b/src/llamafactory/model/adapter.py index 8b8c6fb8..9a000f41 100644 --- a/src/llamafactory/model/adapter.py +++ b/src/llamafactory/model/adapter.py @@ -188,7 +188,7 @@ def _setup_lora_tuning( if adapter_to_resume is not None: # resume lora training if model_args.use_unsloth: - model = load_unsloth_peft_model(config, model_args, is_trainable=is_trainable) + model = load_unsloth_peft_model(config, model_args, finetuning_args, is_trainable=is_trainable) else: model = PeftModel.from_pretrained(model, adapter_to_resume, is_trainable=is_trainable, **init_kwargs) diff --git a/src/llamafactory/model/model_utils/unsloth.py b/src/llamafactory/model/model_utils/unsloth.py index 37791524..7792857a 100644 --- a/src/llamafactory/model/model_utils/unsloth.py +++ b/src/llamafactory/model/model_utils/unsloth.py @@ -80,12 +80,12 @@ def get_unsloth_peft_model( def load_unsloth_peft_model( - config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool + config: "PretrainedConfig", model_args: "ModelArguments", finetuning_args: "FinetuningArguments", is_trainable: bool ) -> "PreTrainedModel": r"""Load peft model with unsloth. Used in both training and inference.""" from unsloth import FastLanguageModel # type: ignore - unsloth_kwargs = _get_unsloth_kwargs(config, model_args.adapter_name_or_path[0], model_args) + unsloth_kwargs = _get_unsloth_kwargs(config, model_args.adapter_name_or_path[0], model_args, finetuning_args) try: if not is_trainable: unsloth_kwargs["use_gradient_checkpointing"] = False