[model] pushing FFT with unsloth (#8325)

Co-authored-by: viyer <vivek_iyer2@apple.com>
This commit is contained in:
Vivek Iyer 2025-06-06 20:20:58 -04:00 committed by GitHub
parent 239ced076c
commit d325a1a7c7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 6 additions and 5 deletions

View File

@ -138,7 +138,7 @@ def load_model(
if model_args.adapter_name_or_path is not None:
lazy_load = True
elif is_trainable:
model = load_unsloth_pretrained_model(config, model_args)
model = load_unsloth_pretrained_model(config, model_args, finetuning_args)
if model is None and not lazy_load:
init_kwargs["config"] = config

View File

@ -21,14 +21,14 @@ from ...extras.misc import get_current_device
if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedModel
from ...hparams import ModelArguments
from ...hparams import ModelArguments, FinetuningArguments
logger = logging.get_logger(__name__)
def _get_unsloth_kwargs(
config: "PretrainedConfig", model_name_or_path: str, model_args: "ModelArguments"
config: "PretrainedConfig", model_name_or_path: str, model_args: "ModelArguments", finetuning_args: "FinetuningArguments"
) -> dict[str, Any]:
return {
"model_name": model_name_or_path,
@ -36,6 +36,7 @@ def _get_unsloth_kwargs(
"dtype": model_args.compute_dtype,
"load_in_4bit": model_args.quantization_bit == 4,
"token": model_args.hf_hub_token,
"full_finetuning": finetuning_args.finetuning_type == "full",
"device_map": {"": get_current_device()},
"rope_scaling": getattr(config, "rope_scaling", None),
"fix_tokenizer": False,
@ -45,12 +46,12 @@ def _get_unsloth_kwargs(
def load_unsloth_pretrained_model(
config: "PretrainedConfig", model_args: "ModelArguments"
config: "PretrainedConfig", model_args: "ModelArguments", finetuning_args: "FinetuningArguments"
) -> Optional["PreTrainedModel"]:
r"""Optionally load pretrained model with unsloth. Used in training."""
from unsloth import FastLanguageModel # type: ignore
unsloth_kwargs = _get_unsloth_kwargs(config, model_args.model_name_or_path, model_args)
unsloth_kwargs = _get_unsloth_kwargs(config, model_args.model_name_or_path, model_args, finetuning_args)
try:
model, _ = FastLanguageModel.from_pretrained(**unsloth_kwargs)
except NotImplementedError: