mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-06-17 04:38:53 +08:00
This commit is contained in:
@@ -194,10 +194,16 @@ def _setup_lora_tuning(
|
||||
logger.info_rank0(f"Merged {len(adapter_to_merge)} adapter(s).")
|
||||
|
||||
if adapter_to_resume is not None: # resume lora training
|
||||
if model_args.use_unsloth:
|
||||
model = load_unsloth_peft_model(config, model_args, finetuning_args, is_trainable=is_trainable)
|
||||
if isinstance(model, PeftModel):
|
||||
pass # already loaded via load_unsloth_peft_model in loader.py
|
||||
else:
|
||||
model = PeftModel.from_pretrained(model, adapter_to_resume, is_trainable=is_trainable, **init_kwargs)
|
||||
if model_args.use_unsloth:
|
||||
peft_model = load_unsloth_peft_model(config, model_args, finetuning_args, is_trainable=is_trainable)
|
||||
if peft_model is not None:
|
||||
model = peft_model
|
||||
|
||||
if not model_args.use_unsloth: # unsloth was disabled or fell back
|
||||
model = PeftModel.from_pretrained(model, adapter_to_resume, is_trainable=is_trainable, **init_kwargs)
|
||||
|
||||
logger.info_rank0("Loaded adapter(s): {}".format(",".join(model_args.adapter_name_or_path)))
|
||||
|
||||
|
||||
@@ -34,7 +34,7 @@ from .adapter import init_adapter
|
||||
from .model_utils.liger_kernel import apply_liger_kernel
|
||||
from .model_utils.misc import register_autoclass
|
||||
from .model_utils.mod import convert_pretrained_model_to_mod, load_mod_pretrained_model
|
||||
from .model_utils.unsloth import load_unsloth_pretrained_model
|
||||
from .model_utils.unsloth import load_unsloth_pretrained_model, load_unsloth_peft_model
|
||||
from .model_utils.valuehead import load_valuehead_params
|
||||
from .patcher import patch_config, patch_model, patch_processor, patch_tokenizer, patch_valuehead_model
|
||||
|
||||
@@ -142,14 +142,13 @@ def load_model(
|
||||
apply_liger_kernel(config, model_args, is_trainable, require_logits=(finetuning_args.stage not in ["pt", "sft"]))
|
||||
|
||||
model = None
|
||||
lazy_load = False
|
||||
if model_args.use_unsloth:
|
||||
if model_args.adapter_name_or_path is not None:
|
||||
lazy_load = True
|
||||
model = load_unsloth_peft_model(config, model_args, finetuning_args, is_trainable=is_trainable)
|
||||
elif is_trainable:
|
||||
model = load_unsloth_pretrained_model(config, model_args, finetuning_args)
|
||||
|
||||
if model is None and not lazy_load:
|
||||
if model is None:
|
||||
init_kwargs["config"] = config
|
||||
init_kwargs["pretrained_model_name_or_path"] = model_args.model_name_or_path
|
||||
init_kwargs["torch_dtype"] = "auto"
|
||||
@@ -176,9 +175,8 @@ def load_model(
|
||||
if model_args.mixture_of_depths == "convert":
|
||||
model = convert_pretrained_model_to_mod(model, config, model_args)
|
||||
|
||||
if not lazy_load:
|
||||
patch_model(model, tokenizer, model_args, is_trainable, add_valuehead)
|
||||
register_autoclass(config, model, tokenizer)
|
||||
patch_model(model, tokenizer, model_args, is_trainable, add_valuehead)
|
||||
register_autoclass(config, model, tokenizer)
|
||||
|
||||
model = init_adapter(config, model, model_args, finetuning_args, is_trainable)
|
||||
|
||||
|
||||
@@ -84,8 +84,12 @@ def load_unsloth_peft_model(
|
||||
model_args: "ModelArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
is_trainable: bool,
|
||||
) -> "PreTrainedModel":
|
||||
r"""Load peft model with unsloth. Used in both training and inference."""
|
||||
) -> Optional["PreTrainedModel"]:
|
||||
r"""Load peft model with unsloth. Used in both training and inference.
|
||||
|
||||
Returns None if unsloth does not support the model type, and sets
|
||||
model_args.use_unsloth = False so callers can fall back to standard loading.
|
||||
"""
|
||||
from unsloth import FastLanguageModel # type: ignore
|
||||
|
||||
unsloth_kwargs = _get_unsloth_kwargs(config, model_args.adapter_name_or_path[0], model_args, finetuning_args)
|
||||
@@ -95,7 +99,9 @@ def load_unsloth_peft_model(
|
||||
|
||||
model, _ = FastLanguageModel.from_pretrained(**unsloth_kwargs)
|
||||
except NotImplementedError:
|
||||
raise ValueError("Unsloth does not support model type {}.".format(getattr(config, "model_type", None)))
|
||||
logger.warning_rank0("Unsloth does not support model type {}.".format(getattr(config, "model_type", None)))
|
||||
model_args.use_unsloth = False
|
||||
return None
|
||||
|
||||
if not is_trainable:
|
||||
FastLanguageModel.for_inference(model)
|
||||
|
||||
Reference in New Issue
Block a user