mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-10-14 15:52:49 +08:00
[model] pushing FFT with unsloth (#8325)
Co-authored-by: viyer <vivek_iyer2@apple.com>
This commit is contained in:
parent
03a93ec513
commit
32b4574094
@ -138,7 +138,7 @@ def load_model(
|
|||||||
if model_args.adapter_name_or_path is not None:
|
if model_args.adapter_name_or_path is not None:
|
||||||
lazy_load = True
|
lazy_load = True
|
||||||
elif is_trainable:
|
elif is_trainable:
|
||||||
model = load_unsloth_pretrained_model(config, model_args)
|
model = load_unsloth_pretrained_model(config, model_args, finetuning_args)
|
||||||
|
|
||||||
if model is None and not lazy_load:
|
if model is None and not lazy_load:
|
||||||
init_kwargs["config"] = config
|
init_kwargs["config"] = config
|
||||||
|
@ -21,14 +21,14 @@ from ...extras.misc import get_current_device
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import PretrainedConfig, PreTrainedModel
|
from transformers import PretrainedConfig, PreTrainedModel
|
||||||
|
|
||||||
from ...hparams import ModelArguments
|
from ...hparams import ModelArguments, FinetuningArguments
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _get_unsloth_kwargs(
|
def _get_unsloth_kwargs(
|
||||||
config: "PretrainedConfig", model_name_or_path: str, model_args: "ModelArguments"
|
config: "PretrainedConfig", model_name_or_path: str, model_args: "ModelArguments", finetuning_args: "FinetuningArguments"
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"model_name": model_name_or_path,
|
"model_name": model_name_or_path,
|
||||||
@ -36,6 +36,7 @@ def _get_unsloth_kwargs(
|
|||||||
"dtype": model_args.compute_dtype,
|
"dtype": model_args.compute_dtype,
|
||||||
"load_in_4bit": model_args.quantization_bit == 4,
|
"load_in_4bit": model_args.quantization_bit == 4,
|
||||||
"token": model_args.hf_hub_token,
|
"token": model_args.hf_hub_token,
|
||||||
|
"full_finetuning": finetuning_args.finetuning_type == "full",
|
||||||
"device_map": {"": get_current_device()},
|
"device_map": {"": get_current_device()},
|
||||||
"rope_scaling": getattr(config, "rope_scaling", None),
|
"rope_scaling": getattr(config, "rope_scaling", None),
|
||||||
"fix_tokenizer": False,
|
"fix_tokenizer": False,
|
||||||
@ -45,12 +46,12 @@ def _get_unsloth_kwargs(
|
|||||||
|
|
||||||
|
|
||||||
def load_unsloth_pretrained_model(
|
def load_unsloth_pretrained_model(
|
||||||
config: "PretrainedConfig", model_args: "ModelArguments"
|
config: "PretrainedConfig", model_args: "ModelArguments", finetuning_args: "FinetuningArguments"
|
||||||
) -> Optional["PreTrainedModel"]:
|
) -> Optional["PreTrainedModel"]:
|
||||||
r"""Optionally load pretrained model with unsloth. Used in training."""
|
r"""Optionally load pretrained model with unsloth. Used in training."""
|
||||||
from unsloth import FastLanguageModel # type: ignore
|
from unsloth import FastLanguageModel # type: ignore
|
||||||
|
|
||||||
unsloth_kwargs = _get_unsloth_kwargs(config, model_args.model_name_or_path, model_args)
|
unsloth_kwargs = _get_unsloth_kwargs(config, model_args.model_name_or_path, model_args, finetuning_args)
|
||||||
try:
|
try:
|
||||||
model, _ = FastLanguageModel.from_pretrained(**unsloth_kwargs)
|
model, _ = FastLanguageModel.from_pretrained(**unsloth_kwargs)
|
||||||
except NotImplementedError:
|
except NotImplementedError:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user