[model] handle unsloth model loading fallback during checkpoint resume (#7156) (#10551)

This commit is contained in:
Co-Cl2
2026-06-09 01:01:01 +08:00
committed by GitHub
parent 0b7aaf8f6a
commit 9ca4026efe
3 changed files with 23 additions and 13 deletions

View File

@@ -194,10 +194,16 @@ def _setup_lora_tuning(
logger.info_rank0(f"Merged {len(adapter_to_merge)} adapter(s).")
if adapter_to_resume is not None: # resume lora training
if model_args.use_unsloth:
model = load_unsloth_peft_model(config, model_args, finetuning_args, is_trainable=is_trainable)
if isinstance(model, PeftModel):
pass # already loaded via load_unsloth_peft_model in loader.py
else:
model = PeftModel.from_pretrained(model, adapter_to_resume, is_trainable=is_trainable, **init_kwargs)
if model_args.use_unsloth:
peft_model = load_unsloth_peft_model(config, model_args, finetuning_args, is_trainable=is_trainable)
if peft_model is not None:
model = peft_model
if not model_args.use_unsloth: # unsloth was disabled or fell back
model = PeftModel.from_pretrained(model, adapter_to_resume, is_trainable=is_trainable, **init_kwargs)
logger.info_rank0("Loaded adapter(s): {}".format(",".join(model_args.adapter_name_or_path)))

View File

@@ -34,7 +34,7 @@ from .adapter import init_adapter
from .model_utils.liger_kernel import apply_liger_kernel
from .model_utils.misc import register_autoclass
from .model_utils.mod import convert_pretrained_model_to_mod, load_mod_pretrained_model
from .model_utils.unsloth import load_unsloth_pretrained_model
from .model_utils.unsloth import load_unsloth_pretrained_model, load_unsloth_peft_model
from .model_utils.valuehead import load_valuehead_params
from .patcher import patch_config, patch_model, patch_processor, patch_tokenizer, patch_valuehead_model
@@ -142,14 +142,13 @@ def load_model(
apply_liger_kernel(config, model_args, is_trainable, require_logits=(finetuning_args.stage not in ["pt", "sft"]))
model = None
lazy_load = False
if model_args.use_unsloth:
if model_args.adapter_name_or_path is not None:
lazy_load = True
model = load_unsloth_peft_model(config, model_args, finetuning_args, is_trainable=is_trainable)
elif is_trainable:
model = load_unsloth_pretrained_model(config, model_args, finetuning_args)
if model is None and not lazy_load:
if model is None:
init_kwargs["config"] = config
init_kwargs["pretrained_model_name_or_path"] = model_args.model_name_or_path
init_kwargs["torch_dtype"] = "auto"
@@ -176,9 +175,8 @@ def load_model(
if model_args.mixture_of_depths == "convert":
model = convert_pretrained_model_to_mod(model, config, model_args)
if not lazy_load:
patch_model(model, tokenizer, model_args, is_trainable, add_valuehead)
register_autoclass(config, model, tokenizer)
patch_model(model, tokenizer, model_args, is_trainable, add_valuehead)
register_autoclass(config, model, tokenizer)
model = init_adapter(config, model, model_args, finetuning_args, is_trainable)

View File

@@ -84,8 +84,12 @@ def load_unsloth_peft_model(
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments",
is_trainable: bool,
) -> "PreTrainedModel":
r"""Load peft model with unsloth. Used in both training and inference."""
) -> Optional["PreTrainedModel"]:
r"""Load peft model with unsloth. Used in both training and inference.
Returns None if unsloth does not support the model type, and sets
model_args.use_unsloth = False so callers can fall back to standard loading.
"""
from unsloth import FastLanguageModel # type: ignore
unsloth_kwargs = _get_unsloth_kwargs(config, model_args.adapter_name_or_path[0], model_args, finetuning_args)
@@ -95,7 +99,9 @@ def load_unsloth_peft_model(
model, _ = FastLanguageModel.from_pretrained(**unsloth_kwargs)
except NotImplementedError:
raise ValueError("Unsloth does not support model type {}.".format(getattr(config, "model_type", None)))
logger.warning_rank0("Unsloth does not support model type {}.".format(getattr(config, "model_type", None)))
model_args.use_unsloth = False
return None
if not is_trainable:
FastLanguageModel.for_inference(model)