[model] handle unsloth model loading fallback during checkpoint resume (#7156) (#10551)

This commit is contained in:
Co-Cl2
2026-06-09 01:01:01 +08:00
committed by GitHub
parent 0b7aaf8f6a
commit 9ca4026efe
3 changed files with 23 additions and 13 deletions

View File

@@ -194,10 +194,16 @@ def _setup_lora_tuning(
logger.info_rank0(f"Merged {len(adapter_to_merge)} adapter(s).") logger.info_rank0(f"Merged {len(adapter_to_merge)} adapter(s).")
if adapter_to_resume is not None: # resume lora training if adapter_to_resume is not None: # resume lora training
if model_args.use_unsloth: if isinstance(model, PeftModel):
model = load_unsloth_peft_model(config, model_args, finetuning_args, is_trainable=is_trainable) pass # already loaded via load_unsloth_peft_model in loader.py
else: else:
model = PeftModel.from_pretrained(model, adapter_to_resume, is_trainable=is_trainable, **init_kwargs) if model_args.use_unsloth:
peft_model = load_unsloth_peft_model(config, model_args, finetuning_args, is_trainable=is_trainable)
if peft_model is not None:
model = peft_model
if not model_args.use_unsloth: # unsloth was disabled or fell back
model = PeftModel.from_pretrained(model, adapter_to_resume, is_trainable=is_trainable, **init_kwargs)
logger.info_rank0("Loaded adapter(s): {}".format(",".join(model_args.adapter_name_or_path))) logger.info_rank0("Loaded adapter(s): {}".format(",".join(model_args.adapter_name_or_path)))

View File

@@ -34,7 +34,7 @@ from .adapter import init_adapter
from .model_utils.liger_kernel import apply_liger_kernel from .model_utils.liger_kernel import apply_liger_kernel
from .model_utils.misc import register_autoclass from .model_utils.misc import register_autoclass
from .model_utils.mod import convert_pretrained_model_to_mod, load_mod_pretrained_model from .model_utils.mod import convert_pretrained_model_to_mod, load_mod_pretrained_model
from .model_utils.unsloth import load_unsloth_pretrained_model from .model_utils.unsloth import load_unsloth_pretrained_model, load_unsloth_peft_model
from .model_utils.valuehead import load_valuehead_params from .model_utils.valuehead import load_valuehead_params
from .patcher import patch_config, patch_model, patch_processor, patch_tokenizer, patch_valuehead_model from .patcher import patch_config, patch_model, patch_processor, patch_tokenizer, patch_valuehead_model
@@ -142,14 +142,13 @@ def load_model(
apply_liger_kernel(config, model_args, is_trainable, require_logits=(finetuning_args.stage not in ["pt", "sft"])) apply_liger_kernel(config, model_args, is_trainable, require_logits=(finetuning_args.stage not in ["pt", "sft"]))
model = None model = None
lazy_load = False
if model_args.use_unsloth: if model_args.use_unsloth:
if model_args.adapter_name_or_path is not None: if model_args.adapter_name_or_path is not None:
lazy_load = True model = load_unsloth_peft_model(config, model_args, finetuning_args, is_trainable=is_trainable)
elif is_trainable: elif is_trainable:
model = load_unsloth_pretrained_model(config, model_args, finetuning_args) model = load_unsloth_pretrained_model(config, model_args, finetuning_args)
if model is None and not lazy_load: if model is None:
init_kwargs["config"] = config init_kwargs["config"] = config
init_kwargs["pretrained_model_name_or_path"] = model_args.model_name_or_path init_kwargs["pretrained_model_name_or_path"] = model_args.model_name_or_path
init_kwargs["torch_dtype"] = "auto" init_kwargs["torch_dtype"] = "auto"
@@ -176,9 +175,8 @@ def load_model(
if model_args.mixture_of_depths == "convert": if model_args.mixture_of_depths == "convert":
model = convert_pretrained_model_to_mod(model, config, model_args) model = convert_pretrained_model_to_mod(model, config, model_args)
if not lazy_load: patch_model(model, tokenizer, model_args, is_trainable, add_valuehead)
patch_model(model, tokenizer, model_args, is_trainable, add_valuehead) register_autoclass(config, model, tokenizer)
register_autoclass(config, model, tokenizer)
model = init_adapter(config, model, model_args, finetuning_args, is_trainable) model = init_adapter(config, model, model_args, finetuning_args, is_trainable)

View File

@@ -84,8 +84,12 @@ def load_unsloth_peft_model(
model_args: "ModelArguments", model_args: "ModelArguments",
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
is_trainable: bool, is_trainable: bool,
) -> "PreTrainedModel": ) -> Optional["PreTrainedModel"]:
r"""Load peft model with unsloth. Used in both training and inference.""" r"""Load peft model with unsloth. Used in both training and inference.
Returns None if unsloth does not support the model type, and sets
model_args.use_unsloth = False so callers can fall back to standard loading.
"""
from unsloth import FastLanguageModel # type: ignore from unsloth import FastLanguageModel # type: ignore
unsloth_kwargs = _get_unsloth_kwargs(config, model_args.adapter_name_or_path[0], model_args, finetuning_args) unsloth_kwargs = _get_unsloth_kwargs(config, model_args.adapter_name_or_path[0], model_args, finetuning_args)
@@ -95,7 +99,9 @@ def load_unsloth_peft_model(
model, _ = FastLanguageModel.from_pretrained(**unsloth_kwargs) model, _ = FastLanguageModel.from_pretrained(**unsloth_kwargs)
except NotImplementedError: except NotImplementedError:
raise ValueError("Unsloth does not support model type {}.".format(getattr(config, "model_type", None))) logger.warning_rank0("Unsloth does not support model type {}.".format(getattr(config, "model_type", None)))
model_args.use_unsloth = False
return None
if not is_trainable: if not is_trainable:
FastLanguageModel.for_inference(model) FastLanguageModel.for_inference(model)