diff --git a/src/llamafactory/model/loader.py b/src/llamafactory/model/loader.py index d327eecf..cbcc6b28 100644 --- a/src/llamafactory/model/loader.py +++ b/src/llamafactory/model/loader.py @@ -138,7 +138,7 @@ def load_model( if model_args.adapter_name_or_path is not None: lazy_load = True elif is_trainable: - model = load_unsloth_pretrained_model(config, model_args) + model = load_unsloth_pretrained_model(config, model_args, finetuning_args) if model is None and not lazy_load: init_kwargs["config"] = config diff --git a/src/llamafactory/model/model_utils/unsloth.py b/src/llamafactory/model/model_utils/unsloth.py index 8bb6aa64..352ef048 100644 --- a/src/llamafactory/model/model_utils/unsloth.py +++ b/src/llamafactory/model/model_utils/unsloth.py @@ -21,14 +21,14 @@ from ...extras.misc import get_current_device if TYPE_CHECKING: from transformers import PretrainedConfig, PreTrainedModel - from ...hparams import ModelArguments + from ...hparams import ModelArguments, FinetuningArguments logger = logging.get_logger(__name__) def _get_unsloth_kwargs( - config: "PretrainedConfig", model_name_or_path: str, model_args: "ModelArguments" + config: "PretrainedConfig", model_name_or_path: str, model_args: "ModelArguments", finetuning_args: "FinetuningArguments" ) -> dict[str, Any]: return { "model_name": model_name_or_path, @@ -36,6 +36,7 @@ def _get_unsloth_kwargs( "dtype": model_args.compute_dtype, "load_in_4bit": model_args.quantization_bit == 4, "token": model_args.hf_hub_token, + "full_finetuning": finetuning_args.finetuning_type == "full", "device_map": {"": get_current_device()}, "rope_scaling": getattr(config, "rope_scaling", None), "fix_tokenizer": False, @@ -45,12 +46,12 @@ def _get_unsloth_kwargs( def load_unsloth_pretrained_model( - config: "PretrainedConfig", model_args: "ModelArguments" + config: "PretrainedConfig", model_args: "ModelArguments", finetuning_args: "FinetuningArguments" ) -> Optional["PreTrainedModel"]: r"""Optionally load pretrained model with unsloth. Used in training.""" from unsloth import FastLanguageModel # type: ignore - unsloth_kwargs = _get_unsloth_kwargs(config, model_args.model_name_or_path, model_args) + unsloth_kwargs = _get_unsloth_kwargs(config, model_args.model_name_or_path, model_args, finetuning_args) try: model, _ = FastLanguageModel.from_pretrained(**unsloth_kwargs) except NotImplementedError: