mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-06-17 20:58:54 +08:00
This commit is contained in:
@@ -194,10 +194,16 @@ def _setup_lora_tuning(
|
|||||||
logger.info_rank0(f"Merged {len(adapter_to_merge)} adapter(s).")
|
logger.info_rank0(f"Merged {len(adapter_to_merge)} adapter(s).")
|
||||||
|
|
||||||
if adapter_to_resume is not None: # resume lora training
|
if adapter_to_resume is not None: # resume lora training
|
||||||
if model_args.use_unsloth:
|
if isinstance(model, PeftModel):
|
||||||
model = load_unsloth_peft_model(config, model_args, finetuning_args, is_trainable=is_trainable)
|
pass # already loaded via load_unsloth_peft_model in loader.py
|
||||||
else:
|
else:
|
||||||
model = PeftModel.from_pretrained(model, adapter_to_resume, is_trainable=is_trainable, **init_kwargs)
|
if model_args.use_unsloth:
|
||||||
|
peft_model = load_unsloth_peft_model(config, model_args, finetuning_args, is_trainable=is_trainable)
|
||||||
|
if peft_model is not None:
|
||||||
|
model = peft_model
|
||||||
|
|
||||||
|
if not model_args.use_unsloth: # unsloth was disabled or fell back
|
||||||
|
model = PeftModel.from_pretrained(model, adapter_to_resume, is_trainable=is_trainable, **init_kwargs)
|
||||||
|
|
||||||
logger.info_rank0("Loaded adapter(s): {}".format(",".join(model_args.adapter_name_or_path)))
|
logger.info_rank0("Loaded adapter(s): {}".format(",".join(model_args.adapter_name_or_path)))
|
||||||
|
|
||||||
|
|||||||
@@ -34,7 +34,7 @@ from .adapter import init_adapter
|
|||||||
from .model_utils.liger_kernel import apply_liger_kernel
|
from .model_utils.liger_kernel import apply_liger_kernel
|
||||||
from .model_utils.misc import register_autoclass
|
from .model_utils.misc import register_autoclass
|
||||||
from .model_utils.mod import convert_pretrained_model_to_mod, load_mod_pretrained_model
|
from .model_utils.mod import convert_pretrained_model_to_mod, load_mod_pretrained_model
|
||||||
from .model_utils.unsloth import load_unsloth_pretrained_model
|
from .model_utils.unsloth import load_unsloth_pretrained_model, load_unsloth_peft_model
|
||||||
from .model_utils.valuehead import load_valuehead_params
|
from .model_utils.valuehead import load_valuehead_params
|
||||||
from .patcher import patch_config, patch_model, patch_processor, patch_tokenizer, patch_valuehead_model
|
from .patcher import patch_config, patch_model, patch_processor, patch_tokenizer, patch_valuehead_model
|
||||||
|
|
||||||
@@ -142,14 +142,13 @@ def load_model(
|
|||||||
apply_liger_kernel(config, model_args, is_trainable, require_logits=(finetuning_args.stage not in ["pt", "sft"]))
|
apply_liger_kernel(config, model_args, is_trainable, require_logits=(finetuning_args.stage not in ["pt", "sft"]))
|
||||||
|
|
||||||
model = None
|
model = None
|
||||||
lazy_load = False
|
|
||||||
if model_args.use_unsloth:
|
if model_args.use_unsloth:
|
||||||
if model_args.adapter_name_or_path is not None:
|
if model_args.adapter_name_or_path is not None:
|
||||||
lazy_load = True
|
model = load_unsloth_peft_model(config, model_args, finetuning_args, is_trainable=is_trainable)
|
||||||
elif is_trainable:
|
elif is_trainable:
|
||||||
model = load_unsloth_pretrained_model(config, model_args, finetuning_args)
|
model = load_unsloth_pretrained_model(config, model_args, finetuning_args)
|
||||||
|
|
||||||
if model is None and not lazy_load:
|
if model is None:
|
||||||
init_kwargs["config"] = config
|
init_kwargs["config"] = config
|
||||||
init_kwargs["pretrained_model_name_or_path"] = model_args.model_name_or_path
|
init_kwargs["pretrained_model_name_or_path"] = model_args.model_name_or_path
|
||||||
init_kwargs["torch_dtype"] = "auto"
|
init_kwargs["torch_dtype"] = "auto"
|
||||||
@@ -176,9 +175,8 @@ def load_model(
|
|||||||
if model_args.mixture_of_depths == "convert":
|
if model_args.mixture_of_depths == "convert":
|
||||||
model = convert_pretrained_model_to_mod(model, config, model_args)
|
model = convert_pretrained_model_to_mod(model, config, model_args)
|
||||||
|
|
||||||
if not lazy_load:
|
patch_model(model, tokenizer, model_args, is_trainable, add_valuehead)
|
||||||
patch_model(model, tokenizer, model_args, is_trainable, add_valuehead)
|
register_autoclass(config, model, tokenizer)
|
||||||
register_autoclass(config, model, tokenizer)
|
|
||||||
|
|
||||||
model = init_adapter(config, model, model_args, finetuning_args, is_trainable)
|
model = init_adapter(config, model, model_args, finetuning_args, is_trainable)
|
||||||
|
|
||||||
|
|||||||
@@ -84,8 +84,12 @@ def load_unsloth_peft_model(
|
|||||||
model_args: "ModelArguments",
|
model_args: "ModelArguments",
|
||||||
finetuning_args: "FinetuningArguments",
|
finetuning_args: "FinetuningArguments",
|
||||||
is_trainable: bool,
|
is_trainable: bool,
|
||||||
) -> "PreTrainedModel":
|
) -> Optional["PreTrainedModel"]:
|
||||||
r"""Load peft model with unsloth. Used in both training and inference."""
|
r"""Load peft model with unsloth. Used in both training and inference.
|
||||||
|
|
||||||
|
Returns None if unsloth does not support the model type, and sets
|
||||||
|
model_args.use_unsloth = False so callers can fall back to standard loading.
|
||||||
|
"""
|
||||||
from unsloth import FastLanguageModel # type: ignore
|
from unsloth import FastLanguageModel # type: ignore
|
||||||
|
|
||||||
unsloth_kwargs = _get_unsloth_kwargs(config, model_args.adapter_name_or_path[0], model_args, finetuning_args)
|
unsloth_kwargs = _get_unsloth_kwargs(config, model_args.adapter_name_or_path[0], model_args, finetuning_args)
|
||||||
@@ -95,7 +99,9 @@ def load_unsloth_peft_model(
|
|||||||
|
|
||||||
model, _ = FastLanguageModel.from_pretrained(**unsloth_kwargs)
|
model, _ = FastLanguageModel.from_pretrained(**unsloth_kwargs)
|
||||||
except NotImplementedError:
|
except NotImplementedError:
|
||||||
raise ValueError("Unsloth does not support model type {}.".format(getattr(config, "model_type", None)))
|
logger.warning_rank0("Unsloth does not support model type {}.".format(getattr(config, "model_type", None)))
|
||||||
|
model_args.use_unsloth = False
|
||||||
|
return None
|
||||||
|
|
||||||
if not is_trainable:
|
if not is_trainable:
|
||||||
FastLanguageModel.for_inference(model)
|
FastLanguageModel.for_inference(model)
|
||||||
|
|||||||
Reference in New Issue
Block a user