[model] pushing FFT with unsloth (#8325)

Co-authored-by: viyer <vivek_iyer2@apple.com>
This commit is contained in:
Vivek Iyer 2025-06-06 20:20:58 -04:00 committed by GitHub
parent 03a93ec513
commit 32b4574094
2 changed files with 6 additions and 5 deletions

View File

@ -138,7 +138,7 @@ def load_model(
if model_args.adapter_name_or_path is not None: if model_args.adapter_name_or_path is not None:
lazy_load = True lazy_load = True
elif is_trainable: elif is_trainable:
model = load_unsloth_pretrained_model(config, model_args) model = load_unsloth_pretrained_model(config, model_args, finetuning_args)
if model is None and not lazy_load: if model is None and not lazy_load:
init_kwargs["config"] = config init_kwargs["config"] = config

View File

@ -21,14 +21,14 @@ from ...extras.misc import get_current_device
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedModel from transformers import PretrainedConfig, PreTrainedModel
from ...hparams import ModelArguments from ...hparams import ModelArguments, FinetuningArguments
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
def _get_unsloth_kwargs( def _get_unsloth_kwargs(
config: "PretrainedConfig", model_name_or_path: str, model_args: "ModelArguments" config: "PretrainedConfig", model_name_or_path: str, model_args: "ModelArguments", finetuning_args: "FinetuningArguments"
) -> dict[str, Any]: ) -> dict[str, Any]:
return { return {
"model_name": model_name_or_path, "model_name": model_name_or_path,
@ -36,6 +36,7 @@ def _get_unsloth_kwargs(
"dtype": model_args.compute_dtype, "dtype": model_args.compute_dtype,
"load_in_4bit": model_args.quantization_bit == 4, "load_in_4bit": model_args.quantization_bit == 4,
"token": model_args.hf_hub_token, "token": model_args.hf_hub_token,
"full_finetuning": finetuning_args.finetuning_type == "full",
"device_map": {"": get_current_device()}, "device_map": {"": get_current_device()},
"rope_scaling": getattr(config, "rope_scaling", None), "rope_scaling": getattr(config, "rope_scaling", None),
"fix_tokenizer": False, "fix_tokenizer": False,
@ -45,12 +46,12 @@ def _get_unsloth_kwargs(
def load_unsloth_pretrained_model( def load_unsloth_pretrained_model(
config: "PretrainedConfig", model_args: "ModelArguments" config: "PretrainedConfig", model_args: "ModelArguments", finetuning_args: "FinetuningArguments"
) -> Optional["PreTrainedModel"]: ) -> Optional["PreTrainedModel"]:
r"""Optionally load pretrained model with unsloth. Used in training.""" r"""Optionally load pretrained model with unsloth. Used in training."""
from unsloth import FastLanguageModel # type: ignore from unsloth import FastLanguageModel # type: ignore
unsloth_kwargs = _get_unsloth_kwargs(config, model_args.model_name_or_path, model_args) unsloth_kwargs = _get_unsloth_kwargs(config, model_args.model_name_or_path, model_args, finetuning_args)
try: try:
model, _ = FastLanguageModel.from_pretrained(**unsloth_kwargs) model, _ = FastLanguageModel.from_pretrained(**unsloth_kwargs)
except NotImplementedError: except NotImplementedError: