diff --git a/src/llamafactory/model/adapter.py b/src/llamafactory/model/adapter.py index b8c3900d7..61fa3c2ea 100644 --- a/src/llamafactory/model/adapter.py +++ b/src/llamafactory/model/adapter.py @@ -194,10 +194,16 @@ def _setup_lora_tuning( logger.info_rank0(f"Merged {len(adapter_to_merge)} adapter(s).") if adapter_to_resume is not None: # resume lora training - if model_args.use_unsloth: - model = load_unsloth_peft_model(config, model_args, finetuning_args, is_trainable=is_trainable) + if isinstance(model, PeftModel): + pass # already loaded via load_unsloth_peft_model in loader.py else: - model = PeftModel.from_pretrained(model, adapter_to_resume, is_trainable=is_trainable, **init_kwargs) + if model_args.use_unsloth: + peft_model = load_unsloth_peft_model(config, model_args, finetuning_args, is_trainable=is_trainable) + if peft_model is not None: + model = peft_model + + if not model_args.use_unsloth: # unsloth was disabled or fell back + model = PeftModel.from_pretrained(model, adapter_to_resume, is_trainable=is_trainable, **init_kwargs) logger.info_rank0("Loaded adapter(s): {}".format(",".join(model_args.adapter_name_or_path))) diff --git a/src/llamafactory/model/loader.py b/src/llamafactory/model/loader.py index 7a209ee11..6ed2d6b9c 100644 --- a/src/llamafactory/model/loader.py +++ b/src/llamafactory/model/loader.py @@ -34,7 +34,7 @@ from .adapter import init_adapter from .model_utils.liger_kernel import apply_liger_kernel from .model_utils.misc import register_autoclass from .model_utils.mod import convert_pretrained_model_to_mod, load_mod_pretrained_model -from .model_utils.unsloth import load_unsloth_pretrained_model +from .model_utils.unsloth import load_unsloth_pretrained_model, load_unsloth_peft_model from .model_utils.valuehead import load_valuehead_params from .patcher import patch_config, patch_model, patch_processor, patch_tokenizer, patch_valuehead_model @@ -142,14 +142,13 @@ def load_model( apply_liger_kernel(config, model_args, is_trainable, require_logits=(finetuning_args.stage not in ["pt", "sft"])) model = None - lazy_load = False if model_args.use_unsloth: if model_args.adapter_name_or_path is not None: - lazy_load = True + model = load_unsloth_peft_model(config, model_args, finetuning_args, is_trainable=is_trainable) elif is_trainable: model = load_unsloth_pretrained_model(config, model_args, finetuning_args) - if model is None and not lazy_load: + if model is None: init_kwargs["config"] = config init_kwargs["pretrained_model_name_or_path"] = model_args.model_name_or_path init_kwargs["torch_dtype"] = "auto" @@ -176,9 +175,8 @@ def load_model( if model_args.mixture_of_depths == "convert": model = convert_pretrained_model_to_mod(model, config, model_args) - if not lazy_load: - patch_model(model, tokenizer, model_args, is_trainable, add_valuehead) - register_autoclass(config, model, tokenizer) + patch_model(model, tokenizer, model_args, is_trainable, add_valuehead) + register_autoclass(config, model, tokenizer) model = init_adapter(config, model, model_args, finetuning_args, is_trainable) diff --git a/src/llamafactory/model/model_utils/unsloth.py b/src/llamafactory/model/model_utils/unsloth.py index 91e18dac9..7340d751d 100644 --- a/src/llamafactory/model/model_utils/unsloth.py +++ b/src/llamafactory/model/model_utils/unsloth.py @@ -84,8 +84,12 @@ def load_unsloth_peft_model( model_args: "ModelArguments", finetuning_args: "FinetuningArguments", is_trainable: bool, -) -> "PreTrainedModel": - r"""Load peft model with unsloth. Used in both training and inference.""" +) -> Optional["PreTrainedModel"]: + r"""Load peft model with unsloth. Used in both training and inference. + + Returns None if unsloth does not support the model type, and sets + model_args.use_unsloth = False so callers can fall back to standard loading. + """ from unsloth import FastLanguageModel # type: ignore unsloth_kwargs = _get_unsloth_kwargs(config, model_args.adapter_name_or_path[0], model_args, finetuning_args) @@ -95,7 +99,9 @@ def load_unsloth_peft_model( model, _ = FastLanguageModel.from_pretrained(**unsloth_kwargs) except NotImplementedError: - raise ValueError("Unsloth does not support model type {}.".format(getattr(config, "model_type", None))) + logger.warning_rank0("Unsloth does not support model type {}.".format(getattr(config, "model_type", None))) + model_args.use_unsloth = False + return None if not is_trainable: FastLanguageModel.for_inference(model)