mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-01 11:12:50 +08:00
[model] pushing FFT with unsloth (#8325)
Co-authored-by: viyer <vivek_iyer2@apple.com>
This commit is contained in:
parent
239ced076c
commit
d325a1a7c7
@ -138,7 +138,7 @@ def load_model(
|
||||
if model_args.adapter_name_or_path is not None:
|
||||
lazy_load = True
|
||||
elif is_trainable:
|
||||
model = load_unsloth_pretrained_model(config, model_args)
|
||||
model = load_unsloth_pretrained_model(config, model_args, finetuning_args)
|
||||
|
||||
if model is None and not lazy_load:
|
||||
init_kwargs["config"] = config
|
||||
|
@ -21,14 +21,14 @@ from ...extras.misc import get_current_device
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PretrainedConfig, PreTrainedModel
|
||||
|
||||
from ...hparams import ModelArguments
|
||||
from ...hparams import ModelArguments, FinetuningArguments
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def _get_unsloth_kwargs(
|
||||
config: "PretrainedConfig", model_name_or_path: str, model_args: "ModelArguments"
|
||||
config: "PretrainedConfig", model_name_or_path: str, model_args: "ModelArguments", finetuning_args: "FinetuningArguments"
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
"model_name": model_name_or_path,
|
||||
@ -36,6 +36,7 @@ def _get_unsloth_kwargs(
|
||||
"dtype": model_args.compute_dtype,
|
||||
"load_in_4bit": model_args.quantization_bit == 4,
|
||||
"token": model_args.hf_hub_token,
|
||||
"full_finetuning": finetuning_args.finetuning_type == "full",
|
||||
"device_map": {"": get_current_device()},
|
||||
"rope_scaling": getattr(config, "rope_scaling", None),
|
||||
"fix_tokenizer": False,
|
||||
@ -45,12 +46,12 @@ def _get_unsloth_kwargs(
|
||||
|
||||
|
||||
def load_unsloth_pretrained_model(
|
||||
config: "PretrainedConfig", model_args: "ModelArguments"
|
||||
config: "PretrainedConfig", model_args: "ModelArguments", finetuning_args: "FinetuningArguments"
|
||||
) -> Optional["PreTrainedModel"]:
|
||||
r"""Optionally load pretrained model with unsloth. Used in training."""
|
||||
from unsloth import FastLanguageModel # type: ignore
|
||||
|
||||
unsloth_kwargs = _get_unsloth_kwargs(config, model_args.model_name_or_path, model_args)
|
||||
unsloth_kwargs = _get_unsloth_kwargs(config, model_args.model_name_or_path, model_args, finetuning_args)
|
||||
try:
|
||||
model, _ = FastLanguageModel.from_pretrained(**unsloth_kwargs)
|
||||
except NotImplementedError:
|
||||
|
Loading…
x
Reference in New Issue
Block a user